🔒 fix: SSRF Protection and Domain Handling in MCP Server Config (#11234)

* 🔒 fix: Enhance SSRF Protection and Domain Handling in MCP Server Configuration

- Updated the `extractMCPServerDomain` function to return the full origin (protocol://hostname:port) for improved protocol/port matching against allowed domains.
- Enhanced tests for `isMCPDomainAllowed` to validate domain access for internal hostnames and .local TLDs, ensuring proper SSRF protection.
- Added detailed comments in the configuration file to clarify security measures regarding allowed domains and internal target access.

* refactor: Domain Validation for WebSocket Protocols in Action and MCP Handling

- Added comprehensive tests to validate handling of WebSocket URLs in `isActionDomainAllowed` and `isMCPDomainAllowed` functions, ensuring that WebSocket protocols are rejected for OpenAPI Actions while allowed for MCP.
- Updated domain validation logic to support HTTP, HTTPS, WS, and WSS protocols, enhancing security and compliance with specifications.
- Refactored `parseDomainSpec` to improve protocol recognition and validation, ensuring robust handling of domain specifications.
- Introduced detailed comments to clarify the purpose and security implications of domain validation functions.
This commit is contained in:
Danny Avila 2026-01-06 13:04:52 -05:00 committed by GitHub
parent a7645f4705
commit 3b41e392ba
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 212 additions and 49 deletions

View file

@ -177,22 +177,29 @@ registration:
# userMax: 50
# userWindowInMinutes: 60 # Rate limit window for conversation imports per user
# Example Actions Object Structure
# Agent Actions domain restrictions (OpenAPI spec validation)
# SECURITY: If not configured, SSRF targets are blocked (localhost, private IPs, .internal/.local TLDs).
# To allow internal targets, you MUST explicitly add them to allowedDomains.
# Supports wildcards: '*.example.com' and protocol/port restrictions: 'https://api.example.com:8443'
actions:
allowedDomains:
- 'swapi.dev'
- 'librechat.ai'
- 'google.com'
# - 'http://10.225.26.25:7894' # Internal IP with protocol/port (uncomment if needed)
# MCP Server domain restrictions for remote transports (SSE, WebSocket, HTTP)
# Stdio transports (local processes) are not restricted.
# If not configured, all domains are allowed (permissive default).
# SECURITY: If not configured, SSRF targets are blocked (localhost, private IPs, .internal/.local TLDs).
# To allow internal targets like host.docker.internal, you MUST explicitly add them to allowedDomains.
# Supports wildcards: '*.example.com' matches 'api.example.com', 'staging.example.com', etc.
# Supports protocol/port restrictions: 'https://api.example.com:8443' restricts to specific protocol/port.
# mcpSettings:
# allowedDomains:
# - 'localhost'
# - '*.example.com'
# - 'trusted-mcp-provider.com'
# - 'host.docker.internal' # Docker host access (required for Docker setups)
# - 'localhost' # Local development
# - '*.example.com' # Wildcard subdomain
# - 'https://secure.api.com' # Protocol-restricted
# - 'http://internal:8080' # Protocol and port restricted
# Example MCP Servers Object Structure
# mcpServers:

View file

@ -341,6 +341,32 @@ describe('isActionDomainAllowed', () => {
// Protocol and Port Restrictions (Recommendation #2)
describe('protocol and port restrictions', () => {
describe('OpenAPI Actions reject WebSocket protocols', () => {
it('should reject ws:// URLs (not part of OpenAPI spec)', async () => {
expect(await isActionDomainAllowed('ws://example.com', ['example.com'])).toBe(false);
expect(await isActionDomainAllowed('ws://example.com', null)).toBe(false);
});
it('should reject wss:// URLs (not part of OpenAPI spec)', async () => {
expect(await isActionDomainAllowed('wss://example.com', ['example.com'])).toBe(false);
expect(await isActionDomainAllowed('wss://example.com', null)).toBe(false);
});
it('should reject WebSocket URLs even if explicitly in allowedDomains', async () => {
expect(await isActionDomainAllowed('wss://ws.example.com', ['wss://ws.example.com'])).toBe(
false,
);
expect(await isActionDomainAllowed('ws://ws.example.com', ['ws://ws.example.com'])).toBe(
false,
);
});
it('should allow only HTTP/HTTPS for OpenAPI Actions', async () => {
expect(await isActionDomainAllowed('http://example.com', ['example.com'])).toBe(true);
expect(await isActionDomainAllowed('https://example.com', ['example.com'])).toBe(true);
});
});
describe('protocol-only restrictions', () => {
const httpsOnlyDomains = ['https://api.example.com', 'https://secure.test.com'];
@ -437,35 +463,40 @@ describe('extractMCPServerDomain', () => {
jest.clearAllMocks();
});
describe('URL extraction', () => {
it('should extract domain from HTTPS URL', () => {
describe('URL extraction (returns full origin for protocol/port matching)', () => {
it('should extract full origin from HTTPS URL', () => {
const config = { url: 'https://api.example.com/sse' };
expect(extractMCPServerDomain(config)).toBe('api.example.com');
expect(extractMCPServerDomain(config)).toBe('https://api.example.com');
});
it('should extract domain from HTTP URL', () => {
it('should extract full origin from HTTP URL', () => {
const config = { url: 'http://api.example.com/sse' };
expect(extractMCPServerDomain(config)).toBe('api.example.com');
expect(extractMCPServerDomain(config)).toBe('http://api.example.com');
});
it('should extract domain from WebSocket URL', () => {
it('should extract full origin from WebSocket URL', () => {
const config = { url: 'wss://ws.example.com' };
expect(extractMCPServerDomain(config)).toBe('ws.example.com');
expect(extractMCPServerDomain(config)).toBe('wss://ws.example.com');
});
it('should handle URL with port', () => {
it('should include port in origin when specified', () => {
const config = { url: 'https://localhost:3001/sse' };
expect(extractMCPServerDomain(config)).toBe('localhost');
expect(extractMCPServerDomain(config)).toBe('https://localhost:3001');
});
it('should strip www prefix', () => {
it('should include port for non-default ports', () => {
const config = { url: 'http://host.docker.internal:8044/mcp' };
expect(extractMCPServerDomain(config)).toBe('http://host.docker.internal:8044');
});
it('should preserve www prefix in origin (matching handles www normalization)', () => {
const config = { url: 'https://www.example.com/api' };
expect(extractMCPServerDomain(config)).toBe('example.com');
expect(extractMCPServerDomain(config)).toBe('https://www.example.com');
});
it('should handle URL with path and query parameters', () => {
it('should strip path and query parameters', () => {
const config = { url: 'https://api.example.com/v1/sse?token=abc' };
expect(extractMCPServerDomain(config)).toBe('api.example.com');
expect(extractMCPServerDomain(config)).toBe('https://api.example.com');
});
});
@ -637,4 +668,92 @@ describe('isMCPDomainAllowed', () => {
expect(await isMCPDomainAllowed(config, ['example.com'])).toBe(true);
});
});
describe('Docker/internal hostname handling (SSRF protection)', () => {
it('should block host.docker.internal without allowedDomains (ends with .internal)', async () => {
const config = { url: 'http://host.docker.internal:8044/mcp' };
expect(await isMCPDomainAllowed(config, null)).toBe(false);
expect(await isMCPDomainAllowed(config, undefined)).toBe(false);
expect(await isMCPDomainAllowed(config, [])).toBe(false);
});
it('should allow host.docker.internal when explicitly in allowedDomains', async () => {
const config = { url: 'http://host.docker.internal:8044/mcp' };
expect(await isMCPDomainAllowed(config, ['host.docker.internal'])).toBe(true);
});
it('should allow host.docker.internal with protocol/port restriction', async () => {
const config = { url: 'http://host.docker.internal:8044/mcp' };
expect(await isMCPDomainAllowed(config, ['http://host.docker.internal:8044'])).toBe(true);
});
it('should reject host.docker.internal with wrong protocol restriction', async () => {
const config = { url: 'http://host.docker.internal:8044/mcp' };
expect(await isMCPDomainAllowed(config, ['https://host.docker.internal:8044'])).toBe(false);
});
it('should reject host.docker.internal with wrong port restriction', async () => {
const config = { url: 'http://host.docker.internal:8044/mcp' };
expect(await isMCPDomainAllowed(config, ['http://host.docker.internal:9000'])).toBe(false);
});
it('should block .local TLD without allowedDomains', async () => {
const config = { url: 'http://myserver.local/mcp' };
expect(await isMCPDomainAllowed(config, null)).toBe(false);
});
it('should allow .local TLD when explicitly in allowedDomains', async () => {
const config = { url: 'http://myserver.local/mcp' };
expect(await isMCPDomainAllowed(config, ['myserver.local'])).toBe(true);
});
});
describe('protocol/port matching with full origin extraction', () => {
it('should match unrestricted allowedDomain against full origin', async () => {
// When allowedDomain has no protocol/port, it should match any protocol/port
const config = { url: 'https://api.example.com:8443/sse' };
expect(await isMCPDomainAllowed(config, ['api.example.com'])).toBe(true);
});
it('should enforce protocol restriction from allowedDomain', async () => {
const config = { url: 'http://api.example.com/sse' };
expect(await isMCPDomainAllowed(config, ['https://api.example.com'])).toBe(false);
expect(await isMCPDomainAllowed(config, ['http://api.example.com'])).toBe(true);
});
it('should enforce port restriction from allowedDomain', async () => {
const config = { url: 'https://api.example.com:8443/sse' };
expect(await isMCPDomainAllowed(config, ['https://api.example.com:8443'])).toBe(true);
expect(await isMCPDomainAllowed(config, ['https://api.example.com:443'])).toBe(false);
});
});
describe('WebSocket URL handling (MCP supports ws/wss)', () => {
it('should allow WebSocket URL when hostname is in allowedDomains', async () => {
const config = { url: 'wss://ws.example.com/mcp' };
expect(await isMCPDomainAllowed(config, ['ws.example.com'])).toBe(true);
});
it('should allow WebSocket URL with protocol restriction', async () => {
const config = { url: 'wss://ws.example.com/mcp' };
expect(await isMCPDomainAllowed(config, ['wss://ws.example.com'])).toBe(true);
});
it('should reject WebSocket URL with wrong protocol restriction', async () => {
const config = { url: 'wss://ws.example.com/mcp' };
expect(await isMCPDomainAllowed(config, ['ws://ws.example.com'])).toBe(false);
});
it('should allow ws:// URL when hostname is in allowedDomains', async () => {
const config = { url: 'ws://localhost:8080/mcp' };
expect(await isMCPDomainAllowed(config, ['localhost'])).toBe(true);
});
it('should allow all MCP protocols (http, https, ws, wss)', async () => {
expect(await isMCPDomainAllowed({ url: 'http://example.com' }, ['example.com'])).toBe(true);
expect(await isMCPDomainAllowed({ url: 'https://example.com' }, ['example.com'])).toBe(true);
expect(await isMCPDomainAllowed({ url: 'ws://example.com' }, ['example.com'])).toBe(true);
expect(await isMCPDomainAllowed({ url: 'wss://example.com' }, ['example.com'])).toBe(true);
});
});
});

View file

@ -129,23 +129,37 @@ export function isSSRFTarget(hostname: string): boolean {
return false;
}
/** Supported protocols for domain validation (HTTP, HTTPS, WebSocket) */
type SupportedProtocol = 'http:' | 'https:' | 'ws:' | 'wss:';
/**
* Parsed domain specification including protocol and port constraints.
*/
interface ParsedDomainSpec {
hostname: string;
protocol?: 'http:' | 'https:' | null; // null means any protocol
port?: string | null; // null means any port
protocol: SupportedProtocol | null; // null means any protocol allowed
port: string | null; // null means any port allowed
explicitPort: boolean; // true if port was explicitly specified in original string
isWildcard: boolean;
}
/** Checks if a string starts with a recognized protocol prefix */
function hasRecognizedProtocol(domain: string): boolean {
return (
domain.startsWith('http://') ||
domain.startsWith('https://') ||
domain.startsWith('ws://') ||
domain.startsWith('wss://')
);
}
/**
* Parses a domain specification into its components.
* Supports formats:
* - `example.com` (any protocol, any port)
* - `https://example.com` (https only, any port)
* - `https://example.com:443` (https only, port 443)
* - `wss://ws.example.com` (secure WebSocket only)
* - `*.example.com` (wildcard subdomain)
* @param domain - Domain specification string
* @returns ParsedDomainSpec or null if invalid
@ -154,17 +168,17 @@ function parseDomainSpec(domain: string): ParsedDomainSpec | null {
try {
let normalizedDomain = domain.toLowerCase().trim();
// Early return for obviously invalid formats
if (normalizedDomain === 'http://' || normalizedDomain === 'https://') {
// Early return for obviously invalid formats (protocol-only strings)
const emptyProtocols = ['http://', 'https://', 'ws://', 'wss://'];
if (emptyProtocols.includes(normalizedDomain)) {
return null;
}
// Check for wildcard prefix before parsing
const isWildcard = normalizedDomain.startsWith('*.');
// Check if it has a protocol
const hasProtocol =
normalizedDomain.startsWith('http://') || normalizedDomain.startsWith('https://');
// Check if it has a recognized protocol (http, https, ws, wss)
const hasProtocol = hasRecognizedProtocol(normalizedDomain);
// Check if port was explicitly specified (e.g., :443, :8080)
// Need to check before URL parsing because URL normalizes default ports
@ -180,7 +194,7 @@ function parseDomainSpec(domain: string): ParsedDomainSpec | null {
const url = new URL(normalizedDomain);
// Additional validation that hostname isn't just protocol
if (!url.hostname || url.hostname === 'http:' || url.hostname === 'https:') {
if (!url.hostname || emptyProtocols.some((p) => url.hostname === p.replace('://', ''))) {
return null;
}
@ -188,7 +202,7 @@ function parseDomainSpec(domain: string): ParsedDomainSpec | null {
return {
hostname,
protocol: hasProtocol ? (url.protocol as 'http:' | 'https:') : null,
protocol: hasProtocol ? (url.protocol as SupportedProtocol) : null,
// Use the explicitly specified port, or null if no port was specified
port: explicitPort ? explicitPortValue : null,
explicitPort,
@ -211,30 +225,29 @@ function hostnameMatches(inputHostname: string, allowedSpec: ParsedDomainSpec):
return inputHostname === allowedSpec.hostname;
}
/** Protocol sets for different use cases */
const HTTP_PROTOCOLS: SupportedProtocol[] = ['http:', 'https:'];
const MCP_PROTOCOLS: SupportedProtocol[] = ['http:', 'https:', 'ws:', 'wss:'];
/**
* Checks if the given domain is allowed.
* SECURITY: When no allowedDomains is configured, blocks SSRF-prone targets
* (private IPs, localhost, metadata services) to prevent attacks.
* When allowedDomains IS configured, admins can explicitly allow internal targets if needed.
*
* Supports protocol and port restrictions in allowedDomains:
* - `example.com` - allows any protocol/port
* - `https://example.com` - allows only HTTPS on default port
* - `https://example.com:8443` - allows only HTTPS on port 8443
*
* Core domain validation logic with configurable protocol support.
* SECURITY: When no allowedDomains is configured, blocks SSRF-prone targets.
* @param domain - The domain to check (can include protocol/port)
* @param allowedDomains - List of allowed domain patterns
* @param supportedProtocols - Protocols to accept (others are rejected)
*/
export async function isActionDomainAllowed(
domain?: string | null,
allowedDomains?: string[] | null,
async function isDomainAllowedCore(
domain: string,
allowedDomains: string[] | null | undefined,
supportedProtocols: SupportedProtocol[],
): Promise<boolean> {
if (!domain || typeof domain !== 'string') {
const inputSpec = parseDomainSpec(domain);
if (!inputSpec) {
return false;
}
const inputSpec = parseDomainSpec(domain);
if (!inputSpec) {
// SECURITY: Reject unsupported protocols (e.g., WebSocket for OpenAPI Actions)
if (inputSpec.protocol !== null && !supportedProtocols.includes(inputSpec.protocol)) {
return false;
}
@ -254,6 +267,11 @@ export async function isActionDomainAllowed(
continue;
}
// Skip allowedDomains with unsupported protocols for this context
if (allowedSpec.protocol !== null && !supportedProtocols.includes(allowedSpec.protocol)) {
continue;
}
// Check hostname match (with wildcard support)
if (!hostnameMatches(inputSpec.hostname, allowedSpec)) {
continue;
@ -283,7 +301,24 @@ export async function isActionDomainAllowed(
}
/**
* Extracts domain from MCP server config URL.
* Validates domain for OpenAPI Agent Actions (HTTP/HTTPS only).
* SECURITY: WebSocket protocols are NOT allowed per OpenAPI specification.
* @param domain - The domain to check (can include protocol/port)
* @param allowedDomains - List of allowed domain patterns
*/
export async function isActionDomainAllowed(
domain?: string | null,
allowedDomains?: string[] | null,
): Promise<boolean> {
if (!domain || typeof domain !== 'string') {
return false;
}
return isDomainAllowedCore(domain, allowedDomains, HTTP_PROTOCOLS);
}
/**
* Extracts full domain spec (protocol://hostname:port) from MCP server config URL.
* Returns the full origin for proper protocol/port matching against allowedDomains.
* Returns null for stdio transports (no URL) or invalid URLs.
* @param config - MCP server configuration (accepts any config with optional url field)
*/
@ -296,7 +331,9 @@ export function extractMCPServerDomain(config: Record<string, unknown>): string
try {
const parsedUrl = new URL(url);
return parsedUrl.hostname.replace(/^www\./i, '');
// Return full origin (protocol://hostname:port) for proper domain validation
// This allows admins to restrict by protocol/port in allowedDomains
return parsedUrl.origin;
} catch {
return null;
}
@ -304,7 +341,7 @@ export function extractMCPServerDomain(config: Record<string, unknown>): string
/**
* Validates MCP server domain against allowedDomains.
* Reuses isActionDomainAllowed for consistent validation logic.
* Supports HTTP, HTTPS, WS, and WSS protocols (per MCP specification).
* Stdio transports (no URL) are always allowed.
* @param config - MCP server configuration with optional url field
* @param allowedDomains - List of allowed domains (with wildcard support)
@ -320,6 +357,6 @@ export async function isMCPDomainAllowed(
return true;
}
// Reuse existing validation logic (includes wildcard support)
return isActionDomainAllowed(domain, allowedDomains);
// Use MCP_PROTOCOLS (HTTP/HTTPS/WS/WSS) for MCP server validation
return isDomainAllowedCore(domain, allowedDomains, MCP_PROTOCOLS);
}