diff --git a/api/server/routes/__tests__/mcp.spec.js b/api/server/routes/__tests__/mcp.spec.js index f194f361d3..9e3ed7a351 100644 --- a/api/server/routes/__tests__/mcp.spec.js +++ b/api/server/routes/__tests__/mcp.spec.js @@ -311,6 +311,87 @@ describe('MCP Routes', () => { expect(response.headers.location).toBe(`${basePath}/oauth/error?error=access_denied`); }); + describe('OAuth error callback failFlow', () => { + it('should fail the flow when OAuth error is received with valid CSRF cookie', async () => { + const flowId = 'test-user-id:test-server'; + const mockFlowManager = { + failFlow: jest.fn().mockResolvedValue(true), + }; + + getLogStores.mockReturnValueOnce({}); + require('~/config').getFlowStateManager.mockReturnValueOnce(mockFlowManager); + MCPOAuthHandler.resolveStateToFlowId.mockResolvedValueOnce(flowId); + + const csrfToken = generateTestCsrfToken(flowId); + const response = await request(app) + .get('/api/mcp/test-server/oauth/callback') + .set('Cookie', [`oauth_csrf=${csrfToken}`]) + .query({ + error: 'invalid_client', + state: flowId, + }); + const basePath = getBasePath(); + + expect(response.status).toBe(302); + expect(response.headers.location).toBe(`${basePath}/oauth/error?error=invalid_client`); + expect(mockFlowManager.failFlow).toHaveBeenCalledWith( + flowId, + 'mcp_oauth', + 'invalid_client', + ); + }); + + it('should fail the flow when OAuth error is received with valid session cookie', async () => { + const flowId = 'test-user-id:test-server'; + const mockFlowManager = { + failFlow: jest.fn().mockResolvedValue(true), + }; + + getLogStores.mockReturnValueOnce({}); + require('~/config').getFlowStateManager.mockReturnValueOnce(mockFlowManager); + MCPOAuthHandler.resolveStateToFlowId.mockResolvedValueOnce(flowId); + + const sessionToken = generateTestCsrfToken('test-user-id'); + const response = await request(app) + .get('/api/mcp/test-server/oauth/callback') + .set('Cookie', [`oauth_session=${sessionToken}`]) + .query({ + error: 'invalid_client', + state: flowId, + }); + const basePath = getBasePath(); + + expect(response.status).toBe(302); + expect(response.headers.location).toBe(`${basePath}/oauth/error?error=invalid_client`); + expect(mockFlowManager.failFlow).toHaveBeenCalledWith( + flowId, + 'mcp_oauth', + 'invalid_client', + ); + }); + + it('should NOT fail the flow when OAuth error is received without cookies (DoS prevention)', async () => { + const flowId = 'test-user-id:test-server'; + const mockFlowManager = { + failFlow: jest.fn(), + }; + + getLogStores.mockReturnValueOnce({}); + require('~/config').getFlowStateManager.mockReturnValueOnce(mockFlowManager); + MCPOAuthHandler.resolveStateToFlowId.mockResolvedValueOnce(flowId); + + const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({ + error: 'invalid_client', + state: flowId, + }); + const basePath = getBasePath(); + + expect(response.status).toBe(302); + expect(response.headers.location).toBe(`${basePath}/oauth/error?error=invalid_client`); + expect(mockFlowManager.failFlow).not.toHaveBeenCalled(); + }); + }); + it('should redirect to error page when code is missing', async () => { const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({ state: 'test-user-id:test-server', diff --git a/api/server/routes/mcp.js b/api/server/routes/mcp.js index c6496ad4b4..b747e6f5ed 100644 --- a/api/server/routes/mcp.js +++ b/api/server/routes/mcp.js @@ -149,6 +149,29 @@ router.get('/:serverName/oauth/callback', async (req, res) => { if (oauthError) { logger.error('[MCP OAuth] OAuth error received', { error: oauthError }); + // Gate failFlow behind callback validation to prevent DoS via leaked state + if (state && typeof state === 'string') { + try { + const flowsCache = getLogStores(CacheKeys.FLOWS); + const flowManager = getFlowStateManager(flowsCache); + const flowId = await MCPOAuthHandler.resolveStateToFlowId(state, flowManager); + if (flowId) { + const flowParts = flowId.split(':'); + const [flowUserId] = flowParts; + const hasCsrf = validateOAuthCsrf(req, res, flowId, OAUTH_CSRF_COOKIE_PATH); + const hasSession = !hasCsrf && validateOAuthSession(req, flowUserId); + if (hasCsrf || hasSession) { + await flowManager.failFlow(flowId, 'mcp_oauth', String(oauthError)); + logger.debug('[MCP OAuth] Marked flow as FAILED with OAuth error', { + flowId, + error: oauthError, + }); + } + } + } catch (err) { + logger.debug('[MCP OAuth] Could not mark flow as failed', err); + } + } return res.redirect( `${basePath}/oauth/error?error=${encodeURIComponent(String(oauthError))}`, ); diff --git a/api/server/services/MCP.js b/api/server/services/MCP.js index dbb44740a9..ccff184d4d 100644 --- a/api/server/services/MCP.js +++ b/api/server/services/MCP.js @@ -23,7 +23,7 @@ const { getFlowStateManager, getMCPManager, } = require('~/config'); -const { findToken, createToken, updateToken } = require('~/models'); +const { findToken, createToken, updateToken, deleteTokens } = require('~/models'); const { getGraphApiToken } = require('./GraphTokenService'); const { reinitMCPServer } = require('./Tools/mcp'); const { getAppConfig } = require('./Config'); @@ -644,6 +644,7 @@ function createToolInstance({ findToken, createToken, updateToken, + deleteTokens, }, oauthStart, oauthEnd, diff --git a/packages/api/src/mcp/MCPConnectionFactory.ts b/packages/api/src/mcp/MCPConnectionFactory.ts index eb62514a4e..c517140b8a 100644 --- a/packages/api/src/mcp/MCPConnectionFactory.ts +++ b/packages/api/src/mcp/MCPConnectionFactory.ts @@ -351,10 +351,13 @@ export class MCPConnectionFactory { config?.oauth_headers ?? {}, config?.oauth, this.allowedDomains, + // Only reuse stored client when deleteTokens is available for stale-client cleanup + this.tokenMethods?.deleteTokens ? this.tokenMethods.findToken : undefined, ); if (existingFlow) { - const oldState = (existingFlow.metadata as MCPOAuthFlowMetadata)?.state; + const oldMeta = existingFlow.metadata as MCPOAuthFlowMetadata | undefined; + const oldState = oldMeta?.state; await this.flowManager!.deleteFlow(newFlowId, 'mcp_oauth'); if (oldState) { await MCPOAuthHandler.deleteStateMapping(oldState, this.flowManager!); @@ -368,9 +371,12 @@ export class MCPConnectionFactory { // Start monitoring in background — createFlow will find the existing PENDING state // written by initFlow above, so metadata arg is unused (pass {} to make that explicit) - this.flowManager!.createFlow(newFlowId, 'mcp_oauth', {}, this.signal).catch((error) => { - logger.debug(`${this.logPrefix} OAuth flow monitor ended`, error); - }); + this.flowManager!.createFlow(newFlowId, 'mcp_oauth', {}, this.signal).catch( + async (error) => { + logger.debug(`${this.logPrefix} OAuth flow monitor ended`, error); + await this.clearStaleClientIfRejected(flowMetadata.reusedStoredClient, error); + }, + ); if (this.oauthStart) { logger.info(`${this.logPrefix} OAuth flow started, issuing authorization URL`); @@ -412,7 +418,7 @@ export class MCPConnectionFactory { if (result?.tokens) { connection.emit('oauthHandled'); } else { - // OAuth failed, emit oauthFailed to properly reject the promise + await this.clearStaleClientIfRejected(result?.reusedStoredClient, result?.error); logger.warn(`${this.logPrefix} OAuth failed, emitting oauthFailed event`); connection.emit('oauthFailed', new Error('OAuth authentication failed')); } @@ -466,6 +472,49 @@ export class MCPConnectionFactory { } } + /** Clears stored client registration if the error indicates client rejection */ + private async clearStaleClientIfRejected( + reusedStoredClient: boolean | undefined, + error: unknown, + ): Promise { + if (!reusedStoredClient || !this.tokenMethods?.deleteTokens) { + return; + } + if (!MCPConnectionFactory.isClientRejection(error)) { + return; + } + await MCPTokenStorage.deleteClientRegistration({ + userId: this.userId!, + serverName: this.serverName, + deleteTokens: this.tokenMethods.deleteTokens, + }).catch((err) => { + logger.warn(`${this.logPrefix} Failed to clear stale client registration`, err); + }); + } + + /** + * Checks whether an error indicates the OAuth client registration was rejected. + * Includes RFC 6749 §5.2 standard codes (`invalid_client`, `unauthorized_client`) + * and known vendor-specific patterns (Okta: `client_id mismatch`, Auth0: `client not found`, + * generic: `unknown client`). + */ + static isClientRejection(error: unknown): boolean { + if (!error || typeof error !== 'object') { + return false; + } + if ('message' in error && typeof error.message === 'string') { + const msg = error.message.toLowerCase(); + return ( + msg.includes('invalid_client') || + msg.includes('unauthorized_client') || + msg.includes('client_id mismatch') || + msg.includes('client not found') || + msg.includes('unknown client') + ); + } + return false; + } + // Determines if an error indicates OAuth authentication is required private isOAuthError(error: unknown): boolean { if (!error || typeof error !== 'object') { @@ -505,6 +554,8 @@ export class MCPConnectionFactory { tokens: MCPOAuthTokens | null; clientInfo?: OAuthClientInformation; metadata?: OAuthMetadata; + reusedStoredClient?: boolean; + error?: unknown; } | null> { const serverUrl = (this.serverConfig as t.SSEOptions | t.StreamableHTTPOptions).url; logger.debug( @@ -519,6 +570,8 @@ export class MCPConnectionFactory { return null; } + let reusedStoredClient = false; + try { logger.debug(`${this.logPrefix} Checking for existing OAuth flow for ${this.serverName}...`); @@ -549,6 +602,7 @@ export class MCPConnectionFactory { await this.oauthStart(storedAuthUrl); } + reusedStoredClient = flowMeta?.reusedStoredClient === true; const tokens = await this.flowManager.createFlow(flowId, 'mcp_oauth', {}, this.signal); if (typeof this.oauthEnd === 'function') { await this.oauthEnd(); @@ -560,6 +614,7 @@ export class MCPConnectionFactory { tokens, clientInfo: flowMeta?.clientInfo, metadata: flowMeta?.metadata, + reusedStoredClient, }; } @@ -615,8 +670,11 @@ export class MCPConnectionFactory { this.serverConfig.oauth_headers ?? {}, this.serverConfig.oauth, this.allowedDomains, + this.tokenMethods?.deleteTokens ? this.tokenMethods.findToken : undefined, ); + reusedStoredClient = flowMetadata.reusedStoredClient === true; + // Store flow state BEFORE redirecting so the callback can find it const metadataWithUrl = { ...flowMetadata, authorizationUrl }; await this.flowManager.initFlow(newFlowId, 'mcp_oauth', metadataWithUrl); @@ -639,18 +697,15 @@ export class MCPConnectionFactory { } logger.info(`${this.logPrefix} OAuth flow completed, tokens received for ${this.serverName}`); - /** Client information from the flow metadata */ - const clientInfo = flowMetadata?.clientInfo; - const metadata = flowMetadata?.metadata; - return { tokens, - clientInfo, - metadata, + clientInfo: flowMetadata.clientInfo, + metadata: flowMetadata.metadata, + reusedStoredClient, }; } catch (error) { logger.error(`${this.logPrefix} Failed to complete OAuth flow for ${this.serverName}`, error); - return null; + return { tokens: null, reusedStoredClient, error }; } } } diff --git a/packages/api/src/mcp/__tests__/MCPConnectionFactory.test.ts b/packages/api/src/mcp/__tests__/MCPConnectionFactory.test.ts index 326b77789e..d90ca5b345 100644 --- a/packages/api/src/mcp/__tests__/MCPConnectionFactory.test.ts +++ b/packages/api/src/mcp/__tests__/MCPConnectionFactory.test.ts @@ -5,7 +5,7 @@ import type { MCPOAuthTokens } from '~/mcp/oauth'; import type * as t from '~/mcp/types'; import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory'; import { MCPConnection } from '~/mcp/connection'; -import { MCPOAuthHandler } from '~/mcp/oauth'; +import { MCPOAuthHandler, MCPTokenStorage } from '~/mcp/oauth'; import { processMCPEnv } from '~/utils'; jest.mock('~/mcp/connection'); @@ -24,6 +24,7 @@ const mockLogger = logger as jest.Mocked; const mockProcessMCPEnv = processMCPEnv as jest.MockedFunction; const mockMCPConnection = MCPConnection as jest.MockedClass; const mockMCPOAuthHandler = MCPOAuthHandler as jest.Mocked; +const mockMCPTokenStorage = MCPTokenStorage as jest.Mocked; describe('MCPConnectionFactory', () => { let mockUser: IUser | undefined; @@ -270,6 +271,7 @@ describe('MCPConnectionFactory', () => { {}, undefined, undefined, + oauthOptions.tokenMethods.findToken, ); // initFlow must be awaited BEFORE the redirect to guarantee state is stored @@ -292,6 +294,78 @@ describe('MCPConnectionFactory', () => { ); }); + it('should clear stale client registration when returnOnOAuth flow fails with client rejection', async () => { + const basicOptions = { + serverName: 'test-server', + serverConfig: { + ...mockServerConfig, + url: 'https://api.example.com', + type: 'sse' as const, + } as t.SSEOptions, + }; + + const deleteTokensSpy = jest.fn().mockResolvedValue({ acknowledged: true, deletedCount: 1 }); + const oauthOptions = { + useOAuth: true as const, + user: mockUser, + flowManager: mockFlowManager, + returnOnOAuth: true, + oauthStart: jest.fn(), + tokenMethods: { + findToken: jest.fn(), + createToken: jest.fn(), + updateToken: jest.fn(), + deleteTokens: deleteTokensSpy, + }, + }; + + const mockFlowData = { + authorizationUrl: 'https://auth.example.com', + flowId: 'flow123', + flowMetadata: { + serverName: 'test-server', + userId: 'user123', + serverUrl: 'https://api.example.com', + state: 'random-state', + clientInfo: { client_id: 'stale-client' }, + reusedStoredClient: true, + }, + }; + + mockMCPOAuthHandler.initiateOAuthFlow.mockResolvedValue(mockFlowData); + mockMCPTokenStorage.deleteClientRegistration.mockResolvedValue(undefined); + // createFlow rejects with invalid_client — simulating stale client rejection + mockFlowManager.createFlow.mockRejectedValue(new Error('invalid_client')); + mockConnectionInstance.isConnected.mockResolvedValue(false); + + let oauthRequiredHandler: (data: Record) => Promise; + mockConnectionInstance.on.mockImplementation((event, handler) => { + if (event === 'oauthRequired') { + oauthRequiredHandler = handler as (data: Record) => Promise; + } + return mockConnectionInstance; + }); + + try { + await MCPConnectionFactory.create(basicOptions, oauthOptions); + } catch { + // Expected + } + + await oauthRequiredHandler!({ serverUrl: 'https://api.example.com' }); + + // Drain microtasks so the background .catch() handler completes + await new Promise((r) => setImmediate(r)); + + // deleteClientRegistration should have been called via clearStaleClientIfRejected + expect(mockMCPTokenStorage.deleteClientRegistration).toHaveBeenCalledWith( + expect.objectContaining({ + userId: 'user123', + serverName: 'test-server', + }), + ); + }); + it('should skip new OAuth flow initiation when a PENDING flow already exists (returnOnOAuth)', async () => { const basicOptions = { serverName: 'test-server', diff --git a/packages/api/src/mcp/__tests__/MCPOAuthClientRegistrationReuse.test.ts b/packages/api/src/mcp/__tests__/MCPOAuthClientRegistrationReuse.test.ts new file mode 100644 index 0000000000..75cf4147b2 --- /dev/null +++ b/packages/api/src/mcp/__tests__/MCPOAuthClientRegistrationReuse.test.ts @@ -0,0 +1,508 @@ +/** + * Tests for MCP OAuth client registration reuse on reconnection. + * + * Documents the client_id mismatch bug in horizontally scaled deployments: + * + * When LibreChat runs with multiple replicas (e.g., 3 behind a load balancer), + * each replica independently calls registerClient() on the OAuth server's /register + * endpoint, getting a different client_id. The check-then-act race between the + * PENDING flow check and storing the flow state means that even with a shared + * Redis-backed flow store, replicas slip through before any has stored PENDING: + * + * Replica A: getFlowState() → null → initiateOAuthFlow() → registers client_A + * Replica B: getFlowState() → null → initiateOAuthFlow() → registers client_B + * Replica A: initFlow(metadata with client_A) → stored in Redis + * Replica B: initFlow(metadata with client_B) → OVERWRITES in Redis + * User completes OAuth in browser with client_A in the URL + * Callback reads Redis → finds client_B → token exchange fails: "client_id mismatch" + * + * The fix stabilizes reconnection flows: before calling registerClient(), check + * MongoDB (shared across replicas) for an existing client registration from a prior + * successful OAuth flow and reuse it. This eliminates redundant /register calls on + * reconnection. Note: the first-time concurrent auth race is NOT addressed by this + * fix and would require a distributed lock (e.g., Redis SETNX) around registration. + */ + +import type { OAuthClientInformation } from '@modelcontextprotocol/sdk/shared/auth.js'; +import type { OAuthTestServer } from './helpers/oauthTestServer'; +import { InMemoryTokenStore, createOAuthMCPServer } from './helpers/oauthTestServer'; +import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory'; +import { MCPOAuthHandler, MCPTokenStorage } from '~/mcp/oauth'; + +jest.mock('@librechat/data-schemas', () => ({ + logger: { + info: jest.fn(), + warn: jest.fn(), + error: jest.fn(), + debug: jest.fn(), + }, + getTenantId: jest.fn(), + SYSTEM_TENANT_ID: '__SYSTEM__', + encryptV2: jest.fn(async (val: string) => `enc:${val}`), + decryptV2: jest.fn(async (val: string) => val.replace(/^enc:/, '')), +})); + +jest.mock('~/auth', () => ({ + createSSRFSafeUndiciConnect: jest.fn(() => undefined), + resolveHostnameSSRF: jest.fn(async () => false), + isSSRFTarget: jest.fn(async () => false), + isOAuthUrlAllowed: jest.fn(() => true), +})); + +jest.mock('~/mcp/mcpConfig', () => ({ + mcpConfig: { CONNECTION_CHECK_TTL: 0, USER_CONNECTION_IDLE_TIMEOUT: 30 * 60 * 1000 }, +})); + +describe('MCPOAuthHandler - client registration reuse on reconnection', () => { + let server: OAuthTestServer | undefined; + let originalDomainServer: string | undefined; + + beforeEach(() => { + originalDomainServer = process.env.DOMAIN_SERVER; + process.env.DOMAIN_SERVER = 'http://localhost:3080'; + }); + + afterEach(async () => { + if (originalDomainServer !== undefined) { + process.env.DOMAIN_SERVER = originalDomainServer; + } else { + delete process.env.DOMAIN_SERVER; + } + if (server) { + await server.close(); + server = undefined; + } + jest.clearAllMocks(); + }); + + describe('Race condition reproduction: concurrent replicas re-register', () => { + it('should produce duplicate client registrations when two replicas initiate flows concurrently', async () => { + server = await createOAuthMCPServer({ tokenTTLMs: 60000 }); + + const [resultA, resultB] = await Promise.all([ + MCPOAuthHandler.initiateOAuthFlow('test-server', server.url, 'user-1', {}), + MCPOAuthHandler.initiateOAuthFlow('test-server', server.url, 'user-1', {}), + ]); + + expect(resultA.authorizationUrl).toBeTruthy(); + expect(resultB.authorizationUrl).toBeTruthy(); + expect(server.registeredClients.size).toBe(2); + + const clientA = resultA.flowMetadata.clientInfo?.client_id; + const clientB = resultB.flowMetadata.clientInfo?.client_id; + expect(clientA).not.toBe(clientB); + }); + + it('should re-register on every sequential initiateOAuthFlow call (reconnections)', async () => { + server = await createOAuthMCPServer({ tokenTTLMs: 60000 }); + + await MCPOAuthHandler.initiateOAuthFlow('test-server', server.url, 'user-1', {}); + expect(server.registeredClients.size).toBe(1); + + await MCPOAuthHandler.initiateOAuthFlow('test-server', server.url, 'user-1', {}); + expect(server.registeredClients.size).toBe(2); + + await MCPOAuthHandler.initiateOAuthFlow('test-server', server.url, 'user-1', {}); + expect(server.registeredClients.size).toBe(3); + }); + }); + + describe('Client reuse via findToken on reconnection', () => { + it('should reuse an existing client registration when findToken returns stored client info', async () => { + server = await createOAuthMCPServer({ tokenTTLMs: 60000 }); + const tokenStore = new InMemoryTokenStore(); + + const firstResult = await MCPOAuthHandler.initiateOAuthFlow( + 'test-server', + server.url, + 'user-1', + {}, + undefined, + undefined, + tokenStore.findToken, + ); + expect(server.registeredClients.size).toBe(1); + const firstClientId = firstResult.flowMetadata.clientInfo?.client_id; + + await MCPTokenStorage.storeTokens({ + userId: 'user-1', + serverName: 'test-server', + tokens: { access_token: 'test-token', token_type: 'Bearer' }, + createToken: tokenStore.createToken, + updateToken: tokenStore.updateToken, + findToken: tokenStore.findToken, + clientInfo: firstResult.flowMetadata.clientInfo, + metadata: firstResult.flowMetadata.metadata, + }); + + const secondResult = await MCPOAuthHandler.initiateOAuthFlow( + 'test-server', + server.url, + 'user-1', + {}, + undefined, + undefined, + tokenStore.findToken, + ); + + expect(server.registeredClients.size).toBe(1); + expect(secondResult.flowMetadata.clientInfo?.client_id).toBe(firstClientId); + expect(secondResult.flowMetadata.reusedStoredClient).toBe(true); + }); + + it('should reuse the same client when two reconnections fire concurrently with pre-seeded token', async () => { + server = await createOAuthMCPServer({ tokenTTLMs: 60000 }); + const tokenStore = new InMemoryTokenStore(); + + const initialResult = await MCPOAuthHandler.initiateOAuthFlow( + 'test-server', + server.url, + 'user-1', + {}, + undefined, + undefined, + tokenStore.findToken, + ); + const storedClientId = initialResult.flowMetadata.clientInfo?.client_id; + + await MCPTokenStorage.storeTokens({ + userId: 'user-1', + serverName: 'test-server', + tokens: { access_token: 'test-token', token_type: 'Bearer' }, + createToken: tokenStore.createToken, + updateToken: tokenStore.updateToken, + findToken: tokenStore.findToken, + clientInfo: initialResult.flowMetadata.clientInfo, + metadata: initialResult.flowMetadata.metadata, + }); + + const [resultA, resultB] = await Promise.all([ + MCPOAuthHandler.initiateOAuthFlow( + 'test-server', + server.url, + 'user-1', + {}, + undefined, + undefined, + tokenStore.findToken, + ), + MCPOAuthHandler.initiateOAuthFlow( + 'test-server', + server.url, + 'user-1', + {}, + undefined, + undefined, + tokenStore.findToken, + ), + ]); + + // Both should reuse the stored client — only the initial registration should exist + expect(server.registeredClients.size).toBe(1); + expect(resultA.flowMetadata.clientInfo?.client_id).toBe(storedClientId); + expect(resultB.flowMetadata.clientInfo?.client_id).toBe(storedClientId); + }); + + it('should re-register when stored redirect_uri differs from current', async () => { + server = await createOAuthMCPServer({ tokenTTLMs: 60000 }); + const tokenStore = new InMemoryTokenStore(); + + await MCPTokenStorage.storeTokens({ + userId: 'user-1', + serverName: 'test-server', + tokens: { access_token: 'old-token', token_type: 'Bearer' }, + createToken: tokenStore.createToken, + updateToken: tokenStore.updateToken, + findToken: tokenStore.findToken, + clientInfo: { + client_id: 'old-client-id', + client_secret: 'old-secret', + redirect_uris: ['http://old-domain.com/api/mcp/test-server/oauth/callback'], + } as OAuthClientInformation & { redirect_uris: string[] }, + }); + + const result = await MCPOAuthHandler.initiateOAuthFlow( + 'test-server', + server.url, + 'user-1', + {}, + undefined, + undefined, + tokenStore.findToken, + ); + + expect(server.registeredClients.size).toBe(1); + expect(result.flowMetadata.clientInfo?.client_id).not.toBe('old-client-id'); + }); + + it('should re-register when stored client has empty redirect_uris', async () => { + server = await createOAuthMCPServer({ tokenTTLMs: 60000 }); + const tokenStore = new InMemoryTokenStore(); + + await MCPTokenStorage.storeTokens({ + userId: 'user-1', + serverName: 'test-server', + tokens: { access_token: 'old-token', token_type: 'Bearer' }, + createToken: tokenStore.createToken, + updateToken: tokenStore.updateToken, + findToken: tokenStore.findToken, + clientInfo: { + client_id: 'empty-redirect-client', + client_secret: 'secret', + redirect_uris: [], + } as OAuthClientInformation & { redirect_uris: string[] }, + }); + + const result = await MCPOAuthHandler.initiateOAuthFlow( + 'test-server', + server.url, + 'user-1', + {}, + undefined, + undefined, + tokenStore.findToken, + ); + + // Should NOT reuse the client with empty redirect_uris — must re-register + expect(server.registeredClients.size).toBe(1); + expect(result.flowMetadata.clientInfo?.client_id).not.toBe('empty-redirect-client'); + }); + + it('should fall back to registration when findToken lookup throws', async () => { + server = await createOAuthMCPServer({ tokenTTLMs: 60000 }); + const failingFindToken = jest.fn().mockRejectedValue(new Error('DB connection lost')); + + const result = await MCPOAuthHandler.initiateOAuthFlow( + 'test-server', + server.url, + 'user-1', + {}, + undefined, + undefined, + failingFindToken, + ); + + expect(server.registeredClients.size).toBe(1); + expect(result.flowMetadata.clientInfo?.client_id).toBeTruthy(); + }); + + it('should not reuse a stale client on retry after a failed flow', async () => { + server = await createOAuthMCPServer({ tokenTTLMs: 60000 }); + const tokenStore = new InMemoryTokenStore(); + + // Seed a stored client with a client_id that the OAuth server doesn't recognize, + // but with matching issuer and redirect_uri so reuse logic accepts it + const serverIssuer = `http://127.0.0.1:${server.port}`; + await MCPTokenStorage.storeTokens({ + userId: 'user-1', + serverName: 'test-server', + tokens: { access_token: 'old-token', token_type: 'Bearer' }, + createToken: tokenStore.createToken, + updateToken: tokenStore.updateToken, + findToken: tokenStore.findToken, + clientInfo: { + client_id: 'stale-client-that-oauth-server-deleted', + client_secret: 'stale-secret', + redirect_uris: ['http://localhost:3080/api/mcp/test-server/oauth/callback'], + } as OAuthClientInformation & { redirect_uris: string[] }, + metadata: { + issuer: serverIssuer, + authorization_endpoint: `${serverIssuer}/authorize`, + token_endpoint: `${serverIssuer}/token`, + }, + }); + + // First attempt: reuses the stale client (this is expected — we don't know it's stale yet) + const firstResult = await MCPOAuthHandler.initiateOAuthFlow( + 'test-server', + server.url, + 'user-1', + {}, + undefined, + undefined, + tokenStore.findToken, + ); + expect(firstResult.flowMetadata.clientInfo?.client_id).toBe( + 'stale-client-that-oauth-server-deleted', + ); + expect(firstResult.flowMetadata.reusedStoredClient).toBe(true); + expect(server.registeredClients.size).toBe(0); + + // Simulate what MCPConnectionFactory does on failure when reusedStoredClient is set: + // clear the stored client registration so the next attempt does a fresh DCR + await MCPTokenStorage.deleteClientRegistration({ + userId: 'user-1', + serverName: 'test-server', + deleteTokens: tokenStore.deleteTokens, + }); + + // Second attempt (retry after failure): should do a fresh DCR + const secondResult = await MCPOAuthHandler.initiateOAuthFlow( + 'test-server', + server.url, + 'user-1', + {}, + undefined, + undefined, + tokenStore.findToken, + ); + + expect(server.registeredClients.size).toBe(1); + expect(secondResult.flowMetadata.clientInfo?.client_id).not.toBe( + 'stale-client-that-oauth-server-deleted', + ); + expect(secondResult.flowMetadata.reusedStoredClient).toBeUndefined(); + }); + + it('should re-register when stored client was issued by a different authorization server', async () => { + server = await createOAuthMCPServer({ tokenTTLMs: 60000 }); + const tokenStore = new InMemoryTokenStore(); + + // Seed a stored client that was registered with a different issuer + await MCPTokenStorage.storeTokens({ + userId: 'user-1', + serverName: 'test-server', + tokens: { access_token: 'old-token', token_type: 'Bearer' }, + createToken: tokenStore.createToken, + updateToken: tokenStore.updateToken, + findToken: tokenStore.findToken, + clientInfo: { + client_id: 'old-issuer-client', + client_secret: 'secret', + redirect_uris: ['http://localhost:3080/api/mcp/test-server/oauth/callback'], + } as OAuthClientInformation & { redirect_uris: string[] }, + metadata: { + issuer: 'https://old-auth-server.example.com', + authorization_endpoint: 'https://old-auth-server.example.com/authorize', + token_endpoint: 'https://old-auth-server.example.com/token', + }, + }); + + const result = await MCPOAuthHandler.initiateOAuthFlow( + 'test-server', + server.url, + 'user-1', + {}, + undefined, + undefined, + tokenStore.findToken, + ); + + // Should have registered a NEW client because the issuer changed + expect(server.registeredClients.size).toBe(1); + expect(result.flowMetadata.clientInfo?.client_id).not.toBe('old-issuer-client'); + expect(result.flowMetadata.reusedStoredClient).toBeUndefined(); + }); + + it('should not call getClientInfoAndMetadata when findToken is not provided', async () => { + server = await createOAuthMCPServer({ tokenTTLMs: 60000 }); + const spy = jest.spyOn(MCPTokenStorage, 'getClientInfoAndMetadata'); + + await MCPOAuthHandler.initiateOAuthFlow('test-server', server.url, 'user-1', {}); + + expect(spy).not.toHaveBeenCalled(); + spy.mockRestore(); + }); + }); + + describe('isClientRejection', () => { + it('should detect invalid_client errors', () => { + expect(MCPConnectionFactory.isClientRejection(new Error('invalid_client'))).toBe(true); + expect( + MCPConnectionFactory.isClientRejection( + new Error('OAuth token exchange failed: invalid_client'), + ), + ).toBe(true); + }); + + it('should detect unauthorized_client errors', () => { + expect(MCPConnectionFactory.isClientRejection(new Error('unauthorized_client'))).toBe(true); + }); + + it('should detect client_id mismatch errors', () => { + expect( + MCPConnectionFactory.isClientRejection( + new Error('Token exchange rejected: client_id mismatch'), + ), + ).toBe(true); + }); + + it('should detect client not found errors', () => { + expect(MCPConnectionFactory.isClientRejection(new Error('client not found'))).toBe(true); + expect(MCPConnectionFactory.isClientRejection(new Error('unknown client'))).toBe(true); + }); + + it('should not match unrelated errors', () => { + expect(MCPConnectionFactory.isClientRejection(new Error('timeout'))).toBe(false); + expect(MCPConnectionFactory.isClientRejection(new Error('Flow state not found'))).toBe(false); + expect(MCPConnectionFactory.isClientRejection(new Error('user denied access'))).toBe(false); + expect(MCPConnectionFactory.isClientRejection(null)).toBe(false); + expect(MCPConnectionFactory.isClientRejection(undefined)).toBe(false); + }); + }); + + describe('Token exchange with enforced client_id', () => { + it('should reject token exchange when client_id does not match registered client', async () => { + server = await createOAuthMCPServer({ tokenTTLMs: 60000, enforceClientId: true }); + + // Register a real client via DCR + const regRes = await fetch(`${server.url}register`, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ redirect_uris: ['http://localhost/callback'] }), + }); + const registered = (await regRes.json()) as { client_id: string }; + + // Get an auth code bound to the registered client_id + const authRes = await fetch( + `${server.url}authorize?redirect_uri=http://localhost/callback&state=s1&client_id=${registered.client_id}`, + { redirect: 'manual' }, + ); + const location = authRes.headers.get('location') ?? ''; + const code = new URL(location).searchParams.get('code'); + + // Try to exchange the code with a DIFFERENT (stale) client_id + const tokenRes = await fetch(`${server.url}token`, { + method: 'POST', + headers: { 'Content-Type': 'application/x-www-form-urlencoded' }, + body: `grant_type=authorization_code&code=${code}&client_id=stale-client-id`, + }); + + expect(tokenRes.status).toBe(401); + const body = (await tokenRes.json()) as { error: string; error_description?: string }; + expect(body.error).toBe('invalid_client'); + + // Verify isClientRejection would match this error + const errorMsg = body.error_description ?? body.error; + expect(MCPConnectionFactory.isClientRejection(new Error(errorMsg))).toBe(true); + }); + + it('should accept token exchange when client_id matches', async () => { + server = await createOAuthMCPServer({ tokenTTLMs: 60000, enforceClientId: true }); + + const regRes = await fetch(`${server.url}register`, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ redirect_uris: ['http://localhost/callback'] }), + }); + const registered = (await regRes.json()) as { client_id: string }; + + const authRes = await fetch( + `${server.url}authorize?redirect_uri=http://localhost/callback&state=s1&client_id=${registered.client_id}`, + { redirect: 'manual' }, + ); + const location = authRes.headers.get('location') ?? ''; + const code = new URL(location).searchParams.get('code'); + + const tokenRes = await fetch(`${server.url}token`, { + method: 'POST', + headers: { 'Content-Type': 'application/x-www-form-urlencoded' }, + body: `grant_type=authorization_code&code=${code}&client_id=${registered.client_id}`, + }); + + expect(tokenRes.status).toBe(200); + const body = (await tokenRes.json()) as { access_token: string }; + expect(body.access_token).toBeTruthy(); + }); + }); +}); diff --git a/packages/api/src/mcp/__tests__/handler.test.ts b/packages/api/src/mcp/__tests__/handler.test.ts index 31665ce8f7..87de316d17 100644 --- a/packages/api/src/mcp/__tests__/handler.test.ts +++ b/packages/api/src/mcp/__tests__/handler.test.ts @@ -20,6 +20,12 @@ jest.mock('@modelcontextprotocol/sdk/client/auth.js', () => ({ exchangeAuthorization: jest.fn(), })); +jest.mock('../../mcp/oauth/tokens', () => ({ + MCPTokenStorage: { + getClientInfoAndMetadata: jest.fn(), + }, +})); + import { startAuthorization, discoverAuthorizationServerMetadata, @@ -27,6 +33,7 @@ import { registerClient, exchangeAuthorization, } from '@modelcontextprotocol/sdk/client/auth.js'; +import { MCPTokenStorage } from '../../mcp/oauth/tokens'; import { FlowStateManager } from '../../flow/manager'; const mockStartAuthorization = startAuthorization as jest.MockedFunction; @@ -42,6 +49,10 @@ const mockRegisterClient = registerClient as jest.MockedFunction; +const mockGetClientInfoAndMetadata = + MCPTokenStorage.getClientInfoAndMetadata as jest.MockedFunction< + typeof MCPTokenStorage.getClientInfoAndMetadata + >; describe('MCPOAuthHandler - Configurable OAuth Metadata', () => { const mockServerName = 'test-server'; @@ -1391,6 +1402,348 @@ describe('MCPOAuthHandler - Configurable OAuth Metadata', () => { }); }); + describe('Client Registration Reuse', () => { + const originalFetch = global.fetch; + const mockFetch = jest.fn(); + + beforeEach(() => { + jest.clearAllMocks(); + global.fetch = mockFetch as unknown as typeof fetch; + mockFetch.mockResolvedValue({ ok: true, json: async () => ({}) } as Response); + process.env.DOMAIN_SERVER = 'http://localhost:3080'; + }); + + afterAll(() => { + global.fetch = originalFetch; + }); + + const mockFindToken = jest.fn(); + + it('should reuse existing client registration when findToken is provided and client exists', async () => { + const existingClientInfo = { + client_id: 'existing-client-id', + client_secret: 'existing-client-secret', + redirect_uris: ['http://localhost:3080/api/mcp/test-server/oauth/callback'], + token_endpoint_auth_method: 'client_secret_basic', + }; + + mockGetClientInfoAndMetadata.mockResolvedValueOnce({ + clientInfo: existingClientInfo, + clientMetadata: { issuer: 'https://example.com' }, + }); + + // Mock resource metadata discovery to fail + mockDiscoverOAuthProtectedResourceMetadata.mockRejectedValueOnce( + new Error('No resource metadata'), + ); + + // Mock authorization server metadata discovery + mockDiscoverAuthorizationServerMetadata.mockResolvedValueOnce({ + issuer: 'https://example.com', + authorization_endpoint: 'https://example.com/authorize', + token_endpoint: 'https://example.com/token', + registration_endpoint: 'https://example.com/register', + response_types_supported: ['code'], + jwks_uri: 'https://example.com/.well-known/jwks.json', + subject_types_supported: ['public'], + id_token_signing_alg_values_supported: ['RS256'], + } as AuthorizationServerMetadata); + + mockStartAuthorization.mockResolvedValueOnce({ + authorizationUrl: new URL('https://example.com/authorize?client_id=existing-client-id'), + codeVerifier: 'test-code-verifier', + }); + + const result = await MCPOAuthHandler.initiateOAuthFlow( + 'test-server', + 'https://example.com/mcp', + 'user-123', + {}, + undefined, + undefined, + mockFindToken, + ); + + // Should NOT have called registerClient since we reused the existing one + expect(mockRegisterClient).not.toHaveBeenCalled(); + + // Should have used the existing client info for startAuthorization + expect(mockStartAuthorization).toHaveBeenCalledWith( + 'https://example.com/mcp', + expect.objectContaining({ + clientInformation: existingClientInfo, + }), + ); + + expect(result.authorizationUrl).toBeDefined(); + expect(result.flowId).toBeDefined(); + }); + + it('should register a new client when findToken is provided but no existing registration found', async () => { + mockGetClientInfoAndMetadata.mockResolvedValueOnce(null); + + // Mock resource metadata discovery to fail + mockDiscoverOAuthProtectedResourceMetadata.mockRejectedValueOnce( + new Error('No resource metadata'), + ); + + // Mock authorization server metadata discovery + mockDiscoverAuthorizationServerMetadata.mockResolvedValueOnce({ + issuer: 'https://example.com', + authorization_endpoint: 'https://example.com/authorize', + token_endpoint: 'https://example.com/token', + registration_endpoint: 'https://example.com/register', + response_types_supported: ['code'], + jwks_uri: 'https://example.com/.well-known/jwks.json', + subject_types_supported: ['public'], + id_token_signing_alg_values_supported: ['RS256'], + } as AuthorizationServerMetadata); + + mockRegisterClient.mockResolvedValueOnce({ + client_id: 'new-client-id', + client_secret: 'new-client-secret', + redirect_uris: ['http://localhost:3080/api/mcp/test-server/oauth/callback'], + logo_uri: undefined, + tos_uri: undefined, + }); + + mockStartAuthorization.mockResolvedValueOnce({ + authorizationUrl: new URL('https://example.com/authorize?client_id=new-client-id'), + codeVerifier: 'test-code-verifier', + }); + + await MCPOAuthHandler.initiateOAuthFlow( + 'test-server', + 'https://example.com/mcp', + 'user-123', + {}, + undefined, + undefined, + mockFindToken, + ); + + // Should have called registerClient since no existing registration was found + expect(mockRegisterClient).toHaveBeenCalled(); + }); + + it('should register a new client when findToken is not provided', async () => { + // Mock resource metadata discovery to fail + mockDiscoverOAuthProtectedResourceMetadata.mockRejectedValueOnce( + new Error('No resource metadata'), + ); + + // Mock authorization server metadata discovery + mockDiscoverAuthorizationServerMetadata.mockResolvedValueOnce({ + issuer: 'https://example.com', + authorization_endpoint: 'https://example.com/authorize', + token_endpoint: 'https://example.com/token', + registration_endpoint: 'https://example.com/register', + response_types_supported: ['code'], + jwks_uri: 'https://example.com/.well-known/jwks.json', + subject_types_supported: ['public'], + id_token_signing_alg_values_supported: ['RS256'], + } as AuthorizationServerMetadata); + + mockRegisterClient.mockResolvedValueOnce({ + client_id: 'new-client-id', + client_secret: 'new-client-secret', + redirect_uris: ['http://localhost:3080/api/mcp/test-server/oauth/callback'], + logo_uri: undefined, + tos_uri: undefined, + }); + + mockStartAuthorization.mockResolvedValueOnce({ + authorizationUrl: new URL('https://example.com/authorize?client_id=new-client-id'), + codeVerifier: 'test-code-verifier', + }); + + // No findToken passed + await MCPOAuthHandler.initiateOAuthFlow( + 'test-server', + 'https://example.com/mcp', + 'user-123', + {}, + undefined, + ); + + // Should NOT have tried to look up existing registration + expect(mockGetClientInfoAndMetadata).not.toHaveBeenCalled(); + + // Should have called registerClient + expect(mockRegisterClient).toHaveBeenCalled(); + }); + + it('should fall back to registration when getClientInfoAndMetadata throws', async () => { + mockGetClientInfoAndMetadata.mockRejectedValueOnce(new Error('DB error')); + + // Mock resource metadata discovery to fail + mockDiscoverOAuthProtectedResourceMetadata.mockRejectedValueOnce( + new Error('No resource metadata'), + ); + + // Mock authorization server metadata discovery + mockDiscoverAuthorizationServerMetadata.mockResolvedValueOnce({ + issuer: 'https://example.com', + authorization_endpoint: 'https://example.com/authorize', + token_endpoint: 'https://example.com/token', + registration_endpoint: 'https://example.com/register', + response_types_supported: ['code'], + jwks_uri: 'https://example.com/.well-known/jwks.json', + subject_types_supported: ['public'], + id_token_signing_alg_values_supported: ['RS256'], + } as AuthorizationServerMetadata); + + mockRegisterClient.mockResolvedValueOnce({ + client_id: 'new-client-id', + client_secret: 'new-client-secret', + redirect_uris: ['http://localhost:3080/api/mcp/test-server/oauth/callback'], + logo_uri: undefined, + tos_uri: undefined, + }); + + mockStartAuthorization.mockResolvedValueOnce({ + authorizationUrl: new URL('https://example.com/authorize?client_id=new-client-id'), + codeVerifier: 'test-code-verifier', + }); + + await MCPOAuthHandler.initiateOAuthFlow( + 'test-server', + 'https://example.com/mcp', + 'user-123', + {}, + undefined, + undefined, + mockFindToken, + ); + + // Should have fallen back to registerClient + expect(mockRegisterClient).toHaveBeenCalled(); + }); + + it('should re-register when stored redirect_uri differs from current configuration', async () => { + const existingClientInfo = { + client_id: 'existing-client-id', + client_secret: 'existing-client-secret', + redirect_uris: ['http://old-domain.com/api/mcp/test-server/oauth/callback'], + token_endpoint_auth_method: 'client_secret_basic', + }; + + mockGetClientInfoAndMetadata.mockResolvedValueOnce({ + clientInfo: existingClientInfo, + clientMetadata: {}, + }); + + mockDiscoverOAuthProtectedResourceMetadata.mockRejectedValueOnce( + new Error('No resource metadata'), + ); + + mockDiscoverAuthorizationServerMetadata.mockResolvedValueOnce({ + issuer: 'https://example.com', + authorization_endpoint: 'https://example.com/authorize', + token_endpoint: 'https://example.com/token', + registration_endpoint: 'https://example.com/register', + response_types_supported: ['code'], + jwks_uri: 'https://example.com/.well-known/jwks.json', + subject_types_supported: ['public'], + id_token_signing_alg_values_supported: ['RS256'], + } as AuthorizationServerMetadata); + + mockRegisterClient.mockResolvedValueOnce({ + client_id: 'new-client-id', + client_secret: 'new-client-secret', + redirect_uris: ['http://localhost:3080/api/mcp/test-server/oauth/callback'], + logo_uri: undefined, + tos_uri: undefined, + }); + + mockStartAuthorization.mockResolvedValueOnce({ + authorizationUrl: new URL('https://example.com/authorize?client_id=new-client-id'), + codeVerifier: 'test-code-verifier', + }); + + await MCPOAuthHandler.initiateOAuthFlow( + 'test-server', + 'https://example.com/mcp', + 'user-123', + {}, + undefined, + undefined, + mockFindToken, + ); + + expect(mockRegisterClient).toHaveBeenCalled(); + expect(mockStartAuthorization).toHaveBeenCalledWith( + 'https://example.com/mcp', + expect.objectContaining({ + clientInformation: expect.objectContaining({ + client_id: 'new-client-id', + }), + }), + ); + }); + + it('should re-register when stored client has empty redirect_uris', async () => { + const existingClientInfo = { + client_id: 'empty-redirect-client', + client_secret: 'secret', + redirect_uris: [], + }; + + mockGetClientInfoAndMetadata.mockResolvedValueOnce({ + clientInfo: existingClientInfo, + clientMetadata: {}, + }); + + mockDiscoverOAuthProtectedResourceMetadata.mockRejectedValueOnce( + new Error('No resource metadata'), + ); + + mockDiscoverAuthorizationServerMetadata.mockResolvedValueOnce({ + issuer: 'https://example.com', + authorization_endpoint: 'https://example.com/authorize', + token_endpoint: 'https://example.com/token', + registration_endpoint: 'https://example.com/register', + response_types_supported: ['code'], + jwks_uri: 'https://example.com/.well-known/jwks.json', + subject_types_supported: ['public'], + id_token_signing_alg_values_supported: ['RS256'], + } as AuthorizationServerMetadata); + + mockRegisterClient.mockResolvedValueOnce({ + client_id: 'new-client-id', + client_secret: 'new-client-secret', + redirect_uris: ['http://localhost:3080/api/mcp/test-server/oauth/callback'], + logo_uri: undefined, + tos_uri: undefined, + }); + + mockStartAuthorization.mockResolvedValueOnce({ + authorizationUrl: new URL('https://example.com/authorize?client_id=new-client-id'), + codeVerifier: 'test-code-verifier', + }); + + await MCPOAuthHandler.initiateOAuthFlow( + 'test-server', + 'https://example.com/mcp', + 'user-123', + {}, + undefined, + undefined, + mockFindToken, + ); + + expect(mockRegisterClient).toHaveBeenCalled(); + expect(mockStartAuthorization).toHaveBeenCalledWith( + 'https://example.com/mcp', + expect.objectContaining({ + clientInformation: expect.objectContaining({ + client_id: 'new-client-id', + }), + }), + ); + }); + }); + describe('Fallback OAuth Metadata (Legacy Server Support)', () => { const originalFetch = global.fetch; const mockFetch = jest.fn(); diff --git a/packages/api/src/mcp/__tests__/helpers/oauthTestServer.ts b/packages/api/src/mcp/__tests__/helpers/oauthTestServer.ts index 3b68b2ded4..34f968d4e2 100644 --- a/packages/api/src/mcp/__tests__/helpers/oauthTestServer.ts +++ b/packages/api/src/mcp/__tests__/helpers/oauthTestServer.ts @@ -59,6 +59,8 @@ export interface OAuthTestServerOptions { issueRefreshTokens?: boolean; refreshTokenTTLMs?: number; rotateRefreshTokens?: boolean; + /** When true, /token validates client_id against the registered client that initiated /authorize */ + enforceClientId?: boolean; } export interface OAuthTestServer { @@ -81,6 +83,17 @@ async function readRequestBody(req: http.IncomingMessage): Promise { return Buffer.concat(chunks).toString(); } +function parseBasicAuth( + header: string | undefined, +): { clientId: string; clientSecret: string } | null { + if (!header?.startsWith('Basic ')) { + return null; + } + const decoded = Buffer.from(header.slice(6), 'base64').toString(); + const [clientId, clientSecret] = decoded.split(':'); + return clientId ? { clientId, clientSecret: clientSecret ?? '' } : null; +} + function parseTokenRequest(body: string, contentType: string | undefined): URLSearchParams | null { if (contentType?.includes('application/x-www-form-urlencoded')) { return new URLSearchParams(body); @@ -100,6 +113,7 @@ export async function createOAuthMCPServer( issueRefreshTokens = false, refreshTokenTTLMs = 365 * 24 * 60 * 60 * 1000, rotateRefreshTokens = false, + enforceClientId = false, } = options; const sessions = new Map(); @@ -107,7 +121,10 @@ export async function createOAuthMCPServer( const tokenIssueTimes = new Map(); const issuedRefreshTokens = new Map(); const refreshTokenIssueTimes = new Map(); - const authCodes = new Map(); + const authCodes = new Map< + string, + { codeChallenge?: string; codeChallengeMethod?: string; clientId?: string } + >(); const registeredClients = new Map(); let port = 0; @@ -155,7 +172,8 @@ export async function createOAuthMCPServer( const code = randomUUID(); const codeChallenge = url.searchParams.get('code_challenge') ?? undefined; const codeChallengeMethod = url.searchParams.get('code_challenge_method') ?? undefined; - authCodes.set(code, { codeChallenge, codeChallengeMethod }); + const clientId = url.searchParams.get('client_id') ?? undefined; + authCodes.set(code, { codeChallenge, codeChallengeMethod, clientId }); const redirectUri = url.searchParams.get('redirect_uri') ?? ''; const state = url.searchParams.get('state') ?? ''; res.writeHead(302, { @@ -202,6 +220,23 @@ export async function createOAuthMCPServer( } } + if (enforceClientId && codeData.clientId) { + const requestClientId = + params.get('client_id') ?? parseBasicAuth(req.headers.authorization)?.clientId; + if (!requestClientId || !registeredClients.has(requestClientId)) { + res.writeHead(401, { 'Content-Type': 'application/json' }); + res.end(JSON.stringify({ error: 'invalid_client' })); + return; + } + if (requestClientId !== codeData.clientId) { + res.writeHead(401, { 'Content-Type': 'application/json' }); + res.end( + JSON.stringify({ error: 'invalid_client', error_description: 'client_id mismatch' }), + ); + return; + } + } + authCodes.delete(code); const accessToken = randomUUID(); @@ -439,6 +474,25 @@ export class InMemoryTokenStore { this.tokens.delete(this.key(filter)); }; + deleteTokens = async (query: { + userId?: string; + type?: string; + identifier?: string; + }): Promise<{ acknowledged: boolean; deletedCount: number }> => { + let deletedCount = 0; + for (const [key, token] of this.tokens.entries()) { + const match = + (!query.userId || token.userId === query.userId) && + (!query.type || token.type === query.type) && + (!query.identifier || token.identifier === query.identifier); + if (match) { + this.tokens.delete(key); + deletedCount++; + } + } + return { acknowledged: true, deletedCount }; + }; + getAll(): InMemoryToken[] { return [...this.tokens.values()]; } diff --git a/packages/api/src/mcp/oauth/handler.ts b/packages/api/src/mcp/oauth/handler.ts index e128dec308..ccca5b1945 100644 --- a/packages/api/src/mcp/oauth/handler.ts +++ b/packages/api/src/mcp/oauth/handler.ts @@ -10,6 +10,7 @@ import { discoverOAuthProtectedResourceMetadata, } from '@modelcontextprotocol/sdk/client/auth.js'; import { TokenExchangeMethodEnum, type MCPOptions } from 'librechat-data-provider'; +import type { TokenMethods } from '@librechat/data-schemas'; import type { FlowStateManager } from '~/flow/manager'; import type { OAuthClientInformation, @@ -25,6 +26,7 @@ import { inferClientAuthMethod, } from './methods'; import { isSSRFTarget, resolveHostnameSSRF, isOAuthUrlAllowed } from '~/auth'; +import { MCPTokenStorage } from './tokens'; import { sanitizeUrlForLogging } from '~/mcp/utils'; /** Type for the OAuth metadata from the SDK */ @@ -368,6 +370,7 @@ export class MCPOAuthHandler { oauthHeaders: Record, config?: MCPOptions['oauth'], allowedDomains?: string[] | null, + findToken?: TokenMethods['findToken'], ): Promise<{ authorizationUrl: string; flowId: string; flowMetadata: MCPOAuthFlowMetadata }> { logger.debug( `[MCPOAuth] initiateOAuthFlow called for ${serverName} with URL: ${sanitizeUrlForLogging(serverUrl)}`, @@ -494,18 +497,62 @@ export class MCPOAuthHandler { ); const redirectUri = this.getDefaultRedirectUri(serverName); - logger.debug(`[MCPOAuth] Registering OAuth client with redirect URI: ${redirectUri}`); + logger.debug(`[MCPOAuth] Resolving OAuth client with redirect URI: ${redirectUri}`); - const clientInfo = await this.registerOAuthClient( - authServerUrl.toString(), - metadata, - oauthHeaders, - resourceMetadata, - redirectUri, - config?.token_exchange_method, - ); + let clientInfo: OAuthClientInformation | undefined; + let reusedStoredClient = false; - logger.debug(`[MCPOAuth] Client registered with ID: ${clientInfo.client_id}`); + if (findToken) { + try { + const existing = await MCPTokenStorage.getClientInfoAndMetadata({ + userId, + serverName, + findToken, + }); + if (existing?.clientInfo?.client_id) { + const storedRedirectUri = (existing.clientInfo as OAuthClientInformation) + .redirect_uris?.[0]; + const storedIssuer = + typeof existing.clientMetadata?.issuer === 'string' + ? existing.clientMetadata.issuer.replace(/\/+$/, '') + : null; + const currentIssuer = (metadata.issuer ?? authServerUrl.toString()).replace(/\/+$/, ''); + + if (!storedRedirectUri || storedRedirectUri !== redirectUri) { + logger.debug( + `[MCPOAuth] Stored redirect_uri "${storedRedirectUri}" differs from current "${redirectUri}", will re-register`, + ); + } else if (!storedIssuer || storedIssuer !== currentIssuer) { + logger.debug( + `[MCPOAuth] Issuer mismatch (stored: ${storedIssuer ?? 'none'}, current: ${currentIssuer}), will re-register`, + ); + } else { + logger.debug( + `[MCPOAuth] Reusing existing client registration: ${existing.clientInfo.client_id}`, + ); + clientInfo = existing.clientInfo; + reusedStoredClient = true; + } + } + } catch (error) { + logger.warn( + `[MCPOAuth] Failed to look up existing client registration, falling back to new registration`, + { error, serverName, userId }, + ); + } + } + + if (!clientInfo) { + clientInfo = await this.registerOAuthClient( + authServerUrl.toString(), + metadata, + oauthHeaders, + resourceMetadata, + redirectUri, + config?.token_exchange_method, + ); + logger.debug(`[MCPOAuth] Client registered with ID: ${clientInfo.client_id}`); + } /** Authorization Scope */ const scope = @@ -575,6 +622,7 @@ export class MCPOAuthHandler { metadata, resourceMetadata, ...(Object.keys(oauthHeaders).length > 0 && { oauthHeaders }), + ...(reusedStoredClient && { reusedStoredClient }), }; logger.debug( diff --git a/packages/api/src/mcp/oauth/tokens.ts b/packages/api/src/mcp/oauth/tokens.ts index 1e31a64511..61b442ca8c 100644 --- a/packages/api/src/mcp/oauth/tokens.ts +++ b/packages/api/src/mcp/oauth/tokens.ts @@ -476,6 +476,26 @@ export class MCPTokenStorage { }; } + /** Deletes only the stored client registration for a specific user and server */ + static async deleteClientRegistration({ + userId, + serverName, + deleteTokens, + }: { + userId: string; + serverName: string; + deleteTokens: TokenMethods['deleteTokens']; + }): Promise { + const identifier = `mcp:${serverName}`; + await deleteTokens({ + userId, + type: 'mcp_oauth_client', + identifier: `${identifier}:client`, + }); + const logPrefix = this.getLogPrefix(userId, serverName); + logger.debug(`${logPrefix} Cleared stored client registration`); + } + /** * Deletes all OAuth-related tokens for a specific user and server */ diff --git a/packages/api/src/mcp/oauth/types.ts b/packages/api/src/mcp/oauth/types.ts index bc5f53f60c..ee8ce2d76d 100644 --- a/packages/api/src/mcp/oauth/types.ts +++ b/packages/api/src/mcp/oauth/types.ts @@ -91,6 +91,8 @@ export interface MCPOAuthFlowMetadata extends FlowMetadata { authorizationUrl?: string; /** Custom headers for OAuth token exchange, persisted at flow initiation for the callback. */ oauthHeaders?: Record; + /** True when the flow reused a stored client registration from a prior successful OAuth flow */ + reusedStoredClient?: boolean; } export interface MCPOAuthTokens extends OAuthTokens {