⛈️ fix: MCP Reconnection Storm Prevention with Circuit Breaker, Backoff, and Tool Stubs (#12162)

* fix: MCP reconnection stability - circuit breaker, throttling, and cooldown retry

* Comment and logging cleanup

* fix broken tests
This commit is contained in:
matt burnett 2026-03-10 11:21:36 -07:00 committed by GitHub
parent cfbe812d63
commit ad5c51f62b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 736 additions and 38 deletions

View file

@ -253,17 +253,21 @@ describe('OAuthReconnectionManager', () => {
expect(mockMCPManager.disconnectUserConnection).toHaveBeenCalledWith(userId, 'server1');
});
it('should not reconnect servers with expired tokens', async () => {
it('should not reconnect servers with expired tokens and no refresh token', async () => {
const userId = 'user-123';
const oauthServers = new Set(['server1']);
(mockRegistryInstance.getOAuthServers as jest.Mock).mockResolvedValue(oauthServers);
// server1: has expired token
tokenMethods.findToken.mockResolvedValue({
userId,
identifier: 'mcp:server1',
expiresAt: new Date(Date.now() - 3600000), // 1 hour ago
} as unknown as MCPOAuthTokens);
tokenMethods.findToken.mockImplementation(async ({ identifier }) => {
if (identifier === 'mcp:server1') {
return {
userId,
identifier,
expiresAt: new Date(Date.now() - 3600000),
} as unknown as MCPOAuthTokens;
}
return null;
});
await reconnectionManager.reconnectServers(userId);
@ -272,6 +276,87 @@ describe('OAuthReconnectionManager', () => {
expect(mockMCPManager.getUserConnection).not.toHaveBeenCalled();
});
it('should reconnect servers with expired access token but valid refresh token', async () => {
const userId = 'user-123';
const oauthServers = new Set(['server1']);
(mockRegistryInstance.getOAuthServers as jest.Mock).mockResolvedValue(oauthServers);
tokenMethods.findToken.mockImplementation(async ({ identifier }) => {
if (identifier === 'mcp:server1') {
return {
userId,
identifier,
expiresAt: new Date(Date.now() - 3600000),
} as unknown as MCPOAuthTokens;
}
if (identifier === 'mcp:server1:refresh') {
return {
userId,
identifier,
} as unknown as MCPOAuthTokens;
}
return null;
});
const mockNewConnection = {
isConnected: jest.fn().mockResolvedValue(true),
disconnect: jest.fn(),
};
mockMCPManager.getUserConnection.mockResolvedValue(
mockNewConnection as unknown as MCPConnection,
);
(mockRegistryInstance.getServerConfig as jest.Mock).mockResolvedValue(
{} as unknown as MCPOptions,
);
await reconnectionManager.reconnectServers(userId);
expect(reconnectionTracker.isActive(userId, 'server1')).toBe(true);
await new Promise((resolve) => setTimeout(resolve, 100));
expect(mockMCPManager.getUserConnection).toHaveBeenCalledWith(
expect.objectContaining({ serverName: 'server1' }),
);
});
it('should reconnect when access token is TTL-deleted but refresh token exists', async () => {
const userId = 'user-123';
const oauthServers = new Set(['server1']);
(mockRegistryInstance.getOAuthServers as jest.Mock).mockResolvedValue(oauthServers);
tokenMethods.findToken.mockImplementation(async ({ identifier }) => {
if (identifier === 'mcp:server1:refresh') {
return {
userId,
identifier,
} as unknown as MCPOAuthTokens;
}
return null;
});
const mockNewConnection = {
isConnected: jest.fn().mockResolvedValue(true),
disconnect: jest.fn(),
};
mockMCPManager.getUserConnection.mockResolvedValue(
mockNewConnection as unknown as MCPConnection,
);
(mockRegistryInstance.getServerConfig as jest.Mock).mockResolvedValue(
{} as unknown as MCPOptions,
);
await reconnectionManager.reconnectServers(userId);
expect(reconnectionTracker.isActive(userId, 'server1')).toBe(true);
await new Promise((resolve) => setTimeout(resolve, 100));
expect(mockMCPManager.getUserConnection).toHaveBeenCalledWith(
expect.objectContaining({ serverName: 'server1' }),
);
});
it('should handle connection that returns but is not connected', async () => {
const userId = 'user-123';
const oauthServers = new Set(['server1']);
@ -336,6 +421,69 @@ describe('OAuthReconnectionManager', () => {
});
});
describe('reconnectServer', () => {
let reconnectionTracker: OAuthReconnectionTracker;
beforeEach(async () => {
reconnectionTracker = new OAuthReconnectionTracker();
reconnectionManager = await OAuthReconnectionManager.createInstance(
flowManager,
tokenMethods,
reconnectionTracker,
);
});
it('should return true on successful reconnection', async () => {
const userId = 'user-123';
const serverName = 'server1';
const mockConnection = {
isConnected: jest.fn().mockResolvedValue(true),
disconnect: jest.fn(),
};
mockMCPManager.getUserConnection.mockResolvedValue(
mockConnection as unknown as MCPConnection,
);
(mockRegistryInstance.getServerConfig as jest.Mock).mockResolvedValue(
{} as unknown as MCPOptions,
);
const result = await reconnectionManager.reconnectServer(userId, serverName);
expect(result).toBe(true);
});
it('should return false on failed reconnection', async () => {
const userId = 'user-123';
const serverName = 'server1';
mockMCPManager.getUserConnection.mockRejectedValue(new Error('Connection failed'));
(mockRegistryInstance.getServerConfig as jest.Mock).mockResolvedValue(
{} as unknown as MCPOptions,
);
const result = await reconnectionManager.reconnectServer(userId, serverName);
expect(result).toBe(false);
});
it('should return false when MCPManager is not available', async () => {
const userId = 'user-123';
const serverName = 'server1';
(OAuthReconnectionManager as unknown as { instance: null }).instance = null;
(MCPManager.getInstance as jest.Mock).mockImplementation(() => {
throw new Error('MCPManager has not been initialized.');
});
const managerWithoutMCP = await OAuthReconnectionManager.createInstance(
flowManager,
tokenMethods,
reconnectionTracker,
);
const result = await managerWithoutMCP.reconnectServer(userId, serverName);
expect(result).toBe(false);
});
});
describe('reconnection staggering', () => {
let reconnectionTracker: OAuthReconnectionTracker;

View file

@ -96,6 +96,24 @@ export class OAuthReconnectionManager {
}
}
/**
* Attempts to reconnect a single OAuth MCP server.
* @returns true if reconnection succeeded, false otherwise.
*/
public async reconnectServer(userId: string, serverName: string): Promise<boolean> {
if (this.mcpManager == null) {
return false;
}
this.reconnectionsTracker.setActive(userId, serverName);
try {
await this.tryReconnect(userId, serverName);
return !this.reconnectionsTracker.isFailed(userId, serverName);
} catch {
return false;
}
}
public clearReconnection(userId: string, serverName: string) {
this.reconnectionsTracker.removeFailed(userId, serverName);
this.reconnectionsTracker.removeActive(userId, serverName);
@ -174,23 +192,31 @@ export class OAuthReconnectionManager {
}
}
// if the server has no tokens for the user, don't attempt to reconnect
// if the server has a valid (non-expired) access token, allow reconnect
const accessToken = await this.tokenMethods.findToken({
userId,
type: 'mcp_oauth',
identifier: `mcp:${serverName}`,
});
if (accessToken == null) {
if (accessToken != null) {
const now = new Date();
if (!accessToken.expiresAt || accessToken.expiresAt >= now) {
return true;
}
}
// if the access token is expired or TTL-deleted, fall back to refresh token
const refreshToken = await this.tokenMethods.findToken({
userId,
type: 'mcp_oauth',
identifier: `mcp:${serverName}:refresh`,
});
if (refreshToken == null) {
return false;
}
// if the token has expired, don't attempt to reconnect
const now = new Date();
if (accessToken.expiresAt && accessToken.expiresAt < now) {
return false;
}
// …otherwise, we're good to go with the reconnect attempt
return true;
}
}

View file

@ -397,6 +397,101 @@ describe('OAuthReconnectTracker', () => {
});
});
describe('cooldown-based retry', () => {
beforeEach(() => {
jest.useFakeTimers();
});
afterEach(() => {
jest.useRealTimers();
});
it('should return true from isFailed within first cooldown period (5 min)', () => {
const now = Date.now();
jest.setSystemTime(now);
tracker.setFailed(userId, serverName);
expect(tracker.isFailed(userId, serverName)).toBe(true);
jest.advanceTimersByTime(4 * 60 * 1000);
expect(tracker.isFailed(userId, serverName)).toBe(true);
});
it('should return false from isFailed after first cooldown elapses (5 min)', () => {
const now = Date.now();
jest.setSystemTime(now);
tracker.setFailed(userId, serverName);
expect(tracker.isFailed(userId, serverName)).toBe(true);
jest.advanceTimersByTime(5 * 60 * 1000);
expect(tracker.isFailed(userId, serverName)).toBe(false);
});
it('should use progressive cooldown schedule (5m, 10m, 20m, 30m)', () => {
const now = Date.now();
jest.setSystemTime(now);
// First failure: 5 min cooldown
tracker.setFailed(userId, serverName);
jest.advanceTimersByTime(5 * 60 * 1000);
expect(tracker.isFailed(userId, serverName)).toBe(false);
// Second failure: 10 min cooldown
tracker.setFailed(userId, serverName);
jest.advanceTimersByTime(9 * 60 * 1000);
expect(tracker.isFailed(userId, serverName)).toBe(true);
jest.advanceTimersByTime(1 * 60 * 1000);
expect(tracker.isFailed(userId, serverName)).toBe(false);
// Third failure: 20 min cooldown
tracker.setFailed(userId, serverName);
jest.advanceTimersByTime(19 * 60 * 1000);
expect(tracker.isFailed(userId, serverName)).toBe(true);
jest.advanceTimersByTime(1 * 60 * 1000);
expect(tracker.isFailed(userId, serverName)).toBe(false);
// Fourth failure: 30 min cooldown
tracker.setFailed(userId, serverName);
jest.advanceTimersByTime(29 * 60 * 1000);
expect(tracker.isFailed(userId, serverName)).toBe(true);
jest.advanceTimersByTime(1 * 60 * 1000);
expect(tracker.isFailed(userId, serverName)).toBe(false);
});
it('should cap cooldown at 30 min for attempts beyond 4', () => {
const now = Date.now();
jest.setSystemTime(now);
for (let i = 0; i < 5; i++) {
tracker.setFailed(userId, serverName);
jest.advanceTimersByTime(30 * 60 * 1000);
}
tracker.setFailed(userId, serverName);
jest.advanceTimersByTime(29 * 60 * 1000);
expect(tracker.isFailed(userId, serverName)).toBe(true);
jest.advanceTimersByTime(1 * 60 * 1000);
expect(tracker.isFailed(userId, serverName)).toBe(false);
});
it('should fully reset metadata on removeFailed', () => {
const now = Date.now();
jest.setSystemTime(now);
tracker.setFailed(userId, serverName);
tracker.setFailed(userId, serverName);
tracker.setFailed(userId, serverName);
tracker.removeFailed(userId, serverName);
expect(tracker.isFailed(userId, serverName)).toBe(false);
tracker.setFailed(userId, serverName);
jest.advanceTimersByTime(5 * 60 * 1000);
expect(tracker.isFailed(userId, serverName)).toBe(false);
});
});
describe('timestamp tracking edge cases', () => {
beforeEach(() => {
jest.useFakeTimers();

View file

@ -1,6 +1,12 @@
interface FailedMeta {
attempts: number;
lastFailedAt: number;
}
const COOLDOWN_SCHEDULE_MS = [5 * 60 * 1000, 10 * 60 * 1000, 20 * 60 * 1000, 30 * 60 * 1000];
export class OAuthReconnectionTracker {
/** Map of userId -> Set of serverNames that have failed reconnection */
private failed: Map<string, Set<string>> = new Map();
private failedMeta: Map<string, Map<string, FailedMeta>> = new Map();
/** Map of userId -> Set of serverNames that are actively reconnecting */
private active: Map<string, Set<string>> = new Map();
/** Map of userId:serverName -> timestamp when reconnection started */
@ -9,7 +15,17 @@ export class OAuthReconnectionTracker {
private readonly RECONNECTION_TIMEOUT_MS = 3 * 60 * 1000; // 3 minutes
public isFailed(userId: string, serverName: string): boolean {
return this.failed.get(userId)?.has(serverName) ?? false;
const meta = this.failedMeta.get(userId)?.get(serverName);
if (!meta) {
return false;
}
const idx = Math.min(meta.attempts - 1, COOLDOWN_SCHEDULE_MS.length - 1);
const cooldown = COOLDOWN_SCHEDULE_MS[idx];
const elapsed = Date.now() - meta.lastFailedAt;
if (elapsed >= cooldown) {
return false;
}
return true;
}
/** Check if server is in the active set (original simple check) */
@ -48,11 +64,15 @@ export class OAuthReconnectionTracker {
}
public setFailed(userId: string, serverName: string): void {
if (!this.failed.has(userId)) {
this.failed.set(userId, new Set());
if (!this.failedMeta.has(userId)) {
this.failedMeta.set(userId, new Map());
}
this.failed.get(userId)?.add(serverName);
const userMap = this.failedMeta.get(userId)!;
const existing = userMap.get(serverName);
userMap.set(serverName, {
attempts: (existing?.attempts ?? 0) + 1,
lastFailedAt: Date.now(),
});
}
public setActive(userId: string, serverName: string): void {
@ -68,10 +88,10 @@ export class OAuthReconnectionTracker {
}
public removeFailed(userId: string, serverName: string): void {
const userServers = this.failed.get(userId);
userServers?.delete(serverName);
if (userServers?.size === 0) {
this.failed.delete(userId);
const userMap = this.failedMeta.get(userId);
userMap?.delete(serverName);
if (userMap?.size === 0) {
this.failedMeta.delete(userId);
}
}
@ -94,7 +114,7 @@ export class OAuthReconnectionTracker {
activeTimestamps: number;
} {
return {
usersWithFailedServers: this.failed.size,
usersWithFailedServers: this.failedMeta.size,
usersWithActiveReconnections: this.active.size,
activeTimestamps: this.activeTimestamps.size,
};