mirror of
https://github.com/danny-avila/LibreChat.git
synced 2025-12-17 00:40:14 +01:00
👤 fix: Missing User Placeholder Fields for MCP Services (#9824)
This commit is contained in:
parent
57f8b333bc
commit
4f3683fd9a
6 changed files with 388 additions and 38 deletions
|
|
@ -1,8 +1,13 @@
|
||||||
const { logger } = require('@librechat/data-schemas');
|
const { logger } = require('@librechat/data-schemas');
|
||||||
const { SerpAPI } = require('@langchain/community/tools/serpapi');
|
const { SerpAPI } = require('@langchain/community/tools/serpapi');
|
||||||
const { Calculator } = require('@langchain/community/tools/calculator');
|
const { Calculator } = require('@langchain/community/tools/calculator');
|
||||||
const { mcpToolPattern, loadWebSearchAuth, checkAccess } = require('@librechat/api');
|
|
||||||
const { EnvVar, createCodeExecutionTool, createSearchTool } = require('@librechat/agents');
|
const { EnvVar, createCodeExecutionTool, createSearchTool } = require('@librechat/agents');
|
||||||
|
const {
|
||||||
|
checkAccess,
|
||||||
|
createSafeUser,
|
||||||
|
mcpToolPattern,
|
||||||
|
loadWebSearchAuth,
|
||||||
|
} = require('@librechat/api');
|
||||||
const {
|
const {
|
||||||
Tools,
|
Tools,
|
||||||
Constants,
|
Constants,
|
||||||
|
|
@ -410,6 +415,7 @@ Current Date & Time: ${replaceSpecialVars({ text: '{{iso_datetime}}' })}
|
||||||
/** MCP server tools are initialized sequentially by server */
|
/** MCP server tools are initialized sequentially by server */
|
||||||
let index = -1;
|
let index = -1;
|
||||||
const failedMCPServers = new Set();
|
const failedMCPServers = new Set();
|
||||||
|
const safeUser = createSafeUser(options.req?.user);
|
||||||
for (const [serverName, toolConfigs] of Object.entries(requestedMCPTools)) {
|
for (const [serverName, toolConfigs] of Object.entries(requestedMCPTools)) {
|
||||||
index++;
|
index++;
|
||||||
/** @type {LCAvailableTools} */
|
/** @type {LCAvailableTools} */
|
||||||
|
|
@ -420,14 +426,14 @@ Current Date & Time: ${replaceSpecialVars({ text: '{{iso_datetime}}' })}
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
const mcpParams = {
|
const mcpParams = {
|
||||||
res: options.res,
|
|
||||||
userId: user,
|
|
||||||
index,
|
index,
|
||||||
serverName: config.serverName,
|
|
||||||
userMCPAuthMap,
|
|
||||||
model: agent?.model ?? model,
|
|
||||||
provider: agent?.provider ?? endpoint,
|
|
||||||
signal,
|
signal,
|
||||||
|
user: safeUser,
|
||||||
|
userMCPAuthMap,
|
||||||
|
res: options.res,
|
||||||
|
model: agent?.model ?? model,
|
||||||
|
serverName: config.serverName,
|
||||||
|
provider: agent?.provider ?? endpoint,
|
||||||
};
|
};
|
||||||
|
|
||||||
if (config.type === 'all' && toolConfigs.length === 1) {
|
if (config.type === 'all' && toolConfigs.length === 1) {
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,12 @@
|
||||||
const { Router } = require('express');
|
const { Router } = require('express');
|
||||||
const { logger } = require('@librechat/data-schemas');
|
const { logger } = require('@librechat/data-schemas');
|
||||||
const { CacheKeys, Constants } = require('librechat-data-provider');
|
const { CacheKeys, Constants } = require('librechat-data-provider');
|
||||||
const { MCPOAuthHandler, MCPTokenStorage, getUserMCPAuthMap } = require('@librechat/api');
|
const {
|
||||||
|
createSafeUser,
|
||||||
|
MCPOAuthHandler,
|
||||||
|
MCPTokenStorage,
|
||||||
|
getUserMCPAuthMap,
|
||||||
|
} = require('@librechat/api');
|
||||||
const { getMCPManager, getFlowStateManager, getOAuthReconnectionManager } = require('~/config');
|
const { getMCPManager, getFlowStateManager, getOAuthReconnectionManager } = require('~/config');
|
||||||
const { getMCPSetupData, getServerConnectionStatus } = require('~/server/services/MCP');
|
const { getMCPSetupData, getServerConnectionStatus } = require('~/server/services/MCP');
|
||||||
const { findToken, updateToken, createToken, deleteTokens } = require('~/models');
|
const { findToken, updateToken, createToken, deleteTokens } = require('~/models');
|
||||||
|
|
@ -335,9 +340,9 @@ router.post('/oauth/cancel/:serverName', requireJwtAuth, async (req, res) => {
|
||||||
router.post('/:serverName/reinitialize', requireJwtAuth, async (req, res) => {
|
router.post('/:serverName/reinitialize', requireJwtAuth, async (req, res) => {
|
||||||
try {
|
try {
|
||||||
const { serverName } = req.params;
|
const { serverName } = req.params;
|
||||||
const userId = req.user?.id;
|
const user = createSafeUser(req.user);
|
||||||
|
|
||||||
if (!userId) {
|
if (!user.id) {
|
||||||
return res.status(401).json({ error: 'User not authenticated' });
|
return res.status(401).json({ error: 'User not authenticated' });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -351,7 +356,7 @@ router.post('/:serverName/reinitialize', requireJwtAuth, async (req, res) => {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
await mcpManager.disconnectUserConnection(userId, serverName);
|
await mcpManager.disconnectUserConnection(user.id, serverName);
|
||||||
logger.info(
|
logger.info(
|
||||||
`[MCP Reinitialize] Disconnected existing user connection for server: ${serverName}`,
|
`[MCP Reinitialize] Disconnected existing user connection for server: ${serverName}`,
|
||||||
);
|
);
|
||||||
|
|
@ -360,14 +365,14 @@ router.post('/:serverName/reinitialize', requireJwtAuth, async (req, res) => {
|
||||||
let userMCPAuthMap;
|
let userMCPAuthMap;
|
||||||
if (serverConfig.customUserVars && typeof serverConfig.customUserVars === 'object') {
|
if (serverConfig.customUserVars && typeof serverConfig.customUserVars === 'object') {
|
||||||
userMCPAuthMap = await getUserMCPAuthMap({
|
userMCPAuthMap = await getUserMCPAuthMap({
|
||||||
userId,
|
userId: user.id,
|
||||||
servers: [serverName],
|
servers: [serverName],
|
||||||
findPluginAuthsByKeys,
|
findPluginAuthsByKeys,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
const result = await reinitMCPServer({
|
const result = await reinitMCPServer({
|
||||||
userId,
|
user,
|
||||||
serverName,
|
serverName,
|
||||||
userMCPAuthMap,
|
userMCPAuthMap,
|
||||||
});
|
});
|
||||||
|
|
|
||||||
|
|
@ -153,7 +153,7 @@ function createOAuthCallback({ runStepEmitter, runStepDeltaEmitter }) {
|
||||||
/**
|
/**
|
||||||
* @param {Object} params
|
* @param {Object} params
|
||||||
* @param {ServerResponse} params.res - The Express response object for sending events.
|
* @param {ServerResponse} params.res - The Express response object for sending events.
|
||||||
* @param {string} params.userId - The user ID from the request object.
|
* @param {IUser} params.user - The user from the request object.
|
||||||
* @param {string} params.serverName
|
* @param {string} params.serverName
|
||||||
* @param {AbortSignal} params.signal
|
* @param {AbortSignal} params.signal
|
||||||
* @param {string} params.model
|
* @param {string} params.model
|
||||||
|
|
@ -161,9 +161,9 @@ function createOAuthCallback({ runStepEmitter, runStepDeltaEmitter }) {
|
||||||
* @param {Record<string, Record<string, string>>} [params.userMCPAuthMap]
|
* @param {Record<string, Record<string, string>>} [params.userMCPAuthMap]
|
||||||
* @returns { Promise<Array<typeof tool | { _call: (toolInput: Object | string) => unknown}>> } An object with `_call` method to execute the tool input.
|
* @returns { Promise<Array<typeof tool | { _call: (toolInput: Object | string) => unknown}>> } An object with `_call` method to execute the tool input.
|
||||||
*/
|
*/
|
||||||
async function reconnectServer({ res, userId, index, signal, serverName, userMCPAuthMap }) {
|
async function reconnectServer({ res, user, index, signal, serverName, userMCPAuthMap }) {
|
||||||
const runId = Constants.USE_PRELIM_RESPONSE_MESSAGE_ID;
|
const runId = Constants.USE_PRELIM_RESPONSE_MESSAGE_ID;
|
||||||
const flowId = `${userId}:${serverName}:${Date.now()}`;
|
const flowId = `${user.id}:${serverName}:${Date.now()}`;
|
||||||
const flowManager = getFlowStateManager(getLogStores(CacheKeys.FLOWS));
|
const flowManager = getFlowStateManager(getLogStores(CacheKeys.FLOWS));
|
||||||
const stepId = 'step_oauth_login_' + serverName;
|
const stepId = 'step_oauth_login_' + serverName;
|
||||||
const toolCall = {
|
const toolCall = {
|
||||||
|
|
@ -192,7 +192,7 @@ async function reconnectServer({ res, userId, index, signal, serverName, userMCP
|
||||||
flowManager,
|
flowManager,
|
||||||
});
|
});
|
||||||
return await reinitMCPServer({
|
return await reinitMCPServer({
|
||||||
userId,
|
user,
|
||||||
signal,
|
signal,
|
||||||
serverName,
|
serverName,
|
||||||
oauthStart,
|
oauthStart,
|
||||||
|
|
@ -212,7 +212,7 @@ async function reconnectServer({ res, userId, index, signal, serverName, userMCP
|
||||||
*
|
*
|
||||||
* @param {Object} params
|
* @param {Object} params
|
||||||
* @param {ServerResponse} params.res - The Express response object for sending events.
|
* @param {ServerResponse} params.res - The Express response object for sending events.
|
||||||
* @param {string} params.userId - The user ID from the request object.
|
* @param {IUser} params.user - The user from the request object.
|
||||||
* @param {string} params.serverName
|
* @param {string} params.serverName
|
||||||
* @param {string} params.model
|
* @param {string} params.model
|
||||||
* @param {Providers | EModelEndpoint} params.provider - The provider for the tool.
|
* @param {Providers | EModelEndpoint} params.provider - The provider for the tool.
|
||||||
|
|
@ -221,16 +221,8 @@ async function reconnectServer({ res, userId, index, signal, serverName, userMCP
|
||||||
* @param {Record<string, Record<string, string>>} [params.userMCPAuthMap]
|
* @param {Record<string, Record<string, string>>} [params.userMCPAuthMap]
|
||||||
* @returns { Promise<Array<typeof tool | { _call: (toolInput: Object | string) => unknown}>> } An object with `_call` method to execute the tool input.
|
* @returns { Promise<Array<typeof tool | { _call: (toolInput: Object | string) => unknown}>> } An object with `_call` method to execute the tool input.
|
||||||
*/
|
*/
|
||||||
async function createMCPTools({
|
async function createMCPTools({ res, user, index, signal, serverName, provider, userMCPAuthMap }) {
|
||||||
res,
|
const result = await reconnectServer({ res, user, index, signal, serverName, userMCPAuthMap });
|
||||||
userId,
|
|
||||||
index,
|
|
||||||
signal,
|
|
||||||
serverName,
|
|
||||||
provider,
|
|
||||||
userMCPAuthMap,
|
|
||||||
}) {
|
|
||||||
const result = await reconnectServer({ res, userId, index, signal, serverName, userMCPAuthMap });
|
|
||||||
if (!result || !result.tools) {
|
if (!result || !result.tools) {
|
||||||
logger.warn(`[MCP][${serverName}] Failed to reinitialize MCP server.`);
|
logger.warn(`[MCP][${serverName}] Failed to reinitialize MCP server.`);
|
||||||
return;
|
return;
|
||||||
|
|
@ -240,7 +232,7 @@ async function createMCPTools({
|
||||||
for (const tool of result.tools) {
|
for (const tool of result.tools) {
|
||||||
const toolInstance = await createMCPTool({
|
const toolInstance = await createMCPTool({
|
||||||
res,
|
res,
|
||||||
userId,
|
user,
|
||||||
provider,
|
provider,
|
||||||
userMCPAuthMap,
|
userMCPAuthMap,
|
||||||
availableTools: result.availableTools,
|
availableTools: result.availableTools,
|
||||||
|
|
@ -258,7 +250,7 @@ async function createMCPTools({
|
||||||
* Creates a single tool from the specified MCP Server via `toolKey`.
|
* Creates a single tool from the specified MCP Server via `toolKey`.
|
||||||
* @param {Object} params
|
* @param {Object} params
|
||||||
* @param {ServerResponse} params.res - The Express response object for sending events.
|
* @param {ServerResponse} params.res - The Express response object for sending events.
|
||||||
* @param {string} params.userId - The user ID from the request object.
|
* @param {IUser} params.user - The user from the request object.
|
||||||
* @param {string} params.toolKey - The toolKey for the tool.
|
* @param {string} params.toolKey - The toolKey for the tool.
|
||||||
* @param {string} params.model - The model for the tool.
|
* @param {string} params.model - The model for the tool.
|
||||||
* @param {number} [params.index]
|
* @param {number} [params.index]
|
||||||
|
|
@ -270,7 +262,7 @@ async function createMCPTools({
|
||||||
*/
|
*/
|
||||||
async function createMCPTool({
|
async function createMCPTool({
|
||||||
res,
|
res,
|
||||||
userId,
|
user,
|
||||||
index,
|
index,
|
||||||
signal,
|
signal,
|
||||||
toolKey,
|
toolKey,
|
||||||
|
|
@ -288,7 +280,7 @@ async function createMCPTool({
|
||||||
);
|
);
|
||||||
const result = await reconnectServer({
|
const result = await reconnectServer({
|
||||||
res,
|
res,
|
||||||
userId,
|
user,
|
||||||
index,
|
index,
|
||||||
signal,
|
signal,
|
||||||
serverName,
|
serverName,
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,45 @@
|
||||||
const { logger } = require('@librechat/data-schemas');
|
const { logger } = require('@librechat/data-schemas');
|
||||||
const { MCPOAuthHandler } = require('@librechat/api');
|
const { MCPOAuthHandler } = require('@librechat/api');
|
||||||
const { CacheKeys } = require('librechat-data-provider');
|
const { CacheKeys } = require('librechat-data-provider');
|
||||||
const { getMCPSetupData, checkOAuthFlowStatus, getServerConnectionStatus } = require('./MCP');
|
const {
|
||||||
|
createMCPTool,
|
||||||
|
createMCPTools,
|
||||||
|
getMCPSetupData,
|
||||||
|
checkOAuthFlowStatus,
|
||||||
|
getServerConnectionStatus,
|
||||||
|
} = require('./MCP');
|
||||||
|
|
||||||
// Mock all dependencies
|
// Mock all dependencies
|
||||||
jest.mock('@librechat/data-schemas', () => ({
|
jest.mock('@librechat/data-schemas', () => ({
|
||||||
logger: {
|
logger: {
|
||||||
debug: jest.fn(),
|
debug: jest.fn(),
|
||||||
error: jest.fn(),
|
error: jest.fn(),
|
||||||
|
info: jest.fn(),
|
||||||
|
warn: jest.fn(),
|
||||||
|
},
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('@langchain/core/tools', () => ({
|
||||||
|
tool: jest.fn((fn, config) => {
|
||||||
|
const toolInstance = { _call: fn, ...config };
|
||||||
|
return toolInstance;
|
||||||
|
}),
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('@librechat/agents', () => ({
|
||||||
|
Providers: {
|
||||||
|
VERTEXAI: 'vertexai',
|
||||||
|
GOOGLE: 'google',
|
||||||
|
},
|
||||||
|
StepTypes: {
|
||||||
|
TOOL_CALLS: 'tool_calls',
|
||||||
|
},
|
||||||
|
GraphEvents: {
|
||||||
|
ON_RUN_STEP_DELTA: 'on_run_step_delta',
|
||||||
|
ON_RUN_STEP: 'on_run_step',
|
||||||
|
},
|
||||||
|
Constants: {
|
||||||
|
CONTENT_AND_ARTIFACT: 'content_and_artifact',
|
||||||
},
|
},
|
||||||
}));
|
}));
|
||||||
|
|
||||||
|
|
@ -15,12 +47,27 @@ jest.mock('@librechat/api', () => ({
|
||||||
MCPOAuthHandler: {
|
MCPOAuthHandler: {
|
||||||
generateFlowId: jest.fn(),
|
generateFlowId: jest.fn(),
|
||||||
},
|
},
|
||||||
|
sendEvent: jest.fn(),
|
||||||
|
normalizeServerName: jest.fn((name) => name),
|
||||||
|
convertWithResolvedRefs: jest.fn((params) => params),
|
||||||
}));
|
}));
|
||||||
|
|
||||||
jest.mock('librechat-data-provider', () => ({
|
jest.mock('librechat-data-provider', () => ({
|
||||||
CacheKeys: {
|
CacheKeys: {
|
||||||
FLOWS: 'flows',
|
FLOWS: 'flows',
|
||||||
},
|
},
|
||||||
|
Constants: {
|
||||||
|
USE_PRELIM_RESPONSE_MESSAGE_ID: 'prelim_response_id',
|
||||||
|
mcp_delimiter: '::',
|
||||||
|
mcp_prefix: 'mcp_',
|
||||||
|
},
|
||||||
|
ContentTypes: {
|
||||||
|
TEXT: 'text',
|
||||||
|
},
|
||||||
|
isAssistantsEndpoint: jest.fn(() => false),
|
||||||
|
Time: {
|
||||||
|
TWO_MINUTES: 120000,
|
||||||
|
},
|
||||||
}));
|
}));
|
||||||
|
|
||||||
jest.mock('./Config', () => ({
|
jest.mock('./Config', () => ({
|
||||||
|
|
@ -44,8 +91,11 @@ jest.mock('~/models', () => ({
|
||||||
updateToken: jest.fn(),
|
updateToken: jest.fn(),
|
||||||
}));
|
}));
|
||||||
|
|
||||||
|
jest.mock('./Tools/mcp', () => ({
|
||||||
|
reinitMCPServer: jest.fn(),
|
||||||
|
}));
|
||||||
|
|
||||||
describe('tests for the new helper functions used by the MCP connection status endpoints', () => {
|
describe('tests for the new helper functions used by the MCP connection status endpoints', () => {
|
||||||
let mockLoadCustomConfig;
|
|
||||||
let mockGetMCPManager;
|
let mockGetMCPManager;
|
||||||
let mockGetFlowStateManager;
|
let mockGetFlowStateManager;
|
||||||
let mockGetLogStores;
|
let mockGetLogStores;
|
||||||
|
|
@ -54,7 +104,6 @@ describe('tests for the new helper functions used by the MCP connection status e
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
jest.clearAllMocks();
|
jest.clearAllMocks();
|
||||||
|
|
||||||
mockLoadCustomConfig = require('./Config').loadCustomConfig;
|
|
||||||
mockGetMCPManager = require('~/config').getMCPManager;
|
mockGetMCPManager = require('~/config').getMCPManager;
|
||||||
mockGetFlowStateManager = require('~/config').getFlowStateManager;
|
mockGetFlowStateManager = require('~/config').getFlowStateManager;
|
||||||
mockGetLogStores = require('~/cache').getLogStores;
|
mockGetLogStores = require('~/cache').getLogStores;
|
||||||
|
|
@ -567,3 +616,275 @@ describe('tests for the new helper functions used by the MCP connection status e
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe('User parameter passing tests', () => {
|
||||||
|
let mockReinitMCPServer;
|
||||||
|
let mockGetFlowStateManager;
|
||||||
|
let mockGetLogStores;
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
jest.clearAllMocks();
|
||||||
|
mockReinitMCPServer = require('./Tools/mcp').reinitMCPServer;
|
||||||
|
mockGetFlowStateManager = require('~/config').getFlowStateManager;
|
||||||
|
mockGetLogStores = require('~/cache').getLogStores;
|
||||||
|
|
||||||
|
// Setup default mocks
|
||||||
|
mockGetLogStores.mockReturnValue({});
|
||||||
|
mockGetFlowStateManager.mockReturnValue({
|
||||||
|
createFlowWithHandler: jest.fn(),
|
||||||
|
failFlow: jest.fn(),
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('createMCPTools', () => {
|
||||||
|
it('should pass user parameter to reinitMCPServer when calling reconnectServer internally', async () => {
|
||||||
|
const mockUser = { id: 'test-user-123', name: 'Test User' };
|
||||||
|
const mockRes = { write: jest.fn(), flush: jest.fn() };
|
||||||
|
const mockSignal = new AbortController().signal;
|
||||||
|
|
||||||
|
mockReinitMCPServer.mockResolvedValue({
|
||||||
|
tools: [{ name: 'test-tool' }],
|
||||||
|
availableTools: {
|
||||||
|
'test-tool::test-server': {
|
||||||
|
function: {
|
||||||
|
description: 'Test tool',
|
||||||
|
parameters: { type: 'object', properties: {} },
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
await createMCPTools({
|
||||||
|
res: mockRes,
|
||||||
|
user: mockUser,
|
||||||
|
serverName: 'test-server',
|
||||||
|
provider: 'openai',
|
||||||
|
signal: mockSignal,
|
||||||
|
userMCPAuthMap: {},
|
||||||
|
});
|
||||||
|
|
||||||
|
// Verify reinitMCPServer was called with the user
|
||||||
|
expect(mockReinitMCPServer).toHaveBeenCalledWith(
|
||||||
|
expect.objectContaining({
|
||||||
|
user: mockUser,
|
||||||
|
serverName: 'test-server',
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
expect(mockReinitMCPServer.mock.calls[0][0].user).toBe(mockUser);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should throw error if user is not provided', async () => {
|
||||||
|
const mockRes = { write: jest.fn(), flush: jest.fn() };
|
||||||
|
|
||||||
|
mockReinitMCPServer.mockResolvedValue({
|
||||||
|
tools: [],
|
||||||
|
availableTools: {},
|
||||||
|
});
|
||||||
|
|
||||||
|
// Call without user should throw error
|
||||||
|
await expect(
|
||||||
|
createMCPTools({
|
||||||
|
res: mockRes,
|
||||||
|
user: undefined,
|
||||||
|
serverName: 'test-server',
|
||||||
|
provider: 'openai',
|
||||||
|
userMCPAuthMap: {},
|
||||||
|
}),
|
||||||
|
).rejects.toThrow("Cannot read properties of undefined (reading 'id')");
|
||||||
|
|
||||||
|
// Verify reinitMCPServer was not called due to early error
|
||||||
|
expect(mockReinitMCPServer).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('createMCPTool', () => {
|
||||||
|
it('should pass user parameter to reinitMCPServer when tool not in cache', async () => {
|
||||||
|
const mockUser = { id: 'test-user-456', email: 'test@example.com' };
|
||||||
|
const mockRes = { write: jest.fn(), flush: jest.fn() };
|
||||||
|
const mockSignal = new AbortController().signal;
|
||||||
|
|
||||||
|
mockReinitMCPServer.mockResolvedValue({
|
||||||
|
availableTools: {
|
||||||
|
'test-tool::test-server': {
|
||||||
|
function: {
|
||||||
|
description: 'Test tool',
|
||||||
|
parameters: { type: 'object', properties: {} },
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
// Call without availableTools to trigger reinit
|
||||||
|
await createMCPTool({
|
||||||
|
res: mockRes,
|
||||||
|
user: mockUser,
|
||||||
|
toolKey: 'test-tool::test-server',
|
||||||
|
provider: 'openai',
|
||||||
|
signal: mockSignal,
|
||||||
|
userMCPAuthMap: {},
|
||||||
|
availableTools: undefined, // Force reinit
|
||||||
|
});
|
||||||
|
|
||||||
|
// Verify reinitMCPServer was called with the user
|
||||||
|
expect(mockReinitMCPServer).toHaveBeenCalledWith(
|
||||||
|
expect.objectContaining({
|
||||||
|
user: mockUser,
|
||||||
|
serverName: 'test-server',
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
expect(mockReinitMCPServer.mock.calls[0][0].user).toBe(mockUser);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should not call reinitMCPServer when tool is in cache', async () => {
|
||||||
|
const mockUser = { id: 'test-user-789' };
|
||||||
|
const mockRes = { write: jest.fn(), flush: jest.fn() };
|
||||||
|
|
||||||
|
const availableTools = {
|
||||||
|
'test-tool::test-server': {
|
||||||
|
function: {
|
||||||
|
description: 'Cached tool',
|
||||||
|
parameters: { type: 'object', properties: {} },
|
||||||
|
},
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
await createMCPTool({
|
||||||
|
res: mockRes,
|
||||||
|
user: mockUser,
|
||||||
|
toolKey: 'test-tool::test-server',
|
||||||
|
provider: 'openai',
|
||||||
|
userMCPAuthMap: {},
|
||||||
|
availableTools: availableTools,
|
||||||
|
});
|
||||||
|
|
||||||
|
// Verify reinitMCPServer was NOT called since tool was in cache
|
||||||
|
expect(mockReinitMCPServer).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('reinitMCPServer (via reconnectServer)', () => {
|
||||||
|
it('should always receive user parameter when called from createMCPTools', async () => {
|
||||||
|
const mockUser = { id: 'user-001', role: 'admin' };
|
||||||
|
const mockRes = { write: jest.fn(), flush: jest.fn() };
|
||||||
|
|
||||||
|
// Track all calls to reinitMCPServer
|
||||||
|
const reinitCalls = [];
|
||||||
|
mockReinitMCPServer.mockImplementation((params) => {
|
||||||
|
reinitCalls.push(params);
|
||||||
|
return Promise.resolve({
|
||||||
|
tools: [{ name: 'tool1' }, { name: 'tool2' }],
|
||||||
|
availableTools: {
|
||||||
|
'tool1::server1': { function: { description: 'Tool 1', parameters: {} } },
|
||||||
|
'tool2::server1': { function: { description: 'Tool 2', parameters: {} } },
|
||||||
|
},
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
await createMCPTools({
|
||||||
|
res: mockRes,
|
||||||
|
user: mockUser,
|
||||||
|
serverName: 'server1',
|
||||||
|
provider: 'anthropic',
|
||||||
|
userMCPAuthMap: {},
|
||||||
|
});
|
||||||
|
|
||||||
|
// Verify all calls to reinitMCPServer had the user
|
||||||
|
expect(reinitCalls.length).toBeGreaterThan(0);
|
||||||
|
reinitCalls.forEach((call) => {
|
||||||
|
expect(call.user).toBe(mockUser);
|
||||||
|
expect(call.user.id).toBe('user-001');
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should always receive user parameter when called from createMCPTool', async () => {
|
||||||
|
const mockUser = { id: 'user-002', permissions: ['read', 'write'] };
|
||||||
|
const mockRes = { write: jest.fn(), flush: jest.fn() };
|
||||||
|
|
||||||
|
// Track all calls to reinitMCPServer
|
||||||
|
const reinitCalls = [];
|
||||||
|
mockReinitMCPServer.mockImplementation((params) => {
|
||||||
|
reinitCalls.push(params);
|
||||||
|
return Promise.resolve({
|
||||||
|
availableTools: {
|
||||||
|
'my-tool::my-server': {
|
||||||
|
function: { description: 'My Tool', parameters: {} },
|
||||||
|
},
|
||||||
|
},
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
await createMCPTool({
|
||||||
|
res: mockRes,
|
||||||
|
user: mockUser,
|
||||||
|
toolKey: 'my-tool::my-server',
|
||||||
|
provider: 'google',
|
||||||
|
userMCPAuthMap: {},
|
||||||
|
availableTools: undefined, // Force reinit
|
||||||
|
});
|
||||||
|
|
||||||
|
// Verify the call to reinitMCPServer had the user
|
||||||
|
expect(reinitCalls.length).toBe(1);
|
||||||
|
expect(reinitCalls[0].user).toBe(mockUser);
|
||||||
|
expect(reinitCalls[0].user.id).toBe('user-002');
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('User parameter integrity', () => {
|
||||||
|
it('should preserve user object properties through the call chain', async () => {
|
||||||
|
const complexUser = {
|
||||||
|
id: 'complex-user',
|
||||||
|
name: 'John Doe',
|
||||||
|
email: 'john@example.com',
|
||||||
|
metadata: { subscription: 'premium', settings: { theme: 'dark' } },
|
||||||
|
};
|
||||||
|
const mockRes = { write: jest.fn(), flush: jest.fn() };
|
||||||
|
|
||||||
|
let capturedUser = null;
|
||||||
|
mockReinitMCPServer.mockImplementation((params) => {
|
||||||
|
capturedUser = params.user;
|
||||||
|
return Promise.resolve({
|
||||||
|
tools: [{ name: 'test' }],
|
||||||
|
availableTools: {
|
||||||
|
'test::server': { function: { description: 'Test', parameters: {} } },
|
||||||
|
},
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
await createMCPTools({
|
||||||
|
res: mockRes,
|
||||||
|
user: complexUser,
|
||||||
|
serverName: 'server',
|
||||||
|
provider: 'openai',
|
||||||
|
userMCPAuthMap: {},
|
||||||
|
});
|
||||||
|
|
||||||
|
// Verify the complete user object was passed
|
||||||
|
expect(capturedUser).toEqual(complexUser);
|
||||||
|
expect(capturedUser.id).toBe('complex-user');
|
||||||
|
expect(capturedUser.metadata.subscription).toBe('premium');
|
||||||
|
expect(capturedUser.metadata.settings.theme).toBe('dark');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should throw error when user is null', async () => {
|
||||||
|
const mockRes = { write: jest.fn(), flush: jest.fn() };
|
||||||
|
|
||||||
|
mockReinitMCPServer.mockResolvedValue({
|
||||||
|
tools: [],
|
||||||
|
availableTools: {},
|
||||||
|
});
|
||||||
|
|
||||||
|
await expect(
|
||||||
|
createMCPTools({
|
||||||
|
res: mockRes,
|
||||||
|
user: null,
|
||||||
|
serverName: 'test-server',
|
||||||
|
provider: 'openai',
|
||||||
|
userMCPAuthMap: {},
|
||||||
|
}),
|
||||||
|
).rejects.toThrow("Cannot read properties of null (reading 'id')");
|
||||||
|
|
||||||
|
// Verify reinitMCPServer was not called due to early error
|
||||||
|
expect(mockReinitMCPServer).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ const { getLogStores } = require('~/cache');
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @param {Object} params
|
* @param {Object} params
|
||||||
* @param {string} params.userId
|
* @param {IUser} params.user - The user from the request object.
|
||||||
* @param {string} params.serverName - The name of the MCP server
|
* @param {string} params.serverName - The name of the MCP server
|
||||||
* @param {boolean} params.returnOnOAuth - Whether to initiate OAuth and return, or wait for OAuth flow to finish
|
* @param {boolean} params.returnOnOAuth - Whether to initiate OAuth and return, or wait for OAuth flow to finish
|
||||||
* @param {AbortSignal} [params.signal] - The abort signal to handle cancellation.
|
* @param {AbortSignal} [params.signal] - The abort signal to handle cancellation.
|
||||||
|
|
@ -18,7 +18,7 @@ const { getLogStores } = require('~/cache');
|
||||||
* @param {Record<string, Record<string, string>>} [params.userMCPAuthMap]
|
* @param {Record<string, Record<string, string>>} [params.userMCPAuthMap]
|
||||||
*/
|
*/
|
||||||
async function reinitMCPServer({
|
async function reinitMCPServer({
|
||||||
userId,
|
user,
|
||||||
signal,
|
signal,
|
||||||
forceNew,
|
forceNew,
|
||||||
serverName,
|
serverName,
|
||||||
|
|
@ -51,7 +51,7 @@ async function reinitMCPServer({
|
||||||
|
|
||||||
try {
|
try {
|
||||||
userConnection = await mcpManager.getUserConnection({
|
userConnection = await mcpManager.getUserConnection({
|
||||||
user: { id: userId },
|
user,
|
||||||
signal,
|
signal,
|
||||||
forceNew,
|
forceNew,
|
||||||
oauthStart,
|
oauthStart,
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
import { extractEnvVariable } from 'librechat-data-provider';
|
import { extractEnvVariable } from 'librechat-data-provider';
|
||||||
import type { TUser, MCPOptions } from 'librechat-data-provider';
|
import type { TUser, MCPOptions } from 'librechat-data-provider';
|
||||||
|
import type { IUser } from '@librechat/data-schemas';
|
||||||
import type { RequestBody } from '~/types';
|
import type { RequestBody } from '~/types';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
@ -26,6 +27,31 @@ const ALLOWED_USER_FIELDS = [
|
||||||
'termsAccepted',
|
'termsAccepted',
|
||||||
] as const;
|
] as const;
|
||||||
|
|
||||||
|
type AllowedUserField = (typeof ALLOWED_USER_FIELDS)[number];
|
||||||
|
type SafeUser = Pick<IUser, AllowedUserField>;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a safe user object containing only allowed fields.
|
||||||
|
* Optimized for performance while maintaining type safety.
|
||||||
|
*
|
||||||
|
* @param user - The user object to extract safe fields from
|
||||||
|
* @returns A new object containing only allowed fields
|
||||||
|
*/
|
||||||
|
export function createSafeUser(user: IUser | null | undefined): Partial<SafeUser> {
|
||||||
|
if (!user) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
const safeUser: Partial<SafeUser> = {};
|
||||||
|
for (const field of ALLOWED_USER_FIELDS) {
|
||||||
|
if (field in user) {
|
||||||
|
safeUser[field] = user[field];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return safeUser;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* List of allowed request body fields that can be used in header placeholders.
|
* List of allowed request body fields that can be used in header placeholders.
|
||||||
* These are common fields from the request body that are safe to expose in headers.
|
* These are common fields from the request body that are safe to expose in headers.
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue