From acd07e80852f6b931a4459372981b5d3db8082da Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Sun, 15 Mar 2026 23:03:12 -0400 Subject: [PATCH] =?UTF-8?q?=F0=9F=97=9D=EF=B8=8F=20fix:=20Exempt=20Admin-T?= =?UTF-8?q?rusted=20Domains=20from=20MCP=20OAuth=20Validation=20(#12255)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: exempt allowedDomains from MCP OAuth SSRF checks (#12254) The SSRF guard in validateOAuthUrl was context-blind — it blocked private/internal OAuth endpoints even for admin-trusted MCP servers listed in mcpSettings.allowedDomains. Add isHostnameAllowed() to domain.ts and skip SSRF checks in validateOAuthUrl when the OAuth endpoint hostname matches an allowed domain. * refactor: thread allowedDomains through MCP connection stack Pass allowedDomains from MCPServersRegistry through BasicConnectionOptions, MCPConnectionFactory, and into MCPOAuthHandler method calls so the OAuth layer can exempt admin-trusted domains from SSRF validation. * test: add allowedDomains bypass tests and fix registry mocks Add isHostnameAllowed unit tests (exact, wildcard, case-insensitive, private IPs). Add MCPOAuthSecurity tests covering the allowedDomains bypass for initiateOAuthFlow, refreshOAuthTokens, and revokeOAuthToken. Update registry mocks to include getAllowedDomains. * fix: enforce protocol/port constraints in OAuth allowedDomains bypass Replace isHostnameAllowed (hostname-only check) with isOAuthUrlAllowed which parses the full OAuth URL and matches against allowedDomains entries including protocol and explicit port constraints — mirroring isDomainAllowedCore's allowlist logic. Prevents a port-scoped entry like 'https://auth.internal:8443' from also exempting other ports. * test: cover auto-discovery and branch-3 refresh paths with allowedDomains Add three new integration tests using a real OAuth test server: - auto-discovered OAuth endpoints allowed when server IP is in allowedDomains - auto-discovered endpoints rejected when allowedDomains doesn't match - refreshOAuthTokens branch 3 (no clientInfo/config) with allowedDomains bypass Also rename describe block from ephemeral issue number to durable name. * docs: explain intentional absence of allowedDomains in completeOAuthFlow Prevents future contributors from assuming a missing parameter during security audits — URLs are pre-validated during initiateOAuthFlow. * test: update initiateOAuthFlow assertion for allowedDomains parameter * perf: avoid redundant URL parse for admin-trusted OAuth endpoints Move isOAuthUrlAllowed check before the hostname extraction so admin-trusted URLs short-circuit with a single URL parse instead of two. The hostname extraction (new URL) is now deferred to the SSRF-check path where it's actually needed. --- api/server/controllers/UserController.js | 3 + packages/api/src/auth/domain.spec.ts | 91 ++++++++ packages/api/src/auth/domain.ts | 46 ++++ packages/api/src/mcp/ConnectionsRepository.ts | 4 +- packages/api/src/mcp/MCPConnectionFactory.ts | 5 + packages/api/src/mcp/MCPManager.ts | 5 +- packages/api/src/mcp/UserConnectionManager.ts | 4 +- .../__tests__/ConnectionsRepository.test.ts | 4 + .../__tests__/MCPConnectionFactory.test.ts | 1 + .../api/src/mcp/__tests__/MCPManager.test.ts | 1 + .../__tests__/MCPOAuthRaceCondition.test.ts | 2 + .../mcp/__tests__/MCPOAuthSecurity.test.ts | 214 ++++++++++++++++++ packages/api/src/mcp/oauth/handler.ts | 59 +++-- .../src/mcp/registry/MCPServerInspector.ts | 10 +- packages/api/src/mcp/types/index.ts | 1 + 15 files changed, 432 insertions(+), 18 deletions(-) diff --git a/api/server/controllers/UserController.js b/api/server/controllers/UserController.js index b3160bb3d3..6d5df0ac8d 100644 --- a/api/server/controllers/UserController.js +++ b/api/server/controllers/UserController.js @@ -370,6 +370,7 @@ const maybeUninstallOAuthMCP = async (userId, pluginKey, appConfig) => { serverConfig.oauth?.revocation_endpoint_auth_methods_supported ?? clientMetadata.revocation_endpoint_auth_methods_supported; const oauthHeaders = serverConfig.oauth_headers ?? {}; + const allowedDomains = getMCPServersRegistry().getAllowedDomains(); if (tokens?.access_token) { try { @@ -385,6 +386,7 @@ const maybeUninstallOAuthMCP = async (userId, pluginKey, appConfig) => { revocationEndpointAuthMethodsSupported, }, oauthHeaders, + allowedDomains, ); } catch (error) { logger.error(`Error revoking OAuth access token for ${serverName}:`, error); @@ -405,6 +407,7 @@ const maybeUninstallOAuthMCP = async (userId, pluginKey, appConfig) => { revocationEndpointAuthMethodsSupported, }, oauthHeaders, + allowedDomains, ); } catch (error) { logger.error(`Error revoking OAuth refresh token for ${serverName}:`, error); diff --git a/packages/api/src/auth/domain.spec.ts b/packages/api/src/auth/domain.spec.ts index a7140528a9..88a7c98160 100644 --- a/packages/api/src/auth/domain.spec.ts +++ b/packages/api/src/auth/domain.spec.ts @@ -8,6 +8,7 @@ import { extractMCPServerDomain, isActionDomainAllowed, isEmailDomainAllowed, + isOAuthUrlAllowed, isMCPDomainAllowed, isPrivateIP, isSSRFTarget, @@ -1211,6 +1212,96 @@ describe('isMCPDomainAllowed', () => { }); }); +describe('isOAuthUrlAllowed', () => { + it('should return false when allowedDomains is null/undefined/empty', () => { + expect(isOAuthUrlAllowed('https://example.com/token', null)).toBe(false); + expect(isOAuthUrlAllowed('https://example.com/token', undefined)).toBe(false); + expect(isOAuthUrlAllowed('https://example.com/token', [])).toBe(false); + }); + + it('should return false for unparseable URLs', () => { + expect(isOAuthUrlAllowed('not-a-url', ['example.com'])).toBe(false); + }); + + it('should match exact hostnames', () => { + expect(isOAuthUrlAllowed('https://example.com/token', ['example.com'])).toBe(true); + expect(isOAuthUrlAllowed('https://other.com/token', ['example.com'])).toBe(false); + }); + + it('should match wildcard subdomains', () => { + expect(isOAuthUrlAllowed('https://api.example.com/token', ['*.example.com'])).toBe(true); + expect(isOAuthUrlAllowed('https://deep.nested.example.com/token', ['*.example.com'])).toBe( + true, + ); + expect(isOAuthUrlAllowed('https://example.com/token', ['*.example.com'])).toBe(true); + expect(isOAuthUrlAllowed('https://other.com/token', ['*.example.com'])).toBe(false); + }); + + it('should be case-insensitive', () => { + expect(isOAuthUrlAllowed('https://EXAMPLE.COM/token', ['example.com'])).toBe(true); + expect(isOAuthUrlAllowed('https://example.com/token', ['EXAMPLE.COM'])).toBe(true); + }); + + it('should match private/internal URLs when hostname is in allowedDomains', () => { + expect(isOAuthUrlAllowed('http://localhost:8080/token', ['localhost'])).toBe(true); + expect(isOAuthUrlAllowed('http://10.0.0.1/token', ['10.0.0.1'])).toBe(true); + expect( + isOAuthUrlAllowed('http://host.docker.internal:8044/token', ['host.docker.internal']), + ).toBe(true); + expect(isOAuthUrlAllowed('http://myserver.local/token', ['*.local'])).toBe(true); + }); + + it('should match internal URLs with wildcard patterns', () => { + expect(isOAuthUrlAllowed('https://auth.company.internal/token', ['*.company.internal'])).toBe( + true, + ); + expect(isOAuthUrlAllowed('https://company.internal/token', ['*.company.internal'])).toBe(true); + }); + + it('should not match when hostname is absent from allowedDomains', () => { + expect(isOAuthUrlAllowed('http://10.0.0.1/token', ['192.168.1.1'])).toBe(false); + expect(isOAuthUrlAllowed('http://localhost/token', ['host.docker.internal'])).toBe(false); + }); + + describe('protocol and port constraint enforcement', () => { + it('should enforce protocol when allowedDomains specifies one', () => { + expect(isOAuthUrlAllowed('https://auth.internal/token', ['https://auth.internal'])).toBe( + true, + ); + expect(isOAuthUrlAllowed('http://auth.internal/token', ['https://auth.internal'])).toBe( + false, + ); + }); + + it('should allow any protocol when allowedDomains has bare hostname', () => { + expect(isOAuthUrlAllowed('http://auth.internal/token', ['auth.internal'])).toBe(true); + expect(isOAuthUrlAllowed('https://auth.internal/token', ['auth.internal'])).toBe(true); + }); + + it('should enforce port when allowedDomains specifies one', () => { + expect( + isOAuthUrlAllowed('https://auth.internal:8443/token', ['https://auth.internal:8443']), + ).toBe(true); + expect( + isOAuthUrlAllowed('https://auth.internal:6379/token', ['https://auth.internal:8443']), + ).toBe(false); + expect(isOAuthUrlAllowed('https://auth.internal/token', ['https://auth.internal:8443'])).toBe( + false, + ); + }); + + it('should allow any port when allowedDomains has no explicit port', () => { + expect(isOAuthUrlAllowed('https://auth.internal:8443/token', ['auth.internal'])).toBe(true); + expect(isOAuthUrlAllowed('https://auth.internal:22/token', ['auth.internal'])).toBe(true); + }); + + it('should reject wrong port even when hostname matches (prevents port-scanning)', () => { + expect(isOAuthUrlAllowed('http://10.0.0.1:6379/token', ['http://10.0.0.1:8080'])).toBe(false); + expect(isOAuthUrlAllowed('http://10.0.0.1:25/token', ['http://10.0.0.1:8080'])).toBe(false); + }); + }); +}); + describe('validateEndpointURL', () => { afterEach(() => { jest.clearAllMocks(); diff --git a/packages/api/src/auth/domain.ts b/packages/api/src/auth/domain.ts index fabe2502ff..f4f9f5f04e 100644 --- a/packages/api/src/auth/domain.ts +++ b/packages/api/src/auth/domain.ts @@ -500,6 +500,52 @@ export async function isMCPDomainAllowed( return isDomainAllowedCore(domain, allowedDomains, MCP_PROTOCOLS); } +/** + * Checks whether an OAuth URL matches any entry in the MCP allowedDomains list, + * honoring protocol and port constraints when specified by the admin. + * + * Mirrors the allowlist-matching logic of {@link isDomainAllowedCore} (hostname, + * protocol, and explicit-port checks) but is synchronous — no DNS resolution is + * needed because the caller is deciding whether to *skip* the subsequent + * SSRF/DNS checks, not replace them. + * + * @remarks `parseDomainSpec` normalizes `www.` prefixes, so both the input URL + * and allowedDomains entries starting with `www.` are matched without that prefix. + */ +export function isOAuthUrlAllowed(url: string, allowedDomains?: string[] | null): boolean { + if (!Array.isArray(allowedDomains) || allowedDomains.length === 0) { + return false; + } + + const inputSpec = parseDomainSpec(url); + if (!inputSpec) { + return false; + } + + for (const allowedDomain of allowedDomains) { + const allowedSpec = parseDomainSpec(allowedDomain); + if (!allowedSpec) { + continue; + } + if (!hostnameMatches(inputSpec.hostname, allowedSpec)) { + continue; + } + if (allowedSpec.protocol !== null) { + if (inputSpec.protocol === null || inputSpec.protocol !== allowedSpec.protocol) { + continue; + } + } + if (allowedSpec.explicitPort) { + if (!inputSpec.explicitPort || inputSpec.port !== allowedSpec.port) { + continue; + } + } + return true; + } + + return false; +} + /** Matches ErrorTypes.INVALID_BASE_URL — string literal avoids build-time dependency on data-provider */ const INVALID_BASE_URL_TYPE = 'invalid_base_url'; diff --git a/packages/api/src/mcp/ConnectionsRepository.ts b/packages/api/src/mcp/ConnectionsRepository.ts index 970e7ea4b9..6313faa8d4 100644 --- a/packages/api/src/mcp/ConnectionsRepository.ts +++ b/packages/api/src/mcp/ConnectionsRepository.ts @@ -77,12 +77,14 @@ export class ConnectionsRepository { await this.disconnect(serverName); } } + const registry = MCPServersRegistry.getInstance(); const connection = await MCPConnectionFactory.create( { serverName, serverConfig, dbSourced: !!(serverConfig as t.ParsedServerConfig).dbId, - useSSRFProtection: MCPServersRegistry.getInstance().shouldEnableSSRFProtection(), + useSSRFProtection: registry.shouldEnableSSRFProtection(), + allowedDomains: registry.getAllowedDomains(), }, this.oauthOpts, ); diff --git a/packages/api/src/mcp/MCPConnectionFactory.ts b/packages/api/src/mcp/MCPConnectionFactory.ts index 0fc86e0315..b5b3d61bf0 100644 --- a/packages/api/src/mcp/MCPConnectionFactory.ts +++ b/packages/api/src/mcp/MCPConnectionFactory.ts @@ -30,6 +30,7 @@ export class MCPConnectionFactory { protected readonly logPrefix: string; protected readonly useOAuth: boolean; protected readonly useSSRFProtection: boolean; + protected readonly allowedDomains?: string[] | null; // OAuth-related properties (only set when useOAuth is true) protected readonly userId?: string; @@ -197,6 +198,7 @@ export class MCPConnectionFactory { this.serverName = basic.serverName; this.useOAuth = !!oauth?.useOAuth; this.useSSRFProtection = basic.useSSRFProtection === true; + this.allowedDomains = basic.allowedDomains; this.connectionTimeout = oauth?.connectionTimeout; this.logPrefix = oauth?.user ? `[MCP][${basic.serverName}][${oauth.user.id}]` @@ -297,6 +299,7 @@ export class MCPConnectionFactory { }, this.serverConfig.oauth_headers ?? {}, this.serverConfig.oauth, + this.allowedDomains, ); }; } @@ -340,6 +343,7 @@ export class MCPConnectionFactory { this.userId!, config?.oauth_headers ?? {}, config?.oauth, + this.allowedDomains, ); if (existingFlow) { @@ -603,6 +607,7 @@ export class MCPConnectionFactory { this.userId!, this.serverConfig.oauth_headers ?? {}, this.serverConfig.oauth, + this.allowedDomains, ); // Store flow state BEFORE redirecting so the callback can find it diff --git a/packages/api/src/mcp/MCPManager.ts b/packages/api/src/mcp/MCPManager.ts index 6fdf45c27a..afb6c68796 100644 --- a/packages/api/src/mcp/MCPManager.ts +++ b/packages/api/src/mcp/MCPManager.ts @@ -100,13 +100,16 @@ export class MCPManager extends UserConnectionManager { const useOAuth = Boolean(serverConfig.requiresOAuth || serverConfig.oauthMetadata); - const useSSRFProtection = MCPServersRegistry.getInstance().shouldEnableSSRFProtection(); + const registry = MCPServersRegistry.getInstance(); + const useSSRFProtection = registry.shouldEnableSSRFProtection(); + const allowedDomains = registry.getAllowedDomains(); const dbSourced = !!serverConfig.dbId; const basic: t.BasicConnectionOptions = { dbSourced, serverName, serverConfig, useSSRFProtection, + allowedDomains, }; if (!useOAuth) { diff --git a/packages/api/src/mcp/UserConnectionManager.ts b/packages/api/src/mcp/UserConnectionManager.ts index 76523fc0fc..2e9d5be467 100644 --- a/packages/api/src/mcp/UserConnectionManager.ts +++ b/packages/api/src/mcp/UserConnectionManager.ts @@ -153,12 +153,14 @@ export abstract class UserConnectionManager { logger.info(`[MCP][User: ${userId}][${serverName}] Establishing new connection`); try { + const registry = MCPServersRegistry.getInstance(); connection = await MCPConnectionFactory.create( { serverConfig: config, serverName: serverName, dbSourced: !!config.dbId, - useSSRFProtection: MCPServersRegistry.getInstance().shouldEnableSSRFProtection(), + useSSRFProtection: registry.shouldEnableSSRFProtection(), + allowedDomains: registry.getAllowedDomains(), }, { useOAuth: true, diff --git a/packages/api/src/mcp/__tests__/ConnectionsRepository.test.ts b/packages/api/src/mcp/__tests__/ConnectionsRepository.test.ts index 98e15eca18..7a93960765 100644 --- a/packages/api/src/mcp/__tests__/ConnectionsRepository.test.ts +++ b/packages/api/src/mcp/__tests__/ConnectionsRepository.test.ts @@ -25,6 +25,7 @@ const mockRegistryInstance = { getServerConfig: jest.fn(), getAllServerConfigs: jest.fn(), shouldEnableSSRFProtection: jest.fn().mockReturnValue(false), + getAllowedDomains: jest.fn().mockReturnValue(null), }; jest.mock('../registry/MCPServersRegistry', () => ({ @@ -110,6 +111,7 @@ describe('ConnectionsRepository', () => { serverName: 'server1', serverConfig: mockServerConfigs.server1, useSSRFProtection: false, + allowedDomains: null, dbSourced: false, }, undefined, @@ -133,6 +135,7 @@ describe('ConnectionsRepository', () => { serverName: 'server1', serverConfig: mockServerConfigs.server1, useSSRFProtection: false, + allowedDomains: null, dbSourced: false, }, undefined, @@ -173,6 +176,7 @@ describe('ConnectionsRepository', () => { serverName: 'server1', serverConfig: configWithCachedAt, useSSRFProtection: false, + allowedDomains: null, dbSourced: false, }, undefined, diff --git a/packages/api/src/mcp/__tests__/MCPConnectionFactory.test.ts b/packages/api/src/mcp/__tests__/MCPConnectionFactory.test.ts index bceb23b246..23bfa89d56 100644 --- a/packages/api/src/mcp/__tests__/MCPConnectionFactory.test.ts +++ b/packages/api/src/mcp/__tests__/MCPConnectionFactory.test.ts @@ -269,6 +269,7 @@ describe('MCPConnectionFactory', () => { 'user123', {}, undefined, + undefined, ); // initFlow must be awaited BEFORE the redirect to guarantee state is stored diff --git a/packages/api/src/mcp/__tests__/MCPManager.test.ts b/packages/api/src/mcp/__tests__/MCPManager.test.ts index bf63a6af3c..dd1ead0dd9 100644 --- a/packages/api/src/mcp/__tests__/MCPManager.test.ts +++ b/packages/api/src/mcp/__tests__/MCPManager.test.ts @@ -34,6 +34,7 @@ const mockRegistryInstance = { getAllServerConfigs: jest.fn(), getOAuthServers: jest.fn(), shouldEnableSSRFProtection: jest.fn().mockReturnValue(false), + getAllowedDomains: jest.fn().mockReturnValue(null), }; jest.mock('~/mcp/registry/MCPServersRegistry', () => ({ diff --git a/packages/api/src/mcp/__tests__/MCPOAuthRaceCondition.test.ts b/packages/api/src/mcp/__tests__/MCPOAuthRaceCondition.test.ts index 85febb3ece..cb6187ab45 100644 --- a/packages/api/src/mcp/__tests__/MCPOAuthRaceCondition.test.ts +++ b/packages/api/src/mcp/__tests__/MCPOAuthRaceCondition.test.ts @@ -82,6 +82,7 @@ describe('MCP OAuth Race Condition Fixes', () => { .mockReturnValue({ getServerConfig: jest.fn().mockResolvedValue(mockConfig), shouldEnableSSRFProtection: jest.fn().mockReturnValue(false), + getAllowedDomains: jest.fn().mockReturnValue(null), }); const { MCPConnectionFactory } = await import('~/mcp/MCPConnectionFactory'); @@ -147,6 +148,7 @@ describe('MCP OAuth Race Condition Fixes', () => { .mockReturnValue({ getServerConfig: jest.fn().mockResolvedValue(mockConfig), shouldEnableSSRFProtection: jest.fn().mockReturnValue(false), + getAllowedDomains: jest.fn().mockReturnValue(null), }); const { MCPConnectionFactory } = await import('~/mcp/MCPConnectionFactory'); diff --git a/packages/api/src/mcp/__tests__/MCPOAuthSecurity.test.ts b/packages/api/src/mcp/__tests__/MCPOAuthSecurity.test.ts index a5188e24b0..a2d0440d42 100644 --- a/packages/api/src/mcp/__tests__/MCPOAuthSecurity.test.ts +++ b/packages/api/src/mcp/__tests__/MCPOAuthSecurity.test.ts @@ -7,6 +7,9 @@ * * 2. redirect_uri manipulation — validates that user-supplied redirect_uri * is ignored in favor of the server-controlled default. + * + * 3. allowedDomains SSRF exemption — validates that admin-configured allowedDomains + * exempts trusted domains from SSRF checks, including auto-discovery paths. */ import * as http from 'http'; @@ -226,3 +229,214 @@ describe('MCP OAuth redirect_uri enforcement', () => { expect(authUrl.searchParams.get('redirect_uri')).not.toBe(attackerRedirectUri); }); }); + +describe('MCP OAuth allowedDomains SSRF exemption for admin-trusted hosts', () => { + it('should allow private authorization_url when hostname is in allowedDomains', async () => { + const result = await MCPOAuthHandler.initiateOAuthFlow( + 'internal-server', + 'https://speedy-mcp.company.com/', + 'user-1', + {}, + { + authorization_url: 'http://10.0.0.1/authorize', + token_url: 'http://10.0.0.1/token', + client_id: 'client', + client_secret: 'secret', + }, + ['10.0.0.1'], + ); + + expect(result.authorizationUrl).toContain('10.0.0.1/authorize'); + }); + + it('should allow private token_url when hostname matches wildcard allowedDomains', async () => { + const result = await MCPOAuthHandler.initiateOAuthFlow( + 'internal-server', + 'https://speedy-mcp.company.com/', + 'user-1', + {}, + { + authorization_url: 'https://auth.company.internal/authorize', + token_url: 'https://auth.company.internal/token', + client_id: 'client', + client_secret: 'secret', + }, + ['*.company.internal'], + ); + + expect(result.authorizationUrl).toContain('auth.company.internal/authorize'); + }); + + it('should still reject private URLs when allowedDomains does not match', async () => { + await expect( + MCPOAuthHandler.initiateOAuthFlow( + 'test-server', + 'https://mcp.example.com/', + 'user-1', + {}, + { + authorization_url: 'http://169.254.169.254/authorize', + token_url: 'https://auth.example.com/token', + client_id: 'client', + client_secret: 'secret', + }, + ['safe.example.com'], + ), + ).rejects.toThrow(/targets a blocked address/); + }); + + it('should still reject when allowedDomains is empty', async () => { + await expect( + MCPOAuthHandler.initiateOAuthFlow( + 'test-server', + 'https://mcp.example.com/', + 'user-1', + {}, + { + authorization_url: 'http://10.0.0.1/authorize', + token_url: 'https://auth.example.com/token', + client_id: 'client', + client_secret: 'secret', + }, + [], + ), + ).rejects.toThrow(/targets a blocked address/); + }); + + it('should allow private revocationEndpoint when hostname is in allowedDomains', async () => { + const mockFetch = jest.fn().mockResolvedValue({ + ok: true, + status: 200, + } as Response); + const originalFetch = global.fetch; + global.fetch = mockFetch; + + try { + await MCPOAuthHandler.revokeOAuthToken( + 'internal-server', + 'some-token', + 'access', + { + serverUrl: 'https://internal.corp.net/', + clientId: 'client', + clientSecret: 'secret', + revocationEndpoint: 'http://10.0.0.1/revoke', + }, + {}, + ['10.0.0.1'], + ); + + expect(mockFetch).toHaveBeenCalled(); + } finally { + global.fetch = originalFetch; + } + }); + + it('should allow localhost token_url in refreshOAuthTokens when localhost is in allowedDomains', async () => { + const mockFetch = jest.fn().mockResolvedValue({ + ok: true, + json: async () => ({ + access_token: 'new-access-token', + token_type: 'Bearer', + expires_in: 3600, + }), + } as Response); + const originalFetch = global.fetch; + global.fetch = mockFetch; + + try { + const tokens = await MCPOAuthHandler.refreshOAuthTokens( + 'old-refresh-token', + { + serverName: 'local-server', + serverUrl: 'http://localhost:8080/', + clientInfo: { + client_id: 'client-id', + client_secret: 'client-secret', + redirect_uris: ['http://localhost:3080/callback'], + }, + }, + {}, + { + token_url: 'http://localhost:8080/token', + client_id: 'client-id', + client_secret: 'client-secret', + }, + ['localhost'], + ); + + expect(tokens.access_token).toBe('new-access-token'); + expect(mockFetch).toHaveBeenCalled(); + } finally { + global.fetch = originalFetch; + } + }); + + describe('auto-discovery path with allowedDomains', () => { + let discoveryServer: OAuthTestServer; + + beforeEach(async () => { + discoveryServer = await createOAuthMCPServer({ + tokenTTLMs: 60000, + issueRefreshTokens: true, + }); + }); + + afterEach(async () => { + await discoveryServer.close(); + }); + + it('should allow auto-discovered OAuth endpoints when server IP is in allowedDomains', async () => { + const result = await MCPOAuthHandler.initiateOAuthFlow( + 'discovery-server', + discoveryServer.url, + 'user-1', + {}, + undefined, + ['127.0.0.1'], + ); + + expect(result.authorizationUrl).toContain('127.0.0.1'); + expect(result.flowId).toBeTruthy(); + }); + + it('should reject auto-discovered endpoints when allowedDomains does not cover server IP', async () => { + await expect( + MCPOAuthHandler.initiateOAuthFlow( + 'discovery-server', + discoveryServer.url, + 'user-1', + {}, + undefined, + ['safe.example.com'], + ), + ).rejects.toThrow(/targets a blocked address/); + }); + + it('should allow auto-discovered token_url in refreshOAuthTokens branch 3 (no clientInfo/config)', async () => { + const code = await discoveryServer.getAuthCode(); + const tokenRes = await fetch(`${discoveryServer.url}token`, { + method: 'POST', + headers: { 'Content-Type': 'application/x-www-form-urlencoded' }, + body: `grant_type=authorization_code&code=${code}`, + }); + const initial = (await tokenRes.json()) as { + access_token: string; + refresh_token: string; + }; + + const tokens = await MCPOAuthHandler.refreshOAuthTokens( + initial.refresh_token, + { + serverName: 'discovery-refresh-server', + serverUrl: discoveryServer.url, + }, + {}, + undefined, + ['127.0.0.1'], + ); + + expect(tokens.access_token).toBeTruthy(); + }); + }); +}); diff --git a/packages/api/src/mcp/oauth/handler.ts b/packages/api/src/mcp/oauth/handler.ts index 8d863bfe79..0a9154ff35 100644 --- a/packages/api/src/mcp/oauth/handler.ts +++ b/packages/api/src/mcp/oauth/handler.ts @@ -24,7 +24,7 @@ import { selectRegistrationAuthMethod, inferClientAuthMethod, } from './methods'; -import { isSSRFTarget, resolveHostnameSSRF } from '~/auth'; +import { isSSRFTarget, resolveHostnameSSRF, isOAuthUrlAllowed } from '~/auth'; import { sanitizeUrlForLogging } from '~/mcp/utils'; /** Type for the OAuth metadata from the SDK */ @@ -123,6 +123,7 @@ export class MCPOAuthHandler { private static async discoverMetadata( serverUrl: string, oauthHeaders: Record, + allowedDomains?: string[] | null, ): Promise<{ metadata: OAuthMetadata; resourceMetadata?: OAuthProtectedResourceMetadata; @@ -146,7 +147,7 @@ export class MCPOAuthHandler { if (resourceMetadata?.authorization_servers?.length) { const discoveredAuthServer = resourceMetadata.authorization_servers[0]; - await this.validateOAuthUrl(discoveredAuthServer, 'authorization_server'); + await this.validateOAuthUrl(discoveredAuthServer, 'authorization_server', allowedDomains); authServerUrl = new URL(discoveredAuthServer); logger.debug( `[MCPOAuth] Found authorization server from resource metadata: ${authServerUrl}`, @@ -206,11 +207,17 @@ export class MCPOAuthHandler { const endpointChecks: Promise[] = []; if (metadata.registration_endpoint) { endpointChecks.push( - this.validateOAuthUrl(metadata.registration_endpoint, 'registration_endpoint'), + this.validateOAuthUrl( + metadata.registration_endpoint, + 'registration_endpoint', + allowedDomains, + ), ); } if (metadata.token_endpoint) { - endpointChecks.push(this.validateOAuthUrl(metadata.token_endpoint, 'token_endpoint')); + endpointChecks.push( + this.validateOAuthUrl(metadata.token_endpoint, 'token_endpoint', allowedDomains), + ); } if (endpointChecks.length > 0) { await Promise.all(endpointChecks); @@ -360,6 +367,7 @@ export class MCPOAuthHandler { userId: string, oauthHeaders: Record, config?: MCPOptions['oauth'], + allowedDomains?: string[] | null, ): Promise<{ authorizationUrl: string; flowId: string; flowMetadata: MCPOAuthFlowMetadata }> { logger.debug( `[MCPOAuth] initiateOAuthFlow called for ${serverName} with URL: ${sanitizeUrlForLogging(serverUrl)}`, @@ -375,8 +383,8 @@ export class MCPOAuthHandler { logger.debug(`[MCPOAuth] Using pre-configured OAuth settings for ${serverName}`); await Promise.all([ - this.validateOAuthUrl(config.authorization_url, 'authorization_url'), - this.validateOAuthUrl(config.token_url, 'token_url'), + this.validateOAuthUrl(config.authorization_url, 'authorization_url', allowedDomains), + this.validateOAuthUrl(config.token_url, 'token_url', allowedDomains), ]); const skipCodeChallengeCheck = @@ -477,6 +485,7 @@ export class MCPOAuthHandler { const { metadata, resourceMetadata, authServerUrl } = await this.discoverMetadata( serverUrl, oauthHeaders, + allowedDomains, ); logger.debug( @@ -588,7 +597,11 @@ export class MCPOAuthHandler { } /** - * Completes the OAuth flow by exchanging the authorization code for tokens + * Completes the OAuth flow by exchanging the authorization code for tokens. + * + * `allowedDomains` is intentionally absent: all URLs used here (serverUrl, + * token_endpoint) originate from {@link MCPOAuthFlowMetadata} that was + * SSRF-validated during {@link initiateOAuthFlow}. No new URL resolution occurs. */ static async completeOAuthFlow( flowId: string, @@ -692,8 +705,20 @@ export class MCPOAuthHandler { return randomBytes(32).toString('base64url'); } - /** Validates an OAuth URL is not targeting a private/internal address */ - private static async validateOAuthUrl(url: string, fieldName: string): Promise { + /** + * Validates an OAuth URL is not targeting a private/internal address. + * Skipped when the full URL (hostname + protocol + port) matches an admin-trusted + * allowedDomains entry, honoring protocol/port constraints when the admin specifies them. + */ + private static async validateOAuthUrl( + url: string, + fieldName: string, + allowedDomains?: string[] | null, + ): Promise { + if (isOAuthUrlAllowed(url, allowedDomains)) { + return; + } + let hostname: string; try { hostname = new URL(url).hostname; @@ -799,6 +824,7 @@ export class MCPOAuthHandler { metadata: { serverName: string; serverUrl?: string; clientInfo?: OAuthClientInformation }, oauthHeaders: Record, config?: MCPOptions['oauth'], + allowedDomains?: string[] | null, ): Promise { logger.debug(`[MCPOAuth] Refreshing tokens for ${metadata.serverName}`); @@ -824,7 +850,7 @@ export class MCPOAuthHandler { let tokenUrl: string; let authMethods: string[] | undefined; if (config?.token_url) { - await this.validateOAuthUrl(config.token_url, 'token_url'); + await this.validateOAuthUrl(config.token_url, 'token_url', allowedDomains); tokenUrl = config.token_url; authMethods = config.token_endpoint_auth_methods_supported; } else if (!metadata.serverUrl) { @@ -851,7 +877,7 @@ export class MCPOAuthHandler { tokenUrl = oauthMetadata.token_endpoint; authMethods = oauthMetadata.token_endpoint_auth_methods_supported; } - await this.validateOAuthUrl(tokenUrl, 'token_url'); + await this.validateOAuthUrl(tokenUrl, 'token_url', allowedDomains); } const body = new URLSearchParams({ @@ -928,7 +954,7 @@ export class MCPOAuthHandler { if (config?.token_url && config?.client_id) { logger.debug(`[MCPOAuth] Using pre-configured OAuth settings for token refresh`); - await this.validateOAuthUrl(config.token_url, 'token_url'); + await this.validateOAuthUrl(config.token_url, 'token_url', allowedDomains); const tokenUrl = new URL(config.token_url); const body = new URLSearchParams({ @@ -1026,7 +1052,7 @@ export class MCPOAuthHandler { } else { tokenUrl = new URL(oauthMetadata.token_endpoint); } - await this.validateOAuthUrl(tokenUrl.href, 'token_url'); + await this.validateOAuthUrl(tokenUrl.href, 'token_url', allowedDomains); const body = new URLSearchParams({ grant_type: 'refresh_token', @@ -1075,9 +1101,14 @@ export class MCPOAuthHandler { revocationEndpointAuthMethodsSupported?: string[]; }, oauthHeaders: Record = {}, + allowedDomains?: string[] | null, ): Promise { if (metadata.revocationEndpoint != null) { - await this.validateOAuthUrl(metadata.revocationEndpoint, 'revocation_endpoint'); + await this.validateOAuthUrl( + metadata.revocationEndpoint, + 'revocation_endpoint', + allowedDomains, + ); } const revokeUrl: URL = metadata.revocationEndpoint != null diff --git a/packages/api/src/mcp/registry/MCPServerInspector.ts b/packages/api/src/mcp/registry/MCPServerInspector.ts index a477d9b412..7f31211680 100644 --- a/packages/api/src/mcp/registry/MCPServerInspector.ts +++ b/packages/api/src/mcp/registry/MCPServerInspector.ts @@ -20,6 +20,7 @@ export class MCPServerInspector { private readonly config: t.ParsedServerConfig, private connection: MCPConnection | undefined, private readonly useSSRFProtection: boolean = false, + private readonly allowedDomains?: string[] | null, ) {} /** @@ -46,7 +47,13 @@ export class MCPServerInspector { const useSSRFProtection = !Array.isArray(allowedDomains) || allowedDomains.length === 0; const start = Date.now(); - const inspector = new MCPServerInspector(serverName, rawConfig, connection, useSSRFProtection); + const inspector = new MCPServerInspector( + serverName, + rawConfig, + connection, + useSSRFProtection, + allowedDomains, + ); await inspector.inspectServer(); inspector.config.initDuration = Date.now() - start; return inspector.config; @@ -68,6 +75,7 @@ export class MCPServerInspector { serverName: this.serverName, dbSourced: !!this.config.dbId, useSSRFProtection: this.useSSRFProtection, + allowedDomains: this.allowedDomains, }); } diff --git a/packages/api/src/mcp/types/index.ts b/packages/api/src/mcp/types/index.ts index bbdabb4428..0af10c7399 100644 --- a/packages/api/src/mcp/types/index.ts +++ b/packages/api/src/mcp/types/index.ts @@ -169,6 +169,7 @@ export interface BasicConnectionOptions { serverName: string; serverConfig: MCPOptions; useSSRFProtection?: boolean; + allowedDomains?: string[] | null; /** When true, only resolve customUserVars in processMCPEnv (for DB-stored servers) */ dbSourced?: boolean; }