diff --git a/api/app/clients/tools/util/handleTools.js b/api/app/clients/tools/util/handleTools.js index 4b86101425..8adb43f945 100644 --- a/api/app/clients/tools/util/handleTools.js +++ b/api/app/clients/tools/util/handleTools.js @@ -14,7 +14,6 @@ const { buildImageToolContext, buildWebSearchContext, } = require('@librechat/api'); -const { getMCPServersRegistry } = require('~/config'); const { Tools, Constants, @@ -39,12 +38,13 @@ const { createGeminiImageTool, createOpenAIImageTools, } = require('../'); -const { primeFiles: primeCodeFiles } = require('~/server/services/Files/Code/process'); +const { createMCPTool, createMCPTools, resolveConfigServers } = require('~/server/services/MCP'); const { createFileSearchTool, primeFiles: primeSearchFiles } = require('./fileSearch'); +const { primeFiles: primeCodeFiles } = require('~/server/services/Files/Code/process'); const { getUserPluginAuthValue } = require('~/server/services/PluginService'); -const { createMCPTool, createMCPTools } = require('~/server/services/MCP'); const { loadAuthValues } = require('~/server/services/Tools/credentials'); const { getMCPServerTools } = require('~/server/services/Config'); +const { getMCPServersRegistry } = require('~/config'); const { getRoleByName } = require('~/models'); /** @@ -256,6 +256,12 @@ const loadTools = async ({ const toolContextMap = {}; const requestedMCPTools = {}; + /** Resolve config-source servers for the current user/tenant context */ + let configServers; + if (tools.some((tool) => tool && mcpToolPattern.test(tool))) { + configServers = await resolveConfigServers(options.req); + } + for (const tool of tools) { if (tool === Tools.execute_code) { requestedTools[tool] = async () => { @@ -341,7 +347,7 @@ const loadTools = async ({ continue; } const serverConfig = serverName - ? await getMCPServersRegistry().getServerConfig(serverName, user) + ? await getMCPServersRegistry().getServerConfig(serverName, user, configServers) : null; if (!serverConfig) { logger.warn( @@ -419,6 +425,7 @@ const loadTools = async ({ let index = -1; const failedMCPServers = new Set(); const safeUser = createSafeUser(options.req?.user); + for (const [serverName, toolConfigs] of Object.entries(requestedMCPTools)) { index++; /** @type {LCAvailableTools} */ @@ -433,6 +440,7 @@ const loadTools = async ({ signal, user: safeUser, userMCPAuthMap, + configServers, res: options.res, streamId: options.req?._resumableStreamId || null, model: agent?.model ?? model, diff --git a/api/server/controllers/agents/client.js b/api/server/controllers/agents/client.js index 47a10165e3..d6795a4be9 100644 --- a/api/server/controllers/agents/client.js +++ b/api/server/controllers/agents/client.js @@ -50,6 +50,7 @@ const { const { filterFilesByAgentAccess } = require('~/server/services/Files/permissions'); const { encodeAndFormat } = require('~/server/services/Files/images/encode'); const { createContextHandlers } = require('~/app/clients/prompts'); +const { resolveConfigServers } = require('~/server/services/MCP'); const { getMCPServerTools } = require('~/server/services/Config'); const BaseClient = require('~/app/clients/BaseClient'); const { getMCPManager } = require('~/config'); @@ -377,6 +378,9 @@ class AgentClient extends BaseClient { */ const ephemeralAgent = this.options.req.body.ephemeralAgent; const mcpManager = getMCPManager(); + + const configServers = await resolveConfigServers(this.options.req); + await Promise.all( allAgents.map(({ agent, agentId }) => applyContextToAgent({ @@ -384,6 +388,7 @@ class AgentClient extends BaseClient { agentId, logger, mcpManager, + configServers, sharedRunContext, ephemeralAgent: agentId === this.options.agent.id ? ephemeralAgent : undefined, }), diff --git a/api/server/controllers/agents/client.test.js b/api/server/controllers/agents/client.test.js index 41a806f66d..1595f652f7 100644 --- a/api/server/controllers/agents/client.test.js +++ b/api/server/controllers/agents/client.test.js @@ -22,6 +22,10 @@ jest.mock('~/server/services/Config', () => ({ getMCPServerTools: jest.fn(), })); +jest.mock('~/server/services/MCP', () => ({ + resolveConfigServers: jest.fn().mockResolvedValue({}), +})); + jest.mock('~/models', () => ({ getAgent: jest.fn(), getRoleByName: jest.fn(), @@ -1315,7 +1319,7 @@ describe('AgentClient - titleConvo', () => { }); // Verify formatInstructionsForContext was called with correct server names - expect(mockFormatInstructions).toHaveBeenCalledWith(['server1', 'server2']); + expect(mockFormatInstructions).toHaveBeenCalledWith(['server1', 'server2'], {}); // Verify the instructions do NOT contain [object Promise] expect(client.options.agent.instructions).not.toContain('[object Promise]'); @@ -1355,10 +1359,10 @@ describe('AgentClient - titleConvo', () => { }); // Verify formatInstructionsForContext was called with ephemeral server names - expect(mockFormatInstructions).toHaveBeenCalledWith([ - 'ephemeral-server1', - 'ephemeral-server2', - ]); + expect(mockFormatInstructions).toHaveBeenCalledWith( + ['ephemeral-server1', 'ephemeral-server2'], + {}, + ); // Verify no [object Promise] in instructions expect(client.options.agent.instructions).not.toContain('[object Promise]'); diff --git a/api/server/controllers/mcp.js b/api/server/controllers/mcp.js index 729f01da9d..e31bb93bc6 100644 --- a/api/server/controllers/mcp.js +++ b/api/server/controllers/mcp.js @@ -14,6 +14,7 @@ const { isMCPInspectionFailedError, } = require('@librechat/api'); const { Constants, MCPServerUserInputSchema } = require('librechat-data-provider'); +const { resolveConfigServers, resolveAllMcpConfigs } = require('~/server/services/MCP'); const { cacheMCPServerTools, getMCPServerTools } = require('~/server/services/Config'); const { getMCPManager, getMCPServersRegistry } = require('~/config'); @@ -57,7 +58,7 @@ function handleMCPError(error, res) { } /** - * Get all MCP tools available to the user + * Get all MCP tools available to the user. */ const getMCPTools = async (req, res) => { try { @@ -67,10 +68,10 @@ const getMCPTools = async (req, res) => { return res.status(401).json({ message: 'Unauthorized' }); } - const mcpConfig = await getMCPServersRegistry().getAllServerConfigs(userId); - const configuredServers = mcpConfig ? Object.keys(mcpConfig) : []; + const mcpConfig = await resolveAllMcpConfigs(userId, req.user); + const configuredServers = Object.keys(mcpConfig); - if (!mcpConfig || Object.keys(mcpConfig).length == 0) { + if (!configuredServers.length) { return res.status(200).json({ servers: {} }); } @@ -115,14 +116,11 @@ const getMCPTools = async (req, res) => { try { const serverTools = serverToolsMap.get(serverName); - // Get server config once const serverConfig = mcpConfig[serverName]; - const rawServerConfig = await getMCPServersRegistry().getServerConfig(serverName, userId); - // Initialize server object with all server-level data const server = { name: serverName, - icon: rawServerConfig?.iconPath || '', + icon: serverConfig?.iconPath || '', authenticated: true, authConfig: [], tools: [], @@ -183,7 +181,7 @@ const getMCPServersList = async (req, res) => { return res.status(401).json({ message: 'Unauthorized' }); } - const serverConfigs = await getMCPServersRegistry().getAllServerConfigs(userId); + const serverConfigs = await resolveAllMcpConfigs(userId, req.user); return res.json(redactAllServerSecrets(serverConfigs)); } catch (error) { logger.error('[getMCPServersList]', error); @@ -237,7 +235,12 @@ const getMCPServerById = async (req, res) => { if (!serverName) { return res.status(400).json({ message: 'Server name is required' }); } - const parsedConfig = await getMCPServersRegistry().getServerConfig(serverName, userId); + const configServers = await resolveConfigServers(req); + const parsedConfig = await getMCPServersRegistry().getServerConfig( + serverName, + userId, + configServers, + ); if (!parsedConfig) { return res.status(404).json({ message: 'MCP server not found' }); diff --git a/api/server/routes/__tests__/mcp.spec.js b/api/server/routes/__tests__/mcp.spec.js index 1ad8cac087..f194f361d3 100644 --- a/api/server/routes/__tests__/mcp.spec.js +++ b/api/server/routes/__tests__/mcp.spec.js @@ -18,6 +18,7 @@ const mockRegistryInstance = { getServerConfig: jest.fn(), getOAuthServers: jest.fn(), getAllServerConfigs: jest.fn(), + ensureConfigServers: jest.fn().mockResolvedValue({}), addServer: jest.fn(), updateServer: jest.fn(), removeServer: jest.fn(), @@ -58,6 +59,7 @@ jest.mock('@librechat/api', () => { }); jest.mock('@librechat/data-schemas', () => ({ + getTenantId: jest.fn(), logger: { debug: jest.fn(), info: jest.fn(), @@ -93,14 +95,18 @@ jest.mock('~/server/services/Config', () => ({ getCachedTools: jest.fn(), getMCPServerTools: jest.fn(), loadCustomConfig: jest.fn(), + getAppConfig: jest.fn().mockResolvedValue({ mcpConfig: {} }), })); jest.mock('~/server/services/Config/mcp', () => ({ updateMCPServerTools: jest.fn(), })); +const mockResolveAllMcpConfigs = jest.fn().mockResolvedValue({}); jest.mock('~/server/services/MCP', () => ({ getMCPSetupData: jest.fn(), + resolveConfigServers: jest.fn().mockResolvedValue({}), + resolveAllMcpConfigs: (...args) => mockResolveAllMcpConfigs(...args), getServerConnectionStatus: jest.fn(), })); @@ -579,6 +585,112 @@ describe('MCP Routes', () => { ); }); + it('should use oauthHeaders from flow state when present', async () => { + const mockFlowManager = { + getFlowState: jest.fn().mockResolvedValue({ status: 'PENDING' }), + completeFlow: jest.fn().mockResolvedValue(), + deleteFlow: jest.fn().mockResolvedValue(true), + }; + const mockFlowState = { + serverName: 'test-server', + userId: 'test-user-id', + metadata: { toolFlowId: 'tool-flow-123' }, + clientInfo: {}, + codeVerifier: 'test-verifier', + oauthHeaders: { 'X-Custom-Auth': 'header-value' }, + }; + const mockTokens = { access_token: 'tok', refresh_token: 'ref' }; + + MCPOAuthHandler.getFlowState.mockResolvedValue(mockFlowState); + MCPOAuthHandler.completeOAuthFlow.mockResolvedValue(mockTokens); + MCPTokenStorage.storeTokens.mockResolvedValue(); + getLogStores.mockReturnValue({}); + require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager); + require('~/config').getOAuthReconnectionManager.mockReturnValue({ + clearReconnection: jest.fn(), + }); + require('~/config').getMCPManager.mockReturnValue({ + getUserConnection: jest.fn().mockResolvedValue({ + fetchTools: jest.fn().mockResolvedValue([]), + }), + }); + const { getCachedTools, setCachedTools } = require('~/server/services/Config'); + getCachedTools.mockResolvedValue({}); + setCachedTools.mockResolvedValue(); + + const flowId = 'test-user-id:test-server'; + const csrfToken = generateTestCsrfToken(flowId); + + await request(app) + .get('/api/mcp/test-server/oauth/callback') + .set('Cookie', [`oauth_csrf=${csrfToken}`]) + .query({ code: 'auth-code', state: flowId }); + + expect(MCPOAuthHandler.completeOAuthFlow).toHaveBeenCalledWith( + flowId, + 'auth-code', + mockFlowManager, + { 'X-Custom-Auth': 'header-value' }, + ); + expect(mockRegistryInstance.getServerConfig).not.toHaveBeenCalled(); + }); + + it('should fall back to registry oauth_headers when flow state lacks them', async () => { + const mockFlowManager = { + getFlowState: jest.fn().mockResolvedValue({ status: 'PENDING' }), + completeFlow: jest.fn().mockResolvedValue(), + deleteFlow: jest.fn().mockResolvedValue(true), + }; + const mockFlowState = { + serverName: 'test-server', + userId: 'test-user-id', + metadata: { toolFlowId: 'tool-flow-123' }, + clientInfo: {}, + codeVerifier: 'test-verifier', + }; + const mockTokens = { access_token: 'tok', refresh_token: 'ref' }; + + MCPOAuthHandler.getFlowState.mockResolvedValue(mockFlowState); + MCPOAuthHandler.completeOAuthFlow.mockResolvedValue(mockTokens); + MCPTokenStorage.storeTokens.mockResolvedValue(); + mockRegistryInstance.getServerConfig.mockResolvedValue({ + oauth_headers: { 'X-Registry-Header': 'from-registry' }, + }); + getLogStores.mockReturnValue({}); + require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager); + require('~/config').getOAuthReconnectionManager.mockReturnValue({ + clearReconnection: jest.fn(), + }); + require('~/config').getMCPManager.mockReturnValue({ + getUserConnection: jest.fn().mockResolvedValue({ + fetchTools: jest.fn().mockResolvedValue([]), + }), + }); + const { getCachedTools, setCachedTools } = require('~/server/services/Config'); + getCachedTools.mockResolvedValue({}); + setCachedTools.mockResolvedValue(); + + const flowId = 'test-user-id:test-server'; + const csrfToken = generateTestCsrfToken(flowId); + + await request(app) + .get('/api/mcp/test-server/oauth/callback') + .set('Cookie', [`oauth_csrf=${csrfToken}`]) + .query({ code: 'auth-code', state: flowId }); + + expect(MCPOAuthHandler.completeOAuthFlow).toHaveBeenCalledWith( + flowId, + 'auth-code', + mockFlowManager, + { 'X-Registry-Header': 'from-registry' }, + ); + expect(mockRegistryInstance.getServerConfig).toHaveBeenCalledWith( + 'test-server', + 'test-user-id', + undefined, + ); + }); + it('should redirect to error page when callback processing fails', async () => { MCPOAuthHandler.getFlowState.mockRejectedValue(new Error('Callback error')); const flowId = 'test-user-id:test-server'; @@ -1350,19 +1462,10 @@ describe('MCP Routes', () => { }, }); - expect(getMCPSetupData).toHaveBeenCalledWith('test-user-id'); + expect(getMCPSetupData).toHaveBeenCalledWith('test-user-id', expect.any(Object)); expect(getServerConnectionStatus).toHaveBeenCalledTimes(2); }); - it('should return 404 when MCP config is not found', async () => { - getMCPSetupData.mockRejectedValue(new Error('MCP config not found')); - - const response = await request(app).get('/api/mcp/connection/status'); - - expect(response.status).toBe(404); - expect(response.body).toEqual({ error: 'MCP config not found' }); - }); - it('should return 500 when connection status check fails', async () => { getMCPSetupData.mockRejectedValue(new Error('Database error')); @@ -1437,15 +1540,6 @@ describe('MCP Routes', () => { }); }); - it('should return 404 when MCP config is not found', async () => { - getMCPSetupData.mockRejectedValue(new Error('MCP config not found')); - - const response = await request(app).get('/api/mcp/connection/status/test-server'); - - expect(response.status).toBe(404); - expect(response.body).toEqual({ error: 'MCP config not found' }); - }); - it('should return 500 when connection status check fails', async () => { getMCPSetupData.mockRejectedValue(new Error('Database connection failed')); @@ -1704,7 +1798,7 @@ describe('MCP Routes', () => { }, }; - mockRegistryInstance.getAllServerConfigs.mockResolvedValue(mockServerConfigs); + mockResolveAllMcpConfigs.mockResolvedValue(mockServerConfigs); const response = await request(app).get('/api/mcp/servers'); @@ -1721,11 +1815,14 @@ describe('MCP Routes', () => { }); expect(response.body['server-1'].headers).toBeUndefined(); expect(response.body['server-2'].headers).toBeUndefined(); - expect(mockRegistryInstance.getAllServerConfigs).toHaveBeenCalledWith('test-user-id'); + expect(mockResolveAllMcpConfigs).toHaveBeenCalledWith( + 'test-user-id', + expect.objectContaining({ id: 'test-user-id' }), + ); }); it('should return empty object when no servers are configured', async () => { - mockRegistryInstance.getAllServerConfigs.mockResolvedValue({}); + mockResolveAllMcpConfigs.mockResolvedValue({}); const response = await request(app).get('/api/mcp/servers'); @@ -1749,7 +1846,7 @@ describe('MCP Routes', () => { }); it('should return 500 when server config retrieval fails', async () => { - mockRegistryInstance.getAllServerConfigs.mockRejectedValue(new Error('Database error')); + mockResolveAllMcpConfigs.mockRejectedValue(new Error('Database error')); const response = await request(app).get('/api/mcp/servers'); @@ -1939,11 +2036,12 @@ describe('MCP Routes', () => { expect(mockRegistryInstance.getServerConfig).toHaveBeenCalledWith( 'test-server', 'test-user-id', + {}, ); }); it('should return 404 when server not found', async () => { - mockRegistryInstance.getServerConfig.mockResolvedValue(null); + mockRegistryInstance.getServerConfig.mockResolvedValue(undefined); const response = await request(app).get('/api/mcp/servers/non-existent-server'); diff --git a/api/server/routes/mcp.js b/api/server/routes/mcp.js index d6d7ed5ea0..c6496ad4b4 100644 --- a/api/server/routes/mcp.js +++ b/api/server/routes/mcp.js @@ -1,5 +1,5 @@ const { Router } = require('express'); -const { logger } = require('@librechat/data-schemas'); +const { logger, getTenantId } = require('@librechat/data-schemas'); const { CacheKeys, Constants, @@ -36,7 +36,11 @@ const { getFlowStateManager, getMCPManager, } = require('~/config'); -const { getMCPSetupData, getServerConnectionStatus } = require('~/server/services/MCP'); +const { + getServerConnectionStatus, + resolveConfigServers, + getMCPSetupData, +} = require('~/server/services/MCP'); const { requireJwtAuth, canAccessMCPServerResource } = require('~/server/middleware'); const { getUserPluginAuthValue } = require('~/server/services/PluginService'); const { updateMCPServerTools } = require('~/server/services/Config/mcp'); @@ -101,7 +105,8 @@ router.get('/:serverName/oauth/initiate', requireJwtAuth, setOAuthSession, async return res.status(400).json({ error: 'Invalid flow state' }); } - const oauthHeaders = await getOAuthHeaders(serverName, userId); + const configServers = await resolveConfigServers(req); + const oauthHeaders = await getOAuthHeaders(serverName, userId, configServers); const { authorizationUrl, flowId: oauthFlowId, @@ -233,7 +238,14 @@ router.get('/:serverName/oauth/callback', async (req, res) => { } logger.debug('[MCP OAuth] Completing OAuth flow'); - const oauthHeaders = await getOAuthHeaders(serverName, flowState.userId); + if (!flowState.oauthHeaders) { + logger.warn( + '[MCP OAuth] oauthHeaders absent from flow state — config-source server oauth_headers will be empty', + { serverName, flowId }, + ); + } + const oauthHeaders = + flowState.oauthHeaders ?? (await getOAuthHeaders(serverName, flowState.userId)); const tokens = await MCPOAuthHandler.completeOAuthFlow(flowId, code, flowManager, oauthHeaders); logger.info('[MCP OAuth] OAuth flow completed, tokens received in callback route'); @@ -497,7 +509,12 @@ router.post( logger.info(`[MCP Reinitialize] Reinitializing server: ${serverName}`); const mcpManager = getMCPManager(); - const serverConfig = await getMCPServersRegistry().getServerConfig(serverName, user.id); + const configServers = await resolveConfigServers(req); + const serverConfig = await getMCPServersRegistry().getServerConfig( + serverName, + user.id, + configServers, + ); if (!serverConfig) { return res.status(404).json({ error: `MCP server '${serverName}' not found in configuration`, @@ -522,6 +539,8 @@ router.post( const result = await reinitMCPServer({ user, serverName, + serverConfig, + configServers, userMCPAuthMap, }); @@ -564,6 +583,7 @@ router.get('/connection/status', requireJwtAuth, async (req, res) => { const { mcpConfig, appConnections, userConnections, oauthServers } = await getMCPSetupData( user.id, + { role: user.role, tenantId: getTenantId() }, ); const connectionStatus = {}; @@ -593,9 +613,6 @@ router.get('/connection/status', requireJwtAuth, async (req, res) => { connectionStatus, }); } catch (error) { - if (error.message === 'MCP config not found') { - return res.status(404).json({ error: error.message }); - } logger.error('[MCP Connection Status] Failed to get connection status', error); res.status(500).json({ error: 'Failed to get connection status' }); } @@ -616,6 +633,7 @@ router.get('/connection/status/:serverName', requireJwtAuth, async (req, res) => const { mcpConfig, appConnections, userConnections, oauthServers } = await getMCPSetupData( user.id, + { role: user.role, tenantId: getTenantId() }, ); if (!mcpConfig[serverName]) { @@ -640,9 +658,6 @@ router.get('/connection/status/:serverName', requireJwtAuth, async (req, res) => requiresOAuth: serverStatus.requiresOAuth, }); } catch (error) { - if (error.message === 'MCP config not found') { - return res.status(404).json({ error: error.message }); - } logger.error( `[MCP Per-Server Status] Failed to get connection status for ${req.params.serverName}`, error, @@ -664,7 +679,12 @@ router.get('/:serverName/auth-values', requireJwtAuth, checkMCPUsePermissions, a return res.status(401).json({ error: 'User not authenticated' }); } - const serverConfig = await getMCPServersRegistry().getServerConfig(serverName, user.id); + const configServers = await resolveConfigServers(req); + const serverConfig = await getMCPServersRegistry().getServerConfig( + serverName, + user.id, + configServers, + ); if (!serverConfig) { return res.status(404).json({ error: `MCP server '${serverName}' not found in configuration`, @@ -703,8 +723,12 @@ router.get('/:serverName/auth-values', requireJwtAuth, checkMCPUsePermissions, a } }); -async function getOAuthHeaders(serverName, userId) { - const serverConfig = await getMCPServersRegistry().getServerConfig(serverName, userId); +async function getOAuthHeaders(serverName, userId, configServers) { + const serverConfig = await getMCPServersRegistry().getServerConfig( + serverName, + userId, + configServers, + ); return serverConfig?.oauth_headers ?? {}; } diff --git a/api/server/services/Config/__tests__/invalidateConfigCaches.spec.js b/api/server/services/Config/__tests__/invalidateConfigCaches.spec.js index df21786f05..49e94bc081 100644 --- a/api/server/services/Config/__tests__/invalidateConfigCaches.spec.js +++ b/api/server/services/Config/__tests__/invalidateConfigCaches.spec.js @@ -32,12 +32,14 @@ jest.mock('../getCachedTools', () => ({ invalidateCachedTools: mockInvalidateCachedTools, })); +const mockClearMcpConfigCache = jest.fn().mockResolvedValue(undefined); jest.mock('@librechat/api', () => ({ createAppConfigService: jest.fn(() => ({ getAppConfig: jest.fn().mockResolvedValue({ availableTools: {} }), clearAppConfigCache: mockClearAppConfigCache, clearOverrideCache: mockClearOverrideCache, })), + clearMcpConfigCache: mockClearMcpConfigCache, })); // ── Tests ────────────────────────────────────────────────────────────── diff --git a/api/server/services/Config/app.js b/api/server/services/Config/app.js index c0180fdb12..7530ca1031 100644 --- a/api/server/services/Config/app.js +++ b/api/server/services/Config/app.js @@ -1,6 +1,6 @@ const { CacheKeys } = require('librechat-data-provider'); -const { createAppConfigService } = require('@librechat/api'); const { AppService, logger } = require('@librechat/data-schemas'); +const { createAppConfigService, clearMcpConfigCache } = require('@librechat/api'); const { setCachedTools, invalidateCachedTools } = require('./getCachedTools'); const { loadAndFormatTools } = require('~/server/services/start/tools'); const loadCustomConfig = require('./loadCustomConfig'); @@ -42,7 +42,7 @@ async function clearEndpointConfigCache() { /** * Invalidate all config-related caches after an admin config mutation. * Clears the base config, per-principal override caches, tool caches, - * and the endpoints config cache. + * the endpoints config cache, and the MCP config-source server cache. * @param {string} [tenantId] - Optional tenant ID to scope override cache clearing. */ async function invalidateConfigCaches(tenantId) { @@ -51,12 +51,14 @@ async function invalidateConfigCaches(tenantId) { clearOverrideCache(tenantId), invalidateCachedTools({ invalidateGlobal: true }), clearEndpointConfigCache(), + clearMcpConfigCache(), ]); const labels = [ 'clearAppConfigCache', 'clearOverrideCache', 'invalidateCachedTools', 'clearEndpointConfigCache', + 'clearMcpConfigCache', ]; for (let i = 0; i < results.length; i++) { if (results[i].status === 'rejected') { diff --git a/api/server/services/MCP.js b/api/server/services/MCP.js index d765d335aa..dbb44740a9 100644 --- a/api/server/services/MCP.js +++ b/api/server/services/MCP.js @@ -1,5 +1,5 @@ const { tool } = require('@langchain/core/tools'); -const { logger } = require('@librechat/data-schemas'); +const { logger, getTenantId } = require('@librechat/data-schemas'); const { Providers, StepTypes, @@ -54,6 +54,53 @@ function evictStale(map, ttl) { const unavailableMsg = "This tool's MCP server is temporarily unavailable. Please try again shortly."; +/** + * Resolves config-source MCP servers from admin Config overrides for the current + * request context. Returns the parsed configs keyed by server name. + * @param {import('express').Request} req - Express request with user context + * @returns {Promise>} + */ +async function resolveConfigServers(req) { + try { + const registry = getMCPServersRegistry(); + const user = req?.user; + const appConfig = await getAppConfig({ + role: user?.role, + tenantId: getTenantId(), + userId: user?.id, + }); + return await registry.ensureConfigServers(appConfig?.mcpConfig || {}); + } catch (error) { + logger.warn( + '[resolveConfigServers] Failed to resolve config servers, degrading to empty:', + error, + ); + return {}; + } +} + +/** + * Resolves config-source servers and merges all server configs (YAML + config + user DB) + * for the given user context. Shared helper for controllers needing the full merged config. + * @param {string} userId + * @param {{ id?: string, role?: string }} [user] + * @returns {Promise>} + */ +async function resolveAllMcpConfigs(userId, user) { + const registry = getMCPServersRegistry(); + const appConfig = await getAppConfig({ role: user?.role, tenantId: getTenantId(), userId }); + let configServers = {}; + try { + configServers = await registry.ensureConfigServers(appConfig?.mcpConfig || {}); + } catch (error) { + logger.warn( + '[resolveAllMcpConfigs] Config server resolution failed, continuing without:', + error, + ); + } + return await registry.getAllServerConfigs(userId, configServers); +} + /** * @param {string} toolName * @param {string} serverName @@ -249,6 +296,7 @@ async function reconnectServer({ index, signal, serverName, + configServers, userMCPAuthMap, streamId = null, }) { @@ -317,6 +365,7 @@ async function reconnectServer({ user, signal, serverName, + configServers, oauthStart, flowManager, userMCPAuthMap, @@ -359,13 +408,12 @@ async function createMCPTools({ config, provider, serverName, + configServers, userMCPAuthMap, streamId = null, }) { - // Early domain validation before reconnecting server (avoid wasted work on disallowed domains) - // Use getAppConfig() to support per-user/role domain restrictions const serverConfig = - config ?? (await getMCPServersRegistry().getServerConfig(serverName, user?.id)); + config ?? (await getMCPServersRegistry().getServerConfig(serverName, user?.id, configServers)); if (serverConfig?.url) { const appConfig = await getAppConfig({ role: user?.role, tenantId: user?.tenantId }); const allowedDomains = appConfig?.mcpSettings?.allowedDomains; @@ -382,6 +430,7 @@ async function createMCPTools({ index, signal, serverName, + configServers, userMCPAuthMap, streamId, }); @@ -401,6 +450,7 @@ async function createMCPTools({ user, provider, userMCPAuthMap, + configServers, streamId, availableTools: result.availableTools, toolKey: `${tool.name}${Constants.mcp_delimiter}${serverName}`, @@ -440,14 +490,13 @@ async function createMCPTool({ userMCPAuthMap, availableTools, config, + configServers, streamId = null, }) { const [toolName, serverName] = toolKey.split(Constants.mcp_delimiter); - // Runtime domain validation: check if the server's domain is still allowed - // Use getAppConfig() to support per-user/role domain restrictions const serverConfig = - config ?? (await getMCPServersRegistry().getServerConfig(serverName, user?.id)); + config ?? (await getMCPServersRegistry().getServerConfig(serverName, user?.id, configServers)); if (serverConfig?.url) { const appConfig = await getAppConfig({ role: user?.role, tenantId: user?.tenantId }); const allowedDomains = appConfig?.mcpSettings?.allowedDomains; @@ -478,6 +527,7 @@ async function createMCPTool({ index, signal, serverName, + configServers, userMCPAuthMap, streamId, }); @@ -501,6 +551,7 @@ async function createMCPTool({ provider, toolName, serverName, + serverConfig, toolDefinition, streamId, }); @@ -510,13 +561,14 @@ function createToolInstance({ res, toolName, serverName, + serverConfig: capturedServerConfig, toolDefinition, - provider: _provider, + provider: capturedProvider, streamId = null, }) { /** @type {LCTool} */ const { description, parameters } = toolDefinition; - const isGoogle = _provider === Providers.VERTEXAI || _provider === Providers.GOOGLE; + const isGoogle = capturedProvider === Providers.VERTEXAI || capturedProvider === Providers.GOOGLE; let schema = parameters ? normalizeJsonSchema(resolveJsonSchemaRefs(parameters)) : null; @@ -545,7 +597,7 @@ function createToolInstance({ const flowManager = getFlowStateManager(flowsCache); derivedSignal = config?.signal ? AbortSignal.any([config.signal]) : undefined; const mcpManager = getMCPManager(userId); - const provider = (config?.metadata?.provider || _provider)?.toLowerCase(); + const provider = (config?.metadata?.provider || capturedProvider)?.toLowerCase(); const { args: _args, stepId, ...toolCall } = config.toolCall ?? {}; const flowId = `${serverName}:oauth_login:${config.metadata.thread_id}:${config.metadata.run_id}`; @@ -577,6 +629,7 @@ function createToolInstance({ const result = await mcpManager.callTool({ serverName, + serverConfig: capturedServerConfig, toolName, provider, toolArguments, @@ -644,30 +697,36 @@ function createToolInstance({ } /** - * Get MCP setup data including config, connections, and OAuth servers + * Get MCP setup data including config, connections, and OAuth servers. + * Resolves config-source servers from admin Config overrides when tenant context is available. * @param {string} userId - The user ID + * @param {{ role?: string, tenantId?: string }} [options] - Optional role/tenant context * @returns {Object} Object containing mcpConfig, appConnections, userConnections, and oauthServers */ -async function getMCPSetupData(userId) { - const mcpConfig = await getMCPServersRegistry().getAllServerConfigs(userId); - - if (!mcpConfig) { - throw new Error('MCP config not found'); - } +async function getMCPSetupData(userId, options = {}) { + const registry = getMCPServersRegistry(); + const { role, tenantId } = options; + const appConfig = await getAppConfig({ role, tenantId, userId }); + const configServers = await registry.ensureConfigServers(appConfig?.mcpConfig || {}); + const mcpConfig = await registry.getAllServerConfigs(userId, configServers); const mcpManager = getMCPManager(userId); /** @type {Map} */ let appConnections = new Map(); try { - // Use getLoaded() instead of getAll() to avoid forcing connection creation + // Use getLoaded() instead of getAll() to avoid forcing connection creation. // getAll() creates connections for all servers, which is problematic for servers - // that require user context (e.g., those with {{LIBRECHAT_USER_ID}} placeholders) + // that require user context (e.g., those with {{LIBRECHAT_USER_ID}} placeholders). appConnections = (await mcpManager.appConnections?.getLoaded()) || new Map(); } catch (error) { logger.error(`[MCP][User: ${userId}] Error getting app connections:`, error); } const userConnections = mcpManager.getUserConnections(userId) || new Map(); - const oauthServers = await getMCPServersRegistry().getOAuthServers(userId); + const oauthServers = new Set( + Object.entries(mcpConfig) + .filter(([, config]) => config.requiresOAuth) + .map(([name]) => name), + ); return { mcpConfig, @@ -789,6 +848,8 @@ module.exports = { createMCPTool, createMCPTools, getMCPSetupData, + resolveConfigServers, + resolveAllMcpConfigs, checkOAuthFlowStatus, getServerConnectionStatus, createUnavailableToolStub, diff --git a/api/server/services/MCP.spec.js b/api/server/services/MCP.spec.js index 14a9ef90ed..c9925827f8 100644 --- a/api/server/services/MCP.spec.js +++ b/api/server/services/MCP.spec.js @@ -14,6 +14,7 @@ const mockRegistryInstance = { getOAuthServers: jest.fn(() => Promise.resolve(new Set())), getAllServerConfigs: jest.fn(() => Promise.resolve({})), getServerConfig: jest.fn(() => Promise.resolve(null)), + ensureConfigServers: jest.fn(() => Promise.resolve({})), }; // Create isMCPDomainAllowed mock that can be configured per-test @@ -113,38 +114,43 @@ describe('tests for the new helper functions used by the MCP connection status e }); it('should successfully return MCP setup data', async () => { - mockRegistryInstance.getAllServerConfigs.mockResolvedValue(mockConfig); + const mockConfigWithOAuth = { + server1: { type: 'stdio' }, + server2: { type: 'http', requiresOAuth: true }, + }; + mockRegistryInstance.getAllServerConfigs.mockResolvedValue(mockConfigWithOAuth); const mockAppConnections = new Map([['server1', { status: 'connected' }]]); const mockUserConnections = new Map([['server2', { status: 'disconnected' }]]); - const mockOAuthServers = new Set(['server2']); const mockMCPManager = { appConnections: { getLoaded: jest.fn(() => Promise.resolve(mockAppConnections)) }, getUserConnections: jest.fn(() => mockUserConnections), }; mockGetMCPManager.mockReturnValue(mockMCPManager); - mockRegistryInstance.getOAuthServers.mockResolvedValue(mockOAuthServers); const result = await getMCPSetupData(mockUserId); - expect(mockRegistryInstance.getAllServerConfigs).toHaveBeenCalledWith(mockUserId); + expect(mockRegistryInstance.ensureConfigServers).toHaveBeenCalled(); + expect(mockRegistryInstance.getAllServerConfigs).toHaveBeenCalledWith( + mockUserId, + expect.any(Object), + ); expect(mockGetMCPManager).toHaveBeenCalledWith(mockUserId); expect(mockMCPManager.appConnections.getLoaded).toHaveBeenCalled(); expect(mockMCPManager.getUserConnections).toHaveBeenCalledWith(mockUserId); - expect(mockRegistryInstance.getOAuthServers).toHaveBeenCalledWith(mockUserId); - expect(result).toEqual({ - mcpConfig: mockConfig, - appConnections: mockAppConnections, - userConnections: mockUserConnections, - oauthServers: mockOAuthServers, - }); + expect(result.mcpConfig).toEqual(mockConfigWithOAuth); + expect(result.appConnections).toEqual(mockAppConnections); + expect(result.userConnections).toEqual(mockUserConnections); + expect(result.oauthServers).toEqual(new Set(['server2'])); }); - it('should throw error when MCP config not found', async () => { - mockRegistryInstance.getAllServerConfigs.mockResolvedValue(null); - await expect(getMCPSetupData(mockUserId)).rejects.toThrow('MCP config not found'); + it('should return empty data when no servers are configured', async () => { + mockRegistryInstance.getAllServerConfigs.mockResolvedValue({}); + const result = await getMCPSetupData(mockUserId); + expect(result.mcpConfig).toEqual({}); + expect(result.oauthServers).toEqual(new Set()); }); it('should handle null values from MCP manager gracefully', async () => { diff --git a/api/server/services/ToolService.js b/api/server/services/ToolService.js index 838de906fe..c11843cb69 100644 --- a/api/server/services/ToolService.js +++ b/api/server/services/ToolService.js @@ -60,6 +60,7 @@ const { manifestToolMap, toolkits } = require('~/app/clients/tools/manifest'); const { createOnSearchResults } = require('~/server/services/Tools/search'); const { loadAuthValues } = require('~/server/services/Tools/credentials'); const { reinitMCPServer } = require('~/server/services/Tools/mcp'); +const { resolveConfigServers } = require('~/server/services/MCP'); const { recordUsage } = require('~/server/services/Threads'); const { loadTools } = require('~/app/clients/tools/util'); const { redactMessage } = require('~/config/parsers'); @@ -514,6 +515,7 @@ async function loadToolDefinitionsWrapper({ req, res, agent, streamId = null, to const flowsCache = getLogStores(CacheKeys.FLOWS); const flowManager = getFlowStateManager(flowsCache); + const configServers = await resolveConfigServers(req); const pendingOAuthServers = new Set(); const createOAuthEmitter = (serverName) => { @@ -579,6 +581,7 @@ async function loadToolDefinitionsWrapper({ req, res, agent, streamId = null, to oauthStart, flowManager, serverName, + configServers, userMCPAuthMap, }); @@ -666,6 +669,7 @@ async function loadToolDefinitionsWrapper({ req, res, agent, streamId = null, to const result = await reinitMCPServer({ user: req.user, serverName, + configServers, userMCPAuthMap, flowManager, returnOnOAuth: false, diff --git a/api/server/services/Tools/mcp.js b/api/server/services/Tools/mcp.js index 7589043e10..f1ebcf9796 100644 --- a/api/server/services/Tools/mcp.js +++ b/api/server/services/Tools/mcp.js @@ -25,11 +25,13 @@ async function reinitMCPServer({ signal, forceNew, serverName, + configServers, userMCPAuthMap, connectionTimeout, returnOnOAuth = true, oauthStart: _oauthStart, flowManager: _flowManager, + serverConfig: providedConfig, }) { /** @type {MCPConnection | null} */ let connection = null; @@ -42,13 +44,28 @@ async function reinitMCPServer({ try { const registry = getMCPServersRegistry(); - const serverConfig = await registry.getServerConfig(serverName, user?.id); + const serverConfig = + providedConfig ?? (await registry.getServerConfig(serverName, user?.id, configServers)); if (serverConfig?.inspectionFailed) { + if (serverConfig.source === 'config') { + logger.info( + `[MCP Reinitialize] Config-source server ${serverName} has inspectionFailed — retry handled by config cache`, + ); + return { + availableTools: null, + success: false, + message: `MCP server '${serverName}' is still unreachable`, + oauthRequired: false, + serverName, + oauthUrl: null, + tools: null, + }; + } logger.info( `[MCP Reinitialize] Server ${serverName} had failed inspection, attempting reinspection`, ); try { - const storageLocation = serverConfig.dbId ? 'DB' : 'CACHE'; + const storageLocation = serverConfig.source === 'user' ? 'DB' : 'CACHE'; await registry.reinspectServer(serverName, storageLocation, user?.id); logger.info(`[MCP Reinitialize] Reinspection succeeded for server: ${serverName}`); } catch (reinspectError) { @@ -93,6 +110,7 @@ async function reinitMCPServer({ returnOnOAuth, customUserVars, connectionTimeout, + serverConfig, }); logger.info(`[MCP Reinitialize] Successfully established connection for ${serverName}`); @@ -125,6 +143,7 @@ async function reinitMCPServer({ oauthStart, customUserVars, connectionTimeout, + configServers, }); if (discoveryResult.tools && discoveryResult.tools.length > 0) { diff --git a/api/server/services/__tests__/MCP.spec.js b/api/server/services/__tests__/MCP.spec.js new file mode 100644 index 0000000000..39e99d54ac --- /dev/null +++ b/api/server/services/__tests__/MCP.spec.js @@ -0,0 +1,131 @@ +const mockRegistry = { + ensureConfigServers: jest.fn(), + getAllServerConfigs: jest.fn(), +}; + +jest.mock('~/config', () => ({ + getMCPServersRegistry: jest.fn(() => mockRegistry), + getMCPManager: jest.fn(), + getFlowStateManager: jest.fn(), + getOAuthReconnectionManager: jest.fn(), +})); + +jest.mock('@librechat/data-schemas', () => ({ + getTenantId: jest.fn(() => 'tenant-1'), + logger: { debug: jest.fn(), info: jest.fn(), warn: jest.fn(), error: jest.fn() }, +})); + +jest.mock('~/server/services/Config', () => ({ + getAppConfig: jest.fn(), + setCachedTools: jest.fn(), + getCachedTools: jest.fn(), + getMCPServerTools: jest.fn(), + loadCustomConfig: jest.fn(), +})); + +jest.mock('~/cache', () => ({ getLogStores: jest.fn() })); +jest.mock('~/models', () => ({ + findToken: jest.fn(), + createToken: jest.fn(), + updateToken: jest.fn(), +})); +jest.mock('~/server/services/GraphTokenService', () => ({ + getGraphApiToken: jest.fn(), +})); +jest.mock('~/server/services/Tools/mcp', () => ({ + reinitMCPServer: jest.fn(), +})); + +const { getAppConfig } = require('~/server/services/Config'); +const { resolveConfigServers, resolveAllMcpConfigs } = require('../MCP'); + +describe('resolveConfigServers', () => { + beforeEach(() => jest.clearAllMocks()); + + it('resolves config servers for the current request context', async () => { + getAppConfig.mockResolvedValue({ mcpConfig: { srv: { url: 'http://a' } } }); + mockRegistry.ensureConfigServers.mockResolvedValue({ srv: { name: 'srv' } }); + + const result = await resolveConfigServers({ user: { id: 'u1', role: 'admin' } }); + + expect(result).toEqual({ srv: { name: 'srv' } }); + expect(getAppConfig).toHaveBeenCalledWith( + expect.objectContaining({ role: 'admin', userId: 'u1' }), + ); + expect(mockRegistry.ensureConfigServers).toHaveBeenCalledWith({ srv: { url: 'http://a' } }); + }); + + it('returns {} when ensureConfigServers throws', async () => { + getAppConfig.mockResolvedValue({ mcpConfig: { srv: {} } }); + mockRegistry.ensureConfigServers.mockRejectedValue(new Error('inspect failed')); + + const result = await resolveConfigServers({ user: { id: 'u1' } }); + + expect(result).toEqual({}); + }); + + it('returns {} when getAppConfig throws', async () => { + getAppConfig.mockRejectedValue(new Error('db timeout')); + + const result = await resolveConfigServers({ user: { id: 'u1' } }); + + expect(result).toEqual({}); + }); + + it('passes empty mcpConfig when appConfig has none', async () => { + getAppConfig.mockResolvedValue({}); + mockRegistry.ensureConfigServers.mockResolvedValue({}); + + await resolveConfigServers({ user: { id: 'u1' } }); + + expect(mockRegistry.ensureConfigServers).toHaveBeenCalledWith({}); + }); +}); + +describe('resolveAllMcpConfigs', () => { + beforeEach(() => jest.clearAllMocks()); + + it('merges config servers with base servers', async () => { + getAppConfig.mockResolvedValue({ mcpConfig: { cfg_srv: {} } }); + mockRegistry.ensureConfigServers.mockResolvedValue({ cfg_srv: { name: 'cfg_srv' } }); + mockRegistry.getAllServerConfigs.mockResolvedValue({ + cfg_srv: { name: 'cfg_srv' }, + yaml_srv: { name: 'yaml_srv' }, + }); + + const result = await resolveAllMcpConfigs('u1', { id: 'u1', role: 'user' }); + + expect(result).toEqual({ + cfg_srv: { name: 'cfg_srv' }, + yaml_srv: { name: 'yaml_srv' }, + }); + expect(mockRegistry.getAllServerConfigs).toHaveBeenCalledWith('u1', { + cfg_srv: { name: 'cfg_srv' }, + }); + }); + + it('continues with empty configServers when ensureConfigServers fails', async () => { + getAppConfig.mockResolvedValue({ mcpConfig: { srv: {} } }); + mockRegistry.ensureConfigServers.mockRejectedValue(new Error('inspect failed')); + mockRegistry.getAllServerConfigs.mockResolvedValue({ yaml_srv: { name: 'yaml_srv' } }); + + const result = await resolveAllMcpConfigs('u1', { id: 'u1' }); + + expect(result).toEqual({ yaml_srv: { name: 'yaml_srv' } }); + expect(mockRegistry.getAllServerConfigs).toHaveBeenCalledWith('u1', {}); + }); + + it('propagates getAllServerConfigs failures', async () => { + getAppConfig.mockResolvedValue({ mcpConfig: {} }); + mockRegistry.ensureConfigServers.mockResolvedValue({}); + mockRegistry.getAllServerConfigs.mockRejectedValue(new Error('redis down')); + + await expect(resolveAllMcpConfigs('u1', { id: 'u1' })).rejects.toThrow('redis down'); + }); + + it('propagates getAppConfig failures', async () => { + getAppConfig.mockRejectedValue(new Error('mongo down')); + + await expect(resolveAllMcpConfigs('u1', { id: 'u1' })).rejects.toThrow('mongo down'); + }); +}); diff --git a/api/server/services/__tests__/ToolService.spec.js b/api/server/services/__tests__/ToolService.spec.js index a468a88eb3..6e06804280 100644 --- a/api/server/services/__tests__/ToolService.spec.js +++ b/api/server/services/__tests__/ToolService.spec.js @@ -64,6 +64,9 @@ jest.mock('~/models', () => ({ jest.mock('~/config', () => ({ getFlowStateManager: jest.fn(() => ({})), })); +jest.mock('~/server/services/MCP', () => ({ + resolveConfigServers: jest.fn().mockResolvedValue({}), +})); jest.mock('~/cache', () => ({ getLogStores: jest.fn(() => ({})), })); diff --git a/packages/api/src/agents/context.spec.ts b/packages/api/src/agents/context.spec.ts index c5358209c7..1d995a52bb 100644 --- a/packages/api/src/agents/context.spec.ts +++ b/packages/api/src/agents/context.spec.ts @@ -154,10 +154,10 @@ describe('Agent Context Utilities', () => { ); expect(result).toBe(instructions); - expect(mockMCPManager.formatInstructionsForContext).toHaveBeenCalledWith([ - 'server1', - 'server2', - ]); + expect(mockMCPManager.formatInstructionsForContext).toHaveBeenCalledWith( + ['server1', 'server2'], + undefined, + ); expect(mockLogger.debug).toHaveBeenCalledWith( '[AgentContext] Fetched MCP instructions for servers:', ['server1', 'server2'], @@ -345,9 +345,10 @@ describe('Agent Context Utilities', () => { logger: mockLogger, }); - expect(mockMCPManager.formatInstructionsForContext).toHaveBeenCalledWith([ - 'ephemeral-server', - ]); + expect(mockMCPManager.formatInstructionsForContext).toHaveBeenCalledWith( + ['ephemeral-server'], + undefined, + ); expect(agent.instructions).toContain('Ephemeral MCP'); }); @@ -375,7 +376,10 @@ describe('Agent Context Utilities', () => { logger: mockLogger, }); - expect(mockMCPManager.formatInstructionsForContext).toHaveBeenCalledWith(['agent-server']); + expect(mockMCPManager.formatInstructionsForContext).toHaveBeenCalledWith( + ['agent-server'], + undefined, + ); }); it('should work without agentId', async () => { diff --git a/packages/api/src/agents/context.ts b/packages/api/src/agents/context.ts index ebae2e0f9f..c526fd13fe 100644 --- a/packages/api/src/agents/context.ts +++ b/packages/api/src/agents/context.ts @@ -1,8 +1,9 @@ -import { DynamicStructuredTool } from '@langchain/core/tools'; import { Constants } from 'librechat-data-provider'; +import { DynamicStructuredTool } from '@langchain/core/tools'; import type { Agent, TEphemeralAgent } from 'librechat-data-provider'; import type { LCTool } from '@librechat/agents'; import type { Logger } from 'winston'; +import type { ParsedServerConfig } from '~/mcp/types'; import type { MCPManager } from '~/mcp/MCPManager'; /** @@ -63,12 +64,16 @@ export async function getMCPInstructionsForServers( mcpServers: string[], mcpManager: MCPManager, logger?: Logger, + configServers?: Record, ): Promise { if (!mcpServers.length) { return ''; } try { - const mcpInstructions = await mcpManager.formatInstructionsForContext(mcpServers); + const mcpInstructions = await mcpManager.formatInstructionsForContext( + mcpServers, + configServers, + ); if (mcpInstructions && logger) { logger.debug('[AgentContext] Fetched MCP instructions for servers:', mcpServers); } @@ -125,6 +130,7 @@ export async function applyContextToAgent({ ephemeralAgent, agentId, logger, + configServers, }: { agent: AgentWithTools; sharedRunContext: string; @@ -132,12 +138,18 @@ export async function applyContextToAgent({ ephemeralAgent?: TEphemeralAgent; agentId?: string; logger?: Logger; + configServers?: Record; }): Promise { const baseInstructions = agent.instructions || ''; try { const mcpServers = ephemeralAgent?.mcp?.length ? ephemeralAgent.mcp : extractMCPServers(agent); - const mcpInstructions = await getMCPInstructionsForServers(mcpServers, mcpManager, logger); + const mcpInstructions = await getMCPInstructionsForServers( + mcpServers, + mcpManager, + logger, + configServers, + ); agent.instructions = buildAgentInstructions({ sharedRunContext, diff --git a/packages/api/src/index.ts b/packages/api/src/index.ts index 5ccf6b0124..7a04b8e74a 100644 --- a/packages/api/src/index.ts +++ b/packages/api/src/index.ts @@ -14,6 +14,7 @@ export * from './mcp/oauth'; export * from './mcp/auth'; export * from './mcp/zod'; export * from './mcp/errors'; +export * from './mcp/cache'; /* Utilities */ export * from './mcp/utils'; export * from './utils'; diff --git a/packages/api/src/mcp/ConnectionsRepository.ts b/packages/api/src/mcp/ConnectionsRepository.ts index 6313faa8d4..79976b1199 100644 --- a/packages/api/src/mcp/ConnectionsRepository.ts +++ b/packages/api/src/mcp/ConnectionsRepository.ts @@ -2,7 +2,7 @@ import { logger } from '@librechat/data-schemas'; import type * as t from './types'; import { MCPServersRegistry } from '~/mcp/registry/MCPServersRegistry'; import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory'; -import { hasCustomUserVars } from './utils'; +import { hasCustomUserVars, isUserSourced } from './utils'; import { MCPConnection } from './connection'; const CONNECT_CONCURRENCY = 3; @@ -82,7 +82,7 @@ export class ConnectionsRepository { { serverName, serverConfig, - dbSourced: !!(serverConfig as t.ParsedServerConfig).dbId, + dbSourced: isUserSourced(serverConfig as t.ParsedServerConfig), useSSRFProtection: registry.shouldEnableSSRFProtection(), allowedDomains: registry.getAllowedDomains(), }, diff --git a/packages/api/src/mcp/MCPManager.ts b/packages/api/src/mcp/MCPManager.ts index 935307fa49..12227de39f 100644 --- a/packages/api/src/mcp/MCPManager.ts +++ b/packages/api/src/mcp/MCPManager.ts @@ -18,6 +18,7 @@ import { preProcessGraphTokens } from '~/utils/graph'; import { formatToolContent } from './parsers'; import { MCPConnection } from './connection'; import { processMCPEnv } from '~/utils/env'; +import { isUserSourced } from './utils'; /** * Centralized manager for MCP server connections and tool execution. @@ -53,6 +54,8 @@ export class MCPManager extends UserConnectionManager { user?: IUser; forceNew?: boolean; flowManager?: FlowStateManager; + /** Pre-resolved config for config-source servers not in YAML/DB */ + serverConfig?: t.ParsedServerConfig; } & Omit, ): Promise { //the get method checks if the config is still valid as app level @@ -91,6 +94,7 @@ export class MCPManager extends UserConnectionManager { const serverConfig = await MCPServersRegistry.getInstance().getServerConfig( serverName, user?.id, + args.configServers, ); if (!serverConfig) { @@ -103,7 +107,7 @@ export class MCPManager extends UserConnectionManager { const registry = MCPServersRegistry.getInstance(); const useSSRFProtection = registry.shouldEnableSSRFProtection(); const allowedDomains = registry.getAllowedDomains(); - const dbSourced = !!serverConfig.dbId; + const dbSourced = isUserSourced(serverConfig); const basic: t.BasicConnectionOptions = { dbSourced, serverName, @@ -193,9 +197,15 @@ export class MCPManager extends UserConnectionManager { * @param serverNames Optional array of server names. If not provided or empty, returns all servers. * @returns Object mapping server names to their instructions */ - private async getInstructions(serverNames?: string[]): Promise> { + private async getInstructions( + serverNames?: string[], + configServers?: Record, + ): Promise> { const instructions: Record = {}; - const configs = await MCPServersRegistry.getInstance().getAllServerConfigs(); + const configs = await MCPServersRegistry.getInstance().getAllServerConfigs( + undefined, + configServers, + ); for (const [serverName, config] of Object.entries(configs)) { if (config.serverInstructions != null) { instructions[serverName] = config.serverInstructions as string; @@ -210,9 +220,11 @@ export class MCPManager extends UserConnectionManager { * @param serverNames Optional array of server names to include. If not provided, includes all servers. * @returns Formatted instructions string ready for context injection */ - public async formatInstructionsForContext(serverNames?: string[]): Promise { - /** Instructions for specified servers or all stored instructions */ - const instructionsToInclude = await this.getInstructions(serverNames); + public async formatInstructionsForContext( + serverNames?: string[], + configServers?: Record, + ): Promise { + const instructionsToInclude = await this.getInstructions(serverNames, configServers); if (Object.keys(instructionsToInclude).length === 0) { return ''; @@ -248,6 +260,7 @@ Please follow these instructions when using tools from the respective MCP server async callTool({ user, serverName, + serverConfig: providedConfig, toolName, provider, toolArguments, @@ -262,6 +275,8 @@ Please follow these instructions when using tools from the respective MCP server }: { user?: IUser; serverName: string; + /** Pre-resolved config from tool creation context — avoids readThrough TTL and cross-tenant issues */ + serverConfig?: t.ParsedServerConfig; toolName: string; provider: t.Provider; toolArguments?: Record; @@ -292,6 +307,7 @@ Please follow these instructions when using tools from the respective MCP server signal: options?.signal, customUserVars, requestBody, + serverConfig: providedConfig, }); if (!(await connection.isConnected())) { @@ -302,8 +318,16 @@ Please follow these instructions when using tools from the respective MCP server ); } - const rawConfig = await MCPServersRegistry.getInstance().getServerConfig(serverName, userId); - const isDbSourced = !!rawConfig?.dbId; + const rawConfig = + providedConfig ?? + (await MCPServersRegistry.getInstance().getServerConfig(serverName, userId)); + if (!rawConfig) { + throw new McpError( + ErrorCode.InvalidRequest, + `${logPrefix} Configuration for server "${serverName}" not found.`, + ); + } + const isDbSourced = isUserSourced(rawConfig); /** Pre-process Graph token placeholders (async) before the synchronous processMCPEnv pass */ const graphProcessedConfig = isDbSourced diff --git a/packages/api/src/mcp/UserConnectionManager.ts b/packages/api/src/mcp/UserConnectionManager.ts index 2e9d5be467..760f84c75e 100644 --- a/packages/api/src/mcp/UserConnectionManager.ts +++ b/packages/api/src/mcp/UserConnectionManager.ts @@ -4,6 +4,7 @@ import type * as t from './types'; import { MCPServersRegistry } from '~/mcp/registry/MCPServersRegistry'; import { ConnectionsRepository } from '~/mcp/ConnectionsRepository'; import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory'; +import { isUserSourced } from './utils'; import { MCPConnection } from './connection'; import { mcpConfig } from './mcpConfig'; @@ -38,6 +39,8 @@ export abstract class UserConnectionManager { opts: { serverName: string; forceNew?: boolean; + /** Pre-resolved config for config-source servers not in YAML/DB */ + serverConfig?: t.ParsedServerConfig; } & Omit, ): Promise { const { serverName, forceNew, user } = opts; @@ -85,9 +88,11 @@ export abstract class UserConnectionManager { signal, returnOnOAuth = false, connectionTimeout, + serverConfig: providedConfig, }: { serverName: string; forceNew?: boolean; + serverConfig?: t.ParsedServerConfig; } & Omit, userId: string, ): Promise { @@ -98,7 +103,9 @@ export abstract class UserConnectionManager { ); } - const config = await MCPServersRegistry.getInstance().getServerConfig(serverName, userId); + const config = + providedConfig ?? + (await MCPServersRegistry.getInstance().getServerConfig(serverName, userId)); const userServerMap = this.userConnections.get(userId); let connection = forceNew ? undefined : userServerMap?.get(serverName); @@ -158,7 +165,7 @@ export abstract class UserConnectionManager { { serverConfig: config, serverName: serverName, - dbSourced: !!config.dbId, + dbSourced: isUserSourced(config), useSSRFProtection: registry.shouldEnableSSRFProtection(), allowedDomains: registry.getAllowedDomains(), }, diff --git a/packages/api/src/mcp/__tests__/utils.test.ts b/packages/api/src/mcp/__tests__/utils.test.ts index c244205b99..b9c2a31fa5 100644 --- a/packages/api/src/mcp/__tests__/utils.test.ts +++ b/packages/api/src/mcp/__tests__/utils.test.ts @@ -3,6 +3,7 @@ import { normalizeServerName, redactAllServerSecrets, redactServerSecrets, + isUserSourced, } from '~/mcp/utils'; import type { ParsedServerConfig } from '~/mcp/types'; @@ -273,3 +274,29 @@ describe('redactAllServerSecrets', () => { expect((redacted['server-c'] as Record).command).toBeUndefined(); }); }); + +describe('isUserSourced', () => { + it('returns false when source is yaml', () => { + expect(isUserSourced({ source: 'yaml' })).toBe(false); + }); + + it('returns false when source is config', () => { + expect(isUserSourced({ source: 'config' })).toBe(false); + }); + + it('returns true when source is user', () => { + expect(isUserSourced({ source: 'user' })).toBe(true); + }); + + it('falls back to dbId when source is undefined — dbId present means user-sourced', () => { + expect(isUserSourced({ source: undefined, dbId: 'abc123' })).toBe(true); + }); + + it('falls back to dbId when source is undefined — no dbId means trusted', () => { + expect(isUserSourced({ source: undefined, dbId: undefined })).toBe(false); + }); + + it('returns false when both source and dbId are absent (pre-upgrade YAML server)', () => { + expect(isUserSourced({})).toBe(false); + }); +}); diff --git a/packages/api/src/mcp/cache.ts b/packages/api/src/mcp/cache.ts new file mode 100644 index 0000000000..e68ef42b3c --- /dev/null +++ b/packages/api/src/mcp/cache.ts @@ -0,0 +1,43 @@ +import { logger } from '@librechat/data-schemas'; +import { MCPServersRegistry } from './registry/MCPServersRegistry'; +import { MCPManager } from './MCPManager'; + +/** + * Clears config-source MCP server inspection cache so servers are re-inspected on next access. + * Best-effort disconnection of app-level connections for evicted servers. + * + * User-level connections (used by config-source servers) are cleaned up lazily via + * the stale-check mechanism on the next tool call — this is an accepted design tradeoff + * since iterating all active user sessions is expensive and config mutations are rare. + */ +export async function clearMcpConfigCache(): Promise { + let registry: MCPServersRegistry; + try { + registry = MCPServersRegistry.getInstance(); + } catch { + return; + } + + let evictedServers: string[]; + try { + evictedServers = await registry.invalidateConfigCache(); + } catch (error) { + logger.error('[clearMcpConfigCache] Failed to invalidate config cache:', error); + return; + } + + if (!evictedServers.length) { + return; + } + + try { + const mcpManager = MCPManager.getInstance(); + if (mcpManager?.appConnections) { + await Promise.allSettled( + evictedServers.map((serverName) => mcpManager.appConnections!.disconnect(serverName)), + ); + } + } catch { + // MCPManager not yet initialized — connections cleaned up lazily + } +} diff --git a/packages/api/src/mcp/oauth/handler.ts b/packages/api/src/mcp/oauth/handler.ts index 873af5c66d..e128dec308 100644 --- a/packages/api/src/mcp/oauth/handler.ts +++ b/packages/api/src/mcp/oauth/handler.ts @@ -467,6 +467,7 @@ export class MCPOAuthHandler { codeVerifier, clientInfo, metadata, + ...(Object.keys(oauthHeaders).length > 0 && { oauthHeaders }), }; logger.debug( @@ -573,6 +574,7 @@ export class MCPOAuthHandler { clientInfo, metadata, resourceMetadata, + ...(Object.keys(oauthHeaders).length > 0 && { oauthHeaders }), }; logger.debug( diff --git a/packages/api/src/mcp/oauth/types.ts b/packages/api/src/mcp/oauth/types.ts index 2138b4a782..bc5f53f60c 100644 --- a/packages/api/src/mcp/oauth/types.ts +++ b/packages/api/src/mcp/oauth/types.ts @@ -89,6 +89,8 @@ export interface MCPOAuthFlowMetadata extends FlowMetadata { metadata?: OAuthMetadata; resourceMetadata?: OAuthProtectedResourceMetadata; authorizationUrl?: string; + /** Custom headers for OAuth token exchange, persisted at flow initiation for the callback. */ + oauthHeaders?: Record; } export interface MCPOAuthTokens extends OAuthTokens { diff --git a/packages/api/src/mcp/registry/MCPServerInspector.ts b/packages/api/src/mcp/registry/MCPServerInspector.ts index 7f31211680..f064fbb7e5 100644 --- a/packages/api/src/mcp/registry/MCPServerInspector.ts +++ b/packages/api/src/mcp/registry/MCPServerInspector.ts @@ -4,9 +4,9 @@ import type { MCPConnection } from '~/mcp/connection'; import type * as t from '~/mcp/types'; import { isMCPDomainAllowed, extractMCPServerDomain } from '~/auth/domain'; import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory'; +import { hasCustomUserVars, isUserSourced } from '~/mcp/utils'; import { MCPDomainNotAllowedError } from '~/mcp/errors'; import { detectOAuthRequirement } from '~/mcp/oauth'; -import { hasCustomUserVars } from '~/mcp/utils'; import { isEnabled } from '~/utils'; /** @@ -73,7 +73,7 @@ export class MCPServerInspector { this.connection = await MCPConnectionFactory.create({ serverConfig: this.config, serverName: this.serverName, - dbSourced: !!this.config.dbId, + dbSourced: isUserSourced(this.config), useSSRFProtection: this.useSSRFProtection, allowedDomains: this.allowedDomains, }); diff --git a/packages/api/src/mcp/registry/MCPServersRegistry.ts b/packages/api/src/mcp/registry/MCPServersRegistry.ts index b9c1eb66f5..6c98a6b8dd 100644 --- a/packages/api/src/mcp/registry/MCPServersRegistry.ts +++ b/packages/api/src/mcp/registry/MCPServersRegistry.ts @@ -1,28 +1,48 @@ import { Keyv } from 'keyv'; +import { createHash } from 'crypto'; import { logger } from '@librechat/data-schemas'; import type { IServerConfigsRepositoryInterface } from './ServerConfigsRepositoryInterface'; import type * as t from '~/mcp/types'; -import { ServerConfigsCacheFactory, APP_CACHE_NAMESPACE } from './cache/ServerConfigsCacheFactory'; +import { + ServerConfigsCacheFactory, + APP_CACHE_NAMESPACE, + CONFIG_CACHE_NAMESPACE, +} from './cache/ServerConfigsCacheFactory'; import { MCPInspectionFailedError, isMCPDomainNotAllowedError } from '~/mcp/errors'; import { MCPServerInspector } from './MCPServerInspector'; import { ServerConfigsDB } from './db/ServerConfigsDB'; import { cacheConfig } from '~/cache/cacheConfig'; +import { withTimeout } from '~/utils'; + +/** How long a failure stub is considered fresh before re-attempting inspection (5 minutes). */ +const CONFIG_STUB_RETRY_MS = 5 * 60 * 1000; + +const CONFIG_SERVER_INIT_TIMEOUT_MS = (() => { + const raw = process.env.MCP_INIT_TIMEOUT_MS; + if (raw == null) { + return 30_000; + } + const parsed = parseInt(raw, 10); + return Number.isFinite(parsed) && parsed > 0 ? parsed : 30_000; +})(); /** * Central registry for managing MCP server configurations. * Authoritative source of truth for all MCP servers provided by LibreChat. * - * Uses a two-repository architecture: - * - Cache Repository: Stores YAML-defined configs loaded at startup (in-memory or Redis-backed) - * - DB Repository: Stores dynamic configs created at runtime (not yet implemented) + * Uses a three-layer architecture: + * - YAML Cache (cacheConfigsRepo): Operator-defined configs loaded at startup (in-memory or Redis) + * - Config Cache (configCacheRepo): Admin-defined configs from Config overrides, lazily initialized + * - DB Repository (dbConfigsRepo): User-provided configs created at runtime (MongoDB + ACL) * - * Query priority: Cache configs are checked first, then DB configs. + * Query priority: YAML cache → Config cache → DB. */ export class MCPServersRegistry { private static instance: MCPServersRegistry; private readonly dbConfigsRepo: IServerConfigsRepositoryInterface; private readonly cacheConfigsRepo: IServerConfigsRepositoryInterface; + private readonly configCacheRepo: IServerConfigsRepositoryInterface; private readonly allowedDomains?: string[] | null; private readonly readThroughCache: Keyv; private readonly readThroughCacheAll: Keyv>; @@ -31,9 +51,20 @@ export class MCPServersRegistry { Promise> >(); + /** Tracks in-flight config server initializations to prevent duplicate work. */ + private readonly pendingConfigInits = new Map< + string, + Promise + >(); + + /** Memoized YAML server names — set once after boot-time init, never changes. */ + private yamlServerNames: Set | null = null; + private yamlServerNamesPromise: Promise> | null = null; + constructor(mongoose: typeof import('mongoose'), allowedDomains?: string[] | null) { this.dbConfigsRepo = new ServerConfigsDB(mongoose); this.cacheConfigsRepo = ServerConfigsCacheFactory.create(APP_CACHE_NAMESPACE, false); + this.configCacheRepo = ServerConfigsCacheFactory.create(CONFIG_CACHE_NAMESPACE, false); this.allowedDomains = allowedDomains; const ttl = cacheConfig.MCP_REGISTRY_CACHE_TTL; @@ -86,22 +117,29 @@ export class MCPServersRegistry { return !Array.isArray(this.allowedDomains) || this.allowedDomains.length === 0; } + /** + * Returns the config for a single server. When `configServers` is provided, config-source + * servers are resolved from it directly (no global state, no cross-tenant race). + */ public async getServerConfig( serverName: string, userId?: string, + configServers?: Record, ): Promise { + if (configServers?.[serverName]) { + return configServers[serverName]; + } + const cacheKey = this.getReadThroughCacheKey(serverName, userId); if (await this.readThroughCache.has(cacheKey)) { return await this.readThroughCache.get(cacheKey); } - // First we check if any config exist with the cache - // Yaml config are pre loaded to the cache - const configFromCache = await this.cacheConfigsRepo.get(serverName); - if (configFromCache) { - await this.readThroughCache.set(cacheKey, configFromCache); - return configFromCache; + const configFromYaml = await this.cacheConfigsRepo.get(serverName); + if (configFromYaml) { + await this.readThroughCache.set(cacheKey, configFromYaml); + return configFromYaml; } const configFromDB = await this.dbConfigsRepo.get(serverName, userId); @@ -109,7 +147,30 @@ export class MCPServersRegistry { return configFromDB; } - public async getAllServerConfigs(userId?: string): Promise> { + /** + * Returns all server configs visible to the given user. + * YAML and Config tiers are mutually exclusive by design (`ensureConfigServers` filters + * YAML names), so the spread order only matters for User DB (highest priority) overriding both. + */ + public async getAllServerConfigs( + userId?: string, + configServers?: Record, + ): Promise> { + if (configServers == null || !Object.keys(configServers).length) { + return this.getBaseServerConfigs(userId); + } + const base = await this.getBaseServerConfigs(userId); + return { ...configServers, ...base }; + } + + /** + * Returns YAML + user-DB server configs, cached via `readThroughCacheAll`. + * Always called by `getAllServerConfigs` so the DB query is amortized across + * requests within the TTL window regardless of whether `configServers` is present. + */ + private async getBaseServerConfigs( + userId?: string, + ): Promise> { const cacheKey = userId ?? '__no_user__'; if (await this.readThroughCacheAll.has(cacheKey)) { @@ -121,7 +182,7 @@ export class MCPServersRegistry { return pending; } - const fetchPromise = this.fetchAllServerConfigs(cacheKey, userId); + const fetchPromise = this.fetchBaseServerConfigs(cacheKey, userId); this.pendingGetAllPromises.set(cacheKey, fetchPromise); try { @@ -131,7 +192,7 @@ export class MCPServersRegistry { } } - private async fetchAllServerConfigs( + private async fetchBaseServerConfigs( cacheKey: string, userId?: string, ): Promise> { @@ -155,7 +216,8 @@ export class MCPServersRegistry { userId?: string, ): Promise { const configRepo = this.getConfigRepository(storageLocation); - const stubConfig: t.ParsedServerConfig = { ...config, inspectionFailed: true }; + const source: t.MCPServerSource = storageLocation === 'CACHE' ? 'yaml' : 'user'; + const stubConfig: t.ParsedServerConfig = { ...config, inspectionFailed: true, source }; const result = await configRepo.add(serverName, stubConfig, userId); await this.readThroughCache.delete(this.getReadThroughCacheKey(serverName, userId)); await this.readThroughCache.delete(this.getReadThroughCacheKey(serverName)); @@ -179,13 +241,16 @@ export class MCPServersRegistry { ); } catch (error) { logger.error(`[MCPServersRegistry] Failed to inspect server "${serverName}":`, error); - // Preserve domain-specific error for better error handling if (isMCPDomainNotAllowedError(error)) { throw error; } throw new MCPInspectionFailedError(serverName, error as Error); } - return await configRepo.add(serverName, parsedConfig, userId); + const tagged = { + ...parsedConfig, + source: (storageLocation === 'CACHE' ? 'yaml' : 'user') as t.MCPServerSource, + }; + return await configRepo.add(serverName, tagged, userId); } /** @@ -267,7 +332,6 @@ export class MCPServersRegistry { ); } catch (error) { logger.error(`[MCPServersRegistry] Failed to inspect server "${serverName}":`, error); - // Preserve domain-specific error for better error handling if (isMCPDomainNotAllowedError(error)) { throw error; } @@ -277,8 +341,180 @@ export class MCPServersRegistry { return parsedConfig; } - // TODO: This is currently used to determine if a server requires OAuth. However, this info can - // can be determined through config.requiresOAuth. Refactor usages and remove this method. + /** + * Ensures that config-source MCP servers (from admin Config overrides) are initialized. + * Identifies servers in `resolvedMcpConfig` that are not from YAML, lazily initializes + * any not yet in the config cache, and returns their parsed configs. + * + * Config cache keys are scoped by a hash of the raw config to prevent cross-tenant + * cache poisoning when two tenants define a server with the same name but different configs. + */ + public async ensureConfigServers( + resolvedMcpConfig: Record, + ): Promise> { + if (!resolvedMcpConfig || Object.keys(resolvedMcpConfig).length === 0) { + return {}; + } + + const yamlNames = await this.getYamlServerNames(); + const configServerEntries = Object.entries(resolvedMcpConfig).filter( + ([name]) => !yamlNames.has(name), + ); + + if (configServerEntries.length === 0) { + return {}; + } + + const result: Record = {}; + + const settled = await Promise.allSettled( + configServerEntries.map(async ([serverName, rawConfig]) => { + const parsed = await this.ensureSingleConfigServer(serverName, rawConfig); + if (parsed) { + result[serverName] = parsed; + } + }), + ); + for (const outcome of settled) { + if (outcome.status === 'rejected') { + logger.error('[MCPServersRegistry][ensureConfigServers] Unexpected error:', outcome.reason); + } + } + + return result; + } + + /** + * Ensures a single config-source server is initialized. + * Cache key is scoped by config hash to prevent cross-tenant poisoning. + * Deduplicates concurrent init requests for the same server+config. + * Stale failure stubs are retried after `CONFIG_STUB_RETRY_MS` to recover from transient errors. + */ + private async ensureSingleConfigServer( + serverName: string, + rawConfig: t.MCPOptions, + ): Promise { + const cacheKey = this.configCacheKey(serverName, rawConfig); + + const cached = await this.configCacheRepo.get(cacheKey); + if (cached) { + const isStaleStub = + cached.inspectionFailed && Date.now() - (cached.updatedAt ?? 0) > CONFIG_STUB_RETRY_MS; + if (!isStaleStub) { + return cached; + } + logger.info(`[MCP][config][${serverName}] Retrying stale failure stub`); + } + + const pending = this.pendingConfigInits.get(cacheKey); + if (pending) { + return pending; + } + + const initPromise = this.lazyInitConfigServer(cacheKey, serverName, rawConfig); + this.pendingConfigInits.set(cacheKey, initPromise); + + try { + return await initPromise; + } finally { + this.pendingConfigInits.delete(cacheKey); + } + } + + /** + * Lazily initializes a config-source MCP server: inspects capabilities/tools, then + * stores the parsed config in the config cache with `source: 'config'`. + */ + private async lazyInitConfigServer( + cacheKey: string, + serverName: string, + rawConfig: t.MCPOptions, + ): Promise { + const prefix = `[MCP][config][${serverName}]`; + logger.info(`${prefix} Lazy-initializing config-source server`); + + try { + const inspected = await withTimeout( + MCPServerInspector.inspect(serverName, rawConfig, undefined, this.allowedDomains), + CONFIG_SERVER_INIT_TIMEOUT_MS, + `${prefix} Server initialization timed out`, + ); + + const parsedConfig: t.ParsedServerConfig = { ...inspected, source: 'config' }; + await this.upsertConfigCache(cacheKey, parsedConfig); + + logger.info( + `${prefix} Initialized: tools=${parsedConfig.tools ?? 'N/A'}, ` + + `duration=${parsedConfig.initDuration ?? 'N/A'}ms`, + ); + return parsedConfig; + } catch (error) { + logger.error(`${prefix} Failed to initialize:`, error); + + const stubConfig: t.ParsedServerConfig = { + ...rawConfig, + inspectionFailed: true, + source: 'config', + updatedAt: Date.now(), + }; + try { + await this.upsertConfigCache(cacheKey, stubConfig); + logger.info(`${prefix} Stored stub config for recovery`); + } catch (cacheError) { + logger.error( + `${prefix} Failed to store stub config (will retry on next request):`, + cacheError, + ); + } + return stubConfig; + } + } + + /** + * Writes a config to `configCacheRepo` using the atomic upsert operation. + * Safe for cross-process races — the underlying cache handles add-or-update internally. + */ + private async upsertConfigCache(cacheKey: string, config: t.ParsedServerConfig): Promise { + await this.configCacheRepo.upsert(cacheKey, config); + } + + /** + * Clears the config-source server cache, forcing re-inspection on next access. + * Called when admin config overrides change (e.g., mcpServers mutation). + * + * @returns Names of servers that were evicted from the config cache. + * Callers should disconnect active connections for these servers. + */ + public async invalidateConfigCache(): Promise { + const allCached = await this.configCacheRepo.getAll(); + const evictedNames = [ + ...new Set( + Object.keys(allCached).map((key) => { + const lastColon = key.lastIndexOf(':'); + return lastColon > 0 ? key.slice(0, lastColon) : key; + }), + ), + ]; + + await Promise.all([ + this.configCacheRepo.reset(), + // Only clear readThroughCacheAll (merged results that may include stale config servers). + // readThroughCache (individual YAML/user lookups) is unaffected by config mutations. + this.readThroughCacheAll.clear(), + ]); + + if (evictedNames.length > 0) { + logger.info( + `[MCPServersRegistry] Config server cache invalidated, evicted: ${evictedNames.join(', ')}`, + ); + } + return evictedNames; + } + + // TODO: Refactor callers to use config.requiresOAuth directly instead of this method. + // Known gap: config-source OAuth servers are not included here because callers + // (OAuthReconnectionManager, UserController) lack request context to resolve configServers. + // Config-source OAuth auto-reconnection and uninstall cleanup require a separate mechanism. public async getOAuthServers(userId?: string): Promise> { const allServers = await this.getAllServerConfigs(userId); const oauthServers = Object.entries(allServers).filter(([, config]) => config.requiresOAuth); @@ -287,8 +523,11 @@ export class MCPServersRegistry { public async reset(): Promise { await this.cacheConfigsRepo.reset(); + await this.configCacheRepo.reset(); await this.readThroughCache.clear(); await this.readThroughCacheAll.clear(); + this.yamlServerNames = null; + this.yamlServerNamesPromise = null; } public async removeServer( @@ -316,4 +555,48 @@ export class MCPServersRegistry { private getReadThroughCacheKey(serverName: string, userId?: string): string { return userId ? `${serverName}::${userId}` : serverName; } + + /** + * Returns memoized YAML server names. Populated lazily on first call after boot/reset. + * YAML servers don't change after boot, so this avoids repeated `getAll()` calls. + * Uses promise deduplication to prevent concurrent cold-start double-fetch. + */ + private getYamlServerNames(): Promise> { + if (this.yamlServerNames) { + return Promise.resolve(this.yamlServerNames); + } + if (this.yamlServerNamesPromise) { + return this.yamlServerNamesPromise; + } + this.yamlServerNamesPromise = this.cacheConfigsRepo + .getAll() + .then((configs) => { + this.yamlServerNames = new Set(Object.keys(configs)); + this.yamlServerNamesPromise = null; + return this.yamlServerNames; + }) + .catch((err) => { + this.yamlServerNamesPromise = null; + throw err; + }); + return this.yamlServerNamesPromise; + } + + /** + * Produces a config-cache key scoped by server name AND a hash of the raw config. + * Prevents cross-tenant cache poisoning when two tenants define the same server name + * with different configurations. + */ + private configCacheKey(serverName: string, rawConfig: t.MCPOptions): string { + const sorted = JSON.stringify(rawConfig, (_key, value: unknown) => { + if (value !== null && typeof value === 'object' && !Array.isArray(value)) { + return Object.fromEntries( + Object.entries(value as Record).sort(([a], [b]) => a.localeCompare(b)), + ); + } + return value; + }); + const hash = createHash('sha256').update(sorted).digest('hex').slice(0, 16); + return `${serverName}:${hash}`; + } } diff --git a/packages/api/src/mcp/registry/ServerConfigsRepositoryInterface.ts b/packages/api/src/mcp/registry/ServerConfigsRepositoryInterface.ts index 1c913dd1a3..4bf0fdd615 100644 --- a/packages/api/src/mcp/registry/ServerConfigsRepositoryInterface.ts +++ b/packages/api/src/mcp/registry/ServerConfigsRepositoryInterface.ts @@ -9,6 +9,9 @@ export interface IServerConfigsRepositoryInterface { //ACL Entry check if update is possible update(serverName: string, config: ParsedServerConfig, userId?: string): Promise; + /** Atomic add-or-update without requiring callers to inspect error messages. */ + upsert(serverName: string, config: ParsedServerConfig, userId?: string): Promise; + //ACL Entry check if remove is possible remove(serverName: string, userId?: string): Promise; diff --git a/packages/api/src/mcp/registry/__tests__/MCPServerInspector.test.ts b/packages/api/src/mcp/registry/__tests__/MCPServerInspector.test.ts index f0ab75c9b4..2012f82e31 100644 --- a/packages/api/src/mcp/registry/__tests__/MCPServerInspector.test.ts +++ b/packages/api/src/mcp/registry/__tests__/MCPServerInspector.test.ts @@ -321,12 +321,12 @@ describe('MCPServerInspector', () => { const result = await MCPServerInspector.inspect('test_server', rawConfig); // Verify factory was called to create connection - expect(MCPConnectionFactory.create).toHaveBeenCalledWith({ - serverName: 'test_server', - serverConfig: expect.objectContaining({ type: 'stdio', command: 'node' }), - useSSRFProtection: true, - dbSourced: false, - }); + expect(MCPConnectionFactory.create).toHaveBeenCalledWith( + expect.objectContaining({ + serverName: 'test_server', + serverConfig: expect.objectContaining({ type: 'stdio', command: 'node' }), + }), + ); // Verify temporary connection was disconnected expect(tempMockConnection.disconnect).toHaveBeenCalled(); diff --git a/packages/api/src/mcp/registry/__tests__/MCPServersRegistry.test.ts b/packages/api/src/mcp/registry/__tests__/MCPServersRegistry.test.ts index 8891120717..a20c09705f 100644 --- a/packages/api/src/mcp/registry/__tests__/MCPServersRegistry.test.ts +++ b/packages/api/src/mcp/registry/__tests__/MCPServersRegistry.test.ts @@ -112,8 +112,8 @@ describe('MCPServersRegistry', () => { const userConfigBefore = await registry.getServerConfig('user_server'); const allConfigsBefore = await registry.getAllServerConfigs(); - expect(appConfigBefore).toEqual(testParsedConfig); - expect(userConfigBefore).toEqual(testParsedConfig); + expect(appConfigBefore).toEqual(expect.objectContaining(testParsedConfig)); + expect(userConfigBefore).toEqual(expect.objectContaining(testParsedConfig)); expect(Object.keys(allConfigsBefore)).toHaveLength(2); // Reset everything @@ -250,22 +250,18 @@ describe('MCPServersRegistry', () => { }); it('should use different cache keys for different userIds', async () => { - // Spy on the cache repository get method + await registry['cacheConfigsRepo'].add('test_server', testParsedConfig); const cacheRepoGetSpy = jest.spyOn(registry['cacheConfigsRepo'], 'get'); - // First call without userId await registry.getServerConfig('test_server'); expect(cacheRepoGetSpy).toHaveBeenCalledTimes(1); - // Call with userId - should be a different cache key, so hits repository again await registry.getServerConfig('test_server', 'user123'); expect(cacheRepoGetSpy).toHaveBeenCalledTimes(2); - // Repeat call with same userId - should hit read-through cache await registry.getServerConfig('test_server', 'user123'); - expect(cacheRepoGetSpy).toHaveBeenCalledTimes(2); // Still 2 + expect(cacheRepoGetSpy).toHaveBeenCalledTimes(2); - // Call with different userId - should hit repository await registry.getServerConfig('test_server', 'user456'); expect(cacheRepoGetSpy).toHaveBeenCalledTimes(3); }); diff --git a/packages/api/src/mcp/registry/__tests__/ensureConfigServers.test.ts b/packages/api/src/mcp/registry/__tests__/ensureConfigServers.test.ts new file mode 100644 index 0000000000..70eb2f75c4 --- /dev/null +++ b/packages/api/src/mcp/registry/__tests__/ensureConfigServers.test.ts @@ -0,0 +1,328 @@ +import type * as t from '~/mcp/types'; +import { MCPServersRegistry } from '~/mcp/registry/MCPServersRegistry'; +import { MCPServerInspector } from '~/mcp/registry/MCPServerInspector'; + +jest.mock('~/mcp/registry/MCPServerInspector'); +jest.mock('~/mcp/registry/db/ServerConfigsDB', () => ({ + ServerConfigsDB: jest.fn().mockImplementation(() => ({ + get: jest.fn().mockResolvedValue(undefined), + getAll: jest.fn().mockResolvedValue({}), + add: jest.fn().mockResolvedValue(undefined), + update: jest.fn().mockResolvedValue(undefined), + upsert: jest.fn().mockResolvedValue(undefined), + remove: jest.fn().mockResolvedValue(undefined), + reset: jest.fn().mockResolvedValue(undefined), + })), +})); + +const FIXED_TIME = 1699564800000; + +const mockMongoose = {} as typeof import('mongoose'); + +const sseConfig: t.MCPOptions = { + type: 'sse', + url: 'https://mcp.example.com/sse', +} as unknown as t.MCPOptions; + +const altSseConfig: t.MCPOptions = { + type: 'sse', + url: 'https://mcp.other-tenant.com/sse', +} as unknown as t.MCPOptions; + +const yamlConfig: t.MCPOptions = { + type: 'stdio', + command: 'node', + args: ['tools.js'], +} as unknown as t.MCPOptions; + +function makeParsedConfig(overrides: Partial = {}): t.ParsedServerConfig { + return { + type: 'sse', + url: 'https://mcp.example.com/sse', + requiresOAuth: false, + tools: 'tool_a, tool_b', + capabilities: '{}', + initDuration: 42, + ...overrides, + } as unknown as t.ParsedServerConfig; +} + +describe('MCPServersRegistry — ensureConfigServers', () => { + let registry: MCPServersRegistry; + let inspectSpy: jest.SpyInstance; + + beforeAll(() => { + jest.useFakeTimers(); + jest.setSystemTime(new Date(FIXED_TIME)); + }); + + afterAll(() => { + jest.useRealTimers(); + }); + + beforeEach(async () => { + (MCPServersRegistry as unknown as { instance: undefined }).instance = undefined; + MCPServersRegistry.createInstance(mockMongoose); + registry = MCPServersRegistry.getInstance(); + + inspectSpy = jest + .spyOn(MCPServerInspector, 'inspect') + .mockImplementation(async (_serverName: string, rawConfig: t.MCPOptions) => + makeParsedConfig(rawConfig as unknown as Partial), + ); + + await registry.reset(); + }); + + afterEach(() => { + inspectSpy.mockClear(); + }); + + it('should return empty for empty input', async () => { + expect(await registry.ensureConfigServers({})).toEqual({}); + }); + + it('should return empty for null/undefined input', async () => { + expect( + await registry.ensureConfigServers(null as unknown as Record), + ).toEqual({}); + expect( + await registry.ensureConfigServers(undefined as unknown as Record), + ).toEqual({}); + }); + + it('should exclude YAML servers from config-source detection', async () => { + await registry.addServer('yaml_server', yamlConfig, 'CACHE'); + + const result = await registry.ensureConfigServers({ + yaml_server: yamlConfig, + config_server: sseConfig, + }); + + expect(result).toHaveProperty('config_server'); + expect(result).not.toHaveProperty('yaml_server'); + }); + + it('should return empty when all servers are YAML', async () => { + await registry.addServer('yaml_a', yamlConfig, 'CACHE'); + await registry.addServer('yaml_b', yamlConfig, 'CACHE'); + inspectSpy.mockClear(); + + const result = await registry.ensureConfigServers({ + yaml_a: yamlConfig, + yaml_b: yamlConfig, + }); + + expect(result).toEqual({}); + expect(inspectSpy).not.toHaveBeenCalled(); + }); + + it('should lazy-initialize a config-source server and tag source as config', async () => { + const result = await registry.ensureConfigServers({ my_server: sseConfig }); + + expect(result).toHaveProperty('my_server'); + expect(result.my_server.source).toBe('config'); + expect(inspectSpy).toHaveBeenCalledTimes(1); + expect(inspectSpy).toHaveBeenCalledWith('my_server', sseConfig, undefined, undefined); + }); + + it('should return cached result on second call without re-inspecting', async () => { + await registry.ensureConfigServers({ my_server: sseConfig }); + expect(inspectSpy).toHaveBeenCalledTimes(1); + + const result2 = await registry.ensureConfigServers({ my_server: sseConfig }); + expect(result2).toHaveProperty('my_server'); + expect(result2.my_server.source).toBe('config'); + expect(inspectSpy).toHaveBeenCalledTimes(1); + }); + + it('should store inspectionFailed stub on inspection failure', async () => { + inspectSpy.mockRejectedValueOnce(new Error('connection refused')); + + const result = await registry.ensureConfigServers({ bad_server: sseConfig }); + + expect(result).toHaveProperty('bad_server'); + expect(result.bad_server.inspectionFailed).toBe(true); + expect(result.bad_server.source).toBe('config'); + }); + + it('should return stub from cache on repeated failure without re-inspecting', async () => { + inspectSpy.mockRejectedValueOnce(new Error('connection refused')); + await registry.ensureConfigServers({ bad_server: sseConfig }); + expect(inspectSpy).toHaveBeenCalledTimes(1); + + const result2 = await registry.ensureConfigServers({ bad_server: sseConfig }); + expect(result2.bad_server.inspectionFailed).toBe(true); + expect(inspectSpy).toHaveBeenCalledTimes(1); + }); + + it('should retry stale failure stub after CONFIG_STUB_RETRY_MS', async () => { + inspectSpy.mockRejectedValueOnce(new Error('transient DNS failure')); + await registry.ensureConfigServers({ flaky_server: sseConfig }); + expect(inspectSpy).toHaveBeenCalledTimes(1); + + jest.setSystemTime(new Date(FIXED_TIME + 6 * 60 * 1000)); + + const result = await registry.ensureConfigServers({ flaky_server: sseConfig }); + expect(inspectSpy).toHaveBeenCalledTimes(2); + expect(result.flaky_server.inspectionFailed).toBeUndefined(); + expect(result.flaky_server.source).toBe('config'); + + jest.setSystemTime(new Date(FIXED_TIME)); + }); + + describe('cross-tenant isolation', () => { + it('should use different cache keys for same server name with different configs', async () => { + inspectSpy.mockClear(); + const resultA = await registry.ensureConfigServers({ shared_name: sseConfig }); + expect(resultA.shared_name.source).toBe('config'); + expect(inspectSpy).toHaveBeenCalledTimes(1); + + const resultB = await registry.ensureConfigServers({ shared_name: altSseConfig }); + expect(resultB.shared_name.source).toBe('config'); + expect(inspectSpy).toHaveBeenCalledTimes(2); + }); + + it('should return tenant-A config for tenant-A and tenant-B config for tenant-B', async () => { + const resultA = await registry.ensureConfigServers({ srv: sseConfig }); + const resultB = await registry.ensureConfigServers({ srv: altSseConfig }); + + expect((resultA.srv as unknown as { url: string }).url).toBe('https://mcp.example.com/sse'); + expect((resultB.srv as unknown as { url: string }).url).toBe( + 'https://mcp.other-tenant.com/sse', + ); + }); + }); + + describe('concurrent deduplication', () => { + it('should only inspect once for multiple parallel calls with the same config', async () => { + inspectSpy.mockClear(); + // Fire two calls simultaneously — both see cache miss, but only one should inspect + const [r1, r2] = await Promise.all([ + registry.ensureConfigServers({ dedup_srv: sseConfig }), + registry.ensureConfigServers({ dedup_srv: sseConfig }), + ]); + + expect(r1.dedup_srv).toBeDefined(); + expect(r2.dedup_srv).toBeDefined(); + expect(inspectSpy).toHaveBeenCalledTimes(1); + + // Subsequent call must NOT re-inspect (cached) + inspectSpy.mockClear(); + await registry.ensureConfigServers({ dedup_srv: sseConfig }); + expect(inspectSpy).toHaveBeenCalledTimes(0); + }); + }); + + describe('merge order', () => { + it('should merge YAML → config → user with correct precedence in getAllServerConfigs', async () => { + await registry.addServer('yaml_srv', yamlConfig, 'CACHE'); + + const configServers = await registry.ensureConfigServers({ config_srv: sseConfig }); + + const all = await registry.getAllServerConfigs(undefined, configServers); + expect(all).toHaveProperty('yaml_srv'); + expect(all).toHaveProperty('config_srv'); + expect(all.yaml_srv.source).toBe('yaml'); + expect(all.config_srv.source).toBe('config'); + }); + + it('should let config servers appear alongside user DB servers', async () => { + const mockDbConfigs = { + user_srv: makeParsedConfig({ source: 'user', dbId: 'abc123' }), + }; + jest.spyOn(registry['dbConfigsRepo'], 'getAll').mockResolvedValue(mockDbConfigs); + + const configServers = await registry.ensureConfigServers({ config_srv: sseConfig }); + const all = await registry.getAllServerConfigs('user-1', configServers); + + expect(all).toHaveProperty('config_srv'); + expect(all).toHaveProperty('user_srv'); + expect(all.config_srv.source).toBe('config'); + expect(all.user_srv.source).toBe('user'); + }); + }); + + describe('invalidateConfigCache', () => { + it('should clear config cache and force re-inspection on next call', async () => { + await registry.ensureConfigServers({ my_server: sseConfig }); + inspectSpy.mockClear(); + + await registry.invalidateConfigCache(); + + await registry.ensureConfigServers({ my_server: sseConfig }); + expect(inspectSpy).toHaveBeenCalledTimes(1); + }); + + it('should return evicted server names', async () => { + await registry.ensureConfigServers({ srv_a: sseConfig, srv_b: altSseConfig }); + const evicted = await registry.invalidateConfigCache(); + expect(evicted.length).toBeGreaterThan(0); + }); + + it('should return empty array when nothing is cached', async () => { + const evicted = await registry.invalidateConfigCache(); + expect(evicted).toEqual([]); + }); + }); + + describe('getServerConfig with configServers', () => { + it('should return config-source server when configServers is passed', async () => { + const configServers = await registry.ensureConfigServers({ config_srv: sseConfig }); + const config = await registry.getServerConfig('config_srv', undefined, configServers); + expect(config).toBeDefined(); + expect(config?.source).toBe('config'); + }); + + it('should return config-source server with userId when configServers is passed', async () => { + const configServers = await registry.ensureConfigServers({ config_srv: sseConfig }); + const config = await registry.getServerConfig('config_srv', 'user-123', configServers); + expect(config).toBeDefined(); + expect(config?.source).toBe('config'); + }); + + it('should return undefined for config-source server without configServers (tenant isolation)', async () => { + await registry.ensureConfigServers({ config_srv: sseConfig }); + const config = await registry.getServerConfig('config_srv'); + expect(config).toBeUndefined(); + }); + + it('should return correct config after invalidation and re-init', async () => { + const configServers1 = await registry.ensureConfigServers({ config_srv: sseConfig }); + expect(await registry.getServerConfig('config_srv', undefined, configServers1)).toBeDefined(); + + await registry.invalidateConfigCache(); + + const configServers2 = await registry.ensureConfigServers({ config_srv: sseConfig }); + const config = await registry.getServerConfig('config_srv', undefined, configServers2); + expect(config).toBeDefined(); + expect(config?.source).toBe('config'); + }); + + it('should not cross-contaminate between tenant configServers maps', async () => { + const tenantA = await registry.ensureConfigServers({ srv: sseConfig }); + const tenantB = await registry.ensureConfigServers({ srv: altSseConfig }); + + const configA = await registry.getServerConfig('srv', undefined, tenantA); + const configB = await registry.getServerConfig('srv', undefined, tenantB); + + expect((configA as unknown as { url: string }).url).toBe('https://mcp.example.com/sse'); + expect((configB as unknown as { url: string }).url).toBe('https://mcp.other-tenant.com/sse'); + }); + }); + + describe('source tagging', () => { + it('should tag CACHE-stored servers as yaml', async () => { + await registry.addServer('yaml_srv', yamlConfig, 'CACHE'); + const config = await registry.getServerConfig('yaml_srv'); + expect(config?.source).toBe('yaml'); + }); + + it('should tag stubs as yaml when stored in CACHE', async () => { + await registry.addServerStub('stub_srv', yamlConfig, 'CACHE'); + const config = await registry.getServerConfig('stub_srv'); + expect(config?.source).toBe('yaml'); + expect(config?.inspectionFailed).toBe(true); + }); + }); +}); diff --git a/packages/api/src/mcp/registry/cache/ServerConfigsCacheFactory.ts b/packages/api/src/mcp/registry/cache/ServerConfigsCacheFactory.ts index b9549629d6..ebe19b59e3 100644 --- a/packages/api/src/mcp/registry/cache/ServerConfigsCacheFactory.ts +++ b/packages/api/src/mcp/registry/cache/ServerConfigsCacheFactory.ts @@ -16,11 +16,17 @@ export type ServerConfigsCache = */ export const APP_CACHE_NAMESPACE = 'App' as const; +/** Namespace for admin-defined config-override MCP server inspection results. */ +export const CONFIG_CACHE_NAMESPACE = 'Config' as const; + +/** Namespaces that use the aggregate-key optimization to avoid SCAN+N-GETs stalls. */ +const AGGREGATE_KEY_NAMESPACES = new Set([APP_CACHE_NAMESPACE, CONFIG_CACHE_NAMESPACE]); + /** * Factory for creating the appropriate ServerConfigsCache implementation based on * deployment mode and namespace. * - * The {@link APP_CACHE_NAMESPACE} namespace uses {@link ServerConfigsCacheRedisAggregateKey} + * Namespaces in {@link AGGREGATE_KEY_NAMESPACES} use {@link ServerConfigsCacheRedisAggregateKey} * when Redis is enabled — storing all configs under a single key so `getAll()` is one GET * instead of SCAN + N GETs. Cross-instance visibility is preserved: reinspection results * propagate through Redis automatically. @@ -32,8 +38,8 @@ export class ServerConfigsCacheFactory { /** * Create a ServerConfigsCache instance. * - * @param namespace - The namespace for the cache. {@link APP_CACHE_NAMESPACE} uses - * aggregate-key Redis storage (or in-memory when Redis is disabled). + * @param namespace - The namespace for the cache. Namespaces in {@link AGGREGATE_KEY_NAMESPACES} + * use aggregate-key Redis storage (or in-memory when Redis is disabled). * @param leaderOnly - Whether write operations should only be performed by the leader. * @returns ServerConfigsCache instance */ @@ -42,7 +48,7 @@ export class ServerConfigsCacheFactory { return new ServerConfigsCacheInMemory(); } - if (namespace === APP_CACHE_NAMESPACE) { + if (AGGREGATE_KEY_NAMESPACES.has(namespace)) { return new ServerConfigsCacheRedisAggregateKey(namespace, leaderOnly); } diff --git a/packages/api/src/mcp/registry/cache/ServerConfigsCacheInMemory.ts b/packages/api/src/mcp/registry/cache/ServerConfigsCacheInMemory.ts index 384c477756..5a7fd35b9f 100644 --- a/packages/api/src/mcp/registry/cache/ServerConfigsCacheInMemory.ts +++ b/packages/api/src/mcp/registry/cache/ServerConfigsCacheInMemory.ts @@ -28,6 +28,10 @@ export class ServerConfigsCacheInMemory { this.cache.set(serverName, { ...config, updatedAt: Date.now() }); } + public async upsert(serverName: string, config: ParsedServerConfig): Promise { + this.cache.set(serverName, { ...config, updatedAt: Date.now() }); + } + public async remove(serverName: string): Promise { if (!this.cache.delete(serverName)) { throw new Error(`Failed to remove server "${serverName}" in cache.`); diff --git a/packages/api/src/mcp/registry/cache/ServerConfigsCacheRedis.ts b/packages/api/src/mcp/registry/cache/ServerConfigsCacheRedis.ts index d3154baf73..af1316056d 100644 --- a/packages/api/src/mcp/registry/cache/ServerConfigsCacheRedis.ts +++ b/packages/api/src/mcp/registry/cache/ServerConfigsCacheRedis.ts @@ -52,6 +52,12 @@ export class ServerConfigsCacheRedis this.successCheck(`update ${this.namespace} server "${serverName}"`, success); } + public async upsert(serverName: string, config: ParsedServerConfig): Promise { + if (this.leaderOnly) await this.leaderCheck(`upsert ${this.namespace} MCP servers`); + const success = await this.cache.set(serverName, { ...config, updatedAt: Date.now() }); + this.successCheck(`upsert ${this.namespace} server "${serverName}"`, success); + } + public async remove(serverName: string): Promise { if (this.leaderOnly) await this.leaderCheck(`remove ${this.namespace} MCP servers`); const success = await this.cache.delete(serverName); diff --git a/packages/api/src/mcp/registry/cache/ServerConfigsCacheRedisAggregateKey.ts b/packages/api/src/mcp/registry/cache/ServerConfigsCacheRedisAggregateKey.ts index e67c1a4a84..5fc32bd7aa 100644 --- a/packages/api/src/mcp/registry/cache/ServerConfigsCacheRedisAggregateKey.ts +++ b/packages/api/src/mcp/registry/cache/ServerConfigsCacheRedisAggregateKey.ts @@ -53,8 +53,11 @@ export class ServerConfigsCacheRedisAggregateKey /** Milliseconds since epoch. 0 = epoch = always expired on first check. */ private localSnapshotExpiry = 0; + private readonly namespace: string; + constructor(namespace: string, leaderOnly: boolean) { super(leaderOnly); + this.namespace = namespace; this.cache = standardCache(`${this.PREFIX}::Servers::${namespace}`); } @@ -125,7 +128,7 @@ export class ServerConfigsCacheRedisAggregateKey const storedConfig = { ...config, updatedAt: Date.now() }; const newAll = { ...all, [serverName]: storedConfig }; const success = await this.cache.set(AGGREGATE_KEY, newAll); - this.successCheck(`add App server "${serverName}"`, success); + this.successCheck(`add ${this.namespace} server "${serverName}"`, success); return { serverName, config: storedConfig }; }); } @@ -142,7 +145,18 @@ export class ServerConfigsCacheRedisAggregateKey } const newAll = { ...all, [serverName]: { ...config, updatedAt: Date.now() } }; const success = await this.cache.set(AGGREGATE_KEY, newAll); - this.successCheck(`update App server "${serverName}"`, success); + this.successCheck(`update ${this.namespace} server "${serverName}"`, success); + }); + } + + public async upsert(serverName: string, config: ParsedServerConfig): Promise { + if (this.leaderOnly) await this.leaderCheck('upsert MCP servers'); + return this.withWriteLock(async () => { + this.invalidateLocalSnapshot(); + const all = await this.getAll(); + const newAll = { ...all, [serverName]: { ...config, updatedAt: Date.now() } }; + const success = await this.cache.set(AGGREGATE_KEY, newAll); + this.successCheck(`upsert ${this.namespace} server "${serverName}"`, success); }); } @@ -156,7 +170,7 @@ export class ServerConfigsCacheRedisAggregateKey } const { [serverName]: _, ...newAll } = all; const success = await this.cache.set(AGGREGATE_KEY, newAll); - this.successCheck(`remove App server "${serverName}"`, success); + this.successCheck(`remove ${this.namespace} server "${serverName}"`, success); }); } @@ -171,7 +185,7 @@ export class ServerConfigsCacheRedisAggregateKey */ public override async reset(): Promise { if (this.leaderOnly) { - await this.leaderCheck('reset App MCP servers cache'); + await this.leaderCheck(`reset ${this.namespace} MCP servers cache`); } await this.cache.delete(AGGREGATE_KEY); this.invalidateLocalSnapshot(); diff --git a/packages/api/src/mcp/registry/db/ServerConfigsDB.ts b/packages/api/src/mcp/registry/db/ServerConfigsDB.ts index 9981f6b00b..b1649c66ca 100644 --- a/packages/api/src/mcp/registry/db/ServerConfigsDB.ts +++ b/packages/api/src/mcp/registry/db/ServerConfigsDB.ts @@ -220,6 +220,25 @@ export class ServerConfigsDB implements IServerConfigsRepositoryInterface { await this._dbMethods.updateMCPServer(serverName, { config: configToSave }); } + /** + * Atomic add-or-update. For DB-backed servers this delegates to update since + * DB servers are always created via the explicit add() flow with ACL setup. + * Config-source servers should use configCacheRepo, not dbConfigsRepo. + */ + public async upsert( + serverName: string, + config: ParsedServerConfig, + userId?: string, + ): Promise { + if (!userId) { + throw new Error( + `[ServerConfigsDB.upsert] User ID is required for DB-backed MCP server upsert of "${serverName}". ` + + 'Config-source servers should use configCacheRepo, not dbConfigsRepo.', + ); + } + return this.update(serverName, config, userId); + } + /** * Deletes an MCP server and removes all associated ACL entries. * @param serverName - The serverName of the server to remove @@ -411,6 +430,7 @@ export class ServerConfigsDB implements IServerConfigsRepositoryInterface { const config: ParsedServerConfig = { ...serverDBDoc.config, dbId: (serverDBDoc._id as Types.ObjectId).toString(), + source: 'user', updatedAt: serverDBDoc.updatedAt?.getTime(), }; return await this.decryptConfig(config); diff --git a/packages/api/src/mcp/types/index.ts b/packages/api/src/mcp/types/index.ts index 6cb5e02f0b..32c2787165 100644 --- a/packages/api/src/mcp/types/index.ts +++ b/packages/api/src/mcp/types/index.ts @@ -144,6 +144,14 @@ export type ImageFormatter = (item: ImageContent) => FormattedContent; export type FormattedToolResponse = FormattedContentResult; +/** + * Origin of an MCP server definition. + * - `'yaml'` — operator-defined in librechat.yaml, full trust, boot-time init + * - `'config'` — admin-defined via Config override, full trust, lazy init + * - `'user'` — user-provided via UI, sandboxed (restricted placeholder resolution) + */ +export type MCPServerSource = 'yaml' | 'config' | 'user'; + export type ParsedServerConfig = MCPOptions & { url?: string; requiresOAuth?: boolean; @@ -154,6 +162,8 @@ export type ParsedServerConfig = MCPOptions & { initDuration?: number; updatedAt?: number; dbId?: string; + /** Origin of this server definition — determines trust level and placeholder resolution */ + source?: MCPServerSource; /** True if access is only via agent (not directly shared with user) */ consumeOnly?: boolean; /** True when inspection failed at startup; the server is known but not fully initialized */ @@ -202,6 +212,8 @@ export interface ToolDiscoveryOptions { customUserVars?: Record; requestBody?: RequestBody; connectionTimeout?: number; + /** Pre-resolved config-source servers for tenant-scoped lookup */ + configServers?: Record; } export interface ToolDiscoveryResult { diff --git a/packages/api/src/mcp/utils.ts b/packages/api/src/mcp/utils.ts index db89cffada..653a96d5bd 100644 --- a/packages/api/src/mcp/utils.ts +++ b/packages/api/src/mcp/utils.ts @@ -8,6 +8,15 @@ export function hasCustomUserVars(config: Pick 0; } +/** + * Determines whether a server config is user-sourced (sandboxed placeholder resolution). + * When `source` is set, it is authoritative. When absent (pre-upgrade cached configs), + * falls back to the legacy `dbId` heuristic for backward compatibility. + */ +export function isUserSourced(config: Pick): boolean { + return config.source != null ? config.source === 'user' : !!config.dbId; +} + /** * Allowlist-based sanitization for API responses. Only explicitly listed fields are included; * new fields added to ParsedServerConfig are excluded by default until allowlisted here. @@ -31,6 +40,8 @@ export function redactServerSecrets(config: ParsedServerConfig): Partial