mirror of
https://github.com/danny-avila/LibreChat.git
synced 2025-12-23 03:40:14 +01:00
🔒 feat: Add MCP server domain restrictions for remote transports (#11013)
* 🔒 feat: Add MCP server domain restrictions for remote transports * 🔒 feat: Implement comprehensive MCP error handling and domain validation - Added `handleMCPError` function to centralize error responses for domain restrictions and inspection failures. - Introduced custom error classes: `MCPDomainNotAllowedError` and `MCPInspectionFailedError` for better error management. - Updated MCP server controllers to utilize the new error handling mechanism. - Enhanced domain validation logic in `createMCPTools` and `createMCPTool` functions to prevent operations on disallowed domains. - Added tests for runtime domain validation scenarios to ensure correct behavior. * chore: import order * 🔒 feat: Enhance domain validation in MCP tools with user role-based restrictions - Integrated `getAppConfig` to fetch allowed domains based on user roles in `createMCPTools` and `createMCPTool` functions. - Removed the deprecated `getAllowedDomains` method from `MCPServersRegistry`. - Updated tests to verify domain restrictions are applied correctly based on user roles. - Ensured that domain validation logic is consistent and efficient across tool creation processes. * 🔒 test: Refactor MCP tests to utilize configurable app settings - Introduced a mock for `getAppConfig` to enhance test flexibility. - Removed redundant mock definition to streamline test setup. - Ensured tests are aligned with the latest domain validation logic. --------- Co-authored-by: Atef Bellaaj <slalom.bellaaj@external.daimlertruck.com> Co-authored-by: Danny Avila <danny@librechat.ai>
This commit is contained in:
parent
98294755ee
commit
95a69df70e
19 changed files with 815 additions and 75 deletions
|
|
@ -6,10 +6,54 @@
|
|||
* @import { MCPServerDocument } from 'librechat-data-provider'
|
||||
*/
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const {
|
||||
isMCPDomainNotAllowedError,
|
||||
isMCPInspectionFailedError,
|
||||
MCPErrorCodes,
|
||||
} = require('@librechat/api');
|
||||
const { Constants, MCPServerUserInputSchema } = require('librechat-data-provider');
|
||||
const { cacheMCPServerTools, getMCPServerTools } = require('~/server/services/Config');
|
||||
const { getMCPManager, getMCPServersRegistry } = require('~/config');
|
||||
|
||||
/**
|
||||
* Handles MCP-specific errors and sends appropriate HTTP responses.
|
||||
* @param {Error} error - The error to handle
|
||||
* @param {import('express').Response} res - Express response object
|
||||
* @returns {import('express').Response | null} Response if handled, null if not an MCP error
|
||||
*/
|
||||
function handleMCPError(error, res) {
|
||||
if (isMCPDomainNotAllowedError(error)) {
|
||||
return res.status(error.statusCode).json({
|
||||
error: error.code,
|
||||
message: error.message,
|
||||
});
|
||||
}
|
||||
|
||||
if (isMCPInspectionFailedError(error)) {
|
||||
return res.status(error.statusCode).json({
|
||||
error: error.code,
|
||||
message: error.message,
|
||||
});
|
||||
}
|
||||
|
||||
// Fallback for legacy string-based error handling (backwards compatibility)
|
||||
if (error.message?.startsWith(MCPErrorCodes.DOMAIN_NOT_ALLOWED)) {
|
||||
return res.status(403).json({
|
||||
error: MCPErrorCodes.DOMAIN_NOT_ALLOWED,
|
||||
message: error.message.replace(/^MCP_DOMAIN_NOT_ALLOWED\s*:\s*/i, ''),
|
||||
});
|
||||
}
|
||||
|
||||
if (error.message?.startsWith(MCPErrorCodes.INSPECTION_FAILED)) {
|
||||
return res.status(400).json({
|
||||
error: MCPErrorCodes.INSPECTION_FAILED,
|
||||
message: error.message,
|
||||
});
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all MCP tools available to the user
|
||||
*/
|
||||
|
|
@ -175,11 +219,9 @@ const createMCPServerController = async (req, res) => {
|
|||
});
|
||||
} catch (error) {
|
||||
logger.error('[createMCPServer]', error);
|
||||
if (error.message?.startsWith('MCP_INSPECTION_FAILED')) {
|
||||
return res.status(400).json({
|
||||
error: 'MCP_INSPECTION_FAILED',
|
||||
message: error.message,
|
||||
});
|
||||
const mcpErrorResponse = handleMCPError(error, res);
|
||||
if (mcpErrorResponse) {
|
||||
return mcpErrorResponse;
|
||||
}
|
||||
res.status(500).json({ message: error.message });
|
||||
}
|
||||
|
|
@ -235,11 +277,9 @@ const updateMCPServerController = async (req, res) => {
|
|||
res.status(200).json(parsedConfig);
|
||||
} catch (error) {
|
||||
logger.error('[updateMCPServer]', error);
|
||||
if (error.message?.startsWith('MCP_INSPECTION_FAILED:')) {
|
||||
return res.status(400).json({
|
||||
error: 'MCP_INSPECTION_FAILED',
|
||||
message: error.message,
|
||||
});
|
||||
const mcpErrorResponse = handleMCPError(error, res);
|
||||
if (mcpErrorResponse) {
|
||||
return mcpErrorResponse;
|
||||
}
|
||||
res.status(500).json({ message: error.message });
|
||||
}
|
||||
|
|
|
|||
|
|
@ -12,26 +12,36 @@ const mockRegistryInstance = {
|
|||
removeServer: jest.fn(),
|
||||
};
|
||||
|
||||
jest.mock('@librechat/api', () => ({
|
||||
...jest.requireActual('@librechat/api'),
|
||||
MCPOAuthHandler: {
|
||||
initiateOAuthFlow: jest.fn(),
|
||||
getFlowState: jest.fn(),
|
||||
completeOAuthFlow: jest.fn(),
|
||||
generateFlowId: jest.fn(),
|
||||
},
|
||||
MCPTokenStorage: {
|
||||
storeTokens: jest.fn(),
|
||||
getClientInfoAndMetadata: jest.fn(),
|
||||
getTokens: jest.fn(),
|
||||
deleteUserTokens: jest.fn(),
|
||||
},
|
||||
getUserMCPAuthMap: jest.fn(),
|
||||
generateCheckAccess: jest.fn(() => (req, res, next) => next()),
|
||||
MCPServersRegistry: {
|
||||
getInstance: () => mockRegistryInstance,
|
||||
},
|
||||
}));
|
||||
jest.mock('@librechat/api', () => {
|
||||
const actual = jest.requireActual('@librechat/api');
|
||||
return {
|
||||
...actual,
|
||||
MCPOAuthHandler: {
|
||||
initiateOAuthFlow: jest.fn(),
|
||||
getFlowState: jest.fn(),
|
||||
completeOAuthFlow: jest.fn(),
|
||||
generateFlowId: jest.fn(),
|
||||
},
|
||||
MCPTokenStorage: {
|
||||
storeTokens: jest.fn(),
|
||||
getClientInfoAndMetadata: jest.fn(),
|
||||
getTokens: jest.fn(),
|
||||
deleteUserTokens: jest.fn(),
|
||||
},
|
||||
getUserMCPAuthMap: jest.fn(),
|
||||
generateCheckAccess: jest.fn(() => (req, res, next) => next()),
|
||||
MCPServersRegistry: {
|
||||
getInstance: () => mockRegistryInstance,
|
||||
},
|
||||
// Error handling utilities (from @librechat/api mcp/errors)
|
||||
isMCPDomainNotAllowedError: (error) => error?.code === 'MCP_DOMAIN_NOT_ALLOWED',
|
||||
isMCPInspectionFailedError: (error) => error?.code === 'MCP_INSPECTION_FAILED',
|
||||
MCPErrorCodes: {
|
||||
DOMAIN_NOT_ALLOWED: 'MCP_DOMAIN_NOT_ALLOWED',
|
||||
INSPECTION_FAILED: 'MCP_INSPECTION_FAILED',
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
logger: {
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ const {
|
|||
const {
|
||||
sendEvent,
|
||||
MCPOAuthHandler,
|
||||
isMCPDomainAllowed,
|
||||
normalizeServerName,
|
||||
convertWithResolvedRefs,
|
||||
} = require('@librechat/api');
|
||||
|
|
@ -21,13 +22,14 @@ const {
|
|||
isAssistantsEndpoint,
|
||||
} = require('librechat-data-provider');
|
||||
const {
|
||||
getMCPManager,
|
||||
getFlowStateManager,
|
||||
getOAuthReconnectionManager,
|
||||
getMCPServersRegistry,
|
||||
getFlowStateManager,
|
||||
getMCPManager,
|
||||
} = require('~/config');
|
||||
const { findToken, createToken, updateToken } = require('~/models');
|
||||
const { reinitMCPServer } = require('./Tools/mcp');
|
||||
const { getAppConfig } = require('./Config');
|
||||
const { getLogStores } = require('~/cache');
|
||||
|
||||
/**
|
||||
|
|
@ -222,10 +224,34 @@ async function reconnectServer({ res, user, index, signal, serverName, userMCPAu
|
|||
* @param {Providers | EModelEndpoint} params.provider - The provider for the tool.
|
||||
* @param {number} [params.index]
|
||||
* @param {AbortSignal} [params.signal]
|
||||
* @param {import('@librechat/api').ParsedServerConfig} [params.config]
|
||||
* @param {Record<string, Record<string, string>>} [params.userMCPAuthMap]
|
||||
* @returns { Promise<Array<typeof tool | { _call: (toolInput: Object | string) => unknown}>> } An object with `_call` method to execute the tool input.
|
||||
*/
|
||||
async function createMCPTools({ res, user, index, signal, serverName, provider, userMCPAuthMap }) {
|
||||
async function createMCPTools({
|
||||
res,
|
||||
user,
|
||||
index,
|
||||
signal,
|
||||
config,
|
||||
provider,
|
||||
serverName,
|
||||
userMCPAuthMap,
|
||||
}) {
|
||||
// Early domain validation before reconnecting server (avoid wasted work on disallowed domains)
|
||||
// Use getAppConfig() to support per-user/role domain restrictions
|
||||
const serverConfig =
|
||||
config ?? (await getMCPServersRegistry().getServerConfig(serverName, user?.id));
|
||||
if (serverConfig?.url) {
|
||||
const appConfig = await getAppConfig({ role: user?.role });
|
||||
const allowedDomains = appConfig?.mcpSettings?.allowedDomains;
|
||||
const isDomainAllowed = await isMCPDomainAllowed(serverConfig, allowedDomains);
|
||||
if (!isDomainAllowed) {
|
||||
logger.warn(`[MCP][${serverName}] Domain not allowed, skipping all tools`);
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
const result = await reconnectServer({ res, user, index, signal, serverName, userMCPAuthMap });
|
||||
if (!result || !result.tools) {
|
||||
logger.warn(`[MCP][${serverName}] Failed to reinitialize MCP server.`);
|
||||
|
|
@ -241,6 +267,7 @@ async function createMCPTools({ res, user, index, signal, serverName, provider,
|
|||
userMCPAuthMap,
|
||||
availableTools: result.availableTools,
|
||||
toolKey: `${tool.name}${Constants.mcp_delimiter}${serverName}`,
|
||||
config: serverConfig,
|
||||
});
|
||||
if (toolInstance) {
|
||||
serverTools.push(toolInstance);
|
||||
|
|
@ -262,6 +289,7 @@ async function createMCPTools({ res, user, index, signal, serverName, provider,
|
|||
* @param {Providers | EModelEndpoint} params.provider - The provider for the tool.
|
||||
* @param {LCAvailableTools} [params.availableTools]
|
||||
* @param {Record<string, Record<string, string>>} [params.userMCPAuthMap]
|
||||
* @param {import('@librechat/api').ParsedServerConfig} [params.config]
|
||||
* @returns { Promise<typeof tool | { _call: (toolInput: Object | string) => unknown}> } An object with `_call` method to execute the tool input.
|
||||
*/
|
||||
async function createMCPTool({
|
||||
|
|
@ -273,9 +301,24 @@ async function createMCPTool({
|
|||
provider,
|
||||
userMCPAuthMap,
|
||||
availableTools,
|
||||
config,
|
||||
}) {
|
||||
const [toolName, serverName] = toolKey.split(Constants.mcp_delimiter);
|
||||
|
||||
// Runtime domain validation: check if the server's domain is still allowed
|
||||
// Use getAppConfig() to support per-user/role domain restrictions
|
||||
const serverConfig =
|
||||
config ?? (await getMCPServersRegistry().getServerConfig(serverName, user?.id));
|
||||
if (serverConfig?.url) {
|
||||
const appConfig = await getAppConfig({ role: user?.role });
|
||||
const allowedDomains = appConfig?.mcpSettings?.allowedDomains;
|
||||
const isDomainAllowed = await isMCPDomainAllowed(serverConfig, allowedDomains);
|
||||
if (!isDomainAllowed) {
|
||||
logger.warn(`[MCP][${serverName}] Domain no longer allowed, skipping tool: ${toolName}`);
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
|
||||
/** @type {LCTool | undefined} */
|
||||
let toolDefinition = availableTools?.[toolKey]?.function;
|
||||
if (!toolDefinition) {
|
||||
|
|
|
|||
|
|
@ -1,14 +1,4 @@
|
|||
const { logger } = require('@librechat/data-schemas');
|
||||
const { MCPOAuthHandler } = require('@librechat/api');
|
||||
const { CacheKeys } = require('librechat-data-provider');
|
||||
const {
|
||||
createMCPTool,
|
||||
createMCPTools,
|
||||
getMCPSetupData,
|
||||
checkOAuthFlowStatus,
|
||||
getServerConnectionStatus,
|
||||
} = require('./MCP');
|
||||
|
||||
// Mock all dependencies - define mocks before imports
|
||||
// Mock all dependencies
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
logger: {
|
||||
|
|
@ -43,22 +33,46 @@ jest.mock('@librechat/agents', () => ({
|
|||
},
|
||||
}));
|
||||
|
||||
// Create mock registry instance
|
||||
const mockRegistryInstance = {
|
||||
getOAuthServers: jest.fn(() => Promise.resolve(new Set())),
|
||||
getAllServerConfigs: jest.fn(() => Promise.resolve({})),
|
||||
getServerConfig: jest.fn(() => Promise.resolve(null)),
|
||||
};
|
||||
|
||||
jest.mock('@librechat/api', () => ({
|
||||
MCPOAuthHandler: {
|
||||
generateFlowId: jest.fn(),
|
||||
},
|
||||
sendEvent: jest.fn(),
|
||||
normalizeServerName: jest.fn((name) => name),
|
||||
convertWithResolvedRefs: jest.fn((params) => params),
|
||||
MCPServersRegistry: {
|
||||
getInstance: () => mockRegistryInstance,
|
||||
},
|
||||
}));
|
||||
// Create isMCPDomainAllowed mock that can be configured per-test
|
||||
const mockIsMCPDomainAllowed = jest.fn(() => Promise.resolve(true));
|
||||
|
||||
const mockGetAppConfig = jest.fn(() => Promise.resolve({}));
|
||||
|
||||
jest.mock('@librechat/api', () => {
|
||||
// Access mock via getter to avoid hoisting issues
|
||||
return {
|
||||
MCPOAuthHandler: {
|
||||
generateFlowId: jest.fn(),
|
||||
},
|
||||
sendEvent: jest.fn(),
|
||||
normalizeServerName: jest.fn((name) => name),
|
||||
convertWithResolvedRefs: jest.fn((params) => params),
|
||||
get isMCPDomainAllowed() {
|
||||
return mockIsMCPDomainAllowed;
|
||||
},
|
||||
MCPServersRegistry: {
|
||||
getInstance: () => mockRegistryInstance,
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { MCPOAuthHandler } = require('@librechat/api');
|
||||
const { CacheKeys } = require('librechat-data-provider');
|
||||
const {
|
||||
createMCPTool,
|
||||
createMCPTools,
|
||||
getMCPSetupData,
|
||||
checkOAuthFlowStatus,
|
||||
getServerConnectionStatus,
|
||||
} = require('./MCP');
|
||||
|
||||
jest.mock('librechat-data-provider', () => ({
|
||||
CacheKeys: {
|
||||
|
|
@ -80,7 +94,9 @@ jest.mock('librechat-data-provider', () => ({
|
|||
|
||||
jest.mock('./Config', () => ({
|
||||
loadCustomConfig: jest.fn(),
|
||||
getAppConfig: jest.fn(),
|
||||
get getAppConfig() {
|
||||
return mockGetAppConfig;
|
||||
},
|
||||
}));
|
||||
|
||||
jest.mock('~/config', () => ({
|
||||
|
|
@ -692,6 +708,18 @@ describe('User parameter passing tests', () => {
|
|||
createFlowWithHandler: jest.fn(),
|
||||
failFlow: jest.fn(),
|
||||
});
|
||||
|
||||
// Reset domain validation mock to default (allow all)
|
||||
mockIsMCPDomainAllowed.mockReset();
|
||||
mockIsMCPDomainAllowed.mockResolvedValue(true);
|
||||
|
||||
// Reset registry mocks
|
||||
mockRegistryInstance.getServerConfig.mockReset();
|
||||
mockRegistryInstance.getServerConfig.mockResolvedValue(null);
|
||||
|
||||
// Reset getAppConfig mock to default (no restrictions)
|
||||
mockGetAppConfig.mockReset();
|
||||
mockGetAppConfig.mockResolvedValue({});
|
||||
});
|
||||
|
||||
describe('createMCPTools', () => {
|
||||
|
|
@ -887,6 +915,229 @@ describe('User parameter passing tests', () => {
|
|||
});
|
||||
});
|
||||
|
||||
describe('Runtime domain validation', () => {
|
||||
it('should skip tool creation when domain is not allowed', async () => {
|
||||
const mockUser = { id: 'domain-test-user', role: 'user' };
|
||||
const mockRes = { write: jest.fn(), flush: jest.fn() };
|
||||
|
||||
// Mock server config with URL (remote server)
|
||||
mockRegistryInstance.getServerConfig.mockResolvedValue({
|
||||
url: 'https://disallowed-domain.com/sse',
|
||||
});
|
||||
|
||||
// Mock getAppConfig to return domain restrictions
|
||||
mockGetAppConfig.mockResolvedValue({
|
||||
mcpSettings: { allowedDomains: ['allowed-domain.com'] },
|
||||
});
|
||||
|
||||
// Mock domain validation to return false (domain not allowed)
|
||||
mockIsMCPDomainAllowed.mockResolvedValueOnce(false);
|
||||
|
||||
const result = await createMCPTool({
|
||||
res: mockRes,
|
||||
user: mockUser,
|
||||
toolKey: 'test-tool::test-server',
|
||||
provider: 'openai',
|
||||
userMCPAuthMap: {},
|
||||
availableTools: {
|
||||
'test-tool::test-server': {
|
||||
function: {
|
||||
description: 'Test tool',
|
||||
parameters: { type: 'object', properties: {} },
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
// Should return undefined for disallowed domain
|
||||
expect(result).toBeUndefined();
|
||||
|
||||
// Should not call reinitMCPServer since domain check failed
|
||||
expect(mockReinitMCPServer).not.toHaveBeenCalled();
|
||||
|
||||
// Verify getAppConfig was called with user role
|
||||
expect(mockGetAppConfig).toHaveBeenCalledWith({ role: 'user' });
|
||||
|
||||
// Verify domain validation was called with correct parameters
|
||||
expect(mockIsMCPDomainAllowed).toHaveBeenCalledWith(
|
||||
{ url: 'https://disallowed-domain.com/sse' },
|
||||
['allowed-domain.com'],
|
||||
);
|
||||
});
|
||||
|
||||
it('should allow tool creation when domain is allowed', async () => {
|
||||
const mockUser = { id: 'domain-test-user', role: 'admin' };
|
||||
const mockRes = { write: jest.fn(), flush: jest.fn() };
|
||||
|
||||
// Mock server config with URL (remote server)
|
||||
mockRegistryInstance.getServerConfig.mockResolvedValue({
|
||||
url: 'https://allowed-domain.com/sse',
|
||||
});
|
||||
|
||||
// Mock getAppConfig to return domain restrictions
|
||||
mockGetAppConfig.mockResolvedValue({
|
||||
mcpSettings: { allowedDomains: ['allowed-domain.com'] },
|
||||
});
|
||||
|
||||
// Mock domain validation to return true (domain allowed)
|
||||
mockIsMCPDomainAllowed.mockResolvedValueOnce(true);
|
||||
|
||||
const availableTools = {
|
||||
'test-tool::test-server': {
|
||||
function: {
|
||||
description: 'Test tool',
|
||||
parameters: { type: 'object', properties: {} },
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const result = await createMCPTool({
|
||||
res: mockRes,
|
||||
user: mockUser,
|
||||
toolKey: 'test-tool::test-server',
|
||||
provider: 'openai',
|
||||
userMCPAuthMap: {},
|
||||
availableTools,
|
||||
});
|
||||
|
||||
// Should create tool successfully
|
||||
expect(result).toBeDefined();
|
||||
|
||||
// Verify getAppConfig was called with user role
|
||||
expect(mockGetAppConfig).toHaveBeenCalledWith({ role: 'admin' });
|
||||
});
|
||||
|
||||
it('should skip domain validation for stdio transports (no URL)', async () => {
|
||||
const mockUser = { id: 'stdio-test-user' };
|
||||
const mockRes = { write: jest.fn(), flush: jest.fn() };
|
||||
|
||||
// Mock server config without URL (stdio transport)
|
||||
mockRegistryInstance.getServerConfig.mockResolvedValue({
|
||||
command: 'npx',
|
||||
args: ['@modelcontextprotocol/server'],
|
||||
});
|
||||
|
||||
// Mock getAppConfig (should not be called for stdio)
|
||||
mockGetAppConfig.mockResolvedValue({
|
||||
mcpSettings: { allowedDomains: ['restricted-domain.com'] },
|
||||
});
|
||||
|
||||
const availableTools = {
|
||||
'test-tool::test-server': {
|
||||
function: {
|
||||
description: 'Test tool',
|
||||
parameters: { type: 'object', properties: {} },
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const result = await createMCPTool({
|
||||
res: mockRes,
|
||||
user: mockUser,
|
||||
toolKey: 'test-tool::test-server',
|
||||
provider: 'openai',
|
||||
userMCPAuthMap: {},
|
||||
availableTools,
|
||||
});
|
||||
|
||||
// Should create tool successfully without domain check
|
||||
expect(result).toBeDefined();
|
||||
|
||||
// Should not call getAppConfig or isMCPDomainAllowed for stdio transport (no URL)
|
||||
expect(mockGetAppConfig).not.toHaveBeenCalled();
|
||||
expect(mockIsMCPDomainAllowed).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should return empty array from createMCPTools when domain is not allowed', async () => {
|
||||
const mockUser = { id: 'domain-test-user', role: 'user' };
|
||||
const mockRes = { write: jest.fn(), flush: jest.fn() };
|
||||
|
||||
// Mock server config with URL (remote server)
|
||||
const serverConfig = { url: 'https://disallowed-domain.com/sse' };
|
||||
mockRegistryInstance.getServerConfig.mockResolvedValue(serverConfig);
|
||||
|
||||
// Mock getAppConfig to return domain restrictions
|
||||
mockGetAppConfig.mockResolvedValue({
|
||||
mcpSettings: { allowedDomains: ['allowed-domain.com'] },
|
||||
});
|
||||
|
||||
// Mock domain validation to return false (domain not allowed)
|
||||
mockIsMCPDomainAllowed.mockResolvedValueOnce(false);
|
||||
|
||||
const result = await createMCPTools({
|
||||
res: mockRes,
|
||||
user: mockUser,
|
||||
serverName: 'test-server',
|
||||
provider: 'openai',
|
||||
userMCPAuthMap: {},
|
||||
config: serverConfig,
|
||||
});
|
||||
|
||||
// Should return empty array for disallowed domain
|
||||
expect(result).toEqual([]);
|
||||
|
||||
// Should not call reinitMCPServer since domain check failed early
|
||||
expect(mockReinitMCPServer).not.toHaveBeenCalled();
|
||||
|
||||
// Verify getAppConfig was called with user role
|
||||
expect(mockGetAppConfig).toHaveBeenCalledWith({ role: 'user' });
|
||||
});
|
||||
|
||||
it('should use user role when fetching domain restrictions', async () => {
|
||||
const adminUser = { id: 'admin-user', role: 'admin' };
|
||||
const regularUser = { id: 'regular-user', role: 'user' };
|
||||
const mockRes = { write: jest.fn(), flush: jest.fn() };
|
||||
|
||||
mockRegistryInstance.getServerConfig.mockResolvedValue({
|
||||
url: 'https://some-domain.com/sse',
|
||||
});
|
||||
|
||||
// Mock different responses based on role
|
||||
mockGetAppConfig
|
||||
.mockResolvedValueOnce({ mcpSettings: { allowedDomains: ['admin-allowed.com'] } })
|
||||
.mockResolvedValueOnce({ mcpSettings: { allowedDomains: ['user-allowed.com'] } });
|
||||
|
||||
mockIsMCPDomainAllowed.mockResolvedValue(true);
|
||||
|
||||
const availableTools = {
|
||||
'test-tool::test-server': {
|
||||
function: {
|
||||
description: 'Test tool',
|
||||
parameters: { type: 'object', properties: {} },
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
// Call with admin user
|
||||
await createMCPTool({
|
||||
res: mockRes,
|
||||
user: adminUser,
|
||||
toolKey: 'test-tool::test-server',
|
||||
provider: 'openai',
|
||||
userMCPAuthMap: {},
|
||||
availableTools,
|
||||
});
|
||||
|
||||
// Reset and call with regular user
|
||||
mockRegistryInstance.getServerConfig.mockResolvedValue({
|
||||
url: 'https://some-domain.com/sse',
|
||||
});
|
||||
|
||||
await createMCPTool({
|
||||
res: mockRes,
|
||||
user: regularUser,
|
||||
toolKey: 'test-tool::test-server',
|
||||
provider: 'openai',
|
||||
userMCPAuthMap: {},
|
||||
availableTools,
|
||||
});
|
||||
|
||||
// Verify getAppConfig was called with correct roles
|
||||
expect(mockGetAppConfig).toHaveBeenNthCalledWith(1, { role: 'admin' });
|
||||
expect(mockGetAppConfig).toHaveBeenNthCalledWith(2, { role: 'user' });
|
||||
});
|
||||
});
|
||||
|
||||
describe('User parameter integrity', () => {
|
||||
it('should preserve user object properties through the call chain', async () => {
|
||||
const complexUser = {
|
||||
|
|
|
|||
|
|
@ -14,8 +14,9 @@ async function initializeMCPs() {
|
|||
}
|
||||
|
||||
// Initialize MCPServersRegistry first (required for MCPManager)
|
||||
// Pass allowedDomains from mcpSettings for domain validation
|
||||
try {
|
||||
createMCPServersRegistry(mongoose);
|
||||
createMCPServersRegistry(mongoose, appConfig?.mcpSettings?.allowedDomains);
|
||||
} catch (error) {
|
||||
logger.error('[MCP] Failed to initialize MCPServersRegistry:', error);
|
||||
throw error;
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue