🔒 feat: Add MCP server domain restrictions for remote transports (#11013)

* 🔒 feat: Add MCP server domain restrictions for remote transports

* 🔒 feat: Implement comprehensive MCP error handling and domain validation

- Added `handleMCPError` function to centralize error responses for domain restrictions and inspection failures.
- Introduced custom error classes: `MCPDomainNotAllowedError` and `MCPInspectionFailedError` for better error management.
- Updated MCP server controllers to utilize the new error handling mechanism.
- Enhanced domain validation logic in `createMCPTools` and `createMCPTool` functions to prevent operations on disallowed domains.
- Added tests for runtime domain validation scenarios to ensure correct behavior.

* chore: import order

* 🔒 feat: Enhance domain validation in MCP tools with user role-based restrictions

- Integrated `getAppConfig` to fetch allowed domains based on user roles in `createMCPTools` and `createMCPTool` functions.
- Removed the deprecated `getAllowedDomains` method from `MCPServersRegistry`.
- Updated tests to verify domain restrictions are applied correctly based on user roles.
- Ensured that domain validation logic is consistent and efficient across tool creation processes.

* 🔒 test: Refactor MCP tests to utilize configurable app settings

- Introduced a mock for `getAppConfig` to enhance test flexibility.
- Removed redundant mock definition to streamline test setup.
- Ensured tests are aligned with the latest domain validation logic.

---------

Co-authored-by: Atef Bellaaj <slalom.bellaaj@external.daimlertruck.com>
Co-authored-by: Danny Avila <danny@librechat.ai>
This commit is contained in:
Atef Bellaaj 2025-12-18 19:57:49 +01:00 committed by GitHub
parent 98294755ee
commit 95a69df70e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 815 additions and 75 deletions

View file

@ -1,14 +1,4 @@
const { logger } = require('@librechat/data-schemas');
const { MCPOAuthHandler } = require('@librechat/api');
const { CacheKeys } = require('librechat-data-provider');
const {
createMCPTool,
createMCPTools,
getMCPSetupData,
checkOAuthFlowStatus,
getServerConnectionStatus,
} = require('./MCP');
// Mock all dependencies - define mocks before imports
// Mock all dependencies
jest.mock('@librechat/data-schemas', () => ({
logger: {
@ -43,22 +33,46 @@ jest.mock('@librechat/agents', () => ({
},
}));
// Create mock registry instance
const mockRegistryInstance = {
getOAuthServers: jest.fn(() => Promise.resolve(new Set())),
getAllServerConfigs: jest.fn(() => Promise.resolve({})),
getServerConfig: jest.fn(() => Promise.resolve(null)),
};
jest.mock('@librechat/api', () => ({
MCPOAuthHandler: {
generateFlowId: jest.fn(),
},
sendEvent: jest.fn(),
normalizeServerName: jest.fn((name) => name),
convertWithResolvedRefs: jest.fn((params) => params),
MCPServersRegistry: {
getInstance: () => mockRegistryInstance,
},
}));
// Create isMCPDomainAllowed mock that can be configured per-test
const mockIsMCPDomainAllowed = jest.fn(() => Promise.resolve(true));
const mockGetAppConfig = jest.fn(() => Promise.resolve({}));
jest.mock('@librechat/api', () => {
// Access mock via getter to avoid hoisting issues
return {
MCPOAuthHandler: {
generateFlowId: jest.fn(),
},
sendEvent: jest.fn(),
normalizeServerName: jest.fn((name) => name),
convertWithResolvedRefs: jest.fn((params) => params),
get isMCPDomainAllowed() {
return mockIsMCPDomainAllowed;
},
MCPServersRegistry: {
getInstance: () => mockRegistryInstance,
},
};
});
const { logger } = require('@librechat/data-schemas');
const { MCPOAuthHandler } = require('@librechat/api');
const { CacheKeys } = require('librechat-data-provider');
const {
createMCPTool,
createMCPTools,
getMCPSetupData,
checkOAuthFlowStatus,
getServerConnectionStatus,
} = require('./MCP');
jest.mock('librechat-data-provider', () => ({
CacheKeys: {
@ -80,7 +94,9 @@ jest.mock('librechat-data-provider', () => ({
jest.mock('./Config', () => ({
loadCustomConfig: jest.fn(),
getAppConfig: jest.fn(),
get getAppConfig() {
return mockGetAppConfig;
},
}));
jest.mock('~/config', () => ({
@ -692,6 +708,18 @@ describe('User parameter passing tests', () => {
createFlowWithHandler: jest.fn(),
failFlow: jest.fn(),
});
// Reset domain validation mock to default (allow all)
mockIsMCPDomainAllowed.mockReset();
mockIsMCPDomainAllowed.mockResolvedValue(true);
// Reset registry mocks
mockRegistryInstance.getServerConfig.mockReset();
mockRegistryInstance.getServerConfig.mockResolvedValue(null);
// Reset getAppConfig mock to default (no restrictions)
mockGetAppConfig.mockReset();
mockGetAppConfig.mockResolvedValue({});
});
describe('createMCPTools', () => {
@ -887,6 +915,229 @@ describe('User parameter passing tests', () => {
});
});
describe('Runtime domain validation', () => {
it('should skip tool creation when domain is not allowed', async () => {
const mockUser = { id: 'domain-test-user', role: 'user' };
const mockRes = { write: jest.fn(), flush: jest.fn() };
// Mock server config with URL (remote server)
mockRegistryInstance.getServerConfig.mockResolvedValue({
url: 'https://disallowed-domain.com/sse',
});
// Mock getAppConfig to return domain restrictions
mockGetAppConfig.mockResolvedValue({
mcpSettings: { allowedDomains: ['allowed-domain.com'] },
});
// Mock domain validation to return false (domain not allowed)
mockIsMCPDomainAllowed.mockResolvedValueOnce(false);
const result = await createMCPTool({
res: mockRes,
user: mockUser,
toolKey: 'test-tool::test-server',
provider: 'openai',
userMCPAuthMap: {},
availableTools: {
'test-tool::test-server': {
function: {
description: 'Test tool',
parameters: { type: 'object', properties: {} },
},
},
},
});
// Should return undefined for disallowed domain
expect(result).toBeUndefined();
// Should not call reinitMCPServer since domain check failed
expect(mockReinitMCPServer).not.toHaveBeenCalled();
// Verify getAppConfig was called with user role
expect(mockGetAppConfig).toHaveBeenCalledWith({ role: 'user' });
// Verify domain validation was called with correct parameters
expect(mockIsMCPDomainAllowed).toHaveBeenCalledWith(
{ url: 'https://disallowed-domain.com/sse' },
['allowed-domain.com'],
);
});
it('should allow tool creation when domain is allowed', async () => {
const mockUser = { id: 'domain-test-user', role: 'admin' };
const mockRes = { write: jest.fn(), flush: jest.fn() };
// Mock server config with URL (remote server)
mockRegistryInstance.getServerConfig.mockResolvedValue({
url: 'https://allowed-domain.com/sse',
});
// Mock getAppConfig to return domain restrictions
mockGetAppConfig.mockResolvedValue({
mcpSettings: { allowedDomains: ['allowed-domain.com'] },
});
// Mock domain validation to return true (domain allowed)
mockIsMCPDomainAllowed.mockResolvedValueOnce(true);
const availableTools = {
'test-tool::test-server': {
function: {
description: 'Test tool',
parameters: { type: 'object', properties: {} },
},
},
};
const result = await createMCPTool({
res: mockRes,
user: mockUser,
toolKey: 'test-tool::test-server',
provider: 'openai',
userMCPAuthMap: {},
availableTools,
});
// Should create tool successfully
expect(result).toBeDefined();
// Verify getAppConfig was called with user role
expect(mockGetAppConfig).toHaveBeenCalledWith({ role: 'admin' });
});
it('should skip domain validation for stdio transports (no URL)', async () => {
const mockUser = { id: 'stdio-test-user' };
const mockRes = { write: jest.fn(), flush: jest.fn() };
// Mock server config without URL (stdio transport)
mockRegistryInstance.getServerConfig.mockResolvedValue({
command: 'npx',
args: ['@modelcontextprotocol/server'],
});
// Mock getAppConfig (should not be called for stdio)
mockGetAppConfig.mockResolvedValue({
mcpSettings: { allowedDomains: ['restricted-domain.com'] },
});
const availableTools = {
'test-tool::test-server': {
function: {
description: 'Test tool',
parameters: { type: 'object', properties: {} },
},
},
};
const result = await createMCPTool({
res: mockRes,
user: mockUser,
toolKey: 'test-tool::test-server',
provider: 'openai',
userMCPAuthMap: {},
availableTools,
});
// Should create tool successfully without domain check
expect(result).toBeDefined();
// Should not call getAppConfig or isMCPDomainAllowed for stdio transport (no URL)
expect(mockGetAppConfig).not.toHaveBeenCalled();
expect(mockIsMCPDomainAllowed).not.toHaveBeenCalled();
});
it('should return empty array from createMCPTools when domain is not allowed', async () => {
const mockUser = { id: 'domain-test-user', role: 'user' };
const mockRes = { write: jest.fn(), flush: jest.fn() };
// Mock server config with URL (remote server)
const serverConfig = { url: 'https://disallowed-domain.com/sse' };
mockRegistryInstance.getServerConfig.mockResolvedValue(serverConfig);
// Mock getAppConfig to return domain restrictions
mockGetAppConfig.mockResolvedValue({
mcpSettings: { allowedDomains: ['allowed-domain.com'] },
});
// Mock domain validation to return false (domain not allowed)
mockIsMCPDomainAllowed.mockResolvedValueOnce(false);
const result = await createMCPTools({
res: mockRes,
user: mockUser,
serverName: 'test-server',
provider: 'openai',
userMCPAuthMap: {},
config: serverConfig,
});
// Should return empty array for disallowed domain
expect(result).toEqual([]);
// Should not call reinitMCPServer since domain check failed early
expect(mockReinitMCPServer).not.toHaveBeenCalled();
// Verify getAppConfig was called with user role
expect(mockGetAppConfig).toHaveBeenCalledWith({ role: 'user' });
});
it('should use user role when fetching domain restrictions', async () => {
const adminUser = { id: 'admin-user', role: 'admin' };
const regularUser = { id: 'regular-user', role: 'user' };
const mockRes = { write: jest.fn(), flush: jest.fn() };
mockRegistryInstance.getServerConfig.mockResolvedValue({
url: 'https://some-domain.com/sse',
});
// Mock different responses based on role
mockGetAppConfig
.mockResolvedValueOnce({ mcpSettings: { allowedDomains: ['admin-allowed.com'] } })
.mockResolvedValueOnce({ mcpSettings: { allowedDomains: ['user-allowed.com'] } });
mockIsMCPDomainAllowed.mockResolvedValue(true);
const availableTools = {
'test-tool::test-server': {
function: {
description: 'Test tool',
parameters: { type: 'object', properties: {} },
},
},
};
// Call with admin user
await createMCPTool({
res: mockRes,
user: adminUser,
toolKey: 'test-tool::test-server',
provider: 'openai',
userMCPAuthMap: {},
availableTools,
});
// Reset and call with regular user
mockRegistryInstance.getServerConfig.mockResolvedValue({
url: 'https://some-domain.com/sse',
});
await createMCPTool({
res: mockRes,
user: regularUser,
toolKey: 'test-tool::test-server',
provider: 'openai',
userMCPAuthMap: {},
availableTools,
});
// Verify getAppConfig was called with correct roles
expect(mockGetAppConfig).toHaveBeenNthCalledWith(1, { role: 'admin' });
expect(mockGetAppConfig).toHaveBeenNthCalledWith(2, { role: 'user' });
});
});
describe('User parameter integrity', () => {
it('should preserve user object properties through the call chain', async () => {
const complexUser = {