mirror of
https://github.com/danny-avila/LibreChat.git
synced 2025-12-19 09:50:15 +01:00
🔒 feat: Add MCP server domain restrictions for remote transports (#11013)
* 🔒 feat: Add MCP server domain restrictions for remote transports * 🔒 feat: Implement comprehensive MCP error handling and domain validation - Added `handleMCPError` function to centralize error responses for domain restrictions and inspection failures. - Introduced custom error classes: `MCPDomainNotAllowedError` and `MCPInspectionFailedError` for better error management. - Updated MCP server controllers to utilize the new error handling mechanism. - Enhanced domain validation logic in `createMCPTools` and `createMCPTool` functions to prevent operations on disallowed domains. - Added tests for runtime domain validation scenarios to ensure correct behavior. * chore: import order * 🔒 feat: Enhance domain validation in MCP tools with user role-based restrictions - Integrated `getAppConfig` to fetch allowed domains based on user roles in `createMCPTools` and `createMCPTool` functions. - Removed the deprecated `getAllowedDomains` method from `MCPServersRegistry`. - Updated tests to verify domain restrictions are applied correctly based on user roles. - Ensured that domain validation logic is consistent and efficient across tool creation processes. * 🔒 test: Refactor MCP tests to utilize configurable app settings - Introduced a mock for `getAppConfig` to enhance test flexibility. - Removed redundant mock definition to streamline test setup. - Ensured tests are aligned with the latest domain validation logic. --------- Co-authored-by: Atef Bellaaj <slalom.bellaaj@external.daimlertruck.com> Co-authored-by: Danny Avila <danny@librechat.ai>
This commit is contained in:
parent
98294755ee
commit
95a69df70e
19 changed files with 815 additions and 75 deletions
|
|
@ -348,10 +348,10 @@ Anchor pattern: \\ue202turn{N}{type}{index} where N=turn number, type=search|new
|
||||||
/** Placeholder used for UI purposes */
|
/** Placeholder used for UI purposes */
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (
|
const serverConfig = serverName
|
||||||
serverName &&
|
? await getMCPServersRegistry().getServerConfig(serverName, user)
|
||||||
(await getMCPServersRegistry().getServerConfig(serverName, user)) == undefined
|
: null;
|
||||||
) {
|
if (!serverConfig) {
|
||||||
logger.warn(
|
logger.warn(
|
||||||
`MCP server "${serverName}" for "${toolName}" tool is not configured${agent?.id != null && agent.id ? ` but attached to "${agent.id}"` : ''}`,
|
`MCP server "${serverName}" for "${toolName}" tool is not configured${agent?.id != null && agent.id ? ` but attached to "${agent.id}"` : ''}`,
|
||||||
);
|
);
|
||||||
|
|
@ -362,6 +362,7 @@ Anchor pattern: \\ue202turn{N}{type}{index} where N=turn number, type=search|new
|
||||||
{
|
{
|
||||||
type: 'all',
|
type: 'all',
|
||||||
serverName,
|
serverName,
|
||||||
|
config: serverConfig,
|
||||||
},
|
},
|
||||||
];
|
];
|
||||||
continue;
|
continue;
|
||||||
|
|
@ -372,6 +373,7 @@ Anchor pattern: \\ue202turn{N}{type}{index} where N=turn number, type=search|new
|
||||||
type: 'single',
|
type: 'single',
|
||||||
toolKey: tool,
|
toolKey: tool,
|
||||||
serverName,
|
serverName,
|
||||||
|
config: serverConfig,
|
||||||
});
|
});
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
@ -435,6 +437,7 @@ Anchor pattern: \\ue202turn{N}{type}{index} where N=turn number, type=search|new
|
||||||
model: agent?.model ?? model,
|
model: agent?.model ?? model,
|
||||||
serverName: config.serverName,
|
serverName: config.serverName,
|
||||||
provider: agent?.provider ?? endpoint,
|
provider: agent?.provider ?? endpoint,
|
||||||
|
config: config.config,
|
||||||
};
|
};
|
||||||
|
|
||||||
if (config.type === 'all' && toolConfigs.length === 1) {
|
if (config.type === 'all' && toolConfigs.length === 1) {
|
||||||
|
|
|
||||||
|
|
@ -6,10 +6,54 @@
|
||||||
* @import { MCPServerDocument } from 'librechat-data-provider'
|
* @import { MCPServerDocument } from 'librechat-data-provider'
|
||||||
*/
|
*/
|
||||||
const { logger } = require('@librechat/data-schemas');
|
const { logger } = require('@librechat/data-schemas');
|
||||||
|
const {
|
||||||
|
isMCPDomainNotAllowedError,
|
||||||
|
isMCPInspectionFailedError,
|
||||||
|
MCPErrorCodes,
|
||||||
|
} = require('@librechat/api');
|
||||||
const { Constants, MCPServerUserInputSchema } = require('librechat-data-provider');
|
const { Constants, MCPServerUserInputSchema } = require('librechat-data-provider');
|
||||||
const { cacheMCPServerTools, getMCPServerTools } = require('~/server/services/Config');
|
const { cacheMCPServerTools, getMCPServerTools } = require('~/server/services/Config');
|
||||||
const { getMCPManager, getMCPServersRegistry } = require('~/config');
|
const { getMCPManager, getMCPServersRegistry } = require('~/config');
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Handles MCP-specific errors and sends appropriate HTTP responses.
|
||||||
|
* @param {Error} error - The error to handle
|
||||||
|
* @param {import('express').Response} res - Express response object
|
||||||
|
* @returns {import('express').Response | null} Response if handled, null if not an MCP error
|
||||||
|
*/
|
||||||
|
function handleMCPError(error, res) {
|
||||||
|
if (isMCPDomainNotAllowedError(error)) {
|
||||||
|
return res.status(error.statusCode).json({
|
||||||
|
error: error.code,
|
||||||
|
message: error.message,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isMCPInspectionFailedError(error)) {
|
||||||
|
return res.status(error.statusCode).json({
|
||||||
|
error: error.code,
|
||||||
|
message: error.message,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback for legacy string-based error handling (backwards compatibility)
|
||||||
|
if (error.message?.startsWith(MCPErrorCodes.DOMAIN_NOT_ALLOWED)) {
|
||||||
|
return res.status(403).json({
|
||||||
|
error: MCPErrorCodes.DOMAIN_NOT_ALLOWED,
|
||||||
|
message: error.message.replace(/^MCP_DOMAIN_NOT_ALLOWED\s*:\s*/i, ''),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if (error.message?.startsWith(MCPErrorCodes.INSPECTION_FAILED)) {
|
||||||
|
return res.status(400).json({
|
||||||
|
error: MCPErrorCodes.INSPECTION_FAILED,
|
||||||
|
message: error.message,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get all MCP tools available to the user
|
* Get all MCP tools available to the user
|
||||||
*/
|
*/
|
||||||
|
|
@ -175,11 +219,9 @@ const createMCPServerController = async (req, res) => {
|
||||||
});
|
});
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
logger.error('[createMCPServer]', error);
|
logger.error('[createMCPServer]', error);
|
||||||
if (error.message?.startsWith('MCP_INSPECTION_FAILED')) {
|
const mcpErrorResponse = handleMCPError(error, res);
|
||||||
return res.status(400).json({
|
if (mcpErrorResponse) {
|
||||||
error: 'MCP_INSPECTION_FAILED',
|
return mcpErrorResponse;
|
||||||
message: error.message,
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
res.status(500).json({ message: error.message });
|
res.status(500).json({ message: error.message });
|
||||||
}
|
}
|
||||||
|
|
@ -235,11 +277,9 @@ const updateMCPServerController = async (req, res) => {
|
||||||
res.status(200).json(parsedConfig);
|
res.status(200).json(parsedConfig);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
logger.error('[updateMCPServer]', error);
|
logger.error('[updateMCPServer]', error);
|
||||||
if (error.message?.startsWith('MCP_INSPECTION_FAILED:')) {
|
const mcpErrorResponse = handleMCPError(error, res);
|
||||||
return res.status(400).json({
|
if (mcpErrorResponse) {
|
||||||
error: 'MCP_INSPECTION_FAILED',
|
return mcpErrorResponse;
|
||||||
message: error.message,
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
res.status(500).json({ message: error.message });
|
res.status(500).json({ message: error.message });
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -12,26 +12,36 @@ const mockRegistryInstance = {
|
||||||
removeServer: jest.fn(),
|
removeServer: jest.fn(),
|
||||||
};
|
};
|
||||||
|
|
||||||
jest.mock('@librechat/api', () => ({
|
jest.mock('@librechat/api', () => {
|
||||||
...jest.requireActual('@librechat/api'),
|
const actual = jest.requireActual('@librechat/api');
|
||||||
MCPOAuthHandler: {
|
return {
|
||||||
initiateOAuthFlow: jest.fn(),
|
...actual,
|
||||||
getFlowState: jest.fn(),
|
MCPOAuthHandler: {
|
||||||
completeOAuthFlow: jest.fn(),
|
initiateOAuthFlow: jest.fn(),
|
||||||
generateFlowId: jest.fn(),
|
getFlowState: jest.fn(),
|
||||||
},
|
completeOAuthFlow: jest.fn(),
|
||||||
MCPTokenStorage: {
|
generateFlowId: jest.fn(),
|
||||||
storeTokens: jest.fn(),
|
},
|
||||||
getClientInfoAndMetadata: jest.fn(),
|
MCPTokenStorage: {
|
||||||
getTokens: jest.fn(),
|
storeTokens: jest.fn(),
|
||||||
deleteUserTokens: jest.fn(),
|
getClientInfoAndMetadata: jest.fn(),
|
||||||
},
|
getTokens: jest.fn(),
|
||||||
getUserMCPAuthMap: jest.fn(),
|
deleteUserTokens: jest.fn(),
|
||||||
generateCheckAccess: jest.fn(() => (req, res, next) => next()),
|
},
|
||||||
MCPServersRegistry: {
|
getUserMCPAuthMap: jest.fn(),
|
||||||
getInstance: () => mockRegistryInstance,
|
generateCheckAccess: jest.fn(() => (req, res, next) => next()),
|
||||||
},
|
MCPServersRegistry: {
|
||||||
}));
|
getInstance: () => mockRegistryInstance,
|
||||||
|
},
|
||||||
|
// Error handling utilities (from @librechat/api mcp/errors)
|
||||||
|
isMCPDomainNotAllowedError: (error) => error?.code === 'MCP_DOMAIN_NOT_ALLOWED',
|
||||||
|
isMCPInspectionFailedError: (error) => error?.code === 'MCP_INSPECTION_FAILED',
|
||||||
|
MCPErrorCodes: {
|
||||||
|
DOMAIN_NOT_ALLOWED: 'MCP_DOMAIN_NOT_ALLOWED',
|
||||||
|
INSPECTION_FAILED: 'MCP_INSPECTION_FAILED',
|
||||||
|
},
|
||||||
|
};
|
||||||
|
});
|
||||||
|
|
||||||
jest.mock('@librechat/data-schemas', () => ({
|
jest.mock('@librechat/data-schemas', () => ({
|
||||||
logger: {
|
logger: {
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,7 @@ const {
|
||||||
const {
|
const {
|
||||||
sendEvent,
|
sendEvent,
|
||||||
MCPOAuthHandler,
|
MCPOAuthHandler,
|
||||||
|
isMCPDomainAllowed,
|
||||||
normalizeServerName,
|
normalizeServerName,
|
||||||
convertWithResolvedRefs,
|
convertWithResolvedRefs,
|
||||||
} = require('@librechat/api');
|
} = require('@librechat/api');
|
||||||
|
|
@ -21,13 +22,14 @@ const {
|
||||||
isAssistantsEndpoint,
|
isAssistantsEndpoint,
|
||||||
} = require('librechat-data-provider');
|
} = require('librechat-data-provider');
|
||||||
const {
|
const {
|
||||||
getMCPManager,
|
|
||||||
getFlowStateManager,
|
|
||||||
getOAuthReconnectionManager,
|
getOAuthReconnectionManager,
|
||||||
getMCPServersRegistry,
|
getMCPServersRegistry,
|
||||||
|
getFlowStateManager,
|
||||||
|
getMCPManager,
|
||||||
} = require('~/config');
|
} = require('~/config');
|
||||||
const { findToken, createToken, updateToken } = require('~/models');
|
const { findToken, createToken, updateToken } = require('~/models');
|
||||||
const { reinitMCPServer } = require('./Tools/mcp');
|
const { reinitMCPServer } = require('./Tools/mcp');
|
||||||
|
const { getAppConfig } = require('./Config');
|
||||||
const { getLogStores } = require('~/cache');
|
const { getLogStores } = require('~/cache');
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
@ -222,10 +224,34 @@ async function reconnectServer({ res, user, index, signal, serverName, userMCPAu
|
||||||
* @param {Providers | EModelEndpoint} params.provider - The provider for the tool.
|
* @param {Providers | EModelEndpoint} params.provider - The provider for the tool.
|
||||||
* @param {number} [params.index]
|
* @param {number} [params.index]
|
||||||
* @param {AbortSignal} [params.signal]
|
* @param {AbortSignal} [params.signal]
|
||||||
|
* @param {import('@librechat/api').ParsedServerConfig} [params.config]
|
||||||
* @param {Record<string, Record<string, string>>} [params.userMCPAuthMap]
|
* @param {Record<string, Record<string, string>>} [params.userMCPAuthMap]
|
||||||
* @returns { Promise<Array<typeof tool | { _call: (toolInput: Object | string) => unknown}>> } An object with `_call` method to execute the tool input.
|
* @returns { Promise<Array<typeof tool | { _call: (toolInput: Object | string) => unknown}>> } An object with `_call` method to execute the tool input.
|
||||||
*/
|
*/
|
||||||
async function createMCPTools({ res, user, index, signal, serverName, provider, userMCPAuthMap }) {
|
async function createMCPTools({
|
||||||
|
res,
|
||||||
|
user,
|
||||||
|
index,
|
||||||
|
signal,
|
||||||
|
config,
|
||||||
|
provider,
|
||||||
|
serverName,
|
||||||
|
userMCPAuthMap,
|
||||||
|
}) {
|
||||||
|
// Early domain validation before reconnecting server (avoid wasted work on disallowed domains)
|
||||||
|
// Use getAppConfig() to support per-user/role domain restrictions
|
||||||
|
const serverConfig =
|
||||||
|
config ?? (await getMCPServersRegistry().getServerConfig(serverName, user?.id));
|
||||||
|
if (serverConfig?.url) {
|
||||||
|
const appConfig = await getAppConfig({ role: user?.role });
|
||||||
|
const allowedDomains = appConfig?.mcpSettings?.allowedDomains;
|
||||||
|
const isDomainAllowed = await isMCPDomainAllowed(serverConfig, allowedDomains);
|
||||||
|
if (!isDomainAllowed) {
|
||||||
|
logger.warn(`[MCP][${serverName}] Domain not allowed, skipping all tools`);
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const result = await reconnectServer({ res, user, index, signal, serverName, userMCPAuthMap });
|
const result = await reconnectServer({ res, user, index, signal, serverName, userMCPAuthMap });
|
||||||
if (!result || !result.tools) {
|
if (!result || !result.tools) {
|
||||||
logger.warn(`[MCP][${serverName}] Failed to reinitialize MCP server.`);
|
logger.warn(`[MCP][${serverName}] Failed to reinitialize MCP server.`);
|
||||||
|
|
@ -241,6 +267,7 @@ async function createMCPTools({ res, user, index, signal, serverName, provider,
|
||||||
userMCPAuthMap,
|
userMCPAuthMap,
|
||||||
availableTools: result.availableTools,
|
availableTools: result.availableTools,
|
||||||
toolKey: `${tool.name}${Constants.mcp_delimiter}${serverName}`,
|
toolKey: `${tool.name}${Constants.mcp_delimiter}${serverName}`,
|
||||||
|
config: serverConfig,
|
||||||
});
|
});
|
||||||
if (toolInstance) {
|
if (toolInstance) {
|
||||||
serverTools.push(toolInstance);
|
serverTools.push(toolInstance);
|
||||||
|
|
@ -262,6 +289,7 @@ async function createMCPTools({ res, user, index, signal, serverName, provider,
|
||||||
* @param {Providers | EModelEndpoint} params.provider - The provider for the tool.
|
* @param {Providers | EModelEndpoint} params.provider - The provider for the tool.
|
||||||
* @param {LCAvailableTools} [params.availableTools]
|
* @param {LCAvailableTools} [params.availableTools]
|
||||||
* @param {Record<string, Record<string, string>>} [params.userMCPAuthMap]
|
* @param {Record<string, Record<string, string>>} [params.userMCPAuthMap]
|
||||||
|
* @param {import('@librechat/api').ParsedServerConfig} [params.config]
|
||||||
* @returns { Promise<typeof tool | { _call: (toolInput: Object | string) => unknown}> } An object with `_call` method to execute the tool input.
|
* @returns { Promise<typeof tool | { _call: (toolInput: Object | string) => unknown}> } An object with `_call` method to execute the tool input.
|
||||||
*/
|
*/
|
||||||
async function createMCPTool({
|
async function createMCPTool({
|
||||||
|
|
@ -273,9 +301,24 @@ async function createMCPTool({
|
||||||
provider,
|
provider,
|
||||||
userMCPAuthMap,
|
userMCPAuthMap,
|
||||||
availableTools,
|
availableTools,
|
||||||
|
config,
|
||||||
}) {
|
}) {
|
||||||
const [toolName, serverName] = toolKey.split(Constants.mcp_delimiter);
|
const [toolName, serverName] = toolKey.split(Constants.mcp_delimiter);
|
||||||
|
|
||||||
|
// Runtime domain validation: check if the server's domain is still allowed
|
||||||
|
// Use getAppConfig() to support per-user/role domain restrictions
|
||||||
|
const serverConfig =
|
||||||
|
config ?? (await getMCPServersRegistry().getServerConfig(serverName, user?.id));
|
||||||
|
if (serverConfig?.url) {
|
||||||
|
const appConfig = await getAppConfig({ role: user?.role });
|
||||||
|
const allowedDomains = appConfig?.mcpSettings?.allowedDomains;
|
||||||
|
const isDomainAllowed = await isMCPDomainAllowed(serverConfig, allowedDomains);
|
||||||
|
if (!isDomainAllowed) {
|
||||||
|
logger.warn(`[MCP][${serverName}] Domain no longer allowed, skipping tool: ${toolName}`);
|
||||||
|
return undefined;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/** @type {LCTool | undefined} */
|
/** @type {LCTool | undefined} */
|
||||||
let toolDefinition = availableTools?.[toolKey]?.function;
|
let toolDefinition = availableTools?.[toolKey]?.function;
|
||||||
if (!toolDefinition) {
|
if (!toolDefinition) {
|
||||||
|
|
|
||||||
|
|
@ -1,14 +1,4 @@
|
||||||
const { logger } = require('@librechat/data-schemas');
|
// Mock all dependencies - define mocks before imports
|
||||||
const { MCPOAuthHandler } = require('@librechat/api');
|
|
||||||
const { CacheKeys } = require('librechat-data-provider');
|
|
||||||
const {
|
|
||||||
createMCPTool,
|
|
||||||
createMCPTools,
|
|
||||||
getMCPSetupData,
|
|
||||||
checkOAuthFlowStatus,
|
|
||||||
getServerConnectionStatus,
|
|
||||||
} = require('./MCP');
|
|
||||||
|
|
||||||
// Mock all dependencies
|
// Mock all dependencies
|
||||||
jest.mock('@librechat/data-schemas', () => ({
|
jest.mock('@librechat/data-schemas', () => ({
|
||||||
logger: {
|
logger: {
|
||||||
|
|
@ -43,22 +33,46 @@ jest.mock('@librechat/agents', () => ({
|
||||||
},
|
},
|
||||||
}));
|
}));
|
||||||
|
|
||||||
|
// Create mock registry instance
|
||||||
const mockRegistryInstance = {
|
const mockRegistryInstance = {
|
||||||
getOAuthServers: jest.fn(() => Promise.resolve(new Set())),
|
getOAuthServers: jest.fn(() => Promise.resolve(new Set())),
|
||||||
getAllServerConfigs: jest.fn(() => Promise.resolve({})),
|
getAllServerConfigs: jest.fn(() => Promise.resolve({})),
|
||||||
|
getServerConfig: jest.fn(() => Promise.resolve(null)),
|
||||||
};
|
};
|
||||||
|
|
||||||
jest.mock('@librechat/api', () => ({
|
// Create isMCPDomainAllowed mock that can be configured per-test
|
||||||
MCPOAuthHandler: {
|
const mockIsMCPDomainAllowed = jest.fn(() => Promise.resolve(true));
|
||||||
generateFlowId: jest.fn(),
|
|
||||||
},
|
const mockGetAppConfig = jest.fn(() => Promise.resolve({}));
|
||||||
sendEvent: jest.fn(),
|
|
||||||
normalizeServerName: jest.fn((name) => name),
|
jest.mock('@librechat/api', () => {
|
||||||
convertWithResolvedRefs: jest.fn((params) => params),
|
// Access mock via getter to avoid hoisting issues
|
||||||
MCPServersRegistry: {
|
return {
|
||||||
getInstance: () => mockRegistryInstance,
|
MCPOAuthHandler: {
|
||||||
},
|
generateFlowId: jest.fn(),
|
||||||
}));
|
},
|
||||||
|
sendEvent: jest.fn(),
|
||||||
|
normalizeServerName: jest.fn((name) => name),
|
||||||
|
convertWithResolvedRefs: jest.fn((params) => params),
|
||||||
|
get isMCPDomainAllowed() {
|
||||||
|
return mockIsMCPDomainAllowed;
|
||||||
|
},
|
||||||
|
MCPServersRegistry: {
|
||||||
|
getInstance: () => mockRegistryInstance,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
});
|
||||||
|
|
||||||
|
const { logger } = require('@librechat/data-schemas');
|
||||||
|
const { MCPOAuthHandler } = require('@librechat/api');
|
||||||
|
const { CacheKeys } = require('librechat-data-provider');
|
||||||
|
const {
|
||||||
|
createMCPTool,
|
||||||
|
createMCPTools,
|
||||||
|
getMCPSetupData,
|
||||||
|
checkOAuthFlowStatus,
|
||||||
|
getServerConnectionStatus,
|
||||||
|
} = require('./MCP');
|
||||||
|
|
||||||
jest.mock('librechat-data-provider', () => ({
|
jest.mock('librechat-data-provider', () => ({
|
||||||
CacheKeys: {
|
CacheKeys: {
|
||||||
|
|
@ -80,7 +94,9 @@ jest.mock('librechat-data-provider', () => ({
|
||||||
|
|
||||||
jest.mock('./Config', () => ({
|
jest.mock('./Config', () => ({
|
||||||
loadCustomConfig: jest.fn(),
|
loadCustomConfig: jest.fn(),
|
||||||
getAppConfig: jest.fn(),
|
get getAppConfig() {
|
||||||
|
return mockGetAppConfig;
|
||||||
|
},
|
||||||
}));
|
}));
|
||||||
|
|
||||||
jest.mock('~/config', () => ({
|
jest.mock('~/config', () => ({
|
||||||
|
|
@ -692,6 +708,18 @@ describe('User parameter passing tests', () => {
|
||||||
createFlowWithHandler: jest.fn(),
|
createFlowWithHandler: jest.fn(),
|
||||||
failFlow: jest.fn(),
|
failFlow: jest.fn(),
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// Reset domain validation mock to default (allow all)
|
||||||
|
mockIsMCPDomainAllowed.mockReset();
|
||||||
|
mockIsMCPDomainAllowed.mockResolvedValue(true);
|
||||||
|
|
||||||
|
// Reset registry mocks
|
||||||
|
mockRegistryInstance.getServerConfig.mockReset();
|
||||||
|
mockRegistryInstance.getServerConfig.mockResolvedValue(null);
|
||||||
|
|
||||||
|
// Reset getAppConfig mock to default (no restrictions)
|
||||||
|
mockGetAppConfig.mockReset();
|
||||||
|
mockGetAppConfig.mockResolvedValue({});
|
||||||
});
|
});
|
||||||
|
|
||||||
describe('createMCPTools', () => {
|
describe('createMCPTools', () => {
|
||||||
|
|
@ -887,6 +915,229 @@ describe('User parameter passing tests', () => {
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe('Runtime domain validation', () => {
|
||||||
|
it('should skip tool creation when domain is not allowed', async () => {
|
||||||
|
const mockUser = { id: 'domain-test-user', role: 'user' };
|
||||||
|
const mockRes = { write: jest.fn(), flush: jest.fn() };
|
||||||
|
|
||||||
|
// Mock server config with URL (remote server)
|
||||||
|
mockRegistryInstance.getServerConfig.mockResolvedValue({
|
||||||
|
url: 'https://disallowed-domain.com/sse',
|
||||||
|
});
|
||||||
|
|
||||||
|
// Mock getAppConfig to return domain restrictions
|
||||||
|
mockGetAppConfig.mockResolvedValue({
|
||||||
|
mcpSettings: { allowedDomains: ['allowed-domain.com'] },
|
||||||
|
});
|
||||||
|
|
||||||
|
// Mock domain validation to return false (domain not allowed)
|
||||||
|
mockIsMCPDomainAllowed.mockResolvedValueOnce(false);
|
||||||
|
|
||||||
|
const result = await createMCPTool({
|
||||||
|
res: mockRes,
|
||||||
|
user: mockUser,
|
||||||
|
toolKey: 'test-tool::test-server',
|
||||||
|
provider: 'openai',
|
||||||
|
userMCPAuthMap: {},
|
||||||
|
availableTools: {
|
||||||
|
'test-tool::test-server': {
|
||||||
|
function: {
|
||||||
|
description: 'Test tool',
|
||||||
|
parameters: { type: 'object', properties: {} },
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
// Should return undefined for disallowed domain
|
||||||
|
expect(result).toBeUndefined();
|
||||||
|
|
||||||
|
// Should not call reinitMCPServer since domain check failed
|
||||||
|
expect(mockReinitMCPServer).not.toHaveBeenCalled();
|
||||||
|
|
||||||
|
// Verify getAppConfig was called with user role
|
||||||
|
expect(mockGetAppConfig).toHaveBeenCalledWith({ role: 'user' });
|
||||||
|
|
||||||
|
// Verify domain validation was called with correct parameters
|
||||||
|
expect(mockIsMCPDomainAllowed).toHaveBeenCalledWith(
|
||||||
|
{ url: 'https://disallowed-domain.com/sse' },
|
||||||
|
['allowed-domain.com'],
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should allow tool creation when domain is allowed', async () => {
|
||||||
|
const mockUser = { id: 'domain-test-user', role: 'admin' };
|
||||||
|
const mockRes = { write: jest.fn(), flush: jest.fn() };
|
||||||
|
|
||||||
|
// Mock server config with URL (remote server)
|
||||||
|
mockRegistryInstance.getServerConfig.mockResolvedValue({
|
||||||
|
url: 'https://allowed-domain.com/sse',
|
||||||
|
});
|
||||||
|
|
||||||
|
// Mock getAppConfig to return domain restrictions
|
||||||
|
mockGetAppConfig.mockResolvedValue({
|
||||||
|
mcpSettings: { allowedDomains: ['allowed-domain.com'] },
|
||||||
|
});
|
||||||
|
|
||||||
|
// Mock domain validation to return true (domain allowed)
|
||||||
|
mockIsMCPDomainAllowed.mockResolvedValueOnce(true);
|
||||||
|
|
||||||
|
const availableTools = {
|
||||||
|
'test-tool::test-server': {
|
||||||
|
function: {
|
||||||
|
description: 'Test tool',
|
||||||
|
parameters: { type: 'object', properties: {} },
|
||||||
|
},
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
const result = await createMCPTool({
|
||||||
|
res: mockRes,
|
||||||
|
user: mockUser,
|
||||||
|
toolKey: 'test-tool::test-server',
|
||||||
|
provider: 'openai',
|
||||||
|
userMCPAuthMap: {},
|
||||||
|
availableTools,
|
||||||
|
});
|
||||||
|
|
||||||
|
// Should create tool successfully
|
||||||
|
expect(result).toBeDefined();
|
||||||
|
|
||||||
|
// Verify getAppConfig was called with user role
|
||||||
|
expect(mockGetAppConfig).toHaveBeenCalledWith({ role: 'admin' });
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should skip domain validation for stdio transports (no URL)', async () => {
|
||||||
|
const mockUser = { id: 'stdio-test-user' };
|
||||||
|
const mockRes = { write: jest.fn(), flush: jest.fn() };
|
||||||
|
|
||||||
|
// Mock server config without URL (stdio transport)
|
||||||
|
mockRegistryInstance.getServerConfig.mockResolvedValue({
|
||||||
|
command: 'npx',
|
||||||
|
args: ['@modelcontextprotocol/server'],
|
||||||
|
});
|
||||||
|
|
||||||
|
// Mock getAppConfig (should not be called for stdio)
|
||||||
|
mockGetAppConfig.mockResolvedValue({
|
||||||
|
mcpSettings: { allowedDomains: ['restricted-domain.com'] },
|
||||||
|
});
|
||||||
|
|
||||||
|
const availableTools = {
|
||||||
|
'test-tool::test-server': {
|
||||||
|
function: {
|
||||||
|
description: 'Test tool',
|
||||||
|
parameters: { type: 'object', properties: {} },
|
||||||
|
},
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
const result = await createMCPTool({
|
||||||
|
res: mockRes,
|
||||||
|
user: mockUser,
|
||||||
|
toolKey: 'test-tool::test-server',
|
||||||
|
provider: 'openai',
|
||||||
|
userMCPAuthMap: {},
|
||||||
|
availableTools,
|
||||||
|
});
|
||||||
|
|
||||||
|
// Should create tool successfully without domain check
|
||||||
|
expect(result).toBeDefined();
|
||||||
|
|
||||||
|
// Should not call getAppConfig or isMCPDomainAllowed for stdio transport (no URL)
|
||||||
|
expect(mockGetAppConfig).not.toHaveBeenCalled();
|
||||||
|
expect(mockIsMCPDomainAllowed).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return empty array from createMCPTools when domain is not allowed', async () => {
|
||||||
|
const mockUser = { id: 'domain-test-user', role: 'user' };
|
||||||
|
const mockRes = { write: jest.fn(), flush: jest.fn() };
|
||||||
|
|
||||||
|
// Mock server config with URL (remote server)
|
||||||
|
const serverConfig = { url: 'https://disallowed-domain.com/sse' };
|
||||||
|
mockRegistryInstance.getServerConfig.mockResolvedValue(serverConfig);
|
||||||
|
|
||||||
|
// Mock getAppConfig to return domain restrictions
|
||||||
|
mockGetAppConfig.mockResolvedValue({
|
||||||
|
mcpSettings: { allowedDomains: ['allowed-domain.com'] },
|
||||||
|
});
|
||||||
|
|
||||||
|
// Mock domain validation to return false (domain not allowed)
|
||||||
|
mockIsMCPDomainAllowed.mockResolvedValueOnce(false);
|
||||||
|
|
||||||
|
const result = await createMCPTools({
|
||||||
|
res: mockRes,
|
||||||
|
user: mockUser,
|
||||||
|
serverName: 'test-server',
|
||||||
|
provider: 'openai',
|
||||||
|
userMCPAuthMap: {},
|
||||||
|
config: serverConfig,
|
||||||
|
});
|
||||||
|
|
||||||
|
// Should return empty array for disallowed domain
|
||||||
|
expect(result).toEqual([]);
|
||||||
|
|
||||||
|
// Should not call reinitMCPServer since domain check failed early
|
||||||
|
expect(mockReinitMCPServer).not.toHaveBeenCalled();
|
||||||
|
|
||||||
|
// Verify getAppConfig was called with user role
|
||||||
|
expect(mockGetAppConfig).toHaveBeenCalledWith({ role: 'user' });
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should use user role when fetching domain restrictions', async () => {
|
||||||
|
const adminUser = { id: 'admin-user', role: 'admin' };
|
||||||
|
const regularUser = { id: 'regular-user', role: 'user' };
|
||||||
|
const mockRes = { write: jest.fn(), flush: jest.fn() };
|
||||||
|
|
||||||
|
mockRegistryInstance.getServerConfig.mockResolvedValue({
|
||||||
|
url: 'https://some-domain.com/sse',
|
||||||
|
});
|
||||||
|
|
||||||
|
// Mock different responses based on role
|
||||||
|
mockGetAppConfig
|
||||||
|
.mockResolvedValueOnce({ mcpSettings: { allowedDomains: ['admin-allowed.com'] } })
|
||||||
|
.mockResolvedValueOnce({ mcpSettings: { allowedDomains: ['user-allowed.com'] } });
|
||||||
|
|
||||||
|
mockIsMCPDomainAllowed.mockResolvedValue(true);
|
||||||
|
|
||||||
|
const availableTools = {
|
||||||
|
'test-tool::test-server': {
|
||||||
|
function: {
|
||||||
|
description: 'Test tool',
|
||||||
|
parameters: { type: 'object', properties: {} },
|
||||||
|
},
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
// Call with admin user
|
||||||
|
await createMCPTool({
|
||||||
|
res: mockRes,
|
||||||
|
user: adminUser,
|
||||||
|
toolKey: 'test-tool::test-server',
|
||||||
|
provider: 'openai',
|
||||||
|
userMCPAuthMap: {},
|
||||||
|
availableTools,
|
||||||
|
});
|
||||||
|
|
||||||
|
// Reset and call with regular user
|
||||||
|
mockRegistryInstance.getServerConfig.mockResolvedValue({
|
||||||
|
url: 'https://some-domain.com/sse',
|
||||||
|
});
|
||||||
|
|
||||||
|
await createMCPTool({
|
||||||
|
res: mockRes,
|
||||||
|
user: regularUser,
|
||||||
|
toolKey: 'test-tool::test-server',
|
||||||
|
provider: 'openai',
|
||||||
|
userMCPAuthMap: {},
|
||||||
|
availableTools,
|
||||||
|
});
|
||||||
|
|
||||||
|
// Verify getAppConfig was called with correct roles
|
||||||
|
expect(mockGetAppConfig).toHaveBeenNthCalledWith(1, { role: 'admin' });
|
||||||
|
expect(mockGetAppConfig).toHaveBeenNthCalledWith(2, { role: 'user' });
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
describe('User parameter integrity', () => {
|
describe('User parameter integrity', () => {
|
||||||
it('should preserve user object properties through the call chain', async () => {
|
it('should preserve user object properties through the call chain', async () => {
|
||||||
const complexUser = {
|
const complexUser = {
|
||||||
|
|
|
||||||
|
|
@ -14,8 +14,9 @@ async function initializeMCPs() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize MCPServersRegistry first (required for MCPManager)
|
// Initialize MCPServersRegistry first (required for MCPManager)
|
||||||
|
// Pass allowedDomains from mcpSettings for domain validation
|
||||||
try {
|
try {
|
||||||
createMCPServersRegistry(mongoose);
|
createMCPServersRegistry(mongoose, appConfig?.mcpSettings?.allowedDomains);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
logger.error('[MCP] Failed to initialize MCPServersRegistry:', error);
|
logger.error('[MCP] Failed to initialize MCPServersRegistry:', error);
|
||||||
throw error;
|
throw error;
|
||||||
|
|
|
||||||
|
|
@ -308,6 +308,8 @@ export default function MCPServerDialog({
|
||||||
const axiosError = error as any;
|
const axiosError = error as any;
|
||||||
if (axiosError.response?.data?.error === 'MCP_INSPECTION_FAILED') {
|
if (axiosError.response?.data?.error === 'MCP_INSPECTION_FAILED') {
|
||||||
errorMessage = localize('com_ui_mcp_server_connection_failed');
|
errorMessage = localize('com_ui_mcp_server_connection_failed');
|
||||||
|
} else if (axiosError.response?.data?.error === 'MCP_DOMAIN_NOT_ALLOWED') {
|
||||||
|
errorMessage = localize('com_ui_mcp_domain_not_allowed');
|
||||||
} else if (axiosError.response?.data?.error) {
|
} else if (axiosError.response?.data?.error) {
|
||||||
errorMessage = axiosError.response.data.error;
|
errorMessage = axiosError.response.data.error;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1049,6 +1049,7 @@
|
||||||
"com_ui_mcp_configure_server": "Configure {{0}}",
|
"com_ui_mcp_configure_server": "Configure {{0}}",
|
||||||
"com_ui_mcp_configure_server_description": "Configure custom variables for {{0}}",
|
"com_ui_mcp_configure_server_description": "Configure custom variables for {{0}}",
|
||||||
"com_ui_mcp_dialog_title": "Configure Variables for {{serverName}}. Server Status: {{status}}",
|
"com_ui_mcp_dialog_title": "Configure Variables for {{serverName}}. Server Status: {{status}}",
|
||||||
|
"com_ui_mcp_domain_not_allowed": "The MCP server domain is not in the allowed domains list. Please contact your administrator.",
|
||||||
"com_ui_mcp_enter_var": "Enter value for {{0}}",
|
"com_ui_mcp_enter_var": "Enter value for {{0}}",
|
||||||
"com_ui_mcp_init_failed": "Failed to initialize MCP server",
|
"com_ui_mcp_init_failed": "Failed to initialize MCP server",
|
||||||
"com_ui_mcp_initialize": "Initialize",
|
"com_ui_mcp_initialize": "Initialize",
|
||||||
|
|
|
||||||
|
|
@ -184,6 +184,16 @@ actions:
|
||||||
- 'librechat.ai'
|
- 'librechat.ai'
|
||||||
- 'google.com'
|
- 'google.com'
|
||||||
|
|
||||||
|
# MCP Server domain restrictions for remote transports (SSE, WebSocket, HTTP)
|
||||||
|
# Stdio transports (local processes) are not restricted.
|
||||||
|
# If not configured, all domains are allowed (permissive default).
|
||||||
|
# Supports wildcards: '*.example.com' matches 'api.example.com', 'staging.example.com', etc.
|
||||||
|
# mcpSettings:
|
||||||
|
# allowedDomains:
|
||||||
|
# - 'localhost'
|
||||||
|
# - '*.example.com'
|
||||||
|
# - 'trusted-mcp-provider.com'
|
||||||
|
|
||||||
# Example MCP Servers Object Structure
|
# Example MCP Servers Object Structure
|
||||||
# mcpServers:
|
# mcpServers:
|
||||||
# everything:
|
# everything:
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,10 @@
|
||||||
/* eslint-disable @typescript-eslint/ban-ts-comment */
|
/* eslint-disable @typescript-eslint/ban-ts-comment */
|
||||||
import { isEmailDomainAllowed, isActionDomainAllowed } from './domain';
|
import {
|
||||||
|
isEmailDomainAllowed,
|
||||||
|
isActionDomainAllowed,
|
||||||
|
extractMCPServerDomain,
|
||||||
|
isMCPDomainAllowed,
|
||||||
|
} from './domain';
|
||||||
|
|
||||||
describe('isEmailDomainAllowed', () => {
|
describe('isEmailDomainAllowed', () => {
|
||||||
afterEach(() => {
|
afterEach(() => {
|
||||||
|
|
@ -213,3 +218,209 @@ describe('isActionDomainAllowed', () => {
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe('extractMCPServerDomain', () => {
|
||||||
|
afterEach(() => {
|
||||||
|
jest.clearAllMocks();
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('URL extraction', () => {
|
||||||
|
it('should extract domain from HTTPS URL', () => {
|
||||||
|
const config = { url: 'https://api.example.com/sse' };
|
||||||
|
expect(extractMCPServerDomain(config)).toBe('api.example.com');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should extract domain from HTTP URL', () => {
|
||||||
|
const config = { url: 'http://api.example.com/sse' };
|
||||||
|
expect(extractMCPServerDomain(config)).toBe('api.example.com');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should extract domain from WebSocket URL', () => {
|
||||||
|
const config = { url: 'wss://ws.example.com' };
|
||||||
|
expect(extractMCPServerDomain(config)).toBe('ws.example.com');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle URL with port', () => {
|
||||||
|
const config = { url: 'https://localhost:3001/sse' };
|
||||||
|
expect(extractMCPServerDomain(config)).toBe('localhost');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should strip www prefix', () => {
|
||||||
|
const config = { url: 'https://www.example.com/api' };
|
||||||
|
expect(extractMCPServerDomain(config)).toBe('example.com');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle URL with path and query parameters', () => {
|
||||||
|
const config = { url: 'https://api.example.com/v1/sse?token=abc' };
|
||||||
|
expect(extractMCPServerDomain(config)).toBe('api.example.com');
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('stdio transports (no URL)', () => {
|
||||||
|
it('should return null for stdio transport with command only', () => {
|
||||||
|
const config = { command: 'npx', args: ['-y', '@modelcontextprotocol/server-puppeteer'] };
|
||||||
|
expect(extractMCPServerDomain(config)).toBeNull();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return null when url is undefined', () => {
|
||||||
|
const config = { command: 'node', args: ['server.js'] };
|
||||||
|
expect(extractMCPServerDomain(config)).toBeNull();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return null for empty object', () => {
|
||||||
|
const config = {};
|
||||||
|
expect(extractMCPServerDomain(config)).toBeNull();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('invalid URLs', () => {
|
||||||
|
it('should return null for invalid URL format', () => {
|
||||||
|
const config = { url: 'not-a-valid-url' };
|
||||||
|
expect(extractMCPServerDomain(config)).toBeNull();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return null for empty URL string', () => {
|
||||||
|
const config = { url: '' };
|
||||||
|
expect(extractMCPServerDomain(config)).toBeNull();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return null for non-string url', () => {
|
||||||
|
const config = { url: 12345 };
|
||||||
|
expect(extractMCPServerDomain(config)).toBeNull();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return null for null url', () => {
|
||||||
|
const config = { url: null };
|
||||||
|
expect(extractMCPServerDomain(config)).toBeNull();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('isMCPDomainAllowed', () => {
|
||||||
|
afterEach(() => {
|
||||||
|
jest.clearAllMocks();
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('stdio transports (always allowed)', () => {
|
||||||
|
it('should allow stdio transport regardless of allowlist', async () => {
|
||||||
|
const config = { command: 'npx', args: ['-y', '@modelcontextprotocol/server-puppeteer'] };
|
||||||
|
expect(await isMCPDomainAllowed(config, ['example.com'])).toBe(true);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should allow stdio transport even with empty allowlist', async () => {
|
||||||
|
const config = { command: 'node', args: ['server.js'] };
|
||||||
|
expect(await isMCPDomainAllowed(config, [])).toBe(true);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should allow stdio transport when no URL present', async () => {
|
||||||
|
const config = {};
|
||||||
|
expect(await isMCPDomainAllowed(config, ['restricted.com'])).toBe(true);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('permissive defaults (no restrictions)', () => {
|
||||||
|
it('should allow all domains when allowedDomains is null', async () => {
|
||||||
|
const config = { url: 'https://any-domain.com/sse' };
|
||||||
|
expect(await isMCPDomainAllowed(config, null)).toBe(true);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should allow all domains when allowedDomains is undefined', async () => {
|
||||||
|
const config = { url: 'https://any-domain.com/sse' };
|
||||||
|
expect(await isMCPDomainAllowed(config, undefined)).toBe(true);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should allow all domains when allowedDomains is empty array', async () => {
|
||||||
|
const config = { url: 'https://any-domain.com/sse' };
|
||||||
|
expect(await isMCPDomainAllowed(config, [])).toBe(true);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('exact domain matching', () => {
|
||||||
|
const allowedDomains = ['example.com', 'localhost', 'trusted-mcp.com'];
|
||||||
|
|
||||||
|
it('should allow exact domain match', async () => {
|
||||||
|
const config = { url: 'https://example.com/api' };
|
||||||
|
expect(await isMCPDomainAllowed(config, allowedDomains)).toBe(true);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should allow localhost', async () => {
|
||||||
|
const config = { url: 'http://localhost:3001/sse' };
|
||||||
|
expect(await isMCPDomainAllowed(config, allowedDomains)).toBe(true);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should reject non-allowed domain', async () => {
|
||||||
|
const config = { url: 'https://malicious.com/sse' };
|
||||||
|
expect(await isMCPDomainAllowed(config, allowedDomains)).toBe(false);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should reject subdomain when only parent is allowed', async () => {
|
||||||
|
const config = { url: 'https://api.example.com/sse' };
|
||||||
|
expect(await isMCPDomainAllowed(config, allowedDomains)).toBe(false);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('wildcard domain matching', () => {
|
||||||
|
const allowedDomains = ['*.example.com', 'localhost'];
|
||||||
|
|
||||||
|
it('should allow subdomain with wildcard', async () => {
|
||||||
|
const config = { url: 'https://api.example.com/sse' };
|
||||||
|
expect(await isMCPDomainAllowed(config, allowedDomains)).toBe(true);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should allow any subdomain with wildcard', async () => {
|
||||||
|
const config = { url: 'https://staging.example.com/sse' };
|
||||||
|
expect(await isMCPDomainAllowed(config, allowedDomains)).toBe(true);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should allow base domain with wildcard', async () => {
|
||||||
|
const config = { url: 'https://example.com/sse' };
|
||||||
|
expect(await isMCPDomainAllowed(config, allowedDomains)).toBe(true);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should allow nested subdomain with wildcard', async () => {
|
||||||
|
const config = { url: 'https://deep.nested.example.com/sse' };
|
||||||
|
expect(await isMCPDomainAllowed(config, allowedDomains)).toBe(true);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should reject different domain even with wildcard', async () => {
|
||||||
|
const config = { url: 'https://api.other.com/sse' };
|
||||||
|
expect(await isMCPDomainAllowed(config, allowedDomains)).toBe(false);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('case insensitivity', () => {
|
||||||
|
it('should match domains case-insensitively', async () => {
|
||||||
|
const config = { url: 'https://EXAMPLE.COM/sse' };
|
||||||
|
expect(await isMCPDomainAllowed(config, ['example.com'])).toBe(true);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should match with uppercase in allowlist', async () => {
|
||||||
|
const config = { url: 'https://example.com/sse' };
|
||||||
|
expect(await isMCPDomainAllowed(config, ['EXAMPLE.COM'])).toBe(true);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should match with mixed case', async () => {
|
||||||
|
const config = { url: 'https://Api.Example.Com/sse' };
|
||||||
|
expect(await isMCPDomainAllowed(config, ['*.example.com'])).toBe(true);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('www prefix handling', () => {
|
||||||
|
it('should strip www prefix from URL before matching', async () => {
|
||||||
|
const config = { url: 'https://www.example.com/sse' };
|
||||||
|
expect(await isMCPDomainAllowed(config, ['example.com'])).toBe(true);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should match www in allowlist to non-www URL', async () => {
|
||||||
|
const config = { url: 'https://example.com/sse' };
|
||||||
|
expect(await isMCPDomainAllowed(config, ['www.example.com'])).toBe(true);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('invalid URL handling', () => {
|
||||||
|
it('should allow config with invalid URL (treated as stdio)', async () => {
|
||||||
|
const config = { url: 'not-a-valid-url' };
|
||||||
|
expect(await isMCPDomainAllowed(config, ['example.com'])).toBe(true);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
|
||||||
|
|
@ -96,3 +96,45 @@ export async function isActionDomainAllowed(
|
||||||
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Extracts domain from MCP server config URL.
|
||||||
|
* Returns null for stdio transports (no URL) or invalid URLs.
|
||||||
|
* @param config - MCP server configuration (accepts any config with optional url field)
|
||||||
|
*/
|
||||||
|
export function extractMCPServerDomain(config: Record<string, unknown>): string | null {
|
||||||
|
const url = config.url;
|
||||||
|
// Stdio transports don't have URLs - always allowed
|
||||||
|
if (!url || typeof url !== 'string') {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
const parsedUrl = new URL(url);
|
||||||
|
return parsedUrl.hostname.replace(/^www\./i, '');
|
||||||
|
} catch {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Validates MCP server domain against allowedDomains.
|
||||||
|
* Reuses isActionDomainAllowed for consistent validation logic.
|
||||||
|
* Stdio transports (no URL) are always allowed.
|
||||||
|
* @param config - MCP server configuration with optional url field
|
||||||
|
* @param allowedDomains - List of allowed domains (with wildcard support)
|
||||||
|
*/
|
||||||
|
export async function isMCPDomainAllowed(
|
||||||
|
config: Record<string, unknown>,
|
||||||
|
allowedDomains?: string[] | null,
|
||||||
|
): Promise<boolean> {
|
||||||
|
const domain = extractMCPServerDomain(config);
|
||||||
|
|
||||||
|
// Stdio transports don't have domains - always allowed
|
||||||
|
if (!domain) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reuse existing validation logic (includes wildcard support)
|
||||||
|
return isActionDomainAllowed(domain, allowedDomains);
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ export * from './mcp/connection';
|
||||||
export * from './mcp/oauth';
|
export * from './mcp/oauth';
|
||||||
export * from './mcp/auth';
|
export * from './mcp/auth';
|
||||||
export * from './mcp/zod';
|
export * from './mcp/zod';
|
||||||
|
export * from './mcp/errors';
|
||||||
/* Utilities */
|
/* Utilities */
|
||||||
export * from './mcp/utils';
|
export * from './mcp/utils';
|
||||||
export * from './utils';
|
export * from './utils';
|
||||||
|
|
|
||||||
61
packages/api/src/mcp/errors.ts
Normal file
61
packages/api/src/mcp/errors.ts
Normal file
|
|
@ -0,0 +1,61 @@
|
||||||
|
/**
|
||||||
|
* MCP-specific error classes
|
||||||
|
*/
|
||||||
|
|
||||||
|
export const MCPErrorCodes = {
|
||||||
|
DOMAIN_NOT_ALLOWED: 'MCP_DOMAIN_NOT_ALLOWED',
|
||||||
|
INSPECTION_FAILED: 'MCP_INSPECTION_FAILED',
|
||||||
|
} as const;
|
||||||
|
|
||||||
|
export type MCPErrorCode = (typeof MCPErrorCodes)[keyof typeof MCPErrorCodes];
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Custom error for MCP domain restriction violations.
|
||||||
|
* Thrown when a user attempts to connect to an MCP server whose domain is not in the allowlist.
|
||||||
|
*/
|
||||||
|
export class MCPDomainNotAllowedError extends Error {
|
||||||
|
public readonly code = MCPErrorCodes.DOMAIN_NOT_ALLOWED;
|
||||||
|
public readonly statusCode = 403;
|
||||||
|
public readonly domain: string;
|
||||||
|
|
||||||
|
constructor(domain: string) {
|
||||||
|
super(`Domain "${domain}" is not allowed`);
|
||||||
|
this.name = 'MCPDomainNotAllowedError';
|
||||||
|
this.domain = domain;
|
||||||
|
Object.setPrototypeOf(this, MCPDomainNotAllowedError.prototype);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Custom error for MCP server inspection failures.
|
||||||
|
* Thrown when attempting to connect/inspect an MCP server fails.
|
||||||
|
*/
|
||||||
|
export class MCPInspectionFailedError extends Error {
|
||||||
|
public readonly code = MCPErrorCodes.INSPECTION_FAILED;
|
||||||
|
public readonly statusCode = 400;
|
||||||
|
public readonly serverName: string;
|
||||||
|
|
||||||
|
constructor(serverName: string, cause?: Error) {
|
||||||
|
super(`Failed to connect to MCP server "${serverName}"`);
|
||||||
|
this.name = 'MCPInspectionFailedError';
|
||||||
|
this.serverName = serverName;
|
||||||
|
if (cause) {
|
||||||
|
this.cause = cause;
|
||||||
|
}
|
||||||
|
Object.setPrototypeOf(this, MCPInspectionFailedError.prototype);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Type guard to check if an error is an MCPDomainNotAllowedError
|
||||||
|
*/
|
||||||
|
export function isMCPDomainNotAllowedError(error: unknown): error is MCPDomainNotAllowedError {
|
||||||
|
return error instanceof MCPDomainNotAllowedError;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Type guard to check if an error is an MCPInspectionFailedError
|
||||||
|
*/
|
||||||
|
export function isMCPInspectionFailedError(error: unknown): error is MCPInspectionFailedError {
|
||||||
|
return error instanceof MCPInspectionFailedError;
|
||||||
|
}
|
||||||
|
|
@ -2,7 +2,9 @@ import { Constants } from 'librechat-data-provider';
|
||||||
import type { JsonSchemaType } from '@librechat/data-schemas';
|
import type { JsonSchemaType } from '@librechat/data-schemas';
|
||||||
import type { MCPConnection } from '~/mcp/connection';
|
import type { MCPConnection } from '~/mcp/connection';
|
||||||
import type * as t from '~/mcp/types';
|
import type * as t from '~/mcp/types';
|
||||||
|
import { isMCPDomainAllowed, extractMCPServerDomain } from '~/auth/domain';
|
||||||
import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory';
|
import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory';
|
||||||
|
import { MCPDomainNotAllowedError } from '~/mcp/errors';
|
||||||
import { detectOAuthRequirement } from '~/mcp/oauth';
|
import { detectOAuthRequirement } from '~/mcp/oauth';
|
||||||
import { isEnabled } from '~/utils';
|
import { isEnabled } from '~/utils';
|
||||||
|
|
||||||
|
|
@ -24,13 +26,22 @@ export class MCPServerInspector {
|
||||||
* @param serverName - The name of the server (used for tool function naming)
|
* @param serverName - The name of the server (used for tool function naming)
|
||||||
* @param rawConfig - The raw server configuration
|
* @param rawConfig - The raw server configuration
|
||||||
* @param connection - The MCP connection
|
* @param connection - The MCP connection
|
||||||
|
* @param allowedDomains - Optional list of allowed domains for remote transports
|
||||||
* @returns A fully processed and enriched configuration with server metadata
|
* @returns A fully processed and enriched configuration with server metadata
|
||||||
*/
|
*/
|
||||||
public static async inspect(
|
public static async inspect(
|
||||||
serverName: string,
|
serverName: string,
|
||||||
rawConfig: t.MCPOptions,
|
rawConfig: t.MCPOptions,
|
||||||
connection?: MCPConnection,
|
connection?: MCPConnection,
|
||||||
|
allowedDomains?: string[] | null,
|
||||||
): Promise<t.ParsedServerConfig> {
|
): Promise<t.ParsedServerConfig> {
|
||||||
|
// Validate domain against allowlist BEFORE attempting connection
|
||||||
|
const isDomainAllowed = await isMCPDomainAllowed(rawConfig, allowedDomains);
|
||||||
|
if (!isDomainAllowed) {
|
||||||
|
const domain = extractMCPServerDomain(rawConfig);
|
||||||
|
throw new MCPDomainNotAllowedError(domain ?? 'unknown');
|
||||||
|
}
|
||||||
|
|
||||||
const start = Date.now();
|
const start = Date.now();
|
||||||
const inspector = new MCPServerInspector(serverName, rawConfig, connection);
|
const inspector = new MCPServerInspector(serverName, rawConfig, connection);
|
||||||
await inspector.inspectServer();
|
await inspector.inspectServer();
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
import { logger } from '@librechat/data-schemas';
|
import { logger } from '@librechat/data-schemas';
|
||||||
import type { IServerConfigsRepositoryInterface } from './ServerConfigsRepositoryInterface';
|
import type { IServerConfigsRepositoryInterface } from './ServerConfigsRepositoryInterface';
|
||||||
import type * as t from '~/mcp/types';
|
import type * as t from '~/mcp/types';
|
||||||
|
import { MCPInspectionFailedError, isMCPDomainNotAllowedError } from '~/mcp/errors';
|
||||||
import { ServerConfigsCacheFactory } from './cache/ServerConfigsCacheFactory';
|
import { ServerConfigsCacheFactory } from './cache/ServerConfigsCacheFactory';
|
||||||
import { MCPServerInspector } from './MCPServerInspector';
|
import { MCPServerInspector } from './MCPServerInspector';
|
||||||
import { ServerConfigsDB } from './db/ServerConfigsDB';
|
import { ServerConfigsDB } from './db/ServerConfigsDB';
|
||||||
|
|
@ -20,14 +21,19 @@ export class MCPServersRegistry {
|
||||||
|
|
||||||
private readonly dbConfigsRepo: IServerConfigsRepositoryInterface;
|
private readonly dbConfigsRepo: IServerConfigsRepositoryInterface;
|
||||||
private readonly cacheConfigsRepo: IServerConfigsRepositoryInterface;
|
private readonly cacheConfigsRepo: IServerConfigsRepositoryInterface;
|
||||||
|
private readonly allowedDomains?: string[] | null;
|
||||||
|
|
||||||
constructor(mongoose: typeof import('mongoose')) {
|
constructor(mongoose: typeof import('mongoose'), allowedDomains?: string[] | null) {
|
||||||
this.dbConfigsRepo = new ServerConfigsDB(mongoose);
|
this.dbConfigsRepo = new ServerConfigsDB(mongoose);
|
||||||
this.cacheConfigsRepo = ServerConfigsCacheFactory.create('App', false);
|
this.cacheConfigsRepo = ServerConfigsCacheFactory.create('App', false);
|
||||||
|
this.allowedDomains = allowedDomains;
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Creates and initializes the singleton MCPServersRegistry instance */
|
/** Creates and initializes the singleton MCPServersRegistry instance */
|
||||||
public static createInstance(mongoose: typeof import('mongoose')): MCPServersRegistry {
|
public static createInstance(
|
||||||
|
mongoose: typeof import('mongoose'),
|
||||||
|
allowedDomains?: string[] | null,
|
||||||
|
): MCPServersRegistry {
|
||||||
if (!mongoose) {
|
if (!mongoose) {
|
||||||
throw new Error(
|
throw new Error(
|
||||||
'MCPServersRegistry creation failed: mongoose instance is required for database operations. ' +
|
'MCPServersRegistry creation failed: mongoose instance is required for database operations. ' +
|
||||||
|
|
@ -39,7 +45,7 @@ export class MCPServersRegistry {
|
||||||
return MCPServersRegistry.instance;
|
return MCPServersRegistry.instance;
|
||||||
}
|
}
|
||||||
logger.info('[MCPServersRegistry] Creating new instance');
|
logger.info('[MCPServersRegistry] Creating new instance');
|
||||||
MCPServersRegistry.instance = new MCPServersRegistry(mongoose);
|
MCPServersRegistry.instance = new MCPServersRegistry(mongoose, allowedDomains);
|
||||||
return MCPServersRegistry.instance;
|
return MCPServersRegistry.instance;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -80,10 +86,19 @@ export class MCPServersRegistry {
|
||||||
const configRepo = this.getConfigRepository(storageLocation);
|
const configRepo = this.getConfigRepository(storageLocation);
|
||||||
let parsedConfig: t.ParsedServerConfig;
|
let parsedConfig: t.ParsedServerConfig;
|
||||||
try {
|
try {
|
||||||
parsedConfig = await MCPServerInspector.inspect(serverName, config);
|
parsedConfig = await MCPServerInspector.inspect(
|
||||||
|
serverName,
|
||||||
|
config,
|
||||||
|
undefined,
|
||||||
|
this.allowedDomains,
|
||||||
|
);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
logger.error(`[MCPServersRegistry] Failed to inspect server "${serverName}":`, error);
|
logger.error(`[MCPServersRegistry] Failed to inspect server "${serverName}":`, error);
|
||||||
throw new Error(`MCP_INSPECTION_FAILED: Failed to connect to MCP server "${serverName}"`);
|
// Preserve domain-specific error for better error handling
|
||||||
|
if (isMCPDomainNotAllowedError(error)) {
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
throw new MCPInspectionFailedError(serverName, error as Error);
|
||||||
}
|
}
|
||||||
return await configRepo.add(serverName, parsedConfig, userId);
|
return await configRepo.add(serverName, parsedConfig, userId);
|
||||||
}
|
}
|
||||||
|
|
@ -113,10 +128,19 @@ export class MCPServersRegistry {
|
||||||
|
|
||||||
let parsedConfig: t.ParsedServerConfig;
|
let parsedConfig: t.ParsedServerConfig;
|
||||||
try {
|
try {
|
||||||
parsedConfig = await MCPServerInspector.inspect(serverName, configForInspection);
|
parsedConfig = await MCPServerInspector.inspect(
|
||||||
|
serverName,
|
||||||
|
configForInspection,
|
||||||
|
undefined,
|
||||||
|
this.allowedDomains,
|
||||||
|
);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
logger.error(`[MCPServersRegistry] Failed to inspect server "${serverName}":`, error);
|
logger.error(`[MCPServersRegistry] Failed to inspect server "${serverName}":`, error);
|
||||||
throw new Error(`MCP_INSPECTION_FAILED: Failed to connect to MCP server "${serverName}"`);
|
// Preserve domain-specific error for better error handling
|
||||||
|
if (isMCPDomainNotAllowedError(error)) {
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
throw new MCPInspectionFailedError(serverName, error as Error);
|
||||||
}
|
}
|
||||||
await configRepo.update(serverName, parsedConfig, userId);
|
await configRepo.update(serverName, parsedConfig, userId);
|
||||||
return parsedConfig;
|
return parsedConfig;
|
||||||
|
|
|
||||||
|
|
@ -224,18 +224,38 @@ describe('MCPServersInitializer', () => {
|
||||||
it('should process all server configs through inspector', async () => {
|
it('should process all server configs through inspector', async () => {
|
||||||
await MCPServersInitializer.initialize(testConfigs);
|
await MCPServersInitializer.initialize(testConfigs);
|
||||||
|
|
||||||
// Verify all configs were processed by inspector (without connection parameter)
|
// Verify all configs were processed by inspector
|
||||||
|
// Signature: inspect(serverName, rawConfig, connection?, allowedDomains?)
|
||||||
expect(mockInspect).toHaveBeenCalledTimes(5);
|
expect(mockInspect).toHaveBeenCalledTimes(5);
|
||||||
expect(mockInspect).toHaveBeenCalledWith('disabled_server', testConfigs.disabled_server);
|
expect(mockInspect).toHaveBeenCalledWith(
|
||||||
expect(mockInspect).toHaveBeenCalledWith('oauth_server', testConfigs.oauth_server);
|
'disabled_server',
|
||||||
expect(mockInspect).toHaveBeenCalledWith('file_tools_server', testConfigs.file_tools_server);
|
testConfigs.disabled_server,
|
||||||
|
undefined,
|
||||||
|
undefined,
|
||||||
|
);
|
||||||
|
expect(mockInspect).toHaveBeenCalledWith(
|
||||||
|
'oauth_server',
|
||||||
|
testConfigs.oauth_server,
|
||||||
|
undefined,
|
||||||
|
undefined,
|
||||||
|
);
|
||||||
|
expect(mockInspect).toHaveBeenCalledWith(
|
||||||
|
'file_tools_server',
|
||||||
|
testConfigs.file_tools_server,
|
||||||
|
undefined,
|
||||||
|
undefined,
|
||||||
|
);
|
||||||
expect(mockInspect).toHaveBeenCalledWith(
|
expect(mockInspect).toHaveBeenCalledWith(
|
||||||
'search_tools_server',
|
'search_tools_server',
|
||||||
testConfigs.search_tools_server,
|
testConfigs.search_tools_server,
|
||||||
|
undefined,
|
||||||
|
undefined,
|
||||||
);
|
);
|
||||||
expect(mockInspect).toHaveBeenCalledWith(
|
expect(mockInspect).toHaveBeenCalledWith(
|
||||||
'remote_no_oauth_server',
|
'remote_no_oauth_server',
|
||||||
testConfigs.remote_no_oauth_server,
|
testConfigs.remote_no_oauth_server,
|
||||||
|
undefined,
|
||||||
|
undefined,
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -849,6 +849,11 @@ export const configSchema = z.object({
|
||||||
includedTools: z.array(z.string()).optional(),
|
includedTools: z.array(z.string()).optional(),
|
||||||
filteredTools: z.array(z.string()).optional(),
|
filteredTools: z.array(z.string()).optional(),
|
||||||
mcpServers: MCPServersSchema.optional(),
|
mcpServers: MCPServersSchema.optional(),
|
||||||
|
mcpSettings: z
|
||||||
|
.object({
|
||||||
|
allowedDomains: z.array(z.string()).optional(),
|
||||||
|
})
|
||||||
|
.optional(),
|
||||||
interface: interfaceSchema,
|
interface: interfaceSchema,
|
||||||
turnstile: turnstileSchema.optional(),
|
turnstile: turnstileSchema.optional(),
|
||||||
fileStrategy: fileSourceSchema.default(FileSources.local),
|
fileStrategy: fileSourceSchema.default(FileSources.local),
|
||||||
|
|
|
||||||
|
|
@ -60,7 +60,8 @@ export const AppService = async (params?: {
|
||||||
|
|
||||||
const availableTools = systemTools;
|
const availableTools = systemTools;
|
||||||
|
|
||||||
const mcpConfig = config.mcpServers || null;
|
const mcpServersConfig = config.mcpServers || null;
|
||||||
|
const mcpSettings = config.mcpSettings || null;
|
||||||
const registration = config.registration ?? configDefaults.registration;
|
const registration = config.registration ?? configDefaults.registration;
|
||||||
const interfaceConfig = await loadDefaultInterface({ config, configDefaults });
|
const interfaceConfig = await loadDefaultInterface({ config, configDefaults });
|
||||||
const turnstileConfig = loadTurnstileConfig(config, configDefaults);
|
const turnstileConfig = loadTurnstileConfig(config, configDefaults);
|
||||||
|
|
@ -74,7 +75,8 @@ export const AppService = async (params?: {
|
||||||
speech,
|
speech,
|
||||||
balance,
|
balance,
|
||||||
transactions,
|
transactions,
|
||||||
mcpConfig,
|
mcpConfig: mcpServersConfig,
|
||||||
|
mcpSettings,
|
||||||
webSearch,
|
webSearch,
|
||||||
fileStrategy,
|
fileStrategy,
|
||||||
registration,
|
registration,
|
||||||
|
|
|
||||||
|
|
@ -82,6 +82,8 @@ export interface AppConfig {
|
||||||
speech?: TCustomConfig['speech'];
|
speech?: TCustomConfig['speech'];
|
||||||
/** MCP server configuration */
|
/** MCP server configuration */
|
||||||
mcpConfig?: TCustomConfig['mcpServers'] | null;
|
mcpConfig?: TCustomConfig['mcpServers'] | null;
|
||||||
|
/** MCP settings (domain allowlist, etc.) */
|
||||||
|
mcpSettings?: TCustomConfig['mcpSettings'] | null;
|
||||||
/** File configuration */
|
/** File configuration */
|
||||||
fileConfig?: TFileConfig;
|
fileConfig?: TFileConfig;
|
||||||
/** Secure image links configuration */
|
/** Secure image links configuration */
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue