Merge branch 'main' into fix/client-image-resize-threshold

This commit is contained in:
mattdaniell 2026-03-31 14:12:54 +10:30 committed by GitHub
commit 4b14978473
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
246 changed files with 15929 additions and 2880 deletions

3
.gitattributes vendored Normal file
View file

@ -0,0 +1,3 @@
# Force LF line endings for shell scripts and git hooks (required for cross-platform compatibility)
.husky/* text eol=lf
*.sh text eol=lf

View file

@ -97,6 +97,65 @@ jobs:
path: packages/api/dist
retention-days: 2
typecheck:
name: TypeScript type checks
needs: build
runs-on: ubuntu-latest
timeout-minutes: 10
steps:
- uses: actions/checkout@v4
- name: Use Node.js 20.19
uses: actions/setup-node@v4
with:
node-version: '20.19'
- name: Restore node_modules cache
id: cache-node-modules
uses: actions/cache@v4
with:
path: |
node_modules
api/node_modules
packages/api/node_modules
packages/data-provider/node_modules
packages/data-schemas/node_modules
key: node-modules-backend-${{ runner.os }}-20.19-${{ hashFiles('package-lock.json') }}
- name: Install dependencies
if: steps.cache-node-modules.outputs.cache-hit != 'true'
run: npm ci
- name: Download data-provider build
uses: actions/download-artifact@v4
with:
name: build-data-provider
path: packages/data-provider/dist
- name: Download data-schemas build
uses: actions/download-artifact@v4
with:
name: build-data-schemas
path: packages/data-schemas/dist
- name: Download api build
uses: actions/download-artifact@v4
with:
name: build-api
path: packages/api/dist
- name: Type check data-provider
run: npx tsc --noEmit -p packages/data-provider/tsconfig.json
- name: Type check data-schemas
run: npx tsc --noEmit -p packages/data-schemas/tsconfig.json
- name: Type check @librechat/api
run: npx tsc --noEmit -p packages/api/tsconfig.json
- name: Type check @librechat/client
run: npx tsc --noEmit -p packages/client/tsconfig.json
circular-deps:
name: Circular dependency checks
needs: build

View file

@ -29,6 +29,12 @@ The source code for `@librechat/agents` (major backend dependency, same team) is
## Code Style
### Naming and File Organization
- **Single-word file names** whenever possible (e.g., `permissions.ts`, `capabilities.ts`, `service.ts`).
- When multiple words are needed, prefer grouping related modules under a **single-word directory** rather than using multi-word file names (e.g., `admin/capabilities.ts` not `adminCapabilities.ts`).
- The directory already provides context — `app/service.ts` not `app/appConfigService.ts`.
### Structure and Clarity
- **Never-nesting**: early returns, flat code, minimal indentation. Break complex operations into well-named helpers.

View file

@ -32,7 +32,7 @@
</p>
<p align="center">
<a href="https://railway.com/deploy/b5k2mn?referralCode=HI9hWz">
<a href="https://railway.com/deploy/librechat-official?referralCode=HI9hWz&utm_medium=integration&utm_source=readme&utm_campaign=librechat">
<img src="https://railway.com/button.svg" alt="Deploy on Railway" height="30">
</a>
<a href="https://zeabur.com/templates/0X2ZY8">

View file

@ -1,4 +1,4 @@
<!-- Last synced with README.md: 2026-03-20 (e442984364db02163f3cc3ecb7b2ee5efba66fb9) -->
<!-- Last synced with README.md: 2026-03-28 (cae3888) -->
<p align="center">
<a href="https://librechat.ai">
@ -34,7 +34,7 @@
</p>
<p align="center">
<a href="https://railway.com/deploy/b5k2mn?referralCode=HI9hWz">
<a href="https://railway.com/deploy/librechat-official?referralCode=HI9hWz&utm_medium=integration&utm_source=readme&utm_campaign=librechat">
<img src="https://railway.com/button.svg" alt="Deploy on Railway" height="30">
</a>
<a href="https://zeabur.com/templates/0X2ZY8">

View file

@ -32,7 +32,6 @@ class BaseClient {
constructor(apiKey, options = {}) {
this.apiKey = apiKey;
this.sender = options.sender ?? 'AI';
this.contextStrategy = null;
this.currentDateString = new Date().toLocaleDateString('en-us', {
year: 'numeric',
month: 'long',

View file

@ -14,7 +14,6 @@ const {
buildImageToolContext,
buildWebSearchContext,
} = require('@librechat/api');
const { getMCPServersRegistry } = require('~/config');
const {
Tools,
Constants,
@ -39,12 +38,13 @@ const {
createGeminiImageTool,
createOpenAIImageTools,
} = require('../');
const { primeFiles: primeCodeFiles } = require('~/server/services/Files/Code/process');
const { createMCPTool, createMCPTools, resolveConfigServers } = require('~/server/services/MCP');
const { createFileSearchTool, primeFiles: primeSearchFiles } = require('./fileSearch');
const { primeFiles: primeCodeFiles } = require('~/server/services/Files/Code/process');
const { getUserPluginAuthValue } = require('~/server/services/PluginService');
const { createMCPTool, createMCPTools } = require('~/server/services/MCP');
const { loadAuthValues } = require('~/server/services/Tools/credentials');
const { getMCPServerTools } = require('~/server/services/Config');
const { getMCPServersRegistry } = require('~/config');
const { getRoleByName } = require('~/models');
/**
@ -256,6 +256,12 @@ const loadTools = async ({
const toolContextMap = {};
const requestedMCPTools = {};
/** Resolve config-source servers for the current user/tenant context */
let configServers;
if (tools.some((tool) => tool && mcpToolPattern.test(tool))) {
configServers = await resolveConfigServers(options.req);
}
for (const tool of tools) {
if (tool === Tools.execute_code) {
requestedTools[tool] = async () => {
@ -341,7 +347,7 @@ const loadTools = async ({
continue;
}
const serverConfig = serverName
? await getMCPServersRegistry().getServerConfig(serverName, user)
? await getMCPServersRegistry().getServerConfig(serverName, user, configServers)
: null;
if (!serverConfig) {
logger.warn(
@ -419,6 +425,7 @@ const loadTools = async ({
let index = -1;
const failedMCPServers = new Set();
const safeUser = createSafeUser(options.req?.user);
for (const [serverName, toolConfigs] of Object.entries(requestedMCPTools)) {
index++;
/** @type {LCAvailableTools} */
@ -433,6 +440,7 @@ const loadTools = async ({
signal,
user: safeUser,
userMCPAuthMap,
configServers,
res: options.res,
streamId: options.req?._resumableStreamId || null,
model: agent?.model ?? model,

View file

@ -123,9 +123,6 @@ function disposeClient(client) {
if (client.maxContextTokens) {
client.maxContextTokens = null;
}
if (client.contextStrategy) {
client.contextStrategy = null;
}
if (client.currentDateString) {
client.currentDateString = null;
}

View file

@ -1,5 +1,5 @@
const { logger } = require('@librechat/data-schemas');
const { CacheKeys } = require('librechat-data-provider');
const { logger, scopedCacheKey } = require('@librechat/data-schemas');
const { loadDefaultModels, loadConfigModels } = require('~/server/services/Config');
const { getLogStores } = require('~/cache');
@ -9,7 +9,8 @@ const { getLogStores } = require('~/cache');
*/
const getModelsConfig = async (req) => {
const cache = getLogStores(CacheKeys.CONFIG_STORE);
let modelsConfig = await cache.get(CacheKeys.MODELS_CONFIG);
const cacheKey = scopedCacheKey(CacheKeys.MODELS_CONFIG);
let modelsConfig = await cache.get(cacheKey);
if (!modelsConfig) {
modelsConfig = await loadModels(req);
}
@ -24,7 +25,8 @@ const getModelsConfig = async (req) => {
*/
async function loadModels(req) {
const cache = getLogStores(CacheKeys.CONFIG_STORE);
const cachedModelsConfig = await cache.get(CacheKeys.MODELS_CONFIG);
const cacheKey = scopedCacheKey(CacheKeys.MODELS_CONFIG);
const cachedModelsConfig = await cache.get(cacheKey);
if (cachedModelsConfig) {
return cachedModelsConfig;
}
@ -33,7 +35,7 @@ async function loadModels(req) {
const modelConfig = { ...defaultModelsConfig, ...customModelsConfig };
await cache.set(CacheKeys.MODELS_CONFIG, modelConfig);
await cache.set(cacheKey, modelConfig);
return modelConfig;
}

View file

@ -1,5 +1,5 @@
const { logger } = require('@librechat/data-schemas');
const { CacheKeys } = require('librechat-data-provider');
const { logger, scopedCacheKey } = require('@librechat/data-schemas');
const { getToolkitKey, checkPluginAuth, filterUniquePlugins } = require('@librechat/api');
const { getCachedTools, setCachedTools } = require('~/server/services/Config');
const { availableTools, toolkits } = require('~/app/clients/tools');
@ -9,13 +9,14 @@ const { getLogStores } = require('~/cache');
const getAvailablePluginsController = async (req, res) => {
try {
const cache = getLogStores(CacheKeys.TOOL_CACHE);
const cachedPlugins = await cache.get(CacheKeys.PLUGINS);
const pluginsCacheKey = scopedCacheKey(CacheKeys.PLUGINS);
const cachedPlugins = await cache.get(pluginsCacheKey);
if (cachedPlugins) {
res.status(200).json(cachedPlugins);
return;
}
const appConfig = await getAppConfig({ role: req.user?.role });
const appConfig = await getAppConfig({ role: req.user?.role, tenantId: req.user?.tenantId });
/** @type {{ filteredTools: string[], includedTools: string[] }} */
const { filteredTools = [], includedTools = [] } = appConfig;
/** @type {import('@librechat/api').LCManifestTool[]} */
@ -37,7 +38,7 @@ const getAvailablePluginsController = async (req, res) => {
plugins = plugins.filter((plugin) => !filteredTools.includes(plugin.pluginKey));
}
await cache.set(CacheKeys.PLUGINS, plugins);
await cache.set(pluginsCacheKey, plugins);
res.status(200).json(plugins);
} catch (error) {
res.status(500).json({ message: error.message });
@ -64,9 +65,11 @@ const getAvailableTools = async (req, res) => {
return res.status(401).json({ message: 'Unauthorized' });
}
const cache = getLogStores(CacheKeys.TOOL_CACHE);
const cachedToolsArray = await cache.get(CacheKeys.TOOLS);
const toolsCacheKey = scopedCacheKey(CacheKeys.TOOLS);
const cachedToolsArray = await cache.get(toolsCacheKey);
const appConfig = req.config ?? (await getAppConfig({ role: req.user?.role }));
const appConfig =
req.config ?? (await getAppConfig({ role: req.user?.role, tenantId: req.user?.tenantId }));
// Return early if we have cached tools
if (cachedToolsArray != null) {
@ -114,7 +117,7 @@ const getAvailableTools = async (req, res) => {
}
const finalTools = filterUniquePlugins(toolsOutput);
await cache.set(CacheKeys.TOOLS, finalTools);
await cache.set(toolsCacheKey, finalTools);
res.status(200).json(finalTools);
} catch (error) {

View file

@ -8,6 +8,7 @@ jest.mock('@librechat/data-schemas', () => ({
error: jest.fn(),
warn: jest.fn(),
},
scopedCacheKey: jest.fn((key) => key),
}));
jest.mock('~/server/services/Config', () => ({

View file

@ -26,7 +26,7 @@ const { getLogStores } = require('~/cache');
const db = require('~/models');
const getUserController = async (req, res) => {
const appConfig = await getAppConfig({ role: req.user?.role });
const appConfig = await getAppConfig({ role: req.user?.role, tenantId: req.user?.tenantId });
/** @type {IUser} */
const userData = req.user.toObject != null ? req.user.toObject() : { ...req.user };
/**
@ -165,7 +165,7 @@ const deleteUserMcpServers = async (userId) => {
};
const updateUserPluginsController = async (req, res) => {
const appConfig = await getAppConfig({ role: req.user?.role });
const appConfig = await getAppConfig({ role: req.user?.role, tenantId: req.user?.tenantId });
const { user } = req;
const { pluginKey, action, auth, isEntityTool } = req.body;
try {

View file

@ -50,6 +50,7 @@ const {
const { filterFilesByAgentAccess } = require('~/server/services/Files/permissions');
const { encodeAndFormat } = require('~/server/services/Files/images/encode');
const { createContextHandlers } = require('~/app/clients/prompts');
const { resolveConfigServers } = require('~/server/services/MCP');
const { getMCPServerTools } = require('~/server/services/Config');
const BaseClient = require('~/app/clients/BaseClient');
const { getMCPManager } = require('~/config');
@ -377,6 +378,9 @@ class AgentClient extends BaseClient {
*/
const ephemeralAgent = this.options.req.body.ephemeralAgent;
const mcpManager = getMCPManager();
const configServers = await resolveConfigServers(this.options.req);
await Promise.all(
allAgents.map(({ agent, agentId }) =>
applyContextToAgent({
@ -384,6 +388,7 @@ class AgentClient extends BaseClient {
agentId,
logger,
mcpManager,
configServers,
sharedRunContext,
ephemeralAgent: agentId === this.options.agent.id ? ephemeralAgent : undefined,
}),

View file

@ -22,6 +22,10 @@ jest.mock('~/server/services/Config', () => ({
getMCPServerTools: jest.fn(),
}));
jest.mock('~/server/services/MCP', () => ({
resolveConfigServers: jest.fn().mockResolvedValue({}),
}));
jest.mock('~/models', () => ({
getAgent: jest.fn(),
getRoleByName: jest.fn(),
@ -1315,7 +1319,7 @@ describe('AgentClient - titleConvo', () => {
});
// Verify formatInstructionsForContext was called with correct server names
expect(mockFormatInstructions).toHaveBeenCalledWith(['server1', 'server2']);
expect(mockFormatInstructions).toHaveBeenCalledWith(['server1', 'server2'], {});
// Verify the instructions do NOT contain [object Promise]
expect(client.options.agent.instructions).not.toContain('[object Promise]');
@ -1355,10 +1359,10 @@ describe('AgentClient - titleConvo', () => {
});
// Verify formatInstructionsForContext was called with ephemeral server names
expect(mockFormatInstructions).toHaveBeenCalledWith([
'ephemeral-server1',
'ephemeral-server2',
]);
expect(mockFormatInstructions).toHaveBeenCalledWith(
['ephemeral-server1', 'ephemeral-server2'],
{},
);
// Verify no [object Promise] in instructions
expect(client.options.agent.instructions).not.toContain('[object Promise]');

View file

@ -14,6 +14,7 @@ const {
isMCPInspectionFailedError,
} = require('@librechat/api');
const { Constants, MCPServerUserInputSchema } = require('librechat-data-provider');
const { resolveConfigServers, resolveAllMcpConfigs } = require('~/server/services/MCP');
const { cacheMCPServerTools, getMCPServerTools } = require('~/server/services/Config');
const { getMCPManager, getMCPServersRegistry } = require('~/config');
@ -57,7 +58,7 @@ function handleMCPError(error, res) {
}
/**
* Get all MCP tools available to the user
* Get all MCP tools available to the user.
*/
const getMCPTools = async (req, res) => {
try {
@ -67,10 +68,10 @@ const getMCPTools = async (req, res) => {
return res.status(401).json({ message: 'Unauthorized' });
}
const mcpConfig = await getMCPServersRegistry().getAllServerConfigs(userId);
const configuredServers = mcpConfig ? Object.keys(mcpConfig) : [];
const mcpConfig = await resolveAllMcpConfigs(userId, req.user);
const configuredServers = Object.keys(mcpConfig);
if (!mcpConfig || Object.keys(mcpConfig).length == 0) {
if (!configuredServers.length) {
return res.status(200).json({ servers: {} });
}
@ -115,14 +116,11 @@ const getMCPTools = async (req, res) => {
try {
const serverTools = serverToolsMap.get(serverName);
// Get server config once
const serverConfig = mcpConfig[serverName];
const rawServerConfig = await getMCPServersRegistry().getServerConfig(serverName, userId);
// Initialize server object with all server-level data
const server = {
name: serverName,
icon: rawServerConfig?.iconPath || '',
icon: serverConfig?.iconPath || '',
authenticated: true,
authConfig: [],
tools: [],
@ -183,7 +181,7 @@ const getMCPServersList = async (req, res) => {
return res.status(401).json({ message: 'Unauthorized' });
}
const serverConfigs = await getMCPServersRegistry().getAllServerConfigs(userId);
const serverConfigs = await resolveAllMcpConfigs(userId, req.user);
return res.json(redactAllServerSecrets(serverConfigs));
} catch (error) {
logger.error('[getMCPServersList]', error);
@ -237,7 +235,12 @@ const getMCPServerById = async (req, res) => {
if (!serverName) {
return res.status(400).json({ message: 'Server name is required' });
}
const parsedConfig = await getMCPServersRegistry().getServerConfig(serverName, userId);
const configServers = await resolveConfigServers(req);
const parsedConfig = await getMCPServersRegistry().getServerConfig(
serverName,
userId,
configServers,
);
if (!parsedConfig) {
return res.status(404).json({ message: 'MCP server not found' });

View file

@ -8,8 +8,8 @@ const express = require('express');
const passport = require('passport');
const compression = require('compression');
const cookieParser = require('cookie-parser');
const { logger } = require('@librechat/data-schemas');
const mongoSanitize = require('express-mongo-sanitize');
const { logger, runAsSystem } = require('@librechat/data-schemas');
const {
isEnabled,
apiNotFound,
@ -21,6 +21,7 @@ const {
createStreamServices,
initializeFileStorage,
updateInterfacePermissions,
preAuthTenantMiddleware,
} = require('@librechat/api');
const { connectDb, indexSync } = require('~/db');
const initializeOAuthReconnectManager = require('./services/initializeOAuthReconnectManager');
@ -59,11 +60,20 @@ const startServer = async () => {
app.disable('x-powered-by');
app.set('trust proxy', trusted_proxy);
await seedDatabase();
const appConfig = await getAppConfig();
if (isEnabled(process.env.TENANT_ISOLATION_STRICT)) {
logger.warn(
'[Security] TENANT_ISOLATION_STRICT is active. Ensure your reverse proxy strips or sets ' +
'the X-Tenant-Id header — untrusted clients must not be able to set it directly.',
);
}
await runAsSystem(seedDatabase);
const appConfig = await getAppConfig({ baseOnly: true });
initializeFileStorage(appConfig);
await performStartupChecks(appConfig);
await updateInterfacePermissions({ appConfig, getRoleByName, updateAccessPermissions });
await runAsSystem(async () => {
await performStartupChecks(appConfig);
await updateInterfacePermissions({ appConfig, getRoleByName, updateAccessPermissions });
});
const indexPath = path.join(appConfig.paths.dist, 'index.html');
let indexHTML = fs.readFileSync(indexPath, 'utf8');
@ -137,10 +147,15 @@ const startServer = async () => {
/* Per-request capability cache — must be registered before any route that calls hasCapability */
app.use(capabilityContextMiddleware);
app.use('/oauth', routes.oauth);
/* Pre-auth tenant context for unauthenticated routes that need tenant scoping.
* The reverse proxy / auth gateway sets `X-Tenant-Id` header for multi-tenant deployments. */
app.use('/oauth', preAuthTenantMiddleware, routes.oauth);
/* API Endpoints */
app.use('/api/auth', routes.auth);
app.use('/api/auth', preAuthTenantMiddleware, routes.auth);
app.use('/api/admin', routes.adminAuth);
app.use('/api/admin/config', routes.adminConfig);
app.use('/api/admin/groups', routes.adminGroups);
app.use('/api/admin/roles', routes.adminRoles);
app.use('/api/actions', routes.actions);
app.use('/api/keys', routes.keys);
app.use('/api/api-keys', routes.apiKeys);
@ -154,11 +169,11 @@ const startServer = async () => {
app.use('/api/endpoints', routes.endpoints);
app.use('/api/balance', routes.balance);
app.use('/api/models', routes.models);
app.use('/api/config', routes.config);
app.use('/api/config', preAuthTenantMiddleware, routes.config);
app.use('/api/assistants', routes.assistants);
app.use('/api/files', await routes.files.initialize());
app.use('/images/', createValidateImageRequest(appConfig.secureImageLinks), routes.staticRoute);
app.use('/api/share', routes.share);
app.use('/api/share', preAuthTenantMiddleware, routes.share);
app.use('/api/roles', routes.roles);
app.use('/api/agents', routes.agents);
app.use('/api/banner', routes.banner);
@ -204,8 +219,10 @@ const startServer = async () => {
logger.info(`Server listening at http://${host == '0.0.0.0' ? 'localhost' : host}:${port}`);
}
await initializeMCPs();
await initializeOAuthReconnectManager();
await runAsSystem(async () => {
await initializeMCPs();
await initializeOAuthReconnectManager();
});
await checkMigrations();
// Configure stream services (auto-detects Redis from USE_REDIS env var)

View file

@ -0,0 +1,116 @@
/**
* Integration test: verifies that requireJwtAuth chains tenantContextMiddleware
* after successful passport authentication, so ALS tenant context is set for
* all downstream middleware and route handlers.
*
* requireJwtAuth must chain tenantContextMiddleware after passport populates
* req.user (not at global app.use() scope where req.user is undefined).
* If the chaining is removed, these tests fail.
*/
const { getTenantId } = require('@librechat/data-schemas');
// ── Mocks ──────────────────────────────────────────────────────────────
let mockPassportError = null;
jest.mock('passport', () => ({
authenticate: jest.fn(() => {
return (req, _res, done) => {
if (mockPassportError) {
return done(mockPassportError);
}
if (req._mockUser) {
req.user = req._mockUser;
}
done();
};
}),
}));
// Mock @librechat/api — the real tenantContextMiddleware is TS and cannot be
// required directly from CJS tests. This thin wrapper mirrors the real logic
// (read req.user.tenantId, call tenantStorage.run) using the same data-schemas
// primitives. The real implementation is covered by packages/api tenant.spec.ts.
jest.mock('@librechat/api', () => {
const { tenantStorage } = require('@librechat/data-schemas');
return {
isEnabled: jest.fn(() => false),
tenantContextMiddleware: (req, res, next) => {
const tenantId = req.user?.tenantId;
if (!tenantId) {
return next();
}
return tenantStorage.run({ tenantId }, async () => next());
},
};
});
// ── Helpers ─────────────────────────────────────────────────────────────
const requireJwtAuth = require('../requireJwtAuth');
function mockReq(user) {
return { headers: {}, _mockUser: user };
}
function mockRes() {
return { status: jest.fn().mockReturnThis(), json: jest.fn().mockReturnThis() };
}
/** Runs requireJwtAuth and returns the tenantId observed inside next(). */
function runAuth(user) {
return new Promise((resolve) => {
const req = mockReq(user);
const res = mockRes();
requireJwtAuth(req, res, () => {
resolve(getTenantId());
});
});
}
// ── Tests ──────────────────────────────────────────────────────────────
describe('requireJwtAuth tenant context chaining', () => {
afterEach(() => {
mockPassportError = null;
});
it('forwards passport errors to next() without entering tenant middleware', async () => {
mockPassportError = new Error('JWT signature invalid');
const req = mockReq(undefined);
const res = mockRes();
const err = await new Promise((resolve) => {
requireJwtAuth(req, res, (e) => resolve(e));
});
expect(err).toBeInstanceOf(Error);
expect(err.message).toBe('JWT signature invalid');
expect(getTenantId()).toBeUndefined();
});
it('sets ALS tenant context after passport auth succeeds', async () => {
const tenantId = await runAuth({ tenantId: 'tenant-abc', role: 'user' });
expect(tenantId).toBe('tenant-abc');
});
it('ALS tenant context is NOT set when user has no tenantId', async () => {
const tenantId = await runAuth({ role: 'user' });
expect(tenantId).toBeUndefined();
});
it('ALS tenant context is NOT set when user is undefined', async () => {
const tenantId = await runAuth(undefined);
expect(tenantId).toBeUndefined();
});
it('concurrent requests get isolated tenant contexts', async () => {
const results = await Promise.all(
['tenant-1', 'tenant-2', 'tenant-3'].map((tid) => runAuth({ tenantId: tid, role: 'user' })),
);
expect(results).toEqual(['tenant-1', 'tenant-2', 'tenant-3']);
});
it('ALS context is not set at top-level scope (outside any request)', () => {
expect(getTenantId()).toBeUndefined();
});
});

View file

@ -18,6 +18,7 @@ const checkDomainAllowed = async (req, res, next) => {
const email = req?.user?.email;
const appConfig = await getAppConfig({
role: req?.user?.role,
tenantId: req?.user?.tenantId,
});
if (email && !isEmailDomainAllowed(email, appConfig?.registration?.allowedDomains)) {

View file

@ -4,7 +4,9 @@ const { getAppConfig } = require('~/server/services/Config');
const configMiddleware = async (req, res, next) => {
try {
const userRole = req.user?.role;
req.config = await getAppConfig({ role: userRole });
const userId = req.user?.id;
const tenantId = req.user?.tenantId;
req.config = await getAppConfig({ role: userRole, userId, tenantId });
next();
} catch (error) {

View file

@ -1,9 +1,10 @@
const cookies = require('cookie');
const passport = require('passport');
const { isEnabled } = require('@librechat/api');
const { isEnabled, tenantContextMiddleware } = require('@librechat/api');
// This middleware does not require authentication,
// but if the user is authenticated, it will set the user object.
// but if the user is authenticated, it will set the user object
// and establish tenant ALS context.
const optionalJwtAuth = (req, res, next) => {
const cookieHeader = req.headers.cookie;
const tokenProvider = cookieHeader ? cookies.parse(cookieHeader).token_provider : null;
@ -13,6 +14,7 @@ const optionalJwtAuth = (req, res, next) => {
}
if (user) {
req.user = user;
return tenantContextMiddleware(req, res, next);
}
next();
};

View file

@ -1,20 +1,29 @@
const cookies = require('cookie');
const passport = require('passport');
const { isEnabled } = require('@librechat/api');
const { isEnabled, tenantContextMiddleware } = require('@librechat/api');
/**
* Custom Middleware to handle JWT authentication, with support for OpenID token reuse
* Switches between JWT and OpenID authentication based on cookies and environment settings
* Custom Middleware to handle JWT authentication, with support for OpenID token reuse.
* Switches between JWT and OpenID authentication based on cookies and environment settings.
*
* After successful authentication (req.user populated), automatically chains into
* `tenantContextMiddleware` to propagate `req.user.tenantId` into AsyncLocalStorage
* for downstream Mongoose tenant isolation.
*/
const requireJwtAuth = (req, res, next) => {
const cookieHeader = req.headers.cookie;
const tokenProvider = cookieHeader ? cookies.parse(cookieHeader).token_provider : null;
if (tokenProvider === 'openid' && isEnabled(process.env.OPENID_REUSE_TOKENS)) {
return passport.authenticate('openidJwt', { session: false })(req, res, next);
}
const strategy =
tokenProvider === 'openid' && isEnabled(process.env.OPENID_REUSE_TOKENS) ? 'openidJwt' : 'jwt';
return passport.authenticate('jwt', { session: false })(req, res, next);
passport.authenticate(strategy, { session: false })(req, res, (err) => {
if (err) {
return next(err);
}
// req.user is now populated by passport — set up tenant ALS context
tenantContextMiddleware(req, res, next);
});
};
module.exports = requireJwtAuth;

View file

@ -18,6 +18,7 @@ const mockRegistryInstance = {
getServerConfig: jest.fn(),
getOAuthServers: jest.fn(),
getAllServerConfigs: jest.fn(),
ensureConfigServers: jest.fn().mockResolvedValue({}),
addServer: jest.fn(),
updateServer: jest.fn(),
removeServer: jest.fn(),
@ -58,6 +59,7 @@ jest.mock('@librechat/api', () => {
});
jest.mock('@librechat/data-schemas', () => ({
getTenantId: jest.fn(),
logger: {
debug: jest.fn(),
info: jest.fn(),
@ -93,14 +95,18 @@ jest.mock('~/server/services/Config', () => ({
getCachedTools: jest.fn(),
getMCPServerTools: jest.fn(),
loadCustomConfig: jest.fn(),
getAppConfig: jest.fn().mockResolvedValue({ mcpConfig: {} }),
}));
jest.mock('~/server/services/Config/mcp', () => ({
updateMCPServerTools: jest.fn(),
}));
const mockResolveAllMcpConfigs = jest.fn().mockResolvedValue({});
jest.mock('~/server/services/MCP', () => ({
getMCPSetupData: jest.fn(),
resolveConfigServers: jest.fn().mockResolvedValue({}),
resolveAllMcpConfigs: (...args) => mockResolveAllMcpConfigs(...args),
getServerConnectionStatus: jest.fn(),
}));
@ -579,6 +585,112 @@ describe('MCP Routes', () => {
);
});
it('should use oauthHeaders from flow state when present', async () => {
const mockFlowManager = {
getFlowState: jest.fn().mockResolvedValue({ status: 'PENDING' }),
completeFlow: jest.fn().mockResolvedValue(),
deleteFlow: jest.fn().mockResolvedValue(true),
};
const mockFlowState = {
serverName: 'test-server',
userId: 'test-user-id',
metadata: { toolFlowId: 'tool-flow-123' },
clientInfo: {},
codeVerifier: 'test-verifier',
oauthHeaders: { 'X-Custom-Auth': 'header-value' },
};
const mockTokens = { access_token: 'tok', refresh_token: 'ref' };
MCPOAuthHandler.getFlowState.mockResolvedValue(mockFlowState);
MCPOAuthHandler.completeOAuthFlow.mockResolvedValue(mockTokens);
MCPTokenStorage.storeTokens.mockResolvedValue();
getLogStores.mockReturnValue({});
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
require('~/config').getOAuthReconnectionManager.mockReturnValue({
clearReconnection: jest.fn(),
});
require('~/config').getMCPManager.mockReturnValue({
getUserConnection: jest.fn().mockResolvedValue({
fetchTools: jest.fn().mockResolvedValue([]),
}),
});
const { getCachedTools, setCachedTools } = require('~/server/services/Config');
getCachedTools.mockResolvedValue({});
setCachedTools.mockResolvedValue();
const flowId = 'test-user-id:test-server';
const csrfToken = generateTestCsrfToken(flowId);
await request(app)
.get('/api/mcp/test-server/oauth/callback')
.set('Cookie', [`oauth_csrf=${csrfToken}`])
.query({ code: 'auth-code', state: flowId });
expect(MCPOAuthHandler.completeOAuthFlow).toHaveBeenCalledWith(
flowId,
'auth-code',
mockFlowManager,
{ 'X-Custom-Auth': 'header-value' },
);
expect(mockRegistryInstance.getServerConfig).not.toHaveBeenCalled();
});
it('should fall back to registry oauth_headers when flow state lacks them', async () => {
const mockFlowManager = {
getFlowState: jest.fn().mockResolvedValue({ status: 'PENDING' }),
completeFlow: jest.fn().mockResolvedValue(),
deleteFlow: jest.fn().mockResolvedValue(true),
};
const mockFlowState = {
serverName: 'test-server',
userId: 'test-user-id',
metadata: { toolFlowId: 'tool-flow-123' },
clientInfo: {},
codeVerifier: 'test-verifier',
};
const mockTokens = { access_token: 'tok', refresh_token: 'ref' };
MCPOAuthHandler.getFlowState.mockResolvedValue(mockFlowState);
MCPOAuthHandler.completeOAuthFlow.mockResolvedValue(mockTokens);
MCPTokenStorage.storeTokens.mockResolvedValue();
mockRegistryInstance.getServerConfig.mockResolvedValue({
oauth_headers: { 'X-Registry-Header': 'from-registry' },
});
getLogStores.mockReturnValue({});
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
require('~/config').getOAuthReconnectionManager.mockReturnValue({
clearReconnection: jest.fn(),
});
require('~/config').getMCPManager.mockReturnValue({
getUserConnection: jest.fn().mockResolvedValue({
fetchTools: jest.fn().mockResolvedValue([]),
}),
});
const { getCachedTools, setCachedTools } = require('~/server/services/Config');
getCachedTools.mockResolvedValue({});
setCachedTools.mockResolvedValue();
const flowId = 'test-user-id:test-server';
const csrfToken = generateTestCsrfToken(flowId);
await request(app)
.get('/api/mcp/test-server/oauth/callback')
.set('Cookie', [`oauth_csrf=${csrfToken}`])
.query({ code: 'auth-code', state: flowId });
expect(MCPOAuthHandler.completeOAuthFlow).toHaveBeenCalledWith(
flowId,
'auth-code',
mockFlowManager,
{ 'X-Registry-Header': 'from-registry' },
);
expect(mockRegistryInstance.getServerConfig).toHaveBeenCalledWith(
'test-server',
'test-user-id',
undefined,
);
});
it('should redirect to error page when callback processing fails', async () => {
MCPOAuthHandler.getFlowState.mockRejectedValue(new Error('Callback error'));
const flowId = 'test-user-id:test-server';
@ -1350,19 +1462,10 @@ describe('MCP Routes', () => {
},
});
expect(getMCPSetupData).toHaveBeenCalledWith('test-user-id');
expect(getMCPSetupData).toHaveBeenCalledWith('test-user-id', expect.any(Object));
expect(getServerConnectionStatus).toHaveBeenCalledTimes(2);
});
it('should return 404 when MCP config is not found', async () => {
getMCPSetupData.mockRejectedValue(new Error('MCP config not found'));
const response = await request(app).get('/api/mcp/connection/status');
expect(response.status).toBe(404);
expect(response.body).toEqual({ error: 'MCP config not found' });
});
it('should return 500 when connection status check fails', async () => {
getMCPSetupData.mockRejectedValue(new Error('Database error'));
@ -1437,15 +1540,6 @@ describe('MCP Routes', () => {
});
});
it('should return 404 when MCP config is not found', async () => {
getMCPSetupData.mockRejectedValue(new Error('MCP config not found'));
const response = await request(app).get('/api/mcp/connection/status/test-server');
expect(response.status).toBe(404);
expect(response.body).toEqual({ error: 'MCP config not found' });
});
it('should return 500 when connection status check fails', async () => {
getMCPSetupData.mockRejectedValue(new Error('Database connection failed'));
@ -1704,7 +1798,7 @@ describe('MCP Routes', () => {
},
};
mockRegistryInstance.getAllServerConfigs.mockResolvedValue(mockServerConfigs);
mockResolveAllMcpConfigs.mockResolvedValue(mockServerConfigs);
const response = await request(app).get('/api/mcp/servers');
@ -1721,11 +1815,14 @@ describe('MCP Routes', () => {
});
expect(response.body['server-1'].headers).toBeUndefined();
expect(response.body['server-2'].headers).toBeUndefined();
expect(mockRegistryInstance.getAllServerConfigs).toHaveBeenCalledWith('test-user-id');
expect(mockResolveAllMcpConfigs).toHaveBeenCalledWith(
'test-user-id',
expect.objectContaining({ id: 'test-user-id' }),
);
});
it('should return empty object when no servers are configured', async () => {
mockRegistryInstance.getAllServerConfigs.mockResolvedValue({});
mockResolveAllMcpConfigs.mockResolvedValue({});
const response = await request(app).get('/api/mcp/servers');
@ -1749,7 +1846,7 @@ describe('MCP Routes', () => {
});
it('should return 500 when server config retrieval fails', async () => {
mockRegistryInstance.getAllServerConfigs.mockRejectedValue(new Error('Database error'));
mockResolveAllMcpConfigs.mockRejectedValue(new Error('Database error'));
const response = await request(app).get('/api/mcp/servers');
@ -1939,11 +2036,12 @@ describe('MCP Routes', () => {
expect(mockRegistryInstance.getServerConfig).toHaveBeenCalledWith(
'test-server',
'test-user-id',
{},
);
});
it('should return 404 when server not found', async () => {
mockRegistryInstance.getServerConfig.mockResolvedValue(null);
mockRegistryInstance.getServerConfig.mockResolvedValue(undefined);
const response = await request(app).get('/api/mcp/servers/non-existent-server');

View file

@ -0,0 +1,40 @@
const express = require('express');
const { createAdminConfigHandlers } = require('@librechat/api');
const { SystemCapabilities } = require('@librechat/data-schemas');
const {
hasConfigCapability,
requireCapability,
} = require('~/server/middleware/roles/capabilities');
const { getAppConfig, invalidateConfigCaches } = require('~/server/services/Config');
const { requireJwtAuth } = require('~/server/middleware');
const db = require('~/models');
const router = express.Router();
const requireAdminAccess = requireCapability(SystemCapabilities.ACCESS_ADMIN);
const handlers = createAdminConfigHandlers({
listAllConfigs: db.listAllConfigs,
findConfigByPrincipal: db.findConfigByPrincipal,
upsertConfig: db.upsertConfig,
patchConfigFields: db.patchConfigFields,
unsetConfigField: db.unsetConfigField,
deleteConfig: db.deleteConfig,
toggleConfigActive: db.toggleConfigActive,
hasConfigCapability,
getAppConfig,
invalidateConfigCaches,
});
router.use(requireJwtAuth, requireAdminAccess);
router.get('/', handlers.listConfigs);
router.get('/base', handlers.getBaseConfig);
router.get('/:principalType/:principalId', handlers.getConfig);
router.put('/:principalType/:principalId', handlers.upsertConfigOverrides);
router.patch('/:principalType/:principalId/fields', handlers.patchConfigField);
router.delete('/:principalType/:principalId/fields', handlers.deleteConfigField);
router.delete('/:principalType/:principalId', handlers.deleteConfigOverrides);
router.patch('/:principalType/:principalId/active', handlers.toggleConfig);
module.exports = router;

View file

@ -0,0 +1,41 @@
const express = require('express');
const { createAdminGroupsHandlers } = require('@librechat/api');
const { SystemCapabilities } = require('@librechat/data-schemas');
const { requireCapability } = require('~/server/middleware/roles/capabilities');
const { requireJwtAuth } = require('~/server/middleware');
const db = require('~/models');
const router = express.Router();
const requireAdminAccess = requireCapability(SystemCapabilities.ACCESS_ADMIN);
const requireReadGroups = requireCapability(SystemCapabilities.READ_GROUPS);
const requireManageGroups = requireCapability(SystemCapabilities.MANAGE_GROUPS);
const handlers = createAdminGroupsHandlers({
listGroups: db.listGroups,
countGroups: db.countGroups,
findGroupById: db.findGroupById,
createGroup: db.createGroup,
updateGroupById: db.updateGroupById,
deleteGroup: db.deleteGroup,
addUserToGroup: db.addUserToGroup,
removeUserFromGroup: db.removeUserFromGroup,
removeMemberById: db.removeMemberById,
findUsers: db.findUsers,
deleteConfig: db.deleteConfig,
deleteAclEntries: db.deleteAclEntries,
deleteGrantsForPrincipal: db.deleteGrantsForPrincipal,
});
router.use(requireJwtAuth, requireAdminAccess);
router.get('/', requireReadGroups, handlers.listGroups);
router.post('/', requireManageGroups, handlers.createGroup);
router.get('/:id', requireReadGroups, handlers.getGroup);
router.patch('/:id', requireManageGroups, handlers.updateGroup);
router.delete('/:id', requireManageGroups, handlers.deleteGroup);
router.get('/:id/members', requireReadGroups, handlers.getGroupMembers);
router.post('/:id/members', requireManageGroups, handlers.addGroupMember);
router.delete('/:id/members/:userId', requireManageGroups, handlers.removeGroupMember);
module.exports = router;

View file

@ -0,0 +1,43 @@
const express = require('express');
const { createAdminRolesHandlers } = require('@librechat/api');
const { SystemCapabilities } = require('@librechat/data-schemas');
const { requireCapability } = require('~/server/middleware/roles/capabilities');
const { requireJwtAuth } = require('~/server/middleware');
const db = require('~/models');
const router = express.Router();
const requireAdminAccess = requireCapability(SystemCapabilities.ACCESS_ADMIN);
const requireReadRoles = requireCapability(SystemCapabilities.READ_ROLES);
const requireManageRoles = requireCapability(SystemCapabilities.MANAGE_ROLES);
const handlers = createAdminRolesHandlers({
listRoles: db.listRoles,
countRoles: db.countRoles,
getRoleByName: db.getRoleByName,
createRoleByName: db.createRoleByName,
updateRoleByName: db.updateRoleByName,
updateAccessPermissions: db.updateAccessPermissions,
deleteRoleByName: db.deleteRoleByName,
findUser: db.findUser,
updateUser: db.updateUser,
updateUsersByRole: db.updateUsersByRole,
findUserIdsByRole: db.findUserIdsByRole,
updateUsersRoleByIds: db.updateUsersRoleByIds,
listUsersByRole: db.listUsersByRole,
countUsersByRole: db.countUsersByRole,
});
router.use(requireJwtAuth, requireAdminAccess);
router.get('/', requireReadRoles, handlers.listRoles);
router.post('/', requireManageRoles, handlers.createRole);
router.get('/:name', requireReadRoles, handlers.getRole);
router.patch('/:name', requireManageRoles, handlers.updateRole);
router.delete('/:name', requireManageRoles, handlers.deleteRole);
router.patch('/:name/permissions', requireManageRoles, handlers.updateRolePermissions);
router.get('/:name/members', requireReadRoles, handlers.getRoleMembers);
router.post('/:name/members', requireManageRoles, handlers.addRoleMember);
router.delete('/:name/members/:userId', requireManageRoles, handlers.removeRoleMember);
module.exports = router;

View file

@ -0,0 +1,186 @@
const express = require('express');
const request = require('supertest');
const mockGenerationJobManager = {
getJob: jest.fn(),
subscribe: jest.fn(),
getResumeState: jest.fn(),
abortJob: jest.fn(),
getActiveJobIdsForUser: jest.fn().mockResolvedValue([]),
};
jest.mock('@librechat/data-schemas', () => ({
...jest.requireActual('@librechat/data-schemas'),
logger: {
debug: jest.fn(),
warn: jest.fn(),
error: jest.fn(),
info: jest.fn(),
},
}));
jest.mock('@librechat/api', () => ({
...jest.requireActual('@librechat/api'),
isEnabled: jest.fn().mockReturnValue(false),
GenerationJobManager: mockGenerationJobManager,
}));
jest.mock('~/models', () => ({
saveMessage: jest.fn(),
}));
let mockUserId = 'user-123';
let mockTenantId;
jest.mock('~/server/middleware', () => ({
uaParser: (req, res, next) => next(),
checkBan: (req, res, next) => next(),
requireJwtAuth: (req, res, next) => {
req.user = { id: mockUserId, tenantId: mockTenantId };
next();
},
messageIpLimiter: (req, res, next) => next(),
configMiddleware: (req, res, next) => next(),
messageUserLimiter: (req, res, next) => next(),
}));
jest.mock('~/server/routes/agents/chat', () => require('express').Router());
jest.mock('~/server/routes/agents/v1', () => ({
v1: require('express').Router(),
}));
jest.mock('~/server/routes/agents/openai', () => require('express').Router());
jest.mock('~/server/routes/agents/responses', () => require('express').Router());
const agentsRouter = require('../index');
const app = express();
app.use(express.json());
app.use('/agents', agentsRouter);
function mockSubscribeSuccess() {
mockGenerationJobManager.subscribe.mockImplementation((_streamId, _writeEvent, onDone) => {
process.nextTick(() => onDone({ done: true }));
return { unsubscribe: jest.fn() };
});
}
describe('SSE stream tenant isolation', () => {
beforeEach(() => {
jest.clearAllMocks();
mockUserId = 'user-123';
mockTenantId = undefined;
});
describe('GET /chat/stream/:streamId', () => {
it('returns 403 when a user from a different tenant accesses a stream', async () => {
mockUserId = 'user-456';
mockTenantId = 'tenant-b';
mockGenerationJobManager.getJob.mockResolvedValue({
metadata: { userId: 'user-456', tenantId: 'tenant-a' },
status: 'running',
});
const res = await request(app).get('/agents/chat/stream/stream-123');
expect(res.status).toBe(403);
expect(res.body.error).toBe('Unauthorized');
});
it('returns 404 when stream does not exist', async () => {
mockGenerationJobManager.getJob.mockResolvedValue(null);
const res = await request(app).get('/agents/chat/stream/nonexistent');
expect(res.status).toBe(404);
});
it('proceeds past tenant guard when tenant matches', async () => {
mockUserId = 'user-123';
mockTenantId = 'tenant-a';
mockSubscribeSuccess();
mockGenerationJobManager.getJob.mockResolvedValue({
metadata: { userId: 'user-123', tenantId: 'tenant-a' },
status: 'running',
});
const res = await request(app).get('/agents/chat/stream/stream-123');
expect(res.status).toBe(200);
expect(mockGenerationJobManager.subscribe).toHaveBeenCalledTimes(1);
});
it('proceeds past tenant guard when job has no tenantId (single-tenant mode)', async () => {
mockUserId = 'user-123';
mockTenantId = undefined;
mockSubscribeSuccess();
mockGenerationJobManager.getJob.mockResolvedValue({
metadata: { userId: 'user-123' },
status: 'running',
});
const res = await request(app).get('/agents/chat/stream/stream-123');
expect(res.status).toBe(200);
expect(mockGenerationJobManager.subscribe).toHaveBeenCalledTimes(1);
});
it('returns 403 when job has tenantId but user has no tenantId', async () => {
mockUserId = 'user-123';
mockTenantId = undefined;
mockGenerationJobManager.getJob.mockResolvedValue({
metadata: { userId: 'user-123', tenantId: 'some-tenant' },
status: 'running',
});
const res = await request(app).get('/agents/chat/stream/stream-123');
expect(res.status).toBe(403);
});
});
describe('GET /chat/status/:conversationId', () => {
it('returns 403 when tenant does not match', async () => {
mockUserId = 'user-123';
mockTenantId = 'tenant-b';
mockGenerationJobManager.getJob.mockResolvedValue({
metadata: { userId: 'user-123', tenantId: 'tenant-a' },
status: 'running',
});
const res = await request(app).get('/agents/chat/status/conv-123');
expect(res.status).toBe(403);
expect(res.body.error).toBe('Unauthorized');
});
it('returns status when tenant matches', async () => {
mockUserId = 'user-123';
mockTenantId = 'tenant-a';
mockGenerationJobManager.getJob.mockResolvedValue({
metadata: { userId: 'user-123', tenantId: 'tenant-a' },
status: 'running',
createdAt: Date.now(),
});
mockGenerationJobManager.getResumeState.mockResolvedValue(null);
const res = await request(app).get('/agents/chat/status/conv-123');
expect(res.status).toBe(200);
expect(res.body.active).toBe(true);
});
});
describe('POST /chat/abort', () => {
it('returns 403 when tenant does not match', async () => {
mockUserId = 'user-123';
mockTenantId = 'tenant-b';
mockGenerationJobManager.getJob.mockResolvedValue({
metadata: { userId: 'user-123', tenantId: 'tenant-a' },
status: 'running',
});
const res = await request(app).post('/agents/chat/abort').send({ streamId: 'stream-123' });
expect(res.status).toBe(403);
expect(res.body.error).toBe('Unauthorized');
});
});
});

View file

@ -17,6 +17,11 @@ const chat = require('./chat');
const { LIMIT_MESSAGE_IP, LIMIT_MESSAGE_USER } = process.env ?? {};
/** Untenanted jobs (pre-multi-tenancy) remain accessible if the userId check passes. */
function hasTenantMismatch(job, user) {
return job.metadata?.tenantId != null && job.metadata.tenantId !== user.tenantId;
}
const router = express.Router();
/**
@ -67,6 +72,10 @@ router.get('/chat/stream/:streamId', async (req, res) => {
return res.status(403).json({ error: 'Unauthorized' });
}
if (hasTenantMismatch(job, req.user)) {
return res.status(403).json({ error: 'Unauthorized' });
}
res.setHeader('Content-Encoding', 'identity');
res.setHeader('Content-Type', 'text/event-stream');
res.setHeader('Cache-Control', 'no-cache, no-transform');
@ -150,7 +159,10 @@ router.get('/chat/stream/:streamId', async (req, res) => {
* @returns { activeJobIds: string[] }
*/
router.get('/chat/active', async (req, res) => {
const activeJobIds = await GenerationJobManager.getActiveJobIdsForUser(req.user.id);
const activeJobIds = await GenerationJobManager.getActiveJobIdsForUser(
req.user.id,
req.user.tenantId,
);
res.json({ activeJobIds });
});
@ -174,6 +186,10 @@ router.get('/chat/status/:conversationId', async (req, res) => {
return res.status(403).json({ error: 'Unauthorized' });
}
if (hasTenantMismatch(job, req.user)) {
return res.status(403).json({ error: 'Unauthorized' });
}
// Get resume state which contains aggregatedContent
// Avoid calling both getStreamInfo and getResumeState (both fetch content)
const resumeState = await GenerationJobManager.getResumeState(conversationId);
@ -213,7 +229,10 @@ router.post('/chat/abort', async (req, res) => {
// This handles the case where frontend sends "new" but job was created with a UUID
if (!job && userId) {
logger.debug(`[AgentStream] Job not found by ID, checking active jobs for user: ${userId}`);
const activeJobIds = await GenerationJobManager.getActiveJobIdsForUser(userId);
const activeJobIds = await GenerationJobManager.getActiveJobIdsForUser(
userId,
req.user.tenantId,
);
if (activeJobIds.length > 0) {
// Abort the most recent active job for this user
jobStreamId = activeJobIds[0];
@ -230,6 +249,10 @@ router.post('/chat/abort', async (req, res) => {
return res.status(403).json({ error: 'Unauthorized' });
}
if (hasTenantMismatch(job, req.user)) {
return res.status(403).json({ error: 'Unauthorized' });
}
logger.debug(`[AgentStream] Job found, aborting: ${jobStreamId}`);
const abortResult = await GenerationJobManager.abortJob(jobStreamId);
logger.debug(`[AgentStream] Job aborted successfully: ${jobStreamId}`, {

View file

@ -1,7 +1,7 @@
const express = require('express');
const { logger } = require('@librechat/data-schemas');
const { isEnabled, getBalanceConfig } = require('@librechat/api');
const { CacheKeys, defaultSocialLogins } = require('librechat-data-provider');
const { logger, getTenantId, scopedCacheKey } = require('@librechat/data-schemas');
const { getLdapConfig } = require('~/server/services/Config/ldap');
const { getAppConfig } = require('~/server/services/Config/app');
const { getLogStores } = require('~/cache');
@ -23,7 +23,8 @@ const openidReuseTokens = isEnabled(process.env.OPENID_REUSE_TOKENS);
router.get('/', async function (req, res) {
const cache = getLogStores(CacheKeys.CONFIG_STORE);
const cachedStartupConfig = await cache.get(CacheKeys.STARTUP_CONFIG);
const cacheKey = scopedCacheKey(CacheKeys.STARTUP_CONFIG);
const cachedStartupConfig = await cache.get(cacheKey);
if (cachedStartupConfig) {
res.send(cachedStartupConfig);
return;
@ -37,7 +38,10 @@ router.get('/', async function (req, res) {
const ldap = getLdapConfig();
try {
const appConfig = await getAppConfig({ role: req.user?.role });
const appConfig = await getAppConfig({
role: req.user?.role,
tenantId: req.user?.tenantId || getTenantId(),
});
const isOpenIdEnabled =
!!process.env.OPENID_CLIENT_ID &&
@ -141,7 +145,7 @@ router.get('/', async function (req, res) {
payload.customFooter = process.env.CUSTOM_FOOTER;
}
await cache.set(CacheKeys.STARTUP_CONFIG, payload);
await cache.set(cacheKey, payload);
return res.status(200).send(payload);
} catch (err) {
logger.error('Error in startup config', err);

View file

@ -2,6 +2,9 @@ const accessPermissions = require('./accessPermissions');
const assistants = require('./assistants');
const categories = require('./categories');
const adminAuth = require('./admin/auth');
const adminConfig = require('./admin/config');
const adminGroups = require('./admin/groups');
const adminRoles = require('./admin/roles');
const endpoints = require('./endpoints');
const staticRoute = require('./static');
const messages = require('./messages');
@ -31,6 +34,9 @@ module.exports = {
mcp,
auth,
adminAuth,
adminConfig,
adminGroups,
adminRoles,
keys,
apiKeys,
user,

View file

@ -1,5 +1,5 @@
const { Router } = require('express');
const { logger } = require('@librechat/data-schemas');
const { logger, getTenantId } = require('@librechat/data-schemas');
const {
CacheKeys,
Constants,
@ -36,7 +36,11 @@ const {
getFlowStateManager,
getMCPManager,
} = require('~/config');
const { getMCPSetupData, getServerConnectionStatus } = require('~/server/services/MCP');
const {
getServerConnectionStatus,
resolveConfigServers,
getMCPSetupData,
} = require('~/server/services/MCP');
const { requireJwtAuth, canAccessMCPServerResource } = require('~/server/middleware');
const { getUserPluginAuthValue } = require('~/server/services/PluginService');
const { updateMCPServerTools } = require('~/server/services/Config/mcp');
@ -101,7 +105,8 @@ router.get('/:serverName/oauth/initiate', requireJwtAuth, setOAuthSession, async
return res.status(400).json({ error: 'Invalid flow state' });
}
const oauthHeaders = await getOAuthHeaders(serverName, userId);
const configServers = await resolveConfigServers(req);
const oauthHeaders = await getOAuthHeaders(serverName, userId, configServers);
const {
authorizationUrl,
flowId: oauthFlowId,
@ -233,7 +238,14 @@ router.get('/:serverName/oauth/callback', async (req, res) => {
}
logger.debug('[MCP OAuth] Completing OAuth flow');
const oauthHeaders = await getOAuthHeaders(serverName, flowState.userId);
if (!flowState.oauthHeaders) {
logger.warn(
'[MCP OAuth] oauthHeaders absent from flow state — config-source server oauth_headers will be empty',
{ serverName, flowId },
);
}
const oauthHeaders =
flowState.oauthHeaders ?? (await getOAuthHeaders(serverName, flowState.userId));
const tokens = await MCPOAuthHandler.completeOAuthFlow(flowId, code, flowManager, oauthHeaders);
logger.info('[MCP OAuth] OAuth flow completed, tokens received in callback route');
@ -497,7 +509,12 @@ router.post(
logger.info(`[MCP Reinitialize] Reinitializing server: ${serverName}`);
const mcpManager = getMCPManager();
const serverConfig = await getMCPServersRegistry().getServerConfig(serverName, user.id);
const configServers = await resolveConfigServers(req);
const serverConfig = await getMCPServersRegistry().getServerConfig(
serverName,
user.id,
configServers,
);
if (!serverConfig) {
return res.status(404).json({
error: `MCP server '${serverName}' not found in configuration`,
@ -522,6 +539,8 @@ router.post(
const result = await reinitMCPServer({
user,
serverName,
serverConfig,
configServers,
userMCPAuthMap,
});
@ -564,6 +583,7 @@ router.get('/connection/status', requireJwtAuth, async (req, res) => {
const { mcpConfig, appConnections, userConnections, oauthServers } = await getMCPSetupData(
user.id,
{ role: user.role, tenantId: getTenantId() },
);
const connectionStatus = {};
@ -593,9 +613,6 @@ router.get('/connection/status', requireJwtAuth, async (req, res) => {
connectionStatus,
});
} catch (error) {
if (error.message === 'MCP config not found') {
return res.status(404).json({ error: error.message });
}
logger.error('[MCP Connection Status] Failed to get connection status', error);
res.status(500).json({ error: 'Failed to get connection status' });
}
@ -616,6 +633,7 @@ router.get('/connection/status/:serverName', requireJwtAuth, async (req, res) =>
const { mcpConfig, appConnections, userConnections, oauthServers } = await getMCPSetupData(
user.id,
{ role: user.role, tenantId: getTenantId() },
);
if (!mcpConfig[serverName]) {
@ -640,9 +658,6 @@ router.get('/connection/status/:serverName', requireJwtAuth, async (req, res) =>
requiresOAuth: serverStatus.requiresOAuth,
});
} catch (error) {
if (error.message === 'MCP config not found') {
return res.status(404).json({ error: error.message });
}
logger.error(
`[MCP Per-Server Status] Failed to get connection status for ${req.params.serverName}`,
error,
@ -664,7 +679,12 @@ router.get('/:serverName/auth-values', requireJwtAuth, checkMCPUsePermissions, a
return res.status(401).json({ error: 'User not authenticated' });
}
const serverConfig = await getMCPServersRegistry().getServerConfig(serverName, user.id);
const configServers = await resolveConfigServers(req);
const serverConfig = await getMCPServersRegistry().getServerConfig(
serverName,
user.id,
configServers,
);
if (!serverConfig) {
return res.status(404).json({
error: `MCP server '${serverName}' not found in configuration`,
@ -703,8 +723,12 @@ router.get('/:serverName/auth-values', requireJwtAuth, checkMCPUsePermissions, a
}
});
async function getOAuthHeaders(serverName, userId) {
const serverConfig = await getMCPServersRegistry().getServerConfig(serverName, userId);
async function getOAuthHeaders(serverName, userId, configServers) {
const serverConfig = await getMCPServersRegistry().getServerConfig(
serverName,
userId,
configServers,
);
return serverConfig?.oauth_headers ?? {};
}

View file

@ -13,6 +13,7 @@ const {
checkEmailConfig,
isEmailDomainAllowed,
shouldUseSecureCookie,
resolveAppConfigForUser,
} = require('@librechat/api');
const {
findUser,
@ -189,7 +190,7 @@ const registerUser = async (user, additionalData = {}) => {
let newUserId;
try {
const appConfig = await getAppConfig();
const appConfig = await getAppConfig({ baseOnly: true });
if (!isEmailDomainAllowed(email, appConfig?.registration?.allowedDomains)) {
const errorMessage =
'The email address provided cannot be used. Please use a different email address.';
@ -255,19 +256,52 @@ const registerUser = async (user, additionalData = {}) => {
};
/**
* Request password reset
* Request password reset.
*
* Uses a two-phase domain check: fast-fail with the memory-cached base config
* (zero DB queries) to block globally denied domains before user lookup, then
* re-check with tenant-scoped config after user lookup so tenant-specific
* restrictions are enforced.
*
* Phase 1 (base check) returns an Error (HTTP 400) this intentionally reveals
* that the domain is globally blocked, but fires before any DB lookup so it
* cannot confirm user existence. Phase 2 (tenant check) returns the generic
* success message (HTTP 200) to prevent user-enumeration via status codes.
*
* @param {ServerRequest} req
*/
const requestPasswordReset = async (req) => {
const { email } = req.body;
const appConfig = await getAppConfig();
if (!isEmailDomainAllowed(email, appConfig?.registration?.allowedDomains)) {
const baseConfig = await getAppConfig({ baseOnly: true });
if (!isEmailDomainAllowed(email, baseConfig?.registration?.allowedDomains)) {
logger.warn(
`[requestPasswordReset] Blocked - email domain not allowed [Email: ${email}] [IP: ${req.ip}]`,
);
const error = new Error(ErrorTypes.AUTH_FAILED);
error.code = ErrorTypes.AUTH_FAILED;
error.message = 'Email domain not allowed';
return error;
}
const user = await findUser({ email }, 'email _id');
const user = await findUser({ email }, 'email _id role tenantId');
let appConfig = baseConfig;
if (user?.tenantId) {
try {
appConfig = await resolveAppConfigForUser(getAppConfig, user);
} catch (err) {
logger.error('[requestPasswordReset] Failed to resolve tenant config, using base:', err);
}
}
if (!isEmailDomainAllowed(email, appConfig?.registration?.allowedDomains)) {
logger.warn(
`[requestPasswordReset] Tenant config blocked domain [Email: ${email}] [IP: ${req.ip}]`,
);
return {
message: 'If an account with that email exists, a password reset link has been sent to it.',
};
}
const emailEnabled = checkEmailConfig();
logger.warn(`[requestPasswordReset] [Password reset request initiated] [Email: ${email}]`);

View file

@ -14,6 +14,7 @@ jest.mock('@librechat/api', () => ({
isEmailDomainAllowed: jest.fn(),
math: jest.fn((val, fallback) => (val ? Number(val) : fallback)),
shouldUseSecureCookie: jest.fn(() => false),
resolveAppConfigForUser: jest.fn(async (_getAppConfig, _user) => ({})),
}));
jest.mock('~/models', () => ({
findUser: jest.fn(),
@ -35,8 +36,14 @@ jest.mock('~/strategies/validators', () => ({ registerSchema: { parse: jest.fn()
jest.mock('~/server/services/Config', () => ({ getAppConfig: jest.fn() }));
jest.mock('~/server/utils', () => ({ sendEmail: jest.fn() }));
const { shouldUseSecureCookie } = require('@librechat/api');
const { setOpenIDAuthTokens } = require('./AuthService');
const {
shouldUseSecureCookie,
isEmailDomainAllowed,
resolveAppConfigForUser,
} = require('@librechat/api');
const { findUser } = require('~/models');
const { getAppConfig } = require('~/server/services/Config');
const { setOpenIDAuthTokens, requestPasswordReset } = require('./AuthService');
/** Helper to build a mock Express response */
function mockResponse() {
@ -267,3 +274,68 @@ describe('setOpenIDAuthTokens', () => {
});
});
});
describe('requestPasswordReset', () => {
beforeEach(() => {
jest.clearAllMocks();
isEmailDomainAllowed.mockReturnValue(true);
getAppConfig.mockResolvedValue({
registration: { allowedDomains: ['example.com'] },
});
resolveAppConfigForUser.mockResolvedValue({
registration: { allowedDomains: ['example.com'] },
});
});
it('should fast-fail with base config before DB lookup for blocked domains', async () => {
isEmailDomainAllowed.mockReturnValue(false);
const req = { body: { email: 'blocked@evil.com' }, ip: '127.0.0.1' };
const result = await requestPasswordReset(req);
expect(getAppConfig).toHaveBeenCalledWith({ baseOnly: true });
expect(findUser).not.toHaveBeenCalled();
expect(result).toBeInstanceOf(Error);
});
it('should call resolveAppConfigForUser for tenant user', async () => {
const user = {
_id: 'user-tenant',
email: 'user@example.com',
tenantId: 'tenant-x',
role: 'USER',
};
findUser.mockResolvedValue(user);
const req = { body: { email: 'user@example.com' }, ip: '127.0.0.1' };
await requestPasswordReset(req);
expect(resolveAppConfigForUser).toHaveBeenCalledWith(getAppConfig, user);
});
it('should reuse baseConfig for non-tenant user without calling resolveAppConfigForUser', async () => {
findUser.mockResolvedValue({ _id: 'user-no-tenant', email: 'user@example.com' });
const req = { body: { email: 'user@example.com' }, ip: '127.0.0.1' };
await requestPasswordReset(req);
expect(resolveAppConfigForUser).not.toHaveBeenCalled();
});
it('should return generic response when tenant config blocks the domain (non-enumerable)', async () => {
const user = {
_id: 'user-tenant',
email: 'user@example.com',
tenantId: 'tenant-x',
role: 'USER',
};
findUser.mockResolvedValue(user);
isEmailDomainAllowed.mockReturnValueOnce(true).mockReturnValueOnce(false);
const req = { body: { email: 'user@example.com' }, ip: '127.0.0.1' };
const result = await requestPasswordReset(req);
expect(result).not.toBeInstanceOf(Error);
expect(result.message).toContain('If an account with that email exists');
});
});

View file

@ -0,0 +1,139 @@
// ── Mocks ──────────────────────────────────────────────────────────────
const mockConfigStoreDelete = jest.fn().mockResolvedValue(true);
const mockClearAppConfigCache = jest.fn().mockResolvedValue(undefined);
const mockClearOverrideCache = jest.fn().mockResolvedValue(undefined);
jest.mock('~/cache/getLogStores', () => {
return jest.fn(() => ({
delete: mockConfigStoreDelete,
}));
});
jest.mock('~/server/services/start/tools', () => ({
loadAndFormatTools: jest.fn(() => ({})),
}));
jest.mock('../loadCustomConfig', () => jest.fn().mockResolvedValue({}));
jest.mock('@librechat/data-schemas', () => {
const actual = jest.requireActual('@librechat/data-schemas');
return { ...actual, AppService: jest.fn(() => ({ availableTools: {} })) };
});
jest.mock('~/models', () => ({
getApplicableConfigs: jest.fn().mockResolvedValue([]),
getUserPrincipals: jest.fn().mockResolvedValue([]),
}));
const mockInvalidateCachedTools = jest.fn().mockResolvedValue(undefined);
jest.mock('../getCachedTools', () => ({
setCachedTools: jest.fn().mockResolvedValue(undefined),
invalidateCachedTools: mockInvalidateCachedTools,
}));
const mockClearMcpConfigCache = jest.fn().mockResolvedValue(undefined);
jest.mock('@librechat/api', () => ({
createAppConfigService: jest.fn(() => ({
getAppConfig: jest.fn().mockResolvedValue({ availableTools: {} }),
clearAppConfigCache: mockClearAppConfigCache,
clearOverrideCache: mockClearOverrideCache,
})),
clearMcpConfigCache: mockClearMcpConfigCache,
}));
// ── Tests ──────────────────────────────────────────────────────────────
const { CacheKeys } = require('librechat-data-provider');
const { invalidateConfigCaches } = require('../app');
describe('invalidateConfigCaches', () => {
beforeEach(() => {
jest.clearAllMocks();
});
it('clears all four caches', async () => {
await invalidateConfigCaches();
expect(mockClearAppConfigCache).toHaveBeenCalledTimes(1);
expect(mockClearOverrideCache).toHaveBeenCalledTimes(1);
expect(mockInvalidateCachedTools).toHaveBeenCalledWith({ invalidateGlobal: true });
expect(mockConfigStoreDelete).toHaveBeenCalledWith(CacheKeys.ENDPOINT_CONFIG);
});
it('passes tenantId through to clearOverrideCache', async () => {
await invalidateConfigCaches('tenant-a');
expect(mockClearOverrideCache).toHaveBeenCalledWith('tenant-a');
expect(mockClearAppConfigCache).toHaveBeenCalledTimes(1);
expect(mockInvalidateCachedTools).toHaveBeenCalledWith({ invalidateGlobal: true });
});
it('does not throw when CONFIG_STORE.delete fails', async () => {
mockConfigStoreDelete.mockRejectedValueOnce(new Error('store not found'));
await expect(invalidateConfigCaches()).resolves.not.toThrow();
// Other caches should still have been invalidated
expect(mockClearAppConfigCache).toHaveBeenCalledTimes(1);
expect(mockClearOverrideCache).toHaveBeenCalledTimes(1);
expect(mockInvalidateCachedTools).toHaveBeenCalledWith({ invalidateGlobal: true });
});
it('all operations run in parallel (not sequentially)', async () => {
const order = [];
mockClearAppConfigCache.mockImplementation(
() =>
new Promise((r) =>
setTimeout(() => {
order.push('base');
r();
}, 10),
),
);
mockClearOverrideCache.mockImplementation(
() =>
new Promise((r) =>
setTimeout(() => {
order.push('override');
r();
}, 10),
),
);
mockInvalidateCachedTools.mockImplementation(
() =>
new Promise((r) =>
setTimeout(() => {
order.push('tools');
r();
}, 10),
),
);
mockConfigStoreDelete.mockImplementation(
() =>
new Promise((r) =>
setTimeout(() => {
order.push('endpoint');
r();
}, 10),
),
);
await invalidateConfigCaches();
// All four should have been called (parallel execution via Promise.allSettled)
expect(order).toHaveLength(4);
expect(new Set(order)).toEqual(new Set(['base', 'override', 'tools', 'endpoint']));
});
it('resolves even when clearAppConfigCache throws (partial failure)', async () => {
mockClearAppConfigCache.mockRejectedValueOnce(new Error('cache connection lost'));
await expect(invalidateConfigCaches()).resolves.not.toThrow();
// Other caches should still have been invalidated despite the failure
expect(mockClearOverrideCache).toHaveBeenCalledTimes(1);
expect(mockInvalidateCachedTools).toHaveBeenCalledWith({ invalidateGlobal: true });
});
});

View file

@ -1,12 +1,12 @@
const { CacheKeys } = require('librechat-data-provider');
const { logger, AppService } = require('@librechat/data-schemas');
const { AppService, logger, scopedCacheKey } = require('@librechat/data-schemas');
const { createAppConfigService, clearMcpConfigCache } = require('@librechat/api');
const { setCachedTools, invalidateCachedTools } = require('./getCachedTools');
const { loadAndFormatTools } = require('~/server/services/start/tools');
const loadCustomConfig = require('./loadCustomConfig');
const { setCachedTools } = require('./getCachedTools');
const getLogStores = require('~/cache/getLogStores');
const paths = require('~/config/paths');
const BASE_CONFIG_KEY = '_BASE_';
const db = require('~/models');
const loadBaseConfig = async () => {
/** @type {TCustomConfig} */
@ -20,65 +20,67 @@ const loadBaseConfig = async () => {
return AppService({ config, paths, systemTools });
};
const { getAppConfig, clearAppConfigCache, clearOverrideCache } = createAppConfigService({
loadBaseConfig,
setCachedTools,
getCache: getLogStores,
cacheKeys: CacheKeys,
getApplicableConfigs: db.getApplicableConfigs,
getUserPrincipals: db.getUserPrincipals,
});
/**
* Get the app configuration based on user context
* @param {Object} [options]
* @param {string} [options.role] - User role for role-based config
* @param {boolean} [options.refresh] - Force refresh the cache
* @returns {Promise<AppConfig>}
* Deletes ENDPOINT_CONFIG entries from CONFIG_STORE.
* Clears both the tenant-scoped key (if in tenant context) and the
* unscoped base key (populated by unauthenticated /api/endpoints calls).
* Other tenants' scoped keys are NOT actively cleared they expire
* via TTL. Config mutations in one tenant do not propagate immediately
* to other tenants' endpoint config caches.
*/
async function getAppConfig(options = {}) {
const { role, refresh } = options;
const cache = getLogStores(CacheKeys.APP_CONFIG);
const cacheKey = role ? role : BASE_CONFIG_KEY;
if (!refresh) {
const cached = await cache.get(cacheKey);
if (cached) {
return cached;
async function clearEndpointConfigCache() {
try {
const configStore = getLogStores(CacheKeys.CONFIG_STORE);
const scoped = scopedCacheKey(CacheKeys.ENDPOINT_CONFIG);
const keys = [scoped];
if (scoped !== CacheKeys.ENDPOINT_CONFIG) {
keys.push(CacheKeys.ENDPOINT_CONFIG);
}
await Promise.all(keys.map((k) => configStore.delete(k)));
} catch {
// CONFIG_STORE or ENDPOINT_CONFIG may not exist — not critical
}
let baseConfig = await cache.get(BASE_CONFIG_KEY);
if (!baseConfig) {
logger.info('[getAppConfig] App configuration not initialized. Initializing AppService...');
baseConfig = await loadBaseConfig();
if (!baseConfig) {
throw new Error('Failed to initialize app configuration through AppService.');
}
if (baseConfig.availableTools) {
await setCachedTools(baseConfig.availableTools);
}
await cache.set(BASE_CONFIG_KEY, baseConfig);
}
// For now, return the base config
// In the future, this is where we'll apply role-based modifications
if (role) {
// TODO: Apply role-based config modifications
// const roleConfig = await applyRoleBasedConfig(baseConfig, role);
// await cache.set(cacheKey, roleConfig);
// return roleConfig;
}
return baseConfig;
}
/**
* Clear the app configuration cache
* @returns {Promise<boolean>}
* Invalidate all config-related caches after an admin config mutation.
* Clears the base config, per-principal override caches, tool caches,
* the endpoints config cache, and the MCP config-source server cache.
* @param {string} [tenantId] - Optional tenant ID to scope override cache clearing.
*/
async function clearAppConfigCache() {
const cache = getLogStores(CacheKeys.CONFIG_STORE);
const cacheKey = CacheKeys.APP_CONFIG;
return await cache.delete(cacheKey);
async function invalidateConfigCaches(tenantId) {
const results = await Promise.allSettled([
clearAppConfigCache(),
clearOverrideCache(tenantId),
invalidateCachedTools({ invalidateGlobal: true }),
clearEndpointConfigCache(),
clearMcpConfigCache(),
]);
const labels = [
'clearAppConfigCache',
'clearOverrideCache',
'invalidateCachedTools',
'clearEndpointConfigCache',
'clearMcpConfigCache',
];
for (let i = 0; i < results.length; i++) {
if (results[i].status === 'rejected') {
logger.error(`[invalidateConfigCaches] ${labels[i]} failed:`, results[i].reason);
}
}
}
module.exports = {
getAppConfig,
clearAppConfigCache,
invalidateConfigCaches,
};

View file

@ -1,3 +1,4 @@
const { scopedCacheKey } = require('@librechat/data-schemas');
const { loadCustomEndpointsConfig } = require('@librechat/api');
const {
CacheKeys,
@ -17,16 +18,18 @@ const { getAppConfig } = require('./app');
*/
async function getEndpointsConfig(req) {
const cache = getLogStores(CacheKeys.CONFIG_STORE);
const cachedEndpointsConfig = await cache.get(CacheKeys.ENDPOINT_CONFIG);
const cacheKey = scopedCacheKey(CacheKeys.ENDPOINT_CONFIG);
const cachedEndpointsConfig = await cache.get(cacheKey);
if (cachedEndpointsConfig) {
if (cachedEndpointsConfig.gptPlugins) {
await cache.delete(CacheKeys.ENDPOINT_CONFIG);
await cache.delete(cacheKey);
} else {
return cachedEndpointsConfig;
}
}
const appConfig = req.config ?? (await getAppConfig({ role: req.user?.role }));
const appConfig =
req.config ?? (await getAppConfig({ role: req.user?.role, tenantId: req.user?.tenantId }));
const defaultEndpointsConfig = await loadDefaultEndpointsConfig(appConfig);
const customEndpointsConfig = loadCustomEndpointsConfig(appConfig?.endpoints?.custom);
@ -111,7 +114,7 @@ async function getEndpointsConfig(req) {
const endpointsConfig = orderEndpointsConfig(mergedConfig);
await cache.set(CacheKeys.ENDPOINT_CONFIG, endpointsConfig);
await cache.set(cacheKey, endpointsConfig);
return endpointsConfig;
}

View file

@ -12,7 +12,7 @@ const { getAppConfig } = require('./app');
* @param {ServerRequest} req - The Express request object.
*/
async function loadConfigModels(req) {
const appConfig = await getAppConfig({ role: req.user?.role });
const appConfig = await getAppConfig({ role: req.user?.role, tenantId: req.user?.tenantId });
if (!appConfig) {
return {};
}

View file

@ -16,7 +16,8 @@ const { getAppConfig } = require('./app');
*/
async function loadDefaultModels(req) {
try {
const appConfig = req.config ?? (await getAppConfig({ role: req.user?.role }));
const appConfig =
req.config ?? (await getAppConfig({ role: req.user?.role, tenantId: req.user?.tenantId }));
const vertexConfig = appConfig?.endpoints?.[EModelEndpoint.anthropic]?.vertexConfig;
const [openAI, anthropic, azureOpenAI, assistants, azureAssistants, google, bedrock] =

View file

@ -1,5 +1,5 @@
const { logger } = require('@librechat/data-schemas');
const { CacheKeys, Constants } = require('librechat-data-provider');
const { logger, scopedCacheKey } = require('@librechat/data-schemas');
const { getCachedTools, setCachedTools } = require('./getCachedTools');
const { getLogStores } = require('~/cache');
@ -36,7 +36,7 @@ async function updateMCPServerTools({ userId, serverName, tools }) {
await setCachedTools(serverTools, { userId, serverName });
const cache = getLogStores(CacheKeys.TOOL_CACHE);
await cache.delete(CacheKeys.TOOLS);
await cache.delete(scopedCacheKey(CacheKeys.TOOLS));
logger.debug(
`[MCP Cache] Updated ${tools.length} tools for server ${serverName} (user: ${userId})`,
);
@ -48,7 +48,10 @@ async function updateMCPServerTools({ userId, serverName, tools }) {
}
/**
* Merges app-level tools with global tools
* Merges app-level tools with global tools.
* Only the current ALS-scoped key (base key in system/startup context) is cleared.
* Tenant-scoped TOOLS:tenantId keys are NOT actively invalidated they expire
* via TTL on the next tenant request. This matches clearEndpointConfigCache behavior.
* @param {import('@librechat/api').LCAvailableTools} appTools
* @returns {Promise<void>}
*/
@ -62,7 +65,7 @@ async function mergeAppTools(appTools) {
const mergedTools = { ...cachedTools, ...appTools };
await setCachedTools(mergedTools);
const cache = getLogStores(CacheKeys.TOOL_CACHE);
await cache.delete(CacheKeys.TOOLS);
await cache.delete(scopedCacheKey(CacheKeys.TOOLS));
logger.debug(`Merged ${count} app-level tools`);
} catch (error) {
logger.error('Failed to merge app-level tools:', error);

View file

@ -142,6 +142,7 @@ class STTService {
req.config ??
(await getAppConfig({
role: req?.user?.role,
tenantId: req?.user?.tenantId,
}));
const sttSchema = appConfig?.speech?.stt;
if (!sttSchema) {

View file

@ -297,6 +297,7 @@ class TTSService {
req.config ??
(await getAppConfig({
role: req.user?.role,
tenantId: req.user?.tenantId,
}));
try {
res.setHeader('Content-Type', 'audio/mpeg');
@ -365,6 +366,7 @@ class TTSService {
req.config ??
(await getAppConfig({
role: req.user?.role,
tenantId: req.user?.tenantId,
}));
const provider = this.getProvider(appConfig);
const ttsSchema = appConfig?.speech?.tts?.[provider];

View file

@ -17,6 +17,7 @@ async function getCustomConfigSpeech(req, res) {
try {
const appConfig = await getAppConfig({
role: req.user?.role,
tenantId: req.user?.tenantId,
});
if (!appConfig) {

View file

@ -18,6 +18,7 @@ async function getVoices(req, res) {
req.config ??
(await getAppConfig({
role: req.user?.role,
tenantId: req.user?.tenantId,
}));
const ttsSchema = appConfig?.speech?.tts;

View file

@ -1,3 +1,4 @@
const { scopedCacheKey } = require('@librechat/data-schemas');
const {
Time,
CacheKeys,
@ -67,6 +68,8 @@ function createChunkProcessor(user, messageId) {
}
const messageCache = getLogStores(CacheKeys.MESSAGES);
// Captured at creation time — must be called within an active request ALS scope
const cacheKey = scopedCacheKey(messageId);
/**
* @returns {Promise<{ text: string, isFinished: boolean }[] | string>}
@ -81,7 +84,7 @@ function createChunkProcessor(user, messageId) {
}
/** @type { string | { text: string; complete: boolean } } */
let message = await messageCache.get(messageId);
let message = await messageCache.get(cacheKey);
if (!message) {
message = await getMessage({ user, messageId });
}
@ -92,7 +95,7 @@ function createChunkProcessor(user, messageId) {
} else {
const text = message.content?.length > 0 ? parseTextParts(message.content) : message.text;
messageCache.set(
messageId,
cacheKey,
{
text,
complete: true,

View file

@ -1,5 +1,5 @@
const { tool } = require('@langchain/core/tools');
const { logger } = require('@librechat/data-schemas');
const { logger, getTenantId } = require('@librechat/data-schemas');
const {
Providers,
StepTypes,
@ -14,6 +14,7 @@ const {
normalizeJsonSchema,
GenerationJobManager,
resolveJsonSchemaRefs,
buildOAuthToolCallName,
} = require('@librechat/api');
const { Time, CacheKeys, Constants, isAssistantsEndpoint } = require('librechat-data-provider');
const {
@ -53,6 +54,53 @@ function evictStale(map, ttl) {
const unavailableMsg =
"This tool's MCP server is temporarily unavailable. Please try again shortly.";
/**
* Resolves config-source MCP servers from admin Config overrides for the current
* request context. Returns the parsed configs keyed by server name.
* @param {import('express').Request} req - Express request with user context
* @returns {Promise<Record<string, import('@librechat/api').ParsedServerConfig>>}
*/
async function resolveConfigServers(req) {
try {
const registry = getMCPServersRegistry();
const user = req?.user;
const appConfig = await getAppConfig({
role: user?.role,
tenantId: getTenantId(),
userId: user?.id,
});
return await registry.ensureConfigServers(appConfig?.mcpConfig || {});
} catch (error) {
logger.warn(
'[resolveConfigServers] Failed to resolve config servers, degrading to empty:',
error,
);
return {};
}
}
/**
* Resolves config-source servers and merges all server configs (YAML + config + user DB)
* for the given user context. Shared helper for controllers needing the full merged config.
* @param {string} userId
* @param {{ id?: string, role?: string }} [user]
* @returns {Promise<Record<string, import('@librechat/api').ParsedServerConfig>>}
*/
async function resolveAllMcpConfigs(userId, user) {
const registry = getMCPServersRegistry();
const appConfig = await getAppConfig({ role: user?.role, tenantId: getTenantId(), userId });
let configServers = {};
try {
configServers = await registry.ensureConfigServers(appConfig?.mcpConfig || {});
} catch (error) {
logger.warn(
'[resolveAllMcpConfigs] Config server resolution failed, continuing without:',
error,
);
}
return await registry.getAllServerConfigs(userId, configServers);
}
/**
* @param {string} toolName
* @param {string} serverName
@ -248,6 +296,7 @@ async function reconnectServer({
index,
signal,
serverName,
configServers,
userMCPAuthMap,
streamId = null,
}) {
@ -271,7 +320,7 @@ async function reconnectServer({
const stepId = 'step_oauth_login_' + serverName;
const toolCall = {
id: flowId,
name: serverName,
name: buildOAuthToolCallName(serverName),
type: 'tool_call_chunk',
};
@ -316,6 +365,7 @@ async function reconnectServer({
user,
signal,
serverName,
configServers,
oauthStart,
flowManager,
userMCPAuthMap,
@ -358,15 +408,14 @@ async function createMCPTools({
config,
provider,
serverName,
configServers,
userMCPAuthMap,
streamId = null,
}) {
// Early domain validation before reconnecting server (avoid wasted work on disallowed domains)
// Use getAppConfig() to support per-user/role domain restrictions
const serverConfig =
config ?? (await getMCPServersRegistry().getServerConfig(serverName, user?.id));
config ?? (await getMCPServersRegistry().getServerConfig(serverName, user?.id, configServers));
if (serverConfig?.url) {
const appConfig = await getAppConfig({ role: user?.role });
const appConfig = await getAppConfig({ role: user?.role, tenantId: user?.tenantId });
const allowedDomains = appConfig?.mcpSettings?.allowedDomains;
const isDomainAllowed = await isMCPDomainAllowed(serverConfig, allowedDomains);
if (!isDomainAllowed) {
@ -381,6 +430,7 @@ async function createMCPTools({
index,
signal,
serverName,
configServers,
userMCPAuthMap,
streamId,
});
@ -400,6 +450,7 @@ async function createMCPTools({
user,
provider,
userMCPAuthMap,
configServers,
streamId,
availableTools: result.availableTools,
toolKey: `${tool.name}${Constants.mcp_delimiter}${serverName}`,
@ -439,16 +490,15 @@ async function createMCPTool({
userMCPAuthMap,
availableTools,
config,
configServers,
streamId = null,
}) {
const [toolName, serverName] = toolKey.split(Constants.mcp_delimiter);
// Runtime domain validation: check if the server's domain is still allowed
// Use getAppConfig() to support per-user/role domain restrictions
const serverConfig =
config ?? (await getMCPServersRegistry().getServerConfig(serverName, user?.id));
config ?? (await getMCPServersRegistry().getServerConfig(serverName, user?.id, configServers));
if (serverConfig?.url) {
const appConfig = await getAppConfig({ role: user?.role });
const appConfig = await getAppConfig({ role: user?.role, tenantId: user?.tenantId });
const allowedDomains = appConfig?.mcpSettings?.allowedDomains;
const isDomainAllowed = await isMCPDomainAllowed(serverConfig, allowedDomains);
if (!isDomainAllowed) {
@ -477,6 +527,7 @@ async function createMCPTool({
index,
signal,
serverName,
configServers,
userMCPAuthMap,
streamId,
});
@ -500,6 +551,7 @@ async function createMCPTool({
provider,
toolName,
serverName,
serverConfig,
toolDefinition,
streamId,
});
@ -509,13 +561,14 @@ function createToolInstance({
res,
toolName,
serverName,
serverConfig: capturedServerConfig,
toolDefinition,
provider: _provider,
provider: capturedProvider,
streamId = null,
}) {
/** @type {LCTool} */
const { description, parameters } = toolDefinition;
const isGoogle = _provider === Providers.VERTEXAI || _provider === Providers.GOOGLE;
const isGoogle = capturedProvider === Providers.VERTEXAI || capturedProvider === Providers.GOOGLE;
let schema = parameters ? normalizeJsonSchema(resolveJsonSchemaRefs(parameters)) : null;
@ -544,7 +597,7 @@ function createToolInstance({
const flowManager = getFlowStateManager(flowsCache);
derivedSignal = config?.signal ? AbortSignal.any([config.signal]) : undefined;
const mcpManager = getMCPManager(userId);
const provider = (config?.metadata?.provider || _provider)?.toLowerCase();
const provider = (config?.metadata?.provider || capturedProvider)?.toLowerCase();
const { args: _args, stepId, ...toolCall } = config.toolCall ?? {};
const flowId = `${serverName}:oauth_login:${config.metadata.thread_id}:${config.metadata.run_id}`;
@ -576,6 +629,7 @@ function createToolInstance({
const result = await mcpManager.callTool({
serverName,
serverConfig: capturedServerConfig,
toolName,
provider,
toolArguments,
@ -643,30 +697,36 @@ function createToolInstance({
}
/**
* Get MCP setup data including config, connections, and OAuth servers
* Get MCP setup data including config, connections, and OAuth servers.
* Resolves config-source servers from admin Config overrides when tenant context is available.
* @param {string} userId - The user ID
* @param {{ role?: string, tenantId?: string }} [options] - Optional role/tenant context
* @returns {Object} Object containing mcpConfig, appConnections, userConnections, and oauthServers
*/
async function getMCPSetupData(userId) {
const mcpConfig = await getMCPServersRegistry().getAllServerConfigs(userId);
if (!mcpConfig) {
throw new Error('MCP config not found');
}
async function getMCPSetupData(userId, options = {}) {
const registry = getMCPServersRegistry();
const { role, tenantId } = options;
const appConfig = await getAppConfig({ role, tenantId, userId });
const configServers = await registry.ensureConfigServers(appConfig?.mcpConfig || {});
const mcpConfig = await registry.getAllServerConfigs(userId, configServers);
const mcpManager = getMCPManager(userId);
/** @type {Map<string, import('@librechat/api').MCPConnection>} */
let appConnections = new Map();
try {
// Use getLoaded() instead of getAll() to avoid forcing connection creation
// Use getLoaded() instead of getAll() to avoid forcing connection creation.
// getAll() creates connections for all servers, which is problematic for servers
// that require user context (e.g., those with {{LIBRECHAT_USER_ID}} placeholders)
// that require user context (e.g., those with {{LIBRECHAT_USER_ID}} placeholders).
appConnections = (await mcpManager.appConnections?.getLoaded()) || new Map();
} catch (error) {
logger.error(`[MCP][User: ${userId}] Error getting app connections:`, error);
}
const userConnections = mcpManager.getUserConnections(userId) || new Map();
const oauthServers = await getMCPServersRegistry().getOAuthServers(userId);
const oauthServers = new Set(
Object.entries(mcpConfig)
.filter(([, config]) => config.requiresOAuth)
.map(([name]) => name),
);
return {
mcpConfig,
@ -788,6 +848,8 @@ module.exports = {
createMCPTool,
createMCPTools,
getMCPSetupData,
resolveConfigServers,
resolveAllMcpConfigs,
checkOAuthFlowStatus,
getServerConnectionStatus,
createUnavailableToolStub,

View file

@ -14,6 +14,7 @@ const mockRegistryInstance = {
getOAuthServers: jest.fn(() => Promise.resolve(new Set())),
getAllServerConfigs: jest.fn(() => Promise.resolve({})),
getServerConfig: jest.fn(() => Promise.resolve(null)),
ensureConfigServers: jest.fn(() => Promise.resolve({})),
};
// Create isMCPDomainAllowed mock that can be configured per-test
@ -113,38 +114,43 @@ describe('tests for the new helper functions used by the MCP connection status e
});
it('should successfully return MCP setup data', async () => {
mockRegistryInstance.getAllServerConfigs.mockResolvedValue(mockConfig);
const mockConfigWithOAuth = {
server1: { type: 'stdio' },
server2: { type: 'http', requiresOAuth: true },
};
mockRegistryInstance.getAllServerConfigs.mockResolvedValue(mockConfigWithOAuth);
const mockAppConnections = new Map([['server1', { status: 'connected' }]]);
const mockUserConnections = new Map([['server2', { status: 'disconnected' }]]);
const mockOAuthServers = new Set(['server2']);
const mockMCPManager = {
appConnections: { getLoaded: jest.fn(() => Promise.resolve(mockAppConnections)) },
getUserConnections: jest.fn(() => mockUserConnections),
};
mockGetMCPManager.mockReturnValue(mockMCPManager);
mockRegistryInstance.getOAuthServers.mockResolvedValue(mockOAuthServers);
const result = await getMCPSetupData(mockUserId);
expect(mockRegistryInstance.getAllServerConfigs).toHaveBeenCalledWith(mockUserId);
expect(mockRegistryInstance.ensureConfigServers).toHaveBeenCalled();
expect(mockRegistryInstance.getAllServerConfigs).toHaveBeenCalledWith(
mockUserId,
expect.any(Object),
);
expect(mockGetMCPManager).toHaveBeenCalledWith(mockUserId);
expect(mockMCPManager.appConnections.getLoaded).toHaveBeenCalled();
expect(mockMCPManager.getUserConnections).toHaveBeenCalledWith(mockUserId);
expect(mockRegistryInstance.getOAuthServers).toHaveBeenCalledWith(mockUserId);
expect(result).toEqual({
mcpConfig: mockConfig,
appConnections: mockAppConnections,
userConnections: mockUserConnections,
oauthServers: mockOAuthServers,
});
expect(result.mcpConfig).toEqual(mockConfigWithOAuth);
expect(result.appConnections).toEqual(mockAppConnections);
expect(result.userConnections).toEqual(mockUserConnections);
expect(result.oauthServers).toEqual(new Set(['server2']));
});
it('should throw error when MCP config not found', async () => {
mockRegistryInstance.getAllServerConfigs.mockResolvedValue(null);
await expect(getMCPSetupData(mockUserId)).rejects.toThrow('MCP config not found');
it('should return empty data when no servers are configured', async () => {
mockRegistryInstance.getAllServerConfigs.mockResolvedValue({});
const result = await getMCPSetupData(mockUserId);
expect(result.mcpConfig).toEqual({});
expect(result.oauthServers).toEqual(new Set());
});
it('should handle null values from MCP manager gracefully', async () => {

View file

@ -19,6 +19,7 @@ const {
buildWebSearchContext,
buildImageToolContext,
buildToolClassification,
buildOAuthToolCallName,
} = require('@librechat/api');
const {
Time,
@ -59,6 +60,7 @@ const { manifestToolMap, toolkits } = require('~/app/clients/tools/manifest');
const { createOnSearchResults } = require('~/server/services/Tools/search');
const { loadAuthValues } = require('~/server/services/Tools/credentials');
const { reinitMCPServer } = require('~/server/services/Tools/mcp');
const { resolveConfigServers } = require('~/server/services/MCP');
const { recordUsage } = require('~/server/services/Threads');
const { loadTools } = require('~/app/clients/tools/util');
const { redactMessage } = require('~/config/parsers');
@ -513,6 +515,7 @@ async function loadToolDefinitionsWrapper({ req, res, agent, streamId = null, to
const flowsCache = getLogStores(CacheKeys.FLOWS);
const flowManager = getFlowStateManager(flowsCache);
const configServers = await resolveConfigServers(req);
const pendingOAuthServers = new Set();
const createOAuthEmitter = (serverName) => {
@ -521,7 +524,7 @@ async function loadToolDefinitionsWrapper({ req, res, agent, streamId = null, to
const stepId = 'step_oauth_login_' + serverName;
const toolCall = {
id: flowId,
name: serverName,
name: buildOAuthToolCallName(serverName),
type: 'tool_call_chunk',
};
@ -578,6 +581,7 @@ async function loadToolDefinitionsWrapper({ req, res, agent, streamId = null, to
oauthStart,
flowManager,
serverName,
configServers,
userMCPAuthMap,
});
@ -665,6 +669,7 @@ async function loadToolDefinitionsWrapper({ req, res, agent, streamId = null, to
const result = await reinitMCPServer({
user: req.user,
serverName,
configServers,
userMCPAuthMap,
flowManager,
returnOnOAuth: false,

View file

@ -25,11 +25,13 @@ async function reinitMCPServer({
signal,
forceNew,
serverName,
configServers,
userMCPAuthMap,
connectionTimeout,
returnOnOAuth = true,
oauthStart: _oauthStart,
flowManager: _flowManager,
serverConfig: providedConfig,
}) {
/** @type {MCPConnection | null} */
let connection = null;
@ -42,13 +44,28 @@ async function reinitMCPServer({
try {
const registry = getMCPServersRegistry();
const serverConfig = await registry.getServerConfig(serverName, user?.id);
const serverConfig =
providedConfig ?? (await registry.getServerConfig(serverName, user?.id, configServers));
if (serverConfig?.inspectionFailed) {
if (serverConfig.source === 'config') {
logger.info(
`[MCP Reinitialize] Config-source server ${serverName} has inspectionFailed — retry handled by config cache`,
);
return {
availableTools: null,
success: false,
message: `MCP server '${serverName}' is still unreachable`,
oauthRequired: false,
serverName,
oauthUrl: null,
tools: null,
};
}
logger.info(
`[MCP Reinitialize] Server ${serverName} had failed inspection, attempting reinspection`,
);
try {
const storageLocation = serverConfig.dbId ? 'DB' : 'CACHE';
const storageLocation = serverConfig.source === 'user' ? 'DB' : 'CACHE';
await registry.reinspectServer(serverName, storageLocation, user?.id);
logger.info(`[MCP Reinitialize] Reinspection succeeded for server: ${serverName}`);
} catch (reinspectError) {
@ -93,6 +110,7 @@ async function reinitMCPServer({
returnOnOAuth,
customUserVars,
connectionTimeout,
serverConfig,
});
logger.info(`[MCP Reinitialize] Successfully established connection for ${serverName}`);
@ -125,6 +143,7 @@ async function reinitMCPServer({
oauthStart,
customUserVars,
connectionTimeout,
configServers,
});
if (discoveryResult.tools && discoveryResult.tools.length > 0) {

View file

@ -0,0 +1,131 @@
const mockRegistry = {
ensureConfigServers: jest.fn(),
getAllServerConfigs: jest.fn(),
};
jest.mock('~/config', () => ({
getMCPServersRegistry: jest.fn(() => mockRegistry),
getMCPManager: jest.fn(),
getFlowStateManager: jest.fn(),
getOAuthReconnectionManager: jest.fn(),
}));
jest.mock('@librechat/data-schemas', () => ({
getTenantId: jest.fn(() => 'tenant-1'),
logger: { debug: jest.fn(), info: jest.fn(), warn: jest.fn(), error: jest.fn() },
}));
jest.mock('~/server/services/Config', () => ({
getAppConfig: jest.fn(),
setCachedTools: jest.fn(),
getCachedTools: jest.fn(),
getMCPServerTools: jest.fn(),
loadCustomConfig: jest.fn(),
}));
jest.mock('~/cache', () => ({ getLogStores: jest.fn() }));
jest.mock('~/models', () => ({
findToken: jest.fn(),
createToken: jest.fn(),
updateToken: jest.fn(),
}));
jest.mock('~/server/services/GraphTokenService', () => ({
getGraphApiToken: jest.fn(),
}));
jest.mock('~/server/services/Tools/mcp', () => ({
reinitMCPServer: jest.fn(),
}));
const { getAppConfig } = require('~/server/services/Config');
const { resolveConfigServers, resolveAllMcpConfigs } = require('../MCP');
describe('resolveConfigServers', () => {
beforeEach(() => jest.clearAllMocks());
it('resolves config servers for the current request context', async () => {
getAppConfig.mockResolvedValue({ mcpConfig: { srv: { url: 'http://a' } } });
mockRegistry.ensureConfigServers.mockResolvedValue({ srv: { name: 'srv' } });
const result = await resolveConfigServers({ user: { id: 'u1', role: 'admin' } });
expect(result).toEqual({ srv: { name: 'srv' } });
expect(getAppConfig).toHaveBeenCalledWith(
expect.objectContaining({ role: 'admin', userId: 'u1' }),
);
expect(mockRegistry.ensureConfigServers).toHaveBeenCalledWith({ srv: { url: 'http://a' } });
});
it('returns {} when ensureConfigServers throws', async () => {
getAppConfig.mockResolvedValue({ mcpConfig: { srv: {} } });
mockRegistry.ensureConfigServers.mockRejectedValue(new Error('inspect failed'));
const result = await resolveConfigServers({ user: { id: 'u1' } });
expect(result).toEqual({});
});
it('returns {} when getAppConfig throws', async () => {
getAppConfig.mockRejectedValue(new Error('db timeout'));
const result = await resolveConfigServers({ user: { id: 'u1' } });
expect(result).toEqual({});
});
it('passes empty mcpConfig when appConfig has none', async () => {
getAppConfig.mockResolvedValue({});
mockRegistry.ensureConfigServers.mockResolvedValue({});
await resolveConfigServers({ user: { id: 'u1' } });
expect(mockRegistry.ensureConfigServers).toHaveBeenCalledWith({});
});
});
describe('resolveAllMcpConfigs', () => {
beforeEach(() => jest.clearAllMocks());
it('merges config servers with base servers', async () => {
getAppConfig.mockResolvedValue({ mcpConfig: { cfg_srv: {} } });
mockRegistry.ensureConfigServers.mockResolvedValue({ cfg_srv: { name: 'cfg_srv' } });
mockRegistry.getAllServerConfigs.mockResolvedValue({
cfg_srv: { name: 'cfg_srv' },
yaml_srv: { name: 'yaml_srv' },
});
const result = await resolveAllMcpConfigs('u1', { id: 'u1', role: 'user' });
expect(result).toEqual({
cfg_srv: { name: 'cfg_srv' },
yaml_srv: { name: 'yaml_srv' },
});
expect(mockRegistry.getAllServerConfigs).toHaveBeenCalledWith('u1', {
cfg_srv: { name: 'cfg_srv' },
});
});
it('continues with empty configServers when ensureConfigServers fails', async () => {
getAppConfig.mockResolvedValue({ mcpConfig: { srv: {} } });
mockRegistry.ensureConfigServers.mockRejectedValue(new Error('inspect failed'));
mockRegistry.getAllServerConfigs.mockResolvedValue({ yaml_srv: { name: 'yaml_srv' } });
const result = await resolveAllMcpConfigs('u1', { id: 'u1' });
expect(result).toEqual({ yaml_srv: { name: 'yaml_srv' } });
expect(mockRegistry.getAllServerConfigs).toHaveBeenCalledWith('u1', {});
});
it('propagates getAllServerConfigs failures', async () => {
getAppConfig.mockResolvedValue({ mcpConfig: {} });
mockRegistry.ensureConfigServers.mockResolvedValue({});
mockRegistry.getAllServerConfigs.mockRejectedValue(new Error('redis down'));
await expect(resolveAllMcpConfigs('u1', { id: 'u1' })).rejects.toThrow('redis down');
});
it('propagates getAppConfig failures', async () => {
getAppConfig.mockRejectedValue(new Error('mongo down'));
await expect(resolveAllMcpConfigs('u1', { id: 'u1' })).rejects.toThrow('mongo down');
});
});

View file

@ -64,6 +64,9 @@ jest.mock('~/models', () => ({
jest.mock('~/config', () => ({
getFlowStateManager: jest.fn(() => ({})),
}));
jest.mock('~/server/services/MCP', () => ({
resolveConfigServers: jest.fn().mockResolvedValue({}),
}));
jest.mock('~/cache', () => ({
getLogStores: jest.fn(() => ({})),
}));

View file

@ -7,7 +7,7 @@ const { createMCPServersRegistry, createMCPManager } = require('~/config');
* Initialize MCP servers
*/
async function initializeMCPs() {
const appConfig = await getAppConfig();
const appConfig = await getAppConfig({ baseOnly: true });
const mcpServers = appConfig.mcpConfig;
try {

View file

@ -1,5 +1,5 @@
const { v4: uuidv4 } = require('uuid');
const { logger } = require('@librechat/data-schemas');
const { logger, scopedCacheKey } = require('@librechat/data-schemas');
const { EModelEndpoint, Constants, openAISettings, CacheKeys } = require('librechat-data-provider');
const { createImportBatchBuilder } = require('./importBatchBuilder');
const { cloneMessagesWithTimestamps } = require('./fork');
@ -203,7 +203,7 @@ async function importLibreChatConvo(
/* Endpoint configuration */
let endpoint = jsonData.endpoint ?? options.endpoint ?? EModelEndpoint.openAI;
const cache = getLogStores(CacheKeys.CONFIG_STORE);
const endpointsConfig = await cache.get(CacheKeys.ENDPOINT_CONFIG);
const endpointsConfig = await cache.get(scopedCacheKey(CacheKeys.ENDPOINT_CONFIG));
const endpointConfig = endpointsConfig?.[endpoint];
if (!endpointConfig && endpointsConfig) {
endpoint = Object.keys(endpointsConfig)[0];

View file

@ -2,7 +2,12 @@ const fs = require('fs');
const LdapStrategy = require('passport-ldapauth');
const { logger } = require('@librechat/data-schemas');
const { SystemRoles, ErrorTypes } = require('librechat-data-provider');
const { isEnabled, getBalanceConfig, isEmailDomainAllowed } = require('@librechat/api');
const {
isEnabled,
getBalanceConfig,
isEmailDomainAllowed,
resolveAppConfigForUser,
} = require('@librechat/api');
const { createUser, findUser, updateUser, countUsers } = require('~/models');
const { getAppConfig } = require('~/server/services/Config');
@ -89,16 +94,6 @@ const ldapLogin = new LdapStrategy(ldapOptions, async (userinfo, done) => {
const ldapId =
(LDAP_ID && userinfo[LDAP_ID]) || userinfo.uid || userinfo.sAMAccountName || userinfo.mail;
let user = await findUser({ ldapId });
if (user && user.provider !== 'ldap') {
logger.info(
`[ldapStrategy] User ${user.email} already exists with provider ${user.provider}`,
);
return done(null, false, {
message: ErrorTypes.AUTH_FAILED,
});
}
const fullNameAttributes = LDAP_FULL_NAME && LDAP_FULL_NAME.split(',');
const fullName =
fullNameAttributes && fullNameAttributes.length > 0
@ -122,7 +117,31 @@ const ldapLogin = new LdapStrategy(ldapOptions, async (userinfo, done) => {
);
}
const appConfig = await getAppConfig();
// Domain check before findUser for two-phase fast-fail (consistent with SAML/OpenID/social).
// This means cross-provider users from blocked domains get 'Email domain not allowed'
// instead of AUTH_FAILED — both deny access.
const baseConfig = await getAppConfig({ baseOnly: true });
if (!isEmailDomainAllowed(mail, baseConfig?.registration?.allowedDomains)) {
logger.error(
`[LDAP Strategy] Authentication blocked - email domain not allowed [Email: ${mail}]`,
);
return done(null, false, { message: 'Email domain not allowed' });
}
let user = await findUser({ ldapId });
if (user && user.provider !== 'ldap') {
logger.info(
`[ldapStrategy] User ${user.email} already exists with provider ${user.provider}`,
);
return done(null, false, {
message: ErrorTypes.AUTH_FAILED,
});
}
const appConfig = user?.tenantId
? await resolveAppConfigForUser(getAppConfig, user)
: baseConfig;
if (!isEmailDomainAllowed(mail, appConfig?.registration?.allowedDomains)) {
logger.error(
`[LDAP Strategy] Authentication blocked - email domain not allowed [Email: ${mail}]`,

View file

@ -9,10 +9,10 @@ jest.mock('@librechat/data-schemas', () => ({
}));
jest.mock('@librechat/api', () => ({
// isEnabled used for TLS flags
isEnabled: jest.fn(() => false),
isEmailDomainAllowed: jest.fn(() => true),
getBalanceConfig: jest.fn(() => ({ enabled: false })),
resolveAppConfigForUser: jest.fn(async (_getAppConfig, _user) => ({})),
}));
jest.mock('~/models', () => ({
@ -30,14 +30,15 @@ jest.mock('~/server/services/Config', () => ({
let verifyCallback;
jest.mock('passport-ldapauth', () => {
return jest.fn().mockImplementation((options, verify) => {
verifyCallback = verify; // capture the strategy verify function
verifyCallback = verify;
return { name: 'ldap', options, verify };
});
});
const { ErrorTypes } = require('librechat-data-provider');
const { isEmailDomainAllowed } = require('@librechat/api');
const { isEmailDomainAllowed, resolveAppConfigForUser } = require('@librechat/api');
const { findUser, createUser, updateUser, countUsers } = require('~/models');
const { getAppConfig } = require('~/server/services/Config');
// Helper to call the verify callback and wrap in a Promise for convenience
const callVerify = (userinfo) =>
@ -117,6 +118,7 @@ describe('ldapStrategy', () => {
expect(user).toBe(false);
expect(info).toEqual({ message: ErrorTypes.AUTH_FAILED });
expect(createUser).not.toHaveBeenCalled();
expect(resolveAppConfigForUser).not.toHaveBeenCalled();
});
it('updates an existing ldap user with current LDAP info', async () => {
@ -158,7 +160,6 @@ describe('ldapStrategy', () => {
uid: 'uid999',
givenName: 'John',
cn: 'John Doe',
// no mail and no custom LDAP_EMAIL
};
const { user } = await callVerify(userinfo);
@ -180,4 +181,66 @@ describe('ldapStrategy', () => {
expect(user).toBe(false);
expect(info).toEqual({ message: 'Email domain not allowed' });
});
it('passes getAppConfig and found user to resolveAppConfigForUser', async () => {
const existing = {
_id: 'u3',
provider: 'ldap',
email: 'tenant@example.com',
ldapId: 'uid-tenant',
username: 'tenantuser',
name: 'Tenant User',
tenantId: 'tenant-a',
role: 'USER',
};
findUser.mockResolvedValue(existing);
const userinfo = {
uid: 'uid-tenant',
mail: 'tenant@example.com',
givenName: 'Tenant',
cn: 'Tenant User',
};
await callVerify(userinfo);
expect(resolveAppConfigForUser).toHaveBeenCalledWith(getAppConfig, existing);
});
it('uses baseConfig for new user without calling resolveAppConfigForUser', async () => {
findUser.mockResolvedValue(null);
const userinfo = {
uid: 'uid-new',
mail: 'newuser@example.com',
givenName: 'New',
cn: 'New User',
};
await callVerify(userinfo);
expect(resolveAppConfigForUser).not.toHaveBeenCalled();
expect(getAppConfig).toHaveBeenCalledWith({ baseOnly: true });
});
it('should block login when tenant config restricts the domain', async () => {
const existing = {
_id: 'u-blocked',
provider: 'ldap',
ldapId: 'uid-tenant',
tenantId: 'tenant-strict',
role: 'USER',
};
findUser.mockResolvedValue(existing);
resolveAppConfigForUser.mockResolvedValue({
registration: { allowedDomains: ['other.com'] },
});
isEmailDomainAllowed.mockReturnValueOnce(true).mockReturnValueOnce(false);
const userinfo = { uid: 'uid-tenant', mail: 'user@example.com', givenName: 'Test', cn: 'Test' };
const { user, info } = await callVerify(userinfo);
expect(user).toBe(false);
expect(info).toEqual({ message: 'Email domain not allowed' });
});
});

View file

@ -15,6 +15,7 @@ const {
findOpenIDUser,
getBalanceConfig,
isEmailDomainAllowed,
resolveAppConfigForUser,
} = require('@librechat/api');
const { getStrategyFunctions } = require('~/server/services/Files/strategies');
const { findUser, createUser, updateUser } = require('~/models');
@ -468,9 +469,10 @@ async function processOpenIDAuth(tokenset, existingUsersOnly = false) {
Object.assign(userinfo, providerUserinfo);
}
const appConfig = await getAppConfig();
const email = getOpenIdEmail(userinfo);
if (!isEmailDomainAllowed(email, appConfig?.registration?.allowedDomains)) {
const baseConfig = await getAppConfig({ baseOnly: true });
if (!isEmailDomainAllowed(email, baseConfig?.registration?.allowedDomains)) {
logger.error(
`[OpenID Strategy] Authentication blocked - email domain not allowed [Identifier: ${email}]`,
);
@ -491,6 +493,15 @@ async function processOpenIDAuth(tokenset, existingUsersOnly = false) {
throw new Error(ErrorTypes.AUTH_FAILED);
}
const appConfig = user?.tenantId ? await resolveAppConfigForUser(getAppConfig, user) : baseConfig;
if (!isEmailDomainAllowed(email, appConfig?.registration?.allowedDomains)) {
logger.error(
`[OpenID Strategy] Authentication blocked - email domain not allowed [Identifier: ${email}]`,
);
throw new Error('Email domain not allowed');
}
const fullName = getFullName(userinfo);
const requiredRole = process.env.OPENID_REQUIRED_ROLE;

File diff suppressed because it is too large Load diff

View file

@ -5,7 +5,11 @@ const passport = require('passport');
const { ErrorTypes } = require('librechat-data-provider');
const { hashToken, logger } = require('@librechat/data-schemas');
const { Strategy: SamlStrategy } = require('@node-saml/passport-saml');
const { getBalanceConfig, isEmailDomainAllowed } = require('@librechat/api');
const {
getBalanceConfig,
isEmailDomainAllowed,
resolveAppConfigForUser,
} = require('@librechat/api');
const { getStrategyFunctions } = require('~/server/services/Files/strategies');
const { findUser, createUser, updateUser } = require('~/models');
const { getAppConfig } = require('~/server/services/Config');
@ -193,9 +197,9 @@ async function setupSaml() {
logger.debug('[samlStrategy] SAML profile:', profile);
const userEmail = getEmail(profile) || '';
const appConfig = await getAppConfig();
if (!isEmailDomainAllowed(userEmail, appConfig?.registration?.allowedDomains)) {
const baseConfig = await getAppConfig({ baseOnly: true });
if (!isEmailDomainAllowed(userEmail, baseConfig?.registration?.allowedDomains)) {
logger.error(
`[SAML Strategy] Authentication blocked - email domain not allowed [Email: ${userEmail}]`,
);
@ -223,6 +227,17 @@ async function setupSaml() {
});
}
const appConfig = user?.tenantId
? await resolveAppConfigForUser(getAppConfig, user)
: baseConfig;
if (!isEmailDomainAllowed(userEmail, appConfig?.registration?.allowedDomains)) {
logger.error(
`[SAML Strategy] Authentication blocked - email domain not allowed [Email: ${userEmail}]`,
);
return done(null, false, { message: 'Email domain not allowed' });
}
const fullName = getFullName(profile);
const username = convertToUsername(

View file

@ -30,6 +30,7 @@ jest.mock('@librechat/api', () => ({
tokenCredits: 1000,
startBalance: 1000,
})),
resolveAppConfigForUser: jest.fn(async (_getAppConfig, _user) => ({})),
}));
jest.mock('~/server/services/Config/EndpointService', () => ({
config: {},
@ -47,6 +48,9 @@ const fs = require('fs');
const path = require('path');
const fetch = require('node-fetch');
const { Strategy: SamlStrategy } = require('@node-saml/passport-saml');
const { findUser } = require('~/models');
const { resolveAppConfigForUser } = require('@librechat/api');
const { getAppConfig } = require('~/server/services/Config');
const { setupSaml, getCertificateContent } = require('./samlStrategy');
// Configure fs mock
@ -440,4 +444,50 @@ u7wlOSk+oFzDIO/UILIA
expect(fetch).not.toHaveBeenCalled();
});
it('should pass the found user to resolveAppConfigForUser', async () => {
const existingUser = {
_id: 'tenant-user-id',
provider: 'saml',
samlId: 'saml-1234',
email: 'test@example.com',
tenantId: 'tenant-c',
role: 'USER',
};
findUser.mockResolvedValue(existingUser);
const profile = { ...baseProfile };
await validate(profile);
expect(resolveAppConfigForUser).toHaveBeenCalledWith(getAppConfig, existingUser);
});
it('should use baseConfig for new SAML user without calling resolveAppConfigForUser', async () => {
const profile = { ...baseProfile };
await validate(profile);
expect(resolveAppConfigForUser).not.toHaveBeenCalled();
expect(getAppConfig).toHaveBeenCalledWith({ baseOnly: true });
});
it('should block login when tenant config restricts the domain', async () => {
const { isEmailDomainAllowed } = require('@librechat/api');
const existingUser = {
_id: 'tenant-blocked',
provider: 'saml',
samlId: 'saml-1234',
email: 'test@example.com',
tenantId: 'tenant-restrict',
role: 'USER',
};
findUser.mockResolvedValue(existingUser);
resolveAppConfigForUser.mockResolvedValue({
registration: { allowedDomains: ['other.com'] },
});
isEmailDomainAllowed.mockReturnValueOnce(true).mockReturnValueOnce(false);
const profile = { ...baseProfile };
const { user } = await validate(profile);
expect(user).toBe(false);
});
});

View file

@ -1,6 +1,6 @@
const { logger } = require('@librechat/data-schemas');
const { ErrorTypes } = require('librechat-data-provider');
const { isEnabled, isEmailDomainAllowed } = require('@librechat/api');
const { isEnabled, isEmailDomainAllowed, resolveAppConfigForUser } = require('@librechat/api');
const { createSocialUser, handleExistingUser } = require('./process');
const { getAppConfig } = require('~/server/services/Config');
const { findUser } = require('~/models');
@ -13,9 +13,8 @@ const socialLogin =
profile,
});
const appConfig = await getAppConfig();
if (!isEmailDomainAllowed(email, appConfig?.registration?.allowedDomains)) {
const baseConfig = await getAppConfig({ baseOnly: true });
if (!isEmailDomainAllowed(email, baseConfig?.registration?.allowedDomains)) {
logger.error(
`[${provider}Login] Authentication blocked - email domain not allowed [Email: ${email}]`,
);
@ -41,6 +40,20 @@ const socialLogin =
}
}
const appConfig = existingUser?.tenantId
? await resolveAppConfigForUser(getAppConfig, existingUser)
: baseConfig;
if (!isEmailDomainAllowed(email, appConfig?.registration?.allowedDomains)) {
logger.error(
`[${provider}Login] Authentication blocked - email domain not allowed [Email: ${email}]`,
);
const error = new Error(ErrorTypes.AUTH_FAILED);
error.code = ErrorTypes.AUTH_FAILED;
error.message = 'Email domain not allowed';
return cb(error);
}
if (existingUser?.provider === provider) {
await handleExistingUser(existingUser, avatarUrl, appConfig, email);
return cb(null, existingUser);

View file

@ -3,6 +3,8 @@ const { ErrorTypes } = require('librechat-data-provider');
const { createSocialUser, handleExistingUser } = require('./process');
const socialLogin = require('./socialLogin');
const { findUser } = require('~/models');
const { resolveAppConfigForUser } = require('@librechat/api');
const { getAppConfig } = require('~/server/services/Config');
jest.mock('@librechat/data-schemas', () => {
const actualModule = jest.requireActual('@librechat/data-schemas');
@ -25,6 +27,10 @@ jest.mock('@librechat/api', () => ({
...jest.requireActual('@librechat/api'),
isEnabled: jest.fn().mockReturnValue(true),
isEmailDomainAllowed: jest.fn().mockReturnValue(true),
resolveAppConfigForUser: jest.fn().mockResolvedValue({
fileStrategy: 'local',
balance: { enabled: false },
}),
}));
jest.mock('~/models', () => ({
@ -66,10 +72,7 @@ describe('socialLogin', () => {
googleId: googleId,
};
/** Mock findUser to return user on first call (by googleId), null on second call */
findUser
.mockResolvedValueOnce(existingUser) // First call: finds by googleId
.mockResolvedValueOnce(null); // Second call would be by email, but won't be reached
findUser.mockResolvedValueOnce(existingUser).mockResolvedValueOnce(null);
const mockProfile = {
id: googleId,
@ -83,13 +86,9 @@ describe('socialLogin', () => {
await loginFn(null, null, null, mockProfile, callback);
/** Verify it searched by googleId first */
expect(findUser).toHaveBeenNthCalledWith(1, { googleId: googleId });
/** Verify it did NOT search by email (because it found user by googleId) */
expect(findUser).toHaveBeenCalledTimes(1);
/** Verify handleExistingUser was called with the new email */
expect(handleExistingUser).toHaveBeenCalledWith(
existingUser,
'https://example.com/avatar.png',
@ -97,7 +96,6 @@ describe('socialLogin', () => {
newEmail,
);
/** Verify callback was called with success */
expect(callback).toHaveBeenCalledWith(null, existingUser);
});
@ -113,7 +111,7 @@ describe('socialLogin', () => {
facebookId: facebookId,
};
findUser.mockResolvedValue(existingUser); // Always returns user
findUser.mockResolvedValue(existingUser);
const mockProfile = {
id: facebookId,
@ -127,7 +125,6 @@ describe('socialLogin', () => {
await loginFn(null, null, null, mockProfile, callback);
/** Verify it searched by facebookId first */
expect(findUser).toHaveBeenCalledWith({ facebookId: facebookId });
expect(findUser.mock.calls[0]).toEqual([{ facebookId: facebookId }]);
@ -150,13 +147,10 @@ describe('socialLogin', () => {
_id: 'user789',
email: email,
provider: 'google',
googleId: 'old-google-id', // Different googleId (edge case)
googleId: 'old-google-id',
};
/** First call (by googleId) returns null, second call (by email) returns user */
findUser
.mockResolvedValueOnce(null) // By googleId
.mockResolvedValueOnce(existingUser); // By email
findUser.mockResolvedValueOnce(null).mockResolvedValueOnce(existingUser);
const mockProfile = {
id: googleId,
@ -170,13 +164,10 @@ describe('socialLogin', () => {
await loginFn(null, null, null, mockProfile, callback);
/** Verify both searches happened */
expect(findUser).toHaveBeenNthCalledWith(1, { googleId: googleId });
/** Email passed as-is; findUser implementation handles case normalization */
expect(findUser).toHaveBeenNthCalledWith(2, { email: email });
expect(findUser).toHaveBeenCalledTimes(2);
/** Verify warning log */
expect(logger.warn).toHaveBeenCalledWith(
`[${provider}Login] User found by email: ${email} but not by ${provider}Id`,
);
@ -197,7 +188,6 @@ describe('socialLogin', () => {
googleId: googleId,
};
/** Both searches return null */
findUser.mockResolvedValue(null);
createSocialUser.mockResolvedValue(newUser);
@ -213,10 +203,8 @@ describe('socialLogin', () => {
await loginFn(null, null, null, mockProfile, callback);
/** Verify both searches happened */
expect(findUser).toHaveBeenCalledTimes(2);
/** Verify createSocialUser was called */
expect(createSocialUser).toHaveBeenCalledWith({
email: email,
avatarUrl: 'https://example.com/avatar.png',
@ -242,12 +230,10 @@ describe('socialLogin', () => {
const existingUser = {
_id: 'user123',
email: email,
provider: 'local', // Different provider
provider: 'local',
};
findUser
.mockResolvedValueOnce(null) // By googleId
.mockResolvedValueOnce(existingUser); // By email
findUser.mockResolvedValueOnce(null).mockResolvedValueOnce(existingUser);
const mockProfile = {
id: googleId,
@ -261,7 +247,6 @@ describe('socialLogin', () => {
await loginFn(null, null, null, mockProfile, callback);
/** Verify error callback */
expect(callback).toHaveBeenCalledWith(
expect.objectContaining({
code: ErrorTypes.AUTH_FAILED,
@ -274,4 +259,104 @@ describe('socialLogin', () => {
);
});
});
describe('Tenant-scoped config', () => {
it('should call resolveAppConfigForUser for tenant user', async () => {
const provider = 'google';
const googleId = 'google-tenant-user';
const email = 'tenant@example.com';
const existingUser = {
_id: 'userTenant',
email,
provider: 'google',
googleId,
tenantId: 'tenant-b',
role: 'USER',
};
findUser.mockResolvedValue(existingUser);
const mockProfile = {
id: googleId,
emails: [{ value: email, verified: true }],
photos: [{ value: 'https://example.com/avatar.png' }],
name: { givenName: 'Tenant', familyName: 'User' },
};
const loginFn = socialLogin(provider, mockGetProfileDetails);
const callback = jest.fn();
await loginFn(null, null, null, mockProfile, callback);
expect(resolveAppConfigForUser).toHaveBeenCalledWith(getAppConfig, existingUser);
});
it('should use baseConfig for non-tenant user without calling resolveAppConfigForUser', async () => {
const provider = 'google';
const googleId = 'google-new-tenant';
const email = 'new@example.com';
findUser.mockResolvedValue(null);
createSocialUser.mockResolvedValue({
_id: 'newUser',
email,
provider: 'google',
googleId,
});
const mockProfile = {
id: googleId,
emails: [{ value: email, verified: true }],
photos: [{ value: 'https://example.com/avatar.png' }],
name: { givenName: 'New', familyName: 'User' },
};
const loginFn = socialLogin(provider, mockGetProfileDetails);
const callback = jest.fn();
await loginFn(null, null, null, mockProfile, callback);
expect(resolveAppConfigForUser).not.toHaveBeenCalled();
expect(getAppConfig).toHaveBeenCalledWith({ baseOnly: true });
});
it('should block login when tenant config restricts the domain', async () => {
const { isEmailDomainAllowed } = require('@librechat/api');
const provider = 'google';
const googleId = 'google-tenant-blocked';
const email = 'blocked@example.com';
const existingUser = {
_id: 'userBlocked',
email,
provider: 'google',
googleId,
tenantId: 'tenant-restrict',
role: 'USER',
};
findUser.mockResolvedValue(existingUser);
resolveAppConfigForUser.mockResolvedValue({
registration: { allowedDomains: ['other.com'] },
});
isEmailDomainAllowed.mockReturnValueOnce(true).mockReturnValueOnce(false);
const mockProfile = {
id: googleId,
emails: [{ value: email, verified: true }],
photos: [{ value: 'https://example.com/avatar.png' }],
name: { givenName: 'Blocked', familyName: 'User' },
};
const loginFn = socialLogin(provider, mockGetProfileDetails);
const callback = jest.fn();
await loginFn(null, null, null, mockProfile, callback);
expect(callback).toHaveBeenCalledWith(
expect.objectContaining({ message: 'Email domain not allowed' }),
);
});
});
});

View file

@ -1813,3 +1813,57 @@ describe('GLM Model Tests (Zhipu AI)', () => {
});
});
});
describe('Mistral Model Tests', () => {
describe('getModelMaxTokens', () => {
test('should return correct tokens for mistral-large-3 (256k context)', () => {
expect(getModelMaxTokens('mistral-large-3', EModelEndpoint.custom)).toBe(
maxTokensMap[EModelEndpoint.custom]['mistral-large-3'],
);
});
test('should match mistral-large-3 for suffixed variants', () => {
expect(getModelMaxTokens('mistral-large-3-instruct', EModelEndpoint.custom)).toBe(
maxTokensMap[EModelEndpoint.custom]['mistral-large-3'],
);
});
test('should not match mistral-large-3 for generic mistral-large', () => {
expect(getModelMaxTokens('mistral-large', EModelEndpoint.custom)).toBe(
maxTokensMap[EModelEndpoint.custom]['mistral-large'],
);
expect(getModelMaxTokens('mistral-large-latest', EModelEndpoint.custom)).toBe(
maxTokensMap[EModelEndpoint.custom]['mistral-large'],
);
});
});
describe('matchModelName', () => {
test('should match mistral-large-3 exactly', () => {
expect(matchModelName('mistral-large-3', EModelEndpoint.custom)).toBe('mistral-large-3');
});
test('should match mistral-large-3 for prefixed/suffixed variants', () => {
expect(matchModelName('mistral/mistral-large-3', EModelEndpoint.custom)).toBe(
'mistral-large-3',
);
expect(matchModelName('mistral-large-3-instruct', EModelEndpoint.custom)).toBe(
'mistral-large-3',
);
});
test('should match generic mistral-large for non-3 variants', () => {
expect(matchModelName('mistral-large-latest', EModelEndpoint.custom)).toBe('mistral-large');
});
});
describe('findMatchingPattern', () => {
test('should prefer mistral-large-3 over mistral-large for mistral-large-3 variants', () => {
const result = findMatchingPattern(
'mistral-large-3-instruct',
maxTokensMap[EModelEndpoint.custom],
);
expect(result).toBe('mistral-large-3');
});
});
});

View file

@ -140,6 +140,55 @@ export function useMessagesOperations() {
);
}
type OptionalMessagesOps = Pick<
MessagesViewContextValue,
'ask' | 'regenerate' | 'handleContinue' | 'getMessages' | 'setMessages'
>;
const NOOP_OPS: OptionalMessagesOps = {
ask: () => {},
regenerate: () => {},
handleContinue: () => {},
getMessages: () => undefined,
setMessages: () => {},
};
/**
* Hook for components that need message operations but may render outside MessagesViewProvider
* (e.g. the /search route). Returns no-op stubs when the provider is absent UI actions will
* be silently discarded rather than crashing. Callers must use optional chaining on
* `getMessages()` results, as it returns `undefined` outside the provider.
*/
export function useOptionalMessagesOperations(): OptionalMessagesOps {
const context = useContext(MessagesViewContext);
const ask = context?.ask;
const regenerate = context?.regenerate;
const handleContinue = context?.handleContinue;
const getMessages = context?.getMessages;
const setMessages = context?.setMessages;
return useMemo(
() => ({
ask: ask ?? NOOP_OPS.ask,
regenerate: regenerate ?? NOOP_OPS.regenerate,
handleContinue: handleContinue ?? NOOP_OPS.handleContinue,
getMessages: getMessages ?? NOOP_OPS.getMessages,
setMessages: setMessages ?? NOOP_OPS.setMessages,
}),
[ask, regenerate, handleContinue, getMessages, setMessages],
);
}
/**
* Hook for components that need conversation data but may render outside MessagesViewProvider
* (e.g. the /search route). Returns `undefined` for both fields when the provider is absent.
*/
export function useOptionalMessagesConversation() {
const context = useContext(MessagesViewContext);
const conversation = context?.conversation;
const conversationId = context?.conversationId;
return useMemo(() => ({ conversation, conversationId }), [conversation, conversationId]);
}
/** Hook for components that only need message state */
export function useMessagesState() {
const { index, latestMessageId, latestMessageDepth, setLatestMessage } = useMessagesViewContext();

View file

@ -0,0 +1,53 @@
import { renderHook } from '@testing-library/react';
import {
useOptionalMessagesOperations,
useOptionalMessagesConversation,
} from '../MessagesViewContext';
describe('useOptionalMessagesOperations', () => {
it('returns noop stubs when rendered outside MessagesViewProvider', () => {
const { result } = renderHook(() => useOptionalMessagesOperations());
expect(result.current.ask).toBeInstanceOf(Function);
expect(result.current.regenerate).toBeInstanceOf(Function);
expect(result.current.handleContinue).toBeInstanceOf(Function);
expect(result.current.getMessages).toBeInstanceOf(Function);
expect(result.current.setMessages).toBeInstanceOf(Function);
});
it('noop stubs do not throw when called', () => {
const { result } = renderHook(() => useOptionalMessagesOperations());
expect(() => result.current.ask({} as never)).not.toThrow();
expect(() => result.current.regenerate({} as never)).not.toThrow();
expect(() => result.current.handleContinue({} as never)).not.toThrow();
expect(() => result.current.setMessages([])).not.toThrow();
});
it('getMessages returns undefined outside the provider', () => {
const { result } = renderHook(() => useOptionalMessagesOperations());
expect(result.current.getMessages()).toBeUndefined();
});
it('returns stable references across re-renders', () => {
const { result, rerender } = renderHook(() => useOptionalMessagesOperations());
const first = result.current;
rerender();
expect(result.current).toBe(first);
});
});
describe('useOptionalMessagesConversation', () => {
it('returns undefined fields when rendered outside MessagesViewProvider', () => {
const { result } = renderHook(() => useOptionalMessagesConversation());
expect(result.current.conversation).toBeUndefined();
expect(result.current.conversationId).toBeUndefined();
});
it('returns stable references across re-renders', () => {
const { result, rerender } = renderHook(() => useOptionalMessagesConversation());
const first = result.current;
rerender();
expect(result.current).toBe(first);
});
});

View file

@ -355,6 +355,28 @@ export type TOptions = {
export type TAskFunction = (props: TAskProps, options?: TOptions) => void;
/**
* Stable context object passed from non-memo'd wrapper components (Message, MessageContent)
* to memo'd inner components (MessageRender, ContentRender) via props.
*
* This avoids subscribing to ChatContext inside memo'd components, which would bypass React.memo
* and cause unnecessary re-renders when `isSubmitting` changes during streaming.
*
* The `isSubmitting` property should use a getter backed by a ref so it returns the current
* value at call-time (for callback guards) without being a reactive dependency.
*/
export type TMessageChatContext = {
ask: (...args: Parameters<TAskFunction>) => void;
index: number;
regenerate: (message: t.TMessage, options?: { addedConvo?: t.TConversation | null }) => void;
conversation: t.TConversation | null;
latestMessageId: string | undefined;
latestMessageDepth: number | undefined;
handleContinue: (e: React.MouseEvent<HTMLButtonElement>) => void;
/** Should be a getter backed by a ref — reads current value without triggering re-renders */
readonly isSubmitting: boolean;
};
export type TMessageProps = {
conversation?: t.TConversation | null;
messageId?: string | null;

View file

@ -1,4 +1,4 @@
import { useCallback, useRef } from 'react';
import { memo, useCallback, useRef } from 'react';
import { MicOff } from 'lucide-react';
import { useToastContext, TooltipAnchor, ListeningIcon, Spinner } from '@librechat/client';
import { useLocalize, useSpeechToText, useGetAudioSettings } from '~/hooks';
@ -7,7 +7,7 @@ import { globalAudioId } from '~/common';
import { cn } from '~/utils';
const isExternalSTT = (speechToTextEndpoint: string) => speechToTextEndpoint === 'external';
export default function AudioRecorder({
export default memo(function AudioRecorder({
disabled,
ask,
methods,
@ -26,10 +26,12 @@ export default function AudioRecorder({
const { speechToTextEndpoint } = useGetAudioSettings();
const existingTextRef = useRef<string>('');
const isSubmittingRef = useRef(isSubmitting);
isSubmittingRef.current = isSubmitting;
const onTranscriptionComplete = useCallback(
(text: string) => {
if (isSubmitting) {
if (isSubmittingRef.current) {
showToast({
message: localize('com_ui_speech_while_submitting'),
status: 'error',
@ -52,7 +54,7 @@ export default function AudioRecorder({
existingTextRef.current = '';
}
},
[ask, reset, showToast, localize, isSubmitting, speechToTextEndpoint],
[ask, reset, showToast, localize, speechToTextEndpoint],
);
const setText = useCallback(
@ -125,4 +127,4 @@ export default function AudioRecorder({
}
/>
);
}
});

View file

@ -3,6 +3,8 @@ import { useWatch } from 'react-hook-form';
import { TextareaAutosize } from '@librechat/client';
import { useRecoilState, useRecoilValue } from 'recoil';
import { Constants, isAssistantsEndpoint, isAgentsEndpoint } from 'librechat-data-provider';
import type { TConversation } from 'librechat-data-provider';
import type { ExtendedFile, FileSetter, ConvoGenerator } from '~/common';
import {
useChatContext,
useChatFormContext,
@ -35,7 +37,30 @@ import BadgeRow from './BadgeRow';
import Mention from './Mention';
import store from '~/store';
const ChatForm = memo(({ index = 0 }: { index?: number }) => {
interface ChatFormProps {
index: number;
/** From ChatContext — individual values so memo can compare them */
files: Map<string, ExtendedFile>;
setFiles: FileSetter;
conversation: TConversation | null;
isSubmitting: boolean;
filesLoading: boolean;
setFilesLoading: React.Dispatch<React.SetStateAction<boolean>>;
newConversation: ConvoGenerator;
handleStopGenerating: (e: React.MouseEvent<HTMLButtonElement>) => void;
}
const ChatForm = memo(function ChatForm({
index,
files,
setFiles,
conversation,
isSubmitting,
filesLoading,
setFilesLoading,
newConversation,
handleStopGenerating,
}: ChatFormProps) {
const submitButtonRef = useRef<HTMLButtonElement>(null);
const textAreaRef = useRef<HTMLTextAreaElement>(null);
useFocusChatEffect(textAreaRef);
@ -65,15 +90,6 @@ const ChatForm = memo(({ index = 0 }: { index?: number }) => {
const { requiresKey } = useRequiresKey();
const methods = useChatFormContext();
const {
files,
setFiles,
conversation,
isSubmitting,
filesLoading,
newConversation,
handleStopGenerating,
} = useChatContext();
const {
generateConversation,
conversation: addedConvo,
@ -120,6 +136,15 @@ const ChatForm = memo(({ index = 0 }: { index?: number }) => {
}
}, [isCollapsed]);
const handleTextareaFocus = useCallback(() => {
handleFocusOrClick();
setIsTextAreaFocused(true);
}, [handleFocusOrClick]);
const handleTextareaBlur = useCallback(() => {
setIsTextAreaFocused(false);
}, []);
useAutoSave({
files,
setFiles,
@ -253,7 +278,12 @@ const ChatForm = memo(({ index = 0 }: { index?: number }) => {
handleSaveBadges={handleSaveBadges}
setBadges={setBadges}
/>
<FileFormChat conversation={conversation} />
<FileFormChat
conversation={conversation}
files={files}
setFiles={setFiles}
setFilesLoading={setFilesLoading}
/>
{endpoint && (
<div className={cn('flex', isRTL ? 'flex-row-reverse' : 'flex-row')}>
<div
@ -284,11 +314,8 @@ const ChatForm = memo(({ index = 0 }: { index?: number }) => {
tabIndex={0}
data-testid="text-input"
rows={1}
onFocus={() => {
handleFocusOrClick();
setIsTextAreaFocused(true);
}}
onBlur={setIsTextAreaFocused.bind(null, false)}
onFocus={handleTextareaFocus}
onBlur={handleTextareaBlur}
aria-label={localize('com_ui_message_input')}
onClick={handleFocusOrClick}
style={{ height: 44, overflowY: 'auto' }}
@ -315,7 +342,13 @@ const ChatForm = memo(({ index = 0 }: { index?: number }) => {
)}
>
<div className={`${isRTL ? 'mr-2' : 'ml-2'}`}>
<AttachFileChat conversation={conversation} disableInputs={disableInputs} />
<AttachFileChat
conversation={conversation}
disableInputs={disableInputs}
files={files}
setFiles={setFiles}
setFilesLoading={setFilesLoading}
/>
</div>
<BadgeRow
showEphemeralBadges={
@ -360,5 +393,77 @@ const ChatForm = memo(({ index = 0 }: { index?: number }) => {
</form>
);
});
ChatForm.displayName = 'ChatForm';
export default ChatForm;
/**
* Wrapper that subscribes to ChatContext and passes stable individual values
* to the memo'd ChatForm. This prevents ChatForm from re-rendering on every
* streaming chunk it only re-renders when the specific values it uses change.
*/
function ChatFormWrapper({ index = 0 }: { index?: number }) {
const {
files,
setFiles,
conversation,
isSubmitting,
filesLoading,
setFilesLoading,
newConversation,
handleStopGenerating,
} = useChatContext();
/**
* Stabilize conversation reference: only update when rendering-relevant fields change,
* not on every metadata update (e.g., title generation during streaming).
*/
const hasMessages = (conversation?.messages?.length ?? 0) > 0;
const stableConversation = useMemo(
() => conversation,
// eslint-disable-next-line react-hooks/exhaustive-deps
[
conversation?.conversationId,
conversation?.endpoint,
conversation?.endpointType,
conversation?.agent_id,
conversation?.assistant_id,
conversation?.spec,
conversation?.useResponsesApi,
conversation?.model,
hasMessages,
],
);
/** Stabilize function refs so they never trigger ChatForm re-renders */
const handleStopRef = useRef(handleStopGenerating);
handleStopRef.current = handleStopGenerating;
const stableHandleStop = useCallback(
(e: React.MouseEvent<HTMLButtonElement>) => handleStopRef.current(e),
[],
);
const newConvoRef = useRef(newConversation);
newConvoRef.current = newConversation;
const stableNewConversation: ConvoGenerator = useCallback(
(...args: Parameters<ConvoGenerator>): ReturnType<ConvoGenerator> =>
newConvoRef.current(...args),
[],
);
return (
<ChatForm
index={index}
files={files}
setFiles={setFiles}
conversation={stableConversation}
isSubmitting={isSubmitting}
filesLoading={filesLoading}
setFilesLoading={setFilesLoading}
newConversation={stableNewConversation}
handleStopGenerating={stableHandleStop}
/>
);
}
ChatFormWrapper.displayName = 'ChatFormWrapper';
export default ChatFormWrapper;

View file

@ -52,4 +52,4 @@ const CollapseChat = ({
);
};
export default CollapseChat;
export default React.memo(CollapseChat);

View file

@ -1,14 +1,33 @@
import React, { useRef } from 'react';
import { FileUpload, TooltipAnchor, AttachmentIcon } from '@librechat/client';
import { useLocalize, useFileHandling } from '~/hooks';
import type { TConversation } from 'librechat-data-provider';
import type { ExtendedFile, FileSetter } from '~/common';
import { useFileHandlingNoChatContext, useLocalize } from '~/hooks';
import { cn } from '~/utils';
const AttachFile = ({ disabled }: { disabled?: boolean | null }) => {
const AttachFile = ({
disabled,
files,
setFiles,
setFilesLoading,
conversation,
}: {
disabled?: boolean | null;
files: Map<string, ExtendedFile>;
setFiles: FileSetter;
setFilesLoading: React.Dispatch<React.SetStateAction<boolean>>;
conversation: TConversation | null;
}) => {
const localize = useLocalize();
const inputRef = useRef<HTMLInputElement>(null);
const isUploadDisabled = disabled ?? false;
const { handleFileChange } = useFileHandling();
const { handleFileChange } = useFileHandlingNoChatContext(undefined, {
files,
setFiles,
setFilesLoading,
conversation,
});
return (
<FileUpload ref={inputRef} handleFileChange={handleFileChange}>

View file

@ -9,6 +9,7 @@ import {
getEndpointFileConfig,
} from 'librechat-data-provider';
import type { TConversation } from 'librechat-data-provider';
import type { ExtendedFile, FileSetter } from '~/common';
import { useGetFileConfig, useGetEndpointsQuery, useGetAgentByIdQuery } from '~/data-provider';
import { useAgentsMapContext } from '~/Providers';
import AttachFileMenu from './AttachFileMenu';
@ -17,9 +18,15 @@ import AttachFile from './AttachFile';
function AttachFileChat({
disableInputs,
conversation,
files,
setFiles,
setFilesLoading,
}: {
disableInputs: boolean;
conversation: TConversation | null;
files: Map<string, ExtendedFile>;
setFiles: FileSetter;
setFilesLoading: React.Dispatch<React.SetStateAction<boolean>>;
}) {
const conversationId = conversation?.conversationId ?? Constants.NEW_CONVO;
const { endpoint } = conversation ?? { endpoint: null };
@ -90,7 +97,15 @@ function AttachFileChat({
);
if (isAssistants && endpointSupportsFiles && !isUploadDisabled) {
return <AttachFile disabled={disableInputs} />;
return (
<AttachFile
disabled={disableInputs}
files={files}
setFiles={setFiles}
setFilesLoading={setFilesLoading}
conversation={conversation}
/>
);
} else if ((isAgents || endpointSupportsFiles) && !isUploadDisabled) {
return (
<AttachFileMenu
@ -101,6 +116,10 @@ function AttachFileChat({
agentId={conversation?.agent_id}
endpointFileConfig={endpointFileConfig}
useResponsesApi={useResponsesApi}
files={files}
setFiles={setFiles}
setFilesLoading={setFilesLoading}
conversation={conversation}
/>
);
}

View file

@ -23,15 +23,16 @@ import {
bedrockDocumentExtensions,
isDocumentSupportedProvider,
} from 'librechat-data-provider';
import type { EndpointFileConfig } from 'librechat-data-provider';
import type { EndpointFileConfig, TConversation } from 'librechat-data-provider';
import type { ExtendedFile, FileSetter } from '~/common';
import {
useAgentToolPermissions,
useAgentCapabilities,
useGetAgentsConfig,
useFileHandling,
useFileHandlingNoChatContext,
useLocalize,
} from '~/hooks';
import useSharePointFileHandling from '~/hooks/Files/useSharePointFileHandling';
import { useSharePointFileHandlingNoChatContext } from '~/hooks/Files/useSharePointFileHandling';
import { SharePointPickerDialog } from '~/components/SharePoint';
import { useGetStartupConfig } from '~/data-provider';
import { ephemeralAgentByConvoId } from '~/store';
@ -53,6 +54,10 @@ interface AttachFileMenuProps {
endpointType?: EModelEndpoint | string;
endpointFileConfig?: EndpointFileConfig;
useResponsesApi?: boolean;
files: Map<string, ExtendedFile>;
setFiles: FileSetter;
setFilesLoading: React.Dispatch<React.SetStateAction<boolean>>;
conversation: TConversation | null;
}
const AttachFileMenu = ({
@ -63,6 +68,10 @@ const AttachFileMenu = ({
conversationId,
endpointFileConfig,
useResponsesApi,
files,
setFiles,
setFilesLoading,
conversation,
}: AttachFileMenuProps) => {
const localize = useLocalize();
const isUploadDisabled = disabled ?? false;
@ -72,10 +81,17 @@ const AttachFileMenu = ({
ephemeralAgentByConvoId(conversationId),
);
const [toolResource, setToolResource] = useState<EToolResources | undefined>();
const { handleFileChange } = useFileHandling();
const { handleSharePointFiles, isProcessing, downloadProgress } = useSharePointFileHandling({
toolResource,
const { handleFileChange } = useFileHandlingNoChatContext(undefined, {
files,
setFiles,
setFilesLoading,
conversation,
});
const { handleSharePointFiles, isProcessing, downloadProgress } =
useSharePointFileHandlingNoChatContext(
{ toolResource },
{ files, setFiles, setFilesLoading, conversation },
);
const { agentsConfig } = useGetAgentsConfig();
const { data: startupConfig } = useGetStartupConfig();

View file

@ -1,16 +1,30 @@
import { memo } from 'react';
import { useRecoilValue } from 'recoil';
import type { TConversation } from 'librechat-data-provider';
import { useChatContext } from '~/Providers';
import { useFileHandling } from '~/hooks';
import type { ExtendedFile, FileSetter } from '~/common';
import { useFileHandlingNoChatContext } from '~/hooks';
import FileRow from './FileRow';
import store from '~/store';
function FileFormChat({ conversation }: { conversation: TConversation | null }) {
const { files, setFiles, setFilesLoading } = useChatContext();
function FileFormChat({
conversation,
files,
setFiles,
setFilesLoading,
}: {
conversation: TConversation | null;
files: Map<string, ExtendedFile>;
setFiles: FileSetter;
setFilesLoading: React.Dispatch<React.SetStateAction<boolean>>;
}) {
const chatDirection = useRecoilValue(store.chatDirection).toLowerCase();
const { endpoint: _endpoint } = conversation ?? { endpoint: null };
const { abortUpload } = useFileHandling();
const { abortUpload } = useFileHandlingNoChatContext(undefined, {
files,
setFiles,
setFilesLoading,
conversation,
});
const isRTL = chatDirection === 'rtl';

View file

@ -59,7 +59,13 @@ function renderComponent(conversation: Record<string, unknown> | null, disableIn
return render(
<QueryClientProvider client={queryClient}>
<RecoilRoot>
<AttachFileChat conversation={conversation as never} disableInputs={disableInputs} />
<AttachFileChat
conversation={conversation as never}
disableInputs={disableInputs}
files={new Map()}
setFiles={() => {}}
setFilesLoading={() => {}}
/>
</RecoilRoot>
</QueryClientProvider>,
);

View file

@ -9,13 +9,14 @@ jest.mock('~/hooks', () => ({
useAgentToolPermissions: jest.fn(),
useAgentCapabilities: jest.fn(),
useGetAgentsConfig: jest.fn(),
useFileHandling: jest.fn(),
useFileHandlingNoChatContext: jest.fn(),
useLocalize: jest.fn(),
}));
jest.mock('~/hooks/Files/useSharePointFileHandling', () => ({
__esModule: true,
default: jest.fn(),
useSharePointFileHandlingNoChatContext: jest.fn(),
}));
jest.mock('~/data-provider', () => ({
@ -52,6 +53,7 @@ jest.mock('@librechat/client', () => {
),
AttachmentIcon: () => R.createElement('span', { 'data-testid': 'attachment-icon' }),
SharePointIcon: () => R.createElement('span', { 'data-testid': 'sharepoint-icon' }),
useToastContext: () => ({ showToast: jest.fn() }),
};
});
@ -66,11 +68,14 @@ jest.mock('@ariakit/react', () => {
const mockUseAgentToolPermissions = jest.requireMock('~/hooks').useAgentToolPermissions;
const mockUseAgentCapabilities = jest.requireMock('~/hooks').useAgentCapabilities;
const mockUseGetAgentsConfig = jest.requireMock('~/hooks').useGetAgentsConfig;
const mockUseFileHandling = jest.requireMock('~/hooks').useFileHandling;
const mockUseFileHandlingNoChatContext = jest.requireMock('~/hooks').useFileHandlingNoChatContext;
const mockUseLocalize = jest.requireMock('~/hooks').useLocalize;
const mockUseSharePointFileHandling = jest.requireMock(
'~/hooks/Files/useSharePointFileHandling',
).default;
const mockUseSharePointFileHandlingNoChatContext = jest.requireMock(
'~/hooks/Files/useSharePointFileHandling',
).useSharePointFileHandlingNoChatContext;
const mockUseGetStartupConfig = jest.requireMock('~/data-provider').useGetStartupConfig;
const queryClient = new QueryClient({ defaultOptions: { queries: { retry: false } } });
@ -92,12 +97,15 @@ function setupMocks(overrides: { provider?: string } = {}) {
codeEnabled: false,
});
mockUseGetAgentsConfig.mockReturnValue({ agentsConfig: {} });
mockUseFileHandling.mockReturnValue({ handleFileChange: jest.fn() });
mockUseSharePointFileHandling.mockReturnValue({
mockUseFileHandlingNoChatContext.mockReturnValue({ handleFileChange: jest.fn() });
const sharePointReturnValue = {
handleSharePointFiles: jest.fn(),
isProcessing: false,
downloadProgress: 0,
});
error: null,
};
mockUseSharePointFileHandling.mockReturnValue(sharePointReturnValue);
mockUseSharePointFileHandlingNoChatContext.mockReturnValue(sharePointReturnValue);
mockUseGetStartupConfig.mockReturnValue({ data: { sharePointFilePickerEnabled: false } });
mockUseAgentToolPermissions.mockReturnValue({
fileSearchAllowedByAgent: false,
@ -110,7 +118,14 @@ function renderMenu(props: Record<string, unknown> = {}) {
return render(
<QueryClientProvider client={queryClient}>
<RecoilRoot>
<AttachFileMenu conversationId="test-convo" {...props} />
<AttachFileMenu
conversationId="test-convo"
files={new Map()}
setFiles={() => {}}
setFilesLoading={() => {}}
conversation={null}
{...props}
/>
</RecoilRoot>
</QueryClientProvider>,
);

View file

@ -1,8 +1,15 @@
import { memo } from 'react';
import { TooltipAnchor } from '@librechat/client';
import { useLocalize } from '~/hooks';
import { cn } from '~/utils';
export default function StopButton({ stop, setShowStopButton }) {
export default memo(function StopButton({
stop,
setShowStopButton,
}: {
stop: (e: React.MouseEvent<HTMLButtonElement>) => void;
setShowStopButton: (value: boolean) => void;
}) {
const localize = useLocalize();
return (
@ -34,4 +41,4 @@ export default function StopButton({ stop, setShowStopButton }) {
}
></TooltipAnchor>
);
}
});

View file

@ -1,8 +1,9 @@
import { memo } from 'react';
import AddedConvo from './AddedConvo';
import type { TConversation } from 'librechat-data-provider';
import type { SetterOrUpdater } from 'recoil';
export default function TextareaHeader({
export default memo(function TextareaHeader({
addedConvo,
setAddedConvo,
}: {
@ -17,4 +18,4 @@ export default function TextareaHeader({
<AddedConvo addedConvo={addedConvo} setAddedConvo={setAddedConvo} />
</div>
);
}
});

View file

@ -49,19 +49,47 @@ export default function ToolCall({
}
}, [autoExpand, hasOutput]);
const parsedAuthUrl = useMemo(() => {
if (!auth) {
return null;
}
try {
return new URL(auth);
} catch {
return null;
}
}, [auth]);
const { function_name, domain, isMCPToolCall, mcpServerName } = useMemo(() => {
if (typeof name !== 'string') {
return { function_name: '', domain: null, isMCPToolCall: false, mcpServerName: '' };
}
if (name.includes(Constants.mcp_delimiter)) {
const [func, server] = name.split(Constants.mcp_delimiter);
const parts = name.split(Constants.mcp_delimiter);
const func = parts[0];
const server = parts.slice(1).join(Constants.mcp_delimiter);
const displayName = func === 'oauth' ? server : func;
return {
function_name: func || '',
function_name: displayName || '',
domain: server && (server.replaceAll(actionDomainSeparator, '.') || null),
isMCPToolCall: true,
mcpServerName: server || '',
};
}
if (parsedAuthUrl) {
const redirectUri = parsedAuthUrl.searchParams.get('redirect_uri') || '';
const mcpMatch = redirectUri.match(/\/api\/mcp\/([^/]+)\/oauth\/callback/);
if (mcpMatch?.[1]) {
return {
function_name: mcpMatch[1],
domain: null,
isMCPToolCall: true,
mcpServerName: mcpMatch[1],
};
}
}
const [func, _domain] = name.includes(actionDelimiter)
? name.split(actionDelimiter)
: [name, ''];
@ -71,25 +99,20 @@ export default function ToolCall({
isMCPToolCall: false,
mcpServerName: '',
};
}, [name]);
}, [name, parsedAuthUrl]);
const toolIconType = useMemo(() => getToolIconType(name), [name]);
const mcpIconMap = useMCPIconMap();
const mcpIconUrl = isMCPToolCall ? mcpIconMap.get(mcpServerName) : undefined;
const actionId = useMemo(() => {
if (isMCPToolCall || !auth) {
if (isMCPToolCall || !parsedAuthUrl) {
return '';
}
try {
const url = new URL(auth);
const redirectUri = url.searchParams.get('redirect_uri') || '';
const match = redirectUri.match(/\/api\/actions\/([^/]+)\/oauth\/callback/);
return match?.[1] || '';
} catch {
return '';
}
}, [auth, isMCPToolCall]);
const redirectUri = parsedAuthUrl.searchParams.get('redirect_uri') || '';
const match = redirectUri.match(/\/api\/actions\/([^/]+)\/oauth\/callback/);
return match?.[1] || '';
}, [parsedAuthUrl, isMCPToolCall]);
const handleOAuthClick = useCallback(async () => {
if (!auth) {
@ -132,21 +155,8 @@ export default function ToolCall({
);
const authDomain = useMemo(() => {
const authURL = auth ?? '';
if (!authURL) {
return '';
}
try {
const url = new URL(authURL);
return url.hostname;
} catch (e) {
logger.error(
'client/src/components/Chat/Messages/Content/ToolCall.tsx - Failed to parse auth URL',
e,
);
return '';
}
}, [auth]);
return parsedAuthUrl?.hostname ?? '';
}, [parsedAuthUrl]);
const progress = useProgress(initialProgress);
const showCancelled = cancelled || (errorState && !output);

View file

@ -3,11 +3,11 @@ import { ChevronDown } from 'lucide-react';
import { Tools } from 'librechat-data-provider';
import { UIResourceRenderer } from '@mcp-ui/client';
import type { TAttachment, UIResource } from 'librechat-data-provider';
import { useOptionalMessagesOperations } from '~/Providers';
import { useLocalize, useExpandCollapse } from '~/hooks';
import UIResourceCarousel from './UIResourceCarousel';
import { useMessagesOperations } from '~/Providers';
import { OutputRenderer } from './ToolOutput';
import { handleUIAction, cn } from '~/utils';
import { OutputRenderer } from './ToolOutput';
function isSimpleObject(obj: unknown): obj is Record<string, string | number | boolean | null> {
if (typeof obj !== 'object' || obj === null || Array.isArray(obj)) {
@ -102,7 +102,7 @@ export default function ToolCallInfo({
attachments?: TAttachment[];
}) {
const localize = useLocalize();
const { ask } = useMessagesOperations();
const { ask } = useOptionalMessagesOperations();
const [showParams, setShowParams] = useState(false);
const { style: paramsExpandStyle, ref: paramsExpandRef } = useExpandCollapse(showParams);

View file

@ -1,7 +1,7 @@
import React, { useState } from 'react';
import { UIResourceRenderer } from '@mcp-ui/client';
import type { UIResource } from 'librechat-data-provider';
import { useMessagesOperations } from '~/Providers';
import { useOptionalMessagesOperations } from '~/Providers';
import { handleUIAction } from '~/utils';
interface UIResourceCarouselProps {
@ -13,7 +13,7 @@ const UIResourceCarousel: React.FC<UIResourceCarouselProps> = React.memo(({ uiRe
const [showRightArrow, setShowRightArrow] = useState(true);
const [isContainerHovered, setIsContainerHovered] = useState(false);
const scrollContainerRef = React.useRef<HTMLDivElement>(null);
const { ask } = useMessagesOperations();
const { ask } = useOptionalMessagesOperations();
const handleScroll = React.useCallback(() => {
if (!scrollContainerRef.current) return;

View file

@ -3,7 +3,11 @@ import { render, screen } from '@testing-library/react';
import Markdown from '../Markdown';
import { RecoilRoot } from 'recoil';
import { UI_RESOURCE_MARKER } from '~/components/MCPUIResource/plugin';
import { useMessageContext, useMessagesConversation, useMessagesOperations } from '~/Providers';
import {
useMessageContext,
useOptionalMessagesConversation,
useOptionalMessagesOperations,
} from '~/Providers';
import { useGetMessagesByConvoId } from '~/data-provider';
import { useLocalize } from '~/hooks';
@ -12,8 +16,8 @@ import { useLocalize } from '~/hooks';
jest.mock('~/Providers', () => ({
...jest.requireActual('~/Providers'),
useMessageContext: jest.fn(),
useMessagesConversation: jest.fn(),
useMessagesOperations: jest.fn(),
useOptionalMessagesConversation: jest.fn(),
useOptionalMessagesOperations: jest.fn(),
}));
jest.mock('~/data-provider');
jest.mock('~/hooks');
@ -26,11 +30,11 @@ jest.mock('@mcp-ui/client', () => ({
}));
const mockUseMessageContext = useMessageContext as jest.MockedFunction<typeof useMessageContext>;
const mockUseMessagesConversation = useMessagesConversation as jest.MockedFunction<
typeof useMessagesConversation
const mockUseMessagesConversation = useOptionalMessagesConversation as jest.MockedFunction<
typeof useOptionalMessagesConversation
>;
const mockUseMessagesOperations = useMessagesOperations as jest.MockedFunction<
typeof useMessagesOperations
const mockUseMessagesOperations = useOptionalMessagesOperations as jest.MockedFunction<
typeof useOptionalMessagesOperations
>;
const mockUseGetMessagesByConvoId = useGetMessagesByConvoId as jest.MockedFunction<
typeof useGetMessagesByConvoId

View file

@ -1,6 +1,6 @@
import React from 'react';
import { RecoilRoot } from 'recoil';
import { Tools } from 'librechat-data-provider';
import { Tools, Constants } from 'librechat-data-provider';
import { render, screen, fireEvent } from '@testing-library/react';
import ToolCall from '../ToolCall';
@ -53,9 +53,20 @@ jest.mock('../ToolCallInfo', () => ({
jest.mock('../ProgressText', () => ({
__esModule: true,
default: ({ onClick, inProgressText, finishedText, _error, _hasInput, _isExpanded }: any) => (
default: ({
onClick,
inProgressText,
finishedText,
subtitle,
}: {
onClick?: () => void;
inProgressText?: string;
finishedText?: string;
subtitle?: string;
}) => (
<div data-testid="progress-text" onClick={onClick}>
{finishedText || inProgressText}
{subtitle && <span data-testid="subtitle">{subtitle}</span>}
</div>
),
}));
@ -346,6 +357,141 @@ describe('ToolCall', () => {
});
});
describe('MCP OAuth detection', () => {
const d = Constants.mcp_delimiter;
it('should detect MCP OAuth from delimiter in tool-call name', () => {
renderWithRecoil(
<ToolCall
{...mockProps}
name={`oauth${d}my-server`}
initialProgress={0.5}
isSubmitting={true}
auth="https://auth.example.com"
/>,
);
const subtitle = screen.getByTestId('subtitle');
expect(subtitle.textContent).toBe('via my-server');
});
it('should preserve full server name when it contains the delimiter substring', () => {
renderWithRecoil(
<ToolCall
{...mockProps}
name={`oauth${d}foo${d}bar`}
initialProgress={0.5}
isSubmitting={true}
auth="https://auth.example.com"
/>,
);
const subtitle = screen.getByTestId('subtitle');
expect(subtitle.textContent).toBe(`via foo${d}bar`);
});
it('should display server name (not "oauth") as function_name for OAuth tool calls', () => {
renderWithRecoil(
<ToolCall
{...mockProps}
name={`oauth${d}my-server`}
initialProgress={1}
isSubmitting={false}
output="done"
auth="https://auth.example.com"
/>,
);
const progressText = screen.getByTestId('progress-text');
expect(progressText.textContent).toContain('Completed my-server');
expect(progressText.textContent).not.toContain('Completed oauth');
});
it('should display server name even when auth is cleared (post-completion)', () => {
// After OAuth completes, createOAuthEnd re-emits the toolCall without auth.
// The display should still show the server name, not literal "oauth".
renderWithRecoil(
<ToolCall
{...mockProps}
name={`oauth${d}my-server`}
initialProgress={1}
isSubmitting={false}
output="done"
/>,
);
const progressText = screen.getByTestId('progress-text');
expect(progressText.textContent).toContain('Completed my-server');
expect(progressText.textContent).not.toContain('Completed oauth');
});
it('should fallback to auth URL redirect_uri when name lacks delimiter', () => {
const authUrl =
'https://oauth.example.com/authorize?redirect_uri=' +
encodeURIComponent('https://app.example.com/api/mcp/my-server/oauth/callback');
renderWithRecoil(
<ToolCall
{...mockProps}
name="bare_name"
initialProgress={0.5}
isSubmitting={true}
auth={authUrl}
/>,
);
const subtitle = screen.getByTestId('subtitle');
expect(subtitle.textContent).toBe('via my-server');
});
it('should display server name (not raw tool-call ID) in fallback path finished text', () => {
const authUrl =
'https://oauth.example.com/authorize?redirect_uri=' +
encodeURIComponent('https://app.example.com/api/mcp/my-server/oauth/callback');
renderWithRecoil(
<ToolCall
{...mockProps}
name="bare_name"
initialProgress={1}
isSubmitting={false}
output="done"
auth={authUrl}
/>,
);
const progressText = screen.getByTestId('progress-text');
expect(progressText.textContent).toContain('Completed my-server');
expect(progressText.textContent).not.toContain('bare_name');
});
it('should show normalized server name when it contains _mcp_ after prefixing', () => {
// Server named oauth@mcp@server normalizes to oauth_mcp_server,
// gets prefixed to oauth_mcp_oauth_mcp_server. Client parses:
// func="oauth", server="oauth_mcp_server". Visually awkward but
// semantically correct — the normalized name IS oauth_mcp_server.
renderWithRecoil(
<ToolCall
{...mockProps}
name={`oauth${d}oauth${d}server`}
initialProgress={0.5}
isSubmitting={true}
auth="https://auth.example.com"
/>,
);
const subtitle = screen.getByTestId('subtitle');
expect(subtitle.textContent).toBe(`via oauth${d}server`);
});
it('should not misidentify non-MCP action auth as MCP via fallback', () => {
const authUrl =
'https://oauth.example.com/authorize?redirect_uri=' +
encodeURIComponent('https://app.example.com/api/actions/xyz/oauth/callback');
renderWithRecoil(
<ToolCall
{...mockProps}
name="action_name"
initialProgress={0.5}
isSubmitting={true}
auth={authUrl}
/>,
);
expect(screen.queryByTestId('subtitle')).not.toBeInTheDocument();
});
});
describe('A11Y-04: screen reader status announcements', () => {
it('includes sr-only aria-live region for status announcements', () => {
renderWithRecoil(

View file

@ -25,7 +25,7 @@ jest.mock('~/hooks', () => ({
}));
jest.mock('~/Providers', () => ({
useMessagesOperations: () => ({
useOptionalMessagesOperations: () => ({
ask: jest.fn(),
}),
}));

View file

@ -13,10 +13,10 @@ jest.mock('@mcp-ui/client', () => ({
),
}));
// Mock useMessagesOperations hook
// Mock useOptionalMessagesOperations hook
const mockAsk = jest.fn();
jest.mock('~/Providers', () => ({
useMessagesOperations: () => ({
useOptionalMessagesOperations: () => ({
ask: mockAsk,
}),
}));

View file

@ -1,5 +1,5 @@
import React from 'react';
import { useMessageProcess } from '~/hooks';
import { useMessageProcess, useMemoizedChatContext } from '~/hooks';
import type { TMessageProps } from '~/common';
import MessageRender from './ui/MessageRender';
import MultiMessage from './MultiMessage';
@ -23,10 +23,11 @@ const MessageContainer = React.memo(function MessageContainer({
});
export default function Message(props: TMessageProps) {
const { conversation, handleScroll } = useMessageProcess({
const { conversation, handleScroll, isSubmitting } = useMessageProcess({
message: props.message,
});
const { message, currentEditId, setCurrentEditId } = props;
const { chatContext, effectiveIsSubmitting } = useMemoizedChatContext(message, isSubmitting);
if (!message || typeof message !== 'object') {
return null;
@ -38,7 +39,11 @@ export default function Message(props: TMessageProps) {
<>
<MessageContainer handleScroll={handleScroll}>
<div className="m-auto justify-center p-4 py-2 md:gap-6">
<MessageRender {...props} />
<MessageRender
{...props}
isSubmitting={effectiveIsSubmitting}
chatContext={chatContext}
/>
</div>
</MessageContainer>
<MultiMessage

View file

@ -2,7 +2,7 @@ import React, { useCallback, useMemo, memo } from 'react';
import { useAtomValue } from 'jotai';
import { useRecoilValue } from 'recoil';
import type { TMessage } from 'librechat-data-provider';
import type { TMessageProps, TMessageIcon } from '~/common';
import type { TMessageProps, TMessageIcon, TMessageChatContext } from '~/common';
import { cn, getHeaderPrefixForScreenReader, getMessageAriaLabel } from '~/utils';
import MessageContent from '~/components/Chat/Messages/Content/MessageContent';
import { useLocalize, useMessageActions, useContentMetadata } from '~/hooks';
@ -17,12 +17,73 @@ import store from '~/store';
type MessageRenderProps = {
message?: TMessage;
/**
* Effective isSubmitting: false for non-latest messages, real value for latest.
* Computed by the wrapper (Message.tsx) so this memo'd component only re-renders
* when the value actually matters.
*/
isSubmitting?: boolean;
/** Stable context object from wrapper — avoids ChatContext subscription inside memo */
chatContext: TMessageChatContext;
} & Pick<
TMessageProps,
'currentEditId' | 'setCurrentEditId' | 'siblingIdx' | 'setSiblingIdx' | 'siblingCount'
>;
/**
* Custom comparator for React.memo: compares `message` by key fields instead of reference
* because `buildTree` creates new message objects on every streaming update for ALL messages,
* even when only the latest message's text changed.
*/
function areMessageRenderPropsEqual(prev: MessageRenderProps, next: MessageRenderProps): boolean {
if (prev.isSubmitting !== next.isSubmitting) {
return false;
}
if (prev.chatContext !== next.chatContext) {
return false;
}
if (prev.siblingIdx !== next.siblingIdx) {
return false;
}
if (prev.siblingCount !== next.siblingCount) {
return false;
}
if (prev.currentEditId !== next.currentEditId) {
return false;
}
if (prev.setSiblingIdx !== next.setSiblingIdx) {
return false;
}
if (prev.setCurrentEditId !== next.setCurrentEditId) {
return false;
}
const prevMsg = prev.message;
const nextMsg = next.message;
if (prevMsg === nextMsg) {
return true;
}
if (!prevMsg || !nextMsg) {
return prevMsg === nextMsg;
}
return (
prevMsg.messageId === nextMsg.messageId &&
prevMsg.text === nextMsg.text &&
prevMsg.error === nextMsg.error &&
prevMsg.unfinished === nextMsg.unfinished &&
prevMsg.depth === nextMsg.depth &&
prevMsg.isCreatedByUser === nextMsg.isCreatedByUser &&
(prevMsg.children?.length ?? 0) === (nextMsg.children?.length ?? 0) &&
prevMsg.content === nextMsg.content &&
prevMsg.model === nextMsg.model &&
prevMsg.endpoint === nextMsg.endpoint &&
prevMsg.iconURL === nextMsg.iconURL &&
prevMsg.feedback?.rating === nextMsg.feedback?.rating &&
(prevMsg.files?.length ?? 0) === (nextMsg.files?.length ?? 0)
);
}
const MessageRender = memo(function MessageRender({
message: msg,
siblingIdx,
@ -31,6 +92,7 @@ const MessageRender = memo(function MessageRender({
currentEditId,
setCurrentEditId,
isSubmitting = false,
chatContext,
}: MessageRenderProps) {
const localize = useLocalize();
const {
@ -52,6 +114,7 @@ const MessageRender = memo(function MessageRender({
message: msg,
currentEditId,
setCurrentEditId,
chatContext,
});
const fontSize = useAtomValue(fontSizeAtom);
const maximizeChatSpace = useRecoilValue(store.maximizeChatSpace);
@ -63,8 +126,6 @@ const MessageRender = memo(function MessageRender({
[hasNoChildren, msg?.depth, latestMessageDepth],
);
const isLatestMessage = msg?.messageId === latestMessageId;
/** Only pass isSubmitting to the latest message to prevent unnecessary re-renders */
const effectiveIsSubmitting = isLatestMessage ? isSubmitting : false;
const iconData: TMessageIcon = useMemo(
() => ({
@ -92,10 +153,10 @@ const MessageRender = memo(function MessageRender({
messageId,
isLatestMessage,
isExpanded: false as const,
isSubmitting: effectiveIsSubmitting,
isSubmitting,
conversationId: conversation?.conversationId,
}),
[messageId, conversation?.conversationId, effectiveIsSubmitting, isLatestMessage],
[messageId, conversation?.conversationId, isSubmitting, isLatestMessage],
);
if (!msg) {
@ -165,7 +226,7 @@ const MessageRender = memo(function MessageRender({
message={msg}
enterEdit={enterEdit}
error={!!(msg.error ?? false)}
isSubmitting={effectiveIsSubmitting}
isSubmitting={isSubmitting}
unfinished={msg.unfinished ?? false}
isCreatedByUser={msg.isCreatedByUser ?? true}
siblingIdx={siblingIdx ?? 0}
@ -173,7 +234,7 @@ const MessageRender = memo(function MessageRender({
/>
</MessageContext.Provider>
</div>
{hasNoChildren && effectiveIsSubmitting ? (
{hasNoChildren && isSubmitting ? (
<PlaceholderRow />
) : (
<SubRow classes="text-xs">
@ -187,7 +248,7 @@ const MessageRender = memo(function MessageRender({
isEditing={edit}
message={msg}
enterEdit={enterEdit}
isSubmitting={isSubmitting}
isSubmitting={chatContext.isSubmitting}
conversation={conversation ?? null}
regenerate={handleRegenerateMessage}
copyToClipboard={copyToClipboard}
@ -202,7 +263,7 @@ const MessageRender = memo(function MessageRender({
</div>
</div>
);
});
}, areMessageRenderPropsEqual);
MessageRender.displayName = 'MessageRender';
export default MessageRender;

View file

@ -1,8 +1,8 @@
import React from 'react';
import { UIResourceRenderer } from '@mcp-ui/client';
import { handleUIAction } from '~/utils';
import { useOptionalMessagesConversation, useOptionalMessagesOperations } from '~/Providers';
import { useConversationUIResources } from '~/hooks/Messages/useConversationUIResources';
import { useMessagesConversation, useMessagesOperations } from '~/Providers';
import { handleUIAction } from '~/utils';
import { useLocalize } from '~/hooks';
interface MCPUIResourceProps {
@ -13,19 +13,14 @@ interface MCPUIResourceProps {
};
}
/**
* Component that renders an MCP UI resource based on its resource ID.
* Works in both main app and share view.
*/
/** Renders an MCP UI resource based on its resource ID. Works in chat, share, and search views. */
export function MCPUIResource(props: MCPUIResourceProps) {
const { resourceId } = props.node.properties;
const localize = useLocalize();
const { ask } = useMessagesOperations();
const { conversation } = useMessagesConversation();
const { ask } = useOptionalMessagesOperations();
const { conversationId } = useOptionalMessagesConversation();
const conversationResourceMap = useConversationUIResources(
conversation?.conversationId ?? undefined,
);
const conversationResourceMap = useConversationUIResources(conversationId ?? undefined);
const uiResource = conversationResourceMap.get(resourceId ?? '');

View file

@ -1,8 +1,8 @@
import React, { useMemo } from 'react';
import { useConversationUIResources } from '~/hooks/Messages/useConversationUIResources';
import { useMessagesConversation } from '~/Providers';
import UIResourceCarousel from '../Chat/Messages/Content/UIResourceCarousel';
import type { UIResource } from 'librechat-data-provider';
import { useConversationUIResources } from '~/hooks/Messages/useConversationUIResources';
import UIResourceCarousel from '../Chat/Messages/Content/UIResourceCarousel';
import { useOptionalMessagesConversation } from '~/Providers';
interface MCPUIResourceCarouselProps {
node: {
@ -12,16 +12,11 @@ interface MCPUIResourceCarouselProps {
};
}
/**
* Component that renders multiple MCP UI resources in a carousel.
* Works in both main app and share view.
*/
/** Renders multiple MCP UI resources in a carousel. Works in chat, share, and search views. */
export function MCPUIResourceCarousel(props: MCPUIResourceCarouselProps) {
const { conversation } = useMessagesConversation();
const { conversationId } = useOptionalMessagesConversation();
const conversationResourceMap = useConversationUIResources(
conversation?.conversationId ?? undefined,
);
const conversationResourceMap = useConversationUIResources(conversationId ?? undefined);
const uiResources = useMemo(() => {
const { resourceIds = [] } = props.node.properties;

View file

@ -2,7 +2,11 @@ import React from 'react';
import { render, screen } from '@testing-library/react';
import { RecoilRoot } from 'recoil';
import { MCPUIResource } from '../MCPUIResource';
import { useMessageContext, useMessagesConversation, useMessagesOperations } from '~/Providers';
import {
useMessageContext,
useOptionalMessagesConversation,
useOptionalMessagesOperations,
} from '~/Providers';
import { useLocalize } from '~/hooks';
import { handleUIAction } from '~/utils';
@ -22,11 +26,11 @@ jest.mock('@mcp-ui/client', () => ({
}));
const mockUseMessageContext = useMessageContext as jest.MockedFunction<typeof useMessageContext>;
const mockUseMessagesConversation = useMessagesConversation as jest.MockedFunction<
typeof useMessagesConversation
const mockUseMessagesConversation = useOptionalMessagesConversation as jest.MockedFunction<
typeof useOptionalMessagesConversation
>;
const mockUseMessagesOperations = useMessagesOperations as jest.MockedFunction<
typeof useMessagesOperations
const mockUseMessagesOperations = useOptionalMessagesOperations as jest.MockedFunction<
typeof useOptionalMessagesOperations
>;
const mockUseLocalize = useLocalize as jest.MockedFunction<typeof useLocalize>;
const mockHandleUIAction = handleUIAction as jest.MockedFunction<typeof handleUIAction>;

View file

@ -2,7 +2,11 @@ import React from 'react';
import { render, screen } from '@testing-library/react';
import { RecoilRoot } from 'recoil';
import { MCPUIResourceCarousel } from '../MCPUIResourceCarousel';
import { useMessageContext, useMessagesConversation, useMessagesOperations } from '~/Providers';
import {
useMessageContext,
useOptionalMessagesConversation,
useOptionalMessagesOperations,
} from '~/Providers';
// Mock dependencies
jest.mock('~/Providers');
@ -19,11 +23,11 @@ jest.mock('../../Chat/Messages/Content/UIResourceCarousel', () => ({
}));
const mockUseMessageContext = useMessageContext as jest.MockedFunction<typeof useMessageContext>;
const mockUseMessagesConversation = useMessagesConversation as jest.MockedFunction<
typeof useMessagesConversation
const mockUseMessagesConversation = useOptionalMessagesConversation as jest.MockedFunction<
typeof useOptionalMessagesConversation
>;
const mockUseMessagesOperations = useMessagesOperations as jest.MockedFunction<
typeof useMessagesOperations
const mockUseMessagesOperations = useOptionalMessagesOperations as jest.MockedFunction<
typeof useOptionalMessagesOperations
>;
describe('MCPUIResourceCarousel', () => {

View file

@ -2,7 +2,7 @@ import { useCallback, useMemo, memo } from 'react';
import { useAtomValue } from 'jotai';
import { useRecoilValue } from 'recoil';
import type { TMessage, TMessageContentParts } from 'librechat-data-provider';
import type { TMessageProps, TMessageIcon } from '~/common';
import type { TMessageProps, TMessageIcon, TMessageChatContext } from '~/common';
import { useAttachments, useLocalize, useMessageActions, useContentMetadata } from '~/hooks';
import { cn, getHeaderPrefixForScreenReader, getMessageAriaLabel } from '~/utils';
import ContentParts from '~/components/Chat/Messages/Content/ContentParts';
@ -16,12 +16,72 @@ import store from '~/store';
type ContentRenderProps = {
message?: TMessage;
/**
* Effective isSubmitting: false for non-latest messages, real value for latest.
* Computed by the wrapper (MessageContent.tsx) so this memo'd component only re-renders
* when the value actually matters.
*/
isSubmitting?: boolean;
/** Stable context object from wrapper — avoids ChatContext subscription inside memo */
chatContext: TMessageChatContext;
} & Pick<
TMessageProps,
'currentEditId' | 'setCurrentEditId' | 'siblingIdx' | 'setSiblingIdx' | 'siblingCount'
>;
/**
* Custom comparator for React.memo: compares `message` by key fields instead of reference
* because `buildTree` creates new message objects on every streaming update for ALL messages.
*/
function areContentRenderPropsEqual(prev: ContentRenderProps, next: ContentRenderProps): boolean {
if (prev.isSubmitting !== next.isSubmitting) {
return false;
}
if (prev.chatContext !== next.chatContext) {
return false;
}
if (prev.siblingIdx !== next.siblingIdx) {
return false;
}
if (prev.siblingCount !== next.siblingCount) {
return false;
}
if (prev.currentEditId !== next.currentEditId) {
return false;
}
if (prev.setSiblingIdx !== next.setSiblingIdx) {
return false;
}
if (prev.setCurrentEditId !== next.setCurrentEditId) {
return false;
}
const prevMsg = prev.message;
const nextMsg = next.message;
if (prevMsg === nextMsg) {
return true;
}
if (!prevMsg || !nextMsg) {
return prevMsg === nextMsg;
}
return (
prevMsg.messageId === nextMsg.messageId &&
prevMsg.text === nextMsg.text &&
prevMsg.error === nextMsg.error &&
prevMsg.unfinished === nextMsg.unfinished &&
prevMsg.depth === nextMsg.depth &&
prevMsg.isCreatedByUser === nextMsg.isCreatedByUser &&
(prevMsg.children?.length ?? 0) === (nextMsg.children?.length ?? 0) &&
prevMsg.content === nextMsg.content &&
prevMsg.model === nextMsg.model &&
prevMsg.endpoint === nextMsg.endpoint &&
prevMsg.iconURL === nextMsg.iconURL &&
prevMsg.feedback?.rating === nextMsg.feedback?.rating &&
(prevMsg.attachments?.length ?? 0) === (nextMsg.attachments?.length ?? 0)
);
}
const ContentRender = memo(function ContentRender({
message: msg,
siblingIdx,
@ -30,6 +90,7 @@ const ContentRender = memo(function ContentRender({
currentEditId,
setCurrentEditId,
isSubmitting = false,
chatContext,
}: ContentRenderProps) {
const localize = useLocalize();
const { attachments, searchResults } = useAttachments({
@ -55,6 +116,7 @@ const ContentRender = memo(function ContentRender({
searchResults,
currentEditId,
setCurrentEditId,
chatContext,
});
const fontSize = useAtomValue(fontSizeAtom);
const maximizeChatSpace = useRecoilValue(store.maximizeChatSpace);
@ -66,8 +128,6 @@ const ContentRender = memo(function ContentRender({
);
const hasNoChildren = !(msg?.children?.length ?? 0);
const isLatestMessage = msg?.messageId === latestMessageId;
/** Only pass isSubmitting to the latest message to prevent unnecessary re-renders */
const effectiveIsSubmitting = isLatestMessage ? isSubmitting : false;
const iconData: TMessageIcon = useMemo(
() => ({
@ -158,13 +218,13 @@ const ContentRender = memo(function ContentRender({
searchResults={searchResults}
setSiblingIdx={setSiblingIdx}
isLatestMessage={isLatestMessage}
isSubmitting={effectiveIsSubmitting}
isSubmitting={isSubmitting}
isCreatedByUser={msg.isCreatedByUser}
conversationId={conversation?.conversationId}
content={msg.content as Array<TMessageContentParts | undefined>}
/>
</div>
{hasNoChildren && effectiveIsSubmitting ? (
{hasNoChildren && isSubmitting ? (
<PlaceholderRow />
) : (
<SubRow classes="text-xs">
@ -178,7 +238,7 @@ const ContentRender = memo(function ContentRender({
message={msg}
isEditing={edit}
enterEdit={enterEdit}
isSubmitting={isSubmitting}
isSubmitting={chatContext.isSubmitting}
conversation={conversation ?? null}
regenerate={handleRegenerateMessage}
copyToClipboard={copyToClipboard}
@ -193,7 +253,7 @@ const ContentRender = memo(function ContentRender({
</div>
</div>
);
});
}, areContentRenderPropsEqual);
ContentRender.displayName = 'ContentRender';
export default ContentRender;

View file

@ -1,5 +1,5 @@
import React from 'react';
import { useMessageProcess } from '~/hooks';
import { useMessageProcess, useMemoizedChatContext } from '~/hooks';
import type { TMessageProps } from '~/common';
import MultiMessage from '~/components/Chat/Messages/MultiMessage';
@ -28,6 +28,7 @@ export default function MessageContent(props: TMessageProps) {
message: props.message,
});
const { message, currentEditId, setCurrentEditId } = props;
const { chatContext, effectiveIsSubmitting } = useMemoizedChatContext(message, isSubmitting);
if (!message || typeof message !== 'object') {
return null;
@ -39,7 +40,11 @@ export default function MessageContent(props: TMessageProps) {
<>
<MessageContainer handleScroll={handleScroll}>
<div className="m-auto justify-center p-4 py-2 md:gap-6">
<ContentRender {...props} isSubmitting={isSubmitting} />
<ContentRender
{...props}
isSubmitting={effectiveIsSubmitting}
chatContext={chatContext}
/>
</div>
</MessageContainer>
<MultiMessage

View file

@ -1,4 +1,4 @@
import { useCallback } from 'react';
import { useCallback, useMemo } from 'react';
import { useRecoilValue } from 'recoil';
import { useGetModelsQuery } from 'librechat-data-provider/react-query';
import {
@ -122,9 +122,12 @@ export default function useAddedResponse() {
],
);
return {
conversation,
setConversation,
generateConversation,
};
return useMemo(
() => ({
conversation,
setConversation,
generateConversation,
}),
[conversation, setConversation, generateConversation],
);
}

View file

@ -1,6 +1,6 @@
export { default as useDeleteFilesFromTable } from './useDeleteFilesFromTable';
export { default as useSetFilesToDelete } from './useSetFilesToDelete';
export { default as useFileHandling } from './useFileHandling';
export { default as useFileHandling, useFileHandlingNoChatContext } from './useFileHandling';
export { default as useFileDeletion } from './useFileDeletion';
export { default as useUpdateFiles } from './useUpdateFiles';
export { default as useDragHelpers } from './useDragHelpers';

View file

@ -5,6 +5,7 @@ export { default as useSubmitMessage } from './useSubmitMessage';
export type { ContentMetadataResult } from './useContentMetadata';
export { default as useExpandCollapse } from './useExpandCollapse';
export { default as useMessageActions } from './useMessageActions';
export { default as useMemoizedChatContext } from './useMemoizedChatContext';
export { default as useMessageProcess } from './useMessageProcess';
export { default as useMessageHelpers } from './useMessageHelpers';
export { default as useCopyToClipboard } from './useCopyToClipboard';

View file

@ -2,7 +2,7 @@ import { useMemo } from 'react';
import { useRecoilValue } from 'recoil';
import { Tools } from 'librechat-data-provider';
import type { TAttachment, UIResource } from 'librechat-data-provider';
import { useMessagesOperations } from '~/Providers';
import { useOptionalMessagesOperations } from '~/Providers';
import store from '~/store';
/**
@ -16,7 +16,7 @@ import store from '~/store';
export function useConversationUIResources(
conversationId: string | undefined,
): Map<string, UIResource> {
const { getMessages } = useMessagesOperations();
const { getMessages } = useOptionalMessagesOperations();
const conversationAttachmentsMap = useRecoilValue(
store.conversationAttachmentsSelector(conversationId),

View file

@ -0,0 +1,80 @@
import { useRef, useMemo } from 'react';
import type { TMessage } from 'librechat-data-provider';
import type { TMessageChatContext } from '~/common/types';
import { useChatContext } from '~/Providers';
/**
* Creates a stable `TMessageChatContext` object for memo'd message components.
*
* Subscribes to `useChatContext()` internally (intended to be called from non-memo'd
* wrapper components like `Message` and `MessageContent`), then produces:
* - A `chatContext` object that stays referentially stable during streaming
* (uses a getter for `isSubmitting` backed by a ref)
* - A stable `conversation` reference that only updates when rendering-relevant fields change
* - An `effectiveIsSubmitting` value (false for non-latest messages)
*/
export default function useMemoizedChatContext(
message: TMessage | null | undefined,
isSubmitting: boolean,
) {
const chatCtx = useChatContext();
const isSubmittingRef = useRef(isSubmitting);
isSubmittingRef.current = isSubmitting;
/**
* Stabilize conversation: only update when rendering-relevant fields change,
* not on every metadata update (e.g., title generation).
*/
const stableConversation = useMemo(
() => chatCtx.conversation,
// eslint-disable-next-line react-hooks/exhaustive-deps
[
chatCtx.conversation?.conversationId,
chatCtx.conversation?.endpoint,
chatCtx.conversation?.endpointType,
chatCtx.conversation?.model,
chatCtx.conversation?.agent_id,
chatCtx.conversation?.assistant_id,
],
);
/**
* `isSubmitting` is included in deps so that chatContext gets a new reference
* when streaming starts/ends (2x per session). This ensures HoverButtons
* re-renders to update regenerate/edit button visibility via useGenerationsByLatest.
* The getter pattern is still valuable: callbacks reading chatContext.isSubmitting
* at call-time always get the current value even between these re-renders.
*/
const chatContext: TMessageChatContext = useMemo(
() => ({
ask: chatCtx.ask,
index: chatCtx.index,
regenerate: chatCtx.regenerate,
conversation: stableConversation,
latestMessageId: chatCtx.latestMessageId,
latestMessageDepth: chatCtx.latestMessageDepth,
handleContinue: chatCtx.handleContinue,
get isSubmitting() {
return isSubmittingRef.current;
},
}),
// eslint-disable-next-line react-hooks/exhaustive-deps
[
chatCtx.ask,
chatCtx.index,
chatCtx.regenerate,
stableConversation,
chatCtx.latestMessageId,
chatCtx.latestMessageDepth,
chatCtx.handleContinue,
isSubmitting, // intentional: forces new reference on streaming start/end so HoverButtons re-renders
],
);
const messageId = message?.messageId ?? null;
const isLatestMessage = messageId === chatCtx.latestMessageId;
const effectiveIsSubmitting = isLatestMessage ? isSubmitting : false;
return { chatContext, effectiveIsSubmitting };
}

View file

@ -11,7 +11,8 @@ import {
TUpdateFeedbackRequest,
} from 'librechat-data-provider';
import type { TMessageProps } from '~/common';
import { useChatContext, useAssistantsMapContext, useAgentsMapContext } from '~/Providers';
import type { TMessageChatContext } from '~/common/types';
import { useAssistantsMapContext, useAgentsMapContext } from '~/Providers';
import useCopyToClipboard from './useCopyToClipboard';
import { useAuthContext } from '~/hooks/AuthContext';
import { useGetAddedConvo } from '~/hooks/Chat';
@ -23,24 +24,33 @@ export type TMessageActions = Pick<
'message' | 'currentEditId' | 'setCurrentEditId'
> & {
searchResults?: { [key: string]: SearchResultData };
/**
* Stable context object passed from wrapper components to avoid subscribing
* to ChatContext inside memo'd components (which would bypass React.memo).
* The `isSubmitting` property uses a getter backed by a ref, so it always
* returns the current value at call-time without triggering re-renders.
*/
chatContext: TMessageChatContext;
};
export default function useMessageActions(props: TMessageActions) {
const localize = useLocalize();
const { user } = useAuthContext();
const UsernameDisplay = useRecoilValue<boolean>(store.UsernameDisplay);
const { message, currentEditId, setCurrentEditId, searchResults } = props;
const { message, currentEditId, setCurrentEditId, searchResults, chatContext } = props;
const {
ask,
index,
regenerate,
isSubmitting,
conversation,
latestMessageId,
latestMessageDepth,
handleContinue,
} = useChatContext();
// NOTE: isSubmitting is intentionally NOT destructured here.
// chatContext.isSubmitting is a getter backed by a ref — destructuring
// would capture a one-time snapshot. Always access via chatContext.isSubmitting.
} = chatContext;
const getAddedConvo = useGetAddedConvo();
@ -98,13 +108,18 @@ export default function useMessageActions(props: TMessageActions) {
}
}, [agentsMap, conversation?.agent_id, conversation?.endpoint, message?.model]);
/**
* chatContext.isSubmitting is a getter backed by the wrapper's ref,
* so it always returns the current value at call-time even for
* non-latest messages that don't re-render during streaming.
*/
const regenerateMessage = useCallback(() => {
if ((isSubmitting && isCreatedByUser === true) || !message) {
if ((chatContext.isSubmitting && isCreatedByUser === true) || !message) {
return;
}
regenerate(message, { addedConvo: getAddedConvo() });
}, [isSubmitting, isCreatedByUser, message, regenerate, getAddedConvo]);
}, [chatContext, isCreatedByUser, message, regenerate, getAddedConvo]);
const copyToClipboard = useCopyToClipboard({ text, content, searchResults });

View file

@ -1,6 +1,6 @@
import throttle from 'lodash/throttle';
import { Constants } from 'librechat-data-provider';
import { useEffect, useRef, useCallback, useMemo } from 'react';
import { useEffect, useRef, useMemo } from 'react';
import type { TMessage } from 'librechat-data-provider';
import { getTextKey, TEXT_KEY_DIVIDER, logger } from '~/utils';
import { useMessagesViewContext } from '~/Providers';
@ -56,24 +56,25 @@ export default function useMessageProcess({ message }: { message?: TMessage | nu
}
}, [hasNoChildren, message, setLatestMessage, conversation?.conversationId]);
const handleScroll = useCallback(
(event: unknown | TouchEvent | WheelEvent) => {
throttle(() => {
/** Use ref for isSubmitting to stabilize handleScroll across isSubmitting changes */
const isSubmittingRef = useRef(isSubmitting);
isSubmittingRef.current = isSubmitting;
const handleScroll = useMemo(
() =>
throttle((event: unknown) => {
logger.log(
'message_scrolling',
`useMessageProcess: setting abort scroll to ${isSubmitting}, handleScroll event`,
`useMessageProcess: setting abort scroll to ${isSubmittingRef.current}, handleScroll event`,
event,
);
if (isSubmitting) {
setAbortScroll(true);
} else {
setAbortScroll(false);
}
}, 500)();
},
[isSubmitting, setAbortScroll],
setAbortScroll(isSubmittingRef.current);
}, 500),
[setAbortScroll],
);
useEffect(() => () => handleScroll.cancel(), [handleScroll]);
return {
handleScroll,
isSubmitting,

View file

@ -1,5 +1,5 @@
import { renderHook, act } from '@testing-library/react';
import { Constants, ErrorTypes, LocalStorageKeys } from 'librechat-data-provider';
import { Constants, LocalStorageKeys } from 'librechat-data-provider';
import type { TSubmission } from 'librechat-data-provider';
type SSEEventListener = (e: Partial<MessageEvent> & { responseCode?: number }) => void;
@ -34,7 +34,13 @@ jest.mock('sse.js', () => ({
}));
const mockSetQueryData = jest.fn();
const mockQueryClient = { setQueryData: mockSetQueryData };
const mockInvalidateQueries = jest.fn();
const mockRemoveQueries = jest.fn();
const mockQueryClient = {
setQueryData: mockSetQueryData,
invalidateQueries: mockInvalidateQueries,
removeQueries: mockRemoveQueries,
};
jest.mock('@tanstack/react-query', () => ({
...jest.requireActual('@tanstack/react-query'),
@ -63,6 +69,7 @@ jest.mock('~/data-provider', () => ({
useGetStartupConfig: () => ({ data: { balance: { enabled: false } } }),
useGetUserBalance: () => ({ refetch: jest.fn() }),
queueTitleGeneration: jest.fn(),
streamStatusQueryKey: (conversationId: string) => ['streamStatus', conversationId],
}));
const mockErrorHandler = jest.fn();
@ -162,6 +169,11 @@ describe('useResumableSSE - 404 error path', () => {
beforeEach(() => {
mockSSEInstances.length = 0;
localStorage.clear();
mockErrorHandler.mockClear();
mockClearStepMaps.mockClear();
mockSetIsSubmitting.mockClear();
mockInvalidateQueries.mockClear();
mockRemoveQueries.mockClear();
});
const seedDraft = (conversationId: string) => {
@ -200,19 +212,18 @@ describe('useResumableSSE - 404 error path', () => {
unmount();
});
it('calls errorHandler with STREAM_EXPIRED error type on 404', async () => {
it('invalidates message cache and clears stream status on 404 instead of showing error', async () => {
const { unmount } = await render404Scenario(CONV_ID);
expect(mockErrorHandler).toHaveBeenCalledTimes(1);
const call = mockErrorHandler.mock.calls[0][0];
expect(call.data).toBeDefined();
const parsed = JSON.parse(call.data.text);
expect(parsed.type).toBe(ErrorTypes.STREAM_EXPIRED);
expect(call.submission).toEqual(
expect.objectContaining({
conversation: expect.objectContaining({ conversationId: CONV_ID }),
}),
);
expect(mockErrorHandler).not.toHaveBeenCalled();
expect(mockInvalidateQueries).toHaveBeenCalledWith({
queryKey: ['messages', CONV_ID],
});
expect(mockRemoveQueries).toHaveBeenCalledWith({
queryKey: ['streamStatus', CONV_ID],
});
expect(mockClearStepMaps).toHaveBeenCalled();
expect(mockSetIsSubmitting).toHaveBeenCalledWith(false);
unmount();
});

View file

@ -16,7 +16,12 @@ import {
} from 'librechat-data-provider';
import type { TMessage, TPayload, TSubmission, EventSubmission } from 'librechat-data-provider';
import type { EventHandlerParams } from './useEventHandlers';
import { useGetStartupConfig, useGetUserBalance, queueTitleGeneration } from '~/data-provider';
import {
useGetUserBalance,
useGetStartupConfig,
queueTitleGeneration,
streamStatusQueryKey,
} from '~/data-provider';
import type { ActiveJobsResponse } from '~/data-provider';
import { useAuthContext } from '~/hooks/AuthContext';
import useEventHandlers from './useEventHandlers';
@ -343,18 +348,20 @@ export default function useResumableSSE(
/* @ts-ignore - sse.js types don't expose responseCode */
const responseCode = e.responseCode;
// 404 means job doesn't exist (completed/deleted) - don't retry
// 404 → job completed & was cleaned up; messages are persisted in DB.
// Invalidate cache once so react-query refetches instead of showing an error.
if (responseCode === 404) {
console.log('[ResumableSSE] Stream not found (404) - job completed or expired');
const convoId = currentSubmission.conversation?.conversationId;
console.log('[ResumableSSE] Stream 404, invalidating messages for:', convoId);
sse.close();
removeActiveJob(currentStreamId);
clearAllDrafts(currentSubmission.conversation?.conversationId);
errorHandler({
data: {
text: JSON.stringify({ type: ErrorTypes.STREAM_EXPIRED }),
} as unknown as Parameters<typeof errorHandler>[0]['data'],
submission: currentSubmission as EventSubmission,
});
clearAllDrafts(convoId);
clearStepMaps();
if (convoId) {
queryClient.invalidateQueries({ queryKey: [QueryKeys.messages, convoId] });
queryClient.removeQueries({ queryKey: streamStatusQueryKey(convoId) });
}
setIsSubmitting(false);
setShowStopButton(false);
setStreamId(null);
reconnectAttemptRef.current = 0;
@ -544,6 +551,7 @@ export default function useResumableSSE(
startupConfig?.balance?.enabled,
balanceQuery,
removeActiveJob,
queryClient,
],
);

View file

@ -125,7 +125,11 @@ export default function useResumeOnLoad(
conversationId !== Constants.NEW_CONVO &&
processedConvoRef.current !== conversationId; // Don't re-check processed convos
const { data: streamStatus, isSuccess } = useStreamStatus(conversationId, shouldCheck);
const {
data: streamStatus,
isSuccess,
isFetching,
} = useStreamStatus(conversationId, shouldCheck);
useEffect(() => {
console.log('[ResumeOnLoad] Effect check', {
@ -135,6 +139,7 @@ export default function useResumeOnLoad(
hasCurrentSubmission: !!currentSubmission,
currentSubmissionConvoId: currentSubmission?.conversation?.conversationId,
isSuccess,
isFetching,
streamStatusActive: streamStatus?.active,
streamStatusStreamId: streamStatus?.streamId,
processedConvoRef: processedConvoRef.current,
@ -171,8 +176,9 @@ export default function useResumeOnLoad(
);
}
// Wait for stream status query to complete
if (!isSuccess || !streamStatus) {
// Wait for stream status query to complete (including background refetches
// that may replace a stale cached result with fresh data)
if (!isSuccess || !streamStatus || isFetching) {
console.log('[ResumeOnLoad] Waiting for stream status query');
return;
}
@ -183,15 +189,12 @@ export default function useResumeOnLoad(
return;
}
// Check if there's an active job to resume
// DON'T mark as processed here - only mark when we actually create a submission
// This prevents stale cache data from blocking subsequent resume attempts
if (!streamStatus.active || !streamStatus.streamId) {
console.log('[ResumeOnLoad] No active job to resume for:', conversationId);
processedConvoRef.current = conversationId;
return;
}
// Mark as processed NOW - we verified there's an active job and will create submission
processedConvoRef.current = conversationId;
console.log('[ResumeOnLoad] Found active job, creating submission...', {
@ -241,6 +244,7 @@ export default function useResumeOnLoad(
submissionConvoId,
currentSubmission,
isSuccess,
isFetching,
streamStatus,
getMessages,
setSubmission,

View file

@ -8,6 +8,7 @@ export default {
'\\.helper\\.ts$',
'\\.helper\\.d\\.ts$',
'/__tests__/helpers/',
'\\.manual\\.spec\\.[jt]sx?$',
],
coverageReporters: ['text', 'cobertura'],
testResultsProcessor: 'jest-junit',

Some files were not shown because too many files have changed in this diff Show more