💫 feat: MCP OAuth Auto-Reconnect (#9646)

* add oauth reconnect tracker

* add connection tracker to mcp manager

* reconnect oauth mcp servers function

* call reconnection in auth controller

* make sure to check connection in panel

* wait for isConnected

* add const for poll interval

* add logging to tryReconnect

* check expiration

* check mcp manager is not null

* check mcp manager is not null

* add test for reconnecting mcp server

* unify logic inside OAuthReconnectionManager

* test reconnection manager, adjust

* chore: reorder import statements in index.js

* chore: imports

* chore: imports

* chore: imports

* chore: imports

* chore: imports

* chore: imports and use types explicitly

---------

Co-authored-by: Danny Avila <danny@librechat.ai>
This commit is contained in:
Federico Ruggi 2025-09-17 22:49:36 +02:00 committed by GitHub
parent 0e94d97bfb
commit d04da60b3b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 830 additions and 13 deletions

View file

@ -20,8 +20,8 @@ const {
ContentTypes,
isAssistantsEndpoint,
} = require('librechat-data-provider');
const { getMCPManager, getFlowStateManager, getOAuthReconnectionManager } = require('~/config');
const { findToken, createToken, updateToken } = require('~/models');
const { getMCPManager, getFlowStateManager } = require('~/config');
const { getCachedTools, getAppConfig } = require('./Config');
const { reinitMCPServer } = require('./Tools/mcp');
const { getLogStores } = require('~/cache');
@ -538,13 +538,20 @@ async function getServerConnectionStatus(
const baseConnectionState = getConnectionState();
let finalConnectionState = baseConnectionState;
// connection state overrides specific to OAuth servers
if (baseConnectionState === 'disconnected' && oauthServers.has(serverName)) {
const { hasActiveFlow, hasFailedFlow } = await checkOAuthFlowStatus(userId, serverName);
if (hasFailedFlow) {
finalConnectionState = 'error';
} else if (hasActiveFlow) {
// check if server is actively being reconnected
const oauthReconnectionManager = getOAuthReconnectionManager();
if (oauthReconnectionManager.isReconnecting(userId, serverName)) {
finalConnectionState = 'connecting';
} else {
const { hasActiveFlow, hasFailedFlow } = await checkOAuthFlowStatus(userId, serverName);
if (hasFailedFlow) {
finalConnectionState = 'error';
} else if (hasActiveFlow) {
finalConnectionState = 'connecting';
}
}
}

View file

@ -31,6 +31,7 @@ jest.mock('./Config', () => ({
jest.mock('~/config', () => ({
getMCPManager: jest.fn(),
getFlowStateManager: jest.fn(),
getOAuthReconnectionManager: jest.fn(),
}));
jest.mock('~/cache', () => ({
@ -48,6 +49,7 @@ describe('tests for the new helper functions used by the MCP connection status e
let mockGetMCPManager;
let mockGetFlowStateManager;
let mockGetLogStores;
let mockGetOAuthReconnectionManager;
beforeEach(() => {
jest.clearAllMocks();
@ -56,6 +58,7 @@ describe('tests for the new helper functions used by the MCP connection status e
mockGetMCPManager = require('~/config').getMCPManager;
mockGetFlowStateManager = require('~/config').getFlowStateManager;
mockGetLogStores = require('~/cache').getLogStores;
mockGetOAuthReconnectionManager = require('~/config').getOAuthReconnectionManager;
});
describe('getMCPSetupData', () => {
@ -354,6 +357,12 @@ describe('tests for the new helper functions used by the MCP connection status e
const userConnections = new Map();
const oauthServers = new Set([mockServerName]);
// Mock OAuthReconnectionManager
const mockOAuthReconnectionManager = {
isReconnecting: jest.fn(() => false),
};
mockGetOAuthReconnectionManager.mockReturnValue(mockOAuthReconnectionManager);
const result = await getServerConnectionStatus(
mockUserId,
mockServerName,
@ -370,6 +379,12 @@ describe('tests for the new helper functions used by the MCP connection status e
const userConnections = new Map();
const oauthServers = new Set([mockServerName]);
// Mock OAuthReconnectionManager
const mockOAuthReconnectionManager = {
isReconnecting: jest.fn(() => false),
};
mockGetOAuthReconnectionManager.mockReturnValue(mockOAuthReconnectionManager);
// Mock flow state to return failed flow
const mockFlowManager = {
getFlowState: jest.fn(() => ({
@ -401,6 +416,12 @@ describe('tests for the new helper functions used by the MCP connection status e
const userConnections = new Map();
const oauthServers = new Set([mockServerName]);
// Mock OAuthReconnectionManager
const mockOAuthReconnectionManager = {
isReconnecting: jest.fn(() => false),
};
mockGetOAuthReconnectionManager.mockReturnValue(mockOAuthReconnectionManager);
// Mock flow state to return active flow
const mockFlowManager = {
getFlowState: jest.fn(() => ({
@ -432,6 +453,12 @@ describe('tests for the new helper functions used by the MCP connection status e
const userConnections = new Map();
const oauthServers = new Set([mockServerName]);
// Mock OAuthReconnectionManager
const mockOAuthReconnectionManager = {
isReconnecting: jest.fn(() => false),
};
mockGetOAuthReconnectionManager.mockReturnValue(mockOAuthReconnectionManager);
// Mock flow state to return no flow
const mockFlowManager = {
getFlowState: jest.fn(() => null),
@ -454,6 +481,35 @@ describe('tests for the new helper functions used by the MCP connection status e
});
});
it('should return connecting state when OAuth server is reconnecting', async () => {
const appConnections = new Map();
const userConnections = new Map();
const oauthServers = new Set([mockServerName]);
// Mock OAuthReconnectionManager to return true for isReconnecting
const mockOAuthReconnectionManager = {
isReconnecting: jest.fn(() => true),
};
mockGetOAuthReconnectionManager.mockReturnValue(mockOAuthReconnectionManager);
const result = await getServerConnectionStatus(
mockUserId,
mockServerName,
appConnections,
userConnections,
oauthServers,
);
expect(result).toEqual({
requiresOAuth: true,
connectionState: 'connecting',
});
expect(mockOAuthReconnectionManager.isReconnecting).toHaveBeenCalledWith(
mockUserId,
mockServerName,
);
});
it('should not check OAuth flow status when server is connected', async () => {
const mockFlowManager = {
getFlowState: jest.fn(),

View file

@ -0,0 +1,26 @@
const { logger } = require('@librechat/data-schemas');
const { CacheKeys } = require('librechat-data-provider');
const { createOAuthReconnectionManager, getFlowStateManager } = require('~/config');
const { findToken, updateToken, createToken, deleteTokens } = require('~/models');
const { getLogStores } = require('~/cache');
/**
* Initialize OAuth reconnect manager
*/
async function initializeOAuthReconnectManager() {
try {
const flowManager = getFlowStateManager(getLogStores(CacheKeys.FLOWS));
const tokenMethods = {
findToken,
updateToken,
createToken,
deleteTokens,
};
await createOAuthReconnectionManager(flowManager, tokenMethods);
logger.info(`OAuth reconnect manager initialized successfully.`);
} catch (error) {
logger.error('Failed to initialize OAuth reconnect manager:', error);
}
}
module.exports = initializeOAuthReconnectManager;