Merge branch 'main' into fix-image_gen_oai-with-nova-models

This commit is contained in:
Peter 2026-02-18 08:55:05 +01:00 committed by GitHub
commit cbb6b1a7d5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
211 changed files with 10746 additions and 1656 deletions

View file

@ -47,6 +47,10 @@ TRUST_PROXY=1
# password policies.
# MIN_PASSWORD_LENGTH=8
# When enabled, the app will continue running after encountering uncaught exceptions
# instead of exiting the process. Not recommended for production unless necessary.
# CONTINUE_ON_UNCAUGHT_EXCEPTION=false
#===============#
# JSON Logging #
#===============#
@ -131,7 +135,7 @@ PROXY=
#============#
ANTHROPIC_API_KEY=user_provided
# ANTHROPIC_MODELS=claude-opus-4-6,claude-opus-4-20250514,claude-sonnet-4-20250514,claude-3-7-sonnet-20250219,claude-3-5-sonnet-20241022,claude-3-5-haiku-20241022,claude-3-opus-20240229,claude-3-sonnet-20240229,claude-3-haiku-20240307
# ANTHROPIC_MODELS=claude-sonnet-4-6,claude-opus-4-6,claude-opus-4-20250514,claude-sonnet-4-20250514,claude-3-7-sonnet-20250219,claude-3-5-sonnet-20241022,claude-3-5-haiku-20241022,claude-3-opus-20240229,claude-3-sonnet-20240229,claude-3-haiku-20240307
# ANTHROPIC_REVERSE_PROXY=
# Set to true to use Anthropic models through Google Vertex AI instead of direct API
@ -166,8 +170,8 @@ ANTHROPIC_API_KEY=user_provided
# BEDROCK_AWS_SESSION_TOKEN=someSessionToken
# Note: This example list is not meant to be exhaustive. If omitted, all known, supported model IDs will be included for you.
# BEDROCK_AWS_MODELS=anthropic.claude-opus-4-6-v1,anthropic.claude-3-5-sonnet-20240620-v1:0,meta.llama3-1-8b-instruct-v1:0
# Cross-region inference model IDs: us.anthropic.claude-opus-4-6-v1,global.anthropic.claude-opus-4-6-v1
# BEDROCK_AWS_MODELS=anthropic.claude-sonnet-4-6,anthropic.claude-opus-4-6-v1,anthropic.claude-3-5-sonnet-20240620-v1:0,meta.llama3-1-8b-instruct-v1:0
# Cross-region inference model IDs: us.anthropic.claude-sonnet-4-6,us.anthropic.claude-opus-4-6-v1,global.anthropic.claude-opus-4-6-v1
# See all Bedrock model IDs here: https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html#model-ids-arns
@ -748,8 +752,10 @@ HELP_AND_FAQ_URL=https://librechat.ai
# REDIS_PING_INTERVAL=300
# Force specific cache namespaces to use in-memory storage even when Redis is enabled
# Comma-separated list of CacheKeys (e.g., ROLES,MESSAGES)
# FORCED_IN_MEMORY_CACHE_NAMESPACES=ROLES,MESSAGES
# Comma-separated list of CacheKeys
# Defaults to CONFIG_STORE,APP_CONFIG so YAML-derived config stays per-container (safe for blue/green deployments)
# Set to empty string to force all namespaces through Redis: FORCED_IN_MEMORY_CACHE_NAMESPACES=
# FORCED_IN_MEMORY_CACHE_NAMESPACES=CONFIG_STORE,APP_CONFIG
# Leader Election Configuration (for multi-instance deployments with Redis)
# Duration in seconds that the leader lease is valid before it expires (default: 25)

3
.gitignore vendored
View file

