mirror of
https://github.com/danny-avila/LibreChat.git
synced 2026-04-07 00:15:23 +02:00
Merge remote-tracking branch 'upstream/dev' into fix/openid-pkce-secret
# Conflicts: # api/server/routes/config.js # api/strategies/openidStrategy.spec.js
This commit is contained in:
commit
50af2e0ff1
465 changed files with 32926 additions and 5533 deletions
|
|
@ -123,9 +123,6 @@ function disposeClient(client) {
|
|||
if (client.maxContextTokens) {
|
||||
client.maxContextTokens = null;
|
||||
}
|
||||
if (client.contextStrategy) {
|
||||
client.contextStrategy = null;
|
||||
}
|
||||
if (client.currentDateString) {
|
||||
client.currentDateString = null;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,40 +1,12 @@
|
|||
const { logger } = require('@librechat/data-schemas');
|
||||
const { CacheKeys } = require('librechat-data-provider');
|
||||
const { loadDefaultModels, loadConfigModels } = require('~/server/services/Config');
|
||||
const { getLogStores } = require('~/cache');
|
||||
|
||||
/**
|
||||
* @param {ServerRequest} req
|
||||
* @returns {Promise<TModelsConfig>} The models config.
|
||||
*/
|
||||
const getModelsConfig = async (req) => {
|
||||
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
||||
let modelsConfig = await cache.get(CacheKeys.MODELS_CONFIG);
|
||||
if (!modelsConfig) {
|
||||
modelsConfig = await loadModels(req);
|
||||
}
|
||||
const getModelsConfig = (req) => loadModels(req);
|
||||
|
||||
return modelsConfig;
|
||||
};
|
||||
|
||||
/**
|
||||
* Loads the models from the config.
|
||||
* @param {ServerRequest} req - The Express request object.
|
||||
* @returns {Promise<TModelsConfig>} The models config.
|
||||
*/
|
||||
async function loadModels(req) {
|
||||
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
||||
const cachedModelsConfig = await cache.get(CacheKeys.MODELS_CONFIG);
|
||||
if (cachedModelsConfig) {
|
||||
return cachedModelsConfig;
|
||||
}
|
||||
const defaultModelsConfig = await loadDefaultModels(req);
|
||||
const customModelsConfig = await loadConfigModels(req);
|
||||
|
||||
const modelConfig = { ...defaultModelsConfig, ...customModelsConfig };
|
||||
|
||||
await cache.set(CacheKeys.MODELS_CONFIG, modelConfig);
|
||||
return modelConfig;
|
||||
return { ...defaultModelsConfig, ...customModelsConfig };
|
||||
}
|
||||
|
||||
async function modelController(req, res) {
|
||||
|
|
|
|||
|
|
@ -1,61 +1,37 @@
|
|||
const { logger } = require('@librechat/data-schemas');
|
||||
const { CacheKeys } = require('librechat-data-provider');
|
||||
const { getToolkitKey, checkPluginAuth, filterUniquePlugins } = require('@librechat/api');
|
||||
const { getCachedTools, setCachedTools } = require('~/server/services/Config');
|
||||
const { availableTools, toolkits } = require('~/app/clients/tools');
|
||||
const { getAppConfig } = require('~/server/services/Config');
|
||||
const { getLogStores } = require('~/cache');
|
||||
|
||||
const getAvailablePluginsController = async (req, res) => {
|
||||
try {
|
||||
const cache = getLogStores(CacheKeys.TOOL_CACHE);
|
||||
const cachedPlugins = await cache.get(CacheKeys.PLUGINS);
|
||||
if (cachedPlugins) {
|
||||
res.status(200).json(cachedPlugins);
|
||||
return;
|
||||
}
|
||||
|
||||
const appConfig = await getAppConfig({ role: req.user?.role });
|
||||
/** @type {{ filteredTools: string[], includedTools: string[] }} */
|
||||
const appConfig = await getAppConfig({ role: req.user?.role, tenantId: req.user?.tenantId });
|
||||
const { filteredTools = [], includedTools = [] } = appConfig;
|
||||
/** @type {import('@librechat/api').LCManifestTool[]} */
|
||||
const pluginManifest = availableTools;
|
||||
|
||||
const uniquePlugins = filterUniquePlugins(pluginManifest);
|
||||
let authenticatedPlugins = [];
|
||||
const uniquePlugins = filterUniquePlugins(availableTools);
|
||||
const includeSet = new Set(includedTools);
|
||||
const filterSet = new Set(filteredTools);
|
||||
|
||||
/** includedTools takes precedence — filteredTools ignored when both are set. */
|
||||
const plugins = [];
|
||||
for (const plugin of uniquePlugins) {
|
||||
authenticatedPlugins.push(
|
||||
checkPluginAuth(plugin) ? { ...plugin, authenticated: true } : plugin,
|
||||
);
|
||||
if (includeSet.size > 0) {
|
||||
if (!includeSet.has(plugin.pluginKey)) {
|
||||
continue;
|
||||
}
|
||||
} else if (filterSet.has(plugin.pluginKey)) {
|
||||
continue;
|
||||
}
|
||||
plugins.push(checkPluginAuth(plugin) ? { ...plugin, authenticated: true } : plugin);
|
||||
}
|
||||
|
||||
let plugins = authenticatedPlugins;
|
||||
|
||||
if (includedTools.length > 0) {
|
||||
plugins = plugins.filter((plugin) => includedTools.includes(plugin.pluginKey));
|
||||
} else {
|
||||
plugins = plugins.filter((plugin) => !filteredTools.includes(plugin.pluginKey));
|
||||
}
|
||||
|
||||
await cache.set(CacheKeys.PLUGINS, plugins);
|
||||
res.status(200).json(plugins);
|
||||
} catch (error) {
|
||||
res.status(500).json({ message: error.message });
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Retrieves and returns a list of available tools, either from a cache or by reading a plugin manifest file.
|
||||
*
|
||||
* This function first attempts to retrieve the list of tools from a cache. If the tools are not found in the cache,
|
||||
* it reads a plugin manifest file, filters for unique plugins, and determines if each plugin is authenticated.
|
||||
* Only plugins that are marked as available in the application's local state are included in the final list.
|
||||
* The resulting list of tools is then cached and sent to the client.
|
||||
*
|
||||
* @param {object} req - The request object, containing information about the HTTP request.
|
||||
* @param {object} res - The response object, used to send back the desired HTTP response.
|
||||
* @returns {Promise<void>} A promise that resolves when the function has completed.
|
||||
*/
|
||||
const getAvailableTools = async (req, res) => {
|
||||
try {
|
||||
const userId = req.user?.id;
|
||||
|
|
@ -63,18 +39,10 @@ const getAvailableTools = async (req, res) => {
|
|||
logger.warn('[getAvailableTools] User ID not found in request');
|
||||
return res.status(401).json({ message: 'Unauthorized' });
|
||||
}
|
||||
const cache = getLogStores(CacheKeys.TOOL_CACHE);
|
||||
const cachedToolsArray = await cache.get(CacheKeys.TOOLS);
|
||||
|
||||
const appConfig = req.config ?? (await getAppConfig({ role: req.user?.role }));
|
||||
const appConfig =
|
||||
req.config ?? (await getAppConfig({ role: req.user?.role, tenantId: req.user?.tenantId }));
|
||||
|
||||
// Return early if we have cached tools
|
||||
if (cachedToolsArray != null) {
|
||||
res.status(200).json(cachedToolsArray);
|
||||
return;
|
||||
}
|
||||
|
||||
/** @type {Record<string, FunctionTool> | null} Get tool definitions to filter which tools are actually available */
|
||||
let toolDefinitions = await getCachedTools();
|
||||
|
||||
if (toolDefinitions == null && appConfig?.availableTools != null) {
|
||||
|
|
@ -83,26 +51,17 @@ const getAvailableTools = async (req, res) => {
|
|||
toolDefinitions = appConfig.availableTools;
|
||||
}
|
||||
|
||||
/** @type {import('@librechat/api').LCManifestTool[]} */
|
||||
let pluginManifest = availableTools;
|
||||
const uniquePlugins = filterUniquePlugins(availableTools);
|
||||
const toolDefKeysList = toolDefinitions ? Object.keys(toolDefinitions) : null;
|
||||
const toolDefKeys = toolDefKeysList ? new Set(toolDefKeysList) : null;
|
||||
|
||||
/** @type {TPlugin[]} Deduplicate and authenticate plugins */
|
||||
const uniquePlugins = filterUniquePlugins(pluginManifest);
|
||||
const authenticatedPlugins = uniquePlugins.map((plugin) => {
|
||||
if (checkPluginAuth(plugin)) {
|
||||
return { ...plugin, authenticated: true };
|
||||
} else {
|
||||
return plugin;
|
||||
}
|
||||
});
|
||||
|
||||
/** Filter plugins based on availability */
|
||||
const toolsOutput = [];
|
||||
for (const plugin of authenticatedPlugins) {
|
||||
const isToolDefined = toolDefinitions?.[plugin.pluginKey] !== undefined;
|
||||
for (const plugin of uniquePlugins) {
|
||||
const isToolDefined = toolDefKeys?.has(plugin.pluginKey) === true;
|
||||
const isToolkit =
|
||||
plugin.toolkit === true &&
|
||||
Object.keys(toolDefinitions ?? {}).some(
|
||||
toolDefKeysList != null &&
|
||||
toolDefKeysList.some(
|
||||
(key) => getToolkitKey({ toolkits, toolName: key }) === plugin.pluginKey,
|
||||
);
|
||||
|
||||
|
|
@ -110,13 +69,10 @@ const getAvailableTools = async (req, res) => {
|
|||
continue;
|
||||
}
|
||||
|
||||
toolsOutput.push(plugin);
|
||||
toolsOutput.push(checkPluginAuth(plugin) ? { ...plugin, authenticated: true } : plugin);
|
||||
}
|
||||
|
||||
const finalTools = filterUniquePlugins(toolsOutput);
|
||||
await cache.set(CacheKeys.TOOLS, finalTools);
|
||||
|
||||
res.status(200).json(finalTools);
|
||||
res.status(200).json(toolsOutput);
|
||||
} catch (error) {
|
||||
logger.error('[getAvailableTools]', error);
|
||||
res.status(500).json({ message: error.message });
|
||||
|
|
|
|||
|
|
@ -1,6 +1,4 @@
|
|||
const { CacheKeys } = require('librechat-data-provider');
|
||||
const { getCachedTools, getAppConfig } = require('~/server/services/Config');
|
||||
const { getLogStores } = require('~/cache');
|
||||
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
logger: {
|
||||
|
|
@ -19,22 +17,15 @@ jest.mock('~/server/services/Config', () => ({
|
|||
setCachedTools: jest.fn(),
|
||||
}));
|
||||
|
||||
// loadAndFormatTools mock removed - no longer used in PluginController
|
||||
// getMCPManager mock removed - no longer used in PluginController
|
||||
|
||||
jest.mock('~/app/clients/tools', () => ({
|
||||
availableTools: [],
|
||||
toolkits: [],
|
||||
}));
|
||||
|
||||
jest.mock('~/cache', () => ({
|
||||
getLogStores: jest.fn(),
|
||||
}));
|
||||
|
||||
const { getAvailableTools, getAvailablePluginsController } = require('./PluginController');
|
||||
|
||||
describe('PluginController', () => {
|
||||
let mockReq, mockRes, mockCache;
|
||||
let mockReq, mockRes;
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
|
|
@ -46,17 +37,12 @@ describe('PluginController', () => {
|
|||
},
|
||||
};
|
||||
mockRes = { status: jest.fn().mockReturnThis(), json: jest.fn() };
|
||||
mockCache = { get: jest.fn(), set: jest.fn() };
|
||||
getLogStores.mockReturnValue(mockCache);
|
||||
|
||||
// Clear availableTools and toolkits arrays before each test
|
||||
require('~/app/clients/tools').availableTools.length = 0;
|
||||
require('~/app/clients/tools').toolkits.length = 0;
|
||||
|
||||
// Reset getCachedTools mock to ensure clean state
|
||||
getCachedTools.mockReset();
|
||||
|
||||
// Reset getAppConfig mock to ensure clean state with default values
|
||||
getAppConfig.mockReset();
|
||||
getAppConfig.mockResolvedValue({
|
||||
filteredTools: [],
|
||||
|
|
@ -64,31 +50,8 @@ describe('PluginController', () => {
|
|||
});
|
||||
});
|
||||
|
||||
describe('cache namespace', () => {
|
||||
it('getAvailablePluginsController should use TOOL_CACHE namespace', async () => {
|
||||
mockCache.get.mockResolvedValue([]);
|
||||
await getAvailablePluginsController(mockReq, mockRes);
|
||||
expect(getLogStores).toHaveBeenCalledWith(CacheKeys.TOOL_CACHE);
|
||||
});
|
||||
|
||||
it('getAvailableTools should use TOOL_CACHE namespace', async () => {
|
||||
mockCache.get.mockResolvedValue([]);
|
||||
await getAvailableTools(mockReq, mockRes);
|
||||
expect(getLogStores).toHaveBeenCalledWith(CacheKeys.TOOL_CACHE);
|
||||
});
|
||||
|
||||
it('should NOT use CONFIG_STORE namespace for tool/plugin operations', async () => {
|
||||
mockCache.get.mockResolvedValue([]);
|
||||
await getAvailablePluginsController(mockReq, mockRes);
|
||||
await getAvailableTools(mockReq, mockRes);
|
||||
const allCalls = getLogStores.mock.calls.flat();
|
||||
expect(allCalls).not.toContain(CacheKeys.CONFIG_STORE);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getAvailablePluginsController', () => {
|
||||
it('should use filterUniquePlugins to remove duplicate plugins', async () => {
|
||||
// Add plugins with duplicates to availableTools
|
||||
const mockPlugins = [
|
||||
{ name: 'Plugin1', pluginKey: 'key1', description: 'First' },
|
||||
{ name: 'Plugin1', pluginKey: 'key1', description: 'First duplicate' },
|
||||
|
|
@ -97,9 +60,6 @@ describe('PluginController', () => {
|
|||
|
||||
require('~/app/clients/tools').availableTools.push(...mockPlugins);
|
||||
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
|
||||
// Configure getAppConfig to return the expected config
|
||||
getAppConfig.mockResolvedValueOnce({
|
||||
filteredTools: [],
|
||||
includedTools: [],
|
||||
|
|
@ -109,21 +69,16 @@ describe('PluginController', () => {
|
|||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(200);
|
||||
const responseData = mockRes.json.mock.calls[0][0];
|
||||
// The real filterUniquePlugins should have removed the duplicate
|
||||
expect(responseData).toHaveLength(2);
|
||||
expect(responseData[0].pluginKey).toBe('key1');
|
||||
expect(responseData[1].pluginKey).toBe('key2');
|
||||
});
|
||||
|
||||
it('should use checkPluginAuth to verify plugin authentication', async () => {
|
||||
// checkPluginAuth returns false for plugins without authConfig
|
||||
// so authenticated property won't be added
|
||||
const mockPlugin = { name: 'Plugin1', pluginKey: 'key1', description: 'First' };
|
||||
|
||||
require('~/app/clients/tools').availableTools.push(mockPlugin);
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
|
||||
// Configure getAppConfig to return the expected config
|
||||
getAppConfig.mockResolvedValueOnce({
|
||||
filteredTools: [],
|
||||
includedTools: [],
|
||||
|
|
@ -132,23 +87,9 @@ describe('PluginController', () => {
|
|||
await getAvailablePluginsController(mockReq, mockRes);
|
||||
|
||||
const responseData = mockRes.json.mock.calls[0][0];
|
||||
// The real checkPluginAuth returns false for plugins without authConfig, so authenticated property is not added
|
||||
expect(responseData[0].authenticated).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should return cached plugins when available', async () => {
|
||||
const cachedPlugins = [
|
||||
{ name: 'CachedPlugin', pluginKey: 'cached', description: 'Cached plugin' },
|
||||
];
|
||||
|
||||
mockCache.get.mockResolvedValue(cachedPlugins);
|
||||
|
||||
await getAvailablePluginsController(mockReq, mockRes);
|
||||
|
||||
// When cache is hit, we return immediately without processing
|
||||
expect(mockRes.json).toHaveBeenCalledWith(cachedPlugins);
|
||||
});
|
||||
|
||||
it('should filter plugins based on includedTools', async () => {
|
||||
const mockPlugins = [
|
||||
{ name: 'Plugin1', pluginKey: 'key1', description: 'First' },
|
||||
|
|
@ -156,9 +97,7 @@ describe('PluginController', () => {
|
|||
];
|
||||
|
||||
require('~/app/clients/tools').availableTools.push(...mockPlugins);
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
|
||||
// Configure getAppConfig to return config with includedTools
|
||||
getAppConfig.mockResolvedValueOnce({
|
||||
filteredTools: [],
|
||||
includedTools: ['key1'],
|
||||
|
|
@ -170,6 +109,47 @@ describe('PluginController', () => {
|
|||
expect(responseData).toHaveLength(1);
|
||||
expect(responseData[0].pluginKey).toBe('key1');
|
||||
});
|
||||
|
||||
it('should exclude plugins in filteredTools', async () => {
|
||||
const mockPlugins = [
|
||||
{ name: 'Plugin1', pluginKey: 'key1', description: 'First' },
|
||||
{ name: 'Plugin2', pluginKey: 'key2', description: 'Second' },
|
||||
];
|
||||
|
||||
require('~/app/clients/tools').availableTools.push(...mockPlugins);
|
||||
|
||||
getAppConfig.mockResolvedValueOnce({
|
||||
filteredTools: ['key2'],
|
||||
includedTools: [],
|
||||
});
|
||||
|
||||
await getAvailablePluginsController(mockReq, mockRes);
|
||||
|
||||
const responseData = mockRes.json.mock.calls[0][0];
|
||||
expect(responseData).toHaveLength(1);
|
||||
expect(responseData[0].pluginKey).toBe('key1');
|
||||
});
|
||||
|
||||
it('should ignore filteredTools when includedTools is set', async () => {
|
||||
const mockPlugins = [
|
||||
{ name: 'Plugin1', pluginKey: 'key1', description: 'First' },
|
||||
{ name: 'Plugin2', pluginKey: 'key2', description: 'Second' },
|
||||
{ name: 'Plugin3', pluginKey: 'key3', description: 'Third' },
|
||||
];
|
||||
|
||||
require('~/app/clients/tools').availableTools.push(...mockPlugins);
|
||||
|
||||
getAppConfig.mockResolvedValueOnce({
|
||||
includedTools: ['key1', 'key2'],
|
||||
filteredTools: ['key2'],
|
||||
});
|
||||
|
||||
await getAvailablePluginsController(mockReq, mockRes);
|
||||
|
||||
const responseData = mockRes.json.mock.calls[0][0];
|
||||
expect(responseData).toHaveLength(2);
|
||||
expect(responseData.map((p) => p.pluginKey)).toEqual(['key1', 'key2']);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getAvailableTools', () => {
|
||||
|
|
@ -185,12 +165,11 @@ describe('PluginController', () => {
|
|||
},
|
||||
};
|
||||
|
||||
const mockCachedPlugins = [
|
||||
require('~/app/clients/tools').availableTools.push(
|
||||
{ name: 'user-tool', pluginKey: 'user-tool', description: 'Duplicate user tool' },
|
||||
{ name: 'ManifestTool', pluginKey: 'manifest-tool', description: 'Manifest tool' },
|
||||
];
|
||||
);
|
||||
|
||||
mockCache.get.mockResolvedValue(mockCachedPlugins);
|
||||
getCachedTools.mockResolvedValueOnce(mockUserTools);
|
||||
mockReq.config = {
|
||||
mcpConfig: null,
|
||||
|
|
@ -202,24 +181,19 @@ describe('PluginController', () => {
|
|||
expect(mockRes.status).toHaveBeenCalledWith(200);
|
||||
const responseData = mockRes.json.mock.calls[0][0];
|
||||
expect(Array.isArray(responseData)).toBe(true);
|
||||
// The real filterUniquePlugins should have deduplicated tools with same pluginKey
|
||||
const userToolCount = responseData.filter((tool) => tool.pluginKey === 'user-tool').length;
|
||||
expect(userToolCount).toBe(1);
|
||||
});
|
||||
|
||||
it('should use checkPluginAuth to verify authentication status', async () => {
|
||||
// Add a plugin to availableTools that will be checked
|
||||
const mockPlugin = {
|
||||
name: 'Tool1',
|
||||
pluginKey: 'tool1',
|
||||
description: 'Tool 1',
|
||||
// No authConfig means checkPluginAuth returns false
|
||||
};
|
||||
|
||||
require('~/app/clients/tools').availableTools.push(mockPlugin);
|
||||
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
// getCachedTools returns the tool definitions
|
||||
getCachedTools.mockResolvedValueOnce({
|
||||
tool1: {
|
||||
type: 'function',
|
||||
|
|
@ -242,7 +216,6 @@ describe('PluginController', () => {
|
|||
expect(Array.isArray(responseData)).toBe(true);
|
||||
const tool = responseData.find((t) => t.pluginKey === 'tool1');
|
||||
expect(tool).toBeDefined();
|
||||
// The real checkPluginAuth returns false for plugins without authConfig, so authenticated property is not added
|
||||
expect(tool.authenticated).toBeUndefined();
|
||||
});
|
||||
|
||||
|
|
@ -256,15 +229,12 @@ describe('PluginController', () => {
|
|||
|
||||
require('~/app/clients/tools').availableTools.push(mockToolkit);
|
||||
|
||||
// Mock toolkits to have a mapping
|
||||
require('~/app/clients/tools').toolkits.push({
|
||||
name: 'Toolkit1',
|
||||
pluginKey: 'toolkit1',
|
||||
tools: ['toolkit1_function'],
|
||||
});
|
||||
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
// getCachedTools returns the tool definitions
|
||||
getCachedTools.mockResolvedValueOnce({
|
||||
toolkit1_function: {
|
||||
type: 'function',
|
||||
|
|
@ -292,7 +262,7 @@ describe('PluginController', () => {
|
|||
|
||||
describe('helper function integration', () => {
|
||||
it('should handle error cases gracefully', async () => {
|
||||
mockCache.get.mockRejectedValue(new Error('Cache error'));
|
||||
getCachedTools.mockRejectedValue(new Error('Cache error'));
|
||||
|
||||
await getAvailableTools(mockReq, mockRes);
|
||||
|
||||
|
|
@ -302,17 +272,7 @@ describe('PluginController', () => {
|
|||
});
|
||||
|
||||
describe('edge cases with undefined/null values', () => {
|
||||
it('should handle undefined cache gracefully', async () => {
|
||||
getLogStores.mockReturnValue(undefined);
|
||||
|
||||
await getAvailableTools(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(500);
|
||||
});
|
||||
|
||||
it('should handle null cachedTools and cachedUserTools', async () => {
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
// getCachedTools returns empty object instead of null
|
||||
it('should handle null cachedTools', async () => {
|
||||
getCachedTools.mockResolvedValueOnce({});
|
||||
mockReq.config = {
|
||||
mcpConfig: null,
|
||||
|
|
@ -321,51 +281,40 @@ describe('PluginController', () => {
|
|||
|
||||
await getAvailableTools(mockReq, mockRes);
|
||||
|
||||
// Should handle null values gracefully
|
||||
expect(mockRes.status).toHaveBeenCalledWith(200);
|
||||
expect(mockRes.json).toHaveBeenCalledWith([]);
|
||||
});
|
||||
|
||||
it('should handle when getCachedTools returns undefined', async () => {
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
mockReq.config = {
|
||||
mcpConfig: null,
|
||||
paths: { structuredTools: '/mock/path' },
|
||||
};
|
||||
|
||||
// Mock getCachedTools to return undefined
|
||||
getCachedTools.mockReset();
|
||||
getCachedTools.mockResolvedValueOnce(undefined);
|
||||
|
||||
await getAvailableTools(mockReq, mockRes);
|
||||
|
||||
// Should handle undefined values gracefully
|
||||
expect(mockRes.status).toHaveBeenCalledWith(200);
|
||||
expect(mockRes.json).toHaveBeenCalledWith([]);
|
||||
});
|
||||
|
||||
it('should handle empty toolDefinitions object', async () => {
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
// Reset getCachedTools to ensure clean state
|
||||
getCachedTools.mockReset();
|
||||
getCachedTools.mockResolvedValue({});
|
||||
mockReq.config = {}; // No mcpConfig at all
|
||||
mockReq.config = {};
|
||||
|
||||
// Ensure no plugins are available
|
||||
require('~/app/clients/tools').availableTools.length = 0;
|
||||
|
||||
await getAvailableTools(mockReq, mockRes);
|
||||
|
||||
// With empty tool definitions, no tools should be in the final output
|
||||
expect(mockRes.json).toHaveBeenCalledWith([]);
|
||||
});
|
||||
|
||||
it('should handle undefined filteredTools and includedTools', async () => {
|
||||
mockReq.config = {};
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
|
||||
// Configure getAppConfig to return config with undefined properties
|
||||
// The controller will use default values [] for filteredTools and includedTools
|
||||
getAppConfig.mockResolvedValueOnce({});
|
||||
|
||||
await getAvailablePluginsController(mockReq, mockRes);
|
||||
|
|
@ -382,13 +331,8 @@ describe('PluginController', () => {
|
|||
toolkit: true,
|
||||
};
|
||||
|
||||
// No need to mock app.locals anymore as it's not used
|
||||
|
||||
// Add the toolkit to availableTools
|
||||
require('~/app/clients/tools').availableTools.push(mockToolkit);
|
||||
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
// getCachedTools returns empty object to avoid null reference error
|
||||
getCachedTools.mockResolvedValueOnce({});
|
||||
mockReq.config = {
|
||||
mcpConfig: null,
|
||||
|
|
@ -397,43 +341,32 @@ describe('PluginController', () => {
|
|||
|
||||
await getAvailableTools(mockReq, mockRes);
|
||||
|
||||
// Should handle null toolDefinitions gracefully
|
||||
expect(mockRes.status).toHaveBeenCalledWith(200);
|
||||
});
|
||||
|
||||
it('should handle undefined toolDefinitions when checking isToolDefined (traversaal_search bug)', async () => {
|
||||
// This test reproduces the bug where toolDefinitions is undefined
|
||||
// and accessing toolDefinitions[plugin.pluginKey] causes a TypeError
|
||||
it('should handle undefined toolDefinitions when checking isToolDefined', async () => {
|
||||
const mockPlugin = {
|
||||
name: 'Traversaal Search',
|
||||
pluginKey: 'traversaal_search',
|
||||
description: 'Search plugin',
|
||||
};
|
||||
|
||||
// Add the plugin to availableTools
|
||||
require('~/app/clients/tools').availableTools.push(mockPlugin);
|
||||
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
|
||||
mockReq.config = {
|
||||
mcpConfig: null,
|
||||
paths: { structuredTools: '/mock/path' },
|
||||
};
|
||||
|
||||
// CRITICAL: getCachedTools returns undefined
|
||||
// This is what causes the bug when trying to access toolDefinitions[plugin.pluginKey]
|
||||
getCachedTools.mockResolvedValueOnce(undefined);
|
||||
|
||||
// This should not throw an error with the optional chaining fix
|
||||
await getAvailableTools(mockReq, mockRes);
|
||||
|
||||
// Should handle undefined toolDefinitions gracefully and return empty array
|
||||
expect(mockRes.status).toHaveBeenCalledWith(200);
|
||||
expect(mockRes.json).toHaveBeenCalledWith([]);
|
||||
});
|
||||
|
||||
it('should re-initialize tools from appConfig when cache returns null', async () => {
|
||||
// Setup: Initial state with tools in appConfig
|
||||
const mockAppTools = {
|
||||
tool1: {
|
||||
type: 'function',
|
||||
|
|
@ -453,15 +386,12 @@ describe('PluginController', () => {
|
|||
},
|
||||
};
|
||||
|
||||
// Add matching plugins to availableTools
|
||||
require('~/app/clients/tools').availableTools.push(
|
||||
{ name: 'Tool 1', pluginKey: 'tool1', description: 'Tool 1' },
|
||||
{ name: 'Tool 2', pluginKey: 'tool2', description: 'Tool 2' },
|
||||
);
|
||||
|
||||
// Simulate cache cleared state (returns null)
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
getCachedTools.mockResolvedValueOnce(null); // Global tools (cache cleared)
|
||||
getCachedTools.mockResolvedValueOnce(null);
|
||||
|
||||
mockReq.config = {
|
||||
filteredTools: [],
|
||||
|
|
@ -469,15 +399,12 @@ describe('PluginController', () => {
|
|||
availableTools: mockAppTools,
|
||||
};
|
||||
|
||||
// Mock setCachedTools to verify it's called to re-initialize
|
||||
const { setCachedTools } = require('~/server/services/Config');
|
||||
|
||||
await getAvailableTools(mockReq, mockRes);
|
||||
|
||||
// Should have re-initialized the cache with tools from appConfig
|
||||
expect(setCachedTools).toHaveBeenCalledWith(mockAppTools);
|
||||
|
||||
// Should still return tools successfully
|
||||
expect(mockRes.status).toHaveBeenCalledWith(200);
|
||||
const responseData = mockRes.json.mock.calls[0][0];
|
||||
expect(responseData).toHaveLength(2);
|
||||
|
|
@ -486,29 +413,22 @@ describe('PluginController', () => {
|
|||
});
|
||||
|
||||
it('should handle cache clear without appConfig.availableTools gracefully', async () => {
|
||||
// Setup: appConfig without availableTools
|
||||
getAppConfig.mockResolvedValue({
|
||||
filteredTools: [],
|
||||
includedTools: [],
|
||||
// No availableTools property
|
||||
});
|
||||
|
||||
// Clear availableTools array
|
||||
require('~/app/clients/tools').availableTools.length = 0;
|
||||
|
||||
// Cache returns null (cleared state)
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
getCachedTools.mockResolvedValueOnce(null); // Global tools (cache cleared)
|
||||
getCachedTools.mockResolvedValueOnce(null);
|
||||
|
||||
mockReq.config = {
|
||||
filteredTools: [],
|
||||
includedTools: [],
|
||||
// No availableTools
|
||||
};
|
||||
|
||||
await getAvailableTools(mockReq, mockRes);
|
||||
|
||||
// Should handle gracefully without crashing
|
||||
expect(mockRes.status).toHaveBeenCalledWith(200);
|
||||
expect(mockRes.json).toHaveBeenCalledWith([]);
|
||||
});
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ const { getLogStores } = require('~/cache');
|
|||
const db = require('~/models');
|
||||
|
||||
const getUserController = async (req, res) => {
|
||||
const appConfig = await getAppConfig({ role: req.user?.role });
|
||||
const appConfig = await getAppConfig({ role: req.user?.role, tenantId: req.user?.tenantId });
|
||||
/** @type {IUser} */
|
||||
const userData = req.user.toObject != null ? req.user.toObject() : { ...req.user };
|
||||
/**
|
||||
|
|
@ -165,7 +165,7 @@ const deleteUserMcpServers = async (userId) => {
|
|||
};
|
||||
|
||||
const updateUserPluginsController = async (req, res) => {
|
||||
const appConfig = await getAppConfig({ role: req.user?.role });
|
||||
const appConfig = await getAppConfig({ role: req.user?.role, tenantId: req.user?.tenantId });
|
||||
const { user } = req;
|
||||
const { pluginKey, action, auth, isEntityTool } = req.body;
|
||||
try {
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
* Tests that recordCollectedUsage is called correctly for token spending
|
||||
*/
|
||||
|
||||
const mockProcessStream = jest.fn().mockResolvedValue(undefined);
|
||||
const mockSpendTokens = jest.fn().mockResolvedValue({});
|
||||
const mockSpendStructuredTokens = jest.fn().mockResolvedValue({});
|
||||
const mockRecordCollectedUsage = jest
|
||||
|
|
@ -35,7 +36,7 @@ jest.mock('@librechat/agents', () => ({
|
|||
jest.mock('@librechat/api', () => ({
|
||||
writeSSE: jest.fn(),
|
||||
createRun: jest.fn().mockResolvedValue({
|
||||
processStream: jest.fn().mockResolvedValue(undefined),
|
||||
processStream: mockProcessStream,
|
||||
}),
|
||||
createChunk: jest.fn().mockReturnValue({}),
|
||||
buildToolSet: jest.fn().mockReturnValue(new Set()),
|
||||
|
|
@ -68,6 +69,7 @@ jest.mock('@librechat/api', () => ({
|
|||
toolCalls: new Map(),
|
||||
usage: { promptTokens: 100, completionTokens: 50, reasoningTokens: 0 },
|
||||
}),
|
||||
resolveRecursionLimit: jest.fn().mockReturnValue(50),
|
||||
createToolExecuteHandler: jest.fn().mockReturnValue({ handle: jest.fn() }),
|
||||
isChatCompletionValidationFailure: jest.fn().mockReturnValue(false),
|
||||
}));
|
||||
|
|
@ -286,4 +288,36 @@ describe('OpenAIChatCompletionController', () => {
|
|||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('recursionLimit resolution', () => {
|
||||
it('should pass resolveRecursionLimit result to processStream config', async () => {
|
||||
const { resolveRecursionLimit } = require('@librechat/api');
|
||||
resolveRecursionLimit.mockReturnValueOnce(75);
|
||||
|
||||
await OpenAIChatCompletionController(req, res);
|
||||
|
||||
expect(mockProcessStream).toHaveBeenCalledWith(
|
||||
expect.anything(),
|
||||
expect.objectContaining({ recursionLimit: 75 }),
|
||||
expect.anything(),
|
||||
);
|
||||
});
|
||||
|
||||
it('should call resolveRecursionLimit with agentsEConfig and agent', async () => {
|
||||
const { resolveRecursionLimit } = require('@librechat/api');
|
||||
const { getAgent } = require('~/models');
|
||||
const mockAgent = { id: 'agent-123', name: 'Test', recursion_limit: 200 };
|
||||
getAgent.mockResolvedValueOnce(mockAgent);
|
||||
|
||||
req.config = {
|
||||
endpoints: {
|
||||
agents: { recursionLimit: 100, maxRecursionLimit: 150, allowedProviders: [] },
|
||||
},
|
||||
};
|
||||
|
||||
await OpenAIChatCompletionController(req, res);
|
||||
|
||||
expect(resolveRecursionLimit).toHaveBeenCalledWith(req.config.endpoints.agents, mockAgent);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ const {
|
|||
recordCollectedUsage,
|
||||
GenerationJobManager,
|
||||
getTransactionsConfig,
|
||||
resolveRecursionLimit,
|
||||
createMemoryProcessor,
|
||||
loadAgent: loadAgentFn,
|
||||
createMultiAgentMapper,
|
||||
|
|
@ -50,6 +51,7 @@ const {
|
|||
const { filterFilesByAgentAccess } = require('~/server/services/Files/permissions');
|
||||
const { encodeAndFormat } = require('~/server/services/Files/images/encode');
|
||||
const { createContextHandlers } = require('~/app/clients/prompts');
|
||||
const { resolveConfigServers } = require('~/server/services/MCP');
|
||||
const { getMCPServerTools } = require('~/server/services/Config');
|
||||
const BaseClient = require('~/app/clients/BaseClient');
|
||||
const { getMCPManager } = require('~/config');
|
||||
|
|
@ -377,6 +379,9 @@ class AgentClient extends BaseClient {
|
|||
*/
|
||||
const ephemeralAgent = this.options.req.body.ephemeralAgent;
|
||||
const mcpManager = getMCPManager();
|
||||
|
||||
const configServers = await resolveConfigServers(this.options.req);
|
||||
|
||||
await Promise.all(
|
||||
allAgents.map(({ agent, agentId }) =>
|
||||
applyContextToAgent({
|
||||
|
|
@ -384,6 +389,7 @@ class AgentClient extends BaseClient {
|
|||
agentId,
|
||||
logger,
|
||||
mcpManager,
|
||||
configServers,
|
||||
sharedRunContext,
|
||||
ephemeralAgent: agentId === this.options.agent.id ? ephemeralAgent : undefined,
|
||||
}),
|
||||
|
|
@ -728,7 +734,7 @@ class AgentClient extends BaseClient {
|
|||
},
|
||||
user: createSafeUser(this.options.req.user),
|
||||
},
|
||||
recursionLimit: agentsEConfig?.recursionLimit ?? 50,
|
||||
recursionLimit: resolveRecursionLimit(agentsEConfig, this.options.agent),
|
||||
signal: abortController.signal,
|
||||
streamMode: 'values',
|
||||
version: 'v2',
|
||||
|
|
@ -776,17 +782,6 @@ class AgentClient extends BaseClient {
|
|||
agents.push(...this.agentConfigs.values());
|
||||
}
|
||||
|
||||
if (agents[0].recursion_limit && typeof agents[0].recursion_limit === 'number') {
|
||||
config.recursionLimit = agents[0].recursion_limit;
|
||||
}
|
||||
|
||||
if (
|
||||
agentsEConfig?.maxRecursionLimit &&
|
||||
config.recursionLimit > agentsEConfig?.maxRecursionLimit
|
||||
) {
|
||||
config.recursionLimit = agentsEConfig?.maxRecursionLimit;
|
||||
}
|
||||
|
||||
// TODO: needs to be added as part of AgentContext initialization
|
||||
// const noSystemModelRegex = [/\b(o1-preview|o1-mini|amazon\.titan-text)\b/gi];
|
||||
// const noSystemMessages = noSystemModelRegex.some((regex) =>
|
||||
|
|
|
|||
|
|
@ -22,6 +22,10 @@ jest.mock('~/server/services/Config', () => ({
|
|||
getMCPServerTools: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/MCP', () => ({
|
||||
resolveConfigServers: jest.fn().mockResolvedValue({}),
|
||||
}));
|
||||
|
||||
jest.mock('~/models', () => ({
|
||||
getAgent: jest.fn(),
|
||||
getRoleByName: jest.fn(),
|
||||
|
|
@ -1315,7 +1319,7 @@ describe('AgentClient - titleConvo', () => {
|
|||
});
|
||||
|
||||
// Verify formatInstructionsForContext was called with correct server names
|
||||
expect(mockFormatInstructions).toHaveBeenCalledWith(['server1', 'server2']);
|
||||
expect(mockFormatInstructions).toHaveBeenCalledWith(['server1', 'server2'], {});
|
||||
|
||||
// Verify the instructions do NOT contain [object Promise]
|
||||
expect(client.options.agent.instructions).not.toContain('[object Promise]');
|
||||
|
|
@ -1355,10 +1359,10 @@ describe('AgentClient - titleConvo', () => {
|
|||
});
|
||||
|
||||
// Verify formatInstructionsForContext was called with ephemeral server names
|
||||
expect(mockFormatInstructions).toHaveBeenCalledWith([
|
||||
'ephemeral-server1',
|
||||
'ephemeral-server2',
|
||||
]);
|
||||
expect(mockFormatInstructions).toHaveBeenCalledWith(
|
||||
['ephemeral-server1', 'ephemeral-server2'],
|
||||
{},
|
||||
);
|
||||
|
||||
// Verify no [object Promise] in instructions
|
||||
expect(client.options.agent.instructions).not.toContain('[object Promise]');
|
||||
|
|
|
|||
|
|
@ -22,6 +22,10 @@ jest.mock('~/config', () => ({
|
|||
})),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/MCP', () => ({
|
||||
resolveConfigServers: jest.fn().mockResolvedValue({}),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Files/strategies', () => ({
|
||||
getStrategyFunctions: jest.fn(),
|
||||
}));
|
||||
|
|
@ -223,7 +227,27 @@ describe('MCP Tool Authorization', () => {
|
|||
availableTools,
|
||||
});
|
||||
|
||||
expect(mockGetAllServerConfigs).toHaveBeenCalledWith('specific-user-id');
|
||||
expect(mockGetAllServerConfigs).toHaveBeenCalledWith('specific-user-id', undefined);
|
||||
});
|
||||
|
||||
test('should pass configServers to getAllServerConfigs and allow config-override servers', async () => {
|
||||
const configServers = {
|
||||
'config-override-server': { type: 'sse', url: 'https://override.example.com' },
|
||||
};
|
||||
mockGetAllServerConfigs.mockResolvedValue({
|
||||
'config-override-server': configServers['config-override-server'],
|
||||
});
|
||||
|
||||
const result = await filterAuthorizedTools({
|
||||
tools: [`tool${d}config-override-server`, `tool${d}unauthorizedServer`],
|
||||
userId,
|
||||
availableTools,
|
||||
configServers,
|
||||
});
|
||||
|
||||
expect(mockGetAllServerConfigs).toHaveBeenCalledWith(userId, configServers);
|
||||
expect(result).toContain(`tool${d}config-override-server`);
|
||||
expect(result).not.toContain(`tool${d}unauthorizedServer`);
|
||||
});
|
||||
|
||||
test('should only call getAllServerConfigs once even with multiple MCP tools', async () => {
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ const {
|
|||
createErrorResponse,
|
||||
recordCollectedUsage,
|
||||
getTransactionsConfig,
|
||||
resolveRecursionLimit,
|
||||
createToolExecuteHandler,
|
||||
buildNonStreamingResponse,
|
||||
createOpenAIStreamTracker,
|
||||
|
|
@ -194,10 +195,8 @@ const OpenAIChatCompletionController = async (req, res) => {
|
|||
const conversationId = request.conversation_id ?? nanoid();
|
||||
const parentMessageId = request.parent_message_id ?? null;
|
||||
|
||||
// Build allowed providers set
|
||||
const allowedProviders = new Set(
|
||||
appConfig?.endpoints?.[EModelEndpoint.agents]?.allowedProviders,
|
||||
);
|
||||
const agentsEConfig = appConfig?.endpoints?.[EModelEndpoint.agents];
|
||||
const allowedProviders = new Set(agentsEConfig?.allowedProviders);
|
||||
|
||||
// Create tool loader
|
||||
const loadTools = createToolLoader(abortController.signal);
|
||||
|
|
@ -491,7 +490,6 @@ const OpenAIChatCompletionController = async (req, res) => {
|
|||
throw new Error('Failed to create agent run');
|
||||
}
|
||||
|
||||
// Process the stream
|
||||
const config = {
|
||||
runName: 'AgentRun',
|
||||
configurable: {
|
||||
|
|
@ -504,6 +502,7 @@ const OpenAIChatCompletionController = async (req, res) => {
|
|||
},
|
||||
...(userMCPAuthMap != null && { userMCPAuthMap }),
|
||||
},
|
||||
recursionLimit: resolveRecursionLimit(agentsEConfig, agent),
|
||||
signal: abortController.signal,
|
||||
streamMode: 'values',
|
||||
version: 'v2',
|
||||
|
|
|
|||
|
|
@ -38,6 +38,7 @@ const { resizeAvatar } = require('~/server/services/Files/images/avatar');
|
|||
const { getFileStrategy } = require('~/server/utils/getFileStrategy');
|
||||
const { filterFile } = require('~/server/services/Files/process');
|
||||
const { getCachedTools } = require('~/server/services/Config');
|
||||
const { resolveConfigServers } = require('~/server/services/MCP');
|
||||
const { getMCPServersRegistry } = require('~/config');
|
||||
const { getLogStores } = require('~/cache');
|
||||
const db = require('~/models');
|
||||
|
|
@ -101,9 +102,16 @@ const validateEdgeAgentAccess = async (edges, userId, userRole) => {
|
|||
* @param {string} params.userId - Requesting user ID for MCP server access check
|
||||
* @param {Record<string, unknown>} params.availableTools - Global non-MCP tool cache
|
||||
* @param {string[]} [params.existingTools] - Tools already persisted on the agent document
|
||||
* @param {Record<string, unknown>} [params.configServers] - Config-source MCP servers resolved from appConfig overrides
|
||||
* @returns {Promise<string[]>} Only the authorized subset of tools
|
||||
*/
|
||||
const filterAuthorizedTools = async ({ tools, userId, availableTools, existingTools }) => {
|
||||
const filterAuthorizedTools = async ({
|
||||
tools,
|
||||
userId,
|
||||
availableTools,
|
||||
existingTools,
|
||||
configServers,
|
||||
}) => {
|
||||
const filteredTools = [];
|
||||
let mcpServerConfigs;
|
||||
let registryUnavailable = false;
|
||||
|
|
@ -121,7 +129,8 @@ const filterAuthorizedTools = async ({ tools, userId, availableTools, existingTo
|
|||
|
||||
if (mcpServerConfigs === undefined) {
|
||||
try {
|
||||
mcpServerConfigs = (await getMCPServersRegistry().getAllServerConfigs(userId)) ?? {};
|
||||
mcpServerConfigs =
|
||||
(await getMCPServersRegistry().getAllServerConfigs(userId, configServers)) ?? {};
|
||||
} catch (e) {
|
||||
logger.warn(
|
||||
'[filterAuthorizedTools] MCP registry unavailable, filtering all MCP tools',
|
||||
|
|
@ -192,8 +201,17 @@ const createAgentHandler = async (req, res) => {
|
|||
agentData.author = userId;
|
||||
agentData.tools = [];
|
||||
|
||||
const availableTools = (await getCachedTools()) ?? {};
|
||||
agentData.tools = await filterAuthorizedTools({ tools, userId, availableTools });
|
||||
const hasMCPTools = tools.some((t) => t?.includes(Constants.mcp_delimiter));
|
||||
const [availableTools, configServers] = await Promise.all([
|
||||
getCachedTools().then((t) => t ?? {}),
|
||||
hasMCPTools ? resolveConfigServers(req) : Promise.resolve(undefined),
|
||||
]);
|
||||
agentData.tools = await filterAuthorizedTools({
|
||||
tools,
|
||||
userId,
|
||||
availableTools,
|
||||
configServers,
|
||||
});
|
||||
|
||||
const agent = await db.createAgent(agentData);
|
||||
|
||||
|
|
@ -376,11 +394,15 @@ const updateAgentHandler = async (req, res) => {
|
|||
);
|
||||
|
||||
if (newMCPTools.length > 0) {
|
||||
const availableTools = (await getCachedTools()) ?? {};
|
||||
const [availableTools, configServers] = await Promise.all([
|
||||
getCachedTools().then((t) => t ?? {}),
|
||||
resolveConfigServers(req),
|
||||
]);
|
||||
const approvedNew = await filterAuthorizedTools({
|
||||
tools: newMCPTools,
|
||||
userId: req.user.id,
|
||||
availableTools,
|
||||
configServers,
|
||||
});
|
||||
const rejectedSet = new Set(newMCPTools.filter((t) => !approvedNew.includes(t)));
|
||||
if (rejectedSet.size > 0) {
|
||||
|
|
@ -533,12 +555,16 @@ const duplicateAgentHandler = async (req, res) => {
|
|||
newAgentData.actions = agentActions;
|
||||
|
||||
if (newAgentData.tools?.length) {
|
||||
const availableTools = (await getCachedTools()) ?? {};
|
||||
const [availableTools, configServers] = await Promise.all([
|
||||
getCachedTools().then((t) => t ?? {}),
|
||||
resolveConfigServers(req),
|
||||
]);
|
||||
newAgentData.tools = await filterAuthorizedTools({
|
||||
tools: newAgentData.tools,
|
||||
userId,
|
||||
availableTools,
|
||||
existingTools: newAgentData.tools,
|
||||
configServers,
|
||||
});
|
||||
}
|
||||
|
||||
|
|
@ -873,12 +899,16 @@ const revertAgentVersionHandler = async (req, res) => {
|
|||
let updatedAgent = await db.revertAgentVersion({ id }, version_index);
|
||||
|
||||
if (updatedAgent.tools?.length) {
|
||||
const availableTools = (await getCachedTools()) ?? {};
|
||||
const [availableTools, configServers] = await Promise.all([
|
||||
getCachedTools().then((t) => t ?? {}),
|
||||
resolveConfigServers(req),
|
||||
]);
|
||||
const filteredTools = await filterAuthorizedTools({
|
||||
tools: updatedAgent.tools,
|
||||
userId: req.user.id,
|
||||
availableTools,
|
||||
existingTools: updatedAgent.tools,
|
||||
configServers,
|
||||
});
|
||||
if (filteredTools.length !== updatedAgent.tools.length) {
|
||||
updatedAgent = await db.updateAgent(
|
||||
|
|
|
|||
|
|
@ -40,6 +40,7 @@ const { sendResponse } = require('~/server/middleware/error');
|
|||
const {
|
||||
createAutoRefillTransaction,
|
||||
findBalanceByUser,
|
||||
upsertBalanceFields,
|
||||
getTransactions,
|
||||
getMultiplier,
|
||||
getConvo,
|
||||
|
|
@ -296,7 +297,14 @@ const chatV1 = async (req, res) => {
|
|||
amount: promptTokens,
|
||||
},
|
||||
},
|
||||
{ findBalanceByUser, getMultiplier, createAutoRefillTransaction, logViolation },
|
||||
{
|
||||
findBalanceByUser,
|
||||
getMultiplier,
|
||||
createAutoRefillTransaction,
|
||||
logViolation,
|
||||
balanceConfig,
|
||||
upsertBalanceFields,
|
||||
},
|
||||
);
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ const {
|
|||
getMultiplier,
|
||||
getTransactions,
|
||||
findBalanceByUser,
|
||||
upsertBalanceFields,
|
||||
createAutoRefillTransaction,
|
||||
} = require('~/models');
|
||||
const { logViolation, getLogStores } = require('~/cache');
|
||||
|
|
@ -169,7 +170,14 @@ const chatV2 = async (req, res) => {
|
|||
amount: promptTokens,
|
||||
},
|
||||
},
|
||||
{ findBalanceByUser, getMultiplier, createAutoRefillTransaction, logViolation },
|
||||
{
|
||||
findBalanceByUser,
|
||||
getMultiplier,
|
||||
createAutoRefillTransaction,
|
||||
logViolation,
|
||||
balanceConfig,
|
||||
upsertBalanceFields,
|
||||
},
|
||||
);
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -47,9 +47,15 @@ function createOAuthHandler(redirectUri = domains.client) {
|
|||
const refreshToken =
|
||||
req.user.tokenset?.refresh_token || req.user.federatedTokens?.refresh_token;
|
||||
|
||||
const exchangeCode = await generateAdminExchangeCode(cache, req.user, token, refreshToken);
|
||||
|
||||
const callbackUrl = new URL(redirectUri);
|
||||
const exchangeCode = await generateAdminExchangeCode(
|
||||
cache,
|
||||
req.user,
|
||||
token,
|
||||
refreshToken,
|
||||
callbackUrl.origin,
|
||||
req.pkceChallenge,
|
||||
);
|
||||
callbackUrl.searchParams.set('code', exchangeCode);
|
||||
logger.info(`[OAuth] Admin panel redirect with exchange code for user: ${req.user.email}`);
|
||||
return res.redirect(callbackUrl.toString());
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ const {
|
|||
isMCPInspectionFailedError,
|
||||
} = require('@librechat/api');
|
||||
const { Constants, MCPServerUserInputSchema } = require('librechat-data-provider');
|
||||
const { resolveConfigServers, resolveAllMcpConfigs } = require('~/server/services/MCP');
|
||||
const { cacheMCPServerTools, getMCPServerTools } = require('~/server/services/Config');
|
||||
const { getMCPManager, getMCPServersRegistry } = require('~/config');
|
||||
|
||||
|
|
@ -57,7 +58,7 @@ function handleMCPError(error, res) {
|
|||
}
|
||||
|
||||
/**
|
||||
* Get all MCP tools available to the user
|
||||
* Get all MCP tools available to the user.
|
||||
*/
|
||||
const getMCPTools = async (req, res) => {
|
||||
try {
|
||||
|
|
@ -67,10 +68,10 @@ const getMCPTools = async (req, res) => {
|
|||
return res.status(401).json({ message: 'Unauthorized' });
|
||||
}
|
||||
|
||||
const mcpConfig = await getMCPServersRegistry().getAllServerConfigs(userId);
|
||||
const configuredServers = mcpConfig ? Object.keys(mcpConfig) : [];
|
||||
const mcpConfig = await resolveAllMcpConfigs(userId, req.user);
|
||||
const configuredServers = Object.keys(mcpConfig);
|
||||
|
||||
if (!mcpConfig || Object.keys(mcpConfig).length == 0) {
|
||||
if (!configuredServers.length) {
|
||||
return res.status(200).json({ servers: {} });
|
||||
}
|
||||
|
||||
|
|
@ -115,14 +116,11 @@ const getMCPTools = async (req, res) => {
|
|||
try {
|
||||
const serverTools = serverToolsMap.get(serverName);
|
||||
|
||||
// Get server config once
|
||||
const serverConfig = mcpConfig[serverName];
|
||||
const rawServerConfig = await getMCPServersRegistry().getServerConfig(serverName, userId);
|
||||
|
||||
// Initialize server object with all server-level data
|
||||
const server = {
|
||||
name: serverName,
|
||||
icon: rawServerConfig?.iconPath || '',
|
||||
icon: serverConfig?.iconPath || '',
|
||||
authenticated: true,
|
||||
authConfig: [],
|
||||
tools: [],
|
||||
|
|
@ -183,7 +181,7 @@ const getMCPServersList = async (req, res) => {
|
|||
return res.status(401).json({ message: 'Unauthorized' });
|
||||
}
|
||||
|
||||
const serverConfigs = await getMCPServersRegistry().getAllServerConfigs(userId);
|
||||
const serverConfigs = await resolveAllMcpConfigs(userId, req.user);
|
||||
return res.json(redactAllServerSecrets(serverConfigs));
|
||||
} catch (error) {
|
||||
logger.error('[getMCPServersList]', error);
|
||||
|
|
@ -237,7 +235,12 @@ const getMCPServerById = async (req, res) => {
|
|||
if (!serverName) {
|
||||
return res.status(400).json({ message: 'Server name is required' });
|
||||
}
|
||||
const parsedConfig = await getMCPServersRegistry().getServerConfig(serverName, userId);
|
||||
const configServers = await resolveConfigServers(req);
|
||||
const parsedConfig = await getMCPServersRegistry().getServerConfig(
|
||||
serverName,
|
||||
userId,
|
||||
configServers,
|
||||
);
|
||||
|
||||
if (!parsedConfig) {
|
||||
return res.status(404).json({ message: 'MCP server not found' });
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ const {
|
|||
performStartupChecks,
|
||||
handleJsonParseError,
|
||||
initializeFileStorage,
|
||||
preAuthTenantMiddleware,
|
||||
} = require('@librechat/api');
|
||||
const { connectDb, indexSync } = require('~/db');
|
||||
const initializeOAuthReconnectManager = require('./services/initializeOAuthReconnectManager');
|
||||
|
|
@ -31,6 +32,7 @@ const initializeMCPs = require('./services/initializeMCPs');
|
|||
const configureSocialLogins = require('./socialLogins');
|
||||
const { getAppConfig } = require('./services/Config');
|
||||
const staticCache = require('./utils/staticCache');
|
||||
const optionalJwtAuth = require('./middleware/optionalJwtAuth');
|
||||
const noIndex = require('./middleware/noIndex');
|
||||
const routes = require('./routes');
|
||||
|
||||
|
|
@ -312,7 +314,7 @@ if (cluster.isMaster) {
|
|||
app.use('/api/endpoints', routes.endpoints);
|
||||
app.use('/api/balance', routes.balance);
|
||||
app.use('/api/models', routes.models);
|
||||
app.use('/api/config', routes.config);
|
||||
app.use('/api/config', preAuthTenantMiddleware, optionalJwtAuth, routes.config);
|
||||
app.use('/api/assistants', routes.assistants);
|
||||
app.use('/api/files', await routes.files.initialize());
|
||||
app.use('/images/', createValidateImageRequest(appConfig.secureImageLinks), routes.staticRoute);
|
||||
|
|
|
|||
|
|
@ -8,8 +8,8 @@ const express = require('express');
|
|||
const passport = require('passport');
|
||||
const compression = require('compression');
|
||||
const cookieParser = require('cookie-parser');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const mongoSanitize = require('express-mongo-sanitize');
|
||||
const { logger, runAsSystem } = require('@librechat/data-schemas');
|
||||
const {
|
||||
isEnabled,
|
||||
apiNotFound,
|
||||
|
|
@ -21,6 +21,7 @@ const {
|
|||
createStreamServices,
|
||||
initializeFileStorage,
|
||||
updateInterfacePermissions,
|
||||
preAuthTenantMiddleware,
|
||||
} = require('@librechat/api');
|
||||
const { connectDb, indexSync } = require('~/db');
|
||||
const initializeOAuthReconnectManager = require('./services/initializeOAuthReconnectManager');
|
||||
|
|
@ -33,6 +34,7 @@ const initializeMCPs = require('./services/initializeMCPs');
|
|||
const configureSocialLogins = require('./socialLogins');
|
||||
const { getAppConfig } = require('./services/Config');
|
||||
const staticCache = require('./utils/staticCache');
|
||||
const optionalJwtAuth = require('./middleware/optionalJwtAuth');
|
||||
const noIndex = require('./middleware/noIndex');
|
||||
const routes = require('./routes');
|
||||
|
||||
|
|
@ -59,11 +61,20 @@ const startServer = async () => {
|
|||
app.disable('x-powered-by');
|
||||
app.set('trust proxy', trusted_proxy);
|
||||
|
||||
await seedDatabase();
|
||||
const appConfig = await getAppConfig();
|
||||
if (isEnabled(process.env.TENANT_ISOLATION_STRICT)) {
|
||||
logger.warn(
|
||||
'[Security] TENANT_ISOLATION_STRICT is active. Ensure your reverse proxy strips or sets ' +
|
||||
'the X-Tenant-Id header — untrusted clients must not be able to set it directly.',
|
||||
);
|
||||
}
|
||||
|
||||
await runAsSystem(seedDatabase);
|
||||
const appConfig = await getAppConfig({ baseOnly: true });
|
||||
initializeFileStorage(appConfig);
|
||||
await performStartupChecks(appConfig);
|
||||
await updateInterfacePermissions({ appConfig, getRoleByName, updateAccessPermissions });
|
||||
await runAsSystem(async () => {
|
||||
await performStartupChecks(appConfig);
|
||||
await updateInterfacePermissions({ appConfig, getRoleByName, updateAccessPermissions });
|
||||
});
|
||||
|
||||
const indexPath = path.join(appConfig.paths.dist, 'index.html');
|
||||
let indexHTML = fs.readFileSync(indexPath, 'utf8');
|
||||
|
|
@ -137,10 +148,17 @@ const startServer = async () => {
|
|||
/* Per-request capability cache — must be registered before any route that calls hasCapability */
|
||||
app.use(capabilityContextMiddleware);
|
||||
|
||||
app.use('/oauth', routes.oauth);
|
||||
/* Pre-auth tenant context for unauthenticated routes that need tenant scoping.
|
||||
* The reverse proxy / auth gateway sets `X-Tenant-Id` header for multi-tenant deployments. */
|
||||
app.use('/oauth', preAuthTenantMiddleware, routes.oauth);
|
||||
/* API Endpoints */
|
||||
app.use('/api/auth', routes.auth);
|
||||
app.use('/api/auth', preAuthTenantMiddleware, routes.auth);
|
||||
app.use('/api/admin', routes.adminAuth);
|
||||
app.use('/api/admin/config', routes.adminConfig);
|
||||
app.use('/api/admin/grants', routes.adminGrants);
|
||||
app.use('/api/admin/groups', routes.adminGroups);
|
||||
app.use('/api/admin/roles', routes.adminRoles);
|
||||
app.use('/api/admin/users', routes.adminUsers);
|
||||
app.use('/api/actions', routes.actions);
|
||||
app.use('/api/keys', routes.keys);
|
||||
app.use('/api/api-keys', routes.apiKeys);
|
||||
|
|
@ -154,11 +172,11 @@ const startServer = async () => {
|
|||
app.use('/api/endpoints', routes.endpoints);
|
||||
app.use('/api/balance', routes.balance);
|
||||
app.use('/api/models', routes.models);
|
||||
app.use('/api/config', routes.config);
|
||||
app.use('/api/config', preAuthTenantMiddleware, optionalJwtAuth, routes.config);
|
||||
app.use('/api/assistants', routes.assistants);
|
||||
app.use('/api/files', await routes.files.initialize());
|
||||
app.use('/images/', createValidateImageRequest(appConfig.secureImageLinks), routes.staticRoute);
|
||||
app.use('/api/share', routes.share);
|
||||
app.use('/api/share', preAuthTenantMiddleware, routes.share);
|
||||
app.use('/api/roles', routes.roles);
|
||||
app.use('/api/agents', routes.agents);
|
||||
app.use('/api/banner', routes.banner);
|
||||
|
|
@ -204,8 +222,10 @@ const startServer = async () => {
|
|||
logger.info(`Server listening at http://${host == '0.0.0.0' ? 'localhost' : host}:${port}`);
|
||||
}
|
||||
|
||||
await initializeMCPs();
|
||||
await initializeOAuthReconnectManager();
|
||||
await runAsSystem(async () => {
|
||||
await initializeMCPs();
|
||||
await initializeOAuthReconnectManager();
|
||||
});
|
||||
await checkMigrations();
|
||||
|
||||
// Configure stream services (auto-detects Redis from USE_REDIS env var)
|
||||
|
|
|
|||
116
api/server/middleware/__tests__/requireJwtAuth.spec.js
Normal file
116
api/server/middleware/__tests__/requireJwtAuth.spec.js
Normal file
|
|
@ -0,0 +1,116 @@
|
|||
/**
|
||||
* Integration test: verifies that requireJwtAuth chains tenantContextMiddleware
|
||||
* after successful passport authentication, so ALS tenant context is set for
|
||||
* all downstream middleware and route handlers.
|
||||
*
|
||||
* requireJwtAuth must chain tenantContextMiddleware after passport populates
|
||||
* req.user (not at global app.use() scope where req.user is undefined).
|
||||
* If the chaining is removed, these tests fail.
|
||||
*/
|
||||
|
||||
const { getTenantId } = require('@librechat/data-schemas');
|
||||
|
||||
// ── Mocks ──────────────────────────────────────────────────────────────
|
||||
|
||||
let mockPassportError = null;
|
||||
|
||||
jest.mock('passport', () => ({
|
||||
authenticate: jest.fn(() => {
|
||||
return (req, _res, done) => {
|
||||
if (mockPassportError) {
|
||||
return done(mockPassportError);
|
||||
}
|
||||
if (req._mockUser) {
|
||||
req.user = req._mockUser;
|
||||
}
|
||||
done();
|
||||
};
|
||||
}),
|
||||
}));
|
||||
|
||||
// Mock @librechat/api — the real tenantContextMiddleware is TS and cannot be
|
||||
// required directly from CJS tests. This thin wrapper mirrors the real logic
|
||||
// (read req.user.tenantId, call tenantStorage.run) using the same data-schemas
|
||||
// primitives. The real implementation is covered by packages/api tenant.spec.ts.
|
||||
jest.mock('@librechat/api', () => {
|
||||
const { tenantStorage } = require('@librechat/data-schemas');
|
||||
return {
|
||||
isEnabled: jest.fn(() => false),
|
||||
tenantContextMiddleware: (req, res, next) => {
|
||||
const tenantId = req.user?.tenantId;
|
||||
if (!tenantId) {
|
||||
return next();
|
||||
}
|
||||
return tenantStorage.run({ tenantId }, async () => next());
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
// ── Helpers ─────────────────────────────────────────────────────────────
|
||||
|
||||
const requireJwtAuth = require('../requireJwtAuth');
|
||||
|
||||
function mockReq(user) {
|
||||
return { headers: {}, _mockUser: user };
|
||||
}
|
||||
|
||||
function mockRes() {
|
||||
return { status: jest.fn().mockReturnThis(), json: jest.fn().mockReturnThis() };
|
||||
}
|
||||
|
||||
/** Runs requireJwtAuth and returns the tenantId observed inside next(). */
|
||||
function runAuth(user) {
|
||||
return new Promise((resolve) => {
|
||||
const req = mockReq(user);
|
||||
const res = mockRes();
|
||||
requireJwtAuth(req, res, () => {
|
||||
resolve(getTenantId());
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// ── Tests ──────────────────────────────────────────────────────────────
|
||||
|
||||
describe('requireJwtAuth tenant context chaining', () => {
|
||||
afterEach(() => {
|
||||
mockPassportError = null;
|
||||
});
|
||||
|
||||
it('forwards passport errors to next() without entering tenant middleware', async () => {
|
||||
mockPassportError = new Error('JWT signature invalid');
|
||||
const req = mockReq(undefined);
|
||||
const res = mockRes();
|
||||
const err = await new Promise((resolve) => {
|
||||
requireJwtAuth(req, res, (e) => resolve(e));
|
||||
});
|
||||
expect(err).toBeInstanceOf(Error);
|
||||
expect(err.message).toBe('JWT signature invalid');
|
||||
expect(getTenantId()).toBeUndefined();
|
||||
});
|
||||
|
||||
it('sets ALS tenant context after passport auth succeeds', async () => {
|
||||
const tenantId = await runAuth({ tenantId: 'tenant-abc', role: 'user' });
|
||||
expect(tenantId).toBe('tenant-abc');
|
||||
});
|
||||
|
||||
it('ALS tenant context is NOT set when user has no tenantId', async () => {
|
||||
const tenantId = await runAuth({ role: 'user' });
|
||||
expect(tenantId).toBeUndefined();
|
||||
});
|
||||
|
||||
it('ALS tenant context is NOT set when user is undefined', async () => {
|
||||
const tenantId = await runAuth(undefined);
|
||||
expect(tenantId).toBeUndefined();
|
||||
});
|
||||
|
||||
it('concurrent requests get isolated tenant contexts', async () => {
|
||||
const results = await Promise.all(
|
||||
['tenant-1', 'tenant-2', 'tenant-3'].map((tid) => runAuth({ tenantId: tid, role: 'user' })),
|
||||
);
|
||||
expect(results).toEqual(['tenant-1', 'tenant-2', 'tenant-3']);
|
||||
});
|
||||
|
||||
it('ALS context is not set at top-level scope (outside any request)', () => {
|
||||
expect(getTenantId()).toBeUndefined();
|
||||
});
|
||||
});
|
||||
178
api/server/middleware/__tests__/validateModel.spec.js
Normal file
178
api/server/middleware/__tests__/validateModel.spec.js
Normal file
|
|
@ -0,0 +1,178 @@
|
|||
const { ViolationTypes } = require('librechat-data-provider');
|
||||
|
||||
jest.mock('@librechat/api', () => ({
|
||||
handleError: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/controllers/ModelController', () => ({
|
||||
getModelsConfig: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Config', () => ({
|
||||
getEndpointsConfig: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/cache', () => ({
|
||||
logViolation: jest.fn(),
|
||||
}));
|
||||
|
||||
const { handleError } = require('@librechat/api');
|
||||
const { getModelsConfig } = require('~/server/controllers/ModelController');
|
||||
const { getEndpointsConfig } = require('~/server/services/Config');
|
||||
const { logViolation } = require('~/cache');
|
||||
const validateModel = require('../validateModel');
|
||||
|
||||
describe('validateModel', () => {
|
||||
let req, res, next;
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
req = { body: { model: 'gpt-4o', endpoint: 'openAI' } };
|
||||
res = {};
|
||||
next = jest.fn();
|
||||
getEndpointsConfig.mockResolvedValue({
|
||||
openAI: { userProvide: false },
|
||||
});
|
||||
getModelsConfig.mockResolvedValue({
|
||||
openAI: ['gpt-4o', 'gpt-4o-mini'],
|
||||
});
|
||||
});
|
||||
|
||||
describe('format validation', () => {
|
||||
it('rejects missing model', async () => {
|
||||
req.body.model = undefined;
|
||||
await validateModel(req, res, next);
|
||||
expect(handleError).toHaveBeenCalledWith(res, { text: 'Model not provided' });
|
||||
expect(next).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('rejects non-string model', async () => {
|
||||
req.body.model = 12345;
|
||||
await validateModel(req, res, next);
|
||||
expect(handleError).toHaveBeenCalledWith(res, { text: 'Model not provided' });
|
||||
expect(next).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('rejects model exceeding 256 chars', async () => {
|
||||
req.body.model = 'a'.repeat(257);
|
||||
await validateModel(req, res, next);
|
||||
expect(handleError).toHaveBeenCalledWith(res, { text: 'Invalid model identifier' });
|
||||
});
|
||||
|
||||
it('rejects model with leading special character', async () => {
|
||||
req.body.model = '.bad-model';
|
||||
await validateModel(req, res, next);
|
||||
expect(handleError).toHaveBeenCalledWith(res, { text: 'Invalid model identifier' });
|
||||
});
|
||||
|
||||
it('rejects model with script injection', async () => {
|
||||
req.body.model = '<script>alert(1)</script>';
|
||||
await validateModel(req, res, next);
|
||||
expect(handleError).toHaveBeenCalledWith(res, { text: 'Invalid model identifier' });
|
||||
});
|
||||
|
||||
it('trims whitespace before validation', async () => {
|
||||
req.body.model = ' gpt-4o ';
|
||||
getModelsConfig.mockResolvedValue({ openAI: ['gpt-4o'] });
|
||||
await validateModel(req, res, next);
|
||||
expect(next).toHaveBeenCalled();
|
||||
expect(handleError).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('rejects model with spaces in the middle', async () => {
|
||||
req.body.model = 'gpt 4o';
|
||||
await validateModel(req, res, next);
|
||||
expect(handleError).toHaveBeenCalledWith(res, { text: 'Invalid model identifier' });
|
||||
});
|
||||
|
||||
it('accepts standard model IDs', async () => {
|
||||
const validModels = [
|
||||
'gpt-4o',
|
||||
'claude-3-5-sonnet-20241022',
|
||||
'us.amazon.nova-pro-v1:0',
|
||||
'qwen/qwen3.6-plus-preview:free',
|
||||
'Meta-Llama-3-8B-Instruct-4bit',
|
||||
];
|
||||
for (const model of validModels) {
|
||||
jest.clearAllMocks();
|
||||
req.body.model = model;
|
||||
getEndpointsConfig.mockResolvedValue({ openAI: { userProvide: false } });
|
||||
getModelsConfig.mockResolvedValue({ openAI: [model] });
|
||||
next.mockClear();
|
||||
|
||||
await validateModel(req, res, next);
|
||||
expect(next).toHaveBeenCalled();
|
||||
expect(handleError).not.toHaveBeenCalled();
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
describe('userProvide early-return', () => {
|
||||
it('calls next() immediately for userProvide endpoints without checking model list', async () => {
|
||||
getEndpointsConfig.mockResolvedValue({
|
||||
openAI: { userProvide: true },
|
||||
});
|
||||
req.body.model = 'any-model-from-user-key';
|
||||
|
||||
await validateModel(req, res, next);
|
||||
|
||||
expect(next).toHaveBeenCalled();
|
||||
expect(getModelsConfig).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('does not call getModelsConfig for userProvide endpoints', async () => {
|
||||
getEndpointsConfig.mockResolvedValue({
|
||||
CustomEndpoint: { userProvide: true },
|
||||
});
|
||||
req.body = { model: 'custom-model', endpoint: 'CustomEndpoint' };
|
||||
|
||||
await validateModel(req, res, next);
|
||||
|
||||
expect(getModelsConfig).not.toHaveBeenCalled();
|
||||
expect(next).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('system endpoint list validation', () => {
|
||||
it('rejects a model not in the available list', async () => {
|
||||
req.body.model = 'not-in-list';
|
||||
|
||||
await validateModel(req, res, next);
|
||||
|
||||
expect(logViolation).toHaveBeenCalledWith(
|
||||
req,
|
||||
res,
|
||||
ViolationTypes.ILLEGAL_MODEL_REQUEST,
|
||||
expect.any(Object),
|
||||
expect.anything(),
|
||||
);
|
||||
expect(handleError).toHaveBeenCalledWith(res, { text: 'Illegal model request' });
|
||||
expect(next).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('accepts a model in the available list', async () => {
|
||||
req.body.model = 'gpt-4o';
|
||||
|
||||
await validateModel(req, res, next);
|
||||
|
||||
expect(next).toHaveBeenCalled();
|
||||
expect(handleError).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('rejects when endpoint has no models loaded', async () => {
|
||||
getModelsConfig.mockResolvedValue({ openAI: undefined });
|
||||
|
||||
await validateModel(req, res, next);
|
||||
|
||||
expect(handleError).toHaveBeenCalledWith(res, { text: 'Endpoint models not loaded' });
|
||||
});
|
||||
|
||||
it('rejects when modelsConfig is null', async () => {
|
||||
getModelsConfig.mockResolvedValue(null);
|
||||
|
||||
await validateModel(req, res, next);
|
||||
|
||||
expect(handleError).toHaveBeenCalledWith(res, { text: 'Models not loaded' });
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
@ -18,6 +18,7 @@ const checkDomainAllowed = async (req, res, next) => {
|
|||
const email = req?.user?.email;
|
||||
const appConfig = await getAppConfig({
|
||||
role: req?.user?.role,
|
||||
tenantId: req?.user?.tenantId,
|
||||
});
|
||||
|
||||
if (email && !isEmailDomainAllowed(email, appConfig?.registration?.allowedDomains)) {
|
||||
|
|
|
|||
|
|
@ -4,7 +4,9 @@ const { getAppConfig } = require('~/server/services/Config');
|
|||
const configMiddleware = async (req, res, next) => {
|
||||
try {
|
||||
const userRole = req.user?.role;
|
||||
req.config = await getAppConfig({ role: userRole });
|
||||
const userId = req.user?.id;
|
||||
const tenantId = req.user?.tenantId;
|
||||
req.config = await getAppConfig({ role: userRole, userId, tenantId });
|
||||
|
||||
next();
|
||||
} catch (error) {
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
const cookies = require('cookie');
|
||||
const passport = require('passport');
|
||||
const { isEnabled } = require('@librechat/api');
|
||||
const { isEnabled, tenantContextMiddleware } = require('@librechat/api');
|
||||
|
||||
// This middleware does not require authentication,
|
||||
// but if the user is authenticated, it will set the user object.
|
||||
// but if the user is authenticated, it will set the user object
|
||||
// and establish tenant ALS context.
|
||||
const optionalJwtAuth = (req, res, next) => {
|
||||
const cookieHeader = req.headers.cookie;
|
||||
const tokenProvider = cookieHeader ? cookies.parse(cookieHeader).token_provider : null;
|
||||
|
|
@ -13,6 +14,7 @@ const optionalJwtAuth = (req, res, next) => {
|
|||
}
|
||||
if (user) {
|
||||
req.user = user;
|
||||
return tenantContextMiddleware(req, res, next);
|
||||
}
|
||||
next();
|
||||
};
|
||||
|
|
|
|||
|
|
@ -1,20 +1,29 @@
|
|||
const cookies = require('cookie');
|
||||
const passport = require('passport');
|
||||
const { isEnabled } = require('@librechat/api');
|
||||
const { isEnabled, tenantContextMiddleware } = require('@librechat/api');
|
||||
|
||||
/**
|
||||
* Custom Middleware to handle JWT authentication, with support for OpenID token reuse
|
||||
* Switches between JWT and OpenID authentication based on cookies and environment settings
|
||||
* Custom Middleware to handle JWT authentication, with support for OpenID token reuse.
|
||||
* Switches between JWT and OpenID authentication based on cookies and environment settings.
|
||||
*
|
||||
* After successful authentication (req.user populated), automatically chains into
|
||||
* `tenantContextMiddleware` to propagate `req.user.tenantId` into AsyncLocalStorage
|
||||
* for downstream Mongoose tenant isolation.
|
||||
*/
|
||||
const requireJwtAuth = (req, res, next) => {
|
||||
const cookieHeader = req.headers.cookie;
|
||||
const tokenProvider = cookieHeader ? cookies.parse(cookieHeader).token_provider : null;
|
||||
|
||||
if (tokenProvider === 'openid' && isEnabled(process.env.OPENID_REUSE_TOKENS)) {
|
||||
return passport.authenticate('openidJwt', { session: false })(req, res, next);
|
||||
}
|
||||
const strategy =
|
||||
tokenProvider === 'openid' && isEnabled(process.env.OPENID_REUSE_TOKENS) ? 'openidJwt' : 'jwt';
|
||||
|
||||
return passport.authenticate('jwt', { session: false })(req, res, next);
|
||||
passport.authenticate(strategy, { session: false })(req, res, (err) => {
|
||||
if (err) {
|
||||
return next(err);
|
||||
}
|
||||
// req.user is now populated by passport — set up tenant ALS context
|
||||
tenantContextMiddleware(req, res, next);
|
||||
});
|
||||
};
|
||||
|
||||
module.exports = requireJwtAuth;
|
||||
|
|
|
|||
|
|
@ -1,7 +1,12 @@
|
|||
const { handleError } = require('@librechat/api');
|
||||
const { ViolationTypes } = require('librechat-data-provider');
|
||||
const { getModelsConfig } = require('~/server/controllers/ModelController');
|
||||
const { getEndpointsConfig } = require('~/server/services/Config');
|
||||
const { logViolation } = require('~/cache');
|
||||
|
||||
const MAX_MODEL_STRING_LENGTH = 256;
|
||||
const MODEL_PATTERN = /^[a-zA-Z0-9][a-zA-Z0-9_.:/@+-]*$/;
|
||||
|
||||
/**
|
||||
* Validates the model of the request.
|
||||
*
|
||||
|
|
@ -11,11 +16,27 @@ const { logViolation } = require('~/cache');
|
|||
* @param {Function} next - The Express next function.
|
||||
*/
|
||||
const validateModel = async (req, res, next) => {
|
||||
const { model, endpoint } = req.body;
|
||||
if (!model) {
|
||||
const { endpoint } = req.body;
|
||||
const rawModel = req.body.model;
|
||||
|
||||
if (!rawModel || typeof rawModel !== 'string') {
|
||||
return handleError(res, { text: 'Model not provided' });
|
||||
}
|
||||
|
||||
const model = rawModel.trim();
|
||||
if (!model || model.length > MAX_MODEL_STRING_LENGTH || !MODEL_PATTERN.test(model)) {
|
||||
return handleError(res, { text: 'Invalid model identifier' });
|
||||
}
|
||||
|
||||
req.body.model = model;
|
||||
|
||||
const endpointsConfig = await getEndpointsConfig(req);
|
||||
const endpointConfig = endpointsConfig?.[endpoint];
|
||||
|
||||
if (endpointConfig?.userProvide) {
|
||||
return next();
|
||||
}
|
||||
|
||||
const modelsConfig = await getModelsConfig(req);
|
||||
|
||||
if (!modelsConfig) {
|
||||
|
|
|
|||
|
|
@ -1,25 +1,73 @@
|
|||
jest.mock('~/cache/getLogStores');
|
||||
|
||||
const mockGetAppConfig = jest.fn();
|
||||
jest.mock('~/server/services/Config/app', () => ({
|
||||
getAppConfig: (...args) => mockGetAppConfig(...args),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Config/ldap', () => ({
|
||||
getLdapConfig: jest.fn(() => null),
|
||||
}));
|
||||
|
||||
const mockGetTenantId = jest.fn(() => undefined);
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
...jest.requireActual('@librechat/data-schemas'),
|
||||
getTenantId: (...args) => mockGetTenantId(...args),
|
||||
}));
|
||||
|
||||
const request = require('supertest');
|
||||
const express = require('express');
|
||||
const configRoute = require('../config');
|
||||
// file deepcode ignore UseCsurfForExpress/test: test
|
||||
const app = express();
|
||||
app.disable('x-powered-by');
|
||||
app.use('/api/config', configRoute);
|
||||
|
||||
function createApp(user) {
|
||||
const app = express();
|
||||
app.disable('x-powered-by');
|
||||
if (user) {
|
||||
app.use((req, _res, next) => {
|
||||
req.user = user;
|
||||
next();
|
||||
});
|
||||
}
|
||||
app.use('/api/config', configRoute);
|
||||
return app;
|
||||
}
|
||||
|
||||
const baseAppConfig = {
|
||||
registration: { socialLogins: ['google', 'github'] },
|
||||
interfaceConfig: {
|
||||
privacyPolicy: { externalUrl: 'https://example.com/privacy' },
|
||||
termsOfService: { externalUrl: 'https://example.com/tos' },
|
||||
modelSelect: true,
|
||||
},
|
||||
turnstileConfig: { siteKey: 'test-key' },
|
||||
modelSpecs: { list: [{ name: 'test-spec' }] },
|
||||
webSearch: { searchProvider: 'tavily' },
|
||||
};
|
||||
|
||||
const mockUser = {
|
||||
id: 'user123',
|
||||
role: 'USER',
|
||||
tenantId: undefined,
|
||||
};
|
||||
|
||||
afterEach(() => {
|
||||
jest.resetAllMocks();
|
||||
delete process.env.APP_TITLE;
|
||||
delete process.env.CHECK_BALANCE;
|
||||
delete process.env.START_BALANCE;
|
||||
delete process.env.SANDPACK_BUNDLER_URL;
|
||||
delete process.env.SANDPACK_STATIC_BUNDLER_URL;
|
||||
delete process.env.CONVERSATION_IMPORT_MAX_FILE_SIZE_BYTES;
|
||||
delete process.env.ALLOW_REGISTRATION;
|
||||
delete process.env.ALLOW_SOCIAL_LOGIN;
|
||||
delete process.env.ALLOW_PASSWORD_RESET;
|
||||
delete process.env.DOMAIN_SERVER;
|
||||
delete process.env.GOOGLE_CLIENT_ID;
|
||||
delete process.env.GOOGLE_CLIENT_SECRET;
|
||||
delete process.env.FACEBOOK_CLIENT_ID;
|
||||
delete process.env.FACEBOOK_CLIENT_SECRET;
|
||||
delete process.env.OPENID_CLIENT_ID;
|
||||
delete process.env.OPENID_CLIENT_SECRET;
|
||||
delete process.env.OPENID_ISSUER;
|
||||
delete process.env.OPENID_SESSION_SECRET;
|
||||
delete process.env.OPENID_BUTTON_LABEL;
|
||||
delete process.env.OPENID_AUTO_REDIRECT;
|
||||
delete process.env.OPENID_AUTH_URL;
|
||||
delete process.env.GITHUB_CLIENT_ID;
|
||||
delete process.env.GITHUB_CLIENT_SECRET;
|
||||
delete process.env.DISCORD_CLIENT_ID;
|
||||
|
|
@ -28,78 +76,215 @@ afterEach(() => {
|
|||
delete process.env.SAML_ISSUER;
|
||||
delete process.env.SAML_CERT;
|
||||
delete process.env.SAML_SESSION_SECRET;
|
||||
delete process.env.SAML_BUTTON_LABEL;
|
||||
delete process.env.SAML_IMAGE_URL;
|
||||
delete process.env.DOMAIN_SERVER;
|
||||
delete process.env.ALLOW_REGISTRATION;
|
||||
delete process.env.ALLOW_SOCIAL_LOGIN;
|
||||
delete process.env.ALLOW_PASSWORD_RESET;
|
||||
delete process.env.LDAP_URL;
|
||||
delete process.env.LDAP_BIND_DN;
|
||||
delete process.env.LDAP_BIND_CREDENTIALS;
|
||||
delete process.env.LDAP_USER_SEARCH_BASE;
|
||||
delete process.env.LDAP_SEARCH_FILTER;
|
||||
});
|
||||
|
||||
//TODO: This works/passes locally but http request tests fail with 404 in CI. Need to figure out why.
|
||||
describe('GET /api/config', () => {
|
||||
describe('unauthenticated (no req.user)', () => {
|
||||
it('should call getAppConfig with baseOnly when no tenant context', async () => {
|
||||
mockGetAppConfig.mockResolvedValue(baseAppConfig);
|
||||
mockGetTenantId.mockReturnValue(undefined);
|
||||
const app = createApp(null);
|
||||
|
||||
describe.skip('GET /', () => {
|
||||
it('should return 200 and the correct body', async () => {
|
||||
process.env.APP_TITLE = 'Test Title';
|
||||
process.env.GOOGLE_CLIENT_ID = 'Test Google Client Id';
|
||||
process.env.GOOGLE_CLIENT_SECRET = 'Test Google Client Secret';
|
||||
process.env.FACEBOOK_CLIENT_ID = 'Test Facebook Client Id';
|
||||
process.env.FACEBOOK_CLIENT_SECRET = 'Test Facebook Client Secret';
|
||||
process.env.OPENID_CLIENT_ID = 'Test OpenID Id';
|
||||
process.env.OPENID_CLIENT_SECRET = 'Test OpenID Secret';
|
||||
process.env.OPENID_ISSUER = 'Test OpenID Issuer';
|
||||
process.env.OPENID_SESSION_SECRET = 'Test Secret';
|
||||
process.env.OPENID_BUTTON_LABEL = 'Test OpenID';
|
||||
process.env.OPENID_AUTH_URL = 'http://test-server.com';
|
||||
process.env.GITHUB_CLIENT_ID = 'Test Github client Id';
|
||||
process.env.GITHUB_CLIENT_SECRET = 'Test Github client Secret';
|
||||
process.env.DISCORD_CLIENT_ID = 'Test Discord client Id';
|
||||
process.env.DISCORD_CLIENT_SECRET = 'Test Discord client Secret';
|
||||
process.env.SAML_ENTRY_POINT = 'http://test-server.com';
|
||||
process.env.SAML_ISSUER = 'Test SAML Issuer';
|
||||
process.env.SAML_CERT = 'saml.pem';
|
||||
process.env.SAML_SESSION_SECRET = 'Test Secret';
|
||||
process.env.SAML_BUTTON_LABEL = 'Test SAML';
|
||||
process.env.SAML_IMAGE_URL = 'http://test-server.com';
|
||||
process.env.DOMAIN_SERVER = 'http://test-server.com';
|
||||
process.env.ALLOW_REGISTRATION = 'true';
|
||||
process.env.ALLOW_SOCIAL_LOGIN = 'true';
|
||||
process.env.ALLOW_PASSWORD_RESET = 'true';
|
||||
process.env.LDAP_URL = 'Test LDAP URL';
|
||||
process.env.LDAP_BIND_DN = 'Test LDAP Bind DN';
|
||||
process.env.LDAP_BIND_CREDENTIALS = 'Test LDAP Bind Credentials';
|
||||
process.env.LDAP_USER_SEARCH_BASE = 'Test LDAP User Search Base';
|
||||
process.env.LDAP_SEARCH_FILTER = 'Test LDAP Search Filter';
|
||||
await request(app).get('/api/config');
|
||||
|
||||
const response = await request(app).get('/');
|
||||
expect(mockGetAppConfig).toHaveBeenCalledWith({ baseOnly: true });
|
||||
});
|
||||
|
||||
expect(response.statusCode).toBe(200);
|
||||
expect(response.body).toEqual({
|
||||
appTitle: 'Test Title',
|
||||
socialLogins: ['google', 'facebook', 'openid', 'github', 'discord', 'saml'],
|
||||
discordLoginEnabled: true,
|
||||
facebookLoginEnabled: true,
|
||||
githubLoginEnabled: true,
|
||||
googleLoginEnabled: true,
|
||||
openidLoginEnabled: true,
|
||||
openidLabel: 'Test OpenID',
|
||||
openidImageUrl: 'http://test-server.com',
|
||||
samlLoginEnabled: true,
|
||||
samlLabel: 'Test SAML',
|
||||
samlImageUrl: 'http://test-server.com',
|
||||
ldap: {
|
||||
enabled: true,
|
||||
},
|
||||
serverDomain: 'http://test-server.com',
|
||||
emailLoginEnabled: 'true',
|
||||
registrationEnabled: 'true',
|
||||
passwordResetEnabled: 'true',
|
||||
socialLoginEnabled: 'true',
|
||||
it('should call getAppConfig with tenantId when tenant context is present', async () => {
|
||||
mockGetAppConfig.mockResolvedValue(baseAppConfig);
|
||||
mockGetTenantId.mockReturnValue('tenant-abc');
|
||||
const app = createApp(null);
|
||||
|
||||
await request(app).get('/api/config');
|
||||
|
||||
expect(mockGetAppConfig).toHaveBeenCalledWith({ tenantId: 'tenant-abc' });
|
||||
});
|
||||
|
||||
it('should map tenant-scoped config fields in unauthenticated response', async () => {
|
||||
const tenantConfig = {
|
||||
...baseAppConfig,
|
||||
registration: { socialLogins: ['saml'] },
|
||||
turnstileConfig: { siteKey: 'tenant-key' },
|
||||
};
|
||||
mockGetAppConfig.mockResolvedValue(tenantConfig);
|
||||
mockGetTenantId.mockReturnValue('tenant-abc');
|
||||
const app = createApp(null);
|
||||
|
||||
const response = await request(app).get('/api/config');
|
||||
|
||||
expect(response.statusCode).toBe(200);
|
||||
expect(response.body.socialLogins).toEqual(['saml']);
|
||||
expect(response.body.turnstile).toEqual({ siteKey: 'tenant-key' });
|
||||
expect(response.body).not.toHaveProperty('modelSpecs');
|
||||
});
|
||||
|
||||
it('should return minimal payload without authenticated-only fields', async () => {
|
||||
mockGetAppConfig.mockResolvedValue(baseAppConfig);
|
||||
const app = createApp(null);
|
||||
|
||||
const response = await request(app).get('/api/config');
|
||||
|
||||
expect(response.statusCode).toBe(200);
|
||||
expect(response.body).not.toHaveProperty('modelSpecs');
|
||||
expect(response.body).not.toHaveProperty('balance');
|
||||
expect(response.body).not.toHaveProperty('webSearch');
|
||||
expect(response.body).not.toHaveProperty('bundlerURL');
|
||||
expect(response.body).not.toHaveProperty('staticBundlerURL');
|
||||
expect(response.body).not.toHaveProperty('sharePointFilePickerEnabled');
|
||||
expect(response.body).not.toHaveProperty('conversationImportMaxFileSize');
|
||||
});
|
||||
|
||||
it('should include socialLogins and turnstile from base config', async () => {
|
||||
mockGetAppConfig.mockResolvedValue(baseAppConfig);
|
||||
const app = createApp(null);
|
||||
|
||||
const response = await request(app).get('/api/config');
|
||||
|
||||
expect(response.body.socialLogins).toEqual(['google', 'github']);
|
||||
expect(response.body.turnstile).toEqual({ siteKey: 'test-key' });
|
||||
});
|
||||
|
||||
it('should include only privacyPolicy and termsOfService from interface config', async () => {
|
||||
mockGetAppConfig.mockResolvedValue(baseAppConfig);
|
||||
const app = createApp(null);
|
||||
|
||||
const response = await request(app).get('/api/config');
|
||||
|
||||
expect(response.body.interface).toEqual({
|
||||
privacyPolicy: { externalUrl: 'https://example.com/privacy' },
|
||||
termsOfService: { externalUrl: 'https://example.com/tos' },
|
||||
});
|
||||
expect(response.body.interface).not.toHaveProperty('modelSelect');
|
||||
});
|
||||
|
||||
it('should not include interface if no privacyPolicy or termsOfService', async () => {
|
||||
mockGetAppConfig.mockResolvedValue({
|
||||
...baseAppConfig,
|
||||
interfaceConfig: { modelSelect: true },
|
||||
});
|
||||
const app = createApp(null);
|
||||
|
||||
const response = await request(app).get('/api/config');
|
||||
|
||||
expect(response.body).not.toHaveProperty('interface');
|
||||
});
|
||||
|
||||
it('should include shared env var fields', async () => {
|
||||
mockGetAppConfig.mockResolvedValue(baseAppConfig);
|
||||
process.env.APP_TITLE = 'Test App';
|
||||
const app = createApp(null);
|
||||
|
||||
const response = await request(app).get('/api/config');
|
||||
|
||||
expect(response.body.appTitle).toBe('Test App');
|
||||
expect(response.body).toHaveProperty('emailLoginEnabled');
|
||||
expect(response.body).toHaveProperty('serverDomain');
|
||||
});
|
||||
|
||||
it('should return 500 when getAppConfig throws', async () => {
|
||||
mockGetAppConfig.mockRejectedValue(new Error('Config service failure'));
|
||||
const app = createApp(null);
|
||||
|
||||
const response = await request(app).get('/api/config');
|
||||
|
||||
expect(response.statusCode).toBe(500);
|
||||
expect(response.body).toHaveProperty('error');
|
||||
});
|
||||
});
|
||||
|
||||
describe('authenticated (req.user exists)', () => {
|
||||
it('should call getAppConfig with role, userId, and tenantId', async () => {
|
||||
mockGetAppConfig.mockResolvedValue(baseAppConfig);
|
||||
mockGetTenantId.mockReturnValue('fallback-tenant');
|
||||
const app = createApp(mockUser);
|
||||
|
||||
await request(app).get('/api/config');
|
||||
|
||||
expect(mockGetAppConfig).toHaveBeenCalledWith({
|
||||
role: 'USER',
|
||||
userId: 'user123',
|
||||
tenantId: 'fallback-tenant',
|
||||
});
|
||||
});
|
||||
|
||||
it('should prefer user tenantId over getTenantId fallback', async () => {
|
||||
mockGetAppConfig.mockResolvedValue(baseAppConfig);
|
||||
mockGetTenantId.mockReturnValue('fallback-tenant');
|
||||
const app = createApp({ ...mockUser, tenantId: 'user-tenant' });
|
||||
|
||||
await request(app).get('/api/config');
|
||||
|
||||
expect(mockGetAppConfig).toHaveBeenCalledWith({
|
||||
role: 'USER',
|
||||
userId: 'user123',
|
||||
tenantId: 'user-tenant',
|
||||
});
|
||||
});
|
||||
|
||||
it('should include modelSpecs, balance, and webSearch', async () => {
|
||||
mockGetAppConfig.mockResolvedValue(baseAppConfig);
|
||||
process.env.CHECK_BALANCE = 'true';
|
||||
process.env.START_BALANCE = '10000';
|
||||
const app = createApp(mockUser);
|
||||
|
||||
const response = await request(app).get('/api/config');
|
||||
|
||||
expect(response.body.modelSpecs).toEqual({ list: [{ name: 'test-spec' }] });
|
||||
expect(response.body.balance).toEqual({ enabled: true, startBalance: 10000 });
|
||||
expect(response.body.webSearch).toEqual({ searchProvider: 'tavily' });
|
||||
});
|
||||
|
||||
it('should include full interface config', async () => {
|
||||
mockGetAppConfig.mockResolvedValue(baseAppConfig);
|
||||
const app = createApp(mockUser);
|
||||
|
||||
const response = await request(app).get('/api/config');
|
||||
|
||||
expect(response.body.interface).toEqual(baseAppConfig.interfaceConfig);
|
||||
});
|
||||
|
||||
it('should include authenticated-only env var fields', async () => {
|
||||
mockGetAppConfig.mockResolvedValue(baseAppConfig);
|
||||
process.env.SANDPACK_BUNDLER_URL = 'https://bundler.test';
|
||||
process.env.SANDPACK_STATIC_BUNDLER_URL = 'https://static-bundler.test';
|
||||
process.env.CONVERSATION_IMPORT_MAX_FILE_SIZE_BYTES = '5000000';
|
||||
const app = createApp(mockUser);
|
||||
|
||||
const response = await request(app).get('/api/config');
|
||||
|
||||
expect(response.body.bundlerURL).toBe('https://bundler.test');
|
||||
expect(response.body.staticBundlerURL).toBe('https://static-bundler.test');
|
||||
expect(response.body.conversationImportMaxFileSize).toBe(5000000);
|
||||
});
|
||||
|
||||
it('should merge per-user balance override into config', async () => {
|
||||
mockGetAppConfig.mockResolvedValue({
|
||||
...baseAppConfig,
|
||||
balance: {
|
||||
enabled: true,
|
||||
startBalance: 50000,
|
||||
},
|
||||
});
|
||||
const app = createApp(mockUser);
|
||||
|
||||
const response = await request(app).get('/api/config');
|
||||
|
||||
expect(response.body.balance).toEqual(
|
||||
expect.objectContaining({
|
||||
enabled: true,
|
||||
startBalance: 50000,
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should return 500 when getAppConfig throws', async () => {
|
||||
mockGetAppConfig.mockRejectedValue(new Error('Config service failure'));
|
||||
const app = createApp(mockUser);
|
||||
|
||||
const response = await request(app).get('/api/config');
|
||||
|
||||
expect(response.statusCode).toBe(500);
|
||||
expect(response.body).toHaveProperty('error');
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
|||
185
api/server/routes/__tests__/grants.spec.js
Normal file
185
api/server/routes/__tests__/grants.spec.js
Normal file
|
|
@ -0,0 +1,185 @@
|
|||
const express = require('express');
|
||||
const request = require('supertest');
|
||||
const mongoose = require('mongoose');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
const { createModels, createMethods } = require('@librechat/data-schemas');
|
||||
const { PrincipalType, SystemRoles } = require('librechat-data-provider');
|
||||
|
||||
/**
|
||||
* Integration test for the admin grants routes.
|
||||
*
|
||||
* Validates the full Express wiring: route registration → middleware →
|
||||
* handler → real MongoDB. Auth middleware is injected (matching the repo
|
||||
* pattern in keys.spec.js) so we can control the caller identity without
|
||||
* a real JWT, while the handler DI deps use real DB methods.
|
||||
*/
|
||||
|
||||
jest.mock('~/server/middleware', () => ({
|
||||
requireJwtAuth: (_req, _res, next) => next(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/middleware/roles/capabilities', () => ({
|
||||
requireCapability: () => (_req, _res, next) => next(),
|
||||
}));
|
||||
|
||||
let mongoServer;
|
||||
let db;
|
||||
|
||||
beforeAll(async () => {
|
||||
mongoServer = await MongoMemoryServer.create();
|
||||
await mongoose.connect(mongoServer.getUri());
|
||||
createModels(mongoose);
|
||||
db = createMethods(mongoose);
|
||||
await db.seedSystemGrants();
|
||||
await db.initializeRoles();
|
||||
await db.seedDefaultRoles();
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
await mongoose.disconnect();
|
||||
await mongoServer.stop();
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
const SystemGrant = mongoose.models.SystemGrant;
|
||||
// Clean non-seed grants (keep admin seed)
|
||||
await SystemGrant.deleteMany({
|
||||
$or: [
|
||||
{ principalId: { $ne: SystemRoles.ADMIN } },
|
||||
{ principalType: { $ne: PrincipalType.ROLE } },
|
||||
],
|
||||
});
|
||||
});
|
||||
|
||||
function createApp(user) {
|
||||
const { createAdminGrantsHandlers, getCachedPrincipals } = require('@librechat/api');
|
||||
|
||||
const handlers = createAdminGrantsHandlers({
|
||||
listGrants: db.listGrants,
|
||||
countGrants: db.countGrants,
|
||||
getCapabilitiesForPrincipal: db.getCapabilitiesForPrincipal,
|
||||
getCapabilitiesForPrincipals: db.getCapabilitiesForPrincipals,
|
||||
grantCapability: db.grantCapability,
|
||||
revokeCapability: db.revokeCapability,
|
||||
getUserPrincipals: db.getUserPrincipals,
|
||||
hasCapabilityForPrincipals: db.hasCapabilityForPrincipals,
|
||||
getHeldCapabilities: db.getHeldCapabilities,
|
||||
getCachedPrincipals,
|
||||
checkRoleExists: async (name) => (await db.getRoleByName(name)) != null,
|
||||
});
|
||||
|
||||
const app = express();
|
||||
app.use(express.json());
|
||||
app.use((req, _res, next) => {
|
||||
req.user = user;
|
||||
next();
|
||||
});
|
||||
|
||||
const router = express.Router();
|
||||
router.get('/', handlers.listGrants);
|
||||
router.get('/effective', handlers.getEffectiveCapabilities);
|
||||
router.get('/:principalType/:principalId', handlers.getPrincipalGrants);
|
||||
router.post('/', handlers.assignGrant);
|
||||
router.delete('/:principalType/:principalId/:capability', handlers.revokeGrant);
|
||||
app.use('/api/admin/grants', router);
|
||||
|
||||
return app;
|
||||
}
|
||||
|
||||
describe('Admin Grants Routes — Integration', () => {
|
||||
const adminUserId = new mongoose.Types.ObjectId();
|
||||
const adminUser = {
|
||||
_id: adminUserId,
|
||||
id: adminUserId.toString(),
|
||||
role: SystemRoles.ADMIN,
|
||||
};
|
||||
|
||||
it('GET / returns seeded admin grants', async () => {
|
||||
const app = createApp(adminUser);
|
||||
const res = await request(app).get('/api/admin/grants').expect(200);
|
||||
|
||||
expect(res.body).toHaveProperty('grants');
|
||||
expect(res.body).toHaveProperty('total');
|
||||
expect(res.body.grants.length).toBeGreaterThan(0);
|
||||
// Seeded grants are for the ADMIN role
|
||||
expect(res.body.grants[0].principalType).toBe(PrincipalType.ROLE);
|
||||
});
|
||||
|
||||
it('GET /effective returns capabilities for admin', async () => {
|
||||
const app = createApp(adminUser);
|
||||
const res = await request(app).get('/api/admin/grants/effective').expect(200);
|
||||
|
||||
expect(res.body).toHaveProperty('capabilities');
|
||||
expect(res.body.capabilities).toContain('access:admin');
|
||||
expect(res.body.capabilities).toContain('manage:roles');
|
||||
});
|
||||
|
||||
it('POST / assigns a grant and DELETE / revokes it', async () => {
|
||||
const app = createApp(adminUser);
|
||||
|
||||
// Assign
|
||||
const assignRes = await request(app)
|
||||
.post('/api/admin/grants')
|
||||
.send({
|
||||
principalType: PrincipalType.ROLE,
|
||||
principalId: SystemRoles.USER,
|
||||
capability: 'read:users',
|
||||
})
|
||||
.expect(201);
|
||||
|
||||
expect(assignRes.body.grant).toMatchObject({
|
||||
principalType: PrincipalType.ROLE,
|
||||
principalId: SystemRoles.USER,
|
||||
capability: 'read:users',
|
||||
});
|
||||
|
||||
// Verify via GET
|
||||
const getRes = await request(app)
|
||||
.get(`/api/admin/grants/${PrincipalType.ROLE}/${SystemRoles.USER}`)
|
||||
.expect(200);
|
||||
|
||||
expect(getRes.body.grants.some((g) => g.capability === 'read:users')).toBe(true);
|
||||
|
||||
// Revoke
|
||||
await request(app)
|
||||
.delete(`/api/admin/grants/${PrincipalType.ROLE}/${SystemRoles.USER}/read:users`)
|
||||
.expect(200);
|
||||
|
||||
// Verify revoked
|
||||
const afterRes = await request(app)
|
||||
.get(`/api/admin/grants/${PrincipalType.ROLE}/${SystemRoles.USER}`)
|
||||
.expect(200);
|
||||
|
||||
expect(afterRes.body.grants.some((g) => g.capability === 'read:users')).toBe(false);
|
||||
});
|
||||
|
||||
it('POST / returns 400 for non-existent role when checkRoleExists is wired', async () => {
|
||||
const app = createApp(adminUser);
|
||||
|
||||
const res = await request(app)
|
||||
.post('/api/admin/grants')
|
||||
.send({
|
||||
principalType: PrincipalType.ROLE,
|
||||
principalId: 'nonexistent-role',
|
||||
capability: 'read:users',
|
||||
})
|
||||
.expect(400);
|
||||
|
||||
expect(res.body.error).toBe('Role not found');
|
||||
});
|
||||
|
||||
it('POST / returns 401 without authenticated user', async () => {
|
||||
const app = createApp(undefined);
|
||||
|
||||
const res = await request(app)
|
||||
.post('/api/admin/grants')
|
||||
.send({
|
||||
principalType: PrincipalType.ROLE,
|
||||
principalId: SystemRoles.USER,
|
||||
capability: 'read:users',
|
||||
})
|
||||
.expect(401);
|
||||
|
||||
expect(res.body).toHaveProperty('error', 'Authentication required');
|
||||
});
|
||||
});
|
||||
|
|
@ -18,6 +18,7 @@ const mockRegistryInstance = {
|
|||
getServerConfig: jest.fn(),
|
||||
getOAuthServers: jest.fn(),
|
||||
getAllServerConfigs: jest.fn(),
|
||||
ensureConfigServers: jest.fn().mockResolvedValue({}),
|
||||
addServer: jest.fn(),
|
||||
updateServer: jest.fn(),
|
||||
removeServer: jest.fn(),
|
||||
|
|
@ -58,6 +59,7 @@ jest.mock('@librechat/api', () => {
|
|||
});
|
||||
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
getTenantId: jest.fn(),
|
||||
logger: {
|
||||
debug: jest.fn(),
|
||||
info: jest.fn(),
|
||||
|
|
@ -93,14 +95,18 @@ jest.mock('~/server/services/Config', () => ({
|
|||
getCachedTools: jest.fn(),
|
||||
getMCPServerTools: jest.fn(),
|
||||
loadCustomConfig: jest.fn(),
|
||||
getAppConfig: jest.fn().mockResolvedValue({ mcpConfig: {} }),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Config/mcp', () => ({
|
||||
updateMCPServerTools: jest.fn(),
|
||||
}));
|
||||
|
||||
const mockResolveAllMcpConfigs = jest.fn().mockResolvedValue({});
|
||||
jest.mock('~/server/services/MCP', () => ({
|
||||
getMCPSetupData: jest.fn(),
|
||||
resolveConfigServers: jest.fn().mockResolvedValue({}),
|
||||
resolveAllMcpConfigs: (...args) => mockResolveAllMcpConfigs(...args),
|
||||
getServerConnectionStatus: jest.fn(),
|
||||
}));
|
||||
|
||||
|
|
@ -579,6 +585,112 @@ describe('MCP Routes', () => {
|
|||
);
|
||||
});
|
||||
|
||||
it('should use oauthHeaders from flow state when present', async () => {
|
||||
const mockFlowManager = {
|
||||
getFlowState: jest.fn().mockResolvedValue({ status: 'PENDING' }),
|
||||
completeFlow: jest.fn().mockResolvedValue(),
|
||||
deleteFlow: jest.fn().mockResolvedValue(true),
|
||||
};
|
||||
const mockFlowState = {
|
||||
serverName: 'test-server',
|
||||
userId: 'test-user-id',
|
||||
metadata: { toolFlowId: 'tool-flow-123' },
|
||||
clientInfo: {},
|
||||
codeVerifier: 'test-verifier',
|
||||
oauthHeaders: { 'X-Custom-Auth': 'header-value' },
|
||||
};
|
||||
const mockTokens = { access_token: 'tok', refresh_token: 'ref' };
|
||||
|
||||
MCPOAuthHandler.getFlowState.mockResolvedValue(mockFlowState);
|
||||
MCPOAuthHandler.completeOAuthFlow.mockResolvedValue(mockTokens);
|
||||
MCPTokenStorage.storeTokens.mockResolvedValue();
|
||||
getLogStores.mockReturnValue({});
|
||||
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
|
||||
require('~/config').getOAuthReconnectionManager.mockReturnValue({
|
||||
clearReconnection: jest.fn(),
|
||||
});
|
||||
require('~/config').getMCPManager.mockReturnValue({
|
||||
getUserConnection: jest.fn().mockResolvedValue({
|
||||
fetchTools: jest.fn().mockResolvedValue([]),
|
||||
}),
|
||||
});
|
||||
const { getCachedTools, setCachedTools } = require('~/server/services/Config');
|
||||
getCachedTools.mockResolvedValue({});
|
||||
setCachedTools.mockResolvedValue();
|
||||
|
||||
const flowId = 'test-user-id:test-server';
|
||||
const csrfToken = generateTestCsrfToken(flowId);
|
||||
|
||||
await request(app)
|
||||
.get('/api/mcp/test-server/oauth/callback')
|
||||
.set('Cookie', [`oauth_csrf=${csrfToken}`])
|
||||
.query({ code: 'auth-code', state: flowId });
|
||||
|
||||
expect(MCPOAuthHandler.completeOAuthFlow).toHaveBeenCalledWith(
|
||||
flowId,
|
||||
'auth-code',
|
||||
mockFlowManager,
|
||||
{ 'X-Custom-Auth': 'header-value' },
|
||||
);
|
||||
expect(mockRegistryInstance.getServerConfig).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should fall back to registry oauth_headers when flow state lacks them', async () => {
|
||||
const mockFlowManager = {
|
||||
getFlowState: jest.fn().mockResolvedValue({ status: 'PENDING' }),
|
||||
completeFlow: jest.fn().mockResolvedValue(),
|
||||
deleteFlow: jest.fn().mockResolvedValue(true),
|
||||
};
|
||||
const mockFlowState = {
|
||||
serverName: 'test-server',
|
||||
userId: 'test-user-id',
|
||||
metadata: { toolFlowId: 'tool-flow-123' },
|
||||
clientInfo: {},
|
||||
codeVerifier: 'test-verifier',
|
||||
};
|
||||
const mockTokens = { access_token: 'tok', refresh_token: 'ref' };
|
||||
|
||||
MCPOAuthHandler.getFlowState.mockResolvedValue(mockFlowState);
|
||||
MCPOAuthHandler.completeOAuthFlow.mockResolvedValue(mockTokens);
|
||||
MCPTokenStorage.storeTokens.mockResolvedValue();
|
||||
mockRegistryInstance.getServerConfig.mockResolvedValue({
|
||||
oauth_headers: { 'X-Registry-Header': 'from-registry' },
|
||||
});
|
||||
getLogStores.mockReturnValue({});
|
||||
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
|
||||
require('~/config').getOAuthReconnectionManager.mockReturnValue({
|
||||
clearReconnection: jest.fn(),
|
||||
});
|
||||
require('~/config').getMCPManager.mockReturnValue({
|
||||
getUserConnection: jest.fn().mockResolvedValue({
|
||||
fetchTools: jest.fn().mockResolvedValue([]),
|
||||
}),
|
||||
});
|
||||
const { getCachedTools, setCachedTools } = require('~/server/services/Config');
|
||||
getCachedTools.mockResolvedValue({});
|
||||
setCachedTools.mockResolvedValue();
|
||||
|
||||
const flowId = 'test-user-id:test-server';
|
||||
const csrfToken = generateTestCsrfToken(flowId);
|
||||
|
||||
await request(app)
|
||||
.get('/api/mcp/test-server/oauth/callback')
|
||||
.set('Cookie', [`oauth_csrf=${csrfToken}`])
|
||||
.query({ code: 'auth-code', state: flowId });
|
||||
|
||||
expect(MCPOAuthHandler.completeOAuthFlow).toHaveBeenCalledWith(
|
||||
flowId,
|
||||
'auth-code',
|
||||
mockFlowManager,
|
||||
{ 'X-Registry-Header': 'from-registry' },
|
||||
);
|
||||
expect(mockRegistryInstance.getServerConfig).toHaveBeenCalledWith(
|
||||
'test-server',
|
||||
'test-user-id',
|
||||
undefined,
|
||||
);
|
||||
});
|
||||
|
||||
it('should redirect to error page when callback processing fails', async () => {
|
||||
MCPOAuthHandler.getFlowState.mockRejectedValue(new Error('Callback error'));
|
||||
const flowId = 'test-user-id:test-server';
|
||||
|
|
@ -1350,19 +1462,10 @@ describe('MCP Routes', () => {
|
|||
},
|
||||
});
|
||||
|
||||
expect(getMCPSetupData).toHaveBeenCalledWith('test-user-id');
|
||||
expect(getMCPSetupData).toHaveBeenCalledWith('test-user-id', expect.any(Object));
|
||||
expect(getServerConnectionStatus).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
|
||||
it('should return 404 when MCP config is not found', async () => {
|
||||
getMCPSetupData.mockRejectedValue(new Error('MCP config not found'));
|
||||
|
||||
const response = await request(app).get('/api/mcp/connection/status');
|
||||
|
||||
expect(response.status).toBe(404);
|
||||
expect(response.body).toEqual({ error: 'MCP config not found' });
|
||||
});
|
||||
|
||||
it('should return 500 when connection status check fails', async () => {
|
||||
getMCPSetupData.mockRejectedValue(new Error('Database error'));
|
||||
|
||||
|
|
@ -1437,15 +1540,6 @@ describe('MCP Routes', () => {
|
|||
});
|
||||
});
|
||||
|
||||
it('should return 404 when MCP config is not found', async () => {
|
||||
getMCPSetupData.mockRejectedValue(new Error('MCP config not found'));
|
||||
|
||||
const response = await request(app).get('/api/mcp/connection/status/test-server');
|
||||
|
||||
expect(response.status).toBe(404);
|
||||
expect(response.body).toEqual({ error: 'MCP config not found' });
|
||||
});
|
||||
|
||||
it('should return 500 when connection status check fails', async () => {
|
||||
getMCPSetupData.mockRejectedValue(new Error('Database connection failed'));
|
||||
|
||||
|
|
@ -1704,7 +1798,7 @@ describe('MCP Routes', () => {
|
|||
},
|
||||
};
|
||||
|
||||
mockRegistryInstance.getAllServerConfigs.mockResolvedValue(mockServerConfigs);
|
||||
mockResolveAllMcpConfigs.mockResolvedValue(mockServerConfigs);
|
||||
|
||||
const response = await request(app).get('/api/mcp/servers');
|
||||
|
||||
|
|
@ -1721,11 +1815,14 @@ describe('MCP Routes', () => {
|
|||
});
|
||||
expect(response.body['server-1'].headers).toBeUndefined();
|
||||
expect(response.body['server-2'].headers).toBeUndefined();
|
||||
expect(mockRegistryInstance.getAllServerConfigs).toHaveBeenCalledWith('test-user-id');
|
||||
expect(mockResolveAllMcpConfigs).toHaveBeenCalledWith(
|
||||
'test-user-id',
|
||||
expect.objectContaining({ id: 'test-user-id' }),
|
||||
);
|
||||
});
|
||||
|
||||
it('should return empty object when no servers are configured', async () => {
|
||||
mockRegistryInstance.getAllServerConfigs.mockResolvedValue({});
|
||||
mockResolveAllMcpConfigs.mockResolvedValue({});
|
||||
|
||||
const response = await request(app).get('/api/mcp/servers');
|
||||
|
||||
|
|
@ -1749,7 +1846,7 @@ describe('MCP Routes', () => {
|
|||
});
|
||||
|
||||
it('should return 500 when server config retrieval fails', async () => {
|
||||
mockRegistryInstance.getAllServerConfigs.mockRejectedValue(new Error('Database error'));
|
||||
mockResolveAllMcpConfigs.mockRejectedValue(new Error('Database error'));
|
||||
|
||||
const response = await request(app).get('/api/mcp/servers');
|
||||
|
||||
|
|
@ -1939,11 +2036,12 @@ describe('MCP Routes', () => {
|
|||
expect(mockRegistryInstance.getServerConfig).toHaveBeenCalledWith(
|
||||
'test-server',
|
||||
'test-user-id',
|
||||
{},
|
||||
);
|
||||
});
|
||||
|
||||
it('should return 404 when server not found', async () => {
|
||||
mockRegistryInstance.getServerConfig.mockResolvedValue(null);
|
||||
mockRegistryInstance.getServerConfig.mockResolvedValue(undefined);
|
||||
|
||||
const response = await request(app).get('/api/mcp/servers/non-existent-server');
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,8 @@
|
|||
const express = require('express');
|
||||
const passport = require('passport');
|
||||
const { randomState } = require('openid-client');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const crypto = require('node:crypto');
|
||||
const { CacheKeys } = require('librechat-data-provider');
|
||||
const { SystemCapabilities } = require('@librechat/data-schemas');
|
||||
const { logger, SystemCapabilities } = require('@librechat/data-schemas');
|
||||
const { getAdminPanelUrl, exchangeAdminCode, createSetBalanceConfig } = require('@librechat/api');
|
||||
const { loginController } = require('~/server/controllers/auth/LoginController');
|
||||
const { requireCapability } = require('~/server/middleware/roles/capabilities');
|
||||
|
|
@ -24,6 +23,28 @@ const setBalanceConfig = createSetBalanceConfig({
|
|||
|
||||
const router = express.Router();
|
||||
|
||||
function resolveRequestOrigin(req) {
|
||||
const originHeader = req.get('origin');
|
||||
if (originHeader) {
|
||||
try {
|
||||
return new URL(originHeader).origin;
|
||||
} catch {
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
|
||||
const refererHeader = req.get('referer');
|
||||
if (!refererHeader) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
try {
|
||||
return new URL(refererHeader).origin;
|
||||
} catch {
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
|
||||
router.post(
|
||||
'/login/local',
|
||||
middleware.logHeaders,
|
||||
|
|
@ -52,28 +73,340 @@ router.get('/oauth/openid/check', (req, res) => {
|
|||
res.status(200).json({ message: 'OpenID check successful' });
|
||||
});
|
||||
|
||||
router.get('/oauth/openid', (req, res, next) => {
|
||||
/** PKCE challenge cache TTL: 5 minutes (enough for user to authenticate with IdP) */
|
||||
const PKCE_CHALLENGE_TTL = 5 * 60 * 1000;
|
||||
/** Regex pattern for valid PKCE challenges: 64 hex characters (SHA-256 hex digest) */
|
||||
const PKCE_CHALLENGE_PATTERN = /^[a-f0-9]{64}$/;
|
||||
|
||||
/**
|
||||
* Generates a random hex state string for OAuth flows.
|
||||
* @returns {string} A 32-byte random hex string.
|
||||
*/
|
||||
function generateState() {
|
||||
return crypto.randomBytes(32).toString('hex');
|
||||
}
|
||||
|
||||
/**
|
||||
* Stores a PKCE challenge in cache keyed by state.
|
||||
* @param {string} state - The OAuth state value.
|
||||
* @param {string | undefined} codeChallenge - The PKCE code_challenge from query params.
|
||||
* @param {string} provider - Provider name for logging.
|
||||
* @returns {Promise<boolean>} True if stored successfully or no challenge provided.
|
||||
*/
|
||||
async function storePkceChallenge(state, codeChallenge, provider) {
|
||||
if (typeof codeChallenge !== 'string' || !PKCE_CHALLENGE_PATTERN.test(codeChallenge)) {
|
||||
return true;
|
||||
}
|
||||
try {
|
||||
const cache = getLogStores(CacheKeys.ADMIN_OAUTH_EXCHANGE);
|
||||
await cache.set(`pkce:${state}`, codeChallenge, PKCE_CHALLENGE_TTL);
|
||||
return true;
|
||||
} catch (err) {
|
||||
logger.error(`[admin/oauth/${provider}] Failed to store PKCE challenge:`, err);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Middleware to retrieve PKCE challenge from cache using the OAuth state.
|
||||
* Reads state from req.oauthState (set by a preceding middleware).
|
||||
* @param {string} provider - Provider name for logging.
|
||||
* @returns {Function} Express middleware.
|
||||
*/
|
||||
function retrievePkceChallenge(provider) {
|
||||
return async (req, res, next) => {
|
||||
if (!req.oauthState) {
|
||||
return next();
|
||||
}
|
||||
try {
|
||||
const cache = getLogStores(CacheKeys.ADMIN_OAUTH_EXCHANGE);
|
||||
const challenge = await cache.get(`pkce:${req.oauthState}`);
|
||||
if (challenge) {
|
||||
req.pkceChallenge = challenge;
|
||||
await cache.delete(`pkce:${req.oauthState}`);
|
||||
} else {
|
||||
logger.warn(
|
||||
`[admin/oauth/${provider}/callback] State present but no PKCE challenge found; PKCE will not be enforced for this request`,
|
||||
);
|
||||
}
|
||||
} catch (err) {
|
||||
logger.error(
|
||||
`[admin/oauth/${provider}/callback] Failed to retrieve PKCE challenge, aborting:`,
|
||||
err,
|
||||
);
|
||||
return res.redirect(
|
||||
`${getAdminPanelUrl()}/auth/${provider}/callback?error=pkce_retrieval_failed&error_description=Failed+to+retrieve+PKCE+challenge`,
|
||||
);
|
||||
}
|
||||
next();
|
||||
};
|
||||
}
|
||||
|
||||
/* ──────────────────────────────────────────────
|
||||
* OpenID Admin Routes
|
||||
* ────────────────────────────────────────────── */
|
||||
|
||||
router.get('/oauth/openid', async (req, res, next) => {
|
||||
const state = generateState();
|
||||
const stored = await storePkceChallenge(state, req.query.code_challenge, 'openid');
|
||||
if (!stored) {
|
||||
return res.redirect(
|
||||
`${getAdminPanelUrl()}/auth/openid/callback?error=pkce_store_failed&error_description=Failed+to+store+PKCE+challenge`,
|
||||
);
|
||||
}
|
||||
|
||||
return passport.authenticate('openidAdmin', {
|
||||
session: false,
|
||||
state: randomState(),
|
||||
state,
|
||||
})(req, res, next);
|
||||
});
|
||||
|
||||
router.get(
|
||||
'/oauth/openid/callback',
|
||||
(req, res, next) => {
|
||||
req.oauthState = typeof req.query.state === 'string' ? req.query.state : undefined;
|
||||
next();
|
||||
},
|
||||
passport.authenticate('openidAdmin', {
|
||||
failureRedirect: `${getAdminPanelUrl()}/auth/openid/callback?error=auth_failed&error_description=Authentication+failed`,
|
||||
failureMessage: true,
|
||||
session: false,
|
||||
}),
|
||||
retrievePkceChallenge('openid'),
|
||||
requireAdminAccess,
|
||||
setBalanceConfig,
|
||||
middleware.checkDomainAllowed,
|
||||
createOAuthHandler(`${getAdminPanelUrl()}/auth/openid/callback`),
|
||||
);
|
||||
|
||||
/* ──────────────────────────────────────────────
|
||||
* SAML Admin Routes
|
||||
* ────────────────────────────────────────────── */
|
||||
|
||||
router.get('/oauth/saml', async (req, res, next) => {
|
||||
const state = generateState();
|
||||
const stored = await storePkceChallenge(state, req.query.code_challenge, 'saml');
|
||||
if (!stored) {
|
||||
return res.redirect(
|
||||
`${getAdminPanelUrl()}/auth/saml/callback?error=pkce_store_failed&error_description=Failed+to+store+PKCE+challenge`,
|
||||
);
|
||||
}
|
||||
|
||||
return passport.authenticate('samlAdmin', {
|
||||
session: false,
|
||||
additionalParams: { RelayState: state },
|
||||
})(req, res, next);
|
||||
});
|
||||
|
||||
router.post(
|
||||
'/oauth/saml/callback',
|
||||
(req, res, next) => {
|
||||
req.oauthState = typeof req.body.RelayState === 'string' ? req.body.RelayState : undefined;
|
||||
next();
|
||||
},
|
||||
passport.authenticate('samlAdmin', {
|
||||
failureRedirect: `${getAdminPanelUrl()}/auth/saml/callback?error=auth_failed&error_description=Authentication+failed`,
|
||||
failureMessage: true,
|
||||
session: false,
|
||||
}),
|
||||
retrievePkceChallenge('saml'),
|
||||
requireAdminAccess,
|
||||
setBalanceConfig,
|
||||
middleware.checkDomainAllowed,
|
||||
createOAuthHandler(`${getAdminPanelUrl()}/auth/saml/callback`),
|
||||
);
|
||||
|
||||
/* ──────────────────────────────────────────────
|
||||
* Google Admin Routes
|
||||
* ────────────────────────────────────────────── */
|
||||
|
||||
router.get('/oauth/google', async (req, res, next) => {
|
||||
const state = generateState();
|
||||
const stored = await storePkceChallenge(state, req.query.code_challenge, 'google');
|
||||
if (!stored) {
|
||||
return res.redirect(
|
||||
`${getAdminPanelUrl()}/auth/google/callback?error=pkce_store_failed&error_description=Failed+to+store+PKCE+challenge`,
|
||||
);
|
||||
}
|
||||
|
||||
return passport.authenticate('googleAdmin', {
|
||||
scope: ['openid', 'profile', 'email'],
|
||||
session: false,
|
||||
state,
|
||||
})(req, res, next);
|
||||
});
|
||||
|
||||
router.get(
|
||||
'/oauth/google/callback',
|
||||
(req, res, next) => {
|
||||
req.oauthState = typeof req.query.state === 'string' ? req.query.state : undefined;
|
||||
next();
|
||||
},
|
||||
passport.authenticate('googleAdmin', {
|
||||
failureRedirect: `${getAdminPanelUrl()}/auth/google/callback?error=auth_failed&error_description=Authentication+failed`,
|
||||
failureMessage: true,
|
||||
session: false,
|
||||
}),
|
||||
retrievePkceChallenge('google'),
|
||||
requireAdminAccess,
|
||||
setBalanceConfig,
|
||||
middleware.checkDomainAllowed,
|
||||
createOAuthHandler(`${getAdminPanelUrl()}/auth/google/callback`),
|
||||
);
|
||||
|
||||
/* ──────────────────────────────────────────────
|
||||
* GitHub Admin Routes
|
||||
* ────────────────────────────────────────────── */
|
||||
|
||||
router.get('/oauth/github', async (req, res, next) => {
|
||||
const state = generateState();
|
||||
const stored = await storePkceChallenge(state, req.query.code_challenge, 'github');
|
||||
if (!stored) {
|
||||
return res.redirect(
|
||||
`${getAdminPanelUrl()}/auth/github/callback?error=pkce_store_failed&error_description=Failed+to+store+PKCE+challenge`,
|
||||
);
|
||||
}
|
||||
|
||||
return passport.authenticate('githubAdmin', {
|
||||
scope: ['user:email', 'read:user'],
|
||||
session: false,
|
||||
state,
|
||||
})(req, res, next);
|
||||
});
|
||||
|
||||
router.get(
|
||||
'/oauth/github/callback',
|
||||
(req, res, next) => {
|
||||
req.oauthState = typeof req.query.state === 'string' ? req.query.state : undefined;
|
||||
next();
|
||||
},
|
||||
passport.authenticate('githubAdmin', {
|
||||
failureRedirect: `${getAdminPanelUrl()}/auth/github/callback?error=auth_failed&error_description=Authentication+failed`,
|
||||
failureMessage: true,
|
||||
session: false,
|
||||
}),
|
||||
retrievePkceChallenge('github'),
|
||||
requireAdminAccess,
|
||||
setBalanceConfig,
|
||||
middleware.checkDomainAllowed,
|
||||
createOAuthHandler(`${getAdminPanelUrl()}/auth/github/callback`),
|
||||
);
|
||||
|
||||
/* ──────────────────────────────────────────────
|
||||
* Discord Admin Routes
|
||||
* ────────────────────────────────────────────── */
|
||||
|
||||
router.get('/oauth/discord', async (req, res, next) => {
|
||||
const state = generateState();
|
||||
const stored = await storePkceChallenge(state, req.query.code_challenge, 'discord');
|
||||
if (!stored) {
|
||||
return res.redirect(
|
||||
`${getAdminPanelUrl()}/auth/discord/callback?error=pkce_store_failed&error_description=Failed+to+store+PKCE+challenge`,
|
||||
);
|
||||
}
|
||||
|
||||
return passport.authenticate('discordAdmin', {
|
||||
scope: ['identify', 'email'],
|
||||
session: false,
|
||||
state,
|
||||
})(req, res, next);
|
||||
});
|
||||
|
||||
router.get(
|
||||
'/oauth/discord/callback',
|
||||
(req, res, next) => {
|
||||
req.oauthState = typeof req.query.state === 'string' ? req.query.state : undefined;
|
||||
next();
|
||||
},
|
||||
passport.authenticate('discordAdmin', {
|
||||
failureRedirect: `${getAdminPanelUrl()}/auth/discord/callback?error=auth_failed&error_description=Authentication+failed`,
|
||||
failureMessage: true,
|
||||
session: false,
|
||||
}),
|
||||
retrievePkceChallenge('discord'),
|
||||
requireAdminAccess,
|
||||
setBalanceConfig,
|
||||
middleware.checkDomainAllowed,
|
||||
createOAuthHandler(`${getAdminPanelUrl()}/auth/discord/callback`),
|
||||
);
|
||||
|
||||
/* ──────────────────────────────────────────────
|
||||
* Facebook Admin Routes
|
||||
* ────────────────────────────────────────────── */
|
||||
|
||||
router.get('/oauth/facebook', async (req, res, next) => {
|
||||
const state = generateState();
|
||||
const stored = await storePkceChallenge(state, req.query.code_challenge, 'facebook');
|
||||
if (!stored) {
|
||||
return res.redirect(
|
||||
`${getAdminPanelUrl()}/auth/facebook/callback?error=pkce_store_failed&error_description=Failed+to+store+PKCE+challenge`,
|
||||
);
|
||||
}
|
||||
|
||||
return passport.authenticate('facebookAdmin', {
|
||||
scope: ['public_profile'],
|
||||
session: false,
|
||||
state,
|
||||
})(req, res, next);
|
||||
});
|
||||
|
||||
router.get(
|
||||
'/oauth/facebook/callback',
|
||||
(req, res, next) => {
|
||||
req.oauthState = typeof req.query.state === 'string' ? req.query.state : undefined;
|
||||
next();
|
||||
},
|
||||
passport.authenticate('facebookAdmin', {
|
||||
failureRedirect: `${getAdminPanelUrl()}/auth/facebook/callback?error=auth_failed&error_description=Authentication+failed`,
|
||||
failureMessage: true,
|
||||
session: false,
|
||||
}),
|
||||
retrievePkceChallenge('facebook'),
|
||||
requireAdminAccess,
|
||||
setBalanceConfig,
|
||||
middleware.checkDomainAllowed,
|
||||
createOAuthHandler(`${getAdminPanelUrl()}/auth/facebook/callback`),
|
||||
);
|
||||
|
||||
/* ──────────────────────────────────────────────
|
||||
* Apple Admin Routes (POST callback)
|
||||
* ────────────────────────────────────────────── */
|
||||
|
||||
router.get('/oauth/apple', async (req, res, next) => {
|
||||
const state = generateState();
|
||||
const stored = await storePkceChallenge(state, req.query.code_challenge, 'apple');
|
||||
if (!stored) {
|
||||
return res.redirect(
|
||||
`${getAdminPanelUrl()}/auth/apple/callback?error=pkce_store_failed&error_description=Failed+to+store+PKCE+challenge`,
|
||||
);
|
||||
}
|
||||
|
||||
return passport.authenticate('appleAdmin', {
|
||||
session: false,
|
||||
state,
|
||||
})(req, res, next);
|
||||
});
|
||||
|
||||
router.post(
|
||||
'/oauth/apple/callback',
|
||||
(req, res, next) => {
|
||||
req.oauthState = typeof req.body.state === 'string' ? req.body.state : undefined;
|
||||
next();
|
||||
},
|
||||
passport.authenticate('appleAdmin', {
|
||||
failureRedirect: `${getAdminPanelUrl()}/auth/apple/callback?error=auth_failed&error_description=Authentication+failed`,
|
||||
failureMessage: true,
|
||||
session: false,
|
||||
}),
|
||||
retrievePkceChallenge('apple'),
|
||||
requireAdminAccess,
|
||||
setBalanceConfig,
|
||||
middleware.checkDomainAllowed,
|
||||
createOAuthHandler(`${getAdminPanelUrl()}/auth/apple/callback`),
|
||||
);
|
||||
|
||||
/** Regex pattern for valid exchange codes: 64 hex characters */
|
||||
const EXCHANGE_CODE_PATTERN = /^[a-f0-9]{64}$/i;
|
||||
const EXCHANGE_CODE_PATTERN = /^[a-f0-9]{64}$/;
|
||||
|
||||
/**
|
||||
* Exchange OAuth authorization code for tokens.
|
||||
|
|
@ -81,12 +414,12 @@ const EXCHANGE_CODE_PATTERN = /^[a-f0-9]{64}$/i;
|
|||
* The code is one-time-use and expires in 30 seconds.
|
||||
*
|
||||
* POST /api/admin/oauth/exchange
|
||||
* Body: { code: string }
|
||||
* Body: { code: string, code_verifier?: string }
|
||||
* Response: { token: string, refreshToken: string, user: object }
|
||||
*/
|
||||
router.post('/oauth/exchange', middleware.loginLimiter, async (req, res) => {
|
||||
try {
|
||||
const { code } = req.body;
|
||||
const { code, code_verifier: codeVerifier } = req.body;
|
||||
|
||||
if (!code) {
|
||||
logger.warn('[admin/oauth/exchange] Missing authorization code');
|
||||
|
|
@ -104,8 +437,20 @@ router.post('/oauth/exchange', middleware.loginLimiter, async (req, res) => {
|
|||
});
|
||||
}
|
||||
|
||||
if (
|
||||
codeVerifier !== undefined &&
|
||||
(typeof codeVerifier !== 'string' || codeVerifier.length < 1 || codeVerifier.length > 512)
|
||||
) {
|
||||
logger.warn('[admin/oauth/exchange] Invalid code_verifier format');
|
||||
return res.status(400).json({
|
||||
error: 'Invalid code_verifier',
|
||||
error_code: 'INVALID_VERIFIER',
|
||||
});
|
||||
}
|
||||
|
||||
const cache = getLogStores(CacheKeys.ADMIN_OAUTH_EXCHANGE);
|
||||
const result = await exchangeAdminCode(cache, code);
|
||||
const requestOrigin = resolveRequestOrigin(req);
|
||||
const result = await exchangeAdminCode(cache, code, requestOrigin, codeVerifier);
|
||||
|
||||
if (!result) {
|
||||
return res.status(401).json({
|
||||
|
|
|
|||
40
api/server/routes/admin/config.js
Normal file
40
api/server/routes/admin/config.js
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
const express = require('express');
|
||||
const { createAdminConfigHandlers } = require('@librechat/api');
|
||||
const { SystemCapabilities } = require('@librechat/data-schemas');
|
||||
const {
|
||||
hasConfigCapability,
|
||||
requireCapability,
|
||||
} = require('~/server/middleware/roles/capabilities');
|
||||
const { getAppConfig, invalidateConfigCaches } = require('~/server/services/Config');
|
||||
const { requireJwtAuth } = require('~/server/middleware');
|
||||
const db = require('~/models');
|
||||
|
||||
const router = express.Router();
|
||||
|
||||
const requireAdminAccess = requireCapability(SystemCapabilities.ACCESS_ADMIN);
|
||||
|
||||
const handlers = createAdminConfigHandlers({
|
||||
listAllConfigs: db.listAllConfigs,
|
||||
findConfigByPrincipal: db.findConfigByPrincipal,
|
||||
upsertConfig: db.upsertConfig,
|
||||
patchConfigFields: db.patchConfigFields,
|
||||
unsetConfigField: db.unsetConfigField,
|
||||
deleteConfig: db.deleteConfig,
|
||||
toggleConfigActive: db.toggleConfigActive,
|
||||
hasConfigCapability,
|
||||
getAppConfig,
|
||||
invalidateConfigCaches,
|
||||
});
|
||||
|
||||
router.use(requireJwtAuth, requireAdminAccess);
|
||||
|
||||
router.get('/', handlers.listConfigs);
|
||||
router.get('/base', handlers.getBaseConfig);
|
||||
router.get('/:principalType/:principalId', handlers.getConfig);
|
||||
router.put('/:principalType/:principalId', handlers.upsertConfigOverrides);
|
||||
router.patch('/:principalType/:principalId/fields', handlers.patchConfigField);
|
||||
router.delete('/:principalType/:principalId/fields', handlers.deleteConfigField);
|
||||
router.delete('/:principalType/:principalId', handlers.deleteConfigOverrides);
|
||||
router.patch('/:principalType/:principalId/active', handlers.toggleConfig);
|
||||
|
||||
module.exports = router;
|
||||
35
api/server/routes/admin/grants.js
Normal file
35
api/server/routes/admin/grants.js
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
const express = require('express');
|
||||
const { createAdminGrantsHandlers, getCachedPrincipals } = require('@librechat/api');
|
||||
const { SystemCapabilities } = require('@librechat/data-schemas');
|
||||
const { requireCapability } = require('~/server/middleware/roles/capabilities');
|
||||
const { requireJwtAuth } = require('~/server/middleware');
|
||||
const db = require('~/models');
|
||||
|
||||
const router = express.Router();
|
||||
|
||||
const requireAdminAccess = requireCapability(SystemCapabilities.ACCESS_ADMIN);
|
||||
|
||||
const handlers = createAdminGrantsHandlers({
|
||||
listGrants: db.listGrants,
|
||||
countGrants: db.countGrants,
|
||||
getCapabilitiesForPrincipal: db.getCapabilitiesForPrincipal,
|
||||
getCapabilitiesForPrincipals: db.getCapabilitiesForPrincipals,
|
||||
grantCapability: db.grantCapability,
|
||||
revokeCapability: db.revokeCapability,
|
||||
getUserPrincipals: db.getUserPrincipals,
|
||||
hasCapabilityForPrincipals: db.hasCapabilityForPrincipals,
|
||||
getHeldCapabilities: db.getHeldCapabilities,
|
||||
getCachedPrincipals,
|
||||
checkRoleExists: async (name) => (await db.getRoleByName(name)) != null,
|
||||
});
|
||||
|
||||
router.use(requireJwtAuth, requireAdminAccess);
|
||||
|
||||
router.get('/', handlers.listGrants);
|
||||
router.get('/effective', handlers.getEffectiveCapabilities);
|
||||
router.get('/:principalType/:principalId', handlers.getPrincipalGrants);
|
||||
router.post('/', handlers.assignGrant);
|
||||
/** Callers should encodeURIComponent the capability for client compatibility (e.g. manage%3Aconfigs%3Aendpoints). */
|
||||
router.delete('/:principalType/:principalId/:capability', handlers.revokeGrant);
|
||||
|
||||
module.exports = router;
|
||||
40
api/server/routes/admin/groups.js
Normal file
40
api/server/routes/admin/groups.js
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
const express = require('express');
|
||||
const { createAdminGroupsHandlers } = require('@librechat/api');
|
||||
const { SystemCapabilities } = require('@librechat/data-schemas');
|
||||
const { requireCapability } = require('~/server/middleware/roles/capabilities');
|
||||
const { requireJwtAuth } = require('~/server/middleware');
|
||||
const db = require('~/models');
|
||||
|
||||
const router = express.Router();
|
||||
|
||||
const requireAdminAccess = requireCapability(SystemCapabilities.ACCESS_ADMIN);
|
||||
const requireReadGroups = requireCapability(SystemCapabilities.READ_GROUPS);
|
||||
const requireManageGroups = requireCapability(SystemCapabilities.MANAGE_GROUPS);
|
||||
|
||||
const handlers = createAdminGroupsHandlers({
|
||||
listGroups: db.listGroups,
|
||||
countGroups: db.countGroups,
|
||||
findGroupById: db.findGroupById,
|
||||
createGroup: db.createGroup,
|
||||
updateGroupById: db.updateGroupById,
|
||||
deleteGroup: db.deleteGroup,
|
||||
addUserToGroup: db.addUserToGroup,
|
||||
removeUserFromGroup: db.removeUserFromGroup,
|
||||
removeMemberById: db.removeMemberById,
|
||||
findUsers: db.findUsers,
|
||||
deleteConfig: db.deleteConfig,
|
||||
deleteAclEntries: db.deleteAclEntries,
|
||||
});
|
||||
|
||||
router.use(requireJwtAuth, requireAdminAccess);
|
||||
|
||||
router.get('/', requireReadGroups, handlers.listGroups);
|
||||
router.post('/', requireManageGroups, handlers.createGroup);
|
||||
router.get('/:id', requireReadGroups, handlers.getGroup);
|
||||
router.patch('/:id', requireManageGroups, handlers.updateGroup);
|
||||
router.delete('/:id', requireManageGroups, handlers.deleteGroup);
|
||||
router.get('/:id/members', requireReadGroups, handlers.getGroupMembers);
|
||||
router.post('/:id/members', requireManageGroups, handlers.addGroupMember);
|
||||
router.delete('/:id/members/:userId', requireManageGroups, handlers.removeGroupMember);
|
||||
|
||||
module.exports = router;
|
||||
46
api/server/routes/admin/roles.js
Normal file
46
api/server/routes/admin/roles.js
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
const express = require('express');
|
||||
const { createAdminRolesHandlers } = require('@librechat/api');
|
||||
const { SystemCapabilities } = require('@librechat/data-schemas');
|
||||
const { requireCapability } = require('~/server/middleware/roles/capabilities');
|
||||
const { requireJwtAuth } = require('~/server/middleware');
|
||||
const db = require('~/models');
|
||||
|
||||
const router = express.Router();
|
||||
|
||||
const requireAdminAccess = requireCapability(SystemCapabilities.ACCESS_ADMIN);
|
||||
const requireReadRoles = requireCapability(SystemCapabilities.READ_ROLES);
|
||||
const requireManageRoles = requireCapability(SystemCapabilities.MANAGE_ROLES);
|
||||
|
||||
const handlers = createAdminRolesHandlers({
|
||||
listRoles: db.listRoles,
|
||||
countRoles: db.countRoles,
|
||||
getRoleByName: db.getRoleByName,
|
||||
createRoleByName: db.createRoleByName,
|
||||
updateRoleByName: db.updateRoleByName,
|
||||
updateAccessPermissions: db.updateAccessPermissions,
|
||||
deleteRoleByName: db.deleteRoleByName,
|
||||
findUser: db.findUser,
|
||||
updateUser: db.updateUser,
|
||||
updateUsersByRole: db.updateUsersByRole,
|
||||
findUserIdsByRole: db.findUserIdsByRole,
|
||||
updateUsersRoleByIds: db.updateUsersRoleByIds,
|
||||
listUsersByRole: db.listUsersByRole,
|
||||
countUsersByRole: db.countUsersByRole,
|
||||
deleteConfig: db.deleteConfig,
|
||||
deleteAclEntries: db.deleteAclEntries,
|
||||
deleteGrantsForPrincipal: db.deleteGrantsForPrincipal,
|
||||
});
|
||||
|
||||
router.use(requireJwtAuth, requireAdminAccess);
|
||||
|
||||
router.get('/', requireReadRoles, handlers.listRoles);
|
||||
router.post('/', requireManageRoles, handlers.createRole);
|
||||
router.get('/:name', requireReadRoles, handlers.getRole);
|
||||
router.patch('/:name', requireManageRoles, handlers.updateRole);
|
||||
router.delete('/:name', requireManageRoles, handlers.deleteRole);
|
||||
router.patch('/:name/permissions', requireManageRoles, handlers.updateRolePermissions);
|
||||
router.get('/:name/members', requireReadRoles, handlers.getRoleMembers);
|
||||
router.post('/:name/members', requireManageRoles, handlers.addRoleMember);
|
||||
router.delete('/:name/members/:userId', requireManageRoles, handlers.removeRoleMember);
|
||||
|
||||
module.exports = router;
|
||||
28
api/server/routes/admin/users.js
Normal file
28
api/server/routes/admin/users.js
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
const express = require('express');
|
||||
const { createAdminUsersHandlers } = require('@librechat/api');
|
||||
const { SystemCapabilities } = require('@librechat/data-schemas');
|
||||
const { requireCapability } = require('~/server/middleware/roles/capabilities');
|
||||
const { requireJwtAuth } = require('~/server/middleware');
|
||||
const db = require('~/models');
|
||||
|
||||
const router = express.Router();
|
||||
|
||||
const requireAdminAccess = requireCapability(SystemCapabilities.ACCESS_ADMIN);
|
||||
const requireReadUsers = requireCapability(SystemCapabilities.READ_USERS);
|
||||
// const requireManageUsers = requireCapability(SystemCapabilities.MANAGE_USERS);
|
||||
|
||||
const handlers = createAdminUsersHandlers({
|
||||
findUsers: db.findUsers,
|
||||
countUsers: db.countUsers,
|
||||
deleteUserById: db.deleteUserById,
|
||||
deleteConfig: db.deleteConfig,
|
||||
deleteAclEntries: db.deleteAclEntries,
|
||||
});
|
||||
|
||||
router.use(requireJwtAuth, requireAdminAccess);
|
||||
|
||||
router.get('/', requireReadUsers, handlers.listUsers);
|
||||
router.get('/search', requireReadUsers, handlers.searchUsers);
|
||||
// router.delete('/:id', requireManageUsers, handlers.deleteUser);
|
||||
|
||||
module.exports = router;
|
||||
186
api/server/routes/agents/__tests__/streamTenant.spec.js
Normal file
186
api/server/routes/agents/__tests__/streamTenant.spec.js
Normal file
|
|
@ -0,0 +1,186 @@
|
|||
const express = require('express');
|
||||
const request = require('supertest');
|
||||
|
||||
const mockGenerationJobManager = {
|
||||
getJob: jest.fn(),
|
||||
subscribe: jest.fn(),
|
||||
getResumeState: jest.fn(),
|
||||
abortJob: jest.fn(),
|
||||
getActiveJobIdsForUser: jest.fn().mockResolvedValue([]),
|
||||
};
|
||||
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
...jest.requireActual('@librechat/data-schemas'),
|
||||
logger: {
|
||||
debug: jest.fn(),
|
||||
warn: jest.fn(),
|
||||
error: jest.fn(),
|
||||
info: jest.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
jest.mock('@librechat/api', () => ({
|
||||
...jest.requireActual('@librechat/api'),
|
||||
isEnabled: jest.fn().mockReturnValue(false),
|
||||
GenerationJobManager: mockGenerationJobManager,
|
||||
}));
|
||||
|
||||
jest.mock('~/models', () => ({
|
||||
saveMessage: jest.fn(),
|
||||
}));
|
||||
|
||||
let mockUserId = 'user-123';
|
||||
let mockTenantId;
|
||||
|
||||
jest.mock('~/server/middleware', () => ({
|
||||
uaParser: (req, res, next) => next(),
|
||||
checkBan: (req, res, next) => next(),
|
||||
requireJwtAuth: (req, res, next) => {
|
||||
req.user = { id: mockUserId, tenantId: mockTenantId };
|
||||
next();
|
||||
},
|
||||
messageIpLimiter: (req, res, next) => next(),
|
||||
configMiddleware: (req, res, next) => next(),
|
||||
messageUserLimiter: (req, res, next) => next(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/routes/agents/chat', () => require('express').Router());
|
||||
jest.mock('~/server/routes/agents/v1', () => ({
|
||||
v1: require('express').Router(),
|
||||
}));
|
||||
jest.mock('~/server/routes/agents/openai', () => require('express').Router());
|
||||
jest.mock('~/server/routes/agents/responses', () => require('express').Router());
|
||||
|
||||
const agentsRouter = require('../index');
|
||||
const app = express();
|
||||
app.use(express.json());
|
||||
app.use('/agents', agentsRouter);
|
||||
|
||||
function mockSubscribeSuccess() {
|
||||
mockGenerationJobManager.subscribe.mockImplementation((_streamId, _writeEvent, onDone) => {
|
||||
process.nextTick(() => onDone({ done: true }));
|
||||
return { unsubscribe: jest.fn() };
|
||||
});
|
||||
}
|
||||
|
||||
describe('SSE stream tenant isolation', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
mockUserId = 'user-123';
|
||||
mockTenantId = undefined;
|
||||
});
|
||||
|
||||
describe('GET /chat/stream/:streamId', () => {
|
||||
it('returns 403 when a user from a different tenant accesses a stream', async () => {
|
||||
mockUserId = 'user-456';
|
||||
mockTenantId = 'tenant-b';
|
||||
|
||||
mockGenerationJobManager.getJob.mockResolvedValue({
|
||||
metadata: { userId: 'user-456', tenantId: 'tenant-a' },
|
||||
status: 'running',
|
||||
});
|
||||
|
||||
const res = await request(app).get('/agents/chat/stream/stream-123');
|
||||
expect(res.status).toBe(403);
|
||||
expect(res.body.error).toBe('Unauthorized');
|
||||
});
|
||||
|
||||
it('returns 404 when stream does not exist', async () => {
|
||||
mockGenerationJobManager.getJob.mockResolvedValue(null);
|
||||
|
||||
const res = await request(app).get('/agents/chat/stream/nonexistent');
|
||||
expect(res.status).toBe(404);
|
||||
});
|
||||
|
||||
it('proceeds past tenant guard when tenant matches', async () => {
|
||||
mockUserId = 'user-123';
|
||||
mockTenantId = 'tenant-a';
|
||||
mockSubscribeSuccess();
|
||||
|
||||
mockGenerationJobManager.getJob.mockResolvedValue({
|
||||
metadata: { userId: 'user-123', tenantId: 'tenant-a' },
|
||||
status: 'running',
|
||||
});
|
||||
|
||||
const res = await request(app).get('/agents/chat/stream/stream-123');
|
||||
expect(res.status).toBe(200);
|
||||
expect(mockGenerationJobManager.subscribe).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('proceeds past tenant guard when job has no tenantId (single-tenant mode)', async () => {
|
||||
mockUserId = 'user-123';
|
||||
mockTenantId = undefined;
|
||||
mockSubscribeSuccess();
|
||||
|
||||
mockGenerationJobManager.getJob.mockResolvedValue({
|
||||
metadata: { userId: 'user-123' },
|
||||
status: 'running',
|
||||
});
|
||||
|
||||
const res = await request(app).get('/agents/chat/stream/stream-123');
|
||||
expect(res.status).toBe(200);
|
||||
expect(mockGenerationJobManager.subscribe).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('returns 403 when job has tenantId but user has no tenantId', async () => {
|
||||
mockUserId = 'user-123';
|
||||
mockTenantId = undefined;
|
||||
|
||||
mockGenerationJobManager.getJob.mockResolvedValue({
|
||||
metadata: { userId: 'user-123', tenantId: 'some-tenant' },
|
||||
status: 'running',
|
||||
});
|
||||
|
||||
const res = await request(app).get('/agents/chat/stream/stream-123');
|
||||
expect(res.status).toBe(403);
|
||||
});
|
||||
});
|
||||
|
||||
describe('GET /chat/status/:conversationId', () => {
|
||||
it('returns 403 when tenant does not match', async () => {
|
||||
mockUserId = 'user-123';
|
||||
mockTenantId = 'tenant-b';
|
||||
|
||||
mockGenerationJobManager.getJob.mockResolvedValue({
|
||||
metadata: { userId: 'user-123', tenantId: 'tenant-a' },
|
||||
status: 'running',
|
||||
});
|
||||
|
||||
const res = await request(app).get('/agents/chat/status/conv-123');
|
||||
expect(res.status).toBe(403);
|
||||
expect(res.body.error).toBe('Unauthorized');
|
||||
});
|
||||
|
||||
it('returns status when tenant matches', async () => {
|
||||
mockUserId = 'user-123';
|
||||
mockTenantId = 'tenant-a';
|
||||
|
||||
mockGenerationJobManager.getJob.mockResolvedValue({
|
||||
metadata: { userId: 'user-123', tenantId: 'tenant-a' },
|
||||
status: 'running',
|
||||
createdAt: Date.now(),
|
||||
});
|
||||
mockGenerationJobManager.getResumeState.mockResolvedValue(null);
|
||||
|
||||
const res = await request(app).get('/agents/chat/status/conv-123');
|
||||
expect(res.status).toBe(200);
|
||||
expect(res.body.active).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe('POST /chat/abort', () => {
|
||||
it('returns 403 when tenant does not match', async () => {
|
||||
mockUserId = 'user-123';
|
||||
mockTenantId = 'tenant-b';
|
||||
|
||||
mockGenerationJobManager.getJob.mockResolvedValue({
|
||||
metadata: { userId: 'user-123', tenantId: 'tenant-a' },
|
||||
status: 'running',
|
||||
});
|
||||
|
||||
const res = await request(app).post('/agents/chat/abort').send({ streamId: 'stream-123' });
|
||||
expect(res.status).toBe(403);
|
||||
expect(res.body.error).toBe('Unauthorized');
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
@ -17,6 +17,11 @@ const chat = require('./chat');
|
|||
|
||||
const { LIMIT_MESSAGE_IP, LIMIT_MESSAGE_USER } = process.env ?? {};
|
||||
|
||||
/** Untenanted jobs (pre-multi-tenancy) remain accessible if the userId check passes. */
|
||||
function hasTenantMismatch(job, user) {
|
||||
return job.metadata?.tenantId != null && job.metadata.tenantId !== user.tenantId;
|
||||
}
|
||||
|
||||
const router = express.Router();
|
||||
|
||||
/**
|
||||
|
|
@ -67,6 +72,10 @@ router.get('/chat/stream/:streamId', async (req, res) => {
|
|||
return res.status(403).json({ error: 'Unauthorized' });
|
||||
}
|
||||
|
||||
if (hasTenantMismatch(job, req.user)) {
|
||||
return res.status(403).json({ error: 'Unauthorized' });
|
||||
}
|
||||
|
||||
res.setHeader('Content-Encoding', 'identity');
|
||||
res.setHeader('Content-Type', 'text/event-stream');
|
||||
res.setHeader('Cache-Control', 'no-cache, no-transform');
|
||||
|
|
@ -150,7 +159,10 @@ router.get('/chat/stream/:streamId', async (req, res) => {
|
|||
* @returns { activeJobIds: string[] }
|
||||
*/
|
||||
router.get('/chat/active', async (req, res) => {
|
||||
const activeJobIds = await GenerationJobManager.getActiveJobIdsForUser(req.user.id);
|
||||
const activeJobIds = await GenerationJobManager.getActiveJobIdsForUser(
|
||||
req.user.id,
|
||||
req.user.tenantId,
|
||||
);
|
||||
res.json({ activeJobIds });
|
||||
});
|
||||
|
||||
|
|
@ -174,6 +186,10 @@ router.get('/chat/status/:conversationId', async (req, res) => {
|
|||
return res.status(403).json({ error: 'Unauthorized' });
|
||||
}
|
||||
|
||||
if (hasTenantMismatch(job, req.user)) {
|
||||
return res.status(403).json({ error: 'Unauthorized' });
|
||||
}
|
||||
|
||||
// Get resume state which contains aggregatedContent
|
||||
// Avoid calling both getStreamInfo and getResumeState (both fetch content)
|
||||
const resumeState = await GenerationJobManager.getResumeState(conversationId);
|
||||
|
|
@ -213,7 +229,10 @@ router.post('/chat/abort', async (req, res) => {
|
|||
// This handles the case where frontend sends "new" but job was created with a UUID
|
||||
if (!job && userId) {
|
||||
logger.debug(`[AgentStream] Job not found by ID, checking active jobs for user: ${userId}`);
|
||||
const activeJobIds = await GenerationJobManager.getActiveJobIdsForUser(userId);
|
||||
const activeJobIds = await GenerationJobManager.getActiveJobIdsForUser(
|
||||
userId,
|
||||
req.user.tenantId,
|
||||
);
|
||||
if (activeJobIds.length > 0) {
|
||||
// Abort the most recent active job for this user
|
||||
jobStreamId = activeJobIds[0];
|
||||
|
|
@ -230,6 +249,10 @@ router.post('/chat/abort', async (req, res) => {
|
|||
return res.status(403).json({ error: 'Unauthorized' });
|
||||
}
|
||||
|
||||
if (hasTenantMismatch(job, req.user)) {
|
||||
return res.status(403).json({ error: 'Unauthorized' });
|
||||
}
|
||||
|
||||
logger.debug(`[AgentStream] Job found, aborting: ${jobStreamId}`);
|
||||
const abortResult = await GenerationJobManager.abortJob(jobStreamId);
|
||||
logger.debug(`[AgentStream] Job aborted successfully: ${jobStreamId}`, {
|
||||
|
|
|
|||
|
|
@ -1,10 +1,9 @@
|
|||
const express = require('express');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { isEnabled, getBalanceConfig } = require('@librechat/api');
|
||||
const { CacheKeys, defaultSocialLogins } = require('librechat-data-provider');
|
||||
const { defaultSocialLogins } = require('librechat-data-provider');
|
||||
const { logger, getTenantId } = require('@librechat/data-schemas');
|
||||
const { getLdapConfig } = require('~/server/services/Config/ldap');
|
||||
const { getAppConfig } = require('~/server/services/Config/app');
|
||||
const { getLogStores } = require('~/cache');
|
||||
|
||||
const router = express.Router();
|
||||
const emailLoginEnabled =
|
||||
|
|
@ -20,128 +19,159 @@ const publicSharedLinksEnabled =
|
|||
const sharePointFilePickerEnabled = isEnabled(process.env.ENABLE_SHAREPOINT_FILEPICKER);
|
||||
const openidReuseTokens = isEnabled(process.env.OPENID_REUSE_TOKENS);
|
||||
|
||||
router.get('/', async function (req, res) {
|
||||
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
||||
function isBirthday() {
|
||||
const today = new Date();
|
||||
return today.getMonth() === 1 && today.getDate() === 11;
|
||||
}
|
||||
|
||||
const cachedStartupConfig = await cache.get(CacheKeys.STARTUP_CONFIG);
|
||||
if (cachedStartupConfig) {
|
||||
res.send(cachedStartupConfig);
|
||||
return;
|
||||
}
|
||||
function buildSharedPayload() {
|
||||
const isOpenIdEnabled =
|
||||
!!process.env.OPENID_CLIENT_ID &&
|
||||
(isEnabled(process.env.OPENID_USE_PKCE) || !!process.env.OPENID_CLIENT_SECRET?.trim()) &&
|
||||
!!process.env.OPENID_ISSUER &&
|
||||
!!process.env.OPENID_SESSION_SECRET;
|
||||
|
||||
const isBirthday = () => {
|
||||
const today = new Date();
|
||||
return today.getMonth() === 1 && today.getDate() === 11;
|
||||
};
|
||||
const isSamlEnabled =
|
||||
!!process.env.SAML_ENTRY_POINT &&
|
||||
!!process.env.SAML_ISSUER &&
|
||||
!!process.env.SAML_CERT &&
|
||||
!!process.env.SAML_SESSION_SECRET;
|
||||
|
||||
const ldap = getLdapConfig();
|
||||
|
||||
/** @type {Partial<TStartupConfig>} */
|
||||
const payload = {
|
||||
appTitle: process.env.APP_TITLE || 'LibreChat',
|
||||
discordLoginEnabled: !!process.env.DISCORD_CLIENT_ID && !!process.env.DISCORD_CLIENT_SECRET,
|
||||
facebookLoginEnabled: !!process.env.FACEBOOK_CLIENT_ID && !!process.env.FACEBOOK_CLIENT_SECRET,
|
||||
githubLoginEnabled: !!process.env.GITHUB_CLIENT_ID && !!process.env.GITHUB_CLIENT_SECRET,
|
||||
googleLoginEnabled: !!process.env.GOOGLE_CLIENT_ID && !!process.env.GOOGLE_CLIENT_SECRET,
|
||||
appleLoginEnabled:
|
||||
!!process.env.APPLE_CLIENT_ID &&
|
||||
!!process.env.APPLE_TEAM_ID &&
|
||||
!!process.env.APPLE_KEY_ID &&
|
||||
!!process.env.APPLE_PRIVATE_KEY_PATH,
|
||||
openidLoginEnabled: isOpenIdEnabled,
|
||||
openidLabel: process.env.OPENID_BUTTON_LABEL || 'Continue with OpenID',
|
||||
openidImageUrl: process.env.OPENID_IMAGE_URL,
|
||||
openidAutoRedirect: isEnabled(process.env.OPENID_AUTO_REDIRECT),
|
||||
samlLoginEnabled: !isOpenIdEnabled && isSamlEnabled,
|
||||
samlLabel: process.env.SAML_BUTTON_LABEL,
|
||||
samlImageUrl: process.env.SAML_IMAGE_URL,
|
||||
serverDomain: process.env.DOMAIN_SERVER || 'http://localhost:3080',
|
||||
emailLoginEnabled,
|
||||
registrationEnabled: !ldap?.enabled && isEnabled(process.env.ALLOW_REGISTRATION),
|
||||
socialLoginEnabled: isEnabled(process.env.ALLOW_SOCIAL_LOGIN),
|
||||
emailEnabled:
|
||||
(!!process.env.EMAIL_SERVICE || !!process.env.EMAIL_HOST) &&
|
||||
!!process.env.EMAIL_USERNAME &&
|
||||
!!process.env.EMAIL_PASSWORD &&
|
||||
!!process.env.EMAIL_FROM,
|
||||
passwordResetEnabled,
|
||||
showBirthdayIcon:
|
||||
isBirthday() ||
|
||||
isEnabled(process.env.SHOW_BIRTHDAY_ICON) ||
|
||||
process.env.SHOW_BIRTHDAY_ICON === '',
|
||||
helpAndFaqURL: process.env.HELP_AND_FAQ_URL || 'https://librechat.ai',
|
||||
sharedLinksEnabled,
|
||||
publicSharedLinksEnabled,
|
||||
analyticsGtmId: process.env.ANALYTICS_GTM_ID,
|
||||
openidReuseTokens,
|
||||
};
|
||||
|
||||
const minPasswordLength = parseInt(process.env.MIN_PASSWORD_LENGTH, 10);
|
||||
if (minPasswordLength && !isNaN(minPasswordLength)) {
|
||||
payload.minPasswordLength = minPasswordLength;
|
||||
}
|
||||
|
||||
if (ldap) {
|
||||
payload.ldap = ldap;
|
||||
}
|
||||
|
||||
if (typeof process.env.CUSTOM_FOOTER === 'string') {
|
||||
payload.customFooter = process.env.CUSTOM_FOOTER;
|
||||
}
|
||||
|
||||
return payload;
|
||||
}
|
||||
|
||||
function buildWebSearchConfig(appConfig) {
|
||||
const ws = appConfig?.webSearch;
|
||||
if (!ws) {
|
||||
return undefined;
|
||||
}
|
||||
const { searchProvider, scraperProvider, rerankerType } = ws;
|
||||
if (!searchProvider && !scraperProvider && !rerankerType) {
|
||||
return undefined;
|
||||
}
|
||||
return {
|
||||
...(searchProvider && { searchProvider }),
|
||||
...(scraperProvider && { scraperProvider }),
|
||||
...(rerankerType && { rerankerType }),
|
||||
};
|
||||
}
|
||||
|
||||
router.get('/', async function (req, res) {
|
||||
try {
|
||||
const appConfig = await getAppConfig({ role: req.user?.role });
|
||||
const sharedPayload = buildSharedPayload();
|
||||
|
||||
const isOpenIdEnabled =
|
||||
!!process.env.OPENID_CLIENT_ID &&
|
||||
(isEnabled(process.env.OPENID_USE_PKCE) || !!process.env.OPENID_CLIENT_SECRET?.trim()) &&
|
||||
!!process.env.OPENID_ISSUER &&
|
||||
!!process.env.OPENID_SESSION_SECRET;
|
||||
if (!req.user) {
|
||||
const tenantId = getTenantId();
|
||||
const baseConfig = await getAppConfig(tenantId ? { tenantId } : { baseOnly: true });
|
||||
|
||||
const isSamlEnabled =
|
||||
!!process.env.SAML_ENTRY_POINT &&
|
||||
!!process.env.SAML_ISSUER &&
|
||||
!!process.env.SAML_CERT &&
|
||||
!!process.env.SAML_SESSION_SECRET;
|
||||
/** @type {Partial<TStartupConfig>} */
|
||||
const payload = {
|
||||
...sharedPayload,
|
||||
socialLogins: baseConfig?.registration?.socialLogins ?? defaultSocialLogins,
|
||||
turnstile: baseConfig?.turnstileConfig,
|
||||
};
|
||||
|
||||
const interfaceConfig = baseConfig?.interfaceConfig;
|
||||
if (interfaceConfig?.privacyPolicy || interfaceConfig?.termsOfService) {
|
||||
payload.interface = {};
|
||||
if (interfaceConfig.privacyPolicy) {
|
||||
payload.interface.privacyPolicy = interfaceConfig.privacyPolicy;
|
||||
}
|
||||
if (interfaceConfig.termsOfService) {
|
||||
payload.interface.termsOfService = interfaceConfig.termsOfService;
|
||||
}
|
||||
}
|
||||
|
||||
return res.status(200).send(payload);
|
||||
}
|
||||
|
||||
const appConfig = await getAppConfig({
|
||||
role: req.user.role,
|
||||
userId: req.user.id,
|
||||
tenantId: req.user.tenantId || getTenantId(),
|
||||
});
|
||||
|
||||
const balanceConfig = getBalanceConfig(appConfig);
|
||||
|
||||
/** @type {TStartupConfig} */
|
||||
const payload = {
|
||||
appTitle: process.env.APP_TITLE || 'LibreChat',
|
||||
...sharedPayload,
|
||||
socialLogins: appConfig?.registration?.socialLogins ?? defaultSocialLogins,
|
||||
discordLoginEnabled: !!process.env.DISCORD_CLIENT_ID && !!process.env.DISCORD_CLIENT_SECRET,
|
||||
facebookLoginEnabled:
|
||||
!!process.env.FACEBOOK_CLIENT_ID && !!process.env.FACEBOOK_CLIENT_SECRET,
|
||||
githubLoginEnabled: !!process.env.GITHUB_CLIENT_ID && !!process.env.GITHUB_CLIENT_SECRET,
|
||||
googleLoginEnabled: !!process.env.GOOGLE_CLIENT_ID && !!process.env.GOOGLE_CLIENT_SECRET,
|
||||
appleLoginEnabled:
|
||||
!!process.env.APPLE_CLIENT_ID &&
|
||||
!!process.env.APPLE_TEAM_ID &&
|
||||
!!process.env.APPLE_KEY_ID &&
|
||||
!!process.env.APPLE_PRIVATE_KEY_PATH,
|
||||
openidLoginEnabled: isOpenIdEnabled,
|
||||
openidLabel: process.env.OPENID_BUTTON_LABEL || 'Continue with OpenID',
|
||||
openidImageUrl: process.env.OPENID_IMAGE_URL,
|
||||
openidAutoRedirect: isEnabled(process.env.OPENID_AUTO_REDIRECT),
|
||||
samlLoginEnabled: !isOpenIdEnabled && isSamlEnabled,
|
||||
samlLabel: process.env.SAML_BUTTON_LABEL,
|
||||
samlImageUrl: process.env.SAML_IMAGE_URL,
|
||||
serverDomain: process.env.DOMAIN_SERVER || 'http://localhost:3080',
|
||||
emailLoginEnabled,
|
||||
registrationEnabled: !ldap?.enabled && isEnabled(process.env.ALLOW_REGISTRATION),
|
||||
socialLoginEnabled: isEnabled(process.env.ALLOW_SOCIAL_LOGIN),
|
||||
emailEnabled:
|
||||
(!!process.env.EMAIL_SERVICE || !!process.env.EMAIL_HOST) &&
|
||||
!!process.env.EMAIL_USERNAME &&
|
||||
!!process.env.EMAIL_PASSWORD &&
|
||||
!!process.env.EMAIL_FROM,
|
||||
passwordResetEnabled,
|
||||
showBirthdayIcon:
|
||||
isBirthday() ||
|
||||
isEnabled(process.env.SHOW_BIRTHDAY_ICON) ||
|
||||
process.env.SHOW_BIRTHDAY_ICON === '',
|
||||
helpAndFaqURL: process.env.HELP_AND_FAQ_URL || 'https://librechat.ai',
|
||||
interface: appConfig?.interfaceConfig,
|
||||
turnstile: appConfig?.turnstileConfig,
|
||||
modelSpecs: appConfig?.modelSpecs,
|
||||
balance: balanceConfig,
|
||||
sharedLinksEnabled,
|
||||
publicSharedLinksEnabled,
|
||||
analyticsGtmId: process.env.ANALYTICS_GTM_ID,
|
||||
bundlerURL: process.env.SANDPACK_BUNDLER_URL,
|
||||
staticBundlerURL: process.env.SANDPACK_STATIC_BUNDLER_URL,
|
||||
sharePointFilePickerEnabled,
|
||||
sharePointBaseUrl: process.env.SHAREPOINT_BASE_URL,
|
||||
sharePointPickerGraphScope: process.env.SHAREPOINT_PICKER_GRAPH_SCOPE,
|
||||
sharePointPickerSharePointScope: process.env.SHAREPOINT_PICKER_SHAREPOINT_SCOPE,
|
||||
openidReuseTokens,
|
||||
conversationImportMaxFileSize: process.env.CONVERSATION_IMPORT_MAX_FILE_SIZE_BYTES
|
||||
? parseInt(process.env.CONVERSATION_IMPORT_MAX_FILE_SIZE_BYTES, 10)
|
||||
: 0,
|
||||
};
|
||||
|
||||
const minPasswordLength = parseInt(process.env.MIN_PASSWORD_LENGTH, 10);
|
||||
if (minPasswordLength && !isNaN(minPasswordLength)) {
|
||||
payload.minPasswordLength = minPasswordLength;
|
||||
const webSearch = buildWebSearchConfig(appConfig);
|
||||
if (webSearch) {
|
||||
payload.webSearch = webSearch;
|
||||
}
|
||||
|
||||
const webSearchConfig = appConfig?.webSearch;
|
||||
if (
|
||||
webSearchConfig != null &&
|
||||
(webSearchConfig.searchProvider ||
|
||||
webSearchConfig.scraperProvider ||
|
||||
webSearchConfig.rerankerType)
|
||||
) {
|
||||
payload.webSearch = {};
|
||||
}
|
||||
|
||||
if (webSearchConfig?.searchProvider) {
|
||||
payload.webSearch.searchProvider = webSearchConfig.searchProvider;
|
||||
}
|
||||
if (webSearchConfig?.scraperProvider) {
|
||||
payload.webSearch.scraperProvider = webSearchConfig.scraperProvider;
|
||||
}
|
||||
if (webSearchConfig?.rerankerType) {
|
||||
payload.webSearch.rerankerType = webSearchConfig.rerankerType;
|
||||
}
|
||||
|
||||
if (ldap) {
|
||||
payload.ldap = ldap;
|
||||
}
|
||||
|
||||
if (typeof process.env.CUSTOM_FOOTER === 'string') {
|
||||
payload.customFooter = process.env.CUSTOM_FOOTER;
|
||||
}
|
||||
|
||||
await cache.set(CacheKeys.STARTUP_CONFIG, payload);
|
||||
return res.status(200).send(payload);
|
||||
} catch (err) {
|
||||
logger.error('Error in startup config', err);
|
||||
|
|
|
|||
|
|
@ -267,7 +267,11 @@ router.post(
|
|||
async (req, res) => {
|
||||
try {
|
||||
/* TODO: optimize to return imported conversations and add manually */
|
||||
await importConversations({ filepath: req.file.path, requestUserId: req.user.id });
|
||||
await importConversations({
|
||||
filepath: req.file.path,
|
||||
requestUserId: req.user.id,
|
||||
userRole: req.user.role,
|
||||
});
|
||||
res.status(201).json({ message: 'Conversation(s) imported successfully' });
|
||||
} catch (error) {
|
||||
logger.error('Error processing file', error);
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
const express = require('express');
|
||||
const requireJwtAuth = require('~/server/middleware/requireJwtAuth');
|
||||
const endpointController = require('~/server/controllers/EndpointController');
|
||||
|
||||
const router = express.Router();
|
||||
router.get('/', endpointController);
|
||||
/** Auth required for role/tenant-scoped endpoint config resolution. */
|
||||
router.get('/', requireJwtAuth, endpointController);
|
||||
|
||||
module.exports = router;
|
||||
|
|
|
|||
|
|
@ -2,6 +2,11 @@ const accessPermissions = require('./accessPermissions');
|
|||
const assistants = require('./assistants');
|
||||
const categories = require('./categories');
|
||||
const adminAuth = require('./admin/auth');
|
||||
const adminConfig = require('./admin/config');
|
||||
const adminGrants = require('./admin/grants');
|
||||
const adminGroups = require('./admin/groups');
|
||||
const adminRoles = require('./admin/roles');
|
||||
const adminUsers = require('./admin/users');
|
||||
const endpoints = require('./endpoints');
|
||||
const staticRoute = require('./static');
|
||||
const messages = require('./messages');
|
||||
|
|
@ -31,6 +36,11 @@ module.exports = {
|
|||
mcp,
|
||||
auth,
|
||||
adminAuth,
|
||||
adminConfig,
|
||||
adminGrants,
|
||||
adminGroups,
|
||||
adminRoles,
|
||||
adminUsers,
|
||||
keys,
|
||||
apiKeys,
|
||||
user,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
const { Router } = require('express');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { logger, getTenantId } = require('@librechat/data-schemas');
|
||||
const {
|
||||
CacheKeys,
|
||||
Constants,
|
||||
|
|
@ -36,7 +36,11 @@ const {
|
|||
getFlowStateManager,
|
||||
getMCPManager,
|
||||
} = require('~/config');
|
||||
const { getMCPSetupData, getServerConnectionStatus } = require('~/server/services/MCP');
|
||||
const {
|
||||
getServerConnectionStatus,
|
||||
resolveConfigServers,
|
||||
getMCPSetupData,
|
||||
} = require('~/server/services/MCP');
|
||||
const { requireJwtAuth, canAccessMCPServerResource } = require('~/server/middleware');
|
||||
const { getUserPluginAuthValue } = require('~/server/services/PluginService');
|
||||
const { updateMCPServerTools } = require('~/server/services/Config/mcp');
|
||||
|
|
@ -101,7 +105,8 @@ router.get('/:serverName/oauth/initiate', requireJwtAuth, setOAuthSession, async
|
|||
return res.status(400).json({ error: 'Invalid flow state' });
|
||||
}
|
||||
|
||||
const oauthHeaders = await getOAuthHeaders(serverName, userId);
|
||||
const configServers = await resolveConfigServers(req);
|
||||
const oauthHeaders = await getOAuthHeaders(serverName, userId, configServers);
|
||||
const {
|
||||
authorizationUrl,
|
||||
flowId: oauthFlowId,
|
||||
|
|
@ -233,7 +238,14 @@ router.get('/:serverName/oauth/callback', async (req, res) => {
|
|||
}
|
||||
|
||||
logger.debug('[MCP OAuth] Completing OAuth flow');
|
||||
const oauthHeaders = await getOAuthHeaders(serverName, flowState.userId);
|
||||
if (!flowState.oauthHeaders) {
|
||||
logger.warn(
|
||||
'[MCP OAuth] oauthHeaders absent from flow state — config-source server oauth_headers will be empty',
|
||||
{ serverName, flowId },
|
||||
);
|
||||
}
|
||||
const oauthHeaders =
|
||||
flowState.oauthHeaders ?? (await getOAuthHeaders(serverName, flowState.userId));
|
||||
const tokens = await MCPOAuthHandler.completeOAuthFlow(flowId, code, flowManager, oauthHeaders);
|
||||
logger.info('[MCP OAuth] OAuth flow completed, tokens received in callback route');
|
||||
|
||||
|
|
@ -497,7 +509,12 @@ router.post(
|
|||
logger.info(`[MCP Reinitialize] Reinitializing server: ${serverName}`);
|
||||
|
||||
const mcpManager = getMCPManager();
|
||||
const serverConfig = await getMCPServersRegistry().getServerConfig(serverName, user.id);
|
||||
const configServers = await resolveConfigServers(req);
|
||||
const serverConfig = await getMCPServersRegistry().getServerConfig(
|
||||
serverName,
|
||||
user.id,
|
||||
configServers,
|
||||
);
|
||||
if (!serverConfig) {
|
||||
return res.status(404).json({
|
||||
error: `MCP server '${serverName}' not found in configuration`,
|
||||
|
|
@ -522,6 +539,8 @@ router.post(
|
|||
const result = await reinitMCPServer({
|
||||
user,
|
||||
serverName,
|
||||
serverConfig,
|
||||
configServers,
|
||||
userMCPAuthMap,
|
||||
});
|
||||
|
||||
|
|
@ -564,6 +583,7 @@ router.get('/connection/status', requireJwtAuth, async (req, res) => {
|
|||
|
||||
const { mcpConfig, appConnections, userConnections, oauthServers } = await getMCPSetupData(
|
||||
user.id,
|
||||
{ role: user.role, tenantId: getTenantId() },
|
||||
);
|
||||
const connectionStatus = {};
|
||||
|
||||
|
|
@ -593,9 +613,6 @@ router.get('/connection/status', requireJwtAuth, async (req, res) => {
|
|||
connectionStatus,
|
||||
});
|
||||
} catch (error) {
|
||||
if (error.message === 'MCP config not found') {
|
||||
return res.status(404).json({ error: error.message });
|
||||
}
|
||||
logger.error('[MCP Connection Status] Failed to get connection status', error);
|
||||
res.status(500).json({ error: 'Failed to get connection status' });
|
||||
}
|
||||
|
|
@ -616,6 +633,7 @@ router.get('/connection/status/:serverName', requireJwtAuth, async (req, res) =>
|
|||
|
||||
const { mcpConfig, appConnections, userConnections, oauthServers } = await getMCPSetupData(
|
||||
user.id,
|
||||
{ role: user.role, tenantId: getTenantId() },
|
||||
);
|
||||
|
||||
if (!mcpConfig[serverName]) {
|
||||
|
|
@ -640,9 +658,6 @@ router.get('/connection/status/:serverName', requireJwtAuth, async (req, res) =>
|
|||
requiresOAuth: serverStatus.requiresOAuth,
|
||||
});
|
||||
} catch (error) {
|
||||
if (error.message === 'MCP config not found') {
|
||||
return res.status(404).json({ error: error.message });
|
||||
}
|
||||
logger.error(
|
||||
`[MCP Per-Server Status] Failed to get connection status for ${req.params.serverName}`,
|
||||
error,
|
||||
|
|
@ -664,7 +679,12 @@ router.get('/:serverName/auth-values', requireJwtAuth, checkMCPUsePermissions, a
|
|||
return res.status(401).json({ error: 'User not authenticated' });
|
||||
}
|
||||
|
||||
const serverConfig = await getMCPServersRegistry().getServerConfig(serverName, user.id);
|
||||
const configServers = await resolveConfigServers(req);
|
||||
const serverConfig = await getMCPServersRegistry().getServerConfig(
|
||||
serverName,
|
||||
user.id,
|
||||
configServers,
|
||||
);
|
||||
if (!serverConfig) {
|
||||
return res.status(404).json({
|
||||
error: `MCP server '${serverName}' not found in configuration`,
|
||||
|
|
@ -703,8 +723,12 @@ router.get('/:serverName/auth-values', requireJwtAuth, checkMCPUsePermissions, a
|
|||
}
|
||||
});
|
||||
|
||||
async function getOAuthHeaders(serverName, userId) {
|
||||
const serverConfig = await getMCPServersRegistry().getServerConfig(serverName, userId);
|
||||
async function getOAuthHeaders(serverName, userId, configServers) {
|
||||
const serverConfig = await getMCPServersRegistry().getServerConfig(
|
||||
serverName,
|
||||
userId,
|
||||
configServers,
|
||||
);
|
||||
return serverConfig?.oauth_headers ?? {};
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ const {
|
|||
checkEmailConfig,
|
||||
isEmailDomainAllowed,
|
||||
shouldUseSecureCookie,
|
||||
resolveAppConfigForUser,
|
||||
} = require('@librechat/api');
|
||||
const {
|
||||
findUser,
|
||||
|
|
@ -189,7 +190,7 @@ const registerUser = async (user, additionalData = {}) => {
|
|||
|
||||
let newUserId;
|
||||
try {
|
||||
const appConfig = await getAppConfig();
|
||||
const appConfig = await getAppConfig({ baseOnly: true });
|
||||
if (!isEmailDomainAllowed(email, appConfig?.registration?.allowedDomains)) {
|
||||
const errorMessage =
|
||||
'The email address provided cannot be used. Please use a different email address.';
|
||||
|
|
@ -255,19 +256,52 @@ const registerUser = async (user, additionalData = {}) => {
|
|||
};
|
||||
|
||||
/**
|
||||
* Request password reset
|
||||
* Request password reset.
|
||||
*
|
||||
* Uses a two-phase domain check: fast-fail with the memory-cached base config
|
||||
* (zero DB queries) to block globally denied domains before user lookup, then
|
||||
* re-check with tenant-scoped config after user lookup so tenant-specific
|
||||
* restrictions are enforced.
|
||||
*
|
||||
* Phase 1 (base check) returns an Error (HTTP 400) — this intentionally reveals
|
||||
* that the domain is globally blocked, but fires before any DB lookup so it
|
||||
* cannot confirm user existence. Phase 2 (tenant check) returns the generic
|
||||
* success message (HTTP 200) to prevent user-enumeration via status codes.
|
||||
*
|
||||
* @param {ServerRequest} req
|
||||
*/
|
||||
const requestPasswordReset = async (req) => {
|
||||
const { email } = req.body;
|
||||
const appConfig = await getAppConfig();
|
||||
if (!isEmailDomainAllowed(email, appConfig?.registration?.allowedDomains)) {
|
||||
|
||||
const baseConfig = await getAppConfig({ baseOnly: true });
|
||||
if (!isEmailDomainAllowed(email, baseConfig?.registration?.allowedDomains)) {
|
||||
logger.warn(
|
||||
`[requestPasswordReset] Blocked - email domain not allowed [Email: ${email}] [IP: ${req.ip}]`,
|
||||
);
|
||||
const error = new Error(ErrorTypes.AUTH_FAILED);
|
||||
error.code = ErrorTypes.AUTH_FAILED;
|
||||
error.message = 'Email domain not allowed';
|
||||
return error;
|
||||
}
|
||||
const user = await findUser({ email }, 'email _id');
|
||||
|
||||
const user = await findUser({ email }, 'email _id role tenantId');
|
||||
let appConfig = baseConfig;
|
||||
if (user?.tenantId) {
|
||||
try {
|
||||
appConfig = await resolveAppConfigForUser(getAppConfig, user);
|
||||
} catch (err) {
|
||||
logger.error('[requestPasswordReset] Failed to resolve tenant config, using base:', err);
|
||||
}
|
||||
}
|
||||
|
||||
if (!isEmailDomainAllowed(email, appConfig?.registration?.allowedDomains)) {
|
||||
logger.warn(
|
||||
`[requestPasswordReset] Tenant config blocked domain [Email: ${email}] [IP: ${req.ip}]`,
|
||||
);
|
||||
return {
|
||||
message: 'If an account with that email exists, a password reset link has been sent to it.',
|
||||
};
|
||||
}
|
||||
const emailEnabled = checkEmailConfig();
|
||||
|
||||
logger.warn(`[requestPasswordReset] [Password reset request initiated] [Email: ${email}]`);
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ jest.mock('@librechat/api', () => ({
|
|||
isEmailDomainAllowed: jest.fn(),
|
||||
math: jest.fn((val, fallback) => (val ? Number(val) : fallback)),
|
||||
shouldUseSecureCookie: jest.fn(() => false),
|
||||
resolveAppConfigForUser: jest.fn(async (_getAppConfig, _user) => ({})),
|
||||
}));
|
||||
jest.mock('~/models', () => ({
|
||||
findUser: jest.fn(),
|
||||
|
|
@ -35,8 +36,14 @@ jest.mock('~/strategies/validators', () => ({ registerSchema: { parse: jest.fn()
|
|||
jest.mock('~/server/services/Config', () => ({ getAppConfig: jest.fn() }));
|
||||
jest.mock('~/server/utils', () => ({ sendEmail: jest.fn() }));
|
||||
|
||||
const { shouldUseSecureCookie } = require('@librechat/api');
|
||||
const { setOpenIDAuthTokens } = require('./AuthService');
|
||||
const {
|
||||
shouldUseSecureCookie,
|
||||
isEmailDomainAllowed,
|
||||
resolveAppConfigForUser,
|
||||
} = require('@librechat/api');
|
||||
const { findUser } = require('~/models');
|
||||
const { getAppConfig } = require('~/server/services/Config');
|
||||
const { setOpenIDAuthTokens, requestPasswordReset } = require('./AuthService');
|
||||
|
||||
/** Helper to build a mock Express response */
|
||||
function mockResponse() {
|
||||
|
|
@ -267,3 +274,68 @@ describe('setOpenIDAuthTokens', () => {
|
|||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('requestPasswordReset', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
isEmailDomainAllowed.mockReturnValue(true);
|
||||
getAppConfig.mockResolvedValue({
|
||||
registration: { allowedDomains: ['example.com'] },
|
||||
});
|
||||
resolveAppConfigForUser.mockResolvedValue({
|
||||
registration: { allowedDomains: ['example.com'] },
|
||||
});
|
||||
});
|
||||
|
||||
it('should fast-fail with base config before DB lookup for blocked domains', async () => {
|
||||
isEmailDomainAllowed.mockReturnValue(false);
|
||||
|
||||
const req = { body: { email: 'blocked@evil.com' }, ip: '127.0.0.1' };
|
||||
const result = await requestPasswordReset(req);
|
||||
|
||||
expect(getAppConfig).toHaveBeenCalledWith({ baseOnly: true });
|
||||
expect(findUser).not.toHaveBeenCalled();
|
||||
expect(result).toBeInstanceOf(Error);
|
||||
});
|
||||
|
||||
it('should call resolveAppConfigForUser for tenant user', async () => {
|
||||
const user = {
|
||||
_id: 'user-tenant',
|
||||
email: 'user@example.com',
|
||||
tenantId: 'tenant-x',
|
||||
role: 'USER',
|
||||
};
|
||||
findUser.mockResolvedValue(user);
|
||||
|
||||
const req = { body: { email: 'user@example.com' }, ip: '127.0.0.1' };
|
||||
await requestPasswordReset(req);
|
||||
|
||||
expect(resolveAppConfigForUser).toHaveBeenCalledWith(getAppConfig, user);
|
||||
});
|
||||
|
||||
it('should reuse baseConfig for non-tenant user without calling resolveAppConfigForUser', async () => {
|
||||
findUser.mockResolvedValue({ _id: 'user-no-tenant', email: 'user@example.com' });
|
||||
|
||||
const req = { body: { email: 'user@example.com' }, ip: '127.0.0.1' };
|
||||
await requestPasswordReset(req);
|
||||
|
||||
expect(resolveAppConfigForUser).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should return generic response when tenant config blocks the domain (non-enumerable)', async () => {
|
||||
const user = {
|
||||
_id: 'user-tenant',
|
||||
email: 'user@example.com',
|
||||
tenantId: 'tenant-x',
|
||||
role: 'USER',
|
||||
};
|
||||
findUser.mockResolvedValue(user);
|
||||
isEmailDomainAllowed.mockReturnValueOnce(true).mockReturnValueOnce(false);
|
||||
|
||||
const req = { body: { email: 'user@example.com' }, ip: '127.0.0.1' };
|
||||
const result = await requestPasswordReset(req);
|
||||
|
||||
expect(result).not.toBeInstanceOf(Error);
|
||||
expect(result.message).toContain('If an account with that email exists');
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -0,0 +1,122 @@
|
|||
// ── Mocks ──────────────────────────────────────────────────────────────
|
||||
|
||||
const mockClearAppConfigCache = jest.fn().mockResolvedValue(undefined);
|
||||
const mockClearOverrideCache = jest.fn().mockResolvedValue(undefined);
|
||||
|
||||
jest.mock('~/cache/getLogStores', () => {
|
||||
return jest.fn(() => ({}));
|
||||
});
|
||||
|
||||
jest.mock('~/server/services/start/tools', () => ({
|
||||
loadAndFormatTools: jest.fn(() => ({})),
|
||||
}));
|
||||
|
||||
jest.mock('../loadCustomConfig', () => jest.fn().mockResolvedValue({}));
|
||||
|
||||
jest.mock('@librechat/data-schemas', () => {
|
||||
const actual = jest.requireActual('@librechat/data-schemas');
|
||||
return { ...actual, AppService: jest.fn(() => ({ availableTools: {} })) };
|
||||
});
|
||||
|
||||
jest.mock('~/models', () => ({
|
||||
getApplicableConfigs: jest.fn().mockResolvedValue([]),
|
||||
getUserPrincipals: jest.fn().mockResolvedValue([]),
|
||||
}));
|
||||
|
||||
const mockInvalidateCachedTools = jest.fn().mockResolvedValue(undefined);
|
||||
jest.mock('../getCachedTools', () => ({
|
||||
setCachedTools: jest.fn().mockResolvedValue(undefined),
|
||||
invalidateCachedTools: mockInvalidateCachedTools,
|
||||
}));
|
||||
|
||||
const mockClearMcpConfigCache = jest.fn().mockResolvedValue(undefined);
|
||||
jest.mock('@librechat/api', () => ({
|
||||
createAppConfigService: jest.fn(() => ({
|
||||
getAppConfig: jest.fn().mockResolvedValue({ availableTools: {} }),
|
||||
clearAppConfigCache: mockClearAppConfigCache,
|
||||
clearOverrideCache: mockClearOverrideCache,
|
||||
})),
|
||||
clearMcpConfigCache: mockClearMcpConfigCache,
|
||||
}));
|
||||
|
||||
// ── Tests ──────────────────────────────────────────────────────────────
|
||||
|
||||
const { invalidateConfigCaches } = require('../app');
|
||||
|
||||
describe('invalidateConfigCaches', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
it('clears all caches', async () => {
|
||||
await invalidateConfigCaches();
|
||||
|
||||
expect(mockClearAppConfigCache).toHaveBeenCalledTimes(1);
|
||||
expect(mockClearOverrideCache).toHaveBeenCalledTimes(1);
|
||||
expect(mockInvalidateCachedTools).toHaveBeenCalledWith({ invalidateGlobal: true });
|
||||
expect(mockClearMcpConfigCache).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('passes tenantId through to clearOverrideCache', async () => {
|
||||
await invalidateConfigCaches('tenant-a');
|
||||
|
||||
expect(mockClearOverrideCache).toHaveBeenCalledWith('tenant-a');
|
||||
expect(mockClearAppConfigCache).toHaveBeenCalledTimes(1);
|
||||
expect(mockInvalidateCachedTools).toHaveBeenCalledWith({ invalidateGlobal: true });
|
||||
});
|
||||
|
||||
it('all operations run in parallel (not sequentially)', async () => {
|
||||
const order = [];
|
||||
|
||||
mockClearAppConfigCache.mockImplementation(
|
||||
() =>
|
||||
new Promise((r) =>
|
||||
setTimeout(() => {
|
||||
order.push('base');
|
||||
r();
|
||||
}, 10),
|
||||
),
|
||||
);
|
||||
mockClearOverrideCache.mockImplementation(
|
||||
() =>
|
||||
new Promise((r) =>
|
||||
setTimeout(() => {
|
||||
order.push('override');
|
||||
r();
|
||||
}, 10),
|
||||
),
|
||||
);
|
||||
mockInvalidateCachedTools.mockImplementation(
|
||||
() =>
|
||||
new Promise((r) =>
|
||||
setTimeout(() => {
|
||||
order.push('tools');
|
||||
r();
|
||||
}, 10),
|
||||
),
|
||||
);
|
||||
mockClearMcpConfigCache.mockImplementation(
|
||||
() =>
|
||||
new Promise((r) =>
|
||||
setTimeout(() => {
|
||||
order.push('mcp');
|
||||
r();
|
||||
}, 10),
|
||||
),
|
||||
);
|
||||
|
||||
await invalidateConfigCaches();
|
||||
|
||||
expect(order).toHaveLength(4);
|
||||
expect(new Set(order)).toEqual(new Set(['base', 'override', 'tools', 'mcp']));
|
||||
});
|
||||
|
||||
it('resolves even when clearAppConfigCache throws (partial failure)', async () => {
|
||||
mockClearAppConfigCache.mockRejectedValueOnce(new Error('cache connection lost'));
|
||||
|
||||
await expect(invalidateConfigCaches()).resolves.not.toThrow();
|
||||
|
||||
expect(mockClearOverrideCache).toHaveBeenCalledTimes(1);
|
||||
expect(mockInvalidateCachedTools).toHaveBeenCalledWith({ invalidateGlobal: true });
|
||||
});
|
||||
});
|
||||
|
|
@ -1,12 +1,12 @@
|
|||
const { CacheKeys } = require('librechat-data-provider');
|
||||
const { logger, AppService } = require('@librechat/data-schemas');
|
||||
const { AppService, logger } = require('@librechat/data-schemas');
|
||||
const { createAppConfigService, clearMcpConfigCache } = require('@librechat/api');
|
||||
const { setCachedTools, invalidateCachedTools } = require('./getCachedTools');
|
||||
const { loadAndFormatTools } = require('~/server/services/start/tools');
|
||||
const loadCustomConfig = require('./loadCustomConfig');
|
||||
const { setCachedTools } = require('./getCachedTools');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const paths = require('~/config/paths');
|
||||
|
||||
const BASE_CONFIG_KEY = '_BASE_';
|
||||
const db = require('~/models');
|
||||
|
||||
const loadBaseConfig = async () => {
|
||||
/** @type {TCustomConfig} */
|
||||
|
|
@ -20,65 +20,43 @@ const loadBaseConfig = async () => {
|
|||
return AppService({ config, paths, systemTools });
|
||||
};
|
||||
|
||||
/**
|
||||
* Get the app configuration based on user context
|
||||
* @param {Object} [options]
|
||||
* @param {string} [options.role] - User role for role-based config
|
||||
* @param {boolean} [options.refresh] - Force refresh the cache
|
||||
* @returns {Promise<AppConfig>}
|
||||
*/
|
||||
async function getAppConfig(options = {}) {
|
||||
const { role, refresh } = options;
|
||||
|
||||
const cache = getLogStores(CacheKeys.APP_CONFIG);
|
||||
const cacheKey = role ? role : BASE_CONFIG_KEY;
|
||||
|
||||
if (!refresh) {
|
||||
const cached = await cache.get(cacheKey);
|
||||
if (cached) {
|
||||
return cached;
|
||||
}
|
||||
}
|
||||
|
||||
let baseConfig = await cache.get(BASE_CONFIG_KEY);
|
||||
if (!baseConfig) {
|
||||
logger.info('[getAppConfig] App configuration not initialized. Initializing AppService...');
|
||||
baseConfig = await loadBaseConfig();
|
||||
|
||||
if (!baseConfig) {
|
||||
throw new Error('Failed to initialize app configuration through AppService.');
|
||||
}
|
||||
|
||||
if (baseConfig.availableTools) {
|
||||
await setCachedTools(baseConfig.availableTools);
|
||||
}
|
||||
|
||||
await cache.set(BASE_CONFIG_KEY, baseConfig);
|
||||
}
|
||||
|
||||
// For now, return the base config
|
||||
// In the future, this is where we'll apply role-based modifications
|
||||
if (role) {
|
||||
// TODO: Apply role-based config modifications
|
||||
// const roleConfig = await applyRoleBasedConfig(baseConfig, role);
|
||||
// await cache.set(cacheKey, roleConfig);
|
||||
// return roleConfig;
|
||||
}
|
||||
|
||||
return baseConfig;
|
||||
}
|
||||
const { getAppConfig, clearAppConfigCache, clearOverrideCache } = createAppConfigService({
|
||||
loadBaseConfig,
|
||||
setCachedTools,
|
||||
getCache: getLogStores,
|
||||
cacheKeys: CacheKeys,
|
||||
getApplicableConfigs: db.getApplicableConfigs,
|
||||
getUserPrincipals: db.getUserPrincipals,
|
||||
});
|
||||
|
||||
/**
|
||||
* Clear the app configuration cache
|
||||
* @returns {Promise<boolean>}
|
||||
* Invalidate all config-related caches after an admin config mutation.
|
||||
* Clears the base config, per-principal override caches, tool caches,
|
||||
* and the MCP config-source server cache.
|
||||
* @param {string} [tenantId] - Optional tenant ID to scope override cache clearing.
|
||||
*/
|
||||
async function clearAppConfigCache() {
|
||||
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
||||
const cacheKey = CacheKeys.APP_CONFIG;
|
||||
return await cache.delete(cacheKey);
|
||||
async function invalidateConfigCaches(tenantId) {
|
||||
const results = await Promise.allSettled([
|
||||
clearAppConfigCache(),
|
||||
clearOverrideCache(tenantId),
|
||||
invalidateCachedTools({ invalidateGlobal: true }),
|
||||
clearMcpConfigCache(),
|
||||
]);
|
||||
const labels = [
|
||||
'clearAppConfigCache',
|
||||
'clearOverrideCache',
|
||||
'invalidateCachedTools',
|
||||
'clearMcpConfigCache',
|
||||
];
|
||||
for (let i = 0; i < results.length; i++) {
|
||||
if (results[i].status === 'rejected') {
|
||||
logger.error(`[invalidateConfigCaches] ${labels[i]} failed:`, results[i].reason);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
getAppConfig,
|
||||
clearAppConfigCache,
|
||||
invalidateConfigCaches,
|
||||
};
|
||||
|
|
|
|||
|
|
@ -1,133 +1,10 @@
|
|||
const { loadCustomEndpointsConfig } = require('@librechat/api');
|
||||
const {
|
||||
CacheKeys,
|
||||
EModelEndpoint,
|
||||
isAgentsEndpoint,
|
||||
orderEndpointsConfig,
|
||||
defaultAgentCapabilities,
|
||||
} = require('librechat-data-provider');
|
||||
const { createEndpointsConfigService } = require('@librechat/api');
|
||||
const loadDefaultEndpointsConfig = require('./loadDefaultEConfig');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const { getAppConfig } = require('./app');
|
||||
|
||||
/**
|
||||
*
|
||||
* @param {ServerRequest} req
|
||||
* @returns {Promise<TEndpointsConfig>}
|
||||
*/
|
||||
async function getEndpointsConfig(req) {
|
||||
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
||||
const cachedEndpointsConfig = await cache.get(CacheKeys.ENDPOINT_CONFIG);
|
||||
if (cachedEndpointsConfig) {
|
||||
if (cachedEndpointsConfig.gptPlugins) {
|
||||
await cache.delete(CacheKeys.ENDPOINT_CONFIG);
|
||||
} else {
|
||||
return cachedEndpointsConfig;
|
||||
}
|
||||
}
|
||||
|
||||
const appConfig = req.config ?? (await getAppConfig({ role: req.user?.role }));
|
||||
const defaultEndpointsConfig = await loadDefaultEndpointsConfig(appConfig);
|
||||
const customEndpointsConfig = loadCustomEndpointsConfig(appConfig?.endpoints?.custom);
|
||||
|
||||
/** @type {TEndpointsConfig} */
|
||||
const mergedConfig = {
|
||||
...defaultEndpointsConfig,
|
||||
...customEndpointsConfig,
|
||||
};
|
||||
|
||||
if (appConfig.endpoints?.[EModelEndpoint.azureOpenAI]) {
|
||||
/** @type {Omit<TConfig, 'order'>} */
|
||||
mergedConfig[EModelEndpoint.azureOpenAI] = {
|
||||
userProvide: false,
|
||||
};
|
||||
}
|
||||
|
||||
// Enable Anthropic endpoint when Vertex AI is configured in YAML
|
||||
if (appConfig.endpoints?.[EModelEndpoint.anthropic]?.vertexConfig?.enabled) {
|
||||
/** @type {Omit<TConfig, 'order'>} */
|
||||
mergedConfig[EModelEndpoint.anthropic] = {
|
||||
userProvide: false,
|
||||
};
|
||||
}
|
||||
|
||||
if (appConfig.endpoints?.[EModelEndpoint.azureOpenAI]?.assistants) {
|
||||
/** @type {Omit<TConfig, 'order'>} */
|
||||
mergedConfig[EModelEndpoint.azureAssistants] = {
|
||||
userProvide: false,
|
||||
};
|
||||
}
|
||||
|
||||
if (
|
||||
mergedConfig[EModelEndpoint.assistants] &&
|
||||
appConfig?.endpoints?.[EModelEndpoint.assistants]
|
||||
) {
|
||||
const { disableBuilder, retrievalModels, capabilities, version, ..._rest } =
|
||||
appConfig.endpoints[EModelEndpoint.assistants];
|
||||
|
||||
mergedConfig[EModelEndpoint.assistants] = {
|
||||
...mergedConfig[EModelEndpoint.assistants],
|
||||
version,
|
||||
retrievalModels,
|
||||
disableBuilder,
|
||||
capabilities,
|
||||
};
|
||||
}
|
||||
if (mergedConfig[EModelEndpoint.agents] && appConfig?.endpoints?.[EModelEndpoint.agents]) {
|
||||
const { disableBuilder, capabilities, allowedProviders, ..._rest } =
|
||||
appConfig.endpoints[EModelEndpoint.agents];
|
||||
|
||||
mergedConfig[EModelEndpoint.agents] = {
|
||||
...mergedConfig[EModelEndpoint.agents],
|
||||
allowedProviders,
|
||||
disableBuilder,
|
||||
capabilities,
|
||||
};
|
||||
}
|
||||
|
||||
if (
|
||||
mergedConfig[EModelEndpoint.azureAssistants] &&
|
||||
appConfig?.endpoints?.[EModelEndpoint.azureAssistants]
|
||||
) {
|
||||
const { disableBuilder, retrievalModels, capabilities, version, ..._rest } =
|
||||
appConfig.endpoints[EModelEndpoint.azureAssistants];
|
||||
|
||||
mergedConfig[EModelEndpoint.azureAssistants] = {
|
||||
...mergedConfig[EModelEndpoint.azureAssistants],
|
||||
version,
|
||||
retrievalModels,
|
||||
disableBuilder,
|
||||
capabilities,
|
||||
};
|
||||
}
|
||||
|
||||
if (mergedConfig[EModelEndpoint.bedrock] && appConfig?.endpoints?.[EModelEndpoint.bedrock]) {
|
||||
const { availableRegions } = appConfig.endpoints[EModelEndpoint.bedrock];
|
||||
mergedConfig[EModelEndpoint.bedrock] = {
|
||||
...mergedConfig[EModelEndpoint.bedrock],
|
||||
availableRegions,
|
||||
};
|
||||
}
|
||||
|
||||
const endpointsConfig = orderEndpointsConfig(mergedConfig);
|
||||
|
||||
await cache.set(CacheKeys.ENDPOINT_CONFIG, endpointsConfig);
|
||||
return endpointsConfig;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param {ServerRequest} req
|
||||
* @param {import('librechat-data-provider').AgentCapabilities} capability
|
||||
* @returns {Promise<boolean>}
|
||||
*/
|
||||
const checkCapability = async (req, capability) => {
|
||||
const isAgents = isAgentsEndpoint(req.body?.endpointType || req.body?.endpoint);
|
||||
const endpointsConfig = await getEndpointsConfig(req);
|
||||
const capabilities =
|
||||
isAgents || endpointsConfig?.[EModelEndpoint.agents]?.capabilities != null
|
||||
? (endpointsConfig?.[EModelEndpoint.agents]?.capabilities ?? [])
|
||||
: defaultAgentCapabilities;
|
||||
return capabilities.includes(capability);
|
||||
};
|
||||
const { getEndpointsConfig, checkCapability } = createEndpointsConfigService({
|
||||
getAppConfig,
|
||||
loadDefaultEndpointsConfig,
|
||||
});
|
||||
|
||||
module.exports = { getEndpointsConfig, checkCapability };
|
||||
|
|
|
|||
|
|
@ -1,117 +1,11 @@
|
|||
const { isUserProvided, fetchModels } = require('@librechat/api');
|
||||
const {
|
||||
EModelEndpoint,
|
||||
extractEnvVariable,
|
||||
normalizeEndpointName,
|
||||
} = require('librechat-data-provider');
|
||||
const { createLoadConfigModels, fetchModels } = require('@librechat/api');
|
||||
const { getAppConfig } = require('./app');
|
||||
const db = require('~/models');
|
||||
|
||||
/**
|
||||
* Load config endpoints from the cached configuration object
|
||||
* @function loadConfigModels
|
||||
* @param {ServerRequest} req - The Express request object.
|
||||
*/
|
||||
async function loadConfigModels(req) {
|
||||
const appConfig = await getAppConfig({ role: req.user?.role });
|
||||
if (!appConfig) {
|
||||
return {};
|
||||
}
|
||||
const modelsConfig = {};
|
||||
const azureConfig = appConfig.endpoints?.[EModelEndpoint.azureOpenAI];
|
||||
const { modelNames } = azureConfig ?? {};
|
||||
|
||||
if (modelNames && azureConfig) {
|
||||
modelsConfig[EModelEndpoint.azureOpenAI] = modelNames;
|
||||
}
|
||||
|
||||
if (azureConfig?.assistants && azureConfig.assistantModels) {
|
||||
modelsConfig[EModelEndpoint.azureAssistants] = azureConfig.assistantModels;
|
||||
}
|
||||
|
||||
const bedrockConfig = appConfig.endpoints?.[EModelEndpoint.bedrock];
|
||||
if (bedrockConfig?.models && Array.isArray(bedrockConfig.models)) {
|
||||
modelsConfig[EModelEndpoint.bedrock] = bedrockConfig.models;
|
||||
}
|
||||
|
||||
if (!Array.isArray(appConfig.endpoints?.[EModelEndpoint.custom])) {
|
||||
return modelsConfig;
|
||||
}
|
||||
|
||||
const customEndpoints = appConfig.endpoints[EModelEndpoint.custom].filter(
|
||||
(endpoint) =>
|
||||
endpoint.baseURL &&
|
||||
endpoint.apiKey &&
|
||||
endpoint.name &&
|
||||
endpoint.models &&
|
||||
(endpoint.models.fetch || endpoint.models.default),
|
||||
);
|
||||
|
||||
/**
|
||||
* @type {Record<string, Promise<string[]>>}
|
||||
* Map for promises keyed by unique combination of baseURL and apiKey */
|
||||
const fetchPromisesMap = {};
|
||||
/**
|
||||
* @type {Record<string, string[]>}
|
||||
* Map to associate unique keys with endpoint names; note: one key may can correspond to multiple endpoints */
|
||||
const uniqueKeyToEndpointsMap = {};
|
||||
/**
|
||||
* @type {Record<string, Partial<TEndpoint>>}
|
||||
* Map to associate endpoint names to their configurations */
|
||||
const endpointsMap = {};
|
||||
|
||||
for (let i = 0; i < customEndpoints.length; i++) {
|
||||
const endpoint = customEndpoints[i];
|
||||
const { models, name: configName, baseURL, apiKey, headers: endpointHeaders } = endpoint;
|
||||
const name = normalizeEndpointName(configName);
|
||||
endpointsMap[name] = endpoint;
|
||||
|
||||
const API_KEY = extractEnvVariable(apiKey);
|
||||
const BASE_URL = extractEnvVariable(baseURL);
|
||||
|
||||
const uniqueKey = `${BASE_URL}__${API_KEY}`;
|
||||
|
||||
modelsConfig[name] = [];
|
||||
|
||||
if (models.fetch && !isUserProvided(API_KEY) && !isUserProvided(BASE_URL)) {
|
||||
fetchPromisesMap[uniqueKey] =
|
||||
fetchPromisesMap[uniqueKey] ||
|
||||
fetchModels({
|
||||
name,
|
||||
apiKey: API_KEY,
|
||||
baseURL: BASE_URL,
|
||||
user: req.user.id,
|
||||
userObject: req.user,
|
||||
headers: endpointHeaders,
|
||||
direct: endpoint.directEndpoint,
|
||||
userIdQuery: models.userIdQuery,
|
||||
});
|
||||
uniqueKeyToEndpointsMap[uniqueKey] = uniqueKeyToEndpointsMap[uniqueKey] || [];
|
||||
uniqueKeyToEndpointsMap[uniqueKey].push(name);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (Array.isArray(models.default)) {
|
||||
modelsConfig[name] = models.default.map((model) =>
|
||||
typeof model === 'string' ? model : model.name,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
const fetchedData = await Promise.all(Object.values(fetchPromisesMap));
|
||||
const uniqueKeys = Object.keys(fetchPromisesMap);
|
||||
|
||||
for (let i = 0; i < fetchedData.length; i++) {
|
||||
const currentKey = uniqueKeys[i];
|
||||
const modelData = fetchedData[i];
|
||||
const associatedNames = uniqueKeyToEndpointsMap[currentKey];
|
||||
|
||||
for (const name of associatedNames) {
|
||||
const endpoint = endpointsMap[name];
|
||||
modelsConfig[name] = !modelData?.length ? (endpoint.models.default ?? []) : modelData;
|
||||
}
|
||||
}
|
||||
|
||||
return modelsConfig;
|
||||
}
|
||||
const loadConfigModels = createLoadConfigModels({
|
||||
getAppConfig,
|
||||
getUserKeyValues: db.getUserKeyValues,
|
||||
fetchModels,
|
||||
});
|
||||
|
||||
module.exports = loadConfigModels;
|
||||
|
|
|
|||
|
|
@ -7,6 +7,13 @@ jest.mock('@librechat/api', () => ({
|
|||
fetchModels: jest.fn(),
|
||||
}));
|
||||
jest.mock('./app');
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
...jest.requireActual('@librechat/data-schemas'),
|
||||
logger: { debug: jest.fn(), error: jest.fn(), warn: jest.fn() },
|
||||
}));
|
||||
jest.mock('~/models', () => ({
|
||||
getUserKeyValues: jest.fn(),
|
||||
}));
|
||||
|
||||
const exampleConfig = {
|
||||
endpoints: {
|
||||
|
|
@ -68,11 +75,11 @@ describe('loadConfigModels', () => {
|
|||
const originalEnv = process.env;
|
||||
|
||||
beforeEach(() => {
|
||||
jest.resetAllMocks();
|
||||
jest.resetModules();
|
||||
jest.clearAllMocks();
|
||||
fetchModels.mockReset();
|
||||
require('~/models').getUserKeyValues.mockReset();
|
||||
process.env = { ...originalEnv };
|
||||
|
||||
// Default mock for getAppConfig
|
||||
getAppConfig.mockResolvedValue({});
|
||||
});
|
||||
|
||||
|
|
@ -337,6 +344,168 @@ describe('loadConfigModels', () => {
|
|||
expect(result.FalsyFetchModel).toEqual(['defaultModel1', 'defaultModel2']);
|
||||
});
|
||||
|
||||
describe('user-provided API key model fetching', () => {
|
||||
it('fetches models using user-provided API key when key is stored', async () => {
|
||||
const { getUserKeyValues } = require('~/models');
|
||||
getUserKeyValues.mockResolvedValueOnce({
|
||||
apiKey: 'sk-user-key',
|
||||
baseURL: 'https://api.x.com/v1',
|
||||
});
|
||||
getAppConfig.mockResolvedValue({
|
||||
endpoints: {
|
||||
custom: [
|
||||
{
|
||||
name: 'UserEndpoint',
|
||||
apiKey: 'user_provided',
|
||||
baseURL: 'user_provided',
|
||||
models: { fetch: true, default: ['fallback-model'] },
|
||||
},
|
||||
],
|
||||
},
|
||||
});
|
||||
fetchModels.mockResolvedValue(['fetched-model-a', 'fetched-model-b']);
|
||||
|
||||
const result = await loadConfigModels(mockRequest);
|
||||
|
||||
expect(getUserKeyValues).toHaveBeenCalledWith({ userId: 'testUserId', name: 'UserEndpoint' });
|
||||
expect(fetchModels).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
apiKey: 'sk-user-key',
|
||||
baseURL: 'https://api.x.com/v1',
|
||||
skipCache: true,
|
||||
}),
|
||||
);
|
||||
expect(result.UserEndpoint).toEqual(['fetched-model-a', 'fetched-model-b']);
|
||||
});
|
||||
|
||||
it('falls back to defaults when getUserKeyValues returns no apiKey', async () => {
|
||||
const { getUserKeyValues } = require('~/models');
|
||||
getUserKeyValues.mockResolvedValueOnce({ baseURL: 'https://api.x.com/v1' });
|
||||
getAppConfig.mockResolvedValue({
|
||||
endpoints: {
|
||||
custom: [
|
||||
{
|
||||
name: 'NoKeyEndpoint',
|
||||
apiKey: 'user_provided',
|
||||
baseURL: 'https://api.x.com/v1',
|
||||
models: { fetch: true, default: ['default-model'] },
|
||||
},
|
||||
],
|
||||
},
|
||||
});
|
||||
|
||||
const result = await loadConfigModels(mockRequest);
|
||||
|
||||
expect(fetchModels).not.toHaveBeenCalled();
|
||||
expect(result.NoKeyEndpoint).toEqual(['default-model']);
|
||||
});
|
||||
|
||||
it('falls back to defaults and logs warn when getUserKeyValues throws infra error', async () => {
|
||||
const { getUserKeyValues } = require('~/models');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
getUserKeyValues.mockRejectedValueOnce(new Error('DB connection timeout'));
|
||||
getAppConfig.mockResolvedValue({
|
||||
endpoints: {
|
||||
custom: [
|
||||
{
|
||||
name: 'ErrorEndpoint',
|
||||
apiKey: 'user_provided',
|
||||
baseURL: 'https://api.example.com/v1',
|
||||
models: { fetch: true, default: ['fallback'] },
|
||||
},
|
||||
],
|
||||
},
|
||||
});
|
||||
|
||||
const result = await loadConfigModels(mockRequest);
|
||||
|
||||
expect(fetchModels).not.toHaveBeenCalled();
|
||||
expect(result.ErrorEndpoint).toEqual(['fallback']);
|
||||
expect(logger.warn).toHaveBeenCalledWith(
|
||||
expect.stringContaining(
|
||||
'Failed to retrieve user key for "ErrorEndpoint": DB connection timeout',
|
||||
),
|
||||
);
|
||||
expect(logger.debug).not.toHaveBeenCalledWith(expect.stringContaining('No user key stored'));
|
||||
});
|
||||
|
||||
it('logs debug (not warn) for NO_USER_KEY errors', async () => {
|
||||
const { getUserKeyValues } = require('~/models');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
getUserKeyValues.mockRejectedValueOnce(new Error(JSON.stringify({ type: 'no_user_key' })));
|
||||
getAppConfig.mockResolvedValue({
|
||||
endpoints: {
|
||||
custom: [
|
||||
{
|
||||
name: 'MissingKeyEndpoint',
|
||||
apiKey: 'user_provided',
|
||||
baseURL: 'https://api.example.com/v1',
|
||||
models: { fetch: true, default: ['default-model'] },
|
||||
},
|
||||
],
|
||||
},
|
||||
});
|
||||
|
||||
const result = await loadConfigModels(mockRequest);
|
||||
|
||||
expect(result.MissingKeyEndpoint).toEqual(['default-model']);
|
||||
expect(logger.debug).toHaveBeenCalledWith(expect.stringContaining('No user key stored'));
|
||||
expect(logger.warn).not.toHaveBeenCalledWith(
|
||||
expect.stringContaining('Failed to retrieve user key'),
|
||||
);
|
||||
});
|
||||
|
||||
it('skips user key lookup when req.user.id is undefined', async () => {
|
||||
const { getUserKeyValues } = require('~/models');
|
||||
getAppConfig.mockResolvedValue({
|
||||
endpoints: {
|
||||
custom: [
|
||||
{
|
||||
name: 'NoUserEndpoint',
|
||||
apiKey: 'user_provided',
|
||||
baseURL: 'https://api.x.com/v1',
|
||||
models: { fetch: true, default: ['anon-model'] },
|
||||
},
|
||||
],
|
||||
},
|
||||
});
|
||||
|
||||
const result = await loadConfigModels({ user: {} });
|
||||
|
||||
expect(getUserKeyValues).not.toHaveBeenCalled();
|
||||
expect(result.NoUserEndpoint).toEqual(['anon-model']);
|
||||
});
|
||||
|
||||
it('uses stored baseURL only when baseURL is user_provided', async () => {
|
||||
const { getUserKeyValues } = require('~/models');
|
||||
getUserKeyValues.mockResolvedValueOnce({ apiKey: 'sk-key' });
|
||||
getAppConfig.mockResolvedValue({
|
||||
endpoints: {
|
||||
custom: [
|
||||
{
|
||||
name: 'KeyOnly',
|
||||
apiKey: 'user_provided',
|
||||
baseURL: 'https://fixed-base.com/v1',
|
||||
models: { fetch: true, default: ['default'] },
|
||||
},
|
||||
],
|
||||
},
|
||||
});
|
||||
fetchModels.mockResolvedValue(['model-from-fixed-base']);
|
||||
|
||||
const result = await loadConfigModels(mockRequest);
|
||||
|
||||
expect(fetchModels).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
apiKey: 'sk-key',
|
||||
baseURL: 'https://fixed-base.com/v1',
|
||||
skipCache: true,
|
||||
}),
|
||||
);
|
||||
expect(result.KeyOnly).toEqual(['model-from-fixed-base']);
|
||||
});
|
||||
});
|
||||
|
||||
it('normalizes Ollama endpoint name to lowercase', async () => {
|
||||
const testCases = [
|
||||
{
|
||||
|
|
|
|||
|
|
@ -16,7 +16,8 @@ const { getAppConfig } = require('./app');
|
|||
*/
|
||||
async function loadDefaultModels(req) {
|
||||
try {
|
||||
const appConfig = req.config ?? (await getAppConfig({ role: req.user?.role }));
|
||||
const appConfig =
|
||||
req.config ?? (await getAppConfig({ role: req.user?.role, tenantId: req.user?.tenantId }));
|
||||
const vertexConfig = appConfig?.endpoints?.[EModelEndpoint.anthropic]?.vertexConfig;
|
||||
|
||||
const [openAI, anthropic, azureOpenAI, assistants, azureAssistants, google, bedrock] =
|
||||
|
|
|
|||
|
|
@ -1,97 +1,10 @@
|
|||
const { logger } = require('@librechat/data-schemas');
|
||||
const { CacheKeys, Constants } = require('librechat-data-provider');
|
||||
const { createMCPToolCacheService } = require('@librechat/api');
|
||||
const { getCachedTools, setCachedTools } = require('./getCachedTools');
|
||||
const { getLogStores } = require('~/cache');
|
||||
|
||||
/**
|
||||
* Updates MCP tools in the cache for a specific server
|
||||
* @param {Object} params - Parameters for updating MCP tools
|
||||
* @param {string} params.userId - User ID for user-specific caching
|
||||
* @param {string} params.serverName - MCP server name
|
||||
* @param {Array} params.tools - Array of tool objects from MCP server
|
||||
* @returns {Promise<LCAvailableTools>}
|
||||
*/
|
||||
async function updateMCPServerTools({ userId, serverName, tools }) {
|
||||
try {
|
||||
const serverTools = {};
|
||||
const mcpDelimiter = Constants.mcp_delimiter;
|
||||
|
||||
if (tools == null || tools.length === 0) {
|
||||
logger.debug(`[MCP Cache] No tools to update for server ${serverName} (user: ${userId})`);
|
||||
return serverTools;
|
||||
}
|
||||
|
||||
for (const tool of tools) {
|
||||
const name = `${tool.name}${mcpDelimiter}${serverName}`;
|
||||
serverTools[name] = {
|
||||
type: 'function',
|
||||
['function']: {
|
||||
name,
|
||||
description: tool.description,
|
||||
parameters: tool.inputSchema,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
await setCachedTools(serverTools, { userId, serverName });
|
||||
|
||||
const cache = getLogStores(CacheKeys.TOOL_CACHE);
|
||||
await cache.delete(CacheKeys.TOOLS);
|
||||
logger.debug(
|
||||
`[MCP Cache] Updated ${tools.length} tools for server ${serverName} (user: ${userId})`,
|
||||
);
|
||||
return serverTools;
|
||||
} catch (error) {
|
||||
logger.error(`[MCP Cache] Failed to update tools for ${serverName} (user: ${userId}):`, error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Merges app-level tools with global tools
|
||||
* @param {import('@librechat/api').LCAvailableTools} appTools
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
async function mergeAppTools(appTools) {
|
||||
try {
|
||||
const count = Object.keys(appTools).length;
|
||||
if (!count) {
|
||||
return;
|
||||
}
|
||||
const cachedTools = await getCachedTools();
|
||||
const mergedTools = { ...cachedTools, ...appTools };
|
||||
await setCachedTools(mergedTools);
|
||||
const cache = getLogStores(CacheKeys.TOOL_CACHE);
|
||||
await cache.delete(CacheKeys.TOOLS);
|
||||
logger.debug(`Merged ${count} app-level tools`);
|
||||
} catch (error) {
|
||||
logger.error('Failed to merge app-level tools:', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Caches MCP server tools (no longer merges with global)
|
||||
* @param {object} params
|
||||
* @param {string} params.userId - User ID for user-specific caching
|
||||
* @param {string} params.serverName
|
||||
* @param {import('@librechat/api').LCAvailableTools} params.serverTools
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
async function cacheMCPServerTools({ userId, serverName, serverTools }) {
|
||||
try {
|
||||
const count = Object.keys(serverTools).length;
|
||||
if (!count) {
|
||||
return;
|
||||
}
|
||||
// Only cache server-specific tools, no merging with global
|
||||
await setCachedTools(serverTools, { userId, serverName });
|
||||
logger.debug(`Cached ${count} MCP server tools for ${serverName} (user: ${userId})`);
|
||||
} catch (error) {
|
||||
logger.error(`Failed to cache MCP server tools for ${serverName} (user: ${userId}):`, error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
const { mergeAppTools, cacheMCPServerTools, updateMCPServerTools } = createMCPToolCacheService({
|
||||
getCachedTools,
|
||||
setCachedTools,
|
||||
});
|
||||
|
||||
module.exports = {
|
||||
mergeAppTools,
|
||||
|
|
|
|||
|
|
@ -142,6 +142,7 @@ class STTService {
|
|||
req.config ??
|
||||
(await getAppConfig({
|
||||
role: req?.user?.role,
|
||||
tenantId: req?.user?.tenantId,
|
||||
}));
|
||||
const sttSchema = appConfig?.speech?.stt;
|
||||
if (!sttSchema) {
|
||||
|
|
|
|||
|
|
@ -297,6 +297,7 @@ class TTSService {
|
|||
req.config ??
|
||||
(await getAppConfig({
|
||||
role: req.user?.role,
|
||||
tenantId: req.user?.tenantId,
|
||||
}));
|
||||
try {
|
||||
res.setHeader('Content-Type', 'audio/mpeg');
|
||||
|
|
@ -365,6 +366,7 @@ class TTSService {
|
|||
req.config ??
|
||||
(await getAppConfig({
|
||||
role: req.user?.role,
|
||||
tenantId: req.user?.tenantId,
|
||||
}));
|
||||
const provider = this.getProvider(appConfig);
|
||||
const ttsSchema = appConfig?.speech?.tts?.[provider];
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ async function getCustomConfigSpeech(req, res) {
|
|||
try {
|
||||
const appConfig = await getAppConfig({
|
||||
role: req.user?.role,
|
||||
tenantId: req.user?.tenantId,
|
||||
});
|
||||
|
||||
if (!appConfig) {
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ async function getVoices(req, res) {
|
|||
req.config ??
|
||||
(await getAppConfig({
|
||||
role: req.user?.role,
|
||||
tenantId: req.user?.tenantId,
|
||||
}));
|
||||
|
||||
const ttsSchema = appConfig?.speech?.tts;
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
const { scopedCacheKey } = require('@librechat/data-schemas');
|
||||
const {
|
||||
Time,
|
||||
CacheKeys,
|
||||
|
|
@ -67,6 +68,8 @@ function createChunkProcessor(user, messageId) {
|
|||
}
|
||||
|
||||
const messageCache = getLogStores(CacheKeys.MESSAGES);
|
||||
// Captured at creation time — must be called within an active request ALS scope
|
||||
const cacheKey = scopedCacheKey(messageId);
|
||||
|
||||
/**
|
||||
* @returns {Promise<{ text: string, isFinished: boolean }[] | string>}
|
||||
|
|
@ -81,7 +84,7 @@ function createChunkProcessor(user, messageId) {
|
|||
}
|
||||
|
||||
/** @type { string | { text: string; complete: boolean } } */
|
||||
let message = await messageCache.get(messageId);
|
||||
let message = await messageCache.get(cacheKey);
|
||||
if (!message) {
|
||||
message = await getMessage({ user, messageId });
|
||||
}
|
||||
|
|
@ -92,7 +95,7 @@ function createChunkProcessor(user, messageId) {
|
|||
} else {
|
||||
const text = message.content?.length > 0 ? parseTextParts(message.content) : message.text;
|
||||
messageCache.set(
|
||||
messageId,
|
||||
cacheKey,
|
||||
{
|
||||
text,
|
||||
complete: true,
|
||||
|
|
|
|||
|
|
@ -47,7 +47,10 @@ async function processFileCitations({ user, appConfig, toolArtifact, toolCallId,
|
|||
logger.error(
|
||||
`[processFileCitations] Permission check failed for FILE_CITATIONS: ${error.message}`,
|
||||
);
|
||||
logger.debug(`[processFileCitations] Proceeding with citations due to permission error`);
|
||||
logger.warn(
|
||||
'[processFileCitations] Returning null citations due to permission check error — citations will not be shown for this message',
|
||||
);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -145,6 +148,8 @@ async function enhanceSourcesWithMetadata(sources, appConfig) {
|
|||
metadata: {
|
||||
...source.metadata,
|
||||
storageType: configuredStorageType,
|
||||
fileType: fileRecord.type || undefined,
|
||||
fileBytes: fileRecord.bytes || undefined,
|
||||
},
|
||||
};
|
||||
});
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
const { tool } = require('@langchain/core/tools');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { logger, getTenantId } = require('@librechat/data-schemas');
|
||||
const {
|
||||
Providers,
|
||||
StepTypes,
|
||||
|
|
@ -14,6 +14,7 @@ const {
|
|||
normalizeJsonSchema,
|
||||
GenerationJobManager,
|
||||
resolveJsonSchemaRefs,
|
||||
buildOAuthToolCallName,
|
||||
} = require('@librechat/api');
|
||||
const { Time, CacheKeys, Constants, isAssistantsEndpoint } = require('librechat-data-provider');
|
||||
const {
|
||||
|
|
@ -53,6 +54,53 @@ function evictStale(map, ttl) {
|
|||
const unavailableMsg =
|
||||
"This tool's MCP server is temporarily unavailable. Please try again shortly.";
|
||||
|
||||
/**
|
||||
* Resolves config-source MCP servers from admin Config overrides for the current
|
||||
* request context. Returns the parsed configs keyed by server name.
|
||||
* @param {import('express').Request} req - Express request with user context
|
||||
* @returns {Promise<Record<string, import('@librechat/api').ParsedServerConfig>>}
|
||||
*/
|
||||
async function resolveConfigServers(req) {
|
||||
try {
|
||||
const registry = getMCPServersRegistry();
|
||||
const user = req?.user;
|
||||
const appConfig = await getAppConfig({
|
||||
role: user?.role,
|
||||
tenantId: getTenantId(),
|
||||
userId: user?.id,
|
||||
});
|
||||
return await registry.ensureConfigServers(appConfig?.mcpConfig || {});
|
||||
} catch (error) {
|
||||
logger.warn(
|
||||
'[resolveConfigServers] Failed to resolve config servers, degrading to empty:',
|
||||
error,
|
||||
);
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolves config-source servers and merges all server configs (YAML + config + user DB)
|
||||
* for the given user context. Shared helper for controllers needing the full merged config.
|
||||
* @param {string} userId
|
||||
* @param {{ id?: string, role?: string }} [user]
|
||||
* @returns {Promise<Record<string, import('@librechat/api').ParsedServerConfig>>}
|
||||
*/
|
||||
async function resolveAllMcpConfigs(userId, user) {
|
||||
const registry = getMCPServersRegistry();
|
||||
const appConfig = await getAppConfig({ role: user?.role, tenantId: getTenantId(), userId });
|
||||
let configServers = {};
|
||||
try {
|
||||
configServers = await registry.ensureConfigServers(appConfig?.mcpConfig || {});
|
||||
} catch (error) {
|
||||
logger.warn(
|
||||
'[resolveAllMcpConfigs] Config server resolution failed, continuing without:',
|
||||
error,
|
||||
);
|
||||
}
|
||||
return await registry.getAllServerConfigs(userId, configServers);
|
||||
}
|
||||
|
||||
/**
|
||||
* @param {string} toolName
|
||||
* @param {string} serverName
|
||||
|
|
@ -248,6 +296,7 @@ async function reconnectServer({
|
|||
index,
|
||||
signal,
|
||||
serverName,
|
||||
configServers,
|
||||
userMCPAuthMap,
|
||||
streamId = null,
|
||||
}) {
|
||||
|
|
@ -271,7 +320,7 @@ async function reconnectServer({
|
|||
const stepId = 'step_oauth_login_' + serverName;
|
||||
const toolCall = {
|
||||
id: flowId,
|
||||
name: serverName,
|
||||
name: buildOAuthToolCallName(serverName),
|
||||
type: 'tool_call_chunk',
|
||||
};
|
||||
|
||||
|
|
@ -316,6 +365,7 @@ async function reconnectServer({
|
|||
user,
|
||||
signal,
|
||||
serverName,
|
||||
configServers,
|
||||
oauthStart,
|
||||
flowManager,
|
||||
userMCPAuthMap,
|
||||
|
|
@ -358,15 +408,14 @@ async function createMCPTools({
|
|||
config,
|
||||
provider,
|
||||
serverName,
|
||||
configServers,
|
||||
userMCPAuthMap,
|
||||
streamId = null,
|
||||
}) {
|
||||
// Early domain validation before reconnecting server (avoid wasted work on disallowed domains)
|
||||
// Use getAppConfig() to support per-user/role domain restrictions
|
||||
const serverConfig =
|
||||
config ?? (await getMCPServersRegistry().getServerConfig(serverName, user?.id));
|
||||
config ?? (await getMCPServersRegistry().getServerConfig(serverName, user?.id, configServers));
|
||||
if (serverConfig?.url) {
|
||||
const appConfig = await getAppConfig({ role: user?.role });
|
||||
const appConfig = await getAppConfig({ role: user?.role, tenantId: user?.tenantId });
|
||||
const allowedDomains = appConfig?.mcpSettings?.allowedDomains;
|
||||
const isDomainAllowed = await isMCPDomainAllowed(serverConfig, allowedDomains);
|
||||
if (!isDomainAllowed) {
|
||||
|
|
@ -381,6 +430,7 @@ async function createMCPTools({
|
|||
index,
|
||||
signal,
|
||||
serverName,
|
||||
configServers,
|
||||
userMCPAuthMap,
|
||||
streamId,
|
||||
});
|
||||
|
|
@ -400,6 +450,7 @@ async function createMCPTools({
|
|||
user,
|
||||
provider,
|
||||
userMCPAuthMap,
|
||||
configServers,
|
||||
streamId,
|
||||
availableTools: result.availableTools,
|
||||
toolKey: `${tool.name}${Constants.mcp_delimiter}${serverName}`,
|
||||
|
|
@ -439,16 +490,15 @@ async function createMCPTool({
|
|||
userMCPAuthMap,
|
||||
availableTools,
|
||||
config,
|
||||
configServers,
|
||||
streamId = null,
|
||||
}) {
|
||||
const [toolName, serverName] = toolKey.split(Constants.mcp_delimiter);
|
||||
|
||||
// Runtime domain validation: check if the server's domain is still allowed
|
||||
// Use getAppConfig() to support per-user/role domain restrictions
|
||||
const serverConfig =
|
||||
config ?? (await getMCPServersRegistry().getServerConfig(serverName, user?.id));
|
||||
config ?? (await getMCPServersRegistry().getServerConfig(serverName, user?.id, configServers));
|
||||
if (serverConfig?.url) {
|
||||
const appConfig = await getAppConfig({ role: user?.role });
|
||||
const appConfig = await getAppConfig({ role: user?.role, tenantId: user?.tenantId });
|
||||
const allowedDomains = appConfig?.mcpSettings?.allowedDomains;
|
||||
const isDomainAllowed = await isMCPDomainAllowed(serverConfig, allowedDomains);
|
||||
if (!isDomainAllowed) {
|
||||
|
|
@ -477,6 +527,7 @@ async function createMCPTool({
|
|||
index,
|
||||
signal,
|
||||
serverName,
|
||||
configServers,
|
||||
userMCPAuthMap,
|
||||
streamId,
|
||||
});
|
||||
|
|
@ -500,6 +551,7 @@ async function createMCPTool({
|
|||
provider,
|
||||
toolName,
|
||||
serverName,
|
||||
serverConfig,
|
||||
toolDefinition,
|
||||
streamId,
|
||||
});
|
||||
|
|
@ -509,13 +561,14 @@ function createToolInstance({
|
|||
res,
|
||||
toolName,
|
||||
serverName,
|
||||
serverConfig: capturedServerConfig,
|
||||
toolDefinition,
|
||||
provider: _provider,
|
||||
provider: capturedProvider,
|
||||
streamId = null,
|
||||
}) {
|
||||
/** @type {LCTool} */
|
||||
const { description, parameters } = toolDefinition;
|
||||
const isGoogle = _provider === Providers.VERTEXAI || _provider === Providers.GOOGLE;
|
||||
const isGoogle = capturedProvider === Providers.VERTEXAI || capturedProvider === Providers.GOOGLE;
|
||||
|
||||
let schema = parameters ? normalizeJsonSchema(resolveJsonSchemaRefs(parameters)) : null;
|
||||
|
||||
|
|
@ -544,7 +597,7 @@ function createToolInstance({
|
|||
const flowManager = getFlowStateManager(flowsCache);
|
||||
derivedSignal = config?.signal ? AbortSignal.any([config.signal]) : undefined;
|
||||
const mcpManager = getMCPManager(userId);
|
||||
const provider = (config?.metadata?.provider || _provider)?.toLowerCase();
|
||||
const provider = (config?.metadata?.provider || capturedProvider)?.toLowerCase();
|
||||
|
||||
const { args: _args, stepId, ...toolCall } = config.toolCall ?? {};
|
||||
const flowId = `${serverName}:oauth_login:${config.metadata.thread_id}:${config.metadata.run_id}`;
|
||||
|
|
@ -576,6 +629,7 @@ function createToolInstance({
|
|||
|
||||
const result = await mcpManager.callTool({
|
||||
serverName,
|
||||
serverConfig: capturedServerConfig,
|
||||
toolName,
|
||||
provider,
|
||||
toolArguments,
|
||||
|
|
@ -643,30 +697,36 @@ function createToolInstance({
|
|||
}
|
||||
|
||||
/**
|
||||
* Get MCP setup data including config, connections, and OAuth servers
|
||||
* Get MCP setup data including config, connections, and OAuth servers.
|
||||
* Resolves config-source servers from admin Config overrides when tenant context is available.
|
||||
* @param {string} userId - The user ID
|
||||
* @param {{ role?: string, tenantId?: string }} [options] - Optional role/tenant context
|
||||
* @returns {Object} Object containing mcpConfig, appConnections, userConnections, and oauthServers
|
||||
*/
|
||||
async function getMCPSetupData(userId) {
|
||||
const mcpConfig = await getMCPServersRegistry().getAllServerConfigs(userId);
|
||||
|
||||
if (!mcpConfig) {
|
||||
throw new Error('MCP config not found');
|
||||
}
|
||||
async function getMCPSetupData(userId, options = {}) {
|
||||
const registry = getMCPServersRegistry();
|
||||
const { role, tenantId } = options;
|
||||
|
||||
const appConfig = await getAppConfig({ role, tenantId, userId });
|
||||
const configServers = await registry.ensureConfigServers(appConfig?.mcpConfig || {});
|
||||
const mcpConfig = await registry.getAllServerConfigs(userId, configServers);
|
||||
const mcpManager = getMCPManager(userId);
|
||||
/** @type {Map<string, import('@librechat/api').MCPConnection>} */
|
||||
let appConnections = new Map();
|
||||
try {
|
||||
// Use getLoaded() instead of getAll() to avoid forcing connection creation
|
||||
// Use getLoaded() instead of getAll() to avoid forcing connection creation.
|
||||
// getAll() creates connections for all servers, which is problematic for servers
|
||||
// that require user context (e.g., those with {{LIBRECHAT_USER_ID}} placeholders)
|
||||
// that require user context (e.g., those with {{LIBRECHAT_USER_ID}} placeholders).
|
||||
appConnections = (await mcpManager.appConnections?.getLoaded()) || new Map();
|
||||
} catch (error) {
|
||||
logger.error(`[MCP][User: ${userId}] Error getting app connections:`, error);
|
||||
}
|
||||
const userConnections = mcpManager.getUserConnections(userId) || new Map();
|
||||
const oauthServers = await getMCPServersRegistry().getOAuthServers(userId);
|
||||
const oauthServers = new Set(
|
||||
Object.entries(mcpConfig)
|
||||
.filter(([, config]) => config.requiresOAuth)
|
||||
.map(([name]) => name),
|
||||
);
|
||||
|
||||
return {
|
||||
mcpConfig,
|
||||
|
|
@ -788,6 +848,8 @@ module.exports = {
|
|||
createMCPTool,
|
||||
createMCPTools,
|
||||
getMCPSetupData,
|
||||
resolveConfigServers,
|
||||
resolveAllMcpConfigs,
|
||||
checkOAuthFlowStatus,
|
||||
getServerConnectionStatus,
|
||||
createUnavailableToolStub,
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ const mockRegistryInstance = {
|
|||
getOAuthServers: jest.fn(() => Promise.resolve(new Set())),
|
||||
getAllServerConfigs: jest.fn(() => Promise.resolve({})),
|
||||
getServerConfig: jest.fn(() => Promise.resolve(null)),
|
||||
ensureConfigServers: jest.fn(() => Promise.resolve({})),
|
||||
};
|
||||
|
||||
// Create isMCPDomainAllowed mock that can be configured per-test
|
||||
|
|
@ -113,38 +114,43 @@ describe('tests for the new helper functions used by the MCP connection status e
|
|||
});
|
||||
|
||||
it('should successfully return MCP setup data', async () => {
|
||||
mockRegistryInstance.getAllServerConfigs.mockResolvedValue(mockConfig);
|
||||
const mockConfigWithOAuth = {
|
||||
server1: { type: 'stdio' },
|
||||
server2: { type: 'http', requiresOAuth: true },
|
||||
};
|
||||
mockRegistryInstance.getAllServerConfigs.mockResolvedValue(mockConfigWithOAuth);
|
||||
|
||||
const mockAppConnections = new Map([['server1', { status: 'connected' }]]);
|
||||
const mockUserConnections = new Map([['server2', { status: 'disconnected' }]]);
|
||||
const mockOAuthServers = new Set(['server2']);
|
||||
|
||||
const mockMCPManager = {
|
||||
appConnections: { getLoaded: jest.fn(() => Promise.resolve(mockAppConnections)) },
|
||||
getUserConnections: jest.fn(() => mockUserConnections),
|
||||
};
|
||||
mockGetMCPManager.mockReturnValue(mockMCPManager);
|
||||
mockRegistryInstance.getOAuthServers.mockResolvedValue(mockOAuthServers);
|
||||
|
||||
const result = await getMCPSetupData(mockUserId);
|
||||
|
||||
expect(mockRegistryInstance.getAllServerConfigs).toHaveBeenCalledWith(mockUserId);
|
||||
expect(mockRegistryInstance.ensureConfigServers).toHaveBeenCalled();
|
||||
expect(mockRegistryInstance.getAllServerConfigs).toHaveBeenCalledWith(
|
||||
mockUserId,
|
||||
expect.any(Object),
|
||||
);
|
||||
expect(mockGetMCPManager).toHaveBeenCalledWith(mockUserId);
|
||||
expect(mockMCPManager.appConnections.getLoaded).toHaveBeenCalled();
|
||||
expect(mockMCPManager.getUserConnections).toHaveBeenCalledWith(mockUserId);
|
||||
expect(mockRegistryInstance.getOAuthServers).toHaveBeenCalledWith(mockUserId);
|
||||
|
||||
expect(result).toEqual({
|
||||
mcpConfig: mockConfig,
|
||||
appConnections: mockAppConnections,
|
||||
userConnections: mockUserConnections,
|
||||
oauthServers: mockOAuthServers,
|
||||
});
|
||||
expect(result.mcpConfig).toEqual(mockConfigWithOAuth);
|
||||
expect(result.appConnections).toEqual(mockAppConnections);
|
||||
expect(result.userConnections).toEqual(mockUserConnections);
|
||||
expect(result.oauthServers).toEqual(new Set(['server2']));
|
||||
});
|
||||
|
||||
it('should throw error when MCP config not found', async () => {
|
||||
mockRegistryInstance.getAllServerConfigs.mockResolvedValue(null);
|
||||
await expect(getMCPSetupData(mockUserId)).rejects.toThrow('MCP config not found');
|
||||
it('should return empty data when no servers are configured', async () => {
|
||||
mockRegistryInstance.getAllServerConfigs.mockResolvedValue({});
|
||||
const result = await getMCPSetupData(mockUserId);
|
||||
expect(result.mcpConfig).toEqual({});
|
||||
expect(result.oauthServers).toEqual(new Set());
|
||||
});
|
||||
|
||||
it('should handle null values from MCP manager gracefully', async () => {
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ const {
|
|||
buildWebSearchContext,
|
||||
buildImageToolContext,
|
||||
buildToolClassification,
|
||||
buildOAuthToolCallName,
|
||||
} = require('@librechat/api');
|
||||
const {
|
||||
Time,
|
||||
|
|
@ -30,6 +31,7 @@ const {
|
|||
imageGenTools,
|
||||
EModelEndpoint,
|
||||
EToolResources,
|
||||
isActionTool,
|
||||
actionDelimiter,
|
||||
ImageVisionTool,
|
||||
openapiToFunction,
|
||||
|
|
@ -59,6 +61,7 @@ const { manifestToolMap, toolkits } = require('~/app/clients/tools/manifest');
|
|||
const { createOnSearchResults } = require('~/server/services/Tools/search');
|
||||
const { loadAuthValues } = require('~/server/services/Tools/credentials');
|
||||
const { reinitMCPServer } = require('~/server/services/Tools/mcp');
|
||||
const { resolveConfigServers } = require('~/server/services/MCP');
|
||||
const { recordUsage } = require('~/server/services/Threads');
|
||||
const { loadTools } = require('~/app/clients/tools/util');
|
||||
const { redactMessage } = require('~/config/parsers');
|
||||
|
|
@ -488,7 +491,7 @@ async function loadToolDefinitionsWrapper({ req, res, agent, streamId = null, to
|
|||
if (tool === Tools.web_search) {
|
||||
return checkCapability(AgentCapabilities.web_search);
|
||||
}
|
||||
if (tool.includes(actionDelimiter)) {
|
||||
if (isActionTool(tool)) {
|
||||
return actionsEnabled;
|
||||
}
|
||||
if (!areToolsEnabled) {
|
||||
|
|
@ -513,6 +516,7 @@ async function loadToolDefinitionsWrapper({ req, res, agent, streamId = null, to
|
|||
|
||||
const flowsCache = getLogStores(CacheKeys.FLOWS);
|
||||
const flowManager = getFlowStateManager(flowsCache);
|
||||
const configServers = await resolveConfigServers(req);
|
||||
const pendingOAuthServers = new Set();
|
||||
|
||||
const createOAuthEmitter = (serverName) => {
|
||||
|
|
@ -521,7 +525,7 @@ async function loadToolDefinitionsWrapper({ req, res, agent, streamId = null, to
|
|||
const stepId = 'step_oauth_login_' + serverName;
|
||||
const toolCall = {
|
||||
id: flowId,
|
||||
name: serverName,
|
||||
name: buildOAuthToolCallName(serverName),
|
||||
type: 'tool_call_chunk',
|
||||
};
|
||||
|
||||
|
|
@ -578,6 +582,7 @@ async function loadToolDefinitionsWrapper({ req, res, agent, streamId = null, to
|
|||
oauthStart,
|
||||
flowManager,
|
||||
serverName,
|
||||
configServers,
|
||||
userMCPAuthMap,
|
||||
});
|
||||
|
||||
|
|
@ -665,6 +670,7 @@ async function loadToolDefinitionsWrapper({ req, res, agent, streamId = null, to
|
|||
const result = await reinitMCPServer({
|
||||
user: req.user,
|
||||
serverName,
|
||||
configServers,
|
||||
userMCPAuthMap,
|
||||
flowManager,
|
||||
returnOnOAuth: false,
|
||||
|
|
@ -866,7 +872,7 @@ async function loadAgentTools({
|
|||
} else if (tool === Tools.web_search) {
|
||||
includesWebSearch = checkCapability(AgentCapabilities.web_search);
|
||||
return includesWebSearch;
|
||||
} else if (tool.includes(actionDelimiter)) {
|
||||
} else if (isActionTool(tool)) {
|
||||
return actionsEnabled;
|
||||
} else if (!areToolsEnabled) {
|
||||
return false;
|
||||
|
|
@ -973,7 +979,7 @@ async function loadAgentTools({
|
|||
|
||||
agentTools.push(...additionalTools);
|
||||
|
||||
const hasActionTools = _agentTools.some((t) => t.includes(actionDelimiter));
|
||||
const hasActionTools = _agentTools.some((t) => isActionTool(t));
|
||||
if (!hasActionTools) {
|
||||
return {
|
||||
toolRegistry,
|
||||
|
|
@ -1232,8 +1238,11 @@ async function loadToolsForExecution({
|
|||
? [...new Set([...requestedNonSpecialToolNames, ...ptcOrchestratedToolNames])]
|
||||
: requestedNonSpecialToolNames;
|
||||
|
||||
const actionToolNames = allToolNamesToLoad.filter((name) => name.includes(actionDelimiter));
|
||||
const regularToolNames = allToolNamesToLoad.filter((name) => !name.includes(actionDelimiter));
|
||||
const actionToolNames = [];
|
||||
const regularToolNames = [];
|
||||
for (const name of allToolNamesToLoad) {
|
||||
(isActionTool(name) ? actionToolNames : regularToolNames).push(name);
|
||||
}
|
||||
|
||||
if (regularToolNames.length > 0) {
|
||||
const includesWebSearch = regularToolNames.includes(Tools.web_search);
|
||||
|
|
|
|||
|
|
@ -25,11 +25,13 @@ async function reinitMCPServer({
|
|||
signal,
|
||||
forceNew,
|
||||
serverName,
|
||||
configServers,
|
||||
userMCPAuthMap,
|
||||
connectionTimeout,
|
||||
returnOnOAuth = true,
|
||||
oauthStart: _oauthStart,
|
||||
flowManager: _flowManager,
|
||||
serverConfig: providedConfig,
|
||||
}) {
|
||||
/** @type {MCPConnection | null} */
|
||||
let connection = null;
|
||||
|
|
@ -42,13 +44,28 @@ async function reinitMCPServer({
|
|||
|
||||
try {
|
||||
const registry = getMCPServersRegistry();
|
||||
const serverConfig = await registry.getServerConfig(serverName, user?.id);
|
||||
const serverConfig =
|
||||
providedConfig ?? (await registry.getServerConfig(serverName, user?.id, configServers));
|
||||
if (serverConfig?.inspectionFailed) {
|
||||
if (serverConfig.source === 'config') {
|
||||
logger.info(
|
||||
`[MCP Reinitialize] Config-source server ${serverName} has inspectionFailed — retry handled by config cache`,
|
||||
);
|
||||
return {
|
||||
availableTools: null,
|
||||
success: false,
|
||||
message: `MCP server '${serverName}' is still unreachable`,
|
||||
oauthRequired: false,
|
||||
serverName,
|
||||
oauthUrl: null,
|
||||
tools: null,
|
||||
};
|
||||
}
|
||||
logger.info(
|
||||
`[MCP Reinitialize] Server ${serverName} had failed inspection, attempting reinspection`,
|
||||
);
|
||||
try {
|
||||
const storageLocation = serverConfig.dbId ? 'DB' : 'CACHE';
|
||||
const storageLocation = serverConfig.source === 'user' ? 'DB' : 'CACHE';
|
||||
await registry.reinspectServer(serverName, storageLocation, user?.id);
|
||||
logger.info(`[MCP Reinitialize] Reinspection succeeded for server: ${serverName}`);
|
||||
} catch (reinspectError) {
|
||||
|
|
@ -93,6 +110,7 @@ async function reinitMCPServer({
|
|||
returnOnOAuth,
|
||||
customUserVars,
|
||||
connectionTimeout,
|
||||
serverConfig,
|
||||
});
|
||||
|
||||
logger.info(`[MCP Reinitialize] Successfully established connection for ${serverName}`);
|
||||
|
|
@ -125,6 +143,7 @@ async function reinitMCPServer({
|
|||
oauthStart,
|
||||
customUserVars,
|
||||
connectionTimeout,
|
||||
configServers,
|
||||
});
|
||||
|
||||
if (discoveryResult.tools && discoveryResult.tools.length > 0) {
|
||||
|
|
|
|||
131
api/server/services/__tests__/MCP.spec.js
Normal file
131
api/server/services/__tests__/MCP.spec.js
Normal file
|
|
@ -0,0 +1,131 @@
|
|||
const mockRegistry = {
|
||||
ensureConfigServers: jest.fn(),
|
||||
getAllServerConfigs: jest.fn(),
|
||||
};
|
||||
|
||||
jest.mock('~/config', () => ({
|
||||
getMCPServersRegistry: jest.fn(() => mockRegistry),
|
||||
getMCPManager: jest.fn(),
|
||||
getFlowStateManager: jest.fn(),
|
||||
getOAuthReconnectionManager: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
getTenantId: jest.fn(() => 'tenant-1'),
|
||||
logger: { debug: jest.fn(), info: jest.fn(), warn: jest.fn(), error: jest.fn() },
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Config', () => ({
|
||||
getAppConfig: jest.fn(),
|
||||
setCachedTools: jest.fn(),
|
||||
getCachedTools: jest.fn(),
|
||||
getMCPServerTools: jest.fn(),
|
||||
loadCustomConfig: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/cache', () => ({ getLogStores: jest.fn() }));
|
||||
jest.mock('~/models', () => ({
|
||||
findToken: jest.fn(),
|
||||
createToken: jest.fn(),
|
||||
updateToken: jest.fn(),
|
||||
}));
|
||||
jest.mock('~/server/services/GraphTokenService', () => ({
|
||||
getGraphApiToken: jest.fn(),
|
||||
}));
|
||||
jest.mock('~/server/services/Tools/mcp', () => ({
|
||||
reinitMCPServer: jest.fn(),
|
||||
}));
|
||||
|
||||
const { getAppConfig } = require('~/server/services/Config');
|
||||
const { resolveConfigServers, resolveAllMcpConfigs } = require('../MCP');
|
||||
|
||||
describe('resolveConfigServers', () => {
|
||||
beforeEach(() => jest.clearAllMocks());
|
||||
|
||||
it('resolves config servers for the current request context', async () => {
|
||||
getAppConfig.mockResolvedValue({ mcpConfig: { srv: { url: 'http://a' } } });
|
||||
mockRegistry.ensureConfigServers.mockResolvedValue({ srv: { name: 'srv' } });
|
||||
|
||||
const result = await resolveConfigServers({ user: { id: 'u1', role: 'admin' } });
|
||||
|
||||
expect(result).toEqual({ srv: { name: 'srv' } });
|
||||
expect(getAppConfig).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ role: 'admin', userId: 'u1' }),
|
||||
);
|
||||
expect(mockRegistry.ensureConfigServers).toHaveBeenCalledWith({ srv: { url: 'http://a' } });
|
||||
});
|
||||
|
||||
it('returns {} when ensureConfigServers throws', async () => {
|
||||
getAppConfig.mockResolvedValue({ mcpConfig: { srv: {} } });
|
||||
mockRegistry.ensureConfigServers.mockRejectedValue(new Error('inspect failed'));
|
||||
|
||||
const result = await resolveConfigServers({ user: { id: 'u1' } });
|
||||
|
||||
expect(result).toEqual({});
|
||||
});
|
||||
|
||||
it('returns {} when getAppConfig throws', async () => {
|
||||
getAppConfig.mockRejectedValue(new Error('db timeout'));
|
||||
|
||||
const result = await resolveConfigServers({ user: { id: 'u1' } });
|
||||
|
||||
expect(result).toEqual({});
|
||||
});
|
||||
|
||||
it('passes empty mcpConfig when appConfig has none', async () => {
|
||||
getAppConfig.mockResolvedValue({});
|
||||
mockRegistry.ensureConfigServers.mockResolvedValue({});
|
||||
|
||||
await resolveConfigServers({ user: { id: 'u1' } });
|
||||
|
||||
expect(mockRegistry.ensureConfigServers).toHaveBeenCalledWith({});
|
||||
});
|
||||
});
|
||||
|
||||
describe('resolveAllMcpConfigs', () => {
|
||||
beforeEach(() => jest.clearAllMocks());
|
||||
|
||||
it('merges config servers with base servers', async () => {
|
||||
getAppConfig.mockResolvedValue({ mcpConfig: { cfg_srv: {} } });
|
||||
mockRegistry.ensureConfigServers.mockResolvedValue({ cfg_srv: { name: 'cfg_srv' } });
|
||||
mockRegistry.getAllServerConfigs.mockResolvedValue({
|
||||
cfg_srv: { name: 'cfg_srv' },
|
||||
yaml_srv: { name: 'yaml_srv' },
|
||||
});
|
||||
|
||||
const result = await resolveAllMcpConfigs('u1', { id: 'u1', role: 'user' });
|
||||
|
||||
expect(result).toEqual({
|
||||
cfg_srv: { name: 'cfg_srv' },
|
||||
yaml_srv: { name: 'yaml_srv' },
|
||||
});
|
||||
expect(mockRegistry.getAllServerConfigs).toHaveBeenCalledWith('u1', {
|
||||
cfg_srv: { name: 'cfg_srv' },
|
||||
});
|
||||
});
|
||||
|
||||
it('continues with empty configServers when ensureConfigServers fails', async () => {
|
||||
getAppConfig.mockResolvedValue({ mcpConfig: { srv: {} } });
|
||||
mockRegistry.ensureConfigServers.mockRejectedValue(new Error('inspect failed'));
|
||||
mockRegistry.getAllServerConfigs.mockResolvedValue({ yaml_srv: { name: 'yaml_srv' } });
|
||||
|
||||
const result = await resolveAllMcpConfigs('u1', { id: 'u1' });
|
||||
|
||||
expect(result).toEqual({ yaml_srv: { name: 'yaml_srv' } });
|
||||
expect(mockRegistry.getAllServerConfigs).toHaveBeenCalledWith('u1', {});
|
||||
});
|
||||
|
||||
it('propagates getAllServerConfigs failures', async () => {
|
||||
getAppConfig.mockResolvedValue({ mcpConfig: {} });
|
||||
mockRegistry.ensureConfigServers.mockResolvedValue({});
|
||||
mockRegistry.getAllServerConfigs.mockRejectedValue(new Error('redis down'));
|
||||
|
||||
await expect(resolveAllMcpConfigs('u1', { id: 'u1' })).rejects.toThrow('redis down');
|
||||
});
|
||||
|
||||
it('propagates getAppConfig failures', async () => {
|
||||
getAppConfig.mockRejectedValue(new Error('mongo down'));
|
||||
|
||||
await expect(resolveAllMcpConfigs('u1', { id: 'u1' })).rejects.toThrow('mongo down');
|
||||
});
|
||||
});
|
||||
|
|
@ -2,6 +2,7 @@ const {
|
|||
Tools,
|
||||
Constants,
|
||||
EModelEndpoint,
|
||||
isActionTool,
|
||||
actionDelimiter,
|
||||
AgentCapabilities,
|
||||
defaultAgentCapabilities,
|
||||
|
|
@ -64,6 +65,9 @@ jest.mock('~/models', () => ({
|
|||
jest.mock('~/config', () => ({
|
||||
getFlowStateManager: jest.fn(() => ({})),
|
||||
}));
|
||||
jest.mock('~/server/services/MCP', () => ({
|
||||
resolveConfigServers: jest.fn().mockResolvedValue({}),
|
||||
}));
|
||||
jest.mock('~/cache', () => ({
|
||||
getLogStores: jest.fn(() => ({})),
|
||||
}));
|
||||
|
|
@ -140,6 +144,42 @@ describe('ToolService - Action Capability Gating', () => {
|
|||
});
|
||||
});
|
||||
|
||||
describe('isActionTool — cross-delimiter collision guard', () => {
|
||||
it('should identify real action tools', () => {
|
||||
expect(isActionTool(`get_weather${actionDelimiter}api_example_com`)).toBe(true);
|
||||
expect(isActionTool(`fetch_data${actionDelimiter}my---domain---com`)).toBe(true);
|
||||
});
|
||||
|
||||
it('should identify action tools whose operationId contains _mcp_', () => {
|
||||
expect(isActionTool(`sync_mcp_state${actionDelimiter}api---example---com`)).toBe(true);
|
||||
expect(isActionTool(`get_mcp_config${actionDelimiter}internal---api---com`)).toBe(true);
|
||||
});
|
||||
|
||||
it('should reject MCP tools whose name ends with _action', () => {
|
||||
expect(isActionTool(`get_action${Constants.mcp_delimiter}myserver`)).toBe(false);
|
||||
expect(isActionTool(`fetch_action${Constants.mcp_delimiter}server_name`)).toBe(false);
|
||||
expect(isActionTool(`retrieve_action${Constants.mcp_delimiter}srv`)).toBe(false);
|
||||
});
|
||||
|
||||
it('should reject MCP tools with _action_ in the middle of their name', () => {
|
||||
expect(isActionTool(`get_action_data${Constants.mcp_delimiter}myserver`)).toBe(false);
|
||||
expect(isActionTool(`create_action_item${Constants.mcp_delimiter}server`)).toBe(false);
|
||||
});
|
||||
|
||||
it('should reject tools without the action delimiter', () => {
|
||||
expect(isActionTool('calculator')).toBe(false);
|
||||
expect(isActionTool(`web_search${Constants.mcp_delimiter}myserver`)).toBe(false);
|
||||
});
|
||||
|
||||
it('known limitation: non-RFC domain with _mcp_ substring yields false negative', () => {
|
||||
// RFC 952/1123 prohibit underscores in hostnames, so this is not expected in practice.
|
||||
// Encoded domain `api_mcp_internal_com` places `_mcp_` after `_action_`, which
|
||||
// the guard interprets as the MCP suffix.
|
||||
const edgeCaseTool = `getData${actionDelimiter}api_mcp_internal_com`;
|
||||
expect(isActionTool(edgeCaseTool)).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('loadAgentTools (definitionsOnly=true) — action tool filtering', () => {
|
||||
const actionToolName = `get_weather${actionDelimiter}api_example_com`;
|
||||
const regularTool = 'calculator';
|
||||
|
|
@ -180,6 +220,25 @@ describe('ToolService - Action Capability Gating', () => {
|
|||
expect(callArgs.tools).toContain(actionToolName);
|
||||
});
|
||||
|
||||
it('should not filter MCP tools whose name contains _action (cross-delimiter collision)', async () => {
|
||||
const mcpToolWithAction = `get_action${Constants.mcp_delimiter}myserver`;
|
||||
const capabilities = [AgentCapabilities.tools];
|
||||
const req = createMockReq(capabilities);
|
||||
mockGetEndpointsConfig.mockResolvedValue(createEndpointsConfig(capabilities));
|
||||
|
||||
await loadAgentTools({
|
||||
req,
|
||||
res: {},
|
||||
agent: { id: 'agent_123', tools: [regularTool, mcpToolWithAction] },
|
||||
definitionsOnly: true,
|
||||
});
|
||||
|
||||
expect(mockLoadToolDefinitions).toHaveBeenCalledTimes(1);
|
||||
const [callArgs] = mockLoadToolDefinitions.mock.calls[0];
|
||||
expect(callArgs.tools).toContain(mcpToolWithAction);
|
||||
expect(callArgs.tools).toContain(regularTool);
|
||||
});
|
||||
|
||||
it('should return actionsEnabled in the result', async () => {
|
||||
const capabilities = [AgentCapabilities.tools];
|
||||
const req = createMockReq(capabilities);
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ const { createMCPServersRegistry, createMCPManager } = require('~/config');
|
|||
* Initialize MCP servers
|
||||
*/
|
||||
async function initializeMCPs() {
|
||||
const appConfig = await getAppConfig();
|
||||
const appConfig = await getAppConfig({ baseOnly: true });
|
||||
const mcpServers = appConfig.mcpConfig;
|
||||
|
||||
try {
|
||||
|
|
|
|||
|
|
@ -6,11 +6,16 @@ const { logger, DEFAULT_SESSION_EXPIRY } = require('@librechat/data-schemas');
|
|||
const {
|
||||
openIdJwtLogin,
|
||||
facebookLogin,
|
||||
facebookAdminLogin,
|
||||
discordLogin,
|
||||
discordAdminLogin,
|
||||
setupOpenId,
|
||||
googleLogin,
|
||||
googleAdminLogin,
|
||||
githubLogin,
|
||||
githubAdminLogin,
|
||||
appleLogin,
|
||||
appleAdminLogin,
|
||||
setupSaml,
|
||||
} = require('~/strategies');
|
||||
const { getLogStores } = require('~/cache');
|
||||
|
|
@ -58,18 +63,23 @@ const configureSocialLogins = async (app) => {
|
|||
|
||||
if (process.env.GOOGLE_CLIENT_ID && process.env.GOOGLE_CLIENT_SECRET) {
|
||||
passport.use(googleLogin());
|
||||
passport.use('googleAdmin', googleAdminLogin());
|
||||
}
|
||||
if (process.env.FACEBOOK_CLIENT_ID && process.env.FACEBOOK_CLIENT_SECRET) {
|
||||
passport.use(facebookLogin());
|
||||
passport.use('facebookAdmin', facebookAdminLogin());
|
||||
}
|
||||
if (process.env.GITHUB_CLIENT_ID && process.env.GITHUB_CLIENT_SECRET) {
|
||||
passport.use(githubLogin());
|
||||
passport.use('githubAdmin', githubAdminLogin());
|
||||
}
|
||||
if (process.env.DISCORD_CLIENT_ID && process.env.DISCORD_CLIENT_SECRET) {
|
||||
passport.use(discordLogin());
|
||||
passport.use('discordAdmin', discordAdminLogin());
|
||||
}
|
||||
if (process.env.APPLE_CLIENT_ID && process.env.APPLE_PRIVATE_KEY_PATH) {
|
||||
passport.use(appleLogin());
|
||||
passport.use('appleAdmin', appleAdminLogin());
|
||||
}
|
||||
if (
|
||||
process.env.OPENID_CLIENT_ID &&
|
||||
|
|
|
|||
|
|
@ -7,10 +7,10 @@ const maxFileSize = resolveImportMaxFileSize();
|
|||
|
||||
/**
|
||||
* Job definition for importing a conversation.
|
||||
* @param {{ filepath, requestUserId }} job - The job object.
|
||||
* @param {{ filepath: string, requestUserId: string, userRole?: string }} job
|
||||
*/
|
||||
const importConversations = async (job) => {
|
||||
const { filepath, requestUserId } = job;
|
||||
const { filepath, requestUserId, userRole } = job;
|
||||
try {
|
||||
logger.debug(`user: ${requestUserId} | Importing conversation(s) from file...`);
|
||||
|
||||
|
|
@ -24,7 +24,7 @@ const importConversations = async (job) => {
|
|||
const fileData = await fs.readFile(filepath, 'utf8');
|
||||
const jsonData = JSON.parse(fileData);
|
||||
const importer = getImporter(jsonData);
|
||||
await importer(jsonData, requestUserId);
|
||||
await importer(jsonData, requestUserId, undefined, userRole);
|
||||
logger.debug(`user: ${requestUserId} | Finished importing conversations`);
|
||||
} catch (error) {
|
||||
logger.error(`user: ${requestUserId} | Failed to import conversation: `, error);
|
||||
|
|
|
|||
|
|
@ -8,17 +8,16 @@ jest.mock('~/models', () => ({
|
|||
bulkSaveConvos: jest.fn(),
|
||||
bulkSaveMessages: jest.fn(),
|
||||
}));
|
||||
jest.mock('~/cache/getLogStores');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const mockedCacheGet = jest.fn();
|
||||
getLogStores.mockImplementation(() => ({
|
||||
get: mockedCacheGet,
|
||||
|
||||
const mockGetEndpointsConfig = jest.fn().mockResolvedValue(null);
|
||||
jest.mock('~/server/services/Config', () => ({
|
||||
getEndpointsConfig: (...args) => mockGetEndpointsConfig(...args),
|
||||
}));
|
||||
|
||||
describe('Import Timestamp Ordering', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
mockedCacheGet.mockResolvedValue(null);
|
||||
mockGetEndpointsConfig.mockResolvedValue(null);
|
||||
});
|
||||
|
||||
describe('LibreChat Import - Timestamp Issues', () => {
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
const { v4: uuidv4 } = require('uuid');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { EModelEndpoint, Constants, openAISettings, CacheKeys } = require('librechat-data-provider');
|
||||
const { logger, getTenantId } = require('@librechat/data-schemas');
|
||||
const { EModelEndpoint, Constants, openAISettings } = require('librechat-data-provider');
|
||||
const { getEndpointsConfig } = require('~/server/services/Config');
|
||||
const { createImportBatchBuilder } = require('./importBatchBuilder');
|
||||
const { cloneMessagesWithTimestamps } = require('./fork');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
|
||||
/**
|
||||
* Returns the appropriate importer function based on the provided JSON data.
|
||||
|
|
@ -194,6 +194,7 @@ async function importLibreChatConvo(
|
|||
jsonData,
|
||||
requestUserId,
|
||||
builderFactory = createImportBatchBuilder,
|
||||
userRole,
|
||||
) {
|
||||
try {
|
||||
/** @type {ImportBatchBuilder} */
|
||||
|
|
@ -202,8 +203,9 @@ async function importLibreChatConvo(
|
|||
|
||||
/* Endpoint configuration */
|
||||
let endpoint = jsonData.endpoint ?? options.endpoint ?? EModelEndpoint.openAI;
|
||||
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
||||
const endpointsConfig = await cache.get(CacheKeys.ENDPOINT_CONFIG);
|
||||
const endpointsConfig = await getEndpointsConfig({
|
||||
user: { id: requestUserId, role: userRole, tenantId: getTenantId() },
|
||||
});
|
||||
const endpointConfig = endpointsConfig?.[endpoint];
|
||||
if (!endpointConfig && endpointsConfig) {
|
||||
endpoint = Object.keys(endpointsConfig)[0];
|
||||
|
|
|
|||
|
|
@ -4,12 +4,13 @@ const { EModelEndpoint, Constants, openAISettings } = require('librechat-data-pr
|
|||
const { getImporter, processAssistantMessage } = require('./importers');
|
||||
const { ImportBatchBuilder } = require('./importBatchBuilder');
|
||||
const { bulkSaveMessages, bulkSaveConvos: _bulkSaveConvos } = require('~/models');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
|
||||
jest.mock('~/cache/getLogStores');
|
||||
const mockedCacheGet = jest.fn();
|
||||
getLogStores.mockImplementation(() => ({
|
||||
get: mockedCacheGet,
|
||||
const mockGetEndpointsConfig = jest.fn().mockResolvedValue({
|
||||
[EModelEndpoint.openAI]: { userProvide: false },
|
||||
});
|
||||
|
||||
jest.mock('~/server/services/Config', () => ({
|
||||
getEndpointsConfig: (...args) => mockGetEndpointsConfig(...args),
|
||||
}));
|
||||
|
||||
// Mock the database methods
|
||||
|
|
@ -758,7 +759,7 @@ describe('importLibreChatConvo', () => {
|
|||
);
|
||||
|
||||
it('should import conversation correctly', async () => {
|
||||
mockedCacheGet.mockResolvedValue({
|
||||
mockGetEndpointsConfig.mockResolvedValue({
|
||||
[EModelEndpoint.openAI]: {},
|
||||
});
|
||||
const expectedNumberOfMessages = 6;
|
||||
|
|
@ -784,7 +785,7 @@ describe('importLibreChatConvo', () => {
|
|||
});
|
||||
|
||||
it('should import linear, non-recursive thread correctly with correct endpoint', async () => {
|
||||
mockedCacheGet.mockResolvedValue({
|
||||
mockGetEndpointsConfig.mockResolvedValue({
|
||||
[EModelEndpoint.azureOpenAI]: {},
|
||||
});
|
||||
|
||||
|
|
@ -924,7 +925,7 @@ describe('importLibreChatConvo', () => {
|
|||
});
|
||||
|
||||
it('should retain properties from the original conversation as well as new settings', async () => {
|
||||
mockedCacheGet.mockResolvedValue({
|
||||
mockGetEndpointsConfig.mockResolvedValue({
|
||||
[EModelEndpoint.azureOpenAI]: {},
|
||||
});
|
||||
const requestUserId = 'user-123';
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue