From 8780a7816503f3bc5dfe236a978e8bd36e720269 Mon Sep 17 00:00:00 2001 From: "Theo N. Truong" <644650+nhtruong@users.noreply.github.com> Date: Wed, 13 Aug 2025 09:45:06 -0600 Subject: [PATCH] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20refactor:=20MCPManager=20f?= =?UTF-8?q?or=20Scalability,=20Fix=20App-Level=20Detection,=20Add=20Lazy?= =?UTF-8?q?=20Connections=20(#8930)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: MCP Connection management overhaul - Making MCPManager manageable Refactor the monolithic MCPManager into focused, single-responsibility classes: • MCPServersRegistry: Server configuration discovery and metadata management • UserConnectionManager: Manages user-level connections • ConnectionsRepository: Low-level connection pool with lazy loading • MCPConnectionFactory: Handles MCP connection creation with OAuth support New Features: • Lazy loading of app-level connections for horizontal scaling • Automatic reconnection for app-level connections • Enhanced OAuth detection with explicit requiresOAuth flag • Centralized MCP configuration management Bug Fixes: • App-level connection detection in MCPManager.callTool • MCP Connection Reinitialization route behavior Optimizations: • MCPConnection.isConnected() caching to reduce overhead • Concurrent server metadata retrieval instead of sequential This refactoring addresses scalability bottlenecks and improves reliability while maintaining backward compatibility with existing configurations. * feat: Enabled import order in eslint. * # Moved tests to __tests__ folder # added tests for MCPServersRegistry.ts * # Add unit tests for ConnectionsRepository functionality * # Add unit tests for MCPConnectionFactory functionality * # Reorganize MCP connection tests and improve error handling * # reordering imports * # Update testPathIgnorePatterns in jest.config.mjs to exclude development TypeScript files * # removed mcp/manager.ts --- .env.example | 13 + .github/CONTRIBUTING.md | 4 +- .gitignore | 1 + api/config/index.js | 19 +- api/server/routes/__tests__/mcp.spec.js | 295 ++--- api/server/routes/mcp.js | 32 +- api/server/services/initializeMCPs.js | 29 +- eslint.config.mjs | 57 +- packages/api/jest.config.mjs | 2 +- packages/api/src/index.ts | 2 +- packages/api/src/mcp/ConnectionsRepository.ts | 87 ++ packages/api/src/mcp/MCPConnectionFactory.ts | 384 ++++++ packages/api/src/mcp/MCPManager.ts | 263 ++++ packages/api/src/mcp/MCPServersRegistry.ts | 200 +++ packages/api/src/mcp/UserConnectionManager.ts | 236 ++++ .../__tests__/ConnectionsRepository.test.ts | 212 +++ .../__tests__/MCPConnectionFactory.test.ts | 347 +++++ .../mcp/__tests__/MCPServersRegistry.test.ts | 287 +++++ .../api/src/mcp/{ => __tests__}/auth.test.ts | 2 +- .../__tests__/detectOAuth.integration.dev.ts | 76 ++ .../MCPServersRegistry.parsedConfigs.yml | 74 ++ .../MCPServersRegistry.rawConfigs.yml | 53 + .../mcp/{oauth => __tests__}/handler.test.ts | 2 +- .../api/src/mcp/{ => __tests__}/mcp.spec.ts | 0 .../api/src/mcp/{ => __tests__}/utils.test.ts | 2 +- .../api/src/mcp/{ => __tests__}/zod.spec.ts | 2 +- packages/api/src/mcp/connection.ts | 71 +- packages/api/src/mcp/manager.ts | 1143 ----------------- packages/api/src/mcp/mcpConfig.ts | 11 + packages/api/src/mcp/oauth/detectOAuth.ts | 120 ++ packages/api/src/mcp/oauth/index.ts | 1 + packages/data-provider/src/mcp.ts | 12 + 32 files changed, 2571 insertions(+), 1468 deletions(-) create mode 100644 packages/api/src/mcp/ConnectionsRepository.ts create mode 100644 packages/api/src/mcp/MCPConnectionFactory.ts create mode 100644 packages/api/src/mcp/MCPManager.ts create mode 100644 packages/api/src/mcp/MCPServersRegistry.ts create mode 100644 packages/api/src/mcp/UserConnectionManager.ts create mode 100644 packages/api/src/mcp/__tests__/ConnectionsRepository.test.ts create mode 100644 packages/api/src/mcp/__tests__/MCPConnectionFactory.test.ts create mode 100644 packages/api/src/mcp/__tests__/MCPServersRegistry.test.ts rename packages/api/src/mcp/{ => __tests__}/auth.test.ts (99%) create mode 100644 packages/api/src/mcp/__tests__/detectOAuth.integration.dev.ts create mode 100644 packages/api/src/mcp/__tests__/fixtures/MCPServersRegistry.parsedConfigs.yml create mode 100644 packages/api/src/mcp/__tests__/fixtures/MCPServersRegistry.rawConfigs.yml rename packages/api/src/mcp/{oauth => __tests__}/handler.test.ts (99%) rename packages/api/src/mcp/{ => __tests__}/mcp.spec.ts (100%) rename packages/api/src/mcp/{ => __tests__}/utils.test.ts (95%) rename packages/api/src/mcp/{ => __tests__}/zod.spec.ts (99%) delete mode 100644 packages/api/src/mcp/manager.ts create mode 100644 packages/api/src/mcp/mcpConfig.ts create mode 100644 packages/api/src/mcp/oauth/detectOAuth.ts diff --git a/.env.example b/.env.example index d0435c746..819b0dfab 100644 --- a/.env.example +++ b/.env.example @@ -698,3 +698,16 @@ OPENWEATHER_API_KEY= # JINA_API_KEY=your_jina_api_key # or # COHERE_API_KEY=your_cohere_api_key + +#======================# +# MCP Configuration # +#======================# + +# Treat 401/403 responses as OAuth requirement when no oauth metadata found +# MCP_OAUTH_ON_AUTH_ERROR=true + +# Timeout for OAuth detection requests in milliseconds +# MCP_OAUTH_DETECTION_TIMEOUT=5000 + +# Cache connection status checks for this many milliseconds to avoid expensive verification +# MCP_CONNECTION_CHECK_TTL=60000 diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md index 207aa17e6..ad0a75ab9 100644 --- a/.github/CONTRIBUTING.md +++ b/.github/CONTRIBUTING.md @@ -147,7 +147,7 @@ Apply the following naming conventions to branches, labels, and other Git-relate ## 8. Module Import Conventions - `npm` packages first, - - from shortest line (top) to longest (bottom) + - from longest line (top) to shortest (bottom) - Followed by typescript types (pertains to data-provider and client workspaces) - longest line (top) to shortest (bottom) @@ -157,6 +157,8 @@ Apply the following naming conventions to branches, labels, and other Git-relate - longest line (top) to shortest (bottom) - imports with alias `~` treated the same as relative import with respect to line length +**Note:** ESLint will automatically enforce these import conventions when you run `npm run lint --fix` or through pre-commit hooks. + --- Please ensure that you adapt this summary to fit the specific context and nuances of your project. diff --git a/.gitignore b/.gitignore index 38b9bbc9e..079690550 100644 --- a/.gitignore +++ b/.gitignore @@ -137,3 +137,4 @@ helm/**/.values.yaml /.openai/ /.tabnine/ /.codeium +*.local.md diff --git a/api/config/index.js b/api/config/index.js index 2e69e8711..2ffcf1cdf 100644 --- a/api/config/index.js +++ b/api/config/index.js @@ -1,27 +1,13 @@ +const { MCPManager, FlowStateManager } = require('@librechat/api'); const { EventSource } = require('eventsource'); const { Time } = require('librechat-data-provider'); -const { MCPManager, FlowStateManager } = require('@librechat/api'); const logger = require('./winston'); global.EventSource = EventSource; /** @type {MCPManager} */ -let mcpManager = null; let flowManager = null; -/** - * @param {string} [userId] - Optional user ID, to avoid disconnecting the current user. - * @returns {MCPManager} - */ -function getMCPManager(userId) { - if (!mcpManager) { - mcpManager = MCPManager.getInstance(); - } else { - mcpManager.checkIdleConnections(userId); - } - return mcpManager; -} - /** * @param {Keyv} flowsCache * @returns {FlowStateManager} @@ -37,6 +23,7 @@ function getFlowStateManager(flowsCache) { module.exports = { logger, - getMCPManager, + createMCPManager: MCPManager.createInstance, + getMCPManager: MCPManager.getInstance, getFlowStateManager, }; diff --git a/api/server/routes/__tests__/mcp.spec.js b/api/server/routes/__tests__/mcp.spec.js index 272b9f723..243711d76 100644 --- a/api/server/routes/__tests__/mcp.spec.js +++ b/api/server/routes/__tests__/mcp.spec.js @@ -1,7 +1,7 @@ -const express = require('express'); +const { MongoMemoryServer } = require('mongodb-memory-server'); const request = require('supertest'); const mongoose = require('mongoose'); -const { MongoMemoryServer } = require('mongodb-memory-server'); +const express = require('express'); jest.mock('@librechat/api', () => ({ MCPOAuthHandler: { @@ -494,12 +494,9 @@ describe('MCP Routes', () => { }); it('should return 500 when token retrieval throws an unexpected error', async () => { - const mockFlowManager = { - getFlowState: jest.fn().mockRejectedValue(new Error('Database connection failed')), - }; - - getLogStores.mockReturnValue({}); - require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager); + getLogStores.mockImplementation(() => { + throw new Error('Database connection failed'); + }); const response = await request(app).get('/api/mcp/oauth/tokens/test-user-id:error-flow'); @@ -563,8 +560,8 @@ describe('MCP Routes', () => { }); describe('POST /oauth/cancel/:serverName', () => { - const { getLogStores } = require('~/cache'); const { MCPOAuthHandler } = require('@librechat/api'); + const { getLogStores } = require('~/cache'); it('should cancel OAuth flow successfully', async () => { const mockFlowManager = { @@ -644,15 +641,15 @@ describe('MCP Routes', () => { }); describe('POST /:serverName/reinitialize', () => { - const { loadCustomConfig } = require('~/server/services/Config'); - const { getUserPluginAuthValue } = require('~/server/services/PluginService'); - it('should return 404 when server is not found in configuration', async () => { - loadCustomConfig.mockResolvedValue({ - mcpServers: { - 'other-server': {}, - }, - }); + const mockMcpManager = { + getRawConfig: jest.fn().mockReturnValue(null), + disconnectUserConnection: jest.fn().mockResolvedValue(), + }; + + require('~/config').getMCPManager.mockReturnValue(mockMcpManager); + require('~/config').getFlowStateManager.mockReturnValue({}); + require('~/cache').getLogStores.mockReturnValue({}); const response = await request(app).post('/api/mcp/non-existent-server/reinitialize'); @@ -663,16 +660,11 @@ describe('MCP Routes', () => { }); it('should handle OAuth requirement during reinitialize', async () => { - loadCustomConfig.mockResolvedValue({ - mcpServers: { - 'oauth-server': { - customUserVars: {}, - }, - }, - }); - const mockMcpManager = { - disconnectServer: jest.fn().mockResolvedValue(), + getRawConfig: jest.fn().mockReturnValue({ + customUserVars: {}, + }), + disconnectUserConnection: jest.fn().mockResolvedValue(), mcpConfigs: {}, getUserConnection: jest.fn().mockImplementation(async ({ oauthStart }) => { if (oauthStart) { @@ -690,7 +682,7 @@ describe('MCP Routes', () => { expect(response.status).toBe(200); expect(response.body).toEqual({ - success: 'https://oauth.example.com/auth', + success: true, message: "MCP server 'oauth-server' ready for OAuth authentication", serverName: 'oauth-server', oauthRequired: true, @@ -699,14 +691,9 @@ describe('MCP Routes', () => { }); it('should return 500 when reinitialize fails with non-OAuth error', async () => { - loadCustomConfig.mockResolvedValue({ - mcpServers: { - 'error-server': {}, - }, - }); - const mockMcpManager = { - disconnectServer: jest.fn().mockResolvedValue(), + getRawConfig: jest.fn().mockReturnValue({}), + disconnectUserConnection: jest.fn().mockResolvedValue(), mcpConfigs: {}, getUserConnection: jest.fn().mockRejectedValue(new Error('Connection failed')), }; @@ -724,7 +711,13 @@ describe('MCP Routes', () => { }); it('should return 500 when unexpected error occurs', async () => { - loadCustomConfig.mockRejectedValue(new Error('Config loading failed')); + const mockMcpManager = { + getRawConfig: jest.fn().mockImplementation(() => { + throw new Error('Config loading failed'); + }), + }; + + require('~/config').getMCPManager.mockReturnValue(mockMcpManager); const response = await request(app).post('/api/mcp/test-server/reinitialize'); @@ -747,29 +740,17 @@ describe('MCP Routes', () => { expect(response.body).toEqual({ error: 'User not authenticated' }); }); - it('should handle errors when fetching custom user variables', async () => { - loadCustomConfig.mockResolvedValue({ - mcpServers: { - 'test-server': { - customUserVars: { - API_KEY: 'test-key-var', - SECRET_TOKEN: 'test-secret-var', - }, - }, - }, - }); - - getUserPluginAuthValue - .mockResolvedValueOnce('test-api-key-value') - .mockRejectedValueOnce(new Error('Database error')); - + it('should successfully reinitialize server and cache tools', async () => { const mockUserConnection = { - fetchTools: jest.fn().mockResolvedValue([]), + fetchTools: jest.fn().mockResolvedValue([ + { name: 'tool1', description: 'Test tool 1', inputSchema: { type: 'object' } }, + { name: 'tool2', description: 'Test tool 2', inputSchema: { type: 'object' } }, + ]), }; const mockMcpManager = { - disconnectServer: jest.fn().mockResolvedValue(), - mcpConfigs: {}, + getRawConfig: jest.fn().mockReturnValue({ endpoint: 'http://test-server.com' }), + disconnectUserConnection: jest.fn().mockResolvedValue(), getUserConnection: jest.fn().mockResolvedValue(mockUserConnection), }; @@ -784,38 +765,54 @@ describe('MCP Routes', () => { const response = await request(app).post('/api/mcp/test-server/reinitialize'); expect(response.status).toBe(200); - expect(response.body.success).toBe(true); + expect(response.body).toEqual({ + success: true, + message: "MCP server 'test-server' reinitialized successfully", + serverName: 'test-server', + oauthRequired: false, + oauthUrl: null, + }); + expect(mockMcpManager.disconnectUserConnection).toHaveBeenCalledWith( + 'test-user-id', + 'test-server', + ); + expect(setCachedTools).toHaveBeenCalled(); }); - it('should return failure message when reinitialize completely fails', async () => { - loadCustomConfig.mockResolvedValue({ - mcpServers: { - 'test-server': {}, - }, - }); + it('should handle server with custom user variables', async () => { + const mockUserConnection = { + fetchTools: jest.fn().mockResolvedValue([]), + }; const mockMcpManager = { - disconnectServer: jest.fn().mockResolvedValue(), - mcpConfigs: {}, - getUserConnection: jest.fn().mockResolvedValue(null), + getRawConfig: jest.fn().mockReturnValue({ + endpoint: 'http://test-server.com', + customUserVars: { + API_KEY: 'some-env-var', + }, + }), + disconnectUserConnection: jest.fn().mockResolvedValue(), + getUserConnection: jest.fn().mockResolvedValue(mockUserConnection), }; require('~/config').getMCPManager.mockReturnValue(mockMcpManager); require('~/config').getFlowStateManager.mockReturnValue({}); require('~/cache').getLogStores.mockReturnValue({}); + require('~/server/services/PluginService').getUserPluginAuthValue.mockResolvedValue( + 'api-key-value', + ); const { getCachedTools, setCachedTools } = require('~/server/services/Config'); - const { Constants } = require('librechat-data-provider'); - getCachedTools.mockResolvedValue({ - [`existing-tool${Constants.mcp_delimiter}test-server`]: { type: 'function' }, - }); + getCachedTools.mockResolvedValue({}); setCachedTools.mockResolvedValue(); const response = await request(app).post('/api/mcp/test-server/reinitialize'); expect(response.status).toBe(200); - expect(response.body.success).toBe(false); - expect(response.body.message).toBe("Failed to reinitialize MCP server 'test-server'"); + expect(response.body.success).toBe(true); + expect( + require('~/server/services/PluginService').getUserPluginAuthValue, + ).toHaveBeenCalledWith('test-user-id', 'API_KEY', false); }); }); @@ -984,21 +981,19 @@ describe('MCP Routes', () => { }); describe('GET /:serverName/auth-values', () => { - const { loadCustomConfig } = require('~/server/services/Config'); const { getUserPluginAuthValue } = require('~/server/services/PluginService'); it('should return auth value flags for server', async () => { - loadCustomConfig.mockResolvedValue({ - mcpServers: { - 'test-server': { - customUserVars: { - API_KEY: 'some-env-var', - SECRET_TOKEN: 'another-env-var', - }, + const mockMcpManager = { + getRawConfig: jest.fn().mockReturnValue({ + customUserVars: { + API_KEY: 'some-env-var', + SECRET_TOKEN: 'another-env-var', }, - }, - }); + }), + }; + require('~/config').getMCPManager.mockReturnValue(mockMcpManager); getUserPluginAuthValue.mockResolvedValueOnce('some-api-key-value').mockResolvedValueOnce(''); const response = await request(app).get('/api/mcp/test-server/auth-values'); @@ -1017,11 +1012,11 @@ describe('MCP Routes', () => { }); it('should return 404 when server is not found in configuration', async () => { - loadCustomConfig.mockResolvedValue({ - mcpServers: { - 'other-server': {}, - }, - }); + const mockMcpManager = { + getRawConfig: jest.fn().mockReturnValue(null), + }; + + require('~/config').getMCPManager.mockReturnValue(mockMcpManager); const response = await request(app).get('/api/mcp/non-existent-server/auth-values'); @@ -1032,16 +1027,15 @@ describe('MCP Routes', () => { }); it('should handle errors when checking auth values', async () => { - loadCustomConfig.mockResolvedValue({ - mcpServers: { - 'test-server': { - customUserVars: { - API_KEY: 'some-env-var', - }, + const mockMcpManager = { + getRawConfig: jest.fn().mockReturnValue({ + customUserVars: { + API_KEY: 'some-env-var', }, - }, - }); + }), + }; + require('~/config').getMCPManager.mockReturnValue(mockMcpManager); getUserPluginAuthValue.mockRejectedValue(new Error('Database error')); const response = await request(app).get('/api/mcp/test-server/auth-values'); @@ -1057,7 +1051,13 @@ describe('MCP Routes', () => { }); it('should return 500 when auth values check throws unexpected error', async () => { - loadCustomConfig.mockRejectedValue(new Error('Config loading failed')); + const mockMcpManager = { + getRawConfig: jest.fn().mockImplementation(() => { + throw new Error('Config loading failed'); + }), + }; + + require('~/config').getMCPManager.mockReturnValue(mockMcpManager); const response = await request(app).get('/api/mcp/test-server/auth-values'); @@ -1066,14 +1066,13 @@ describe('MCP Routes', () => { }); it('should handle customUserVars that is not an object', async () => { - const { loadCustomConfig } = require('~/server/services/Config'); - loadCustomConfig.mockResolvedValue({ - mcpServers: { - 'test-server': { - customUserVars: 'not-an-object', - }, - }, - }); + const mockMcpManager = { + getRawConfig: jest.fn().mockReturnValue({ + customUserVars: 'not-an-object', + }), + }; + + require('~/config').getMCPManager.mockReturnValue(mockMcpManager); const response = await request(app).get('/api/mcp/test-server/auth-values'); @@ -1097,98 +1096,6 @@ describe('MCP Routes', () => { }); }); - describe('POST /:serverName/reinitialize - Tool Deletion Coverage', () => { - it('should handle null cached tools during reinitialize (triggers || {} fallback)', async () => { - const { loadCustomConfig, getCachedTools } = require('~/server/services/Config'); - - const mockUserConnection = { - fetchTools: jest.fn().mockResolvedValue([{ name: 'new-tool', description: 'A new tool' }]), - }; - - const mockMcpManager = { - getUserConnection: jest.fn().mockResolvedValue(mockUserConnection), - disconnectServer: jest.fn(), - initializeServer: jest.fn(), - mcpConfigs: {}, - }; - require('~/config').getMCPManager.mockReturnValue(mockMcpManager); - - loadCustomConfig.mockResolvedValue({ - mcpServers: { - 'test-server': { env: { API_KEY: 'test-key' } }, - }, - }); - - getCachedTools.mockResolvedValue(null); - - const response = await request(app).post('/api/mcp/test-server/reinitialize').expect(200); - - expect(response.body).toEqual({ - message: "MCP server 'test-server' reinitialized successfully", - success: true, - oauthRequired: false, - oauthUrl: null, - serverName: 'test-server', - }); - }); - - it('should delete existing cached tools during successful reinitialize', async () => { - const { - loadCustomConfig, - getCachedTools, - setCachedTools, - } = require('~/server/services/Config'); - - const mockUserConnection = { - fetchTools: jest.fn().mockResolvedValue([{ name: 'new-tool', description: 'A new tool' }]), - }; - - const mockMcpManager = { - getUserConnection: jest.fn().mockResolvedValue(mockUserConnection), - disconnectServer: jest.fn(), - initializeServer: jest.fn(), - mcpConfigs: {}, - }; - require('~/config').getMCPManager.mockReturnValue(mockMcpManager); - - loadCustomConfig.mockResolvedValue({ - mcpServers: { - 'test-server': { env: { API_KEY: 'test-key' } }, - }, - }); - - const existingTools = { - 'old-tool_mcp_test-server': { type: 'function' }, - 'other-tool_mcp_other-server': { type: 'function' }, - }; - getCachedTools.mockResolvedValue(existingTools); - - const response = await request(app).post('/api/mcp/test-server/reinitialize').expect(200); - - expect(response.body).toEqual({ - message: "MCP server 'test-server' reinitialized successfully", - success: true, - oauthRequired: false, - oauthUrl: null, - serverName: 'test-server', - }); - - expect(setCachedTools).toHaveBeenCalledWith( - expect.objectContaining({ - 'new-tool_mcp_test-server': expect.any(Object), - 'other-tool_mcp_other-server': { type: 'function' }, - }), - { userId: 'test-user-id' }, - ); - expect(setCachedTools).toHaveBeenCalledWith( - expect.not.objectContaining({ - 'old-tool_mcp_test-server': expect.anything(), - }), - { userId: 'test-user-id' }, - ); - }); - }); - describe('GET /:serverName/oauth/callback - Edge Cases', () => { it('should handle OAuth callback without toolFlowId (falsy toolFlowId)', async () => { const { MCPOAuthHandler } = require('@librechat/api'); diff --git a/api/server/routes/mcp.js b/api/server/routes/mcp.js index a725cf666..5931ca02f 100644 --- a/api/server/routes/mcp.js +++ b/api/server/routes/mcp.js @@ -1,11 +1,11 @@ -const { Router } = require('express'); const { logger } = require('@librechat/data-schemas'); const { MCPOAuthHandler } = require('@librechat/api'); -const { CacheKeys, Constants } = require('librechat-data-provider'); -const { findToken, updateToken, createToken, deleteTokens } = require('~/models'); -const { setCachedTools, getCachedTools, loadCustomConfig } = require('~/server/services/Config'); +const { Router } = require('express'); const { getMCPSetupData, getServerConnectionStatus } = require('~/server/services/MCP'); +const { findToken, updateToken, createToken, deleteTokens } = require('~/models'); +const { setCachedTools, getCachedTools } = require('~/server/services/Config'); const { getUserPluginAuthValue } = require('~/server/services/PluginService'); +const { CacheKeys, Constants } = require('librechat-data-provider'); const { getMCPManager, getFlowStateManager } = require('~/config'); const { requireJwtAuth } = require('~/server/middleware'); const { getLogStores } = require('~/cache'); @@ -315,9 +315,9 @@ router.post('/:serverName/reinitialize', requireJwtAuth, async (req, res) => { logger.info(`[MCP Reinitialize] Reinitializing server: ${serverName}`); - const printConfig = false; - const config = await loadCustomConfig(printConfig); - if (!config || !config.mcpServers || !config.mcpServers[serverName]) { + const mcpManager = getMCPManager(); + const serverConfig = mcpManager.getRawConfig(serverName); + if (!serverConfig) { return res.status(404).json({ error: `MCP server '${serverName}' not found in configuration`, }); @@ -325,13 +325,12 @@ router.post('/:serverName/reinitialize', requireJwtAuth, async (req, res) => { const flowsCache = getLogStores(CacheKeys.FLOWS); const flowManager = getFlowStateManager(flowsCache); - const mcpManager = getMCPManager(); - await mcpManager.disconnectServer(serverName); - logger.info(`[MCP Reinitialize] Disconnected existing server: ${serverName}`); + await mcpManager.disconnectUserConnection(user.id, serverName); + logger.info( + `[MCP Reinitialize] Disconnected existing user connection for server: ${serverName}`, + ); - const serverConfig = config.mcpServers[serverName]; - mcpManager.mcpConfigs[serverName] = serverConfig; let customUserVars = {}; if (serverConfig.customUserVars && typeof serverConfig.customUserVars === 'object') { for (const varName of Object.keys(serverConfig.customUserVars)) { @@ -437,7 +436,7 @@ router.post('/:serverName/reinitialize', requireJwtAuth, async (req, res) => { }; res.json({ - success: (userConnection && !oauthRequired) || (oauthRequired && oauthUrl), + success: Boolean((userConnection && !oauthRequired) || (oauthRequired && oauthUrl)), message: getResponseMessage(), serverName, oauthRequired, @@ -551,15 +550,14 @@ router.get('/:serverName/auth-values', requireJwtAuth, async (req, res) => { return res.status(401).json({ error: 'User not authenticated' }); } - const printConfig = false; - const config = await loadCustomConfig(printConfig); - if (!config || !config.mcpServers || !config.mcpServers[serverName]) { + const mcpManager = getMCPManager(); + const serverConfig = mcpManager.getRawConfig(serverName); + if (!serverConfig) { return res.status(404).json({ error: `MCP server '${serverName}' not found in configuration`, }); } - const serverConfig = config.mcpServers[serverName]; const pluginKey = `${Constants.mcp_prefix}${serverName}`; const authValueFlags = {}; diff --git a/api/server/services/initializeMCPs.js b/api/server/services/initializeMCPs.js index 18edb2449..40c75e1b0 100644 --- a/api/server/services/initializeMCPs.js +++ b/api/server/services/initializeMCPs.js @@ -1,8 +1,7 @@ const { logger } = require('@librechat/data-schemas'); -const { CacheKeys } = require('librechat-data-provider'); -const { findToken, updateToken, createToken, deleteTokens } = require('~/models'); -const { getMCPManager, getFlowStateManager } = require('~/config'); const { getCachedTools, setCachedTools } = require('./Config'); +const { CacheKeys } = require('librechat-data-provider'); +const { createMCPManager } = require('~/config'); const { getLogStores } = require('~/cache'); /** @@ -31,33 +30,19 @@ async function initializeMCPs(app) { } logger.info('Initializing MCP servers...'); - const mcpManager = getMCPManager(); - const flowsCache = getLogStores(CacheKeys.FLOWS); - const flowManager = flowsCache ? getFlowStateManager(flowsCache) : null; + const mcpManager = await createMCPManager(mcpServers); try { - await mcpManager.initializeMCPs({ - mcpServers: filteredServers, - flowManager, - tokenMethods: { - findToken, - updateToken, - createToken, - deleteTokens, - }, - }); - delete app.locals.mcpConfig; - const availableTools = await getCachedTools(); + const cachedTools = await getCachedTools(); - if (!availableTools) { + if (!cachedTools) { logger.warn('No available tools found in cache during MCP initialization'); return; } - const toolsCopy = { ...availableTools }; - await mcpManager.mapAvailableTools(toolsCopy, flowManager); - await setCachedTools(toolsCopy, { isGlobal: true }); + const mcpTools = mcpManager.getAppToolFunctions(); + await setCachedTools({ ...cachedTools, ...mcpTools }, { isGlobal: true }); const cache = getLogStores(CacheKeys.CONFIG_STORE); await cache.delete(CacheKeys.TOOLS); diff --git a/eslint.config.mjs b/eslint.config.mjs index b60230304..f53e8cc69 100644 --- a/eslint.config.mjs +++ b/eslint.config.mjs @@ -2,11 +2,11 @@ import { fileURLToPath } from 'node:url'; import path from 'node:path'; import typescriptEslintEslintPlugin from '@typescript-eslint/eslint-plugin'; import { fixupConfigRules, fixupPluginRules } from '@eslint/compat'; -// import perfectionist from 'eslint-plugin-perfectionist'; +import perfectionist from 'eslint-plugin-perfectionist'; import reactHooks from 'eslint-plugin-react-hooks'; -import prettier from 'eslint-plugin-prettier'; import tsParser from '@typescript-eslint/parser'; import importPlugin from 'eslint-plugin-import'; +import prettier from 'eslint-plugin-prettier'; import { FlatCompat } from '@eslint/eslintrc'; import jsxA11Y from 'eslint-plugin-jsx-a11y'; import i18next from 'eslint-plugin-i18next'; @@ -62,7 +62,7 @@ export default [ 'jsx-a11y': fixupPluginRules(jsxA11Y), 'import/parsers': tsParser, i18next, - // perfectionist, + perfectionist, prettier: fixupPluginRules(prettier), }, @@ -140,32 +140,31 @@ export default [ 'react/prop-types': 'off', 'react/display-name': 'off', - // 'perfectionist/sort-imports': [ - // 'error', - // { - // type: 'line-length', - // order: 'desc', - // newlinesBetween: 'never', - // customGroups: { - // value: { - // react: ['^react$'], - // // react: ['^react$', '^fs', '^zod', '^path'], - // local: ['^(\\.{1,2}|~)/', '^librechat-data-provider'], - // }, - // }, - // groups: [ - // 'react', - // 'builtin', - // 'external', - // ['builtin-type', 'external-type'], - // ['internal-type'], - // 'local', - // ['parent', 'sibling', 'index'], - // 'object', - // 'unknown', - // ], - // }, - // ], + 'perfectionist/sort-imports': [ + 'error', + { + type: 'line-length', + order: 'desc', + newlinesBetween: 'never', + customGroups: { + value: { + react: ['^react$'], + local: ['^(\\.{1,2}|~)/', '^librechat-data-provider'], + }, + }, + groups: [ + 'react', + 'builtin', + 'external', + ['builtin-type', 'external-type'], + ['internal-type'], + 'local', + ['parent', 'sibling', 'index'], + 'object', + 'unknown', + ], + }, + ], // 'perfectionist/sort-named-imports': [ // 'error', diff --git a/packages/api/jest.config.mjs b/packages/api/jest.config.mjs index eb6be102d..1533a3d21 100644 --- a/packages/api/jest.config.mjs +++ b/packages/api/jest.config.mjs @@ -1,7 +1,7 @@ export default { collectCoverageFrom: ['src/**/*.{js,jsx,ts,tsx}', '!/node_modules/'], coveragePathIgnorePatterns: ['/node_modules/', '/dist/'], - testPathIgnorePatterns: ['/node_modules/', '/dist/'], + testPathIgnorePatterns: ['/node_modules/', '/dist/', '\\.dev\\.ts$'], coverageReporters: ['text', 'cobertura'], testResultsProcessor: 'jest-junit', moduleNameMapper: { diff --git a/packages/api/src/index.ts b/packages/api/src/index.ts index 21f16734c..fb8e63a00 100644 --- a/packages/api/src/index.ts +++ b/packages/api/src/index.ts @@ -1,5 +1,5 @@ /* MCP */ -export * from './mcp/manager'; +export * from './mcp/MCPManager'; export * from './mcp/oauth'; export * from './mcp/auth'; export * from './mcp/zod'; diff --git a/packages/api/src/mcp/ConnectionsRepository.ts b/packages/api/src/mcp/ConnectionsRepository.ts new file mode 100644 index 000000000..7bfb95ad2 --- /dev/null +++ b/packages/api/src/mcp/ConnectionsRepository.ts @@ -0,0 +1,87 @@ +import { logger } from '@librechat/data-schemas'; +import { MCPConnectionFactory, OAuthConnectionOptions } from '~/mcp/MCPConnectionFactory'; +import { MCPConnection } from './connection'; +import type * as t from './types'; + +/** + * Manages MCP connections with lazy loading and reconnection. + * Maintains a pool of connections and handles connection lifecycle management. + */ +export class ConnectionsRepository { + protected readonly serverConfigs: Record; + protected connections: Map = new Map(); + protected oauthOpts: OAuthConnectionOptions | undefined; + + constructor(serverConfigs: t.MCPServers, oauthOpts?: OAuthConnectionOptions) { + this.serverConfigs = serverConfigs; + this.oauthOpts = oauthOpts; + } + + /** Checks whether this repository can connect to a specific server */ + has(serverName: string): boolean { + return !!this.serverConfigs[serverName]; + } + + /** Gets or creates a connection for the specified server with lazy loading */ + async get(serverName: string): Promise { + const existingConnection = this.connections.get(serverName); + if (existingConnection && (await existingConnection.isConnected())) return existingConnection; + else await this.disconnect(serverName); + + const connection = await MCPConnectionFactory.create( + { + serverName, + serverConfig: this.getServerConfig(serverName), + }, + this.oauthOpts, + ); + + this.connections.set(serverName, connection); + return connection; + } + + /** Gets or creates connections for multiple servers concurrently */ + async getMany(serverNames: string[]): Promise> { + const connectionPromises = serverNames.map(async (name) => [name, await this.get(name)]); + const connections = await Promise.all(connectionPromises); + return new Map(connections as [string, MCPConnection][]); + } + + /** Returns all currently loaded connections without creating new ones */ + async getLoaded(): Promise> { + return this.getMany(Array.from(this.connections.keys())); + } + + /** Gets or creates connections for all configured servers */ + async getAll(): Promise> { + return this.getMany(Object.keys(this.serverConfigs)); + } + + /** Disconnects and removes a specific server connection from the pool */ + disconnect(serverName: string): Promise { + const connection = this.connections.get(serverName); + if (!connection) return Promise.resolve(); + this.connections.delete(serverName); + return connection.disconnect().catch((err) => { + logger.error(`${this.prefix(serverName)} Error disconnecting`, err); + }); + } + + /** Disconnects all active connections and returns array of disconnect promises */ + disconnectAll(): Promise[] { + const serverNames = Array.from(this.connections.keys()); + return serverNames.map((serverName) => this.disconnect(serverName)); + } + + // Retrieves server configuration by name or throws if not found + protected getServerConfig(serverName: string): t.MCPOptions { + const serverConfig = this.serverConfigs[serverName]; + if (serverConfig) return serverConfig; + throw new Error(`${this.prefix(serverName)} Server not found in configuration`); + } + + // Returns formatted log prefix for server messages + protected prefix(serverName: string): string { + return `[MCP][${serverName}]`; + } +} diff --git a/packages/api/src/mcp/MCPConnectionFactory.ts b/packages/api/src/mcp/MCPConnectionFactory.ts new file mode 100644 index 000000000..f657f0255 --- /dev/null +++ b/packages/api/src/mcp/MCPConnectionFactory.ts @@ -0,0 +1,384 @@ +import { logger } from '@librechat/data-schemas'; +import type { OAuthClientInformation } from '@modelcontextprotocol/sdk/shared/auth.js'; +import type { TokenMethods } from '@librechat/data-schemas'; +import type { TUser } from 'librechat-data-provider'; +import type { MCPOAuthTokens, MCPOAuthFlowMetadata } from '~/mcp/oauth'; +import type { FlowStateManager } from '~/flow/manager'; +import type { FlowMetadata } from '~/flow/types'; +import { MCPTokenStorage, MCPOAuthHandler } from '~/mcp/oauth'; +import { MCPConnection } from './connection'; +import { processMCPEnv } from '~/utils'; +import type * as t from './types'; + +export interface BasicConnectionOptions { + serverName: string; + serverConfig: t.MCPOptions; +} + +export interface OAuthConnectionOptions { + useOAuth: true; + user: TUser; + customUserVars?: Record; + flowManager: FlowStateManager; + tokenMethods?: TokenMethods; + signal?: AbortSignal; + oauthStart?: (authURL: string) => Promise; + oauthEnd?: () => Promise; + returnOnOAuth?: boolean; +} + +/** + * Factory for creating MCP connections with optional OAuth authentication. + * Handles OAuth flows, token management, and connection retry logic. + * NOTE: Much of the OAuth logic was extracted from the old MCPManager class as is. + */ +export class MCPConnectionFactory { + protected readonly serverName: string; + protected readonly serverConfig: t.MCPOptions; + protected readonly logPrefix: string; + protected readonly useOAuth: boolean; + + // OAuth-related properties (only set when useOAuth is true) + protected readonly userId?: string; + protected readonly flowManager?: FlowStateManager; + protected readonly tokenMethods?: TokenMethods; + protected readonly signal?: AbortSignal; + protected readonly oauthStart?: (authURL: string) => Promise; + protected readonly oauthEnd?: () => Promise; + protected readonly returnOnOAuth?: boolean; + + /** Creates a new MCP connection with optional OAuth support */ + static async create( + basic: BasicConnectionOptions, + oauth?: OAuthConnectionOptions, + ): Promise { + const factory = new this(basic, oauth); + return factory.createConnection(); + } + + protected constructor(basic: BasicConnectionOptions, oauth?: OAuthConnectionOptions) { + this.serverConfig = processMCPEnv(basic.serverConfig, oauth?.user, oauth?.customUserVars); + this.serverName = basic.serverName; + this.useOAuth = !!oauth?.useOAuth; + this.logPrefix = oauth?.user + ? `[MCP][${basic.serverName}][${oauth.user.id}]` + : `[MCP][${basic.serverName}]`; + + if (oauth?.useOAuth) { + this.userId = oauth.user.id; + this.flowManager = oauth.flowManager; + this.tokenMethods = oauth.tokenMethods; + this.signal = oauth.signal; + this.oauthStart = oauth.oauthStart; + this.oauthEnd = oauth.oauthEnd; + this.returnOnOAuth = oauth.returnOnOAuth; + } + } + + /** Creates the base MCP connection with OAuth tokens */ + protected async createConnection(): Promise { + const oauthTokens = this.useOAuth ? await this.getOAuthTokens() : null; + const connection = new MCPConnection({ + serverName: this.serverName, + serverConfig: this.serverConfig, + userId: this.userId, + oauthTokens, + }); + + if (this.useOAuth) this.handleOAuthEvents(connection); + await this.attemptToConnect(connection); + return connection; + } + + /** Retrieves existing OAuth tokens from storage or returns null */ + protected async getOAuthTokens(): Promise { + if (!this.tokenMethods?.findToken) return null; + + try { + const tokens = await this.flowManager!.createFlowWithHandler( + `tokens:${this.userId}:${this.serverName}`, + 'mcp_get_tokens', + async () => { + return await MCPTokenStorage.getTokens({ + userId: this.userId!, + serverName: this.serverName, + findToken: this.tokenMethods!.findToken!, + createToken: this.tokenMethods!.createToken, + updateToken: this.tokenMethods!.updateToken, + refreshTokens: this.createRefreshTokensFunction(), + }); + }, + this.signal, + ); + + if (tokens) logger.info(`${this.logPrefix} Loaded OAuth tokens`); + return tokens; + } catch (error) { + logger.debug(`${this.logPrefix} No existing tokens found or error loading tokens`, error); + return null; + } + } + + /** Creates a function to refresh OAuth tokens when they expire */ + protected createRefreshTokensFunction(): ( + refreshToken: string, + metadata: { + userId: string; + serverName: string; + identifier: string; + clientInfo?: OAuthClientInformation; + }, + ) => Promise { + return async (refreshToken, metadata) => { + return await MCPOAuthHandler.refreshOAuthTokens( + refreshToken, + { + serverUrl: (this.serverConfig as t.SSEOptions | t.StreamableHTTPOptions).url, + serverName: metadata.serverName, + clientInfo: metadata.clientInfo, + }, + this.serverConfig.oauth, + ); + }; + } + + /** Sets up OAuth event handlers for the connection */ + protected handleOAuthEvents(connection: MCPConnection): void { + connection.on('oauthRequired', async (data) => { + logger.info(`${this.logPrefix} oauthRequired event received`); + + // If we just want to initiate OAuth and return, handle it differently + if (this.returnOnOAuth) { + try { + const config = this.serverConfig; + const { authorizationUrl, flowId, flowMetadata } = + await MCPOAuthHandler.initiateOAuthFlow( + this.serverName, + data.serverUrl || '', + this.userId!, + config?.oauth, + ); + + // Create the flow state so the OAuth callback can find it + // We spawn this in the background without waiting for it + this.flowManager!.createFlow(flowId, 'mcp_oauth', flowMetadata).catch(() => { + // The OAuth callback will resolve this flow, so we expect it to timeout here + // which is fine - we just need the flow state to exist + }); + + if (this.oauthStart) { + logger.info(`${this.logPrefix} OAuth flow started, issuing authorization URL`); + await this.oauthStart(authorizationUrl); + } + + // Emit oauthFailed to signal that connection should not proceed + // but OAuth was successfully initiated + connection.emit('oauthFailed', new Error('OAuth flow initiated - return early')); + return; + } catch (error) { + logger.error(`${this.logPrefix} Failed to initiate OAuth flow`, error); + connection.emit('oauthFailed', new Error('OAuth initiation failed')); + return; + } + } + + // Normal OAuth handling - wait for completion + const result = await this.handleOAuthRequired(); + + if (result?.tokens && this.tokenMethods?.createToken) { + try { + connection.setOAuthTokens(result.tokens); + await MCPTokenStorage.storeTokens({ + userId: this.userId!, + serverName: this.serverName, + tokens: result.tokens, + createToken: this.tokenMethods.createToken, + updateToken: this.tokenMethods.updateToken, + findToken: this.tokenMethods.findToken, + clientInfo: result.clientInfo, + }); + logger.info(`${this.logPrefix} OAuth tokens saved to storage`); + } catch (error) { + logger.error(`${this.logPrefix} Failed to save OAuth tokens to storage`, error); + } + } + + // Only emit oauthHandled if we actually got tokens (OAuth succeeded) + if (result?.tokens) { + connection.emit('oauthHandled'); + } else { + // OAuth failed, emit oauthFailed to properly reject the promise + logger.warn(`${this.logPrefix} OAuth failed, emitting oauthFailed event`); + connection.emit('oauthFailed', new Error('OAuth authentication failed')); + } + }); + } + + /** Attempts to establish connection with timeout handling */ + protected async attemptToConnect(connection: MCPConnection): Promise { + const connectTimeout = this.serverConfig.initTimeout ?? 30000; + const connectionTimeout = new Promise((_, reject) => + setTimeout( + () => reject(new Error(`Connection timeout after ${connectTimeout}ms`)), + connectTimeout, + ), + ); + const connectionAttempt = this.connectTo(connection); + await Promise.race([connectionAttempt, connectionTimeout]); + + if (await connection.isConnected()) return; + logger.error(`${this.logPrefix} Failed to establish connection.`); + } + + // Handles connection attempts with retry logic and OAuth error handling + private async connectTo(connection: MCPConnection): Promise { + const maxAttempts = 3; + let attempts = 0; + let oauthHandled = false; + + while (attempts < maxAttempts) { + try { + await connection.connect(); + if (await connection.isConnected()) { + return; + } + throw new Error('Connection attempt succeeded but status is not connected'); + } catch (error) { + attempts++; + + if (this.useOAuth && this.isOAuthError(error)) { + // Only handle OAuth if this is a user connection (has oauthStart handler) + if (this.oauthStart && !oauthHandled) { + const errorWithFlag = error as (Error & { isOAuthError?: boolean }) | undefined; + if (errorWithFlag?.isOAuthError) { + oauthHandled = true; + logger.info(`${this.logPrefix} Handling OAuth`); + await this.handleOAuthRequired(); + } + } + // Don't retry on OAuth errors - just throw + logger.info(`${this.logPrefix} OAuth required, stopping connection attempts`); + throw error; + } + + if (attempts === maxAttempts) { + logger.error(`${this.logPrefix} Failed to connect after ${maxAttempts} attempts`, error); + throw error; + } + await new Promise((resolve) => setTimeout(resolve, 2000 * attempts)); + } + } + } + + // Determines if an error indicates OAuth authentication is required + private isOAuthError(error: unknown): boolean { + if (!error || typeof error !== 'object') { + return false; + } + + // Check for SSE error with 401 status + if ('message' in error && typeof error.message === 'string') { + return error.message.includes('401') || error.message.includes('Non-200 status code (401)'); + } + + // Check for error code + if ('code' in error) { + const code = (error as { code?: number }).code; + return code === 401 || code === 403; + } + + return false; + } + + /** Manages OAuth flow initiation and completion */ + protected async handleOAuthRequired(): Promise<{ + tokens: MCPOAuthTokens | null; + clientInfo?: OAuthClientInformation; + } | null> { + const serverUrl = (this.serverConfig as t.SSEOptions | t.StreamableHTTPOptions).url; + logger.debug(`${this.logPrefix} \`handleOAuthRequired\` called with serverUrl: ${serverUrl}`); + + if (!this.flowManager || !serverUrl) { + logger.error( + `${this.logPrefix} OAuth required but flow manager not available or server URL missing for ${this.serverName}`, + ); + logger.warn(`${this.logPrefix} Please configure OAuth credentials for ${this.serverName}`); + return null; + } + + try { + logger.debug(`${this.logPrefix} Checking for existing OAuth flow for ${this.serverName}...`); + + /** Flow ID to check if a flow already exists */ + const flowId = MCPOAuthHandler.generateFlowId(this.userId!, this.serverName); + + /** Check if there's already an ongoing OAuth flow for this flowId */ + const existingFlow = await this.flowManager.getFlowState(flowId, 'mcp_oauth'); + if (existingFlow && existingFlow.status === 'PENDING') { + logger.debug( + `${this.logPrefix} OAuth flow already exists for ${flowId}, waiting for completion`, + ); + /** Tokens from existing flow to complete */ + const tokens = await this.flowManager.createFlow(flowId, 'mcp_oauth'); + if (typeof this.oauthEnd === 'function') { + await this.oauthEnd(); + } + logger.info( + `${this.logPrefix} OAuth flow completed, tokens received for ${this.serverName}`, + ); + + /** Client information from the existing flow metadata */ + const existingMetadata = existingFlow.metadata as unknown as MCPOAuthFlowMetadata; + const clientInfo = existingMetadata?.clientInfo; + + return { tokens, clientInfo }; + } + + logger.debug(`${this.logPrefix} Initiating new OAuth flow for ${this.serverName}...`); + const { + authorizationUrl, + flowId: newFlowId, + flowMetadata, + } = await MCPOAuthHandler.initiateOAuthFlow( + this.serverName, + serverUrl, + this.userId!, + this.serverConfig.oauth, + ); + + if (typeof this.oauthStart === 'function') { + logger.info(`${this.logPrefix} OAuth flow started, issued authorization URL to user`); + await this.oauthStart(authorizationUrl); + } else { + logger.info(` +═══════════════════════════════════════════════════════════════════════ +Please visit the following URL to authenticate: + +${authorizationUrl} + +${this.logPrefix} Flow ID: ${newFlowId} +═══════════════════════════════════════════════════════════════════════ +`); + } + + /** Tokens from the new flow */ + const tokens = await this.flowManager.createFlow( + newFlowId, + 'mcp_oauth', + flowMetadata as FlowMetadata, + ); + if (typeof this.oauthEnd === 'function') { + await this.oauthEnd(); + } + logger.info(`${this.logPrefix} OAuth flow completed, tokens received for ${this.serverName}`); + + /** Client information from the flow metadata */ + const clientInfo = flowMetadata?.clientInfo; + + return { tokens, clientInfo }; + } catch (error) { + logger.error(`${this.logPrefix} Failed to complete OAuth flow for ${this.serverName}`, error); + return null; + } + } +} diff --git a/packages/api/src/mcp/MCPManager.ts b/packages/api/src/mcp/MCPManager.ts new file mode 100644 index 000000000..9173ec58a --- /dev/null +++ b/packages/api/src/mcp/MCPManager.ts @@ -0,0 +1,263 @@ +import { CallToolResultSchema, ErrorCode, McpError } from '@modelcontextprotocol/sdk/types.js'; +import { logger } from '@librechat/data-schemas'; +import pick from 'lodash/pick'; +import type { RequestOptions } from '@modelcontextprotocol/sdk/shared/protocol.js'; +import type { TokenMethods } from '@librechat/data-schemas'; +import type { TUser } from 'librechat-data-provider'; +import type { FlowStateManager } from '~/flow/manager'; +import type { MCPOAuthTokens } from '~/mcp/oauth'; +import { UserConnectionManager } from '~/mcp/UserConnectionManager'; +import { ConnectionsRepository } from '~/mcp/ConnectionsRepository'; +import { formatToolContent } from './parsers'; +import { MCPConnection } from './connection'; +import { CONSTANTS } from './enum'; +import type * as t from './types'; + +/** + * Centralized manager for MCP server connections and tool execution. + * Extends UserConnectionManager to handle both app-level and user-specific connections. + */ +export class MCPManager extends UserConnectionManager { + private static instance: MCPManager | null; + // Connections shared by all users. + private appConnections: ConnectionsRepository | null = null; + + /** Creates and initializes the singleton MCPManager instance */ + public static async createInstance(configs: t.MCPServers): Promise { + if (MCPManager.instance) throw new Error('MCPManager has already been initialized.'); + MCPManager.instance = new MCPManager(configs); + await MCPManager.instance.initialize(); + return MCPManager.instance; + } + + /** Returns the singleton MCPManager instance */ + public static getInstance(): MCPManager { + if (!MCPManager.instance) throw new Error('MCPManager has not been initialized.'); + return MCPManager.instance; + } + + /** Initializes the MCPManager by setting up server registry and app connections */ + public async initialize() { + await this.serversRegistry.initialize(); + this.appConnections = new ConnectionsRepository(this.serversRegistry.appServerConfigs!); + } + + /** Returns all app-level connections */ + public async getAllConnections(): Promise> { + return this.appConnections!.getAll(); + } + + /** Get servers that require OAuth */ + public getOAuthServers(): Set { + return this.serversRegistry.oauthServers!; + } + + /** Returns all available tool functions from app-level connections */ + public getAppToolFunctions(): t.LCAvailableTools { + return this.serversRegistry.toolFunctions!; + } + + /** + * Get instructions for MCP servers + * @param serverNames Optional array of server names. If not provided or empty, returns all servers. + * @returns Object mapping server names to their instructions + */ + public getInstructions(serverNames?: string[]): Record { + const instructions = this.serversRegistry.serverInstructions!; + if (!serverNames) return instructions; + return pick(instructions, serverNames); + } + + /** + * Format MCP server instructions for injection into context + * @param serverNames Optional array of server names to include. If not provided, includes all servers. + * @returns Formatted instructions string ready for context injection + */ + public formatInstructionsForContext(serverNames?: string[]): string { + /** Instructions for specified servers or all stored instructions */ + const instructionsToInclude = this.getInstructions(serverNames); + + if (Object.keys(instructionsToInclude).length === 0) { + return ''; + } + + // Format instructions for context injection + const formattedInstructions = Object.entries(instructionsToInclude) + .map(([serverName, instructions]) => { + return `## ${serverName} MCP Server Instructions + +${instructions}`; + }) + .join('\n\n'); + + return `# MCP Server Instructions + +The following MCP servers are available with their specific instructions: + +${formattedInstructions} + +Please follow these instructions when using tools from the respective MCP servers.`; + } + + /** Loads tools from all app-level connections into the manifest. */ + public async loadManifestTools({ + serverToolsCallback, + getServerTools, + }: { + flowManager: FlowStateManager; + serverToolsCallback?: (serverName: string, tools: t.LCManifestTool[]) => Promise; + getServerTools?: (serverName: string) => Promise; + }): Promise { + const mcpTools: t.LCManifestTool[] = []; + const connections = await this.appConnections!.getAll(); + for (const [serverName, connection] of connections.entries()) { + try { + if (!(await connection.isConnected())) { + logger.warn( + `[MCP][${serverName}] Connection not available for ${serverName} manifest tools.`, + ); + if (typeof getServerTools !== 'function') { + logger.warn( + `[MCP][${serverName}] No \`getServerTools\` function provided, skipping tool loading.`, + ); + continue; + } + const serverTools = await getServerTools(serverName); + if (serverTools && serverTools.length > 0) { + logger.info(`[MCP][${serverName}] Loaded tools from cache for manifest`); + mcpTools.push(...serverTools); + } + continue; + } + + const tools = await connection.fetchTools(); + const serverTools: t.LCManifestTool[] = []; + for (const tool of tools) { + const pluginKey = `${tool.name}${CONSTANTS.mcp_delimiter}${serverName}`; + + const config = this.serversRegistry.parsedConfigs[serverName]; + const manifestTool: t.LCManifestTool = { + name: tool.name, + pluginKey, + description: tool.description ?? '', + icon: connection.iconPath, + authConfig: config?.customUserVars + ? Object.entries(config.customUserVars).map(([key, value]) => ({ + authField: key, + label: value.title || key, + description: value.description || '', + })) + : undefined, + }; + if (config?.chatMenu === false) { + manifestTool.chatMenu = false; + } + mcpTools.push(manifestTool); + serverTools.push(manifestTool); + } + if (typeof serverToolsCallback === 'function') { + await serverToolsCallback(serverName, serverTools); + } + } catch (error) { + logger.error(`[MCP][${serverName}] Error fetching tools for manifest:`, error); + } + } + + return mcpTools; + } + + /** + * Calls a tool on an MCP server, using either a user-specific connection + * (if userId is provided) or an app-level connection. Updates the last activity timestamp + * for user-specific connections upon successful call initiation. + */ + async callTool({ + user, + serverName, + toolName, + provider, + toolArguments, + options, + tokenMethods, + flowManager, + oauthStart, + oauthEnd, + customUserVars, + }: { + user?: TUser; + serverName: string; + toolName: string; + provider: t.Provider; + toolArguments?: Record; + options?: RequestOptions; + tokenMethods?: TokenMethods; + customUserVars?: Record; + flowManager: FlowStateManager; + oauthStart?: (authURL: string) => Promise; + oauthEnd?: () => Promise; + }): Promise { + /** User-specific connection */ + let connection: MCPConnection | undefined; + const userId = user?.id; + const logPrefix = userId ? `[MCP][User: ${userId}][${serverName}]` : `[MCP][${serverName}]`; + + try { + if (!this.appConnections?.has(serverName) && userId && user) { + this.updateUserLastActivity(userId); + /** Get or create user-specific connection */ + connection = await this.getUserConnection({ + user, + serverName, + flowManager, + tokenMethods, + oauthStart, + oauthEnd, + signal: options?.signal, + customUserVars, + }); + } else { + /** App-level connection */ + connection = await this.appConnections!.get(serverName); + if (!connection) { + throw new McpError( + ErrorCode.InvalidRequest, + `${logPrefix} No app-level connection found. Cannot execute tool ${toolName}.`, + ); + } + } + + if (!(await connection.isConnected())) { + /** May happen if getUserConnection failed silently or app connection dropped */ + throw new McpError( + ErrorCode.InternalError, // Use InternalError for connection issues + `${logPrefix} Connection is not active. Cannot execute tool ${toolName}.`, + ); + } + + const result = await connection.client.request( + { + method: 'tools/call', + params: { + name: toolName, + arguments: toolArguments, + }, + }, + CallToolResultSchema, + { + timeout: connection.timeout, + ...options, + }, + ); + if (userId) { + this.updateUserLastActivity(userId); + } + this.checkIdleConnections(); + return formatToolContent(result as t.MCPToolCallResponse, provider); + } catch (error) { + // Log with context and re-throw or handle as needed + logger.error(`${logPrefix}[${toolName}] Tool call failed`, error); + // Rethrowing allows the caller (createMCPTool) to handle the final user message + throw error; + } + } +} diff --git a/packages/api/src/mcp/MCPServersRegistry.ts b/packages/api/src/mcp/MCPServersRegistry.ts new file mode 100644 index 000000000..b2ce2ed0a --- /dev/null +++ b/packages/api/src/mcp/MCPServersRegistry.ts @@ -0,0 +1,200 @@ +import { logger } from '@librechat/data-schemas'; +import mapValues from 'lodash/mapValues'; +import pickBy from 'lodash/pickBy'; +import pick from 'lodash/pick'; +import type { JsonSchemaType } from '~/types'; +import type * as t from '~/mcp/types'; +import { ConnectionsRepository } from '~/mcp/ConnectionsRepository'; +import { detectOAuthRequirement } from '~/mcp/oauth'; +import { type MCPConnection } from './connection'; +import { processMCPEnv } from '~/utils'; +import { CONSTANTS } from '~/mcp/enum'; + +type ParsedServerConfig = t.MCPOptions & { + url?: string; + requiresOAuth?: boolean; + oauthMetadata?: Record | null; + capabilities?: string; + tools?: string; +}; + +/** + * Manages MCP server configurations and metadata discovery. + * Fetches server capabilities, OAuth requirements, and tool definitions for registry. + * Determines which servers are for app-level connections. + * Has its own connections repository. All connections are disconnected after initialization. + */ +export class MCPServersRegistry { + private initialized: boolean = false; + private connections: ConnectionsRepository; + + public readonly rawConfigs: t.MCPServers; + public readonly parsedConfigs: Record; + + public oauthServers: Set | null = null; + public serverInstructions: Record | null = null; + public toolFunctions: t.LCAvailableTools | null = null; + public appServerConfigs: t.MCPServers | null = null; + + constructor(configs: t.MCPServers) { + this.rawConfigs = configs; + this.parsedConfigs = mapValues(configs, (con) => processMCPEnv(con)); + this.connections = new ConnectionsRepository(configs); + } + + /** Initializes all startup-enabled servers by gathering their metadata asynchronously */ + public async initialize() { + if (this.initialized) return; + this.initialized = true; + + const serverNames = Object.keys(this.parsedConfigs); + + await Promise.allSettled(serverNames.map((serverName) => this.gatherServerInfo(serverName))); + + this.setOAuthServers(); + this.setServerInstructions(); + this.setAppServerConfigs(); + await this.setAppToolFunctions(); + + this.connections.disconnectAll(); + } + + // Fetches all metadata for a single server in parallel + private async gatherServerInfo(serverName: string) { + try { + await Promise.allSettled([ + this.fetchOAuthRequirement(serverName).catch((error) => + logger.error(`${this.prefix(serverName)} Failed to fetch OAuth requirement:`, error), + ), + this.fetchServerInstructions(serverName).catch((error) => + logger.error(`${this.prefix(serverName)} Failed to fetch server instructions:`, error), + ), + this.fetchServerCapabilities(serverName).catch((error) => + logger.error(`${this.prefix(serverName)} Failed to fetch server capabilities:`, error), + ), + ]); + + this.logUpdatedConfig(serverName); + } catch (error) { + logger.error(`${this.prefix(serverName)} Failed to initialize server:`, error); + } + } + + // Sets app-level server configs (startup enabled, non-OAuth servers) + private setAppServerConfigs() { + const appServers = Object.keys( + pickBy( + this.parsedConfigs, + (config) => config.startup !== false && config.requiresOAuth === false, + ), + ); + this.appServerConfigs = pick(this.rawConfigs, appServers); + } + + // Creates set of server names that require OAuth authentication + private setOAuthServers() { + if (this.oauthServers) return this.oauthServers; + this.oauthServers = new Set( + Object.keys(pickBy(this.parsedConfigs, (config) => config.requiresOAuth)), + ); + return this.oauthServers; + } + + // Collects server instructions from all configured servers + private setServerInstructions() { + this.serverInstructions = mapValues( + pickBy(this.parsedConfigs, (config) => config.serverInstructions), + (config) => config.serverInstructions as string, + ); + } + + // Builds registry of all available tool functions from loaded connections + private async setAppToolFunctions() { + const connections = (await this.connections.getLoaded()).entries(); + const allToolFunctions: t.LCAvailableTools = {}; + for (const [serverName, conn] of connections) { + try { + const toolFunctions = await this.getToolFunctions(serverName, conn); + Object.assign(allToolFunctions, toolFunctions); + } catch (error) { + logger.error(`${this.prefix(serverName)} Error fetching tool functions:`, error); + } + } + this.toolFunctions = allToolFunctions; + } + + // Converts server tools to LibreChat-compatible tool functions format + private async getToolFunctions( + serverName: string, + conn: MCPConnection, + ): Promise { + const { tools } = await conn.client.listTools(); + + const toolFunctions: t.LCAvailableTools = {}; + tools.forEach((tool) => { + const name = `${tool.name}${CONSTANTS.mcp_delimiter}${serverName}`; + toolFunctions[name] = { + type: 'function', + ['function']: { + name, + description: tool.description, + parameters: tool.inputSchema as JsonSchemaType, + }, + }; + }); + + return toolFunctions; + } + + // Determines if server requires OAuth if not already specified in the config + private async fetchOAuthRequirement(serverName: string) { + const config = this.parsedConfigs[serverName]; + if (config.requiresOAuth != null) return; + if (config.url == null) return (config.requiresOAuth = false); + + const result = await detectOAuthRequirement(config.url); + config.requiresOAuth = result.requiresOAuth; + config.oauthMetadata = result.metadata; + } + + // Retrieves server instructions from MCP server if enabled in the config + private async fetchServerInstructions(serverName: string) { + const config = this.parsedConfigs[serverName]; + if (!config.serverInstructions) return; + if (typeof config.serverInstructions === 'string') return; + + const conn = await this.connections.get(serverName); + config.serverInstructions = conn.client.getInstructions(); + if (!config.serverInstructions) { + logger.warn(`${this.prefix(serverName)} No server instructions available`); + } + } + + // Fetches server capabilities and available tools list + private async fetchServerCapabilities(serverName: string) { + const config = this.parsedConfigs[serverName]; + const conn = await this.connections.get(serverName); + const capabilities = conn.client.getServerCapabilities(); + if (!capabilities) return; + config.capabilities = JSON.stringify(capabilities); + if (!capabilities.tools) return; + const tools = await conn.client.listTools(); + config.tools = tools.tools.map((tool) => tool.name).join(', '); + } + + // Logs server configuration summary after initialization + private logUpdatedConfig(serverName: string) { + const prefix = this.prefix(serverName); + const config = this.parsedConfigs[serverName]; + logger.info(`${prefix} URL: ${config.url ?? 'N/A'}`); + logger.info(`${prefix} OAuth Required: ${config.requiresOAuth}`); + logger.info(`${prefix} Capabilities: ${config.capabilities}`); + logger.info(`${prefix} Tools: ${config.tools}`); + logger.info(`${prefix} Server Instructions: ${config.serverInstructions ?? 'None'}`); + } + + // Returns formatted log prefix for server messages + private prefix(serverName: string): string { + return `[MCP][${serverName}]`; + } +} diff --git a/packages/api/src/mcp/UserConnectionManager.ts b/packages/api/src/mcp/UserConnectionManager.ts new file mode 100644 index 000000000..0590c8211 --- /dev/null +++ b/packages/api/src/mcp/UserConnectionManager.ts @@ -0,0 +1,236 @@ +import { ErrorCode, McpError } from '@modelcontextprotocol/sdk/types.js'; +import { logger } from '@librechat/data-schemas'; +import type { TokenMethods } from '@librechat/data-schemas'; +import type { TUser } from 'librechat-data-provider'; +import type { FlowStateManager } from '~/flow/manager'; +import type { MCPOAuthTokens } from '~/mcp/oauth'; +import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory'; +import { MCPServersRegistry } from '~/mcp/MCPServersRegistry'; +import { MCPConnection } from './connection'; +import type * as t from './types'; + +/** + * Abstract base class for managing user-specific MCP connections with lifecycle management. + * Only meant to be extended by MCPManager. + * Much of the logic was move here from the old MCPManager to make it more manageable. + * User connections will soon be ephemeral and not cached anymore: + * https://github.com/danny-avila/LibreChat/discussions/8790 + */ +export abstract class UserConnectionManager { + protected readonly serversRegistry: MCPServersRegistry; + protected userConnections: Map> = new Map(); + /** Last activity timestamp for users (not per server) */ + protected userLastActivity: Map = new Map(); + protected readonly USER_CONNECTION_IDLE_TIMEOUT = 15 * 60 * 1000; // 15 minutes (TODO: make configurable) + + constructor(serverConfigs: t.MCPServers) { + this.serversRegistry = new MCPServersRegistry(serverConfigs); + } + + /** fetches am MCP Server config from the registry */ + public getRawConfig(serverName: string): t.MCPOptions | undefined { + return this.serversRegistry.rawConfigs[serverName]; + } + + /** Updates the last activity timestamp for a user */ + protected updateUserLastActivity(userId: string): void { + const now = Date.now(); + this.userLastActivity.set(userId, now); + logger.debug( + `[MCP][User: ${userId}] Updated last activity timestamp: ${new Date(now).toISOString()}`, + ); + } + + /** Gets or creates a connection for a specific user */ + public async getUserConnection({ + user, + serverName, + flowManager, + customUserVars, + tokenMethods, + oauthStart, + oauthEnd, + signal, + returnOnOAuth = false, + }: { + user: TUser; + serverName: string; + flowManager: FlowStateManager; + customUserVars?: Record; + tokenMethods?: TokenMethods; + oauthStart?: (authURL: string) => Promise; + oauthEnd?: () => Promise; + signal?: AbortSignal; + returnOnOAuth?: boolean; + }): Promise { + const userId = user.id; + if (!userId) { + throw new McpError(ErrorCode.InvalidRequest, `[MCP] User object missing id property`); + } + + const userServerMap = this.userConnections.get(userId); + let connection = userServerMap?.get(serverName); + const now = Date.now(); + + // Check if user is idle + const lastActivity = this.userLastActivity.get(userId); + if (lastActivity && now - lastActivity > this.USER_CONNECTION_IDLE_TIMEOUT) { + logger.info(`[MCP][User: ${userId}] User idle for too long. Disconnecting all connections.`); + // Disconnect all user connections + try { + await this.disconnectUserConnections(userId); + } catch (err) { + logger.error(`[MCP][User: ${userId}] Error disconnecting idle connections:`, err); + } + connection = undefined; // Force creation of a new connection + } else if (connection) { + if (await connection.isConnected()) { + logger.debug(`[MCP][User: ${userId}][${serverName}] Reusing active connection`); + this.updateUserLastActivity(userId); + return connection; + } else { + // Connection exists but is not connected, attempt to remove potentially stale entry + logger.warn( + `[MCP][User: ${userId}][${serverName}] Found existing but disconnected connection object. Cleaning up.`, + ); + this.removeUserConnection(userId, serverName); // Clean up maps + connection = undefined; + } + } + + // If no valid connection exists, create a new one + if (!connection) { + logger.info(`[MCP][User: ${userId}][${serverName}] Establishing new connection`); + } + + const config = this.serversRegistry.parsedConfigs[serverName]; + if (!config) { + throw new McpError( + ErrorCode.InvalidRequest, + `[MCP][User: ${userId}] Configuration for server "${serverName}" not found.`, + ); + } + + try { + connection = await MCPConnectionFactory.create( + { + serverName: serverName, + serverConfig: config, + }, + { + useOAuth: true, + user: user, + customUserVars: customUserVars, + flowManager: flowManager, + tokenMethods: tokenMethods, + signal: signal, + oauthStart: oauthStart, + oauthEnd: oauthEnd, + returnOnOAuth: returnOnOAuth, + }, + ); + + if (!(await connection?.isConnected())) { + throw new Error('Failed to establish connection after initialization attempt.'); + } + + if (!this.userConnections.has(userId)) { + this.userConnections.set(userId, new Map()); + } + this.userConnections.get(userId)?.set(serverName, connection); + + logger.info(`[MCP][User: ${userId}][${serverName}] Connection successfully established`); + // Update timestamp on creation + this.updateUserLastActivity(userId); + return connection; + } catch (error) { + logger.error(`[MCP][User: ${userId}][${serverName}] Failed to establish connection`, error); + // Ensure partial connection state is cleaned up if initialization fails + await connection?.disconnect().catch((disconnectError) => { + logger.error( + `[MCP][User: ${userId}][${serverName}] Error during cleanup after failed connection`, + disconnectError, + ); + }); + // Ensure cleanup even if connection attempt fails + this.removeUserConnection(userId, serverName); + throw error; // Re-throw the error to the caller + } + } + + /** Returns all connections for a specific user */ + public getUserConnections(userId: string) { + return this.userConnections.get(userId); + } + + /** Removes a specific user connection entry */ + protected removeUserConnection(userId: string, serverName: string): void { + const userMap = this.userConnections.get(userId); + if (userMap) { + userMap.delete(serverName); + if (userMap.size === 0) { + this.userConnections.delete(userId); + // Only remove user activity timestamp if all connections are gone + this.userLastActivity.delete(userId); + } + } + + logger.debug(`[MCP][User: ${userId}][${serverName}] Removed connection entry.`); + } + + /** Disconnects and removes a specific user connection */ + public async disconnectUserConnection(userId: string, serverName: string): Promise { + const userMap = this.userConnections.get(userId); + const connection = userMap?.get(serverName); + if (connection) { + logger.info(`[MCP][User: ${userId}][${serverName}] Disconnecting...`); + await connection.disconnect(); + this.removeUserConnection(userId, serverName); + } + } + + /** Disconnects and removes all connections for a specific user */ + public async disconnectUserConnections(userId: string): Promise { + const userMap = this.userConnections.get(userId); + const disconnectPromises: Promise[] = []; + if (userMap) { + logger.info(`[MCP][User: ${userId}] Disconnecting all servers...`); + const userServers = Array.from(userMap.keys()); + for (const serverName of userServers) { + disconnectPromises.push( + this.disconnectUserConnection(userId, serverName).catch((error) => { + logger.error( + `[MCP][User: ${userId}][${serverName}] Error during disconnection:`, + error, + ); + }), + ); + } + await Promise.allSettled(disconnectPromises); + // Ensure user activity timestamp is removed + this.userLastActivity.delete(userId); + logger.info(`[MCP][User: ${userId}] All connections processed for disconnection.`); + } + } + + /** Check for and disconnect idle connections */ + protected checkIdleConnections(currentUserId?: string): void { + const now = Date.now(); + + // Iterate through all users to check for idle ones + for (const [userId, lastActivity] of this.userLastActivity.entries()) { + if (currentUserId && currentUserId === userId) { + continue; + } + if (now - lastActivity > this.USER_CONNECTION_IDLE_TIMEOUT) { + logger.info( + `[MCP][User: ${userId}] User idle for too long. Disconnecting all connections...`, + ); + // Disconnect all user connections asynchronously (fire and forget) + this.disconnectUserConnections(userId).catch((err) => + logger.error(`[MCP][User: ${userId}] Error disconnecting idle connections:`, err), + ); + } + } + } +} diff --git a/packages/api/src/mcp/__tests__/ConnectionsRepository.test.ts b/packages/api/src/mcp/__tests__/ConnectionsRepository.test.ts new file mode 100644 index 000000000..dd71466e9 --- /dev/null +++ b/packages/api/src/mcp/__tests__/ConnectionsRepository.test.ts @@ -0,0 +1,212 @@ +import { logger } from '@librechat/data-schemas'; +import { ConnectionsRepository } from '../ConnectionsRepository'; +import { MCPConnectionFactory } from '../MCPConnectionFactory'; +import { MCPConnection } from '../connection'; +import type * as t from '../types'; + +// Mock external dependencies +jest.mock('@librechat/data-schemas', () => ({ + logger: { + error: jest.fn(), + }, +})); + +jest.mock('../MCPConnectionFactory', () => ({ + MCPConnectionFactory: { + create: jest.fn(), + }, +})); + +jest.mock('../connection'); + +const mockLogger = logger as jest.Mocked; + +describe('ConnectionsRepository', () => { + let repository: ConnectionsRepository; + let mockServerConfigs: t.MCPServers; + let mockConnection: jest.Mocked; + + beforeEach(() => { + mockServerConfigs = { + server1: { url: 'http://localhost:3001' }, + server2: { command: 'test-command', args: ['--test'] }, + server3: { url: 'ws://localhost:8080', type: 'websocket' }, + }; + + mockConnection = { + isConnected: jest.fn().mockResolvedValue(true), + disconnect: jest.fn().mockResolvedValue(undefined), + } as unknown as jest.Mocked; + + (MCPConnectionFactory.create as jest.Mock).mockResolvedValue(mockConnection); + + repository = new ConnectionsRepository(mockServerConfigs); + + jest.clearAllMocks(); + }); + + afterEach(() => { + jest.clearAllMocks(); + }); + + describe('has', () => { + it('should return true for existing server', () => { + expect(repository.has('server1')).toBe(true); + }); + + it('should return false for non-existing server', () => { + expect(repository.has('nonexistent')).toBe(false); + }); + }); + + describe('get', () => { + it('should return existing connected connection', async () => { + mockConnection.isConnected.mockResolvedValue(true); + repository['connections'].set('server1', mockConnection); + + const result = await repository.get('server1'); + + expect(result).toBe(mockConnection); + expect(MCPConnectionFactory.create).not.toHaveBeenCalled(); + }); + + it('should create new connection if none exists', async () => { + const result = await repository.get('server1'); + + expect(result).toBe(mockConnection); + expect(MCPConnectionFactory.create).toHaveBeenCalledWith( + { + serverName: 'server1', + serverConfig: mockServerConfigs.server1, + }, + undefined, + ); + expect(repository['connections'].get('server1')).toBe(mockConnection); + }); + + it('should create new connection if existing connection is not connected', async () => { + const oldConnection = { + isConnected: jest.fn().mockResolvedValue(false), + disconnect: jest.fn().mockResolvedValue(undefined), + } as unknown as jest.Mocked; + repository['connections'].set('server1', oldConnection); + + const result = await repository.get('server1'); + + expect(result).toBe(mockConnection); + expect(oldConnection.disconnect).toHaveBeenCalled(); + expect(MCPConnectionFactory.create).toHaveBeenCalledWith( + { + serverName: 'server1', + serverConfig: mockServerConfigs.server1, + }, + undefined, + ); + }); + + it('should throw error for non-existent server configuration', async () => { + await expect(repository.get('nonexistent')).rejects.toThrow( + '[MCP][nonexistent] Server not found in configuration', + ); + }); + + it('should handle MCPConnectionFactory.create errors', async () => { + const createError = new Error('Connection creation failed'); + (MCPConnectionFactory.create as jest.Mock).mockRejectedValue(createError); + + await expect(repository.get('server1')).rejects.toThrow('Connection creation failed'); + }); + }); + + describe('getMany', () => { + it('should return connections for multiple servers', async () => { + const result = await repository.getMany(['server1', 'server3']); + + expect(result).toBeInstanceOf(Map); + expect(result.size).toBe(2); + expect(result.get('server1')).toBe(mockConnection); + expect(result.get('server3')).toBe(mockConnection); + }); + }); + + describe('getLoaded', () => { + it('should return connections for loaded servers only', async () => { + // Load one connection + await repository.get('server1'); + + const result = await repository.getLoaded(); + + expect(result).toBeInstanceOf(Map); + expect(result.size).toBe(1); + expect(result.get('server1')).toBe(mockConnection); + }); + + it('should return empty map when no connections are loaded', async () => { + const result = await repository.getLoaded(); + + expect(result).toBeInstanceOf(Map); + expect(result.size).toBe(0); + }); + }); + + describe('getAll', () => { + it('should return connections for all configured servers', async () => { + const result = await repository.getAll(); + + expect(result).toBeInstanceOf(Map); + expect(result.size).toBe(3); + expect(result.get('server1')).toBe(mockConnection); + expect(result.get('server2')).toBe(mockConnection); + expect(result.get('server3')).toBe(mockConnection); + }); + }); + + describe('disconnect', () => { + it('should disconnect and remove existing connection', async () => { + repository['connections'].set('server1', mockConnection); + + await repository.disconnect('server1'); + + expect(mockConnection.disconnect).toHaveBeenCalled(); + expect(repository['connections'].has('server1')).toBe(false); + }); + + it('should handle disconnect error gracefully', async () => { + const disconnectError = new Error('Disconnect failed'); + mockConnection.disconnect.mockRejectedValue(disconnectError); + repository['connections'].set('server1', mockConnection); + + await repository.disconnect('server1'); + + expect(mockConnection.disconnect).toHaveBeenCalled(); + expect(repository['connections'].has('server1')).toBe(false); + expect(mockLogger.error).toHaveBeenCalledWith( + '[MCP][server1] Error disconnecting', + disconnectError, + ); + }); + }); + + describe('disconnectAll', () => { + it('should disconnect all active connections', () => { + const mockConnection1 = { + disconnect: jest.fn().mockResolvedValue(undefined), + } as unknown as jest.Mocked; + const mockConnection2 = { + disconnect: jest.fn().mockResolvedValue(undefined), + } as unknown as jest.Mocked; + const mockConnection3 = { + disconnect: jest.fn().mockResolvedValue(undefined), + } as unknown as jest.Mocked; + + repository['connections'].set('server1', mockConnection1); + repository['connections'].set('server2', mockConnection2); + repository['connections'].set('server3', mockConnection3); + + const promises = repository.disconnectAll(); + + expect(promises).toHaveLength(3); + expect(Array.isArray(promises)).toBe(true); + }); + }); +}); diff --git a/packages/api/src/mcp/__tests__/MCPConnectionFactory.test.ts b/packages/api/src/mcp/__tests__/MCPConnectionFactory.test.ts new file mode 100644 index 000000000..9ca447f92 --- /dev/null +++ b/packages/api/src/mcp/__tests__/MCPConnectionFactory.test.ts @@ -0,0 +1,347 @@ +import { logger } from '@librechat/data-schemas'; +import type { TUser } from 'librechat-data-provider'; +import type { FlowStateManager } from '~/flow/manager'; +import type { MCPOAuthTokens } from '~/mcp/oauth'; +import { MCPConnectionFactory } from '../MCPConnectionFactory'; +import { MCPOAuthHandler } from '~/mcp/oauth'; +import { MCPConnection } from '../connection'; +import { processMCPEnv } from '~/utils'; +import type * as t from '../types'; + +jest.mock('../connection'); +jest.mock('~/mcp/oauth'); +jest.mock('~/utils'); +jest.mock('@librechat/data-schemas', () => ({ + logger: { + info: jest.fn(), + warn: jest.fn(), + error: jest.fn(), + debug: jest.fn(), + }, +})); + +const mockLogger = logger as jest.Mocked; +const mockProcessMCPEnv = processMCPEnv as jest.MockedFunction; +const mockMCPConnection = MCPConnection as jest.MockedClass; +const mockMCPOAuthHandler = MCPOAuthHandler as jest.Mocked; + +describe('MCPConnectionFactory', () => { + let mockUser: TUser; + let mockServerConfig: t.MCPOptions; + let mockFlowManager: jest.Mocked>; + let mockConnectionInstance: jest.Mocked; + + beforeEach(() => { + jest.clearAllMocks(); + mockUser = { + id: 'user123', + email: 'test@example.com', + } as TUser; + + mockServerConfig = { + command: 'node', + args: ['server.js'], + initTimeout: 5000, + } as t.MCPOptions; + + mockFlowManager = { + createFlow: jest.fn(), + createFlowWithHandler: jest.fn(), + getFlowState: jest.fn(), + } as unknown as jest.Mocked>; + + mockConnectionInstance = { + connect: jest.fn(), + isConnected: jest.fn(), + setOAuthTokens: jest.fn(), + on: jest.fn().mockReturnValue(mockConnectionInstance), + emit: jest.fn(), + } as unknown as jest.Mocked; + + mockMCPConnection.mockImplementation(() => mockConnectionInstance); + mockProcessMCPEnv.mockReturnValue(mockServerConfig); + }); + + describe('static create method', () => { + it('should create a basic connection without OAuth', async () => { + const basicOptions = { + serverName: 'test-server', + serverConfig: mockServerConfig, + }; + + mockConnectionInstance.isConnected.mockResolvedValue(true); + + const connection = await MCPConnectionFactory.create(basicOptions); + + expect(connection).toBe(mockConnectionInstance); + expect(mockProcessMCPEnv).toHaveBeenCalledWith(mockServerConfig, undefined, undefined); + expect(mockMCPConnection).toHaveBeenCalledWith({ + serverName: 'test-server', + serverConfig: mockServerConfig, + userId: undefined, + oauthTokens: null, + }); + expect(mockConnectionInstance.connect).toHaveBeenCalled(); + }); + + it('should create a connection with OAuth', async () => { + const basicOptions = { + serverName: 'test-server', + serverConfig: mockServerConfig, + }; + + const oauthOptions = { + useOAuth: true as const, + user: mockUser, + flowManager: mockFlowManager, + tokenMethods: { + findToken: jest.fn(), + createToken: jest.fn(), + updateToken: jest.fn(), + deleteTokens: jest.fn(), + }, + }; + + const mockTokens: MCPOAuthTokens = { + access_token: 'access123', + refresh_token: 'refresh123', + token_type: 'Bearer', + obtained_at: Date.now(), + }; + + mockFlowManager.createFlowWithHandler.mockResolvedValue(mockTokens); + mockConnectionInstance.isConnected.mockResolvedValue(true); + + const connection = await MCPConnectionFactory.create(basicOptions, oauthOptions); + + expect(connection).toBe(mockConnectionInstance); + expect(mockProcessMCPEnv).toHaveBeenCalledWith(mockServerConfig, mockUser, undefined); + expect(mockMCPConnection).toHaveBeenCalledWith({ + serverName: 'test-server', + serverConfig: mockServerConfig, + userId: 'user123', + oauthTokens: mockTokens, + }); + }); + }); + + describe('OAuth token handling', () => { + it('should return null when no findToken method is provided', async () => { + const basicOptions = { + serverName: 'test-server', + serverConfig: mockServerConfig, + }; + + const oauthOptions = { + useOAuth: true as const, + user: mockUser, + flowManager: mockFlowManager, + tokenMethods: { + findToken: undefined as unknown as () => Promise, + createToken: jest.fn(), + updateToken: jest.fn(), + deleteTokens: jest.fn(), + }, + }; + + mockConnectionInstance.isConnected.mockResolvedValue(true); + + await MCPConnectionFactory.create(basicOptions, oauthOptions); + + expect(mockFlowManager.createFlowWithHandler).not.toHaveBeenCalled(); + }); + + it('should handle token retrieval errors gracefully', async () => { + const basicOptions = { + serverName: 'test-server', + serverConfig: mockServerConfig, + }; + + const oauthOptions = { + useOAuth: true as const, + user: mockUser, + flowManager: mockFlowManager, + tokenMethods: { + findToken: jest.fn(), + createToken: jest.fn(), + updateToken: jest.fn(), + deleteTokens: jest.fn(), + }, + }; + + mockFlowManager.createFlowWithHandler.mockRejectedValue(new Error('Token fetch failed')); + mockConnectionInstance.isConnected.mockResolvedValue(true); + + const connection = await MCPConnectionFactory.create(basicOptions, oauthOptions); + + expect(connection).toBe(mockConnectionInstance); + expect(mockMCPConnection).toHaveBeenCalledWith({ + serverName: 'test-server', + serverConfig: mockServerConfig, + userId: 'user123', + oauthTokens: null, + }); + expect(mockLogger.debug).toHaveBeenCalledWith( + expect.stringContaining('No existing tokens found or error loading tokens'), + expect.any(Error), + ); + }); + }); + + describe('OAuth event handling', () => { + it('should handle oauthRequired event for returnOnOAuth scenario', async () => { + const basicOptions = { + serverName: 'test-server', + serverConfig: { + ...mockServerConfig, + url: 'https://api.example.com', + type: 'sse' as const, + } as t.SSEOptions, + }; + + const oauthOptions = { + useOAuth: true as const, + user: mockUser, + flowManager: mockFlowManager, + returnOnOAuth: true, + oauthStart: jest.fn(), + tokenMethods: { + findToken: jest.fn(), + createToken: jest.fn(), + updateToken: jest.fn(), + deleteTokens: jest.fn(), + }, + }; + + const mockFlowData = { + authorizationUrl: 'https://auth.example.com', + flowId: 'flow123', + flowMetadata: { + serverName: 'test-server', + userId: 'user123', + serverUrl: 'https://api.example.com', + state: 'random-state', + clientInfo: { client_id: 'client123' }, + }, + }; + + mockMCPOAuthHandler.initiateOAuthFlow.mockResolvedValue(mockFlowData); + mockFlowManager.createFlow.mockRejectedValue(new Error('Timeout expected')); + mockConnectionInstance.isConnected.mockResolvedValue(false); + + let oauthRequiredHandler: (data: Record) => Promise; + mockConnectionInstance.on.mockImplementation((event, handler) => { + if (event === 'oauthRequired') { + oauthRequiredHandler = handler as (data: Record) => Promise; + } + return mockConnectionInstance; + }); + + try { + await MCPConnectionFactory.create(basicOptions, oauthOptions); + } catch { + // Expected to fail due to connection not established + } + + expect(oauthRequiredHandler!).toBeDefined(); + + await oauthRequiredHandler!({ serverUrl: 'https://api.example.com' }); + + expect(mockMCPOAuthHandler.initiateOAuthFlow).toHaveBeenCalledWith( + 'test-server', + 'https://api.example.com', + 'user123', + undefined, + ); + expect(oauthOptions.oauthStart).toHaveBeenCalledWith('https://auth.example.com'); + expect(mockConnectionInstance.emit).toHaveBeenCalledWith( + 'oauthFailed', + expect.objectContaining({ + message: 'OAuth flow initiated - return early', + }), + ); + }); + }); + + describe('connection retry logic', () => { + it('should establish connection successfully', async () => { + const basicOptions = { + serverName: 'test-server', + serverConfig: mockServerConfig, // Use default 5000ms timeout + }; + + mockConnectionInstance.connect.mockResolvedValue(undefined); + mockConnectionInstance.isConnected.mockResolvedValue(true); + + const connection = await MCPConnectionFactory.create(basicOptions); + + expect(connection).toBe(mockConnectionInstance); + expect(mockConnectionInstance.connect).toHaveBeenCalledTimes(1); + }); + + it('should handle OAuth errors during connection attempts', async () => { + const basicOptions = { + serverName: 'test-server', + serverConfig: mockServerConfig, + }; + + const oauthOptions = { + useOAuth: true as const, + user: mockUser, + flowManager: mockFlowManager, + oauthStart: jest.fn(), + tokenMethods: { + findToken: jest.fn(), + createToken: jest.fn(), + updateToken: jest.fn(), + deleteTokens: jest.fn(), + }, + }; + + const oauthError = new Error('Non-200 status code (401)'); + (oauthError as unknown as Record).isOAuthError = true; + + mockConnectionInstance.connect.mockRejectedValue(oauthError); + mockConnectionInstance.isConnected.mockResolvedValue(false); + + await expect(MCPConnectionFactory.create(basicOptions, oauthOptions)).rejects.toThrow( + 'Non-200 status code (401)', + ); + + expect(mockLogger.info).toHaveBeenCalledWith( + expect.stringContaining('OAuth required, stopping connection attempts'), + ); + }); + }); + + describe('isOAuthError method', () => { + it('should identify OAuth errors by message content', async () => { + const basicOptions = { + serverName: 'test-server', + serverConfig: mockServerConfig, + }; + + const oauthOptions = { + useOAuth: true as const, + user: mockUser, + flowManager: mockFlowManager, + tokenMethods: { + findToken: jest.fn(), + createToken: jest.fn(), + updateToken: jest.fn(), + deleteTokens: jest.fn(), + }, + }; + + const error401 = new Error('401 Unauthorized'); + + mockConnectionInstance.connect.mockRejectedValue(error401); + mockConnectionInstance.isConnected.mockResolvedValue(false); + + await expect(MCPConnectionFactory.create(basicOptions, oauthOptions)).rejects.toThrow('401'); + expect(mockLogger.info).toHaveBeenCalledWith( + expect.stringContaining('OAuth required, stopping connection attempts'), + ); + }); + }); +}); diff --git a/packages/api/src/mcp/__tests__/MCPServersRegistry.test.ts b/packages/api/src/mcp/__tests__/MCPServersRegistry.test.ts new file mode 100644 index 000000000..63fdc72ff --- /dev/null +++ b/packages/api/src/mcp/__tests__/MCPServersRegistry.test.ts @@ -0,0 +1,287 @@ +import { readFileSync } from 'fs'; +import { join } from 'path'; +import { logger } from '@librechat/data-schemas'; +import { load as yamlLoad } from 'js-yaml'; +import { ConnectionsRepository } from '../ConnectionsRepository'; +import { MCPServersRegistry } from '../MCPServersRegistry'; +import { detectOAuthRequirement } from '~/mcp/oauth'; +import { MCPConnection } from '../connection'; +import type * as t from '../types'; + +// Mock external dependencies +jest.mock('../oauth/detectOAuth'); +jest.mock('../ConnectionsRepository'); +jest.mock('../connection'); +jest.mock('@librechat/data-schemas', () => ({ + logger: { + info: jest.fn(), + warn: jest.fn(), + error: jest.fn(), + debug: jest.fn(), + }, +})); + +// Mock processMCPEnv to verify it's called and adds a processed marker +jest.mock('~/utils', () => ({ + ...jest.requireActual('~/utils'), + processMCPEnv: jest.fn((config) => ({ + ...config, + _processed: true, // Simple marker to verify processing occurred + })), +})); + +const mockDetectOAuthRequirement = detectOAuthRequirement as jest.MockedFunction< + typeof detectOAuthRequirement +>; +const mockLogger = logger as jest.Mocked; + +describe('MCPServersRegistry - Initialize Function', () => { + let rawConfigs: t.MCPServers; + let expectedParsedConfigs: Record; + let mockConnectionsRepo: jest.Mocked; + let mockConnections: Map>; + + beforeEach(() => { + // Load fixtures + const rawConfigsPath = join(__dirname, 'fixtures', 'MCPServersRegistry.rawConfigs.yml'); + const parsedConfigsPath = join(__dirname, 'fixtures', 'MCPServersRegistry.parsedConfigs.yml'); + + rawConfigs = yamlLoad(readFileSync(rawConfigsPath, 'utf8')) as t.MCPServers; + expectedParsedConfigs = yamlLoad(readFileSync(parsedConfigsPath, 'utf8')) as Record< + string, + any + >; + + // Setup mock connections + mockConnections = new Map(); + const serverNames = Object.keys(rawConfigs); + + serverNames.forEach((serverName) => { + const mockConnection = { + client: { + listTools: jest.fn(), + getInstructions: jest.fn(), + getServerCapabilities: jest.fn(), + }, + } as unknown as jest.Mocked; + + // Setup mock responses based on expected configs + const expectedConfig = expectedParsedConfigs[serverName]; + + // Mock listTools response + if (expectedConfig.tools) { + const toolNames = expectedConfig.tools.split(', '); + const tools = toolNames.map((name: string) => ({ + name, + description: `Description for ${name}`, + inputSchema: { + type: 'object', + properties: { + input: { type: 'string' }, + }, + }, + })); + mockConnection.client.listTools.mockResolvedValue({ tools }); + } else { + mockConnection.client.listTools.mockResolvedValue({ tools: [] }); + } + + // Mock getInstructions response + if (expectedConfig.serverInstructions) { + mockConnection.client.getInstructions.mockReturnValue(expectedConfig.serverInstructions); + } else { + mockConnection.client.getInstructions.mockReturnValue(null); + } + + // Mock getServerCapabilities response + if (expectedConfig.capabilities) { + const capabilities = JSON.parse(expectedConfig.capabilities); + mockConnection.client.getServerCapabilities.mockReturnValue(capabilities); + } else { + mockConnection.client.getServerCapabilities.mockReturnValue(null); + } + + mockConnections.set(serverName, mockConnection); + }); + + // Setup ConnectionsRepository mock + mockConnectionsRepo = { + get: jest.fn(), + getLoaded: jest.fn(), + disconnectAll: jest.fn(), + } as unknown as jest.Mocked; + + mockConnectionsRepo.get.mockImplementation((serverName: string) => + Promise.resolve(mockConnections.get(serverName)!), + ); + + mockConnectionsRepo.getLoaded.mockResolvedValue(mockConnections); + + (ConnectionsRepository as jest.Mock).mockImplementation(() => mockConnectionsRepo); + + // Setup OAuth detection mock with deterministic results + mockDetectOAuthRequirement.mockImplementation((url: string) => { + const oauthResults: Record = { + 'https://api.github.com/mcp': { + requiresOAuth: true, + metadata: { + authorization_url: 'https://github.com/login/oauth/authorize', + token_url: 'https://github.com/login/oauth/access_token', + }, + }, + 'https://api.disabled.com/mcp': { + requiresOAuth: false, + metadata: null, + }, + 'https://api.public.com/mcp': { + requiresOAuth: false, + metadata: null, + }, + }; + + return Promise.resolve(oauthResults[url] || { requiresOAuth: false, metadata: null }); + }); + + // Clear all mocks + jest.clearAllMocks(); + }); + + afterEach(() => { + jest.clearAllMocks(); + }); + + describe('initialize() method', () => { + it('should only run initialization once', async () => { + const registry = new MCPServersRegistry(rawConfigs); + + await registry.initialize(); + await registry.initialize(); // Second call should not re-run + + // Verify that connections are only requested for servers that need them + // (servers with serverInstructions=true or all servers for capabilities) + expect(mockConnectionsRepo.get).toHaveBeenCalled(); + }); + + it('should set all public properties correctly after initialization', async () => { + const registry = new MCPServersRegistry(rawConfigs); + + // Verify initial state + expect(registry.oauthServers).toBeNull(); + expect(registry.serverInstructions).toBeNull(); + expect(registry.toolFunctions).toBeNull(); + expect(registry.appServerConfigs).toBeNull(); + + await registry.initialize(); + + // Test oauthServers Set + expect(registry.oauthServers).toBeInstanceOf(Set); + expect(registry.oauthServers).toEqual( + new Set(['oauth_server', 'oauth_predefined', 'oauth_startup_enabled']), + ); + + // Test serverInstructions + expect(registry.serverInstructions).toEqual({ + oauth_server: 'GitHub MCP server instructions', + stdio_server: 'Follow these instructions for stdio server', + non_oauth_server: 'Public API instructions', + }); + + // Test appServerConfigs (startup enabled, non-OAuth servers only) + expect(registry.appServerConfigs).toEqual({ + stdio_server: rawConfigs.stdio_server, + websocket_server: rawConfigs.websocket_server, + non_oauth_server: rawConfigs.non_oauth_server, + }); + + // Test toolFunctions (only 2 servers have tools: oauth_server has 1, stdio_server has 2) + const expectedToolFunctions = { + get_repository_mcp_oauth_server: { + type: 'function', + function: { + name: 'get_repository_mcp_oauth_server', + description: 'Description for get_repository', + parameters: { type: 'object', properties: { input: { type: 'string' } } }, + }, + }, + file_read_mcp_stdio_server: { + type: 'function', + function: { + name: 'file_read_mcp_stdio_server', + description: 'Description for file_read', + parameters: { type: 'object', properties: { input: { type: 'string' } } }, + }, + }, + file_write_mcp_stdio_server: { + type: 'function', + function: { + name: 'file_write_mcp_stdio_server', + description: 'Description for file_write', + parameters: { type: 'object', properties: { input: { type: 'string' } } }, + }, + }, + }; + expect(registry.toolFunctions).toEqual(expectedToolFunctions); + }); + + it('should handle errors gracefully and continue initialization', async () => { + const registry = new MCPServersRegistry(rawConfigs); + + // Make one server throw an error + mockDetectOAuthRequirement.mockRejectedValueOnce(new Error('OAuth detection failed')); + + await registry.initialize(); + + // Should still initialize successfully + expect(registry.oauthServers).toBeInstanceOf(Set); + expect(registry.toolFunctions).toBeDefined(); + + // Error should be logged + expect(mockLogger.error).toHaveBeenCalledWith( + expect.stringContaining('[MCP][oauth_server] Failed to fetch OAuth requirement:'), + expect.any(Error), + ); + }); + + it('should disconnect all connections after initialization', async () => { + const registry = new MCPServersRegistry(rawConfigs); + + await registry.initialize(); + + expect(mockConnectionsRepo.disconnectAll).toHaveBeenCalledTimes(1); + }); + + it('should log configuration updates for each server', async () => { + const registry = new MCPServersRegistry(rawConfigs); + + await registry.initialize(); + + const serverNames = Object.keys(rawConfigs); + serverNames.forEach((serverName) => { + expect(mockLogger.info).toHaveBeenCalledWith( + expect.stringContaining(`[MCP][${serverName}] URL:`), + ); + expect(mockLogger.info).toHaveBeenCalledWith( + expect.stringContaining(`[MCP][${serverName}] OAuth Required:`), + ); + expect(mockLogger.info).toHaveBeenCalledWith( + expect.stringContaining(`[MCP][${serverName}] Capabilities:`), + ); + expect(mockLogger.info).toHaveBeenCalledWith( + expect.stringContaining(`[MCP][${serverName}] Tools:`), + ); + expect(mockLogger.info).toHaveBeenCalledWith( + expect.stringContaining(`[MCP][${serverName}] Server Instructions:`), + ); + }); + }); + + it('should have parsedConfigs matching the expected fixture after initialization', async () => { + const registry = new MCPServersRegistry(rawConfigs); + + await registry.initialize(); + + // Compare the actual parsedConfigs against the expected fixture + expect(registry.parsedConfigs).toEqual(expectedParsedConfigs); + }); + }); +}); diff --git a/packages/api/src/mcp/auth.test.ts b/packages/api/src/mcp/__tests__/auth.test.ts similarity index 99% rename from packages/api/src/mcp/auth.test.ts rename to packages/api/src/mcp/__tests__/auth.test.ts index 7bfb40ae9..5d3793cfb 100644 --- a/packages/api/src/mcp/auth.test.ts +++ b/packages/api/src/mcp/__tests__/auth.test.ts @@ -1,7 +1,7 @@ import type { PluginAuthMethods } from '@librechat/data-schemas'; import type { GenericTool } from '@librechat/agents'; import { getPluginAuthMap } from '~/agents/auth'; -import { getUserMCPAuthMap } from './auth'; +import { getUserMCPAuthMap } from '../auth'; jest.mock('~/agents/auth', () => ({ getPluginAuthMap: jest.fn(), diff --git a/packages/api/src/mcp/__tests__/detectOAuth.integration.dev.ts b/packages/api/src/mcp/__tests__/detectOAuth.integration.dev.ts new file mode 100644 index 000000000..7881b9bd6 --- /dev/null +++ b/packages/api/src/mcp/__tests__/detectOAuth.integration.dev.ts @@ -0,0 +1,76 @@ +// Integration tests for OAuth detection against real public MCP servers +// These tests verify the actual behavior against live endpoints +// +// DEVELOPMENT ONLY: This file is excluded from the test suite (.dev.ts extension) +// Use this for development and debugging OAuth detection behavior +// +// To run manually from packages/api directory: +// npx jest --testMatch="**/detectOAuth.integration.dev.ts" + +import { detectOAuthRequirement } from '~/mcp/oauth'; + +describe('OAuth Detection Integration Tests', () => { + const NETWORK_TIMEOUT = 10000; + + interface TestServer { + name: string; + url: string; + expectedOAuth: boolean; + expectedMethod: string; + withMeta: boolean; + } + + const testServers: TestServer[] = [ + { + name: 'GitHub Copilot MCP Server', + url: 'https://api.githubcopilot.com/mcp', + expectedOAuth: true, + expectedMethod: '401-challenge-metadata', + withMeta: true, + }, + { + name: 'GitHub API (401 without metadata)', + url: 'https://api.github.com/user', + expectedOAuth: true, + expectedMethod: 'no-metadata-found', + withMeta: false, + }, + { + name: 'Stytch Todo MCP Server', + url: 'https://mcp-stytch-consumer-todo-list.maxwell-gerber42.workers.dev', + expectedOAuth: true, + expectedMethod: 'protected-resource-metadata', + withMeta: true, + }, + { + name: 'HTTPBin (Non-OAuth)', + url: 'https://httpbin.org', + expectedOAuth: false, + expectedMethod: 'no-metadata-found', + withMeta: false, + }, + { + name: 'Unreachable Server', + url: 'https://definitely-not-a-real-server-12345.com', + expectedOAuth: false, + expectedMethod: 'no-metadata-found', + withMeta: false, + }, + ]; + + describe('detectOAuthRequirement integration', () => { + testServers.forEach((server) => { + it( + `should handle ${server.name}`, + async () => { + const result = await detectOAuthRequirement(server.url); + + expect(result.requiresOAuth).toBe(server.expectedOAuth); + expect(result.method).toBe(server.expectedMethod); + expect(result.metadata == null).toBe(!server.withMeta); + }, + NETWORK_TIMEOUT, + ); + }); + }); +}); diff --git a/packages/api/src/mcp/__tests__/fixtures/MCPServersRegistry.parsedConfigs.yml b/packages/api/src/mcp/__tests__/fixtures/MCPServersRegistry.parsedConfigs.yml new file mode 100644 index 000000000..55ddbf4ed --- /dev/null +++ b/packages/api/src/mcp/__tests__/fixtures/MCPServersRegistry.parsedConfigs.yml @@ -0,0 +1,74 @@ +# Expected parsed MCP server configurations after running initialize() +# These represent the expected state of parsedConfigs after all fetch functions complete + +oauth_server: + _processed: true + type: "streamable-http" + url: "https://api.github.com/mcp" + headers: + Authorization: "Bearer {{GITHUB_TOKEN}}" + serverInstructions: "GitHub MCP server instructions" + requiresOAuth: true + oauthMetadata: + authorization_url: "https://github.com/login/oauth/authorize" + token_url: "https://github.com/login/oauth/access_token" + capabilities: '{"tools":{"listChanged":true},"resources":{},"prompts":{}}' + tools: "get_repository" + +oauth_predefined: + _processed: true + type: "sse" + url: "https://api.example.com/sse" + requiresOAuth: true + oauthMetadata: + authorization_url: "https://example.com/oauth/authorize" + token_url: "https://example.com/oauth/token" + capabilities: '{"tools":{},"resources":{},"prompts":{}}' + tools: "" + +stdio_server: + _processed: true + command: "node" + args: ["server.js"] + env: + API_KEY: "${TEST_API_KEY}" + startup: true + serverInstructions: "Follow these instructions for stdio server" + requiresOAuth: false + capabilities: '{"tools":{"listChanged":true},"resources":{"listChanged":true},"prompts":{}}' + tools: "file_read, file_write" + +websocket_server: + _processed: true + type: "websocket" + url: "ws://localhost:3001/mcp" + startup: true + requiresOAuth: false + oauthMetadata: null + capabilities: '{"tools":{},"resources":{},"prompts":{}}' + tools: "" + +disabled_server: + _processed: true + type: "streamable-http" + url: "https://api.disabled.com/mcp" + startup: false + requiresOAuth: false + oauthMetadata: null + +non_oauth_server: + _processed: true + type: "streamable-http" + url: "https://api.public.com/mcp" + requiresOAuth: false + serverInstructions: "Public API instructions" + capabilities: '{"tools":{},"resources":{},"prompts":{}}' + tools: "" + +oauth_startup_enabled: + _processed: true + type: "sse" + url: "https://api.oauth-startup.com/sse" + requiresOAuth: true + capabilities: '{"tools":{},"resources":{},"prompts":{}}' + tools: "" \ No newline at end of file diff --git a/packages/api/src/mcp/__tests__/fixtures/MCPServersRegistry.rawConfigs.yml b/packages/api/src/mcp/__tests__/fixtures/MCPServersRegistry.rawConfigs.yml new file mode 100644 index 000000000..907dfaa96 --- /dev/null +++ b/packages/api/src/mcp/__tests__/fixtures/MCPServersRegistry.rawConfigs.yml @@ -0,0 +1,53 @@ +# Raw MCP server configurations used as input to MCPServersRegistry constructor +# These configs test different code paths in the initialization process + +# Test OAuth detection with URL - should trigger fetchOAuthRequirement +oauth_server: + type: "streamable-http" + url: "https://api.github.com/mcp" + headers: + Authorization: "Bearer {{GITHUB_TOKEN}}" + serverInstructions: true + +# Test OAuth already specified - should skip OAuth detection +oauth_predefined: + type: "sse" + url: "https://api.example.com/sse" + requiresOAuth: true + oauthMetadata: + authorization_url: "https://example.com/oauth/authorize" + token_url: "https://example.com/oauth/token" + +# Test stdio server without URL - should set requiresOAuth to false +stdio_server: + command: "node" + args: ["server.js"] + env: + API_KEY: "${TEST_API_KEY}" + startup: true + serverInstructions: "Follow these instructions for stdio server" + +# Test websocket server with capabilities but no tools +websocket_server: + type: "websocket" + url: "ws://localhost:3001/mcp" + startup: true + +# Test server with startup disabled - should not be included in appServerConfigs +disabled_server: + type: "streamable-http" + url: "https://api.disabled.com/mcp" + startup: false + +# Test non-OAuth server - should be included in appServerConfigs +non_oauth_server: + type: "streamable-http" + url: "https://api.public.com/mcp" + requiresOAuth: false + serverInstructions: true + +# Test server with OAuth but startup enabled - should not be in appServerConfigs +oauth_startup_enabled: + type: "sse" + url: "https://api.oauth-startup.com/sse" + requiresOAuth: true \ No newline at end of file diff --git a/packages/api/src/mcp/oauth/handler.test.ts b/packages/api/src/mcp/__tests__/handler.test.ts similarity index 99% rename from packages/api/src/mcp/oauth/handler.test.ts rename to packages/api/src/mcp/__tests__/handler.test.ts index c65ffc151..e4a4075aa 100644 --- a/packages/api/src/mcp/oauth/handler.test.ts +++ b/packages/api/src/mcp/__tests__/handler.test.ts @@ -1,5 +1,5 @@ -import { MCPOAuthHandler } from './handler'; import type { MCPOptions } from 'librechat-data-provider'; +import { MCPOAuthHandler } from '~/mcp/oauth'; jest.mock('@librechat/data-schemas', () => ({ logger: { diff --git a/packages/api/src/mcp/mcp.spec.ts b/packages/api/src/mcp/__tests__/mcp.spec.ts similarity index 100% rename from packages/api/src/mcp/mcp.spec.ts rename to packages/api/src/mcp/__tests__/mcp.spec.ts diff --git a/packages/api/src/mcp/utils.test.ts b/packages/api/src/mcp/__tests__/utils.test.ts similarity index 95% rename from packages/api/src/mcp/utils.test.ts rename to packages/api/src/mcp/__tests__/utils.test.ts index bc5d0ba7d..716a230eb 100644 --- a/packages/api/src/mcp/utils.test.ts +++ b/packages/api/src/mcp/__tests__/utils.test.ts @@ -1,4 +1,4 @@ -import { normalizeServerName } from './utils'; +import { normalizeServerName } from '../utils'; describe('normalizeServerName', () => { it('should not modify server names that already match the pattern', () => { diff --git a/packages/api/src/mcp/zod.spec.ts b/packages/api/src/mcp/__tests__/zod.spec.ts similarity index 99% rename from packages/api/src/mcp/zod.spec.ts rename to packages/api/src/mcp/__tests__/zod.spec.ts index 7eb9c162a..fbc72b7e6 100644 --- a/packages/api/src/mcp/zod.spec.ts +++ b/packages/api/src/mcp/__tests__/zod.spec.ts @@ -2,7 +2,7 @@ // zod.spec.ts import { z } from 'zod'; import type { JsonSchemaType } from '~/types'; -import { resolveJsonSchemaRefs, convertJsonSchemaToZod, convertWithResolvedRefs } from './zod'; +import { resolveJsonSchemaRefs, convertJsonSchemaToZod, convertWithResolvedRefs } from '../zod'; describe('convertJsonSchemaToZod', () => { describe('primitive types', () => { diff --git a/packages/api/src/mcp/connection.ts b/packages/api/src/mcp/connection.ts index 52189e2ed..7099205e6 100644 --- a/packages/api/src/mcp/connection.ts +++ b/packages/api/src/mcp/connection.ts @@ -1,17 +1,18 @@ import { EventEmitter } from 'events'; -import { logger } from '@librechat/data-schemas'; -import { Client } from '@modelcontextprotocol/sdk/client/index.js'; -import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js'; import { StdioClientTransport, getDefaultEnvironment, } from '@modelcontextprotocol/sdk/client/stdio.js'; -import { WebSocketClientTransport } from '@modelcontextprotocol/sdk/client/websocket.js'; -import { ResourceListChangedNotificationSchema } from '@modelcontextprotocol/sdk/types.js'; import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js'; +import { ResourceListChangedNotificationSchema } from '@modelcontextprotocol/sdk/types.js'; +import { WebSocketClientTransport } from '@modelcontextprotocol/sdk/client/websocket.js'; +import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js'; +import { Client } from '@modelcontextprotocol/sdk/client/index.js'; +import { logger } from '@librechat/data-schemas'; import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js'; import type { JSONRPCMessage } from '@modelcontextprotocol/sdk/types.js'; import type { MCPOAuthTokens } from './oauth/types'; +import { mcpConfig } from './mcpConfig'; import type * as t from './types'; function isStdioOptions(options: t.MCPOptions): options is t.StdioOptions { @@ -56,9 +57,17 @@ function isStreamableHTTPOptions(options: t.MCPOptions): options is t.Streamable } const FIVE_MINUTES = 5 * 60 * 1000; + +interface MCPConnectionParams { + serverName: string; + serverConfig: t.MCPOptions; + userId?: string; + oauthTokens?: MCPOAuthTokens | null; +} + export class MCPConnection extends EventEmitter { - private static instance: MCPConnection | null = null; public client: Client; + private options: t.MCPOptions; private transport: Transport | null = null; // Make this nullable private connectionState: t.ConnectionState = 'disconnected'; private connectPromise: Promise | null = null; @@ -70,26 +79,23 @@ export class MCPConnection extends EventEmitter { private reconnectAttempts = 0; private readonly userId?: string; private lastPingTime: number; + private lastConnectionCheckAt: number = 0; private oauthTokens?: MCPOAuthTokens | null; private oauthRequired = false; iconPath?: string; timeout?: number; url?: string; - constructor( - serverName: string, - private readonly options: t.MCPOptions, - userId?: string, - oauthTokens?: MCPOAuthTokens | null, - ) { + constructor(params: MCPConnectionParams) { super(); - this.serverName = serverName; - this.userId = userId; - this.iconPath = options.iconPath; - this.timeout = options.timeout; + this.options = params.serverConfig; + this.serverName = params.serverName; + this.userId = params.userId; + this.iconPath = params.serverConfig.iconPath; + this.timeout = params.serverConfig.timeout; this.lastPingTime = Date.now(); - if (oauthTokens) { - this.oauthTokens = oauthTokens; + if (params.oauthTokens) { + this.oauthTokens = params.oauthTokens; } this.client = new Client( { @@ -110,28 +116,6 @@ export class MCPConnection extends EventEmitter { return `[MCP]${userPart}[${this.serverName}]`; } - public static getInstance( - serverName: string, - options: t.MCPOptions, - userId?: string, - ): MCPConnection { - if (!MCPConnection.instance) { - MCPConnection.instance = new MCPConnection(serverName, options, userId); - } - return MCPConnection.instance; - } - - public static getExistingInstance(): MCPConnection | null { - return MCPConnection.instance; - } - - public static async destroyInstance(): Promise { - if (MCPConnection.instance) { - await MCPConnection.instance.disconnect(); - MCPConnection.instance = null; - } - } - private emitError(error: unknown, errorContext: string): void { const errorMessage = error instanceof Error ? error.message : String(error); logger.error(`${this.getLogPrefix()} ${errorContext}: ${errorMessage}`); @@ -589,6 +573,13 @@ export class MCPConnection extends EventEmitter { return false; } + // If we recently checked, skip expensive verification + const now = Date.now(); + if (now - this.lastConnectionCheckAt < mcpConfig.CONNECTION_CHECK_TTL) { + return true; + } + this.lastConnectionCheckAt = now; + try { // Try ping first as it's the lightest check await this.client.ping(); diff --git a/packages/api/src/mcp/manager.ts b/packages/api/src/mcp/manager.ts deleted file mode 100644 index 6db52ec9e..000000000 --- a/packages/api/src/mcp/manager.ts +++ /dev/null @@ -1,1143 +0,0 @@ -import { logger } from '@librechat/data-schemas'; -import { CallToolResultSchema, ErrorCode, McpError } from '@modelcontextprotocol/sdk/types.js'; -import type { OAuthClientInformation } from '@modelcontextprotocol/sdk/shared/auth.js'; -import type { RequestOptions } from '@modelcontextprotocol/sdk/shared/protocol.js'; -import type { TokenMethods } from '@librechat/data-schemas'; -import type { TUser } from 'librechat-data-provider'; -import type { MCPOAuthTokens, MCPOAuthFlowMetadata } from './oauth/types'; -import type { FlowStateManager } from '~/flow/manager'; -import type { JsonSchemaType } from '~/types/zod'; -import type { FlowMetadata } from '~/flow/types'; -import type * as t from './types'; -import { CONSTANTS, isSystemUserId } from './enum'; -import { MCPOAuthHandler } from './oauth/handler'; -import { MCPTokenStorage } from './oauth/tokens'; -import { formatToolContent } from './parsers'; -import { MCPConnection } from './connection'; -import { processMCPEnv } from '~/utils/env'; - -export class MCPManager { - private static instance: MCPManager | null = null; - /** App-level connections initialized at startup */ - private connections: Map = new Map(); - /** User-specific connections initialized on demand */ - private userConnections: Map> = new Map(); - /** Last activity timestamp for users (not per server) */ - private userLastActivity: Map = new Map(); - private readonly USER_CONNECTION_IDLE_TIMEOUT = 15 * 60 * 1000; // 15 minutes (TODO: make configurable) - private mcpConfigs: t.MCPServers = {}; - /** Store MCP server instructions */ - private serverInstructions: Map = new Map(); - /** Track servers that required OAuth at startup */ - private oauthServers: Set = new Set(); - - public static getInstance(): MCPManager { - if (!MCPManager.instance) { - MCPManager.instance = new MCPManager(); - } - // Check for idle connections when getInstance is called - MCPManager.instance.checkIdleConnections(); - return MCPManager.instance; - } - - /** Stores configs and initializes app-level connections */ - public async initializeMCPs({ - mcpServers, - flowManager, - tokenMethods, - }: { - mcpServers: t.MCPServers; - flowManager: FlowStateManager; - tokenMethods?: TokenMethods; - }): Promise { - this.mcpConfigs = mcpServers; - - if (!flowManager) { - logger.info('[MCP] No flow manager provided, OAuth will not be available'); - } - - if (!tokenMethods) { - logger.info('[MCP] No token methods provided, token persistence will not be available'); - } - const entries = Object.entries(mcpServers); - const initializedServers = new Set(); - const connectionResults = await Promise.allSettled( - entries.map(async ([serverName, config], i) => { - try { - await this.initializeMCP({ - serverName, - config, - flowManager, - tokenMethods, - }); - initializedServers.add(i); - } catch (error) { - logger.error(`[MCP][${serverName}] Initialization failed`, error); - } - }), - ); - - const failedConnections = connectionResults.filter( - (result): result is PromiseRejectedResult => result.status === 'rejected', - ); - - logger.info( - `[MCP] Initialized ${initializedServers.size}/${entries.length} app-level server(s)`, - ); - - if (failedConnections.length > 0) { - logger.warn( - `[MCP] ${failedConnections.length}/${entries.length} app-level server(s) failed to initialize`, - ); - } - - entries.forEach(([serverName], index) => { - if (initializedServers.has(index)) { - logger.info(`[MCP][${serverName}] ✓ Initialized`); - } else { - logger.info(`[MCP][${serverName}] ✗ Failed`); - } - }); - - if (initializedServers.size === entries.length) { - logger.info('[MCP] All app-level servers initialized successfully'); - } else if (initializedServers.size === 0) { - logger.warn('[MCP] No app-level servers initialized'); - } - } - - /** Initializes a single MCP server connection (app-level) */ - public async initializeMCP({ - serverName, - config, - flowManager, - tokenMethods, - }: { - serverName: string; - config: t.MCPOptions; - flowManager: FlowStateManager; - tokenMethods?: TokenMethods; - }): Promise { - const processedConfig = processMCPEnv(config); - let tokens: MCPOAuthTokens | null = null; - if (tokenMethods?.findToken) { - try { - /** Refresh function for app-level connections */ - const refreshTokensFunction = async ( - refreshToken: string, - metadata: { - userId: string; - serverName: string; - identifier: string; - clientInfo?: OAuthClientInformation; - }, - ) => { - const serverUrl = (processedConfig as t.SSEOptions | t.StreamableHTTPOptions).url; - return await MCPOAuthHandler.refreshOAuthTokens( - refreshToken, - { - serverName: metadata.serverName, - serverUrl, - clientInfo: metadata.clientInfo, - }, - processedConfig.oauth, - ); - }; - - /** Flow state to prevent concurrent token operations */ - const tokenFlowId = `tokens:${CONSTANTS.SYSTEM_USER_ID}:${serverName}`; - tokens = await flowManager.createFlowWithHandler( - tokenFlowId, - 'mcp_get_tokens', - async () => { - return await MCPTokenStorage.getTokens({ - userId: CONSTANTS.SYSTEM_USER_ID, - serverName, - findToken: tokenMethods.findToken, - refreshTokens: refreshTokensFunction, - createToken: tokenMethods.createToken, - updateToken: tokenMethods.updateToken, - }); - }, - ); - } catch { - logger.debug(`[MCP][${serverName}] No existing tokens found`); - } - } - if (tokens) { - logger.info(`[MCP][${serverName}] Loaded OAuth tokens`); - } - const connection = new MCPConnection(serverName, processedConfig, undefined, tokens); - logger.info(`[MCP][${serverName}] Setting up OAuth event listener`); - connection.on('oauthRequired', async () => { - logger.debug(`[MCP][${serverName}] oauthRequired event received`); - - this.oauthServers.add(serverName); - - // Skip OAuth at startup - let connection fail gracefully - logger.info(`[MCP][${serverName}] OAuth required, skipping at startup`); - connection.emit('oauthFailed', new Error('OAuth authentication skipped at startup')); - }); - try { - const connectTimeout = processedConfig.initTimeout ?? 30000; - const connectionTimeout = new Promise((_, reject) => - setTimeout( - () => reject(new Error(`Connection timeout after ${connectTimeout}ms`)), - connectTimeout, - ), - ); - const connectionAttempt = this.initializeServer({ - connection, - logPrefix: `[MCP][${serverName}]`, - flowManager, - handleOAuth: false, - }); - await Promise.race([connectionAttempt, connectionTimeout]); - if (await connection.isConnected()) { - this.connections.set(serverName, connection); - - /** Unified `serverInstructions` configuration */ - const configInstructions = processedConfig.serverInstructions; - if (configInstructions !== undefined) { - if (typeof configInstructions === 'string') { - this.serverInstructions.set(serverName, configInstructions); - logger.info( - `[MCP][${serverName}] Custom instructions stored for context inclusion: ${configInstructions}`, - ); - } else if (configInstructions === true) { - /** Server-provided instructions */ - const serverInstructions = connection.client.getInstructions(); - - if (serverInstructions) { - this.serverInstructions.set(serverName, serverInstructions); - logger.info( - `[MCP][${serverName}] Server instructions stored for context inclusion: ${serverInstructions}`, - ); - } else { - logger.info( - `[MCP][${serverName}] serverInstructions=true but no server instructions available`, - ); - } - } else { - logger.info( - `[MCP][${serverName}] Instructions explicitly disabled (serverInstructions=false)`, - ); - } - } else { - logger.info( - `[MCP][${serverName}] Instructions not included (serverInstructions not configured)`, - ); - } - - const serverCapabilities = connection.client.getServerCapabilities(); - logger.info(`[MCP][${serverName}] Capabilities: ${JSON.stringify(serverCapabilities)}`); - - if (serverCapabilities?.tools) { - const tools = await connection.client.listTools(); - if (tools.tools.length) { - logger.info( - `[MCP][${serverName}] Available tools: ${tools.tools - .map((tool) => tool.name) - .join(', ')}`, - ); - } - } - logger.info(`[MCP][${serverName}] ✓ Initialized`); - } else { - logger.info(`[MCP][${serverName}] ✗ Failed`); - } - } catch (error) { - logger.error(`[MCP][${serverName}] Initialization failed`, error); - throw error; - } - } - - /** Generic server initialization logic */ - private async initializeServer({ - connection, - logPrefix, - flowManager, - handleOAuth = true, - }: { - connection: MCPConnection; - logPrefix: string; - flowManager: FlowStateManager; - handleOAuth?: boolean; - }): Promise { - const maxAttempts = 3; - let attempts = 0; - /** Whether OAuth has been handled by the connection */ - let oauthHandled = false; - - while (attempts < maxAttempts) { - try { - await connection.connect(); - if (await connection.isConnected()) { - return; - } - throw new Error('Connection attempt succeeded but status is not connected'); - } catch (error) { - attempts++; - - if (this.isOAuthError(error)) { - // Only handle OAuth if requested (not already handled by event listener) - if (handleOAuth) { - /** Check if OAuth was already handled by the connection */ - const errorWithFlag = error as (Error & { isOAuthError?: boolean }) | undefined; - if (!oauthHandled && errorWithFlag?.isOAuthError) { - oauthHandled = true; - logger.info(`${logPrefix} Handling OAuth`); - const serverUrl = connection.url; - if (serverUrl) { - await this.handleOAuthRequired({ - serverName: connection.serverName, - serverUrl, - flowManager, - }); - } - } else { - logger.info(`${logPrefix} OAuth already handled by connection`); - } - } - // Don't retry on OAuth errors - just throw - logger.info(`${logPrefix} OAuth required, stopping connection attempts`); - throw error; - } - - if (attempts === maxAttempts) { - logger.error(`${logPrefix} Failed to connect after ${maxAttempts} attempts`, error); - throw error; // Re-throw the last error - } - await new Promise((resolve) => setTimeout(resolve, 2000 * attempts)); - } - } - } - - private isOAuthError(error: unknown): boolean { - if (!error || typeof error !== 'object') { - return false; - } - - // Check for SSE error with 401 status - if ('message' in error && typeof error.message === 'string') { - return error.message.includes('401') || error.message.includes('Non-200 status code (401)'); - } - - // Check for error code - if ('code' in error) { - const code = (error as { code?: number }).code; - return code === 401 || code === 403; - } - - return false; - } - - /** Check for and disconnect idle connections */ - private checkIdleConnections(currentUserId?: string): void { - const now = Date.now(); - - // Iterate through all users to check for idle ones - for (const [userId, lastActivity] of this.userLastActivity.entries()) { - if (currentUserId && currentUserId === userId) { - continue; - } - if (now - lastActivity > this.USER_CONNECTION_IDLE_TIMEOUT) { - logger.info( - `[MCP][User: ${userId}] User idle for too long. Disconnecting all connections...`, - ); - // Disconnect all user connections asynchronously (fire and forget) - this.disconnectUserConnections(userId).catch((err) => - logger.error(`[MCP][User: ${userId}] Error disconnecting idle connections:`, err), - ); - } - } - } - - /** Updates the last activity timestamp for a user */ - private updateUserLastActivity(userId: string): void { - const now = Date.now(); - this.userLastActivity.set(userId, now); - logger.debug( - `[MCP][User: ${userId}] Updated last activity timestamp: ${new Date(now).toISOString()}`, - ); - } - - /** Gets or creates a connection for a specific user */ - public async getUserConnection({ - user, - serverName, - flowManager, - customUserVars, - tokenMethods, - oauthStart, - oauthEnd, - signal, - returnOnOAuth = false, - }: { - user: TUser; - serverName: string; - flowManager: FlowStateManager; - customUserVars?: Record; - tokenMethods?: TokenMethods; - oauthStart?: (authURL: string) => Promise; - oauthEnd?: () => Promise; - signal?: AbortSignal; - returnOnOAuth?: boolean; - }): Promise { - const userId = user.id; - if (!userId) { - throw new McpError(ErrorCode.InvalidRequest, `[MCP] User object missing id property`); - } - - const userServerMap = this.userConnections.get(userId); - let connection = userServerMap?.get(serverName); - const now = Date.now(); - - // Check if user is idle - const lastActivity = this.userLastActivity.get(userId); - if (lastActivity && now - lastActivity > this.USER_CONNECTION_IDLE_TIMEOUT) { - logger.info(`[MCP][User: ${userId}] User idle for too long. Disconnecting all connections.`); - // Disconnect all user connections - try { - await this.disconnectUserConnections(userId); - } catch (err) { - logger.error(`[MCP][User: ${userId}] Error disconnecting idle connections:`, err); - } - connection = undefined; // Force creation of a new connection - } else if (connection) { - if (await connection.isConnected()) { - logger.debug(`[MCP][User: ${userId}][${serverName}] Reusing active connection`); - this.updateUserLastActivity(userId); - return connection; - } else { - // Connection exists but is not connected, attempt to remove potentially stale entry - logger.warn( - `[MCP][User: ${userId}][${serverName}] Found existing but disconnected connection object. Cleaning up.`, - ); - this.removeUserConnection(userId, serverName); // Clean up maps - connection = undefined; - } - } - - // If no valid connection exists, create a new one - if (!connection) { - logger.info(`[MCP][User: ${userId}][${serverName}] Establishing new connection`); - } - - let config = this.mcpConfigs[serverName]; - if (!config) { - throw new McpError( - ErrorCode.InvalidRequest, - `[MCP][User: ${userId}] Configuration for server "${serverName}" not found.`, - ); - } - - config = { ...(processMCPEnv(config, user, customUserVars) ?? {}) }; - /** If no in-memory tokens, tokens from persistent storage */ - let tokens: MCPOAuthTokens | null = null; - if (tokenMethods?.findToken) { - try { - /** Refresh function for user-specific connections */ - const refreshTokensFunction = async ( - refreshToken: string, - metadata: { - userId: string; - serverName: string; - identifier: string; - clientInfo?: OAuthClientInformation; - }, - ) => { - /** URL from config since connection doesn't exist yet */ - const serverUrl = (config as t.SSEOptions | t.StreamableHTTPOptions).url; - return await MCPOAuthHandler.refreshOAuthTokens( - refreshToken, - { - serverName: metadata.serverName, - serverUrl, - clientInfo: metadata.clientInfo, - }, - config.oauth, - ); - }; - - /** Flow state to prevent concurrent token operations */ - const tokenFlowId = `tokens:${userId}:${serverName}`; - tokens = await flowManager.createFlowWithHandler( - tokenFlowId, - 'mcp_get_tokens', - async () => { - return await MCPTokenStorage.getTokens({ - userId, - serverName, - findToken: tokenMethods.findToken, - refreshTokens: refreshTokensFunction, - createToken: tokenMethods.createToken, - updateToken: tokenMethods.updateToken, - }); - }, - signal, - ); - } catch (error) { - logger.error( - `[MCP][User: ${userId}][${serverName}] Error loading OAuth tokens from storage`, - error, - ); - } - } - - if (tokens) { - logger.info(`[MCP][User: ${userId}][${serverName}] Loaded OAuth tokens`); - } - - connection = new MCPConnection(serverName, config, userId, tokens); - - connection.on('oauthRequired', async (data) => { - logger.info(`[MCP][User: ${userId}][${serverName}] oauthRequired event received`); - - // If we just want to initiate OAuth and return, handle it differently - if (returnOnOAuth) { - try { - const config = this.mcpConfigs[serverName]; - const { authorizationUrl, flowId, flowMetadata } = - await MCPOAuthHandler.initiateOAuthFlow( - serverName, - data.serverUrl || '', - userId, - config?.oauth, - ); - - // Create the flow state so the OAuth callback can find it - // We spawn this in the background without waiting for it - flowManager.createFlow(flowId, 'mcp_oauth', flowMetadata).catch(() => { - // The OAuth callback will resolve this flow, so we expect it to timeout here - // which is fine - we just need the flow state to exist - }); - - if (oauthStart) { - logger.info( - `[MCP][User: ${userId}][${serverName}] OAuth flow started, issuing authorization URL`, - ); - await oauthStart(authorizationUrl); - } - - // Emit oauthFailed to signal that connection should not proceed - // but OAuth was successfully initiated - connection?.emit('oauthFailed', new Error('OAuth flow initiated - return early')); - return; - } catch (error) { - logger.error( - `[MCP][User: ${userId}][${serverName}] Failed to initiate OAuth flow`, - error, - ); - connection?.emit('oauthFailed', new Error('OAuth initiation failed')); - return; - } - } - - // Normal OAuth handling - wait for completion - const result = await this.handleOAuthRequired({ - ...data, - flowManager, - oauthStart, - oauthEnd, - }); - - if (result?.tokens && tokenMethods?.createToken) { - try { - connection?.setOAuthTokens(result.tokens); - await MCPTokenStorage.storeTokens({ - userId, - serverName, - tokens: result.tokens, - createToken: tokenMethods.createToken, - updateToken: tokenMethods.updateToken, - findToken: tokenMethods.findToken, - clientInfo: result.clientInfo, - }); - logger.info(`[MCP][User: ${userId}][${serverName}] OAuth tokens saved to storage`); - } catch (error) { - logger.error( - `[MCP][User: ${userId}][${serverName}] Failed to save OAuth tokens to storage`, - error, - ); - } - } - - // Only emit oauthHandled if we actually got tokens (OAuth succeeded) - if (result?.tokens) { - connection?.emit('oauthHandled'); - } else { - // OAuth failed, emit oauthFailed to properly reject the promise - logger.warn( - `[MCP][User: ${userId}][${serverName}] OAuth failed, emitting oauthFailed event`, - ); - connection?.emit('oauthFailed', new Error('OAuth authentication failed')); - } - }); - - try { - const connectTimeout = config.initTimeout ?? 30000; - const connectionTimeout = new Promise((_, reject) => - setTimeout( - () => reject(new Error(`Connection timeout after ${connectTimeout}ms`)), - connectTimeout, - ), - ); - const connectionAttempt = this.initializeServer({ - connection, - logPrefix: `[MCP][User: ${userId}][${serverName}]`, - flowManager, - }); - await Promise.race([connectionAttempt, connectionTimeout]); - - if (!(await connection?.isConnected())) { - throw new Error('Failed to establish connection after initialization attempt.'); - } - - if (!this.userConnections.has(userId)) { - this.userConnections.set(userId, new Map()); - } - this.userConnections.get(userId)?.set(serverName, connection); - - logger.info(`[MCP][User: ${userId}][${serverName}] Connection successfully established`); - // Update timestamp on creation - this.updateUserLastActivity(userId); - return connection; - } catch (error) { - logger.error(`[MCP][User: ${userId}][${serverName}] Failed to establish connection`, error); - // Ensure partial connection state is cleaned up if initialization fails - await connection?.disconnect().catch((disconnectError) => { - logger.error( - `[MCP][User: ${userId}][${serverName}] Error during cleanup after failed connection`, - disconnectError, - ); - }); - // Ensure cleanup even if connection attempt fails - this.removeUserConnection(userId, serverName); - throw error; // Re-throw the error to the caller - } - } - - /** Removes a specific user connection entry */ - private removeUserConnection(userId: string, serverName: string): void { - const userMap = this.userConnections.get(userId); - if (userMap) { - userMap.delete(serverName); - if (userMap.size === 0) { - this.userConnections.delete(userId); - // Only remove user activity timestamp if all connections are gone - this.userLastActivity.delete(userId); - } - } - - logger.debug(`[MCP][User: ${userId}][${serverName}] Removed connection entry.`); - } - - /** Disconnects and removes a specific user connection */ - public async disconnectUserConnection(userId: string, serverName: string): Promise { - const userMap = this.userConnections.get(userId); - const connection = userMap?.get(serverName); - if (connection) { - logger.info(`[MCP][User: ${userId}][${serverName}] Disconnecting...`); - await connection.disconnect(); - this.removeUserConnection(userId, serverName); - } - } - - /** Disconnects and removes all connections for a specific user */ - public async disconnectUserConnections(userId: string): Promise { - const userMap = this.userConnections.get(userId); - const disconnectPromises: Promise[] = []; - if (userMap) { - logger.info(`[MCP][User: ${userId}] Disconnecting all servers...`); - const userServers = Array.from(userMap.keys()); - for (const serverName of userServers) { - disconnectPromises.push( - this.disconnectUserConnection(userId, serverName).catch((error) => { - logger.error( - `[MCP][User: ${userId}][${serverName}] Error during disconnection:`, - error, - ); - }), - ); - } - await Promise.allSettled(disconnectPromises); - // Ensure user activity timestamp is removed - this.userLastActivity.delete(userId); - logger.info(`[MCP][User: ${userId}] All connections processed for disconnection.`); - } - } - - /** Returns the app-level connection (used for mapping tools, etc.) */ - public getConnection(serverName: string): MCPConnection | undefined { - return this.connections.get(serverName); - } - - /** Returns all app-level connections */ - public getAllConnections(): Map { - return this.connections; - } - - /** Attempts to reconnect an app-level connection if it's disconnected */ - private async isConnectionActive({ - serverName, - connection, - flowManager, - skipReconnect = false, - }: { - serverName: string; - connection: MCPConnection; - flowManager: FlowStateManager; - skipReconnect?: boolean; - }): Promise { - if (await connection.isConnected()) { - return true; - } - - if (skipReconnect) { - logger.warn( - `[MCP][${serverName}] App-level connection is disconnected, skipping reconnection attempt`, - ); - return false; - } - - logger.warn( - `[MCP][${serverName}] App-level connection disconnected, attempting to reconnect...`, - ); - - try { - const config = this.mcpConfigs[serverName]; - if (!config) { - logger.error(`[MCP][${serverName}] Configuration not found for reconnection`); - return false; - } - - await this.initializeServer({ - connection, - logPrefix: `[MCP][${serverName}]`, - flowManager, - }); - - if (await connection.isConnected()) { - logger.info(`[MCP][${serverName}] App-level connection successfully reconnected`); - return true; - } else { - logger.warn(`[MCP][${serverName}] App-level connection reconnection failed`); - return false; - } - } catch (error) { - logger.error(`[MCP][${serverName}] Error during app-level connection reconnection:`, error); - return false; - } - } - - /** - * Maps available tools from all app-level connections into the provided object. - * The object is modified in place. - */ - public async mapAvailableTools( - availableTools: t.LCAvailableTools, - flowManager: FlowStateManager, - ): Promise { - for (const [serverName, connection] of this.connections.entries()) { - try { - /** Attempt to ensure connection is active, with reconnection if needed */ - const isActive = await this.isConnectionActive({ serverName, connection, flowManager }); - if (!isActive) { - logger.warn(`[MCP][${serverName}] Connection not available. Skipping tool mapping.`); - continue; - } - - const tools = await connection.fetchTools(); - for (const tool of tools) { - const name = `${tool.name}${CONSTANTS.mcp_delimiter}${serverName}`; - availableTools[name] = { - type: 'function', - ['function']: { - name, - description: tool.description, - parameters: tool.inputSchema as JsonSchemaType, - }, - }; - } - } catch (error) { - logger.warn(`[MCP][${serverName}] Error fetching tools`, error); - } - } - } - - /** - * Loads tools from all app-level connections into the manifest. - */ - public async loadManifestTools({ - flowManager, - serverToolsCallback, - getServerTools, - }: { - flowManager: FlowStateManager; - serverToolsCallback?: (serverName: string, tools: t.LCManifestTool[]) => Promise; - getServerTools?: (serverName: string) => Promise; - }): Promise { - const mcpTools: t.LCManifestTool[] = []; - for (const [serverName, connection] of this.connections.entries()) { - try { - /** Attempt to ensure connection is active, with reconnection if needed */ - const isActive = await this.isConnectionActive({ - serverName, - connection, - flowManager, - skipReconnect: true, - }); - if (!isActive) { - logger.warn( - `[MCP][${serverName}] Connection not available for ${serverName} manifest tools.`, - ); - if (typeof getServerTools !== 'function') { - logger.warn( - `[MCP][${serverName}] No \`getServerTools\` function provided, skipping tool loading.`, - ); - continue; - } - const serverTools = await getServerTools(serverName); - if (serverTools && serverTools.length > 0) { - logger.info(`[MCP][${serverName}] Loaded tools from cache for manifest`); - mcpTools.push(...serverTools); - } - continue; - } - - const tools = await connection.fetchTools(); - const serverTools: t.LCManifestTool[] = []; - for (const tool of tools) { - const pluginKey = `${tool.name}${CONSTANTS.mcp_delimiter}${serverName}`; - - const config = this.mcpConfigs[serverName]; - const manifestTool: t.LCManifestTool = { - name: tool.name, - pluginKey, - description: tool.description ?? '', - icon: connection.iconPath, - authConfig: config?.customUserVars - ? Object.entries(config.customUserVars).map(([key, value]) => ({ - authField: key, - label: value.title || key, - description: value.description || '', - })) - : undefined, - }; - if (config?.chatMenu === false) { - manifestTool.chatMenu = false; - } - mcpTools.push(manifestTool); - serverTools.push(manifestTool); - } - if (typeof serverToolsCallback === 'function') { - await serverToolsCallback(serverName, serverTools); - } - } catch (error) { - logger.error(`[MCP][${serverName}] Error fetching tools for manifest:`, error); - } - } - - return mcpTools; - } - - /** - * Calls a tool on an MCP server, using either a user-specific connection - * (if userId is provided) or an app-level connection. Updates the last activity timestamp - * for user-specific connections upon successful call initiation. - */ - async callTool({ - user, - serverName, - toolName, - provider, - toolArguments, - options, - tokenMethods, - flowManager, - oauthStart, - oauthEnd, - customUserVars, - }: { - user?: TUser; - serverName: string; - toolName: string; - provider: t.Provider; - toolArguments?: Record; - options?: RequestOptions; - tokenMethods?: TokenMethods; - customUserVars?: Record; - flowManager: FlowStateManager; - oauthStart?: (authURL: string) => Promise; - oauthEnd?: () => Promise; - }): Promise { - /** User-specific connection */ - let connection: MCPConnection | undefined; - const userId = user?.id; - const logPrefix = userId ? `[MCP][User: ${userId}][${serverName}]` : `[MCP][${serverName}]`; - - try { - if (userId && user) { - this.updateUserLastActivity(userId); - /** Get or create user-specific connection */ - connection = await this.getUserConnection({ - user, - serverName, - flowManager, - tokenMethods, - oauthStart, - oauthEnd, - signal: options?.signal, - customUserVars, - }); - } else { - /** App-level connection */ - connection = this.connections.get(serverName); - if (!connection) { - throw new McpError( - ErrorCode.InvalidRequest, - `${logPrefix} No app-level connection found. Cannot execute tool ${toolName}.`, - ); - } - } - - if (!(await connection.isConnected())) { - /** May happen if getUserConnection failed silently or app connection dropped */ - throw new McpError( - ErrorCode.InternalError, // Use InternalError for connection issues - `${logPrefix} Connection is not active. Cannot execute tool ${toolName}.`, - ); - } - - const result = await connection.client.request( - { - method: 'tools/call', - params: { - name: toolName, - arguments: toolArguments, - }, - }, - CallToolResultSchema, - { - timeout: connection.timeout, - ...options, - }, - ); - if (userId) { - this.updateUserLastActivity(userId); - } - this.checkIdleConnections(); - return formatToolContent(result as t.MCPToolCallResponse, provider); - } catch (error) { - // Log with context and re-throw or handle as needed - logger.error(`${logPrefix}[${toolName}] Tool call failed`, error); - // Rethrowing allows the caller (createMCPTool) to handle the final user message - throw error; - } - } - - /** Disconnects a specific app-level server */ - public async disconnectServer(serverName: string): Promise { - const connection = this.connections.get(serverName); - if (connection) { - logger.info(`[MCP][${serverName}] Disconnecting...`); - await connection.disconnect(); - this.connections.delete(serverName); - } - } - - /** Disconnects all app-level and user-level connections */ - public async disconnectAll(): Promise { - logger.info('[MCP] Disconnecting all app-level and user-level connections...'); - - const userDisconnectPromises = Array.from(this.userConnections.keys()).map((userId) => - this.disconnectUserConnections(userId), - ); - await Promise.allSettled(userDisconnectPromises); - this.userLastActivity.clear(); - - // Disconnect all app-level connections - const appDisconnectPromises = Array.from(this.connections.values()).map((connection) => - connection.disconnect().catch((error) => { - logger.error(`[MCP][${connection.serverName}] Error during disconnectAll:`, error); - }), - ); - await Promise.allSettled(appDisconnectPromises); - this.connections.clear(); - - logger.info('[MCP] All connections processed for disconnection.'); - } - - /** Destroys the singleton instance and disconnects all connections */ - public static async destroyInstance(): Promise { - if (MCPManager.instance) { - await MCPManager.instance.disconnectAll(); - MCPManager.instance = null; - logger.info('[MCP] Manager instance destroyed.'); - } - } - - /** - * Get instructions for MCP servers - * @param serverNames Optional array of server names. If not provided or empty, returns all servers. - * @returns Object mapping server names to their instructions - */ - public getInstructions(serverNames?: string[]): Record { - const instructions: Record = {}; - - if (!serverNames || serverNames.length === 0) { - // Return all instructions if no specific servers requested - for (const [serverName, serverInstructions] of this.serverInstructions.entries()) { - instructions[serverName] = serverInstructions; - } - } else { - // Return instructions for specific servers - for (const serverName of serverNames) { - const serverInstructions = this.serverInstructions.get(serverName); - if (serverInstructions) { - instructions[serverName] = serverInstructions; - } - } - } - - return instructions; - } - - /** - * Format MCP server instructions for injection into context - * @param serverNames Optional array of server names to include. If not provided, includes all servers. - * @returns Formatted instructions string ready for context injection - */ - public formatInstructionsForContext(serverNames?: string[]): string { - /** Instructions for specified servers or all stored instructions */ - const instructionsToInclude = this.getInstructions(serverNames); - - if (Object.keys(instructionsToInclude).length === 0) { - return ''; - } - - // Format instructions for context injection - const formattedInstructions = Object.entries(instructionsToInclude) - .map(([serverName, instructions]) => { - return `## ${serverName} MCP Server Instructions - -${instructions}`; - }) - .join('\n\n'); - - return `# MCP Server Instructions - -The following MCP servers are available with their specific instructions: - -${formattedInstructions} - -Please follow these instructions when using tools from the respective MCP servers.`; - } - - /** Handles OAuth authentication requirements */ - private async handleOAuthRequired({ - serverName, - serverUrl, - flowManager, - userId = CONSTANTS.SYSTEM_USER_ID, - oauthStart, - oauthEnd, - }: { - serverName: string; - flowManager: FlowStateManager; - userId?: string; - serverUrl?: string; - oauthStart?: (authURL: string) => Promise; - oauthEnd?: () => Promise; - }): Promise<{ tokens: MCPOAuthTokens | null; clientInfo?: OAuthClientInformation } | null> { - const userPart = isSystemUserId(userId) ? '' : `[User: ${userId}]`; - const logPrefix = `[MCP]${userPart}[${serverName}]`; - logger.debug(`${logPrefix} \`handleOAuthRequired\` called with serverUrl: ${serverUrl}`); - - if (!flowManager || !serverUrl) { - logger.error( - `${logPrefix} OAuth required but flow manager not available or server URL missing for ${serverName}`, - ); - logger.warn(`${logPrefix} Please configure OAuth credentials for ${serverName}`); - return null; - } - - try { - const config = this.mcpConfigs[serverName]; - logger.debug(`${logPrefix} Checking for existing OAuth flow for ${serverName}...`); - - /** Flow ID to check if a flow already exists */ - const flowId = MCPOAuthHandler.generateFlowId(userId, serverName); - - /** Check if there's already an ongoing OAuth flow for this flowId */ - const existingFlow = await flowManager.getFlowState(flowId, 'mcp_oauth'); - if (existingFlow && existingFlow.status === 'PENDING') { - logger.debug( - `${logPrefix} OAuth flow already exists for ${flowId}, waiting for completion`, - ); - /** Tokens from existing flow to complete */ - const tokens = await flowManager.createFlow(flowId, 'mcp_oauth'); - if (typeof oauthEnd === 'function') { - await oauthEnd(); - } - logger.info(`${logPrefix} OAuth flow completed, tokens received for ${serverName}`); - - /** Client information from the existing flow metadata */ - const existingMetadata = existingFlow.metadata as unknown as MCPOAuthFlowMetadata; - const clientInfo = existingMetadata?.clientInfo; - - return { tokens, clientInfo }; - } - - logger.debug(`${logPrefix} Initiating new OAuth flow for ${serverName}...`); - const { - authorizationUrl, - flowId: newFlowId, - flowMetadata, - } = await MCPOAuthHandler.initiateOAuthFlow(serverName, serverUrl, userId, config?.oauth); - - if (typeof oauthStart === 'function') { - logger.info(`${logPrefix} OAuth flow started, issued authorization URL to user`); - await oauthStart(authorizationUrl); - } else { - logger.info(` -═══════════════════════════════════════════════════════════════════════ -Please visit the following URL to authenticate: - -${authorizationUrl} - -${logPrefix} Flow ID: ${newFlowId} -═══════════════════════════════════════════════════════════════════════ -`); - } - - /** Tokens from the new flow */ - const tokens = await flowManager.createFlow( - newFlowId, - 'mcp_oauth', - flowMetadata as FlowMetadata, - ); - if (typeof oauthEnd === 'function') { - await oauthEnd(); - } - logger.info(`${logPrefix} OAuth flow completed, tokens received for ${serverName}`); - - /** Client information from the flow metadata */ - const clientInfo = flowMetadata?.clientInfo; - - return { tokens, clientInfo }; - } catch (error) { - logger.error(`${logPrefix} Failed to complete OAuth flow for ${serverName}`, error); - return null; - } - } - public getUserConnections(userId: string) { - return this.userConnections.get(userId); - } - - /** Get servers that require OAuth */ - public getOAuthServers(): Set { - return this.oauthServers; - } -} diff --git a/packages/api/src/mcp/mcpConfig.ts b/packages/api/src/mcp/mcpConfig.ts new file mode 100644 index 000000000..4d0f2b8ca --- /dev/null +++ b/packages/api/src/mcp/mcpConfig.ts @@ -0,0 +1,11 @@ +import { math, isEnabled } from '~/utils'; + +/** + * Centralized configuration for MCP-related environment variables. + * Provides typed access to MCP settings with default values. + */ +export const mcpConfig = { + OAUTH_ON_AUTH_ERROR: isEnabled(process.env.MCP_OAUTH_ON_AUTH_ERROR ?? true), + OAUTH_DETECTION_TIMEOUT: math(process.env.MCP_OAUTH_DETECTION_TIMEOUT ?? 5000), + CONNECTION_CHECK_TTL: math(process.env.MCP_CONNECTION_CHECK_TTL ?? 60000), +}; diff --git a/packages/api/src/mcp/oauth/detectOAuth.ts b/packages/api/src/mcp/oauth/detectOAuth.ts new file mode 100644 index 000000000..f22f5e4cd --- /dev/null +++ b/packages/api/src/mcp/oauth/detectOAuth.ts @@ -0,0 +1,120 @@ +// ATTENTION: If you modify OAuth detection logic in this file, run the integration tests to verify: +// npx jest --testMatch="**/detectOAuth.integration.dev.ts" (from packages/api directory) +// +// These tests are excluded from CI because they make live HTTP requests to external services, +// which could cause flaky builds due to network issues or changes in third-party endpoints. +// Manual testing ensures the OAuth detection still works against real MCP servers. + +import { discoverOAuthProtectedResourceMetadata } from '@modelcontextprotocol/sdk/client/auth.js'; +import { mcpConfig } from '../mcpConfig'; + +export interface OAuthDetectionResult { + requiresOAuth: boolean; + method: 'protected-resource-metadata' | '401-challenge-metadata' | 'no-metadata-found'; + metadata?: Record | null; +} + +/** + * Detects if an MCP server requires OAuth authentication using proactive discovery methods. + * + * This function implements a comprehensive OAuth detection strategy: + * 1. Standard Protected Resource Metadata (RFC 9728) - checks /.well-known/oauth-protected-resource + * 2. 401 Challenge Method - checks WWW-Authenticate header for resource_metadata URL + * 3. Optional fallback: treat any 401/403 response as OAuth requirement (if MCP_OAUTH_ON_AUTH_ERROR=true) + * + * @param serverUrl - The MCP server URL to check for OAuth requirements + * @returns Promise - OAuth requirement details + */ +export async function detectOAuthRequirement(serverUrl: string): Promise { + const protectedResourceResult = await checkProtectedResourceMetadata(serverUrl); + if (protectedResourceResult) return protectedResourceResult; + + const challengeResult = await check401ChallengeMetadata(serverUrl); + if (challengeResult) return challengeResult; + + const fallbackResult = await checkAuthErrorFallback(serverUrl); + if (fallbackResult) return fallbackResult; + + // No OAuth detected + return { + requiresOAuth: false, + method: 'no-metadata-found', + metadata: null, + }; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// ------------------------ Private helper functions for OAuth detection -------------------------// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Checks for OAuth using standard protected resource metadata (RFC 9728) +async function checkProtectedResourceMetadata( + serverUrl: string, +): Promise { + try { + const resourceMetadata = await discoverOAuthProtectedResourceMetadata(serverUrl); + + if (!resourceMetadata?.authorization_servers?.length) return null; + + return { + requiresOAuth: true, + method: 'protected-resource-metadata', + metadata: resourceMetadata, + }; + } catch { + return null; + } +} + +// Checks for OAuth using 401 challenge with resource metadata URL +async function check401ChallengeMetadata(serverUrl: string): Promise { + try { + const response = await fetch(serverUrl, { + method: 'HEAD', + signal: AbortSignal.timeout(mcpConfig.OAUTH_DETECTION_TIMEOUT), + }); + + if (response.status !== 401) return null; + + const wwwAuth = response.headers.get('www-authenticate'); + const metadataUrl = wwwAuth?.match(/resource_metadata="([^"]+)"/)?.[1]; + if (!metadataUrl) return null; + + const metadataResponse = await fetch(metadataUrl, { + signal: AbortSignal.timeout(mcpConfig.OAUTH_DETECTION_TIMEOUT), + }); + const metadata = await metadataResponse.json(); + + if (!metadata?.authorization_servers?.length) return null; + + return { + requiresOAuth: true, + method: '401-challenge-metadata', + metadata, + }; + } catch { + return null; + } +} + +// Fallback method: treats any auth error as OAuth requirement if configured +async function checkAuthErrorFallback(serverUrl: string): Promise { + try { + if (!mcpConfig.OAUTH_ON_AUTH_ERROR) return null; + + const response = await fetch(serverUrl, { + method: 'HEAD', + signal: AbortSignal.timeout(mcpConfig.OAUTH_DETECTION_TIMEOUT), + }); + + if (response.status !== 401 && response.status !== 403) return null; + + return { + requiresOAuth: true, + method: 'no-metadata-found', + metadata: null, + }; + } catch { + return null; + } +} diff --git a/packages/api/src/mcp/oauth/index.ts b/packages/api/src/mcp/oauth/index.ts index ff82bd92f..d9c75071e 100644 --- a/packages/api/src/mcp/oauth/index.ts +++ b/packages/api/src/mcp/oauth/index.ts @@ -1,3 +1,4 @@ export * from './types'; export * from './handler'; export * from './tokens'; +export * from './detectOAuth'; diff --git a/packages/data-provider/src/mcp.ts b/packages/data-provider/src/mcp.ts index b244f7fe5..e908d9f90 100644 --- a/packages/data-provider/src/mcp.ts +++ b/packages/data-provider/src/mcp.ts @@ -3,6 +3,13 @@ import { TokenExchangeMethodEnum } from './types/agents'; import { extractEnvVariable } from './utils'; const BaseOptionsSchema = z.object({ + /** + * Controls whether the MCP server is initialized during application startup. + * - true (default): Server is initialized during app startup and included in app-level connections + * - false: Skips initialization at startup and excludes from app-level connections - useful for servers + * requiring manual authentication (e.g., GitHub PAT tokens) that need to be configured through the UI after startup + */ + startup: z.boolean().optional(), iconPath: z.string().optional(), timeout: z.number().optional(), initTimeout: z.number().optional(), @@ -15,6 +22,11 @@ const BaseOptionsSchema = z.object({ * - string: Use custom instructions (overrides server-provided) */ serverInstructions: z.union([z.boolean(), z.string()]).optional(), + /** + * Whether this server requires OAuth authentication + * If not specified, will be auto-detected during construction + */ + requiresOAuth: z.boolean().optional(), /** * OAuth configuration for SSE and Streamable HTTP transports * - Optional: OAuth can be auto-discovered on 401 responses