🔒 feat: Add MCP server domain restrictions for remote transports (#11013)

* 🔒 feat: Add MCP server domain restrictions for remote transports

* 🔒 feat: Implement comprehensive MCP error handling and domain validation

- Added `handleMCPError` function to centralize error responses for domain restrictions and inspection failures.
- Introduced custom error classes: `MCPDomainNotAllowedError` and `MCPInspectionFailedError` for better error management.
- Updated MCP server controllers to utilize the new error handling mechanism.
- Enhanced domain validation logic in `createMCPTools` and `createMCPTool` functions to prevent operations on disallowed domains.
- Added tests for runtime domain validation scenarios to ensure correct behavior.

* chore: import order

* 🔒 feat: Enhance domain validation in MCP tools with user role-based restrictions

- Integrated `getAppConfig` to fetch allowed domains based on user roles in `createMCPTools` and `createMCPTool` functions.
- Removed the deprecated `getAllowedDomains` method from `MCPServersRegistry`.
- Updated tests to verify domain restrictions are applied correctly based on user roles.
- Ensured that domain validation logic is consistent and efficient across tool creation processes.

* 🔒 test: Refactor MCP tests to utilize configurable app settings

- Introduced a mock for `getAppConfig` to enhance test flexibility.
- Removed redundant mock definition to streamline test setup.
- Ensured tests are aligned with the latest domain validation logic.

---------

Co-authored-by: Atef Bellaaj <slalom.bellaaj@external.daimlertruck.com>
Co-authored-by: Danny Avila <danny@librechat.ai>
This commit is contained in:
Atef Bellaaj 2025-12-18 19:57:49 +01:00 committed by GitHub
parent 98294755ee
commit 95a69df70e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 815 additions and 75 deletions

View file

@ -0,0 +1,61 @@
/**
* MCP-specific error classes
*/
export const MCPErrorCodes = {
DOMAIN_NOT_ALLOWED: 'MCP_DOMAIN_NOT_ALLOWED',
INSPECTION_FAILED: 'MCP_INSPECTION_FAILED',
} as const;
export type MCPErrorCode = (typeof MCPErrorCodes)[keyof typeof MCPErrorCodes];
/**
* Custom error for MCP domain restriction violations.
* Thrown when a user attempts to connect to an MCP server whose domain is not in the allowlist.
*/
export class MCPDomainNotAllowedError extends Error {
public readonly code = MCPErrorCodes.DOMAIN_NOT_ALLOWED;
public readonly statusCode = 403;
public readonly domain: string;
constructor(domain: string) {
super(`Domain "${domain}" is not allowed`);
this.name = 'MCPDomainNotAllowedError';
this.domain = domain;
Object.setPrototypeOf(this, MCPDomainNotAllowedError.prototype);
}
}
/**
* Custom error for MCP server inspection failures.
* Thrown when attempting to connect/inspect an MCP server fails.
*/
export class MCPInspectionFailedError extends Error {
public readonly code = MCPErrorCodes.INSPECTION_FAILED;
public readonly statusCode = 400;
public readonly serverName: string;
constructor(serverName: string, cause?: Error) {
super(`Failed to connect to MCP server "${serverName}"`);
this.name = 'MCPInspectionFailedError';
this.serverName = serverName;
if (cause) {
this.cause = cause;
}
Object.setPrototypeOf(this, MCPInspectionFailedError.prototype);
}
}
/**
* Type guard to check if an error is an MCPDomainNotAllowedError
*/
export function isMCPDomainNotAllowedError(error: unknown): error is MCPDomainNotAllowedError {
return error instanceof MCPDomainNotAllowedError;
}
/**
* Type guard to check if an error is an MCPInspectionFailedError
*/
export function isMCPInspectionFailedError(error: unknown): error is MCPInspectionFailedError {
return error instanceof MCPInspectionFailedError;
}

View file

@ -2,7 +2,9 @@ import { Constants } from 'librechat-data-provider';
import type { JsonSchemaType } from '@librechat/data-schemas';
import type { MCPConnection } from '~/mcp/connection';
import type * as t from '~/mcp/types';
import { isMCPDomainAllowed, extractMCPServerDomain } from '~/auth/domain';
import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory';
import { MCPDomainNotAllowedError } from '~/mcp/errors';
import { detectOAuthRequirement } from '~/mcp/oauth';
import { isEnabled } from '~/utils';
@ -24,13 +26,22 @@ export class MCPServerInspector {
* @param serverName - The name of the server (used for tool function naming)
* @param rawConfig - The raw server configuration
* @param connection - The MCP connection
* @param allowedDomains - Optional list of allowed domains for remote transports
* @returns A fully processed and enriched configuration with server metadata
*/
public static async inspect(
serverName: string,
rawConfig: t.MCPOptions,
connection?: MCPConnection,
allowedDomains?: string[] | null,
): Promise<t.ParsedServerConfig> {
// Validate domain against allowlist BEFORE attempting connection
const isDomainAllowed = await isMCPDomainAllowed(rawConfig, allowedDomains);
if (!isDomainAllowed) {
const domain = extractMCPServerDomain(rawConfig);
throw new MCPDomainNotAllowedError(domain ?? 'unknown');
}
const start = Date.now();
const inspector = new MCPServerInspector(serverName, rawConfig, connection);
await inspector.inspectServer();

View file

@ -1,6 +1,7 @@
import { logger } from '@librechat/data-schemas';
import type { IServerConfigsRepositoryInterface } from './ServerConfigsRepositoryInterface';
import type * as t from '~/mcp/types';
import { MCPInspectionFailedError, isMCPDomainNotAllowedError } from '~/mcp/errors';
import { ServerConfigsCacheFactory } from './cache/ServerConfigsCacheFactory';
import { MCPServerInspector } from './MCPServerInspector';
import { ServerConfigsDB } from './db/ServerConfigsDB';
@ -20,14 +21,19 @@ export class MCPServersRegistry {
private readonly dbConfigsRepo: IServerConfigsRepositoryInterface;
private readonly cacheConfigsRepo: IServerConfigsRepositoryInterface;
private readonly allowedDomains?: string[] | null;
constructor(mongoose: typeof import('mongoose')) {
constructor(mongoose: typeof import('mongoose'), allowedDomains?: string[] | null) {
this.dbConfigsRepo = new ServerConfigsDB(mongoose);
this.cacheConfigsRepo = ServerConfigsCacheFactory.create('App', false);
this.allowedDomains = allowedDomains;
}
/** Creates and initializes the singleton MCPServersRegistry instance */
public static createInstance(mongoose: typeof import('mongoose')): MCPServersRegistry {
public static createInstance(
mongoose: typeof import('mongoose'),
allowedDomains?: string[] | null,
): MCPServersRegistry {
if (!mongoose) {
throw new Error(
'MCPServersRegistry creation failed: mongoose instance is required for database operations. ' +
@ -39,7 +45,7 @@ export class MCPServersRegistry {
return MCPServersRegistry.instance;
}
logger.info('[MCPServersRegistry] Creating new instance');
MCPServersRegistry.instance = new MCPServersRegistry(mongoose);
MCPServersRegistry.instance = new MCPServersRegistry(mongoose, allowedDomains);
return MCPServersRegistry.instance;
}
@ -80,10 +86,19 @@ export class MCPServersRegistry {
const configRepo = this.getConfigRepository(storageLocation);
let parsedConfig: t.ParsedServerConfig;
try {
parsedConfig = await MCPServerInspector.inspect(serverName, config);
parsedConfig = await MCPServerInspector.inspect(
serverName,
config,
undefined,
this.allowedDomains,
);
} catch (error) {
logger.error(`[MCPServersRegistry] Failed to inspect server "${serverName}":`, error);
throw new Error(`MCP_INSPECTION_FAILED: Failed to connect to MCP server "${serverName}"`);
// Preserve domain-specific error for better error handling
if (isMCPDomainNotAllowedError(error)) {
throw error;
}
throw new MCPInspectionFailedError(serverName, error as Error);
}
return await configRepo.add(serverName, parsedConfig, userId);
}
@ -113,10 +128,19 @@ export class MCPServersRegistry {
let parsedConfig: t.ParsedServerConfig;
try {
parsedConfig = await MCPServerInspector.inspect(serverName, configForInspection);
parsedConfig = await MCPServerInspector.inspect(
serverName,
configForInspection,
undefined,
this.allowedDomains,
);
} catch (error) {
logger.error(`[MCPServersRegistry] Failed to inspect server "${serverName}":`, error);
throw new Error(`MCP_INSPECTION_FAILED: Failed to connect to MCP server "${serverName}"`);
// Preserve domain-specific error for better error handling
if (isMCPDomainNotAllowedError(error)) {
throw error;
}
throw new MCPInspectionFailedError(serverName, error as Error);
}
await configRepo.update(serverName, parsedConfig, userId);
return parsedConfig;

View file

@ -224,18 +224,38 @@ describe('MCPServersInitializer', () => {
it('should process all server configs through inspector', async () => {
await MCPServersInitializer.initialize(testConfigs);
// Verify all configs were processed by inspector (without connection parameter)
// Verify all configs were processed by inspector
// Signature: inspect(serverName, rawConfig, connection?, allowedDomains?)
expect(mockInspect).toHaveBeenCalledTimes(5);
expect(mockInspect).toHaveBeenCalledWith('disabled_server', testConfigs.disabled_server);
expect(mockInspect).toHaveBeenCalledWith('oauth_server', testConfigs.oauth_server);
expect(mockInspect).toHaveBeenCalledWith('file_tools_server', testConfigs.file_tools_server);
expect(mockInspect).toHaveBeenCalledWith(
'disabled_server',
testConfigs.disabled_server,
undefined,
undefined,
);
expect(mockInspect).toHaveBeenCalledWith(
'oauth_server',
testConfigs.oauth_server,
undefined,
undefined,
);
expect(mockInspect).toHaveBeenCalledWith(
'file_tools_server',
testConfigs.file_tools_server,
undefined,
undefined,
);
expect(mockInspect).toHaveBeenCalledWith(
'search_tools_server',
testConfigs.search_tools_server,
undefined,
undefined,
);
expect(mockInspect).toHaveBeenCalledWith(
'remote_no_oauth_server',
testConfigs.remote_no_oauth_server,
undefined,
undefined,
);
});