From 95a69df70eaa79d1c3605a1b705bfc9905c38d1b Mon Sep 17 00:00:00 2001 From: Atef Bellaaj Date: Thu, 18 Dec 2025 19:57:49 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=92=20feat:=20Add=20MCP=20server=20dom?= =?UTF-8?q?ain=20restrictions=20for=20remote=20transports=20(#11013)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 🔒 feat: Add MCP server domain restrictions for remote transports * 🔒 feat: Implement comprehensive MCP error handling and domain validation - Added `handleMCPError` function to centralize error responses for domain restrictions and inspection failures. - Introduced custom error classes: `MCPDomainNotAllowedError` and `MCPInspectionFailedError` for better error management. - Updated MCP server controllers to utilize the new error handling mechanism. - Enhanced domain validation logic in `createMCPTools` and `createMCPTool` functions to prevent operations on disallowed domains. - Added tests for runtime domain validation scenarios to ensure correct behavior. * chore: import order * 🔒 feat: Enhance domain validation in MCP tools with user role-based restrictions - Integrated `getAppConfig` to fetch allowed domains based on user roles in `createMCPTools` and `createMCPTool` functions. - Removed the deprecated `getAllowedDomains` method from `MCPServersRegistry`. - Updated tests to verify domain restrictions are applied correctly based on user roles. - Ensured that domain validation logic is consistent and efficient across tool creation processes. * 🔒 test: Refactor MCP tests to utilize configurable app settings - Introduced a mock for `getAppConfig` to enhance test flexibility. - Removed redundant mock definition to streamline test setup. - Ensured tests are aligned with the latest domain validation logic. --------- Co-authored-by: Atef Bellaaj Co-authored-by: Danny Avila --- api/app/clients/tools/util/handleTools.js | 11 +- api/server/controllers/mcp.js | 60 +++- api/server/routes/__tests__/mcp.spec.js | 50 +-- api/server/services/MCP.js | 49 ++- api/server/services/MCP.spec.js | 297 ++++++++++++++++-- api/server/services/initializeMCPs.js | 3 +- .../SidePanel/MCPBuilder/MCPServerDialog.tsx | 2 + client/src/locales/en/translation.json | 1 + librechat.example.yaml | 10 + packages/api/src/auth/domain.spec.ts | 213 ++++++++++++- packages/api/src/auth/domain.ts | 42 +++ packages/api/src/index.ts | 1 + packages/api/src/mcp/errors.ts | 61 ++++ .../src/mcp/registry/MCPServerInspector.ts | 11 + .../src/mcp/registry/MCPServersRegistry.ts | 38 ++- .../__tests__/MCPServersInitializer.test.ts | 28 +- packages/data-provider/src/config.ts | 5 + packages/data-schemas/src/app/service.ts | 6 +- packages/data-schemas/src/types/app.ts | 2 + 19 files changed, 815 insertions(+), 75 deletions(-) create mode 100644 packages/api/src/mcp/errors.ts diff --git a/api/app/clients/tools/util/handleTools.js b/api/app/clients/tools/util/handleTools.js index 15ccd38129..bae7255d97 100644 --- a/api/app/clients/tools/util/handleTools.js +++ b/api/app/clients/tools/util/handleTools.js @@ -348,10 +348,10 @@ Anchor pattern: \\ue202turn{N}{type}{index} where N=turn number, type=search|new /** Placeholder used for UI purposes */ continue; } - if ( - serverName && - (await getMCPServersRegistry().getServerConfig(serverName, user)) == undefined - ) { + const serverConfig = serverName + ? await getMCPServersRegistry().getServerConfig(serverName, user) + : null; + if (!serverConfig) { logger.warn( `MCP server "${serverName}" for "${toolName}" tool is not configured${agent?.id != null && agent.id ? ` but attached to "${agent.id}"` : ''}`, ); @@ -362,6 +362,7 @@ Anchor pattern: \\ue202turn{N}{type}{index} where N=turn number, type=search|new { type: 'all', serverName, + config: serverConfig, }, ]; continue; @@ -372,6 +373,7 @@ Anchor pattern: \\ue202turn{N}{type}{index} where N=turn number, type=search|new type: 'single', toolKey: tool, serverName, + config: serverConfig, }); continue; } @@ -435,6 +437,7 @@ Anchor pattern: \\ue202turn{N}{type}{index} where N=turn number, type=search|new model: agent?.model ?? model, serverName: config.serverName, provider: agent?.provider ?? endpoint, + config: config.config, }; if (config.type === 'all' && toolConfigs.length === 1) { diff --git a/api/server/controllers/mcp.js b/api/server/controllers/mcp.js index 1afd7095a6..e5dfff61ca 100644 --- a/api/server/controllers/mcp.js +++ b/api/server/controllers/mcp.js @@ -6,10 +6,54 @@ * @import { MCPServerDocument } from 'librechat-data-provider' */ const { logger } = require('@librechat/data-schemas'); +const { + isMCPDomainNotAllowedError, + isMCPInspectionFailedError, + MCPErrorCodes, +} = require('@librechat/api'); const { Constants, MCPServerUserInputSchema } = require('librechat-data-provider'); const { cacheMCPServerTools, getMCPServerTools } = require('~/server/services/Config'); const { getMCPManager, getMCPServersRegistry } = require('~/config'); +/** + * Handles MCP-specific errors and sends appropriate HTTP responses. + * @param {Error} error - The error to handle + * @param {import('express').Response} res - Express response object + * @returns {import('express').Response | null} Response if handled, null if not an MCP error + */ +function handleMCPError(error, res) { + if (isMCPDomainNotAllowedError(error)) { + return res.status(error.statusCode).json({ + error: error.code, + message: error.message, + }); + } + + if (isMCPInspectionFailedError(error)) { + return res.status(error.statusCode).json({ + error: error.code, + message: error.message, + }); + } + + // Fallback for legacy string-based error handling (backwards compatibility) + if (error.message?.startsWith(MCPErrorCodes.DOMAIN_NOT_ALLOWED)) { + return res.status(403).json({ + error: MCPErrorCodes.DOMAIN_NOT_ALLOWED, + message: error.message.replace(/^MCP_DOMAIN_NOT_ALLOWED\s*:\s*/i, ''), + }); + } + + if (error.message?.startsWith(MCPErrorCodes.INSPECTION_FAILED)) { + return res.status(400).json({ + error: MCPErrorCodes.INSPECTION_FAILED, + message: error.message, + }); + } + + return null; +} + /** * Get all MCP tools available to the user */ @@ -175,11 +219,9 @@ const createMCPServerController = async (req, res) => { }); } catch (error) { logger.error('[createMCPServer]', error); - if (error.message?.startsWith('MCP_INSPECTION_FAILED')) { - return res.status(400).json({ - error: 'MCP_INSPECTION_FAILED', - message: error.message, - }); + const mcpErrorResponse = handleMCPError(error, res); + if (mcpErrorResponse) { + return mcpErrorResponse; } res.status(500).json({ message: error.message }); } @@ -235,11 +277,9 @@ const updateMCPServerController = async (req, res) => { res.status(200).json(parsedConfig); } catch (error) { logger.error('[updateMCPServer]', error); - if (error.message?.startsWith('MCP_INSPECTION_FAILED:')) { - return res.status(400).json({ - error: 'MCP_INSPECTION_FAILED', - message: error.message, - }); + const mcpErrorResponse = handleMCPError(error, res); + if (mcpErrorResponse) { + return mcpErrorResponse; } res.status(500).json({ message: error.message }); } diff --git a/api/server/routes/__tests__/mcp.spec.js b/api/server/routes/__tests__/mcp.spec.js index af038ba8d6..1da1e0aa86 100644 --- a/api/server/routes/__tests__/mcp.spec.js +++ b/api/server/routes/__tests__/mcp.spec.js @@ -12,26 +12,36 @@ const mockRegistryInstance = { removeServer: jest.fn(), }; -jest.mock('@librechat/api', () => ({ - ...jest.requireActual('@librechat/api'), - MCPOAuthHandler: { - initiateOAuthFlow: jest.fn(), - getFlowState: jest.fn(), - completeOAuthFlow: jest.fn(), - generateFlowId: jest.fn(), - }, - MCPTokenStorage: { - storeTokens: jest.fn(), - getClientInfoAndMetadata: jest.fn(), - getTokens: jest.fn(), - deleteUserTokens: jest.fn(), - }, - getUserMCPAuthMap: jest.fn(), - generateCheckAccess: jest.fn(() => (req, res, next) => next()), - MCPServersRegistry: { - getInstance: () => mockRegistryInstance, - }, -})); +jest.mock('@librechat/api', () => { + const actual = jest.requireActual('@librechat/api'); + return { + ...actual, + MCPOAuthHandler: { + initiateOAuthFlow: jest.fn(), + getFlowState: jest.fn(), + completeOAuthFlow: jest.fn(), + generateFlowId: jest.fn(), + }, + MCPTokenStorage: { + storeTokens: jest.fn(), + getClientInfoAndMetadata: jest.fn(), + getTokens: jest.fn(), + deleteUserTokens: jest.fn(), + }, + getUserMCPAuthMap: jest.fn(), + generateCheckAccess: jest.fn(() => (req, res, next) => next()), + MCPServersRegistry: { + getInstance: () => mockRegistryInstance, + }, + // Error handling utilities (from @librechat/api mcp/errors) + isMCPDomainNotAllowedError: (error) => error?.code === 'MCP_DOMAIN_NOT_ALLOWED', + isMCPInspectionFailedError: (error) => error?.code === 'MCP_INSPECTION_FAILED', + MCPErrorCodes: { + DOMAIN_NOT_ALLOWED: 'MCP_DOMAIN_NOT_ALLOWED', + INSPECTION_FAILED: 'MCP_INSPECTION_FAILED', + }, + }; +}); jest.mock('@librechat/data-schemas', () => ({ logger: { diff --git a/api/server/services/MCP.js b/api/server/services/MCP.js index d63adc9822..72db447d3d 100644 --- a/api/server/services/MCP.js +++ b/api/server/services/MCP.js @@ -10,6 +10,7 @@ const { const { sendEvent, MCPOAuthHandler, + isMCPDomainAllowed, normalizeServerName, convertWithResolvedRefs, } = require('@librechat/api'); @@ -21,13 +22,14 @@ const { isAssistantsEndpoint, } = require('librechat-data-provider'); const { - getMCPManager, - getFlowStateManager, getOAuthReconnectionManager, getMCPServersRegistry, + getFlowStateManager, + getMCPManager, } = require('~/config'); const { findToken, createToken, updateToken } = require('~/models'); const { reinitMCPServer } = require('./Tools/mcp'); +const { getAppConfig } = require('./Config'); const { getLogStores } = require('~/cache'); /** @@ -222,10 +224,34 @@ async function reconnectServer({ res, user, index, signal, serverName, userMCPAu * @param {Providers | EModelEndpoint} params.provider - The provider for the tool. * @param {number} [params.index] * @param {AbortSignal} [params.signal] + * @param {import('@librechat/api').ParsedServerConfig} [params.config] * @param {Record>} [params.userMCPAuthMap] * @returns { Promise unknown}>> } An object with `_call` method to execute the tool input. */ -async function createMCPTools({ res, user, index, signal, serverName, provider, userMCPAuthMap }) { +async function createMCPTools({ + res, + user, + index, + signal, + config, + provider, + serverName, + userMCPAuthMap, +}) { + // Early domain validation before reconnecting server (avoid wasted work on disallowed domains) + // Use getAppConfig() to support per-user/role domain restrictions + const serverConfig = + config ?? (await getMCPServersRegistry().getServerConfig(serverName, user?.id)); + if (serverConfig?.url) { + const appConfig = await getAppConfig({ role: user?.role }); + const allowedDomains = appConfig?.mcpSettings?.allowedDomains; + const isDomainAllowed = await isMCPDomainAllowed(serverConfig, allowedDomains); + if (!isDomainAllowed) { + logger.warn(`[MCP][${serverName}] Domain not allowed, skipping all tools`); + return []; + } + } + const result = await reconnectServer({ res, user, index, signal, serverName, userMCPAuthMap }); if (!result || !result.tools) { logger.warn(`[MCP][${serverName}] Failed to reinitialize MCP server.`); @@ -241,6 +267,7 @@ async function createMCPTools({ res, user, index, signal, serverName, provider, userMCPAuthMap, availableTools: result.availableTools, toolKey: `${tool.name}${Constants.mcp_delimiter}${serverName}`, + config: serverConfig, }); if (toolInstance) { serverTools.push(toolInstance); @@ -262,6 +289,7 @@ async function createMCPTools({ res, user, index, signal, serverName, provider, * @param {Providers | EModelEndpoint} params.provider - The provider for the tool. * @param {LCAvailableTools} [params.availableTools] * @param {Record>} [params.userMCPAuthMap] + * @param {import('@librechat/api').ParsedServerConfig} [params.config] * @returns { Promise unknown}> } An object with `_call` method to execute the tool input. */ async function createMCPTool({ @@ -273,9 +301,24 @@ async function createMCPTool({ provider, userMCPAuthMap, availableTools, + config, }) { const [toolName, serverName] = toolKey.split(Constants.mcp_delimiter); + // Runtime domain validation: check if the server's domain is still allowed + // Use getAppConfig() to support per-user/role domain restrictions + const serverConfig = + config ?? (await getMCPServersRegistry().getServerConfig(serverName, user?.id)); + if (serverConfig?.url) { + const appConfig = await getAppConfig({ role: user?.role }); + const allowedDomains = appConfig?.mcpSettings?.allowedDomains; + const isDomainAllowed = await isMCPDomainAllowed(serverConfig, allowedDomains); + if (!isDomainAllowed) { + logger.warn(`[MCP][${serverName}] Domain no longer allowed, skipping tool: ${toolName}`); + return undefined; + } + } + /** @type {LCTool | undefined} */ let toolDefinition = availableTools?.[toolKey]?.function; if (!toolDefinition) { diff --git a/api/server/services/MCP.spec.js b/api/server/services/MCP.spec.js index 835dd7e29e..cb2f0081a3 100644 --- a/api/server/services/MCP.spec.js +++ b/api/server/services/MCP.spec.js @@ -1,14 +1,4 @@ -const { logger } = require('@librechat/data-schemas'); -const { MCPOAuthHandler } = require('@librechat/api'); -const { CacheKeys } = require('librechat-data-provider'); -const { - createMCPTool, - createMCPTools, - getMCPSetupData, - checkOAuthFlowStatus, - getServerConnectionStatus, -} = require('./MCP'); - +// Mock all dependencies - define mocks before imports // Mock all dependencies jest.mock('@librechat/data-schemas', () => ({ logger: { @@ -43,22 +33,46 @@ jest.mock('@librechat/agents', () => ({ }, })); +// Create mock registry instance const mockRegistryInstance = { getOAuthServers: jest.fn(() => Promise.resolve(new Set())), getAllServerConfigs: jest.fn(() => Promise.resolve({})), + getServerConfig: jest.fn(() => Promise.resolve(null)), }; -jest.mock('@librechat/api', () => ({ - MCPOAuthHandler: { - generateFlowId: jest.fn(), - }, - sendEvent: jest.fn(), - normalizeServerName: jest.fn((name) => name), - convertWithResolvedRefs: jest.fn((params) => params), - MCPServersRegistry: { - getInstance: () => mockRegistryInstance, - }, -})); +// Create isMCPDomainAllowed mock that can be configured per-test +const mockIsMCPDomainAllowed = jest.fn(() => Promise.resolve(true)); + +const mockGetAppConfig = jest.fn(() => Promise.resolve({})); + +jest.mock('@librechat/api', () => { + // Access mock via getter to avoid hoisting issues + return { + MCPOAuthHandler: { + generateFlowId: jest.fn(), + }, + sendEvent: jest.fn(), + normalizeServerName: jest.fn((name) => name), + convertWithResolvedRefs: jest.fn((params) => params), + get isMCPDomainAllowed() { + return mockIsMCPDomainAllowed; + }, + MCPServersRegistry: { + getInstance: () => mockRegistryInstance, + }, + }; +}); + +const { logger } = require('@librechat/data-schemas'); +const { MCPOAuthHandler } = require('@librechat/api'); +const { CacheKeys } = require('librechat-data-provider'); +const { + createMCPTool, + createMCPTools, + getMCPSetupData, + checkOAuthFlowStatus, + getServerConnectionStatus, +} = require('./MCP'); jest.mock('librechat-data-provider', () => ({ CacheKeys: { @@ -80,7 +94,9 @@ jest.mock('librechat-data-provider', () => ({ jest.mock('./Config', () => ({ loadCustomConfig: jest.fn(), - getAppConfig: jest.fn(), + get getAppConfig() { + return mockGetAppConfig; + }, })); jest.mock('~/config', () => ({ @@ -692,6 +708,18 @@ describe('User parameter passing tests', () => { createFlowWithHandler: jest.fn(), failFlow: jest.fn(), }); + + // Reset domain validation mock to default (allow all) + mockIsMCPDomainAllowed.mockReset(); + mockIsMCPDomainAllowed.mockResolvedValue(true); + + // Reset registry mocks + mockRegistryInstance.getServerConfig.mockReset(); + mockRegistryInstance.getServerConfig.mockResolvedValue(null); + + // Reset getAppConfig mock to default (no restrictions) + mockGetAppConfig.mockReset(); + mockGetAppConfig.mockResolvedValue({}); }); describe('createMCPTools', () => { @@ -887,6 +915,229 @@ describe('User parameter passing tests', () => { }); }); + describe('Runtime domain validation', () => { + it('should skip tool creation when domain is not allowed', async () => { + const mockUser = { id: 'domain-test-user', role: 'user' }; + const mockRes = { write: jest.fn(), flush: jest.fn() }; + + // Mock server config with URL (remote server) + mockRegistryInstance.getServerConfig.mockResolvedValue({ + url: 'https://disallowed-domain.com/sse', + }); + + // Mock getAppConfig to return domain restrictions + mockGetAppConfig.mockResolvedValue({ + mcpSettings: { allowedDomains: ['allowed-domain.com'] }, + }); + + // Mock domain validation to return false (domain not allowed) + mockIsMCPDomainAllowed.mockResolvedValueOnce(false); + + const result = await createMCPTool({ + res: mockRes, + user: mockUser, + toolKey: 'test-tool::test-server', + provider: 'openai', + userMCPAuthMap: {}, + availableTools: { + 'test-tool::test-server': { + function: { + description: 'Test tool', + parameters: { type: 'object', properties: {} }, + }, + }, + }, + }); + + // Should return undefined for disallowed domain + expect(result).toBeUndefined(); + + // Should not call reinitMCPServer since domain check failed + expect(mockReinitMCPServer).not.toHaveBeenCalled(); + + // Verify getAppConfig was called with user role + expect(mockGetAppConfig).toHaveBeenCalledWith({ role: 'user' }); + + // Verify domain validation was called with correct parameters + expect(mockIsMCPDomainAllowed).toHaveBeenCalledWith( + { url: 'https://disallowed-domain.com/sse' }, + ['allowed-domain.com'], + ); + }); + + it('should allow tool creation when domain is allowed', async () => { + const mockUser = { id: 'domain-test-user', role: 'admin' }; + const mockRes = { write: jest.fn(), flush: jest.fn() }; + + // Mock server config with URL (remote server) + mockRegistryInstance.getServerConfig.mockResolvedValue({ + url: 'https://allowed-domain.com/sse', + }); + + // Mock getAppConfig to return domain restrictions + mockGetAppConfig.mockResolvedValue({ + mcpSettings: { allowedDomains: ['allowed-domain.com'] }, + }); + + // Mock domain validation to return true (domain allowed) + mockIsMCPDomainAllowed.mockResolvedValueOnce(true); + + const availableTools = { + 'test-tool::test-server': { + function: { + description: 'Test tool', + parameters: { type: 'object', properties: {} }, + }, + }, + }; + + const result = await createMCPTool({ + res: mockRes, + user: mockUser, + toolKey: 'test-tool::test-server', + provider: 'openai', + userMCPAuthMap: {}, + availableTools, + }); + + // Should create tool successfully + expect(result).toBeDefined(); + + // Verify getAppConfig was called with user role + expect(mockGetAppConfig).toHaveBeenCalledWith({ role: 'admin' }); + }); + + it('should skip domain validation for stdio transports (no URL)', async () => { + const mockUser = { id: 'stdio-test-user' }; + const mockRes = { write: jest.fn(), flush: jest.fn() }; + + // Mock server config without URL (stdio transport) + mockRegistryInstance.getServerConfig.mockResolvedValue({ + command: 'npx', + args: ['@modelcontextprotocol/server'], + }); + + // Mock getAppConfig (should not be called for stdio) + mockGetAppConfig.mockResolvedValue({ + mcpSettings: { allowedDomains: ['restricted-domain.com'] }, + }); + + const availableTools = { + 'test-tool::test-server': { + function: { + description: 'Test tool', + parameters: { type: 'object', properties: {} }, + }, + }, + }; + + const result = await createMCPTool({ + res: mockRes, + user: mockUser, + toolKey: 'test-tool::test-server', + provider: 'openai', + userMCPAuthMap: {}, + availableTools, + }); + + // Should create tool successfully without domain check + expect(result).toBeDefined(); + + // Should not call getAppConfig or isMCPDomainAllowed for stdio transport (no URL) + expect(mockGetAppConfig).not.toHaveBeenCalled(); + expect(mockIsMCPDomainAllowed).not.toHaveBeenCalled(); + }); + + it('should return empty array from createMCPTools when domain is not allowed', async () => { + const mockUser = { id: 'domain-test-user', role: 'user' }; + const mockRes = { write: jest.fn(), flush: jest.fn() }; + + // Mock server config with URL (remote server) + const serverConfig = { url: 'https://disallowed-domain.com/sse' }; + mockRegistryInstance.getServerConfig.mockResolvedValue(serverConfig); + + // Mock getAppConfig to return domain restrictions + mockGetAppConfig.mockResolvedValue({ + mcpSettings: { allowedDomains: ['allowed-domain.com'] }, + }); + + // Mock domain validation to return false (domain not allowed) + mockIsMCPDomainAllowed.mockResolvedValueOnce(false); + + const result = await createMCPTools({ + res: mockRes, + user: mockUser, + serverName: 'test-server', + provider: 'openai', + userMCPAuthMap: {}, + config: serverConfig, + }); + + // Should return empty array for disallowed domain + expect(result).toEqual([]); + + // Should not call reinitMCPServer since domain check failed early + expect(mockReinitMCPServer).not.toHaveBeenCalled(); + + // Verify getAppConfig was called with user role + expect(mockGetAppConfig).toHaveBeenCalledWith({ role: 'user' }); + }); + + it('should use user role when fetching domain restrictions', async () => { + const adminUser = { id: 'admin-user', role: 'admin' }; + const regularUser = { id: 'regular-user', role: 'user' }; + const mockRes = { write: jest.fn(), flush: jest.fn() }; + + mockRegistryInstance.getServerConfig.mockResolvedValue({ + url: 'https://some-domain.com/sse', + }); + + // Mock different responses based on role + mockGetAppConfig + .mockResolvedValueOnce({ mcpSettings: { allowedDomains: ['admin-allowed.com'] } }) + .mockResolvedValueOnce({ mcpSettings: { allowedDomains: ['user-allowed.com'] } }); + + mockIsMCPDomainAllowed.mockResolvedValue(true); + + const availableTools = { + 'test-tool::test-server': { + function: { + description: 'Test tool', + parameters: { type: 'object', properties: {} }, + }, + }, + }; + + // Call with admin user + await createMCPTool({ + res: mockRes, + user: adminUser, + toolKey: 'test-tool::test-server', + provider: 'openai', + userMCPAuthMap: {}, + availableTools, + }); + + // Reset and call with regular user + mockRegistryInstance.getServerConfig.mockResolvedValue({ + url: 'https://some-domain.com/sse', + }); + + await createMCPTool({ + res: mockRes, + user: regularUser, + toolKey: 'test-tool::test-server', + provider: 'openai', + userMCPAuthMap: {}, + availableTools, + }); + + // Verify getAppConfig was called with correct roles + expect(mockGetAppConfig).toHaveBeenNthCalledWith(1, { role: 'admin' }); + expect(mockGetAppConfig).toHaveBeenNthCalledWith(2, { role: 'user' }); + }); + }); + describe('User parameter integrity', () => { it('should preserve user object properties through the call chain', async () => { const complexUser = { diff --git a/api/server/services/initializeMCPs.js b/api/server/services/initializeMCPs.js index e4306245bb..c964b2f292 100644 --- a/api/server/services/initializeMCPs.js +++ b/api/server/services/initializeMCPs.js @@ -14,8 +14,9 @@ async function initializeMCPs() { } // Initialize MCPServersRegistry first (required for MCPManager) + // Pass allowedDomains from mcpSettings for domain validation try { - createMCPServersRegistry(mongoose); + createMCPServersRegistry(mongoose, appConfig?.mcpSettings?.allowedDomains); } catch (error) { logger.error('[MCP] Failed to initialize MCPServersRegistry:', error); throw error; diff --git a/client/src/components/SidePanel/MCPBuilder/MCPServerDialog.tsx b/client/src/components/SidePanel/MCPBuilder/MCPServerDialog.tsx index b4da42482b..6ae065ceee 100644 --- a/client/src/components/SidePanel/MCPBuilder/MCPServerDialog.tsx +++ b/client/src/components/SidePanel/MCPBuilder/MCPServerDialog.tsx @@ -308,6 +308,8 @@ export default function MCPServerDialog({ const axiosError = error as any; if (axiosError.response?.data?.error === 'MCP_INSPECTION_FAILED') { errorMessage = localize('com_ui_mcp_server_connection_failed'); + } else if (axiosError.response?.data?.error === 'MCP_DOMAIN_NOT_ALLOWED') { + errorMessage = localize('com_ui_mcp_domain_not_allowed'); } else if (axiosError.response?.data?.error) { errorMessage = axiosError.response.data.error; } diff --git a/client/src/locales/en/translation.json b/client/src/locales/en/translation.json index 5d8c67c92c..0a3c6f7b68 100644 --- a/client/src/locales/en/translation.json +++ b/client/src/locales/en/translation.json @@ -1049,6 +1049,7 @@ "com_ui_mcp_configure_server": "Configure {{0}}", "com_ui_mcp_configure_server_description": "Configure custom variables for {{0}}", "com_ui_mcp_dialog_title": "Configure Variables for {{serverName}}. Server Status: {{status}}", + "com_ui_mcp_domain_not_allowed": "The MCP server domain is not in the allowed domains list. Please contact your administrator.", "com_ui_mcp_enter_var": "Enter value for {{0}}", "com_ui_mcp_init_failed": "Failed to initialize MCP server", "com_ui_mcp_initialize": "Initialize", diff --git a/librechat.example.yaml b/librechat.example.yaml index 2d0cb80abd..4c27fe6ec9 100644 --- a/librechat.example.yaml +++ b/librechat.example.yaml @@ -184,6 +184,16 @@ actions: - 'librechat.ai' - 'google.com' +# MCP Server domain restrictions for remote transports (SSE, WebSocket, HTTP) +# Stdio transports (local processes) are not restricted. +# If not configured, all domains are allowed (permissive default). +# Supports wildcards: '*.example.com' matches 'api.example.com', 'staging.example.com', etc. +# mcpSettings: +# allowedDomains: +# - 'localhost' +# - '*.example.com' +# - 'trusted-mcp-provider.com' + # Example MCP Servers Object Structure # mcpServers: # everything: diff --git a/packages/api/src/auth/domain.spec.ts b/packages/api/src/auth/domain.spec.ts index 4f6c25ec51..02ca9767d3 100644 --- a/packages/api/src/auth/domain.spec.ts +++ b/packages/api/src/auth/domain.spec.ts @@ -1,5 +1,10 @@ /* eslint-disable @typescript-eslint/ban-ts-comment */ -import { isEmailDomainAllowed, isActionDomainAllowed } from './domain'; +import { + isEmailDomainAllowed, + isActionDomainAllowed, + extractMCPServerDomain, + isMCPDomainAllowed, +} from './domain'; describe('isEmailDomainAllowed', () => { afterEach(() => { @@ -213,3 +218,209 @@ describe('isActionDomainAllowed', () => { }); }); }); + +describe('extractMCPServerDomain', () => { + afterEach(() => { + jest.clearAllMocks(); + }); + + describe('URL extraction', () => { + it('should extract domain from HTTPS URL', () => { + const config = { url: 'https://api.example.com/sse' }; + expect(extractMCPServerDomain(config)).toBe('api.example.com'); + }); + + it('should extract domain from HTTP URL', () => { + const config = { url: 'http://api.example.com/sse' }; + expect(extractMCPServerDomain(config)).toBe('api.example.com'); + }); + + it('should extract domain from WebSocket URL', () => { + const config = { url: 'wss://ws.example.com' }; + expect(extractMCPServerDomain(config)).toBe('ws.example.com'); + }); + + it('should handle URL with port', () => { + const config = { url: 'https://localhost:3001/sse' }; + expect(extractMCPServerDomain(config)).toBe('localhost'); + }); + + it('should strip www prefix', () => { + const config = { url: 'https://www.example.com/api' }; + expect(extractMCPServerDomain(config)).toBe('example.com'); + }); + + it('should handle URL with path and query parameters', () => { + const config = { url: 'https://api.example.com/v1/sse?token=abc' }; + expect(extractMCPServerDomain(config)).toBe('api.example.com'); + }); + }); + + describe('stdio transports (no URL)', () => { + it('should return null for stdio transport with command only', () => { + const config = { command: 'npx', args: ['-y', '@modelcontextprotocol/server-puppeteer'] }; + expect(extractMCPServerDomain(config)).toBeNull(); + }); + + it('should return null when url is undefined', () => { + const config = { command: 'node', args: ['server.js'] }; + expect(extractMCPServerDomain(config)).toBeNull(); + }); + + it('should return null for empty object', () => { + const config = {}; + expect(extractMCPServerDomain(config)).toBeNull(); + }); + }); + + describe('invalid URLs', () => { + it('should return null for invalid URL format', () => { + const config = { url: 'not-a-valid-url' }; + expect(extractMCPServerDomain(config)).toBeNull(); + }); + + it('should return null for empty URL string', () => { + const config = { url: '' }; + expect(extractMCPServerDomain(config)).toBeNull(); + }); + + it('should return null for non-string url', () => { + const config = { url: 12345 }; + expect(extractMCPServerDomain(config)).toBeNull(); + }); + + it('should return null for null url', () => { + const config = { url: null }; + expect(extractMCPServerDomain(config)).toBeNull(); + }); + }); +}); + +describe('isMCPDomainAllowed', () => { + afterEach(() => { + jest.clearAllMocks(); + }); + + describe('stdio transports (always allowed)', () => { + it('should allow stdio transport regardless of allowlist', async () => { + const config = { command: 'npx', args: ['-y', '@modelcontextprotocol/server-puppeteer'] }; + expect(await isMCPDomainAllowed(config, ['example.com'])).toBe(true); + }); + + it('should allow stdio transport even with empty allowlist', async () => { + const config = { command: 'node', args: ['server.js'] }; + expect(await isMCPDomainAllowed(config, [])).toBe(true); + }); + + it('should allow stdio transport when no URL present', async () => { + const config = {}; + expect(await isMCPDomainAllowed(config, ['restricted.com'])).toBe(true); + }); + }); + + describe('permissive defaults (no restrictions)', () => { + it('should allow all domains when allowedDomains is null', async () => { + const config = { url: 'https://any-domain.com/sse' }; + expect(await isMCPDomainAllowed(config, null)).toBe(true); + }); + + it('should allow all domains when allowedDomains is undefined', async () => { + const config = { url: 'https://any-domain.com/sse' }; + expect(await isMCPDomainAllowed(config, undefined)).toBe(true); + }); + + it('should allow all domains when allowedDomains is empty array', async () => { + const config = { url: 'https://any-domain.com/sse' }; + expect(await isMCPDomainAllowed(config, [])).toBe(true); + }); + }); + + describe('exact domain matching', () => { + const allowedDomains = ['example.com', 'localhost', 'trusted-mcp.com']; + + it('should allow exact domain match', async () => { + const config = { url: 'https://example.com/api' }; + expect(await isMCPDomainAllowed(config, allowedDomains)).toBe(true); + }); + + it('should allow localhost', async () => { + const config = { url: 'http://localhost:3001/sse' }; + expect(await isMCPDomainAllowed(config, allowedDomains)).toBe(true); + }); + + it('should reject non-allowed domain', async () => { + const config = { url: 'https://malicious.com/sse' }; + expect(await isMCPDomainAllowed(config, allowedDomains)).toBe(false); + }); + + it('should reject subdomain when only parent is allowed', async () => { + const config = { url: 'https://api.example.com/sse' }; + expect(await isMCPDomainAllowed(config, allowedDomains)).toBe(false); + }); + }); + + describe('wildcard domain matching', () => { + const allowedDomains = ['*.example.com', 'localhost']; + + it('should allow subdomain with wildcard', async () => { + const config = { url: 'https://api.example.com/sse' }; + expect(await isMCPDomainAllowed(config, allowedDomains)).toBe(true); + }); + + it('should allow any subdomain with wildcard', async () => { + const config = { url: 'https://staging.example.com/sse' }; + expect(await isMCPDomainAllowed(config, allowedDomains)).toBe(true); + }); + + it('should allow base domain with wildcard', async () => { + const config = { url: 'https://example.com/sse' }; + expect(await isMCPDomainAllowed(config, allowedDomains)).toBe(true); + }); + + it('should allow nested subdomain with wildcard', async () => { + const config = { url: 'https://deep.nested.example.com/sse' }; + expect(await isMCPDomainAllowed(config, allowedDomains)).toBe(true); + }); + + it('should reject different domain even with wildcard', async () => { + const config = { url: 'https://api.other.com/sse' }; + expect(await isMCPDomainAllowed(config, allowedDomains)).toBe(false); + }); + }); + + describe('case insensitivity', () => { + it('should match domains case-insensitively', async () => { + const config = { url: 'https://EXAMPLE.COM/sse' }; + expect(await isMCPDomainAllowed(config, ['example.com'])).toBe(true); + }); + + it('should match with uppercase in allowlist', async () => { + const config = { url: 'https://example.com/sse' }; + expect(await isMCPDomainAllowed(config, ['EXAMPLE.COM'])).toBe(true); + }); + + it('should match with mixed case', async () => { + const config = { url: 'https://Api.Example.Com/sse' }; + expect(await isMCPDomainAllowed(config, ['*.example.com'])).toBe(true); + }); + }); + + describe('www prefix handling', () => { + it('should strip www prefix from URL before matching', async () => { + const config = { url: 'https://www.example.com/sse' }; + expect(await isMCPDomainAllowed(config, ['example.com'])).toBe(true); + }); + + it('should match www in allowlist to non-www URL', async () => { + const config = { url: 'https://example.com/sse' }; + expect(await isMCPDomainAllowed(config, ['www.example.com'])).toBe(true); + }); + }); + + describe('invalid URL handling', () => { + it('should allow config with invalid URL (treated as stdio)', async () => { + const config = { url: 'not-a-valid-url' }; + expect(await isMCPDomainAllowed(config, ['example.com'])).toBe(true); + }); + }); +}); diff --git a/packages/api/src/auth/domain.ts b/packages/api/src/auth/domain.ts index 00bcf91787..851d3678dc 100644 --- a/packages/api/src/auth/domain.ts +++ b/packages/api/src/auth/domain.ts @@ -96,3 +96,45 @@ export async function isActionDomainAllowed( return false; } + +/** + * Extracts domain from MCP server config URL. + * Returns null for stdio transports (no URL) or invalid URLs. + * @param config - MCP server configuration (accepts any config with optional url field) + */ +export function extractMCPServerDomain(config: Record): string | null { + const url = config.url; + // Stdio transports don't have URLs - always allowed + if (!url || typeof url !== 'string') { + return null; + } + + try { + const parsedUrl = new URL(url); + return parsedUrl.hostname.replace(/^www\./i, ''); + } catch { + return null; + } +} + +/** + * Validates MCP server domain against allowedDomains. + * Reuses isActionDomainAllowed for consistent validation logic. + * Stdio transports (no URL) are always allowed. + * @param config - MCP server configuration with optional url field + * @param allowedDomains - List of allowed domains (with wildcard support) + */ +export async function isMCPDomainAllowed( + config: Record, + allowedDomains?: string[] | null, +): Promise { + const domain = extractMCPServerDomain(config); + + // Stdio transports don't have domains - always allowed + if (!domain) { + return true; + } + + // Reuse existing validation logic (includes wildcard support) + return isActionDomainAllowed(domain, allowedDomains); +} diff --git a/packages/api/src/index.ts b/packages/api/src/index.ts index 6350247a69..067d0a1e07 100644 --- a/packages/api/src/index.ts +++ b/packages/api/src/index.ts @@ -9,6 +9,7 @@ export * from './mcp/connection'; export * from './mcp/oauth'; export * from './mcp/auth'; export * from './mcp/zod'; +export * from './mcp/errors'; /* Utilities */ export * from './mcp/utils'; export * from './utils'; diff --git a/packages/api/src/mcp/errors.ts b/packages/api/src/mcp/errors.ts new file mode 100644 index 0000000000..21b249db30 --- /dev/null +++ b/packages/api/src/mcp/errors.ts @@ -0,0 +1,61 @@ +/** + * MCP-specific error classes + */ + +export const MCPErrorCodes = { + DOMAIN_NOT_ALLOWED: 'MCP_DOMAIN_NOT_ALLOWED', + INSPECTION_FAILED: 'MCP_INSPECTION_FAILED', +} as const; + +export type MCPErrorCode = (typeof MCPErrorCodes)[keyof typeof MCPErrorCodes]; + +/** + * Custom error for MCP domain restriction violations. + * Thrown when a user attempts to connect to an MCP server whose domain is not in the allowlist. + */ +export class MCPDomainNotAllowedError extends Error { + public readonly code = MCPErrorCodes.DOMAIN_NOT_ALLOWED; + public readonly statusCode = 403; + public readonly domain: string; + + constructor(domain: string) { + super(`Domain "${domain}" is not allowed`); + this.name = 'MCPDomainNotAllowedError'; + this.domain = domain; + Object.setPrototypeOf(this, MCPDomainNotAllowedError.prototype); + } +} + +/** + * Custom error for MCP server inspection failures. + * Thrown when attempting to connect/inspect an MCP server fails. + */ +export class MCPInspectionFailedError extends Error { + public readonly code = MCPErrorCodes.INSPECTION_FAILED; + public readonly statusCode = 400; + public readonly serverName: string; + + constructor(serverName: string, cause?: Error) { + super(`Failed to connect to MCP server "${serverName}"`); + this.name = 'MCPInspectionFailedError'; + this.serverName = serverName; + if (cause) { + this.cause = cause; + } + Object.setPrototypeOf(this, MCPInspectionFailedError.prototype); + } +} + +/** + * Type guard to check if an error is an MCPDomainNotAllowedError + */ +export function isMCPDomainNotAllowedError(error: unknown): error is MCPDomainNotAllowedError { + return error instanceof MCPDomainNotAllowedError; +} + +/** + * Type guard to check if an error is an MCPInspectionFailedError + */ +export function isMCPInspectionFailedError(error: unknown): error is MCPInspectionFailedError { + return error instanceof MCPInspectionFailedError; +} diff --git a/packages/api/src/mcp/registry/MCPServerInspector.ts b/packages/api/src/mcp/registry/MCPServerInspector.ts index d7807e6c95..2263c10422 100644 --- a/packages/api/src/mcp/registry/MCPServerInspector.ts +++ b/packages/api/src/mcp/registry/MCPServerInspector.ts @@ -2,7 +2,9 @@ import { Constants } from 'librechat-data-provider'; import type { JsonSchemaType } from '@librechat/data-schemas'; import type { MCPConnection } from '~/mcp/connection'; import type * as t from '~/mcp/types'; +import { isMCPDomainAllowed, extractMCPServerDomain } from '~/auth/domain'; import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory'; +import { MCPDomainNotAllowedError } from '~/mcp/errors'; import { detectOAuthRequirement } from '~/mcp/oauth'; import { isEnabled } from '~/utils'; @@ -24,13 +26,22 @@ export class MCPServerInspector { * @param serverName - The name of the server (used for tool function naming) * @param rawConfig - The raw server configuration * @param connection - The MCP connection + * @param allowedDomains - Optional list of allowed domains for remote transports * @returns A fully processed and enriched configuration with server metadata */ public static async inspect( serverName: string, rawConfig: t.MCPOptions, connection?: MCPConnection, + allowedDomains?: string[] | null, ): Promise { + // Validate domain against allowlist BEFORE attempting connection + const isDomainAllowed = await isMCPDomainAllowed(rawConfig, allowedDomains); + if (!isDomainAllowed) { + const domain = extractMCPServerDomain(rawConfig); + throw new MCPDomainNotAllowedError(domain ?? 'unknown'); + } + const start = Date.now(); const inspector = new MCPServerInspector(serverName, rawConfig, connection); await inspector.inspectServer(); diff --git a/packages/api/src/mcp/registry/MCPServersRegistry.ts b/packages/api/src/mcp/registry/MCPServersRegistry.ts index eb1ee1d3d6..9c097270b5 100644 --- a/packages/api/src/mcp/registry/MCPServersRegistry.ts +++ b/packages/api/src/mcp/registry/MCPServersRegistry.ts @@ -1,6 +1,7 @@ import { logger } from '@librechat/data-schemas'; import type { IServerConfigsRepositoryInterface } from './ServerConfigsRepositoryInterface'; import type * as t from '~/mcp/types'; +import { MCPInspectionFailedError, isMCPDomainNotAllowedError } from '~/mcp/errors'; import { ServerConfigsCacheFactory } from './cache/ServerConfigsCacheFactory'; import { MCPServerInspector } from './MCPServerInspector'; import { ServerConfigsDB } from './db/ServerConfigsDB'; @@ -20,14 +21,19 @@ export class MCPServersRegistry { private readonly dbConfigsRepo: IServerConfigsRepositoryInterface; private readonly cacheConfigsRepo: IServerConfigsRepositoryInterface; + private readonly allowedDomains?: string[] | null; - constructor(mongoose: typeof import('mongoose')) { + constructor(mongoose: typeof import('mongoose'), allowedDomains?: string[] | null) { this.dbConfigsRepo = new ServerConfigsDB(mongoose); this.cacheConfigsRepo = ServerConfigsCacheFactory.create('App', false); + this.allowedDomains = allowedDomains; } /** Creates and initializes the singleton MCPServersRegistry instance */ - public static createInstance(mongoose: typeof import('mongoose')): MCPServersRegistry { + public static createInstance( + mongoose: typeof import('mongoose'), + allowedDomains?: string[] | null, + ): MCPServersRegistry { if (!mongoose) { throw new Error( 'MCPServersRegistry creation failed: mongoose instance is required for database operations. ' + @@ -39,7 +45,7 @@ export class MCPServersRegistry { return MCPServersRegistry.instance; } logger.info('[MCPServersRegistry] Creating new instance'); - MCPServersRegistry.instance = new MCPServersRegistry(mongoose); + MCPServersRegistry.instance = new MCPServersRegistry(mongoose, allowedDomains); return MCPServersRegistry.instance; } @@ -80,10 +86,19 @@ export class MCPServersRegistry { const configRepo = this.getConfigRepository(storageLocation); let parsedConfig: t.ParsedServerConfig; try { - parsedConfig = await MCPServerInspector.inspect(serverName, config); + parsedConfig = await MCPServerInspector.inspect( + serverName, + config, + undefined, + this.allowedDomains, + ); } catch (error) { logger.error(`[MCPServersRegistry] Failed to inspect server "${serverName}":`, error); - throw new Error(`MCP_INSPECTION_FAILED: Failed to connect to MCP server "${serverName}"`); + // Preserve domain-specific error for better error handling + if (isMCPDomainNotAllowedError(error)) { + throw error; + } + throw new MCPInspectionFailedError(serverName, error as Error); } return await configRepo.add(serverName, parsedConfig, userId); } @@ -113,10 +128,19 @@ export class MCPServersRegistry { let parsedConfig: t.ParsedServerConfig; try { - parsedConfig = await MCPServerInspector.inspect(serverName, configForInspection); + parsedConfig = await MCPServerInspector.inspect( + serverName, + configForInspection, + undefined, + this.allowedDomains, + ); } catch (error) { logger.error(`[MCPServersRegistry] Failed to inspect server "${serverName}":`, error); - throw new Error(`MCP_INSPECTION_FAILED: Failed to connect to MCP server "${serverName}"`); + // Preserve domain-specific error for better error handling + if (isMCPDomainNotAllowedError(error)) { + throw error; + } + throw new MCPInspectionFailedError(serverName, error as Error); } await configRepo.update(serverName, parsedConfig, userId); return parsedConfig; diff --git a/packages/api/src/mcp/registry/__tests__/MCPServersInitializer.test.ts b/packages/api/src/mcp/registry/__tests__/MCPServersInitializer.test.ts index c5eb1b7171..255ef20760 100644 --- a/packages/api/src/mcp/registry/__tests__/MCPServersInitializer.test.ts +++ b/packages/api/src/mcp/registry/__tests__/MCPServersInitializer.test.ts @@ -224,18 +224,38 @@ describe('MCPServersInitializer', () => { it('should process all server configs through inspector', async () => { await MCPServersInitializer.initialize(testConfigs); - // Verify all configs were processed by inspector (without connection parameter) + // Verify all configs were processed by inspector + // Signature: inspect(serverName, rawConfig, connection?, allowedDomains?) expect(mockInspect).toHaveBeenCalledTimes(5); - expect(mockInspect).toHaveBeenCalledWith('disabled_server', testConfigs.disabled_server); - expect(mockInspect).toHaveBeenCalledWith('oauth_server', testConfigs.oauth_server); - expect(mockInspect).toHaveBeenCalledWith('file_tools_server', testConfigs.file_tools_server); + expect(mockInspect).toHaveBeenCalledWith( + 'disabled_server', + testConfigs.disabled_server, + undefined, + undefined, + ); + expect(mockInspect).toHaveBeenCalledWith( + 'oauth_server', + testConfigs.oauth_server, + undefined, + undefined, + ); + expect(mockInspect).toHaveBeenCalledWith( + 'file_tools_server', + testConfigs.file_tools_server, + undefined, + undefined, + ); expect(mockInspect).toHaveBeenCalledWith( 'search_tools_server', testConfigs.search_tools_server, + undefined, + undefined, ); expect(mockInspect).toHaveBeenCalledWith( 'remote_no_oauth_server', testConfigs.remote_no_oauth_server, + undefined, + undefined, ); }); diff --git a/packages/data-provider/src/config.ts b/packages/data-provider/src/config.ts index 9609b8de3f..d21a64ab6a 100644 --- a/packages/data-provider/src/config.ts +++ b/packages/data-provider/src/config.ts @@ -849,6 +849,11 @@ export const configSchema = z.object({ includedTools: z.array(z.string()).optional(), filteredTools: z.array(z.string()).optional(), mcpServers: MCPServersSchema.optional(), + mcpSettings: z + .object({ + allowedDomains: z.array(z.string()).optional(), + }) + .optional(), interface: interfaceSchema, turnstile: turnstileSchema.optional(), fileStrategy: fileSourceSchema.default(FileSources.local), diff --git a/packages/data-schemas/src/app/service.ts b/packages/data-schemas/src/app/service.ts index aef2472d5f..e15a27e0b0 100644 --- a/packages/data-schemas/src/app/service.ts +++ b/packages/data-schemas/src/app/service.ts @@ -60,7 +60,8 @@ export const AppService = async (params?: { const availableTools = systemTools; - const mcpConfig = config.mcpServers || null; + const mcpServersConfig = config.mcpServers || null; + const mcpSettings = config.mcpSettings || null; const registration = config.registration ?? configDefaults.registration; const interfaceConfig = await loadDefaultInterface({ config, configDefaults }); const turnstileConfig = loadTurnstileConfig(config, configDefaults); @@ -74,7 +75,8 @@ export const AppService = async (params?: { speech, balance, transactions, - mcpConfig, + mcpConfig: mcpServersConfig, + mcpSettings, webSearch, fileStrategy, registration, diff --git a/packages/data-schemas/src/types/app.ts b/packages/data-schemas/src/types/app.ts index 751e6a81d0..9157fabd44 100644 --- a/packages/data-schemas/src/types/app.ts +++ b/packages/data-schemas/src/types/app.ts @@ -82,6 +82,8 @@ export interface AppConfig { speech?: TCustomConfig['speech']; /** MCP server configuration */ mcpConfig?: TCustomConfig['mcpServers'] | null; + /** MCP settings (domain allowlist, etc.) */ + mcpSettings?: TCustomConfig['mcpSettings'] | null; /** File configuration */ fileConfig?: TFileConfig; /** Secure image links configuration */