From f7ac449ca47c017b3a51b481acb3dcbc67381d4d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=B3n=20Levy?= Date: Tue, 3 Mar 2026 00:27:36 +0000 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=8C=20fix:=20Resolve=20MCP=20OAuth=20f?= =?UTF-8?q?low=20state=20race=20condition=20(#11941)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 🔌 fix: Resolve MCP OAuth flow state race condition The OAuth callback arrives before the flow state is stored because `createFlow()` returns a long-running Promise that only resolves on flow COMPLETION, not when the initial PENDING state is persisted. Calling it fire-and-forget with `.catch(() => {})` meant the redirect happened before the state existed, causing "Flow state not found" errors. Changes: - Add `initFlow()` to FlowStateManager that stores PENDING state and returns immediately, decoupling state persistence from monitoring - Await `initFlow()` before emitting the OAuth redirect so the callback always finds existing state - Keep `createFlow()` in the background for monitoring, but log warnings instead of silently swallowing errors - Increase FLOWS cache TTL from 3 minutes to 10 minutes to give users more time to complete OAuth consent screens Co-Authored-By: Claude Opus 4.6 * 🔌 refactor: Revert FLOWS cache TTL change The race condition fix (initFlow) is sufficient on its own. TTL configurability should be a separate enhancement via librechat.yaml mcpSettings rather than a hardcoded increase. Co-Authored-By: Claude Opus 4.6 * 🔌 fix: Address PR review — restore FLOWS TTL, fix blocking-path race, clean up dead args - Restore FLOWS cache TTL to 10 minutes (was silently dropped back to 3) - Add initFlow before oauthStart in blocking handleOAuthRequired path to guarantee state persistence before any redirect - Pass {} to createFlow metadata arg (dead after initFlow writes state) - Downgrade background monitor .catch from logger.warn to logger.debug - Replace process.nextTick with Promise.resolve in test (correct semantics) - Add initFlow TTL assertion test - Add blocking-path ordering test (initFlow → oauthStart → createFlow) Co-Authored-By: Claude Opus 4.6 --------- Co-authored-by: Claude Opus 4.6 --- api/cache/getLogStores.js | 2 +- packages/api/src/flow/manager.test.ts | 66 ++++- packages/api/src/flow/manager.ts | 18 ++ packages/api/src/mcp/MCPConnectionFactory.ts | 24 +- .../__tests__/MCPConnectionFactory.test.ts | 251 +++++++++++++++++- 5 files changed, 345 insertions(+), 16 deletions(-) diff --git a/api/cache/getLogStores.js b/api/cache/getLogStores.js index 3089192196..70eb681e53 100644 --- a/api/cache/getLogStores.js +++ b/api/cache/getLogStores.js @@ -47,7 +47,7 @@ const namespaces = { [CacheKeys.MODEL_QUERIES]: standardCache(CacheKeys.MODEL_QUERIES), [CacheKeys.AUDIO_RUNS]: standardCache(CacheKeys.AUDIO_RUNS, Time.TEN_MINUTES), [CacheKeys.MESSAGES]: standardCache(CacheKeys.MESSAGES, Time.ONE_MINUTE), - [CacheKeys.FLOWS]: standardCache(CacheKeys.FLOWS, Time.ONE_MINUTE * 3), + [CacheKeys.FLOWS]: standardCache(CacheKeys.FLOWS, Time.ONE_MINUTE * 10), [CacheKeys.OPENID_EXCHANGED_TOKENS]: standardCache( CacheKeys.OPENID_EXCHANGED_TOKENS, Time.TEN_MINUTES, diff --git a/packages/api/src/flow/manager.test.ts b/packages/api/src/flow/manager.test.ts index a419f4aeab..b34dcbafab 100644 --- a/packages/api/src/flow/manager.test.ts +++ b/packages/api/src/flow/manager.test.ts @@ -24,7 +24,6 @@ class MockKeyv { return this.store.get(key); } - // eslint-disable-next-line @typescript-eslint/no-unused-vars async set(key: string, value: FlowState, _ttl?: number): Promise { this.store.set(key, value); return true; @@ -160,6 +159,71 @@ describe('FlowStateManager', () => { }, 15000); }); + describe('initFlow', () => { + const flowId = 'init-test-flow'; + const type = 'test-type'; + const flowKey = `${type}:${flowId}`; + + it('stores a PENDING flow state in the cache', async () => { + await flowManager.initFlow(flowId, type, { serverName: 'test' }); + + const state = await store.get(flowKey); + expect(state).toBeDefined(); + expect(state!.status).toBe('PENDING'); + expect(state!.type).toBe(type); + expect(state!.metadata).toEqual({ serverName: 'test' }); + expect(state!.createdAt).toBeGreaterThan(0); + }); + + it('overwrites an existing flow state', async () => { + await store.set(flowKey, { + type, + status: 'COMPLETED', + metadata: { old: true }, + createdAt: Date.now() - 10000, + }); + + await flowManager.initFlow(flowId, type, { new: true }); + + const state = await store.get(flowKey); + expect(state!.status).toBe('PENDING'); + expect(state!.metadata).toEqual({ new: true }); + }); + + it('allows createFlow to find and monitor the pre-stored state', async () => { + // initFlow stores the PENDING state + await flowManager.initFlow(flowId, type, { preStored: true }); + + // createFlow should find the existing state and start monitoring + const flowPromise = flowManager.createFlow(flowId, type); + + // Complete the flow so the monitor resolves + await new Promise((resolve) => setTimeout(resolve, 500)); + await flowManager.completeFlow(flowId, type, 'success'); + + const result = await flowPromise; + expect(result).toBe('success'); + }, 15000); + + it('passes the configured TTL to keyv.set', async () => { + const setSpy = jest.spyOn(store, 'set'); + + await flowManager.initFlow(flowId, type, { serverName: 'test' }); + + expect(setSpy).toHaveBeenCalledWith( + flowKey, + expect.objectContaining({ status: 'PENDING' }), + 30000, + ); + }); + + it('propagates store write failures', async () => { + jest.spyOn(store, 'set').mockRejectedValueOnce(new Error('Store write failed')); + + await expect(flowManager.initFlow(flowId, type)).rejects.toThrow('Store write failed'); + }); + }); + describe('deleteFlow', () => { const flowId = 'test-flow-123'; const type = 'test-type'; diff --git a/packages/api/src/flow/manager.ts b/packages/api/src/flow/manager.ts index 2e2731a2d4..4f9023a3d7 100644 --- a/packages/api/src/flow/manager.ts +++ b/packages/api/src/flow/manager.ts @@ -88,6 +88,24 @@ export class FlowStateManager { return normalizedExpiresAt < Date.now(); } + /** + * Stores initial PENDING flow state without starting the monitor loop. + * Use this when you need to guarantee the state is persisted before + * performing an action (e.g., an OAuth redirect), then call createFlow() + * separately to start monitoring for completion. + */ + async initFlow(flowId: string, type: string, metadata: FlowMetadata = {}): Promise { + const flowKey = this.getFlowKey(flowId, type); + const initialState: FlowState = { + type, + status: 'PENDING', + metadata, + createdAt: Date.now(), + }; + logger.debug(`[${flowKey}] Storing initial flow state`); + await this.keyv.set(flowKey, initialState, this.ttl); + } + /** * Creates a new flow and waits for its completion */ diff --git a/packages/api/src/mcp/MCPConnectionFactory.ts b/packages/api/src/mcp/MCPConnectionFactory.ts index a8f631614d..2bf02c7a3f 100644 --- a/packages/api/src/mcp/MCPConnectionFactory.ts +++ b/packages/api/src/mcp/MCPConnectionFactory.ts @@ -328,9 +328,14 @@ export class MCPConnectionFactory { await this.flowManager!.deleteFlow(newFlowId, 'mcp_oauth'); } - this.flowManager!.createFlow(newFlowId, 'mcp_oauth', flowMetadata, this.signal).catch( - () => {}, - ); + // Store flow state BEFORE redirecting so the callback can find it + await this.flowManager!.initFlow(newFlowId, 'mcp_oauth', flowMetadata); + + // Start monitoring in background — createFlow will find the existing PENDING state + // written by initFlow above, so metadata arg is unused (pass {} to make that explicit) + this.flowManager!.createFlow(newFlowId, 'mcp_oauth', {}, this.signal).catch((error) => { + logger.debug(`${this.logPrefix} OAuth flow monitor ended`, error); + }); if (this.oauthStart) { logger.info(`${this.logPrefix} OAuth flow started, issuing authorization URL`); @@ -512,6 +517,9 @@ export class MCPConnectionFactory { this.serverConfig.oauth, ); + // Store flow state BEFORE redirecting so the callback can find it + await this.flowManager.initFlow(newFlowId, 'mcp_oauth', flowMetadata as FlowMetadata); + if (typeof this.oauthStart === 'function') { logger.info(`${this.logPrefix} OAuth flow started, issued authorization URL to user`); await this.oauthStart(authorizationUrl); @@ -521,13 +529,9 @@ export class MCPConnectionFactory { ); } - /** Tokens from the new flow */ - const tokens = await this.flowManager.createFlow( - newFlowId, - 'mcp_oauth', - flowMetadata as FlowMetadata, - this.signal, - ); + // createFlow will find the existing PENDING state written by initFlow above, + // so metadata arg is unused (pass {} to make that explicit) + const tokens = await this.flowManager.createFlow(newFlowId, 'mcp_oauth', {}, this.signal); if (typeof this.oauthEnd === 'function') { await this.oauthEnd(); } diff --git a/packages/api/src/mcp/__tests__/MCPConnectionFactory.test.ts b/packages/api/src/mcp/__tests__/MCPConnectionFactory.test.ts index 263c84357a..c8b3a4b04f 100644 --- a/packages/api/src/mcp/__tests__/MCPConnectionFactory.test.ts +++ b/packages/api/src/mcp/__tests__/MCPConnectionFactory.test.ts @@ -45,6 +45,7 @@ describe('MCPConnectionFactory', () => { } as t.MCPOptions; mockFlowManager = { + initFlow: jest.fn().mockResolvedValue(undefined), createFlow: jest.fn(), createFlowWithHandler: jest.fn(), getFlowState: jest.fn(), @@ -233,7 +234,8 @@ describe('MCPConnectionFactory', () => { }; mockMCPOAuthHandler.initiateOAuthFlow.mockResolvedValue(mockFlowData); - mockFlowManager.createFlow.mockRejectedValue(new Error('Timeout expected')); + // createFlow runs as a background monitor — simulate it staying pending + mockFlowManager.createFlow.mockReturnValue(new Promise(() => {})); mockConnectionInstance.isConnected.mockResolvedValue(false); let oauthRequiredHandler: (data: Record) => Promise; @@ -261,6 +263,18 @@ describe('MCPConnectionFactory', () => { {}, undefined, ); + + // initFlow must be awaited BEFORE the redirect to guarantee state is stored + expect(mockFlowManager.initFlow).toHaveBeenCalledWith( + 'flow123', + 'mcp_oauth', + mockFlowData.flowMetadata, + ); + const initCallOrder = mockFlowManager.initFlow.mock.invocationCallOrder[0]; + const oauthStartCallOrder = (oauthOptions.oauthStart as jest.Mock).mock + .invocationCallOrder[0]; + expect(initCallOrder).toBeLessThan(oauthStartCallOrder); + expect(oauthOptions.oauthStart).toHaveBeenCalledWith('https://auth.example.com'); expect(mockConnectionInstance.emit).toHaveBeenCalledWith( 'oauthFailed', @@ -317,6 +331,223 @@ describe('MCPConnectionFactory', () => { ); }); + it('should emit oauthFailed when initFlow fails to store flow state (returnOnOAuth)', async () => { + const basicOptions = { + serverName: 'test-server', + serverConfig: { + ...mockServerConfig, + url: 'https://api.example.com', + type: 'sse' as const, + } as t.SSEOptions, + }; + + const oauthOptions = { + useOAuth: true as const, + user: mockUser, + flowManager: mockFlowManager, + returnOnOAuth: true, + oauthStart: jest.fn(), + tokenMethods: { + findToken: jest.fn(), + createToken: jest.fn(), + updateToken: jest.fn(), + deleteTokens: jest.fn(), + }, + }; + + const mockFlowData = { + authorizationUrl: 'https://auth.example.com', + flowId: 'flow123', + flowMetadata: { serverName: 'test-server', userId: 'user123' }, + }; + + mockMCPOAuthHandler.initiateOAuthFlow.mockResolvedValue(mockFlowData); + mockFlowManager.initFlow.mockRejectedValue(new Error('Store write failed')); + mockConnectionInstance.isConnected.mockResolvedValue(false); + + let oauthRequiredHandler: (data: Record) => Promise; + mockConnectionInstance.on.mockImplementation((event, handler) => { + if (event === 'oauthRequired') { + oauthRequiredHandler = handler as (data: Record) => Promise; + } + return mockConnectionInstance; + }); + + try { + await MCPConnectionFactory.create(basicOptions, oauthOptions); + } catch { + // Expected to fail + } + + await oauthRequiredHandler!({ serverUrl: 'https://api.example.com' }); + + // initFlow failed, so oauthStart should NOT have been called (redirect never happens) + expect(oauthOptions.oauthStart).not.toHaveBeenCalled(); + // createFlow should NOT have been called since initFlow failed first + expect(mockFlowManager.createFlow).not.toHaveBeenCalled(); + expect(mockConnectionInstance.emit).toHaveBeenCalledWith( + 'oauthFailed', + expect.objectContaining({ message: 'OAuth initiation failed' }), + ); + expect(mockLogger.error).toHaveBeenCalled(); + }); + + it('should log warnings when background createFlow monitor rejects (returnOnOAuth)', async () => { + const basicOptions = { + serverName: 'test-server', + serverConfig: { + ...mockServerConfig, + url: 'https://api.example.com', + type: 'sse' as const, + } as t.SSEOptions, + }; + + const oauthOptions = { + useOAuth: true as const, + user: mockUser, + flowManager: mockFlowManager, + returnOnOAuth: true, + oauthStart: jest.fn(), + tokenMethods: { + findToken: jest.fn(), + createToken: jest.fn(), + updateToken: jest.fn(), + deleteTokens: jest.fn(), + }, + }; + + const mockFlowData = { + authorizationUrl: 'https://auth.example.com', + flowId: 'flow123', + flowMetadata: { serverName: 'test-server', userId: 'user123' }, + }; + + mockMCPOAuthHandler.initiateOAuthFlow.mockResolvedValue(mockFlowData); + // Simulate the background monitor timing out + mockFlowManager.createFlow.mockRejectedValue(new Error('mcp_oauth flow timed out')); + mockConnectionInstance.isConnected.mockResolvedValue(false); + + let oauthRequiredHandler: (data: Record) => Promise; + mockConnectionInstance.on.mockImplementation((event, handler) => { + if (event === 'oauthRequired') { + oauthRequiredHandler = handler as (data: Record) => Promise; + } + return mockConnectionInstance; + }); + + try { + await MCPConnectionFactory.create(basicOptions, oauthOptions); + } catch { + // Expected to fail + } + + await oauthRequiredHandler!({ serverUrl: 'https://api.example.com' }); + + // Allow the .catch handler on createFlow to execute + await Promise.resolve(); + + // initFlow should have succeeded and redirect should have happened + expect(mockFlowManager.initFlow).toHaveBeenCalled(); + expect(oauthOptions.oauthStart).toHaveBeenCalledWith('https://auth.example.com'); + // The background monitor error should be logged, not silently swallowed + expect(mockLogger.debug).toHaveBeenCalledWith( + expect.stringContaining('OAuth flow monitor ended'), + expect.any(Error), + ); + }); + + it('should call initFlow before createFlow in blocking OAuth path (non-returnOnOAuth)', async () => { + const sseConfig = { + ...mockServerConfig, + url: 'https://api.example.com', + type: 'sse' as const, + } as t.SSEOptions; + + const basicOptions = { + serverName: 'test-server', + serverConfig: sseConfig, + }; + + const oauthOptions = { + useOAuth: true as const, + user: mockUser, + flowManager: mockFlowManager, + oauthStart: jest.fn(), + oauthEnd: jest.fn(), + tokenMethods: { + findToken: jest.fn(), + createToken: jest.fn(), + updateToken: jest.fn(), + deleteTokens: jest.fn(), + }, + }; + + const mockFlowData = { + authorizationUrl: 'https://auth.example.com', + flowId: 'flow123', + flowMetadata: { + serverName: 'test-server', + userId: 'user123', + serverUrl: 'https://api.example.com', + state: 'random-state', + clientInfo: { client_id: 'client123' }, + metadata: { token_endpoint: 'https://auth.example.com/token' }, + }, + }; + + const mockTokens: MCPOAuthTokens = { + access_token: 'access123', + refresh_token: 'refresh123', + token_type: 'Bearer', + obtained_at: Date.now(), + }; + + // processMCPEnv must return config with url so handleOAuthRequired proceeds + mockProcessMCPEnv.mockReturnValue(sseConfig); + mockMCPOAuthHandler.initiateOAuthFlow.mockResolvedValue(mockFlowData); + mockMCPOAuthHandler.generateFlowId.mockReturnValue('flow123'); + mockFlowManager.getFlowState.mockResolvedValue(null); + mockFlowManager.createFlow.mockResolvedValue(mockTokens); + mockConnectionInstance.isConnected.mockResolvedValue(false); + + let oauthRequiredHandler: (data: Record) => Promise; + mockConnectionInstance.on.mockImplementation((event, handler) => { + if (event === 'oauthRequired') { + oauthRequiredHandler = handler as (data: Record) => Promise; + } + return mockConnectionInstance; + }); + + try { + await MCPConnectionFactory.create(basicOptions, oauthOptions); + } catch { + // Expected to fail + } + + await oauthRequiredHandler!({ serverUrl: 'https://api.example.com' }); + + // initFlow must be called BEFORE oauthStart and createFlow + expect(mockFlowManager.initFlow).toHaveBeenCalledWith( + 'flow123', + 'mcp_oauth', + mockFlowData.flowMetadata, + ); + const initCallOrder = mockFlowManager.initFlow.mock.invocationCallOrder[0]; + const oauthStartCallOrder = (oauthOptions.oauthStart as jest.Mock).mock + .invocationCallOrder[0]; + const createCallOrder = mockFlowManager.createFlow.mock.invocationCallOrder[0]; + expect(initCallOrder).toBeLessThan(oauthStartCallOrder); + expect(initCallOrder).toBeLessThan(createCallOrder); + + // createFlow should receive {} since initFlow already persisted metadata + expect(mockFlowManager.createFlow).toHaveBeenCalledWith( + 'flow123', + 'mcp_oauth', + {}, + undefined, + ); + }); + it('should delete stale flow and create new OAuth flow when existing flow is COMPLETED', async () => { const basicOptions = { serverName: 'test-server', @@ -358,7 +589,8 @@ describe('MCPConnectionFactory', () => { }); mockMCPOAuthHandler.initiateOAuthFlow.mockResolvedValue(mockFlowData); mockFlowManager.deleteFlow.mockResolvedValue(true); - mockFlowManager.createFlow.mockRejectedValue(new Error('Timeout expected')); + // createFlow runs as a background monitor — simulate it staying pending + mockFlowManager.createFlow.mockReturnValue(new Promise(() => {})); mockConnectionInstance.isConnected.mockResolvedValue(false); let oauthRequiredHandler: (data: Record) => Promise; @@ -379,16 +611,27 @@ describe('MCPConnectionFactory', () => { expect(mockFlowManager.deleteFlow).toHaveBeenCalledWith('user123:test-server', 'mcp_oauth'); + // initFlow must be called after deleteFlow and before createFlow const deleteCallOrder = mockFlowManager.deleteFlow.mock.invocationCallOrder[0]; + const initCallOrder = mockFlowManager.initFlow.mock.invocationCallOrder[0]; const createCallOrder = mockFlowManager.createFlow.mock.invocationCallOrder[0]; - expect(deleteCallOrder).toBeLessThan(createCallOrder); + expect(deleteCallOrder).toBeLessThan(initCallOrder); + expect(initCallOrder).toBeLessThan(createCallOrder); - expect(mockFlowManager.createFlow).toHaveBeenCalledWith( + expect(mockFlowManager.initFlow).toHaveBeenCalledWith( 'user123:test-server', 'mcp_oauth', expect.objectContaining({ codeVerifier: 'new-code-verifier-xyz', }), + ); + + // createFlow finds the existing PENDING state written by initFlow, + // so metadata arg is unused (passed as {}) + expect(mockFlowManager.createFlow).toHaveBeenCalledWith( + 'user123:test-server', + 'mcp_oauth', + {}, undefined, ); });