🛡️ fix: Implement TOCTOU-Safe SSRF Protection for Actions and MCP (#11722)

* refactor: better SSRF Protection in Action and Tool Services

- Added `createSSRFSafeAgents` function to create HTTP/HTTPS agents that block connections to private/reserved IP addresses, enhancing security against SSRF attacks.
- Updated `createActionTool` to accept a `useSSRFProtection` parameter, allowing the use of SSRF-safe agents during tool execution.
- Modified `processRequiredActions` and `loadAgentTools` to utilize the new SSRF protection feature based on allowed domains configuration.
- Introduced `resolveHostnameSSRF` function to validate resolved IPs against private ranges, preventing potential SSRF vulnerabilities.
- Enhanced tests for domain resolution and private IP detection to ensure robust SSRF protection mechanisms are in place.

* feat: Implement SSRF protection in MCP connections

- Added `createSSRFSafeUndiciConnect` function to provide SSRF-safe DNS lookup options for undici agents.
- Updated `MCPConnection`, `MCPConnectionFactory`, and `ConnectionsRepository` to include `useSSRFProtection` parameter, enabling SSRF protection based on server configuration.
- Enhanced `MCPManager` and `UserConnectionManager` to utilize SSRF protection when establishing connections.
- Updated tests to validate the integration of SSRF protection across various components, ensuring robust security measures are in place.

* refactor: WS MCPConnection with SSRF protection and async transport construction

- Added `resolveHostnameSSRF` to validate WebSocket hostnames against private IP addresses, enhancing SSRF protection.
- Updated `constructTransport` method to be asynchronous, ensuring proper handling of SSRF checks before establishing connections.
- Improved error handling for WebSocket transport to prevent connections to potentially unsafe addresses.

* test: Enhance ActionRequest tests for SSRF-safe agent passthrough

- Added tests to verify that httpAgent and httpsAgent are correctly passed to axios.create when provided in ActionRequest.
- Included scenarios to ensure agents are not included when no options are specified.
- Enhanced coverage for POST requests to confirm agent passthrough functionality.
- Improved overall test robustness for SSRF protection in ActionRequest execution.
This commit is contained in:
Danny Avila 2026-02-11 22:09:58 -05:00 committed by GitHub
parent d6b6f191f7
commit 924be3b647
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 567 additions and 53 deletions

View file

@ -0,0 +1,113 @@
jest.mock('node:dns', () => {
const actual = jest.requireActual('node:dns');
return {
...actual,
lookup: jest.fn(),
};
});
import dns from 'node:dns';
import { createSSRFSafeAgents, createSSRFSafeUndiciConnect } from './agent';
type LookupCallback = (err: NodeJS.ErrnoException | null, address: string, family: number) => void;
const mockedDnsLookup = dns.lookup as jest.MockedFunction<typeof dns.lookup>;
function mockDnsResult(address: string, family: number): void {
mockedDnsLookup.mockImplementation(((
_hostname: string,
_options: unknown,
callback: LookupCallback,
) => {
callback(null, address, family);
}) as never);
}
function mockDnsError(err: NodeJS.ErrnoException): void {
mockedDnsLookup.mockImplementation(((
_hostname: string,
_options: unknown,
callback: LookupCallback,
) => {
callback(err, '', 0);
}) as never);
}
describe('createSSRFSafeAgents', () => {
afterEach(() => {
jest.clearAllMocks();
});
it('should return httpAgent and httpsAgent', () => {
const agents = createSSRFSafeAgents();
expect(agents.httpAgent).toBeDefined();
expect(agents.httpsAgent).toBeDefined();
});
it('should patch httpAgent createConnection to inject SSRF lookup', () => {
const agents = createSSRFSafeAgents();
const internal = agents.httpAgent as unknown as {
createConnection: (opts: Record<string, unknown>) => unknown;
};
expect(internal.createConnection).toBeInstanceOf(Function);
});
});
describe('createSSRFSafeUndiciConnect', () => {
afterEach(() => {
jest.clearAllMocks();
});
it('should return an object with a lookup function', () => {
const connect = createSSRFSafeUndiciConnect();
expect(connect).toHaveProperty('lookup');
expect(connect.lookup).toBeInstanceOf(Function);
});
it('lookup should block private IPs', async () => {
mockDnsResult('10.0.0.1', 4);
const connect = createSSRFSafeUndiciConnect();
const result = await new Promise<{ err: NodeJS.ErrnoException | null }>((resolve) => {
connect.lookup('evil.example.com', {}, (err) => {
resolve({ err });
});
});
expect(result.err).toBeTruthy();
expect(result.err!.code).toBe('ESSRF');
});
it('lookup should allow public IPs', async () => {
mockDnsResult('93.184.216.34', 4);
const connect = createSSRFSafeUndiciConnect();
const result = await new Promise<{ err: NodeJS.ErrnoException | null; address: string }>(
(resolve) => {
connect.lookup('example.com', {}, (err, address) => {
resolve({ err, address: address as string });
});
},
);
expect(result.err).toBeNull();
expect(result.address).toBe('93.184.216.34');
});
it('lookup should forward DNS errors', async () => {
const dnsError = Object.assign(new Error('ENOTFOUND'), {
code: 'ENOTFOUND',
}) as NodeJS.ErrnoException;
mockDnsError(dnsError);
const connect = createSSRFSafeUndiciConnect();
const result = await new Promise<{ err: NodeJS.ErrnoException | null }>((resolve) => {
connect.lookup('nonexistent.example.com', {}, (err) => {
resolve({ err });
});
});
expect(result.err).toBeTruthy();
expect(result.err!.code).toBe('ENOTFOUND');
});
});

View file

@ -0,0 +1,61 @@
import dns from 'node:dns';
import http from 'node:http';
import https from 'node:https';
import type { LookupFunction } from 'node:net';
import { isPrivateIP } from './domain';
/** DNS lookup wrapper that blocks resolution to private/reserved IP addresses */
const ssrfSafeLookup: LookupFunction = (hostname, options, callback) => {
dns.lookup(hostname, options, (err, address, family) => {
if (err) {
callback(err, '', 0);
return;
}
if (typeof address === 'string' && isPrivateIP(address)) {
const ssrfError = Object.assign(
new Error(`SSRF protection: ${hostname} resolved to blocked address ${address}`),
{ code: 'ESSRF' },
) as NodeJS.ErrnoException;
callback(ssrfError, address, family as number);
return;
}
callback(null, address as string, family as number);
});
};
/** Internal agent shape exposing createConnection (exists at runtime but not in TS types) */
type AgentInternal = {
createConnection: (options: Record<string, unknown>, oncreate?: unknown) => unknown;
};
/** Patches an agent instance to inject SSRF-safe DNS lookup at connect time */
function withSSRFProtection<T extends http.Agent>(agent: T): T {
const internal = agent as unknown as AgentInternal;
const origCreate = internal.createConnection.bind(agent);
internal.createConnection = (options: Record<string, unknown>, oncreate?: unknown) => {
options.lookup = ssrfSafeLookup;
return origCreate(options, oncreate);
};
return agent;
}
/**
* Creates HTTP and HTTPS agents that block TCP connections to private/reserved IP addresses.
* Provides TOCTOU-safe SSRF protection by validating the resolved IP at connect time,
* preventing DNS rebinding attacks where a hostname resolves to a public IP during
* pre-validation but to a private IP when the actual connection is made.
*/
export function createSSRFSafeAgents(): { httpAgent: http.Agent; httpsAgent: https.Agent } {
return {
httpAgent: withSSRFProtection(new http.Agent()),
httpsAgent: withSSRFProtection(new https.Agent()),
};
}
/**
* Returns undici-compatible `connect` options with SSRF-safe DNS lookup.
* Pass the result as the `connect` property when constructing an undici `Agent`.
*/
export function createSSRFSafeUndiciConnect(): { lookup: LookupFunction } {
return { lookup: ssrfSafeLookup };
}

View file

@ -1,12 +1,21 @@
/* eslint-disable @typescript-eslint/ban-ts-comment */
jest.mock('node:dns/promises', () => ({
lookup: jest.fn(),
}));
import { lookup } from 'node:dns/promises';
import {
extractMCPServerDomain,
isActionDomainAllowed,
isEmailDomainAllowed,
isMCPDomainAllowed,
isPrivateIP,
isSSRFTarget,
resolveHostnameSSRF,
} from './domain';
const mockedLookup = lookup as jest.MockedFunction<typeof lookup>;
describe('isEmailDomainAllowed', () => {
afterEach(() => {
jest.clearAllMocks();
@ -192,7 +201,154 @@ describe('isSSRFTarget', () => {
});
});
describe('isPrivateIP', () => {
describe('IPv4 private ranges', () => {
it('should detect loopback addresses', () => {
expect(isPrivateIP('127.0.0.1')).toBe(true);
expect(isPrivateIP('127.255.255.255')).toBe(true);
});
it('should detect 10.x.x.x private range', () => {
expect(isPrivateIP('10.0.0.1')).toBe(true);
expect(isPrivateIP('10.255.255.255')).toBe(true);
});
it('should detect 172.16-31.x.x private range', () => {
expect(isPrivateIP('172.16.0.1')).toBe(true);
expect(isPrivateIP('172.31.255.255')).toBe(true);
expect(isPrivateIP('172.15.0.1')).toBe(false);
expect(isPrivateIP('172.32.0.1')).toBe(false);
});
it('should detect 192.168.x.x private range', () => {
expect(isPrivateIP('192.168.0.1')).toBe(true);
expect(isPrivateIP('192.168.255.255')).toBe(true);
});
it('should detect 169.254.x.x link-local range', () => {
expect(isPrivateIP('169.254.169.254')).toBe(true);
expect(isPrivateIP('169.254.0.1')).toBe(true);
});
it('should detect 0.0.0.0', () => {
expect(isPrivateIP('0.0.0.0')).toBe(true);
});
it('should allow public IPs', () => {
expect(isPrivateIP('8.8.8.8')).toBe(false);
expect(isPrivateIP('1.1.1.1')).toBe(false);
expect(isPrivateIP('93.184.216.34')).toBe(false);
});
});
describe('IPv6 private ranges', () => {
it('should detect loopback', () => {
expect(isPrivateIP('::1')).toBe(true);
expect(isPrivateIP('::')).toBe(true);
expect(isPrivateIP('[::1]')).toBe(true);
});
it('should detect unique local (fc/fd) and link-local (fe80)', () => {
expect(isPrivateIP('fc00::1')).toBe(true);
expect(isPrivateIP('fd00::1')).toBe(true);
expect(isPrivateIP('fe80::1')).toBe(true);
});
});
describe('IPv4-mapped IPv6 addresses', () => {
it('should detect private IPs in IPv4-mapped IPv6 form', () => {
expect(isPrivateIP('::ffff:169.254.169.254')).toBe(true);
expect(isPrivateIP('::ffff:127.0.0.1')).toBe(true);
expect(isPrivateIP('::ffff:10.0.0.1')).toBe(true);
expect(isPrivateIP('::ffff:192.168.1.1')).toBe(true);
});
it('should allow public IPs in IPv4-mapped IPv6 form', () => {
expect(isPrivateIP('::ffff:8.8.8.8')).toBe(false);
expect(isPrivateIP('::ffff:93.184.216.34')).toBe(false);
});
});
});
describe('resolveHostnameSSRF', () => {
afterEach(() => {
jest.clearAllMocks();
});
it('should detect domains that resolve to private IPs (nip.io bypass)', async () => {
mockedLookup.mockResolvedValueOnce([{ address: '169.254.169.254', family: 4 }] as never);
expect(await resolveHostnameSSRF('169.254.169.254.nip.io')).toBe(true);
});
it('should detect domains that resolve to loopback', async () => {
mockedLookup.mockResolvedValueOnce([{ address: '127.0.0.1', family: 4 }] as never);
expect(await resolveHostnameSSRF('loopback.example.com')).toBe(true);
});
it('should detect when any resolved address is private', async () => {
mockedLookup.mockResolvedValueOnce([
{ address: '93.184.216.34', family: 4 },
{ address: '10.0.0.1', family: 4 },
] as never);
expect(await resolveHostnameSSRF('dual.example.com')).toBe(true);
});
it('should allow domains that resolve to public IPs', async () => {
mockedLookup.mockResolvedValueOnce([{ address: '93.184.216.34', family: 4 }] as never);
expect(await resolveHostnameSSRF('example.com')).toBe(false);
});
it('should skip literal IPv4 addresses (handled by isSSRFTarget)', async () => {
expect(await resolveHostnameSSRF('169.254.169.254')).toBe(false);
expect(mockedLookup).not.toHaveBeenCalled();
});
it('should skip literal IPv6 addresses', async () => {
expect(await resolveHostnameSSRF('::1')).toBe(false);
expect(mockedLookup).not.toHaveBeenCalled();
});
it('should fail open on DNS resolution failure', async () => {
mockedLookup.mockRejectedValueOnce(new Error('ENOTFOUND'));
expect(await resolveHostnameSSRF('nonexistent.example.com')).toBe(false);
});
});
describe('isActionDomainAllowed - DNS resolution SSRF protection', () => {
afterEach(() => {
jest.clearAllMocks();
});
it('should block domains resolving to cloud metadata IP (169.254.169.254)', async () => {
mockedLookup.mockResolvedValueOnce([{ address: '169.254.169.254', family: 4 }] as never);
expect(await isActionDomainAllowed('169.254.169.254.nip.io', null)).toBe(false);
});
it('should block domains resolving to private 10.x range', async () => {
mockedLookup.mockResolvedValueOnce([{ address: '10.0.0.5', family: 4 }] as never);
expect(await isActionDomainAllowed('internal.attacker.com', null)).toBe(false);
});
it('should block domains resolving to 172.16.x range', async () => {
mockedLookup.mockResolvedValueOnce([{ address: '172.16.0.1', family: 4 }] as never);
expect(await isActionDomainAllowed('docker.attacker.com', null)).toBe(false);
});
it('should allow domains resolving to public IPs when no allowlist', async () => {
mockedLookup.mockResolvedValueOnce([{ address: '93.184.216.34', family: 4 }] as never);
expect(await isActionDomainAllowed('example.com', null)).toBe(true);
});
it('should not perform DNS check when allowedDomains is configured', async () => {
expect(await isActionDomainAllowed('example.com', ['example.com'])).toBe(true);
expect(mockedLookup).not.toHaveBeenCalled();
});
});
describe('isActionDomainAllowed', () => {
beforeEach(() => {
mockedLookup.mockResolvedValue([{ address: '93.184.216.34', family: 4 }] as never);
});
afterEach(() => {
jest.clearAllMocks();
});
@ -541,6 +697,9 @@ describe('extractMCPServerDomain', () => {
});
describe('isMCPDomainAllowed', () => {
beforeEach(() => {
mockedLookup.mockResolvedValue([{ address: '93.184.216.34', family: 4 }] as never);
});
afterEach(() => {
jest.clearAllMocks();
});

View file

@ -1,3 +1,5 @@
import { lookup } from 'node:dns/promises';
/**
* @param email
* @param allowedDomains
@ -22,6 +24,88 @@ export function isEmailDomainAllowed(email: string, allowedDomains?: string[] |
return allowedDomains.some((allowedDomain) => allowedDomain?.toLowerCase() === domain);
}
/** Checks if IPv4 octets fall within private, reserved, or link-local ranges */
function isPrivateIPv4(a: number, b: number, c: number): boolean {
if (a === 127) {
return true;
}
if (a === 10) {
return true;
}
if (a === 172 && b >= 16 && b <= 31) {
return true;
}
if (a === 192 && b === 168) {
return true;
}
if (a === 169 && b === 254) {
return true;
}
if (a === 0 && b === 0 && c === 0) {
return true;
}
return false;
}
/**
* Checks if an IP address belongs to a private, reserved, or link-local range.
* Handles IPv4, IPv6, and IPv4-mapped IPv6 addresses (::ffff:A.B.C.D).
*/
export function isPrivateIP(ip: string): boolean {
const normalized = ip.toLowerCase().trim();
const mappedMatch = normalized.match(/^::ffff:(\d{1,3})\.(\d{1,3})\.(\d{1,3})\.(\d{1,3})$/);
if (mappedMatch) {
const [, a, b, c] = mappedMatch.map(Number);
return isPrivateIPv4(a, b, c);
}
const ipv4Match = normalized.match(/^(\d{1,3})\.(\d{1,3})\.(\d{1,3})\.(\d{1,3})$/);
if (ipv4Match) {
const [, a, b, c] = ipv4Match.map(Number);
return isPrivateIPv4(a, b, c);
}
const ipv6 = normalized.replace(/^\[|\]$/g, '');
if (
ipv6 === '::1' ||
ipv6 === '::' ||
ipv6.startsWith('fc') ||
ipv6.startsWith('fd') ||
ipv6.startsWith('fe80')
) {
return true;
}
return false;
}
/**
* Resolves a hostname via DNS and checks if any resolved address is a private/reserved IP.
* Detects DNS-based SSRF bypasses (e.g., nip.io wildcard DNS, attacker-controlled nameservers).
* Fails open: returns false if DNS resolution fails, since hostname-only checks still apply
* and the actual HTTP request would also fail.
*/
export async function resolveHostnameSSRF(hostname: string): Promise<boolean> {
const normalizedHost = hostname.toLowerCase().trim();
if (/^(\d{1,3})\.(\d{1,3})\.(\d{1,3})\.(\d{1,3})$/.test(normalizedHost)) {
return false;
}
const ipv6Check = normalizedHost.replace(/^\[|\]$/g, '');
if (ipv6Check.includes(':')) {
return false;
}
try {
const addresses = await lookup(hostname, { all: true });
return addresses.some((entry) => isPrivateIP(entry.address));
} catch {
return false;
}
}
/**
* SSRF Protection: Checks if a hostname/IP is a potentially dangerous internal target.
* Blocks private IPs, localhost, cloud metadata IPs, and common internal hostnames.
@ -31,7 +115,6 @@ export function isEmailDomainAllowed(email: string, allowedDomains?: string[] |
export function isSSRFTarget(hostname: string): boolean {
const normalizedHost = hostname.toLowerCase().trim();
// Block localhost variations
if (
normalizedHost === 'localhost' ||
normalizedHost === 'localhost.localdomain' ||
@ -40,51 +123,7 @@ export function isSSRFTarget(hostname: string): boolean {
return true;
}
// Check if it's an IP address and block private/internal ranges
const ipv4Match = normalizedHost.match(/^(\d{1,3})\.(\d{1,3})\.(\d{1,3})\.(\d{1,3})$/);
if (ipv4Match) {
const [, a, b, c] = ipv4Match.map(Number);
// 127.0.0.0/8 - Loopback
if (a === 127) {
return true;
}
// 10.0.0.0/8 - Private
if (a === 10) {
return true;
}
// 172.16.0.0/12 - Private (172.16.x.x - 172.31.x.x)
if (a === 172 && b >= 16 && b <= 31) {
return true;
}
// 192.168.0.0/16 - Private
if (a === 192 && b === 168) {
return true;
}
// 169.254.0.0/16 - Link-local (includes cloud metadata 169.254.169.254)
if (a === 169 && b === 254) {
return true;
}
// 0.0.0.0 - Special
if (a === 0 && b === 0 && c === 0) {
return true;
}
}
// IPv6 loopback and private ranges
const ipv6Normalized = normalizedHost.replace(/^\[|\]$/g, ''); // Remove brackets if present
if (
ipv6Normalized === '::1' ||
ipv6Normalized === '::' ||
ipv6Normalized.startsWith('fc') || // fc00::/7 - Unique local
ipv6Normalized.startsWith('fd') || // fd00::/8 - Unique local
ipv6Normalized.startsWith('fe80') // fe80::/10 - Link-local
) {
if (isPrivateIP(normalizedHost)) {
return true;
}
@ -257,6 +296,10 @@ async function isDomainAllowedCore(
if (isSSRFTarget(inputSpec.hostname)) {
return false;
}
/** SECURITY: Resolve hostname and block if it points to a private/reserved IP */
if (await resolveHostnameSSRF(inputSpec.hostname)) {
return false;
}
return true;
}

View file

@ -1,3 +1,4 @@
export * from './domain';
export * from './openid';
export * from './exchange';
export * from './agent';

View file

@ -73,6 +73,7 @@ export class ConnectionsRepository {
{
serverName,
serverConfig,
useSSRFProtection: MCPServersRegistry.getInstance().shouldEnableSSRFProtection(),
},
this.oauthOpts,
);

View file

@ -29,6 +29,7 @@ export class MCPConnectionFactory {
protected readonly serverConfig: t.MCPOptions;
protected readonly logPrefix: string;
protected readonly useOAuth: boolean;
protected readonly useSSRFProtection: boolean;
// OAuth-related properties (only set when useOAuth is true)
protected readonly userId?: string;
@ -72,6 +73,7 @@ export class MCPConnectionFactory {
serverConfig: this.serverConfig,
userId: this.userId,
oauthTokens,
useSSRFProtection: this.useSSRFProtection,
});
const oauthHandler = async () => {
@ -146,6 +148,7 @@ export class MCPConnectionFactory {
serverConfig: this.serverConfig,
userId: this.userId,
oauthTokens: null,
useSSRFProtection: this.useSSRFProtection,
});
unauthConnection.on('oauthRequired', () => {
@ -189,6 +192,7 @@ export class MCPConnectionFactory {
});
this.serverName = basic.serverName;
this.useOAuth = !!oauth?.useOAuth;
this.useSSRFProtection = basic.useSSRFProtection === true;
this.connectionTimeout = oauth?.connectionTimeout;
this.logPrefix = oauth?.user
? `[MCP][${basic.serverName}][${oauth.user.id}]`
@ -213,6 +217,7 @@ export class MCPConnectionFactory {
serverConfig: this.serverConfig,
userId: this.userId,
oauthTokens,
useSSRFProtection: this.useSSRFProtection,
});
let cleanupOAuthHandlers: (() => void) | null = null;

View file

@ -102,7 +102,8 @@ export class MCPManager extends UserConnectionManager {
serverConfig.requiresOAuth || (serverConfig as t.ParsedServerConfig).oauthMetadata,
);
const basic: t.BasicConnectionOptions = { serverName, serverConfig };
const useSSRFProtection = MCPServersRegistry.getInstance().shouldEnableSSRFProtection();
const basic: t.BasicConnectionOptions = { serverName, serverConfig, useSSRFProtection };
if (!useOAuth) {
const result = await MCPConnectionFactory.discoverTools(basic);

View file

@ -117,6 +117,7 @@ export abstract class UserConnectionManager {
{
serverName: serverName,
serverConfig: config,
useSSRFProtection: MCPServersRegistry.getInstance().shouldEnableSSRFProtection(),
},
{
useOAuth: true,

View file

@ -24,6 +24,7 @@ jest.mock('../connection');
const mockRegistryInstance = {
getServerConfig: jest.fn(),
getAllServerConfigs: jest.fn(),
shouldEnableSSRFProtection: jest.fn().mockReturnValue(false),
};
jest.mock('../registry/MCPServersRegistry', () => ({
@ -108,6 +109,7 @@ describe('ConnectionsRepository', () => {
{
serverName: 'server1',
serverConfig: mockServerConfigs.server1,
useSSRFProtection: false,
},
undefined,
);
@ -129,6 +131,7 @@ describe('ConnectionsRepository', () => {
{
serverName: 'server1',
serverConfig: mockServerConfigs.server1,
useSSRFProtection: false,
},
undefined,
);
@ -167,6 +170,7 @@ describe('ConnectionsRepository', () => {
{
serverName: 'server1',
serverConfig: configWithCachedAt,
useSSRFProtection: false,
},
undefined,
);

View file

@ -84,6 +84,7 @@ describe('MCPConnectionFactory', () => {
serverConfig: mockServerConfig,
userId: undefined,
oauthTokens: null,
useSSRFProtection: false,
});
expect(mockConnectionInstance.connect).toHaveBeenCalled();
});
@ -125,6 +126,7 @@ describe('MCPConnectionFactory', () => {
serverConfig: mockServerConfig,
userId: 'user123',
oauthTokens: mockTokens,
useSSRFProtection: false,
});
});
});
@ -184,6 +186,7 @@ describe('MCPConnectionFactory', () => {
serverConfig: mockServerConfig,
userId: 'user123',
oauthTokens: null,
useSSRFProtection: false,
});
expect(mockLogger.debug).toHaveBeenCalledWith(
expect.stringContaining('No existing tokens found or error loading tokens'),

View file

@ -33,6 +33,7 @@ const mockRegistryInstance = {
getServerConfig: jest.fn(),
getAllServerConfigs: jest.fn(),
getOAuthServers: jest.fn(),
shouldEnableSSRFProtection: jest.fn().mockReturnValue(false),
};
jest.mock('~/mcp/registry/MCPServersRegistry', () => ({

View file

@ -20,6 +20,7 @@ import type {
import type { MCPOAuthTokens } from './oauth/types';
import { withTimeout } from '~/utils/promise';
import type * as t from './types';
import { createSSRFSafeUndiciConnect, resolveHostnameSSRF } from '~/auth';
import { sanitizeUrlForLogging } from './utils';
import { mcpConfig } from './mcpConfig';
@ -213,6 +214,7 @@ interface MCPConnectionParams {
serverConfig: t.MCPOptions;
userId?: string;
oauthTokens?: MCPOAuthTokens | null;
useSSRFProtection?: boolean;
}
export class MCPConnection extends EventEmitter {
@ -233,6 +235,7 @@ export class MCPConnection extends EventEmitter {
private oauthTokens?: MCPOAuthTokens | null;
private requestHeaders?: Record<string, string> | null;
private oauthRequired = false;
private readonly useSSRFProtection: boolean;
iconPath?: string;
timeout?: number;
url?: string;
@ -263,6 +266,7 @@ export class MCPConnection extends EventEmitter {
this.options = params.serverConfig;
this.serverName = params.serverName;
this.userId = params.userId;
this.useSSRFProtection = params.useSSRFProtection === true;
this.iconPath = params.serverConfig.iconPath;
this.timeout = params.serverConfig.timeout;
this.lastPingTime = Date.now();
@ -301,6 +305,7 @@ export class MCPConnection extends EventEmitter {
getHeaders: () => Record<string, string> | null | undefined,
timeout?: number,
): (input: UndiciRequestInfo, init?: UndiciRequestInit) => Promise<UndiciResponse> {
const ssrfConnect = this.useSSRFProtection ? createSSRFSafeUndiciConnect() : undefined;
return function customFetch(
input: UndiciRequestInfo,
init?: UndiciRequestInit,
@ -310,6 +315,7 @@ export class MCPConnection extends EventEmitter {
const agent = new Agent({
bodyTimeout: effectiveTimeout,
headersTimeout: effectiveTimeout,
...(ssrfConnect != null ? { connect: ssrfConnect } : {}),
});
if (!requestHeaders) {
return undiciFetch(input, { ...init, dispatcher: agent });
@ -342,7 +348,7 @@ export class MCPConnection extends EventEmitter {
logger.error(`${this.getLogPrefix()} ${errorContext}: ${errorMessage}`);
}
private constructTransport(options: t.MCPOptions): Transport {
private async constructTransport(options: t.MCPOptions): Promise<Transport> {
try {
let type: t.MCPOptions['type'];
if (isStdioOptions(options)) {
@ -378,6 +384,15 @@ export class MCPConnection extends EventEmitter {
throw new Error('Invalid options for websocket transport.');
}
this.url = options.url;
if (this.useSSRFProtection) {
const wsHostname = new URL(options.url).hostname;
const isSSRF = await resolveHostnameSSRF(wsHostname);
if (isSSRF) {
throw new Error(
`SSRF protection: WebSocket host "${wsHostname}" resolved to a private/reserved IP address`,
);
}
}
return new WebSocketClientTransport(new URL(options.url));
case 'sse': {
@ -402,6 +417,7 @@ export class MCPConnection extends EventEmitter {
* The connect timeout is extended because proxies may delay initial response.
*/
const sseTimeout = this.timeout || SSE_CONNECT_TIMEOUT;
const ssrfConnect = this.useSSRFProtection ? createSSRFSafeUndiciConnect() : undefined;
const transport = new SSEClientTransport(url, {
requestInit: {
/** User/OAuth headers override SSE defaults */
@ -420,6 +436,7 @@ export class MCPConnection extends EventEmitter {
/** Extended keep-alive for long-lived SSE connections */
keepAliveTimeout: sseTimeout,
keepAliveMaxTimeout: sseTimeout * 2,
...(ssrfConnect != null ? { connect: ssrfConnect } : {}),
});
return undiciFetch(url, {
...init,
@ -629,7 +646,7 @@ export class MCPConnection extends EventEmitter {
}
}
this.transport = this.constructTransport(this.options);
this.transport = await this.constructTransport(this.options);
this.setupTransportDebugHandlers();
const connectTimeout = this.options.initTimeout ?? 120000;

View file

@ -18,6 +18,7 @@ export class MCPServerInspector {
private readonly serverName: string,
private readonly config: t.ParsedServerConfig,
private connection: MCPConnection | undefined,
private readonly useSSRFProtection: boolean = false,
) {}
/**
@ -42,8 +43,9 @@ export class MCPServerInspector {
throw new MCPDomainNotAllowedError(domain ?? 'unknown');
}
const useSSRFProtection = !Array.isArray(allowedDomains) || allowedDomains.length === 0;
const start = Date.now();
const inspector = new MCPServerInspector(serverName, rawConfig, connection);
const inspector = new MCPServerInspector(serverName, rawConfig, connection, useSSRFProtection);
await inspector.inspectServer();
inspector.config.initDuration = Date.now() - start;
return inspector.config;
@ -59,6 +61,7 @@ export class MCPServerInspector {
this.connection = await MCPConnectionFactory.create({
serverName: this.serverName,
serverConfig: this.config,
useSSRFProtection: this.useSSRFProtection,
});
}

View file

@ -77,6 +77,15 @@ export class MCPServersRegistry {
return MCPServersRegistry.instance;
}
public getAllowedDomains(): string[] | null | undefined {
return this.allowedDomains;
}
/** Returns true when no explicit allowedDomains allowlist is configured, enabling SSRF TOCTOU protection */
public shouldEnableSSRFProtection(): boolean {
return !Array.isArray(this.allowedDomains) || this.allowedDomains.length === 0;
}
public async getServerConfig(
serverName: string,
userId?: string,

View file

@ -276,6 +276,7 @@ describe('MCPServerInspector', () => {
expect(MCPConnectionFactory.create).toHaveBeenCalledWith({
serverName: 'test_server',
serverConfig: expect.objectContaining({ type: 'stdio', command: 'node' }),
useSSRFProtection: true,
});
// Verify temporary connection was disconnected

View file

@ -166,6 +166,7 @@ export type AddServerResult = {
export interface BasicConnectionOptions {
serverName: string;
serverConfig: MCPOptions;
useSSRFProtection?: boolean;
}
export interface OAuthConnectionOptions {