🛡️ fix: Implement TOCTOU-Safe SSRF Protection for Actions and MCP (#11722)

* refactor: better SSRF Protection in Action and Tool Services

- Added `createSSRFSafeAgents` function to create HTTP/HTTPS agents that block connections to private/reserved IP addresses, enhancing security against SSRF attacks.
- Updated `createActionTool` to accept a `useSSRFProtection` parameter, allowing the use of SSRF-safe agents during tool execution.
- Modified `processRequiredActions` and `loadAgentTools` to utilize the new SSRF protection feature based on allowed domains configuration.
- Introduced `resolveHostnameSSRF` function to validate resolved IPs against private ranges, preventing potential SSRF vulnerabilities.
- Enhanced tests for domain resolution and private IP detection to ensure robust SSRF protection mechanisms are in place.

* feat: Implement SSRF protection in MCP connections

- Added `createSSRFSafeUndiciConnect` function to provide SSRF-safe DNS lookup options for undici agents.
- Updated `MCPConnection`, `MCPConnectionFactory`, and `ConnectionsRepository` to include `useSSRFProtection` parameter, enabling SSRF protection based on server configuration.
- Enhanced `MCPManager` and `UserConnectionManager` to utilize SSRF protection when establishing connections.
- Updated tests to validate the integration of SSRF protection across various components, ensuring robust security measures are in place.

* refactor: WS MCPConnection with SSRF protection and async transport construction

- Added `resolveHostnameSSRF` to validate WebSocket hostnames against private IP addresses, enhancing SSRF protection.
- Updated `constructTransport` method to be asynchronous, ensuring proper handling of SSRF checks before establishing connections.
- Improved error handling for WebSocket transport to prevent connections to potentially unsafe addresses.

* test: Enhance ActionRequest tests for SSRF-safe agent passthrough

- Added tests to verify that httpAgent and httpsAgent are correctly passed to axios.create when provided in ActionRequest.
- Included scenarios to ensure agents are not included when no options are specified.
- Enhanced coverage for POST requests to confirm agent passthrough functionality.
- Improved overall test robustness for SSRF protection in ActionRequest execution.
This commit is contained in:
Danny Avila 2026-02-11 22:09:58 -05:00 committed by GitHub
parent d6b6f191f7
commit 924be3b647
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 567 additions and 53 deletions

View file

@ -73,6 +73,7 @@ export class ConnectionsRepository {
{
serverName,
serverConfig,
useSSRFProtection: MCPServersRegistry.getInstance().shouldEnableSSRFProtection(),
},
this.oauthOpts,
);

View file

@ -29,6 +29,7 @@ export class MCPConnectionFactory {
protected readonly serverConfig: t.MCPOptions;
protected readonly logPrefix: string;
protected readonly useOAuth: boolean;
protected readonly useSSRFProtection: boolean;
// OAuth-related properties (only set when useOAuth is true)
protected readonly userId?: string;
@ -72,6 +73,7 @@ export class MCPConnectionFactory {
serverConfig: this.serverConfig,
userId: this.userId,
oauthTokens,
useSSRFProtection: this.useSSRFProtection,
});
const oauthHandler = async () => {
@ -146,6 +148,7 @@ export class MCPConnectionFactory {
serverConfig: this.serverConfig,
userId: this.userId,
oauthTokens: null,
useSSRFProtection: this.useSSRFProtection,
});
unauthConnection.on('oauthRequired', () => {
@ -189,6 +192,7 @@ export class MCPConnectionFactory {
});
this.serverName = basic.serverName;
this.useOAuth = !!oauth?.useOAuth;
this.useSSRFProtection = basic.useSSRFProtection === true;
this.connectionTimeout = oauth?.connectionTimeout;
this.logPrefix = oauth?.user
? `[MCP][${basic.serverName}][${oauth.user.id}]`
@ -213,6 +217,7 @@ export class MCPConnectionFactory {
serverConfig: this.serverConfig,
userId: this.userId,
oauthTokens,
useSSRFProtection: this.useSSRFProtection,
});
let cleanupOAuthHandlers: (() => void) | null = null;

View file

@ -102,7 +102,8 @@ export class MCPManager extends UserConnectionManager {
serverConfig.requiresOAuth || (serverConfig as t.ParsedServerConfig).oauthMetadata,
);
const basic: t.BasicConnectionOptions = { serverName, serverConfig };
const useSSRFProtection = MCPServersRegistry.getInstance().shouldEnableSSRFProtection();
const basic: t.BasicConnectionOptions = { serverName, serverConfig, useSSRFProtection };
if (!useOAuth) {
const result = await MCPConnectionFactory.discoverTools(basic);

View file

@ -117,6 +117,7 @@ export abstract class UserConnectionManager {
{
serverName: serverName,
serverConfig: config,
useSSRFProtection: MCPServersRegistry.getInstance().shouldEnableSSRFProtection(),
},
{
useOAuth: true,

View file

@ -24,6 +24,7 @@ jest.mock('../connection');
const mockRegistryInstance = {
getServerConfig: jest.fn(),
getAllServerConfigs: jest.fn(),
shouldEnableSSRFProtection: jest.fn().mockReturnValue(false),
};
jest.mock('../registry/MCPServersRegistry', () => ({
@ -108,6 +109,7 @@ describe('ConnectionsRepository', () => {
{
serverName: 'server1',
serverConfig: mockServerConfigs.server1,
useSSRFProtection: false,
},
undefined,
);
@ -129,6 +131,7 @@ describe('ConnectionsRepository', () => {
{
serverName: 'server1',
serverConfig: mockServerConfigs.server1,
useSSRFProtection: false,
},
undefined,
);
@ -167,6 +170,7 @@ describe('ConnectionsRepository', () => {
{
serverName: 'server1',
serverConfig: configWithCachedAt,
useSSRFProtection: false,
},
undefined,
);

View file

@ -84,6 +84,7 @@ describe('MCPConnectionFactory', () => {
serverConfig: mockServerConfig,
userId: undefined,
oauthTokens: null,
useSSRFProtection: false,
});
expect(mockConnectionInstance.connect).toHaveBeenCalled();
});
@ -125,6 +126,7 @@ describe('MCPConnectionFactory', () => {
serverConfig: mockServerConfig,
userId: 'user123',
oauthTokens: mockTokens,
useSSRFProtection: false,
});
});
});
@ -184,6 +186,7 @@ describe('MCPConnectionFactory', () => {
serverConfig: mockServerConfig,
userId: 'user123',
oauthTokens: null,
useSSRFProtection: false,
});
expect(mockLogger.debug).toHaveBeenCalledWith(
expect.stringContaining('No existing tokens found or error loading tokens'),

View file

@ -33,6 +33,7 @@ const mockRegistryInstance = {
getServerConfig: jest.fn(),
getAllServerConfigs: jest.fn(),
getOAuthServers: jest.fn(),
shouldEnableSSRFProtection: jest.fn().mockReturnValue(false),
};
jest.mock('~/mcp/registry/MCPServersRegistry', () => ({

View file

@ -20,6 +20,7 @@ import type {
import type { MCPOAuthTokens } from './oauth/types';
import { withTimeout } from '~/utils/promise';
import type * as t from './types';
import { createSSRFSafeUndiciConnect, resolveHostnameSSRF } from '~/auth';
import { sanitizeUrlForLogging } from './utils';
import { mcpConfig } from './mcpConfig';
@ -213,6 +214,7 @@ interface MCPConnectionParams {
serverConfig: t.MCPOptions;
userId?: string;
oauthTokens?: MCPOAuthTokens | null;
useSSRFProtection?: boolean;
}
export class MCPConnection extends EventEmitter {
@ -233,6 +235,7 @@ export class MCPConnection extends EventEmitter {
private oauthTokens?: MCPOAuthTokens | null;
private requestHeaders?: Record<string, string> | null;
private oauthRequired = false;
private readonly useSSRFProtection: boolean;
iconPath?: string;
timeout?: number;
url?: string;
@ -263,6 +266,7 @@ export class MCPConnection extends EventEmitter {
this.options = params.serverConfig;
this.serverName = params.serverName;
this.userId = params.userId;
this.useSSRFProtection = params.useSSRFProtection === true;
this.iconPath = params.serverConfig.iconPath;
this.timeout = params.serverConfig.timeout;
this.lastPingTime = Date.now();
@ -301,6 +305,7 @@ export class MCPConnection extends EventEmitter {
getHeaders: () => Record<string, string> | null | undefined,
timeout?: number,
): (input: UndiciRequestInfo, init?: UndiciRequestInit) => Promise<UndiciResponse> {
const ssrfConnect = this.useSSRFProtection ? createSSRFSafeUndiciConnect() : undefined;
return function customFetch(
input: UndiciRequestInfo,
init?: UndiciRequestInit,
@ -310,6 +315,7 @@ export class MCPConnection extends EventEmitter {
const agent = new Agent({
bodyTimeout: effectiveTimeout,
headersTimeout: effectiveTimeout,
...(ssrfConnect != null ? { connect: ssrfConnect } : {}),
});
if (!requestHeaders) {
return undiciFetch(input, { ...init, dispatcher: agent });
@ -342,7 +348,7 @@ export class MCPConnection extends EventEmitter {
logger.error(`${this.getLogPrefix()} ${errorContext}: ${errorMessage}`);
}
private constructTransport(options: t.MCPOptions): Transport {
private async constructTransport(options: t.MCPOptions): Promise<Transport> {
try {
let type: t.MCPOptions['type'];
if (isStdioOptions(options)) {
@ -378,6 +384,15 @@ export class MCPConnection extends EventEmitter {
throw new Error('Invalid options for websocket transport.');
}
this.url = options.url;
if (this.useSSRFProtection) {
const wsHostname = new URL(options.url).hostname;
const isSSRF = await resolveHostnameSSRF(wsHostname);
if (isSSRF) {
throw new Error(
`SSRF protection: WebSocket host "${wsHostname}" resolved to a private/reserved IP address`,
);
}
}
return new WebSocketClientTransport(new URL(options.url));
case 'sse': {
@ -402,6 +417,7 @@ export class MCPConnection extends EventEmitter {
* The connect timeout is extended because proxies may delay initial response.
*/
const sseTimeout = this.timeout || SSE_CONNECT_TIMEOUT;
const ssrfConnect = this.useSSRFProtection ? createSSRFSafeUndiciConnect() : undefined;
const transport = new SSEClientTransport(url, {
requestInit: {
/** User/OAuth headers override SSE defaults */
@ -420,6 +436,7 @@ export class MCPConnection extends EventEmitter {
/** Extended keep-alive for long-lived SSE connections */
keepAliveTimeout: sseTimeout,
keepAliveMaxTimeout: sseTimeout * 2,
...(ssrfConnect != null ? { connect: ssrfConnect } : {}),
});
return undiciFetch(url, {
...init,
@ -629,7 +646,7 @@ export class MCPConnection extends EventEmitter {
}
}
this.transport = this.constructTransport(this.options);
this.transport = await this.constructTransport(this.options);
this.setupTransportDebugHandlers();
const connectTimeout = this.options.initTimeout ?? 120000;

View file

@ -18,6 +18,7 @@ export class MCPServerInspector {
private readonly serverName: string,
private readonly config: t.ParsedServerConfig,
private connection: MCPConnection | undefined,
private readonly useSSRFProtection: boolean = false,
) {}
/**
@ -42,8 +43,9 @@ export class MCPServerInspector {
throw new MCPDomainNotAllowedError(domain ?? 'unknown');
}
const useSSRFProtection = !Array.isArray(allowedDomains) || allowedDomains.length === 0;
const start = Date.now();
const inspector = new MCPServerInspector(serverName, rawConfig, connection);
const inspector = new MCPServerInspector(serverName, rawConfig, connection, useSSRFProtection);
await inspector.inspectServer();
inspector.config.initDuration = Date.now() - start;
return inspector.config;
@ -59,6 +61,7 @@ export class MCPServerInspector {
this.connection = await MCPConnectionFactory.create({
serverName: this.serverName,
serverConfig: this.config,
useSSRFProtection: this.useSSRFProtection,
});
}

View file

@ -77,6 +77,15 @@ export class MCPServersRegistry {
return MCPServersRegistry.instance;
}
public getAllowedDomains(): string[] | null | undefined {
return this.allowedDomains;
}
/** Returns true when no explicit allowedDomains allowlist is configured, enabling SSRF TOCTOU protection */
public shouldEnableSSRFProtection(): boolean {
return !Array.isArray(this.allowedDomains) || this.allowedDomains.length === 0;
}
public async getServerConfig(
serverName: string,
userId?: string,

View file

@ -276,6 +276,7 @@ describe('MCPServerInspector', () => {
expect(MCPConnectionFactory.create).toHaveBeenCalledWith({
serverName: 'test_server',
serverConfig: expect.objectContaining({ type: 'stdio', command: 'node' }),
useSSRFProtection: true,
});
// Verify temporary connection was disconnected

View file

@ -166,6 +166,7 @@ export type AddServerResult = {
export interface BasicConnectionOptions {
serverName: string;
serverConfig: MCPOptions;
useSSRFProtection?: boolean;
}
export interface OAuthConnectionOptions {