🔒 feat: Add MCP server domain restrictions for remote transports (#11013)

* 🔒 feat: Add MCP server domain restrictions for remote transports

* 🔒 feat: Implement comprehensive MCP error handling and domain validation

- Added `handleMCPError` function to centralize error responses for domain restrictions and inspection failures.
- Introduced custom error classes: `MCPDomainNotAllowedError` and `MCPInspectionFailedError` for better error management.
- Updated MCP server controllers to utilize the new error handling mechanism.
- Enhanced domain validation logic in `createMCPTools` and `createMCPTool` functions to prevent operations on disallowed domains.
- Added tests for runtime domain validation scenarios to ensure correct behavior.

* chore: import order

* 🔒 feat: Enhance domain validation in MCP tools with user role-based restrictions

- Integrated `getAppConfig` to fetch allowed domains based on user roles in `createMCPTools` and `createMCPTool` functions.
- Removed the deprecated `getAllowedDomains` method from `MCPServersRegistry`.
- Updated tests to verify domain restrictions are applied correctly based on user roles.
- Ensured that domain validation logic is consistent and efficient across tool creation processes.

* 🔒 test: Refactor MCP tests to utilize configurable app settings

- Introduced a mock for `getAppConfig` to enhance test flexibility.
- Removed redundant mock definition to streamline test setup.
- Ensured tests are aligned with the latest domain validation logic.

---------

Co-authored-by: Atef Bellaaj <slalom.bellaaj@external.daimlertruck.com>
Co-authored-by: Danny Avila <danny@librechat.ai>
This commit is contained in:
Atef Bellaaj 2025-12-18 19:57:49 +01:00 committed by GitHub
parent 98294755ee
commit 95a69df70e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 815 additions and 75 deletions

View file

@ -10,6 +10,7 @@ const {
const {
sendEvent,
MCPOAuthHandler,
isMCPDomainAllowed,
normalizeServerName,
convertWithResolvedRefs,
} = require('@librechat/api');
@ -21,13 +22,14 @@ const {
isAssistantsEndpoint,
} = require('librechat-data-provider');
const {
getMCPManager,
getFlowStateManager,
getOAuthReconnectionManager,
getMCPServersRegistry,
getFlowStateManager,
getMCPManager,
} = require('~/config');
const { findToken, createToken, updateToken } = require('~/models');
const { reinitMCPServer } = require('./Tools/mcp');
const { getAppConfig } = require('./Config');
const { getLogStores } = require('~/cache');
/**
@ -222,10 +224,34 @@ async function reconnectServer({ res, user, index, signal, serverName, userMCPAu
* @param {Providers | EModelEndpoint} params.provider - The provider for the tool.
* @param {number} [params.index]
* @param {AbortSignal} [params.signal]
* @param {import('@librechat/api').ParsedServerConfig} [params.config]
* @param {Record<string, Record<string, string>>} [params.userMCPAuthMap]
* @returns { Promise<Array<typeof tool | { _call: (toolInput: Object | string) => unknown}>> } An object with `_call` method to execute the tool input.
*/
async function createMCPTools({ res, user, index, signal, serverName, provider, userMCPAuthMap }) {
async function createMCPTools({
res,
user,
index,
signal,
config,
provider,
serverName,
userMCPAuthMap,
}) {
// Early domain validation before reconnecting server (avoid wasted work on disallowed domains)
// Use getAppConfig() to support per-user/role domain restrictions
const serverConfig =
config ?? (await getMCPServersRegistry().getServerConfig(serverName, user?.id));
if (serverConfig?.url) {
const appConfig = await getAppConfig({ role: user?.role });
const allowedDomains = appConfig?.mcpSettings?.allowedDomains;
const isDomainAllowed = await isMCPDomainAllowed(serverConfig, allowedDomains);
if (!isDomainAllowed) {
logger.warn(`[MCP][${serverName}] Domain not allowed, skipping all tools`);
return [];
}
}
const result = await reconnectServer({ res, user, index, signal, serverName, userMCPAuthMap });
if (!result || !result.tools) {
logger.warn(`[MCP][${serverName}] Failed to reinitialize MCP server.`);
@ -241,6 +267,7 @@ async function createMCPTools({ res, user, index, signal, serverName, provider,
userMCPAuthMap,
availableTools: result.availableTools,
toolKey: `${tool.name}${Constants.mcp_delimiter}${serverName}`,
config: serverConfig,
});
if (toolInstance) {
serverTools.push(toolInstance);
@ -262,6 +289,7 @@ async function createMCPTools({ res, user, index, signal, serverName, provider,
* @param {Providers | EModelEndpoint} params.provider - The provider for the tool.
* @param {LCAvailableTools} [params.availableTools]
* @param {Record<string, Record<string, string>>} [params.userMCPAuthMap]
* @param {import('@librechat/api').ParsedServerConfig} [params.config]
* @returns { Promise<typeof tool | { _call: (toolInput: Object | string) => unknown}> } An object with `_call` method to execute the tool input.
*/
async function createMCPTool({
@ -273,9 +301,24 @@ async function createMCPTool({
provider,
userMCPAuthMap,
availableTools,
config,
}) {
const [toolName, serverName] = toolKey.split(Constants.mcp_delimiter);
// Runtime domain validation: check if the server's domain is still allowed
// Use getAppConfig() to support per-user/role domain restrictions
const serverConfig =
config ?? (await getMCPServersRegistry().getServerConfig(serverName, user?.id));
if (serverConfig?.url) {
const appConfig = await getAppConfig({ role: user?.role });
const allowedDomains = appConfig?.mcpSettings?.allowedDomains;
const isDomainAllowed = await isMCPDomainAllowed(serverConfig, allowedDomains);
if (!isDomainAllowed) {
logger.warn(`[MCP][${serverName}] Domain no longer allowed, skipping tool: ${toolName}`);
return undefined;
}
}
/** @type {LCTool | undefined} */
let toolDefinition = availableTools?.[toolKey]?.function;
if (!toolDefinition) {