From ad5c51f62b321bd6f714e946ab08190616c09ecf Mon Sep 17 00:00:00 2001 From: matt burnett Date: Tue, 10 Mar 2026 11:21:36 -0700 Subject: [PATCH] =?UTF-8?q?=E2=9B=88=EF=B8=8F=20fix:=20MCP=20Reconnection?= =?UTF-8?q?=20Storm=20Prevention=20with=20Circuit=20Breaker,=20Backoff,=20?= =?UTF-8?q?and=20Tool=20Stubs=20(#12162)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: MCP reconnection stability - circuit breaker, throttling, and cooldown retry * Comment and logging cleanup * fix broken tests --- api/server/services/MCP.js | 66 ++++- packages/api/src/mcp/UserConnectionManager.ts | 3 + .../src/mcp/__tests__/MCPConnection.test.ts | 239 ++++++++++++++++++ .../MCPConnectionAgentLifecycle.test.ts | 4 + packages/api/src/mcp/connection.ts | 117 ++++++++- .../oauth/OAuthReconnectionManager.test.ts | 162 +++++++++++- .../src/mcp/oauth/OAuthReconnectionManager.ts | 44 +++- .../oauth/OAuthReconnectionTracker.test.ts | 95 +++++++ .../src/mcp/oauth/OAuthReconnectionTracker.ts | 44 +++- 9 files changed, 736 insertions(+), 38 deletions(-) diff --git a/api/server/services/MCP.js b/api/server/services/MCP.js index ad1f9f5cc3..4f8cdc8195 100644 --- a/api/server/services/MCP.js +++ b/api/server/services/MCP.js @@ -34,6 +34,39 @@ const { reinitMCPServer } = require('./Tools/mcp'); const { getAppConfig } = require('./Config'); const { getLogStores } = require('~/cache'); +const lastReconnectAttempts = new Map(); +const RECONNECT_THROTTLE_MS = 10_000; + +const missingToolCache = new Map(); +const MISSING_TOOL_TTL_MS = 10_000; + +const unavailableMsg = + "This tool's MCP server is temporarily unavailable. Please try again shortly."; + +/** + * @param {string} toolName + * @param {string} serverName + */ +function createUnavailableToolStub(toolName, serverName) { + const normalizedToolKey = `${toolName}${Constants.mcp_delimiter}${normalizeServerName(serverName)}`; + const _call = async () => unavailableMsg; + const toolInstance = tool(_call, { + schema: { + type: 'object', + properties: { + input: { type: 'string', description: 'Input for the tool' }, + }, + required: [], + }, + name: normalizedToolKey, + description: unavailableMsg, + responseFormat: AgentConstants.CONTENT_AND_ARTIFACT, + }); + toolInstance.mcp = true; + toolInstance.mcpRawServerName = serverName; + return toolInstance; +} + function isEmptyObjectSchema(jsonSchema) { return ( jsonSchema != null && @@ -211,6 +244,16 @@ async function reconnectServer({ logger.debug( `[MCP][reconnectServer] serverName: ${serverName}, user: ${user?.id}, hasUserMCPAuthMap: ${!!userMCPAuthMap}`, ); + + const throttleKey = `${user.id}:${serverName}`; + const now = Date.now(); + const lastAttempt = lastReconnectAttempts.get(throttleKey) ?? 0; + if (now - lastAttempt < RECONNECT_THROTTLE_MS) { + logger.debug(`[MCP][reconnectServer] Throttled reconnect for ${serverName}`); + return null; + } + lastReconnectAttempts.set(throttleKey, now); + const runId = Constants.USE_PRELIM_RESPONSE_MESSAGE_ID; const flowId = `${user.id}:${serverName}:${Date.now()}`; const flowManager = getFlowStateManager(getLogStores(CacheKeys.FLOWS)); @@ -267,7 +310,7 @@ async function reconnectServer({ userMCPAuthMap, forceNew: true, returnOnOAuth: false, - connectionTimeout: Time.TWO_MINUTES, + connectionTimeout: Time.THIRTY_SECONDS, }); } finally { // Clean up abort handler to prevent memory leaks @@ -332,7 +375,7 @@ async function createMCPTools({ }); if (!result || !result.tools) { logger.warn(`[MCP][${serverName}] Failed to reinitialize MCP server.`); - return; + return []; } const serverTools = []; @@ -402,6 +445,14 @@ async function createMCPTool({ /** @type {LCTool | undefined} */ let toolDefinition = availableTools?.[toolKey]?.function; if (!toolDefinition) { + const cachedAt = missingToolCache.get(toolKey); + if (cachedAt && Date.now() - cachedAt < MISSING_TOOL_TTL_MS) { + logger.debug( + `[MCP][${serverName}][${toolName}] Tool in negative cache, returning unavailable stub.`, + ); + return createUnavailableToolStub(toolName, serverName); + } + logger.warn( `[MCP][${serverName}][${toolName}] Requested tool not found in available tools, re-initializing MCP server.`, ); @@ -415,11 +466,17 @@ async function createMCPTool({ streamId, }); toolDefinition = result?.availableTools?.[toolKey]?.function; + + if (!toolDefinition) { + missingToolCache.set(toolKey, Date.now()); + } } if (!toolDefinition) { - logger.warn(`[MCP][${serverName}][${toolName}] Tool definition not found, cannot create tool.`); - return; + logger.warn( + `[MCP][${serverName}][${toolName}] Tool definition not found, returning unavailable stub.`, + ); + return createUnavailableToolStub(toolName, serverName); } return createToolInstance({ @@ -720,4 +777,5 @@ module.exports = { getMCPSetupData, checkOAuthFlowStatus, getServerConnectionStatus, + createUnavailableToolStub, }; diff --git a/packages/api/src/mcp/UserConnectionManager.ts b/packages/api/src/mcp/UserConnectionManager.ts index c0ecd18fe2..0828b1720a 100644 --- a/packages/api/src/mcp/UserConnectionManager.ts +++ b/packages/api/src/mcp/UserConnectionManager.ts @@ -65,6 +65,9 @@ export abstract class UserConnectionManager { const userServerMap = this.userConnections.get(userId); let connection = forceNew ? undefined : userServerMap?.get(serverName); + if (forceNew) { + MCPConnection.clearCooldown(serverName); + } const now = Date.now(); // Check if user is idle diff --git a/packages/api/src/mcp/__tests__/MCPConnection.test.ts b/packages/api/src/mcp/__tests__/MCPConnection.test.ts index 4cca1b3316..5cb5606d57 100644 --- a/packages/api/src/mcp/__tests__/MCPConnection.test.ts +++ b/packages/api/src/mcp/__tests__/MCPConnection.test.ts @@ -559,3 +559,242 @@ describe('extractSSEErrorMessage', () => { }); }); }); + +/** + * Tests for circuit breaker logic. + * + * Uses standalone implementations that mirror the static/private circuit breaker + * methods in MCPConnection. Same approach as the error detection tests above. + */ +describe('MCPConnection Circuit Breaker', () => { + /** 5 cycles within 60s triggers a 30s cooldown */ + const CB_MAX_CYCLES = 5; + const CB_CYCLE_WINDOW_MS = 60_000; + const CB_CYCLE_COOLDOWN_MS = 30_000; + + /** 3 failed rounds within 120s triggers exponential backoff (30s - 300s) */ + const CB_MAX_FAILED_ROUNDS = 3; + const CB_FAILED_WINDOW_MS = 120_000; + const CB_BASE_BACKOFF_MS = 30_000; + const CB_MAX_BACKOFF_MS = 300_000; + + interface CircuitBreakerState { + cycleCount: number; + cycleWindowStart: number; + cooldownUntil: number; + failedRounds: number; + failedWindowStart: number; + failedBackoffUntil: number; + } + + function createCB(): CircuitBreakerState { + return { + cycleCount: 0, + cycleWindowStart: Date.now(), + cooldownUntil: 0, + failedRounds: 0, + failedWindowStart: Date.now(), + failedBackoffUntil: 0, + }; + } + + function isCircuitOpen(cb: CircuitBreakerState): boolean { + const now = Date.now(); + return now < cb.cooldownUntil || now < cb.failedBackoffUntil; + } + + function recordCycle(cb: CircuitBreakerState): void { + const now = Date.now(); + if (now - cb.cycleWindowStart > CB_CYCLE_WINDOW_MS) { + cb.cycleCount = 0; + cb.cycleWindowStart = now; + } + cb.cycleCount++; + if (cb.cycleCount >= CB_MAX_CYCLES) { + cb.cooldownUntil = now + CB_CYCLE_COOLDOWN_MS; + cb.cycleCount = 0; + cb.cycleWindowStart = now; + } + } + + function recordFailedRound(cb: CircuitBreakerState): void { + const now = Date.now(); + if (now - cb.failedWindowStart > CB_FAILED_WINDOW_MS) { + cb.failedRounds = 0; + cb.failedWindowStart = now; + } + cb.failedRounds++; + if (cb.failedRounds >= CB_MAX_FAILED_ROUNDS) { + const backoff = Math.min( + CB_BASE_BACKOFF_MS * Math.pow(2, cb.failedRounds - CB_MAX_FAILED_ROUNDS), + CB_MAX_BACKOFF_MS, + ); + cb.failedBackoffUntil = now + backoff; + } + } + + function resetFailedRounds(cb: CircuitBreakerState): void { + cb.failedRounds = 0; + cb.failedWindowStart = Date.now(); + cb.failedBackoffUntil = 0; + } + + beforeEach(() => { + jest.useFakeTimers(); + }); + + afterEach(() => { + jest.useRealTimers(); + }); + + describe('cycle tracking', () => { + it('should not trigger cooldown for fewer than 5 cycles', () => { + const now = Date.now(); + jest.setSystemTime(now); + + const cb = createCB(); + for (let i = 0; i < CB_MAX_CYCLES - 1; i++) { + recordCycle(cb); + } + expect(isCircuitOpen(cb)).toBe(false); + }); + + it('should trigger 30s cooldown after 5 cycles within 60s', () => { + const now = Date.now(); + jest.setSystemTime(now); + + const cb = createCB(); + for (let i = 0; i < CB_MAX_CYCLES; i++) { + recordCycle(cb); + } + expect(isCircuitOpen(cb)).toBe(true); + + jest.advanceTimersByTime(29_000); + expect(isCircuitOpen(cb)).toBe(true); + + jest.advanceTimersByTime(1_000); + expect(isCircuitOpen(cb)).toBe(false); + }); + + it('should reset cycle count when window expires', () => { + const now = Date.now(); + jest.setSystemTime(now); + + const cb = createCB(); + for (let i = 0; i < CB_MAX_CYCLES - 1; i++) { + recordCycle(cb); + } + + jest.advanceTimersByTime(CB_CYCLE_WINDOW_MS + 1); + + recordCycle(cb); + expect(isCircuitOpen(cb)).toBe(false); + }); + }); + + describe('failed round tracking', () => { + it('should not trigger backoff for fewer than 3 failures', () => { + const now = Date.now(); + jest.setSystemTime(now); + + const cb = createCB(); + for (let i = 0; i < CB_MAX_FAILED_ROUNDS - 1; i++) { + recordFailedRound(cb); + } + expect(isCircuitOpen(cb)).toBe(false); + }); + + it('should trigger 30s backoff after 3 failures within 120s', () => { + const now = Date.now(); + jest.setSystemTime(now); + + const cb = createCB(); + for (let i = 0; i < CB_MAX_FAILED_ROUNDS; i++) { + recordFailedRound(cb); + } + expect(isCircuitOpen(cb)).toBe(true); + + jest.advanceTimersByTime(CB_BASE_BACKOFF_MS); + expect(isCircuitOpen(cb)).toBe(false); + }); + + it('should use exponential backoff based on failure count', () => { + jest.setSystemTime(Date.now()); + + const cb = createCB(); + + for (let i = 0; i < 3; i++) { + recordFailedRound(cb); + } + expect(cb.failedBackoffUntil - Date.now()).toBe(30_000); + + recordFailedRound(cb); + expect(cb.failedBackoffUntil - Date.now()).toBe(60_000); + + recordFailedRound(cb); + expect(cb.failedBackoffUntil - Date.now()).toBe(120_000); + + recordFailedRound(cb); + expect(cb.failedBackoffUntil - Date.now()).toBe(240_000); + + // capped at 300s + recordFailedRound(cb); + expect(cb.failedBackoffUntil - Date.now()).toBe(300_000); + }); + + it('should reset failed window when window expires', () => { + const now = Date.now(); + jest.setSystemTime(now); + + const cb = createCB(); + recordFailedRound(cb); + recordFailedRound(cb); + + jest.advanceTimersByTime(CB_FAILED_WINDOW_MS + 1); + + recordFailedRound(cb); + expect(isCircuitOpen(cb)).toBe(false); + }); + }); + + describe('resetFailedRounds', () => { + it('should clear failed round state on successful connection', () => { + const now = Date.now(); + jest.setSystemTime(now); + + const cb = createCB(); + for (let i = 0; i < CB_MAX_FAILED_ROUNDS; i++) { + recordFailedRound(cb); + } + expect(isCircuitOpen(cb)).toBe(true); + + resetFailedRounds(cb); + expect(isCircuitOpen(cb)).toBe(false); + expect(cb.failedRounds).toBe(0); + expect(cb.failedBackoffUntil).toBe(0); + }); + }); + + describe('clearCooldown (registry deletion)', () => { + it('should allow connections after clearing circuit breaker state', () => { + const now = Date.now(); + jest.setSystemTime(now); + + const registry = new Map(); + const serverName = 'test-server'; + + const cb = createCB(); + registry.set(serverName, cb); + + for (let i = 0; i < CB_MAX_CYCLES; i++) { + recordCycle(cb); + } + expect(isCircuitOpen(cb)).toBe(true); + + registry.delete(serverName); + + const newCb = createCB(); + expect(isCircuitOpen(newCb)).toBe(false); + }); + }); +}); diff --git a/packages/api/src/mcp/__tests__/MCPConnectionAgentLifecycle.test.ts b/packages/api/src/mcp/__tests__/MCPConnectionAgentLifecycle.test.ts index 14e0694558..281bd590db 100644 --- a/packages/api/src/mcp/__tests__/MCPConnectionAgentLifecycle.test.ts +++ b/packages/api/src/mcp/__tests__/MCPConnectionAgentLifecycle.test.ts @@ -207,6 +207,7 @@ describe('MCPConnection Agent lifecycle – streamable-http', () => { }); afterEach(async () => { + MCPConnection.clearCooldown('test'); await safeDisconnect(conn); conn = null; jest.restoreAllMocks(); @@ -366,6 +367,7 @@ describe('MCPConnection Agent lifecycle – SSE', () => { }); afterEach(async () => { + MCPConnection.clearCooldown('test-sse'); await safeDisconnect(conn); conn = null; jest.restoreAllMocks(); @@ -453,6 +455,7 @@ describe('Regression: old per-request Agent pattern leaks agents', () => { }); afterEach(async () => { + MCPConnection.clearCooldown('test-regression'); await safeDisconnect(conn); conn = null; jest.restoreAllMocks(); @@ -675,6 +678,7 @@ describe('MCPConnection SSE GET stream recovery – integration', () => { }); afterEach(async () => { + MCPConnection.clearCooldown('test-sse-recovery'); await safeDisconnect(conn); conn = null; jest.restoreAllMocks(); diff --git a/packages/api/src/mcp/connection.ts b/packages/api/src/mcp/connection.ts index 83f1af1824..cac0a4afc5 100644 --- a/packages/api/src/mcp/connection.ts +++ b/packages/api/src/mcp/connection.ts @@ -71,6 +71,25 @@ const FIVE_MINUTES = 5 * 60 * 1000; const DEFAULT_TIMEOUT = 60000; /** SSE connections through proxies may need longer initial handshake time */ const SSE_CONNECT_TIMEOUT = 120000; +const DEFAULT_INIT_TIMEOUT = 30000; + +interface CircuitBreakerState { + cycleCount: number; + cycleWindowStart: number; + cooldownUntil: number; + failedRounds: number; + failedWindowStart: number; + failedBackoffUntil: number; +} + +const CB_MAX_CYCLES = 5; +const CB_CYCLE_WINDOW_MS = 60_000; +const CB_CYCLE_COOLDOWN_MS = 30_000; + +const CB_MAX_FAILED_ROUNDS = 3; +const CB_FAILED_WINDOW_MS = 120_000; +const CB_BASE_BACKOFF_MS = 30_000; +const CB_MAX_BACKOFF_MS = 300_000; /** Default body timeout for Streamable HTTP GET SSE streams that idle between server pushes */ const DEFAULT_SSE_READ_TIMEOUT = FIVE_MINUTES; @@ -274,6 +293,80 @@ export class MCPConnection extends EventEmitter { */ public readonly createdAt: number; + private static circuitBreakers: Map = new Map(); + + public static clearCooldown(serverName: string): void { + MCPConnection.circuitBreakers.delete(serverName); + logger.debug(`[MCP][${serverName}] Circuit breaker state cleared`); + } + + private getCircuitBreaker(): CircuitBreakerState { + let cb = MCPConnection.circuitBreakers.get(this.serverName); + if (!cb) { + cb = { + cycleCount: 0, + cycleWindowStart: Date.now(), + cooldownUntil: 0, + failedRounds: 0, + failedWindowStart: Date.now(), + failedBackoffUntil: 0, + }; + MCPConnection.circuitBreakers.set(this.serverName, cb); + } + return cb; + } + + private isCircuitOpen(): boolean { + const cb = this.getCircuitBreaker(); + const now = Date.now(); + return now < cb.cooldownUntil || now < cb.failedBackoffUntil; + } + + private recordCycle(): void { + const cb = this.getCircuitBreaker(); + const now = Date.now(); + if (now - cb.cycleWindowStart > CB_CYCLE_WINDOW_MS) { + cb.cycleCount = 0; + cb.cycleWindowStart = now; + } + cb.cycleCount++; + if (cb.cycleCount >= CB_MAX_CYCLES) { + cb.cooldownUntil = now + CB_CYCLE_COOLDOWN_MS; + cb.cycleCount = 0; + cb.cycleWindowStart = now; + logger.warn( + `${this.getLogPrefix()} Circuit breaker: too many cycles, cooling down for ${CB_CYCLE_COOLDOWN_MS}ms`, + ); + } + } + + private recordFailedRound(): void { + const cb = this.getCircuitBreaker(); + const now = Date.now(); + if (now - cb.failedWindowStart > CB_FAILED_WINDOW_MS) { + cb.failedRounds = 0; + cb.failedWindowStart = now; + } + cb.failedRounds++; + if (cb.failedRounds >= CB_MAX_FAILED_ROUNDS) { + const backoff = Math.min( + CB_BASE_BACKOFF_MS * Math.pow(2, cb.failedRounds - CB_MAX_FAILED_ROUNDS), + CB_MAX_BACKOFF_MS, + ); + cb.failedBackoffUntil = now + backoff; + logger.warn( + `${this.getLogPrefix()} Circuit breaker: too many failures, backing off for ${backoff}ms`, + ); + } + } + + private resetFailedRounds(): void { + const cb = this.getCircuitBreaker(); + cb.failedRounds = 0; + cb.failedWindowStart = Date.now(); + cb.failedBackoffUntil = 0; + } + setRequestHeaders(headers: Record | null): void { if (!headers) { return; @@ -686,6 +779,12 @@ export class MCPConnection extends EventEmitter { return; } + if (this.isCircuitOpen()) { + this.connectionState = 'error'; + this.emit('connectionChange', 'error'); + throw new Error(`${this.getLogPrefix()} Circuit breaker is open, connection attempt blocked`); + } + this.emit('connectionChange', 'connecting'); this.connectPromise = (async () => { @@ -703,7 +802,7 @@ export class MCPConnection extends EventEmitter { this.transport = await runOutsideTracing(() => this.constructTransport(this.options)); this.patchTransportSend(); - const connectTimeout = this.options.initTimeout ?? 120000; + const connectTimeout = this.options.initTimeout ?? DEFAULT_INIT_TIMEOUT; await runOutsideTracing(() => withTimeout( this.client.connect(this.transport!), @@ -716,6 +815,7 @@ export class MCPConnection extends EventEmitter { this.connectionState = 'connected'; this.emit('connectionChange', 'connected'); this.reconnectAttempts = 0; + this.resetFailedRounds(); } catch (error) { // Check if it's a rate limit error - stop immediately to avoid making it worse if (this.isRateLimitError(error)) { @@ -817,6 +917,7 @@ export class MCPConnection extends EventEmitter { this.connectionState = 'error'; this.emit('connectionChange', 'error'); + this.recordFailedRound(); throw error; } finally { this.connectPromise = null; @@ -866,7 +967,8 @@ export class MCPConnection extends EventEmitter { async connect(): Promise { try { - await this.disconnect(); + // preserve cycle tracking across reconnects so the circuit breaker can detect rapid cycling + await this.disconnect(false); await this.connectClient(); if (!(await this.isConnected())) { throw new Error('Connection not established'); @@ -906,7 +1008,7 @@ export class MCPConnection extends EventEmitter { isTransient, } = extractSSEErrorMessage(error); - if (errorCode === 404) { + if (errorCode === 400 || errorCode === 404 || errorCode === 405) { const hasSession = 'sessionId' in transport && (transport as { sessionId?: string }).sessionId != null && @@ -914,14 +1016,14 @@ export class MCPConnection extends EventEmitter { if (!hasSession && errorMessage.toLowerCase().includes('failed to open sse stream')) { logger.warn( - `${this.getLogPrefix()} SSE stream not available (404), no session. Ignoring.`, + `${this.getLogPrefix()} SSE stream not available (${errorCode}), no session. Ignoring.`, ); return; } if (hasSession) { logger.warn( - `${this.getLogPrefix()} 404 with active session — session lost, triggering reconnection.`, + `${this.getLogPrefix()} ${errorCode} with active session — session lost, triggering reconnection.`, ); } } @@ -992,7 +1094,7 @@ export class MCPConnection extends EventEmitter { await Promise.all(closing); } - public async disconnect(): Promise { + public async disconnect(resetCycleTracking = true): Promise { try { if (this.transport) { await this.client.close(); @@ -1006,6 +1108,9 @@ export class MCPConnection extends EventEmitter { this.emit('connectionChange', 'disconnected'); } finally { this.connectPromise = null; + if (!resetCycleTracking) { + this.recordCycle(); + } } } diff --git a/packages/api/src/mcp/oauth/OAuthReconnectionManager.test.ts b/packages/api/src/mcp/oauth/OAuthReconnectionManager.test.ts index d3447eaeb8..d889da4f2f 100644 --- a/packages/api/src/mcp/oauth/OAuthReconnectionManager.test.ts +++ b/packages/api/src/mcp/oauth/OAuthReconnectionManager.test.ts @@ -253,17 +253,21 @@ describe('OAuthReconnectionManager', () => { expect(mockMCPManager.disconnectUserConnection).toHaveBeenCalledWith(userId, 'server1'); }); - it('should not reconnect servers with expired tokens', async () => { + it('should not reconnect servers with expired tokens and no refresh token', async () => { const userId = 'user-123'; const oauthServers = new Set(['server1']); (mockRegistryInstance.getOAuthServers as jest.Mock).mockResolvedValue(oauthServers); - // server1: has expired token - tokenMethods.findToken.mockResolvedValue({ - userId, - identifier: 'mcp:server1', - expiresAt: new Date(Date.now() - 3600000), // 1 hour ago - } as unknown as MCPOAuthTokens); + tokenMethods.findToken.mockImplementation(async ({ identifier }) => { + if (identifier === 'mcp:server1') { + return { + userId, + identifier, + expiresAt: new Date(Date.now() - 3600000), + } as unknown as MCPOAuthTokens; + } + return null; + }); await reconnectionManager.reconnectServers(userId); @@ -272,6 +276,87 @@ describe('OAuthReconnectionManager', () => { expect(mockMCPManager.getUserConnection).not.toHaveBeenCalled(); }); + it('should reconnect servers with expired access token but valid refresh token', async () => { + const userId = 'user-123'; + const oauthServers = new Set(['server1']); + (mockRegistryInstance.getOAuthServers as jest.Mock).mockResolvedValue(oauthServers); + + tokenMethods.findToken.mockImplementation(async ({ identifier }) => { + if (identifier === 'mcp:server1') { + return { + userId, + identifier, + expiresAt: new Date(Date.now() - 3600000), + } as unknown as MCPOAuthTokens; + } + if (identifier === 'mcp:server1:refresh') { + return { + userId, + identifier, + } as unknown as MCPOAuthTokens; + } + return null; + }); + + const mockNewConnection = { + isConnected: jest.fn().mockResolvedValue(true), + disconnect: jest.fn(), + }; + mockMCPManager.getUserConnection.mockResolvedValue( + mockNewConnection as unknown as MCPConnection, + ); + (mockRegistryInstance.getServerConfig as jest.Mock).mockResolvedValue( + {} as unknown as MCPOptions, + ); + + await reconnectionManager.reconnectServers(userId); + + expect(reconnectionTracker.isActive(userId, 'server1')).toBe(true); + + await new Promise((resolve) => setTimeout(resolve, 100)); + + expect(mockMCPManager.getUserConnection).toHaveBeenCalledWith( + expect.objectContaining({ serverName: 'server1' }), + ); + }); + + it('should reconnect when access token is TTL-deleted but refresh token exists', async () => { + const userId = 'user-123'; + const oauthServers = new Set(['server1']); + (mockRegistryInstance.getOAuthServers as jest.Mock).mockResolvedValue(oauthServers); + + tokenMethods.findToken.mockImplementation(async ({ identifier }) => { + if (identifier === 'mcp:server1:refresh') { + return { + userId, + identifier, + } as unknown as MCPOAuthTokens; + } + return null; + }); + + const mockNewConnection = { + isConnected: jest.fn().mockResolvedValue(true), + disconnect: jest.fn(), + }; + mockMCPManager.getUserConnection.mockResolvedValue( + mockNewConnection as unknown as MCPConnection, + ); + (mockRegistryInstance.getServerConfig as jest.Mock).mockResolvedValue( + {} as unknown as MCPOptions, + ); + + await reconnectionManager.reconnectServers(userId); + + expect(reconnectionTracker.isActive(userId, 'server1')).toBe(true); + + await new Promise((resolve) => setTimeout(resolve, 100)); + + expect(mockMCPManager.getUserConnection).toHaveBeenCalledWith( + expect.objectContaining({ serverName: 'server1' }), + ); + }); + it('should handle connection that returns but is not connected', async () => { const userId = 'user-123'; const oauthServers = new Set(['server1']); @@ -336,6 +421,69 @@ describe('OAuthReconnectionManager', () => { }); }); + describe('reconnectServer', () => { + let reconnectionTracker: OAuthReconnectionTracker; + beforeEach(async () => { + reconnectionTracker = new OAuthReconnectionTracker(); + reconnectionManager = await OAuthReconnectionManager.createInstance( + flowManager, + tokenMethods, + reconnectionTracker, + ); + }); + + it('should return true on successful reconnection', async () => { + const userId = 'user-123'; + const serverName = 'server1'; + + const mockConnection = { + isConnected: jest.fn().mockResolvedValue(true), + disconnect: jest.fn(), + }; + mockMCPManager.getUserConnection.mockResolvedValue( + mockConnection as unknown as MCPConnection, + ); + (mockRegistryInstance.getServerConfig as jest.Mock).mockResolvedValue( + {} as unknown as MCPOptions, + ); + + const result = await reconnectionManager.reconnectServer(userId, serverName); + expect(result).toBe(true); + }); + + it('should return false on failed reconnection', async () => { + const userId = 'user-123'; + const serverName = 'server1'; + + mockMCPManager.getUserConnection.mockRejectedValue(new Error('Connection failed')); + (mockRegistryInstance.getServerConfig as jest.Mock).mockResolvedValue( + {} as unknown as MCPOptions, + ); + + const result = await reconnectionManager.reconnectServer(userId, serverName); + expect(result).toBe(false); + }); + + it('should return false when MCPManager is not available', async () => { + const userId = 'user-123'; + const serverName = 'server1'; + + (OAuthReconnectionManager as unknown as { instance: null }).instance = null; + (MCPManager.getInstance as jest.Mock).mockImplementation(() => { + throw new Error('MCPManager has not been initialized.'); + }); + + const managerWithoutMCP = await OAuthReconnectionManager.createInstance( + flowManager, + tokenMethods, + reconnectionTracker, + ); + + const result = await managerWithoutMCP.reconnectServer(userId, serverName); + expect(result).toBe(false); + }); + }); + describe('reconnection staggering', () => { let reconnectionTracker: OAuthReconnectionTracker; diff --git a/packages/api/src/mcp/oauth/OAuthReconnectionManager.ts b/packages/api/src/mcp/oauth/OAuthReconnectionManager.ts index f14c4abf15..7afe992772 100644 --- a/packages/api/src/mcp/oauth/OAuthReconnectionManager.ts +++ b/packages/api/src/mcp/oauth/OAuthReconnectionManager.ts @@ -96,6 +96,24 @@ export class OAuthReconnectionManager { } } + /** + * Attempts to reconnect a single OAuth MCP server. + * @returns true if reconnection succeeded, false otherwise. + */ + public async reconnectServer(userId: string, serverName: string): Promise { + if (this.mcpManager == null) { + return false; + } + + this.reconnectionsTracker.setActive(userId, serverName); + try { + await this.tryReconnect(userId, serverName); + return !this.reconnectionsTracker.isFailed(userId, serverName); + } catch { + return false; + } + } + public clearReconnection(userId: string, serverName: string) { this.reconnectionsTracker.removeFailed(userId, serverName); this.reconnectionsTracker.removeActive(userId, serverName); @@ -174,23 +192,31 @@ export class OAuthReconnectionManager { } } - // if the server has no tokens for the user, don't attempt to reconnect + // if the server has a valid (non-expired) access token, allow reconnect const accessToken = await this.tokenMethods.findToken({ userId, type: 'mcp_oauth', identifier: `mcp:${serverName}`, }); - if (accessToken == null) { + + if (accessToken != null) { + const now = new Date(); + if (!accessToken.expiresAt || accessToken.expiresAt >= now) { + return true; + } + } + + // if the access token is expired or TTL-deleted, fall back to refresh token + const refreshToken = await this.tokenMethods.findToken({ + userId, + type: 'mcp_oauth', + identifier: `mcp:${serverName}:refresh`, + }); + + if (refreshToken == null) { return false; } - // if the token has expired, don't attempt to reconnect - const now = new Date(); - if (accessToken.expiresAt && accessToken.expiresAt < now) { - return false; - } - - // …otherwise, we're good to go with the reconnect attempt return true; } } diff --git a/packages/api/src/mcp/oauth/OAuthReconnectionTracker.test.ts b/packages/api/src/mcp/oauth/OAuthReconnectionTracker.test.ts index 68ac1d027e..206fe96ef1 100644 --- a/packages/api/src/mcp/oauth/OAuthReconnectionTracker.test.ts +++ b/packages/api/src/mcp/oauth/OAuthReconnectionTracker.test.ts @@ -397,6 +397,101 @@ describe('OAuthReconnectTracker', () => { }); }); + describe('cooldown-based retry', () => { + beforeEach(() => { + jest.useFakeTimers(); + }); + + afterEach(() => { + jest.useRealTimers(); + }); + + it('should return true from isFailed within first cooldown period (5 min)', () => { + const now = Date.now(); + jest.setSystemTime(now); + + tracker.setFailed(userId, serverName); + expect(tracker.isFailed(userId, serverName)).toBe(true); + + jest.advanceTimersByTime(4 * 60 * 1000); + expect(tracker.isFailed(userId, serverName)).toBe(true); + }); + + it('should return false from isFailed after first cooldown elapses (5 min)', () => { + const now = Date.now(); + jest.setSystemTime(now); + + tracker.setFailed(userId, serverName); + expect(tracker.isFailed(userId, serverName)).toBe(true); + + jest.advanceTimersByTime(5 * 60 * 1000); + expect(tracker.isFailed(userId, serverName)).toBe(false); + }); + + it('should use progressive cooldown schedule (5m, 10m, 20m, 30m)', () => { + const now = Date.now(); + jest.setSystemTime(now); + + // First failure: 5 min cooldown + tracker.setFailed(userId, serverName); + jest.advanceTimersByTime(5 * 60 * 1000); + expect(tracker.isFailed(userId, serverName)).toBe(false); + + // Second failure: 10 min cooldown + tracker.setFailed(userId, serverName); + jest.advanceTimersByTime(9 * 60 * 1000); + expect(tracker.isFailed(userId, serverName)).toBe(true); + jest.advanceTimersByTime(1 * 60 * 1000); + expect(tracker.isFailed(userId, serverName)).toBe(false); + + // Third failure: 20 min cooldown + tracker.setFailed(userId, serverName); + jest.advanceTimersByTime(19 * 60 * 1000); + expect(tracker.isFailed(userId, serverName)).toBe(true); + jest.advanceTimersByTime(1 * 60 * 1000); + expect(tracker.isFailed(userId, serverName)).toBe(false); + + // Fourth failure: 30 min cooldown + tracker.setFailed(userId, serverName); + jest.advanceTimersByTime(29 * 60 * 1000); + expect(tracker.isFailed(userId, serverName)).toBe(true); + jest.advanceTimersByTime(1 * 60 * 1000); + expect(tracker.isFailed(userId, serverName)).toBe(false); + }); + + it('should cap cooldown at 30 min for attempts beyond 4', () => { + const now = Date.now(); + jest.setSystemTime(now); + + for (let i = 0; i < 5; i++) { + tracker.setFailed(userId, serverName); + jest.advanceTimersByTime(30 * 60 * 1000); + } + + tracker.setFailed(userId, serverName); + jest.advanceTimersByTime(29 * 60 * 1000); + expect(tracker.isFailed(userId, serverName)).toBe(true); + jest.advanceTimersByTime(1 * 60 * 1000); + expect(tracker.isFailed(userId, serverName)).toBe(false); + }); + + it('should fully reset metadata on removeFailed', () => { + const now = Date.now(); + jest.setSystemTime(now); + + tracker.setFailed(userId, serverName); + tracker.setFailed(userId, serverName); + tracker.setFailed(userId, serverName); + + tracker.removeFailed(userId, serverName); + expect(tracker.isFailed(userId, serverName)).toBe(false); + + tracker.setFailed(userId, serverName); + jest.advanceTimersByTime(5 * 60 * 1000); + expect(tracker.isFailed(userId, serverName)).toBe(false); + }); + }); + describe('timestamp tracking edge cases', () => { beforeEach(() => { jest.useFakeTimers(); diff --git a/packages/api/src/mcp/oauth/OAuthReconnectionTracker.ts b/packages/api/src/mcp/oauth/OAuthReconnectionTracker.ts index 9f6ef4abd3..504ea7d43a 100644 --- a/packages/api/src/mcp/oauth/OAuthReconnectionTracker.ts +++ b/packages/api/src/mcp/oauth/OAuthReconnectionTracker.ts @@ -1,6 +1,12 @@ +interface FailedMeta { + attempts: number; + lastFailedAt: number; +} + +const COOLDOWN_SCHEDULE_MS = [5 * 60 * 1000, 10 * 60 * 1000, 20 * 60 * 1000, 30 * 60 * 1000]; + export class OAuthReconnectionTracker { - /** Map of userId -> Set of serverNames that have failed reconnection */ - private failed: Map> = new Map(); + private failedMeta: Map> = new Map(); /** Map of userId -> Set of serverNames that are actively reconnecting */ private active: Map> = new Map(); /** Map of userId:serverName -> timestamp when reconnection started */ @@ -9,7 +15,17 @@ export class OAuthReconnectionTracker { private readonly RECONNECTION_TIMEOUT_MS = 3 * 60 * 1000; // 3 minutes public isFailed(userId: string, serverName: string): boolean { - return this.failed.get(userId)?.has(serverName) ?? false; + const meta = this.failedMeta.get(userId)?.get(serverName); + if (!meta) { + return false; + } + const idx = Math.min(meta.attempts - 1, COOLDOWN_SCHEDULE_MS.length - 1); + const cooldown = COOLDOWN_SCHEDULE_MS[idx]; + const elapsed = Date.now() - meta.lastFailedAt; + if (elapsed >= cooldown) { + return false; + } + return true; } /** Check if server is in the active set (original simple check) */ @@ -48,11 +64,15 @@ export class OAuthReconnectionTracker { } public setFailed(userId: string, serverName: string): void { - if (!this.failed.has(userId)) { - this.failed.set(userId, new Set()); + if (!this.failedMeta.has(userId)) { + this.failedMeta.set(userId, new Map()); } - - this.failed.get(userId)?.add(serverName); + const userMap = this.failedMeta.get(userId)!; + const existing = userMap.get(serverName); + userMap.set(serverName, { + attempts: (existing?.attempts ?? 0) + 1, + lastFailedAt: Date.now(), + }); } public setActive(userId: string, serverName: string): void { @@ -68,10 +88,10 @@ export class OAuthReconnectionTracker { } public removeFailed(userId: string, serverName: string): void { - const userServers = this.failed.get(userId); - userServers?.delete(serverName); - if (userServers?.size === 0) { - this.failed.delete(userId); + const userMap = this.failedMeta.get(userId); + userMap?.delete(serverName); + if (userMap?.size === 0) { + this.failedMeta.delete(userId); } } @@ -94,7 +114,7 @@ export class OAuthReconnectionTracker { activeTimestamps: number; } { return { - usersWithFailedServers: this.failed.size, + usersWithFailedServers: this.failedMeta.size, usersWithActiveReconnections: this.active.size, activeTimestamps: this.activeTimestamps.size, };