From f61e057f7f1c3060e5dd304fe08d484388ea197c Mon Sep 17 00:00:00 2001 From: Sean McGrath Date: Wed, 24 Sep 2025 01:35:56 +1200 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=90=20fix:=20MCP=20OAuth=20Token=20Per?= =?UTF-8?q?sistence=20Race=20Condition=20and=20Refresh=20Auth=20Method=20(?= =?UTF-8?q?#9773)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * set supported endpoint auth method when token_url exists * persist tokens immediately * add token storage validation tests --- api/server/routes/__tests__/mcp.spec.js | 87 ++++++++++++++++++++++++- api/server/routes/mcp.js | 37 ++++++++++- packages/api/src/mcp/oauth/handler.ts | 1 + 3 files changed, 122 insertions(+), 3 deletions(-) diff --git a/api/server/routes/__tests__/mcp.spec.js b/api/server/routes/__tests__/mcp.spec.js index 1eb55224f5..0df28d7b10 100644 --- a/api/server/routes/__tests__/mcp.spec.js +++ b/api/server/routes/__tests__/mcp.spec.js @@ -11,6 +11,9 @@ jest.mock('@librechat/api', () => ({ completeOAuthFlow: jest.fn(), generateFlowId: jest.fn(), }, + MCPTokenStorage: { + storeTokens: jest.fn(), + }, getUserMCPAuthMap: jest.fn(), })); @@ -234,7 +237,7 @@ describe('MCP Routes', () => { }); describe('GET /:serverName/oauth/callback', () => { - const { MCPOAuthHandler } = require('@librechat/api'); + const { MCPOAuthHandler, MCPTokenStorage } = require('@librechat/api'); const { getLogStores } = require('~/cache'); it('should redirect to error page when OAuth error is received', async () => { @@ -280,6 +283,7 @@ describe('MCP Routes', () => { it('should handle OAuth callback successfully', async () => { const mockFlowManager = { completeFlow: jest.fn().mockResolvedValue(), + deleteFlow: jest.fn().mockResolvedValue(true), }; const mockFlowState = { serverName: 'test-server', @@ -295,6 +299,7 @@ describe('MCP Routes', () => { MCPOAuthHandler.getFlowState.mockResolvedValue(mockFlowState); MCPOAuthHandler.completeOAuthFlow.mockResolvedValue(mockTokens); + MCPTokenStorage.storeTokens.mockResolvedValue(); getLogStores.mockReturnValue({}); require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager); @@ -332,11 +337,24 @@ describe('MCP Routes', () => { 'test-auth-code', mockFlowManager, ); + expect(MCPTokenStorage.storeTokens).toHaveBeenCalledWith( + expect.objectContaining({ + userId: 'test-user-id', + serverName: 'test-server', + tokens: mockTokens, + clientInfo: mockFlowState.clientInfo, + metadata: mockFlowState.metadata, + }), + ); + const storeInvocation = MCPTokenStorage.storeTokens.mock.invocationCallOrder[0]; + const connectInvocation = mockMcpManager.getUserConnection.mock.invocationCallOrder[0]; + expect(storeInvocation).toBeLessThan(connectInvocation); expect(mockFlowManager.completeFlow).toHaveBeenCalledWith( 'tool-flow-123', 'mcp_oauth', mockTokens, ); + expect(mockFlowManager.deleteFlow).toHaveBeenCalledWith('test-flow-id', 'mcp_get_tokens'); }); it('should redirect to error page when callback processing fails', async () => { @@ -354,6 +372,7 @@ describe('MCP Routes', () => { it('should handle system-level OAuth completion', async () => { const mockFlowManager = { completeFlow: jest.fn().mockResolvedValue(), + deleteFlow: jest.fn().mockResolvedValue(true), }; const mockFlowState = { serverName: 'test-server', @@ -369,6 +388,7 @@ describe('MCP Routes', () => { MCPOAuthHandler.getFlowState.mockResolvedValue(mockFlowState); MCPOAuthHandler.completeOAuthFlow.mockResolvedValue(mockTokens); + MCPTokenStorage.storeTokens.mockResolvedValue(); getLogStores.mockReturnValue({}); require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager); @@ -379,11 +399,13 @@ describe('MCP Routes', () => { expect(response.status).toBe(302); expect(response.headers.location).toBe('/oauth/success?serverName=test-server'); + expect(mockFlowManager.deleteFlow).toHaveBeenCalledWith('test-flow-id', 'mcp_get_tokens'); }); it('should handle reconnection failure after OAuth', async () => { const mockFlowManager = { completeFlow: jest.fn().mockResolvedValue(), + deleteFlow: jest.fn().mockResolvedValue(true), }; const mockFlowState = { serverName: 'test-server', @@ -399,6 +421,7 @@ describe('MCP Routes', () => { MCPOAuthHandler.getFlowState.mockResolvedValue(mockFlowState); MCPOAuthHandler.completeOAuthFlow.mockResolvedValue(mockTokens); + MCPTokenStorage.storeTokens.mockResolvedValue(); getLogStores.mockReturnValue({}); require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager); @@ -418,6 +441,46 @@ describe('MCP Routes', () => { expect(response.status).toBe(302); expect(response.headers.location).toBe('/oauth/success?serverName=test-server'); + expect(MCPTokenStorage.storeTokens).toHaveBeenCalled(); + expect(mockFlowManager.deleteFlow).toHaveBeenCalledWith('test-flow-id', 'mcp_get_tokens'); + }); + + it('should redirect to error page if token storage fails', async () => { + const mockFlowManager = { + completeFlow: jest.fn().mockResolvedValue(), + deleteFlow: jest.fn().mockResolvedValue(true), + }; + const mockFlowState = { + serverName: 'test-server', + userId: 'test-user-id', + metadata: { toolFlowId: 'tool-flow-123' }, + clientInfo: {}, + codeVerifier: 'test-verifier', + }; + const mockTokens = { + access_token: 'test-access-token', + refresh_token: 'test-refresh-token', + }; + + MCPOAuthHandler.getFlowState.mockResolvedValue(mockFlowState); + MCPOAuthHandler.completeOAuthFlow.mockResolvedValue(mockTokens); + MCPTokenStorage.storeTokens.mockRejectedValue(new Error('store failed')); + getLogStores.mockReturnValue({}); + require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager); + + const mockMcpManager = { + getUserConnection: jest.fn(), + }; + require('~/config').getMCPManager.mockReturnValue(mockMcpManager); + + const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({ + code: 'test-auth-code', + state: 'test-flow-id', + }); + + expect(response.status).toBe(302); + expect(response.headers.location).toBe('/oauth/error?error=callback_failed'); + expect(mockMcpManager.getUserConnection).not.toHaveBeenCalled(); }); }); @@ -1143,7 +1206,11 @@ describe('MCP Routes', () => { describe('GET /:serverName/oauth/callback - Edge Cases', () => { it('should handle OAuth callback without toolFlowId (falsy toolFlowId)', async () => { - const { MCPOAuthHandler } = require('@librechat/api'); + const { MCPOAuthHandler, MCPTokenStorage } = require('@librechat/api'); + const mockTokens = { + access_token: 'edge-access-token', + refresh_token: 'edge-refresh-token', + }; MCPOAuthHandler.getFlowState = jest.fn().mockResolvedValue({ id: 'test-flow-id', userId: 'test-user-id', @@ -1155,6 +1222,8 @@ describe('MCP Routes', () => { clientInfo: {}, codeVerifier: 'test-verifier', }); + MCPOAuthHandler.completeOAuthFlow = jest.fn().mockResolvedValue(mockTokens); + MCPTokenStorage.storeTokens.mockResolvedValue(); const mockFlowManager = { completeFlow: jest.fn(), @@ -1179,6 +1248,11 @@ describe('MCP Routes', () => { it('should handle null cached tools in OAuth callback (triggers || {} fallback)', async () => { const { getCachedTools } = require('~/server/services/Config'); getCachedTools.mockResolvedValue(null); + const { MCPOAuthHandler, MCPTokenStorage } = require('@librechat/api'); + const mockTokens = { + access_token: 'edge-access-token', + refresh_token: 'edge-refresh-token', + }; const mockFlowManager = { getFlowState: jest.fn().mockResolvedValue({ @@ -1191,6 +1265,15 @@ describe('MCP Routes', () => { completeFlow: jest.fn(), }; require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager); + MCPOAuthHandler.getFlowState.mockResolvedValue({ + serverName: 'test-server', + userId: 'test-user-id', + metadata: { serverUrl: 'https://example.com', oauth: {} }, + clientInfo: {}, + codeVerifier: 'test-verifier', + }); + MCPOAuthHandler.completeOAuthFlow.mockResolvedValue(mockTokens); + MCPTokenStorage.storeTokens.mockResolvedValue(); const mockMcpManager = { getUserConnection: jest.fn().mockResolvedValue({ diff --git a/api/server/routes/mcp.js b/api/server/routes/mcp.js index d4aa225f23..9182a2a0cd 100644 --- a/api/server/routes/mcp.js +++ b/api/server/routes/mcp.js @@ -1,7 +1,7 @@ const { Router } = require('express'); const { logger } = require('@librechat/data-schemas'); const { CacheKeys, Constants } = require('librechat-data-provider'); -const { MCPOAuthHandler, getUserMCPAuthMap } = require('@librechat/api'); +const { MCPOAuthHandler, MCPTokenStorage, getUserMCPAuthMap } = require('@librechat/api'); const { getMCPManager, getFlowStateManager, getOAuthReconnectionManager } = require('~/config'); const { getMCPSetupData, getServerConnectionStatus } = require('~/server/services/MCP'); const { findToken, updateToken, createToken, deleteTokens } = require('~/models'); @@ -130,6 +130,41 @@ router.get('/:serverName/oauth/callback', async (req, res) => { const tokens = await MCPOAuthHandler.completeOAuthFlow(flowId, code, flowManager); logger.info('[MCP OAuth] OAuth flow completed, tokens received in callback route'); + /** Persist tokens immediately so reconnection uses fresh credentials */ + if (flowState?.userId && tokens) { + try { + await MCPTokenStorage.storeTokens({ + userId: flowState.userId, + serverName, + tokens, + createToken, + updateToken, + findToken, + clientInfo: flowState.clientInfo, + metadata: flowState.metadata, + }); + logger.debug('[MCP OAuth] Stored OAuth tokens prior to reconnection', { + serverName, + userId: flowState.userId, + }); + } catch (error) { + logger.error('[MCP OAuth] Failed to store OAuth tokens after callback', error); + throw error; + } + + /** + * Clear any cached `mcp_get_tokens` flow result so subsequent lookups + * re-fetch the freshly stored credentials instead of returning stale nulls. + */ + if (typeof flowManager?.deleteFlow === 'function') { + try { + await flowManager.deleteFlow(flowId, 'mcp_get_tokens'); + } catch (error) { + logger.warn('[MCP OAuth] Failed to clear cached token flow state', error); + } + } + } + try { const mcpManager = getMCPManager(flowState.userId); logger.debug(`[MCP OAuth] Attempting to reconnect ${serverName} with new OAuth tokens`); diff --git a/packages/api/src/mcp/oauth/handler.ts b/packages/api/src/mcp/oauth/handler.ts index a3fb144e60..a96dae8442 100644 --- a/packages/api/src/mcp/oauth/handler.ts +++ b/packages/api/src/mcp/oauth/handler.ts @@ -504,6 +504,7 @@ export class MCPOAuthHandler { let authMethods: string[] | undefined; if (config?.token_url) { tokenUrl = config.token_url; + authMethods = config.token_endpoint_auth_methods_supported; } else if (!metadata.serverUrl) { throw new Error('No token URL available for refresh'); } else {