mirror of
https://github.com/danny-avila/LibreChat.git
synced 2026-02-19 08:58:09 +01:00
♻️ refactor: On-demand MCP connections: remove proactive reconnect, default to available (#11839)
* feat: Implement reconnection staggering and backoff jitter for MCP connections
- Enhanced the reconnection logic in OAuthReconnectionManager to stagger reconnection attempts for multiple servers, reducing the risk of connection storms.
- Introduced a backoff delay with random jitter in MCPConnection to improve reconnection behavior during network issues.
- Updated the ConnectionsRepository to handle multiple server connections concurrently with a defined concurrency limit.
Added tests to ensure the new reconnection strategy works as intended.
* refactor: Update MCP server query configuration for improved data freshness
- Reduced stale time from 5 minutes to 30 seconds to ensure quicker updates on server initialization.
- Enabled refetching on window focus and mount to enhance data accuracy during user interactions.
* ♻️ refactor: On-demand MCP connections; remove proactive reconnection, default to available
- Remove reconnectServers() from refresh controller (connection storm root cause)
- Stop gating server selection on connection status; add to selection immediately
- Render agent panel tools from DB cache, not live connection status
- Proceed to cached tools on init failure (only gate on OAuth)
- Remove unused batchToggleServers()
- Reduce useMCPServersQuery staleTime from 5min to 30s, enable refetchOnMount/WindowFocus
* refactor: Optimize MCP tool initialization and server connection logic
- Adjusted tool initialization to only occur if no cached tools are available, improving efficiency.
- Updated comments for clarity on server connection and tool fetching processes.
- Removed unnecessary connection status checks during server selection to streamline the user experience.
This commit is contained in:
parent
dbf8cd40d3
commit
3bf715e05e
9 changed files with 101 additions and 63 deletions
|
|
@ -18,7 +18,6 @@ const {
|
|||
findUser,
|
||||
} = require('~/models');
|
||||
const { getGraphApiToken } = require('~/server/services/GraphTokenService');
|
||||
const { getOAuthReconnectionManager } = require('~/config');
|
||||
const { getOpenIdConfig } = require('~/strategies');
|
||||
|
||||
const registrationController = async (req, res) => {
|
||||
|
|
@ -166,17 +165,6 @@ const refreshController = async (req, res) => {
|
|||
if (session && session.expiration > new Date()) {
|
||||
const token = await setAuthTokens(userId, res, session);
|
||||
|
||||
// trigger OAuth MCP server reconnection asynchronously (best effort)
|
||||
try {
|
||||
void getOAuthReconnectionManager()
|
||||
.reconnectServers(userId)
|
||||
.catch((err) => {
|
||||
logger.error('[refreshController] Error reconnecting OAuth MCP servers:', err);
|
||||
});
|
||||
} catch (err) {
|
||||
logger.warn(`[refreshController] Cannot attempt OAuth MCP servers reconnection:`, err);
|
||||
}
|
||||
|
||||
res.status(200).send({ token, user });
|
||||
} else if (req?.query?.retry) {
|
||||
// Retrying from a refresh token request that failed (401)
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ export default function MCPTools({
|
|||
return null;
|
||||
}
|
||||
|
||||
if (serverInfo.isConnected) {
|
||||
if (serverInfo?.tools?.length && serverInfo.tools.length > 0) {
|
||||
return (
|
||||
<MCPTool key={`${serverInfo.serverName}-${agentId}`} serverInfo={serverInfo} />
|
||||
);
|
||||
|
|
|
|||
|
|
@ -96,17 +96,17 @@ function MCPToolSelectDialog({
|
|||
await new Promise((resolve) => setTimeout(resolve, 500));
|
||||
}
|
||||
|
||||
// Then initialize server if needed
|
||||
// Only initialize if no cached tools exist; skip if tools are already available from DB
|
||||
const serverInfo = mcpServersMap.get(serverName);
|
||||
if (!serverInfo?.isConnected) {
|
||||
if (!serverInfo?.tools?.length) {
|
||||
const result = await initializeServer(serverName);
|
||||
if (result?.success && result.oauthRequired && result.oauthUrl) {
|
||||
if (result?.oauthRequired && result.oauthUrl) {
|
||||
setIsInitializing(null);
|
||||
return;
|
||||
return; // OAuth flow must complete first
|
||||
}
|
||||
}
|
||||
|
||||
// Finally, add tools to form
|
||||
// Add tools to form (refetches from backend's persisted cache)
|
||||
await addToolsToForm(serverName);
|
||||
setIsInitializing(null);
|
||||
} catch (error) {
|
||||
|
|
|
|||
|
|
@ -12,10 +12,10 @@ export const useMCPServersQuery = <TData = t.MCPServersListResponse>(
|
|||
[QueryKeys.mcpServers],
|
||||
() => dataService.getMCPServers(),
|
||||
{
|
||||
staleTime: 1000 * 60 * 5, // 5 minutes - data stays fresh longer
|
||||
refetchOnWindowFocus: false,
|
||||
staleTime: 30 * 1000, // 30 seconds — short enough to pick up servers that finish initializing after first load
|
||||
refetchOnWindowFocus: true,
|
||||
refetchOnReconnect: false,
|
||||
refetchOnMount: false,
|
||||
refetchOnMount: true,
|
||||
retry: false,
|
||||
...config,
|
||||
},
|
||||
|
|
|
|||
|
|
@ -433,33 +433,6 @@ export function useMCPServerManager({
|
|||
[startupConfig?.interface?.mcpServers?.placeholder, localize],
|
||||
);
|
||||
|
||||
const batchToggleServers = useCallback(
|
||||
(serverNames: string[]) => {
|
||||
const connectedServers: string[] = [];
|
||||
const disconnectedServers: string[] = [];
|
||||
|
||||
serverNames.forEach((serverName) => {
|
||||
if (isInitializing(serverName)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const serverStatus = connectionStatus?.[serverName];
|
||||
if (serverStatus?.connectionState === 'connected') {
|
||||
connectedServers.push(serverName);
|
||||
} else {
|
||||
disconnectedServers.push(serverName);
|
||||
}
|
||||
});
|
||||
|
||||
setMCPValues(connectedServers);
|
||||
|
||||
disconnectedServers.forEach((serverName) => {
|
||||
initializeServer(serverName);
|
||||
});
|
||||
},
|
||||
[connectionStatus, setMCPValues, initializeServer, isInitializing],
|
||||
);
|
||||
|
||||
const toggleServerSelection = useCallback(
|
||||
(serverName: string) => {
|
||||
if (isInitializing(serverName)) {
|
||||
|
|
@ -473,15 +446,10 @@ export function useMCPServerManager({
|
|||
const filteredValues = currentValues.filter((name) => name !== serverName);
|
||||
setMCPValues(filteredValues);
|
||||
} else {
|
||||
const serverStatus = connectionStatus?.[serverName];
|
||||
if (serverStatus?.connectionState === 'connected') {
|
||||
setMCPValues([...currentValues, serverName]);
|
||||
} else {
|
||||
initializeServer(serverName);
|
||||
}
|
||||
setMCPValues([...currentValues, serverName]);
|
||||
}
|
||||
},
|
||||
[mcpValues, setMCPValues, connectionStatus, initializeServer, isInitializing],
|
||||
[mcpValues, setMCPValues, isInitializing],
|
||||
);
|
||||
|
||||
const handleConfigSave = useCallback(
|
||||
|
|
@ -677,7 +645,6 @@ export function useMCPServerManager({
|
|||
isPinned,
|
||||
setIsPinned,
|
||||
placeholderText,
|
||||
batchToggleServers,
|
||||
toggleServerSelection,
|
||||
localize,
|
||||
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@ import { MCPConnection } from './connection';
|
|||
import { MCPServersRegistry } from '~/mcp/registry/MCPServersRegistry';
|
||||
import type * as t from './types';
|
||||
|
||||
const CONNECT_CONCURRENCY = 3;
|
||||
|
||||
/**
|
||||
* Manages MCP connections with lazy loading and reconnection.
|
||||
* Maintains a pool of connections and handles connection lifecycle management.
|
||||
|
|
@ -84,9 +86,17 @@ export class ConnectionsRepository {
|
|||
|
||||
/** Gets or creates connections for multiple servers concurrently */
|
||||
async getMany(serverNames: string[]): Promise<Map<string, MCPConnection>> {
|
||||
const connectionPromises = serverNames.map(async (name) => [name, await this.get(name)]);
|
||||
const connections = await Promise.all(connectionPromises);
|
||||
return new Map((connections as [string, MCPConnection][]).filter((v) => !!v[1]));
|
||||
const results: [string, MCPConnection | null][] = [];
|
||||
for (let i = 0; i < serverNames.length; i += CONNECT_CONCURRENCY) {
|
||||
const batch = serverNames.slice(i, i + CONNECT_CONCURRENCY);
|
||||
const batchResults = await Promise.all(
|
||||
batch.map(
|
||||
async (name): Promise<[string, MCPConnection | null]> => [name, await this.get(name)],
|
||||
),
|
||||
);
|
||||
results.push(...batchResults);
|
||||
}
|
||||
return new Map(results.filter((v): v is [string, MCPConnection] => v[1] != null));
|
||||
}
|
||||
|
||||
/** Returns all currently loaded connections without creating new ones */
|
||||
|
|
|
|||
|
|
@ -559,7 +559,11 @@ export class MCPConnection extends EventEmitter {
|
|||
}
|
||||
|
||||
this.isReconnecting = true;
|
||||
const backoffDelay = (attempt: number) => Math.min(1000 * Math.pow(2, attempt), 30000);
|
||||
const backoffDelay = (attempt: number) => {
|
||||
const base = Math.min(1000 * Math.pow(2, attempt), 30000);
|
||||
const jitter = Math.floor(Math.random() * 1000); // up to 1s of random jitter
|
||||
return base + jitter;
|
||||
};
|
||||
|
||||
try {
|
||||
while (
|
||||
|
|
|
|||
|
|
@ -336,6 +336,69 @@ describe('OAuthReconnectionManager', () => {
|
|||
});
|
||||
});
|
||||
|
||||
describe('reconnection staggering', () => {
|
||||
let reconnectionTracker: OAuthReconnectionTracker;
|
||||
|
||||
beforeEach(async () => {
|
||||
jest.useFakeTimers();
|
||||
reconnectionTracker = new OAuthReconnectionTracker();
|
||||
reconnectionManager = await OAuthReconnectionManager.createInstance(
|
||||
flowManager,
|
||||
tokenMethods,
|
||||
reconnectionTracker,
|
||||
);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
jest.useRealTimers();
|
||||
});
|
||||
|
||||
it('should stagger reconnection attempts for multiple servers', async () => {
|
||||
const userId = 'user-123';
|
||||
const oauthServers = new Set(['server1', 'server2', 'server3']);
|
||||
(mockRegistryInstance.getOAuthServers as jest.Mock).mockResolvedValue(oauthServers);
|
||||
|
||||
// All servers have valid tokens and are not connected
|
||||
tokenMethods.findToken.mockImplementation(async ({ identifier }) => {
|
||||
return {
|
||||
userId,
|
||||
identifier,
|
||||
expiresAt: new Date(Date.now() + 3600000),
|
||||
} as unknown as MCPOAuthTokens;
|
||||
});
|
||||
|
||||
const mockNewConnection = {
|
||||
isConnected: jest.fn().mockResolvedValue(true),
|
||||
disconnect: jest.fn(),
|
||||
};
|
||||
mockMCPManager.getUserConnection.mockResolvedValue(
|
||||
mockNewConnection as unknown as MCPConnection,
|
||||
);
|
||||
(mockRegistryInstance.getServerConfig as jest.Mock).mockResolvedValue(
|
||||
{} as unknown as MCPOptions,
|
||||
);
|
||||
|
||||
await reconnectionManager.reconnectServers(userId);
|
||||
|
||||
// Only the first server should have been attempted immediately
|
||||
expect(mockMCPManager.getUserConnection).toHaveBeenCalledTimes(1);
|
||||
expect(mockMCPManager.getUserConnection).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ serverName: 'server1' }),
|
||||
);
|
||||
|
||||
// After advancing all timers, all servers should have been attempted
|
||||
await jest.runAllTimersAsync();
|
||||
|
||||
expect(mockMCPManager.getUserConnection).toHaveBeenCalledTimes(3);
|
||||
expect(mockMCPManager.getUserConnection).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ serverName: 'server2' }),
|
||||
);
|
||||
expect(mockMCPManager.getUserConnection).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ serverName: 'server3' }),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('reconnection timeout behavior', () => {
|
||||
let reconnectionTracker: OAuthReconnectionTracker;
|
||||
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ import { MCPManager } from '~/mcp/MCPManager';
|
|||
import { MCPServersRegistry } from '~/mcp/registry/MCPServersRegistry';
|
||||
|
||||
const DEFAULT_CONNECTION_TIMEOUT_MS = 10_000; // ms
|
||||
const RECONNECT_STAGGER_MS = 500; // ms between each server reconnection
|
||||
|
||||
export class OAuthReconnectionManager {
|
||||
private static instance: OAuthReconnectionManager | null = null;
|
||||
|
|
@ -84,9 +85,14 @@ export class OAuthReconnectionManager {
|
|||
this.reconnectionsTracker.setActive(userId, serverName);
|
||||
}
|
||||
|
||||
// 3. attempt to reconnect the servers
|
||||
for (const serverName of serversToReconnect) {
|
||||
void this.tryReconnect(userId, serverName);
|
||||
// 3. attempt to reconnect the servers with staggered delays to avoid connection storms
|
||||
for (let i = 0; i < serversToReconnect.length; i++) {
|
||||
const serverName = serversToReconnect[i];
|
||||
if (i === 0) {
|
||||
void this.tryReconnect(userId, serverName);
|
||||
} else {
|
||||
setTimeout(() => void this.tryReconnect(userId, serverName), i * RECONNECT_STAGGER_MS);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue