mirror of
https://github.com/danny-avila/LibreChat.git
synced 2025-09-22 06:00:56 +02:00
🚦 feat: Auto-reinitialize MCP Servers on Request (#9226)
This commit is contained in:
parent
ac608ded46
commit
c827fdd10e
28 changed files with 871 additions and 312 deletions
|
@ -3,7 +3,7 @@ const { SerpAPI } = require('@langchain/community/tools/serpapi');
|
||||||
const { Calculator } = require('@langchain/community/tools/calculator');
|
const { Calculator } = require('@langchain/community/tools/calculator');
|
||||||
const { mcpToolPattern, loadWebSearchAuth } = require('@librechat/api');
|
const { mcpToolPattern, loadWebSearchAuth } = require('@librechat/api');
|
||||||
const { EnvVar, createCodeExecutionTool, createSearchTool } = require('@librechat/agents');
|
const { EnvVar, createCodeExecutionTool, createSearchTool } = require('@librechat/agents');
|
||||||
const { Tools, EToolResources, replaceSpecialVars } = require('librechat-data-provider');
|
const { Tools, Constants, EToolResources, replaceSpecialVars } = require('librechat-data-provider');
|
||||||
const {
|
const {
|
||||||
availableTools,
|
availableTools,
|
||||||
manifestToolMap,
|
manifestToolMap,
|
||||||
|
@ -24,9 +24,9 @@ const {
|
||||||
const { primeFiles: primeCodeFiles } = require('~/server/services/Files/Code/process');
|
const { primeFiles: primeCodeFiles } = require('~/server/services/Files/Code/process');
|
||||||
const { createFileSearchTool, primeFiles: primeSearchFiles } = require('./fileSearch');
|
const { createFileSearchTool, primeFiles: primeSearchFiles } = require('./fileSearch');
|
||||||
const { getUserPluginAuthValue } = require('~/server/services/PluginService');
|
const { getUserPluginAuthValue } = require('~/server/services/PluginService');
|
||||||
|
const { createMCPTool, createMCPTools } = require('~/server/services/MCP');
|
||||||
const { loadAuthValues } = require('~/server/services/Tools/credentials');
|
const { loadAuthValues } = require('~/server/services/Tools/credentials');
|
||||||
const { getCachedTools } = require('~/server/services/Config');
|
const { getCachedTools } = require('~/server/services/Config');
|
||||||
const { createMCPTool } = require('~/server/services/MCP');
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Validates the availability and authentication of tools for a user based on environment variables or user-specific plugin authentication values.
|
* Validates the availability and authentication of tools for a user based on environment variables or user-specific plugin authentication values.
|
||||||
|
@ -123,6 +123,8 @@ const getAuthFields = (toolKey) => {
|
||||||
*
|
*
|
||||||
* @param {object} object
|
* @param {object} object
|
||||||
* @param {string} object.user
|
* @param {string} object.user
|
||||||
|
* @param {Record<string, Record<string, string>>} [object.userMCPAuthMap]
|
||||||
|
* @param {AbortSignal} [object.signal]
|
||||||
* @param {Pick<Agent, 'id' | 'provider' | 'model'>} [object.agent]
|
* @param {Pick<Agent, 'id' | 'provider' | 'model'>} [object.agent]
|
||||||
* @param {string} [object.model]
|
* @param {string} [object.model]
|
||||||
* @param {EModelEndpoint} [object.endpoint]
|
* @param {EModelEndpoint} [object.endpoint]
|
||||||
|
@ -137,7 +139,9 @@ const loadTools = async ({
|
||||||
user,
|
user,
|
||||||
agent,
|
agent,
|
||||||
model,
|
model,
|
||||||
|
signal,
|
||||||
endpoint,
|
endpoint,
|
||||||
|
userMCPAuthMap,
|
||||||
tools = [],
|
tools = [],
|
||||||
options = {},
|
options = {},
|
||||||
functions = true,
|
functions = true,
|
||||||
|
@ -231,6 +235,7 @@ const loadTools = async ({
|
||||||
/** @type {Record<string, string>} */
|
/** @type {Record<string, string>} */
|
||||||
const toolContextMap = {};
|
const toolContextMap = {};
|
||||||
const cachedTools = (await getCachedTools({ userId: user, includeGlobal: true })) ?? {};
|
const cachedTools = (await getCachedTools({ userId: user, includeGlobal: true })) ?? {};
|
||||||
|
const requestedMCPTools = {};
|
||||||
|
|
||||||
for (const tool of tools) {
|
for (const tool of tools) {
|
||||||
if (tool === Tools.execute_code) {
|
if (tool === Tools.execute_code) {
|
||||||
|
@ -299,14 +304,35 @@ Current Date & Time: ${replaceSpecialVars({ text: '{{iso_datetime}}' })}
|
||||||
};
|
};
|
||||||
continue;
|
continue;
|
||||||
} else if (tool && cachedTools && mcpToolPattern.test(tool)) {
|
} else if (tool && cachedTools && mcpToolPattern.test(tool)) {
|
||||||
requestedTools[tool] = async () =>
|
const [toolName, serverName] = tool.split(Constants.mcp_delimiter);
|
||||||
|
if (toolName === Constants.mcp_all) {
|
||||||
|
const currentMCPGenerator = async (index) =>
|
||||||
|
createMCPTools({
|
||||||
|
req: options.req,
|
||||||
|
res: options.res,
|
||||||
|
index,
|
||||||
|
serverName,
|
||||||
|
userMCPAuthMap,
|
||||||
|
model: agent?.model ?? model,
|
||||||
|
provider: agent?.provider ?? endpoint,
|
||||||
|
signal,
|
||||||
|
});
|
||||||
|
requestedMCPTools[serverName] = [currentMCPGenerator];
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
const currentMCPGenerator = async (index) =>
|
||||||
createMCPTool({
|
createMCPTool({
|
||||||
|
index,
|
||||||
req: options.req,
|
req: options.req,
|
||||||
res: options.res,
|
res: options.res,
|
||||||
toolKey: tool,
|
toolKey: tool,
|
||||||
|
userMCPAuthMap,
|
||||||
model: agent?.model ?? model,
|
model: agent?.model ?? model,
|
||||||
provider: agent?.provider ?? endpoint,
|
provider: agent?.provider ?? endpoint,
|
||||||
|
signal,
|
||||||
});
|
});
|
||||||
|
requestedMCPTools[serverName] = requestedMCPTools[serverName] || [];
|
||||||
|
requestedMCPTools[serverName].push(currentMCPGenerator);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -346,6 +372,34 @@ Current Date & Time: ${replaceSpecialVars({ text: '{{iso_datetime}}' })}
|
||||||
}
|
}
|
||||||
|
|
||||||
const loadedTools = (await Promise.all(toolPromises)).flatMap((plugin) => plugin || []);
|
const loadedTools = (await Promise.all(toolPromises)).flatMap((plugin) => plugin || []);
|
||||||
|
const mcpToolPromises = [];
|
||||||
|
/** MCP server tools are initialized sequentially by server */
|
||||||
|
let index = -1;
|
||||||
|
for (const [serverName, generators] of Object.entries(requestedMCPTools)) {
|
||||||
|
index++;
|
||||||
|
for (const generator of generators) {
|
||||||
|
try {
|
||||||
|
if (generator && generators.length === 1) {
|
||||||
|
mcpToolPromises.push(
|
||||||
|
generator(index).catch((error) => {
|
||||||
|
logger.error(`Error loading ${serverName} tools:`, error);
|
||||||
|
return null;
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
const mcpTool = await generator(index);
|
||||||
|
if (Array.isArray(mcpTool)) {
|
||||||
|
loadedTools.push(...mcpTool);
|
||||||
|
} else if (mcpTool) {
|
||||||
|
loadedTools.push(mcpTool);
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
logger.error(`Error loading MCP tool for server ${serverName}:`, error);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
loadedTools.push(...(await Promise.all(mcpToolPromises)).flatMap((plugin) => plugin || []));
|
||||||
return { loadedTools, toolContextMap };
|
return { loadedTools, toolContextMap };
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,7 @@ const mongoose = require('mongoose');
|
||||||
const crypto = require('node:crypto');
|
const crypto = require('node:crypto');
|
||||||
const { logger } = require('@librechat/data-schemas');
|
const { logger } = require('@librechat/data-schemas');
|
||||||
const { ResourceType, SystemRoles, Tools, actionDelimiter } = require('librechat-data-provider');
|
const { ResourceType, SystemRoles, Tools, actionDelimiter } = require('librechat-data-provider');
|
||||||
const { GLOBAL_PROJECT_NAME, EPHEMERAL_AGENT_ID, mcp_delimiter } =
|
const { GLOBAL_PROJECT_NAME, EPHEMERAL_AGENT_ID, mcp_all, mcp_delimiter } =
|
||||||
require('librechat-data-provider').Constants;
|
require('librechat-data-provider').Constants;
|
||||||
const {
|
const {
|
||||||
removeAgentFromAllProjects,
|
removeAgentFromAllProjects,
|
||||||
|
@ -78,6 +78,7 @@ const loadEphemeralAgent = async ({ req, agent_id, endpoint, model_parameters: _
|
||||||
tools.push(Tools.web_search);
|
tools.push(Tools.web_search);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const addedServers = new Set();
|
||||||
if (mcpServers.size > 0) {
|
if (mcpServers.size > 0) {
|
||||||
for (const toolName of Object.keys(availableTools)) {
|
for (const toolName of Object.keys(availableTools)) {
|
||||||
if (!toolName.includes(mcp_delimiter)) {
|
if (!toolName.includes(mcp_delimiter)) {
|
||||||
|
@ -85,9 +86,17 @@ const loadEphemeralAgent = async ({ req, agent_id, endpoint, model_parameters: _
|
||||||
}
|
}
|
||||||
const mcpServer = toolName.split(mcp_delimiter)?.[1];
|
const mcpServer = toolName.split(mcp_delimiter)?.[1];
|
||||||
if (mcpServer && mcpServers.has(mcpServer)) {
|
if (mcpServer && mcpServers.has(mcpServer)) {
|
||||||
|
addedServers.add(mcpServer);
|
||||||
tools.push(toolName);
|
tools.push(toolName);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for (const mcpServer of mcpServers) {
|
||||||
|
if (addedServers.has(mcpServer)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
tools.push(`${mcp_all}${mcp_delimiter}${mcpServer}`);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const instructions = req.body.promptPrefix;
|
const instructions = req.body.promptPrefix;
|
||||||
|
|
|
@ -33,18 +33,13 @@ const {
|
||||||
bedrockInputSchema,
|
bedrockInputSchema,
|
||||||
removeNullishValues,
|
removeNullishValues,
|
||||||
} = require('librechat-data-provider');
|
} = require('librechat-data-provider');
|
||||||
const {
|
|
||||||
findPluginAuthsByKeys,
|
|
||||||
getFormattedMemories,
|
|
||||||
deleteMemory,
|
|
||||||
setMemory,
|
|
||||||
} = require('~/models');
|
|
||||||
const { getMCPAuthMap, checkCapability, hasCustomUserVars } = require('~/server/services/Config');
|
|
||||||
const { addCacheControl, createContextHandlers } = require('~/app/clients/prompts');
|
const { addCacheControl, createContextHandlers } = require('~/app/clients/prompts');
|
||||||
const { initializeAgent } = require('~/server/services/Endpoints/agents/agent');
|
const { initializeAgent } = require('~/server/services/Endpoints/agents/agent');
|
||||||
const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens');
|
const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens');
|
||||||
|
const { getFormattedMemories, deleteMemory, setMemory } = require('~/models');
|
||||||
const { encodeAndFormat } = require('~/server/services/Files/images/encode');
|
const { encodeAndFormat } = require('~/server/services/Files/images/encode');
|
||||||
const { getProviderConfig } = require('~/server/services/Endpoints');
|
const { getProviderConfig } = require('~/server/services/Endpoints');
|
||||||
|
const { checkCapability } = require('~/server/services/Config');
|
||||||
const BaseClient = require('~/app/clients/BaseClient');
|
const BaseClient = require('~/app/clients/BaseClient');
|
||||||
const { getRoleByName } = require('~/models/Role');
|
const { getRoleByName } = require('~/models/Role');
|
||||||
const { loadAgent } = require('~/models/Agent');
|
const { loadAgent } = require('~/models/Agent');
|
||||||
|
@ -615,6 +610,7 @@ class AgentClient extends BaseClient {
|
||||||
await this.chatCompletion({
|
await this.chatCompletion({
|
||||||
payload,
|
payload,
|
||||||
onProgress: opts.onProgress,
|
onProgress: opts.onProgress,
|
||||||
|
userMCPAuthMap: opts.userMCPAuthMap,
|
||||||
abortController: opts.abortController,
|
abortController: opts.abortController,
|
||||||
});
|
});
|
||||||
return this.contentParts;
|
return this.contentParts;
|
||||||
|
@ -747,7 +743,13 @@ class AgentClient extends BaseClient {
|
||||||
return currentMessageTokens > 0 ? currentMessageTokens : originalEstimate;
|
return currentMessageTokens > 0 ? currentMessageTokens : originalEstimate;
|
||||||
}
|
}
|
||||||
|
|
||||||
async chatCompletion({ payload, abortController = null }) {
|
/**
|
||||||
|
* @param {object} params
|
||||||
|
* @param {string | ChatCompletionMessageParam[]} params.payload
|
||||||
|
* @param {Record<string, Record<string, string>>} [params.userMCPAuthMap]
|
||||||
|
* @param {AbortController} [params.abortController]
|
||||||
|
*/
|
||||||
|
async chatCompletion({ payload, userMCPAuthMap, abortController = null }) {
|
||||||
/** @type {Partial<GraphRunnableConfig>} */
|
/** @type {Partial<GraphRunnableConfig>} */
|
||||||
let config;
|
let config;
|
||||||
/** @type {ReturnType<createRun>} */
|
/** @type {ReturnType<createRun>} */
|
||||||
|
@ -903,21 +905,9 @@ class AgentClient extends BaseClient {
|
||||||
run.Graph.contentData = contentData;
|
run.Graph.contentData = contentData;
|
||||||
}
|
}
|
||||||
|
|
||||||
try {
|
if (userMCPAuthMap != null) {
|
||||||
if (await hasCustomUserVars()) {
|
config.configurable.userMCPAuthMap = userMCPAuthMap;
|
||||||
config.configurable.userMCPAuthMap = await getMCPAuthMap({
|
|
||||||
tools: agent.tools,
|
|
||||||
userId: this.options.req.user.id,
|
|
||||||
findPluginAuthsByKeys,
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
} catch (err) {
|
|
||||||
logger.error(
|
|
||||||
`[api/server/controllers/agents/client.js #chatCompletion] Error getting custom user vars for agent ${agent.id}`,
|
|
||||||
err,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
await run.processStream({ messages }, config, {
|
await run.processStream({ messages }, config, {
|
||||||
keepContent: i !== 0,
|
keepContent: i !== 0,
|
||||||
tokenCounter: createTokenCounter(this.getEncoding()),
|
tokenCounter: createTokenCounter(this.getEncoding()),
|
||||||
|
|
|
@ -9,6 +9,24 @@ const {
|
||||||
const { disposeClient, clientRegistry, requestDataMap } = require('~/server/cleanup');
|
const { disposeClient, clientRegistry, requestDataMap } = require('~/server/cleanup');
|
||||||
const { saveMessage } = require('~/models');
|
const { saveMessage } = require('~/models');
|
||||||
|
|
||||||
|
function createCloseHandler(abortController) {
|
||||||
|
return function (manual) {
|
||||||
|
if (!manual) {
|
||||||
|
logger.debug('[AgentController] Request closed');
|
||||||
|
}
|
||||||
|
if (!abortController) {
|
||||||
|
return;
|
||||||
|
} else if (abortController.signal.aborted) {
|
||||||
|
return;
|
||||||
|
} else if (abortController.requestCompleted) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
abortController.abort();
|
||||||
|
logger.debug('[AgentController] Request aborted on close');
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
||||||
let {
|
let {
|
||||||
text,
|
text,
|
||||||
|
@ -31,7 +49,6 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
||||||
let userMessagePromise;
|
let userMessagePromise;
|
||||||
let getAbortData;
|
let getAbortData;
|
||||||
let client = null;
|
let client = null;
|
||||||
// Initialize as an array
|
|
||||||
let cleanupHandlers = [];
|
let cleanupHandlers = [];
|
||||||
|
|
||||||
const newConvo = !conversationId;
|
const newConvo = !conversationId;
|
||||||
|
@ -62,9 +79,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
||||||
// Create a function to handle final cleanup
|
// Create a function to handle final cleanup
|
||||||
const performCleanup = () => {
|
const performCleanup = () => {
|
||||||
logger.debug('[AgentController] Performing cleanup');
|
logger.debug('[AgentController] Performing cleanup');
|
||||||
// Make sure cleanupHandlers is an array before iterating
|
|
||||||
if (Array.isArray(cleanupHandlers)) {
|
if (Array.isArray(cleanupHandlers)) {
|
||||||
// Execute all cleanup handlers
|
|
||||||
for (const handler of cleanupHandlers) {
|
for (const handler of cleanupHandlers) {
|
||||||
try {
|
try {
|
||||||
if (typeof handler === 'function') {
|
if (typeof handler === 'function') {
|
||||||
|
@ -105,8 +120,33 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
||||||
};
|
};
|
||||||
|
|
||||||
try {
|
try {
|
||||||
/** @type {{ client: TAgentClient }} */
|
let prelimAbortController = new AbortController();
|
||||||
const result = await initializeClient({ req, res, endpointOption });
|
const prelimCloseHandler = createCloseHandler(prelimAbortController);
|
||||||
|
res.on('close', prelimCloseHandler);
|
||||||
|
const removePrelimHandler = (manual) => {
|
||||||
|
try {
|
||||||
|
prelimCloseHandler(manual);
|
||||||
|
res.removeListener('close', prelimCloseHandler);
|
||||||
|
} catch (e) {
|
||||||
|
logger.error('[AgentController] Error removing close listener', e);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
cleanupHandlers.push(removePrelimHandler);
|
||||||
|
/** @type {{ client: TAgentClient; userMCPAuthMap?: Record<string, Record<string, string>> }} */
|
||||||
|
const result = await initializeClient({
|
||||||
|
req,
|
||||||
|
res,
|
||||||
|
endpointOption,
|
||||||
|
signal: prelimAbortController.signal,
|
||||||
|
});
|
||||||
|
if (prelimAbortController.signal?.aborted) {
|
||||||
|
prelimAbortController = null;
|
||||||
|
throw new Error('Request was aborted before initialization could complete');
|
||||||
|
} else {
|
||||||
|
prelimAbortController = null;
|
||||||
|
removePrelimHandler(true);
|
||||||
|
cleanupHandlers.pop();
|
||||||
|
}
|
||||||
client = result.client;
|
client = result.client;
|
||||||
|
|
||||||
// Register client with finalization registry if available
|
// Register client with finalization registry if available
|
||||||
|
@ -138,22 +178,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
||||||
};
|
};
|
||||||
|
|
||||||
const { abortController, onStart } = createAbortController(req, res, getAbortData, getReqData);
|
const { abortController, onStart } = createAbortController(req, res, getAbortData, getReqData);
|
||||||
|
const closeHandler = createCloseHandler(abortController);
|
||||||
// Simple handler to avoid capturing scope
|
|
||||||
const closeHandler = () => {
|
|
||||||
logger.debug('[AgentController] Request closed');
|
|
||||||
if (!abortController) {
|
|
||||||
return;
|
|
||||||
} else if (abortController.signal.aborted) {
|
|
||||||
return;
|
|
||||||
} else if (abortController.requestCompleted) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
abortController.abort();
|
|
||||||
logger.debug('[AgentController] Request aborted on close');
|
|
||||||
};
|
|
||||||
|
|
||||||
res.on('close', closeHandler);
|
res.on('close', closeHandler);
|
||||||
cleanupHandlers.push(() => {
|
cleanupHandlers.push(() => {
|
||||||
try {
|
try {
|
||||||
|
@ -175,6 +200,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
||||||
abortController,
|
abortController,
|
||||||
overrideParentMessageId,
|
overrideParentMessageId,
|
||||||
isEdited: !!editedContent,
|
isEdited: !!editedContent,
|
||||||
|
userMCPAuthMap: result.userMCPAuthMap,
|
||||||
responseMessageId: editedResponseMessageId,
|
responseMessageId: editedResponseMessageId,
|
||||||
progressOptions: {
|
progressOptions: {
|
||||||
res,
|
res,
|
||||||
|
|
|
@ -11,6 +11,7 @@ jest.mock('@librechat/api', () => ({
|
||||||
completeOAuthFlow: jest.fn(),
|
completeOAuthFlow: jest.fn(),
|
||||||
generateFlowId: jest.fn(),
|
generateFlowId: jest.fn(),
|
||||||
},
|
},
|
||||||
|
getUserMCPAuthMap: jest.fn(),
|
||||||
}));
|
}));
|
||||||
|
|
||||||
jest.mock('@librechat/data-schemas', () => ({
|
jest.mock('@librechat/data-schemas', () => ({
|
||||||
|
@ -37,6 +38,7 @@ jest.mock('~/models', () => ({
|
||||||
updateToken: jest.fn(),
|
updateToken: jest.fn(),
|
||||||
createToken: jest.fn(),
|
createToken: jest.fn(),
|
||||||
deleteTokens: jest.fn(),
|
deleteTokens: jest.fn(),
|
||||||
|
findPluginAuthsByKeys: jest.fn(),
|
||||||
}));
|
}));
|
||||||
|
|
||||||
jest.mock('~/server/services/Config', () => ({
|
jest.mock('~/server/services/Config', () => ({
|
||||||
|
@ -71,6 +73,10 @@ jest.mock('~/server/middleware', () => ({
|
||||||
requireJwtAuth: (req, res, next) => next(),
|
requireJwtAuth: (req, res, next) => next(),
|
||||||
}));
|
}));
|
||||||
|
|
||||||
|
jest.mock('~/server/services/Tools/mcp', () => ({
|
||||||
|
reinitMCPServer: jest.fn(),
|
||||||
|
}));
|
||||||
|
|
||||||
describe('MCP Routes', () => {
|
describe('MCP Routes', () => {
|
||||||
let app;
|
let app;
|
||||||
let mongoServer;
|
let mongoServer;
|
||||||
|
@ -682,6 +688,13 @@ describe('MCP Routes', () => {
|
||||||
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
|
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
|
||||||
require('~/config').getFlowStateManager.mockReturnValue({});
|
require('~/config').getFlowStateManager.mockReturnValue({});
|
||||||
require('~/cache').getLogStores.mockReturnValue({});
|
require('~/cache').getLogStores.mockReturnValue({});
|
||||||
|
require('~/server/services/Tools/mcp').reinitMCPServer.mockResolvedValue({
|
||||||
|
success: true,
|
||||||
|
message: "MCP server 'oauth-server' ready for OAuth authentication",
|
||||||
|
serverName: 'oauth-server',
|
||||||
|
oauthRequired: true,
|
||||||
|
oauthUrl: 'https://oauth.example.com/auth',
|
||||||
|
});
|
||||||
|
|
||||||
const response = await request(app).post('/api/mcp/oauth-server/reinitialize');
|
const response = await request(app).post('/api/mcp/oauth-server/reinitialize');
|
||||||
|
|
||||||
|
@ -706,6 +719,7 @@ describe('MCP Routes', () => {
|
||||||
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
|
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
|
||||||
require('~/config').getFlowStateManager.mockReturnValue({});
|
require('~/config').getFlowStateManager.mockReturnValue({});
|
||||||
require('~/cache').getLogStores.mockReturnValue({});
|
require('~/cache').getLogStores.mockReturnValue({});
|
||||||
|
require('~/server/services/Tools/mcp').reinitMCPServer.mockResolvedValue(null);
|
||||||
|
|
||||||
const response = await request(app).post('/api/mcp/error-server/reinitialize');
|
const response = await request(app).post('/api/mcp/error-server/reinitialize');
|
||||||
|
|
||||||
|
@ -769,6 +783,14 @@ describe('MCP Routes', () => {
|
||||||
setCachedTools.mockResolvedValue();
|
setCachedTools.mockResolvedValue();
|
||||||
updateMCPUserTools.mockResolvedValue();
|
updateMCPUserTools.mockResolvedValue();
|
||||||
|
|
||||||
|
require('~/server/services/Tools/mcp').reinitMCPServer.mockResolvedValue({
|
||||||
|
success: true,
|
||||||
|
message: "MCP server 'test-server' reinitialized successfully",
|
||||||
|
serverName: 'test-server',
|
||||||
|
oauthRequired: false,
|
||||||
|
oauthUrl: null,
|
||||||
|
});
|
||||||
|
|
||||||
const response = await request(app).post('/api/mcp/test-server/reinitialize');
|
const response = await request(app).post('/api/mcp/test-server/reinitialize');
|
||||||
|
|
||||||
expect(response.status).toBe(200);
|
expect(response.status).toBe(200);
|
||||||
|
@ -783,14 +805,6 @@ describe('MCP Routes', () => {
|
||||||
'test-user-id',
|
'test-user-id',
|
||||||
'test-server',
|
'test-server',
|
||||||
);
|
);
|
||||||
expect(updateMCPUserTools).toHaveBeenCalledWith({
|
|
||||||
userId: 'test-user-id',
|
|
||||||
serverName: 'test-server',
|
|
||||||
tools: [
|
|
||||||
{ name: 'tool1', description: 'Test tool 1', inputSchema: { type: 'object' } },
|
|
||||||
{ name: 'tool2', description: 'Test tool 2', inputSchema: { type: 'object' } },
|
|
||||||
],
|
|
||||||
});
|
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should handle server with custom user variables', async () => {
|
it('should handle server with custom user variables', async () => {
|
||||||
|
@ -812,9 +826,14 @@ describe('MCP Routes', () => {
|
||||||
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
|
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
|
||||||
require('~/config').getFlowStateManager.mockReturnValue({});
|
require('~/config').getFlowStateManager.mockReturnValue({});
|
||||||
require('~/cache').getLogStores.mockReturnValue({});
|
require('~/cache').getLogStores.mockReturnValue({});
|
||||||
require('~/server/services/PluginService').getUserPluginAuthValue.mockResolvedValue(
|
require('@librechat/api').getUserMCPAuthMap.mockResolvedValue({
|
||||||
'api-key-value',
|
'mcp:test-server': {
|
||||||
);
|
API_KEY: 'api-key-value',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
require('~/models').findPluginAuthsByKeys.mockResolvedValue([
|
||||||
|
{ key: 'API_KEY', value: 'api-key-value' },
|
||||||
|
]);
|
||||||
|
|
||||||
const { getCachedTools, setCachedTools } = require('~/server/services/Config');
|
const { getCachedTools, setCachedTools } = require('~/server/services/Config');
|
||||||
const { updateMCPUserTools } = require('~/server/services/Config/mcpToolsCache');
|
const { updateMCPUserTools } = require('~/server/services/Config/mcpToolsCache');
|
||||||
|
@ -822,13 +841,23 @@ describe('MCP Routes', () => {
|
||||||
setCachedTools.mockResolvedValue();
|
setCachedTools.mockResolvedValue();
|
||||||
updateMCPUserTools.mockResolvedValue();
|
updateMCPUserTools.mockResolvedValue();
|
||||||
|
|
||||||
|
require('~/server/services/Tools/mcp').reinitMCPServer.mockResolvedValue({
|
||||||
|
success: true,
|
||||||
|
message: "MCP server 'test-server' reinitialized successfully",
|
||||||
|
serverName: 'test-server',
|
||||||
|
oauthRequired: false,
|
||||||
|
oauthUrl: null,
|
||||||
|
});
|
||||||
|
|
||||||
const response = await request(app).post('/api/mcp/test-server/reinitialize');
|
const response = await request(app).post('/api/mcp/test-server/reinitialize');
|
||||||
|
|
||||||
expect(response.status).toBe(200);
|
expect(response.status).toBe(200);
|
||||||
expect(response.body.success).toBe(true);
|
expect(response.body.success).toBe(true);
|
||||||
expect(
|
expect(require('@librechat/api').getUserMCPAuthMap).toHaveBeenCalledWith({
|
||||||
require('~/server/services/PluginService').getUserPluginAuthValue,
|
userId: 'test-user-id',
|
||||||
).toHaveBeenCalledWith('test-user-id', 'API_KEY', false);
|
servers: ['test-server'],
|
||||||
|
findPluginAuthsByKeys: require('~/models').findPluginAuthsByKeys,
|
||||||
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
|
@ -1,13 +1,15 @@
|
||||||
const { logger } = require('@librechat/data-schemas');
|
|
||||||
const { MCPOAuthHandler } = require('@librechat/api');
|
|
||||||
const { Router } = require('express');
|
const { Router } = require('express');
|
||||||
|
const { logger } = require('@librechat/data-schemas');
|
||||||
|
const { MCPOAuthHandler, getUserMCPAuthMap } = require('@librechat/api');
|
||||||
const { getMCPSetupData, getServerConnectionStatus } = require('~/server/services/MCP');
|
const { getMCPSetupData, getServerConnectionStatus } = require('~/server/services/MCP');
|
||||||
const { findToken, updateToken, createToken, deleteTokens } = require('~/models');
|
const { findToken, updateToken, createToken, deleteTokens } = require('~/models');
|
||||||
const { updateMCPUserTools } = require('~/server/services/Config/mcpToolsCache');
|
const { updateMCPUserTools } = require('~/server/services/Config/mcpToolsCache');
|
||||||
const { getUserPluginAuthValue } = require('~/server/services/PluginService');
|
const { getUserPluginAuthValue } = require('~/server/services/PluginService');
|
||||||
const { CacheKeys, Constants } = require('librechat-data-provider');
|
const { CacheKeys, Constants } = require('librechat-data-provider');
|
||||||
const { getMCPManager, getFlowStateManager } = require('~/config');
|
const { getMCPManager, getFlowStateManager } = require('~/config');
|
||||||
|
const { reinitMCPServer } = require('~/server/services/Tools/mcp');
|
||||||
const { requireJwtAuth } = require('~/server/middleware');
|
const { requireJwtAuth } = require('~/server/middleware');
|
||||||
|
const { findPluginAuthsByKeys } = require('~/models');
|
||||||
const { getLogStores } = require('~/cache');
|
const { getLogStores } = require('~/cache');
|
||||||
|
|
||||||
const router = Router();
|
const router = Router();
|
||||||
|
@ -302,107 +304,39 @@ router.post('/:serverName/reinitialize', requireJwtAuth, async (req, res) => {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
const flowsCache = getLogStores(CacheKeys.FLOWS);
|
|
||||||
const flowManager = getFlowStateManager(flowsCache);
|
|
||||||
|
|
||||||
await mcpManager.disconnectUserConnection(user.id, serverName);
|
await mcpManager.disconnectUserConnection(user.id, serverName);
|
||||||
logger.info(
|
logger.info(
|
||||||
`[MCP Reinitialize] Disconnected existing user connection for server: ${serverName}`,
|
`[MCP Reinitialize] Disconnected existing user connection for server: ${serverName}`,
|
||||||
);
|
);
|
||||||
|
|
||||||
let customUserVars = {};
|
/** @type {Record<string, Record<string, string>> | undefined} */
|
||||||
|
let userMCPAuthMap;
|
||||||
if (serverConfig.customUserVars && typeof serverConfig.customUserVars === 'object') {
|
if (serverConfig.customUserVars && typeof serverConfig.customUserVars === 'object') {
|
||||||
for (const varName of Object.keys(serverConfig.customUserVars)) {
|
userMCPAuthMap = await getUserMCPAuthMap({
|
||||||
try {
|
userId: user.id,
|
||||||
const value = await getUserPluginAuthValue(user.id, varName, false);
|
servers: [serverName],
|
||||||
customUserVars[varName] = value;
|
findPluginAuthsByKeys,
|
||||||
} catch (err) {
|
});
|
||||||
logger.error(`[MCP Reinitialize] Error fetching ${varName} for user ${user.id}:`, err);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let userConnection = null;
|
const result = await reinitMCPServer({
|
||||||
let oauthRequired = false;
|
req,
|
||||||
let oauthUrl = null;
|
|
||||||
|
|
||||||
try {
|
|
||||||
userConnection = await mcpManager.getUserConnection({
|
|
||||||
user,
|
|
||||||
serverName,
|
serverName,
|
||||||
flowManager,
|
userMCPAuthMap,
|
||||||
customUserVars,
|
|
||||||
tokenMethods: {
|
|
||||||
findToken,
|
|
||||||
updateToken,
|
|
||||||
createToken,
|
|
||||||
deleteTokens,
|
|
||||||
},
|
|
||||||
returnOnOAuth: true,
|
|
||||||
oauthStart: async (authURL) => {
|
|
||||||
logger.info(`[MCP Reinitialize] OAuth URL received: ${authURL}`);
|
|
||||||
oauthUrl = authURL;
|
|
||||||
oauthRequired = true;
|
|
||||||
},
|
|
||||||
});
|
});
|
||||||
|
|
||||||
logger.info(`[MCP Reinitialize] Successfully established connection for ${serverName}`);
|
if (!result) {
|
||||||
} catch (err) {
|
|
||||||
logger.info(`[MCP Reinitialize] getUserConnection threw error: ${err.message}`);
|
|
||||||
logger.info(
|
|
||||||
`[MCP Reinitialize] OAuth state - oauthRequired: ${oauthRequired}, oauthUrl: ${oauthUrl ? 'present' : 'null'}`,
|
|
||||||
);
|
|
||||||
|
|
||||||
const isOAuthError =
|
|
||||||
err.message?.includes('OAuth') ||
|
|
||||||
err.message?.includes('authentication') ||
|
|
||||||
err.message?.includes('401');
|
|
||||||
|
|
||||||
const isOAuthFlowInitiated = err.message === 'OAuth flow initiated - return early';
|
|
||||||
|
|
||||||
if (isOAuthError || oauthRequired || isOAuthFlowInitiated) {
|
|
||||||
logger.info(
|
|
||||||
`[MCP Reinitialize] OAuth required for ${serverName} (isOAuthError: ${isOAuthError}, oauthRequired: ${oauthRequired}, isOAuthFlowInitiated: ${isOAuthFlowInitiated})`,
|
|
||||||
);
|
|
||||||
oauthRequired = true;
|
|
||||||
} else {
|
|
||||||
logger.error(
|
|
||||||
`[MCP Reinitialize] Error initializing MCP server ${serverName} for user:`,
|
|
||||||
err,
|
|
||||||
);
|
|
||||||
return res.status(500).json({ error: 'Failed to reinitialize MCP server for user' });
|
return res.status(500).json({ error: 'Failed to reinitialize MCP server for user' });
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if (userConnection && !oauthRequired) {
|
const { success, message, oauthRequired, oauthUrl } = result;
|
||||||
const tools = await userConnection.fetchTools();
|
|
||||||
await updateMCPUserTools({
|
|
||||||
userId: user.id,
|
|
||||||
serverName,
|
|
||||||
tools,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
`[MCP Reinitialize] Sending response for ${serverName} - oauthRequired: ${oauthRequired}, oauthUrl: ${oauthUrl ? 'present' : 'null'}`,
|
|
||||||
);
|
|
||||||
|
|
||||||
const getResponseMessage = () => {
|
|
||||||
if (oauthRequired) {
|
|
||||||
return `MCP server '${serverName}' ready for OAuth authentication`;
|
|
||||||
}
|
|
||||||
if (userConnection) {
|
|
||||||
return `MCP server '${serverName}' reinitialized successfully`;
|
|
||||||
}
|
|
||||||
return `Failed to reinitialize MCP server '${serverName}'`;
|
|
||||||
};
|
|
||||||
|
|
||||||
res.json({
|
res.json({
|
||||||
success: Boolean((userConnection && !oauthRequired) || (oauthRequired && oauthUrl)),
|
success,
|
||||||
message: getResponseMessage(),
|
message,
|
||||||
|
oauthUrl,
|
||||||
serverName,
|
serverName,
|
||||||
oauthRequired,
|
oauthRequired,
|
||||||
oauthUrl,
|
|
||||||
});
|
});
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
logger.error('[MCP Reinitialize] Unexpected error', error);
|
logger.error('[MCP Reinitialize] Unexpected error', error);
|
||||||
|
|
|
@ -26,7 +26,7 @@ const ToolCacheKeys = {
|
||||||
* @param {string[]} [options.roleIds] - Role IDs for role-based tools
|
* @param {string[]} [options.roleIds] - Role IDs for role-based tools
|
||||||
* @param {string[]} [options.groupIds] - Group IDs for group-based tools
|
* @param {string[]} [options.groupIds] - Group IDs for group-based tools
|
||||||
* @param {boolean} [options.includeGlobal=true] - Whether to include global tools
|
* @param {boolean} [options.includeGlobal=true] - Whether to include global tools
|
||||||
* @returns {Promise<Object|null>} The available tools object or null if not cached
|
* @returns {Promise<LCAvailableTools|null>} The available tools object or null if not cached
|
||||||
*/
|
*/
|
||||||
async function getCachedTools(options = {}) {
|
async function getCachedTools(options = {}) {
|
||||||
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
||||||
|
@ -41,13 +41,13 @@ async function getCachedTools(options = {}) {
|
||||||
// Future implementation will merge tools from multiple sources
|
// Future implementation will merge tools from multiple sources
|
||||||
// based on user permissions, roles, and groups
|
// based on user permissions, roles, and groups
|
||||||
if (userId) {
|
if (userId) {
|
||||||
// Check if we have pre-computed effective tools for this user
|
/** @type {LCAvailableTools | null} Check if we have pre-computed effective tools for this user */
|
||||||
const effectiveTools = await cache.get(ToolCacheKeys.EFFECTIVE(userId));
|
const effectiveTools = await cache.get(ToolCacheKeys.EFFECTIVE(userId));
|
||||||
if (effectiveTools) {
|
if (effectiveTools) {
|
||||||
return effectiveTools;
|
return effectiveTools;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Otherwise, compute from individual sources
|
/** @type {LCAvailableTools | null} Otherwise, compute from individual sources */
|
||||||
const toolSources = [];
|
const toolSources = [];
|
||||||
|
|
||||||
if (includeGlobal) {
|
if (includeGlobal) {
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
const { logger } = require('@librechat/data-schemas');
|
const { isEnabled } = require('@librechat/api');
|
||||||
const { isEnabled, getUserMCPAuthMap } = require('@librechat/api');
|
|
||||||
const { CacheKeys, EModelEndpoint } = require('librechat-data-provider');
|
const { CacheKeys, EModelEndpoint } = require('librechat-data-provider');
|
||||||
const { normalizeEndpointName } = require('~/server/utils');
|
const { normalizeEndpointName } = require('~/server/utils');
|
||||||
const loadCustomConfig = require('./loadCustomConfig');
|
const loadCustomConfig = require('./loadCustomConfig');
|
||||||
|
@ -53,31 +52,6 @@ const getCustomEndpointConfig = async (endpoint) => {
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
|
||||||
* @param {Object} params
|
|
||||||
* @param {string} params.userId
|
|
||||||
* @param {GenericTool[]} [params.tools]
|
|
||||||
* @param {import('@librechat/data-schemas').PluginAuthMethods['findPluginAuthsByKeys']} params.findPluginAuthsByKeys
|
|
||||||
* @returns {Promise<Record<string, Record<string, string>> | undefined>}
|
|
||||||
*/
|
|
||||||
async function getMCPAuthMap({ userId, tools, findPluginAuthsByKeys }) {
|
|
||||||
try {
|
|
||||||
if (!tools || tools.length === 0) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
return await getUserMCPAuthMap({
|
|
||||||
tools,
|
|
||||||
userId,
|
|
||||||
findPluginAuthsByKeys,
|
|
||||||
});
|
|
||||||
} catch (err) {
|
|
||||||
logger.error(
|
|
||||||
`[api/server/controllers/agents/client.js #chatCompletion] Error getting custom user vars for agent`,
|
|
||||||
err,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @returns {Promise<boolean>}
|
* @returns {Promise<boolean>}
|
||||||
*/
|
*/
|
||||||
|
@ -88,7 +62,6 @@ async function hasCustomUserVars() {
|
||||||
}
|
}
|
||||||
|
|
||||||
module.exports = {
|
module.exports = {
|
||||||
getMCPAuthMap,
|
|
||||||
getCustomConfig,
|
getCustomConfig,
|
||||||
getBalanceConfig,
|
getBalanceConfig,
|
||||||
hasCustomUserVars,
|
hasCustomUserVars,
|
||||||
|
|
|
@ -9,7 +9,7 @@ const { getLogStores } = require('~/cache');
|
||||||
* @param {string} params.userId - User ID
|
* @param {string} params.userId - User ID
|
||||||
* @param {string} params.serverName - MCP server name
|
* @param {string} params.serverName - MCP server name
|
||||||
* @param {Array} params.tools - Array of tool objects from MCP server
|
* @param {Array} params.tools - Array of tool objects from MCP server
|
||||||
* @returns {Promise<void>}
|
* @returns {Promise<LCAvailableTools>}
|
||||||
*/
|
*/
|
||||||
async function updateMCPUserTools({ userId, serverName, tools }) {
|
async function updateMCPUserTools({ userId, serverName, tools }) {
|
||||||
try {
|
try {
|
||||||
|
@ -39,6 +39,7 @@ async function updateMCPUserTools({ userId, serverName, tools }) {
|
||||||
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
||||||
await cache.delete(CacheKeys.TOOLS);
|
await cache.delete(CacheKeys.TOOLS);
|
||||||
logger.debug(`[MCP Cache] Updated ${tools.length} tools for ${serverName} user ${userId}`);
|
logger.debug(`[MCP Cache] Updated ${tools.length} tools for ${serverName} user ${userId}`);
|
||||||
|
return userTools;
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
logger.error(`[MCP Cache] Failed to update tools for ${serverName}:`, error);
|
logger.error(`[MCP Cache] Failed to update tools for ${serverName}:`, error);
|
||||||
throw error;
|
throw error;
|
||||||
|
|
|
@ -30,7 +30,13 @@ const { getModelMaxTokens } = require('~/utils');
|
||||||
* @param {TEndpointOption} [params.endpointOption]
|
* @param {TEndpointOption} [params.endpointOption]
|
||||||
* @param {Set<string>} [params.allowedProviders]
|
* @param {Set<string>} [params.allowedProviders]
|
||||||
* @param {boolean} [params.isInitialAgent]
|
* @param {boolean} [params.isInitialAgent]
|
||||||
* @returns {Promise<Agent & { tools: StructuredTool[], attachments: Array<MongoFile>, toolContextMap: Record<string, unknown>, maxContextTokens: number }>}
|
* @returns {Promise<Agent & {
|
||||||
|
* tools: StructuredTool[],
|
||||||
|
* attachments: Array<MongoFile>,
|
||||||
|
* toolContextMap: Record<string, unknown>,
|
||||||
|
* maxContextTokens: number,
|
||||||
|
* userMCPAuthMap?: Record<string, Record<string, string>>
|
||||||
|
* }>}
|
||||||
*/
|
*/
|
||||||
const initializeAgent = async ({
|
const initializeAgent = async ({
|
||||||
req,
|
req,
|
||||||
|
@ -91,8 +97,11 @@ const initializeAgent = async ({
|
||||||
});
|
});
|
||||||
|
|
||||||
const provider = agent.provider;
|
const provider = agent.provider;
|
||||||
const { tools: structuredTools, toolContextMap } =
|
const {
|
||||||
(await loadTools?.({
|
tools: structuredTools,
|
||||||
|
toolContextMap,
|
||||||
|
userMCPAuthMap,
|
||||||
|
} = (await loadTools?.({
|
||||||
req,
|
req,
|
||||||
res,
|
res,
|
||||||
provider,
|
provider,
|
||||||
|
@ -189,6 +198,7 @@ const initializeAgent = async ({
|
||||||
tools,
|
tools,
|
||||||
attachments,
|
attachments,
|
||||||
resendFiles,
|
resendFiles,
|
||||||
|
userMCPAuthMap,
|
||||||
toolContextMap,
|
toolContextMap,
|
||||||
useLegacyContent: !!options.useLegacyContent,
|
useLegacyContent: !!options.useLegacyContent,
|
||||||
maxContextTokens: Math.round((agentMaxContextTokens - maxTokens) * 0.9),
|
maxContextTokens: Math.round((agentMaxContextTokens - maxTokens) * 0.9),
|
||||||
|
|
|
@ -19,7 +19,10 @@ const AgentClient = require('~/server/controllers/agents/client');
|
||||||
const { getAgent } = require('~/models/Agent');
|
const { getAgent } = require('~/models/Agent');
|
||||||
const { logViolation } = require('~/cache');
|
const { logViolation } = require('~/cache');
|
||||||
|
|
||||||
function createToolLoader() {
|
/**
|
||||||
|
* @param {AbortSignal} signal
|
||||||
|
*/
|
||||||
|
function createToolLoader(signal) {
|
||||||
/**
|
/**
|
||||||
* @param {object} params
|
* @param {object} params
|
||||||
* @param {ServerRequest} params.req
|
* @param {ServerRequest} params.req
|
||||||
|
@ -29,7 +32,11 @@ function createToolLoader() {
|
||||||
* @param {string} params.provider
|
* @param {string} params.provider
|
||||||
* @param {string} params.model
|
* @param {string} params.model
|
||||||
* @param {AgentToolResources} params.tool_resources
|
* @param {AgentToolResources} params.tool_resources
|
||||||
* @returns {Promise<{ tools: StructuredTool[], toolContextMap: Record<string, unknown> } | undefined>}
|
* @returns {Promise<{
|
||||||
|
* tools: StructuredTool[],
|
||||||
|
* toolContextMap: Record<string, unknown>,
|
||||||
|
* userMCPAuthMap?: Record<string, Record<string, string>>
|
||||||
|
* } | undefined>}
|
||||||
*/
|
*/
|
||||||
return async function loadTools({ req, res, agentId, tools, provider, model, tool_resources }) {
|
return async function loadTools({ req, res, agentId, tools, provider, model, tool_resources }) {
|
||||||
const agent = { id: agentId, tools, provider, model };
|
const agent = { id: agentId, tools, provider, model };
|
||||||
|
@ -38,6 +45,7 @@ function createToolLoader() {
|
||||||
req,
|
req,
|
||||||
res,
|
res,
|
||||||
agent,
|
agent,
|
||||||
|
signal,
|
||||||
tool_resources,
|
tool_resources,
|
||||||
});
|
});
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
|
@ -46,7 +54,7 @@ function createToolLoader() {
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
const initializeClient = async ({ req, res, endpointOption }) => {
|
const initializeClient = async ({ req, res, signal, endpointOption }) => {
|
||||||
if (!endpointOption) {
|
if (!endpointOption) {
|
||||||
throw new Error('Endpoint option not provided');
|
throw new Error('Endpoint option not provided');
|
||||||
}
|
}
|
||||||
|
@ -92,7 +100,7 @@ const initializeClient = async ({ req, res, endpointOption }) => {
|
||||||
/** @type {Set<string>} */
|
/** @type {Set<string>} */
|
||||||
const allowedProviders = new Set(req?.app?.locals?.[EModelEndpoint.agents]?.allowedProviders);
|
const allowedProviders = new Set(req?.app?.locals?.[EModelEndpoint.agents]?.allowedProviders);
|
||||||
|
|
||||||
const loadTools = createToolLoader();
|
const loadTools = createToolLoader(signal);
|
||||||
/** @type {Array<MongoFile>} */
|
/** @type {Array<MongoFile>} */
|
||||||
const requestFiles = req.body.files ?? [];
|
const requestFiles = req.body.files ?? [];
|
||||||
/** @type {string} */
|
/** @type {string} */
|
||||||
|
@ -111,6 +119,7 @@ const initializeClient = async ({ req, res, endpointOption }) => {
|
||||||
});
|
});
|
||||||
|
|
||||||
const agent_ids = primaryConfig.agent_ids;
|
const agent_ids = primaryConfig.agent_ids;
|
||||||
|
let userMCPAuthMap = primaryConfig.userMCPAuthMap;
|
||||||
if (agent_ids?.length) {
|
if (agent_ids?.length) {
|
||||||
for (const agentId of agent_ids) {
|
for (const agentId of agent_ids) {
|
||||||
const agent = await getAgent({ id: agentId });
|
const agent = await getAgent({ id: agentId });
|
||||||
|
@ -140,6 +149,7 @@ const initializeClient = async ({ req, res, endpointOption }) => {
|
||||||
endpointOption,
|
endpointOption,
|
||||||
allowedProviders,
|
allowedProviders,
|
||||||
});
|
});
|
||||||
|
Object.assign(userMCPAuthMap, config.userMCPAuthMap ?? {});
|
||||||
agentConfigs.set(agentId, config);
|
agentConfigs.set(agentId, config);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -188,7 +198,7 @@ const initializeClient = async ({ req, res, endpointOption }) => {
|
||||||
: EModelEndpoint.agents,
|
: EModelEndpoint.agents,
|
||||||
});
|
});
|
||||||
|
|
||||||
return { client };
|
return { client, userMCPAuthMap };
|
||||||
};
|
};
|
||||||
|
|
||||||
module.exports = { initializeClient };
|
module.exports = { initializeClient };
|
||||||
|
|
|
@ -1,7 +1,12 @@
|
||||||
const { z } = require('zod');
|
const { z } = require('zod');
|
||||||
const { tool } = require('@langchain/core/tools');
|
const { tool } = require('@langchain/core/tools');
|
||||||
const { logger } = require('@librechat/data-schemas');
|
const { logger } = require('@librechat/data-schemas');
|
||||||
const { Constants: AgentConstants, Providers, GraphEvents } = require('@librechat/agents');
|
const {
|
||||||
|
Providers,
|
||||||
|
StepTypes,
|
||||||
|
GraphEvents,
|
||||||
|
Constants: AgentConstants,
|
||||||
|
} = require('@librechat/agents');
|
||||||
const {
|
const {
|
||||||
sendEvent,
|
sendEvent,
|
||||||
MCPOAuthHandler,
|
MCPOAuthHandler,
|
||||||
|
@ -11,14 +16,14 @@ const {
|
||||||
const {
|
const {
|
||||||
Time,
|
Time,
|
||||||
CacheKeys,
|
CacheKeys,
|
||||||
StepTypes,
|
|
||||||
Constants,
|
Constants,
|
||||||
ContentTypes,
|
ContentTypes,
|
||||||
isAssistantsEndpoint,
|
isAssistantsEndpoint,
|
||||||
} = require('librechat-data-provider');
|
} = require('librechat-data-provider');
|
||||||
|
const { getCachedTools, loadCustomConfig } = require('./Config');
|
||||||
const { findToken, createToken, updateToken } = require('~/models');
|
const { findToken, createToken, updateToken } = require('~/models');
|
||||||
const { getMCPManager, getFlowStateManager } = require('~/config');
|
const { getMCPManager, getFlowStateManager } = require('~/config');
|
||||||
const { getCachedTools, loadCustomConfig } = require('./Config');
|
const { reinitMCPServer } = require('./Tools/mcp');
|
||||||
const { getLogStores } = require('~/cache');
|
const { getLogStores } = require('~/cache');
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -26,16 +31,13 @@ const { getLogStores } = require('~/cache');
|
||||||
* @param {ServerResponse} params.res - The Express response object for sending events.
|
* @param {ServerResponse} params.res - The Express response object for sending events.
|
||||||
* @param {string} params.stepId - The ID of the step in the flow.
|
* @param {string} params.stepId - The ID of the step in the flow.
|
||||||
* @param {ToolCallChunk} params.toolCall - The tool call object containing tool information.
|
* @param {ToolCallChunk} params.toolCall - The tool call object containing tool information.
|
||||||
* @param {string} params.loginFlowId - The ID of the login flow.
|
|
||||||
* @param {FlowStateManager<any>} params.flowManager - The flow manager instance.
|
|
||||||
*/
|
*/
|
||||||
function createOAuthStart({ res, stepId, toolCall, loginFlowId, flowManager, signal }) {
|
function createRunStepDeltaEmitter({ res, stepId, toolCall }) {
|
||||||
/**
|
/**
|
||||||
* Creates a function to handle OAuth login requests.
|
|
||||||
* @param {string} authURL - The URL to redirect the user for OAuth authentication.
|
* @param {string} authURL - The URL to redirect the user for OAuth authentication.
|
||||||
* @returns {Promise<boolean>} Returns true to indicate the event was sent successfully.
|
* @returns {void}
|
||||||
*/
|
*/
|
||||||
return async function (authURL) {
|
return function (authURL) {
|
||||||
/** @type {{ id: string; delta: AgentToolCallDelta }} */
|
/** @type {{ id: string; delta: AgentToolCallDelta }} */
|
||||||
const data = {
|
const data = {
|
||||||
id: stepId,
|
id: stepId,
|
||||||
|
@ -46,17 +48,54 @@ function createOAuthStart({ res, stepId, toolCall, loginFlowId, flowManager, sig
|
||||||
expires_at: Date.now() + Time.TWO_MINUTES,
|
expires_at: Date.now() + Time.TWO_MINUTES,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
/** Used to ensure the handler (use of `sendEvent`) is only invoked once */
|
|
||||||
await flowManager.createFlowWithHandler(
|
|
||||||
loginFlowId,
|
|
||||||
'oauth_login',
|
|
||||||
async () => {
|
|
||||||
sendEvent(res, { event: GraphEvents.ON_RUN_STEP_DELTA, data });
|
sendEvent(res, { event: GraphEvents.ON_RUN_STEP_DELTA, data });
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param {object} params
|
||||||
|
* @param {ServerResponse} params.res - The Express response object for sending events.
|
||||||
|
* @param {string} params.runId - The Run ID, i.e. message ID
|
||||||
|
* @param {string} params.stepId - The ID of the step in the flow.
|
||||||
|
* @param {ToolCallChunk} params.toolCall - The tool call object containing tool information.
|
||||||
|
* @param {number} [params.index]
|
||||||
|
*/
|
||||||
|
function createRunStepEmitter({ res, runId, stepId, toolCall, index }) {
|
||||||
|
return function () {
|
||||||
|
/** @type {import('@librechat/agents').RunStep} */
|
||||||
|
const data = {
|
||||||
|
runId: runId ?? Constants.USE_PRELIM_RESPONSE_MESSAGE_ID,
|
||||||
|
id: stepId,
|
||||||
|
type: StepTypes.TOOL_CALLS,
|
||||||
|
index: index ?? 0,
|
||||||
|
stepDetails: {
|
||||||
|
type: StepTypes.TOOL_CALLS,
|
||||||
|
tool_calls: [toolCall],
|
||||||
|
},
|
||||||
|
};
|
||||||
|
sendEvent(res, { event: GraphEvents.ON_RUN_STEP, data });
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a function used to ensure the flow handler is only invoked once
|
||||||
|
* @param {object} params
|
||||||
|
* @param {string} params.flowId - The ID of the login flow.
|
||||||
|
* @param {FlowStateManager<any>} params.flowManager - The flow manager instance.
|
||||||
|
* @param {(authURL: string) => void} [params.callback]
|
||||||
|
*/
|
||||||
|
function createOAuthStart({ flowId, flowManager, callback }) {
|
||||||
|
/**
|
||||||
|
* Creates a function to handle OAuth login requests.
|
||||||
|
* @param {string} authURL - The URL to redirect the user for OAuth authentication.
|
||||||
|
* @returns {Promise<boolean>} Returns true to indicate the event was sent successfully.
|
||||||
|
*/
|
||||||
|
return async function (authURL) {
|
||||||
|
await flowManager.createFlowWithHandler(flowId, 'oauth_login', async () => {
|
||||||
|
callback?.(authURL);
|
||||||
logger.debug('Sent OAuth login request to client');
|
logger.debug('Sent OAuth login request to client');
|
||||||
return true;
|
return true;
|
||||||
},
|
});
|
||||||
signal,
|
|
||||||
);
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -99,23 +138,166 @@ function createAbortHandler({ userId, serverName, toolName, flowManager }) {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates a general tool for an entire action set.
|
* @param {Object} params
|
||||||
|
* @param {() => void} params.runStepEmitter
|
||||||
|
* @param {(authURL: string) => void} params.runStepDeltaEmitter
|
||||||
|
* @returns {(authURL: string) => void}
|
||||||
|
*/
|
||||||
|
function createOAuthCallback({ runStepEmitter, runStepDeltaEmitter }) {
|
||||||
|
return function (authURL) {
|
||||||
|
runStepEmitter();
|
||||||
|
runStepDeltaEmitter(authURL);
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param {Object} params
|
||||||
|
* @param {ServerRequest} params.req - The Express request object, containing user/request info.
|
||||||
|
* @param {ServerResponse} params.res - The Express response object for sending events.
|
||||||
|
* @param {string} params.serverName
|
||||||
|
* @param {AbortSignal} params.signal
|
||||||
|
* @param {string} params.model
|
||||||
|
* @param {number} [params.index]
|
||||||
|
* @param {Record<string, Record<string, string>>} [params.userMCPAuthMap]
|
||||||
|
* @returns { Promise<Array<typeof tool | { _call: (toolInput: Object | string) => unknown}>> } An object with `_call` method to execute the tool input.
|
||||||
|
*/
|
||||||
|
async function reconnectServer({ req, res, index, signal, serverName, userMCPAuthMap }) {
|
||||||
|
const runId = Constants.USE_PRELIM_RESPONSE_MESSAGE_ID;
|
||||||
|
const flowId = `${req.user?.id}:${serverName}:${Date.now()}`;
|
||||||
|
const flowManager = getFlowStateManager(getLogStores(CacheKeys.FLOWS));
|
||||||
|
const stepId = 'step_oauth_login_' + serverName;
|
||||||
|
const toolCall = {
|
||||||
|
id: flowId,
|
||||||
|
name: serverName,
|
||||||
|
type: 'tool_call_chunk',
|
||||||
|
};
|
||||||
|
|
||||||
|
const runStepEmitter = createRunStepEmitter({
|
||||||
|
res,
|
||||||
|
index,
|
||||||
|
runId,
|
||||||
|
stepId,
|
||||||
|
toolCall,
|
||||||
|
});
|
||||||
|
const runStepDeltaEmitter = createRunStepDeltaEmitter({
|
||||||
|
res,
|
||||||
|
stepId,
|
||||||
|
toolCall,
|
||||||
|
});
|
||||||
|
const callback = createOAuthCallback({ runStepEmitter, runStepDeltaEmitter });
|
||||||
|
const oauthStart = createOAuthStart({
|
||||||
|
res,
|
||||||
|
flowId,
|
||||||
|
callback,
|
||||||
|
flowManager,
|
||||||
|
});
|
||||||
|
return await reinitMCPServer({
|
||||||
|
req,
|
||||||
|
signal,
|
||||||
|
serverName,
|
||||||
|
oauthStart,
|
||||||
|
flowManager,
|
||||||
|
userMCPAuthMap,
|
||||||
|
forceNew: true,
|
||||||
|
returnOnOAuth: false,
|
||||||
|
connectionTimeout: Time.TWO_MINUTES,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates all tools from the specified MCP Server via `toolKey`.
|
||||||
*
|
*
|
||||||
* @param {Object} params - The parameters for loading action sets.
|
* This function assumes tools could not be aggregated from the cache of tool definitions,
|
||||||
|
* i.e. `availableTools`, and will reinitialize the MCP server to ensure all tools are generated.
|
||||||
|
*
|
||||||
|
* @param {Object} params
|
||||||
|
* @param {ServerRequest} params.req - The Express request object, containing user/request info.
|
||||||
|
* @param {ServerResponse} params.res - The Express response object for sending events.
|
||||||
|
* @param {string} params.serverName
|
||||||
|
* @param {string} params.model
|
||||||
|
* @param {Providers | EModelEndpoint} params.provider - The provider for the tool.
|
||||||
|
* @param {number} [params.index]
|
||||||
|
* @param {AbortSignal} [params.signal]
|
||||||
|
* @param {Record<string, Record<string, string>>} [params.userMCPAuthMap]
|
||||||
|
* @returns { Promise<Array<typeof tool | { _call: (toolInput: Object | string) => unknown}>> } An object with `_call` method to execute the tool input.
|
||||||
|
*/
|
||||||
|
async function createMCPTools({ req, res, index, signal, serverName, provider, userMCPAuthMap }) {
|
||||||
|
const result = await reconnectServer({ req, res, index, signal, serverName, userMCPAuthMap });
|
||||||
|
if (!result || !result.tools) {
|
||||||
|
logger.warn(`[MCP][${serverName}] Failed to reinitialize MCP server.`);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const serverTools = [];
|
||||||
|
for (const tool of result.tools) {
|
||||||
|
const toolInstance = await createMCPTool({
|
||||||
|
req,
|
||||||
|
res,
|
||||||
|
provider,
|
||||||
|
userMCPAuthMap,
|
||||||
|
availableTools: result.availableTools,
|
||||||
|
toolKey: `${tool.name}${Constants.mcp_delimiter}${serverName}`,
|
||||||
|
});
|
||||||
|
if (toolInstance) {
|
||||||
|
serverTools.push(toolInstance);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return serverTools;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a single tool from the specified MCP Server via `toolKey`.
|
||||||
|
* @param {Object} params
|
||||||
* @param {ServerRequest} params.req - The Express request object, containing user/request info.
|
* @param {ServerRequest} params.req - The Express request object, containing user/request info.
|
||||||
* @param {ServerResponse} params.res - The Express response object for sending events.
|
* @param {ServerResponse} params.res - The Express response object for sending events.
|
||||||
* @param {string} params.toolKey - The toolKey for the tool.
|
* @param {string} params.toolKey - The toolKey for the tool.
|
||||||
* @param {import('@librechat/agents').Providers | EModelEndpoint} params.provider - The provider for the tool.
|
|
||||||
* @param {string} params.model - The model for the tool.
|
* @param {string} params.model - The model for the tool.
|
||||||
|
* @param {number} [params.index]
|
||||||
|
* @param {AbortSignal} [params.signal]
|
||||||
|
* @param {Providers | EModelEndpoint} params.provider - The provider for the tool.
|
||||||
|
* @param {LCAvailableTools} [params.availableTools]
|
||||||
|
* @param {Record<string, Record<string, string>>} [params.userMCPAuthMap]
|
||||||
* @returns { Promise<typeof tool | { _call: (toolInput: Object | string) => unknown}> } An object with `_call` method to execute the tool input.
|
* @returns { Promise<typeof tool | { _call: (toolInput: Object | string) => unknown}> } An object with `_call` method to execute the tool input.
|
||||||
*/
|
*/
|
||||||
async function createMCPTool({ req, res, toolKey, provider: _provider }) {
|
async function createMCPTool({
|
||||||
const availableTools = await getCachedTools({ userId: req.user?.id, includeGlobal: true });
|
req,
|
||||||
const toolDefinition = availableTools?.[toolKey]?.function;
|
res,
|
||||||
|
index,
|
||||||
|
signal,
|
||||||
|
toolKey,
|
||||||
|
provider,
|
||||||
|
userMCPAuthMap,
|
||||||
|
availableTools: tools,
|
||||||
|
}) {
|
||||||
|
const [toolName, serverName] = toolKey.split(Constants.mcp_delimiter);
|
||||||
|
const availableTools =
|
||||||
|
tools ?? (await getCachedTools({ userId: req.user?.id, includeGlobal: true }));
|
||||||
|
/** @type {LCTool | undefined} */
|
||||||
|
let toolDefinition = availableTools?.[toolKey]?.function;
|
||||||
if (!toolDefinition) {
|
if (!toolDefinition) {
|
||||||
logger.error(`Tool ${toolKey} not found in available tools`);
|
logger.warn(
|
||||||
return null;
|
`[MCP][${serverName}][${toolName}] Requested tool not found in available tools, re-initializing MCP server.`,
|
||||||
|
);
|
||||||
|
const result = await reconnectServer({ req, res, index, signal, serverName, userMCPAuthMap });
|
||||||
|
toolDefinition = result?.availableTools?.[toolKey]?.function;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!toolDefinition) {
|
||||||
|
logger.warn(`[MCP][${serverName}][${toolName}] Tool definition not found, cannot create tool.`);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
return createToolInstance({
|
||||||
|
res,
|
||||||
|
provider,
|
||||||
|
toolName,
|
||||||
|
serverName,
|
||||||
|
toolDefinition,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
function createToolInstance({ res, toolName, serverName, toolDefinition, provider: _provider }) {
|
||||||
/** @type {LCTool} */
|
/** @type {LCTool} */
|
||||||
const { description, parameters } = toolDefinition;
|
const { description, parameters } = toolDefinition;
|
||||||
const isGoogle = _provider === Providers.VERTEXAI || _provider === Providers.GOOGLE;
|
const isGoogle = _provider === Providers.VERTEXAI || _provider === Providers.GOOGLE;
|
||||||
|
@ -128,16 +310,8 @@ async function createMCPTool({ req, res, toolKey, provider: _provider }) {
|
||||||
schema = z.object({ input: z.string().optional() });
|
schema = z.object({ input: z.string().optional() });
|
||||||
}
|
}
|
||||||
|
|
||||||
const [toolName, serverName] = toolKey.split(Constants.mcp_delimiter);
|
|
||||||
const normalizedToolKey = `${toolName}${Constants.mcp_delimiter}${normalizeServerName(serverName)}`;
|
const normalizedToolKey = `${toolName}${Constants.mcp_delimiter}${normalizeServerName(serverName)}`;
|
||||||
|
|
||||||
if (!req.user?.id) {
|
|
||||||
logger.error(
|
|
||||||
`[MCP][${serverName}][${toolName}] User ID not found on request. Cannot create tool.`,
|
|
||||||
);
|
|
||||||
throw new Error(`User ID not found on request. Cannot create tool for ${toolKey}.`);
|
|
||||||
}
|
|
||||||
|
|
||||||
/** @type {(toolArguments: Object | string, config?: GraphRunnableConfig) => Promise<unknown>} */
|
/** @type {(toolArguments: Object | string, config?: GraphRunnableConfig) => Promise<unknown>} */
|
||||||
const _call = async (toolArguments, config) => {
|
const _call = async (toolArguments, config) => {
|
||||||
const userId = config?.configurable?.user?.id || config?.configurable?.user_id;
|
const userId = config?.configurable?.user?.id || config?.configurable?.user_id;
|
||||||
|
@ -154,14 +328,16 @@ async function createMCPTool({ req, res, toolKey, provider: _provider }) {
|
||||||
const provider = (config?.metadata?.provider || _provider)?.toLowerCase();
|
const provider = (config?.metadata?.provider || _provider)?.toLowerCase();
|
||||||
|
|
||||||
const { args: _args, stepId, ...toolCall } = config.toolCall ?? {};
|
const { args: _args, stepId, ...toolCall } = config.toolCall ?? {};
|
||||||
const loginFlowId = `${serverName}:oauth_login:${config.metadata.thread_id}:${config.metadata.run_id}`;
|
const flowId = `${serverName}:oauth_login:${config.metadata.thread_id}:${config.metadata.run_id}`;
|
||||||
const oauthStart = createOAuthStart({
|
const runStepDeltaEmitter = createRunStepDeltaEmitter({
|
||||||
res,
|
res,
|
||||||
stepId,
|
stepId,
|
||||||
toolCall,
|
toolCall,
|
||||||
loginFlowId,
|
});
|
||||||
|
const oauthStart = createOAuthStart({
|
||||||
|
flowId,
|
||||||
flowManager,
|
flowManager,
|
||||||
signal: derivedSignal,
|
callback: runStepDeltaEmitter,
|
||||||
});
|
});
|
||||||
const oauthEnd = createOAuthEnd({
|
const oauthEnd = createOAuthEnd({
|
||||||
res,
|
res,
|
||||||
|
@ -207,7 +383,7 @@ async function createMCPTool({ req, res, toolKey, provider: _provider }) {
|
||||||
return result;
|
return result;
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
logger.error(
|
logger.error(
|
||||||
`[MCP][User: ${userId}][${serverName}] Error calling "${toolName}" MCP tool:`,
|
`[MCP][${serverName}][${toolName}][User: ${userId}] Error calling MCP tool:`,
|
||||||
error,
|
error,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -220,12 +396,12 @@ async function createMCPTool({ req, res, toolKey, provider: _provider }) {
|
||||||
|
|
||||||
if (isOAuthError) {
|
if (isOAuthError) {
|
||||||
throw new Error(
|
throw new Error(
|
||||||
`OAuth authentication required for ${serverName}. Please check the server logs for the authentication URL.`,
|
`[MCP][${serverName}][${toolName}] OAuth authentication required. Please check the server logs for the authentication URL.`,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
throw new Error(
|
throw new Error(
|
||||||
`"${toolKey}" tool call failed${error?.message ? `: ${error?.message}` : '.'}`,
|
`[MCP][${serverName}][${toolName}] tool call failed${error?.message ? `: ${error?.message}` : '.'}`,
|
||||||
);
|
);
|
||||||
} finally {
|
} finally {
|
||||||
// Clean up abort handler to prevent memory leaks
|
// Clean up abort handler to prevent memory leaks
|
||||||
|
@ -380,6 +556,7 @@ async function getServerConnectionStatus(
|
||||||
|
|
||||||
module.exports = {
|
module.exports = {
|
||||||
createMCPTool,
|
createMCPTool,
|
||||||
|
createMCPTools,
|
||||||
getMCPSetupData,
|
getMCPSetupData,
|
||||||
checkOAuthFlowStatus,
|
checkOAuthFlowStatus,
|
||||||
getServerConnectionStatus,
|
getServerConnectionStatus,
|
||||||
|
|
|
@ -1,9 +1,9 @@
|
||||||
const fs = require('fs');
|
const fs = require('fs');
|
||||||
const path = require('path');
|
const path = require('path');
|
||||||
const { sleep } = require('@librechat/agents');
|
const { sleep } = require('@librechat/agents');
|
||||||
const { getToolkitKey } = require('@librechat/api');
|
|
||||||
const { logger } = require('@librechat/data-schemas');
|
const { logger } = require('@librechat/data-schemas');
|
||||||
const { zodToJsonSchema } = require('zod-to-json-schema');
|
const { zodToJsonSchema } = require('zod-to-json-schema');
|
||||||
|
const { getToolkitKey, getUserMCPAuthMap } = require('@librechat/api');
|
||||||
const { Calculator } = require('@langchain/community/tools/calculator');
|
const { Calculator } = require('@langchain/community/tools/calculator');
|
||||||
const { tool: toolFn, Tool, DynamicStructuredTool } = require('@langchain/core/tools');
|
const { tool: toolFn, Tool, DynamicStructuredTool } = require('@langchain/core/tools');
|
||||||
const {
|
const {
|
||||||
|
@ -33,12 +33,17 @@ const {
|
||||||
toolkits,
|
toolkits,
|
||||||
} = require('~/app/clients/tools');
|
} = require('~/app/clients/tools');
|
||||||
const { processFileURL, uploadImageBuffer } = require('~/server/services/Files/process');
|
const { processFileURL, uploadImageBuffer } = require('~/server/services/Files/process');
|
||||||
const { getEndpointsConfig, getCachedTools } = require('~/server/services/Config');
|
const {
|
||||||
|
getEndpointsConfig,
|
||||||
|
hasCustomUserVars,
|
||||||
|
getCachedTools,
|
||||||
|
} = require('~/server/services/Config');
|
||||||
const { createOnSearchResults } = require('~/server/services/Tools/search');
|
const { createOnSearchResults } = require('~/server/services/Tools/search');
|
||||||
const { isActionDomainAllowed } = require('~/server/services/domains');
|
const { isActionDomainAllowed } = require('~/server/services/domains');
|
||||||
const { recordUsage } = require('~/server/services/Threads');
|
const { recordUsage } = require('~/server/services/Threads');
|
||||||
const { loadTools } = require('~/app/clients/tools/util');
|
const { loadTools } = require('~/app/clients/tools/util');
|
||||||
const { redactMessage } = require('~/config/parsers');
|
const { redactMessage } = require('~/config/parsers');
|
||||||
|
const { findPluginAuthsByKeys } = require('~/models');
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Loads and formats tools from the specified tool directory.
|
* Loads and formats tools from the specified tool directory.
|
||||||
|
@ -469,11 +474,12 @@ async function processRequiredActions(client, requiredActions) {
|
||||||
* @param {Object} params - Run params containing user and request information.
|
* @param {Object} params - Run params containing user and request information.
|
||||||
* @param {ServerRequest} params.req - The request object.
|
* @param {ServerRequest} params.req - The request object.
|
||||||
* @param {ServerResponse} params.res - The request object.
|
* @param {ServerResponse} params.res - The request object.
|
||||||
|
* @param {AbortSignal} params.signal
|
||||||
* @param {Pick<Agent, 'id' | 'provider' | 'model' | 'tools'} params.agent - The agent to load tools for.
|
* @param {Pick<Agent, 'id' | 'provider' | 'model' | 'tools'} params.agent - The agent to load tools for.
|
||||||
* @param {string | undefined} [params.openAIApiKey] - The OpenAI API key.
|
* @param {string | undefined} [params.openAIApiKey] - The OpenAI API key.
|
||||||
* @returns {Promise<{ tools?: StructuredTool[] }>} The agent tools.
|
* @returns {Promise<{ tools?: StructuredTool[]; userMCPAuthMap?: Record<string, Record<string, string>> }>} The agent tools.
|
||||||
*/
|
*/
|
||||||
async function loadAgentTools({ req, res, agent, tool_resources, openAIApiKey }) {
|
async function loadAgentTools({ req, res, agent, signal, tool_resources, openAIApiKey }) {
|
||||||
if (!agent.tools || agent.tools.length === 0) {
|
if (!agent.tools || agent.tools.length === 0) {
|
||||||
return {};
|
return {};
|
||||||
} else if (agent.tools && agent.tools.length === 1 && agent.tools[0] === AgentCapabilities.ocr) {
|
} else if (agent.tools && agent.tools.length === 1 && agent.tools[0] === AgentCapabilities.ocr) {
|
||||||
|
@ -523,8 +529,20 @@ async function loadAgentTools({ req, res, agent, tool_resources, openAIApiKey })
|
||||||
webSearchCallbacks = createOnSearchResults(res);
|
webSearchCallbacks = createOnSearchResults(res);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** @type {Record<string, Record<string, string>>} */
|
||||||
|
let userMCPAuthMap;
|
||||||
|
if (await hasCustomUserVars()) {
|
||||||
|
userMCPAuthMap = await getUserMCPAuthMap({
|
||||||
|
tools: agent.tools,
|
||||||
|
userId: req.user.id,
|
||||||
|
findPluginAuthsByKeys,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
const { loadedTools, toolContextMap } = await loadTools({
|
const { loadedTools, toolContextMap } = await loadTools({
|
||||||
agent,
|
agent,
|
||||||
|
signal,
|
||||||
|
userMCPAuthMap,
|
||||||
functions: true,
|
functions: true,
|
||||||
user: req.user.id,
|
user: req.user.id,
|
||||||
tools: _agentTools,
|
tools: _agentTools,
|
||||||
|
@ -588,6 +606,7 @@ async function loadAgentTools({ req, res, agent, tool_resources, openAIApiKey })
|
||||||
if (!checkCapability(AgentCapabilities.actions)) {
|
if (!checkCapability(AgentCapabilities.actions)) {
|
||||||
return {
|
return {
|
||||||
tools: agentTools,
|
tools: agentTools,
|
||||||
|
userMCPAuthMap,
|
||||||
toolContextMap,
|
toolContextMap,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
@ -599,6 +618,7 @@ async function loadAgentTools({ req, res, agent, tool_resources, openAIApiKey })
|
||||||
}
|
}
|
||||||
return {
|
return {
|
||||||
tools: agentTools,
|
tools: agentTools,
|
||||||
|
userMCPAuthMap,
|
||||||
toolContextMap,
|
toolContextMap,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
@ -707,6 +727,7 @@ async function loadAgentTools({ req, res, agent, tool_resources, openAIApiKey })
|
||||||
return {
|
return {
|
||||||
tools: agentTools,
|
tools: agentTools,
|
||||||
toolContextMap,
|
toolContextMap,
|
||||||
|
userMCPAuthMap,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
142
api/server/services/Tools/mcp.js
Normal file
142
api/server/services/Tools/mcp.js
Normal file
|
@ -0,0 +1,142 @@
|
||||||
|
const { logger } = require('@librechat/data-schemas');
|
||||||
|
const { CacheKeys, Constants } = require('librechat-data-provider');
|
||||||
|
const { findToken, createToken, updateToken, deleteTokens } = require('~/models');
|
||||||
|
const { getMCPManager, getFlowStateManager } = require('~/config');
|
||||||
|
const { updateMCPUserTools } = require('~/server/services/Config');
|
||||||
|
const { getLogStores } = require('~/cache');
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param {Object} params
|
||||||
|
* @param {ServerRequest} params.req
|
||||||
|
* @param {string} params.serverName - The name of the MCP server
|
||||||
|
* @param {boolean} params.returnOnOAuth - Whether to initiate OAuth and return, or wait for OAuth flow to finish
|
||||||
|
* @param {AbortSignal} [params.signal] - The abort signal to handle cancellation.
|
||||||
|
* @param {boolean} [params.forceNew]
|
||||||
|
* @param {number} [params.connectionTimeout]
|
||||||
|
* @param {FlowStateManager<any>} [params.flowManager]
|
||||||
|
* @param {(authURL: string) => Promise<boolean>} [params.oauthStart]
|
||||||
|
* @param {Record<string, Record<string, string>>} [params.userMCPAuthMap]
|
||||||
|
*/
|
||||||
|
async function reinitMCPServer({
|
||||||
|
req,
|
||||||
|
signal,
|
||||||
|
forceNew,
|
||||||
|
serverName,
|
||||||
|
userMCPAuthMap,
|
||||||
|
connectionTimeout,
|
||||||
|
returnOnOAuth = true,
|
||||||
|
oauthStart: _oauthStart,
|
||||||
|
flowManager: _flowManager,
|
||||||
|
}) {
|
||||||
|
/** @type {MCPConnection | null} */
|
||||||
|
let userConnection = null;
|
||||||
|
/** @type {LCAvailableTools | null} */
|
||||||
|
let availableTools = null;
|
||||||
|
/** @type {ReturnType<MCPConnection['fetchTools']> | null} */
|
||||||
|
let tools = null;
|
||||||
|
let oauthRequired = false;
|
||||||
|
let oauthUrl = null;
|
||||||
|
try {
|
||||||
|
const customUserVars = userMCPAuthMap?.[`${Constants.mcp_prefix}${serverName}`];
|
||||||
|
const flowManager = _flowManager ?? getFlowStateManager(getLogStores(CacheKeys.FLOWS));
|
||||||
|
const mcpManager = getMCPManager();
|
||||||
|
|
||||||
|
const oauthStart =
|
||||||
|
_oauthStart ??
|
||||||
|
(async (authURL) => {
|
||||||
|
logger.info(`[MCP Reinitialize] OAuth URL received: ${authURL}`);
|
||||||
|
oauthUrl = authURL;
|
||||||
|
oauthRequired = true;
|
||||||
|
});
|
||||||
|
|
||||||
|
try {
|
||||||
|
userConnection = await mcpManager.getUserConnection({
|
||||||
|
user: req.user,
|
||||||
|
signal,
|
||||||
|
forceNew,
|
||||||
|
oauthStart,
|
||||||
|
serverName,
|
||||||
|
flowManager,
|
||||||
|
returnOnOAuth,
|
||||||
|
customUserVars,
|
||||||
|
connectionTimeout,
|
||||||
|
tokenMethods: {
|
||||||
|
findToken,
|
||||||
|
updateToken,
|
||||||
|
createToken,
|
||||||
|
deleteTokens,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
logger.info(`[MCP Reinitialize] Successfully established connection for ${serverName}`);
|
||||||
|
} catch (err) {
|
||||||
|
logger.info(`[MCP Reinitialize] getUserConnection threw error: ${err.message}`);
|
||||||
|
logger.info(
|
||||||
|
`[MCP Reinitialize] OAuth state - oauthRequired: ${oauthRequired}, oauthUrl: ${oauthUrl ? 'present' : 'null'}`,
|
||||||
|
);
|
||||||
|
|
||||||
|
const isOAuthError =
|
||||||
|
err.message?.includes('OAuth') ||
|
||||||
|
err.message?.includes('authentication') ||
|
||||||
|
err.message?.includes('401');
|
||||||
|
|
||||||
|
const isOAuthFlowInitiated = err.message === 'OAuth flow initiated - return early';
|
||||||
|
|
||||||
|
if (isOAuthError || oauthRequired || isOAuthFlowInitiated) {
|
||||||
|
logger.info(
|
||||||
|
`[MCP Reinitialize] OAuth required for ${serverName} (isOAuthError: ${isOAuthError}, oauthRequired: ${oauthRequired}, isOAuthFlowInitiated: ${isOAuthFlowInitiated})`,
|
||||||
|
);
|
||||||
|
oauthRequired = true;
|
||||||
|
} else {
|
||||||
|
logger.error(
|
||||||
|
`[MCP Reinitialize] Error initializing MCP server ${serverName} for user:`,
|
||||||
|
err,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (userConnection && !oauthRequired) {
|
||||||
|
tools = await userConnection.fetchTools();
|
||||||
|
availableTools = await updateMCPUserTools({
|
||||||
|
userId: req.user.id,
|
||||||
|
serverName,
|
||||||
|
tools,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
`[MCP Reinitialize] Sending response for ${serverName} - oauthRequired: ${oauthRequired}, oauthUrl: ${oauthUrl ? 'present' : 'null'}`,
|
||||||
|
);
|
||||||
|
|
||||||
|
const getResponseMessage = () => {
|
||||||
|
if (oauthRequired) {
|
||||||
|
return `MCP server '${serverName}' ready for OAuth authentication`;
|
||||||
|
}
|
||||||
|
if (userConnection) {
|
||||||
|
return `MCP server '${serverName}' reinitialized successfully`;
|
||||||
|
}
|
||||||
|
return `Failed to reinitialize MCP server '${serverName}'`;
|
||||||
|
};
|
||||||
|
|
||||||
|
const result = {
|
||||||
|
availableTools,
|
||||||
|
success: Boolean((userConnection && !oauthRequired) || (oauthRequired && oauthUrl)),
|
||||||
|
message: getResponseMessage(),
|
||||||
|
oauthRequired,
|
||||||
|
serverName,
|
||||||
|
oauthUrl,
|
||||||
|
tools,
|
||||||
|
};
|
||||||
|
logger.debug(`[MCP Reinitialize] Response for ${serverName}:`, result);
|
||||||
|
return result;
|
||||||
|
} catch (error) {
|
||||||
|
logger.error(
|
||||||
|
'[MCP Reinitialize] Error loading MCP Tools, servers may still be initializing:',
|
||||||
|
error,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
module.exports = {
|
||||||
|
reinitMCPServer,
|
||||||
|
};
|
|
@ -1115,6 +1115,18 @@
|
||||||
* @memberof typedefs
|
* @memberof typedefs
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @exports MCPConnection
|
||||||
|
* @typedef {import('@librechat/api').MCPConnection} MCPConnection
|
||||||
|
* @memberof typedefs
|
||||||
|
*/
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @exports LCFunctionTool
|
||||||
|
* @typedef {import('@librechat/api').LCFunctionTool} LCFunctionTool
|
||||||
|
* @memberof typedefs
|
||||||
|
*/
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @exports FlowStateManager
|
* @exports FlowStateManager
|
||||||
* @typedef {import('@librechat/api').FlowStateManager} FlowStateManager
|
* @typedef {import('@librechat/api').FlowStateManager} FlowStateManager
|
||||||
|
@ -1825,6 +1837,7 @@
|
||||||
* @param {object} opts - Options for the completion
|
* @param {object} opts - Options for the completion
|
||||||
* @param {onTokenProgress} opts.onProgress - Callback function to handle token progress
|
* @param {onTokenProgress} opts.onProgress - Callback function to handle token progress
|
||||||
* @param {AbortController} opts.abortController - AbortController instance
|
* @param {AbortController} opts.abortController - AbortController instance
|
||||||
|
* @param {Record<string, Record<string, string>>} [opts.userMCPAuthMap]
|
||||||
* @returns {Promise<string>}
|
* @returns {Promise<string>}
|
||||||
* @memberof typedefs
|
* @memberof typedefs
|
||||||
*/
|
*/
|
||||||
|
|
|
@ -230,15 +230,19 @@ export default function useChatFunctions({
|
||||||
|
|
||||||
const responseMessageId =
|
const responseMessageId =
|
||||||
editedMessageId ??
|
editedMessageId ??
|
||||||
(latestMessage?.messageId && isRegenerate ? latestMessage?.messageId + '_' : null) ??
|
(latestMessage?.messageId && isRegenerate
|
||||||
|
? latestMessage.messageId.replace(/_+$/, '') + '_'
|
||||||
|
: null) ??
|
||||||
null;
|
null;
|
||||||
|
const initialResponseId =
|
||||||
|
responseMessageId ?? `${isRegenerate ? messageId : intermediateId}`.replace(/_+$/, '') + '_';
|
||||||
|
|
||||||
const initialResponse: TMessage = {
|
const initialResponse: TMessage = {
|
||||||
sender: responseSender,
|
sender: responseSender,
|
||||||
text: '',
|
text: '',
|
||||||
endpoint: endpoint ?? '',
|
endpoint: endpoint ?? '',
|
||||||
parentMessageId: isRegenerate ? messageId : intermediateId,
|
parentMessageId: isRegenerate ? messageId : intermediateId,
|
||||||
messageId: responseMessageId ?? `${isRegenerate ? messageId : intermediateId}_`,
|
messageId: initialResponseId,
|
||||||
thread_id,
|
thread_id,
|
||||||
conversationId,
|
conversationId,
|
||||||
unfinished: false,
|
unfinished: false,
|
||||||
|
|
|
@ -182,7 +182,7 @@ export default function useEventHandlers({
|
||||||
const { token } = useAuthContext();
|
const { token } = useAuthContext();
|
||||||
|
|
||||||
const contentHandler = useContentHandler({ setMessages, getMessages });
|
const contentHandler = useContentHandler({ setMessages, getMessages });
|
||||||
const stepHandler = useStepHandler({
|
const { stepHandler, clearStepMaps } = useStepHandler({
|
||||||
setMessages,
|
setMessages,
|
||||||
getMessages,
|
getMessages,
|
||||||
announcePolite,
|
announcePolite,
|
||||||
|
@ -806,6 +806,7 @@ export default function useEventHandlers({
|
||||||
);
|
);
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
clearStepMaps,
|
||||||
stepHandler,
|
stepHandler,
|
||||||
syncHandler,
|
syncHandler,
|
||||||
finalHandler,
|
finalHandler,
|
||||||
|
|
|
@ -62,6 +62,7 @@ export default function useSSE(
|
||||||
} = chatHelpers;
|
} = chatHelpers;
|
||||||
|
|
||||||
const {
|
const {
|
||||||
|
clearStepMaps,
|
||||||
stepHandler,
|
stepHandler,
|
||||||
syncHandler,
|
syncHandler,
|
||||||
finalHandler,
|
finalHandler,
|
||||||
|
@ -101,6 +102,7 @@ export default function useSSE(
|
||||||
payload = removeNullishValues(payload) as TPayload;
|
payload = removeNullishValues(payload) as TPayload;
|
||||||
|
|
||||||
let textIndex = null;
|
let textIndex = null;
|
||||||
|
clearStepMaps();
|
||||||
|
|
||||||
const sse = new SSE(payloadData.server, {
|
const sse = new SSE(payloadData.server, {
|
||||||
payload: JSON.stringify(payload),
|
payload: JSON.stringify(payload),
|
||||||
|
|
|
@ -1,5 +1,11 @@
|
||||||
import { useCallback, useRef } from 'react';
|
import { useCallback, useRef } from 'react';
|
||||||
import { StepTypes, ContentTypes, ToolCallTypes, getNonEmptyValue } from 'librechat-data-provider';
|
import {
|
||||||
|
Constants,
|
||||||
|
StepTypes,
|
||||||
|
ContentTypes,
|
||||||
|
ToolCallTypes,
|
||||||
|
getNonEmptyValue,
|
||||||
|
} from 'librechat-data-provider';
|
||||||
import type {
|
import type {
|
||||||
Agents,
|
Agents,
|
||||||
TMessage,
|
TMessage,
|
||||||
|
@ -178,11 +184,12 @@ export default function useStepHandler({
|
||||||
return { ...message, content: updatedContent as TMessageContentParts[] };
|
return { ...message, content: updatedContent as TMessageContentParts[] };
|
||||||
};
|
};
|
||||||
|
|
||||||
return useCallback(
|
const stepHandler = useCallback(
|
||||||
({ event, data }: TStepEvent, submission: EventSubmission) => {
|
({ event, data }: TStepEvent, submission: EventSubmission) => {
|
||||||
const messages = getMessages() || [];
|
const messages = getMessages() || [];
|
||||||
const { userMessage } = submission;
|
const { userMessage } = submission;
|
||||||
setIsSubmitting(true);
|
setIsSubmitting(true);
|
||||||
|
let parentMessageId = userMessage.messageId;
|
||||||
|
|
||||||
const currentTime = Date.now();
|
const currentTime = Date.now();
|
||||||
if (currentTime - lastAnnouncementTimeRef.current > MESSAGE_UPDATE_INTERVAL) {
|
if (currentTime - lastAnnouncementTimeRef.current > MESSAGE_UPDATE_INTERVAL) {
|
||||||
|
@ -197,7 +204,11 @@ export default function useStepHandler({
|
||||||
|
|
||||||
if (event === 'on_run_step') {
|
if (event === 'on_run_step') {
|
||||||
const runStep = data as Agents.RunStep;
|
const runStep = data as Agents.RunStep;
|
||||||
const responseMessageId = runStep.runId ?? '';
|
let responseMessageId = runStep.runId ?? '';
|
||||||
|
if (responseMessageId === Constants.USE_PRELIM_RESPONSE_MESSAGE_ID) {
|
||||||
|
responseMessageId = submission?.initialResponse?.messageId ?? '';
|
||||||
|
parentMessageId = submission?.initialResponse?.parentMessageId ?? '';
|
||||||
|
}
|
||||||
if (!responseMessageId) {
|
if (!responseMessageId) {
|
||||||
console.warn('No message id found in run step event');
|
console.warn('No message id found in run step event');
|
||||||
return;
|
return;
|
||||||
|
@ -211,7 +222,7 @@ export default function useStepHandler({
|
||||||
|
|
||||||
response = {
|
response = {
|
||||||
...responseMessage,
|
...responseMessage,
|
||||||
parentMessageId: userMessage.messageId,
|
parentMessageId,
|
||||||
conversationId: userMessage.conversationId,
|
conversationId: userMessage.conversationId,
|
||||||
messageId: responseMessageId,
|
messageId: responseMessageId,
|
||||||
content: initialContent,
|
content: initialContent,
|
||||||
|
@ -246,14 +257,18 @@ export default function useStepHandler({
|
||||||
|
|
||||||
messageMap.current.set(responseMessageId, updatedResponse);
|
messageMap.current.set(responseMessageId, updatedResponse);
|
||||||
const updatedMessages = messages.map((msg) =>
|
const updatedMessages = messages.map((msg) =>
|
||||||
msg.messageId === runStep.runId ? updatedResponse : msg,
|
msg.messageId === responseMessageId ? updatedResponse : msg,
|
||||||
);
|
);
|
||||||
|
|
||||||
setMessages(updatedMessages);
|
setMessages(updatedMessages);
|
||||||
}
|
}
|
||||||
} else if (event === 'on_agent_update') {
|
} else if (event === 'on_agent_update') {
|
||||||
const { agent_update } = data as Agents.AgentUpdate;
|
const { agent_update } = data as Agents.AgentUpdate;
|
||||||
const responseMessageId = agent_update.runId || '';
|
let responseMessageId = agent_update.runId || '';
|
||||||
|
if (responseMessageId === Constants.USE_PRELIM_RESPONSE_MESSAGE_ID) {
|
||||||
|
responseMessageId = submission?.initialResponse?.messageId ?? '';
|
||||||
|
parentMessageId = submission?.initialResponse?.parentMessageId ?? '';
|
||||||
|
}
|
||||||
if (!responseMessageId) {
|
if (!responseMessageId) {
|
||||||
console.warn('No message id found in agent update event');
|
console.warn('No message id found in agent update event');
|
||||||
return;
|
return;
|
||||||
|
@ -271,7 +286,11 @@ export default function useStepHandler({
|
||||||
} else if (event === 'on_message_delta') {
|
} else if (event === 'on_message_delta') {
|
||||||
const messageDelta = data as Agents.MessageDeltaEvent;
|
const messageDelta = data as Agents.MessageDeltaEvent;
|
||||||
const runStep = stepMap.current.get(messageDelta.id);
|
const runStep = stepMap.current.get(messageDelta.id);
|
||||||
const responseMessageId = runStep?.runId ?? '';
|
let responseMessageId = runStep?.runId ?? '';
|
||||||
|
if (responseMessageId === Constants.USE_PRELIM_RESPONSE_MESSAGE_ID) {
|
||||||
|
responseMessageId = submission?.initialResponse?.messageId ?? '';
|
||||||
|
parentMessageId = submission?.initialResponse?.parentMessageId ?? '';
|
||||||
|
}
|
||||||
|
|
||||||
if (!runStep || !responseMessageId) {
|
if (!runStep || !responseMessageId) {
|
||||||
console.warn('No run step or runId found for message delta event');
|
console.warn('No run step or runId found for message delta event');
|
||||||
|
@ -299,7 +318,11 @@ export default function useStepHandler({
|
||||||
} else if (event === 'on_reasoning_delta') {
|
} else if (event === 'on_reasoning_delta') {
|
||||||
const reasoningDelta = data as Agents.ReasoningDeltaEvent;
|
const reasoningDelta = data as Agents.ReasoningDeltaEvent;
|
||||||
const runStep = stepMap.current.get(reasoningDelta.id);
|
const runStep = stepMap.current.get(reasoningDelta.id);
|
||||||
const responseMessageId = runStep?.runId ?? '';
|
let responseMessageId = runStep?.runId ?? '';
|
||||||
|
if (responseMessageId === Constants.USE_PRELIM_RESPONSE_MESSAGE_ID) {
|
||||||
|
responseMessageId = submission?.initialResponse?.messageId ?? '';
|
||||||
|
parentMessageId = submission?.initialResponse?.parentMessageId ?? '';
|
||||||
|
}
|
||||||
|
|
||||||
if (!runStep || !responseMessageId) {
|
if (!runStep || !responseMessageId) {
|
||||||
console.warn('No run step or runId found for reasoning delta event');
|
console.warn('No run step or runId found for reasoning delta event');
|
||||||
|
@ -327,7 +350,11 @@ export default function useStepHandler({
|
||||||
} else if (event === 'on_run_step_delta') {
|
} else if (event === 'on_run_step_delta') {
|
||||||
const runStepDelta = data as Agents.RunStepDeltaEvent;
|
const runStepDelta = data as Agents.RunStepDeltaEvent;
|
||||||
const runStep = stepMap.current.get(runStepDelta.id);
|
const runStep = stepMap.current.get(runStepDelta.id);
|
||||||
const responseMessageId = runStep?.runId ?? '';
|
let responseMessageId = runStep?.runId ?? '';
|
||||||
|
if (responseMessageId === Constants.USE_PRELIM_RESPONSE_MESSAGE_ID) {
|
||||||
|
responseMessageId = submission?.initialResponse?.messageId ?? '';
|
||||||
|
parentMessageId = submission?.initialResponse?.parentMessageId ?? '';
|
||||||
|
}
|
||||||
|
|
||||||
if (!runStep || !responseMessageId) {
|
if (!runStep || !responseMessageId) {
|
||||||
console.warn('No run step or runId found for run step delta event');
|
console.warn('No run step or runId found for run step delta event');
|
||||||
|
@ -366,7 +393,7 @@ export default function useStepHandler({
|
||||||
|
|
||||||
messageMap.current.set(responseMessageId, updatedResponse);
|
messageMap.current.set(responseMessageId, updatedResponse);
|
||||||
const updatedMessages = messages.map((msg) =>
|
const updatedMessages = messages.map((msg) =>
|
||||||
msg.messageId === runStep.runId ? updatedResponse : msg,
|
msg.messageId === responseMessageId ? updatedResponse : msg,
|
||||||
);
|
);
|
||||||
|
|
||||||
setMessages(updatedMessages);
|
setMessages(updatedMessages);
|
||||||
|
@ -377,7 +404,11 @@ export default function useStepHandler({
|
||||||
const { id: stepId } = result;
|
const { id: stepId } = result;
|
||||||
|
|
||||||
const runStep = stepMap.current.get(stepId);
|
const runStep = stepMap.current.get(stepId);
|
||||||
const responseMessageId = runStep?.runId ?? '';
|
let responseMessageId = runStep?.runId ?? '';
|
||||||
|
if (responseMessageId === Constants.USE_PRELIM_RESPONSE_MESSAGE_ID) {
|
||||||
|
responseMessageId = submission?.initialResponse?.messageId ?? '';
|
||||||
|
parentMessageId = submission?.initialResponse?.parentMessageId ?? '';
|
||||||
|
}
|
||||||
|
|
||||||
if (!runStep || !responseMessageId) {
|
if (!runStep || !responseMessageId) {
|
||||||
console.warn('No run step or runId found for completed tool call event');
|
console.warn('No run step or runId found for completed tool call event');
|
||||||
|
@ -399,7 +430,7 @@ export default function useStepHandler({
|
||||||
|
|
||||||
messageMap.current.set(responseMessageId, updatedResponse);
|
messageMap.current.set(responseMessageId, updatedResponse);
|
||||||
const updatedMessages = messages.map((msg) =>
|
const updatedMessages = messages.map((msg) =>
|
||||||
msg.messageId === runStep.runId ? updatedResponse : msg,
|
msg.messageId === responseMessageId ? updatedResponse : msg,
|
||||||
);
|
);
|
||||||
|
|
||||||
setMessages(updatedMessages);
|
setMessages(updatedMessages);
|
||||||
|
@ -414,4 +445,11 @@ export default function useStepHandler({
|
||||||
},
|
},
|
||||||
[getMessages, setIsSubmitting, lastAnnouncementTimeRef, announcePolite, setMessages],
|
[getMessages, setIsSubmitting, lastAnnouncementTimeRef, announcePolite, setMessages],
|
||||||
);
|
);
|
||||||
|
|
||||||
|
const clearStepMaps = useCallback(() => {
|
||||||
|
toolCallIdMap.current.clear();
|
||||||
|
messageMap.current.clear();
|
||||||
|
stepMap.current.clear();
|
||||||
|
}, []);
|
||||||
|
return { stepHandler, clearStepMaps };
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
/* MCP */
|
/* MCP */
|
||||||
export * from './mcp/MCPManager';
|
export * from './mcp/MCPManager';
|
||||||
|
export * from './mcp/connection';
|
||||||
export * from './mcp/oauth';
|
export * from './mcp/oauth';
|
||||||
export * from './mcp/auth';
|
export * from './mcp/auth';
|
||||||
export * from './mcp/zod';
|
export * from './mcp/zod';
|
||||||
|
|
|
@ -28,6 +28,7 @@ export class MCPConnectionFactory {
|
||||||
protected readonly oauthStart?: (authURL: string) => Promise<void>;
|
protected readonly oauthStart?: (authURL: string) => Promise<void>;
|
||||||
protected readonly oauthEnd?: () => Promise<void>;
|
protected readonly oauthEnd?: () => Promise<void>;
|
||||||
protected readonly returnOnOAuth?: boolean;
|
protected readonly returnOnOAuth?: boolean;
|
||||||
|
protected readonly connectionTimeout?: number;
|
||||||
|
|
||||||
/** Creates a new MCP connection with optional OAuth support */
|
/** Creates a new MCP connection with optional OAuth support */
|
||||||
static async create(
|
static async create(
|
||||||
|
@ -47,6 +48,7 @@ export class MCPConnectionFactory {
|
||||||
});
|
});
|
||||||
this.serverName = basic.serverName;
|
this.serverName = basic.serverName;
|
||||||
this.useOAuth = !!oauth?.useOAuth;
|
this.useOAuth = !!oauth?.useOAuth;
|
||||||
|
this.connectionTimeout = oauth?.connectionTimeout;
|
||||||
this.logPrefix = oauth?.user
|
this.logPrefix = oauth?.user
|
||||||
? `[MCP][${basic.serverName}][${oauth.user.id}]`
|
? `[MCP][${basic.serverName}][${oauth.user.id}]`
|
||||||
: `[MCP][${basic.serverName}]`;
|
: `[MCP][${basic.serverName}]`;
|
||||||
|
@ -82,8 +84,9 @@ export class MCPConnectionFactory {
|
||||||
if (!this.tokenMethods?.findToken) return null;
|
if (!this.tokenMethods?.findToken) return null;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
|
const flowId = MCPOAuthHandler.generateFlowId(this.userId!, this.serverName);
|
||||||
const tokens = await this.flowManager!.createFlowWithHandler(
|
const tokens = await this.flowManager!.createFlowWithHandler(
|
||||||
`tokens:${this.userId}:${this.serverName}`,
|
flowId,
|
||||||
'mcp_get_tokens',
|
'mcp_get_tokens',
|
||||||
async () => {
|
async () => {
|
||||||
return await MCPTokenStorage.getTokens({
|
return await MCPTokenStorage.getTokens({
|
||||||
|
@ -203,7 +206,7 @@ export class MCPConnectionFactory {
|
||||||
|
|
||||||
/** Attempts to establish connection with timeout handling */
|
/** Attempts to establish connection with timeout handling */
|
||||||
protected async attemptToConnect(connection: MCPConnection): Promise<void> {
|
protected async attemptToConnect(connection: MCPConnection): Promise<void> {
|
||||||
const connectTimeout = this.serverConfig.initTimeout ?? 30000;
|
const connectTimeout = this.connectionTimeout ?? this.serverConfig.initTimeout ?? 30000;
|
||||||
const connectionTimeout = new Promise<void>((_, reject) =>
|
const connectionTimeout = new Promise<void>((_, reject) =>
|
||||||
setTimeout(
|
setTimeout(
|
||||||
() => reject(new Error(`Connection timeout after ${connectTimeout}ms`)),
|
() => reject(new Error(`Connection timeout after ${connectTimeout}ms`)),
|
||||||
|
@ -347,6 +350,7 @@ export class MCPConnectionFactory {
|
||||||
newFlowId,
|
newFlowId,
|
||||||
'mcp_oauth',
|
'mcp_oauth',
|
||||||
flowMetadata as FlowMetadata,
|
flowMetadata as FlowMetadata,
|
||||||
|
this.signal,
|
||||||
);
|
);
|
||||||
if (typeof this.oauthEnd === 'function') {
|
if (typeof this.oauthEnd === 'function') {
|
||||||
await this.oauthEnd();
|
await this.oauthEnd();
|
||||||
|
|
|
@ -1,13 +1,8 @@
|
||||||
import { ErrorCode, McpError } from '@modelcontextprotocol/sdk/types.js';
|
|
||||||
import { logger } from '@librechat/data-schemas';
|
import { logger } from '@librechat/data-schemas';
|
||||||
import type { TokenMethods } from '@librechat/data-schemas';
|
import { ErrorCode, McpError } from '@modelcontextprotocol/sdk/types.js';
|
||||||
import type { TUser } from 'librechat-data-provider';
|
|
||||||
import type { FlowStateManager } from '~/flow/manager';
|
|
||||||
import type { MCPOAuthTokens } from '~/mcp/oauth';
|
|
||||||
import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory';
|
import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory';
|
||||||
import { MCPServersRegistry } from '~/mcp/MCPServersRegistry';
|
import { MCPServersRegistry } from '~/mcp/MCPServersRegistry';
|
||||||
import { MCPConnection } from './connection';
|
import { MCPConnection } from './connection';
|
||||||
import type { RequestBody } from '~/types';
|
|
||||||
import type * as t from './types';
|
import type * as t from './types';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -44,8 +39,9 @@ export abstract class UserConnectionManager {
|
||||||
|
|
||||||
/** Gets or creates a connection for a specific user */
|
/** Gets or creates a connection for a specific user */
|
||||||
public async getUserConnection({
|
public async getUserConnection({
|
||||||
user,
|
|
||||||
serverName,
|
serverName,
|
||||||
|
forceNew,
|
||||||
|
user,
|
||||||
flowManager,
|
flowManager,
|
||||||
customUserVars,
|
customUserVars,
|
||||||
requestBody,
|
requestBody,
|
||||||
|
@ -54,25 +50,18 @@ export abstract class UserConnectionManager {
|
||||||
oauthEnd,
|
oauthEnd,
|
||||||
signal,
|
signal,
|
||||||
returnOnOAuth = false,
|
returnOnOAuth = false,
|
||||||
|
connectionTimeout,
|
||||||
}: {
|
}: {
|
||||||
user: TUser;
|
|
||||||
serverName: string;
|
serverName: string;
|
||||||
flowManager: FlowStateManager<MCPOAuthTokens | null>;
|
forceNew?: boolean;
|
||||||
customUserVars?: Record<string, string>;
|
} & Omit<t.OAuthConnectionOptions, 'useOAuth'>): Promise<MCPConnection> {
|
||||||
requestBody?: RequestBody;
|
|
||||||
tokenMethods?: TokenMethods;
|
|
||||||
oauthStart?: (authURL: string) => Promise<void>;
|
|
||||||
oauthEnd?: () => Promise<void>;
|
|
||||||
signal?: AbortSignal;
|
|
||||||
returnOnOAuth?: boolean;
|
|
||||||
}): Promise<MCPConnection> {
|
|
||||||
const userId = user.id;
|
const userId = user.id;
|
||||||
if (!userId) {
|
if (!userId) {
|
||||||
throw new McpError(ErrorCode.InvalidRequest, `[MCP] User object missing id property`);
|
throw new McpError(ErrorCode.InvalidRequest, `[MCP] User object missing id property`);
|
||||||
}
|
}
|
||||||
|
|
||||||
const userServerMap = this.userConnections.get(userId);
|
const userServerMap = this.userConnections.get(userId);
|
||||||
let connection = userServerMap?.get(serverName);
|
let connection = forceNew ? undefined : userServerMap?.get(serverName);
|
||||||
const now = Date.now();
|
const now = Date.now();
|
||||||
|
|
||||||
// Check if user is idle
|
// Check if user is idle
|
||||||
|
@ -131,6 +120,7 @@ export abstract class UserConnectionManager {
|
||||||
oauthEnd: oauthEnd,
|
oauthEnd: oauthEnd,
|
||||||
returnOnOAuth: returnOnOAuth,
|
returnOnOAuth: returnOnOAuth,
|
||||||
requestBody: requestBody,
|
requestBody: requestBody,
|
||||||
|
connectionTimeout: connectionTimeout,
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
|
@ -45,7 +45,7 @@ describe('getUserMCPAuthMap', () => {
|
||||||
},
|
},
|
||||||
];
|
];
|
||||||
|
|
||||||
const tools = testCases.map((testCase) =>
|
const toolInstances = testCases.map((testCase) =>
|
||||||
createMockTool(testCase.normalizedToolName, testCase.originalName),
|
createMockTool(testCase.normalizedToolName, testCase.originalName),
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -54,7 +54,7 @@ describe('getUserMCPAuthMap', () => {
|
||||||
|
|
||||||
await getUserMCPAuthMap({
|
await getUserMCPAuthMap({
|
||||||
userId: 'user123',
|
userId: 'user123',
|
||||||
tools,
|
toolInstances,
|
||||||
findPluginAuthsByKeys: mockFindPluginAuthsByKeys,
|
findPluginAuthsByKeys: mockFindPluginAuthsByKeys,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -69,7 +69,7 @@ describe('getUserMCPAuthMap', () => {
|
||||||
|
|
||||||
describe('Edge Cases', () => {
|
describe('Edge Cases', () => {
|
||||||
it('should return empty object when no tools have mcpRawServerName', async () => {
|
it('should return empty object when no tools have mcpRawServerName', async () => {
|
||||||
const tools = [
|
const toolInstances = [
|
||||||
createMockTool('regular_tool', undefined, false),
|
createMockTool('regular_tool', undefined, false),
|
||||||
createMockTool('another_tool', undefined, false),
|
createMockTool('another_tool', undefined, false),
|
||||||
createMockTool('test_mcp_Server_no_raw_name', undefined),
|
createMockTool('test_mcp_Server_no_raw_name', undefined),
|
||||||
|
@ -77,7 +77,7 @@ describe('getUserMCPAuthMap', () => {
|
||||||
|
|
||||||
const result = await getUserMCPAuthMap({
|
const result = await getUserMCPAuthMap({
|
||||||
userId: 'user123',
|
userId: 'user123',
|
||||||
tools,
|
toolInstances,
|
||||||
findPluginAuthsByKeys: mockFindPluginAuthsByKeys,
|
findPluginAuthsByKeys: mockFindPluginAuthsByKeys,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -104,14 +104,14 @@ describe('getUserMCPAuthMap', () => {
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should handle database errors gracefully', async () => {
|
it('should handle database errors gracefully', async () => {
|
||||||
const tools = [createMockTool('test_mcp_Server1', 'Server1')];
|
const toolInstances = [createMockTool('test_mcp_Server1', 'Server1')];
|
||||||
const dbError = new Error('Database connection failed');
|
const dbError = new Error('Database connection failed');
|
||||||
|
|
||||||
mockGetPluginAuthMap.mockRejectedValue(dbError);
|
mockGetPluginAuthMap.mockRejectedValue(dbError);
|
||||||
|
|
||||||
const result = await getUserMCPAuthMap({
|
const result = await getUserMCPAuthMap({
|
||||||
userId: 'user123',
|
userId: 'user123',
|
||||||
tools,
|
toolInstances,
|
||||||
findPluginAuthsByKeys: mockFindPluginAuthsByKeys,
|
findPluginAuthsByKeys: mockFindPluginAuthsByKeys,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -119,18 +119,119 @@ describe('getUserMCPAuthMap', () => {
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should handle non-Error exceptions gracefully', async () => {
|
it('should handle non-Error exceptions gracefully', async () => {
|
||||||
const tools = [createMockTool('test_mcp_Server1', 'Server1')];
|
const toolInstances = [createMockTool('test_mcp_Server1', 'Server1')];
|
||||||
|
|
||||||
mockGetPluginAuthMap.mockRejectedValue('String error');
|
mockGetPluginAuthMap.mockRejectedValue('String error');
|
||||||
|
|
||||||
const result = await getUserMCPAuthMap({
|
const result = await getUserMCPAuthMap({
|
||||||
userId: 'user123',
|
userId: 'user123',
|
||||||
tools,
|
toolInstances,
|
||||||
findPluginAuthsByKeys: mockFindPluginAuthsByKeys,
|
findPluginAuthsByKeys: mockFindPluginAuthsByKeys,
|
||||||
});
|
});
|
||||||
|
|
||||||
expect(result).toEqual({});
|
expect(result).toEqual({});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('should handle mixed null/undefined values in tools array', async () => {
|
||||||
|
const tools = [
|
||||||
|
'test_mcp_Server1',
|
||||||
|
null,
|
||||||
|
'test_mcp_Server2',
|
||||||
|
undefined,
|
||||||
|
'regular_tool',
|
||||||
|
'test_mcp_Server3',
|
||||||
|
];
|
||||||
|
|
||||||
|
mockGetPluginAuthMap.mockResolvedValue({
|
||||||
|
mcp_Server1: { API_KEY: 'key1' },
|
||||||
|
mcp_Server2: { API_KEY: 'key2' },
|
||||||
|
mcp_Server3: { API_KEY: 'key3' },
|
||||||
|
});
|
||||||
|
|
||||||
|
const result = await getUserMCPAuthMap({
|
||||||
|
userId: 'user123',
|
||||||
|
tools: tools as (string | undefined)[],
|
||||||
|
findPluginAuthsByKeys: mockFindPluginAuthsByKeys,
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(mockGetPluginAuthMap).toHaveBeenCalledWith({
|
||||||
|
userId: 'user123',
|
||||||
|
pluginKeys: ['mcp_Server1', 'mcp_Server2', 'mcp_Server3'],
|
||||||
|
throwError: false,
|
||||||
|
findPluginAuthsByKeys: mockFindPluginAuthsByKeys,
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(result).toEqual({
|
||||||
|
mcp_Server1: { API_KEY: 'key1' },
|
||||||
|
mcp_Server2: { API_KEY: 'key2' },
|
||||||
|
mcp_Server3: { API_KEY: 'key3' },
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle mixed null/undefined values in servers array', async () => {
|
||||||
|
const servers = ['Server1', null, 'Server2', undefined, 'Server3'];
|
||||||
|
|
||||||
|
mockGetPluginAuthMap.mockResolvedValue({
|
||||||
|
mcp_Server1: { API_KEY: 'key1' },
|
||||||
|
mcp_Server2: { API_KEY: 'key2' },
|
||||||
|
mcp_Server3: { API_KEY: 'key3' },
|
||||||
|
});
|
||||||
|
|
||||||
|
const result = await getUserMCPAuthMap({
|
||||||
|
userId: 'user123',
|
||||||
|
servers: servers as (string | undefined)[],
|
||||||
|
findPluginAuthsByKeys: mockFindPluginAuthsByKeys,
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(mockGetPluginAuthMap).toHaveBeenCalledWith({
|
||||||
|
userId: 'user123',
|
||||||
|
pluginKeys: ['mcp_Server1', 'mcp_Server2', 'mcp_Server3'],
|
||||||
|
throwError: false,
|
||||||
|
findPluginAuthsByKeys: mockFindPluginAuthsByKeys,
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(result).toEqual({
|
||||||
|
mcp_Server1: { API_KEY: 'key1' },
|
||||||
|
mcp_Server2: { API_KEY: 'key2' },
|
||||||
|
mcp_Server3: { API_KEY: 'key3' },
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle mixed null/undefined values in toolInstances array', async () => {
|
||||||
|
const toolInstances = [
|
||||||
|
createMockTool('test_mcp_Server1', 'Server1'),
|
||||||
|
null,
|
||||||
|
createMockTool('test_mcp_Server2', 'Server2'),
|
||||||
|
undefined,
|
||||||
|
createMockTool('regular_tool', undefined, false),
|
||||||
|
createMockTool('test_mcp_Server3', 'Server3'),
|
||||||
|
];
|
||||||
|
|
||||||
|
mockGetPluginAuthMap.mockResolvedValue({
|
||||||
|
mcp_Server1: { API_KEY: 'key1' },
|
||||||
|
mcp_Server2: { API_KEY: 'key2' },
|
||||||
|
mcp_Server3: { API_KEY: 'key3' },
|
||||||
|
});
|
||||||
|
|
||||||
|
const result = await getUserMCPAuthMap({
|
||||||
|
userId: 'user123',
|
||||||
|
toolInstances: toolInstances as (GenericTool | null)[],
|
||||||
|
findPluginAuthsByKeys: mockFindPluginAuthsByKeys,
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(mockGetPluginAuthMap).toHaveBeenCalledWith({
|
||||||
|
userId: 'user123',
|
||||||
|
pluginKeys: ['mcp_Server1', 'mcp_Server2', 'mcp_Server3'],
|
||||||
|
throwError: false,
|
||||||
|
findPluginAuthsByKeys: mockFindPluginAuthsByKeys,
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(result).toEqual({
|
||||||
|
mcp_Server1: { API_KEY: 'key1' },
|
||||||
|
mcp_Server2: { API_KEY: 'key2' },
|
||||||
|
mcp_Server3: { API_KEY: 'key3' },
|
||||||
|
});
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
describe('Integration', () => {
|
describe('Integration', () => {
|
||||||
|
@ -138,7 +239,7 @@ describe('getUserMCPAuthMap', () => {
|
||||||
const originalServerName = 'Connector: Company';
|
const originalServerName = 'Connector: Company';
|
||||||
const toolName = 'test_auth_mcp_Connector__Company';
|
const toolName = 'test_auth_mcp_Connector__Company';
|
||||||
|
|
||||||
const tools = [createMockTool(toolName, originalServerName)];
|
const toolInstances = [createMockTool(toolName, originalServerName)];
|
||||||
|
|
||||||
const mockCustomUserVars = {
|
const mockCustomUserVars = {
|
||||||
'mcp_Connector: Company': {
|
'mcp_Connector: Company': {
|
||||||
|
@ -151,7 +252,7 @@ describe('getUserMCPAuthMap', () => {
|
||||||
|
|
||||||
const result = await getUserMCPAuthMap({
|
const result = await getUserMCPAuthMap({
|
||||||
userId: 'user123',
|
userId: 'user123',
|
||||||
tools,
|
toolInstances,
|
||||||
findPluginAuthsByKeys: mockFindPluginAuthsByKeys,
|
findPluginAuthsByKeys: mockFindPluginAuthsByKeys,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
|
@ -7,33 +7,56 @@ import { getPluginAuthMap } from '~/agents/auth';
|
||||||
export async function getUserMCPAuthMap({
|
export async function getUserMCPAuthMap({
|
||||||
userId,
|
userId,
|
||||||
tools,
|
tools,
|
||||||
|
servers,
|
||||||
|
toolInstances,
|
||||||
findPluginAuthsByKeys,
|
findPluginAuthsByKeys,
|
||||||
}: {
|
}: {
|
||||||
userId: string;
|
userId: string;
|
||||||
tools: GenericTool[] | undefined;
|
tools?: (string | undefined)[];
|
||||||
|
servers?: (string | undefined)[];
|
||||||
|
toolInstances?: (GenericTool | null)[];
|
||||||
findPluginAuthsByKeys: PluginAuthMethods['findPluginAuthsByKeys'];
|
findPluginAuthsByKeys: PluginAuthMethods['findPluginAuthsByKeys'];
|
||||||
}) {
|
}) {
|
||||||
if (!tools || tools.length === 0) {
|
let allMcpCustomUserVars: Record<string, Record<string, string>> = {};
|
||||||
return {};
|
let mcpPluginKeysToFetch: string[] = [];
|
||||||
}
|
try {
|
||||||
|
|
||||||
const uniqueMcpServers = new Set<string>();
|
const uniqueMcpServers = new Set<string>();
|
||||||
|
|
||||||
for (const tool of tools) {
|
if (servers != null && servers.length) {
|
||||||
|
for (const serverName of servers) {
|
||||||
|
if (!serverName) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
uniqueMcpServers.add(`${Constants.mcp_prefix}${serverName}`);
|
||||||
|
}
|
||||||
|
} else if (tools != null && tools.length) {
|
||||||
|
for (const toolName of tools) {
|
||||||
|
if (!toolName) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
const delimiterIndex = toolName.indexOf(Constants.mcp_delimiter);
|
||||||
|
if (delimiterIndex === -1) continue;
|
||||||
|
const mcpServer = toolName.slice(delimiterIndex + Constants.mcp_delimiter.length);
|
||||||
|
if (!mcpServer) continue;
|
||||||
|
uniqueMcpServers.add(`${Constants.mcp_prefix}${mcpServer}`);
|
||||||
|
}
|
||||||
|
} else if (toolInstances != null && toolInstances.length) {
|
||||||
|
for (const tool of toolInstances) {
|
||||||
|
if (!tool) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
const mcpTool = tool as GenericTool & { mcpRawServerName?: string };
|
const mcpTool = tool as GenericTool & { mcpRawServerName?: string };
|
||||||
if (mcpTool.mcpRawServerName) {
|
if (mcpTool.mcpRawServerName) {
|
||||||
uniqueMcpServers.add(`${Constants.mcp_prefix}${mcpTool.mcpRawServerName}`);
|
uniqueMcpServers.add(`${Constants.mcp_prefix}${mcpTool.mcpRawServerName}`);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (uniqueMcpServers.size === 0) {
|
if (uniqueMcpServers.size === 0) {
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
const mcpPluginKeysToFetch = Array.from(uniqueMcpServers);
|
mcpPluginKeysToFetch = Array.from(uniqueMcpServers);
|
||||||
|
|
||||||
let allMcpCustomUserVars: Record<string, Record<string, string>> = {};
|
|
||||||
try {
|
|
||||||
allMcpCustomUserVars = await getPluginAuthMap({
|
allMcpCustomUserVars = await getPluginAuthMap({
|
||||||
userId,
|
userId,
|
||||||
pluginKeys: mcpPluginKeysToFetch,
|
pluginKeys: mcpPluginKeysToFetch,
|
||||||
|
|
|
@ -446,7 +446,7 @@ export class MCPConnection extends EventEmitter {
|
||||||
const serverUrl = this.url;
|
const serverUrl = this.url;
|
||||||
logger.debug(`${this.getLogPrefix()} Server URL for OAuth: ${serverUrl}`);
|
logger.debug(`${this.getLogPrefix()} Server URL for OAuth: ${serverUrl}`);
|
||||||
|
|
||||||
const oauthTimeout = this.options.initTimeout ?? 60000;
|
const oauthTimeout = this.options.initTimeout ?? 60000 * 2;
|
||||||
/** Promise that will resolve when OAuth is handled */
|
/** Promise that will resolve when OAuth is handled */
|
||||||
const oauthHandledPromise = new Promise<void>((resolve, reject) => {
|
const oauthHandledPromise = new Promise<void>((resolve, reject) => {
|
||||||
let timeoutId: NodeJS.Timeout | null = null;
|
let timeoutId: NodeJS.Timeout | null = null;
|
||||||
|
|
|
@ -134,4 +134,5 @@ export interface OAuthConnectionOptions {
|
||||||
oauthStart?: (authURL: string) => Promise<void>;
|
oauthStart?: (authURL: string) => Promise<void>;
|
||||||
oauthEnd?: () => Promise<void>;
|
oauthEnd?: () => Promise<void>;
|
||||||
returnOnOAuth?: boolean;
|
returnOnOAuth?: boolean;
|
||||||
|
connectionTimeout?: number;
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import { AuthType, Constants, EToolResources } from 'librechat-data-provider';
|
import { AuthType, Constants, EToolResources } from 'librechat-data-provider';
|
||||||
import type { TCustomConfig, TPlugin, FunctionTool } from 'librechat-data-provider';
|
import type { TCustomConfig, TPlugin } from 'librechat-data-provider';
|
||||||
|
import { LCAvailableTools, LCFunctionTool } from '~/mcp/types';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Filters out duplicate plugins from the list of plugins.
|
* Filters out duplicate plugins from the list of plugins.
|
||||||
|
@ -60,7 +61,7 @@ export function convertMCPToolToPlugin({
|
||||||
customConfig,
|
customConfig,
|
||||||
}: {
|
}: {
|
||||||
toolKey: string;
|
toolKey: string;
|
||||||
toolData: FunctionTool;
|
toolData: LCFunctionTool;
|
||||||
customConfig?: Partial<TCustomConfig> | null;
|
customConfig?: Partial<TCustomConfig> | null;
|
||||||
}): TPlugin | undefined {
|
}): TPlugin | undefined {
|
||||||
if (!toolData.function || !toolKey.includes(Constants.mcp_delimiter)) {
|
if (!toolData.function || !toolKey.includes(Constants.mcp_delimiter)) {
|
||||||
|
@ -112,7 +113,7 @@ export function convertMCPToolsToPlugins({
|
||||||
functionTools,
|
functionTools,
|
||||||
customConfig,
|
customConfig,
|
||||||
}: {
|
}: {
|
||||||
functionTools?: Record<string, FunctionTool>;
|
functionTools?: LCAvailableTools;
|
||||||
customConfig?: Partial<TCustomConfig> | null;
|
customConfig?: Partial<TCustomConfig> | null;
|
||||||
}): TPlugin[] | undefined {
|
}): TPlugin[] | undefined {
|
||||||
if (!functionTools || typeof functionTools !== 'object') {
|
if (!functionTools || typeof functionTools !== 'object') {
|
||||||
|
|
|
@ -1525,6 +1525,8 @@ export enum Constants {
|
||||||
CONFIG_VERSION = '1.2.8',
|
CONFIG_VERSION = '1.2.8',
|
||||||
/** Standard value for the first message's `parentMessageId` value, to indicate no parent exists. */
|
/** Standard value for the first message's `parentMessageId` value, to indicate no parent exists. */
|
||||||
NO_PARENT = '00000000-0000-0000-0000-000000000000',
|
NO_PARENT = '00000000-0000-0000-0000-000000000000',
|
||||||
|
/** Standard value to use whatever the submission prelim. `responseMessageId` is */
|
||||||
|
USE_PRELIM_RESPONSE_MESSAGE_ID = 'USE_PRELIM_RESPONSE_MESSAGE_ID',
|
||||||
/** Standard value for the initial conversationId before a request is sent */
|
/** Standard value for the initial conversationId before a request is sent */
|
||||||
NEW_CONVO = 'new',
|
NEW_CONVO = 'new',
|
||||||
/** Standard value for the temporary conversationId after a request is sent and before the server responds */
|
/** Standard value for the temporary conversationId after a request is sent and before the server responds */
|
||||||
|
@ -1551,6 +1553,8 @@ export enum Constants {
|
||||||
mcp_delimiter = '_mcp_',
|
mcp_delimiter = '_mcp_',
|
||||||
/** Prefix for MCP plugins */
|
/** Prefix for MCP plugins */
|
||||||
mcp_prefix = 'mcp_',
|
mcp_prefix = 'mcp_',
|
||||||
|
/** Unique value to indicate all MCP servers */
|
||||||
|
mcp_all = 'sys__all__sys',
|
||||||
/** Placeholder Agent ID for Ephemeral Agents */
|
/** Placeholder Agent ID for Ephemeral Agents */
|
||||||
EPHEMERAL_AGENT_ID = 'ephemeral',
|
EPHEMERAL_AGENT_ID = 'ephemeral',
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue