⛈️ fix: MCP Reconnection Storm Prevention with Circuit Breaker, Backoff, and Tool Stubs (#12162)

* fix: MCP reconnection stability - circuit breaker, throttling, and cooldown retry

* Comment and logging cleanup

* fix broken tests
This commit is contained in:
matt burnett 2026-03-10 11:21:36 -07:00 committed by GitHub
parent cfbe812d63
commit ad5c51f62b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 736 additions and 38 deletions

View file

@ -65,6 +65,9 @@ export abstract class UserConnectionManager {
const userServerMap = this.userConnections.get(userId);
let connection = forceNew ? undefined : userServerMap?.get(serverName);
if (forceNew) {
MCPConnection.clearCooldown(serverName);
}
const now = Date.now();
// Check if user is idle

View file

@ -559,3 +559,242 @@ describe('extractSSEErrorMessage', () => {
});
});
});
/**
* Tests for circuit breaker logic.
*
* Uses standalone implementations that mirror the static/private circuit breaker
* methods in MCPConnection. Same approach as the error detection tests above.
*/
describe('MCPConnection Circuit Breaker', () => {
/** 5 cycles within 60s triggers a 30s cooldown */
const CB_MAX_CYCLES = 5;
const CB_CYCLE_WINDOW_MS = 60_000;
const CB_CYCLE_COOLDOWN_MS = 30_000;
/** 3 failed rounds within 120s triggers exponential backoff (30s - 300s) */
const CB_MAX_FAILED_ROUNDS = 3;
const CB_FAILED_WINDOW_MS = 120_000;
const CB_BASE_BACKOFF_MS = 30_000;
const CB_MAX_BACKOFF_MS = 300_000;
interface CircuitBreakerState {
cycleCount: number;
cycleWindowStart: number;
cooldownUntil: number;
failedRounds: number;
failedWindowStart: number;
failedBackoffUntil: number;
}
function createCB(): CircuitBreakerState {
return {
cycleCount: 0,
cycleWindowStart: Date.now(),
cooldownUntil: 0,
failedRounds: 0,
failedWindowStart: Date.now(),
failedBackoffUntil: 0,
};
}
function isCircuitOpen(cb: CircuitBreakerState): boolean {
const now = Date.now();
return now < cb.cooldownUntil || now < cb.failedBackoffUntil;
}
function recordCycle(cb: CircuitBreakerState): void {
const now = Date.now();
if (now - cb.cycleWindowStart > CB_CYCLE_WINDOW_MS) {
cb.cycleCount = 0;
cb.cycleWindowStart = now;
}
cb.cycleCount++;
if (cb.cycleCount >= CB_MAX_CYCLES) {
cb.cooldownUntil = now + CB_CYCLE_COOLDOWN_MS;
cb.cycleCount = 0;
cb.cycleWindowStart = now;
}
}
function recordFailedRound(cb: CircuitBreakerState): void {
const now = Date.now();
if (now - cb.failedWindowStart > CB_FAILED_WINDOW_MS) {
cb.failedRounds = 0;
cb.failedWindowStart = now;
}
cb.failedRounds++;
if (cb.failedRounds >= CB_MAX_FAILED_ROUNDS) {
const backoff = Math.min(
CB_BASE_BACKOFF_MS * Math.pow(2, cb.failedRounds - CB_MAX_FAILED_ROUNDS),
CB_MAX_BACKOFF_MS,
);
cb.failedBackoffUntil = now + backoff;
}
}
function resetFailedRounds(cb: CircuitBreakerState): void {
cb.failedRounds = 0;
cb.failedWindowStart = Date.now();
cb.failedBackoffUntil = 0;
}
beforeEach(() => {
jest.useFakeTimers();
});
afterEach(() => {
jest.useRealTimers();
});
describe('cycle tracking', () => {
it('should not trigger cooldown for fewer than 5 cycles', () => {
const now = Date.now();
jest.setSystemTime(now);
const cb = createCB();
for (let i = 0; i < CB_MAX_CYCLES - 1; i++) {
recordCycle(cb);
}
expect(isCircuitOpen(cb)).toBe(false);
});
it('should trigger 30s cooldown after 5 cycles within 60s', () => {
const now = Date.now();
jest.setSystemTime(now);
const cb = createCB();
for (let i = 0; i < CB_MAX_CYCLES; i++) {
recordCycle(cb);
}
expect(isCircuitOpen(cb)).toBe(true);
jest.advanceTimersByTime(29_000);
expect(isCircuitOpen(cb)).toBe(true);
jest.advanceTimersByTime(1_000);
expect(isCircuitOpen(cb)).toBe(false);
});
it('should reset cycle count when window expires', () => {
const now = Date.now();
jest.setSystemTime(now);
const cb = createCB();
for (let i = 0; i < CB_MAX_CYCLES - 1; i++) {
recordCycle(cb);
}
jest.advanceTimersByTime(CB_CYCLE_WINDOW_MS + 1);
recordCycle(cb);
expect(isCircuitOpen(cb)).toBe(false);
});
});
describe('failed round tracking', () => {
it('should not trigger backoff for fewer than 3 failures', () => {
const now = Date.now();
jest.setSystemTime(now);
const cb = createCB();
for (let i = 0; i < CB_MAX_FAILED_ROUNDS - 1; i++) {
recordFailedRound(cb);
}
expect(isCircuitOpen(cb)).toBe(false);
});
it('should trigger 30s backoff after 3 failures within 120s', () => {
const now = Date.now();
jest.setSystemTime(now);
const cb = createCB();
for (let i = 0; i < CB_MAX_FAILED_ROUNDS; i++) {
recordFailedRound(cb);
}
expect(isCircuitOpen(cb)).toBe(true);
jest.advanceTimersByTime(CB_BASE_BACKOFF_MS);
expect(isCircuitOpen(cb)).toBe(false);
});
it('should use exponential backoff based on failure count', () => {
jest.setSystemTime(Date.now());
const cb = createCB();
for (let i = 0; i < 3; i++) {
recordFailedRound(cb);
}
expect(cb.failedBackoffUntil - Date.now()).toBe(30_000);
recordFailedRound(cb);
expect(cb.failedBackoffUntil - Date.now()).toBe(60_000);
recordFailedRound(cb);
expect(cb.failedBackoffUntil - Date.now()).toBe(120_000);
recordFailedRound(cb);
expect(cb.failedBackoffUntil - Date.now()).toBe(240_000);
// capped at 300s
recordFailedRound(cb);
expect(cb.failedBackoffUntil - Date.now()).toBe(300_000);
});
it('should reset failed window when window expires', () => {
const now = Date.now();
jest.setSystemTime(now);
const cb = createCB();
recordFailedRound(cb);
recordFailedRound(cb);
jest.advanceTimersByTime(CB_FAILED_WINDOW_MS + 1);
recordFailedRound(cb);
expect(isCircuitOpen(cb)).toBe(false);
});
});
describe('resetFailedRounds', () => {
it('should clear failed round state on successful connection', () => {
const now = Date.now();
jest.setSystemTime(now);
const cb = createCB();
for (let i = 0; i < CB_MAX_FAILED_ROUNDS; i++) {
recordFailedRound(cb);
}
expect(isCircuitOpen(cb)).toBe(true);
resetFailedRounds(cb);
expect(isCircuitOpen(cb)).toBe(false);
expect(cb.failedRounds).toBe(0);
expect(cb.failedBackoffUntil).toBe(0);
});
});
describe('clearCooldown (registry deletion)', () => {
it('should allow connections after clearing circuit breaker state', () => {
const now = Date.now();
jest.setSystemTime(now);
const registry = new Map<string, CircuitBreakerState>();
const serverName = 'test-server';
const cb = createCB();
registry.set(serverName, cb);
for (let i = 0; i < CB_MAX_CYCLES; i++) {
recordCycle(cb);
}
expect(isCircuitOpen(cb)).toBe(true);
registry.delete(serverName);
const newCb = createCB();
expect(isCircuitOpen(newCb)).toBe(false);
});
});
});

View file

@ -207,6 +207,7 @@ describe('MCPConnection Agent lifecycle streamable-http', () => {
});
afterEach(async () => {
MCPConnection.clearCooldown('test');
await safeDisconnect(conn);
conn = null;
jest.restoreAllMocks();
@ -366,6 +367,7 @@ describe('MCPConnection Agent lifecycle SSE', () => {
});
afterEach(async () => {
MCPConnection.clearCooldown('test-sse');
await safeDisconnect(conn);
conn = null;
jest.restoreAllMocks();
@ -453,6 +455,7 @@ describe('Regression: old per-request Agent pattern leaks agents', () => {
});
afterEach(async () => {
MCPConnection.clearCooldown('test-regression');
await safeDisconnect(conn);
conn = null;
jest.restoreAllMocks();
@ -675,6 +678,7 @@ describe('MCPConnection SSE GET stream recovery integration', () => {
});
afterEach(async () => {
MCPConnection.clearCooldown('test-sse-recovery');
await safeDisconnect(conn);
conn = null;
jest.restoreAllMocks();

View file

@ -71,6 +71,25 @@ const FIVE_MINUTES = 5 * 60 * 1000;
const DEFAULT_TIMEOUT = 60000;
/** SSE connections through proxies may need longer initial handshake time */
const SSE_CONNECT_TIMEOUT = 120000;
const DEFAULT_INIT_TIMEOUT = 30000;
interface CircuitBreakerState {
cycleCount: number;
cycleWindowStart: number;
cooldownUntil: number;
failedRounds: number;
failedWindowStart: number;
failedBackoffUntil: number;
}
const CB_MAX_CYCLES = 5;
const CB_CYCLE_WINDOW_MS = 60_000;
const CB_CYCLE_COOLDOWN_MS = 30_000;
const CB_MAX_FAILED_ROUNDS = 3;
const CB_FAILED_WINDOW_MS = 120_000;
const CB_BASE_BACKOFF_MS = 30_000;
const CB_MAX_BACKOFF_MS = 300_000;
/** Default body timeout for Streamable HTTP GET SSE streams that idle between server pushes */
const DEFAULT_SSE_READ_TIMEOUT = FIVE_MINUTES;
@ -274,6 +293,80 @@ export class MCPConnection extends EventEmitter {
*/
public readonly createdAt: number;
private static circuitBreakers: Map<string, CircuitBreakerState> = new Map();
public static clearCooldown(serverName: string): void {
MCPConnection.circuitBreakers.delete(serverName);
logger.debug(`[MCP][${serverName}] Circuit breaker state cleared`);
}
private getCircuitBreaker(): CircuitBreakerState {
let cb = MCPConnection.circuitBreakers.get(this.serverName);
if (!cb) {
cb = {
cycleCount: 0,
cycleWindowStart: Date.now(),
cooldownUntil: 0,
failedRounds: 0,
failedWindowStart: Date.now(),
failedBackoffUntil: 0,
};
MCPConnection.circuitBreakers.set(this.serverName, cb);
}
return cb;
}
private isCircuitOpen(): boolean {
const cb = this.getCircuitBreaker();
const now = Date.now();
return now < cb.cooldownUntil || now < cb.failedBackoffUntil;
}
private recordCycle(): void {
const cb = this.getCircuitBreaker();
const now = Date.now();
if (now - cb.cycleWindowStart > CB_CYCLE_WINDOW_MS) {
cb.cycleCount = 0;
cb.cycleWindowStart = now;
}
cb.cycleCount++;
if (cb.cycleCount >= CB_MAX_CYCLES) {
cb.cooldownUntil = now + CB_CYCLE_COOLDOWN_MS;
cb.cycleCount = 0;
cb.cycleWindowStart = now;
logger.warn(
`${this.getLogPrefix()} Circuit breaker: too many cycles, cooling down for ${CB_CYCLE_COOLDOWN_MS}ms`,
);
}
}
private recordFailedRound(): void {
const cb = this.getCircuitBreaker();
const now = Date.now();
if (now - cb.failedWindowStart > CB_FAILED_WINDOW_MS) {
cb.failedRounds = 0;
cb.failedWindowStart = now;
}
cb.failedRounds++;
if (cb.failedRounds >= CB_MAX_FAILED_ROUNDS) {
const backoff = Math.min(
CB_BASE_BACKOFF_MS * Math.pow(2, cb.failedRounds - CB_MAX_FAILED_ROUNDS),
CB_MAX_BACKOFF_MS,
);
cb.failedBackoffUntil = now + backoff;
logger.warn(
`${this.getLogPrefix()} Circuit breaker: too many failures, backing off for ${backoff}ms`,
);
}
}
private resetFailedRounds(): void {
const cb = this.getCircuitBreaker();
cb.failedRounds = 0;
cb.failedWindowStart = Date.now();
cb.failedBackoffUntil = 0;
}
setRequestHeaders(headers: Record<string, string> | null): void {
if (!headers) {
return;
@ -686,6 +779,12 @@ export class MCPConnection extends EventEmitter {
return;
}
if (this.isCircuitOpen()) {
this.connectionState = 'error';
this.emit('connectionChange', 'error');
throw new Error(`${this.getLogPrefix()} Circuit breaker is open, connection attempt blocked`);
}
this.emit('connectionChange', 'connecting');
this.connectPromise = (async () => {
@ -703,7 +802,7 @@ export class MCPConnection extends EventEmitter {
this.transport = await runOutsideTracing(() => this.constructTransport(this.options));
this.patchTransportSend();
const connectTimeout = this.options.initTimeout ?? 120000;
const connectTimeout = this.options.initTimeout ?? DEFAULT_INIT_TIMEOUT;
await runOutsideTracing(() =>
withTimeout(
this.client.connect(this.transport!),
@ -716,6 +815,7 @@ export class MCPConnection extends EventEmitter {
this.connectionState = 'connected';
this.emit('connectionChange', 'connected');
this.reconnectAttempts = 0;
this.resetFailedRounds();
} catch (error) {
// Check if it's a rate limit error - stop immediately to avoid making it worse
if (this.isRateLimitError(error)) {
@ -817,6 +917,7 @@ export class MCPConnection extends EventEmitter {
this.connectionState = 'error';
this.emit('connectionChange', 'error');
this.recordFailedRound();
throw error;
} finally {
this.connectPromise = null;
@ -866,7 +967,8 @@ export class MCPConnection extends EventEmitter {
async connect(): Promise<void> {
try {
await this.disconnect();
// preserve cycle tracking across reconnects so the circuit breaker can detect rapid cycling
await this.disconnect(false);
await this.connectClient();
if (!(await this.isConnected())) {
throw new Error('Connection not established');
@ -906,7 +1008,7 @@ export class MCPConnection extends EventEmitter {
isTransient,
} = extractSSEErrorMessage(error);
if (errorCode === 404) {
if (errorCode === 400 || errorCode === 404 || errorCode === 405) {
const hasSession =
'sessionId' in transport &&
(transport as { sessionId?: string }).sessionId != null &&
@ -914,14 +1016,14 @@ export class MCPConnection extends EventEmitter {
if (!hasSession && errorMessage.toLowerCase().includes('failed to open sse stream')) {
logger.warn(
`${this.getLogPrefix()} SSE stream not available (404), no session. Ignoring.`,
`${this.getLogPrefix()} SSE stream not available (${errorCode}), no session. Ignoring.`,
);
return;
}
if (hasSession) {
logger.warn(
`${this.getLogPrefix()} 404 with active session — session lost, triggering reconnection.`,
`${this.getLogPrefix()} ${errorCode} with active session — session lost, triggering reconnection.`,
);
}
}
@ -992,7 +1094,7 @@ export class MCPConnection extends EventEmitter {
await Promise.all(closing);
}
public async disconnect(): Promise<void> {
public async disconnect(resetCycleTracking = true): Promise<void> {
try {
if (this.transport) {
await this.client.close();
@ -1006,6 +1108,9 @@ export class MCPConnection extends EventEmitter {
this.emit('connectionChange', 'disconnected');
} finally {
this.connectPromise = null;
if (!resetCycleTracking) {
this.recordCycle();
}
}
}

View file

@ -253,17 +253,21 @@ describe('OAuthReconnectionManager', () => {
expect(mockMCPManager.disconnectUserConnection).toHaveBeenCalledWith(userId, 'server1');
});
it('should not reconnect servers with expired tokens', async () => {
it('should not reconnect servers with expired tokens and no refresh token', async () => {
const userId = 'user-123';
const oauthServers = new Set(['server1']);
(mockRegistryInstance.getOAuthServers as jest.Mock).mockResolvedValue(oauthServers);
// server1: has expired token
tokenMethods.findToken.mockResolvedValue({
userId,
identifier: 'mcp:server1',
expiresAt: new Date(Date.now() - 3600000), // 1 hour ago
} as unknown as MCPOAuthTokens);
tokenMethods.findToken.mockImplementation(async ({ identifier }) => {
if (identifier === 'mcp:server1') {
return {
userId,
identifier,
expiresAt: new Date(Date.now() - 3600000),
} as unknown as MCPOAuthTokens;
}
return null;
});
await reconnectionManager.reconnectServers(userId);
@ -272,6 +276,87 @@ describe('OAuthReconnectionManager', () => {
expect(mockMCPManager.getUserConnection).not.toHaveBeenCalled();
});
it('should reconnect servers with expired access token but valid refresh token', async () => {
const userId = 'user-123';
const oauthServers = new Set(['server1']);
(mockRegistryInstance.getOAuthServers as jest.Mock).mockResolvedValue(oauthServers);
tokenMethods.findToken.mockImplementation(async ({ identifier }) => {
if (identifier === 'mcp:server1') {
return {
userId,
identifier,
expiresAt: new Date(Date.now() - 3600000),
} as unknown as MCPOAuthTokens;
}
if (identifier === 'mcp:server1:refresh') {
return {
userId,
identifier,
} as unknown as MCPOAuthTokens;
}
return null;
});
const mockNewConnection = {
isConnected: jest.fn().mockResolvedValue(true),
disconnect: jest.fn(),
};
mockMCPManager.getUserConnection.mockResolvedValue(
mockNewConnection as unknown as MCPConnection,
);
(mockRegistryInstance.getServerConfig as jest.Mock).mockResolvedValue(
{} as unknown as MCPOptions,
);
await reconnectionManager.reconnectServers(userId);
expect(reconnectionTracker.isActive(userId, 'server1')).toBe(true);
await new Promise((resolve) => setTimeout(resolve, 100));
expect(mockMCPManager.getUserConnection).toHaveBeenCalledWith(
expect.objectContaining({ serverName: 'server1' }),
);
});
it('should reconnect when access token is TTL-deleted but refresh token exists', async () => {
const userId = 'user-123';
const oauthServers = new Set(['server1']);
(mockRegistryInstance.getOAuthServers as jest.Mock).mockResolvedValue(oauthServers);
tokenMethods.findToken.mockImplementation(async ({ identifier }) => {
if (identifier === 'mcp:server1:refresh') {
return {
userId,
identifier,
} as unknown as MCPOAuthTokens;
}
return null;
});
const mockNewConnection = {
isConnected: jest.fn().mockResolvedValue(true),
disconnect: jest.fn(),
};
mockMCPManager.getUserConnection.mockResolvedValue(
mockNewConnection as unknown as MCPConnection,
);
(mockRegistryInstance.getServerConfig as jest.Mock).mockResolvedValue(
{} as unknown as MCPOptions,
);
await reconnectionManager.reconnectServers(userId);
expect(reconnectionTracker.isActive(userId, 'server1')).toBe(true);
await new Promise((resolve) => setTimeout(resolve, 100));
expect(mockMCPManager.getUserConnection).toHaveBeenCalledWith(
expect.objectContaining({ serverName: 'server1' }),
);
});
it('should handle connection that returns but is not connected', async () => {
const userId = 'user-123';
const oauthServers = new Set(['server1']);
@ -336,6 +421,69 @@ describe('OAuthReconnectionManager', () => {
});
});
describe('reconnectServer', () => {
let reconnectionTracker: OAuthReconnectionTracker;
beforeEach(async () => {
reconnectionTracker = new OAuthReconnectionTracker();
reconnectionManager = await OAuthReconnectionManager.createInstance(
flowManager,
tokenMethods,
reconnectionTracker,
);
});
it('should return true on successful reconnection', async () => {
const userId = 'user-123';
const serverName = 'server1';
const mockConnection = {
isConnected: jest.fn().mockResolvedValue(true),
disconnect: jest.fn(),
};
mockMCPManager.getUserConnection.mockResolvedValue(
mockConnection as unknown as MCPConnection,
);
(mockRegistryInstance.getServerConfig as jest.Mock).mockResolvedValue(
{} as unknown as MCPOptions,
);
const result = await reconnectionManager.reconnectServer(userId, serverName);
expect(result).toBe(true);
});
it('should return false on failed reconnection', async () => {
const userId = 'user-123';
const serverName = 'server1';
mockMCPManager.getUserConnection.mockRejectedValue(new Error('Connection failed'));
(mockRegistryInstance.getServerConfig as jest.Mock).mockResolvedValue(
{} as unknown as MCPOptions,
);
const result = await reconnectionManager.reconnectServer(userId, serverName);
expect(result).toBe(false);
});
it('should return false when MCPManager is not available', async () => {
const userId = 'user-123';
const serverName = 'server1';
(OAuthReconnectionManager as unknown as { instance: null }).instance = null;
(MCPManager.getInstance as jest.Mock).mockImplementation(() => {
throw new Error('MCPManager has not been initialized.');
});
const managerWithoutMCP = await OAuthReconnectionManager.createInstance(
flowManager,
tokenMethods,
reconnectionTracker,
);
const result = await managerWithoutMCP.reconnectServer(userId, serverName);
expect(result).toBe(false);
});
});
describe('reconnection staggering', () => {
let reconnectionTracker: OAuthReconnectionTracker;

View file

@ -96,6 +96,24 @@ export class OAuthReconnectionManager {
}
}
/**
* Attempts to reconnect a single OAuth MCP server.
* @returns true if reconnection succeeded, false otherwise.
*/
public async reconnectServer(userId: string, serverName: string): Promise<boolean> {
if (this.mcpManager == null) {
return false;
}
this.reconnectionsTracker.setActive(userId, serverName);
try {
await this.tryReconnect(userId, serverName);
return !this.reconnectionsTracker.isFailed(userId, serverName);
} catch {
return false;
}
}
public clearReconnection(userId: string, serverName: string) {
this.reconnectionsTracker.removeFailed(userId, serverName);
this.reconnectionsTracker.removeActive(userId, serverName);
@ -174,23 +192,31 @@ export class OAuthReconnectionManager {
}
}
// if the server has no tokens for the user, don't attempt to reconnect
// if the server has a valid (non-expired) access token, allow reconnect
const accessToken = await this.tokenMethods.findToken({
userId,
type: 'mcp_oauth',
identifier: `mcp:${serverName}`,
});
if (accessToken == null) {
if (accessToken != null) {
const now = new Date();
if (!accessToken.expiresAt || accessToken.expiresAt >= now) {
return true;
}
}
// if the access token is expired or TTL-deleted, fall back to refresh token
const refreshToken = await this.tokenMethods.findToken({
userId,
type: 'mcp_oauth',
identifier: `mcp:${serverName}:refresh`,
});
if (refreshToken == null) {
return false;
}
// if the token has expired, don't attempt to reconnect
const now = new Date();
if (accessToken.expiresAt && accessToken.expiresAt < now) {
return false;
}
// …otherwise, we're good to go with the reconnect attempt
return true;
}
}

View file

@ -397,6 +397,101 @@ describe('OAuthReconnectTracker', () => {
});
});
describe('cooldown-based retry', () => {
beforeEach(() => {
jest.useFakeTimers();
});
afterEach(() => {
jest.useRealTimers();
});
it('should return true from isFailed within first cooldown period (5 min)', () => {
const now = Date.now();
jest.setSystemTime(now);
tracker.setFailed(userId, serverName);
expect(tracker.isFailed(userId, serverName)).toBe(true);
jest.advanceTimersByTime(4 * 60 * 1000);
expect(tracker.isFailed(userId, serverName)).toBe(true);
});
it('should return false from isFailed after first cooldown elapses (5 min)', () => {
const now = Date.now();
jest.setSystemTime(now);
tracker.setFailed(userId, serverName);
expect(tracker.isFailed(userId, serverName)).toBe(true);
jest.advanceTimersByTime(5 * 60 * 1000);
expect(tracker.isFailed(userId, serverName)).toBe(false);
});
it('should use progressive cooldown schedule (5m, 10m, 20m, 30m)', () => {
const now = Date.now();
jest.setSystemTime(now);
// First failure: 5 min cooldown
tracker.setFailed(userId, serverName);
jest.advanceTimersByTime(5 * 60 * 1000);
expect(tracker.isFailed(userId, serverName)).toBe(false);
// Second failure: 10 min cooldown
tracker.setFailed(userId, serverName);
jest.advanceTimersByTime(9 * 60 * 1000);
expect(tracker.isFailed(userId, serverName)).toBe(true);
jest.advanceTimersByTime(1 * 60 * 1000);
expect(tracker.isFailed(userId, serverName)).toBe(false);
// Third failure: 20 min cooldown
tracker.setFailed(userId, serverName);
jest.advanceTimersByTime(19 * 60 * 1000);
expect(tracker.isFailed(userId, serverName)).toBe(true);
jest.advanceTimersByTime(1 * 60 * 1000);
expect(tracker.isFailed(userId, serverName)).toBe(false);
// Fourth failure: 30 min cooldown
tracker.setFailed(userId, serverName);
jest.advanceTimersByTime(29 * 60 * 1000);
expect(tracker.isFailed(userId, serverName)).toBe(true);
jest.advanceTimersByTime(1 * 60 * 1000);
expect(tracker.isFailed(userId, serverName)).toBe(false);
});
it('should cap cooldown at 30 min for attempts beyond 4', () => {
const now = Date.now();
jest.setSystemTime(now);
for (let i = 0; i < 5; i++) {
tracker.setFailed(userId, serverName);
jest.advanceTimersByTime(30 * 60 * 1000);
}
tracker.setFailed(userId, serverName);
jest.advanceTimersByTime(29 * 60 * 1000);
expect(tracker.isFailed(userId, serverName)).toBe(true);
jest.advanceTimersByTime(1 * 60 * 1000);
expect(tracker.isFailed(userId, serverName)).toBe(false);
});
it('should fully reset metadata on removeFailed', () => {
const now = Date.now();
jest.setSystemTime(now);
tracker.setFailed(userId, serverName);
tracker.setFailed(userId, serverName);
tracker.setFailed(userId, serverName);
tracker.removeFailed(userId, serverName);
expect(tracker.isFailed(userId, serverName)).toBe(false);
tracker.setFailed(userId, serverName);
jest.advanceTimersByTime(5 * 60 * 1000);
expect(tracker.isFailed(userId, serverName)).toBe(false);
});
});
describe('timestamp tracking edge cases', () => {
beforeEach(() => {
jest.useFakeTimers();

View file

@ -1,6 +1,12 @@
interface FailedMeta {
attempts: number;
lastFailedAt: number;
}
const COOLDOWN_SCHEDULE_MS = [5 * 60 * 1000, 10 * 60 * 1000, 20 * 60 * 1000, 30 * 60 * 1000];
export class OAuthReconnectionTracker {
/** Map of userId -> Set of serverNames that have failed reconnection */
private failed: Map<string, Set<string>> = new Map();
private failedMeta: Map<string, Map<string, FailedMeta>> = new Map();
/** Map of userId -> Set of serverNames that are actively reconnecting */
private active: Map<string, Set<string>> = new Map();
/** Map of userId:serverName -> timestamp when reconnection started */
@ -9,7 +15,17 @@ export class OAuthReconnectionTracker {
private readonly RECONNECTION_TIMEOUT_MS = 3 * 60 * 1000; // 3 minutes
public isFailed(userId: string, serverName: string): boolean {
return this.failed.get(userId)?.has(serverName) ?? false;
const meta = this.failedMeta.get(userId)?.get(serverName);
if (!meta) {
return false;
}
const idx = Math.min(meta.attempts - 1, COOLDOWN_SCHEDULE_MS.length - 1);
const cooldown = COOLDOWN_SCHEDULE_MS[idx];
const elapsed = Date.now() - meta.lastFailedAt;
if (elapsed >= cooldown) {
return false;
}
return true;
}
/** Check if server is in the active set (original simple check) */
@ -48,11 +64,15 @@ export class OAuthReconnectionTracker {
}
public setFailed(userId: string, serverName: string): void {
if (!this.failed.has(userId)) {
this.failed.set(userId, new Set());
if (!this.failedMeta.has(userId)) {
this.failedMeta.set(userId, new Map());
}
this.failed.get(userId)?.add(serverName);
const userMap = this.failedMeta.get(userId)!;
const existing = userMap.get(serverName);
userMap.set(serverName, {
attempts: (existing?.attempts ?? 0) + 1,
lastFailedAt: Date.now(),
});
}
public setActive(userId: string, serverName: string): void {
@ -68,10 +88,10 @@ export class OAuthReconnectionTracker {
}
public removeFailed(userId: string, serverName: string): void {
const userServers = this.failed.get(userId);
userServers?.delete(serverName);
if (userServers?.size === 0) {
this.failed.delete(userId);
const userMap = this.failedMeta.get(userId);
userMap?.delete(serverName);
if (userMap?.size === 0) {
this.failedMeta.delete(userId);
}
}
@ -94,7 +114,7 @@ export class OAuthReconnectionTracker {
activeTimestamps: number;
} {
return {
usersWithFailedServers: this.failed.size,
usersWithFailedServers: this.failedMeta.size,
usersWithActiveReconnections: this.active.size,
activeTimestamps: this.activeTimestamps.size,
};