mirror of
https://github.com/danny-avila/LibreChat.git
synced 2025-12-17 00:40:14 +01:00
🔐 fix: MCP OAuth Token Persistence Race Condition and Refresh Auth Method (#9773)
* set supported endpoint auth method when token_url exists * persist tokens immediately * add token storage validation tests
This commit is contained in:
parent
91e49d82aa
commit
f61e057f7f
3 changed files with 122 additions and 3 deletions
|
|
@ -11,6 +11,9 @@ jest.mock('@librechat/api', () => ({
|
||||||
completeOAuthFlow: jest.fn(),
|
completeOAuthFlow: jest.fn(),
|
||||||
generateFlowId: jest.fn(),
|
generateFlowId: jest.fn(),
|
||||||
},
|
},
|
||||||
|
MCPTokenStorage: {
|
||||||
|
storeTokens: jest.fn(),
|
||||||
|
},
|
||||||
getUserMCPAuthMap: jest.fn(),
|
getUserMCPAuthMap: jest.fn(),
|
||||||
}));
|
}));
|
||||||
|
|
||||||
|
|
@ -234,7 +237,7 @@ describe('MCP Routes', () => {
|
||||||
});
|
});
|
||||||
|
|
||||||
describe('GET /:serverName/oauth/callback', () => {
|
describe('GET /:serverName/oauth/callback', () => {
|
||||||
const { MCPOAuthHandler } = require('@librechat/api');
|
const { MCPOAuthHandler, MCPTokenStorage } = require('@librechat/api');
|
||||||
const { getLogStores } = require('~/cache');
|
const { getLogStores } = require('~/cache');
|
||||||
|
|
||||||
it('should redirect to error page when OAuth error is received', async () => {
|
it('should redirect to error page when OAuth error is received', async () => {
|
||||||
|
|
@ -280,6 +283,7 @@ describe('MCP Routes', () => {
|
||||||
it('should handle OAuth callback successfully', async () => {
|
it('should handle OAuth callback successfully', async () => {
|
||||||
const mockFlowManager = {
|
const mockFlowManager = {
|
||||||
completeFlow: jest.fn().mockResolvedValue(),
|
completeFlow: jest.fn().mockResolvedValue(),
|
||||||
|
deleteFlow: jest.fn().mockResolvedValue(true),
|
||||||
};
|
};
|
||||||
const mockFlowState = {
|
const mockFlowState = {
|
||||||
serverName: 'test-server',
|
serverName: 'test-server',
|
||||||
|
|
@ -295,6 +299,7 @@ describe('MCP Routes', () => {
|
||||||
|
|
||||||
MCPOAuthHandler.getFlowState.mockResolvedValue(mockFlowState);
|
MCPOAuthHandler.getFlowState.mockResolvedValue(mockFlowState);
|
||||||
MCPOAuthHandler.completeOAuthFlow.mockResolvedValue(mockTokens);
|
MCPOAuthHandler.completeOAuthFlow.mockResolvedValue(mockTokens);
|
||||||
|
MCPTokenStorage.storeTokens.mockResolvedValue();
|
||||||
getLogStores.mockReturnValue({});
|
getLogStores.mockReturnValue({});
|
||||||
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
|
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
|
||||||
|
|
||||||
|
|
@ -332,11 +337,24 @@ describe('MCP Routes', () => {
|
||||||
'test-auth-code',
|
'test-auth-code',
|
||||||
mockFlowManager,
|
mockFlowManager,
|
||||||
);
|
);
|
||||||
|
expect(MCPTokenStorage.storeTokens).toHaveBeenCalledWith(
|
||||||
|
expect.objectContaining({
|
||||||
|
userId: 'test-user-id',
|
||||||
|
serverName: 'test-server',
|
||||||
|
tokens: mockTokens,
|
||||||
|
clientInfo: mockFlowState.clientInfo,
|
||||||
|
metadata: mockFlowState.metadata,
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
const storeInvocation = MCPTokenStorage.storeTokens.mock.invocationCallOrder[0];
|
||||||
|
const connectInvocation = mockMcpManager.getUserConnection.mock.invocationCallOrder[0];
|
||||||
|
expect(storeInvocation).toBeLessThan(connectInvocation);
|
||||||
expect(mockFlowManager.completeFlow).toHaveBeenCalledWith(
|
expect(mockFlowManager.completeFlow).toHaveBeenCalledWith(
|
||||||
'tool-flow-123',
|
'tool-flow-123',
|
||||||
'mcp_oauth',
|
'mcp_oauth',
|
||||||
mockTokens,
|
mockTokens,
|
||||||
);
|
);
|
||||||
|
expect(mockFlowManager.deleteFlow).toHaveBeenCalledWith('test-flow-id', 'mcp_get_tokens');
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should redirect to error page when callback processing fails', async () => {
|
it('should redirect to error page when callback processing fails', async () => {
|
||||||
|
|
@ -354,6 +372,7 @@ describe('MCP Routes', () => {
|
||||||
it('should handle system-level OAuth completion', async () => {
|
it('should handle system-level OAuth completion', async () => {
|
||||||
const mockFlowManager = {
|
const mockFlowManager = {
|
||||||
completeFlow: jest.fn().mockResolvedValue(),
|
completeFlow: jest.fn().mockResolvedValue(),
|
||||||
|
deleteFlow: jest.fn().mockResolvedValue(true),
|
||||||
};
|
};
|
||||||
const mockFlowState = {
|
const mockFlowState = {
|
||||||
serverName: 'test-server',
|
serverName: 'test-server',
|
||||||
|
|
@ -369,6 +388,7 @@ describe('MCP Routes', () => {
|
||||||
|
|
||||||
MCPOAuthHandler.getFlowState.mockResolvedValue(mockFlowState);
|
MCPOAuthHandler.getFlowState.mockResolvedValue(mockFlowState);
|
||||||
MCPOAuthHandler.completeOAuthFlow.mockResolvedValue(mockTokens);
|
MCPOAuthHandler.completeOAuthFlow.mockResolvedValue(mockTokens);
|
||||||
|
MCPTokenStorage.storeTokens.mockResolvedValue();
|
||||||
getLogStores.mockReturnValue({});
|
getLogStores.mockReturnValue({});
|
||||||
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
|
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
|
||||||
|
|
||||||
|
|
@ -379,11 +399,13 @@ describe('MCP Routes', () => {
|
||||||
|
|
||||||
expect(response.status).toBe(302);
|
expect(response.status).toBe(302);
|
||||||
expect(response.headers.location).toBe('/oauth/success?serverName=test-server');
|
expect(response.headers.location).toBe('/oauth/success?serverName=test-server');
|
||||||
|
expect(mockFlowManager.deleteFlow).toHaveBeenCalledWith('test-flow-id', 'mcp_get_tokens');
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should handle reconnection failure after OAuth', async () => {
|
it('should handle reconnection failure after OAuth', async () => {
|
||||||
const mockFlowManager = {
|
const mockFlowManager = {
|
||||||
completeFlow: jest.fn().mockResolvedValue(),
|
completeFlow: jest.fn().mockResolvedValue(),
|
||||||
|
deleteFlow: jest.fn().mockResolvedValue(true),
|
||||||
};
|
};
|
||||||
const mockFlowState = {
|
const mockFlowState = {
|
||||||
serverName: 'test-server',
|
serverName: 'test-server',
|
||||||
|
|
@ -399,6 +421,7 @@ describe('MCP Routes', () => {
|
||||||
|
|
||||||
MCPOAuthHandler.getFlowState.mockResolvedValue(mockFlowState);
|
MCPOAuthHandler.getFlowState.mockResolvedValue(mockFlowState);
|
||||||
MCPOAuthHandler.completeOAuthFlow.mockResolvedValue(mockTokens);
|
MCPOAuthHandler.completeOAuthFlow.mockResolvedValue(mockTokens);
|
||||||
|
MCPTokenStorage.storeTokens.mockResolvedValue();
|
||||||
getLogStores.mockReturnValue({});
|
getLogStores.mockReturnValue({});
|
||||||
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
|
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
|
||||||
|
|
||||||
|
|
@ -418,6 +441,46 @@ describe('MCP Routes', () => {
|
||||||
|
|
||||||
expect(response.status).toBe(302);
|
expect(response.status).toBe(302);
|
||||||
expect(response.headers.location).toBe('/oauth/success?serverName=test-server');
|
expect(response.headers.location).toBe('/oauth/success?serverName=test-server');
|
||||||
|
expect(MCPTokenStorage.storeTokens).toHaveBeenCalled();
|
||||||
|
expect(mockFlowManager.deleteFlow).toHaveBeenCalledWith('test-flow-id', 'mcp_get_tokens');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should redirect to error page if token storage fails', async () => {
|
||||||
|
const mockFlowManager = {
|
||||||
|
completeFlow: jest.fn().mockResolvedValue(),
|
||||||
|
deleteFlow: jest.fn().mockResolvedValue(true),
|
||||||
|
};
|
||||||
|
const mockFlowState = {
|
||||||
|
serverName: 'test-server',
|
||||||
|
userId: 'test-user-id',
|
||||||
|
metadata: { toolFlowId: 'tool-flow-123' },
|
||||||
|
clientInfo: {},
|
||||||
|
codeVerifier: 'test-verifier',
|
||||||
|
};
|
||||||
|
const mockTokens = {
|
||||||
|
access_token: 'test-access-token',
|
||||||
|
refresh_token: 'test-refresh-token',
|
||||||
|
};
|
||||||
|
|
||||||
|
MCPOAuthHandler.getFlowState.mockResolvedValue(mockFlowState);
|
||||||
|
MCPOAuthHandler.completeOAuthFlow.mockResolvedValue(mockTokens);
|
||||||
|
MCPTokenStorage.storeTokens.mockRejectedValue(new Error('store failed'));
|
||||||
|
getLogStores.mockReturnValue({});
|
||||||
|
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
|
||||||
|
|
||||||
|
const mockMcpManager = {
|
||||||
|
getUserConnection: jest.fn(),
|
||||||
|
};
|
||||||
|
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
|
||||||
|
|
||||||
|
const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({
|
||||||
|
code: 'test-auth-code',
|
||||||
|
state: 'test-flow-id',
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(response.status).toBe(302);
|
||||||
|
expect(response.headers.location).toBe('/oauth/error?error=callback_failed');
|
||||||
|
expect(mockMcpManager.getUserConnection).not.toHaveBeenCalled();
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
@ -1143,7 +1206,11 @@ describe('MCP Routes', () => {
|
||||||
|
|
||||||
describe('GET /:serverName/oauth/callback - Edge Cases', () => {
|
describe('GET /:serverName/oauth/callback - Edge Cases', () => {
|
||||||
it('should handle OAuth callback without toolFlowId (falsy toolFlowId)', async () => {
|
it('should handle OAuth callback without toolFlowId (falsy toolFlowId)', async () => {
|
||||||
const { MCPOAuthHandler } = require('@librechat/api');
|
const { MCPOAuthHandler, MCPTokenStorage } = require('@librechat/api');
|
||||||
|
const mockTokens = {
|
||||||
|
access_token: 'edge-access-token',
|
||||||
|
refresh_token: 'edge-refresh-token',
|
||||||
|
};
|
||||||
MCPOAuthHandler.getFlowState = jest.fn().mockResolvedValue({
|
MCPOAuthHandler.getFlowState = jest.fn().mockResolvedValue({
|
||||||
id: 'test-flow-id',
|
id: 'test-flow-id',
|
||||||
userId: 'test-user-id',
|
userId: 'test-user-id',
|
||||||
|
|
@ -1155,6 +1222,8 @@ describe('MCP Routes', () => {
|
||||||
clientInfo: {},
|
clientInfo: {},
|
||||||
codeVerifier: 'test-verifier',
|
codeVerifier: 'test-verifier',
|
||||||
});
|
});
|
||||||
|
MCPOAuthHandler.completeOAuthFlow = jest.fn().mockResolvedValue(mockTokens);
|
||||||
|
MCPTokenStorage.storeTokens.mockResolvedValue();
|
||||||
|
|
||||||
const mockFlowManager = {
|
const mockFlowManager = {
|
||||||
completeFlow: jest.fn(),
|
completeFlow: jest.fn(),
|
||||||
|
|
@ -1179,6 +1248,11 @@ describe('MCP Routes', () => {
|
||||||
it('should handle null cached tools in OAuth callback (triggers || {} fallback)', async () => {
|
it('should handle null cached tools in OAuth callback (triggers || {} fallback)', async () => {
|
||||||
const { getCachedTools } = require('~/server/services/Config');
|
const { getCachedTools } = require('~/server/services/Config');
|
||||||
getCachedTools.mockResolvedValue(null);
|
getCachedTools.mockResolvedValue(null);
|
||||||
|
const { MCPOAuthHandler, MCPTokenStorage } = require('@librechat/api');
|
||||||
|
const mockTokens = {
|
||||||
|
access_token: 'edge-access-token',
|
||||||
|
refresh_token: 'edge-refresh-token',
|
||||||
|
};
|
||||||
|
|
||||||
const mockFlowManager = {
|
const mockFlowManager = {
|
||||||
getFlowState: jest.fn().mockResolvedValue({
|
getFlowState: jest.fn().mockResolvedValue({
|
||||||
|
|
@ -1191,6 +1265,15 @@ describe('MCP Routes', () => {
|
||||||
completeFlow: jest.fn(),
|
completeFlow: jest.fn(),
|
||||||
};
|
};
|
||||||
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
|
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
|
||||||
|
MCPOAuthHandler.getFlowState.mockResolvedValue({
|
||||||
|
serverName: 'test-server',
|
||||||
|
userId: 'test-user-id',
|
||||||
|
metadata: { serverUrl: 'https://example.com', oauth: {} },
|
||||||
|
clientInfo: {},
|
||||||
|
codeVerifier: 'test-verifier',
|
||||||
|
});
|
||||||
|
MCPOAuthHandler.completeOAuthFlow.mockResolvedValue(mockTokens);
|
||||||
|
MCPTokenStorage.storeTokens.mockResolvedValue();
|
||||||
|
|
||||||
const mockMcpManager = {
|
const mockMcpManager = {
|
||||||
getUserConnection: jest.fn().mockResolvedValue({
|
getUserConnection: jest.fn().mockResolvedValue({
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
const { Router } = require('express');
|
const { Router } = require('express');
|
||||||
const { logger } = require('@librechat/data-schemas');
|
const { logger } = require('@librechat/data-schemas');
|
||||||
const { CacheKeys, Constants } = require('librechat-data-provider');
|
const { CacheKeys, Constants } = require('librechat-data-provider');
|
||||||
const { MCPOAuthHandler, getUserMCPAuthMap } = require('@librechat/api');
|
const { MCPOAuthHandler, MCPTokenStorage, getUserMCPAuthMap } = require('@librechat/api');
|
||||||
const { getMCPManager, getFlowStateManager, getOAuthReconnectionManager } = require('~/config');
|
const { getMCPManager, getFlowStateManager, getOAuthReconnectionManager } = require('~/config');
|
||||||
const { getMCPSetupData, getServerConnectionStatus } = require('~/server/services/MCP');
|
const { getMCPSetupData, getServerConnectionStatus } = require('~/server/services/MCP');
|
||||||
const { findToken, updateToken, createToken, deleteTokens } = require('~/models');
|
const { findToken, updateToken, createToken, deleteTokens } = require('~/models');
|
||||||
|
|
@ -130,6 +130,41 @@ router.get('/:serverName/oauth/callback', async (req, res) => {
|
||||||
const tokens = await MCPOAuthHandler.completeOAuthFlow(flowId, code, flowManager);
|
const tokens = await MCPOAuthHandler.completeOAuthFlow(flowId, code, flowManager);
|
||||||
logger.info('[MCP OAuth] OAuth flow completed, tokens received in callback route');
|
logger.info('[MCP OAuth] OAuth flow completed, tokens received in callback route');
|
||||||
|
|
||||||
|
/** Persist tokens immediately so reconnection uses fresh credentials */
|
||||||
|
if (flowState?.userId && tokens) {
|
||||||
|
try {
|
||||||
|
await MCPTokenStorage.storeTokens({
|
||||||
|
userId: flowState.userId,
|
||||||
|
serverName,
|
||||||
|
tokens,
|
||||||
|
createToken,
|
||||||
|
updateToken,
|
||||||
|
findToken,
|
||||||
|
clientInfo: flowState.clientInfo,
|
||||||
|
metadata: flowState.metadata,
|
||||||
|
});
|
||||||
|
logger.debug('[MCP OAuth] Stored OAuth tokens prior to reconnection', {
|
||||||
|
serverName,
|
||||||
|
userId: flowState.userId,
|
||||||
|
});
|
||||||
|
} catch (error) {
|
||||||
|
logger.error('[MCP OAuth] Failed to store OAuth tokens after callback', error);
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Clear any cached `mcp_get_tokens` flow result so subsequent lookups
|
||||||
|
* re-fetch the freshly stored credentials instead of returning stale nulls.
|
||||||
|
*/
|
||||||
|
if (typeof flowManager?.deleteFlow === 'function') {
|
||||||
|
try {
|
||||||
|
await flowManager.deleteFlow(flowId, 'mcp_get_tokens');
|
||||||
|
} catch (error) {
|
||||||
|
logger.warn('[MCP OAuth] Failed to clear cached token flow state', error);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const mcpManager = getMCPManager(flowState.userId);
|
const mcpManager = getMCPManager(flowState.userId);
|
||||||
logger.debug(`[MCP OAuth] Attempting to reconnect ${serverName} with new OAuth tokens`);
|
logger.debug(`[MCP OAuth] Attempting to reconnect ${serverName} with new OAuth tokens`);
|
||||||
|
|
|
||||||
|
|
@ -504,6 +504,7 @@ export class MCPOAuthHandler {
|
||||||
let authMethods: string[] | undefined;
|
let authMethods: string[] | undefined;
|
||||||
if (config?.token_url) {
|
if (config?.token_url) {
|
||||||
tokenUrl = config.token_url;
|
tokenUrl = config.token_url;
|
||||||
|
authMethods = config.token_endpoint_auth_methods_supported;
|
||||||
} else if (!metadata.serverUrl) {
|
} else if (!metadata.serverUrl) {
|
||||||
throw new Error('No token URL available for refresh');
|
throw new Error('No token URL available for refresh');
|
||||||
} else {
|
} else {
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue