🚦 feat: Auto-reinitialize MCP Servers on Request (#9226)

This commit is contained in:
Danny Avila 2025-08-23 03:27:05 -04:00
parent ac608ded46
commit c827fdd10e
No known key found for this signature in database
GPG key ID: BF31EEB2C5CA0956
28 changed files with 871 additions and 312 deletions

View file

@ -3,7 +3,7 @@ const { SerpAPI } = require('@langchain/community/tools/serpapi');
const { Calculator } = require('@langchain/community/tools/calculator'); const { Calculator } = require('@langchain/community/tools/calculator');
const { mcpToolPattern, loadWebSearchAuth } = require('@librechat/api'); const { mcpToolPattern, loadWebSearchAuth } = require('@librechat/api');
const { EnvVar, createCodeExecutionTool, createSearchTool } = require('@librechat/agents'); const { EnvVar, createCodeExecutionTool, createSearchTool } = require('@librechat/agents');
const { Tools, EToolResources, replaceSpecialVars } = require('librechat-data-provider'); const { Tools, Constants, EToolResources, replaceSpecialVars } = require('librechat-data-provider');
const { const {
availableTools, availableTools,
manifestToolMap, manifestToolMap,
@ -24,9 +24,9 @@ const {
const { primeFiles: primeCodeFiles } = require('~/server/services/Files/Code/process'); const { primeFiles: primeCodeFiles } = require('~/server/services/Files/Code/process');
const { createFileSearchTool, primeFiles: primeSearchFiles } = require('./fileSearch'); const { createFileSearchTool, primeFiles: primeSearchFiles } = require('./fileSearch');
const { getUserPluginAuthValue } = require('~/server/services/PluginService'); const { getUserPluginAuthValue } = require('~/server/services/PluginService');
const { createMCPTool, createMCPTools } = require('~/server/services/MCP');
const { loadAuthValues } = require('~/server/services/Tools/credentials'); const { loadAuthValues } = require('~/server/services/Tools/credentials');
const { getCachedTools } = require('~/server/services/Config'); const { getCachedTools } = require('~/server/services/Config');
const { createMCPTool } = require('~/server/services/MCP');
/** /**
* Validates the availability and authentication of tools for a user based on environment variables or user-specific plugin authentication values. * Validates the availability and authentication of tools for a user based on environment variables or user-specific plugin authentication values.
@ -123,6 +123,8 @@ const getAuthFields = (toolKey) => {
* *
* @param {object} object * @param {object} object
* @param {string} object.user * @param {string} object.user
* @param {Record<string, Record<string, string>>} [object.userMCPAuthMap]
* @param {AbortSignal} [object.signal]
* @param {Pick<Agent, 'id' | 'provider' | 'model'>} [object.agent] * @param {Pick<Agent, 'id' | 'provider' | 'model'>} [object.agent]
* @param {string} [object.model] * @param {string} [object.model]
* @param {EModelEndpoint} [object.endpoint] * @param {EModelEndpoint} [object.endpoint]
@ -137,7 +139,9 @@ const loadTools = async ({
user, user,
agent, agent,
model, model,
signal,
endpoint, endpoint,
userMCPAuthMap,
tools = [], tools = [],
options = {}, options = {},
functions = true, functions = true,
@ -231,6 +235,7 @@ const loadTools = async ({
/** @type {Record<string, string>} */ /** @type {Record<string, string>} */
const toolContextMap = {}; const toolContextMap = {};
const cachedTools = (await getCachedTools({ userId: user, includeGlobal: true })) ?? {}; const cachedTools = (await getCachedTools({ userId: user, includeGlobal: true })) ?? {};
const requestedMCPTools = {};
for (const tool of tools) { for (const tool of tools) {
if (tool === Tools.execute_code) { if (tool === Tools.execute_code) {
@ -299,14 +304,35 @@ Current Date & Time: ${replaceSpecialVars({ text: '{{iso_datetime}}' })}
}; };
continue; continue;
} else if (tool && cachedTools && mcpToolPattern.test(tool)) { } else if (tool && cachedTools && mcpToolPattern.test(tool)) {
requestedTools[tool] = async () => const [toolName, serverName] = tool.split(Constants.mcp_delimiter);
if (toolName === Constants.mcp_all) {
const currentMCPGenerator = async (index) =>
createMCPTools({
req: options.req,
res: options.res,
index,
serverName,
userMCPAuthMap,
model: agent?.model ?? model,
provider: agent?.provider ?? endpoint,
signal,
});
requestedMCPTools[serverName] = [currentMCPGenerator];
continue;
}
const currentMCPGenerator = async (index) =>
createMCPTool({ createMCPTool({
index,
req: options.req, req: options.req,
res: options.res, res: options.res,
toolKey: tool, toolKey: tool,
userMCPAuthMap,
model: agent?.model ?? model, model: agent?.model ?? model,
provider: agent?.provider ?? endpoint, provider: agent?.provider ?? endpoint,
signal,
}); });
requestedMCPTools[serverName] = requestedMCPTools[serverName] || [];
requestedMCPTools[serverName].push(currentMCPGenerator);
continue; continue;
} }
@ -346,6 +372,34 @@ Current Date & Time: ${replaceSpecialVars({ text: '{{iso_datetime}}' })}
} }
const loadedTools = (await Promise.all(toolPromises)).flatMap((plugin) => plugin || []); const loadedTools = (await Promise.all(toolPromises)).flatMap((plugin) => plugin || []);
const mcpToolPromises = [];
/** MCP server tools are initialized sequentially by server */
let index = -1;
for (const [serverName, generators] of Object.entries(requestedMCPTools)) {
index++;
for (const generator of generators) {
try {
if (generator && generators.length === 1) {
mcpToolPromises.push(
generator(index).catch((error) => {
logger.error(`Error loading ${serverName} tools:`, error);
return null;
}),
);
continue;
}
const mcpTool = await generator(index);
if (Array.isArray(mcpTool)) {
loadedTools.push(...mcpTool);
} else if (mcpTool) {
loadedTools.push(mcpTool);
}
} catch (error) {
logger.error(`Error loading MCP tool for server ${serverName}:`, error);
}
}
}
loadedTools.push(...(await Promise.all(mcpToolPromises)).flatMap((plugin) => plugin || []));
return { loadedTools, toolContextMap }; return { loadedTools, toolContextMap };
}; };

View file

@ -2,7 +2,7 @@ const mongoose = require('mongoose');
const crypto = require('node:crypto'); const crypto = require('node:crypto');
const { logger } = require('@librechat/data-schemas'); const { logger } = require('@librechat/data-schemas');
const { ResourceType, SystemRoles, Tools, actionDelimiter } = require('librechat-data-provider'); const { ResourceType, SystemRoles, Tools, actionDelimiter } = require('librechat-data-provider');
const { GLOBAL_PROJECT_NAME, EPHEMERAL_AGENT_ID, mcp_delimiter } = const { GLOBAL_PROJECT_NAME, EPHEMERAL_AGENT_ID, mcp_all, mcp_delimiter } =
require('librechat-data-provider').Constants; require('librechat-data-provider').Constants;
const { const {
removeAgentFromAllProjects, removeAgentFromAllProjects,
@ -78,6 +78,7 @@ const loadEphemeralAgent = async ({ req, agent_id, endpoint, model_parameters: _
tools.push(Tools.web_search); tools.push(Tools.web_search);
} }
const addedServers = new Set();
if (mcpServers.size > 0) { if (mcpServers.size > 0) {
for (const toolName of Object.keys(availableTools)) { for (const toolName of Object.keys(availableTools)) {
if (!toolName.includes(mcp_delimiter)) { if (!toolName.includes(mcp_delimiter)) {
@ -85,9 +86,17 @@ const loadEphemeralAgent = async ({ req, agent_id, endpoint, model_parameters: _
} }
const mcpServer = toolName.split(mcp_delimiter)?.[1]; const mcpServer = toolName.split(mcp_delimiter)?.[1];
if (mcpServer && mcpServers.has(mcpServer)) { if (mcpServer && mcpServers.has(mcpServer)) {
addedServers.add(mcpServer);
tools.push(toolName); tools.push(toolName);
} }
} }
for (const mcpServer of mcpServers) {
if (addedServers.has(mcpServer)) {
continue;
}
tools.push(`${mcp_all}${mcp_delimiter}${mcpServer}`);
}
} }
const instructions = req.body.promptPrefix; const instructions = req.body.promptPrefix;

View file

@ -33,18 +33,13 @@ const {
bedrockInputSchema, bedrockInputSchema,
removeNullishValues, removeNullishValues,
} = require('librechat-data-provider'); } = require('librechat-data-provider');
const {
findPluginAuthsByKeys,
getFormattedMemories,
deleteMemory,
setMemory,
} = require('~/models');
const { getMCPAuthMap, checkCapability, hasCustomUserVars } = require('~/server/services/Config');
const { addCacheControl, createContextHandlers } = require('~/app/clients/prompts'); const { addCacheControl, createContextHandlers } = require('~/app/clients/prompts');
const { initializeAgent } = require('~/server/services/Endpoints/agents/agent'); const { initializeAgent } = require('~/server/services/Endpoints/agents/agent');
const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens'); const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens');
const { getFormattedMemories, deleteMemory, setMemory } = require('~/models');
const { encodeAndFormat } = require('~/server/services/Files/images/encode'); const { encodeAndFormat } = require('~/server/services/Files/images/encode');
const { getProviderConfig } = require('~/server/services/Endpoints'); const { getProviderConfig } = require('~/server/services/Endpoints');
const { checkCapability } = require('~/server/services/Config');
const BaseClient = require('~/app/clients/BaseClient'); const BaseClient = require('~/app/clients/BaseClient');
const { getRoleByName } = require('~/models/Role'); const { getRoleByName } = require('~/models/Role');
const { loadAgent } = require('~/models/Agent'); const { loadAgent } = require('~/models/Agent');
@ -615,6 +610,7 @@ class AgentClient extends BaseClient {
await this.chatCompletion({ await this.chatCompletion({
payload, payload,
onProgress: opts.onProgress, onProgress: opts.onProgress,
userMCPAuthMap: opts.userMCPAuthMap,
abortController: opts.abortController, abortController: opts.abortController,
}); });
return this.contentParts; return this.contentParts;
@ -747,7 +743,13 @@ class AgentClient extends BaseClient {
return currentMessageTokens > 0 ? currentMessageTokens : originalEstimate; return currentMessageTokens > 0 ? currentMessageTokens : originalEstimate;
} }
async chatCompletion({ payload, abortController = null }) { /**
* @param {object} params
* @param {string | ChatCompletionMessageParam[]} params.payload
* @param {Record<string, Record<string, string>>} [params.userMCPAuthMap]
* @param {AbortController} [params.abortController]
*/
async chatCompletion({ payload, userMCPAuthMap, abortController = null }) {
/** @type {Partial<GraphRunnableConfig>} */ /** @type {Partial<GraphRunnableConfig>} */
let config; let config;
/** @type {ReturnType<createRun>} */ /** @type {ReturnType<createRun>} */
@ -903,21 +905,9 @@ class AgentClient extends BaseClient {
run.Graph.contentData = contentData; run.Graph.contentData = contentData;
} }
try { if (userMCPAuthMap != null) {
if (await hasCustomUserVars()) { config.configurable.userMCPAuthMap = userMCPAuthMap;
config.configurable.userMCPAuthMap = await getMCPAuthMap({
tools: agent.tools,
userId: this.options.req.user.id,
findPluginAuthsByKeys,
});
}
} catch (err) {
logger.error(
`[api/server/controllers/agents/client.js #chatCompletion] Error getting custom user vars for agent ${agent.id}`,
err,
);
} }
await run.processStream({ messages }, config, { await run.processStream({ messages }, config, {
keepContent: i !== 0, keepContent: i !== 0,
tokenCounter: createTokenCounter(this.getEncoding()), tokenCounter: createTokenCounter(this.getEncoding()),

View file

@ -9,6 +9,24 @@ const {
const { disposeClient, clientRegistry, requestDataMap } = require('~/server/cleanup'); const { disposeClient, clientRegistry, requestDataMap } = require('~/server/cleanup');
const { saveMessage } = require('~/models'); const { saveMessage } = require('~/models');
function createCloseHandler(abortController) {
return function (manual) {
if (!manual) {
logger.debug('[AgentController] Request closed');
}
if (!abortController) {
return;
} else if (abortController.signal.aborted) {
return;
} else if (abortController.requestCompleted) {
return;
}
abortController.abort();
logger.debug('[AgentController] Request aborted on close');
};
}
const AgentController = async (req, res, next, initializeClient, addTitle) => { const AgentController = async (req, res, next, initializeClient, addTitle) => {
let { let {
text, text,
@ -31,7 +49,6 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
let userMessagePromise; let userMessagePromise;
let getAbortData; let getAbortData;
let client = null; let client = null;
// Initialize as an array
let cleanupHandlers = []; let cleanupHandlers = [];
const newConvo = !conversationId; const newConvo = !conversationId;
@ -62,9 +79,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
// Create a function to handle final cleanup // Create a function to handle final cleanup
const performCleanup = () => { const performCleanup = () => {
logger.debug('[AgentController] Performing cleanup'); logger.debug('[AgentController] Performing cleanup');
// Make sure cleanupHandlers is an array before iterating
if (Array.isArray(cleanupHandlers)) { if (Array.isArray(cleanupHandlers)) {
// Execute all cleanup handlers
for (const handler of cleanupHandlers) { for (const handler of cleanupHandlers) {
try { try {
if (typeof handler === 'function') { if (typeof handler === 'function') {
@ -105,8 +120,33 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
}; };
try { try {
/** @type {{ client: TAgentClient }} */ let prelimAbortController = new AbortController();
const result = await initializeClient({ req, res, endpointOption }); const prelimCloseHandler = createCloseHandler(prelimAbortController);
res.on('close', prelimCloseHandler);
const removePrelimHandler = (manual) => {
try {
prelimCloseHandler(manual);
res.removeListener('close', prelimCloseHandler);
} catch (e) {
logger.error('[AgentController] Error removing close listener', e);
}
};
cleanupHandlers.push(removePrelimHandler);
/** @type {{ client: TAgentClient; userMCPAuthMap?: Record<string, Record<string, string>> }} */
const result = await initializeClient({
req,
res,
endpointOption,
signal: prelimAbortController.signal,
});
if (prelimAbortController.signal?.aborted) {
prelimAbortController = null;
throw new Error('Request was aborted before initialization could complete');
} else {
prelimAbortController = null;
removePrelimHandler(true);
cleanupHandlers.pop();
}
client = result.client; client = result.client;
// Register client with finalization registry if available // Register client with finalization registry if available
@ -138,22 +178,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
}; };
const { abortController, onStart } = createAbortController(req, res, getAbortData, getReqData); const { abortController, onStart } = createAbortController(req, res, getAbortData, getReqData);
const closeHandler = createCloseHandler(abortController);
// Simple handler to avoid capturing scope
const closeHandler = () => {
logger.debug('[AgentController] Request closed');
if (!abortController) {
return;
} else if (abortController.signal.aborted) {
return;
} else if (abortController.requestCompleted) {
return;
}
abortController.abort();
logger.debug('[AgentController] Request aborted on close');
};
res.on('close', closeHandler); res.on('close', closeHandler);
cleanupHandlers.push(() => { cleanupHandlers.push(() => {
try { try {
@ -175,6 +200,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
abortController, abortController,
overrideParentMessageId, overrideParentMessageId,
isEdited: !!editedContent, isEdited: !!editedContent,
userMCPAuthMap: result.userMCPAuthMap,
responseMessageId: editedResponseMessageId, responseMessageId: editedResponseMessageId,
progressOptions: { progressOptions: {
res, res,

View file

@ -11,6 +11,7 @@ jest.mock('@librechat/api', () => ({
completeOAuthFlow: jest.fn(), completeOAuthFlow: jest.fn(),
generateFlowId: jest.fn(), generateFlowId: jest.fn(),
}, },
getUserMCPAuthMap: jest.fn(),
})); }));
jest.mock('@librechat/data-schemas', () => ({ jest.mock('@librechat/data-schemas', () => ({
@ -37,6 +38,7 @@ jest.mock('~/models', () => ({
updateToken: jest.fn(), updateToken: jest.fn(),
createToken: jest.fn(), createToken: jest.fn(),
deleteTokens: jest.fn(), deleteTokens: jest.fn(),
findPluginAuthsByKeys: jest.fn(),
})); }));
jest.mock('~/server/services/Config', () => ({ jest.mock('~/server/services/Config', () => ({
@ -71,6 +73,10 @@ jest.mock('~/server/middleware', () => ({
requireJwtAuth: (req, res, next) => next(), requireJwtAuth: (req, res, next) => next(),
})); }));
jest.mock('~/server/services/Tools/mcp', () => ({
reinitMCPServer: jest.fn(),
}));
describe('MCP Routes', () => { describe('MCP Routes', () => {
let app; let app;
let mongoServer; let mongoServer;
@ -682,6 +688,13 @@ describe('MCP Routes', () => {
require('~/config').getMCPManager.mockReturnValue(mockMcpManager); require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
require('~/config').getFlowStateManager.mockReturnValue({}); require('~/config').getFlowStateManager.mockReturnValue({});
require('~/cache').getLogStores.mockReturnValue({}); require('~/cache').getLogStores.mockReturnValue({});
require('~/server/services/Tools/mcp').reinitMCPServer.mockResolvedValue({
success: true,
message: "MCP server 'oauth-server' ready for OAuth authentication",
serverName: 'oauth-server',
oauthRequired: true,
oauthUrl: 'https://oauth.example.com/auth',
});
const response = await request(app).post('/api/mcp/oauth-server/reinitialize'); const response = await request(app).post('/api/mcp/oauth-server/reinitialize');
@ -706,6 +719,7 @@ describe('MCP Routes', () => {
require('~/config').getMCPManager.mockReturnValue(mockMcpManager); require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
require('~/config').getFlowStateManager.mockReturnValue({}); require('~/config').getFlowStateManager.mockReturnValue({});
require('~/cache').getLogStores.mockReturnValue({}); require('~/cache').getLogStores.mockReturnValue({});
require('~/server/services/Tools/mcp').reinitMCPServer.mockResolvedValue(null);
const response = await request(app).post('/api/mcp/error-server/reinitialize'); const response = await request(app).post('/api/mcp/error-server/reinitialize');
@ -769,6 +783,14 @@ describe('MCP Routes', () => {
setCachedTools.mockResolvedValue(); setCachedTools.mockResolvedValue();
updateMCPUserTools.mockResolvedValue(); updateMCPUserTools.mockResolvedValue();
require('~/server/services/Tools/mcp').reinitMCPServer.mockResolvedValue({
success: true,
message: "MCP server 'test-server' reinitialized successfully",
serverName: 'test-server',
oauthRequired: false,
oauthUrl: null,
});
const response = await request(app).post('/api/mcp/test-server/reinitialize'); const response = await request(app).post('/api/mcp/test-server/reinitialize');
expect(response.status).toBe(200); expect(response.status).toBe(200);
@ -783,14 +805,6 @@ describe('MCP Routes', () => {
'test-user-id', 'test-user-id',
'test-server', 'test-server',
); );
expect(updateMCPUserTools).toHaveBeenCalledWith({
userId: 'test-user-id',
serverName: 'test-server',
tools: [
{ name: 'tool1', description: 'Test tool 1', inputSchema: { type: 'object' } },
{ name: 'tool2', description: 'Test tool 2', inputSchema: { type: 'object' } },
],
});
}); });
it('should handle server with custom user variables', async () => { it('should handle server with custom user variables', async () => {
@ -812,9 +826,14 @@ describe('MCP Routes', () => {
require('~/config').getMCPManager.mockReturnValue(mockMcpManager); require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
require('~/config').getFlowStateManager.mockReturnValue({}); require('~/config').getFlowStateManager.mockReturnValue({});
require('~/cache').getLogStores.mockReturnValue({}); require('~/cache').getLogStores.mockReturnValue({});
require('~/server/services/PluginService').getUserPluginAuthValue.mockResolvedValue( require('@librechat/api').getUserMCPAuthMap.mockResolvedValue({
'api-key-value', 'mcp:test-server': {
); API_KEY: 'api-key-value',
},
});
require('~/models').findPluginAuthsByKeys.mockResolvedValue([
{ key: 'API_KEY', value: 'api-key-value' },
]);
const { getCachedTools, setCachedTools } = require('~/server/services/Config'); const { getCachedTools, setCachedTools } = require('~/server/services/Config');
const { updateMCPUserTools } = require('~/server/services/Config/mcpToolsCache'); const { updateMCPUserTools } = require('~/server/services/Config/mcpToolsCache');
@ -822,13 +841,23 @@ describe('MCP Routes', () => {
setCachedTools.mockResolvedValue(); setCachedTools.mockResolvedValue();
updateMCPUserTools.mockResolvedValue(); updateMCPUserTools.mockResolvedValue();
require('~/server/services/Tools/mcp').reinitMCPServer.mockResolvedValue({
success: true,
message: "MCP server 'test-server' reinitialized successfully",
serverName: 'test-server',
oauthRequired: false,
oauthUrl: null,
});
const response = await request(app).post('/api/mcp/test-server/reinitialize'); const response = await request(app).post('/api/mcp/test-server/reinitialize');
expect(response.status).toBe(200); expect(response.status).toBe(200);
expect(response.body.success).toBe(true); expect(response.body.success).toBe(true);
expect( expect(require('@librechat/api').getUserMCPAuthMap).toHaveBeenCalledWith({
require('~/server/services/PluginService').getUserPluginAuthValue, userId: 'test-user-id',
).toHaveBeenCalledWith('test-user-id', 'API_KEY', false); servers: ['test-server'],
findPluginAuthsByKeys: require('~/models').findPluginAuthsByKeys,
});
}); });
}); });

View file

@ -1,13 +1,15 @@
const { logger } = require('@librechat/data-schemas');
const { MCPOAuthHandler } = require('@librechat/api');
const { Router } = require('express'); const { Router } = require('express');
const { logger } = require('@librechat/data-schemas');
const { MCPOAuthHandler, getUserMCPAuthMap } = require('@librechat/api');
const { getMCPSetupData, getServerConnectionStatus } = require('~/server/services/MCP'); const { getMCPSetupData, getServerConnectionStatus } = require('~/server/services/MCP');
const { findToken, updateToken, createToken, deleteTokens } = require('~/models'); const { findToken, updateToken, createToken, deleteTokens } = require('~/models');
const { updateMCPUserTools } = require('~/server/services/Config/mcpToolsCache'); const { updateMCPUserTools } = require('~/server/services/Config/mcpToolsCache');
const { getUserPluginAuthValue } = require('~/server/services/PluginService'); const { getUserPluginAuthValue } = require('~/server/services/PluginService');
const { CacheKeys, Constants } = require('librechat-data-provider'); const { CacheKeys, Constants } = require('librechat-data-provider');
const { getMCPManager, getFlowStateManager } = require('~/config'); const { getMCPManager, getFlowStateManager } = require('~/config');
const { reinitMCPServer } = require('~/server/services/Tools/mcp');
const { requireJwtAuth } = require('~/server/middleware'); const { requireJwtAuth } = require('~/server/middleware');
const { findPluginAuthsByKeys } = require('~/models');
const { getLogStores } = require('~/cache'); const { getLogStores } = require('~/cache');
const router = Router(); const router = Router();
@ -302,107 +304,39 @@ router.post('/:serverName/reinitialize', requireJwtAuth, async (req, res) => {
}); });
} }
const flowsCache = getLogStores(CacheKeys.FLOWS);
const flowManager = getFlowStateManager(flowsCache);
await mcpManager.disconnectUserConnection(user.id, serverName); await mcpManager.disconnectUserConnection(user.id, serverName);
logger.info( logger.info(
`[MCP Reinitialize] Disconnected existing user connection for server: ${serverName}`, `[MCP Reinitialize] Disconnected existing user connection for server: ${serverName}`,
); );
let customUserVars = {}; /** @type {Record<string, Record<string, string>> | undefined} */
let userMCPAuthMap;
if (serverConfig.customUserVars && typeof serverConfig.customUserVars === 'object') { if (serverConfig.customUserVars && typeof serverConfig.customUserVars === 'object') {
for (const varName of Object.keys(serverConfig.customUserVars)) { userMCPAuthMap = await getUserMCPAuthMap({
try {
const value = await getUserPluginAuthValue(user.id, varName, false);
customUserVars[varName] = value;
} catch (err) {
logger.error(`[MCP Reinitialize] Error fetching ${varName} for user ${user.id}:`, err);
}
}
}
let userConnection = null;
let oauthRequired = false;
let oauthUrl = null;
try {
userConnection = await mcpManager.getUserConnection({
user,
serverName,
flowManager,
customUserVars,
tokenMethods: {
findToken,
updateToken,
createToken,
deleteTokens,
},
returnOnOAuth: true,
oauthStart: async (authURL) => {
logger.info(`[MCP Reinitialize] OAuth URL received: ${authURL}`);
oauthUrl = authURL;
oauthRequired = true;
},
});
logger.info(`[MCP Reinitialize] Successfully established connection for ${serverName}`);
} catch (err) {
logger.info(`[MCP Reinitialize] getUserConnection threw error: ${err.message}`);
logger.info(
`[MCP Reinitialize] OAuth state - oauthRequired: ${oauthRequired}, oauthUrl: ${oauthUrl ? 'present' : 'null'}`,
);
const isOAuthError =
err.message?.includes('OAuth') ||
err.message?.includes('authentication') ||
err.message?.includes('401');
const isOAuthFlowInitiated = err.message === 'OAuth flow initiated - return early';
if (isOAuthError || oauthRequired || isOAuthFlowInitiated) {
logger.info(
`[MCP Reinitialize] OAuth required for ${serverName} (isOAuthError: ${isOAuthError}, oauthRequired: ${oauthRequired}, isOAuthFlowInitiated: ${isOAuthFlowInitiated})`,
);
oauthRequired = true;
} else {
logger.error(
`[MCP Reinitialize] Error initializing MCP server ${serverName} for user:`,
err,
);
return res.status(500).json({ error: 'Failed to reinitialize MCP server for user' });
}
}
if (userConnection && !oauthRequired) {
const tools = await userConnection.fetchTools();
await updateMCPUserTools({
userId: user.id, userId: user.id,
serverName, servers: [serverName],
tools, findPluginAuthsByKeys,
}); });
} }
logger.debug( const result = await reinitMCPServer({
`[MCP Reinitialize] Sending response for ${serverName} - oauthRequired: ${oauthRequired}, oauthUrl: ${oauthUrl ? 'present' : 'null'}`, req,
); serverName,
userMCPAuthMap,
});
const getResponseMessage = () => { if (!result) {
if (oauthRequired) { return res.status(500).json({ error: 'Failed to reinitialize MCP server for user' });
return `MCP server '${serverName}' ready for OAuth authentication`; }
}
if (userConnection) { const { success, message, oauthRequired, oauthUrl } = result;
return `MCP server '${serverName}' reinitialized successfully`;
}
return `Failed to reinitialize MCP server '${serverName}'`;
};
res.json({ res.json({
success: Boolean((userConnection && !oauthRequired) || (oauthRequired && oauthUrl)), success,
message: getResponseMessage(), message,
oauthUrl,
serverName, serverName,
oauthRequired, oauthRequired,
oauthUrl,
}); });
} catch (error) { } catch (error) {
logger.error('[MCP Reinitialize] Unexpected error', error); logger.error('[MCP Reinitialize] Unexpected error', error);

View file

@ -26,7 +26,7 @@ const ToolCacheKeys = {
* @param {string[]} [options.roleIds] - Role IDs for role-based tools * @param {string[]} [options.roleIds] - Role IDs for role-based tools
* @param {string[]} [options.groupIds] - Group IDs for group-based tools * @param {string[]} [options.groupIds] - Group IDs for group-based tools
* @param {boolean} [options.includeGlobal=true] - Whether to include global tools * @param {boolean} [options.includeGlobal=true] - Whether to include global tools
* @returns {Promise<Object|null>} The available tools object or null if not cached * @returns {Promise<LCAvailableTools|null>} The available tools object or null if not cached
*/ */
async function getCachedTools(options = {}) { async function getCachedTools(options = {}) {
const cache = getLogStores(CacheKeys.CONFIG_STORE); const cache = getLogStores(CacheKeys.CONFIG_STORE);
@ -41,13 +41,13 @@ async function getCachedTools(options = {}) {
// Future implementation will merge tools from multiple sources // Future implementation will merge tools from multiple sources
// based on user permissions, roles, and groups // based on user permissions, roles, and groups
if (userId) { if (userId) {
// Check if we have pre-computed effective tools for this user /** @type {LCAvailableTools | null} Check if we have pre-computed effective tools for this user */
const effectiveTools = await cache.get(ToolCacheKeys.EFFECTIVE(userId)); const effectiveTools = await cache.get(ToolCacheKeys.EFFECTIVE(userId));
if (effectiveTools) { if (effectiveTools) {
return effectiveTools; return effectiveTools;
} }
// Otherwise, compute from individual sources /** @type {LCAvailableTools | null} Otherwise, compute from individual sources */
const toolSources = []; const toolSources = [];
if (includeGlobal) { if (includeGlobal) {

View file

@ -1,5 +1,4 @@
const { logger } = require('@librechat/data-schemas'); const { isEnabled } = require('@librechat/api');
const { isEnabled, getUserMCPAuthMap } = require('@librechat/api');
const { CacheKeys, EModelEndpoint } = require('librechat-data-provider'); const { CacheKeys, EModelEndpoint } = require('librechat-data-provider');
const { normalizeEndpointName } = require('~/server/utils'); const { normalizeEndpointName } = require('~/server/utils');
const loadCustomConfig = require('./loadCustomConfig'); const loadCustomConfig = require('./loadCustomConfig');
@ -53,31 +52,6 @@ const getCustomEndpointConfig = async (endpoint) => {
); );
}; };
/**
* @param {Object} params
* @param {string} params.userId
* @param {GenericTool[]} [params.tools]
* @param {import('@librechat/data-schemas').PluginAuthMethods['findPluginAuthsByKeys']} params.findPluginAuthsByKeys
* @returns {Promise<Record<string, Record<string, string>> | undefined>}
*/
async function getMCPAuthMap({ userId, tools, findPluginAuthsByKeys }) {
try {
if (!tools || tools.length === 0) {
return;
}
return await getUserMCPAuthMap({
tools,
userId,
findPluginAuthsByKeys,
});
} catch (err) {
logger.error(
`[api/server/controllers/agents/client.js #chatCompletion] Error getting custom user vars for agent`,
err,
);
}
}
/** /**
* @returns {Promise<boolean>} * @returns {Promise<boolean>}
*/ */
@ -88,7 +62,6 @@ async function hasCustomUserVars() {
} }
module.exports = { module.exports = {
getMCPAuthMap,
getCustomConfig, getCustomConfig,
getBalanceConfig, getBalanceConfig,
hasCustomUserVars, hasCustomUserVars,

View file

@ -9,7 +9,7 @@ const { getLogStores } = require('~/cache');
* @param {string} params.userId - User ID * @param {string} params.userId - User ID
* @param {string} params.serverName - MCP server name * @param {string} params.serverName - MCP server name
* @param {Array} params.tools - Array of tool objects from MCP server * @param {Array} params.tools - Array of tool objects from MCP server
* @returns {Promise<void>} * @returns {Promise<LCAvailableTools>}
*/ */
async function updateMCPUserTools({ userId, serverName, tools }) { async function updateMCPUserTools({ userId, serverName, tools }) {
try { try {
@ -39,6 +39,7 @@ async function updateMCPUserTools({ userId, serverName, tools }) {
const cache = getLogStores(CacheKeys.CONFIG_STORE); const cache = getLogStores(CacheKeys.CONFIG_STORE);
await cache.delete(CacheKeys.TOOLS); await cache.delete(CacheKeys.TOOLS);
logger.debug(`[MCP Cache] Updated ${tools.length} tools for ${serverName} user ${userId}`); logger.debug(`[MCP Cache] Updated ${tools.length} tools for ${serverName} user ${userId}`);
return userTools;
} catch (error) { } catch (error) {
logger.error(`[MCP Cache] Failed to update tools for ${serverName}:`, error); logger.error(`[MCP Cache] Failed to update tools for ${serverName}:`, error);
throw error; throw error;

View file

@ -30,7 +30,13 @@ const { getModelMaxTokens } = require('~/utils');
* @param {TEndpointOption} [params.endpointOption] * @param {TEndpointOption} [params.endpointOption]
* @param {Set<string>} [params.allowedProviders] * @param {Set<string>} [params.allowedProviders]
* @param {boolean} [params.isInitialAgent] * @param {boolean} [params.isInitialAgent]
* @returns {Promise<Agent & { tools: StructuredTool[], attachments: Array<MongoFile>, toolContextMap: Record<string, unknown>, maxContextTokens: number }>} * @returns {Promise<Agent & {
* tools: StructuredTool[],
* attachments: Array<MongoFile>,
* toolContextMap: Record<string, unknown>,
* maxContextTokens: number,
* userMCPAuthMap?: Record<string, Record<string, string>>
* }>}
*/ */
const initializeAgent = async ({ const initializeAgent = async ({
req, req,
@ -91,16 +97,19 @@ const initializeAgent = async ({
}); });
const provider = agent.provider; const provider = agent.provider;
const { tools: structuredTools, toolContextMap } = const {
(await loadTools?.({ tools: structuredTools,
req, toolContextMap,
res, userMCPAuthMap,
provider, } = (await loadTools?.({
agentId: agent.id, req,
tools: agent.tools, res,
model: agent.model, provider,
tool_resources, agentId: agent.id,
})) ?? {}; tools: agent.tools,
model: agent.model,
tool_resources,
})) ?? {};
agent.endpoint = provider; agent.endpoint = provider;
const { getOptions, overrideProvider } = await getProviderConfig(provider); const { getOptions, overrideProvider } = await getProviderConfig(provider);
@ -189,6 +198,7 @@ const initializeAgent = async ({
tools, tools,
attachments, attachments,
resendFiles, resendFiles,
userMCPAuthMap,
toolContextMap, toolContextMap,
useLegacyContent: !!options.useLegacyContent, useLegacyContent: !!options.useLegacyContent,
maxContextTokens: Math.round((agentMaxContextTokens - maxTokens) * 0.9), maxContextTokens: Math.round((agentMaxContextTokens - maxTokens) * 0.9),

View file

@ -19,7 +19,10 @@ const AgentClient = require('~/server/controllers/agents/client');
const { getAgent } = require('~/models/Agent'); const { getAgent } = require('~/models/Agent');
const { logViolation } = require('~/cache'); const { logViolation } = require('~/cache');
function createToolLoader() { /**
* @param {AbortSignal} signal
*/
function createToolLoader(signal) {
/** /**
* @param {object} params * @param {object} params
* @param {ServerRequest} params.req * @param {ServerRequest} params.req
@ -29,7 +32,11 @@ function createToolLoader() {
* @param {string} params.provider * @param {string} params.provider
* @param {string} params.model * @param {string} params.model
* @param {AgentToolResources} params.tool_resources * @param {AgentToolResources} params.tool_resources
* @returns {Promise<{ tools: StructuredTool[], toolContextMap: Record<string, unknown> } | undefined>} * @returns {Promise<{
* tools: StructuredTool[],
* toolContextMap: Record<string, unknown>,
* userMCPAuthMap?: Record<string, Record<string, string>>
* } | undefined>}
*/ */
return async function loadTools({ req, res, agentId, tools, provider, model, tool_resources }) { return async function loadTools({ req, res, agentId, tools, provider, model, tool_resources }) {
const agent = { id: agentId, tools, provider, model }; const agent = { id: agentId, tools, provider, model };
@ -38,6 +45,7 @@ function createToolLoader() {
req, req,
res, res,
agent, agent,
signal,
tool_resources, tool_resources,
}); });
} catch (error) { } catch (error) {
@ -46,7 +54,7 @@ function createToolLoader() {
}; };
} }
const initializeClient = async ({ req, res, endpointOption }) => { const initializeClient = async ({ req, res, signal, endpointOption }) => {
if (!endpointOption) { if (!endpointOption) {
throw new Error('Endpoint option not provided'); throw new Error('Endpoint option not provided');
} }
@ -92,7 +100,7 @@ const initializeClient = async ({ req, res, endpointOption }) => {
/** @type {Set<string>} */ /** @type {Set<string>} */
const allowedProviders = new Set(req?.app?.locals?.[EModelEndpoint.agents]?.allowedProviders); const allowedProviders = new Set(req?.app?.locals?.[EModelEndpoint.agents]?.allowedProviders);
const loadTools = createToolLoader(); const loadTools = createToolLoader(signal);
/** @type {Array<MongoFile>} */ /** @type {Array<MongoFile>} */
const requestFiles = req.body.files ?? []; const requestFiles = req.body.files ?? [];
/** @type {string} */ /** @type {string} */
@ -111,6 +119,7 @@ const initializeClient = async ({ req, res, endpointOption }) => {
}); });
const agent_ids = primaryConfig.agent_ids; const agent_ids = primaryConfig.agent_ids;
let userMCPAuthMap = primaryConfig.userMCPAuthMap;
if (agent_ids?.length) { if (agent_ids?.length) {
for (const agentId of agent_ids) { for (const agentId of agent_ids) {
const agent = await getAgent({ id: agentId }); const agent = await getAgent({ id: agentId });
@ -140,6 +149,7 @@ const initializeClient = async ({ req, res, endpointOption }) => {
endpointOption, endpointOption,
allowedProviders, allowedProviders,
}); });
Object.assign(userMCPAuthMap, config.userMCPAuthMap ?? {});
agentConfigs.set(agentId, config); agentConfigs.set(agentId, config);
} }
} }
@ -188,7 +198,7 @@ const initializeClient = async ({ req, res, endpointOption }) => {
: EModelEndpoint.agents, : EModelEndpoint.agents,
}); });
return { client }; return { client, userMCPAuthMap };
}; };
module.exports = { initializeClient }; module.exports = { initializeClient };

View file

@ -1,7 +1,12 @@
const { z } = require('zod'); const { z } = require('zod');
const { tool } = require('@langchain/core/tools'); const { tool } = require('@langchain/core/tools');
const { logger } = require('@librechat/data-schemas'); const { logger } = require('@librechat/data-schemas');
const { Constants: AgentConstants, Providers, GraphEvents } = require('@librechat/agents'); const {
Providers,
StepTypes,
GraphEvents,
Constants: AgentConstants,
} = require('@librechat/agents');
const { const {
sendEvent, sendEvent,
MCPOAuthHandler, MCPOAuthHandler,
@ -11,14 +16,14 @@ const {
const { const {
Time, Time,
CacheKeys, CacheKeys,
StepTypes,
Constants, Constants,
ContentTypes, ContentTypes,
isAssistantsEndpoint, isAssistantsEndpoint,
} = require('librechat-data-provider'); } = require('librechat-data-provider');
const { getCachedTools, loadCustomConfig } = require('./Config');
const { findToken, createToken, updateToken } = require('~/models'); const { findToken, createToken, updateToken } = require('~/models');
const { getMCPManager, getFlowStateManager } = require('~/config'); const { getMCPManager, getFlowStateManager } = require('~/config');
const { getCachedTools, loadCustomConfig } = require('./Config'); const { reinitMCPServer } = require('./Tools/mcp');
const { getLogStores } = require('~/cache'); const { getLogStores } = require('~/cache');
/** /**
@ -26,16 +31,13 @@ const { getLogStores } = require('~/cache');
* @param {ServerResponse} params.res - The Express response object for sending events. * @param {ServerResponse} params.res - The Express response object for sending events.
* @param {string} params.stepId - The ID of the step in the flow. * @param {string} params.stepId - The ID of the step in the flow.
* @param {ToolCallChunk} params.toolCall - The tool call object containing tool information. * @param {ToolCallChunk} params.toolCall - The tool call object containing tool information.
* @param {string} params.loginFlowId - The ID of the login flow.
* @param {FlowStateManager<any>} params.flowManager - The flow manager instance.
*/ */
function createOAuthStart({ res, stepId, toolCall, loginFlowId, flowManager, signal }) { function createRunStepDeltaEmitter({ res, stepId, toolCall }) {
/** /**
* Creates a function to handle OAuth login requests.
* @param {string} authURL - The URL to redirect the user for OAuth authentication. * @param {string} authURL - The URL to redirect the user for OAuth authentication.
* @returns {Promise<boolean>} Returns true to indicate the event was sent successfully. * @returns {void}
*/ */
return async function (authURL) { return function (authURL) {
/** @type {{ id: string; delta: AgentToolCallDelta }} */ /** @type {{ id: string; delta: AgentToolCallDelta }} */
const data = { const data = {
id: stepId, id: stepId,
@ -46,17 +48,54 @@ function createOAuthStart({ res, stepId, toolCall, loginFlowId, flowManager, sig
expires_at: Date.now() + Time.TWO_MINUTES, expires_at: Date.now() + Time.TWO_MINUTES,
}, },
}; };
/** Used to ensure the handler (use of `sendEvent`) is only invoked once */ sendEvent(res, { event: GraphEvents.ON_RUN_STEP_DELTA, data });
await flowManager.createFlowWithHandler( };
loginFlowId, }
'oauth_login',
async () => { /**
sendEvent(res, { event: GraphEvents.ON_RUN_STEP_DELTA, data }); * @param {object} params
logger.debug('Sent OAuth login request to client'); * @param {ServerResponse} params.res - The Express response object for sending events.
return true; * @param {string} params.runId - The Run ID, i.e. message ID
* @param {string} params.stepId - The ID of the step in the flow.
* @param {ToolCallChunk} params.toolCall - The tool call object containing tool information.
* @param {number} [params.index]
*/
function createRunStepEmitter({ res, runId, stepId, toolCall, index }) {
return function () {
/** @type {import('@librechat/agents').RunStep} */
const data = {
runId: runId ?? Constants.USE_PRELIM_RESPONSE_MESSAGE_ID,
id: stepId,
type: StepTypes.TOOL_CALLS,
index: index ?? 0,
stepDetails: {
type: StepTypes.TOOL_CALLS,
tool_calls: [toolCall],
}, },
signal, };
); sendEvent(res, { event: GraphEvents.ON_RUN_STEP, data });
};
}
/**
* Creates a function used to ensure the flow handler is only invoked once
* @param {object} params
* @param {string} params.flowId - The ID of the login flow.
* @param {FlowStateManager<any>} params.flowManager - The flow manager instance.
* @param {(authURL: string) => void} [params.callback]
*/
function createOAuthStart({ flowId, flowManager, callback }) {
/**
* Creates a function to handle OAuth login requests.
* @param {string} authURL - The URL to redirect the user for OAuth authentication.
* @returns {Promise<boolean>} Returns true to indicate the event was sent successfully.
*/
return async function (authURL) {
await flowManager.createFlowWithHandler(flowId, 'oauth_login', async () => {
callback?.(authURL);
logger.debug('Sent OAuth login request to client');
return true;
});
}; };
} }
@ -99,23 +138,166 @@ function createAbortHandler({ userId, serverName, toolName, flowManager }) {
} }
/** /**
* Creates a general tool for an entire action set. * @param {Object} params
* @param {() => void} params.runStepEmitter
* @param {(authURL: string) => void} params.runStepDeltaEmitter
* @returns {(authURL: string) => void}
*/
function createOAuthCallback({ runStepEmitter, runStepDeltaEmitter }) {
return function (authURL) {
runStepEmitter();
runStepDeltaEmitter(authURL);
};
}
/**
* @param {Object} params
* @param {ServerRequest} params.req - The Express request object, containing user/request info.
* @param {ServerResponse} params.res - The Express response object for sending events.
* @param {string} params.serverName
* @param {AbortSignal} params.signal
* @param {string} params.model
* @param {number} [params.index]
* @param {Record<string, Record<string, string>>} [params.userMCPAuthMap]
* @returns { Promise<Array<typeof tool | { _call: (toolInput: Object | string) => unknown}>> } An object with `_call` method to execute the tool input.
*/
async function reconnectServer({ req, res, index, signal, serverName, userMCPAuthMap }) {
const runId = Constants.USE_PRELIM_RESPONSE_MESSAGE_ID;
const flowId = `${req.user?.id}:${serverName}:${Date.now()}`;
const flowManager = getFlowStateManager(getLogStores(CacheKeys.FLOWS));
const stepId = 'step_oauth_login_' + serverName;
const toolCall = {
id: flowId,
name: serverName,
type: 'tool_call_chunk',
};
const runStepEmitter = createRunStepEmitter({
res,
index,
runId,
stepId,
toolCall,
});
const runStepDeltaEmitter = createRunStepDeltaEmitter({
res,
stepId,
toolCall,
});
const callback = createOAuthCallback({ runStepEmitter, runStepDeltaEmitter });
const oauthStart = createOAuthStart({
res,
flowId,
callback,
flowManager,
});
return await reinitMCPServer({
req,
signal,
serverName,
oauthStart,
flowManager,
userMCPAuthMap,
forceNew: true,
returnOnOAuth: false,
connectionTimeout: Time.TWO_MINUTES,
});
}
/**
* Creates all tools from the specified MCP Server via `toolKey`.
* *
* @param {Object} params - The parameters for loading action sets. * This function assumes tools could not be aggregated from the cache of tool definitions,
* i.e. `availableTools`, and will reinitialize the MCP server to ensure all tools are generated.
*
* @param {Object} params
* @param {ServerRequest} params.req - The Express request object, containing user/request info.
* @param {ServerResponse} params.res - The Express response object for sending events.
* @param {string} params.serverName
* @param {string} params.model
* @param {Providers | EModelEndpoint} params.provider - The provider for the tool.
* @param {number} [params.index]
* @param {AbortSignal} [params.signal]
* @param {Record<string, Record<string, string>>} [params.userMCPAuthMap]
* @returns { Promise<Array<typeof tool | { _call: (toolInput: Object | string) => unknown}>> } An object with `_call` method to execute the tool input.
*/
async function createMCPTools({ req, res, index, signal, serverName, provider, userMCPAuthMap }) {
const result = await reconnectServer({ req, res, index, signal, serverName, userMCPAuthMap });
if (!result || !result.tools) {
logger.warn(`[MCP][${serverName}] Failed to reinitialize MCP server.`);
return;
}
const serverTools = [];
for (const tool of result.tools) {
const toolInstance = await createMCPTool({
req,
res,
provider,
userMCPAuthMap,
availableTools: result.availableTools,
toolKey: `${tool.name}${Constants.mcp_delimiter}${serverName}`,
});
if (toolInstance) {
serverTools.push(toolInstance);
}
}
return serverTools;
}
/**
* Creates a single tool from the specified MCP Server via `toolKey`.
* @param {Object} params
* @param {ServerRequest} params.req - The Express request object, containing user/request info. * @param {ServerRequest} params.req - The Express request object, containing user/request info.
* @param {ServerResponse} params.res - The Express response object for sending events. * @param {ServerResponse} params.res - The Express response object for sending events.
* @param {string} params.toolKey - The toolKey for the tool. * @param {string} params.toolKey - The toolKey for the tool.
* @param {import('@librechat/agents').Providers | EModelEndpoint} params.provider - The provider for the tool.
* @param {string} params.model - The model for the tool. * @param {string} params.model - The model for the tool.
* @param {number} [params.index]
* @param {AbortSignal} [params.signal]
* @param {Providers | EModelEndpoint} params.provider - The provider for the tool.
* @param {LCAvailableTools} [params.availableTools]
* @param {Record<string, Record<string, string>>} [params.userMCPAuthMap]
* @returns { Promise<typeof tool | { _call: (toolInput: Object | string) => unknown}> } An object with `_call` method to execute the tool input. * @returns { Promise<typeof tool | { _call: (toolInput: Object | string) => unknown}> } An object with `_call` method to execute the tool input.
*/ */
async function createMCPTool({ req, res, toolKey, provider: _provider }) { async function createMCPTool({
const availableTools = await getCachedTools({ userId: req.user?.id, includeGlobal: true }); req,
const toolDefinition = availableTools?.[toolKey]?.function; res,
index,
signal,
toolKey,
provider,
userMCPAuthMap,
availableTools: tools,
}) {
const [toolName, serverName] = toolKey.split(Constants.mcp_delimiter);
const availableTools =
tools ?? (await getCachedTools({ userId: req.user?.id, includeGlobal: true }));
/** @type {LCTool | undefined} */
let toolDefinition = availableTools?.[toolKey]?.function;
if (!toolDefinition) { if (!toolDefinition) {
logger.error(`Tool ${toolKey} not found in available tools`); logger.warn(
return null; `[MCP][${serverName}][${toolName}] Requested tool not found in available tools, re-initializing MCP server.`,
);
const result = await reconnectServer({ req, res, index, signal, serverName, userMCPAuthMap });
toolDefinition = result?.availableTools?.[toolKey]?.function;
} }
if (!toolDefinition) {
logger.warn(`[MCP][${serverName}][${toolName}] Tool definition not found, cannot create tool.`);
return;
}
return createToolInstance({
res,
provider,
toolName,
serverName,
toolDefinition,
});
}
function createToolInstance({ res, toolName, serverName, toolDefinition, provider: _provider }) {
/** @type {LCTool} */ /** @type {LCTool} */
const { description, parameters } = toolDefinition; const { description, parameters } = toolDefinition;
const isGoogle = _provider === Providers.VERTEXAI || _provider === Providers.GOOGLE; const isGoogle = _provider === Providers.VERTEXAI || _provider === Providers.GOOGLE;
@ -128,16 +310,8 @@ async function createMCPTool({ req, res, toolKey, provider: _provider }) {
schema = z.object({ input: z.string().optional() }); schema = z.object({ input: z.string().optional() });
} }
const [toolName, serverName] = toolKey.split(Constants.mcp_delimiter);
const normalizedToolKey = `${toolName}${Constants.mcp_delimiter}${normalizeServerName(serverName)}`; const normalizedToolKey = `${toolName}${Constants.mcp_delimiter}${normalizeServerName(serverName)}`;
if (!req.user?.id) {
logger.error(
`[MCP][${serverName}][${toolName}] User ID not found on request. Cannot create tool.`,
);
throw new Error(`User ID not found on request. Cannot create tool for ${toolKey}.`);
}
/** @type {(toolArguments: Object | string, config?: GraphRunnableConfig) => Promise<unknown>} */ /** @type {(toolArguments: Object | string, config?: GraphRunnableConfig) => Promise<unknown>} */
const _call = async (toolArguments, config) => { const _call = async (toolArguments, config) => {
const userId = config?.configurable?.user?.id || config?.configurable?.user_id; const userId = config?.configurable?.user?.id || config?.configurable?.user_id;
@ -154,14 +328,16 @@ async function createMCPTool({ req, res, toolKey, provider: _provider }) {
const provider = (config?.metadata?.provider || _provider)?.toLowerCase(); const provider = (config?.metadata?.provider || _provider)?.toLowerCase();
const { args: _args, stepId, ...toolCall } = config.toolCall ?? {}; const { args: _args, stepId, ...toolCall } = config.toolCall ?? {};
const loginFlowId = `${serverName}:oauth_login:${config.metadata.thread_id}:${config.metadata.run_id}`; const flowId = `${serverName}:oauth_login:${config.metadata.thread_id}:${config.metadata.run_id}`;
const oauthStart = createOAuthStart({ const runStepDeltaEmitter = createRunStepDeltaEmitter({
res, res,
stepId, stepId,
toolCall, toolCall,
loginFlowId, });
const oauthStart = createOAuthStart({
flowId,
flowManager, flowManager,
signal: derivedSignal, callback: runStepDeltaEmitter,
}); });
const oauthEnd = createOAuthEnd({ const oauthEnd = createOAuthEnd({
res, res,
@ -207,7 +383,7 @@ async function createMCPTool({ req, res, toolKey, provider: _provider }) {
return result; return result;
} catch (error) { } catch (error) {
logger.error( logger.error(
`[MCP][User: ${userId}][${serverName}] Error calling "${toolName}" MCP tool:`, `[MCP][${serverName}][${toolName}][User: ${userId}] Error calling MCP tool:`,
error, error,
); );
@ -220,12 +396,12 @@ async function createMCPTool({ req, res, toolKey, provider: _provider }) {
if (isOAuthError) { if (isOAuthError) {
throw new Error( throw new Error(
`OAuth authentication required for ${serverName}. Please check the server logs for the authentication URL.`, `[MCP][${serverName}][${toolName}] OAuth authentication required. Please check the server logs for the authentication URL.`,
); );
} }
throw new Error( throw new Error(
`"${toolKey}" tool call failed${error?.message ? `: ${error?.message}` : '.'}`, `[MCP][${serverName}][${toolName}] tool call failed${error?.message ? `: ${error?.message}` : '.'}`,
); );
} finally { } finally {
// Clean up abort handler to prevent memory leaks // Clean up abort handler to prevent memory leaks
@ -380,6 +556,7 @@ async function getServerConnectionStatus(
module.exports = { module.exports = {
createMCPTool, createMCPTool,
createMCPTools,
getMCPSetupData, getMCPSetupData,
checkOAuthFlowStatus, checkOAuthFlowStatus,
getServerConnectionStatus, getServerConnectionStatus,

View file

@ -1,9 +1,9 @@
const fs = require('fs'); const fs = require('fs');
const path = require('path'); const path = require('path');
const { sleep } = require('@librechat/agents'); const { sleep } = require('@librechat/agents');
const { getToolkitKey } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas'); const { logger } = require('@librechat/data-schemas');
const { zodToJsonSchema } = require('zod-to-json-schema'); const { zodToJsonSchema } = require('zod-to-json-schema');
const { getToolkitKey, getUserMCPAuthMap } = require('@librechat/api');
const { Calculator } = require('@langchain/community/tools/calculator'); const { Calculator } = require('@langchain/community/tools/calculator');
const { tool: toolFn, Tool, DynamicStructuredTool } = require('@langchain/core/tools'); const { tool: toolFn, Tool, DynamicStructuredTool } = require('@langchain/core/tools');
const { const {
@ -33,12 +33,17 @@ const {
toolkits, toolkits,
} = require('~/app/clients/tools'); } = require('~/app/clients/tools');
const { processFileURL, uploadImageBuffer } = require('~/server/services/Files/process'); const { processFileURL, uploadImageBuffer } = require('~/server/services/Files/process');
const { getEndpointsConfig, getCachedTools } = require('~/server/services/Config'); const {
getEndpointsConfig,
hasCustomUserVars,
getCachedTools,
} = require('~/server/services/Config');
const { createOnSearchResults } = require('~/server/services/Tools/search'); const { createOnSearchResults } = require('~/server/services/Tools/search');
const { isActionDomainAllowed } = require('~/server/services/domains'); const { isActionDomainAllowed } = require('~/server/services/domains');
const { recordUsage } = require('~/server/services/Threads'); const { recordUsage } = require('~/server/services/Threads');
const { loadTools } = require('~/app/clients/tools/util'); const { loadTools } = require('~/app/clients/tools/util');
const { redactMessage } = require('~/config/parsers'); const { redactMessage } = require('~/config/parsers');
const { findPluginAuthsByKeys } = require('~/models');
/** /**
* Loads and formats tools from the specified tool directory. * Loads and formats tools from the specified tool directory.
@ -469,11 +474,12 @@ async function processRequiredActions(client, requiredActions) {
* @param {Object} params - Run params containing user and request information. * @param {Object} params - Run params containing user and request information.
* @param {ServerRequest} params.req - The request object. * @param {ServerRequest} params.req - The request object.
* @param {ServerResponse} params.res - The request object. * @param {ServerResponse} params.res - The request object.
* @param {AbortSignal} params.signal
* @param {Pick<Agent, 'id' | 'provider' | 'model' | 'tools'} params.agent - The agent to load tools for. * @param {Pick<Agent, 'id' | 'provider' | 'model' | 'tools'} params.agent - The agent to load tools for.
* @param {string | undefined} [params.openAIApiKey] - The OpenAI API key. * @param {string | undefined} [params.openAIApiKey] - The OpenAI API key.
* @returns {Promise<{ tools?: StructuredTool[] }>} The agent tools. * @returns {Promise<{ tools?: StructuredTool[]; userMCPAuthMap?: Record<string, Record<string, string>> }>} The agent tools.
*/ */
async function loadAgentTools({ req, res, agent, tool_resources, openAIApiKey }) { async function loadAgentTools({ req, res, agent, signal, tool_resources, openAIApiKey }) {
if (!agent.tools || agent.tools.length === 0) { if (!agent.tools || agent.tools.length === 0) {
return {}; return {};
} else if (agent.tools && agent.tools.length === 1 && agent.tools[0] === AgentCapabilities.ocr) { } else if (agent.tools && agent.tools.length === 1 && agent.tools[0] === AgentCapabilities.ocr) {
@ -523,8 +529,20 @@ async function loadAgentTools({ req, res, agent, tool_resources, openAIApiKey })
webSearchCallbacks = createOnSearchResults(res); webSearchCallbacks = createOnSearchResults(res);
} }
/** @type {Record<string, Record<string, string>>} */
let userMCPAuthMap;
if (await hasCustomUserVars()) {
userMCPAuthMap = await getUserMCPAuthMap({
tools: agent.tools,
userId: req.user.id,
findPluginAuthsByKeys,
});
}
const { loadedTools, toolContextMap } = await loadTools({ const { loadedTools, toolContextMap } = await loadTools({
agent, agent,
signal,
userMCPAuthMap,
functions: true, functions: true,
user: req.user.id, user: req.user.id,
tools: _agentTools, tools: _agentTools,
@ -588,6 +606,7 @@ async function loadAgentTools({ req, res, agent, tool_resources, openAIApiKey })
if (!checkCapability(AgentCapabilities.actions)) { if (!checkCapability(AgentCapabilities.actions)) {
return { return {
tools: agentTools, tools: agentTools,
userMCPAuthMap,
toolContextMap, toolContextMap,
}; };
} }
@ -599,6 +618,7 @@ async function loadAgentTools({ req, res, agent, tool_resources, openAIApiKey })
} }
return { return {
tools: agentTools, tools: agentTools,
userMCPAuthMap,
toolContextMap, toolContextMap,
}; };
} }
@ -707,6 +727,7 @@ async function loadAgentTools({ req, res, agent, tool_resources, openAIApiKey })
return { return {
tools: agentTools, tools: agentTools,
toolContextMap, toolContextMap,
userMCPAuthMap,
}; };
} }

View file

@ -0,0 +1,142 @@
const { logger } = require('@librechat/data-schemas');
const { CacheKeys, Constants } = require('librechat-data-provider');
const { findToken, createToken, updateToken, deleteTokens } = require('~/models');
const { getMCPManager, getFlowStateManager } = require('~/config');
const { updateMCPUserTools } = require('~/server/services/Config');
const { getLogStores } = require('~/cache');
/**
* @param {Object} params
* @param {ServerRequest} params.req
* @param {string} params.serverName - The name of the MCP server
* @param {boolean} params.returnOnOAuth - Whether to initiate OAuth and return, or wait for OAuth flow to finish
* @param {AbortSignal} [params.signal] - The abort signal to handle cancellation.
* @param {boolean} [params.forceNew]
* @param {number} [params.connectionTimeout]
* @param {FlowStateManager<any>} [params.flowManager]
* @param {(authURL: string) => Promise<boolean>} [params.oauthStart]
* @param {Record<string, Record<string, string>>} [params.userMCPAuthMap]
*/
async function reinitMCPServer({
req,
signal,
forceNew,
serverName,
userMCPAuthMap,
connectionTimeout,
returnOnOAuth = true,
oauthStart: _oauthStart,
flowManager: _flowManager,
}) {
/** @type {MCPConnection | null} */
let userConnection = null;
/** @type {LCAvailableTools | null} */
let availableTools = null;
/** @type {ReturnType<MCPConnection['fetchTools']> | null} */
let tools = null;
let oauthRequired = false;
let oauthUrl = null;
try {
const customUserVars = userMCPAuthMap?.[`${Constants.mcp_prefix}${serverName}`];
const flowManager = _flowManager ?? getFlowStateManager(getLogStores(CacheKeys.FLOWS));
const mcpManager = getMCPManager();
const oauthStart =
_oauthStart ??
(async (authURL) => {
logger.info(`[MCP Reinitialize] OAuth URL received: ${authURL}`);
oauthUrl = authURL;
oauthRequired = true;
});
try {
userConnection = await mcpManager.getUserConnection({
user: req.user,
signal,
forceNew,
oauthStart,
serverName,
flowManager,
returnOnOAuth,
customUserVars,
connectionTimeout,
tokenMethods: {
findToken,
updateToken,
createToken,
deleteTokens,
},
});
logger.info(`[MCP Reinitialize] Successfully established connection for ${serverName}`);
} catch (err) {
logger.info(`[MCP Reinitialize] getUserConnection threw error: ${err.message}`);
logger.info(
`[MCP Reinitialize] OAuth state - oauthRequired: ${oauthRequired}, oauthUrl: ${oauthUrl ? 'present' : 'null'}`,
);
const isOAuthError =
err.message?.includes('OAuth') ||
err.message?.includes('authentication') ||
err.message?.includes('401');
const isOAuthFlowInitiated = err.message === 'OAuth flow initiated - return early';
if (isOAuthError || oauthRequired || isOAuthFlowInitiated) {
logger.info(
`[MCP Reinitialize] OAuth required for ${serverName} (isOAuthError: ${isOAuthError}, oauthRequired: ${oauthRequired}, isOAuthFlowInitiated: ${isOAuthFlowInitiated})`,
);
oauthRequired = true;
} else {
logger.error(
`[MCP Reinitialize] Error initializing MCP server ${serverName} for user:`,
err,
);
}
}
if (userConnection && !oauthRequired) {
tools = await userConnection.fetchTools();
availableTools = await updateMCPUserTools({
userId: req.user.id,
serverName,
tools,
});
}
logger.debug(
`[MCP Reinitialize] Sending response for ${serverName} - oauthRequired: ${oauthRequired}, oauthUrl: ${oauthUrl ? 'present' : 'null'}`,
);
const getResponseMessage = () => {
if (oauthRequired) {
return `MCP server '${serverName}' ready for OAuth authentication`;
}
if (userConnection) {
return `MCP server '${serverName}' reinitialized successfully`;
}
return `Failed to reinitialize MCP server '${serverName}'`;
};
const result = {
availableTools,
success: Boolean((userConnection && !oauthRequired) || (oauthRequired && oauthUrl)),
message: getResponseMessage(),
oauthRequired,
serverName,
oauthUrl,
tools,
};
logger.debug(`[MCP Reinitialize] Response for ${serverName}:`, result);
return result;
} catch (error) {
logger.error(
'[MCP Reinitialize] Error loading MCP Tools, servers may still be initializing:',
error,
);
}
}
module.exports = {
reinitMCPServer,
};

View file

@ -1115,6 +1115,18 @@
* @memberof typedefs * @memberof typedefs
*/ */
/**
* @exports MCPConnection
* @typedef {import('@librechat/api').MCPConnection} MCPConnection
* @memberof typedefs
*/
/**
* @exports LCFunctionTool
* @typedef {import('@librechat/api').LCFunctionTool} LCFunctionTool
* @memberof typedefs
*/
/** /**
* @exports FlowStateManager * @exports FlowStateManager
* @typedef {import('@librechat/api').FlowStateManager} FlowStateManager * @typedef {import('@librechat/api').FlowStateManager} FlowStateManager
@ -1825,6 +1837,7 @@
* @param {object} opts - Options for the completion * @param {object} opts - Options for the completion
* @param {onTokenProgress} opts.onProgress - Callback function to handle token progress * @param {onTokenProgress} opts.onProgress - Callback function to handle token progress
* @param {AbortController} opts.abortController - AbortController instance * @param {AbortController} opts.abortController - AbortController instance
* @param {Record<string, Record<string, string>>} [opts.userMCPAuthMap]
* @returns {Promise<string>} * @returns {Promise<string>}
* @memberof typedefs * @memberof typedefs
*/ */

View file

@ -230,15 +230,19 @@ export default function useChatFunctions({
const responseMessageId = const responseMessageId =
editedMessageId ?? editedMessageId ??
(latestMessage?.messageId && isRegenerate ? latestMessage?.messageId + '_' : null) ?? (latestMessage?.messageId && isRegenerate
? latestMessage.messageId.replace(/_+$/, '') + '_'
: null) ??
null; null;
const initialResponseId =
responseMessageId ?? `${isRegenerate ? messageId : intermediateId}`.replace(/_+$/, '') + '_';
const initialResponse: TMessage = { const initialResponse: TMessage = {
sender: responseSender, sender: responseSender,
text: '', text: '',
endpoint: endpoint ?? '', endpoint: endpoint ?? '',
parentMessageId: isRegenerate ? messageId : intermediateId, parentMessageId: isRegenerate ? messageId : intermediateId,
messageId: responseMessageId ?? `${isRegenerate ? messageId : intermediateId}_`, messageId: initialResponseId,
thread_id, thread_id,
conversationId, conversationId,
unfinished: false, unfinished: false,

View file

@ -182,7 +182,7 @@ export default function useEventHandlers({
const { token } = useAuthContext(); const { token } = useAuthContext();
const contentHandler = useContentHandler({ setMessages, getMessages }); const contentHandler = useContentHandler({ setMessages, getMessages });
const stepHandler = useStepHandler({ const { stepHandler, clearStepMaps } = useStepHandler({
setMessages, setMessages,
getMessages, getMessages,
announcePolite, announcePolite,
@ -806,6 +806,7 @@ export default function useEventHandlers({
); );
return { return {
clearStepMaps,
stepHandler, stepHandler,
syncHandler, syncHandler,
finalHandler, finalHandler,

View file

@ -62,6 +62,7 @@ export default function useSSE(
} = chatHelpers; } = chatHelpers;
const { const {
clearStepMaps,
stepHandler, stepHandler,
syncHandler, syncHandler,
finalHandler, finalHandler,
@ -101,6 +102,7 @@ export default function useSSE(
payload = removeNullishValues(payload) as TPayload; payload = removeNullishValues(payload) as TPayload;
let textIndex = null; let textIndex = null;
clearStepMaps();
const sse = new SSE(payloadData.server, { const sse = new SSE(payloadData.server, {
payload: JSON.stringify(payload), payload: JSON.stringify(payload),

View file

@ -1,5 +1,11 @@
import { useCallback, useRef } from 'react'; import { useCallback, useRef } from 'react';
import { StepTypes, ContentTypes, ToolCallTypes, getNonEmptyValue } from 'librechat-data-provider'; import {
Constants,
StepTypes,
ContentTypes,
ToolCallTypes,
getNonEmptyValue,
} from 'librechat-data-provider';
import type { import type {
Agents, Agents,
TMessage, TMessage,
@ -178,11 +184,12 @@ export default function useStepHandler({
return { ...message, content: updatedContent as TMessageContentParts[] }; return { ...message, content: updatedContent as TMessageContentParts[] };
}; };
return useCallback( const stepHandler = useCallback(
({ event, data }: TStepEvent, submission: EventSubmission) => { ({ event, data }: TStepEvent, submission: EventSubmission) => {
const messages = getMessages() || []; const messages = getMessages() || [];
const { userMessage } = submission; const { userMessage } = submission;
setIsSubmitting(true); setIsSubmitting(true);
let parentMessageId = userMessage.messageId;
const currentTime = Date.now(); const currentTime = Date.now();
if (currentTime - lastAnnouncementTimeRef.current > MESSAGE_UPDATE_INTERVAL) { if (currentTime - lastAnnouncementTimeRef.current > MESSAGE_UPDATE_INTERVAL) {
@ -197,7 +204,11 @@ export default function useStepHandler({
if (event === 'on_run_step') { if (event === 'on_run_step') {
const runStep = data as Agents.RunStep; const runStep = data as Agents.RunStep;
const responseMessageId = runStep.runId ?? ''; let responseMessageId = runStep.runId ?? '';
if (responseMessageId === Constants.USE_PRELIM_RESPONSE_MESSAGE_ID) {
responseMessageId = submission?.initialResponse?.messageId ?? '';
parentMessageId = submission?.initialResponse?.parentMessageId ?? '';
}
if (!responseMessageId) { if (!responseMessageId) {
console.warn('No message id found in run step event'); console.warn('No message id found in run step event');
return; return;
@ -211,7 +222,7 @@ export default function useStepHandler({
response = { response = {
...responseMessage, ...responseMessage,
parentMessageId: userMessage.messageId, parentMessageId,
conversationId: userMessage.conversationId, conversationId: userMessage.conversationId,
messageId: responseMessageId, messageId: responseMessageId,
content: initialContent, content: initialContent,
@ -246,14 +257,18 @@ export default function useStepHandler({
messageMap.current.set(responseMessageId, updatedResponse); messageMap.current.set(responseMessageId, updatedResponse);
const updatedMessages = messages.map((msg) => const updatedMessages = messages.map((msg) =>
msg.messageId === runStep.runId ? updatedResponse : msg, msg.messageId === responseMessageId ? updatedResponse : msg,
); );
setMessages(updatedMessages); setMessages(updatedMessages);
} }
} else if (event === 'on_agent_update') { } else if (event === 'on_agent_update') {
const { agent_update } = data as Agents.AgentUpdate; const { agent_update } = data as Agents.AgentUpdate;
const responseMessageId = agent_update.runId || ''; let responseMessageId = agent_update.runId || '';
if (responseMessageId === Constants.USE_PRELIM_RESPONSE_MESSAGE_ID) {
responseMessageId = submission?.initialResponse?.messageId ?? '';
parentMessageId = submission?.initialResponse?.parentMessageId ?? '';
}
if (!responseMessageId) { if (!responseMessageId) {
console.warn('No message id found in agent update event'); console.warn('No message id found in agent update event');
return; return;
@ -271,7 +286,11 @@ export default function useStepHandler({
} else if (event === 'on_message_delta') { } else if (event === 'on_message_delta') {
const messageDelta = data as Agents.MessageDeltaEvent; const messageDelta = data as Agents.MessageDeltaEvent;
const runStep = stepMap.current.get(messageDelta.id); const runStep = stepMap.current.get(messageDelta.id);
const responseMessageId = runStep?.runId ?? ''; let responseMessageId = runStep?.runId ?? '';
if (responseMessageId === Constants.USE_PRELIM_RESPONSE_MESSAGE_ID) {
responseMessageId = submission?.initialResponse?.messageId ?? '';
parentMessageId = submission?.initialResponse?.parentMessageId ?? '';
}
if (!runStep || !responseMessageId) { if (!runStep || !responseMessageId) {
console.warn('No run step or runId found for message delta event'); console.warn('No run step or runId found for message delta event');
@ -299,7 +318,11 @@ export default function useStepHandler({
} else if (event === 'on_reasoning_delta') { } else if (event === 'on_reasoning_delta') {
const reasoningDelta = data as Agents.ReasoningDeltaEvent; const reasoningDelta = data as Agents.ReasoningDeltaEvent;
const runStep = stepMap.current.get(reasoningDelta.id); const runStep = stepMap.current.get(reasoningDelta.id);
const responseMessageId = runStep?.runId ?? ''; let responseMessageId = runStep?.runId ?? '';
if (responseMessageId === Constants.USE_PRELIM_RESPONSE_MESSAGE_ID) {
responseMessageId = submission?.initialResponse?.messageId ?? '';
parentMessageId = submission?.initialResponse?.parentMessageId ?? '';
}
if (!runStep || !responseMessageId) { if (!runStep || !responseMessageId) {
console.warn('No run step or runId found for reasoning delta event'); console.warn('No run step or runId found for reasoning delta event');
@ -327,7 +350,11 @@ export default function useStepHandler({
} else if (event === 'on_run_step_delta') { } else if (event === 'on_run_step_delta') {
const runStepDelta = data as Agents.RunStepDeltaEvent; const runStepDelta = data as Agents.RunStepDeltaEvent;
const runStep = stepMap.current.get(runStepDelta.id); const runStep = stepMap.current.get(runStepDelta.id);
const responseMessageId = runStep?.runId ?? ''; let responseMessageId = runStep?.runId ?? '';
if (responseMessageId === Constants.USE_PRELIM_RESPONSE_MESSAGE_ID) {
responseMessageId = submission?.initialResponse?.messageId ?? '';
parentMessageId = submission?.initialResponse?.parentMessageId ?? '';
}
if (!runStep || !responseMessageId) { if (!runStep || !responseMessageId) {
console.warn('No run step or runId found for run step delta event'); console.warn('No run step or runId found for run step delta event');
@ -366,7 +393,7 @@ export default function useStepHandler({
messageMap.current.set(responseMessageId, updatedResponse); messageMap.current.set(responseMessageId, updatedResponse);
const updatedMessages = messages.map((msg) => const updatedMessages = messages.map((msg) =>
msg.messageId === runStep.runId ? updatedResponse : msg, msg.messageId === responseMessageId ? updatedResponse : msg,
); );
setMessages(updatedMessages); setMessages(updatedMessages);
@ -377,7 +404,11 @@ export default function useStepHandler({
const { id: stepId } = result; const { id: stepId } = result;
const runStep = stepMap.current.get(stepId); const runStep = stepMap.current.get(stepId);
const responseMessageId = runStep?.runId ?? ''; let responseMessageId = runStep?.runId ?? '';
if (responseMessageId === Constants.USE_PRELIM_RESPONSE_MESSAGE_ID) {
responseMessageId = submission?.initialResponse?.messageId ?? '';
parentMessageId = submission?.initialResponse?.parentMessageId ?? '';
}
if (!runStep || !responseMessageId) { if (!runStep || !responseMessageId) {
console.warn('No run step or runId found for completed tool call event'); console.warn('No run step or runId found for completed tool call event');
@ -399,7 +430,7 @@ export default function useStepHandler({
messageMap.current.set(responseMessageId, updatedResponse); messageMap.current.set(responseMessageId, updatedResponse);
const updatedMessages = messages.map((msg) => const updatedMessages = messages.map((msg) =>
msg.messageId === runStep.runId ? updatedResponse : msg, msg.messageId === responseMessageId ? updatedResponse : msg,
); );
setMessages(updatedMessages); setMessages(updatedMessages);
@ -414,4 +445,11 @@ export default function useStepHandler({
}, },
[getMessages, setIsSubmitting, lastAnnouncementTimeRef, announcePolite, setMessages], [getMessages, setIsSubmitting, lastAnnouncementTimeRef, announcePolite, setMessages],
); );
const clearStepMaps = useCallback(() => {
toolCallIdMap.current.clear();
messageMap.current.clear();
stepMap.current.clear();
}, []);
return { stepHandler, clearStepMaps };
} }

View file

@ -1,5 +1,6 @@
/* MCP */ /* MCP */
export * from './mcp/MCPManager'; export * from './mcp/MCPManager';
export * from './mcp/connection';
export * from './mcp/oauth'; export * from './mcp/oauth';
export * from './mcp/auth'; export * from './mcp/auth';
export * from './mcp/zod'; export * from './mcp/zod';

View file

@ -28,6 +28,7 @@ export class MCPConnectionFactory {
protected readonly oauthStart?: (authURL: string) => Promise<void>; protected readonly oauthStart?: (authURL: string) => Promise<void>;
protected readonly oauthEnd?: () => Promise<void>; protected readonly oauthEnd?: () => Promise<void>;
protected readonly returnOnOAuth?: boolean; protected readonly returnOnOAuth?: boolean;
protected readonly connectionTimeout?: number;
/** Creates a new MCP connection with optional OAuth support */ /** Creates a new MCP connection with optional OAuth support */
static async create( static async create(
@ -47,6 +48,7 @@ export class MCPConnectionFactory {
}); });
this.serverName = basic.serverName; this.serverName = basic.serverName;
this.useOAuth = !!oauth?.useOAuth; this.useOAuth = !!oauth?.useOAuth;
this.connectionTimeout = oauth?.connectionTimeout;
this.logPrefix = oauth?.user this.logPrefix = oauth?.user
? `[MCP][${basic.serverName}][${oauth.user.id}]` ? `[MCP][${basic.serverName}][${oauth.user.id}]`
: `[MCP][${basic.serverName}]`; : `[MCP][${basic.serverName}]`;
@ -82,8 +84,9 @@ export class MCPConnectionFactory {
if (!this.tokenMethods?.findToken) return null; if (!this.tokenMethods?.findToken) return null;
try { try {
const flowId = MCPOAuthHandler.generateFlowId(this.userId!, this.serverName);
const tokens = await this.flowManager!.createFlowWithHandler( const tokens = await this.flowManager!.createFlowWithHandler(
`tokens:${this.userId}:${this.serverName}`, flowId,
'mcp_get_tokens', 'mcp_get_tokens',
async () => { async () => {
return await MCPTokenStorage.getTokens({ return await MCPTokenStorage.getTokens({
@ -203,7 +206,7 @@ export class MCPConnectionFactory {
/** Attempts to establish connection with timeout handling */ /** Attempts to establish connection with timeout handling */
protected async attemptToConnect(connection: MCPConnection): Promise<void> { protected async attemptToConnect(connection: MCPConnection): Promise<void> {
const connectTimeout = this.serverConfig.initTimeout ?? 30000; const connectTimeout = this.connectionTimeout ?? this.serverConfig.initTimeout ?? 30000;
const connectionTimeout = new Promise<void>((_, reject) => const connectionTimeout = new Promise<void>((_, reject) =>
setTimeout( setTimeout(
() => reject(new Error(`Connection timeout after ${connectTimeout}ms`)), () => reject(new Error(`Connection timeout after ${connectTimeout}ms`)),
@ -347,6 +350,7 @@ export class MCPConnectionFactory {
newFlowId, newFlowId,
'mcp_oauth', 'mcp_oauth',
flowMetadata as FlowMetadata, flowMetadata as FlowMetadata,
this.signal,
); );
if (typeof this.oauthEnd === 'function') { if (typeof this.oauthEnd === 'function') {
await this.oauthEnd(); await this.oauthEnd();

View file

@ -1,13 +1,8 @@
import { ErrorCode, McpError } from '@modelcontextprotocol/sdk/types.js';
import { logger } from '@librechat/data-schemas'; import { logger } from '@librechat/data-schemas';
import type { TokenMethods } from '@librechat/data-schemas'; import { ErrorCode, McpError } from '@modelcontextprotocol/sdk/types.js';
import type { TUser } from 'librechat-data-provider';
import type { FlowStateManager } from '~/flow/manager';
import type { MCPOAuthTokens } from '~/mcp/oauth';
import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory'; import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory';
import { MCPServersRegistry } from '~/mcp/MCPServersRegistry'; import { MCPServersRegistry } from '~/mcp/MCPServersRegistry';
import { MCPConnection } from './connection'; import { MCPConnection } from './connection';
import type { RequestBody } from '~/types';
import type * as t from './types'; import type * as t from './types';
/** /**
@ -44,8 +39,9 @@ export abstract class UserConnectionManager {
/** Gets or creates a connection for a specific user */ /** Gets or creates a connection for a specific user */
public async getUserConnection({ public async getUserConnection({
user,
serverName, serverName,
forceNew,
user,
flowManager, flowManager,
customUserVars, customUserVars,
requestBody, requestBody,
@ -54,25 +50,18 @@ export abstract class UserConnectionManager {
oauthEnd, oauthEnd,
signal, signal,
returnOnOAuth = false, returnOnOAuth = false,
connectionTimeout,
}: { }: {
user: TUser;
serverName: string; serverName: string;
flowManager: FlowStateManager<MCPOAuthTokens | null>; forceNew?: boolean;
customUserVars?: Record<string, string>; } & Omit<t.OAuthConnectionOptions, 'useOAuth'>): Promise<MCPConnection> {
requestBody?: RequestBody;
tokenMethods?: TokenMethods;
oauthStart?: (authURL: string) => Promise<void>;
oauthEnd?: () => Promise<void>;
signal?: AbortSignal;
returnOnOAuth?: boolean;
}): Promise<MCPConnection> {
const userId = user.id; const userId = user.id;
if (!userId) { if (!userId) {
throw new McpError(ErrorCode.InvalidRequest, `[MCP] User object missing id property`); throw new McpError(ErrorCode.InvalidRequest, `[MCP] User object missing id property`);
} }
const userServerMap = this.userConnections.get(userId); const userServerMap = this.userConnections.get(userId);
let connection = userServerMap?.get(serverName); let connection = forceNew ? undefined : userServerMap?.get(serverName);
const now = Date.now(); const now = Date.now();
// Check if user is idle // Check if user is idle
@ -131,6 +120,7 @@ export abstract class UserConnectionManager {
oauthEnd: oauthEnd, oauthEnd: oauthEnd,
returnOnOAuth: returnOnOAuth, returnOnOAuth: returnOnOAuth,
requestBody: requestBody, requestBody: requestBody,
connectionTimeout: connectionTimeout,
}, },
); );

View file

@ -45,7 +45,7 @@ describe('getUserMCPAuthMap', () => {
}, },
]; ];
const tools = testCases.map((testCase) => const toolInstances = testCases.map((testCase) =>
createMockTool(testCase.normalizedToolName, testCase.originalName), createMockTool(testCase.normalizedToolName, testCase.originalName),
); );
@ -54,7 +54,7 @@ describe('getUserMCPAuthMap', () => {
await getUserMCPAuthMap({ await getUserMCPAuthMap({
userId: 'user123', userId: 'user123',
tools, toolInstances,
findPluginAuthsByKeys: mockFindPluginAuthsByKeys, findPluginAuthsByKeys: mockFindPluginAuthsByKeys,
}); });
@ -69,7 +69,7 @@ describe('getUserMCPAuthMap', () => {
describe('Edge Cases', () => { describe('Edge Cases', () => {
it('should return empty object when no tools have mcpRawServerName', async () => { it('should return empty object when no tools have mcpRawServerName', async () => {
const tools = [ const toolInstances = [
createMockTool('regular_tool', undefined, false), createMockTool('regular_tool', undefined, false),
createMockTool('another_tool', undefined, false), createMockTool('another_tool', undefined, false),
createMockTool('test_mcp_Server_no_raw_name', undefined), createMockTool('test_mcp_Server_no_raw_name', undefined),
@ -77,7 +77,7 @@ describe('getUserMCPAuthMap', () => {
const result = await getUserMCPAuthMap({ const result = await getUserMCPAuthMap({
userId: 'user123', userId: 'user123',
tools, toolInstances,
findPluginAuthsByKeys: mockFindPluginAuthsByKeys, findPluginAuthsByKeys: mockFindPluginAuthsByKeys,
}); });
@ -104,14 +104,14 @@ describe('getUserMCPAuthMap', () => {
}); });
it('should handle database errors gracefully', async () => { it('should handle database errors gracefully', async () => {
const tools = [createMockTool('test_mcp_Server1', 'Server1')]; const toolInstances = [createMockTool('test_mcp_Server1', 'Server1')];
const dbError = new Error('Database connection failed'); const dbError = new Error('Database connection failed');
mockGetPluginAuthMap.mockRejectedValue(dbError); mockGetPluginAuthMap.mockRejectedValue(dbError);
const result = await getUserMCPAuthMap({ const result = await getUserMCPAuthMap({
userId: 'user123', userId: 'user123',
tools, toolInstances,
findPluginAuthsByKeys: mockFindPluginAuthsByKeys, findPluginAuthsByKeys: mockFindPluginAuthsByKeys,
}); });
@ -119,18 +119,119 @@ describe('getUserMCPAuthMap', () => {
}); });
it('should handle non-Error exceptions gracefully', async () => { it('should handle non-Error exceptions gracefully', async () => {
const tools = [createMockTool('test_mcp_Server1', 'Server1')]; const toolInstances = [createMockTool('test_mcp_Server1', 'Server1')];
mockGetPluginAuthMap.mockRejectedValue('String error'); mockGetPluginAuthMap.mockRejectedValue('String error');
const result = await getUserMCPAuthMap({ const result = await getUserMCPAuthMap({
userId: 'user123', userId: 'user123',
tools, toolInstances,
findPluginAuthsByKeys: mockFindPluginAuthsByKeys, findPluginAuthsByKeys: mockFindPluginAuthsByKeys,
}); });
expect(result).toEqual({}); expect(result).toEqual({});
}); });
it('should handle mixed null/undefined values in tools array', async () => {
const tools = [
'test_mcp_Server1',
null,
'test_mcp_Server2',
undefined,
'regular_tool',
'test_mcp_Server3',
];
mockGetPluginAuthMap.mockResolvedValue({
mcp_Server1: { API_KEY: 'key1' },
mcp_Server2: { API_KEY: 'key2' },
mcp_Server3: { API_KEY: 'key3' },
});
const result = await getUserMCPAuthMap({
userId: 'user123',
tools: tools as (string | undefined)[],
findPluginAuthsByKeys: mockFindPluginAuthsByKeys,
});
expect(mockGetPluginAuthMap).toHaveBeenCalledWith({
userId: 'user123',
pluginKeys: ['mcp_Server1', 'mcp_Server2', 'mcp_Server3'],
throwError: false,
findPluginAuthsByKeys: mockFindPluginAuthsByKeys,
});
expect(result).toEqual({
mcp_Server1: { API_KEY: 'key1' },
mcp_Server2: { API_KEY: 'key2' },
mcp_Server3: { API_KEY: 'key3' },
});
});
it('should handle mixed null/undefined values in servers array', async () => {
const servers = ['Server1', null, 'Server2', undefined, 'Server3'];
mockGetPluginAuthMap.mockResolvedValue({
mcp_Server1: { API_KEY: 'key1' },
mcp_Server2: { API_KEY: 'key2' },
mcp_Server3: { API_KEY: 'key3' },
});
const result = await getUserMCPAuthMap({
userId: 'user123',
servers: servers as (string | undefined)[],
findPluginAuthsByKeys: mockFindPluginAuthsByKeys,
});
expect(mockGetPluginAuthMap).toHaveBeenCalledWith({
userId: 'user123',
pluginKeys: ['mcp_Server1', 'mcp_Server2', 'mcp_Server3'],
throwError: false,
findPluginAuthsByKeys: mockFindPluginAuthsByKeys,
});
expect(result).toEqual({
mcp_Server1: { API_KEY: 'key1' },
mcp_Server2: { API_KEY: 'key2' },
mcp_Server3: { API_KEY: 'key3' },
});
});
it('should handle mixed null/undefined values in toolInstances array', async () => {
const toolInstances = [
createMockTool('test_mcp_Server1', 'Server1'),
null,
createMockTool('test_mcp_Server2', 'Server2'),
undefined,
createMockTool('regular_tool', undefined, false),
createMockTool('test_mcp_Server3', 'Server3'),
];
mockGetPluginAuthMap.mockResolvedValue({
mcp_Server1: { API_KEY: 'key1' },
mcp_Server2: { API_KEY: 'key2' },
mcp_Server3: { API_KEY: 'key3' },
});
const result = await getUserMCPAuthMap({
userId: 'user123',
toolInstances: toolInstances as (GenericTool | null)[],
findPluginAuthsByKeys: mockFindPluginAuthsByKeys,
});
expect(mockGetPluginAuthMap).toHaveBeenCalledWith({
userId: 'user123',
pluginKeys: ['mcp_Server1', 'mcp_Server2', 'mcp_Server3'],
throwError: false,
findPluginAuthsByKeys: mockFindPluginAuthsByKeys,
});
expect(result).toEqual({
mcp_Server1: { API_KEY: 'key1' },
mcp_Server2: { API_KEY: 'key2' },
mcp_Server3: { API_KEY: 'key3' },
});
});
}); });
describe('Integration', () => { describe('Integration', () => {
@ -138,7 +239,7 @@ describe('getUserMCPAuthMap', () => {
const originalServerName = 'Connector: Company'; const originalServerName = 'Connector: Company';
const toolName = 'test_auth_mcp_Connector__Company'; const toolName = 'test_auth_mcp_Connector__Company';
const tools = [createMockTool(toolName, originalServerName)]; const toolInstances = [createMockTool(toolName, originalServerName)];
const mockCustomUserVars = { const mockCustomUserVars = {
'mcp_Connector: Company': { 'mcp_Connector: Company': {
@ -151,7 +252,7 @@ describe('getUserMCPAuthMap', () => {
const result = await getUserMCPAuthMap({ const result = await getUserMCPAuthMap({
userId: 'user123', userId: 'user123',
tools, toolInstances,
findPluginAuthsByKeys: mockFindPluginAuthsByKeys, findPluginAuthsByKeys: mockFindPluginAuthsByKeys,
}); });

View file

@ -7,33 +7,56 @@ import { getPluginAuthMap } from '~/agents/auth';
export async function getUserMCPAuthMap({ export async function getUserMCPAuthMap({
userId, userId,
tools, tools,
servers,
toolInstances,
findPluginAuthsByKeys, findPluginAuthsByKeys,
}: { }: {
userId: string; userId: string;
tools: GenericTool[] | undefined; tools?: (string | undefined)[];
servers?: (string | undefined)[];
toolInstances?: (GenericTool | null)[];
findPluginAuthsByKeys: PluginAuthMethods['findPluginAuthsByKeys']; findPluginAuthsByKeys: PluginAuthMethods['findPluginAuthsByKeys'];
}) { }) {
if (!tools || tools.length === 0) {
return {};
}
const uniqueMcpServers = new Set<string>();
for (const tool of tools) {
const mcpTool = tool as GenericTool & { mcpRawServerName?: string };
if (mcpTool.mcpRawServerName) {
uniqueMcpServers.add(`${Constants.mcp_prefix}${mcpTool.mcpRawServerName}`);
}
}
if (uniqueMcpServers.size === 0) {
return {};
}
const mcpPluginKeysToFetch = Array.from(uniqueMcpServers);
let allMcpCustomUserVars: Record<string, Record<string, string>> = {}; let allMcpCustomUserVars: Record<string, Record<string, string>> = {};
let mcpPluginKeysToFetch: string[] = [];
try { try {
const uniqueMcpServers = new Set<string>();
if (servers != null && servers.length) {
for (const serverName of servers) {
if (!serverName) {
continue;
}
uniqueMcpServers.add(`${Constants.mcp_prefix}${serverName}`);
}
} else if (tools != null && tools.length) {
for (const toolName of tools) {
if (!toolName) {
continue;
}
const delimiterIndex = toolName.indexOf(Constants.mcp_delimiter);
if (delimiterIndex === -1) continue;
const mcpServer = toolName.slice(delimiterIndex + Constants.mcp_delimiter.length);
if (!mcpServer) continue;
uniqueMcpServers.add(`${Constants.mcp_prefix}${mcpServer}`);
}
} else if (toolInstances != null && toolInstances.length) {
for (const tool of toolInstances) {
if (!tool) {
continue;
}
const mcpTool = tool as GenericTool & { mcpRawServerName?: string };
if (mcpTool.mcpRawServerName) {
uniqueMcpServers.add(`${Constants.mcp_prefix}${mcpTool.mcpRawServerName}`);
}
}
}
if (uniqueMcpServers.size === 0) {
return {};
}
mcpPluginKeysToFetch = Array.from(uniqueMcpServers);
allMcpCustomUserVars = await getPluginAuthMap({ allMcpCustomUserVars = await getPluginAuthMap({
userId, userId,
pluginKeys: mcpPluginKeysToFetch, pluginKeys: mcpPluginKeysToFetch,

View file

@ -446,7 +446,7 @@ export class MCPConnection extends EventEmitter {
const serverUrl = this.url; const serverUrl = this.url;
logger.debug(`${this.getLogPrefix()} Server URL for OAuth: ${serverUrl}`); logger.debug(`${this.getLogPrefix()} Server URL for OAuth: ${serverUrl}`);
const oauthTimeout = this.options.initTimeout ?? 60000; const oauthTimeout = this.options.initTimeout ?? 60000 * 2;
/** Promise that will resolve when OAuth is handled */ /** Promise that will resolve when OAuth is handled */
const oauthHandledPromise = new Promise<void>((resolve, reject) => { const oauthHandledPromise = new Promise<void>((resolve, reject) => {
let timeoutId: NodeJS.Timeout | null = null; let timeoutId: NodeJS.Timeout | null = null;

View file

@ -134,4 +134,5 @@ export interface OAuthConnectionOptions {
oauthStart?: (authURL: string) => Promise<void>; oauthStart?: (authURL: string) => Promise<void>;
oauthEnd?: () => Promise<void>; oauthEnd?: () => Promise<void>;
returnOnOAuth?: boolean; returnOnOAuth?: boolean;
connectionTimeout?: number;
} }

View file

@ -1,5 +1,6 @@
import { AuthType, Constants, EToolResources } from 'librechat-data-provider'; import { AuthType, Constants, EToolResources } from 'librechat-data-provider';
import type { TCustomConfig, TPlugin, FunctionTool } from 'librechat-data-provider'; import type { TCustomConfig, TPlugin } from 'librechat-data-provider';
import { LCAvailableTools, LCFunctionTool } from '~/mcp/types';
/** /**
* Filters out duplicate plugins from the list of plugins. * Filters out duplicate plugins from the list of plugins.
@ -60,7 +61,7 @@ export function convertMCPToolToPlugin({
customConfig, customConfig,
}: { }: {
toolKey: string; toolKey: string;
toolData: FunctionTool; toolData: LCFunctionTool;
customConfig?: Partial<TCustomConfig> | null; customConfig?: Partial<TCustomConfig> | null;
}): TPlugin | undefined { }): TPlugin | undefined {
if (!toolData.function || !toolKey.includes(Constants.mcp_delimiter)) { if (!toolData.function || !toolKey.includes(Constants.mcp_delimiter)) {
@ -112,7 +113,7 @@ export function convertMCPToolsToPlugins({
functionTools, functionTools,
customConfig, customConfig,
}: { }: {
functionTools?: Record<string, FunctionTool>; functionTools?: LCAvailableTools;
customConfig?: Partial<TCustomConfig> | null; customConfig?: Partial<TCustomConfig> | null;
}): TPlugin[] | undefined { }): TPlugin[] | undefined {
if (!functionTools || typeof functionTools !== 'object') { if (!functionTools || typeof functionTools !== 'object') {

View file

@ -1525,6 +1525,8 @@ export enum Constants {
CONFIG_VERSION = '1.2.8', CONFIG_VERSION = '1.2.8',
/** Standard value for the first message's `parentMessageId` value, to indicate no parent exists. */ /** Standard value for the first message's `parentMessageId` value, to indicate no parent exists. */
NO_PARENT = '00000000-0000-0000-0000-000000000000', NO_PARENT = '00000000-0000-0000-0000-000000000000',
/** Standard value to use whatever the submission prelim. `responseMessageId` is */
USE_PRELIM_RESPONSE_MESSAGE_ID = 'USE_PRELIM_RESPONSE_MESSAGE_ID',
/** Standard value for the initial conversationId before a request is sent */ /** Standard value for the initial conversationId before a request is sent */
NEW_CONVO = 'new', NEW_CONVO = 'new',
/** Standard value for the temporary conversationId after a request is sent and before the server responds */ /** Standard value for the temporary conversationId after a request is sent and before the server responds */
@ -1551,6 +1553,8 @@ export enum Constants {
mcp_delimiter = '_mcp_', mcp_delimiter = '_mcp_',
/** Prefix for MCP plugins */ /** Prefix for MCP plugins */
mcp_prefix = 'mcp_', mcp_prefix = 'mcp_',
/** Unique value to indicate all MCP servers */
mcp_all = 'sys__all__sys',
/** Placeholder Agent ID for Ephemeral Agents */ /** Placeholder Agent ID for Ephemeral Agents */
EPHEMERAL_AGENT_ID = 'ephemeral', EPHEMERAL_AGENT_ID = 'ephemeral',
} }