mirror of
https://github.com/danny-avila/LibreChat.git
synced 2026-03-16 12:46:34 +01:00
🗝️ fix: Exempt Admin-Trusted Domains from MCP OAuth Validation (#12255)
* fix: exempt allowedDomains from MCP OAuth SSRF checks (#12254) The SSRF guard in validateOAuthUrl was context-blind — it blocked private/internal OAuth endpoints even for admin-trusted MCP servers listed in mcpSettings.allowedDomains. Add isHostnameAllowed() to domain.ts and skip SSRF checks in validateOAuthUrl when the OAuth endpoint hostname matches an allowed domain. * refactor: thread allowedDomains through MCP connection stack Pass allowedDomains from MCPServersRegistry through BasicConnectionOptions, MCPConnectionFactory, and into MCPOAuthHandler method calls so the OAuth layer can exempt admin-trusted domains from SSRF validation. * test: add allowedDomains bypass tests and fix registry mocks Add isHostnameAllowed unit tests (exact, wildcard, case-insensitive, private IPs). Add MCPOAuthSecurity tests covering the allowedDomains bypass for initiateOAuthFlow, refreshOAuthTokens, and revokeOAuthToken. Update registry mocks to include getAllowedDomains. * fix: enforce protocol/port constraints in OAuth allowedDomains bypass Replace isHostnameAllowed (hostname-only check) with isOAuthUrlAllowed which parses the full OAuth URL and matches against allowedDomains entries including protocol and explicit port constraints — mirroring isDomainAllowedCore's allowlist logic. Prevents a port-scoped entry like 'https://auth.internal:8443' from also exempting other ports. * test: cover auto-discovery and branch-3 refresh paths with allowedDomains Add three new integration tests using a real OAuth test server: - auto-discovered OAuth endpoints allowed when server IP is in allowedDomains - auto-discovered endpoints rejected when allowedDomains doesn't match - refreshOAuthTokens branch 3 (no clientInfo/config) with allowedDomains bypass Also rename describe block from ephemeral issue number to durable name. * docs: explain intentional absence of allowedDomains in completeOAuthFlow Prevents future contributors from assuming a missing parameter during security audits — URLs are pre-validated during initiateOAuthFlow. * test: update initiateOAuthFlow assertion for allowedDomains parameter * perf: avoid redundant URL parse for admin-trusted OAuth endpoints Move isOAuthUrlAllowed check before the hostname extraction so admin-trusted URLs short-circuit with a single URL parse instead of two. The hostname extraction (new URL) is now deferred to the SSRF-check path where it's actually needed.
This commit is contained in:
parent
8e8fb01d18
commit
acd07e8085
15 changed files with 432 additions and 18 deletions
|
|
@ -370,6 +370,7 @@ const maybeUninstallOAuthMCP = async (userId, pluginKey, appConfig) => {
|
|||
serverConfig.oauth?.revocation_endpoint_auth_methods_supported ??
|
||||
clientMetadata.revocation_endpoint_auth_methods_supported;
|
||||
const oauthHeaders = serverConfig.oauth_headers ?? {};
|
||||
const allowedDomains = getMCPServersRegistry().getAllowedDomains();
|
||||
|
||||
if (tokens?.access_token) {
|
||||
try {
|
||||
|
|
@ -385,6 +386,7 @@ const maybeUninstallOAuthMCP = async (userId, pluginKey, appConfig) => {
|
|||
revocationEndpointAuthMethodsSupported,
|
||||
},
|
||||
oauthHeaders,
|
||||
allowedDomains,
|
||||
);
|
||||
} catch (error) {
|
||||
logger.error(`Error revoking OAuth access token for ${serverName}:`, error);
|
||||
|
|
@ -405,6 +407,7 @@ const maybeUninstallOAuthMCP = async (userId, pluginKey, appConfig) => {
|
|||
revocationEndpointAuthMethodsSupported,
|
||||
},
|
||||
oauthHeaders,
|
||||
allowedDomains,
|
||||
);
|
||||
} catch (error) {
|
||||
logger.error(`Error revoking OAuth refresh token for ${serverName}:`, error);
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ import {
|
|||
extractMCPServerDomain,
|
||||
isActionDomainAllowed,
|
||||
isEmailDomainAllowed,
|
||||
isOAuthUrlAllowed,
|
||||
isMCPDomainAllowed,
|
||||
isPrivateIP,
|
||||
isSSRFTarget,
|
||||
|
|
@ -1211,6 +1212,96 @@ describe('isMCPDomainAllowed', () => {
|
|||
});
|
||||
});
|
||||
|
||||
describe('isOAuthUrlAllowed', () => {
|
||||
it('should return false when allowedDomains is null/undefined/empty', () => {
|
||||
expect(isOAuthUrlAllowed('https://example.com/token', null)).toBe(false);
|
||||
expect(isOAuthUrlAllowed('https://example.com/token', undefined)).toBe(false);
|
||||
expect(isOAuthUrlAllowed('https://example.com/token', [])).toBe(false);
|
||||
});
|
||||
|
||||
it('should return false for unparseable URLs', () => {
|
||||
expect(isOAuthUrlAllowed('not-a-url', ['example.com'])).toBe(false);
|
||||
});
|
||||
|
||||
it('should match exact hostnames', () => {
|
||||
expect(isOAuthUrlAllowed('https://example.com/token', ['example.com'])).toBe(true);
|
||||
expect(isOAuthUrlAllowed('https://other.com/token', ['example.com'])).toBe(false);
|
||||
});
|
||||
|
||||
it('should match wildcard subdomains', () => {
|
||||
expect(isOAuthUrlAllowed('https://api.example.com/token', ['*.example.com'])).toBe(true);
|
||||
expect(isOAuthUrlAllowed('https://deep.nested.example.com/token', ['*.example.com'])).toBe(
|
||||
true,
|
||||
);
|
||||
expect(isOAuthUrlAllowed('https://example.com/token', ['*.example.com'])).toBe(true);
|
||||
expect(isOAuthUrlAllowed('https://other.com/token', ['*.example.com'])).toBe(false);
|
||||
});
|
||||
|
||||
it('should be case-insensitive', () => {
|
||||
expect(isOAuthUrlAllowed('https://EXAMPLE.COM/token', ['example.com'])).toBe(true);
|
||||
expect(isOAuthUrlAllowed('https://example.com/token', ['EXAMPLE.COM'])).toBe(true);
|
||||
});
|
||||
|
||||
it('should match private/internal URLs when hostname is in allowedDomains', () => {
|
||||
expect(isOAuthUrlAllowed('http://localhost:8080/token', ['localhost'])).toBe(true);
|
||||
expect(isOAuthUrlAllowed('http://10.0.0.1/token', ['10.0.0.1'])).toBe(true);
|
||||
expect(
|
||||
isOAuthUrlAllowed('http://host.docker.internal:8044/token', ['host.docker.internal']),
|
||||
).toBe(true);
|
||||
expect(isOAuthUrlAllowed('http://myserver.local/token', ['*.local'])).toBe(true);
|
||||
});
|
||||
|
||||
it('should match internal URLs with wildcard patterns', () => {
|
||||
expect(isOAuthUrlAllowed('https://auth.company.internal/token', ['*.company.internal'])).toBe(
|
||||
true,
|
||||
);
|
||||
expect(isOAuthUrlAllowed('https://company.internal/token', ['*.company.internal'])).toBe(true);
|
||||
});
|
||||
|
||||
it('should not match when hostname is absent from allowedDomains', () => {
|
||||
expect(isOAuthUrlAllowed('http://10.0.0.1/token', ['192.168.1.1'])).toBe(false);
|
||||
expect(isOAuthUrlAllowed('http://localhost/token', ['host.docker.internal'])).toBe(false);
|
||||
});
|
||||
|
||||
describe('protocol and port constraint enforcement', () => {
|
||||
it('should enforce protocol when allowedDomains specifies one', () => {
|
||||
expect(isOAuthUrlAllowed('https://auth.internal/token', ['https://auth.internal'])).toBe(
|
||||
true,
|
||||
);
|
||||
expect(isOAuthUrlAllowed('http://auth.internal/token', ['https://auth.internal'])).toBe(
|
||||
false,
|
||||
);
|
||||
});
|
||||
|
||||
it('should allow any protocol when allowedDomains has bare hostname', () => {
|
||||
expect(isOAuthUrlAllowed('http://auth.internal/token', ['auth.internal'])).toBe(true);
|
||||
expect(isOAuthUrlAllowed('https://auth.internal/token', ['auth.internal'])).toBe(true);
|
||||
});
|
||||
|
||||
it('should enforce port when allowedDomains specifies one', () => {
|
||||
expect(
|
||||
isOAuthUrlAllowed('https://auth.internal:8443/token', ['https://auth.internal:8443']),
|
||||
).toBe(true);
|
||||
expect(
|
||||
isOAuthUrlAllowed('https://auth.internal:6379/token', ['https://auth.internal:8443']),
|
||||
).toBe(false);
|
||||
expect(isOAuthUrlAllowed('https://auth.internal/token', ['https://auth.internal:8443'])).toBe(
|
||||
false,
|
||||
);
|
||||
});
|
||||
|
||||
it('should allow any port when allowedDomains has no explicit port', () => {
|
||||
expect(isOAuthUrlAllowed('https://auth.internal:8443/token', ['auth.internal'])).toBe(true);
|
||||
expect(isOAuthUrlAllowed('https://auth.internal:22/token', ['auth.internal'])).toBe(true);
|
||||
});
|
||||
|
||||
it('should reject wrong port even when hostname matches (prevents port-scanning)', () => {
|
||||
expect(isOAuthUrlAllowed('http://10.0.0.1:6379/token', ['http://10.0.0.1:8080'])).toBe(false);
|
||||
expect(isOAuthUrlAllowed('http://10.0.0.1:25/token', ['http://10.0.0.1:8080'])).toBe(false);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('validateEndpointURL', () => {
|
||||
afterEach(() => {
|
||||
jest.clearAllMocks();
|
||||
|
|
|
|||
|
|
@ -500,6 +500,52 @@ export async function isMCPDomainAllowed(
|
|||
return isDomainAllowedCore(domain, allowedDomains, MCP_PROTOCOLS);
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks whether an OAuth URL matches any entry in the MCP allowedDomains list,
|
||||
* honoring protocol and port constraints when specified by the admin.
|
||||
*
|
||||
* Mirrors the allowlist-matching logic of {@link isDomainAllowedCore} (hostname,
|
||||
* protocol, and explicit-port checks) but is synchronous — no DNS resolution is
|
||||
* needed because the caller is deciding whether to *skip* the subsequent
|
||||
* SSRF/DNS checks, not replace them.
|
||||
*
|
||||
* @remarks `parseDomainSpec` normalizes `www.` prefixes, so both the input URL
|
||||
* and allowedDomains entries starting with `www.` are matched without that prefix.
|
||||
*/
|
||||
export function isOAuthUrlAllowed(url: string, allowedDomains?: string[] | null): boolean {
|
||||
if (!Array.isArray(allowedDomains) || allowedDomains.length === 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const inputSpec = parseDomainSpec(url);
|
||||
if (!inputSpec) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (const allowedDomain of allowedDomains) {
|
||||
const allowedSpec = parseDomainSpec(allowedDomain);
|
||||
if (!allowedSpec) {
|
||||
continue;
|
||||
}
|
||||
if (!hostnameMatches(inputSpec.hostname, allowedSpec)) {
|
||||
continue;
|
||||
}
|
||||
if (allowedSpec.protocol !== null) {
|
||||
if (inputSpec.protocol === null || inputSpec.protocol !== allowedSpec.protocol) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if (allowedSpec.explicitPort) {
|
||||
if (!inputSpec.explicitPort || inputSpec.port !== allowedSpec.port) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
/** Matches ErrorTypes.INVALID_BASE_URL — string literal avoids build-time dependency on data-provider */
|
||||
const INVALID_BASE_URL_TYPE = 'invalid_base_url';
|
||||
|
||||
|
|
|
|||
|
|
@ -77,12 +77,14 @@ export class ConnectionsRepository {
|
|||
await this.disconnect(serverName);
|
||||
}
|
||||
}
|
||||
const registry = MCPServersRegistry.getInstance();
|
||||
const connection = await MCPConnectionFactory.create(
|
||||
{
|
||||
serverName,
|
||||
serverConfig,
|
||||
dbSourced: !!(serverConfig as t.ParsedServerConfig).dbId,
|
||||
useSSRFProtection: MCPServersRegistry.getInstance().shouldEnableSSRFProtection(),
|
||||
useSSRFProtection: registry.shouldEnableSSRFProtection(),
|
||||
allowedDomains: registry.getAllowedDomains(),
|
||||
},
|
||||
this.oauthOpts,
|
||||
);
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ export class MCPConnectionFactory {
|
|||
protected readonly logPrefix: string;
|
||||
protected readonly useOAuth: boolean;
|
||||
protected readonly useSSRFProtection: boolean;
|
||||
protected readonly allowedDomains?: string[] | null;
|
||||
|
||||
// OAuth-related properties (only set when useOAuth is true)
|
||||
protected readonly userId?: string;
|
||||
|
|
@ -197,6 +198,7 @@ export class MCPConnectionFactory {
|
|||
this.serverName = basic.serverName;
|
||||
this.useOAuth = !!oauth?.useOAuth;
|
||||
this.useSSRFProtection = basic.useSSRFProtection === true;
|
||||
this.allowedDomains = basic.allowedDomains;
|
||||
this.connectionTimeout = oauth?.connectionTimeout;
|
||||
this.logPrefix = oauth?.user
|
||||
? `[MCP][${basic.serverName}][${oauth.user.id}]`
|
||||
|
|
@ -297,6 +299,7 @@ export class MCPConnectionFactory {
|
|||
},
|
||||
this.serverConfig.oauth_headers ?? {},
|
||||
this.serverConfig.oauth,
|
||||
this.allowedDomains,
|
||||
);
|
||||
};
|
||||
}
|
||||
|
|
@ -340,6 +343,7 @@ export class MCPConnectionFactory {
|
|||
this.userId!,
|
||||
config?.oauth_headers ?? {},
|
||||
config?.oauth,
|
||||
this.allowedDomains,
|
||||
);
|
||||
|
||||
if (existingFlow) {
|
||||
|
|
@ -603,6 +607,7 @@ export class MCPConnectionFactory {
|
|||
this.userId!,
|
||||
this.serverConfig.oauth_headers ?? {},
|
||||
this.serverConfig.oauth,
|
||||
this.allowedDomains,
|
||||
);
|
||||
|
||||
// Store flow state BEFORE redirecting so the callback can find it
|
||||
|
|
|
|||
|
|
@ -100,13 +100,16 @@ export class MCPManager extends UserConnectionManager {
|
|||
|
||||
const useOAuth = Boolean(serverConfig.requiresOAuth || serverConfig.oauthMetadata);
|
||||
|
||||
const useSSRFProtection = MCPServersRegistry.getInstance().shouldEnableSSRFProtection();
|
||||
const registry = MCPServersRegistry.getInstance();
|
||||
const useSSRFProtection = registry.shouldEnableSSRFProtection();
|
||||
const allowedDomains = registry.getAllowedDomains();
|
||||
const dbSourced = !!serverConfig.dbId;
|
||||
const basic: t.BasicConnectionOptions = {
|
||||
dbSourced,
|
||||
serverName,
|
||||
serverConfig,
|
||||
useSSRFProtection,
|
||||
allowedDomains,
|
||||
};
|
||||
|
||||
if (!useOAuth) {
|
||||
|
|
|
|||
|
|
@ -153,12 +153,14 @@ export abstract class UserConnectionManager {
|
|||
logger.info(`[MCP][User: ${userId}][${serverName}] Establishing new connection`);
|
||||
|
||||
try {
|
||||
const registry = MCPServersRegistry.getInstance();
|
||||
connection = await MCPConnectionFactory.create(
|
||||
{
|
||||
serverConfig: config,
|
||||
serverName: serverName,
|
||||
dbSourced: !!config.dbId,
|
||||
useSSRFProtection: MCPServersRegistry.getInstance().shouldEnableSSRFProtection(),
|
||||
useSSRFProtection: registry.shouldEnableSSRFProtection(),
|
||||
allowedDomains: registry.getAllowedDomains(),
|
||||
},
|
||||
{
|
||||
useOAuth: true,
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ const mockRegistryInstance = {
|
|||
getServerConfig: jest.fn(),
|
||||
getAllServerConfigs: jest.fn(),
|
||||
shouldEnableSSRFProtection: jest.fn().mockReturnValue(false),
|
||||
getAllowedDomains: jest.fn().mockReturnValue(null),
|
||||
};
|
||||
|
||||
jest.mock('../registry/MCPServersRegistry', () => ({
|
||||
|
|
@ -110,6 +111,7 @@ describe('ConnectionsRepository', () => {
|
|||
serverName: 'server1',
|
||||
serverConfig: mockServerConfigs.server1,
|
||||
useSSRFProtection: false,
|
||||
allowedDomains: null,
|
||||
dbSourced: false,
|
||||
},
|
||||
undefined,
|
||||
|
|
@ -133,6 +135,7 @@ describe('ConnectionsRepository', () => {
|
|||
serverName: 'server1',
|
||||
serverConfig: mockServerConfigs.server1,
|
||||
useSSRFProtection: false,
|
||||
allowedDomains: null,
|
||||
dbSourced: false,
|
||||
},
|
||||
undefined,
|
||||
|
|
@ -173,6 +176,7 @@ describe('ConnectionsRepository', () => {
|
|||
serverName: 'server1',
|
||||
serverConfig: configWithCachedAt,
|
||||
useSSRFProtection: false,
|
||||
allowedDomains: null,
|
||||
dbSourced: false,
|
||||
},
|
||||
undefined,
|
||||
|
|
|
|||
|
|
@ -269,6 +269,7 @@ describe('MCPConnectionFactory', () => {
|
|||
'user123',
|
||||
{},
|
||||
undefined,
|
||||
undefined,
|
||||
);
|
||||
|
||||
// initFlow must be awaited BEFORE the redirect to guarantee state is stored
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@ const mockRegistryInstance = {
|
|||
getAllServerConfigs: jest.fn(),
|
||||
getOAuthServers: jest.fn(),
|
||||
shouldEnableSSRFProtection: jest.fn().mockReturnValue(false),
|
||||
getAllowedDomains: jest.fn().mockReturnValue(null),
|
||||
};
|
||||
|
||||
jest.mock('~/mcp/registry/MCPServersRegistry', () => ({
|
||||
|
|
|
|||
|
|
@ -82,6 +82,7 @@ describe('MCP OAuth Race Condition Fixes', () => {
|
|||
.mockReturnValue({
|
||||
getServerConfig: jest.fn().mockResolvedValue(mockConfig),
|
||||
shouldEnableSSRFProtection: jest.fn().mockReturnValue(false),
|
||||
getAllowedDomains: jest.fn().mockReturnValue(null),
|
||||
});
|
||||
|
||||
const { MCPConnectionFactory } = await import('~/mcp/MCPConnectionFactory');
|
||||
|
|
@ -147,6 +148,7 @@ describe('MCP OAuth Race Condition Fixes', () => {
|
|||
.mockReturnValue({
|
||||
getServerConfig: jest.fn().mockResolvedValue(mockConfig),
|
||||
shouldEnableSSRFProtection: jest.fn().mockReturnValue(false),
|
||||
getAllowedDomains: jest.fn().mockReturnValue(null),
|
||||
});
|
||||
|
||||
const { MCPConnectionFactory } = await import('~/mcp/MCPConnectionFactory');
|
||||
|
|
|
|||
|
|
@ -7,6 +7,9 @@
|
|||
*
|
||||
* 2. redirect_uri manipulation — validates that user-supplied redirect_uri
|
||||
* is ignored in favor of the server-controlled default.
|
||||
*
|
||||
* 3. allowedDomains SSRF exemption — validates that admin-configured allowedDomains
|
||||
* exempts trusted domains from SSRF checks, including auto-discovery paths.
|
||||
*/
|
||||
|
||||
import * as http from 'http';
|
||||
|
|
@ -226,3 +229,214 @@ describe('MCP OAuth redirect_uri enforcement', () => {
|
|||
expect(authUrl.searchParams.get('redirect_uri')).not.toBe(attackerRedirectUri);
|
||||
});
|
||||
});
|
||||
|
||||
describe('MCP OAuth allowedDomains SSRF exemption for admin-trusted hosts', () => {
|
||||
it('should allow private authorization_url when hostname is in allowedDomains', async () => {
|
||||
const result = await MCPOAuthHandler.initiateOAuthFlow(
|
||||
'internal-server',
|
||||
'https://speedy-mcp.company.com/',
|
||||
'user-1',
|
||||
{},
|
||||
{
|
||||
authorization_url: 'http://10.0.0.1/authorize',
|
||||
token_url: 'http://10.0.0.1/token',
|
||||
client_id: 'client',
|
||||
client_secret: 'secret',
|
||||
},
|
||||
['10.0.0.1'],
|
||||
);
|
||||
|
||||
expect(result.authorizationUrl).toContain('10.0.0.1/authorize');
|
||||
});
|
||||
|
||||
it('should allow private token_url when hostname matches wildcard allowedDomains', async () => {
|
||||
const result = await MCPOAuthHandler.initiateOAuthFlow(
|
||||
'internal-server',
|
||||
'https://speedy-mcp.company.com/',
|
||||
'user-1',
|
||||
{},
|
||||
{
|
||||
authorization_url: 'https://auth.company.internal/authorize',
|
||||
token_url: 'https://auth.company.internal/token',
|
||||
client_id: 'client',
|
||||
client_secret: 'secret',
|
||||
},
|
||||
['*.company.internal'],
|
||||
);
|
||||
|
||||
expect(result.authorizationUrl).toContain('auth.company.internal/authorize');
|
||||
});
|
||||
|
||||
it('should still reject private URLs when allowedDomains does not match', async () => {
|
||||
await expect(
|
||||
MCPOAuthHandler.initiateOAuthFlow(
|
||||
'test-server',
|
||||
'https://mcp.example.com/',
|
||||
'user-1',
|
||||
{},
|
||||
{
|
||||
authorization_url: 'http://169.254.169.254/authorize',
|
||||
token_url: 'https://auth.example.com/token',
|
||||
client_id: 'client',
|
||||
client_secret: 'secret',
|
||||
},
|
||||
['safe.example.com'],
|
||||
),
|
||||
).rejects.toThrow(/targets a blocked address/);
|
||||
});
|
||||
|
||||
it('should still reject when allowedDomains is empty', async () => {
|
||||
await expect(
|
||||
MCPOAuthHandler.initiateOAuthFlow(
|
||||
'test-server',
|
||||
'https://mcp.example.com/',
|
||||
'user-1',
|
||||
{},
|
||||
{
|
||||
authorization_url: 'http://10.0.0.1/authorize',
|
||||
token_url: 'https://auth.example.com/token',
|
||||
client_id: 'client',
|
||||
client_secret: 'secret',
|
||||
},
|
||||
[],
|
||||
),
|
||||
).rejects.toThrow(/targets a blocked address/);
|
||||
});
|
||||
|
||||
it('should allow private revocationEndpoint when hostname is in allowedDomains', async () => {
|
||||
const mockFetch = jest.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
status: 200,
|
||||
} as Response);
|
||||
const originalFetch = global.fetch;
|
||||
global.fetch = mockFetch;
|
||||
|
||||
try {
|
||||
await MCPOAuthHandler.revokeOAuthToken(
|
||||
'internal-server',
|
||||
'some-token',
|
||||
'access',
|
||||
{
|
||||
serverUrl: 'https://internal.corp.net/',
|
||||
clientId: 'client',
|
||||
clientSecret: 'secret',
|
||||
revocationEndpoint: 'http://10.0.0.1/revoke',
|
||||
},
|
||||
{},
|
||||
['10.0.0.1'],
|
||||
);
|
||||
|
||||
expect(mockFetch).toHaveBeenCalled();
|
||||
} finally {
|
||||
global.fetch = originalFetch;
|
||||
}
|
||||
});
|
||||
|
||||
it('should allow localhost token_url in refreshOAuthTokens when localhost is in allowedDomains', async () => {
|
||||
const mockFetch = jest.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
json: async () => ({
|
||||
access_token: 'new-access-token',
|
||||
token_type: 'Bearer',
|
||||
expires_in: 3600,
|
||||
}),
|
||||
} as Response);
|
||||
const originalFetch = global.fetch;
|
||||
global.fetch = mockFetch;
|
||||
|
||||
try {
|
||||
const tokens = await MCPOAuthHandler.refreshOAuthTokens(
|
||||
'old-refresh-token',
|
||||
{
|
||||
serverName: 'local-server',
|
||||
serverUrl: 'http://localhost:8080/',
|
||||
clientInfo: {
|
||||
client_id: 'client-id',
|
||||
client_secret: 'client-secret',
|
||||
redirect_uris: ['http://localhost:3080/callback'],
|
||||
},
|
||||
},
|
||||
{},
|
||||
{
|
||||
token_url: 'http://localhost:8080/token',
|
||||
client_id: 'client-id',
|
||||
client_secret: 'client-secret',
|
||||
},
|
||||
['localhost'],
|
||||
);
|
||||
|
||||
expect(tokens.access_token).toBe('new-access-token');
|
||||
expect(mockFetch).toHaveBeenCalled();
|
||||
} finally {
|
||||
global.fetch = originalFetch;
|
||||
}
|
||||
});
|
||||
|
||||
describe('auto-discovery path with allowedDomains', () => {
|
||||
let discoveryServer: OAuthTestServer;
|
||||
|
||||
beforeEach(async () => {
|
||||
discoveryServer = await createOAuthMCPServer({
|
||||
tokenTTLMs: 60000,
|
||||
issueRefreshTokens: true,
|
||||
});
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
await discoveryServer.close();
|
||||
});
|
||||
|
||||
it('should allow auto-discovered OAuth endpoints when server IP is in allowedDomains', async () => {
|
||||
const result = await MCPOAuthHandler.initiateOAuthFlow(
|
||||
'discovery-server',
|
||||
discoveryServer.url,
|
||||
'user-1',
|
||||
{},
|
||||
undefined,
|
||||
['127.0.0.1'],
|
||||
);
|
||||
|
||||
expect(result.authorizationUrl).toContain('127.0.0.1');
|
||||
expect(result.flowId).toBeTruthy();
|
||||
});
|
||||
|
||||
it('should reject auto-discovered endpoints when allowedDomains does not cover server IP', async () => {
|
||||
await expect(
|
||||
MCPOAuthHandler.initiateOAuthFlow(
|
||||
'discovery-server',
|
||||
discoveryServer.url,
|
||||
'user-1',
|
||||
{},
|
||||
undefined,
|
||||
['safe.example.com'],
|
||||
),
|
||||
).rejects.toThrow(/targets a blocked address/);
|
||||
});
|
||||
|
||||
it('should allow auto-discovered token_url in refreshOAuthTokens branch 3 (no clientInfo/config)', async () => {
|
||||
const code = await discoveryServer.getAuthCode();
|
||||
const tokenRes = await fetch(`${discoveryServer.url}token`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
|
||||
body: `grant_type=authorization_code&code=${code}`,
|
||||
});
|
||||
const initial = (await tokenRes.json()) as {
|
||||
access_token: string;
|
||||
refresh_token: string;
|
||||
};
|
||||
|
||||
const tokens = await MCPOAuthHandler.refreshOAuthTokens(
|
||||
initial.refresh_token,
|
||||
{
|
||||
serverName: 'discovery-refresh-server',
|
||||
serverUrl: discoveryServer.url,
|
||||
},
|
||||
{},
|
||||
undefined,
|
||||
['127.0.0.1'],
|
||||
);
|
||||
|
||||
expect(tokens.access_token).toBeTruthy();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ import {
|
|||
selectRegistrationAuthMethod,
|
||||
inferClientAuthMethod,
|
||||
} from './methods';
|
||||
import { isSSRFTarget, resolveHostnameSSRF } from '~/auth';
|
||||
import { isSSRFTarget, resolveHostnameSSRF, isOAuthUrlAllowed } from '~/auth';
|
||||
import { sanitizeUrlForLogging } from '~/mcp/utils';
|
||||
|
||||
/** Type for the OAuth metadata from the SDK */
|
||||
|
|
@ -123,6 +123,7 @@ export class MCPOAuthHandler {
|
|||
private static async discoverMetadata(
|
||||
serverUrl: string,
|
||||
oauthHeaders: Record<string, string>,
|
||||
allowedDomains?: string[] | null,
|
||||
): Promise<{
|
||||
metadata: OAuthMetadata;
|
||||
resourceMetadata?: OAuthProtectedResourceMetadata;
|
||||
|
|
@ -146,7 +147,7 @@ export class MCPOAuthHandler {
|
|||
|
||||
if (resourceMetadata?.authorization_servers?.length) {
|
||||
const discoveredAuthServer = resourceMetadata.authorization_servers[0];
|
||||
await this.validateOAuthUrl(discoveredAuthServer, 'authorization_server');
|
||||
await this.validateOAuthUrl(discoveredAuthServer, 'authorization_server', allowedDomains);
|
||||
authServerUrl = new URL(discoveredAuthServer);
|
||||
logger.debug(
|
||||
`[MCPOAuth] Found authorization server from resource metadata: ${authServerUrl}`,
|
||||
|
|
@ -206,11 +207,17 @@ export class MCPOAuthHandler {
|
|||
const endpointChecks: Promise<void>[] = [];
|
||||
if (metadata.registration_endpoint) {
|
||||
endpointChecks.push(
|
||||
this.validateOAuthUrl(metadata.registration_endpoint, 'registration_endpoint'),
|
||||
this.validateOAuthUrl(
|
||||
metadata.registration_endpoint,
|
||||
'registration_endpoint',
|
||||
allowedDomains,
|
||||
),
|
||||
);
|
||||
}
|
||||
if (metadata.token_endpoint) {
|
||||
endpointChecks.push(this.validateOAuthUrl(metadata.token_endpoint, 'token_endpoint'));
|
||||
endpointChecks.push(
|
||||
this.validateOAuthUrl(metadata.token_endpoint, 'token_endpoint', allowedDomains),
|
||||
);
|
||||
}
|
||||
if (endpointChecks.length > 0) {
|
||||
await Promise.all(endpointChecks);
|
||||
|
|
@ -360,6 +367,7 @@ export class MCPOAuthHandler {
|
|||
userId: string,
|
||||
oauthHeaders: Record<string, string>,
|
||||
config?: MCPOptions['oauth'],
|
||||
allowedDomains?: string[] | null,
|
||||
): Promise<{ authorizationUrl: string; flowId: string; flowMetadata: MCPOAuthFlowMetadata }> {
|
||||
logger.debug(
|
||||
`[MCPOAuth] initiateOAuthFlow called for ${serverName} with URL: ${sanitizeUrlForLogging(serverUrl)}`,
|
||||
|
|
@ -375,8 +383,8 @@ export class MCPOAuthHandler {
|
|||
logger.debug(`[MCPOAuth] Using pre-configured OAuth settings for ${serverName}`);
|
||||
|
||||
await Promise.all([
|
||||
this.validateOAuthUrl(config.authorization_url, 'authorization_url'),
|
||||
this.validateOAuthUrl(config.token_url, 'token_url'),
|
||||
this.validateOAuthUrl(config.authorization_url, 'authorization_url', allowedDomains),
|
||||
this.validateOAuthUrl(config.token_url, 'token_url', allowedDomains),
|
||||
]);
|
||||
|
||||
const skipCodeChallengeCheck =
|
||||
|
|
@ -477,6 +485,7 @@ export class MCPOAuthHandler {
|
|||
const { metadata, resourceMetadata, authServerUrl } = await this.discoverMetadata(
|
||||
serverUrl,
|
||||
oauthHeaders,
|
||||
allowedDomains,
|
||||
);
|
||||
|
||||
logger.debug(
|
||||
|
|
@ -588,7 +597,11 @@ export class MCPOAuthHandler {
|
|||
}
|
||||
|
||||
/**
|
||||
* Completes the OAuth flow by exchanging the authorization code for tokens
|
||||
* Completes the OAuth flow by exchanging the authorization code for tokens.
|
||||
*
|
||||
* `allowedDomains` is intentionally absent: all URLs used here (serverUrl,
|
||||
* token_endpoint) originate from {@link MCPOAuthFlowMetadata} that was
|
||||
* SSRF-validated during {@link initiateOAuthFlow}. No new URL resolution occurs.
|
||||
*/
|
||||
static async completeOAuthFlow(
|
||||
flowId: string,
|
||||
|
|
@ -692,8 +705,20 @@ export class MCPOAuthHandler {
|
|||
return randomBytes(32).toString('base64url');
|
||||
}
|
||||
|
||||
/** Validates an OAuth URL is not targeting a private/internal address */
|
||||
private static async validateOAuthUrl(url: string, fieldName: string): Promise<void> {
|
||||
/**
|
||||
* Validates an OAuth URL is not targeting a private/internal address.
|
||||
* Skipped when the full URL (hostname + protocol + port) matches an admin-trusted
|
||||
* allowedDomains entry, honoring protocol/port constraints when the admin specifies them.
|
||||
*/
|
||||
private static async validateOAuthUrl(
|
||||
url: string,
|
||||
fieldName: string,
|
||||
allowedDomains?: string[] | null,
|
||||
): Promise<void> {
|
||||
if (isOAuthUrlAllowed(url, allowedDomains)) {
|
||||
return;
|
||||
}
|
||||
|
||||
let hostname: string;
|
||||
try {
|
||||
hostname = new URL(url).hostname;
|
||||
|
|
@ -799,6 +824,7 @@ export class MCPOAuthHandler {
|
|||
metadata: { serverName: string; serverUrl?: string; clientInfo?: OAuthClientInformation },
|
||||
oauthHeaders: Record<string, string>,
|
||||
config?: MCPOptions['oauth'],
|
||||
allowedDomains?: string[] | null,
|
||||
): Promise<MCPOAuthTokens> {
|
||||
logger.debug(`[MCPOAuth] Refreshing tokens for ${metadata.serverName}`);
|
||||
|
||||
|
|
@ -824,7 +850,7 @@ export class MCPOAuthHandler {
|
|||
let tokenUrl: string;
|
||||
let authMethods: string[] | undefined;
|
||||
if (config?.token_url) {
|
||||
await this.validateOAuthUrl(config.token_url, 'token_url');
|
||||
await this.validateOAuthUrl(config.token_url, 'token_url', allowedDomains);
|
||||
tokenUrl = config.token_url;
|
||||
authMethods = config.token_endpoint_auth_methods_supported;
|
||||
} else if (!metadata.serverUrl) {
|
||||
|
|
@ -851,7 +877,7 @@ export class MCPOAuthHandler {
|
|||
tokenUrl = oauthMetadata.token_endpoint;
|
||||
authMethods = oauthMetadata.token_endpoint_auth_methods_supported;
|
||||
}
|
||||
await this.validateOAuthUrl(tokenUrl, 'token_url');
|
||||
await this.validateOAuthUrl(tokenUrl, 'token_url', allowedDomains);
|
||||
}
|
||||
|
||||
const body = new URLSearchParams({
|
||||
|
|
@ -928,7 +954,7 @@ export class MCPOAuthHandler {
|
|||
if (config?.token_url && config?.client_id) {
|
||||
logger.debug(`[MCPOAuth] Using pre-configured OAuth settings for token refresh`);
|
||||
|
||||
await this.validateOAuthUrl(config.token_url, 'token_url');
|
||||
await this.validateOAuthUrl(config.token_url, 'token_url', allowedDomains);
|
||||
const tokenUrl = new URL(config.token_url);
|
||||
|
||||
const body = new URLSearchParams({
|
||||
|
|
@ -1026,7 +1052,7 @@ export class MCPOAuthHandler {
|
|||
} else {
|
||||
tokenUrl = new URL(oauthMetadata.token_endpoint);
|
||||
}
|
||||
await this.validateOAuthUrl(tokenUrl.href, 'token_url');
|
||||
await this.validateOAuthUrl(tokenUrl.href, 'token_url', allowedDomains);
|
||||
|
||||
const body = new URLSearchParams({
|
||||
grant_type: 'refresh_token',
|
||||
|
|
@ -1075,9 +1101,14 @@ export class MCPOAuthHandler {
|
|||
revocationEndpointAuthMethodsSupported?: string[];
|
||||
},
|
||||
oauthHeaders: Record<string, string> = {},
|
||||
allowedDomains?: string[] | null,
|
||||
): Promise<void> {
|
||||
if (metadata.revocationEndpoint != null) {
|
||||
await this.validateOAuthUrl(metadata.revocationEndpoint, 'revocation_endpoint');
|
||||
await this.validateOAuthUrl(
|
||||
metadata.revocationEndpoint,
|
||||
'revocation_endpoint',
|
||||
allowedDomains,
|
||||
);
|
||||
}
|
||||
const revokeUrl: URL =
|
||||
metadata.revocationEndpoint != null
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ export class MCPServerInspector {
|
|||
private readonly config: t.ParsedServerConfig,
|
||||
private connection: MCPConnection | undefined,
|
||||
private readonly useSSRFProtection: boolean = false,
|
||||
private readonly allowedDomains?: string[] | null,
|
||||
) {}
|
||||
|
||||
/**
|
||||
|
|
@ -46,7 +47,13 @@ export class MCPServerInspector {
|
|||
|
||||
const useSSRFProtection = !Array.isArray(allowedDomains) || allowedDomains.length === 0;
|
||||
const start = Date.now();
|
||||
const inspector = new MCPServerInspector(serverName, rawConfig, connection, useSSRFProtection);
|
||||
const inspector = new MCPServerInspector(
|
||||
serverName,
|
||||
rawConfig,
|
||||
connection,
|
||||
useSSRFProtection,
|
||||
allowedDomains,
|
||||
);
|
||||
await inspector.inspectServer();
|
||||
inspector.config.initDuration = Date.now() - start;
|
||||
return inspector.config;
|
||||
|
|
@ -68,6 +75,7 @@ export class MCPServerInspector {
|
|||
serverName: this.serverName,
|
||||
dbSourced: !!this.config.dbId,
|
||||
useSSRFProtection: this.useSSRFProtection,
|
||||
allowedDomains: this.allowedDomains,
|
||||
});
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -169,6 +169,7 @@ export interface BasicConnectionOptions {
|
|||
serverName: string;
|
||||
serverConfig: MCPOptions;
|
||||
useSSRFProtection?: boolean;
|
||||
allowedDomains?: string[] | null;
|
||||
/** When true, only resolve customUserVars in processMCPEnv (for DB-stored servers) */
|
||||
dbSourced?: boolean;
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue