👤 fix: Missing User Placeholder Fields for MCP Services (#9824)

This commit is contained in:
Danny Avila 2025-09-24 22:48:38 -04:00 committed by GitHub
parent 57f8b333bc
commit 4f3683fd9a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 388 additions and 38 deletions

View file

@ -1,8 +1,13 @@
const { logger } = require('@librechat/data-schemas');
const { SerpAPI } = require('@langchain/community/tools/serpapi');
const { Calculator } = require('@langchain/community/tools/calculator');
const { mcpToolPattern, loadWebSearchAuth, checkAccess } = require('@librechat/api');
const { EnvVar, createCodeExecutionTool, createSearchTool } = require('@librechat/agents');
const {
checkAccess,
createSafeUser,
mcpToolPattern,
loadWebSearchAuth,
} = require('@librechat/api');
const {
Tools,
Constants,
@ -410,6 +415,7 @@ Current Date & Time: ${replaceSpecialVars({ text: '{{iso_datetime}}' })}
/** MCP server tools are initialized sequentially by server */
let index = -1;
const failedMCPServers = new Set();
const safeUser = createSafeUser(options.req?.user);
for (const [serverName, toolConfigs] of Object.entries(requestedMCPTools)) {
index++;
/** @type {LCAvailableTools} */
@ -420,14 +426,14 @@ Current Date & Time: ${replaceSpecialVars({ text: '{{iso_datetime}}' })}
continue;
}
const mcpParams = {
res: options.res,
userId: user,
index,
serverName: config.serverName,
userMCPAuthMap,
model: agent?.model ?? model,
provider: agent?.provider ?? endpoint,
signal,
user: safeUser,
userMCPAuthMap,
res: options.res,
model: agent?.model ?? model,
serverName: config.serverName,
provider: agent?.provider ?? endpoint,
};
if (config.type === 'all' && toolConfigs.length === 1) {

View file

@ -1,7 +1,12 @@
const { Router } = require('express');
const { logger } = require('@librechat/data-schemas');
const { CacheKeys, Constants } = require('librechat-data-provider');
const { MCPOAuthHandler, MCPTokenStorage, getUserMCPAuthMap } = require('@librechat/api');
const {
createSafeUser,
MCPOAuthHandler,
MCPTokenStorage,
getUserMCPAuthMap,
} = require('@librechat/api');
const { getMCPManager, getFlowStateManager, getOAuthReconnectionManager } = require('~/config');
const { getMCPSetupData, getServerConnectionStatus } = require('~/server/services/MCP');
const { findToken, updateToken, createToken, deleteTokens } = require('~/models');
@ -335,9 +340,9 @@ router.post('/oauth/cancel/:serverName', requireJwtAuth, async (req, res) => {
router.post('/:serverName/reinitialize', requireJwtAuth, async (req, res) => {
try {
const { serverName } = req.params;
const userId = req.user?.id;
const user = createSafeUser(req.user);
if (!userId) {
if (!user.id) {
return res.status(401).json({ error: 'User not authenticated' });
}
@ -351,7 +356,7 @@ router.post('/:serverName/reinitialize', requireJwtAuth, async (req, res) => {
});
}
await mcpManager.disconnectUserConnection(userId, serverName);
await mcpManager.disconnectUserConnection(user.id, serverName);
logger.info(
`[MCP Reinitialize] Disconnected existing user connection for server: ${serverName}`,
);
@ -360,14 +365,14 @@ router.post('/:serverName/reinitialize', requireJwtAuth, async (req, res) => {
let userMCPAuthMap;
if (serverConfig.customUserVars && typeof serverConfig.customUserVars === 'object') {
userMCPAuthMap = await getUserMCPAuthMap({
userId,
userId: user.id,
servers: [serverName],
findPluginAuthsByKeys,
});
}
const result = await reinitMCPServer({
userId,
user,
serverName,
userMCPAuthMap,
});

View file

@ -153,7 +153,7 @@ function createOAuthCallback({ runStepEmitter, runStepDeltaEmitter }) {
/**
* @param {Object} params
* @param {ServerResponse} params.res - The Express response object for sending events.
* @param {string} params.userId - The user ID from the request object.
* @param {IUser} params.user - The user from the request object.
* @param {string} params.serverName
* @param {AbortSignal} params.signal
* @param {string} params.model
@ -161,9 +161,9 @@ function createOAuthCallback({ runStepEmitter, runStepDeltaEmitter }) {
* @param {Record<string, Record<string, string>>} [params.userMCPAuthMap]
* @returns { Promise<Array<typeof tool | { _call: (toolInput: Object | string) => unknown}>> } An object with `_call` method to execute the tool input.
*/
async function reconnectServer({ res, userId, index, signal, serverName, userMCPAuthMap }) {
async function reconnectServer({ res, user, index, signal, serverName, userMCPAuthMap }) {
const runId = Constants.USE_PRELIM_RESPONSE_MESSAGE_ID;
const flowId = `${userId}:${serverName}:${Date.now()}`;
const flowId = `${user.id}:${serverName}:${Date.now()}`;
const flowManager = getFlowStateManager(getLogStores(CacheKeys.FLOWS));
const stepId = 'step_oauth_login_' + serverName;
const toolCall = {
@ -192,7 +192,7 @@ async function reconnectServer({ res, userId, index, signal, serverName, userMCP
flowManager,
});
return await reinitMCPServer({
userId,
user,
signal,
serverName,
oauthStart,
@ -212,7 +212,7 @@ async function reconnectServer({ res, userId, index, signal, serverName, userMCP
*
* @param {Object} params
* @param {ServerResponse} params.res - The Express response object for sending events.
* @param {string} params.userId - The user ID from the request object.
* @param {IUser} params.user - The user from the request object.
* @param {string} params.serverName
* @param {string} params.model
* @param {Providers | EModelEndpoint} params.provider - The provider for the tool.
@ -221,16 +221,8 @@ async function reconnectServer({ res, userId, index, signal, serverName, userMCP
* @param {Record<string, Record<string, string>>} [params.userMCPAuthMap]
* @returns { Promise<Array<typeof tool | { _call: (toolInput: Object | string) => unknown}>> } An object with `_call` method to execute the tool input.
*/
async function createMCPTools({
res,
userId,
index,
signal,
serverName,
provider,
userMCPAuthMap,
}) {
const result = await reconnectServer({ res, userId, index, signal, serverName, userMCPAuthMap });
async function createMCPTools({ res, user, index, signal, serverName, provider, userMCPAuthMap }) {
const result = await reconnectServer({ res, user, index, signal, serverName, userMCPAuthMap });
if (!result || !result.tools) {
logger.warn(`[MCP][${serverName}] Failed to reinitialize MCP server.`);
return;
@ -240,7 +232,7 @@ async function createMCPTools({
for (const tool of result.tools) {
const toolInstance = await createMCPTool({
res,
userId,
user,
provider,
userMCPAuthMap,
availableTools: result.availableTools,
@ -258,7 +250,7 @@ async function createMCPTools({
* Creates a single tool from the specified MCP Server via `toolKey`.
* @param {Object} params
* @param {ServerResponse} params.res - The Express response object for sending events.
* @param {string} params.userId - The user ID from the request object.
* @param {IUser} params.user - The user from the request object.
* @param {string} params.toolKey - The toolKey for the tool.
* @param {string} params.model - The model for the tool.
* @param {number} [params.index]
@ -270,7 +262,7 @@ async function createMCPTools({
*/
async function createMCPTool({
res,
userId,
user,
index,
signal,
toolKey,
@ -288,7 +280,7 @@ async function createMCPTool({
);
const result = await reconnectServer({
res,
userId,
user,
index,
signal,
serverName,

View file

@ -1,13 +1,45 @@
const { logger } = require('@librechat/data-schemas');
const { MCPOAuthHandler } = require('@librechat/api');
const { CacheKeys } = require('librechat-data-provider');
const { getMCPSetupData, checkOAuthFlowStatus, getServerConnectionStatus } = require('./MCP');
const {
createMCPTool,
createMCPTools,
getMCPSetupData,
checkOAuthFlowStatus,
getServerConnectionStatus,
} = require('./MCP');
// Mock all dependencies
jest.mock('@librechat/data-schemas', () => ({
logger: {
debug: jest.fn(),
error: jest.fn(),
info: jest.fn(),
warn: jest.fn(),
},
}));
jest.mock('@langchain/core/tools', () => ({
tool: jest.fn((fn, config) => {
const toolInstance = { _call: fn, ...config };
return toolInstance;
}),
}));
jest.mock('@librechat/agents', () => ({
Providers: {
VERTEXAI: 'vertexai',
GOOGLE: 'google',
},
StepTypes: {
TOOL_CALLS: 'tool_calls',
},
GraphEvents: {
ON_RUN_STEP_DELTA: 'on_run_step_delta',
ON_RUN_STEP: 'on_run_step',
},
Constants: {
CONTENT_AND_ARTIFACT: 'content_and_artifact',
},
}));
@ -15,12 +47,27 @@ jest.mock('@librechat/api', () => ({
MCPOAuthHandler: {
generateFlowId: jest.fn(),
},
sendEvent: jest.fn(),
normalizeServerName: jest.fn((name) => name),
convertWithResolvedRefs: jest.fn((params) => params),
}));
jest.mock('librechat-data-provider', () => ({
CacheKeys: {
FLOWS: 'flows',
},
Constants: {
USE_PRELIM_RESPONSE_MESSAGE_ID: 'prelim_response_id',
mcp_delimiter: '::',
mcp_prefix: 'mcp_',
},
ContentTypes: {
TEXT: 'text',
},
isAssistantsEndpoint: jest.fn(() => false),
Time: {
TWO_MINUTES: 120000,
},
}));
jest.mock('./Config', () => ({
@ -44,8 +91,11 @@ jest.mock('~/models', () => ({
updateToken: jest.fn(),
}));
jest.mock('./Tools/mcp', () => ({
reinitMCPServer: jest.fn(),
}));
describe('tests for the new helper functions used by the MCP connection status endpoints', () => {
let mockLoadCustomConfig;
let mockGetMCPManager;
let mockGetFlowStateManager;
let mockGetLogStores;
@ -54,7 +104,6 @@ describe('tests for the new helper functions used by the MCP connection status e
beforeEach(() => {
jest.clearAllMocks();
mockLoadCustomConfig = require('./Config').loadCustomConfig;
mockGetMCPManager = require('~/config').getMCPManager;
mockGetFlowStateManager = require('~/config').getFlowStateManager;
mockGetLogStores = require('~/cache').getLogStores;
@ -567,3 +616,275 @@ describe('tests for the new helper functions used by the MCP connection status e
});
});
});
describe('User parameter passing tests', () => {
let mockReinitMCPServer;
let mockGetFlowStateManager;
let mockGetLogStores;
beforeEach(() => {
jest.clearAllMocks();
mockReinitMCPServer = require('./Tools/mcp').reinitMCPServer;
mockGetFlowStateManager = require('~/config').getFlowStateManager;
mockGetLogStores = require('~/cache').getLogStores;
// Setup default mocks
mockGetLogStores.mockReturnValue({});
mockGetFlowStateManager.mockReturnValue({
createFlowWithHandler: jest.fn(),
failFlow: jest.fn(),
});
});
describe('createMCPTools', () => {
it('should pass user parameter to reinitMCPServer when calling reconnectServer internally', async () => {
const mockUser = { id: 'test-user-123', name: 'Test User' };
const mockRes = { write: jest.fn(), flush: jest.fn() };
const mockSignal = new AbortController().signal;
mockReinitMCPServer.mockResolvedValue({
tools: [{ name: 'test-tool' }],
availableTools: {
'test-tool::test-server': {
function: {
description: 'Test tool',
parameters: { type: 'object', properties: {} },
},
},
},
});
await createMCPTools({
res: mockRes,
user: mockUser,
serverName: 'test-server',
provider: 'openai',
signal: mockSignal,
userMCPAuthMap: {},
});
// Verify reinitMCPServer was called with the user
expect(mockReinitMCPServer).toHaveBeenCalledWith(
expect.objectContaining({
user: mockUser,
serverName: 'test-server',
}),
);
expect(mockReinitMCPServer.mock.calls[0][0].user).toBe(mockUser);
});
it('should throw error if user is not provided', async () => {
const mockRes = { write: jest.fn(), flush: jest.fn() };
mockReinitMCPServer.mockResolvedValue({
tools: [],
availableTools: {},
});
// Call without user should throw error
await expect(
createMCPTools({
res: mockRes,
user: undefined,
serverName: 'test-server',
provider: 'openai',
userMCPAuthMap: {},
}),
).rejects.toThrow("Cannot read properties of undefined (reading 'id')");
// Verify reinitMCPServer was not called due to early error
expect(mockReinitMCPServer).not.toHaveBeenCalled();
});
});
describe('createMCPTool', () => {
it('should pass user parameter to reinitMCPServer when tool not in cache', async () => {
const mockUser = { id: 'test-user-456', email: 'test@example.com' };
const mockRes = { write: jest.fn(), flush: jest.fn() };
const mockSignal = new AbortController().signal;
mockReinitMCPServer.mockResolvedValue({
availableTools: {
'test-tool::test-server': {
function: {
description: 'Test tool',
parameters: { type: 'object', properties: {} },
},
},
},
});
// Call without availableTools to trigger reinit
await createMCPTool({
res: mockRes,
user: mockUser,
toolKey: 'test-tool::test-server',
provider: 'openai',
signal: mockSignal,
userMCPAuthMap: {},
availableTools: undefined, // Force reinit
});
// Verify reinitMCPServer was called with the user
expect(mockReinitMCPServer).toHaveBeenCalledWith(
expect.objectContaining({
user: mockUser,
serverName: 'test-server',
}),
);
expect(mockReinitMCPServer.mock.calls[0][0].user).toBe(mockUser);
});
it('should not call reinitMCPServer when tool is in cache', async () => {
const mockUser = { id: 'test-user-789' };
const mockRes = { write: jest.fn(), flush: jest.fn() };
const availableTools = {
'test-tool::test-server': {
function: {
description: 'Cached tool',
parameters: { type: 'object', properties: {} },
},
},
};
await createMCPTool({
res: mockRes,
user: mockUser,
toolKey: 'test-tool::test-server',
provider: 'openai',
userMCPAuthMap: {},
availableTools: availableTools,
});
// Verify reinitMCPServer was NOT called since tool was in cache
expect(mockReinitMCPServer).not.toHaveBeenCalled();
});
});
describe('reinitMCPServer (via reconnectServer)', () => {
it('should always receive user parameter when called from createMCPTools', async () => {
const mockUser = { id: 'user-001', role: 'admin' };
const mockRes = { write: jest.fn(), flush: jest.fn() };
// Track all calls to reinitMCPServer
const reinitCalls = [];
mockReinitMCPServer.mockImplementation((params) => {
reinitCalls.push(params);
return Promise.resolve({
tools: [{ name: 'tool1' }, { name: 'tool2' }],
availableTools: {
'tool1::server1': { function: { description: 'Tool 1', parameters: {} } },
'tool2::server1': { function: { description: 'Tool 2', parameters: {} } },
},
});
});
await createMCPTools({
res: mockRes,
user: mockUser,
serverName: 'server1',
provider: 'anthropic',
userMCPAuthMap: {},
});
// Verify all calls to reinitMCPServer had the user
expect(reinitCalls.length).toBeGreaterThan(0);
reinitCalls.forEach((call) => {
expect(call.user).toBe(mockUser);
expect(call.user.id).toBe('user-001');
});
});
it('should always receive user parameter when called from createMCPTool', async () => {
const mockUser = { id: 'user-002', permissions: ['read', 'write'] };
const mockRes = { write: jest.fn(), flush: jest.fn() };
// Track all calls to reinitMCPServer
const reinitCalls = [];
mockReinitMCPServer.mockImplementation((params) => {
reinitCalls.push(params);
return Promise.resolve({
availableTools: {
'my-tool::my-server': {
function: { description: 'My Tool', parameters: {} },
},
},
});
});
await createMCPTool({
res: mockRes,
user: mockUser,
toolKey: 'my-tool::my-server',
provider: 'google',
userMCPAuthMap: {},
availableTools: undefined, // Force reinit
});
// Verify the call to reinitMCPServer had the user
expect(reinitCalls.length).toBe(1);
expect(reinitCalls[0].user).toBe(mockUser);
expect(reinitCalls[0].user.id).toBe('user-002');
});
});
describe('User parameter integrity', () => {
it('should preserve user object properties through the call chain', async () => {
const complexUser = {
id: 'complex-user',
name: 'John Doe',
email: 'john@example.com',
metadata: { subscription: 'premium', settings: { theme: 'dark' } },
};
const mockRes = { write: jest.fn(), flush: jest.fn() };
let capturedUser = null;
mockReinitMCPServer.mockImplementation((params) => {
capturedUser = params.user;
return Promise.resolve({
tools: [{ name: 'test' }],
availableTools: {
'test::server': { function: { description: 'Test', parameters: {} } },
},
});
});
await createMCPTools({
res: mockRes,
user: complexUser,
serverName: 'server',
provider: 'openai',
userMCPAuthMap: {},
});
// Verify the complete user object was passed
expect(capturedUser).toEqual(complexUser);
expect(capturedUser.id).toBe('complex-user');
expect(capturedUser.metadata.subscription).toBe('premium');
expect(capturedUser.metadata.settings.theme).toBe('dark');
});
it('should throw error when user is null', async () => {
const mockRes = { write: jest.fn(), flush: jest.fn() };
mockReinitMCPServer.mockResolvedValue({
tools: [],
availableTools: {},
});
await expect(
createMCPTools({
res: mockRes,
user: null,
serverName: 'test-server',
provider: 'openai',
userMCPAuthMap: {},
}),
).rejects.toThrow("Cannot read properties of null (reading 'id')");
// Verify reinitMCPServer was not called due to early error
expect(mockReinitMCPServer).not.toHaveBeenCalled();
});
});
});

View file

@ -7,7 +7,7 @@ const { getLogStores } = require('~/cache');
/**
* @param {Object} params
* @param {string} params.userId
* @param {IUser} params.user - The user from the request object.
* @param {string} params.serverName - The name of the MCP server
* @param {boolean} params.returnOnOAuth - Whether to initiate OAuth and return, or wait for OAuth flow to finish
* @param {AbortSignal} [params.signal] - The abort signal to handle cancellation.
@ -18,7 +18,7 @@ const { getLogStores } = require('~/cache');
* @param {Record<string, Record<string, string>>} [params.userMCPAuthMap]
*/
async function reinitMCPServer({
userId,
user,
signal,
forceNew,
serverName,
@ -51,7 +51,7 @@ async function reinitMCPServer({
try {
userConnection = await mcpManager.getUserConnection({
user: { id: userId },
user,
signal,
forceNew,
oauthStart,

View file

@ -1,5 +1,6 @@
import { extractEnvVariable } from 'librechat-data-provider';
import type { TUser, MCPOptions } from 'librechat-data-provider';
import type { IUser } from '@librechat/data-schemas';
import type { RequestBody } from '~/types';
/**
@ -26,6 +27,31 @@ const ALLOWED_USER_FIELDS = [
'termsAccepted',
] as const;
type AllowedUserField = (typeof ALLOWED_USER_FIELDS)[number];
type SafeUser = Pick<IUser, AllowedUserField>;
/**
* Creates a safe user object containing only allowed fields.
* Optimized for performance while maintaining type safety.
*
* @param user - The user object to extract safe fields from
* @returns A new object containing only allowed fields
*/
export function createSafeUser(user: IUser | null | undefined): Partial<SafeUser> {
if (!user) {
return {};
}
const safeUser: Partial<SafeUser> = {};
for (const field of ALLOWED_USER_FIELDS) {
if (field in user) {
safeUser[field] = user[field];
}
}
return safeUser;
}
/**
* List of allowed request body fields that can be used in header placeholders.
* These are common fields from the request body that are safe to expose in headers.