From ce7e6edad8b87367fc61340658ef08d4b74acb96 Mon Sep 17 00:00:00 2001 From: "Theo N. Truong" <644650+nhtruong@users.noreply.github.com> Date: Fri, 31 Oct 2025 13:00:21 -0600 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=84=20refactor:=20MCP=20Registry=20Sys?= =?UTF-8?q?tem=20with=20Distributed=20Caching=20(#10191)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * refactor: Restructure MCP registry system with caching - Split MCPServersRegistry into modular components: - MCPServerInspector: handles server inspection and health checks - MCPServersInitializer: manages server initialization logic - MCPServersRegistry: simplified registry coordination - Add distributed caching layer: - ServerConfigsCacheRedis: Redis-backed configuration cache - ServerConfigsCacheInMemory: in-memory fallback cache - RegistryStatusCache: distributed leader election state - Add promise utilities (withTimeout) replacing Promise.race patterns - Add comprehensive cache integration tests for all cache implementations - Remove unused MCPManager.getAllToolFunctions method * fix: Update OAuth flow to include user-specific headers * chore: Update Jest configuration to ignore additional test files - Added patterns to ignore files ending with .helper.ts and .helper.d.ts in testPathIgnorePatterns for cleaner test runs. * fix: oauth headers in callback * chore: Update Jest testPathIgnorePatterns to exclude helper files - Modified testPathIgnorePatterns in package.json to ignore files ending with .helper.ts and .helper.d.ts for cleaner test execution. * ci: update test mocks --------- Co-authored-by: Danny Avila --- .github/workflows/cache-integration-tests.yml | 9 + api/server/controllers/UserController.js | 12 +- api/server/controllers/mcp.js | 3 +- api/server/routes/__tests__/mcp.spec.js | 120 ++-- api/server/routes/config.js | 10 +- api/server/routes/mcp.js | 22 +- api/server/services/MCP.js | 3 +- api/server/services/MCP.spec.js | 13 +- api/server/services/initializeMCPs.js | 2 +- packages/api/jest.config.mjs | 10 +- packages/api/package.json | 5 +- packages/api/src/index.ts | 1 + packages/api/src/mcp/MCPConnectionFactory.ts | 12 +- packages/api/src/mcp/MCPManager.ts | 76 +-- packages/api/src/mcp/MCPServersRegistry.ts | 230 ------- packages/api/src/mcp/UserConnectionManager.ts | 14 +- .../api/src/mcp/__tests__/MCPManager.test.ts | 282 ++++++++- .../mcp/__tests__/MCPServersRegistry.test.ts | 595 ------------------ .../MCPServersRegistry.parsedConfigs.yml | 67 -- .../MCPServersRegistry.rawConfigs.yml | 53 -- packages/api/src/mcp/connection.ts | 13 +- .../oauth/OAuthReconnectionManager.test.ts | 37 +- .../src/mcp/oauth/OAuthReconnectionManager.ts | 5 +- .../src/mcp/registry/MCPServerInspector.ts | 123 ++++ .../src/mcp/registry/MCPServersInitializer.ts | 96 +++ .../src/mcp/registry/MCPServersRegistry.ts | 91 +++ .../__tests__/MCPServerInspector.test.ts | 338 ++++++++++ ...rversInitializer.cache_integration.spec.ts | 301 +++++++++ .../__tests__/MCPServersInitializer.test.ts | 292 +++++++++ ...PServersRegistry.cache_integration.spec.ts | 227 +++++++ .../__tests__/MCPServersRegistry.test.ts | 175 ++++++ .../__tests__/mcpConnectionsMock.helper.ts | 55 ++ .../mcp/registry/cache/BaseRegistryCache.ts | 26 + .../mcp/registry/cache/RegistryStatusCache.ts | 37 ++ .../cache/ServerConfigsCacheFactory.ts | 31 + .../cache/ServerConfigsCacheInMemory.ts | 46 ++ .../registry/cache/ServerConfigsCacheRedis.ts | 80 +++ ...istryStatusCache.cache_integration.spec.ts | 73 +++ .../ServerConfigsCacheFactory.test.ts | 70 +++ .../ServerConfigsCacheInMemory.test.ts | 173 +++++ ...onfigsCacheRedis.cache_integration.spec.ts | 278 ++++++++ packages/api/src/mcp/types/index.ts | 2 + packages/api/src/utils/index.ts | 1 + packages/api/src/utils/promise.spec.ts | 115 ++++ packages/api/src/utils/promise.ts | 42 ++ 45 files changed, 3116 insertions(+), 1150 deletions(-) delete mode 100644 packages/api/src/mcp/MCPServersRegistry.ts delete mode 100644 packages/api/src/mcp/__tests__/MCPServersRegistry.test.ts delete mode 100644 packages/api/src/mcp/__tests__/fixtures/MCPServersRegistry.parsedConfigs.yml delete mode 100644 packages/api/src/mcp/__tests__/fixtures/MCPServersRegistry.rawConfigs.yml create mode 100644 packages/api/src/mcp/registry/MCPServerInspector.ts create mode 100644 packages/api/src/mcp/registry/MCPServersInitializer.ts create mode 100644 packages/api/src/mcp/registry/MCPServersRegistry.ts create mode 100644 packages/api/src/mcp/registry/__tests__/MCPServerInspector.test.ts create mode 100644 packages/api/src/mcp/registry/__tests__/MCPServersInitializer.cache_integration.spec.ts create mode 100644 packages/api/src/mcp/registry/__tests__/MCPServersInitializer.test.ts create mode 100644 packages/api/src/mcp/registry/__tests__/MCPServersRegistry.cache_integration.spec.ts create mode 100644 packages/api/src/mcp/registry/__tests__/MCPServersRegistry.test.ts create mode 100644 packages/api/src/mcp/registry/__tests__/mcpConnectionsMock.helper.ts create mode 100644 packages/api/src/mcp/registry/cache/BaseRegistryCache.ts create mode 100644 packages/api/src/mcp/registry/cache/RegistryStatusCache.ts create mode 100644 packages/api/src/mcp/registry/cache/ServerConfigsCacheFactory.ts create mode 100644 packages/api/src/mcp/registry/cache/ServerConfigsCacheInMemory.ts create mode 100644 packages/api/src/mcp/registry/cache/ServerConfigsCacheRedis.ts create mode 100644 packages/api/src/mcp/registry/cache/__tests__/RegistryStatusCache.cache_integration.spec.ts create mode 100644 packages/api/src/mcp/registry/cache/__tests__/ServerConfigsCacheFactory.test.ts create mode 100644 packages/api/src/mcp/registry/cache/__tests__/ServerConfigsCacheInMemory.test.ts create mode 100644 packages/api/src/mcp/registry/cache/__tests__/ServerConfigsCacheRedis.cache_integration.spec.ts create mode 100644 packages/api/src/utils/promise.spec.ts create mode 100644 packages/api/src/utils/promise.ts diff --git a/.github/workflows/cache-integration-tests.yml b/.github/workflows/cache-integration-tests.yml index f7ac638282..bdd3f2e83d 100644 --- a/.github/workflows/cache-integration-tests.yml +++ b/.github/workflows/cache-integration-tests.yml @@ -9,6 +9,7 @@ on: paths: - 'packages/api/src/cache/**' - 'packages/api/src/cluster/**' + - 'packages/api/src/mcp/**' - 'redis-config/**' - '.github/workflows/cache-integration-tests.yml' @@ -77,6 +78,14 @@ jobs: REDIS_URI: redis://127.0.0.1:6379 run: npm run test:cache-integration:cluster + - name: Run mcp integration tests + working-directory: packages/api + env: + NODE_ENV: test + USE_REDIS: true + REDIS_URI: redis://127.0.0.1:6379 + run: npm run test:cache-integration:mcp + - name: Stop Redis Cluster if: always() working-directory: redis-config diff --git a/api/server/controllers/UserController.js b/api/server/controllers/UserController.js index 31295387ed..b488864a93 100644 --- a/api/server/controllers/UserController.js +++ b/api/server/controllers/UserController.js @@ -28,6 +28,7 @@ const { getMCPManager, getFlowStateManager } = require('~/config'); const { getAppConfig } = require('~/server/services/Config'); const { deleteToolCalls } = require('~/models/ToolCall'); const { getLogStores } = require('~/cache'); +const { mcpServersRegistry } = require('@librechat/api'); const getUserController = async (req, res) => { const appConfig = await getAppConfig({ role: req.user?.role }); @@ -198,7 +199,7 @@ const updateUserPluginsController = async (req, res) => { // If auth was updated successfully, disconnect MCP sessions as they might use these credentials if (pluginKey.startsWith(Constants.mcp_prefix)) { try { - const mcpManager = getMCPManager(user.id); + const mcpManager = getMCPManager(); if (mcpManager) { // Extract server name from pluginKey (format: "mcp_") const serverName = pluginKey.replace(Constants.mcp_prefix, ''); @@ -295,10 +296,11 @@ const maybeUninstallOAuthMCP = async (userId, pluginKey, appConfig) => { } const serverName = pluginKey.replace(Constants.mcp_prefix, ''); - const mcpManager = getMCPManager(userId); - const serverConfig = mcpManager.getRawConfig(serverName) ?? appConfig?.mcpServers?.[serverName]; - - if (!mcpManager.getOAuthServers().has(serverName)) { + const serverConfig = + (await mcpServersRegistry.getServerConfig(serverName, userId)) ?? + appConfig?.mcpServers?.[serverName]; + const oauthServers = await mcpServersRegistry.getOAuthServers(); + if (!oauthServers.has(serverName)) { // this server does not use OAuth, so nothing to do here as well return; } diff --git a/api/server/controllers/mcp.js b/api/server/controllers/mcp.js index 9e520d392e..e113b01f17 100644 --- a/api/server/controllers/mcp.js +++ b/api/server/controllers/mcp.js @@ -10,6 +10,7 @@ const { getAppConfig, } = require('~/server/services/Config'); const { getMCPManager } = require('~/config'); +const { mcpServersRegistry } = require('@librechat/api'); /** * Get all MCP tools available to the user @@ -65,7 +66,7 @@ const getMCPTools = async (req, res) => { // Get server config once const serverConfig = appConfig.mcpConfig[serverName]; - const rawServerConfig = mcpManager.getRawConfig(serverName); + const rawServerConfig = await mcpServersRegistry.getServerConfig(serverName, userId); // Initialize server object with all server-level data const server = { diff --git a/api/server/routes/__tests__/mcp.spec.js b/api/server/routes/__tests__/mcp.spec.js index 8ae92cdd3d..43e086f7b3 100644 --- a/api/server/routes/__tests__/mcp.spec.js +++ b/api/server/routes/__tests__/mcp.spec.js @@ -15,6 +15,10 @@ jest.mock('@librechat/api', () => ({ storeTokens: jest.fn(), }, getUserMCPAuthMap: jest.fn(), + mcpServersRegistry: { + getServerConfig: jest.fn(), + getOAuthServers: jest.fn(), + }, })); jest.mock('@librechat/data-schemas', () => ({ @@ -115,7 +119,7 @@ describe('MCP Routes', () => { }); describe('GET /:serverName/oauth/initiate', () => { - const { MCPOAuthHandler } = require('@librechat/api'); + const { MCPOAuthHandler, mcpServersRegistry } = require('@librechat/api'); const { getLogStores } = require('~/cache'); it('should initiate OAuth flow successfully', async () => { @@ -128,13 +132,9 @@ describe('MCP Routes', () => { }), }; - const mockMcpManager = { - getRawConfig: jest.fn().mockReturnValue({}), - }; - getLogStores.mockReturnValue({}); require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager); - require('~/config').getMCPManager.mockReturnValue(mockMcpManager); + mcpServersRegistry.getServerConfig.mockResolvedValue({}); MCPOAuthHandler.initiateOAuthFlow.mockResolvedValue({ authorizationUrl: 'https://oauth.example.com/auth', @@ -288,6 +288,7 @@ describe('MCP Routes', () => { }); it('should handle OAuth callback successfully', async () => { + const { mcpServersRegistry } = require('@librechat/api'); const mockFlowManager = { completeFlow: jest.fn().mockResolvedValue(), deleteFlow: jest.fn().mockResolvedValue(true), @@ -307,6 +308,7 @@ describe('MCP Routes', () => { MCPOAuthHandler.getFlowState.mockResolvedValue(mockFlowState); MCPOAuthHandler.completeOAuthFlow.mockResolvedValue(mockTokens); MCPTokenStorage.storeTokens.mockResolvedValue(); + mcpServersRegistry.getServerConfig.mockResolvedValue({}); getLogStores.mockReturnValue({}); require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager); @@ -321,7 +323,6 @@ describe('MCP Routes', () => { }; const mockMcpManager = { getUserConnection: jest.fn().mockResolvedValue(mockUserConnection), - getRawConfig: jest.fn().mockReturnValue({}), }; require('~/config').getMCPManager.mockReturnValue(mockMcpManager); @@ -379,6 +380,7 @@ describe('MCP Routes', () => { }); it('should handle system-level OAuth completion', async () => { + const { mcpServersRegistry } = require('@librechat/api'); const mockFlowManager = { completeFlow: jest.fn().mockResolvedValue(), deleteFlow: jest.fn().mockResolvedValue(true), @@ -398,14 +400,10 @@ describe('MCP Routes', () => { MCPOAuthHandler.getFlowState.mockResolvedValue(mockFlowState); MCPOAuthHandler.completeOAuthFlow.mockResolvedValue(mockTokens); MCPTokenStorage.storeTokens.mockResolvedValue(); + mcpServersRegistry.getServerConfig.mockResolvedValue({}); getLogStores.mockReturnValue({}); require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager); - const mockMcpManager = { - getRawConfig: jest.fn().mockReturnValue({}), - }; - require('~/config').getMCPManager.mockReturnValue(mockMcpManager); - const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({ code: 'test-auth-code', state: 'test-flow-id', @@ -417,6 +415,7 @@ describe('MCP Routes', () => { }); it('should handle reconnection failure after OAuth', async () => { + const { mcpServersRegistry } = require('@librechat/api'); const mockFlowManager = { completeFlow: jest.fn().mockResolvedValue(), deleteFlow: jest.fn().mockResolvedValue(true), @@ -436,12 +435,12 @@ describe('MCP Routes', () => { MCPOAuthHandler.getFlowState.mockResolvedValue(mockFlowState); MCPOAuthHandler.completeOAuthFlow.mockResolvedValue(mockTokens); MCPTokenStorage.storeTokens.mockResolvedValue(); + mcpServersRegistry.getServerConfig.mockResolvedValue({}); getLogStores.mockReturnValue({}); require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager); const mockMcpManager = { getUserConnection: jest.fn().mockRejectedValue(new Error('Reconnection failed')), - getRawConfig: jest.fn().mockReturnValue({}), }; require('~/config').getMCPManager.mockReturnValue(mockMcpManager); @@ -461,6 +460,7 @@ describe('MCP Routes', () => { }); it('should redirect to error page if token storage fails', async () => { + const { mcpServersRegistry } = require('@librechat/api'); const mockFlowManager = { completeFlow: jest.fn().mockResolvedValue(), deleteFlow: jest.fn().mockResolvedValue(true), @@ -480,6 +480,7 @@ describe('MCP Routes', () => { MCPOAuthHandler.getFlowState.mockResolvedValue(mockFlowState); MCPOAuthHandler.completeOAuthFlow.mockResolvedValue(mockTokens); MCPTokenStorage.storeTokens.mockRejectedValue(new Error('store failed')); + mcpServersRegistry.getServerConfig.mockResolvedValue({}); getLogStores.mockReturnValue({}); require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager); @@ -730,12 +731,14 @@ describe('MCP Routes', () => { }); describe('POST /:serverName/reinitialize', () => { + const { mcpServersRegistry } = require('@librechat/api'); + it('should return 404 when server is not found in configuration', async () => { const mockMcpManager = { - getRawConfig: jest.fn().mockReturnValue(null), disconnectUserConnection: jest.fn().mockResolvedValue(), }; + mcpServersRegistry.getServerConfig.mockResolvedValue(null); require('~/config').getMCPManager.mockReturnValue(mockMcpManager); require('~/config').getFlowStateManager.mockReturnValue({}); require('~/cache').getLogStores.mockReturnValue({}); @@ -750,9 +753,6 @@ describe('MCP Routes', () => { it('should handle OAuth requirement during reinitialize', async () => { const mockMcpManager = { - getRawConfig: jest.fn().mockReturnValue({ - customUserVars: {}, - }), disconnectUserConnection: jest.fn().mockResolvedValue(), mcpConfigs: {}, getUserConnection: jest.fn().mockImplementation(async ({ oauthStart }) => { @@ -763,6 +763,9 @@ describe('MCP Routes', () => { }), }; + mcpServersRegistry.getServerConfig.mockResolvedValue({ + customUserVars: {}, + }); require('~/config').getMCPManager.mockReturnValue(mockMcpManager); require('~/config').getFlowStateManager.mockReturnValue({}); require('~/cache').getLogStores.mockReturnValue({}); @@ -788,12 +791,12 @@ describe('MCP Routes', () => { it('should return 500 when reinitialize fails with non-OAuth error', async () => { const mockMcpManager = { - getRawConfig: jest.fn().mockReturnValue({}), disconnectUserConnection: jest.fn().mockResolvedValue(), mcpConfigs: {}, getUserConnection: jest.fn().mockRejectedValue(new Error('Connection failed')), }; + mcpServersRegistry.getServerConfig.mockResolvedValue({}); require('~/config').getMCPManager.mockReturnValue(mockMcpManager); require('~/config').getFlowStateManager.mockReturnValue({}); require('~/cache').getLogStores.mockReturnValue({}); @@ -809,11 +812,12 @@ describe('MCP Routes', () => { it('should return 500 when unexpected error occurs', async () => { const mockMcpManager = { - getRawConfig: jest.fn().mockImplementation(() => { - throw new Error('Config loading failed'); - }), + disconnectUserConnection: jest.fn(), }; + mcpServersRegistry.getServerConfig.mockImplementation(() => { + throw new Error('Config loading failed'); + }); require('~/config').getMCPManager.mockReturnValue(mockMcpManager); const response = await request(app).post('/api/mcp/test-server/reinitialize'); @@ -846,11 +850,11 @@ describe('MCP Routes', () => { }; const mockMcpManager = { - getRawConfig: jest.fn().mockReturnValue({ endpoint: 'http://test-server.com' }), disconnectUserConnection: jest.fn().mockResolvedValue(), getUserConnection: jest.fn().mockResolvedValue(mockUserConnection), }; + mcpServersRegistry.getServerConfig.mockResolvedValue({ endpoint: 'http://test-server.com' }); require('~/config').getMCPManager.mockReturnValue(mockMcpManager); require('~/config').getFlowStateManager.mockReturnValue({}); require('~/cache').getLogStores.mockReturnValue({}); @@ -891,16 +895,16 @@ describe('MCP Routes', () => { }; const mockMcpManager = { - getRawConfig: jest.fn().mockReturnValue({ - endpoint: 'http://test-server.com', - customUserVars: { - API_KEY: 'some-env-var', - }, - }), disconnectUserConnection: jest.fn().mockResolvedValue(), getUserConnection: jest.fn().mockResolvedValue(mockUserConnection), }; + mcpServersRegistry.getServerConfig.mockResolvedValue({ + endpoint: 'http://test-server.com', + customUserVars: { + API_KEY: 'some-env-var', + }, + }); require('~/config').getMCPManager.mockReturnValue(mockMcpManager); require('~/config').getFlowStateManager.mockReturnValue({}); require('~/cache').getLogStores.mockReturnValue({}); @@ -1105,17 +1109,17 @@ describe('MCP Routes', () => { describe('GET /:serverName/auth-values', () => { const { getUserPluginAuthValue } = require('~/server/services/PluginService'); + const { mcpServersRegistry } = require('@librechat/api'); it('should return auth value flags for server', async () => { - const mockMcpManager = { - getRawConfig: jest.fn().mockReturnValue({ - customUserVars: { - API_KEY: 'some-env-var', - SECRET_TOKEN: 'another-env-var', - }, - }), - }; + const mockMcpManager = {}; + mcpServersRegistry.getServerConfig.mockResolvedValue({ + customUserVars: { + API_KEY: 'some-env-var', + SECRET_TOKEN: 'another-env-var', + }, + }); require('~/config').getMCPManager.mockReturnValue(mockMcpManager); getUserPluginAuthValue.mockResolvedValueOnce('some-api-key-value').mockResolvedValueOnce(''); @@ -1135,10 +1139,9 @@ describe('MCP Routes', () => { }); it('should return 404 when server is not found in configuration', async () => { - const mockMcpManager = { - getRawConfig: jest.fn().mockReturnValue(null), - }; + const mockMcpManager = {}; + mcpServersRegistry.getServerConfig.mockResolvedValue(null); require('~/config').getMCPManager.mockReturnValue(mockMcpManager); const response = await request(app).get('/api/mcp/non-existent-server/auth-values'); @@ -1150,14 +1153,13 @@ describe('MCP Routes', () => { }); it('should handle errors when checking auth values', async () => { - const mockMcpManager = { - getRawConfig: jest.fn().mockReturnValue({ - customUserVars: { - API_KEY: 'some-env-var', - }, - }), - }; + const mockMcpManager = {}; + mcpServersRegistry.getServerConfig.mockResolvedValue({ + customUserVars: { + API_KEY: 'some-env-var', + }, + }); require('~/config').getMCPManager.mockReturnValue(mockMcpManager); getUserPluginAuthValue.mockRejectedValue(new Error('Database error')); @@ -1174,12 +1176,11 @@ describe('MCP Routes', () => { }); it('should return 500 when auth values check throws unexpected error', async () => { - const mockMcpManager = { - getRawConfig: jest.fn().mockImplementation(() => { - throw new Error('Config loading failed'); - }), - }; + const mockMcpManager = {}; + mcpServersRegistry.getServerConfig.mockImplementation(() => { + throw new Error('Config loading failed'); + }); require('~/config').getMCPManager.mockReturnValue(mockMcpManager); const response = await request(app).get('/api/mcp/test-server/auth-values'); @@ -1189,12 +1190,11 @@ describe('MCP Routes', () => { }); it('should handle customUserVars that is not an object', async () => { - const mockMcpManager = { - getRawConfig: jest.fn().mockReturnValue({ - customUserVars: 'not-an-object', - }), - }; + const mockMcpManager = {}; + mcpServersRegistry.getServerConfig.mockResolvedValue({ + customUserVars: 'not-an-object', + }); require('~/config').getMCPManager.mockReturnValue(mockMcpManager); const response = await request(app).get('/api/mcp/test-server/auth-values'); @@ -1221,7 +1221,7 @@ describe('MCP Routes', () => { describe('GET /:serverName/oauth/callback - Edge Cases', () => { it('should handle OAuth callback without toolFlowId (falsy toolFlowId)', async () => { - const { MCPOAuthHandler, MCPTokenStorage } = require('@librechat/api'); + const { MCPOAuthHandler, MCPTokenStorage, mcpServersRegistry } = require('@librechat/api'); const mockTokens = { access_token: 'edge-access-token', refresh_token: 'edge-refresh-token', @@ -1239,6 +1239,7 @@ describe('MCP Routes', () => { }); MCPOAuthHandler.completeOAuthFlow = jest.fn().mockResolvedValue(mockTokens); MCPTokenStorage.storeTokens.mockResolvedValue(); + mcpServersRegistry.getServerConfig.mockResolvedValue({}); const mockFlowManager = { completeFlow: jest.fn(), @@ -1249,7 +1250,6 @@ describe('MCP Routes', () => { getUserConnection: jest.fn().mockResolvedValue({ fetchTools: jest.fn().mockResolvedValue([]), }), - getRawConfig: jest.fn().mockReturnValue({}), }; require('~/config').getMCPManager.mockReturnValue(mockMcpManager); @@ -1264,7 +1264,7 @@ describe('MCP Routes', () => { it('should handle null cached tools in OAuth callback (triggers || {} fallback)', async () => { const { getCachedTools } = require('~/server/services/Config'); getCachedTools.mockResolvedValue(null); - const { MCPOAuthHandler, MCPTokenStorage } = require('@librechat/api'); + const { MCPOAuthHandler, MCPTokenStorage, mcpServersRegistry } = require('@librechat/api'); const mockTokens = { access_token: 'edge-access-token', refresh_token: 'edge-refresh-token', @@ -1290,6 +1290,7 @@ describe('MCP Routes', () => { }); MCPOAuthHandler.completeOAuthFlow.mockResolvedValue(mockTokens); MCPTokenStorage.storeTokens.mockResolvedValue(); + mcpServersRegistry.getServerConfig.mockResolvedValue({}); const mockMcpManager = { getUserConnection: jest.fn().mockResolvedValue({ @@ -1297,7 +1298,6 @@ describe('MCP Routes', () => { .fn() .mockResolvedValue([{ name: 'test-tool', description: 'Test tool' }]), }), - getRawConfig: jest.fn().mockReturnValue({}), }; require('~/config').getMCPManager.mockReturnValue(mockMcpManager); diff --git a/api/server/routes/config.js b/api/server/routes/config.js index bae5f764b0..f1d2332047 100644 --- a/api/server/routes/config.js +++ b/api/server/routes/config.js @@ -12,6 +12,7 @@ const { getAppConfig } = require('~/server/services/Config/app'); const { getProjectByName } = require('~/models/Project'); const { getMCPManager } = require('~/config'); const { getLogStores } = require('~/cache'); +const { mcpServersRegistry } = require('@librechat/api'); const router = express.Router(); const emailLoginEnabled = @@ -125,7 +126,7 @@ router.get('/', async function (req, res) { payload.minPasswordLength = minPasswordLength; } - const getMCPServers = () => { + const getMCPServers = async () => { try { if (appConfig?.mcpConfig == null) { return; @@ -134,9 +135,8 @@ router.get('/', async function (req, res) { if (!mcpManager) { return; } - const mcpServers = mcpManager.getAllServers(); + const mcpServers = await mcpServersRegistry.getAllServerConfigs(); if (!mcpServers) return; - const oauthServers = mcpManager.getOAuthServers(); for (const serverName in mcpServers) { if (!payload.mcpServers) { payload.mcpServers = {}; @@ -145,7 +145,7 @@ router.get('/', async function (req, res) { payload.mcpServers[serverName] = removeNullishValues({ startup: serverConfig?.startup, chatMenu: serverConfig?.chatMenu, - isOAuth: oauthServers?.has(serverName), + isOAuth: serverConfig.requiresOAuth, customUserVars: serverConfig?.customUserVars, }); } @@ -154,7 +154,7 @@ router.get('/', async function (req, res) { } }; - getMCPServers(); + await getMCPServers(); const webSearchConfig = appConfig?.webSearch; if ( webSearchConfig != null && diff --git a/api/server/routes/mcp.js b/api/server/routes/mcp.js index 9b66b10e52..8d6d91e8d9 100644 --- a/api/server/routes/mcp.js +++ b/api/server/routes/mcp.js @@ -6,6 +6,7 @@ const { MCPOAuthHandler, MCPTokenStorage, getUserMCPAuthMap, + mcpServersRegistry, } = require('@librechat/api'); const { getMCPManager, getFlowStateManager, getOAuthReconnectionManager } = require('~/config'); const { getMCPSetupData, getServerConnectionStatus } = require('~/server/services/MCP'); @@ -61,11 +62,12 @@ router.get('/:serverName/oauth/initiate', requireJwtAuth, async (req, res) => { return res.status(400).json({ error: 'Invalid flow state' }); } + const oauthHeaders = await getOAuthHeaders(serverName, userId); const { authorizationUrl, flowId: oauthFlowId } = await MCPOAuthHandler.initiateOAuthFlow( serverName, serverUrl, userId, - getOAuthHeaders(serverName), + oauthHeaders, oauthConfig, ); @@ -133,12 +135,8 @@ router.get('/:serverName/oauth/callback', async (req, res) => { }); logger.debug('[MCP OAuth] Completing OAuth flow'); - const tokens = await MCPOAuthHandler.completeOAuthFlow( - flowId, - code, - flowManager, - getOAuthHeaders(serverName), - ); + const oauthHeaders = await getOAuthHeaders(serverName, flowState.userId); + const tokens = await MCPOAuthHandler.completeOAuthFlow(flowId, code, flowManager, oauthHeaders); logger.info('[MCP OAuth] OAuth flow completed, tokens received in callback route'); /** Persist tokens immediately so reconnection uses fresh credentials */ @@ -356,7 +354,7 @@ router.post('/:serverName/reinitialize', requireJwtAuth, async (req, res) => { logger.info(`[MCP Reinitialize] Reinitializing server: ${serverName}`); const mcpManager = getMCPManager(); - const serverConfig = mcpManager.getRawConfig(serverName); + const serverConfig = await mcpServersRegistry.getServerConfig(serverName, user.id); if (!serverConfig) { return res.status(404).json({ error: `MCP server '${serverName}' not found in configuration`, @@ -505,8 +503,7 @@ router.get('/:serverName/auth-values', requireJwtAuth, async (req, res) => { return res.status(401).json({ error: 'User not authenticated' }); } - const mcpManager = getMCPManager(); - const serverConfig = mcpManager.getRawConfig(serverName); + const serverConfig = await mcpServersRegistry.getServerConfig(serverName, user.id); if (!serverConfig) { return res.status(404).json({ error: `MCP server '${serverName}' not found in configuration`, @@ -545,9 +542,8 @@ router.get('/:serverName/auth-values', requireJwtAuth, async (req, res) => { } }); -function getOAuthHeaders(serverName) { - const mcpManager = getMCPManager(); - const serverConfig = mcpManager.getRawConfig(serverName); +async function getOAuthHeaders(serverName, userId) { + const serverConfig = await mcpServersRegistry.getServerConfig(serverName, userId); return serverConfig?.oauth_headers ?? {}; } diff --git a/api/server/services/MCP.js b/api/server/services/MCP.js index b7975b12fa..e91e5e7904 100644 --- a/api/server/services/MCP.js +++ b/api/server/services/MCP.js @@ -25,6 +25,7 @@ const { findToken, createToken, updateToken } = require('~/models'); const { reinitMCPServer } = require('./Tools/mcp'); const { getAppConfig } = require('./Config'); const { getLogStores } = require('~/cache'); +const { mcpServersRegistry } = require('@librechat/api'); /** * @param {object} params @@ -450,7 +451,7 @@ async function getMCPSetupData(userId) { logger.error(`[MCP][User: ${userId}] Error getting app connections:`, error); } const userConnections = mcpManager.getUserConnections(userId) || new Map(); - const oauthServers = mcpManager.getOAuthServers(); + const oauthServers = await mcpServersRegistry.getOAuthServers(); return { mcpConfig, diff --git a/api/server/services/MCP.spec.js b/api/server/services/MCP.spec.js index 7b192995e3..18857c4893 100644 --- a/api/server/services/MCP.spec.js +++ b/api/server/services/MCP.spec.js @@ -50,6 +50,9 @@ jest.mock('@librechat/api', () => ({ sendEvent: jest.fn(), normalizeServerName: jest.fn((name) => name), convertWithResolvedRefs: jest.fn((params) => params), + mcpServersRegistry: { + getOAuthServers: jest.fn(() => Promise.resolve(new Set())), + }, })); jest.mock('librechat-data-provider', () => ({ @@ -100,6 +103,7 @@ describe('tests for the new helper functions used by the MCP connection status e let mockGetFlowStateManager; let mockGetLogStores; let mockGetOAuthReconnectionManager; + let mockMcpServersRegistry; beforeEach(() => { jest.clearAllMocks(); @@ -108,6 +112,7 @@ describe('tests for the new helper functions used by the MCP connection status e mockGetFlowStateManager = require('~/config').getFlowStateManager; mockGetLogStores = require('~/cache').getLogStores; mockGetOAuthReconnectionManager = require('~/config').getOAuthReconnectionManager; + mockMcpServersRegistry = require('@librechat/api').mcpServersRegistry; }); describe('getMCPSetupData', () => { @@ -125,8 +130,8 @@ describe('tests for the new helper functions used by the MCP connection status e mockGetMCPManager.mockReturnValue({ appConnections: { getAll: jest.fn(() => new Map()) }, getUserConnections: jest.fn(() => new Map()), - getOAuthServers: jest.fn(() => new Set()), }); + mockMcpServersRegistry.getOAuthServers.mockResolvedValue(new Set()); }); it('should successfully return MCP setup data', async () => { @@ -139,9 +144,9 @@ describe('tests for the new helper functions used by the MCP connection status e const mockMCPManager = { appConnections: { getAll: jest.fn(() => mockAppConnections) }, getUserConnections: jest.fn(() => mockUserConnections), - getOAuthServers: jest.fn(() => mockOAuthServers), }; mockGetMCPManager.mockReturnValue(mockMCPManager); + mockMcpServersRegistry.getOAuthServers.mockResolvedValue(mockOAuthServers); const result = await getMCPSetupData(mockUserId); @@ -149,7 +154,7 @@ describe('tests for the new helper functions used by the MCP connection status e expect(mockGetMCPManager).toHaveBeenCalledWith(mockUserId); expect(mockMCPManager.appConnections.getAll).toHaveBeenCalled(); expect(mockMCPManager.getUserConnections).toHaveBeenCalledWith(mockUserId); - expect(mockMCPManager.getOAuthServers).toHaveBeenCalled(); + expect(mockMcpServersRegistry.getOAuthServers).toHaveBeenCalled(); expect(result).toEqual({ mcpConfig: mockConfig.mcpServers, @@ -170,9 +175,9 @@ describe('tests for the new helper functions used by the MCP connection status e const mockMCPManager = { appConnections: { getAll: jest.fn(() => null) }, getUserConnections: jest.fn(() => null), - getOAuthServers: jest.fn(() => new Set()), }; mockGetMCPManager.mockReturnValue(mockMCPManager); + mockMcpServersRegistry.getOAuthServers.mockResolvedValue(new Set()); const result = await getMCPSetupData(mockUserId); diff --git a/api/server/services/initializeMCPs.js b/api/server/services/initializeMCPs.js index 397fc85202..7fdb128683 100644 --- a/api/server/services/initializeMCPs.js +++ b/api/server/services/initializeMCPs.js @@ -15,7 +15,7 @@ async function initializeMCPs() { const mcpManager = await createMCPManager(mcpServers); try { - const mcpTools = mcpManager.getAppToolFunctions() || {}; + const mcpTools = (await mcpManager.getAppToolFunctions()) || {}; await mergeAppTools(mcpTools); logger.info( diff --git a/packages/api/jest.config.mjs b/packages/api/jest.config.mjs index 1533a3d213..10fa4554e4 100644 --- a/packages/api/jest.config.mjs +++ b/packages/api/jest.config.mjs @@ -1,7 +1,13 @@ export default { collectCoverageFrom: ['src/**/*.{js,jsx,ts,tsx}', '!/node_modules/'], coveragePathIgnorePatterns: ['/node_modules/', '/dist/'], - testPathIgnorePatterns: ['/node_modules/', '/dist/', '\\.dev\\.ts$'], + testPathIgnorePatterns: [ + '/node_modules/', + '/dist/', + '\\.dev\\.ts$', + '\\.helper\\.ts$', + '\\.helper\\.d\\.ts$', + ], coverageReporters: ['text', 'cobertura'], testResultsProcessor: 'jest-junit', moduleNameMapper: { @@ -18,4 +24,4 @@ export default { // }, restoreMocks: true, testTimeout: 15000, -}; \ No newline at end of file +}; diff --git a/packages/api/package.json b/packages/api/package.json index 4d333082a3..86c2d3f42a 100644 --- a/packages/api/package.json +++ b/packages/api/package.json @@ -18,10 +18,11 @@ "build:dev": "npm run clean && NODE_ENV=development rollup -c --bundleConfigAsCjs", "build:watch": "NODE_ENV=development rollup -c -w --bundleConfigAsCjs", "build:watch:prod": "rollup -c -w --bundleConfigAsCjs", - "test": "jest --coverage --watch --testPathIgnorePatterns=\"\\.*integration\\.\"", - "test:ci": "jest --coverage --ci --testPathIgnorePatterns=\"\\.*integration\\.\"", + "test": "jest --coverage --watch --testPathIgnorePatterns=\"\\.*integration\\.|\\.*helper\\.\"", + "test:ci": "jest --coverage --ci --testPathIgnorePatterns=\"\\.*integration\\.|\\.*helper\\.\"", "test:cache-integration:core": "jest --testPathPattern=\"src/cache/.*\\.cache_integration\\.spec\\.ts$\" --coverage=false", "test:cache-integration:cluster": "jest --testPathPattern=\"src/cluster/.*\\.cache_integration\\.spec\\.ts$\" --coverage=false --runInBand", + "test:cache-integration:mcp": "jest --testPathPattern=\"src/mcp/.*\\.cache_integration\\.spec\\.ts$\" --coverage=false", "verify": "npm run test:ci", "b:clean": "bun run rimraf dist", "b:build": "bun run b:clean && bun run rollup -c --silent --bundleConfigAsCjs", diff --git a/packages/api/src/index.ts b/packages/api/src/index.ts index e839a335a4..02d09797d3 100644 --- a/packages/api/src/index.ts +++ b/packages/api/src/index.ts @@ -3,6 +3,7 @@ export * from './cdn'; /* Auth */ export * from './auth'; /* MCP */ +export * from './mcp/registry/MCPServersRegistry'; export * from './mcp/MCPManager'; export * from './mcp/connection'; export * from './mcp/oauth'; diff --git a/packages/api/src/mcp/MCPConnectionFactory.ts b/packages/api/src/mcp/MCPConnectionFactory.ts index 5f4447b2bd..4425788cc9 100644 --- a/packages/api/src/mcp/MCPConnectionFactory.ts +++ b/packages/api/src/mcp/MCPConnectionFactory.ts @@ -9,6 +9,7 @@ import { MCPTokenStorage, MCPOAuthHandler } from '~/mcp/oauth'; import { sanitizeUrlForLogging } from './utils'; import { MCPConnection } from './connection'; import { processMCPEnv } from '~/utils'; +import { withTimeout } from '~/utils/promise'; /** * Factory for creating MCP connections with optional OAuth authentication. @@ -231,14 +232,11 @@ export class MCPConnectionFactory { /** Attempts to establish connection with timeout handling */ protected async attemptToConnect(connection: MCPConnection): Promise { const connectTimeout = this.connectionTimeout ?? this.serverConfig.initTimeout ?? 30000; - const connectionTimeout = new Promise((_, reject) => - setTimeout( - () => reject(new Error(`Connection timeout after ${connectTimeout}ms`)), - connectTimeout, - ), + await withTimeout( + this.connectTo(connection), + connectTimeout, + `Connection timeout after ${connectTimeout}ms`, ); - const connectionAttempt = this.connectTo(connection); - await Promise.race([connectionAttempt, connectionTimeout]); if (await connection.isConnected()) return; logger.error(`${this.logPrefix} Failed to establish connection.`); diff --git a/packages/api/src/mcp/MCPManager.ts b/packages/api/src/mcp/MCPManager.ts index c6bfe77b8f..1e0d483f17 100644 --- a/packages/api/src/mcp/MCPManager.ts +++ b/packages/api/src/mcp/MCPManager.ts @@ -5,11 +5,14 @@ import type { RequestOptions } from '@modelcontextprotocol/sdk/shared/protocol.j import type { TokenMethods } from '@librechat/data-schemas'; import type { FlowStateManager } from '~/flow/manager'; import type { TUser } from 'librechat-data-provider'; -import type { MCPOAuthTokens } from '~/mcp/oauth'; +import type { MCPOAuthTokens } from './oauth'; import type { RequestBody } from '~/types'; import type * as t from './types'; -import { UserConnectionManager } from '~/mcp/UserConnectionManager'; -import { ConnectionsRepository } from '~/mcp/ConnectionsRepository'; +import { UserConnectionManager } from './UserConnectionManager'; +import { ConnectionsRepository } from './ConnectionsRepository'; +import { MCPServerInspector } from './registry/MCPServerInspector'; +import { MCPServersInitializer } from './registry/MCPServersInitializer'; +import { mcpServersRegistry as registry } from './registry/MCPServersRegistry'; import { formatToolContent } from './parsers'; import { MCPConnection } from './connection'; import { processMCPEnv } from '~/utils/env'; @@ -24,8 +27,8 @@ export class MCPManager extends UserConnectionManager { /** Creates and initializes the singleton MCPManager instance */ public static async createInstance(configs: t.MCPServers): Promise { if (MCPManager.instance) throw new Error('MCPManager has already been initialized.'); - MCPManager.instance = new MCPManager(configs); - await MCPManager.instance.initialize(); + MCPManager.instance = new MCPManager(); + await MCPManager.instance.initialize(configs); return MCPManager.instance; } @@ -36,9 +39,10 @@ export class MCPManager extends UserConnectionManager { } /** Initializes the MCPManager by setting up server registry and app connections */ - public async initialize() { - await this.serversRegistry.initialize(); - this.appConnections = new ConnectionsRepository(this.serversRegistry.appServerConfigs); + public async initialize(configs: t.MCPServers) { + await MCPServersInitializer.initialize(configs); + const appConfigs = await registry.sharedAppServers.getAll(); + this.appConnections = new ConnectionsRepository(appConfigs); } /** Retrieves an app-level or user-specific connection based on provided arguments */ @@ -62,36 +66,18 @@ export class MCPManager extends UserConnectionManager { } } - /** Get servers that require OAuth */ - public getOAuthServers(): Set { - return this.serversRegistry.oauthServers; - } - - /** Get all servers */ - public getAllServers(): t.MCPServers { - return this.serversRegistry.rawConfigs; - } - /** Returns all available tool functions from app-level connections */ - public getAppToolFunctions(): t.LCAvailableTools { - return this.serversRegistry.toolFunctions; + public async getAppToolFunctions(): Promise { + const toolFunctions: t.LCAvailableTools = {}; + const configs = await registry.getAllServerConfigs(); + for (const config of Object.values(configs)) { + if (config.toolFunctions != null) { + Object.assign(toolFunctions, config.toolFunctions); + } + } + return toolFunctions; } - /** Returns all available tool functions from all connections available to user */ - public async getAllToolFunctions(userId: string): Promise { - const allToolFunctions: t.LCAvailableTools = this.getAppToolFunctions(); - const userConnections = this.getUserConnections(userId); - if (!userConnections || userConnections.size === 0) { - return allToolFunctions; - } - - for (const [serverName, connection] of userConnections.entries()) { - const toolFunctions = await this.serversRegistry.getToolFunctions(serverName, connection); - Object.assign(allToolFunctions, toolFunctions); - } - - return allToolFunctions; - } /** Returns all available tool functions from all connections available to user */ public async getServerToolFunctions( userId: string, @@ -99,7 +85,7 @@ export class MCPManager extends UserConnectionManager { ): Promise { try { if (this.appConnections?.has(serverName)) { - return this.serversRegistry.getToolFunctions( + return MCPServerInspector.getToolFunctions( serverName, await this.appConnections.get(serverName), ); @@ -113,7 +99,7 @@ export class MCPManager extends UserConnectionManager { return null; } - return this.serversRegistry.getToolFunctions(serverName, userConnections.get(serverName)!); + return MCPServerInspector.getToolFunctions(serverName, userConnections.get(serverName)!); } catch (error) { logger.warn( `[getServerToolFunctions] Error getting tool functions for server ${serverName}`, @@ -128,8 +114,14 @@ export class MCPManager extends UserConnectionManager { * @param serverNames Optional array of server names. If not provided or empty, returns all servers. * @returns Object mapping server names to their instructions */ - public getInstructions(serverNames?: string[]): Record { - const instructions = this.serversRegistry.serverInstructions; + private async getInstructions(serverNames?: string[]): Promise> { + const instructions: Record = {}; + const configs = await registry.getAllServerConfigs(); + for (const [serverName, config] of Object.entries(configs)) { + if (config.serverInstructions != null) { + instructions[serverName] = config.serverInstructions as string; + } + } if (!serverNames) return instructions; return pick(instructions, serverNames); } @@ -139,9 +131,9 @@ export class MCPManager extends UserConnectionManager { * @param serverNames Optional array of server names to include. If not provided, includes all servers. * @returns Formatted instructions string ready for context injection */ - public formatInstructionsForContext(serverNames?: string[]): string { + public async formatInstructionsForContext(serverNames?: string[]): Promise { /** Instructions for specified servers or all stored instructions */ - const instructionsToInclude = this.getInstructions(serverNames); + const instructionsToInclude = await this.getInstructions(serverNames); if (Object.keys(instructionsToInclude).length === 0) { return ''; @@ -225,7 +217,7 @@ Please follow these instructions when using tools from the respective MCP server ); } - const rawConfig = this.getRawConfig(serverName) as t.MCPOptions; + const rawConfig = (await registry.getServerConfig(serverName, userId)) as t.MCPOptions; const currentOptions = processMCPEnv({ user, options: rawConfig, diff --git a/packages/api/src/mcp/MCPServersRegistry.ts b/packages/api/src/mcp/MCPServersRegistry.ts deleted file mode 100644 index 668ad7d2c0..0000000000 --- a/packages/api/src/mcp/MCPServersRegistry.ts +++ /dev/null @@ -1,230 +0,0 @@ -import mapValues from 'lodash/mapValues'; -import { logger } from '@librechat/data-schemas'; -import { Constants } from 'librechat-data-provider'; -import type { JsonSchemaType } from '@librechat/data-schemas'; -import type { MCPConnection } from '~/mcp/connection'; -import type * as t from '~/mcp/types'; -import { ConnectionsRepository } from '~/mcp/ConnectionsRepository'; -import { detectOAuthRequirement } from '~/mcp/oauth'; -import { sanitizeUrlForLogging } from '~/mcp/utils'; -import { processMCPEnv, isEnabled } from '~/utils'; - -const DEFAULT_MCP_INIT_TIMEOUT_MS = 30_000; - -function getMCPInitTimeout(): number { - return process.env.MCP_INIT_TIMEOUT_MS != null - ? parseInt(process.env.MCP_INIT_TIMEOUT_MS) - : DEFAULT_MCP_INIT_TIMEOUT_MS; -} - -/** - * Manages MCP server configurations and metadata discovery. - * Fetches server capabilities, OAuth requirements, and tool definitions for registry. - * Determines which servers are for app-level connections. - * Has its own connections repository. All connections are disconnected after initialization. - */ -export class MCPServersRegistry { - private initialized: boolean = false; - private connections: ConnectionsRepository; - private initTimeoutMs: number; - - public readonly rawConfigs: t.MCPServers; - public readonly parsedConfigs: Record; - - public oauthServers: Set = new Set(); - public serverInstructions: Record = {}; - public toolFunctions: t.LCAvailableTools = {}; - public appServerConfigs: t.MCPServers = {}; - - constructor(configs: t.MCPServers) { - this.rawConfigs = configs; - this.parsedConfigs = mapValues(configs, (con) => processMCPEnv({ options: con })); - this.connections = new ConnectionsRepository(configs); - this.initTimeoutMs = getMCPInitTimeout(); - } - - /** Initializes all startup-enabled servers by gathering their metadata asynchronously */ - public async initialize(): Promise { - if (this.initialized) return; - this.initialized = true; - - const serverNames = Object.keys(this.parsedConfigs); - - await Promise.allSettled( - serverNames.map((serverName) => this.initializeServerWithTimeout(serverName)), - ); - } - - /** Wraps server initialization with a timeout to prevent hanging */ - private async initializeServerWithTimeout(serverName: string): Promise { - let timeoutId: NodeJS.Timeout | null = null; - - try { - await Promise.race([ - this.initializeServer(serverName), - new Promise((_, reject) => { - timeoutId = setTimeout(() => { - reject(new Error('Server initialization timed out')); - }, this.initTimeoutMs); - }), - ]); - } catch (error) { - logger.warn(`${this.prefix(serverName)} Server initialization failed:`, error); - throw error; - } finally { - if (timeoutId != null) { - clearTimeout(timeoutId); - } - } - } - - /** Initializes a single server with all its metadata and adds it to appropriate collections */ - private async initializeServer(serverName: string): Promise { - const start = Date.now(); - - const config = this.parsedConfigs[serverName]; - - // 1. Detect OAuth requirements if not already specified - try { - await this.fetchOAuthRequirement(serverName); - - if (config.startup !== false && !config.requiresOAuth) { - await Promise.allSettled([ - this.fetchServerInstructions(serverName).catch((error) => - logger.warn(`${this.prefix(serverName)} Failed to fetch server instructions:`, error), - ), - this.fetchServerCapabilities(serverName).catch((error) => - logger.warn(`${this.prefix(serverName)} Failed to fetch server capabilities:`, error), - ), - ]); - } - } catch (error) { - logger.warn(`${this.prefix(serverName)} Failed to initialize server:`, error); - } - - // 2. Fetch tool functions for this server if a connection was established - const getToolFunctions = async (): Promise => { - try { - const loadedConns = await this.connections.getLoaded(); - const conn = loadedConns.get(serverName); - if (conn == null) { - return null; - } - return this.getToolFunctions(serverName, conn); - } catch (error) { - logger.warn(`${this.prefix(serverName)} Error fetching tool functions:`, error); - return null; - } - }; - const toolFunctions = await getToolFunctions(); - - // 3. Disconnect this server's connection if it was established (fire-and-forget) - void this.connections.disconnect(serverName); - - // 4. Side effects - // 4.1 Add to OAuth servers if needed - if (config.requiresOAuth) { - this.oauthServers.add(serverName); - } - // 4.2 Add server instructions if available - if (config.serverInstructions != null) { - this.serverInstructions[serverName] = config.serverInstructions as string; - } - // 4.3 Add to app server configs if eligible (startup enabled, non-OAuth servers) - if (config.startup !== false && config.requiresOAuth === false) { - this.appServerConfigs[serverName] = this.rawConfigs[serverName]; - } - // 4.4 Add tool functions if available - if (toolFunctions != null) { - Object.assign(this.toolFunctions, toolFunctions); - } - - const duration = Date.now() - start; - this.logUpdatedConfig(serverName, duration); - } - - /** Converts server tools to LibreChat-compatible tool functions format */ - public async getToolFunctions( - serverName: string, - conn: MCPConnection, - ): Promise { - const { tools }: t.MCPToolListResponse = await conn.client.listTools(); - - const toolFunctions: t.LCAvailableTools = {}; - tools.forEach((tool) => { - const name = `${tool.name}${Constants.mcp_delimiter}${serverName}`; - toolFunctions[name] = { - type: 'function', - ['function']: { - name, - description: tool.description, - parameters: tool.inputSchema as JsonSchemaType, - }, - }; - }); - - return toolFunctions; - } - - /** Determines if server requires OAuth if not already specified in the config */ - private async fetchOAuthRequirement(serverName: string): Promise { - const config = this.parsedConfigs[serverName]; - if (config.requiresOAuth != null) return config.requiresOAuth; - if (config.url == null) return (config.requiresOAuth = false); - if (config.startup === false) return (config.requiresOAuth = false); - - const result = await detectOAuthRequirement(config.url); - config.requiresOAuth = result.requiresOAuth; - config.oauthMetadata = result.metadata; - return config.requiresOAuth; - } - - /** Retrieves server instructions from MCP server if enabled in the config */ - private async fetchServerInstructions(serverName: string): Promise { - const config = this.parsedConfigs[serverName]; - if (!config.serverInstructions) return; - - // If it's a string that's not "true", it's a custom instruction - if (typeof config.serverInstructions === 'string' && !isEnabled(config.serverInstructions)) { - return; - } - - // Fetch from server if true (boolean) or "true" (string) - const conn = await this.connections.get(serverName); - config.serverInstructions = conn.client.getInstructions(); - if (!config.serverInstructions) { - logger.warn(`${this.prefix(serverName)} No server instructions available`); - } - } - - /** Fetches server capabilities and available tools list */ - private async fetchServerCapabilities(serverName: string): Promise { - const config = this.parsedConfigs[serverName]; - const conn = await this.connections.get(serverName); - const capabilities = conn.client.getServerCapabilities(); - if (!capabilities) return; - config.capabilities = JSON.stringify(capabilities); - if (!capabilities.tools) return; - const tools = await conn.client.listTools(); - config.tools = tools.tools.map((tool) => tool.name).join(', '); - } - - // Logs server configuration summary after initialization - private logUpdatedConfig(serverName: string, initDuration: number): void { - const prefix = this.prefix(serverName); - const config = this.parsedConfigs[serverName]; - logger.info(`${prefix} -------------------------------------------------┐`); - logger.info(`${prefix} URL: ${config.url ? sanitizeUrlForLogging(config.url) : 'N/A'}`); - logger.info(`${prefix} OAuth Required: ${config.requiresOAuth}`); - logger.info(`${prefix} Capabilities: ${config.capabilities}`); - logger.info(`${prefix} Tools: ${config.tools}`); - logger.info(`${prefix} Server Instructions: ${config.serverInstructions}`); - logger.info(`${prefix} Initialized in: ${initDuration}ms`); - logger.info(`${prefix} -------------------------------------------------┘`); - } - - // Returns formatted log prefix for server messages - private prefix(serverName: string): string { - return `[MCP][${serverName}]`; - } -} diff --git a/packages/api/src/mcp/UserConnectionManager.ts b/packages/api/src/mcp/UserConnectionManager.ts index 7f5862b2a8..21c177dc7c 100644 --- a/packages/api/src/mcp/UserConnectionManager.ts +++ b/packages/api/src/mcp/UserConnectionManager.ts @@ -1,7 +1,7 @@ import { logger } from '@librechat/data-schemas'; import { ErrorCode, McpError } from '@modelcontextprotocol/sdk/types.js'; import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory'; -import { MCPServersRegistry } from '~/mcp/MCPServersRegistry'; +import { mcpServersRegistry as serversRegistry } from '~/mcp/registry/MCPServersRegistry'; import { MCPConnection } from './connection'; import type * as t from './types'; import { ConnectionsRepository } from '~/mcp/ConnectionsRepository'; @@ -14,7 +14,6 @@ import { ConnectionsRepository } from '~/mcp/ConnectionsRepository'; * https://github.com/danny-avila/LibreChat/discussions/8790 */ export abstract class UserConnectionManager { - protected readonly serversRegistry: MCPServersRegistry; // Connections shared by all users. public appConnections: ConnectionsRepository | null = null; // Connections per userId -> serverName -> connection @@ -23,15 +22,6 @@ export abstract class UserConnectionManager { protected userLastActivity: Map = new Map(); protected readonly USER_CONNECTION_IDLE_TIMEOUT = 15 * 60 * 1000; // 15 minutes (TODO: make configurable) - constructor(serverConfigs: t.MCPServers) { - this.serversRegistry = new MCPServersRegistry(serverConfigs); - } - - /** fetches am MCP Server config from the registry */ - public getRawConfig(serverName: string): t.MCPOptions | undefined { - return this.serversRegistry.rawConfigs[serverName]; - } - /** Updates the last activity timestamp for a user */ protected updateUserLastActivity(userId: string): void { const now = Date.now(); @@ -106,7 +96,7 @@ export abstract class UserConnectionManager { logger.info(`[MCP][User: ${userId}][${serverName}] Establishing new connection`); } - const config = this.serversRegistry.parsedConfigs[serverName]; + const config = await serversRegistry.getServerConfig(serverName, userId); if (!config) { throw new McpError( ErrorCode.InvalidRequest, diff --git a/packages/api/src/mcp/__tests__/MCPManager.test.ts b/packages/api/src/mcp/__tests__/MCPManager.test.ts index 4d60a16954..ff0ba8ad3b 100644 --- a/packages/api/src/mcp/__tests__/MCPManager.test.ts +++ b/packages/api/src/mcp/__tests__/MCPManager.test.ts @@ -1,7 +1,9 @@ import { logger } from '@librechat/data-schemas'; import type * as t from '~/mcp/types'; import { MCPManager } from '~/mcp/MCPManager'; -import { MCPServersRegistry } from '~/mcp/MCPServersRegistry'; +import { mcpServersRegistry } from '~/mcp/registry/MCPServersRegistry'; +import { MCPServersInitializer } from '~/mcp/registry/MCPServersInitializer'; +import { MCPServerInspector } from '~/mcp/registry/MCPServerInspector'; import { ConnectionsRepository } from '~/mcp/ConnectionsRepository'; import { MCPConnection } from '../connection'; @@ -15,7 +17,24 @@ jest.mock('@librechat/data-schemas', () => ({ }, })); -jest.mock('~/mcp/MCPServersRegistry'); +jest.mock('~/mcp/registry/MCPServersRegistry', () => ({ + mcpServersRegistry: { + sharedAppServers: { + getAll: jest.fn(), + }, + getServerConfig: jest.fn(), + getAllServerConfigs: jest.fn(), + getOAuthServers: jest.fn(), + }, +})); + +jest.mock('~/mcp/registry/MCPServersInitializer', () => ({ + MCPServersInitializer: { + initialize: jest.fn(), + }, +})); + +jest.mock('~/mcp/registry/MCPServerInspector'); jest.mock('~/mcp/ConnectionsRepository'); const mockLogger = logger as jest.Mocked; @@ -28,20 +47,12 @@ describe('MCPManager', () => { // Reset MCPManager singleton state (MCPManager as unknown as { instance: null }).instance = null; jest.clearAllMocks(); - }); - function mockRegistry( - registryConfig: Partial, - ): jest.MockedClass { - const mock = { - initialize: jest.fn().mockResolvedValue(undefined), - getToolFunctions: jest.fn().mockResolvedValue(null), - ...registryConfig, - }; - return (MCPServersRegistry as jest.MockedClass).mockImplementation( - () => mock as unknown as MCPServersRegistry, - ); - } + // Set up default mock implementations + (MCPServersInitializer.initialize as jest.Mock).mockResolvedValue(undefined); + (mcpServersRegistry.sharedAppServers.getAll as jest.Mock).mockResolvedValue({}); + (mcpServersRegistry.getAllServerConfigs as jest.Mock).mockResolvedValue({}); + }); function mockAppConnections( appConnectionsConfig: Partial, @@ -66,12 +77,229 @@ describe('MCPManager', () => { }; } + describe('getAppToolFunctions', () => { + it('should return empty object when no servers have tool functions', async () => { + (mcpServersRegistry.getAllServerConfigs as jest.Mock).mockResolvedValue({ + server1: { type: 'stdio', command: 'test', args: [] }, + server2: { type: 'stdio', command: 'test2', args: [] }, + }); + + const manager = await MCPManager.createInstance(newMCPServersConfig()); + const result = await manager.getAppToolFunctions(); + + expect(result).toEqual({}); + }); + + it('should collect tool functions from multiple servers', async () => { + const toolFunctions1 = { + tool1_mcp_server1: { + type: 'function' as const, + function: { + name: 'tool1_mcp_server1', + description: 'Tool 1', + parameters: { type: 'object' as const }, + }, + }, + }; + + const toolFunctions2 = { + tool2_mcp_server2: { + type: 'function' as const, + function: { + name: 'tool2_mcp_server2', + description: 'Tool 2', + parameters: { type: 'object' as const }, + }, + }, + }; + + (mcpServersRegistry.getAllServerConfigs as jest.Mock).mockResolvedValue({ + server1: { + type: 'stdio', + command: 'test', + args: [], + toolFunctions: toolFunctions1, + }, + server2: { + type: 'stdio', + command: 'test2', + args: [], + toolFunctions: toolFunctions2, + }, + }); + + const manager = await MCPManager.createInstance(newMCPServersConfig()); + const result = await manager.getAppToolFunctions(); + + expect(result).toEqual({ + ...toolFunctions1, + ...toolFunctions2, + }); + }); + + it('should handle servers with null or undefined toolFunctions', async () => { + const toolFunctions1 = { + tool1_mcp_server1: { + type: 'function' as const, + function: { + name: 'tool1_mcp_server1', + description: 'Tool 1', + parameters: { type: 'object' as const }, + }, + }, + }; + + (mcpServersRegistry.getAllServerConfigs as jest.Mock).mockResolvedValue({ + server1: { + type: 'stdio', + command: 'test', + args: [], + toolFunctions: toolFunctions1, + }, + server2: { + type: 'stdio', + command: 'test2', + args: [], + toolFunctions: null, + }, + server3: { + type: 'stdio', + command: 'test3', + args: [], + }, + }); + + const manager = await MCPManager.createInstance(newMCPServersConfig()); + const result = await manager.getAppToolFunctions(); + + expect(result).toEqual(toolFunctions1); + }); + }); + + describe('formatInstructionsForContext', () => { + it('should return empty string when no servers have instructions', async () => { + (mcpServersRegistry.getAllServerConfigs as jest.Mock).mockResolvedValue({ + server1: { type: 'stdio', command: 'test', args: [] }, + server2: { type: 'stdio', command: 'test2', args: [] }, + }); + + const manager = await MCPManager.createInstance(newMCPServersConfig()); + const result = await manager.formatInstructionsForContext(); + + expect(result).toBe(''); + }); + + it('should format instructions from multiple servers', async () => { + (mcpServersRegistry.getAllServerConfigs as jest.Mock).mockResolvedValue({ + github: { + type: 'sse', + url: 'https://api.github.com', + serverInstructions: 'Use GitHub API with care', + }, + files: { + type: 'stdio', + command: 'node', + args: ['files.js'], + serverInstructions: 'Only read/write files in allowed directories', + }, + }); + + const manager = await MCPManager.createInstance(newMCPServersConfig()); + const result = await manager.formatInstructionsForContext(); + + expect(result).toContain('# MCP Server Instructions'); + expect(result).toContain('## github MCP Server Instructions'); + expect(result).toContain('Use GitHub API with care'); + expect(result).toContain('## files MCP Server Instructions'); + expect(result).toContain('Only read/write files in allowed directories'); + }); + + it('should filter instructions by server names when provided', async () => { + (mcpServersRegistry.getAllServerConfigs as jest.Mock).mockResolvedValue({ + github: { + type: 'sse', + url: 'https://api.github.com', + serverInstructions: 'Use GitHub API with care', + }, + files: { + type: 'stdio', + command: 'node', + args: ['files.js'], + serverInstructions: 'Only read/write files in allowed directories', + }, + database: { + type: 'stdio', + command: 'node', + args: ['db.js'], + serverInstructions: 'Be careful with database operations', + }, + }); + + const manager = await MCPManager.createInstance(newMCPServersConfig()); + const result = await manager.formatInstructionsForContext(['github', 'database']); + + expect(result).toContain('## github MCP Server Instructions'); + expect(result).toContain('Use GitHub API with care'); + expect(result).toContain('## database MCP Server Instructions'); + expect(result).toContain('Be careful with database operations'); + expect(result).not.toContain('files'); + expect(result).not.toContain('Only read/write files in allowed directories'); + }); + + it('should handle servers with null or undefined instructions', async () => { + (mcpServersRegistry.getAllServerConfigs as jest.Mock).mockResolvedValue({ + github: { + type: 'sse', + url: 'https://api.github.com', + serverInstructions: 'Use GitHub API with care', + }, + files: { + type: 'stdio', + command: 'node', + args: ['files.js'], + serverInstructions: null, + }, + database: { + type: 'stdio', + command: 'node', + args: ['db.js'], + }, + }); + + const manager = await MCPManager.createInstance(newMCPServersConfig()); + const result = await manager.formatInstructionsForContext(); + + expect(result).toContain('## github MCP Server Instructions'); + expect(result).toContain('Use GitHub API with care'); + expect(result).not.toContain('files'); + expect(result).not.toContain('database'); + }); + + it('should return empty string when filtered servers have no instructions', async () => { + (mcpServersRegistry.getAllServerConfigs as jest.Mock).mockResolvedValue({ + github: { + type: 'sse', + url: 'https://api.github.com', + serverInstructions: 'Use GitHub API with care', + }, + files: { + type: 'stdio', + command: 'node', + args: ['files.js'], + }, + }); + + const manager = await MCPManager.createInstance(newMCPServersConfig()); + const result = await manager.formatInstructionsForContext(['files']); + + expect(result).toBe(''); + }); + }); + describe('getServerToolFunctions', () => { it('should catch and handle errors gracefully', async () => { - mockRegistry({ - getToolFunctions: jest.fn(() => { - throw new Error('Connection failed'); - }), + (MCPServerInspector.getToolFunctions as jest.Mock) = jest.fn(() => { + throw new Error('Connection failed'); }); mockAppConnections({ @@ -90,9 +318,7 @@ describe('MCPManager', () => { }); it('should catch synchronous errors from getUserConnections', async () => { - mockRegistry({ - getToolFunctions: jest.fn().mockResolvedValue({}), - }); + (MCPServerInspector.getToolFunctions as jest.Mock) = jest.fn().mockResolvedValue({}); mockAppConnections({ has: jest.fn().mockReturnValue(false), @@ -126,9 +352,9 @@ describe('MCPManager', () => { }, }; - mockRegistry({ - getToolFunctions: jest.fn().mockResolvedValue(expectedTools), - }); + (MCPServerInspector.getToolFunctions as jest.Mock) = jest + .fn() + .mockResolvedValue(expectedTools); mockAppConnections({ has: jest.fn().mockReturnValue(true), @@ -145,10 +371,8 @@ describe('MCPManager', () => { it('should include specific server name in error messages', async () => { const specificServerName = 'github_mcp_server'; - mockRegistry({ - getToolFunctions: jest.fn(() => { - throw new Error('Server specific error'); - }), + (MCPServerInspector.getToolFunctions as jest.Mock) = jest.fn(() => { + throw new Error('Server specific error'); }); mockAppConnections({ diff --git a/packages/api/src/mcp/__tests__/MCPServersRegistry.test.ts b/packages/api/src/mcp/__tests__/MCPServersRegistry.test.ts deleted file mode 100644 index ade8eab32c..0000000000 --- a/packages/api/src/mcp/__tests__/MCPServersRegistry.test.ts +++ /dev/null @@ -1,595 +0,0 @@ -import { join } from 'path'; -import { readFileSync } from 'fs'; -import { load as yamlLoad } from 'js-yaml'; -import { logger } from '@librechat/data-schemas'; -import type { OAuthDetectionResult } from '~/mcp/oauth/detectOAuth'; -import type * as t from '~/mcp/types'; -import { ConnectionsRepository } from '~/mcp/ConnectionsRepository'; -import { MCPServersRegistry } from '~/mcp/MCPServersRegistry'; -import { detectOAuthRequirement } from '~/mcp/oauth'; -import { MCPConnection } from '~/mcp/connection'; - -// Mock external dependencies -jest.mock('../oauth/detectOAuth'); -jest.mock('../ConnectionsRepository'); -jest.mock('../connection'); -jest.mock('@librechat/data-schemas', () => ({ - logger: { - info: jest.fn(), - warn: jest.fn(), - error: jest.fn(), - debug: jest.fn(), - }, -})); - -// Mock processMCPEnv to verify it's called and adds a processed marker -jest.mock('~/utils', () => ({ - ...jest.requireActual('~/utils'), - processMCPEnv: jest.fn(({ options }) => ({ - ...options, - _processed: true, // Simple marker to verify processing occurred - })), -})); - -const mockDetectOAuthRequirement = detectOAuthRequirement as jest.MockedFunction< - typeof detectOAuthRequirement ->; -const mockLogger = logger as jest.Mocked; - -describe('MCPServersRegistry - Initialize Function', () => { - let rawConfigs: t.MCPServers; - let expectedParsedConfigs: Record; - let mockConnectionsRepo: jest.Mocked; - let mockConnections: Map>; - - beforeEach(() => { - // Load fixtures - const rawConfigsPath = join(__dirname, 'fixtures', 'MCPServersRegistry.rawConfigs.yml'); - const parsedConfigsPath = join(__dirname, 'fixtures', 'MCPServersRegistry.parsedConfigs.yml'); - - rawConfigs = yamlLoad(readFileSync(rawConfigsPath, 'utf8')) as t.MCPServers; - expectedParsedConfigs = yamlLoad(readFileSync(parsedConfigsPath, 'utf8')) as Record< - string, - t.ParsedServerConfig - >; - - // Setup mock connections - mockConnections = new Map(); - const serverNames = Object.keys(rawConfigs); - - serverNames.forEach((serverName) => { - const mockClient = { - listTools: jest.fn(), - getInstructions: jest.fn(), - getServerCapabilities: jest.fn(), - }; - const mockConnection = { - client: mockClient, - } as unknown as jest.Mocked; - - // Setup mock responses based on expected configs - const expectedConfig = expectedParsedConfigs[serverName]; - - // Mock listTools response - if (expectedConfig.tools) { - const toolNames = expectedConfig.tools.split(', '); - const tools = toolNames.map((name: string) => ({ - name, - description: `Description for ${name}`, - inputSchema: { - type: 'object' as const, - properties: { - input: { type: 'string' }, - }, - }, - })); - (mockClient.listTools as jest.Mock).mockResolvedValue({ tools }); - } else { - (mockClient.listTools as jest.Mock).mockResolvedValue({ tools: [] }); - } - - // Mock getInstructions response - if (expectedConfig.serverInstructions) { - (mockClient.getInstructions as jest.Mock).mockReturnValue( - expectedConfig.serverInstructions as string, - ); - } else { - (mockClient.getInstructions as jest.Mock).mockReturnValue(undefined); - } - - // Mock getServerCapabilities response - if (expectedConfig.capabilities) { - const capabilities = JSON.parse(expectedConfig.capabilities) as Record; - (mockClient.getServerCapabilities as jest.Mock).mockReturnValue(capabilities); - } else { - (mockClient.getServerCapabilities as jest.Mock).mockReturnValue(undefined); - } - - mockConnections.set(serverName, mockConnection); - }); - - // Setup ConnectionsRepository mock - mockConnectionsRepo = { - get: jest.fn(), - getLoaded: jest.fn(), - disconnectAll: jest.fn(), - disconnect: jest.fn().mockResolvedValue(undefined), - } as unknown as jest.Mocked; - - mockConnectionsRepo.get.mockImplementation((serverName: string) => { - const connection = mockConnections.get(serverName); - if (!connection) { - throw new Error(`Connection not found for server: ${serverName}`); - } - return Promise.resolve(connection); - }); - - mockConnectionsRepo.getLoaded.mockResolvedValue(mockConnections); - - (ConnectionsRepository as jest.Mock).mockImplementation(() => mockConnectionsRepo); - - // Setup OAuth detection mock with deterministic results - mockDetectOAuthRequirement.mockImplementation((url: string) => { - const oauthResults: Record = { - 'https://api.github.com/mcp': { - requiresOAuth: true, - method: 'protected-resource-metadata', - metadata: { - authorization_url: 'https://github.com/login/oauth/authorize', - token_url: 'https://github.com/login/oauth/access_token', - }, - }, - 'https://api.disabled.com/mcp': { - requiresOAuth: false, - method: 'no-metadata-found', - metadata: null, - }, - 'https://api.public.com/mcp': { - requiresOAuth: false, - method: 'no-metadata-found', - metadata: null, - }, - }; - - return Promise.resolve( - oauthResults[url] || { requiresOAuth: false, method: 'no-metadata-found', metadata: null }, - ); - }); - - // Clear all mocks - jest.clearAllMocks(); - }); - - afterEach(() => { - delete process.env.MCP_INIT_TIMEOUT_MS; - jest.clearAllMocks(); - }); - - describe('initialize() method', () => { - it('should only run initialization once', async () => { - const registry = new MCPServersRegistry(rawConfigs); - - await registry.initialize(); - await registry.initialize(); // Second call should not re-run - - // Verify that connections are only requested for servers that need them - // (servers with serverInstructions=true or all servers for capabilities) - expect(mockConnectionsRepo.get).toHaveBeenCalled(); - }); - - it('should set all public properties correctly after initialization', async () => { - const registry = new MCPServersRegistry(rawConfigs); - - // Verify initial state - expect(registry.oauthServers.size).toBe(0); - expect(registry.serverInstructions).toEqual({}); - expect(registry.toolFunctions).toEqual({}); - expect(registry.appServerConfigs).toEqual({}); - - await registry.initialize(); - - // Test oauthServers Set - expect(registry.oauthServers).toEqual( - new Set(['oauth_server', 'oauth_predefined', 'oauth_startup_enabled']), - ); - - // Test serverInstructions - OAuth servers keep their original boolean value, non-OAuth fetch actual strings - expect(registry.serverInstructions).toEqual({ - stdio_server: 'Follow these instructions for stdio server', - oauth_server: true, - non_oauth_server: 'Public API instructions', - }); - - // Test appServerConfigs (startup enabled, non-OAuth servers only) - expect(registry.appServerConfigs).toEqual({ - stdio_server: rawConfigs.stdio_server, - websocket_server: rawConfigs.websocket_server, - non_oauth_server: rawConfigs.non_oauth_server, - }); - - // Test toolFunctions (only non-OAuth servers get their tools fetched during initialization) - const expectedToolFunctions = { - file_read_mcp_stdio_server: { - type: 'function', - function: { - name: 'file_read_mcp_stdio_server', - description: 'Description for file_read', - parameters: { type: 'object', properties: { input: { type: 'string' } } }, - }, - }, - file_write_mcp_stdio_server: { - type: 'function', - function: { - name: 'file_write_mcp_stdio_server', - description: 'Description for file_write', - parameters: { type: 'object', properties: { input: { type: 'string' } } }, - }, - }, - }; - expect(registry.toolFunctions).toEqual(expectedToolFunctions); - }); - - it('should handle errors gracefully and continue initialization of other servers', async () => { - const registry = new MCPServersRegistry(rawConfigs); - - // Make one specific server throw an error during OAuth detection - mockDetectOAuthRequirement.mockImplementation((url: string) => { - if (url === 'https://api.github.com/mcp') { - return Promise.reject(new Error('OAuth detection failed')); - } - // Return normal responses for other servers - const oauthResults: Record = { - 'https://api.disabled.com/mcp': { - requiresOAuth: false, - method: 'no-metadata-found', - metadata: null, - }, - 'https://api.public.com/mcp': { - requiresOAuth: false, - method: 'no-metadata-found', - metadata: null, - }, - }; - return Promise.resolve( - oauthResults[url] ?? { - requiresOAuth: false, - method: 'no-metadata-found', - metadata: null, - }, - ); - }); - - await registry.initialize(); - - // Should still initialize successfully for other servers - expect(registry.oauthServers).toBeInstanceOf(Set); - expect(registry.toolFunctions).toBeDefined(); - - // The failed server should not be in oauthServers (since it failed OAuth detection) - expect(registry.oauthServers.has('oauth_server')).toBe(false); - - // But other servers should still be processed successfully - expect(registry.appServerConfigs).toHaveProperty('stdio_server'); - expect(registry.appServerConfigs).toHaveProperty('non_oauth_server'); - - // Error should be logged as a warning at the higher level - expect(mockLogger.warn).toHaveBeenCalledWith( - expect.stringContaining('[MCP][oauth_server] Failed to initialize server:'), - expect.any(Error), - ); - }); - - it('should disconnect individual connections after each server initialization', async () => { - const registry = new MCPServersRegistry(rawConfigs); - - await registry.initialize(); - - // Verify disconnect was called for each server during initialization - // All servers attempt to connect during initialization for metadata gathering - const serverNames = Object.keys(rawConfigs); - expect(mockConnectionsRepo.disconnect).toHaveBeenCalledTimes(serverNames.length); - }); - - it('should log configuration updates for each startup-enabled server', async () => { - const registry = new MCPServersRegistry(rawConfigs); - - await registry.initialize(); - - const serverNames = Object.keys(rawConfigs); - serverNames.forEach((serverName) => { - expect(mockLogger.info).toHaveBeenCalledWith( - expect.stringContaining(`[MCP][${serverName}] URL:`), - ); - expect(mockLogger.info).toHaveBeenCalledWith( - expect.stringContaining(`[MCP][${serverName}] OAuth Required:`), - ); - expect(mockLogger.info).toHaveBeenCalledWith( - expect.stringContaining(`[MCP][${serverName}] Capabilities:`), - ); - expect(mockLogger.info).toHaveBeenCalledWith( - expect.stringContaining(`[MCP][${serverName}] Tools:`), - ); - expect(mockLogger.info).toHaveBeenCalledWith( - expect.stringContaining(`[MCP][${serverName}] Server Instructions:`), - ); - }); - }); - - it('should have parsedConfigs matching the expected fixture after initialization', async () => { - const registry = new MCPServersRegistry(rawConfigs); - - await registry.initialize(); - - // Compare the actual parsedConfigs against the expected fixture - expect(registry.parsedConfigs).toEqual(expectedParsedConfigs); - }); - - it('should handle serverInstructions as string "true" correctly and fetch from server', async () => { - // Create test config with serverInstructions as string "true" - const testConfig: t.MCPServers = { - test_server_string_true: { - type: 'stdio', - args: [], - command: 'test-command', - serverInstructions: 'true', // Simulating string "true" from YAML parsing - }, - test_server_custom_string: { - type: 'stdio', - args: [], - command: 'test-command', - serverInstructions: 'Custom instructions here', - }, - test_server_bool_true: { - type: 'stdio', - args: [], - command: 'test-command', - serverInstructions: true, - }, - }; - - const registry = new MCPServersRegistry(testConfig); - - // Setup mock connection for servers that should fetch - const mockClient = { - listTools: jest.fn().mockResolvedValue({ tools: [] }), - getInstructions: jest.fn().mockReturnValue('Fetched instructions from server'), - getServerCapabilities: jest.fn().mockReturnValue({ tools: {} }), - }; - const mockConnection = { - client: mockClient, - } as unknown as jest.Mocked; - - mockConnectionsRepo.get.mockResolvedValue(mockConnection); - mockConnectionsRepo.getLoaded.mockResolvedValue( - new Map([ - ['test_server_string_true', mockConnection], - ['test_server_bool_true', mockConnection], - ]), - ); - mockDetectOAuthRequirement.mockResolvedValue({ - requiresOAuth: false, - method: 'no-metadata-found', - metadata: null, - }); - - await registry.initialize(); - - // Verify that string "true" was treated as fetch-from-server - expect(registry.parsedConfigs['test_server_string_true'].serverInstructions).toBe( - 'Fetched instructions from server', - ); - - // Verify that custom string was kept as-is - expect(registry.parsedConfigs['test_server_custom_string'].serverInstructions).toBe( - 'Custom instructions here', - ); - - // Verify that boolean true also fetched from server - expect(registry.parsedConfigs['test_server_bool_true'].serverInstructions).toBe( - 'Fetched instructions from server', - ); - - // Verify getInstructions was called for both "true" cases - expect(mockClient.getInstructions).toHaveBeenCalledTimes(2); - }); - - it('should use Promise.allSettled for individual server initialization', async () => { - const registry = new MCPServersRegistry(rawConfigs); - - // Spy on Promise.allSettled to verify it's being used - const allSettledSpy = jest.spyOn(Promise, 'allSettled'); - - await registry.initialize(); - - // Verify Promise.allSettled was called with an array of server initialization promises - expect(allSettledSpy).toHaveBeenCalledWith(expect.arrayContaining([expect.any(Promise)])); - - // Verify it was called with the correct number of server promises - const serverNames = Object.keys(rawConfigs); - expect(allSettledSpy).toHaveBeenCalledWith( - expect.arrayContaining(new Array(serverNames.length).fill(expect.any(Promise))), - ); - - allSettledSpy.mockRestore(); - }); - - it('should isolate server failures and not affect other servers', async () => { - const registry = new MCPServersRegistry(rawConfigs); - - // Make multiple servers fail in different ways - mockConnectionsRepo.get.mockImplementation((serverName: string) => { - if (serverName === 'stdio_server') { - // First server fails - throw new Error('Connection failed for stdio_server'); - } - if (serverName === 'websocket_server') { - // Second server fails - throw new Error('Connection failed for websocket_server'); - } - // Other servers succeed - const connection = mockConnections.get(serverName); - if (!connection) { - throw new Error(`Connection not found for server: ${serverName}`); - } - return Promise.resolve(connection); - }); - - await registry.initialize(); - - // Despite failures, initialization should complete - expect(registry.oauthServers).toBeInstanceOf(Set); - expect(registry.toolFunctions).toBeDefined(); - - // Successful servers should still be processed - expect(registry.appServerConfigs).toHaveProperty('non_oauth_server'); - - // Failed servers should not crash the whole initialization - expect(mockLogger.warn).toHaveBeenCalledWith( - expect.stringContaining('[MCP][stdio_server] Failed to fetch server capabilities:'), - expect.any(Error), - ); - expect(mockLogger.warn).toHaveBeenCalledWith( - expect.stringContaining('[MCP][websocket_server] Failed to fetch server capabilities:'), - expect.any(Error), - ); - }); - - it('should properly clean up connections even when some servers fail', async () => { - const registry = new MCPServersRegistry(rawConfigs); - - // Track disconnect failures but suppress unhandled rejections - const disconnectErrors: Error[] = []; - mockConnectionsRepo.disconnect.mockImplementation((serverName: string) => { - if (serverName === 'stdio_server') { - const error = new Error('Disconnect failed'); - disconnectErrors.push(error); - return Promise.reject(error).catch(() => {}); // Suppress unhandled rejection - } - return Promise.resolve(); - }); - - await registry.initialize(); - - // Should still attempt to disconnect all servers during initialization - const serverNames = Object.keys(rawConfigs); - expect(mockConnectionsRepo.disconnect).toHaveBeenCalledTimes(serverNames.length); - expect(disconnectErrors).toHaveLength(1); - }); - - it('should timeout individual server initialization after configured timeout', async () => { - const timeout = 2000; - // Create registry with a short timeout for testing - process.env.MCP_INIT_TIMEOUT_MS = `${timeout}`; - - const registry = new MCPServersRegistry(rawConfigs); - - // Make one server hang indefinitely during OAuth detection - mockDetectOAuthRequirement.mockImplementation((url: string) => { - if (url === 'https://api.github.com/mcp') { - // Slow init - return new Promise((res) => setTimeout(res, timeout * 2)); - } - // Return normal responses for other servers - return Promise.resolve({ - requiresOAuth: false, - method: 'no-metadata-found', - metadata: null, - }); - }); - - const start = Date.now(); - await registry.initialize(); - const duration = Date.now() - start; - - // Should complete within reasonable time despite one server hanging - // Allow some buffer for test execution overhead - expect(duration).toBeLessThan(timeout * 1.5); - - // The timeout should prevent the hanging server from blocking initialization - // Other servers should still be processed successfully - expect(registry.appServerConfigs).toHaveProperty('stdio_server'); - expect(registry.appServerConfigs).toHaveProperty('non_oauth_server'); - }, 10_000); // 10 second Jest timeout - - it('should skip tool function fetching if connection was not established', async () => { - const testConfig: t.MCPServers = { - server_with_connection: { - type: 'stdio', - args: [], - command: 'test-command', - }, - server_without_connection: { - type: 'stdio', - args: [], - command: 'failing-command', - }, - }; - - const registry = new MCPServersRegistry(testConfig); - - const mockClient = { - listTools: jest.fn().mockResolvedValue({ - tools: [ - { - name: 'test_tool', - description: 'Test tool', - inputSchema: { type: 'object', properties: {} }, - }, - ], - }), - getInstructions: jest.fn().mockReturnValue(undefined), - getServerCapabilities: jest.fn().mockReturnValue({ tools: {} }), - }; - const mockConnection = { - client: mockClient, - } as unknown as jest.Mocked; - - mockConnectionsRepo.get.mockImplementation((serverName: string) => { - if (serverName === 'server_with_connection') { - return Promise.resolve(mockConnection); - } - throw new Error('Connection failed'); - }); - - // Mock getLoaded to return connections map - the real implementation returns all loaded connections at once - mockConnectionsRepo.getLoaded.mockResolvedValue( - new Map([['server_with_connection', mockConnection]]), - ); - - mockDetectOAuthRequirement.mockResolvedValue({ - requiresOAuth: false, - method: 'no-metadata-found', - metadata: null, - }); - - await registry.initialize(); - - expect(registry.toolFunctions).toHaveProperty('test_tool_mcp_server_with_connection'); - expect(Object.keys(registry.toolFunctions)).toHaveLength(1); - }); - - it('should handle getLoaded returning empty map gracefully', async () => { - const testConfig: t.MCPServers = { - test_server: { - type: 'stdio', - args: [], - command: 'test-command', - }, - }; - - const registry = new MCPServersRegistry(testConfig); - - mockConnectionsRepo.get.mockRejectedValue(new Error('All connections failed')); - mockConnectionsRepo.getLoaded.mockResolvedValue(new Map()); - mockDetectOAuthRequirement.mockResolvedValue({ - requiresOAuth: false, - method: 'no-metadata-found', - metadata: null, - }); - - await registry.initialize(); - - expect(registry.toolFunctions).toEqual({}); - }); - }); -}); diff --git a/packages/api/src/mcp/__tests__/fixtures/MCPServersRegistry.parsedConfigs.yml b/packages/api/src/mcp/__tests__/fixtures/MCPServersRegistry.parsedConfigs.yml deleted file mode 100644 index 71b3e01d22..0000000000 --- a/packages/api/src/mcp/__tests__/fixtures/MCPServersRegistry.parsedConfigs.yml +++ /dev/null @@ -1,67 +0,0 @@ -# Expected parsed MCP server configurations after running initialize() -# These represent the expected state of parsedConfigs after all fetch functions complete - -oauth_server: - _processed: true - type: "streamable-http" - url: "https://api.github.com/mcp" - headers: - Authorization: "Bearer {{GITHUB_TOKEN}}" - serverInstructions: true - requiresOAuth: true - oauthMetadata: - authorization_url: "https://github.com/login/oauth/authorize" - token_url: "https://github.com/login/oauth/access_token" - -oauth_predefined: - _processed: true - type: "sse" - url: "https://api.example.com/sse" - requiresOAuth: true - oauthMetadata: - authorization_url: "https://example.com/oauth/authorize" - token_url: "https://example.com/oauth/token" - -stdio_server: - _processed: true - command: "node" - args: ["server.js"] - env: - API_KEY: "${TEST_API_KEY}" - startup: true - serverInstructions: "Follow these instructions for stdio server" - requiresOAuth: false - capabilities: '{"tools":{"listChanged":true},"resources":{"listChanged":true},"prompts":{}}' - tools: "file_read, file_write" - -websocket_server: - _processed: true - type: "websocket" - url: "ws://localhost:3001/mcp" - startup: true - requiresOAuth: false - oauthMetadata: null - capabilities: '{"tools":{},"resources":{},"prompts":{}}' - tools: "" - -disabled_server: - _processed: true - requiresOAuth: false - type: "streamable-http" - url: "https://api.disabled.com/mcp" - startup: false - -non_oauth_server: - _processed: true - type: "streamable-http" - url: "https://api.public.com/mcp" - requiresOAuth: false - serverInstructions: "Public API instructions" - capabilities: '{"tools":{},"resources":{},"prompts":{}}' - tools: "" - -oauth_startup_enabled: - _processed: true - type: "sse" - url: "https://api.oauth-startup.com/sse" - requiresOAuth: true \ No newline at end of file diff --git a/packages/api/src/mcp/__tests__/fixtures/MCPServersRegistry.rawConfigs.yml b/packages/api/src/mcp/__tests__/fixtures/MCPServersRegistry.rawConfigs.yml deleted file mode 100644 index 907dfaa96b..0000000000 --- a/packages/api/src/mcp/__tests__/fixtures/MCPServersRegistry.rawConfigs.yml +++ /dev/null @@ -1,53 +0,0 @@ -# Raw MCP server configurations used as input to MCPServersRegistry constructor -# These configs test different code paths in the initialization process - -# Test OAuth detection with URL - should trigger fetchOAuthRequirement -oauth_server: - type: "streamable-http" - url: "https://api.github.com/mcp" - headers: - Authorization: "Bearer {{GITHUB_TOKEN}}" - serverInstructions: true - -# Test OAuth already specified - should skip OAuth detection -oauth_predefined: - type: "sse" - url: "https://api.example.com/sse" - requiresOAuth: true - oauthMetadata: - authorization_url: "https://example.com/oauth/authorize" - token_url: "https://example.com/oauth/token" - -# Test stdio server without URL - should set requiresOAuth to false -stdio_server: - command: "node" - args: ["server.js"] - env: - API_KEY: "${TEST_API_KEY}" - startup: true - serverInstructions: "Follow these instructions for stdio server" - -# Test websocket server with capabilities but no tools -websocket_server: - type: "websocket" - url: "ws://localhost:3001/mcp" - startup: true - -# Test server with startup disabled - should not be included in appServerConfigs -disabled_server: - type: "streamable-http" - url: "https://api.disabled.com/mcp" - startup: false - -# Test non-OAuth server - should be included in appServerConfigs -non_oauth_server: - type: "streamable-http" - url: "https://api.public.com/mcp" - requiresOAuth: false - serverInstructions: true - -# Test server with OAuth but startup enabled - should not be in appServerConfigs -oauth_startup_enabled: - type: "sse" - url: "https://api.oauth-startup.com/sse" - requiresOAuth: true \ No newline at end of file diff --git a/packages/api/src/mcp/connection.ts b/packages/api/src/mcp/connection.ts index ad1b3e32aa..7e75acf751 100644 --- a/packages/api/src/mcp/connection.ts +++ b/packages/api/src/mcp/connection.ts @@ -18,6 +18,7 @@ import type { Response as UndiciResponse, } from 'undici'; import type { MCPOAuthTokens } from './oauth/types'; +import { withTimeout } from '~/utils/promise'; import type * as t from './types'; import { sanitizeUrlForLogging } from './utils'; import { mcpConfig } from './mcpConfig'; @@ -457,15 +458,11 @@ export class MCPConnection extends EventEmitter { this.setupTransportDebugHandlers(); const connectTimeout = this.options.initTimeout ?? 120000; - await Promise.race([ + await withTimeout( this.client.connect(this.transport), - new Promise((_resolve, reject) => - setTimeout( - () => reject(new Error(`Connection timeout after ${connectTimeout}ms`)), - connectTimeout, - ), - ), - ]); + connectTimeout, + `Connection timeout after ${connectTimeout}ms`, + ); this.connectionState = 'connected'; this.emit('connectionChange', 'connected'); diff --git a/packages/api/src/mcp/oauth/OAuthReconnectionManager.test.ts b/packages/api/src/mcp/oauth/OAuthReconnectionManager.test.ts index d2295191cf..f9a3c7ab73 100644 --- a/packages/api/src/mcp/oauth/OAuthReconnectionManager.test.ts +++ b/packages/api/src/mcp/oauth/OAuthReconnectionManager.test.ts @@ -1,6 +1,7 @@ import { TokenMethods } from '@librechat/data-schemas'; import { FlowStateManager, MCPConnection, MCPOAuthTokens, MCPOptions } from '../..'; import { MCPManager } from '../MCPManager'; +import { mcpServersRegistry } from '../../mcp/registry/MCPServersRegistry'; import { OAuthReconnectionManager } from './OAuthReconnectionManager'; import { OAuthReconnectionTracker } from './OAuthReconnectionTracker'; @@ -14,6 +15,12 @@ jest.mock('@librechat/data-schemas', () => ({ })); jest.mock('../MCPManager'); +jest.mock('../../mcp/registry/MCPServersRegistry', () => ({ + mcpServersRegistry: { + getServerConfig: jest.fn(), + getOAuthServers: jest.fn(), + }, +})); describe('OAuthReconnectionManager', () => { let flowManager: jest.Mocked>; @@ -51,10 +58,10 @@ describe('OAuthReconnectionManager', () => { getUserConnection: jest.fn(), getUserConnections: jest.fn(), disconnectUserConnection: jest.fn(), - getRawConfig: jest.fn(), } as unknown as jest.Mocked; (MCPManager.getInstance as jest.Mock).mockReturnValue(mockMCPManager); + (mcpServersRegistry.getServerConfig as jest.Mock).mockResolvedValue({}); }); afterEach(() => { @@ -152,7 +159,7 @@ describe('OAuthReconnectionManager', () => { it('should reconnect eligible servers', async () => { const userId = 'user-123'; const oauthServers = new Set(['server1', 'server2', 'server3']); - mockMCPManager.getOAuthServers.mockReturnValue(oauthServers); + (mcpServersRegistry.getOAuthServers as jest.Mock).mockResolvedValue(oauthServers); // server1: has failed reconnection reconnectionTracker.setFailed(userId, 'server1'); @@ -186,7 +193,9 @@ describe('OAuthReconnectionManager', () => { mockMCPManager.getUserConnection.mockResolvedValue( mockNewConnection as unknown as MCPConnection, ); - mockMCPManager.getRawConfig.mockReturnValue({ initTimeout: 5000 } as unknown as MCPOptions); + (mcpServersRegistry.getServerConfig as jest.Mock).mockResolvedValue({ + initTimeout: 5000, + } as unknown as MCPOptions); await reconnectionManager.reconnectServers(userId); @@ -215,7 +224,7 @@ describe('OAuthReconnectionManager', () => { it('should handle failed reconnection attempts', async () => { const userId = 'user-123'; const oauthServers = new Set(['server1']); - mockMCPManager.getOAuthServers.mockReturnValue(oauthServers); + (mcpServersRegistry.getOAuthServers as jest.Mock).mockResolvedValue(oauthServers); // server1: has valid token tokenMethods.findToken.mockResolvedValue({ @@ -226,7 +235,9 @@ describe('OAuthReconnectionManager', () => { // Mock failed connection mockMCPManager.getUserConnection.mockRejectedValue(new Error('Connection failed')); - mockMCPManager.getRawConfig.mockReturnValue({} as unknown as MCPOptions); + (mcpServersRegistry.getServerConfig as jest.Mock).mockResolvedValue( + {} as unknown as MCPOptions, + ); await reconnectionManager.reconnectServers(userId); @@ -242,7 +253,7 @@ describe('OAuthReconnectionManager', () => { it('should not reconnect servers with expired tokens', async () => { const userId = 'user-123'; const oauthServers = new Set(['server1']); - mockMCPManager.getOAuthServers.mockReturnValue(oauthServers); + (mcpServersRegistry.getOAuthServers as jest.Mock).mockResolvedValue(oauthServers); // server1: has expired token tokenMethods.findToken.mockResolvedValue({ @@ -261,7 +272,7 @@ describe('OAuthReconnectionManager', () => { it('should handle connection that returns but is not connected', async () => { const userId = 'user-123'; const oauthServers = new Set(['server1']); - mockMCPManager.getOAuthServers.mockReturnValue(oauthServers); + (mcpServersRegistry.getOAuthServers as jest.Mock).mockResolvedValue(oauthServers); tokenMethods.findToken.mockResolvedValue({ userId, @@ -277,7 +288,9 @@ describe('OAuthReconnectionManager', () => { mockMCPManager.getUserConnection.mockResolvedValue( mockConnection as unknown as MCPConnection, ); - mockMCPManager.getRawConfig.mockReturnValue({} as unknown as MCPOptions); + (mcpServersRegistry.getServerConfig as jest.Mock).mockResolvedValue( + {} as unknown as MCPOptions, + ); await reconnectionManager.reconnectServers(userId); @@ -359,7 +372,7 @@ describe('OAuthReconnectionManager', () => { it('should not attempt to reconnect servers that have timed out during reconnection', async () => { const userId = 'user-123'; const oauthServers = new Set(['server1', 'server2']); - mockMCPManager.getOAuthServers.mockReturnValue(oauthServers); + (mcpServersRegistry.getOAuthServers as jest.Mock).mockResolvedValue(oauthServers); const now = Date.now(); jest.setSystemTime(now); @@ -414,7 +427,7 @@ describe('OAuthReconnectionManager', () => { const userId = 'user-123'; const serverName = 'server1'; const oauthServers = new Set([serverName]); - mockMCPManager.getOAuthServers.mockReturnValue(oauthServers); + (mcpServersRegistry.getOAuthServers as jest.Mock).mockResolvedValue(oauthServers); const now = Date.now(); jest.setSystemTime(now); @@ -428,7 +441,9 @@ describe('OAuthReconnectionManager', () => { // First reconnect attempt - will fail mockMCPManager.getUserConnection.mockRejectedValueOnce(new Error('Connection failed')); - mockMCPManager.getRawConfig.mockReturnValue({} as unknown as MCPOptions); + (mcpServersRegistry.getServerConfig as jest.Mock).mockResolvedValue( + {} as unknown as MCPOptions, + ); await reconnectionManager.reconnectServers(userId); await jest.runAllTimersAsync(); diff --git a/packages/api/src/mcp/oauth/OAuthReconnectionManager.ts b/packages/api/src/mcp/oauth/OAuthReconnectionManager.ts index 09abb2b048..25edec7f3a 100644 --- a/packages/api/src/mcp/oauth/OAuthReconnectionManager.ts +++ b/packages/api/src/mcp/oauth/OAuthReconnectionManager.ts @@ -5,6 +5,7 @@ import type { MCPOAuthTokens } from './types'; import { OAuthReconnectionTracker } from './OAuthReconnectionTracker'; import { FlowStateManager } from '~/flow/manager'; import { MCPManager } from '~/mcp/MCPManager'; +import { mcpServersRegistry } from '~/mcp/registry/MCPServersRegistry'; const DEFAULT_CONNECTION_TIMEOUT_MS = 10_000; // ms @@ -72,7 +73,7 @@ export class OAuthReconnectionManager { // 1. derive the servers to reconnect const serversToReconnect = []; - for (const serverName of this.mcpManager.getOAuthServers()) { + for (const serverName of await mcpServersRegistry.getOAuthServers()) { const canReconnect = await this.canReconnect(userId, serverName); if (canReconnect) { serversToReconnect.push(serverName); @@ -104,7 +105,7 @@ export class OAuthReconnectionManager { logger.info(`${logPrefix} Attempting reconnection`); - const config = this.mcpManager.getRawConfig(serverName); + const config = await mcpServersRegistry.getServerConfig(serverName, userId); const cleanupOnFailedReconnect = () => { this.reconnectionsTracker.setFailed(userId, serverName); diff --git a/packages/api/src/mcp/registry/MCPServerInspector.ts b/packages/api/src/mcp/registry/MCPServerInspector.ts new file mode 100644 index 0000000000..3ae51d7b36 --- /dev/null +++ b/packages/api/src/mcp/registry/MCPServerInspector.ts @@ -0,0 +1,123 @@ +import { Constants } from 'librechat-data-provider'; +import type { JsonSchemaType } from '@librechat/data-schemas'; +import type { MCPConnection } from '~/mcp/connection'; +import type * as t from '~/mcp/types'; +import { detectOAuthRequirement } from '~/mcp/oauth'; +import { isEnabled } from '~/utils'; +import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory'; + +/** + * Inspects MCP servers to discover their metadata, capabilities, and tools. + * Connects to servers and populates configuration with OAuth requirements, + * server instructions, capabilities, and available tools. + */ +export class MCPServerInspector { + private constructor( + private readonly serverName: string, + private readonly config: t.ParsedServerConfig, + private connection: MCPConnection | undefined, + ) {} + + /** + * Inspects a server and returns an enriched configuration with metadata. + * Detects OAuth requirements and fetches server capabilities. + * @param serverName - The name of the server (used for tool function naming) + * @param rawConfig - The raw server configuration + * @param connection - The MCP connection + * @returns A fully processed and enriched configuration with server metadata + */ + public static async inspect( + serverName: string, + rawConfig: t.MCPOptions, + connection?: MCPConnection, + ): Promise { + const start = Date.now(); + const inspector = new MCPServerInspector(serverName, rawConfig, connection); + await inspector.inspectServer(); + inspector.config.initDuration = Date.now() - start; + return inspector.config; + } + + private async inspectServer(): Promise { + await this.detectOAuth(); + + if (this.config.startup !== false && !this.config.requiresOAuth) { + let tempConnection = false; + if (!this.connection) { + tempConnection = true; + this.connection = await MCPConnectionFactory.create({ + serverName: this.serverName, + serverConfig: this.config, + }); + } + + await Promise.allSettled([ + this.fetchServerInstructions(), + this.fetchServerCapabilities(), + this.fetchToolFunctions(), + ]); + + if (tempConnection) await this.connection.disconnect(); + } + } + + private async detectOAuth(): Promise { + if (this.config.requiresOAuth != null) return; + if (this.config.url == null || this.config.startup === false) { + this.config.requiresOAuth = false; + return; + } + + const result = await detectOAuthRequirement(this.config.url); + this.config.requiresOAuth = result.requiresOAuth; + this.config.oauthMetadata = result.metadata; + } + + private async fetchServerInstructions(): Promise { + if (isEnabled(this.config.serverInstructions)) { + this.config.serverInstructions = this.connection!.client.getInstructions(); + } + } + + private async fetchServerCapabilities(): Promise { + const capabilities = this.connection!.client.getServerCapabilities(); + this.config.capabilities = JSON.stringify(capabilities); + const tools = await this.connection!.client.listTools(); + this.config.tools = tools.tools.map((tool) => tool.name).join(', '); + } + + private async fetchToolFunctions(): Promise { + this.config.toolFunctions = await MCPServerInspector.getToolFunctions( + this.serverName, + this.connection!, + ); + } + + /** + * Converts server tools to LibreChat-compatible tool functions format. + * @param serverName - The name of the server + * @param connection - The MCP connection + * @returns Tool functions formatted for LibreChat + */ + public static async getToolFunctions( + serverName: string, + connection: MCPConnection, + ): Promise { + const { tools }: t.MCPToolListResponse = await connection.client.listTools(); + + const toolFunctions: t.LCAvailableTools = {}; + tools.forEach((tool) => { + const name = `${tool.name}${Constants.mcp_delimiter}${serverName}`; + toolFunctions[name] = { + type: 'function', + ['function']: { + name, + description: tool.description, + parameters: tool.inputSchema as JsonSchemaType, + }, + }; + }); + + return toolFunctions; + } +} diff --git a/packages/api/src/mcp/registry/MCPServersInitializer.ts b/packages/api/src/mcp/registry/MCPServersInitializer.ts new file mode 100644 index 0000000000..f29cd6769f --- /dev/null +++ b/packages/api/src/mcp/registry/MCPServersInitializer.ts @@ -0,0 +1,96 @@ +import { registryStatusCache as statusCache } from './cache/RegistryStatusCache'; +import { isLeader } from '~/cluster'; +import { withTimeout } from '~/utils'; +import { logger } from '@librechat/data-schemas'; +import { MCPServerInspector } from './MCPServerInspector'; +import { ParsedServerConfig } from '~/mcp/types'; +import { sanitizeUrlForLogging } from '~/mcp/utils'; +import type * as t from '~/mcp/types'; +import { mcpServersRegistry as registry } from './MCPServersRegistry'; + +const MCP_INIT_TIMEOUT_MS = + process.env.MCP_INIT_TIMEOUT_MS != null ? parseInt(process.env.MCP_INIT_TIMEOUT_MS) : 30_000; + +/** + * Handles initialization of MCP servers at application startup with distributed coordination. + * In cluster environments, ensures only the leader node performs initialization while followers wait. + * Connects to each configured MCP server, inspects capabilities and tools, then caches the results. + * Categorizes servers as either shared app servers (auto-started) or shared user servers (OAuth/on-demand). + * Uses a timeout mechanism to prevent hanging on unresponsive servers during initialization. + */ +export class MCPServersInitializer { + /** + * Initializes MCP servers with distributed leader-follower coordination. + * + * Design rationale: + * - Handles leader crash scenarios: If the leader crashes during initialization, all followers + * will independently attempt initialization after a 3-second delay. The first to become leader + * will complete the initialization. + * - Only the leader performs the actual initialization work (reset caches, inspect servers). + * When complete, the leader signals completion via `statusCache`, allowing followers to proceed. + * - Followers wait and poll `statusCache` until the leader finishes, ensuring only one node + * performs the expensive initialization operations. + */ + public static async initialize(rawConfigs: t.MCPServers): Promise { + if (await statusCache.isInitialized()) return; + + if (await isLeader()) { + // Leader performs initialization + await statusCache.reset(); + await registry.reset(); + const serverNames = Object.keys(rawConfigs); + await Promise.allSettled( + serverNames.map((serverName) => + withTimeout( + MCPServersInitializer.initializeServer(serverName, rawConfigs[serverName]), + MCP_INIT_TIMEOUT_MS, + `${MCPServersInitializer.prefix(serverName)} Server initialization timed out`, + logger.error, + ), + ), + ); + await statusCache.setInitialized(true); + } else { + // Followers try again after a delay if not initialized + await new Promise((resolve) => setTimeout(resolve, 3000)); + await this.initialize(rawConfigs); + } + } + + /** Initializes a single server with all its metadata and adds it to appropriate collections */ + private static async initializeServer( + serverName: string, + rawConfig: t.MCPOptions, + ): Promise { + try { + const config = await MCPServerInspector.inspect(serverName, rawConfig); + + if (config.startup === false || config.requiresOAuth) { + await registry.sharedUserServers.add(serverName, config); + } else { + await registry.sharedAppServers.add(serverName, config); + } + MCPServersInitializer.logParsedConfig(serverName, config); + } catch (error) { + logger.error(`${MCPServersInitializer.prefix(serverName)} Failed to initialize:`, error); + } + } + + // Logs server configuration summary after initialization + private static logParsedConfig(serverName: string, config: ParsedServerConfig): void { + const prefix = MCPServersInitializer.prefix(serverName); + logger.info(`${prefix} -------------------------------------------------┐`); + logger.info(`${prefix} URL: ${config.url ? sanitizeUrlForLogging(config.url) : 'N/A'}`); + logger.info(`${prefix} OAuth Required: ${config.requiresOAuth}`); + logger.info(`${prefix} Capabilities: ${config.capabilities}`); + logger.info(`${prefix} Tools: ${config.tools}`); + logger.info(`${prefix} Server Instructions: ${config.serverInstructions}`); + logger.info(`${prefix} Initialized in: ${config.initDuration ?? 'N/A'}ms`); + logger.info(`${prefix} -------------------------------------------------┘`); + } + + // Returns formatted log prefix for server messages + private static prefix(serverName: string): string { + return `[MCP][${serverName}]`; + } +} diff --git a/packages/api/src/mcp/registry/MCPServersRegistry.ts b/packages/api/src/mcp/registry/MCPServersRegistry.ts new file mode 100644 index 0000000000..8c6ef13e9c --- /dev/null +++ b/packages/api/src/mcp/registry/MCPServersRegistry.ts @@ -0,0 +1,91 @@ +import type * as t from '~/mcp/types'; +import { + ServerConfigsCacheFactory, + type ServerConfigsCache, +} from './cache/ServerConfigsCacheFactory'; + +/** + * Central registry for managing MCP server configurations across different scopes and users. + * Maintains three categories of server configurations: + * - Shared App Servers: Auto-started servers available to all users (initialized at startup) + * - Shared User Servers: User-scope servers that require OAuth or on-demand startup + * - Private User Servers: Per-user configurations dynamically added during runtime + * + * Provides a unified interface for retrieving server configs with proper fallback hierarchy: + * checks shared app servers first, then shared user servers, then private user servers. + * Handles server lifecycle operations including adding, removing, and querying configurations. + */ +class MCPServersRegistry { + public readonly sharedAppServers = ServerConfigsCacheFactory.create('App', true); + public readonly sharedUserServers = ServerConfigsCacheFactory.create('User', true); + private readonly privateUserServers: Map = new Map(); + + public async addPrivateUserServer( + userId: string, + serverName: string, + config: t.ParsedServerConfig, + ): Promise { + if (!this.privateUserServers.has(userId)) { + const cache = ServerConfigsCacheFactory.create(`User(${userId})`, false); + this.privateUserServers.set(userId, cache); + } + await this.privateUserServers.get(userId)!.add(serverName, config); + } + + public async updatePrivateUserServer( + userId: string, + serverName: string, + config: t.ParsedServerConfig, + ): Promise { + const userCache = this.privateUserServers.get(userId); + if (!userCache) throw new Error(`No private servers found for user "${userId}".`); + await userCache.update(serverName, config); + } + + public async removePrivateUserServer(userId: string, serverName: string): Promise { + await this.privateUserServers.get(userId)?.remove(serverName); + } + + public async getServerConfig( + serverName: string, + userId?: string, + ): Promise { + const sharedAppServer = await this.sharedAppServers.get(serverName); + if (sharedAppServer) return sharedAppServer; + + const sharedUserServer = await this.sharedUserServers.get(serverName); + if (sharedUserServer) return sharedUserServer; + + const privateUserServer = await this.privateUserServers.get(userId)?.get(serverName); + if (privateUserServer) return privateUserServer; + + return undefined; + } + + public async getAllServerConfigs(userId?: string): Promise> { + return { + ...(await this.sharedAppServers.getAll()), + ...(await this.sharedUserServers.getAll()), + ...((await this.privateUserServers.get(userId)?.getAll()) ?? {}), + }; + } + + // TODO: This is currently used to determine if a server requires OAuth. However, this info can + // can be determined through config.requiresOAuth. Refactor usages and remove this method. + public async getOAuthServers(userId?: string): Promise> { + const allServers = await this.getAllServerConfigs(userId); + const oauthServers = Object.entries(allServers).filter(([, config]) => config.requiresOAuth); + return new Set(oauthServers.map(([name]) => name)); + } + + public async reset(): Promise { + await this.sharedAppServers.reset(); + await this.sharedUserServers.reset(); + for (const cache of this.privateUserServers.values()) { + await cache.reset(); + } + this.privateUserServers.clear(); + } +} + +export const mcpServersRegistry = new MCPServersRegistry(); diff --git a/packages/api/src/mcp/registry/__tests__/MCPServerInspector.test.ts b/packages/api/src/mcp/registry/__tests__/MCPServerInspector.test.ts new file mode 100644 index 0000000000..0e4a6ebbe9 --- /dev/null +++ b/packages/api/src/mcp/registry/__tests__/MCPServerInspector.test.ts @@ -0,0 +1,338 @@ +import type { MCPConnection } from '~/mcp/connection'; +import type * as t from '~/mcp/types'; +import { MCPServerInspector } from '~/mcp/registry/MCPServerInspector'; +import { detectOAuthRequirement } from '~/mcp/oauth'; +import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory'; +import { createMockConnection } from './mcpConnectionsMock.helper'; + +// Mock external dependencies +jest.mock('../../oauth/detectOAuth'); +jest.mock('../../MCPConnectionFactory'); + +const mockDetectOAuthRequirement = detectOAuthRequirement as jest.MockedFunction< + typeof detectOAuthRequirement +>; + +describe('MCPServerInspector', () => { + let mockConnection: jest.Mocked; + + beforeEach(() => { + mockConnection = createMockConnection('test_server'); + jest.clearAllMocks(); + }); + + describe('inspect()', () => { + it('should process env and fetch all metadata for non-OAuth stdio server with serverInstructions=true', async () => { + const rawConfig: t.MCPOptions = { + type: 'stdio', + command: 'node', + args: ['server.js'], + serverInstructions: true, + }; + + mockDetectOAuthRequirement.mockResolvedValue({ + requiresOAuth: false, + method: 'no-metadata-found', + }); + + const result = await MCPServerInspector.inspect('test_server', rawConfig, mockConnection); + + expect(result).toEqual({ + type: 'stdio', + command: 'node', + args: ['server.js'], + serverInstructions: 'instructions for test_server', + requiresOAuth: false, + capabilities: + '{"tools":{"listChanged":true},"resources":{"listChanged":true},"prompts":{"get":"getPrompts for test_server"}}', + tools: 'listFiles', + toolFunctions: { + listFiles_mcp_test_server: expect.objectContaining({ + type: 'function', + function: expect.objectContaining({ + name: 'listFiles_mcp_test_server', + }), + }), + }, + initDuration: expect.any(Number), + }); + }); + + it('should detect OAuth and skip capabilities fetch for streamable-http server', async () => { + const rawConfig: t.MCPOptions = { + type: 'streamable-http', + url: 'https://api.example.com/mcp', + }; + + mockDetectOAuthRequirement.mockResolvedValue({ + requiresOAuth: true, + method: 'protected-resource-metadata', + }); + + const result = await MCPServerInspector.inspect('test_server', rawConfig, mockConnection); + + expect(result).toEqual({ + type: 'streamable-http', + url: 'https://api.example.com/mcp', + requiresOAuth: true, + oauthMetadata: undefined, + initDuration: expect.any(Number), + }); + }); + + it('should skip capabilities fetch when startup=false', async () => { + const rawConfig: t.MCPOptions = { + type: 'stdio', + command: 'node', + args: ['server.js'], + startup: false, + }; + + const result = await MCPServerInspector.inspect('test_server', rawConfig, mockConnection); + + expect(result).toEqual({ + type: 'stdio', + command: 'node', + args: ['server.js'], + startup: false, + requiresOAuth: false, + initDuration: expect.any(Number), + }); + }); + + it('should keep custom serverInstructions string and not fetch from server', async () => { + const rawConfig: t.MCPOptions = { + type: 'stdio', + command: 'node', + args: ['server.js'], + serverInstructions: 'Custom instructions here', + }; + + mockDetectOAuthRequirement.mockResolvedValue({ + requiresOAuth: false, + method: 'no-metadata-found', + }); + + const result = await MCPServerInspector.inspect('test_server', rawConfig, mockConnection); + + expect(result).toEqual({ + type: 'stdio', + command: 'node', + args: ['server.js'], + serverInstructions: 'Custom instructions here', + requiresOAuth: false, + capabilities: + '{"tools":{"listChanged":true},"resources":{"listChanged":true},"prompts":{"get":"getPrompts for test_server"}}', + tools: 'listFiles', + toolFunctions: expect.any(Object), + initDuration: expect.any(Number), + }); + }); + + it('should handle serverInstructions as string "true" and fetch from server', async () => { + const rawConfig: t.MCPOptions = { + type: 'stdio', + command: 'node', + args: ['server.js'], + serverInstructions: 'true', // String "true" from YAML + }; + + mockDetectOAuthRequirement.mockResolvedValue({ + requiresOAuth: false, + method: 'no-metadata-found', + }); + + const result = await MCPServerInspector.inspect('test_server', rawConfig, mockConnection); + + expect(result).toEqual({ + type: 'stdio', + command: 'node', + args: ['server.js'], + serverInstructions: 'instructions for test_server', + requiresOAuth: false, + capabilities: + '{"tools":{"listChanged":true},"resources":{"listChanged":true},"prompts":{"get":"getPrompts for test_server"}}', + tools: 'listFiles', + toolFunctions: expect.any(Object), + initDuration: expect.any(Number), + }); + }); + + it('should handle predefined requiresOAuth without detection', async () => { + const rawConfig: t.MCPOptions = { + type: 'sse', + url: 'https://api.example.com/sse', + requiresOAuth: true, + }; + + const result = await MCPServerInspector.inspect('test_server', rawConfig, mockConnection); + + expect(result).toEqual({ + type: 'sse', + url: 'https://api.example.com/sse', + requiresOAuth: true, + initDuration: expect.any(Number), + }); + }); + + it('should fetch capabilities when server has no tools', async () => { + const rawConfig: t.MCPOptions = { + type: 'stdio', + command: 'node', + args: ['server.js'], + }; + + mockDetectOAuthRequirement.mockResolvedValue({ + requiresOAuth: false, + method: 'no-metadata-found', + }); + + // Mock server with no tools + mockConnection.client.listTools = jest.fn().mockResolvedValue({ tools: [] }); + + const result = await MCPServerInspector.inspect('test_server', rawConfig, mockConnection); + + expect(result).toEqual({ + type: 'stdio', + command: 'node', + args: ['server.js'], + requiresOAuth: false, + capabilities: + '{"tools":{"listChanged":true},"resources":{"listChanged":true},"prompts":{"get":"getPrompts for test_server"}}', + tools: '', + toolFunctions: {}, + initDuration: expect.any(Number), + }); + }); + + it('should create temporary connection when no connection is provided', async () => { + const rawConfig: t.MCPOptions = { + type: 'stdio', + command: 'node', + args: ['server.js'], + serverInstructions: true, + }; + + const tempMockConnection = createMockConnection('test_server'); + (MCPConnectionFactory.create as jest.Mock).mockResolvedValue(tempMockConnection); + + mockDetectOAuthRequirement.mockResolvedValue({ + requiresOAuth: false, + method: 'no-metadata-found', + }); + + const result = await MCPServerInspector.inspect('test_server', rawConfig); + + // Verify factory was called to create connection + expect(MCPConnectionFactory.create).toHaveBeenCalledWith({ + serverName: 'test_server', + serverConfig: expect.objectContaining({ type: 'stdio', command: 'node' }), + }); + + // Verify temporary connection was disconnected + expect(tempMockConnection.disconnect).toHaveBeenCalled(); + + // Verify result is correct + expect(result).toEqual({ + type: 'stdio', + command: 'node', + args: ['server.js'], + serverInstructions: 'instructions for test_server', + requiresOAuth: false, + capabilities: + '{"tools":{"listChanged":true},"resources":{"listChanged":true},"prompts":{"get":"getPrompts for test_server"}}', + tools: 'listFiles', + toolFunctions: expect.any(Object), + initDuration: expect.any(Number), + }); + }); + + it('should not create temporary connection when connection is provided', async () => { + const rawConfig: t.MCPOptions = { + type: 'stdio', + command: 'node', + args: ['server.js'], + serverInstructions: true, + }; + + mockDetectOAuthRequirement.mockResolvedValue({ + requiresOAuth: false, + method: 'no-metadata-found', + }); + + await MCPServerInspector.inspect('test_server', rawConfig, mockConnection); + + // Verify factory was NOT called + expect(MCPConnectionFactory.create).not.toHaveBeenCalled(); + + // Verify provided connection was NOT disconnected + expect(mockConnection.disconnect).not.toHaveBeenCalled(); + }); + }); + + describe('getToolFunctions()', () => { + it('should convert MCP tools to LibreChat tool functions format', async () => { + mockConnection.client.listTools = jest.fn().mockResolvedValue({ + tools: [ + { + name: 'file_read', + description: 'Read a file', + inputSchema: { + type: 'object', + properties: { path: { type: 'string' } }, + }, + }, + { + name: 'file_write', + description: 'Write a file', + inputSchema: { + type: 'object', + properties: { + path: { type: 'string' }, + content: { type: 'string' }, + }, + }, + }, + ], + }); + + const result = await MCPServerInspector.getToolFunctions('my_server', mockConnection); + + expect(result).toEqual({ + file_read_mcp_my_server: { + type: 'function', + function: { + name: 'file_read_mcp_my_server', + description: 'Read a file', + parameters: { + type: 'object', + properties: { path: { type: 'string' } }, + }, + }, + }, + file_write_mcp_my_server: { + type: 'function', + function: { + name: 'file_write_mcp_my_server', + description: 'Write a file', + parameters: { + type: 'object', + properties: { + path: { type: 'string' }, + content: { type: 'string' }, + }, + }, + }, + }, + }); + }); + + it('should handle empty tools list', async () => { + mockConnection.client.listTools = jest.fn().mockResolvedValue({ tools: [] }); + + const result = await MCPServerInspector.getToolFunctions('my_server', mockConnection); + + expect(result).toEqual({}); + }); + }); +}); diff --git a/packages/api/src/mcp/registry/__tests__/MCPServersInitializer.cache_integration.spec.ts b/packages/api/src/mcp/registry/__tests__/MCPServersInitializer.cache_integration.spec.ts new file mode 100644 index 0000000000..820cdfa54e --- /dev/null +++ b/packages/api/src/mcp/registry/__tests__/MCPServersInitializer.cache_integration.spec.ts @@ -0,0 +1,301 @@ +import { expect } from '@playwright/test'; +import type * as t from '~/mcp/types'; +import type { MCPConnection } from '~/mcp/connection'; + +// Mock isLeader to always return true to avoid lock contention during parallel operations +jest.mock('~/cluster', () => ({ + ...jest.requireActual('~/cluster'), + isLeader: jest.fn().mockResolvedValue(true), +})); + +describe('MCPServersInitializer Redis Integration Tests', () => { + let MCPServersInitializer: typeof import('../MCPServersInitializer').MCPServersInitializer; + let registry: typeof import('../MCPServersRegistry').mcpServersRegistry; + let registryStatusCache: typeof import('../cache/RegistryStatusCache').registryStatusCache; + let MCPServerInspector: typeof import('../MCPServerInspector').MCPServerInspector; + let MCPConnectionFactory: typeof import('~/mcp/MCPConnectionFactory').MCPConnectionFactory; + let keyvRedisClient: Awaited['keyvRedisClient']; + let LeaderElection: typeof import('~/cluster/LeaderElection').LeaderElection; + let leaderInstance: InstanceType; + + const testConfigs: t.MCPServers = { + disabled_server: { + type: 'stdio', + command: 'node', + args: ['disabled.js'], + startup: false, + }, + oauth_server: { + type: 'streamable-http', + url: 'https://api.example.com/mcp', + }, + file_tools_server: { + type: 'stdio', + command: 'node', + args: ['tools.js'], + }, + search_tools_server: { + type: 'stdio', + command: 'node', + args: ['instructions.js'], + }, + }; + + const testParsedConfigs: Record = { + disabled_server: { + type: 'stdio', + command: 'node', + args: ['disabled.js'], + startup: false, + requiresOAuth: false, + }, + oauth_server: { + type: 'streamable-http', + url: 'https://api.example.com/mcp', + requiresOAuth: true, + }, + file_tools_server: { + type: 'stdio', + command: 'node', + args: ['tools.js'], + requiresOAuth: false, + serverInstructions: 'Instructions for file_tools_server', + tools: 'file_read, file_write', + capabilities: '{"tools":{"listChanged":true}}', + toolFunctions: { + file_read_mcp_file_tools_server: { + type: 'function', + function: { + name: 'file_read_mcp_file_tools_server', + description: 'Read a file', + parameters: { type: 'object' }, + }, + }, + }, + }, + search_tools_server: { + type: 'stdio', + command: 'node', + args: ['instructions.js'], + requiresOAuth: false, + serverInstructions: 'Instructions for search_tools_server', + capabilities: '{"tools":{"listChanged":true}}', + tools: 'search', + toolFunctions: { + search_mcp_search_tools_server: { + type: 'function', + function: { + name: 'search_mcp_search_tools_server', + description: 'Search tool', + parameters: { type: 'object' }, + }, + }, + }, + }, + }; + + beforeAll(async () => { + // Set up environment variables for Redis (only if not already set) + process.env.USE_REDIS = process.env.USE_REDIS ?? 'true'; + process.env.REDIS_URI = process.env.REDIS_URI ?? 'redis://127.0.0.1:6379'; + process.env.REDIS_KEY_PREFIX = + process.env.REDIS_KEY_PREFIX ?? 'MCPServersInitializer-IntegrationTest'; + + // Import modules after setting env vars + const initializerModule = await import('../MCPServersInitializer'); + const registryModule = await import('../MCPServersRegistry'); + const statusCacheModule = await import('../cache/RegistryStatusCache'); + const inspectorModule = await import('../MCPServerInspector'); + const connectionFactoryModule = await import('~/mcp/MCPConnectionFactory'); + const redisClients = await import('~/cache/redisClients'); + const leaderElectionModule = await import('~/cluster/LeaderElection'); + + MCPServersInitializer = initializerModule.MCPServersInitializer; + registry = registryModule.mcpServersRegistry; + registryStatusCache = statusCacheModule.registryStatusCache; + MCPServerInspector = inspectorModule.MCPServerInspector; + MCPConnectionFactory = connectionFactoryModule.MCPConnectionFactory; + keyvRedisClient = redisClients.keyvRedisClient; + LeaderElection = leaderElectionModule.LeaderElection; + + // Ensure Redis is connected + if (!keyvRedisClient) throw new Error('Redis client is not initialized'); + + // Wait for Redis to be ready + if (!keyvRedisClient.isOpen) await keyvRedisClient.connect(); + + // Become leader so we can perform write operations + leaderInstance = new LeaderElection(); + const isLeader = await leaderInstance.isLeader(); + expect(isLeader).toBe(true); + }); + + beforeEach(async () => { + // Ensure we're still the leader + const isLeader = await leaderInstance.isLeader(); + if (!isLeader) { + throw new Error('Lost leader status before test'); + } + + // Mock MCPServerInspector.inspect to return parsed config + jest.spyOn(MCPServerInspector, 'inspect').mockImplementation(async (serverName: string) => { + return { + ...testParsedConfigs[serverName], + _processedByInspector: true, + } as unknown as t.ParsedServerConfig; + }); + + // Mock MCPConnection + const mockConnection = { + disconnect: jest.fn().mockResolvedValue(undefined), + } as unknown as jest.Mocked; + + // Mock MCPConnectionFactory + jest.spyOn(MCPConnectionFactory, 'create').mockResolvedValue(mockConnection); + + // Reset caches before each test + await registryStatusCache.reset(); + await registry.reset(); + }); + + afterEach(async () => { + // Clean up: clear all test keys from Redis + if (keyvRedisClient) { + const pattern = '*MCPServersInitializer-IntegrationTest*'; + if ('scanIterator' in keyvRedisClient) { + for await (const key of keyvRedisClient.scanIterator({ MATCH: pattern })) { + await keyvRedisClient.del(key); + } + } + } + + jest.restoreAllMocks(); + }); + + afterAll(async () => { + // Resign as leader + if (leaderInstance) await leaderInstance.resign(); + + // Close Redis connection + if (keyvRedisClient?.isOpen) await keyvRedisClient.disconnect(); + }); + + describe('initialize()', () => { + it('should reset registry and status cache before initialization', async () => { + // Pre-populate registry with some old servers + await registry.sharedAppServers.add('old_app_server', testParsedConfigs.file_tools_server); + await registry.sharedUserServers.add('old_user_server', testParsedConfigs.oauth_server); + + // Initialize with new configs (this should reset first) + await MCPServersInitializer.initialize(testConfigs); + + // Verify old servers are gone + expect(await registry.sharedAppServers.get('old_app_server')).toBeUndefined(); + expect(await registry.sharedUserServers.get('old_user_server')).toBeUndefined(); + + // Verify new servers are present + expect(await registry.sharedAppServers.get('file_tools_server')).toBeDefined(); + expect(await registry.sharedUserServers.get('oauth_server')).toBeDefined(); + expect(await registryStatusCache.isInitialized()).toBe(true); + }); + + it('should skip initialization if already initialized', async () => { + // First initialization + await MCPServersInitializer.initialize(testConfigs); + + // Clear mock calls + jest.clearAllMocks(); + + // Second initialization should skip due to static flag + await MCPServersInitializer.initialize(testConfigs); + + // Verify inspect was not called again + expect(MCPServerInspector.inspect).not.toHaveBeenCalled(); + }); + + it('should add disabled servers to sharedUserServers', async () => { + await MCPServersInitializer.initialize(testConfigs); + + const disabledServer = await registry.sharedUserServers.get('disabled_server'); + expect(disabledServer).toBeDefined(); + expect(disabledServer).toMatchObject({ + ...testParsedConfigs.disabled_server, + _processedByInspector: true, + }); + }); + + it('should add OAuth servers to sharedUserServers', async () => { + await MCPServersInitializer.initialize(testConfigs); + + const oauthServer = await registry.sharedUserServers.get('oauth_server'); + expect(oauthServer).toBeDefined(); + expect(oauthServer).toMatchObject({ + ...testParsedConfigs.oauth_server, + _processedByInspector: true, + }); + }); + + it('should add enabled non-OAuth servers to sharedAppServers', async () => { + await MCPServersInitializer.initialize(testConfigs); + + const fileToolsServer = await registry.sharedAppServers.get('file_tools_server'); + expect(fileToolsServer).toBeDefined(); + expect(fileToolsServer).toMatchObject({ + ...testParsedConfigs.file_tools_server, + _processedByInspector: true, + }); + + const searchToolsServer = await registry.sharedAppServers.get('search_tools_server'); + expect(searchToolsServer).toBeDefined(); + expect(searchToolsServer).toMatchObject({ + ...testParsedConfigs.search_tools_server, + _processedByInspector: true, + }); + }); + + it('should successfully initialize all servers', async () => { + await MCPServersInitializer.initialize(testConfigs); + + // Verify all servers were added to appropriate registries + expect(await registry.sharedUserServers.get('disabled_server')).toBeDefined(); + expect(await registry.sharedUserServers.get('oauth_server')).toBeDefined(); + expect(await registry.sharedAppServers.get('file_tools_server')).toBeDefined(); + expect(await registry.sharedAppServers.get('search_tools_server')).toBeDefined(); + }); + + it('should handle inspection failures gracefully', async () => { + // Mock inspection failure for one server + jest.spyOn(MCPServerInspector, 'inspect').mockImplementation(async (serverName: string) => { + if (serverName === 'file_tools_server') { + throw new Error('Inspection failed'); + } + return { + ...testParsedConfigs[serverName], + _processedByInspector: true, + } as unknown as t.ParsedServerConfig; + }); + + await MCPServersInitializer.initialize(testConfigs); + + // Verify other servers were still processed + const disabledServer = await registry.sharedUserServers.get('disabled_server'); + expect(disabledServer).toBeDefined(); + + const oauthServer = await registry.sharedUserServers.get('oauth_server'); + expect(oauthServer).toBeDefined(); + + const searchToolsServer = await registry.sharedAppServers.get('search_tools_server'); + expect(searchToolsServer).toBeDefined(); + + // Verify file_tools_server was not added (due to inspection failure) + const fileToolsServer = await registry.sharedAppServers.get('file_tools_server'); + expect(fileToolsServer).toBeUndefined(); + }); + + it('should set initialized status after completion', async () => { + await MCPServersInitializer.initialize(testConfigs); + + expect(await registryStatusCache.isInitialized()).toBe(true); + }); + }); +}); diff --git a/packages/api/src/mcp/registry/__tests__/MCPServersInitializer.test.ts b/packages/api/src/mcp/registry/__tests__/MCPServersInitializer.test.ts new file mode 100644 index 0000000000..2ce8d09d93 --- /dev/null +++ b/packages/api/src/mcp/registry/__tests__/MCPServersInitializer.test.ts @@ -0,0 +1,292 @@ +import { logger } from '@librechat/data-schemas'; +import * as t from '~/mcp/types'; +import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory'; +import { MCPServersInitializer } from '~/mcp/registry/MCPServersInitializer'; +import { MCPConnection } from '~/mcp/connection'; +import { registryStatusCache } from '~/mcp/registry/cache/RegistryStatusCache'; +import { MCPServerInspector } from '~/mcp/registry/MCPServerInspector'; +import { mcpServersRegistry as registry } from '~/mcp/registry/MCPServersRegistry'; + +// Mock external dependencies +jest.mock('../../MCPConnectionFactory'); +jest.mock('../../connection'); +jest.mock('../../registry/MCPServerInspector'); +jest.mock('~/cluster', () => ({ + isLeader: jest.fn().mockResolvedValue(true), +})); +jest.mock('@librechat/data-schemas', () => ({ + logger: { + info: jest.fn(), + warn: jest.fn(), + error: jest.fn(), + debug: jest.fn(), + }, +})); + +const mockLogger = logger as jest.Mocked; +const mockInspect = MCPServerInspector.inspect as jest.MockedFunction< + typeof MCPServerInspector.inspect +>; + +describe('MCPServersInitializer', () => { + let mockConnection: jest.Mocked; + + const testConfigs: t.MCPServers = { + disabled_server: { + type: 'stdio', + command: 'node', + args: ['disabled.js'], + startup: false, + }, + oauth_server: { + type: 'streamable-http', + url: 'https://api.example.com/mcp', + }, + file_tools_server: { + type: 'stdio', + command: 'node', + args: ['tools.js'], + }, + search_tools_server: { + type: 'stdio', + command: 'node', + args: ['instructions.js'], + }, + }; + + const testParsedConfigs: Record = { + disabled_server: { + type: 'stdio', + command: 'node', + args: ['disabled.js'], + startup: false, + requiresOAuth: false, + }, + oauth_server: { + type: 'streamable-http', + url: 'https://api.example.com/mcp', + requiresOAuth: true, + }, + file_tools_server: { + type: 'stdio', + command: 'node', + args: ['tools.js'], + requiresOAuth: false, + serverInstructions: 'Instructions for file_tools_server', + tools: 'file_read, file_write', + capabilities: '{"tools":{"listChanged":true}}', + toolFunctions: { + file_read_mcp_file_tools_server: { + type: 'function', + function: { + name: 'file_read_mcp_file_tools_server', + description: 'Read a file', + parameters: { type: 'object' }, + }, + }, + }, + }, + search_tools_server: { + type: 'stdio', + command: 'node', + args: ['instructions.js'], + requiresOAuth: false, + serverInstructions: 'Instructions for search_tools_server', + capabilities: '{"tools":{"listChanged":true}}', + tools: 'search', + toolFunctions: { + search_mcp_search_tools_server: { + type: 'function', + function: { + name: 'search_mcp_search_tools_server', + description: 'Search tool', + parameters: { type: 'object' }, + }, + }, + }, + }, + }; + + beforeEach(async () => { + // Setup MCPConnection mock + mockConnection = { + disconnect: jest.fn().mockResolvedValue(undefined), + } as unknown as jest.Mocked; + + // Setup MCPConnectionFactory mock + (MCPConnectionFactory.create as jest.Mock).mockResolvedValue(mockConnection); + + // Mock MCPServerInspector.inspect to return parsed config + mockInspect.mockImplementation(async (serverName: string) => { + return { + ...testParsedConfigs[serverName], + _processedByInspector: true, + } as unknown as t.ParsedServerConfig; + }); + + // Reset caches before each test + await registryStatusCache.reset(); + await registry.sharedAppServers.reset(); + await registry.sharedUserServers.reset(); + jest.clearAllMocks(); + }); + + afterEach(() => { + jest.restoreAllMocks(); + }); + + describe('initialize()', () => { + it('should reset registry and status cache before initialization', async () => { + // Pre-populate registry with some old servers + await registry.sharedAppServers.add('old_app_server', testParsedConfigs.file_tools_server); + await registry.sharedUserServers.add('old_user_server', testParsedConfigs.oauth_server); + + // Initialize with new configs (this should reset first) + await MCPServersInitializer.initialize(testConfigs); + + // Verify old servers are gone + expect(await registry.sharedAppServers.get('old_app_server')).toBeUndefined(); + expect(await registry.sharedUserServers.get('old_user_server')).toBeUndefined(); + + // Verify new servers are present + expect(await registry.sharedAppServers.get('file_tools_server')).toBeDefined(); + expect(await registry.sharedUserServers.get('oauth_server')).toBeDefined(); + expect(await registryStatusCache.isInitialized()).toBe(true); + }); + + it('should skip initialization if already initialized (Redis flag)', async () => { + // First initialization + await MCPServersInitializer.initialize(testConfigs); + + jest.clearAllMocks(); + + // Second initialization should skip due to Redis cache flag + await MCPServersInitializer.initialize(testConfigs); + + expect(mockInspect).not.toHaveBeenCalled(); + }); + + it('should process all server configs through inspector', async () => { + await MCPServersInitializer.initialize(testConfigs); + + // Verify all configs were processed by inspector (without connection parameter) + expect(mockInspect).toHaveBeenCalledTimes(4); + expect(mockInspect).toHaveBeenCalledWith('disabled_server', testConfigs.disabled_server); + expect(mockInspect).toHaveBeenCalledWith('oauth_server', testConfigs.oauth_server); + expect(mockInspect).toHaveBeenCalledWith('file_tools_server', testConfigs.file_tools_server); + expect(mockInspect).toHaveBeenCalledWith( + 'search_tools_server', + testConfigs.search_tools_server, + ); + }); + + it('should add disabled servers to sharedUserServers', async () => { + await MCPServersInitializer.initialize(testConfigs); + + const disabledServer = await registry.sharedUserServers.get('disabled_server'); + expect(disabledServer).toBeDefined(); + expect(disabledServer).toMatchObject({ + ...testParsedConfigs.disabled_server, + _processedByInspector: true, + }); + }); + + it('should add OAuth servers to sharedUserServers', async () => { + await MCPServersInitializer.initialize(testConfigs); + + const oauthServer = await registry.sharedUserServers.get('oauth_server'); + expect(oauthServer).toBeDefined(); + expect(oauthServer).toMatchObject({ + ...testParsedConfigs.oauth_server, + _processedByInspector: true, + }); + }); + + it('should add enabled non-OAuth servers to sharedAppServers', async () => { + await MCPServersInitializer.initialize(testConfigs); + + const fileToolsServer = await registry.sharedAppServers.get('file_tools_server'); + expect(fileToolsServer).toBeDefined(); + expect(fileToolsServer).toMatchObject({ + ...testParsedConfigs.file_tools_server, + _processedByInspector: true, + }); + + const searchToolsServer = await registry.sharedAppServers.get('search_tools_server'); + expect(searchToolsServer).toBeDefined(); + expect(searchToolsServer).toMatchObject({ + ...testParsedConfigs.search_tools_server, + _processedByInspector: true, + }); + }); + + it('should successfully initialize all servers', async () => { + await MCPServersInitializer.initialize(testConfigs); + + // Verify all servers were added to appropriate registries + expect(await registry.sharedUserServers.get('disabled_server')).toBeDefined(); + expect(await registry.sharedUserServers.get('oauth_server')).toBeDefined(); + expect(await registry.sharedAppServers.get('file_tools_server')).toBeDefined(); + expect(await registry.sharedAppServers.get('search_tools_server')).toBeDefined(); + }); + + it('should handle inspection failures gracefully', async () => { + // Mock inspection failure for one server + mockInspect.mockImplementation(async (serverName: string) => { + if (serverName === 'file_tools_server') { + throw new Error('Inspection failed'); + } + return { + ...testParsedConfigs[serverName], + _processedByInspector: true, + } as unknown as t.ParsedServerConfig; + }); + + await MCPServersInitializer.initialize(testConfigs); + + // Verify other servers were still processed + const disabledServer = await registry.sharedUserServers.get('disabled_server'); + expect(disabledServer).toBeDefined(); + + const oauthServer = await registry.sharedUserServers.get('oauth_server'); + expect(oauthServer).toBeDefined(); + + const searchToolsServer = await registry.sharedAppServers.get('search_tools_server'); + expect(searchToolsServer).toBeDefined(); + + // Verify file_tools_server was not added (due to inspection failure) + const fileToolsServer = await registry.sharedAppServers.get('file_tools_server'); + expect(fileToolsServer).toBeUndefined(); + }); + + it('should log server configuration after initialization', async () => { + await MCPServersInitializer.initialize(testConfigs); + + // Verify logging occurred for each server + expect(mockLogger.info).toHaveBeenCalledWith( + expect.stringContaining('[MCP][disabled_server]'), + ); + expect(mockLogger.info).toHaveBeenCalledWith(expect.stringContaining('[MCP][oauth_server]')); + expect(mockLogger.info).toHaveBeenCalledWith( + expect.stringContaining('[MCP][file_tools_server]'), + ); + }); + + it('should use Promise.allSettled for parallel server initialization', async () => { + const allSettledSpy = jest.spyOn(Promise, 'allSettled'); + + await MCPServersInitializer.initialize(testConfigs); + + expect(allSettledSpy).toHaveBeenCalledWith(expect.arrayContaining([expect.any(Promise)])); + expect(allSettledSpy).toHaveBeenCalledTimes(1); + + allSettledSpy.mockRestore(); + }); + + it('should set initialized status after completion', async () => { + await MCPServersInitializer.initialize(testConfigs); + + expect(await registryStatusCache.isInitialized()).toBe(true); + }); + }); +}); diff --git a/packages/api/src/mcp/registry/__tests__/MCPServersRegistry.cache_integration.spec.ts b/packages/api/src/mcp/registry/__tests__/MCPServersRegistry.cache_integration.spec.ts new file mode 100644 index 0000000000..68e9291d46 --- /dev/null +++ b/packages/api/src/mcp/registry/__tests__/MCPServersRegistry.cache_integration.spec.ts @@ -0,0 +1,227 @@ +import { expect } from '@playwright/test'; +import type * as t from '~/mcp/types'; + +/** + * Integration tests for MCPServersRegistry using Redis-backed cache. + * For unit tests using in-memory cache, see MCPServersRegistry.test.ts + */ +describe('MCPServersRegistry Redis Integration Tests', () => { + let registry: typeof import('../MCPServersRegistry').mcpServersRegistry; + let keyvRedisClient: Awaited['keyvRedisClient']; + let LeaderElection: typeof import('~/cluster/LeaderElection').LeaderElection; + let leaderInstance: InstanceType; + + const testParsedConfig: t.ParsedServerConfig = { + type: 'stdio', + command: 'node', + args: ['tools.js'], + requiresOAuth: false, + serverInstructions: 'Instructions for file_tools_server', + tools: 'file_read, file_write', + capabilities: '{"tools":{"listChanged":true}}', + toolFunctions: { + file_read_mcp_file_tools_server: { + type: 'function', + function: { + name: 'file_read_mcp_file_tools_server', + description: 'Read a file', + parameters: { type: 'object' }, + }, + }, + }, + }; + + beforeAll(async () => { + // Set up environment variables for Redis (only if not already set) + process.env.USE_REDIS = process.env.USE_REDIS ?? 'true'; + process.env.REDIS_URI = process.env.REDIS_URI ?? 'redis://127.0.0.1:6379'; + process.env.REDIS_KEY_PREFIX = + process.env.REDIS_KEY_PREFIX ?? 'MCPServersRegistry-IntegrationTest'; + + // Import modules after setting env vars + const registryModule = await import('../MCPServersRegistry'); + const redisClients = await import('~/cache/redisClients'); + const leaderElectionModule = await import('~/cluster/LeaderElection'); + + registry = registryModule.mcpServersRegistry; + keyvRedisClient = redisClients.keyvRedisClient; + LeaderElection = leaderElectionModule.LeaderElection; + + // Ensure Redis is connected + if (!keyvRedisClient) throw new Error('Redis client is not initialized'); + + // Wait for Redis to be ready + if (!keyvRedisClient.isOpen) await keyvRedisClient.connect(); + + // Become leader so we can perform write operations + leaderInstance = new LeaderElection(); + const isLeader = await leaderInstance.isLeader(); + expect(isLeader).toBe(true); + }); + + afterEach(async () => { + // Clean up: reset registry to clear all test data + await registry.reset(); + + // Also clean up any remaining test keys from Redis + if (keyvRedisClient) { + const pattern = '*MCPServersRegistry-IntegrationTest*'; + if ('scanIterator' in keyvRedisClient) { + for await (const key of keyvRedisClient.scanIterator({ MATCH: pattern })) { + await keyvRedisClient.del(key); + } + } + } + }); + + afterAll(async () => { + // Resign as leader + if (leaderInstance) await leaderInstance.resign(); + + // Close Redis connection + if (keyvRedisClient?.isOpen) await keyvRedisClient.disconnect(); + }); + + describe('private user servers', () => { + it('should add and remove private user server', async () => { + const userId = 'user123'; + const serverName = 'private_server'; + + // Add private user server + await registry.addPrivateUserServer(userId, serverName, testParsedConfig); + + // Verify server was added + const retrievedConfig = await registry.getServerConfig(serverName, userId); + expect(retrievedConfig).toEqual(testParsedConfig); + + // Remove private user server + await registry.removePrivateUserServer(userId, serverName); + + // Verify server was removed + const configAfterRemoval = await registry.getServerConfig(serverName, userId); + expect(configAfterRemoval).toBeUndefined(); + }); + + it('should throw error when adding duplicate private user server', async () => { + const userId = 'user123'; + const serverName = 'private_server'; + + await registry.addPrivateUserServer(userId, serverName, testParsedConfig); + await expect( + registry.addPrivateUserServer(userId, serverName, testParsedConfig), + ).rejects.toThrow( + 'Server "private_server" already exists in cache. Use update() to modify existing configs.', + ); + }); + + it('should update an existing private user server', async () => { + const userId = 'user123'; + const serverName = 'private_server'; + const updatedConfig: t.ParsedServerConfig = { + type: 'stdio', + command: 'python', + args: ['updated.py'], + requiresOAuth: true, + }; + + // Add private user server + await registry.addPrivateUserServer(userId, serverName, testParsedConfig); + + // Update the server config + await registry.updatePrivateUserServer(userId, serverName, updatedConfig); + + // Verify server was updated + const retrievedConfig = await registry.getServerConfig(serverName, userId); + expect(retrievedConfig).toEqual(updatedConfig); + }); + + it('should throw error when updating non-existent server', async () => { + const userId = 'user123'; + const serverName = 'private_server'; + + // Add a user cache first + await registry.addPrivateUserServer(userId, 'other_server', testParsedConfig); + + await expect( + registry.updatePrivateUserServer(userId, serverName, testParsedConfig), + ).rejects.toThrow( + 'Server "private_server" does not exist in cache. Use add() to create new configs.', + ); + }); + + it('should throw error when updating server for non-existent user', async () => { + const userId = 'nonexistent_user'; + const serverName = 'private_server'; + + await expect( + registry.updatePrivateUserServer(userId, serverName, testParsedConfig), + ).rejects.toThrow('No private servers found for user "nonexistent_user".'); + }); + }); + + describe('getAllServerConfigs', () => { + it('should return correct servers based on userId', async () => { + // Add servers to all three caches + await registry.sharedAppServers.add('app_server', testParsedConfig); + await registry.sharedUserServers.add('user_server', testParsedConfig); + await registry.addPrivateUserServer('abc', 'abc_private_server', testParsedConfig); + await registry.addPrivateUserServer('xyz', 'xyz_private_server', testParsedConfig); + + // Without userId: should return only shared app + shared user servers + const configsNoUser = await registry.getAllServerConfigs(); + expect(Object.keys(configsNoUser)).toHaveLength(2); + expect(configsNoUser).toHaveProperty('app_server'); + expect(configsNoUser).toHaveProperty('user_server'); + + // With userId 'abc': should return shared app + shared user + abc's private servers + const configsAbc = await registry.getAllServerConfigs('abc'); + expect(Object.keys(configsAbc)).toHaveLength(3); + expect(configsAbc).toHaveProperty('app_server'); + expect(configsAbc).toHaveProperty('user_server'); + expect(configsAbc).toHaveProperty('abc_private_server'); + + // With userId 'xyz': should return shared app + shared user + xyz's private servers + const configsXyz = await registry.getAllServerConfigs('xyz'); + expect(Object.keys(configsXyz)).toHaveLength(3); + expect(configsXyz).toHaveProperty('app_server'); + expect(configsXyz).toHaveProperty('user_server'); + expect(configsXyz).toHaveProperty('xyz_private_server'); + }); + }); + + describe('reset', () => { + it('should clear all servers from all caches (shared app, shared user, and private user)', async () => { + const userId = 'user123'; + + // Add servers to all three caches + await registry.sharedAppServers.add('app_server', testParsedConfig); + await registry.sharedUserServers.add('user_server', testParsedConfig); + await registry.addPrivateUserServer(userId, 'private_server', testParsedConfig); + + // Verify all servers are accessible before reset + const appConfigBefore = await registry.getServerConfig('app_server'); + const userConfigBefore = await registry.getServerConfig('user_server'); + const privateConfigBefore = await registry.getServerConfig('private_server', userId); + const allConfigsBefore = await registry.getAllServerConfigs(userId); + + expect(appConfigBefore).toEqual(testParsedConfig); + expect(userConfigBefore).toEqual(testParsedConfig); + expect(privateConfigBefore).toEqual(testParsedConfig); + expect(Object.keys(allConfigsBefore)).toHaveLength(3); + + // Reset everything + await registry.reset(); + + // Verify all servers are cleared after reset + const appConfigAfter = await registry.getServerConfig('app_server'); + const userConfigAfter = await registry.getServerConfig('user_server'); + const privateConfigAfter = await registry.getServerConfig('private_server', userId); + const allConfigsAfter = await registry.getAllServerConfigs(userId); + + expect(appConfigAfter).toBeUndefined(); + expect(userConfigAfter).toBeUndefined(); + expect(privateConfigAfter).toBeUndefined(); + expect(Object.keys(allConfigsAfter)).toHaveLength(0); + }); + }); +}); diff --git a/packages/api/src/mcp/registry/__tests__/MCPServersRegistry.test.ts b/packages/api/src/mcp/registry/__tests__/MCPServersRegistry.test.ts new file mode 100644 index 0000000000..db4b40a46b --- /dev/null +++ b/packages/api/src/mcp/registry/__tests__/MCPServersRegistry.test.ts @@ -0,0 +1,175 @@ +import * as t from '~/mcp/types'; +import { mcpServersRegistry as registry } from '~/mcp/registry/MCPServersRegistry'; + +/** + * Unit tests for MCPServersRegistry using in-memory cache. + * For integration tests using Redis-backed cache, see MCPServersRegistry.cache_integration.spec.ts + */ +describe('MCPServersRegistry', () => { + const testParsedConfig: t.ParsedServerConfig = { + type: 'stdio', + command: 'node', + args: ['tools.js'], + requiresOAuth: false, + serverInstructions: 'Instructions for file_tools_server', + tools: 'file_read, file_write', + capabilities: '{"tools":{"listChanged":true}}', + toolFunctions: { + file_read_mcp_file_tools_server: { + type: 'function', + function: { + name: 'file_read_mcp_file_tools_server', + description: 'Read a file', + parameters: { type: 'object' }, + }, + }, + }, + }; + + beforeEach(async () => { + await registry.reset(); + }); + + describe('private user servers', () => { + it('should add and remove private user server', async () => { + const userId = 'user123'; + const serverName = 'private_server'; + + // Add private user server + await registry.addPrivateUserServer(userId, serverName, testParsedConfig); + + // Verify server was added + const retrievedConfig = await registry.getServerConfig(serverName, userId); + expect(retrievedConfig).toEqual(testParsedConfig); + + // Remove private user server + await registry.removePrivateUserServer(userId, serverName); + + // Verify server was removed + const configAfterRemoval = await registry.getServerConfig(serverName, userId); + expect(configAfterRemoval).toBeUndefined(); + }); + + it('should throw error when adding duplicate private user server', async () => { + const userId = 'user123'; + const serverName = 'private_server'; + + await registry.addPrivateUserServer(userId, serverName, testParsedConfig); + await expect( + registry.addPrivateUserServer(userId, serverName, testParsedConfig), + ).rejects.toThrow( + 'Server "private_server" already exists in cache. Use update() to modify existing configs.', + ); + }); + + it('should update an existing private user server', async () => { + const userId = 'user123'; + const serverName = 'private_server'; + const updatedConfig: t.ParsedServerConfig = { + type: 'stdio', + command: 'python', + args: ['updated.py'], + requiresOAuth: true, + }; + + // Add private user server + await registry.addPrivateUserServer(userId, serverName, testParsedConfig); + + // Update the server config + await registry.updatePrivateUserServer(userId, serverName, updatedConfig); + + // Verify server was updated + const retrievedConfig = await registry.getServerConfig(serverName, userId); + expect(retrievedConfig).toEqual(updatedConfig); + }); + + it('should throw error when updating non-existent server', async () => { + const userId = 'user123'; + const serverName = 'private_server'; + + // Add a user cache first + await registry.addPrivateUserServer(userId, 'other_server', testParsedConfig); + + await expect( + registry.updatePrivateUserServer(userId, serverName, testParsedConfig), + ).rejects.toThrow( + 'Server "private_server" does not exist in cache. Use add() to create new configs.', + ); + }); + + it('should throw error when updating server for non-existent user', async () => { + const userId = 'nonexistent_user'; + const serverName = 'private_server'; + + await expect( + registry.updatePrivateUserServer(userId, serverName, testParsedConfig), + ).rejects.toThrow('No private servers found for user "nonexistent_user".'); + }); + }); + + describe('getAllServerConfigs', () => { + it('should return correct servers based on userId', async () => { + // Add servers to all three caches + await registry.sharedAppServers.add('app_server', testParsedConfig); + await registry.sharedUserServers.add('user_server', testParsedConfig); + await registry.addPrivateUserServer('abc', 'abc_private_server', testParsedConfig); + await registry.addPrivateUserServer('xyz', 'xyz_private_server', testParsedConfig); + + // Without userId: should return only shared app + shared user servers + const configsNoUser = await registry.getAllServerConfigs(); + expect(Object.keys(configsNoUser)).toHaveLength(2); + expect(configsNoUser).toHaveProperty('app_server'); + expect(configsNoUser).toHaveProperty('user_server'); + + // With userId 'abc': should return shared app + shared user + abc's private servers + const configsAbc = await registry.getAllServerConfigs('abc'); + expect(Object.keys(configsAbc)).toHaveLength(3); + expect(configsAbc).toHaveProperty('app_server'); + expect(configsAbc).toHaveProperty('user_server'); + expect(configsAbc).toHaveProperty('abc_private_server'); + + // With userId 'xyz': should return shared app + shared user + xyz's private servers + const configsXyz = await registry.getAllServerConfigs('xyz'); + expect(Object.keys(configsXyz)).toHaveLength(3); + expect(configsXyz).toHaveProperty('app_server'); + expect(configsXyz).toHaveProperty('user_server'); + expect(configsXyz).toHaveProperty('xyz_private_server'); + }); + }); + + describe('reset', () => { + it('should clear all servers from all caches (shared app, shared user, and private user)', async () => { + const userId = 'user123'; + + // Add servers to all three caches + await registry.sharedAppServers.add('app_server', testParsedConfig); + await registry.sharedUserServers.add('user_server', testParsedConfig); + await registry.addPrivateUserServer(userId, 'private_server', testParsedConfig); + + // Verify all servers are accessible before reset + const appConfigBefore = await registry.getServerConfig('app_server'); + const userConfigBefore = await registry.getServerConfig('user_server'); + const privateConfigBefore = await registry.getServerConfig('private_server', userId); + const allConfigsBefore = await registry.getAllServerConfigs(userId); + + expect(appConfigBefore).toEqual(testParsedConfig); + expect(userConfigBefore).toEqual(testParsedConfig); + expect(privateConfigBefore).toEqual(testParsedConfig); + expect(Object.keys(allConfigsBefore)).toHaveLength(3); + + // Reset everything + await registry.reset(); + + // Verify all servers are cleared after reset + const appConfigAfter = await registry.getServerConfig('app_server'); + const userConfigAfter = await registry.getServerConfig('user_server'); + const privateConfigAfter = await registry.getServerConfig('private_server', userId); + const allConfigsAfter = await registry.getAllServerConfigs(userId); + + expect(appConfigAfter).toBeUndefined(); + expect(userConfigAfter).toBeUndefined(); + expect(privateConfigAfter).toBeUndefined(); + expect(Object.keys(allConfigsAfter)).toHaveLength(0); + }); + }); +}); diff --git a/packages/api/src/mcp/registry/__tests__/mcpConnectionsMock.helper.ts b/packages/api/src/mcp/registry/__tests__/mcpConnectionsMock.helper.ts new file mode 100644 index 0000000000..74bc83425d --- /dev/null +++ b/packages/api/src/mcp/registry/__tests__/mcpConnectionsMock.helper.ts @@ -0,0 +1,55 @@ +import type { MCPConnection } from '~/mcp/connection'; + +/** + * Creates a single mock MCP connection for testing. + * The connection has a client with mocked methods that return server-specific data. + * @param serverName - Name of the server to create mock connection for + * @returns Mocked MCPConnection instance + */ +export function createMockConnection(serverName: string): jest.Mocked { + const mockClient = { + getInstructions: jest.fn().mockReturnValue(`instructions for ${serverName}`), + getServerCapabilities: jest.fn().mockReturnValue({ + tools: { listChanged: true }, + resources: { listChanged: true }, + prompts: { get: `getPrompts for ${serverName}` }, + }), + listTools: jest.fn().mockResolvedValue({ + tools: [ + { + name: 'listFiles', + description: `Description for ${serverName}'s listFiles tool`, + inputSchema: { + type: 'object', + properties: { + input: { type: 'string' }, + }, + }, + }, + ], + }), + }; + + return { + client: mockClient, + disconnect: jest.fn().mockResolvedValue(undefined), + } as unknown as jest.Mocked; +} + +/** + * Creates mock MCP connections for testing. + * Each connection has a client with mocked methods that return server-specific data. + * @param serverNames - Array of server names to create mock connections for + * @returns Map of server names to mocked MCPConnection instances + */ +export function createMockConnectionsMap( + serverNames: string[], +): Map> { + const mockConnections = new Map>(); + + serverNames.forEach((serverName) => { + mockConnections.set(serverName, createMockConnection(serverName)); + }); + + return mockConnections; +} diff --git a/packages/api/src/mcp/registry/cache/BaseRegistryCache.ts b/packages/api/src/mcp/registry/cache/BaseRegistryCache.ts new file mode 100644 index 0000000000..1d2266fc6d --- /dev/null +++ b/packages/api/src/mcp/registry/cache/BaseRegistryCache.ts @@ -0,0 +1,26 @@ +import type Keyv from 'keyv'; +import { isLeader } from '~/cluster'; + +/** + * Base class for MCP registry caches that require distributed leader coordination. + * Provides helper methods for leader-only operations and success validation. + * All concrete implementations must provide their own Keyv cache instance. + */ +export abstract class BaseRegistryCache { + protected readonly PREFIX = 'MCP::ServersRegistry'; + protected abstract readonly cache: Keyv; + + protected async leaderCheck(action: string): Promise { + if (!(await isLeader())) throw new Error(`Only leader can ${action}.`); + } + + protected successCheck(action: string, success: boolean): true { + if (!success) throw new Error(`Failed to ${action} in cache.`); + return true; + } + + public async reset(): Promise { + await this.leaderCheck(`reset ${this.cache.namespace} cache`); + await this.cache.clear(); + } +} diff --git a/packages/api/src/mcp/registry/cache/RegistryStatusCache.ts b/packages/api/src/mcp/registry/cache/RegistryStatusCache.ts new file mode 100644 index 0000000000..2a8fc72213 --- /dev/null +++ b/packages/api/src/mcp/registry/cache/RegistryStatusCache.ts @@ -0,0 +1,37 @@ +import { standardCache } from '~/cache'; +import { BaseRegistryCache } from './BaseRegistryCache'; + +// Status keys +const INITIALIZED = 'INITIALIZED'; + +/** + * Cache for tracking MCP Servers Registry metadata and status across distributed instances. + * Uses Redis-backed storage to coordinate state between leader and follower nodes. + * Currently, tracks initialization status to ensure only the leader performs initialization + * while followers wait for completion. Designed to be extended with additional registry + * metadata as needed (e.g., last update timestamps, version info, health status). + * This cache is only meant to be used internally by registry management components. + */ +class RegistryStatusCache extends BaseRegistryCache { + protected readonly cache = standardCache(`${this.PREFIX}::Status`); + + public async isInitialized(): Promise { + return (await this.get(INITIALIZED)) === true; + } + + public async setInitialized(value: boolean): Promise { + await this.set(INITIALIZED, value); + } + + private async get(key: string): Promise { + return this.cache.get(key); + } + + private async set(key: string, value: string | number | boolean, ttl?: number): Promise { + await this.leaderCheck('set MCP Servers Registry status'); + const success = await this.cache.set(key, value, ttl); + this.successCheck(`set status key "${key}"`, success); + } +} + +export const registryStatusCache = new RegistryStatusCache(); diff --git a/packages/api/src/mcp/registry/cache/ServerConfigsCacheFactory.ts b/packages/api/src/mcp/registry/cache/ServerConfigsCacheFactory.ts new file mode 100644 index 0000000000..72c664d844 --- /dev/null +++ b/packages/api/src/mcp/registry/cache/ServerConfigsCacheFactory.ts @@ -0,0 +1,31 @@ +import { cacheConfig } from '~/cache'; +import { ServerConfigsCacheInMemory } from './ServerConfigsCacheInMemory'; +import { ServerConfigsCacheRedis } from './ServerConfigsCacheRedis'; + +export type ServerConfigsCache = ServerConfigsCacheInMemory | ServerConfigsCacheRedis; + +/** + * Factory for creating the appropriate ServerConfigsCache implementation based on deployment mode. + * Automatically selects between in-memory and Redis-backed storage depending on USE_REDIS config. + * In single-instance mode (USE_REDIS=false), returns lightweight in-memory cache. + * In cluster mode (USE_REDIS=true), returns Redis-backed cache with distributed coordination. + * Provides a unified interface regardless of the underlying storage mechanism. + */ +export class ServerConfigsCacheFactory { + /** + * Create a ServerConfigsCache instance. + * Returns Redis implementation if Redis is configured, otherwise in-memory implementation. + * + * @param owner - The owner of the cache (e.g., 'user', 'global') - only used for Redis namespacing + * @param leaderOnly - Whether operations should only be performed by the leader (only applies to Redis) + * @returns ServerConfigsCache instance + */ + static create(owner: string, leaderOnly: boolean): ServerConfigsCache { + if (cacheConfig.USE_REDIS) { + return new ServerConfigsCacheRedis(owner, leaderOnly); + } + + // In-memory mode uses a simple Map - doesn't need owner/namespace + return new ServerConfigsCacheInMemory(); + } +} diff --git a/packages/api/src/mcp/registry/cache/ServerConfigsCacheInMemory.ts b/packages/api/src/mcp/registry/cache/ServerConfigsCacheInMemory.ts new file mode 100644 index 0000000000..1dd2385053 --- /dev/null +++ b/packages/api/src/mcp/registry/cache/ServerConfigsCacheInMemory.ts @@ -0,0 +1,46 @@ +import { ParsedServerConfig } from '~/mcp/types'; + +/** + * In-memory implementation of MCP server configurations cache for single-instance deployments. + * Uses a native JavaScript Map for fast, local storage without Redis dependencies. + * Suitable for development environments or single-server production deployments. + * Does not require leader checks or distributed coordination since data is instance-local. + * Data is lost on server restart and not shared across multiple server instances. + */ +export class ServerConfigsCacheInMemory { + private readonly cache: Map = new Map(); + + public async add(serverName: string, config: ParsedServerConfig): Promise { + if (this.cache.has(serverName)) + throw new Error( + `Server "${serverName}" already exists in cache. Use update() to modify existing configs.`, + ); + this.cache.set(serverName, config); + } + + public async update(serverName: string, config: ParsedServerConfig): Promise { + if (!this.cache.has(serverName)) + throw new Error( + `Server "${serverName}" does not exist in cache. Use add() to create new configs.`, + ); + this.cache.set(serverName, config); + } + + public async remove(serverName: string): Promise { + if (!this.cache.delete(serverName)) { + throw new Error(`Failed to remove server "${serverName}" in cache.`); + } + } + + public async get(serverName: string): Promise { + return this.cache.get(serverName); + } + + public async getAll(): Promise> { + return Object.fromEntries(this.cache); + } + + public async reset(): Promise { + this.cache.clear(); + } +} diff --git a/packages/api/src/mcp/registry/cache/ServerConfigsCacheRedis.ts b/packages/api/src/mcp/registry/cache/ServerConfigsCacheRedis.ts new file mode 100644 index 0000000000..a2e025736c --- /dev/null +++ b/packages/api/src/mcp/registry/cache/ServerConfigsCacheRedis.ts @@ -0,0 +1,80 @@ +import type Keyv from 'keyv'; +import { fromPairs } from 'lodash'; +import { standardCache, keyvRedisClient } from '~/cache'; +import { ParsedServerConfig } from '~/mcp/types'; +import { BaseRegistryCache } from './BaseRegistryCache'; + +/** + * Redis-backed implementation of MCP server configurations cache for distributed deployments. + * Stores server configs in Redis with namespace isolation by owner (App, User, or specific user ID). + * Enables data sharing across multiple server instances in a cluster environment. + * Supports optional leader-only write operations to prevent race conditions during initialization. + * Data persists across server restarts and is accessible from any instance in the cluster. + */ +export class ServerConfigsCacheRedis extends BaseRegistryCache { + protected readonly cache: Keyv; + private readonly owner: string; + private readonly leaderOnly: boolean; + + constructor(owner: string, leaderOnly: boolean) { + super(); + this.owner = owner; + this.leaderOnly = leaderOnly; + this.cache = standardCache(`${this.PREFIX}::Servers::${owner}`); + } + + public async add(serverName: string, config: ParsedServerConfig): Promise { + if (this.leaderOnly) await this.leaderCheck(`add ${this.owner} MCP servers`); + const exists = await this.cache.has(serverName); + if (exists) + throw new Error( + `Server "${serverName}" already exists in cache. Use update() to modify existing configs.`, + ); + const success = await this.cache.set(serverName, config); + this.successCheck(`add ${this.owner} server "${serverName}"`, success); + } + + public async update(serverName: string, config: ParsedServerConfig): Promise { + if (this.leaderOnly) await this.leaderCheck(`update ${this.owner} MCP servers`); + const exists = await this.cache.has(serverName); + if (!exists) + throw new Error( + `Server "${serverName}" does not exist in cache. Use add() to create new configs.`, + ); + const success = await this.cache.set(serverName, config); + this.successCheck(`update ${this.owner} server "${serverName}"`, success); + } + + public async remove(serverName: string): Promise { + if (this.leaderOnly) await this.leaderCheck(`remove ${this.owner} MCP servers`); + const success = await this.cache.delete(serverName); + this.successCheck(`remove ${this.owner} server "${serverName}"`, success); + } + + public async get(serverName: string): Promise { + return this.cache.get(serverName); + } + + public async getAll(): Promise> { + // Use Redis SCAN iterator directly (non-blocking, production-ready) + // Note: Keyv uses a single colon ':' between namespace and key, even if GLOBAL_PREFIX_SEPARATOR is '::' + const pattern = `*${this.cache.namespace}:*`; + const entries: Array<[string, ParsedServerConfig]> = []; + + // Use scanIterator from Redis client + if (keyvRedisClient && 'scanIterator' in keyvRedisClient) { + for await (const key of keyvRedisClient.scanIterator({ MATCH: pattern })) { + // Extract the actual key name (last part after final colon) + // Full key format: "prefix::namespace:keyName" + const lastColonIndex = key.lastIndexOf(':'); + const keyName = key.substring(lastColonIndex + 1); + const value = await this.cache.get(keyName); + if (value) { + entries.push([keyName, value as ParsedServerConfig]); + } + } + } + + return fromPairs(entries); + } +} diff --git a/packages/api/src/mcp/registry/cache/__tests__/RegistryStatusCache.cache_integration.spec.ts b/packages/api/src/mcp/registry/cache/__tests__/RegistryStatusCache.cache_integration.spec.ts new file mode 100644 index 0000000000..643e7c27df --- /dev/null +++ b/packages/api/src/mcp/registry/cache/__tests__/RegistryStatusCache.cache_integration.spec.ts @@ -0,0 +1,73 @@ +import { expect } from '@playwright/test'; + +describe('RegistryStatusCache Integration Tests', () => { + let registryStatusCache: typeof import('../RegistryStatusCache').registryStatusCache; + let keyvRedisClient: Awaited['keyvRedisClient']; + let LeaderElection: typeof import('~/cluster/LeaderElection').LeaderElection; + let leaderInstance: InstanceType; + + beforeAll(async () => { + // Set up environment variables for Redis (only if not already set) + process.env.USE_REDIS = process.env.USE_REDIS ?? 'true'; + process.env.REDIS_URI = process.env.REDIS_URI ?? 'redis://127.0.0.1:6379'; + process.env.REDIS_KEY_PREFIX = + process.env.REDIS_KEY_PREFIX ?? 'RegistryStatusCache-IntegrationTest'; + + // Import modules after setting env vars + const statusCacheModule = await import('../RegistryStatusCache'); + const redisClients = await import('~/cache/redisClients'); + const leaderElectionModule = await import('~/cluster/LeaderElection'); + + registryStatusCache = statusCacheModule.registryStatusCache; + keyvRedisClient = redisClients.keyvRedisClient; + LeaderElection = leaderElectionModule.LeaderElection; + + // Ensure Redis is connected + if (!keyvRedisClient) throw new Error('Redis client is not initialized'); + + // Wait for Redis to be ready + if (!keyvRedisClient.isOpen) await keyvRedisClient.connect(); + + // Become leader so we can perform write operations + leaderInstance = new LeaderElection(); + const isLeader = await leaderInstance.isLeader(); + expect(isLeader).toBe(true); + }); + + afterEach(async () => { + // Clean up: clear all test keys from Redis + if (keyvRedisClient) { + const pattern = '*RegistryStatusCache-IntegrationTest*'; + if ('scanIterator' in keyvRedisClient) { + for await (const key of keyvRedisClient.scanIterator({ MATCH: pattern })) { + await keyvRedisClient.del(key); + } + } + } + }); + + afterAll(async () => { + // Resign as leader + if (leaderInstance) await leaderInstance.resign(); + + // Close Redis connection + if (keyvRedisClient?.isOpen) await keyvRedisClient.disconnect(); + }); + + describe('Initialization status tracking', () => { + it('should return false for isInitialized when not set', async () => { + const initialized = await registryStatusCache.isInitialized(); + expect(initialized).toBe(false); + }); + + it('should set and get initialized status', async () => { + await registryStatusCache.setInitialized(true); + const initialized = await registryStatusCache.isInitialized(); + expect(initialized).toBe(true); + + await registryStatusCache.setInitialized(false); + const uninitialized = await registryStatusCache.isInitialized(); + expect(uninitialized).toBe(false); + }); + }); +}); diff --git a/packages/api/src/mcp/registry/cache/__tests__/ServerConfigsCacheFactory.test.ts b/packages/api/src/mcp/registry/cache/__tests__/ServerConfigsCacheFactory.test.ts new file mode 100644 index 0000000000..d1e0a0d486 --- /dev/null +++ b/packages/api/src/mcp/registry/cache/__tests__/ServerConfigsCacheFactory.test.ts @@ -0,0 +1,70 @@ +import { ServerConfigsCacheFactory } from '../ServerConfigsCacheFactory'; +import { ServerConfigsCacheInMemory } from '../ServerConfigsCacheInMemory'; +import { ServerConfigsCacheRedis } from '../ServerConfigsCacheRedis'; +import { cacheConfig } from '~/cache'; + +// Mock the cache implementations +jest.mock('../ServerConfigsCacheInMemory'); +jest.mock('../ServerConfigsCacheRedis'); + +// Mock the cache config module +jest.mock('~/cache', () => ({ + cacheConfig: { + USE_REDIS: false, + }, +})); + +describe('ServerConfigsCacheFactory', () => { + beforeEach(() => { + jest.clearAllMocks(); + }); + + describe('create()', () => { + it('should return ServerConfigsCacheRedis when USE_REDIS is true', () => { + // Arrange + cacheConfig.USE_REDIS = true; + + // Act + const cache = ServerConfigsCacheFactory.create('TestOwner', true); + + // Assert + expect(cache).toBeInstanceOf(ServerConfigsCacheRedis); + expect(ServerConfigsCacheRedis).toHaveBeenCalledWith('TestOwner', true); + }); + + it('should return ServerConfigsCacheInMemory when USE_REDIS is false', () => { + // Arrange + cacheConfig.USE_REDIS = false; + + // Act + const cache = ServerConfigsCacheFactory.create('TestOwner', false); + + // Assert + expect(cache).toBeInstanceOf(ServerConfigsCacheInMemory); + expect(ServerConfigsCacheInMemory).toHaveBeenCalled(); + }); + + it('should pass correct parameters to ServerConfigsCacheRedis', () => { + // Arrange + cacheConfig.USE_REDIS = true; + + // Act + ServerConfigsCacheFactory.create('App', true); + + // Assert + expect(ServerConfigsCacheRedis).toHaveBeenCalledWith('App', true); + }); + + it('should create ServerConfigsCacheInMemory without parameters when USE_REDIS is false', () => { + // Arrange + cacheConfig.USE_REDIS = false; + + // Act + ServerConfigsCacheFactory.create('User', false); + + // Assert + // In-memory cache doesn't use owner/leaderOnly parameters + expect(ServerConfigsCacheInMemory).toHaveBeenCalledWith(); + }); + }); +}); diff --git a/packages/api/src/mcp/registry/cache/__tests__/ServerConfigsCacheInMemory.test.ts b/packages/api/src/mcp/registry/cache/__tests__/ServerConfigsCacheInMemory.test.ts new file mode 100644 index 0000000000..e2033d0ba8 --- /dev/null +++ b/packages/api/src/mcp/registry/cache/__tests__/ServerConfigsCacheInMemory.test.ts @@ -0,0 +1,173 @@ +import { expect } from '@playwright/test'; +import { ParsedServerConfig } from '~/mcp/types'; + +describe('ServerConfigsCacheInMemory Integration Tests', () => { + let ServerConfigsCacheInMemory: typeof import('../ServerConfigsCacheInMemory').ServerConfigsCacheInMemory; + let cache: InstanceType< + typeof import('../ServerConfigsCacheInMemory').ServerConfigsCacheInMemory + >; + + // Test data + const mockConfig1: ParsedServerConfig = { + command: 'node', + args: ['server1.js'], + env: { TEST: 'value1' }, + }; + + const mockConfig2: ParsedServerConfig = { + command: 'python', + args: ['server2.py'], + env: { TEST: 'value2' }, + }; + + const mockConfig3: ParsedServerConfig = { + command: 'node', + args: ['server3.js'], + url: 'http://localhost:3000', + requiresOAuth: true, + }; + + beforeAll(async () => { + // Import modules + const cacheModule = await import('../ServerConfigsCacheInMemory'); + ServerConfigsCacheInMemory = cacheModule.ServerConfigsCacheInMemory; + }); + + beforeEach(() => { + // Create a fresh instance for each test + cache = new ServerConfigsCacheInMemory(); + }); + + describe('add and get operations', () => { + it('should add and retrieve a server config', async () => { + await cache.add('server1', mockConfig1); + const result = await cache.get('server1'); + expect(result).toEqual(mockConfig1); + }); + + it('should return undefined for non-existent server', async () => { + const result = await cache.get('non-existent'); + expect(result).toBeUndefined(); + }); + + it('should throw error when adding duplicate server', async () => { + await cache.add('server1', mockConfig1); + await expect(cache.add('server1', mockConfig2)).rejects.toThrow( + 'Server "server1" already exists in cache. Use update() to modify existing configs.', + ); + }); + + it('should handle multiple server configs', async () => { + await cache.add('server1', mockConfig1); + await cache.add('server2', mockConfig2); + await cache.add('server3', mockConfig3); + + const result1 = await cache.get('server1'); + const result2 = await cache.get('server2'); + const result3 = await cache.get('server3'); + + expect(result1).toEqual(mockConfig1); + expect(result2).toEqual(mockConfig2); + expect(result3).toEqual(mockConfig3); + }); + }); + + describe('getAll operation', () => { + it('should return empty object when no servers exist', async () => { + const result = await cache.getAll(); + expect(result).toEqual({}); + }); + + it('should return all server configs', async () => { + await cache.add('server1', mockConfig1); + await cache.add('server2', mockConfig2); + await cache.add('server3', mockConfig3); + + const result = await cache.getAll(); + expect(result).toEqual({ + server1: mockConfig1, + server2: mockConfig2, + server3: mockConfig3, + }); + }); + + it('should reflect updates in getAll', async () => { + await cache.add('server1', mockConfig1); + await cache.add('server2', mockConfig2); + + let result = await cache.getAll(); + expect(Object.keys(result).length).toBe(2); + + await cache.add('server3', mockConfig3); + result = await cache.getAll(); + expect(Object.keys(result).length).toBe(3); + expect(result.server3).toEqual(mockConfig3); + }); + }); + + describe('update operation', () => { + it('should update an existing server config', async () => { + await cache.add('server1', mockConfig1); + expect(await cache.get('server1')).toEqual(mockConfig1); + + await cache.update('server1', mockConfig2); + const result = await cache.get('server1'); + expect(result).toEqual(mockConfig2); + }); + + it('should throw error when updating non-existent server', async () => { + await expect(cache.update('non-existent', mockConfig1)).rejects.toThrow( + 'Server "non-existent" does not exist in cache. Use add() to create new configs.', + ); + }); + + it('should reflect updates in getAll', async () => { + await cache.add('server1', mockConfig1); + await cache.add('server2', mockConfig2); + + await cache.update('server1', mockConfig3); + const result = await cache.getAll(); + expect(result.server1).toEqual(mockConfig3); + expect(result.server2).toEqual(mockConfig2); + }); + }); + + describe('remove operation', () => { + it('should remove an existing server config', async () => { + await cache.add('server1', mockConfig1); + expect(await cache.get('server1')).toEqual(mockConfig1); + + await cache.remove('server1'); + expect(await cache.get('server1')).toBeUndefined(); + }); + + it('should throw error when removing non-existent server', async () => { + await expect(cache.remove('non-existent')).rejects.toThrow( + 'Failed to remove server "non-existent" in cache.', + ); + }); + + it('should remove server from getAll results', async () => { + await cache.add('server1', mockConfig1); + await cache.add('server2', mockConfig2); + + let result = await cache.getAll(); + expect(Object.keys(result).length).toBe(2); + + await cache.remove('server1'); + result = await cache.getAll(); + expect(Object.keys(result).length).toBe(1); + expect(result.server1).toBeUndefined(); + expect(result.server2).toEqual(mockConfig2); + }); + + it('should allow re-adding a removed server', async () => { + await cache.add('server1', mockConfig1); + await cache.remove('server1'); + await cache.add('server1', mockConfig3); + + const result = await cache.get('server1'); + expect(result).toEqual(mockConfig3); + }); + }); +}); diff --git a/packages/api/src/mcp/registry/cache/__tests__/ServerConfigsCacheRedis.cache_integration.spec.ts b/packages/api/src/mcp/registry/cache/__tests__/ServerConfigsCacheRedis.cache_integration.spec.ts new file mode 100644 index 0000000000..7e139dc5be --- /dev/null +++ b/packages/api/src/mcp/registry/cache/__tests__/ServerConfigsCacheRedis.cache_integration.spec.ts @@ -0,0 +1,278 @@ +import { expect } from '@playwright/test'; +import { ParsedServerConfig } from '~/mcp/types'; + +describe('ServerConfigsCacheRedis Integration Tests', () => { + let ServerConfigsCacheRedis: typeof import('../ServerConfigsCacheRedis').ServerConfigsCacheRedis; + let keyvRedisClient: Awaited['keyvRedisClient']; + let LeaderElection: typeof import('~/cluster/LeaderElection').LeaderElection; + let checkIsLeader: () => Promise; + let cache: InstanceType; + + // Test data + const mockConfig1: ParsedServerConfig = { + command: 'node', + args: ['server1.js'], + env: { TEST: 'value1' }, + }; + + const mockConfig2: ParsedServerConfig = { + command: 'python', + args: ['server2.py'], + env: { TEST: 'value2' }, + }; + + const mockConfig3: ParsedServerConfig = { + command: 'node', + args: ['server3.js'], + url: 'http://localhost:3000', + requiresOAuth: true, + }; + + beforeAll(async () => { + // Set up environment variables for Redis (only if not already set) + process.env.USE_REDIS = process.env.USE_REDIS ?? 'true'; + process.env.REDIS_URI = process.env.REDIS_URI ?? 'redis://127.0.0.1:6379'; + process.env.REDIS_KEY_PREFIX = + process.env.REDIS_KEY_PREFIX ?? 'ServerConfigsCacheRedis-IntegrationTest'; + + // Import modules after setting env vars + const cacheModule = await import('../ServerConfigsCacheRedis'); + const redisClients = await import('~/cache/redisClients'); + const leaderElectionModule = await import('~/cluster/LeaderElection'); + const clusterModule = await import('~/cluster'); + + ServerConfigsCacheRedis = cacheModule.ServerConfigsCacheRedis; + keyvRedisClient = redisClients.keyvRedisClient; + LeaderElection = leaderElectionModule.LeaderElection; + checkIsLeader = clusterModule.isLeader; + + // Ensure Redis is connected + if (!keyvRedisClient) throw new Error('Redis client is not initialized'); + + // Wait for Redis to be ready + if (!keyvRedisClient.isOpen) await keyvRedisClient.connect(); + + // Clear any existing leader key to ensure clean state + await keyvRedisClient.del(LeaderElection.LEADER_KEY); + + // Become leader so we can perform write operations (using default election instance) + const isLeader = await checkIsLeader(); + expect(isLeader).toBe(true); + }); + + beforeEach(() => { + // Create a fresh instance for each test with leaderOnly=true + cache = new ServerConfigsCacheRedis('test-user', true); + }); + + afterEach(async () => { + // Clean up: clear all test keys from Redis + if (keyvRedisClient) { + const pattern = '*ServerConfigsCacheRedis-IntegrationTest*'; + if ('scanIterator' in keyvRedisClient) { + for await (const key of keyvRedisClient.scanIterator({ MATCH: pattern })) { + await keyvRedisClient.del(key); + } + } + } + }); + + afterAll(async () => { + // Clear leader key to allow other tests to become leader + if (keyvRedisClient) await keyvRedisClient.del(LeaderElection.LEADER_KEY); + + // Close Redis connection + if (keyvRedisClient?.isOpen) await keyvRedisClient.disconnect(); + }); + + describe('add and get operations', () => { + it('should add and retrieve a server config', async () => { + await cache.add('server1', mockConfig1); + const result = await cache.get('server1'); + expect(result).toEqual(mockConfig1); + }); + + it('should return undefined for non-existent server', async () => { + const result = await cache.get('non-existent'); + expect(result).toBeUndefined(); + }); + + it('should throw error when adding duplicate server', async () => { + await cache.add('server1', mockConfig1); + await expect(cache.add('server1', mockConfig2)).rejects.toThrow( + 'Server "server1" already exists in cache. Use update() to modify existing configs.', + ); + }); + + it('should handle multiple server configs', async () => { + await cache.add('server1', mockConfig1); + await cache.add('server2', mockConfig2); + await cache.add('server3', mockConfig3); + + const result1 = await cache.get('server1'); + const result2 = await cache.get('server2'); + const result3 = await cache.get('server3'); + + expect(result1).toEqual(mockConfig1); + expect(result2).toEqual(mockConfig2); + expect(result3).toEqual(mockConfig3); + }); + + it('should isolate caches by owner namespace', async () => { + const userCache = new ServerConfigsCacheRedis('user1', true); + const globalCache = new ServerConfigsCacheRedis('global', true); + + await userCache.add('server1', mockConfig1); + await globalCache.add('server1', mockConfig2); + + const userResult = await userCache.get('server1'); + const globalResult = await globalCache.get('server1'); + + expect(userResult).toEqual(mockConfig1); + expect(globalResult).toEqual(mockConfig2); + }); + }); + + describe('getAll operation', () => { + it('should return empty object when no servers exist', async () => { + const result = await cache.getAll(); + expect(result).toEqual({}); + }); + + it('should return all server configs', async () => { + await cache.add('server1', mockConfig1); + await cache.add('server2', mockConfig2); + await cache.add('server3', mockConfig3); + + const result = await cache.getAll(); + expect(result).toEqual({ + server1: mockConfig1, + server2: mockConfig2, + server3: mockConfig3, + }); + }); + + it('should reflect updates in getAll', async () => { + await cache.add('server1', mockConfig1); + await cache.add('server2', mockConfig2); + + let result = await cache.getAll(); + expect(Object.keys(result).length).toBe(2); + + await cache.add('server3', mockConfig3); + result = await cache.getAll(); + expect(Object.keys(result).length).toBe(3); + expect(result.server3).toEqual(mockConfig3); + }); + + it('should only return configs for the specific owner', async () => { + const userCache = new ServerConfigsCacheRedis('user1', true); + const globalCache = new ServerConfigsCacheRedis('global', true); + + await userCache.add('server1', mockConfig1); + await userCache.add('server2', mockConfig2); + await globalCache.add('server3', mockConfig3); + + const userResult = await userCache.getAll(); + const globalResult = await globalCache.getAll(); + + expect(Object.keys(userResult).length).toBe(2); + expect(Object.keys(globalResult).length).toBe(1); + expect(userResult.server1).toEqual(mockConfig1); + expect(userResult.server3).toBeUndefined(); + expect(globalResult.server3).toEqual(mockConfig3); + }); + }); + + describe('update operation', () => { + it('should update an existing server config', async () => { + await cache.add('server1', mockConfig1); + expect(await cache.get('server1')).toEqual(mockConfig1); + + await cache.update('server1', mockConfig2); + const result = await cache.get('server1'); + expect(result).toEqual(mockConfig2); + }); + + it('should throw error when updating non-existent server', async () => { + await expect(cache.update('non-existent', mockConfig1)).rejects.toThrow( + 'Server "non-existent" does not exist in cache. Use add() to create new configs.', + ); + }); + + it('should reflect updates in getAll', async () => { + await cache.add('server1', mockConfig1); + await cache.add('server2', mockConfig2); + + await cache.update('server1', mockConfig3); + const result = await cache.getAll(); + expect(result.server1).toEqual(mockConfig3); + expect(result.server2).toEqual(mockConfig2); + }); + + it('should only update in the specific owner namespace', async () => { + const userCache = new ServerConfigsCacheRedis('user1', true); + const globalCache = new ServerConfigsCacheRedis('global', true); + + await userCache.add('server1', mockConfig1); + await globalCache.add('server1', mockConfig2); + + await userCache.update('server1', mockConfig3); + + expect(await userCache.get('server1')).toEqual(mockConfig3); + expect(await globalCache.get('server1')).toEqual(mockConfig2); + }); + }); + + describe('remove operation', () => { + it('should remove an existing server config', async () => { + await cache.add('server1', mockConfig1); + expect(await cache.get('server1')).toEqual(mockConfig1); + + await cache.remove('server1'); + expect(await cache.get('server1')).toBeUndefined(); + }); + + it('should throw error when removing non-existent server', async () => { + await expect(cache.remove('non-existent')).rejects.toThrow( + 'Failed to remove test-user server "non-existent"', + ); + }); + + it('should remove server from getAll results', async () => { + await cache.add('server1', mockConfig1); + await cache.add('server2', mockConfig2); + + let result = await cache.getAll(); + expect(Object.keys(result).length).toBe(2); + + await cache.remove('server1'); + result = await cache.getAll(); + expect(Object.keys(result).length).toBe(1); + expect(result.server1).toBeUndefined(); + expect(result.server2).toEqual(mockConfig2); + }); + + it('should allow re-adding a removed server', async () => { + await cache.add('server1', mockConfig1); + await cache.remove('server1'); + await cache.add('server1', mockConfig3); + + const result = await cache.get('server1'); + expect(result).toEqual(mockConfig3); + }); + + it('should only remove from the specific owner namespace', async () => { + const userCache = new ServerConfigsCacheRedis('user1', true); + const globalCache = new ServerConfigsCacheRedis('global', true); + + await userCache.add('server1', mockConfig1); + await globalCache.add('server1', mockConfig2); + + await userCache.remove('server1'); + + expect(await userCache.get('server1')).toBeUndefined(); + expect(await globalCache.get('server1')).toEqual(mockConfig2); + }); + }); +}); diff --git a/packages/api/src/mcp/types/index.ts b/packages/api/src/mcp/types/index.ts index 5cf003b9f5..6e445e26ad 100644 --- a/packages/api/src/mcp/types/index.ts +++ b/packages/api/src/mcp/types/index.ts @@ -151,6 +151,8 @@ export type ParsedServerConfig = MCPOptions & { oauthMetadata?: Record | null; capabilities?: string; tools?: string; + toolFunctions?: LCAvailableTools; + initDuration?: number; }; export interface BasicConnectionOptions { diff --git a/packages/api/src/utils/index.ts b/packages/api/src/utils/index.ts index 9fd3b01885..85c99d108f 100644 --- a/packages/api/src/utils/index.ts +++ b/packages/api/src/utils/index.ts @@ -10,6 +10,7 @@ export * from './key'; export * from './llm'; export * from './math'; export * from './openid'; +export * from './promise'; export * from './sanitizeTitle'; export * from './tempChatRetention'; export * from './text'; diff --git a/packages/api/src/utils/promise.spec.ts b/packages/api/src/utils/promise.spec.ts new file mode 100644 index 0000000000..c43c8bf739 --- /dev/null +++ b/packages/api/src/utils/promise.spec.ts @@ -0,0 +1,115 @@ +import { withTimeout } from './promise'; + +describe('withTimeout', () => { + beforeEach(() => { + jest.clearAllTimers(); + }); + + it('should resolve when promise completes before timeout', async () => { + const promise = Promise.resolve('success'); + const result = await withTimeout(promise, 1000); + expect(result).toBe('success'); + }); + + it('should reject when promise rejects before timeout', async () => { + const promise = Promise.reject(new Error('test error')); + await expect(withTimeout(promise, 1000)).rejects.toThrow('test error'); + }); + + it('should timeout when promise takes too long', async () => { + const promise = new Promise((resolve) => setTimeout(() => resolve('late'), 2000)); + await expect(withTimeout(promise, 100, 'Custom timeout message')).rejects.toThrow( + 'Custom timeout message', + ); + }); + + it('should use default error message when none provided', async () => { + const promise = new Promise((resolve) => setTimeout(() => resolve('late'), 2000)); + await expect(withTimeout(promise, 100)).rejects.toThrow('Operation timed out after 100ms'); + }); + + it('should clear timeout when promise resolves', async () => { + const clearTimeoutSpy = jest.spyOn(global, 'clearTimeout'); + const promise = Promise.resolve('fast'); + + await withTimeout(promise, 1000); + + expect(clearTimeoutSpy).toHaveBeenCalled(); + clearTimeoutSpy.mockRestore(); + }); + + it('should clear timeout when promise rejects', async () => { + const clearTimeoutSpy = jest.spyOn(global, 'clearTimeout'); + const promise = Promise.reject(new Error('fail')); + + await expect(withTimeout(promise, 1000)).rejects.toThrow('fail'); + + expect(clearTimeoutSpy).toHaveBeenCalled(); + clearTimeoutSpy.mockRestore(); + }); + + it('should handle multiple concurrent timeouts', async () => { + const promise1 = Promise.resolve('first'); + const promise2 = new Promise((resolve) => setTimeout(() => resolve('second'), 50)); + const promise3 = new Promise((resolve) => setTimeout(() => resolve('third'), 2000)); + + const [result1, result2] = await Promise.all([ + withTimeout(promise1, 1000), + withTimeout(promise2, 1000), + ]); + + expect(result1).toBe('first'); + expect(result2).toBe('second'); + + await expect(withTimeout(promise3, 100)).rejects.toThrow('Operation timed out after 100ms'); + }); + + it('should work with async functions', async () => { + const asyncFunction = async () => { + await new Promise((resolve) => setTimeout(resolve, 10)); + return 'async result'; + }; + + const result = await withTimeout(asyncFunction(), 1000); + expect(result).toBe('async result'); + }); + + it('should work with any return type', async () => { + const numberPromise = Promise.resolve(42); + const objectPromise = Promise.resolve({ key: 'value' }); + const arrayPromise = Promise.resolve([1, 2, 3]); + + expect(await withTimeout(numberPromise, 1000)).toBe(42); + expect(await withTimeout(objectPromise, 1000)).toEqual({ key: 'value' }); + expect(await withTimeout(arrayPromise, 1000)).toEqual([1, 2, 3]); + }); + + it('should call logger when timeout occurs', async () => { + const loggerMock = jest.fn(); + const promise = new Promise((resolve) => setTimeout(() => resolve('late'), 2000)); + const errorMessage = 'Custom timeout with logger'; + + await expect(withTimeout(promise, 100, errorMessage, loggerMock)).rejects.toThrow(errorMessage); + + expect(loggerMock).toHaveBeenCalledTimes(1); + expect(loggerMock).toHaveBeenCalledWith(errorMessage, expect.any(Error)); + }); + + it('should not call logger when promise resolves', async () => { + const loggerMock = jest.fn(); + const promise = Promise.resolve('success'); + + const result = await withTimeout(promise, 1000, 'Should not timeout', loggerMock); + + expect(result).toBe('success'); + expect(loggerMock).not.toHaveBeenCalled(); + }); + + it('should work without logger parameter', async () => { + const promise = new Promise((resolve) => setTimeout(() => resolve('late'), 2000)); + + await expect(withTimeout(promise, 100, 'No logger provided')).rejects.toThrow( + 'No logger provided', + ); + }); +}); diff --git a/packages/api/src/utils/promise.ts b/packages/api/src/utils/promise.ts new file mode 100644 index 0000000000..72719a3ff0 --- /dev/null +++ b/packages/api/src/utils/promise.ts @@ -0,0 +1,42 @@ +/** + * Wraps a promise with a timeout. If the promise doesn't resolve/reject within + * the specified time, it will be rejected with a timeout error. + * + * @param promise - The promise to wrap with a timeout + * @param timeoutMs - Timeout duration in milliseconds + * @param errorMessage - Custom error message for timeout (optional) + * @param logger - Optional logger function to log timeout errors (e.g., console.warn, logger.warn) + * @returns Promise that resolves/rejects with the original promise or times out + * + * @example + * ```typescript + * const result = await withTimeout( + * fetchData(), + * 5000, + * 'Failed to fetch data within 5 seconds', + * console.warn + * ); + * ``` + */ +export async function withTimeout( + promise: Promise, + timeoutMs: number, + errorMessage?: string, + logger?: (message: string, error: Error) => void, +): Promise { + let timeoutId: NodeJS.Timeout; + + const timeoutPromise = new Promise((_, reject) => { + timeoutId = setTimeout(() => { + const error = new Error(errorMessage ?? `Operation timed out after ${timeoutMs}ms`); + if (logger) logger(error.message, error); + reject(error); + }, timeoutMs); + }); + + try { + return await Promise.race([promise, timeoutPromise]); + } finally { + clearTimeout(timeoutId!); + } +}