♻️ refactor: On-demand MCP connections: remove proactive reconnect, default to available (#11839)

* feat: Implement reconnection staggering and backoff jitter for MCP connections

- Enhanced the reconnection logic in OAuthReconnectionManager to stagger reconnection attempts for multiple servers, reducing the risk of connection storms.
- Introduced a backoff delay with random jitter in MCPConnection to improve reconnection behavior during network issues.
- Updated the ConnectionsRepository to handle multiple server connections concurrently with a defined concurrency limit.

Added tests to ensure the new reconnection strategy works as intended.

* refactor: Update MCP server query configuration for improved data freshness

- Reduced stale time from 5 minutes to 30 seconds to ensure quicker updates on server initialization.
- Enabled refetching on window focus and mount to enhance data accuracy during user interactions.

* ♻️  refactor: On-demand MCP connections; remove proactive reconnection, default to available

  - Remove reconnectServers() from refresh controller (connection storm root cause)
  - Stop gating server selection on connection status; add to selection immediately
  - Render agent panel tools from DB cache, not live connection status
  - Proceed to cached tools on init failure (only gate on OAuth)
  - Remove unused batchToggleServers()
  - Reduce useMCPServersQuery staleTime from 5min to 30s, enable refetchOnMount/WindowFocus

* refactor: Optimize MCP tool initialization and server connection logic

- Adjusted tool initialization to only occur if no cached tools are available, improving efficiency.
- Updated comments for clarity on server connection and tool fetching processes.
- Removed unnecessary connection status checks during server selection to streamline the user experience.
This commit is contained in:
Danny Avila 2026-02-17 22:33:57 -05:00 committed by GitHub
parent dbf8cd40d3
commit 3bf715e05e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 101 additions and 63 deletions

View file

@ -18,7 +18,6 @@ const {
findUser,
} = require('~/models');
const { getGraphApiToken } = require('~/server/services/GraphTokenService');
const { getOAuthReconnectionManager } = require('~/config');
const { getOpenIdConfig } = require('~/strategies');
const registrationController = async (req, res) => {
@ -166,17 +165,6 @@ const refreshController = async (req, res) => {
if (session && session.expiration > new Date()) {
const token = await setAuthTokens(userId, res, session);
// trigger OAuth MCP server reconnection asynchronously (best effort)
try {
void getOAuthReconnectionManager()
.reconnectServers(userId)
.catch((err) => {
logger.error('[refreshController] Error reconnecting OAuth MCP servers:', err);
});
} catch (err) {
logger.warn(`[refreshController] Cannot attempt OAuth MCP servers reconnection:`, err);
}
res.status(200).send({ token, user });
} else if (req?.query?.retry) {
// Retrying from a refresh token request that failed (401)

View file

@ -46,7 +46,7 @@ export default function MCPTools({
return null;
}
if (serverInfo.isConnected) {
if (serverInfo?.tools?.length && serverInfo.tools.length > 0) {
return (
<MCPTool key={`${serverInfo.serverName}-${agentId}`} serverInfo={serverInfo} />
);

View file

@ -96,17 +96,17 @@ function MCPToolSelectDialog({
await new Promise((resolve) => setTimeout(resolve, 500));
}
// Then initialize server if needed
// Only initialize if no cached tools exist; skip if tools are already available from DB
const serverInfo = mcpServersMap.get(serverName);
if (!serverInfo?.isConnected) {
if (!serverInfo?.tools?.length) {
const result = await initializeServer(serverName);
if (result?.success && result.oauthRequired && result.oauthUrl) {
if (result?.oauthRequired && result.oauthUrl) {
setIsInitializing(null);
return;
return; // OAuth flow must complete first
}
}
// Finally, add tools to form
// Add tools to form (refetches from backend's persisted cache)
await addToolsToForm(serverName);
setIsInitializing(null);
} catch (error) {

View file

@ -12,10 +12,10 @@ export const useMCPServersQuery = <TData = t.MCPServersListResponse>(
[QueryKeys.mcpServers],
() => dataService.getMCPServers(),
{
staleTime: 1000 * 60 * 5, // 5 minutes - data stays fresh longer
refetchOnWindowFocus: false,
staleTime: 30 * 1000, // 30 seconds — short enough to pick up servers that finish initializing after first load
refetchOnWindowFocus: true,
refetchOnReconnect: false,
refetchOnMount: false,
refetchOnMount: true,
retry: false,
...config,
},

View file

@ -433,33 +433,6 @@ export function useMCPServerManager({
[startupConfig?.interface?.mcpServers?.placeholder, localize],
);
const batchToggleServers = useCallback(
(serverNames: string[]) => {
const connectedServers: string[] = [];
const disconnectedServers: string[] = [];
serverNames.forEach((serverName) => {
if (isInitializing(serverName)) {
return;
}
const serverStatus = connectionStatus?.[serverName];
if (serverStatus?.connectionState === 'connected') {
connectedServers.push(serverName);
} else {
disconnectedServers.push(serverName);
}
});
setMCPValues(connectedServers);
disconnectedServers.forEach((serverName) => {
initializeServer(serverName);
});
},
[connectionStatus, setMCPValues, initializeServer, isInitializing],
);
const toggleServerSelection = useCallback(
(serverName: string) => {
if (isInitializing(serverName)) {
@ -473,15 +446,10 @@ export function useMCPServerManager({
const filteredValues = currentValues.filter((name) => name !== serverName);
setMCPValues(filteredValues);
} else {
const serverStatus = connectionStatus?.[serverName];
if (serverStatus?.connectionState === 'connected') {
setMCPValues([...currentValues, serverName]);
} else {
initializeServer(serverName);
}
setMCPValues([...currentValues, serverName]);
}
},
[mcpValues, setMCPValues, connectionStatus, initializeServer, isInitializing],
[mcpValues, setMCPValues, isInitializing],
);
const handleConfigSave = useCallback(
@ -677,7 +645,6 @@ export function useMCPServerManager({
isPinned,
setIsPinned,
placeholderText,
batchToggleServers,
toggleServerSelection,
localize,

View file

@ -4,6 +4,8 @@ import { MCPConnection } from './connection';
import { MCPServersRegistry } from '~/mcp/registry/MCPServersRegistry';
import type * as t from './types';
const CONNECT_CONCURRENCY = 3;
/**
* Manages MCP connections with lazy loading and reconnection.
* Maintains a pool of connections and handles connection lifecycle management.
@ -84,9 +86,17 @@ export class ConnectionsRepository {
/** Gets or creates connections for multiple servers concurrently */
async getMany(serverNames: string[]): Promise<Map<string, MCPConnection>> {
const connectionPromises = serverNames.map(async (name) => [name, await this.get(name)]);
const connections = await Promise.all(connectionPromises);
return new Map((connections as [string, MCPConnection][]).filter((v) => !!v[1]));
const results: [string, MCPConnection | null][] = [];
for (let i = 0; i < serverNames.length; i += CONNECT_CONCURRENCY) {
const batch = serverNames.slice(i, i + CONNECT_CONCURRENCY);
const batchResults = await Promise.all(
batch.map(
async (name): Promise<[string, MCPConnection | null]> => [name, await this.get(name)],
),
);
results.push(...batchResults);
}
return new Map(results.filter((v): v is [string, MCPConnection] => v[1] != null));
}
/** Returns all currently loaded connections without creating new ones */

View file

@ -559,7 +559,11 @@ export class MCPConnection extends EventEmitter {
}
this.isReconnecting = true;
const backoffDelay = (attempt: number) => Math.min(1000 * Math.pow(2, attempt), 30000);
const backoffDelay = (attempt: number) => {
const base = Math.min(1000 * Math.pow(2, attempt), 30000);
const jitter = Math.floor(Math.random() * 1000); // up to 1s of random jitter
return base + jitter;
};
try {
while (

View file

@ -336,6 +336,69 @@ describe('OAuthReconnectionManager', () => {
});
});
describe('reconnection staggering', () => {
let reconnectionTracker: OAuthReconnectionTracker;
beforeEach(async () => {
jest.useFakeTimers();
reconnectionTracker = new OAuthReconnectionTracker();
reconnectionManager = await OAuthReconnectionManager.createInstance(
flowManager,
tokenMethods,
reconnectionTracker,
);
});
afterEach(() => {
jest.useRealTimers();
});
it('should stagger reconnection attempts for multiple servers', async () => {
const userId = 'user-123';
const oauthServers = new Set(['server1', 'server2', 'server3']);
(mockRegistryInstance.getOAuthServers as jest.Mock).mockResolvedValue(oauthServers);
// All servers have valid tokens and are not connected
tokenMethods.findToken.mockImplementation(async ({ identifier }) => {
return {
userId,
identifier,
expiresAt: new Date(Date.now() + 3600000),
} as unknown as MCPOAuthTokens;
});
const mockNewConnection = {
isConnected: jest.fn().mockResolvedValue(true),
disconnect: jest.fn(),
};
mockMCPManager.getUserConnection.mockResolvedValue(
mockNewConnection as unknown as MCPConnection,
);
(mockRegistryInstance.getServerConfig as jest.Mock).mockResolvedValue(
{} as unknown as MCPOptions,
);
await reconnectionManager.reconnectServers(userId);
// Only the first server should have been attempted immediately
expect(mockMCPManager.getUserConnection).toHaveBeenCalledTimes(1);
expect(mockMCPManager.getUserConnection).toHaveBeenCalledWith(
expect.objectContaining({ serverName: 'server1' }),
);
// After advancing all timers, all servers should have been attempted
await jest.runAllTimersAsync();
expect(mockMCPManager.getUserConnection).toHaveBeenCalledTimes(3);
expect(mockMCPManager.getUserConnection).toHaveBeenCalledWith(
expect.objectContaining({ serverName: 'server2' }),
);
expect(mockMCPManager.getUserConnection).toHaveBeenCalledWith(
expect.objectContaining({ serverName: 'server3' }),
);
});
});
describe('reconnection timeout behavior', () => {
let reconnectionTracker: OAuthReconnectionTracker;

View file

@ -7,6 +7,7 @@ import { MCPManager } from '~/mcp/MCPManager';
import { MCPServersRegistry } from '~/mcp/registry/MCPServersRegistry';
const DEFAULT_CONNECTION_TIMEOUT_MS = 10_000; // ms
const RECONNECT_STAGGER_MS = 500; // ms between each server reconnection
export class OAuthReconnectionManager {
private static instance: OAuthReconnectionManager | null = null;
@ -84,9 +85,14 @@ export class OAuthReconnectionManager {
this.reconnectionsTracker.setActive(userId, serverName);
}
// 3. attempt to reconnect the servers
for (const serverName of serversToReconnect) {
void this.tryReconnect(userId, serverName);
// 3. attempt to reconnect the servers with staggered delays to avoid connection storms
for (let i = 0; i < serversToReconnect.length; i++) {
const serverName = serversToReconnect[i];
if (i === 0) {
void this.tryReconnect(userId, serverName);
} else {
setTimeout(() => void this.tryReconnect(userId, serverName), i * RECONNECT_STAGGER_MS);
}
}
}