mirror of
https://github.com/danny-avila/LibreChat.git
synced 2026-04-07 00:15:23 +02:00
Merge branch 'main' into fix/client-image-resize-threshold
This commit is contained in:
commit
4b14978473
246 changed files with 15929 additions and 2880 deletions
3
.gitattributes
vendored
Normal file
3
.gitattributes
vendored
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
# Force LF line endings for shell scripts and git hooks (required for cross-platform compatibility)
|
||||
.husky/* text eol=lf
|
||||
*.sh text eol=lf
|
||||
59
.github/workflows/backend-review.yml
vendored
59
.github/workflows/backend-review.yml
vendored
|
|
@ -97,6 +97,65 @@ jobs:
|
|||
path: packages/api/dist
|
||||
retention-days: 2
|
||||
|
||||
typecheck:
|
||||
name: TypeScript type checks
|
||||
needs: build
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 10
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Use Node.js 20.19
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '20.19'
|
||||
|
||||
- name: Restore node_modules cache
|
||||
id: cache-node-modules
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
node_modules
|
||||
api/node_modules
|
||||
packages/api/node_modules
|
||||
packages/data-provider/node_modules
|
||||
packages/data-schemas/node_modules
|
||||
key: node-modules-backend-${{ runner.os }}-20.19-${{ hashFiles('package-lock.json') }}
|
||||
|
||||
- name: Install dependencies
|
||||
if: steps.cache-node-modules.outputs.cache-hit != 'true'
|
||||
run: npm ci
|
||||
|
||||
- name: Download data-provider build
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: build-data-provider
|
||||
path: packages/data-provider/dist
|
||||
|
||||
- name: Download data-schemas build
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: build-data-schemas
|
||||
path: packages/data-schemas/dist
|
||||
|
||||
- name: Download api build
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: build-api
|
||||
path: packages/api/dist
|
||||
|
||||
- name: Type check data-provider
|
||||
run: npx tsc --noEmit -p packages/data-provider/tsconfig.json
|
||||
|
||||
- name: Type check data-schemas
|
||||
run: npx tsc --noEmit -p packages/data-schemas/tsconfig.json
|
||||
|
||||
- name: Type check @librechat/api
|
||||
run: npx tsc --noEmit -p packages/api/tsconfig.json
|
||||
|
||||
- name: Type check @librechat/client
|
||||
run: npx tsc --noEmit -p packages/client/tsconfig.json
|
||||
|
||||
circular-deps:
|
||||
name: Circular dependency checks
|
||||
needs: build
|
||||
|
|
|
|||
|
|
@ -29,6 +29,12 @@ The source code for `@librechat/agents` (major backend dependency, same team) is
|
|||
|
||||
## Code Style
|
||||
|
||||
### Naming and File Organization
|
||||
|
||||
- **Single-word file names** whenever possible (e.g., `permissions.ts`, `capabilities.ts`, `service.ts`).
|
||||
- When multiple words are needed, prefer grouping related modules under a **single-word directory** rather than using multi-word file names (e.g., `admin/capabilities.ts` not `adminCapabilities.ts`).
|
||||
- The directory already provides context — `app/service.ts` not `app/appConfigService.ts`.
|
||||
|
||||
### Structure and Clarity
|
||||
|
||||
- **Never-nesting**: early returns, flat code, minimal indentation. Break complex operations into well-named helpers.
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@
|
|||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://railway.com/deploy/b5k2mn?referralCode=HI9hWz">
|
||||
<a href="https://railway.com/deploy/librechat-official?referralCode=HI9hWz&utm_medium=integration&utm_source=readme&utm_campaign=librechat">
|
||||
<img src="https://railway.com/button.svg" alt="Deploy on Railway" height="30">
|
||||
</a>
|
||||
<a href="https://zeabur.com/templates/0X2ZY8">
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
<!-- Last synced with README.md: 2026-03-20 (e442984364db02163f3cc3ecb7b2ee5efba66fb9) -->
|
||||
<!-- Last synced with README.md: 2026-03-28 (cae3888) -->
|
||||
|
||||
<p align="center">
|
||||
<a href="https://librechat.ai">
|
||||
|
|
@ -34,7 +34,7 @@
|
|||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://railway.com/deploy/b5k2mn?referralCode=HI9hWz">
|
||||
<a href="https://railway.com/deploy/librechat-official?referralCode=HI9hWz&utm_medium=integration&utm_source=readme&utm_campaign=librechat">
|
||||
<img src="https://railway.com/button.svg" alt="Deploy on Railway" height="30">
|
||||
</a>
|
||||
<a href="https://zeabur.com/templates/0X2ZY8">
|
||||
|
|
|
|||
|
|
@ -32,7 +32,6 @@ class BaseClient {
|
|||
constructor(apiKey, options = {}) {
|
||||
this.apiKey = apiKey;
|
||||
this.sender = options.sender ?? 'AI';
|
||||
this.contextStrategy = null;
|
||||
this.currentDateString = new Date().toLocaleDateString('en-us', {
|
||||
year: 'numeric',
|
||||
month: 'long',
|
||||
|
|
|
|||
|
|
@ -14,7 +14,6 @@ const {
|
|||
buildImageToolContext,
|
||||
buildWebSearchContext,
|
||||
} = require('@librechat/api');
|
||||
const { getMCPServersRegistry } = require('~/config');
|
||||
const {
|
||||
Tools,
|
||||
Constants,
|
||||
|
|
@ -39,12 +38,13 @@ const {
|
|||
createGeminiImageTool,
|
||||
createOpenAIImageTools,
|
||||
} = require('../');
|
||||
const { primeFiles: primeCodeFiles } = require('~/server/services/Files/Code/process');
|
||||
const { createMCPTool, createMCPTools, resolveConfigServers } = require('~/server/services/MCP');
|
||||
const { createFileSearchTool, primeFiles: primeSearchFiles } = require('./fileSearch');
|
||||
const { primeFiles: primeCodeFiles } = require('~/server/services/Files/Code/process');
|
||||
const { getUserPluginAuthValue } = require('~/server/services/PluginService');
|
||||
const { createMCPTool, createMCPTools } = require('~/server/services/MCP');
|
||||
const { loadAuthValues } = require('~/server/services/Tools/credentials');
|
||||
const { getMCPServerTools } = require('~/server/services/Config');
|
||||
const { getMCPServersRegistry } = require('~/config');
|
||||
const { getRoleByName } = require('~/models');
|
||||
|
||||
/**
|
||||
|
|
@ -256,6 +256,12 @@ const loadTools = async ({
|
|||
const toolContextMap = {};
|
||||
const requestedMCPTools = {};
|
||||
|
||||
/** Resolve config-source servers for the current user/tenant context */
|
||||
let configServers;
|
||||
if (tools.some((tool) => tool && mcpToolPattern.test(tool))) {
|
||||
configServers = await resolveConfigServers(options.req);
|
||||
}
|
||||
|
||||
for (const tool of tools) {
|
||||
if (tool === Tools.execute_code) {
|
||||
requestedTools[tool] = async () => {
|
||||
|
|
@ -341,7 +347,7 @@ const loadTools = async ({
|
|||
continue;
|
||||
}
|
||||
const serverConfig = serverName
|
||||
? await getMCPServersRegistry().getServerConfig(serverName, user)
|
||||
? await getMCPServersRegistry().getServerConfig(serverName, user, configServers)
|
||||
: null;
|
||||
if (!serverConfig) {
|
||||
logger.warn(
|
||||
|
|
@ -419,6 +425,7 @@ const loadTools = async ({
|
|||
let index = -1;
|
||||
const failedMCPServers = new Set();
|
||||
const safeUser = createSafeUser(options.req?.user);
|
||||
|
||||
for (const [serverName, toolConfigs] of Object.entries(requestedMCPTools)) {
|
||||
index++;
|
||||
/** @type {LCAvailableTools} */
|
||||
|
|
@ -433,6 +440,7 @@ const loadTools = async ({
|
|||
signal,
|
||||
user: safeUser,
|
||||
userMCPAuthMap,
|
||||
configServers,
|
||||
res: options.res,
|
||||
streamId: options.req?._resumableStreamId || null,
|
||||
model: agent?.model ?? model,
|
||||
|
|
|
|||
|
|
@ -123,9 +123,6 @@ function disposeClient(client) {
|
|||
if (client.maxContextTokens) {
|
||||
client.maxContextTokens = null;
|
||||
}
|
||||
if (client.contextStrategy) {
|
||||
client.contextStrategy = null;
|
||||
}
|
||||
if (client.currentDateString) {
|
||||
client.currentDateString = null;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
const { logger } = require('@librechat/data-schemas');
|
||||
const { CacheKeys } = require('librechat-data-provider');
|
||||
const { logger, scopedCacheKey } = require('@librechat/data-schemas');
|
||||
const { loadDefaultModels, loadConfigModels } = require('~/server/services/Config');
|
||||
const { getLogStores } = require('~/cache');
|
||||
|
||||
|
|
@ -9,7 +9,8 @@ const { getLogStores } = require('~/cache');
|
|||
*/
|
||||
const getModelsConfig = async (req) => {
|
||||
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
||||
let modelsConfig = await cache.get(CacheKeys.MODELS_CONFIG);
|
||||
const cacheKey = scopedCacheKey(CacheKeys.MODELS_CONFIG);
|
||||
let modelsConfig = await cache.get(cacheKey);
|
||||
if (!modelsConfig) {
|
||||
modelsConfig = await loadModels(req);
|
||||
}
|
||||
|
|
@ -24,7 +25,8 @@ const getModelsConfig = async (req) => {
|
|||
*/
|
||||
async function loadModels(req) {
|
||||
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
||||
const cachedModelsConfig = await cache.get(CacheKeys.MODELS_CONFIG);
|
||||
const cacheKey = scopedCacheKey(CacheKeys.MODELS_CONFIG);
|
||||
const cachedModelsConfig = await cache.get(cacheKey);
|
||||
if (cachedModelsConfig) {
|
||||
return cachedModelsConfig;
|
||||
}
|
||||
|
|
@ -33,7 +35,7 @@ async function loadModels(req) {
|
|||
|
||||
const modelConfig = { ...defaultModelsConfig, ...customModelsConfig };
|
||||
|
||||
await cache.set(CacheKeys.MODELS_CONFIG, modelConfig);
|
||||
await cache.set(cacheKey, modelConfig);
|
||||
return modelConfig;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
const { logger } = require('@librechat/data-schemas');
|
||||
const { CacheKeys } = require('librechat-data-provider');
|
||||
const { logger, scopedCacheKey } = require('@librechat/data-schemas');
|
||||
const { getToolkitKey, checkPluginAuth, filterUniquePlugins } = require('@librechat/api');
|
||||
const { getCachedTools, setCachedTools } = require('~/server/services/Config');
|
||||
const { availableTools, toolkits } = require('~/app/clients/tools');
|
||||
|
|
@ -9,13 +9,14 @@ const { getLogStores } = require('~/cache');
|
|||
const getAvailablePluginsController = async (req, res) => {
|
||||
try {
|
||||
const cache = getLogStores(CacheKeys.TOOL_CACHE);
|
||||
const cachedPlugins = await cache.get(CacheKeys.PLUGINS);
|
||||
const pluginsCacheKey = scopedCacheKey(CacheKeys.PLUGINS);
|
||||
const cachedPlugins = await cache.get(pluginsCacheKey);
|
||||
if (cachedPlugins) {
|
||||
res.status(200).json(cachedPlugins);
|
||||
return;
|
||||
}
|
||||
|
||||
const appConfig = await getAppConfig({ role: req.user?.role });
|
||||
const appConfig = await getAppConfig({ role: req.user?.role, tenantId: req.user?.tenantId });
|
||||
/** @type {{ filteredTools: string[], includedTools: string[] }} */
|
||||
const { filteredTools = [], includedTools = [] } = appConfig;
|
||||
/** @type {import('@librechat/api').LCManifestTool[]} */
|
||||
|
|
@ -37,7 +38,7 @@ const getAvailablePluginsController = async (req, res) => {
|
|||
plugins = plugins.filter((plugin) => !filteredTools.includes(plugin.pluginKey));
|
||||
}
|
||||
|
||||
await cache.set(CacheKeys.PLUGINS, plugins);
|
||||
await cache.set(pluginsCacheKey, plugins);
|
||||
res.status(200).json(plugins);
|
||||
} catch (error) {
|
||||
res.status(500).json({ message: error.message });
|
||||
|
|
@ -64,9 +65,11 @@ const getAvailableTools = async (req, res) => {
|
|||
return res.status(401).json({ message: 'Unauthorized' });
|
||||
}
|
||||
const cache = getLogStores(CacheKeys.TOOL_CACHE);
|
||||
const cachedToolsArray = await cache.get(CacheKeys.TOOLS);
|
||||
const toolsCacheKey = scopedCacheKey(CacheKeys.TOOLS);
|
||||
const cachedToolsArray = await cache.get(toolsCacheKey);
|
||||
|
||||
const appConfig = req.config ?? (await getAppConfig({ role: req.user?.role }));
|
||||
const appConfig =
|
||||
req.config ?? (await getAppConfig({ role: req.user?.role, tenantId: req.user?.tenantId }));
|
||||
|
||||
// Return early if we have cached tools
|
||||
if (cachedToolsArray != null) {
|
||||
|
|
@ -114,7 +117,7 @@ const getAvailableTools = async (req, res) => {
|
|||
}
|
||||
|
||||
const finalTools = filterUniquePlugins(toolsOutput);
|
||||
await cache.set(CacheKeys.TOOLS, finalTools);
|
||||
await cache.set(toolsCacheKey, finalTools);
|
||||
|
||||
res.status(200).json(finalTools);
|
||||
} catch (error) {
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ jest.mock('@librechat/data-schemas', () => ({
|
|||
error: jest.fn(),
|
||||
warn: jest.fn(),
|
||||
},
|
||||
scopedCacheKey: jest.fn((key) => key),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Config', () => ({
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ const { getLogStores } = require('~/cache');
|
|||
const db = require('~/models');
|
||||
|
||||
const getUserController = async (req, res) => {
|
||||
const appConfig = await getAppConfig({ role: req.user?.role });
|
||||
const appConfig = await getAppConfig({ role: req.user?.role, tenantId: req.user?.tenantId });
|
||||
/** @type {IUser} */
|
||||
const userData = req.user.toObject != null ? req.user.toObject() : { ...req.user };
|
||||
/**
|
||||
|
|
@ -165,7 +165,7 @@ const deleteUserMcpServers = async (userId) => {
|
|||
};
|
||||
|
||||
const updateUserPluginsController = async (req, res) => {
|
||||
const appConfig = await getAppConfig({ role: req.user?.role });
|
||||
const appConfig = await getAppConfig({ role: req.user?.role, tenantId: req.user?.tenantId });
|
||||
const { user } = req;
|
||||
const { pluginKey, action, auth, isEntityTool } = req.body;
|
||||
try {
|
||||
|
|
|
|||
|
|
@ -50,6 +50,7 @@ const {
|
|||
const { filterFilesByAgentAccess } = require('~/server/services/Files/permissions');
|
||||
const { encodeAndFormat } = require('~/server/services/Files/images/encode');
|
||||
const { createContextHandlers } = require('~/app/clients/prompts');
|
||||
const { resolveConfigServers } = require('~/server/services/MCP');
|
||||
const { getMCPServerTools } = require('~/server/services/Config');
|
||||
const BaseClient = require('~/app/clients/BaseClient');
|
||||
const { getMCPManager } = require('~/config');
|
||||
|
|
@ -377,6 +378,9 @@ class AgentClient extends BaseClient {
|
|||
*/
|
||||
const ephemeralAgent = this.options.req.body.ephemeralAgent;
|
||||
const mcpManager = getMCPManager();
|
||||
|
||||
const configServers = await resolveConfigServers(this.options.req);
|
||||
|
||||
await Promise.all(
|
||||
allAgents.map(({ agent, agentId }) =>
|
||||
applyContextToAgent({
|
||||
|
|
@ -384,6 +388,7 @@ class AgentClient extends BaseClient {
|
|||
agentId,
|
||||
logger,
|
||||
mcpManager,
|
||||
configServers,
|
||||
sharedRunContext,
|
||||
ephemeralAgent: agentId === this.options.agent.id ? ephemeralAgent : undefined,
|
||||
}),
|
||||
|
|
|
|||
|
|
@ -22,6 +22,10 @@ jest.mock('~/server/services/Config', () => ({
|
|||
getMCPServerTools: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/MCP', () => ({
|
||||
resolveConfigServers: jest.fn().mockResolvedValue({}),
|
||||
}));
|
||||
|
||||
jest.mock('~/models', () => ({
|
||||
getAgent: jest.fn(),
|
||||
getRoleByName: jest.fn(),
|
||||
|
|
@ -1315,7 +1319,7 @@ describe('AgentClient - titleConvo', () => {
|
|||
});
|
||||
|
||||
// Verify formatInstructionsForContext was called with correct server names
|
||||
expect(mockFormatInstructions).toHaveBeenCalledWith(['server1', 'server2']);
|
||||
expect(mockFormatInstructions).toHaveBeenCalledWith(['server1', 'server2'], {});
|
||||
|
||||
// Verify the instructions do NOT contain [object Promise]
|
||||
expect(client.options.agent.instructions).not.toContain('[object Promise]');
|
||||
|
|
@ -1355,10 +1359,10 @@ describe('AgentClient - titleConvo', () => {
|
|||
});
|
||||
|
||||
// Verify formatInstructionsForContext was called with ephemeral server names
|
||||
expect(mockFormatInstructions).toHaveBeenCalledWith([
|
||||
'ephemeral-server1',
|
||||
'ephemeral-server2',
|
||||
]);
|
||||
expect(mockFormatInstructions).toHaveBeenCalledWith(
|
||||
['ephemeral-server1', 'ephemeral-server2'],
|
||||
{},
|
||||
);
|
||||
|
||||
// Verify no [object Promise] in instructions
|
||||
expect(client.options.agent.instructions).not.toContain('[object Promise]');
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ const {
|
|||
isMCPInspectionFailedError,
|
||||
} = require('@librechat/api');
|
||||
const { Constants, MCPServerUserInputSchema } = require('librechat-data-provider');
|
||||
const { resolveConfigServers, resolveAllMcpConfigs } = require('~/server/services/MCP');
|
||||
const { cacheMCPServerTools, getMCPServerTools } = require('~/server/services/Config');
|
||||
const { getMCPManager, getMCPServersRegistry } = require('~/config');
|
||||
|
||||
|
|
@ -57,7 +58,7 @@ function handleMCPError(error, res) {
|
|||
}
|
||||
|
||||
/**
|
||||
* Get all MCP tools available to the user
|
||||
* Get all MCP tools available to the user.
|
||||
*/
|
||||
const getMCPTools = async (req, res) => {
|
||||
try {
|
||||
|
|
@ -67,10 +68,10 @@ const getMCPTools = async (req, res) => {
|
|||
return res.status(401).json({ message: 'Unauthorized' });
|
||||
}
|
||||
|
||||
const mcpConfig = await getMCPServersRegistry().getAllServerConfigs(userId);
|
||||
const configuredServers = mcpConfig ? Object.keys(mcpConfig) : [];
|
||||
const mcpConfig = await resolveAllMcpConfigs(userId, req.user);
|
||||
const configuredServers = Object.keys(mcpConfig);
|
||||
|
||||
if (!mcpConfig || Object.keys(mcpConfig).length == 0) {
|
||||
if (!configuredServers.length) {
|
||||
return res.status(200).json({ servers: {} });
|
||||
}
|
||||
|
||||
|
|
@ -115,14 +116,11 @@ const getMCPTools = async (req, res) => {
|
|||
try {
|
||||
const serverTools = serverToolsMap.get(serverName);
|
||||
|
||||
// Get server config once
|
||||
const serverConfig = mcpConfig[serverName];
|
||||
const rawServerConfig = await getMCPServersRegistry().getServerConfig(serverName, userId);
|
||||
|
||||
// Initialize server object with all server-level data
|
||||
const server = {
|
||||
name: serverName,
|
||||
icon: rawServerConfig?.iconPath || '',
|
||||
icon: serverConfig?.iconPath || '',
|
||||
authenticated: true,
|
||||
authConfig: [],
|
||||
tools: [],
|
||||
|
|
@ -183,7 +181,7 @@ const getMCPServersList = async (req, res) => {
|
|||
return res.status(401).json({ message: 'Unauthorized' });
|
||||
}
|
||||
|
||||
const serverConfigs = await getMCPServersRegistry().getAllServerConfigs(userId);
|
||||
const serverConfigs = await resolveAllMcpConfigs(userId, req.user);
|
||||
return res.json(redactAllServerSecrets(serverConfigs));
|
||||
} catch (error) {
|
||||
logger.error('[getMCPServersList]', error);
|
||||
|
|
@ -237,7 +235,12 @@ const getMCPServerById = async (req, res) => {
|
|||
if (!serverName) {
|
||||
return res.status(400).json({ message: 'Server name is required' });
|
||||
}
|
||||
const parsedConfig = await getMCPServersRegistry().getServerConfig(serverName, userId);
|
||||
const configServers = await resolveConfigServers(req);
|
||||
const parsedConfig = await getMCPServersRegistry().getServerConfig(
|
||||
serverName,
|
||||
userId,
|
||||
configServers,
|
||||
);
|
||||
|
||||
if (!parsedConfig) {
|
||||
return res.status(404).json({ message: 'MCP server not found' });
|
||||
|
|
|
|||
|
|
@ -8,8 +8,8 @@ const express = require('express');
|
|||
const passport = require('passport');
|
||||
const compression = require('compression');
|
||||
const cookieParser = require('cookie-parser');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const mongoSanitize = require('express-mongo-sanitize');
|
||||
const { logger, runAsSystem } = require('@librechat/data-schemas');
|
||||
const {
|
||||
isEnabled,
|
||||
apiNotFound,
|
||||
|
|
@ -21,6 +21,7 @@ const {
|
|||
createStreamServices,
|
||||
initializeFileStorage,
|
||||
updateInterfacePermissions,
|
||||
preAuthTenantMiddleware,
|
||||
} = require('@librechat/api');
|
||||
const { connectDb, indexSync } = require('~/db');
|
||||
const initializeOAuthReconnectManager = require('./services/initializeOAuthReconnectManager');
|
||||
|
|
@ -59,11 +60,20 @@ const startServer = async () => {
|
|||
app.disable('x-powered-by');
|
||||
app.set('trust proxy', trusted_proxy);
|
||||
|
||||
await seedDatabase();
|
||||
const appConfig = await getAppConfig();
|
||||
if (isEnabled(process.env.TENANT_ISOLATION_STRICT)) {
|
||||
logger.warn(
|
||||
'[Security] TENANT_ISOLATION_STRICT is active. Ensure your reverse proxy strips or sets ' +
|
||||
'the X-Tenant-Id header — untrusted clients must not be able to set it directly.',
|
||||
);
|
||||
}
|
||||
|
||||
await runAsSystem(seedDatabase);
|
||||
const appConfig = await getAppConfig({ baseOnly: true });
|
||||
initializeFileStorage(appConfig);
|
||||
await performStartupChecks(appConfig);
|
||||
await updateInterfacePermissions({ appConfig, getRoleByName, updateAccessPermissions });
|
||||
await runAsSystem(async () => {
|
||||
await performStartupChecks(appConfig);
|
||||
await updateInterfacePermissions({ appConfig, getRoleByName, updateAccessPermissions });
|
||||
});
|
||||
|
||||
const indexPath = path.join(appConfig.paths.dist, 'index.html');
|
||||
let indexHTML = fs.readFileSync(indexPath, 'utf8');
|
||||
|
|
@ -137,10 +147,15 @@ const startServer = async () => {
|
|||
/* Per-request capability cache — must be registered before any route that calls hasCapability */
|
||||
app.use(capabilityContextMiddleware);
|
||||
|
||||
app.use('/oauth', routes.oauth);
|
||||
/* Pre-auth tenant context for unauthenticated routes that need tenant scoping.
|
||||
* The reverse proxy / auth gateway sets `X-Tenant-Id` header for multi-tenant deployments. */
|
||||
app.use('/oauth', preAuthTenantMiddleware, routes.oauth);
|
||||
/* API Endpoints */
|
||||
app.use('/api/auth', routes.auth);
|
||||
app.use('/api/auth', preAuthTenantMiddleware, routes.auth);
|
||||
app.use('/api/admin', routes.adminAuth);
|
||||
app.use('/api/admin/config', routes.adminConfig);
|
||||
app.use('/api/admin/groups', routes.adminGroups);
|
||||
app.use('/api/admin/roles', routes.adminRoles);
|
||||
app.use('/api/actions', routes.actions);
|
||||
app.use('/api/keys', routes.keys);
|
||||
app.use('/api/api-keys', routes.apiKeys);
|
||||
|
|
@ -154,11 +169,11 @@ const startServer = async () => {
|
|||
app.use('/api/endpoints', routes.endpoints);
|
||||
app.use('/api/balance', routes.balance);
|
||||
app.use('/api/models', routes.models);
|
||||
app.use('/api/config', routes.config);
|
||||
app.use('/api/config', preAuthTenantMiddleware, routes.config);
|
||||
app.use('/api/assistants', routes.assistants);
|
||||
app.use('/api/files', await routes.files.initialize());
|
||||
app.use('/images/', createValidateImageRequest(appConfig.secureImageLinks), routes.staticRoute);
|
||||
app.use('/api/share', routes.share);
|
||||
app.use('/api/share', preAuthTenantMiddleware, routes.share);
|
||||
app.use('/api/roles', routes.roles);
|
||||
app.use('/api/agents', routes.agents);
|
||||
app.use('/api/banner', routes.banner);
|
||||
|
|
@ -204,8 +219,10 @@ const startServer = async () => {
|
|||
logger.info(`Server listening at http://${host == '0.0.0.0' ? 'localhost' : host}:${port}`);
|
||||
}
|
||||
|
||||
await initializeMCPs();
|
||||
await initializeOAuthReconnectManager();
|
||||
await runAsSystem(async () => {
|
||||
await initializeMCPs();
|
||||
await initializeOAuthReconnectManager();
|
||||
});
|
||||
await checkMigrations();
|
||||
|
||||
// Configure stream services (auto-detects Redis from USE_REDIS env var)
|
||||
|
|
|
|||
116
api/server/middleware/__tests__/requireJwtAuth.spec.js
Normal file
116
api/server/middleware/__tests__/requireJwtAuth.spec.js
Normal file
|
|
@ -0,0 +1,116 @@
|
|||
/**
|
||||
* Integration test: verifies that requireJwtAuth chains tenantContextMiddleware
|
||||
* after successful passport authentication, so ALS tenant context is set for
|
||||
* all downstream middleware and route handlers.
|
||||
*
|
||||
* requireJwtAuth must chain tenantContextMiddleware after passport populates
|
||||
* req.user (not at global app.use() scope where req.user is undefined).
|
||||
* If the chaining is removed, these tests fail.
|
||||
*/
|
||||
|
||||
const { getTenantId } = require('@librechat/data-schemas');
|
||||
|
||||
// ── Mocks ──────────────────────────────────────────────────────────────
|
||||
|
||||
let mockPassportError = null;
|
||||
|
||||
jest.mock('passport', () => ({
|
||||
authenticate: jest.fn(() => {
|
||||
return (req, _res, done) => {
|
||||
if (mockPassportError) {
|
||||
return done(mockPassportError);
|
||||
}
|
||||
if (req._mockUser) {
|
||||
req.user = req._mockUser;
|
||||
}
|
||||
done();
|
||||
};
|
||||
}),
|
||||
}));
|
||||
|
||||
// Mock @librechat/api — the real tenantContextMiddleware is TS and cannot be
|
||||
// required directly from CJS tests. This thin wrapper mirrors the real logic
|
||||
// (read req.user.tenantId, call tenantStorage.run) using the same data-schemas
|
||||
// primitives. The real implementation is covered by packages/api tenant.spec.ts.
|
||||
jest.mock('@librechat/api', () => {
|
||||
const { tenantStorage } = require('@librechat/data-schemas');
|
||||
return {
|
||||
isEnabled: jest.fn(() => false),
|
||||
tenantContextMiddleware: (req, res, next) => {
|
||||
const tenantId = req.user?.tenantId;
|
||||
if (!tenantId) {
|
||||
return next();
|
||||
}
|
||||
return tenantStorage.run({ tenantId }, async () => next());
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
// ── Helpers ─────────────────────────────────────────────────────────────
|
||||
|
||||
const requireJwtAuth = require('../requireJwtAuth');
|
||||
|
||||
function mockReq(user) {
|
||||
return { headers: {}, _mockUser: user };
|
||||
}
|
||||
|
||||
function mockRes() {
|
||||
return { status: jest.fn().mockReturnThis(), json: jest.fn().mockReturnThis() };
|
||||
}
|
||||
|
||||
/** Runs requireJwtAuth and returns the tenantId observed inside next(). */
|
||||
function runAuth(user) {
|
||||
return new Promise((resolve) => {
|
||||
const req = mockReq(user);
|
||||
const res = mockRes();
|
||||
requireJwtAuth(req, res, () => {
|
||||
resolve(getTenantId());
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// ── Tests ──────────────────────────────────────────────────────────────
|
||||
|
||||
describe('requireJwtAuth tenant context chaining', () => {
|
||||
afterEach(() => {
|
||||
mockPassportError = null;
|
||||
});
|
||||
|
||||
it('forwards passport errors to next() without entering tenant middleware', async () => {
|
||||
mockPassportError = new Error('JWT signature invalid');
|
||||
const req = mockReq(undefined);
|
||||
const res = mockRes();
|
||||
const err = await new Promise((resolve) => {
|
||||
requireJwtAuth(req, res, (e) => resolve(e));
|
||||
});
|
||||
expect(err).toBeInstanceOf(Error);
|
||||
expect(err.message).toBe('JWT signature invalid');
|
||||
expect(getTenantId()).toBeUndefined();
|
||||
});
|
||||
|
||||
it('sets ALS tenant context after passport auth succeeds', async () => {
|
||||
const tenantId = await runAuth({ tenantId: 'tenant-abc', role: 'user' });
|
||||
expect(tenantId).toBe('tenant-abc');
|
||||
});
|
||||
|
||||
it('ALS tenant context is NOT set when user has no tenantId', async () => {
|
||||
const tenantId = await runAuth({ role: 'user' });
|
||||
expect(tenantId).toBeUndefined();
|
||||
});
|
||||
|
||||
it('ALS tenant context is NOT set when user is undefined', async () => {
|
||||
const tenantId = await runAuth(undefined);
|
||||
expect(tenantId).toBeUndefined();
|
||||
});
|
||||
|
||||
it('concurrent requests get isolated tenant contexts', async () => {
|
||||
const results = await Promise.all(
|
||||
['tenant-1', 'tenant-2', 'tenant-3'].map((tid) => runAuth({ tenantId: tid, role: 'user' })),
|
||||
);
|
||||
expect(results).toEqual(['tenant-1', 'tenant-2', 'tenant-3']);
|
||||
});
|
||||
|
||||
it('ALS context is not set at top-level scope (outside any request)', () => {
|
||||
expect(getTenantId()).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
|
@ -18,6 +18,7 @@ const checkDomainAllowed = async (req, res, next) => {
|
|||
const email = req?.user?.email;
|
||||
const appConfig = await getAppConfig({
|
||||
role: req?.user?.role,
|
||||
tenantId: req?.user?.tenantId,
|
||||
});
|
||||
|
||||
if (email && !isEmailDomainAllowed(email, appConfig?.registration?.allowedDomains)) {
|
||||
|
|
|
|||
|
|
@ -4,7 +4,9 @@ const { getAppConfig } = require('~/server/services/Config');
|
|||
const configMiddleware = async (req, res, next) => {
|
||||
try {
|
||||
const userRole = req.user?.role;
|
||||
req.config = await getAppConfig({ role: userRole });
|
||||
const userId = req.user?.id;
|
||||
const tenantId = req.user?.tenantId;
|
||||
req.config = await getAppConfig({ role: userRole, userId, tenantId });
|
||||
|
||||
next();
|
||||
} catch (error) {
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
const cookies = require('cookie');
|
||||
const passport = require('passport');
|
||||
const { isEnabled } = require('@librechat/api');
|
||||
const { isEnabled, tenantContextMiddleware } = require('@librechat/api');
|
||||
|
||||
// This middleware does not require authentication,
|
||||
// but if the user is authenticated, it will set the user object.
|
||||
// but if the user is authenticated, it will set the user object
|
||||
// and establish tenant ALS context.
|
||||
const optionalJwtAuth = (req, res, next) => {
|
||||
const cookieHeader = req.headers.cookie;
|
||||
const tokenProvider = cookieHeader ? cookies.parse(cookieHeader).token_provider : null;
|
||||
|
|
@ -13,6 +14,7 @@ const optionalJwtAuth = (req, res, next) => {
|
|||
}
|
||||
if (user) {
|
||||
req.user = user;
|
||||
return tenantContextMiddleware(req, res, next);
|
||||
}
|
||||
next();
|
||||
};
|
||||
|
|
|
|||
|
|
@ -1,20 +1,29 @@
|
|||
const cookies = require('cookie');
|
||||
const passport = require('passport');
|
||||
const { isEnabled } = require('@librechat/api');
|
||||
const { isEnabled, tenantContextMiddleware } = require('@librechat/api');
|
||||
|
||||
/**
|
||||
* Custom Middleware to handle JWT authentication, with support for OpenID token reuse
|
||||
* Switches between JWT and OpenID authentication based on cookies and environment settings
|
||||
* Custom Middleware to handle JWT authentication, with support for OpenID token reuse.
|
||||
* Switches between JWT and OpenID authentication based on cookies and environment settings.
|
||||
*
|
||||
* After successful authentication (req.user populated), automatically chains into
|
||||
* `tenantContextMiddleware` to propagate `req.user.tenantId` into AsyncLocalStorage
|
||||
* for downstream Mongoose tenant isolation.
|
||||
*/
|
||||
const requireJwtAuth = (req, res, next) => {
|
||||
const cookieHeader = req.headers.cookie;
|
||||
const tokenProvider = cookieHeader ? cookies.parse(cookieHeader).token_provider : null;
|
||||
|
||||
if (tokenProvider === 'openid' && isEnabled(process.env.OPENID_REUSE_TOKENS)) {
|
||||
return passport.authenticate('openidJwt', { session: false })(req, res, next);
|
||||
}
|
||||
const strategy =
|
||||
tokenProvider === 'openid' && isEnabled(process.env.OPENID_REUSE_TOKENS) ? 'openidJwt' : 'jwt';
|
||||
|
||||
return passport.authenticate('jwt', { session: false })(req, res, next);
|
||||
passport.authenticate(strategy, { session: false })(req, res, (err) => {
|
||||
if (err) {
|
||||
return next(err);
|
||||
}
|
||||
// req.user is now populated by passport — set up tenant ALS context
|
||||
tenantContextMiddleware(req, res, next);
|
||||
});
|
||||
};
|
||||
|
||||
module.exports = requireJwtAuth;
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ const mockRegistryInstance = {
|
|||
getServerConfig: jest.fn(),
|
||||
getOAuthServers: jest.fn(),
|
||||
getAllServerConfigs: jest.fn(),
|
||||
ensureConfigServers: jest.fn().mockResolvedValue({}),
|
||||
addServer: jest.fn(),
|
||||
updateServer: jest.fn(),
|
||||
removeServer: jest.fn(),
|
||||
|
|
@ -58,6 +59,7 @@ jest.mock('@librechat/api', () => {
|
|||
});
|
||||
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
getTenantId: jest.fn(),
|
||||
logger: {
|
||||
debug: jest.fn(),
|
||||
info: jest.fn(),
|
||||
|
|
@ -93,14 +95,18 @@ jest.mock('~/server/services/Config', () => ({
|
|||
getCachedTools: jest.fn(),
|
||||
getMCPServerTools: jest.fn(),
|
||||
loadCustomConfig: jest.fn(),
|
||||
getAppConfig: jest.fn().mockResolvedValue({ mcpConfig: {} }),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Config/mcp', () => ({
|
||||
updateMCPServerTools: jest.fn(),
|
||||
}));
|
||||
|
||||
const mockResolveAllMcpConfigs = jest.fn().mockResolvedValue({});
|
||||
jest.mock('~/server/services/MCP', () => ({
|
||||
getMCPSetupData: jest.fn(),
|
||||
resolveConfigServers: jest.fn().mockResolvedValue({}),
|
||||
resolveAllMcpConfigs: (...args) => mockResolveAllMcpConfigs(...args),
|
||||
getServerConnectionStatus: jest.fn(),
|
||||
}));
|
||||
|
||||
|
|
@ -579,6 +585,112 @@ describe('MCP Routes', () => {
|
|||
);
|
||||
});
|
||||
|
||||
it('should use oauthHeaders from flow state when present', async () => {
|
||||
const mockFlowManager = {
|
||||
getFlowState: jest.fn().mockResolvedValue({ status: 'PENDING' }),
|
||||
completeFlow: jest.fn().mockResolvedValue(),
|
||||
deleteFlow: jest.fn().mockResolvedValue(true),
|
||||
};
|
||||
const mockFlowState = {
|
||||
serverName: 'test-server',
|
||||
userId: 'test-user-id',
|
||||
metadata: { toolFlowId: 'tool-flow-123' },
|
||||
clientInfo: {},
|
||||
codeVerifier: 'test-verifier',
|
||||
oauthHeaders: { 'X-Custom-Auth': 'header-value' },
|
||||
};
|
||||
const mockTokens = { access_token: 'tok', refresh_token: 'ref' };
|
||||
|
||||
MCPOAuthHandler.getFlowState.mockResolvedValue(mockFlowState);
|
||||
MCPOAuthHandler.completeOAuthFlow.mockResolvedValue(mockTokens);
|
||||
MCPTokenStorage.storeTokens.mockResolvedValue();
|
||||
getLogStores.mockReturnValue({});
|
||||
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
|
||||
require('~/config').getOAuthReconnectionManager.mockReturnValue({
|
||||
clearReconnection: jest.fn(),
|
||||
});
|
||||
require('~/config').getMCPManager.mockReturnValue({
|
||||
getUserConnection: jest.fn().mockResolvedValue({
|
||||
fetchTools: jest.fn().mockResolvedValue([]),
|
||||
}),
|
||||
});
|
||||
const { getCachedTools, setCachedTools } = require('~/server/services/Config');
|
||||
getCachedTools.mockResolvedValue({});
|
||||
setCachedTools.mockResolvedValue();
|
||||
|
||||
const flowId = 'test-user-id:test-server';
|
||||
const csrfToken = generateTestCsrfToken(flowId);
|
||||
|
||||
await request(app)
|
||||
.get('/api/mcp/test-server/oauth/callback')
|
||||
.set('Cookie', [`oauth_csrf=${csrfToken}`])
|
||||
.query({ code: 'auth-code', state: flowId });
|
||||
|
||||
expect(MCPOAuthHandler.completeOAuthFlow).toHaveBeenCalledWith(
|
||||
flowId,
|
||||
'auth-code',
|
||||
mockFlowManager,
|
||||
{ 'X-Custom-Auth': 'header-value' },
|
||||
);
|
||||
expect(mockRegistryInstance.getServerConfig).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should fall back to registry oauth_headers when flow state lacks them', async () => {
|
||||
const mockFlowManager = {
|
||||
getFlowState: jest.fn().mockResolvedValue({ status: 'PENDING' }),
|
||||
completeFlow: jest.fn().mockResolvedValue(),
|
||||
deleteFlow: jest.fn().mockResolvedValue(true),
|
||||
};
|
||||
const mockFlowState = {
|
||||
serverName: 'test-server',
|
||||
userId: 'test-user-id',
|
||||
metadata: { toolFlowId: 'tool-flow-123' },
|
||||
clientInfo: {},
|
||||
codeVerifier: 'test-verifier',
|
||||
};
|
||||
const mockTokens = { access_token: 'tok', refresh_token: 'ref' };
|
||||
|
||||
MCPOAuthHandler.getFlowState.mockResolvedValue(mockFlowState);
|
||||
MCPOAuthHandler.completeOAuthFlow.mockResolvedValue(mockTokens);
|
||||
MCPTokenStorage.storeTokens.mockResolvedValue();
|
||||
mockRegistryInstance.getServerConfig.mockResolvedValue({
|
||||
oauth_headers: { 'X-Registry-Header': 'from-registry' },
|
||||
});
|
||||
getLogStores.mockReturnValue({});
|
||||
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
|
||||
require('~/config').getOAuthReconnectionManager.mockReturnValue({
|
||||
clearReconnection: jest.fn(),
|
||||
});
|
||||
require('~/config').getMCPManager.mockReturnValue({
|
||||
getUserConnection: jest.fn().mockResolvedValue({
|
||||
fetchTools: jest.fn().mockResolvedValue([]),
|
||||
}),
|
||||
});
|
||||
const { getCachedTools, setCachedTools } = require('~/server/services/Config');
|
||||
getCachedTools.mockResolvedValue({});
|
||||
setCachedTools.mockResolvedValue();
|
||||
|
||||
const flowId = 'test-user-id:test-server';
|
||||
const csrfToken = generateTestCsrfToken(flowId);
|
||||
|
||||
await request(app)
|
||||
.get('/api/mcp/test-server/oauth/callback')
|
||||
.set('Cookie', [`oauth_csrf=${csrfToken}`])
|
||||
.query({ code: 'auth-code', state: flowId });
|
||||
|
||||
expect(MCPOAuthHandler.completeOAuthFlow).toHaveBeenCalledWith(
|
||||
flowId,
|
||||
'auth-code',
|
||||
mockFlowManager,
|
||||
{ 'X-Registry-Header': 'from-registry' },
|
||||
);
|
||||
expect(mockRegistryInstance.getServerConfig).toHaveBeenCalledWith(
|
||||
'test-server',
|
||||
'test-user-id',
|
||||
undefined,
|
||||
);
|
||||
});
|
||||
|
||||
it('should redirect to error page when callback processing fails', async () => {
|
||||
MCPOAuthHandler.getFlowState.mockRejectedValue(new Error('Callback error'));
|
||||
const flowId = 'test-user-id:test-server';
|
||||
|
|
@ -1350,19 +1462,10 @@ describe('MCP Routes', () => {
|
|||
},
|
||||
});
|
||||
|
||||
expect(getMCPSetupData).toHaveBeenCalledWith('test-user-id');
|
||||
expect(getMCPSetupData).toHaveBeenCalledWith('test-user-id', expect.any(Object));
|
||||
expect(getServerConnectionStatus).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
|
||||
it('should return 404 when MCP config is not found', async () => {
|
||||
getMCPSetupData.mockRejectedValue(new Error('MCP config not found'));
|
||||
|
||||
const response = await request(app).get('/api/mcp/connection/status');
|
||||
|
||||
expect(response.status).toBe(404);
|
||||
expect(response.body).toEqual({ error: 'MCP config not found' });
|
||||
});
|
||||
|
||||
it('should return 500 when connection status check fails', async () => {
|
||||
getMCPSetupData.mockRejectedValue(new Error('Database error'));
|
||||
|
||||
|
|
@ -1437,15 +1540,6 @@ describe('MCP Routes', () => {
|
|||
});
|
||||
});
|
||||
|
||||
it('should return 404 when MCP config is not found', async () => {
|
||||
getMCPSetupData.mockRejectedValue(new Error('MCP config not found'));
|
||||
|
||||
const response = await request(app).get('/api/mcp/connection/status/test-server');
|
||||
|
||||
expect(response.status).toBe(404);
|
||||
expect(response.body).toEqual({ error: 'MCP config not found' });
|
||||
});
|
||||
|
||||
it('should return 500 when connection status check fails', async () => {
|
||||
getMCPSetupData.mockRejectedValue(new Error('Database connection failed'));
|
||||
|
||||
|
|
@ -1704,7 +1798,7 @@ describe('MCP Routes', () => {
|
|||
},
|
||||
};
|
||||
|
||||
mockRegistryInstance.getAllServerConfigs.mockResolvedValue(mockServerConfigs);
|
||||
mockResolveAllMcpConfigs.mockResolvedValue(mockServerConfigs);
|
||||
|
||||
const response = await request(app).get('/api/mcp/servers');
|
||||
|
||||
|
|
@ -1721,11 +1815,14 @@ describe('MCP Routes', () => {
|
|||
});
|
||||
expect(response.body['server-1'].headers).toBeUndefined();
|
||||
expect(response.body['server-2'].headers).toBeUndefined();
|
||||
expect(mockRegistryInstance.getAllServerConfigs).toHaveBeenCalledWith('test-user-id');
|
||||
expect(mockResolveAllMcpConfigs).toHaveBeenCalledWith(
|
||||
'test-user-id',
|
||||
expect.objectContaining({ id: 'test-user-id' }),
|
||||
);
|
||||
});
|
||||
|
||||
it('should return empty object when no servers are configured', async () => {
|
||||
mockRegistryInstance.getAllServerConfigs.mockResolvedValue({});
|
||||
mockResolveAllMcpConfigs.mockResolvedValue({});
|
||||
|
||||
const response = await request(app).get('/api/mcp/servers');
|
||||
|
||||
|
|
@ -1749,7 +1846,7 @@ describe('MCP Routes', () => {
|
|||
});
|
||||
|
||||
it('should return 500 when server config retrieval fails', async () => {
|
||||
mockRegistryInstance.getAllServerConfigs.mockRejectedValue(new Error('Database error'));
|
||||
mockResolveAllMcpConfigs.mockRejectedValue(new Error('Database error'));
|
||||
|
||||
const response = await request(app).get('/api/mcp/servers');
|
||||
|
||||
|
|
@ -1939,11 +2036,12 @@ describe('MCP Routes', () => {
|
|||
expect(mockRegistryInstance.getServerConfig).toHaveBeenCalledWith(
|
||||
'test-server',
|
||||
'test-user-id',
|
||||
{},
|
||||
);
|
||||
});
|
||||
|
||||
it('should return 404 when server not found', async () => {
|
||||
mockRegistryInstance.getServerConfig.mockResolvedValue(null);
|
||||
mockRegistryInstance.getServerConfig.mockResolvedValue(undefined);
|
||||
|
||||
const response = await request(app).get('/api/mcp/servers/non-existent-server');
|
||||
|
||||
|
|
|
|||
40
api/server/routes/admin/config.js
Normal file
40
api/server/routes/admin/config.js
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
const express = require('express');
|
||||
const { createAdminConfigHandlers } = require('@librechat/api');
|
||||
const { SystemCapabilities } = require('@librechat/data-schemas');
|
||||
const {
|
||||
hasConfigCapability,
|
||||
requireCapability,
|
||||
} = require('~/server/middleware/roles/capabilities');
|
||||
const { getAppConfig, invalidateConfigCaches } = require('~/server/services/Config');
|
||||
const { requireJwtAuth } = require('~/server/middleware');
|
||||
const db = require('~/models');
|
||||
|
||||
const router = express.Router();
|
||||
|
||||
const requireAdminAccess = requireCapability(SystemCapabilities.ACCESS_ADMIN);
|
||||
|
||||
const handlers = createAdminConfigHandlers({
|
||||
listAllConfigs: db.listAllConfigs,
|
||||
findConfigByPrincipal: db.findConfigByPrincipal,
|
||||
upsertConfig: db.upsertConfig,
|
||||
patchConfigFields: db.patchConfigFields,
|
||||
unsetConfigField: db.unsetConfigField,
|
||||
deleteConfig: db.deleteConfig,
|
||||
toggleConfigActive: db.toggleConfigActive,
|
||||
hasConfigCapability,
|
||||
getAppConfig,
|
||||
invalidateConfigCaches,
|
||||
});
|
||||
|
||||
router.use(requireJwtAuth, requireAdminAccess);
|
||||
|
||||
router.get('/', handlers.listConfigs);
|
||||
router.get('/base', handlers.getBaseConfig);
|
||||
router.get('/:principalType/:principalId', handlers.getConfig);
|
||||
router.put('/:principalType/:principalId', handlers.upsertConfigOverrides);
|
||||
router.patch('/:principalType/:principalId/fields', handlers.patchConfigField);
|
||||
router.delete('/:principalType/:principalId/fields', handlers.deleteConfigField);
|
||||
router.delete('/:principalType/:principalId', handlers.deleteConfigOverrides);
|
||||
router.patch('/:principalType/:principalId/active', handlers.toggleConfig);
|
||||
|
||||
module.exports = router;
|
||||
41
api/server/routes/admin/groups.js
Normal file
41
api/server/routes/admin/groups.js
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
const express = require('express');
|
||||
const { createAdminGroupsHandlers } = require('@librechat/api');
|
||||
const { SystemCapabilities } = require('@librechat/data-schemas');
|
||||
const { requireCapability } = require('~/server/middleware/roles/capabilities');
|
||||
const { requireJwtAuth } = require('~/server/middleware');
|
||||
const db = require('~/models');
|
||||
|
||||
const router = express.Router();
|
||||
|
||||
const requireAdminAccess = requireCapability(SystemCapabilities.ACCESS_ADMIN);
|
||||
const requireReadGroups = requireCapability(SystemCapabilities.READ_GROUPS);
|
||||
const requireManageGroups = requireCapability(SystemCapabilities.MANAGE_GROUPS);
|
||||
|
||||
const handlers = createAdminGroupsHandlers({
|
||||
listGroups: db.listGroups,
|
||||
countGroups: db.countGroups,
|
||||
findGroupById: db.findGroupById,
|
||||
createGroup: db.createGroup,
|
||||
updateGroupById: db.updateGroupById,
|
||||
deleteGroup: db.deleteGroup,
|
||||
addUserToGroup: db.addUserToGroup,
|
||||
removeUserFromGroup: db.removeUserFromGroup,
|
||||
removeMemberById: db.removeMemberById,
|
||||
findUsers: db.findUsers,
|
||||
deleteConfig: db.deleteConfig,
|
||||
deleteAclEntries: db.deleteAclEntries,
|
||||
deleteGrantsForPrincipal: db.deleteGrantsForPrincipal,
|
||||
});
|
||||
|
||||
router.use(requireJwtAuth, requireAdminAccess);
|
||||
|
||||
router.get('/', requireReadGroups, handlers.listGroups);
|
||||
router.post('/', requireManageGroups, handlers.createGroup);
|
||||
router.get('/:id', requireReadGroups, handlers.getGroup);
|
||||
router.patch('/:id', requireManageGroups, handlers.updateGroup);
|
||||
router.delete('/:id', requireManageGroups, handlers.deleteGroup);
|
||||
router.get('/:id/members', requireReadGroups, handlers.getGroupMembers);
|
||||
router.post('/:id/members', requireManageGroups, handlers.addGroupMember);
|
||||
router.delete('/:id/members/:userId', requireManageGroups, handlers.removeGroupMember);
|
||||
|
||||
module.exports = router;
|
||||
43
api/server/routes/admin/roles.js
Normal file
43
api/server/routes/admin/roles.js
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
const express = require('express');
|
||||
const { createAdminRolesHandlers } = require('@librechat/api');
|
||||
const { SystemCapabilities } = require('@librechat/data-schemas');
|
||||
const { requireCapability } = require('~/server/middleware/roles/capabilities');
|
||||
const { requireJwtAuth } = require('~/server/middleware');
|
||||
const db = require('~/models');
|
||||
|
||||
const router = express.Router();
|
||||
|
||||
const requireAdminAccess = requireCapability(SystemCapabilities.ACCESS_ADMIN);
|
||||
const requireReadRoles = requireCapability(SystemCapabilities.READ_ROLES);
|
||||
const requireManageRoles = requireCapability(SystemCapabilities.MANAGE_ROLES);
|
||||
|
||||
const handlers = createAdminRolesHandlers({
|
||||
listRoles: db.listRoles,
|
||||
countRoles: db.countRoles,
|
||||
getRoleByName: db.getRoleByName,
|
||||
createRoleByName: db.createRoleByName,
|
||||
updateRoleByName: db.updateRoleByName,
|
||||
updateAccessPermissions: db.updateAccessPermissions,
|
||||
deleteRoleByName: db.deleteRoleByName,
|
||||
findUser: db.findUser,
|
||||
updateUser: db.updateUser,
|
||||
updateUsersByRole: db.updateUsersByRole,
|
||||
findUserIdsByRole: db.findUserIdsByRole,
|
||||
updateUsersRoleByIds: db.updateUsersRoleByIds,
|
||||
listUsersByRole: db.listUsersByRole,
|
||||
countUsersByRole: db.countUsersByRole,
|
||||
});
|
||||
|
||||
router.use(requireJwtAuth, requireAdminAccess);
|
||||
|
||||
router.get('/', requireReadRoles, handlers.listRoles);
|
||||
router.post('/', requireManageRoles, handlers.createRole);
|
||||
router.get('/:name', requireReadRoles, handlers.getRole);
|
||||
router.patch('/:name', requireManageRoles, handlers.updateRole);
|
||||
router.delete('/:name', requireManageRoles, handlers.deleteRole);
|
||||
router.patch('/:name/permissions', requireManageRoles, handlers.updateRolePermissions);
|
||||
router.get('/:name/members', requireReadRoles, handlers.getRoleMembers);
|
||||
router.post('/:name/members', requireManageRoles, handlers.addRoleMember);
|
||||
router.delete('/:name/members/:userId', requireManageRoles, handlers.removeRoleMember);
|
||||
|
||||
module.exports = router;
|
||||
186
api/server/routes/agents/__tests__/streamTenant.spec.js
Normal file
186
api/server/routes/agents/__tests__/streamTenant.spec.js
Normal file
|
|
@ -0,0 +1,186 @@
|
|||
const express = require('express');
|
||||
const request = require('supertest');
|
||||
|
||||
const mockGenerationJobManager = {
|
||||
getJob: jest.fn(),
|
||||
subscribe: jest.fn(),
|
||||
getResumeState: jest.fn(),
|
||||
abortJob: jest.fn(),
|
||||
getActiveJobIdsForUser: jest.fn().mockResolvedValue([]),
|
||||
};
|
||||
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
...jest.requireActual('@librechat/data-schemas'),
|
||||
logger: {
|
||||
debug: jest.fn(),
|
||||
warn: jest.fn(),
|
||||
error: jest.fn(),
|
||||
info: jest.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
jest.mock('@librechat/api', () => ({
|
||||
...jest.requireActual('@librechat/api'),
|
||||
isEnabled: jest.fn().mockReturnValue(false),
|
||||
GenerationJobManager: mockGenerationJobManager,
|
||||
}));
|
||||
|
||||
jest.mock('~/models', () => ({
|
||||
saveMessage: jest.fn(),
|
||||
}));
|
||||
|
||||
let mockUserId = 'user-123';
|
||||
let mockTenantId;
|
||||
|
||||
jest.mock('~/server/middleware', () => ({
|
||||
uaParser: (req, res, next) => next(),
|
||||
checkBan: (req, res, next) => next(),
|
||||
requireJwtAuth: (req, res, next) => {
|
||||
req.user = { id: mockUserId, tenantId: mockTenantId };
|
||||
next();
|
||||
},
|
||||
messageIpLimiter: (req, res, next) => next(),
|
||||
configMiddleware: (req, res, next) => next(),
|
||||
messageUserLimiter: (req, res, next) => next(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/routes/agents/chat', () => require('express').Router());
|
||||
jest.mock('~/server/routes/agents/v1', () => ({
|
||||
v1: require('express').Router(),
|
||||
}));
|
||||
jest.mock('~/server/routes/agents/openai', () => require('express').Router());
|
||||
jest.mock('~/server/routes/agents/responses', () => require('express').Router());
|
||||
|
||||
const agentsRouter = require('../index');
|
||||
const app = express();
|
||||
app.use(express.json());
|
||||
app.use('/agents', agentsRouter);
|
||||
|
||||
function mockSubscribeSuccess() {
|
||||
mockGenerationJobManager.subscribe.mockImplementation((_streamId, _writeEvent, onDone) => {
|
||||
process.nextTick(() => onDone({ done: true }));
|
||||
return { unsubscribe: jest.fn() };
|
||||
});
|
||||
}
|
||||
|
||||
describe('SSE stream tenant isolation', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
mockUserId = 'user-123';
|
||||
mockTenantId = undefined;
|
||||
});
|
||||
|
||||
describe('GET /chat/stream/:streamId', () => {
|
||||
it('returns 403 when a user from a different tenant accesses a stream', async () => {
|
||||
mockUserId = 'user-456';
|
||||
mockTenantId = 'tenant-b';
|
||||
|
||||
mockGenerationJobManager.getJob.mockResolvedValue({
|
||||
metadata: { userId: 'user-456', tenantId: 'tenant-a' },
|
||||
status: 'running',
|
||||
});
|
||||
|
||||
const res = await request(app).get('/agents/chat/stream/stream-123');
|
||||
expect(res.status).toBe(403);
|
||||
expect(res.body.error).toBe('Unauthorized');
|
||||
});
|
||||
|
||||
it('returns 404 when stream does not exist', async () => {
|
||||
mockGenerationJobManager.getJob.mockResolvedValue(null);
|
||||
|
||||
const res = await request(app).get('/agents/chat/stream/nonexistent');
|
||||
expect(res.status).toBe(404);
|
||||
});
|
||||
|
||||
it('proceeds past tenant guard when tenant matches', async () => {
|
||||
mockUserId = 'user-123';
|
||||
mockTenantId = 'tenant-a';
|
||||
mockSubscribeSuccess();
|
||||
|
||||
mockGenerationJobManager.getJob.mockResolvedValue({
|
||||
metadata: { userId: 'user-123', tenantId: 'tenant-a' },
|
||||
status: 'running',
|
||||
});
|
||||
|
||||
const res = await request(app).get('/agents/chat/stream/stream-123');
|
||||
expect(res.status).toBe(200);
|
||||
expect(mockGenerationJobManager.subscribe).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('proceeds past tenant guard when job has no tenantId (single-tenant mode)', async () => {
|
||||
mockUserId = 'user-123';
|
||||
mockTenantId = undefined;
|
||||
mockSubscribeSuccess();
|
||||
|
||||
mockGenerationJobManager.getJob.mockResolvedValue({
|
||||
metadata: { userId: 'user-123' },
|
||||
status: 'running',
|
||||
});
|
||||
|
||||
const res = await request(app).get('/agents/chat/stream/stream-123');
|
||||
expect(res.status).toBe(200);
|
||||
expect(mockGenerationJobManager.subscribe).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('returns 403 when job has tenantId but user has no tenantId', async () => {
|
||||
mockUserId = 'user-123';
|
||||
mockTenantId = undefined;
|
||||
|
||||
mockGenerationJobManager.getJob.mockResolvedValue({
|
||||
metadata: { userId: 'user-123', tenantId: 'some-tenant' },
|
||||
status: 'running',
|
||||
});
|
||||
|
||||
const res = await request(app).get('/agents/chat/stream/stream-123');
|
||||
expect(res.status).toBe(403);
|
||||
});
|
||||
});
|
||||
|
||||
describe('GET /chat/status/:conversationId', () => {
|
||||
it('returns 403 when tenant does not match', async () => {
|
||||
mockUserId = 'user-123';
|
||||
mockTenantId = 'tenant-b';
|
||||
|
||||
mockGenerationJobManager.getJob.mockResolvedValue({
|
||||
metadata: { userId: 'user-123', tenantId: 'tenant-a' },
|
||||
status: 'running',
|
||||
});
|
||||
|
||||
const res = await request(app).get('/agents/chat/status/conv-123');
|
||||
expect(res.status).toBe(403);
|
||||
expect(res.body.error).toBe('Unauthorized');
|
||||
});
|
||||
|
||||
it('returns status when tenant matches', async () => {
|
||||
mockUserId = 'user-123';
|
||||
mockTenantId = 'tenant-a';
|
||||
|
||||
mockGenerationJobManager.getJob.mockResolvedValue({
|
||||
metadata: { userId: 'user-123', tenantId: 'tenant-a' },
|
||||
status: 'running',
|
||||
createdAt: Date.now(),
|
||||
});
|
||||
mockGenerationJobManager.getResumeState.mockResolvedValue(null);
|
||||
|
||||
const res = await request(app).get('/agents/chat/status/conv-123');
|
||||
expect(res.status).toBe(200);
|
||||
expect(res.body.active).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe('POST /chat/abort', () => {
|
||||
it('returns 403 when tenant does not match', async () => {
|
||||
mockUserId = 'user-123';
|
||||
mockTenantId = 'tenant-b';
|
||||
|
||||
mockGenerationJobManager.getJob.mockResolvedValue({
|
||||
metadata: { userId: 'user-123', tenantId: 'tenant-a' },
|
||||
status: 'running',
|
||||
});
|
||||
|
||||
const res = await request(app).post('/agents/chat/abort').send({ streamId: 'stream-123' });
|
||||
expect(res.status).toBe(403);
|
||||
expect(res.body.error).toBe('Unauthorized');
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
@ -17,6 +17,11 @@ const chat = require('./chat');
|
|||
|
||||
const { LIMIT_MESSAGE_IP, LIMIT_MESSAGE_USER } = process.env ?? {};
|
||||
|
||||
/** Untenanted jobs (pre-multi-tenancy) remain accessible if the userId check passes. */
|
||||
function hasTenantMismatch(job, user) {
|
||||
return job.metadata?.tenantId != null && job.metadata.tenantId !== user.tenantId;
|
||||
}
|
||||
|
||||
const router = express.Router();
|
||||
|
||||
/**
|
||||
|
|
@ -67,6 +72,10 @@ router.get('/chat/stream/:streamId', async (req, res) => {
|
|||
return res.status(403).json({ error: 'Unauthorized' });
|
||||
}
|
||||
|
||||
if (hasTenantMismatch(job, req.user)) {
|
||||
return res.status(403).json({ error: 'Unauthorized' });
|
||||
}
|
||||
|
||||
res.setHeader('Content-Encoding', 'identity');
|
||||
res.setHeader('Content-Type', 'text/event-stream');
|
||||
res.setHeader('Cache-Control', 'no-cache, no-transform');
|
||||
|
|
@ -150,7 +159,10 @@ router.get('/chat/stream/:streamId', async (req, res) => {
|
|||
* @returns { activeJobIds: string[] }
|
||||
*/
|
||||
router.get('/chat/active', async (req, res) => {
|
||||
const activeJobIds = await GenerationJobManager.getActiveJobIdsForUser(req.user.id);
|
||||
const activeJobIds = await GenerationJobManager.getActiveJobIdsForUser(
|
||||
req.user.id,
|
||||
req.user.tenantId,
|
||||
);
|
||||
res.json({ activeJobIds });
|
||||
});
|
||||
|
||||
|
|
@ -174,6 +186,10 @@ router.get('/chat/status/:conversationId', async (req, res) => {
|
|||
return res.status(403).json({ error: 'Unauthorized' });
|
||||
}
|
||||
|
||||
if (hasTenantMismatch(job, req.user)) {
|
||||
return res.status(403).json({ error: 'Unauthorized' });
|
||||
}
|
||||
|
||||
// Get resume state which contains aggregatedContent
|
||||
// Avoid calling both getStreamInfo and getResumeState (both fetch content)
|
||||
const resumeState = await GenerationJobManager.getResumeState(conversationId);
|
||||
|
|
@ -213,7 +229,10 @@ router.post('/chat/abort', async (req, res) => {
|
|||
// This handles the case where frontend sends "new" but job was created with a UUID
|
||||
if (!job && userId) {
|
||||
logger.debug(`[AgentStream] Job not found by ID, checking active jobs for user: ${userId}`);
|
||||
const activeJobIds = await GenerationJobManager.getActiveJobIdsForUser(userId);
|
||||
const activeJobIds = await GenerationJobManager.getActiveJobIdsForUser(
|
||||
userId,
|
||||
req.user.tenantId,
|
||||
);
|
||||
if (activeJobIds.length > 0) {
|
||||
// Abort the most recent active job for this user
|
||||
jobStreamId = activeJobIds[0];
|
||||
|
|
@ -230,6 +249,10 @@ router.post('/chat/abort', async (req, res) => {
|
|||
return res.status(403).json({ error: 'Unauthorized' });
|
||||
}
|
||||
|
||||
if (hasTenantMismatch(job, req.user)) {
|
||||
return res.status(403).json({ error: 'Unauthorized' });
|
||||
}
|
||||
|
||||
logger.debug(`[AgentStream] Job found, aborting: ${jobStreamId}`);
|
||||
const abortResult = await GenerationJobManager.abortJob(jobStreamId);
|
||||
logger.debug(`[AgentStream] Job aborted successfully: ${jobStreamId}`, {
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
const express = require('express');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { isEnabled, getBalanceConfig } = require('@librechat/api');
|
||||
const { CacheKeys, defaultSocialLogins } = require('librechat-data-provider');
|
||||
const { logger, getTenantId, scopedCacheKey } = require('@librechat/data-schemas');
|
||||
const { getLdapConfig } = require('~/server/services/Config/ldap');
|
||||
const { getAppConfig } = require('~/server/services/Config/app');
|
||||
const { getLogStores } = require('~/cache');
|
||||
|
|
@ -23,7 +23,8 @@ const openidReuseTokens = isEnabled(process.env.OPENID_REUSE_TOKENS);
|
|||
router.get('/', async function (req, res) {
|
||||
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
||||
|
||||
const cachedStartupConfig = await cache.get(CacheKeys.STARTUP_CONFIG);
|
||||
const cacheKey = scopedCacheKey(CacheKeys.STARTUP_CONFIG);
|
||||
const cachedStartupConfig = await cache.get(cacheKey);
|
||||
if (cachedStartupConfig) {
|
||||
res.send(cachedStartupConfig);
|
||||
return;
|
||||
|
|
@ -37,7 +38,10 @@ router.get('/', async function (req, res) {
|
|||
const ldap = getLdapConfig();
|
||||
|
||||
try {
|
||||
const appConfig = await getAppConfig({ role: req.user?.role });
|
||||
const appConfig = await getAppConfig({
|
||||
role: req.user?.role,
|
||||
tenantId: req.user?.tenantId || getTenantId(),
|
||||
});
|
||||
|
||||
const isOpenIdEnabled =
|
||||
!!process.env.OPENID_CLIENT_ID &&
|
||||
|
|
@ -141,7 +145,7 @@ router.get('/', async function (req, res) {
|
|||
payload.customFooter = process.env.CUSTOM_FOOTER;
|
||||
}
|
||||
|
||||
await cache.set(CacheKeys.STARTUP_CONFIG, payload);
|
||||
await cache.set(cacheKey, payload);
|
||||
return res.status(200).send(payload);
|
||||
} catch (err) {
|
||||
logger.error('Error in startup config', err);
|
||||
|
|
|
|||
|
|
@ -2,6 +2,9 @@ const accessPermissions = require('./accessPermissions');
|
|||
const assistants = require('./assistants');
|
||||
const categories = require('./categories');
|
||||
const adminAuth = require('./admin/auth');
|
||||
const adminConfig = require('./admin/config');
|
||||
const adminGroups = require('./admin/groups');
|
||||
const adminRoles = require('./admin/roles');
|
||||
const endpoints = require('./endpoints');
|
||||
const staticRoute = require('./static');
|
||||
const messages = require('./messages');
|
||||
|
|
@ -31,6 +34,9 @@ module.exports = {
|
|||
mcp,
|
||||
auth,
|
||||
adminAuth,
|
||||
adminConfig,
|
||||
adminGroups,
|
||||
adminRoles,
|
||||
keys,
|
||||
apiKeys,
|
||||
user,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
const { Router } = require('express');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { logger, getTenantId } = require('@librechat/data-schemas');
|
||||
const {
|
||||
CacheKeys,
|
||||
Constants,
|
||||
|
|
@ -36,7 +36,11 @@ const {
|
|||
getFlowStateManager,
|
||||
getMCPManager,
|
||||
} = require('~/config');
|
||||
const { getMCPSetupData, getServerConnectionStatus } = require('~/server/services/MCP');
|
||||
const {
|
||||
getServerConnectionStatus,
|
||||
resolveConfigServers,
|
||||
getMCPSetupData,
|
||||
} = require('~/server/services/MCP');
|
||||
const { requireJwtAuth, canAccessMCPServerResource } = require('~/server/middleware');
|
||||
const { getUserPluginAuthValue } = require('~/server/services/PluginService');
|
||||
const { updateMCPServerTools } = require('~/server/services/Config/mcp');
|
||||
|
|
@ -101,7 +105,8 @@ router.get('/:serverName/oauth/initiate', requireJwtAuth, setOAuthSession, async
|
|||
return res.status(400).json({ error: 'Invalid flow state' });
|
||||
}
|
||||
|
||||
const oauthHeaders = await getOAuthHeaders(serverName, userId);
|
||||
const configServers = await resolveConfigServers(req);
|
||||
const oauthHeaders = await getOAuthHeaders(serverName, userId, configServers);
|
||||
const {
|
||||
authorizationUrl,
|
||||
flowId: oauthFlowId,
|
||||
|
|
@ -233,7 +238,14 @@ router.get('/:serverName/oauth/callback', async (req, res) => {
|
|||
}
|
||||
|
||||
logger.debug('[MCP OAuth] Completing OAuth flow');
|
||||
const oauthHeaders = await getOAuthHeaders(serverName, flowState.userId);
|
||||
if (!flowState.oauthHeaders) {
|
||||
logger.warn(
|
||||
'[MCP OAuth] oauthHeaders absent from flow state — config-source server oauth_headers will be empty',
|
||||
{ serverName, flowId },
|
||||
);
|
||||
}
|
||||
const oauthHeaders =
|
||||
flowState.oauthHeaders ?? (await getOAuthHeaders(serverName, flowState.userId));
|
||||
const tokens = await MCPOAuthHandler.completeOAuthFlow(flowId, code, flowManager, oauthHeaders);
|
||||
logger.info('[MCP OAuth] OAuth flow completed, tokens received in callback route');
|
||||
|
||||
|
|
@ -497,7 +509,12 @@ router.post(
|
|||
logger.info(`[MCP Reinitialize] Reinitializing server: ${serverName}`);
|
||||
|
||||
const mcpManager = getMCPManager();
|
||||
const serverConfig = await getMCPServersRegistry().getServerConfig(serverName, user.id);
|
||||
const configServers = await resolveConfigServers(req);
|
||||
const serverConfig = await getMCPServersRegistry().getServerConfig(
|
||||
serverName,
|
||||
user.id,
|
||||
configServers,
|
||||
);
|
||||
if (!serverConfig) {
|
||||
return res.status(404).json({
|
||||
error: `MCP server '${serverName}' not found in configuration`,
|
||||
|
|
@ -522,6 +539,8 @@ router.post(
|
|||
const result = await reinitMCPServer({
|
||||
user,
|
||||
serverName,
|
||||
serverConfig,
|
||||
configServers,
|
||||
userMCPAuthMap,
|
||||
});
|
||||
|
||||
|
|
@ -564,6 +583,7 @@ router.get('/connection/status', requireJwtAuth, async (req, res) => {
|
|||
|
||||
const { mcpConfig, appConnections, userConnections, oauthServers } = await getMCPSetupData(
|
||||
user.id,
|
||||
{ role: user.role, tenantId: getTenantId() },
|
||||
);
|
||||
const connectionStatus = {};
|
||||
|
||||
|
|
@ -593,9 +613,6 @@ router.get('/connection/status', requireJwtAuth, async (req, res) => {
|
|||
connectionStatus,
|
||||
});
|
||||
} catch (error) {
|
||||
if (error.message === 'MCP config not found') {
|
||||
return res.status(404).json({ error: error.message });
|
||||
}
|
||||
logger.error('[MCP Connection Status] Failed to get connection status', error);
|
||||
res.status(500).json({ error: 'Failed to get connection status' });
|
||||
}
|
||||
|
|
@ -616,6 +633,7 @@ router.get('/connection/status/:serverName', requireJwtAuth, async (req, res) =>
|
|||
|
||||
const { mcpConfig, appConnections, userConnections, oauthServers } = await getMCPSetupData(
|
||||
user.id,
|
||||
{ role: user.role, tenantId: getTenantId() },
|
||||
);
|
||||
|
||||
if (!mcpConfig[serverName]) {
|
||||
|
|
@ -640,9 +658,6 @@ router.get('/connection/status/:serverName', requireJwtAuth, async (req, res) =>
|
|||
requiresOAuth: serverStatus.requiresOAuth,
|
||||
});
|
||||
} catch (error) {
|
||||
if (error.message === 'MCP config not found') {
|
||||
return res.status(404).json({ error: error.message });
|
||||
}
|
||||
logger.error(
|
||||
`[MCP Per-Server Status] Failed to get connection status for ${req.params.serverName}`,
|
||||
error,
|
||||
|
|
@ -664,7 +679,12 @@ router.get('/:serverName/auth-values', requireJwtAuth, checkMCPUsePermissions, a
|
|||
return res.status(401).json({ error: 'User not authenticated' });
|
||||
}
|
||||
|
||||
const serverConfig = await getMCPServersRegistry().getServerConfig(serverName, user.id);
|
||||
const configServers = await resolveConfigServers(req);
|
||||
const serverConfig = await getMCPServersRegistry().getServerConfig(
|
||||
serverName,
|
||||
user.id,
|
||||
configServers,
|
||||
);
|
||||
if (!serverConfig) {
|
||||
return res.status(404).json({
|
||||
error: `MCP server '${serverName}' not found in configuration`,
|
||||
|
|
@ -703,8 +723,12 @@ router.get('/:serverName/auth-values', requireJwtAuth, checkMCPUsePermissions, a
|
|||
}
|
||||
});
|
||||
|
||||
async function getOAuthHeaders(serverName, userId) {
|
||||
const serverConfig = await getMCPServersRegistry().getServerConfig(serverName, userId);
|
||||
async function getOAuthHeaders(serverName, userId, configServers) {
|
||||
const serverConfig = await getMCPServersRegistry().getServerConfig(
|
||||
serverName,
|
||||
userId,
|
||||
configServers,
|
||||
);
|
||||
return serverConfig?.oauth_headers ?? {};
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ const {
|
|||
checkEmailConfig,
|
||||
isEmailDomainAllowed,
|
||||
shouldUseSecureCookie,
|
||||
resolveAppConfigForUser,
|
||||
} = require('@librechat/api');
|
||||
const {
|
||||
findUser,
|
||||
|
|
@ -189,7 +190,7 @@ const registerUser = async (user, additionalData = {}) => {
|
|||
|
||||
let newUserId;
|
||||
try {
|
||||
const appConfig = await getAppConfig();
|
||||
const appConfig = await getAppConfig({ baseOnly: true });
|
||||
if (!isEmailDomainAllowed(email, appConfig?.registration?.allowedDomains)) {
|
||||
const errorMessage =
|
||||
'The email address provided cannot be used. Please use a different email address.';
|
||||
|
|
@ -255,19 +256,52 @@ const registerUser = async (user, additionalData = {}) => {
|
|||
};
|
||||
|
||||
/**
|
||||
* Request password reset
|
||||
* Request password reset.
|
||||
*
|
||||
* Uses a two-phase domain check: fast-fail with the memory-cached base config
|
||||
* (zero DB queries) to block globally denied domains before user lookup, then
|
||||
* re-check with tenant-scoped config after user lookup so tenant-specific
|
||||
* restrictions are enforced.
|
||||
*
|
||||
* Phase 1 (base check) returns an Error (HTTP 400) — this intentionally reveals
|
||||
* that the domain is globally blocked, but fires before any DB lookup so it
|
||||
* cannot confirm user existence. Phase 2 (tenant check) returns the generic
|
||||
* success message (HTTP 200) to prevent user-enumeration via status codes.
|
||||
*
|
||||
* @param {ServerRequest} req
|
||||
*/
|
||||
const requestPasswordReset = async (req) => {
|
||||
const { email } = req.body;
|
||||
const appConfig = await getAppConfig();
|
||||
if (!isEmailDomainAllowed(email, appConfig?.registration?.allowedDomains)) {
|
||||
|
||||
const baseConfig = await getAppConfig({ baseOnly: true });
|
||||
if (!isEmailDomainAllowed(email, baseConfig?.registration?.allowedDomains)) {
|
||||
logger.warn(
|
||||
`[requestPasswordReset] Blocked - email domain not allowed [Email: ${email}] [IP: ${req.ip}]`,
|
||||
);
|
||||
const error = new Error(ErrorTypes.AUTH_FAILED);
|
||||
error.code = ErrorTypes.AUTH_FAILED;
|
||||
error.message = 'Email domain not allowed';
|
||||
return error;
|
||||
}
|
||||
const user = await findUser({ email }, 'email _id');
|
||||
|
||||
const user = await findUser({ email }, 'email _id role tenantId');
|
||||
let appConfig = baseConfig;
|
||||
if (user?.tenantId) {
|
||||
try {
|
||||
appConfig = await resolveAppConfigForUser(getAppConfig, user);
|
||||
} catch (err) {
|
||||
logger.error('[requestPasswordReset] Failed to resolve tenant config, using base:', err);
|
||||
}
|
||||
}
|
||||
|
||||
if (!isEmailDomainAllowed(email, appConfig?.registration?.allowedDomains)) {
|
||||
logger.warn(
|
||||
`[requestPasswordReset] Tenant config blocked domain [Email: ${email}] [IP: ${req.ip}]`,
|
||||
);
|
||||
return {
|
||||
message: 'If an account with that email exists, a password reset link has been sent to it.',
|
||||
};
|
||||
}
|
||||
const emailEnabled = checkEmailConfig();
|
||||
|
||||
logger.warn(`[requestPasswordReset] [Password reset request initiated] [Email: ${email}]`);
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ jest.mock('@librechat/api', () => ({
|
|||
isEmailDomainAllowed: jest.fn(),
|
||||
math: jest.fn((val, fallback) => (val ? Number(val) : fallback)),
|
||||
shouldUseSecureCookie: jest.fn(() => false),
|
||||
resolveAppConfigForUser: jest.fn(async (_getAppConfig, _user) => ({})),
|
||||
}));
|
||||
jest.mock('~/models', () => ({
|
||||
findUser: jest.fn(),
|
||||
|
|
@ -35,8 +36,14 @@ jest.mock('~/strategies/validators', () => ({ registerSchema: { parse: jest.fn()
|
|||
jest.mock('~/server/services/Config', () => ({ getAppConfig: jest.fn() }));
|
||||
jest.mock('~/server/utils', () => ({ sendEmail: jest.fn() }));
|
||||
|
||||
const { shouldUseSecureCookie } = require('@librechat/api');
|
||||
const { setOpenIDAuthTokens } = require('./AuthService');
|
||||
const {
|
||||
shouldUseSecureCookie,
|
||||
isEmailDomainAllowed,
|
||||
resolveAppConfigForUser,
|
||||
} = require('@librechat/api');
|
||||
const { findUser } = require('~/models');
|
||||
const { getAppConfig } = require('~/server/services/Config');
|
||||
const { setOpenIDAuthTokens, requestPasswordReset } = require('./AuthService');
|
||||
|
||||
/** Helper to build a mock Express response */
|
||||
function mockResponse() {
|
||||
|
|
@ -267,3 +274,68 @@ describe('setOpenIDAuthTokens', () => {
|
|||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('requestPasswordReset', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
isEmailDomainAllowed.mockReturnValue(true);
|
||||
getAppConfig.mockResolvedValue({
|
||||
registration: { allowedDomains: ['example.com'] },
|
||||
});
|
||||
resolveAppConfigForUser.mockResolvedValue({
|
||||
registration: { allowedDomains: ['example.com'] },
|
||||
});
|
||||
});
|
||||
|
||||
it('should fast-fail with base config before DB lookup for blocked domains', async () => {
|
||||
isEmailDomainAllowed.mockReturnValue(false);
|
||||
|
||||
const req = { body: { email: 'blocked@evil.com' }, ip: '127.0.0.1' };
|
||||
const result = await requestPasswordReset(req);
|
||||
|
||||
expect(getAppConfig).toHaveBeenCalledWith({ baseOnly: true });
|
||||
expect(findUser).not.toHaveBeenCalled();
|
||||
expect(result).toBeInstanceOf(Error);
|
||||
});
|
||||
|
||||
it('should call resolveAppConfigForUser for tenant user', async () => {
|
||||
const user = {
|
||||
_id: 'user-tenant',
|
||||
email: 'user@example.com',
|
||||
tenantId: 'tenant-x',
|
||||
role: 'USER',
|
||||
};
|
||||
findUser.mockResolvedValue(user);
|
||||
|
||||
const req = { body: { email: 'user@example.com' }, ip: '127.0.0.1' };
|
||||
await requestPasswordReset(req);
|
||||
|
||||
expect(resolveAppConfigForUser).toHaveBeenCalledWith(getAppConfig, user);
|
||||
});
|
||||
|
||||
it('should reuse baseConfig for non-tenant user without calling resolveAppConfigForUser', async () => {
|
||||
findUser.mockResolvedValue({ _id: 'user-no-tenant', email: 'user@example.com' });
|
||||
|
||||
const req = { body: { email: 'user@example.com' }, ip: '127.0.0.1' };
|
||||
await requestPasswordReset(req);
|
||||
|
||||
expect(resolveAppConfigForUser).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should return generic response when tenant config blocks the domain (non-enumerable)', async () => {
|
||||
const user = {
|
||||
_id: 'user-tenant',
|
||||
email: 'user@example.com',
|
||||
tenantId: 'tenant-x',
|
||||
role: 'USER',
|
||||
};
|
||||
findUser.mockResolvedValue(user);
|
||||
isEmailDomainAllowed.mockReturnValueOnce(true).mockReturnValueOnce(false);
|
||||
|
||||
const req = { body: { email: 'user@example.com' }, ip: '127.0.0.1' };
|
||||
const result = await requestPasswordReset(req);
|
||||
|
||||
expect(result).not.toBeInstanceOf(Error);
|
||||
expect(result.message).toContain('If an account with that email exists');
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -0,0 +1,139 @@
|
|||
// ── Mocks ──────────────────────────────────────────────────────────────
|
||||
|
||||
const mockConfigStoreDelete = jest.fn().mockResolvedValue(true);
|
||||
const mockClearAppConfigCache = jest.fn().mockResolvedValue(undefined);
|
||||
const mockClearOverrideCache = jest.fn().mockResolvedValue(undefined);
|
||||
|
||||
jest.mock('~/cache/getLogStores', () => {
|
||||
return jest.fn(() => ({
|
||||
delete: mockConfigStoreDelete,
|
||||
}));
|
||||
});
|
||||
|
||||
jest.mock('~/server/services/start/tools', () => ({
|
||||
loadAndFormatTools: jest.fn(() => ({})),
|
||||
}));
|
||||
|
||||
jest.mock('../loadCustomConfig', () => jest.fn().mockResolvedValue({}));
|
||||
|
||||
jest.mock('@librechat/data-schemas', () => {
|
||||
const actual = jest.requireActual('@librechat/data-schemas');
|
||||
return { ...actual, AppService: jest.fn(() => ({ availableTools: {} })) };
|
||||
});
|
||||
|
||||
jest.mock('~/models', () => ({
|
||||
getApplicableConfigs: jest.fn().mockResolvedValue([]),
|
||||
getUserPrincipals: jest.fn().mockResolvedValue([]),
|
||||
}));
|
||||
|
||||
const mockInvalidateCachedTools = jest.fn().mockResolvedValue(undefined);
|
||||
jest.mock('../getCachedTools', () => ({
|
||||
setCachedTools: jest.fn().mockResolvedValue(undefined),
|
||||
invalidateCachedTools: mockInvalidateCachedTools,
|
||||
}));
|
||||
|
||||
const mockClearMcpConfigCache = jest.fn().mockResolvedValue(undefined);
|
||||
jest.mock('@librechat/api', () => ({
|
||||
createAppConfigService: jest.fn(() => ({
|
||||
getAppConfig: jest.fn().mockResolvedValue({ availableTools: {} }),
|
||||
clearAppConfigCache: mockClearAppConfigCache,
|
||||
clearOverrideCache: mockClearOverrideCache,
|
||||
})),
|
||||
clearMcpConfigCache: mockClearMcpConfigCache,
|
||||
}));
|
||||
|
||||
// ── Tests ──────────────────────────────────────────────────────────────
|
||||
|
||||
const { CacheKeys } = require('librechat-data-provider');
|
||||
const { invalidateConfigCaches } = require('../app');
|
||||
|
||||
describe('invalidateConfigCaches', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
it('clears all four caches', async () => {
|
||||
await invalidateConfigCaches();
|
||||
|
||||
expect(mockClearAppConfigCache).toHaveBeenCalledTimes(1);
|
||||
expect(mockClearOverrideCache).toHaveBeenCalledTimes(1);
|
||||
expect(mockInvalidateCachedTools).toHaveBeenCalledWith({ invalidateGlobal: true });
|
||||
expect(mockConfigStoreDelete).toHaveBeenCalledWith(CacheKeys.ENDPOINT_CONFIG);
|
||||
});
|
||||
|
||||
it('passes tenantId through to clearOverrideCache', async () => {
|
||||
await invalidateConfigCaches('tenant-a');
|
||||
|
||||
expect(mockClearOverrideCache).toHaveBeenCalledWith('tenant-a');
|
||||
expect(mockClearAppConfigCache).toHaveBeenCalledTimes(1);
|
||||
expect(mockInvalidateCachedTools).toHaveBeenCalledWith({ invalidateGlobal: true });
|
||||
});
|
||||
|
||||
it('does not throw when CONFIG_STORE.delete fails', async () => {
|
||||
mockConfigStoreDelete.mockRejectedValueOnce(new Error('store not found'));
|
||||
|
||||
await expect(invalidateConfigCaches()).resolves.not.toThrow();
|
||||
|
||||
// Other caches should still have been invalidated
|
||||
expect(mockClearAppConfigCache).toHaveBeenCalledTimes(1);
|
||||
expect(mockClearOverrideCache).toHaveBeenCalledTimes(1);
|
||||
expect(mockInvalidateCachedTools).toHaveBeenCalledWith({ invalidateGlobal: true });
|
||||
});
|
||||
|
||||
it('all operations run in parallel (not sequentially)', async () => {
|
||||
const order = [];
|
||||
|
||||
mockClearAppConfigCache.mockImplementation(
|
||||
() =>
|
||||
new Promise((r) =>
|
||||
setTimeout(() => {
|
||||
order.push('base');
|
||||
r();
|
||||
}, 10),
|
||||
),
|
||||
);
|
||||
mockClearOverrideCache.mockImplementation(
|
||||
() =>
|
||||
new Promise((r) =>
|
||||
setTimeout(() => {
|
||||
order.push('override');
|
||||
r();
|
||||
}, 10),
|
||||
),
|
||||
);
|
||||
mockInvalidateCachedTools.mockImplementation(
|
||||
() =>
|
||||
new Promise((r) =>
|
||||
setTimeout(() => {
|
||||
order.push('tools');
|
||||
r();
|
||||
}, 10),
|
||||
),
|
||||
);
|
||||
mockConfigStoreDelete.mockImplementation(
|
||||
() =>
|
||||
new Promise((r) =>
|
||||
setTimeout(() => {
|
||||
order.push('endpoint');
|
||||
r();
|
||||
}, 10),
|
||||
),
|
||||
);
|
||||
|
||||
await invalidateConfigCaches();
|
||||
|
||||
// All four should have been called (parallel execution via Promise.allSettled)
|
||||
expect(order).toHaveLength(4);
|
||||
expect(new Set(order)).toEqual(new Set(['base', 'override', 'tools', 'endpoint']));
|
||||
});
|
||||
|
||||
it('resolves even when clearAppConfigCache throws (partial failure)', async () => {
|
||||
mockClearAppConfigCache.mockRejectedValueOnce(new Error('cache connection lost'));
|
||||
|
||||
await expect(invalidateConfigCaches()).resolves.not.toThrow();
|
||||
|
||||
// Other caches should still have been invalidated despite the failure
|
||||
expect(mockClearOverrideCache).toHaveBeenCalledTimes(1);
|
||||
expect(mockInvalidateCachedTools).toHaveBeenCalledWith({ invalidateGlobal: true });
|
||||
});
|
||||
});
|
||||
|
|
@ -1,12 +1,12 @@
|
|||
const { CacheKeys } = require('librechat-data-provider');
|
||||
const { logger, AppService } = require('@librechat/data-schemas');
|
||||
const { AppService, logger, scopedCacheKey } = require('@librechat/data-schemas');
|
||||
const { createAppConfigService, clearMcpConfigCache } = require('@librechat/api');
|
||||
const { setCachedTools, invalidateCachedTools } = require('./getCachedTools');
|
||||
const { loadAndFormatTools } = require('~/server/services/start/tools');
|
||||
const loadCustomConfig = require('./loadCustomConfig');
|
||||
const { setCachedTools } = require('./getCachedTools');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const paths = require('~/config/paths');
|
||||
|
||||
const BASE_CONFIG_KEY = '_BASE_';
|
||||
const db = require('~/models');
|
||||
|
||||
const loadBaseConfig = async () => {
|
||||
/** @type {TCustomConfig} */
|
||||
|
|
@ -20,65 +20,67 @@ const loadBaseConfig = async () => {
|
|||
return AppService({ config, paths, systemTools });
|
||||
};
|
||||
|
||||
const { getAppConfig, clearAppConfigCache, clearOverrideCache } = createAppConfigService({
|
||||
loadBaseConfig,
|
||||
setCachedTools,
|
||||
getCache: getLogStores,
|
||||
cacheKeys: CacheKeys,
|
||||
getApplicableConfigs: db.getApplicableConfigs,
|
||||
getUserPrincipals: db.getUserPrincipals,
|
||||
});
|
||||
|
||||
/**
|
||||
* Get the app configuration based on user context
|
||||
* @param {Object} [options]
|
||||
* @param {string} [options.role] - User role for role-based config
|
||||
* @param {boolean} [options.refresh] - Force refresh the cache
|
||||
* @returns {Promise<AppConfig>}
|
||||
* Deletes ENDPOINT_CONFIG entries from CONFIG_STORE.
|
||||
* Clears both the tenant-scoped key (if in tenant context) and the
|
||||
* unscoped base key (populated by unauthenticated /api/endpoints calls).
|
||||
* Other tenants' scoped keys are NOT actively cleared — they expire
|
||||
* via TTL. Config mutations in one tenant do not propagate immediately
|
||||
* to other tenants' endpoint config caches.
|
||||
*/
|
||||
async function getAppConfig(options = {}) {
|
||||
const { role, refresh } = options;
|
||||
|
||||
const cache = getLogStores(CacheKeys.APP_CONFIG);
|
||||
const cacheKey = role ? role : BASE_CONFIG_KEY;
|
||||
|
||||
if (!refresh) {
|
||||
const cached = await cache.get(cacheKey);
|
||||
if (cached) {
|
||||
return cached;
|
||||
async function clearEndpointConfigCache() {
|
||||
try {
|
||||
const configStore = getLogStores(CacheKeys.CONFIG_STORE);
|
||||
const scoped = scopedCacheKey(CacheKeys.ENDPOINT_CONFIG);
|
||||
const keys = [scoped];
|
||||
if (scoped !== CacheKeys.ENDPOINT_CONFIG) {
|
||||
keys.push(CacheKeys.ENDPOINT_CONFIG);
|
||||
}
|
||||
await Promise.all(keys.map((k) => configStore.delete(k)));
|
||||
} catch {
|
||||
// CONFIG_STORE or ENDPOINT_CONFIG may not exist — not critical
|
||||
}
|
||||
|
||||
let baseConfig = await cache.get(BASE_CONFIG_KEY);
|
||||
if (!baseConfig) {
|
||||
logger.info('[getAppConfig] App configuration not initialized. Initializing AppService...');
|
||||
baseConfig = await loadBaseConfig();
|
||||
|
||||
if (!baseConfig) {
|
||||
throw new Error('Failed to initialize app configuration through AppService.');
|
||||
}
|
||||
|
||||
if (baseConfig.availableTools) {
|
||||
await setCachedTools(baseConfig.availableTools);
|
||||
}
|
||||
|
||||
await cache.set(BASE_CONFIG_KEY, baseConfig);
|
||||
}
|
||||
|
||||
// For now, return the base config
|
||||
// In the future, this is where we'll apply role-based modifications
|
||||
if (role) {
|
||||
// TODO: Apply role-based config modifications
|
||||
// const roleConfig = await applyRoleBasedConfig(baseConfig, role);
|
||||
// await cache.set(cacheKey, roleConfig);
|
||||
// return roleConfig;
|
||||
}
|
||||
|
||||
return baseConfig;
|
||||
}
|
||||
|
||||
/**
|
||||
* Clear the app configuration cache
|
||||
* @returns {Promise<boolean>}
|
||||
* Invalidate all config-related caches after an admin config mutation.
|
||||
* Clears the base config, per-principal override caches, tool caches,
|
||||
* the endpoints config cache, and the MCP config-source server cache.
|
||||
* @param {string} [tenantId] - Optional tenant ID to scope override cache clearing.
|
||||
*/
|
||||
async function clearAppConfigCache() {
|
||||
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
||||
const cacheKey = CacheKeys.APP_CONFIG;
|
||||
return await cache.delete(cacheKey);
|
||||
async function invalidateConfigCaches(tenantId) {
|
||||
const results = await Promise.allSettled([
|
||||
clearAppConfigCache(),
|
||||
clearOverrideCache(tenantId),
|
||||
invalidateCachedTools({ invalidateGlobal: true }),
|
||||
clearEndpointConfigCache(),
|
||||
clearMcpConfigCache(),
|
||||
]);
|
||||
const labels = [
|
||||
'clearAppConfigCache',
|
||||
'clearOverrideCache',
|
||||
'invalidateCachedTools',
|
||||
'clearEndpointConfigCache',
|
||||
'clearMcpConfigCache',
|
||||
];
|
||||
for (let i = 0; i < results.length; i++) {
|
||||
if (results[i].status === 'rejected') {
|
||||
logger.error(`[invalidateConfigCaches] ${labels[i]} failed:`, results[i].reason);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
getAppConfig,
|
||||
clearAppConfigCache,
|
||||
invalidateConfigCaches,
|
||||
};
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
const { scopedCacheKey } = require('@librechat/data-schemas');
|
||||
const { loadCustomEndpointsConfig } = require('@librechat/api');
|
||||
const {
|
||||
CacheKeys,
|
||||
|
|
@ -17,16 +18,18 @@ const { getAppConfig } = require('./app');
|
|||
*/
|
||||
async function getEndpointsConfig(req) {
|
||||
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
||||
const cachedEndpointsConfig = await cache.get(CacheKeys.ENDPOINT_CONFIG);
|
||||
const cacheKey = scopedCacheKey(CacheKeys.ENDPOINT_CONFIG);
|
||||
const cachedEndpointsConfig = await cache.get(cacheKey);
|
||||
if (cachedEndpointsConfig) {
|
||||
if (cachedEndpointsConfig.gptPlugins) {
|
||||
await cache.delete(CacheKeys.ENDPOINT_CONFIG);
|
||||
await cache.delete(cacheKey);
|
||||
} else {
|
||||
return cachedEndpointsConfig;
|
||||
}
|
||||
}
|
||||
|
||||
const appConfig = req.config ?? (await getAppConfig({ role: req.user?.role }));
|
||||
const appConfig =
|
||||
req.config ?? (await getAppConfig({ role: req.user?.role, tenantId: req.user?.tenantId }));
|
||||
const defaultEndpointsConfig = await loadDefaultEndpointsConfig(appConfig);
|
||||
const customEndpointsConfig = loadCustomEndpointsConfig(appConfig?.endpoints?.custom);
|
||||
|
||||
|
|
@ -111,7 +114,7 @@ async function getEndpointsConfig(req) {
|
|||
|
||||
const endpointsConfig = orderEndpointsConfig(mergedConfig);
|
||||
|
||||
await cache.set(CacheKeys.ENDPOINT_CONFIG, endpointsConfig);
|
||||
await cache.set(cacheKey, endpointsConfig);
|
||||
return endpointsConfig;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ const { getAppConfig } = require('./app');
|
|||
* @param {ServerRequest} req - The Express request object.
|
||||
*/
|
||||
async function loadConfigModels(req) {
|
||||
const appConfig = await getAppConfig({ role: req.user?.role });
|
||||
const appConfig = await getAppConfig({ role: req.user?.role, tenantId: req.user?.tenantId });
|
||||
if (!appConfig) {
|
||||
return {};
|
||||
}
|
||||
|
|
|
|||
|
|
@ -16,7 +16,8 @@ const { getAppConfig } = require('./app');
|
|||
*/
|
||||
async function loadDefaultModels(req) {
|
||||
try {
|
||||
const appConfig = req.config ?? (await getAppConfig({ role: req.user?.role }));
|
||||
const appConfig =
|
||||
req.config ?? (await getAppConfig({ role: req.user?.role, tenantId: req.user?.tenantId }));
|
||||
const vertexConfig = appConfig?.endpoints?.[EModelEndpoint.anthropic]?.vertexConfig;
|
||||
|
||||
const [openAI, anthropic, azureOpenAI, assistants, azureAssistants, google, bedrock] =
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
const { logger } = require('@librechat/data-schemas');
|
||||
const { CacheKeys, Constants } = require('librechat-data-provider');
|
||||
const { logger, scopedCacheKey } = require('@librechat/data-schemas');
|
||||
const { getCachedTools, setCachedTools } = require('./getCachedTools');
|
||||
const { getLogStores } = require('~/cache');
|
||||
|
||||
|
|
@ -36,7 +36,7 @@ async function updateMCPServerTools({ userId, serverName, tools }) {
|
|||
await setCachedTools(serverTools, { userId, serverName });
|
||||
|
||||
const cache = getLogStores(CacheKeys.TOOL_CACHE);
|
||||
await cache.delete(CacheKeys.TOOLS);
|
||||
await cache.delete(scopedCacheKey(CacheKeys.TOOLS));
|
||||
logger.debug(
|
||||
`[MCP Cache] Updated ${tools.length} tools for server ${serverName} (user: ${userId})`,
|
||||
);
|
||||
|
|
@ -48,7 +48,10 @@ async function updateMCPServerTools({ userId, serverName, tools }) {
|
|||
}
|
||||
|
||||
/**
|
||||
* Merges app-level tools with global tools
|
||||
* Merges app-level tools with global tools.
|
||||
* Only the current ALS-scoped key (base key in system/startup context) is cleared.
|
||||
* Tenant-scoped TOOLS:tenantId keys are NOT actively invalidated — they expire
|
||||
* via TTL on the next tenant request. This matches clearEndpointConfigCache behavior.
|
||||
* @param {import('@librechat/api').LCAvailableTools} appTools
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
|
|
@ -62,7 +65,7 @@ async function mergeAppTools(appTools) {
|
|||
const mergedTools = { ...cachedTools, ...appTools };
|
||||
await setCachedTools(mergedTools);
|
||||
const cache = getLogStores(CacheKeys.TOOL_CACHE);
|
||||
await cache.delete(CacheKeys.TOOLS);
|
||||
await cache.delete(scopedCacheKey(CacheKeys.TOOLS));
|
||||
logger.debug(`Merged ${count} app-level tools`);
|
||||
} catch (error) {
|
||||
logger.error('Failed to merge app-level tools:', error);
|
||||
|
|
|
|||
|
|
@ -142,6 +142,7 @@ class STTService {
|
|||
req.config ??
|
||||
(await getAppConfig({
|
||||
role: req?.user?.role,
|
||||
tenantId: req?.user?.tenantId,
|
||||
}));
|
||||
const sttSchema = appConfig?.speech?.stt;
|
||||
if (!sttSchema) {
|
||||
|
|
|
|||
|
|
@ -297,6 +297,7 @@ class TTSService {
|
|||
req.config ??
|
||||
(await getAppConfig({
|
||||
role: req.user?.role,
|
||||
tenantId: req.user?.tenantId,
|
||||
}));
|
||||
try {
|
||||
res.setHeader('Content-Type', 'audio/mpeg');
|
||||
|
|
@ -365,6 +366,7 @@ class TTSService {
|
|||
req.config ??
|
||||
(await getAppConfig({
|
||||
role: req.user?.role,
|
||||
tenantId: req.user?.tenantId,
|
||||
}));
|
||||
const provider = this.getProvider(appConfig);
|
||||
const ttsSchema = appConfig?.speech?.tts?.[provider];
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ async function getCustomConfigSpeech(req, res) {
|
|||
try {
|
||||
const appConfig = await getAppConfig({
|
||||
role: req.user?.role,
|
||||
tenantId: req.user?.tenantId,
|
||||
});
|
||||
|
||||
if (!appConfig) {
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ async function getVoices(req, res) {
|
|||
req.config ??
|
||||
(await getAppConfig({
|
||||
role: req.user?.role,
|
||||
tenantId: req.user?.tenantId,
|
||||
}));
|
||||
|
||||
const ttsSchema = appConfig?.speech?.tts;
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
const { scopedCacheKey } = require('@librechat/data-schemas');
|
||||
const {
|
||||
Time,
|
||||
CacheKeys,
|
||||
|
|
@ -67,6 +68,8 @@ function createChunkProcessor(user, messageId) {
|
|||
}
|
||||
|
||||
const messageCache = getLogStores(CacheKeys.MESSAGES);
|
||||
// Captured at creation time — must be called within an active request ALS scope
|
||||
const cacheKey = scopedCacheKey(messageId);
|
||||
|
||||
/**
|
||||
* @returns {Promise<{ text: string, isFinished: boolean }[] | string>}
|
||||
|
|
@ -81,7 +84,7 @@ function createChunkProcessor(user, messageId) {
|
|||
}
|
||||
|
||||
/** @type { string | { text: string; complete: boolean } } */
|
||||
let message = await messageCache.get(messageId);
|
||||
let message = await messageCache.get(cacheKey);
|
||||
if (!message) {
|
||||
message = await getMessage({ user, messageId });
|
||||
}
|
||||
|
|
@ -92,7 +95,7 @@ function createChunkProcessor(user, messageId) {
|
|||
} else {
|
||||
const text = message.content?.length > 0 ? parseTextParts(message.content) : message.text;
|
||||
messageCache.set(
|
||||
messageId,
|
||||
cacheKey,
|
||||
{
|
||||
text,
|
||||
complete: true,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
const { tool } = require('@langchain/core/tools');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { logger, getTenantId } = require('@librechat/data-schemas');
|
||||
const {
|
||||
Providers,
|
||||
StepTypes,
|
||||
|
|
@ -14,6 +14,7 @@ const {
|
|||
normalizeJsonSchema,
|
||||
GenerationJobManager,
|
||||
resolveJsonSchemaRefs,
|
||||
buildOAuthToolCallName,
|
||||
} = require('@librechat/api');
|
||||
const { Time, CacheKeys, Constants, isAssistantsEndpoint } = require('librechat-data-provider');
|
||||
const {
|
||||
|
|
@ -53,6 +54,53 @@ function evictStale(map, ttl) {
|
|||
const unavailableMsg =
|
||||
"This tool's MCP server is temporarily unavailable. Please try again shortly.";
|
||||
|
||||
/**
|
||||
* Resolves config-source MCP servers from admin Config overrides for the current
|
||||
* request context. Returns the parsed configs keyed by server name.
|
||||
* @param {import('express').Request} req - Express request with user context
|
||||
* @returns {Promise<Record<string, import('@librechat/api').ParsedServerConfig>>}
|
||||
*/
|
||||
async function resolveConfigServers(req) {
|
||||
try {
|
||||
const registry = getMCPServersRegistry();
|
||||
const user = req?.user;
|
||||
const appConfig = await getAppConfig({
|
||||
role: user?.role,
|
||||
tenantId: getTenantId(),
|
||||
userId: user?.id,
|
||||
});
|
||||
return await registry.ensureConfigServers(appConfig?.mcpConfig || {});
|
||||
} catch (error) {
|
||||
logger.warn(
|
||||
'[resolveConfigServers] Failed to resolve config servers, degrading to empty:',
|
||||
error,
|
||||
);
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolves config-source servers and merges all server configs (YAML + config + user DB)
|
||||
* for the given user context. Shared helper for controllers needing the full merged config.
|
||||
* @param {string} userId
|
||||
* @param {{ id?: string, role?: string }} [user]
|
||||
* @returns {Promise<Record<string, import('@librechat/api').ParsedServerConfig>>}
|
||||
*/
|
||||
async function resolveAllMcpConfigs(userId, user) {
|
||||
const registry = getMCPServersRegistry();
|
||||
const appConfig = await getAppConfig({ role: user?.role, tenantId: getTenantId(), userId });
|
||||
let configServers = {};
|
||||
try {
|
||||
configServers = await registry.ensureConfigServers(appConfig?.mcpConfig || {});
|
||||
} catch (error) {
|
||||
logger.warn(
|
||||
'[resolveAllMcpConfigs] Config server resolution failed, continuing without:',
|
||||
error,
|
||||
);
|
||||
}
|
||||
return await registry.getAllServerConfigs(userId, configServers);
|
||||
}
|
||||
|
||||
/**
|
||||
* @param {string} toolName
|
||||
* @param {string} serverName
|
||||
|
|
@ -248,6 +296,7 @@ async function reconnectServer({
|
|||
index,
|
||||
signal,
|
||||
serverName,
|
||||
configServers,
|
||||
userMCPAuthMap,
|
||||
streamId = null,
|
||||
}) {
|
||||
|
|
@ -271,7 +320,7 @@ async function reconnectServer({
|
|||
const stepId = 'step_oauth_login_' + serverName;
|
||||
const toolCall = {
|
||||
id: flowId,
|
||||
name: serverName,
|
||||
name: buildOAuthToolCallName(serverName),
|
||||
type: 'tool_call_chunk',
|
||||
};
|
||||
|
||||
|
|
@ -316,6 +365,7 @@ async function reconnectServer({
|
|||
user,
|
||||
signal,
|
||||
serverName,
|
||||
configServers,
|
||||
oauthStart,
|
||||
flowManager,
|
||||
userMCPAuthMap,
|
||||
|
|
@ -358,15 +408,14 @@ async function createMCPTools({
|
|||
config,
|
||||
provider,
|
||||
serverName,
|
||||
configServers,
|
||||
userMCPAuthMap,
|
||||
streamId = null,
|
||||
}) {
|
||||
// Early domain validation before reconnecting server (avoid wasted work on disallowed domains)
|
||||
// Use getAppConfig() to support per-user/role domain restrictions
|
||||
const serverConfig =
|
||||
config ?? (await getMCPServersRegistry().getServerConfig(serverName, user?.id));
|
||||
config ?? (await getMCPServersRegistry().getServerConfig(serverName, user?.id, configServers));
|
||||
if (serverConfig?.url) {
|
||||
const appConfig = await getAppConfig({ role: user?.role });
|
||||
const appConfig = await getAppConfig({ role: user?.role, tenantId: user?.tenantId });
|
||||
const allowedDomains = appConfig?.mcpSettings?.allowedDomains;
|
||||
const isDomainAllowed = await isMCPDomainAllowed(serverConfig, allowedDomains);
|
||||
if (!isDomainAllowed) {
|
||||
|
|
@ -381,6 +430,7 @@ async function createMCPTools({
|
|||
index,
|
||||
signal,
|
||||
serverName,
|
||||
configServers,
|
||||
userMCPAuthMap,
|
||||
streamId,
|
||||
});
|
||||
|
|
@ -400,6 +450,7 @@ async function createMCPTools({
|
|||
user,
|
||||
provider,
|
||||
userMCPAuthMap,
|
||||
configServers,
|
||||
streamId,
|
||||
availableTools: result.availableTools,
|
||||
toolKey: `${tool.name}${Constants.mcp_delimiter}${serverName}`,
|
||||
|
|
@ -439,16 +490,15 @@ async function createMCPTool({
|
|||
userMCPAuthMap,
|
||||
availableTools,
|
||||
config,
|
||||
configServers,
|
||||
streamId = null,
|
||||
}) {
|
||||
const [toolName, serverName] = toolKey.split(Constants.mcp_delimiter);
|
||||
|
||||
// Runtime domain validation: check if the server's domain is still allowed
|
||||
// Use getAppConfig() to support per-user/role domain restrictions
|
||||
const serverConfig =
|
||||
config ?? (await getMCPServersRegistry().getServerConfig(serverName, user?.id));
|
||||
config ?? (await getMCPServersRegistry().getServerConfig(serverName, user?.id, configServers));
|
||||
if (serverConfig?.url) {
|
||||
const appConfig = await getAppConfig({ role: user?.role });
|
||||
const appConfig = await getAppConfig({ role: user?.role, tenantId: user?.tenantId });
|
||||
const allowedDomains = appConfig?.mcpSettings?.allowedDomains;
|
||||
const isDomainAllowed = await isMCPDomainAllowed(serverConfig, allowedDomains);
|
||||
if (!isDomainAllowed) {
|
||||
|
|
@ -477,6 +527,7 @@ async function createMCPTool({
|
|||
index,
|
||||
signal,
|
||||
serverName,
|
||||
configServers,
|
||||
userMCPAuthMap,
|
||||
streamId,
|
||||
});
|
||||
|
|
@ -500,6 +551,7 @@ async function createMCPTool({
|
|||
provider,
|
||||
toolName,
|
||||
serverName,
|
||||
serverConfig,
|
||||
toolDefinition,
|
||||
streamId,
|
||||
});
|
||||
|
|
@ -509,13 +561,14 @@ function createToolInstance({
|
|||
res,
|
||||
toolName,
|
||||
serverName,
|
||||
serverConfig: capturedServerConfig,
|
||||
toolDefinition,
|
||||
provider: _provider,
|
||||
provider: capturedProvider,
|
||||
streamId = null,
|
||||
}) {
|
||||
/** @type {LCTool} */
|
||||
const { description, parameters } = toolDefinition;
|
||||
const isGoogle = _provider === Providers.VERTEXAI || _provider === Providers.GOOGLE;
|
||||
const isGoogle = capturedProvider === Providers.VERTEXAI || capturedProvider === Providers.GOOGLE;
|
||||
|
||||
let schema = parameters ? normalizeJsonSchema(resolveJsonSchemaRefs(parameters)) : null;
|
||||
|
||||
|
|
@ -544,7 +597,7 @@ function createToolInstance({
|
|||
const flowManager = getFlowStateManager(flowsCache);
|
||||
derivedSignal = config?.signal ? AbortSignal.any([config.signal]) : undefined;
|
||||
const mcpManager = getMCPManager(userId);
|
||||
const provider = (config?.metadata?.provider || _provider)?.toLowerCase();
|
||||
const provider = (config?.metadata?.provider || capturedProvider)?.toLowerCase();
|
||||
|
||||
const { args: _args, stepId, ...toolCall } = config.toolCall ?? {};
|
||||
const flowId = `${serverName}:oauth_login:${config.metadata.thread_id}:${config.metadata.run_id}`;
|
||||
|
|
@ -576,6 +629,7 @@ function createToolInstance({
|
|||
|
||||
const result = await mcpManager.callTool({
|
||||
serverName,
|
||||
serverConfig: capturedServerConfig,
|
||||
toolName,
|
||||
provider,
|
||||
toolArguments,
|
||||
|
|
@ -643,30 +697,36 @@ function createToolInstance({
|
|||
}
|
||||
|
||||
/**
|
||||
* Get MCP setup data including config, connections, and OAuth servers
|
||||
* Get MCP setup data including config, connections, and OAuth servers.
|
||||
* Resolves config-source servers from admin Config overrides when tenant context is available.
|
||||
* @param {string} userId - The user ID
|
||||
* @param {{ role?: string, tenantId?: string }} [options] - Optional role/tenant context
|
||||
* @returns {Object} Object containing mcpConfig, appConnections, userConnections, and oauthServers
|
||||
*/
|
||||
async function getMCPSetupData(userId) {
|
||||
const mcpConfig = await getMCPServersRegistry().getAllServerConfigs(userId);
|
||||
|
||||
if (!mcpConfig) {
|
||||
throw new Error('MCP config not found');
|
||||
}
|
||||
async function getMCPSetupData(userId, options = {}) {
|
||||
const registry = getMCPServersRegistry();
|
||||
const { role, tenantId } = options;
|
||||
|
||||
const appConfig = await getAppConfig({ role, tenantId, userId });
|
||||
const configServers = await registry.ensureConfigServers(appConfig?.mcpConfig || {});
|
||||
const mcpConfig = await registry.getAllServerConfigs(userId, configServers);
|
||||
const mcpManager = getMCPManager(userId);
|
||||
/** @type {Map<string, import('@librechat/api').MCPConnection>} */
|
||||
let appConnections = new Map();
|
||||
try {
|
||||
// Use getLoaded() instead of getAll() to avoid forcing connection creation
|
||||
// Use getLoaded() instead of getAll() to avoid forcing connection creation.
|
||||
// getAll() creates connections for all servers, which is problematic for servers
|
||||
// that require user context (e.g., those with {{LIBRECHAT_USER_ID}} placeholders)
|
||||
// that require user context (e.g., those with {{LIBRECHAT_USER_ID}} placeholders).
|
||||
appConnections = (await mcpManager.appConnections?.getLoaded()) || new Map();
|
||||
} catch (error) {
|
||||
logger.error(`[MCP][User: ${userId}] Error getting app connections:`, error);
|
||||
}
|
||||
const userConnections = mcpManager.getUserConnections(userId) || new Map();
|
||||
const oauthServers = await getMCPServersRegistry().getOAuthServers(userId);
|
||||
const oauthServers = new Set(
|
||||
Object.entries(mcpConfig)
|
||||
.filter(([, config]) => config.requiresOAuth)
|
||||
.map(([name]) => name),
|
||||
);
|
||||
|
||||
return {
|
||||
mcpConfig,
|
||||
|
|
@ -788,6 +848,8 @@ module.exports = {
|
|||
createMCPTool,
|
||||
createMCPTools,
|
||||
getMCPSetupData,
|
||||
resolveConfigServers,
|
||||
resolveAllMcpConfigs,
|
||||
checkOAuthFlowStatus,
|
||||
getServerConnectionStatus,
|
||||
createUnavailableToolStub,
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ const mockRegistryInstance = {
|
|||
getOAuthServers: jest.fn(() => Promise.resolve(new Set())),
|
||||
getAllServerConfigs: jest.fn(() => Promise.resolve({})),
|
||||
getServerConfig: jest.fn(() => Promise.resolve(null)),
|
||||
ensureConfigServers: jest.fn(() => Promise.resolve({})),
|
||||
};
|
||||
|
||||
// Create isMCPDomainAllowed mock that can be configured per-test
|
||||
|
|
@ -113,38 +114,43 @@ describe('tests for the new helper functions used by the MCP connection status e
|
|||
});
|
||||
|
||||
it('should successfully return MCP setup data', async () => {
|
||||
mockRegistryInstance.getAllServerConfigs.mockResolvedValue(mockConfig);
|
||||
const mockConfigWithOAuth = {
|
||||
server1: { type: 'stdio' },
|
||||
server2: { type: 'http', requiresOAuth: true },
|
||||
};
|
||||
mockRegistryInstance.getAllServerConfigs.mockResolvedValue(mockConfigWithOAuth);
|
||||
|
||||
const mockAppConnections = new Map([['server1', { status: 'connected' }]]);
|
||||
const mockUserConnections = new Map([['server2', { status: 'disconnected' }]]);
|
||||
const mockOAuthServers = new Set(['server2']);
|
||||
|
||||
const mockMCPManager = {
|
||||
appConnections: { getLoaded: jest.fn(() => Promise.resolve(mockAppConnections)) },
|
||||
getUserConnections: jest.fn(() => mockUserConnections),
|
||||
};
|
||||
mockGetMCPManager.mockReturnValue(mockMCPManager);
|
||||
mockRegistryInstance.getOAuthServers.mockResolvedValue(mockOAuthServers);
|
||||
|
||||
const result = await getMCPSetupData(mockUserId);
|
||||
|
||||
expect(mockRegistryInstance.getAllServerConfigs).toHaveBeenCalledWith(mockUserId);
|
||||
expect(mockRegistryInstance.ensureConfigServers).toHaveBeenCalled();
|
||||
expect(mockRegistryInstance.getAllServerConfigs).toHaveBeenCalledWith(
|
||||
mockUserId,
|
||||
expect.any(Object),
|
||||
);
|
||||
expect(mockGetMCPManager).toHaveBeenCalledWith(mockUserId);
|
||||
expect(mockMCPManager.appConnections.getLoaded).toHaveBeenCalled();
|
||||
expect(mockMCPManager.getUserConnections).toHaveBeenCalledWith(mockUserId);
|
||||
expect(mockRegistryInstance.getOAuthServers).toHaveBeenCalledWith(mockUserId);
|
||||
|
||||
expect(result).toEqual({
|
||||
mcpConfig: mockConfig,
|
||||
appConnections: mockAppConnections,
|
||||
userConnections: mockUserConnections,
|
||||
oauthServers: mockOAuthServers,
|
||||
});
|
||||
expect(result.mcpConfig).toEqual(mockConfigWithOAuth);
|
||||
expect(result.appConnections).toEqual(mockAppConnections);
|
||||
expect(result.userConnections).toEqual(mockUserConnections);
|
||||
expect(result.oauthServers).toEqual(new Set(['server2']));
|
||||
});
|
||||
|
||||
it('should throw error when MCP config not found', async () => {
|
||||
mockRegistryInstance.getAllServerConfigs.mockResolvedValue(null);
|
||||
await expect(getMCPSetupData(mockUserId)).rejects.toThrow('MCP config not found');
|
||||
it('should return empty data when no servers are configured', async () => {
|
||||
mockRegistryInstance.getAllServerConfigs.mockResolvedValue({});
|
||||
const result = await getMCPSetupData(mockUserId);
|
||||
expect(result.mcpConfig).toEqual({});
|
||||
expect(result.oauthServers).toEqual(new Set());
|
||||
});
|
||||
|
||||
it('should handle null values from MCP manager gracefully', async () => {
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ const {
|
|||
buildWebSearchContext,
|
||||
buildImageToolContext,
|
||||
buildToolClassification,
|
||||
buildOAuthToolCallName,
|
||||
} = require('@librechat/api');
|
||||
const {
|
||||
Time,
|
||||
|
|
@ -59,6 +60,7 @@ const { manifestToolMap, toolkits } = require('~/app/clients/tools/manifest');
|
|||
const { createOnSearchResults } = require('~/server/services/Tools/search');
|
||||
const { loadAuthValues } = require('~/server/services/Tools/credentials');
|
||||
const { reinitMCPServer } = require('~/server/services/Tools/mcp');
|
||||
const { resolveConfigServers } = require('~/server/services/MCP');
|
||||
const { recordUsage } = require('~/server/services/Threads');
|
||||
const { loadTools } = require('~/app/clients/tools/util');
|
||||
const { redactMessage } = require('~/config/parsers');
|
||||
|
|
@ -513,6 +515,7 @@ async function loadToolDefinitionsWrapper({ req, res, agent, streamId = null, to
|
|||
|
||||
const flowsCache = getLogStores(CacheKeys.FLOWS);
|
||||
const flowManager = getFlowStateManager(flowsCache);
|
||||
const configServers = await resolveConfigServers(req);
|
||||
const pendingOAuthServers = new Set();
|
||||
|
||||
const createOAuthEmitter = (serverName) => {
|
||||
|
|
@ -521,7 +524,7 @@ async function loadToolDefinitionsWrapper({ req, res, agent, streamId = null, to
|
|||
const stepId = 'step_oauth_login_' + serverName;
|
||||
const toolCall = {
|
||||
id: flowId,
|
||||
name: serverName,
|
||||
name: buildOAuthToolCallName(serverName),
|
||||
type: 'tool_call_chunk',
|
||||
};
|
||||
|
||||
|
|
@ -578,6 +581,7 @@ async function loadToolDefinitionsWrapper({ req, res, agent, streamId = null, to
|
|||
oauthStart,
|
||||
flowManager,
|
||||
serverName,
|
||||
configServers,
|
||||
userMCPAuthMap,
|
||||
});
|
||||
|
||||
|
|
@ -665,6 +669,7 @@ async function loadToolDefinitionsWrapper({ req, res, agent, streamId = null, to
|
|||
const result = await reinitMCPServer({
|
||||
user: req.user,
|
||||
serverName,
|
||||
configServers,
|
||||
userMCPAuthMap,
|
||||
flowManager,
|
||||
returnOnOAuth: false,
|
||||
|
|
|
|||
|
|
@ -25,11 +25,13 @@ async function reinitMCPServer({
|
|||
signal,
|
||||
forceNew,
|
||||
serverName,
|
||||
configServers,
|
||||
userMCPAuthMap,
|
||||
connectionTimeout,
|
||||
returnOnOAuth = true,
|
||||
oauthStart: _oauthStart,
|
||||
flowManager: _flowManager,
|
||||
serverConfig: providedConfig,
|
||||
}) {
|
||||
/** @type {MCPConnection | null} */
|
||||
let connection = null;
|
||||
|
|
@ -42,13 +44,28 @@ async function reinitMCPServer({
|
|||
|
||||
try {
|
||||
const registry = getMCPServersRegistry();
|
||||
const serverConfig = await registry.getServerConfig(serverName, user?.id);
|
||||
const serverConfig =
|
||||
providedConfig ?? (await registry.getServerConfig(serverName, user?.id, configServers));
|
||||
if (serverConfig?.inspectionFailed) {
|
||||
if (serverConfig.source === 'config') {
|
||||
logger.info(
|
||||
`[MCP Reinitialize] Config-source server ${serverName} has inspectionFailed — retry handled by config cache`,
|
||||
);
|
||||
return {
|
||||
availableTools: null,
|
||||
success: false,
|
||||
message: `MCP server '${serverName}' is still unreachable`,
|
||||
oauthRequired: false,
|
||||
serverName,
|
||||
oauthUrl: null,
|
||||
tools: null,
|
||||
};
|
||||
}
|
||||
logger.info(
|
||||
`[MCP Reinitialize] Server ${serverName} had failed inspection, attempting reinspection`,
|
||||
);
|
||||
try {
|
||||
const storageLocation = serverConfig.dbId ? 'DB' : 'CACHE';
|
||||
const storageLocation = serverConfig.source === 'user' ? 'DB' : 'CACHE';
|
||||
await registry.reinspectServer(serverName, storageLocation, user?.id);
|
||||
logger.info(`[MCP Reinitialize] Reinspection succeeded for server: ${serverName}`);
|
||||
} catch (reinspectError) {
|
||||
|
|
@ -93,6 +110,7 @@ async function reinitMCPServer({
|
|||
returnOnOAuth,
|
||||
customUserVars,
|
||||
connectionTimeout,
|
||||
serverConfig,
|
||||
});
|
||||
|
||||
logger.info(`[MCP Reinitialize] Successfully established connection for ${serverName}`);
|
||||
|
|
@ -125,6 +143,7 @@ async function reinitMCPServer({
|
|||
oauthStart,
|
||||
customUserVars,
|
||||
connectionTimeout,
|
||||
configServers,
|
||||
});
|
||||
|
||||
if (discoveryResult.tools && discoveryResult.tools.length > 0) {
|
||||
|
|
|
|||
131
api/server/services/__tests__/MCP.spec.js
Normal file
131
api/server/services/__tests__/MCP.spec.js
Normal file
|
|
@ -0,0 +1,131 @@
|
|||
const mockRegistry = {
|
||||
ensureConfigServers: jest.fn(),
|
||||
getAllServerConfigs: jest.fn(),
|
||||
};
|
||||
|
||||
jest.mock('~/config', () => ({
|
||||
getMCPServersRegistry: jest.fn(() => mockRegistry),
|
||||
getMCPManager: jest.fn(),
|
||||
getFlowStateManager: jest.fn(),
|
||||
getOAuthReconnectionManager: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
getTenantId: jest.fn(() => 'tenant-1'),
|
||||
logger: { debug: jest.fn(), info: jest.fn(), warn: jest.fn(), error: jest.fn() },
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Config', () => ({
|
||||
getAppConfig: jest.fn(),
|
||||
setCachedTools: jest.fn(),
|
||||
getCachedTools: jest.fn(),
|
||||
getMCPServerTools: jest.fn(),
|
||||
loadCustomConfig: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/cache', () => ({ getLogStores: jest.fn() }));
|
||||
jest.mock('~/models', () => ({
|
||||
findToken: jest.fn(),
|
||||
createToken: jest.fn(),
|
||||
updateToken: jest.fn(),
|
||||
}));
|
||||
jest.mock('~/server/services/GraphTokenService', () => ({
|
||||
getGraphApiToken: jest.fn(),
|
||||
}));
|
||||
jest.mock('~/server/services/Tools/mcp', () => ({
|
||||
reinitMCPServer: jest.fn(),
|
||||
}));
|
||||
|
||||
const { getAppConfig } = require('~/server/services/Config');
|
||||
const { resolveConfigServers, resolveAllMcpConfigs } = require('../MCP');
|
||||
|
||||
describe('resolveConfigServers', () => {
|
||||
beforeEach(() => jest.clearAllMocks());
|
||||
|
||||
it('resolves config servers for the current request context', async () => {
|
||||
getAppConfig.mockResolvedValue({ mcpConfig: { srv: { url: 'http://a' } } });
|
||||
mockRegistry.ensureConfigServers.mockResolvedValue({ srv: { name: 'srv' } });
|
||||
|
||||
const result = await resolveConfigServers({ user: { id: 'u1', role: 'admin' } });
|
||||
|
||||
expect(result).toEqual({ srv: { name: 'srv' } });
|
||||
expect(getAppConfig).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ role: 'admin', userId: 'u1' }),
|
||||
);
|
||||
expect(mockRegistry.ensureConfigServers).toHaveBeenCalledWith({ srv: { url: 'http://a' } });
|
||||
});
|
||||
|
||||
it('returns {} when ensureConfigServers throws', async () => {
|
||||
getAppConfig.mockResolvedValue({ mcpConfig: { srv: {} } });
|
||||
mockRegistry.ensureConfigServers.mockRejectedValue(new Error('inspect failed'));
|
||||
|
||||
const result = await resolveConfigServers({ user: { id: 'u1' } });
|
||||
|
||||
expect(result).toEqual({});
|
||||
});
|
||||
|
||||
it('returns {} when getAppConfig throws', async () => {
|
||||
getAppConfig.mockRejectedValue(new Error('db timeout'));
|
||||
|
||||
const result = await resolveConfigServers({ user: { id: 'u1' } });
|
||||
|
||||
expect(result).toEqual({});
|
||||
});
|
||||
|
||||
it('passes empty mcpConfig when appConfig has none', async () => {
|
||||
getAppConfig.mockResolvedValue({});
|
||||
mockRegistry.ensureConfigServers.mockResolvedValue({});
|
||||
|
||||
await resolveConfigServers({ user: { id: 'u1' } });
|
||||
|
||||
expect(mockRegistry.ensureConfigServers).toHaveBeenCalledWith({});
|
||||
});
|
||||
});
|
||||
|
||||
describe('resolveAllMcpConfigs', () => {
|
||||
beforeEach(() => jest.clearAllMocks());
|
||||
|
||||
it('merges config servers with base servers', async () => {
|
||||
getAppConfig.mockResolvedValue({ mcpConfig: { cfg_srv: {} } });
|
||||
mockRegistry.ensureConfigServers.mockResolvedValue({ cfg_srv: { name: 'cfg_srv' } });
|
||||
mockRegistry.getAllServerConfigs.mockResolvedValue({
|
||||
cfg_srv: { name: 'cfg_srv' },
|
||||
yaml_srv: { name: 'yaml_srv' },
|
||||
});
|
||||
|
||||
const result = await resolveAllMcpConfigs('u1', { id: 'u1', role: 'user' });
|
||||
|
||||
expect(result).toEqual({
|
||||
cfg_srv: { name: 'cfg_srv' },
|
||||
yaml_srv: { name: 'yaml_srv' },
|
||||
});
|
||||
expect(mockRegistry.getAllServerConfigs).toHaveBeenCalledWith('u1', {
|
||||
cfg_srv: { name: 'cfg_srv' },
|
||||
});
|
||||
});
|
||||
|
||||
it('continues with empty configServers when ensureConfigServers fails', async () => {
|
||||
getAppConfig.mockResolvedValue({ mcpConfig: { srv: {} } });
|
||||
mockRegistry.ensureConfigServers.mockRejectedValue(new Error('inspect failed'));
|
||||
mockRegistry.getAllServerConfigs.mockResolvedValue({ yaml_srv: { name: 'yaml_srv' } });
|
||||
|
||||
const result = await resolveAllMcpConfigs('u1', { id: 'u1' });
|
||||
|
||||
expect(result).toEqual({ yaml_srv: { name: 'yaml_srv' } });
|
||||
expect(mockRegistry.getAllServerConfigs).toHaveBeenCalledWith('u1', {});
|
||||
});
|
||||
|
||||
it('propagates getAllServerConfigs failures', async () => {
|
||||
getAppConfig.mockResolvedValue({ mcpConfig: {} });
|
||||
mockRegistry.ensureConfigServers.mockResolvedValue({});
|
||||
mockRegistry.getAllServerConfigs.mockRejectedValue(new Error('redis down'));
|
||||
|
||||
await expect(resolveAllMcpConfigs('u1', { id: 'u1' })).rejects.toThrow('redis down');
|
||||
});
|
||||
|
||||
it('propagates getAppConfig failures', async () => {
|
||||
getAppConfig.mockRejectedValue(new Error('mongo down'));
|
||||
|
||||
await expect(resolveAllMcpConfigs('u1', { id: 'u1' })).rejects.toThrow('mongo down');
|
||||
});
|
||||
});
|
||||
|
|
@ -64,6 +64,9 @@ jest.mock('~/models', () => ({
|
|||
jest.mock('~/config', () => ({
|
||||
getFlowStateManager: jest.fn(() => ({})),
|
||||
}));
|
||||
jest.mock('~/server/services/MCP', () => ({
|
||||
resolveConfigServers: jest.fn().mockResolvedValue({}),
|
||||
}));
|
||||
jest.mock('~/cache', () => ({
|
||||
getLogStores: jest.fn(() => ({})),
|
||||
}));
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ const { createMCPServersRegistry, createMCPManager } = require('~/config');
|
|||
* Initialize MCP servers
|
||||
*/
|
||||
async function initializeMCPs() {
|
||||
const appConfig = await getAppConfig();
|
||||
const appConfig = await getAppConfig({ baseOnly: true });
|
||||
const mcpServers = appConfig.mcpConfig;
|
||||
|
||||
try {
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
const { v4: uuidv4 } = require('uuid');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { logger, scopedCacheKey } = require('@librechat/data-schemas');
|
||||
const { EModelEndpoint, Constants, openAISettings, CacheKeys } = require('librechat-data-provider');
|
||||
const { createImportBatchBuilder } = require('./importBatchBuilder');
|
||||
const { cloneMessagesWithTimestamps } = require('./fork');
|
||||
|
|
@ -203,7 +203,7 @@ async function importLibreChatConvo(
|
|||
/* Endpoint configuration */
|
||||
let endpoint = jsonData.endpoint ?? options.endpoint ?? EModelEndpoint.openAI;
|
||||
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
||||
const endpointsConfig = await cache.get(CacheKeys.ENDPOINT_CONFIG);
|
||||
const endpointsConfig = await cache.get(scopedCacheKey(CacheKeys.ENDPOINT_CONFIG));
|
||||
const endpointConfig = endpointsConfig?.[endpoint];
|
||||
if (!endpointConfig && endpointsConfig) {
|
||||
endpoint = Object.keys(endpointsConfig)[0];
|
||||
|
|
|
|||
|
|
@ -2,7 +2,12 @@ const fs = require('fs');
|
|||
const LdapStrategy = require('passport-ldapauth');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { SystemRoles, ErrorTypes } = require('librechat-data-provider');
|
||||
const { isEnabled, getBalanceConfig, isEmailDomainAllowed } = require('@librechat/api');
|
||||
const {
|
||||
isEnabled,
|
||||
getBalanceConfig,
|
||||
isEmailDomainAllowed,
|
||||
resolveAppConfigForUser,
|
||||
} = require('@librechat/api');
|
||||
const { createUser, findUser, updateUser, countUsers } = require('~/models');
|
||||
const { getAppConfig } = require('~/server/services/Config');
|
||||
|
||||
|
|
@ -89,16 +94,6 @@ const ldapLogin = new LdapStrategy(ldapOptions, async (userinfo, done) => {
|
|||
const ldapId =
|
||||
(LDAP_ID && userinfo[LDAP_ID]) || userinfo.uid || userinfo.sAMAccountName || userinfo.mail;
|
||||
|
||||
let user = await findUser({ ldapId });
|
||||
if (user && user.provider !== 'ldap') {
|
||||
logger.info(
|
||||
`[ldapStrategy] User ${user.email} already exists with provider ${user.provider}`,
|
||||
);
|
||||
return done(null, false, {
|
||||
message: ErrorTypes.AUTH_FAILED,
|
||||
});
|
||||
}
|
||||
|
||||
const fullNameAttributes = LDAP_FULL_NAME && LDAP_FULL_NAME.split(',');
|
||||
const fullName =
|
||||
fullNameAttributes && fullNameAttributes.length > 0
|
||||
|
|
@ -122,7 +117,31 @@ const ldapLogin = new LdapStrategy(ldapOptions, async (userinfo, done) => {
|
|||
);
|
||||
}
|
||||
|
||||
const appConfig = await getAppConfig();
|
||||
// Domain check before findUser for two-phase fast-fail (consistent with SAML/OpenID/social).
|
||||
// This means cross-provider users from blocked domains get 'Email domain not allowed'
|
||||
// instead of AUTH_FAILED — both deny access.
|
||||
const baseConfig = await getAppConfig({ baseOnly: true });
|
||||
if (!isEmailDomainAllowed(mail, baseConfig?.registration?.allowedDomains)) {
|
||||
logger.error(
|
||||
`[LDAP Strategy] Authentication blocked - email domain not allowed [Email: ${mail}]`,
|
||||
);
|
||||
return done(null, false, { message: 'Email domain not allowed' });
|
||||
}
|
||||
|
||||
let user = await findUser({ ldapId });
|
||||
if (user && user.provider !== 'ldap') {
|
||||
logger.info(
|
||||
`[ldapStrategy] User ${user.email} already exists with provider ${user.provider}`,
|
||||
);
|
||||
return done(null, false, {
|
||||
message: ErrorTypes.AUTH_FAILED,
|
||||
});
|
||||
}
|
||||
|
||||
const appConfig = user?.tenantId
|
||||
? await resolveAppConfigForUser(getAppConfig, user)
|
||||
: baseConfig;
|
||||
|
||||
if (!isEmailDomainAllowed(mail, appConfig?.registration?.allowedDomains)) {
|
||||
logger.error(
|
||||
`[LDAP Strategy] Authentication blocked - email domain not allowed [Email: ${mail}]`,
|
||||
|
|
|
|||
|
|
@ -9,10 +9,10 @@ jest.mock('@librechat/data-schemas', () => ({
|
|||
}));
|
||||
|
||||
jest.mock('@librechat/api', () => ({
|
||||
// isEnabled used for TLS flags
|
||||
isEnabled: jest.fn(() => false),
|
||||
isEmailDomainAllowed: jest.fn(() => true),
|
||||
getBalanceConfig: jest.fn(() => ({ enabled: false })),
|
||||
resolveAppConfigForUser: jest.fn(async (_getAppConfig, _user) => ({})),
|
||||
}));
|
||||
|
||||
jest.mock('~/models', () => ({
|
||||
|
|
@ -30,14 +30,15 @@ jest.mock('~/server/services/Config', () => ({
|
|||
let verifyCallback;
|
||||
jest.mock('passport-ldapauth', () => {
|
||||
return jest.fn().mockImplementation((options, verify) => {
|
||||
verifyCallback = verify; // capture the strategy verify function
|
||||
verifyCallback = verify;
|
||||
return { name: 'ldap', options, verify };
|
||||
});
|
||||
});
|
||||
|
||||
const { ErrorTypes } = require('librechat-data-provider');
|
||||
const { isEmailDomainAllowed } = require('@librechat/api');
|
||||
const { isEmailDomainAllowed, resolveAppConfigForUser } = require('@librechat/api');
|
||||
const { findUser, createUser, updateUser, countUsers } = require('~/models');
|
||||
const { getAppConfig } = require('~/server/services/Config');
|
||||
|
||||
// Helper to call the verify callback and wrap in a Promise for convenience
|
||||
const callVerify = (userinfo) =>
|
||||
|
|
@ -117,6 +118,7 @@ describe('ldapStrategy', () => {
|
|||
expect(user).toBe(false);
|
||||
expect(info).toEqual({ message: ErrorTypes.AUTH_FAILED });
|
||||
expect(createUser).not.toHaveBeenCalled();
|
||||
expect(resolveAppConfigForUser).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('updates an existing ldap user with current LDAP info', async () => {
|
||||
|
|
@ -158,7 +160,6 @@ describe('ldapStrategy', () => {
|
|||
uid: 'uid999',
|
||||
givenName: 'John',
|
||||
cn: 'John Doe',
|
||||
// no mail and no custom LDAP_EMAIL
|
||||
};
|
||||
|
||||
const { user } = await callVerify(userinfo);
|
||||
|
|
@ -180,4 +181,66 @@ describe('ldapStrategy', () => {
|
|||
expect(user).toBe(false);
|
||||
expect(info).toEqual({ message: 'Email domain not allowed' });
|
||||
});
|
||||
|
||||
it('passes getAppConfig and found user to resolveAppConfigForUser', async () => {
|
||||
const existing = {
|
||||
_id: 'u3',
|
||||
provider: 'ldap',
|
||||
email: 'tenant@example.com',
|
||||
ldapId: 'uid-tenant',
|
||||
username: 'tenantuser',
|
||||
name: 'Tenant User',
|
||||
tenantId: 'tenant-a',
|
||||
role: 'USER',
|
||||
};
|
||||
findUser.mockResolvedValue(existing);
|
||||
|
||||
const userinfo = {
|
||||
uid: 'uid-tenant',
|
||||
mail: 'tenant@example.com',
|
||||
givenName: 'Tenant',
|
||||
cn: 'Tenant User',
|
||||
};
|
||||
|
||||
await callVerify(userinfo);
|
||||
|
||||
expect(resolveAppConfigForUser).toHaveBeenCalledWith(getAppConfig, existing);
|
||||
});
|
||||
|
||||
it('uses baseConfig for new user without calling resolveAppConfigForUser', async () => {
|
||||
findUser.mockResolvedValue(null);
|
||||
|
||||
const userinfo = {
|
||||
uid: 'uid-new',
|
||||
mail: 'newuser@example.com',
|
||||
givenName: 'New',
|
||||
cn: 'New User',
|
||||
};
|
||||
|
||||
await callVerify(userinfo);
|
||||
|
||||
expect(resolveAppConfigForUser).not.toHaveBeenCalled();
|
||||
expect(getAppConfig).toHaveBeenCalledWith({ baseOnly: true });
|
||||
});
|
||||
|
||||
it('should block login when tenant config restricts the domain', async () => {
|
||||
const existing = {
|
||||
_id: 'u-blocked',
|
||||
provider: 'ldap',
|
||||
ldapId: 'uid-tenant',
|
||||
tenantId: 'tenant-strict',
|
||||
role: 'USER',
|
||||
};
|
||||
findUser.mockResolvedValue(existing);
|
||||
resolveAppConfigForUser.mockResolvedValue({
|
||||
registration: { allowedDomains: ['other.com'] },
|
||||
});
|
||||
isEmailDomainAllowed.mockReturnValueOnce(true).mockReturnValueOnce(false);
|
||||
|
||||
const userinfo = { uid: 'uid-tenant', mail: 'user@example.com', givenName: 'Test', cn: 'Test' };
|
||||
const { user, info } = await callVerify(userinfo);
|
||||
|
||||
expect(user).toBe(false);
|
||||
expect(info).toEqual({ message: 'Email domain not allowed' });
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ const {
|
|||
findOpenIDUser,
|
||||
getBalanceConfig,
|
||||
isEmailDomainAllowed,
|
||||
resolveAppConfigForUser,
|
||||
} = require('@librechat/api');
|
||||
const { getStrategyFunctions } = require('~/server/services/Files/strategies');
|
||||
const { findUser, createUser, updateUser } = require('~/models');
|
||||
|
|
@ -468,9 +469,10 @@ async function processOpenIDAuth(tokenset, existingUsersOnly = false) {
|
|||
Object.assign(userinfo, providerUserinfo);
|
||||
}
|
||||
|
||||
const appConfig = await getAppConfig();
|
||||
const email = getOpenIdEmail(userinfo);
|
||||
if (!isEmailDomainAllowed(email, appConfig?.registration?.allowedDomains)) {
|
||||
|
||||
const baseConfig = await getAppConfig({ baseOnly: true });
|
||||
if (!isEmailDomainAllowed(email, baseConfig?.registration?.allowedDomains)) {
|
||||
logger.error(
|
||||
`[OpenID Strategy] Authentication blocked - email domain not allowed [Identifier: ${email}]`,
|
||||
);
|
||||
|
|
@ -491,6 +493,15 @@ async function processOpenIDAuth(tokenset, existingUsersOnly = false) {
|
|||
throw new Error(ErrorTypes.AUTH_FAILED);
|
||||
}
|
||||
|
||||
const appConfig = user?.tenantId ? await resolveAppConfigForUser(getAppConfig, user) : baseConfig;
|
||||
|
||||
if (!isEmailDomainAllowed(email, appConfig?.registration?.allowedDomains)) {
|
||||
logger.error(
|
||||
`[OpenID Strategy] Authentication blocked - email domain not allowed [Identifier: ${email}]`,
|
||||
);
|
||||
throw new Error('Email domain not allowed');
|
||||
}
|
||||
|
||||
const fullName = getFullName(userinfo);
|
||||
|
||||
const requiredRole = process.env.OPENID_REQUIRED_ROLE;
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -5,7 +5,11 @@ const passport = require('passport');
|
|||
const { ErrorTypes } = require('librechat-data-provider');
|
||||
const { hashToken, logger } = require('@librechat/data-schemas');
|
||||
const { Strategy: SamlStrategy } = require('@node-saml/passport-saml');
|
||||
const { getBalanceConfig, isEmailDomainAllowed } = require('@librechat/api');
|
||||
const {
|
||||
getBalanceConfig,
|
||||
isEmailDomainAllowed,
|
||||
resolveAppConfigForUser,
|
||||
} = require('@librechat/api');
|
||||
const { getStrategyFunctions } = require('~/server/services/Files/strategies');
|
||||
const { findUser, createUser, updateUser } = require('~/models');
|
||||
const { getAppConfig } = require('~/server/services/Config');
|
||||
|
|
@ -193,9 +197,9 @@ async function setupSaml() {
|
|||
logger.debug('[samlStrategy] SAML profile:', profile);
|
||||
|
||||
const userEmail = getEmail(profile) || '';
|
||||
const appConfig = await getAppConfig();
|
||||
|
||||
if (!isEmailDomainAllowed(userEmail, appConfig?.registration?.allowedDomains)) {
|
||||
const baseConfig = await getAppConfig({ baseOnly: true });
|
||||
if (!isEmailDomainAllowed(userEmail, baseConfig?.registration?.allowedDomains)) {
|
||||
logger.error(
|
||||
`[SAML Strategy] Authentication blocked - email domain not allowed [Email: ${userEmail}]`,
|
||||
);
|
||||
|
|
@ -223,6 +227,17 @@ async function setupSaml() {
|
|||
});
|
||||
}
|
||||
|
||||
const appConfig = user?.tenantId
|
||||
? await resolveAppConfigForUser(getAppConfig, user)
|
||||
: baseConfig;
|
||||
|
||||
if (!isEmailDomainAllowed(userEmail, appConfig?.registration?.allowedDomains)) {
|
||||
logger.error(
|
||||
`[SAML Strategy] Authentication blocked - email domain not allowed [Email: ${userEmail}]`,
|
||||
);
|
||||
return done(null, false, { message: 'Email domain not allowed' });
|
||||
}
|
||||
|
||||
const fullName = getFullName(profile);
|
||||
|
||||
const username = convertToUsername(
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ jest.mock('@librechat/api', () => ({
|
|||
tokenCredits: 1000,
|
||||
startBalance: 1000,
|
||||
})),
|
||||
resolveAppConfigForUser: jest.fn(async (_getAppConfig, _user) => ({})),
|
||||
}));
|
||||
jest.mock('~/server/services/Config/EndpointService', () => ({
|
||||
config: {},
|
||||
|
|
@ -47,6 +48,9 @@ const fs = require('fs');
|
|||
const path = require('path');
|
||||
const fetch = require('node-fetch');
|
||||
const { Strategy: SamlStrategy } = require('@node-saml/passport-saml');
|
||||
const { findUser } = require('~/models');
|
||||
const { resolveAppConfigForUser } = require('@librechat/api');
|
||||
const { getAppConfig } = require('~/server/services/Config');
|
||||
const { setupSaml, getCertificateContent } = require('./samlStrategy');
|
||||
|
||||
// Configure fs mock
|
||||
|
|
@ -440,4 +444,50 @@ u7wlOSk+oFzDIO/UILIA
|
|||
|
||||
expect(fetch).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should pass the found user to resolveAppConfigForUser', async () => {
|
||||
const existingUser = {
|
||||
_id: 'tenant-user-id',
|
||||
provider: 'saml',
|
||||
samlId: 'saml-1234',
|
||||
email: 'test@example.com',
|
||||
tenantId: 'tenant-c',
|
||||
role: 'USER',
|
||||
};
|
||||
findUser.mockResolvedValue(existingUser);
|
||||
|
||||
const profile = { ...baseProfile };
|
||||
await validate(profile);
|
||||
|
||||
expect(resolveAppConfigForUser).toHaveBeenCalledWith(getAppConfig, existingUser);
|
||||
});
|
||||
|
||||
it('should use baseConfig for new SAML user without calling resolveAppConfigForUser', async () => {
|
||||
const profile = { ...baseProfile };
|
||||
await validate(profile);
|
||||
|
||||
expect(resolveAppConfigForUser).not.toHaveBeenCalled();
|
||||
expect(getAppConfig).toHaveBeenCalledWith({ baseOnly: true });
|
||||
});
|
||||
|
||||
it('should block login when tenant config restricts the domain', async () => {
|
||||
const { isEmailDomainAllowed } = require('@librechat/api');
|
||||
const existingUser = {
|
||||
_id: 'tenant-blocked',
|
||||
provider: 'saml',
|
||||
samlId: 'saml-1234',
|
||||
email: 'test@example.com',
|
||||
tenantId: 'tenant-restrict',
|
||||
role: 'USER',
|
||||
};
|
||||
findUser.mockResolvedValue(existingUser);
|
||||
resolveAppConfigForUser.mockResolvedValue({
|
||||
registration: { allowedDomains: ['other.com'] },
|
||||
});
|
||||
isEmailDomainAllowed.mockReturnValueOnce(true).mockReturnValueOnce(false);
|
||||
|
||||
const profile = { ...baseProfile };
|
||||
const { user } = await validate(profile);
|
||||
expect(user).toBe(false);
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
const { logger } = require('@librechat/data-schemas');
|
||||
const { ErrorTypes } = require('librechat-data-provider');
|
||||
const { isEnabled, isEmailDomainAllowed } = require('@librechat/api');
|
||||
const { isEnabled, isEmailDomainAllowed, resolveAppConfigForUser } = require('@librechat/api');
|
||||
const { createSocialUser, handleExistingUser } = require('./process');
|
||||
const { getAppConfig } = require('~/server/services/Config');
|
||||
const { findUser } = require('~/models');
|
||||
|
|
@ -13,9 +13,8 @@ const socialLogin =
|
|||
profile,
|
||||
});
|
||||
|
||||
const appConfig = await getAppConfig();
|
||||
|
||||
if (!isEmailDomainAllowed(email, appConfig?.registration?.allowedDomains)) {
|
||||
const baseConfig = await getAppConfig({ baseOnly: true });
|
||||
if (!isEmailDomainAllowed(email, baseConfig?.registration?.allowedDomains)) {
|
||||
logger.error(
|
||||
`[${provider}Login] Authentication blocked - email domain not allowed [Email: ${email}]`,
|
||||
);
|
||||
|
|
@ -41,6 +40,20 @@ const socialLogin =
|
|||
}
|
||||
}
|
||||
|
||||
const appConfig = existingUser?.tenantId
|
||||
? await resolveAppConfigForUser(getAppConfig, existingUser)
|
||||
: baseConfig;
|
||||
|
||||
if (!isEmailDomainAllowed(email, appConfig?.registration?.allowedDomains)) {
|
||||
logger.error(
|
||||
`[${provider}Login] Authentication blocked - email domain not allowed [Email: ${email}]`,
|
||||
);
|
||||
const error = new Error(ErrorTypes.AUTH_FAILED);
|
||||
error.code = ErrorTypes.AUTH_FAILED;
|
||||
error.message = 'Email domain not allowed';
|
||||
return cb(error);
|
||||
}
|
||||
|
||||
if (existingUser?.provider === provider) {
|
||||
await handleExistingUser(existingUser, avatarUrl, appConfig, email);
|
||||
return cb(null, existingUser);
|
||||
|
|
|
|||
|
|
@ -3,6 +3,8 @@ const { ErrorTypes } = require('librechat-data-provider');
|
|||
const { createSocialUser, handleExistingUser } = require('./process');
|
||||
const socialLogin = require('./socialLogin');
|
||||
const { findUser } = require('~/models');
|
||||
const { resolveAppConfigForUser } = require('@librechat/api');
|
||||
const { getAppConfig } = require('~/server/services/Config');
|
||||
|
||||
jest.mock('@librechat/data-schemas', () => {
|
||||
const actualModule = jest.requireActual('@librechat/data-schemas');
|
||||
|
|
@ -25,6 +27,10 @@ jest.mock('@librechat/api', () => ({
|
|||
...jest.requireActual('@librechat/api'),
|
||||
isEnabled: jest.fn().mockReturnValue(true),
|
||||
isEmailDomainAllowed: jest.fn().mockReturnValue(true),
|
||||
resolveAppConfigForUser: jest.fn().mockResolvedValue({
|
||||
fileStrategy: 'local',
|
||||
balance: { enabled: false },
|
||||
}),
|
||||
}));
|
||||
|
||||
jest.mock('~/models', () => ({
|
||||
|
|
@ -66,10 +72,7 @@ describe('socialLogin', () => {
|
|||
googleId: googleId,
|
||||
};
|
||||
|
||||
/** Mock findUser to return user on first call (by googleId), null on second call */
|
||||
findUser
|
||||
.mockResolvedValueOnce(existingUser) // First call: finds by googleId
|
||||
.mockResolvedValueOnce(null); // Second call would be by email, but won't be reached
|
||||
findUser.mockResolvedValueOnce(existingUser).mockResolvedValueOnce(null);
|
||||
|
||||
const mockProfile = {
|
||||
id: googleId,
|
||||
|
|
@ -83,13 +86,9 @@ describe('socialLogin', () => {
|
|||
|
||||
await loginFn(null, null, null, mockProfile, callback);
|
||||
|
||||
/** Verify it searched by googleId first */
|
||||
expect(findUser).toHaveBeenNthCalledWith(1, { googleId: googleId });
|
||||
|
||||
/** Verify it did NOT search by email (because it found user by googleId) */
|
||||
expect(findUser).toHaveBeenCalledTimes(1);
|
||||
|
||||
/** Verify handleExistingUser was called with the new email */
|
||||
expect(handleExistingUser).toHaveBeenCalledWith(
|
||||
existingUser,
|
||||
'https://example.com/avatar.png',
|
||||
|
|
@ -97,7 +96,6 @@ describe('socialLogin', () => {
|
|||
newEmail,
|
||||
);
|
||||
|
||||
/** Verify callback was called with success */
|
||||
expect(callback).toHaveBeenCalledWith(null, existingUser);
|
||||
});
|
||||
|
||||
|
|
@ -113,7 +111,7 @@ describe('socialLogin', () => {
|
|||
facebookId: facebookId,
|
||||
};
|
||||
|
||||
findUser.mockResolvedValue(existingUser); // Always returns user
|
||||
findUser.mockResolvedValue(existingUser);
|
||||
|
||||
const mockProfile = {
|
||||
id: facebookId,
|
||||
|
|
@ -127,7 +125,6 @@ describe('socialLogin', () => {
|
|||
|
||||
await loginFn(null, null, null, mockProfile, callback);
|
||||
|
||||
/** Verify it searched by facebookId first */
|
||||
expect(findUser).toHaveBeenCalledWith({ facebookId: facebookId });
|
||||
expect(findUser.mock.calls[0]).toEqual([{ facebookId: facebookId }]);
|
||||
|
||||
|
|
@ -150,13 +147,10 @@ describe('socialLogin', () => {
|
|||
_id: 'user789',
|
||||
email: email,
|
||||
provider: 'google',
|
||||
googleId: 'old-google-id', // Different googleId (edge case)
|
||||
googleId: 'old-google-id',
|
||||
};
|
||||
|
||||
/** First call (by googleId) returns null, second call (by email) returns user */
|
||||
findUser
|
||||
.mockResolvedValueOnce(null) // By googleId
|
||||
.mockResolvedValueOnce(existingUser); // By email
|
||||
findUser.mockResolvedValueOnce(null).mockResolvedValueOnce(existingUser);
|
||||
|
||||
const mockProfile = {
|
||||
id: googleId,
|
||||
|
|
@ -170,13 +164,10 @@ describe('socialLogin', () => {
|
|||
|
||||
await loginFn(null, null, null, mockProfile, callback);
|
||||
|
||||
/** Verify both searches happened */
|
||||
expect(findUser).toHaveBeenNthCalledWith(1, { googleId: googleId });
|
||||
/** Email passed as-is; findUser implementation handles case normalization */
|
||||
expect(findUser).toHaveBeenNthCalledWith(2, { email: email });
|
||||
expect(findUser).toHaveBeenCalledTimes(2);
|
||||
|
||||
/** Verify warning log */
|
||||
expect(logger.warn).toHaveBeenCalledWith(
|
||||
`[${provider}Login] User found by email: ${email} but not by ${provider}Id`,
|
||||
);
|
||||
|
|
@ -197,7 +188,6 @@ describe('socialLogin', () => {
|
|||
googleId: googleId,
|
||||
};
|
||||
|
||||
/** Both searches return null */
|
||||
findUser.mockResolvedValue(null);
|
||||
createSocialUser.mockResolvedValue(newUser);
|
||||
|
||||
|
|
@ -213,10 +203,8 @@ describe('socialLogin', () => {
|
|||
|
||||
await loginFn(null, null, null, mockProfile, callback);
|
||||
|
||||
/** Verify both searches happened */
|
||||
expect(findUser).toHaveBeenCalledTimes(2);
|
||||
|
||||
/** Verify createSocialUser was called */
|
||||
expect(createSocialUser).toHaveBeenCalledWith({
|
||||
email: email,
|
||||
avatarUrl: 'https://example.com/avatar.png',
|
||||
|
|
@ -242,12 +230,10 @@ describe('socialLogin', () => {
|
|||
const existingUser = {
|
||||
_id: 'user123',
|
||||
email: email,
|
||||
provider: 'local', // Different provider
|
||||
provider: 'local',
|
||||
};
|
||||
|
||||
findUser
|
||||
.mockResolvedValueOnce(null) // By googleId
|
||||
.mockResolvedValueOnce(existingUser); // By email
|
||||
findUser.mockResolvedValueOnce(null).mockResolvedValueOnce(existingUser);
|
||||
|
||||
const mockProfile = {
|
||||
id: googleId,
|
||||
|
|
@ -261,7 +247,6 @@ describe('socialLogin', () => {
|
|||
|
||||
await loginFn(null, null, null, mockProfile, callback);
|
||||
|
||||
/** Verify error callback */
|
||||
expect(callback).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
code: ErrorTypes.AUTH_FAILED,
|
||||
|
|
@ -274,4 +259,104 @@ describe('socialLogin', () => {
|
|||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Tenant-scoped config', () => {
|
||||
it('should call resolveAppConfigForUser for tenant user', async () => {
|
||||
const provider = 'google';
|
||||
const googleId = 'google-tenant-user';
|
||||
const email = 'tenant@example.com';
|
||||
|
||||
const existingUser = {
|
||||
_id: 'userTenant',
|
||||
email,
|
||||
provider: 'google',
|
||||
googleId,
|
||||
tenantId: 'tenant-b',
|
||||
role: 'USER',
|
||||
};
|
||||
|
||||
findUser.mockResolvedValue(existingUser);
|
||||
|
||||
const mockProfile = {
|
||||
id: googleId,
|
||||
emails: [{ value: email, verified: true }],
|
||||
photos: [{ value: 'https://example.com/avatar.png' }],
|
||||
name: { givenName: 'Tenant', familyName: 'User' },
|
||||
};
|
||||
|
||||
const loginFn = socialLogin(provider, mockGetProfileDetails);
|
||||
const callback = jest.fn();
|
||||
|
||||
await loginFn(null, null, null, mockProfile, callback);
|
||||
|
||||
expect(resolveAppConfigForUser).toHaveBeenCalledWith(getAppConfig, existingUser);
|
||||
});
|
||||
|
||||
it('should use baseConfig for non-tenant user without calling resolveAppConfigForUser', async () => {
|
||||
const provider = 'google';
|
||||
const googleId = 'google-new-tenant';
|
||||
const email = 'new@example.com';
|
||||
|
||||
findUser.mockResolvedValue(null);
|
||||
createSocialUser.mockResolvedValue({
|
||||
_id: 'newUser',
|
||||
email,
|
||||
provider: 'google',
|
||||
googleId,
|
||||
});
|
||||
|
||||
const mockProfile = {
|
||||
id: googleId,
|
||||
emails: [{ value: email, verified: true }],
|
||||
photos: [{ value: 'https://example.com/avatar.png' }],
|
||||
name: { givenName: 'New', familyName: 'User' },
|
||||
};
|
||||
|
||||
const loginFn = socialLogin(provider, mockGetProfileDetails);
|
||||
const callback = jest.fn();
|
||||
|
||||
await loginFn(null, null, null, mockProfile, callback);
|
||||
|
||||
expect(resolveAppConfigForUser).not.toHaveBeenCalled();
|
||||
expect(getAppConfig).toHaveBeenCalledWith({ baseOnly: true });
|
||||
});
|
||||
|
||||
it('should block login when tenant config restricts the domain', async () => {
|
||||
const { isEmailDomainAllowed } = require('@librechat/api');
|
||||
const provider = 'google';
|
||||
const googleId = 'google-tenant-blocked';
|
||||
const email = 'blocked@example.com';
|
||||
|
||||
const existingUser = {
|
||||
_id: 'userBlocked',
|
||||
email,
|
||||
provider: 'google',
|
||||
googleId,
|
||||
tenantId: 'tenant-restrict',
|
||||
role: 'USER',
|
||||
};
|
||||
|
||||
findUser.mockResolvedValue(existingUser);
|
||||
resolveAppConfigForUser.mockResolvedValue({
|
||||
registration: { allowedDomains: ['other.com'] },
|
||||
});
|
||||
isEmailDomainAllowed.mockReturnValueOnce(true).mockReturnValueOnce(false);
|
||||
|
||||
const mockProfile = {
|
||||
id: googleId,
|
||||
emails: [{ value: email, verified: true }],
|
||||
photos: [{ value: 'https://example.com/avatar.png' }],
|
||||
name: { givenName: 'Blocked', familyName: 'User' },
|
||||
};
|
||||
|
||||
const loginFn = socialLogin(provider, mockGetProfileDetails);
|
||||
const callback = jest.fn();
|
||||
|
||||
await loginFn(null, null, null, mockProfile, callback);
|
||||
|
||||
expect(callback).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ message: 'Email domain not allowed' }),
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -1813,3 +1813,57 @@ describe('GLM Model Tests (Zhipu AI)', () => {
|
|||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('Mistral Model Tests', () => {
|
||||
describe('getModelMaxTokens', () => {
|
||||
test('should return correct tokens for mistral-large-3 (256k context)', () => {
|
||||
expect(getModelMaxTokens('mistral-large-3', EModelEndpoint.custom)).toBe(
|
||||
maxTokensMap[EModelEndpoint.custom]['mistral-large-3'],
|
||||
);
|
||||
});
|
||||
|
||||
test('should match mistral-large-3 for suffixed variants', () => {
|
||||
expect(getModelMaxTokens('mistral-large-3-instruct', EModelEndpoint.custom)).toBe(
|
||||
maxTokensMap[EModelEndpoint.custom]['mistral-large-3'],
|
||||
);
|
||||
});
|
||||
|
||||
test('should not match mistral-large-3 for generic mistral-large', () => {
|
||||
expect(getModelMaxTokens('mistral-large', EModelEndpoint.custom)).toBe(
|
||||
maxTokensMap[EModelEndpoint.custom]['mistral-large'],
|
||||
);
|
||||
expect(getModelMaxTokens('mistral-large-latest', EModelEndpoint.custom)).toBe(
|
||||
maxTokensMap[EModelEndpoint.custom]['mistral-large'],
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('matchModelName', () => {
|
||||
test('should match mistral-large-3 exactly', () => {
|
||||
expect(matchModelName('mistral-large-3', EModelEndpoint.custom)).toBe('mistral-large-3');
|
||||
});
|
||||
|
||||
test('should match mistral-large-3 for prefixed/suffixed variants', () => {
|
||||
expect(matchModelName('mistral/mistral-large-3', EModelEndpoint.custom)).toBe(
|
||||
'mistral-large-3',
|
||||
);
|
||||
expect(matchModelName('mistral-large-3-instruct', EModelEndpoint.custom)).toBe(
|
||||
'mistral-large-3',
|
||||
);
|
||||
});
|
||||
|
||||
test('should match generic mistral-large for non-3 variants', () => {
|
||||
expect(matchModelName('mistral-large-latest', EModelEndpoint.custom)).toBe('mistral-large');
|
||||
});
|
||||
});
|
||||
|
||||
describe('findMatchingPattern', () => {
|
||||
test('should prefer mistral-large-3 over mistral-large for mistral-large-3 variants', () => {
|
||||
const result = findMatchingPattern(
|
||||
'mistral-large-3-instruct',
|
||||
maxTokensMap[EModelEndpoint.custom],
|
||||
);
|
||||
expect(result).toBe('mistral-large-3');
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -140,6 +140,55 @@ export function useMessagesOperations() {
|
|||
);
|
||||
}
|
||||
|
||||
type OptionalMessagesOps = Pick<
|
||||
MessagesViewContextValue,
|
||||
'ask' | 'regenerate' | 'handleContinue' | 'getMessages' | 'setMessages'
|
||||
>;
|
||||
|
||||
const NOOP_OPS: OptionalMessagesOps = {
|
||||
ask: () => {},
|
||||
regenerate: () => {},
|
||||
handleContinue: () => {},
|
||||
getMessages: () => undefined,
|
||||
setMessages: () => {},
|
||||
};
|
||||
|
||||
/**
|
||||
* Hook for components that need message operations but may render outside MessagesViewProvider
|
||||
* (e.g. the /search route). Returns no-op stubs when the provider is absent — UI actions will
|
||||
* be silently discarded rather than crashing. Callers must use optional chaining on
|
||||
* `getMessages()` results, as it returns `undefined` outside the provider.
|
||||
*/
|
||||
export function useOptionalMessagesOperations(): OptionalMessagesOps {
|
||||
const context = useContext(MessagesViewContext);
|
||||
const ask = context?.ask;
|
||||
const regenerate = context?.regenerate;
|
||||
const handleContinue = context?.handleContinue;
|
||||
const getMessages = context?.getMessages;
|
||||
const setMessages = context?.setMessages;
|
||||
return useMemo(
|
||||
() => ({
|
||||
ask: ask ?? NOOP_OPS.ask,
|
||||
regenerate: regenerate ?? NOOP_OPS.regenerate,
|
||||
handleContinue: handleContinue ?? NOOP_OPS.handleContinue,
|
||||
getMessages: getMessages ?? NOOP_OPS.getMessages,
|
||||
setMessages: setMessages ?? NOOP_OPS.setMessages,
|
||||
}),
|
||||
[ask, regenerate, handleContinue, getMessages, setMessages],
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Hook for components that need conversation data but may render outside MessagesViewProvider
|
||||
* (e.g. the /search route). Returns `undefined` for both fields when the provider is absent.
|
||||
*/
|
||||
export function useOptionalMessagesConversation() {
|
||||
const context = useContext(MessagesViewContext);
|
||||
const conversation = context?.conversation;
|
||||
const conversationId = context?.conversationId;
|
||||
return useMemo(() => ({ conversation, conversationId }), [conversation, conversationId]);
|
||||
}
|
||||
|
||||
/** Hook for components that only need message state */
|
||||
export function useMessagesState() {
|
||||
const { index, latestMessageId, latestMessageDepth, setLatestMessage } = useMessagesViewContext();
|
||||
|
|
|
|||
53
client/src/Providers/__tests__/MessagesViewContext.spec.tsx
Normal file
53
client/src/Providers/__tests__/MessagesViewContext.spec.tsx
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
import { renderHook } from '@testing-library/react';
|
||||
import {
|
||||
useOptionalMessagesOperations,
|
||||
useOptionalMessagesConversation,
|
||||
} from '../MessagesViewContext';
|
||||
|
||||
describe('useOptionalMessagesOperations', () => {
|
||||
it('returns noop stubs when rendered outside MessagesViewProvider', () => {
|
||||
const { result } = renderHook(() => useOptionalMessagesOperations());
|
||||
|
||||
expect(result.current.ask).toBeInstanceOf(Function);
|
||||
expect(result.current.regenerate).toBeInstanceOf(Function);
|
||||
expect(result.current.handleContinue).toBeInstanceOf(Function);
|
||||
expect(result.current.getMessages).toBeInstanceOf(Function);
|
||||
expect(result.current.setMessages).toBeInstanceOf(Function);
|
||||
});
|
||||
|
||||
it('noop stubs do not throw when called', () => {
|
||||
const { result } = renderHook(() => useOptionalMessagesOperations());
|
||||
|
||||
expect(() => result.current.ask({} as never)).not.toThrow();
|
||||
expect(() => result.current.regenerate({} as never)).not.toThrow();
|
||||
expect(() => result.current.handleContinue({} as never)).not.toThrow();
|
||||
expect(() => result.current.setMessages([])).not.toThrow();
|
||||
});
|
||||
|
||||
it('getMessages returns undefined outside the provider', () => {
|
||||
const { result } = renderHook(() => useOptionalMessagesOperations());
|
||||
expect(result.current.getMessages()).toBeUndefined();
|
||||
});
|
||||
|
||||
it('returns stable references across re-renders', () => {
|
||||
const { result, rerender } = renderHook(() => useOptionalMessagesOperations());
|
||||
const first = result.current;
|
||||
rerender();
|
||||
expect(result.current).toBe(first);
|
||||
});
|
||||
});
|
||||
|
||||
describe('useOptionalMessagesConversation', () => {
|
||||
it('returns undefined fields when rendered outside MessagesViewProvider', () => {
|
||||
const { result } = renderHook(() => useOptionalMessagesConversation());
|
||||
expect(result.current.conversation).toBeUndefined();
|
||||
expect(result.current.conversationId).toBeUndefined();
|
||||
});
|
||||
|
||||
it('returns stable references across re-renders', () => {
|
||||
const { result, rerender } = renderHook(() => useOptionalMessagesConversation());
|
||||
const first = result.current;
|
||||
rerender();
|
||||
expect(result.current).toBe(first);
|
||||
});
|
||||
});
|
||||
|
|
@ -355,6 +355,28 @@ export type TOptions = {
|
|||
|
||||
export type TAskFunction = (props: TAskProps, options?: TOptions) => void;
|
||||
|
||||
/**
|
||||
* Stable context object passed from non-memo'd wrapper components (Message, MessageContent)
|
||||
* to memo'd inner components (MessageRender, ContentRender) via props.
|
||||
*
|
||||
* This avoids subscribing to ChatContext inside memo'd components, which would bypass React.memo
|
||||
* and cause unnecessary re-renders when `isSubmitting` changes during streaming.
|
||||
*
|
||||
* The `isSubmitting` property should use a getter backed by a ref so it returns the current
|
||||
* value at call-time (for callback guards) without being a reactive dependency.
|
||||
*/
|
||||
export type TMessageChatContext = {
|
||||
ask: (...args: Parameters<TAskFunction>) => void;
|
||||
index: number;
|
||||
regenerate: (message: t.TMessage, options?: { addedConvo?: t.TConversation | null }) => void;
|
||||
conversation: t.TConversation | null;
|
||||
latestMessageId: string | undefined;
|
||||
latestMessageDepth: number | undefined;
|
||||
handleContinue: (e: React.MouseEvent<HTMLButtonElement>) => void;
|
||||
/** Should be a getter backed by a ref — reads current value without triggering re-renders */
|
||||
readonly isSubmitting: boolean;
|
||||
};
|
||||
|
||||
export type TMessageProps = {
|
||||
conversation?: t.TConversation | null;
|
||||
messageId?: string | null;
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
import { useCallback, useRef } from 'react';
|
||||
import { memo, useCallback, useRef } from 'react';
|
||||
import { MicOff } from 'lucide-react';
|
||||
import { useToastContext, TooltipAnchor, ListeningIcon, Spinner } from '@librechat/client';
|
||||
import { useLocalize, useSpeechToText, useGetAudioSettings } from '~/hooks';
|
||||
|
|
@ -7,7 +7,7 @@ import { globalAudioId } from '~/common';
|
|||
import { cn } from '~/utils';
|
||||
|
||||
const isExternalSTT = (speechToTextEndpoint: string) => speechToTextEndpoint === 'external';
|
||||
export default function AudioRecorder({
|
||||
export default memo(function AudioRecorder({
|
||||
disabled,
|
||||
ask,
|
||||
methods,
|
||||
|
|
@ -26,10 +26,12 @@ export default function AudioRecorder({
|
|||
const { speechToTextEndpoint } = useGetAudioSettings();
|
||||
|
||||
const existingTextRef = useRef<string>('');
|
||||
const isSubmittingRef = useRef(isSubmitting);
|
||||
isSubmittingRef.current = isSubmitting;
|
||||
|
||||
const onTranscriptionComplete = useCallback(
|
||||
(text: string) => {
|
||||
if (isSubmitting) {
|
||||
if (isSubmittingRef.current) {
|
||||
showToast({
|
||||
message: localize('com_ui_speech_while_submitting'),
|
||||
status: 'error',
|
||||
|
|
@ -52,7 +54,7 @@ export default function AudioRecorder({
|
|||
existingTextRef.current = '';
|
||||
}
|
||||
},
|
||||
[ask, reset, showToast, localize, isSubmitting, speechToTextEndpoint],
|
||||
[ask, reset, showToast, localize, speechToTextEndpoint],
|
||||
);
|
||||
|
||||
const setText = useCallback(
|
||||
|
|
@ -125,4 +127,4 @@ export default function AudioRecorder({
|
|||
}
|
||||
/>
|
||||
);
|
||||
}
|
||||
});
|
||||
|
|
|
|||
|
|
@ -3,6 +3,8 @@ import { useWatch } from 'react-hook-form';
|
|||
import { TextareaAutosize } from '@librechat/client';
|
||||
import { useRecoilState, useRecoilValue } from 'recoil';
|
||||
import { Constants, isAssistantsEndpoint, isAgentsEndpoint } from 'librechat-data-provider';
|
||||
import type { TConversation } from 'librechat-data-provider';
|
||||
import type { ExtendedFile, FileSetter, ConvoGenerator } from '~/common';
|
||||
import {
|
||||
useChatContext,
|
||||
useChatFormContext,
|
||||
|
|
@ -35,7 +37,30 @@ import BadgeRow from './BadgeRow';
|
|||
import Mention from './Mention';
|
||||
import store from '~/store';
|
||||
|
||||
const ChatForm = memo(({ index = 0 }: { index?: number }) => {
|
||||
interface ChatFormProps {
|
||||
index: number;
|
||||
/** From ChatContext — individual values so memo can compare them */
|
||||
files: Map<string, ExtendedFile>;
|
||||
setFiles: FileSetter;
|
||||
conversation: TConversation | null;
|
||||
isSubmitting: boolean;
|
||||
filesLoading: boolean;
|
||||
setFilesLoading: React.Dispatch<React.SetStateAction<boolean>>;
|
||||
newConversation: ConvoGenerator;
|
||||
handleStopGenerating: (e: React.MouseEvent<HTMLButtonElement>) => void;
|
||||
}
|
||||
|
||||
const ChatForm = memo(function ChatForm({
|
||||
index,
|
||||
files,
|
||||
setFiles,
|
||||
conversation,
|
||||
isSubmitting,
|
||||
filesLoading,
|
||||
setFilesLoading,
|
||||
newConversation,
|
||||
handleStopGenerating,
|
||||
}: ChatFormProps) {
|
||||
const submitButtonRef = useRef<HTMLButtonElement>(null);
|
||||
const textAreaRef = useRef<HTMLTextAreaElement>(null);
|
||||
useFocusChatEffect(textAreaRef);
|
||||
|
|
@ -65,15 +90,6 @@ const ChatForm = memo(({ index = 0 }: { index?: number }) => {
|
|||
|
||||
const { requiresKey } = useRequiresKey();
|
||||
const methods = useChatFormContext();
|
||||
const {
|
||||
files,
|
||||
setFiles,
|
||||
conversation,
|
||||
isSubmitting,
|
||||
filesLoading,
|
||||
newConversation,
|
||||
handleStopGenerating,
|
||||
} = useChatContext();
|
||||
const {
|
||||
generateConversation,
|
||||
conversation: addedConvo,
|
||||
|
|
@ -120,6 +136,15 @@ const ChatForm = memo(({ index = 0 }: { index?: number }) => {
|
|||
}
|
||||
}, [isCollapsed]);
|
||||
|
||||
const handleTextareaFocus = useCallback(() => {
|
||||
handleFocusOrClick();
|
||||
setIsTextAreaFocused(true);
|
||||
}, [handleFocusOrClick]);
|
||||
|
||||
const handleTextareaBlur = useCallback(() => {
|
||||
setIsTextAreaFocused(false);
|
||||
}, []);
|
||||
|
||||
useAutoSave({
|
||||
files,
|
||||
setFiles,
|
||||
|
|
@ -253,7 +278,12 @@ const ChatForm = memo(({ index = 0 }: { index?: number }) => {
|
|||
handleSaveBadges={handleSaveBadges}
|
||||
setBadges={setBadges}
|
||||
/>
|
||||
<FileFormChat conversation={conversation} />
|
||||
<FileFormChat
|
||||
conversation={conversation}
|
||||
files={files}
|
||||
setFiles={setFiles}
|
||||
setFilesLoading={setFilesLoading}
|
||||
/>
|
||||
{endpoint && (
|
||||
<div className={cn('flex', isRTL ? 'flex-row-reverse' : 'flex-row')}>
|
||||
<div
|
||||
|
|
@ -284,11 +314,8 @@ const ChatForm = memo(({ index = 0 }: { index?: number }) => {
|
|||
tabIndex={0}
|
||||
data-testid="text-input"
|
||||
rows={1}
|
||||
onFocus={() => {
|
||||
handleFocusOrClick();
|
||||
setIsTextAreaFocused(true);
|
||||
}}
|
||||
onBlur={setIsTextAreaFocused.bind(null, false)}
|
||||
onFocus={handleTextareaFocus}
|
||||
onBlur={handleTextareaBlur}
|
||||
aria-label={localize('com_ui_message_input')}
|
||||
onClick={handleFocusOrClick}
|
||||
style={{ height: 44, overflowY: 'auto' }}
|
||||
|
|
@ -315,7 +342,13 @@ const ChatForm = memo(({ index = 0 }: { index?: number }) => {
|
|||
)}
|
||||
>
|
||||
<div className={`${isRTL ? 'mr-2' : 'ml-2'}`}>
|
||||
<AttachFileChat conversation={conversation} disableInputs={disableInputs} />
|
||||
<AttachFileChat
|
||||
conversation={conversation}
|
||||
disableInputs={disableInputs}
|
||||
files={files}
|
||||
setFiles={setFiles}
|
||||
setFilesLoading={setFilesLoading}
|
||||
/>
|
||||
</div>
|
||||
<BadgeRow
|
||||
showEphemeralBadges={
|
||||
|
|
@ -360,5 +393,77 @@ const ChatForm = memo(({ index = 0 }: { index?: number }) => {
|
|||
</form>
|
||||
);
|
||||
});
|
||||
ChatForm.displayName = 'ChatForm';
|
||||
|
||||
export default ChatForm;
|
||||
/**
|
||||
* Wrapper that subscribes to ChatContext and passes stable individual values
|
||||
* to the memo'd ChatForm. This prevents ChatForm from re-rendering on every
|
||||
* streaming chunk — it only re-renders when the specific values it uses change.
|
||||
*/
|
||||
function ChatFormWrapper({ index = 0 }: { index?: number }) {
|
||||
const {
|
||||
files,
|
||||
setFiles,
|
||||
conversation,
|
||||
isSubmitting,
|
||||
filesLoading,
|
||||
setFilesLoading,
|
||||
newConversation,
|
||||
handleStopGenerating,
|
||||
} = useChatContext();
|
||||
|
||||
/**
|
||||
* Stabilize conversation reference: only update when rendering-relevant fields change,
|
||||
* not on every metadata update (e.g., title generation during streaming).
|
||||
*/
|
||||
const hasMessages = (conversation?.messages?.length ?? 0) > 0;
|
||||
const stableConversation = useMemo(
|
||||
() => conversation,
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
[
|
||||
conversation?.conversationId,
|
||||
conversation?.endpoint,
|
||||
conversation?.endpointType,
|
||||
conversation?.agent_id,
|
||||
conversation?.assistant_id,
|
||||
conversation?.spec,
|
||||
conversation?.useResponsesApi,
|
||||
conversation?.model,
|
||||
hasMessages,
|
||||
],
|
||||
);
|
||||
|
||||
/** Stabilize function refs so they never trigger ChatForm re-renders */
|
||||
const handleStopRef = useRef(handleStopGenerating);
|
||||
handleStopRef.current = handleStopGenerating;
|
||||
const stableHandleStop = useCallback(
|
||||
(e: React.MouseEvent<HTMLButtonElement>) => handleStopRef.current(e),
|
||||
[],
|
||||
);
|
||||
|
||||
const newConvoRef = useRef(newConversation);
|
||||
newConvoRef.current = newConversation;
|
||||
const stableNewConversation: ConvoGenerator = useCallback(
|
||||
(...args: Parameters<ConvoGenerator>): ReturnType<ConvoGenerator> =>
|
||||
newConvoRef.current(...args),
|
||||
[],
|
||||
);
|
||||
|
||||
return (
|
||||
<ChatForm
|
||||
index={index}
|
||||
files={files}
|
||||
setFiles={setFiles}
|
||||
conversation={stableConversation}
|
||||
isSubmitting={isSubmitting}
|
||||
filesLoading={filesLoading}
|
||||
setFilesLoading={setFilesLoading}
|
||||
newConversation={stableNewConversation}
|
||||
handleStopGenerating={stableHandleStop}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
ChatFormWrapper.displayName = 'ChatFormWrapper';
|
||||
|
||||
export default ChatFormWrapper;
|
||||
|
|
|
|||
|
|
@ -52,4 +52,4 @@ const CollapseChat = ({
|
|||
);
|
||||
};
|
||||
|
||||
export default CollapseChat;
|
||||
export default React.memo(CollapseChat);
|
||||
|
|
|
|||
|
|
@ -1,14 +1,33 @@
|
|||
import React, { useRef } from 'react';
|
||||
import { FileUpload, TooltipAnchor, AttachmentIcon } from '@librechat/client';
|
||||
import { useLocalize, useFileHandling } from '~/hooks';
|
||||
import type { TConversation } from 'librechat-data-provider';
|
||||
import type { ExtendedFile, FileSetter } from '~/common';
|
||||
import { useFileHandlingNoChatContext, useLocalize } from '~/hooks';
|
||||
import { cn } from '~/utils';
|
||||
|
||||
const AttachFile = ({ disabled }: { disabled?: boolean | null }) => {
|
||||
const AttachFile = ({
|
||||
disabled,
|
||||
files,
|
||||
setFiles,
|
||||
setFilesLoading,
|
||||
conversation,
|
||||
}: {
|
||||
disabled?: boolean | null;
|
||||
files: Map<string, ExtendedFile>;
|
||||
setFiles: FileSetter;
|
||||
setFilesLoading: React.Dispatch<React.SetStateAction<boolean>>;
|
||||
conversation: TConversation | null;
|
||||
}) => {
|
||||
const localize = useLocalize();
|
||||
const inputRef = useRef<HTMLInputElement>(null);
|
||||
const isUploadDisabled = disabled ?? false;
|
||||
|
||||
const { handleFileChange } = useFileHandling();
|
||||
const { handleFileChange } = useFileHandlingNoChatContext(undefined, {
|
||||
files,
|
||||
setFiles,
|
||||
setFilesLoading,
|
||||
conversation,
|
||||
});
|
||||
|
||||
return (
|
||||
<FileUpload ref={inputRef} handleFileChange={handleFileChange}>
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ import {
|
|||
getEndpointFileConfig,
|
||||
} from 'librechat-data-provider';
|
||||
import type { TConversation } from 'librechat-data-provider';
|
||||
import type { ExtendedFile, FileSetter } from '~/common';
|
||||
import { useGetFileConfig, useGetEndpointsQuery, useGetAgentByIdQuery } from '~/data-provider';
|
||||
import { useAgentsMapContext } from '~/Providers';
|
||||
import AttachFileMenu from './AttachFileMenu';
|
||||
|
|
@ -17,9 +18,15 @@ import AttachFile from './AttachFile';
|
|||
function AttachFileChat({
|
||||
disableInputs,
|
||||
conversation,
|
||||
files,
|
||||
setFiles,
|
||||
setFilesLoading,
|
||||
}: {
|
||||
disableInputs: boolean;
|
||||
conversation: TConversation | null;
|
||||
files: Map<string, ExtendedFile>;
|
||||
setFiles: FileSetter;
|
||||
setFilesLoading: React.Dispatch<React.SetStateAction<boolean>>;
|
||||
}) {
|
||||
const conversationId = conversation?.conversationId ?? Constants.NEW_CONVO;
|
||||
const { endpoint } = conversation ?? { endpoint: null };
|
||||
|
|
@ -90,7 +97,15 @@ function AttachFileChat({
|
|||
);
|
||||
|
||||
if (isAssistants && endpointSupportsFiles && !isUploadDisabled) {
|
||||
return <AttachFile disabled={disableInputs} />;
|
||||
return (
|
||||
<AttachFile
|
||||
disabled={disableInputs}
|
||||
files={files}
|
||||
setFiles={setFiles}
|
||||
setFilesLoading={setFilesLoading}
|
||||
conversation={conversation}
|
||||
/>
|
||||
);
|
||||
} else if ((isAgents || endpointSupportsFiles) && !isUploadDisabled) {
|
||||
return (
|
||||
<AttachFileMenu
|
||||
|
|
@ -101,6 +116,10 @@ function AttachFileChat({
|
|||
agentId={conversation?.agent_id}
|
||||
endpointFileConfig={endpointFileConfig}
|
||||
useResponsesApi={useResponsesApi}
|
||||
files={files}
|
||||
setFiles={setFiles}
|
||||
setFilesLoading={setFilesLoading}
|
||||
conversation={conversation}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -23,15 +23,16 @@ import {
|
|||
bedrockDocumentExtensions,
|
||||
isDocumentSupportedProvider,
|
||||
} from 'librechat-data-provider';
|
||||
import type { EndpointFileConfig } from 'librechat-data-provider';
|
||||
import type { EndpointFileConfig, TConversation } from 'librechat-data-provider';
|
||||
import type { ExtendedFile, FileSetter } from '~/common';
|
||||
import {
|
||||
useAgentToolPermissions,
|
||||
useAgentCapabilities,
|
||||
useGetAgentsConfig,
|
||||
useFileHandling,
|
||||
useFileHandlingNoChatContext,
|
||||
useLocalize,
|
||||
} from '~/hooks';
|
||||
import useSharePointFileHandling from '~/hooks/Files/useSharePointFileHandling';
|
||||
import { useSharePointFileHandlingNoChatContext } from '~/hooks/Files/useSharePointFileHandling';
|
||||
import { SharePointPickerDialog } from '~/components/SharePoint';
|
||||
import { useGetStartupConfig } from '~/data-provider';
|
||||
import { ephemeralAgentByConvoId } from '~/store';
|
||||
|
|
@ -53,6 +54,10 @@ interface AttachFileMenuProps {
|
|||
endpointType?: EModelEndpoint | string;
|
||||
endpointFileConfig?: EndpointFileConfig;
|
||||
useResponsesApi?: boolean;
|
||||
files: Map<string, ExtendedFile>;
|
||||
setFiles: FileSetter;
|
||||
setFilesLoading: React.Dispatch<React.SetStateAction<boolean>>;
|
||||
conversation: TConversation | null;
|
||||
}
|
||||
|
||||
const AttachFileMenu = ({
|
||||
|
|
@ -63,6 +68,10 @@ const AttachFileMenu = ({
|
|||
conversationId,
|
||||
endpointFileConfig,
|
||||
useResponsesApi,
|
||||
files,
|
||||
setFiles,
|
||||
setFilesLoading,
|
||||
conversation,
|
||||
}: AttachFileMenuProps) => {
|
||||
const localize = useLocalize();
|
||||
const isUploadDisabled = disabled ?? false;
|
||||
|
|
@ -72,10 +81,17 @@ const AttachFileMenu = ({
|
|||
ephemeralAgentByConvoId(conversationId),
|
||||
);
|
||||
const [toolResource, setToolResource] = useState<EToolResources | undefined>();
|
||||
const { handleFileChange } = useFileHandling();
|
||||
const { handleSharePointFiles, isProcessing, downloadProgress } = useSharePointFileHandling({
|
||||
toolResource,
|
||||
const { handleFileChange } = useFileHandlingNoChatContext(undefined, {
|
||||
files,
|
||||
setFiles,
|
||||
setFilesLoading,
|
||||
conversation,
|
||||
});
|
||||
const { handleSharePointFiles, isProcessing, downloadProgress } =
|
||||
useSharePointFileHandlingNoChatContext(
|
||||
{ toolResource },
|
||||
{ files, setFiles, setFilesLoading, conversation },
|
||||
);
|
||||
|
||||
const { agentsConfig } = useGetAgentsConfig();
|
||||
const { data: startupConfig } = useGetStartupConfig();
|
||||
|
|
|
|||
|
|
@ -1,16 +1,30 @@
|
|||
import { memo } from 'react';
|
||||
import { useRecoilValue } from 'recoil';
|
||||
import type { TConversation } from 'librechat-data-provider';
|
||||
import { useChatContext } from '~/Providers';
|
||||
import { useFileHandling } from '~/hooks';
|
||||
import type { ExtendedFile, FileSetter } from '~/common';
|
||||
import { useFileHandlingNoChatContext } from '~/hooks';
|
||||
import FileRow from './FileRow';
|
||||
import store from '~/store';
|
||||
|
||||
function FileFormChat({ conversation }: { conversation: TConversation | null }) {
|
||||
const { files, setFiles, setFilesLoading } = useChatContext();
|
||||
function FileFormChat({
|
||||
conversation,
|
||||
files,
|
||||
setFiles,
|
||||
setFilesLoading,
|
||||
}: {
|
||||
conversation: TConversation | null;
|
||||
files: Map<string, ExtendedFile>;
|
||||
setFiles: FileSetter;
|
||||
setFilesLoading: React.Dispatch<React.SetStateAction<boolean>>;
|
||||
}) {
|
||||
const chatDirection = useRecoilValue(store.chatDirection).toLowerCase();
|
||||
const { endpoint: _endpoint } = conversation ?? { endpoint: null };
|
||||
const { abortUpload } = useFileHandling();
|
||||
const { abortUpload } = useFileHandlingNoChatContext(undefined, {
|
||||
files,
|
||||
setFiles,
|
||||
setFilesLoading,
|
||||
conversation,
|
||||
});
|
||||
|
||||
const isRTL = chatDirection === 'rtl';
|
||||
|
||||
|
|
|
|||
|
|
@ -59,7 +59,13 @@ function renderComponent(conversation: Record<string, unknown> | null, disableIn
|
|||
return render(
|
||||
<QueryClientProvider client={queryClient}>
|
||||
<RecoilRoot>
|
||||
<AttachFileChat conversation={conversation as never} disableInputs={disableInputs} />
|
||||
<AttachFileChat
|
||||
conversation={conversation as never}
|
||||
disableInputs={disableInputs}
|
||||
files={new Map()}
|
||||
setFiles={() => {}}
|
||||
setFilesLoading={() => {}}
|
||||
/>
|
||||
</RecoilRoot>
|
||||
</QueryClientProvider>,
|
||||
);
|
||||
|
|
|
|||
|
|
@ -9,13 +9,14 @@ jest.mock('~/hooks', () => ({
|
|||
useAgentToolPermissions: jest.fn(),
|
||||
useAgentCapabilities: jest.fn(),
|
||||
useGetAgentsConfig: jest.fn(),
|
||||
useFileHandling: jest.fn(),
|
||||
useFileHandlingNoChatContext: jest.fn(),
|
||||
useLocalize: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/hooks/Files/useSharePointFileHandling', () => ({
|
||||
__esModule: true,
|
||||
default: jest.fn(),
|
||||
useSharePointFileHandlingNoChatContext: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/data-provider', () => ({
|
||||
|
|
@ -52,6 +53,7 @@ jest.mock('@librechat/client', () => {
|
|||
),
|
||||
AttachmentIcon: () => R.createElement('span', { 'data-testid': 'attachment-icon' }),
|
||||
SharePointIcon: () => R.createElement('span', { 'data-testid': 'sharepoint-icon' }),
|
||||
useToastContext: () => ({ showToast: jest.fn() }),
|
||||
};
|
||||
});
|
||||
|
||||
|
|
@ -66,11 +68,14 @@ jest.mock('@ariakit/react', () => {
|
|||
const mockUseAgentToolPermissions = jest.requireMock('~/hooks').useAgentToolPermissions;
|
||||
const mockUseAgentCapabilities = jest.requireMock('~/hooks').useAgentCapabilities;
|
||||
const mockUseGetAgentsConfig = jest.requireMock('~/hooks').useGetAgentsConfig;
|
||||
const mockUseFileHandling = jest.requireMock('~/hooks').useFileHandling;
|
||||
const mockUseFileHandlingNoChatContext = jest.requireMock('~/hooks').useFileHandlingNoChatContext;
|
||||
const mockUseLocalize = jest.requireMock('~/hooks').useLocalize;
|
||||
const mockUseSharePointFileHandling = jest.requireMock(
|
||||
'~/hooks/Files/useSharePointFileHandling',
|
||||
).default;
|
||||
const mockUseSharePointFileHandlingNoChatContext = jest.requireMock(
|
||||
'~/hooks/Files/useSharePointFileHandling',
|
||||
).useSharePointFileHandlingNoChatContext;
|
||||
const mockUseGetStartupConfig = jest.requireMock('~/data-provider').useGetStartupConfig;
|
||||
|
||||
const queryClient = new QueryClient({ defaultOptions: { queries: { retry: false } } });
|
||||
|
|
@ -92,12 +97,15 @@ function setupMocks(overrides: { provider?: string } = {}) {
|
|||
codeEnabled: false,
|
||||
});
|
||||
mockUseGetAgentsConfig.mockReturnValue({ agentsConfig: {} });
|
||||
mockUseFileHandling.mockReturnValue({ handleFileChange: jest.fn() });
|
||||
mockUseSharePointFileHandling.mockReturnValue({
|
||||
mockUseFileHandlingNoChatContext.mockReturnValue({ handleFileChange: jest.fn() });
|
||||
const sharePointReturnValue = {
|
||||
handleSharePointFiles: jest.fn(),
|
||||
isProcessing: false,
|
||||
downloadProgress: 0,
|
||||
});
|
||||
error: null,
|
||||
};
|
||||
mockUseSharePointFileHandling.mockReturnValue(sharePointReturnValue);
|
||||
mockUseSharePointFileHandlingNoChatContext.mockReturnValue(sharePointReturnValue);
|
||||
mockUseGetStartupConfig.mockReturnValue({ data: { sharePointFilePickerEnabled: false } });
|
||||
mockUseAgentToolPermissions.mockReturnValue({
|
||||
fileSearchAllowedByAgent: false,
|
||||
|
|
@ -110,7 +118,14 @@ function renderMenu(props: Record<string, unknown> = {}) {
|
|||
return render(
|
||||
<QueryClientProvider client={queryClient}>
|
||||
<RecoilRoot>
|
||||
<AttachFileMenu conversationId="test-convo" {...props} />
|
||||
<AttachFileMenu
|
||||
conversationId="test-convo"
|
||||
files={new Map()}
|
||||
setFiles={() => {}}
|
||||
setFilesLoading={() => {}}
|
||||
conversation={null}
|
||||
{...props}
|
||||
/>
|
||||
</RecoilRoot>
|
||||
</QueryClientProvider>,
|
||||
);
|
||||
|
|
|
|||
|
|
@ -1,8 +1,15 @@
|
|||
import { memo } from 'react';
|
||||
import { TooltipAnchor } from '@librechat/client';
|
||||
import { useLocalize } from '~/hooks';
|
||||
import { cn } from '~/utils';
|
||||
|
||||
export default function StopButton({ stop, setShowStopButton }) {
|
||||
export default memo(function StopButton({
|
||||
stop,
|
||||
setShowStopButton,
|
||||
}: {
|
||||
stop: (e: React.MouseEvent<HTMLButtonElement>) => void;
|
||||
setShowStopButton: (value: boolean) => void;
|
||||
}) {
|
||||
const localize = useLocalize();
|
||||
|
||||
return (
|
||||
|
|
@ -34,4 +41,4 @@ export default function StopButton({ stop, setShowStopButton }) {
|
|||
}
|
||||
></TooltipAnchor>
|
||||
);
|
||||
}
|
||||
});
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
import { memo } from 'react';
|
||||
import AddedConvo from './AddedConvo';
|
||||
import type { TConversation } from 'librechat-data-provider';
|
||||
import type { SetterOrUpdater } from 'recoil';
|
||||
|
||||
export default function TextareaHeader({
|
||||
export default memo(function TextareaHeader({
|
||||
addedConvo,
|
||||
setAddedConvo,
|
||||
}: {
|
||||
|
|
@ -17,4 +18,4 @@ export default function TextareaHeader({
|
|||
<AddedConvo addedConvo={addedConvo} setAddedConvo={setAddedConvo} />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
});
|
||||
|
|
|
|||
|
|
@ -49,19 +49,47 @@ export default function ToolCall({
|
|||
}
|
||||
}, [autoExpand, hasOutput]);
|
||||
|
||||
const parsedAuthUrl = useMemo(() => {
|
||||
if (!auth) {
|
||||
return null;
|
||||
}
|
||||
try {
|
||||
return new URL(auth);
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}, [auth]);
|
||||
|
||||
const { function_name, domain, isMCPToolCall, mcpServerName } = useMemo(() => {
|
||||
if (typeof name !== 'string') {
|
||||
return { function_name: '', domain: null, isMCPToolCall: false, mcpServerName: '' };
|
||||
}
|
||||
if (name.includes(Constants.mcp_delimiter)) {
|
||||
const [func, server] = name.split(Constants.mcp_delimiter);
|
||||
const parts = name.split(Constants.mcp_delimiter);
|
||||
const func = parts[0];
|
||||
const server = parts.slice(1).join(Constants.mcp_delimiter);
|
||||
const displayName = func === 'oauth' ? server : func;
|
||||
return {
|
||||
function_name: func || '',
|
||||
function_name: displayName || '',
|
||||
domain: server && (server.replaceAll(actionDomainSeparator, '.') || null),
|
||||
isMCPToolCall: true,
|
||||
mcpServerName: server || '',
|
||||
};
|
||||
}
|
||||
|
||||
if (parsedAuthUrl) {
|
||||
const redirectUri = parsedAuthUrl.searchParams.get('redirect_uri') || '';
|
||||
const mcpMatch = redirectUri.match(/\/api\/mcp\/([^/]+)\/oauth\/callback/);
|
||||
if (mcpMatch?.[1]) {
|
||||
return {
|
||||
function_name: mcpMatch[1],
|
||||
domain: null,
|
||||
isMCPToolCall: true,
|
||||
mcpServerName: mcpMatch[1],
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
const [func, _domain] = name.includes(actionDelimiter)
|
||||
? name.split(actionDelimiter)
|
||||
: [name, ''];
|
||||
|
|
@ -71,25 +99,20 @@ export default function ToolCall({
|
|||
isMCPToolCall: false,
|
||||
mcpServerName: '',
|
||||
};
|
||||
}, [name]);
|
||||
}, [name, parsedAuthUrl]);
|
||||
|
||||
const toolIconType = useMemo(() => getToolIconType(name), [name]);
|
||||
const mcpIconMap = useMCPIconMap();
|
||||
const mcpIconUrl = isMCPToolCall ? mcpIconMap.get(mcpServerName) : undefined;
|
||||
|
||||
const actionId = useMemo(() => {
|
||||
if (isMCPToolCall || !auth) {
|
||||
if (isMCPToolCall || !parsedAuthUrl) {
|
||||
return '';
|
||||
}
|
||||
try {
|
||||
const url = new URL(auth);
|
||||
const redirectUri = url.searchParams.get('redirect_uri') || '';
|
||||
const match = redirectUri.match(/\/api\/actions\/([^/]+)\/oauth\/callback/);
|
||||
return match?.[1] || '';
|
||||
} catch {
|
||||
return '';
|
||||
}
|
||||
}, [auth, isMCPToolCall]);
|
||||
const redirectUri = parsedAuthUrl.searchParams.get('redirect_uri') || '';
|
||||
const match = redirectUri.match(/\/api\/actions\/([^/]+)\/oauth\/callback/);
|
||||
return match?.[1] || '';
|
||||
}, [parsedAuthUrl, isMCPToolCall]);
|
||||
|
||||
const handleOAuthClick = useCallback(async () => {
|
||||
if (!auth) {
|
||||
|
|
@ -132,21 +155,8 @@ export default function ToolCall({
|
|||
);
|
||||
|
||||
const authDomain = useMemo(() => {
|
||||
const authURL = auth ?? '';
|
||||
if (!authURL) {
|
||||
return '';
|
||||
}
|
||||
try {
|
||||
const url = new URL(authURL);
|
||||
return url.hostname;
|
||||
} catch (e) {
|
||||
logger.error(
|
||||
'client/src/components/Chat/Messages/Content/ToolCall.tsx - Failed to parse auth URL',
|
||||
e,
|
||||
);
|
||||
return '';
|
||||
}
|
||||
}, [auth]);
|
||||
return parsedAuthUrl?.hostname ?? '';
|
||||
}, [parsedAuthUrl]);
|
||||
|
||||
const progress = useProgress(initialProgress);
|
||||
const showCancelled = cancelled || (errorState && !output);
|
||||
|
|
|
|||
|
|
@ -3,11 +3,11 @@ import { ChevronDown } from 'lucide-react';
|
|||
import { Tools } from 'librechat-data-provider';
|
||||
import { UIResourceRenderer } from '@mcp-ui/client';
|
||||
import type { TAttachment, UIResource } from 'librechat-data-provider';
|
||||
import { useOptionalMessagesOperations } from '~/Providers';
|
||||
import { useLocalize, useExpandCollapse } from '~/hooks';
|
||||
import UIResourceCarousel from './UIResourceCarousel';
|
||||
import { useMessagesOperations } from '~/Providers';
|
||||
import { OutputRenderer } from './ToolOutput';
|
||||
import { handleUIAction, cn } from '~/utils';
|
||||
import { OutputRenderer } from './ToolOutput';
|
||||
|
||||
function isSimpleObject(obj: unknown): obj is Record<string, string | number | boolean | null> {
|
||||
if (typeof obj !== 'object' || obj === null || Array.isArray(obj)) {
|
||||
|
|
@ -102,7 +102,7 @@ export default function ToolCallInfo({
|
|||
attachments?: TAttachment[];
|
||||
}) {
|
||||
const localize = useLocalize();
|
||||
const { ask } = useMessagesOperations();
|
||||
const { ask } = useOptionalMessagesOperations();
|
||||
const [showParams, setShowParams] = useState(false);
|
||||
const { style: paramsExpandStyle, ref: paramsExpandRef } = useExpandCollapse(showParams);
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import React, { useState } from 'react';
|
||||
import { UIResourceRenderer } from '@mcp-ui/client';
|
||||
import type { UIResource } from 'librechat-data-provider';
|
||||
import { useMessagesOperations } from '~/Providers';
|
||||
import { useOptionalMessagesOperations } from '~/Providers';
|
||||
import { handleUIAction } from '~/utils';
|
||||
|
||||
interface UIResourceCarouselProps {
|
||||
|
|
@ -13,7 +13,7 @@ const UIResourceCarousel: React.FC<UIResourceCarouselProps> = React.memo(({ uiRe
|
|||
const [showRightArrow, setShowRightArrow] = useState(true);
|
||||
const [isContainerHovered, setIsContainerHovered] = useState(false);
|
||||
const scrollContainerRef = React.useRef<HTMLDivElement>(null);
|
||||
const { ask } = useMessagesOperations();
|
||||
const { ask } = useOptionalMessagesOperations();
|
||||
|
||||
const handleScroll = React.useCallback(() => {
|
||||
if (!scrollContainerRef.current) return;
|
||||
|
|
|
|||
|
|
@ -3,7 +3,11 @@ import { render, screen } from '@testing-library/react';
|
|||
import Markdown from '../Markdown';
|
||||
import { RecoilRoot } from 'recoil';
|
||||
import { UI_RESOURCE_MARKER } from '~/components/MCPUIResource/plugin';
|
||||
import { useMessageContext, useMessagesConversation, useMessagesOperations } from '~/Providers';
|
||||
import {
|
||||
useMessageContext,
|
||||
useOptionalMessagesConversation,
|
||||
useOptionalMessagesOperations,
|
||||
} from '~/Providers';
|
||||
import { useGetMessagesByConvoId } from '~/data-provider';
|
||||
import { useLocalize } from '~/hooks';
|
||||
|
||||
|
|
@ -12,8 +16,8 @@ import { useLocalize } from '~/hooks';
|
|||
jest.mock('~/Providers', () => ({
|
||||
...jest.requireActual('~/Providers'),
|
||||
useMessageContext: jest.fn(),
|
||||
useMessagesConversation: jest.fn(),
|
||||
useMessagesOperations: jest.fn(),
|
||||
useOptionalMessagesConversation: jest.fn(),
|
||||
useOptionalMessagesOperations: jest.fn(),
|
||||
}));
|
||||
jest.mock('~/data-provider');
|
||||
jest.mock('~/hooks');
|
||||
|
|
@ -26,11 +30,11 @@ jest.mock('@mcp-ui/client', () => ({
|
|||
}));
|
||||
|
||||
const mockUseMessageContext = useMessageContext as jest.MockedFunction<typeof useMessageContext>;
|
||||
const mockUseMessagesConversation = useMessagesConversation as jest.MockedFunction<
|
||||
typeof useMessagesConversation
|
||||
const mockUseMessagesConversation = useOptionalMessagesConversation as jest.MockedFunction<
|
||||
typeof useOptionalMessagesConversation
|
||||
>;
|
||||
const mockUseMessagesOperations = useMessagesOperations as jest.MockedFunction<
|
||||
typeof useMessagesOperations
|
||||
const mockUseMessagesOperations = useOptionalMessagesOperations as jest.MockedFunction<
|
||||
typeof useOptionalMessagesOperations
|
||||
>;
|
||||
const mockUseGetMessagesByConvoId = useGetMessagesByConvoId as jest.MockedFunction<
|
||||
typeof useGetMessagesByConvoId
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import React from 'react';
|
||||
import { RecoilRoot } from 'recoil';
|
||||
import { Tools } from 'librechat-data-provider';
|
||||
import { Tools, Constants } from 'librechat-data-provider';
|
||||
import { render, screen, fireEvent } from '@testing-library/react';
|
||||
import ToolCall from '../ToolCall';
|
||||
|
||||
|
|
@ -53,9 +53,20 @@ jest.mock('../ToolCallInfo', () => ({
|
|||
|
||||
jest.mock('../ProgressText', () => ({
|
||||
__esModule: true,
|
||||
default: ({ onClick, inProgressText, finishedText, _error, _hasInput, _isExpanded }: any) => (
|
||||
default: ({
|
||||
onClick,
|
||||
inProgressText,
|
||||
finishedText,
|
||||
subtitle,
|
||||
}: {
|
||||
onClick?: () => void;
|
||||
inProgressText?: string;
|
||||
finishedText?: string;
|
||||
subtitle?: string;
|
||||
}) => (
|
||||
<div data-testid="progress-text" onClick={onClick}>
|
||||
{finishedText || inProgressText}
|
||||
{subtitle && <span data-testid="subtitle">{subtitle}</span>}
|
||||
</div>
|
||||
),
|
||||
}));
|
||||
|
|
@ -346,6 +357,141 @@ describe('ToolCall', () => {
|
|||
});
|
||||
});
|
||||
|
||||
describe('MCP OAuth detection', () => {
|
||||
const d = Constants.mcp_delimiter;
|
||||
|
||||
it('should detect MCP OAuth from delimiter in tool-call name', () => {
|
||||
renderWithRecoil(
|
||||
<ToolCall
|
||||
{...mockProps}
|
||||
name={`oauth${d}my-server`}
|
||||
initialProgress={0.5}
|
||||
isSubmitting={true}
|
||||
auth="https://auth.example.com"
|
||||
/>,
|
||||
);
|
||||
const subtitle = screen.getByTestId('subtitle');
|
||||
expect(subtitle.textContent).toBe('via my-server');
|
||||
});
|
||||
|
||||
it('should preserve full server name when it contains the delimiter substring', () => {
|
||||
renderWithRecoil(
|
||||
<ToolCall
|
||||
{...mockProps}
|
||||
name={`oauth${d}foo${d}bar`}
|
||||
initialProgress={0.5}
|
||||
isSubmitting={true}
|
||||
auth="https://auth.example.com"
|
||||
/>,
|
||||
);
|
||||
const subtitle = screen.getByTestId('subtitle');
|
||||
expect(subtitle.textContent).toBe(`via foo${d}bar`);
|
||||
});
|
||||
|
||||
it('should display server name (not "oauth") as function_name for OAuth tool calls', () => {
|
||||
renderWithRecoil(
|
||||
<ToolCall
|
||||
{...mockProps}
|
||||
name={`oauth${d}my-server`}
|
||||
initialProgress={1}
|
||||
isSubmitting={false}
|
||||
output="done"
|
||||
auth="https://auth.example.com"
|
||||
/>,
|
||||
);
|
||||
const progressText = screen.getByTestId('progress-text');
|
||||
expect(progressText.textContent).toContain('Completed my-server');
|
||||
expect(progressText.textContent).not.toContain('Completed oauth');
|
||||
});
|
||||
|
||||
it('should display server name even when auth is cleared (post-completion)', () => {
|
||||
// After OAuth completes, createOAuthEnd re-emits the toolCall without auth.
|
||||
// The display should still show the server name, not literal "oauth".
|
||||
renderWithRecoil(
|
||||
<ToolCall
|
||||
{...mockProps}
|
||||
name={`oauth${d}my-server`}
|
||||
initialProgress={1}
|
||||
isSubmitting={false}
|
||||
output="done"
|
||||
/>,
|
||||
);
|
||||
const progressText = screen.getByTestId('progress-text');
|
||||
expect(progressText.textContent).toContain('Completed my-server');
|
||||
expect(progressText.textContent).not.toContain('Completed oauth');
|
||||
});
|
||||
|
||||
it('should fallback to auth URL redirect_uri when name lacks delimiter', () => {
|
||||
const authUrl =
|
||||
'https://oauth.example.com/authorize?redirect_uri=' +
|
||||
encodeURIComponent('https://app.example.com/api/mcp/my-server/oauth/callback');
|
||||
renderWithRecoil(
|
||||
<ToolCall
|
||||
{...mockProps}
|
||||
name="bare_name"
|
||||
initialProgress={0.5}
|
||||
isSubmitting={true}
|
||||
auth={authUrl}
|
||||
/>,
|
||||
);
|
||||
const subtitle = screen.getByTestId('subtitle');
|
||||
expect(subtitle.textContent).toBe('via my-server');
|
||||
});
|
||||
|
||||
it('should display server name (not raw tool-call ID) in fallback path finished text', () => {
|
||||
const authUrl =
|
||||
'https://oauth.example.com/authorize?redirect_uri=' +
|
||||
encodeURIComponent('https://app.example.com/api/mcp/my-server/oauth/callback');
|
||||
renderWithRecoil(
|
||||
<ToolCall
|
||||
{...mockProps}
|
||||
name="bare_name"
|
||||
initialProgress={1}
|
||||
isSubmitting={false}
|
||||
output="done"
|
||||
auth={authUrl}
|
||||
/>,
|
||||
);
|
||||
const progressText = screen.getByTestId('progress-text');
|
||||
expect(progressText.textContent).toContain('Completed my-server');
|
||||
expect(progressText.textContent).not.toContain('bare_name');
|
||||
});
|
||||
|
||||
it('should show normalized server name when it contains _mcp_ after prefixing', () => {
|
||||
// Server named oauth@mcp@server normalizes to oauth_mcp_server,
|
||||
// gets prefixed to oauth_mcp_oauth_mcp_server. Client parses:
|
||||
// func="oauth", server="oauth_mcp_server". Visually awkward but
|
||||
// semantically correct — the normalized name IS oauth_mcp_server.
|
||||
renderWithRecoil(
|
||||
<ToolCall
|
||||
{...mockProps}
|
||||
name={`oauth${d}oauth${d}server`}
|
||||
initialProgress={0.5}
|
||||
isSubmitting={true}
|
||||
auth="https://auth.example.com"
|
||||
/>,
|
||||
);
|
||||
const subtitle = screen.getByTestId('subtitle');
|
||||
expect(subtitle.textContent).toBe(`via oauth${d}server`);
|
||||
});
|
||||
|
||||
it('should not misidentify non-MCP action auth as MCP via fallback', () => {
|
||||
const authUrl =
|
||||
'https://oauth.example.com/authorize?redirect_uri=' +
|
||||
encodeURIComponent('https://app.example.com/api/actions/xyz/oauth/callback');
|
||||
renderWithRecoil(
|
||||
<ToolCall
|
||||
{...mockProps}
|
||||
name="action_name"
|
||||
initialProgress={0.5}
|
||||
isSubmitting={true}
|
||||
auth={authUrl}
|
||||
/>,
|
||||
);
|
||||
expect(screen.queryByTestId('subtitle')).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe('A11Y-04: screen reader status announcements', () => {
|
||||
it('includes sr-only aria-live region for status announcements', () => {
|
||||
renderWithRecoil(
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ jest.mock('~/hooks', () => ({
|
|||
}));
|
||||
|
||||
jest.mock('~/Providers', () => ({
|
||||
useMessagesOperations: () => ({
|
||||
useOptionalMessagesOperations: () => ({
|
||||
ask: jest.fn(),
|
||||
}),
|
||||
}));
|
||||
|
|
|
|||
|
|
@ -13,10 +13,10 @@ jest.mock('@mcp-ui/client', () => ({
|
|||
),
|
||||
}));
|
||||
|
||||
// Mock useMessagesOperations hook
|
||||
// Mock useOptionalMessagesOperations hook
|
||||
const mockAsk = jest.fn();
|
||||
jest.mock('~/Providers', () => ({
|
||||
useMessagesOperations: () => ({
|
||||
useOptionalMessagesOperations: () => ({
|
||||
ask: mockAsk,
|
||||
}),
|
||||
}));
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import React from 'react';
|
||||
import { useMessageProcess } from '~/hooks';
|
||||
import { useMessageProcess, useMemoizedChatContext } from '~/hooks';
|
||||
import type { TMessageProps } from '~/common';
|
||||
import MessageRender from './ui/MessageRender';
|
||||
import MultiMessage from './MultiMessage';
|
||||
|
|
@ -23,10 +23,11 @@ const MessageContainer = React.memo(function MessageContainer({
|
|||
});
|
||||
|
||||
export default function Message(props: TMessageProps) {
|
||||
const { conversation, handleScroll } = useMessageProcess({
|
||||
const { conversation, handleScroll, isSubmitting } = useMessageProcess({
|
||||
message: props.message,
|
||||
});
|
||||
const { message, currentEditId, setCurrentEditId } = props;
|
||||
const { chatContext, effectiveIsSubmitting } = useMemoizedChatContext(message, isSubmitting);
|
||||
|
||||
if (!message || typeof message !== 'object') {
|
||||
return null;
|
||||
|
|
@ -38,7 +39,11 @@ export default function Message(props: TMessageProps) {
|
|||
<>
|
||||
<MessageContainer handleScroll={handleScroll}>
|
||||
<div className="m-auto justify-center p-4 py-2 md:gap-6">
|
||||
<MessageRender {...props} />
|
||||
<MessageRender
|
||||
{...props}
|
||||
isSubmitting={effectiveIsSubmitting}
|
||||
chatContext={chatContext}
|
||||
/>
|
||||
</div>
|
||||
</MessageContainer>
|
||||
<MultiMessage
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import React, { useCallback, useMemo, memo } from 'react';
|
|||
import { useAtomValue } from 'jotai';
|
||||
import { useRecoilValue } from 'recoil';
|
||||
import type { TMessage } from 'librechat-data-provider';
|
||||
import type { TMessageProps, TMessageIcon } from '~/common';
|
||||
import type { TMessageProps, TMessageIcon, TMessageChatContext } from '~/common';
|
||||
import { cn, getHeaderPrefixForScreenReader, getMessageAriaLabel } from '~/utils';
|
||||
import MessageContent from '~/components/Chat/Messages/Content/MessageContent';
|
||||
import { useLocalize, useMessageActions, useContentMetadata } from '~/hooks';
|
||||
|
|
@ -17,12 +17,73 @@ import store from '~/store';
|
|||
|
||||
type MessageRenderProps = {
|
||||
message?: TMessage;
|
||||
/**
|
||||
* Effective isSubmitting: false for non-latest messages, real value for latest.
|
||||
* Computed by the wrapper (Message.tsx) so this memo'd component only re-renders
|
||||
* when the value actually matters.
|
||||
*/
|
||||
isSubmitting?: boolean;
|
||||
/** Stable context object from wrapper — avoids ChatContext subscription inside memo */
|
||||
chatContext: TMessageChatContext;
|
||||
} & Pick<
|
||||
TMessageProps,
|
||||
'currentEditId' | 'setCurrentEditId' | 'siblingIdx' | 'setSiblingIdx' | 'siblingCount'
|
||||
>;
|
||||
|
||||
/**
|
||||
* Custom comparator for React.memo: compares `message` by key fields instead of reference
|
||||
* because `buildTree` creates new message objects on every streaming update for ALL messages,
|
||||
* even when only the latest message's text changed.
|
||||
*/
|
||||
function areMessageRenderPropsEqual(prev: MessageRenderProps, next: MessageRenderProps): boolean {
|
||||
if (prev.isSubmitting !== next.isSubmitting) {
|
||||
return false;
|
||||
}
|
||||
if (prev.chatContext !== next.chatContext) {
|
||||
return false;
|
||||
}
|
||||
if (prev.siblingIdx !== next.siblingIdx) {
|
||||
return false;
|
||||
}
|
||||
if (prev.siblingCount !== next.siblingCount) {
|
||||
return false;
|
||||
}
|
||||
if (prev.currentEditId !== next.currentEditId) {
|
||||
return false;
|
||||
}
|
||||
if (prev.setSiblingIdx !== next.setSiblingIdx) {
|
||||
return false;
|
||||
}
|
||||
if (prev.setCurrentEditId !== next.setCurrentEditId) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const prevMsg = prev.message;
|
||||
const nextMsg = next.message;
|
||||
if (prevMsg === nextMsg) {
|
||||
return true;
|
||||
}
|
||||
if (!prevMsg || !nextMsg) {
|
||||
return prevMsg === nextMsg;
|
||||
}
|
||||
|
||||
return (
|
||||
prevMsg.messageId === nextMsg.messageId &&
|
||||
prevMsg.text === nextMsg.text &&
|
||||
prevMsg.error === nextMsg.error &&
|
||||
prevMsg.unfinished === nextMsg.unfinished &&
|
||||
prevMsg.depth === nextMsg.depth &&
|
||||
prevMsg.isCreatedByUser === nextMsg.isCreatedByUser &&
|
||||
(prevMsg.children?.length ?? 0) === (nextMsg.children?.length ?? 0) &&
|
||||
prevMsg.content === nextMsg.content &&
|
||||
prevMsg.model === nextMsg.model &&
|
||||
prevMsg.endpoint === nextMsg.endpoint &&
|
||||
prevMsg.iconURL === nextMsg.iconURL &&
|
||||
prevMsg.feedback?.rating === nextMsg.feedback?.rating &&
|
||||
(prevMsg.files?.length ?? 0) === (nextMsg.files?.length ?? 0)
|
||||
);
|
||||
}
|
||||
|
||||
const MessageRender = memo(function MessageRender({
|
||||
message: msg,
|
||||
siblingIdx,
|
||||
|
|
@ -31,6 +92,7 @@ const MessageRender = memo(function MessageRender({
|
|||
currentEditId,
|
||||
setCurrentEditId,
|
||||
isSubmitting = false,
|
||||
chatContext,
|
||||
}: MessageRenderProps) {
|
||||
const localize = useLocalize();
|
||||
const {
|
||||
|
|
@ -52,6 +114,7 @@ const MessageRender = memo(function MessageRender({
|
|||
message: msg,
|
||||
currentEditId,
|
||||
setCurrentEditId,
|
||||
chatContext,
|
||||
});
|
||||
const fontSize = useAtomValue(fontSizeAtom);
|
||||
const maximizeChatSpace = useRecoilValue(store.maximizeChatSpace);
|
||||
|
|
@ -63,8 +126,6 @@ const MessageRender = memo(function MessageRender({
|
|||
[hasNoChildren, msg?.depth, latestMessageDepth],
|
||||
);
|
||||
const isLatestMessage = msg?.messageId === latestMessageId;
|
||||
/** Only pass isSubmitting to the latest message to prevent unnecessary re-renders */
|
||||
const effectiveIsSubmitting = isLatestMessage ? isSubmitting : false;
|
||||
|
||||
const iconData: TMessageIcon = useMemo(
|
||||
() => ({
|
||||
|
|
@ -92,10 +153,10 @@ const MessageRender = memo(function MessageRender({
|
|||
messageId,
|
||||
isLatestMessage,
|
||||
isExpanded: false as const,
|
||||
isSubmitting: effectiveIsSubmitting,
|
||||
isSubmitting,
|
||||
conversationId: conversation?.conversationId,
|
||||
}),
|
||||
[messageId, conversation?.conversationId, effectiveIsSubmitting, isLatestMessage],
|
||||
[messageId, conversation?.conversationId, isSubmitting, isLatestMessage],
|
||||
);
|
||||
|
||||
if (!msg) {
|
||||
|
|
@ -165,7 +226,7 @@ const MessageRender = memo(function MessageRender({
|
|||
message={msg}
|
||||
enterEdit={enterEdit}
|
||||
error={!!(msg.error ?? false)}
|
||||
isSubmitting={effectiveIsSubmitting}
|
||||
isSubmitting={isSubmitting}
|
||||
unfinished={msg.unfinished ?? false}
|
||||
isCreatedByUser={msg.isCreatedByUser ?? true}
|
||||
siblingIdx={siblingIdx ?? 0}
|
||||
|
|
@ -173,7 +234,7 @@ const MessageRender = memo(function MessageRender({
|
|||
/>
|
||||
</MessageContext.Provider>
|
||||
</div>
|
||||
{hasNoChildren && effectiveIsSubmitting ? (
|
||||
{hasNoChildren && isSubmitting ? (
|
||||
<PlaceholderRow />
|
||||
) : (
|
||||
<SubRow classes="text-xs">
|
||||
|
|
@ -187,7 +248,7 @@ const MessageRender = memo(function MessageRender({
|
|||
isEditing={edit}
|
||||
message={msg}
|
||||
enterEdit={enterEdit}
|
||||
isSubmitting={isSubmitting}
|
||||
isSubmitting={chatContext.isSubmitting}
|
||||
conversation={conversation ?? null}
|
||||
regenerate={handleRegenerateMessage}
|
||||
copyToClipboard={copyToClipboard}
|
||||
|
|
@ -202,7 +263,7 @@ const MessageRender = memo(function MessageRender({
|
|||
</div>
|
||||
</div>
|
||||
);
|
||||
});
|
||||
}, areMessageRenderPropsEqual);
|
||||
MessageRender.displayName = 'MessageRender';
|
||||
|
||||
export default MessageRender;
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
import React from 'react';
|
||||
import { UIResourceRenderer } from '@mcp-ui/client';
|
||||
import { handleUIAction } from '~/utils';
|
||||
import { useOptionalMessagesConversation, useOptionalMessagesOperations } from '~/Providers';
|
||||
import { useConversationUIResources } from '~/hooks/Messages/useConversationUIResources';
|
||||
import { useMessagesConversation, useMessagesOperations } from '~/Providers';
|
||||
import { handleUIAction } from '~/utils';
|
||||
import { useLocalize } from '~/hooks';
|
||||
|
||||
interface MCPUIResourceProps {
|
||||
|
|
@ -13,19 +13,14 @@ interface MCPUIResourceProps {
|
|||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Component that renders an MCP UI resource based on its resource ID.
|
||||
* Works in both main app and share view.
|
||||
*/
|
||||
/** Renders an MCP UI resource based on its resource ID. Works in chat, share, and search views. */
|
||||
export function MCPUIResource(props: MCPUIResourceProps) {
|
||||
const { resourceId } = props.node.properties;
|
||||
const localize = useLocalize();
|
||||
const { ask } = useMessagesOperations();
|
||||
const { conversation } = useMessagesConversation();
|
||||
const { ask } = useOptionalMessagesOperations();
|
||||
const { conversationId } = useOptionalMessagesConversation();
|
||||
|
||||
const conversationResourceMap = useConversationUIResources(
|
||||
conversation?.conversationId ?? undefined,
|
||||
);
|
||||
const conversationResourceMap = useConversationUIResources(conversationId ?? undefined);
|
||||
|
||||
const uiResource = conversationResourceMap.get(resourceId ?? '');
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
import React, { useMemo } from 'react';
|
||||
import { useConversationUIResources } from '~/hooks/Messages/useConversationUIResources';
|
||||
import { useMessagesConversation } from '~/Providers';
|
||||
import UIResourceCarousel from '../Chat/Messages/Content/UIResourceCarousel';
|
||||
import type { UIResource } from 'librechat-data-provider';
|
||||
import { useConversationUIResources } from '~/hooks/Messages/useConversationUIResources';
|
||||
import UIResourceCarousel from '../Chat/Messages/Content/UIResourceCarousel';
|
||||
import { useOptionalMessagesConversation } from '~/Providers';
|
||||
|
||||
interface MCPUIResourceCarouselProps {
|
||||
node: {
|
||||
|
|
@ -12,16 +12,11 @@ interface MCPUIResourceCarouselProps {
|
|||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Component that renders multiple MCP UI resources in a carousel.
|
||||
* Works in both main app and share view.
|
||||
*/
|
||||
/** Renders multiple MCP UI resources in a carousel. Works in chat, share, and search views. */
|
||||
export function MCPUIResourceCarousel(props: MCPUIResourceCarouselProps) {
|
||||
const { conversation } = useMessagesConversation();
|
||||
const { conversationId } = useOptionalMessagesConversation();
|
||||
|
||||
const conversationResourceMap = useConversationUIResources(
|
||||
conversation?.conversationId ?? undefined,
|
||||
);
|
||||
const conversationResourceMap = useConversationUIResources(conversationId ?? undefined);
|
||||
|
||||
const uiResources = useMemo(() => {
|
||||
const { resourceIds = [] } = props.node.properties;
|
||||
|
|
|
|||
|
|
@ -2,7 +2,11 @@ import React from 'react';
|
|||
import { render, screen } from '@testing-library/react';
|
||||
import { RecoilRoot } from 'recoil';
|
||||
import { MCPUIResource } from '../MCPUIResource';
|
||||
import { useMessageContext, useMessagesConversation, useMessagesOperations } from '~/Providers';
|
||||
import {
|
||||
useMessageContext,
|
||||
useOptionalMessagesConversation,
|
||||
useOptionalMessagesOperations,
|
||||
} from '~/Providers';
|
||||
import { useLocalize } from '~/hooks';
|
||||
import { handleUIAction } from '~/utils';
|
||||
|
||||
|
|
@ -22,11 +26,11 @@ jest.mock('@mcp-ui/client', () => ({
|
|||
}));
|
||||
|
||||
const mockUseMessageContext = useMessageContext as jest.MockedFunction<typeof useMessageContext>;
|
||||
const mockUseMessagesConversation = useMessagesConversation as jest.MockedFunction<
|
||||
typeof useMessagesConversation
|
||||
const mockUseMessagesConversation = useOptionalMessagesConversation as jest.MockedFunction<
|
||||
typeof useOptionalMessagesConversation
|
||||
>;
|
||||
const mockUseMessagesOperations = useMessagesOperations as jest.MockedFunction<
|
||||
typeof useMessagesOperations
|
||||
const mockUseMessagesOperations = useOptionalMessagesOperations as jest.MockedFunction<
|
||||
typeof useOptionalMessagesOperations
|
||||
>;
|
||||
const mockUseLocalize = useLocalize as jest.MockedFunction<typeof useLocalize>;
|
||||
const mockHandleUIAction = handleUIAction as jest.MockedFunction<typeof handleUIAction>;
|
||||
|
|
|
|||
|
|
@ -2,7 +2,11 @@ import React from 'react';
|
|||
import { render, screen } from '@testing-library/react';
|
||||
import { RecoilRoot } from 'recoil';
|
||||
import { MCPUIResourceCarousel } from '../MCPUIResourceCarousel';
|
||||
import { useMessageContext, useMessagesConversation, useMessagesOperations } from '~/Providers';
|
||||
import {
|
||||
useMessageContext,
|
||||
useOptionalMessagesConversation,
|
||||
useOptionalMessagesOperations,
|
||||
} from '~/Providers';
|
||||
|
||||
// Mock dependencies
|
||||
jest.mock('~/Providers');
|
||||
|
|
@ -19,11 +23,11 @@ jest.mock('../../Chat/Messages/Content/UIResourceCarousel', () => ({
|
|||
}));
|
||||
|
||||
const mockUseMessageContext = useMessageContext as jest.MockedFunction<typeof useMessageContext>;
|
||||
const mockUseMessagesConversation = useMessagesConversation as jest.MockedFunction<
|
||||
typeof useMessagesConversation
|
||||
const mockUseMessagesConversation = useOptionalMessagesConversation as jest.MockedFunction<
|
||||
typeof useOptionalMessagesConversation
|
||||
>;
|
||||
const mockUseMessagesOperations = useMessagesOperations as jest.MockedFunction<
|
||||
typeof useMessagesOperations
|
||||
const mockUseMessagesOperations = useOptionalMessagesOperations as jest.MockedFunction<
|
||||
typeof useOptionalMessagesOperations
|
||||
>;
|
||||
|
||||
describe('MCPUIResourceCarousel', () => {
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import { useCallback, useMemo, memo } from 'react';
|
|||
import { useAtomValue } from 'jotai';
|
||||
import { useRecoilValue } from 'recoil';
|
||||
import type { TMessage, TMessageContentParts } from 'librechat-data-provider';
|
||||
import type { TMessageProps, TMessageIcon } from '~/common';
|
||||
import type { TMessageProps, TMessageIcon, TMessageChatContext } from '~/common';
|
||||
import { useAttachments, useLocalize, useMessageActions, useContentMetadata } from '~/hooks';
|
||||
import { cn, getHeaderPrefixForScreenReader, getMessageAriaLabel } from '~/utils';
|
||||
import ContentParts from '~/components/Chat/Messages/Content/ContentParts';
|
||||
|
|
@ -16,12 +16,72 @@ import store from '~/store';
|
|||
|
||||
type ContentRenderProps = {
|
||||
message?: TMessage;
|
||||
/**
|
||||
* Effective isSubmitting: false for non-latest messages, real value for latest.
|
||||
* Computed by the wrapper (MessageContent.tsx) so this memo'd component only re-renders
|
||||
* when the value actually matters.
|
||||
*/
|
||||
isSubmitting?: boolean;
|
||||
/** Stable context object from wrapper — avoids ChatContext subscription inside memo */
|
||||
chatContext: TMessageChatContext;
|
||||
} & Pick<
|
||||
TMessageProps,
|
||||
'currentEditId' | 'setCurrentEditId' | 'siblingIdx' | 'setSiblingIdx' | 'siblingCount'
|
||||
>;
|
||||
|
||||
/**
|
||||
* Custom comparator for React.memo: compares `message` by key fields instead of reference
|
||||
* because `buildTree` creates new message objects on every streaming update for ALL messages.
|
||||
*/
|
||||
function areContentRenderPropsEqual(prev: ContentRenderProps, next: ContentRenderProps): boolean {
|
||||
if (prev.isSubmitting !== next.isSubmitting) {
|
||||
return false;
|
||||
}
|
||||
if (prev.chatContext !== next.chatContext) {
|
||||
return false;
|
||||
}
|
||||
if (prev.siblingIdx !== next.siblingIdx) {
|
||||
return false;
|
||||
}
|
||||
if (prev.siblingCount !== next.siblingCount) {
|
||||
return false;
|
||||
}
|
||||
if (prev.currentEditId !== next.currentEditId) {
|
||||
return false;
|
||||
}
|
||||
if (prev.setSiblingIdx !== next.setSiblingIdx) {
|
||||
return false;
|
||||
}
|
||||
if (prev.setCurrentEditId !== next.setCurrentEditId) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const prevMsg = prev.message;
|
||||
const nextMsg = next.message;
|
||||
if (prevMsg === nextMsg) {
|
||||
return true;
|
||||
}
|
||||
if (!prevMsg || !nextMsg) {
|
||||
return prevMsg === nextMsg;
|
||||
}
|
||||
|
||||
return (
|
||||
prevMsg.messageId === nextMsg.messageId &&
|
||||
prevMsg.text === nextMsg.text &&
|
||||
prevMsg.error === nextMsg.error &&
|
||||
prevMsg.unfinished === nextMsg.unfinished &&
|
||||
prevMsg.depth === nextMsg.depth &&
|
||||
prevMsg.isCreatedByUser === nextMsg.isCreatedByUser &&
|
||||
(prevMsg.children?.length ?? 0) === (nextMsg.children?.length ?? 0) &&
|
||||
prevMsg.content === nextMsg.content &&
|
||||
prevMsg.model === nextMsg.model &&
|
||||
prevMsg.endpoint === nextMsg.endpoint &&
|
||||
prevMsg.iconURL === nextMsg.iconURL &&
|
||||
prevMsg.feedback?.rating === nextMsg.feedback?.rating &&
|
||||
(prevMsg.attachments?.length ?? 0) === (nextMsg.attachments?.length ?? 0)
|
||||
);
|
||||
}
|
||||
|
||||
const ContentRender = memo(function ContentRender({
|
||||
message: msg,
|
||||
siblingIdx,
|
||||
|
|
@ -30,6 +90,7 @@ const ContentRender = memo(function ContentRender({
|
|||
currentEditId,
|
||||
setCurrentEditId,
|
||||
isSubmitting = false,
|
||||
chatContext,
|
||||
}: ContentRenderProps) {
|
||||
const localize = useLocalize();
|
||||
const { attachments, searchResults } = useAttachments({
|
||||
|
|
@ -55,6 +116,7 @@ const ContentRender = memo(function ContentRender({
|
|||
searchResults,
|
||||
currentEditId,
|
||||
setCurrentEditId,
|
||||
chatContext,
|
||||
});
|
||||
const fontSize = useAtomValue(fontSizeAtom);
|
||||
const maximizeChatSpace = useRecoilValue(store.maximizeChatSpace);
|
||||
|
|
@ -66,8 +128,6 @@ const ContentRender = memo(function ContentRender({
|
|||
);
|
||||
const hasNoChildren = !(msg?.children?.length ?? 0);
|
||||
const isLatestMessage = msg?.messageId === latestMessageId;
|
||||
/** Only pass isSubmitting to the latest message to prevent unnecessary re-renders */
|
||||
const effectiveIsSubmitting = isLatestMessage ? isSubmitting : false;
|
||||
|
||||
const iconData: TMessageIcon = useMemo(
|
||||
() => ({
|
||||
|
|
@ -158,13 +218,13 @@ const ContentRender = memo(function ContentRender({
|
|||
searchResults={searchResults}
|
||||
setSiblingIdx={setSiblingIdx}
|
||||
isLatestMessage={isLatestMessage}
|
||||
isSubmitting={effectiveIsSubmitting}
|
||||
isSubmitting={isSubmitting}
|
||||
isCreatedByUser={msg.isCreatedByUser}
|
||||
conversationId={conversation?.conversationId}
|
||||
content={msg.content as Array<TMessageContentParts | undefined>}
|
||||
/>
|
||||
</div>
|
||||
{hasNoChildren && effectiveIsSubmitting ? (
|
||||
{hasNoChildren && isSubmitting ? (
|
||||
<PlaceholderRow />
|
||||
) : (
|
||||
<SubRow classes="text-xs">
|
||||
|
|
@ -178,7 +238,7 @@ const ContentRender = memo(function ContentRender({
|
|||
message={msg}
|
||||
isEditing={edit}
|
||||
enterEdit={enterEdit}
|
||||
isSubmitting={isSubmitting}
|
||||
isSubmitting={chatContext.isSubmitting}
|
||||
conversation={conversation ?? null}
|
||||
regenerate={handleRegenerateMessage}
|
||||
copyToClipboard={copyToClipboard}
|
||||
|
|
@ -193,7 +253,7 @@ const ContentRender = memo(function ContentRender({
|
|||
</div>
|
||||
</div>
|
||||
);
|
||||
});
|
||||
}, areContentRenderPropsEqual);
|
||||
ContentRender.displayName = 'ContentRender';
|
||||
|
||||
export default ContentRender;
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import React from 'react';
|
||||
import { useMessageProcess } from '~/hooks';
|
||||
import { useMessageProcess, useMemoizedChatContext } from '~/hooks';
|
||||
import type { TMessageProps } from '~/common';
|
||||
|
||||
import MultiMessage from '~/components/Chat/Messages/MultiMessage';
|
||||
|
|
@ -28,6 +28,7 @@ export default function MessageContent(props: TMessageProps) {
|
|||
message: props.message,
|
||||
});
|
||||
const { message, currentEditId, setCurrentEditId } = props;
|
||||
const { chatContext, effectiveIsSubmitting } = useMemoizedChatContext(message, isSubmitting);
|
||||
|
||||
if (!message || typeof message !== 'object') {
|
||||
return null;
|
||||
|
|
@ -39,7 +40,11 @@ export default function MessageContent(props: TMessageProps) {
|
|||
<>
|
||||
<MessageContainer handleScroll={handleScroll}>
|
||||
<div className="m-auto justify-center p-4 py-2 md:gap-6">
|
||||
<ContentRender {...props} isSubmitting={isSubmitting} />
|
||||
<ContentRender
|
||||
{...props}
|
||||
isSubmitting={effectiveIsSubmitting}
|
||||
chatContext={chatContext}
|
||||
/>
|
||||
</div>
|
||||
</MessageContainer>
|
||||
<MultiMessage
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
import { useCallback } from 'react';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useRecoilValue } from 'recoil';
|
||||
import { useGetModelsQuery } from 'librechat-data-provider/react-query';
|
||||
import {
|
||||
|
|
@ -122,9 +122,12 @@ export default function useAddedResponse() {
|
|||
],
|
||||
);
|
||||
|
||||
return {
|
||||
conversation,
|
||||
setConversation,
|
||||
generateConversation,
|
||||
};
|
||||
return useMemo(
|
||||
() => ({
|
||||
conversation,
|
||||
setConversation,
|
||||
generateConversation,
|
||||
}),
|
||||
[conversation, setConversation, generateConversation],
|
||||
);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
export { default as useDeleteFilesFromTable } from './useDeleteFilesFromTable';
|
||||
export { default as useSetFilesToDelete } from './useSetFilesToDelete';
|
||||
export { default as useFileHandling } from './useFileHandling';
|
||||
export { default as useFileHandling, useFileHandlingNoChatContext } from './useFileHandling';
|
||||
export { default as useFileDeletion } from './useFileDeletion';
|
||||
export { default as useUpdateFiles } from './useUpdateFiles';
|
||||
export { default as useDragHelpers } from './useDragHelpers';
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ export { default as useSubmitMessage } from './useSubmitMessage';
|
|||
export type { ContentMetadataResult } from './useContentMetadata';
|
||||
export { default as useExpandCollapse } from './useExpandCollapse';
|
||||
export { default as useMessageActions } from './useMessageActions';
|
||||
export { default as useMemoizedChatContext } from './useMemoizedChatContext';
|
||||
export { default as useMessageProcess } from './useMessageProcess';
|
||||
export { default as useMessageHelpers } from './useMessageHelpers';
|
||||
export { default as useCopyToClipboard } from './useCopyToClipboard';
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import { useMemo } from 'react';
|
|||
import { useRecoilValue } from 'recoil';
|
||||
import { Tools } from 'librechat-data-provider';
|
||||
import type { TAttachment, UIResource } from 'librechat-data-provider';
|
||||
import { useMessagesOperations } from '~/Providers';
|
||||
import { useOptionalMessagesOperations } from '~/Providers';
|
||||
import store from '~/store';
|
||||
|
||||
/**
|
||||
|
|
@ -16,7 +16,7 @@ import store from '~/store';
|
|||
export function useConversationUIResources(
|
||||
conversationId: string | undefined,
|
||||
): Map<string, UIResource> {
|
||||
const { getMessages } = useMessagesOperations();
|
||||
const { getMessages } = useOptionalMessagesOperations();
|
||||
|
||||
const conversationAttachmentsMap = useRecoilValue(
|
||||
store.conversationAttachmentsSelector(conversationId),
|
||||
|
|
|
|||
80
client/src/hooks/Messages/useMemoizedChatContext.ts
Normal file
80
client/src/hooks/Messages/useMemoizedChatContext.ts
Normal file
|
|
@ -0,0 +1,80 @@
|
|||
import { useRef, useMemo } from 'react';
|
||||
import type { TMessage } from 'librechat-data-provider';
|
||||
import type { TMessageChatContext } from '~/common/types';
|
||||
import { useChatContext } from '~/Providers';
|
||||
|
||||
/**
|
||||
* Creates a stable `TMessageChatContext` object for memo'd message components.
|
||||
*
|
||||
* Subscribes to `useChatContext()` internally (intended to be called from non-memo'd
|
||||
* wrapper components like `Message` and `MessageContent`), then produces:
|
||||
* - A `chatContext` object that stays referentially stable during streaming
|
||||
* (uses a getter for `isSubmitting` backed by a ref)
|
||||
* - A stable `conversation` reference that only updates when rendering-relevant fields change
|
||||
* - An `effectiveIsSubmitting` value (false for non-latest messages)
|
||||
*/
|
||||
export default function useMemoizedChatContext(
|
||||
message: TMessage | null | undefined,
|
||||
isSubmitting: boolean,
|
||||
) {
|
||||
const chatCtx = useChatContext();
|
||||
|
||||
const isSubmittingRef = useRef(isSubmitting);
|
||||
isSubmittingRef.current = isSubmitting;
|
||||
|
||||
/**
|
||||
* Stabilize conversation: only update when rendering-relevant fields change,
|
||||
* not on every metadata update (e.g., title generation).
|
||||
*/
|
||||
const stableConversation = useMemo(
|
||||
() => chatCtx.conversation,
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
[
|
||||
chatCtx.conversation?.conversationId,
|
||||
chatCtx.conversation?.endpoint,
|
||||
chatCtx.conversation?.endpointType,
|
||||
chatCtx.conversation?.model,
|
||||
chatCtx.conversation?.agent_id,
|
||||
chatCtx.conversation?.assistant_id,
|
||||
],
|
||||
);
|
||||
|
||||
/**
|
||||
* `isSubmitting` is included in deps so that chatContext gets a new reference
|
||||
* when streaming starts/ends (2x per session). This ensures HoverButtons
|
||||
* re-renders to update regenerate/edit button visibility via useGenerationsByLatest.
|
||||
* The getter pattern is still valuable: callbacks reading chatContext.isSubmitting
|
||||
* at call-time always get the current value even between these re-renders.
|
||||
*/
|
||||
const chatContext: TMessageChatContext = useMemo(
|
||||
() => ({
|
||||
ask: chatCtx.ask,
|
||||
index: chatCtx.index,
|
||||
regenerate: chatCtx.regenerate,
|
||||
conversation: stableConversation,
|
||||
latestMessageId: chatCtx.latestMessageId,
|
||||
latestMessageDepth: chatCtx.latestMessageDepth,
|
||||
handleContinue: chatCtx.handleContinue,
|
||||
get isSubmitting() {
|
||||
return isSubmittingRef.current;
|
||||
},
|
||||
}),
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
[
|
||||
chatCtx.ask,
|
||||
chatCtx.index,
|
||||
chatCtx.regenerate,
|
||||
stableConversation,
|
||||
chatCtx.latestMessageId,
|
||||
chatCtx.latestMessageDepth,
|
||||
chatCtx.handleContinue,
|
||||
isSubmitting, // intentional: forces new reference on streaming start/end so HoverButtons re-renders
|
||||
],
|
||||
);
|
||||
|
||||
const messageId = message?.messageId ?? null;
|
||||
const isLatestMessage = messageId === chatCtx.latestMessageId;
|
||||
const effectiveIsSubmitting = isLatestMessage ? isSubmitting : false;
|
||||
|
||||
return { chatContext, effectiveIsSubmitting };
|
||||
}
|
||||
|
|
@ -11,7 +11,8 @@ import {
|
|||
TUpdateFeedbackRequest,
|
||||
} from 'librechat-data-provider';
|
||||
import type { TMessageProps } from '~/common';
|
||||
import { useChatContext, useAssistantsMapContext, useAgentsMapContext } from '~/Providers';
|
||||
import type { TMessageChatContext } from '~/common/types';
|
||||
import { useAssistantsMapContext, useAgentsMapContext } from '~/Providers';
|
||||
import useCopyToClipboard from './useCopyToClipboard';
|
||||
import { useAuthContext } from '~/hooks/AuthContext';
|
||||
import { useGetAddedConvo } from '~/hooks/Chat';
|
||||
|
|
@ -23,24 +24,33 @@ export type TMessageActions = Pick<
|
|||
'message' | 'currentEditId' | 'setCurrentEditId'
|
||||
> & {
|
||||
searchResults?: { [key: string]: SearchResultData };
|
||||
/**
|
||||
* Stable context object passed from wrapper components to avoid subscribing
|
||||
* to ChatContext inside memo'd components (which would bypass React.memo).
|
||||
* The `isSubmitting` property uses a getter backed by a ref, so it always
|
||||
* returns the current value at call-time without triggering re-renders.
|
||||
*/
|
||||
chatContext: TMessageChatContext;
|
||||
};
|
||||
|
||||
export default function useMessageActions(props: TMessageActions) {
|
||||
const localize = useLocalize();
|
||||
const { user } = useAuthContext();
|
||||
const UsernameDisplay = useRecoilValue<boolean>(store.UsernameDisplay);
|
||||
const { message, currentEditId, setCurrentEditId, searchResults } = props;
|
||||
const { message, currentEditId, setCurrentEditId, searchResults, chatContext } = props;
|
||||
|
||||
const {
|
||||
ask,
|
||||
index,
|
||||
regenerate,
|
||||
isSubmitting,
|
||||
conversation,
|
||||
latestMessageId,
|
||||
latestMessageDepth,
|
||||
handleContinue,
|
||||
} = useChatContext();
|
||||
// NOTE: isSubmitting is intentionally NOT destructured here.
|
||||
// chatContext.isSubmitting is a getter backed by a ref — destructuring
|
||||
// would capture a one-time snapshot. Always access via chatContext.isSubmitting.
|
||||
} = chatContext;
|
||||
|
||||
const getAddedConvo = useGetAddedConvo();
|
||||
|
||||
|
|
@ -98,13 +108,18 @@ export default function useMessageActions(props: TMessageActions) {
|
|||
}
|
||||
}, [agentsMap, conversation?.agent_id, conversation?.endpoint, message?.model]);
|
||||
|
||||
/**
|
||||
* chatContext.isSubmitting is a getter backed by the wrapper's ref,
|
||||
* so it always returns the current value at call-time — even for
|
||||
* non-latest messages that don't re-render during streaming.
|
||||
*/
|
||||
const regenerateMessage = useCallback(() => {
|
||||
if ((isSubmitting && isCreatedByUser === true) || !message) {
|
||||
if ((chatContext.isSubmitting && isCreatedByUser === true) || !message) {
|
||||
return;
|
||||
}
|
||||
|
||||
regenerate(message, { addedConvo: getAddedConvo() });
|
||||
}, [isSubmitting, isCreatedByUser, message, regenerate, getAddedConvo]);
|
||||
}, [chatContext, isCreatedByUser, message, regenerate, getAddedConvo]);
|
||||
|
||||
const copyToClipboard = useCopyToClipboard({ text, content, searchResults });
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import throttle from 'lodash/throttle';
|
||||
import { Constants } from 'librechat-data-provider';
|
||||
import { useEffect, useRef, useCallback, useMemo } from 'react';
|
||||
import { useEffect, useRef, useMemo } from 'react';
|
||||
import type { TMessage } from 'librechat-data-provider';
|
||||
import { getTextKey, TEXT_KEY_DIVIDER, logger } from '~/utils';
|
||||
import { useMessagesViewContext } from '~/Providers';
|
||||
|
|
@ -56,24 +56,25 @@ export default function useMessageProcess({ message }: { message?: TMessage | nu
|
|||
}
|
||||
}, [hasNoChildren, message, setLatestMessage, conversation?.conversationId]);
|
||||
|
||||
const handleScroll = useCallback(
|
||||
(event: unknown | TouchEvent | WheelEvent) => {
|
||||
throttle(() => {
|
||||
/** Use ref for isSubmitting to stabilize handleScroll across isSubmitting changes */
|
||||
const isSubmittingRef = useRef(isSubmitting);
|
||||
isSubmittingRef.current = isSubmitting;
|
||||
|
||||
const handleScroll = useMemo(
|
||||
() =>
|
||||
throttle((event: unknown) => {
|
||||
logger.log(
|
||||
'message_scrolling',
|
||||
`useMessageProcess: setting abort scroll to ${isSubmitting}, handleScroll event`,
|
||||
`useMessageProcess: setting abort scroll to ${isSubmittingRef.current}, handleScroll event`,
|
||||
event,
|
||||
);
|
||||
if (isSubmitting) {
|
||||
setAbortScroll(true);
|
||||
} else {
|
||||
setAbortScroll(false);
|
||||
}
|
||||
}, 500)();
|
||||
},
|
||||
[isSubmitting, setAbortScroll],
|
||||
setAbortScroll(isSubmittingRef.current);
|
||||
}, 500),
|
||||
[setAbortScroll],
|
||||
);
|
||||
|
||||
useEffect(() => () => handleScroll.cancel(), [handleScroll]);
|
||||
|
||||
return {
|
||||
handleScroll,
|
||||
isSubmitting,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import { renderHook, act } from '@testing-library/react';
|
||||
import { Constants, ErrorTypes, LocalStorageKeys } from 'librechat-data-provider';
|
||||
import { Constants, LocalStorageKeys } from 'librechat-data-provider';
|
||||
import type { TSubmission } from 'librechat-data-provider';
|
||||
|
||||
type SSEEventListener = (e: Partial<MessageEvent> & { responseCode?: number }) => void;
|
||||
|
|
@ -34,7 +34,13 @@ jest.mock('sse.js', () => ({
|
|||
}));
|
||||
|
||||
const mockSetQueryData = jest.fn();
|
||||
const mockQueryClient = { setQueryData: mockSetQueryData };
|
||||
const mockInvalidateQueries = jest.fn();
|
||||
const mockRemoveQueries = jest.fn();
|
||||
const mockQueryClient = {
|
||||
setQueryData: mockSetQueryData,
|
||||
invalidateQueries: mockInvalidateQueries,
|
||||
removeQueries: mockRemoveQueries,
|
||||
};
|
||||
|
||||
jest.mock('@tanstack/react-query', () => ({
|
||||
...jest.requireActual('@tanstack/react-query'),
|
||||
|
|
@ -63,6 +69,7 @@ jest.mock('~/data-provider', () => ({
|
|||
useGetStartupConfig: () => ({ data: { balance: { enabled: false } } }),
|
||||
useGetUserBalance: () => ({ refetch: jest.fn() }),
|
||||
queueTitleGeneration: jest.fn(),
|
||||
streamStatusQueryKey: (conversationId: string) => ['streamStatus', conversationId],
|
||||
}));
|
||||
|
||||
const mockErrorHandler = jest.fn();
|
||||
|
|
@ -162,6 +169,11 @@ describe('useResumableSSE - 404 error path', () => {
|
|||
beforeEach(() => {
|
||||
mockSSEInstances.length = 0;
|
||||
localStorage.clear();
|
||||
mockErrorHandler.mockClear();
|
||||
mockClearStepMaps.mockClear();
|
||||
mockSetIsSubmitting.mockClear();
|
||||
mockInvalidateQueries.mockClear();
|
||||
mockRemoveQueries.mockClear();
|
||||
});
|
||||
|
||||
const seedDraft = (conversationId: string) => {
|
||||
|
|
@ -200,19 +212,18 @@ describe('useResumableSSE - 404 error path', () => {
|
|||
unmount();
|
||||
});
|
||||
|
||||
it('calls errorHandler with STREAM_EXPIRED error type on 404', async () => {
|
||||
it('invalidates message cache and clears stream status on 404 instead of showing error', async () => {
|
||||
const { unmount } = await render404Scenario(CONV_ID);
|
||||
|
||||
expect(mockErrorHandler).toHaveBeenCalledTimes(1);
|
||||
const call = mockErrorHandler.mock.calls[0][0];
|
||||
expect(call.data).toBeDefined();
|
||||
const parsed = JSON.parse(call.data.text);
|
||||
expect(parsed.type).toBe(ErrorTypes.STREAM_EXPIRED);
|
||||
expect(call.submission).toEqual(
|
||||
expect.objectContaining({
|
||||
conversation: expect.objectContaining({ conversationId: CONV_ID }),
|
||||
}),
|
||||
);
|
||||
expect(mockErrorHandler).not.toHaveBeenCalled();
|
||||
expect(mockInvalidateQueries).toHaveBeenCalledWith({
|
||||
queryKey: ['messages', CONV_ID],
|
||||
});
|
||||
expect(mockRemoveQueries).toHaveBeenCalledWith({
|
||||
queryKey: ['streamStatus', CONV_ID],
|
||||
});
|
||||
expect(mockClearStepMaps).toHaveBeenCalled();
|
||||
expect(mockSetIsSubmitting).toHaveBeenCalledWith(false);
|
||||
unmount();
|
||||
});
|
||||
|
||||
|
|
|
|||
|
|
@ -16,7 +16,12 @@ import {
|
|||
} from 'librechat-data-provider';
|
||||
import type { TMessage, TPayload, TSubmission, EventSubmission } from 'librechat-data-provider';
|
||||
import type { EventHandlerParams } from './useEventHandlers';
|
||||
import { useGetStartupConfig, useGetUserBalance, queueTitleGeneration } from '~/data-provider';
|
||||
import {
|
||||
useGetUserBalance,
|
||||
useGetStartupConfig,
|
||||
queueTitleGeneration,
|
||||
streamStatusQueryKey,
|
||||
} from '~/data-provider';
|
||||
import type { ActiveJobsResponse } from '~/data-provider';
|
||||
import { useAuthContext } from '~/hooks/AuthContext';
|
||||
import useEventHandlers from './useEventHandlers';
|
||||
|
|
@ -343,18 +348,20 @@ export default function useResumableSSE(
|
|||
/* @ts-ignore - sse.js types don't expose responseCode */
|
||||
const responseCode = e.responseCode;
|
||||
|
||||
// 404 means job doesn't exist (completed/deleted) - don't retry
|
||||
// 404 → job completed & was cleaned up; messages are persisted in DB.
|
||||
// Invalidate cache once so react-query refetches instead of showing an error.
|
||||
if (responseCode === 404) {
|
||||
console.log('[ResumableSSE] Stream not found (404) - job completed or expired');
|
||||
const convoId = currentSubmission.conversation?.conversationId;
|
||||
console.log('[ResumableSSE] Stream 404, invalidating messages for:', convoId);
|
||||
sse.close();
|
||||
removeActiveJob(currentStreamId);
|
||||
clearAllDrafts(currentSubmission.conversation?.conversationId);
|
||||
errorHandler({
|
||||
data: {
|
||||
text: JSON.stringify({ type: ErrorTypes.STREAM_EXPIRED }),
|
||||
} as unknown as Parameters<typeof errorHandler>[0]['data'],
|
||||
submission: currentSubmission as EventSubmission,
|
||||
});
|
||||
clearAllDrafts(convoId);
|
||||
clearStepMaps();
|
||||
if (convoId) {
|
||||
queryClient.invalidateQueries({ queryKey: [QueryKeys.messages, convoId] });
|
||||
queryClient.removeQueries({ queryKey: streamStatusQueryKey(convoId) });
|
||||
}
|
||||
setIsSubmitting(false);
|
||||
setShowStopButton(false);
|
||||
setStreamId(null);
|
||||
reconnectAttemptRef.current = 0;
|
||||
|
|
@ -544,6 +551,7 @@ export default function useResumableSSE(
|
|||
startupConfig?.balance?.enabled,
|
||||
balanceQuery,
|
||||
removeActiveJob,
|
||||
queryClient,
|
||||
],
|
||||
);
|
||||
|
||||
|
|
|
|||
|
|
@ -125,7 +125,11 @@ export default function useResumeOnLoad(
|
|||
conversationId !== Constants.NEW_CONVO &&
|
||||
processedConvoRef.current !== conversationId; // Don't re-check processed convos
|
||||
|
||||
const { data: streamStatus, isSuccess } = useStreamStatus(conversationId, shouldCheck);
|
||||
const {
|
||||
data: streamStatus,
|
||||
isSuccess,
|
||||
isFetching,
|
||||
} = useStreamStatus(conversationId, shouldCheck);
|
||||
|
||||
useEffect(() => {
|
||||
console.log('[ResumeOnLoad] Effect check', {
|
||||
|
|
@ -135,6 +139,7 @@ export default function useResumeOnLoad(
|
|||
hasCurrentSubmission: !!currentSubmission,
|
||||
currentSubmissionConvoId: currentSubmission?.conversation?.conversationId,
|
||||
isSuccess,
|
||||
isFetching,
|
||||
streamStatusActive: streamStatus?.active,
|
||||
streamStatusStreamId: streamStatus?.streamId,
|
||||
processedConvoRef: processedConvoRef.current,
|
||||
|
|
@ -171,8 +176,9 @@ export default function useResumeOnLoad(
|
|||
);
|
||||
}
|
||||
|
||||
// Wait for stream status query to complete
|
||||
if (!isSuccess || !streamStatus) {
|
||||
// Wait for stream status query to complete (including background refetches
|
||||
// that may replace a stale cached result with fresh data)
|
||||
if (!isSuccess || !streamStatus || isFetching) {
|
||||
console.log('[ResumeOnLoad] Waiting for stream status query');
|
||||
return;
|
||||
}
|
||||
|
|
@ -183,15 +189,12 @@ export default function useResumeOnLoad(
|
|||
return;
|
||||
}
|
||||
|
||||
// Check if there's an active job to resume
|
||||
// DON'T mark as processed here - only mark when we actually create a submission
|
||||
// This prevents stale cache data from blocking subsequent resume attempts
|
||||
if (!streamStatus.active || !streamStatus.streamId) {
|
||||
console.log('[ResumeOnLoad] No active job to resume for:', conversationId);
|
||||
processedConvoRef.current = conversationId;
|
||||
return;
|
||||
}
|
||||
|
||||
// Mark as processed NOW - we verified there's an active job and will create submission
|
||||
processedConvoRef.current = conversationId;
|
||||
|
||||
console.log('[ResumeOnLoad] Found active job, creating submission...', {
|
||||
|
|
@ -241,6 +244,7 @@ export default function useResumeOnLoad(
|
|||
submissionConvoId,
|
||||
currentSubmission,
|
||||
isSuccess,
|
||||
isFetching,
|
||||
streamStatus,
|
||||
getMessages,
|
||||
setSubmission,
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ export default {
|
|||
'\\.helper\\.ts$',
|
||||
'\\.helper\\.d\\.ts$',
|
||||
'/__tests__/helpers/',
|
||||
'\\.manual\\.spec\\.[jt]sx?$',
|
||||
],
|
||||
coverageReporters: ['text', 'cobertura'],
|
||||
testResultsProcessor: 'jest-junit',
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue