🔒 fix: SSRF Protection and Domain Handling in MCP Server Config (#11234)

* 🔒 fix: Enhance SSRF Protection and Domain Handling in MCP Server Configuration

- Updated the `extractMCPServerDomain` function to return the full origin (protocol://hostname:port) for improved protocol/port matching against allowed domains.
- Enhanced tests for `isMCPDomainAllowed` to validate domain access for internal hostnames and .local TLDs, ensuring proper SSRF protection.
- Added detailed comments in the configuration file to clarify security measures regarding allowed domains and internal target access.

* refactor: Domain Validation for WebSocket Protocols in Action and MCP Handling

- Added comprehensive tests to validate handling of WebSocket URLs in `isActionDomainAllowed` and `isMCPDomainAllowed` functions, ensuring that WebSocket protocols are rejected for OpenAPI Actions while allowed for MCP.
- Updated domain validation logic to support HTTP, HTTPS, WS, and WSS protocols, enhancing security and compliance with specifications.
- Refactored `parseDomainSpec` to improve protocol recognition and validation, ensuring robust handling of domain specifications.
- Introduced detailed comments to clarify the purpose and security implications of domain validation functions.
This commit is contained in:
Danny Avila 2026-01-06 13:04:52 -05:00 committed by GitHub
parent a7645f4705
commit 3b41e392ba
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 212 additions and 49 deletions

View file

@ -129,23 +129,37 @@ export function isSSRFTarget(hostname: string): boolean {
return false;
}
/** Supported protocols for domain validation (HTTP, HTTPS, WebSocket) */
type SupportedProtocol = 'http:' | 'https:' | 'ws:' | 'wss:';
/**
* Parsed domain specification including protocol and port constraints.
*/
interface ParsedDomainSpec {
hostname: string;
protocol?: 'http:' | 'https:' | null; // null means any protocol
port?: string | null; // null means any port
protocol: SupportedProtocol | null; // null means any protocol allowed
port: string | null; // null means any port allowed
explicitPort: boolean; // true if port was explicitly specified in original string
isWildcard: boolean;
}
/** Checks if a string starts with a recognized protocol prefix */
function hasRecognizedProtocol(domain: string): boolean {
return (
domain.startsWith('http://') ||
domain.startsWith('https://') ||
domain.startsWith('ws://') ||
domain.startsWith('wss://')
);
}
/**
* Parses a domain specification into its components.
* Supports formats:
* - `example.com` (any protocol, any port)
* - `https://example.com` (https only, any port)
* - `https://example.com:443` (https only, port 443)
* - `wss://ws.example.com` (secure WebSocket only)
* - `*.example.com` (wildcard subdomain)
* @param domain - Domain specification string
* @returns ParsedDomainSpec or null if invalid
@ -154,17 +168,17 @@ function parseDomainSpec(domain: string): ParsedDomainSpec | null {
try {
let normalizedDomain = domain.toLowerCase().trim();
// Early return for obviously invalid formats
if (normalizedDomain === 'http://' || normalizedDomain === 'https://') {
// Early return for obviously invalid formats (protocol-only strings)
const emptyProtocols = ['http://', 'https://', 'ws://', 'wss://'];
if (emptyProtocols.includes(normalizedDomain)) {
return null;
}
// Check for wildcard prefix before parsing
const isWildcard = normalizedDomain.startsWith('*.');
// Check if it has a protocol
const hasProtocol =
normalizedDomain.startsWith('http://') || normalizedDomain.startsWith('https://');
// Check if it has a recognized protocol (http, https, ws, wss)
const hasProtocol = hasRecognizedProtocol(normalizedDomain);
// Check if port was explicitly specified (e.g., :443, :8080)
// Need to check before URL parsing because URL normalizes default ports
@ -180,7 +194,7 @@ function parseDomainSpec(domain: string): ParsedDomainSpec | null {
const url = new URL(normalizedDomain);
// Additional validation that hostname isn't just protocol
if (!url.hostname || url.hostname === 'http:' || url.hostname === 'https:') {
if (!url.hostname || emptyProtocols.some((p) => url.hostname === p.replace('://', ''))) {
return null;
}
@ -188,7 +202,7 @@ function parseDomainSpec(domain: string): ParsedDomainSpec | null {
return {
hostname,
protocol: hasProtocol ? (url.protocol as 'http:' | 'https:') : null,
protocol: hasProtocol ? (url.protocol as SupportedProtocol) : null,
// Use the explicitly specified port, or null if no port was specified
port: explicitPort ? explicitPortValue : null,
explicitPort,
@ -211,30 +225,29 @@ function hostnameMatches(inputHostname: string, allowedSpec: ParsedDomainSpec):
return inputHostname === allowedSpec.hostname;
}
/** Protocol sets for different use cases */
const HTTP_PROTOCOLS: SupportedProtocol[] = ['http:', 'https:'];
const MCP_PROTOCOLS: SupportedProtocol[] = ['http:', 'https:', 'ws:', 'wss:'];
/**
* Checks if the given domain is allowed.
* SECURITY: When no allowedDomains is configured, blocks SSRF-prone targets
* (private IPs, localhost, metadata services) to prevent attacks.
* When allowedDomains IS configured, admins can explicitly allow internal targets if needed.
*
* Supports protocol and port restrictions in allowedDomains:
* - `example.com` - allows any protocol/port
* - `https://example.com` - allows only HTTPS on default port
* - `https://example.com:8443` - allows only HTTPS on port 8443
*
* Core domain validation logic with configurable protocol support.
* SECURITY: When no allowedDomains is configured, blocks SSRF-prone targets.
* @param domain - The domain to check (can include protocol/port)
* @param allowedDomains - List of allowed domain patterns
* @param supportedProtocols - Protocols to accept (others are rejected)
*/
export async function isActionDomainAllowed(
domain?: string | null,
allowedDomains?: string[] | null,
async function isDomainAllowedCore(
domain: string,
allowedDomains: string[] | null | undefined,
supportedProtocols: SupportedProtocol[],
): Promise<boolean> {
if (!domain || typeof domain !== 'string') {
const inputSpec = parseDomainSpec(domain);
if (!inputSpec) {
return false;
}
const inputSpec = parseDomainSpec(domain);
if (!inputSpec) {
// SECURITY: Reject unsupported protocols (e.g., WebSocket for OpenAPI Actions)
if (inputSpec.protocol !== null && !supportedProtocols.includes(inputSpec.protocol)) {
return false;
}
@ -254,6 +267,11 @@ export async function isActionDomainAllowed(
continue;
}
// Skip allowedDomains with unsupported protocols for this context
if (allowedSpec.protocol !== null && !supportedProtocols.includes(allowedSpec.protocol)) {
continue;
}
// Check hostname match (with wildcard support)
if (!hostnameMatches(inputSpec.hostname, allowedSpec)) {
continue;
@ -283,7 +301,24 @@ export async function isActionDomainAllowed(
}
/**
* Extracts domain from MCP server config URL.
* Validates domain for OpenAPI Agent Actions (HTTP/HTTPS only).
* SECURITY: WebSocket protocols are NOT allowed per OpenAPI specification.
* @param domain - The domain to check (can include protocol/port)
* @param allowedDomains - List of allowed domain patterns
*/
export async function isActionDomainAllowed(
domain?: string | null,
allowedDomains?: string[] | null,
): Promise<boolean> {
if (!domain || typeof domain !== 'string') {
return false;
}
return isDomainAllowedCore(domain, allowedDomains, HTTP_PROTOCOLS);
}
/**
* Extracts full domain spec (protocol://hostname:port) from MCP server config URL.
* Returns the full origin for proper protocol/port matching against allowedDomains.
* Returns null for stdio transports (no URL) or invalid URLs.
* @param config - MCP server configuration (accepts any config with optional url field)
*/
@ -296,7 +331,9 @@ export function extractMCPServerDomain(config: Record<string, unknown>): string
try {
const parsedUrl = new URL(url);
return parsedUrl.hostname.replace(/^www\./i, '');
// Return full origin (protocol://hostname:port) for proper domain validation
// This allows admins to restrict by protocol/port in allowedDomains
return parsedUrl.origin;
} catch {
return null;
}
@ -304,7 +341,7 @@ export function extractMCPServerDomain(config: Record<string, unknown>): string
/**
* Validates MCP server domain against allowedDomains.
* Reuses isActionDomainAllowed for consistent validation logic.
* Supports HTTP, HTTPS, WS, and WSS protocols (per MCP specification).
* Stdio transports (no URL) are always allowed.
* @param config - MCP server configuration with optional url field
* @param allowedDomains - List of allowed domains (with wildcard support)
@ -320,6 +357,6 @@ export async function isMCPDomainAllowed(
return true;
}
// Reuse existing validation logic (includes wildcard support)
return isActionDomainAllowed(domain, allowedDomains);
// Use MCP_PROTOCOLS (HTTP/HTTPS/WS/WSS) for MCP server validation
return isDomainAllowedCore(domain, allowedDomains, MCP_PROTOCOLS);
}