🚦 feat: Auto-reinitialize MCP Servers on Request (#9226)

This commit is contained in:
Danny Avila 2025-08-23 03:27:05 -04:00
parent ac608ded46
commit c827fdd10e
No known key found for this signature in database
GPG key ID: BF31EEB2C5CA0956
28 changed files with 871 additions and 312 deletions

View file

@ -26,7 +26,7 @@ const ToolCacheKeys = {
* @param {string[]} [options.roleIds] - Role IDs for role-based tools
* @param {string[]} [options.groupIds] - Group IDs for group-based tools
* @param {boolean} [options.includeGlobal=true] - Whether to include global tools
* @returns {Promise<Object|null>} The available tools object or null if not cached
* @returns {Promise<LCAvailableTools|null>} The available tools object or null if not cached
*/
async function getCachedTools(options = {}) {
const cache = getLogStores(CacheKeys.CONFIG_STORE);
@ -41,13 +41,13 @@ async function getCachedTools(options = {}) {
// Future implementation will merge tools from multiple sources
// based on user permissions, roles, and groups
if (userId) {
// Check if we have pre-computed effective tools for this user
/** @type {LCAvailableTools | null} Check if we have pre-computed effective tools for this user */
const effectiveTools = await cache.get(ToolCacheKeys.EFFECTIVE(userId));
if (effectiveTools) {
return effectiveTools;
}
// Otherwise, compute from individual sources
/** @type {LCAvailableTools | null} Otherwise, compute from individual sources */
const toolSources = [];
if (includeGlobal) {

View file

@ -1,5 +1,4 @@
const { logger } = require('@librechat/data-schemas');
const { isEnabled, getUserMCPAuthMap } = require('@librechat/api');
const { isEnabled } = require('@librechat/api');
const { CacheKeys, EModelEndpoint } = require('librechat-data-provider');
const { normalizeEndpointName } = require('~/server/utils');
const loadCustomConfig = require('./loadCustomConfig');
@ -53,31 +52,6 @@ const getCustomEndpointConfig = async (endpoint) => {
);
};
/**
* @param {Object} params
* @param {string} params.userId
* @param {GenericTool[]} [params.tools]
* @param {import('@librechat/data-schemas').PluginAuthMethods['findPluginAuthsByKeys']} params.findPluginAuthsByKeys
* @returns {Promise<Record<string, Record<string, string>> | undefined>}
*/
async function getMCPAuthMap({ userId, tools, findPluginAuthsByKeys }) {
try {
if (!tools || tools.length === 0) {
return;
}
return await getUserMCPAuthMap({
tools,
userId,
findPluginAuthsByKeys,
});
} catch (err) {
logger.error(
`[api/server/controllers/agents/client.js #chatCompletion] Error getting custom user vars for agent`,
err,
);
}
}
/**
* @returns {Promise<boolean>}
*/
@ -88,7 +62,6 @@ async function hasCustomUserVars() {
}
module.exports = {
getMCPAuthMap,
getCustomConfig,
getBalanceConfig,
hasCustomUserVars,

View file

@ -9,7 +9,7 @@ const { getLogStores } = require('~/cache');
* @param {string} params.userId - User ID
* @param {string} params.serverName - MCP server name
* @param {Array} params.tools - Array of tool objects from MCP server
* @returns {Promise<void>}
* @returns {Promise<LCAvailableTools>}
*/
async function updateMCPUserTools({ userId, serverName, tools }) {
try {
@ -39,6 +39,7 @@ async function updateMCPUserTools({ userId, serverName, tools }) {
const cache = getLogStores(CacheKeys.CONFIG_STORE);
await cache.delete(CacheKeys.TOOLS);
logger.debug(`[MCP Cache] Updated ${tools.length} tools for ${serverName} user ${userId}`);
return userTools;
} catch (error) {
logger.error(`[MCP Cache] Failed to update tools for ${serverName}:`, error);
throw error;

View file

@ -30,7 +30,13 @@ const { getModelMaxTokens } = require('~/utils');
* @param {TEndpointOption} [params.endpointOption]
* @param {Set<string>} [params.allowedProviders]
* @param {boolean} [params.isInitialAgent]
* @returns {Promise<Agent & { tools: StructuredTool[], attachments: Array<MongoFile>, toolContextMap: Record<string, unknown>, maxContextTokens: number }>}
* @returns {Promise<Agent & {
* tools: StructuredTool[],
* attachments: Array<MongoFile>,
* toolContextMap: Record<string, unknown>,
* maxContextTokens: number,
* userMCPAuthMap?: Record<string, Record<string, string>>
* }>}
*/
const initializeAgent = async ({
req,
@ -91,16 +97,19 @@ const initializeAgent = async ({
});
const provider = agent.provider;
const { tools: structuredTools, toolContextMap } =
(await loadTools?.({
req,
res,
provider,
agentId: agent.id,
tools: agent.tools,
model: agent.model,
tool_resources,
})) ?? {};
const {
tools: structuredTools,
toolContextMap,
userMCPAuthMap,
} = (await loadTools?.({
req,
res,
provider,
agentId: agent.id,
tools: agent.tools,
model: agent.model,
tool_resources,
})) ?? {};
agent.endpoint = provider;
const { getOptions, overrideProvider } = await getProviderConfig(provider);
@ -189,6 +198,7 @@ const initializeAgent = async ({
tools,
attachments,
resendFiles,
userMCPAuthMap,
toolContextMap,
useLegacyContent: !!options.useLegacyContent,
maxContextTokens: Math.round((agentMaxContextTokens - maxTokens) * 0.9),

View file

@ -19,7 +19,10 @@ const AgentClient = require('~/server/controllers/agents/client');
const { getAgent } = require('~/models/Agent');
const { logViolation } = require('~/cache');
function createToolLoader() {
/**
* @param {AbortSignal} signal
*/
function createToolLoader(signal) {
/**
* @param {object} params
* @param {ServerRequest} params.req
@ -29,7 +32,11 @@ function createToolLoader() {
* @param {string} params.provider
* @param {string} params.model
* @param {AgentToolResources} params.tool_resources
* @returns {Promise<{ tools: StructuredTool[], toolContextMap: Record<string, unknown> } | undefined>}
* @returns {Promise<{
* tools: StructuredTool[],
* toolContextMap: Record<string, unknown>,
* userMCPAuthMap?: Record<string, Record<string, string>>
* } | undefined>}
*/
return async function loadTools({ req, res, agentId, tools, provider, model, tool_resources }) {
const agent = { id: agentId, tools, provider, model };
@ -38,6 +45,7 @@ function createToolLoader() {
req,
res,
agent,
signal,
tool_resources,
});
} catch (error) {
@ -46,7 +54,7 @@ function createToolLoader() {
};
}
const initializeClient = async ({ req, res, endpointOption }) => {
const initializeClient = async ({ req, res, signal, endpointOption }) => {
if (!endpointOption) {
throw new Error('Endpoint option not provided');
}
@ -92,7 +100,7 @@ const initializeClient = async ({ req, res, endpointOption }) => {
/** @type {Set<string>} */
const allowedProviders = new Set(req?.app?.locals?.[EModelEndpoint.agents]?.allowedProviders);
const loadTools = createToolLoader();
const loadTools = createToolLoader(signal);
/** @type {Array<MongoFile>} */
const requestFiles = req.body.files ?? [];
/** @type {string} */
@ -111,6 +119,7 @@ const initializeClient = async ({ req, res, endpointOption }) => {
});
const agent_ids = primaryConfig.agent_ids;
let userMCPAuthMap = primaryConfig.userMCPAuthMap;
if (agent_ids?.length) {
for (const agentId of agent_ids) {
const agent = await getAgent({ id: agentId });
@ -140,6 +149,7 @@ const initializeClient = async ({ req, res, endpointOption }) => {
endpointOption,
allowedProviders,
});
Object.assign(userMCPAuthMap, config.userMCPAuthMap ?? {});
agentConfigs.set(agentId, config);
}
}
@ -188,7 +198,7 @@ const initializeClient = async ({ req, res, endpointOption }) => {
: EModelEndpoint.agents,
});
return { client };
return { client, userMCPAuthMap };
};
module.exports = { initializeClient };

View file

@ -1,7 +1,12 @@
const { z } = require('zod');
const { tool } = require('@langchain/core/tools');
const { logger } = require('@librechat/data-schemas');
const { Constants: AgentConstants, Providers, GraphEvents } = require('@librechat/agents');
const {
Providers,
StepTypes,
GraphEvents,
Constants: AgentConstants,
} = require('@librechat/agents');
const {
sendEvent,
MCPOAuthHandler,
@ -11,14 +16,14 @@ const {
const {
Time,
CacheKeys,
StepTypes,
Constants,
ContentTypes,
isAssistantsEndpoint,
} = require('librechat-data-provider');
const { getCachedTools, loadCustomConfig } = require('./Config');
const { findToken, createToken, updateToken } = require('~/models');
const { getMCPManager, getFlowStateManager } = require('~/config');
const { getCachedTools, loadCustomConfig } = require('./Config');
const { reinitMCPServer } = require('./Tools/mcp');
const { getLogStores } = require('~/cache');
/**
@ -26,16 +31,13 @@ const { getLogStores } = require('~/cache');
* @param {ServerResponse} params.res - The Express response object for sending events.
* @param {string} params.stepId - The ID of the step in the flow.
* @param {ToolCallChunk} params.toolCall - The tool call object containing tool information.
* @param {string} params.loginFlowId - The ID of the login flow.
* @param {FlowStateManager<any>} params.flowManager - The flow manager instance.
*/
function createOAuthStart({ res, stepId, toolCall, loginFlowId, flowManager, signal }) {
function createRunStepDeltaEmitter({ res, stepId, toolCall }) {
/**
* Creates a function to handle OAuth login requests.
* @param {string} authURL - The URL to redirect the user for OAuth authentication.
* @returns {Promise<boolean>} Returns true to indicate the event was sent successfully.
* @returns {void}
*/
return async function (authURL) {
return function (authURL) {
/** @type {{ id: string; delta: AgentToolCallDelta }} */
const data = {
id: stepId,
@ -46,17 +48,54 @@ function createOAuthStart({ res, stepId, toolCall, loginFlowId, flowManager, sig
expires_at: Date.now() + Time.TWO_MINUTES,
},
};
/** Used to ensure the handler (use of `sendEvent`) is only invoked once */
await flowManager.createFlowWithHandler(
loginFlowId,
'oauth_login',
async () => {
sendEvent(res, { event: GraphEvents.ON_RUN_STEP_DELTA, data });
logger.debug('Sent OAuth login request to client');
return true;
sendEvent(res, { event: GraphEvents.ON_RUN_STEP_DELTA, data });
};
}
/**
* @param {object} params
* @param {ServerResponse} params.res - The Express response object for sending events.
* @param {string} params.runId - The Run ID, i.e. message ID
* @param {string} params.stepId - The ID of the step in the flow.
* @param {ToolCallChunk} params.toolCall - The tool call object containing tool information.
* @param {number} [params.index]
*/
function createRunStepEmitter({ res, runId, stepId, toolCall, index }) {
return function () {
/** @type {import('@librechat/agents').RunStep} */
const data = {
runId: runId ?? Constants.USE_PRELIM_RESPONSE_MESSAGE_ID,
id: stepId,
type: StepTypes.TOOL_CALLS,
index: index ?? 0,
stepDetails: {
type: StepTypes.TOOL_CALLS,
tool_calls: [toolCall],
},
signal,
);
};
sendEvent(res, { event: GraphEvents.ON_RUN_STEP, data });
};
}
/**
* Creates a function used to ensure the flow handler is only invoked once
* @param {object} params
* @param {string} params.flowId - The ID of the login flow.
* @param {FlowStateManager<any>} params.flowManager - The flow manager instance.
* @param {(authURL: string) => void} [params.callback]
*/
function createOAuthStart({ flowId, flowManager, callback }) {
/**
* Creates a function to handle OAuth login requests.
* @param {string} authURL - The URL to redirect the user for OAuth authentication.
* @returns {Promise<boolean>} Returns true to indicate the event was sent successfully.
*/
return async function (authURL) {
await flowManager.createFlowWithHandler(flowId, 'oauth_login', async () => {
callback?.(authURL);
logger.debug('Sent OAuth login request to client');
return true;
});
};
}
@ -99,23 +138,166 @@ function createAbortHandler({ userId, serverName, toolName, flowManager }) {
}
/**
* Creates a general tool for an entire action set.
* @param {Object} params
* @param {() => void} params.runStepEmitter
* @param {(authURL: string) => void} params.runStepDeltaEmitter
* @returns {(authURL: string) => void}
*/
function createOAuthCallback({ runStepEmitter, runStepDeltaEmitter }) {
return function (authURL) {
runStepEmitter();
runStepDeltaEmitter(authURL);
};
}
/**
* @param {Object} params
* @param {ServerRequest} params.req - The Express request object, containing user/request info.
* @param {ServerResponse} params.res - The Express response object for sending events.
* @param {string} params.serverName
* @param {AbortSignal} params.signal
* @param {string} params.model
* @param {number} [params.index]
* @param {Record<string, Record<string, string>>} [params.userMCPAuthMap]
* @returns { Promise<Array<typeof tool | { _call: (toolInput: Object | string) => unknown}>> } An object with `_call` method to execute the tool input.
*/
async function reconnectServer({ req, res, index, signal, serverName, userMCPAuthMap }) {
const runId = Constants.USE_PRELIM_RESPONSE_MESSAGE_ID;
const flowId = `${req.user?.id}:${serverName}:${Date.now()}`;
const flowManager = getFlowStateManager(getLogStores(CacheKeys.FLOWS));
const stepId = 'step_oauth_login_' + serverName;
const toolCall = {
id: flowId,
name: serverName,
type: 'tool_call_chunk',
};
const runStepEmitter = createRunStepEmitter({
res,
index,
runId,
stepId,
toolCall,
});
const runStepDeltaEmitter = createRunStepDeltaEmitter({
res,
stepId,
toolCall,
});
const callback = createOAuthCallback({ runStepEmitter, runStepDeltaEmitter });
const oauthStart = createOAuthStart({
res,
flowId,
callback,
flowManager,
});
return await reinitMCPServer({
req,
signal,
serverName,
oauthStart,
flowManager,
userMCPAuthMap,
forceNew: true,
returnOnOAuth: false,
connectionTimeout: Time.TWO_MINUTES,
});
}
/**
* Creates all tools from the specified MCP Server via `toolKey`.
*
* @param {Object} params - The parameters for loading action sets.
* This function assumes tools could not be aggregated from the cache of tool definitions,
* i.e. `availableTools`, and will reinitialize the MCP server to ensure all tools are generated.
*
* @param {Object} params
* @param {ServerRequest} params.req - The Express request object, containing user/request info.
* @param {ServerResponse} params.res - The Express response object for sending events.
* @param {string} params.serverName
* @param {string} params.model
* @param {Providers | EModelEndpoint} params.provider - The provider for the tool.
* @param {number} [params.index]
* @param {AbortSignal} [params.signal]
* @param {Record<string, Record<string, string>>} [params.userMCPAuthMap]
* @returns { Promise<Array<typeof tool | { _call: (toolInput: Object | string) => unknown}>> } An object with `_call` method to execute the tool input.
*/
async function createMCPTools({ req, res, index, signal, serverName, provider, userMCPAuthMap }) {
const result = await reconnectServer({ req, res, index, signal, serverName, userMCPAuthMap });
if (!result || !result.tools) {
logger.warn(`[MCP][${serverName}] Failed to reinitialize MCP server.`);
return;
}
const serverTools = [];
for (const tool of result.tools) {
const toolInstance = await createMCPTool({
req,
res,
provider,
userMCPAuthMap,
availableTools: result.availableTools,
toolKey: `${tool.name}${Constants.mcp_delimiter}${serverName}`,
});
if (toolInstance) {
serverTools.push(toolInstance);
}
}
return serverTools;
}
/**
* Creates a single tool from the specified MCP Server via `toolKey`.
* @param {Object} params
* @param {ServerRequest} params.req - The Express request object, containing user/request info.
* @param {ServerResponse} params.res - The Express response object for sending events.
* @param {string} params.toolKey - The toolKey for the tool.
* @param {import('@librechat/agents').Providers | EModelEndpoint} params.provider - The provider for the tool.
* @param {string} params.model - The model for the tool.
* @param {number} [params.index]
* @param {AbortSignal} [params.signal]
* @param {Providers | EModelEndpoint} params.provider - The provider for the tool.
* @param {LCAvailableTools} [params.availableTools]
* @param {Record<string, Record<string, string>>} [params.userMCPAuthMap]
* @returns { Promise<typeof tool | { _call: (toolInput: Object | string) => unknown}> } An object with `_call` method to execute the tool input.
*/
async function createMCPTool({ req, res, toolKey, provider: _provider }) {
const availableTools = await getCachedTools({ userId: req.user?.id, includeGlobal: true });
const toolDefinition = availableTools?.[toolKey]?.function;
async function createMCPTool({
req,
res,
index,
signal,
toolKey,
provider,
userMCPAuthMap,
availableTools: tools,
}) {
const [toolName, serverName] = toolKey.split(Constants.mcp_delimiter);
const availableTools =
tools ?? (await getCachedTools({ userId: req.user?.id, includeGlobal: true }));
/** @type {LCTool | undefined} */
let toolDefinition = availableTools?.[toolKey]?.function;
if (!toolDefinition) {
logger.error(`Tool ${toolKey} not found in available tools`);
return null;
logger.warn(
`[MCP][${serverName}][${toolName}] Requested tool not found in available tools, re-initializing MCP server.`,
);
const result = await reconnectServer({ req, res, index, signal, serverName, userMCPAuthMap });
toolDefinition = result?.availableTools?.[toolKey]?.function;
}
if (!toolDefinition) {
logger.warn(`[MCP][${serverName}][${toolName}] Tool definition not found, cannot create tool.`);
return;
}
return createToolInstance({
res,
provider,
toolName,
serverName,
toolDefinition,
});
}
function createToolInstance({ res, toolName, serverName, toolDefinition, provider: _provider }) {
/** @type {LCTool} */
const { description, parameters } = toolDefinition;
const isGoogle = _provider === Providers.VERTEXAI || _provider === Providers.GOOGLE;
@ -128,16 +310,8 @@ async function createMCPTool({ req, res, toolKey, provider: _provider }) {
schema = z.object({ input: z.string().optional() });
}
const [toolName, serverName] = toolKey.split(Constants.mcp_delimiter);
const normalizedToolKey = `${toolName}${Constants.mcp_delimiter}${normalizeServerName(serverName)}`;
if (!req.user?.id) {
logger.error(
`[MCP][${serverName}][${toolName}] User ID not found on request. Cannot create tool.`,
);
throw new Error(`User ID not found on request. Cannot create tool for ${toolKey}.`);
}
/** @type {(toolArguments: Object | string, config?: GraphRunnableConfig) => Promise<unknown>} */
const _call = async (toolArguments, config) => {
const userId = config?.configurable?.user?.id || config?.configurable?.user_id;
@ -154,14 +328,16 @@ async function createMCPTool({ req, res, toolKey, provider: _provider }) {
const provider = (config?.metadata?.provider || _provider)?.toLowerCase();
const { args: _args, stepId, ...toolCall } = config.toolCall ?? {};
const loginFlowId = `${serverName}:oauth_login:${config.metadata.thread_id}:${config.metadata.run_id}`;
const oauthStart = createOAuthStart({
const flowId = `${serverName}:oauth_login:${config.metadata.thread_id}:${config.metadata.run_id}`;
const runStepDeltaEmitter = createRunStepDeltaEmitter({
res,
stepId,
toolCall,
loginFlowId,
});
const oauthStart = createOAuthStart({
flowId,
flowManager,
signal: derivedSignal,
callback: runStepDeltaEmitter,
});
const oauthEnd = createOAuthEnd({
res,
@ -207,7 +383,7 @@ async function createMCPTool({ req, res, toolKey, provider: _provider }) {
return result;
} catch (error) {
logger.error(
`[MCP][User: ${userId}][${serverName}] Error calling "${toolName}" MCP tool:`,
`[MCP][${serverName}][${toolName}][User: ${userId}] Error calling MCP tool:`,
error,
);
@ -220,12 +396,12 @@ async function createMCPTool({ req, res, toolKey, provider: _provider }) {
if (isOAuthError) {
throw new Error(
`OAuth authentication required for ${serverName}. Please check the server logs for the authentication URL.`,
`[MCP][${serverName}][${toolName}] OAuth authentication required. Please check the server logs for the authentication URL.`,
);
}
throw new Error(
`"${toolKey}" tool call failed${error?.message ? `: ${error?.message}` : '.'}`,
`[MCP][${serverName}][${toolName}] tool call failed${error?.message ? `: ${error?.message}` : '.'}`,
);
} finally {
// Clean up abort handler to prevent memory leaks
@ -380,6 +556,7 @@ async function getServerConnectionStatus(
module.exports = {
createMCPTool,
createMCPTools,
getMCPSetupData,
checkOAuthFlowStatus,
getServerConnectionStatus,

View file

@ -1,9 +1,9 @@
const fs = require('fs');
const path = require('path');
const { sleep } = require('@librechat/agents');
const { getToolkitKey } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const { zodToJsonSchema } = require('zod-to-json-schema');
const { getToolkitKey, getUserMCPAuthMap } = require('@librechat/api');
const { Calculator } = require('@langchain/community/tools/calculator');
const { tool: toolFn, Tool, DynamicStructuredTool } = require('@langchain/core/tools');
const {
@ -33,12 +33,17 @@ const {
toolkits,
} = require('~/app/clients/tools');
const { processFileURL, uploadImageBuffer } = require('~/server/services/Files/process');
const { getEndpointsConfig, getCachedTools } = require('~/server/services/Config');
const {
getEndpointsConfig,
hasCustomUserVars,
getCachedTools,
} = require('~/server/services/Config');
const { createOnSearchResults } = require('~/server/services/Tools/search');
const { isActionDomainAllowed } = require('~/server/services/domains');
const { recordUsage } = require('~/server/services/Threads');
const { loadTools } = require('~/app/clients/tools/util');
const { redactMessage } = require('~/config/parsers');
const { findPluginAuthsByKeys } = require('~/models');
/**
* Loads and formats tools from the specified tool directory.
@ -469,11 +474,12 @@ async function processRequiredActions(client, requiredActions) {
* @param {Object} params - Run params containing user and request information.
* @param {ServerRequest} params.req - The request object.
* @param {ServerResponse} params.res - The request object.
* @param {AbortSignal} params.signal
* @param {Pick<Agent, 'id' | 'provider' | 'model' | 'tools'} params.agent - The agent to load tools for.
* @param {string | undefined} [params.openAIApiKey] - The OpenAI API key.
* @returns {Promise<{ tools?: StructuredTool[] }>} The agent tools.
* @returns {Promise<{ tools?: StructuredTool[]; userMCPAuthMap?: Record<string, Record<string, string>> }>} The agent tools.
*/
async function loadAgentTools({ req, res, agent, tool_resources, openAIApiKey }) {
async function loadAgentTools({ req, res, agent, signal, tool_resources, openAIApiKey }) {
if (!agent.tools || agent.tools.length === 0) {
return {};
} else if (agent.tools && agent.tools.length === 1 && agent.tools[0] === AgentCapabilities.ocr) {
@ -523,8 +529,20 @@ async function loadAgentTools({ req, res, agent, tool_resources, openAIApiKey })
webSearchCallbacks = createOnSearchResults(res);
}
/** @type {Record<string, Record<string, string>>} */
let userMCPAuthMap;
if (await hasCustomUserVars()) {
userMCPAuthMap = await getUserMCPAuthMap({
tools: agent.tools,
userId: req.user.id,
findPluginAuthsByKeys,
});
}
const { loadedTools, toolContextMap } = await loadTools({
agent,
signal,
userMCPAuthMap,
functions: true,
user: req.user.id,
tools: _agentTools,
@ -588,6 +606,7 @@ async function loadAgentTools({ req, res, agent, tool_resources, openAIApiKey })
if (!checkCapability(AgentCapabilities.actions)) {
return {
tools: agentTools,
userMCPAuthMap,
toolContextMap,
};
}
@ -599,6 +618,7 @@ async function loadAgentTools({ req, res, agent, tool_resources, openAIApiKey })
}
return {
tools: agentTools,
userMCPAuthMap,
toolContextMap,
};
}
@ -707,6 +727,7 @@ async function loadAgentTools({ req, res, agent, tool_resources, openAIApiKey })
return {
tools: agentTools,
toolContextMap,
userMCPAuthMap,
};
}

View file

@ -0,0 +1,142 @@
const { logger } = require('@librechat/data-schemas');
const { CacheKeys, Constants } = require('librechat-data-provider');
const { findToken, createToken, updateToken, deleteTokens } = require('~/models');
const { getMCPManager, getFlowStateManager } = require('~/config');
const { updateMCPUserTools } = require('~/server/services/Config');
const { getLogStores } = require('~/cache');
/**
* @param {Object} params
* @param {ServerRequest} params.req
* @param {string} params.serverName - The name of the MCP server
* @param {boolean} params.returnOnOAuth - Whether to initiate OAuth and return, or wait for OAuth flow to finish
* @param {AbortSignal} [params.signal] - The abort signal to handle cancellation.
* @param {boolean} [params.forceNew]
* @param {number} [params.connectionTimeout]
* @param {FlowStateManager<any>} [params.flowManager]
* @param {(authURL: string) => Promise<boolean>} [params.oauthStart]
* @param {Record<string, Record<string, string>>} [params.userMCPAuthMap]
*/
async function reinitMCPServer({
req,
signal,
forceNew,
serverName,
userMCPAuthMap,
connectionTimeout,
returnOnOAuth = true,
oauthStart: _oauthStart,
flowManager: _flowManager,
}) {
/** @type {MCPConnection | null} */
let userConnection = null;
/** @type {LCAvailableTools | null} */
let availableTools = null;
/** @type {ReturnType<MCPConnection['fetchTools']> | null} */
let tools = null;
let oauthRequired = false;
let oauthUrl = null;
try {
const customUserVars = userMCPAuthMap?.[`${Constants.mcp_prefix}${serverName}`];
const flowManager = _flowManager ?? getFlowStateManager(getLogStores(CacheKeys.FLOWS));
const mcpManager = getMCPManager();
const oauthStart =
_oauthStart ??
(async (authURL) => {
logger.info(`[MCP Reinitialize] OAuth URL received: ${authURL}`);
oauthUrl = authURL;
oauthRequired = true;
});
try {
userConnection = await mcpManager.getUserConnection({
user: req.user,
signal,
forceNew,
oauthStart,
serverName,
flowManager,
returnOnOAuth,
customUserVars,
connectionTimeout,
tokenMethods: {
findToken,
updateToken,
createToken,
deleteTokens,
},
});
logger.info(`[MCP Reinitialize] Successfully established connection for ${serverName}`);
} catch (err) {
logger.info(`[MCP Reinitialize] getUserConnection threw error: ${err.message}`);
logger.info(
`[MCP Reinitialize] OAuth state - oauthRequired: ${oauthRequired}, oauthUrl: ${oauthUrl ? 'present' : 'null'}`,
);
const isOAuthError =
err.message?.includes('OAuth') ||
err.message?.includes('authentication') ||
err.message?.includes('401');
const isOAuthFlowInitiated = err.message === 'OAuth flow initiated - return early';
if (isOAuthError || oauthRequired || isOAuthFlowInitiated) {
logger.info(
`[MCP Reinitialize] OAuth required for ${serverName} (isOAuthError: ${isOAuthError}, oauthRequired: ${oauthRequired}, isOAuthFlowInitiated: ${isOAuthFlowInitiated})`,
);
oauthRequired = true;
} else {
logger.error(
`[MCP Reinitialize] Error initializing MCP server ${serverName} for user:`,
err,
);
}
}
if (userConnection && !oauthRequired) {
tools = await userConnection.fetchTools();
availableTools = await updateMCPUserTools({
userId: req.user.id,
serverName,
tools,
});
}
logger.debug(
`[MCP Reinitialize] Sending response for ${serverName} - oauthRequired: ${oauthRequired}, oauthUrl: ${oauthUrl ? 'present' : 'null'}`,
);
const getResponseMessage = () => {
if (oauthRequired) {
return `MCP server '${serverName}' ready for OAuth authentication`;
}
if (userConnection) {
return `MCP server '${serverName}' reinitialized successfully`;
}
return `Failed to reinitialize MCP server '${serverName}'`;
};
const result = {
availableTools,
success: Boolean((userConnection && !oauthRequired) || (oauthRequired && oauthUrl)),
message: getResponseMessage(),
oauthRequired,
serverName,
oauthUrl,
tools,
};
logger.debug(`[MCP Reinitialize] Response for ${serverName}:`, result);
return result;
} catch (error) {
logger.error(
'[MCP Reinitialize] Error loading MCP Tools, servers may still be initializing:',
error,
);
}
}
module.exports = {
reinitMCPServer,
};