@ -30,6 +30,9 @@ coverage
config/translations/stores/*
client/src/localization/languages/*_missing_keys.json
# Turborepo
.turbo
# Compiled Dirs (http://nodejs.org/api/addons.html)
build/
dist/

View file

@ -55,6 +55,7 @@ const banViolation = async (req, res, errorMessage) => {
res.clearCookie('refreshToken');
res.clearCookie('openid_access_token');
res.clearCookie('openid_id_token');
res.clearCookie('openid_user_id');
res.clearCookie('token_provider');

View file

@ -37,6 +37,7 @@ const namespaces = {
[CacheKeys.ROLES]: standardCache(CacheKeys.ROLES),
[CacheKeys.APP_CONFIG]: standardCache(CacheKeys.APP_CONFIG),
[CacheKeys.CONFIG_STORE]: standardCache(CacheKeys.CONFIG_STORE),
[CacheKeys.TOOL_CACHE]: standardCache(CacheKeys.TOOL_CACHE),
[CacheKeys.PENDING_REQ]: standardCache(CacheKeys.PENDING_REQ),
[CacheKeys.ENCODED_DOMAINS]: new Keyv({ store: keyvMongo, namespace: CacheKeys.ENCODED_DOMAINS }),
[CacheKeys.ABORT_KEYS]: standardCache(CacheKeys.ABORT_KEYS, Time.TEN_MINUTES),

View file

@ -40,6 +40,10 @@ if (!cached) {
cached = global.mongoose = { conn: null, promise: null };
}
mongoose.connection.on('error', (err) => {
logger.error('[connectDb] MongoDB connection error:', err);
});
async function connectDb() {
if (cached.conn && cached.conn?._readyState === 1) {
return cached.conn;

View file

@ -26,7 +26,7 @@ async function batchResetMeiliFlags(collection) {
try {
while (hasMore) {
const docs = await collection
.find({ expiredAt: null, _meiliIndex: true }, { projection: { _id: 1 } })
.find({ expiredAt: null, _meiliIndex: { $ne: false } }, { projection: { _id: 1 } })
.limit(BATCH_SIZE)
.toArray();

View file

@ -265,8 +265,8 @@ describe('batchResetMeiliFlags', () => {
const result = await batchResetMeiliFlags(testCollection);
// Only one document has _meiliIndex: true
expect(result).toBe(1);
// both documents should be updated
expect(result).toBe(2);
});
it('should handle mixed document states correctly', async () => {
@ -275,16 +275,18 @@ describe('batchResetMeiliFlags', () => {
{ _id: new mongoose.Types.ObjectId(), expiredAt: null, _meiliIndex: false },
{ _id: new mongoose.Types.ObjectId(), expiredAt: new Date(), _meiliIndex: true },
{ _id: new mongoose.Types.ObjectId(), expiredAt: null, _meiliIndex: true },
{ _id: new mongoose.Types.ObjectId(), expiredAt: null, _meiliIndex: null },
{ _id: new mongoose.Types.ObjectId(), expiredAt: null },
]);
const result = await batchResetMeiliFlags(testCollection);
expect(result).toBe(2);
expect(result).toBe(4);
const flaggedDocs = await testCollection
.find({ expiredAt: null, _meiliIndex: false })
.toArray();
expect(flaggedDocs).toHaveLength(3); // 2 were updated, 1 was already false
expect(flaggedDocs).toHaveLength(5); // 4 were updated, 1 was already false
});
});

View file

@ -124,10 +124,15 @@ module.exports = {
updateOperation,
{
new: true,
upsert: true,
upsert: metadata?.noUpsert !== true,
},
);
if (!conversation) {
logger.debug('[saveConvo] Conversation not found, skipping update');
return null;
}
return conversation.toObject();
} catch (error) {
logger.error('[saveConvo] Error saving conversation', error);

View file

@ -106,6 +106,47 @@ describe('Conversation Operations', () => {
expect(result.conversationId).toBe(newConversationId);
});
it('should not create a conversation when noUpsert is true and conversation does not exist', async () => {
const nonExistentId = uuidv4();
const result = await saveConvo(
mockReq,
{ conversationId: nonExistentId, title: 'Ghost Title' },
{ noUpsert: true },
);
expect(result).toBeNull();
const dbConvo = await Conversation.findOne({ conversationId: nonExistentId });
expect(dbConvo).toBeNull();
});
it('should update an existing conversation when noUpsert is true', async () => {
await saveConvo(mockReq, mockConversationData);
const result = await saveConvo(
mockReq,
{ conversationId: mockConversationData.conversationId, title: 'Updated Title' },
{ noUpsert: true },
);
expect(result).not.toBeNull();
expect(result.title).toBe('Updated Title');
expect(result.conversationId).toBe(mockConversationData.conversationId);
});
it('should still upsert by default when noUpsert is not provided', async () => {
const newId = uuidv4();
const result = await saveConvo(mockReq, {
conversationId: newId,
title: 'New Conversation',
endpoint: EModelEndpoint.openAI,
});
expect(result).not.toBeNull();
expect(result.conversationId).toBe(newId);
expect(result.title).toBe('New Conversation');
});
it('should handle unsetFields metadata', async () => {
const metadata = {
unsetFields: { someField: 1 },
@ -122,7 +163,6 @@ describe('Conversation Operations', () => {
describe('isTemporary conversation handling', () => {
it('should save a conversation with expiredAt when isTemporary is true', async () => {
// Mock app config with 24 hour retention
mockReq.config.interfaceConfig.temporaryChatRetention = 24;
mockReq.body = { isTemporary: true };
@ -135,7 +175,6 @@ describe('Conversation Operations', () => {
expect(result.expiredAt).toBeDefined();
expect(result.expiredAt).toBeInstanceOf(Date);
// Verify expiredAt is approximately 24 hours in the future
const expectedExpirationTime = new Date(beforeSave.getTime() + 24 * 60 * 60 * 1000);
const actualExpirationTime = new Date(result.expiredAt);
@ -157,7 +196,6 @@ describe('Conversation Operations', () => {
});
it('should save a conversation without expiredAt when isTemporary is not provided', async () => {
// No isTemporary in body
mockReq.body = {};
const result = await saveConvo(mockReq, mockConversationData);
@ -167,7 +205,6 @@ describe('Conversation Operations', () => {
});
it('should use custom retention period from config', async () => {
// Mock app config with 48 hour retention
mockReq.config.interfaceConfig.temporaryChatRetention = 48;
mockReq.body = { isTemporary: true };

View file

@ -176,6 +176,7 @@ const tokenValues = Object.assign(
'claude-opus-4-5': { prompt: 5, completion: 25 },
'claude-opus-4-6': { prompt: 5, completion: 25 },
'claude-sonnet-4': { prompt: 3, completion: 15 },
'claude-sonnet-4-6': { prompt: 3, completion: 15 },
'command-r': { prompt: 0.5, completion: 1.5 },
'command-r-plus': { prompt: 3, completion: 15 },
'command-text': { prompt: 1.5, completion: 2.0 },
@ -309,6 +310,7 @@ const cacheTokenValues = {
'claude-3-haiku': { write: 0.3, read: 0.03 },
'claude-haiku-4-5': { write: 1.25, read: 0.1 },
'claude-sonnet-4': { write: 3.75, read: 0.3 },
'claude-sonnet-4-6': { write: 3.75, read: 0.3 },
'claude-opus-4': { write: 18.75, read: 1.5 },
'claude-opus-4-5': { write: 6.25, read: 0.5 },
'claude-opus-4-6': { write: 6.25, read: 0.5 },
@ -337,6 +339,7 @@ const cacheTokenValues = {
*/
const premiumTokenValues = {
'claude-opus-4-6': { threshold: 200000, prompt: 10, completion: 37.5 },
'claude-sonnet-4-6': { threshold: 200000, prompt: 6, completion: 22.5 },
};
/**

View file

@ -44,14 +44,14 @@
"@google/genai": "^1.19.0",
"@keyv/redis": "^4.3.3",
"@langchain/core": "^0.3.80",
"@librechat/agents": "^3.1.38",
"@librechat/agents": "^3.1.50",
"@librechat/api": "*",
"@librechat/data-schemas": "*",
"@microsoft/microsoft-graph-client": "^3.0.7",
"@modelcontextprotocol/sdk": "^1.26.0",
"@node-saml/passport-saml": "^5.1.0",
"@smithy/node-http-handler": "^4.4.5",
"axios": "^1.12.1",
"axios": "^1.13.5",
"bcryptjs": "^2.4.3",
"compression": "^1.8.1",
"connect-redis": "^8.1.0",

View file

@ -18,7 +18,6 @@ const {
findUser,
} = require('~/models');
const { getGraphApiToken } = require('~/server/services/GraphTokenService');
const { getOAuthReconnectionManager } = require('~/config');
const { getOpenIdConfig } = require('~/strategies');
const registrationController = async (req, res) => {
@ -79,7 +78,12 @@ const refreshController = async (req, res) => {
try {
const openIdConfig = getOpenIdConfig();
const tokenset = await openIdClient.refreshTokenGrant(openIdConfig, refreshToken);
const refreshParams = process.env.OPENID_SCOPE ? { scope: process.env.OPENID_SCOPE } : {};
const tokenset = await openIdClient.refreshTokenGrant(
openIdConfig,
refreshToken,
refreshParams,
);
const claims = tokenset.claims();
const { user, error, migration } = await findOpenIDUser({
findUser,
@ -161,17 +165,6 @@ const refreshController = async (req, res) => {
if (session && session.expiration > new Date()) {
const token = await setAuthTokens(userId, res, session);
// trigger OAuth MCP server reconnection asynchronously (best effort)
try {
void getOAuthReconnectionManager()
.reconnectServers(userId)
.catch((err) => {
logger.error('[refreshController] Error reconnecting OAuth MCP servers:', err);
});
} catch (err) {
logger.warn(`[refreshController] Cannot attempt OAuth MCP servers reconnection:`, err);
}
res.status(200).send({ token, user });
} else if (req?.query?.retry) {
// Retrying from a refresh token request that failed (401)

View file

@ -8,7 +8,7 @@ const { getLogStores } = require('~/cache');
const getAvailablePluginsController = async (req, res) => {
try {
const cache = getLogStores(CacheKeys.CONFIG_STORE);
const cache = getLogStores(CacheKeys.TOOL_CACHE);
const cachedPlugins = await cache.get(CacheKeys.PLUGINS);
if (cachedPlugins) {
res.status(200).json(cachedPlugins);
@ -63,7 +63,7 @@ const getAvailableTools = async (req, res) => {
logger.warn('[getAvailableTools] User ID not found in request');
return res.status(401).json({ message: 'Unauthorized' });
}
const cache = getLogStores(CacheKeys.CONFIG_STORE);
const cache = getLogStores(CacheKeys.TOOL_CACHE);
const cachedToolsArray = await cache.get(CacheKeys.TOOLS);
const appConfig = req.config ?? (await getAppConfig({ role: req.user?.role }));

View file

@ -1,3 +1,4 @@
const { CacheKeys } = require('librechat-data-provider');
const { getCachedTools, getAppConfig } = require('~/server/services/Config');
const { getLogStores } = require('~/cache');
@ -63,6 +64,28 @@ describe('PluginController', () => {
});
});
describe('cache namespace', () => {
it('getAvailablePluginsController should use TOOL_CACHE namespace', async () => {
mockCache.get.mockResolvedValue([]);
await getAvailablePluginsController(mockReq, mockRes);
expect(getLogStores).toHaveBeenCalledWith(CacheKeys.TOOL_CACHE);
});
it('getAvailableTools should use TOOL_CACHE namespace', async () => {
mockCache.get.mockResolvedValue([]);
await getAvailableTools(mockReq, mockRes);
expect(getLogStores).toHaveBeenCalledWith(CacheKeys.TOOL_CACHE);
});
it('should NOT use CONFIG_STORE namespace for tool/plugin operations', async () => {
mockCache.get.mockResolvedValue([]);
await getAvailablePluginsController(mockReq, mockRes);
await getAvailableTools(mockReq, mockRes);
const allCalls = getLogStores.mock.calls.flat();
expect(allCalls).not.toContain(CacheKeys.CONFIG_STORE);
});
});
describe('getAvailablePluginsController', () => {
it('should use filterUniquePlugins to remove duplicate plugins', async () => {
// Add plugins with duplicates to availableTools

View file

@ -36,6 +36,7 @@ const {
const { updateUserPluginAuth, deleteUserPluginAuth } = require('~/server/services/PluginService');
const { verifyEmail, resendVerificationEmail } = require('~/server/services/AuthService');
const { getMCPManager, getFlowStateManager, getMCPServersRegistry } = require('~/config');
const { invalidateCachedTools } = require('~/server/services/Config/getCachedTools');
const { needsRefresh, getNewS3URL } = require('~/server/services/Files/S3/crud');
const { processDeleteRequest } = require('~/server/services/Files/process');
const { getAppConfig } = require('~/server/services/Config');
@ -215,6 +216,7 @@ const updateUserPluginsController = async (req, res) => {
`[updateUserPluginsController] Attempting disconnect of MCP server "${serverName}" for user ${user.id} after plugin auth update.`,
);
await mcpManager.disconnectUserConnection(user.id, serverName);
await invalidateCachedTools({ userId: user.id, serverName });
}
} catch (disconnectError) {
logger.error(

View file

@ -20,7 +20,6 @@ jest.mock('@librechat/agents', () => ({
getMessageId: jest.fn(),
ToolEndHandler: jest.fn(),
handleToolCalls: jest.fn(),
ChatModelStreamHandler: jest.fn(),
}));
jest.mock('~/server/services/Files/Citations', () => ({

View file

@ -30,9 +30,6 @@ jest.mock('@librechat/agents', () => ({
messages: [],
indexTokenCountMap: {},
}),
ChatModelStreamHandler: jest.fn().mockImplementation(() => ({
handle: jest.fn(),
})),
}));
jest.mock('@librechat/api', () => ({

View file

@ -34,9 +34,6 @@ jest.mock('@librechat/agents', () => ({
messages: [],
indexTokenCountMap: {},
}),
ChatModelStreamHandler: jest.fn().mockImplementation(() => ({
handle: jest.fn(),
})),
}));
jest.mock('@librechat/api', () => ({

View file

@ -1,22 +1,13 @@
const { nanoid } = require('nanoid');
const { Constants } = require('@librechat/agents');
const { logger } = require('@librechat/data-schemas');
const { Constants, EnvVar, GraphEvents, ToolEndHandler } = require('@librechat/agents');
const { Tools, StepTypes, FileContext, ErrorTypes } = require('librechat-data-provider');
const {
sendEvent,
GenerationJobManager,
writeAttachmentEvent,
createToolExecuteHandler,
} = require('@librechat/api');
const { Tools, StepTypes, FileContext, ErrorTypes } = require('librechat-data-provider');
const {
EnvVar,
Providers,
GraphEvents,
getMessageId,
ToolEndHandler,
handleToolCalls,
ChatModelStreamHandler,
} = require('@librechat/agents');
const { processFileCitations } = require('~/server/services/Files/Citations');
const { processCodeOutput } = require('~/server/services/Files/Code/process');
const { loadAuthValues } = require('~/server/services/Tools/credentials');
@ -57,8 +48,6 @@ class ModelEndHandler {
let errorMessage;
try {
const agentContext = graph.getAgentContext(metadata);
const isGoogle = agentContext.provider === Providers.GOOGLE;
const streamingDisabled = !!agentContext.clientOptions?.disableStreaming;
if (data?.output?.additional_kwargs?.stop_reason === 'refusal') {
const info = { ...data.output.additional_kwargs };
errorMessage = JSON.stringify({
@ -73,21 +62,6 @@ class ModelEndHandler {
});
}
const toolCalls = data?.output?.tool_calls;
let hasUnprocessedToolCalls = false;
if (Array.isArray(toolCalls) && toolCalls.length > 0 && graph?.toolCallStepIds?.has) {
try {
hasUnprocessedToolCalls = toolCalls.some(
(tc) => tc?.id && !graph.toolCallStepIds.has(tc.id),
);
} catch {
hasUnprocessedToolCalls = false;
}
}
if (isGoogle || streamingDisabled || hasUnprocessedToolCalls) {
await handleToolCalls(toolCalls, metadata, graph);
}
const usage = data?.output?.usage_metadata;
if (!usage) {
return this.finalize(errorMessage);
@ -98,38 +72,6 @@ class ModelEndHandler {
}
this.collectedUsage.push(usage);
if (!streamingDisabled) {
return this.finalize(errorMessage);
}
if (!data.output.content) {
return this.finalize(errorMessage);
}
const stepKey = graph.getStepKey(metadata);
const message_id = getMessageId(stepKey, graph) ?? '';
if (message_id) {
await graph.dispatchRunStep(stepKey, {
type: StepTypes.MESSAGE_CREATION,
message_creation: {
message_id,
},
});
}
const stepId = graph.getStepIdByKey(stepKey);
const content = data.output.content;
if (typeof content === 'string') {
await graph.dispatchMessageDelta(stepId, {
content: [
{
type: 'text',
text: content,
},
],
});
} else if (content.every((c) => c.type?.startsWith('text'))) {
await graph.dispatchMessageDelta(stepId, {
content,
});
}
} catch (error) {
logger.error('Error handling model end event:', error);
return this.finalize(errorMessage);
@ -200,7 +142,6 @@ function getDefaultHandlers({
const handlers = {
[GraphEvents.CHAT_MODEL_END]: new ModelEndHandler(collectedUsage),
[GraphEvents.TOOL_END]: new ToolEndHandler(toolEndCallback, logger),
[GraphEvents.CHAT_MODEL_STREAM]: new ChatModelStreamHandler(),
[GraphEvents.ON_RUN_STEP]: {
/**
* Handle ON_RUN_STEP event.
@ -209,6 +150,7 @@ function getDefaultHandlers({
* @param {GraphRunnableConfig['configurable']} [metadata] The runnable metadata.
*/
handle: async (event, data, metadata) => {
aggregateContent({ event, data });
if (data?.stepDetails.type === StepTypes.TOOL_CALLS) {
await emitEvent(res, streamId, { event, data });
} else if (checkIfLastAgent(metadata?.last_agent_id, metadata?.langgraph_node)) {
@ -227,7 +169,6 @@ function getDefaultHandlers({
},
});
}
aggregateContent({ event, data });
},
},
[GraphEvents.ON_RUN_STEP_DELTA]: {
@ -238,6 +179,7 @@ function getDefaultHandlers({
* @param {GraphRunnableConfig['configurable']} [metadata] The runnable metadata.
*/
handle: async (event, data, metadata) => {
aggregateContent({ event, data });
if (data?.delta.type === StepTypes.TOOL_CALLS) {
await emitEvent(res, streamId, { event, data });
} else if (checkIfLastAgent(metadata?.last_agent_id, metadata?.langgraph_node)) {
@ -245,7 +187,6 @@ function getDefaultHandlers({
} else if (!metadata?.hide_sequential_outputs) {
await emitEvent(res, streamId, { event, data });
}
aggregateContent({ event, data });
},
},
[GraphEvents.ON_RUN_STEP_COMPLETED]: {
@ -256,6 +197,7 @@ function getDefaultHandlers({
* @param {GraphRunnableConfig['configurable']} [metadata] The runnable metadata.
*/
handle: async (event, data, metadata) => {
aggregateContent({ event, data });
if (data?.result != null) {
await emitEvent(res, streamId, { event, data });
} else if (checkIfLastAgent(metadata?.last_agent_id, metadata?.langgraph_node)) {
@ -263,7 +205,6 @@ function getDefaultHandlers({
} else if (!metadata?.hide_sequential_outputs) {
await emitEvent(res, streamId, { event, data });
}
aggregateContent({ event, data });
},
},
[GraphEvents.ON_MESSAGE_DELTA]: {
@ -274,12 +215,12 @@ function getDefaultHandlers({
* @param {GraphRunnableConfig['configurable']} [metadata] The runnable metadata.
*/
handle: async (event, data, metadata) => {
aggregateContent({ event, data });
if (checkIfLastAgent(metadata?.last_agent_id, metadata?.langgraph_node)) {
await emitEvent(res, streamId, { event, data });
} else if (!metadata?.hide_sequential_outputs) {
await emitEvent(res, streamId, { event, data });
}
aggregateContent({ event, data });
},
},
[GraphEvents.ON_REASONING_DELTA]: {
@ -290,12 +231,12 @@ function getDefaultHandlers({
* @param {GraphRunnableConfig['configurable']} [metadata] The runnable metadata.
*/
handle: async (event, data, metadata) => {
aggregateContent({ event, data });
if (checkIfLastAgent(metadata?.last_agent_id, metadata?.langgraph_node)) {
await emitEvent(res, streamId, { event, data });
} else if (!metadata?.hide_sequential_outputs) {
await emitEvent(res, streamId, { event, data });
}
aggregateContent({ event, data });
},
},
};

View file

@ -6,18 +6,22 @@ const {
Tokenizer,
checkAccess,
buildToolSet,
logAxiosError,
sanitizeTitle,
logToolError,
payloadParser,
resolveHeaders,
createSafeUser,
initializeAgent,
getBalanceConfig,
getProviderConfig,
omitTitleOptions,
memoryInstructions,
applyContextToAgent,
createTokenCounter,
GenerationJobManager,
getTransactionsConfig,
createMemoryProcessor,
createMultiAgentMapper,
filterMalformedContentParts,
} = require('@librechat/api');
const {
@ -25,9 +29,7 @@ const {
Providers,
TitleMethod,
formatMessage,
labelContentByAgent,
formatAgentMessages,
getTokenCountForMessage,
createMetadataAggregator,
} = require('@librechat/agents');
const {
@ -39,7 +41,6 @@ const {
PermissionTypes,
isAgentsEndpoint,
isEphemeralAgentId,
bedrockInputSchema,
removeNullishValues,
} = require('librechat-data-provider');
const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens');
@ -52,183 +53,6 @@ const { loadAgent } = require('~/models/Agent');
const { getMCPManager } = require('~/config');
const db = require('~/models');
const omitTitleOptions = new Set([
'stream',
'thinking',
'streaming',
'clientOptions',
'thinkingConfig',
'thinkingBudget',
'includeThoughts',
'maxOutputTokens',
'additionalModelRequestFields',
]);
/**
* @param {ServerRequest} req
* @param {Agent} agent
* @param {string} endpoint
*/
const payloadParser = ({ req, agent, endpoint }) => {
if (isAgentsEndpoint(endpoint)) {
return { model: undefined };
} else if (endpoint === EModelEndpoint.bedrock) {
const parsedValues = bedrockInputSchema.parse(agent.model_parameters);
if (parsedValues.thinking == null) {
parsedValues.thinking = false;
}
return parsedValues;
}
return req.body.endpointOption.model_parameters;
};
function createTokenCounter(encoding) {
return function (message) {
const countTokens = (text) => Tokenizer.getTokenCount(text, encoding);
return getTokenCountForMessage(message, countTokens);
};
}
function logToolError(graph, error, toolId) {
logAxiosError({
error,
message: `[api/server/controllers/agents/client.js #chatCompletion] Tool Error "${toolId}"`,
});
}
/** Regex pattern to match agent ID suffix (____N) */
const AGENT_SUFFIX_PATTERN = /____(\d+)$/;
/**
* Finds the primary agent ID within a set of agent IDs.
* Primary = no suffix (____N) or lowest suffix number.
* @param {Set<string>} agentIds
* @returns {string | null}
*/
function findPrimaryAgentId(agentIds) {
let primaryAgentId = null;
let lowestSuffixIndex = Infinity;
for (const agentId of agentIds) {
const suffixMatch = agentId.match(AGENT_SUFFIX_PATTERN);
if (!suffixMatch) {
return agentId;
}
const suffixIndex = parseInt(suffixMatch[1], 10);
if (suffixIndex < lowestSuffixIndex) {
lowestSuffixIndex = suffixIndex;
primaryAgentId = agentId;
}
}
return primaryAgentId;
}
/**
* Creates a mapMethod for getMessagesForConversation that processes agent content.
* - Strips agentId/groupId metadata from all content
* - For parallel agents (addedConvo with groupId): filters each group to its primary agent
* - For handoffs (agentId without groupId): keeps all content from all agents
* - For multi-agent: applies agent labels to content
*
* The key distinction:
* - Parallel execution (addedConvo): Parts have both agentId AND groupId
* - Handoffs: Parts only have agentId, no groupId
*
* @param {Agent} primaryAgent - Primary agent configuration
* @param {Map<string, Agent>} [agentConfigs] - Additional agent configurations
* @returns {(message: TMessage) => TMessage} Map method for processing messages
*/
function createMultiAgentMapper(primaryAgent, agentConfigs) {
const hasMultipleAgents = (primaryAgent.edges?.length ?? 0) > 0 || (agentConfigs?.size ?? 0) > 0;
/** @type {Record<string, string> | null} */
let agentNames = null;
if (hasMultipleAgents) {
agentNames = { [primaryAgent.id]: primaryAgent.name || 'Assistant' };
if (agentConfigs) {
for (const [agentId, agentConfig] of agentConfigs.entries()) {
agentNames[agentId] = agentConfig.name || agentConfig.id;
}
}
}
return (message) => {
if (message.isCreatedByUser || !Array.isArray(message.content)) {
return message;
}
// Check for metadata
const hasAgentMetadata = message.content.some((part) => part?.agentId || part?.groupId != null);
if (!hasAgentMetadata) {
return message;
}
try {
// Build a map of groupId -> Set of agentIds, to find primary per group
/** @type {Map<number, Set<string>>} */
const groupAgentMap = new Map();
for (const part of message.content) {
const groupId = part?.groupId;
const agentId = part?.agentId;
if (groupId != null && agentId) {
if (!groupAgentMap.has(groupId)) {
groupAgentMap.set(groupId, new Set());
}
groupAgentMap.get(groupId).add(agentId);
}
}
// For each group, find the primary agent
/** @type {Map<number, string>} */
const groupPrimaryMap = new Map();
for (const [groupId, agentIds] of groupAgentMap) {
const primary = findPrimaryAgentId(agentIds);
if (primary) {
groupPrimaryMap.set(groupId, primary);
}
}
/** @type {Array<TMessageContentParts>} */
const filteredContent = [];
/** @type {Record<number, string>} */
const agentIdMap = {};
for (const part of message.content) {
const agentId = part?.agentId;
const groupId = part?.groupId;
// Filtering logic:
// - No groupId (handoffs): always include
// - Has groupId (parallel): only include if it's the primary for that group
const isParallelPart = groupId != null;
const groupPrimary = isParallelPart ? groupPrimaryMap.get(groupId) : null;
const shouldInclude = !isParallelPart || !agentId || agentId === groupPrimary;
if (shouldInclude) {
const newIndex = filteredContent.length;
const { agentId: _a, groupId: _g, ...cleanPart } = part;
filteredContent.push(cleanPart);
if (agentId && hasMultipleAgents) {
agentIdMap[newIndex] = agentId;
}
}
}
const finalContent =
Object.keys(agentIdMap).length > 0 && agentNames
? labelContentByAgent(filteredContent, agentIdMap, agentNames)
: filteredContent;
return { ...message, content: finalContent };
} catch (error) {
logger.error('[AgentClient] Error processing multi-agent message:', error);
return message;
}
};
}
class AgentClient extends BaseClient {
constructor(options = {}) {
super(null, options);
@ -296,14 +120,9 @@ class AgentClient extends BaseClient {
checkVisionRequest() {}
getSaveOptions() {
// TODO:
// would need to be override settings; otherwise, model needs to be undefined
// model: this.override.model,
// instructions: this.override.instructions,
// additional_instructions: this.override.additional_instructions,
let runOptions = {};
try {
runOptions = payloadParser(this.options);
runOptions = payloadParser(this.options) ?? {};
} catch (error) {
logger.error(
'[api/server/controllers/agents/client.js #getSaveOptions] Error parsing options',
@ -314,14 +133,14 @@ class AgentClient extends BaseClient {
return removeNullishValues(
Object.assign(
{
spec: this.options.spec,
iconURL: this.options.iconURL,
endpoint: this.options.endpoint,
agent_id: this.options.agent.id,
modelLabel: this.options.modelLabel,
maxContextTokens: this.options.maxContextTokens,
resendFiles: this.options.resendFiles,
imageDetail: this.options.imageDetail,
spec: this.options.spec,
iconURL: this.options.iconURL,
maxContextTokens: this.maxContextTokens,
},
// TODO: PARSE OPTIONS BY PROVIDER, MAY CONTAIN SENSITIVE DATA
runOptions,
@ -969,7 +788,7 @@ class AgentClient extends BaseClient {
},
user: createSafeUser(this.options.req.user),
},
recursionLimit: agentsEConfig?.recursionLimit ?? 25,
recursionLimit: agentsEConfig?.recursionLimit ?? 50,
signal: abortController.signal,
streamMode: 'values',
version: 'v2',

View file

@ -1,12 +1,7 @@
const { nanoid } = require('nanoid');
const { logger } = require('@librechat/data-schemas');
const { Callback, ToolEndHandler, formatAgentMessages } = require('@librechat/agents');
const { EModelEndpoint, ResourceType, PermissionBits } = require('librechat-data-provider');
const {
Callback,
ToolEndHandler,
formatAgentMessages,
ChatModelStreamHandler,
} = require('@librechat/agents');
const {
writeSSE,
createRun,
@ -325,18 +320,8 @@ const OpenAIChatCompletionController = async (req, res) => {
}
};
// Built-in handler for processing raw model stream chunks
const chatModelStreamHandler = new ChatModelStreamHandler();
// Event handlers for OpenAI-compatible streaming
const handlers = {
// Process raw model chunks and dispatch message/reasoning deltas
on_chat_model_stream: {
handle: async (event, data, metadata, graph) => {
await chatModelStreamHandler.handle(event, data, metadata, graph);
},
},
// Text content streaming
on_message_delta: createHandler((data) => {
const content = data?.delta?.content;
@ -577,7 +562,14 @@ const OpenAIChatCompletionController = async (req, res) => {
writeSSE(res, '[DONE]');
res.end();
} else {
sendErrorResponse(res, 500, errorMessage, 'server_error');
// Forward upstream provider status codes (e.g., Anthropic 400s) instead of masking as 500
const statusCode =
typeof error?.status === 'number' && error.status >= 400 && error.status < 600
? error.status
: 500;
const errorType =
statusCode >= 400 && statusCode < 500 ? 'invalid_request_error' : 'server_error';
sendErrorResponse(res, statusCode, errorMessage, errorType);
}
}
};

View file

@ -1,13 +1,8 @@
const { nanoid } = require('nanoid');
const { v4: uuidv4 } = require('uuid');
const { logger } = require('@librechat/data-schemas');
const { Callback, ToolEndHandler, formatAgentMessages } = require('@librechat/agents');
const { EModelEndpoint, ResourceType, PermissionBits } = require('librechat-data-provider');
const {
Callback,
ToolEndHandler,
formatAgentMessages,
ChatModelStreamHandler,
} = require('@librechat/agents');
const {
createRun,
buildToolSet,
@ -410,9 +405,6 @@ const createResponse = async (req, res) => {
// Collect usage for balance tracking
const collectedUsage = [];
// Built-in handler for processing raw model stream chunks
const chatModelStreamHandler = new ChatModelStreamHandler();
// Artifact promises for processing tool outputs
/** @type {Promise<import('librechat-data-provider').TAttachment | null>[]} */
const artifactPromises = [];
@ -443,11 +435,6 @@ const createResponse = async (req, res) => {
// Combine handlers
const handlers = {
on_chat_model_stream: {
handle: async (event, data, metadata, graph) => {
await chatModelStreamHandler.handle(event, data, metadata, graph);
},
},
on_message_delta: responsesHandlers.on_message_delta,
on_reasoning_delta: responsesHandlers.on_reasoning_delta,
on_run_step: responsesHandlers.on_run_step,
@ -570,8 +557,6 @@ const createResponse = async (req, res) => {
} else {
const aggregatorHandlers = createAggregatorEventHandlers(aggregator);
const chatModelStreamHandler = new ChatModelStreamHandler();
// Collect usage for balance tracking
const collectedUsage = [];
@ -596,11 +581,6 @@ const createResponse = async (req, res) => {
};
const handlers = {
on_chat_model_stream: {
handle: async (event, data, metadata, graph) => {
await chatModelStreamHandler.handle(event, data, metadata, graph);
},
},
on_message_delta: aggregatorHandlers.on_message_delta,
on_reasoning_delta: aggregatorHandlers.on_reasoning_delta,
on_run_step: aggregatorHandlers.on_run_step,
@ -727,7 +707,13 @@ const createResponse = async (req, res) => {
writeDone(res);
res.end();
} else {
sendResponsesErrorResponse(res, 500, errorMessage, 'server_error');
// Forward upstream provider status codes (e.g., Anthropic 400s) instead of masking as 500
const statusCode =
typeof error?.status === 'number' && error.status >= 400 && error.status < 600
? error.status
: 500;
const errorType = statusCode >= 400 && statusCode < 500 ? 'invalid_request' : 'server_error';
sendResponsesErrorResponse(res, statusCode, errorMessage, errorType);
}
}
};

View file

@ -22,6 +22,7 @@ const logoutController = async (req, res) => {
res.clearCookie('refreshToken');
res.clearCookie('openid_access_token');
res.clearCookie('openid_id_token');
res.clearCookie('openid_user_id');
res.clearCookie('token_provider');
const response = { message };

View file

@ -251,6 +251,15 @@ process.on('uncaughtException', (err) => {
return;
}
if (isEnabled(process.env.CONTINUE_ON_UNCAUGHT_EXCEPTION)) {
logger.error('Unhandled error encountered. The app will continue running.', {
name: err?.name,
message: err?.message,
stack: err?.stack,
});
return;
}
process.exit(1);
});

View file

@ -5,9 +5,11 @@ const {
EModelEndpoint,
isAgentsEndpoint,
parseCompactConvo,
getDefaultParamsEndpoint,
} = require('librechat-data-provider');
const azureAssistants = require('~/server/services/Endpoints/azureAssistants');
const assistants = require('~/server/services/Endpoints/assistants');
const { getEndpointsConfig } = require('~/server/services/Config');
const agents = require('~/server/services/Endpoints/agents');
const { updateFilesUsage } = require('~/models');
@ -19,9 +21,24 @@ const buildFunction = {
async function buildEndpointOption(req, res, next) {
const { endpoint, endpointType } = req.body;
let endpointsConfig;
try {
endpointsConfig = await getEndpointsConfig(req);
} catch (error) {
logger.error('Error fetching endpoints config in buildEndpointOption', error);
}
const defaultParamsEndpoint = getDefaultParamsEndpoint(endpointsConfig, endpoint);
let parsedBody;
try {
parsedBody = parseCompactConvo({ endpoint, endpointType, conversation: req.body });
parsedBody = parseCompactConvo({
endpoint,
endpointType,
conversation: req.body,
defaultParamsEndpoint,
});
} catch (error) {
logger.error(`Error parsing compact conversation for endpoint ${endpoint}`, error);
logger.debug({
@ -55,6 +72,7 @@ async function buildEndpointOption(req, res, next) {
endpoint,
endpointType,
conversation: currentModelSpec.preset,
defaultParamsEndpoint,
});
if (currentModelSpec.iconURL != null && currentModelSpec.iconURL !== '') {
parsedBody.iconURL = currentModelSpec.iconURL;

View file

@ -0,0 +1,237 @@
/**
* Wrap parseCompactConvo: the REAL function runs, but jest can observe
* calls and return values. Must be declared before require('./buildEndpointOption')
* so the destructured reference in the middleware captures the wrapper.
*/
jest.mock('librechat-data-provider', () => {
const actual = jest.requireActual('librechat-data-provider');
return {
...actual,
parseCompactConvo: jest.fn((...args) => actual.parseCompactConvo(...args)),
};
});
const { EModelEndpoint, parseCompactConvo } = require('librechat-data-provider');
const mockBuildOptions = jest.fn((_endpoint, parsedBody) => ({
...parsedBody,
endpoint: _endpoint,
}));
jest.mock('~/server/services/Endpoints/azureAssistants', () => ({
buildOptions: mockBuildOptions,
}));
jest.mock('~/server/services/Endpoints/assistants', () => ({
buildOptions: mockBuildOptions,
}));
jest.mock('~/server/services/Endpoints/agents', () => ({
buildOptions: mockBuildOptions,
}));
jest.mock('~/models', () => ({
updateFilesUsage: jest.fn(),
}));
const mockGetEndpointsConfig = jest.fn();
jest.mock('~/server/services/Config', () => ({
getEndpointsConfig: (...args) => mockGetEndpointsConfig(...args),
}));
jest.mock('@librechat/api', () => ({
handleError: jest.fn(),
}));
const buildEndpointOption = require('./buildEndpointOption');
const createReq = (body, config = {}) => ({
body,
config,
baseUrl: '/api/chat',
});
const createRes = () => ({
status: jest.fn().mockReturnThis(),
json: jest.fn().mockReturnThis(),
});
describe('buildEndpointOption - defaultParamsEndpoint parsing', () => {
beforeEach(() => {
jest.clearAllMocks();
});
it('should pass defaultParamsEndpoint to parseCompactConvo and preserve maxOutputTokens', async () => {
mockGetEndpointsConfig.mockResolvedValue({
AnthropicClaude: {
type: EModelEndpoint.custom,
customParams: {
defaultParamsEndpoint: EModelEndpoint.anthropic,
},
},
});
const req = createReq(
{
endpoint: 'AnthropicClaude',
endpointType: EModelEndpoint.custom,
model: 'anthropic/claude-opus-4.5',
temperature: 0.7,
maxOutputTokens: 8192,
topP: 0.9,
maxContextTokens: 50000,
},
{ modelSpecs: null },
);
await buildEndpointOption(req, createRes(), jest.fn());
expect(parseCompactConvo).toHaveBeenCalledWith(
expect.objectContaining({
defaultParamsEndpoint: EModelEndpoint.anthropic,
}),
);
const parsedResult = parseCompactConvo.mock.results[0].value;
expect(parsedResult.maxOutputTokens).toBe(8192);
expect(parsedResult.topP).toBe(0.9);
expect(parsedResult.temperature).toBe(0.7);
expect(parsedResult.maxContextTokens).toBe(50000);
});
it('should strip maxOutputTokens when no defaultParamsEndpoint is configured', async () => {
mockGetEndpointsConfig.mockResolvedValue({
MyOpenRouter: {
type: EModelEndpoint.custom,
},
});
const req = createReq(
{
endpoint: 'MyOpenRouter',
endpointType: EModelEndpoint.custom,
model: 'gpt-4o',
temperature: 0.7,
maxOutputTokens: 8192,
max_tokens: 4096,
},
{ modelSpecs: null },
);
await buildEndpointOption(req, createRes(), jest.fn());
expect(parseCompactConvo).toHaveBeenCalledWith(
expect.objectContaining({
defaultParamsEndpoint: undefined,
}),
);
const parsedResult = parseCompactConvo.mock.results[0].value;
expect(parsedResult.maxOutputTokens).toBeUndefined();
expect(parsedResult.max_tokens).toBe(4096);
expect(parsedResult.temperature).toBe(0.7);
});
it('should strip bedrock region from custom endpoint without defaultParamsEndpoint', async () => {
mockGetEndpointsConfig.mockResolvedValue({
MyEndpoint: {
type: EModelEndpoint.custom,
},
});
const req = createReq(
{
endpoint: 'MyEndpoint',
endpointType: EModelEndpoint.custom,
model: 'gpt-4o',
temperature: 0.7,
region: 'us-east-1',
},
{ modelSpecs: null },
);
await buildEndpointOption(req, createRes(), jest.fn());
const parsedResult = parseCompactConvo.mock.results[0].value;
expect(parsedResult.region).toBeUndefined();
expect(parsedResult.temperature).toBe(0.7);
});
it('should pass defaultParamsEndpoint when re-parsing enforced model spec', async () => {
mockGetEndpointsConfig.mockResolvedValue({
AnthropicClaude: {
type: EModelEndpoint.custom,
customParams: {
defaultParamsEndpoint: EModelEndpoint.anthropic,
},
},
});
const modelSpec = {
name: 'claude-opus-4.5',
preset: {
endpoint: 'AnthropicClaude',
endpointType: EModelEndpoint.custom,
model: 'anthropic/claude-opus-4.5',
temperature: 0.7,
maxOutputTokens: 8192,
maxContextTokens: 50000,
},
};
const req = createReq(
{
endpoint: 'AnthropicClaude',
endpointType: EModelEndpoint.custom,
spec: 'claude-opus-4.5',
model: 'anthropic/claude-opus-4.5',
},
{
modelSpecs: {
enforce: true,
list: [modelSpec],
},
},
);
await buildEndpointOption(req, createRes(), jest.fn());
const enforcedCall = parseCompactConvo.mock.calls[1];
expect(enforcedCall[0]).toEqual(
expect.objectContaining({
defaultParamsEndpoint: EModelEndpoint.anthropic,
}),
);
const enforcedResult = parseCompactConvo.mock.results[1].value;
expect(enforcedResult.maxOutputTokens).toBe(8192);
expect(enforcedResult.temperature).toBe(0.7);
expect(enforcedResult.maxContextTokens).toBe(50000);
});
it('should fall back to OpenAI schema when getEndpointsConfig fails', async () => {
mockGetEndpointsConfig.mockRejectedValue(new Error('Config unavailable'));
const req = createReq(
{
endpoint: 'AnthropicClaude',
endpointType: EModelEndpoint.custom,
model: 'anthropic/claude-opus-4.5',
temperature: 0.7,
maxOutputTokens: 8192,
max_tokens: 4096,
},
{ modelSpecs: null },
);
await buildEndpointOption(req, createRes(), jest.fn());
expect(parseCompactConvo).toHaveBeenCalledWith(
expect.objectContaining({
defaultParamsEndpoint: undefined,
}),
);
const parsedResult = parseCompactConvo.mock.results[0].value;
expect(parsedResult.maxOutputTokens).toBeUndefined();
expect(parsedResult.max_tokens).toBe(4096);
});
});

View file

@ -7,16 +7,13 @@ const { isEnabled } = require('@librechat/api');
* Switches between JWT and OpenID authentication based on cookies and environment settings
*/
const requireJwtAuth = (req, res, next) => {
// Check if token provider is specified in cookies
const cookieHeader = req.headers.cookie;
const tokenProvider = cookieHeader ? cookies.parse(cookieHeader).token_provider : null;
// Use OpenID authentication if token provider is OpenID and OPENID_REUSE_TOKENS is enabled
if (tokenProvider === 'openid' && isEnabled(process.env.OPENID_REUSE_TOKENS)) {
return passport.authenticate('openidJwt', { session: false })(req, res, next);
}
// Default to standard JWT authentication
return passport.authenticate('jwt', { session: false })(req, res, next);
};

View file

@ -385,6 +385,40 @@ describe('Convos Routes', () => {
expect(deleteConvoSharedLink).not.toHaveBeenCalled();
});
it('should return 400 when request body is empty (DoS prevention)', async () => {
const response = await request(app).delete('/api/convos').send({});
expect(response.status).toBe(400);
expect(response.body).toEqual({ error: 'no parameters provided' });
expect(deleteConvos).not.toHaveBeenCalled();
});
it('should return 400 when arg is null (DoS prevention)', async () => {
const response = await request(app).delete('/api/convos').send({ arg: null });
expect(response.status).toBe(400);
expect(response.body).toEqual({ error: 'no parameters provided' });
expect(deleteConvos).not.toHaveBeenCalled();
});
it('should return 400 when arg is undefined (DoS prevention)', async () => {
const response = await request(app).delete('/api/convos').send({ arg: undefined });
expect(response.status).toBe(400);
expect(response.body).toEqual({ error: 'no parameters provided' });
expect(deleteConvos).not.toHaveBeenCalled();
});
it('should return 400 when request body is null (DoS prevention)', async () => {
const response = await request(app)
.delete('/api/convos')
.set('Content-Type', 'application/json')
.send('null');
expect(response.status).toBe(400);
expect(deleteConvos).not.toHaveBeenCalled();
});
it('should return 500 if deleteConvoSharedLink fails', async () => {
const mockConversationId = 'conv-error';

View file

@ -0,0 +1,174 @@
const express = require('express');
const request = require('supertest');
jest.mock('~/models', () => ({
updateUserKey: jest.fn(),
deleteUserKey: jest.fn(),
getUserKeyExpiry: jest.fn(),
}));
jest.mock('~/server/middleware/requireJwtAuth', () => (req, res, next) => next());
jest.mock('~/server/middleware', () => ({
requireJwtAuth: (req, res, next) => next(),
}));
describe('Keys Routes', () => {
let app;
const { updateUserKey, deleteUserKey, getUserKeyExpiry } = require('~/models');
beforeAll(() => {
const keysRouter = require('../keys');
app = express();
app.use(express.json());
app.use((req, res, next) => {
req.user = { id: 'test-user-123' };
next();
});
app.use('/api/keys', keysRouter);
});
beforeEach(() => {
jest.clearAllMocks();
});
describe('PUT /', () => {
it('should update a user key with the authenticated user ID', async () => {
updateUserKey.mockResolvedValue({});
const response = await request(app)
.put('/api/keys')
.send({ name: 'openAI', value: 'sk-test-key-123', expiresAt: '2026-12-31' });
expect(response.status).toBe(201);
expect(updateUserKey).toHaveBeenCalledWith({
userId: 'test-user-123',
name: 'openAI',
value: 'sk-test-key-123',
expiresAt: '2026-12-31',
});
expect(updateUserKey).toHaveBeenCalledTimes(1);
});
it('should not allow userId override via request body (IDOR prevention)', async () => {
updateUserKey.mockResolvedValue({});
const response = await request(app).put('/api/keys').send({
userId: 'attacker-injected-id',
name: 'openAI',
value: 'sk-attacker-key',
});
expect(response.status).toBe(201);
expect(updateUserKey).toHaveBeenCalledWith({
userId: 'test-user-123',
name: 'openAI',
value: 'sk-attacker-key',
expiresAt: undefined,
});
});
it('should ignore extraneous fields from request body', async () => {
updateUserKey.mockResolvedValue({});
const response = await request(app).put('/api/keys').send({
name: 'openAI',
value: 'sk-test-key',
expiresAt: '2026-12-31',
_id: 'injected-mongo-id',
__v: 99,
extra: 'should-be-ignored',
});
expect(response.status).toBe(201);
expect(updateUserKey).toHaveBeenCalledWith({
userId: 'test-user-123',
name: 'openAI',
value: 'sk-test-key',
expiresAt: '2026-12-31',
});
});
it('should handle missing optional fields', async () => {
updateUserKey.mockResolvedValue({});
const response = await request(app)
.put('/api/keys')
.send({ name: 'anthropic', value: 'sk-ant-key' });
expect(response.status).toBe(201);
expect(updateUserKey).toHaveBeenCalledWith({
userId: 'test-user-123',
name: 'anthropic',
value: 'sk-ant-key',
expiresAt: undefined,
});
});
it('should return 400 when request body is null', async () => {
const response = await request(app)
.put('/api/keys')
.set('Content-Type', 'application/json')
.send('null');
expect(response.status).toBe(400);
expect(updateUserKey).not.toHaveBeenCalled();
});
});
describe('DELETE /:name', () => {
it('should delete a user key by name', async () => {
deleteUserKey.mockResolvedValue({});
const response = await request(app).delete('/api/keys/openAI');
expect(response.status).toBe(204);
expect(deleteUserKey).toHaveBeenCalledWith({
userId: 'test-user-123',
name: 'openAI',
});
expect(deleteUserKey).toHaveBeenCalledTimes(1);
});
});
describe('DELETE /', () => {
it('should delete all keys when all=true', async () => {
deleteUserKey.mockResolvedValue({});
const response = await request(app).delete('/api/keys?all=true');
expect(response.status).toBe(204);
expect(deleteUserKey).toHaveBeenCalledWith({
userId: 'test-user-123',
all: true,
});
});
it('should return 400 when all query param is not true', async () => {
const response = await request(app).delete('/api/keys');
expect(response.status).toBe(400);
expect(response.body).toEqual({ error: 'Specify either all=true to delete.' });
expect(deleteUserKey).not.toHaveBeenCalled();
});
});
describe('GET /', () => {
it('should return key expiry for a given key name', async () => {
const mockExpiry = { expiresAt: '2026-12-31' };
getUserKeyExpiry.mockResolvedValue(mockExpiry);
const response = await request(app).get('/api/keys?name=openAI');
expect(response.status).toBe(200);
expect(response.body).toEqual(mockExpiry);
expect(getUserKeyExpiry).toHaveBeenCalledWith({
userId: 'test-user-123',
name: 'openAI',
});
});
});
});

View file

@ -1,8 +1,18 @@
const crypto = require('crypto');
const express = require('express');
const request = require('supertest');
const mongoose = require('mongoose');
const { MongoMemoryServer } = require('mongodb-memory-server');
const cookieParser = require('cookie-parser');
const { getBasePath } = require('@librechat/api');
const { MongoMemoryServer } = require('mongodb-memory-server');
function generateTestCsrfToken(flowId) {
return crypto
.createHmac('sha256', process.env.JWT_SECRET)
.update(flowId)
.digest('hex')
.slice(0, 32);
}
const mockRegistryInstance = {
getServerConfig: jest.fn(),
@ -130,6 +140,7 @@ describe('MCP Routes', () => {
app = express();
app.use(express.json());
app.use(cookieParser());
app.use((req, res, next) => {
req.user = { id: 'test-user-id' };
@ -168,12 +179,12 @@ describe('MCP Routes', () => {
MCPOAuthHandler.initiateOAuthFlow.mockResolvedValue({
authorizationUrl: 'https://oauth.example.com/auth',
flowId: 'test-flow-id',
flowId: 'test-user-id:test-server',
});
const response = await request(app).get('/api/mcp/test-server/oauth/initiate').query({
userId: 'test-user-id',
flowId: 'test-flow-id',
flowId: 'test-user-id:test-server',
});
expect(response.status).toBe(302);
@ -190,7 +201,7 @@ describe('MCP Routes', () => {
it('should return 403 when userId does not match authenticated user', async () => {
const response = await request(app).get('/api/mcp/test-server/oauth/initiate').query({
userId: 'different-user-id',
flowId: 'test-flow-id',
flowId: 'test-user-id:test-server',
});
expect(response.status).toBe(403);
@ -228,7 +239,7 @@ describe('MCP Routes', () => {
const response = await request(app).get('/api/mcp/test-server/oauth/initiate').query({
userId: 'test-user-id',
flowId: 'test-flow-id',
flowId: 'test-user-id:test-server',
});
expect(response.status).toBe(400);
@ -245,7 +256,7 @@ describe('MCP Routes', () => {
const response = await request(app).get('/api/mcp/test-server/oauth/initiate').query({
userId: 'test-user-id',
flowId: 'test-flow-id',
flowId: 'test-user-id:test-server',
});
expect(response.status).toBe(500);
@ -255,7 +266,7 @@ describe('MCP Routes', () => {
it('should return 400 when flow state metadata is null', async () => {
const mockFlowManager = {
getFlowState: jest.fn().mockResolvedValue({
id: 'test-flow-id',
id: 'test-user-id:test-server',
metadata: null,
}),
};
@ -265,7 +276,7 @@ describe('MCP Routes', () => {
const response = await request(app).get('/api/mcp/test-server/oauth/initiate').query({
userId: 'test-user-id',
flowId: 'test-flow-id',
flowId: 'test-user-id:test-server',
});
expect(response.status).toBe(400);
@ -280,7 +291,7 @@ describe('MCP Routes', () => {
it('should redirect to error page when OAuth error is received', async () => {
const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({
error: 'access_denied',
state: 'test-flow-id',
state: 'test-user-id:test-server',
});
const basePath = getBasePath();
@ -290,7 +301,7 @@ describe('MCP Routes', () => {
it('should redirect to error page when code is missing', async () => {
const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({
state: 'test-flow-id',
state: 'test-user-id:test-server',
});
const basePath = getBasePath();
@ -308,15 +319,50 @@ describe('MCP Routes', () => {
expect(response.headers.location).toBe(`${basePath}/oauth/error?error=missing_state`);
});
it('should redirect to error page when flow state is not found', async () => {
MCPOAuthHandler.getFlowState.mockResolvedValue(null);
it('should redirect to error page when CSRF cookie is missing', async () => {
const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({
code: 'test-auth-code',
state: 'invalid-flow-id',
state: 'test-user-id:test-server',
});
const basePath = getBasePath();
expect(response.status).toBe(302);
expect(response.headers.location).toBe(
`${basePath}/oauth/error?error=csrf_validation_failed`,
);
});
it('should redirect to error page when CSRF cookie does not match state', async () => {
const csrfToken = generateTestCsrfToken('different-flow-id');
const response = await request(app)
.get('/api/mcp/test-server/oauth/callback')
.set('Cookie', [`oauth_csrf=${csrfToken}`])
.query({
code: 'test-auth-code',
state: 'test-user-id:test-server',
});
const basePath = getBasePath();
expect(response.status).toBe(302);
expect(response.headers.location).toBe(
`${basePath}/oauth/error?error=csrf_validation_failed`,
);
});
it('should redirect to error page when flow state is not found', async () => {
MCPOAuthHandler.getFlowState.mockResolvedValue(null);
const flowId = 'invalid-flow:id';
const csrfToken = generateTestCsrfToken(flowId);
const response = await request(app)
.get('/api/mcp/test-server/oauth/callback')
.set('Cookie', [`oauth_csrf=${csrfToken}`])
.query({
code: 'test-auth-code',
state: flowId,
});
const basePath = getBasePath();
expect(response.status).toBe(302);
expect(response.headers.location).toBe(`${basePath}/oauth/error?error=invalid_state`);
});
@ -369,16 +415,22 @@ describe('MCP Routes', () => {
});
setCachedTools.mockResolvedValue();
const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({
code: 'test-auth-code',
state: 'test-flow-id',
});
const flowId = 'test-user-id:test-server';
const csrfToken = generateTestCsrfToken(flowId);
const response = await request(app)
.get('/api/mcp/test-server/oauth/callback')
.set('Cookie', [`oauth_csrf=${csrfToken}`])
.query({
code: 'test-auth-code',
state: flowId,
});
const basePath = getBasePath();
expect(response.status).toBe(302);
expect(response.headers.location).toBe(`${basePath}/oauth/success?serverName=test-server`);
expect(MCPOAuthHandler.completeOAuthFlow).toHaveBeenCalledWith(
'test-flow-id',
flowId,
'test-auth-code',
mockFlowManager,
{},
@ -400,16 +452,24 @@ describe('MCP Routes', () => {
'mcp_oauth',
mockTokens,
);
expect(mockFlowManager.deleteFlow).toHaveBeenCalledWith('test-flow-id', 'mcp_get_tokens');
expect(mockFlowManager.deleteFlow).toHaveBeenCalledWith(
'test-user-id:test-server',
'mcp_get_tokens',
);
});
it('should redirect to error page when callback processing fails', async () => {
MCPOAuthHandler.getFlowState.mockRejectedValue(new Error('Callback error'));
const flowId = 'test-user-id:test-server';
const csrfToken = generateTestCsrfToken(flowId);
const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({
code: 'test-auth-code',
state: 'test-flow-id',
});
const response = await request(app)
.get('/api/mcp/test-server/oauth/callback')
.set('Cookie', [`oauth_csrf=${csrfToken}`])
.query({
code: 'test-auth-code',
state: flowId,
});
const basePath = getBasePath();
expect(response.status).toBe(302);
@ -442,15 +502,21 @@ describe('MCP Routes', () => {
getLogStores.mockReturnValue({});
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({
code: 'test-auth-code',
state: 'test-flow-id',
});
const flowId = 'test-user-id:test-server';
const csrfToken = generateTestCsrfToken(flowId);
const response = await request(app)
.get('/api/mcp/test-server/oauth/callback')
.set('Cookie', [`oauth_csrf=${csrfToken}`])
.query({
code: 'test-auth-code',
state: flowId,
});
const basePath = getBasePath();
expect(response.status).toBe(302);
expect(response.headers.location).toBe(`${basePath}/oauth/success?serverName=test-server`);
expect(mockFlowManager.deleteFlow).toHaveBeenCalledWith('test-flow-id', 'mcp_get_tokens');
expect(mockFlowManager.deleteFlow).toHaveBeenCalledWith(flowId, 'mcp_get_tokens');
});
it('should handle reconnection failure after OAuth', async () => {
@ -488,16 +554,22 @@ describe('MCP Routes', () => {
getCachedTools.mockResolvedValue({});
setCachedTools.mockResolvedValue();
const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({
code: 'test-auth-code',
state: 'test-flow-id',
});
const flowId = 'test-user-id:test-server';
const csrfToken = generateTestCsrfToken(flowId);
const response = await request(app)
.get('/api/mcp/test-server/oauth/callback')
.set('Cookie', [`oauth_csrf=${csrfToken}`])
.query({
code: 'test-auth-code',
state: flowId,
});
const basePath = getBasePath();
expect(response.status).toBe(302);
expect(response.headers.location).toBe(`${basePath}/oauth/success?serverName=test-server`);
expect(MCPTokenStorage.storeTokens).toHaveBeenCalled();
expect(mockFlowManager.deleteFlow).toHaveBeenCalledWith('test-flow-id', 'mcp_get_tokens');
expect(mockFlowManager.deleteFlow).toHaveBeenCalledWith(flowId, 'mcp_get_tokens');
});
it('should redirect to error page if token storage fails', async () => {
@ -530,10 +602,16 @@ describe('MCP Routes', () => {
};
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({
code: 'test-auth-code',
state: 'test-flow-id',
});
const flowId = 'test-user-id:test-server';
const csrfToken = generateTestCsrfToken(flowId);
const response = await request(app)
.get('/api/mcp/test-server/oauth/callback')
.set('Cookie', [`oauth_csrf=${csrfToken}`])
.query({
code: 'test-auth-code',
state: flowId,
});
const basePath = getBasePath();
expect(response.status).toBe(302);
@ -589,22 +667,27 @@ describe('MCP Routes', () => {
clearReconnection: jest.fn(),
});
const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({
code: 'test-auth-code',
state: 'test-flow-id',
});
const flowId = 'test-user-id:test-server';
const csrfToken = generateTestCsrfToken(flowId);
const response = await request(app)
.get('/api/mcp/test-server/oauth/callback')
.set('Cookie', [`oauth_csrf=${csrfToken}`])
.query({
code: 'test-auth-code',
state: flowId,
});
const basePath = getBasePath();
expect(response.status).toBe(302);
expect(response.headers.location).toBe(`${basePath}/oauth/success?serverName=test-server`);
// Verify storeTokens was called with ORIGINAL flow state credentials
expect(MCPTokenStorage.storeTokens).toHaveBeenCalledWith(
expect.objectContaining({
userId: 'test-user-id',
serverName: 'test-server',
tokens: mockTokens,
clientInfo: clientInfo, // Uses original flow state, not any "updated" credentials
clientInfo: clientInfo,
metadata: flowState.metadata,
}),
);
@ -631,16 +714,21 @@ describe('MCP Routes', () => {
getLogStores.mockReturnValue({});
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({
code: 'test-auth-code',
state: 'test-flow-id',
});
const flowId = 'test-user-id:test-server';
const csrfToken = generateTestCsrfToken(flowId);
const response = await request(app)
.get('/api/mcp/test-server/oauth/callback')
.set('Cookie', [`oauth_csrf=${csrfToken}`])
.query({
code: 'test-auth-code',
state: flowId,
});
const basePath = getBasePath();
expect(response.status).toBe(302);
expect(response.headers.location).toBe(`${basePath}/oauth/success?serverName=test-server`);
// Verify completeOAuthFlow was NOT called (prevented duplicate)
expect(MCPOAuthHandler.completeOAuthFlow).not.toHaveBeenCalled();
expect(MCPTokenStorage.storeTokens).not.toHaveBeenCalled();
});
@ -755,7 +843,7 @@ describe('MCP Routes', () => {
getLogStores.mockReturnValue({});
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
const response = await request(app).get('/api/mcp/oauth/status/test-flow-id');
const response = await request(app).get('/api/mcp/oauth/status/test-user-id:test-server');
expect(response.status).toBe(200);
expect(response.body).toEqual({
@ -766,6 +854,13 @@ describe('MCP Routes', () => {
});
});
it('should return 403 when flowId does not match authenticated user', async () => {
const response = await request(app).get('/api/mcp/oauth/status/other-user-id:test-server');
expect(response.status).toBe(403);
expect(response.body).toEqual({ error: 'Access denied' });
});
it('should return 404 when flow is not found', async () => {
const mockFlowManager = {
getFlowState: jest.fn().mockResolvedValue(null),
@ -774,7 +869,7 @@ describe('MCP Routes', () => {
getLogStores.mockReturnValue({});
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
const response = await request(app).get('/api/mcp/oauth/status/non-existent-flow');
const response = await request(app).get('/api/mcp/oauth/status/test-user-id:non-existent');
expect(response.status).toBe(404);
expect(response.body).toEqual({ error: 'Flow not found' });
@ -788,7 +883,7 @@ describe('MCP Routes', () => {
getLogStores.mockReturnValue({});
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
const response = await request(app).get('/api/mcp/oauth/status/error-flow-id');
const response = await request(app).get('/api/mcp/oauth/status/test-user-id:error-server');
expect(response.status).toBe(500);
expect(response.body).toEqual({ error: 'Failed to get flow status' });
@ -1375,7 +1470,7 @@ describe('MCP Routes', () => {
refresh_token: 'edge-refresh-token',
};
MCPOAuthHandler.getFlowState = jest.fn().mockResolvedValue({
id: 'test-flow-id',
id: 'test-user-id:test-server',
userId: 'test-user-id',
metadata: {
serverUrl: 'https://example.com',
@ -1403,8 +1498,12 @@ describe('MCP Routes', () => {
};
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
const flowId = 'test-user-id:test-server';
const csrfToken = generateTestCsrfToken(flowId);
const response = await request(app)
.get('/api/mcp/test-server/oauth/callback?code=test-code&state=test-flow-id')
.get(`/api/mcp/test-server/oauth/callback?code=test-code&state=${flowId}`)
.set('Cookie', [`oauth_csrf=${csrfToken}`])
.expect(302);
const basePath = getBasePath();
@ -1424,7 +1523,7 @@ describe('MCP Routes', () => {
const mockFlowManager = {
getFlowState: jest.fn().mockResolvedValue({
id: 'test-flow-id',
id: 'test-user-id:test-server',
userId: 'test-user-id',
metadata: { serverUrl: 'https://example.com', oauth: {} },
clientInfo: {},
@ -1453,8 +1552,12 @@ describe('MCP Routes', () => {
};
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
const flowId = 'test-user-id:test-server';
const csrfToken = generateTestCsrfToken(flowId);
const response = await request(app)
.get('/api/mcp/test-server/oauth/callback?code=test-code&state=test-flow-id')
.get(`/api/mcp/test-server/oauth/callback?code=test-code&state=${flowId}`)
.set('Cookie', [`oauth_csrf=${csrfToken}`])
.expect(302);
const basePath = getBasePath();

View file

@ -1,14 +1,47 @@
const express = require('express');
const jwt = require('jsonwebtoken');
const { getAccessToken, getBasePath } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const { CacheKeys } = require('librechat-data-provider');
const {
getBasePath,
getAccessToken,
setOAuthSession,
validateOAuthCsrf,
OAUTH_CSRF_COOKIE,
setOAuthCsrfCookie,
validateOAuthSession,
OAUTH_SESSION_COOKIE,
} = require('@librechat/api');
const { findToken, updateToken, createToken } = require('~/models');
const { requireJwtAuth } = require('~/server/middleware');
const { getFlowStateManager } = require('~/config');
const { getLogStores } = require('~/cache');
const router = express.Router();
const JWT_SECRET = process.env.JWT_SECRET;
const OAUTH_CSRF_COOKIE_PATH = '/api/actions';
/**
* Sets a CSRF cookie binding the action OAuth flow to the current browser session.
* Must be called before the user opens the IdP authorization URL.
*
* @route POST /actions/:action_id/oauth/bind
*/
router.post('/:action_id/oauth/bind', requireJwtAuth, setOAuthSession, async (req, res) => {
try {
const { action_id } = req.params;
const user = req.user;
if (!user?.id) {
return res.status(401).json({ error: 'User not authenticated' });
}
const flowId = `${user.id}:${action_id}`;
setOAuthCsrfCookie(res, flowId, OAUTH_CSRF_COOKIE_PATH);
res.json({ success: true });
} catch (error) {
logger.error('[Action OAuth] Failed to set CSRF binding cookie', error);
res.status(500).json({ error: 'Failed to bind OAuth flow' });
}
});
/**
* Handles the OAuth callback and exchanges the authorization code for tokens.
@ -45,7 +78,22 @@ router.get('/:action_id/oauth/callback', async (req, res) => {
await flowManager.failFlow(identifier, 'oauth', 'Invalid user ID in state parameter');
return res.redirect(`${basePath}/oauth/error?error=invalid_state`);
}
identifier = `${decodedState.user}:${action_id}`;
if (
!validateOAuthCsrf(req, res, identifier, OAUTH_CSRF_COOKIE_PATH) &&
!validateOAuthSession(req, decodedState.user)
) {
logger.error('[Action OAuth] CSRF validation failed: no valid CSRF or session cookie', {
identifier,
hasCsrfCookie: !!req.cookies?.[OAUTH_CSRF_COOKIE],
hasSessionCookie: !!req.cookies?.[OAUTH_SESSION_COOKIE],
});
await flowManager.failFlow(identifier, 'oauth', 'CSRF validation failed');
return res.redirect(`${basePath}/oauth/error?error=csrf_validation_failed`);
}
const flowState = await flowManager.getFlowState(identifier, 'oauth');
if (!flowState) {
throw new Error('OAuth flow not found');
@ -71,7 +119,6 @@ router.get('/:action_id/oauth/callback', async (req, res) => {
);
await flowManager.completeFlow(identifier, 'oauth', tokenData);
/** Redirect to React success page */
const serverName = flowState.metadata?.action_name || `Action ${action_id}`;
const redirectUrl = `${basePath}/oauth/success?serverName=${encodeURIComponent(serverName)}`;
res.redirect(redirectUrl);

View file

@ -98,7 +98,7 @@ router.get('/gen_title/:conversationId', async (req, res) => {
router.delete('/', async (req, res) => {
let filter = {};
const { conversationId, source, thread_id, endpoint } = req.body.arg;
const { conversationId, source, thread_id, endpoint } = req.body?.arg ?? {};
// Prevent deletion of all conversations
if (!conversationId && !source && !thread_id && !endpoint) {
@ -160,7 +160,7 @@ router.delete('/all', async (req, res) => {
* @returns {object} 200 - The updated conversation object.
*/
router.post('/archive', validateConvoAccess, async (req, res) => {
const { conversationId, isArchived } = req.body.arg ?? {};
const { conversationId, isArchived } = req.body?.arg ?? {};
if (!conversationId) {
return res.status(400).json({ error: 'conversationId is required' });
@ -194,7 +194,7 @@ const MAX_CONVO_TITLE_LENGTH = 1024;
* @returns {object} 201 - The updated conversation object.
*/
router.post('/update', validateConvoAccess, async (req, res) => {
const { conversationId, title } = req.body.arg ?? {};
const { conversationId, title } = req.body?.arg ?? {};
if (!conversationId) {
return res.status(400).json({ error: 'conversationId is required' });

View file

@ -5,7 +5,11 @@ const { requireJwtAuth } = require('~/server/middleware');
const router = express.Router();
router.put('/', requireJwtAuth, async (req, res) => {
await updateUserKey({ userId: req.user.id, ...req.body });
if (req.body == null || typeof req.body !== 'object') {
return res.status(400).send({ error: 'Invalid request body.' });
}
const { name, value, expiresAt } = req.body;
await updateUserKey({ userId: req.user.id, name, value, expiresAt });
res.status(201).send();
});

View file

@ -8,18 +8,32 @@ const {
Permissions,
} = require('librechat-data-provider');
const {
getBasePath,
createSafeUser,
MCPOAuthHandler,
MCPTokenStorage,
getBasePath,
setOAuthSession,
getUserMCPAuthMap,
validateOAuthCsrf,
OAUTH_CSRF_COOKIE,
setOAuthCsrfCookie,
generateCheckAccess,
validateOAuthSession,
OAUTH_SESSION_COOKIE,
} = require('@librechat/api');
const {
getMCPManager,
getFlowStateManager,
createMCPServerController,
updateMCPServerController,
deleteMCPServerController,
getMCPServersList,
getMCPServerById,
getMCPTools,
} = require('~/server/controllers/mcp');
const {
getOAuthReconnectionManager,
getMCPServersRegistry,
getFlowStateManager,
getMCPManager,
} = require('~/config');
const { getMCPSetupData, getServerConnectionStatus } = require('~/server/services/MCP');
const { requireJwtAuth, canAccessMCPServerResource } = require('~/server/middleware');
@ -27,20 +41,14 @@ const { findToken, updateToken, createToken, deleteTokens } = require('~/models'
const { getUserPluginAuthValue } = require('~/server/services/PluginService');
const { updateMCPServerTools } = require('~/server/services/Config/mcp');
const { reinitMCPServer } = require('~/server/services/Tools/mcp');
const { getMCPTools } = require('~/server/controllers/mcp');
const { findPluginAuthsByKeys } = require('~/models');
const { getRoleByName } = require('~/models/Role');
const { getLogStores } = require('~/cache');
const {
createMCPServerController,
getMCPServerById,
getMCPServersList,
updateMCPServerController,
deleteMCPServerController,
} = require('~/server/controllers/mcp');
const router = Router();
const OAUTH_CSRF_COOKIE_PATH = '/api/mcp';
/**
* Get all MCP tools available to the user
* Returns only MCP tools, completely decoupled from regular LibreChat tools
@ -53,7 +61,7 @@ router.get('/tools', requireJwtAuth, async (req, res) => {
* Initiate OAuth flow
* This endpoint is called when the user clicks the auth link in the UI
*/
router.get('/:serverName/oauth/initiate', requireJwtAuth, async (req, res) => {
router.get('/:serverName/oauth/initiate', requireJwtAuth, setOAuthSession, async (req, res) => {
try {
const { serverName } = req.params;
const { userId, flowId } = req.query;
@ -93,7 +101,7 @@ router.get('/:serverName/oauth/initiate', requireJwtAuth, async (req, res) => {
logger.debug('[MCP OAuth] OAuth flow initiated', { oauthFlowId, authorizationUrl });
// Redirect user to the authorization URL
setOAuthCsrfCookie(res, oauthFlowId, OAUTH_CSRF_COOKIE_PATH);
res.redirect(authorizationUrl);
} catch (error) {
logger.error('[MCP OAuth] Failed to initiate OAuth', error);
@ -138,6 +146,25 @@ router.get('/:serverName/oauth/callback', async (req, res) => {
const flowId = state;
logger.debug('[MCP OAuth] Using flow ID from state', { flowId });
const flowParts = flowId.split(':');
if (flowParts.length < 2 || !flowParts[0] || !flowParts[1]) {
logger.error('[MCP OAuth] Invalid flow ID format in state', { flowId });
return res.redirect(`${basePath}/oauth/error?error=invalid_state`);
}
const [flowUserId] = flowParts;
if (
!validateOAuthCsrf(req, res, flowId, OAUTH_CSRF_COOKIE_PATH) &&
!validateOAuthSession(req, flowUserId)
) {
logger.error('[MCP OAuth] CSRF validation failed: no valid CSRF or session cookie', {
flowId,
hasCsrfCookie: !!req.cookies?.[OAUTH_CSRF_COOKIE],
hasSessionCookie: !!req.cookies?.[OAUTH_SESSION_COOKIE],
});
return res.redirect(`${basePath}/oauth/error?error=csrf_validation_failed`);
}
const flowsCache = getLogStores(CacheKeys.FLOWS);
const flowManager = getFlowStateManager(flowsCache);
@ -302,13 +329,47 @@ router.get('/oauth/tokens/:flowId', requireJwtAuth, async (req, res) => {
}
});
/**
* Set CSRF binding cookie for OAuth flows initiated outside of HTTP request/response
* (e.g. during chat via SSE). The frontend should call this before opening the OAuth URL
* so the callback can verify the browser matches the flow initiator.
*/
router.post('/:serverName/oauth/bind', requireJwtAuth, setOAuthSession, async (req, res) => {
try {
const { serverName } = req.params;
const user = req.user;
if (!user?.id) {
return res.status(401).json({ error: 'User not authenticated' });
}
const flowId = MCPOAuthHandler.generateFlowId(user.id, serverName);
setOAuthCsrfCookie(res, flowId, OAUTH_CSRF_COOKIE_PATH);
res.json({ success: true });
} catch (error) {
logger.error('[MCP OAuth] Failed to set CSRF binding cookie', error);
res.status(500).json({ error: 'Failed to bind OAuth flow' });
}
});
/**
* Check OAuth flow status
* This endpoint can be used to poll the status of an OAuth flow
*/
router.get('/oauth/status/:flowId', async (req, res) => {
router.get('/oauth/status/:flowId', requireJwtAuth, async (req, res) => {
try {
const { flowId } = req.params;
const user = req.user;
if (!user?.id) {
return res.status(401).json({ error: 'User not authenticated' });
}
if (!flowId.startsWith(`${user.id}:`) && !flowId.startsWith('system:')) {
return res.status(403).json({ error: 'Access denied' });
}
const flowsCache = getLogStores(CacheKeys.FLOWS);
const flowManager = getFlowStateManager(flowsCache);
@ -375,7 +436,7 @@ router.post('/oauth/cancel/:serverName', requireJwtAuth, async (req, res) => {
* Reinitialize MCP server
* This endpoint allows reinitializing a specific MCP server
*/
router.post('/:serverName/reinitialize', requireJwtAuth, async (req, res) => {
router.post('/:serverName/reinitialize', requireJwtAuth, setOAuthSession, async (req, res) => {
try {
const { serverName } = req.params;
const user = createSafeUser(req.user);
@ -421,6 +482,11 @@ router.post('/:serverName/reinitialize', requireJwtAuth, async (req, res) => {
const { success, message, oauthRequired, oauthUrl } = result;
if (oauthRequired) {
const flowId = MCPOAuthHandler.generateFlowId(user.id, serverName);
setOAuthCsrfCookie(res, flowId, OAUTH_CSRF_COOKIE_PATH);
}
res.json({
success,
message,

View file

@ -29,7 +29,7 @@ const oauthHandler = createOAuthHandler();
router.get('/error', (req, res) => {
/** A single error message is pushed by passport when authentication fails. */
const errorMessage = req.session?.messages?.pop() || 'Unknown error';
const errorMessage = req.session?.messages?.pop() || 'Unknown OAuth error';
logger.error('Error in OAuth authentication:', {
message: errorMessage,
});

View file

@ -8,6 +8,7 @@ const {
logAxiosError,
refreshAccessToken,
GenerationJobManager,
createSSRFSafeAgents,
} = require('@librechat/api');
const {
Time,
@ -133,6 +134,7 @@ async function loadActionSets(searchParams) {
* @param {import('zod').ZodTypeAny | undefined} [params.zodSchema] - The Zod schema for tool input validation/definition
* @param {{ oauth_client_id?: string; oauth_client_secret?: string; }} params.encrypted - The encrypted values for the action.
* @param {string | null} [params.streamId] - The stream ID for resumable streams.
* @param {boolean} [params.useSSRFProtection] - When true, uses SSRF-safe HTTP agents that validate resolved IPs at connect time.
* @returns { Promise<typeof tool | { _call: (toolInput: Object | string) => unknown}> } An object with `_call` method to execute the tool input.
*/
async function createActionTool({
@ -145,7 +147,9 @@ async function createActionTool({
description,
encrypted,
streamId = null,
useSSRFProtection = false,
}) {
const ssrfAgents = useSSRFProtection ? createSSRFSafeAgents() : undefined;
/** @type {(toolInput: Object | string, config: GraphRunnableConfig) => Promise<unknown>} */
const _call = async (toolInput, config) => {
try {
@ -324,7 +328,7 @@ async function createActionTool({
}
}
const response = await preparedExecutor.execute();
const response = await preparedExecutor.execute(ssrfAgents);
if (typeof response.data === 'object') {
return JSON.stringify(response.data);

View file

@ -7,7 +7,13 @@ const {
DEFAULT_REFRESH_TOKEN_EXPIRY,
} = require('@librechat/data-schemas');
const { ErrorTypes, SystemRoles, errorsToString } = require('librechat-data-provider');
const { isEnabled, checkEmailConfig, isEmailDomainAllowed, math } = require('@librechat/api');
const {
math,
isEnabled,
checkEmailConfig,
isEmailDomainAllowed,
shouldUseSecureCookie,
} = require('@librechat/api');
const {
findUser,
findToken,
@ -33,7 +39,6 @@ const domains = {
server: process.env.DOMAIN_SERVER,
};
const isProduction = process.env.NODE_ENV === 'production';
const genericVerificationMessage = 'Please check your email to verify your email address.';
/**
@ -392,13 +397,13 @@ const setAuthTokens = async (userId, res, _session = null) => {
res.cookie('refreshToken', refreshToken, {
expires: new Date(refreshTokenExpires),
httpOnly: true,
secure: isProduction,
secure: shouldUseSecureCookie(),
sameSite: 'strict',
});
res.cookie('token_provider', 'librechat', {
expires: new Date(refreshTokenExpires),
httpOnly: true,
secure: isProduction,
secure: shouldUseSecureCookie(),
sameSite: 'strict',
});
return token;
@ -419,7 +424,7 @@ const setAuthTokens = async (userId, res, _session = null) => {
* @param {Object} req - request object (for session access)
* @param {Object} res - response object
* @param {string} [userId] - Optional MongoDB user ID for image path validation
* @returns {String} - access token
* @returns {String} - id_token (preferred) or access_token as the app auth token
*/
const setOpenIDAuthTokens = (tokenset, req, res, userId, existingRefreshToken) => {
try {
@ -448,34 +453,62 @@ const setOpenIDAuthTokens = (tokenset, req, res, userId, existingRefreshToken) =
return;
}
/**
* Use id_token as the app authentication token (Bearer token for JWKS validation).
* The id_token is always a standard JWT signed by the IdP's JWKS keys with the app's
* client_id as audience. The access_token may be opaque or intended for a different
* audience (e.g., Microsoft Graph API), which fails JWKS validation.
* Falls back to access_token for providers where id_token is not available.
*/
const appAuthToken = tokenset.id_token || tokenset.access_token;
/**
* Always set refresh token cookie so it survives express session expiry.
* The session cookie maxAge (SESSION_EXPIRY, default 15 min) is typically shorter
* than the OIDC token lifetime (~1 hour). Without this cookie fallback, the refresh
* token stored only in the session is lost when the session expires, causing the user
* to be signed out on the next token refresh attempt.
* The refresh token is small (opaque string) so it doesn't hit the HTTP/2 header
* size limits that motivated session storage for the larger access_token/id_token.
*/
res.cookie('refreshToken', refreshToken, {
expires: expirationDate,
httpOnly: true,
secure: shouldUseSecureCookie(),
sameSite: 'strict',
});
/** Store tokens server-side in session to avoid large cookies */
if (req.session) {
req.session.openidTokens = {
accessToken: tokenset.access_token,
idToken: tokenset.id_token,
refreshToken: refreshToken,
expiresAt: expirationDate.getTime(),
};
} else {
logger.warn('[setOpenIDAuthTokens] No session available, falling back to cookies');
res.cookie('refreshToken', refreshToken, {
expires: expirationDate,
httpOnly: true,
secure: isProduction,
sameSite: 'strict',
});
res.cookie('openid_access_token', tokenset.access_token, {
expires: expirationDate,
httpOnly: true,
secure: isProduction,
secure: shouldUseSecureCookie(),
sameSite: 'strict',
});
if (tokenset.id_token) {
res.cookie('openid_id_token', tokenset.id_token, {
expires: expirationDate,
httpOnly: true,
secure: shouldUseSecureCookie(),
sameSite: 'strict',
});
}
}
/** Small cookie to indicate token provider (required for auth middleware) */
res.cookie('token_provider', 'openid', {
expires: expirationDate,
httpOnly: true,
secure: isProduction,
secure: shouldUseSecureCookie(),
sameSite: 'strict',
});
if (userId && isEnabled(process.env.OPENID_REUSE_TOKENS)) {
@ -486,11 +519,11 @@ const setOpenIDAuthTokens = (tokenset, req, res, userId, existingRefreshToken) =
res.cookie('openid_user_id', signedUserId, {
expires: expirationDate,
httpOnly: true,
secure: isProduction,
secure: shouldUseSecureCookie(),
sameSite: 'strict',
});
}
return tokenset.access_token;
return appAuthToken;
} catch (error) {
logger.error('[setOpenIDAuthTokens] Error in setting authentication tokens:', error);
throw error;

View file

@ -0,0 +1,269 @@
jest.mock('@librechat/data-schemas', () => ({
logger: { info: jest.fn(), warn: jest.fn(), debug: jest.fn(), error: jest.fn() },
DEFAULT_SESSION_EXPIRY: 900000,
DEFAULT_REFRESH_TOKEN_EXPIRY: 604800000,
}));
jest.mock('librechat-data-provider', () => ({
ErrorTypes: {},
SystemRoles: { USER: 'USER', ADMIN: 'ADMIN' },
errorsToString: jest.fn(),
}));
jest.mock('@librechat/api', () => ({
isEnabled: jest.fn((val) => val === 'true' || val === true),
checkEmailConfig: jest.fn(),
isEmailDomainAllowed: jest.fn(),
math: jest.fn((val, fallback) => (val ? Number(val) : fallback)),
shouldUseSecureCookie: jest.fn(() => false),
}));
jest.mock('~/models', () => ({
findUser: jest.fn(),
findToken: jest.fn(),
createUser: jest.fn(),
updateUser: jest.fn(),
countUsers: jest.fn(),
getUserById: jest.fn(),
findSession: jest.fn(),
createToken: jest.fn(),
deleteTokens: jest.fn(),
deleteSession: jest.fn(),
createSession: jest.fn(),
generateToken: jest.fn(),
deleteUserById: jest.fn(),
generateRefreshToken: jest.fn(),
}));
jest.mock('~/strategies/validators', () => ({ registerSchema: { parse: jest.fn() } }));
jest.mock('~/server/services/Config', () => ({ getAppConfig: jest.fn() }));
jest.mock('~/server/utils', () => ({ sendEmail: jest.fn() }));
const { shouldUseSecureCookie } = require('@librechat/api');
const { setOpenIDAuthTokens } = require('./AuthService');
/** Helper to build a mock Express response */
function mockResponse() {
const cookies = {};
const res = {
cookie: jest.fn((name, value, options) => {
cookies[name] = { value, options };
}),
_cookies: cookies,
};
return res;
}
/** Helper to build a mock Express request with session */
function mockRequest(sessionData = {}) {
return {
session: { openidTokens: null, ...sessionData },
};
}
describe('setOpenIDAuthTokens', () => {
const env = process.env;
beforeEach(() => {
jest.clearAllMocks();
process.env = {
...env,
JWT_REFRESH_SECRET: 'test-refresh-secret',
OPENID_REUSE_TOKENS: 'true',
};
});
afterAll(() => {
process.env = env;
});
describe('token selection (id_token vs access_token)', () => {
it('should return id_token when both id_token and access_token are present', () => {
const tokenset = {
id_token: 'the-id-token',
access_token: 'the-access-token',
refresh_token: 'the-refresh-token',
};
const req = mockRequest();
const res = mockResponse();
const result = setOpenIDAuthTokens(tokenset, req, res, 'user-123');
expect(result).toBe('the-id-token');
});
it('should return access_token when id_token is not available', () => {
const tokenset = {
access_token: 'the-access-token',
refresh_token: 'the-refresh-token',
};
const req = mockRequest();
const res = mockResponse();
const result = setOpenIDAuthTokens(tokenset, req, res, 'user-123');
expect(result).toBe('the-access-token');
});
it('should return access_token when id_token is undefined', () => {
const tokenset = {
id_token: undefined,
access_token: 'the-access-token',
refresh_token: 'the-refresh-token',
};
const req = mockRequest();
const res = mockResponse();
const result = setOpenIDAuthTokens(tokenset, req, res, 'user-123');
expect(result).toBe('the-access-token');
});
it('should return access_token when id_token is null', () => {
const tokenset = {
id_token: null,
access_token: 'the-access-token',
refresh_token: 'the-refresh-token',
};
const req = mockRequest();
const res = mockResponse();
const result = setOpenIDAuthTokens(tokenset, req, res, 'user-123');
expect(result).toBe('the-access-token');
});
it('should return id_token even when id_token and access_token differ', () => {
const tokenset = {
id_token: 'id-token-jwt-signed-by-idp',
access_token: 'opaque-graph-api-token',
refresh_token: 'refresh-token',
};
const req = mockRequest();
const res = mockResponse();
const result = setOpenIDAuthTokens(tokenset, req, res, 'user-123');
expect(result).toBe('id-token-jwt-signed-by-idp');
expect(result).not.toBe('opaque-graph-api-token');
});
});
describe('session token storage', () => {
it('should store the original access_token in session (not id_token)', () => {
const tokenset = {
id_token: 'the-id-token',
access_token: 'the-access-token',
refresh_token: 'the-refresh-token',
};
const req = mockRequest();
const res = mockResponse();
setOpenIDAuthTokens(tokenset, req, res, 'user-123');
expect(req.session.openidTokens.accessToken).toBe('the-access-token');
expect(req.session.openidTokens.refreshToken).toBe('the-refresh-token');
});
});
describe('cookie secure flag', () => {
it('should call shouldUseSecureCookie for every cookie set', () => {
const tokenset = {
id_token: 'the-id-token',
access_token: 'the-access-token',
refresh_token: 'the-refresh-token',
};
const req = mockRequest();
const res = mockResponse();
setOpenIDAuthTokens(tokenset, req, res, 'user-123');
// token_provider + openid_user_id (session path, so no refreshToken/openid_access_token cookies)
const secureCalls = shouldUseSecureCookie.mock.calls.length;
expect(secureCalls).toBeGreaterThanOrEqual(2);
// Verify all cookies use the result of shouldUseSecureCookie
for (const [, cookie] of Object.entries(res._cookies)) {
expect(cookie.options.secure).toBe(false);
}
});
it('should set secure: true when shouldUseSecureCookie returns true', () => {
shouldUseSecureCookie.mockReturnValue(true);
const tokenset = {
id_token: 'the-id-token',
access_token: 'the-access-token',
refresh_token: 'the-refresh-token',
};
const req = mockRequest();
const res = mockResponse();
setOpenIDAuthTokens(tokenset, req, res, 'user-123');
for (const [, cookie] of Object.entries(res._cookies)) {
expect(cookie.options.secure).toBe(true);
}
});
it('should use shouldUseSecureCookie for cookie fallback path (no session)', () => {
shouldUseSecureCookie.mockReturnValue(false);
const tokenset = {
id_token: 'the-id-token',
access_token: 'the-access-token',
refresh_token: 'the-refresh-token',
};
const req = { session: null };
const res = mockResponse();
setOpenIDAuthTokens(tokenset, req, res, 'user-123');
// In the cookie fallback path, we get: refreshToken, openid_access_token, token_provider, openid_user_id
expect(res.cookie).toHaveBeenCalledWith(
'refreshToken',
expect.any(String),
expect.objectContaining({ secure: false }),
);
expect(res.cookie).toHaveBeenCalledWith(
'openid_access_token',
expect.any(String),
expect.objectContaining({ secure: false }),
);
expect(res.cookie).toHaveBeenCalledWith(
'token_provider',
'openid',
expect.objectContaining({ secure: false }),
);
});
});
describe('edge cases', () => {
it('should return undefined when tokenset is null', () => {
const req = mockRequest();
const res = mockResponse();
const result = setOpenIDAuthTokens(null, req, res, 'user-123');
expect(result).toBeUndefined();
});
it('should return undefined when access_token is missing', () => {
const tokenset = { refresh_token: 'refresh' };
const req = mockRequest();
const res = mockResponse();
const result = setOpenIDAuthTokens(tokenset, req, res, 'user-123');
expect(result).toBeUndefined();
});
it('should return undefined when no refresh token is available', () => {
const tokenset = { access_token: 'access', id_token: 'id' };
const req = mockRequest();
const res = mockResponse();
const result = setOpenIDAuthTokens(tokenset, req, res, 'user-123');
expect(result).toBeUndefined();
});
it('should use existingRefreshToken when tokenset has no refresh_token', () => {
const tokenset = {
id_token: 'the-id-token',
access_token: 'the-access-token',
};
const req = mockRequest();
const res = mockResponse();
const result = setOpenIDAuthTokens(tokenset, req, res, 'user-123', 'existing-refresh');
expect(result).toBe('the-id-token');
expect(req.session.openidTokens.refreshToken).toBe('existing-refresh');
});
});
});

View file

@ -1,10 +1,92 @@
const { ToolCacheKeys } = require('../getCachedTools');
const { CacheKeys } = require('librechat-data-provider');
jest.mock('~/cache/getLogStores');
const getLogStores = require('~/cache/getLogStores');
const mockCache = { get: jest.fn(), set: jest.fn(), delete: jest.fn() };
getLogStores.mockReturnValue(mockCache);
const {
ToolCacheKeys,
getCachedTools,
setCachedTools,
getMCPServerTools,
invalidateCachedTools,
} = require('../getCachedTools');
describe('getCachedTools', () => {
beforeEach(() => {
jest.clearAllMocks();
getLogStores.mockReturnValue(mockCache);
});
describe('getCachedTools - Cache Isolation Security', () => {
describe('ToolCacheKeys.MCP_SERVER', () => {
it('should generate cache keys that include userId', () => {
const key = ToolCacheKeys.MCP_SERVER('user123', 'github');
expect(key).toBe('tools:mcp:user123:github');
});
});
describe('TOOL_CACHE namespace usage', () => {
it('getCachedTools should use TOOL_CACHE namespace', async () => {
mockCache.get.mockResolvedValue(null);
await getCachedTools();
expect(getLogStores).toHaveBeenCalledWith(CacheKeys.TOOL_CACHE);
});
it('getCachedTools with MCP server options should use TOOL_CACHE namespace', async () => {
mockCache.get.mockResolvedValue({ tool1: {} });
await getCachedTools({ userId: 'user1', serverName: 'github' });
expect(getLogStores).toHaveBeenCalledWith(CacheKeys.TOOL_CACHE);
expect(mockCache.get).toHaveBeenCalledWith(ToolCacheKeys.MCP_SERVER('user1', 'github'));
});
it('setCachedTools should use TOOL_CACHE namespace', async () => {
mockCache.set.mockResolvedValue(true);
const tools = { tool1: { type: 'function' } };
await setCachedTools(tools);
expect(getLogStores).toHaveBeenCalledWith(CacheKeys.TOOL_CACHE);
expect(mockCache.set).toHaveBeenCalledWith(ToolCacheKeys.GLOBAL, tools, expect.any(Number));
});
it('setCachedTools with MCP server options should use TOOL_CACHE namespace', async () => {
mockCache.set.mockResolvedValue(true);
const tools = { tool1: { type: 'function' } };
await setCachedTools(tools, { userId: 'user1', serverName: 'github' });
expect(getLogStores).toHaveBeenCalledWith(CacheKeys.TOOL_CACHE);
expect(mockCache.set).toHaveBeenCalledWith(
ToolCacheKeys.MCP_SERVER('user1', 'github'),
tools,
expect.any(Number),
);
});
it('invalidateCachedTools should use TOOL_CACHE namespace', async () => {
mockCache.delete.mockResolvedValue(true);
await invalidateCachedTools({ invalidateGlobal: true });
expect(getLogStores).toHaveBeenCalledWith(CacheKeys.TOOL_CACHE);
expect(mockCache.delete).toHaveBeenCalledWith(ToolCacheKeys.GLOBAL);
});
it('getMCPServerTools should use TOOL_CACHE namespace', async () => {
mockCache.get.mockResolvedValue(null);
await getMCPServerTools('user1', 'github');
expect(getLogStores).toHaveBeenCalledWith(CacheKeys.TOOL_CACHE);
expect(mockCache.get).toHaveBeenCalledWith(ToolCacheKeys.MCP_SERVER('user1', 'github'));
});
it('should NOT use CONFIG_STORE namespace', async () => {
mockCache.get.mockResolvedValue(null);
await getCachedTools();
await getMCPServerTools('user1', 'github');
mockCache.set.mockResolvedValue(true);
await setCachedTools({ tool1: {} });
mockCache.delete.mockResolvedValue(true);
await invalidateCachedTools({ invalidateGlobal: true });
const allCalls = getLogStores.mock.calls.flat();
expect(allCalls).not.toContain(CacheKeys.CONFIG_STORE);
expect(allCalls.every((key) => key === CacheKeys.TOOL_CACHE)).toBe(true);
});
});
});

View file

@ -20,7 +20,7 @@ const ToolCacheKeys = {
* @returns {Promise<LCAvailableTools|null>} The available tools object or null if not cached
*/
async function getCachedTools(options = {}) {
const cache = getLogStores(CacheKeys.CONFIG_STORE);
const cache = getLogStores(CacheKeys.TOOL_CACHE);
const { userId, serverName } = options;
// Return MCP server-specific tools if requested
@ -43,7 +43,7 @@ async function getCachedTools(options = {}) {
* @returns {Promise<boolean>} Whether the operation was successful
*/
async function setCachedTools(tools, options = {}) {
const cache = getLogStores(CacheKeys.CONFIG_STORE);
const cache = getLogStores(CacheKeys.TOOL_CACHE);
const { userId, serverName, ttl = Time.TWELVE_HOURS } = options;
// Cache by MCP server if specified (requires userId)
@ -65,7 +65,7 @@ async function setCachedTools(tools, options = {}) {
* @returns {Promise<void>}
*/
async function invalidateCachedTools(options = {}) {
const cache = getLogStores(CacheKeys.CONFIG_STORE);
const cache = getLogStores(CacheKeys.TOOL_CACHE);
const { userId, serverName, invalidateGlobal = false } = options;
const keysToDelete = [];
@ -89,7 +89,7 @@ async function invalidateCachedTools(options = {}) {
* @returns {Promise<LCAvailableTools|null>} The available tools for the server
*/
async function getMCPServerTools(userId, serverName) {
const cache = getLogStores(CacheKeys.CONFIG_STORE);
const cache = getLogStores(CacheKeys.TOOL_CACHE);
const serverTools = await cache.get(ToolCacheKeys.MCP_SERVER(userId, serverName));
if (serverTools) {

View file

@ -35,7 +35,7 @@ async function updateMCPServerTools({ userId, serverName, tools }) {
await setCachedTools(serverTools, { userId, serverName });
const cache = getLogStores(CacheKeys.CONFIG_STORE);
const cache = getLogStores(CacheKeys.TOOL_CACHE);
await cache.delete(CacheKeys.TOOLS);
logger.debug(
`[MCP Cache] Updated ${tools.length} tools for server ${serverName} (user: ${userId})`,
@ -61,7 +61,7 @@ async function mergeAppTools(appTools) {
const cachedTools = await getCachedTools();
const mergedTools = { ...cachedTools, ...appTools };
await setCachedTools(mergedTools);
const cache = getLogStores(CacheKeys.CONFIG_STORE);
const cache = getLogStores(CacheKeys.TOOL_CACHE);
await cache.delete(CacheKeys.TOOLS);
logger.debug(`Merged ${count} app-level tools`);
} catch (error) {

View file

@ -71,7 +71,7 @@ const addTitle = async (req, { text, response, client }) => {
conversationId: response.conversationId,
title,
},
{ context: 'api/server/services/Endpoints/agents/title.js' },
{ context: 'api/server/services/Endpoints/agents/title.js', noUpsert: true },
);
} catch (error) {
logger.error('Error generating title:', error);

View file

@ -69,7 +69,7 @@ const addTitle = async (req, { text, responseText, conversationId }) => {
conversationId,
title,
},
{ context: 'api/server/services/Endpoints/assistants/addTitle.js' },
{ context: 'api/server/services/Endpoints/assistants/addTitle.js', noUpsert: true },
);
} catch (error) {
logger.error('[addTitle] Error generating title:', error);
@ -81,7 +81,7 @@ const addTitle = async (req, { text, responseText, conversationId }) => {
conversationId,
title: fallbackTitle,
},
{ context: 'api/server/services/Endpoints/assistants/addTitle.js' },
{ context: 'api/server/services/Endpoints/assistants/addTitle.js', noUpsert: true },
);
}
};

View file

@ -4,7 +4,7 @@ const mime = require('mime');
const axios = require('axios');
const fetch = require('node-fetch');
const { logger } = require('@librechat/data-schemas');
const { getAzureContainerClient } = require('@librechat/api');
const { getAzureContainerClient, deleteRagFile } = require('@librechat/api');
const defaultBasePath = 'images';
const { AZURE_STORAGE_PUBLIC_ACCESS = 'true', AZURE_CONTAINER_NAME = 'files' } = process.env;
@ -102,6 +102,8 @@ async function getAzureURL({ fileName, basePath = defaultBasePath, userId, conta
* @param {MongoFile} params.file - The file object.
*/
async function deleteFileFromAzure(req, file) {
await deleteRagFile({ userId: req.user.id, file });
try {
const containerClient = await getAzureContainerClient(AZURE_CONTAINER_NAME);
const blobPath = file.filepath.split(`${AZURE_CONTAINER_NAME}/`)[1];

View file

@ -3,7 +3,7 @@ const path = require('path');
const axios = require('axios');
const fetch = require('node-fetch');
const { logger } = require('@librechat/data-schemas');
const { getFirebaseStorage } = require('@librechat/api');
const { getFirebaseStorage, deleteRagFile } = require('@librechat/api');
const { ref, uploadBytes, getDownloadURL, deleteObject } = require('firebase/storage');
const { getBufferMetadata } = require('~/server/utils');
@ -167,27 +167,7 @@ function extractFirebaseFilePath(urlString) {
* Throws an error if there is an issue with deletion.
*/
const deleteFirebaseFile = async (req, file) => {
if (file.embedded && process.env.RAG_API_URL) {
const jwtToken = req.headers.authorization.split(' ')[1];
try {
await axios.delete(`${process.env.RAG_API_URL}/documents`, {
headers: {
Authorization: `Bearer ${jwtToken}`,
'Content-Type': 'application/json',
accept: 'application/json',
},
data: [file.file_id],
});
} catch (error) {
if (error.response?.status === 404) {
logger.warn(
`[deleteFirebaseFile] Document ${file.file_id} not found in RAG API, may have been deleted already`,
);
} else {
logger.error('[deleteFirebaseFile] Error deleting document from RAG API:', error);
}
}
}
await deleteRagFile({ userId: req.user.id, file });
const fileName = extractFirebaseFilePath(file.filepath);
if (!fileName.includes(req.user.id)) {

View file

@ -1,9 +1,9 @@
const fs = require('fs');
const path = require('path');
const axios = require('axios');
const { deleteRagFile } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const { EModelEndpoint } = require('librechat-data-provider');
const { generateShortLivedToken } = require('@librechat/api');
const { resizeImageBuffer } = require('~/server/services/Files/images/resize');
const { getBufferMetadata } = require('~/server/utils');
const paths = require('~/config/paths');
@ -213,27 +213,7 @@ const deleteLocalFile = async (req, file) => {
/** Filepath stripped of query parameters (e.g., ?manual=true) */
const cleanFilepath = file.filepath.split('?')[0];
if (file.embedded && process.env.RAG_API_URL) {
const jwtToken = generateShortLivedToken(req.user.id);
try {
await axios.delete(`${process.env.RAG_API_URL}/documents`, {
headers: {
Authorization: `Bearer ${jwtToken}`,
'Content-Type': 'application/json',
accept: 'application/json',
},
data: [file.file_id],
});
} catch (error) {
if (error.response?.status === 404) {
logger.warn(
`[deleteLocalFile] Document ${file.file_id} not found in RAG API, may have been deleted already`,
);
} else {
logger.error('[deleteLocalFile] Error deleting document from RAG API:', error);
}
}
}
await deleteRagFile({ userId: req.user.id, file });
if (cleanFilepath.startsWith(`/uploads/${req.user.id}`)) {
const userUploadDir = path.join(uploads, req.user.id);

View file

@ -1,9 +1,9 @@
const fs = require('fs');
const fetch = require('node-fetch');
const { initializeS3 } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const { FileSources } = require('librechat-data-provider');
const { getSignedUrl } = require('@aws-sdk/s3-request-presigner');
const { initializeS3, deleteRagFile } = require('@librechat/api');
const {
PutObjectCommand,
GetObjectCommand,
@ -142,6 +142,8 @@ async function saveURLToS3({ userId, URL, fileName, basePath = defaultBasePath }
* @returns {Promise<void>}
*/
async function deleteFileFromS3(req, file) {
await deleteRagFile({ userId: req.user.id, file });
const key = extractKeyFromS3Url(file.filepath);
const params = { Bucket: bucketName, Key: key };
if (!key.includes(req.user.id)) {

View file

@ -11,8 +11,9 @@ const {
MCPOAuthHandler,
isMCPDomainAllowed,
normalizeServerName,
resolveJsonSchemaRefs,
normalizeJsonSchema,
GenerationJobManager,
resolveJsonSchemaRefs,
} = require('@librechat/api');
const {
Time,
@ -443,7 +444,7 @@ function createToolInstance({
const { description, parameters } = toolDefinition;
const isGoogle = _provider === Providers.VERTEXAI || _provider === Providers.GOOGLE;
let schema = parameters ? resolveJsonSchemaRefs(parameters) : null;
let schema = parameters ? normalizeJsonSchema(resolveJsonSchemaRefs(parameters)) : null;
if (!schema || (isGoogle && isEmptyObjectSchema(schema))) {
schema = {

View file

@ -9,30 +9,6 @@ jest.mock('@librechat/data-schemas', () => ({
},
}));
jest.mock('@langchain/core/tools', () => ({
tool: jest.fn((fn, config) => {
const toolInstance = { _call: fn, ...config };
return toolInstance;
}),
}));
jest.mock('@librechat/agents', () => ({
Providers: {
VERTEXAI: 'vertexai',
GOOGLE: 'google',
},
StepTypes: {
TOOL_CALLS: 'tool_calls',
},
GraphEvents: {
ON_RUN_STEP_DELTA: 'on_run_step_delta',
ON_RUN_STEP: 'on_run_step',
},
Constants: {
CONTENT_AND_ARTIFACT: 'content_and_artifact',
},
}));
// Create mock registry instance
const mockRegistryInstance = {
getOAuthServers: jest.fn(() => Promise.resolve(new Set())),
@ -46,26 +22,23 @@ const mockIsMCPDomainAllowed = jest.fn(() => Promise.resolve(true));
const mockGetAppConfig = jest.fn(() => Promise.resolve({}));
jest.mock('@librechat/api', () => {
// Access mock via getter to avoid hoisting issues
const actual = jest.requireActual('@librechat/api');
return {
MCPOAuthHandler: {
generateFlowId: jest.fn(),
},
...actual,
sendEvent: jest.fn(),
normalizeServerName: jest.fn((name) => name),
resolveJsonSchemaRefs: jest.fn((params) => params),
get isMCPDomainAllowed() {
return mockIsMCPDomainAllowed;
},
MCPServersRegistry: {
getInstance: () => mockRegistryInstance,
GenerationJobManager: {
emitChunk: jest.fn(),
},
};
});
const { logger } = require('@librechat/data-schemas');
const { MCPOAuthHandler } = require('@librechat/api');
const { CacheKeys } = require('librechat-data-provider');
const { CacheKeys, Constants } = require('librechat-data-provider');
const D = Constants.mcp_delimiter;
const {
createMCPTool,
createMCPTools,
@ -74,24 +47,6 @@ const {
getServerConnectionStatus,
} = require('./MCP');
jest.mock('librechat-data-provider', () => ({
CacheKeys: {
FLOWS: 'flows',
},
Constants: {
USE_PRELIM_RESPONSE_MESSAGE_ID: 'prelim_response_id',
mcp_delimiter: '::',
mcp_prefix: 'mcp_',
},
ContentTypes: {
TEXT: 'text',
},
isAssistantsEndpoint: jest.fn(() => false),
Time: {
TWO_MINUTES: 120000,
},
}));
jest.mock('./Config', () => ({
loadCustomConfig: jest.fn(),
get getAppConfig() {
@ -132,6 +87,7 @@ describe('tests for the new helper functions used by the MCP connection status e
beforeEach(() => {
jest.clearAllMocks();
jest.spyOn(MCPOAuthHandler, 'generateFlowId');
mockGetMCPManager = require('~/config').getMCPManager;
mockGetFlowStateManager = require('~/config').getFlowStateManager;
@ -735,7 +691,7 @@ describe('User parameter passing tests', () => {
mockReinitMCPServer.mockResolvedValue({
tools: [{ name: 'test-tool' }],
availableTools: {
'test-tool::test-server': {
[`test-tool${D}test-server`]: {
function: {
description: 'Test tool',
parameters: { type: 'object', properties: {} },
@ -795,7 +751,7 @@ describe('User parameter passing tests', () => {
mockReinitMCPServer.mockResolvedValue({
availableTools: {
'test-tool::test-server': {
[`test-tool${D}test-server`]: {
function: {
description: 'Test tool',
parameters: { type: 'object', properties: {} },
@ -808,7 +764,7 @@ describe('User parameter passing tests', () => {
await createMCPTool({
res: mockRes,
user: mockUser,
toolKey: 'test-tool::test-server',
toolKey: `test-tool${D}test-server`,
provider: 'openai',
signal: mockSignal,
userMCPAuthMap: {},
@ -830,7 +786,7 @@ describe('User parameter passing tests', () => {
const mockRes = { write: jest.fn(), flush: jest.fn() };
const availableTools = {
'test-tool::test-server': {
[`test-tool${D}test-server`]: {
function: {
description: 'Cached tool',
parameters: { type: 'object', properties: {} },
@ -841,7 +797,7 @@ describe('User parameter passing tests', () => {
await createMCPTool({
res: mockRes,
user: mockUser,
toolKey: 'test-tool::test-server',
toolKey: `test-tool${D}test-server`,
provider: 'openai',
userMCPAuthMap: {},
availableTools: availableTools,
@ -864,8 +820,8 @@ describe('User parameter passing tests', () => {
return Promise.resolve({
tools: [{ name: 'tool1' }, { name: 'tool2' }],
availableTools: {
'tool1::server1': { function: { description: 'Tool 1', parameters: {} } },
'tool2::server1': { function: { description: 'Tool 2', parameters: {} } },
[`tool1${D}server1`]: { function: { description: 'Tool 1', parameters: {} } },
[`tool2${D}server1`]: { function: { description: 'Tool 2', parameters: {} } },
},
});
});
@ -896,7 +852,7 @@ describe('User parameter passing tests', () => {
reinitCalls.push(params);
return Promise.resolve({
availableTools: {
'my-tool::my-server': {
[`my-tool${D}my-server`]: {
function: { description: 'My Tool', parameters: {} },
},
},
@ -906,7 +862,7 @@ describe('User parameter passing tests', () => {
await createMCPTool({
res: mockRes,
user: mockUser,
toolKey: 'my-tool::my-server',
toolKey: `my-tool${D}my-server`,
provider: 'google',
userMCPAuthMap: {},
availableTools: undefined, // Force reinit
@ -940,11 +896,11 @@ describe('User parameter passing tests', () => {
const result = await createMCPTool({
res: mockRes,
user: mockUser,
toolKey: 'test-tool::test-server',
toolKey: `test-tool${D}test-server`,
provider: 'openai',
userMCPAuthMap: {},
availableTools: {
'test-tool::test-server': {
[`test-tool${D}test-server`]: {
function: {
description: 'Test tool',
parameters: { type: 'object', properties: {} },
@ -987,7 +943,7 @@ describe('User parameter passing tests', () => {
mockIsMCPDomainAllowed.mockResolvedValueOnce(true);
const availableTools = {
'test-tool::test-server': {
[`test-tool${D}test-server`]: {
function: {
description: 'Test tool',
parameters: { type: 'object', properties: {} },
@ -998,7 +954,7 @@ describe('User parameter passing tests', () => {
const result = await createMCPTool({
res: mockRes,
user: mockUser,
toolKey: 'test-tool::test-server',
toolKey: `test-tool${D}test-server`,
provider: 'openai',
userMCPAuthMap: {},
availableTools,
@ -1027,7 +983,7 @@ describe('User parameter passing tests', () => {
});
const availableTools = {
'test-tool::test-server': {
[`test-tool${D}test-server`]: {
function: {
description: 'Test tool',
parameters: { type: 'object', properties: {} },
@ -1038,7 +994,7 @@ describe('User parameter passing tests', () => {
const result = await createMCPTool({
res: mockRes,
user: mockUser,
toolKey: 'test-tool::test-server',
toolKey: `test-tool${D}test-server`,
provider: 'openai',
userMCPAuthMap: {},
availableTools,
@ -1104,7 +1060,7 @@ describe('User parameter passing tests', () => {
mockIsMCPDomainAllowed.mockResolvedValue(true);
const availableTools = {
'test-tool::test-server': {
[`test-tool${D}test-server`]: {
function: {
description: 'Test tool',
parameters: { type: 'object', properties: {} },
@ -1116,7 +1072,7 @@ describe('User parameter passing tests', () => {
await createMCPTool({
res: mockRes,
user: adminUser,
toolKey: 'test-tool::test-server',
toolKey: `test-tool${D}test-server`,
provider: 'openai',
userMCPAuthMap: {},
availableTools,
@ -1130,7 +1086,7 @@ describe('User parameter passing tests', () => {
await createMCPTool({
res: mockRes,
user: regularUser,
toolKey: 'test-tool::test-server',
toolKey: `test-tool${D}test-server`,
provider: 'openai',
userMCPAuthMap: {},
availableTools,
@ -1158,7 +1114,7 @@ describe('User parameter passing tests', () => {
return Promise.resolve({
tools: [{ name: 'test' }],
availableTools: {
'test::server': { function: { description: 'Test', parameters: {} } },
[`test${D}server`]: { function: { description: 'Test', parameters: {} } },
},
});
});

View file

@ -338,6 +338,7 @@ async function processRequiredActions(client, requiredActions) {
}
// We've already decrypted the metadata, so we can pass it directly
const _allowedDomains = appConfig?.actions?.allowedDomains;
tool = await createActionTool({
userId: client.req.user.id,
res: client.res,
@ -345,6 +346,7 @@ async function processRequiredActions(client, requiredActions) {
requestBuilder,
// Note: intentionally not passing zodSchema, name, and description for assistants API
encrypted, // Pass the encrypted values for OAuth flow
useSSRFProtection: !Array.isArray(_allowedDomains) || _allowedDomains.length === 0,
});
if (!tool) {
logger.warn(
@ -1064,6 +1066,7 @@ async function loadAgentTools({
const zodSchema = zodSchemas[functionName];
if (requestBuilder) {
const _allowedDomains = appConfig?.actions?.allowedDomains;
const tool = await createActionTool({
userId: req.user.id,
res,
@ -1074,6 +1077,7 @@ async function loadAgentTools({
name: toolName,
description: functionSig.description,
streamId,
useSSRFProtection: !Array.isArray(_allowedDomains) || _allowedDomains.length === 0,
});
if (!tool) {
@ -1335,6 +1339,7 @@ async function loadActionToolsForExecution({
});
}
const domainSeparatorRegex = new RegExp(actionDomainSeparator, 'g');
for (const toolName of actionToolNames) {
let currentDomain = '';
for (const domain of domainMap.keys()) {
@ -1351,7 +1356,6 @@ async function loadActionToolsForExecution({
const { action, encrypted, zodSchemas, requestBuilders, functionSignatures } =
processedActionSets.get(currentDomain);
const domainSeparatorRegex = new RegExp(actionDomainSeparator, 'g');
const normalizedDomain = currentDomain.replace(domainSeparatorRegex, '_');
const functionName = toolName.replace(`${actionDelimiter}${normalizedDomain}`, '');
const functionSig = functionSignatures.find((sig) => sig.name === functionName);
@ -1372,6 +1376,7 @@ async function loadActionToolsForExecution({
requestBuilder,
name: toolName,
description: functionSig?.description ?? '',
useSSRFProtection: !Array.isArray(allowedDomains) || allowedDomains.length === 0,
});
if (!tool) {

View file

@ -1,7 +1,7 @@
const passport = require('passport');
const session = require('express-session');
const { isEnabled } = require('@librechat/api');
const { CacheKeys } = require('librechat-data-provider');
const { isEnabled, shouldUseSecureCookie } = require('@librechat/api');
const { logger, DEFAULT_SESSION_EXPIRY } = require('@librechat/data-schemas');
const {
openIdJwtLogin,
@ -15,38 +15,6 @@ const {
} = require('~/strategies');
const { getLogStores } = require('~/cache');
/**
* Determines if secure cookies should be used.
* Only use secure cookies in production when not on localhost.
* @returns {boolean}
*/
function shouldUseSecureCookie() {
const isProduction = process.env.NODE_ENV === 'production';
const domainServer = process.env.DOMAIN_SERVER || '';
let hostname = '';
if (domainServer) {
try {
const normalized = /^https?:\/\//i.test(domainServer)
? domainServer
: `http://${domainServer}`;
const url = new URL(normalized);
hostname = (url.hostname || '').toLowerCase();
} catch {
// Fallback: treat DOMAIN_SERVER directly as a hostname-like string
hostname = domainServer.toLowerCase();
}
}
const isLocalhost =
hostname === 'localhost' ||
hostname === '127.0.0.1' ||
hostname === '::1' ||
hostname.endsWith('.localhost');
return isProduction && !isLocalhost;
}
/**
* Configures OpenID Connect for the application.
* @param {Express.Application} app - The Express application instance.

View file

@ -84,19 +84,21 @@ const openIdJwtLogin = (openIdConfig) => {
/** Read tokens from session (server-side) to avoid large cookie issues */
const sessionTokens = req.session?.openidTokens;
let accessToken = sessionTokens?.accessToken;
let idToken = sessionTokens?.idToken;
let refreshToken = sessionTokens?.refreshToken;
/** Fallback to cookies for backward compatibility */
if (!accessToken || !refreshToken) {
if (!accessToken || !refreshToken || !idToken) {
const cookieHeader = req.headers.cookie;
const parsedCookies = cookieHeader ? cookies.parse(cookieHeader) : {};
accessToken = accessToken || parsedCookies.openid_access_token;
idToken = idToken || parsedCookies.openid_id_token;
refreshToken = refreshToken || parsedCookies.refreshToken;
}
user.federatedTokens = {
access_token: accessToken || rawToken,
id_token: rawToken,
id_token: idToken,
refresh_token: refreshToken,
expires_at: payload.exp,
};

View file

@ -0,0 +1,183 @@
const { SystemRoles } = require('librechat-data-provider');
// --- Capture the verify callback from JwtStrategy ---
let capturedVerifyCallback;
jest.mock('passport-jwt', () => ({
Strategy: jest.fn((_opts, verifyCallback) => {
capturedVerifyCallback = verifyCallback;
return { name: 'jwt' };
}),
ExtractJwt: {
fromAuthHeaderAsBearerToken: jest.fn(() => 'mock-extractor'),
},
}));
jest.mock('jwks-rsa', () => ({
passportJwtSecret: jest.fn(() => 'mock-secret-provider'),
}));
jest.mock('https-proxy-agent', () => ({
HttpsProxyAgent: jest.fn(),
}));
jest.mock('@librechat/data-schemas', () => ({
logger: { info: jest.fn(), warn: jest.fn(), debug: jest.fn(), error: jest.fn() },
}));
jest.mock('@librechat/api', () => ({
isEnabled: jest.fn(() => false),
findOpenIDUser: jest.fn(),
math: jest.fn((val, fallback) => fallback),
}));
jest.mock('~/models', () => ({
findUser: jest.fn(),
updateUser: jest.fn(),
}));
const { findOpenIDUser } = require('@librechat/api');
const { updateUser } = require('~/models');
const openIdJwtLogin = require('./openIdJwtStrategy');
// Helper: build a mock openIdConfig
const mockOpenIdConfig = {
serverMetadata: () => ({ jwks_uri: 'https://example.com/.well-known/jwks.json' }),
};
// Helper: invoke the captured verify callback
async function invokeVerify(req, payload) {
return new Promise((resolve, reject) => {
capturedVerifyCallback(req, payload, (err, user, info) => {
if (err) {
return reject(err);
}
resolve({ user, info });
});
});
}
describe('openIdJwtStrategy token source handling', () => {
const baseUser = {
_id: { toString: () => 'user-abc' },
role: SystemRoles.USER,
provider: 'openid',
};
const payload = { sub: 'oidc-123', email: 'test@example.com', exp: 9999999999 };
beforeEach(() => {
jest.clearAllMocks();
findOpenIDUser.mockResolvedValue({ user: { ...baseUser }, error: null, migration: false });
updateUser.mockResolvedValue({});
// Initialize the strategy so capturedVerifyCallback is set
openIdJwtLogin(mockOpenIdConfig);
});
it('should read all tokens from session when available', async () => {
const req = {
headers: { authorization: 'Bearer raw-bearer-token' },
session: {
openidTokens: {
accessToken: 'session-access',
idToken: 'session-id',
refreshToken: 'session-refresh',
},
},
};
const { user } = await invokeVerify(req, payload);
expect(user.federatedTokens).toEqual({
access_token: 'session-access',
id_token: 'session-id',
refresh_token: 'session-refresh',
expires_at: payload.exp,
});
});
it('should fall back to cookies when session is absent', async () => {
const req = {
headers: {
authorization: 'Bearer raw-bearer-token',
cookie:
'openid_access_token=cookie-access; openid_id_token=cookie-id; refreshToken=cookie-refresh',
},
};
const { user } = await invokeVerify(req, payload);
expect(user.federatedTokens).toEqual({
access_token: 'cookie-access',
id_token: 'cookie-id',
refresh_token: 'cookie-refresh',
expires_at: payload.exp,
});
});
it('should fall back to cookie for idToken only when session lacks it', async () => {
const req = {
headers: {
authorization: 'Bearer raw-bearer-token',
cookie: 'openid_id_token=cookie-id',
},
session: {
openidTokens: {
accessToken: 'session-access',
// idToken intentionally missing
refreshToken: 'session-refresh',
},
},
};
const { user } = await invokeVerify(req, payload);
expect(user.federatedTokens).toEqual({
access_token: 'session-access',
id_token: 'cookie-id',
refresh_token: 'session-refresh',
expires_at: payload.exp,
});
});
it('should use raw Bearer token as access_token fallback when neither session nor cookie has one', async () => {
const req = {
headers: {
authorization: 'Bearer raw-bearer-token',
cookie: 'openid_id_token=cookie-id; refreshToken=cookie-refresh',
},
};
const { user } = await invokeVerify(req, payload);
expect(user.federatedTokens.access_token).toBe('raw-bearer-token');
expect(user.federatedTokens.id_token).toBe('cookie-id');
expect(user.federatedTokens.refresh_token).toBe('cookie-refresh');
});
it('should set id_token to undefined when not available in session or cookies', async () => {
const req = {
headers: {
authorization: 'Bearer raw-bearer-token',
cookie: 'openid_access_token=cookie-access; refreshToken=cookie-refresh',
},
};
const { user } = await invokeVerify(req, payload);
expect(user.federatedTokens.access_token).toBe('cookie-access');
expect(user.federatedTokens.id_token).toBeUndefined();
expect(user.federatedTokens.refresh_token).toBe('cookie-refresh');
});
it('should keep id_token and access_token as distinct values from cookies', async () => {
const req = {
headers: {
authorization: 'Bearer raw-bearer-token',
cookie:
'openid_access_token=the-access-token; openid_id_token=the-id-token; refreshToken=the-refresh',
},
};
const { user } = await invokeVerify(req, payload);
expect(user.federatedTokens.access_token).toBe('the-access-token');
expect(user.federatedTokens.id_token).toBe('the-id-token');
expect(user.federatedTokens.access_token).not.toBe(user.federatedTokens.id_token);
});
});

View file

@ -287,6 +287,77 @@ function convertToUsername(input, defaultValue = '') {
return defaultValue;
}
/**
* Resolve Azure AD groups when group overage is in effect (groups moved to _claim_names/_claim_sources).
*
* NOTE: Microsoft recommends treating _claim_names/_claim_sources as a signal only and using Microsoft Graph
* to resolve group membership instead of calling the endpoint in _claim_sources directly.
*
* @param {string} accessToken - Access token with Microsoft Graph permissions
* @returns {Promise<string[] | null>} Resolved group IDs or null on failure
* @see https://learn.microsoft.com/en-us/entra/identity-platform/access-token-claims-reference#groups-overage-claim
* @see https://learn.microsoft.com/en-us/graph/api/directoryobject-getmemberobjects
*/
async function resolveGroupsFromOverage(accessToken) {
try {
if (!accessToken) {
logger.error('[openidStrategy] Access token missing; cannot resolve group overage');
return null;
}
// Use /me/getMemberObjects so least-privileged delegated permission User.Read is sufficient
// when resolving the signed-in user's group membership.
const url = 'https://graph.microsoft.com/v1.0/me/getMemberObjects';
logger.debug(
`[openidStrategy] Detected group overage, resolving groups via Microsoft Graph getMemberObjects: ${url}`,
);
const fetchOptions = {
method: 'POST',
headers: {
Authorization: `Bearer ${accessToken}`,
'Content-Type': 'application/json',
},
body: JSON.stringify({ securityEnabledOnly: false }),
};
if (process.env.PROXY) {
const { ProxyAgent } = undici;
fetchOptions.dispatcher = new ProxyAgent(process.env.PROXY);
}
const response = await undici.fetch(url, fetchOptions);
if (!response.ok) {
logger.error(
`[openidStrategy] Failed to resolve groups via Microsoft Graph getMemberObjects: HTTP ${response.status} ${response.statusText}`,
);
return null;
}
const data = await response.json();
const values = Array.isArray(data?.value) ? data.value : null;
if (!values) {
logger.error(
'[openidStrategy] Unexpected response format when resolving groups via Microsoft Graph getMemberObjects',
);
return null;
}
const groupIds = values.filter((id) => typeof id === 'string');
logger.debug(
`[openidStrategy] Successfully resolved ${groupIds.length} groups via Microsoft Graph getMemberObjects`,
);
return groupIds;
} catch (err) {
logger.error(
'[openidStrategy] Error resolving groups via Microsoft Graph getMemberObjects:',
err,
);
return null;
}
}
/**
* Process OpenID authentication tokenset and userinfo
* This is the core logic extracted from the passport strategy callback
@ -350,6 +421,25 @@ async function processOpenIDAuth(tokenset, existingUsersOnly = false) {
}
let roles = get(decodedToken, requiredRoleParameterPath);
// Handle Azure AD group overage for ID token groups: when hasgroups or _claim_* indicate overage,
// resolve groups via Microsoft Graph instead of relying on token group values.
if (
!Array.isArray(roles) &&
typeof roles !== 'string' &&
requiredRoleTokenKind === 'id' &&
requiredRoleParameterPath === 'groups' &&
decodedToken &&
(decodedToken.hasgroups ||
(decodedToken._claim_names?.groups &&
decodedToken._claim_sources?.[decodedToken._claim_names.groups]))
) {
const overageGroups = await resolveGroupsFromOverage(tokenset.access_token);
if (overageGroups) {
roles = overageGroups;
}
}
if (!roles || (!Array.isArray(roles) && typeof roles !== 'string')) {
logger.error(
`[openidStrategy] Key '${requiredRoleParameterPath}' not found in ${requiredRoleTokenKind} token!`,
@ -361,7 +451,9 @@ async function processOpenIDAuth(tokenset, existingUsersOnly = false) {
throw new Error(`You must have ${rolesList} role to log in.`);
}
if (!requiredRoles.some((role) => roles.includes(role))) {
const roleValues = Array.isArray(roles) ? roles : [roles];
if (!requiredRoles.some((role) => roleValues.includes(role))) {
const rolesList =
requiredRoles.length === 1
? `"${requiredRoles[0]}"`
@ -498,6 +590,7 @@ async function processOpenIDAuth(tokenset, existingUsersOnly = false) {
tokenset,
federatedTokens: {
access_token: tokenset.access_token,
id_token: tokenset.id_token,
refresh_token: tokenset.refresh_token,
expires_at: tokenset.expires_at,
},

View file

@ -1,5 +1,6 @@
const fetch = require('node-fetch');
const jwtDecode = require('jsonwebtoken/decode');
const undici = require('undici');
const { ErrorTypes } = require('librechat-data-provider');
const { findUser, createUser, updateUser } = require('~/models');
const { setupOpenId } = require('./openidStrategy');
@ -7,6 +8,10 @@ const { setupOpenId } = require('./openidStrategy');
// --- Mocks ---
jest.mock('node-fetch');
jest.mock('jsonwebtoken/decode');
jest.mock('undici', () => ({
fetch: jest.fn(),
ProxyAgent: jest.fn(),
}));
jest.mock('~/server/services/Files/strategies', () => ({
getStrategyFunctions: jest.fn(() => ({
saveBuffer: jest.fn().mockResolvedValue('/fake/path/to/avatar.png'),
@ -360,6 +365,25 @@ describe('setupOpenId', () => {
expect(details.message).toBe('You must have "requiredRole" role to log in.');
});
it('should not treat substring matches in string roles as satisfying required role', async () => {
// Arrange override required role to "read" then re-setup
process.env.OPENID_REQUIRED_ROLE = 'read';
await setupOpenId();
verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid');
// Token contains "bread" which *contains* "read" as a substring
jwtDecode.mockReturnValue({
roles: 'bread',
});
// Act
const { user, details } = await validate(tokenset);
// Assert verify that substring match does not grant access
expect(user).toBe(false);
expect(details.message).toBe('You must have "read" role to log in.');
});
it('should allow login when single required role is present (backward compatibility)', async () => {
// Arrange ensure single role configuration (as set in beforeEach)
// OPENID_REQUIRED_ROLE = 'requiredRole'
@ -378,6 +402,292 @@ describe('setupOpenId', () => {
expect(createUser).toHaveBeenCalled();
});
describe('group overage and groups handling', () => {
it.each([
['groups array contains required group', ['group-required', 'other-group'], true, undefined],
[
'groups array missing required group',
['other-group'],
false,
'You must have "group-required" role to log in.',
],
['groups string equals required group', 'group-required', true, undefined],
[
'groups string is other group',
'other-group',
false,
'You must have "group-required" role to log in.',
],
])(
'uses groups claim directly when %s (no overage)',
async (_label, groupsClaim, expectedAllowed, expectedMessage) => {
process.env.OPENID_REQUIRED_ROLE = 'group-required';
process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups';
process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id';
jwtDecode.mockReturnValue({
groups: groupsClaim,
permissions: ['admin'],
});
await setupOpenId();
verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid');
const { user, details } = await validate(tokenset);
expect(undici.fetch).not.toHaveBeenCalled();
expect(Boolean(user)).toBe(expectedAllowed);
expect(details?.message).toBe(expectedMessage);
},
);
it.each([
['token kind is not id', { kind: 'access', path: 'groups', decoded: { hasgroups: true } }],
['parameter path is not groups', { kind: 'id', path: 'roles', decoded: { hasgroups: true } }],
['decoded token is falsy', { kind: 'id', path: 'groups', decoded: null }],
[
'no overage indicators in decoded token',
{
kind: 'id',
path: 'groups',
decoded: {
permissions: ['admin'],
},
},
],
[
'only _claim_names present (no _claim_sources)',
{
kind: 'id',
path: 'groups',
decoded: {
_claim_names: { groups: 'src1' },
permissions: ['admin'],
},
},
],
[
'only _claim_sources present (no _claim_names)',
{
kind: 'id',
path: 'groups',
decoded: {
_claim_sources: { src1: { endpoint: 'https://graph.windows.net/ignored' } },
permissions: ['admin'],
},
},
],
])('does not attempt overage resolution when %s', async (_label, cfg) => {
process.env.OPENID_REQUIRED_ROLE = 'group-required';
process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = cfg.path;
process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = cfg.kind;
jwtDecode.mockReturnValue(cfg.decoded);
await setupOpenId();
verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid');
const { user, details } = await validate(tokenset);
expect(undici.fetch).not.toHaveBeenCalled();
expect(user).toBe(false);
expect(details.message).toBe('You must have "group-required" role to log in.');
const { logger } = require('@librechat/data-schemas');
const expectedTokenKind = cfg.kind === 'access' ? 'access token' : 'id token';
expect(logger.error).toHaveBeenCalledWith(
expect.stringContaining(`Key '${cfg.path}' not found in ${expectedTokenKind}!`),
);
});
});
describe('resolving groups via Microsoft Graph', () => {
it('denies login and does not call Graph when access token is missing', async () => {
process.env.OPENID_REQUIRED_ROLE = 'group-required';
process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups';
process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id';
const { logger } = require('@librechat/data-schemas');
jwtDecode.mockReturnValue({
hasgroups: true,
permissions: ['admin'],
});
await setupOpenId();
verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid');
const tokensetWithoutAccess = {
...tokenset,
access_token: undefined,
};
const { user, details } = await validate(tokensetWithoutAccess);
expect(user).toBe(false);
expect(details.message).toBe('You must have "group-required" role to log in.');
expect(undici.fetch).not.toHaveBeenCalled();
expect(logger.error).toHaveBeenCalledWith(
expect.stringContaining('Access token missing; cannot resolve group overage'),
);
});
it.each([
[
'Graph returns HTTP error',
async () => ({
ok: false,
status: 403,
statusText: 'Forbidden',
json: async () => ({}),
}),
[
'[openidStrategy] Failed to resolve groups via Microsoft Graph getMemberObjects: HTTP 403 Forbidden',
],
],
[
'Graph network error',
async () => {
throw new Error('network error');
},
[
'[openidStrategy] Error resolving groups via Microsoft Graph getMemberObjects:',
expect.any(Error),
],
],
[
'Graph returns unexpected shape (no value)',
async () => ({
ok: true,
status: 200,
statusText: 'OK',
json: async () => ({}),
}),
[
'[openidStrategy] Unexpected response format when resolving groups via Microsoft Graph getMemberObjects',
],
],
[
'Graph returns invalid value type',
async () => ({
ok: true,
status: 200,
statusText: 'OK',
json: async () => ({ value: 'not-an-array' }),
}),
[
'[openidStrategy] Unexpected response format when resolving groups via Microsoft Graph getMemberObjects',
],
],
])(
'denies login when overage resolution fails because %s',
async (_label, setupFetch, expectedErrorArgs) => {
process.env.OPENID_REQUIRED_ROLE = 'group-required';
process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups';
process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id';
const { logger } = require('@librechat/data-schemas');
jwtDecode.mockReturnValue({
hasgroups: true,
permissions: ['admin'],
});
await setupOpenId();
verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid');
undici.fetch.mockImplementation(setupFetch);
const { user, details } = await validate(tokenset);
expect(undici.fetch).toHaveBeenCalled();
expect(user).toBe(false);
expect(details.message).toBe('You must have "group-required" role to log in.');
expect(logger.error).toHaveBeenCalledWith(...expectedErrorArgs);
},
);
it.each([
[
'hasgroups overage and Graph contains required group',
{
hasgroups: true,
},
['group-required', 'some-other-group'],
true,
],
[
'_claim_* overage and Graph contains required group',
{
_claim_names: { groups: 'src1' },
_claim_sources: { src1: { endpoint: 'https://graph.windows.net/ignored' } },
},
['group-required', 'some-other-group'],
true,
],
[
'hasgroups overage and Graph does NOT contain required group',
{
hasgroups: true,
},
['some-other-group'],
false,
],
[
'_claim_* overage and Graph does NOT contain required group',
{
_claim_names: { groups: 'src1' },
_claim_sources: { src1: { endpoint: 'https://graph.windows.net/ignored' } },
},
['some-other-group'],
false,
],
])(
'resolves groups via Microsoft Graph when %s',
async (_label, decodedTokenValue, graphGroups, expectedAllowed) => {
process.env.OPENID_REQUIRED_ROLE = 'group-required';
process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups';
process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id';
const { logger } = require('@librechat/data-schemas');
jwtDecode.mockReturnValue(decodedTokenValue);
await setupOpenId();
verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid');
undici.fetch.mockResolvedValue({
ok: true,
status: 200,
statusText: 'OK',
json: async () => ({
value: graphGroups,
}),
});
const { user } = await validate(tokenset);
expect(undici.fetch).toHaveBeenCalledWith(
'https://graph.microsoft.com/v1.0/me/getMemberObjects',
expect.objectContaining({
method: 'POST',
headers: expect.objectContaining({
Authorization: `Bearer ${tokenset.access_token}`,
}),
}),
);
expect(Boolean(user)).toBe(expectedAllowed);
expect(logger.debug).toHaveBeenCalledWith(
expect.stringContaining(
`Successfully resolved ${graphGroups.length} groups via Microsoft Graph getMemberObjects`,
),
);
},
);
});
it('should attempt to download and save the avatar if picture is provided', async () => {
// Act
const { user } = await validate(tokenset);
@ -465,10 +775,11 @@ describe('setupOpenId', () => {
});
it('should attach federatedTokens to user object for token propagation', async () => {
// Arrange - setup tokenset with access token, refresh token, and expiration
// Arrange - setup tokenset with access token, id token, refresh token, and expiration
const tokensetWithTokens = {
...tokenset,
access_token: 'mock_access_token_abc123',
id_token: 'mock_id_token_def456',
refresh_token: 'mock_refresh_token_xyz789',
expires_at: 1234567890,
};
@ -480,16 +791,37 @@ describe('setupOpenId', () => {
expect(user.federatedTokens).toBeDefined();
expect(user.federatedTokens).toEqual({
access_token: 'mock_access_token_abc123',
id_token: 'mock_id_token_def456',
refresh_token: 'mock_refresh_token_xyz789',
expires_at: 1234567890,
});
});
it('should include id_token in federatedTokens distinct from access_token', async () => {
// Arrange - use different values for access_token and id_token
const tokensetWithTokens = {
...tokenset,
access_token: 'the_access_token',
id_token: 'the_id_token',
refresh_token: 'the_refresh_token',
expires_at: 9999999999,
};
// Act
const { user } = await validate(tokensetWithTokens);
// Assert - id_token and access_token must be different values
expect(user.federatedTokens.access_token).toBe('the_access_token');
expect(user.federatedTokens.id_token).toBe('the_id_token');
expect(user.federatedTokens.id_token).not.toBe(user.federatedTokens.access_token);
});
it('should include tokenset along with federatedTokens', async () => {
// Arrange
const tokensetWithTokens = {
...tokenset,
access_token: 'test_access_token',
id_token: 'test_id_token',
refresh_token: 'test_refresh_token',
expires_at: 9999999999,
};
@ -501,7 +833,9 @@ describe('setupOpenId', () => {
expect(user.tokenset).toBeDefined();
expect(user.federatedTokens).toBeDefined();
expect(user.tokenset.access_token).toBe('test_access_token');
expect(user.tokenset.id_token).toBe('test_id_token');
expect(user.federatedTokens.access_token).toBe('test_access_token');
expect(user.federatedTokens.id_token).toBe('test_id_token');
});
it('should set role to "ADMIN" if OPENID_ADMIN_ROLE is set and user has that role', async () => {

View file

@ -1162,6 +1162,56 @@ describe('Claude Model Tests', () => {
expect(matchModelName(model, EModelEndpoint.anthropic)).toBe('claude-opus-4-6');
});
});
it('should return correct context length for Claude Sonnet 4.6 (1M)', () => {
expect(getModelMaxTokens('claude-sonnet-4-6', EModelEndpoint.anthropic)).toBe(
maxTokensMap[EModelEndpoint.anthropic]['claude-sonnet-4-6'],
);
expect(getModelMaxTokens('claude-sonnet-4-6')).toBe(
maxTokensMap[EModelEndpoint.anthropic]['claude-sonnet-4-6'],
);
});
it('should return correct max output tokens for Claude Sonnet 4.6 (64K)', () => {
const { getModelMaxOutputTokens } = require('@librechat/api');
expect(getModelMaxOutputTokens('claude-sonnet-4-6', EModelEndpoint.anthropic)).toBe(
maxOutputTokensMap[EModelEndpoint.anthropic]['claude-sonnet-4-6'],
);
});
it('should handle Claude Sonnet 4.6 model name variations', () => {
const modelVariations = [
'claude-sonnet-4-6',
'claude-sonnet-4-6-20260101',
'claude-sonnet-4-6-latest',
'anthropic/claude-sonnet-4-6',
'claude-sonnet-4-6/anthropic',
'claude-sonnet-4-6-preview',
];
modelVariations.forEach((model) => {
const modelKey = findMatchingPattern(model, maxTokensMap[EModelEndpoint.anthropic]);
expect(modelKey).toBe('claude-sonnet-4-6');
expect(getModelMaxTokens(model, EModelEndpoint.anthropic)).toBe(
maxTokensMap[EModelEndpoint.anthropic]['claude-sonnet-4-6'],
);
});
});
it('should match model names correctly for Claude Sonnet 4.6', () => {
const modelVariations = [
'claude-sonnet-4-6',
'claude-sonnet-4-6-20260101',
'claude-sonnet-4-6-latest',
'anthropic/claude-sonnet-4-6',
'claude-sonnet-4-6/anthropic',
'claude-sonnet-4-6-preview',
];
modelVariations.forEach((model) => {
expect(matchModelName(model, EModelEndpoint.anthropic)).toBe('claude-sonnet-4-6');
});
});
});
describe('Moonshot/Kimi Model Tests', () => {

View file

@ -80,7 +80,7 @@
"lodash": "^4.17.23",
"lucide-react": "^0.394.0",
"match-sorter": "^8.1.0",
"mermaid": "^11.12.2",
"mermaid": "^11.12.3",
"micromark-extension-llm-math": "^3.1.0",
"qrcode.react": "^4.2.0",
"rc-input-number": "^7.4.2",

View file

@ -1,4 +1,4 @@
import React, { createContext, useContext, useEffect, useRef } from 'react';
import React, { createContext, useContext, useEffect, useMemo, useRef } from 'react';
import { useSetRecoilState } from 'recoil';
import { Tools, Constants, LocalStorageKeys, AgentCapabilities } from 'librechat-data-provider';
import type { TAgentsEndpoint } from 'librechat-data-provider';
@ -9,11 +9,13 @@ import {
useCodeApiKeyForm,
useToolToggle,
} from '~/hooks';
import { getTimestampedValue, setTimestamp } from '~/utils/timestamps';
import { getTimestampedValue } from '~/utils/timestamps';
import { useGetStartupConfig } from '~/data-provider';
import { ephemeralAgentByConvoId } from '~/store';
interface BadgeRowContextType {
conversationId?: string | null;
storageContextKey?: string;
agentsConfig?: TAgentsEndpoint | null;
webSearch: ReturnType<typeof useToolToggle>;
artifacts: ReturnType<typeof useToolToggle>;
@ -38,34 +40,70 @@ interface BadgeRowProviderProps {
children: React.ReactNode;
isSubmitting?: boolean;
conversationId?: string | null;
specName?: string | null;
}
export default function BadgeRowProvider({
children,
isSubmitting,
conversationId,
specName,
}: BadgeRowProviderProps) {
const lastKeyRef = useRef<string>('');
const lastContextKeyRef = useRef<string>('');
const hasInitializedRef = useRef(false);
const { agentsConfig } = useGetAgentsConfig();
const { data: startupConfig } = useGetStartupConfig();
const key = conversationId ?? Constants.NEW_CONVO;
const hasModelSpecs = (startupConfig?.modelSpecs?.list?.length ?? 0) > 0;
/**
* Compute the storage context key for non-spec persistence:
* - `__defaults__`: specs configured but none active shared defaults key
* - undefined: spec active (no persistence) or no specs configured (original behavior)
*
* When a spec is active, tool/MCP state is NOT persisted the admin's spec
* configuration is always applied fresh. Only non-spec user preferences persist.
*/
const storageContextKey = useMemo(() => {
if (!specName && hasModelSpecs) {
return Constants.spec_defaults_key as string;
}
return undefined;
}, [specName, hasModelSpecs]);
/**
* Compute the storage suffix for reading localStorage defaults:
* - New conversations read from environment key (spec or non-spec defaults)
* - Existing conversations read from conversation key (per-conversation state)
*/
const isNewConvo = key === Constants.NEW_CONVO;
const storageSuffix = isNewConvo && storageContextKey ? storageContextKey : key;
const setEphemeralAgent = useSetRecoilState(ephemeralAgentByConvoId(key));
/** Initialize ephemeralAgent from localStorage on mount and when conversation changes */
/** Initialize ephemeralAgent from localStorage on mount and when conversation/spec changes.
* Skipped when a spec is active applyModelSpecEphemeralAgent handles both new conversations
* (pure spec values) and existing conversations (spec values + localStorage overrides). */
useEffect(() => {
if (isSubmitting) {
return;
}
// Check if this is a new conversation or the first load
if (!hasInitializedRef.current || lastKeyRef.current !== key) {
if (specName) {
// Spec active: applyModelSpecEphemeralAgent handles all state (spec base + localStorage
// overrides for existing conversations). Reset init flag so switching back to non-spec
// triggers a fresh re-init.
hasInitializedRef.current = false;
return;
}
// Check if this is a new conversation/spec or the first load
if (!hasInitializedRef.current || lastContextKeyRef.current !== storageSuffix) {
hasInitializedRef.current = true;
lastKeyRef.current = key;
lastContextKeyRef.current = storageSuffix;
const codeToggleKey = `${LocalStorageKeys.LAST_CODE_TOGGLE_}${key}`;
const webSearchToggleKey = `${LocalStorageKeys.LAST_WEB_SEARCH_TOGGLE_}${key}`;
const fileSearchToggleKey = `${LocalStorageKeys.LAST_FILE_SEARCH_TOGGLE_}${key}`;
const artifactsToggleKey = `${LocalStorageKeys.LAST_ARTIFACTS_TOGGLE_}${key}`;
const codeToggleKey = `${LocalStorageKeys.LAST_CODE_TOGGLE_}${storageSuffix}`;
const webSearchToggleKey = `${LocalStorageKeys.LAST_WEB_SEARCH_TOGGLE_}${storageSuffix}`;
const fileSearchToggleKey = `${LocalStorageKeys.LAST_FILE_SEARCH_TOGGLE_}${storageSuffix}`;
const artifactsToggleKey = `${LocalStorageKeys.LAST_ARTIFACTS_TOGGLE_}${storageSuffix}`;
const codeToggleValue = getTimestampedValue(codeToggleKey);
const webSearchToggleValue = getTimestampedValue(webSearchToggleKey);
@ -106,39 +144,53 @@ export default function BadgeRowProvider({
}
}
/**
* Always set values for all tools (use defaults if not in `localStorage`)
* If `ephemeralAgent` is `null`, create a new object with just our tool values
*/
const finalValues = {
[Tools.execute_code]: initialValues[Tools.execute_code] ?? false,
[Tools.web_search]: initialValues[Tools.web_search] ?? false,
[Tools.file_search]: initialValues[Tools.file_search] ?? false,
[AgentCapabilities.artifacts]: initialValues[AgentCapabilities.artifacts] ?? false,
};
const hasOverrides = Object.keys(initialValues).length > 0;
setEphemeralAgent((prev) => ({
...(prev || {}),
...finalValues,
}));
Object.entries(finalValues).forEach(([toolKey, value]) => {
if (value !== false) {
let storageKey = artifactsToggleKey;
if (toolKey === Tools.execute_code) {
storageKey = codeToggleKey;
} else if (toolKey === Tools.web_search) {
storageKey = webSearchToggleKey;
} else if (toolKey === Tools.file_search) {
storageKey = fileSearchToggleKey;
/** Read persisted MCP values from localStorage */
let mcpOverrides: string[] | null = null;
const mcpStorageKey = `${LocalStorageKeys.LAST_MCP_}${storageSuffix}`;
const mcpRaw = localStorage.getItem(mcpStorageKey);
if (mcpRaw !== null) {
try {
const parsed = JSON.parse(mcpRaw);
if (Array.isArray(parsed) && parsed.length > 0) {
mcpOverrides = parsed;
}
// Store the value and set timestamp for existing values
localStorage.setItem(storageKey, JSON.stringify(value));
setTimestamp(storageKey);
} catch (e) {
console.error('Failed to parse MCP values:', e);
}
}
setEphemeralAgent((prev) => {
if (prev == null) {
/** ephemeralAgent is null — use localStorage defaults */
if (hasOverrides || mcpOverrides) {
const result = { ...initialValues };
if (mcpOverrides) {
result.mcp = mcpOverrides;
}
return result;
}
return prev;
}
/** ephemeralAgent already has values (from prior state).
* Only fill in undefined keys from localStorage. */
let changed = false;
const result = { ...prev };
for (const [toolKey, value] of Object.entries(initialValues)) {
if (result[toolKey] === undefined) {
result[toolKey] = value;
changed = true;
}
}
if (mcpOverrides && result.mcp === undefined) {
result.mcp = mcpOverrides;
changed = true;
}
return changed ? result : prev;
});
}
}, [key, isSubmitting, setEphemeralAgent]);
}, [storageSuffix, specName, isSubmitting, setEphemeralAgent]);
/** CodeInterpreter hooks */
const codeApiKeyForm = useCodeApiKeyForm({});
@ -146,6 +198,7 @@ export default function BadgeRowProvider({
const codeInterpreter = useToolToggle({
conversationId,
storageContextKey,
setIsDialogOpen: setCodeDialogOpen,
toolKey: Tools.execute_code,
localStorageKey: LocalStorageKeys.LAST_CODE_TOGGLE_,
@ -161,6 +214,7 @@ export default function BadgeRowProvider({
const webSearch = useToolToggle({
conversationId,
storageContextKey,
toolKey: Tools.web_search,
localStorageKey: LocalStorageKeys.LAST_WEB_SEARCH_TOGGLE_,
setIsDialogOpen: setWebSearchDialogOpen,
@ -173,6 +227,7 @@ export default function BadgeRowProvider({
/** FileSearch hook */
const fileSearch = useToolToggle({
conversationId,
storageContextKey,
toolKey: Tools.file_search,
localStorageKey: LocalStorageKeys.LAST_FILE_SEARCH_TOGGLE_,
isAuthenticated: true,
@ -181,12 +236,13 @@ export default function BadgeRowProvider({
/** Artifacts hook - using a custom key since it's not a Tool but a capability */
const artifacts = useToolToggle({
conversationId,
storageContextKey,
toolKey: AgentCapabilities.artifacts,
localStorageKey: LocalStorageKeys.LAST_ARTIFACTS_TOGGLE_,
isAuthenticated: true,
});
const mcpServerManager = useMCPServerManager({ conversationId });
const mcpServerManager = useMCPServerManager({ conversationId, storageContextKey });
const value: BadgeRowContextType = {
webSearch,
@ -194,6 +250,7 @@ export default function BadgeRowProvider({
fileSearch,
agentsConfig,
conversationId,
storageContextKey,
codeApiKeyForm,
codeInterpreter,
searchApiKeyForm,

View file

@ -9,6 +9,8 @@ import type {
} from 'librechat-data-provider';
import type { OptionWithIcon, ExtendedFile } from './types';
export type AgentQueryResult = { found: true; agent: Agent } | { found: false };
export type TAgentOption = OptionWithIcon &
Agent & {
knowledge_files?: Array<[string, ExtendedFile]>;

View file

@ -28,6 +28,7 @@ interface BadgeRowProps {
onChange: (badges: Pick<BadgeItem, 'id'>[]) => void;
onToggle?: (badgeId: string, currentActive: boolean) => void;
conversationId?: string | null;
specName?: string | null;
isSubmitting?: boolean;
isInChat: boolean;
}
@ -142,6 +143,7 @@ const dragReducer = (state: DragState, action: DragAction): DragState => {
function BadgeRow({
showEphemeralBadges,
conversationId,
specName,
isSubmitting,
onChange,
onToggle,
@ -320,7 +322,11 @@ function BadgeRow({
}, [dragState.draggedBadge, handleMouseMove, handleMouseUp]);
return (
<BadgeRowProvider conversationId={conversationId} isSubmitting={isSubmitting}>
<BadgeRowProvider
conversationId={conversationId}
specName={specName}
isSubmitting={isSubmitting}
>
<div ref={containerRef} className="relative flex flex-wrap items-center gap-2">
{showEphemeralBadges === true && <ToolsDropdown />}
{tempBadges.map((badge, index) => (

View file

@ -325,6 +325,7 @@ const ChatForm = memo(({ index = 0 }: { index?: number }) => {
}
isSubmitting={isSubmitting}
conversationId={conversationId}
specName={conversation?.spec}
onChange={setBadges}
isInChat={
Array.isArray(conversation?.messages) && conversation.messages.length >= 1

View file

@ -158,11 +158,11 @@ const ImagePreview = ({
<DialogPrimitive.Root open={isModalOpen} onOpenChange={handleOpenChange}>
<DialogPrimitive.Portal>
<DialogPrimitive.Overlay
className="fixed inset-0 z-[100] bg-black/90"
className="fixed inset-0 z-[250] bg-black/90"
onClick={handleBackgroundClick}
/>
<DialogPrimitive.Content
className="fixed inset-0 z-[100] flex items-center justify-center outline-none"
className="fixed inset-0 z-[250] flex items-center justify-center outline-none"
onOpenAutoFocus={(e) => {
e.preventDefault();
closeButtonRef.current?.focus();

View file

@ -11,7 +11,7 @@ import { useHasAccess } from '~/hooks';
import { cn } from '~/utils';
function MCPSelectContent() {
const { conversationId, mcpServerManager } = useBadgeRowContext();
const { conversationId, storageContextKey, mcpServerManager } = useBadgeRowContext();
const {
localize,
isPinned,
@ -128,7 +128,11 @@ function MCPSelectContent() {
</Ariakit.Menu>
</Ariakit.MenuProvider>
{configDialogProps && (
<MCPConfigDialog {...configDialogProps} conversationId={conversationId} />
<MCPConfigDialog
{...configDialogProps}
conversationId={conversationId}
storageContextKey={storageContextKey}
/>
)}
</>
);

View file

@ -15,7 +15,7 @@ interface MCPSubMenuProps {
const MCPSubMenu = React.forwardRef<HTMLDivElement, MCPSubMenuProps>(
({ placeholder, ...props }, ref) => {
const localize = useLocalize();
const { mcpServerManager } = useBadgeRowContext();
const { storageContextKey, mcpServerManager } = useBadgeRowContext();
const {
isPinned,
mcpValues,
@ -106,7 +106,9 @@ const MCPSubMenu = React.forwardRef<HTMLDivElement, MCPSubMenuProps>(
</div>
</Ariakit.Menu>
</Ariakit.MenuProvider>
{configDialogProps && <MCPConfigDialog {...configDialogProps} />}
{configDialogProps && (
<MCPConfigDialog {...configDialogProps} storageContextKey={storageContextKey} />
)}
</div>
);
},

View file

@ -41,7 +41,8 @@ const SubmitButton = React.memo(
const SendButton = React.memo(
forwardRef((props: SendButtonProps, ref: React.ForwardedRef<HTMLButtonElement>) => {
const data = useWatch({ control: props.control });
return <SubmitButton ref={ref} disabled={props.disabled || !data.text} />;
const content = data?.text?.trim();
return <SubmitButton ref={ref} disabled={props.disabled || !content} />;
}),
);

View file

@ -80,12 +80,76 @@ const SettingsButton = ({
);
};
/**
* Lazily-rendered content for an endpoint submenu. By extracting this into a
* separate component, the expensive model-list rendering (and per-item hooks
* such as MutationObservers in EndpointModelItem) only runs when the submenu
* is actually mounted which Ariakit defers via `unmountOnHide`.
*/
function EndpointMenuContent({
endpoint,
endpointIndex,
}: {
endpoint: Endpoint;
endpointIndex: number;
}) {
const localize = useLocalize();
const { agentsMap, assistantsMap, modelSpecs, selectedValues, endpointSearchValues } =
useModelSelectorContext();
const { model: selectedModel, modelSpec: selectedSpec } = selectedValues;
const searchValue = endpointSearchValues[endpoint.value] || '';
const endpointSpecs = useMemo(() => {
if (!modelSpecs || !modelSpecs.length) {
return [];
}
return modelSpecs.filter((spec: TModelSpec) => spec.group === endpoint.value);
}, [modelSpecs, endpoint.value]);
if (isAssistantsEndpoint(endpoint.value) && endpoint.models === undefined) {
return (
<div
className="flex items-center justify-center p-2"
role="status"
aria-label={localize('com_ui_loading')}
>
<Spinner aria-hidden="true" />
</div>
);
}
const filteredModels = searchValue
? filterModels(
endpoint,
(endpoint.models || []).map((model) => model.name),
searchValue,
agentsMap,
assistantsMap,
)
: null;
return (
<>
{endpointSpecs.map((spec: TModelSpec) => (
<ModelSpecItem key={spec.name} spec={spec} isSelected={selectedSpec === spec.name} />
))}
{filteredModels
? renderEndpointModels(
endpoint,
endpoint.models || [],
selectedModel,
filteredModels,
endpointIndex,
)
: endpoint.models &&
renderEndpointModels(endpoint, endpoint.models, selectedModel, undefined, endpointIndex)}
</>
);
}
export function EndpointItem({ endpoint, endpointIndex }: EndpointItemProps) {
const localize = useLocalize();
const {
agentsMap,
assistantsMap,
modelSpecs,
selectedValues,
handleOpenKeyDialog,
handleSelectEndpoint,
@ -93,19 +157,7 @@ export function EndpointItem({ endpoint, endpointIndex }: EndpointItemProps) {
setEndpointSearchValue,
endpointRequiresUserKey,
} = useModelSelectorContext();
const {
model: selectedModel,
endpoint: selectedEndpoint,
modelSpec: selectedSpec,
} = selectedValues;
// Filter modelSpecs for this endpoint (by group matching endpoint value)
const endpointSpecs = useMemo(() => {
if (!modelSpecs || !modelSpecs.length) {
return [];
}
return modelSpecs.filter((spec: TModelSpec) => spec.group === endpoint.value);
}, [modelSpecs, endpoint.value]);
const { endpoint: selectedEndpoint } = selectedValues;
const searchValue = endpointSearchValues[endpoint.value] || '';
const isUserProvided = useMemo(
@ -130,15 +182,6 @@ export function EndpointItem({ endpoint, endpointIndex }: EndpointItemProps) {
const isEndpointSelected = selectedEndpoint === endpoint.value;
if (endpoint.hasModels) {
const filteredModels = searchValue
? filterModels(
endpoint,
(endpoint.models || []).map((model) => model.name),
searchValue,
agentsMap,
assistantsMap,
)
: null;
const placeholder =
isAgentsEndpoint(endpoint.value) || isAssistantsEndpoint(endpoint.value)
? localize('com_endpoint_search_var', { 0: endpoint.label })
@ -147,7 +190,6 @@ export function EndpointItem({ endpoint, endpointIndex }: EndpointItemProps) {
<Menu
id={`endpoint-${endpoint.value}-menu`}
key={`endpoint-${endpoint.value}-item`}
defaultOpen={endpoint.value === selectedEndpoint}
searchValue={searchValue}
onSearch={(value) => setEndpointSearchValue(endpoint.value, value)}
combobox={<input placeholder=" " />}
@ -170,39 +212,7 @@ export function EndpointItem({ endpoint, endpointIndex }: EndpointItemProps) {
</div>
}
>
{isAssistantsEndpoint(endpoint.value) && endpoint.models === undefined ? (
<div
className="flex items-center justify-center p-2"
role="status"
aria-label={localize('com_ui_loading')}
>
<Spinner aria-hidden="true" />
</div>
) : (
<>
{/* Render modelSpecs for this endpoint */}
{endpointSpecs.map((spec: TModelSpec) => (
<ModelSpecItem key={spec.name} spec={spec} isSelected={selectedSpec === spec.name} />
))}
{/* Render endpoint models */}
{filteredModels
? renderEndpointModels(
endpoint,
endpoint.models || [],
selectedModel,
filteredModels,
endpointIndex,
)
: endpoint.models &&
renderEndpointModels(
endpoint,
endpoint.models,
selectedModel,
undefined,
endpointIndex,
)}
</>
)}
<EndpointMenuContent endpoint={endpoint} endpointIndex={endpointIndex} />
</Menu>
);
} else {

View file

@ -111,7 +111,7 @@ export const a: React.ElementType = memo(({ href, children }: TAnchorProps) => {
}, [user?.id, href]);
const { refetch: downloadFile } = useFileDownload(user?.id ?? '', file_id);
const props: { target?: string; onClick?: React.MouseEventHandler } = { target: '_new' };
const props: { target?: string; onClick?: React.MouseEventHandler } = { target: '_blank' };
if (!file_id || !filename) {
return (

View file

@ -38,7 +38,6 @@ const MarkdownLite = memo(
]}
/** @ts-ignore */
rehypePlugins={rehypePlugins}
// linkTarget="_new"
components={
{
code: codeExecution ? code : codeNoExecution,

View file

@ -67,9 +67,20 @@ const Part = memo(
if (part.tool_call_ids != null && !text) {
return null;
}
/** Skip rendering if text is only whitespace to avoid empty Container */
if (!isLast && text.length > 0 && /^\s*$/.test(text)) {
return null;
/** Handle whitespace-only text to avoid layout shift */
if (text.length > 0 && /^\s*$/.test(text)) {
/** Show placeholder for whitespace-only last part during streaming */
if (isLast && showCursor) {
return (
<Container>
<EmptyText />
</Container>
);
}
/** Skip rendering non-last whitespace-only parts to avoid empty Container */
if (!isLast) {
return null;
}
}
return (
<Container>

View file

@ -1,7 +1,12 @@
import { useMemo, useState, useEffect, useRef, useLayoutEffect } from 'react';
import { useMemo, useState, useEffect, useRef, useCallback, useLayoutEffect } from 'react';
import { Button } from '@librechat/client';
import { TriangleAlert } from 'lucide-react';
import { actionDelimiter, actionDomainSeparator, Constants } from 'librechat-data-provider';
import {
Constants,
dataService,
actionDelimiter,
actionDomainSeparator,
} from 'librechat-data-provider';
import type { TAttachment } from 'librechat-data-provider';
import { useLocalize, useProgress } from '~/hooks';
import { AttachmentGroup } from './Parts';
@ -36,9 +41,9 @@ export default function ToolCall({
const [isAnimating, setIsAnimating] = useState(false);
const prevShowInfoRef = useRef<boolean>(showInfo);
const { function_name, domain, isMCPToolCall } = useMemo(() => {
const { function_name, domain, isMCPToolCall, mcpServerName } = useMemo(() => {
if (typeof name !== 'string') {
return { function_name: '', domain: null, isMCPToolCall: false };
return { function_name: '', domain: null, isMCPToolCall: false, mcpServerName: '' };
}
if (name.includes(Constants.mcp_delimiter)) {
const [func, server] = name.split(Constants.mcp_delimiter);
@ -46,6 +51,7 @@ export default function ToolCall({
function_name: func || '',
domain: server && (server.replaceAll(actionDomainSeparator, '.') || null),
isMCPToolCall: true,
mcpServerName: server || '',
};
}
const [func, _domain] = name.includes(actionDelimiter)
@ -55,9 +61,40 @@ export default function ToolCall({
function_name: func || '',
domain: _domain && (_domain.replaceAll(actionDomainSeparator, '.') || null),
isMCPToolCall: false,
mcpServerName: '',
};
}, [name]);
const actionId = useMemo(() => {
if (isMCPToolCall || !auth) {
return '';
}
try {
const url = new URL(auth);
const redirectUri = url.searchParams.get('redirect_uri') || '';
const match = redirectUri.match(/\/api\/actions\/([^/]+)\/oauth\/callback/);
return match?.[1] || '';
} catch {
return '';
}
}, [auth, isMCPToolCall]);
const handleOAuthClick = useCallback(async () => {
if (!auth) {
return;
}
try {
if (isMCPToolCall && mcpServerName) {
await dataService.bindMCPOAuth(mcpServerName);
} else if (actionId) {
await dataService.bindActionOAuth(actionId);
}
} catch (e) {
logger.error('Failed to bind OAuth CSRF cookie', e);
}
window.open(auth, '_blank', 'noopener,noreferrer');
}, [auth, isMCPToolCall, mcpServerName, actionId]);
const error =
typeof output === 'string' && output.toLowerCase().includes('error processing tool');
@ -230,7 +267,7 @@ export default function ToolCall({
className="font-mediu inline-flex items-center justify-center rounded-xl px-4 py-2 text-sm"
variant="default"
rel="noopener noreferrer"
onClick={() => window.open(auth, '_blank', 'noopener,noreferrer')}
onClick={handleOAuthClick}
>
{localize('com_ui_sign_in_to_domain', { 0: authDomain })}
</Button>

View file

@ -24,6 +24,7 @@ interface MCPConfigDialogProps {
serverName: string;
serverStatus?: MCPServerStatus;
conversationId?: string | null;
storageContextKey?: string;
}
export default function MCPConfigDialog({
@ -36,6 +37,7 @@ export default function MCPConfigDialog({
serverName,
serverStatus,
conversationId,
storageContextKey,
}: MCPConfigDialogProps) {
const localize = useLocalize();
@ -167,6 +169,7 @@ export default function MCPConfigDialog({
<ServerInitializationSection
serverName={serverName}
conversationId={conversationId}
storageContextKey={storageContextKey}
requiresOAuth={serverStatus?.requiresOAuth || false}
hasCustomUserVars={fieldsSchema && Object.keys(fieldsSchema).length > 0}
/>

View file

@ -9,12 +9,14 @@ interface ServerInitializationSectionProps {
requiresOAuth: boolean;
hasCustomUserVars?: boolean;
conversationId?: string | null;
storageContextKey?: string;
}
export default function ServerInitializationSection({
serverName,
requiresOAuth,
conversationId,
storageContextKey,
sidePanel = false,
hasCustomUserVars = false,
}: ServerInitializationSectionProps) {
@ -28,7 +30,7 @@ export default function ServerInitializationSection({
initializeServer,
availableMCPServers,
revokeOAuthForServer,
} = useMCPServerManager({ conversationId });
} = useMCPServerManager({ conversationId, storageContextKey });
const { connectionStatus } = useMCPConnectionStatus({
enabled: !!availableMCPServers && availableMCPServers.length > 0,

View file

@ -9,6 +9,7 @@ import { QueryKeys, dataService } from 'librechat-data-provider';
import type t from 'librechat-data-provider';
import { useFavorites, useLocalize, useShowMarketplace, useNewConvo } from '~/hooks';
import { useAssistantsMapContext, useAgentsMapContext } from '~/Providers';
import type { AgentQueryResult } from '~/common';
import useSelectMention from '~/hooks/Input/useSelectMention';
import { useGetEndpointsQuery } from '~/data-provider';
import FavoriteItem from './FavoriteItem';
@ -184,7 +185,20 @@ export default function FavoritesList({
const missingAgentQueries = useQueries({
queries: missingAgentIds.map((agentId) => ({
queryKey: [QueryKeys.agent, agentId],
queryFn: () => dataService.getAgentById({ agent_id: agentId }),
queryFn: async (): Promise<AgentQueryResult> => {
try {
const agent = await dataService.getAgentById({ agent_id: agentId });
return { found: true, agent };
} catch (error) {
if (error && typeof error === 'object' && 'response' in error) {
const axiosError = error as { response?: { status?: number } };
if (axiosError.response?.status === 404) {
return { found: false };
}
}
throw error;
}
},
staleTime: 1000 * 60 * 5,
enabled: missingAgentIds.length > 0,
})),
@ -201,8 +215,8 @@ export default function FavoritesList({
}
}
missingAgentQueries.forEach((query) => {
if (query.data) {
combined[query.data.id] = query.data;
if (query.data?.found) {
combined[query.data.agent.id] = query.data.agent;
}
});
return combined;

View file

@ -0,0 +1,191 @@
import React from 'react';
import { render, waitFor } from '@testing-library/react';
import '@testing-library/jest-dom';
import { QueryClient, QueryClientProvider } from '@tanstack/react-query';
import { RecoilRoot } from 'recoil';
import { DndProvider } from 'react-dnd';
import { HTML5Backend } from 'react-dnd-html5-backend';
import { BrowserRouter } from 'react-router-dom';
import { dataService } from 'librechat-data-provider';
import type t from 'librechat-data-provider';
// Mock store before importing FavoritesList
jest.mock('~/store', () => {
const { atom } = jest.requireActual('recoil');
return {
__esModule: true,
default: {
search: atom({
key: 'mock-search-atom',
default: { query: '' },
}),
conversationByIndex: (index: number) =>
atom({
key: `mock-conversation-atom-${index}`,
default: null,
}),
},
};
});
import FavoritesList from '../FavoritesList';
type FavoriteItem = {
agentId?: string;
model?: string;
endpoint?: string;
};
// Mock dataService
jest.mock('librechat-data-provider', () => ({
...jest.requireActual('librechat-data-provider'),
dataService: {
getAgentById: jest.fn(),
},
}));
// Mock hooks
const mockFavorites: FavoriteItem[] = [];
const mockUseFavorites = jest.fn(() => ({
favorites: mockFavorites,
reorderFavorites: jest.fn(),
isLoading: false,
}));
jest.mock('~/hooks', () => ({
useFavorites: () => mockUseFavorites(),
useLocalize: () => (key: string) => key,
useShowMarketplace: () => false,
useNewConvo: () => ({ newConversation: jest.fn() }),
}));
jest.mock('~/Providers', () => ({
useAssistantsMapContext: () => ({}),
useAgentsMapContext: () => ({}),
}));
jest.mock('~/hooks/Input/useSelectMention', () => () => ({
onSelectEndpoint: jest.fn(),
}));
jest.mock('~/data-provider', () => ({
useGetEndpointsQuery: () => ({ data: {} }),
}));
jest.mock('../FavoriteItem', () => ({
__esModule: true,
default: ({ item, type }: { item: any; type: string }) => (
<div data-testid="favorite-item" data-type={type}>
{type === 'agent' ? item.name : item.model}
</div>
),
}));
const createTestQueryClient = () =>
new QueryClient({
defaultOptions: {
queries: {
retry: false,
},
},
});
const renderWithProviders = (ui: React.ReactElement) => {
const queryClient = createTestQueryClient();
return render(
<QueryClientProvider client={queryClient}>
<RecoilRoot>
<BrowserRouter>
<DndProvider backend={HTML5Backend}>{ui}</DndProvider>
</BrowserRouter>
</RecoilRoot>
</QueryClientProvider>,
);
};
describe('FavoritesList', () => {
beforeEach(() => {
jest.clearAllMocks();
mockFavorites.length = 0;
});
describe('rendering', () => {
it('should render nothing when favorites is empty and marketplace is hidden', () => {
const { container } = renderWithProviders(<FavoritesList />);
expect(container.firstChild).toBeNull();
});
it('should render skeleton while loading', () => {
mockUseFavorites.mockReturnValueOnce({
favorites: [],
reorderFavorites: jest.fn(),
isLoading: true,
});
const { container } = renderWithProviders(<FavoritesList />);
// Skeletons should be present during loading - container should have children
expect(container.firstChild).not.toBeNull();
// When loading, the component renders skeleton placeholders (check for content, not specific CSS)
expect(container.innerHTML).toContain('div');
});
});
describe('missing agent handling', () => {
it('should exclude missing agents (404) from rendered favorites and render valid agents', async () => {
const validAgent: t.Agent = {
id: 'valid-agent',
name: 'Valid Agent',
author: 'test-author',
} as t.Agent;
// Set up favorites with both valid and missing agent
mockFavorites.push({ agentId: 'valid-agent' }, { agentId: 'deleted-agent' });
// Mock getAgentById: valid-agent returns successfully, deleted-agent returns 404
(dataService.getAgentById as jest.Mock).mockImplementation(
({ agent_id }: { agent_id: string }) => {
if (agent_id === 'valid-agent') {
return Promise.resolve(validAgent);
}
if (agent_id === 'deleted-agent') {
return Promise.reject({ response: { status: 404 } });
}
return Promise.reject(new Error('Unknown agent'));
},
);
const { findAllByTestId } = renderWithProviders(<FavoritesList />);
// Wait for queries to resolve
const favoriteItems = await findAllByTestId('favorite-item');
// Only the valid agent should be rendered
expect(favoriteItems).toHaveLength(1);
expect(favoriteItems[0]).toHaveTextContent('Valid Agent');
// The deleted agent should still be requested, but not rendered
expect(dataService.getAgentById as jest.Mock).toHaveBeenCalledWith({
agent_id: 'deleted-agent',
});
});
it('should not show infinite loading skeleton when agents return 404', async () => {
// Set up favorites with only a deleted agent
mockFavorites.push({ agentId: 'deleted-agent' });
// Mock getAgentById to return 404
(dataService.getAgentById as jest.Mock).mockRejectedValue({ response: { status: 404 } });
const { queryAllByTestId } = renderWithProviders(<FavoritesList />);
// Wait for the loading state to resolve after 404 handling by ensuring the agent request was made
await waitFor(() => {
expect(dataService.getAgentById as jest.Mock).toHaveBeenCalledWith({
agent_id: 'deleted-agent',
});
});
// No favorite items should be rendered (deleted agent is filtered out)
expect(queryAllByTestId('favorite-item')).toHaveLength(0);
});
});
});

View file

@ -1,7 +1,7 @@
import { Plus } from 'lucide-react';
import React, { useMemo, useCallback, useRef, useState } from 'react';
import { Plus } from 'lucide-react';
import { Button, useToastContext } from '@librechat/client';
import { useWatch, useForm, FormProvider, type FieldNamesMarkedBoolean } from 'react-hook-form';
import { useWatch, useForm, FormProvider } from 'react-hook-form';
import { useGetModelsQuery } from 'librechat-data-provider/react-query';
import {
Tools,
@ -11,8 +11,10 @@ import {
PermissionBits,
isAssistantsEndpoint,
} from 'librechat-data-provider';
import type { AgentForm, StringOption } from '~/common';
import type { FieldNamesMarkedBoolean } from 'react-hook-form';
import type { Agent } from 'librechat-data-provider';
import type { TranslationKeys } from '~/hooks/useLocalize';
import type { AgentForm, StringOption } from '~/common';
import {
useCreateAgentMutation,
useUpdateAgentMutation,
@ -23,7 +25,6 @@ import {
import { createProviderOption, getDefaultAgentFormValues } from '~/utils';
import { useResourcePermissions } from '~/hooks/useResourcePermissions';
import { useSelectAgent, useLocalize, useAuthContext } from '~/hooks';
import type { TranslationKeys } from '~/hooks/useLocalize';
import { useAgentPanelContext } from '~/Providers/AgentPanelContext';
import AgentPanelSkeleton from './AgentPanelSkeleton';
import AdvancedPanel from './Advanced/AdvancedPanel';

View file

@ -46,7 +46,7 @@ export default function MCPTools({
return null;
}
if (serverInfo.isConnected) {
if (serverInfo?.tools?.length && serverInfo.tools.length > 0) {
return (
<MCPTool key={`${serverInfo.serverName}-${agentId}`} serverInfo={serverInfo} />
);

View file

@ -1,10 +1,10 @@
import { FormProvider } from 'react-hook-form';
import type { useMCPServerForm } from './hooks/useMCPServerForm';
import ConnectionSection from './sections/ConnectionSection';
import BasicInfoSection from './sections/BasicInfoSection';
import TransportSection from './sections/TransportSection';
import AuthSection from './sections/AuthSection';
import TrustSection from './sections/TrustSection';
import type { useMCPServerForm } from './hooks/useMCPServerForm';
import AuthSection from './sections/AuthSection';
interface MCPServerFormProps {
formHook: ReturnType<typeof useMCPServerForm>;

View file

@ -1,13 +1,18 @@
import React, { useState, useEffect } from 'react';
import { Copy, CopyCheck } from 'lucide-react';
import {
OGDialog,
OGDialogTemplate,
OGDialogContent,
OGDialogHeader,
OGDialogTitle,
Label,
Input,
Button,
TrashIcon,
Spinner,
TrashIcon,
useToastContext,
OGDialog,
OGDialogTitle,
OGDialogHeader,
OGDialogFooter,
OGDialogContent,
OGDialogTemplate,
} from '@librechat/client';
import {
SystemRoles,
@ -16,10 +21,10 @@ import {
PermissionBits,
PermissionTypes,
} from 'librechat-data-provider';
import { GenericGrantAccessDialog } from '~/components/Sharing';
import { useAuthContext, useHasAccess, useResourcePermissions, MCPServerDefinition } from '~/hooks';
import { useLocalize } from '~/hooks';
import { GenericGrantAccessDialog } from '~/components/Sharing';
import { useMCPServerForm } from './hooks/useMCPServerForm';
import { useLocalize, useCopyToClipboard } from '~/hooks';
import MCPServerForm from './MCPServerForm';
interface MCPServerDialogProps {
@ -39,8 +44,10 @@ export default function MCPServerDialog({
}: MCPServerDialogProps) {
const localize = useLocalize();
const { user } = useAuthContext();
const { showToast } = useToastContext();
// State for dialogs
const [isCopying, setIsCopying] = useState(false);
const [showDeleteConfirm, setShowDeleteConfirm] = useState(false);
const [showRedirectUriDialog, setShowRedirectUriDialog] = useState(false);
const [createdServerId, setCreatedServerId] = useState<string | null>(null);
@ -99,20 +106,26 @@ export default function MCPServerDialog({
? `${window.location.origin}/api/mcp/${createdServerId}/oauth/callback`
: '';
const copyLink = useCopyToClipboard({ text: redirectUri });
return (
<>
{/* Delete confirmation dialog */}
<OGDialog open={showDeleteConfirm} onOpenChange={(isOpen) => setShowDeleteConfirm(isOpen)}>
<OGDialogTemplate
title={localize('com_ui_delete')}
className="max-w-[450px]"
main={<p className="text-left text-sm">{localize('com_ui_mcp_server_delete_confirm')}</p>}
selection={{
selectHandler: handleDelete,
selectClasses:
'bg-destructive text-white transition-all duration-200 hover:bg-destructive/80',
selectText: isDeleting ? <Spinner /> : localize('com_ui_delete'),
}}
title={localize('com_ui_delete_mcp_server')}
className="w-11/12 max-w-md"
description={localize('com_ui_mcp_server_delete_confirm', { 0: server?.serverName })}
selection={
<Button
onClick={handleDelete}
variant="destructive"
aria-live="polite"
aria-label={isDeleting ? localize('com_ui_deleting') : localize('com_ui_delete')}
>
{isDeleting ? <Spinner /> : localize('com_ui_delete')}
</Button>
}
/>
</OGDialog>
@ -127,48 +140,53 @@ export default function MCPServerDialog({
}
}}
>
<OGDialogContent className="w-full max-w-lg border-none bg-surface-primary text-text-primary">
<OGDialogHeader className="border-b border-border-light px-4 py-3">
<OGDialogContent showCloseButton={false} className="w-11/12 max-w-lg">
<OGDialogHeader>
<OGDialogTitle>{localize('com_ui_mcp_server_created')}</OGDialogTitle>
</OGDialogHeader>
<div className="space-y-4 p-4">
<p className="text-sm text-text-secondary">
{localize('com_ui_redirect_uri_instructions')}
</p>
<div className="rounded-lg border border-border-medium bg-surface-secondary p-3">
<label className="mb-2 block text-xs font-medium text-text-secondary">
<div className="space-y-4">
<p className="text-sm">{localize('com_ui_redirect_uri_instructions')}</p>
<div className="space-y-2">
<Label htmlFor="redirect-uri-input" className="text-sm font-medium">
{localize('com_ui_redirect_uri')}
</label>
</Label>
<div className="flex items-center gap-2">
<input
className="flex-1 rounded border border-border-medium bg-surface-primary px-3 py-2 text-sm"
value={redirectUri}
<Input
id="redirect-uri-input"
type="text"
readOnly
value={redirectUri}
className="flex-1 text-text-secondary"
/>
<Button
onClick={() => {
navigator.clipboard.writeText(redirectUri);
}}
size="icon"
variant="outline"
className="whitespace-nowrap"
onClick={() => {
if (isCopying) return;
showToast({ message: localize('com_ui_copied_to_clipboard') });
copyLink(setIsCopying);
}}
disabled={isCopying}
className="p-0"
aria-label={localize('com_ui_copy_link')}
>
{localize('com_ui_copy_link')}
{isCopying ? <CopyCheck className="size-4" /> : <Copy className="size-4" />}
</Button>
</div>
</div>
<div className="flex justify-end">
<OGDialogFooter>
<Button
variant="default"
onClick={() => {
setShowRedirectUriDialog(false);
onOpenChange(false);
setCreatedServerId(null);
}}
variant="submit"
className="text-white"
>
{localize('com_ui_done')}
</Button>
</div>
</OGDialogFooter>
</div>
</OGDialogContent>
</OGDialog>
@ -187,6 +205,7 @@ export default function MCPServerDialog({
})
: undefined
}
showCloseButton={false}
className="w-11/12 md:max-w-3xl"
main={<MCPServerForm formHook={formHook} />}
footerClassName="sm:justify-between"
@ -194,16 +213,15 @@ export default function MCPServerDialog({
isEditMode ? (
<div className="flex items-center gap-2">
<Button
type="button"
variant="outline"
variant="destructive"
size="sm"
aria-label={localize('com_ui_delete')}
aria-label={localize('com_ui_delete_mcp_server_name', {
0: server?.config?.title || server?.serverName || '',
})}
onClick={() => setShowDeleteConfirm(true)}
disabled={isSubmitting || isDeleting}
>
<div className="flex w-full items-center justify-center gap-2 text-red-500">
<TrashIcon />
</div>
<TrashIcon aria-hidden="true" />
</Button>
{shouldShowShareButton && server && (
<GenericGrantAccessDialog
@ -218,10 +236,15 @@ export default function MCPServerDialog({
buttons={
<Button
type="button"
variant="submit"
variant={isEditMode ? 'default' : 'submit'}
onClick={onSubmit}
disabled={isSubmitting}
className="text-white"
aria-live="polite"
aria-label={
isSubmitting
? localize(isEditMode ? 'com_ui_updating' : 'com_ui_creating')
: localize(isEditMode ? 'com_ui_update_mcp_server' : 'com_ui_create_mcp_server')
}
>
{isSubmitting ? (
<Spinner className="size-4" />

View file

@ -1,11 +1,11 @@
import { useMemo, useState } from 'react';
import { Copy, CopyCheck } from 'lucide-react';
import { useFormContext, useWatch } from 'react-hook-form';
import { Label, Input, Checkbox, SecretInput, Radio, useToastContext } from '@librechat/client';
import { Copy, CopyCheck } from 'lucide-react';
import { useLocalize, useCopyToClipboard } from '~/hooks';
import { cn } from '~/utils';
import { AuthTypeEnum, AuthorizationTypeEnum } from '../hooks/useMCPServerForm';
import type { MCPServerFormData } from '../hooks/useMCPServerForm';
import { useLocalize, useCopyToClipboard } from '~/hooks';
import { cn } from '~/utils';
interface AuthSectionProps {
isEditMode: boolean;
@ -62,15 +62,20 @@ export default function AuthSection({ isEditMode, serverName }: AuthSectionProps
return (
<div className="space-y-3">
{/* Auth Type Radio */}
<div className="space-y-1.5">
<Label className="text-sm font-medium">{localize('com_ui_authentication')}</Label>
<fieldset className="space-y-1.5">
<legend>
<Label id="auth-type-label" className="text-sm font-medium">
{localize('com_ui_authentication')}
</Label>
</legend>
<Radio
options={authTypeOptions}
value={authType || AuthTypeEnum.None}
onChange={(val) => setValue('auth.auth_type', val as AuthTypeEnum)}
fullWidth
aria-labelledby="auth-type-label"
/>
</div>
</fieldset>
{/* API Key Fields */}
{authType === AuthTypeEnum.ServiceHttp && (
@ -83,9 +88,13 @@ export default function AuthSection({ isEditMode, serverName }: AuthSectionProps
onCheckedChange={(checked) =>
setValue('auth.api_key_source', checked ? 'user' : 'admin')
}
aria-label={localize('com_ui_user_provides_key')}
aria-labelledby="user_provides_key_label"
/>
<label htmlFor="user_provides_key" className="cursor-pointer text-sm">
<label
id="user_provides_key_label"
htmlFor="user_provides_key"
className="cursor-pointer text-sm"
>
{localize('com_ui_user_provides_key')}
</label>
</div>
@ -101,8 +110,12 @@ export default function AuthSection({ isEditMode, serverName }: AuthSectionProps
)}
{/* Header Format Radio */}
<div className="space-y-1.5">
<Label className="text-sm font-medium">{localize('com_ui_header_format')}</Label>
<fieldset className="space-y-1.5">
<legend>
<Label id="header-format-label" className="text-sm font-medium">
{localize('com_ui_header_format')}
</Label>
</legend>
<Radio
options={headerFormatOptions}
value={authorizationType || AuthorizationTypeEnum.Bearer}
@ -110,8 +123,9 @@ export default function AuthSection({ isEditMode, serverName }: AuthSectionProps
setValue('auth.api_key_authorization_type', val as AuthorizationTypeEnum)
}
fullWidth
aria-labelledby="header-format-label"
/>
</div>
</fieldset>
{/* Custom header name */}
{authorizationType === AuthorizationTypeEnum.Custom && (
@ -137,27 +151,67 @@ export default function AuthSection({ isEditMode, serverName }: AuthSectionProps
<div className="space-y-1.5">
<Label htmlFor="oauth_client_id" className="text-sm font-medium">
{localize('com_ui_client_id')}{' '}
{!isEditMode && <span className="text-text-secondary">*</span>}
{!isEditMode && (
<>
<span aria-hidden="true" className="text-text-secondary">
*
</span>
<span className="sr-only">{localize('com_ui_field_required')}</span>
</>
)}
</Label>
<Input
id="oauth_client_id"
autoComplete="off"
placeholder={isEditMode ? localize('com_ui_leave_blank_to_keep') : ''}
aria-invalid={errors.auth?.oauth_client_id ? 'true' : 'false'}
aria-describedby={
errors.auth?.oauth_client_id ? 'oauth-client-id-error' : undefined
}
{...register('auth.oauth_client_id', { required: !isEditMode })}
className={cn(errors.auth?.oauth_client_id && 'border-red-500')}
className={cn(errors.auth?.oauth_client_id && 'border-border-destructive')}
/>
{errors.auth?.oauth_client_id && (
<p
id="oauth-client-id-error"
role="alert"
className="text-xs text-text-destructive"
>
{localize('com_ui_field_required')}
</p>
)}
</div>
<div className="space-y-1.5">
<Label htmlFor="oauth_client_secret" className="text-sm font-medium">
{localize('com_ui_client_secret')}{' '}
{!isEditMode && <span className="text-text-secondary">*</span>}
{!isEditMode && (
<>
<span aria-hidden="true" className="text-text-secondary">
*
</span>
<span className="sr-only">{localize('com_ui_field_required')}</span>
</>
)}
</Label>
<SecretInput
id="oauth_client_secret"
placeholder={isEditMode ? localize('com_ui_leave_blank_to_keep') : ''}
aria-invalid={errors.auth?.oauth_client_secret ? 'true' : 'false'}
aria-describedby={
errors.auth?.oauth_client_secret ? 'oauth-client-secret-error' : undefined
}
{...register('auth.oauth_client_secret', { required: !isEditMode })}
className={cn(errors.auth?.oauth_client_secret && 'border-red-500')}
className={cn(errors.auth?.oauth_client_secret && 'border-border-destructive')}
/>
{errors.auth?.oauth_client_secret && (
<p
id="oauth-client-secret-error"
role="alert"
className="text-xs text-text-destructive"
>
{localize('com_ui_field_required')}
</p>
)}
</div>
</div>
@ -196,9 +250,12 @@ export default function AuthSection({ isEditMode, serverName }: AuthSectionProps
{/* Redirect URI */}
{isEditMode && redirectUri && (
<div className="space-y-1.5">
<Label className="text-sm font-medium">{localize('com_ui_redirect_uri')}</Label>
<Label htmlFor="auth-redirect-uri" className="text-sm font-medium">
{localize('com_ui_redirect_uri')}
</Label>
<div className="flex items-center gap-2">
<Input
id="auth-redirect-uri"
type="text"
readOnly
value={redirectUri}

View file

@ -1,9 +1,9 @@
import { useFormContext } from 'react-hook-form';
import { Input, Label, TextareaAutosize } from '@librechat/client';
import { Input, Label, Textarea } from '@librechat/client';
import type { MCPServerFormData } from '../hooks/useMCPServerForm';
import MCPIcon from '~/components/SidePanel/Agents/MCPIcon';
import { useLocalize } from '~/hooks';
import { cn } from '~/utils';
import MCPIcon from '~/components/SidePanel/Agents/MCPIcon';
import type { MCPServerFormData } from '../hooks/useMCPServerForm';
export default function BasicInfoSection() {
const localize = useLocalize();
@ -36,13 +36,19 @@ export default function BasicInfoSection() {
<MCPIcon icon={iconValue} onIconChange={handleIconChange} />
</div>
<div className="w-full space-y-1.5 sm:flex-1">
<Label htmlFor="title" className="text-sm font-medium">
{localize('com_ui_name')} <span className="text-text-secondary">*</span>
<Label htmlFor="mcp-title" className="text-sm font-medium">
{localize('com_ui_name')}{' '}
<span aria-hidden="true" className="text-text-secondary">
*
</span>
<span className="sr-only">{localize('com_ui_field_required')}</span>
</Label>
<Input
id="title"
id="mcp-title"
autoComplete="off"
placeholder={localize('com_agents_mcp_name_placeholder')}
aria-invalid={errors.title ? 'true' : 'false'}
aria-describedby={errors.title ? 'mcp-title-error' : undefined}
{...register('title', {
required: localize('com_ui_field_required'),
pattern: {
@ -50,26 +56,26 @@ export default function BasicInfoSection() {
message: localize('com_ui_mcp_title_invalid'),
},
})}
className={cn(errors.title && 'border-red-500 focus:border-red-500')}
className={cn(errors.title && 'border-border-destructive')}
/>
{errors.title && <p className="text-xs text-red-500">{errors.title.message}</p>}
{errors.title && (
<p id="mcp-title-error" role="alert" className="text-xs text-text-destructive">
{errors.title.message}
</p>
)}
</div>
</div>
{/* Description - always visible, full width */}
{/* Description */}
<div className="space-y-1.5">
<Label htmlFor="description" className="text-sm font-medium">
<Label htmlFor="mcp-description" className="text-sm font-medium">
{localize('com_ui_description')}{' '}
<span className="text-xs text-text-secondary">{localize('com_ui_optional')}</span>
</Label>
<TextareaAutosize
id="description"
aria-label={localize('com_ui_description')}
<Textarea
id="mcp-description"
placeholder={localize('com_agents_mcp_description_placeholder')}
{...register('description')}
minRows={2}
maxRows={4}
className="w-full resize-none rounded-lg border border-input bg-transparent px-3 py-2 text-sm placeholder:text-muted-foreground focus-visible:outline-none"
/>
</div>
</div>

View file

@ -15,13 +15,19 @@ export default function ConnectionSection() {
return (
<div className="space-y-1.5">
<Label htmlFor="url" className="text-sm font-medium">
{localize('com_ui_mcp_url')} <span className="text-text-secondary">*</span>
{localize('com_ui_mcp_url')}{' '}
<span aria-hidden="true" className="text-text-secondary">
*
</span>
<span className="sr-only">{localize('com_ui_field_required')}</span>
</Label>
<Input
id="url"
type="url"
autoComplete="off"
placeholder={localize('com_ui_mcp_server_url_placeholder')}
aria-invalid={errors.url ? 'true' : 'false'}
aria-describedby={errors.url ? 'url-error' : undefined}
{...register('url', {
required: localize('com_ui_field_required'),
validate: (value) => {
@ -29,9 +35,13 @@ export default function ConnectionSection() {
return isValidUrl(normalized) || localize('com_ui_mcp_invalid_url');
},
})}
className={cn(errors.url && 'border-red-500 focus:border-red-500')}
className={cn(errors.url && 'border-border-destructive')}
/>
{errors.url && <p className="text-xs text-red-500">{errors.url.message}</p>}
{errors.url && (
<p id="url-error" role="alert" className="text-xs text-text-destructive">
{errors.url.message}
</p>
)}
</div>
);
}

View file

@ -25,14 +25,19 @@ export default function TransportSection() {
);
return (
<div className="space-y-2">
<Label className="text-sm font-medium">{localize('com_ui_mcp_transport')}</Label>
<fieldset className="space-y-2">
<legend>
<Label id="transport-label" className="text-sm font-medium">
{localize('com_ui_mcp_transport')}
</Label>
</legend>
<Radio
options={transportOptions}
value={transportType}
onChange={handleTransportChange}
fullWidth
aria-labelledby="transport-label"
/>
</div>
</fieldset>
);
}

View file

@ -26,17 +26,17 @@ export default function TrustSection() {
checked={field.value}
onCheckedChange={field.onChange}
aria-labelledby="trust-label"
aria-describedby="trust-description"
aria-describedby={
errors.trust ? 'trust-description trust-error' : 'trust-description'
}
aria-invalid={errors.trust ? 'true' : 'false'}
aria-required="true"
className="mt-0.5"
/>
)}
/>
<Label
id="trust-label"
htmlFor="trust"
className="flex cursor-pointer flex-col gap-0.5 text-sm"
>
<span className="font-medium text-text-primary">
<Label htmlFor="trust" className="flex cursor-pointer flex-col gap-0.5 text-sm">
<span id="trust-label" className="font-medium text-text-primary">
{startupConfig?.interface?.mcpServers?.trustCheckbox?.label ? (
<span
dangerouslySetInnerHTML={{
@ -49,7 +49,9 @@ export default function TrustSection() {
) : (
localize('com_ui_trust_app')
)}{' '}
<span className="text-text-secondary">*</span>
<span aria-hidden="true" className="text-text-secondary">
*
</span>
</span>
<span id="trust-description" className="text-xs font-normal text-text-secondary">
{startupConfig?.interface?.mcpServers?.trustCheckbox?.subLabel ? (
@ -68,7 +70,9 @@ export default function TrustSection() {
</Label>
</div>
{errors.trust && (
<p className="mt-2 text-xs text-red-500">{localize('com_ui_field_required')}</p>
<p id="trust-error" role="alert" className="mt-2 text-xs text-text-destructive">
{localize('com_ui_field_required')}
</p>
)}
</div>
);

View file

@ -96,17 +96,17 @@ function MCPToolSelectDialog({
await new Promise((resolve) => setTimeout(resolve, 500));
}
// Then initialize server if needed
// Only initialize if no cached tools exist; skip if tools are already available from DB
const serverInfo = mcpServersMap.get(serverName);
if (!serverInfo?.isConnected) {
if (!serverInfo?.tools?.length) {
const result = await initializeServer(serverName);
if (result?.success && result.oauthRequired && result.oauthUrl) {
if (result?.oauthRequired && result.oauthUrl) {
setIsInitializing(null);
return;
return; // OAuth flow must complete first
}
}
// Finally, add tools to form
// Add tools to form (refetches from backend's persisted cache)
await addToolsToForm(serverName);
setIsInitializing(null);
} catch (error) {

View file

@ -12,10 +12,10 @@ export const useMCPServersQuery = <TData = t.MCPServersListResponse>(
[QueryKeys.mcpServers],
() => dataService.getMCPServers(),
{
staleTime: 1000 * 60 * 5, // 5 minutes - data stays fresh longer
refetchOnWindowFocus: false,
staleTime: 30 * 1000, // 30 seconds — short enough to pick up servers that finish initializing after first load
refetchOnWindowFocus: true,
refetchOnReconnect: false,
refetchOnMount: false,
refetchOnMount: true,
retry: false,
...config,
},

View file

@ -1,4 +1,5 @@
import { useCallback } from 'react';
import { Constants } from 'librechat-data-provider';
import type { TStartupConfig, TSubmission } from 'librechat-data-provider';
import { useUpdateEphemeralAgent, useApplyNewAgentTemplate } from '~/store/agents';
import { getModelSpec, applyModelSpecEphemeralAgent } from '~/utils';
@ -6,6 +7,10 @@ import { getModelSpec, applyModelSpecEphemeralAgent } from '~/utils';
/**
* Hook that applies a model spec from a preset to an ephemeral agent.
* This is used when initializing a new conversation with a preset that has a spec.
*
* When a spec is provided, its tool settings are applied to the ephemeral agent.
* When no spec is provided but specs are configured, the ephemeral agent is reset
* to null so BadgeRowContext can apply localStorage defaults (non-spec experience).
*/
export function useApplyModelSpecEffects() {
const updateEphemeralAgent = useUpdateEphemeralAgent();
@ -20,6 +25,11 @@ export function useApplyModelSpecEffects() {
startupConfig?: TStartupConfig;
}) => {
if (specName == null || !specName) {
if (startupConfig?.modelSpecs?.list?.length) {
/** Specs are configured but none selected reset ephemeral agent to null
* so BadgeRowContext fills all values (tool toggles + MCP) from localStorage. */
updateEphemeralAgent((convoId ?? Constants.NEW_CONVO) || Constants.NEW_CONVO, null);
}
return;
}
@ -80,6 +90,9 @@ export function useApplyAgentTemplate() {
web_search: ephemeralAgent?.web_search ?? modelSpec.webSearch ?? false,
file_search: ephemeralAgent?.file_search ?? modelSpec.fileSearch ?? false,
execute_code: ephemeralAgent?.execute_code ?? modelSpec.executeCode ?? false,
artifacts:
ephemeralAgent?.artifacts ??
(modelSpec.artifacts === true ? 'default' : modelSpec.artifacts || ''),
};
mergedAgent.mcp = [...new Set(mergedAgent.mcp)];

View file

@ -1,7 +1,12 @@
import { useCallback } from 'react';
import { useRecoilValue } from 'recoil';
import { useGetModelsQuery } from 'librechat-data-provider/react-query';
import { getEndpointField, LocalStorageKeys, isAssistantsEndpoint } from 'librechat-data-provider';
import {
getEndpointField,
LocalStorageKeys,
isAssistantsEndpoint,
getDefaultParamsEndpoint,
} from 'librechat-data-provider';
import type { TEndpointsConfig, EModelEndpoint, TConversation } from 'librechat-data-provider';
import type { AssistantListItem, NewConversationParams } from '~/common';
import useAssistantListMap from '~/hooks/Assistants/useAssistantListMap';
@ -84,11 +89,13 @@ export default function useAddedResponse() {
}
const models = modelsConfig?.[defaultEndpoint ?? ''] ?? [];
const defaultParamsEndpoint = getDefaultParamsEndpoint(endpointsConfig, defaultEndpoint);
newConversation = buildDefaultConvo({
conversation: newConversation,
lastConversationSetup: preset as TConversation,
endpoint: defaultEndpoint ?? ('' as EModelEndpoint),
models,
defaultParamsEndpoint,
});
if (preset?.title != null && preset.title !== '') {

View file

@ -13,6 +13,7 @@ import {
parseCompactConvo,
replaceSpecialVars,
isAssistantsEndpoint,
getDefaultParamsEndpoint,
} from 'librechat-data-provider';
import type {
TMessage,
@ -96,6 +97,8 @@ export default function useChatFunctions({
) => {
setShowStopButton(false);
resetLatestMultiMessage();
text = text.trim();
if (!!isSubmitting || text === '') {
return;
}
@ -133,7 +136,6 @@ export default function useChatFunctions({
// construct the query message
// this is not a real messageId, it is used as placeholder before real messageId returned
text = text.trim();
const intermediateId = overrideUserMessageId ?? v4();
parentMessageId = parentMessageId ?? latestMessage?.messageId ?? Constants.NO_PARENT;
@ -173,12 +175,14 @@ export default function useChatFunctions({
const startupConfig = queryClient.getQueryData<TStartupConfig>([QueryKeys.startupConfig]);
const endpointType = getEndpointField(endpointsConfig, endpoint, 'type');
const iconURL = conversation?.iconURL;
const defaultParamsEndpoint = getDefaultParamsEndpoint(endpointsConfig, endpoint);
/** This becomes part of the `endpointOption` */
const convo = parseCompactConvo({
endpoint: endpoint as EndpointSchemaKey,
endpointType: endpointType as EndpointSchemaKey,
conversation: conversation ?? {},
defaultParamsEndpoint,
});
const { modelDisplayLabel } = endpointsConfig?.[endpoint ?? ''] ?? {};

View file

@ -1,5 +1,5 @@
import { excludedKeys } from 'librechat-data-provider';
import { useGetModelsQuery } from 'librechat-data-provider/react-query';
import { excludedKeys, getDefaultParamsEndpoint } from 'librechat-data-provider';
import type {
TEndpointsConfig,
TModelsConfig,
@ -47,11 +47,14 @@ const useDefaultConvo = () => {
}
}
const defaultParamsEndpoint = getDefaultParamsEndpoint(endpointsConfig, endpoint);
const defaultConvo = buildDefaultConvo({
conversation: conversation as TConversation,
endpoint,
lastConversationSetup: preset as TConversation,
models,
defaultParamsEndpoint,
});
if (!cleanOutput) {

View file

@ -106,6 +106,9 @@ export default function useExportConversation({
// TEXT
const textPart = content[ContentTypes.TEXT];
const text = typeof textPart === 'string' ? textPart : (textPart?.value ?? '');
if (text.trim().length === 0) {
return [];
}
return [sender, text];
}

View file

@ -1,7 +1,12 @@
import { useRecoilValue } from 'recoil';
import { useCallback, useRef, useEffect } from 'react';
import { useGetModelsQuery } from 'librechat-data-provider/react-query';
import { getEndpointField, LocalStorageKeys, isAssistantsEndpoint } from 'librechat-data-provider';
import {
getEndpointField,
LocalStorageKeys,
isAssistantsEndpoint,
getDefaultParamsEndpoint,
} from 'librechat-data-provider';
import type {
TEndpointsConfig,
EModelEndpoint,
@ -117,11 +122,13 @@ const useGenerateConvo = ({
}
const models = modelsConfig?.[defaultEndpoint ?? ''] ?? [];
const defaultParamsEndpoint = getDefaultParamsEndpoint(endpointsConfig, defaultEndpoint);
conversation = buildDefaultConvo({
conversation,
lastConversationSetup: preset as TConversation,
endpoint: defaultEndpoint ?? ('' as EModelEndpoint),
models,
defaultParamsEndpoint,
});
if (preset?.title != null && preset.title !== '') {

View file

@ -2,7 +2,13 @@ import { useCallback } from 'react';
import { useSetRecoilState } from 'recoil';
import { useNavigate } from 'react-router-dom';
import { useQueryClient } from '@tanstack/react-query';
import { QueryKeys, Constants, dataService, getEndpointField } from 'librechat-data-provider';
import {
QueryKeys,
Constants,
dataService,
getEndpointField,
getDefaultParamsEndpoint,
} from 'librechat-data-provider';
import type {
TEndpointsConfig,
TStartupConfig,
@ -106,11 +112,13 @@ const useNavigateToConvo = (index = 0) => {
const models = modelsConfig?.[defaultEndpoint ?? ''] ?? [];
const defaultParamsEndpoint = getDefaultParamsEndpoint(endpointsConfig, defaultEndpoint);
convo = buildDefaultConvo({
models,
conversation,
endpoint: defaultEndpoint,
lastConversationSetup: conversation,
defaultParamsEndpoint,
});
}
clearAllConversations(true);

View file

@ -415,7 +415,7 @@ describe('useMCPSelect', () => {
});
});
it('should handle empty ephemeralAgent.mcp array correctly', async () => {
it('should clear mcpValues when ephemeralAgent.mcp is set to empty array', async () => {
// Create a shared wrapper
const { Wrapper, servers } = createWrapper(['initial-value']);
@ -437,19 +437,21 @@ describe('useMCPSelect', () => {
expect(result.current.mcpHook.mcpValues).toEqual(['initial-value']);
});
// Try to set empty array externally
// Set empty array externally (e.g., spec with no MCP servers)
act(() => {
result.current.setEphemeralAgent({
mcp: [],
});
});
// Values should remain unchanged since empty mcp array doesn't trigger update
// (due to the condition: ephemeralAgent?.mcp && ephemeralAgent.mcp.length > 0)
expect(result.current.mcpHook.mcpValues).toEqual(['initial-value']);
// Jotai atom should be cleared — an explicit empty mcp array means
// the spec (or reset) has no MCP servers, so the visual selection must clear
await waitFor(() => {
expect(result.current.mcpHook.mcpValues).toEqual([]);
});
});
it('should handle ephemeralAgent with clear mcp value', async () => {
it('should handle ephemeralAgent being reset to null', async () => {
// Create a shared wrapper
const { Wrapper, servers } = createWrapper(['server1', 'server2']);
@ -471,16 +473,15 @@ describe('useMCPSelect', () => {
expect(result.current.mcpHook.mcpValues).toEqual(['server1', 'server2']);
});
// Set ephemeralAgent with clear value
// Reset ephemeralAgent to null (simulating non-spec reset)
act(() => {
result.current.setEphemeralAgent({
mcp: [Constants.mcp_clear as string],
});
result.current.setEphemeralAgent(null);
});
// mcpValues should be cleared
// mcpValues should remain unchanged since null ephemeral agent
// doesn't trigger the sync effect (mcps.length === 0)
await waitFor(() => {
expect(result.current.mcpHook.mcpValues).toEqual([]);
expect(result.current.mcpHook.mcpValues).toEqual(['server1', 'server2']);
});
});
@ -590,6 +591,233 @@ describe('useMCPSelect', () => {
});
});
describe('Environment-Keyed Storage (storageContextKey)', () => {
it('should use storageContextKey as atom key for new conversations', async () => {
const { Wrapper, servers } = createWrapper(['server1', 'server2']);
const storageContextKey = '__defaults__';
// Hook A: new conversation with storageContextKey
const { result: resultA } = renderHook(
() => useMCPSelect({ conversationId: null, storageContextKey, servers }),
{ wrapper: Wrapper },
);
act(() => {
resultA.current.setMCPValues(['server1']);
});
await waitFor(() => {
expect(resultA.current.mcpValues).toEqual(['server1']);
});
// Hook B: new conversation WITHOUT storageContextKey (different environment)
const { result: resultB } = renderHook(
() => useMCPSelect({ conversationId: null, servers }),
{ wrapper: Wrapper },
);
// Should NOT see server1 since it's a different atom (NEW_CONVO vs __defaults__)
expect(resultB.current.mcpValues).toEqual([]);
});
it('should use conversationId as atom key for existing conversations even with storageContextKey', async () => {
const conversationId = 'existing-convo-123';
const { Wrapper, servers } = createWrapper(['server1', 'server2']);
const storageContextKey = '__defaults__';
const { result } = renderHook(
() => useMCPSelect({ conversationId, storageContextKey, servers }),
{ wrapper: Wrapper },
);
act(() => {
result.current.setMCPValues(['server1', 'server2']);
});
await waitFor(() => {
expect(result.current.mcpValues).toEqual(['server1', 'server2']);
});
// Verify timestamp was written to the conversation key, not the environment key
const convoKey = `${LocalStorageKeys.LAST_MCP_}${conversationId}`;
expect(setTimestamp).toHaveBeenCalledWith(convoKey);
});
it('should dual-write to environment key when storageContextKey is provided', async () => {
const { Wrapper, servers } = createWrapper(['server1', 'server2']);
const storageContextKey = '__defaults__';
const { result } = renderHook(
() => useMCPSelect({ conversationId: null, storageContextKey, servers }),
{ wrapper: Wrapper },
);
act(() => {
result.current.setMCPValues(['server1', 'server2']);
});
await waitFor(() => {
// Verify dual-write to environment key
const envKey = `${LocalStorageKeys.LAST_MCP_}${storageContextKey}`;
expect(localStorage.getItem(envKey)).toEqual(JSON.stringify(['server1', 'server2']));
expect(setTimestamp).toHaveBeenCalledWith(envKey);
});
});
it('should NOT dual-write when storageContextKey is undefined', async () => {
const conversationId = 'convo-no-specs';
const { Wrapper, servers } = createWrapper(['server1']);
const { result } = renderHook(() => useMCPSelect({ conversationId, servers }), {
wrapper: Wrapper,
});
act(() => {
result.current.setMCPValues(['server1']);
});
await waitFor(() => {
expect(result.current.mcpValues).toEqual(['server1']);
});
// Only the conversation-keyed timestamp should be set, no environment key
const envKey = `${LocalStorageKeys.LAST_MCP_}__defaults__`;
expect(localStorage.getItem(envKey)).toBeNull();
});
it('should isolate per-conversation state from environment defaults', async () => {
const { Wrapper, servers } = createWrapper(['server1', 'server2', 'server3']);
const storageContextKey = '__defaults__';
// Set environment defaults via new conversation
const { result: newConvoResult } = renderHook(
() => useMCPSelect({ conversationId: null, storageContextKey, servers }),
{ wrapper: Wrapper },
);
act(() => {
newConvoResult.current.setMCPValues(['server1', 'server2']);
});
await waitFor(() => {
expect(newConvoResult.current.mcpValues).toEqual(['server1', 'server2']);
});
// Existing conversation should have its own isolated state
const { result: existingResult } = renderHook(
() => useMCPSelect({ conversationId: 'existing-convo', storageContextKey, servers }),
{ wrapper: Wrapper },
);
// Should start empty (its own atom), not inherit from defaults
expect(existingResult.current.mcpValues).toEqual([]);
// Set different value for existing conversation
act(() => {
existingResult.current.setMCPValues(['server3']);
});
await waitFor(() => {
expect(existingResult.current.mcpValues).toEqual(['server3']);
});
// New conversation defaults should be unchanged
expect(newConvoResult.current.mcpValues).toEqual(['server1', 'server2']);
});
});
describe('Spec/Non-Spec Context Switching', () => {
it('should clear MCP when ephemeral agent switches to empty mcp (spec with no MCP)', async () => {
const { Wrapper, servers } = createWrapper(['server1', 'server2']);
const storageContextKey = '__defaults__';
const TestComponent = ({ ctxKey }: { ctxKey?: string }) => {
const mcpHook = useMCPSelect({ conversationId: null, storageContextKey: ctxKey, servers });
const setEphemeralAgent = useSetRecoilState(ephemeralAgentByConvoId(Constants.NEW_CONVO));
return { mcpHook, setEphemeralAgent };
};
// Start in non-spec context with some servers selected
const { result } = renderHook(() => TestComponent({ ctxKey: storageContextKey }), {
wrapper: Wrapper,
});
act(() => {
result.current.mcpHook.setMCPValues(['server1', 'server2']);
});
await waitFor(() => {
expect(result.current.mcpHook.mcpValues).toEqual(['server1', 'server2']);
});
// Simulate switching to a spec with no MCP — ephemeral agent gets mcp: []
act(() => {
result.current.setEphemeralAgent({ mcp: [] });
});
// MCP values should clear since the spec explicitly has no MCP servers
await waitFor(() => {
expect(result.current.mcpHook.mcpValues).toEqual([]);
});
});
it('should handle ephemeral agent with spec MCP servers syncing to Jotai atom', async () => {
const { Wrapper, servers } = createWrapper(['spec-server1', 'spec-server2']);
const TestComponent = () => {
const mcpHook = useMCPSelect({ conversationId: null, servers });
const setEphemeralAgent = useSetRecoilState(ephemeralAgentByConvoId(Constants.NEW_CONVO));
return { mcpHook, setEphemeralAgent };
};
const { result } = renderHook(() => TestComponent(), { wrapper: Wrapper });
// Simulate spec application setting ephemeral agent MCP
act(() => {
result.current.setEphemeralAgent({
mcp: ['spec-server1', 'spec-server2'],
execute_code: true,
});
});
await waitFor(() => {
expect(result.current.mcpHook.mcpValues).toEqual(['spec-server1', 'spec-server2']);
});
});
it('should handle null ephemeral agent reset (non-spec with specs configured)', async () => {
const { Wrapper, servers } = createWrapper(['server1', 'server2']);
const TestComponent = () => {
const mcpHook = useMCPSelect({ servers });
const setEphemeralAgent = useSetRecoilState(ephemeralAgentByConvoId(Constants.NEW_CONVO));
return { mcpHook, setEphemeralAgent };
};
const { result } = renderHook(() => TestComponent(), { wrapper: Wrapper });
// Set values from a spec
act(() => {
result.current.setEphemeralAgent({ mcp: ['server1', 'server2'] });
});
await waitFor(() => {
expect(result.current.mcpHook.mcpValues).toEqual(['server1', 'server2']);
});
// Reset ephemeral agent to null (switching to non-spec)
act(() => {
result.current.setEphemeralAgent(null);
});
// mcpValues should remain unchanged — null ephemeral agent doesn't trigger sync
// (BadgeRowContext will fill from localStorage defaults separately)
await waitFor(() => {
expect(result.current.mcpHook.mcpValues).toEqual(['server1', 'server2']);
});
});
});
describe('Memory Leak Prevention', () => {
it('should not leak memory on repeated updates', async () => {
const values = Array.from({ length: 100 }, (_, i) => `value-${i}`);

View file

@ -9,9 +9,11 @@ import { MCPServerDefinition } from './useMCPServerManager';
export function useMCPSelect({
conversationId,
storageContextKey,
servers,
}: {
conversationId?: string | null;
storageContextKey?: string;
servers: MCPServerDefinition[];
}) {
const key = conversationId ?? Constants.NEW_CONVO;
@ -19,47 +21,61 @@ export function useMCPSelect({
return new Set(servers?.map((s) => s.serverName));
}, [servers]);
/**
* For new conversations, key the MCP atom by environment (spec or defaults)
* so switching between spec non-spec gives each its own atom.
* For existing conversations, key by conversation ID for per-conversation isolation.
*/
const isNewConvo = key === Constants.NEW_CONVO;
const mcpAtomKey = isNewConvo && storageContextKey ? storageContextKey : key;
const [isPinned, setIsPinned] = useAtom(mcpPinnedAtom);
const [mcpValues, setMCPValuesRaw] = useAtom(mcpValuesAtomFamily(key));
const [mcpValues, setMCPValuesRaw] = useAtom(mcpValuesAtomFamily(mcpAtomKey));
const [ephemeralAgent, setEphemeralAgent] = useRecoilState(ephemeralAgentByConvoId(key));
// Sync Jotai state with ephemeral agent state
// Sync ephemeral agent MCP → Jotai atom (strip unconfigured servers)
useEffect(() => {
const mcps = ephemeralAgent?.mcp ?? [];
if (mcps.length === 1 && mcps[0] === Constants.mcp_clear) {
setMCPValuesRaw([]);
} else if (mcps.length > 0) {
// Strip out servers that are not available in the startup config
const mcps = ephemeralAgent?.mcp;
if (Array.isArray(mcps) && mcps.length > 0 && configuredServers.size > 0) {
const activeMcps = mcps.filter((mcp) => configuredServers.has(mcp));
setMCPValuesRaw(activeMcps);
}
}, [ephemeralAgent?.mcp, setMCPValuesRaw, configuredServers]);
useEffect(() => {
setEphemeralAgent((prev) => {
if (!isEqual(prev?.mcp, mcpValues)) {
return { ...(prev ?? {}), mcp: mcpValues };
if (!isEqual(activeMcps, mcpValues)) {
setMCPValuesRaw(activeMcps);
}
return prev;
});
}, [mcpValues, setEphemeralAgent]);
} else if (Array.isArray(mcps) && mcps.length === 0 && mcpValues.length > 0) {
// Ephemeral agent explicitly has empty MCP (e.g., spec with no MCP servers) — clear atom
setMCPValuesRaw([]);
}
}, [ephemeralAgent?.mcp, setMCPValuesRaw, configuredServers, mcpValues]);
// Write timestamp when MCP values change
useEffect(() => {
const mcpStorageKey = `${LocalStorageKeys.LAST_MCP_}${key}`;
const mcpStorageKey = `${LocalStorageKeys.LAST_MCP_}${mcpAtomKey}`;
if (mcpValues.length > 0) {
setTimestamp(mcpStorageKey);
}
}, [mcpValues, key]);
}, [mcpValues, mcpAtomKey]);
/** Stable memoized setter */
/** Stable memoized setter with dual-write to environment key */
const setMCPValues = useCallback(
(value: string[]) => {
if (!Array.isArray(value)) {
return;
}
setMCPValuesRaw(value);
setEphemeralAgent((prev) => {
if (!isEqual(prev?.mcp, value)) {
return { ...(prev ?? {}), mcp: value };
}
return prev;
});
// Dual-write to environment key for new conversation defaults
if (storageContextKey) {
const envKey = `${LocalStorageKeys.LAST_MCP_}${storageContextKey}`;
localStorage.setItem(envKey, JSON.stringify(value));
setTimestamp(envKey);
}
},
[setMCPValuesRaw],
[setMCPValuesRaw, setEphemeralAgent, storageContextKey],
);
return {

View file

@ -28,7 +28,10 @@ export interface MCPServerDefinition {
// The init states (isInitializing, isCancellable, etc.) are stored in the global Jotai atom
type PollIntervals = Record<string, NodeJS.Timeout | null>;
export function useMCPServerManager({ conversationId }: { conversationId?: string | null } = {}) {
export function useMCPServerManager({
conversationId,
storageContextKey,
}: { conversationId?: string | null; storageContextKey?: string } = {}) {
const localize = useLocalize();
const queryClient = useQueryClient();
const { showToast } = useToastContext();
@ -73,6 +76,7 @@ export function useMCPServerManager({ conversationId }: { conversationId?: strin
const { mcpValues, setMCPValues, isPinned, setIsPinned } = useMCPSelect({
conversationId,
storageContextKey,
servers: selectableServers,
});
const mcpValuesRef = useRef(mcpValues);
@ -429,33 +433,6 @@ export function useMCPServerManager({ conversationId }: { conversationId?: strin
[startupConfig?.interface?.mcpServers?.placeholder, localize],
);
const batchToggleServers = useCallback(
(serverNames: string[]) => {
const connectedServers: string[] = [];
const disconnectedServers: string[] = [];
serverNames.forEach((serverName) => {
if (isInitializing(serverName)) {
return;
}
const serverStatus = connectionStatus?.[serverName];
if (serverStatus?.connectionState === 'connected') {
connectedServers.push(serverName);
} else {
disconnectedServers.push(serverName);
}
});
setMCPValues(connectedServers);
disconnectedServers.forEach((serverName) => {
initializeServer(serverName);
});
},
[connectionStatus, setMCPValues, initializeServer, isInitializing],
);
const toggleServerSelection = useCallback(
(serverName: string) => {
if (isInitializing(serverName)) {
@ -469,15 +446,10 @@ export function useMCPServerManager({ conversationId }: { conversationId?: strin
const filteredValues = currentValues.filter((name) => name !== serverName);
setMCPValues(filteredValues);
} else {
const serverStatus = connectionStatus?.[serverName];
if (serverStatus?.connectionState === 'connected') {
setMCPValues([...currentValues, serverName]);
} else {
initializeServer(serverName);
}
setMCPValues([...currentValues, serverName]);
}
},
[mcpValues, setMCPValues, connectionStatus, initializeServer, isInitializing],
[mcpValues, setMCPValues, isInitializing],
);
const handleConfigSave = useCallback(
@ -673,7 +645,6 @@ export function useMCPServerManager({ conversationId }: { conversationId?: strin
isPinned,
setIsPinned,
placeholderText,
batchToggleServers,
toggleServerSelection,
localize,

View file

@ -0,0 +1,328 @@
import React from 'react';
import { renderHook, act, waitFor } from '@testing-library/react';
import { LocalStorageKeys, Tools } from 'librechat-data-provider';
import { RecoilRoot, useRecoilValue, useSetRecoilState } from 'recoil';
import { ephemeralAgentByConvoId } from '~/store';
import { useToolToggle } from '../useToolToggle';
/**
* Tests for useToolToggle the hook responsible for toggling tool badges
* (code execution, web search, file search, artifacts) and persisting state.
*
* Desired behaviors:
* - User toggles persist to per-conversation localStorage
* - In non-spec mode with specs configured (storageContextKey = '__defaults__'),
* toggles ALSO persist to the defaults key so future new conversations inherit them
* - In spec mode (storageContextKey = undefined), toggles only persist per-conversation
* - The hook reflects the current ephemeral agent state
*/
// Mock data-provider auth query
jest.mock('~/data-provider', () => ({
useVerifyAgentToolAuth: jest.fn().mockReturnValue({
data: { authenticated: true },
}),
}));
// Mock timestamps (track calls without actual localStorage timestamp logic)
jest.mock('~/utils/timestamps', () => ({
setTimestamp: jest.fn(),
}));
// Mock useLocalStorageAlt (isPinned state — not relevant to our behavior tests)
jest.mock('~/hooks/useLocalStorageAlt', () => jest.fn(() => [false, jest.fn()]));
const Wrapper: React.FC<{ children: React.ReactNode }> = ({ children }) => (
<RecoilRoot>{children}</RecoilRoot>
);
describe('useToolToggle', () => {
beforeEach(() => {
jest.clearAllMocks();
localStorage.clear();
});
// ─── Dual-Write Behavior ───────────────────────────────────────────
describe('non-spec mode: dual-write to defaults key', () => {
const storageContextKey = '__defaults__';
it('should write to both conversation key and defaults key when user toggles a tool', () => {
const conversationId = 'convo-123';
const { result } = renderHook(
() =>
useToolToggle({
conversationId,
storageContextKey,
toolKey: Tools.execute_code,
localStorageKey: LocalStorageKeys.LAST_CODE_TOGGLE_,
isAuthenticated: true,
}),
{ wrapper: Wrapper },
);
act(() => {
result.current.handleChange({ value: true });
});
// Conversation key: per-conversation persistence
const convoKey = `${LocalStorageKeys.LAST_CODE_TOGGLE_}${conversationId}`;
// Defaults key: persists for future new conversations
const defaultsKey = `${LocalStorageKeys.LAST_CODE_TOGGLE_}${storageContextKey}`;
// Sync effect writes to conversation key
expect(localStorage.getItem(convoKey)).toBe(JSON.stringify(true));
// handleChange dual-writes to defaults key
expect(localStorage.getItem(defaultsKey)).toBe(JSON.stringify(true));
});
it('should persist false values to defaults key when user disables a tool', () => {
const { result } = renderHook(
() =>
useToolToggle({
conversationId: 'convo-456',
storageContextKey,
toolKey: Tools.web_search,
localStorageKey: LocalStorageKeys.LAST_WEB_SEARCH_TOGGLE_,
isAuthenticated: true,
}),
{ wrapper: Wrapper },
);
// Enable then disable
act(() => {
result.current.handleChange({ value: true });
});
act(() => {
result.current.handleChange({ value: false });
});
const defaultsKey = `${LocalStorageKeys.LAST_WEB_SEARCH_TOGGLE_}${storageContextKey}`;
expect(localStorage.getItem(defaultsKey)).toBe(JSON.stringify(false));
});
});
describe('spec mode: no dual-write', () => {
it('should only write to conversation key, not to any defaults key', () => {
const conversationId = 'spec-convo-789';
const { result } = renderHook(
() =>
useToolToggle({
conversationId,
storageContextKey: undefined, // spec mode
toolKey: Tools.execute_code,
localStorageKey: LocalStorageKeys.LAST_CODE_TOGGLE_,
isAuthenticated: true,
}),
{ wrapper: Wrapper },
);
act(() => {
result.current.handleChange({ value: true });
});
// Conversation key should have the value
const convoKey = `${LocalStorageKeys.LAST_CODE_TOGGLE_}${conversationId}`;
expect(localStorage.getItem(convoKey)).toBe(JSON.stringify(true));
// Defaults key should NOT have a value
const defaultsKey = `${LocalStorageKeys.LAST_CODE_TOGGLE_}__defaults__`;
expect(localStorage.getItem(defaultsKey)).toBeNull();
});
});
// ─── Per-Conversation Isolation ────────────────────────────────────
describe('per-conversation isolation', () => {
it('should maintain separate toggle state per conversation', () => {
const TestComponent = ({ conversationId }: { conversationId: string }) => {
const toggle = useToolToggle({
conversationId,
toolKey: Tools.execute_code,
localStorageKey: LocalStorageKeys.LAST_CODE_TOGGLE_,
isAuthenticated: true,
});
const ephemeralAgent = useRecoilValue(ephemeralAgentByConvoId(conversationId));
return { toggle, ephemeralAgent };
};
// Conversation A: enable code
const { result: resultA } = renderHook(() => TestComponent({ conversationId: 'convo-A' }), {
wrapper: Wrapper,
});
act(() => {
resultA.current.toggle.handleChange({ value: true });
});
// Conversation B: disable code
const { result: resultB } = renderHook(() => TestComponent({ conversationId: 'convo-B' }), {
wrapper: Wrapper,
});
act(() => {
resultB.current.toggle.handleChange({ value: false });
});
// Each conversation has its own value in localStorage
expect(localStorage.getItem(`${LocalStorageKeys.LAST_CODE_TOGGLE_}convo-A`)).toBe('true');
expect(localStorage.getItem(`${LocalStorageKeys.LAST_CODE_TOGGLE_}convo-B`)).toBe('false');
});
});
// ─── Ephemeral Agent Sync ──────────────────────────────────────────
describe('ephemeral agent reflects toggle state', () => {
it('should update ephemeral agent when user toggles a tool', async () => {
const conversationId = 'convo-sync-test';
const TestComponent = () => {
const toggle = useToolToggle({
conversationId,
toolKey: Tools.execute_code,
localStorageKey: LocalStorageKeys.LAST_CODE_TOGGLE_,
isAuthenticated: true,
});
const ephemeralAgent = useRecoilValue(ephemeralAgentByConvoId(conversationId));
return { toggle, ephemeralAgent };
};
const { result } = renderHook(() => TestComponent(), { wrapper: Wrapper });
act(() => {
result.current.toggle.handleChange({ value: true });
});
await waitFor(() => {
expect(result.current.ephemeralAgent?.execute_code).toBe(true);
});
act(() => {
result.current.toggle.handleChange({ value: false });
});
await waitFor(() => {
expect(result.current.ephemeralAgent?.execute_code).toBe(false);
});
});
it('should reflect external ephemeral agent changes in toolValue', async () => {
const conversationId = 'convo-external';
const TestComponent = () => {
const toggle = useToolToggle({
conversationId,
toolKey: Tools.web_search,
localStorageKey: LocalStorageKeys.LAST_WEB_SEARCH_TOGGLE_,
isAuthenticated: true,
});
const setEphemeralAgent = useSetRecoilState(ephemeralAgentByConvoId(conversationId));
return { toggle, setEphemeralAgent };
};
const { result } = renderHook(() => TestComponent(), { wrapper: Wrapper });
// External update (e.g., from applyModelSpecEphemeralAgent)
act(() => {
result.current.setEphemeralAgent({ web_search: true, execute_code: false });
});
await waitFor(() => {
expect(result.current.toggle.toolValue).toBe(true);
expect(result.current.toggle.isToolEnabled).toBe(true);
});
});
it('should sync externally-set ephemeral agent values to localStorage', async () => {
const conversationId = 'convo-sync-ls';
const TestComponent = () => {
const toggle = useToolToggle({
conversationId,
toolKey: Tools.file_search,
localStorageKey: LocalStorageKeys.LAST_FILE_SEARCH_TOGGLE_,
isAuthenticated: true,
});
const setEphemeralAgent = useSetRecoilState(ephemeralAgentByConvoId(conversationId));
return { toggle, setEphemeralAgent };
};
const { result } = renderHook(() => TestComponent(), { wrapper: Wrapper });
// Simulate applyModelSpecEphemeralAgent setting a value
act(() => {
result.current.setEphemeralAgent({ file_search: true });
});
// The sync effect should write to conversation-keyed localStorage
await waitFor(() => {
const storageKey = `${LocalStorageKeys.LAST_FILE_SEARCH_TOGGLE_}${conversationId}`;
expect(localStorage.getItem(storageKey)).toBe(JSON.stringify(true));
});
});
});
// ─── isToolEnabled computation ─────────────────────────────────────
describe('isToolEnabled computation', () => {
it('should return false when tool is not set', () => {
const { result } = renderHook(
() =>
useToolToggle({
conversationId: 'convo-1',
toolKey: Tools.execute_code,
localStorageKey: LocalStorageKeys.LAST_CODE_TOGGLE_,
isAuthenticated: true,
}),
{ wrapper: Wrapper },
);
expect(result.current.isToolEnabled).toBe(false);
});
it('should treat non-empty string as enabled (artifacts)', async () => {
const conversationId = 'convo-artifacts';
const TestComponent = () => {
const toggle = useToolToggle({
conversationId,
toolKey: 'artifacts',
localStorageKey: LocalStorageKeys.LAST_ARTIFACTS_TOGGLE_,
isAuthenticated: true,
});
const setEphemeralAgent = useSetRecoilState(ephemeralAgentByConvoId(conversationId));
return { toggle, setEphemeralAgent };
};
const { result } = renderHook(() => TestComponent(), { wrapper: Wrapper });
act(() => {
result.current.setEphemeralAgent({ artifacts: 'default' });
});
await waitFor(() => {
expect(result.current.toggle.isToolEnabled).toBe(true);
});
});
it('should treat empty string as disabled (artifacts off)', async () => {
const conversationId = 'convo-no-artifacts';
const TestComponent = () => {
const toggle = useToolToggle({
conversationId,
toolKey: 'artifacts',
localStorageKey: LocalStorageKeys.LAST_ARTIFACTS_TOGGLE_,
isAuthenticated: true,
});
const setEphemeralAgent = useSetRecoilState(ephemeralAgentByConvoId(conversationId));
return { toggle, setEphemeralAgent };
};
const { result } = renderHook(() => TestComponent(), { wrapper: Wrapper });
act(() => {
result.current.setEphemeralAgent({ artifacts: '' });
});
await waitFor(() => {
expect(result.current.toggle.isToolEnabled).toBe(false);
});
});
});
});

View file

@ -13,6 +13,7 @@ type ToolValue = boolean | string;
interface UseToolToggleOptions {
conversationId?: string | null;
storageContextKey?: string;
toolKey: string;
localStorageKey: LocalStorageKeys;
isAuthenticated?: boolean;
@ -26,6 +27,7 @@ interface UseToolToggleOptions {
export function useToolToggle({
conversationId,
storageContextKey,
toolKey: _toolKey,
localStorageKey,
isAuthenticated: externalIsAuthenticated,
@ -93,8 +95,22 @@ export function useToolToggle({
...(prev || {}),
[toolKey]: value,
}));
// Dual-write to environment key for new conversation defaults
if (storageContextKey) {
const envKey = `${localStorageKey}${storageContextKey}`;
localStorage.setItem(envKey, JSON.stringify(value));
setTimestamp(envKey);
}
},
[setIsDialogOpen, isAuthenticated, setEphemeralAgent, toolKey],
[
setIsDialogOpen,
isAuthenticated,
setEphemeralAgent,
toolKey,
storageContextKey,
localStorageKey,
],
);
const debouncedChange = useMemo(

View file

@ -14,6 +14,7 @@ import {
LocalStorageKeys,
isEphemeralAgentId,
isAssistantsEndpoint,
getDefaultParamsEndpoint,
} from 'librechat-data-provider';
import type {
TPreset,
@ -191,11 +192,13 @@ const useNewConvo = (index = 0) => {
}
const models = modelsConfig?.[defaultEndpoint] ?? [];
const defaultParamsEndpoint = getDefaultParamsEndpoint(endpointsConfig, defaultEndpoint);
conversation = buildDefaultConvo({
conversation,
lastConversationSetup: activePreset as TConversation,
endpoint: defaultEndpoint,
models,
defaultParamsEndpoint,
});
}

View file

@ -859,11 +859,13 @@
"com_ui_create_api_key": "Create API Key",
"com_ui_create_assistant": "Create Assistant",
"com_ui_create_link": "Create link",
"com_ui_create_mcp_server": "Create MCP server",
"com_ui_create_memory": "Create Memory",
"com_ui_create_new_agent": "Create New Agent",
"com_ui_create_prompt": "Create Prompt",
"com_ui_create_prompt_page": "New Prompt Configuration Page",
"com_ui_created": "Created",
"com_ui_creating": "Creating...",
"com_ui_creating_image": "Creating image. May take a moment",
"com_ui_current": "Current",
"com_ui_currently_production": "Currently in production",
@ -904,6 +906,8 @@
"com_ui_delete_confirm_strong": "This will delete <strong>{{title}}</strong>",
"com_ui_delete_conversation": "Delete chat?",
"com_ui_delete_conversation_tooltip": "Delete conversation",
"com_ui_delete_mcp_server": "Delete MCP Server?",
"com_ui_delete_mcp_server_name": "Delete MCP server {{0}}",
"com_ui_delete_memory": "Delete Memory",
"com_ui_delete_not_allowed": "Delete operation is not allowed",
"com_ui_delete_preset": "Delete Preset?",
@ -916,6 +920,7 @@
"com_ui_delete_tool_confirm": "Are you sure you want to delete this tool?",
"com_ui_delete_tool_save_reminder": "Tool removed. Save the agent to apply changes.",
"com_ui_deleted": "Deleted",
"com_ui_deleting": "Deleting...",
"com_ui_deleting_file": "Deleting file...",
"com_ui_descending": "Desc",
"com_ui_description": "Description",
@ -1438,6 +1443,8 @@
"com_ui_unset": "Unset",
"com_ui_untitled": "Untitled",
"com_ui_update": "Update",
"com_ui_update_mcp_server": "Update MCP server",
"com_ui_updating": "Updating...",
"com_ui_upload": "Upload",
"com_ui_upload_agent_avatar": "Successfully updated agent avatar",
"com_ui_upload_agent_avatar_label": "Upload agent avatar image",

View file

@ -224,6 +224,7 @@
"com_endpoint_agent": "Aģents",
"com_endpoint_agent_placeholder": "Lūdzu, izvēlieties aģentu",
"com_endpoint_ai": "Mākslīgais intelekts",
"com_endpoint_anthropic_effort": "Kontrolē, cik lielu skaitļošanas piepūli piemēro Claude. Mazāka piepūle ietaupa tokenus un samazina ātrumu; lielāka piepūle nodrošina rūpīgākas atbildes. 'Max' ļauj veikt visdziļāko argumentāciju (tikai Opus 4.6).",
"com_endpoint_anthropic_maxoutputtokens": "Maksimālais atbildē ģenerējamo tokenu skaits. Norādiet zemāku vērtību īsākām atbildēm un augstāku vērtību garākām atbildēm. Piezīme: modeļi var apstāties pirms šī maksimālā skaita sasniegšanas.",
"com_endpoint_anthropic_prompt_cache": "Uzvednes kešatmiņa ļauj atkārtoti izmantot lielu kontekstu vai instrukcijas API izsaukumos, samazinot izmaksas un ābildes ātrumu.",
"com_endpoint_anthropic_temp": "Diapazons no 0 līdz 1. Analītiskiem/atbilžu variantiem izmantot temp vērtību tuvāk 0, bet radošiem un ģeneratīviem uzdevumiem — tuvāk 1. Iesakām mainīt šo vai Top P, bet ne abus.",
@ -265,6 +266,7 @@
"com_endpoint_default_with_num": "noklusējums: {{0}}",
"com_endpoint_disable_streaming": "Izslēgt atbilžu straumēšanu un saņemt visu atbildi uzreiz. Noderīgi tādiem modeļiem kā o3, kas pieprasa organizācijas pārbaudi straumēšanai.",
"com_endpoint_disable_streaming_label": "Atspējot straumēšanu",
"com_endpoint_effort": "Piepūle",
"com_endpoint_examples": "Iestatījumi",
"com_endpoint_export": "Eksportēt",
"com_endpoint_export_share": "Eksportēt/kopīgot",
@ -857,11 +859,13 @@
"com_ui_create_api_key": "Izveidot API atslēgu",
"com_ui_create_assistant": "Izveidot palīgu",
"com_ui_create_link": "Izveidot saiti",
"com_ui_create_mcp_server": "Izveidot MCP serveri",
"com_ui_create_memory": "Izveidot atmiņu",
"com_ui_create_new_agent": "Izveidot jaunu aģentu",
"com_ui_create_prompt": "Izveidot uzvedni",
"com_ui_create_prompt_page": "Jauna uzvedņu konfigurācijas lapa",
"com_ui_created": "Izveidots",
"com_ui_creating": "Notiek izveide...",
"com_ui_creating_image": "Attēla izveide. Var aizņemt brīdi.",
"com_ui_current": "Pašreizējais",
"com_ui_currently_production": "Pašlaik produkcijā",
@ -902,6 +906,8 @@
"com_ui_delete_confirm_strong": "Šis izdzēsīs <strong>{{title}}</strong>",
"com_ui_delete_conversation": "Dzēst sarunu?",
"com_ui_delete_conversation_tooltip": "Dzēst sarunu",
"com_ui_delete_mcp_server": "Vai dzēst MCP serveri?",
"com_ui_delete_mcp_server_name": "Dzēst MCP serveri {{0}}",
"com_ui_delete_memory": "Dzēst atmiņu",
"com_ui_delete_not_allowed": "Dzēšanas darbība nav atļauta",
"com_ui_delete_preset": "Vai dzēst iestatījumu?",
@ -914,6 +920,7 @@
"com_ui_delete_tool_confirm": "Vai tiešām vēlaties dzēst šo rīku?",
"com_ui_delete_tool_save_reminder": "Rīks noņemts. Saglabājiet aģentu, lai piemērotu izmaiņas.",
"com_ui_deleted": "Dzēsts",
"com_ui_deleting": "Dzēš...",
"com_ui_deleting_file": "Dzēšu failu...",
"com_ui_descending": "Dilstošs",
"com_ui_description": "Apraksts",
@ -1084,6 +1091,7 @@
"com_ui_manage": "Pārvaldīt",
"com_ui_marketplace": "Katalogs",
"com_ui_marketplace_allow_use": "Atļaut izmantot katalogu",
"com_ui_max": "Maksimums",
"com_ui_max_favorites_reached": "Sasniegts maksimālais piesprausto elementu skaits ({{0}}). Atvienojiet elementu, lai pievienotu citu.",
"com_ui_max_file_size": "PNG, JPG vai JPEG (maks. {{0}})",
"com_ui_max_tags": "Maksimālais atļautais skaits ir {{0}}, izmantojot jaunākās vērtības.",
@ -1437,6 +1445,8 @@
"com_ui_unset": "Neuzlikts",
"com_ui_untitled": "Bez nosaukuma",
"com_ui_update": "Atjauninājums",
"com_ui_update_mcp_server": "Atjaunināt MCP serveri",
"com_ui_updating": "Atjaunina...",
"com_ui_upload": "Augšupielādēt",
"com_ui_upload_agent_avatar": "Aģenta avatars veiksmīgi atjaunināts",
"com_ui_upload_agent_avatar_label": "Augšupielādēt aģenta avatāra attēlu",

Some files were not shown because too many files have changed in this diff Show more