mirror of
https://github.com/danny-avila/LibreChat.git
synced 2026-03-15 12:16:33 +01:00
⛈️ fix: MCP Reconnection Storm Prevention with Circuit Breaker, Backoff, and Tool Stubs (#12162)
* fix: MCP reconnection stability - circuit breaker, throttling, and cooldown retry * Comment and logging cleanup * fix broken tests
This commit is contained in:
parent
cfbe812d63
commit
ad5c51f62b
9 changed files with 736 additions and 38 deletions
|
|
@ -253,17 +253,21 @@ describe('OAuthReconnectionManager', () => {
|
|||
expect(mockMCPManager.disconnectUserConnection).toHaveBeenCalledWith(userId, 'server1');
|
||||
});
|
||||
|
||||
it('should not reconnect servers with expired tokens', async () => {
|
||||
it('should not reconnect servers with expired tokens and no refresh token', async () => {
|
||||
const userId = 'user-123';
|
||||
const oauthServers = new Set(['server1']);
|
||||
(mockRegistryInstance.getOAuthServers as jest.Mock).mockResolvedValue(oauthServers);
|
||||
|
||||
// server1: has expired token
|
||||
tokenMethods.findToken.mockResolvedValue({
|
||||
userId,
|
||||
identifier: 'mcp:server1',
|
||||
expiresAt: new Date(Date.now() - 3600000), // 1 hour ago
|
||||
} as unknown as MCPOAuthTokens);
|
||||
tokenMethods.findToken.mockImplementation(async ({ identifier }) => {
|
||||
if (identifier === 'mcp:server1') {
|
||||
return {
|
||||
userId,
|
||||
identifier,
|
||||
expiresAt: new Date(Date.now() - 3600000),
|
||||
} as unknown as MCPOAuthTokens;
|
||||
}
|
||||
return null;
|
||||
});
|
||||
|
||||
await reconnectionManager.reconnectServers(userId);
|
||||
|
||||
|
|
@ -272,6 +276,87 @@ describe('OAuthReconnectionManager', () => {
|
|||
expect(mockMCPManager.getUserConnection).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should reconnect servers with expired access token but valid refresh token', async () => {
|
||||
const userId = 'user-123';
|
||||
const oauthServers = new Set(['server1']);
|
||||
(mockRegistryInstance.getOAuthServers as jest.Mock).mockResolvedValue(oauthServers);
|
||||
|
||||
tokenMethods.findToken.mockImplementation(async ({ identifier }) => {
|
||||
if (identifier === 'mcp:server1') {
|
||||
return {
|
||||
userId,
|
||||
identifier,
|
||||
expiresAt: new Date(Date.now() - 3600000),
|
||||
} as unknown as MCPOAuthTokens;
|
||||
}
|
||||
if (identifier === 'mcp:server1:refresh') {
|
||||
return {
|
||||
userId,
|
||||
identifier,
|
||||
} as unknown as MCPOAuthTokens;
|
||||
}
|
||||
return null;
|
||||
});
|
||||
|
||||
const mockNewConnection = {
|
||||
isConnected: jest.fn().mockResolvedValue(true),
|
||||
disconnect: jest.fn(),
|
||||
};
|
||||
mockMCPManager.getUserConnection.mockResolvedValue(
|
||||
mockNewConnection as unknown as MCPConnection,
|
||||
);
|
||||
(mockRegistryInstance.getServerConfig as jest.Mock).mockResolvedValue(
|
||||
{} as unknown as MCPOptions,
|
||||
);
|
||||
|
||||
await reconnectionManager.reconnectServers(userId);
|
||||
|
||||
expect(reconnectionTracker.isActive(userId, 'server1')).toBe(true);
|
||||
|
||||
await new Promise((resolve) => setTimeout(resolve, 100));
|
||||
|
||||
expect(mockMCPManager.getUserConnection).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ serverName: 'server1' }),
|
||||
);
|
||||
});
|
||||
|
||||
it('should reconnect when access token is TTL-deleted but refresh token exists', async () => {
|
||||
const userId = 'user-123';
|
||||
const oauthServers = new Set(['server1']);
|
||||
(mockRegistryInstance.getOAuthServers as jest.Mock).mockResolvedValue(oauthServers);
|
||||
|
||||
tokenMethods.findToken.mockImplementation(async ({ identifier }) => {
|
||||
if (identifier === 'mcp:server1:refresh') {
|
||||
return {
|
||||
userId,
|
||||
identifier,
|
||||
} as unknown as MCPOAuthTokens;
|
||||
}
|
||||
return null;
|
||||
});
|
||||
|
||||
const mockNewConnection = {
|
||||
isConnected: jest.fn().mockResolvedValue(true),
|
||||
disconnect: jest.fn(),
|
||||
};
|
||||
mockMCPManager.getUserConnection.mockResolvedValue(
|
||||
mockNewConnection as unknown as MCPConnection,
|
||||
);
|
||||
(mockRegistryInstance.getServerConfig as jest.Mock).mockResolvedValue(
|
||||
{} as unknown as MCPOptions,
|
||||
);
|
||||
|
||||
await reconnectionManager.reconnectServers(userId);
|
||||
|
||||
expect(reconnectionTracker.isActive(userId, 'server1')).toBe(true);
|
||||
|
||||
await new Promise((resolve) => setTimeout(resolve, 100));
|
||||
|
||||
expect(mockMCPManager.getUserConnection).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ serverName: 'server1' }),
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle connection that returns but is not connected', async () => {
|
||||
const userId = 'user-123';
|
||||
const oauthServers = new Set(['server1']);
|
||||
|
|
@ -336,6 +421,69 @@ describe('OAuthReconnectionManager', () => {
|
|||
});
|
||||
});
|
||||
|
||||
describe('reconnectServer', () => {
|
||||
let reconnectionTracker: OAuthReconnectionTracker;
|
||||
beforeEach(async () => {
|
||||
reconnectionTracker = new OAuthReconnectionTracker();
|
||||
reconnectionManager = await OAuthReconnectionManager.createInstance(
|
||||
flowManager,
|
||||
tokenMethods,
|
||||
reconnectionTracker,
|
||||
);
|
||||
});
|
||||
|
||||
it('should return true on successful reconnection', async () => {
|
||||
const userId = 'user-123';
|
||||
const serverName = 'server1';
|
||||
|
||||
const mockConnection = {
|
||||
isConnected: jest.fn().mockResolvedValue(true),
|
||||
disconnect: jest.fn(),
|
||||
};
|
||||
mockMCPManager.getUserConnection.mockResolvedValue(
|
||||
mockConnection as unknown as MCPConnection,
|
||||
);
|
||||
(mockRegistryInstance.getServerConfig as jest.Mock).mockResolvedValue(
|
||||
{} as unknown as MCPOptions,
|
||||
);
|
||||
|
||||
const result = await reconnectionManager.reconnectServer(userId, serverName);
|
||||
expect(result).toBe(true);
|
||||
});
|
||||
|
||||
it('should return false on failed reconnection', async () => {
|
||||
const userId = 'user-123';
|
||||
const serverName = 'server1';
|
||||
|
||||
mockMCPManager.getUserConnection.mockRejectedValue(new Error('Connection failed'));
|
||||
(mockRegistryInstance.getServerConfig as jest.Mock).mockResolvedValue(
|
||||
{} as unknown as MCPOptions,
|
||||
);
|
||||
|
||||
const result = await reconnectionManager.reconnectServer(userId, serverName);
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
|
||||
it('should return false when MCPManager is not available', async () => {
|
||||
const userId = 'user-123';
|
||||
const serverName = 'server1';
|
||||
|
||||
(OAuthReconnectionManager as unknown as { instance: null }).instance = null;
|
||||
(MCPManager.getInstance as jest.Mock).mockImplementation(() => {
|
||||
throw new Error('MCPManager has not been initialized.');
|
||||
});
|
||||
|
||||
const managerWithoutMCP = await OAuthReconnectionManager.createInstance(
|
||||
flowManager,
|
||||
tokenMethods,
|
||||
reconnectionTracker,
|
||||
);
|
||||
|
||||
const result = await managerWithoutMCP.reconnectServer(userId, serverName);
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('reconnection staggering', () => {
|
||||
let reconnectionTracker: OAuthReconnectionTracker;
|
||||
|
||||
|
|
|
|||
|
|
@ -96,6 +96,24 @@ export class OAuthReconnectionManager {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Attempts to reconnect a single OAuth MCP server.
|
||||
* @returns true if reconnection succeeded, false otherwise.
|
||||
*/
|
||||
public async reconnectServer(userId: string, serverName: string): Promise<boolean> {
|
||||
if (this.mcpManager == null) {
|
||||
return false;
|
||||
}
|
||||
|
||||
this.reconnectionsTracker.setActive(userId, serverName);
|
||||
try {
|
||||
await this.tryReconnect(userId, serverName);
|
||||
return !this.reconnectionsTracker.isFailed(userId, serverName);
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
public clearReconnection(userId: string, serverName: string) {
|
||||
this.reconnectionsTracker.removeFailed(userId, serverName);
|
||||
this.reconnectionsTracker.removeActive(userId, serverName);
|
||||
|
|
@ -174,23 +192,31 @@ export class OAuthReconnectionManager {
|
|||
}
|
||||
}
|
||||
|
||||
// if the server has no tokens for the user, don't attempt to reconnect
|
||||
// if the server has a valid (non-expired) access token, allow reconnect
|
||||
const accessToken = await this.tokenMethods.findToken({
|
||||
userId,
|
||||
type: 'mcp_oauth',
|
||||
identifier: `mcp:${serverName}`,
|
||||
});
|
||||
if (accessToken == null) {
|
||||
|
||||
if (accessToken != null) {
|
||||
const now = new Date();
|
||||
if (!accessToken.expiresAt || accessToken.expiresAt >= now) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
// if the access token is expired or TTL-deleted, fall back to refresh token
|
||||
const refreshToken = await this.tokenMethods.findToken({
|
||||
userId,
|
||||
type: 'mcp_oauth',
|
||||
identifier: `mcp:${serverName}:refresh`,
|
||||
});
|
||||
|
||||
if (refreshToken == null) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// if the token has expired, don't attempt to reconnect
|
||||
const now = new Date();
|
||||
if (accessToken.expiresAt && accessToken.expiresAt < now) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// …otherwise, we're good to go with the reconnect attempt
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -397,6 +397,101 @@ describe('OAuthReconnectTracker', () => {
|
|||
});
|
||||
});
|
||||
|
||||
describe('cooldown-based retry', () => {
|
||||
beforeEach(() => {
|
||||
jest.useFakeTimers();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
jest.useRealTimers();
|
||||
});
|
||||
|
||||
it('should return true from isFailed within first cooldown period (5 min)', () => {
|
||||
const now = Date.now();
|
||||
jest.setSystemTime(now);
|
||||
|
||||
tracker.setFailed(userId, serverName);
|
||||
expect(tracker.isFailed(userId, serverName)).toBe(true);
|
||||
|
||||
jest.advanceTimersByTime(4 * 60 * 1000);
|
||||
expect(tracker.isFailed(userId, serverName)).toBe(true);
|
||||
});
|
||||
|
||||
it('should return false from isFailed after first cooldown elapses (5 min)', () => {
|
||||
const now = Date.now();
|
||||
jest.setSystemTime(now);
|
||||
|
||||
tracker.setFailed(userId, serverName);
|
||||
expect(tracker.isFailed(userId, serverName)).toBe(true);
|
||||
|
||||
jest.advanceTimersByTime(5 * 60 * 1000);
|
||||
expect(tracker.isFailed(userId, serverName)).toBe(false);
|
||||
});
|
||||
|
||||
it('should use progressive cooldown schedule (5m, 10m, 20m, 30m)', () => {
|
||||
const now = Date.now();
|
||||
jest.setSystemTime(now);
|
||||
|
||||
// First failure: 5 min cooldown
|
||||
tracker.setFailed(userId, serverName);
|
||||
jest.advanceTimersByTime(5 * 60 * 1000);
|
||||
expect(tracker.isFailed(userId, serverName)).toBe(false);
|
||||
|
||||
// Second failure: 10 min cooldown
|
||||
tracker.setFailed(userId, serverName);
|
||||
jest.advanceTimersByTime(9 * 60 * 1000);
|
||||
expect(tracker.isFailed(userId, serverName)).toBe(true);
|
||||
jest.advanceTimersByTime(1 * 60 * 1000);
|
||||
expect(tracker.isFailed(userId, serverName)).toBe(false);
|
||||
|
||||
// Third failure: 20 min cooldown
|
||||
tracker.setFailed(userId, serverName);
|
||||
jest.advanceTimersByTime(19 * 60 * 1000);
|
||||
expect(tracker.isFailed(userId, serverName)).toBe(true);
|
||||
jest.advanceTimersByTime(1 * 60 * 1000);
|
||||
expect(tracker.isFailed(userId, serverName)).toBe(false);
|
||||
|
||||
// Fourth failure: 30 min cooldown
|
||||
tracker.setFailed(userId, serverName);
|
||||
jest.advanceTimersByTime(29 * 60 * 1000);
|
||||
expect(tracker.isFailed(userId, serverName)).toBe(true);
|
||||
jest.advanceTimersByTime(1 * 60 * 1000);
|
||||
expect(tracker.isFailed(userId, serverName)).toBe(false);
|
||||
});
|
||||
|
||||
it('should cap cooldown at 30 min for attempts beyond 4', () => {
|
||||
const now = Date.now();
|
||||
jest.setSystemTime(now);
|
||||
|
||||
for (let i = 0; i < 5; i++) {
|
||||
tracker.setFailed(userId, serverName);
|
||||
jest.advanceTimersByTime(30 * 60 * 1000);
|
||||
}
|
||||
|
||||
tracker.setFailed(userId, serverName);
|
||||
jest.advanceTimersByTime(29 * 60 * 1000);
|
||||
expect(tracker.isFailed(userId, serverName)).toBe(true);
|
||||
jest.advanceTimersByTime(1 * 60 * 1000);
|
||||
expect(tracker.isFailed(userId, serverName)).toBe(false);
|
||||
});
|
||||
|
||||
it('should fully reset metadata on removeFailed', () => {
|
||||
const now = Date.now();
|
||||
jest.setSystemTime(now);
|
||||
|
||||
tracker.setFailed(userId, serverName);
|
||||
tracker.setFailed(userId, serverName);
|
||||
tracker.setFailed(userId, serverName);
|
||||
|
||||
tracker.removeFailed(userId, serverName);
|
||||
expect(tracker.isFailed(userId, serverName)).toBe(false);
|
||||
|
||||
tracker.setFailed(userId, serverName);
|
||||
jest.advanceTimersByTime(5 * 60 * 1000);
|
||||
expect(tracker.isFailed(userId, serverName)).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('timestamp tracking edge cases', () => {
|
||||
beforeEach(() => {
|
||||
jest.useFakeTimers();
|
||||
|
|
|
|||
|
|
@ -1,6 +1,12 @@
|
|||
interface FailedMeta {
|
||||
attempts: number;
|
||||
lastFailedAt: number;
|
||||
}
|
||||
|
||||
const COOLDOWN_SCHEDULE_MS = [5 * 60 * 1000, 10 * 60 * 1000, 20 * 60 * 1000, 30 * 60 * 1000];
|
||||
|
||||
export class OAuthReconnectionTracker {
|
||||
/** Map of userId -> Set of serverNames that have failed reconnection */
|
||||
private failed: Map<string, Set<string>> = new Map();
|
||||
private failedMeta: Map<string, Map<string, FailedMeta>> = new Map();
|
||||
/** Map of userId -> Set of serverNames that are actively reconnecting */
|
||||
private active: Map<string, Set<string>> = new Map();
|
||||
/** Map of userId:serverName -> timestamp when reconnection started */
|
||||
|
|
@ -9,7 +15,17 @@ export class OAuthReconnectionTracker {
|
|||
private readonly RECONNECTION_TIMEOUT_MS = 3 * 60 * 1000; // 3 minutes
|
||||
|
||||
public isFailed(userId: string, serverName: string): boolean {
|
||||
return this.failed.get(userId)?.has(serverName) ?? false;
|
||||
const meta = this.failedMeta.get(userId)?.get(serverName);
|
||||
if (!meta) {
|
||||
return false;
|
||||
}
|
||||
const idx = Math.min(meta.attempts - 1, COOLDOWN_SCHEDULE_MS.length - 1);
|
||||
const cooldown = COOLDOWN_SCHEDULE_MS[idx];
|
||||
const elapsed = Date.now() - meta.lastFailedAt;
|
||||
if (elapsed >= cooldown) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/** Check if server is in the active set (original simple check) */
|
||||
|
|
@ -48,11 +64,15 @@ export class OAuthReconnectionTracker {
|
|||
}
|
||||
|
||||
public setFailed(userId: string, serverName: string): void {
|
||||
if (!this.failed.has(userId)) {
|
||||
this.failed.set(userId, new Set());
|
||||
if (!this.failedMeta.has(userId)) {
|
||||
this.failedMeta.set(userId, new Map());
|
||||
}
|
||||
|
||||
this.failed.get(userId)?.add(serverName);
|
||||
const userMap = this.failedMeta.get(userId)!;
|
||||
const existing = userMap.get(serverName);
|
||||
userMap.set(serverName, {
|
||||
attempts: (existing?.attempts ?? 0) + 1,
|
||||
lastFailedAt: Date.now(),
|
||||
});
|
||||
}
|
||||
|
||||
public setActive(userId: string, serverName: string): void {
|
||||
|
|
@ -68,10 +88,10 @@ export class OAuthReconnectionTracker {
|
|||
}
|
||||
|
||||
public removeFailed(userId: string, serverName: string): void {
|
||||
const userServers = this.failed.get(userId);
|
||||
userServers?.delete(serverName);
|
||||
if (userServers?.size === 0) {
|
||||
this.failed.delete(userId);
|
||||
const userMap = this.failedMeta.get(userId);
|
||||
userMap?.delete(serverName);
|
||||
if (userMap?.size === 0) {
|
||||
this.failedMeta.delete(userId);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -94,7 +114,7 @@ export class OAuthReconnectionTracker {
|
|||
activeTimestamps: number;
|
||||
} {
|
||||
return {
|
||||
usersWithFailedServers: this.failed.size,
|
||||
usersWithFailedServers: this.failedMeta.size,
|
||||
usersWithActiveReconnections: this.active.size,
|
||||
activeTimestamps: this.activeTimestamps.size,
|
||||
};
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue