🚦 feat: Auto-reinitialize MCP Servers on Request (#9226)

This commit is contained in:
Danny Avila 2025-08-23 03:27:05 -04:00
parent ac608ded46
commit c827fdd10e
No known key found for this signature in database
GPG key ID: BF31EEB2C5CA0956
28 changed files with 871 additions and 312 deletions

View file

@ -33,18 +33,13 @@ const {
bedrockInputSchema,
removeNullishValues,
} = require('librechat-data-provider');
const {
findPluginAuthsByKeys,
getFormattedMemories,
deleteMemory,
setMemory,
} = require('~/models');
const { getMCPAuthMap, checkCapability, hasCustomUserVars } = require('~/server/services/Config');
const { addCacheControl, createContextHandlers } = require('~/app/clients/prompts');
const { initializeAgent } = require('~/server/services/Endpoints/agents/agent');
const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens');
const { getFormattedMemories, deleteMemory, setMemory } = require('~/models');
const { encodeAndFormat } = require('~/server/services/Files/images/encode');
const { getProviderConfig } = require('~/server/services/Endpoints');
const { checkCapability } = require('~/server/services/Config');
const BaseClient = require('~/app/clients/BaseClient');
const { getRoleByName } = require('~/models/Role');
const { loadAgent } = require('~/models/Agent');
@ -615,6 +610,7 @@ class AgentClient extends BaseClient {
await this.chatCompletion({
payload,
onProgress: opts.onProgress,
userMCPAuthMap: opts.userMCPAuthMap,
abortController: opts.abortController,
});
return this.contentParts;
@ -747,7 +743,13 @@ class AgentClient extends BaseClient {
return currentMessageTokens > 0 ? currentMessageTokens : originalEstimate;
}
async chatCompletion({ payload, abortController = null }) {
/**
* @param {object} params
* @param {string | ChatCompletionMessageParam[]} params.payload
* @param {Record<string, Record<string, string>>} [params.userMCPAuthMap]
* @param {AbortController} [params.abortController]
*/
async chatCompletion({ payload, userMCPAuthMap, abortController = null }) {
/** @type {Partial<GraphRunnableConfig>} */
let config;
/** @type {ReturnType<createRun>} */
@ -903,21 +905,9 @@ class AgentClient extends BaseClient {
run.Graph.contentData = contentData;
}
try {
if (await hasCustomUserVars()) {
config.configurable.userMCPAuthMap = await getMCPAuthMap({
tools: agent.tools,
userId: this.options.req.user.id,
findPluginAuthsByKeys,
});
}
} catch (err) {
logger.error(
`[api/server/controllers/agents/client.js #chatCompletion] Error getting custom user vars for agent ${agent.id}`,
err,
);
if (userMCPAuthMap != null) {
config.configurable.userMCPAuthMap = userMCPAuthMap;
}
await run.processStream({ messages }, config, {
keepContent: i !== 0,
tokenCounter: createTokenCounter(this.getEncoding()),

View file

@ -9,6 +9,24 @@ const {
const { disposeClient, clientRegistry, requestDataMap } = require('~/server/cleanup');
const { saveMessage } = require('~/models');
function createCloseHandler(abortController) {
return function (manual) {
if (!manual) {
logger.debug('[AgentController] Request closed');
}
if (!abortController) {
return;
} else if (abortController.signal.aborted) {
return;
} else if (abortController.requestCompleted) {
return;
}
abortController.abort();
logger.debug('[AgentController] Request aborted on close');
};
}
const AgentController = async (req, res, next, initializeClient, addTitle) => {
let {
text,
@ -31,7 +49,6 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
let userMessagePromise;
let getAbortData;
let client = null;
// Initialize as an array
let cleanupHandlers = [];
const newConvo = !conversationId;
@ -62,9 +79,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
// Create a function to handle final cleanup
const performCleanup = () => {
logger.debug('[AgentController] Performing cleanup');
// Make sure cleanupHandlers is an array before iterating
if (Array.isArray(cleanupHandlers)) {
// Execute all cleanup handlers
for (const handler of cleanupHandlers) {
try {
if (typeof handler === 'function') {
@ -105,8 +120,33 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
};
try {
/** @type {{ client: TAgentClient }} */
const result = await initializeClient({ req, res, endpointOption });
let prelimAbortController = new AbortController();
const prelimCloseHandler = createCloseHandler(prelimAbortController);
res.on('close', prelimCloseHandler);
const removePrelimHandler = (manual) => {
try {
prelimCloseHandler(manual);
res.removeListener('close', prelimCloseHandler);
} catch (e) {
logger.error('[AgentController] Error removing close listener', e);
}
};
cleanupHandlers.push(removePrelimHandler);
/** @type {{ client: TAgentClient; userMCPAuthMap?: Record<string, Record<string, string>> }} */
const result = await initializeClient({
req,
res,
endpointOption,
signal: prelimAbortController.signal,
});
if (prelimAbortController.signal?.aborted) {
prelimAbortController = null;
throw new Error('Request was aborted before initialization could complete');
} else {
prelimAbortController = null;
removePrelimHandler(true);
cleanupHandlers.pop();
}
client = result.client;
// Register client with finalization registry if available
@ -138,22 +178,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
};
const { abortController, onStart } = createAbortController(req, res, getAbortData, getReqData);
// Simple handler to avoid capturing scope
const closeHandler = () => {
logger.debug('[AgentController] Request closed');
if (!abortController) {
return;
} else if (abortController.signal.aborted) {
return;
} else if (abortController.requestCompleted) {
return;
}
abortController.abort();
logger.debug('[AgentController] Request aborted on close');
};
const closeHandler = createCloseHandler(abortController);
res.on('close', closeHandler);
cleanupHandlers.push(() => {
try {
@ -175,6 +200,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
abortController,
overrideParentMessageId,
isEdited: !!editedContent,
userMCPAuthMap: result.userMCPAuthMap,
responseMessageId: editedResponseMessageId,
progressOptions: {
res,

View file

@ -11,6 +11,7 @@ jest.mock('@librechat/api', () => ({
completeOAuthFlow: jest.fn(),
generateFlowId: jest.fn(),
},
getUserMCPAuthMap: jest.fn(),
}));
jest.mock('@librechat/data-schemas', () => ({
@ -37,6 +38,7 @@ jest.mock('~/models', () => ({
updateToken: jest.fn(),
createToken: jest.fn(),
deleteTokens: jest.fn(),
findPluginAuthsByKeys: jest.fn(),
}));
jest.mock('~/server/services/Config', () => ({
@ -71,6 +73,10 @@ jest.mock('~/server/middleware', () => ({
requireJwtAuth: (req, res, next) => next(),
}));
jest.mock('~/server/services/Tools/mcp', () => ({
reinitMCPServer: jest.fn(),
}));
describe('MCP Routes', () => {
let app;
let mongoServer;
@ -682,6 +688,13 @@ describe('MCP Routes', () => {
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
require('~/config').getFlowStateManager.mockReturnValue({});
require('~/cache').getLogStores.mockReturnValue({});
require('~/server/services/Tools/mcp').reinitMCPServer.mockResolvedValue({
success: true,
message: "MCP server 'oauth-server' ready for OAuth authentication",
serverName: 'oauth-server',
oauthRequired: true,
oauthUrl: 'https://oauth.example.com/auth',
});
const response = await request(app).post('/api/mcp/oauth-server/reinitialize');
@ -706,6 +719,7 @@ describe('MCP Routes', () => {
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
require('~/config').getFlowStateManager.mockReturnValue({});
require('~/cache').getLogStores.mockReturnValue({});
require('~/server/services/Tools/mcp').reinitMCPServer.mockResolvedValue(null);
const response = await request(app).post('/api/mcp/error-server/reinitialize');
@ -769,6 +783,14 @@ describe('MCP Routes', () => {
setCachedTools.mockResolvedValue();
updateMCPUserTools.mockResolvedValue();
require('~/server/services/Tools/mcp').reinitMCPServer.mockResolvedValue({
success: true,
message: "MCP server 'test-server' reinitialized successfully",
serverName: 'test-server',
oauthRequired: false,
oauthUrl: null,
});
const response = await request(app).post('/api/mcp/test-server/reinitialize');
expect(response.status).toBe(200);
@ -783,14 +805,6 @@ describe('MCP Routes', () => {
'test-user-id',
'test-server',
);
expect(updateMCPUserTools).toHaveBeenCalledWith({
userId: 'test-user-id',
serverName: 'test-server',
tools: [
{ name: 'tool1', description: 'Test tool 1', inputSchema: { type: 'object' } },
{ name: 'tool2', description: 'Test tool 2', inputSchema: { type: 'object' } },
],
});
});
it('should handle server with custom user variables', async () => {
@ -812,9 +826,14 @@ describe('MCP Routes', () => {
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
require('~/config').getFlowStateManager.mockReturnValue({});
require('~/cache').getLogStores.mockReturnValue({});
require('~/server/services/PluginService').getUserPluginAuthValue.mockResolvedValue(
'api-key-value',
);
require('@librechat/api').getUserMCPAuthMap.mockResolvedValue({
'mcp:test-server': {
API_KEY: 'api-key-value',
},
});
require('~/models').findPluginAuthsByKeys.mockResolvedValue([
{ key: 'API_KEY', value: 'api-key-value' },
]);
const { getCachedTools, setCachedTools } = require('~/server/services/Config');
const { updateMCPUserTools } = require('~/server/services/Config/mcpToolsCache');
@ -822,13 +841,23 @@ describe('MCP Routes', () => {
setCachedTools.mockResolvedValue();
updateMCPUserTools.mockResolvedValue();
require('~/server/services/Tools/mcp').reinitMCPServer.mockResolvedValue({
success: true,
message: "MCP server 'test-server' reinitialized successfully",
serverName: 'test-server',
oauthRequired: false,
oauthUrl: null,
});
const response = await request(app).post('/api/mcp/test-server/reinitialize');
expect(response.status).toBe(200);
expect(response.body.success).toBe(true);
expect(
require('~/server/services/PluginService').getUserPluginAuthValue,
).toHaveBeenCalledWith('test-user-id', 'API_KEY', false);
expect(require('@librechat/api').getUserMCPAuthMap).toHaveBeenCalledWith({
userId: 'test-user-id',
servers: ['test-server'],
findPluginAuthsByKeys: require('~/models').findPluginAuthsByKeys,
});
});
});

View file

@ -1,13 +1,15 @@
const { logger } = require('@librechat/data-schemas');
const { MCPOAuthHandler } = require('@librechat/api');
const { Router } = require('express');
const { logger } = require('@librechat/data-schemas');
const { MCPOAuthHandler, getUserMCPAuthMap } = require('@librechat/api');
const { getMCPSetupData, getServerConnectionStatus } = require('~/server/services/MCP');
const { findToken, updateToken, createToken, deleteTokens } = require('~/models');
const { updateMCPUserTools } = require('~/server/services/Config/mcpToolsCache');
const { getUserPluginAuthValue } = require('~/server/services/PluginService');
const { CacheKeys, Constants } = require('librechat-data-provider');
const { getMCPManager, getFlowStateManager } = require('~/config');
const { reinitMCPServer } = require('~/server/services/Tools/mcp');
const { requireJwtAuth } = require('~/server/middleware');
const { findPluginAuthsByKeys } = require('~/models');
const { getLogStores } = require('~/cache');
const router = Router();
@ -302,107 +304,39 @@ router.post('/:serverName/reinitialize', requireJwtAuth, async (req, res) => {
});
}
const flowsCache = getLogStores(CacheKeys.FLOWS);
const flowManager = getFlowStateManager(flowsCache);
await mcpManager.disconnectUserConnection(user.id, serverName);
logger.info(
`[MCP Reinitialize] Disconnected existing user connection for server: ${serverName}`,
);
let customUserVars = {};
/** @type {Record<string, Record<string, string>> | undefined} */
let userMCPAuthMap;
if (serverConfig.customUserVars && typeof serverConfig.customUserVars === 'object') {
for (const varName of Object.keys(serverConfig.customUserVars)) {
try {
const value = await getUserPluginAuthValue(user.id, varName, false);
customUserVars[varName] = value;
} catch (err) {
logger.error(`[MCP Reinitialize] Error fetching ${varName} for user ${user.id}:`, err);
}
}
}
let userConnection = null;
let oauthRequired = false;
let oauthUrl = null;
try {
userConnection = await mcpManager.getUserConnection({
user,
serverName,
flowManager,
customUserVars,
tokenMethods: {
findToken,
updateToken,
createToken,
deleteTokens,
},
returnOnOAuth: true,
oauthStart: async (authURL) => {
logger.info(`[MCP Reinitialize] OAuth URL received: ${authURL}`);
oauthUrl = authURL;
oauthRequired = true;
},
});
logger.info(`[MCP Reinitialize] Successfully established connection for ${serverName}`);
} catch (err) {
logger.info(`[MCP Reinitialize] getUserConnection threw error: ${err.message}`);
logger.info(
`[MCP Reinitialize] OAuth state - oauthRequired: ${oauthRequired}, oauthUrl: ${oauthUrl ? 'present' : 'null'}`,
);
const isOAuthError =
err.message?.includes('OAuth') ||
err.message?.includes('authentication') ||
err.message?.includes('401');
const isOAuthFlowInitiated = err.message === 'OAuth flow initiated - return early';
if (isOAuthError || oauthRequired || isOAuthFlowInitiated) {
logger.info(
`[MCP Reinitialize] OAuth required for ${serverName} (isOAuthError: ${isOAuthError}, oauthRequired: ${oauthRequired}, isOAuthFlowInitiated: ${isOAuthFlowInitiated})`,
);
oauthRequired = true;
} else {
logger.error(
`[MCP Reinitialize] Error initializing MCP server ${serverName} for user:`,
err,
);
return res.status(500).json({ error: 'Failed to reinitialize MCP server for user' });
}
}
if (userConnection && !oauthRequired) {
const tools = await userConnection.fetchTools();
await updateMCPUserTools({
userMCPAuthMap = await getUserMCPAuthMap({
userId: user.id,
serverName,
tools,
servers: [serverName],
findPluginAuthsByKeys,
});
}
logger.debug(
`[MCP Reinitialize] Sending response for ${serverName} - oauthRequired: ${oauthRequired}, oauthUrl: ${oauthUrl ? 'present' : 'null'}`,
);
const result = await reinitMCPServer({
req,
serverName,
userMCPAuthMap,
});
const getResponseMessage = () => {
if (oauthRequired) {
return `MCP server '${serverName}' ready for OAuth authentication`;
}
if (userConnection) {
return `MCP server '${serverName}' reinitialized successfully`;
}
return `Failed to reinitialize MCP server '${serverName}'`;
};
if (!result) {
return res.status(500).json({ error: 'Failed to reinitialize MCP server for user' });
}
const { success, message, oauthRequired, oauthUrl } = result;
res.json({
success: Boolean((userConnection && !oauthRequired) || (oauthRequired && oauthUrl)),
message: getResponseMessage(),
success,
message,
oauthUrl,
serverName,
oauthRequired,
oauthUrl,
});
} catch (error) {
logger.error('[MCP Reinitialize] Unexpected error', error);

View file

@ -26,7 +26,7 @@ const ToolCacheKeys = {
* @param {string[]} [options.roleIds] - Role IDs for role-based tools
* @param {string[]} [options.groupIds] - Group IDs for group-based tools
* @param {boolean} [options.includeGlobal=true] - Whether to include global tools
* @returns {Promise<Object|null>} The available tools object or null if not cached
* @returns {Promise<LCAvailableTools|null>} The available tools object or null if not cached
*/
async function getCachedTools(options = {}) {
const cache = getLogStores(CacheKeys.CONFIG_STORE);
@ -41,13 +41,13 @@ async function getCachedTools(options = {}) {
// Future implementation will merge tools from multiple sources
// based on user permissions, roles, and groups
if (userId) {
// Check if we have pre-computed effective tools for this user
/** @type {LCAvailableTools | null} Check if we have pre-computed effective tools for this user */
const effectiveTools = await cache.get(ToolCacheKeys.EFFECTIVE(userId));
if (effectiveTools) {
return effectiveTools;
}
// Otherwise, compute from individual sources
/** @type {LCAvailableTools | null} Otherwise, compute from individual sources */
const toolSources = [];
if (includeGlobal) {

View file

@ -1,5 +1,4 @@
const { logger } = require('@librechat/data-schemas');
const { isEnabled, getUserMCPAuthMap } = require('@librechat/api');
const { isEnabled } = require('@librechat/api');
const { CacheKeys, EModelEndpoint } = require('librechat-data-provider');
const { normalizeEndpointName } = require('~/server/utils');
const loadCustomConfig = require('./loadCustomConfig');
@ -53,31 +52,6 @@ const getCustomEndpointConfig = async (endpoint) => {
);
};
/**
* @param {Object} params
* @param {string} params.userId
* @param {GenericTool[]} [params.tools]
* @param {import('@librechat/data-schemas').PluginAuthMethods['findPluginAuthsByKeys']} params.findPluginAuthsByKeys
* @returns {Promise<Record<string, Record<string, string>> | undefined>}
*/
async function getMCPAuthMap({ userId, tools, findPluginAuthsByKeys }) {
try {
if (!tools || tools.length === 0) {
return;
}
return await getUserMCPAuthMap({
tools,
userId,
findPluginAuthsByKeys,
});
} catch (err) {
logger.error(
`[api/server/controllers/agents/client.js #chatCompletion] Error getting custom user vars for agent`,
err,
);
}
}
/**
* @returns {Promise<boolean>}
*/
@ -88,7 +62,6 @@ async function hasCustomUserVars() {
}
module.exports = {
getMCPAuthMap,
getCustomConfig,
getBalanceConfig,
hasCustomUserVars,

View file

@ -9,7 +9,7 @@ const { getLogStores } = require('~/cache');
* @param {string} params.userId - User ID
* @param {string} params.serverName - MCP server name
* @param {Array} params.tools - Array of tool objects from MCP server
* @returns {Promise<void>}
* @returns {Promise<LCAvailableTools>}
*/
async function updateMCPUserTools({ userId, serverName, tools }) {
try {
@ -39,6 +39,7 @@ async function updateMCPUserTools({ userId, serverName, tools }) {
const cache = getLogStores(CacheKeys.CONFIG_STORE);
await cache.delete(CacheKeys.TOOLS);
logger.debug(`[MCP Cache] Updated ${tools.length} tools for ${serverName} user ${userId}`);
return userTools;
} catch (error) {
logger.error(`[MCP Cache] Failed to update tools for ${serverName}:`, error);
throw error;

View file

@ -30,7 +30,13 @@ const { getModelMaxTokens } = require('~/utils');
* @param {TEndpointOption} [params.endpointOption]
* @param {Set<string>} [params.allowedProviders]
* @param {boolean} [params.isInitialAgent]
* @returns {Promise<Agent & { tools: StructuredTool[], attachments: Array<MongoFile>, toolContextMap: Record<string, unknown>, maxContextTokens: number }>}
* @returns {Promise<Agent & {
* tools: StructuredTool[],
* attachments: Array<MongoFile>,
* toolContextMap: Record<string, unknown>,
* maxContextTokens: number,
* userMCPAuthMap?: Record<string, Record<string, string>>
* }>}
*/
const initializeAgent = async ({
req,
@ -91,16 +97,19 @@ const initializeAgent = async ({
});
const provider = agent.provider;
const { tools: structuredTools, toolContextMap } =
(await loadTools?.({
req,
res,
provider,
agentId: agent.id,
tools: agent.tools,
model: agent.model,
tool_resources,
})) ?? {};
const {
tools: structuredTools,
toolContextMap,
userMCPAuthMap,
} = (await loadTools?.({
req,
res,
provider,
agentId: agent.id,
tools: agent.tools,
model: agent.model,
tool_resources,
})) ?? {};
agent.endpoint = provider;
const { getOptions, overrideProvider } = await getProviderConfig(provider);
@ -189,6 +198,7 @@ const initializeAgent = async ({
tools,
attachments,
resendFiles,
userMCPAuthMap,
toolContextMap,
useLegacyContent: !!options.useLegacyContent,
maxContextTokens: Math.round((agentMaxContextTokens - maxTokens) * 0.9),

View file

@ -19,7 +19,10 @@ const AgentClient = require('~/server/controllers/agents/client');
const { getAgent } = require('~/models/Agent');
const { logViolation } = require('~/cache');
function createToolLoader() {
/**
* @param {AbortSignal} signal
*/
function createToolLoader(signal) {
/**
* @param {object} params
* @param {ServerRequest} params.req
@ -29,7 +32,11 @@ function createToolLoader() {
* @param {string} params.provider
* @param {string} params.model
* @param {AgentToolResources} params.tool_resources
* @returns {Promise<{ tools: StructuredTool[], toolContextMap: Record<string, unknown> } | undefined>}
* @returns {Promise<{
* tools: StructuredTool[],
* toolContextMap: Record<string, unknown>,
* userMCPAuthMap?: Record<string, Record<string, string>>
* } | undefined>}
*/
return async function loadTools({ req, res, agentId, tools, provider, model, tool_resources }) {
const agent = { id: agentId, tools, provider, model };
@ -38,6 +45,7 @@ function createToolLoader() {
req,
res,
agent,
signal,
tool_resources,
});
} catch (error) {
@ -46,7 +54,7 @@ function createToolLoader() {
};
}
const initializeClient = async ({ req, res, endpointOption }) => {
const initializeClient = async ({ req, res, signal, endpointOption }) => {
if (!endpointOption) {
throw new Error('Endpoint option not provided');
}
@ -92,7 +100,7 @@ const initializeClient = async ({ req, res, endpointOption }) => {
/** @type {Set<string>} */
const allowedProviders = new Set(req?.app?.locals?.[EModelEndpoint.agents]?.allowedProviders);
const loadTools = createToolLoader();
const loadTools = createToolLoader(signal);
/** @type {Array<MongoFile>} */
const requestFiles = req.body.files ?? [];
/** @type {string} */
@ -111,6 +119,7 @@ const initializeClient = async ({ req, res, endpointOption }) => {
});
const agent_ids = primaryConfig.agent_ids;
let userMCPAuthMap = primaryConfig.userMCPAuthMap;
if (agent_ids?.length) {
for (const agentId of agent_ids) {
const agent = await getAgent({ id: agentId });
@ -140,6 +149,7 @@ const initializeClient = async ({ req, res, endpointOption }) => {
endpointOption,
allowedProviders,
});
Object.assign(userMCPAuthMap, config.userMCPAuthMap ?? {});
agentConfigs.set(agentId, config);
}
}
@ -188,7 +198,7 @@ const initializeClient = async ({ req, res, endpointOption }) => {
: EModelEndpoint.agents,
});
return { client };
return { client, userMCPAuthMap };
};
module.exports = { initializeClient };

View file

@ -1,7 +1,12 @@
const { z } = require('zod');
const { tool } = require('@langchain/core/tools');
const { logger } = require('@librechat/data-schemas');
const { Constants: AgentConstants, Providers, GraphEvents } = require('@librechat/agents');
const {
Providers,
StepTypes,
GraphEvents,
Constants: AgentConstants,
} = require('@librechat/agents');
const {
sendEvent,
MCPOAuthHandler,
@ -11,14 +16,14 @@ const {
const {
Time,
CacheKeys,
StepTypes,
Constants,
ContentTypes,
isAssistantsEndpoint,
} = require('librechat-data-provider');
const { getCachedTools, loadCustomConfig } = require('./Config');
const { findToken, createToken, updateToken } = require('~/models');
const { getMCPManager, getFlowStateManager } = require('~/config');
const { getCachedTools, loadCustomConfig } = require('./Config');
const { reinitMCPServer } = require('./Tools/mcp');
const { getLogStores } = require('~/cache');
/**
@ -26,16 +31,13 @@ const { getLogStores } = require('~/cache');
* @param {ServerResponse} params.res - The Express response object for sending events.
* @param {string} params.stepId - The ID of the step in the flow.
* @param {ToolCallChunk} params.toolCall - The tool call object containing tool information.
* @param {string} params.loginFlowId - The ID of the login flow.
* @param {FlowStateManager<any>} params.flowManager - The flow manager instance.
*/
function createOAuthStart({ res, stepId, toolCall, loginFlowId, flowManager, signal }) {
function createRunStepDeltaEmitter({ res, stepId, toolCall }) {
/**
* Creates a function to handle OAuth login requests.
* @param {string} authURL - The URL to redirect the user for OAuth authentication.
* @returns {Promise<boolean>} Returns true to indicate the event was sent successfully.
* @returns {void}
*/
return async function (authURL) {
return function (authURL) {
/** @type {{ id: string; delta: AgentToolCallDelta }} */
const data = {
id: stepId,
@ -46,17 +48,54 @@ function createOAuthStart({ res, stepId, toolCall, loginFlowId, flowManager, sig
expires_at: Date.now() + Time.TWO_MINUTES,
},
};
/** Used to ensure the handler (use of `sendEvent`) is only invoked once */
await flowManager.createFlowWithHandler(
loginFlowId,
'oauth_login',
async () => {
sendEvent(res, { event: GraphEvents.ON_RUN_STEP_DELTA, data });
logger.debug('Sent OAuth login request to client');
return true;
sendEvent(res, { event: GraphEvents.ON_RUN_STEP_DELTA, data });
};
}
/**
* @param {object} params
* @param {ServerResponse} params.res - The Express response object for sending events.
* @param {string} params.runId - The Run ID, i.e. message ID
* @param {string} params.stepId - The ID of the step in the flow.
* @param {ToolCallChunk} params.toolCall - The tool call object containing tool information.
* @param {number} [params.index]
*/
function createRunStepEmitter({ res, runId, stepId, toolCall, index }) {
return function () {
/** @type {import('@librechat/agents').RunStep} */
const data = {
runId: runId ?? Constants.USE_PRELIM_RESPONSE_MESSAGE_ID,
id: stepId,
type: StepTypes.TOOL_CALLS,
index: index ?? 0,
stepDetails: {
type: StepTypes.TOOL_CALLS,
tool_calls: [toolCall],
},
signal,
);
};
sendEvent(res, { event: GraphEvents.ON_RUN_STEP, data });
};
}
/**
* Creates a function used to ensure the flow handler is only invoked once
* @param {object} params
* @param {string} params.flowId - The ID of the login flow.
* @param {FlowStateManager<any>} params.flowManager - The flow manager instance.
* @param {(authURL: string) => void} [params.callback]
*/
function createOAuthStart({ flowId, flowManager, callback }) {
/**
* Creates a function to handle OAuth login requests.
* @param {string} authURL - The URL to redirect the user for OAuth authentication.
* @returns {Promise<boolean>} Returns true to indicate the event was sent successfully.
*/
return async function (authURL) {
await flowManager.createFlowWithHandler(flowId, 'oauth_login', async () => {
callback?.(authURL);
logger.debug('Sent OAuth login request to client');
return true;
});
};
}
@ -99,23 +138,166 @@ function createAbortHandler({ userId, serverName, toolName, flowManager }) {
}
/**
* Creates a general tool for an entire action set.
* @param {Object} params
* @param {() => void} params.runStepEmitter
* @param {(authURL: string) => void} params.runStepDeltaEmitter
* @returns {(authURL: string) => void}
*/
function createOAuthCallback({ runStepEmitter, runStepDeltaEmitter }) {
return function (authURL) {
runStepEmitter();
runStepDeltaEmitter(authURL);
};
}
/**
* @param {Object} params
* @param {ServerRequest} params.req - The Express request object, containing user/request info.
* @param {ServerResponse} params.res - The Express response object for sending events.
* @param {string} params.serverName
* @param {AbortSignal} params.signal
* @param {string} params.model
* @param {number} [params.index]
* @param {Record<string, Record<string, string>>} [params.userMCPAuthMap]
* @returns { Promise<Array<typeof tool | { _call: (toolInput: Object | string) => unknown}>> } An object with `_call` method to execute the tool input.
*/
async function reconnectServer({ req, res, index, signal, serverName, userMCPAuthMap }) {
const runId = Constants.USE_PRELIM_RESPONSE_MESSAGE_ID;
const flowId = `${req.user?.id}:${serverName}:${Date.now()}`;
const flowManager = getFlowStateManager(getLogStores(CacheKeys.FLOWS));
const stepId = 'step_oauth_login_' + serverName;
const toolCall = {
id: flowId,
name: serverName,
type: 'tool_call_chunk',
};
const runStepEmitter = createRunStepEmitter({
res,
index,
runId,
stepId,
toolCall,
});
const runStepDeltaEmitter = createRunStepDeltaEmitter({
res,
stepId,
toolCall,
});
const callback = createOAuthCallback({ runStepEmitter, runStepDeltaEmitter });
const oauthStart = createOAuthStart({
res,
flowId,
callback,
flowManager,
});
return await reinitMCPServer({
req,
signal,
serverName,
oauthStart,
flowManager,
userMCPAuthMap,
forceNew: true,
returnOnOAuth: false,
connectionTimeout: Time.TWO_MINUTES,
});
}
/**
* Creates all tools from the specified MCP Server via `toolKey`.
*
* @param {Object} params - The parameters for loading action sets.
* This function assumes tools could not be aggregated from the cache of tool definitions,
* i.e. `availableTools`, and will reinitialize the MCP server to ensure all tools are generated.
*
* @param {Object} params
* @param {ServerRequest} params.req - The Express request object, containing user/request info.
* @param {ServerResponse} params.res - The Express response object for sending events.
* @param {string} params.serverName
* @param {string} params.model
* @param {Providers | EModelEndpoint} params.provider - The provider for the tool.
* @param {number} [params.index]
* @param {AbortSignal} [params.signal]
* @param {Record<string, Record<string, string>>} [params.userMCPAuthMap]
* @returns { Promise<Array<typeof tool | { _call: (toolInput: Object | string) => unknown}>> } An object with `_call` method to execute the tool input.
*/
async function createMCPTools({ req, res, index, signal, serverName, provider, userMCPAuthMap }) {
const result = await reconnectServer({ req, res, index, signal, serverName, userMCPAuthMap });
if (!result || !result.tools) {
logger.warn(`[MCP][${serverName}] Failed to reinitialize MCP server.`);
return;
}
const serverTools = [];
for (const tool of result.tools) {
const toolInstance = await createMCPTool({
req,
res,
provider,
userMCPAuthMap,
availableTools: result.availableTools,
toolKey: `${tool.name}${Constants.mcp_delimiter}${serverName}`,
});
if (toolInstance) {
serverTools.push(toolInstance);
}
}
return serverTools;
}
/**
* Creates a single tool from the specified MCP Server via `toolKey`.
* @param {Object} params
* @param {ServerRequest} params.req - The Express request object, containing user/request info.
* @param {ServerResponse} params.res - The Express response object for sending events.
* @param {string} params.toolKey - The toolKey for the tool.
* @param {import('@librechat/agents').Providers | EModelEndpoint} params.provider - The provider for the tool.
* @param {string} params.model - The model for the tool.
* @param {number} [params.index]
* @param {AbortSignal} [params.signal]
* @param {Providers | EModelEndpoint} params.provider - The provider for the tool.
* @param {LCAvailableTools} [params.availableTools]
* @param {Record<string, Record<string, string>>} [params.userMCPAuthMap]
* @returns { Promise<typeof tool | { _call: (toolInput: Object | string) => unknown}> } An object with `_call` method to execute the tool input.
*/
async function createMCPTool({ req, res, toolKey, provider: _provider }) {
const availableTools = await getCachedTools({ userId: req.user?.id, includeGlobal: true });
const toolDefinition = availableTools?.[toolKey]?.function;
async function createMCPTool({
req,
res,
index,
signal,
toolKey,
provider,
userMCPAuthMap,
availableTools: tools,
}) {
const [toolName, serverName] = toolKey.split(Constants.mcp_delimiter);
const availableTools =
tools ?? (await getCachedTools({ userId: req.user?.id, includeGlobal: true }));
/** @type {LCTool | undefined} */
let toolDefinition = availableTools?.[toolKey]?.function;
if (!toolDefinition) {
logger.error(`Tool ${toolKey} not found in available tools`);
return null;
logger.warn(
`[MCP][${serverName}][${toolName}] Requested tool not found in available tools, re-initializing MCP server.`,
);
const result = await reconnectServer({ req, res, index, signal, serverName, userMCPAuthMap });
toolDefinition = result?.availableTools?.[toolKey]?.function;
}
if (!toolDefinition) {
logger.warn(`[MCP][${serverName}][${toolName}] Tool definition not found, cannot create tool.`);
return;
}
return createToolInstance({
res,
provider,
toolName,
serverName,
toolDefinition,
});
}
function createToolInstance({ res, toolName, serverName, toolDefinition, provider: _provider }) {
/** @type {LCTool} */
const { description, parameters } = toolDefinition;
const isGoogle = _provider === Providers.VERTEXAI || _provider === Providers.GOOGLE;
@ -128,16 +310,8 @@ async function createMCPTool({ req, res, toolKey, provider: _provider }) {
schema = z.object({ input: z.string().optional() });
}
const [toolName, serverName] = toolKey.split(Constants.mcp_delimiter);
const normalizedToolKey = `${toolName}${Constants.mcp_delimiter}${normalizeServerName(serverName)}`;
if (!req.user?.id) {
logger.error(
`[MCP][${serverName}][${toolName}] User ID not found on request. Cannot create tool.`,
);
throw new Error(`User ID not found on request. Cannot create tool for ${toolKey}.`);
}
/** @type {(toolArguments: Object | string, config?: GraphRunnableConfig) => Promise<unknown>} */
const _call = async (toolArguments, config) => {
const userId = config?.configurable?.user?.id || config?.configurable?.user_id;
@ -154,14 +328,16 @@ async function createMCPTool({ req, res, toolKey, provider: _provider }) {
const provider = (config?.metadata?.provider || _provider)?.toLowerCase();
const { args: _args, stepId, ...toolCall } = config.toolCall ?? {};
const loginFlowId = `${serverName}:oauth_login:${config.metadata.thread_id}:${config.metadata.run_id}`;
const oauthStart = createOAuthStart({
const flowId = `${serverName}:oauth_login:${config.metadata.thread_id}:${config.metadata.run_id}`;
const runStepDeltaEmitter = createRunStepDeltaEmitter({
res,
stepId,
toolCall,
loginFlowId,
});
const oauthStart = createOAuthStart({
flowId,
flowManager,
signal: derivedSignal,
callback: runStepDeltaEmitter,
});
const oauthEnd = createOAuthEnd({
res,
@ -207,7 +383,7 @@ async function createMCPTool({ req, res, toolKey, provider: _provider }) {
return result;
} catch (error) {
logger.error(
`[MCP][User: ${userId}][${serverName}] Error calling "${toolName}" MCP tool:`,
`[MCP][${serverName}][${toolName}][User: ${userId}] Error calling MCP tool:`,
error,
);
@ -220,12 +396,12 @@ async function createMCPTool({ req, res, toolKey, provider: _provider }) {
if (isOAuthError) {
throw new Error(
`OAuth authentication required for ${serverName}. Please check the server logs for the authentication URL.`,
`[MCP][${serverName}][${toolName}] OAuth authentication required. Please check the server logs for the authentication URL.`,
);
}
throw new Error(
`"${toolKey}" tool call failed${error?.message ? `: ${error?.message}` : '.'}`,
`[MCP][${serverName}][${toolName}] tool call failed${error?.message ? `: ${error?.message}` : '.'}`,
);
} finally {
// Clean up abort handler to prevent memory leaks
@ -380,6 +556,7 @@ async function getServerConnectionStatus(
module.exports = {
createMCPTool,
createMCPTools,
getMCPSetupData,
checkOAuthFlowStatus,
getServerConnectionStatus,

View file

@ -1,9 +1,9 @@
const fs = require('fs');
const path = require('path');
const { sleep } = require('@librechat/agents');
const { getToolkitKey } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const { zodToJsonSchema } = require('zod-to-json-schema');
const { getToolkitKey, getUserMCPAuthMap } = require('@librechat/api');
const { Calculator } = require('@langchain/community/tools/calculator');
const { tool: toolFn, Tool, DynamicStructuredTool } = require('@langchain/core/tools');
const {
@ -33,12 +33,17 @@ const {
toolkits,
} = require('~/app/clients/tools');
const { processFileURL, uploadImageBuffer } = require('~/server/services/Files/process');
const { getEndpointsConfig, getCachedTools } = require('~/server/services/Config');
const {
getEndpointsConfig,
hasCustomUserVars,
getCachedTools,
} = require('~/server/services/Config');
const { createOnSearchResults } = require('~/server/services/Tools/search');
const { isActionDomainAllowed } = require('~/server/services/domains');
const { recordUsage } = require('~/server/services/Threads');
const { loadTools } = require('~/app/clients/tools/util');
const { redactMessage } = require('~/config/parsers');
const { findPluginAuthsByKeys } = require('~/models');
/**
* Loads and formats tools from the specified tool directory.
@ -469,11 +474,12 @@ async function processRequiredActions(client, requiredActions) {
* @param {Object} params - Run params containing user and request information.
* @param {ServerRequest} params.req - The request object.
* @param {ServerResponse} params.res - The request object.
* @param {AbortSignal} params.signal
* @param {Pick<Agent, 'id' | 'provider' | 'model' | 'tools'} params.agent - The agent to load tools for.
* @param {string | undefined} [params.openAIApiKey] - The OpenAI API key.
* @returns {Promise<{ tools?: StructuredTool[] }>} The agent tools.
* @returns {Promise<{ tools?: StructuredTool[]; userMCPAuthMap?: Record<string, Record<string, string>> }>} The agent tools.
*/
async function loadAgentTools({ req, res, agent, tool_resources, openAIApiKey }) {
async function loadAgentTools({ req, res, agent, signal, tool_resources, openAIApiKey }) {
if (!agent.tools || agent.tools.length === 0) {
return {};
} else if (agent.tools && agent.tools.length === 1 && agent.tools[0] === AgentCapabilities.ocr) {
@ -523,8 +529,20 @@ async function loadAgentTools({ req, res, agent, tool_resources, openAIApiKey })
webSearchCallbacks = createOnSearchResults(res);
}
/** @type {Record<string, Record<string, string>>} */
let userMCPAuthMap;
if (await hasCustomUserVars()) {
userMCPAuthMap = await getUserMCPAuthMap({
tools: agent.tools,
userId: req.user.id,
findPluginAuthsByKeys,
});
}
const { loadedTools, toolContextMap } = await loadTools({
agent,
signal,
userMCPAuthMap,
functions: true,
user: req.user.id,
tools: _agentTools,
@ -588,6 +606,7 @@ async function loadAgentTools({ req, res, agent, tool_resources, openAIApiKey })
if (!checkCapability(AgentCapabilities.actions)) {
return {
tools: agentTools,
userMCPAuthMap,
toolContextMap,
};
}
@ -599,6 +618,7 @@ async function loadAgentTools({ req, res, agent, tool_resources, openAIApiKey })
}
return {
tools: agentTools,
userMCPAuthMap,
toolContextMap,
};
}
@ -707,6 +727,7 @@ async function loadAgentTools({ req, res, agent, tool_resources, openAIApiKey })
return {
tools: agentTools,
toolContextMap,
userMCPAuthMap,
};
}

View file

@ -0,0 +1,142 @@
const { logger } = require('@librechat/data-schemas');
const { CacheKeys, Constants } = require('librechat-data-provider');
const { findToken, createToken, updateToken, deleteTokens } = require('~/models');
const { getMCPManager, getFlowStateManager } = require('~/config');
const { updateMCPUserTools } = require('~/server/services/Config');
const { getLogStores } = require('~/cache');
/**
* @param {Object} params
* @param {ServerRequest} params.req
* @param {string} params.serverName - The name of the MCP server
* @param {boolean} params.returnOnOAuth - Whether to initiate OAuth and return, or wait for OAuth flow to finish
* @param {AbortSignal} [params.signal] - The abort signal to handle cancellation.
* @param {boolean} [params.forceNew]
* @param {number} [params.connectionTimeout]
* @param {FlowStateManager<any>} [params.flowManager]
* @param {(authURL: string) => Promise<boolean>} [params.oauthStart]
* @param {Record<string, Record<string, string>>} [params.userMCPAuthMap]
*/
async function reinitMCPServer({
req,
signal,
forceNew,
serverName,
userMCPAuthMap,
connectionTimeout,
returnOnOAuth = true,
oauthStart: _oauthStart,
flowManager: _flowManager,
}) {
/** @type {MCPConnection | null} */
let userConnection = null;
/** @type {LCAvailableTools | null} */
let availableTools = null;
/** @type {ReturnType<MCPConnection['fetchTools']> | null} */
let tools = null;
let oauthRequired = false;
let oauthUrl = null;
try {
const customUserVars = userMCPAuthMap?.[`${Constants.mcp_prefix}${serverName}`];
const flowManager = _flowManager ?? getFlowStateManager(getLogStores(CacheKeys.FLOWS));
const mcpManager = getMCPManager();
const oauthStart =
_oauthStart ??
(async (authURL) => {
logger.info(`[MCP Reinitialize] OAuth URL received: ${authURL}`);
oauthUrl = authURL;
oauthRequired = true;
});
try {
userConnection = await mcpManager.getUserConnection({
user: req.user,
signal,
forceNew,
oauthStart,
serverName,
flowManager,
returnOnOAuth,
customUserVars,
connectionTimeout,
tokenMethods: {
findToken,
updateToken,
createToken,
deleteTokens,
},
});
logger.info(`[MCP Reinitialize] Successfully established connection for ${serverName}`);
} catch (err) {
logger.info(`[MCP Reinitialize] getUserConnection threw error: ${err.message}`);
logger.info(
`[MCP Reinitialize] OAuth state - oauthRequired: ${oauthRequired}, oauthUrl: ${oauthUrl ? 'present' : 'null'}`,
);
const isOAuthError =
err.message?.includes('OAuth') ||
err.message?.includes('authentication') ||
err.message?.includes('401');
const isOAuthFlowInitiated = err.message === 'OAuth flow initiated - return early';
if (isOAuthError || oauthRequired || isOAuthFlowInitiated) {
logger.info(
`[MCP Reinitialize] OAuth required for ${serverName} (isOAuthError: ${isOAuthError}, oauthRequired: ${oauthRequired}, isOAuthFlowInitiated: ${isOAuthFlowInitiated})`,
);
oauthRequired = true;
} else {
logger.error(
`[MCP Reinitialize] Error initializing MCP server ${serverName} for user:`,
err,
);
}
}
if (userConnection && !oauthRequired) {
tools = await userConnection.fetchTools();
availableTools = await updateMCPUserTools({
userId: req.user.id,
serverName,
tools,
});
}
logger.debug(
`[MCP Reinitialize] Sending response for ${serverName} - oauthRequired: ${oauthRequired}, oauthUrl: ${oauthUrl ? 'present' : 'null'}`,
);
const getResponseMessage = () => {
if (oauthRequired) {
return `MCP server '${serverName}' ready for OAuth authentication`;
}
if (userConnection) {
return `MCP server '${serverName}' reinitialized successfully`;
}
return `Failed to reinitialize MCP server '${serverName}'`;
};
const result = {
availableTools,
success: Boolean((userConnection && !oauthRequired) || (oauthRequired && oauthUrl)),
message: getResponseMessage(),
oauthRequired,
serverName,
oauthUrl,
tools,
};
logger.debug(`[MCP Reinitialize] Response for ${serverName}:`, result);
return result;
} catch (error) {
logger.error(
'[MCP Reinitialize] Error loading MCP Tools, servers may still be initializing:',
error,
);
}
}
module.exports = {
reinitMCPServer,
};