mirror of
https://github.com/danny-avila/LibreChat.git
synced 2026-02-20 17:34:10 +01:00
🛡️ fix: Implement TOCTOU-Safe SSRF Protection for Actions and MCP (#11722)
* refactor: better SSRF Protection in Action and Tool Services - Added `createSSRFSafeAgents` function to create HTTP/HTTPS agents that block connections to private/reserved IP addresses, enhancing security against SSRF attacks. - Updated `createActionTool` to accept a `useSSRFProtection` parameter, allowing the use of SSRF-safe agents during tool execution. - Modified `processRequiredActions` and `loadAgentTools` to utilize the new SSRF protection feature based on allowed domains configuration. - Introduced `resolveHostnameSSRF` function to validate resolved IPs against private ranges, preventing potential SSRF vulnerabilities. - Enhanced tests for domain resolution and private IP detection to ensure robust SSRF protection mechanisms are in place. * feat: Implement SSRF protection in MCP connections - Added `createSSRFSafeUndiciConnect` function to provide SSRF-safe DNS lookup options for undici agents. - Updated `MCPConnection`, `MCPConnectionFactory`, and `ConnectionsRepository` to include `useSSRFProtection` parameter, enabling SSRF protection based on server configuration. - Enhanced `MCPManager` and `UserConnectionManager` to utilize SSRF protection when establishing connections. - Updated tests to validate the integration of SSRF protection across various components, ensuring robust security measures are in place. * refactor: WS MCPConnection with SSRF protection and async transport construction - Added `resolveHostnameSSRF` to validate WebSocket hostnames against private IP addresses, enhancing SSRF protection. - Updated `constructTransport` method to be asynchronous, ensuring proper handling of SSRF checks before establishing connections. - Improved error handling for WebSocket transport to prevent connections to potentially unsafe addresses. * test: Enhance ActionRequest tests for SSRF-safe agent passthrough - Added tests to verify that httpAgent and httpsAgent are correctly passed to axios.create when provided in ActionRequest. - Included scenarios to ensure agents are not included when no options are specified. - Enhanced coverage for POST requests to confirm agent passthrough functionality. - Improved overall test robustness for SSRF protection in ActionRequest execution.
This commit is contained in:
parent
d6b6f191f7
commit
924be3b647
21 changed files with 567 additions and 53 deletions
|
|
@ -73,6 +73,7 @@ export class ConnectionsRepository {
|
|||
{
|
||||
serverName,
|
||||
serverConfig,
|
||||
useSSRFProtection: MCPServersRegistry.getInstance().shouldEnableSSRFProtection(),
|
||||
},
|
||||
this.oauthOpts,
|
||||
);
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ export class MCPConnectionFactory {
|
|||
protected readonly serverConfig: t.MCPOptions;
|
||||
protected readonly logPrefix: string;
|
||||
protected readonly useOAuth: boolean;
|
||||
protected readonly useSSRFProtection: boolean;
|
||||
|
||||
// OAuth-related properties (only set when useOAuth is true)
|
||||
protected readonly userId?: string;
|
||||
|
|
@ -72,6 +73,7 @@ export class MCPConnectionFactory {
|
|||
serverConfig: this.serverConfig,
|
||||
userId: this.userId,
|
||||
oauthTokens,
|
||||
useSSRFProtection: this.useSSRFProtection,
|
||||
});
|
||||
|
||||
const oauthHandler = async () => {
|
||||
|
|
@ -146,6 +148,7 @@ export class MCPConnectionFactory {
|
|||
serverConfig: this.serverConfig,
|
||||
userId: this.userId,
|
||||
oauthTokens: null,
|
||||
useSSRFProtection: this.useSSRFProtection,
|
||||
});
|
||||
|
||||
unauthConnection.on('oauthRequired', () => {
|
||||
|
|
@ -189,6 +192,7 @@ export class MCPConnectionFactory {
|
|||
});
|
||||
this.serverName = basic.serverName;
|
||||
this.useOAuth = !!oauth?.useOAuth;
|
||||
this.useSSRFProtection = basic.useSSRFProtection === true;
|
||||
this.connectionTimeout = oauth?.connectionTimeout;
|
||||
this.logPrefix = oauth?.user
|
||||
? `[MCP][${basic.serverName}][${oauth.user.id}]`
|
||||
|
|
@ -213,6 +217,7 @@ export class MCPConnectionFactory {
|
|||
serverConfig: this.serverConfig,
|
||||
userId: this.userId,
|
||||
oauthTokens,
|
||||
useSSRFProtection: this.useSSRFProtection,
|
||||
});
|
||||
|
||||
let cleanupOAuthHandlers: (() => void) | null = null;
|
||||
|
|
|
|||
|
|
@ -102,7 +102,8 @@ export class MCPManager extends UserConnectionManager {
|
|||
serverConfig.requiresOAuth || (serverConfig as t.ParsedServerConfig).oauthMetadata,
|
||||
);
|
||||
|
||||
const basic: t.BasicConnectionOptions = { serverName, serverConfig };
|
||||
const useSSRFProtection = MCPServersRegistry.getInstance().shouldEnableSSRFProtection();
|
||||
const basic: t.BasicConnectionOptions = { serverName, serverConfig, useSSRFProtection };
|
||||
|
||||
if (!useOAuth) {
|
||||
const result = await MCPConnectionFactory.discoverTools(basic);
|
||||
|
|
|
|||
|
|
@ -117,6 +117,7 @@ export abstract class UserConnectionManager {
|
|||
{
|
||||
serverName: serverName,
|
||||
serverConfig: config,
|
||||
useSSRFProtection: MCPServersRegistry.getInstance().shouldEnableSSRFProtection(),
|
||||
},
|
||||
{
|
||||
useOAuth: true,
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ jest.mock('../connection');
|
|||
const mockRegistryInstance = {
|
||||
getServerConfig: jest.fn(),
|
||||
getAllServerConfigs: jest.fn(),
|
||||
shouldEnableSSRFProtection: jest.fn().mockReturnValue(false),
|
||||
};
|
||||
|
||||
jest.mock('../registry/MCPServersRegistry', () => ({
|
||||
|
|
@ -108,6 +109,7 @@ describe('ConnectionsRepository', () => {
|
|||
{
|
||||
serverName: 'server1',
|
||||
serverConfig: mockServerConfigs.server1,
|
||||
useSSRFProtection: false,
|
||||
},
|
||||
undefined,
|
||||
);
|
||||
|
|
@ -129,6 +131,7 @@ describe('ConnectionsRepository', () => {
|
|||
{
|
||||
serverName: 'server1',
|
||||
serverConfig: mockServerConfigs.server1,
|
||||
useSSRFProtection: false,
|
||||
},
|
||||
undefined,
|
||||
);
|
||||
|
|
@ -167,6 +170,7 @@ describe('ConnectionsRepository', () => {
|
|||
{
|
||||
serverName: 'server1',
|
||||
serverConfig: configWithCachedAt,
|
||||
useSSRFProtection: false,
|
||||
},
|
||||
undefined,
|
||||
);
|
||||
|
|
|
|||
|
|
@ -84,6 +84,7 @@ describe('MCPConnectionFactory', () => {
|
|||
serverConfig: mockServerConfig,
|
||||
userId: undefined,
|
||||
oauthTokens: null,
|
||||
useSSRFProtection: false,
|
||||
});
|
||||
expect(mockConnectionInstance.connect).toHaveBeenCalled();
|
||||
});
|
||||
|
|
@ -125,6 +126,7 @@ describe('MCPConnectionFactory', () => {
|
|||
serverConfig: mockServerConfig,
|
||||
userId: 'user123',
|
||||
oauthTokens: mockTokens,
|
||||
useSSRFProtection: false,
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
@ -184,6 +186,7 @@ describe('MCPConnectionFactory', () => {
|
|||
serverConfig: mockServerConfig,
|
||||
userId: 'user123',
|
||||
oauthTokens: null,
|
||||
useSSRFProtection: false,
|
||||
});
|
||||
expect(mockLogger.debug).toHaveBeenCalledWith(
|
||||
expect.stringContaining('No existing tokens found or error loading tokens'),
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ const mockRegistryInstance = {
|
|||
getServerConfig: jest.fn(),
|
||||
getAllServerConfigs: jest.fn(),
|
||||
getOAuthServers: jest.fn(),
|
||||
shouldEnableSSRFProtection: jest.fn().mockReturnValue(false),
|
||||
};
|
||||
|
||||
jest.mock('~/mcp/registry/MCPServersRegistry', () => ({
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ import type {
|
|||
import type { MCPOAuthTokens } from './oauth/types';
|
||||
import { withTimeout } from '~/utils/promise';
|
||||
import type * as t from './types';
|
||||
import { createSSRFSafeUndiciConnect, resolveHostnameSSRF } from '~/auth';
|
||||
import { sanitizeUrlForLogging } from './utils';
|
||||
import { mcpConfig } from './mcpConfig';
|
||||
|
||||
|
|
@ -213,6 +214,7 @@ interface MCPConnectionParams {
|
|||
serverConfig: t.MCPOptions;
|
||||
userId?: string;
|
||||
oauthTokens?: MCPOAuthTokens | null;
|
||||
useSSRFProtection?: boolean;
|
||||
}
|
||||
|
||||
export class MCPConnection extends EventEmitter {
|
||||
|
|
@ -233,6 +235,7 @@ export class MCPConnection extends EventEmitter {
|
|||
private oauthTokens?: MCPOAuthTokens | null;
|
||||
private requestHeaders?: Record<string, string> | null;
|
||||
private oauthRequired = false;
|
||||
private readonly useSSRFProtection: boolean;
|
||||
iconPath?: string;
|
||||
timeout?: number;
|
||||
url?: string;
|
||||
|
|
@ -263,6 +266,7 @@ export class MCPConnection extends EventEmitter {
|
|||
this.options = params.serverConfig;
|
||||
this.serverName = params.serverName;
|
||||
this.userId = params.userId;
|
||||
this.useSSRFProtection = params.useSSRFProtection === true;
|
||||
this.iconPath = params.serverConfig.iconPath;
|
||||
this.timeout = params.serverConfig.timeout;
|
||||
this.lastPingTime = Date.now();
|
||||
|
|
@ -301,6 +305,7 @@ export class MCPConnection extends EventEmitter {
|
|||
getHeaders: () => Record<string, string> | null | undefined,
|
||||
timeout?: number,
|
||||
): (input: UndiciRequestInfo, init?: UndiciRequestInit) => Promise<UndiciResponse> {
|
||||
const ssrfConnect = this.useSSRFProtection ? createSSRFSafeUndiciConnect() : undefined;
|
||||
return function customFetch(
|
||||
input: UndiciRequestInfo,
|
||||
init?: UndiciRequestInit,
|
||||
|
|
@ -310,6 +315,7 @@ export class MCPConnection extends EventEmitter {
|
|||
const agent = new Agent({
|
||||
bodyTimeout: effectiveTimeout,
|
||||
headersTimeout: effectiveTimeout,
|
||||
...(ssrfConnect != null ? { connect: ssrfConnect } : {}),
|
||||
});
|
||||
if (!requestHeaders) {
|
||||
return undiciFetch(input, { ...init, dispatcher: agent });
|
||||
|
|
@ -342,7 +348,7 @@ export class MCPConnection extends EventEmitter {
|
|||
logger.error(`${this.getLogPrefix()} ${errorContext}: ${errorMessage}`);
|
||||
}
|
||||
|
||||
private constructTransport(options: t.MCPOptions): Transport {
|
||||
private async constructTransport(options: t.MCPOptions): Promise<Transport> {
|
||||
try {
|
||||
let type: t.MCPOptions['type'];
|
||||
if (isStdioOptions(options)) {
|
||||
|
|
@ -378,6 +384,15 @@ export class MCPConnection extends EventEmitter {
|
|||
throw new Error('Invalid options for websocket transport.');
|
||||
}
|
||||
this.url = options.url;
|
||||
if (this.useSSRFProtection) {
|
||||
const wsHostname = new URL(options.url).hostname;
|
||||
const isSSRF = await resolveHostnameSSRF(wsHostname);
|
||||
if (isSSRF) {
|
||||
throw new Error(
|
||||
`SSRF protection: WebSocket host "${wsHostname}" resolved to a private/reserved IP address`,
|
||||
);
|
||||
}
|
||||
}
|
||||
return new WebSocketClientTransport(new URL(options.url));
|
||||
|
||||
case 'sse': {
|
||||
|
|
@ -402,6 +417,7 @@ export class MCPConnection extends EventEmitter {
|
|||
* The connect timeout is extended because proxies may delay initial response.
|
||||
*/
|
||||
const sseTimeout = this.timeout || SSE_CONNECT_TIMEOUT;
|
||||
const ssrfConnect = this.useSSRFProtection ? createSSRFSafeUndiciConnect() : undefined;
|
||||
const transport = new SSEClientTransport(url, {
|
||||
requestInit: {
|
||||
/** User/OAuth headers override SSE defaults */
|
||||
|
|
@ -420,6 +436,7 @@ export class MCPConnection extends EventEmitter {
|
|||
/** Extended keep-alive for long-lived SSE connections */
|
||||
keepAliveTimeout: sseTimeout,
|
||||
keepAliveMaxTimeout: sseTimeout * 2,
|
||||
...(ssrfConnect != null ? { connect: ssrfConnect } : {}),
|
||||
});
|
||||
return undiciFetch(url, {
|
||||
...init,
|
||||
|
|
@ -629,7 +646,7 @@ export class MCPConnection extends EventEmitter {
|
|||
}
|
||||
}
|
||||
|
||||
this.transport = this.constructTransport(this.options);
|
||||
this.transport = await this.constructTransport(this.options);
|
||||
this.setupTransportDebugHandlers();
|
||||
|
||||
const connectTimeout = this.options.initTimeout ?? 120000;
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ export class MCPServerInspector {
|
|||
private readonly serverName: string,
|
||||
private readonly config: t.ParsedServerConfig,
|
||||
private connection: MCPConnection | undefined,
|
||||
private readonly useSSRFProtection: boolean = false,
|
||||
) {}
|
||||
|
||||
/**
|
||||
|
|
@ -42,8 +43,9 @@ export class MCPServerInspector {
|
|||
throw new MCPDomainNotAllowedError(domain ?? 'unknown');
|
||||
}
|
||||
|
||||
const useSSRFProtection = !Array.isArray(allowedDomains) || allowedDomains.length === 0;
|
||||
const start = Date.now();
|
||||
const inspector = new MCPServerInspector(serverName, rawConfig, connection);
|
||||
const inspector = new MCPServerInspector(serverName, rawConfig, connection, useSSRFProtection);
|
||||
await inspector.inspectServer();
|
||||
inspector.config.initDuration = Date.now() - start;
|
||||
return inspector.config;
|
||||
|
|
@ -59,6 +61,7 @@ export class MCPServerInspector {
|
|||
this.connection = await MCPConnectionFactory.create({
|
||||
serverName: this.serverName,
|
||||
serverConfig: this.config,
|
||||
useSSRFProtection: this.useSSRFProtection,
|
||||
});
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -77,6 +77,15 @@ export class MCPServersRegistry {
|
|||
return MCPServersRegistry.instance;
|
||||
}
|
||||
|
||||
public getAllowedDomains(): string[] | null | undefined {
|
||||
return this.allowedDomains;
|
||||
}
|
||||
|
||||
/** Returns true when no explicit allowedDomains allowlist is configured, enabling SSRF TOCTOU protection */
|
||||
public shouldEnableSSRFProtection(): boolean {
|
||||
return !Array.isArray(this.allowedDomains) || this.allowedDomains.length === 0;
|
||||
}
|
||||
|
||||
public async getServerConfig(
|
||||
serverName: string,
|
||||
userId?: string,
|
||||
|
|
|
|||
|
|
@ -276,6 +276,7 @@ describe('MCPServerInspector', () => {
|
|||
expect(MCPConnectionFactory.create).toHaveBeenCalledWith({
|
||||
serverName: 'test_server',
|
||||
serverConfig: expect.objectContaining({ type: 'stdio', command: 'node' }),
|
||||
useSSRFProtection: true,
|
||||
});
|
||||
|
||||
// Verify temporary connection was disconnected
|
||||
|
|
|
|||
|
|
@ -166,6 +166,7 @@ export type AddServerResult = {
|
|||
export interface BasicConnectionOptions {
|
||||
serverName: string;
|
||||
serverConfig: MCPOptions;
|
||||
useSSRFProtection?: boolean;
|
||||
}
|
||||
|
||||
export interface OAuthConnectionOptions {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue