From 96870e0da01531be1106e2d007203a9748c747c9 Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Sun, 21 Sep 2025 22:58:19 -0400 Subject: [PATCH] =?UTF-8?q?=E2=8F=B3=20refactor:=20MCP=20OAuth=20Polling?= =?UTF-8?q?=20with=20Gradual=20Backoff=20and=20Timeout=20Handling=20(#9752?= =?UTF-8?q?)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * refactor: Implement gradual backoff polling for oauth connection status with timeout handling * refactor: Enhance OAuth polling with gradual backoff and timeout handling; update reconnection tracking * refactor: reconnection timeout behavior in OAuthReconnectionManager and OAuthReconnectionTracker - Implement tests to verify reconnection timeout handling, including tracking of reconnection states and cleanup of timed-out entries. - Enhance existing methods in OAuthReconnectionManager and OAuthReconnectionTracker to support timeout checks and cleanup logic. - Ensure proper handling of multiple servers with different timeout periods and edge cases for active states. * chore: remove comment * refactor: Enforce strict 3-minute OAuth timeout with updated polling intervals and improved timeout handling * refactor: Remove unused polling logic and prevent duplicate polling for servers in MCP server manager * refactor: Update localization key for no memories message in MemoryViewer * refactor: Improve MCP tool initialization by handling server failures - Introduced a mechanism to track failed MCP servers, preventing retries for unavailable servers. - Added logging for failed tool creation attempts to enhance debugging and monitoring. * refactor: Update reconnection timeout to enforce a strict 3-minute limit * ci: Update reconnection timeout tests to reflect a strict 3-minute limit * ci: Update reconnection timeout tests to enforce a strict 3-minute limit * chore: Remove unused MCP connection timeout message --- api/app/clients/tools/util/handleTools.js | 9 + .../src/components/SidePanel/MCP/MCPPanel.tsx | 27 +- .../SidePanel/Memories/MemoryViewer.tsx | 2 +- client/src/hooks/MCP/useMCPServerManager.ts | 88 +++++- client/src/locales/en/translation.json | 2 +- .../oauth/OAuthReconnectionManager.test.ts | 144 +++++++++ .../src/mcp/oauth/OAuthReconnectionManager.ts | 8 +- .../oauth/OAuthReconnectionTracker.test.ts | 274 ++++++++++++++++++ .../src/mcp/oauth/OAuthReconnectionTracker.ts | 47 ++- 9 files changed, 560 insertions(+), 41 deletions(-) diff --git a/api/app/clients/tools/util/handleTools.js b/api/app/clients/tools/util/handleTools.js index 8bd4a46cfe..efaa80cfc8 100644 --- a/api/app/clients/tools/util/handleTools.js +++ b/api/app/clients/tools/util/handleTools.js @@ -409,12 +409,16 @@ Current Date & Time: ${replaceSpecialVars({ text: '{{iso_datetime}}' })} const mcpToolPromises = []; /** MCP server tools are initialized sequentially by server */ let index = -1; + const failedMCPServers = new Set(); for (const [serverName, toolConfigs] of Object.entries(requestedMCPTools)) { index++; /** @type {LCAvailableTools} */ let availableTools; for (const config of toolConfigs) { try { + if (failedMCPServers.has(serverName)) { + continue; + } const mcpParams = { res: options.res, userId: user, @@ -458,6 +462,11 @@ Current Date & Time: ${replaceSpecialVars({ text: '{{iso_datetime}}' })} loadedTools.push(...mcpTool); } else if (mcpTool) { loadedTools.push(mcpTool); + } else { + failedMCPServers.add(serverName); + logger.warn( + `MCP tool creation failed for "${config.toolKey}", server may be unavailable or unauthenticated.`, + ); } } catch (error) { logger.error(`Error loading MCP tool for server ${serverName}:`, error); diff --git a/client/src/components/SidePanel/MCP/MCPPanel.tsx b/client/src/components/SidePanel/MCP/MCPPanel.tsx index 5bb6d18b81..8f79023897 100644 --- a/client/src/components/SidePanel/MCP/MCPPanel.tsx +++ b/client/src/components/SidePanel/MCP/MCPPanel.tsx @@ -1,4 +1,4 @@ -import React, { useState, useMemo, useCallback, useEffect } from 'react'; +import React, { useState, useMemo, useCallback } from 'react'; import { ChevronLeft, Trash2 } from 'lucide-react'; import { useQueryClient } from '@tanstack/react-query'; import { Button, useToastContext } from '@librechat/client'; @@ -12,8 +12,6 @@ import { useLocalize, useMCPConnectionStatus } from '~/hooks'; import { useGetStartupConfig } from '~/data-provider'; import MCPPanelSkeleton from './MCPPanelSkeleton'; -const POLL_FOR_CONNECTION_STATUS_INTERVAL = 2_000; // ms - function MCPPanelContent() { const localize = useLocalize(); const queryClient = useQueryClient(); @@ -28,29 +26,6 @@ function MCPPanelContent() { null, ); - // Check if any connections are in 'connecting' state - const hasConnectingServers = useMemo(() => { - if (!connectionStatus) { - return false; - } - return Object.values(connectionStatus).some( - (status) => status?.connectionState === 'connecting', - ); - }, [connectionStatus]); - - // Set up polling when servers are connecting - useEffect(() => { - if (!hasConnectingServers) { - return; - } - - const intervalId = setInterval(() => { - queryClient.invalidateQueries([QueryKeys.mcpConnectionStatus]); - }, POLL_FOR_CONNECTION_STATUS_INTERVAL); - - return () => clearInterval(intervalId); - }, [hasConnectingServers, queryClient]); - const updateUserPluginsMutation = useUpdateUserPluginsMutation({ onSuccess: async () => { showToast({ message: localize('com_nav_mcp_vars_updated'), status: 'success' }); diff --git a/client/src/components/SidePanel/Memories/MemoryViewer.tsx b/client/src/components/SidePanel/Memories/MemoryViewer.tsx index 473b1e06b4..befb19523c 100644 --- a/client/src/components/SidePanel/Memories/MemoryViewer.tsx +++ b/client/src/components/SidePanel/Memories/MemoryViewer.tsx @@ -362,7 +362,7 @@ export default function MemoryViewer() { colSpan={hasUpdateAccess ? 2 : 1} className="h-24 text-center text-sm text-text-secondary" > - {localize('com_ui_no_data')} + {localize('com_ui_no_memories')} )} diff --git a/client/src/hooks/MCP/useMCPServerManager.ts b/client/src/hooks/MCP/useMCPServerManager.ts index 440e3bb14a..2e05569710 100644 --- a/client/src/hooks/MCP/useMCPServerManager.ts +++ b/client/src/hooks/MCP/useMCPServerManager.ts @@ -129,7 +129,7 @@ export function useMCPServerManager({ conversationId }: { conversationId?: strin (serverName: string) => { const state = serverStates[serverName]; if (state?.pollInterval) { - clearInterval(state.pollInterval); + clearTimeout(state.pollInterval); } updateServerState(serverName, { isInitializing: false, @@ -144,8 +144,53 @@ export function useMCPServerManager({ conversationId }: { conversationId?: strin const startServerPolling = useCallback( (serverName: string) => { - const pollInterval = setInterval(async () => { + // Prevent duplicate polling for the same server + const existingState = serverStates[serverName]; + if (existingState?.pollInterval) { + console.debug(`[MCP Manager] Polling already active for ${serverName}, skipping duplicate`); + return; + } + + let pollAttempts = 0; + let timeoutId: NodeJS.Timeout | null = null; + + /** OAuth typically completes in 5 seconds to 3 minutes + * We enforce a strict 3-minute timeout with gradual backoff + */ + const getPollInterval = (attempt: number): number => { + if (attempt < 12) return 5000; // First minute: every 5s (12 polls) + if (attempt < 22) return 6000; // Second minute: every 6s (10 polls) + return 7500; // Final minute: every 7.5s (8 polls) + }; + + const maxAttempts = 30; // Exactly 3 minutes (180 seconds) total + const OAUTH_TIMEOUT_MS = 180000; // 3 minutes in milliseconds + + const pollOnce = async () => { try { + pollAttempts++; + const state = serverStates[serverName]; + + /** Stop polling after 3 minutes or max attempts */ + const elapsedTime = state?.oauthStartTime + ? Date.now() - state.oauthStartTime + : pollAttempts * 5000; // Rough estimate if no start time + + if (pollAttempts > maxAttempts || elapsedTime > OAUTH_TIMEOUT_MS) { + console.warn( + `[MCP Manager] OAuth timeout for ${serverName} after ${(elapsedTime / 1000).toFixed(0)}s (attempt ${pollAttempts})`, + ); + showToast({ + message: localize('com_ui_mcp_oauth_timeout', { 0: serverName }), + status: 'error', + }); + if (timeoutId) { + clearTimeout(timeoutId); + } + cleanupServerState(serverName); + return; + } + await queryClient.refetchQueries([QueryKeys.mcpConnectionStatus]); const freshConnectionData = queryClient.getQueryData([ @@ -153,11 +198,12 @@ export function useMCPServerManager({ conversationId }: { conversationId?: strin ]) as any; const freshConnectionStatus = freshConnectionData?.connectionStatus || {}; - const state = serverStates[serverName]; const serverStatus = freshConnectionStatus[serverName]; if (serverStatus?.connectionState === 'connected') { - clearInterval(pollInterval); + if (timeoutId) { + clearTimeout(timeoutId); + } showToast({ message: localize('com_ui_mcp_authenticated_success', { 0: serverName }), @@ -179,12 +225,15 @@ export function useMCPServerManager({ conversationId }: { conversationId?: strin return; } - if (state?.oauthStartTime && Date.now() - state.oauthStartTime > 180000) { + // Check for OAuth timeout (should align with maxAttempts) + if (state?.oauthStartTime && Date.now() - state.oauthStartTime > OAUTH_TIMEOUT_MS) { showToast({ message: localize('com_ui_mcp_oauth_timeout', { 0: serverName }), status: 'error', }); - clearInterval(pollInterval); + if (timeoutId) { + clearTimeout(timeoutId); + } cleanupServerState(serverName); return; } @@ -194,19 +243,38 @@ export function useMCPServerManager({ conversationId }: { conversationId?: strin message: localize('com_ui_mcp_init_failed'), status: 'error', }); - clearInterval(pollInterval); + if (timeoutId) { + clearTimeout(timeoutId); + } cleanupServerState(serverName); return; } + + // Schedule next poll with smart intervals based on OAuth timing + const nextInterval = getPollInterval(pollAttempts); + + // Log progress periodically + if (pollAttempts % 5 === 0 || pollAttempts <= 2) { + console.debug( + `[MCP Manager] Polling ${serverName} attempt ${pollAttempts}/${maxAttempts}, next in ${nextInterval / 1000}s`, + ); + } + + timeoutId = setTimeout(pollOnce, nextInterval); + updateServerState(serverName, { pollInterval: timeoutId }); } catch (error) { console.error(`[MCP Manager] Error polling server ${serverName}:`, error); - clearInterval(pollInterval); + if (timeoutId) { + clearTimeout(timeoutId); + } cleanupServerState(serverName); return; } - }, 3500); + }; - updateServerState(serverName, { pollInterval }); + // Start the first poll + timeoutId = setTimeout(pollOnce, getPollInterval(0)); + updateServerState(serverName, { pollInterval: timeoutId }); }, [ queryClient, diff --git a/client/src/locales/en/translation.json b/client/src/locales/en/translation.json index a92c1887ca..7c2e9e7f3e 100644 --- a/client/src/locales/en/translation.json +++ b/client/src/locales/en/translation.json @@ -1027,8 +1027,8 @@ "com_ui_no_categories": "No categories available", "com_ui_no_category": "No category", "com_ui_no_changes": "No changes were made", - "com_ui_no_data": "something needs to go here. was empty", "com_ui_no_individual_access": "No individual users or groups have access to this agent", + "com_ui_no_memories": "No memories. Create them manually or prompt the AI to remember something", "com_ui_no_personalization_available": "No personalization options are currently available", "com_ui_no_read_access": "You don't have permission to view memories", "com_ui_no_results_found": "No results found", diff --git a/packages/api/src/mcp/oauth/OAuthReconnectionManager.test.ts b/packages/api/src/mcp/oauth/OAuthReconnectionManager.test.ts index 8f18df2f5d..d2295191cf 100644 --- a/packages/api/src/mcp/oauth/OAuthReconnectionManager.test.ts +++ b/packages/api/src/mcp/oauth/OAuthReconnectionManager.test.ts @@ -319,4 +319,148 @@ describe('OAuthReconnectionManager', () => { expect(mockMCPManager.disconnectUserConnection).not.toHaveBeenCalled(); }); }); + + describe('reconnection timeout behavior', () => { + let reconnectionTracker: OAuthReconnectionTracker; + + beforeEach(async () => { + jest.useFakeTimers(); + reconnectionTracker = new OAuthReconnectionTracker(); + reconnectionManager = await OAuthReconnectionManager.createInstance( + flowManager, + tokenMethods, + reconnectionTracker, + ); + }); + + afterEach(() => { + jest.useRealTimers(); + }); + + it('should handle timed out reconnections via isReconnecting check', () => { + const userId = 'user-123'; + const serverName = 'test-server'; + const now = Date.now(); + jest.setSystemTime(now); + + // Set server as reconnecting + reconnectionTracker.setActive(userId, serverName); + expect(reconnectionManager.isReconnecting(userId, serverName)).toBe(true); + + // Advance time by 2 minutes 59 seconds - should still be reconnecting + jest.advanceTimersByTime(2 * 60 * 1000 + 59 * 1000); + expect(reconnectionManager.isReconnecting(userId, serverName)).toBe(true); + + // Advance time by 2 more seconds (total 3 minutes 1 second) - should be auto-cleaned + jest.advanceTimersByTime(2000); + expect(reconnectionManager.isReconnecting(userId, serverName)).toBe(false); + }); + + it('should not attempt to reconnect servers that have timed out during reconnection', async () => { + const userId = 'user-123'; + const oauthServers = new Set(['server1', 'server2']); + mockMCPManager.getOAuthServers.mockReturnValue(oauthServers); + + const now = Date.now(); + jest.setSystemTime(now); + + // Set server1 as having been reconnecting for over 5 minutes + reconnectionTracker.setActive(userId, 'server1'); + jest.advanceTimersByTime(6 * 60 * 1000); // 6 minutes + + // server2: has valid token and not connected + tokenMethods.findToken.mockImplementation(async ({ identifier }) => { + if (identifier === 'mcp:server2') { + return { + userId, + identifier, + expiresAt: new Date(Date.now() + 3600000), + } as unknown as MCPOAuthTokens; + } + return null; + }); + + // Mock successful reconnection + const mockNewConnection = { + isConnected: jest.fn().mockResolvedValue(true), + disconnect: jest.fn(), + }; + mockMCPManager.getUserConnection.mockResolvedValue( + mockNewConnection as unknown as MCPConnection, + ); + + await reconnectionManager.reconnectServers(userId); + + // server1 should still be in active set, just not eligible for reconnection + expect(reconnectionTracker.isActive(userId, 'server1')).toBe(true); + expect(reconnectionTracker.isStillReconnecting(userId, 'server1')).toBe(false); + + // Only server2 should be marked as reconnecting + expect(reconnectionTracker.isActive(userId, 'server2')).toBe(true); + + // Wait for async reconnection using runAllTimersAsync + await jest.runAllTimersAsync(); + + // Verify only server2 was reconnected + expect(mockMCPManager.getUserConnection).toHaveBeenCalledTimes(1); + expect(mockMCPManager.getUserConnection).toHaveBeenCalledWith( + expect.objectContaining({ + serverName: 'server2', + }), + ); + }); + + it('should properly track reconnection time for multiple sequential reconnect attempts', async () => { + const userId = 'user-123'; + const serverName = 'server1'; + const oauthServers = new Set([serverName]); + mockMCPManager.getOAuthServers.mockReturnValue(oauthServers); + + const now = Date.now(); + jest.setSystemTime(now); + + // Setup valid token + tokenMethods.findToken.mockResolvedValue({ + userId, + identifier: `mcp:${serverName}`, + expiresAt: new Date(Date.now() + 3600000), + } as unknown as MCPOAuthTokens); + + // First reconnect attempt - will fail + mockMCPManager.getUserConnection.mockRejectedValueOnce(new Error('Connection failed')); + mockMCPManager.getRawConfig.mockReturnValue({} as unknown as MCPOptions); + + await reconnectionManager.reconnectServers(userId); + await jest.runAllTimersAsync(); + + // Server should be marked as failed + expect(reconnectionTracker.isFailed(userId, serverName)).toBe(true); + expect(reconnectionTracker.isActive(userId, serverName)).toBe(false); + + // Clear failed state to allow another attempt + reconnectionManager.clearReconnection(userId, serverName); + + // Advance time by 3 minutes + jest.advanceTimersByTime(3 * 60 * 1000); + + // Second reconnect attempt - will succeed + const mockConnection = { + isConnected: jest.fn().mockResolvedValue(true), + }; + mockMCPManager.getUserConnection.mockResolvedValue( + mockConnection as unknown as MCPConnection, + ); + + await reconnectionManager.reconnectServers(userId); + + // Server should be marked as active with new timestamp + expect(reconnectionTracker.isActive(userId, serverName)).toBe(true); + + await jest.runAllTimersAsync(); + + // After successful reconnection, should be cleared + expect(reconnectionTracker.isActive(userId, serverName)).toBe(false); + expect(reconnectionTracker.isFailed(userId, serverName)).toBe(false); + }); + }); }); diff --git a/packages/api/src/mcp/oauth/OAuthReconnectionManager.ts b/packages/api/src/mcp/oauth/OAuthReconnectionManager.ts index b819403a60..9e84ef1483 100644 --- a/packages/api/src/mcp/oauth/OAuthReconnectionManager.ts +++ b/packages/api/src/mcp/oauth/OAuthReconnectionManager.ts @@ -56,7 +56,9 @@ export class OAuthReconnectionManager { } public isReconnecting(userId: string, serverName: string): boolean { - return this.reconnectionsTracker.isActive(userId, serverName); + // Clean up if timed out, then return whether still reconnecting + this.reconnectionsTracker.cleanupIfTimedOut(userId, serverName); + return this.reconnectionsTracker.isStillReconnecting(userId, serverName); } public async reconnectServers(userId: string) { @@ -149,6 +151,10 @@ export class OAuthReconnectionManager { return false; } + if (this.reconnectionsTracker.isActive(userId, serverName)) { + return false; + } + // if the server is already connected, don't attempt to reconnect const existingConnections = this.mcpManager.getUserConnections(userId); if (existingConnections?.has(serverName)) { diff --git a/packages/api/src/mcp/oauth/OAuthReconnectionTracker.test.ts b/packages/api/src/mcp/oauth/OAuthReconnectionTracker.test.ts index 2a4516dd47..68ac1d027e 100644 --- a/packages/api/src/mcp/oauth/OAuthReconnectionTracker.test.ts +++ b/packages/api/src/mcp/oauth/OAuthReconnectionTracker.test.ts @@ -178,4 +178,278 @@ describe('OAuthReconnectTracker', () => { expect(tracker.isFailed(userId, serverName)).toBe(false); }); }); + + describe('timeout behavior', () => { + beforeEach(() => { + jest.useFakeTimers(); + }); + + afterEach(() => { + jest.useRealTimers(); + }); + + it('should track timestamp when setting active state', () => { + const now = Date.now(); + jest.setSystemTime(now); + + tracker.setActive(userId, serverName); + expect(tracker.isActive(userId, serverName)).toBe(true); + + // Verify timestamp was recorded (implementation detail tested via timeout behavior) + jest.advanceTimersByTime(1000); // 1 second + expect(tracker.isActive(userId, serverName)).toBe(true); + }); + + it('should handle timeout checking with isStillReconnecting', () => { + const now = Date.now(); + jest.setSystemTime(now); + + tracker.setActive(userId, serverName); + expect(tracker.isStillReconnecting(userId, serverName)).toBe(true); + + // Advance time by 2 minutes 59 seconds - should still be reconnecting + jest.advanceTimersByTime(2 * 60 * 1000 + 59 * 1000); + expect(tracker.isStillReconnecting(userId, serverName)).toBe(true); + + // Advance time by 2 more seconds (total 3 minutes 1 second) - should not be still reconnecting + jest.advanceTimersByTime(2000); + expect(tracker.isStillReconnecting(userId, serverName)).toBe(false); + + // But isActive should still return true (simple check) + expect(tracker.isActive(userId, serverName)).toBe(true); + }); + + it('should handle multiple servers with different timeout periods', () => { + const now = Date.now(); + jest.setSystemTime(now); + + // Set server1 as active + tracker.setActive(userId, serverName); + expect(tracker.isActive(userId, serverName)).toBe(true); + + // Advance 3 minutes + jest.advanceTimersByTime(3 * 60 * 1000); + + // Set server2 as active + tracker.setActive(userId, anotherServer); + expect(tracker.isActive(userId, anotherServer)).toBe(true); + expect(tracker.isActive(userId, serverName)).toBe(true); + + // Advance 2 more minutes + 1ms (server1 at 5 min + 1ms, server2 at 2 min + 1ms) + jest.advanceTimersByTime(2 * 60 * 1000 + 1); + expect(tracker.isStillReconnecting(userId, serverName)).toBe(false); // server1 timed out + expect(tracker.isStillReconnecting(userId, anotherServer)).toBe(true); // server2 still active + + // Advance 3 more minutes (server2 at 5 min + 1ms) + jest.advanceTimersByTime(3 * 60 * 1000); + expect(tracker.isStillReconnecting(userId, anotherServer)).toBe(false); // server2 timed out + }); + + it('should clear timestamp when removing active state', () => { + const now = Date.now(); + jest.setSystemTime(now); + + tracker.setActive(userId, serverName); + expect(tracker.isActive(userId, serverName)).toBe(true); + + tracker.removeActive(userId, serverName); + expect(tracker.isActive(userId, serverName)).toBe(false); + + // Set active again and verify new timestamp is used + jest.advanceTimersByTime(3 * 60 * 1000); + tracker.setActive(userId, serverName); + expect(tracker.isActive(userId, serverName)).toBe(true); + + // Advance 4 more minutes from new timestamp - should still be active + jest.advanceTimersByTime(4 * 60 * 1000); + expect(tracker.isActive(userId, serverName)).toBe(true); + }); + + it('should properly cleanup after timeout occurs', () => { + const now = Date.now(); + jest.setSystemTime(now); + + tracker.setActive(userId, serverName); + tracker.setActive(userId, anotherServer); + expect(tracker.isActive(userId, serverName)).toBe(true); + expect(tracker.isActive(userId, anotherServer)).toBe(true); + + // Advance past timeout + jest.advanceTimersByTime(6 * 60 * 1000); + + // Both should still be in active set but not "still reconnecting" + expect(tracker.isActive(userId, serverName)).toBe(true); + expect(tracker.isActive(userId, anotherServer)).toBe(true); + expect(tracker.isStillReconnecting(userId, serverName)).toBe(false); + expect(tracker.isStillReconnecting(userId, anotherServer)).toBe(false); + + // Cleanup both + expect(tracker.cleanupIfTimedOut(userId, serverName)).toBe(true); + expect(tracker.cleanupIfTimedOut(userId, anotherServer)).toBe(true); + + // Now they should be removed from active set + expect(tracker.isActive(userId, serverName)).toBe(false); + expect(tracker.isActive(userId, anotherServer)).toBe(false); + }); + + it('should handle timeout check for non-existent entries gracefully', () => { + const now = Date.now(); + jest.setSystemTime(now); + + // Check non-existent entry + expect(tracker.isActive('non-existent', 'non-existent')).toBe(false); + expect(tracker.isStillReconnecting('non-existent', 'non-existent')).toBe(false); + + // Set and then manually remove + tracker.setActive(userId, serverName); + tracker.removeActive(userId, serverName); + + // Advance time and check - should not throw + jest.advanceTimersByTime(6 * 60 * 1000); + expect(tracker.isActive(userId, serverName)).toBe(false); + expect(tracker.isStillReconnecting(userId, serverName)).toBe(false); + }); + }); + + describe('isStillReconnecting', () => { + it('should return true for active entries within timeout', () => { + jest.useFakeTimers(); + const now = Date.now(); + jest.setSystemTime(now); + + tracker.setActive(userId, serverName); + expect(tracker.isStillReconnecting(userId, serverName)).toBe(true); + + // Still within timeout + jest.advanceTimersByTime(3 * 60 * 1000); + expect(tracker.isStillReconnecting(userId, serverName)).toBe(true); + + jest.useRealTimers(); + }); + + it('should return false for timed out entries', () => { + jest.useFakeTimers(); + const now = Date.now(); + jest.setSystemTime(now); + + tracker.setActive(userId, serverName); + + // Advance past timeout + jest.advanceTimersByTime(6 * 60 * 1000); + + // Should not be still reconnecting + expect(tracker.isStillReconnecting(userId, serverName)).toBe(false); + + // But isActive should still return true (simple check) + expect(tracker.isActive(userId, serverName)).toBe(true); + + jest.useRealTimers(); + }); + + it('should return false for non-existent entries', () => { + expect(tracker.isStillReconnecting('non-existent', 'non-existent')).toBe(false); + expect(tracker.isStillReconnecting(userId, serverName)).toBe(false); + }); + }); + + describe('cleanupIfTimedOut', () => { + beforeEach(() => { + jest.useFakeTimers(); + }); + + afterEach(() => { + jest.useRealTimers(); + }); + + it('should cleanup timed out entries and return true', () => { + const now = Date.now(); + jest.setSystemTime(now); + + tracker.setActive(userId, serverName); + expect(tracker.isActive(userId, serverName)).toBe(true); + + // Advance past timeout + jest.advanceTimersByTime(6 * 60 * 1000); + + // Cleanup should return true and remove the entry + const wasCleanedUp = tracker.cleanupIfTimedOut(userId, serverName); + expect(wasCleanedUp).toBe(true); + expect(tracker.isActive(userId, serverName)).toBe(false); + }); + + it('should not cleanup active entries and return false', () => { + const now = Date.now(); + jest.setSystemTime(now); + + tracker.setActive(userId, serverName); + + // Within timeout period + jest.advanceTimersByTime(3 * 60 * 1000); + + const wasCleanedUp = tracker.cleanupIfTimedOut(userId, serverName); + expect(wasCleanedUp).toBe(false); + expect(tracker.isActive(userId, serverName)).toBe(true); + }); + + it('should return false for non-existent entries', () => { + const wasCleanedUp = tracker.cleanupIfTimedOut('non-existent', 'non-existent'); + expect(wasCleanedUp).toBe(false); + }); + }); + + describe('timestamp tracking edge cases', () => { + beforeEach(() => { + jest.useFakeTimers(); + }); + + afterEach(() => { + jest.useRealTimers(); + }); + + it('should update timestamp when setting active on already active server', () => { + const now = Date.now(); + jest.setSystemTime(now); + + tracker.setActive(userId, serverName); + expect(tracker.isActive(userId, serverName)).toBe(true); + + // Advance 3 minutes + jest.advanceTimersByTime(3 * 60 * 1000); + expect(tracker.isActive(userId, serverName)).toBe(true); + + // Set active again - should reset timestamp + tracker.setActive(userId, serverName); + + // Advance 4 more minutes from reset (total 7 minutes from start) + jest.advanceTimersByTime(4 * 60 * 1000); + // Should still be active since timestamp was reset at 3 minutes + expect(tracker.isActive(userId, serverName)).toBe(true); + + // Advance 2 more minutes (6 minutes from reset) + jest.advanceTimersByTime(2 * 60 * 1000); + // Should not be still reconnecting (timed out) + expect(tracker.isStillReconnecting(userId, serverName)).toBe(false); + }); + + it('should handle same server for different users independently', () => { + const anotherUserId = 'user456'; + const now = Date.now(); + jest.setSystemTime(now); + + tracker.setActive(userId, serverName); + + // Advance 3 minutes + jest.advanceTimersByTime(3 * 60 * 1000); + + tracker.setActive(anotherUserId, serverName); + + // Advance 3 more minutes + jest.advanceTimersByTime(3 * 60 * 1000); + + // First user's connection should be timed out + expect(tracker.isStillReconnecting(userId, serverName)).toBe(false); + // Second user's connection should still be reconnecting + expect(tracker.isStillReconnecting(anotherUserId, serverName)).toBe(true); + }); + }); }); diff --git a/packages/api/src/mcp/oauth/OAuthReconnectionTracker.ts b/packages/api/src/mcp/oauth/OAuthReconnectionTracker.ts index f18decd1ab..b65f8ad115 100644 --- a/packages/api/src/mcp/oauth/OAuthReconnectionTracker.ts +++ b/packages/api/src/mcp/oauth/OAuthReconnectionTracker.ts @@ -1,17 +1,52 @@ export class OAuthReconnectionTracker { - // Map of userId -> Set of serverNames that have failed reconnection + /** Map of userId -> Set of serverNames that have failed reconnection */ private failed: Map> = new Map(); - // Map of userId -> Set of serverNames that are actively reconnecting + /** Map of userId -> Set of serverNames that are actively reconnecting */ private active: Map> = new Map(); + /** Map of userId:serverName -> timestamp when reconnection started */ + private activeTimestamps: Map = new Map(); + /** Maximum time (ms) a server can be in reconnecting state before auto-cleanup */ + private readonly RECONNECTION_TIMEOUT_MS = 3 * 60 * 1000; // 3 minutes public isFailed(userId: string, serverName: string): boolean { return this.failed.get(userId)?.has(serverName) ?? false; } + /** Check if server is in the active set (original simple check) */ public isActive(userId: string, serverName: string): boolean { return this.active.get(userId)?.has(serverName) ?? false; } + /** Check if server is still reconnecting (considers timeout) */ + public isStillReconnecting(userId: string, serverName: string): boolean { + if (!this.isActive(userId, serverName)) { + return false; + } + + const key = `${userId}:${serverName}`; + const startTime = this.activeTimestamps.get(key); + + // If there's a timestamp and it has timed out, it's not still reconnecting + if (startTime && Date.now() - startTime > this.RECONNECTION_TIMEOUT_MS) { + return false; + } + + return true; + } + + /** Clean up server if it has timed out - returns true if cleanup was performed */ + public cleanupIfTimedOut(userId: string, serverName: string): boolean { + const key = `${userId}:${serverName}`; + const startTime = this.activeTimestamps.get(key); + + if (startTime && Date.now() - startTime > this.RECONNECTION_TIMEOUT_MS) { + this.removeActive(userId, serverName); + return true; + } + + return false; + } + public setFailed(userId: string, serverName: string): void { if (!this.failed.has(userId)) { this.failed.set(userId, new Set()); @@ -26,6 +61,10 @@ export class OAuthReconnectionTracker { } this.active.get(userId)?.add(serverName); + + /** Track when reconnection started */ + const key = `${userId}:${serverName}`; + this.activeTimestamps.set(key, Date.now()); } public removeFailed(userId: string, serverName: string): void { @@ -42,5 +81,9 @@ export class OAuthReconnectionTracker { if (userServers?.size === 0) { this.active.delete(userId); } + + /** Clear timestamp tracking */ + const key = `${userId}:${serverName}`; + this.activeTimestamps.delete(key); } }