From d04da60b3be26cbab283ceb640718142cd946c32 Mon Sep 17 00:00:00 2001 From: Federico Ruggi Date: Wed, 17 Sep 2025 22:49:36 +0200 Subject: [PATCH] =?UTF-8?q?=F0=9F=92=AB=20feat:=20MCP=20OAuth=20Auto-Recon?= =?UTF-8?q?nect=20(#9646)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add oauth reconnect tracker * add connection tracker to mcp manager * reconnect oauth mcp servers function * call reconnection in auth controller * make sure to check connection in panel * wait for isConnected * add const for poll interval * add logging to tryReconnect * check expiration * check mcp manager is not null * check mcp manager is not null * add test for reconnecting mcp server * unify logic inside OAuthReconnectionManager * test reconnection manager, adjust * chore: reorder import statements in index.js * chore: imports * chore: imports * chore: imports * chore: imports * chore: imports * chore: imports and use types explicitly --------- Co-authored-by: Danny Avila --- api/config/index.js | 4 +- api/server/controllers/AuthController.js | 11 +- api/server/index.js | 7 +- api/server/routes/mcp.js | 8 +- api/server/services/MCP.js | 19 +- api/server/services/MCP.spec.js | 56 ++++ .../initializeOAuthReconnectManager.js | 26 ++ .../src/components/SidePanel/MCP/MCPPanel.tsx | 27 +- packages/api/src/index.ts | 1 + .../oauth/OAuthReconnectionManager.test.ts | 294 ++++++++++++++++++ .../src/mcp/oauth/OAuthReconnectionManager.ts | 163 ++++++++++ .../oauth/OAuthReconnectionTracker.test.ts | 181 +++++++++++ .../src/mcp/oauth/OAuthReconnectionTracker.ts | 46 +++ 13 files changed, 830 insertions(+), 13 deletions(-) create mode 100644 api/server/services/initializeOAuthReconnectManager.js create mode 100644 packages/api/src/mcp/oauth/OAuthReconnectionManager.test.ts create mode 100644 packages/api/src/mcp/oauth/OAuthReconnectionManager.ts create mode 100644 packages/api/src/mcp/oauth/OAuthReconnectionTracker.test.ts create mode 100644 packages/api/src/mcp/oauth/OAuthReconnectionTracker.ts diff --git a/api/config/index.js b/api/config/index.js index 2ffcf1cdf..0ddbee166 100644 --- a/api/config/index.js +++ b/api/config/index.js @@ -1,6 +1,6 @@ -const { MCPManager, FlowStateManager } = require('@librechat/api'); const { EventSource } = require('eventsource'); const { Time } = require('librechat-data-provider'); +const { MCPManager, FlowStateManager, OAuthReconnectionManager } = require('@librechat/api'); const logger = require('./winston'); global.EventSource = EventSource; @@ -26,4 +26,6 @@ module.exports = { createMCPManager: MCPManager.createInstance, getMCPManager: MCPManager.getInstance, getFlowStateManager, + createOAuthReconnectionManager: OAuthReconnectionManager.createInstance, + getOAuthReconnectionManager: OAuthReconnectionManager.getInstance, }; diff --git a/api/server/controllers/AuthController.js b/api/server/controllers/AuthController.js index 08d649932..e1d0977f1 100644 --- a/api/server/controllers/AuthController.js +++ b/api/server/controllers/AuthController.js @@ -11,8 +11,9 @@ const { registerUser, } = require('~/server/services/AuthService'); const { findUser, getUserById, deleteAllUserSessions, findSession } = require('~/models'); -const { getOpenIdConfig } = require('~/strategies'); const { getGraphApiToken } = require('~/server/services/GraphTokenService'); +const { getOAuthReconnectionManager } = require('~/config'); +const { getOpenIdConfig } = require('~/strategies'); const registrationController = async (req, res) => { try { @@ -107,6 +108,14 @@ const refreshController = async (req, res) => { if (session && session.expiration > new Date()) { const token = await setAuthTokens(userId, res, session); + + // trigger OAuth MCP server reconnection asynchronously (best effort) + void getOAuthReconnectionManager() + .reconnectServers(userId) + .catch((err) => { + logger.error('Error reconnecting OAuth MCP servers:', err); + }); + res.status(200).send({ token, user }); } else if (req?.query?.retry) { // Retrying from a refresh token request that failed (401) diff --git a/api/server/index.js b/api/server/index.js index a8f9ac24e..e458b0349 100644 --- a/api/server/index.js +++ b/api/server/index.js @@ -12,6 +12,7 @@ const { logger } = require('@librechat/data-schemas'); const mongoSanitize = require('express-mongo-sanitize'); const { isEnabled, ErrorController } = require('@librechat/api'); const { connectDb, indexSync } = require('~/db'); +const initializeOAuthReconnectManager = require('./services/initializeOAuthReconnectManager'); const createValidateImageRequest = require('./middleware/validateImageRequest'); const { jwtLogin, ldapLogin, passportLogin } = require('~/strategies'); const { updateInterfacePermissions } = require('~/models/interface'); @@ -154,7 +155,7 @@ const startServer = async () => { res.send(updatedIndexHtml); }); - app.listen(port, host, () => { + app.listen(port, host, async () => { if (host === '0.0.0.0') { logger.info( `Server listening on all interfaces at port ${port}. Use http://localhost:${port} to access it`, @@ -163,7 +164,9 @@ const startServer = async () => { logger.info(`Server listening at http://${host == '0.0.0.0' ? 'localhost' : host}:${port}`); } - initializeMCPs().then(() => checkMigrations()); + await initializeMCPs(); + await initializeOAuthReconnectManager(); + await checkMigrations(); }); }; diff --git a/api/server/routes/mcp.js b/api/server/routes/mcp.js index d41cc6d73..bff919158 100644 --- a/api/server/routes/mcp.js +++ b/api/server/routes/mcp.js @@ -1,12 +1,12 @@ const { Router } = require('express'); const { logger } = require('@librechat/data-schemas'); +const { CacheKeys, Constants } = require('librechat-data-provider'); const { MCPOAuthHandler, getUserMCPAuthMap } = require('@librechat/api'); +const { getMCPManager, getFlowStateManager, getOAuthReconnectionManager } = require('~/config'); const { getMCPSetupData, getServerConnectionStatus } = require('~/server/services/MCP'); const { findToken, updateToken, createToken, deleteTokens } = require('~/models'); const { updateMCPUserTools } = require('~/server/services/Config/mcpToolsCache'); const { getUserPluginAuthValue } = require('~/server/services/PluginService'); -const { CacheKeys, Constants } = require('librechat-data-provider'); -const { getMCPManager, getFlowStateManager } = require('~/config'); const { reinitMCPServer } = require('~/server/services/Tools/mcp'); const { requireJwtAuth } = require('~/server/middleware'); const { findPluginAuthsByKeys } = require('~/models'); @@ -144,6 +144,10 @@ router.get('/:serverName/oauth/callback', async (req, res) => { `[MCP OAuth] Successfully reconnected ${serverName} for user ${flowState.userId}`, ); + // clear any reconnection attempts + const oauthReconnectionManager = getOAuthReconnectionManager(); + oauthReconnectionManager.clearReconnection(flowState.userId, serverName); + const tools = await userConnection.fetchTools(); await updateMCPUserTools({ userId: flowState.userId, diff --git a/api/server/services/MCP.js b/api/server/services/MCP.js index 3521f19ab..bc32dabeb 100644 --- a/api/server/services/MCP.js +++ b/api/server/services/MCP.js @@ -20,8 +20,8 @@ const { ContentTypes, isAssistantsEndpoint, } = require('librechat-data-provider'); +const { getMCPManager, getFlowStateManager, getOAuthReconnectionManager } = require('~/config'); const { findToken, createToken, updateToken } = require('~/models'); -const { getMCPManager, getFlowStateManager } = require('~/config'); const { getCachedTools, getAppConfig } = require('./Config'); const { reinitMCPServer } = require('./Tools/mcp'); const { getLogStores } = require('~/cache'); @@ -538,13 +538,20 @@ async function getServerConnectionStatus( const baseConnectionState = getConnectionState(); let finalConnectionState = baseConnectionState; + // connection state overrides specific to OAuth servers if (baseConnectionState === 'disconnected' && oauthServers.has(serverName)) { - const { hasActiveFlow, hasFailedFlow } = await checkOAuthFlowStatus(userId, serverName); - - if (hasFailedFlow) { - finalConnectionState = 'error'; - } else if (hasActiveFlow) { + // check if server is actively being reconnected + const oauthReconnectionManager = getOAuthReconnectionManager(); + if (oauthReconnectionManager.isReconnecting(userId, serverName)) { finalConnectionState = 'connecting'; + } else { + const { hasActiveFlow, hasFailedFlow } = await checkOAuthFlowStatus(userId, serverName); + + if (hasFailedFlow) { + finalConnectionState = 'error'; + } else if (hasActiveFlow) { + finalConnectionState = 'connecting'; + } } } diff --git a/api/server/services/MCP.spec.js b/api/server/services/MCP.spec.js index 3751c8a88..cb90ace24 100644 --- a/api/server/services/MCP.spec.js +++ b/api/server/services/MCP.spec.js @@ -31,6 +31,7 @@ jest.mock('./Config', () => ({ jest.mock('~/config', () => ({ getMCPManager: jest.fn(), getFlowStateManager: jest.fn(), + getOAuthReconnectionManager: jest.fn(), })); jest.mock('~/cache', () => ({ @@ -48,6 +49,7 @@ describe('tests for the new helper functions used by the MCP connection status e let mockGetMCPManager; let mockGetFlowStateManager; let mockGetLogStores; + let mockGetOAuthReconnectionManager; beforeEach(() => { jest.clearAllMocks(); @@ -56,6 +58,7 @@ describe('tests for the new helper functions used by the MCP connection status e mockGetMCPManager = require('~/config').getMCPManager; mockGetFlowStateManager = require('~/config').getFlowStateManager; mockGetLogStores = require('~/cache').getLogStores; + mockGetOAuthReconnectionManager = require('~/config').getOAuthReconnectionManager; }); describe('getMCPSetupData', () => { @@ -354,6 +357,12 @@ describe('tests for the new helper functions used by the MCP connection status e const userConnections = new Map(); const oauthServers = new Set([mockServerName]); + // Mock OAuthReconnectionManager + const mockOAuthReconnectionManager = { + isReconnecting: jest.fn(() => false), + }; + mockGetOAuthReconnectionManager.mockReturnValue(mockOAuthReconnectionManager); + const result = await getServerConnectionStatus( mockUserId, mockServerName, @@ -370,6 +379,12 @@ describe('tests for the new helper functions used by the MCP connection status e const userConnections = new Map(); const oauthServers = new Set([mockServerName]); + // Mock OAuthReconnectionManager + const mockOAuthReconnectionManager = { + isReconnecting: jest.fn(() => false), + }; + mockGetOAuthReconnectionManager.mockReturnValue(mockOAuthReconnectionManager); + // Mock flow state to return failed flow const mockFlowManager = { getFlowState: jest.fn(() => ({ @@ -401,6 +416,12 @@ describe('tests for the new helper functions used by the MCP connection status e const userConnections = new Map(); const oauthServers = new Set([mockServerName]); + // Mock OAuthReconnectionManager + const mockOAuthReconnectionManager = { + isReconnecting: jest.fn(() => false), + }; + mockGetOAuthReconnectionManager.mockReturnValue(mockOAuthReconnectionManager); + // Mock flow state to return active flow const mockFlowManager = { getFlowState: jest.fn(() => ({ @@ -432,6 +453,12 @@ describe('tests for the new helper functions used by the MCP connection status e const userConnections = new Map(); const oauthServers = new Set([mockServerName]); + // Mock OAuthReconnectionManager + const mockOAuthReconnectionManager = { + isReconnecting: jest.fn(() => false), + }; + mockGetOAuthReconnectionManager.mockReturnValue(mockOAuthReconnectionManager); + // Mock flow state to return no flow const mockFlowManager = { getFlowState: jest.fn(() => null), @@ -454,6 +481,35 @@ describe('tests for the new helper functions used by the MCP connection status e }); }); + it('should return connecting state when OAuth server is reconnecting', async () => { + const appConnections = new Map(); + const userConnections = new Map(); + const oauthServers = new Set([mockServerName]); + + // Mock OAuthReconnectionManager to return true for isReconnecting + const mockOAuthReconnectionManager = { + isReconnecting: jest.fn(() => true), + }; + mockGetOAuthReconnectionManager.mockReturnValue(mockOAuthReconnectionManager); + + const result = await getServerConnectionStatus( + mockUserId, + mockServerName, + appConnections, + userConnections, + oauthServers, + ); + + expect(result).toEqual({ + requiresOAuth: true, + connectionState: 'connecting', + }); + expect(mockOAuthReconnectionManager.isReconnecting).toHaveBeenCalledWith( + mockUserId, + mockServerName, + ); + }); + it('should not check OAuth flow status when server is connected', async () => { const mockFlowManager = { getFlowState: jest.fn(), diff --git a/api/server/services/initializeOAuthReconnectManager.js b/api/server/services/initializeOAuthReconnectManager.js new file mode 100644 index 000000000..a3af3a736 --- /dev/null +++ b/api/server/services/initializeOAuthReconnectManager.js @@ -0,0 +1,26 @@ +const { logger } = require('@librechat/data-schemas'); +const { CacheKeys } = require('librechat-data-provider'); +const { createOAuthReconnectionManager, getFlowStateManager } = require('~/config'); +const { findToken, updateToken, createToken, deleteTokens } = require('~/models'); +const { getLogStores } = require('~/cache'); + +/** + * Initialize OAuth reconnect manager + */ +async function initializeOAuthReconnectManager() { + try { + const flowManager = getFlowStateManager(getLogStores(CacheKeys.FLOWS)); + const tokenMethods = { + findToken, + updateToken, + createToken, + deleteTokens, + }; + await createOAuthReconnectionManager(flowManager, tokenMethods); + logger.info(`OAuth reconnect manager initialized successfully.`); + } catch (error) { + logger.error('Failed to initialize OAuth reconnect manager:', error); + } +} + +module.exports = initializeOAuthReconnectManager; diff --git a/client/src/components/SidePanel/MCP/MCPPanel.tsx b/client/src/components/SidePanel/MCP/MCPPanel.tsx index 7beee34d0..adcb81e0e 100644 --- a/client/src/components/SidePanel/MCP/MCPPanel.tsx +++ b/client/src/components/SidePanel/MCP/MCPPanel.tsx @@ -1,4 +1,4 @@ -import React, { useState, useMemo, useCallback } from 'react'; +import React, { useState, useMemo, useCallback, useEffect } from 'react'; import { ChevronLeft, Trash2 } from 'lucide-react'; import { useQueryClient } from '@tanstack/react-query'; import { Button, useToastContext } from '@librechat/client'; @@ -12,6 +12,8 @@ import { useLocalize, useMCPConnectionStatus } from '~/hooks'; import { useGetStartupConfig } from '~/data-provider'; import MCPPanelSkeleton from './MCPPanelSkeleton'; +const POLL_FOR_CONNECTION_STATUS_INTERVAL = 2_000; // ms + function MCPPanelContent() { const localize = useLocalize(); const queryClient = useQueryClient(); @@ -26,6 +28,29 @@ function MCPPanelContent() { null, ); + // Check if any connections are in 'connecting' state + const hasConnectingServers = useMemo(() => { + if (!connectionStatus) { + return false; + } + return Object.values(connectionStatus).some( + (status) => status?.connectionState === 'connecting', + ); + }, [connectionStatus]); + + // Set up polling when servers are connecting + useEffect(() => { + if (!hasConnectingServers) { + return; + } + + const intervalId = setInterval(() => { + queryClient.invalidateQueries([QueryKeys.mcpConnectionStatus]); + }, POLL_FOR_CONNECTION_STATUS_INTERVAL); + + return () => clearInterval(intervalId); + }, [hasConnectingServers, queryClient]); + const updateUserPluginsMutation = useUpdateUserPluginsMutation({ onSuccess: async () => { showToast({ message: localize('com_nav_mcp_vars_updated'), status: 'success' }); diff --git a/packages/api/src/index.ts b/packages/api/src/index.ts index e9cbfd1ff..6cfdc9bcc 100644 --- a/packages/api/src/index.ts +++ b/packages/api/src/index.ts @@ -14,6 +14,7 @@ export * from './utils'; export * from './db/utils'; /* OAuth */ export * from './oauth'; +export * from './mcp/oauth/OAuthReconnectionManager'; /* Crypto */ export * from './crypto'; /* Flow */ diff --git a/packages/api/src/mcp/oauth/OAuthReconnectionManager.test.ts b/packages/api/src/mcp/oauth/OAuthReconnectionManager.test.ts new file mode 100644 index 000000000..78fedb9c3 --- /dev/null +++ b/packages/api/src/mcp/oauth/OAuthReconnectionManager.test.ts @@ -0,0 +1,294 @@ +import { TokenMethods } from '@librechat/data-schemas'; +import { FlowStateManager, MCPConnection, MCPOAuthTokens, MCPOptions } from '../..'; +import { MCPManager } from '../MCPManager'; +import { OAuthReconnectionManager } from './OAuthReconnectionManager'; +import { OAuthReconnectionTracker } from './OAuthReconnectionTracker'; + +jest.mock('@librechat/data-schemas', () => ({ + logger: { + info: jest.fn(), + warn: jest.fn(), + error: jest.fn(), + debug: jest.fn(), + }, +})); + +jest.mock('../MCPManager'); + +describe('OAuthReconnectionManager', () => { + let flowManager: jest.Mocked>; + let tokenMethods: jest.Mocked; + let mockMCPManager: jest.Mocked; + let reconnectionManager: OAuthReconnectionManager; + + beforeEach(() => { + jest.clearAllMocks(); + + // Reset singleton instance + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (OAuthReconnectionManager as any).instance = null; + + // Setup mock flow manager + flowManager = { + createFlow: jest.fn(), + completeFlow: jest.fn(), + failFlow: jest.fn(), + deleteFlow: jest.fn(), + getFlow: jest.fn(), + } as unknown as jest.Mocked>; + + // Setup mock token methods + tokenMethods = { + findToken: jest.fn(), + createToken: jest.fn(), + updateToken: jest.fn(), + deleteToken: jest.fn(), + } as unknown as jest.Mocked; + + // Setup mock MCP Manager + mockMCPManager = { + getOAuthServers: jest.fn(), + getUserConnection: jest.fn(), + getUserConnections: jest.fn(), + disconnectUserConnection: jest.fn(), + getRawConfig: jest.fn(), + } as unknown as jest.Mocked; + + (MCPManager.getInstance as jest.Mock).mockReturnValue(mockMCPManager); + }); + + afterEach(() => { + jest.clearAllMocks(); + }); + + describe('Singleton Pattern', () => { + it('should create instance successfully', async () => { + const instance = await OAuthReconnectionManager.createInstance(flowManager, tokenMethods); + expect(instance).toBeInstanceOf(OAuthReconnectionManager); + }); + + it('should throw error when creating instance twice', async () => { + await OAuthReconnectionManager.createInstance(flowManager, tokenMethods); + await expect( + OAuthReconnectionManager.createInstance(flowManager, tokenMethods), + ).rejects.toThrow('OAuthReconnectionManager already initialized'); + }); + + it('should throw error when getting instance before creation', () => { + expect(() => OAuthReconnectionManager.getInstance()).toThrow( + 'OAuthReconnectionManager not initialized', + ); + }); + }); + + describe('isReconnecting', () => { + let reconnectionTracker: OAuthReconnectionTracker; + beforeEach(async () => { + reconnectionTracker = new OAuthReconnectionTracker(); + reconnectionManager = await OAuthReconnectionManager.createInstance( + flowManager, + tokenMethods, + reconnectionTracker, + ); + }); + + it('should return true when server is actively reconnecting', () => { + const userId = 'user-123'; + const serverName = 'test-server'; + + expect(reconnectionManager.isReconnecting(userId, serverName)).toBe(false); + + reconnectionTracker.setActive(userId, serverName); + const result = reconnectionManager.isReconnecting(userId, serverName); + expect(result).toBe(true); + }); + + it('should return false when server is not reconnecting', () => { + const userId = 'user-123'; + const serverName = 'test-server'; + + const result = reconnectionManager.isReconnecting(userId, serverName); + expect(result).toBe(false); + }); + }); + + describe('clearReconnection', () => { + let reconnectionTracker: OAuthReconnectionTracker; + beforeEach(async () => { + reconnectionTracker = new OAuthReconnectionTracker(); + reconnectionManager = await OAuthReconnectionManager.createInstance( + flowManager, + tokenMethods, + reconnectionTracker, + ); + }); + + it('should clear both failed and active reconnection states', () => { + const userId = 'user-123'; + const serverName = 'test-server'; + + reconnectionTracker.setFailed(userId, serverName); + reconnectionTracker.setActive(userId, serverName); + + reconnectionManager.clearReconnection(userId, serverName); + + expect(reconnectionManager.isReconnecting(userId, serverName)).toBe(false); + expect(reconnectionTracker.isFailed(userId, serverName)).toBe(false); + expect(reconnectionTracker.isActive(userId, serverName)).toBe(false); + }); + }); + + describe('reconnectServers', () => { + let reconnectionTracker: OAuthReconnectionTracker; + beforeEach(async () => { + reconnectionTracker = new OAuthReconnectionTracker(); + reconnectionManager = await OAuthReconnectionManager.createInstance( + flowManager, + tokenMethods, + reconnectionTracker, + ); + }); + + it('should reconnect eligible servers', async () => { + const userId = 'user-123'; + const oauthServers = new Set(['server1', 'server2', 'server3']); + mockMCPManager.getOAuthServers.mockReturnValue(oauthServers); + + // server1: has failed reconnection + reconnectionTracker.setFailed(userId, 'server1'); + + // server2: already connected + const mockConnection = { + isConnected: jest.fn().mockResolvedValue(true), + }; + const userConnections = new Map([['server2', mockConnection]]); + mockMCPManager.getUserConnections.mockReturnValue( + userConnections as unknown as Map, + ); + + // server3: has valid token and not connected + tokenMethods.findToken.mockImplementation(async ({ identifier }) => { + if (identifier === 'mcp:server3') { + return { + userId, + identifier, + expiresAt: new Date(Date.now() + 3600000), // 1 hour from now + } as unknown as MCPOAuthTokens; + } + return null; + }); + + // Mock successful reconnection + const mockNewConnection = { + isConnected: jest.fn().mockResolvedValue(true), + disconnect: jest.fn(), + }; + mockMCPManager.getUserConnection.mockResolvedValue( + mockNewConnection as unknown as MCPConnection, + ); + mockMCPManager.getRawConfig.mockReturnValue({ initTimeout: 5000 } as unknown as MCPOptions); + + await reconnectionManager.reconnectServers(userId); + + // Verify server3 was marked as active + expect(reconnectionTracker.isActive(userId, 'server3')).toBe(true); + + // Wait for async tryReconnect to complete + await new Promise((resolve) => setTimeout(resolve, 100)); + + // Verify reconnection was attempted for server3 + expect(mockMCPManager.getUserConnection).toHaveBeenCalledWith({ + serverName: 'server3', + user: { id: userId }, + flowManager, + tokenMethods, + forceNew: false, + connectionTimeout: 5000, + returnOnOAuth: true, + }); + + // Verify successful reconnection cleared the states + expect(reconnectionTracker.isFailed(userId, 'server3')).toBe(false); + expect(reconnectionTracker.isActive(userId, 'server3')).toBe(false); + }); + + it('should handle failed reconnection attempts', async () => { + const userId = 'user-123'; + const oauthServers = new Set(['server1']); + mockMCPManager.getOAuthServers.mockReturnValue(oauthServers); + + // server1: has valid token + tokenMethods.findToken.mockResolvedValue({ + userId, + identifier: 'mcp:server1', + expiresAt: new Date(Date.now() + 3600000), + } as unknown as MCPOAuthTokens); + + // Mock failed connection + mockMCPManager.getUserConnection.mockRejectedValue(new Error('Connection failed')); + mockMCPManager.getRawConfig.mockReturnValue({} as unknown as MCPOptions); + + await reconnectionManager.reconnectServers(userId); + + // Wait for async tryReconnect to complete + await new Promise((resolve) => setTimeout(resolve, 100)); + + // Verify failure handling + expect(reconnectionTracker.isFailed(userId, 'server1')).toBe(true); + expect(reconnectionTracker.isActive(userId, 'server1')).toBe(false); + expect(mockMCPManager.disconnectUserConnection).toHaveBeenCalledWith(userId, 'server1'); + }); + + it('should not reconnect servers with expired tokens', async () => { + const userId = 'user-123'; + const oauthServers = new Set(['server1']); + mockMCPManager.getOAuthServers.mockReturnValue(oauthServers); + + // server1: has expired token + tokenMethods.findToken.mockResolvedValue({ + userId, + identifier: 'mcp:server1', + expiresAt: new Date(Date.now() - 3600000), // 1 hour ago + } as unknown as MCPOAuthTokens); + + await reconnectionManager.reconnectServers(userId); + + // Verify no reconnection attempt was made + expect(reconnectionTracker.isActive(userId, 'server1')).toBe(false); + expect(mockMCPManager.getUserConnection).not.toHaveBeenCalled(); + }); + + it('should handle connection that returns but is not connected', async () => { + const userId = 'user-123'; + const oauthServers = new Set(['server1']); + mockMCPManager.getOAuthServers.mockReturnValue(oauthServers); + + tokenMethods.findToken.mockResolvedValue({ + userId, + identifier: 'mcp:server1', + expiresAt: new Date(Date.now() + 3600000), + } as unknown as MCPOAuthTokens); + + // Mock connection that returns but is not connected + const mockConnection = { + isConnected: jest.fn().mockResolvedValue(false), + disconnect: jest.fn(), + }; + mockMCPManager.getUserConnection.mockResolvedValue( + mockConnection as unknown as MCPConnection, + ); + mockMCPManager.getRawConfig.mockReturnValue({} as unknown as MCPOptions); + + await reconnectionManager.reconnectServers(userId); + + // Wait for async tryReconnect to complete + await new Promise((resolve) => setTimeout(resolve, 100)); + + // Verify failure handling + expect(mockConnection.disconnect).toHaveBeenCalled(); + expect(reconnectionTracker.isFailed(userId, 'server1')).toBe(true); + expect(reconnectionTracker.isActive(userId, 'server1')).toBe(false); + expect(mockMCPManager.disconnectUserConnection).toHaveBeenCalledWith(userId, 'server1'); + }); + }); +}); diff --git a/packages/api/src/mcp/oauth/OAuthReconnectionManager.ts b/packages/api/src/mcp/oauth/OAuthReconnectionManager.ts new file mode 100644 index 000000000..48b751dfa --- /dev/null +++ b/packages/api/src/mcp/oauth/OAuthReconnectionManager.ts @@ -0,0 +1,163 @@ +import { logger } from '@librechat/data-schemas'; +import type { TokenMethods } from '@librechat/data-schemas'; +import type { TUser } from 'librechat-data-provider'; +import type { MCPOAuthTokens } from './types'; +import { OAuthReconnectionTracker } from './OAuthReconnectionTracker'; +import { FlowStateManager } from '~/flow/manager'; +import { MCPManager } from '~/mcp/MCPManager'; + +const DEFAULT_CONNECTION_TIMEOUT_MS = 10_000; // ms + +export class OAuthReconnectionManager { + private static instance: OAuthReconnectionManager | null = null; + + protected readonly flowManager: FlowStateManager; + protected readonly tokenMethods: TokenMethods; + + private readonly reconnectionsTracker: OAuthReconnectionTracker; + + public static getInstance(): OAuthReconnectionManager { + if (!OAuthReconnectionManager.instance) { + throw new Error('OAuthReconnectionManager not initialized'); + } + return OAuthReconnectionManager.instance; + } + + public static async createInstance( + flowManager: FlowStateManager, + tokenMethods: TokenMethods, + reconnections?: OAuthReconnectionTracker, + ): Promise { + if (OAuthReconnectionManager.instance != null) { + throw new Error('OAuthReconnectionManager already initialized'); + } + + const manager = new OAuthReconnectionManager(flowManager, tokenMethods, reconnections); + OAuthReconnectionManager.instance = manager; + + return manager; + } + + public constructor( + flowManager: FlowStateManager, + tokenMethods: TokenMethods, + reconnections?: OAuthReconnectionTracker, + ) { + this.flowManager = flowManager; + this.tokenMethods = tokenMethods; + this.reconnectionsTracker = reconnections ?? new OAuthReconnectionTracker(); + } + + public isReconnecting(userId: string, serverName: string): boolean { + return this.reconnectionsTracker.isActive(userId, serverName); + } + + public async reconnectServers(userId: string) { + const mcpManager = MCPManager.getInstance(); + + // 1. derive the servers to reconnect + const serversToReconnect = []; + for (const serverName of mcpManager.getOAuthServers() ?? []) { + const canReconnect = await this.canReconnect(userId, serverName); + if (canReconnect) { + serversToReconnect.push(serverName); + } + } + + // 2. mark the servers as reconnecting + for (const serverName of serversToReconnect) { + this.reconnectionsTracker.setActive(userId, serverName); + } + + // 3. attempt to reconnect the servers + for (const serverName of serversToReconnect) { + void this.tryReconnect(userId, serverName); + } + } + + public clearReconnection(userId: string, serverName: string) { + this.reconnectionsTracker.removeFailed(userId, serverName); + this.reconnectionsTracker.removeActive(userId, serverName); + } + + private async tryReconnect(userId: string, serverName: string) { + const mcpManager = MCPManager.getInstance(); + + const logPrefix = `[tryReconnectOAuthMCPServer][User: ${userId}][${serverName}]`; + + logger.info(`${logPrefix} Attempting reconnection`); + + const config = mcpManager.getRawConfig(serverName); + + const cleanupOnFailedReconnect = () => { + this.reconnectionsTracker.setFailed(userId, serverName); + this.reconnectionsTracker.removeActive(userId, serverName); + mcpManager.disconnectUserConnection(userId, serverName); + }; + + try { + // attempt to get connection (this will use existing tokens and refresh if needed) + const connection = await mcpManager.getUserConnection({ + serverName, + user: { id: userId } as TUser, + flowManager: this.flowManager, + tokenMethods: this.tokenMethods, + // don't force new connection, let it reuse existing or create new as needed + forceNew: false, + // set a reasonable timeout for reconnection attempts + connectionTimeout: config?.initTimeout ?? DEFAULT_CONNECTION_TIMEOUT_MS, + // don't trigger OAuth flow during reconnection + returnOnOAuth: true, + }); + + if (connection && (await connection.isConnected())) { + logger.info(`${logPrefix} Successfully reconnected`); + this.clearReconnection(userId, serverName); + } else { + logger.warn(`${logPrefix} Failed to reconnect`); + await connection?.disconnect(); + cleanupOnFailedReconnect(); + } + } catch (error) { + logger.warn(`${logPrefix} Failed to reconnect: ${error}`); + cleanupOnFailedReconnect(); + } + } + + private async canReconnect(userId: string, serverName: string) { + const mcpManager = MCPManager.getInstance(); + + // if the server has failed reconnection, don't attempt to reconnect + if (this.reconnectionsTracker.isFailed(userId, serverName)) { + return false; + } + + // if the server is already connected, don't attempt to reconnect + const existingConnections = mcpManager.getUserConnections(userId); + if (existingConnections?.has(serverName)) { + const isConnected = await existingConnections.get(serverName)?.isConnected(); + if (isConnected) { + return false; + } + } + + // if the server has no tokens for the user, don't attempt to reconnect + const accessToken = await this.tokenMethods.findToken({ + userId, + type: 'mcp_oauth', + identifier: `mcp:${serverName}`, + }); + if (accessToken == null) { + return false; + } + + // if the token has expired, don't attempt to reconnect + const now = new Date(); + if (accessToken.expiresAt && accessToken.expiresAt < now) { + return false; + } + + // …otherwise, we're good to go with the reconnect attempt + return true; + } +} diff --git a/packages/api/src/mcp/oauth/OAuthReconnectionTracker.test.ts b/packages/api/src/mcp/oauth/OAuthReconnectionTracker.test.ts new file mode 100644 index 000000000..2a4516dd4 --- /dev/null +++ b/packages/api/src/mcp/oauth/OAuthReconnectionTracker.test.ts @@ -0,0 +1,181 @@ +import { OAuthReconnectionTracker } from './OAuthReconnectionTracker'; + +describe('OAuthReconnectTracker', () => { + let tracker: OAuthReconnectionTracker; + const userId = 'user123'; + const serverName = 'test-server'; + const anotherServer = 'another-server'; + + beforeEach(() => { + tracker = new OAuthReconnectionTracker(); + }); + + describe('setFailed', () => { + it('should record a failed reconnection attempt', () => { + tracker.setFailed(userId, serverName); + expect(tracker.isFailed(userId, serverName)).toBe(true); + }); + + it('should track multiple servers for the same user', () => { + tracker.setFailed(userId, serverName); + tracker.setFailed(userId, anotherServer); + + expect(tracker.isFailed(userId, serverName)).toBe(true); + expect(tracker.isFailed(userId, anotherServer)).toBe(true); + }); + + it('should track different users independently', () => { + const anotherUserId = 'user456'; + + tracker.setFailed(userId, serverName); + tracker.setFailed(anotherUserId, serverName); + + expect(tracker.isFailed(userId, serverName)).toBe(true); + expect(tracker.isFailed(anotherUserId, serverName)).toBe(true); + }); + }); + + describe('isFailed', () => { + it('should return false when no failed attempt is recorded', () => { + expect(tracker.isFailed(userId, serverName)).toBe(false); + }); + + it('should return true after a failed attempt is recorded', () => { + tracker.setFailed(userId, serverName); + expect(tracker.isFailed(userId, serverName)).toBe(true); + }); + + it('should return false for a different server even after another server failed', () => { + tracker.setFailed(userId, serverName); + expect(tracker.isFailed(userId, anotherServer)).toBe(false); + }); + }); + + describe('removeFailed', () => { + it('should clear a failed reconnect record', () => { + tracker.setFailed(userId, serverName); + expect(tracker.isFailed(userId, serverName)).toBe(true); + + tracker.removeFailed(userId, serverName); + expect(tracker.isFailed(userId, serverName)).toBe(false); + }); + + it('should only clear the specific server for the user', () => { + tracker.setFailed(userId, serverName); + tracker.setFailed(userId, anotherServer); + + tracker.removeFailed(userId, serverName); + + expect(tracker.isFailed(userId, serverName)).toBe(false); + expect(tracker.isFailed(userId, anotherServer)).toBe(true); + }); + + it('should handle clearing non-existent records gracefully', () => { + expect(() => tracker.removeFailed(userId, serverName)).not.toThrow(); + }); + }); + + describe('setActive', () => { + it('should mark a server as reconnecting', () => { + tracker.setActive(userId, serverName); + expect(tracker.isActive(userId, serverName)).toBe(true); + }); + + it('should track multiple reconnecting servers', () => { + tracker.setActive(userId, serverName); + tracker.setActive(userId, anotherServer); + + expect(tracker.isActive(userId, serverName)).toBe(true); + expect(tracker.isActive(userId, anotherServer)).toBe(true); + }); + }); + + describe('isActive', () => { + it('should return false when server is not reconnecting', () => { + expect(tracker.isActive(userId, serverName)).toBe(false); + }); + + it('should return true when server is marked as reconnecting', () => { + tracker.setActive(userId, serverName); + expect(tracker.isActive(userId, serverName)).toBe(true); + }); + + it('should handle non-existent user gracefully', () => { + expect(tracker.isActive('non-existent-user', serverName)).toBe(false); + }); + }); + + describe('removeActive', () => { + it('should clear reconnecting state for a server', () => { + tracker.setActive(userId, serverName); + expect(tracker.isActive(userId, serverName)).toBe(true); + + tracker.removeActive(userId, serverName); + expect(tracker.isActive(userId, serverName)).toBe(false); + }); + + it('should only clear specific server state', () => { + tracker.setActive(userId, serverName); + tracker.setActive(userId, anotherServer); + + tracker.removeActive(userId, serverName); + + expect(tracker.isActive(userId, serverName)).toBe(false); + expect(tracker.isActive(userId, anotherServer)).toBe(true); + }); + + it('should handle clearing non-existent state gracefully', () => { + expect(() => tracker.removeActive(userId, serverName)).not.toThrow(); + }); + }); + + describe('cleanup behavior', () => { + it('should clean up empty user sets for failed reconnects', () => { + tracker.setFailed(userId, serverName); + tracker.removeFailed(userId, serverName); + + // Record and clear another user to ensure internal cleanup + const anotherUserId = 'user456'; + tracker.setFailed(anotherUserId, serverName); + + // Original user should still be able to reconnect + expect(tracker.isFailed(userId, serverName)).toBe(false); + }); + + it('should clean up empty user sets for active reconnections', () => { + tracker.setActive(userId, serverName); + tracker.removeActive(userId, serverName); + + // Mark another user to ensure internal cleanup + const anotherUserId = 'user456'; + tracker.setActive(anotherUserId, serverName); + + // Original user should not be reconnecting + expect(tracker.isActive(userId, serverName)).toBe(false); + }); + }); + + describe('combined state management', () => { + it('should handle both failed and reconnecting states independently', () => { + // Mark as reconnecting + tracker.setActive(userId, serverName); + expect(tracker.isActive(userId, serverName)).toBe(true); + expect(tracker.isFailed(userId, serverName)).toBe(false); + + // Record failed attempt + tracker.setFailed(userId, serverName); + expect(tracker.isActive(userId, serverName)).toBe(true); + expect(tracker.isFailed(userId, serverName)).toBe(true); + + // Clear reconnecting state + tracker.removeActive(userId, serverName); + expect(tracker.isActive(userId, serverName)).toBe(false); + expect(tracker.isFailed(userId, serverName)).toBe(true); + + // Clear failed state + tracker.removeFailed(userId, serverName); + expect(tracker.isActive(userId, serverName)).toBe(false); + expect(tracker.isFailed(userId, serverName)).toBe(false); + }); + }); +}); diff --git a/packages/api/src/mcp/oauth/OAuthReconnectionTracker.ts b/packages/api/src/mcp/oauth/OAuthReconnectionTracker.ts new file mode 100644 index 000000000..f18decd1a --- /dev/null +++ b/packages/api/src/mcp/oauth/OAuthReconnectionTracker.ts @@ -0,0 +1,46 @@ +export class OAuthReconnectionTracker { + // Map of userId -> Set of serverNames that have failed reconnection + private failed: Map> = new Map(); + // Map of userId -> Set of serverNames that are actively reconnecting + private active: Map> = new Map(); + + public isFailed(userId: string, serverName: string): boolean { + return this.failed.get(userId)?.has(serverName) ?? false; + } + + public isActive(userId: string, serverName: string): boolean { + return this.active.get(userId)?.has(serverName) ?? false; + } + + public setFailed(userId: string, serverName: string): void { + if (!this.failed.has(userId)) { + this.failed.set(userId, new Set()); + } + + this.failed.get(userId)?.add(serverName); + } + + public setActive(userId: string, serverName: string): void { + if (!this.active.has(userId)) { + this.active.set(userId, new Set()); + } + + this.active.get(userId)?.add(serverName); + } + + public removeFailed(userId: string, serverName: string): void { + const userServers = this.failed.get(userId); + userServers?.delete(serverName); + if (userServers?.size === 0) { + this.failed.delete(userId); + } + } + + public removeActive(userId: string, serverName: string): void { + const userServers = this.active.get(userId); + userServers?.delete(serverName); + if (userServers?.size === 0) { + this.active.delete(userId); + } + } +}