diff --git a/api/server/routes/__tests__/mcp.spec.js b/api/server/routes/__tests__/mcp.spec.js index f194f361d3..3a50c0f6cf 100644 --- a/api/server/routes/__tests__/mcp.spec.js +++ b/api/server/routes/__tests__/mcp.spec.js @@ -311,6 +311,60 @@ describe('MCP Routes', () => { expect(response.headers.location).toBe(`${basePath}/oauth/error?error=access_denied`); }); + describe('OAuth error callback failFlow', () => { + it('should fail the flow when OAuth error is received with valid CSRF cookie', async () => { + const flowId = 'test-user-id:test-server'; + const mockFlowManager = { + getFlowState: jest.fn().mockResolvedValue({ status: 'PENDING', createdAt: Date.now() }), + failFlow: jest.fn().mockResolvedValue(true), + }; + + getLogStores.mockReturnValueOnce({}); + require('~/config').getFlowStateManager.mockReturnValueOnce(mockFlowManager); + MCPOAuthHandler.resolveStateToFlowId.mockResolvedValueOnce(flowId); + + const csrfToken = generateTestCsrfToken(flowId); + const response = await request(app) + .get('/api/mcp/test-server/oauth/callback') + .set('Cookie', [`oauth_csrf=${csrfToken}`]) + .query({ + error: 'invalid_client', + state: flowId, + }); + const basePath = getBasePath(); + + expect(response.status).toBe(302); + expect(response.headers.location).toBe(`${basePath}/oauth/error?error=invalid_client`); + expect(mockFlowManager.failFlow).toHaveBeenCalledWith( + flowId, + 'mcp_oauth', + 'invalid_client', + ); + }); + + it('should NOT fail the flow when OAuth error is received without cookies (DoS prevention)', async () => { + const flowId = 'test-user-id:test-server'; + const mockFlowManager = { + getFlowState: jest.fn().mockResolvedValue({ status: 'PENDING', createdAt: Date.now() }), + failFlow: jest.fn(), + }; + + getLogStores.mockReturnValueOnce({}); + require('~/config').getFlowStateManager.mockReturnValueOnce(mockFlowManager); + MCPOAuthHandler.resolveStateToFlowId.mockResolvedValueOnce(flowId); + + const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({ + error: 'invalid_client', + state: flowId, + }); + const basePath = getBasePath(); + + expect(response.status).toBe(302); + expect(response.headers.location).toBe(`${basePath}/oauth/error?error=invalid_client`); + expect(mockFlowManager.failFlow).not.toHaveBeenCalled(); + }); + }); + it('should redirect to error page when code is missing', async () => { const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({ state: 'test-user-id:test-server', diff --git a/packages/api/src/mcp/MCPConnectionFactory.ts b/packages/api/src/mcp/MCPConnectionFactory.ts index 488c2fa43a..d6c3b5290c 100644 --- a/packages/api/src/mcp/MCPConnectionFactory.ts +++ b/packages/api/src/mcp/MCPConnectionFactory.ts @@ -351,6 +351,7 @@ export class MCPConnectionFactory { config?.oauth_headers ?? {}, config?.oauth, this.allowedDomains, + // Only reuse stored client when deleteTokens is available for stale-client cleanup this.tokenMethods?.deleteTokens ? this.tokenMethods.findToken : undefined, ); @@ -491,7 +492,11 @@ export class MCPConnectionFactory { }); } - /** Determines if an error indicates the OAuth client registration was rejected */ + /** + * Checks whether an error indicates the OAuth client registration was rejected. + * Includes RFC 6749 §5.2 standard codes (`invalid_client`, `unauthorized_client`) + * and known vendor-specific patterns (Okta: `client_id mismatch`, Auth0: `client not found`). + */ static isClientRejection(error: unknown): boolean { if (!error || typeof error !== 'object') { return false; @@ -608,7 +613,7 @@ export class MCPConnectionFactory { tokens, clientInfo: flowMeta?.clientInfo, metadata: flowMeta?.metadata, - reusedStoredClient: flowMeta?.reusedStoredClient === true, + reusedStoredClient, }; } diff --git a/packages/api/src/mcp/__tests__/MCPConnectionFactory.test.ts b/packages/api/src/mcp/__tests__/MCPConnectionFactory.test.ts index f5c5fc771b..5983710f65 100644 --- a/packages/api/src/mcp/__tests__/MCPConnectionFactory.test.ts +++ b/packages/api/src/mcp/__tests__/MCPConnectionFactory.test.ts @@ -5,7 +5,7 @@ import type { MCPOAuthTokens } from '~/mcp/oauth'; import type * as t from '~/mcp/types'; import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory'; import { MCPConnection } from '~/mcp/connection'; -import { MCPOAuthHandler } from '~/mcp/oauth'; +import { MCPOAuthHandler, MCPTokenStorage } from '~/mcp/oauth'; import { processMCPEnv } from '~/utils'; jest.mock('~/mcp/connection'); @@ -24,6 +24,7 @@ const mockLogger = logger as jest.Mocked; const mockProcessMCPEnv = processMCPEnv as jest.MockedFunction; const mockMCPConnection = MCPConnection as jest.MockedClass; const mockMCPOAuthHandler = MCPOAuthHandler as jest.Mocked; +const mockMCPTokenStorage = MCPTokenStorage as jest.Mocked; describe('MCPConnectionFactory', () => { let mockUser: IUser | undefined; @@ -293,6 +294,78 @@ describe('MCPConnectionFactory', () => { ); }); + it('should clear stale client registration when returnOnOAuth flow fails with client rejection', async () => { + const basicOptions = { + serverName: 'test-server', + serverConfig: { + ...mockServerConfig, + url: 'https://api.example.com', + type: 'sse' as const, + } as t.SSEOptions, + }; + + const deleteTokensSpy = jest.fn().mockResolvedValue({ acknowledged: true, deletedCount: 1 }); + const oauthOptions = { + useOAuth: true as const, + user: mockUser, + flowManager: mockFlowManager, + returnOnOAuth: true, + oauthStart: jest.fn(), + tokenMethods: { + findToken: jest.fn(), + createToken: jest.fn(), + updateToken: jest.fn(), + deleteTokens: deleteTokensSpy, + }, + }; + + const mockFlowData = { + authorizationUrl: 'https://auth.example.com', + flowId: 'flow123', + flowMetadata: { + serverName: 'test-server', + userId: 'user123', + serverUrl: 'https://api.example.com', + state: 'random-state', + clientInfo: { client_id: 'stale-client' }, + reusedStoredClient: true, + }, + }; + + mockMCPOAuthHandler.initiateOAuthFlow.mockResolvedValue(mockFlowData); + mockMCPTokenStorage.deleteClientRegistration.mockResolvedValue(undefined); + // createFlow rejects with invalid_client — simulating stale client rejection + mockFlowManager.createFlow.mockRejectedValue(new Error('invalid_client')); + mockConnectionInstance.isConnected.mockResolvedValue(false); + + let oauthRequiredHandler: (data: Record) => Promise; + mockConnectionInstance.on.mockImplementation((event, handler) => { + if (event === 'oauthRequired') { + oauthRequiredHandler = handler as (data: Record) => Promise; + } + return mockConnectionInstance; + }); + + try { + await MCPConnectionFactory.create(basicOptions, oauthOptions); + } catch { + // Expected + } + + await oauthRequiredHandler!({ serverUrl: 'https://api.example.com' }); + + // Wait for the background .catch() handler to run + await new Promise((r) => setTimeout(r, 50)); + + // deleteClientRegistration should have been called via clearStaleClientIfRejected + expect(mockMCPTokenStorage.deleteClientRegistration).toHaveBeenCalledWith( + expect.objectContaining({ + userId: 'user123', + serverName: 'test-server', + }), + ); + }); + it('should skip new OAuth flow initiation when a PENDING flow already exists (returnOnOAuth)', async () => { const basicOptions = { serverName: 'test-server', diff --git a/packages/api/src/mcp/__tests__/MCPOAuthClientRegistrationReuse.test.ts b/packages/api/src/mcp/__tests__/MCPOAuthClientRegistrationReuse.test.ts index 38b2d13fb7..75cf4147b2 100644 --- a/packages/api/src/mcp/__tests__/MCPOAuthClientRegistrationReuse.test.ts +++ b/packages/api/src/mcp/__tests__/MCPOAuthClientRegistrationReuse.test.ts @@ -333,7 +333,7 @@ describe('MCPOAuthHandler - client registration reuse on reconnection', () => { await MCPTokenStorage.deleteClientRegistration({ userId: 'user-1', serverName: 'test-server', - deleteTokens: tokenStore.deleteToken, + deleteTokens: tokenStore.deleteTokens, }); // Second attempt (retry after failure): should do a fresh DCR diff --git a/packages/api/src/mcp/__tests__/helpers/oauthTestServer.ts b/packages/api/src/mcp/__tests__/helpers/oauthTestServer.ts index 770c1148e0..34f968d4e2 100644 --- a/packages/api/src/mcp/__tests__/helpers/oauthTestServer.ts +++ b/packages/api/src/mcp/__tests__/helpers/oauthTestServer.ts @@ -474,6 +474,25 @@ export class InMemoryTokenStore { this.tokens.delete(this.key(filter)); }; + deleteTokens = async (query: { + userId?: string; + type?: string; + identifier?: string; + }): Promise<{ acknowledged: boolean; deletedCount: number }> => { + let deletedCount = 0; + for (const [key, token] of this.tokens.entries()) { + const match = + (!query.userId || token.userId === query.userId) && + (!query.type || token.type === query.type) && + (!query.identifier || token.identifier === query.identifier); + if (match) { + this.tokens.delete(key); + deletedCount++; + } + } + return { acknowledged: true, deletedCount }; + }; + getAll(): InMemoryToken[] { return [...this.tokens.values()]; } diff --git a/packages/api/src/mcp/oauth/handler.ts b/packages/api/src/mcp/oauth/handler.ts index 883ebb123f..ccca5b1945 100644 --- a/packages/api/src/mcp/oauth/handler.ts +++ b/packages/api/src/mcp/oauth/handler.ts @@ -514,9 +514,9 @@ export class MCPOAuthHandler { .redirect_uris?.[0]; const storedIssuer = typeof existing.clientMetadata?.issuer === 'string' - ? existing.clientMetadata.issuer + ? existing.clientMetadata.issuer.replace(/\/+$/, '') : null; - const currentIssuer = metadata.issuer ?? authServerUrl.toString(); + const currentIssuer = (metadata.issuer ?? authServerUrl.toString()).replace(/\/+$/, ''); if (!storedRedirectUri || storedRedirectUri !== redirectUri) { logger.debug(