💫 feat: MCP OAuth Auto-Reconnect (#9646)

* add oauth reconnect tracker

* add connection tracker to mcp manager

* reconnect oauth mcp servers function

* call reconnection in auth controller

* make sure to check connection in panel

* wait for isConnected

* add const for poll interval

* add logging to tryReconnect

* check expiration

* check mcp manager is not null

* check mcp manager is not null

* add test for reconnecting mcp server

* unify logic inside OAuthReconnectionManager

* test reconnection manager, adjust

* chore: reorder import statements in index.js

* chore: imports

* chore: imports

* chore: imports

* chore: imports

* chore: imports

* chore: imports and use types explicitly

---------

Co-authored-by: Danny Avila <danny@librechat.ai>
This commit is contained in:
Federico Ruggi 2025-09-17 22:49:36 +02:00 committed by GitHub
parent 0e94d97bfb
commit d04da60b3b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 830 additions and 13 deletions

View file

@ -14,6 +14,7 @@ export * from './utils';
export * from './db/utils';
/* OAuth */
export * from './oauth';
export * from './mcp/oauth/OAuthReconnectionManager';
/* Crypto */
export * from './crypto';
/* Flow */

View file

@ -0,0 +1,294 @@
import { TokenMethods } from '@librechat/data-schemas';
import { FlowStateManager, MCPConnection, MCPOAuthTokens, MCPOptions } from '../..';
import { MCPManager } from '../MCPManager';
import { OAuthReconnectionManager } from './OAuthReconnectionManager';
import { OAuthReconnectionTracker } from './OAuthReconnectionTracker';
jest.mock('@librechat/data-schemas', () => ({
logger: {
info: jest.fn(),
warn: jest.fn(),
error: jest.fn(),
debug: jest.fn(),
},
}));
jest.mock('../MCPManager');
describe('OAuthReconnectionManager', () => {
let flowManager: jest.Mocked<FlowStateManager<null>>;
let tokenMethods: jest.Mocked<TokenMethods>;
let mockMCPManager: jest.Mocked<MCPManager>;
let reconnectionManager: OAuthReconnectionManager;
beforeEach(() => {
jest.clearAllMocks();
// Reset singleton instance
// eslint-disable-next-line @typescript-eslint/no-explicit-any
(OAuthReconnectionManager as any).instance = null;
// Setup mock flow manager
flowManager = {
createFlow: jest.fn(),
completeFlow: jest.fn(),
failFlow: jest.fn(),
deleteFlow: jest.fn(),
getFlow: jest.fn(),
} as unknown as jest.Mocked<FlowStateManager<null>>;
// Setup mock token methods
tokenMethods = {
findToken: jest.fn(),
createToken: jest.fn(),
updateToken: jest.fn(),
deleteToken: jest.fn(),
} as unknown as jest.Mocked<TokenMethods>;
// Setup mock MCP Manager
mockMCPManager = {
getOAuthServers: jest.fn(),
getUserConnection: jest.fn(),
getUserConnections: jest.fn(),
disconnectUserConnection: jest.fn(),
getRawConfig: jest.fn(),
} as unknown as jest.Mocked<MCPManager>;
(MCPManager.getInstance as jest.Mock).mockReturnValue(mockMCPManager);
});
afterEach(() => {
jest.clearAllMocks();
});
describe('Singleton Pattern', () => {
it('should create instance successfully', async () => {
const instance = await OAuthReconnectionManager.createInstance(flowManager, tokenMethods);
expect(instance).toBeInstanceOf(OAuthReconnectionManager);
});
it('should throw error when creating instance twice', async () => {
await OAuthReconnectionManager.createInstance(flowManager, tokenMethods);
await expect(
OAuthReconnectionManager.createInstance(flowManager, tokenMethods),
).rejects.toThrow('OAuthReconnectionManager already initialized');
});
it('should throw error when getting instance before creation', () => {
expect(() => OAuthReconnectionManager.getInstance()).toThrow(
'OAuthReconnectionManager not initialized',
);
});
});
describe('isReconnecting', () => {
let reconnectionTracker: OAuthReconnectionTracker;
beforeEach(async () => {
reconnectionTracker = new OAuthReconnectionTracker();
reconnectionManager = await OAuthReconnectionManager.createInstance(
flowManager,
tokenMethods,
reconnectionTracker,
);
});
it('should return true when server is actively reconnecting', () => {
const userId = 'user-123';
const serverName = 'test-server';
expect(reconnectionManager.isReconnecting(userId, serverName)).toBe(false);
reconnectionTracker.setActive(userId, serverName);
const result = reconnectionManager.isReconnecting(userId, serverName);
expect(result).toBe(true);
});
it('should return false when server is not reconnecting', () => {
const userId = 'user-123';
const serverName = 'test-server';
const result = reconnectionManager.isReconnecting(userId, serverName);
expect(result).toBe(false);
});
});
describe('clearReconnection', () => {
let reconnectionTracker: OAuthReconnectionTracker;
beforeEach(async () => {
reconnectionTracker = new OAuthReconnectionTracker();
reconnectionManager = await OAuthReconnectionManager.createInstance(
flowManager,
tokenMethods,
reconnectionTracker,
);
});
it('should clear both failed and active reconnection states', () => {
const userId = 'user-123';
const serverName = 'test-server';
reconnectionTracker.setFailed(userId, serverName);
reconnectionTracker.setActive(userId, serverName);
reconnectionManager.clearReconnection(userId, serverName);
expect(reconnectionManager.isReconnecting(userId, serverName)).toBe(false);
expect(reconnectionTracker.isFailed(userId, serverName)).toBe(false);
expect(reconnectionTracker.isActive(userId, serverName)).toBe(false);
});
});
describe('reconnectServers', () => {
let reconnectionTracker: OAuthReconnectionTracker;
beforeEach(async () => {
reconnectionTracker = new OAuthReconnectionTracker();
reconnectionManager = await OAuthReconnectionManager.createInstance(
flowManager,
tokenMethods,
reconnectionTracker,
);
});
it('should reconnect eligible servers', async () => {
const userId = 'user-123';
const oauthServers = new Set(['server1', 'server2', 'server3']);
mockMCPManager.getOAuthServers.mockReturnValue(oauthServers);
// server1: has failed reconnection
reconnectionTracker.setFailed(userId, 'server1');
// server2: already connected
const mockConnection = {
isConnected: jest.fn().mockResolvedValue(true),
};
const userConnections = new Map([['server2', mockConnection]]);
mockMCPManager.getUserConnections.mockReturnValue(
userConnections as unknown as Map<string, MCPConnection>,
);
// server3: has valid token and not connected
tokenMethods.findToken.mockImplementation(async ({ identifier }) => {
if (identifier === 'mcp:server3') {
return {
userId,
identifier,
expiresAt: new Date(Date.now() + 3600000), // 1 hour from now
} as unknown as MCPOAuthTokens;
}
return null;
});
// Mock successful reconnection
const mockNewConnection = {
isConnected: jest.fn().mockResolvedValue(true),
disconnect: jest.fn(),
};
mockMCPManager.getUserConnection.mockResolvedValue(
mockNewConnection as unknown as MCPConnection,
);
mockMCPManager.getRawConfig.mockReturnValue({ initTimeout: 5000 } as unknown as MCPOptions);
await reconnectionManager.reconnectServers(userId);
// Verify server3 was marked as active
expect(reconnectionTracker.isActive(userId, 'server3')).toBe(true);
// Wait for async tryReconnect to complete
await new Promise((resolve) => setTimeout(resolve, 100));
// Verify reconnection was attempted for server3
expect(mockMCPManager.getUserConnection).toHaveBeenCalledWith({
serverName: 'server3',
user: { id: userId },
flowManager,
tokenMethods,
forceNew: false,
connectionTimeout: 5000,
returnOnOAuth: true,
});
// Verify successful reconnection cleared the states
expect(reconnectionTracker.isFailed(userId, 'server3')).toBe(false);
expect(reconnectionTracker.isActive(userId, 'server3')).toBe(false);
});
it('should handle failed reconnection attempts', async () => {
const userId = 'user-123';
const oauthServers = new Set(['server1']);
mockMCPManager.getOAuthServers.mockReturnValue(oauthServers);
// server1: has valid token
tokenMethods.findToken.mockResolvedValue({
userId,
identifier: 'mcp:server1',
expiresAt: new Date(Date.now() + 3600000),
} as unknown as MCPOAuthTokens);
// Mock failed connection
mockMCPManager.getUserConnection.mockRejectedValue(new Error('Connection failed'));
mockMCPManager.getRawConfig.mockReturnValue({} as unknown as MCPOptions);
await reconnectionManager.reconnectServers(userId);
// Wait for async tryReconnect to complete
await new Promise((resolve) => setTimeout(resolve, 100));
// Verify failure handling
expect(reconnectionTracker.isFailed(userId, 'server1')).toBe(true);
expect(reconnectionTracker.isActive(userId, 'server1')).toBe(false);
expect(mockMCPManager.disconnectUserConnection).toHaveBeenCalledWith(userId, 'server1');
});
it('should not reconnect servers with expired tokens', async () => {
const userId = 'user-123';
const oauthServers = new Set(['server1']);
mockMCPManager.getOAuthServers.mockReturnValue(oauthServers);
// server1: has expired token
tokenMethods.findToken.mockResolvedValue({
userId,
identifier: 'mcp:server1',
expiresAt: new Date(Date.now() - 3600000), // 1 hour ago
} as unknown as MCPOAuthTokens);
await reconnectionManager.reconnectServers(userId);
// Verify no reconnection attempt was made
expect(reconnectionTracker.isActive(userId, 'server1')).toBe(false);
expect(mockMCPManager.getUserConnection).not.toHaveBeenCalled();
});
it('should handle connection that returns but is not connected', async () => {
const userId = 'user-123';
const oauthServers = new Set(['server1']);
mockMCPManager.getOAuthServers.mockReturnValue(oauthServers);
tokenMethods.findToken.mockResolvedValue({
userId,
identifier: 'mcp:server1',
expiresAt: new Date(Date.now() + 3600000),
} as unknown as MCPOAuthTokens);
// Mock connection that returns but is not connected
const mockConnection = {
isConnected: jest.fn().mockResolvedValue(false),
disconnect: jest.fn(),
};
mockMCPManager.getUserConnection.mockResolvedValue(
mockConnection as unknown as MCPConnection,
);
mockMCPManager.getRawConfig.mockReturnValue({} as unknown as MCPOptions);
await reconnectionManager.reconnectServers(userId);
// Wait for async tryReconnect to complete
await new Promise((resolve) => setTimeout(resolve, 100));
// Verify failure handling
expect(mockConnection.disconnect).toHaveBeenCalled();
expect(reconnectionTracker.isFailed(userId, 'server1')).toBe(true);
expect(reconnectionTracker.isActive(userId, 'server1')).toBe(false);
expect(mockMCPManager.disconnectUserConnection).toHaveBeenCalledWith(userId, 'server1');
});
});
});

View file

@ -0,0 +1,163 @@
import { logger } from '@librechat/data-schemas';
import type { TokenMethods } from '@librechat/data-schemas';
import type { TUser } from 'librechat-data-provider';
import type { MCPOAuthTokens } from './types';
import { OAuthReconnectionTracker } from './OAuthReconnectionTracker';
import { FlowStateManager } from '~/flow/manager';
import { MCPManager } from '~/mcp/MCPManager';
const DEFAULT_CONNECTION_TIMEOUT_MS = 10_000; // ms
export class OAuthReconnectionManager {
private static instance: OAuthReconnectionManager | null = null;
protected readonly flowManager: FlowStateManager<MCPOAuthTokens | null>;
protected readonly tokenMethods: TokenMethods;
private readonly reconnectionsTracker: OAuthReconnectionTracker;
public static getInstance(): OAuthReconnectionManager {
if (!OAuthReconnectionManager.instance) {
throw new Error('OAuthReconnectionManager not initialized');
}
return OAuthReconnectionManager.instance;
}
public static async createInstance(
flowManager: FlowStateManager<MCPOAuthTokens | null>,
tokenMethods: TokenMethods,
reconnections?: OAuthReconnectionTracker,
): Promise<OAuthReconnectionManager> {
if (OAuthReconnectionManager.instance != null) {
throw new Error('OAuthReconnectionManager already initialized');
}
const manager = new OAuthReconnectionManager(flowManager, tokenMethods, reconnections);
OAuthReconnectionManager.instance = manager;
return manager;
}
public constructor(
flowManager: FlowStateManager<MCPOAuthTokens | null>,
tokenMethods: TokenMethods,
reconnections?: OAuthReconnectionTracker,
) {
this.flowManager = flowManager;
this.tokenMethods = tokenMethods;
this.reconnectionsTracker = reconnections ?? new OAuthReconnectionTracker();
}
public isReconnecting(userId: string, serverName: string): boolean {
return this.reconnectionsTracker.isActive(userId, serverName);
}
public async reconnectServers(userId: string) {
const mcpManager = MCPManager.getInstance();
// 1. derive the servers to reconnect
const serversToReconnect = [];
for (const serverName of mcpManager.getOAuthServers() ?? []) {
const canReconnect = await this.canReconnect(userId, serverName);
if (canReconnect) {
serversToReconnect.push(serverName);
}
}
// 2. mark the servers as reconnecting
for (const serverName of serversToReconnect) {
this.reconnectionsTracker.setActive(userId, serverName);
}
// 3. attempt to reconnect the servers
for (const serverName of serversToReconnect) {
void this.tryReconnect(userId, serverName);
}
}
public clearReconnection(userId: string, serverName: string) {
this.reconnectionsTracker.removeFailed(userId, serverName);
this.reconnectionsTracker.removeActive(userId, serverName);
}
private async tryReconnect(userId: string, serverName: string) {
const mcpManager = MCPManager.getInstance();
const logPrefix = `[tryReconnectOAuthMCPServer][User: ${userId}][${serverName}]`;
logger.info(`${logPrefix} Attempting reconnection`);
const config = mcpManager.getRawConfig(serverName);
const cleanupOnFailedReconnect = () => {
this.reconnectionsTracker.setFailed(userId, serverName);
this.reconnectionsTracker.removeActive(userId, serverName);
mcpManager.disconnectUserConnection(userId, serverName);
};
try {
// attempt to get connection (this will use existing tokens and refresh if needed)
const connection = await mcpManager.getUserConnection({
serverName,
user: { id: userId } as TUser,
flowManager: this.flowManager,
tokenMethods: this.tokenMethods,
// don't force new connection, let it reuse existing or create new as needed
forceNew: false,
// set a reasonable timeout for reconnection attempts
connectionTimeout: config?.initTimeout ?? DEFAULT_CONNECTION_TIMEOUT_MS,
// don't trigger OAuth flow during reconnection
returnOnOAuth: true,
});
if (connection && (await connection.isConnected())) {
logger.info(`${logPrefix} Successfully reconnected`);
this.clearReconnection(userId, serverName);
} else {
logger.warn(`${logPrefix} Failed to reconnect`);
await connection?.disconnect();
cleanupOnFailedReconnect();
}
} catch (error) {
logger.warn(`${logPrefix} Failed to reconnect: ${error}`);
cleanupOnFailedReconnect();
}
}
private async canReconnect(userId: string, serverName: string) {
const mcpManager = MCPManager.getInstance();
// if the server has failed reconnection, don't attempt to reconnect
if (this.reconnectionsTracker.isFailed(userId, serverName)) {
return false;
}
// if the server is already connected, don't attempt to reconnect
const existingConnections = mcpManager.getUserConnections(userId);
if (existingConnections?.has(serverName)) {
const isConnected = await existingConnections.get(serverName)?.isConnected();
if (isConnected) {
return false;
}
}
// if the server has no tokens for the user, don't attempt to reconnect
const accessToken = await this.tokenMethods.findToken({
userId,
type: 'mcp_oauth',
identifier: `mcp:${serverName}`,
});
if (accessToken == null) {
return false;
}
// if the token has expired, don't attempt to reconnect
const now = new Date();
if (accessToken.expiresAt && accessToken.expiresAt < now) {
return false;
}
// …otherwise, we're good to go with the reconnect attempt
return true;
}
}

View file

@ -0,0 +1,181 @@
import { OAuthReconnectionTracker } from './OAuthReconnectionTracker';
describe('OAuthReconnectTracker', () => {
let tracker: OAuthReconnectionTracker;
const userId = 'user123';
const serverName = 'test-server';
const anotherServer = 'another-server';
beforeEach(() => {
tracker = new OAuthReconnectionTracker();
});
describe('setFailed', () => {
it('should record a failed reconnection attempt', () => {
tracker.setFailed(userId, serverName);
expect(tracker.isFailed(userId, serverName)).toBe(true);
});
it('should track multiple servers for the same user', () => {
tracker.setFailed(userId, serverName);
tracker.setFailed(userId, anotherServer);
expect(tracker.isFailed(userId, serverName)).toBe(true);
expect(tracker.isFailed(userId, anotherServer)).toBe(true);
});
it('should track different users independently', () => {
const anotherUserId = 'user456';
tracker.setFailed(userId, serverName);
tracker.setFailed(anotherUserId, serverName);
expect(tracker.isFailed(userId, serverName)).toBe(true);
expect(tracker.isFailed(anotherUserId, serverName)).toBe(true);
});
});
describe('isFailed', () => {
it('should return false when no failed attempt is recorded', () => {
expect(tracker.isFailed(userId, serverName)).toBe(false);
});
it('should return true after a failed attempt is recorded', () => {
tracker.setFailed(userId, serverName);
expect(tracker.isFailed(userId, serverName)).toBe(true);
});
it('should return false for a different server even after another server failed', () => {
tracker.setFailed(userId, serverName);
expect(tracker.isFailed(userId, anotherServer)).toBe(false);
});
});
describe('removeFailed', () => {
it('should clear a failed reconnect record', () => {
tracker.setFailed(userId, serverName);
expect(tracker.isFailed(userId, serverName)).toBe(true);
tracker.removeFailed(userId, serverName);
expect(tracker.isFailed(userId, serverName)).toBe(false);
});
it('should only clear the specific server for the user', () => {
tracker.setFailed(userId, serverName);
tracker.setFailed(userId, anotherServer);
tracker.removeFailed(userId, serverName);
expect(tracker.isFailed(userId, serverName)).toBe(false);
expect(tracker.isFailed(userId, anotherServer)).toBe(true);
});
it('should handle clearing non-existent records gracefully', () => {
expect(() => tracker.removeFailed(userId, serverName)).not.toThrow();
});
});
describe('setActive', () => {
it('should mark a server as reconnecting', () => {
tracker.setActive(userId, serverName);
expect(tracker.isActive(userId, serverName)).toBe(true);
});
it('should track multiple reconnecting servers', () => {
tracker.setActive(userId, serverName);
tracker.setActive(userId, anotherServer);
expect(tracker.isActive(userId, serverName)).toBe(true);
expect(tracker.isActive(userId, anotherServer)).toBe(true);
});
});
describe('isActive', () => {
it('should return false when server is not reconnecting', () => {
expect(tracker.isActive(userId, serverName)).toBe(false);
});
it('should return true when server is marked as reconnecting', () => {
tracker.setActive(userId, serverName);
expect(tracker.isActive(userId, serverName)).toBe(true);
});
it('should handle non-existent user gracefully', () => {
expect(tracker.isActive('non-existent-user', serverName)).toBe(false);
});
});
describe('removeActive', () => {
it('should clear reconnecting state for a server', () => {
tracker.setActive(userId, serverName);
expect(tracker.isActive(userId, serverName)).toBe(true);
tracker.removeActive(userId, serverName);
expect(tracker.isActive(userId, serverName)).toBe(false);
});
it('should only clear specific server state', () => {
tracker.setActive(userId, serverName);
tracker.setActive(userId, anotherServer);
tracker.removeActive(userId, serverName);
expect(tracker.isActive(userId, serverName)).toBe(false);
expect(tracker.isActive(userId, anotherServer)).toBe(true);
});
it('should handle clearing non-existent state gracefully', () => {
expect(() => tracker.removeActive(userId, serverName)).not.toThrow();
});
});
describe('cleanup behavior', () => {
it('should clean up empty user sets for failed reconnects', () => {
tracker.setFailed(userId, serverName);
tracker.removeFailed(userId, serverName);
// Record and clear another user to ensure internal cleanup
const anotherUserId = 'user456';
tracker.setFailed(anotherUserId, serverName);
// Original user should still be able to reconnect
expect(tracker.isFailed(userId, serverName)).toBe(false);
});
it('should clean up empty user sets for active reconnections', () => {
tracker.setActive(userId, serverName);
tracker.removeActive(userId, serverName);
// Mark another user to ensure internal cleanup
const anotherUserId = 'user456';
tracker.setActive(anotherUserId, serverName);
// Original user should not be reconnecting
expect(tracker.isActive(userId, serverName)).toBe(false);
});
});
describe('combined state management', () => {
it('should handle both failed and reconnecting states independently', () => {
// Mark as reconnecting
tracker.setActive(userId, serverName);
expect(tracker.isActive(userId, serverName)).toBe(true);
expect(tracker.isFailed(userId, serverName)).toBe(false);
// Record failed attempt
tracker.setFailed(userId, serverName);
expect(tracker.isActive(userId, serverName)).toBe(true);
expect(tracker.isFailed(userId, serverName)).toBe(true);
// Clear reconnecting state
tracker.removeActive(userId, serverName);
expect(tracker.isActive(userId, serverName)).toBe(false);
expect(tracker.isFailed(userId, serverName)).toBe(true);
// Clear failed state
tracker.removeFailed(userId, serverName);
expect(tracker.isActive(userId, serverName)).toBe(false);
expect(tracker.isFailed(userId, serverName)).toBe(false);
});
});
});

View file

@ -0,0 +1,46 @@
export class OAuthReconnectionTracker {
// Map of userId -> Set of serverNames that have failed reconnection
private failed: Map<string, Set<string>> = new Map();
// Map of userId -> Set of serverNames that are actively reconnecting
private active: Map<string, Set<string>> = new Map();
public isFailed(userId: string, serverName: string): boolean {
return this.failed.get(userId)?.has(serverName) ?? false;
}
public isActive(userId: string, serverName: string): boolean {
return this.active.get(userId)?.has(serverName) ?? false;
}
public setFailed(userId: string, serverName: string): void {
if (!this.failed.has(userId)) {
this.failed.set(userId, new Set());
}
this.failed.get(userId)?.add(serverName);
}
public setActive(userId: string, serverName: string): void {
if (!this.active.has(userId)) {
this.active.set(userId, new Set());
}
this.active.get(userId)?.add(serverName);
}
public removeFailed(userId: string, serverName: string): void {
const userServers = this.failed.get(userId);
userServers?.delete(serverName);
if (userServers?.size === 0) {
this.failed.delete(userId);
}
}
public removeActive(userId: string, serverName: string): void {
const userServers = this.active.get(userId);
userServers?.delete(serverName);
if (userServers?.size === 0) {
this.active.delete(userId);
}
}
}