mirror of
https://github.com/danny-avila/LibreChat.git
synced 2026-03-12 11:02:37 +01:00
⛈️ fix: MCP Reconnection Storm Prevention with Circuit Breaker, Backoff, and Tool Stubs (#12162)
* fix: MCP reconnection stability - circuit breaker, throttling, and cooldown retry * Comment and logging cleanup * fix broken tests
This commit is contained in:
parent
cfbe812d63
commit
ad5c51f62b
9 changed files with 736 additions and 38 deletions
|
|
@ -34,6 +34,39 @@ const { reinitMCPServer } = require('./Tools/mcp');
|
|||
const { getAppConfig } = require('./Config');
|
||||
const { getLogStores } = require('~/cache');
|
||||
|
||||
const lastReconnectAttempts = new Map();
|
||||
const RECONNECT_THROTTLE_MS = 10_000;
|
||||
|
||||
const missingToolCache = new Map();
|
||||
const MISSING_TOOL_TTL_MS = 10_000;
|
||||
|
||||
const unavailableMsg =
|
||||
"This tool's MCP server is temporarily unavailable. Please try again shortly.";
|
||||
|
||||
/**
|
||||
* @param {string} toolName
|
||||
* @param {string} serverName
|
||||
*/
|
||||
function createUnavailableToolStub(toolName, serverName) {
|
||||
const normalizedToolKey = `${toolName}${Constants.mcp_delimiter}${normalizeServerName(serverName)}`;
|
||||
const _call = async () => unavailableMsg;
|
||||
const toolInstance = tool(_call, {
|
||||
schema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
input: { type: 'string', description: 'Input for the tool' },
|
||||
},
|
||||
required: [],
|
||||
},
|
||||
name: normalizedToolKey,
|
||||
description: unavailableMsg,
|
||||
responseFormat: AgentConstants.CONTENT_AND_ARTIFACT,
|
||||
});
|
||||
toolInstance.mcp = true;
|
||||
toolInstance.mcpRawServerName = serverName;
|
||||
return toolInstance;
|
||||
}
|
||||
|
||||
function isEmptyObjectSchema(jsonSchema) {
|
||||
return (
|
||||
jsonSchema != null &&
|
||||
|
|
@ -211,6 +244,16 @@ async function reconnectServer({
|
|||
logger.debug(
|
||||
`[MCP][reconnectServer] serverName: ${serverName}, user: ${user?.id}, hasUserMCPAuthMap: ${!!userMCPAuthMap}`,
|
||||
);
|
||||
|
||||
const throttleKey = `${user.id}:${serverName}`;
|
||||
const now = Date.now();
|
||||
const lastAttempt = lastReconnectAttempts.get(throttleKey) ?? 0;
|
||||
if (now - lastAttempt < RECONNECT_THROTTLE_MS) {
|
||||
logger.debug(`[MCP][reconnectServer] Throttled reconnect for ${serverName}`);
|
||||
return null;
|
||||
}
|
||||
lastReconnectAttempts.set(throttleKey, now);
|
||||
|
||||
const runId = Constants.USE_PRELIM_RESPONSE_MESSAGE_ID;
|
||||
const flowId = `${user.id}:${serverName}:${Date.now()}`;
|
||||
const flowManager = getFlowStateManager(getLogStores(CacheKeys.FLOWS));
|
||||
|
|
@ -267,7 +310,7 @@ async function reconnectServer({
|
|||
userMCPAuthMap,
|
||||
forceNew: true,
|
||||
returnOnOAuth: false,
|
||||
connectionTimeout: Time.TWO_MINUTES,
|
||||
connectionTimeout: Time.THIRTY_SECONDS,
|
||||
});
|
||||
} finally {
|
||||
// Clean up abort handler to prevent memory leaks
|
||||
|
|
@ -332,7 +375,7 @@ async function createMCPTools({
|
|||
});
|
||||
if (!result || !result.tools) {
|
||||
logger.warn(`[MCP][${serverName}] Failed to reinitialize MCP server.`);
|
||||
return;
|
||||
return [];
|
||||
}
|
||||
|
||||
const serverTools = [];
|
||||
|
|
@ -402,6 +445,14 @@ async function createMCPTool({
|
|||
/** @type {LCTool | undefined} */
|
||||
let toolDefinition = availableTools?.[toolKey]?.function;
|
||||
if (!toolDefinition) {
|
||||
const cachedAt = missingToolCache.get(toolKey);
|
||||
if (cachedAt && Date.now() - cachedAt < MISSING_TOOL_TTL_MS) {
|
||||
logger.debug(
|
||||
`[MCP][${serverName}][${toolName}] Tool in negative cache, returning unavailable stub.`,
|
||||
);
|
||||
return createUnavailableToolStub(toolName, serverName);
|
||||
}
|
||||
|
||||
logger.warn(
|
||||
`[MCP][${serverName}][${toolName}] Requested tool not found in available tools, re-initializing MCP server.`,
|
||||
);
|
||||
|
|
@ -415,11 +466,17 @@ async function createMCPTool({
|
|||
streamId,
|
||||
});
|
||||
toolDefinition = result?.availableTools?.[toolKey]?.function;
|
||||
|
||||
if (!toolDefinition) {
|
||||
missingToolCache.set(toolKey, Date.now());
|
||||
}
|
||||
}
|
||||
|
||||
if (!toolDefinition) {
|
||||
logger.warn(`[MCP][${serverName}][${toolName}] Tool definition not found, cannot create tool.`);
|
||||
return;
|
||||
logger.warn(
|
||||
`[MCP][${serverName}][${toolName}] Tool definition not found, returning unavailable stub.`,
|
||||
);
|
||||
return createUnavailableToolStub(toolName, serverName);
|
||||
}
|
||||
|
||||
return createToolInstance({
|
||||
|
|
@ -720,4 +777,5 @@ module.exports = {
|
|||
getMCPSetupData,
|
||||
checkOAuthFlowStatus,
|
||||
getServerConnectionStatus,
|
||||
createUnavailableToolStub,
|
||||
};
|
||||
|
|
|
|||
|
|
@ -65,6 +65,9 @@ export abstract class UserConnectionManager {
|
|||
|
||||
const userServerMap = this.userConnections.get(userId);
|
||||
let connection = forceNew ? undefined : userServerMap?.get(serverName);
|
||||
if (forceNew) {
|
||||
MCPConnection.clearCooldown(serverName);
|
||||
}
|
||||
const now = Date.now();
|
||||
|
||||
// Check if user is idle
|
||||
|
|
|
|||
|
|
@ -559,3 +559,242 @@ describe('extractSSEErrorMessage', () => {
|
|||
});
|
||||
});
|
||||
});
|
||||
|
||||
/**
|
||||
* Tests for circuit breaker logic.
|
||||
*
|
||||
* Uses standalone implementations that mirror the static/private circuit breaker
|
||||
* methods in MCPConnection. Same approach as the error detection tests above.
|
||||
*/
|
||||
describe('MCPConnection Circuit Breaker', () => {
|
||||
/** 5 cycles within 60s triggers a 30s cooldown */
|
||||
const CB_MAX_CYCLES = 5;
|
||||
const CB_CYCLE_WINDOW_MS = 60_000;
|
||||
const CB_CYCLE_COOLDOWN_MS = 30_000;
|
||||
|
||||
/** 3 failed rounds within 120s triggers exponential backoff (30s - 300s) */
|
||||
const CB_MAX_FAILED_ROUNDS = 3;
|
||||
const CB_FAILED_WINDOW_MS = 120_000;
|
||||
const CB_BASE_BACKOFF_MS = 30_000;
|
||||
const CB_MAX_BACKOFF_MS = 300_000;
|
||||
|
||||
interface CircuitBreakerState {
|
||||
cycleCount: number;
|
||||
cycleWindowStart: number;
|
||||
cooldownUntil: number;
|
||||
failedRounds: number;
|
||||
failedWindowStart: number;
|
||||
failedBackoffUntil: number;
|
||||
}
|
||||
|
||||
function createCB(): CircuitBreakerState {
|
||||
return {
|
||||
cycleCount: 0,
|
||||
cycleWindowStart: Date.now(),
|
||||
cooldownUntil: 0,
|
||||
failedRounds: 0,
|
||||
failedWindowStart: Date.now(),
|
||||
failedBackoffUntil: 0,
|
||||
};
|
||||
}
|
||||
|
||||
function isCircuitOpen(cb: CircuitBreakerState): boolean {
|
||||
const now = Date.now();
|
||||
return now < cb.cooldownUntil || now < cb.failedBackoffUntil;
|
||||
}
|
||||
|
||||
function recordCycle(cb: CircuitBreakerState): void {
|
||||
const now = Date.now();
|
||||
if (now - cb.cycleWindowStart > CB_CYCLE_WINDOW_MS) {
|
||||
cb.cycleCount = 0;
|
||||
cb.cycleWindowStart = now;
|
||||
}
|
||||
cb.cycleCount++;
|
||||
if (cb.cycleCount >= CB_MAX_CYCLES) {
|
||||
cb.cooldownUntil = now + CB_CYCLE_COOLDOWN_MS;
|
||||
cb.cycleCount = 0;
|
||||
cb.cycleWindowStart = now;
|
||||
}
|
||||
}
|
||||
|
||||
function recordFailedRound(cb: CircuitBreakerState): void {
|
||||
const now = Date.now();
|
||||
if (now - cb.failedWindowStart > CB_FAILED_WINDOW_MS) {
|
||||
cb.failedRounds = 0;
|
||||
cb.failedWindowStart = now;
|
||||
}
|
||||
cb.failedRounds++;
|
||||
if (cb.failedRounds >= CB_MAX_FAILED_ROUNDS) {
|
||||
const backoff = Math.min(
|
||||
CB_BASE_BACKOFF_MS * Math.pow(2, cb.failedRounds - CB_MAX_FAILED_ROUNDS),
|
||||
CB_MAX_BACKOFF_MS,
|
||||
);
|
||||
cb.failedBackoffUntil = now + backoff;
|
||||
}
|
||||
}
|
||||
|
||||
function resetFailedRounds(cb: CircuitBreakerState): void {
|
||||
cb.failedRounds = 0;
|
||||
cb.failedWindowStart = Date.now();
|
||||
cb.failedBackoffUntil = 0;
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
jest.useFakeTimers();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
jest.useRealTimers();
|
||||
});
|
||||
|
||||
describe('cycle tracking', () => {
|
||||
it('should not trigger cooldown for fewer than 5 cycles', () => {
|
||||
const now = Date.now();
|
||||
jest.setSystemTime(now);
|
||||
|
||||
const cb = createCB();
|
||||
for (let i = 0; i < CB_MAX_CYCLES - 1; i++) {
|
||||
recordCycle(cb);
|
||||
}
|
||||
expect(isCircuitOpen(cb)).toBe(false);
|
||||
});
|
||||
|
||||
it('should trigger 30s cooldown after 5 cycles within 60s', () => {
|
||||
const now = Date.now();
|
||||
jest.setSystemTime(now);
|
||||
|
||||
const cb = createCB();
|
||||
for (let i = 0; i < CB_MAX_CYCLES; i++) {
|
||||
recordCycle(cb);
|
||||
}
|
||||
expect(isCircuitOpen(cb)).toBe(true);
|
||||
|
||||
jest.advanceTimersByTime(29_000);
|
||||
expect(isCircuitOpen(cb)).toBe(true);
|
||||
|
||||
jest.advanceTimersByTime(1_000);
|
||||
expect(isCircuitOpen(cb)).toBe(false);
|
||||
});
|
||||
|
||||
it('should reset cycle count when window expires', () => {
|
||||
const now = Date.now();
|
||||
jest.setSystemTime(now);
|
||||
|
||||
const cb = createCB();
|
||||
for (let i = 0; i < CB_MAX_CYCLES - 1; i++) {
|
||||
recordCycle(cb);
|
||||
}
|
||||
|
||||
jest.advanceTimersByTime(CB_CYCLE_WINDOW_MS + 1);
|
||||
|
||||
recordCycle(cb);
|
||||
expect(isCircuitOpen(cb)).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('failed round tracking', () => {
|
||||
it('should not trigger backoff for fewer than 3 failures', () => {
|
||||
const now = Date.now();
|
||||
jest.setSystemTime(now);
|
||||
|
||||
const cb = createCB();
|
||||
for (let i = 0; i < CB_MAX_FAILED_ROUNDS - 1; i++) {
|
||||
recordFailedRound(cb);
|
||||
}
|
||||
expect(isCircuitOpen(cb)).toBe(false);
|
||||
});
|
||||
|
||||
it('should trigger 30s backoff after 3 failures within 120s', () => {
|
||||
const now = Date.now();
|
||||
jest.setSystemTime(now);
|
||||
|
||||
const cb = createCB();
|
||||
for (let i = 0; i < CB_MAX_FAILED_ROUNDS; i++) {
|
||||
recordFailedRound(cb);
|
||||
}
|
||||
expect(isCircuitOpen(cb)).toBe(true);
|
||||
|
||||
jest.advanceTimersByTime(CB_BASE_BACKOFF_MS);
|
||||
expect(isCircuitOpen(cb)).toBe(false);
|
||||
});
|
||||
|
||||
it('should use exponential backoff based on failure count', () => {
|
||||
jest.setSystemTime(Date.now());
|
||||
|
||||
const cb = createCB();
|
||||
|
||||
for (let i = 0; i < 3; i++) {
|
||||
recordFailedRound(cb);
|
||||
}
|
||||
expect(cb.failedBackoffUntil - Date.now()).toBe(30_000);
|
||||
|
||||
recordFailedRound(cb);
|
||||
expect(cb.failedBackoffUntil - Date.now()).toBe(60_000);
|
||||
|
||||
recordFailedRound(cb);
|
||||
expect(cb.failedBackoffUntil - Date.now()).toBe(120_000);
|
||||
|
||||
recordFailedRound(cb);
|
||||
expect(cb.failedBackoffUntil - Date.now()).toBe(240_000);
|
||||
|
||||
// capped at 300s
|
||||
recordFailedRound(cb);
|
||||
expect(cb.failedBackoffUntil - Date.now()).toBe(300_000);
|
||||
});
|
||||
|
||||
it('should reset failed window when window expires', () => {
|
||||
const now = Date.now();
|
||||
jest.setSystemTime(now);
|
||||
|
||||
const cb = createCB();
|
||||
recordFailedRound(cb);
|
||||
recordFailedRound(cb);
|
||||
|
||||
jest.advanceTimersByTime(CB_FAILED_WINDOW_MS + 1);
|
||||
|
||||
recordFailedRound(cb);
|
||||
expect(isCircuitOpen(cb)).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('resetFailedRounds', () => {
|
||||
it('should clear failed round state on successful connection', () => {
|
||||
const now = Date.now();
|
||||
jest.setSystemTime(now);
|
||||
|
||||
const cb = createCB();
|
||||
for (let i = 0; i < CB_MAX_FAILED_ROUNDS; i++) {
|
||||
recordFailedRound(cb);
|
||||
}
|
||||
expect(isCircuitOpen(cb)).toBe(true);
|
||||
|
||||
resetFailedRounds(cb);
|
||||
expect(isCircuitOpen(cb)).toBe(false);
|
||||
expect(cb.failedRounds).toBe(0);
|
||||
expect(cb.failedBackoffUntil).toBe(0);
|
||||
});
|
||||
});
|
||||
|
||||
describe('clearCooldown (registry deletion)', () => {
|
||||
it('should allow connections after clearing circuit breaker state', () => {
|
||||
const now = Date.now();
|
||||
jest.setSystemTime(now);
|
||||
|
||||
const registry = new Map<string, CircuitBreakerState>();
|
||||
const serverName = 'test-server';
|
||||
|
||||
const cb = createCB();
|
||||
registry.set(serverName, cb);
|
||||
|
||||
for (let i = 0; i < CB_MAX_CYCLES; i++) {
|
||||
recordCycle(cb);
|
||||
}
|
||||
expect(isCircuitOpen(cb)).toBe(true);
|
||||
|
||||
registry.delete(serverName);
|
||||
|
||||
const newCb = createCB();
|
||||
expect(isCircuitOpen(newCb)).toBe(false);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -207,6 +207,7 @@ describe('MCPConnection Agent lifecycle – streamable-http', () => {
|
|||
});
|
||||
|
||||
afterEach(async () => {
|
||||
MCPConnection.clearCooldown('test');
|
||||
await safeDisconnect(conn);
|
||||
conn = null;
|
||||
jest.restoreAllMocks();
|
||||
|
|
@ -366,6 +367,7 @@ describe('MCPConnection Agent lifecycle – SSE', () => {
|
|||
});
|
||||
|
||||
afterEach(async () => {
|
||||
MCPConnection.clearCooldown('test-sse');
|
||||
await safeDisconnect(conn);
|
||||
conn = null;
|
||||
jest.restoreAllMocks();
|
||||
|
|
@ -453,6 +455,7 @@ describe('Regression: old per-request Agent pattern leaks agents', () => {
|
|||
});
|
||||
|
||||
afterEach(async () => {
|
||||
MCPConnection.clearCooldown('test-regression');
|
||||
await safeDisconnect(conn);
|
||||
conn = null;
|
||||
jest.restoreAllMocks();
|
||||
|
|
@ -675,6 +678,7 @@ describe('MCPConnection SSE GET stream recovery – integration', () => {
|
|||
});
|
||||
|
||||
afterEach(async () => {
|
||||
MCPConnection.clearCooldown('test-sse-recovery');
|
||||
await safeDisconnect(conn);
|
||||
conn = null;
|
||||
jest.restoreAllMocks();
|
||||
|
|
|
|||
|
|
@ -71,6 +71,25 @@ const FIVE_MINUTES = 5 * 60 * 1000;
|
|||
const DEFAULT_TIMEOUT = 60000;
|
||||
/** SSE connections through proxies may need longer initial handshake time */
|
||||
const SSE_CONNECT_TIMEOUT = 120000;
|
||||
const DEFAULT_INIT_TIMEOUT = 30000;
|
||||
|
||||
interface CircuitBreakerState {
|
||||
cycleCount: number;
|
||||
cycleWindowStart: number;
|
||||
cooldownUntil: number;
|
||||
failedRounds: number;
|
||||
failedWindowStart: number;
|
||||
failedBackoffUntil: number;
|
||||
}
|
||||
|
||||
const CB_MAX_CYCLES = 5;
|
||||
const CB_CYCLE_WINDOW_MS = 60_000;
|
||||
const CB_CYCLE_COOLDOWN_MS = 30_000;
|
||||
|
||||
const CB_MAX_FAILED_ROUNDS = 3;
|
||||
const CB_FAILED_WINDOW_MS = 120_000;
|
||||
const CB_BASE_BACKOFF_MS = 30_000;
|
||||
const CB_MAX_BACKOFF_MS = 300_000;
|
||||
/** Default body timeout for Streamable HTTP GET SSE streams that idle between server pushes */
|
||||
const DEFAULT_SSE_READ_TIMEOUT = FIVE_MINUTES;
|
||||
|
||||
|
|
@ -274,6 +293,80 @@ export class MCPConnection extends EventEmitter {
|
|||
*/
|
||||
public readonly createdAt: number;
|
||||
|
||||
private static circuitBreakers: Map<string, CircuitBreakerState> = new Map();
|
||||
|
||||
public static clearCooldown(serverName: string): void {
|
||||
MCPConnection.circuitBreakers.delete(serverName);
|
||||
logger.debug(`[MCP][${serverName}] Circuit breaker state cleared`);
|
||||
}
|
||||
|
||||
private getCircuitBreaker(): CircuitBreakerState {
|
||||
let cb = MCPConnection.circuitBreakers.get(this.serverName);
|
||||
if (!cb) {
|
||||
cb = {
|
||||
cycleCount: 0,
|
||||
cycleWindowStart: Date.now(),
|
||||
cooldownUntil: 0,
|
||||
failedRounds: 0,
|
||||
failedWindowStart: Date.now(),
|
||||
failedBackoffUntil: 0,
|
||||
};
|
||||
MCPConnection.circuitBreakers.set(this.serverName, cb);
|
||||
}
|
||||
return cb;
|
||||
}
|
||||
|
||||
private isCircuitOpen(): boolean {
|
||||
const cb = this.getCircuitBreaker();
|
||||
const now = Date.now();
|
||||
return now < cb.cooldownUntil || now < cb.failedBackoffUntil;
|
||||
}
|
||||
|
||||
private recordCycle(): void {
|
||||
const cb = this.getCircuitBreaker();
|
||||
const now = Date.now();
|
||||
if (now - cb.cycleWindowStart > CB_CYCLE_WINDOW_MS) {
|
||||
cb.cycleCount = 0;
|
||||
cb.cycleWindowStart = now;
|
||||
}
|
||||
cb.cycleCount++;
|
||||
if (cb.cycleCount >= CB_MAX_CYCLES) {
|
||||
cb.cooldownUntil = now + CB_CYCLE_COOLDOWN_MS;
|
||||
cb.cycleCount = 0;
|
||||
cb.cycleWindowStart = now;
|
||||
logger.warn(
|
||||
`${this.getLogPrefix()} Circuit breaker: too many cycles, cooling down for ${CB_CYCLE_COOLDOWN_MS}ms`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
private recordFailedRound(): void {
|
||||
const cb = this.getCircuitBreaker();
|
||||
const now = Date.now();
|
||||
if (now - cb.failedWindowStart > CB_FAILED_WINDOW_MS) {
|
||||
cb.failedRounds = 0;
|
||||
cb.failedWindowStart = now;
|
||||
}
|
||||
cb.failedRounds++;
|
||||
if (cb.failedRounds >= CB_MAX_FAILED_ROUNDS) {
|
||||
const backoff = Math.min(
|
||||
CB_BASE_BACKOFF_MS * Math.pow(2, cb.failedRounds - CB_MAX_FAILED_ROUNDS),
|
||||
CB_MAX_BACKOFF_MS,
|
||||
);
|
||||
cb.failedBackoffUntil = now + backoff;
|
||||
logger.warn(
|
||||
`${this.getLogPrefix()} Circuit breaker: too many failures, backing off for ${backoff}ms`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
private resetFailedRounds(): void {
|
||||
const cb = this.getCircuitBreaker();
|
||||
cb.failedRounds = 0;
|
||||
cb.failedWindowStart = Date.now();
|
||||
cb.failedBackoffUntil = 0;
|
||||
}
|
||||
|
||||
setRequestHeaders(headers: Record<string, string> | null): void {
|
||||
if (!headers) {
|
||||
return;
|
||||
|
|
@ -686,6 +779,12 @@ export class MCPConnection extends EventEmitter {
|
|||
return;
|
||||
}
|
||||
|
||||
if (this.isCircuitOpen()) {
|
||||
this.connectionState = 'error';
|
||||
this.emit('connectionChange', 'error');
|
||||
throw new Error(`${this.getLogPrefix()} Circuit breaker is open, connection attempt blocked`);
|
||||
}
|
||||
|
||||
this.emit('connectionChange', 'connecting');
|
||||
|
||||
this.connectPromise = (async () => {
|
||||
|
|
@ -703,7 +802,7 @@ export class MCPConnection extends EventEmitter {
|
|||
this.transport = await runOutsideTracing(() => this.constructTransport(this.options));
|
||||
this.patchTransportSend();
|
||||
|
||||
const connectTimeout = this.options.initTimeout ?? 120000;
|
||||
const connectTimeout = this.options.initTimeout ?? DEFAULT_INIT_TIMEOUT;
|
||||
await runOutsideTracing(() =>
|
||||
withTimeout(
|
||||
this.client.connect(this.transport!),
|
||||
|
|
@ -716,6 +815,7 @@ export class MCPConnection extends EventEmitter {
|
|||
this.connectionState = 'connected';
|
||||
this.emit('connectionChange', 'connected');
|
||||
this.reconnectAttempts = 0;
|
||||
this.resetFailedRounds();
|
||||
} catch (error) {
|
||||
// Check if it's a rate limit error - stop immediately to avoid making it worse
|
||||
if (this.isRateLimitError(error)) {
|
||||
|
|
@ -817,6 +917,7 @@ export class MCPConnection extends EventEmitter {
|
|||
|
||||
this.connectionState = 'error';
|
||||
this.emit('connectionChange', 'error');
|
||||
this.recordFailedRound();
|
||||
throw error;
|
||||
} finally {
|
||||
this.connectPromise = null;
|
||||
|
|
@ -866,7 +967,8 @@ export class MCPConnection extends EventEmitter {
|
|||
|
||||
async connect(): Promise<void> {
|
||||
try {
|
||||
await this.disconnect();
|
||||
// preserve cycle tracking across reconnects so the circuit breaker can detect rapid cycling
|
||||
await this.disconnect(false);
|
||||
await this.connectClient();
|
||||
if (!(await this.isConnected())) {
|
||||
throw new Error('Connection not established');
|
||||
|
|
@ -906,7 +1008,7 @@ export class MCPConnection extends EventEmitter {
|
|||
isTransient,
|
||||
} = extractSSEErrorMessage(error);
|
||||
|
||||
if (errorCode === 404) {
|
||||
if (errorCode === 400 || errorCode === 404 || errorCode === 405) {
|
||||
const hasSession =
|
||||
'sessionId' in transport &&
|
||||
(transport as { sessionId?: string }).sessionId != null &&
|
||||
|
|
@ -914,14 +1016,14 @@ export class MCPConnection extends EventEmitter {
|
|||
|
||||
if (!hasSession && errorMessage.toLowerCase().includes('failed to open sse stream')) {
|
||||
logger.warn(
|
||||
`${this.getLogPrefix()} SSE stream not available (404), no session. Ignoring.`,
|
||||
`${this.getLogPrefix()} SSE stream not available (${errorCode}), no session. Ignoring.`,
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
if (hasSession) {
|
||||
logger.warn(
|
||||
`${this.getLogPrefix()} 404 with active session — session lost, triggering reconnection.`,
|
||||
`${this.getLogPrefix()} ${errorCode} with active session — session lost, triggering reconnection.`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
@ -992,7 +1094,7 @@ export class MCPConnection extends EventEmitter {
|
|||
await Promise.all(closing);
|
||||
}
|
||||
|
||||
public async disconnect(): Promise<void> {
|
||||
public async disconnect(resetCycleTracking = true): Promise<void> {
|
||||
try {
|
||||
if (this.transport) {
|
||||
await this.client.close();
|
||||
|
|
@ -1006,6 +1108,9 @@ export class MCPConnection extends EventEmitter {
|
|||
this.emit('connectionChange', 'disconnected');
|
||||
} finally {
|
||||
this.connectPromise = null;
|
||||
if (!resetCycleTracking) {
|
||||
this.recordCycle();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -253,17 +253,21 @@ describe('OAuthReconnectionManager', () => {
|
|||
expect(mockMCPManager.disconnectUserConnection).toHaveBeenCalledWith(userId, 'server1');
|
||||
});
|
||||
|
||||
it('should not reconnect servers with expired tokens', async () => {
|
||||
it('should not reconnect servers with expired tokens and no refresh token', async () => {
|
||||
const userId = 'user-123';
|
||||
const oauthServers = new Set(['server1']);
|
||||
(mockRegistryInstance.getOAuthServers as jest.Mock).mockResolvedValue(oauthServers);
|
||||
|
||||
// server1: has expired token
|
||||
tokenMethods.findToken.mockResolvedValue({
|
||||
userId,
|
||||
identifier: 'mcp:server1',
|
||||
expiresAt: new Date(Date.now() - 3600000), // 1 hour ago
|
||||
} as unknown as MCPOAuthTokens);
|
||||
tokenMethods.findToken.mockImplementation(async ({ identifier }) => {
|
||||
if (identifier === 'mcp:server1') {
|
||||
return {
|
||||
userId,
|
||||
identifier,
|
||||
expiresAt: new Date(Date.now() - 3600000),
|
||||
} as unknown as MCPOAuthTokens;
|
||||
}
|
||||
return null;
|
||||
});
|
||||
|
||||
await reconnectionManager.reconnectServers(userId);
|
||||
|
||||
|
|
@ -272,6 +276,87 @@ describe('OAuthReconnectionManager', () => {
|
|||
expect(mockMCPManager.getUserConnection).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should reconnect servers with expired access token but valid refresh token', async () => {
|
||||
const userId = 'user-123';
|
||||
const oauthServers = new Set(['server1']);
|
||||
(mockRegistryInstance.getOAuthServers as jest.Mock).mockResolvedValue(oauthServers);
|
||||
|
||||
tokenMethods.findToken.mockImplementation(async ({ identifier }) => {
|
||||
if (identifier === 'mcp:server1') {
|
||||
return {
|
||||
userId,
|
||||
identifier,
|
||||
expiresAt: new Date(Date.now() - 3600000),
|
||||
} as unknown as MCPOAuthTokens;
|
||||
}
|
||||
if (identifier === 'mcp:server1:refresh') {
|
||||
return {
|
||||
userId,
|
||||
identifier,
|
||||
} as unknown as MCPOAuthTokens;
|
||||
}
|
||||
return null;
|
||||
});
|
||||
|
||||
const mockNewConnection = {
|
||||
isConnected: jest.fn().mockResolvedValue(true),
|
||||
disconnect: jest.fn(),
|
||||
};
|
||||
mockMCPManager.getUserConnection.mockResolvedValue(
|
||||
mockNewConnection as unknown as MCPConnection,
|
||||
);
|
||||
(mockRegistryInstance.getServerConfig as jest.Mock).mockResolvedValue(
|
||||
{} as unknown as MCPOptions,
|
||||
);
|
||||
|
||||
await reconnectionManager.reconnectServers(userId);
|
||||
|
||||
expect(reconnectionTracker.isActive(userId, 'server1')).toBe(true);
|
||||
|
||||
await new Promise((resolve) => setTimeout(resolve, 100));
|
||||
|
||||
expect(mockMCPManager.getUserConnection).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ serverName: 'server1' }),
|
||||
);
|
||||
});
|
||||
|
||||
it('should reconnect when access token is TTL-deleted but refresh token exists', async () => {
|
||||
const userId = 'user-123';
|
||||
const oauthServers = new Set(['server1']);
|
||||
(mockRegistryInstance.getOAuthServers as jest.Mock).mockResolvedValue(oauthServers);
|
||||
|
||||
tokenMethods.findToken.mockImplementation(async ({ identifier }) => {
|
||||
if (identifier === 'mcp:server1:refresh') {
|
||||
return {
|
||||
userId,
|
||||
identifier,
|
||||
} as unknown as MCPOAuthTokens;
|
||||
}
|
||||
return null;
|
||||
});
|
||||
|
||||
const mockNewConnection = {
|
||||
isConnected: jest.fn().mockResolvedValue(true),
|
||||
disconnect: jest.fn(),
|
||||
};
|
||||
mockMCPManager.getUserConnection.mockResolvedValue(
|
||||
mockNewConnection as unknown as MCPConnection,
|
||||
);
|
||||
(mockRegistryInstance.getServerConfig as jest.Mock).mockResolvedValue(
|
||||
{} as unknown as MCPOptions,
|
||||
);
|
||||
|
||||
await reconnectionManager.reconnectServers(userId);
|
||||
|
||||
expect(reconnectionTracker.isActive(userId, 'server1')).toBe(true);
|
||||
|
||||
await new Promise((resolve) => setTimeout(resolve, 100));
|
||||
|
||||
expect(mockMCPManager.getUserConnection).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ serverName: 'server1' }),
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle connection that returns but is not connected', async () => {
|
||||
const userId = 'user-123';
|
||||
const oauthServers = new Set(['server1']);
|
||||
|
|
@ -336,6 +421,69 @@ describe('OAuthReconnectionManager', () => {
|
|||
});
|
||||
});
|
||||
|
||||
describe('reconnectServer', () => {
|
||||
let reconnectionTracker: OAuthReconnectionTracker;
|
||||
beforeEach(async () => {
|
||||
reconnectionTracker = new OAuthReconnectionTracker();
|
||||
reconnectionManager = await OAuthReconnectionManager.createInstance(
|
||||
flowManager,
|
||||
tokenMethods,
|
||||
reconnectionTracker,
|
||||
);
|
||||
});
|
||||
|
||||
it('should return true on successful reconnection', async () => {
|
||||
const userId = 'user-123';
|
||||
const serverName = 'server1';
|
||||
|
||||
const mockConnection = {
|
||||
isConnected: jest.fn().mockResolvedValue(true),
|
||||
disconnect: jest.fn(),
|
||||
};
|
||||
mockMCPManager.getUserConnection.mockResolvedValue(
|
||||
mockConnection as unknown as MCPConnection,
|
||||
);
|
||||
(mockRegistryInstance.getServerConfig as jest.Mock).mockResolvedValue(
|
||||
{} as unknown as MCPOptions,
|
||||
);
|
||||
|
||||
const result = await reconnectionManager.reconnectServer(userId, serverName);
|
||||
expect(result).toBe(true);
|
||||
});
|
||||
|
||||
it('should return false on failed reconnection', async () => {
|
||||
const userId = 'user-123';
|
||||
const serverName = 'server1';
|
||||
|
||||
mockMCPManager.getUserConnection.mockRejectedValue(new Error('Connection failed'));
|
||||
(mockRegistryInstance.getServerConfig as jest.Mock).mockResolvedValue(
|
||||
{} as unknown as MCPOptions,
|
||||
);
|
||||
|
||||
const result = await reconnectionManager.reconnectServer(userId, serverName);
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
|
||||
it('should return false when MCPManager is not available', async () => {
|
||||
const userId = 'user-123';
|
||||
const serverName = 'server1';
|
||||
|
||||
(OAuthReconnectionManager as unknown as { instance: null }).instance = null;
|
||||
(MCPManager.getInstance as jest.Mock).mockImplementation(() => {
|
||||
throw new Error('MCPManager has not been initialized.');
|
||||
});
|
||||
|
||||
const managerWithoutMCP = await OAuthReconnectionManager.createInstance(
|
||||
flowManager,
|
||||
tokenMethods,
|
||||
reconnectionTracker,
|
||||
);
|
||||
|
||||
const result = await managerWithoutMCP.reconnectServer(userId, serverName);
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('reconnection staggering', () => {
|
||||
let reconnectionTracker: OAuthReconnectionTracker;
|
||||
|
||||
|
|
|
|||
|
|
@ -96,6 +96,24 @@ export class OAuthReconnectionManager {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Attempts to reconnect a single OAuth MCP server.
|
||||
* @returns true if reconnection succeeded, false otherwise.
|
||||
*/
|
||||
public async reconnectServer(userId: string, serverName: string): Promise<boolean> {
|
||||
if (this.mcpManager == null) {
|
||||
return false;
|
||||
}
|
||||
|
||||
this.reconnectionsTracker.setActive(userId, serverName);
|
||||
try {
|
||||
await this.tryReconnect(userId, serverName);
|
||||
return !this.reconnectionsTracker.isFailed(userId, serverName);
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
public clearReconnection(userId: string, serverName: string) {
|
||||
this.reconnectionsTracker.removeFailed(userId, serverName);
|
||||
this.reconnectionsTracker.removeActive(userId, serverName);
|
||||
|
|
@ -174,23 +192,31 @@ export class OAuthReconnectionManager {
|
|||
}
|
||||
}
|
||||
|
||||
// if the server has no tokens for the user, don't attempt to reconnect
|
||||
// if the server has a valid (non-expired) access token, allow reconnect
|
||||
const accessToken = await this.tokenMethods.findToken({
|
||||
userId,
|
||||
type: 'mcp_oauth',
|
||||
identifier: `mcp:${serverName}`,
|
||||
});
|
||||
if (accessToken == null) {
|
||||
|
||||
if (accessToken != null) {
|
||||
const now = new Date();
|
||||
if (!accessToken.expiresAt || accessToken.expiresAt >= now) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
// if the access token is expired or TTL-deleted, fall back to refresh token
|
||||
const refreshToken = await this.tokenMethods.findToken({
|
||||
userId,
|
||||
type: 'mcp_oauth',
|
||||
identifier: `mcp:${serverName}:refresh`,
|
||||
});
|
||||
|
||||
if (refreshToken == null) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// if the token has expired, don't attempt to reconnect
|
||||
const now = new Date();
|
||||
if (accessToken.expiresAt && accessToken.expiresAt < now) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// …otherwise, we're good to go with the reconnect attempt
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -397,6 +397,101 @@ describe('OAuthReconnectTracker', () => {
|
|||
});
|
||||
});
|
||||
|
||||
describe('cooldown-based retry', () => {
|
||||
beforeEach(() => {
|
||||
jest.useFakeTimers();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
jest.useRealTimers();
|
||||
});
|
||||
|
||||
it('should return true from isFailed within first cooldown period (5 min)', () => {
|
||||
const now = Date.now();
|
||||
jest.setSystemTime(now);
|
||||
|
||||
tracker.setFailed(userId, serverName);
|
||||
expect(tracker.isFailed(userId, serverName)).toBe(true);
|
||||
|
||||
jest.advanceTimersByTime(4 * 60 * 1000);
|
||||
expect(tracker.isFailed(userId, serverName)).toBe(true);
|
||||
});
|
||||
|
||||
it('should return false from isFailed after first cooldown elapses (5 min)', () => {
|
||||
const now = Date.now();
|
||||
jest.setSystemTime(now);
|
||||
|
||||
tracker.setFailed(userId, serverName);
|
||||
expect(tracker.isFailed(userId, serverName)).toBe(true);
|
||||
|
||||
jest.advanceTimersByTime(5 * 60 * 1000);
|
||||
expect(tracker.isFailed(userId, serverName)).toBe(false);
|
||||
});
|
||||
|
||||
it('should use progressive cooldown schedule (5m, 10m, 20m, 30m)', () => {
|
||||
const now = Date.now();
|
||||
jest.setSystemTime(now);
|
||||
|
||||
// First failure: 5 min cooldown
|
||||
tracker.setFailed(userId, serverName);
|
||||
jest.advanceTimersByTime(5 * 60 * 1000);
|
||||
expect(tracker.isFailed(userId, serverName)).toBe(false);
|
||||
|
||||
// Second failure: 10 min cooldown
|
||||
tracker.setFailed(userId, serverName);
|
||||
jest.advanceTimersByTime(9 * 60 * 1000);
|
||||
expect(tracker.isFailed(userId, serverName)).toBe(true);
|
||||
jest.advanceTimersByTime(1 * 60 * 1000);
|
||||
expect(tracker.isFailed(userId, serverName)).toBe(false);
|
||||
|
||||
// Third failure: 20 min cooldown
|
||||
tracker.setFailed(userId, serverName);
|
||||
jest.advanceTimersByTime(19 * 60 * 1000);
|
||||
expect(tracker.isFailed(userId, serverName)).toBe(true);
|
||||
jest.advanceTimersByTime(1 * 60 * 1000);
|
||||
expect(tracker.isFailed(userId, serverName)).toBe(false);
|
||||
|
||||
// Fourth failure: 30 min cooldown
|
||||
tracker.setFailed(userId, serverName);
|
||||
jest.advanceTimersByTime(29 * 60 * 1000);
|
||||
expect(tracker.isFailed(userId, serverName)).toBe(true);
|
||||
jest.advanceTimersByTime(1 * 60 * 1000);
|
||||
expect(tracker.isFailed(userId, serverName)).toBe(false);
|
||||
});
|
||||
|
||||
it('should cap cooldown at 30 min for attempts beyond 4', () => {
|
||||
const now = Date.now();
|
||||
jest.setSystemTime(now);
|
||||
|
||||
for (let i = 0; i < 5; i++) {
|
||||
tracker.setFailed(userId, serverName);
|
||||
jest.advanceTimersByTime(30 * 60 * 1000);
|
||||
}
|
||||
|
||||
tracker.setFailed(userId, serverName);
|
||||
jest.advanceTimersByTime(29 * 60 * 1000);
|
||||
expect(tracker.isFailed(userId, serverName)).toBe(true);
|
||||
jest.advanceTimersByTime(1 * 60 * 1000);
|
||||
expect(tracker.isFailed(userId, serverName)).toBe(false);
|
||||
});
|
||||
|
||||
it('should fully reset metadata on removeFailed', () => {
|
||||
const now = Date.now();
|
||||
jest.setSystemTime(now);
|
||||
|
||||
tracker.setFailed(userId, serverName);
|
||||
tracker.setFailed(userId, serverName);
|
||||
tracker.setFailed(userId, serverName);
|
||||
|
||||
tracker.removeFailed(userId, serverName);
|
||||
expect(tracker.isFailed(userId, serverName)).toBe(false);
|
||||
|
||||
tracker.setFailed(userId, serverName);
|
||||
jest.advanceTimersByTime(5 * 60 * 1000);
|
||||
expect(tracker.isFailed(userId, serverName)).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('timestamp tracking edge cases', () => {
|
||||
beforeEach(() => {
|
||||
jest.useFakeTimers();
|
||||
|
|
|
|||
|
|
@ -1,6 +1,12 @@
|
|||
interface FailedMeta {
|
||||
attempts: number;
|
||||
lastFailedAt: number;
|
||||
}
|
||||
|
||||
const COOLDOWN_SCHEDULE_MS = [5 * 60 * 1000, 10 * 60 * 1000, 20 * 60 * 1000, 30 * 60 * 1000];
|
||||
|
||||
export class OAuthReconnectionTracker {
|
||||
/** Map of userId -> Set of serverNames that have failed reconnection */
|
||||
private failed: Map<string, Set<string>> = new Map();
|
||||
private failedMeta: Map<string, Map<string, FailedMeta>> = new Map();
|
||||
/** Map of userId -> Set of serverNames that are actively reconnecting */
|
||||
private active: Map<string, Set<string>> = new Map();
|
||||
/** Map of userId:serverName -> timestamp when reconnection started */
|
||||
|
|
@ -9,7 +15,17 @@ export class OAuthReconnectionTracker {
|
|||
private readonly RECONNECTION_TIMEOUT_MS = 3 * 60 * 1000; // 3 minutes
|
||||
|
||||
public isFailed(userId: string, serverName: string): boolean {
|
||||
return this.failed.get(userId)?.has(serverName) ?? false;
|
||||
const meta = this.failedMeta.get(userId)?.get(serverName);
|
||||
if (!meta) {
|
||||
return false;
|
||||
}
|
||||
const idx = Math.min(meta.attempts - 1, COOLDOWN_SCHEDULE_MS.length - 1);
|
||||
const cooldown = COOLDOWN_SCHEDULE_MS[idx];
|
||||
const elapsed = Date.now() - meta.lastFailedAt;
|
||||
if (elapsed >= cooldown) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/** Check if server is in the active set (original simple check) */
|
||||
|
|
@ -48,11 +64,15 @@ export class OAuthReconnectionTracker {
|
|||
}
|
||||
|
||||
public setFailed(userId: string, serverName: string): void {
|
||||
if (!this.failed.has(userId)) {
|
||||
this.failed.set(userId, new Set());
|
||||
if (!this.failedMeta.has(userId)) {
|
||||
this.failedMeta.set(userId, new Map());
|
||||
}
|
||||
|
||||
this.failed.get(userId)?.add(serverName);
|
||||
const userMap = this.failedMeta.get(userId)!;
|
||||
const existing = userMap.get(serverName);
|
||||
userMap.set(serverName, {
|
||||
attempts: (existing?.attempts ?? 0) + 1,
|
||||
lastFailedAt: Date.now(),
|
||||
});
|
||||
}
|
||||
|
||||
public setActive(userId: string, serverName: string): void {
|
||||
|
|
@ -68,10 +88,10 @@ export class OAuthReconnectionTracker {
|
|||
}
|
||||
|
||||
public removeFailed(userId: string, serverName: string): void {
|
||||
const userServers = this.failed.get(userId);
|
||||
userServers?.delete(serverName);
|
||||
if (userServers?.size === 0) {
|
||||
this.failed.delete(userId);
|
||||
const userMap = this.failedMeta.get(userId);
|
||||
userMap?.delete(serverName);
|
||||
if (userMap?.size === 0) {
|
||||
this.failedMeta.delete(userId);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -94,7 +114,7 @@ export class OAuthReconnectionTracker {
|
|||
activeTimestamps: number;
|
||||
} {
|
||||
return {
|
||||
usersWithFailedServers: this.failed.size,
|
||||
usersWithFailedServers: this.failedMeta.size,
|
||||
usersWithActiveReconnections: this.active.size,
|
||||
activeTimestamps: this.activeTimestamps.size,
|
||||
};
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue