🔐 fix: MCP OAuth Token Persistence Race Condition and Refresh Auth Method (#9773)

* set supported endpoint auth method when token_url exists

* persist tokens immediately

* add token storage validation tests
This commit is contained in:
Sean McGrath 2025-09-24 01:35:56 +12:00 committed by GitHub
parent 91e49d82aa
commit f61e057f7f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 122 additions and 3 deletions

View file

@ -11,6 +11,9 @@ jest.mock('@librechat/api', () => ({
completeOAuthFlow: jest.fn(),
generateFlowId: jest.fn(),
},
MCPTokenStorage: {
storeTokens: jest.fn(),
},
getUserMCPAuthMap: jest.fn(),
}));
@ -234,7 +237,7 @@ describe('MCP Routes', () => {
});
describe('GET /:serverName/oauth/callback', () => {
const { MCPOAuthHandler } = require('@librechat/api');
const { MCPOAuthHandler, MCPTokenStorage } = require('@librechat/api');
const { getLogStores } = require('~/cache');
it('should redirect to error page when OAuth error is received', async () => {
@ -280,6 +283,7 @@ describe('MCP Routes', () => {
it('should handle OAuth callback successfully', async () => {
const mockFlowManager = {
completeFlow: jest.fn().mockResolvedValue(),
deleteFlow: jest.fn().mockResolvedValue(true),
};
const mockFlowState = {
serverName: 'test-server',
@ -295,6 +299,7 @@ describe('MCP Routes', () => {
MCPOAuthHandler.getFlowState.mockResolvedValue(mockFlowState);
MCPOAuthHandler.completeOAuthFlow.mockResolvedValue(mockTokens);
MCPTokenStorage.storeTokens.mockResolvedValue();
getLogStores.mockReturnValue({});
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
@ -332,11 +337,24 @@ describe('MCP Routes', () => {
'test-auth-code',
mockFlowManager,
);
expect(MCPTokenStorage.storeTokens).toHaveBeenCalledWith(
expect.objectContaining({
userId: 'test-user-id',
serverName: 'test-server',
tokens: mockTokens,
clientInfo: mockFlowState.clientInfo,
metadata: mockFlowState.metadata,
}),
);
const storeInvocation = MCPTokenStorage.storeTokens.mock.invocationCallOrder[0];
const connectInvocation = mockMcpManager.getUserConnection.mock.invocationCallOrder[0];
expect(storeInvocation).toBeLessThan(connectInvocation);
expect(mockFlowManager.completeFlow).toHaveBeenCalledWith(
'tool-flow-123',
'mcp_oauth',
mockTokens,
);
expect(mockFlowManager.deleteFlow).toHaveBeenCalledWith('test-flow-id', 'mcp_get_tokens');
});
it('should redirect to error page when callback processing fails', async () => {
@ -354,6 +372,7 @@ describe('MCP Routes', () => {
it('should handle system-level OAuth completion', async () => {
const mockFlowManager = {
completeFlow: jest.fn().mockResolvedValue(),
deleteFlow: jest.fn().mockResolvedValue(true),
};
const mockFlowState = {
serverName: 'test-server',
@ -369,6 +388,7 @@ describe('MCP Routes', () => {
MCPOAuthHandler.getFlowState.mockResolvedValue(mockFlowState);
MCPOAuthHandler.completeOAuthFlow.mockResolvedValue(mockTokens);
MCPTokenStorage.storeTokens.mockResolvedValue();
getLogStores.mockReturnValue({});
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
@ -379,11 +399,13 @@ describe('MCP Routes', () => {
expect(response.status).toBe(302);
expect(response.headers.location).toBe('/oauth/success?serverName=test-server');
expect(mockFlowManager.deleteFlow).toHaveBeenCalledWith('test-flow-id', 'mcp_get_tokens');
});
it('should handle reconnection failure after OAuth', async () => {
const mockFlowManager = {
completeFlow: jest.fn().mockResolvedValue(),
deleteFlow: jest.fn().mockResolvedValue(true),
};
const mockFlowState = {
serverName: 'test-server',
@ -399,6 +421,7 @@ describe('MCP Routes', () => {
MCPOAuthHandler.getFlowState.mockResolvedValue(mockFlowState);
MCPOAuthHandler.completeOAuthFlow.mockResolvedValue(mockTokens);
MCPTokenStorage.storeTokens.mockResolvedValue();
getLogStores.mockReturnValue({});
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
@ -418,6 +441,46 @@ describe('MCP Routes', () => {
expect(response.status).toBe(302);
expect(response.headers.location).toBe('/oauth/success?serverName=test-server');
expect(MCPTokenStorage.storeTokens).toHaveBeenCalled();
expect(mockFlowManager.deleteFlow).toHaveBeenCalledWith('test-flow-id', 'mcp_get_tokens');
});
it('should redirect to error page if token storage fails', async () => {
const mockFlowManager = {
completeFlow: jest.fn().mockResolvedValue(),
deleteFlow: jest.fn().mockResolvedValue(true),
};
const mockFlowState = {
serverName: 'test-server',
userId: 'test-user-id',
metadata: { toolFlowId: 'tool-flow-123' },
clientInfo: {},
codeVerifier: 'test-verifier',
};
const mockTokens = {
access_token: 'test-access-token',
refresh_token: 'test-refresh-token',
};
MCPOAuthHandler.getFlowState.mockResolvedValue(mockFlowState);
MCPOAuthHandler.completeOAuthFlow.mockResolvedValue(mockTokens);
MCPTokenStorage.storeTokens.mockRejectedValue(new Error('store failed'));
getLogStores.mockReturnValue({});
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
const mockMcpManager = {
getUserConnection: jest.fn(),
};
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({
code: 'test-auth-code',
state: 'test-flow-id',
});
expect(response.status).toBe(302);
expect(response.headers.location).toBe('/oauth/error?error=callback_failed');
expect(mockMcpManager.getUserConnection).not.toHaveBeenCalled();
});
});
@ -1143,7 +1206,11 @@ describe('MCP Routes', () => {
describe('GET /:serverName/oauth/callback - Edge Cases', () => {
it('should handle OAuth callback without toolFlowId (falsy toolFlowId)', async () => {
const { MCPOAuthHandler } = require('@librechat/api');
const { MCPOAuthHandler, MCPTokenStorage } = require('@librechat/api');
const mockTokens = {
access_token: 'edge-access-token',
refresh_token: 'edge-refresh-token',
};
MCPOAuthHandler.getFlowState = jest.fn().mockResolvedValue({
id: 'test-flow-id',
userId: 'test-user-id',
@ -1155,6 +1222,8 @@ describe('MCP Routes', () => {
clientInfo: {},
codeVerifier: 'test-verifier',
});
MCPOAuthHandler.completeOAuthFlow = jest.fn().mockResolvedValue(mockTokens);
MCPTokenStorage.storeTokens.mockResolvedValue();
const mockFlowManager = {
completeFlow: jest.fn(),
@ -1179,6 +1248,11 @@ describe('MCP Routes', () => {
it('should handle null cached tools in OAuth callback (triggers || {} fallback)', async () => {
const { getCachedTools } = require('~/server/services/Config');
getCachedTools.mockResolvedValue(null);
const { MCPOAuthHandler, MCPTokenStorage } = require('@librechat/api');
const mockTokens = {
access_token: 'edge-access-token',
refresh_token: 'edge-refresh-token',
};
const mockFlowManager = {
getFlowState: jest.fn().mockResolvedValue({
@ -1191,6 +1265,15 @@ describe('MCP Routes', () => {
completeFlow: jest.fn(),
};
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
MCPOAuthHandler.getFlowState.mockResolvedValue({
serverName: 'test-server',
userId: 'test-user-id',
metadata: { serverUrl: 'https://example.com', oauth: {} },
clientInfo: {},
codeVerifier: 'test-verifier',
});
MCPOAuthHandler.completeOAuthFlow.mockResolvedValue(mockTokens);
MCPTokenStorage.storeTokens.mockResolvedValue();
const mockMcpManager = {
getUserConnection: jest.fn().mockResolvedValue({