diff --git a/api/server/services/ActionService.js b/api/server/services/ActionService.js index 132f6f4686..5e96726a46 100644 --- a/api/server/services/ActionService.js +++ b/api/server/services/ActionService.js @@ -8,6 +8,7 @@ const { logAxiosError, refreshAccessToken, GenerationJobManager, + createSSRFSafeAgents, } = require('@librechat/api'); const { Time, @@ -133,6 +134,7 @@ async function loadActionSets(searchParams) { * @param {import('zod').ZodTypeAny | undefined} [params.zodSchema] - The Zod schema for tool input validation/definition * @param {{ oauth_client_id?: string; oauth_client_secret?: string; }} params.encrypted - The encrypted values for the action. * @param {string | null} [params.streamId] - The stream ID for resumable streams. + * @param {boolean} [params.useSSRFProtection] - When true, uses SSRF-safe HTTP agents that validate resolved IPs at connect time. * @returns { Promise unknown}> } An object with `_call` method to execute the tool input. */ async function createActionTool({ @@ -145,7 +147,9 @@ async function createActionTool({ description, encrypted, streamId = null, + useSSRFProtection = false, }) { + const ssrfAgents = useSSRFProtection ? createSSRFSafeAgents() : undefined; /** @type {(toolInput: Object | string, config: GraphRunnableConfig) => Promise} */ const _call = async (toolInput, config) => { try { @@ -324,7 +328,7 @@ async function createActionTool({ } } - const response = await preparedExecutor.execute(); + const response = await preparedExecutor.execute(ssrfAgents); if (typeof response.data === 'object') { return JSON.stringify(response.data); diff --git a/api/server/services/ToolService.js b/api/server/services/ToolService.js index fe7a0f40c2..7f8c1d0460 100644 --- a/api/server/services/ToolService.js +++ b/api/server/services/ToolService.js @@ -338,6 +338,7 @@ async function processRequiredActions(client, requiredActions) { } // We've already decrypted the metadata, so we can pass it directly + const _allowedDomains = appConfig?.actions?.allowedDomains; tool = await createActionTool({ userId: client.req.user.id, res: client.res, @@ -345,6 +346,7 @@ async function processRequiredActions(client, requiredActions) { requestBuilder, // Note: intentionally not passing zodSchema, name, and description for assistants API encrypted, // Pass the encrypted values for OAuth flow + useSSRFProtection: !Array.isArray(_allowedDomains) || _allowedDomains.length === 0, }); if (!tool) { logger.warn( @@ -1064,6 +1066,7 @@ async function loadAgentTools({ const zodSchema = zodSchemas[functionName]; if (requestBuilder) { + const _allowedDomains = appConfig?.actions?.allowedDomains; const tool = await createActionTool({ userId: req.user.id, res, @@ -1074,6 +1077,7 @@ async function loadAgentTools({ name: toolName, description: functionSig.description, streamId, + useSSRFProtection: !Array.isArray(_allowedDomains) || _allowedDomains.length === 0, }); if (!tool) { @@ -1372,6 +1376,7 @@ async function loadActionToolsForExecution({ requestBuilder, name: toolName, description: functionSig?.description ?? '', + useSSRFProtection: !Array.isArray(allowedDomains) || allowedDomains.length === 0, }); if (!tool) { diff --git a/packages/api/src/auth/agent.spec.ts b/packages/api/src/auth/agent.spec.ts new file mode 100644 index 0000000000..9ab2a9aaf9 --- /dev/null +++ b/packages/api/src/auth/agent.spec.ts @@ -0,0 +1,113 @@ +jest.mock('node:dns', () => { + const actual = jest.requireActual('node:dns'); + return { + ...actual, + lookup: jest.fn(), + }; +}); + +import dns from 'node:dns'; +import { createSSRFSafeAgents, createSSRFSafeUndiciConnect } from './agent'; + +type LookupCallback = (err: NodeJS.ErrnoException | null, address: string, family: number) => void; + +const mockedDnsLookup = dns.lookup as jest.MockedFunction; + +function mockDnsResult(address: string, family: number): void { + mockedDnsLookup.mockImplementation((( + _hostname: string, + _options: unknown, + callback: LookupCallback, + ) => { + callback(null, address, family); + }) as never); +} + +function mockDnsError(err: NodeJS.ErrnoException): void { + mockedDnsLookup.mockImplementation((( + _hostname: string, + _options: unknown, + callback: LookupCallback, + ) => { + callback(err, '', 0); + }) as never); +} + +describe('createSSRFSafeAgents', () => { + afterEach(() => { + jest.clearAllMocks(); + }); + + it('should return httpAgent and httpsAgent', () => { + const agents = createSSRFSafeAgents(); + expect(agents.httpAgent).toBeDefined(); + expect(agents.httpsAgent).toBeDefined(); + }); + + it('should patch httpAgent createConnection to inject SSRF lookup', () => { + const agents = createSSRFSafeAgents(); + const internal = agents.httpAgent as unknown as { + createConnection: (opts: Record) => unknown; + }; + expect(internal.createConnection).toBeInstanceOf(Function); + }); +}); + +describe('createSSRFSafeUndiciConnect', () => { + afterEach(() => { + jest.clearAllMocks(); + }); + + it('should return an object with a lookup function', () => { + const connect = createSSRFSafeUndiciConnect(); + expect(connect).toHaveProperty('lookup'); + expect(connect.lookup).toBeInstanceOf(Function); + }); + + it('lookup should block private IPs', async () => { + mockDnsResult('10.0.0.1', 4); + const connect = createSSRFSafeUndiciConnect(); + + const result = await new Promise<{ err: NodeJS.ErrnoException | null }>((resolve) => { + connect.lookup('evil.example.com', {}, (err) => { + resolve({ err }); + }); + }); + + expect(result.err).toBeTruthy(); + expect(result.err!.code).toBe('ESSRF'); + }); + + it('lookup should allow public IPs', async () => { + mockDnsResult('93.184.216.34', 4); + const connect = createSSRFSafeUndiciConnect(); + + const result = await new Promise<{ err: NodeJS.ErrnoException | null; address: string }>( + (resolve) => { + connect.lookup('example.com', {}, (err, address) => { + resolve({ err, address: address as string }); + }); + }, + ); + + expect(result.err).toBeNull(); + expect(result.address).toBe('93.184.216.34'); + }); + + it('lookup should forward DNS errors', async () => { + const dnsError = Object.assign(new Error('ENOTFOUND'), { + code: 'ENOTFOUND', + }) as NodeJS.ErrnoException; + mockDnsError(dnsError); + const connect = createSSRFSafeUndiciConnect(); + + const result = await new Promise<{ err: NodeJS.ErrnoException | null }>((resolve) => { + connect.lookup('nonexistent.example.com', {}, (err) => { + resolve({ err }); + }); + }); + + expect(result.err).toBeTruthy(); + expect(result.err!.code).toBe('ENOTFOUND'); + }); +}); diff --git a/packages/api/src/auth/agent.ts b/packages/api/src/auth/agent.ts new file mode 100644 index 0000000000..2442aa20fa --- /dev/null +++ b/packages/api/src/auth/agent.ts @@ -0,0 +1,61 @@ +import dns from 'node:dns'; +import http from 'node:http'; +import https from 'node:https'; +import type { LookupFunction } from 'node:net'; +import { isPrivateIP } from './domain'; + +/** DNS lookup wrapper that blocks resolution to private/reserved IP addresses */ +const ssrfSafeLookup: LookupFunction = (hostname, options, callback) => { + dns.lookup(hostname, options, (err, address, family) => { + if (err) { + callback(err, '', 0); + return; + } + if (typeof address === 'string' && isPrivateIP(address)) { + const ssrfError = Object.assign( + new Error(`SSRF protection: ${hostname} resolved to blocked address ${address}`), + { code: 'ESSRF' }, + ) as NodeJS.ErrnoException; + callback(ssrfError, address, family as number); + return; + } + callback(null, address as string, family as number); + }); +}; + +/** Internal agent shape exposing createConnection (exists at runtime but not in TS types) */ +type AgentInternal = { + createConnection: (options: Record, oncreate?: unknown) => unknown; +}; + +/** Patches an agent instance to inject SSRF-safe DNS lookup at connect time */ +function withSSRFProtection(agent: T): T { + const internal = agent as unknown as AgentInternal; + const origCreate = internal.createConnection.bind(agent); + internal.createConnection = (options: Record, oncreate?: unknown) => { + options.lookup = ssrfSafeLookup; + return origCreate(options, oncreate); + }; + return agent; +} + +/** + * Creates HTTP and HTTPS agents that block TCP connections to private/reserved IP addresses. + * Provides TOCTOU-safe SSRF protection by validating the resolved IP at connect time, + * preventing DNS rebinding attacks where a hostname resolves to a public IP during + * pre-validation but to a private IP when the actual connection is made. + */ +export function createSSRFSafeAgents(): { httpAgent: http.Agent; httpsAgent: https.Agent } { + return { + httpAgent: withSSRFProtection(new http.Agent()), + httpsAgent: withSSRFProtection(new https.Agent()), + }; +} + +/** + * Returns undici-compatible `connect` options with SSRF-safe DNS lookup. + * Pass the result as the `connect` property when constructing an undici `Agent`. + */ +export function createSSRFSafeUndiciConnect(): { lookup: LookupFunction } { + return { lookup: ssrfSafeLookup }; +} diff --git a/packages/api/src/auth/domain.spec.ts b/packages/api/src/auth/domain.spec.ts index a2b4c42cd7..5f6187c9b4 100644 --- a/packages/api/src/auth/domain.spec.ts +++ b/packages/api/src/auth/domain.spec.ts @@ -1,12 +1,21 @@ /* eslint-disable @typescript-eslint/ban-ts-comment */ +jest.mock('node:dns/promises', () => ({ + lookup: jest.fn(), +})); + +import { lookup } from 'node:dns/promises'; import { extractMCPServerDomain, isActionDomainAllowed, isEmailDomainAllowed, isMCPDomainAllowed, + isPrivateIP, isSSRFTarget, + resolveHostnameSSRF, } from './domain'; +const mockedLookup = lookup as jest.MockedFunction; + describe('isEmailDomainAllowed', () => { afterEach(() => { jest.clearAllMocks(); @@ -192,7 +201,154 @@ describe('isSSRFTarget', () => { }); }); +describe('isPrivateIP', () => { + describe('IPv4 private ranges', () => { + it('should detect loopback addresses', () => { + expect(isPrivateIP('127.0.0.1')).toBe(true); + expect(isPrivateIP('127.255.255.255')).toBe(true); + }); + + it('should detect 10.x.x.x private range', () => { + expect(isPrivateIP('10.0.0.1')).toBe(true); + expect(isPrivateIP('10.255.255.255')).toBe(true); + }); + + it('should detect 172.16-31.x.x private range', () => { + expect(isPrivateIP('172.16.0.1')).toBe(true); + expect(isPrivateIP('172.31.255.255')).toBe(true); + expect(isPrivateIP('172.15.0.1')).toBe(false); + expect(isPrivateIP('172.32.0.1')).toBe(false); + }); + + it('should detect 192.168.x.x private range', () => { + expect(isPrivateIP('192.168.0.1')).toBe(true); + expect(isPrivateIP('192.168.255.255')).toBe(true); + }); + + it('should detect 169.254.x.x link-local range', () => { + expect(isPrivateIP('169.254.169.254')).toBe(true); + expect(isPrivateIP('169.254.0.1')).toBe(true); + }); + + it('should detect 0.0.0.0', () => { + expect(isPrivateIP('0.0.0.0')).toBe(true); + }); + + it('should allow public IPs', () => { + expect(isPrivateIP('8.8.8.8')).toBe(false); + expect(isPrivateIP('1.1.1.1')).toBe(false); + expect(isPrivateIP('93.184.216.34')).toBe(false); + }); + }); + + describe('IPv6 private ranges', () => { + it('should detect loopback', () => { + expect(isPrivateIP('::1')).toBe(true); + expect(isPrivateIP('::')).toBe(true); + expect(isPrivateIP('[::1]')).toBe(true); + }); + + it('should detect unique local (fc/fd) and link-local (fe80)', () => { + expect(isPrivateIP('fc00::1')).toBe(true); + expect(isPrivateIP('fd00::1')).toBe(true); + expect(isPrivateIP('fe80::1')).toBe(true); + }); + }); + + describe('IPv4-mapped IPv6 addresses', () => { + it('should detect private IPs in IPv4-mapped IPv6 form', () => { + expect(isPrivateIP('::ffff:169.254.169.254')).toBe(true); + expect(isPrivateIP('::ffff:127.0.0.1')).toBe(true); + expect(isPrivateIP('::ffff:10.0.0.1')).toBe(true); + expect(isPrivateIP('::ffff:192.168.1.1')).toBe(true); + }); + + it('should allow public IPs in IPv4-mapped IPv6 form', () => { + expect(isPrivateIP('::ffff:8.8.8.8')).toBe(false); + expect(isPrivateIP('::ffff:93.184.216.34')).toBe(false); + }); + }); +}); + +describe('resolveHostnameSSRF', () => { + afterEach(() => { + jest.clearAllMocks(); + }); + + it('should detect domains that resolve to private IPs (nip.io bypass)', async () => { + mockedLookup.mockResolvedValueOnce([{ address: '169.254.169.254', family: 4 }] as never); + expect(await resolveHostnameSSRF('169.254.169.254.nip.io')).toBe(true); + }); + + it('should detect domains that resolve to loopback', async () => { + mockedLookup.mockResolvedValueOnce([{ address: '127.0.0.1', family: 4 }] as never); + expect(await resolveHostnameSSRF('loopback.example.com')).toBe(true); + }); + + it('should detect when any resolved address is private', async () => { + mockedLookup.mockResolvedValueOnce([ + { address: '93.184.216.34', family: 4 }, + { address: '10.0.0.1', family: 4 }, + ] as never); + expect(await resolveHostnameSSRF('dual.example.com')).toBe(true); + }); + + it('should allow domains that resolve to public IPs', async () => { + mockedLookup.mockResolvedValueOnce([{ address: '93.184.216.34', family: 4 }] as never); + expect(await resolveHostnameSSRF('example.com')).toBe(false); + }); + + it('should skip literal IPv4 addresses (handled by isSSRFTarget)', async () => { + expect(await resolveHostnameSSRF('169.254.169.254')).toBe(false); + expect(mockedLookup).not.toHaveBeenCalled(); + }); + + it('should skip literal IPv6 addresses', async () => { + expect(await resolveHostnameSSRF('::1')).toBe(false); + expect(mockedLookup).not.toHaveBeenCalled(); + }); + + it('should fail open on DNS resolution failure', async () => { + mockedLookup.mockRejectedValueOnce(new Error('ENOTFOUND')); + expect(await resolveHostnameSSRF('nonexistent.example.com')).toBe(false); + }); +}); + +describe('isActionDomainAllowed - DNS resolution SSRF protection', () => { + afterEach(() => { + jest.clearAllMocks(); + }); + + it('should block domains resolving to cloud metadata IP (169.254.169.254)', async () => { + mockedLookup.mockResolvedValueOnce([{ address: '169.254.169.254', family: 4 }] as never); + expect(await isActionDomainAllowed('169.254.169.254.nip.io', null)).toBe(false); + }); + + it('should block domains resolving to private 10.x range', async () => { + mockedLookup.mockResolvedValueOnce([{ address: '10.0.0.5', family: 4 }] as never); + expect(await isActionDomainAllowed('internal.attacker.com', null)).toBe(false); + }); + + it('should block domains resolving to 172.16.x range', async () => { + mockedLookup.mockResolvedValueOnce([{ address: '172.16.0.1', family: 4 }] as never); + expect(await isActionDomainAllowed('docker.attacker.com', null)).toBe(false); + }); + + it('should allow domains resolving to public IPs when no allowlist', async () => { + mockedLookup.mockResolvedValueOnce([{ address: '93.184.216.34', family: 4 }] as never); + expect(await isActionDomainAllowed('example.com', null)).toBe(true); + }); + + it('should not perform DNS check when allowedDomains is configured', async () => { + expect(await isActionDomainAllowed('example.com', ['example.com'])).toBe(true); + expect(mockedLookup).not.toHaveBeenCalled(); + }); +}); + describe('isActionDomainAllowed', () => { + beforeEach(() => { + mockedLookup.mockResolvedValue([{ address: '93.184.216.34', family: 4 }] as never); + }); afterEach(() => { jest.clearAllMocks(); }); @@ -541,6 +697,9 @@ describe('extractMCPServerDomain', () => { }); describe('isMCPDomainAllowed', () => { + beforeEach(() => { + mockedLookup.mockResolvedValue([{ address: '93.184.216.34', family: 4 }] as never); + }); afterEach(() => { jest.clearAllMocks(); }); diff --git a/packages/api/src/auth/domain.ts b/packages/api/src/auth/domain.ts index 5d9fc51d02..f2e86875d4 100644 --- a/packages/api/src/auth/domain.ts +++ b/packages/api/src/auth/domain.ts @@ -1,3 +1,5 @@ +import { lookup } from 'node:dns/promises'; + /** * @param email * @param allowedDomains @@ -22,6 +24,88 @@ export function isEmailDomainAllowed(email: string, allowedDomains?: string[] | return allowedDomains.some((allowedDomain) => allowedDomain?.toLowerCase() === domain); } +/** Checks if IPv4 octets fall within private, reserved, or link-local ranges */ +function isPrivateIPv4(a: number, b: number, c: number): boolean { + if (a === 127) { + return true; + } + if (a === 10) { + return true; + } + if (a === 172 && b >= 16 && b <= 31) { + return true; + } + if (a === 192 && b === 168) { + return true; + } + if (a === 169 && b === 254) { + return true; + } + if (a === 0 && b === 0 && c === 0) { + return true; + } + return false; +} + +/** + * Checks if an IP address belongs to a private, reserved, or link-local range. + * Handles IPv4, IPv6, and IPv4-mapped IPv6 addresses (::ffff:A.B.C.D). + */ +export function isPrivateIP(ip: string): boolean { + const normalized = ip.toLowerCase().trim(); + + const mappedMatch = normalized.match(/^::ffff:(\d{1,3})\.(\d{1,3})\.(\d{1,3})\.(\d{1,3})$/); + if (mappedMatch) { + const [, a, b, c] = mappedMatch.map(Number); + return isPrivateIPv4(a, b, c); + } + + const ipv4Match = normalized.match(/^(\d{1,3})\.(\d{1,3})\.(\d{1,3})\.(\d{1,3})$/); + if (ipv4Match) { + const [, a, b, c] = ipv4Match.map(Number); + return isPrivateIPv4(a, b, c); + } + + const ipv6 = normalized.replace(/^\[|\]$/g, ''); + if ( + ipv6 === '::1' || + ipv6 === '::' || + ipv6.startsWith('fc') || + ipv6.startsWith('fd') || + ipv6.startsWith('fe80') + ) { + return true; + } + + return false; +} + +/** + * Resolves a hostname via DNS and checks if any resolved address is a private/reserved IP. + * Detects DNS-based SSRF bypasses (e.g., nip.io wildcard DNS, attacker-controlled nameservers). + * Fails open: returns false if DNS resolution fails, since hostname-only checks still apply + * and the actual HTTP request would also fail. + */ +export async function resolveHostnameSSRF(hostname: string): Promise { + const normalizedHost = hostname.toLowerCase().trim(); + + if (/^(\d{1,3})\.(\d{1,3})\.(\d{1,3})\.(\d{1,3})$/.test(normalizedHost)) { + return false; + } + + const ipv6Check = normalizedHost.replace(/^\[|\]$/g, ''); + if (ipv6Check.includes(':')) { + return false; + } + + try { + const addresses = await lookup(hostname, { all: true }); + return addresses.some((entry) => isPrivateIP(entry.address)); + } catch { + return false; + } +} + /** * SSRF Protection: Checks if a hostname/IP is a potentially dangerous internal target. * Blocks private IPs, localhost, cloud metadata IPs, and common internal hostnames. @@ -31,7 +115,6 @@ export function isEmailDomainAllowed(email: string, allowedDomains?: string[] | export function isSSRFTarget(hostname: string): boolean { const normalizedHost = hostname.toLowerCase().trim(); - // Block localhost variations if ( normalizedHost === 'localhost' || normalizedHost === 'localhost.localdomain' || @@ -40,51 +123,7 @@ export function isSSRFTarget(hostname: string): boolean { return true; } - // Check if it's an IP address and block private/internal ranges - const ipv4Match = normalizedHost.match(/^(\d{1,3})\.(\d{1,3})\.(\d{1,3})\.(\d{1,3})$/); - if (ipv4Match) { - const [, a, b, c] = ipv4Match.map(Number); - - // 127.0.0.0/8 - Loopback - if (a === 127) { - return true; - } - - // 10.0.0.0/8 - Private - if (a === 10) { - return true; - } - - // 172.16.0.0/12 - Private (172.16.x.x - 172.31.x.x) - if (a === 172 && b >= 16 && b <= 31) { - return true; - } - - // 192.168.0.0/16 - Private - if (a === 192 && b === 168) { - return true; - } - - // 169.254.0.0/16 - Link-local (includes cloud metadata 169.254.169.254) - if (a === 169 && b === 254) { - return true; - } - - // 0.0.0.0 - Special - if (a === 0 && b === 0 && c === 0) { - return true; - } - } - - // IPv6 loopback and private ranges - const ipv6Normalized = normalizedHost.replace(/^\[|\]$/g, ''); // Remove brackets if present - if ( - ipv6Normalized === '::1' || - ipv6Normalized === '::' || - ipv6Normalized.startsWith('fc') || // fc00::/7 - Unique local - ipv6Normalized.startsWith('fd') || // fd00::/8 - Unique local - ipv6Normalized.startsWith('fe80') // fe80::/10 - Link-local - ) { + if (isPrivateIP(normalizedHost)) { return true; } @@ -257,6 +296,10 @@ async function isDomainAllowedCore( if (isSSRFTarget(inputSpec.hostname)) { return false; } + /** SECURITY: Resolve hostname and block if it points to a private/reserved IP */ + if (await resolveHostnameSSRF(inputSpec.hostname)) { + return false; + } return true; } diff --git a/packages/api/src/auth/index.ts b/packages/api/src/auth/index.ts index d15d94aad2..392605ef50 100644 --- a/packages/api/src/auth/index.ts +++ b/packages/api/src/auth/index.ts @@ -1,3 +1,4 @@ export * from './domain'; export * from './openid'; export * from './exchange'; +export * from './agent'; diff --git a/packages/api/src/mcp/ConnectionsRepository.ts b/packages/api/src/mcp/ConnectionsRepository.ts index e2c48c88ab..49d0799085 100644 --- a/packages/api/src/mcp/ConnectionsRepository.ts +++ b/packages/api/src/mcp/ConnectionsRepository.ts @@ -73,6 +73,7 @@ export class ConnectionsRepository { { serverName, serverConfig, + useSSRFProtection: MCPServersRegistry.getInstance().shouldEnableSSRFProtection(), }, this.oauthOpts, ); diff --git a/packages/api/src/mcp/MCPConnectionFactory.ts b/packages/api/src/mcp/MCPConnectionFactory.ts index bcc63b7500..748cd0a967 100644 --- a/packages/api/src/mcp/MCPConnectionFactory.ts +++ b/packages/api/src/mcp/MCPConnectionFactory.ts @@ -29,6 +29,7 @@ export class MCPConnectionFactory { protected readonly serverConfig: t.MCPOptions; protected readonly logPrefix: string; protected readonly useOAuth: boolean; + protected readonly useSSRFProtection: boolean; // OAuth-related properties (only set when useOAuth is true) protected readonly userId?: string; @@ -72,6 +73,7 @@ export class MCPConnectionFactory { serverConfig: this.serverConfig, userId: this.userId, oauthTokens, + useSSRFProtection: this.useSSRFProtection, }); const oauthHandler = async () => { @@ -146,6 +148,7 @@ export class MCPConnectionFactory { serverConfig: this.serverConfig, userId: this.userId, oauthTokens: null, + useSSRFProtection: this.useSSRFProtection, }); unauthConnection.on('oauthRequired', () => { @@ -189,6 +192,7 @@ export class MCPConnectionFactory { }); this.serverName = basic.serverName; this.useOAuth = !!oauth?.useOAuth; + this.useSSRFProtection = basic.useSSRFProtection === true; this.connectionTimeout = oauth?.connectionTimeout; this.logPrefix = oauth?.user ? `[MCP][${basic.serverName}][${oauth.user.id}]` @@ -213,6 +217,7 @@ export class MCPConnectionFactory { serverConfig: this.serverConfig, userId: this.userId, oauthTokens, + useSSRFProtection: this.useSSRFProtection, }); let cleanupOAuthHandlers: (() => void) | null = null; diff --git a/packages/api/src/mcp/MCPManager.ts b/packages/api/src/mcp/MCPManager.ts index 211382c032..cab495774a 100644 --- a/packages/api/src/mcp/MCPManager.ts +++ b/packages/api/src/mcp/MCPManager.ts @@ -102,7 +102,8 @@ export class MCPManager extends UserConnectionManager { serverConfig.requiresOAuth || (serverConfig as t.ParsedServerConfig).oauthMetadata, ); - const basic: t.BasicConnectionOptions = { serverName, serverConfig }; + const useSSRFProtection = MCPServersRegistry.getInstance().shouldEnableSSRFProtection(); + const basic: t.BasicConnectionOptions = { serverName, serverConfig, useSSRFProtection }; if (!useOAuth) { const result = await MCPConnectionFactory.discoverTools(basic); diff --git a/packages/api/src/mcp/UserConnectionManager.ts b/packages/api/src/mcp/UserConnectionManager.ts index 25fc753d6b..e5d94689a0 100644 --- a/packages/api/src/mcp/UserConnectionManager.ts +++ b/packages/api/src/mcp/UserConnectionManager.ts @@ -117,6 +117,7 @@ export abstract class UserConnectionManager { { serverName: serverName, serverConfig: config, + useSSRFProtection: MCPServersRegistry.getInstance().shouldEnableSSRFProtection(), }, { useOAuth: true, diff --git a/packages/api/src/mcp/__tests__/ConnectionsRepository.test.ts b/packages/api/src/mcp/__tests__/ConnectionsRepository.test.ts index e722b38375..4240ba12d6 100644 --- a/packages/api/src/mcp/__tests__/ConnectionsRepository.test.ts +++ b/packages/api/src/mcp/__tests__/ConnectionsRepository.test.ts @@ -24,6 +24,7 @@ jest.mock('../connection'); const mockRegistryInstance = { getServerConfig: jest.fn(), getAllServerConfigs: jest.fn(), + shouldEnableSSRFProtection: jest.fn().mockReturnValue(false), }; jest.mock('../registry/MCPServersRegistry', () => ({ @@ -108,6 +109,7 @@ describe('ConnectionsRepository', () => { { serverName: 'server1', serverConfig: mockServerConfigs.server1, + useSSRFProtection: false, }, undefined, ); @@ -129,6 +131,7 @@ describe('ConnectionsRepository', () => { { serverName: 'server1', serverConfig: mockServerConfigs.server1, + useSSRFProtection: false, }, undefined, ); @@ -167,6 +170,7 @@ describe('ConnectionsRepository', () => { { serverName: 'server1', serverConfig: configWithCachedAt, + useSSRFProtection: false, }, undefined, ); diff --git a/packages/api/src/mcp/__tests__/MCPConnectionFactory.test.ts b/packages/api/src/mcp/__tests__/MCPConnectionFactory.test.ts index 0986188e04..9f824bce23 100644 --- a/packages/api/src/mcp/__tests__/MCPConnectionFactory.test.ts +++ b/packages/api/src/mcp/__tests__/MCPConnectionFactory.test.ts @@ -84,6 +84,7 @@ describe('MCPConnectionFactory', () => { serverConfig: mockServerConfig, userId: undefined, oauthTokens: null, + useSSRFProtection: false, }); expect(mockConnectionInstance.connect).toHaveBeenCalled(); }); @@ -125,6 +126,7 @@ describe('MCPConnectionFactory', () => { serverConfig: mockServerConfig, userId: 'user123', oauthTokens: mockTokens, + useSSRFProtection: false, }); }); }); @@ -184,6 +186,7 @@ describe('MCPConnectionFactory', () => { serverConfig: mockServerConfig, userId: 'user123', oauthTokens: null, + useSSRFProtection: false, }); expect(mockLogger.debug).toHaveBeenCalledWith( expect.stringContaining('No existing tokens found or error loading tokens'), diff --git a/packages/api/src/mcp/__tests__/MCPManager.test.ts b/packages/api/src/mcp/__tests__/MCPManager.test.ts index caeb9176d3..bf63a6af3c 100644 --- a/packages/api/src/mcp/__tests__/MCPManager.test.ts +++ b/packages/api/src/mcp/__tests__/MCPManager.test.ts @@ -33,6 +33,7 @@ const mockRegistryInstance = { getServerConfig: jest.fn(), getAllServerConfigs: jest.fn(), getOAuthServers: jest.fn(), + shouldEnableSSRFProtection: jest.fn().mockReturnValue(false), }; jest.mock('~/mcp/registry/MCPServersRegistry', () => ({ diff --git a/packages/api/src/mcp/connection.ts b/packages/api/src/mcp/connection.ts index b954a2e839..74891dbd15 100644 --- a/packages/api/src/mcp/connection.ts +++ b/packages/api/src/mcp/connection.ts @@ -20,6 +20,7 @@ import type { import type { MCPOAuthTokens } from './oauth/types'; import { withTimeout } from '~/utils/promise'; import type * as t from './types'; +import { createSSRFSafeUndiciConnect, resolveHostnameSSRF } from '~/auth'; import { sanitizeUrlForLogging } from './utils'; import { mcpConfig } from './mcpConfig'; @@ -213,6 +214,7 @@ interface MCPConnectionParams { serverConfig: t.MCPOptions; userId?: string; oauthTokens?: MCPOAuthTokens | null; + useSSRFProtection?: boolean; } export class MCPConnection extends EventEmitter { @@ -233,6 +235,7 @@ export class MCPConnection extends EventEmitter { private oauthTokens?: MCPOAuthTokens | null; private requestHeaders?: Record | null; private oauthRequired = false; + private readonly useSSRFProtection: boolean; iconPath?: string; timeout?: number; url?: string; @@ -263,6 +266,7 @@ export class MCPConnection extends EventEmitter { this.options = params.serverConfig; this.serverName = params.serverName; this.userId = params.userId; + this.useSSRFProtection = params.useSSRFProtection === true; this.iconPath = params.serverConfig.iconPath; this.timeout = params.serverConfig.timeout; this.lastPingTime = Date.now(); @@ -301,6 +305,7 @@ export class MCPConnection extends EventEmitter { getHeaders: () => Record | null | undefined, timeout?: number, ): (input: UndiciRequestInfo, init?: UndiciRequestInit) => Promise { + const ssrfConnect = this.useSSRFProtection ? createSSRFSafeUndiciConnect() : undefined; return function customFetch( input: UndiciRequestInfo, init?: UndiciRequestInit, @@ -310,6 +315,7 @@ export class MCPConnection extends EventEmitter { const agent = new Agent({ bodyTimeout: effectiveTimeout, headersTimeout: effectiveTimeout, + ...(ssrfConnect != null ? { connect: ssrfConnect } : {}), }); if (!requestHeaders) { return undiciFetch(input, { ...init, dispatcher: agent }); @@ -342,7 +348,7 @@ export class MCPConnection extends EventEmitter { logger.error(`${this.getLogPrefix()} ${errorContext}: ${errorMessage}`); } - private constructTransport(options: t.MCPOptions): Transport { + private async constructTransport(options: t.MCPOptions): Promise { try { let type: t.MCPOptions['type']; if (isStdioOptions(options)) { @@ -378,6 +384,15 @@ export class MCPConnection extends EventEmitter { throw new Error('Invalid options for websocket transport.'); } this.url = options.url; + if (this.useSSRFProtection) { + const wsHostname = new URL(options.url).hostname; + const isSSRF = await resolveHostnameSSRF(wsHostname); + if (isSSRF) { + throw new Error( + `SSRF protection: WebSocket host "${wsHostname}" resolved to a private/reserved IP address`, + ); + } + } return new WebSocketClientTransport(new URL(options.url)); case 'sse': { @@ -402,6 +417,7 @@ export class MCPConnection extends EventEmitter { * The connect timeout is extended because proxies may delay initial response. */ const sseTimeout = this.timeout || SSE_CONNECT_TIMEOUT; + const ssrfConnect = this.useSSRFProtection ? createSSRFSafeUndiciConnect() : undefined; const transport = new SSEClientTransport(url, { requestInit: { /** User/OAuth headers override SSE defaults */ @@ -420,6 +436,7 @@ export class MCPConnection extends EventEmitter { /** Extended keep-alive for long-lived SSE connections */ keepAliveTimeout: sseTimeout, keepAliveMaxTimeout: sseTimeout * 2, + ...(ssrfConnect != null ? { connect: ssrfConnect } : {}), }); return undiciFetch(url, { ...init, @@ -629,7 +646,7 @@ export class MCPConnection extends EventEmitter { } } - this.transport = this.constructTransport(this.options); + this.transport = await this.constructTransport(this.options); this.setupTransportDebugHandlers(); const connectTimeout = this.options.initTimeout ?? 120000; diff --git a/packages/api/src/mcp/registry/MCPServerInspector.ts b/packages/api/src/mcp/registry/MCPServerInspector.ts index 2263c10422..50da9cdc25 100644 --- a/packages/api/src/mcp/registry/MCPServerInspector.ts +++ b/packages/api/src/mcp/registry/MCPServerInspector.ts @@ -18,6 +18,7 @@ export class MCPServerInspector { private readonly serverName: string, private readonly config: t.ParsedServerConfig, private connection: MCPConnection | undefined, + private readonly useSSRFProtection: boolean = false, ) {} /** @@ -42,8 +43,9 @@ export class MCPServerInspector { throw new MCPDomainNotAllowedError(domain ?? 'unknown'); } + const useSSRFProtection = !Array.isArray(allowedDomains) || allowedDomains.length === 0; const start = Date.now(); - const inspector = new MCPServerInspector(serverName, rawConfig, connection); + const inspector = new MCPServerInspector(serverName, rawConfig, connection, useSSRFProtection); await inspector.inspectServer(); inspector.config.initDuration = Date.now() - start; return inspector.config; @@ -59,6 +61,7 @@ export class MCPServerInspector { this.connection = await MCPConnectionFactory.create({ serverName: this.serverName, serverConfig: this.config, + useSSRFProtection: this.useSSRFProtection, }); } diff --git a/packages/api/src/mcp/registry/MCPServersRegistry.ts b/packages/api/src/mcp/registry/MCPServersRegistry.ts index 801b3957a0..0264a8ed7a 100644 --- a/packages/api/src/mcp/registry/MCPServersRegistry.ts +++ b/packages/api/src/mcp/registry/MCPServersRegistry.ts @@ -77,6 +77,15 @@ export class MCPServersRegistry { return MCPServersRegistry.instance; } + public getAllowedDomains(): string[] | null | undefined { + return this.allowedDomains; + } + + /** Returns true when no explicit allowedDomains allowlist is configured, enabling SSRF TOCTOU protection */ + public shouldEnableSSRFProtection(): boolean { + return !Array.isArray(this.allowedDomains) || this.allowedDomains.length === 0; + } + public async getServerConfig( serverName: string, userId?: string, diff --git a/packages/api/src/mcp/registry/__tests__/MCPServerInspector.test.ts b/packages/api/src/mcp/registry/__tests__/MCPServerInspector.test.ts index 72bf57857e..42dc4d2005 100644 --- a/packages/api/src/mcp/registry/__tests__/MCPServerInspector.test.ts +++ b/packages/api/src/mcp/registry/__tests__/MCPServerInspector.test.ts @@ -276,6 +276,7 @@ describe('MCPServerInspector', () => { expect(MCPConnectionFactory.create).toHaveBeenCalledWith({ serverName: 'test_server', serverConfig: expect.objectContaining({ type: 'stdio', command: 'node' }), + useSSRFProtection: true, }); // Verify temporary connection was disconnected diff --git a/packages/api/src/mcp/types/index.ts b/packages/api/src/mcp/types/index.ts index 46447c6687..270131036b 100644 --- a/packages/api/src/mcp/types/index.ts +++ b/packages/api/src/mcp/types/index.ts @@ -166,6 +166,7 @@ export type AddServerResult = { export interface BasicConnectionOptions { serverName: string; serverConfig: MCPOptions; + useSSRFProtection?: boolean; } export interface OAuthConnectionOptions { diff --git a/packages/data-provider/specs/actions.spec.ts b/packages/data-provider/specs/actions.spec.ts index 08942d5505..59f068586d 100644 --- a/packages/data-provider/specs/actions.spec.ts +++ b/packages/data-provider/specs/actions.spec.ts @@ -459,6 +459,82 @@ describe('ActionRequest', () => { await expect(actionRequest.execute()).rejects.toThrow('Unsupported HTTP method: invalid'); }); + describe('SSRF-safe agent passthrough', () => { + beforeEach(() => { + mockedAxios.get.mockResolvedValue({ data: { success: true } }); + mockedAxios.post.mockResolvedValue({ data: { success: true } }); + }); + + it('should pass httpAgent and httpsAgent to axios.create when provided', async () => { + const mockHttpAgent = { keepAlive: true }; + const mockHttpsAgent = { keepAlive: true }; + + const actionRequest = new ActionRequest( + 'https://example.com', + '/test', + 'GET', + 'testOp', + false, + 'application/json', + ); + const executor = actionRequest.createExecutor(); + executor.setParams({ key: 'value' }); + await executor.execute({ httpAgent: mockHttpAgent, httpsAgent: mockHttpsAgent }); + + expect(mockedAxios.create).toHaveBeenCalledWith( + expect.objectContaining({ + httpAgent: mockHttpAgent, + httpsAgent: mockHttpsAgent, + maxRedirects: 0, + }), + ); + }); + + it('should not include agent keys when no options are provided', async () => { + const actionRequest = new ActionRequest( + 'https://example.com', + '/test', + 'GET', + 'testOp', + false, + 'application/json', + ); + const executor = actionRequest.createExecutor(); + executor.setParams({ key: 'value' }); + await executor.execute(); + + const createArg = mockedAxios.create.mock.calls[ + mockedAxios.create.mock.calls.length - 1 + ][0] as Record; + expect(createArg).not.toHaveProperty('httpAgent'); + expect(createArg).not.toHaveProperty('httpsAgent'); + }); + + it('should pass agents through for POST requests', async () => { + const mockAgent = { ssrf: true }; + + const actionRequest = new ActionRequest( + 'https://example.com', + '/test', + 'POST', + 'testOp', + false, + 'application/json', + ); + const executor = actionRequest.createExecutor(); + executor.setParams({ body: 'data' }); + await executor.execute({ httpAgent: mockAgent, httpsAgent: mockAgent }); + + expect(mockedAxios.create).toHaveBeenCalledWith( + expect.objectContaining({ + httpAgent: mockAgent, + httpsAgent: mockAgent, + }), + ); + expect(mockedAxios.post).toHaveBeenCalled(); + }); + }); + describe('ActionRequest Concurrent Execution', () => { beforeEach(() => { jest.clearAllMocks(); diff --git a/packages/data-provider/src/actions.ts b/packages/data-provider/src/actions.ts index c7566e479f..53c9e8ae1c 100644 --- a/packages/data-provider/src/actions.ts +++ b/packages/data-provider/src/actions.ts @@ -283,7 +283,7 @@ class RequestExecutor { return this; } - async execute() { + async execute(options?: { httpAgent?: unknown; httpsAgent?: unknown }) { const url = createURL(this.config.domain, this.path); const headers: Record = { ...this.authHeaders, @@ -300,10 +300,15 @@ class RequestExecutor { * * By setting maxRedirects: 0, we prevent this attack vector. * The action will receive the redirect response (3xx) instead of following it. + * + * SECURITY: When httpAgent/httpsAgent are provided (SSRF-safe agents), they validate + * the DNS-resolved IP at TCP connect time, preventing TOCTOU DNS rebinding attacks. */ const axios = _axios.create({ maxRedirects: 0, - validateStatus: (status) => status >= 200 && status < 400, // Accept 3xx but don't follow + validateStatus: (status) => status >= 200 && status < 400, + ...(options?.httpAgent != null ? { httpAgent: options.httpAgent } : {}), + ...(options?.httpsAgent != null ? { httpsAgent: options.httpsAgent } : {}), }); // Initialize separate containers for query and body parameters.