diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000000..725ac8b6bd --- /dev/null +++ b/.gitattributes @@ -0,0 +1,3 @@ +# Force LF line endings for shell scripts and git hooks (required for cross-platform compatibility) +.husky/* text eol=lf +*.sh text eol=lf diff --git a/.github/workflows/backend-review.yml b/.github/workflows/backend-review.yml index 038c90627e..9dd3905c0e 100644 --- a/.github/workflows/backend-review.yml +++ b/.github/workflows/backend-review.yml @@ -97,6 +97,65 @@ jobs: path: packages/api/dist retention-days: 2 + typecheck: + name: TypeScript type checks + needs: build + runs-on: ubuntu-latest + timeout-minutes: 10 + steps: + - uses: actions/checkout@v4 + + - name: Use Node.js 20.19 + uses: actions/setup-node@v4 + with: + node-version: '20.19' + + - name: Restore node_modules cache + id: cache-node-modules + uses: actions/cache@v4 + with: + path: | + node_modules + api/node_modules + packages/api/node_modules + packages/data-provider/node_modules + packages/data-schemas/node_modules + key: node-modules-backend-${{ runner.os }}-20.19-${{ hashFiles('package-lock.json') }} + + - name: Install dependencies + if: steps.cache-node-modules.outputs.cache-hit != 'true' + run: npm ci + + - name: Download data-provider build + uses: actions/download-artifact@v4 + with: + name: build-data-provider + path: packages/data-provider/dist + + - name: Download data-schemas build + uses: actions/download-artifact@v4 + with: + name: build-data-schemas + path: packages/data-schemas/dist + + - name: Download api build + uses: actions/download-artifact@v4 + with: + name: build-api + path: packages/api/dist + + - name: Type check data-provider + run: npx tsc --noEmit -p packages/data-provider/tsconfig.json + + - name: Type check data-schemas + run: npx tsc --noEmit -p packages/data-schemas/tsconfig.json + + - name: Type check @librechat/api + run: npx tsc --noEmit -p packages/api/tsconfig.json + + - name: Type check @librechat/client + run: npx tsc --noEmit -p packages/client/tsconfig.json + circular-deps: name: Circular dependency checks needs: build diff --git a/AGENTS.md b/AGENTS.md index ec44607aa7..81362cfc57 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -29,6 +29,12 @@ The source code for `@librechat/agents` (major backend dependency, same team) is ## Code Style +### Naming and File Organization + +- **Single-word file names** whenever possible (e.g., `permissions.ts`, `capabilities.ts`, `service.ts`). +- When multiple words are needed, prefer grouping related modules under a **single-word directory** rather than using multi-word file names (e.g., `admin/capabilities.ts` not `adminCapabilities.ts`). +- The directory already provides context — `app/service.ts` not `app/appConfigService.ts`. + ### Structure and Clarity - **Never-nesting**: early returns, flat code, minimal indentation. Break complex operations into well-named helpers. diff --git a/README.md b/README.md index 7da34974e3..a7f68d9a92 100644 --- a/README.md +++ b/README.md @@ -32,7 +32,7 @@

- + Deploy on Railway diff --git a/README.zh.md b/README.zh.md index cc9cb5a205..7f74057413 100644 --- a/README.zh.md +++ b/README.zh.md @@ -1,4 +1,4 @@ - +

@@ -34,7 +34,7 @@

- + Deploy on Railway diff --git a/api/app/clients/BaseClient.js b/api/app/clients/BaseClient.js index ec5ccfb5f4..08cb1f6ada 100644 --- a/api/app/clients/BaseClient.js +++ b/api/app/clients/BaseClient.js @@ -32,7 +32,6 @@ class BaseClient { constructor(apiKey, options = {}) { this.apiKey = apiKey; this.sender = options.sender ?? 'AI'; - this.contextStrategy = null; this.currentDateString = new Date().toLocaleDateString('en-us', { year: 'numeric', month: 'long', diff --git a/api/app/clients/tools/util/handleTools.js b/api/app/clients/tools/util/handleTools.js index 4b86101425..8adb43f945 100644 --- a/api/app/clients/tools/util/handleTools.js +++ b/api/app/clients/tools/util/handleTools.js @@ -14,7 +14,6 @@ const { buildImageToolContext, buildWebSearchContext, } = require('@librechat/api'); -const { getMCPServersRegistry } = require('~/config'); const { Tools, Constants, @@ -39,12 +38,13 @@ const { createGeminiImageTool, createOpenAIImageTools, } = require('../'); -const { primeFiles: primeCodeFiles } = require('~/server/services/Files/Code/process'); +const { createMCPTool, createMCPTools, resolveConfigServers } = require('~/server/services/MCP'); const { createFileSearchTool, primeFiles: primeSearchFiles } = require('./fileSearch'); +const { primeFiles: primeCodeFiles } = require('~/server/services/Files/Code/process'); const { getUserPluginAuthValue } = require('~/server/services/PluginService'); -const { createMCPTool, createMCPTools } = require('~/server/services/MCP'); const { loadAuthValues } = require('~/server/services/Tools/credentials'); const { getMCPServerTools } = require('~/server/services/Config'); +const { getMCPServersRegistry } = require('~/config'); const { getRoleByName } = require('~/models'); /** @@ -256,6 +256,12 @@ const loadTools = async ({ const toolContextMap = {}; const requestedMCPTools = {}; + /** Resolve config-source servers for the current user/tenant context */ + let configServers; + if (tools.some((tool) => tool && mcpToolPattern.test(tool))) { + configServers = await resolveConfigServers(options.req); + } + for (const tool of tools) { if (tool === Tools.execute_code) { requestedTools[tool] = async () => { @@ -341,7 +347,7 @@ const loadTools = async ({ continue; } const serverConfig = serverName - ? await getMCPServersRegistry().getServerConfig(serverName, user) + ? await getMCPServersRegistry().getServerConfig(serverName, user, configServers) : null; if (!serverConfig) { logger.warn( @@ -419,6 +425,7 @@ const loadTools = async ({ let index = -1; const failedMCPServers = new Set(); const safeUser = createSafeUser(options.req?.user); + for (const [serverName, toolConfigs] of Object.entries(requestedMCPTools)) { index++; /** @type {LCAvailableTools} */ @@ -433,6 +440,7 @@ const loadTools = async ({ signal, user: safeUser, userMCPAuthMap, + configServers, res: options.res, streamId: options.req?._resumableStreamId || null, model: agent?.model ?? model, diff --git a/api/server/cleanup.js b/api/server/cleanup.js index 364c02cd8a..c27814292d 100644 --- a/api/server/cleanup.js +++ b/api/server/cleanup.js @@ -123,9 +123,6 @@ function disposeClient(client) { if (client.maxContextTokens) { client.maxContextTokens = null; } - if (client.contextStrategy) { - client.contextStrategy = null; - } if (client.currentDateString) { client.currentDateString = null; } diff --git a/api/server/controllers/ModelController.js b/api/server/controllers/ModelController.js index 805d9eef27..1741d3f6b1 100644 --- a/api/server/controllers/ModelController.js +++ b/api/server/controllers/ModelController.js @@ -1,5 +1,5 @@ -const { logger } = require('@librechat/data-schemas'); const { CacheKeys } = require('librechat-data-provider'); +const { logger, scopedCacheKey } = require('@librechat/data-schemas'); const { loadDefaultModels, loadConfigModels } = require('~/server/services/Config'); const { getLogStores } = require('~/cache'); @@ -9,7 +9,8 @@ const { getLogStores } = require('~/cache'); */ const getModelsConfig = async (req) => { const cache = getLogStores(CacheKeys.CONFIG_STORE); - let modelsConfig = await cache.get(CacheKeys.MODELS_CONFIG); + const cacheKey = scopedCacheKey(CacheKeys.MODELS_CONFIG); + let modelsConfig = await cache.get(cacheKey); if (!modelsConfig) { modelsConfig = await loadModels(req); } @@ -24,7 +25,8 @@ const getModelsConfig = async (req) => { */ async function loadModels(req) { const cache = getLogStores(CacheKeys.CONFIG_STORE); - const cachedModelsConfig = await cache.get(CacheKeys.MODELS_CONFIG); + const cacheKey = scopedCacheKey(CacheKeys.MODELS_CONFIG); + const cachedModelsConfig = await cache.get(cacheKey); if (cachedModelsConfig) { return cachedModelsConfig; } @@ -33,7 +35,7 @@ async function loadModels(req) { const modelConfig = { ...defaultModelsConfig, ...customModelsConfig }; - await cache.set(CacheKeys.MODELS_CONFIG, modelConfig); + await cache.set(cacheKey, modelConfig); return modelConfig; } diff --git a/api/server/controllers/PluginController.js b/api/server/controllers/PluginController.js index 279ffb15fd..7c47fe4d57 100644 --- a/api/server/controllers/PluginController.js +++ b/api/server/controllers/PluginController.js @@ -1,5 +1,5 @@ -const { logger } = require('@librechat/data-schemas'); const { CacheKeys } = require('librechat-data-provider'); +const { logger, scopedCacheKey } = require('@librechat/data-schemas'); const { getToolkitKey, checkPluginAuth, filterUniquePlugins } = require('@librechat/api'); const { getCachedTools, setCachedTools } = require('~/server/services/Config'); const { availableTools, toolkits } = require('~/app/clients/tools'); @@ -9,13 +9,14 @@ const { getLogStores } = require('~/cache'); const getAvailablePluginsController = async (req, res) => { try { const cache = getLogStores(CacheKeys.TOOL_CACHE); - const cachedPlugins = await cache.get(CacheKeys.PLUGINS); + const pluginsCacheKey = scopedCacheKey(CacheKeys.PLUGINS); + const cachedPlugins = await cache.get(pluginsCacheKey); if (cachedPlugins) { res.status(200).json(cachedPlugins); return; } - const appConfig = await getAppConfig({ role: req.user?.role }); + const appConfig = await getAppConfig({ role: req.user?.role, tenantId: req.user?.tenantId }); /** @type {{ filteredTools: string[], includedTools: string[] }} */ const { filteredTools = [], includedTools = [] } = appConfig; /** @type {import('@librechat/api').LCManifestTool[]} */ @@ -37,7 +38,7 @@ const getAvailablePluginsController = async (req, res) => { plugins = plugins.filter((plugin) => !filteredTools.includes(plugin.pluginKey)); } - await cache.set(CacheKeys.PLUGINS, plugins); + await cache.set(pluginsCacheKey, plugins); res.status(200).json(plugins); } catch (error) { res.status(500).json({ message: error.message }); @@ -64,9 +65,11 @@ const getAvailableTools = async (req, res) => { return res.status(401).json({ message: 'Unauthorized' }); } const cache = getLogStores(CacheKeys.TOOL_CACHE); - const cachedToolsArray = await cache.get(CacheKeys.TOOLS); + const toolsCacheKey = scopedCacheKey(CacheKeys.TOOLS); + const cachedToolsArray = await cache.get(toolsCacheKey); - const appConfig = req.config ?? (await getAppConfig({ role: req.user?.role })); + const appConfig = + req.config ?? (await getAppConfig({ role: req.user?.role, tenantId: req.user?.tenantId })); // Return early if we have cached tools if (cachedToolsArray != null) { @@ -114,7 +117,7 @@ const getAvailableTools = async (req, res) => { } const finalTools = filterUniquePlugins(toolsOutput); - await cache.set(CacheKeys.TOOLS, finalTools); + await cache.set(toolsCacheKey, finalTools); res.status(200).json(finalTools); } catch (error) { diff --git a/api/server/controllers/PluginController.spec.js b/api/server/controllers/PluginController.spec.js index 06a51a3bd6..fdbc2401ce 100644 --- a/api/server/controllers/PluginController.spec.js +++ b/api/server/controllers/PluginController.spec.js @@ -8,6 +8,7 @@ jest.mock('@librechat/data-schemas', () => ({ error: jest.fn(), warn: jest.fn(), }, + scopedCacheKey: jest.fn((key) => key), })); jest.mock('~/server/services/Config', () => ({ diff --git a/api/server/controllers/UserController.js b/api/server/controllers/UserController.js index 301c6d2f76..16b68968d9 100644 --- a/api/server/controllers/UserController.js +++ b/api/server/controllers/UserController.js @@ -26,7 +26,7 @@ const { getLogStores } = require('~/cache'); const db = require('~/models'); const getUserController = async (req, res) => { - const appConfig = await getAppConfig({ role: req.user?.role }); + const appConfig = await getAppConfig({ role: req.user?.role, tenantId: req.user?.tenantId }); /** @type {IUser} */ const userData = req.user.toObject != null ? req.user.toObject() : { ...req.user }; /** @@ -165,7 +165,7 @@ const deleteUserMcpServers = async (userId) => { }; const updateUserPluginsController = async (req, res) => { - const appConfig = await getAppConfig({ role: req.user?.role }); + const appConfig = await getAppConfig({ role: req.user?.role, tenantId: req.user?.tenantId }); const { user } = req; const { pluginKey, action, auth, isEntityTool } = req.body; try { diff --git a/api/server/controllers/agents/client.js b/api/server/controllers/agents/client.js index 47a10165e3..d6795a4be9 100644 --- a/api/server/controllers/agents/client.js +++ b/api/server/controllers/agents/client.js @@ -50,6 +50,7 @@ const { const { filterFilesByAgentAccess } = require('~/server/services/Files/permissions'); const { encodeAndFormat } = require('~/server/services/Files/images/encode'); const { createContextHandlers } = require('~/app/clients/prompts'); +const { resolveConfigServers } = require('~/server/services/MCP'); const { getMCPServerTools } = require('~/server/services/Config'); const BaseClient = require('~/app/clients/BaseClient'); const { getMCPManager } = require('~/config'); @@ -377,6 +378,9 @@ class AgentClient extends BaseClient { */ const ephemeralAgent = this.options.req.body.ephemeralAgent; const mcpManager = getMCPManager(); + + const configServers = await resolveConfigServers(this.options.req); + await Promise.all( allAgents.map(({ agent, agentId }) => applyContextToAgent({ @@ -384,6 +388,7 @@ class AgentClient extends BaseClient { agentId, logger, mcpManager, + configServers, sharedRunContext, ephemeralAgent: agentId === this.options.agent.id ? ephemeralAgent : undefined, }), diff --git a/api/server/controllers/agents/client.test.js b/api/server/controllers/agents/client.test.js index 41a806f66d..1595f652f7 100644 --- a/api/server/controllers/agents/client.test.js +++ b/api/server/controllers/agents/client.test.js @@ -22,6 +22,10 @@ jest.mock('~/server/services/Config', () => ({ getMCPServerTools: jest.fn(), })); +jest.mock('~/server/services/MCP', () => ({ + resolveConfigServers: jest.fn().mockResolvedValue({}), +})); + jest.mock('~/models', () => ({ getAgent: jest.fn(), getRoleByName: jest.fn(), @@ -1315,7 +1319,7 @@ describe('AgentClient - titleConvo', () => { }); // Verify formatInstructionsForContext was called with correct server names - expect(mockFormatInstructions).toHaveBeenCalledWith(['server1', 'server2']); + expect(mockFormatInstructions).toHaveBeenCalledWith(['server1', 'server2'], {}); // Verify the instructions do NOT contain [object Promise] expect(client.options.agent.instructions).not.toContain('[object Promise]'); @@ -1355,10 +1359,10 @@ describe('AgentClient - titleConvo', () => { }); // Verify formatInstructionsForContext was called with ephemeral server names - expect(mockFormatInstructions).toHaveBeenCalledWith([ - 'ephemeral-server1', - 'ephemeral-server2', - ]); + expect(mockFormatInstructions).toHaveBeenCalledWith( + ['ephemeral-server1', 'ephemeral-server2'], + {}, + ); // Verify no [object Promise] in instructions expect(client.options.agent.instructions).not.toContain('[object Promise]'); diff --git a/api/server/controllers/mcp.js b/api/server/controllers/mcp.js index 729f01da9d..e31bb93bc6 100644 --- a/api/server/controllers/mcp.js +++ b/api/server/controllers/mcp.js @@ -14,6 +14,7 @@ const { isMCPInspectionFailedError, } = require('@librechat/api'); const { Constants, MCPServerUserInputSchema } = require('librechat-data-provider'); +const { resolveConfigServers, resolveAllMcpConfigs } = require('~/server/services/MCP'); const { cacheMCPServerTools, getMCPServerTools } = require('~/server/services/Config'); const { getMCPManager, getMCPServersRegistry } = require('~/config'); @@ -57,7 +58,7 @@ function handleMCPError(error, res) { } /** - * Get all MCP tools available to the user + * Get all MCP tools available to the user. */ const getMCPTools = async (req, res) => { try { @@ -67,10 +68,10 @@ const getMCPTools = async (req, res) => { return res.status(401).json({ message: 'Unauthorized' }); } - const mcpConfig = await getMCPServersRegistry().getAllServerConfigs(userId); - const configuredServers = mcpConfig ? Object.keys(mcpConfig) : []; + const mcpConfig = await resolveAllMcpConfigs(userId, req.user); + const configuredServers = Object.keys(mcpConfig); - if (!mcpConfig || Object.keys(mcpConfig).length == 0) { + if (!configuredServers.length) { return res.status(200).json({ servers: {} }); } @@ -115,14 +116,11 @@ const getMCPTools = async (req, res) => { try { const serverTools = serverToolsMap.get(serverName); - // Get server config once const serverConfig = mcpConfig[serverName]; - const rawServerConfig = await getMCPServersRegistry().getServerConfig(serverName, userId); - // Initialize server object with all server-level data const server = { name: serverName, - icon: rawServerConfig?.iconPath || '', + icon: serverConfig?.iconPath || '', authenticated: true, authConfig: [], tools: [], @@ -183,7 +181,7 @@ const getMCPServersList = async (req, res) => { return res.status(401).json({ message: 'Unauthorized' }); } - const serverConfigs = await getMCPServersRegistry().getAllServerConfigs(userId); + const serverConfigs = await resolveAllMcpConfigs(userId, req.user); return res.json(redactAllServerSecrets(serverConfigs)); } catch (error) { logger.error('[getMCPServersList]', error); @@ -237,7 +235,12 @@ const getMCPServerById = async (req, res) => { if (!serverName) { return res.status(400).json({ message: 'Server name is required' }); } - const parsedConfig = await getMCPServersRegistry().getServerConfig(serverName, userId); + const configServers = await resolveConfigServers(req); + const parsedConfig = await getMCPServersRegistry().getServerConfig( + serverName, + userId, + configServers, + ); if (!parsedConfig) { return res.status(404).json({ message: 'MCP server not found' }); diff --git a/api/server/index.js b/api/server/index.js index ba376ab335..4b919b1ceb 100644 --- a/api/server/index.js +++ b/api/server/index.js @@ -8,8 +8,8 @@ const express = require('express'); const passport = require('passport'); const compression = require('compression'); const cookieParser = require('cookie-parser'); -const { logger } = require('@librechat/data-schemas'); const mongoSanitize = require('express-mongo-sanitize'); +const { logger, runAsSystem } = require('@librechat/data-schemas'); const { isEnabled, apiNotFound, @@ -21,6 +21,7 @@ const { createStreamServices, initializeFileStorage, updateInterfacePermissions, + preAuthTenantMiddleware, } = require('@librechat/api'); const { connectDb, indexSync } = require('~/db'); const initializeOAuthReconnectManager = require('./services/initializeOAuthReconnectManager'); @@ -59,11 +60,20 @@ const startServer = async () => { app.disable('x-powered-by'); app.set('trust proxy', trusted_proxy); - await seedDatabase(); - const appConfig = await getAppConfig(); + if (isEnabled(process.env.TENANT_ISOLATION_STRICT)) { + logger.warn( + '[Security] TENANT_ISOLATION_STRICT is active. Ensure your reverse proxy strips or sets ' + + 'the X-Tenant-Id header — untrusted clients must not be able to set it directly.', + ); + } + + await runAsSystem(seedDatabase); + const appConfig = await getAppConfig({ baseOnly: true }); initializeFileStorage(appConfig); - await performStartupChecks(appConfig); - await updateInterfacePermissions({ appConfig, getRoleByName, updateAccessPermissions }); + await runAsSystem(async () => { + await performStartupChecks(appConfig); + await updateInterfacePermissions({ appConfig, getRoleByName, updateAccessPermissions }); + }); const indexPath = path.join(appConfig.paths.dist, 'index.html'); let indexHTML = fs.readFileSync(indexPath, 'utf8'); @@ -137,10 +147,15 @@ const startServer = async () => { /* Per-request capability cache — must be registered before any route that calls hasCapability */ app.use(capabilityContextMiddleware); - app.use('/oauth', routes.oauth); + /* Pre-auth tenant context for unauthenticated routes that need tenant scoping. + * The reverse proxy / auth gateway sets `X-Tenant-Id` header for multi-tenant deployments. */ + app.use('/oauth', preAuthTenantMiddleware, routes.oauth); /* API Endpoints */ - app.use('/api/auth', routes.auth); + app.use('/api/auth', preAuthTenantMiddleware, routes.auth); app.use('/api/admin', routes.adminAuth); + app.use('/api/admin/config', routes.adminConfig); + app.use('/api/admin/groups', routes.adminGroups); + app.use('/api/admin/roles', routes.adminRoles); app.use('/api/actions', routes.actions); app.use('/api/keys', routes.keys); app.use('/api/api-keys', routes.apiKeys); @@ -154,11 +169,11 @@ const startServer = async () => { app.use('/api/endpoints', routes.endpoints); app.use('/api/balance', routes.balance); app.use('/api/models', routes.models); - app.use('/api/config', routes.config); + app.use('/api/config', preAuthTenantMiddleware, routes.config); app.use('/api/assistants', routes.assistants); app.use('/api/files', await routes.files.initialize()); app.use('/images/', createValidateImageRequest(appConfig.secureImageLinks), routes.staticRoute); - app.use('/api/share', routes.share); + app.use('/api/share', preAuthTenantMiddleware, routes.share); app.use('/api/roles', routes.roles); app.use('/api/agents', routes.agents); app.use('/api/banner', routes.banner); @@ -204,8 +219,10 @@ const startServer = async () => { logger.info(`Server listening at http://${host == '0.0.0.0' ? 'localhost' : host}:${port}`); } - await initializeMCPs(); - await initializeOAuthReconnectManager(); + await runAsSystem(async () => { + await initializeMCPs(); + await initializeOAuthReconnectManager(); + }); await checkMigrations(); // Configure stream services (auto-detects Redis from USE_REDIS env var) diff --git a/api/server/middleware/__tests__/requireJwtAuth.spec.js b/api/server/middleware/__tests__/requireJwtAuth.spec.js new file mode 100644 index 0000000000..bc288e5dab --- /dev/null +++ b/api/server/middleware/__tests__/requireJwtAuth.spec.js @@ -0,0 +1,116 @@ +/** + * Integration test: verifies that requireJwtAuth chains tenantContextMiddleware + * after successful passport authentication, so ALS tenant context is set for + * all downstream middleware and route handlers. + * + * requireJwtAuth must chain tenantContextMiddleware after passport populates + * req.user (not at global app.use() scope where req.user is undefined). + * If the chaining is removed, these tests fail. + */ + +const { getTenantId } = require('@librechat/data-schemas'); + +// ── Mocks ────────────────────────────────────────────────────────────── + +let mockPassportError = null; + +jest.mock('passport', () => ({ + authenticate: jest.fn(() => { + return (req, _res, done) => { + if (mockPassportError) { + return done(mockPassportError); + } + if (req._mockUser) { + req.user = req._mockUser; + } + done(); + }; + }), +})); + +// Mock @librechat/api — the real tenantContextMiddleware is TS and cannot be +// required directly from CJS tests. This thin wrapper mirrors the real logic +// (read req.user.tenantId, call tenantStorage.run) using the same data-schemas +// primitives. The real implementation is covered by packages/api tenant.spec.ts. +jest.mock('@librechat/api', () => { + const { tenantStorage } = require('@librechat/data-schemas'); + return { + isEnabled: jest.fn(() => false), + tenantContextMiddleware: (req, res, next) => { + const tenantId = req.user?.tenantId; + if (!tenantId) { + return next(); + } + return tenantStorage.run({ tenantId }, async () => next()); + }, + }; +}); + +// ── Helpers ───────────────────────────────────────────────────────────── + +const requireJwtAuth = require('../requireJwtAuth'); + +function mockReq(user) { + return { headers: {}, _mockUser: user }; +} + +function mockRes() { + return { status: jest.fn().mockReturnThis(), json: jest.fn().mockReturnThis() }; +} + +/** Runs requireJwtAuth and returns the tenantId observed inside next(). */ +function runAuth(user) { + return new Promise((resolve) => { + const req = mockReq(user); + const res = mockRes(); + requireJwtAuth(req, res, () => { + resolve(getTenantId()); + }); + }); +} + +// ── Tests ────────────────────────────────────────────────────────────── + +describe('requireJwtAuth tenant context chaining', () => { + afterEach(() => { + mockPassportError = null; + }); + + it('forwards passport errors to next() without entering tenant middleware', async () => { + mockPassportError = new Error('JWT signature invalid'); + const req = mockReq(undefined); + const res = mockRes(); + const err = await new Promise((resolve) => { + requireJwtAuth(req, res, (e) => resolve(e)); + }); + expect(err).toBeInstanceOf(Error); + expect(err.message).toBe('JWT signature invalid'); + expect(getTenantId()).toBeUndefined(); + }); + + it('sets ALS tenant context after passport auth succeeds', async () => { + const tenantId = await runAuth({ tenantId: 'tenant-abc', role: 'user' }); + expect(tenantId).toBe('tenant-abc'); + }); + + it('ALS tenant context is NOT set when user has no tenantId', async () => { + const tenantId = await runAuth({ role: 'user' }); + expect(tenantId).toBeUndefined(); + }); + + it('ALS tenant context is NOT set when user is undefined', async () => { + const tenantId = await runAuth(undefined); + expect(tenantId).toBeUndefined(); + }); + + it('concurrent requests get isolated tenant contexts', async () => { + const results = await Promise.all( + ['tenant-1', 'tenant-2', 'tenant-3'].map((tid) => runAuth({ tenantId: tid, role: 'user' })), + ); + expect(results).toEqual(['tenant-1', 'tenant-2', 'tenant-3']); + }); + + it('ALS context is not set at top-level scope (outside any request)', () => { + expect(getTenantId()).toBeUndefined(); + }); +}); diff --git a/api/server/middleware/checkDomainAllowed.js b/api/server/middleware/checkDomainAllowed.js index 754eb9c127..f7a3f00e68 100644 --- a/api/server/middleware/checkDomainAllowed.js +++ b/api/server/middleware/checkDomainAllowed.js @@ -18,6 +18,7 @@ const checkDomainAllowed = async (req, res, next) => { const email = req?.user?.email; const appConfig = await getAppConfig({ role: req?.user?.role, + tenantId: req?.user?.tenantId, }); if (email && !isEmailDomainAllowed(email, appConfig?.registration?.allowedDomains)) { diff --git a/api/server/middleware/config/app.js b/api/server/middleware/config/app.js index bca3c8f71d..fb5f89b229 100644 --- a/api/server/middleware/config/app.js +++ b/api/server/middleware/config/app.js @@ -4,7 +4,9 @@ const { getAppConfig } = require('~/server/services/Config'); const configMiddleware = async (req, res, next) => { try { const userRole = req.user?.role; - req.config = await getAppConfig({ role: userRole }); + const userId = req.user?.id; + const tenantId = req.user?.tenantId; + req.config = await getAppConfig({ role: userRole, userId, tenantId }); next(); } catch (error) { diff --git a/api/server/middleware/optionalJwtAuth.js b/api/server/middleware/optionalJwtAuth.js index 2f59fdda4a..d46478d36e 100644 --- a/api/server/middleware/optionalJwtAuth.js +++ b/api/server/middleware/optionalJwtAuth.js @@ -1,9 +1,10 @@ const cookies = require('cookie'); const passport = require('passport'); -const { isEnabled } = require('@librechat/api'); +const { isEnabled, tenantContextMiddleware } = require('@librechat/api'); // This middleware does not require authentication, -// but if the user is authenticated, it will set the user object. +// but if the user is authenticated, it will set the user object +// and establish tenant ALS context. const optionalJwtAuth = (req, res, next) => { const cookieHeader = req.headers.cookie; const tokenProvider = cookieHeader ? cookies.parse(cookieHeader).token_provider : null; @@ -13,6 +14,7 @@ const optionalJwtAuth = (req, res, next) => { } if (user) { req.user = user; + return tenantContextMiddleware(req, res, next); } next(); }; diff --git a/api/server/middleware/requireJwtAuth.js b/api/server/middleware/requireJwtAuth.js index 16b107aefc..b13e991b23 100644 --- a/api/server/middleware/requireJwtAuth.js +++ b/api/server/middleware/requireJwtAuth.js @@ -1,20 +1,29 @@ const cookies = require('cookie'); const passport = require('passport'); -const { isEnabled } = require('@librechat/api'); +const { isEnabled, tenantContextMiddleware } = require('@librechat/api'); /** - * Custom Middleware to handle JWT authentication, with support for OpenID token reuse - * Switches between JWT and OpenID authentication based on cookies and environment settings + * Custom Middleware to handle JWT authentication, with support for OpenID token reuse. + * Switches between JWT and OpenID authentication based on cookies and environment settings. + * + * After successful authentication (req.user populated), automatically chains into + * `tenantContextMiddleware` to propagate `req.user.tenantId` into AsyncLocalStorage + * for downstream Mongoose tenant isolation. */ const requireJwtAuth = (req, res, next) => { const cookieHeader = req.headers.cookie; const tokenProvider = cookieHeader ? cookies.parse(cookieHeader).token_provider : null; - if (tokenProvider === 'openid' && isEnabled(process.env.OPENID_REUSE_TOKENS)) { - return passport.authenticate('openidJwt', { session: false })(req, res, next); - } + const strategy = + tokenProvider === 'openid' && isEnabled(process.env.OPENID_REUSE_TOKENS) ? 'openidJwt' : 'jwt'; - return passport.authenticate('jwt', { session: false })(req, res, next); + passport.authenticate(strategy, { session: false })(req, res, (err) => { + if (err) { + return next(err); + } + // req.user is now populated by passport — set up tenant ALS context + tenantContextMiddleware(req, res, next); + }); }; module.exports = requireJwtAuth; diff --git a/api/server/routes/__tests__/mcp.spec.js b/api/server/routes/__tests__/mcp.spec.js index 1ad8cac087..f194f361d3 100644 --- a/api/server/routes/__tests__/mcp.spec.js +++ b/api/server/routes/__tests__/mcp.spec.js @@ -18,6 +18,7 @@ const mockRegistryInstance = { getServerConfig: jest.fn(), getOAuthServers: jest.fn(), getAllServerConfigs: jest.fn(), + ensureConfigServers: jest.fn().mockResolvedValue({}), addServer: jest.fn(), updateServer: jest.fn(), removeServer: jest.fn(), @@ -58,6 +59,7 @@ jest.mock('@librechat/api', () => { }); jest.mock('@librechat/data-schemas', () => ({ + getTenantId: jest.fn(), logger: { debug: jest.fn(), info: jest.fn(), @@ -93,14 +95,18 @@ jest.mock('~/server/services/Config', () => ({ getCachedTools: jest.fn(), getMCPServerTools: jest.fn(), loadCustomConfig: jest.fn(), + getAppConfig: jest.fn().mockResolvedValue({ mcpConfig: {} }), })); jest.mock('~/server/services/Config/mcp', () => ({ updateMCPServerTools: jest.fn(), })); +const mockResolveAllMcpConfigs = jest.fn().mockResolvedValue({}); jest.mock('~/server/services/MCP', () => ({ getMCPSetupData: jest.fn(), + resolveConfigServers: jest.fn().mockResolvedValue({}), + resolveAllMcpConfigs: (...args) => mockResolveAllMcpConfigs(...args), getServerConnectionStatus: jest.fn(), })); @@ -579,6 +585,112 @@ describe('MCP Routes', () => { ); }); + it('should use oauthHeaders from flow state when present', async () => { + const mockFlowManager = { + getFlowState: jest.fn().mockResolvedValue({ status: 'PENDING' }), + completeFlow: jest.fn().mockResolvedValue(), + deleteFlow: jest.fn().mockResolvedValue(true), + }; + const mockFlowState = { + serverName: 'test-server', + userId: 'test-user-id', + metadata: { toolFlowId: 'tool-flow-123' }, + clientInfo: {}, + codeVerifier: 'test-verifier', + oauthHeaders: { 'X-Custom-Auth': 'header-value' }, + }; + const mockTokens = { access_token: 'tok', refresh_token: 'ref' }; + + MCPOAuthHandler.getFlowState.mockResolvedValue(mockFlowState); + MCPOAuthHandler.completeOAuthFlow.mockResolvedValue(mockTokens); + MCPTokenStorage.storeTokens.mockResolvedValue(); + getLogStores.mockReturnValue({}); + require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager); + require('~/config').getOAuthReconnectionManager.mockReturnValue({ + clearReconnection: jest.fn(), + }); + require('~/config').getMCPManager.mockReturnValue({ + getUserConnection: jest.fn().mockResolvedValue({ + fetchTools: jest.fn().mockResolvedValue([]), + }), + }); + const { getCachedTools, setCachedTools } = require('~/server/services/Config'); + getCachedTools.mockResolvedValue({}); + setCachedTools.mockResolvedValue(); + + const flowId = 'test-user-id:test-server'; + const csrfToken = generateTestCsrfToken(flowId); + + await request(app) + .get('/api/mcp/test-server/oauth/callback') + .set('Cookie', [`oauth_csrf=${csrfToken}`]) + .query({ code: 'auth-code', state: flowId }); + + expect(MCPOAuthHandler.completeOAuthFlow).toHaveBeenCalledWith( + flowId, + 'auth-code', + mockFlowManager, + { 'X-Custom-Auth': 'header-value' }, + ); + expect(mockRegistryInstance.getServerConfig).not.toHaveBeenCalled(); + }); + + it('should fall back to registry oauth_headers when flow state lacks them', async () => { + const mockFlowManager = { + getFlowState: jest.fn().mockResolvedValue({ status: 'PENDING' }), + completeFlow: jest.fn().mockResolvedValue(), + deleteFlow: jest.fn().mockResolvedValue(true), + }; + const mockFlowState = { + serverName: 'test-server', + userId: 'test-user-id', + metadata: { toolFlowId: 'tool-flow-123' }, + clientInfo: {}, + codeVerifier: 'test-verifier', + }; + const mockTokens = { access_token: 'tok', refresh_token: 'ref' }; + + MCPOAuthHandler.getFlowState.mockResolvedValue(mockFlowState); + MCPOAuthHandler.completeOAuthFlow.mockResolvedValue(mockTokens); + MCPTokenStorage.storeTokens.mockResolvedValue(); + mockRegistryInstance.getServerConfig.mockResolvedValue({ + oauth_headers: { 'X-Registry-Header': 'from-registry' }, + }); + getLogStores.mockReturnValue({}); + require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager); + require('~/config').getOAuthReconnectionManager.mockReturnValue({ + clearReconnection: jest.fn(), + }); + require('~/config').getMCPManager.mockReturnValue({ + getUserConnection: jest.fn().mockResolvedValue({ + fetchTools: jest.fn().mockResolvedValue([]), + }), + }); + const { getCachedTools, setCachedTools } = require('~/server/services/Config'); + getCachedTools.mockResolvedValue({}); + setCachedTools.mockResolvedValue(); + + const flowId = 'test-user-id:test-server'; + const csrfToken = generateTestCsrfToken(flowId); + + await request(app) + .get('/api/mcp/test-server/oauth/callback') + .set('Cookie', [`oauth_csrf=${csrfToken}`]) + .query({ code: 'auth-code', state: flowId }); + + expect(MCPOAuthHandler.completeOAuthFlow).toHaveBeenCalledWith( + flowId, + 'auth-code', + mockFlowManager, + { 'X-Registry-Header': 'from-registry' }, + ); + expect(mockRegistryInstance.getServerConfig).toHaveBeenCalledWith( + 'test-server', + 'test-user-id', + undefined, + ); + }); + it('should redirect to error page when callback processing fails', async () => { MCPOAuthHandler.getFlowState.mockRejectedValue(new Error('Callback error')); const flowId = 'test-user-id:test-server'; @@ -1350,19 +1462,10 @@ describe('MCP Routes', () => { }, }); - expect(getMCPSetupData).toHaveBeenCalledWith('test-user-id'); + expect(getMCPSetupData).toHaveBeenCalledWith('test-user-id', expect.any(Object)); expect(getServerConnectionStatus).toHaveBeenCalledTimes(2); }); - it('should return 404 when MCP config is not found', async () => { - getMCPSetupData.mockRejectedValue(new Error('MCP config not found')); - - const response = await request(app).get('/api/mcp/connection/status'); - - expect(response.status).toBe(404); - expect(response.body).toEqual({ error: 'MCP config not found' }); - }); - it('should return 500 when connection status check fails', async () => { getMCPSetupData.mockRejectedValue(new Error('Database error')); @@ -1437,15 +1540,6 @@ describe('MCP Routes', () => { }); }); - it('should return 404 when MCP config is not found', async () => { - getMCPSetupData.mockRejectedValue(new Error('MCP config not found')); - - const response = await request(app).get('/api/mcp/connection/status/test-server'); - - expect(response.status).toBe(404); - expect(response.body).toEqual({ error: 'MCP config not found' }); - }); - it('should return 500 when connection status check fails', async () => { getMCPSetupData.mockRejectedValue(new Error('Database connection failed')); @@ -1704,7 +1798,7 @@ describe('MCP Routes', () => { }, }; - mockRegistryInstance.getAllServerConfigs.mockResolvedValue(mockServerConfigs); + mockResolveAllMcpConfigs.mockResolvedValue(mockServerConfigs); const response = await request(app).get('/api/mcp/servers'); @@ -1721,11 +1815,14 @@ describe('MCP Routes', () => { }); expect(response.body['server-1'].headers).toBeUndefined(); expect(response.body['server-2'].headers).toBeUndefined(); - expect(mockRegistryInstance.getAllServerConfigs).toHaveBeenCalledWith('test-user-id'); + expect(mockResolveAllMcpConfigs).toHaveBeenCalledWith( + 'test-user-id', + expect.objectContaining({ id: 'test-user-id' }), + ); }); it('should return empty object when no servers are configured', async () => { - mockRegistryInstance.getAllServerConfigs.mockResolvedValue({}); + mockResolveAllMcpConfigs.mockResolvedValue({}); const response = await request(app).get('/api/mcp/servers'); @@ -1749,7 +1846,7 @@ describe('MCP Routes', () => { }); it('should return 500 when server config retrieval fails', async () => { - mockRegistryInstance.getAllServerConfigs.mockRejectedValue(new Error('Database error')); + mockResolveAllMcpConfigs.mockRejectedValue(new Error('Database error')); const response = await request(app).get('/api/mcp/servers'); @@ -1939,11 +2036,12 @@ describe('MCP Routes', () => { expect(mockRegistryInstance.getServerConfig).toHaveBeenCalledWith( 'test-server', 'test-user-id', + {}, ); }); it('should return 404 when server not found', async () => { - mockRegistryInstance.getServerConfig.mockResolvedValue(null); + mockRegistryInstance.getServerConfig.mockResolvedValue(undefined); const response = await request(app).get('/api/mcp/servers/non-existent-server'); diff --git a/api/server/routes/admin/config.js b/api/server/routes/admin/config.js new file mode 100644 index 0000000000..0632077ea9 --- /dev/null +++ b/api/server/routes/admin/config.js @@ -0,0 +1,40 @@ +const express = require('express'); +const { createAdminConfigHandlers } = require('@librechat/api'); +const { SystemCapabilities } = require('@librechat/data-schemas'); +const { + hasConfigCapability, + requireCapability, +} = require('~/server/middleware/roles/capabilities'); +const { getAppConfig, invalidateConfigCaches } = require('~/server/services/Config'); +const { requireJwtAuth } = require('~/server/middleware'); +const db = require('~/models'); + +const router = express.Router(); + +const requireAdminAccess = requireCapability(SystemCapabilities.ACCESS_ADMIN); + +const handlers = createAdminConfigHandlers({ + listAllConfigs: db.listAllConfigs, + findConfigByPrincipal: db.findConfigByPrincipal, + upsertConfig: db.upsertConfig, + patchConfigFields: db.patchConfigFields, + unsetConfigField: db.unsetConfigField, + deleteConfig: db.deleteConfig, + toggleConfigActive: db.toggleConfigActive, + hasConfigCapability, + getAppConfig, + invalidateConfigCaches, +}); + +router.use(requireJwtAuth, requireAdminAccess); + +router.get('/', handlers.listConfigs); +router.get('/base', handlers.getBaseConfig); +router.get('/:principalType/:principalId', handlers.getConfig); +router.put('/:principalType/:principalId', handlers.upsertConfigOverrides); +router.patch('/:principalType/:principalId/fields', handlers.patchConfigField); +router.delete('/:principalType/:principalId/fields', handlers.deleteConfigField); +router.delete('/:principalType/:principalId', handlers.deleteConfigOverrides); +router.patch('/:principalType/:principalId/active', handlers.toggleConfig); + +module.exports = router; diff --git a/api/server/routes/admin/groups.js b/api/server/routes/admin/groups.js new file mode 100644 index 0000000000..7ca93acaa2 --- /dev/null +++ b/api/server/routes/admin/groups.js @@ -0,0 +1,41 @@ +const express = require('express'); +const { createAdminGroupsHandlers } = require('@librechat/api'); +const { SystemCapabilities } = require('@librechat/data-schemas'); +const { requireCapability } = require('~/server/middleware/roles/capabilities'); +const { requireJwtAuth } = require('~/server/middleware'); +const db = require('~/models'); + +const router = express.Router(); + +const requireAdminAccess = requireCapability(SystemCapabilities.ACCESS_ADMIN); +const requireReadGroups = requireCapability(SystemCapabilities.READ_GROUPS); +const requireManageGroups = requireCapability(SystemCapabilities.MANAGE_GROUPS); + +const handlers = createAdminGroupsHandlers({ + listGroups: db.listGroups, + countGroups: db.countGroups, + findGroupById: db.findGroupById, + createGroup: db.createGroup, + updateGroupById: db.updateGroupById, + deleteGroup: db.deleteGroup, + addUserToGroup: db.addUserToGroup, + removeUserFromGroup: db.removeUserFromGroup, + removeMemberById: db.removeMemberById, + findUsers: db.findUsers, + deleteConfig: db.deleteConfig, + deleteAclEntries: db.deleteAclEntries, + deleteGrantsForPrincipal: db.deleteGrantsForPrincipal, +}); + +router.use(requireJwtAuth, requireAdminAccess); + +router.get('/', requireReadGroups, handlers.listGroups); +router.post('/', requireManageGroups, handlers.createGroup); +router.get('/:id', requireReadGroups, handlers.getGroup); +router.patch('/:id', requireManageGroups, handlers.updateGroup); +router.delete('/:id', requireManageGroups, handlers.deleteGroup); +router.get('/:id/members', requireReadGroups, handlers.getGroupMembers); +router.post('/:id/members', requireManageGroups, handlers.addGroupMember); +router.delete('/:id/members/:userId', requireManageGroups, handlers.removeGroupMember); + +module.exports = router; diff --git a/api/server/routes/admin/roles.js b/api/server/routes/admin/roles.js new file mode 100644 index 0000000000..2d0f1b1128 --- /dev/null +++ b/api/server/routes/admin/roles.js @@ -0,0 +1,43 @@ +const express = require('express'); +const { createAdminRolesHandlers } = require('@librechat/api'); +const { SystemCapabilities } = require('@librechat/data-schemas'); +const { requireCapability } = require('~/server/middleware/roles/capabilities'); +const { requireJwtAuth } = require('~/server/middleware'); +const db = require('~/models'); + +const router = express.Router(); + +const requireAdminAccess = requireCapability(SystemCapabilities.ACCESS_ADMIN); +const requireReadRoles = requireCapability(SystemCapabilities.READ_ROLES); +const requireManageRoles = requireCapability(SystemCapabilities.MANAGE_ROLES); + +const handlers = createAdminRolesHandlers({ + listRoles: db.listRoles, + countRoles: db.countRoles, + getRoleByName: db.getRoleByName, + createRoleByName: db.createRoleByName, + updateRoleByName: db.updateRoleByName, + updateAccessPermissions: db.updateAccessPermissions, + deleteRoleByName: db.deleteRoleByName, + findUser: db.findUser, + updateUser: db.updateUser, + updateUsersByRole: db.updateUsersByRole, + findUserIdsByRole: db.findUserIdsByRole, + updateUsersRoleByIds: db.updateUsersRoleByIds, + listUsersByRole: db.listUsersByRole, + countUsersByRole: db.countUsersByRole, +}); + +router.use(requireJwtAuth, requireAdminAccess); + +router.get('/', requireReadRoles, handlers.listRoles); +router.post('/', requireManageRoles, handlers.createRole); +router.get('/:name', requireReadRoles, handlers.getRole); +router.patch('/:name', requireManageRoles, handlers.updateRole); +router.delete('/:name', requireManageRoles, handlers.deleteRole); +router.patch('/:name/permissions', requireManageRoles, handlers.updateRolePermissions); +router.get('/:name/members', requireReadRoles, handlers.getRoleMembers); +router.post('/:name/members', requireManageRoles, handlers.addRoleMember); +router.delete('/:name/members/:userId', requireManageRoles, handlers.removeRoleMember); + +module.exports = router; diff --git a/api/server/routes/agents/__tests__/streamTenant.spec.js b/api/server/routes/agents/__tests__/streamTenant.spec.js new file mode 100644 index 0000000000..1f89953186 --- /dev/null +++ b/api/server/routes/agents/__tests__/streamTenant.spec.js @@ -0,0 +1,186 @@ +const express = require('express'); +const request = require('supertest'); + +const mockGenerationJobManager = { + getJob: jest.fn(), + subscribe: jest.fn(), + getResumeState: jest.fn(), + abortJob: jest.fn(), + getActiveJobIdsForUser: jest.fn().mockResolvedValue([]), +}; + +jest.mock('@librechat/data-schemas', () => ({ + ...jest.requireActual('@librechat/data-schemas'), + logger: { + debug: jest.fn(), + warn: jest.fn(), + error: jest.fn(), + info: jest.fn(), + }, +})); + +jest.mock('@librechat/api', () => ({ + ...jest.requireActual('@librechat/api'), + isEnabled: jest.fn().mockReturnValue(false), + GenerationJobManager: mockGenerationJobManager, +})); + +jest.mock('~/models', () => ({ + saveMessage: jest.fn(), +})); + +let mockUserId = 'user-123'; +let mockTenantId; + +jest.mock('~/server/middleware', () => ({ + uaParser: (req, res, next) => next(), + checkBan: (req, res, next) => next(), + requireJwtAuth: (req, res, next) => { + req.user = { id: mockUserId, tenantId: mockTenantId }; + next(); + }, + messageIpLimiter: (req, res, next) => next(), + configMiddleware: (req, res, next) => next(), + messageUserLimiter: (req, res, next) => next(), +})); + +jest.mock('~/server/routes/agents/chat', () => require('express').Router()); +jest.mock('~/server/routes/agents/v1', () => ({ + v1: require('express').Router(), +})); +jest.mock('~/server/routes/agents/openai', () => require('express').Router()); +jest.mock('~/server/routes/agents/responses', () => require('express').Router()); + +const agentsRouter = require('../index'); +const app = express(); +app.use(express.json()); +app.use('/agents', agentsRouter); + +function mockSubscribeSuccess() { + mockGenerationJobManager.subscribe.mockImplementation((_streamId, _writeEvent, onDone) => { + process.nextTick(() => onDone({ done: true })); + return { unsubscribe: jest.fn() }; + }); +} + +describe('SSE stream tenant isolation', () => { + beforeEach(() => { + jest.clearAllMocks(); + mockUserId = 'user-123'; + mockTenantId = undefined; + }); + + describe('GET /chat/stream/:streamId', () => { + it('returns 403 when a user from a different tenant accesses a stream', async () => { + mockUserId = 'user-456'; + mockTenantId = 'tenant-b'; + + mockGenerationJobManager.getJob.mockResolvedValue({ + metadata: { userId: 'user-456', tenantId: 'tenant-a' }, + status: 'running', + }); + + const res = await request(app).get('/agents/chat/stream/stream-123'); + expect(res.status).toBe(403); + expect(res.body.error).toBe('Unauthorized'); + }); + + it('returns 404 when stream does not exist', async () => { + mockGenerationJobManager.getJob.mockResolvedValue(null); + + const res = await request(app).get('/agents/chat/stream/nonexistent'); + expect(res.status).toBe(404); + }); + + it('proceeds past tenant guard when tenant matches', async () => { + mockUserId = 'user-123'; + mockTenantId = 'tenant-a'; + mockSubscribeSuccess(); + + mockGenerationJobManager.getJob.mockResolvedValue({ + metadata: { userId: 'user-123', tenantId: 'tenant-a' }, + status: 'running', + }); + + const res = await request(app).get('/agents/chat/stream/stream-123'); + expect(res.status).toBe(200); + expect(mockGenerationJobManager.subscribe).toHaveBeenCalledTimes(1); + }); + + it('proceeds past tenant guard when job has no tenantId (single-tenant mode)', async () => { + mockUserId = 'user-123'; + mockTenantId = undefined; + mockSubscribeSuccess(); + + mockGenerationJobManager.getJob.mockResolvedValue({ + metadata: { userId: 'user-123' }, + status: 'running', + }); + + const res = await request(app).get('/agents/chat/stream/stream-123'); + expect(res.status).toBe(200); + expect(mockGenerationJobManager.subscribe).toHaveBeenCalledTimes(1); + }); + + it('returns 403 when job has tenantId but user has no tenantId', async () => { + mockUserId = 'user-123'; + mockTenantId = undefined; + + mockGenerationJobManager.getJob.mockResolvedValue({ + metadata: { userId: 'user-123', tenantId: 'some-tenant' }, + status: 'running', + }); + + const res = await request(app).get('/agents/chat/stream/stream-123'); + expect(res.status).toBe(403); + }); + }); + + describe('GET /chat/status/:conversationId', () => { + it('returns 403 when tenant does not match', async () => { + mockUserId = 'user-123'; + mockTenantId = 'tenant-b'; + + mockGenerationJobManager.getJob.mockResolvedValue({ + metadata: { userId: 'user-123', tenantId: 'tenant-a' }, + status: 'running', + }); + + const res = await request(app).get('/agents/chat/status/conv-123'); + expect(res.status).toBe(403); + expect(res.body.error).toBe('Unauthorized'); + }); + + it('returns status when tenant matches', async () => { + mockUserId = 'user-123'; + mockTenantId = 'tenant-a'; + + mockGenerationJobManager.getJob.mockResolvedValue({ + metadata: { userId: 'user-123', tenantId: 'tenant-a' }, + status: 'running', + createdAt: Date.now(), + }); + mockGenerationJobManager.getResumeState.mockResolvedValue(null); + + const res = await request(app).get('/agents/chat/status/conv-123'); + expect(res.status).toBe(200); + expect(res.body.active).toBe(true); + }); + }); + + describe('POST /chat/abort', () => { + it('returns 403 when tenant does not match', async () => { + mockUserId = 'user-123'; + mockTenantId = 'tenant-b'; + + mockGenerationJobManager.getJob.mockResolvedValue({ + metadata: { userId: 'user-123', tenantId: 'tenant-a' }, + status: 'running', + }); + + const res = await request(app).post('/agents/chat/abort').send({ streamId: 'stream-123' }); + expect(res.status).toBe(403); + expect(res.body.error).toBe('Unauthorized'); + }); + }); +}); diff --git a/api/server/routes/agents/index.js b/api/server/routes/agents/index.js index 86966a3f3e..eb42046bed 100644 --- a/api/server/routes/agents/index.js +++ b/api/server/routes/agents/index.js @@ -17,6 +17,11 @@ const chat = require('./chat'); const { LIMIT_MESSAGE_IP, LIMIT_MESSAGE_USER } = process.env ?? {}; +/** Untenanted jobs (pre-multi-tenancy) remain accessible if the userId check passes. */ +function hasTenantMismatch(job, user) { + return job.metadata?.tenantId != null && job.metadata.tenantId !== user.tenantId; +} + const router = express.Router(); /** @@ -67,6 +72,10 @@ router.get('/chat/stream/:streamId', async (req, res) => { return res.status(403).json({ error: 'Unauthorized' }); } + if (hasTenantMismatch(job, req.user)) { + return res.status(403).json({ error: 'Unauthorized' }); + } + res.setHeader('Content-Encoding', 'identity'); res.setHeader('Content-Type', 'text/event-stream'); res.setHeader('Cache-Control', 'no-cache, no-transform'); @@ -150,7 +159,10 @@ router.get('/chat/stream/:streamId', async (req, res) => { * @returns { activeJobIds: string[] } */ router.get('/chat/active', async (req, res) => { - const activeJobIds = await GenerationJobManager.getActiveJobIdsForUser(req.user.id); + const activeJobIds = await GenerationJobManager.getActiveJobIdsForUser( + req.user.id, + req.user.tenantId, + ); res.json({ activeJobIds }); }); @@ -174,6 +186,10 @@ router.get('/chat/status/:conversationId', async (req, res) => { return res.status(403).json({ error: 'Unauthorized' }); } + if (hasTenantMismatch(job, req.user)) { + return res.status(403).json({ error: 'Unauthorized' }); + } + // Get resume state which contains aggregatedContent // Avoid calling both getStreamInfo and getResumeState (both fetch content) const resumeState = await GenerationJobManager.getResumeState(conversationId); @@ -213,7 +229,10 @@ router.post('/chat/abort', async (req, res) => { // This handles the case where frontend sends "new" but job was created with a UUID if (!job && userId) { logger.debug(`[AgentStream] Job not found by ID, checking active jobs for user: ${userId}`); - const activeJobIds = await GenerationJobManager.getActiveJobIdsForUser(userId); + const activeJobIds = await GenerationJobManager.getActiveJobIdsForUser( + userId, + req.user.tenantId, + ); if (activeJobIds.length > 0) { // Abort the most recent active job for this user jobStreamId = activeJobIds[0]; @@ -230,6 +249,10 @@ router.post('/chat/abort', async (req, res) => { return res.status(403).json({ error: 'Unauthorized' }); } + if (hasTenantMismatch(job, req.user)) { + return res.status(403).json({ error: 'Unauthorized' }); + } + logger.debug(`[AgentStream] Job found, aborting: ${jobStreamId}`); const abortResult = await GenerationJobManager.abortJob(jobStreamId); logger.debug(`[AgentStream] Job aborted successfully: ${jobStreamId}`, { diff --git a/api/server/routes/config.js b/api/server/routes/config.js index bf60f57e08..8caa180854 100644 --- a/api/server/routes/config.js +++ b/api/server/routes/config.js @@ -1,7 +1,7 @@ const express = require('express'); -const { logger } = require('@librechat/data-schemas'); const { isEnabled, getBalanceConfig } = require('@librechat/api'); const { CacheKeys, defaultSocialLogins } = require('librechat-data-provider'); +const { logger, getTenantId, scopedCacheKey } = require('@librechat/data-schemas'); const { getLdapConfig } = require('~/server/services/Config/ldap'); const { getAppConfig } = require('~/server/services/Config/app'); const { getLogStores } = require('~/cache'); @@ -23,7 +23,8 @@ const openidReuseTokens = isEnabled(process.env.OPENID_REUSE_TOKENS); router.get('/', async function (req, res) { const cache = getLogStores(CacheKeys.CONFIG_STORE); - const cachedStartupConfig = await cache.get(CacheKeys.STARTUP_CONFIG); + const cacheKey = scopedCacheKey(CacheKeys.STARTUP_CONFIG); + const cachedStartupConfig = await cache.get(cacheKey); if (cachedStartupConfig) { res.send(cachedStartupConfig); return; @@ -37,7 +38,10 @@ router.get('/', async function (req, res) { const ldap = getLdapConfig(); try { - const appConfig = await getAppConfig({ role: req.user?.role }); + const appConfig = await getAppConfig({ + role: req.user?.role, + tenantId: req.user?.tenantId || getTenantId(), + }); const isOpenIdEnabled = !!process.env.OPENID_CLIENT_ID && @@ -141,7 +145,7 @@ router.get('/', async function (req, res) { payload.customFooter = process.env.CUSTOM_FOOTER; } - await cache.set(CacheKeys.STARTUP_CONFIG, payload); + await cache.set(cacheKey, payload); return res.status(200).send(payload); } catch (err) { logger.error('Error in startup config', err); diff --git a/api/server/routes/index.js b/api/server/routes/index.js index 6a48919db3..71ae041fc2 100644 --- a/api/server/routes/index.js +++ b/api/server/routes/index.js @@ -2,6 +2,9 @@ const accessPermissions = require('./accessPermissions'); const assistants = require('./assistants'); const categories = require('./categories'); const adminAuth = require('./admin/auth'); +const adminConfig = require('./admin/config'); +const adminGroups = require('./admin/groups'); +const adminRoles = require('./admin/roles'); const endpoints = require('./endpoints'); const staticRoute = require('./static'); const messages = require('./messages'); @@ -31,6 +34,9 @@ module.exports = { mcp, auth, adminAuth, + adminConfig, + adminGroups, + adminRoles, keys, apiKeys, user, diff --git a/api/server/routes/mcp.js b/api/server/routes/mcp.js index d6d7ed5ea0..c6496ad4b4 100644 --- a/api/server/routes/mcp.js +++ b/api/server/routes/mcp.js @@ -1,5 +1,5 @@ const { Router } = require('express'); -const { logger } = require('@librechat/data-schemas'); +const { logger, getTenantId } = require('@librechat/data-schemas'); const { CacheKeys, Constants, @@ -36,7 +36,11 @@ const { getFlowStateManager, getMCPManager, } = require('~/config'); -const { getMCPSetupData, getServerConnectionStatus } = require('~/server/services/MCP'); +const { + getServerConnectionStatus, + resolveConfigServers, + getMCPSetupData, +} = require('~/server/services/MCP'); const { requireJwtAuth, canAccessMCPServerResource } = require('~/server/middleware'); const { getUserPluginAuthValue } = require('~/server/services/PluginService'); const { updateMCPServerTools } = require('~/server/services/Config/mcp'); @@ -101,7 +105,8 @@ router.get('/:serverName/oauth/initiate', requireJwtAuth, setOAuthSession, async return res.status(400).json({ error: 'Invalid flow state' }); } - const oauthHeaders = await getOAuthHeaders(serverName, userId); + const configServers = await resolveConfigServers(req); + const oauthHeaders = await getOAuthHeaders(serverName, userId, configServers); const { authorizationUrl, flowId: oauthFlowId, @@ -233,7 +238,14 @@ router.get('/:serverName/oauth/callback', async (req, res) => { } logger.debug('[MCP OAuth] Completing OAuth flow'); - const oauthHeaders = await getOAuthHeaders(serverName, flowState.userId); + if (!flowState.oauthHeaders) { + logger.warn( + '[MCP OAuth] oauthHeaders absent from flow state — config-source server oauth_headers will be empty', + { serverName, flowId }, + ); + } + const oauthHeaders = + flowState.oauthHeaders ?? (await getOAuthHeaders(serverName, flowState.userId)); const tokens = await MCPOAuthHandler.completeOAuthFlow(flowId, code, flowManager, oauthHeaders); logger.info('[MCP OAuth] OAuth flow completed, tokens received in callback route'); @@ -497,7 +509,12 @@ router.post( logger.info(`[MCP Reinitialize] Reinitializing server: ${serverName}`); const mcpManager = getMCPManager(); - const serverConfig = await getMCPServersRegistry().getServerConfig(serverName, user.id); + const configServers = await resolveConfigServers(req); + const serverConfig = await getMCPServersRegistry().getServerConfig( + serverName, + user.id, + configServers, + ); if (!serverConfig) { return res.status(404).json({ error: `MCP server '${serverName}' not found in configuration`, @@ -522,6 +539,8 @@ router.post( const result = await reinitMCPServer({ user, serverName, + serverConfig, + configServers, userMCPAuthMap, }); @@ -564,6 +583,7 @@ router.get('/connection/status', requireJwtAuth, async (req, res) => { const { mcpConfig, appConnections, userConnections, oauthServers } = await getMCPSetupData( user.id, + { role: user.role, tenantId: getTenantId() }, ); const connectionStatus = {}; @@ -593,9 +613,6 @@ router.get('/connection/status', requireJwtAuth, async (req, res) => { connectionStatus, }); } catch (error) { - if (error.message === 'MCP config not found') { - return res.status(404).json({ error: error.message }); - } logger.error('[MCP Connection Status] Failed to get connection status', error); res.status(500).json({ error: 'Failed to get connection status' }); } @@ -616,6 +633,7 @@ router.get('/connection/status/:serverName', requireJwtAuth, async (req, res) => const { mcpConfig, appConnections, userConnections, oauthServers } = await getMCPSetupData( user.id, + { role: user.role, tenantId: getTenantId() }, ); if (!mcpConfig[serverName]) { @@ -640,9 +658,6 @@ router.get('/connection/status/:serverName', requireJwtAuth, async (req, res) => requiresOAuth: serverStatus.requiresOAuth, }); } catch (error) { - if (error.message === 'MCP config not found') { - return res.status(404).json({ error: error.message }); - } logger.error( `[MCP Per-Server Status] Failed to get connection status for ${req.params.serverName}`, error, @@ -664,7 +679,12 @@ router.get('/:serverName/auth-values', requireJwtAuth, checkMCPUsePermissions, a return res.status(401).json({ error: 'User not authenticated' }); } - const serverConfig = await getMCPServersRegistry().getServerConfig(serverName, user.id); + const configServers = await resolveConfigServers(req); + const serverConfig = await getMCPServersRegistry().getServerConfig( + serverName, + user.id, + configServers, + ); if (!serverConfig) { return res.status(404).json({ error: `MCP server '${serverName}' not found in configuration`, @@ -703,8 +723,12 @@ router.get('/:serverName/auth-values', requireJwtAuth, checkMCPUsePermissions, a } }); -async function getOAuthHeaders(serverName, userId) { - const serverConfig = await getMCPServersRegistry().getServerConfig(serverName, userId); +async function getOAuthHeaders(serverName, userId, configServers) { + const serverConfig = await getMCPServersRegistry().getServerConfig( + serverName, + userId, + configServers, + ); return serverConfig?.oauth_headers ?? {}; } diff --git a/api/server/services/AuthService.js b/api/server/services/AuthService.js index ef50a365b9..816a0eac5b 100644 --- a/api/server/services/AuthService.js +++ b/api/server/services/AuthService.js @@ -13,6 +13,7 @@ const { checkEmailConfig, isEmailDomainAllowed, shouldUseSecureCookie, + resolveAppConfigForUser, } = require('@librechat/api'); const { findUser, @@ -189,7 +190,7 @@ const registerUser = async (user, additionalData = {}) => { let newUserId; try { - const appConfig = await getAppConfig(); + const appConfig = await getAppConfig({ baseOnly: true }); if (!isEmailDomainAllowed(email, appConfig?.registration?.allowedDomains)) { const errorMessage = 'The email address provided cannot be used. Please use a different email address.'; @@ -255,19 +256,52 @@ const registerUser = async (user, additionalData = {}) => { }; /** - * Request password reset + * Request password reset. + * + * Uses a two-phase domain check: fast-fail with the memory-cached base config + * (zero DB queries) to block globally denied domains before user lookup, then + * re-check with tenant-scoped config after user lookup so tenant-specific + * restrictions are enforced. + * + * Phase 1 (base check) returns an Error (HTTP 400) — this intentionally reveals + * that the domain is globally blocked, but fires before any DB lookup so it + * cannot confirm user existence. Phase 2 (tenant check) returns the generic + * success message (HTTP 200) to prevent user-enumeration via status codes. + * * @param {ServerRequest} req */ const requestPasswordReset = async (req) => { const { email } = req.body; - const appConfig = await getAppConfig(); - if (!isEmailDomainAllowed(email, appConfig?.registration?.allowedDomains)) { + + const baseConfig = await getAppConfig({ baseOnly: true }); + if (!isEmailDomainAllowed(email, baseConfig?.registration?.allowedDomains)) { + logger.warn( + `[requestPasswordReset] Blocked - email domain not allowed [Email: ${email}] [IP: ${req.ip}]`, + ); const error = new Error(ErrorTypes.AUTH_FAILED); error.code = ErrorTypes.AUTH_FAILED; error.message = 'Email domain not allowed'; return error; } - const user = await findUser({ email }, 'email _id'); + + const user = await findUser({ email }, 'email _id role tenantId'); + let appConfig = baseConfig; + if (user?.tenantId) { + try { + appConfig = await resolveAppConfigForUser(getAppConfig, user); + } catch (err) { + logger.error('[requestPasswordReset] Failed to resolve tenant config, using base:', err); + } + } + + if (!isEmailDomainAllowed(email, appConfig?.registration?.allowedDomains)) { + logger.warn( + `[requestPasswordReset] Tenant config blocked domain [Email: ${email}] [IP: ${req.ip}]`, + ); + return { + message: 'If an account with that email exists, a password reset link has been sent to it.', + }; + } const emailEnabled = checkEmailConfig(); logger.warn(`[requestPasswordReset] [Password reset request initiated] [Email: ${email}]`); diff --git a/api/server/services/AuthService.spec.js b/api/server/services/AuthService.spec.js index da78f8d775..c8abafdbe5 100644 --- a/api/server/services/AuthService.spec.js +++ b/api/server/services/AuthService.spec.js @@ -14,6 +14,7 @@ jest.mock('@librechat/api', () => ({ isEmailDomainAllowed: jest.fn(), math: jest.fn((val, fallback) => (val ? Number(val) : fallback)), shouldUseSecureCookie: jest.fn(() => false), + resolveAppConfigForUser: jest.fn(async (_getAppConfig, _user) => ({})), })); jest.mock('~/models', () => ({ findUser: jest.fn(), @@ -35,8 +36,14 @@ jest.mock('~/strategies/validators', () => ({ registerSchema: { parse: jest.fn() jest.mock('~/server/services/Config', () => ({ getAppConfig: jest.fn() })); jest.mock('~/server/utils', () => ({ sendEmail: jest.fn() })); -const { shouldUseSecureCookie } = require('@librechat/api'); -const { setOpenIDAuthTokens } = require('./AuthService'); +const { + shouldUseSecureCookie, + isEmailDomainAllowed, + resolveAppConfigForUser, +} = require('@librechat/api'); +const { findUser } = require('~/models'); +const { getAppConfig } = require('~/server/services/Config'); +const { setOpenIDAuthTokens, requestPasswordReset } = require('./AuthService'); /** Helper to build a mock Express response */ function mockResponse() { @@ -267,3 +274,68 @@ describe('setOpenIDAuthTokens', () => { }); }); }); + +describe('requestPasswordReset', () => { + beforeEach(() => { + jest.clearAllMocks(); + isEmailDomainAllowed.mockReturnValue(true); + getAppConfig.mockResolvedValue({ + registration: { allowedDomains: ['example.com'] }, + }); + resolveAppConfigForUser.mockResolvedValue({ + registration: { allowedDomains: ['example.com'] }, + }); + }); + + it('should fast-fail with base config before DB lookup for blocked domains', async () => { + isEmailDomainAllowed.mockReturnValue(false); + + const req = { body: { email: 'blocked@evil.com' }, ip: '127.0.0.1' }; + const result = await requestPasswordReset(req); + + expect(getAppConfig).toHaveBeenCalledWith({ baseOnly: true }); + expect(findUser).not.toHaveBeenCalled(); + expect(result).toBeInstanceOf(Error); + }); + + it('should call resolveAppConfigForUser for tenant user', async () => { + const user = { + _id: 'user-tenant', + email: 'user@example.com', + tenantId: 'tenant-x', + role: 'USER', + }; + findUser.mockResolvedValue(user); + + const req = { body: { email: 'user@example.com' }, ip: '127.0.0.1' }; + await requestPasswordReset(req); + + expect(resolveAppConfigForUser).toHaveBeenCalledWith(getAppConfig, user); + }); + + it('should reuse baseConfig for non-tenant user without calling resolveAppConfigForUser', async () => { + findUser.mockResolvedValue({ _id: 'user-no-tenant', email: 'user@example.com' }); + + const req = { body: { email: 'user@example.com' }, ip: '127.0.0.1' }; + await requestPasswordReset(req); + + expect(resolveAppConfigForUser).not.toHaveBeenCalled(); + }); + + it('should return generic response when tenant config blocks the domain (non-enumerable)', async () => { + const user = { + _id: 'user-tenant', + email: 'user@example.com', + tenantId: 'tenant-x', + role: 'USER', + }; + findUser.mockResolvedValue(user); + isEmailDomainAllowed.mockReturnValueOnce(true).mockReturnValueOnce(false); + + const req = { body: { email: 'user@example.com' }, ip: '127.0.0.1' }; + const result = await requestPasswordReset(req); + + expect(result).not.toBeInstanceOf(Error); + expect(result.message).toContain('If an account with that email exists'); + }); +}); diff --git a/api/server/services/Config/__tests__/invalidateConfigCaches.spec.js b/api/server/services/Config/__tests__/invalidateConfigCaches.spec.js new file mode 100644 index 0000000000..49e94bc081 --- /dev/null +++ b/api/server/services/Config/__tests__/invalidateConfigCaches.spec.js @@ -0,0 +1,139 @@ +// ── Mocks ────────────────────────────────────────────────────────────── + +const mockConfigStoreDelete = jest.fn().mockResolvedValue(true); +const mockClearAppConfigCache = jest.fn().mockResolvedValue(undefined); +const mockClearOverrideCache = jest.fn().mockResolvedValue(undefined); + +jest.mock('~/cache/getLogStores', () => { + return jest.fn(() => ({ + delete: mockConfigStoreDelete, + })); +}); + +jest.mock('~/server/services/start/tools', () => ({ + loadAndFormatTools: jest.fn(() => ({})), +})); + +jest.mock('../loadCustomConfig', () => jest.fn().mockResolvedValue({})); + +jest.mock('@librechat/data-schemas', () => { + const actual = jest.requireActual('@librechat/data-schemas'); + return { ...actual, AppService: jest.fn(() => ({ availableTools: {} })) }; +}); + +jest.mock('~/models', () => ({ + getApplicableConfigs: jest.fn().mockResolvedValue([]), + getUserPrincipals: jest.fn().mockResolvedValue([]), +})); + +const mockInvalidateCachedTools = jest.fn().mockResolvedValue(undefined); +jest.mock('../getCachedTools', () => ({ + setCachedTools: jest.fn().mockResolvedValue(undefined), + invalidateCachedTools: mockInvalidateCachedTools, +})); + +const mockClearMcpConfigCache = jest.fn().mockResolvedValue(undefined); +jest.mock('@librechat/api', () => ({ + createAppConfigService: jest.fn(() => ({ + getAppConfig: jest.fn().mockResolvedValue({ availableTools: {} }), + clearAppConfigCache: mockClearAppConfigCache, + clearOverrideCache: mockClearOverrideCache, + })), + clearMcpConfigCache: mockClearMcpConfigCache, +})); + +// ── Tests ────────────────────────────────────────────────────────────── + +const { CacheKeys } = require('librechat-data-provider'); +const { invalidateConfigCaches } = require('../app'); + +describe('invalidateConfigCaches', () => { + beforeEach(() => { + jest.clearAllMocks(); + }); + + it('clears all four caches', async () => { + await invalidateConfigCaches(); + + expect(mockClearAppConfigCache).toHaveBeenCalledTimes(1); + expect(mockClearOverrideCache).toHaveBeenCalledTimes(1); + expect(mockInvalidateCachedTools).toHaveBeenCalledWith({ invalidateGlobal: true }); + expect(mockConfigStoreDelete).toHaveBeenCalledWith(CacheKeys.ENDPOINT_CONFIG); + }); + + it('passes tenantId through to clearOverrideCache', async () => { + await invalidateConfigCaches('tenant-a'); + + expect(mockClearOverrideCache).toHaveBeenCalledWith('tenant-a'); + expect(mockClearAppConfigCache).toHaveBeenCalledTimes(1); + expect(mockInvalidateCachedTools).toHaveBeenCalledWith({ invalidateGlobal: true }); + }); + + it('does not throw when CONFIG_STORE.delete fails', async () => { + mockConfigStoreDelete.mockRejectedValueOnce(new Error('store not found')); + + await expect(invalidateConfigCaches()).resolves.not.toThrow(); + + // Other caches should still have been invalidated + expect(mockClearAppConfigCache).toHaveBeenCalledTimes(1); + expect(mockClearOverrideCache).toHaveBeenCalledTimes(1); + expect(mockInvalidateCachedTools).toHaveBeenCalledWith({ invalidateGlobal: true }); + }); + + it('all operations run in parallel (not sequentially)', async () => { + const order = []; + + mockClearAppConfigCache.mockImplementation( + () => + new Promise((r) => + setTimeout(() => { + order.push('base'); + r(); + }, 10), + ), + ); + mockClearOverrideCache.mockImplementation( + () => + new Promise((r) => + setTimeout(() => { + order.push('override'); + r(); + }, 10), + ), + ); + mockInvalidateCachedTools.mockImplementation( + () => + new Promise((r) => + setTimeout(() => { + order.push('tools'); + r(); + }, 10), + ), + ); + mockConfigStoreDelete.mockImplementation( + () => + new Promise((r) => + setTimeout(() => { + order.push('endpoint'); + r(); + }, 10), + ), + ); + + await invalidateConfigCaches(); + + // All four should have been called (parallel execution via Promise.allSettled) + expect(order).toHaveLength(4); + expect(new Set(order)).toEqual(new Set(['base', 'override', 'tools', 'endpoint'])); + }); + + it('resolves even when clearAppConfigCache throws (partial failure)', async () => { + mockClearAppConfigCache.mockRejectedValueOnce(new Error('cache connection lost')); + + await expect(invalidateConfigCaches()).resolves.not.toThrow(); + + // Other caches should still have been invalidated despite the failure + expect(mockClearOverrideCache).toHaveBeenCalledTimes(1); + expect(mockInvalidateCachedTools).toHaveBeenCalledWith({ invalidateGlobal: true }); + }); +}); diff --git a/api/server/services/Config/app.js b/api/server/services/Config/app.js index 75a5cbe56d..3256732ec2 100644 --- a/api/server/services/Config/app.js +++ b/api/server/services/Config/app.js @@ -1,12 +1,12 @@ const { CacheKeys } = require('librechat-data-provider'); -const { logger, AppService } = require('@librechat/data-schemas'); +const { AppService, logger, scopedCacheKey } = require('@librechat/data-schemas'); +const { createAppConfigService, clearMcpConfigCache } = require('@librechat/api'); +const { setCachedTools, invalidateCachedTools } = require('./getCachedTools'); const { loadAndFormatTools } = require('~/server/services/start/tools'); const loadCustomConfig = require('./loadCustomConfig'); -const { setCachedTools } = require('./getCachedTools'); const getLogStores = require('~/cache/getLogStores'); const paths = require('~/config/paths'); - -const BASE_CONFIG_KEY = '_BASE_'; +const db = require('~/models'); const loadBaseConfig = async () => { /** @type {TCustomConfig} */ @@ -20,65 +20,67 @@ const loadBaseConfig = async () => { return AppService({ config, paths, systemTools }); }; +const { getAppConfig, clearAppConfigCache, clearOverrideCache } = createAppConfigService({ + loadBaseConfig, + setCachedTools, + getCache: getLogStores, + cacheKeys: CacheKeys, + getApplicableConfigs: db.getApplicableConfigs, + getUserPrincipals: db.getUserPrincipals, +}); + /** - * Get the app configuration based on user context - * @param {Object} [options] - * @param {string} [options.role] - User role for role-based config - * @param {boolean} [options.refresh] - Force refresh the cache - * @returns {Promise} + * Deletes ENDPOINT_CONFIG entries from CONFIG_STORE. + * Clears both the tenant-scoped key (if in tenant context) and the + * unscoped base key (populated by unauthenticated /api/endpoints calls). + * Other tenants' scoped keys are NOT actively cleared — they expire + * via TTL. Config mutations in one tenant do not propagate immediately + * to other tenants' endpoint config caches. */ -async function getAppConfig(options = {}) { - const { role, refresh } = options; - - const cache = getLogStores(CacheKeys.APP_CONFIG); - const cacheKey = role ? role : BASE_CONFIG_KEY; - - if (!refresh) { - const cached = await cache.get(cacheKey); - if (cached) { - return cached; +async function clearEndpointConfigCache() { + try { + const configStore = getLogStores(CacheKeys.CONFIG_STORE); + const scoped = scopedCacheKey(CacheKeys.ENDPOINT_CONFIG); + const keys = [scoped]; + if (scoped !== CacheKeys.ENDPOINT_CONFIG) { + keys.push(CacheKeys.ENDPOINT_CONFIG); } + await Promise.all(keys.map((k) => configStore.delete(k))); + } catch { + // CONFIG_STORE or ENDPOINT_CONFIG may not exist — not critical } - - let baseConfig = await cache.get(BASE_CONFIG_KEY); - if (!baseConfig) { - logger.info('[getAppConfig] App configuration not initialized. Initializing AppService...'); - baseConfig = await loadBaseConfig(); - - if (!baseConfig) { - throw new Error('Failed to initialize app configuration through AppService.'); - } - - if (baseConfig.availableTools) { - await setCachedTools(baseConfig.availableTools); - } - - await cache.set(BASE_CONFIG_KEY, baseConfig); - } - - // For now, return the base config - // In the future, this is where we'll apply role-based modifications - if (role) { - // TODO: Apply role-based config modifications - // const roleConfig = await applyRoleBasedConfig(baseConfig, role); - // await cache.set(cacheKey, roleConfig); - // return roleConfig; - } - - return baseConfig; } /** - * Clear the app configuration cache - * @returns {Promise} + * Invalidate all config-related caches after an admin config mutation. + * Clears the base config, per-principal override caches, tool caches, + * the endpoints config cache, and the MCP config-source server cache. + * @param {string} [tenantId] - Optional tenant ID to scope override cache clearing. */ -async function clearAppConfigCache() { - const cache = getLogStores(CacheKeys.CONFIG_STORE); - const cacheKey = CacheKeys.APP_CONFIG; - return await cache.delete(cacheKey); +async function invalidateConfigCaches(tenantId) { + const results = await Promise.allSettled([ + clearAppConfigCache(), + clearOverrideCache(tenantId), + invalidateCachedTools({ invalidateGlobal: true }), + clearEndpointConfigCache(), + clearMcpConfigCache(), + ]); + const labels = [ + 'clearAppConfigCache', + 'clearOverrideCache', + 'invalidateCachedTools', + 'clearEndpointConfigCache', + 'clearMcpConfigCache', + ]; + for (let i = 0; i < results.length; i++) { + if (results[i].status === 'rejected') { + logger.error(`[invalidateConfigCaches] ${labels[i]} failed:`, results[i].reason); + } + } } module.exports = { getAppConfig, clearAppConfigCache, + invalidateConfigCaches, }; diff --git a/api/server/services/Config/getEndpointsConfig.js b/api/server/services/Config/getEndpointsConfig.js index bb22584851..cd0230ad4a 100644 --- a/api/server/services/Config/getEndpointsConfig.js +++ b/api/server/services/Config/getEndpointsConfig.js @@ -1,3 +1,4 @@ +const { scopedCacheKey } = require('@librechat/data-schemas'); const { loadCustomEndpointsConfig } = require('@librechat/api'); const { CacheKeys, @@ -17,16 +18,18 @@ const { getAppConfig } = require('./app'); */ async function getEndpointsConfig(req) { const cache = getLogStores(CacheKeys.CONFIG_STORE); - const cachedEndpointsConfig = await cache.get(CacheKeys.ENDPOINT_CONFIG); + const cacheKey = scopedCacheKey(CacheKeys.ENDPOINT_CONFIG); + const cachedEndpointsConfig = await cache.get(cacheKey); if (cachedEndpointsConfig) { if (cachedEndpointsConfig.gptPlugins) { - await cache.delete(CacheKeys.ENDPOINT_CONFIG); + await cache.delete(cacheKey); } else { return cachedEndpointsConfig; } } - const appConfig = req.config ?? (await getAppConfig({ role: req.user?.role })); + const appConfig = + req.config ?? (await getAppConfig({ role: req.user?.role, tenantId: req.user?.tenantId })); const defaultEndpointsConfig = await loadDefaultEndpointsConfig(appConfig); const customEndpointsConfig = loadCustomEndpointsConfig(appConfig?.endpoints?.custom); @@ -111,7 +114,7 @@ async function getEndpointsConfig(req) { const endpointsConfig = orderEndpointsConfig(mergedConfig); - await cache.set(CacheKeys.ENDPOINT_CONFIG, endpointsConfig); + await cache.set(cacheKey, endpointsConfig); return endpointsConfig; } diff --git a/api/server/services/Config/loadConfigModels.js b/api/server/services/Config/loadConfigModels.js index 2bc83ecc3a..b94a719909 100644 --- a/api/server/services/Config/loadConfigModels.js +++ b/api/server/services/Config/loadConfigModels.js @@ -12,7 +12,7 @@ const { getAppConfig } = require('./app'); * @param {ServerRequest} req - The Express request object. */ async function loadConfigModels(req) { - const appConfig = await getAppConfig({ role: req.user?.role }); + const appConfig = await getAppConfig({ role: req.user?.role, tenantId: req.user?.tenantId }); if (!appConfig) { return {}; } diff --git a/api/server/services/Config/loadDefaultModels.js b/api/server/services/Config/loadDefaultModels.js index 31aa831a70..85f2c42a33 100644 --- a/api/server/services/Config/loadDefaultModels.js +++ b/api/server/services/Config/loadDefaultModels.js @@ -16,7 +16,8 @@ const { getAppConfig } = require('./app'); */ async function loadDefaultModels(req) { try { - const appConfig = req.config ?? (await getAppConfig({ role: req.user?.role })); + const appConfig = + req.config ?? (await getAppConfig({ role: req.user?.role, tenantId: req.user?.tenantId })); const vertexConfig = appConfig?.endpoints?.[EModelEndpoint.anthropic]?.vertexConfig; const [openAI, anthropic, azureOpenAI, assistants, azureAssistants, google, bedrock] = diff --git a/api/server/services/Config/mcp.js b/api/server/services/Config/mcp.js index cc4e98b59e..869c9e66da 100644 --- a/api/server/services/Config/mcp.js +++ b/api/server/services/Config/mcp.js @@ -1,5 +1,5 @@ -const { logger } = require('@librechat/data-schemas'); const { CacheKeys, Constants } = require('librechat-data-provider'); +const { logger, scopedCacheKey } = require('@librechat/data-schemas'); const { getCachedTools, setCachedTools } = require('./getCachedTools'); const { getLogStores } = require('~/cache'); @@ -36,7 +36,7 @@ async function updateMCPServerTools({ userId, serverName, tools }) { await setCachedTools(serverTools, { userId, serverName }); const cache = getLogStores(CacheKeys.TOOL_CACHE); - await cache.delete(CacheKeys.TOOLS); + await cache.delete(scopedCacheKey(CacheKeys.TOOLS)); logger.debug( `[MCP Cache] Updated ${tools.length} tools for server ${serverName} (user: ${userId})`, ); @@ -48,7 +48,10 @@ async function updateMCPServerTools({ userId, serverName, tools }) { } /** - * Merges app-level tools with global tools + * Merges app-level tools with global tools. + * Only the current ALS-scoped key (base key in system/startup context) is cleared. + * Tenant-scoped TOOLS:tenantId keys are NOT actively invalidated — they expire + * via TTL on the next tenant request. This matches clearEndpointConfigCache behavior. * @param {import('@librechat/api').LCAvailableTools} appTools * @returns {Promise} */ @@ -62,7 +65,7 @@ async function mergeAppTools(appTools) { const mergedTools = { ...cachedTools, ...appTools }; await setCachedTools(mergedTools); const cache = getLogStores(CacheKeys.TOOL_CACHE); - await cache.delete(CacheKeys.TOOLS); + await cache.delete(scopedCacheKey(CacheKeys.TOOLS)); logger.debug(`Merged ${count} app-level tools`); } catch (error) { logger.error('Failed to merge app-level tools:', error); diff --git a/api/server/services/Files/Audio/STTService.js b/api/server/services/Files/Audio/STTService.js index 4ba62a7eeb..c9a35c35ea 100644 --- a/api/server/services/Files/Audio/STTService.js +++ b/api/server/services/Files/Audio/STTService.js @@ -142,6 +142,7 @@ class STTService { req.config ?? (await getAppConfig({ role: req?.user?.role, + tenantId: req?.user?.tenantId, })); const sttSchema = appConfig?.speech?.stt; if (!sttSchema) { diff --git a/api/server/services/Files/Audio/TTSService.js b/api/server/services/Files/Audio/TTSService.js index 2c932968c6..1125dd74ed 100644 --- a/api/server/services/Files/Audio/TTSService.js +++ b/api/server/services/Files/Audio/TTSService.js @@ -297,6 +297,7 @@ class TTSService { req.config ?? (await getAppConfig({ role: req.user?.role, + tenantId: req.user?.tenantId, })); try { res.setHeader('Content-Type', 'audio/mpeg'); @@ -365,6 +366,7 @@ class TTSService { req.config ?? (await getAppConfig({ role: req.user?.role, + tenantId: req.user?.tenantId, })); const provider = this.getProvider(appConfig); const ttsSchema = appConfig?.speech?.tts?.[provider]; diff --git a/api/server/services/Files/Audio/getCustomConfigSpeech.js b/api/server/services/Files/Audio/getCustomConfigSpeech.js index d0d0b51ac2..b438771ec1 100644 --- a/api/server/services/Files/Audio/getCustomConfigSpeech.js +++ b/api/server/services/Files/Audio/getCustomConfigSpeech.js @@ -17,6 +17,7 @@ async function getCustomConfigSpeech(req, res) { try { const appConfig = await getAppConfig({ role: req.user?.role, + tenantId: req.user?.tenantId, }); if (!appConfig) { diff --git a/api/server/services/Files/Audio/getVoices.js b/api/server/services/Files/Audio/getVoices.js index f2f8e100c3..22bd7cea6e 100644 --- a/api/server/services/Files/Audio/getVoices.js +++ b/api/server/services/Files/Audio/getVoices.js @@ -18,6 +18,7 @@ async function getVoices(req, res) { req.config ?? (await getAppConfig({ role: req.user?.role, + tenantId: req.user?.tenantId, })); const ttsSchema = appConfig?.speech?.tts; diff --git a/api/server/services/Files/Audio/streamAudio.js b/api/server/services/Files/Audio/streamAudio.js index c28a96edff..7120399b5e 100644 --- a/api/server/services/Files/Audio/streamAudio.js +++ b/api/server/services/Files/Audio/streamAudio.js @@ -1,3 +1,4 @@ +const { scopedCacheKey } = require('@librechat/data-schemas'); const { Time, CacheKeys, @@ -67,6 +68,8 @@ function createChunkProcessor(user, messageId) { } const messageCache = getLogStores(CacheKeys.MESSAGES); + // Captured at creation time — must be called within an active request ALS scope + const cacheKey = scopedCacheKey(messageId); /** * @returns {Promise<{ text: string, isFinished: boolean }[] | string>} @@ -81,7 +84,7 @@ function createChunkProcessor(user, messageId) { } /** @type { string | { text: string; complete: boolean } } */ - let message = await messageCache.get(messageId); + let message = await messageCache.get(cacheKey); if (!message) { message = await getMessage({ user, messageId }); } @@ -92,7 +95,7 @@ function createChunkProcessor(user, messageId) { } else { const text = message.content?.length > 0 ? parseTextParts(message.content) : message.text; messageCache.set( - messageId, + cacheKey, { text, complete: true, diff --git a/api/server/services/MCP.js b/api/server/services/MCP.js index 5d97891c55..dbb44740a9 100644 --- a/api/server/services/MCP.js +++ b/api/server/services/MCP.js @@ -1,5 +1,5 @@ const { tool } = require('@langchain/core/tools'); -const { logger } = require('@librechat/data-schemas'); +const { logger, getTenantId } = require('@librechat/data-schemas'); const { Providers, StepTypes, @@ -14,6 +14,7 @@ const { normalizeJsonSchema, GenerationJobManager, resolveJsonSchemaRefs, + buildOAuthToolCallName, } = require('@librechat/api'); const { Time, CacheKeys, Constants, isAssistantsEndpoint } = require('librechat-data-provider'); const { @@ -53,6 +54,53 @@ function evictStale(map, ttl) { const unavailableMsg = "This tool's MCP server is temporarily unavailable. Please try again shortly."; +/** + * Resolves config-source MCP servers from admin Config overrides for the current + * request context. Returns the parsed configs keyed by server name. + * @param {import('express').Request} req - Express request with user context + * @returns {Promise>} + */ +async function resolveConfigServers(req) { + try { + const registry = getMCPServersRegistry(); + const user = req?.user; + const appConfig = await getAppConfig({ + role: user?.role, + tenantId: getTenantId(), + userId: user?.id, + }); + return await registry.ensureConfigServers(appConfig?.mcpConfig || {}); + } catch (error) { + logger.warn( + '[resolveConfigServers] Failed to resolve config servers, degrading to empty:', + error, + ); + return {}; + } +} + +/** + * Resolves config-source servers and merges all server configs (YAML + config + user DB) + * for the given user context. Shared helper for controllers needing the full merged config. + * @param {string} userId + * @param {{ id?: string, role?: string }} [user] + * @returns {Promise>} + */ +async function resolveAllMcpConfigs(userId, user) { + const registry = getMCPServersRegistry(); + const appConfig = await getAppConfig({ role: user?.role, tenantId: getTenantId(), userId }); + let configServers = {}; + try { + configServers = await registry.ensureConfigServers(appConfig?.mcpConfig || {}); + } catch (error) { + logger.warn( + '[resolveAllMcpConfigs] Config server resolution failed, continuing without:', + error, + ); + } + return await registry.getAllServerConfigs(userId, configServers); +} + /** * @param {string} toolName * @param {string} serverName @@ -248,6 +296,7 @@ async function reconnectServer({ index, signal, serverName, + configServers, userMCPAuthMap, streamId = null, }) { @@ -271,7 +320,7 @@ async function reconnectServer({ const stepId = 'step_oauth_login_' + serverName; const toolCall = { id: flowId, - name: serverName, + name: buildOAuthToolCallName(serverName), type: 'tool_call_chunk', }; @@ -316,6 +365,7 @@ async function reconnectServer({ user, signal, serverName, + configServers, oauthStart, flowManager, userMCPAuthMap, @@ -358,15 +408,14 @@ async function createMCPTools({ config, provider, serverName, + configServers, userMCPAuthMap, streamId = null, }) { - // Early domain validation before reconnecting server (avoid wasted work on disallowed domains) - // Use getAppConfig() to support per-user/role domain restrictions const serverConfig = - config ?? (await getMCPServersRegistry().getServerConfig(serverName, user?.id)); + config ?? (await getMCPServersRegistry().getServerConfig(serverName, user?.id, configServers)); if (serverConfig?.url) { - const appConfig = await getAppConfig({ role: user?.role }); + const appConfig = await getAppConfig({ role: user?.role, tenantId: user?.tenantId }); const allowedDomains = appConfig?.mcpSettings?.allowedDomains; const isDomainAllowed = await isMCPDomainAllowed(serverConfig, allowedDomains); if (!isDomainAllowed) { @@ -381,6 +430,7 @@ async function createMCPTools({ index, signal, serverName, + configServers, userMCPAuthMap, streamId, }); @@ -400,6 +450,7 @@ async function createMCPTools({ user, provider, userMCPAuthMap, + configServers, streamId, availableTools: result.availableTools, toolKey: `${tool.name}${Constants.mcp_delimiter}${serverName}`, @@ -439,16 +490,15 @@ async function createMCPTool({ userMCPAuthMap, availableTools, config, + configServers, streamId = null, }) { const [toolName, serverName] = toolKey.split(Constants.mcp_delimiter); - // Runtime domain validation: check if the server's domain is still allowed - // Use getAppConfig() to support per-user/role domain restrictions const serverConfig = - config ?? (await getMCPServersRegistry().getServerConfig(serverName, user?.id)); + config ?? (await getMCPServersRegistry().getServerConfig(serverName, user?.id, configServers)); if (serverConfig?.url) { - const appConfig = await getAppConfig({ role: user?.role }); + const appConfig = await getAppConfig({ role: user?.role, tenantId: user?.tenantId }); const allowedDomains = appConfig?.mcpSettings?.allowedDomains; const isDomainAllowed = await isMCPDomainAllowed(serverConfig, allowedDomains); if (!isDomainAllowed) { @@ -477,6 +527,7 @@ async function createMCPTool({ index, signal, serverName, + configServers, userMCPAuthMap, streamId, }); @@ -500,6 +551,7 @@ async function createMCPTool({ provider, toolName, serverName, + serverConfig, toolDefinition, streamId, }); @@ -509,13 +561,14 @@ function createToolInstance({ res, toolName, serverName, + serverConfig: capturedServerConfig, toolDefinition, - provider: _provider, + provider: capturedProvider, streamId = null, }) { /** @type {LCTool} */ const { description, parameters } = toolDefinition; - const isGoogle = _provider === Providers.VERTEXAI || _provider === Providers.GOOGLE; + const isGoogle = capturedProvider === Providers.VERTEXAI || capturedProvider === Providers.GOOGLE; let schema = parameters ? normalizeJsonSchema(resolveJsonSchemaRefs(parameters)) : null; @@ -544,7 +597,7 @@ function createToolInstance({ const flowManager = getFlowStateManager(flowsCache); derivedSignal = config?.signal ? AbortSignal.any([config.signal]) : undefined; const mcpManager = getMCPManager(userId); - const provider = (config?.metadata?.provider || _provider)?.toLowerCase(); + const provider = (config?.metadata?.provider || capturedProvider)?.toLowerCase(); const { args: _args, stepId, ...toolCall } = config.toolCall ?? {}; const flowId = `${serverName}:oauth_login:${config.metadata.thread_id}:${config.metadata.run_id}`; @@ -576,6 +629,7 @@ function createToolInstance({ const result = await mcpManager.callTool({ serverName, + serverConfig: capturedServerConfig, toolName, provider, toolArguments, @@ -643,30 +697,36 @@ function createToolInstance({ } /** - * Get MCP setup data including config, connections, and OAuth servers + * Get MCP setup data including config, connections, and OAuth servers. + * Resolves config-source servers from admin Config overrides when tenant context is available. * @param {string} userId - The user ID + * @param {{ role?: string, tenantId?: string }} [options] - Optional role/tenant context * @returns {Object} Object containing mcpConfig, appConnections, userConnections, and oauthServers */ -async function getMCPSetupData(userId) { - const mcpConfig = await getMCPServersRegistry().getAllServerConfigs(userId); - - if (!mcpConfig) { - throw new Error('MCP config not found'); - } +async function getMCPSetupData(userId, options = {}) { + const registry = getMCPServersRegistry(); + const { role, tenantId } = options; + const appConfig = await getAppConfig({ role, tenantId, userId }); + const configServers = await registry.ensureConfigServers(appConfig?.mcpConfig || {}); + const mcpConfig = await registry.getAllServerConfigs(userId, configServers); const mcpManager = getMCPManager(userId); /** @type {Map} */ let appConnections = new Map(); try { - // Use getLoaded() instead of getAll() to avoid forcing connection creation + // Use getLoaded() instead of getAll() to avoid forcing connection creation. // getAll() creates connections for all servers, which is problematic for servers - // that require user context (e.g., those with {{LIBRECHAT_USER_ID}} placeholders) + // that require user context (e.g., those with {{LIBRECHAT_USER_ID}} placeholders). appConnections = (await mcpManager.appConnections?.getLoaded()) || new Map(); } catch (error) { logger.error(`[MCP][User: ${userId}] Error getting app connections:`, error); } const userConnections = mcpManager.getUserConnections(userId) || new Map(); - const oauthServers = await getMCPServersRegistry().getOAuthServers(userId); + const oauthServers = new Set( + Object.entries(mcpConfig) + .filter(([, config]) => config.requiresOAuth) + .map(([name]) => name), + ); return { mcpConfig, @@ -788,6 +848,8 @@ module.exports = { createMCPTool, createMCPTools, getMCPSetupData, + resolveConfigServers, + resolveAllMcpConfigs, checkOAuthFlowStatus, getServerConnectionStatus, createUnavailableToolStub, diff --git a/api/server/services/MCP.spec.js b/api/server/services/MCP.spec.js index 14a9ef90ed..c9925827f8 100644 --- a/api/server/services/MCP.spec.js +++ b/api/server/services/MCP.spec.js @@ -14,6 +14,7 @@ const mockRegistryInstance = { getOAuthServers: jest.fn(() => Promise.resolve(new Set())), getAllServerConfigs: jest.fn(() => Promise.resolve({})), getServerConfig: jest.fn(() => Promise.resolve(null)), + ensureConfigServers: jest.fn(() => Promise.resolve({})), }; // Create isMCPDomainAllowed mock that can be configured per-test @@ -113,38 +114,43 @@ describe('tests for the new helper functions used by the MCP connection status e }); it('should successfully return MCP setup data', async () => { - mockRegistryInstance.getAllServerConfigs.mockResolvedValue(mockConfig); + const mockConfigWithOAuth = { + server1: { type: 'stdio' }, + server2: { type: 'http', requiresOAuth: true }, + }; + mockRegistryInstance.getAllServerConfigs.mockResolvedValue(mockConfigWithOAuth); const mockAppConnections = new Map([['server1', { status: 'connected' }]]); const mockUserConnections = new Map([['server2', { status: 'disconnected' }]]); - const mockOAuthServers = new Set(['server2']); const mockMCPManager = { appConnections: { getLoaded: jest.fn(() => Promise.resolve(mockAppConnections)) }, getUserConnections: jest.fn(() => mockUserConnections), }; mockGetMCPManager.mockReturnValue(mockMCPManager); - mockRegistryInstance.getOAuthServers.mockResolvedValue(mockOAuthServers); const result = await getMCPSetupData(mockUserId); - expect(mockRegistryInstance.getAllServerConfigs).toHaveBeenCalledWith(mockUserId); + expect(mockRegistryInstance.ensureConfigServers).toHaveBeenCalled(); + expect(mockRegistryInstance.getAllServerConfigs).toHaveBeenCalledWith( + mockUserId, + expect.any(Object), + ); expect(mockGetMCPManager).toHaveBeenCalledWith(mockUserId); expect(mockMCPManager.appConnections.getLoaded).toHaveBeenCalled(); expect(mockMCPManager.getUserConnections).toHaveBeenCalledWith(mockUserId); - expect(mockRegistryInstance.getOAuthServers).toHaveBeenCalledWith(mockUserId); - expect(result).toEqual({ - mcpConfig: mockConfig, - appConnections: mockAppConnections, - userConnections: mockUserConnections, - oauthServers: mockOAuthServers, - }); + expect(result.mcpConfig).toEqual(mockConfigWithOAuth); + expect(result.appConnections).toEqual(mockAppConnections); + expect(result.userConnections).toEqual(mockUserConnections); + expect(result.oauthServers).toEqual(new Set(['server2'])); }); - it('should throw error when MCP config not found', async () => { - mockRegistryInstance.getAllServerConfigs.mockResolvedValue(null); - await expect(getMCPSetupData(mockUserId)).rejects.toThrow('MCP config not found'); + it('should return empty data when no servers are configured', async () => { + mockRegistryInstance.getAllServerConfigs.mockResolvedValue({}); + const result = await getMCPSetupData(mockUserId); + expect(result.mcpConfig).toEqual({}); + expect(result.oauthServers).toEqual(new Set()); }); it('should handle null values from MCP manager gracefully', async () => { diff --git a/api/server/services/ToolService.js b/api/server/services/ToolService.js index ca75e7eb4f..c11843cb69 100644 --- a/api/server/services/ToolService.js +++ b/api/server/services/ToolService.js @@ -19,6 +19,7 @@ const { buildWebSearchContext, buildImageToolContext, buildToolClassification, + buildOAuthToolCallName, } = require('@librechat/api'); const { Time, @@ -59,6 +60,7 @@ const { manifestToolMap, toolkits } = require('~/app/clients/tools/manifest'); const { createOnSearchResults } = require('~/server/services/Tools/search'); const { loadAuthValues } = require('~/server/services/Tools/credentials'); const { reinitMCPServer } = require('~/server/services/Tools/mcp'); +const { resolveConfigServers } = require('~/server/services/MCP'); const { recordUsage } = require('~/server/services/Threads'); const { loadTools } = require('~/app/clients/tools/util'); const { redactMessage } = require('~/config/parsers'); @@ -513,6 +515,7 @@ async function loadToolDefinitionsWrapper({ req, res, agent, streamId = null, to const flowsCache = getLogStores(CacheKeys.FLOWS); const flowManager = getFlowStateManager(flowsCache); + const configServers = await resolveConfigServers(req); const pendingOAuthServers = new Set(); const createOAuthEmitter = (serverName) => { @@ -521,7 +524,7 @@ async function loadToolDefinitionsWrapper({ req, res, agent, streamId = null, to const stepId = 'step_oauth_login_' + serverName; const toolCall = { id: flowId, - name: serverName, + name: buildOAuthToolCallName(serverName), type: 'tool_call_chunk', }; @@ -578,6 +581,7 @@ async function loadToolDefinitionsWrapper({ req, res, agent, streamId = null, to oauthStart, flowManager, serverName, + configServers, userMCPAuthMap, }); @@ -665,6 +669,7 @@ async function loadToolDefinitionsWrapper({ req, res, agent, streamId = null, to const result = await reinitMCPServer({ user: req.user, serverName, + configServers, userMCPAuthMap, flowManager, returnOnOAuth: false, diff --git a/api/server/services/Tools/mcp.js b/api/server/services/Tools/mcp.js index 7589043e10..f1ebcf9796 100644 --- a/api/server/services/Tools/mcp.js +++ b/api/server/services/Tools/mcp.js @@ -25,11 +25,13 @@ async function reinitMCPServer({ signal, forceNew, serverName, + configServers, userMCPAuthMap, connectionTimeout, returnOnOAuth = true, oauthStart: _oauthStart, flowManager: _flowManager, + serverConfig: providedConfig, }) { /** @type {MCPConnection | null} */ let connection = null; @@ -42,13 +44,28 @@ async function reinitMCPServer({ try { const registry = getMCPServersRegistry(); - const serverConfig = await registry.getServerConfig(serverName, user?.id); + const serverConfig = + providedConfig ?? (await registry.getServerConfig(serverName, user?.id, configServers)); if (serverConfig?.inspectionFailed) { + if (serverConfig.source === 'config') { + logger.info( + `[MCP Reinitialize] Config-source server ${serverName} has inspectionFailed — retry handled by config cache`, + ); + return { + availableTools: null, + success: false, + message: `MCP server '${serverName}' is still unreachable`, + oauthRequired: false, + serverName, + oauthUrl: null, + tools: null, + }; + } logger.info( `[MCP Reinitialize] Server ${serverName} had failed inspection, attempting reinspection`, ); try { - const storageLocation = serverConfig.dbId ? 'DB' : 'CACHE'; + const storageLocation = serverConfig.source === 'user' ? 'DB' : 'CACHE'; await registry.reinspectServer(serverName, storageLocation, user?.id); logger.info(`[MCP Reinitialize] Reinspection succeeded for server: ${serverName}`); } catch (reinspectError) { @@ -93,6 +110,7 @@ async function reinitMCPServer({ returnOnOAuth, customUserVars, connectionTimeout, + serverConfig, }); logger.info(`[MCP Reinitialize] Successfully established connection for ${serverName}`); @@ -125,6 +143,7 @@ async function reinitMCPServer({ oauthStart, customUserVars, connectionTimeout, + configServers, }); if (discoveryResult.tools && discoveryResult.tools.length > 0) { diff --git a/api/server/services/__tests__/MCP.spec.js b/api/server/services/__tests__/MCP.spec.js new file mode 100644 index 0000000000..39e99d54ac --- /dev/null +++ b/api/server/services/__tests__/MCP.spec.js @@ -0,0 +1,131 @@ +const mockRegistry = { + ensureConfigServers: jest.fn(), + getAllServerConfigs: jest.fn(), +}; + +jest.mock('~/config', () => ({ + getMCPServersRegistry: jest.fn(() => mockRegistry), + getMCPManager: jest.fn(), + getFlowStateManager: jest.fn(), + getOAuthReconnectionManager: jest.fn(), +})); + +jest.mock('@librechat/data-schemas', () => ({ + getTenantId: jest.fn(() => 'tenant-1'), + logger: { debug: jest.fn(), info: jest.fn(), warn: jest.fn(), error: jest.fn() }, +})); + +jest.mock('~/server/services/Config', () => ({ + getAppConfig: jest.fn(), + setCachedTools: jest.fn(), + getCachedTools: jest.fn(), + getMCPServerTools: jest.fn(), + loadCustomConfig: jest.fn(), +})); + +jest.mock('~/cache', () => ({ getLogStores: jest.fn() })); +jest.mock('~/models', () => ({ + findToken: jest.fn(), + createToken: jest.fn(), + updateToken: jest.fn(), +})); +jest.mock('~/server/services/GraphTokenService', () => ({ + getGraphApiToken: jest.fn(), +})); +jest.mock('~/server/services/Tools/mcp', () => ({ + reinitMCPServer: jest.fn(), +})); + +const { getAppConfig } = require('~/server/services/Config'); +const { resolveConfigServers, resolveAllMcpConfigs } = require('../MCP'); + +describe('resolveConfigServers', () => { + beforeEach(() => jest.clearAllMocks()); + + it('resolves config servers for the current request context', async () => { + getAppConfig.mockResolvedValue({ mcpConfig: { srv: { url: 'http://a' } } }); + mockRegistry.ensureConfigServers.mockResolvedValue({ srv: { name: 'srv' } }); + + const result = await resolveConfigServers({ user: { id: 'u1', role: 'admin' } }); + + expect(result).toEqual({ srv: { name: 'srv' } }); + expect(getAppConfig).toHaveBeenCalledWith( + expect.objectContaining({ role: 'admin', userId: 'u1' }), + ); + expect(mockRegistry.ensureConfigServers).toHaveBeenCalledWith({ srv: { url: 'http://a' } }); + }); + + it('returns {} when ensureConfigServers throws', async () => { + getAppConfig.mockResolvedValue({ mcpConfig: { srv: {} } }); + mockRegistry.ensureConfigServers.mockRejectedValue(new Error('inspect failed')); + + const result = await resolveConfigServers({ user: { id: 'u1' } }); + + expect(result).toEqual({}); + }); + + it('returns {} when getAppConfig throws', async () => { + getAppConfig.mockRejectedValue(new Error('db timeout')); + + const result = await resolveConfigServers({ user: { id: 'u1' } }); + + expect(result).toEqual({}); + }); + + it('passes empty mcpConfig when appConfig has none', async () => { + getAppConfig.mockResolvedValue({}); + mockRegistry.ensureConfigServers.mockResolvedValue({}); + + await resolveConfigServers({ user: { id: 'u1' } }); + + expect(mockRegistry.ensureConfigServers).toHaveBeenCalledWith({}); + }); +}); + +describe('resolveAllMcpConfigs', () => { + beforeEach(() => jest.clearAllMocks()); + + it('merges config servers with base servers', async () => { + getAppConfig.mockResolvedValue({ mcpConfig: { cfg_srv: {} } }); + mockRegistry.ensureConfigServers.mockResolvedValue({ cfg_srv: { name: 'cfg_srv' } }); + mockRegistry.getAllServerConfigs.mockResolvedValue({ + cfg_srv: { name: 'cfg_srv' }, + yaml_srv: { name: 'yaml_srv' }, + }); + + const result = await resolveAllMcpConfigs('u1', { id: 'u1', role: 'user' }); + + expect(result).toEqual({ + cfg_srv: { name: 'cfg_srv' }, + yaml_srv: { name: 'yaml_srv' }, + }); + expect(mockRegistry.getAllServerConfigs).toHaveBeenCalledWith('u1', { + cfg_srv: { name: 'cfg_srv' }, + }); + }); + + it('continues with empty configServers when ensureConfigServers fails', async () => { + getAppConfig.mockResolvedValue({ mcpConfig: { srv: {} } }); + mockRegistry.ensureConfigServers.mockRejectedValue(new Error('inspect failed')); + mockRegistry.getAllServerConfigs.mockResolvedValue({ yaml_srv: { name: 'yaml_srv' } }); + + const result = await resolveAllMcpConfigs('u1', { id: 'u1' }); + + expect(result).toEqual({ yaml_srv: { name: 'yaml_srv' } }); + expect(mockRegistry.getAllServerConfigs).toHaveBeenCalledWith('u1', {}); + }); + + it('propagates getAllServerConfigs failures', async () => { + getAppConfig.mockResolvedValue({ mcpConfig: {} }); + mockRegistry.ensureConfigServers.mockResolvedValue({}); + mockRegistry.getAllServerConfigs.mockRejectedValue(new Error('redis down')); + + await expect(resolveAllMcpConfigs('u1', { id: 'u1' })).rejects.toThrow('redis down'); + }); + + it('propagates getAppConfig failures', async () => { + getAppConfig.mockRejectedValue(new Error('mongo down')); + + await expect(resolveAllMcpConfigs('u1', { id: 'u1' })).rejects.toThrow('mongo down'); + }); +}); diff --git a/api/server/services/__tests__/ToolService.spec.js b/api/server/services/__tests__/ToolService.spec.js index a468a88eb3..6e06804280 100644 --- a/api/server/services/__tests__/ToolService.spec.js +++ b/api/server/services/__tests__/ToolService.spec.js @@ -64,6 +64,9 @@ jest.mock('~/models', () => ({ jest.mock('~/config', () => ({ getFlowStateManager: jest.fn(() => ({})), })); +jest.mock('~/server/services/MCP', () => ({ + resolveConfigServers: jest.fn().mockResolvedValue({}), +})); jest.mock('~/cache', () => ({ getLogStores: jest.fn(() => ({})), })); diff --git a/api/server/services/initializeMCPs.js b/api/server/services/initializeMCPs.js index c7f27acd0e..5728730131 100644 --- a/api/server/services/initializeMCPs.js +++ b/api/server/services/initializeMCPs.js @@ -7,7 +7,7 @@ const { createMCPServersRegistry, createMCPManager } = require('~/config'); * Initialize MCP servers */ async function initializeMCPs() { - const appConfig = await getAppConfig(); + const appConfig = await getAppConfig({ baseOnly: true }); const mcpServers = appConfig.mcpConfig; try { diff --git a/api/server/utils/import/importers.js b/api/server/utils/import/importers.js index 39734c181c..f8b3be4dab 100644 --- a/api/server/utils/import/importers.js +++ b/api/server/utils/import/importers.js @@ -1,5 +1,5 @@ const { v4: uuidv4 } = require('uuid'); -const { logger } = require('@librechat/data-schemas'); +const { logger, scopedCacheKey } = require('@librechat/data-schemas'); const { EModelEndpoint, Constants, openAISettings, CacheKeys } = require('librechat-data-provider'); const { createImportBatchBuilder } = require('./importBatchBuilder'); const { cloneMessagesWithTimestamps } = require('./fork'); @@ -203,7 +203,7 @@ async function importLibreChatConvo( /* Endpoint configuration */ let endpoint = jsonData.endpoint ?? options.endpoint ?? EModelEndpoint.openAI; const cache = getLogStores(CacheKeys.CONFIG_STORE); - const endpointsConfig = await cache.get(CacheKeys.ENDPOINT_CONFIG); + const endpointsConfig = await cache.get(scopedCacheKey(CacheKeys.ENDPOINT_CONFIG)); const endpointConfig = endpointsConfig?.[endpoint]; if (!endpointConfig && endpointsConfig) { endpoint = Object.keys(endpointsConfig)[0]; diff --git a/api/strategies/ldapStrategy.js b/api/strategies/ldapStrategy.js index dcadc26a45..9253f54196 100644 --- a/api/strategies/ldapStrategy.js +++ b/api/strategies/ldapStrategy.js @@ -2,7 +2,12 @@ const fs = require('fs'); const LdapStrategy = require('passport-ldapauth'); const { logger } = require('@librechat/data-schemas'); const { SystemRoles, ErrorTypes } = require('librechat-data-provider'); -const { isEnabled, getBalanceConfig, isEmailDomainAllowed } = require('@librechat/api'); +const { + isEnabled, + getBalanceConfig, + isEmailDomainAllowed, + resolveAppConfigForUser, +} = require('@librechat/api'); const { createUser, findUser, updateUser, countUsers } = require('~/models'); const { getAppConfig } = require('~/server/services/Config'); @@ -89,16 +94,6 @@ const ldapLogin = new LdapStrategy(ldapOptions, async (userinfo, done) => { const ldapId = (LDAP_ID && userinfo[LDAP_ID]) || userinfo.uid || userinfo.sAMAccountName || userinfo.mail; - let user = await findUser({ ldapId }); - if (user && user.provider !== 'ldap') { - logger.info( - `[ldapStrategy] User ${user.email} already exists with provider ${user.provider}`, - ); - return done(null, false, { - message: ErrorTypes.AUTH_FAILED, - }); - } - const fullNameAttributes = LDAP_FULL_NAME && LDAP_FULL_NAME.split(','); const fullName = fullNameAttributes && fullNameAttributes.length > 0 @@ -122,7 +117,31 @@ const ldapLogin = new LdapStrategy(ldapOptions, async (userinfo, done) => { ); } - const appConfig = await getAppConfig(); + // Domain check before findUser for two-phase fast-fail (consistent with SAML/OpenID/social). + // This means cross-provider users from blocked domains get 'Email domain not allowed' + // instead of AUTH_FAILED — both deny access. + const baseConfig = await getAppConfig({ baseOnly: true }); + if (!isEmailDomainAllowed(mail, baseConfig?.registration?.allowedDomains)) { + logger.error( + `[LDAP Strategy] Authentication blocked - email domain not allowed [Email: ${mail}]`, + ); + return done(null, false, { message: 'Email domain not allowed' }); + } + + let user = await findUser({ ldapId }); + if (user && user.provider !== 'ldap') { + logger.info( + `[ldapStrategy] User ${user.email} already exists with provider ${user.provider}`, + ); + return done(null, false, { + message: ErrorTypes.AUTH_FAILED, + }); + } + + const appConfig = user?.tenantId + ? await resolveAppConfigForUser(getAppConfig, user) + : baseConfig; + if (!isEmailDomainAllowed(mail, appConfig?.registration?.allowedDomains)) { logger.error( `[LDAP Strategy] Authentication blocked - email domain not allowed [Email: ${mail}]`, diff --git a/api/strategies/ldapStrategy.spec.js b/api/strategies/ldapStrategy.spec.js index a00e9b14b7..876d70f845 100644 --- a/api/strategies/ldapStrategy.spec.js +++ b/api/strategies/ldapStrategy.spec.js @@ -9,10 +9,10 @@ jest.mock('@librechat/data-schemas', () => ({ })); jest.mock('@librechat/api', () => ({ - // isEnabled used for TLS flags isEnabled: jest.fn(() => false), isEmailDomainAllowed: jest.fn(() => true), getBalanceConfig: jest.fn(() => ({ enabled: false })), + resolveAppConfigForUser: jest.fn(async (_getAppConfig, _user) => ({})), })); jest.mock('~/models', () => ({ @@ -30,14 +30,15 @@ jest.mock('~/server/services/Config', () => ({ let verifyCallback; jest.mock('passport-ldapauth', () => { return jest.fn().mockImplementation((options, verify) => { - verifyCallback = verify; // capture the strategy verify function + verifyCallback = verify; return { name: 'ldap', options, verify }; }); }); const { ErrorTypes } = require('librechat-data-provider'); -const { isEmailDomainAllowed } = require('@librechat/api'); +const { isEmailDomainAllowed, resolveAppConfigForUser } = require('@librechat/api'); const { findUser, createUser, updateUser, countUsers } = require('~/models'); +const { getAppConfig } = require('~/server/services/Config'); // Helper to call the verify callback and wrap in a Promise for convenience const callVerify = (userinfo) => @@ -117,6 +118,7 @@ describe('ldapStrategy', () => { expect(user).toBe(false); expect(info).toEqual({ message: ErrorTypes.AUTH_FAILED }); expect(createUser).not.toHaveBeenCalled(); + expect(resolveAppConfigForUser).not.toHaveBeenCalled(); }); it('updates an existing ldap user with current LDAP info', async () => { @@ -158,7 +160,6 @@ describe('ldapStrategy', () => { uid: 'uid999', givenName: 'John', cn: 'John Doe', - // no mail and no custom LDAP_EMAIL }; const { user } = await callVerify(userinfo); @@ -180,4 +181,66 @@ describe('ldapStrategy', () => { expect(user).toBe(false); expect(info).toEqual({ message: 'Email domain not allowed' }); }); + + it('passes getAppConfig and found user to resolveAppConfigForUser', async () => { + const existing = { + _id: 'u3', + provider: 'ldap', + email: 'tenant@example.com', + ldapId: 'uid-tenant', + username: 'tenantuser', + name: 'Tenant User', + tenantId: 'tenant-a', + role: 'USER', + }; + findUser.mockResolvedValue(existing); + + const userinfo = { + uid: 'uid-tenant', + mail: 'tenant@example.com', + givenName: 'Tenant', + cn: 'Tenant User', + }; + + await callVerify(userinfo); + + expect(resolveAppConfigForUser).toHaveBeenCalledWith(getAppConfig, existing); + }); + + it('uses baseConfig for new user without calling resolveAppConfigForUser', async () => { + findUser.mockResolvedValue(null); + + const userinfo = { + uid: 'uid-new', + mail: 'newuser@example.com', + givenName: 'New', + cn: 'New User', + }; + + await callVerify(userinfo); + + expect(resolveAppConfigForUser).not.toHaveBeenCalled(); + expect(getAppConfig).toHaveBeenCalledWith({ baseOnly: true }); + }); + + it('should block login when tenant config restricts the domain', async () => { + const existing = { + _id: 'u-blocked', + provider: 'ldap', + ldapId: 'uid-tenant', + tenantId: 'tenant-strict', + role: 'USER', + }; + findUser.mockResolvedValue(existing); + resolveAppConfigForUser.mockResolvedValue({ + registration: { allowedDomains: ['other.com'] }, + }); + isEmailDomainAllowed.mockReturnValueOnce(true).mockReturnValueOnce(false); + + const userinfo = { uid: 'uid-tenant', mail: 'user@example.com', givenName: 'Test', cn: 'Test' }; + const { user, info } = await callVerify(userinfo); + + expect(user).toBe(false); + expect(info).toEqual({ message: 'Email domain not allowed' }); + }); }); diff --git a/api/strategies/openidStrategy.js b/api/strategies/openidStrategy.js index 7c43358297..7314a84e15 100644 --- a/api/strategies/openidStrategy.js +++ b/api/strategies/openidStrategy.js @@ -15,6 +15,7 @@ const { findOpenIDUser, getBalanceConfig, isEmailDomainAllowed, + resolveAppConfigForUser, } = require('@librechat/api'); const { getStrategyFunctions } = require('~/server/services/Files/strategies'); const { findUser, createUser, updateUser } = require('~/models'); @@ -468,9 +469,10 @@ async function processOpenIDAuth(tokenset, existingUsersOnly = false) { Object.assign(userinfo, providerUserinfo); } - const appConfig = await getAppConfig(); const email = getOpenIdEmail(userinfo); - if (!isEmailDomainAllowed(email, appConfig?.registration?.allowedDomains)) { + + const baseConfig = await getAppConfig({ baseOnly: true }); + if (!isEmailDomainAllowed(email, baseConfig?.registration?.allowedDomains)) { logger.error( `[OpenID Strategy] Authentication blocked - email domain not allowed [Identifier: ${email}]`, ); @@ -491,6 +493,15 @@ async function processOpenIDAuth(tokenset, existingUsersOnly = false) { throw new Error(ErrorTypes.AUTH_FAILED); } + const appConfig = user?.tenantId ? await resolveAppConfigForUser(getAppConfig, user) : baseConfig; + + if (!isEmailDomainAllowed(email, appConfig?.registration?.allowedDomains)) { + logger.error( + `[OpenID Strategy] Authentication blocked - email domain not allowed [Identifier: ${email}]`, + ); + throw new Error('Email domain not allowed'); + } + const fullName = getFullName(userinfo); const requiredRole = process.env.OPENID_REQUIRED_ROLE; diff --git a/api/strategies/openidStrategy.spec.js b/api/strategies/openidStrategy.spec.js index 4436fab672..6d824176f7 100644 --- a/api/strategies/openidStrategy.spec.js +++ b/api/strategies/openidStrategy.spec.js @@ -1,1822 +1,1873 @@ -const undici = require('undici'); -const fetch = require('node-fetch'); -const jwtDecode = require('jsonwebtoken/decode'); -const { ErrorTypes } = require('librechat-data-provider'); -const { findUser, createUser, updateUser } = require('~/models'); -const { setupOpenId } = require('./openidStrategy'); - -// --- Mocks --- -jest.mock('node-fetch'); -jest.mock('jsonwebtoken/decode'); -jest.mock('undici', () => ({ - fetch: jest.fn(), - ProxyAgent: jest.fn(), -})); -jest.mock('~/server/services/Files/strategies', () => ({ - getStrategyFunctions: jest.fn(() => ({ - saveBuffer: jest.fn().mockResolvedValue('/fake/path/to/avatar.png'), - })), -})); -jest.mock('~/server/services/Config', () => ({ - getAppConfig: jest.fn().mockResolvedValue({}), -})); -jest.mock('@librechat/api', () => ({ - ...jest.requireActual('@librechat/api'), - isEnabled: jest.fn(() => false), - isEmailDomainAllowed: jest.fn(() => true), - findOpenIDUser: jest.requireActual('@librechat/api').findOpenIDUser, - getBalanceConfig: jest.fn(() => ({ - enabled: false, - })), -})); -jest.mock('~/models', () => ({ - findUser: jest.fn(), - createUser: jest.fn(), - updateUser: jest.fn(), -})); -jest.mock('@librechat/data-schemas', () => ({ - ...jest.requireActual('@librechat/api'), - logger: { - info: jest.fn(), - warn: jest.fn(), - debug: jest.fn(), - error: jest.fn(), - }, - hashToken: jest.fn().mockResolvedValue('hashed-token'), -})); -jest.mock('~/cache/getLogStores', () => - jest.fn(() => ({ - get: jest.fn(), - set: jest.fn(), - })), -); - -// Mock the openid-client module and all its dependencies -jest.mock('openid-client', () => { - return { - discovery: jest.fn().mockResolvedValue({ - clientId: 'fake_client_id', - clientSecret: 'fake_client_secret', - issuer: 'https://fake-issuer.com', - // Add any other properties needed by the implementation - }), - fetchUserInfo: jest.fn().mockImplementation(() => { - // Only return additional properties, but don't override any claims - return Promise.resolve({}); - }), - genericGrantRequest: jest.fn().mockResolvedValue({ - access_token: 'exchanged_graph_token', - expires_in: 3600, - }), - customFetch: Symbol('customFetch'), - }; -}); - -jest.mock('openid-client/passport', () => { - /** Store callbacks by strategy name - 'openid' and 'openidAdmin' */ - const verifyCallbacks = {}; - let lastVerifyCallback; - - const mockStrategy = jest.fn((options, verify) => { - lastVerifyCallback = verify; - return { name: 'openid', options, verify }; - }); - - return { - Strategy: mockStrategy, - /** Get the last registered callback (for backward compatibility) */ - __getVerifyCallback: () => lastVerifyCallback, - /** Store callback by name when passport.use is called */ - __setVerifyCallback: (name, callback) => { - verifyCallbacks[name] = callback; - }, - /** Get callback by strategy name */ - __getVerifyCallbackByName: (name) => verifyCallbacks[name], - }; -}); - -// Mock passport - capture strategy name and callback -jest.mock('passport', () => ({ - use: jest.fn((name, strategy) => { - const passportMock = require('openid-client/passport'); - if (strategy && strategy.verify) { - passportMock.__setVerifyCallback(name, strategy.verify); - } - }), -})); - -describe('setupOpenId', () => { - // Store a reference to the verify callback once it's set up - let verifyCallback; - - // Helper to wrap the verify callback in a promise - const validate = (tokenset) => - new Promise((resolve, reject) => { - verifyCallback(tokenset, (err, user, details) => { - if (err) { - reject(err); - } else { - resolve({ user, details }); - } - }); - }); - - const tokenset = { - id_token: 'fake_id_token', - access_token: 'fake_access_token', - claims: () => ({ - sub: '1234', - email: 'test@example.com', - email_verified: true, - given_name: 'First', - family_name: 'Last', - name: 'My Full', - preferred_username: 'testusername', - username: 'flast', - picture: 'https://example.com/avatar.png', - }), - }; - - beforeEach(async () => { - // Clear previous mock calls and reset implementations - jest.clearAllMocks(); - - // Reset environment variables needed by the strategy - process.env.OPENID_ISSUER = 'https://fake-issuer.com'; - process.env.OPENID_CLIENT_ID = 'fake_client_id'; - process.env.OPENID_CLIENT_SECRET = 'fake_client_secret'; - process.env.DOMAIN_SERVER = 'https://example.com'; - process.env.OPENID_CALLBACK_URL = '/callback'; - process.env.OPENID_SCOPE = 'openid profile email'; - process.env.OPENID_REQUIRED_ROLE = 'requiredRole'; - process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'roles'; - process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; - process.env.OPENID_ADMIN_ROLE = 'admin'; - process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'permissions'; - process.env.OPENID_ADMIN_ROLE_TOKEN_KIND = 'id'; - delete process.env.OPENID_USERNAME_CLAIM; - delete process.env.OPENID_NAME_CLAIM; - delete process.env.OPENID_EMAIL_CLAIM; - delete process.env.PROXY; - delete process.env.OPENID_USE_PKCE; - - // Default jwtDecode mock returns a token that includes the required role. - jwtDecode.mockReturnValue({ - roles: ['requiredRole'], - permissions: ['admin'], - }); - - // By default, assume that no user is found, so createUser will be called - findUser.mockResolvedValue(null); - createUser.mockImplementation(async (userData) => { - // simulate created user with an _id property - return { _id: 'newUserId', ...userData }; - }); - updateUser.mockImplementation(async (id, userData) => { - return { _id: id, ...userData }; - }); - - // For image download, simulate a successful response - const fakeBuffer = Buffer.from('fake image'); - const fakeResponse = { - ok: true, - buffer: jest.fn().mockResolvedValue(fakeBuffer), - }; - fetch.mockResolvedValue(fakeResponse); - - // Call the setup function and capture the verify callback for the regular 'openid' strategy - // (not 'openidAdmin' which requires existing users) - await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - }); - - it('should create a new user with correct username when preferred_username claim exists', async () => { - // Arrange – our userinfo already has preferred_username 'testusername' - const userinfo = tokenset.claims(); - - // Act - const { user } = await validate(tokenset); - - // Assert - expect(user.username).toBe(userinfo.preferred_username); - expect(createUser).toHaveBeenCalledWith( - expect.objectContaining({ - provider: 'openid', - openidId: userinfo.sub, - username: userinfo.preferred_username, - email: userinfo.email, - name: `${userinfo.given_name} ${userinfo.family_name}`, - }), - { enabled: false }, - true, - true, - ); - }); - - it('should use username as username when preferred_username claim is missing', async () => { - // Arrange – remove preferred_username from userinfo - const userinfo = { ...tokenset.claims() }; - delete userinfo.preferred_username; - // Expect the username to be the "username" - const expectUsername = userinfo.username; - - // Act - const { user } = await validate({ ...tokenset, claims: () => userinfo }); - - // Assert - expect(user.username).toBe(expectUsername); - expect(createUser).toHaveBeenCalledWith( - expect.objectContaining({ username: expectUsername }), - { enabled: false }, - true, - true, - ); - }); - - it('should use email as username when username and preferred_username are missing', async () => { - // Arrange – remove username and preferred_username - const userinfo = { ...tokenset.claims() }; - delete userinfo.username; - delete userinfo.preferred_username; - const expectUsername = userinfo.email; - - // Act - const { user } = await validate({ ...tokenset, claims: () => userinfo }); - - // Assert - expect(user.username).toBe(expectUsername); - expect(createUser).toHaveBeenCalledWith( - expect.objectContaining({ username: expectUsername }), - { enabled: false }, - true, - true, - ); - }); - - it('should override username with OPENID_USERNAME_CLAIM when set', async () => { - // Arrange – set OPENID_USERNAME_CLAIM so that the sub claim is used - process.env.OPENID_USERNAME_CLAIM = 'sub'; - const userinfo = tokenset.claims(); - - // Act - const { user } = await validate(tokenset); - - // Assert – username should equal the sub (converted as-is) - expect(user.username).toBe(userinfo.sub); - expect(createUser).toHaveBeenCalledWith( - expect.objectContaining({ username: userinfo.sub }), - { enabled: false }, - true, - true, - ); - }); - - it('should set the full name correctly when given_name and family_name exist', async () => { - // Arrange - const userinfo = tokenset.claims(); - const expectedFullName = `${userinfo.given_name} ${userinfo.family_name}`; - - // Act - const { user } = await validate(tokenset); - - // Assert - expect(user.name).toBe(expectedFullName); - }); - - it('should override full name with OPENID_NAME_CLAIM when set', async () => { - // Arrange – use the name claim as the full name - process.env.OPENID_NAME_CLAIM = 'name'; - const userinfo = { ...tokenset.claims(), name: 'Custom Name' }; - - // Act - const { user } = await validate({ ...tokenset, claims: () => userinfo }); - - // Assert - expect(user.name).toBe('Custom Name'); - }); - - it('should update an existing user on login', async () => { - // Arrange – simulate that a user already exists with openid provider - const existingUser = { - _id: 'existingUserId', - provider: 'openid', - email: tokenset.claims().email, - openidId: '', - username: '', - name: '', - }; - findUser.mockImplementation(async (query) => { - if (query.openidId === tokenset.claims().sub || query.email === tokenset.claims().email) { - return existingUser; - } - return null; - }); - - const userinfo = tokenset.claims(); - - // Act - await validate(tokenset); - - // Assert – updateUser should be called and the user object updated - expect(updateUser).toHaveBeenCalledWith( - existingUser._id, - expect.objectContaining({ - provider: 'openid', - openidId: userinfo.sub, - username: userinfo.preferred_username, - name: `${userinfo.given_name} ${userinfo.family_name}`, - }), - ); - }); - - it('should block login when email exists with different provider', async () => { - // Arrange – simulate that a user exists with same email but different provider - const existingUser = { - _id: 'existingUserId', - provider: 'google', - email: tokenset.claims().email, - googleId: 'some-google-id', - username: 'existinguser', - name: 'Existing User', - }; - findUser.mockImplementation(async (query) => { - if (query.email === tokenset.claims().email && !query.provider) { - return existingUser; - } - return null; - }); - - // Act - const result = await validate(tokenset); - - // Assert – verify that the strategy rejects login - expect(result.user).toBe(false); - expect(result.details.message).toBe(ErrorTypes.AUTH_FAILED); - expect(createUser).not.toHaveBeenCalled(); - expect(updateUser).not.toHaveBeenCalled(); - }); - - it('should block login when email fallback finds user with mismatched openidId', async () => { - const existingUser = { - _id: 'existingUserId', - provider: 'openid', - openidId: 'different-sub-claim', - email: tokenset.claims().email, - username: 'existinguser', - name: 'Existing User', - }; - findUser.mockImplementation(async (query) => { - if (query.$or) { - return null; - } - if (query.email === tokenset.claims().email) { - return existingUser; - } - return null; - }); - - const result = await validate(tokenset); - - expect(result.user).toBe(false); - expect(result.details.message).toBe(ErrorTypes.AUTH_FAILED); - expect(createUser).not.toHaveBeenCalled(); - expect(updateUser).not.toHaveBeenCalled(); - }); - - it('should enforce the required role and reject login if missing', async () => { - // Arrange – simulate a token without the required role. - jwtDecode.mockReturnValue({ - roles: ['SomeOtherRole'], - }); - - // Act - const { user, details } = await validate(tokenset); - - // Assert – verify that the strategy rejects login - expect(user).toBe(false); - expect(details.message).toBe('You must have "requiredRole" role to log in.'); - }); - - it('should not treat substring matches in string roles as satisfying required role', async () => { - // Arrange – override required role to "read" then re-setup - process.env.OPENID_REQUIRED_ROLE = 'read'; - await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - - // Token contains "bread" which *contains* "read" as a substring - jwtDecode.mockReturnValue({ - roles: 'bread', - }); - - // Act - const { user, details } = await validate(tokenset); - - // Assert – verify that substring match does not grant access - expect(user).toBe(false); - expect(details.message).toBe('You must have "read" role to log in.'); - }); - - it('should allow login when roles claim is a space-separated string containing the required role', async () => { - // Arrange – IdP returns roles as a space-delimited string - jwtDecode.mockReturnValue({ - roles: 'role1 role2 requiredRole', - }); - - // Act - const { user } = await validate(tokenset); - - // Assert – login succeeds when required role is present after splitting - expect(user).toBeTruthy(); - expect(createUser).toHaveBeenCalled(); - }); - - it('should allow login when roles claim is a comma-separated string containing the required role', async () => { - // Arrange – IdP returns roles as a comma-delimited string - jwtDecode.mockReturnValue({ - roles: 'role1,role2,requiredRole', - }); - - // Act - const { user } = await validate(tokenset); - - // Assert – login succeeds when required role is present after splitting - expect(user).toBeTruthy(); - expect(createUser).toHaveBeenCalled(); - }); - - it('should allow login when roles claim is a mixed comma-and-space-separated string containing the required role', async () => { - // Arrange – IdP returns roles with comma-and-space delimiters - jwtDecode.mockReturnValue({ - roles: 'role1, role2, requiredRole', - }); - - // Act - const { user } = await validate(tokenset); - - // Assert – login succeeds when required role is present after splitting - expect(user).toBeTruthy(); - expect(createUser).toHaveBeenCalled(); - }); - - it('should reject login when roles claim is a space-separated string that does not contain the required role', async () => { - // Arrange – IdP returns a delimited string but required role is absent - jwtDecode.mockReturnValue({ - roles: 'role1 role2 otherRole', - }); - - // Act - const { user, details } = await validate(tokenset); - - // Assert – login is rejected with the correct error message - expect(user).toBe(false); - expect(details.message).toBe('You must have "requiredRole" role to log in.'); - }); - - it('should allow login when single required role is present (backward compatibility)', async () => { - // Arrange – ensure single role configuration (as set in beforeEach) - // OPENID_REQUIRED_ROLE = 'requiredRole' - // Default jwtDecode mock in beforeEach already returns this role - jwtDecode.mockReturnValue({ - roles: ['requiredRole', 'anotherRole'], - }); - - // Act - const { user } = await validate(tokenset); - - // Assert – verify that login succeeds with single role configuration - expect(user).toBeTruthy(); - expect(user.email).toBe(tokenset.claims().email); - expect(user.username).toBe(tokenset.claims().preferred_username); - expect(createUser).toHaveBeenCalled(); - }); - - describe('group overage and groups handling', () => { - it.each([ - ['groups array contains required group', ['group-required', 'other-group'], true, undefined], - [ - 'groups array missing required group', - ['other-group'], - false, - 'You must have "group-required" role to log in.', - ], - ['groups string equals required group', 'group-required', true, undefined], - [ - 'groups string is other group', - 'other-group', - false, - 'You must have "group-required" role to log in.', - ], - ])( - 'uses groups claim directly when %s (no overage)', - async (_label, groupsClaim, expectedAllowed, expectedMessage) => { - process.env.OPENID_REQUIRED_ROLE = 'group-required'; - process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups'; - process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; - - jwtDecode.mockReturnValue({ - groups: groupsClaim, - permissions: ['admin'], - }); - - await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - - const { user, details } = await validate(tokenset); - - expect(undici.fetch).not.toHaveBeenCalled(); - expect(Boolean(user)).toBe(expectedAllowed); - expect(details?.message).toBe(expectedMessage); - }, - ); - - it.each([ - ['token kind is not id', { kind: 'access', path: 'groups', decoded: { hasgroups: true } }], - ['parameter path is not groups', { kind: 'id', path: 'roles', decoded: { hasgroups: true } }], - ['decoded token is falsy', { kind: 'id', path: 'groups', decoded: null }], - [ - 'no overage indicators in decoded token', - { - kind: 'id', - path: 'groups', - decoded: { - permissions: ['admin'], - }, - }, - ], - [ - 'only _claim_names present (no _claim_sources)', - { - kind: 'id', - path: 'groups', - decoded: { - _claim_names: { groups: 'src1' }, - permissions: ['admin'], - }, - }, - ], - [ - 'only _claim_sources present (no _claim_names)', - { - kind: 'id', - path: 'groups', - decoded: { - _claim_sources: { src1: { endpoint: 'https://graph.windows.net/ignored' } }, - permissions: ['admin'], - }, - }, - ], - ])('does not attempt overage resolution when %s', async (_label, cfg) => { - process.env.OPENID_REQUIRED_ROLE = 'group-required'; - process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = cfg.path; - process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = cfg.kind; - - jwtDecode.mockReturnValue(cfg.decoded); - - await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - - const { user, details } = await validate(tokenset); - - expect(undici.fetch).not.toHaveBeenCalled(); - expect(user).toBe(false); - expect(details.message).toBe('You must have "group-required" role to log in.'); - const { logger } = require('@librechat/data-schemas'); - const expectedTokenKind = cfg.kind === 'access' ? 'access token' : 'id token'; - expect(logger.error).toHaveBeenCalledWith( - expect.stringContaining(`Key '${cfg.path}' not found in ${expectedTokenKind}!`), - ); - }); - }); - - describe('resolving groups via Microsoft Graph', () => { - it('denies login and does not call Graph when access token is missing', async () => { - process.env.OPENID_REQUIRED_ROLE = 'group-required'; - process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups'; - process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; - - const { logger } = require('@librechat/data-schemas'); - - jwtDecode.mockReturnValue({ - hasgroups: true, - permissions: ['admin'], - }); - - await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - - const tokensetWithoutAccess = { - ...tokenset, - access_token: undefined, - }; - - const { user, details } = await validate(tokensetWithoutAccess); - - expect(user).toBe(false); - expect(details.message).toBe('You must have "group-required" role to log in.'); - - expect(undici.fetch).not.toHaveBeenCalled(); - expect(logger.error).toHaveBeenCalledWith( - expect.stringContaining('Access token missing; cannot resolve group overage'), - ); - }); - - it.each([ - [ - 'Graph returns HTTP error', - async () => ({ - ok: false, - status: 403, - statusText: 'Forbidden', - json: async () => ({}), - }), - [ - '[openidStrategy] Failed to resolve groups via Microsoft Graph getMemberObjects: HTTP 403 Forbidden', - ], - ], - [ - 'Graph network error', - async () => { - throw new Error('network error'); - }, - [ - '[openidStrategy] Error resolving groups via Microsoft Graph getMemberObjects:', - expect.any(Error), - ], - ], - [ - 'Graph returns unexpected shape (no value)', - async () => ({ - ok: true, - status: 200, - statusText: 'OK', - json: async () => ({}), - }), - [ - '[openidStrategy] Unexpected response format when resolving groups via Microsoft Graph getMemberObjects', - ], - ], - [ - 'Graph returns invalid value type', - async () => ({ - ok: true, - status: 200, - statusText: 'OK', - json: async () => ({ value: 'not-an-array' }), - }), - [ - '[openidStrategy] Unexpected response format when resolving groups via Microsoft Graph getMemberObjects', - ], - ], - ])( - 'denies login when overage resolution fails because %s', - async (_label, setupFetch, expectedErrorArgs) => { - process.env.OPENID_REQUIRED_ROLE = 'group-required'; - process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups'; - process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; - - const { logger } = require('@librechat/data-schemas'); - - jwtDecode.mockReturnValue({ - hasgroups: true, - permissions: ['admin'], - }); - - await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - - undici.fetch.mockImplementation(setupFetch); - - const { user, details } = await validate(tokenset); - - expect(undici.fetch).toHaveBeenCalled(); - expect(user).toBe(false); - expect(details.message).toBe('You must have "group-required" role to log in.'); - - expect(logger.error).toHaveBeenCalledWith(...expectedErrorArgs); - }, - ); - - it.each([ - [ - 'hasgroups overage and Graph contains required group', - { - hasgroups: true, - }, - ['group-required', 'some-other-group'], - true, - ], - [ - '_claim_* overage and Graph contains required group', - { - _claim_names: { groups: 'src1' }, - _claim_sources: { src1: { endpoint: 'https://graph.windows.net/ignored' } }, - }, - ['group-required', 'some-other-group'], - true, - ], - [ - 'hasgroups overage and Graph does NOT contain required group', - { - hasgroups: true, - }, - ['some-other-group'], - false, - ], - [ - '_claim_* overage and Graph does NOT contain required group', - { - _claim_names: { groups: 'src1' }, - _claim_sources: { src1: { endpoint: 'https://graph.windows.net/ignored' } }, - }, - ['some-other-group'], - false, - ], - ])( - 'resolves groups via Microsoft Graph when %s', - async (_label, decodedTokenValue, graphGroups, expectedAllowed) => { - process.env.OPENID_REQUIRED_ROLE = 'group-required'; - process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups'; - process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; - - const { logger } = require('@librechat/data-schemas'); - - jwtDecode.mockReturnValue(decodedTokenValue); - - await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - - undici.fetch.mockResolvedValue({ - ok: true, - status: 200, - statusText: 'OK', - json: async () => ({ - value: graphGroups, - }), - }); - - const { user } = await validate(tokenset); - - expect(undici.fetch).toHaveBeenCalledWith( - 'https://graph.microsoft.com/v1.0/me/getMemberObjects', - expect.objectContaining({ - method: 'POST', - headers: expect.objectContaining({ - Authorization: 'Bearer exchanged_graph_token', - }), - }), - ); - expect(Boolean(user)).toBe(expectedAllowed); - - expect(logger.debug).toHaveBeenCalledWith( - expect.stringContaining( - `Successfully resolved ${graphGroups.length} groups via Microsoft Graph getMemberObjects`, - ), - ); - }, - ); - }); - - describe('OBO token exchange for overage', () => { - it('exchanges access token via OBO before calling Graph API', async () => { - const openidClient = require('openid-client'); - process.env.OPENID_REQUIRED_ROLE = 'group-required'; - process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups'; - process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; - - jwtDecode.mockReturnValue({ hasgroups: true }); - - await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - - undici.fetch.mockResolvedValue({ - ok: true, - status: 200, - statusText: 'OK', - json: async () => ({ value: ['group-required'] }), - }); - - await validate(tokenset); - - expect(openidClient.genericGrantRequest).toHaveBeenCalledWith( - expect.anything(), - 'urn:ietf:params:oauth:grant-type:jwt-bearer', - expect.objectContaining({ - scope: 'https://graph.microsoft.com/User.Read', - assertion: tokenset.access_token, - requested_token_use: 'on_behalf_of', - }), - ); - - expect(undici.fetch).toHaveBeenCalledWith( - 'https://graph.microsoft.com/v1.0/me/getMemberObjects', - expect.objectContaining({ - headers: expect.objectContaining({ - Authorization: 'Bearer exchanged_graph_token', - }), - }), - ); - }); - - it('caches the exchanged token and reuses it on subsequent calls', async () => { - const openidClient = require('openid-client'); - const getLogStores = require('~/cache/getLogStores'); - const mockSet = jest.fn(); - const mockGet = jest - .fn() - .mockResolvedValueOnce(undefined) - .mockResolvedValueOnce({ access_token: 'exchanged_graph_token' }); - getLogStores.mockReturnValue({ get: mockGet, set: mockSet }); - - process.env.OPENID_REQUIRED_ROLE = 'group-required'; - process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups'; - process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; - - jwtDecode.mockReturnValue({ hasgroups: true }); - - await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - - undici.fetch.mockResolvedValue({ - ok: true, - status: 200, - statusText: 'OK', - json: async () => ({ value: ['group-required'] }), - }); - - // First call: cache miss → OBO exchange → cache set - await validate(tokenset); - expect(mockSet).toHaveBeenCalledWith( - '1234:overage', - { access_token: 'exchanged_graph_token' }, - 3600000, - ); - expect(openidClient.genericGrantRequest).toHaveBeenCalledTimes(1); - - // Second call: cache hit → no new OBO exchange - openidClient.genericGrantRequest.mockClear(); - await validate(tokenset); - expect(openidClient.genericGrantRequest).not.toHaveBeenCalled(); - }); - }); - - describe('admin role group overage', () => { - it('resolves admin groups via Graph when overage is detected for admin role', async () => { - process.env.OPENID_REQUIRED_ROLE = 'group-required'; - process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups'; - process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; - process.env.OPENID_ADMIN_ROLE = 'admin-group-id'; - process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'groups'; - process.env.OPENID_ADMIN_ROLE_TOKEN_KIND = 'id'; - - jwtDecode.mockReturnValue({ hasgroups: true }); - - await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - - undici.fetch.mockResolvedValue({ - ok: true, - status: 200, - statusText: 'OK', - json: async () => ({ value: ['group-required', 'admin-group-id'] }), - }); - - const { user } = await validate(tokenset); - - expect(user.role).toBe('ADMIN'); - }); - - it('does not grant admin when overage groups do not contain admin role', async () => { - process.env.OPENID_REQUIRED_ROLE = 'group-required'; - process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups'; - process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; - process.env.OPENID_ADMIN_ROLE = 'admin-group-id'; - process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'groups'; - process.env.OPENID_ADMIN_ROLE_TOKEN_KIND = 'id'; - - jwtDecode.mockReturnValue({ hasgroups: true }); - - await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - - undici.fetch.mockResolvedValue({ - ok: true, - status: 200, - statusText: 'OK', - json: async () => ({ value: ['group-required', 'other-group'] }), - }); - - const { user } = await validate(tokenset); - - expect(user).toBeTruthy(); - expect(user.role).toBeUndefined(); - }); - - it('reuses already-resolved overage groups for admin role check (no duplicate Graph call)', async () => { - process.env.OPENID_REQUIRED_ROLE = 'group-required'; - process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups'; - process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; - process.env.OPENID_ADMIN_ROLE = 'admin-group-id'; - process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'groups'; - process.env.OPENID_ADMIN_ROLE_TOKEN_KIND = 'id'; - - jwtDecode.mockReturnValue({ hasgroups: true }); - - await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - - undici.fetch.mockResolvedValue({ - ok: true, - status: 200, - statusText: 'OK', - json: async () => ({ value: ['group-required', 'admin-group-id'] }), - }); - - await validate(tokenset); - - // Graph API should be called only once (for required role), admin role reuses the result - expect(undici.fetch).toHaveBeenCalledTimes(1); - }); - - it('demotes existing admin when overage groups no longer contain admin role', async () => { - process.env.OPENID_REQUIRED_ROLE = 'group-required'; - process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups'; - process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; - process.env.OPENID_ADMIN_ROLE = 'admin-group-id'; - process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'groups'; - process.env.OPENID_ADMIN_ROLE_TOKEN_KIND = 'id'; - - const existingAdminUser = { - _id: 'existingAdminId', - provider: 'openid', - email: tokenset.claims().email, - openidId: tokenset.claims().sub, - username: 'adminuser', - name: 'Admin User', - role: 'ADMIN', - }; - - findUser.mockImplementation(async (query) => { - if (query.openidId === tokenset.claims().sub || query.email === tokenset.claims().email) { - return existingAdminUser; - } - return null; - }); - - jwtDecode.mockReturnValue({ hasgroups: true }); - - await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - - undici.fetch.mockResolvedValue({ - ok: true, - status: 200, - statusText: 'OK', - json: async () => ({ value: ['group-required'] }), - }); - - const { user } = await validate(tokenset); - - expect(user.role).toBe('USER'); - }); - - it('does not attempt overage for admin role when token kind is not id', async () => { - process.env.OPENID_REQUIRED_ROLE = 'requiredRole'; - process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'roles'; - process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; - process.env.OPENID_ADMIN_ROLE = 'admin'; - process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'groups'; - process.env.OPENID_ADMIN_ROLE_TOKEN_KIND = 'access'; - - jwtDecode.mockReturnValue({ - roles: ['requiredRole'], - hasgroups: true, - }); - - await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - - const { user } = await validate(tokenset); - - // No Graph call since admin uses access token (not id) - expect(undici.fetch).not.toHaveBeenCalled(); - expect(user.role).toBeUndefined(); - }); - - it('resolves admin via Graph independently when OPENID_REQUIRED_ROLE is not configured', async () => { - delete process.env.OPENID_REQUIRED_ROLE; - process.env.OPENID_ADMIN_ROLE = 'admin-group-id'; - process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'groups'; - process.env.OPENID_ADMIN_ROLE_TOKEN_KIND = 'id'; - - jwtDecode.mockReturnValue({ hasgroups: true }); - await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - - undici.fetch.mockResolvedValue({ - ok: true, - status: 200, - statusText: 'OK', - json: async () => ({ value: ['admin-group-id'] }), - }); - - const { user } = await validate(tokenset); - expect(user.role).toBe('ADMIN'); - expect(undici.fetch).toHaveBeenCalledTimes(1); - }); - - it('denies admin when OPENID_REQUIRED_ROLE is absent and Graph does not contain admin group', async () => { - delete process.env.OPENID_REQUIRED_ROLE; - process.env.OPENID_ADMIN_ROLE = 'admin-group-id'; - process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'groups'; - process.env.OPENID_ADMIN_ROLE_TOKEN_KIND = 'id'; - - jwtDecode.mockReturnValue({ hasgroups: true }); - await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - - undici.fetch.mockResolvedValue({ - ok: true, - status: 200, - statusText: 'OK', - json: async () => ({ value: ['other-group'] }), - }); - - const { user } = await validate(tokenset); - expect(user).toBeTruthy(); - expect(user.role).toBeUndefined(); - }); - - it('denies login and logs error when OBO exchange throws', async () => { - const openidClient = require('openid-client'); - process.env.OPENID_REQUIRED_ROLE = 'group-required'; - process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups'; - process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; - - jwtDecode.mockReturnValue({ hasgroups: true }); - openidClient.genericGrantRequest.mockRejectedValueOnce(new Error('OBO exchange rejected')); - - await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - - const { user, details } = await validate(tokenset); - expect(user).toBe(false); - expect(details.message).toBe('You must have "group-required" role to log in.'); - expect(undici.fetch).not.toHaveBeenCalled(); - }); - - it('denies login when OBO exchange returns no access_token', async () => { - const openidClient = require('openid-client'); - process.env.OPENID_REQUIRED_ROLE = 'group-required'; - process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups'; - process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; - - jwtDecode.mockReturnValue({ hasgroups: true }); - openidClient.genericGrantRequest.mockResolvedValueOnce({ expires_in: 3600 }); - - await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - - const { user, details } = await validate(tokenset); - expect(user).toBe(false); - expect(details.message).toBe('You must have "group-required" role to log in.'); - expect(undici.fetch).not.toHaveBeenCalled(); - }); - }); - - it('should attempt to download and save the avatar if picture is provided', async () => { - // Act - const { user } = await validate(tokenset); - - // Assert – verify that download was attempted and the avatar field was set via updateUser - expect(fetch).toHaveBeenCalled(); - // Our mock getStrategyFunctions.saveBuffer returns '/fake/path/to/avatar.png' - expect(user.avatar).toBe('/fake/path/to/avatar.png'); - }); - - it('should not attempt to download avatar if picture is not provided', async () => { - // Arrange – remove picture - const userinfo = { ...tokenset.claims() }; - delete userinfo.picture; - - // Act - await validate({ ...tokenset, claims: () => userinfo }); - - // Assert – fetch should not be called and avatar should remain undefined or empty - expect(fetch).not.toHaveBeenCalled(); - // Depending on your implementation, user.avatar may be undefined or an empty string. - }); - - it('should support comma-separated multiple roles', async () => { - // Arrange - process.env.OPENID_REQUIRED_ROLE = 'someRole,anotherRole,admin'; - await setupOpenId(); // Re-initialize the strategy - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - jwtDecode.mockReturnValue({ - roles: ['anotherRole', 'aThirdRole'], - }); - - // Act - const { user } = await validate(tokenset); - - // Assert - expect(user).toBeTruthy(); - expect(user.email).toBe(tokenset.claims().email); - }); - - it('should reject login when user has none of the required multiple roles', async () => { - // Arrange - process.env.OPENID_REQUIRED_ROLE = 'someRole,anotherRole,admin'; - await setupOpenId(); // Re-initialize the strategy - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - jwtDecode.mockReturnValue({ - roles: ['aThirdRole', 'aFourthRole'], - }); - - // Act - const { user, details } = await validate(tokenset); - - // Assert - expect(user).toBe(false); - expect(details.message).toBe( - 'You must have one of: "someRole", "anotherRole", "admin" role to log in.', - ); - }); - - it('should handle spaces in comma-separated roles', async () => { - // Arrange - process.env.OPENID_REQUIRED_ROLE = ' someRole , anotherRole , admin '; - await setupOpenId(); // Re-initialize the strategy - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - jwtDecode.mockReturnValue({ - roles: ['someRole'], - }); - - // Act - const { user } = await validate(tokenset); - - // Assert - expect(user).toBeTruthy(); - }); - - it('should default to usePKCE false when OPENID_USE_PKCE is not defined', async () => { - const OpenIDStrategy = require('openid-client/passport').Strategy; - - delete process.env.OPENID_USE_PKCE; - await setupOpenId(); - - const callOptions = OpenIDStrategy.mock.calls[OpenIDStrategy.mock.calls.length - 1][0]; - expect(callOptions.usePKCE).toBe(false); - expect(callOptions.params?.code_challenge_method).toBeUndefined(); - }); - - it('should attach federatedTokens to user object for token propagation', async () => { - // Arrange - setup tokenset with access token, id token, refresh token, and expiration - const tokensetWithTokens = { - ...tokenset, - access_token: 'mock_access_token_abc123', - id_token: 'mock_id_token_def456', - refresh_token: 'mock_refresh_token_xyz789', - expires_at: 1234567890, - }; - - // Act - validate with the tokenset containing tokens - const { user } = await validate(tokensetWithTokens); - - // Assert - verify federatedTokens object is attached with correct values - expect(user.federatedTokens).toBeDefined(); - expect(user.federatedTokens).toEqual({ - access_token: 'mock_access_token_abc123', - id_token: 'mock_id_token_def456', - refresh_token: 'mock_refresh_token_xyz789', - expires_at: 1234567890, - }); - }); - - it('should include id_token in federatedTokens distinct from access_token', async () => { - // Arrange - use different values for access_token and id_token - const tokensetWithTokens = { - ...tokenset, - access_token: 'the_access_token', - id_token: 'the_id_token', - refresh_token: 'the_refresh_token', - expires_at: 9999999999, - }; - - // Act - const { user } = await validate(tokensetWithTokens); - - // Assert - id_token and access_token must be different values - expect(user.federatedTokens.access_token).toBe('the_access_token'); - expect(user.federatedTokens.id_token).toBe('the_id_token'); - expect(user.federatedTokens.id_token).not.toBe(user.federatedTokens.access_token); - }); - - it('should include tokenset along with federatedTokens', async () => { - // Arrange - const tokensetWithTokens = { - ...tokenset, - access_token: 'test_access_token', - id_token: 'test_id_token', - refresh_token: 'test_refresh_token', - expires_at: 9999999999, - }; - - // Act - const { user } = await validate(tokensetWithTokens); - - // Assert - both tokenset and federatedTokens should be present - expect(user.tokenset).toBeDefined(); - expect(user.federatedTokens).toBeDefined(); - expect(user.tokenset.access_token).toBe('test_access_token'); - expect(user.tokenset.id_token).toBe('test_id_token'); - expect(user.federatedTokens.access_token).toBe('test_access_token'); - expect(user.federatedTokens.id_token).toBe('test_id_token'); - }); - - it('should set role to "ADMIN" if OPENID_ADMIN_ROLE is set and user has that role', async () => { - // Act - const { user } = await validate(tokenset); - - // Assert – verify that the user role is set to "ADMIN" - expect(user.role).toBe('ADMIN'); - }); - - it('should not set user role if OPENID_ADMIN_ROLE is set but the user does not have that role', async () => { - // Arrange – simulate a token without the admin permission - jwtDecode.mockReturnValue({ - roles: ['requiredRole'], - permissions: ['not-admin'], - }); - - // Act - const { user } = await validate(tokenset); - - // Assert – verify that the user role is not defined - expect(user.role).toBeUndefined(); - }); - - it('should demote existing admin user when admin role is removed from token', async () => { - // Arrange – simulate an existing user who is currently an admin - const existingAdminUser = { - _id: 'existingAdminId', - provider: 'openid', - email: tokenset.claims().email, - openidId: tokenset.claims().sub, - username: 'adminuser', - name: 'Admin User', - role: 'ADMIN', - }; - - findUser.mockImplementation(async (query) => { - if (query.openidId === tokenset.claims().sub || query.email === tokenset.claims().email) { - return existingAdminUser; - } - return null; - }); - - // Token without admin permission - jwtDecode.mockReturnValue({ - roles: ['requiredRole'], - permissions: ['not-admin'], - }); - - const { logger } = require('@librechat/data-schemas'); - - // Act - const { user } = await validate(tokenset); - - // Assert – verify that the user was demoted - expect(user.role).toBe('USER'); - expect(updateUser).toHaveBeenCalledWith( - existingAdminUser._id, - expect.objectContaining({ - role: 'USER', - }), - ); - expect(logger.info).toHaveBeenCalledWith( - expect.stringContaining('demoted from admin - role no longer present in token'), - ); - }); - - it('should NOT demote admin user when admin role env vars are not configured', async () => { - // Arrange – remove admin role env vars - delete process.env.OPENID_ADMIN_ROLE; - delete process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH; - delete process.env.OPENID_ADMIN_ROLE_TOKEN_KIND; - - await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - - // Simulate an existing admin user - const existingAdminUser = { - _id: 'existingAdminId', - provider: 'openid', - email: tokenset.claims().email, - openidId: tokenset.claims().sub, - username: 'adminuser', - name: 'Admin User', - role: 'ADMIN', - }; - - findUser.mockImplementation(async (query) => { - if (query.openidId === tokenset.claims().sub || query.email === tokenset.claims().email) { - return existingAdminUser; - } - return null; - }); - - jwtDecode.mockReturnValue({ - roles: ['requiredRole'], - }); - - // Act - const { user } = await validate(tokenset); - - // Assert – verify that the admin user was NOT demoted - expect(user.role).toBe('ADMIN'); - expect(updateUser).toHaveBeenCalledWith( - existingAdminUser._id, - expect.objectContaining({ - role: 'ADMIN', - }), - ); - }); - - describe('lodash get - nested path extraction', () => { - it('should extract roles from deeply nested token path', async () => { - process.env.OPENID_REQUIRED_ROLE = 'app-user'; - process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'resource_access.my-client.roles'; - - jwtDecode.mockReturnValue({ - resource_access: { - 'my-client': { - roles: ['app-user', 'viewer'], - }, - }, - }); - - await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - - const { user } = await validate(tokenset); - - expect(user).toBeTruthy(); - expect(user.email).toBe(tokenset.claims().email); - }); - - it('should extract roles from three-level nested path', async () => { - process.env.OPENID_REQUIRED_ROLE = 'editor'; - process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'data.access.permissions.roles'; - - jwtDecode.mockReturnValue({ - data: { - access: { - permissions: { - roles: ['editor', 'reader'], - }, - }, - }, - }); - - await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - - const { user } = await validate(tokenset); - - expect(user).toBeTruthy(); - }); - - it('should log error and reject login when required role path does not exist in token', async () => { - const { logger } = require('@librechat/data-schemas'); - process.env.OPENID_REQUIRED_ROLE = 'app-user'; - process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'resource_access.nonexistent.roles'; - - jwtDecode.mockReturnValue({ - resource_access: { - 'my-client': { - roles: ['app-user'], - }, - }, - }); - - await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - - const { user, details } = await validate(tokenset); - - expect(logger.error).toHaveBeenCalledWith( - expect.stringContaining("Key 'resource_access.nonexistent.roles' not found in id token!"), - ); - expect(user).toBe(false); - expect(details.message).toContain('role to log in'); - }); - - it('should handle missing intermediate nested path gracefully', async () => { - const { logger } = require('@librechat/data-schemas'); - process.env.OPENID_REQUIRED_ROLE = 'user'; - process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'org.team.roles'; - - jwtDecode.mockReturnValue({ - org: { - other: 'value', - }, - }); - - await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - - const { user } = await validate(tokenset); - - expect(logger.error).toHaveBeenCalledWith( - expect.stringContaining("Key 'org.team.roles' not found in id token!"), - ); - expect(user).toBe(false); - }); - - it('should extract admin role from nested path in access token', async () => { - process.env.OPENID_ADMIN_ROLE = 'admin'; - process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'realm_access.roles'; - process.env.OPENID_ADMIN_ROLE_TOKEN_KIND = 'access'; - - jwtDecode.mockImplementation((token) => { - if (token === 'fake_access_token') { - return { - realm_access: { - roles: ['admin', 'user'], - }, - }; - } - return { - roles: ['requiredRole'], - }; - }); - - await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - - const { user } = await validate(tokenset); - - expect(user.role).toBe('ADMIN'); - }); - - it('should extract admin role from nested path in userinfo', async () => { - process.env.OPENID_ADMIN_ROLE = 'admin'; - process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'organization.permissions'; - process.env.OPENID_ADMIN_ROLE_TOKEN_KIND = 'userinfo'; - - const userinfoWithNestedGroups = { - ...tokenset.claims(), - organization: { - permissions: ['admin', 'write'], - }, - }; - - require('openid-client').fetchUserInfo.mockResolvedValue({ - organization: { - permissions: ['admin', 'write'], - }, - }); - - jwtDecode.mockReturnValue({ - roles: ['requiredRole'], - }); - - await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - - const { user } = await validate({ - ...tokenset, - claims: () => userinfoWithNestedGroups, - }); - - expect(user.role).toBe('ADMIN'); - }); - - it('should handle boolean admin role value', async () => { - process.env.OPENID_ADMIN_ROLE = 'admin'; - process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'is_admin'; - - jwtDecode.mockReturnValue({ - roles: ['requiredRole'], - is_admin: true, - }); - - await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - - const { user } = await validate(tokenset); - - expect(user.role).toBe('ADMIN'); - }); - - it('should handle string admin role value matching exactly', async () => { - process.env.OPENID_ADMIN_ROLE = 'super-admin'; - process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'role'; - - jwtDecode.mockReturnValue({ - roles: ['requiredRole'], - role: 'super-admin', - }); - - await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - - const { user } = await validate(tokenset); - - expect(user.role).toBe('ADMIN'); - }); - - it('should not set admin role when string value does not match', async () => { - process.env.OPENID_ADMIN_ROLE = 'super-admin'; - process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'role'; - - jwtDecode.mockReturnValue({ - roles: ['requiredRole'], - role: 'regular-user', - }); - - await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - - const { user } = await validate(tokenset); - - expect(user.role).toBeUndefined(); - }); - - it('should handle array admin role value', async () => { - process.env.OPENID_ADMIN_ROLE = 'site-admin'; - process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'app_roles'; - - jwtDecode.mockReturnValue({ - roles: ['requiredRole'], - app_roles: ['user', 'site-admin', 'moderator'], - }); - - await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - - const { user } = await validate(tokenset); - - expect(user.role).toBe('ADMIN'); - }); - - it('should not set admin when role is not in array', async () => { - process.env.OPENID_ADMIN_ROLE = 'site-admin'; - process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'app_roles'; - - jwtDecode.mockReturnValue({ - roles: ['requiredRole'], - app_roles: ['user', 'moderator'], - }); - - await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - - const { user } = await validate(tokenset); - - expect(user.role).toBeUndefined(); - }); - - it('should grant admin when admin role claim is a space-separated string containing the admin role', async () => { - // Arrange – IdP returns admin roles as a space-delimited string - process.env.OPENID_ADMIN_ROLE = 'site-admin'; - process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'app_roles'; - - jwtDecode.mockReturnValue({ - roles: ['requiredRole'], - app_roles: 'user site-admin moderator', - }); - - await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - - // Act - const { user } = await validate(tokenset); - - // Assert – admin role is granted after splitting the delimited string - expect(user.role).toBe('ADMIN'); - }); - - it('should not grant admin when admin role claim is a space-separated string that does not contain the admin role', async () => { - // Arrange – delimited string present but admin role is absent - process.env.OPENID_ADMIN_ROLE = 'site-admin'; - process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'app_roles'; - - jwtDecode.mockReturnValue({ - roles: ['requiredRole'], - app_roles: 'user moderator', - }); - - await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - - // Act - const { user } = await validate(tokenset); - - // Assert – admin role is not granted - expect(user.role).toBeUndefined(); - }); - - it('should handle nested path with special characters in keys', async () => { - process.env.OPENID_REQUIRED_ROLE = 'app-user'; - process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'resource_access.my-app-123.roles'; - - jwtDecode.mockReturnValue({ - resource_access: { - 'my-app-123': { - roles: ['app-user'], - }, - }, - }); - - await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - - const { user } = await validate(tokenset); - - expect(user).toBeTruthy(); - }); - - it('should handle empty object at nested path', async () => { - const { logger } = require('@librechat/data-schemas'); - process.env.OPENID_REQUIRED_ROLE = 'user'; - process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'access.roles'; - - jwtDecode.mockReturnValue({ - access: {}, - }); - - await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - - const { user } = await validate(tokenset); - - expect(logger.error).toHaveBeenCalledWith( - expect.stringContaining("Key 'access.roles' not found in id token!"), - ); - expect(user).toBe(false); - }); - - it('should handle null value at intermediate path', async () => { - const { logger } = require('@librechat/data-schemas'); - process.env.OPENID_REQUIRED_ROLE = 'user'; - process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'data.roles'; - - jwtDecode.mockReturnValue({ - data: null, - }); - - await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - - const { user } = await validate(tokenset); - - expect(logger.error).toHaveBeenCalledWith( - expect.stringContaining("Key 'data.roles' not found in id token!"), - ); - expect(user).toBe(false); - }); - - it('should reject login with invalid admin role token kind', async () => { - process.env.OPENID_ADMIN_ROLE = 'admin'; - process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'roles'; - process.env.OPENID_ADMIN_ROLE_TOKEN_KIND = 'invalid'; - - const { logger } = require('@librechat/data-schemas'); - - jwtDecode.mockReturnValue({ - roles: ['requiredRole', 'admin'], - }); - - await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - - await expect(validate(tokenset)).rejects.toThrow('Invalid admin role token kind'); - - expect(logger.error).toHaveBeenCalledWith( - expect.stringContaining( - "Invalid admin role token kind: invalid. Must be one of 'access', 'id', or 'userinfo'", - ), - ); - }); - - it('should reject login when roles path returns invalid type (object)', async () => { - const { logger } = require('@librechat/data-schemas'); - process.env.OPENID_REQUIRED_ROLE = 'app-user'; - process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'roles'; - - jwtDecode.mockReturnValue({ - roles: { admin: true, user: false }, - }); - - await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - - const { user, details } = await validate(tokenset); - - expect(logger.error).toHaveBeenCalledWith( - expect.stringContaining("Key 'roles' not found in id token!"), - ); - expect(user).toBe(false); - expect(details.message).toContain('role to log in'); - }); - - it('should reject login when roles path returns invalid type (number)', async () => { - const { logger } = require('@librechat/data-schemas'); - process.env.OPENID_REQUIRED_ROLE = 'user'; - process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'roleCount'; - - jwtDecode.mockReturnValue({ - roleCount: 5, - }); - - await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - - const { user } = await validate(tokenset); - - expect(logger.error).toHaveBeenCalledWith( - expect.stringContaining("Key 'roleCount' not found in id token!"), - ); - expect(user).toBe(false); - }); - }); - - describe('OPENID_EMAIL_CLAIM', () => { - it('should use the default email when OPENID_EMAIL_CLAIM is not set', async () => { - const { user } = await validate(tokenset); - expect(user.email).toBe('test@example.com'); - }); - - it('should use the configured claim when OPENID_EMAIL_CLAIM is set', async () => { - process.env.OPENID_EMAIL_CLAIM = 'upn'; - const userinfo = { ...tokenset.claims(), upn: 'user@corp.example.com' }; - - const { user } = await validate({ ...tokenset, claims: () => userinfo }); - - expect(user.email).toBe('user@corp.example.com'); - expect(createUser).toHaveBeenCalledWith( - expect.objectContaining({ email: 'user@corp.example.com' }), - expect.anything(), - true, - true, - ); - }); - - it('should fall back to preferred_username when email is missing and OPENID_EMAIL_CLAIM is not set', async () => { - const userinfo = { ...tokenset.claims() }; - delete userinfo.email; - - const { user } = await validate({ ...tokenset, claims: () => userinfo }); - - expect(user.email).toBe('testusername'); - }); - - it('should fall back to upn when email and preferred_username are missing and OPENID_EMAIL_CLAIM is not set', async () => { - const userinfo = { ...tokenset.claims(), upn: 'user@corp.example.com' }; - delete userinfo.email; - delete userinfo.preferred_username; - - const { user } = await validate({ ...tokenset, claims: () => userinfo }); - - expect(user.email).toBe('user@corp.example.com'); - }); - - it('should ignore empty string OPENID_EMAIL_CLAIM and use default fallback', async () => { - process.env.OPENID_EMAIL_CLAIM = ''; - - const { user } = await validate(tokenset); - - expect(user.email).toBe('test@example.com'); - }); - - it('should trim whitespace from OPENID_EMAIL_CLAIM and resolve correctly', async () => { - process.env.OPENID_EMAIL_CLAIM = ' upn '; - const userinfo = { ...tokenset.claims(), upn: 'user@corp.example.com' }; - - const { user } = await validate({ ...tokenset, claims: () => userinfo }); - - expect(user.email).toBe('user@corp.example.com'); - }); - - it('should ignore whitespace-only OPENID_EMAIL_CLAIM and use default fallback', async () => { - process.env.OPENID_EMAIL_CLAIM = ' '; - - const { user } = await validate(tokenset); - - expect(user.email).toBe('test@example.com'); - }); - - it('should fall back to default chain with warning when configured claim is missing from userinfo', async () => { - const { logger } = require('@librechat/data-schemas'); - process.env.OPENID_EMAIL_CLAIM = 'nonexistent_claim'; - - const { user } = await validate(tokenset); - - expect(user.email).toBe('test@example.com'); - expect(logger.warn).toHaveBeenCalledWith( - expect.stringContaining('OPENID_EMAIL_CLAIM="nonexistent_claim" not present in userinfo'), - ); - }); - }); -}); +const undici = require('undici'); +const fetch = require('node-fetch'); +const jwtDecode = require('jsonwebtoken/decode'); +const { ErrorTypes } = require('librechat-data-provider'); +const { findUser, createUser, updateUser } = require('~/models'); +const { resolveAppConfigForUser } = require('@librechat/api'); +const { getAppConfig } = require('~/server/services/Config'); +const { setupOpenId } = require('./openidStrategy'); + +// --- Mocks --- +jest.mock('node-fetch'); +jest.mock('jsonwebtoken/decode'); +jest.mock('undici', () => ({ + fetch: jest.fn(), + ProxyAgent: jest.fn(), +})); +jest.mock('~/server/services/Files/strategies', () => ({ + getStrategyFunctions: jest.fn(() => ({ + saveBuffer: jest.fn().mockResolvedValue('/fake/path/to/avatar.png'), + })), +})); +jest.mock('~/server/services/Config', () => ({ + getAppConfig: jest.fn().mockResolvedValue({}), +})); +jest.mock('@librechat/api', () => ({ + ...jest.requireActual('@librechat/api'), + isEnabled: jest.fn(() => false), + isEmailDomainAllowed: jest.fn(() => true), + findOpenIDUser: jest.requireActual('@librechat/api').findOpenIDUser, + getBalanceConfig: jest.fn(() => ({ + enabled: false, + })), + resolveAppConfigForUser: jest.fn(async (_getAppConfig, _user) => ({})), +})); +jest.mock('~/models', () => ({ + findUser: jest.fn(), + createUser: jest.fn(), + updateUser: jest.fn(), +})); +jest.mock('@librechat/data-schemas', () => ({ + ...jest.requireActual('@librechat/api'), + logger: { + info: jest.fn(), + warn: jest.fn(), + debug: jest.fn(), + error: jest.fn(), + }, + hashToken: jest.fn().mockResolvedValue('hashed-token'), +})); +jest.mock('~/cache/getLogStores', () => + jest.fn(() => ({ + get: jest.fn(), + set: jest.fn(), + })), +); + +// Mock the openid-client module and all its dependencies +jest.mock('openid-client', () => { + return { + discovery: jest.fn().mockResolvedValue({ + clientId: 'fake_client_id', + clientSecret: 'fake_client_secret', + issuer: 'https://fake-issuer.com', + // Add any other properties needed by the implementation + }), + fetchUserInfo: jest.fn().mockImplementation(() => { + // Only return additional properties, but don't override any claims + return Promise.resolve({}); + }), + genericGrantRequest: jest.fn().mockResolvedValue({ + access_token: 'exchanged_graph_token', + expires_in: 3600, + }), + customFetch: Symbol('customFetch'), + }; +}); + +jest.mock('openid-client/passport', () => { + /** Store callbacks by strategy name - 'openid' and 'openidAdmin' */ + const verifyCallbacks = {}; + let lastVerifyCallback; + + const mockStrategy = jest.fn((options, verify) => { + lastVerifyCallback = verify; + return { name: 'openid', options, verify }; + }); + + return { + Strategy: mockStrategy, + /** Get the last registered callback (for backward compatibility) */ + __getVerifyCallback: () => lastVerifyCallback, + /** Store callback by name when passport.use is called */ + __setVerifyCallback: (name, callback) => { + verifyCallbacks[name] = callback; + }, + /** Get callback by strategy name */ + __getVerifyCallbackByName: (name) => verifyCallbacks[name], + }; +}); + +// Mock passport - capture strategy name and callback +jest.mock('passport', () => ({ + use: jest.fn((name, strategy) => { + const passportMock = require('openid-client/passport'); + if (strategy && strategy.verify) { + passportMock.__setVerifyCallback(name, strategy.verify); + } + }), +})); + +describe('setupOpenId', () => { + // Store a reference to the verify callback once it's set up + let verifyCallback; + + // Helper to wrap the verify callback in a promise + const validate = (tokenset) => + new Promise((resolve, reject) => { + verifyCallback(tokenset, (err, user, details) => { + if (err) { + reject(err); + } else { + resolve({ user, details }); + } + }); + }); + + const tokenset = { + id_token: 'fake_id_token', + access_token: 'fake_access_token', + claims: () => ({ + sub: '1234', + email: 'test@example.com', + email_verified: true, + given_name: 'First', + family_name: 'Last', + name: 'My Full', + preferred_username: 'testusername', + username: 'flast', + picture: 'https://example.com/avatar.png', + }), + }; + + beforeEach(async () => { + // Clear previous mock calls and reset implementations + jest.clearAllMocks(); + + // Reset environment variables needed by the strategy + process.env.OPENID_ISSUER = 'https://fake-issuer.com'; + process.env.OPENID_CLIENT_ID = 'fake_client_id'; + process.env.OPENID_CLIENT_SECRET = 'fake_client_secret'; + process.env.DOMAIN_SERVER = 'https://example.com'; + process.env.OPENID_CALLBACK_URL = '/callback'; + process.env.OPENID_SCOPE = 'openid profile email'; + process.env.OPENID_REQUIRED_ROLE = 'requiredRole'; + process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'roles'; + process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; + process.env.OPENID_ADMIN_ROLE = 'admin'; + process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'permissions'; + process.env.OPENID_ADMIN_ROLE_TOKEN_KIND = 'id'; + delete process.env.OPENID_USERNAME_CLAIM; + delete process.env.OPENID_NAME_CLAIM; + delete process.env.OPENID_EMAIL_CLAIM; + delete process.env.PROXY; + delete process.env.OPENID_USE_PKCE; + + // Default jwtDecode mock returns a token that includes the required role. + jwtDecode.mockReturnValue({ + roles: ['requiredRole'], + permissions: ['admin'], + }); + + // By default, assume that no user is found, so createUser will be called + findUser.mockResolvedValue(null); + createUser.mockImplementation(async (userData) => { + // simulate created user with an _id property + return { _id: 'newUserId', ...userData }; + }); + updateUser.mockImplementation(async (id, userData) => { + return { _id: id, ...userData }; + }); + + // For image download, simulate a successful response + const fakeBuffer = Buffer.from('fake image'); + const fakeResponse = { + ok: true, + buffer: jest.fn().mockResolvedValue(fakeBuffer), + }; + fetch.mockResolvedValue(fakeResponse); + + // Call the setup function and capture the verify callback for the regular 'openid' strategy + // (not 'openidAdmin' which requires existing users) + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + }); + + it('should create a new user with correct username when preferred_username claim exists', async () => { + // Arrange – our userinfo already has preferred_username 'testusername' + const userinfo = tokenset.claims(); + + // Act + const { user } = await validate(tokenset); + + // Assert + expect(user.username).toBe(userinfo.preferred_username); + expect(createUser).toHaveBeenCalledWith( + expect.objectContaining({ + provider: 'openid', + openidId: userinfo.sub, + username: userinfo.preferred_username, + email: userinfo.email, + name: `${userinfo.given_name} ${userinfo.family_name}`, + }), + { enabled: false }, + true, + true, + ); + }); + + it('should use username as username when preferred_username claim is missing', async () => { + // Arrange – remove preferred_username from userinfo + const userinfo = { ...tokenset.claims() }; + delete userinfo.preferred_username; + // Expect the username to be the "username" + const expectUsername = userinfo.username; + + // Act + const { user } = await validate({ ...tokenset, claims: () => userinfo }); + + // Assert + expect(user.username).toBe(expectUsername); + expect(createUser).toHaveBeenCalledWith( + expect.objectContaining({ username: expectUsername }), + { enabled: false }, + true, + true, + ); + }); + + it('should use email as username when username and preferred_username are missing', async () => { + // Arrange – remove username and preferred_username + const userinfo = { ...tokenset.claims() }; + delete userinfo.username; + delete userinfo.preferred_username; + const expectUsername = userinfo.email; + + // Act + const { user } = await validate({ ...tokenset, claims: () => userinfo }); + + // Assert + expect(user.username).toBe(expectUsername); + expect(createUser).toHaveBeenCalledWith( + expect.objectContaining({ username: expectUsername }), + { enabled: false }, + true, + true, + ); + }); + + it('should override username with OPENID_USERNAME_CLAIM when set', async () => { + // Arrange – set OPENID_USERNAME_CLAIM so that the sub claim is used + process.env.OPENID_USERNAME_CLAIM = 'sub'; + const userinfo = tokenset.claims(); + + // Act + const { user } = await validate(tokenset); + + // Assert – username should equal the sub (converted as-is) + expect(user.username).toBe(userinfo.sub); + expect(createUser).toHaveBeenCalledWith( + expect.objectContaining({ username: userinfo.sub }), + { enabled: false }, + true, + true, + ); + }); + + it('should set the full name correctly when given_name and family_name exist', async () => { + // Arrange + const userinfo = tokenset.claims(); + const expectedFullName = `${userinfo.given_name} ${userinfo.family_name}`; + + // Act + const { user } = await validate(tokenset); + + // Assert + expect(user.name).toBe(expectedFullName); + }); + + it('should override full name with OPENID_NAME_CLAIM when set', async () => { + // Arrange – use the name claim as the full name + process.env.OPENID_NAME_CLAIM = 'name'; + const userinfo = { ...tokenset.claims(), name: 'Custom Name' }; + + // Act + const { user } = await validate({ ...tokenset, claims: () => userinfo }); + + // Assert + expect(user.name).toBe('Custom Name'); + }); + + it('should update an existing user on login', async () => { + // Arrange – simulate that a user already exists with openid provider + const existingUser = { + _id: 'existingUserId', + provider: 'openid', + email: tokenset.claims().email, + openidId: '', + username: '', + name: '', + }; + findUser.mockImplementation(async (query) => { + if (query.openidId === tokenset.claims().sub || query.email === tokenset.claims().email) { + return existingUser; + } + return null; + }); + + const userinfo = tokenset.claims(); + + // Act + await validate(tokenset); + + // Assert – updateUser should be called and the user object updated + expect(updateUser).toHaveBeenCalledWith( + existingUser._id, + expect.objectContaining({ + provider: 'openid', + openidId: userinfo.sub, + username: userinfo.preferred_username, + name: `${userinfo.given_name} ${userinfo.family_name}`, + }), + ); + }); + + it('should block login when email exists with different provider', async () => { + // Arrange – simulate that a user exists with same email but different provider + const existingUser = { + _id: 'existingUserId', + provider: 'google', + email: tokenset.claims().email, + googleId: 'some-google-id', + username: 'existinguser', + name: 'Existing User', + }; + findUser.mockImplementation(async (query) => { + if (query.email === tokenset.claims().email && !query.provider) { + return existingUser; + } + return null; + }); + + // Act + const result = await validate(tokenset); + + // Assert – verify that the strategy rejects login + expect(result.user).toBe(false); + expect(result.details.message).toBe(ErrorTypes.AUTH_FAILED); + expect(createUser).not.toHaveBeenCalled(); + expect(updateUser).not.toHaveBeenCalled(); + }); + + it('should block login when email fallback finds user with mismatched openidId', async () => { + const existingUser = { + _id: 'existingUserId', + provider: 'openid', + openidId: 'different-sub-claim', + email: tokenset.claims().email, + username: 'existinguser', + name: 'Existing User', + }; + findUser.mockImplementation(async (query) => { + if (query.$or) { + return null; + } + if (query.email === tokenset.claims().email) { + return existingUser; + } + return null; + }); + + const result = await validate(tokenset); + + expect(result.user).toBe(false); + expect(result.details.message).toBe(ErrorTypes.AUTH_FAILED); + expect(createUser).not.toHaveBeenCalled(); + expect(updateUser).not.toHaveBeenCalled(); + }); + + it('should enforce the required role and reject login if missing', async () => { + // Arrange – simulate a token without the required role. + jwtDecode.mockReturnValue({ + roles: ['SomeOtherRole'], + }); + + // Act + const { user, details } = await validate(tokenset); + + // Assert – verify that the strategy rejects login + expect(user).toBe(false); + expect(details.message).toBe('You must have "requiredRole" role to log in.'); + }); + + it('should not treat substring matches in string roles as satisfying required role', async () => { + // Arrange – override required role to "read" then re-setup + process.env.OPENID_REQUIRED_ROLE = 'read'; + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + // Token contains "bread" which *contains* "read" as a substring + jwtDecode.mockReturnValue({ + roles: 'bread', + }); + + // Act + const { user, details } = await validate(tokenset); + + // Assert – verify that substring match does not grant access + expect(user).toBe(false); + expect(details.message).toBe('You must have "read" role to log in.'); + }); + + it('should allow login when roles claim is a space-separated string containing the required role', async () => { + // Arrange – IdP returns roles as a space-delimited string + jwtDecode.mockReturnValue({ + roles: 'role1 role2 requiredRole', + }); + + // Act + const { user } = await validate(tokenset); + + // Assert – login succeeds when required role is present after splitting + expect(user).toBeTruthy(); + expect(createUser).toHaveBeenCalled(); + }); + + it('should allow login when roles claim is a comma-separated string containing the required role', async () => { + // Arrange – IdP returns roles as a comma-delimited string + jwtDecode.mockReturnValue({ + roles: 'role1,role2,requiredRole', + }); + + // Act + const { user } = await validate(tokenset); + + // Assert – login succeeds when required role is present after splitting + expect(user).toBeTruthy(); + expect(createUser).toHaveBeenCalled(); + }); + + it('should allow login when roles claim is a mixed comma-and-space-separated string containing the required role', async () => { + // Arrange – IdP returns roles with comma-and-space delimiters + jwtDecode.mockReturnValue({ + roles: 'role1, role2, requiredRole', + }); + + // Act + const { user } = await validate(tokenset); + + // Assert – login succeeds when required role is present after splitting + expect(user).toBeTruthy(); + expect(createUser).toHaveBeenCalled(); + }); + + it('should reject login when roles claim is a space-separated string that does not contain the required role', async () => { + // Arrange – IdP returns a delimited string but required role is absent + jwtDecode.mockReturnValue({ + roles: 'role1 role2 otherRole', + }); + + // Act + const { user, details } = await validate(tokenset); + + // Assert – login is rejected with the correct error message + expect(user).toBe(false); + expect(details.message).toBe('You must have "requiredRole" role to log in.'); + }); + + it('should allow login when single required role is present (backward compatibility)', async () => { + // Arrange – ensure single role configuration (as set in beforeEach) + // OPENID_REQUIRED_ROLE = 'requiredRole' + // Default jwtDecode mock in beforeEach already returns this role + jwtDecode.mockReturnValue({ + roles: ['requiredRole', 'anotherRole'], + }); + + // Act + const { user } = await validate(tokenset); + + // Assert – verify that login succeeds with single role configuration + expect(user).toBeTruthy(); + expect(user.email).toBe(tokenset.claims().email); + expect(user.username).toBe(tokenset.claims().preferred_username); + expect(createUser).toHaveBeenCalled(); + }); + + describe('group overage and groups handling', () => { + it.each([ + ['groups array contains required group', ['group-required', 'other-group'], true, undefined], + [ + 'groups array missing required group', + ['other-group'], + false, + 'You must have "group-required" role to log in.', + ], + ['groups string equals required group', 'group-required', true, undefined], + [ + 'groups string is other group', + 'other-group', + false, + 'You must have "group-required" role to log in.', + ], + ])( + 'uses groups claim directly when %s (no overage)', + async (_label, groupsClaim, expectedAllowed, expectedMessage) => { + process.env.OPENID_REQUIRED_ROLE = 'group-required'; + process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups'; + process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; + + jwtDecode.mockReturnValue({ + groups: groupsClaim, + permissions: ['admin'], + }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + const { user, details } = await validate(tokenset); + + expect(undici.fetch).not.toHaveBeenCalled(); + expect(Boolean(user)).toBe(expectedAllowed); + expect(details?.message).toBe(expectedMessage); + }, + ); + + it.each([ + ['token kind is not id', { kind: 'access', path: 'groups', decoded: { hasgroups: true } }], + ['parameter path is not groups', { kind: 'id', path: 'roles', decoded: { hasgroups: true } }], + ['decoded token is falsy', { kind: 'id', path: 'groups', decoded: null }], + [ + 'no overage indicators in decoded token', + { + kind: 'id', + path: 'groups', + decoded: { + permissions: ['admin'], + }, + }, + ], + [ + 'only _claim_names present (no _claim_sources)', + { + kind: 'id', + path: 'groups', + decoded: { + _claim_names: { groups: 'src1' }, + permissions: ['admin'], + }, + }, + ], + [ + 'only _claim_sources present (no _claim_names)', + { + kind: 'id', + path: 'groups', + decoded: { + _claim_sources: { src1: { endpoint: 'https://graph.windows.net/ignored' } }, + permissions: ['admin'], + }, + }, + ], + ])('does not attempt overage resolution when %s', async (_label, cfg) => { + process.env.OPENID_REQUIRED_ROLE = 'group-required'; + process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = cfg.path; + process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = cfg.kind; + + jwtDecode.mockReturnValue(cfg.decoded); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + const { user, details } = await validate(tokenset); + + expect(undici.fetch).not.toHaveBeenCalled(); + expect(user).toBe(false); + expect(details.message).toBe('You must have "group-required" role to log in.'); + const { logger } = require('@librechat/data-schemas'); + const expectedTokenKind = cfg.kind === 'access' ? 'access token' : 'id token'; + expect(logger.error).toHaveBeenCalledWith( + expect.stringContaining(`Key '${cfg.path}' not found in ${expectedTokenKind}!`), + ); + }); + }); + + describe('resolving groups via Microsoft Graph', () => { + it('denies login and does not call Graph when access token is missing', async () => { + process.env.OPENID_REQUIRED_ROLE = 'group-required'; + process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups'; + process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; + + const { logger } = require('@librechat/data-schemas'); + + jwtDecode.mockReturnValue({ + hasgroups: true, + permissions: ['admin'], + }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + const tokensetWithoutAccess = { + ...tokenset, + access_token: undefined, + }; + + const { user, details } = await validate(tokensetWithoutAccess); + + expect(user).toBe(false); + expect(details.message).toBe('You must have "group-required" role to log in.'); + + expect(undici.fetch).not.toHaveBeenCalled(); + expect(logger.error).toHaveBeenCalledWith( + expect.stringContaining('Access token missing; cannot resolve group overage'), + ); + }); + + it.each([ + [ + 'Graph returns HTTP error', + async () => ({ + ok: false, + status: 403, + statusText: 'Forbidden', + json: async () => ({}), + }), + [ + '[openidStrategy] Failed to resolve groups via Microsoft Graph getMemberObjects: HTTP 403 Forbidden', + ], + ], + [ + 'Graph network error', + async () => { + throw new Error('network error'); + }, + [ + '[openidStrategy] Error resolving groups via Microsoft Graph getMemberObjects:', + expect.any(Error), + ], + ], + [ + 'Graph returns unexpected shape (no value)', + async () => ({ + ok: true, + status: 200, + statusText: 'OK', + json: async () => ({}), + }), + [ + '[openidStrategy] Unexpected response format when resolving groups via Microsoft Graph getMemberObjects', + ], + ], + [ + 'Graph returns invalid value type', + async () => ({ + ok: true, + status: 200, + statusText: 'OK', + json: async () => ({ value: 'not-an-array' }), + }), + [ + '[openidStrategy] Unexpected response format when resolving groups via Microsoft Graph getMemberObjects', + ], + ], + ])( + 'denies login when overage resolution fails because %s', + async (_label, setupFetch, expectedErrorArgs) => { + process.env.OPENID_REQUIRED_ROLE = 'group-required'; + process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups'; + process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; + + const { logger } = require('@librechat/data-schemas'); + + jwtDecode.mockReturnValue({ + hasgroups: true, + permissions: ['admin'], + }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + undici.fetch.mockImplementation(setupFetch); + + const { user, details } = await validate(tokenset); + + expect(undici.fetch).toHaveBeenCalled(); + expect(user).toBe(false); + expect(details.message).toBe('You must have "group-required" role to log in.'); + + expect(logger.error).toHaveBeenCalledWith(...expectedErrorArgs); + }, + ); + + it.each([ + [ + 'hasgroups overage and Graph contains required group', + { + hasgroups: true, + }, + ['group-required', 'some-other-group'], + true, + ], + [ + '_claim_* overage and Graph contains required group', + { + _claim_names: { groups: 'src1' }, + _claim_sources: { src1: { endpoint: 'https://graph.windows.net/ignored' } }, + }, + ['group-required', 'some-other-group'], + true, + ], + [ + 'hasgroups overage and Graph does NOT contain required group', + { + hasgroups: true, + }, + ['some-other-group'], + false, + ], + [ + '_claim_* overage and Graph does NOT contain required group', + { + _claim_names: { groups: 'src1' }, + _claim_sources: { src1: { endpoint: 'https://graph.windows.net/ignored' } }, + }, + ['some-other-group'], + false, + ], + ])( + 'resolves groups via Microsoft Graph when %s', + async (_label, decodedTokenValue, graphGroups, expectedAllowed) => { + process.env.OPENID_REQUIRED_ROLE = 'group-required'; + process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups'; + process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; + + const { logger } = require('@librechat/data-schemas'); + + jwtDecode.mockReturnValue(decodedTokenValue); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + undici.fetch.mockResolvedValue({ + ok: true, + status: 200, + statusText: 'OK', + json: async () => ({ + value: graphGroups, + }), + }); + + const { user } = await validate(tokenset); + + expect(undici.fetch).toHaveBeenCalledWith( + 'https://graph.microsoft.com/v1.0/me/getMemberObjects', + expect.objectContaining({ + method: 'POST', + headers: expect.objectContaining({ + Authorization: 'Bearer exchanged_graph_token', + }), + }), + ); + expect(Boolean(user)).toBe(expectedAllowed); + + expect(logger.debug).toHaveBeenCalledWith( + expect.stringContaining( + `Successfully resolved ${graphGroups.length} groups via Microsoft Graph getMemberObjects`, + ), + ); + }, + ); + }); + + describe('OBO token exchange for overage', () => { + it('exchanges access token via OBO before calling Graph API', async () => { + const openidClient = require('openid-client'); + process.env.OPENID_REQUIRED_ROLE = 'group-required'; + process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups'; + process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; + + jwtDecode.mockReturnValue({ hasgroups: true }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + undici.fetch.mockResolvedValue({ + ok: true, + status: 200, + statusText: 'OK', + json: async () => ({ value: ['group-required'] }), + }); + + await validate(tokenset); + + expect(openidClient.genericGrantRequest).toHaveBeenCalledWith( + expect.anything(), + 'urn:ietf:params:oauth:grant-type:jwt-bearer', + expect.objectContaining({ + scope: 'https://graph.microsoft.com/User.Read', + assertion: tokenset.access_token, + requested_token_use: 'on_behalf_of', + }), + ); + + expect(undici.fetch).toHaveBeenCalledWith( + 'https://graph.microsoft.com/v1.0/me/getMemberObjects', + expect.objectContaining({ + headers: expect.objectContaining({ + Authorization: 'Bearer exchanged_graph_token', + }), + }), + ); + }); + + it('caches the exchanged token and reuses it on subsequent calls', async () => { + const openidClient = require('openid-client'); + const getLogStores = require('~/cache/getLogStores'); + const mockSet = jest.fn(); + const mockGet = jest + .fn() + .mockResolvedValueOnce(undefined) + .mockResolvedValueOnce({ access_token: 'exchanged_graph_token' }); + getLogStores.mockReturnValue({ get: mockGet, set: mockSet }); + + process.env.OPENID_REQUIRED_ROLE = 'group-required'; + process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups'; + process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; + + jwtDecode.mockReturnValue({ hasgroups: true }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + undici.fetch.mockResolvedValue({ + ok: true, + status: 200, + statusText: 'OK', + json: async () => ({ value: ['group-required'] }), + }); + + // First call: cache miss → OBO exchange → cache set + await validate(tokenset); + expect(mockSet).toHaveBeenCalledWith( + '1234:overage', + { access_token: 'exchanged_graph_token' }, + 3600000, + ); + expect(openidClient.genericGrantRequest).toHaveBeenCalledTimes(1); + + // Second call: cache hit → no new OBO exchange + openidClient.genericGrantRequest.mockClear(); + await validate(tokenset); + expect(openidClient.genericGrantRequest).not.toHaveBeenCalled(); + }); + }); + + describe('admin role group overage', () => { + it('resolves admin groups via Graph when overage is detected for admin role', async () => { + process.env.OPENID_REQUIRED_ROLE = 'group-required'; + process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups'; + process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; + process.env.OPENID_ADMIN_ROLE = 'admin-group-id'; + process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'groups'; + process.env.OPENID_ADMIN_ROLE_TOKEN_KIND = 'id'; + + jwtDecode.mockReturnValue({ hasgroups: true }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + undici.fetch.mockResolvedValue({ + ok: true, + status: 200, + statusText: 'OK', + json: async () => ({ value: ['group-required', 'admin-group-id'] }), + }); + + const { user } = await validate(tokenset); + + expect(user.role).toBe('ADMIN'); + }); + + it('does not grant admin when overage groups do not contain admin role', async () => { + process.env.OPENID_REQUIRED_ROLE = 'group-required'; + process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups'; + process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; + process.env.OPENID_ADMIN_ROLE = 'admin-group-id'; + process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'groups'; + process.env.OPENID_ADMIN_ROLE_TOKEN_KIND = 'id'; + + jwtDecode.mockReturnValue({ hasgroups: true }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + undici.fetch.mockResolvedValue({ + ok: true, + status: 200, + statusText: 'OK', + json: async () => ({ value: ['group-required', 'other-group'] }), + }); + + const { user } = await validate(tokenset); + + expect(user).toBeTruthy(); + expect(user.role).toBeUndefined(); + }); + + it('reuses already-resolved overage groups for admin role check (no duplicate Graph call)', async () => { + process.env.OPENID_REQUIRED_ROLE = 'group-required'; + process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups'; + process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; + process.env.OPENID_ADMIN_ROLE = 'admin-group-id'; + process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'groups'; + process.env.OPENID_ADMIN_ROLE_TOKEN_KIND = 'id'; + + jwtDecode.mockReturnValue({ hasgroups: true }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + undici.fetch.mockResolvedValue({ + ok: true, + status: 200, + statusText: 'OK', + json: async () => ({ value: ['group-required', 'admin-group-id'] }), + }); + + await validate(tokenset); + + // Graph API should be called only once (for required role), admin role reuses the result + expect(undici.fetch).toHaveBeenCalledTimes(1); + }); + + it('demotes existing admin when overage groups no longer contain admin role', async () => { + process.env.OPENID_REQUIRED_ROLE = 'group-required'; + process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups'; + process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; + process.env.OPENID_ADMIN_ROLE = 'admin-group-id'; + process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'groups'; + process.env.OPENID_ADMIN_ROLE_TOKEN_KIND = 'id'; + + const existingAdminUser = { + _id: 'existingAdminId', + provider: 'openid', + email: tokenset.claims().email, + openidId: tokenset.claims().sub, + username: 'adminuser', + name: 'Admin User', + role: 'ADMIN', + }; + + findUser.mockImplementation(async (query) => { + if (query.openidId === tokenset.claims().sub || query.email === tokenset.claims().email) { + return existingAdminUser; + } + return null; + }); + + jwtDecode.mockReturnValue({ hasgroups: true }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + undici.fetch.mockResolvedValue({ + ok: true, + status: 200, + statusText: 'OK', + json: async () => ({ value: ['group-required'] }), + }); + + const { user } = await validate(tokenset); + + expect(user.role).toBe('USER'); + }); + + it('does not attempt overage for admin role when token kind is not id', async () => { + process.env.OPENID_REQUIRED_ROLE = 'requiredRole'; + process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'roles'; + process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; + process.env.OPENID_ADMIN_ROLE = 'admin'; + process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'groups'; + process.env.OPENID_ADMIN_ROLE_TOKEN_KIND = 'access'; + + jwtDecode.mockReturnValue({ + roles: ['requiredRole'], + hasgroups: true, + }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + const { user } = await validate(tokenset); + + // No Graph call since admin uses access token (not id) + expect(undici.fetch).not.toHaveBeenCalled(); + expect(user.role).toBeUndefined(); + }); + + it('resolves admin via Graph independently when OPENID_REQUIRED_ROLE is not configured', async () => { + delete process.env.OPENID_REQUIRED_ROLE; + process.env.OPENID_ADMIN_ROLE = 'admin-group-id'; + process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'groups'; + process.env.OPENID_ADMIN_ROLE_TOKEN_KIND = 'id'; + + jwtDecode.mockReturnValue({ hasgroups: true }); + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + undici.fetch.mockResolvedValue({ + ok: true, + status: 200, + statusText: 'OK', + json: async () => ({ value: ['admin-group-id'] }), + }); + + const { user } = await validate(tokenset); + expect(user.role).toBe('ADMIN'); + expect(undici.fetch).toHaveBeenCalledTimes(1); + }); + + it('denies admin when OPENID_REQUIRED_ROLE is absent and Graph does not contain admin group', async () => { + delete process.env.OPENID_REQUIRED_ROLE; + process.env.OPENID_ADMIN_ROLE = 'admin-group-id'; + process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'groups'; + process.env.OPENID_ADMIN_ROLE_TOKEN_KIND = 'id'; + + jwtDecode.mockReturnValue({ hasgroups: true }); + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + undici.fetch.mockResolvedValue({ + ok: true, + status: 200, + statusText: 'OK', + json: async () => ({ value: ['other-group'] }), + }); + + const { user } = await validate(tokenset); + expect(user).toBeTruthy(); + expect(user.role).toBeUndefined(); + }); + + it('denies login and logs error when OBO exchange throws', async () => { + const openidClient = require('openid-client'); + process.env.OPENID_REQUIRED_ROLE = 'group-required'; + process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups'; + process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; + + jwtDecode.mockReturnValue({ hasgroups: true }); + openidClient.genericGrantRequest.mockRejectedValueOnce(new Error('OBO exchange rejected')); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + const { user, details } = await validate(tokenset); + expect(user).toBe(false); + expect(details.message).toBe('You must have "group-required" role to log in.'); + expect(undici.fetch).not.toHaveBeenCalled(); + }); + + it('denies login when OBO exchange returns no access_token', async () => { + const openidClient = require('openid-client'); + process.env.OPENID_REQUIRED_ROLE = 'group-required'; + process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups'; + process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; + + jwtDecode.mockReturnValue({ hasgroups: true }); + openidClient.genericGrantRequest.mockResolvedValueOnce({ expires_in: 3600 }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + const { user, details } = await validate(tokenset); + expect(user).toBe(false); + expect(details.message).toBe('You must have "group-required" role to log in.'); + expect(undici.fetch).not.toHaveBeenCalled(); + }); + }); + + it('should attempt to download and save the avatar if picture is provided', async () => { + // Act + const { user } = await validate(tokenset); + + // Assert – verify that download was attempted and the avatar field was set via updateUser + expect(fetch).toHaveBeenCalled(); + // Our mock getStrategyFunctions.saveBuffer returns '/fake/path/to/avatar.png' + expect(user.avatar).toBe('/fake/path/to/avatar.png'); + }); + + it('should not attempt to download avatar if picture is not provided', async () => { + // Arrange – remove picture + const userinfo = { ...tokenset.claims() }; + delete userinfo.picture; + + // Act + await validate({ ...tokenset, claims: () => userinfo }); + + // Assert – fetch should not be called and avatar should remain undefined or empty + expect(fetch).not.toHaveBeenCalled(); + // Depending on your implementation, user.avatar may be undefined or an empty string. + }); + + it('should support comma-separated multiple roles', async () => { + // Arrange + process.env.OPENID_REQUIRED_ROLE = 'someRole,anotherRole,admin'; + await setupOpenId(); // Re-initialize the strategy + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + jwtDecode.mockReturnValue({ + roles: ['anotherRole', 'aThirdRole'], + }); + + // Act + const { user } = await validate(tokenset); + + // Assert + expect(user).toBeTruthy(); + expect(user.email).toBe(tokenset.claims().email); + }); + + it('should reject login when user has none of the required multiple roles', async () => { + // Arrange + process.env.OPENID_REQUIRED_ROLE = 'someRole,anotherRole,admin'; + await setupOpenId(); // Re-initialize the strategy + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + jwtDecode.mockReturnValue({ + roles: ['aThirdRole', 'aFourthRole'], + }); + + // Act + const { user, details } = await validate(tokenset); + + // Assert + expect(user).toBe(false); + expect(details.message).toBe( + 'You must have one of: "someRole", "anotherRole", "admin" role to log in.', + ); + }); + + it('should handle spaces in comma-separated roles', async () => { + // Arrange + process.env.OPENID_REQUIRED_ROLE = ' someRole , anotherRole , admin '; + await setupOpenId(); // Re-initialize the strategy + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + jwtDecode.mockReturnValue({ + roles: ['someRole'], + }); + + // Act + const { user } = await validate(tokenset); + + // Assert + expect(user).toBeTruthy(); + }); + + it('should default to usePKCE false when OPENID_USE_PKCE is not defined', async () => { + const OpenIDStrategy = require('openid-client/passport').Strategy; + + delete process.env.OPENID_USE_PKCE; + await setupOpenId(); + + const callOptions = OpenIDStrategy.mock.calls[OpenIDStrategy.mock.calls.length - 1][0]; + expect(callOptions.usePKCE).toBe(false); + expect(callOptions.params?.code_challenge_method).toBeUndefined(); + }); + + it('should attach federatedTokens to user object for token propagation', async () => { + // Arrange - setup tokenset with access token, id token, refresh token, and expiration + const tokensetWithTokens = { + ...tokenset, + access_token: 'mock_access_token_abc123', + id_token: 'mock_id_token_def456', + refresh_token: 'mock_refresh_token_xyz789', + expires_at: 1234567890, + }; + + // Act - validate with the tokenset containing tokens + const { user } = await validate(tokensetWithTokens); + + // Assert - verify federatedTokens object is attached with correct values + expect(user.federatedTokens).toBeDefined(); + expect(user.federatedTokens).toEqual({ + access_token: 'mock_access_token_abc123', + id_token: 'mock_id_token_def456', + refresh_token: 'mock_refresh_token_xyz789', + expires_at: 1234567890, + }); + }); + + it('should include id_token in federatedTokens distinct from access_token', async () => { + // Arrange - use different values for access_token and id_token + const tokensetWithTokens = { + ...tokenset, + access_token: 'the_access_token', + id_token: 'the_id_token', + refresh_token: 'the_refresh_token', + expires_at: 9999999999, + }; + + // Act + const { user } = await validate(tokensetWithTokens); + + // Assert - id_token and access_token must be different values + expect(user.federatedTokens.access_token).toBe('the_access_token'); + expect(user.federatedTokens.id_token).toBe('the_id_token'); + expect(user.federatedTokens.id_token).not.toBe(user.federatedTokens.access_token); + }); + + it('should include tokenset along with federatedTokens', async () => { + // Arrange + const tokensetWithTokens = { + ...tokenset, + access_token: 'test_access_token', + id_token: 'test_id_token', + refresh_token: 'test_refresh_token', + expires_at: 9999999999, + }; + + // Act + const { user } = await validate(tokensetWithTokens); + + // Assert - both tokenset and federatedTokens should be present + expect(user.tokenset).toBeDefined(); + expect(user.federatedTokens).toBeDefined(); + expect(user.tokenset.access_token).toBe('test_access_token'); + expect(user.tokenset.id_token).toBe('test_id_token'); + expect(user.federatedTokens.access_token).toBe('test_access_token'); + expect(user.federatedTokens.id_token).toBe('test_id_token'); + }); + + it('should set role to "ADMIN" if OPENID_ADMIN_ROLE is set and user has that role', async () => { + // Act + const { user } = await validate(tokenset); + + // Assert – verify that the user role is set to "ADMIN" + expect(user.role).toBe('ADMIN'); + }); + + it('should not set user role if OPENID_ADMIN_ROLE is set but the user does not have that role', async () => { + // Arrange – simulate a token without the admin permission + jwtDecode.mockReturnValue({ + roles: ['requiredRole'], + permissions: ['not-admin'], + }); + + // Act + const { user } = await validate(tokenset); + + // Assert – verify that the user role is not defined + expect(user.role).toBeUndefined(); + }); + + it('should demote existing admin user when admin role is removed from token', async () => { + // Arrange – simulate an existing user who is currently an admin + const existingAdminUser = { + _id: 'existingAdminId', + provider: 'openid', + email: tokenset.claims().email, + openidId: tokenset.claims().sub, + username: 'adminuser', + name: 'Admin User', + role: 'ADMIN', + }; + + findUser.mockImplementation(async (query) => { + if (query.openidId === tokenset.claims().sub || query.email === tokenset.claims().email) { + return existingAdminUser; + } + return null; + }); + + // Token without admin permission + jwtDecode.mockReturnValue({ + roles: ['requiredRole'], + permissions: ['not-admin'], + }); + + const { logger } = require('@librechat/data-schemas'); + + // Act + const { user } = await validate(tokenset); + + // Assert – verify that the user was demoted + expect(user.role).toBe('USER'); + expect(updateUser).toHaveBeenCalledWith( + existingAdminUser._id, + expect.objectContaining({ + role: 'USER', + }), + ); + expect(logger.info).toHaveBeenCalledWith( + expect.stringContaining('demoted from admin - role no longer present in token'), + ); + }); + + it('should NOT demote admin user when admin role env vars are not configured', async () => { + // Arrange – remove admin role env vars + delete process.env.OPENID_ADMIN_ROLE; + delete process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH; + delete process.env.OPENID_ADMIN_ROLE_TOKEN_KIND; + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + // Simulate an existing admin user + const existingAdminUser = { + _id: 'existingAdminId', + provider: 'openid', + email: tokenset.claims().email, + openidId: tokenset.claims().sub, + username: 'adminuser', + name: 'Admin User', + role: 'ADMIN', + }; + + findUser.mockImplementation(async (query) => { + if (query.openidId === tokenset.claims().sub || query.email === tokenset.claims().email) { + return existingAdminUser; + } + return null; + }); + + jwtDecode.mockReturnValue({ + roles: ['requiredRole'], + }); + + // Act + const { user } = await validate(tokenset); + + // Assert – verify that the admin user was NOT demoted + expect(user.role).toBe('ADMIN'); + expect(updateUser).toHaveBeenCalledWith( + existingAdminUser._id, + expect.objectContaining({ + role: 'ADMIN', + }), + ); + }); + + describe('lodash get - nested path extraction', () => { + it('should extract roles from deeply nested token path', async () => { + process.env.OPENID_REQUIRED_ROLE = 'app-user'; + process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'resource_access.my-client.roles'; + + jwtDecode.mockReturnValue({ + resource_access: { + 'my-client': { + roles: ['app-user', 'viewer'], + }, + }, + }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + const { user } = await validate(tokenset); + + expect(user).toBeTruthy(); + expect(user.email).toBe(tokenset.claims().email); + }); + + it('should extract roles from three-level nested path', async () => { + process.env.OPENID_REQUIRED_ROLE = 'editor'; + process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'data.access.permissions.roles'; + + jwtDecode.mockReturnValue({ + data: { + access: { + permissions: { + roles: ['editor', 'reader'], + }, + }, + }, + }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + const { user } = await validate(tokenset); + + expect(user).toBeTruthy(); + }); + + it('should log error and reject login when required role path does not exist in token', async () => { + const { logger } = require('@librechat/data-schemas'); + process.env.OPENID_REQUIRED_ROLE = 'app-user'; + process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'resource_access.nonexistent.roles'; + + jwtDecode.mockReturnValue({ + resource_access: { + 'my-client': { + roles: ['app-user'], + }, + }, + }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + const { user, details } = await validate(tokenset); + + expect(logger.error).toHaveBeenCalledWith( + expect.stringContaining("Key 'resource_access.nonexistent.roles' not found in id token!"), + ); + expect(user).toBe(false); + expect(details.message).toContain('role to log in'); + }); + + it('should handle missing intermediate nested path gracefully', async () => { + const { logger } = require('@librechat/data-schemas'); + process.env.OPENID_REQUIRED_ROLE = 'user'; + process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'org.team.roles'; + + jwtDecode.mockReturnValue({ + org: { + other: 'value', + }, + }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + const { user } = await validate(tokenset); + + expect(logger.error).toHaveBeenCalledWith( + expect.stringContaining("Key 'org.team.roles' not found in id token!"), + ); + expect(user).toBe(false); + }); + + it('should extract admin role from nested path in access token', async () => { + process.env.OPENID_ADMIN_ROLE = 'admin'; + process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'realm_access.roles'; + process.env.OPENID_ADMIN_ROLE_TOKEN_KIND = 'access'; + + jwtDecode.mockImplementation((token) => { + if (token === 'fake_access_token') { + return { + realm_access: { + roles: ['admin', 'user'], + }, + }; + } + return { + roles: ['requiredRole'], + }; + }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + const { user } = await validate(tokenset); + + expect(user.role).toBe('ADMIN'); + }); + + it('should extract admin role from nested path in userinfo', async () => { + process.env.OPENID_ADMIN_ROLE = 'admin'; + process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'organization.permissions'; + process.env.OPENID_ADMIN_ROLE_TOKEN_KIND = 'userinfo'; + + const userinfoWithNestedGroups = { + ...tokenset.claims(), + organization: { + permissions: ['admin', 'write'], + }, + }; + + require('openid-client').fetchUserInfo.mockResolvedValue({ + organization: { + permissions: ['admin', 'write'], + }, + }); + + jwtDecode.mockReturnValue({ + roles: ['requiredRole'], + }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + const { user } = await validate({ + ...tokenset, + claims: () => userinfoWithNestedGroups, + }); + + expect(user.role).toBe('ADMIN'); + }); + + it('should handle boolean admin role value', async () => { + process.env.OPENID_ADMIN_ROLE = 'admin'; + process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'is_admin'; + + jwtDecode.mockReturnValue({ + roles: ['requiredRole'], + is_admin: true, + }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + const { user } = await validate(tokenset); + + expect(user.role).toBe('ADMIN'); + }); + + it('should handle string admin role value matching exactly', async () => { + process.env.OPENID_ADMIN_ROLE = 'super-admin'; + process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'role'; + + jwtDecode.mockReturnValue({ + roles: ['requiredRole'], + role: 'super-admin', + }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + const { user } = await validate(tokenset); + + expect(user.role).toBe('ADMIN'); + }); + + it('should not set admin role when string value does not match', async () => { + process.env.OPENID_ADMIN_ROLE = 'super-admin'; + process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'role'; + + jwtDecode.mockReturnValue({ + roles: ['requiredRole'], + role: 'regular-user', + }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + const { user } = await validate(tokenset); + + expect(user.role).toBeUndefined(); + }); + + it('should handle array admin role value', async () => { + process.env.OPENID_ADMIN_ROLE = 'site-admin'; + process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'app_roles'; + + jwtDecode.mockReturnValue({ + roles: ['requiredRole'], + app_roles: ['user', 'site-admin', 'moderator'], + }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + const { user } = await validate(tokenset); + + expect(user.role).toBe('ADMIN'); + }); + + it('should not set admin when role is not in array', async () => { + process.env.OPENID_ADMIN_ROLE = 'site-admin'; + process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'app_roles'; + + jwtDecode.mockReturnValue({ + roles: ['requiredRole'], + app_roles: ['user', 'moderator'], + }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + const { user } = await validate(tokenset); + + expect(user.role).toBeUndefined(); + }); + + it('should grant admin when admin role claim is a space-separated string containing the admin role', async () => { + // Arrange – IdP returns admin roles as a space-delimited string + process.env.OPENID_ADMIN_ROLE = 'site-admin'; + process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'app_roles'; + + jwtDecode.mockReturnValue({ + roles: ['requiredRole'], + app_roles: 'user site-admin moderator', + }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + // Act + const { user } = await validate(tokenset); + + // Assert – admin role is granted after splitting the delimited string + expect(user.role).toBe('ADMIN'); + }); + + it('should not grant admin when admin role claim is a space-separated string that does not contain the admin role', async () => { + // Arrange – delimited string present but admin role is absent + process.env.OPENID_ADMIN_ROLE = 'site-admin'; + process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'app_roles'; + + jwtDecode.mockReturnValue({ + roles: ['requiredRole'], + app_roles: 'user moderator', + }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + // Act + const { user } = await validate(tokenset); + + // Assert – admin role is not granted + expect(user.role).toBeUndefined(); + }); + + it('should handle nested path with special characters in keys', async () => { + process.env.OPENID_REQUIRED_ROLE = 'app-user'; + process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'resource_access.my-app-123.roles'; + + jwtDecode.mockReturnValue({ + resource_access: { + 'my-app-123': { + roles: ['app-user'], + }, + }, + }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + const { user } = await validate(tokenset); + + expect(user).toBeTruthy(); + }); + + it('should handle empty object at nested path', async () => { + const { logger } = require('@librechat/data-schemas'); + process.env.OPENID_REQUIRED_ROLE = 'user'; + process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'access.roles'; + + jwtDecode.mockReturnValue({ + access: {}, + }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + const { user } = await validate(tokenset); + + expect(logger.error).toHaveBeenCalledWith( + expect.stringContaining("Key 'access.roles' not found in id token!"), + ); + expect(user).toBe(false); + }); + + it('should handle null value at intermediate path', async () => { + const { logger } = require('@librechat/data-schemas'); + process.env.OPENID_REQUIRED_ROLE = 'user'; + process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'data.roles'; + + jwtDecode.mockReturnValue({ + data: null, + }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + const { user } = await validate(tokenset); + + expect(logger.error).toHaveBeenCalledWith( + expect.stringContaining("Key 'data.roles' not found in id token!"), + ); + expect(user).toBe(false); + }); + + it('should reject login with invalid admin role token kind', async () => { + process.env.OPENID_ADMIN_ROLE = 'admin'; + process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'roles'; + process.env.OPENID_ADMIN_ROLE_TOKEN_KIND = 'invalid'; + + const { logger } = require('@librechat/data-schemas'); + + jwtDecode.mockReturnValue({ + roles: ['requiredRole', 'admin'], + }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + await expect(validate(tokenset)).rejects.toThrow('Invalid admin role token kind'); + + expect(logger.error).toHaveBeenCalledWith( + expect.stringContaining( + "Invalid admin role token kind: invalid. Must be one of 'access', 'id', or 'userinfo'", + ), + ); + }); + + it('should reject login when roles path returns invalid type (object)', async () => { + const { logger } = require('@librechat/data-schemas'); + process.env.OPENID_REQUIRED_ROLE = 'app-user'; + process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'roles'; + + jwtDecode.mockReturnValue({ + roles: { admin: true, user: false }, + }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + const { user, details } = await validate(tokenset); + + expect(logger.error).toHaveBeenCalledWith( + expect.stringContaining("Key 'roles' not found in id token!"), + ); + expect(user).toBe(false); + expect(details.message).toContain('role to log in'); + }); + + it('should reject login when roles path returns invalid type (number)', async () => { + const { logger } = require('@librechat/data-schemas'); + process.env.OPENID_REQUIRED_ROLE = 'user'; + process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'roleCount'; + + jwtDecode.mockReturnValue({ + roleCount: 5, + }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + const { user } = await validate(tokenset); + + expect(logger.error).toHaveBeenCalledWith( + expect.stringContaining("Key 'roleCount' not found in id token!"), + ); + expect(user).toBe(false); + }); + }); + + describe('OPENID_EMAIL_CLAIM', () => { + it('should use the default email when OPENID_EMAIL_CLAIM is not set', async () => { + const { user } = await validate(tokenset); + expect(user.email).toBe('test@example.com'); + }); + + it('should use the configured claim when OPENID_EMAIL_CLAIM is set', async () => { + process.env.OPENID_EMAIL_CLAIM = 'upn'; + const userinfo = { ...tokenset.claims(), upn: 'user@corp.example.com' }; + + const { user } = await validate({ ...tokenset, claims: () => userinfo }); + + expect(user.email).toBe('user@corp.example.com'); + expect(createUser).toHaveBeenCalledWith( + expect.objectContaining({ email: 'user@corp.example.com' }), + expect.anything(), + true, + true, + ); + }); + + it('should fall back to preferred_username when email is missing and OPENID_EMAIL_CLAIM is not set', async () => { + const userinfo = { ...tokenset.claims() }; + delete userinfo.email; + + const { user } = await validate({ ...tokenset, claims: () => userinfo }); + + expect(user.email).toBe('testusername'); + }); + + it('should fall back to upn when email and preferred_username are missing and OPENID_EMAIL_CLAIM is not set', async () => { + const userinfo = { ...tokenset.claims(), upn: 'user@corp.example.com' }; + delete userinfo.email; + delete userinfo.preferred_username; + + const { user } = await validate({ ...tokenset, claims: () => userinfo }); + + expect(user.email).toBe('user@corp.example.com'); + }); + + it('should ignore empty string OPENID_EMAIL_CLAIM and use default fallback', async () => { + process.env.OPENID_EMAIL_CLAIM = ''; + + const { user } = await validate(tokenset); + + expect(user.email).toBe('test@example.com'); + }); + + it('should trim whitespace from OPENID_EMAIL_CLAIM and resolve correctly', async () => { + process.env.OPENID_EMAIL_CLAIM = ' upn '; + const userinfo = { ...tokenset.claims(), upn: 'user@corp.example.com' }; + + const { user } = await validate({ ...tokenset, claims: () => userinfo }); + + expect(user.email).toBe('user@corp.example.com'); + }); + + it('should ignore whitespace-only OPENID_EMAIL_CLAIM and use default fallback', async () => { + process.env.OPENID_EMAIL_CLAIM = ' '; + + const { user } = await validate(tokenset); + + expect(user.email).toBe('test@example.com'); + }); + + it('should fall back to default chain with warning when configured claim is missing from userinfo', async () => { + const { logger } = require('@librechat/data-schemas'); + process.env.OPENID_EMAIL_CLAIM = 'nonexistent_claim'; + + const { user } = await validate(tokenset); + + expect(user.email).toBe('test@example.com'); + expect(logger.warn).toHaveBeenCalledWith( + expect.stringContaining('OPENID_EMAIL_CLAIM="nonexistent_claim" not present in userinfo'), + ); + }); + }); + + describe('Tenant-scoped config', () => { + it('should call resolveAppConfigForUser for tenant user', async () => { + const existingUser = { + _id: 'openid-tenant-user', + provider: 'openid', + openidId: '1234', + email: 'test@example.com', + tenantId: 'tenant-d', + role: 'USER', + }; + findUser.mockResolvedValue(existingUser); + + await validate(tokenset); + + expect(resolveAppConfigForUser).toHaveBeenCalledWith(getAppConfig, existingUser); + }); + + it('should use baseConfig for new user without calling resolveAppConfigForUser', async () => { + findUser.mockResolvedValue(null); + + await validate(tokenset); + + expect(resolveAppConfigForUser).not.toHaveBeenCalled(); + expect(getAppConfig).toHaveBeenCalledWith({ baseOnly: true }); + }); + + it('should block login when tenant config restricts the domain', async () => { + const { isEmailDomainAllowed } = require('@librechat/api'); + const existingUser = { + _id: 'openid-tenant-blocked', + provider: 'openid', + openidId: '1234', + email: 'test@example.com', + tenantId: 'tenant-restrict', + role: 'USER', + }; + findUser.mockResolvedValue(existingUser); + resolveAppConfigForUser.mockResolvedValue({ + registration: { allowedDomains: ['other.com'] }, + }); + isEmailDomainAllowed.mockReturnValueOnce(true).mockReturnValueOnce(false); + + const { user, details } = await validate(tokenset); + expect(user).toBe(false); + expect(details).toEqual({ message: 'Email domain not allowed' }); + }); + }); +}); diff --git a/api/strategies/samlStrategy.js b/api/strategies/samlStrategy.js index 843baf8a64..21e7bdd001 100644 --- a/api/strategies/samlStrategy.js +++ b/api/strategies/samlStrategy.js @@ -5,7 +5,11 @@ const passport = require('passport'); const { ErrorTypes } = require('librechat-data-provider'); const { hashToken, logger } = require('@librechat/data-schemas'); const { Strategy: SamlStrategy } = require('@node-saml/passport-saml'); -const { getBalanceConfig, isEmailDomainAllowed } = require('@librechat/api'); +const { + getBalanceConfig, + isEmailDomainAllowed, + resolveAppConfigForUser, +} = require('@librechat/api'); const { getStrategyFunctions } = require('~/server/services/Files/strategies'); const { findUser, createUser, updateUser } = require('~/models'); const { getAppConfig } = require('~/server/services/Config'); @@ -193,9 +197,9 @@ async function setupSaml() { logger.debug('[samlStrategy] SAML profile:', profile); const userEmail = getEmail(profile) || ''; - const appConfig = await getAppConfig(); - if (!isEmailDomainAllowed(userEmail, appConfig?.registration?.allowedDomains)) { + const baseConfig = await getAppConfig({ baseOnly: true }); + if (!isEmailDomainAllowed(userEmail, baseConfig?.registration?.allowedDomains)) { logger.error( `[SAML Strategy] Authentication blocked - email domain not allowed [Email: ${userEmail}]`, ); @@ -223,6 +227,17 @@ async function setupSaml() { }); } + const appConfig = user?.tenantId + ? await resolveAppConfigForUser(getAppConfig, user) + : baseConfig; + + if (!isEmailDomainAllowed(userEmail, appConfig?.registration?.allowedDomains)) { + logger.error( + `[SAML Strategy] Authentication blocked - email domain not allowed [Email: ${userEmail}]`, + ); + return done(null, false, { message: 'Email domain not allowed' }); + } + const fullName = getFullName(profile); const username = convertToUsername( diff --git a/api/strategies/samlStrategy.spec.js b/api/strategies/samlStrategy.spec.js index 1d16719b87..2022d34b33 100644 --- a/api/strategies/samlStrategy.spec.js +++ b/api/strategies/samlStrategy.spec.js @@ -30,6 +30,7 @@ jest.mock('@librechat/api', () => ({ tokenCredits: 1000, startBalance: 1000, })), + resolveAppConfigForUser: jest.fn(async (_getAppConfig, _user) => ({})), })); jest.mock('~/server/services/Config/EndpointService', () => ({ config: {}, @@ -47,6 +48,9 @@ const fs = require('fs'); const path = require('path'); const fetch = require('node-fetch'); const { Strategy: SamlStrategy } = require('@node-saml/passport-saml'); +const { findUser } = require('~/models'); +const { resolveAppConfigForUser } = require('@librechat/api'); +const { getAppConfig } = require('~/server/services/Config'); const { setupSaml, getCertificateContent } = require('./samlStrategy'); // Configure fs mock @@ -440,4 +444,50 @@ u7wlOSk+oFzDIO/UILIA expect(fetch).not.toHaveBeenCalled(); }); + + it('should pass the found user to resolveAppConfigForUser', async () => { + const existingUser = { + _id: 'tenant-user-id', + provider: 'saml', + samlId: 'saml-1234', + email: 'test@example.com', + tenantId: 'tenant-c', + role: 'USER', + }; + findUser.mockResolvedValue(existingUser); + + const profile = { ...baseProfile }; + await validate(profile); + + expect(resolveAppConfigForUser).toHaveBeenCalledWith(getAppConfig, existingUser); + }); + + it('should use baseConfig for new SAML user without calling resolveAppConfigForUser', async () => { + const profile = { ...baseProfile }; + await validate(profile); + + expect(resolveAppConfigForUser).not.toHaveBeenCalled(); + expect(getAppConfig).toHaveBeenCalledWith({ baseOnly: true }); + }); + + it('should block login when tenant config restricts the domain', async () => { + const { isEmailDomainAllowed } = require('@librechat/api'); + const existingUser = { + _id: 'tenant-blocked', + provider: 'saml', + samlId: 'saml-1234', + email: 'test@example.com', + tenantId: 'tenant-restrict', + role: 'USER', + }; + findUser.mockResolvedValue(existingUser); + resolveAppConfigForUser.mockResolvedValue({ + registration: { allowedDomains: ['other.com'] }, + }); + isEmailDomainAllowed.mockReturnValueOnce(true).mockReturnValueOnce(false); + + const profile = { ...baseProfile }; + const { user } = await validate(profile); + expect(user).toBe(false); + }); }); diff --git a/api/strategies/socialLogin.js b/api/strategies/socialLogin.js index 88fb347042..a5fe78e17d 100644 --- a/api/strategies/socialLogin.js +++ b/api/strategies/socialLogin.js @@ -1,6 +1,6 @@ const { logger } = require('@librechat/data-schemas'); const { ErrorTypes } = require('librechat-data-provider'); -const { isEnabled, isEmailDomainAllowed } = require('@librechat/api'); +const { isEnabled, isEmailDomainAllowed, resolveAppConfigForUser } = require('@librechat/api'); const { createSocialUser, handleExistingUser } = require('./process'); const { getAppConfig } = require('~/server/services/Config'); const { findUser } = require('~/models'); @@ -13,9 +13,8 @@ const socialLogin = profile, }); - const appConfig = await getAppConfig(); - - if (!isEmailDomainAllowed(email, appConfig?.registration?.allowedDomains)) { + const baseConfig = await getAppConfig({ baseOnly: true }); + if (!isEmailDomainAllowed(email, baseConfig?.registration?.allowedDomains)) { logger.error( `[${provider}Login] Authentication blocked - email domain not allowed [Email: ${email}]`, ); @@ -41,6 +40,20 @@ const socialLogin = } } + const appConfig = existingUser?.tenantId + ? await resolveAppConfigForUser(getAppConfig, existingUser) + : baseConfig; + + if (!isEmailDomainAllowed(email, appConfig?.registration?.allowedDomains)) { + logger.error( + `[${provider}Login] Authentication blocked - email domain not allowed [Email: ${email}]`, + ); + const error = new Error(ErrorTypes.AUTH_FAILED); + error.code = ErrorTypes.AUTH_FAILED; + error.message = 'Email domain not allowed'; + return cb(error); + } + if (existingUser?.provider === provider) { await handleExistingUser(existingUser, avatarUrl, appConfig, email); return cb(null, existingUser); diff --git a/api/strategies/socialLogin.test.js b/api/strategies/socialLogin.test.js index ba4778c8b1..4fde397d55 100644 --- a/api/strategies/socialLogin.test.js +++ b/api/strategies/socialLogin.test.js @@ -3,6 +3,8 @@ const { ErrorTypes } = require('librechat-data-provider'); const { createSocialUser, handleExistingUser } = require('./process'); const socialLogin = require('./socialLogin'); const { findUser } = require('~/models'); +const { resolveAppConfigForUser } = require('@librechat/api'); +const { getAppConfig } = require('~/server/services/Config'); jest.mock('@librechat/data-schemas', () => { const actualModule = jest.requireActual('@librechat/data-schemas'); @@ -25,6 +27,10 @@ jest.mock('@librechat/api', () => ({ ...jest.requireActual('@librechat/api'), isEnabled: jest.fn().mockReturnValue(true), isEmailDomainAllowed: jest.fn().mockReturnValue(true), + resolveAppConfigForUser: jest.fn().mockResolvedValue({ + fileStrategy: 'local', + balance: { enabled: false }, + }), })); jest.mock('~/models', () => ({ @@ -66,10 +72,7 @@ describe('socialLogin', () => { googleId: googleId, }; - /** Mock findUser to return user on first call (by googleId), null on second call */ - findUser - .mockResolvedValueOnce(existingUser) // First call: finds by googleId - .mockResolvedValueOnce(null); // Second call would be by email, but won't be reached + findUser.mockResolvedValueOnce(existingUser).mockResolvedValueOnce(null); const mockProfile = { id: googleId, @@ -83,13 +86,9 @@ describe('socialLogin', () => { await loginFn(null, null, null, mockProfile, callback); - /** Verify it searched by googleId first */ expect(findUser).toHaveBeenNthCalledWith(1, { googleId: googleId }); - - /** Verify it did NOT search by email (because it found user by googleId) */ expect(findUser).toHaveBeenCalledTimes(1); - /** Verify handleExistingUser was called with the new email */ expect(handleExistingUser).toHaveBeenCalledWith( existingUser, 'https://example.com/avatar.png', @@ -97,7 +96,6 @@ describe('socialLogin', () => { newEmail, ); - /** Verify callback was called with success */ expect(callback).toHaveBeenCalledWith(null, existingUser); }); @@ -113,7 +111,7 @@ describe('socialLogin', () => { facebookId: facebookId, }; - findUser.mockResolvedValue(existingUser); // Always returns user + findUser.mockResolvedValue(existingUser); const mockProfile = { id: facebookId, @@ -127,7 +125,6 @@ describe('socialLogin', () => { await loginFn(null, null, null, mockProfile, callback); - /** Verify it searched by facebookId first */ expect(findUser).toHaveBeenCalledWith({ facebookId: facebookId }); expect(findUser.mock.calls[0]).toEqual([{ facebookId: facebookId }]); @@ -150,13 +147,10 @@ describe('socialLogin', () => { _id: 'user789', email: email, provider: 'google', - googleId: 'old-google-id', // Different googleId (edge case) + googleId: 'old-google-id', }; - /** First call (by googleId) returns null, second call (by email) returns user */ - findUser - .mockResolvedValueOnce(null) // By googleId - .mockResolvedValueOnce(existingUser); // By email + findUser.mockResolvedValueOnce(null).mockResolvedValueOnce(existingUser); const mockProfile = { id: googleId, @@ -170,13 +164,10 @@ describe('socialLogin', () => { await loginFn(null, null, null, mockProfile, callback); - /** Verify both searches happened */ expect(findUser).toHaveBeenNthCalledWith(1, { googleId: googleId }); - /** Email passed as-is; findUser implementation handles case normalization */ expect(findUser).toHaveBeenNthCalledWith(2, { email: email }); expect(findUser).toHaveBeenCalledTimes(2); - /** Verify warning log */ expect(logger.warn).toHaveBeenCalledWith( `[${provider}Login] User found by email: ${email} but not by ${provider}Id`, ); @@ -197,7 +188,6 @@ describe('socialLogin', () => { googleId: googleId, }; - /** Both searches return null */ findUser.mockResolvedValue(null); createSocialUser.mockResolvedValue(newUser); @@ -213,10 +203,8 @@ describe('socialLogin', () => { await loginFn(null, null, null, mockProfile, callback); - /** Verify both searches happened */ expect(findUser).toHaveBeenCalledTimes(2); - /** Verify createSocialUser was called */ expect(createSocialUser).toHaveBeenCalledWith({ email: email, avatarUrl: 'https://example.com/avatar.png', @@ -242,12 +230,10 @@ describe('socialLogin', () => { const existingUser = { _id: 'user123', email: email, - provider: 'local', // Different provider + provider: 'local', }; - findUser - .mockResolvedValueOnce(null) // By googleId - .mockResolvedValueOnce(existingUser); // By email + findUser.mockResolvedValueOnce(null).mockResolvedValueOnce(existingUser); const mockProfile = { id: googleId, @@ -261,7 +247,6 @@ describe('socialLogin', () => { await loginFn(null, null, null, mockProfile, callback); - /** Verify error callback */ expect(callback).toHaveBeenCalledWith( expect.objectContaining({ code: ErrorTypes.AUTH_FAILED, @@ -274,4 +259,104 @@ describe('socialLogin', () => { ); }); }); + + describe('Tenant-scoped config', () => { + it('should call resolveAppConfigForUser for tenant user', async () => { + const provider = 'google'; + const googleId = 'google-tenant-user'; + const email = 'tenant@example.com'; + + const existingUser = { + _id: 'userTenant', + email, + provider: 'google', + googleId, + tenantId: 'tenant-b', + role: 'USER', + }; + + findUser.mockResolvedValue(existingUser); + + const mockProfile = { + id: googleId, + emails: [{ value: email, verified: true }], + photos: [{ value: 'https://example.com/avatar.png' }], + name: { givenName: 'Tenant', familyName: 'User' }, + }; + + const loginFn = socialLogin(provider, mockGetProfileDetails); + const callback = jest.fn(); + + await loginFn(null, null, null, mockProfile, callback); + + expect(resolveAppConfigForUser).toHaveBeenCalledWith(getAppConfig, existingUser); + }); + + it('should use baseConfig for non-tenant user without calling resolveAppConfigForUser', async () => { + const provider = 'google'; + const googleId = 'google-new-tenant'; + const email = 'new@example.com'; + + findUser.mockResolvedValue(null); + createSocialUser.mockResolvedValue({ + _id: 'newUser', + email, + provider: 'google', + googleId, + }); + + const mockProfile = { + id: googleId, + emails: [{ value: email, verified: true }], + photos: [{ value: 'https://example.com/avatar.png' }], + name: { givenName: 'New', familyName: 'User' }, + }; + + const loginFn = socialLogin(provider, mockGetProfileDetails); + const callback = jest.fn(); + + await loginFn(null, null, null, mockProfile, callback); + + expect(resolveAppConfigForUser).not.toHaveBeenCalled(); + expect(getAppConfig).toHaveBeenCalledWith({ baseOnly: true }); + }); + + it('should block login when tenant config restricts the domain', async () => { + const { isEmailDomainAllowed } = require('@librechat/api'); + const provider = 'google'; + const googleId = 'google-tenant-blocked'; + const email = 'blocked@example.com'; + + const existingUser = { + _id: 'userBlocked', + email, + provider: 'google', + googleId, + tenantId: 'tenant-restrict', + role: 'USER', + }; + + findUser.mockResolvedValue(existingUser); + resolveAppConfigForUser.mockResolvedValue({ + registration: { allowedDomains: ['other.com'] }, + }); + isEmailDomainAllowed.mockReturnValueOnce(true).mockReturnValueOnce(false); + + const mockProfile = { + id: googleId, + emails: [{ value: email, verified: true }], + photos: [{ value: 'https://example.com/avatar.png' }], + name: { givenName: 'Blocked', familyName: 'User' }, + }; + + const loginFn = socialLogin(provider, mockGetProfileDetails); + const callback = jest.fn(); + + await loginFn(null, null, null, mockProfile, callback); + + expect(callback).toHaveBeenCalledWith( + expect.objectContaining({ message: 'Email domain not allowed' }), + ); + }); + }); }); diff --git a/api/utils/tokens.spec.js b/api/utils/tokens.spec.js index 6cecdb95c8..dfa6762ee5 100644 --- a/api/utils/tokens.spec.js +++ b/api/utils/tokens.spec.js @@ -1813,3 +1813,57 @@ describe('GLM Model Tests (Zhipu AI)', () => { }); }); }); + +describe('Mistral Model Tests', () => { + describe('getModelMaxTokens', () => { + test('should return correct tokens for mistral-large-3 (256k context)', () => { + expect(getModelMaxTokens('mistral-large-3', EModelEndpoint.custom)).toBe( + maxTokensMap[EModelEndpoint.custom]['mistral-large-3'], + ); + }); + + test('should match mistral-large-3 for suffixed variants', () => { + expect(getModelMaxTokens('mistral-large-3-instruct', EModelEndpoint.custom)).toBe( + maxTokensMap[EModelEndpoint.custom]['mistral-large-3'], + ); + }); + + test('should not match mistral-large-3 for generic mistral-large', () => { + expect(getModelMaxTokens('mistral-large', EModelEndpoint.custom)).toBe( + maxTokensMap[EModelEndpoint.custom]['mistral-large'], + ); + expect(getModelMaxTokens('mistral-large-latest', EModelEndpoint.custom)).toBe( + maxTokensMap[EModelEndpoint.custom]['mistral-large'], + ); + }); + }); + + describe('matchModelName', () => { + test('should match mistral-large-3 exactly', () => { + expect(matchModelName('mistral-large-3', EModelEndpoint.custom)).toBe('mistral-large-3'); + }); + + test('should match mistral-large-3 for prefixed/suffixed variants', () => { + expect(matchModelName('mistral/mistral-large-3', EModelEndpoint.custom)).toBe( + 'mistral-large-3', + ); + expect(matchModelName('mistral-large-3-instruct', EModelEndpoint.custom)).toBe( + 'mistral-large-3', + ); + }); + + test('should match generic mistral-large for non-3 variants', () => { + expect(matchModelName('mistral-large-latest', EModelEndpoint.custom)).toBe('mistral-large'); + }); + }); + + describe('findMatchingPattern', () => { + test('should prefer mistral-large-3 over mistral-large for mistral-large-3 variants', () => { + const result = findMatchingPattern( + 'mistral-large-3-instruct', + maxTokensMap[EModelEndpoint.custom], + ); + expect(result).toBe('mistral-large-3'); + }); + }); +}); diff --git a/client/src/Providers/MessagesViewContext.tsx b/client/src/Providers/MessagesViewContext.tsx index f1cae204a4..c44972918c 100644 --- a/client/src/Providers/MessagesViewContext.tsx +++ b/client/src/Providers/MessagesViewContext.tsx @@ -140,6 +140,55 @@ export function useMessagesOperations() { ); } +type OptionalMessagesOps = Pick< + MessagesViewContextValue, + 'ask' | 'regenerate' | 'handleContinue' | 'getMessages' | 'setMessages' +>; + +const NOOP_OPS: OptionalMessagesOps = { + ask: () => {}, + regenerate: () => {}, + handleContinue: () => {}, + getMessages: () => undefined, + setMessages: () => {}, +}; + +/** + * Hook for components that need message operations but may render outside MessagesViewProvider + * (e.g. the /search route). Returns no-op stubs when the provider is absent — UI actions will + * be silently discarded rather than crashing. Callers must use optional chaining on + * `getMessages()` results, as it returns `undefined` outside the provider. + */ +export function useOptionalMessagesOperations(): OptionalMessagesOps { + const context = useContext(MessagesViewContext); + const ask = context?.ask; + const regenerate = context?.regenerate; + const handleContinue = context?.handleContinue; + const getMessages = context?.getMessages; + const setMessages = context?.setMessages; + return useMemo( + () => ({ + ask: ask ?? NOOP_OPS.ask, + regenerate: regenerate ?? NOOP_OPS.regenerate, + handleContinue: handleContinue ?? NOOP_OPS.handleContinue, + getMessages: getMessages ?? NOOP_OPS.getMessages, + setMessages: setMessages ?? NOOP_OPS.setMessages, + }), + [ask, regenerate, handleContinue, getMessages, setMessages], + ); +} + +/** + * Hook for components that need conversation data but may render outside MessagesViewProvider + * (e.g. the /search route). Returns `undefined` for both fields when the provider is absent. + */ +export function useOptionalMessagesConversation() { + const context = useContext(MessagesViewContext); + const conversation = context?.conversation; + const conversationId = context?.conversationId; + return useMemo(() => ({ conversation, conversationId }), [conversation, conversationId]); +} + /** Hook for components that only need message state */ export function useMessagesState() { const { index, latestMessageId, latestMessageDepth, setLatestMessage } = useMessagesViewContext(); diff --git a/client/src/Providers/__tests__/MessagesViewContext.spec.tsx b/client/src/Providers/__tests__/MessagesViewContext.spec.tsx new file mode 100644 index 0000000000..88cd6f702d --- /dev/null +++ b/client/src/Providers/__tests__/MessagesViewContext.spec.tsx @@ -0,0 +1,53 @@ +import { renderHook } from '@testing-library/react'; +import { + useOptionalMessagesOperations, + useOptionalMessagesConversation, +} from '../MessagesViewContext'; + +describe('useOptionalMessagesOperations', () => { + it('returns noop stubs when rendered outside MessagesViewProvider', () => { + const { result } = renderHook(() => useOptionalMessagesOperations()); + + expect(result.current.ask).toBeInstanceOf(Function); + expect(result.current.regenerate).toBeInstanceOf(Function); + expect(result.current.handleContinue).toBeInstanceOf(Function); + expect(result.current.getMessages).toBeInstanceOf(Function); + expect(result.current.setMessages).toBeInstanceOf(Function); + }); + + it('noop stubs do not throw when called', () => { + const { result } = renderHook(() => useOptionalMessagesOperations()); + + expect(() => result.current.ask({} as never)).not.toThrow(); + expect(() => result.current.regenerate({} as never)).not.toThrow(); + expect(() => result.current.handleContinue({} as never)).not.toThrow(); + expect(() => result.current.setMessages([])).not.toThrow(); + }); + + it('getMessages returns undefined outside the provider', () => { + const { result } = renderHook(() => useOptionalMessagesOperations()); + expect(result.current.getMessages()).toBeUndefined(); + }); + + it('returns stable references across re-renders', () => { + const { result, rerender } = renderHook(() => useOptionalMessagesOperations()); + const first = result.current; + rerender(); + expect(result.current).toBe(first); + }); +}); + +describe('useOptionalMessagesConversation', () => { + it('returns undefined fields when rendered outside MessagesViewProvider', () => { + const { result } = renderHook(() => useOptionalMessagesConversation()); + expect(result.current.conversation).toBeUndefined(); + expect(result.current.conversationId).toBeUndefined(); + }); + + it('returns stable references across re-renders', () => { + const { result, rerender } = renderHook(() => useOptionalMessagesConversation()); + const first = result.current; + rerender(); + expect(result.current).toBe(first); + }); +}); diff --git a/client/src/common/types.ts b/client/src/common/types.ts index 85044bb2bc..6ca408685f 100644 --- a/client/src/common/types.ts +++ b/client/src/common/types.ts @@ -355,6 +355,28 @@ export type TOptions = { export type TAskFunction = (props: TAskProps, options?: TOptions) => void; +/** + * Stable context object passed from non-memo'd wrapper components (Message, MessageContent) + * to memo'd inner components (MessageRender, ContentRender) via props. + * + * This avoids subscribing to ChatContext inside memo'd components, which would bypass React.memo + * and cause unnecessary re-renders when `isSubmitting` changes during streaming. + * + * The `isSubmitting` property should use a getter backed by a ref so it returns the current + * value at call-time (for callback guards) without being a reactive dependency. + */ +export type TMessageChatContext = { + ask: (...args: Parameters) => void; + index: number; + regenerate: (message: t.TMessage, options?: { addedConvo?: t.TConversation | null }) => void; + conversation: t.TConversation | null; + latestMessageId: string | undefined; + latestMessageDepth: number | undefined; + handleContinue: (e: React.MouseEvent) => void; + /** Should be a getter backed by a ref — reads current value without triggering re-renders */ + readonly isSubmitting: boolean; +}; + export type TMessageProps = { conversation?: t.TConversation | null; messageId?: string | null; diff --git a/client/src/components/Chat/Input/AudioRecorder.tsx b/client/src/components/Chat/Input/AudioRecorder.tsx index dbf2c29d83..e9e19d0904 100644 --- a/client/src/components/Chat/Input/AudioRecorder.tsx +++ b/client/src/components/Chat/Input/AudioRecorder.tsx @@ -1,4 +1,4 @@ -import { useCallback, useRef } from 'react'; +import { memo, useCallback, useRef } from 'react'; import { MicOff } from 'lucide-react'; import { useToastContext, TooltipAnchor, ListeningIcon, Spinner } from '@librechat/client'; import { useLocalize, useSpeechToText, useGetAudioSettings } from '~/hooks'; @@ -7,7 +7,7 @@ import { globalAudioId } from '~/common'; import { cn } from '~/utils'; const isExternalSTT = (speechToTextEndpoint: string) => speechToTextEndpoint === 'external'; -export default function AudioRecorder({ +export default memo(function AudioRecorder({ disabled, ask, methods, @@ -26,10 +26,12 @@ export default function AudioRecorder({ const { speechToTextEndpoint } = useGetAudioSettings(); const existingTextRef = useRef(''); + const isSubmittingRef = useRef(isSubmitting); + isSubmittingRef.current = isSubmitting; const onTranscriptionComplete = useCallback( (text: string) => { - if (isSubmitting) { + if (isSubmittingRef.current) { showToast({ message: localize('com_ui_speech_while_submitting'), status: 'error', @@ -52,7 +54,7 @@ export default function AudioRecorder({ existingTextRef.current = ''; } }, - [ask, reset, showToast, localize, isSubmitting, speechToTextEndpoint], + [ask, reset, showToast, localize, speechToTextEndpoint], ); const setText = useCallback( @@ -125,4 +127,4 @@ export default function AudioRecorder({ } /> ); -} +}); diff --git a/client/src/components/Chat/Input/ChatForm.tsx b/client/src/components/Chat/Input/ChatForm.tsx index fed355dcb3..9e0ad7f382 100644 --- a/client/src/components/Chat/Input/ChatForm.tsx +++ b/client/src/components/Chat/Input/ChatForm.tsx @@ -3,6 +3,8 @@ import { useWatch } from 'react-hook-form'; import { TextareaAutosize } from '@librechat/client'; import { useRecoilState, useRecoilValue } from 'recoil'; import { Constants, isAssistantsEndpoint, isAgentsEndpoint } from 'librechat-data-provider'; +import type { TConversation } from 'librechat-data-provider'; +import type { ExtendedFile, FileSetter, ConvoGenerator } from '~/common'; import { useChatContext, useChatFormContext, @@ -35,7 +37,30 @@ import BadgeRow from './BadgeRow'; import Mention from './Mention'; import store from '~/store'; -const ChatForm = memo(({ index = 0 }: { index?: number }) => { +interface ChatFormProps { + index: number; + /** From ChatContext — individual values so memo can compare them */ + files: Map; + setFiles: FileSetter; + conversation: TConversation | null; + isSubmitting: boolean; + filesLoading: boolean; + setFilesLoading: React.Dispatch>; + newConversation: ConvoGenerator; + handleStopGenerating: (e: React.MouseEvent) => void; +} + +const ChatForm = memo(function ChatForm({ + index, + files, + setFiles, + conversation, + isSubmitting, + filesLoading, + setFilesLoading, + newConversation, + handleStopGenerating, +}: ChatFormProps) { const submitButtonRef = useRef(null); const textAreaRef = useRef(null); useFocusChatEffect(textAreaRef); @@ -65,15 +90,6 @@ const ChatForm = memo(({ index = 0 }: { index?: number }) => { const { requiresKey } = useRequiresKey(); const methods = useChatFormContext(); - const { - files, - setFiles, - conversation, - isSubmitting, - filesLoading, - newConversation, - handleStopGenerating, - } = useChatContext(); const { generateConversation, conversation: addedConvo, @@ -120,6 +136,15 @@ const ChatForm = memo(({ index = 0 }: { index?: number }) => { } }, [isCollapsed]); + const handleTextareaFocus = useCallback(() => { + handleFocusOrClick(); + setIsTextAreaFocused(true); + }, [handleFocusOrClick]); + + const handleTextareaBlur = useCallback(() => { + setIsTextAreaFocused(false); + }, []); + useAutoSave({ files, setFiles, @@ -253,7 +278,12 @@ const ChatForm = memo(({ index = 0 }: { index?: number }) => { handleSaveBadges={handleSaveBadges} setBadges={setBadges} /> - + {endpoint && (

{ tabIndex={0} data-testid="text-input" rows={1} - onFocus={() => { - handleFocusOrClick(); - setIsTextAreaFocused(true); - }} - onBlur={setIsTextAreaFocused.bind(null, false)} + onFocus={handleTextareaFocus} + onBlur={handleTextareaBlur} aria-label={localize('com_ui_message_input')} onClick={handleFocusOrClick} style={{ height: 44, overflowY: 'auto' }} @@ -315,7 +342,13 @@ const ChatForm = memo(({ index = 0 }: { index?: number }) => { )} >
- +
{ ); }); +ChatForm.displayName = 'ChatForm'; -export default ChatForm; +/** + * Wrapper that subscribes to ChatContext and passes stable individual values + * to the memo'd ChatForm. This prevents ChatForm from re-rendering on every + * streaming chunk — it only re-renders when the specific values it uses change. + */ +function ChatFormWrapper({ index = 0 }: { index?: number }) { + const { + files, + setFiles, + conversation, + isSubmitting, + filesLoading, + setFilesLoading, + newConversation, + handleStopGenerating, + } = useChatContext(); + + /** + * Stabilize conversation reference: only update when rendering-relevant fields change, + * not on every metadata update (e.g., title generation during streaming). + */ + const hasMessages = (conversation?.messages?.length ?? 0) > 0; + const stableConversation = useMemo( + () => conversation, + // eslint-disable-next-line react-hooks/exhaustive-deps + [ + conversation?.conversationId, + conversation?.endpoint, + conversation?.endpointType, + conversation?.agent_id, + conversation?.assistant_id, + conversation?.spec, + conversation?.useResponsesApi, + conversation?.model, + hasMessages, + ], + ); + + /** Stabilize function refs so they never trigger ChatForm re-renders */ + const handleStopRef = useRef(handleStopGenerating); + handleStopRef.current = handleStopGenerating; + const stableHandleStop = useCallback( + (e: React.MouseEvent) => handleStopRef.current(e), + [], + ); + + const newConvoRef = useRef(newConversation); + newConvoRef.current = newConversation; + const stableNewConversation: ConvoGenerator = useCallback( + (...args: Parameters): ReturnType => + newConvoRef.current(...args), + [], + ); + + return ( + + ); +} + +ChatFormWrapper.displayName = 'ChatFormWrapper'; + +export default ChatFormWrapper; diff --git a/client/src/components/Chat/Input/CollapseChat.tsx b/client/src/components/Chat/Input/CollapseChat.tsx index ea099ed69b..7efe52dc8d 100644 --- a/client/src/components/Chat/Input/CollapseChat.tsx +++ b/client/src/components/Chat/Input/CollapseChat.tsx @@ -52,4 +52,4 @@ const CollapseChat = ({ ); }; -export default CollapseChat; +export default React.memo(CollapseChat); diff --git a/client/src/components/Chat/Input/Files/AttachFile.tsx b/client/src/components/Chat/Input/Files/AttachFile.tsx index 38a3fa8c6f..098fa2c4c3 100644 --- a/client/src/components/Chat/Input/Files/AttachFile.tsx +++ b/client/src/components/Chat/Input/Files/AttachFile.tsx @@ -1,14 +1,33 @@ import React, { useRef } from 'react'; import { FileUpload, TooltipAnchor, AttachmentIcon } from '@librechat/client'; -import { useLocalize, useFileHandling } from '~/hooks'; +import type { TConversation } from 'librechat-data-provider'; +import type { ExtendedFile, FileSetter } from '~/common'; +import { useFileHandlingNoChatContext, useLocalize } from '~/hooks'; import { cn } from '~/utils'; -const AttachFile = ({ disabled }: { disabled?: boolean | null }) => { +const AttachFile = ({ + disabled, + files, + setFiles, + setFilesLoading, + conversation, +}: { + disabled?: boolean | null; + files: Map; + setFiles: FileSetter; + setFilesLoading: React.Dispatch>; + conversation: TConversation | null; +}) => { const localize = useLocalize(); const inputRef = useRef(null); const isUploadDisabled = disabled ?? false; - const { handleFileChange } = useFileHandling(); + const { handleFileChange } = useFileHandlingNoChatContext(undefined, { + files, + setFiles, + setFilesLoading, + conversation, + }); return ( diff --git a/client/src/components/Chat/Input/Files/AttachFileChat.tsx b/client/src/components/Chat/Input/Files/AttachFileChat.tsx index 2f954d01d5..7eb9b0c474 100644 --- a/client/src/components/Chat/Input/Files/AttachFileChat.tsx +++ b/client/src/components/Chat/Input/Files/AttachFileChat.tsx @@ -9,6 +9,7 @@ import { getEndpointFileConfig, } from 'librechat-data-provider'; import type { TConversation } from 'librechat-data-provider'; +import type { ExtendedFile, FileSetter } from '~/common'; import { useGetFileConfig, useGetEndpointsQuery, useGetAgentByIdQuery } from '~/data-provider'; import { useAgentsMapContext } from '~/Providers'; import AttachFileMenu from './AttachFileMenu'; @@ -17,9 +18,15 @@ import AttachFile from './AttachFile'; function AttachFileChat({ disableInputs, conversation, + files, + setFiles, + setFilesLoading, }: { disableInputs: boolean; conversation: TConversation | null; + files: Map; + setFiles: FileSetter; + setFilesLoading: React.Dispatch>; }) { const conversationId = conversation?.conversationId ?? Constants.NEW_CONVO; const { endpoint } = conversation ?? { endpoint: null }; @@ -90,7 +97,15 @@ function AttachFileChat({ ); if (isAssistants && endpointSupportsFiles && !isUploadDisabled) { - return ; + return ( + + ); } else if ((isAgents || endpointSupportsFiles) && !isUploadDisabled) { return ( ); } diff --git a/client/src/components/Chat/Input/Files/AttachFileMenu.tsx b/client/src/components/Chat/Input/Files/AttachFileMenu.tsx index 62072e49e5..181d219c08 100644 --- a/client/src/components/Chat/Input/Files/AttachFileMenu.tsx +++ b/client/src/components/Chat/Input/Files/AttachFileMenu.tsx @@ -23,15 +23,16 @@ import { bedrockDocumentExtensions, isDocumentSupportedProvider, } from 'librechat-data-provider'; -import type { EndpointFileConfig } from 'librechat-data-provider'; +import type { EndpointFileConfig, TConversation } from 'librechat-data-provider'; +import type { ExtendedFile, FileSetter } from '~/common'; import { useAgentToolPermissions, useAgentCapabilities, useGetAgentsConfig, - useFileHandling, + useFileHandlingNoChatContext, useLocalize, } from '~/hooks'; -import useSharePointFileHandling from '~/hooks/Files/useSharePointFileHandling'; +import { useSharePointFileHandlingNoChatContext } from '~/hooks/Files/useSharePointFileHandling'; import { SharePointPickerDialog } from '~/components/SharePoint'; import { useGetStartupConfig } from '~/data-provider'; import { ephemeralAgentByConvoId } from '~/store'; @@ -53,6 +54,10 @@ interface AttachFileMenuProps { endpointType?: EModelEndpoint | string; endpointFileConfig?: EndpointFileConfig; useResponsesApi?: boolean; + files: Map; + setFiles: FileSetter; + setFilesLoading: React.Dispatch>; + conversation: TConversation | null; } const AttachFileMenu = ({ @@ -63,6 +68,10 @@ const AttachFileMenu = ({ conversationId, endpointFileConfig, useResponsesApi, + files, + setFiles, + setFilesLoading, + conversation, }: AttachFileMenuProps) => { const localize = useLocalize(); const isUploadDisabled = disabled ?? false; @@ -72,10 +81,17 @@ const AttachFileMenu = ({ ephemeralAgentByConvoId(conversationId), ); const [toolResource, setToolResource] = useState(); - const { handleFileChange } = useFileHandling(); - const { handleSharePointFiles, isProcessing, downloadProgress } = useSharePointFileHandling({ - toolResource, + const { handleFileChange } = useFileHandlingNoChatContext(undefined, { + files, + setFiles, + setFilesLoading, + conversation, }); + const { handleSharePointFiles, isProcessing, downloadProgress } = + useSharePointFileHandlingNoChatContext( + { toolResource }, + { files, setFiles, setFilesLoading, conversation }, + ); const { agentsConfig } = useGetAgentsConfig(); const { data: startupConfig } = useGetStartupConfig(); diff --git a/client/src/components/Chat/Input/Files/FileFormChat.tsx b/client/src/components/Chat/Input/Files/FileFormChat.tsx index 3c01f2d642..4d37938c5d 100644 --- a/client/src/components/Chat/Input/Files/FileFormChat.tsx +++ b/client/src/components/Chat/Input/Files/FileFormChat.tsx @@ -1,16 +1,30 @@ import { memo } from 'react'; import { useRecoilValue } from 'recoil'; import type { TConversation } from 'librechat-data-provider'; -import { useChatContext } from '~/Providers'; -import { useFileHandling } from '~/hooks'; +import type { ExtendedFile, FileSetter } from '~/common'; +import { useFileHandlingNoChatContext } from '~/hooks'; import FileRow from './FileRow'; import store from '~/store'; -function FileFormChat({ conversation }: { conversation: TConversation | null }) { - const { files, setFiles, setFilesLoading } = useChatContext(); +function FileFormChat({ + conversation, + files, + setFiles, + setFilesLoading, +}: { + conversation: TConversation | null; + files: Map; + setFiles: FileSetter; + setFilesLoading: React.Dispatch>; +}) { const chatDirection = useRecoilValue(store.chatDirection).toLowerCase(); const { endpoint: _endpoint } = conversation ?? { endpoint: null }; - const { abortUpload } = useFileHandling(); + const { abortUpload } = useFileHandlingNoChatContext(undefined, { + files, + setFiles, + setFilesLoading, + conversation, + }); const isRTL = chatDirection === 'rtl'; diff --git a/client/src/components/Chat/Input/Files/__tests__/AttachFileChat.spec.tsx b/client/src/components/Chat/Input/Files/__tests__/AttachFileChat.spec.tsx index cea55f5ce8..80f06a1b89 100644 --- a/client/src/components/Chat/Input/Files/__tests__/AttachFileChat.spec.tsx +++ b/client/src/components/Chat/Input/Files/__tests__/AttachFileChat.spec.tsx @@ -59,7 +59,13 @@ function renderComponent(conversation: Record | null, disableIn return render( - + {}} + setFilesLoading={() => {}} + /> , ); diff --git a/client/src/components/Chat/Input/Files/__tests__/AttachFileMenu.spec.tsx b/client/src/components/Chat/Input/Files/__tests__/AttachFileMenu.spec.tsx index cf08721207..c2710d4ef8 100644 --- a/client/src/components/Chat/Input/Files/__tests__/AttachFileMenu.spec.tsx +++ b/client/src/components/Chat/Input/Files/__tests__/AttachFileMenu.spec.tsx @@ -9,13 +9,14 @@ jest.mock('~/hooks', () => ({ useAgentToolPermissions: jest.fn(), useAgentCapabilities: jest.fn(), useGetAgentsConfig: jest.fn(), - useFileHandling: jest.fn(), + useFileHandlingNoChatContext: jest.fn(), useLocalize: jest.fn(), })); jest.mock('~/hooks/Files/useSharePointFileHandling', () => ({ __esModule: true, default: jest.fn(), + useSharePointFileHandlingNoChatContext: jest.fn(), })); jest.mock('~/data-provider', () => ({ @@ -52,6 +53,7 @@ jest.mock('@librechat/client', () => { ), AttachmentIcon: () => R.createElement('span', { 'data-testid': 'attachment-icon' }), SharePointIcon: () => R.createElement('span', { 'data-testid': 'sharepoint-icon' }), + useToastContext: () => ({ showToast: jest.fn() }), }; }); @@ -66,11 +68,14 @@ jest.mock('@ariakit/react', () => { const mockUseAgentToolPermissions = jest.requireMock('~/hooks').useAgentToolPermissions; const mockUseAgentCapabilities = jest.requireMock('~/hooks').useAgentCapabilities; const mockUseGetAgentsConfig = jest.requireMock('~/hooks').useGetAgentsConfig; -const mockUseFileHandling = jest.requireMock('~/hooks').useFileHandling; +const mockUseFileHandlingNoChatContext = jest.requireMock('~/hooks').useFileHandlingNoChatContext; const mockUseLocalize = jest.requireMock('~/hooks').useLocalize; const mockUseSharePointFileHandling = jest.requireMock( '~/hooks/Files/useSharePointFileHandling', ).default; +const mockUseSharePointFileHandlingNoChatContext = jest.requireMock( + '~/hooks/Files/useSharePointFileHandling', +).useSharePointFileHandlingNoChatContext; const mockUseGetStartupConfig = jest.requireMock('~/data-provider').useGetStartupConfig; const queryClient = new QueryClient({ defaultOptions: { queries: { retry: false } } }); @@ -92,12 +97,15 @@ function setupMocks(overrides: { provider?: string } = {}) { codeEnabled: false, }); mockUseGetAgentsConfig.mockReturnValue({ agentsConfig: {} }); - mockUseFileHandling.mockReturnValue({ handleFileChange: jest.fn() }); - mockUseSharePointFileHandling.mockReturnValue({ + mockUseFileHandlingNoChatContext.mockReturnValue({ handleFileChange: jest.fn() }); + const sharePointReturnValue = { handleSharePointFiles: jest.fn(), isProcessing: false, downloadProgress: 0, - }); + error: null, + }; + mockUseSharePointFileHandling.mockReturnValue(sharePointReturnValue); + mockUseSharePointFileHandlingNoChatContext.mockReturnValue(sharePointReturnValue); mockUseGetStartupConfig.mockReturnValue({ data: { sharePointFilePickerEnabled: false } }); mockUseAgentToolPermissions.mockReturnValue({ fileSearchAllowedByAgent: false, @@ -110,7 +118,14 @@ function renderMenu(props: Record = {}) { return render( - + {}} + setFilesLoading={() => {}} + conversation={null} + {...props} + /> , ); diff --git a/client/src/components/Chat/Input/StopButton.tsx b/client/src/components/Chat/Input/StopButton.tsx index 4a058777f1..fd94ba806c 100644 --- a/client/src/components/Chat/Input/StopButton.tsx +++ b/client/src/components/Chat/Input/StopButton.tsx @@ -1,8 +1,15 @@ +import { memo } from 'react'; import { TooltipAnchor } from '@librechat/client'; import { useLocalize } from '~/hooks'; import { cn } from '~/utils'; -export default function StopButton({ stop, setShowStopButton }) { +export default memo(function StopButton({ + stop, + setShowStopButton, +}: { + stop: (e: React.MouseEvent) => void; + setShowStopButton: (value: boolean) => void; +}) { const localize = useLocalize(); return ( @@ -34,4 +41,4 @@ export default function StopButton({ stop, setShowStopButton }) { } > ); -} +}); diff --git a/client/src/components/Chat/Input/TextareaHeader.tsx b/client/src/components/Chat/Input/TextareaHeader.tsx index 9e67252efe..06c1802585 100644 --- a/client/src/components/Chat/Input/TextareaHeader.tsx +++ b/client/src/components/Chat/Input/TextareaHeader.tsx @@ -1,8 +1,9 @@ +import { memo } from 'react'; import AddedConvo from './AddedConvo'; import type { TConversation } from 'librechat-data-provider'; import type { SetterOrUpdater } from 'recoil'; -export default function TextareaHeader({ +export default memo(function TextareaHeader({ addedConvo, setAddedConvo, }: { @@ -17,4 +18,4 @@ export default function TextareaHeader({
); -} +}); diff --git a/client/src/components/Chat/Messages/Content/ToolCall.tsx b/client/src/components/Chat/Messages/Content/ToolCall.tsx index 5abdd45f98..c7dd974577 100644 --- a/client/src/components/Chat/Messages/Content/ToolCall.tsx +++ b/client/src/components/Chat/Messages/Content/ToolCall.tsx @@ -49,19 +49,47 @@ export default function ToolCall({ } }, [autoExpand, hasOutput]); + const parsedAuthUrl = useMemo(() => { + if (!auth) { + return null; + } + try { + return new URL(auth); + } catch { + return null; + } + }, [auth]); + const { function_name, domain, isMCPToolCall, mcpServerName } = useMemo(() => { if (typeof name !== 'string') { return { function_name: '', domain: null, isMCPToolCall: false, mcpServerName: '' }; } if (name.includes(Constants.mcp_delimiter)) { - const [func, server] = name.split(Constants.mcp_delimiter); + const parts = name.split(Constants.mcp_delimiter); + const func = parts[0]; + const server = parts.slice(1).join(Constants.mcp_delimiter); + const displayName = func === 'oauth' ? server : func; return { - function_name: func || '', + function_name: displayName || '', domain: server && (server.replaceAll(actionDomainSeparator, '.') || null), isMCPToolCall: true, mcpServerName: server || '', }; } + + if (parsedAuthUrl) { + const redirectUri = parsedAuthUrl.searchParams.get('redirect_uri') || ''; + const mcpMatch = redirectUri.match(/\/api\/mcp\/([^/]+)\/oauth\/callback/); + if (mcpMatch?.[1]) { + return { + function_name: mcpMatch[1], + domain: null, + isMCPToolCall: true, + mcpServerName: mcpMatch[1], + }; + } + } + const [func, _domain] = name.includes(actionDelimiter) ? name.split(actionDelimiter) : [name, '']; @@ -71,25 +99,20 @@ export default function ToolCall({ isMCPToolCall: false, mcpServerName: '', }; - }, [name]); + }, [name, parsedAuthUrl]); const toolIconType = useMemo(() => getToolIconType(name), [name]); const mcpIconMap = useMCPIconMap(); const mcpIconUrl = isMCPToolCall ? mcpIconMap.get(mcpServerName) : undefined; const actionId = useMemo(() => { - if (isMCPToolCall || !auth) { + if (isMCPToolCall || !parsedAuthUrl) { return ''; } - try { - const url = new URL(auth); - const redirectUri = url.searchParams.get('redirect_uri') || ''; - const match = redirectUri.match(/\/api\/actions\/([^/]+)\/oauth\/callback/); - return match?.[1] || ''; - } catch { - return ''; - } - }, [auth, isMCPToolCall]); + const redirectUri = parsedAuthUrl.searchParams.get('redirect_uri') || ''; + const match = redirectUri.match(/\/api\/actions\/([^/]+)\/oauth\/callback/); + return match?.[1] || ''; + }, [parsedAuthUrl, isMCPToolCall]); const handleOAuthClick = useCallback(async () => { if (!auth) { @@ -132,21 +155,8 @@ export default function ToolCall({ ); const authDomain = useMemo(() => { - const authURL = auth ?? ''; - if (!authURL) { - return ''; - } - try { - const url = new URL(authURL); - return url.hostname; - } catch (e) { - logger.error( - 'client/src/components/Chat/Messages/Content/ToolCall.tsx - Failed to parse auth URL', - e, - ); - return ''; - } - }, [auth]); + return parsedAuthUrl?.hostname ?? ''; + }, [parsedAuthUrl]); const progress = useProgress(initialProgress); const showCancelled = cancelled || (errorState && !output); diff --git a/client/src/components/Chat/Messages/Content/ToolCallInfo.tsx b/client/src/components/Chat/Messages/Content/ToolCallInfo.tsx index 59a564be4d..79ac78dbb2 100644 --- a/client/src/components/Chat/Messages/Content/ToolCallInfo.tsx +++ b/client/src/components/Chat/Messages/Content/ToolCallInfo.tsx @@ -3,11 +3,11 @@ import { ChevronDown } from 'lucide-react'; import { Tools } from 'librechat-data-provider'; import { UIResourceRenderer } from '@mcp-ui/client'; import type { TAttachment, UIResource } from 'librechat-data-provider'; +import { useOptionalMessagesOperations } from '~/Providers'; import { useLocalize, useExpandCollapse } from '~/hooks'; import UIResourceCarousel from './UIResourceCarousel'; -import { useMessagesOperations } from '~/Providers'; -import { OutputRenderer } from './ToolOutput'; import { handleUIAction, cn } from '~/utils'; +import { OutputRenderer } from './ToolOutput'; function isSimpleObject(obj: unknown): obj is Record { if (typeof obj !== 'object' || obj === null || Array.isArray(obj)) { @@ -102,7 +102,7 @@ export default function ToolCallInfo({ attachments?: TAttachment[]; }) { const localize = useLocalize(); - const { ask } = useMessagesOperations(); + const { ask } = useOptionalMessagesOperations(); const [showParams, setShowParams] = useState(false); const { style: paramsExpandStyle, ref: paramsExpandRef } = useExpandCollapse(showParams); diff --git a/client/src/components/Chat/Messages/Content/UIResourceCarousel.tsx b/client/src/components/Chat/Messages/Content/UIResourceCarousel.tsx index c0829e5ad9..4cafa643c6 100644 --- a/client/src/components/Chat/Messages/Content/UIResourceCarousel.tsx +++ b/client/src/components/Chat/Messages/Content/UIResourceCarousel.tsx @@ -1,7 +1,7 @@ import React, { useState } from 'react'; import { UIResourceRenderer } from '@mcp-ui/client'; import type { UIResource } from 'librechat-data-provider'; -import { useMessagesOperations } from '~/Providers'; +import { useOptionalMessagesOperations } from '~/Providers'; import { handleUIAction } from '~/utils'; interface UIResourceCarouselProps { @@ -13,7 +13,7 @@ const UIResourceCarousel: React.FC = React.memo(({ uiRe const [showRightArrow, setShowRightArrow] = useState(true); const [isContainerHovered, setIsContainerHovered] = useState(false); const scrollContainerRef = React.useRef(null); - const { ask } = useMessagesOperations(); + const { ask } = useOptionalMessagesOperations(); const handleScroll = React.useCallback(() => { if (!scrollContainerRef.current) return; diff --git a/client/src/components/Chat/Messages/Content/__tests__/Markdown.mcpui.test.tsx b/client/src/components/Chat/Messages/Content/__tests__/Markdown.mcpui.test.tsx index 6df66c9e15..6ca06056fa 100644 --- a/client/src/components/Chat/Messages/Content/__tests__/Markdown.mcpui.test.tsx +++ b/client/src/components/Chat/Messages/Content/__tests__/Markdown.mcpui.test.tsx @@ -3,7 +3,11 @@ import { render, screen } from '@testing-library/react'; import Markdown from '../Markdown'; import { RecoilRoot } from 'recoil'; import { UI_RESOURCE_MARKER } from '~/components/MCPUIResource/plugin'; -import { useMessageContext, useMessagesConversation, useMessagesOperations } from '~/Providers'; +import { + useMessageContext, + useOptionalMessagesConversation, + useOptionalMessagesOperations, +} from '~/Providers'; import { useGetMessagesByConvoId } from '~/data-provider'; import { useLocalize } from '~/hooks'; @@ -12,8 +16,8 @@ import { useLocalize } from '~/hooks'; jest.mock('~/Providers', () => ({ ...jest.requireActual('~/Providers'), useMessageContext: jest.fn(), - useMessagesConversation: jest.fn(), - useMessagesOperations: jest.fn(), + useOptionalMessagesConversation: jest.fn(), + useOptionalMessagesOperations: jest.fn(), })); jest.mock('~/data-provider'); jest.mock('~/hooks'); @@ -26,11 +30,11 @@ jest.mock('@mcp-ui/client', () => ({ })); const mockUseMessageContext = useMessageContext as jest.MockedFunction; -const mockUseMessagesConversation = useMessagesConversation as jest.MockedFunction< - typeof useMessagesConversation +const mockUseMessagesConversation = useOptionalMessagesConversation as jest.MockedFunction< + typeof useOptionalMessagesConversation >; -const mockUseMessagesOperations = useMessagesOperations as jest.MockedFunction< - typeof useMessagesOperations +const mockUseMessagesOperations = useOptionalMessagesOperations as jest.MockedFunction< + typeof useOptionalMessagesOperations >; const mockUseGetMessagesByConvoId = useGetMessagesByConvoId as jest.MockedFunction< typeof useGetMessagesByConvoId diff --git a/client/src/components/Chat/Messages/Content/__tests__/ToolCall.test.tsx b/client/src/components/Chat/Messages/Content/__tests__/ToolCall.test.tsx index 41356412f6..14b4b7e07a 100644 --- a/client/src/components/Chat/Messages/Content/__tests__/ToolCall.test.tsx +++ b/client/src/components/Chat/Messages/Content/__tests__/ToolCall.test.tsx @@ -1,6 +1,6 @@ import React from 'react'; import { RecoilRoot } from 'recoil'; -import { Tools } from 'librechat-data-provider'; +import { Tools, Constants } from 'librechat-data-provider'; import { render, screen, fireEvent } from '@testing-library/react'; import ToolCall from '../ToolCall'; @@ -53,9 +53,20 @@ jest.mock('../ToolCallInfo', () => ({ jest.mock('../ProgressText', () => ({ __esModule: true, - default: ({ onClick, inProgressText, finishedText, _error, _hasInput, _isExpanded }: any) => ( + default: ({ + onClick, + inProgressText, + finishedText, + subtitle, + }: { + onClick?: () => void; + inProgressText?: string; + finishedText?: string; + subtitle?: string; + }) => (
{finishedText || inProgressText} + {subtitle && {subtitle}}
), })); @@ -346,6 +357,141 @@ describe('ToolCall', () => { }); }); + describe('MCP OAuth detection', () => { + const d = Constants.mcp_delimiter; + + it('should detect MCP OAuth from delimiter in tool-call name', () => { + renderWithRecoil( + , + ); + const subtitle = screen.getByTestId('subtitle'); + expect(subtitle.textContent).toBe('via my-server'); + }); + + it('should preserve full server name when it contains the delimiter substring', () => { + renderWithRecoil( + , + ); + const subtitle = screen.getByTestId('subtitle'); + expect(subtitle.textContent).toBe(`via foo${d}bar`); + }); + + it('should display server name (not "oauth") as function_name for OAuth tool calls', () => { + renderWithRecoil( + , + ); + const progressText = screen.getByTestId('progress-text'); + expect(progressText.textContent).toContain('Completed my-server'); + expect(progressText.textContent).not.toContain('Completed oauth'); + }); + + it('should display server name even when auth is cleared (post-completion)', () => { + // After OAuth completes, createOAuthEnd re-emits the toolCall without auth. + // The display should still show the server name, not literal "oauth". + renderWithRecoil( + , + ); + const progressText = screen.getByTestId('progress-text'); + expect(progressText.textContent).toContain('Completed my-server'); + expect(progressText.textContent).not.toContain('Completed oauth'); + }); + + it('should fallback to auth URL redirect_uri when name lacks delimiter', () => { + const authUrl = + 'https://oauth.example.com/authorize?redirect_uri=' + + encodeURIComponent('https://app.example.com/api/mcp/my-server/oauth/callback'); + renderWithRecoil( + , + ); + const subtitle = screen.getByTestId('subtitle'); + expect(subtitle.textContent).toBe('via my-server'); + }); + + it('should display server name (not raw tool-call ID) in fallback path finished text', () => { + const authUrl = + 'https://oauth.example.com/authorize?redirect_uri=' + + encodeURIComponent('https://app.example.com/api/mcp/my-server/oauth/callback'); + renderWithRecoil( + , + ); + const progressText = screen.getByTestId('progress-text'); + expect(progressText.textContent).toContain('Completed my-server'); + expect(progressText.textContent).not.toContain('bare_name'); + }); + + it('should show normalized server name when it contains _mcp_ after prefixing', () => { + // Server named oauth@mcp@server normalizes to oauth_mcp_server, + // gets prefixed to oauth_mcp_oauth_mcp_server. Client parses: + // func="oauth", server="oauth_mcp_server". Visually awkward but + // semantically correct — the normalized name IS oauth_mcp_server. + renderWithRecoil( + , + ); + const subtitle = screen.getByTestId('subtitle'); + expect(subtitle.textContent).toBe(`via oauth${d}server`); + }); + + it('should not misidentify non-MCP action auth as MCP via fallback', () => { + const authUrl = + 'https://oauth.example.com/authorize?redirect_uri=' + + encodeURIComponent('https://app.example.com/api/actions/xyz/oauth/callback'); + renderWithRecoil( + , + ); + expect(screen.queryByTestId('subtitle')).not.toBeInTheDocument(); + }); + }); + describe('A11Y-04: screen reader status announcements', () => { it('includes sr-only aria-live region for status announcements', () => { renderWithRecoil( diff --git a/client/src/components/Chat/Messages/Content/__tests__/ToolCallInfo.test.tsx b/client/src/components/Chat/Messages/Content/__tests__/ToolCallInfo.test.tsx index 4a4d80ae8d..38b792ccae 100644 --- a/client/src/components/Chat/Messages/Content/__tests__/ToolCallInfo.test.tsx +++ b/client/src/components/Chat/Messages/Content/__tests__/ToolCallInfo.test.tsx @@ -25,7 +25,7 @@ jest.mock('~/hooks', () => ({ })); jest.mock('~/Providers', () => ({ - useMessagesOperations: () => ({ + useOptionalMessagesOperations: () => ({ ask: jest.fn(), }), })); diff --git a/client/src/components/Chat/Messages/Content/__tests__/UIResourceCarousel.test.tsx b/client/src/components/Chat/Messages/Content/__tests__/UIResourceCarousel.test.tsx index 6d208c2cf2..6e472e3f49 100644 --- a/client/src/components/Chat/Messages/Content/__tests__/UIResourceCarousel.test.tsx +++ b/client/src/components/Chat/Messages/Content/__tests__/UIResourceCarousel.test.tsx @@ -13,10 +13,10 @@ jest.mock('@mcp-ui/client', () => ({ ), })); -// Mock useMessagesOperations hook +// Mock useOptionalMessagesOperations hook const mockAsk = jest.fn(); jest.mock('~/Providers', () => ({ - useMessagesOperations: () => ({ + useOptionalMessagesOperations: () => ({ ask: mockAsk, }), })); diff --git a/client/src/components/Chat/Messages/Message.tsx b/client/src/components/Chat/Messages/Message.tsx index f9db38fdab..53aef812fc 100644 --- a/client/src/components/Chat/Messages/Message.tsx +++ b/client/src/components/Chat/Messages/Message.tsx @@ -1,5 +1,5 @@ import React from 'react'; -import { useMessageProcess } from '~/hooks'; +import { useMessageProcess, useMemoizedChatContext } from '~/hooks'; import type { TMessageProps } from '~/common'; import MessageRender from './ui/MessageRender'; import MultiMessage from './MultiMessage'; @@ -23,10 +23,11 @@ const MessageContainer = React.memo(function MessageContainer({ }); export default function Message(props: TMessageProps) { - const { conversation, handleScroll } = useMessageProcess({ + const { conversation, handleScroll, isSubmitting } = useMessageProcess({ message: props.message, }); const { message, currentEditId, setCurrentEditId } = props; + const { chatContext, effectiveIsSubmitting } = useMemoizedChatContext(message, isSubmitting); if (!message || typeof message !== 'object') { return null; @@ -38,7 +39,11 @@ export default function Message(props: TMessageProps) { <>
- +
; +/** + * Custom comparator for React.memo: compares `message` by key fields instead of reference + * because `buildTree` creates new message objects on every streaming update for ALL messages, + * even when only the latest message's text changed. + */ +function areMessageRenderPropsEqual(prev: MessageRenderProps, next: MessageRenderProps): boolean { + if (prev.isSubmitting !== next.isSubmitting) { + return false; + } + if (prev.chatContext !== next.chatContext) { + return false; + } + if (prev.siblingIdx !== next.siblingIdx) { + return false; + } + if (prev.siblingCount !== next.siblingCount) { + return false; + } + if (prev.currentEditId !== next.currentEditId) { + return false; + } + if (prev.setSiblingIdx !== next.setSiblingIdx) { + return false; + } + if (prev.setCurrentEditId !== next.setCurrentEditId) { + return false; + } + + const prevMsg = prev.message; + const nextMsg = next.message; + if (prevMsg === nextMsg) { + return true; + } + if (!prevMsg || !nextMsg) { + return prevMsg === nextMsg; + } + + return ( + prevMsg.messageId === nextMsg.messageId && + prevMsg.text === nextMsg.text && + prevMsg.error === nextMsg.error && + prevMsg.unfinished === nextMsg.unfinished && + prevMsg.depth === nextMsg.depth && + prevMsg.isCreatedByUser === nextMsg.isCreatedByUser && + (prevMsg.children?.length ?? 0) === (nextMsg.children?.length ?? 0) && + prevMsg.content === nextMsg.content && + prevMsg.model === nextMsg.model && + prevMsg.endpoint === nextMsg.endpoint && + prevMsg.iconURL === nextMsg.iconURL && + prevMsg.feedback?.rating === nextMsg.feedback?.rating && + (prevMsg.files?.length ?? 0) === (nextMsg.files?.length ?? 0) + ); +} + const MessageRender = memo(function MessageRender({ message: msg, siblingIdx, @@ -31,6 +92,7 @@ const MessageRender = memo(function MessageRender({ currentEditId, setCurrentEditId, isSubmitting = false, + chatContext, }: MessageRenderProps) { const localize = useLocalize(); const { @@ -52,6 +114,7 @@ const MessageRender = memo(function MessageRender({ message: msg, currentEditId, setCurrentEditId, + chatContext, }); const fontSize = useAtomValue(fontSizeAtom); const maximizeChatSpace = useRecoilValue(store.maximizeChatSpace); @@ -63,8 +126,6 @@ const MessageRender = memo(function MessageRender({ [hasNoChildren, msg?.depth, latestMessageDepth], ); const isLatestMessage = msg?.messageId === latestMessageId; - /** Only pass isSubmitting to the latest message to prevent unnecessary re-renders */ - const effectiveIsSubmitting = isLatestMessage ? isSubmitting : false; const iconData: TMessageIcon = useMemo( () => ({ @@ -92,10 +153,10 @@ const MessageRender = memo(function MessageRender({ messageId, isLatestMessage, isExpanded: false as const, - isSubmitting: effectiveIsSubmitting, + isSubmitting, conversationId: conversation?.conversationId, }), - [messageId, conversation?.conversationId, effectiveIsSubmitting, isLatestMessage], + [messageId, conversation?.conversationId, isSubmitting, isLatestMessage], ); if (!msg) { @@ -165,7 +226,7 @@ const MessageRender = memo(function MessageRender({ message={msg} enterEdit={enterEdit} error={!!(msg.error ?? false)} - isSubmitting={effectiveIsSubmitting} + isSubmitting={isSubmitting} unfinished={msg.unfinished ?? false} isCreatedByUser={msg.isCreatedByUser ?? true} siblingIdx={siblingIdx ?? 0} @@ -173,7 +234,7 @@ const MessageRender = memo(function MessageRender({ />
- {hasNoChildren && effectiveIsSubmitting ? ( + {hasNoChildren && isSubmitting ? ( ) : ( @@ -187,7 +248,7 @@ const MessageRender = memo(function MessageRender({ isEditing={edit} message={msg} enterEdit={enterEdit} - isSubmitting={isSubmitting} + isSubmitting={chatContext.isSubmitting} conversation={conversation ?? null} regenerate={handleRegenerateMessage} copyToClipboard={copyToClipboard} @@ -202,7 +263,7 @@ const MessageRender = memo(function MessageRender({ ); -}); +}, areMessageRenderPropsEqual); MessageRender.displayName = 'MessageRender'; export default MessageRender; diff --git a/client/src/components/MCPUIResource/MCPUIResource.tsx b/client/src/components/MCPUIResource/MCPUIResource.tsx index ddf65c4388..692db889c9 100644 --- a/client/src/components/MCPUIResource/MCPUIResource.tsx +++ b/client/src/components/MCPUIResource/MCPUIResource.tsx @@ -1,8 +1,8 @@ import React from 'react'; import { UIResourceRenderer } from '@mcp-ui/client'; -import { handleUIAction } from '~/utils'; +import { useOptionalMessagesConversation, useOptionalMessagesOperations } from '~/Providers'; import { useConversationUIResources } from '~/hooks/Messages/useConversationUIResources'; -import { useMessagesConversation, useMessagesOperations } from '~/Providers'; +import { handleUIAction } from '~/utils'; import { useLocalize } from '~/hooks'; interface MCPUIResourceProps { @@ -13,19 +13,14 @@ interface MCPUIResourceProps { }; } -/** - * Component that renders an MCP UI resource based on its resource ID. - * Works in both main app and share view. - */ +/** Renders an MCP UI resource based on its resource ID. Works in chat, share, and search views. */ export function MCPUIResource(props: MCPUIResourceProps) { const { resourceId } = props.node.properties; const localize = useLocalize(); - const { ask } = useMessagesOperations(); - const { conversation } = useMessagesConversation(); + const { ask } = useOptionalMessagesOperations(); + const { conversationId } = useOptionalMessagesConversation(); - const conversationResourceMap = useConversationUIResources( - conversation?.conversationId ?? undefined, - ); + const conversationResourceMap = useConversationUIResources(conversationId ?? undefined); const uiResource = conversationResourceMap.get(resourceId ?? ''); diff --git a/client/src/components/MCPUIResource/MCPUIResourceCarousel.tsx b/client/src/components/MCPUIResource/MCPUIResourceCarousel.tsx index cf32318491..ba81a2f153 100644 --- a/client/src/components/MCPUIResource/MCPUIResourceCarousel.tsx +++ b/client/src/components/MCPUIResource/MCPUIResourceCarousel.tsx @@ -1,8 +1,8 @@ import React, { useMemo } from 'react'; -import { useConversationUIResources } from '~/hooks/Messages/useConversationUIResources'; -import { useMessagesConversation } from '~/Providers'; -import UIResourceCarousel from '../Chat/Messages/Content/UIResourceCarousel'; import type { UIResource } from 'librechat-data-provider'; +import { useConversationUIResources } from '~/hooks/Messages/useConversationUIResources'; +import UIResourceCarousel from '../Chat/Messages/Content/UIResourceCarousel'; +import { useOptionalMessagesConversation } from '~/Providers'; interface MCPUIResourceCarouselProps { node: { @@ -12,16 +12,11 @@ interface MCPUIResourceCarouselProps { }; } -/** - * Component that renders multiple MCP UI resources in a carousel. - * Works in both main app and share view. - */ +/** Renders multiple MCP UI resources in a carousel. Works in chat, share, and search views. */ export function MCPUIResourceCarousel(props: MCPUIResourceCarouselProps) { - const { conversation } = useMessagesConversation(); + const { conversationId } = useOptionalMessagesConversation(); - const conversationResourceMap = useConversationUIResources( - conversation?.conversationId ?? undefined, - ); + const conversationResourceMap = useConversationUIResources(conversationId ?? undefined); const uiResources = useMemo(() => { const { resourceIds = [] } = props.node.properties; diff --git a/client/src/components/MCPUIResource/__tests__/MCPUIResource.test.tsx b/client/src/components/MCPUIResource/__tests__/MCPUIResource.test.tsx index 53896bb6fe..c37b6d5d51 100644 --- a/client/src/components/MCPUIResource/__tests__/MCPUIResource.test.tsx +++ b/client/src/components/MCPUIResource/__tests__/MCPUIResource.test.tsx @@ -2,7 +2,11 @@ import React from 'react'; import { render, screen } from '@testing-library/react'; import { RecoilRoot } from 'recoil'; import { MCPUIResource } from '../MCPUIResource'; -import { useMessageContext, useMessagesConversation, useMessagesOperations } from '~/Providers'; +import { + useMessageContext, + useOptionalMessagesConversation, + useOptionalMessagesOperations, +} from '~/Providers'; import { useLocalize } from '~/hooks'; import { handleUIAction } from '~/utils'; @@ -22,11 +26,11 @@ jest.mock('@mcp-ui/client', () => ({ })); const mockUseMessageContext = useMessageContext as jest.MockedFunction; -const mockUseMessagesConversation = useMessagesConversation as jest.MockedFunction< - typeof useMessagesConversation +const mockUseMessagesConversation = useOptionalMessagesConversation as jest.MockedFunction< + typeof useOptionalMessagesConversation >; -const mockUseMessagesOperations = useMessagesOperations as jest.MockedFunction< - typeof useMessagesOperations +const mockUseMessagesOperations = useOptionalMessagesOperations as jest.MockedFunction< + typeof useOptionalMessagesOperations >; const mockUseLocalize = useLocalize as jest.MockedFunction; const mockHandleUIAction = handleUIAction as jest.MockedFunction; diff --git a/client/src/components/MCPUIResource/__tests__/MCPUIResourceCarousel.test.tsx b/client/src/components/MCPUIResource/__tests__/MCPUIResourceCarousel.test.tsx index a9f7962ab0..9a5ca934a0 100644 --- a/client/src/components/MCPUIResource/__tests__/MCPUIResourceCarousel.test.tsx +++ b/client/src/components/MCPUIResource/__tests__/MCPUIResourceCarousel.test.tsx @@ -2,7 +2,11 @@ import React from 'react'; import { render, screen } from '@testing-library/react'; import { RecoilRoot } from 'recoil'; import { MCPUIResourceCarousel } from '../MCPUIResourceCarousel'; -import { useMessageContext, useMessagesConversation, useMessagesOperations } from '~/Providers'; +import { + useMessageContext, + useOptionalMessagesConversation, + useOptionalMessagesOperations, +} from '~/Providers'; // Mock dependencies jest.mock('~/Providers'); @@ -19,11 +23,11 @@ jest.mock('../../Chat/Messages/Content/UIResourceCarousel', () => ({ })); const mockUseMessageContext = useMessageContext as jest.MockedFunction; -const mockUseMessagesConversation = useMessagesConversation as jest.MockedFunction< - typeof useMessagesConversation +const mockUseMessagesConversation = useOptionalMessagesConversation as jest.MockedFunction< + typeof useOptionalMessagesConversation >; -const mockUseMessagesOperations = useMessagesOperations as jest.MockedFunction< - typeof useMessagesOperations +const mockUseMessagesOperations = useOptionalMessagesOperations as jest.MockedFunction< + typeof useOptionalMessagesOperations >; describe('MCPUIResourceCarousel', () => { diff --git a/client/src/components/Messages/ContentRender.tsx b/client/src/components/Messages/ContentRender.tsx index 6b3f05ce5d..4ba8db36f8 100644 --- a/client/src/components/Messages/ContentRender.tsx +++ b/client/src/components/Messages/ContentRender.tsx @@ -2,7 +2,7 @@ import { useCallback, useMemo, memo } from 'react'; import { useAtomValue } from 'jotai'; import { useRecoilValue } from 'recoil'; import type { TMessage, TMessageContentParts } from 'librechat-data-provider'; -import type { TMessageProps, TMessageIcon } from '~/common'; +import type { TMessageProps, TMessageIcon, TMessageChatContext } from '~/common'; import { useAttachments, useLocalize, useMessageActions, useContentMetadata } from '~/hooks'; import { cn, getHeaderPrefixForScreenReader, getMessageAriaLabel } from '~/utils'; import ContentParts from '~/components/Chat/Messages/Content/ContentParts'; @@ -16,12 +16,72 @@ import store from '~/store'; type ContentRenderProps = { message?: TMessage; + /** + * Effective isSubmitting: false for non-latest messages, real value for latest. + * Computed by the wrapper (MessageContent.tsx) so this memo'd component only re-renders + * when the value actually matters. + */ isSubmitting?: boolean; + /** Stable context object from wrapper — avoids ChatContext subscription inside memo */ + chatContext: TMessageChatContext; } & Pick< TMessageProps, 'currentEditId' | 'setCurrentEditId' | 'siblingIdx' | 'setSiblingIdx' | 'siblingCount' >; +/** + * Custom comparator for React.memo: compares `message` by key fields instead of reference + * because `buildTree` creates new message objects on every streaming update for ALL messages. + */ +function areContentRenderPropsEqual(prev: ContentRenderProps, next: ContentRenderProps): boolean { + if (prev.isSubmitting !== next.isSubmitting) { + return false; + } + if (prev.chatContext !== next.chatContext) { + return false; + } + if (prev.siblingIdx !== next.siblingIdx) { + return false; + } + if (prev.siblingCount !== next.siblingCount) { + return false; + } + if (prev.currentEditId !== next.currentEditId) { + return false; + } + if (prev.setSiblingIdx !== next.setSiblingIdx) { + return false; + } + if (prev.setCurrentEditId !== next.setCurrentEditId) { + return false; + } + + const prevMsg = prev.message; + const nextMsg = next.message; + if (prevMsg === nextMsg) { + return true; + } + if (!prevMsg || !nextMsg) { + return prevMsg === nextMsg; + } + + return ( + prevMsg.messageId === nextMsg.messageId && + prevMsg.text === nextMsg.text && + prevMsg.error === nextMsg.error && + prevMsg.unfinished === nextMsg.unfinished && + prevMsg.depth === nextMsg.depth && + prevMsg.isCreatedByUser === nextMsg.isCreatedByUser && + (prevMsg.children?.length ?? 0) === (nextMsg.children?.length ?? 0) && + prevMsg.content === nextMsg.content && + prevMsg.model === nextMsg.model && + prevMsg.endpoint === nextMsg.endpoint && + prevMsg.iconURL === nextMsg.iconURL && + prevMsg.feedback?.rating === nextMsg.feedback?.rating && + (prevMsg.attachments?.length ?? 0) === (nextMsg.attachments?.length ?? 0) + ); +} + const ContentRender = memo(function ContentRender({ message: msg, siblingIdx, @@ -30,6 +90,7 @@ const ContentRender = memo(function ContentRender({ currentEditId, setCurrentEditId, isSubmitting = false, + chatContext, }: ContentRenderProps) { const localize = useLocalize(); const { attachments, searchResults } = useAttachments({ @@ -55,6 +116,7 @@ const ContentRender = memo(function ContentRender({ searchResults, currentEditId, setCurrentEditId, + chatContext, }); const fontSize = useAtomValue(fontSizeAtom); const maximizeChatSpace = useRecoilValue(store.maximizeChatSpace); @@ -66,8 +128,6 @@ const ContentRender = memo(function ContentRender({ ); const hasNoChildren = !(msg?.children?.length ?? 0); const isLatestMessage = msg?.messageId === latestMessageId; - /** Only pass isSubmitting to the latest message to prevent unnecessary re-renders */ - const effectiveIsSubmitting = isLatestMessage ? isSubmitting : false; const iconData: TMessageIcon = useMemo( () => ({ @@ -158,13 +218,13 @@ const ContentRender = memo(function ContentRender({ searchResults={searchResults} setSiblingIdx={setSiblingIdx} isLatestMessage={isLatestMessage} - isSubmitting={effectiveIsSubmitting} + isSubmitting={isSubmitting} isCreatedByUser={msg.isCreatedByUser} conversationId={conversation?.conversationId} content={msg.content as Array} /> - {hasNoChildren && effectiveIsSubmitting ? ( + {hasNoChildren && isSubmitting ? ( ) : ( @@ -178,7 +238,7 @@ const ContentRender = memo(function ContentRender({ message={msg} isEditing={edit} enterEdit={enterEdit} - isSubmitting={isSubmitting} + isSubmitting={chatContext.isSubmitting} conversation={conversation ?? null} regenerate={handleRegenerateMessage} copyToClipboard={copyToClipboard} @@ -193,7 +253,7 @@ const ContentRender = memo(function ContentRender({ ); -}); +}, areContentRenderPropsEqual); ContentRender.displayName = 'ContentRender'; export default ContentRender; diff --git a/client/src/components/Messages/MessageContent.tsx b/client/src/components/Messages/MessageContent.tsx index 0e53b1c840..977e397022 100644 --- a/client/src/components/Messages/MessageContent.tsx +++ b/client/src/components/Messages/MessageContent.tsx @@ -1,5 +1,5 @@ import React from 'react'; -import { useMessageProcess } from '~/hooks'; +import { useMessageProcess, useMemoizedChatContext } from '~/hooks'; import type { TMessageProps } from '~/common'; import MultiMessage from '~/components/Chat/Messages/MultiMessage'; @@ -28,6 +28,7 @@ export default function MessageContent(props: TMessageProps) { message: props.message, }); const { message, currentEditId, setCurrentEditId } = props; + const { chatContext, effectiveIsSubmitting } = useMemoizedChatContext(message, isSubmitting); if (!message || typeof message !== 'object') { return null; @@ -39,7 +40,11 @@ export default function MessageContent(props: TMessageProps) { <>
- +
({ + conversation, + setConversation, + generateConversation, + }), + [conversation, setConversation, generateConversation], + ); } diff --git a/client/src/hooks/Files/index.ts b/client/src/hooks/Files/index.ts index df86c02a96..499572f0e0 100644 --- a/client/src/hooks/Files/index.ts +++ b/client/src/hooks/Files/index.ts @@ -1,6 +1,6 @@ export { default as useDeleteFilesFromTable } from './useDeleteFilesFromTable'; export { default as useSetFilesToDelete } from './useSetFilesToDelete'; -export { default as useFileHandling } from './useFileHandling'; +export { default as useFileHandling, useFileHandlingNoChatContext } from './useFileHandling'; export { default as useFileDeletion } from './useFileDeletion'; export { default as useUpdateFiles } from './useUpdateFiles'; export { default as useDragHelpers } from './useDragHelpers'; diff --git a/client/src/hooks/Messages/index.ts b/client/src/hooks/Messages/index.ts index a78a1ef553..439b7e152e 100644 --- a/client/src/hooks/Messages/index.ts +++ b/client/src/hooks/Messages/index.ts @@ -5,6 +5,7 @@ export { default as useSubmitMessage } from './useSubmitMessage'; export type { ContentMetadataResult } from './useContentMetadata'; export { default as useExpandCollapse } from './useExpandCollapse'; export { default as useMessageActions } from './useMessageActions'; +export { default as useMemoizedChatContext } from './useMemoizedChatContext'; export { default as useMessageProcess } from './useMessageProcess'; export { default as useMessageHelpers } from './useMessageHelpers'; export { default as useCopyToClipboard } from './useCopyToClipboard'; diff --git a/client/src/hooks/Messages/useConversationUIResources.ts b/client/src/hooks/Messages/useConversationUIResources.ts index 2333f64e5f..28e9aa035a 100644 --- a/client/src/hooks/Messages/useConversationUIResources.ts +++ b/client/src/hooks/Messages/useConversationUIResources.ts @@ -2,7 +2,7 @@ import { useMemo } from 'react'; import { useRecoilValue } from 'recoil'; import { Tools } from 'librechat-data-provider'; import type { TAttachment, UIResource } from 'librechat-data-provider'; -import { useMessagesOperations } from '~/Providers'; +import { useOptionalMessagesOperations } from '~/Providers'; import store from '~/store'; /** @@ -16,7 +16,7 @@ import store from '~/store'; export function useConversationUIResources( conversationId: string | undefined, ): Map { - const { getMessages } = useMessagesOperations(); + const { getMessages } = useOptionalMessagesOperations(); const conversationAttachmentsMap = useRecoilValue( store.conversationAttachmentsSelector(conversationId), diff --git a/client/src/hooks/Messages/useMemoizedChatContext.ts b/client/src/hooks/Messages/useMemoizedChatContext.ts new file mode 100644 index 0000000000..aa35372a8e --- /dev/null +++ b/client/src/hooks/Messages/useMemoizedChatContext.ts @@ -0,0 +1,80 @@ +import { useRef, useMemo } from 'react'; +import type { TMessage } from 'librechat-data-provider'; +import type { TMessageChatContext } from '~/common/types'; +import { useChatContext } from '~/Providers'; + +/** + * Creates a stable `TMessageChatContext` object for memo'd message components. + * + * Subscribes to `useChatContext()` internally (intended to be called from non-memo'd + * wrapper components like `Message` and `MessageContent`), then produces: + * - A `chatContext` object that stays referentially stable during streaming + * (uses a getter for `isSubmitting` backed by a ref) + * - A stable `conversation` reference that only updates when rendering-relevant fields change + * - An `effectiveIsSubmitting` value (false for non-latest messages) + */ +export default function useMemoizedChatContext( + message: TMessage | null | undefined, + isSubmitting: boolean, +) { + const chatCtx = useChatContext(); + + const isSubmittingRef = useRef(isSubmitting); + isSubmittingRef.current = isSubmitting; + + /** + * Stabilize conversation: only update when rendering-relevant fields change, + * not on every metadata update (e.g., title generation). + */ + const stableConversation = useMemo( + () => chatCtx.conversation, + // eslint-disable-next-line react-hooks/exhaustive-deps + [ + chatCtx.conversation?.conversationId, + chatCtx.conversation?.endpoint, + chatCtx.conversation?.endpointType, + chatCtx.conversation?.model, + chatCtx.conversation?.agent_id, + chatCtx.conversation?.assistant_id, + ], + ); + + /** + * `isSubmitting` is included in deps so that chatContext gets a new reference + * when streaming starts/ends (2x per session). This ensures HoverButtons + * re-renders to update regenerate/edit button visibility via useGenerationsByLatest. + * The getter pattern is still valuable: callbacks reading chatContext.isSubmitting + * at call-time always get the current value even between these re-renders. + */ + const chatContext: TMessageChatContext = useMemo( + () => ({ + ask: chatCtx.ask, + index: chatCtx.index, + regenerate: chatCtx.regenerate, + conversation: stableConversation, + latestMessageId: chatCtx.latestMessageId, + latestMessageDepth: chatCtx.latestMessageDepth, + handleContinue: chatCtx.handleContinue, + get isSubmitting() { + return isSubmittingRef.current; + }, + }), + // eslint-disable-next-line react-hooks/exhaustive-deps + [ + chatCtx.ask, + chatCtx.index, + chatCtx.regenerate, + stableConversation, + chatCtx.latestMessageId, + chatCtx.latestMessageDepth, + chatCtx.handleContinue, + isSubmitting, // intentional: forces new reference on streaming start/end so HoverButtons re-renders + ], + ); + + const messageId = message?.messageId ?? null; + const isLatestMessage = messageId === chatCtx.latestMessageId; + const effectiveIsSubmitting = isLatestMessage ? isSubmitting : false; + + return { chatContext, effectiveIsSubmitting }; +} diff --git a/client/src/hooks/Messages/useMessageActions.tsx b/client/src/hooks/Messages/useMessageActions.tsx index e8946b895b..590ba6a40e 100644 --- a/client/src/hooks/Messages/useMessageActions.tsx +++ b/client/src/hooks/Messages/useMessageActions.tsx @@ -11,7 +11,8 @@ import { TUpdateFeedbackRequest, } from 'librechat-data-provider'; import type { TMessageProps } from '~/common'; -import { useChatContext, useAssistantsMapContext, useAgentsMapContext } from '~/Providers'; +import type { TMessageChatContext } from '~/common/types'; +import { useAssistantsMapContext, useAgentsMapContext } from '~/Providers'; import useCopyToClipboard from './useCopyToClipboard'; import { useAuthContext } from '~/hooks/AuthContext'; import { useGetAddedConvo } from '~/hooks/Chat'; @@ -23,24 +24,33 @@ export type TMessageActions = Pick< 'message' | 'currentEditId' | 'setCurrentEditId' > & { searchResults?: { [key: string]: SearchResultData }; + /** + * Stable context object passed from wrapper components to avoid subscribing + * to ChatContext inside memo'd components (which would bypass React.memo). + * The `isSubmitting` property uses a getter backed by a ref, so it always + * returns the current value at call-time without triggering re-renders. + */ + chatContext: TMessageChatContext; }; export default function useMessageActions(props: TMessageActions) { const localize = useLocalize(); const { user } = useAuthContext(); const UsernameDisplay = useRecoilValue(store.UsernameDisplay); - const { message, currentEditId, setCurrentEditId, searchResults } = props; + const { message, currentEditId, setCurrentEditId, searchResults, chatContext } = props; const { ask, index, regenerate, - isSubmitting, conversation, latestMessageId, latestMessageDepth, handleContinue, - } = useChatContext(); + // NOTE: isSubmitting is intentionally NOT destructured here. + // chatContext.isSubmitting is a getter backed by a ref — destructuring + // would capture a one-time snapshot. Always access via chatContext.isSubmitting. + } = chatContext; const getAddedConvo = useGetAddedConvo(); @@ -98,13 +108,18 @@ export default function useMessageActions(props: TMessageActions) { } }, [agentsMap, conversation?.agent_id, conversation?.endpoint, message?.model]); + /** + * chatContext.isSubmitting is a getter backed by the wrapper's ref, + * so it always returns the current value at call-time — even for + * non-latest messages that don't re-render during streaming. + */ const regenerateMessage = useCallback(() => { - if ((isSubmitting && isCreatedByUser === true) || !message) { + if ((chatContext.isSubmitting && isCreatedByUser === true) || !message) { return; } regenerate(message, { addedConvo: getAddedConvo() }); - }, [isSubmitting, isCreatedByUser, message, regenerate, getAddedConvo]); + }, [chatContext, isCreatedByUser, message, regenerate, getAddedConvo]); const copyToClipboard = useCopyToClipboard({ text, content, searchResults }); diff --git a/client/src/hooks/Messages/useMessageProcess.tsx b/client/src/hooks/Messages/useMessageProcess.tsx index 37738b50a9..bb49670a2f 100644 --- a/client/src/hooks/Messages/useMessageProcess.tsx +++ b/client/src/hooks/Messages/useMessageProcess.tsx @@ -1,6 +1,6 @@ import throttle from 'lodash/throttle'; import { Constants } from 'librechat-data-provider'; -import { useEffect, useRef, useCallback, useMemo } from 'react'; +import { useEffect, useRef, useMemo } from 'react'; import type { TMessage } from 'librechat-data-provider'; import { getTextKey, TEXT_KEY_DIVIDER, logger } from '~/utils'; import { useMessagesViewContext } from '~/Providers'; @@ -56,24 +56,25 @@ export default function useMessageProcess({ message }: { message?: TMessage | nu } }, [hasNoChildren, message, setLatestMessage, conversation?.conversationId]); - const handleScroll = useCallback( - (event: unknown | TouchEvent | WheelEvent) => { - throttle(() => { + /** Use ref for isSubmitting to stabilize handleScroll across isSubmitting changes */ + const isSubmittingRef = useRef(isSubmitting); + isSubmittingRef.current = isSubmitting; + + const handleScroll = useMemo( + () => + throttle((event: unknown) => { logger.log( 'message_scrolling', - `useMessageProcess: setting abort scroll to ${isSubmitting}, handleScroll event`, + `useMessageProcess: setting abort scroll to ${isSubmittingRef.current}, handleScroll event`, event, ); - if (isSubmitting) { - setAbortScroll(true); - } else { - setAbortScroll(false); - } - }, 500)(); - }, - [isSubmitting, setAbortScroll], + setAbortScroll(isSubmittingRef.current); + }, 500), + [setAbortScroll], ); + useEffect(() => () => handleScroll.cancel(), [handleScroll]); + return { handleScroll, isSubmitting, diff --git a/client/src/hooks/SSE/__tests__/useResumableSSE.spec.ts b/client/src/hooks/SSE/__tests__/useResumableSSE.spec.ts index 9100f39858..1717d27c22 100644 --- a/client/src/hooks/SSE/__tests__/useResumableSSE.spec.ts +++ b/client/src/hooks/SSE/__tests__/useResumableSSE.spec.ts @@ -1,5 +1,5 @@ import { renderHook, act } from '@testing-library/react'; -import { Constants, ErrorTypes, LocalStorageKeys } from 'librechat-data-provider'; +import { Constants, LocalStorageKeys } from 'librechat-data-provider'; import type { TSubmission } from 'librechat-data-provider'; type SSEEventListener = (e: Partial & { responseCode?: number }) => void; @@ -34,7 +34,13 @@ jest.mock('sse.js', () => ({ })); const mockSetQueryData = jest.fn(); -const mockQueryClient = { setQueryData: mockSetQueryData }; +const mockInvalidateQueries = jest.fn(); +const mockRemoveQueries = jest.fn(); +const mockQueryClient = { + setQueryData: mockSetQueryData, + invalidateQueries: mockInvalidateQueries, + removeQueries: mockRemoveQueries, +}; jest.mock('@tanstack/react-query', () => ({ ...jest.requireActual('@tanstack/react-query'), @@ -63,6 +69,7 @@ jest.mock('~/data-provider', () => ({ useGetStartupConfig: () => ({ data: { balance: { enabled: false } } }), useGetUserBalance: () => ({ refetch: jest.fn() }), queueTitleGeneration: jest.fn(), + streamStatusQueryKey: (conversationId: string) => ['streamStatus', conversationId], })); const mockErrorHandler = jest.fn(); @@ -162,6 +169,11 @@ describe('useResumableSSE - 404 error path', () => { beforeEach(() => { mockSSEInstances.length = 0; localStorage.clear(); + mockErrorHandler.mockClear(); + mockClearStepMaps.mockClear(); + mockSetIsSubmitting.mockClear(); + mockInvalidateQueries.mockClear(); + mockRemoveQueries.mockClear(); }); const seedDraft = (conversationId: string) => { @@ -200,19 +212,18 @@ describe('useResumableSSE - 404 error path', () => { unmount(); }); - it('calls errorHandler with STREAM_EXPIRED error type on 404', async () => { + it('invalidates message cache and clears stream status on 404 instead of showing error', async () => { const { unmount } = await render404Scenario(CONV_ID); - expect(mockErrorHandler).toHaveBeenCalledTimes(1); - const call = mockErrorHandler.mock.calls[0][0]; - expect(call.data).toBeDefined(); - const parsed = JSON.parse(call.data.text); - expect(parsed.type).toBe(ErrorTypes.STREAM_EXPIRED); - expect(call.submission).toEqual( - expect.objectContaining({ - conversation: expect.objectContaining({ conversationId: CONV_ID }), - }), - ); + expect(mockErrorHandler).not.toHaveBeenCalled(); + expect(mockInvalidateQueries).toHaveBeenCalledWith({ + queryKey: ['messages', CONV_ID], + }); + expect(mockRemoveQueries).toHaveBeenCalledWith({ + queryKey: ['streamStatus', CONV_ID], + }); + expect(mockClearStepMaps).toHaveBeenCalled(); + expect(mockSetIsSubmitting).toHaveBeenCalledWith(false); unmount(); }); diff --git a/client/src/hooks/SSE/useResumableSSE.ts b/client/src/hooks/SSE/useResumableSSE.ts index 32820f8392..39dc610dae 100644 --- a/client/src/hooks/SSE/useResumableSSE.ts +++ b/client/src/hooks/SSE/useResumableSSE.ts @@ -16,7 +16,12 @@ import { } from 'librechat-data-provider'; import type { TMessage, TPayload, TSubmission, EventSubmission } from 'librechat-data-provider'; import type { EventHandlerParams } from './useEventHandlers'; -import { useGetStartupConfig, useGetUserBalance, queueTitleGeneration } from '~/data-provider'; +import { + useGetUserBalance, + useGetStartupConfig, + queueTitleGeneration, + streamStatusQueryKey, +} from '~/data-provider'; import type { ActiveJobsResponse } from '~/data-provider'; import { useAuthContext } from '~/hooks/AuthContext'; import useEventHandlers from './useEventHandlers'; @@ -343,18 +348,20 @@ export default function useResumableSSE( /* @ts-ignore - sse.js types don't expose responseCode */ const responseCode = e.responseCode; - // 404 means job doesn't exist (completed/deleted) - don't retry + // 404 → job completed & was cleaned up; messages are persisted in DB. + // Invalidate cache once so react-query refetches instead of showing an error. if (responseCode === 404) { - console.log('[ResumableSSE] Stream not found (404) - job completed or expired'); + const convoId = currentSubmission.conversation?.conversationId; + console.log('[ResumableSSE] Stream 404, invalidating messages for:', convoId); sse.close(); removeActiveJob(currentStreamId); - clearAllDrafts(currentSubmission.conversation?.conversationId); - errorHandler({ - data: { - text: JSON.stringify({ type: ErrorTypes.STREAM_EXPIRED }), - } as unknown as Parameters[0]['data'], - submission: currentSubmission as EventSubmission, - }); + clearAllDrafts(convoId); + clearStepMaps(); + if (convoId) { + queryClient.invalidateQueries({ queryKey: [QueryKeys.messages, convoId] }); + queryClient.removeQueries({ queryKey: streamStatusQueryKey(convoId) }); + } + setIsSubmitting(false); setShowStopButton(false); setStreamId(null); reconnectAttemptRef.current = 0; @@ -544,6 +551,7 @@ export default function useResumableSSE( startupConfig?.balance?.enabled, balanceQuery, removeActiveJob, + queryClient, ], ); diff --git a/client/src/hooks/SSE/useResumeOnLoad.ts b/client/src/hooks/SSE/useResumeOnLoad.ts index f09751db0e..5f0f691787 100644 --- a/client/src/hooks/SSE/useResumeOnLoad.ts +++ b/client/src/hooks/SSE/useResumeOnLoad.ts @@ -125,7 +125,11 @@ export default function useResumeOnLoad( conversationId !== Constants.NEW_CONVO && processedConvoRef.current !== conversationId; // Don't re-check processed convos - const { data: streamStatus, isSuccess } = useStreamStatus(conversationId, shouldCheck); + const { + data: streamStatus, + isSuccess, + isFetching, + } = useStreamStatus(conversationId, shouldCheck); useEffect(() => { console.log('[ResumeOnLoad] Effect check', { @@ -135,6 +139,7 @@ export default function useResumeOnLoad( hasCurrentSubmission: !!currentSubmission, currentSubmissionConvoId: currentSubmission?.conversation?.conversationId, isSuccess, + isFetching, streamStatusActive: streamStatus?.active, streamStatusStreamId: streamStatus?.streamId, processedConvoRef: processedConvoRef.current, @@ -171,8 +176,9 @@ export default function useResumeOnLoad( ); } - // Wait for stream status query to complete - if (!isSuccess || !streamStatus) { + // Wait for stream status query to complete (including background refetches + // that may replace a stale cached result with fresh data) + if (!isSuccess || !streamStatus || isFetching) { console.log('[ResumeOnLoad] Waiting for stream status query'); return; } @@ -183,15 +189,12 @@ export default function useResumeOnLoad( return; } - // Check if there's an active job to resume - // DON'T mark as processed here - only mark when we actually create a submission - // This prevents stale cache data from blocking subsequent resume attempts if (!streamStatus.active || !streamStatus.streamId) { console.log('[ResumeOnLoad] No active job to resume for:', conversationId); + processedConvoRef.current = conversationId; return; } - // Mark as processed NOW - we verified there's an active job and will create submission processedConvoRef.current = conversationId; console.log('[ResumeOnLoad] Found active job, creating submission...', { @@ -241,6 +244,7 @@ export default function useResumeOnLoad( submissionConvoId, currentSubmission, isSuccess, + isFetching, streamStatus, getMessages, setSubmission, diff --git a/packages/api/jest.config.mjs b/packages/api/jest.config.mjs index df9cf6bcc2..976b794122 100644 --- a/packages/api/jest.config.mjs +++ b/packages/api/jest.config.mjs @@ -8,6 +8,7 @@ export default { '\\.helper\\.ts$', '\\.helper\\.d\\.ts$', '/__tests__/helpers/', + '\\.manual\\.spec\\.[jt]sx?$', ], coverageReporters: ['text', 'cobertura'], testResultsProcessor: 'jest-junit', diff --git a/packages/api/package.json b/packages/api/package.json index a4e74a7a3c..f09d946ec5 100644 --- a/packages/api/package.json +++ b/packages/api/package.json @@ -18,8 +18,8 @@ "build:dev": "npm run clean && NODE_ENV=development rollup -c --bundleConfigAsCjs", "build:watch": "NODE_ENV=development rollup -c -w --bundleConfigAsCjs", "build:watch:prod": "rollup -c -w --bundleConfigAsCjs", - "test": "jest --coverage --watch --testPathIgnorePatterns=\"\\.*integration\\.|\\.*helper\\.|__tests__/helpers/\"", - "test:ci": "jest --coverage --ci --testPathIgnorePatterns=\"\\.*integration\\.|\\.*helper\\.|__tests__/helpers/\"", + "test": "jest --coverage --watch --testPathIgnorePatterns=\"\\.*integration\\.|\\.*helper\\.|__tests__/helpers/|\\.*manual\\.spec\\.\"", + "test:ci": "jest --coverage --ci --testPathIgnorePatterns=\"\\.*integration\\.|\\.*helper\\.|__tests__/helpers/|\\.*manual\\.spec\\.\"", "test:cache-integration:core": "jest --testPathPatterns=\"src/cache/.*\\.cache_integration\\.spec\\.ts$\" --coverage=false", "test:cache-integration:cluster": "jest --testPathPatterns=\"src/cluster/.*\\.cache_integration\\.spec\\.ts$\" --coverage=false --runInBand", "test:cache-integration:mcp": "jest --testPathPatterns=\"src/mcp/.*\\.cache_integration\\.spec\\.ts$\" --coverage=false", diff --git a/packages/api/src/admin/config.handler.spec.ts b/packages/api/src/admin/config.handler.spec.ts new file mode 100644 index 0000000000..708d114e72 --- /dev/null +++ b/packages/api/src/admin/config.handler.spec.ts @@ -0,0 +1,423 @@ +import type { Response } from 'express'; +import type { ServerRequest } from '~/types/http'; +import { createAdminConfigHandlers } from './config'; + +function mockReq(overrides = {}) { + return { + user: { id: 'u1', role: 'ADMIN', _id: { toString: () => 'u1' } }, + params: {}, + body: {}, + query: {}, + ...overrides, + } as Partial as ServerRequest; +} + +interface MockRes { + statusCode: number; + body: undefined | { config?: unknown; error?: string; [key: string]: unknown }; + status: jest.Mock; + json: jest.Mock; +} + +function mockRes() { + const res: MockRes = { + statusCode: 200, + body: undefined, + status: jest.fn((code: number) => { + res.statusCode = code; + return res; + }), + json: jest.fn((data: MockRes['body']) => { + res.body = data; + return res; + }), + }; + return res as Partial as Response & MockRes; +} + +function createHandlers(overrides = {}) { + const deps = { + listAllConfigs: jest.fn().mockResolvedValue([]), + findConfigByPrincipal: jest.fn().mockResolvedValue(null), + upsertConfig: jest.fn().mockResolvedValue({ + _id: 'c1', + principalType: 'role', + principalId: 'admin', + overrides: {}, + configVersion: 1, + }), + patchConfigFields: jest + .fn() + .mockResolvedValue({ _id: 'c1', overrides: { interface: { endpointsMenu: false } } }), + unsetConfigField: jest.fn().mockResolvedValue({ _id: 'c1', overrides: {} }), + deleteConfig: jest.fn().mockResolvedValue({ _id: 'c1' }), + toggleConfigActive: jest.fn().mockResolvedValue({ _id: 'c1', isActive: false }), + hasConfigCapability: jest.fn().mockResolvedValue(true), + + getAppConfig: jest.fn().mockResolvedValue({ interface: { endpointsMenu: true } }), + ...overrides, + }; + const handlers = createAdminConfigHandlers(deps); + return { handlers, deps }; +} + +describe('createAdminConfigHandlers', () => { + describe('getConfig', () => { + it('returns 403 before DB lookup when user lacks READ_CONFIGS', async () => { + const { handlers, deps } = createHandlers({ + hasConfigCapability: jest.fn().mockResolvedValue(false), + }); + const req = mockReq({ params: { principalType: 'role', principalId: 'admin' } }); + const res = mockRes(); + + await handlers.getConfig(req, res); + + expect(res.statusCode).toBe(403); + expect(deps.findConfigByPrincipal).not.toHaveBeenCalled(); + }); + + it('returns 404 when config does not exist', async () => { + const { handlers } = createHandlers(); + const req = mockReq({ params: { principalType: 'role', principalId: 'nonexistent' } }); + const res = mockRes(); + + await handlers.getConfig(req, res); + + expect(res.statusCode).toBe(404); + }); + + it('returns config when authorized and exists', async () => { + const config = { + _id: 'c1', + principalType: 'role', + principalId: 'admin', + overrides: { x: 1 }, + }; + const { handlers } = createHandlers({ + findConfigByPrincipal: jest.fn().mockResolvedValue(config), + }); + const req = mockReq({ params: { principalType: 'role', principalId: 'admin' } }); + const res = mockRes(); + + await handlers.getConfig(req, res); + + expect(res.statusCode).toBe(200); + expect(res.body!.config).toEqual(config); + }); + + it('returns 400 for invalid principalType', async () => { + const { handlers } = createHandlers(); + const req = mockReq({ params: { principalType: 'invalid', principalId: 'x' } }); + const res = mockRes(); + + await handlers.getConfig(req, res); + + expect(res.statusCode).toBe(400); + }); + + it('rejects public principalType — not usable for config overrides', async () => { + const { handlers } = createHandlers(); + const req = mockReq({ params: { principalType: 'public', principalId: 'x' } }); + const res = mockRes(); + + await handlers.getConfig(req, res); + + expect(res.statusCode).toBe(400); + }); + }); + + describe('upsertConfigOverrides', () => { + it('returns 201 when creating a new config (configVersion === 1)', async () => { + const { handlers } = createHandlers({ + upsertConfig: jest.fn().mockResolvedValue({ _id: 'c1', configVersion: 1 }), + }); + const req = mockReq({ + params: { principalType: 'role', principalId: 'admin' }, + body: { overrides: { interface: { endpointsMenu: false } } }, + }); + const res = mockRes(); + + await handlers.upsertConfigOverrides(req, res); + + expect(res.statusCode).toBe(201); + }); + + it('returns 200 when updating an existing config (configVersion > 1)', async () => { + const { handlers } = createHandlers({ + upsertConfig: jest.fn().mockResolvedValue({ _id: 'c1', configVersion: 5 }), + }); + const req = mockReq({ + params: { principalType: 'role', principalId: 'admin' }, + body: { overrides: { interface: { endpointsMenu: false } } }, + }); + const res = mockRes(); + + await handlers.upsertConfigOverrides(req, res); + + expect(res.statusCode).toBe(200); + }); + + it('returns 400 when overrides is missing', async () => { + const { handlers } = createHandlers(); + const req = mockReq({ + params: { principalType: 'role', principalId: 'admin' }, + body: {}, + }); + const res = mockRes(); + + await handlers.upsertConfigOverrides(req, res); + + expect(res.statusCode).toBe(400); + }); + }); + + describe('deleteConfigField', () => { + it('reads fieldPath from query parameter', async () => { + const { handlers, deps } = createHandlers(); + const req = mockReq({ + params: { principalType: 'role', principalId: 'admin' }, + query: { fieldPath: 'interface.endpointsMenu' }, + }); + const res = mockRes(); + + await handlers.deleteConfigField(req, res); + + expect(deps.unsetConfigField).toHaveBeenCalledWith( + 'role', + 'admin', + 'interface.endpointsMenu', + ); + }); + + it('returns 400 when fieldPath query param is missing', async () => { + const { handlers } = createHandlers(); + const req = mockReq({ + params: { principalType: 'role', principalId: 'admin' }, + query: {}, + }); + const res = mockRes(); + + await handlers.deleteConfigField(req, res); + + expect(res.statusCode).toBe(400); + expect(res.body!.error).toContain('query parameter'); + }); + + it('rejects unsafe field paths', async () => { + const { handlers } = createHandlers(); + const req = mockReq({ + params: { principalType: 'role', principalId: 'admin' }, + query: { fieldPath: '__proto__.polluted' }, + }); + const res = mockRes(); + + await handlers.deleteConfigField(req, res); + + expect(res.statusCode).toBe(400); + }); + }); + + describe('patchConfigField', () => { + it('returns 403 when user lacks capability for section', async () => { + const { handlers } = createHandlers({ + hasConfigCapability: jest.fn().mockResolvedValue(false), + }); + const req = mockReq({ + params: { principalType: 'role', principalId: 'admin' }, + body: { entries: [{ fieldPath: 'interface.endpointsMenu', value: false }] }, + }); + const res = mockRes(); + + await handlers.patchConfigField(req, res); + + expect(res.statusCode).toBe(403); + }); + + it('rejects entries with unsafe field paths (prototype pollution)', async () => { + const { handlers } = createHandlers(); + const req = mockReq({ + params: { principalType: 'role', principalId: 'admin' }, + body: { entries: [{ fieldPath: '__proto__.polluted', value: true }] }, + }); + const res = mockRes(); + + await handlers.patchConfigField(req, res); + + expect(res.statusCode).toBe(400); + }); + }); + + describe('upsertConfigOverrides — Bug 2 regression', () => { + it('returns 403 for empty overrides when user lacks MANAGE_CONFIGS', async () => { + const { handlers } = createHandlers({ + hasConfigCapability: jest.fn().mockResolvedValue(false), + }); + const req = mockReq({ + params: { principalType: 'role', principalId: 'admin' }, + body: { overrides: {} }, + }); + const res = mockRes(); + + await handlers.upsertConfigOverrides(req, res); + + expect(res.statusCode).toBe(403); + }); + }); + + // ── Invariant tests: rules that must hold across ALL handlers ────── + + const MUTATION_HANDLERS: Array<{ + name: string; + reqOverrides: Record; + }> = [ + { + name: 'upsertConfigOverrides', + reqOverrides: { + params: { principalType: 'role', principalId: 'admin' }, + body: { overrides: { interface: { endpointsMenu: false } } }, + }, + }, + { + name: 'patchConfigField', + reqOverrides: { + params: { principalType: 'role', principalId: 'admin' }, + body: { entries: [{ fieldPath: 'interface.endpointsMenu', value: false }] }, + }, + }, + { + name: 'deleteConfigField', + reqOverrides: { + params: { principalType: 'role', principalId: 'admin' }, + query: { fieldPath: 'interface.endpointsMenu' }, + }, + }, + { + name: 'deleteConfigOverrides', + reqOverrides: { + params: { principalType: 'role', principalId: 'admin' }, + }, + }, + { + name: 'toggleConfig', + reqOverrides: { + params: { principalType: 'role', principalId: 'admin' }, + body: { isActive: false }, + }, + }, + ]; + + describe('invariant: all mutation handlers return 401 without auth', () => { + for (const { name, reqOverrides } of MUTATION_HANDLERS) { + it(`${name} returns 401 when user is missing`, async () => { + const { handlers } = createHandlers(); + const req = mockReq({ ...reqOverrides, user: undefined }); + const res = mockRes(); + + await (handlers as Record Promise>)[name]( + req, + res, + ); + + expect(res.statusCode).toBe(401); + }); + } + }); + + describe('invariant: all mutation handlers return 403 without capability', () => { + for (const { name, reqOverrides } of MUTATION_HANDLERS) { + it(`${name} returns 403 when user lacks capability`, async () => { + const { handlers } = createHandlers({ + hasConfigCapability: jest.fn().mockResolvedValue(false), + }); + const req = mockReq(reqOverrides); + const res = mockRes(); + + await (handlers as Record Promise>)[name]( + req, + res, + ); + + expect(res.statusCode).toBe(403); + }); + } + }); + + describe('invariant: all read handlers return 403 without capability', () => { + const READ_HANDLERS: Array<{ name: string; reqOverrides: Record }> = [ + { name: 'listConfigs', reqOverrides: {} }, + { name: 'getBaseConfig', reqOverrides: {} }, + { + name: 'getConfig', + reqOverrides: { params: { principalType: 'role', principalId: 'admin' } }, + }, + ]; + + for (const { name, reqOverrides } of READ_HANDLERS) { + it(`${name} returns 403 when user lacks capability`, async () => { + const { handlers } = createHandlers({ + hasConfigCapability: jest.fn().mockResolvedValue(false), + }); + const req = mockReq(reqOverrides); + const res = mockRes(); + + await (handlers as Record Promise>)[name]( + req, + res, + ); + + expect(res.statusCode).toBe(403); + }); + } + }); + + describe('invariant: all read handlers return 401 without auth', () => { + const READ_HANDLERS: Array<{ name: string; reqOverrides: Record }> = [ + { name: 'listConfigs', reqOverrides: {} }, + { name: 'getBaseConfig', reqOverrides: {} }, + { + name: 'getConfig', + reqOverrides: { params: { principalType: 'role', principalId: 'admin' } }, + }, + ]; + + for (const { name, reqOverrides } of READ_HANDLERS) { + it(`${name} returns 401 when user is missing`, async () => { + const { handlers } = createHandlers(); + const req = mockReq({ ...reqOverrides, user: undefined }); + const res = mockRes(); + + await (handlers as Record Promise>)[name]( + req, + res, + ); + + expect(res.statusCode).toBe(401); + }); + } + }); + + describe('getBaseConfig', () => { + it('returns 403 when user lacks READ_CONFIGS', async () => { + const { handlers } = createHandlers({ + hasConfigCapability: jest.fn().mockResolvedValue(false), + }); + const req = mockReq(); + const res = mockRes(); + + await handlers.getBaseConfig(req, res); + + expect(res.statusCode).toBe(403); + }); + + it('returns the full AppConfig', async () => { + const { handlers } = createHandlers(); + const req = mockReq(); + const res = mockRes(); + + await handlers.getBaseConfig(req, res); + + expect(res.statusCode).toBe(200); + expect(res.body!.config).toEqual({ interface: { endpointsMenu: true } }); + }); + }); +}); diff --git a/packages/api/src/admin/config.spec.ts b/packages/api/src/admin/config.spec.ts new file mode 100644 index 0000000000..499cfaa35b --- /dev/null +++ b/packages/api/src/admin/config.spec.ts @@ -0,0 +1,57 @@ +import { isValidFieldPath, getTopLevelSection } from './config'; + +describe('isValidFieldPath', () => { + it('accepts simple dot paths', () => { + expect(isValidFieldPath('interface.endpointsMenu')).toBe(true); + expect(isValidFieldPath('registration.socialLogins')).toBe(true); + expect(isValidFieldPath('a')).toBe(true); + expect(isValidFieldPath('a.b.c.d')).toBe(true); + }); + + it('rejects empty and non-string', () => { + expect(isValidFieldPath('')).toBe(false); + // @ts-expect-error testing invalid input + expect(isValidFieldPath(undefined)).toBe(false); + // @ts-expect-error testing invalid input + expect(isValidFieldPath(null)).toBe(false); + // @ts-expect-error testing invalid input + expect(isValidFieldPath(42)).toBe(false); + }); + + it('rejects __proto__ and dunder-prefixed segments', () => { + expect(isValidFieldPath('__proto__')).toBe(false); + expect(isValidFieldPath('a.__proto__')).toBe(false); + expect(isValidFieldPath('__proto__.polluted')).toBe(false); + expect(isValidFieldPath('a.__proto__.b')).toBe(false); + expect(isValidFieldPath('__defineGetter__')).toBe(false); + expect(isValidFieldPath('a.__lookupSetter__')).toBe(false); + expect(isValidFieldPath('__')).toBe(false); + expect(isValidFieldPath('a.__.b')).toBe(false); + }); + + it('rejects constructor and prototype segments', () => { + expect(isValidFieldPath('constructor')).toBe(false); + expect(isValidFieldPath('a.constructor')).toBe(false); + expect(isValidFieldPath('constructor.a')).toBe(false); + expect(isValidFieldPath('prototype')).toBe(false); + expect(isValidFieldPath('a.prototype')).toBe(false); + expect(isValidFieldPath('prototype.a')).toBe(false); + }); + + it('allows segments containing but not matching reserved words', () => { + expect(isValidFieldPath('constructorName')).toBe(true); + expect(isValidFieldPath('prototypeChain')).toBe(true); + expect(isValidFieldPath('a.myConstructor')).toBe(true); + }); +}); + +describe('getTopLevelSection', () => { + it('returns first segment of a dot path', () => { + expect(getTopLevelSection('interface.endpointsMenu')).toBe('interface'); + expect(getTopLevelSection('registration.socialLogins.github')).toBe('registration'); + }); + + it('returns the whole string when no dots', () => { + expect(getTopLevelSection('interface')).toBe('interface'); + }); +}); diff --git a/packages/api/src/admin/config.ts b/packages/api/src/admin/config.ts new file mode 100644 index 0000000000..b2afd9c69b --- /dev/null +++ b/packages/api/src/admin/config.ts @@ -0,0 +1,529 @@ +import { logger } from '@librechat/data-schemas'; +import { PrincipalType, PrincipalModel } from 'librechat-data-provider'; +import type { TCustomConfig } from 'librechat-data-provider'; +import type { AppConfig, ConfigSection, IConfig } from '@librechat/data-schemas'; +import type { Types, ClientSession } from 'mongoose'; +import type { Response } from 'express'; +import type { CapabilityUser } from '~/middleware/capabilities'; +import type { ServerRequest } from '~/types/http'; + +const UNSAFE_SEGMENTS = /(?:^|\.)(__[\w]*|constructor|prototype)(?:\.|$)/; +const MAX_PATCH_ENTRIES = 100; +const DEFAULT_PRIORITY = 10; + +export function isValidFieldPath(path: string): boolean { + return ( + typeof path === 'string' && + path.length > 0 && + !path.startsWith('.') && + !path.endsWith('.') && + !path.includes('..') && + !UNSAFE_SEGMENTS.test(path) + ); +} + +export function getTopLevelSection(fieldPath: string): string { + return fieldPath.split('.')[0]; +} + +export interface AdminConfigDeps { + listAllConfigs: (filter?: { isActive?: boolean }, session?: ClientSession) => Promise; + findConfigByPrincipal: ( + principalType: PrincipalType, + principalId: string | Types.ObjectId, + options?: { includeInactive?: boolean }, + session?: ClientSession, + ) => Promise; + upsertConfig: ( + principalType: PrincipalType, + principalId: string | Types.ObjectId, + principalModel: PrincipalModel, + overrides: Partial, + priority: number, + session?: ClientSession, + ) => Promise; + patchConfigFields: ( + principalType: PrincipalType, + principalId: string | Types.ObjectId, + principalModel: PrincipalModel, + fields: Record, + priority: number, + session?: ClientSession, + ) => Promise; + unsetConfigField: ( + principalType: PrincipalType, + principalId: string | Types.ObjectId, + fieldPath: string, + session?: ClientSession, + ) => Promise; + deleteConfig: ( + principalType: PrincipalType, + principalId: string | Types.ObjectId, + session?: ClientSession, + ) => Promise; + toggleConfigActive: ( + principalType: PrincipalType, + principalId: string | Types.ObjectId, + isActive: boolean, + session?: ClientSession, + ) => Promise; + hasConfigCapability: ( + user: CapabilityUser, + section: ConfigSection | null, + verb?: 'manage' | 'read', + ) => Promise; + getAppConfig?: (options?: { + role?: string; + userId?: string; + tenantId?: string; + }) => Promise; + /** Invalidate all config-related caches after a mutation. */ + invalidateConfigCaches?: (tenantId?: string) => Promise; +} + +// ── Validation helpers ─────────────────────────────────────────────── + +const CONFIG_PRINCIPAL_TYPES = new Set([ + PrincipalType.USER, + PrincipalType.GROUP, + PrincipalType.ROLE, +]); + +function validatePrincipalType(value: string): value is PrincipalType { + return CONFIG_PRINCIPAL_TYPES.has(value as PrincipalType); +} + +function principalModel(type: PrincipalType): PrincipalModel { + switch (type) { + case PrincipalType.USER: + return PrincipalModel.USER; + case PrincipalType.GROUP: + return PrincipalModel.GROUP; + case PrincipalType.ROLE: + return PrincipalModel.ROLE; + case PrincipalType.PUBLIC: + return PrincipalModel.ROLE; + default: { + const _exhaustive: never = type; + logger.warn(`[adminConfig] Unmapped PrincipalType: ${String(_exhaustive)}`); + return PrincipalModel.ROLE; + } + } +} + +function getCapabilityUser(req: ServerRequest): CapabilityUser | null { + if (!req.user) { + return null; + } + return { + id: req.user.id ?? req.user._id?.toString() ?? '', + role: req.user.role ?? '', + tenantId: (req.user as { tenantId?: string }).tenantId, + }; +} + +// ── Handler factory ────────────────────────────────────────────────── + +export function createAdminConfigHandlers(deps: AdminConfigDeps) { + const { + listAllConfigs, + findConfigByPrincipal, + upsertConfig, + patchConfigFields, + unsetConfigField, + deleteConfig, + toggleConfigActive, + hasConfigCapability, + getAppConfig, + invalidateConfigCaches, + } = deps; + + /** + * GET / — List all active config overrides. + */ + async function listConfigs(req: ServerRequest, res: Response) { + try { + const user = getCapabilityUser(req); + if (!user) { + return res.status(401).json({ error: 'Authentication required' }); + } + + if (!(await hasConfigCapability(user, null, 'read'))) { + return res.status(403).json({ error: 'Insufficient permissions' }); + } + + const configs = await listAllConfigs(); + return res.status(200).json({ configs }); + } catch (error) { + logger.error('[adminConfig] listConfigs error:', error); + return res.status(500).json({ error: 'Failed to list configs' }); + } + } + + /** + * GET /base — Return the raw AppConfig (YAML + DB base merged). + * This is the full config structure admins can edit, NOT the startup payload. + */ + async function getBaseConfig(req: ServerRequest, res: Response) { + try { + const user = getCapabilityUser(req); + if (!user) { + return res.status(401).json({ error: 'Authentication required' }); + } + + if (!(await hasConfigCapability(user, null, 'read'))) { + return res.status(403).json({ error: 'Insufficient permissions' }); + } + + if (!getAppConfig) { + return res.status(501).json({ error: 'Base config endpoint not configured' }); + } + + const appConfig = await getAppConfig({ + tenantId: user.tenantId, + }); + return res.status(200).json({ config: appConfig }); + } catch (error) { + logger.error('[adminConfig] getBaseConfig error:', error); + return res.status(500).json({ error: 'Failed to get base config' }); + } + } + + /** + * GET /:principalType/:principalId — Get config for a specific principal. + */ + async function getConfig(req: ServerRequest, res: Response) { + try { + const { principalType, principalId } = req.params as { + principalType: string; + principalId: string; + }; + + if (!validatePrincipalType(principalType)) { + return res.status(400).json({ error: `Invalid principalType: ${principalType}` }); + } + + const user = getCapabilityUser(req); + if (!user) { + return res.status(401).json({ error: 'Authentication required' }); + } + + if (!(await hasConfigCapability(user, null, 'read'))) { + return res.status(403).json({ error: 'Insufficient permissions' }); + } + + const config = await findConfigByPrincipal(principalType, principalId, { + includeInactive: true, + }); + if (!config) { + return res.status(404).json({ error: 'Config not found' }); + } + + return res.status(200).json({ config }); + } catch (error) { + logger.error('[adminConfig] getConfig error:', error); + return res.status(500).json({ error: 'Failed to get config' }); + } + } + + /** + * PUT /:principalType/:principalId — Replace entire overrides for a principal. + */ + async function upsertConfigOverrides(req: ServerRequest, res: Response) { + try { + const { principalType, principalId } = req.params as { + principalType: string; + principalId: string; + }; + + if (!validatePrincipalType(principalType)) { + return res.status(400).json({ error: `Invalid principalType: ${principalType}` }); + } + + const { overrides, priority } = req.body as { + overrides?: Partial; + priority?: number; + }; + + if (!overrides || typeof overrides !== 'object' || Array.isArray(overrides)) { + return res.status(400).json({ error: 'overrides must be a plain object' }); + } + + if (priority != null && (typeof priority !== 'number' || priority < 0)) { + return res.status(400).json({ error: 'priority must be a non-negative number' }); + } + + const user = getCapabilityUser(req); + if (!user) { + return res.status(401).json({ error: 'Authentication required' }); + } + + if (!(await hasConfigCapability(user, null, 'manage'))) { + return res.status(403).json({ error: 'Insufficient permissions' }); + } + + const overrideSections = Object.keys(overrides); + if (overrideSections.length > 0) { + const allowed = await Promise.all( + overrideSections.map((s) => hasConfigCapability(user, s as ConfigSection, 'manage')), + ); + const denied = overrideSections.find((_, i) => !allowed[i]); + if (denied) { + return res.status(403).json({ + error: `Insufficient permissions for config section: ${denied}`, + }); + } + } + + const config = await upsertConfig( + principalType, + principalId, + principalModel(principalType), + overrides, + priority ?? DEFAULT_PRIORITY, + ); + + invalidateConfigCaches?.(user.tenantId)?.catch((err) => + logger.error('[adminConfig] Cache invalidation failed after upsert:', err), + ); + return res.status(config?.configVersion === 1 ? 201 : 200).json({ config }); + } catch (error) { + logger.error('[adminConfig] upsertConfigOverrides error:', error); + return res.status(500).json({ error: 'Failed to upsert config' }); + } + } + + /** + * PATCH /:principalType/:principalId/fields — Set individual fields via dot-paths. + */ + async function patchConfigField(req: ServerRequest, res: Response) { + try { + const { principalType, principalId } = req.params as { + principalType: string; + principalId: string; + }; + + if (!validatePrincipalType(principalType)) { + return res.status(400).json({ error: `Invalid principalType: ${principalType}` }); + } + + const { entries, priority } = req.body as { + entries?: Array<{ fieldPath: string; value: unknown }>; + priority?: number; + }; + + if (priority != null && (typeof priority !== 'number' || priority < 0)) { + return res.status(400).json({ error: 'priority must be a non-negative number' }); + } + + if (!Array.isArray(entries) || entries.length === 0) { + return res.status(400).json({ error: 'entries array is required and must not be empty' }); + } + + if (entries.length > MAX_PATCH_ENTRIES) { + return res + .status(400) + .json({ error: `entries array exceeds maximum of ${MAX_PATCH_ENTRIES}` }); + } + + for (const entry of entries) { + if (!isValidFieldPath(entry.fieldPath)) { + return res + .status(400) + .json({ error: `Invalid or unsafe field path: ${entry.fieldPath}` }); + } + } + + const user = getCapabilityUser(req); + if (!user) { + return res.status(401).json({ error: 'Authentication required' }); + } + + if (!(await hasConfigCapability(user, null, 'manage'))) { + const sections = [...new Set(entries.map((e) => getTopLevelSection(e.fieldPath)))]; + const allowed = await Promise.all( + sections.map((s) => hasConfigCapability(user, s as ConfigSection, 'manage')), + ); + const denied = sections.find((_, i) => !allowed[i]); + if (denied) { + return res.status(403).json({ + error: `Insufficient permissions for config section: ${denied}`, + }); + } + } + + const seen = new Set(); + const fields: Record = {}; + for (const entry of entries) { + if (seen.has(entry.fieldPath)) { + return res.status(400).json({ error: `Duplicate fieldPath: ${entry.fieldPath}` }); + } + seen.add(entry.fieldPath); + fields[entry.fieldPath] = entry.value; + } + + const existing = + priority == null + ? await findConfigByPrincipal(principalType, principalId, { includeInactive: true }) + : null; + + const config = await patchConfigFields( + principalType, + principalId, + principalModel(principalType), + fields, + priority ?? existing?.priority ?? DEFAULT_PRIORITY, + ); + + invalidateConfigCaches?.(user.tenantId)?.catch((err) => + logger.error('[adminConfig] Cache invalidation failed after patch:', err), + ); + return res.status(200).json({ config }); + } catch (error) { + logger.error('[adminConfig] patchConfigField error:', error); + return res.status(500).json({ error: 'Failed to patch config fields' }); + } + } + + /** + * DELETE /:principalType/:principalId/fields?fieldPath=dotted.path + */ + async function deleteConfigField(req: ServerRequest, res: Response) { + try { + const { principalType, principalId } = req.params as { + principalType: string; + principalId: string; + }; + if (!validatePrincipalType(principalType)) { + return res.status(400).json({ error: `Invalid principalType: ${principalType}` }); + } + + const fieldPath = req.query.fieldPath as string | undefined; + + if (!fieldPath || typeof fieldPath !== 'string') { + return res.status(400).json({ error: 'fieldPath query parameter is required' }); + } + + if (!isValidFieldPath(fieldPath)) { + return res.status(400).json({ error: `Invalid or unsafe field path: ${fieldPath}` }); + } + + const user = getCapabilityUser(req); + if (!user) { + return res.status(401).json({ error: 'Authentication required' }); + } + + const section = getTopLevelSection(fieldPath); + if (!(await hasConfigCapability(user, section as ConfigSection, 'manage'))) { + return res.status(403).json({ + error: `Insufficient permissions for config section: ${section}`, + }); + } + + const config = await unsetConfigField(principalType, principalId, fieldPath); + if (!config) { + return res.status(404).json({ error: 'Config not found' }); + } + + invalidateConfigCaches?.(user.tenantId)?.catch((err) => + logger.error('[adminConfig] Cache invalidation failed after field delete:', err), + ); + return res.status(200).json({ config }); + } catch (error) { + logger.error('[adminConfig] deleteConfigField error:', error); + return res.status(500).json({ error: 'Failed to delete config field' }); + } + } + + /** + * DELETE /:principalType/:principalId — Delete an entire config override. + */ + async function deleteConfigOverrides(req: ServerRequest, res: Response) { + try { + const { principalType, principalId } = req.params as { + principalType: string; + principalId: string; + }; + + if (!validatePrincipalType(principalType)) { + return res.status(400).json({ error: `Invalid principalType: ${principalType}` }); + } + + const user = getCapabilityUser(req); + if (!user) { + return res.status(401).json({ error: 'Authentication required' }); + } + + if (!(await hasConfigCapability(user, null, 'manage'))) { + return res.status(403).json({ error: 'Insufficient permissions' }); + } + + const config = await deleteConfig(principalType, principalId); + if (!config) { + return res.status(404).json({ error: 'Config not found' }); + } + + invalidateConfigCaches?.(user.tenantId)?.catch((err) => + logger.error('[adminConfig] Cache invalidation failed after config delete:', err), + ); + return res.status(200).json({ success: true }); + } catch (error) { + logger.error('[adminConfig] deleteConfigOverrides error:', error); + return res.status(500).json({ error: 'Failed to delete config' }); + } + } + + /** + * PATCH /:principalType/:principalId/active — Toggle isActive. + */ + async function toggleConfig(req: ServerRequest, res: Response) { + try { + const { principalType, principalId } = req.params as { + principalType: string; + principalId: string; + }; + + if (!validatePrincipalType(principalType)) { + return res.status(400).json({ error: `Invalid principalType: ${principalType}` }); + } + + const { isActive } = req.body as { isActive?: boolean }; + if (typeof isActive !== 'boolean') { + return res.status(400).json({ error: 'isActive boolean is required' }); + } + + const user = getCapabilityUser(req); + if (!user) { + return res.status(401).json({ error: 'Authentication required' }); + } + + if (!(await hasConfigCapability(user, null, 'manage'))) { + return res.status(403).json({ error: 'Insufficient permissions' }); + } + + const config = await toggleConfigActive(principalType, principalId, isActive); + if (!config) { + return res.status(404).json({ error: 'Config not found' }); + } + + invalidateConfigCaches?.(user.tenantId)?.catch((err) => + logger.error('[adminConfig] Cache invalidation failed after toggle:', err), + ); + return res.status(200).json({ config }); + } catch (error) { + logger.error('[adminConfig] toggleConfig error:', error); + return res.status(500).json({ error: 'Failed to toggle config' }); + } + } + + return { + listConfigs, + getBaseConfig, + getConfig, + upsertConfigOverrides, + patchConfigField, + deleteConfigField, + deleteConfigOverrides, + toggleConfig, + }; +} diff --git a/packages/api/src/admin/groups.spec.ts b/packages/api/src/admin/groups.spec.ts new file mode 100644 index 0000000000..42e32152d9 --- /dev/null +++ b/packages/api/src/admin/groups.spec.ts @@ -0,0 +1,1348 @@ +import { Types } from 'mongoose'; +import { PrincipalType } from 'librechat-data-provider'; +import type { IGroup, IUser } from '@librechat/data-schemas'; +import type { Response } from 'express'; +import type { ServerRequest } from '~/types/http'; +import type { AdminGroupsDeps } from './groups'; +import { createAdminGroupsHandlers } from './groups'; + +jest.mock('@librechat/data-schemas', () => ({ + ...jest.requireActual('@librechat/data-schemas'), + logger: { error: jest.fn(), warn: jest.fn() }, +})); + +describe('createAdminGroupsHandlers', () => { + let validId: string; + let validUserId: string; + + beforeEach(() => { + validId = new Types.ObjectId().toString(); + validUserId = new Types.ObjectId().toString(); + }); + + function mockGroup(overrides: Partial = {}): IGroup { + return { + _id: new Types.ObjectId(validId), + name: 'Test Group', + source: 'local', + memberIds: [], + createdAt: new Date(), + updatedAt: new Date(), + ...overrides, + } as IGroup; + } + + function mockUser(overrides: Partial = {}): IUser { + return { + _id: new Types.ObjectId(validUserId), + name: 'Test User', + email: 'test@example.com', + avatar: 'https://example.com/avatar.png', + ...overrides, + } as IUser; + } + + function createReqRes( + overrides: { + params?: Record; + query?: Record; + body?: Record; + } = {}, + ) { + const req = { + params: overrides.params ?? {}, + query: overrides.query ?? {}, + body: overrides.body ?? {}, + } as unknown as ServerRequest; + + const json = jest.fn(); + const status = jest.fn().mockReturnValue({ json }); + const res = { status, json } as unknown as Response; + + return { req, res, status, json }; + } + + function createDeps(overrides: Partial = {}): AdminGroupsDeps { + return { + listGroups: jest.fn().mockResolvedValue([]), + countGroups: jest.fn().mockResolvedValue(0), + findGroupById: jest.fn().mockResolvedValue(null), + createGroup: jest.fn().mockResolvedValue(mockGroup()), + updateGroupById: jest.fn().mockResolvedValue(mockGroup()), + deleteGroup: jest.fn().mockResolvedValue(mockGroup()), + addUserToGroup: jest.fn().mockResolvedValue({ user: mockUser(), group: mockGroup() }), + removeUserFromGroup: jest.fn().mockResolvedValue({ user: mockUser(), group: mockGroup() }), + removeMemberById: jest.fn().mockResolvedValue(mockGroup()), + findUsers: jest.fn().mockResolvedValue([]), + deleteConfig: jest.fn().mockResolvedValue(null), + deleteAclEntries: jest.fn().mockResolvedValue({ deletedCount: 0 }), + deleteGrantsForPrincipal: jest.fn().mockResolvedValue(undefined), + ...overrides, + }; + } + + describe('listGroups', () => { + it('returns groups with total, limit, offset', async () => { + const groups = [mockGroup()]; + const deps = createDeps({ + listGroups: jest.fn().mockResolvedValue(groups), + countGroups: jest.fn().mockResolvedValue(1), + }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ query: {} }); + + await handlers.listGroups(req, res); + + expect(status).toHaveBeenCalledWith(200); + expect(json).toHaveBeenCalledWith({ groups, total: 1, limit: 50, offset: 0 }); + }); + + it('passes source and search filters with pagination', async () => { + const deps = createDeps(); + const handlers = createAdminGroupsHandlers(deps); + const { req, res } = createReqRes({ + query: { source: 'entra', search: 'engineering', limit: '20', offset: '10' }, + }); + + await handlers.listGroups(req, res); + + expect(deps.listGroups).toHaveBeenCalledWith({ + source: 'entra', + search: 'engineering', + limit: 20, + offset: 10, + }); + expect(deps.countGroups).toHaveBeenCalledWith({ + source: 'entra', + search: 'engineering', + }); + }); + + it('passes search filter alone', async () => { + const deps = createDeps(); + const handlers = createAdminGroupsHandlers(deps); + const { req, res } = createReqRes({ query: { search: 'eng' } }); + + await handlers.listGroups(req, res); + + expect(deps.listGroups).toHaveBeenCalledWith({ search: 'eng', limit: 50, offset: 0 }); + expect(deps.countGroups).toHaveBeenCalledWith({ search: 'eng' }); + }); + + it('ignores invalid source values', async () => { + const deps = createDeps(); + const handlers = createAdminGroupsHandlers(deps); + const { req, res } = createReqRes({ query: { source: 'invalid' } }); + + await handlers.listGroups(req, res); + + expect(deps.listGroups).toHaveBeenCalledWith({ limit: 50, offset: 0 }); + expect(deps.countGroups).toHaveBeenCalledWith({}); + }); + + it('clamps limit and offset', async () => { + const deps = createDeps(); + const handlers = createAdminGroupsHandlers(deps); + const { req, res } = createReqRes({ query: { limit: '999', offset: '-5' } }); + + await handlers.listGroups(req, res); + + expect(deps.listGroups).toHaveBeenCalledWith({ limit: 200, offset: 0 }); + }); + + it('returns 400 when search exceeds max length', async () => { + const deps = createDeps(); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ + query: { search: 'a'.repeat(201) }, + }); + + await handlers.listGroups(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'search must not exceed 200 characters' }); + expect(deps.listGroups).not.toHaveBeenCalled(); + }); + + it('returns 500 when countGroups fails', async () => { + const deps = createDeps({ + countGroups: jest.fn().mockRejectedValue(new Error('count failed')), + }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes(); + + await handlers.listGroups(req, res); + + expect(status).toHaveBeenCalledWith(500); + expect(json).toHaveBeenCalledWith({ error: 'Failed to list groups' }); + }); + + it('returns 500 on error', async () => { + const deps = createDeps({ listGroups: jest.fn().mockRejectedValue(new Error('db down')) }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes(); + + await handlers.listGroups(req, res); + + expect(status).toHaveBeenCalledWith(500); + expect(json).toHaveBeenCalledWith({ error: 'Failed to list groups' }); + }); + }); + + describe('getGroup', () => { + it('returns group with 200', async () => { + const group = mockGroup(); + const deps = createDeps({ findGroupById: jest.fn().mockResolvedValue(group) }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ params: { id: validId } }); + + await handlers.getGroup(req, res); + + expect(status).toHaveBeenCalledWith(200); + expect(json).toHaveBeenCalledWith({ group }); + }); + + it('returns 400 for invalid ID', async () => { + const deps = createDeps(); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ params: { id: 'not-an-id' } }); + + await handlers.getGroup(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'Invalid group ID format' }); + expect(deps.findGroupById).not.toHaveBeenCalled(); + }); + + it('returns 404 when group not found', async () => { + const deps = createDeps({ findGroupById: jest.fn().mockResolvedValue(null) }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ params: { id: validId } }); + + await handlers.getGroup(req, res); + + expect(status).toHaveBeenCalledWith(404); + expect(json).toHaveBeenCalledWith({ error: 'Group not found' }); + }); + + it('returns 500 on error', async () => { + const deps = createDeps({ + findGroupById: jest.fn().mockRejectedValue(new Error('db down')), + }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ params: { id: validId } }); + + await handlers.getGroup(req, res); + + expect(status).toHaveBeenCalledWith(500); + expect(json).toHaveBeenCalledWith({ error: 'Failed to get group' }); + }); + }); + + describe('createGroup', () => { + it('creates group and returns 201', async () => { + const group = mockGroup(); + const deps = createDeps({ createGroup: jest.fn().mockResolvedValue(group) }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ + body: { name: 'New Group', description: 'A group' }, + }); + + await handlers.createGroup(req, res); + + expect(status).toHaveBeenCalledWith(201); + expect(json).toHaveBeenCalledWith({ group }); + expect(deps.createGroup).toHaveBeenCalledWith( + expect.objectContaining({ + name: 'New Group', + description: 'A group', + source: 'local', + memberIds: [], + }), + ); + }); + + it('normalizes memberIds to idOnTheSource values', async () => { + const userId = new Types.ObjectId().toString(); + const user = { _id: new Types.ObjectId(userId), idOnTheSource: 'ext-norm-1' } as IUser; + const group = mockGroup(); + const deps = createDeps({ + createGroup: jest.fn().mockResolvedValue(group), + findUsers: jest.fn().mockResolvedValue([user]), + }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status } = createReqRes({ + body: { name: 'With Members', memberIds: [userId] }, + }); + + await handlers.createGroup(req, res); + + expect(status).toHaveBeenCalledWith(201); + expect(deps.findUsers).toHaveBeenCalledWith({ _id: { $in: [userId] } }, 'idOnTheSource'); + expect(deps.createGroup).toHaveBeenCalledWith( + expect.objectContaining({ memberIds: ['ext-norm-1'] }), + ); + }); + + it('logs warning when memberIds contain non-existent user ObjectIds', async () => { + const { logger } = jest.requireMock('@librechat/data-schemas'); + const unknownId = new Types.ObjectId().toString(); + const group = mockGroup(); + const deps = createDeps({ + createGroup: jest.fn().mockResolvedValue(group), + findUsers: jest.fn().mockResolvedValue([]), + }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status } = createReqRes({ + body: { name: 'With Unknown', memberIds: [unknownId] }, + }); + + await handlers.createGroup(req, res); + + expect(status).toHaveBeenCalledWith(201); + expect(logger.warn).toHaveBeenCalledWith( + '[adminGroups] createGroup: memberIds contain unknown user ObjectIds:', + [unknownId], + ); + }); + + it('passes idOnTheSource when provided', async () => { + const group = mockGroup(); + const deps = createDeps({ createGroup: jest.fn().mockResolvedValue(group) }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status } = createReqRes({ + body: { name: 'Entra Group', source: 'entra', idOnTheSource: 'ent-abc-123' }, + }); + + await handlers.createGroup(req, res); + + expect(status).toHaveBeenCalledWith(201); + expect(deps.createGroup).toHaveBeenCalledWith( + expect.objectContaining({ idOnTheSource: 'ent-abc-123', source: 'entra' }), + ); + }); + + it('returns 400 for invalid source value', async () => { + const deps = createDeps(); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ + body: { name: 'Bad Source', source: 'azure' }, + }); + + await handlers.createGroup(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'Invalid source value' }); + expect(deps.createGroup).not.toHaveBeenCalled(); + }); + + it('returns 400 when name exceeds max length', async () => { + const deps = createDeps(); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ + body: { name: 'a'.repeat(501) }, + }); + + await handlers.createGroup(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'name must not exceed 500 characters' }); + expect(deps.createGroup).not.toHaveBeenCalled(); + }); + + it('returns 400 when description exceeds max length', async () => { + const deps = createDeps(); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ + body: { name: 'Valid', description: 'x'.repeat(2001) }, + }); + + await handlers.createGroup(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ + error: 'description must not exceed 2000 characters', + }); + expect(deps.createGroup).not.toHaveBeenCalled(); + }); + + it('returns 400 when email exceeds max length', async () => { + const deps = createDeps(); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ + body: { name: 'Valid', email: 'x'.repeat(501) }, + }); + + await handlers.createGroup(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'email must not exceed 500 characters' }); + expect(deps.createGroup).not.toHaveBeenCalled(); + }); + + it('returns 400 when avatar exceeds max length', async () => { + const deps = createDeps(); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ + body: { name: 'Valid', avatar: 'x'.repeat(2001) }, + }); + + await handlers.createGroup(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'avatar must not exceed 2000 characters' }); + expect(deps.createGroup).not.toHaveBeenCalled(); + }); + + it('returns 400 when idOnTheSource exceeds max length', async () => { + const deps = createDeps(); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ + body: { name: 'Valid', idOnTheSource: 'x'.repeat(501) }, + }); + + await handlers.createGroup(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ + error: 'idOnTheSource must not exceed 500 characters', + }); + expect(deps.createGroup).not.toHaveBeenCalled(); + }); + + it('returns 400 when memberIds exceeds cap', async () => { + const deps = createDeps(); + const handlers = createAdminGroupsHandlers(deps); + const memberIds = Array.from({ length: 501 }, (_, i) => `ext-${i}`); + const { req, res, status, json } = createReqRes({ + body: { name: 'Too Many Members', memberIds }, + }); + + await handlers.createGroup(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'memberIds must not exceed 500 entries' }); + expect(deps.createGroup).not.toHaveBeenCalled(); + }); + + it('passes non-ObjectId memberIds through unchanged', async () => { + const deps = createDeps(); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status } = createReqRes({ + body: { name: 'Ext Group', memberIds: ['ext-1', 'ext-2'] }, + }); + + await handlers.createGroup(req, res); + + expect(deps.findUsers).not.toHaveBeenCalled(); + expect(deps.createGroup).toHaveBeenCalledWith( + expect.objectContaining({ memberIds: ['ext-1', 'ext-2'] }), + ); + expect(status).toHaveBeenCalledWith(201); + }); + + it('returns 400 when name is missing', async () => { + const deps = createDeps(); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ body: {} }); + + await handlers.createGroup(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'name is required' }); + expect(deps.createGroup).not.toHaveBeenCalled(); + }); + + it('returns 400 when name is whitespace-only', async () => { + const deps = createDeps(); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ body: { name: ' ' } }); + + await handlers.createGroup(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'name is required' }); + }); + + it('returns 400 on ValidationError', async () => { + const validationError = new Error('source must be local or entra'); + validationError.name = 'ValidationError'; + const deps = createDeps({ createGroup: jest.fn().mockRejectedValue(validationError) }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ body: { name: 'Test' } }); + + await handlers.createGroup(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'source must be local or entra' }); + }); + + it('returns 500 on unexpected error', async () => { + const deps = createDeps({ createGroup: jest.fn().mockRejectedValue(new Error('db crash')) }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ body: { name: 'Test' } }); + + await handlers.createGroup(req, res); + + expect(status).toHaveBeenCalledWith(500); + expect(json).toHaveBeenCalledWith({ error: 'Failed to create group' }); + }); + }); + + describe('updateGroup', () => { + it('updates group and returns 200', async () => { + const group = mockGroup({ name: 'Updated' }); + const deps = createDeps({ + updateGroupById: jest.fn().mockResolvedValue(group), + }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { id: validId }, + body: { name: 'Updated' }, + }); + + await handlers.updateGroup(req, res); + + expect(status).toHaveBeenCalledWith(200); + expect(json).toHaveBeenCalledWith({ group }); + }); + + it('updates description only', async () => { + const group = mockGroup({ description: 'New desc' }); + const deps = createDeps({ updateGroupById: jest.fn().mockResolvedValue(group) }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status } = createReqRes({ + params: { id: validId }, + body: { description: 'New desc' }, + }); + + await handlers.updateGroup(req, res); + + expect(status).toHaveBeenCalledWith(200); + expect(deps.updateGroupById).toHaveBeenCalledWith(validId, { description: 'New desc' }); + }); + + it('updates email only', async () => { + const group = mockGroup({ email: 'team@co.com' }); + const deps = createDeps({ updateGroupById: jest.fn().mockResolvedValue(group) }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status } = createReqRes({ + params: { id: validId }, + body: { email: 'team@co.com' }, + }); + + await handlers.updateGroup(req, res); + + expect(status).toHaveBeenCalledWith(200); + expect(deps.updateGroupById).toHaveBeenCalledWith(validId, { email: 'team@co.com' }); + }); + + it('updates avatar only', async () => { + const group = mockGroup({ avatar: 'https://img.co/a.png' }); + const deps = createDeps({ updateGroupById: jest.fn().mockResolvedValue(group) }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status } = createReqRes({ + params: { id: validId }, + body: { avatar: 'https://img.co/a.png' }, + }); + + await handlers.updateGroup(req, res); + + expect(status).toHaveBeenCalledWith(200); + expect(deps.updateGroupById).toHaveBeenCalledWith(validId, { + avatar: 'https://img.co/a.png', + }); + }); + + it('updates multiple fields at once', async () => { + const group = mockGroup({ name: 'New', description: 'Desc', email: 'a@b.com' }); + const deps = createDeps({ updateGroupById: jest.fn().mockResolvedValue(group) }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status } = createReqRes({ + params: { id: validId }, + body: { name: ' New ', description: 'Desc', email: 'a@b.com' }, + }); + + await handlers.updateGroup(req, res); + + expect(status).toHaveBeenCalledWith(200); + expect(deps.updateGroupById).toHaveBeenCalledWith(validId, { + name: 'New', + description: 'Desc', + email: 'a@b.com', + }); + }); + + it('returns 400 for invalid ID', async () => { + const deps = createDeps(); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { id: 'bad' }, + body: { name: 'Updated' }, + }); + + await handlers.updateGroup(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'Invalid group ID format' }); + }); + + it('returns 400 when name is empty string', async () => { + const deps = createDeps(); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { id: validId }, + body: { name: '' }, + }); + + await handlers.updateGroup(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'name must be a non-empty string' }); + }); + + it('returns 400 when name is whitespace-only', async () => { + const deps = createDeps(); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { id: validId }, + body: { name: ' ' }, + }); + + await handlers.updateGroup(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'name must be a non-empty string' }); + }); + + it('returns 400 when name exceeds max length', async () => { + const deps = createDeps(); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { id: validId }, + body: { name: 'a'.repeat(501) }, + }); + + await handlers.updateGroup(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'name must not exceed 500 characters' }); + expect(deps.updateGroupById).not.toHaveBeenCalled(); + }); + + it('returns 400 when description exceeds max length', async () => { + const deps = createDeps(); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { id: validId }, + body: { description: 'x'.repeat(2001) }, + }); + + await handlers.updateGroup(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ + error: 'description must not exceed 2000 characters', + }); + expect(deps.updateGroupById).not.toHaveBeenCalled(); + }); + + it('returns 400 when email exceeds max length', async () => { + const deps = createDeps(); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { id: validId }, + body: { email: 'x'.repeat(501) }, + }); + + await handlers.updateGroup(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'email must not exceed 500 characters' }); + expect(deps.updateGroupById).not.toHaveBeenCalled(); + }); + + it('returns 400 when avatar exceeds max length', async () => { + const deps = createDeps(); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { id: validId }, + body: { avatar: 'x'.repeat(2001) }, + }); + + await handlers.updateGroup(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'avatar must not exceed 2000 characters' }); + expect(deps.updateGroupById).not.toHaveBeenCalled(); + }); + + it('returns 400 when no valid fields provided', async () => { + const deps = createDeps(); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { id: validId }, + body: {}, + }); + + await handlers.updateGroup(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'No valid fields to update' }); + expect(deps.updateGroupById).not.toHaveBeenCalled(); + }); + + it('returns 404 when updateGroupById returns null', async () => { + const deps = createDeps({ + updateGroupById: jest.fn().mockResolvedValue(null), + }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { id: validId }, + body: { name: 'Updated' }, + }); + + await handlers.updateGroup(req, res); + + expect(status).toHaveBeenCalledWith(404); + expect(json).toHaveBeenCalledWith({ error: 'Group not found' }); + }); + + it('returns 400 on ValidationError', async () => { + const validationError = new Error('invalid field'); + validationError.name = 'ValidationError'; + const deps = createDeps({ + updateGroupById: jest.fn().mockRejectedValue(validationError), + }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { id: validId }, + body: { name: 'Updated' }, + }); + + await handlers.updateGroup(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'invalid field' }); + }); + + it('returns 500 on error', async () => { + const deps = createDeps({ + updateGroupById: jest.fn().mockRejectedValue(new Error('db down')), + }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { id: validId }, + body: { name: 'Updated' }, + }); + + await handlers.updateGroup(req, res); + + expect(status).toHaveBeenCalledWith(500); + expect(json).toHaveBeenCalledWith({ error: 'Failed to update group' }); + }); + }); + + describe('deleteGroup', () => { + it('deletes group and returns 200 with id', async () => { + const deps = createDeps({ deleteGroup: jest.fn().mockResolvedValue(mockGroup()) }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ params: { id: validId } }); + + await handlers.deleteGroup(req, res); + + expect(deps.deleteGroup).toHaveBeenCalledWith(validId); + expect(status).toHaveBeenCalledWith(200); + expect(json).toHaveBeenCalledWith({ success: true, id: validId }); + }); + + it('returns 400 for invalid ID', async () => { + const deps = createDeps(); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ params: { id: 'bad-id' } }); + + await handlers.deleteGroup(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'Invalid group ID format' }); + }); + + it('returns 404 when deleteGroup returns null', async () => { + const deps = createDeps({ deleteGroup: jest.fn().mockResolvedValue(null) }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ params: { id: validId } }); + + await handlers.deleteGroup(req, res); + + expect(status).toHaveBeenCalledWith(404); + expect(json).toHaveBeenCalledWith({ error: 'Group not found' }); + expect(deps.deleteConfig).not.toHaveBeenCalled(); + }); + + it('returns 500 on error', async () => { + const deps = createDeps({ + deleteGroup: jest.fn().mockRejectedValue(new Error('db down')), + }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ params: { id: validId } }); + + await handlers.deleteGroup(req, res); + + expect(status).toHaveBeenCalledWith(500); + expect(json).toHaveBeenCalledWith({ error: 'Failed to delete group' }); + }); + + it('returns 200 even when cascade cleanup partially fails', async () => { + const deps = createDeps({ + deleteGroup: jest.fn().mockResolvedValue(mockGroup()), + deleteAclEntries: jest.fn().mockRejectedValue(new Error('cleanup failed')), + }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ params: { id: validId } }); + + await handlers.deleteGroup(req, res); + + expect(status).toHaveBeenCalledWith(200); + expect(json).toHaveBeenCalledWith({ success: true, id: validId }); + expect(deps.deleteConfig).toHaveBeenCalledWith(PrincipalType.GROUP, validId); + expect(deps.deleteAclEntries).toHaveBeenCalledWith({ + principalType: PrincipalType.GROUP, + principalId: new Types.ObjectId(validId), + }); + expect(deps.deleteGrantsForPrincipal).toHaveBeenCalledWith(PrincipalType.GROUP, validId); + }); + + it('cleans up Config, AclEntry, and SystemGrant on group delete', async () => { + const deps = createDeps({ deleteGroup: jest.fn().mockResolvedValue(mockGroup()) }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status } = createReqRes({ params: { id: validId } }); + + await handlers.deleteGroup(req, res); + + expect(status).toHaveBeenCalledWith(200); + expect(deps.deleteConfig).toHaveBeenCalledWith(PrincipalType.GROUP, validId); + expect(deps.deleteAclEntries).toHaveBeenCalledWith({ + principalType: PrincipalType.GROUP, + principalId: new Types.ObjectId(validId), + }); + expect(deps.deleteGrantsForPrincipal).toHaveBeenCalledWith(PrincipalType.GROUP, validId); + }); + }); + + describe('getGroupMembers', () => { + it('fetches group with memberIds projection only', async () => { + const group = mockGroup({ memberIds: [] }); + const deps = createDeps({ findGroupById: jest.fn().mockResolvedValue(group) }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res } = createReqRes({ params: { id: validId } }); + + await handlers.getGroupMembers(req, res); + + expect(deps.findGroupById).toHaveBeenCalledWith(validId, { memberIds: 1 }); + }); + + it('returns empty members for group with no memberIds', async () => { + const group = mockGroup({ memberIds: [] }); + const deps = createDeps({ findGroupById: jest.fn().mockResolvedValue(group) }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ params: { id: validId } }); + + await handlers.getGroupMembers(req, res); + + expect(status).toHaveBeenCalledWith(200); + expect(json).toHaveBeenCalledWith({ members: [], total: 0, limit: 50, offset: 0 }); + expect(deps.findUsers).not.toHaveBeenCalled(); + }); + + it('batches member lookup with $or query', async () => { + const user = mockUser({ idOnTheSource: 'ext-123' }); + const group = mockGroup({ memberIds: [validUserId, 'ext-123'] }); + const deps = createDeps({ + findGroupById: jest.fn().mockResolvedValue(group), + findUsers: jest.fn().mockResolvedValue([user]), + }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ params: { id: validId } }); + + await handlers.getGroupMembers(req, res); + + expect(deps.findUsers).toHaveBeenCalledWith( + { + $or: [ + { idOnTheSource: { $in: [validUserId, 'ext-123'] } }, + { _id: { $in: [validUserId] } }, + ], + }, + 'name email avatar idOnTheSource', + ); + expect(status).toHaveBeenCalledWith(200); + const members = json.mock.calls[0][0].members; + expect(members).toHaveLength(1); + }); + + it('skips _id condition when no valid ObjectIds in memberIds', async () => { + const group = mockGroup({ memberIds: ['ext-1', 'ext-2'] }); + const deps = createDeps({ + findGroupById: jest.fn().mockResolvedValue(group), + findUsers: jest.fn().mockResolvedValue([]), + }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res } = createReqRes({ params: { id: validId } }); + + await handlers.getGroupMembers(req, res); + + expect(deps.findUsers).toHaveBeenCalledWith( + { $or: [{ idOnTheSource: { $in: ['ext-1', 'ext-2'] } }] }, + 'name email avatar idOnTheSource', + ); + }); + + it('falls back to memberId when user not found', async () => { + const group = mockGroup({ memberIds: ['unknown-member'] }); + const deps = createDeps({ + findGroupById: jest.fn().mockResolvedValue(group), + findUsers: jest.fn().mockResolvedValue([]), + }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, json } = createReqRes({ params: { id: validId } }); + + await handlers.getGroupMembers(req, res); + + expect(json.mock.calls[0][0].members).toEqual([ + { userId: 'unknown-member', name: 'unknown-member', email: '', avatarUrl: undefined }, + ]); + }); + + it('deduplicates when identical memberId appears twice', async () => { + const user = mockUser({ idOnTheSource: validUserId }); + const group = mockGroup({ memberIds: [validUserId, validUserId] }); + const deps = createDeps({ + findGroupById: jest.fn().mockResolvedValue(group), + findUsers: jest.fn().mockResolvedValue([user]), + }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, json } = createReqRes({ params: { id: validId } }); + + await handlers.getGroupMembers(req, res); + + const result = json.mock.calls[0][0]; + expect(result.members).toHaveLength(1); + expect(result.total).toBe(1); + }); + + it('deduplicates when objectId and idOnTheSource both present for same user', async () => { + const extId = 'ext-dedup-123'; + const user = mockUser({ idOnTheSource: extId }); + const group = mockGroup({ memberIds: [validUserId, extId] }); + const deps = createDeps({ + findGroupById: jest.fn().mockResolvedValue(group), + findUsers: jest.fn().mockResolvedValue([user]), + }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, json } = createReqRes({ params: { id: validId } }); + + await handlers.getGroupMembers(req, res); + + expect(json.mock.calls[0][0].members).toHaveLength(1); + }); + + it('reports deduplicated total for duplicate memberIds', async () => { + const group = mockGroup({ memberIds: ['m1', 'm2', 'm1', 'm3', 'm2'] }); + const deps = createDeps({ + findGroupById: jest.fn().mockResolvedValue(group), + findUsers: jest.fn().mockResolvedValue([]), + }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, json } = createReqRes({ params: { id: validId } }); + + await handlers.getGroupMembers(req, res); + + const result = json.mock.calls[0][0]; + expect(result.total).toBe(3); + expect(result.members).toHaveLength(3); + }); + + it('paginates members with limit and offset', async () => { + const ids = ['m1', 'm2', 'm3', 'm4', 'm5']; + const group = mockGroup({ memberIds: ids }); + const deps = createDeps({ + findGroupById: jest.fn().mockResolvedValue(group), + findUsers: jest.fn().mockResolvedValue([]), + }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, json } = createReqRes({ + params: { id: validId }, + query: { limit: '2', offset: '1' }, + }); + + await handlers.getGroupMembers(req, res); + + const result = json.mock.calls[0][0]; + expect(result.total).toBe(5); + expect(result.limit).toBe(2); + expect(result.offset).toBe(1); + expect(result.members).toHaveLength(2); + expect(result.members[0].userId).toBe('m2'); + expect(result.members[1].userId).toBe('m3'); + }); + + it('caps limit at 200', async () => { + const ids = Array.from({ length: 5 }, (_, i) => `m${i}`); + const group = mockGroup({ memberIds: ids }); + const deps = createDeps({ + findGroupById: jest.fn().mockResolvedValue(group), + findUsers: jest.fn().mockResolvedValue([]), + }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, json } = createReqRes({ + params: { id: validId }, + query: { limit: '999' }, + }); + + await handlers.getGroupMembers(req, res); + + const result = json.mock.calls[0][0]; + expect(result.limit).toBe(200); + }); + + it('returns empty when offset exceeds total', async () => { + const group = mockGroup({ memberIds: ['m1', 'm2'] }); + const deps = createDeps({ + findGroupById: jest.fn().mockResolvedValue(group), + }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, json } = createReqRes({ + params: { id: validId }, + query: { offset: '10' }, + }); + + await handlers.getGroupMembers(req, res); + + const result = json.mock.calls[0][0]; + expect(result.members).toHaveLength(0); + expect(result.total).toBe(2); + expect(deps.findUsers).not.toHaveBeenCalled(); + }); + + it('returns 400 for invalid group ID', async () => { + const deps = createDeps(); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ params: { id: 'nope' } }); + + await handlers.getGroupMembers(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'Invalid group ID format' }); + }); + + it('returns 404 when group not found', async () => { + const deps = createDeps({ findGroupById: jest.fn().mockResolvedValue(null) }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ params: { id: validId } }); + + await handlers.getGroupMembers(req, res); + + expect(status).toHaveBeenCalledWith(404); + expect(json).toHaveBeenCalledWith({ error: 'Group not found' }); + }); + + it('returns 500 on error', async () => { + const deps = createDeps({ + findGroupById: jest.fn().mockRejectedValue(new Error('db down')), + }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ params: { id: validId } }); + + await handlers.getGroupMembers(req, res); + + expect(status).toHaveBeenCalledWith(500); + expect(json).toHaveBeenCalledWith({ error: 'Failed to get group members' }); + }); + }); + + describe('addGroupMember', () => { + it('adds member and returns 200', async () => { + const group = mockGroup(); + const deps = createDeps({ + addUserToGroup: jest.fn().mockResolvedValue({ user: mockUser(), group }), + }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { id: validId }, + body: { userId: validUserId }, + }); + + await handlers.addGroupMember(req, res); + + expect(deps.addUserToGroup).toHaveBeenCalledWith(validUserId, validId); + expect(status).toHaveBeenCalledWith(200); + expect(json).toHaveBeenCalledWith({ group }); + }); + + it('returns 400 for invalid group ID', async () => { + const deps = createDeps(); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { id: 'bad' }, + body: { userId: validUserId }, + }); + + await handlers.addGroupMember(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'Invalid group ID format' }); + }); + + it('returns 400 when userId is missing', async () => { + const deps = createDeps(); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { id: validId }, + body: {}, + }); + + await handlers.addGroupMember(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'userId is required' }); + }); + + it('returns 400 for non-ObjectId userId', async () => { + const deps = createDeps(); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { id: validId }, + body: { userId: 'not-valid' }, + }); + + await handlers.addGroupMember(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ + error: 'Only native user ObjectIds can be added via this endpoint', + }); + }); + + it('returns 404 when addUserToGroup returns null group', async () => { + const deps = createDeps({ + addUserToGroup: jest.fn().mockResolvedValue({ user: mockUser(), group: null }), + }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { id: validId }, + body: { userId: validUserId }, + }); + + await handlers.addGroupMember(req, res); + + expect(status).toHaveBeenCalledWith(404); + expect(json).toHaveBeenCalledWith({ error: 'Group not found' }); + }); + + it('returns 404 for "User not found" error', async () => { + const deps = createDeps({ + addUserToGroup: jest.fn().mockRejectedValue(new Error('User not found')), + }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { id: validId }, + body: { userId: validUserId }, + }); + + await handlers.addGroupMember(req, res); + + expect(status).toHaveBeenCalledWith(404); + expect(json).toHaveBeenCalledWith({ error: 'User not found' }); + }); + + it('returns 500 for unrelated errors', async () => { + const deps = createDeps({ + addUserToGroup: jest.fn().mockRejectedValue(new Error('connection lost')), + }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { id: validId }, + body: { userId: validUserId }, + }); + + await handlers.addGroupMember(req, res); + + expect(status).toHaveBeenCalledWith(500); + expect(json).toHaveBeenCalledWith({ error: 'Failed to add member' }); + }); + + it('does not misclassify errors containing "not found" substring', async () => { + const deps = createDeps({ + addUserToGroup: jest.fn().mockRejectedValue(new Error('Permission not found in config')), + }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status } = createReqRes({ + params: { id: validId }, + body: { userId: validUserId }, + }); + + await handlers.addGroupMember(req, res); + + expect(status).toHaveBeenCalledWith(500); + }); + }); + + describe('removeGroupMember', () => { + it('removes member and returns 200', async () => { + const deps = createDeps({ + removeUserFromGroup: jest.fn().mockResolvedValue({ user: mockUser(), group: mockGroup() }), + }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { id: validId, userId: validUserId }, + }); + + await handlers.removeGroupMember(req, res); + + expect(deps.removeUserFromGroup).toHaveBeenCalledWith(validUserId, validId); + expect(status).toHaveBeenCalledWith(200); + expect(json).toHaveBeenCalledWith({ success: true }); + }); + + it('returns 400 for invalid group ID', async () => { + const deps = createDeps(); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { id: 'bad', userId: validUserId }, + }); + + await handlers.removeGroupMember(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'Invalid group ID format' }); + }); + + it('removes non-ObjectId member via removeMemberById', async () => { + const deps = createDeps({ + removeMemberById: jest.fn().mockResolvedValue(mockGroup()), + }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { id: validId, userId: 'ent-abc-123' }, + }); + + await handlers.removeGroupMember(req, res); + + expect(deps.removeMemberById).toHaveBeenCalledWith(validId, 'ent-abc-123'); + expect(deps.removeUserFromGroup).not.toHaveBeenCalled(); + expect(status).toHaveBeenCalledWith(200); + expect(json).toHaveBeenCalledWith({ success: true }); + }); + + it('returns 404 when removeMemberById returns null', async () => { + const deps = createDeps({ + removeMemberById: jest.fn().mockResolvedValue(null), + }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { id: validId, userId: 'ent-abc-123' }, + }); + + await handlers.removeGroupMember(req, res); + + expect(status).toHaveBeenCalledWith(404); + expect(json).toHaveBeenCalledWith({ error: 'Group not found' }); + }); + + it('falls back to removeMemberById when ObjectId userId not found as user', async () => { + const deps = createDeps({ + removeUserFromGroup: jest.fn().mockRejectedValue(new Error('User not found')), + removeMemberById: jest.fn().mockResolvedValue(mockGroup()), + }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { id: validId, userId: validUserId }, + }); + + await handlers.removeGroupMember(req, res); + + expect(deps.removeUserFromGroup).toHaveBeenCalledWith(validUserId, validId); + expect(deps.removeMemberById).toHaveBeenCalledWith(validId, validUserId); + expect(status).toHaveBeenCalledWith(200); + expect(json).toHaveBeenCalledWith({ success: true }); + }); + + it('returns 404 when removeUserFromGroup returns null group', async () => { + const deps = createDeps({ + removeUserFromGroup: jest.fn().mockResolvedValue({ user: mockUser(), group: null }), + }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { id: validId, userId: validUserId }, + }); + + await handlers.removeGroupMember(req, res); + + expect(status).toHaveBeenCalledWith(404); + expect(json).toHaveBeenCalledWith({ error: 'Group not found' }); + }); + + it('returns 404 when fallback removeMemberById also returns null', async () => { + const deps = createDeps({ + removeUserFromGroup: jest.fn().mockRejectedValue(new Error('User not found')), + removeMemberById: jest.fn().mockResolvedValue(null), + }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { id: validId, userId: validUserId }, + }); + + await handlers.removeGroupMember(req, res); + + expect(status).toHaveBeenCalledWith(404); + expect(json).toHaveBeenCalledWith({ error: 'Group not found' }); + }); + + it('returns 500 for unrelated errors', async () => { + const deps = createDeps({ + removeUserFromGroup: jest.fn().mockRejectedValue(new Error('timeout')), + }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { id: validId, userId: validUserId }, + }); + + await handlers.removeGroupMember(req, res); + + expect(status).toHaveBeenCalledWith(500); + expect(json).toHaveBeenCalledWith({ error: 'Failed to remove member' }); + }); + + it('returns 200 when removing ObjectId member not in group (idempotent delete)', async () => { + const group = mockGroup({ memberIds: [] }); + const deps = createDeps({ + removeUserFromGroup: jest.fn().mockResolvedValue({ user: mockUser(), group }), + }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { id: validId, userId: validUserId }, + }); + + await handlers.removeGroupMember(req, res); + + expect(status).toHaveBeenCalledWith(200); + expect(json).toHaveBeenCalledWith({ success: true }); + }); + + it('returns 200 when removing non-ObjectId member not in group (idempotent delete)', async () => { + const group = mockGroup({ memberIds: [] }); + const deps = createDeps({ + removeMemberById: jest.fn().mockResolvedValue(group), + }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { id: validId, userId: 'ext-not-in-group' }, + }); + + await handlers.removeGroupMember(req, res); + + expect(status).toHaveBeenCalledWith(200); + expect(json).toHaveBeenCalledWith({ success: true }); + }); + }); +}); diff --git a/packages/api/src/admin/groups.ts b/packages/api/src/admin/groups.ts new file mode 100644 index 0000000000..ab4490e05f --- /dev/null +++ b/packages/api/src/admin/groups.ts @@ -0,0 +1,481 @@ +import { Types } from 'mongoose'; +import { PrincipalType } from 'librechat-data-provider'; +import { logger, isValidObjectIdString } from '@librechat/data-schemas'; +import type { + IGroup, + IUser, + IConfig, + CreateGroupRequest, + UpdateGroupRequest, + GroupFilterOptions, +} from '@librechat/data-schemas'; +import type { FilterQuery, ClientSession, DeleteResult } from 'mongoose'; +import type { Response } from 'express'; +import type { ValidationError } from '~/types/error'; +import type { ServerRequest } from '~/types/http'; +import { parsePagination } from './pagination'; + +type GroupListFilter = Pick; + +const VALID_GROUP_SOURCES: ReadonlySet = new Set(['local', 'entra']); +const MAX_CREATE_MEMBER_IDS = 500; +const MAX_SEARCH_LENGTH = 200; +const MAX_NAME_LENGTH = 500; +const MAX_DESCRIPTION_LENGTH = 2000; +const MAX_EMAIL_LENGTH = 500; +const MAX_AVATAR_LENGTH = 2000; +const MAX_EXTERNAL_ID_LENGTH = 500; + +interface GroupIdParams { + id: string; +} + +interface GroupMemberParams extends GroupIdParams { + userId: string; +} + +export interface AdminGroupsDeps { + listGroups: ( + filter?: GroupListFilter & { limit?: number; offset?: number }, + session?: ClientSession, + ) => Promise; + countGroups: (filter?: GroupListFilter, session?: ClientSession) => Promise; + findGroupById: ( + groupId: string | Types.ObjectId, + projection?: Record, + session?: ClientSession, + ) => Promise; + createGroup: (groupData: Partial, session?: ClientSession) => Promise; + updateGroupById: ( + groupId: string | Types.ObjectId, + data: Partial>, + session?: ClientSession, + ) => Promise; + deleteGroup: ( + groupId: string | Types.ObjectId, + session?: ClientSession, + ) => Promise; + addUserToGroup: ( + userId: string | Types.ObjectId, + groupId: string | Types.ObjectId, + session?: ClientSession, + ) => Promise<{ user: IUser; group: IGroup | null }>; + removeUserFromGroup: ( + userId: string | Types.ObjectId, + groupId: string | Types.ObjectId, + session?: ClientSession, + ) => Promise<{ user: IUser; group: IGroup | null }>; + removeMemberById: ( + groupId: string | Types.ObjectId, + memberId: string, + session?: ClientSession, + ) => Promise; + findUsers: ( + searchCriteria: FilterQuery, + fieldsToSelect?: string | string[] | null, + ) => Promise; + deleteConfig: ( + principalType: PrincipalType, + principalId: string | Types.ObjectId, + ) => Promise; + deleteAclEntries: (filter: { + principalType: PrincipalType; + principalId: string | Types.ObjectId; + }) => Promise; + deleteGrantsForPrincipal: ( + principalType: PrincipalType, + principalId: string | Types.ObjectId, + ) => Promise; +} + +export function createAdminGroupsHandlers(deps: AdminGroupsDeps) { + const { + listGroups, + countGroups, + findGroupById, + createGroup, + updateGroupById, + deleteGroup, + addUserToGroup, + removeUserFromGroup, + removeMemberById, + findUsers, + deleteConfig, + deleteAclEntries, + deleteGrantsForPrincipal, + } = deps; + + async function listGroupsHandler(req: ServerRequest, res: Response) { + try { + const { search, source } = req.query as { search?: string; source?: string }; + const filter: GroupListFilter = {}; + if (source && VALID_GROUP_SOURCES.has(source)) { + filter.source = source as IGroup['source']; + } + if (search && search.length > MAX_SEARCH_LENGTH) { + return res + .status(400) + .json({ error: `search must not exceed ${MAX_SEARCH_LENGTH} characters` }); + } + if (search) { + filter.search = search; + } + const { limit, offset } = parsePagination(req.query); + const [groups, total] = await Promise.all([ + listGroups({ ...filter, limit, offset }), + countGroups(filter), + ]); + return res.status(200).json({ groups, total, limit, offset }); + } catch (error) { + logger.error('[adminGroups] listGroups error:', error); + return res.status(500).json({ error: 'Failed to list groups' }); + } + } + + async function getGroupHandler(req: ServerRequest, res: Response) { + try { + const { id } = req.params as GroupIdParams; + if (!isValidObjectIdString(id)) { + return res.status(400).json({ error: 'Invalid group ID format' }); + } + const group = await findGroupById(id); + if (!group) { + return res.status(404).json({ error: 'Group not found' }); + } + return res.status(200).json({ group }); + } catch (error) { + logger.error('[adminGroups] getGroup error:', error); + return res.status(500).json({ error: 'Failed to get group' }); + } + } + + async function createGroupHandler(req: ServerRequest, res: Response) { + try { + const body = req.body as CreateGroupRequest; + if (!body.name || typeof body.name !== 'string' || !body.name.trim()) { + return res.status(400).json({ error: 'name is required' }); + } + if (body.name.trim().length > MAX_NAME_LENGTH) { + return res + .status(400) + .json({ error: `name must not exceed ${MAX_NAME_LENGTH} characters` }); + } + if (body.source && !VALID_GROUP_SOURCES.has(body.source)) { + return res.status(400).json({ error: 'Invalid source value' }); + } + if (body.description && body.description.length > MAX_DESCRIPTION_LENGTH) { + return res + .status(400) + .json({ error: `description must not exceed ${MAX_DESCRIPTION_LENGTH} characters` }); + } + if (body.email && body.email.length > MAX_EMAIL_LENGTH) { + return res + .status(400) + .json({ error: `email must not exceed ${MAX_EMAIL_LENGTH} characters` }); + } + if (body.avatar && body.avatar.length > MAX_AVATAR_LENGTH) { + return res + .status(400) + .json({ error: `avatar must not exceed ${MAX_AVATAR_LENGTH} characters` }); + } + if (body.idOnTheSource && body.idOnTheSource.length > MAX_EXTERNAL_ID_LENGTH) { + return res + .status(400) + .json({ error: `idOnTheSource must not exceed ${MAX_EXTERNAL_ID_LENGTH} characters` }); + } + + const rawIds = Array.isArray(body.memberIds) ? body.memberIds : []; + if (rawIds.length > MAX_CREATE_MEMBER_IDS) { + return res + .status(400) + .json({ error: `memberIds must not exceed ${MAX_CREATE_MEMBER_IDS} entries` }); + } + let memberIds = rawIds; + const objectIds = rawIds.filter(isValidObjectIdString); + if (objectIds.length > 0) { + const users = await findUsers({ _id: { $in: objectIds } }, 'idOnTheSource'); + const idMap = new Map(); + for (const user of users) { + const uid = user._id?.toString() ?? ''; + idMap.set(uid, user.idOnTheSource || uid); + } + const unmapped = objectIds.filter((oid) => !idMap.has(oid)); + if (unmapped.length > 0) { + logger.warn( + '[adminGroups] createGroup: memberIds contain unknown user ObjectIds:', + unmapped, + ); + } + memberIds = rawIds.map((id) => idMap.get(id) || id); + } + + const group = await createGroup({ + name: body.name.trim(), + description: body.description, + email: body.email, + avatar: body.avatar, + source: body.source || 'local', + memberIds, + ...(body.idOnTheSource ? { idOnTheSource: body.idOnTheSource } : {}), + }); + return res.status(201).json({ group }); + } catch (error) { + if ((error as ValidationError).name === 'ValidationError') { + return res.status(400).json({ error: (error as ValidationError).message }); + } + logger.error('[adminGroups] createGroup error:', error); + return res.status(500).json({ error: 'Failed to create group' }); + } + } + + async function updateGroupHandler(req: ServerRequest, res: Response) { + try { + const { id } = req.params as GroupIdParams; + if (!isValidObjectIdString(id)) { + return res.status(400).json({ error: 'Invalid group ID format' }); + } + const body = req.body as UpdateGroupRequest; + + if ( + body.name !== undefined && + (!body.name || typeof body.name !== 'string' || !body.name.trim()) + ) { + return res.status(400).json({ error: 'name must be a non-empty string' }); + } + if (body.name !== undefined && body.name.trim().length > MAX_NAME_LENGTH) { + return res + .status(400) + .json({ error: `name must not exceed ${MAX_NAME_LENGTH} characters` }); + } + if (body.description !== undefined && body.description.length > MAX_DESCRIPTION_LENGTH) { + return res + .status(400) + .json({ error: `description must not exceed ${MAX_DESCRIPTION_LENGTH} characters` }); + } + if (body.email !== undefined && body.email.length > MAX_EMAIL_LENGTH) { + return res + .status(400) + .json({ error: `email must not exceed ${MAX_EMAIL_LENGTH} characters` }); + } + if (body.avatar !== undefined && body.avatar.length > MAX_AVATAR_LENGTH) { + return res + .status(400) + .json({ error: `avatar must not exceed ${MAX_AVATAR_LENGTH} characters` }); + } + + const updateData: Partial> = {}; + if (body.name !== undefined) { + updateData.name = body.name.trim(); + } + if (body.description !== undefined) { + updateData.description = body.description; + } + if (body.email !== undefined) { + updateData.email = body.email; + } + if (body.avatar !== undefined) { + updateData.avatar = body.avatar; + } + + if (Object.keys(updateData).length === 0) { + return res.status(400).json({ error: 'No valid fields to update' }); + } + + const group = await updateGroupById(id, updateData); + if (!group) { + return res.status(404).json({ error: 'Group not found' }); + } + return res.status(200).json({ group }); + } catch (error) { + if ((error as ValidationError).name === 'ValidationError') { + return res.status(400).json({ error: (error as ValidationError).message }); + } + logger.error('[adminGroups] updateGroup error:', error); + return res.status(500).json({ error: 'Failed to update group' }); + } + } + + async function deleteGroupHandler(req: ServerRequest, res: Response) { + try { + const { id } = req.params as GroupIdParams; + if (!isValidObjectIdString(id)) { + return res.status(400).json({ error: 'Invalid group ID format' }); + } + const deleted = await deleteGroup(id); + if (!deleted) { + return res.status(404).json({ error: 'Group not found' }); + } + /** + * deleteAclEntries is a raw deleteMany wrapper with no type casting. + * grantPermission stores group principalId as ObjectId, so we must + * cast here. deleteConfig and deleteGrantsForPrincipal normalize internally. + */ + const cleanupResults = await Promise.allSettled([ + deleteConfig(PrincipalType.GROUP, id), + deleteAclEntries({ + principalType: PrincipalType.GROUP, + principalId: new Types.ObjectId(id), + }), + deleteGrantsForPrincipal(PrincipalType.GROUP, id), + ]); + for (const result of cleanupResults) { + if (result.status === 'rejected') { + logger.error('[adminGroups] cascade cleanup step failed for group:', id, result.reason); + } + } + return res.status(200).json({ success: true, id }); + } catch (error) { + logger.error('[adminGroups] deleteGroup error:', error); + return res.status(500).json({ error: 'Failed to delete group' }); + } + } + + async function getGroupMembersHandler(req: ServerRequest, res: Response) { + try { + const { id } = req.params as GroupIdParams; + if (!isValidObjectIdString(id)) { + return res.status(400).json({ error: 'Invalid group ID format' }); + } + const group = await findGroupById(id, { memberIds: 1 }); + if (!group) { + return res.status(404).json({ error: 'Group not found' }); + } + + /** + * `total` counts unique raw memberId strings. After user resolution, two + * distinct strings may map to the same user, so `members.length` can be + * less than the page size. Write paths prevent this for well-formed data. + */ + const allMemberIds = [...new Set(group.memberIds || [])]; + const total = allMemberIds.length; + const { limit, offset } = parsePagination(req.query); + + if (total === 0 || offset >= total) { + return res.status(200).json({ members: [], total, limit, offset }); + } + + const memberIds = allMemberIds.slice(offset, offset + limit); + + const validObjectIds = memberIds.filter(isValidObjectIdString); + const conditions: FilterQuery[] = [{ idOnTheSource: { $in: memberIds } }]; + if (validObjectIds.length > 0) { + conditions.push({ _id: { $in: validObjectIds } }); + } + const users = await findUsers({ $or: conditions }, 'name email avatar idOnTheSource'); + + const userMap = new Map(); + for (const user of users) { + if (user.idOnTheSource) { + userMap.set(user.idOnTheSource, user); + } + if (user._id) { + userMap.set(user._id.toString(), user); + } + } + + const seen = new Set(); + const members: { userId: string; name: string; email: string; avatarUrl?: string }[] = []; + for (const memberId of memberIds) { + const user = userMap.get(memberId); + const userId = user?._id?.toString() ?? memberId; + if (seen.has(userId)) { + continue; + } + seen.add(userId); + members.push({ + userId, + name: user?.name ?? memberId, + email: user?.email ?? '', + avatarUrl: user?.avatar, + }); + } + + return res.status(200).json({ members, total, limit, offset }); + } catch (error) { + logger.error('[adminGroups] getGroupMembers error:', error); + return res.status(500).json({ error: 'Failed to get group members' }); + } + } + + async function addGroupMemberHandler(req: ServerRequest, res: Response) { + try { + const { id } = req.params as GroupIdParams; + if (!isValidObjectIdString(id)) { + return res.status(400).json({ error: 'Invalid group ID format' }); + } + const { userId } = req.body as { userId: string }; + if (!userId || typeof userId !== 'string') { + return res.status(400).json({ error: 'userId is required' }); + } + if (!isValidObjectIdString(userId)) { + return res + .status(400) + .json({ error: 'Only native user ObjectIds can be added via this endpoint' }); + } + + const { group } = await addUserToGroup(userId, id); + if (!group) { + return res.status(404).json({ error: 'Group not found' }); + } + return res.status(200).json({ group }); + } catch (error) { + const message = error instanceof Error ? error.message : ''; + const isNotFound = message === 'User not found' || message.startsWith('User not found:'); + if (isNotFound) { + return res.status(404).json({ error: 'User not found' }); + } + logger.error('[adminGroups] addGroupMember error:', error); + return res.status(500).json({ error: 'Failed to add member' }); + } + } + + /** + * Attempt removal of an ObjectId-format member: first via removeUserFromGroup + * (which resolves the user), falling back to a raw $pull if the user record + * no longer exists. Returns null only when the group itself is not found. + */ + async function removeObjectIdMember(groupId: string, userId: string): Promise { + try { + const { group } = await removeUserFromGroup(userId, groupId); + return group; + } catch (err) { + const msg = err instanceof Error ? err.message : ''; + if (msg === 'User not found' || msg.startsWith('User not found:')) { + return removeMemberById(groupId, userId); + } + throw err; + } + } + + async function removeGroupMemberHandler(req: ServerRequest, res: Response) { + try { + const { id, userId } = req.params as GroupMemberParams; + if (!isValidObjectIdString(id)) { + return res.status(400).json({ error: 'Invalid group ID format' }); + } + + const group = isValidObjectIdString(userId) + ? await removeObjectIdMember(id, userId) + : await removeMemberById(id, userId); + + if (!group) { + return res.status(404).json({ error: 'Group not found' }); + } + return res.status(200).json({ success: true }); + } catch (error) { + logger.error('[adminGroups] removeGroupMember error:', error); + return res.status(500).json({ error: 'Failed to remove member' }); + } + } + + return { + listGroups: listGroupsHandler, + getGroup: getGroupHandler, + createGroup: createGroupHandler, + updateGroup: updateGroupHandler, + deleteGroup: deleteGroupHandler, + getGroupMembers: getGroupMembersHandler, + addGroupMember: addGroupMemberHandler, + removeGroupMember: removeGroupMemberHandler, + }; +} diff --git a/packages/api/src/admin/index.ts b/packages/api/src/admin/index.ts new file mode 100644 index 0000000000..fe60f1d993 --- /dev/null +++ b/packages/api/src/admin/index.ts @@ -0,0 +1,6 @@ +export { createAdminConfigHandlers } from './config'; +export { createAdminGroupsHandlers } from './groups'; +export { createAdminRolesHandlers } from './roles'; +export type { AdminConfigDeps } from './config'; +export type { AdminGroupsDeps } from './groups'; +export type { AdminRolesDeps } from './roles'; diff --git a/packages/api/src/admin/pagination.ts b/packages/api/src/admin/pagination.ts new file mode 100644 index 0000000000..69003f0418 --- /dev/null +++ b/packages/api/src/admin/pagination.ts @@ -0,0 +1,17 @@ +export const DEFAULT_PAGE_LIMIT = 50; +export const MAX_PAGE_LIMIT = 200; + +export function parsePagination(query: { limit?: string; offset?: string }): { + limit: number; + offset: number; +} { + const rawLimit = parseInt(query.limit ?? '', 10); + const rawOffset = parseInt(query.offset ?? '', 10); + return { + limit: Math.min( + Math.max(Number.isNaN(rawLimit) ? DEFAULT_PAGE_LIMIT : rawLimit, 1), + MAX_PAGE_LIMIT, + ), + offset: Math.max(Number.isNaN(rawOffset) ? 0 : rawOffset, 0), + }; +} diff --git a/packages/api/src/admin/roles.spec.ts b/packages/api/src/admin/roles.spec.ts new file mode 100644 index 0000000000..3f43079bfb --- /dev/null +++ b/packages/api/src/admin/roles.spec.ts @@ -0,0 +1,1484 @@ +import { Types } from 'mongoose'; +import { SystemRoles } from 'librechat-data-provider'; +import type { IRole, IUser } from '@librechat/data-schemas'; +import type { Response } from 'express'; +import type { ServerRequest } from '~/types/http'; +import type { AdminRolesDeps } from './roles'; +import { createAdminRolesHandlers } from './roles'; + +const { RoleConflictError } = jest.requireActual('@librechat/data-schemas'); + +jest.mock('@librechat/data-schemas', () => ({ + ...jest.requireActual('@librechat/data-schemas'), + logger: { error: jest.fn() }, +})); + +const validUserId = new Types.ObjectId().toString(); + +function mockRole(overrides: Partial = {}): IRole { + return { + name: 'editor', + description: 'Can edit content', + permissions: {}, + ...overrides, + } as IRole; +} + +function mockUser(overrides: Partial = {}): IUser { + return { + _id: new Types.ObjectId(validUserId), + name: 'Test User', + email: 'test@example.com', + avatar: 'https://example.com/avatar.png', + role: 'editor', + ...overrides, + } as IUser; +} + +function createReqRes( + overrides: { + params?: Record; + query?: Record; + body?: Record; + } = {}, +) { + const req = { + params: overrides.params ?? {}, + query: overrides.query ?? {}, + body: overrides.body ?? {}, + } as unknown as ServerRequest; + + const json = jest.fn(); + const status = jest.fn().mockReturnValue({ json }); + const res = { status, json } as unknown as Response; + + return { req, res, status, json }; +} + +function createDeps(overrides: Partial = {}): AdminRolesDeps { + return { + listRoles: jest.fn().mockResolvedValue([]), + countRoles: jest.fn().mockResolvedValue(0), + getRoleByName: jest.fn().mockResolvedValue(null), + createRoleByName: jest.fn().mockResolvedValue(mockRole()), + updateRoleByName: jest.fn().mockResolvedValue(mockRole()), + updateAccessPermissions: jest.fn().mockResolvedValue(undefined), + deleteRoleByName: jest.fn().mockResolvedValue(mockRole()), + findUser: jest.fn().mockResolvedValue(null), + updateUser: jest.fn().mockResolvedValue(mockUser()), + updateUsersByRole: jest.fn().mockResolvedValue(undefined), + findUserIdsByRole: jest.fn().mockResolvedValue(['uid-1', 'uid-2']), + updateUsersRoleByIds: jest.fn().mockResolvedValue(undefined), + listUsersByRole: jest.fn().mockResolvedValue([]), + countUsersByRole: jest.fn().mockResolvedValue(0), + ...overrides, + }; +} + +describe('createAdminRolesHandlers', () => { + describe('listRoles', () => { + it('returns paginated roles with 200', async () => { + const roles = [mockRole()]; + const deps = createDeps({ + listRoles: jest.fn().mockResolvedValue(roles), + countRoles: jest.fn().mockResolvedValue(1), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes(); + + await handlers.listRoles(req, res); + + expect(status).toHaveBeenCalledWith(200); + expect(json).toHaveBeenCalledWith({ roles, total: 1, limit: 50, offset: 0 }); + expect(deps.listRoles).toHaveBeenCalledWith({ limit: 50, offset: 0 }); + }); + + it('passes custom limit and offset from query', async () => { + const deps = createDeps({ + countRoles: jest.fn().mockResolvedValue(100), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + query: { limit: '25', offset: '50' }, + }); + + await handlers.listRoles(req, res); + + expect(status).toHaveBeenCalledWith(200); + expect(json).toHaveBeenCalledWith({ roles: [], total: 100, limit: 25, offset: 50 }); + expect(deps.listRoles).toHaveBeenCalledWith({ limit: 25, offset: 50 }); + }); + + it('clamps limit to 200', async () => { + const deps = createDeps(); + const handlers = createAdminRolesHandlers(deps); + const { req, res } = createReqRes({ query: { limit: '999' } }); + + await handlers.listRoles(req, res); + + expect(deps.listRoles).toHaveBeenCalledWith({ limit: 200, offset: 0 }); + }); + + it('clamps negative offset to 0', async () => { + const deps = createDeps(); + const handlers = createAdminRolesHandlers(deps); + const { req, res } = createReqRes({ query: { offset: '-5' } }); + + await handlers.listRoles(req, res); + + expect(deps.listRoles).toHaveBeenCalledWith({ limit: 50, offset: 0 }); + }); + + it('treats non-numeric limit as default', async () => { + const deps = createDeps(); + const handlers = createAdminRolesHandlers(deps); + const { req, res } = createReqRes({ query: { limit: 'abc' } }); + + await handlers.listRoles(req, res); + + expect(deps.listRoles).toHaveBeenCalledWith({ limit: 50, offset: 0 }); + }); + + it('clamps limit=0 to 1', async () => { + const deps = createDeps(); + const handlers = createAdminRolesHandlers(deps); + const { req, res } = createReqRes({ query: { limit: '0' } }); + + await handlers.listRoles(req, res); + + expect(deps.listRoles).toHaveBeenCalledWith({ limit: 1, offset: 0 }); + }); + + it('truncates float offset to integer', async () => { + const deps = createDeps(); + const handlers = createAdminRolesHandlers(deps); + const { req, res } = createReqRes({ query: { offset: '1.7' } }); + + await handlers.listRoles(req, res); + + expect(deps.listRoles).toHaveBeenCalledWith({ limit: 50, offset: 1 }); + }); + + it('returns 500 on error', async () => { + const deps = createDeps({ listRoles: jest.fn().mockRejectedValue(new Error('db down')) }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes(); + + await handlers.listRoles(req, res); + + expect(status).toHaveBeenCalledWith(500); + expect(json).toHaveBeenCalledWith({ error: 'Failed to list roles' }); + }); + }); + + describe('getRole', () => { + it('returns role with 200', async () => { + const role = mockRole(); + const deps = createDeps({ getRoleByName: jest.fn().mockResolvedValue(role) }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ params: { name: 'editor' } }); + + await handlers.getRole(req, res); + + expect(status).toHaveBeenCalledWith(200); + expect(json).toHaveBeenCalledWith({ role }); + }); + + it('returns 404 when role not found', async () => { + const deps = createDeps({ getRoleByName: jest.fn().mockResolvedValue(null) }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ params: { name: 'nonexistent' } }); + + await handlers.getRole(req, res); + + expect(status).toHaveBeenCalledWith(404); + expect(json).toHaveBeenCalledWith({ error: 'Role not found' }); + }); + + it('returns 500 on error', async () => { + const deps = createDeps({ + getRoleByName: jest.fn().mockRejectedValue(new Error('db down')), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ params: { name: 'editor' } }); + + await handlers.getRole(req, res); + + expect(status).toHaveBeenCalledWith(500); + expect(json).toHaveBeenCalledWith({ error: 'Failed to get role' }); + }); + }); + + describe('createRole', () => { + it('creates role and returns 201', async () => { + const role = mockRole(); + const deps = createDeps({ createRoleByName: jest.fn().mockResolvedValue(role) }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + body: { name: 'editor', description: 'Can edit' }, + }); + + await handlers.createRole(req, res); + + expect(status).toHaveBeenCalledWith(201); + expect(json).toHaveBeenCalledWith({ role }); + expect(deps.createRoleByName).toHaveBeenCalledWith({ + name: 'editor', + description: 'Can edit', + permissions: {}, + }); + }); + + it('passes provided permissions to createRoleByName', async () => { + const perms = { chat: { read: true, write: false } } as unknown as IRole['permissions']; + const role = mockRole({ permissions: perms }); + const deps = createDeps({ createRoleByName: jest.fn().mockResolvedValue(role) }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + body: { name: 'editor', permissions: perms }, + }); + + await handlers.createRole(req, res); + + expect(status).toHaveBeenCalledWith(201); + expect(json).toHaveBeenCalledWith({ role }); + expect(deps.createRoleByName).toHaveBeenCalledWith({ + name: 'editor', + permissions: perms, + }); + }); + + it('returns 400 when name is missing', async () => { + const deps = createDeps(); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ body: {} }); + + await handlers.createRole(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'name is required' }); + expect(deps.createRoleByName).not.toHaveBeenCalled(); + }); + + it('returns 400 when name is whitespace-only', async () => { + const deps = createDeps(); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ body: { name: ' ' } }); + + await handlers.createRole(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'name is required' }); + }); + + it('returns 400 when name contains control characters', async () => { + const deps = createDeps(); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ body: { name: 'bad\x00name' } }); + + await handlers.createRole(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'name contains invalid characters' }); + expect(deps.createRoleByName).not.toHaveBeenCalled(); + }); + + it('returns 400 when name is a reserved path segment', async () => { + const deps = createDeps(); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ body: { name: 'members' } }); + + await handlers.createRole(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'name is a reserved path segment' }); + expect(deps.createRoleByName).not.toHaveBeenCalled(); + }); + + it('returns 400 when name exceeds max length', async () => { + const deps = createDeps(); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + body: { name: 'a'.repeat(501) }, + }); + + await handlers.createRole(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'name must not exceed 500 characters' }); + expect(deps.createRoleByName).not.toHaveBeenCalled(); + }); + + it('returns 400 when description exceeds max length', async () => { + const deps = createDeps(); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + body: { name: 'editor', description: 'a'.repeat(2001) }, + }); + + await handlers.createRole(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ + error: 'description must not exceed 2000 characters', + }); + expect(deps.createRoleByName).not.toHaveBeenCalled(); + }); + + it('returns 409 when role already exists', async () => { + const deps = createDeps({ + createRoleByName: jest + .fn() + .mockRejectedValue(new RoleConflictError('Role "editor" already exists')), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ body: { name: 'editor' } }); + + await handlers.createRole(req, res); + + expect(status).toHaveBeenCalledWith(409); + expect(json).toHaveBeenCalledWith({ error: 'Role "editor" already exists' }); + }); + + it('returns 409 when name is reserved system role', async () => { + const deps = createDeps({ + createRoleByName: jest + .fn() + .mockRejectedValue( + new RoleConflictError('Cannot create role with reserved system name: ADMIN'), + ), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ body: { name: 'ADMIN' } }); + + await handlers.createRole(req, res); + + expect(status).toHaveBeenCalledWith(409); + expect(json).toHaveBeenCalledWith({ + error: 'Cannot create role with reserved system name: ADMIN', + }); + }); + + it('returns 500 on unexpected error', async () => { + const deps = createDeps({ + createRoleByName: jest.fn().mockRejectedValue(new Error('db crash')), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ body: { name: 'editor' } }); + + await handlers.createRole(req, res); + + expect(status).toHaveBeenCalledWith(500); + expect(json).toHaveBeenCalledWith({ error: 'Failed to create role' }); + }); + + it('does not classify unrelated errors as 409', async () => { + const deps = createDeps({ + createRoleByName: jest + .fn() + .mockRejectedValue(new Error('Disk space reserved for system use')), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status } = createReqRes({ body: { name: 'test' } }); + + await handlers.createRole(req, res); + + expect(status).toHaveBeenCalledWith(500); + }); + + it('returns 400 when description is not a string', async () => { + const deps = createDeps(); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + body: { name: 'editor', description: 123 }, + }); + + await handlers.createRole(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'description must be a string' }); + expect(deps.createRoleByName).not.toHaveBeenCalled(); + }); + + it('returns 400 when permissions is an array', async () => { + const deps = createDeps(); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + body: { name: 'editor', permissions: [1, 2, 3] }, + }); + + await handlers.createRole(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'permissions must be an object' }); + expect(deps.createRoleByName).not.toHaveBeenCalled(); + }); + }); + + describe('updateRole', () => { + it('updates role and returns 200', async () => { + const role = mockRole({ name: 'senior-editor' }); + const deps = createDeps({ + getRoleByName: jest.fn().mockResolvedValueOnce(mockRole()).mockResolvedValueOnce(null), + updateRoleByName: jest.fn().mockResolvedValue(role), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: 'editor' }, + body: { name: 'senior-editor' }, + }); + + await handlers.updateRole(req, res); + + expect(status).toHaveBeenCalledWith(200); + expect(json).toHaveBeenCalledWith({ role }); + expect(deps.updateRoleByName).toHaveBeenCalledWith('editor', { name: 'senior-editor' }); + }); + + it('trims name before storage', async () => { + const role = mockRole({ name: 'trimmed' }); + const deps = createDeps({ + getRoleByName: jest.fn().mockResolvedValueOnce(mockRole()).mockResolvedValueOnce(null), + updateRoleByName: jest.fn().mockResolvedValue(role), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res } = createReqRes({ + params: { name: 'editor' }, + body: { name: ' trimmed ' }, + }); + + await handlers.updateRole(req, res); + + expect(deps.updateRoleByName).toHaveBeenCalledWith('editor', { name: 'trimmed' }); + }); + + it('migrates users before renaming role', async () => { + const role = mockRole({ name: 'new-name' }); + const callOrder: string[] = []; + const deps = createDeps({ + getRoleByName: jest.fn().mockResolvedValueOnce(mockRole()).mockResolvedValueOnce(null), + findUserIdsByRole: jest.fn().mockImplementation(() => { + callOrder.push('findUserIdsByRole'); + return Promise.resolve(['uid-1']); + }), + updateUsersByRole: jest.fn().mockImplementation(() => { + callOrder.push('updateUsersByRole'); + return Promise.resolve(); + }), + updateRoleByName: jest.fn().mockImplementation(() => { + callOrder.push('updateRoleByName'); + return Promise.resolve(role); + }), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status } = createReqRes({ + params: { name: 'editor' }, + body: { name: 'new-name' }, + }); + + await handlers.updateRole(req, res); + + expect(status).toHaveBeenCalledWith(200); + expect(deps.findUserIdsByRole).toHaveBeenCalledWith('editor'); + expect(deps.updateUsersByRole).toHaveBeenCalledWith('editor', 'new-name'); + expect(callOrder).toEqual(['findUserIdsByRole', 'updateUsersByRole', 'updateRoleByName']); + }); + + it('does not rename role when user migration fails', async () => { + const deps = createDeps({ + getRoleByName: jest.fn().mockResolvedValueOnce(mockRole()).mockResolvedValueOnce(null), + updateUsersByRole: jest.fn().mockRejectedValue(new Error('migration failed')), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status } = createReqRes({ + params: { name: 'editor' }, + body: { name: 'new-name' }, + }); + + await handlers.updateRole(req, res); + + expect(status).toHaveBeenCalledWith(500); + expect(deps.updateRoleByName).not.toHaveBeenCalled(); + }); + + it('does not migrate users when name unchanged', async () => { + const role = mockRole({ description: 'updated' }); + const deps = createDeps({ + getRoleByName: jest.fn().mockResolvedValue(mockRole()), + updateRoleByName: jest.fn().mockResolvedValue(role), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res } = createReqRes({ + params: { name: 'editor' }, + body: { description: 'updated' }, + }); + + await handlers.updateRole(req, res); + + expect(deps.updateUsersByRole).not.toHaveBeenCalled(); + }); + + it('renames and updates description in a single request', async () => { + const role = mockRole({ name: 'senior-editor', description: 'Updated desc' }); + const deps = createDeps({ + getRoleByName: jest.fn().mockResolvedValueOnce(mockRole()).mockResolvedValueOnce(null), + updateRoleByName: jest.fn().mockResolvedValue(role), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: 'editor' }, + body: { name: 'senior-editor', description: 'Updated desc' }, + }); + + await handlers.updateRole(req, res); + + expect(status).toHaveBeenCalledWith(200); + expect(json).toHaveBeenCalledWith({ role }); + expect(deps.updateUsersByRole).toHaveBeenCalledWith('editor', 'senior-editor'); + expect(deps.updateRoleByName).toHaveBeenCalledWith('editor', { + name: 'senior-editor', + description: 'Updated desc', + }); + }); + + it('returns 403 when renaming a system role', async () => { + const deps = createDeps(); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: SystemRoles.ADMIN }, + body: { name: 'custom-admin' }, + }); + + await handlers.updateRole(req, res); + + expect(status).toHaveBeenCalledWith(403); + expect(json).toHaveBeenCalledWith({ error: 'Cannot rename system role' }); + expect(deps.getRoleByName).not.toHaveBeenCalled(); + }); + + it('returns 403 when renaming to a system role name', async () => { + const deps = createDeps(); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: 'editor' }, + body: { name: SystemRoles.ADMIN }, + }); + + await handlers.updateRole(req, res); + + expect(status).toHaveBeenCalledWith(403); + expect(json).toHaveBeenCalledWith({ error: 'Cannot use a reserved system role name' }); + }); + + it('returns 409 when target name already exists', async () => { + const deps = createDeps({ + getRoleByName: jest.fn().mockResolvedValue(mockRole()), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: 'editor' }, + body: { name: 'viewer' }, + }); + + await handlers.updateRole(req, res); + + expect(status).toHaveBeenCalledWith(409); + expect(json).toHaveBeenCalledWith({ error: 'Role "viewer" already exists' }); + }); + + it('returns 400 when name is empty string', async () => { + const deps = createDeps(); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: 'editor' }, + body: { name: '' }, + }); + + await handlers.updateRole(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'name must be a non-empty string' }); + expect(deps.getRoleByName).not.toHaveBeenCalled(); + }); + + it('returns 400 when name is whitespace-only', async () => { + const deps = createDeps(); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: 'editor' }, + body: { name: ' ' }, + }); + + await handlers.updateRole(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'name must be a non-empty string' }); + }); + + it('returns 400 when name exceeds max length', async () => { + const deps = createDeps(); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: 'editor' }, + body: { name: 'a'.repeat(501) }, + }); + + await handlers.updateRole(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'name must not exceed 500 characters' }); + expect(deps.getRoleByName).not.toHaveBeenCalled(); + }); + + it('returns 404 when role not found', async () => { + const deps = createDeps({ getRoleByName: jest.fn().mockResolvedValue(null) }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: 'nonexistent' }, + body: { description: 'updated' }, + }); + + await handlers.updateRole(req, res); + + expect(status).toHaveBeenCalledWith(404); + expect(json).toHaveBeenCalledWith({ error: 'Role not found' }); + }); + + it('returns 404 when updateRoleByName returns null', async () => { + const deps = createDeps({ + getRoleByName: jest.fn().mockResolvedValue(mockRole()), + updateRoleByName: jest.fn().mockResolvedValue(null), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: 'editor' }, + body: { description: 'updated' }, + }); + + await handlers.updateRole(req, res); + + expect(status).toHaveBeenCalledWith(404); + expect(json).toHaveBeenCalledWith({ error: 'Role not found' }); + }); + + it('rolls back user migration when rename fails', async () => { + const ids = ['uid-1', 'uid-2']; + const deps = createDeps({ + getRoleByName: jest.fn().mockResolvedValueOnce(mockRole()).mockResolvedValueOnce(null), + findUserIdsByRole: jest.fn().mockResolvedValue(ids), + updateRoleByName: jest.fn().mockResolvedValue(null), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: 'editor' }, + body: { name: 'new-name' }, + }); + + await handlers.updateRole(req, res); + + expect(status).toHaveBeenCalledWith(404); + expect(json).toHaveBeenCalledWith({ error: 'Role not found' }); + expect(deps.updateUsersByRole).toHaveBeenCalledTimes(1); + expect(deps.updateUsersByRole).toHaveBeenCalledWith('editor', 'new-name'); + expect(deps.updateUsersRoleByIds).toHaveBeenCalledWith(ids, 'editor'); + }); + + it('rolls back user migration when rename throws', async () => { + const ids = ['uid-1', 'uid-2']; + const deps = createDeps({ + getRoleByName: jest.fn().mockResolvedValueOnce(mockRole()).mockResolvedValueOnce(null), + findUserIdsByRole: jest.fn().mockResolvedValue(ids), + updateRoleByName: jest.fn().mockRejectedValue(new Error('db crash')), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status } = createReqRes({ + params: { name: 'editor' }, + body: { name: 'new-name' }, + }); + + await handlers.updateRole(req, res); + + expect(status).toHaveBeenCalledWith(500); + expect(deps.updateUsersByRole).toHaveBeenCalledTimes(1); + expect(deps.updateUsersByRole).toHaveBeenCalledWith('editor', 'new-name'); + expect(deps.updateUsersRoleByIds).toHaveBeenCalledWith(ids, 'editor'); + }); + + it('logs rollback failure and still returns 500', async () => { + const deps = createDeps({ + getRoleByName: jest.fn().mockResolvedValueOnce(mockRole()).mockResolvedValueOnce(null), + findUserIdsByRole: jest.fn().mockResolvedValue(['uid-1']), + updateUsersRoleByIds: jest.fn().mockRejectedValue(new Error('rollback failed')), + updateRoleByName: jest.fn().mockRejectedValue(new Error('rename failed')), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status } = createReqRes({ + params: { name: 'editor' }, + body: { name: 'new-name' }, + }); + + await handlers.updateRole(req, res); + + expect(status).toHaveBeenCalledWith(500); + expect(deps.updateUsersByRole).toHaveBeenCalledTimes(1); + expect(deps.updateUsersRoleByIds).toHaveBeenCalledTimes(1); + }); + + it('returns 400 when description exceeds max length', async () => { + const deps = createDeps(); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: 'editor' }, + body: { description: 'a'.repeat(2001) }, + }); + + await handlers.updateRole(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ + error: 'description must not exceed 2000 characters', + }); + expect(deps.getRoleByName).not.toHaveBeenCalled(); + }); + + it('returns 500 on unexpected error', async () => { + const deps = createDeps({ + getRoleByName: jest.fn().mockResolvedValue(mockRole()), + updateRoleByName: jest.fn().mockRejectedValue(new Error('db error')), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: 'editor' }, + body: { description: 'updated' }, + }); + + await handlers.updateRole(req, res); + + expect(status).toHaveBeenCalledWith(500); + expect(json).toHaveBeenCalledWith({ error: 'Failed to update role' }); + }); + + it('does not roll back when error occurs before user migration', async () => { + const deps = createDeps({ + getRoleByName: jest + .fn() + .mockResolvedValueOnce(mockRole()) + .mockRejectedValueOnce(new Error('db crash')), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status } = createReqRes({ + params: { name: 'editor' }, + body: { name: 'new-name' }, + }); + + await handlers.updateRole(req, res); + + expect(status).toHaveBeenCalledWith(500); + expect(deps.updateUsersByRole).not.toHaveBeenCalled(); + }); + + it('does not migrate users when findUserIdsByRole throws', async () => { + const deps = createDeps({ + getRoleByName: jest.fn().mockResolvedValueOnce(mockRole()).mockResolvedValueOnce(null), + findUserIdsByRole: jest.fn().mockRejectedValue(new Error('db crash')), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status } = createReqRes({ + params: { name: 'editor' }, + body: { name: 'new-name' }, + }); + + await handlers.updateRole(req, res); + + expect(status).toHaveBeenCalledWith(500); + expect(deps.updateUsersByRole).not.toHaveBeenCalled(); + expect(deps.updateUsersRoleByIds).not.toHaveBeenCalled(); + }); + + it('returns existing role early when update body has no changes', async () => { + const role = mockRole(); + const deps = createDeps({ + getRoleByName: jest.fn().mockResolvedValue(role), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: 'editor' }, + body: {}, + }); + + await handlers.updateRole(req, res); + + expect(status).toHaveBeenCalledWith(200); + expect(json).toHaveBeenCalledWith({ role }); + expect(deps.updateRoleByName).not.toHaveBeenCalled(); + }); + + it('rejects invalid description before making DB calls', async () => { + const deps = createDeps(); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: 'editor' }, + body: { description: 123 }, + }); + + await handlers.updateRole(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'description must be a string' }); + expect(deps.getRoleByName).not.toHaveBeenCalled(); + }); + }); + + describe('updateRolePermissions', () => { + it('updates permissions and returns 200 with updated role', async () => { + const role = mockRole(); + const updatedRole = mockRole({ + permissions: { chat: { read: true, write: true } } as IRole['permissions'], + }); + const deps = createDeps({ + getRoleByName: jest.fn().mockResolvedValueOnce(role).mockResolvedValueOnce(updatedRole), + }); + const handlers = createAdminRolesHandlers(deps); + const perms = { chat: { read: true, write: true } }; + const { req, res, status, json } = createReqRes({ + params: { name: 'editor' }, + body: { permissions: perms }, + }); + + await handlers.updateRolePermissions(req, res); + + expect(deps.updateAccessPermissions).toHaveBeenCalledWith('editor', perms, role); + expect(status).toHaveBeenCalledWith(200); + expect(json).toHaveBeenCalledWith({ role: updatedRole }); + }); + + it('returns 400 when permissions is missing', async () => { + const deps = createDeps(); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: 'editor' }, + body: {}, + }); + + await handlers.updateRolePermissions(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'permissions object is required' }); + }); + + it('returns 400 when permissions is an array', async () => { + const deps = createDeps(); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: 'editor' }, + body: { permissions: [1, 2, 3] }, + }); + + await handlers.updateRolePermissions(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'permissions object is required' }); + }); + + it('returns 404 when role not found', async () => { + const deps = createDeps({ getRoleByName: jest.fn().mockResolvedValue(null) }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: 'nonexistent' }, + body: { permissions: { chat: { read: true } } }, + }); + + await handlers.updateRolePermissions(req, res); + + expect(status).toHaveBeenCalledWith(404); + expect(json).toHaveBeenCalledWith({ error: 'Role not found' }); + }); + + it('returns 500 on error', async () => { + const deps = createDeps({ + getRoleByName: jest.fn().mockResolvedValue(mockRole()), + updateAccessPermissions: jest.fn().mockRejectedValue(new Error('db down')), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: 'editor' }, + body: { permissions: { chat: { read: true } } }, + }); + + await handlers.updateRolePermissions(req, res); + + expect(status).toHaveBeenCalledWith(500); + expect(json).toHaveBeenCalledWith({ error: 'Failed to update role permissions' }); + }); + }); + + describe('deleteRole', () => { + it('deletes role and returns 200', async () => { + const deps = createDeps(); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ params: { name: 'editor' } }); + + await handlers.deleteRole(req, res); + + expect(deps.deleteRoleByName).toHaveBeenCalledWith('editor'); + expect(status).toHaveBeenCalledWith(200); + expect(json).toHaveBeenCalledWith({ success: true }); + }); + + it('returns 403 for system role', async () => { + const deps = createDeps(); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ params: { name: SystemRoles.ADMIN } }); + + await handlers.deleteRole(req, res); + + expect(status).toHaveBeenCalledWith(403); + expect(json).toHaveBeenCalledWith({ error: 'Cannot delete system role' }); + expect(deps.deleteRoleByName).not.toHaveBeenCalled(); + }); + + it('returns 404 when role not found', async () => { + const deps = createDeps({ deleteRoleByName: jest.fn().mockResolvedValue(null) }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ params: { name: 'nonexistent' } }); + + await handlers.deleteRole(req, res); + + expect(status).toHaveBeenCalledWith(404); + expect(json).toHaveBeenCalledWith({ error: 'Role not found' }); + }); + + it('returns 500 on error', async () => { + const deps = createDeps({ + deleteRoleByName: jest.fn().mockRejectedValue(new Error('db down')), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ params: { name: 'editor' } }); + + await handlers.deleteRole(req, res); + + expect(status).toHaveBeenCalledWith(500); + expect(json).toHaveBeenCalledWith({ error: 'Failed to delete role' }); + }); + }); + + describe('getRoleMembers', () => { + it('returns paginated members with 200', async () => { + const user = mockUser(); + const deps = createDeps({ + getRoleByName: jest.fn().mockResolvedValue(mockRole()), + listUsersByRole: jest.fn().mockResolvedValue([user]), + countUsersByRole: jest.fn().mockResolvedValue(1), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ params: { name: 'editor' } }); + + await handlers.getRoleMembers(req, res); + + expect(deps.listUsersByRole).toHaveBeenCalledWith('editor', { limit: 50, offset: 0 }); + expect(deps.countUsersByRole).toHaveBeenCalledWith('editor'); + expect(status).toHaveBeenCalledWith(200); + const response = json.mock.calls[0][0]; + expect(response.members).toHaveLength(1); + expect(response.members[0]).toEqual({ + userId: validUserId, + name: 'Test User', + email: 'test@example.com', + avatarUrl: 'https://example.com/avatar.png', + }); + expect(response.total).toBe(1); + expect(response.limit).toBe(50); + expect(response.offset).toBe(0); + }); + + it('passes pagination parameters from query', async () => { + const deps = createDeps({ + getRoleByName: jest.fn().mockResolvedValue(mockRole()), + countUsersByRole: jest.fn().mockResolvedValue(0), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res } = createReqRes({ + params: { name: 'editor' }, + query: { limit: '10', offset: '20' }, + }); + + await handlers.getRoleMembers(req, res); + + expect(deps.listUsersByRole).toHaveBeenCalledWith('editor', { limit: 10, offset: 20 }); + }); + + it('clamps limit to 200', async () => { + const deps = createDeps({ + getRoleByName: jest.fn().mockResolvedValue(mockRole()), + countUsersByRole: jest.fn().mockResolvedValue(0), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res } = createReqRes({ + params: { name: 'editor' }, + query: { limit: '999' }, + }); + + await handlers.getRoleMembers(req, res); + + expect(deps.listUsersByRole).toHaveBeenCalledWith('editor', { limit: 200, offset: 0 }); + }); + + it('does not include joinedAt in response', async () => { + const user = mockUser(); + const deps = createDeps({ + getRoleByName: jest.fn().mockResolvedValue(mockRole()), + listUsersByRole: jest.fn().mockResolvedValue([user]), + countUsersByRole: jest.fn().mockResolvedValue(1), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, json } = createReqRes({ params: { name: 'editor' } }); + + await handlers.getRoleMembers(req, res); + + const member = json.mock.calls[0][0].members[0]; + expect(member).not.toHaveProperty('joinedAt'); + }); + + it('returns empty array when no members', async () => { + const deps = createDeps({ + getRoleByName: jest.fn().mockResolvedValue(mockRole()), + countUsersByRole: jest.fn().mockResolvedValue(0), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ params: { name: 'editor' } }); + + await handlers.getRoleMembers(req, res); + + expect(status).toHaveBeenCalledWith(200); + expect(json).toHaveBeenCalledWith({ members: [], total: 0, limit: 50, offset: 0 }); + }); + + it('returns 404 when role not found', async () => { + const deps = createDeps({ getRoleByName: jest.fn().mockResolvedValue(null) }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ params: { name: 'nonexistent' } }); + + await handlers.getRoleMembers(req, res); + + expect(status).toHaveBeenCalledWith(404); + expect(json).toHaveBeenCalledWith({ error: 'Role not found' }); + }); + + it('returns 500 on error', async () => { + const deps = createDeps({ + getRoleByName: jest.fn().mockResolvedValue(mockRole()), + listUsersByRole: jest.fn().mockRejectedValue(new Error('db down')), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ params: { name: 'editor' } }); + + await handlers.getRoleMembers(req, res); + + expect(status).toHaveBeenCalledWith(500); + expect(json).toHaveBeenCalledWith({ error: 'Failed to get role members' }); + }); + }); + + describe('addRoleMember', () => { + it('adds member and returns 200', async () => { + const deps = createDeps({ + getRoleByName: jest.fn().mockResolvedValue(mockRole()), + findUser: jest.fn().mockResolvedValue(mockUser({ role: 'viewer' })), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: 'editor' }, + body: { userId: validUserId }, + }); + + await handlers.addRoleMember(req, res); + + expect(deps.updateUser).toHaveBeenCalledWith(validUserId, { role: 'editor' }); + expect(status).toHaveBeenCalledWith(200); + expect(json).toHaveBeenCalledWith({ success: true }); + }); + + it('skips DB write when user already has the target role', async () => { + const deps = createDeps({ + getRoleByName: jest.fn().mockResolvedValue(mockRole()), + findUser: jest.fn().mockResolvedValue(mockUser({ role: 'editor' })), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: 'editor' }, + body: { userId: validUserId }, + }); + + await handlers.addRoleMember(req, res); + + expect(status).toHaveBeenCalledWith(200); + expect(json).toHaveBeenCalledWith({ success: true }); + expect(deps.updateUser).not.toHaveBeenCalled(); + }); + + it('returns 400 when userId is missing', async () => { + const deps = createDeps(); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: 'editor' }, + body: {}, + }); + + await handlers.addRoleMember(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'userId is required' }); + }); + + it('returns 400 for invalid ObjectId', async () => { + const deps = createDeps(); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: 'editor' }, + body: { userId: 'not-valid' }, + }); + + await handlers.addRoleMember(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'Invalid user ID format' }); + }); + + it('returns 404 when role not found', async () => { + const deps = createDeps({ getRoleByName: jest.fn().mockResolvedValue(null) }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: 'nonexistent' }, + body: { userId: validUserId }, + }); + + await handlers.addRoleMember(req, res); + + expect(status).toHaveBeenCalledWith(404); + expect(json).toHaveBeenCalledWith({ error: 'Role not found' }); + }); + + it('returns 404 when user not found', async () => { + const deps = createDeps({ + getRoleByName: jest.fn().mockResolvedValue(mockRole()), + findUser: jest.fn().mockResolvedValue(null), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: 'editor' }, + body: { userId: validUserId }, + }); + + await handlers.addRoleMember(req, res); + + expect(status).toHaveBeenCalledWith(404); + expect(json).toHaveBeenCalledWith({ error: 'User not found' }); + }); + + it('returns 400 when reassigning the last admin to another role', async () => { + const deps = createDeps({ + getRoleByName: jest.fn().mockResolvedValue(mockRole({ name: 'editor' })), + findUser: jest.fn().mockResolvedValue(mockUser({ role: SystemRoles.ADMIN })), + countUsersByRole: jest.fn().mockResolvedValue(1), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: 'editor' }, + body: { userId: validUserId }, + }); + + await handlers.addRoleMember(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'Cannot remove the last admin user' }); + expect(deps.updateUser).not.toHaveBeenCalled(); + }); + + it('allows reassigning an admin when multiple admins exist', async () => { + const deps = createDeps({ + getRoleByName: jest.fn().mockResolvedValue(mockRole({ name: 'editor' })), + findUser: jest.fn().mockResolvedValue(mockUser({ role: SystemRoles.ADMIN })), + countUsersByRole: jest.fn().mockResolvedValue(3), + updateUser: jest.fn().mockResolvedValue(mockUser({ role: 'editor' })), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status } = createReqRes({ + params: { name: 'editor' }, + body: { userId: validUserId }, + }); + + await handlers.addRoleMember(req, res); + + expect(status).toHaveBeenCalledWith(200); + expect(deps.updateUser).toHaveBeenCalledWith(validUserId, { role: 'editor' }); + }); + + it('rolls back assignment when post-write admin count is zero', async () => { + const deps = createDeps({ + getRoleByName: jest.fn().mockResolvedValue(mockRole({ name: 'editor' })), + findUser: jest.fn().mockResolvedValue(mockUser({ role: SystemRoles.ADMIN })), + countUsersByRole: jest.fn().mockResolvedValueOnce(2).mockResolvedValueOnce(0), + updateUser: jest.fn().mockResolvedValue(mockUser()), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: 'editor' }, + body: { userId: validUserId }, + }); + + await handlers.addRoleMember(req, res); + + expect(deps.updateUser).toHaveBeenCalledTimes(2); + expect(deps.updateUser).toHaveBeenLastCalledWith(validUserId, { role: SystemRoles.ADMIN }); + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'Cannot remove the last admin user' }); + }); + + it('returns 403 when adding to a non-ADMIN system role', async () => { + const deps = createDeps(); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: SystemRoles.USER }, + body: { userId: validUserId }, + }); + + await handlers.addRoleMember(req, res); + + expect(status).toHaveBeenCalledWith(403); + expect(json).toHaveBeenCalledWith({ + error: 'Cannot directly assign members to a system role', + }); + expect(deps.updateUser).not.toHaveBeenCalled(); + }); + + it('allows promoting a non-admin user to the ADMIN role', async () => { + const deps = createDeps({ + getRoleByName: jest.fn().mockResolvedValue(mockRole({ name: SystemRoles.ADMIN })), + findUser: jest.fn().mockResolvedValue(mockUser({ role: 'editor' })), + updateUser: jest.fn().mockResolvedValue(mockUser({ role: SystemRoles.ADMIN })), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: SystemRoles.ADMIN }, + body: { userId: validUserId }, + }); + + await handlers.addRoleMember(req, res); + + expect(deps.updateUser).toHaveBeenCalledWith(validUserId, { role: SystemRoles.ADMIN }); + expect(status).toHaveBeenCalledWith(200); + expect(json).toHaveBeenCalledWith({ success: true }); + }); + + it('returns 500 on unexpected error', async () => { + const deps = createDeps({ + getRoleByName: jest.fn().mockResolvedValue(mockRole()), + findUser: jest.fn().mockResolvedValue(mockUser({ role: 'viewer' })), + updateUser: jest.fn().mockRejectedValue(new Error('timeout')), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: 'editor' }, + body: { userId: validUserId }, + }); + + await handlers.addRoleMember(req, res); + + expect(status).toHaveBeenCalledWith(500); + expect(json).toHaveBeenCalledWith({ error: 'Failed to add role member' }); + }); + }); + + describe('removeRoleMember', () => { + it('removes member and returns 200', async () => { + const deps = createDeps({ + getRoleByName: jest.fn().mockResolvedValue(mockRole()), + findUser: jest.fn().mockResolvedValue(mockUser({ role: 'editor' })), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: 'editor', userId: validUserId }, + }); + + await handlers.removeRoleMember(req, res); + + expect(deps.updateUser).toHaveBeenCalledWith(validUserId, { role: SystemRoles.USER }); + expect(status).toHaveBeenCalledWith(200); + expect(json).toHaveBeenCalledWith({ success: true }); + }); + + it('returns 403 when removing from a non-ADMIN system role', async () => { + const deps = createDeps(); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: SystemRoles.USER, userId: validUserId }, + }); + + await handlers.removeRoleMember(req, res); + + expect(status).toHaveBeenCalledWith(403); + expect(json).toHaveBeenCalledWith({ error: 'Cannot remove members from a system role' }); + expect(deps.getRoleByName).not.toHaveBeenCalled(); + }); + + it('returns 400 for invalid ObjectId', async () => { + const deps = createDeps(); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: 'editor', userId: 'bad' }, + }); + + await handlers.removeRoleMember(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'Invalid user ID format' }); + expect(deps.findUser).not.toHaveBeenCalled(); + }); + + it('returns 404 when role not found', async () => { + const deps = createDeps({ getRoleByName: jest.fn().mockResolvedValue(null) }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: 'nonexistent', userId: validUserId }, + }); + + await handlers.removeRoleMember(req, res); + + expect(status).toHaveBeenCalledWith(404); + expect(json).toHaveBeenCalledWith({ error: 'Role not found' }); + expect(deps.findUser).not.toHaveBeenCalled(); + }); + + it('returns 404 when user not found', async () => { + const deps = createDeps({ + getRoleByName: jest.fn().mockResolvedValue(mockRole()), + findUser: jest.fn().mockResolvedValue(null), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: 'editor', userId: validUserId }, + }); + + await handlers.removeRoleMember(req, res); + + expect(status).toHaveBeenCalledWith(404); + expect(json).toHaveBeenCalledWith({ error: 'User not found' }); + }); + + it('returns 400 when user is not a member of the role', async () => { + const deps = createDeps({ + getRoleByName: jest.fn().mockResolvedValue(mockRole()), + findUser: jest.fn().mockResolvedValue(mockUser({ role: 'other-role' })), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: 'editor', userId: validUserId }, + }); + + await handlers.removeRoleMember(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'User is not a member of this role' }); + expect(deps.updateUser).not.toHaveBeenCalled(); + }); + + it('returns 400 when removing the last admin user', async () => { + const deps = createDeps({ + getRoleByName: jest.fn().mockResolvedValue(mockRole({ name: SystemRoles.ADMIN })), + findUser: jest.fn().mockResolvedValue(mockUser({ role: SystemRoles.ADMIN })), + countUsersByRole: jest.fn().mockResolvedValue(1), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: SystemRoles.ADMIN, userId: validUserId }, + }); + + await handlers.removeRoleMember(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'Cannot remove the last admin user' }); + expect(deps.updateUser).not.toHaveBeenCalled(); + }); + + it('allows removing an admin when multiple admins exist', async () => { + const deps = createDeps({ + getRoleByName: jest.fn().mockResolvedValue(mockRole({ name: SystemRoles.ADMIN })), + findUser: jest.fn().mockResolvedValue(mockUser({ role: SystemRoles.ADMIN })), + countUsersByRole: jest.fn().mockResolvedValue(3), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: SystemRoles.ADMIN, userId: validUserId }, + }); + + await handlers.removeRoleMember(req, res); + + expect(status).toHaveBeenCalledWith(200); + expect(json).toHaveBeenCalledWith({ success: true }); + expect(deps.updateUser).toHaveBeenCalledWith(validUserId, { role: SystemRoles.USER }); + }); + + it('rolls back removal when post-write check finds zero admins', async () => { + const deps = createDeps({ + getRoleByName: jest.fn().mockResolvedValue(mockRole({ name: SystemRoles.ADMIN })), + findUser: jest.fn().mockResolvedValue(mockUser({ role: SystemRoles.ADMIN })), + countUsersByRole: jest.fn().mockResolvedValueOnce(2).mockResolvedValueOnce(0), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: SystemRoles.ADMIN, userId: validUserId }, + }); + + await handlers.removeRoleMember(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'Cannot remove the last admin user' }); + expect(deps.updateUser).toHaveBeenCalledTimes(2); + expect(deps.updateUser).toHaveBeenNthCalledWith(1, validUserId, { + role: SystemRoles.USER, + }); + expect(deps.updateUser).toHaveBeenNthCalledWith(2, validUserId, { + role: SystemRoles.ADMIN, + }); + }); + + it('returns 400 even when rollback updateUser throws', async () => { + const deps = createDeps({ + getRoleByName: jest.fn().mockResolvedValue(mockRole({ name: SystemRoles.ADMIN })), + findUser: jest.fn().mockResolvedValue(mockUser({ role: SystemRoles.ADMIN })), + countUsersByRole: jest.fn().mockResolvedValueOnce(2).mockResolvedValueOnce(0), + updateUser: jest + .fn() + .mockResolvedValueOnce(mockUser({ role: SystemRoles.USER })) + .mockRejectedValueOnce(new Error('rollback failed')), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: SystemRoles.ADMIN, userId: validUserId }, + }); + + await handlers.removeRoleMember(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'Cannot remove the last admin user' }); + expect(deps.updateUser).toHaveBeenCalledTimes(2); + }); + + it('returns 500 on unexpected error', async () => { + const deps = createDeps({ + getRoleByName: jest.fn().mockResolvedValue(mockRole()), + findUser: jest.fn().mockResolvedValue(mockUser({ role: 'editor' })), + updateUser: jest.fn().mockRejectedValue(new Error('timeout')), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: 'editor', userId: validUserId }, + }); + + await handlers.removeRoleMember(req, res); + + expect(status).toHaveBeenCalledWith(500); + expect(json).toHaveBeenCalledWith({ error: 'Failed to remove role member' }); + }); + }); +}); diff --git a/packages/api/src/admin/roles.ts b/packages/api/src/admin/roles.ts new file mode 100644 index 0000000000..b8c87c23ea --- /dev/null +++ b/packages/api/src/admin/roles.ts @@ -0,0 +1,550 @@ +import { SystemRoles } from 'librechat-data-provider'; +import { logger, isValidObjectIdString, RoleConflictError } from '@librechat/data-schemas'; +import type { IRole, IUser, AdminMember } from '@librechat/data-schemas'; +import type { FilterQuery, Types } from 'mongoose'; +import type { Response } from 'express'; +import type { ServerRequest } from '~/types/http'; +import { parsePagination } from './pagination'; + +const systemRoleValues = new Set(Object.values(SystemRoles)); + +/** Case-insensitive check — the legacy roles route uppercases params. */ +function isSystemRoleName(name: string): boolean { + return systemRoleValues.has(name.toUpperCase()); +} + +const MAX_NAME_LENGTH = 500; +const MAX_DESCRIPTION_LENGTH = 2000; +const CONTROL_CHAR_RE = /\p{Cc}/u; +/** + * Role names that would create semantically ambiguous URLs. + * e.g. GET /api/admin/roles/members — is that "list roles" or "get role named members"? + * Express routing resolves this correctly (single vs multi-segment), but the URLs + * are confusing for API consumers. Keep in sync with sub-path routes in routes/admin/roles.js. + */ +const RESERVED_ROLE_NAMES = new Set(['members', 'permissions']); + +function validateNameParam(name: string): string | null { + if (!name || typeof name !== 'string') { + return 'name parameter is required'; + } + if (name.length > MAX_NAME_LENGTH) { + return `name must not exceed ${MAX_NAME_LENGTH} characters`; + } + if (CONTROL_CHAR_RE.test(name)) { + return 'name contains invalid characters'; + } + return null; +} + +function validateRoleName(name: unknown, required: boolean): string | null { + if (name === undefined) { + return required ? 'name is required' : null; + } + if (typeof name !== 'string' || !name.trim()) { + return required ? 'name is required' : 'name must be a non-empty string'; + } + const trimmed = name.trim(); + if (trimmed.length > MAX_NAME_LENGTH) { + return `name must not exceed ${MAX_NAME_LENGTH} characters`; + } + if (CONTROL_CHAR_RE.test(trimmed)) { + return 'name contains invalid characters'; + } + if (RESERVED_ROLE_NAMES.has(trimmed)) { + return 'name is a reserved path segment'; + } + return null; +} + +function validateDescription(description: unknown): string | null { + if (description === undefined) { + return null; + } + if (typeof description !== 'string') { + return 'description must be a string'; + } + if (description.length > MAX_DESCRIPTION_LENGTH) { + return `description must not exceed ${MAX_DESCRIPTION_LENGTH} characters`; + } + return null; +} + +interface RoleNameParams { + name: string; +} + +interface RoleMemberParams extends RoleNameParams { + userId: string; +} + +export type RoleListItem = { _id: Types.ObjectId | string; name: string; description?: string }; + +export interface AdminRolesDeps { + listRoles: (options?: { limit?: number; offset?: number }) => Promise; + countRoles: () => Promise; + getRoleByName: (name: string, fields?: string | string[] | null) => Promise; + createRoleByName: (roleData: Partial) => Promise; + updateRoleByName: (name: string, updates: Partial) => Promise; + updateAccessPermissions: ( + name: string, + perms: Record>, + roleData?: IRole, + ) => Promise; + deleteRoleByName: (name: string) => Promise; + findUser: ( + criteria: FilterQuery, + fields?: string | string[] | null, + ) => Promise; + updateUser: (userId: string, data: Partial) => Promise; + updateUsersByRole: (oldRole: string, newRole: string) => Promise; + findUserIdsByRole: (roleName: string) => Promise; + updateUsersRoleByIds: (userIds: string[], newRole: string) => Promise; + listUsersByRole: ( + roleName: string, + options?: { limit?: number; offset?: number }, + ) => Promise; + countUsersByRole: (roleName: string) => Promise; +} + +export function createAdminRolesHandlers(deps: AdminRolesDeps) { + const { + listRoles, + countRoles, + getRoleByName, + createRoleByName, + updateRoleByName, + updateAccessPermissions, + deleteRoleByName, + findUser, + updateUser, + updateUsersByRole, + findUserIdsByRole, + updateUsersRoleByIds, + listUsersByRole, + countUsersByRole, + } = deps; + + async function listRolesHandler(req: ServerRequest, res: Response) { + try { + const { limit, offset } = parsePagination(req.query); + const [roles, total] = await Promise.all([listRoles({ limit, offset }), countRoles()]); + return res.status(200).json({ roles, total, limit, offset }); + } catch (error) { + logger.error('[adminRoles] listRoles error:', error); + return res.status(500).json({ error: 'Failed to list roles' }); + } + } + + async function getRoleHandler(req: ServerRequest, res: Response) { + try { + const { name } = req.params as RoleNameParams; + const paramError = validateNameParam(name); + if (paramError) { + return res.status(400).json({ error: paramError }); + } + const role = await getRoleByName(name); + if (!role) { + return res.status(404).json({ error: 'Role not found' }); + } + return res.status(200).json({ role }); + } catch (error) { + logger.error('[adminRoles] getRole error:', error); + return res.status(500).json({ error: 'Failed to get role' }); + } + } + + async function createRoleHandler(req: ServerRequest, res: Response) { + try { + const { name, description, permissions } = req.body as { + name?: string; + description?: string; + permissions?: IRole['permissions']; + }; + const nameError = validateRoleName(name, true); + if (nameError) { + return res.status(400).json({ error: nameError }); + } + const descError = validateDescription(description); + if (descError) { + return res.status(400).json({ error: descError }); + } + if ( + permissions !== undefined && + (permissions === null || typeof permissions !== 'object' || Array.isArray(permissions)) + ) { + return res.status(400).json({ error: 'permissions must be an object' }); + } + const roleData: Partial = { + name: (name as string).trim(), + permissions: permissions ?? {}, + }; + if (description !== undefined) { + roleData.description = description; + } + const role = await createRoleByName(roleData); + return res.status(201).json({ role }); + } catch (error) { + logger.error('[adminRoles] createRole error:', error); + if (error instanceof RoleConflictError) { + return res.status(409).json({ error: error.message }); + } + return res.status(500).json({ error: 'Failed to create role' }); + } + } + + async function rollbackMigratedUsers( + migratedIds: string[], + currentName: string, + newName: string, + ): Promise { + if (migratedIds.length === 0) { + return; + } + try { + await updateUsersRoleByIds(migratedIds, currentName); + } catch (rollbackError) { + logger.error( + `[adminRoles] CRITICAL: rename rollback failed — ${migratedIds.length} users have dangling role "${newName}": [${migratedIds.join(', ')}]`, + rollbackError, + ); + } + } + + /** + * Renames a role by migrating users to the new name and updating the role document. + * + * The ID snapshot from `findUserIdsByRole` is a point-in-time read. Users assigned + * to `currentName` between the snapshot and the bulk `updateUsersByRole` write will + * be moved to `newName` but will NOT be reverted on rollback. This window is narrow + * and only relevant under concurrent admin operations during a rename. + */ + async function renameRole( + currentName: string, + newName: string, + extraUpdates?: Partial, + ): Promise { + const migratedIds = await findUserIdsByRole(currentName); + await updateUsersByRole(currentName, newName); + try { + const updates: Partial = { name: newName, ...extraUpdates }; + const role = await updateRoleByName(currentName, updates); + if (!role) { + await rollbackMigratedUsers(migratedIds, currentName, newName); + } + return role; + } catch (error) { + await rollbackMigratedUsers(migratedIds, currentName, newName); + throw error; + } + } + + async function updateRoleHandler(req: ServerRequest, res: Response) { + try { + const { name } = req.params as RoleNameParams; + const paramError = validateNameParam(name); + if (paramError) { + return res.status(400).json({ error: paramError }); + } + const body = req.body as { name?: string; description?: string }; + const nameError = validateRoleName(body.name, false); + if (nameError) { + return res.status(400).json({ error: nameError }); + } + const descError = validateDescription(body.description); + if (descError) { + return res.status(400).json({ error: descError }); + } + + const trimmedName = body.name?.trim() ?? ''; + const isRename = trimmedName !== '' && trimmedName !== name; + + if (isRename && isSystemRoleName(name)) { + return res.status(403).json({ error: 'Cannot rename system role' }); + } + if (isRename && isSystemRoleName(trimmedName)) { + return res.status(403).json({ error: 'Cannot use a reserved system role name' }); + } + + const existing = await getRoleByName(name); + if (!existing) { + return res.status(404).json({ error: 'Role not found' }); + } + + if (isRename) { + const duplicate = await getRoleByName(trimmedName); + if (duplicate) { + return res.status(409).json({ error: `Role "${trimmedName}" already exists` }); + } + } + + const updates: Partial = {}; + if (isRename) { + updates.name = trimmedName; + } + if (body.description !== undefined) { + updates.description = body.description; + } + + if (Object.keys(updates).length === 0) { + return res.status(200).json({ role: existing }); + } + + if (isRename) { + const descUpdate = + body.description !== undefined ? { description: body.description } : undefined; + const role = await renameRole(name, trimmedName, descUpdate); + if (!role) { + return res.status(404).json({ error: 'Role not found' }); + } + return res.status(200).json({ role }); + } + + const role = await updateRoleByName(name, updates); + if (!role) { + return res.status(404).json({ error: 'Role not found' }); + } + return res.status(200).json({ role }); + } catch (error) { + if (error instanceof RoleConflictError) { + return res.status(409).json({ error: error.message }); + } + logger.error('[adminRoles] updateRole error:', error); + return res.status(500).json({ error: 'Failed to update role' }); + } + } + + /** + * The re-fetch via `getRoleByName` after `updateAccessPermissions` depends on the + * callee having written the updated document to the role cache. If the cache layer + * is refactored to stop writing from within `updateAccessPermissions`, this handler + * must be updated to perform an explicit uncached DB read. + */ + async function updateRolePermissionsHandler(req: ServerRequest, res: Response) { + try { + const { name } = req.params as RoleNameParams; + const paramError = validateNameParam(name); + if (paramError) { + return res.status(400).json({ error: paramError }); + } + const { permissions } = req.body as { + permissions: Record>; + }; + + if (!permissions || typeof permissions !== 'object' || Array.isArray(permissions)) { + return res.status(400).json({ error: 'permissions object is required' }); + } + + const existing = await getRoleByName(name); + if (!existing) { + return res.status(404).json({ error: 'Role not found' }); + } + + await updateAccessPermissions(name, permissions, existing); + const updated = await getRoleByName(name); + if (!updated) { + return res.status(404).json({ error: 'Role not found' }); + } + return res.status(200).json({ role: updated }); + } catch (error) { + logger.error('[adminRoles] updateRolePermissions error:', error); + return res.status(500).json({ error: 'Failed to update role permissions' }); + } + } + + async function deleteRoleHandler(req: ServerRequest, res: Response) { + try { + const { name } = req.params as RoleNameParams; + const paramError = validateNameParam(name); + if (paramError) { + return res.status(400).json({ error: paramError }); + } + if (isSystemRoleName(name)) { + return res.status(403).json({ error: 'Cannot delete system role' }); + } + + const deleted = await deleteRoleByName(name); + if (!deleted) { + return res.status(404).json({ error: 'Role not found' }); + } + return res.status(200).json({ success: true }); + } catch (error) { + logger.error('[adminRoles] deleteRole error:', error); + return res.status(500).json({ error: 'Failed to delete role' }); + } + } + + async function getRoleMembersHandler(req: ServerRequest, res: Response) { + try { + const { name } = req.params as RoleNameParams; + const paramError = validateNameParam(name); + if (paramError) { + return res.status(400).json({ error: paramError }); + } + const existing = await getRoleByName(name); + if (!existing) { + return res.status(404).json({ error: 'Role not found' }); + } + + const { limit, offset } = parsePagination(req.query); + + const [users, total] = await Promise.all([ + listUsersByRole(name, { limit, offset }), + countUsersByRole(name), + ]); + const members: AdminMember[] = users.map((u) => ({ + userId: u._id?.toString() ?? '', + name: u.name ?? u._id?.toString() ?? '', + email: u.email ?? '', + avatarUrl: u.avatar, + })); + return res.status(200).json({ members, total, limit, offset }); + } catch (error) { + logger.error('[adminRoles] getRoleMembers error:', error); + return res.status(500).json({ error: 'Failed to get role members' }); + } + } + + async function addRoleMemberHandler(req: ServerRequest, res: Response) { + try { + const { name } = req.params as RoleNameParams; + const paramError = validateNameParam(name); + if (paramError) { + return res.status(400).json({ error: paramError }); + } + const { userId } = req.body as { userId: string }; + + if (!userId || typeof userId !== 'string') { + return res.status(400).json({ error: 'userId is required' }); + } + if (!isValidObjectIdString(userId)) { + return res.status(400).json({ error: 'Invalid user ID format' }); + } + + if (isSystemRoleName(name) && name !== SystemRoles.ADMIN) { + return res.status(403).json({ error: 'Cannot directly assign members to a system role' }); + } + + const existing = await getRoleByName(name); + if (!existing) { + return res.status(404).json({ error: 'Role not found' }); + } + + const user = await findUser({ _id: userId }); + if (!user) { + return res.status(404).json({ error: 'User not found' }); + } + + if (user.role === name) { + return res.status(200).json({ success: true }); + } + + if (user.role === SystemRoles.ADMIN && name !== SystemRoles.ADMIN) { + const adminCount = await countUsersByRole(SystemRoles.ADMIN); + if (adminCount <= 1) { + return res.status(400).json({ error: 'Cannot remove the last admin user' }); + } + } + + const updated = await updateUser(userId, { role: name }); + if (!updated) { + return res.status(404).json({ error: 'User not found' }); + } + + if (user.role === SystemRoles.ADMIN && name !== SystemRoles.ADMIN) { + const postCount = await countUsersByRole(SystemRoles.ADMIN); + if (postCount === 0) { + try { + await updateUser(userId, { role: SystemRoles.ADMIN }); + } catch (rollbackError) { + logger.error( + `[adminRoles] CRITICAL: admin rollback failed in addRoleMember for user ${userId}:`, + rollbackError, + ); + } + return res.status(400).json({ error: 'Cannot remove the last admin user' }); + } + } + + return res.status(200).json({ success: true }); + } catch (error) { + logger.error('[adminRoles] addRoleMember error:', error); + return res.status(500).json({ error: 'Failed to add role member' }); + } + } + + async function removeRoleMemberHandler(req: ServerRequest, res: Response) { + try { + const { name, userId } = req.params as RoleMemberParams; + const paramError = validateNameParam(name); + if (paramError) { + return res.status(400).json({ error: paramError }); + } + if (!isValidObjectIdString(userId)) { + return res.status(400).json({ error: 'Invalid user ID format' }); + } + + if (isSystemRoleName(name) && name !== SystemRoles.ADMIN) { + return res.status(403).json({ error: 'Cannot remove members from a system role' }); + } + + const existing = await getRoleByName(name); + if (!existing) { + return res.status(404).json({ error: 'Role not found' }); + } + + const user = await findUser({ _id: userId }); + if (!user) { + return res.status(404).json({ error: 'User not found' }); + } + + if (user.role !== name) { + return res.status(400).json({ error: 'User is not a member of this role' }); + } + + if (name === SystemRoles.ADMIN) { + const adminCount = await countUsersByRole(SystemRoles.ADMIN); + if (adminCount <= 1) { + return res.status(400).json({ error: 'Cannot remove the last admin user' }); + } + } + + const removed = await updateUser(userId, { role: SystemRoles.USER }); + if (!removed) { + return res.status(404).json({ error: 'User not found' }); + } + + if (name === SystemRoles.ADMIN) { + const postCount = await countUsersByRole(SystemRoles.ADMIN); + if (postCount === 0) { + try { + await updateUser(userId, { role: SystemRoles.ADMIN }); + } catch (rollbackError) { + logger.error( + `[adminRoles] CRITICAL: admin rollback failed for user ${userId}:`, + rollbackError, + ); + } + return res.status(400).json({ error: 'Cannot remove the last admin user' }); + } + } + + return res.status(200).json({ success: true }); + } catch (error) { + logger.error('[adminRoles] removeRoleMember error:', error); + return res.status(500).json({ error: 'Failed to remove role member' }); + } + } + + return { + listRoles: listRolesHandler, + getRole: getRoleHandler, + createRole: createRoleHandler, + updateRole: updateRoleHandler, + updateRolePermissions: updateRolePermissionsHandler, + deleteRole: deleteRoleHandler, + getRoleMembers: getRoleMembersHandler, + addRoleMember: addRoleMemberHandler, + removeRoleMember: removeRoleMemberHandler, + }; +} diff --git a/packages/api/src/agents/context.spec.ts b/packages/api/src/agents/context.spec.ts index c5358209c7..1d995a52bb 100644 --- a/packages/api/src/agents/context.spec.ts +++ b/packages/api/src/agents/context.spec.ts @@ -154,10 +154,10 @@ describe('Agent Context Utilities', () => { ); expect(result).toBe(instructions); - expect(mockMCPManager.formatInstructionsForContext).toHaveBeenCalledWith([ - 'server1', - 'server2', - ]); + expect(mockMCPManager.formatInstructionsForContext).toHaveBeenCalledWith( + ['server1', 'server2'], + undefined, + ); expect(mockLogger.debug).toHaveBeenCalledWith( '[AgentContext] Fetched MCP instructions for servers:', ['server1', 'server2'], @@ -345,9 +345,10 @@ describe('Agent Context Utilities', () => { logger: mockLogger, }); - expect(mockMCPManager.formatInstructionsForContext).toHaveBeenCalledWith([ - 'ephemeral-server', - ]); + expect(mockMCPManager.formatInstructionsForContext).toHaveBeenCalledWith( + ['ephemeral-server'], + undefined, + ); expect(agent.instructions).toContain('Ephemeral MCP'); }); @@ -375,7 +376,10 @@ describe('Agent Context Utilities', () => { logger: mockLogger, }); - expect(mockMCPManager.formatInstructionsForContext).toHaveBeenCalledWith(['agent-server']); + expect(mockMCPManager.formatInstructionsForContext).toHaveBeenCalledWith( + ['agent-server'], + undefined, + ); }); it('should work without agentId', async () => { diff --git a/packages/api/src/agents/context.ts b/packages/api/src/agents/context.ts index ebae2e0f9f..c526fd13fe 100644 --- a/packages/api/src/agents/context.ts +++ b/packages/api/src/agents/context.ts @@ -1,8 +1,9 @@ -import { DynamicStructuredTool } from '@langchain/core/tools'; import { Constants } from 'librechat-data-provider'; +import { DynamicStructuredTool } from '@langchain/core/tools'; import type { Agent, TEphemeralAgent } from 'librechat-data-provider'; import type { LCTool } from '@librechat/agents'; import type { Logger } from 'winston'; +import type { ParsedServerConfig } from '~/mcp/types'; import type { MCPManager } from '~/mcp/MCPManager'; /** @@ -63,12 +64,16 @@ export async function getMCPInstructionsForServers( mcpServers: string[], mcpManager: MCPManager, logger?: Logger, + configServers?: Record, ): Promise { if (!mcpServers.length) { return ''; } try { - const mcpInstructions = await mcpManager.formatInstructionsForContext(mcpServers); + const mcpInstructions = await mcpManager.formatInstructionsForContext( + mcpServers, + configServers, + ); if (mcpInstructions && logger) { logger.debug('[AgentContext] Fetched MCP instructions for servers:', mcpServers); } @@ -125,6 +130,7 @@ export async function applyContextToAgent({ ephemeralAgent, agentId, logger, + configServers, }: { agent: AgentWithTools; sharedRunContext: string; @@ -132,12 +138,18 @@ export async function applyContextToAgent({ ephemeralAgent?: TEphemeralAgent; agentId?: string; logger?: Logger; + configServers?: Record; }): Promise { const baseInstructions = agent.instructions || ''; try { const mcpServers = ephemeralAgent?.mcp?.length ? ephemeralAgent.mcp : extractMCPServers(agent); - const mcpInstructions = await getMCPInstructionsForServers(mcpServers, mcpManager, logger); + const mcpInstructions = await getMCPInstructionsForServers( + mcpServers, + mcpManager, + logger, + configServers, + ); agent.instructions = buildAgentInstructions({ sharedRunContext, diff --git a/packages/api/src/app/index.ts b/packages/api/src/app/index.ts index b95193e943..8d8802f016 100644 --- a/packages/api/src/app/index.ts +++ b/packages/api/src/app/index.ts @@ -1,4 +1,6 @@ +export * from './service'; export * from './config'; export * from './permissions'; export * from './cdn'; export * from './checks'; +export * from './resolve'; diff --git a/packages/api/src/app/permissions.ts b/packages/api/src/app/permissions.ts index 5a557adfcf..92da1342ce 100644 --- a/packages/api/src/app/permissions.ts +++ b/packages/api/src/app/permissions.ts @@ -1,4 +1,4 @@ -import { logger } from '@librechat/data-schemas'; +import { logger, tenantStorage, SYSTEM_TENANT_ID } from '@librechat/data-schemas'; import { SystemRoles, Permissions, @@ -54,6 +54,7 @@ export async function updateInterfacePermissions({ appConfig, getRoleByName, updateAccessPermissions, + tenantId, }: { appConfig: AppConfig; getRoleByName: (roleName: string, fieldsToSelect?: string | string[]) => Promise; @@ -63,7 +64,19 @@ export async function updateInterfacePermissions({ roleData?: IRole | null, ) => Promise; -}) { + /** + * Optional tenant ID for scoping role updates to a specific tenant. + * When provided (and not SYSTEM_TENANT_ID), runs inside `tenantStorage.run({ tenantId })`. + * When omitted or SYSTEM_TENANT_ID, uses the caller's existing ALS context. + */ + tenantId?: string; +}): Promise { + if (tenantId && tenantId !== SYSTEM_TENANT_ID) { + return tenantStorage.run({ tenantId }, async () => + updateInterfacePermissions({ appConfig, getRoleByName, updateAccessPermissions }), + ); + } + const loadedInterface = appConfig?.interfaceConfig; if (!loadedInterface) { return; diff --git a/packages/api/src/app/resolve.spec.ts b/packages/api/src/app/resolve.spec.ts new file mode 100644 index 0000000000..d7585198a0 --- /dev/null +++ b/packages/api/src/app/resolve.spec.ts @@ -0,0 +1,95 @@ +import type { AsyncLocalStorage } from 'async_hooks'; + +jest.mock('@librechat/data-schemas', () => { + // eslint-disable-next-line @typescript-eslint/no-require-imports + const { AsyncLocalStorage: ALS } = require('async_hooks'); + return { tenantStorage: new ALS() }; +}); + +import { resolveAppConfigForUser } from './resolve'; + +const { tenantStorage } = jest.requireMock('@librechat/data-schemas') as { + tenantStorage: AsyncLocalStorage<{ tenantId?: string }>; +}; + +describe('resolveAppConfigForUser', () => { + const mockGetAppConfig = jest.fn(); + + beforeEach(() => { + jest.clearAllMocks(); + mockGetAppConfig.mockResolvedValue({ registration: {} }); + }); + + it('calls getAppConfig with baseOnly when user is null', async () => { + await resolveAppConfigForUser(mockGetAppConfig, null); + expect(mockGetAppConfig).toHaveBeenCalledWith({ baseOnly: true }); + }); + + it('calls getAppConfig with baseOnly when user is undefined', async () => { + await resolveAppConfigForUser(mockGetAppConfig, undefined); + expect(mockGetAppConfig).toHaveBeenCalledWith({ baseOnly: true }); + }); + + it('calls getAppConfig with baseOnly when user has no tenantId', async () => { + await resolveAppConfigForUser(mockGetAppConfig, { role: 'USER' }); + expect(mockGetAppConfig).toHaveBeenCalledWith({ baseOnly: true }); + }); + + it('calls getAppConfig with role and tenantId when user has tenantId', async () => { + await resolveAppConfigForUser(mockGetAppConfig, { tenantId: 'tenant-a', role: 'USER' }); + expect(mockGetAppConfig).toHaveBeenCalledWith({ role: 'USER', tenantId: 'tenant-a' }); + }); + + it('calls tenantStorage.run for tenant users but not for non-tenant users', async () => { + const runSpy = jest.spyOn(tenantStorage, 'run'); + + await resolveAppConfigForUser(mockGetAppConfig, { role: 'USER' }); + expect(runSpy).not.toHaveBeenCalled(); + + await resolveAppConfigForUser(mockGetAppConfig, { tenantId: 'tenant-b', role: 'ADMIN' }); + expect(runSpy).toHaveBeenCalledWith({ tenantId: 'tenant-b' }, expect.any(Function)); + + runSpy.mockRestore(); + }); + + it('makes tenantId available via ALS inside getAppConfig', async () => { + let capturedContext: { tenantId?: string } | undefined; + mockGetAppConfig.mockImplementation(async () => { + capturedContext = tenantStorage.getStore(); + return { registration: {} }; + }); + + await resolveAppConfigForUser(mockGetAppConfig, { tenantId: 'tenant-c', role: 'USER' }); + + expect(capturedContext).toEqual({ tenantId: 'tenant-c' }); + }); + + it('returns the config from getAppConfig', async () => { + const tenantConfig = { registration: { allowedDomains: ['example.com'] } }; + mockGetAppConfig.mockResolvedValue(tenantConfig); + + const result = await resolveAppConfigForUser(mockGetAppConfig, { + tenantId: 'tenant-d', + role: 'USER', + }); + + expect(result).toBe(tenantConfig); + }); + + it('calls getAppConfig with role undefined when user has tenantId but no role', async () => { + await resolveAppConfigForUser(mockGetAppConfig, { tenantId: 'tenant-e' }); + expect(mockGetAppConfig).toHaveBeenCalledWith({ role: undefined, tenantId: 'tenant-e' }); + }); + + it('propagates rejection from getAppConfig for tenant users', async () => { + mockGetAppConfig.mockRejectedValue(new Error('config unavailable')); + await expect( + resolveAppConfigForUser(mockGetAppConfig, { tenantId: 'tenant-f', role: 'USER' }), + ).rejects.toThrow('config unavailable'); + }); + + it('propagates rejection from getAppConfig for baseOnly path', async () => { + mockGetAppConfig.mockRejectedValue(new Error('cache failure')); + await expect(resolveAppConfigForUser(mockGetAppConfig, null)).rejects.toThrow('cache failure'); + }); +}); diff --git a/packages/api/src/app/resolve.ts b/packages/api/src/app/resolve.ts new file mode 100644 index 0000000000..0810400222 --- /dev/null +++ b/packages/api/src/app/resolve.ts @@ -0,0 +1,39 @@ +import { tenantStorage } from '@librechat/data-schemas'; +import type { AppConfig } from '@librechat/data-schemas'; + +interface UserForConfigResolution { + tenantId?: string; + role?: string; +} + +type GetAppConfig = (opts: { + role?: string; + tenantId?: string; + baseOnly?: boolean; +}) => Promise; + +/** + * Resolves AppConfig scoped to the given user's tenant when available, + * falling back to YAML-only base config for new users or non-tenant deployments. + * + * Auth flows only apply role-level overrides (userId is not passed) because + * user/group principal resolution requires heavier DB work that is deferred + * to post-authentication config calls. + * + * `tenantId` is propagated through two channels that serve different purposes: + * - `tenantStorage.run()` sets the ALS context so Mongoose's `applyTenantIsolation` + * plugin scopes any DB queries (e.g., `getApplicableConfigs`) to the tenant. + * - The explicit `tenantId` parameter to `getAppConfig` is used for cache-key + * computation in `overrideCacheKey()`. Both channels are required. + */ +export async function resolveAppConfigForUser( + getAppConfig: GetAppConfig, + user: UserForConfigResolution | null | undefined, +): Promise { + if (user?.tenantId) { + return tenantStorage.run({ tenantId: user.tenantId }, async () => + getAppConfig({ role: user.role, tenantId: user.tenantId }), + ); + } + return getAppConfig({ baseOnly: true }); +} diff --git a/packages/api/src/app/service.spec.ts b/packages/api/src/app/service.spec.ts new file mode 100644 index 0000000000..c410783793 --- /dev/null +++ b/packages/api/src/app/service.spec.ts @@ -0,0 +1,346 @@ +import type { AppConfig } from '@librechat/data-schemas'; +import { createAppConfigService } from './service'; + +/** Extends AppConfig with mock fields used by merge behavior tests. */ +interface TestConfig extends AppConfig { + restricted?: boolean; + x?: string; + interface?: { endpointsMenu?: boolean; [key: string]: boolean | undefined }; +} + +/** + * Creates a mock cache that simulates Keyv's namespace behavior. + * Keyv stores keys internally as `namespace:key` but its API (get/set/delete) + * accepts un-namespaced keys and auto-prepends the namespace. + */ +function createMockCache(namespace = 'app_config') { + const store = new Map(); + return { + get: jest.fn((key) => Promise.resolve(store.get(`${namespace}:${key}`))), + set: jest.fn((key, value) => { + store.set(`${namespace}:${key}`, value); + return Promise.resolve(undefined); + }), + delete: jest.fn((key) => { + store.delete(`${namespace}:${key}`); + return Promise.resolve(true); + }), + /** Mimic Keyv's opts.store structure for key enumeration in clearOverrideCache */ + opts: { store: { keys: () => store.keys() } } as { + store?: { keys: () => IterableIterator }; + }, + _store: store, + }; +} + +function createDeps(overrides = {}) { + const cache = createMockCache(); + const baseConfig = { interface: { endpointsMenu: true }, endpoints: ['openAI'] }; + + return { + loadBaseConfig: jest.fn().mockResolvedValue(baseConfig), + setCachedTools: jest.fn().mockResolvedValue(undefined), + getCache: jest.fn().mockReturnValue(cache), + cacheKeys: { APP_CONFIG: 'app_config' }, + getApplicableConfigs: jest.fn().mockResolvedValue([]), + getUserPrincipals: jest.fn().mockResolvedValue([ + { principalType: 'role', principalId: 'USER' }, + { principalType: 'user', principalId: 'uid1' }, + ]), + _cache: cache, + _baseConfig: baseConfig, + ...overrides, + }; +} + +describe('createAppConfigService', () => { + describe('getAppConfig', () => { + it('loads base config on first call', async () => { + const deps = createDeps(); + const { getAppConfig } = createAppConfigService(deps); + + const config = await getAppConfig(); + + expect(deps.loadBaseConfig).toHaveBeenCalledTimes(1); + expect(config).toEqual(deps._baseConfig); + }); + + it('caches base config — does not reload on second call', async () => { + const deps = createDeps(); + const { getAppConfig } = createAppConfigService(deps); + + await getAppConfig(); + await getAppConfig(); + + expect(deps.loadBaseConfig).toHaveBeenCalledTimes(1); + }); + + it('baseOnly returns YAML config without DB queries', async () => { + const deps = createDeps({ + getApplicableConfigs: jest + .fn() + .mockResolvedValue([ + { priority: 10, overrides: { interface: { endpointsMenu: false } }, isActive: true }, + ]), + }); + const { getAppConfig } = createAppConfigService(deps); + + const config = await getAppConfig({ baseOnly: true }); + + expect(deps.loadBaseConfig).toHaveBeenCalledTimes(1); + expect(deps.getApplicableConfigs).not.toHaveBeenCalled(); + expect(config).toEqual(deps._baseConfig); + }); + + it('reloads base config when refresh is true', async () => { + const deps = createDeps(); + const { getAppConfig } = createAppConfigService(deps); + + await getAppConfig(); + await getAppConfig({ refresh: true }); + + expect(deps.loadBaseConfig).toHaveBeenCalledTimes(2); + }); + + it('queries DB for applicable configs', async () => { + const deps = createDeps(); + const { getAppConfig } = createAppConfigService(deps); + + await getAppConfig({ role: 'ADMIN' }); + + expect(deps.getApplicableConfigs).toHaveBeenCalled(); + }); + + it('caches empty result — does not re-query DB on second call', async () => { + const deps = createDeps({ getApplicableConfigs: jest.fn().mockResolvedValue([]) }); + const { getAppConfig } = createAppConfigService(deps); + + await getAppConfig({ role: 'USER' }); + await getAppConfig({ role: 'USER' }); + + expect(deps.getApplicableConfigs).toHaveBeenCalledTimes(1); + }); + + it('merges DB configs when found', async () => { + const deps = createDeps({ + getApplicableConfigs: jest + .fn() + .mockResolvedValue([ + { priority: 10, overrides: { interface: { endpointsMenu: false } }, isActive: true }, + ]), + }); + const { getAppConfig } = createAppConfigService(deps); + + const config = await getAppConfig({ role: 'ADMIN' }); + + // Test data uses mock fields that don't exist on AppConfig to verify merge behavior + const merged = config as TestConfig; + expect(merged.interface?.endpointsMenu).toBe(false); + expect(merged.endpoints).toEqual(['openAI']); + }); + + it('caches merged result with TTL', async () => { + const deps = createDeps({ + getApplicableConfigs: jest + .fn() + .mockResolvedValue([{ priority: 10, overrides: { x: 1 }, isActive: true }]), + }); + const { getAppConfig } = createAppConfigService(deps); + + await getAppConfig({ role: 'ADMIN' }); + await getAppConfig({ role: 'ADMIN' }); + + expect(deps.getApplicableConfigs).toHaveBeenCalledTimes(1); + }); + + it('uses separate cache keys per userId (no cross-user contamination)', async () => { + const deps = createDeps({ + getApplicableConfigs: jest + .fn() + .mockResolvedValue([ + { priority: 100, overrides: { x: 'user-specific' }, isActive: true }, + ]), + }); + const { getAppConfig } = createAppConfigService(deps); + + await getAppConfig({ userId: 'uid1' }); + await getAppConfig({ userId: 'uid2' }); + + expect(deps.getApplicableConfigs).toHaveBeenCalledTimes(2); + }); + + it('userId without role gets its own cache key', async () => { + const deps = createDeps({ + getApplicableConfigs: jest + .fn() + .mockResolvedValue([{ priority: 100, overrides: { y: 1 }, isActive: true }]), + }); + const { getAppConfig } = createAppConfigService(deps); + + await getAppConfig({ userId: 'uid1' }); + + const cachedKeys = [...deps._cache._store.keys()]; + const overrideKey = cachedKeys.find((k) => k.includes('_OVERRIDE_:')); + expect(overrideKey).toBe('app_config:_OVERRIDE_:__default__:uid1'); + }); + + it('tenantId is included in cache key to prevent cross-tenant contamination', async () => { + const deps = createDeps({ + getApplicableConfigs: jest + .fn() + .mockResolvedValue([{ priority: 10, overrides: { x: 1 }, isActive: true }]), + }); + const { getAppConfig } = createAppConfigService(deps); + + await getAppConfig({ role: 'ADMIN', tenantId: 'tenant-a' }); + await getAppConfig({ role: 'ADMIN', tenantId: 'tenant-b' }); + + expect(deps.getApplicableConfigs).toHaveBeenCalledTimes(2); + }); + + it('base-only empty result does not block subsequent scoped queries with results', async () => { + const mockGetConfigs = jest.fn().mockResolvedValue([]); + const deps = createDeps({ getApplicableConfigs: mockGetConfigs }); + const { getAppConfig } = createAppConfigService(deps); + + await getAppConfig(); + + mockGetConfigs.mockResolvedValueOnce([ + { priority: 10, overrides: { restricted: true }, isActive: true }, + ]); + const config = await getAppConfig({ role: 'ADMIN' }); + + expect(mockGetConfigs).toHaveBeenCalledTimes(2); + expect((config as TestConfig).restricted).toBe(true); + }); + + it('does not short-circuit other users when one user has no overrides', async () => { + const mockGetConfigs = jest.fn().mockResolvedValue([]); + const deps = createDeps({ getApplicableConfigs: mockGetConfigs }); + const { getAppConfig } = createAppConfigService(deps); + + await getAppConfig({ role: 'USER' }); + expect(mockGetConfigs).toHaveBeenCalledTimes(1); + + mockGetConfigs.mockResolvedValueOnce([ + { priority: 10, overrides: { x: 'admin-only' }, isActive: true }, + ]); + const config = await getAppConfig({ role: 'ADMIN' }); + + expect(mockGetConfigs).toHaveBeenCalledTimes(2); + expect((config as TestConfig).x).toBe('admin-only'); + }); + + it('falls back to base config on getApplicableConfigs error', async () => { + const deps = createDeps({ + getApplicableConfigs: jest.fn().mockRejectedValue(new Error('DB down')), + }); + const { getAppConfig } = createAppConfigService(deps); + + const config = await getAppConfig({ role: 'ADMIN' }); + + expect(config).toEqual(deps._baseConfig); + }); + + it('calls getUserPrincipals when userId is provided', async () => { + const deps = createDeps(); + const { getAppConfig } = createAppConfigService(deps); + + await getAppConfig({ role: 'USER', userId: 'uid1' }); + + expect(deps.getUserPrincipals).toHaveBeenCalledWith({ + userId: 'uid1', + role: 'USER', + }); + }); + + it('does not call getUserPrincipals when only role is provided', async () => { + const deps = createDeps(); + const { getAppConfig } = createAppConfigService(deps); + + await getAppConfig({ role: 'ADMIN' }); + + expect(deps.getUserPrincipals).not.toHaveBeenCalled(); + }); + }); + + describe('clearAppConfigCache', () => { + it('clears base config so it reloads on next call', async () => { + const deps = createDeps(); + const { getAppConfig, clearAppConfigCache } = createAppConfigService(deps); + + await getAppConfig(); + expect(deps.loadBaseConfig).toHaveBeenCalledTimes(1); + + await clearAppConfigCache(); + await getAppConfig(); + expect(deps.loadBaseConfig).toHaveBeenCalledTimes(2); + }); + }); + + describe('clearOverrideCache', () => { + it('clears all override caches when no tenantId is provided', async () => { + const deps = createDeps({ + getApplicableConfigs: jest + .fn() + .mockResolvedValue([{ priority: 10, overrides: { x: 1 }, isActive: true }]), + }); + const { getAppConfig, clearOverrideCache } = createAppConfigService(deps); + + await getAppConfig({ role: 'ADMIN', tenantId: 'tenant-a' }); + await getAppConfig({ role: 'ADMIN', tenantId: 'tenant-b' }); + expect(deps.getApplicableConfigs).toHaveBeenCalledTimes(2); + + await clearOverrideCache(); + + // After clearing, both tenants should re-query DB + await getAppConfig({ role: 'ADMIN', tenantId: 'tenant-a' }); + await getAppConfig({ role: 'ADMIN', tenantId: 'tenant-b' }); + expect(deps.getApplicableConfigs).toHaveBeenCalledTimes(4); + }); + + it('clears only specified tenant override caches', async () => { + const deps = createDeps({ + getApplicableConfigs: jest + .fn() + .mockResolvedValue([{ priority: 10, overrides: { x: 1 }, isActive: true }]), + }); + const { getAppConfig, clearOverrideCache } = createAppConfigService(deps); + + await getAppConfig({ role: 'ADMIN', tenantId: 'tenant-a' }); + await getAppConfig({ role: 'ADMIN', tenantId: 'tenant-b' }); + expect(deps.getApplicableConfigs).toHaveBeenCalledTimes(2); + + await clearOverrideCache('tenant-a'); + + // tenant-a should re-query, tenant-b should be cached + await getAppConfig({ role: 'ADMIN', tenantId: 'tenant-a' }); + await getAppConfig({ role: 'ADMIN', tenantId: 'tenant-b' }); + expect(deps.getApplicableConfigs).toHaveBeenCalledTimes(3); + }); + + it('does not clear base config', async () => { + const deps = createDeps(); + const { getAppConfig, clearOverrideCache } = createAppConfigService(deps); + + await getAppConfig(); + expect(deps.loadBaseConfig).toHaveBeenCalledTimes(1); + + await clearOverrideCache(); + + await getAppConfig(); + // Base config should still be cached + expect(deps.loadBaseConfig).toHaveBeenCalledTimes(1); + }); + + it('does not throw when store.keys is unavailable (Redis fallback to TTL expiry)', async () => { + const deps = createDeps(); + // Remove store.keys to simulate Redis-backed cache + deps._cache.opts = {}; + const { clearOverrideCache } = createAppConfigService(deps); + + // Should not throw — logs warning and relies on TTL expiry + await expect(clearOverrideCache()).resolves.toBeUndefined(); + }); + }); +}); diff --git a/packages/api/src/app/service.ts b/packages/api/src/app/service.ts new file mode 100644 index 0000000000..6c5d307709 --- /dev/null +++ b/packages/api/src/app/service.ts @@ -0,0 +1,251 @@ +import { PrincipalType } from 'librechat-data-provider'; +import { logger, mergeConfigOverrides, BASE_CONFIG_PRINCIPAL_ID } from '@librechat/data-schemas'; +import type { Types } from 'mongoose'; +import type { AppConfig, IConfig } from '@librechat/data-schemas'; + +const BASE_CONFIG_KEY = '_BASE_'; + +const DEFAULT_OVERRIDE_CACHE_TTL = 60_000; + +// ── Types ──────────────────────────────────────────────────────────── + +interface CacheStore { + get: (key: string) => Promise; + set: (key: string, value: unknown, ttl?: number) => Promise; + delete: (key: string) => Promise; + /** Keyv options — used for key enumeration when clearing override caches. */ + opts?: { + store?: { + keys?: () => IterableIterator; + }; + }; +} + +export interface AppConfigServiceDeps { + /** Load the base AppConfig from YAML + AppService processing. */ + loadBaseConfig: () => Promise; + /** Cache tools after base config is loaded. */ + setCachedTools: (tools: Record) => Promise; + /** Get a cache store by key. */ + getCache: (key: string) => CacheStore; + /** The CacheKeys constants from librechat-data-provider. */ + cacheKeys: { APP_CONFIG: string }; + /** Fetch applicable DB config overrides for a set of principals. */ + getApplicableConfigs: ( + principals?: Array<{ principalType: string; principalId?: string | Types.ObjectId }>, + ) => Promise; + /** Resolve full principal list (user + role + groups) from userId/role. */ + getUserPrincipals: (params: { + userId: string | Types.ObjectId; + role?: string | null; + }) => Promise>; + /** TTL in ms for per-user/role merged config caches. Defaults to 60 000. */ + overrideCacheTtl?: number; +} + +// ── Helpers ────────────────────────────────────────────────────────── + +let _strictOverride: boolean | undefined; +function isStrictOverrideMode(): boolean { + return (_strictOverride ??= process.env.TENANT_ISOLATION_STRICT === 'true'); +} + +/** @internal Resets the cached strict-override flag. Exposed for test teardown only. */ +let _warnedNoTenantInStrictMode = false; + +export function _resetOverrideStrictCache(): void { + _strictOverride = undefined; + _warnedNoTenantInStrictMode = false; +} + +function overrideCacheKey(role?: string, userId?: string, tenantId?: string): string { + const tenant = tenantId || '__default__'; + if (!tenantId && isStrictOverrideMode() && !_warnedNoTenantInStrictMode) { + _warnedNoTenantInStrictMode = true; + logger.warn( + '[overrideCacheKey] No tenantId in strict mode — falling back to __default__. ' + + 'This likely indicates a code path that bypasses the tenant context middleware.', + ); + } + if (userId && role) { + return `_OVERRIDE_:${tenant}:${role}:${userId}`; + } + if (userId) { + return `_OVERRIDE_:${tenant}:${userId}`; + } + if (role) { + return `_OVERRIDE_:${tenant}:${role}`; + } + return `_OVERRIDE_:${tenant}:${BASE_CONFIG_PRINCIPAL_ID}`; +} + +// ── Service factory ────────────────────────────────────────────────── + +export function createAppConfigService(deps: AppConfigServiceDeps) { + const { + loadBaseConfig, + setCachedTools, + getCache, + cacheKeys, + getApplicableConfigs, + getUserPrincipals, + overrideCacheTtl = DEFAULT_OVERRIDE_CACHE_TTL, + } = deps; + + const cache = getCache(cacheKeys.APP_CONFIG); + + async function buildPrincipals( + role?: string, + userId?: string, + ): Promise> { + if (userId) { + return getUserPrincipals({ userId, role }); + } + const principals: Array<{ principalType: string; principalId?: string | Types.ObjectId }> = []; + if (role) { + principals.push({ principalType: PrincipalType.ROLE, principalId: role }); + } + return principals; + } + + /** + * Ensure the YAML-derived base config is loaded and cached. + * Returns the `_BASE_` config (YAML + AppService). No DB queries. + */ + async function ensureBaseConfig(refresh?: boolean): Promise { + let baseConfig = (await cache.get(BASE_CONFIG_KEY)) as AppConfig | undefined; + if (!baseConfig || refresh) { + logger.info('[ensureBaseConfig] Loading base configuration...'); + baseConfig = await loadBaseConfig(); + + if (!baseConfig) { + throw new Error('Failed to initialize app configuration through AppService.'); + } + + if (baseConfig.availableTools) { + await setCachedTools(baseConfig.availableTools); + } + + await cache.set(BASE_CONFIG_KEY, baseConfig); + } + return baseConfig; + } + + /** + * Get the app configuration, optionally merged with DB overrides for the given principal. + * + * The base config (from YAML + AppService) is cached indefinitely. Per-principal merged + * configs are cached with a short TTL (`overrideCacheTtl`, default 60s). On cache miss, + * `getApplicableConfigs` queries the DB for matching overrides and merges them by priority. + * + * When `baseOnly` is true, returns the YAML-derived config without any DB queries. + * `role`, `userId`, and `tenantId` are ignored in this mode. + * Use this for startup, auth strategies, and other pre-tenant code paths. + */ + async function getAppConfig( + options: { + role?: string; + userId?: string; + tenantId?: string; + refresh?: boolean; + /** When true, return only the YAML-derived base config — no DB override queries. */ + baseOnly?: boolean; + } = {}, + ): Promise { + const { role, userId, tenantId, refresh, baseOnly } = options; + + const baseConfig = await ensureBaseConfig(refresh); + + if (baseOnly) { + return baseConfig; + } + + const cacheKey = overrideCacheKey(role, userId, tenantId); + if (!refresh) { + const cachedMerged = (await cache.get(cacheKey)) as AppConfig | undefined; + if (cachedMerged) { + return cachedMerged; + } + } + + try { + const principals = await buildPrincipals(role, userId); + const configs = await getApplicableConfigs(principals); + + if (configs.length === 0) { + await cache.set(cacheKey, baseConfig, overrideCacheTtl); + return baseConfig; + } + + const merged = mergeConfigOverrides(baseConfig, configs); + await cache.set(cacheKey, merged, overrideCacheTtl); + return merged; + } catch (error) { + logger.error('[getAppConfig] Error resolving config overrides, falling back to base:', error); + return baseConfig; + } + } + + /** + * Clear the base config cache. Per-user/role override caches (`_OVERRIDE_:*`) + * are NOT flushed — they expire naturally via `overrideCacheTtl`. After calling this, + * the base config will be reloaded from YAML on the next `getAppConfig` call, but + * users with cached overrides may see stale merged configs for up to `overrideCacheTtl` ms. + */ + async function clearAppConfigCache(): Promise { + await cache.delete(BASE_CONFIG_KEY); + } + + /** + * Clear per-principal override caches. When `tenantId` is provided, only caches + * matching `_OVERRIDE_:${tenantId}:*` are deleted. When omitted, ALL override + * caches are cleared. + */ + async function clearOverrideCache(tenantId?: string): Promise { + const namespace = cacheKeys.APP_CONFIG; + const overrideSegment = tenantId ? `_OVERRIDE_:${tenantId}:` : '_OVERRIDE_:'; + + // In-memory store — enumerate keys directly. + // APP_CONFIG defaults to FORCED_IN_MEMORY_CACHE_NAMESPACES, so this is the + // standard path. Redis SCAN is intentionally avoided here — it can cause 60s+ + // stalls under concurrent load (see #12410). When APP_CONFIG is Redis-backed + // and store.keys() is unavailable, overrides expire naturally via TTL. + const store = (cache as CacheStore).opts?.store; + if (store && typeof store.keys === 'function') { + // Keyv stores keys with a namespace prefix (e.g. "APP_CONFIG:_OVERRIDE_:..."). + // We match on the namespaced key but delete using the un-namespaced key + // because Keyv.delete() auto-prepends the namespace. + const namespacedPrefix = `${namespace}:${overrideSegment}`; + const toDelete: string[] = []; + for (const key of store.keys()) { + if (key.startsWith(namespacedPrefix)) { + toDelete.push(key.slice(namespace.length + 1)); + } + } + if (toDelete.length > 0) { + await Promise.all(toDelete.map((key) => cache.delete(key))); + logger.info( + `[clearOverrideCache] Cleared ${toDelete.length} override cache entries` + + (tenantId ? ` for tenant ${tenantId}` : ''), + ); + } + return; + } + + logger.warn( + '[clearOverrideCache] Cache store does not support key enumeration. ' + + 'Override caches will expire naturally via TTL (%dms). ' + + 'This is expected when APP_CONFIG is Redis-backed — Redis SCAN is avoided ' + + 'for performance reasons (see #12410).', + overrideCacheTtl, + ); + } + + return { + getAppConfig, + clearAppConfigCache, + clearOverrideCache, + }; +} + +export type AppConfigService = ReturnType; diff --git a/packages/api/src/auth/openid.spec.ts b/packages/api/src/auth/openid.spec.ts index 0761a24e85..2cf3992cdf 100644 --- a/packages/api/src/auth/openid.spec.ts +++ b/packages/api/src/auth/openid.spec.ts @@ -1,8 +1,13 @@ +import { Types } from 'mongoose'; import { ErrorTypes } from 'librechat-data-provider'; import { logger } from '@librechat/data-schemas'; import type { IUser, UserMethods } from '@librechat/data-schemas'; import { findOpenIDUser } from './openid'; +function newId() { + return new Types.ObjectId(); +} + jest.mock('@librechat/data-schemas', () => ({ ...jest.requireActual('@librechat/data-schemas'), logger: { @@ -24,7 +29,7 @@ describe('findOpenIDUser', () => { describe('Primary condition searches', () => { it('should find user by openidId', async () => { const mockUser: IUser = { - _id: 'user123', + _id: newId(), provider: 'openid', openidId: 'openid_123', email: 'user@example.com', @@ -51,7 +56,7 @@ describe('findOpenIDUser', () => { it('should find user by idOnTheSource', async () => { const mockUser: IUser = { - _id: 'user123', + _id: newId(), provider: 'openid', idOnTheSource: 'source_123', email: 'user@example.com', @@ -78,7 +83,7 @@ describe('findOpenIDUser', () => { it('should find user by both openidId and idOnTheSource', async () => { const mockUser: IUser = { - _id: 'user123', + _id: newId(), provider: 'openid', openidId: 'openid_123', idOnTheSource: 'source_123', @@ -109,16 +114,14 @@ describe('findOpenIDUser', () => { describe('Email-based searches', () => { it('should find user by email when primary conditions fail and openidId matches', async () => { const mockUser: IUser = { - _id: 'user123', + _id: newId(), provider: 'openid', openidId: 'openid_123', email: 'user@example.com', username: 'testuser', } as IUser; - mockFindUser - .mockResolvedValueOnce(null) - .mockResolvedValueOnce(mockUser); + mockFindUser.mockResolvedValueOnce(null).mockResolvedValueOnce(mockUser); const result = await findOpenIDUser({ openidId: 'openid_123', @@ -179,7 +182,7 @@ describe('findOpenIDUser', () => { describe('Provider conflict handling', () => { it('should return error when user has different provider', async () => { const mockUser: IUser = { - _id: 'user123', + _id: newId(), provider: 'google', email: 'user@example.com', username: 'testuser', @@ -204,16 +207,14 @@ describe('findOpenIDUser', () => { it('should reject email fallback when existing openidId does not match token sub', async () => { const mockUser: IUser = { - _id: 'user123', + _id: newId(), provider: 'openid', openidId: 'openid_456', email: 'user@example.com', username: 'testuser', } as IUser; - mockFindUser - .mockResolvedValueOnce(null) - .mockResolvedValueOnce(mockUser); + mockFindUser.mockResolvedValueOnce(null).mockResolvedValueOnce(mockUser); const result = await findOpenIDUser({ openidId: 'openid_123', @@ -230,16 +231,14 @@ describe('findOpenIDUser', () => { it('should allow email fallback when existing openidId matches token sub', async () => { const mockUser: IUser = { - _id: 'user123', + _id: newId(), provider: 'openid', openidId: 'openid_123', email: 'user@example.com', username: 'testuser', } as IUser; - mockFindUser - .mockResolvedValueOnce(null) - .mockResolvedValueOnce(mockUser); + mockFindUser.mockResolvedValueOnce(null).mockResolvedValueOnce(mockUser); const result = await findOpenIDUser({ openidId: 'openid_123', @@ -258,7 +257,7 @@ describe('findOpenIDUser', () => { describe('User migration scenarios', () => { it('should prepare user for migration when email exists without openidId', async () => { const mockUser: IUser = { - _id: 'user123', + _id: newId(), email: 'user@example.com', username: 'testuser', // No provider and no openidId - needs migration @@ -287,16 +286,14 @@ describe('findOpenIDUser', () => { it('should reject when user already has a different openidId', async () => { const mockUser: IUser = { - _id: 'user123', + _id: newId(), provider: 'openid', openidId: 'existing_openid', email: 'user@example.com', username: 'testuser', } as IUser; - mockFindUser - .mockResolvedValueOnce(null) - .mockResolvedValueOnce(mockUser); + mockFindUser.mockResolvedValueOnce(null).mockResolvedValueOnce(mockUser); const result = await findOpenIDUser({ openidId: 'openid_123', @@ -313,16 +310,14 @@ describe('findOpenIDUser', () => { it('should reject when user has no provider but a different openidId', async () => { const mockUser: IUser = { - _id: 'user123', + _id: newId(), openidId: 'existing_openid', email: 'user@example.com', username: 'testuser', // No provider field — tests a different branch than openid-provider mismatch } as IUser; - mockFindUser - .mockResolvedValueOnce(null) - .mockResolvedValueOnce(mockUser); + mockFindUser.mockResolvedValueOnce(null).mockResolvedValueOnce(mockUser); const result = await findOpenIDUser({ openidId: 'openid_123', @@ -422,16 +417,14 @@ describe('findOpenIDUser', () => { it('should pass email to findUser for case-insensitive lookup (findUser handles normalization)', async () => { const mockUser: IUser = { - _id: 'user123', + _id: newId(), provider: 'openid', openidId: 'openid_123', email: 'user@example.com', username: 'testuser', } as IUser; - mockFindUser - .mockResolvedValueOnce(null) - .mockResolvedValueOnce(mockUser); + mockFindUser.mockResolvedValueOnce(null).mockResolvedValueOnce(mockUser); const result = await findOpenIDUser({ openidId: 'openid_123', @@ -460,7 +453,7 @@ describe('findOpenIDUser', () => { it('should reject email fallback when openidId is empty and user has a stored openidId', async () => { const mockUser: IUser = { - _id: 'user123', + _id: newId(), provider: 'openid', openidId: 'existing-real-id', email: 'user@example.com', diff --git a/packages/api/src/cache/__tests__/cacheFactory/sessionCache.cache_integration.spec.ts b/packages/api/src/cache/__tests__/cacheFactory/sessionCache.cache_integration.spec.ts index f4ded8bc74..99c0d69b37 100644 --- a/packages/api/src/cache/__tests__/cacheFactory/sessionCache.cache_integration.spec.ts +++ b/packages/api/src/cache/__tests__/cacheFactory/sessionCache.cache_integration.spec.ts @@ -1,33 +1,36 @@ -interface SessionData { +import type { MemoryStore, SessionData } from 'express-session'; +import type { RedisStore as ConnectRedis } from 'connect-redis'; + +interface TestSessionData { [key: string]: unknown; cookie?: { maxAge: number }; user?: { id: string; name: string }; userId?: string; } -interface SessionStore { - prefix?: string; - set: (id: string, data: SessionData, callback?: (err?: Error) => void) => void; - get: (id: string, callback: (err: Error | null, data?: SessionData | null) => void) => void; - destroy: (id: string, callback?: (err?: Error) => void) => void; - touch: (id: string, data: SessionData, callback?: (err?: Error) => void) => void; - on?: (event: string, handler: (...args: unknown[]) => void) => void; -} +type CacheSessionStore = MemoryStore | ConnectRedis; describe('sessionCache', () => { let originalEnv: NodeJS.ProcessEnv; - // Helper to make session stores async - const asyncStore = (store: SessionStore) => ({ - set: (id: string, data: SessionData) => - new Promise((resolve) => store.set(id, data, () => resolve())), + // Helper to make session stores async — uses generic store type to bridge + // between MemoryStore/ConnectRedis and the test's relaxed SessionData shape. + // The store methods accept express-session's SessionData but test data is + // intentionally simpler; the cast bridges the gap for integration tests. + const asyncStore = (store: CacheSessionStore) => ({ + set: (id: string, data: TestSessionData) => + new Promise((resolve) => + store.set(id, data as Partial as SessionData, () => resolve()), + ), get: (id: string) => - new Promise((resolve) => - store.get(id, (_, data) => resolve(data)), + new Promise((resolve) => + store.get(id, (_, data) => resolve(data as TestSessionData | null | undefined)), ), destroy: (id: string) => new Promise((resolve) => store.destroy(id, () => resolve())), - touch: (id: string, data: SessionData) => - new Promise((resolve) => store.touch(id, data, () => resolve())), + touch: (id: string, data: TestSessionData) => + new Promise((resolve) => + store.touch(id, data as Partial as SessionData, () => resolve()), + ), }); beforeEach(() => { @@ -66,11 +69,11 @@ describe('sessionCache', () => { // Verify it returns a ConnectRedis instance expect(store).toBeDefined(); expect(store.constructor.name).toBe('RedisStore'); - expect(store.prefix).toBe('test-sessions:'); + expect((store as CacheSessionStore & { prefix: string }).prefix).toBe('test-sessions:'); // Test session operations const sessionId = 'sess:123456'; - const sessionData: SessionData = { + const sessionData: TestSessionData = { user: { id: 'user123', name: 'Test User' }, cookie: { maxAge: 3600000 }, }; @@ -107,7 +110,7 @@ describe('sessionCache', () => { // Test session operations const sessionId = 'mem:789012'; - const sessionData: SessionData = { + const sessionData: TestSessionData = { user: { id: 'user456', name: 'Memory User' }, cookie: { maxAge: 3600000 }, }; @@ -135,8 +138,8 @@ describe('sessionCache', () => { const store1 = cacheFactory.sessionCache('namespace1'); const store2 = cacheFactory.sessionCache('namespace2:'); - expect(store1.prefix).toBe('namespace1:'); - expect(store2.prefix).toBe('namespace2:'); + expect((store1 as CacheSessionStore & { prefix: string }).prefix).toBe('namespace1:'); + expect((store2 as CacheSessionStore & { prefix: string }).prefix).toBe('namespace2:'); }); test('should register error handler for Redis connection', async () => { @@ -171,7 +174,7 @@ describe('sessionCache', () => { } const sessionId = 'ttl:12345'; - const sessionData: SessionData = { userId: 'ttl-user' }; + const sessionData: TestSessionData = { userId: 'ttl-user' }; const async = asyncStore(store); // Set session with short TTL diff --git a/packages/api/src/cache/__tests__/redisClients.cache_integration.spec.ts b/packages/api/src/cache/__tests__/redisClients.cache_integration.spec.ts index dc9a325746..77e8c01436 100644 --- a/packages/api/src/cache/__tests__/redisClients.cache_integration.spec.ts +++ b/packages/api/src/cache/__tests__/redisClients.cache_integration.spec.ts @@ -59,8 +59,8 @@ describe('redisClients Integration Tests', () => { if (keys.length > 0) { await ioredisClient.del(...keys); } - } catch (error: any) { - console.warn('Error cleaning up test keys:', error.message); + } catch (error) { + console.warn('Error cleaning up test keys:', (error as Error).message); } } @@ -70,8 +70,8 @@ describe('redisClients Integration Tests', () => { if (ioredisClient.status === 'ready') { ioredisClient.disconnect(); } - } catch (error: any) { - console.warn('Error disconnecting ioredis client:', error.message); + } catch (error) { + console.warn('Error disconnecting ioredis client:', (error as Error).message); } ioredisClient = null; } @@ -80,8 +80,8 @@ describe('redisClients Integration Tests', () => { try { // Try to disconnect - keyv/redis client doesn't have an isReady property await keyvRedisClient.disconnect(); - } catch (error: any) { - console.warn('Error disconnecting keyv redis client:', error.message); + } catch (error) { + console.warn('Error disconnecting keyv redis client:', (error as Error).message); } keyvRedisClient = null; } @@ -138,7 +138,11 @@ describe('redisClients Integration Tests', () => { test('should connect and perform set/get/delete operations', async () => { const clients = await import('../redisClients'); keyvRedisClient = clients.keyvRedisClient; - await testRedisOperations(keyvRedisClient!, 'keyv-single', clients.keyvRedisClientReady!); + await testRedisOperations( + keyvRedisClient!, + 'keyv-single', + clients.keyvRedisClientReady!.then(() => undefined), + ); }); }); @@ -150,7 +154,11 @@ describe('redisClients Integration Tests', () => { const clients = await import('../redisClients'); keyvRedisClient = clients.keyvRedisClient; - await testRedisOperations(keyvRedisClient!, 'keyv-cluster', clients.keyvRedisClientReady!); + await testRedisOperations( + keyvRedisClient!, + 'keyv-cluster', + clients.keyvRedisClientReady!.then(() => undefined), + ); }); }); }); diff --git a/packages/api/src/cache/cacheConfig.ts b/packages/api/src/cache/cacheConfig.ts index 0d4304f5c3..7b4a899e98 100644 --- a/packages/api/src/cache/cacheConfig.ts +++ b/packages/api/src/cache/cacheConfig.ts @@ -128,8 +128,13 @@ const cacheConfig = { REDIS_SCAN_COUNT: math(process.env.REDIS_SCAN_COUNT, 1000), /** - * TTL in milliseconds for MCP registry read-through cache. - * This cache reduces redundant lookups within a single request flow. + * TTL in milliseconds for MCP registry caches. Used by both: + * - `MCPServersRegistry` read-through caches (`readThroughCache`/`readThroughCacheAll`) + * - `ServerConfigsCacheRedisAggregateKey` local snapshot (avoids redundant Redis GETs) + * + * Both layers use this value, so the effective max cross-instance staleness is up + * to 2× this value in multi-instance deployments. Set to 0 to disable the local + * snapshot entirely (every `getAll()` hits Redis directly). * @default 5000 (5 seconds) */ MCP_REGISTRY_CACHE_TTL: math(process.env.MCP_REGISTRY_CACHE_TTL, 5000), diff --git a/packages/api/src/endpoints/custom/initialize.spec.ts b/packages/api/src/endpoints/custom/initialize.spec.ts index 3705f98977..eddd7cb515 100644 --- a/packages/api/src/endpoints/custom/initialize.spec.ts +++ b/packages/api/src/endpoints/custom/initialize.spec.ts @@ -81,7 +81,7 @@ describe('initializeCustom – Agents API user key resolution', () => { userApiKey: 'sk-user-key', }); // Simulate Agents API request body (no `key` field) - params.req.body = { model: 'agent_123', messages: [] }; + params.req.body = { model: 'agent_123' }; await initializeCustom(params); @@ -104,7 +104,7 @@ describe('initializeCustom – Agents API user key resolution', () => { baseURL: AuthType.USER_PROVIDED, userBaseURL: 'https://user-api.example.com/v1', }); - params.req.body = { model: 'agent_123', messages: [] }; + params.req.body = { model: 'agent_123' }; await initializeCustom(params); diff --git a/packages/api/src/endpoints/custom/initialize.ts b/packages/api/src/endpoints/custom/initialize.ts index 1250721500..ea0d2dbf5d 100644 --- a/packages/api/src/endpoints/custom/initialize.ts +++ b/packages/api/src/endpoints/custom/initialize.ts @@ -32,10 +32,8 @@ function buildCustomOptions( customParams: endpointConfig.customParams, titleConvo: endpointConfig.titleConvo, titleModel: endpointConfig.titleModel, - summaryModel: endpointConfig.summaryModel, modelDisplayLabel: endpointConfig.modelDisplayLabel, titleMethod: endpointConfig.titleMethod ?? 'completion', - contextStrategy: endpointConfig.summarize ? 'summarize' : null, directEndpoint: endpointConfig.directEndpoint, titleMessageRole: endpointConfig.titleMessageRole, streamRate: endpointConfig.streamRate, diff --git a/packages/api/src/endpoints/openai/config.spec.ts b/packages/api/src/endpoints/openai/config.spec.ts index cdf9d6f14c..46ad6a6295 100644 --- a/packages/api/src/endpoints/openai/config.spec.ts +++ b/packages/api/src/endpoints/openai/config.spec.ts @@ -1399,10 +1399,8 @@ describe('getOpenAIConfig', () => { dropParams: ['presence_penalty'], titleConvo: true, titleModel: 'gpt-3.5-turbo', - summaryModel: 'gpt-3.5-turbo', modelDisplayLabel: 'Custom GPT-4', titleMethod: 'completion', - contextStrategy: 'summarize', directEndpoint: true, titleMessageRole: 'user', streamRate: 25, @@ -1417,10 +1415,8 @@ describe('getOpenAIConfig', () => { customParams: {}, titleConvo: endpointConfig.titleConvo, titleModel: endpointConfig.titleModel, - summaryModel: endpointConfig.summaryModel, modelDisplayLabel: endpointConfig.modelDisplayLabel, titleMethod: endpointConfig.titleMethod, - contextStrategy: endpointConfig.contextStrategy, directEndpoint: endpointConfig.directEndpoint, titleMessageRole: endpointConfig.titleMessageRole, streamRate: endpointConfig.streamRate, diff --git a/packages/api/src/flow/manager.tenant.spec.ts b/packages/api/src/flow/manager.tenant.spec.ts new file mode 100644 index 0000000000..14b780c34b --- /dev/null +++ b/packages/api/src/flow/manager.tenant.spec.ts @@ -0,0 +1,49 @@ +import { Keyv } from 'keyv'; +import { logger, tenantStorage } from '@librechat/data-schemas'; +import { FlowStateManager } from './manager'; + +jest.mock('@librechat/data-schemas', () => ({ + ...jest.requireActual('@librechat/data-schemas'), + logger: { + info: jest.fn(), + warn: jest.fn(), + error: jest.fn(), + debug: jest.fn(), + }, +})); + +describe('FlowStateManager flow keys are not tenant-scoped', () => { + let manager: FlowStateManager; + + beforeEach(() => { + jest.clearAllMocks(); + const store = new Keyv({ store: new Map() }); + manager = new FlowStateManager(store, { ci: true, ttl: 60_000 }); + }); + + it('completeFlow finds a flow regardless of tenant context (OAuth callback compatibility)', async () => { + await tenantStorage.run({ tenantId: 'tenant-a' }, async () => { + await manager.initFlow('flow-1', 'oauth', {}); + }); + + const found = await manager.completeFlow('flow-1', 'oauth', { token: 'abc' }); + expect(found).toBe(true); + }); + + it('completeFlow works when both creation and completion have the same tenant', async () => { + await tenantStorage.run({ tenantId: 'tenant-a' }, async () => { + await manager.initFlow('flow-2', 'oauth', {}); + const found = await manager.completeFlow('flow-2', 'oauth', { token: 'abc' }); + expect(found).toBe(true); + }); + }); + + it('completeFlow returns false and logs when flow does not exist', async () => { + const found = await manager.completeFlow('ghost-flow', 'oauth', { token: 'x' }); + expect(found).toBe(false); + expect(logger.warn).toHaveBeenCalledWith( + expect.stringContaining('ghost-flow'), + expect.objectContaining({ flowId: 'ghost-flow', type: 'oauth' }), + ); + }); +}); diff --git a/packages/api/src/flow/manager.ts b/packages/api/src/flow/manager.ts index b68b9edb7a..544cba9560 100644 --- a/packages/api/src/flow/manager.ts +++ b/packages/api/src/flow/manager.ts @@ -53,6 +53,12 @@ export class FlowStateManager { process.on('SIGHUP', cleanup); } + /** + * Flow keys are intentionally NOT tenant-scoped. OAuth callbacks arrive + * without tenant ALS context (the provider redirect doesn't carry + * X-Tenant-Id). Flow IDs are random UUIDs with no collision risk, and + * flow data is ephemeral (TTL-bounded, no sensitive user content). + */ private getFlowKey(flowId: string, type: string): string { return `${type}:${flowId}`; } @@ -253,7 +259,9 @@ export class FlowStateManager { if (!flowState) { logger.warn( - '[FlowStateManager] Flow state not found during completion — cannot recover metadata, skipping', + `[FlowStateManager] completeFlow: flow not found — key=${flowKey}. ` + + 'Possible causes: flow TTL expired before callback arrived, flow was never created, or ' + + 'the callback is routing to a different instance without shared Keyv storage.', { flowId, type }, ); return false; diff --git a/packages/api/src/index.ts b/packages/api/src/index.ts index ef32e7b6b0..7a04b8e74a 100644 --- a/packages/api/src/index.ts +++ b/packages/api/src/index.ts @@ -1,4 +1,6 @@ export * from './app'; +/* Admin */ +export * from './admin'; export * from './cdn'; /* Auth */ export * from './auth'; @@ -12,6 +14,7 @@ export * from './mcp/oauth'; export * from './mcp/auth'; export * from './mcp/zod'; export * from './mcp/errors'; +export * from './mcp/cache'; /* Utilities */ export * from './mcp/utils'; export * from './utils'; diff --git a/packages/api/src/mcp/ConnectionsRepository.ts b/packages/api/src/mcp/ConnectionsRepository.ts index 6313faa8d4..79976b1199 100644 --- a/packages/api/src/mcp/ConnectionsRepository.ts +++ b/packages/api/src/mcp/ConnectionsRepository.ts @@ -2,7 +2,7 @@ import { logger } from '@librechat/data-schemas'; import type * as t from './types'; import { MCPServersRegistry } from '~/mcp/registry/MCPServersRegistry'; import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory'; -import { hasCustomUserVars } from './utils'; +import { hasCustomUserVars, isUserSourced } from './utils'; import { MCPConnection } from './connection'; const CONNECT_CONCURRENCY = 3; @@ -82,7 +82,7 @@ export class ConnectionsRepository { { serverName, serverConfig, - dbSourced: !!(serverConfig as t.ParsedServerConfig).dbId, + dbSourced: isUserSourced(serverConfig as t.ParsedServerConfig), useSSRFProtection: registry.shouldEnableSSRFProtection(), allowedDomains: registry.getAllowedDomains(), }, diff --git a/packages/api/src/mcp/MCPManager.ts b/packages/api/src/mcp/MCPManager.ts index 935307fa49..12227de39f 100644 --- a/packages/api/src/mcp/MCPManager.ts +++ b/packages/api/src/mcp/MCPManager.ts @@ -18,6 +18,7 @@ import { preProcessGraphTokens } from '~/utils/graph'; import { formatToolContent } from './parsers'; import { MCPConnection } from './connection'; import { processMCPEnv } from '~/utils/env'; +import { isUserSourced } from './utils'; /** * Centralized manager for MCP server connections and tool execution. @@ -53,6 +54,8 @@ export class MCPManager extends UserConnectionManager { user?: IUser; forceNew?: boolean; flowManager?: FlowStateManager; + /** Pre-resolved config for config-source servers not in YAML/DB */ + serverConfig?: t.ParsedServerConfig; } & Omit, ): Promise { //the get method checks if the config is still valid as app level @@ -91,6 +94,7 @@ export class MCPManager extends UserConnectionManager { const serverConfig = await MCPServersRegistry.getInstance().getServerConfig( serverName, user?.id, + args.configServers, ); if (!serverConfig) { @@ -103,7 +107,7 @@ export class MCPManager extends UserConnectionManager { const registry = MCPServersRegistry.getInstance(); const useSSRFProtection = registry.shouldEnableSSRFProtection(); const allowedDomains = registry.getAllowedDomains(); - const dbSourced = !!serverConfig.dbId; + const dbSourced = isUserSourced(serverConfig); const basic: t.BasicConnectionOptions = { dbSourced, serverName, @@ -193,9 +197,15 @@ export class MCPManager extends UserConnectionManager { * @param serverNames Optional array of server names. If not provided or empty, returns all servers. * @returns Object mapping server names to their instructions */ - private async getInstructions(serverNames?: string[]): Promise> { + private async getInstructions( + serverNames?: string[], + configServers?: Record, + ): Promise> { const instructions: Record = {}; - const configs = await MCPServersRegistry.getInstance().getAllServerConfigs(); + const configs = await MCPServersRegistry.getInstance().getAllServerConfigs( + undefined, + configServers, + ); for (const [serverName, config] of Object.entries(configs)) { if (config.serverInstructions != null) { instructions[serverName] = config.serverInstructions as string; @@ -210,9 +220,11 @@ export class MCPManager extends UserConnectionManager { * @param serverNames Optional array of server names to include. If not provided, includes all servers. * @returns Formatted instructions string ready for context injection */ - public async formatInstructionsForContext(serverNames?: string[]): Promise { - /** Instructions for specified servers or all stored instructions */ - const instructionsToInclude = await this.getInstructions(serverNames); + public async formatInstructionsForContext( + serverNames?: string[], + configServers?: Record, + ): Promise { + const instructionsToInclude = await this.getInstructions(serverNames, configServers); if (Object.keys(instructionsToInclude).length === 0) { return ''; @@ -248,6 +260,7 @@ Please follow these instructions when using tools from the respective MCP server async callTool({ user, serverName, + serverConfig: providedConfig, toolName, provider, toolArguments, @@ -262,6 +275,8 @@ Please follow these instructions when using tools from the respective MCP server }: { user?: IUser; serverName: string; + /** Pre-resolved config from tool creation context — avoids readThrough TTL and cross-tenant issues */ + serverConfig?: t.ParsedServerConfig; toolName: string; provider: t.Provider; toolArguments?: Record; @@ -292,6 +307,7 @@ Please follow these instructions when using tools from the respective MCP server signal: options?.signal, customUserVars, requestBody, + serverConfig: providedConfig, }); if (!(await connection.isConnected())) { @@ -302,8 +318,16 @@ Please follow these instructions when using tools from the respective MCP server ); } - const rawConfig = await MCPServersRegistry.getInstance().getServerConfig(serverName, userId); - const isDbSourced = !!rawConfig?.dbId; + const rawConfig = + providedConfig ?? + (await MCPServersRegistry.getInstance().getServerConfig(serverName, userId)); + if (!rawConfig) { + throw new McpError( + ErrorCode.InvalidRequest, + `${logPrefix} Configuration for server "${serverName}" not found.`, + ); + } + const isDbSourced = isUserSourced(rawConfig); /** Pre-process Graph token placeholders (async) before the synchronous processMCPEnv pass */ const graphProcessedConfig = isDbSourced diff --git a/packages/api/src/mcp/UserConnectionManager.ts b/packages/api/src/mcp/UserConnectionManager.ts index 2e9d5be467..760f84c75e 100644 --- a/packages/api/src/mcp/UserConnectionManager.ts +++ b/packages/api/src/mcp/UserConnectionManager.ts @@ -4,6 +4,7 @@ import type * as t from './types'; import { MCPServersRegistry } from '~/mcp/registry/MCPServersRegistry'; import { ConnectionsRepository } from '~/mcp/ConnectionsRepository'; import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory'; +import { isUserSourced } from './utils'; import { MCPConnection } from './connection'; import { mcpConfig } from './mcpConfig'; @@ -38,6 +39,8 @@ export abstract class UserConnectionManager { opts: { serverName: string; forceNew?: boolean; + /** Pre-resolved config for config-source servers not in YAML/DB */ + serverConfig?: t.ParsedServerConfig; } & Omit, ): Promise { const { serverName, forceNew, user } = opts; @@ -85,9 +88,11 @@ export abstract class UserConnectionManager { signal, returnOnOAuth = false, connectionTimeout, + serverConfig: providedConfig, }: { serverName: string; forceNew?: boolean; + serverConfig?: t.ParsedServerConfig; } & Omit, userId: string, ): Promise { @@ -98,7 +103,9 @@ export abstract class UserConnectionManager { ); } - const config = await MCPServersRegistry.getInstance().getServerConfig(serverName, userId); + const config = + providedConfig ?? + (await MCPServersRegistry.getInstance().getServerConfig(serverName, userId)); const userServerMap = this.userConnections.get(userId); let connection = forceNew ? undefined : userServerMap?.get(serverName); @@ -158,7 +165,7 @@ export abstract class UserConnectionManager { { serverConfig: config, serverName: serverName, - dbSourced: !!config.dbId, + dbSourced: isUserSourced(config), useSSRFProtection: registry.shouldEnableSSRFProtection(), allowedDomains: registry.getAllowedDomains(), }, diff --git a/packages/api/src/mcp/__tests__/ConnectionsRepository.test.ts b/packages/api/src/mcp/__tests__/ConnectionsRepository.test.ts index 7a93960765..dfb57a1faf 100644 --- a/packages/api/src/mcp/__tests__/ConnectionsRepository.test.ts +++ b/packages/api/src/mcp/__tests__/ConnectionsRepository.test.ts @@ -46,8 +46,8 @@ describe('ConnectionsRepository', () => { beforeEach(() => { mockServerConfigs = { - server1: { url: 'http://localhost:3001' }, - server2: { command: 'test-command', args: ['--test'] }, + server1: { url: 'http://localhost:3001', type: 'sse' }, + server2: { command: 'test-command', args: ['--test'], type: 'stdio' }, server3: { url: 'ws://localhost:8080', type: 'websocket' }, }; diff --git a/packages/api/src/mcp/__tests__/MCPConnectionAgentLifecycle.test.ts b/packages/api/src/mcp/__tests__/MCPConnectionAgentLifecycle.test.ts index 281bd590db..c7b6b273ba 100644 --- a/packages/api/src/mcp/__tests__/MCPConnectionAgentLifecycle.test.ts +++ b/packages/api/src/mcp/__tests__/MCPConnectionAgentLifecycle.test.ts @@ -377,7 +377,7 @@ describe('MCPConnection Agent lifecycle – SSE', () => { it('reuses the same Agents across multiple requests instead of creating one per request', async () => { conn = new MCPConnection({ serverName: 'test-sse', - serverConfig: { url: server.url }, + serverConfig: { url: server.url, type: 'sse' }, useSSRFProtection: false, }); @@ -402,7 +402,7 @@ describe('MCPConnection Agent lifecycle – SSE', () => { it('calls Agent.close() on every registered Agent when disconnect() is called', async () => { conn = new MCPConnection({ serverName: 'test-sse', - serverConfig: { url: server.url }, + serverConfig: { url: server.url, type: 'sse' }, useSSRFProtection: false, }); @@ -417,7 +417,7 @@ describe('MCPConnection Agent lifecycle – SSE', () => { it('closes at least two Agents for SSE transport (eventSourceInit + fetch)', async () => { conn = new MCPConnection({ serverName: 'test-sse', - serverConfig: { url: server.url }, + serverConfig: { url: server.url, type: 'sse' }, useSSRFProtection: false, }); @@ -431,7 +431,7 @@ describe('MCPConnection Agent lifecycle – SSE', () => { it('does not double-close Agents when disconnect() is called twice', async () => { conn = new MCPConnection({ serverName: 'test-sse', - serverConfig: { url: server.url }, + serverConfig: { url: server.url, type: 'sse' }, useSSRFProtection: false, }); @@ -533,7 +533,7 @@ describe('MCPConnection SSE 404 handling – session-aware', () => { function makeConn() { return new MCPConnection({ serverName: 'test-404', - serverConfig: { url: 'http://127.0.0.1:1/sse' }, + serverConfig: { url: 'http://127.0.0.1:1/sse', type: 'sse' }, useSSRFProtection: false, }); } @@ -599,7 +599,7 @@ describe('MCPConnection SSE stream disconnect handling', () => { function makeConn() { return new MCPConnection({ serverName: 'test-sse-disconnect', - serverConfig: { url: 'http://127.0.0.1:1/sse' }, + serverConfig: { url: 'http://127.0.0.1:1/sse', type: 'sse' }, useSSRFProtection: false, }); } diff --git a/packages/api/src/mcp/__tests__/MCPOAuthCSRFFallback.test.ts b/packages/api/src/mcp/__tests__/MCPOAuthCSRFFallback.test.ts index cdba06cf8d..c0a861817c 100644 --- a/packages/api/src/mcp/__tests__/MCPOAuthCSRFFallback.test.ts +++ b/packages/api/src/mcp/__tests__/MCPOAuthCSRFFallback.test.ts @@ -34,6 +34,8 @@ jest.mock('@librechat/data-schemas', () => ({ error: jest.fn(), debug: jest.fn(), }, + getTenantId: jest.fn(), + SYSTEM_TENANT_ID: '__SYSTEM__', encryptV2: jest.fn(async (val: string) => `enc:${val}`), decryptV2: jest.fn(async (val: string) => val.replace(/^enc:/, '')), })); diff --git a/packages/api/src/mcp/__tests__/MCPOAuthFlow.test.ts b/packages/api/src/mcp/__tests__/MCPOAuthFlow.test.ts index f73a5ed3e8..cbd29d3571 100644 --- a/packages/api/src/mcp/__tests__/MCPOAuthFlow.test.ts +++ b/packages/api/src/mcp/__tests__/MCPOAuthFlow.test.ts @@ -7,6 +7,7 @@ import { createHash } from 'crypto'; import { Keyv } from 'keyv'; +import { TokenExchangeMethodEnum } from 'librechat-data-provider'; import { MCPTokenStorage, MCPOAuthHandler } from '~/mcp/oauth'; import { FlowStateManager } from '~/flow/manager'; import { createOAuthMCPServer, MockKeyv, InMemoryTokenStore } from './helpers/oauthTestServer'; @@ -20,6 +21,8 @@ jest.mock('@librechat/data-schemas', () => ({ error: jest.fn(), debug: jest.fn(), }, + getTenantId: jest.fn(), + SYSTEM_TENANT_ID: '__SYSTEM__', encryptV2: jest.fn(async (val: string) => `enc:${val}`), decryptV2: jest.fn(async (val: string) => val.replace(/^enc:/, '')), })); @@ -92,7 +95,7 @@ describe('MCP OAuth Flow — Real HTTP Server', () => { token_url: `${server.url}token`, client_id: clientInfo.client_id, client_secret: clientInfo.client_secret, - token_exchange_method: 'DefaultPost', + token_exchange_method: TokenExchangeMethodEnum.DefaultPost, }, ); @@ -131,7 +134,7 @@ describe('MCP OAuth Flow — Real HTTP Server', () => { { token_url: `${rotatingServer.url}token`, client_id: 'anon', - token_exchange_method: 'DefaultPost', + token_exchange_method: TokenExchangeMethodEnum.DefaultPost, }, ); @@ -155,7 +158,7 @@ describe('MCP OAuth Flow — Real HTTP Server', () => { { token_url: `${server.url}token`, client_id: 'anon', - token_exchange_method: 'DefaultPost', + token_exchange_method: TokenExchangeMethodEnum.DefaultPost, }, ), ).rejects.toThrow(); @@ -412,7 +415,7 @@ describe('MCP OAuth Flow — Real HTTP Server', () => { const state = await flowManager.getFlowState(flowId, 'mcp_oauth'); expect(state?.status).toBe('COMPLETED'); - expect(state?.result?.access_token).toBe(tokens.access_token); + expect((state?.result as MCPOAuthTokens | undefined)?.access_token).toBe(tokens.access_token); }); it('should fail flow when authorization code is invalid', async () => { diff --git a/packages/api/src/mcp/__tests__/MCPOAuthRaceCondition.test.ts b/packages/api/src/mcp/__tests__/MCPOAuthRaceCondition.test.ts index cb6187ab45..d5fb1d67f7 100644 --- a/packages/api/src/mcp/__tests__/MCPOAuthRaceCondition.test.ts +++ b/packages/api/src/mcp/__tests__/MCPOAuthRaceCondition.test.ts @@ -23,6 +23,8 @@ jest.mock('@librechat/data-schemas', () => ({ error: jest.fn(), debug: jest.fn(), }, + getTenantId: jest.fn(), + SYSTEM_TENANT_ID: '__SYSTEM__', encryptV2: jest.fn(async (val: string) => `enc:${val}`), decryptV2: jest.fn(async (val: string) => val.replace(/^enc:/, '')), })); @@ -258,7 +260,7 @@ describe('MCP OAuth Race Condition Fixes', () => { expect(stateAfterComplete).toBeUndefined(); expect(mockLogger.warn).toHaveBeenCalledWith( - expect.stringContaining('cannot recover metadata'), + expect.stringContaining('flow not found'), expect.any(Object), ); }); diff --git a/packages/api/src/mcp/__tests__/MCPOAuthSecurity.test.ts b/packages/api/src/mcp/__tests__/MCPOAuthSecurity.test.ts index a2d0440d42..d50e29eab7 100644 --- a/packages/api/src/mcp/__tests__/MCPOAuthSecurity.test.ts +++ b/packages/api/src/mcp/__tests__/MCPOAuthSecurity.test.ts @@ -304,10 +304,10 @@ describe('MCP OAuth allowedDomains SSRF exemption for admin-trusted hosts', () = }); it('should allow private revocationEndpoint when hostname is in allowedDomains', async () => { - const mockFetch = jest.fn().mockResolvedValue({ - ok: true, - status: 200, - } as Response); + const mockFetch = Object.assign( + jest.fn().mockResolvedValue({ ok: true, status: 200 } as Response), + { preconnect: jest.fn() }, + ); const originalFetch = global.fetch; global.fetch = mockFetch; @@ -333,14 +333,17 @@ describe('MCP OAuth allowedDomains SSRF exemption for admin-trusted hosts', () = }); it('should allow localhost token_url in refreshOAuthTokens when localhost is in allowedDomains', async () => { - const mockFetch = jest.fn().mockResolvedValue({ - ok: true, - json: async () => ({ - access_token: 'new-access-token', - token_type: 'Bearer', - expires_in: 3600, - }), - } as Response); + const mockFetch = Object.assign( + jest.fn().mockResolvedValue({ + ok: true, + json: async () => ({ + access_token: 'new-access-token', + token_type: 'Bearer', + expires_in: 3600, + }), + } as Response), + { preconnect: jest.fn() }, + ); const originalFetch = global.fetch; global.fetch = mockFetch; diff --git a/packages/api/src/mcp/__tests__/MCPOAuthTokenExpiry.test.ts b/packages/api/src/mcp/__tests__/MCPOAuthTokenExpiry.test.ts index 986ac4c8b4..b5cbc869a8 100644 --- a/packages/api/src/mcp/__tests__/MCPOAuthTokenExpiry.test.ts +++ b/packages/api/src/mcp/__tests__/MCPOAuthTokenExpiry.test.ts @@ -26,6 +26,8 @@ jest.mock('@librechat/data-schemas', () => ({ error: jest.fn(), debug: jest.fn(), }, + getTenantId: jest.fn(), + SYSTEM_TENANT_ID: '__SYSTEM__', encryptV2: jest.fn(async (val: string) => `enc:${val}`), decryptV2: jest.fn(async (val: string) => val.replace(/^enc:/, '')), })); diff --git a/packages/api/src/mcp/__tests__/MCPOAuthTokenStorage.test.ts b/packages/api/src/mcp/__tests__/MCPOAuthTokenStorage.test.ts index 3805586453..2d3905d2fb 100644 --- a/packages/api/src/mcp/__tests__/MCPOAuthTokenStorage.test.ts +++ b/packages/api/src/mcp/__tests__/MCPOAuthTokenStorage.test.ts @@ -160,7 +160,7 @@ describe('MCPTokenStorage', () => { serverName: 'srv1', tokens: { access_token: 'at1', token_type: 'Bearer', expires_in: 3600 }, createToken: store.createToken, - clientInfo: { client_id: 'cid', client_secret: 'csec', redirect_uris: [] }, + clientInfo: { client_id: 'cid', client_secret: 'csec' }, }); const clientSaved = await store.findToken({ @@ -525,7 +525,7 @@ describe('MCPTokenStorage', () => { refresh_token: 'my-refresh-token', }, createToken: store.createToken, - clientInfo: { client_id: 'cid', client_secret: 'sec', redirect_uris: [] }, + clientInfo: { client_id: 'cid', client_secret: 'sec' }, }); const result = await MCPTokenStorage.getTokens({ diff --git a/packages/api/src/mcp/__tests__/mcp.spec.ts b/packages/api/src/mcp/__tests__/mcp.spec.ts index d64f9f3afa..d5cc44569f 100644 --- a/packages/api/src/mcp/__tests__/mcp.spec.ts +++ b/packages/api/src/mcp/__tests__/mcp.spec.ts @@ -179,6 +179,7 @@ describe('Environment Variable Extraction (MCP)', () => { describe('processMCPEnv', () => { it('should create a deep clone of the input object', () => { const originalObj: MCPOptions = { + type: 'stdio', command: 'node', args: ['server.js'], env: { @@ -202,6 +203,7 @@ describe('Environment Variable Extraction (MCP)', () => { it('should process environment variables in env field', () => { const options: MCPOptions = { + type: 'stdio', command: 'node', args: ['server.js'], env: { @@ -252,6 +254,7 @@ describe('Environment Variable Extraction (MCP)', () => { it('should not modify objects without env or headers', () => { const options: MCPOptions = { + type: 'stdio', command: 'node', args: ['server.js'], timeout: 5000, @@ -433,6 +436,7 @@ describe('Environment Variable Extraction (MCP)', () => { ldapId: 'ldap-user-123', }); const options: MCPOptions = { + type: 'stdio', command: 'node', args: ['server.js'], env: { @@ -599,6 +603,7 @@ describe('Environment Variable Extraction (MCP)', () => { CUSTOM_VAR_2: 'custom-value-2', }; const options: MCPOptions = { + type: 'stdio', command: 'node', args: ['server.js'], env: { @@ -674,6 +679,7 @@ describe('Environment Variable Extraction (MCP)', () => { PROFILE_NAME: 'production-profile', }; const options: MCPOptions = { + type: 'stdio', command: 'npx', args: [ '-y', @@ -734,6 +740,7 @@ describe('Environment Variable Extraction (MCP)', () => { UNUSED_VAR: 'unused-value', }; const options: MCPOptions = { + type: 'stdio', command: 'node', args: ['server.js'], env: { @@ -959,6 +966,7 @@ describe('Environment Variable Extraction (MCP)', () => { }) as unknown as IUser; const options: MCPOptions = { + type: 'stdio', command: 'node', args: ['mcp-server.js', '--user', '{{LIBRECHAT_USER_USERNAME}}'], env: { diff --git a/packages/api/src/mcp/__tests__/utils.test.ts b/packages/api/src/mcp/__tests__/utils.test.ts index e4fb31bdad..b9c2a31fa5 100644 --- a/packages/api/src/mcp/__tests__/utils.test.ts +++ b/packages/api/src/mcp/__tests__/utils.test.ts @@ -1,4 +1,10 @@ -import { normalizeServerName, redactServerSecrets, redactAllServerSecrets } from '~/mcp/utils'; +import { + buildOAuthToolCallName, + normalizeServerName, + redactAllServerSecrets, + redactServerSecrets, + isUserSourced, +} from '~/mcp/utils'; import type { ParsedServerConfig } from '~/mcp/types'; describe('normalizeServerName', () => { @@ -28,6 +34,49 @@ describe('normalizeServerName', () => { }); }); +describe('buildOAuthToolCallName', () => { + it('should prefix a simple server name with oauth_mcp_', () => { + expect(buildOAuthToolCallName('my-server')).toBe('oauth_mcp_my-server'); + }); + + it('should not double-wrap a name that already starts with oauth_mcp_', () => { + expect(buildOAuthToolCallName('oauth_mcp_my-server')).toBe('oauth_mcp_my-server'); + }); + + it('should correctly handle server names containing _mcp_ substring', () => { + const result = buildOAuthToolCallName('my_mcp_server'); + expect(result).toBe('oauth_mcp_my_mcp_server'); + }); + + it('should normalize non-ASCII server names before prefixing', () => { + const result = buildOAuthToolCallName('我的服务'); + expect(result).toMatch(/^oauth_mcp_server_\d+$/); + }); + + it('should normalize special characters before prefixing', () => { + expect(buildOAuthToolCallName('server@name!')).toBe('oauth_mcp_server_name'); + }); + + it('should handle empty string server name gracefully', () => { + const result = buildOAuthToolCallName(''); + expect(result).toMatch(/^oauth_mcp_server_\d+$/); + }); + + it('should treat a name already starting with oauth_mcp_ as pre-wrapped', () => { + // At the function level, a name starting with the oauth prefix is + // indistinguishable from a pre-wrapped name — guard prevents double-wrapping. + // Server names with this prefix should be blocked at registration time. + expect(buildOAuthToolCallName('oauth_mcp_github')).toBe('oauth_mcp_github'); + }); + + it('should not treat special chars that normalize to oauth_mcp_* as pre-wrapped', () => { + // oauth@mcp@server does NOT start with 'oauth_mcp_' before normalization, + // so the guard correctly does not fire and the prefix is added. + const result = buildOAuthToolCallName('oauth@mcp@server'); + expect(result).toBe('oauth_mcp_oauth_mcp_server'); + }); +}); + describe('redactServerSecrets', () => { it('should strip apiKey.key from admin-sourced keys', () => { const config: ParsedServerConfig = { @@ -225,3 +274,29 @@ describe('redactAllServerSecrets', () => { expect((redacted['server-c'] as Record).command).toBeUndefined(); }); }); + +describe('isUserSourced', () => { + it('returns false when source is yaml', () => { + expect(isUserSourced({ source: 'yaml' })).toBe(false); + }); + + it('returns false when source is config', () => { + expect(isUserSourced({ source: 'config' })).toBe(false); + }); + + it('returns true when source is user', () => { + expect(isUserSourced({ source: 'user' })).toBe(true); + }); + + it('falls back to dbId when source is undefined — dbId present means user-sourced', () => { + expect(isUserSourced({ source: undefined, dbId: 'abc123' })).toBe(true); + }); + + it('falls back to dbId when source is undefined — no dbId means trusted', () => { + expect(isUserSourced({ source: undefined, dbId: undefined })).toBe(false); + }); + + it('returns false when both source and dbId are absent (pre-upgrade YAML server)', () => { + expect(isUserSourced({})).toBe(false); + }); +}); diff --git a/packages/api/src/mcp/cache.ts b/packages/api/src/mcp/cache.ts new file mode 100644 index 0000000000..e68ef42b3c --- /dev/null +++ b/packages/api/src/mcp/cache.ts @@ -0,0 +1,43 @@ +import { logger } from '@librechat/data-schemas'; +import { MCPServersRegistry } from './registry/MCPServersRegistry'; +import { MCPManager } from './MCPManager'; + +/** + * Clears config-source MCP server inspection cache so servers are re-inspected on next access. + * Best-effort disconnection of app-level connections for evicted servers. + * + * User-level connections (used by config-source servers) are cleaned up lazily via + * the stale-check mechanism on the next tool call — this is an accepted design tradeoff + * since iterating all active user sessions is expensive and config mutations are rare. + */ +export async function clearMcpConfigCache(): Promise { + let registry: MCPServersRegistry; + try { + registry = MCPServersRegistry.getInstance(); + } catch { + return; + } + + let evictedServers: string[]; + try { + evictedServers = await registry.invalidateConfigCache(); + } catch (error) { + logger.error('[clearMcpConfigCache] Failed to invalidate config cache:', error); + return; + } + + if (!evictedServers.length) { + return; + } + + try { + const mcpManager = MCPManager.getInstance(); + if (mcpManager?.appConnections) { + await Promise.allSettled( + evictedServers.map((serverName) => mcpManager.appConnections!.disconnect(serverName)), + ); + } + } catch { + // MCPManager not yet initialized — connections cleaned up lazily + } +} diff --git a/packages/api/src/mcp/oauth/handler.ts b/packages/api/src/mcp/oauth/handler.ts index 873af5c66d..e128dec308 100644 --- a/packages/api/src/mcp/oauth/handler.ts +++ b/packages/api/src/mcp/oauth/handler.ts @@ -467,6 +467,7 @@ export class MCPOAuthHandler { codeVerifier, clientInfo, metadata, + ...(Object.keys(oauthHeaders).length > 0 && { oauthHeaders }), }; logger.debug( @@ -573,6 +574,7 @@ export class MCPOAuthHandler { clientInfo, metadata, resourceMetadata, + ...(Object.keys(oauthHeaders).length > 0 && { oauthHeaders }), }; logger.debug( diff --git a/packages/api/src/mcp/oauth/types.ts b/packages/api/src/mcp/oauth/types.ts index 2138b4a782..bc5f53f60c 100644 --- a/packages/api/src/mcp/oauth/types.ts +++ b/packages/api/src/mcp/oauth/types.ts @@ -89,6 +89,8 @@ export interface MCPOAuthFlowMetadata extends FlowMetadata { metadata?: OAuthMetadata; resourceMetadata?: OAuthProtectedResourceMetadata; authorizationUrl?: string; + /** Custom headers for OAuth token exchange, persisted at flow initiation for the callback. */ + oauthHeaders?: Record; } export interface MCPOAuthTokens extends OAuthTokens { diff --git a/packages/api/src/mcp/registry/MCPServerInspector.ts b/packages/api/src/mcp/registry/MCPServerInspector.ts index 7f31211680..f064fbb7e5 100644 --- a/packages/api/src/mcp/registry/MCPServerInspector.ts +++ b/packages/api/src/mcp/registry/MCPServerInspector.ts @@ -4,9 +4,9 @@ import type { MCPConnection } from '~/mcp/connection'; import type * as t from '~/mcp/types'; import { isMCPDomainAllowed, extractMCPServerDomain } from '~/auth/domain'; import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory'; +import { hasCustomUserVars, isUserSourced } from '~/mcp/utils'; import { MCPDomainNotAllowedError } from '~/mcp/errors'; import { detectOAuthRequirement } from '~/mcp/oauth'; -import { hasCustomUserVars } from '~/mcp/utils'; import { isEnabled } from '~/utils'; /** @@ -73,7 +73,7 @@ export class MCPServerInspector { this.connection = await MCPConnectionFactory.create({ serverConfig: this.config, serverName: this.serverName, - dbSourced: !!this.config.dbId, + dbSourced: isUserSourced(this.config), useSSRFProtection: this.useSSRFProtection, allowedDomains: this.allowedDomains, }); diff --git a/packages/api/src/mcp/registry/MCPServersRegistry.ts b/packages/api/src/mcp/registry/MCPServersRegistry.ts index 506f5b1baa..6c98a6b8dd 100644 --- a/packages/api/src/mcp/registry/MCPServersRegistry.ts +++ b/packages/api/src/mcp/registry/MCPServersRegistry.ts @@ -1,28 +1,48 @@ import { Keyv } from 'keyv'; +import { createHash } from 'crypto'; import { logger } from '@librechat/data-schemas'; import type { IServerConfigsRepositoryInterface } from './ServerConfigsRepositoryInterface'; import type * as t from '~/mcp/types'; +import { + ServerConfigsCacheFactory, + APP_CACHE_NAMESPACE, + CONFIG_CACHE_NAMESPACE, +} from './cache/ServerConfigsCacheFactory'; import { MCPInspectionFailedError, isMCPDomainNotAllowedError } from '~/mcp/errors'; -import { ServerConfigsCacheFactory } from './cache/ServerConfigsCacheFactory'; import { MCPServerInspector } from './MCPServerInspector'; import { ServerConfigsDB } from './db/ServerConfigsDB'; import { cacheConfig } from '~/cache/cacheConfig'; +import { withTimeout } from '~/utils'; + +/** How long a failure stub is considered fresh before re-attempting inspection (5 minutes). */ +const CONFIG_STUB_RETRY_MS = 5 * 60 * 1000; + +const CONFIG_SERVER_INIT_TIMEOUT_MS = (() => { + const raw = process.env.MCP_INIT_TIMEOUT_MS; + if (raw == null) { + return 30_000; + } + const parsed = parseInt(raw, 10); + return Number.isFinite(parsed) && parsed > 0 ? parsed : 30_000; +})(); /** * Central registry for managing MCP server configurations. * Authoritative source of truth for all MCP servers provided by LibreChat. * - * Uses a two-repository architecture: - * - Cache Repository: Stores YAML-defined configs loaded at startup (in-memory or Redis-backed) - * - DB Repository: Stores dynamic configs created at runtime (not yet implemented) + * Uses a three-layer architecture: + * - YAML Cache (cacheConfigsRepo): Operator-defined configs loaded at startup (in-memory or Redis) + * - Config Cache (configCacheRepo): Admin-defined configs from Config overrides, lazily initialized + * - DB Repository (dbConfigsRepo): User-provided configs created at runtime (MongoDB + ACL) * - * Query priority: Cache configs are checked first, then DB configs. + * Query priority: YAML cache → Config cache → DB. */ export class MCPServersRegistry { private static instance: MCPServersRegistry; private readonly dbConfigsRepo: IServerConfigsRepositoryInterface; private readonly cacheConfigsRepo: IServerConfigsRepositoryInterface; + private readonly configCacheRepo: IServerConfigsRepositoryInterface; private readonly allowedDomains?: string[] | null; private readonly readThroughCache: Keyv; private readonly readThroughCacheAll: Keyv>; @@ -31,9 +51,20 @@ export class MCPServersRegistry { Promise> >(); + /** Tracks in-flight config server initializations to prevent duplicate work. */ + private readonly pendingConfigInits = new Map< + string, + Promise + >(); + + /** Memoized YAML server names — set once after boot-time init, never changes. */ + private yamlServerNames: Set | null = null; + private yamlServerNamesPromise: Promise> | null = null; + constructor(mongoose: typeof import('mongoose'), allowedDomains?: string[] | null) { this.dbConfigsRepo = new ServerConfigsDB(mongoose); - this.cacheConfigsRepo = ServerConfigsCacheFactory.create('App', false); + this.cacheConfigsRepo = ServerConfigsCacheFactory.create(APP_CACHE_NAMESPACE, false); + this.configCacheRepo = ServerConfigsCacheFactory.create(CONFIG_CACHE_NAMESPACE, false); this.allowedDomains = allowedDomains; const ttl = cacheConfig.MCP_REGISTRY_CACHE_TTL; @@ -86,22 +117,29 @@ export class MCPServersRegistry { return !Array.isArray(this.allowedDomains) || this.allowedDomains.length === 0; } + /** + * Returns the config for a single server. When `configServers` is provided, config-source + * servers are resolved from it directly (no global state, no cross-tenant race). + */ public async getServerConfig( serverName: string, userId?: string, + configServers?: Record, ): Promise { + if (configServers?.[serverName]) { + return configServers[serverName]; + } + const cacheKey = this.getReadThroughCacheKey(serverName, userId); if (await this.readThroughCache.has(cacheKey)) { return await this.readThroughCache.get(cacheKey); } - // First we check if any config exist with the cache - // Yaml config are pre loaded to the cache - const configFromCache = await this.cacheConfigsRepo.get(serverName); - if (configFromCache) { - await this.readThroughCache.set(cacheKey, configFromCache); - return configFromCache; + const configFromYaml = await this.cacheConfigsRepo.get(serverName); + if (configFromYaml) { + await this.readThroughCache.set(cacheKey, configFromYaml); + return configFromYaml; } const configFromDB = await this.dbConfigsRepo.get(serverName, userId); @@ -109,7 +147,30 @@ export class MCPServersRegistry { return configFromDB; } - public async getAllServerConfigs(userId?: string): Promise> { + /** + * Returns all server configs visible to the given user. + * YAML and Config tiers are mutually exclusive by design (`ensureConfigServers` filters + * YAML names), so the spread order only matters for User DB (highest priority) overriding both. + */ + public async getAllServerConfigs( + userId?: string, + configServers?: Record, + ): Promise> { + if (configServers == null || !Object.keys(configServers).length) { + return this.getBaseServerConfigs(userId); + } + const base = await this.getBaseServerConfigs(userId); + return { ...configServers, ...base }; + } + + /** + * Returns YAML + user-DB server configs, cached via `readThroughCacheAll`. + * Always called by `getAllServerConfigs` so the DB query is amortized across + * requests within the TTL window regardless of whether `configServers` is present. + */ + private async getBaseServerConfigs( + userId?: string, + ): Promise> { const cacheKey = userId ?? '__no_user__'; if (await this.readThroughCacheAll.has(cacheKey)) { @@ -121,7 +182,7 @@ export class MCPServersRegistry { return pending; } - const fetchPromise = this.fetchAllServerConfigs(cacheKey, userId); + const fetchPromise = this.fetchBaseServerConfigs(cacheKey, userId); this.pendingGetAllPromises.set(cacheKey, fetchPromise); try { @@ -131,7 +192,7 @@ export class MCPServersRegistry { } } - private async fetchAllServerConfigs( + private async fetchBaseServerConfigs( cacheKey: string, userId?: string, ): Promise> { @@ -155,7 +216,8 @@ export class MCPServersRegistry { userId?: string, ): Promise { const configRepo = this.getConfigRepository(storageLocation); - const stubConfig: t.ParsedServerConfig = { ...config, inspectionFailed: true }; + const source: t.MCPServerSource = storageLocation === 'CACHE' ? 'yaml' : 'user'; + const stubConfig: t.ParsedServerConfig = { ...config, inspectionFailed: true, source }; const result = await configRepo.add(serverName, stubConfig, userId); await this.readThroughCache.delete(this.getReadThroughCacheKey(serverName, userId)); await this.readThroughCache.delete(this.getReadThroughCacheKey(serverName)); @@ -179,13 +241,16 @@ export class MCPServersRegistry { ); } catch (error) { logger.error(`[MCPServersRegistry] Failed to inspect server "${serverName}":`, error); - // Preserve domain-specific error for better error handling if (isMCPDomainNotAllowedError(error)) { throw error; } throw new MCPInspectionFailedError(serverName, error as Error); } - return await configRepo.add(serverName, parsedConfig, userId); + const tagged = { + ...parsedConfig, + source: (storageLocation === 'CACHE' ? 'yaml' : 'user') as t.MCPServerSource, + }; + return await configRepo.add(serverName, tagged, userId); } /** @@ -267,7 +332,6 @@ export class MCPServersRegistry { ); } catch (error) { logger.error(`[MCPServersRegistry] Failed to inspect server "${serverName}":`, error); - // Preserve domain-specific error for better error handling if (isMCPDomainNotAllowedError(error)) { throw error; } @@ -277,8 +341,180 @@ export class MCPServersRegistry { return parsedConfig; } - // TODO: This is currently used to determine if a server requires OAuth. However, this info can - // can be determined through config.requiresOAuth. Refactor usages and remove this method. + /** + * Ensures that config-source MCP servers (from admin Config overrides) are initialized. + * Identifies servers in `resolvedMcpConfig` that are not from YAML, lazily initializes + * any not yet in the config cache, and returns their parsed configs. + * + * Config cache keys are scoped by a hash of the raw config to prevent cross-tenant + * cache poisoning when two tenants define a server with the same name but different configs. + */ + public async ensureConfigServers( + resolvedMcpConfig: Record, + ): Promise> { + if (!resolvedMcpConfig || Object.keys(resolvedMcpConfig).length === 0) { + return {}; + } + + const yamlNames = await this.getYamlServerNames(); + const configServerEntries = Object.entries(resolvedMcpConfig).filter( + ([name]) => !yamlNames.has(name), + ); + + if (configServerEntries.length === 0) { + return {}; + } + + const result: Record = {}; + + const settled = await Promise.allSettled( + configServerEntries.map(async ([serverName, rawConfig]) => { + const parsed = await this.ensureSingleConfigServer(serverName, rawConfig); + if (parsed) { + result[serverName] = parsed; + } + }), + ); + for (const outcome of settled) { + if (outcome.status === 'rejected') { + logger.error('[MCPServersRegistry][ensureConfigServers] Unexpected error:', outcome.reason); + } + } + + return result; + } + + /** + * Ensures a single config-source server is initialized. + * Cache key is scoped by config hash to prevent cross-tenant poisoning. + * Deduplicates concurrent init requests for the same server+config. + * Stale failure stubs are retried after `CONFIG_STUB_RETRY_MS` to recover from transient errors. + */ + private async ensureSingleConfigServer( + serverName: string, + rawConfig: t.MCPOptions, + ): Promise { + const cacheKey = this.configCacheKey(serverName, rawConfig); + + const cached = await this.configCacheRepo.get(cacheKey); + if (cached) { + const isStaleStub = + cached.inspectionFailed && Date.now() - (cached.updatedAt ?? 0) > CONFIG_STUB_RETRY_MS; + if (!isStaleStub) { + return cached; + } + logger.info(`[MCP][config][${serverName}] Retrying stale failure stub`); + } + + const pending = this.pendingConfigInits.get(cacheKey); + if (pending) { + return pending; + } + + const initPromise = this.lazyInitConfigServer(cacheKey, serverName, rawConfig); + this.pendingConfigInits.set(cacheKey, initPromise); + + try { + return await initPromise; + } finally { + this.pendingConfigInits.delete(cacheKey); + } + } + + /** + * Lazily initializes a config-source MCP server: inspects capabilities/tools, then + * stores the parsed config in the config cache with `source: 'config'`. + */ + private async lazyInitConfigServer( + cacheKey: string, + serverName: string, + rawConfig: t.MCPOptions, + ): Promise { + const prefix = `[MCP][config][${serverName}]`; + logger.info(`${prefix} Lazy-initializing config-source server`); + + try { + const inspected = await withTimeout( + MCPServerInspector.inspect(serverName, rawConfig, undefined, this.allowedDomains), + CONFIG_SERVER_INIT_TIMEOUT_MS, + `${prefix} Server initialization timed out`, + ); + + const parsedConfig: t.ParsedServerConfig = { ...inspected, source: 'config' }; + await this.upsertConfigCache(cacheKey, parsedConfig); + + logger.info( + `${prefix} Initialized: tools=${parsedConfig.tools ?? 'N/A'}, ` + + `duration=${parsedConfig.initDuration ?? 'N/A'}ms`, + ); + return parsedConfig; + } catch (error) { + logger.error(`${prefix} Failed to initialize:`, error); + + const stubConfig: t.ParsedServerConfig = { + ...rawConfig, + inspectionFailed: true, + source: 'config', + updatedAt: Date.now(), + }; + try { + await this.upsertConfigCache(cacheKey, stubConfig); + logger.info(`${prefix} Stored stub config for recovery`); + } catch (cacheError) { + logger.error( + `${prefix} Failed to store stub config (will retry on next request):`, + cacheError, + ); + } + return stubConfig; + } + } + + /** + * Writes a config to `configCacheRepo` using the atomic upsert operation. + * Safe for cross-process races — the underlying cache handles add-or-update internally. + */ + private async upsertConfigCache(cacheKey: string, config: t.ParsedServerConfig): Promise { + await this.configCacheRepo.upsert(cacheKey, config); + } + + /** + * Clears the config-source server cache, forcing re-inspection on next access. + * Called when admin config overrides change (e.g., mcpServers mutation). + * + * @returns Names of servers that were evicted from the config cache. + * Callers should disconnect active connections for these servers. + */ + public async invalidateConfigCache(): Promise { + const allCached = await this.configCacheRepo.getAll(); + const evictedNames = [ + ...new Set( + Object.keys(allCached).map((key) => { + const lastColon = key.lastIndexOf(':'); + return lastColon > 0 ? key.slice(0, lastColon) : key; + }), + ), + ]; + + await Promise.all([ + this.configCacheRepo.reset(), + // Only clear readThroughCacheAll (merged results that may include stale config servers). + // readThroughCache (individual YAML/user lookups) is unaffected by config mutations. + this.readThroughCacheAll.clear(), + ]); + + if (evictedNames.length > 0) { + logger.info( + `[MCPServersRegistry] Config server cache invalidated, evicted: ${evictedNames.join(', ')}`, + ); + } + return evictedNames; + } + + // TODO: Refactor callers to use config.requiresOAuth directly instead of this method. + // Known gap: config-source OAuth servers are not included here because callers + // (OAuthReconnectionManager, UserController) lack request context to resolve configServers. + // Config-source OAuth auto-reconnection and uninstall cleanup require a separate mechanism. public async getOAuthServers(userId?: string): Promise> { const allServers = await this.getAllServerConfigs(userId); const oauthServers = Object.entries(allServers).filter(([, config]) => config.requiresOAuth); @@ -287,8 +523,11 @@ export class MCPServersRegistry { public async reset(): Promise { await this.cacheConfigsRepo.reset(); + await this.configCacheRepo.reset(); await this.readThroughCache.clear(); await this.readThroughCacheAll.clear(); + this.yamlServerNames = null; + this.yamlServerNamesPromise = null; } public async removeServer( @@ -316,4 +555,48 @@ export class MCPServersRegistry { private getReadThroughCacheKey(serverName: string, userId?: string): string { return userId ? `${serverName}::${userId}` : serverName; } + + /** + * Returns memoized YAML server names. Populated lazily on first call after boot/reset. + * YAML servers don't change after boot, so this avoids repeated `getAll()` calls. + * Uses promise deduplication to prevent concurrent cold-start double-fetch. + */ + private getYamlServerNames(): Promise> { + if (this.yamlServerNames) { + return Promise.resolve(this.yamlServerNames); + } + if (this.yamlServerNamesPromise) { + return this.yamlServerNamesPromise; + } + this.yamlServerNamesPromise = this.cacheConfigsRepo + .getAll() + .then((configs) => { + this.yamlServerNames = new Set(Object.keys(configs)); + this.yamlServerNamesPromise = null; + return this.yamlServerNames; + }) + .catch((err) => { + this.yamlServerNamesPromise = null; + throw err; + }); + return this.yamlServerNamesPromise; + } + + /** + * Produces a config-cache key scoped by server name AND a hash of the raw config. + * Prevents cross-tenant cache poisoning when two tenants define the same server name + * with different configurations. + */ + private configCacheKey(serverName: string, rawConfig: t.MCPOptions): string { + const sorted = JSON.stringify(rawConfig, (_key, value: unknown) => { + if (value !== null && typeof value === 'object' && !Array.isArray(value)) { + return Object.fromEntries( + Object.entries(value as Record).sort(([a], [b]) => a.localeCompare(b)), + ); + } + return value; + }); + const hash = createHash('sha256').update(sorted).digest('hex').slice(0, 16); + return `${serverName}:${hash}`; + } } diff --git a/packages/api/src/mcp/registry/ServerConfigsRepositoryInterface.ts b/packages/api/src/mcp/registry/ServerConfigsRepositoryInterface.ts index 1c913dd1a3..4bf0fdd615 100644 --- a/packages/api/src/mcp/registry/ServerConfigsRepositoryInterface.ts +++ b/packages/api/src/mcp/registry/ServerConfigsRepositoryInterface.ts @@ -9,6 +9,9 @@ export interface IServerConfigsRepositoryInterface { //ACL Entry check if update is possible update(serverName: string, config: ParsedServerConfig, userId?: string): Promise; + /** Atomic add-or-update without requiring callers to inspect error messages. */ + upsert(serverName: string, config: ParsedServerConfig, userId?: string): Promise; + //ACL Entry check if remove is possible remove(serverName: string, userId?: string): Promise; diff --git a/packages/api/src/mcp/registry/__tests__/MCPServerInspector.test.ts b/packages/api/src/mcp/registry/__tests__/MCPServerInspector.test.ts index f0ab75c9b4..2012f82e31 100644 --- a/packages/api/src/mcp/registry/__tests__/MCPServerInspector.test.ts +++ b/packages/api/src/mcp/registry/__tests__/MCPServerInspector.test.ts @@ -321,12 +321,12 @@ describe('MCPServerInspector', () => { const result = await MCPServerInspector.inspect('test_server', rawConfig); // Verify factory was called to create connection - expect(MCPConnectionFactory.create).toHaveBeenCalledWith({ - serverName: 'test_server', - serverConfig: expect.objectContaining({ type: 'stdio', command: 'node' }), - useSSRFProtection: true, - dbSourced: false, - }); + expect(MCPConnectionFactory.create).toHaveBeenCalledWith( + expect.objectContaining({ + serverName: 'test_server', + serverConfig: expect.objectContaining({ type: 'stdio', command: 'node' }), + }), + ); // Verify temporary connection was disconnected expect(tempMockConnection.disconnect).toHaveBeenCalled(); diff --git a/packages/api/src/mcp/registry/__tests__/MCPServersRegistry.test.ts b/packages/api/src/mcp/registry/__tests__/MCPServersRegistry.test.ts index 8891120717..a20c09705f 100644 --- a/packages/api/src/mcp/registry/__tests__/MCPServersRegistry.test.ts +++ b/packages/api/src/mcp/registry/__tests__/MCPServersRegistry.test.ts @@ -112,8 +112,8 @@ describe('MCPServersRegistry', () => { const userConfigBefore = await registry.getServerConfig('user_server'); const allConfigsBefore = await registry.getAllServerConfigs(); - expect(appConfigBefore).toEqual(testParsedConfig); - expect(userConfigBefore).toEqual(testParsedConfig); + expect(appConfigBefore).toEqual(expect.objectContaining(testParsedConfig)); + expect(userConfigBefore).toEqual(expect.objectContaining(testParsedConfig)); expect(Object.keys(allConfigsBefore)).toHaveLength(2); // Reset everything @@ -250,22 +250,18 @@ describe('MCPServersRegistry', () => { }); it('should use different cache keys for different userIds', async () => { - // Spy on the cache repository get method + await registry['cacheConfigsRepo'].add('test_server', testParsedConfig); const cacheRepoGetSpy = jest.spyOn(registry['cacheConfigsRepo'], 'get'); - // First call without userId await registry.getServerConfig('test_server'); expect(cacheRepoGetSpy).toHaveBeenCalledTimes(1); - // Call with userId - should be a different cache key, so hits repository again await registry.getServerConfig('test_server', 'user123'); expect(cacheRepoGetSpy).toHaveBeenCalledTimes(2); - // Repeat call with same userId - should hit read-through cache await registry.getServerConfig('test_server', 'user123'); - expect(cacheRepoGetSpy).toHaveBeenCalledTimes(2); // Still 2 + expect(cacheRepoGetSpy).toHaveBeenCalledTimes(2); - // Call with different userId - should hit repository await registry.getServerConfig('test_server', 'user456'); expect(cacheRepoGetSpy).toHaveBeenCalledTimes(3); }); diff --git a/packages/api/src/mcp/registry/__tests__/ensureConfigServers.test.ts b/packages/api/src/mcp/registry/__tests__/ensureConfigServers.test.ts new file mode 100644 index 0000000000..70eb2f75c4 --- /dev/null +++ b/packages/api/src/mcp/registry/__tests__/ensureConfigServers.test.ts @@ -0,0 +1,328 @@ +import type * as t from '~/mcp/types'; +import { MCPServersRegistry } from '~/mcp/registry/MCPServersRegistry'; +import { MCPServerInspector } from '~/mcp/registry/MCPServerInspector'; + +jest.mock('~/mcp/registry/MCPServerInspector'); +jest.mock('~/mcp/registry/db/ServerConfigsDB', () => ({ + ServerConfigsDB: jest.fn().mockImplementation(() => ({ + get: jest.fn().mockResolvedValue(undefined), + getAll: jest.fn().mockResolvedValue({}), + add: jest.fn().mockResolvedValue(undefined), + update: jest.fn().mockResolvedValue(undefined), + upsert: jest.fn().mockResolvedValue(undefined), + remove: jest.fn().mockResolvedValue(undefined), + reset: jest.fn().mockResolvedValue(undefined), + })), +})); + +const FIXED_TIME = 1699564800000; + +const mockMongoose = {} as typeof import('mongoose'); + +const sseConfig: t.MCPOptions = { + type: 'sse', + url: 'https://mcp.example.com/sse', +} as unknown as t.MCPOptions; + +const altSseConfig: t.MCPOptions = { + type: 'sse', + url: 'https://mcp.other-tenant.com/sse', +} as unknown as t.MCPOptions; + +const yamlConfig: t.MCPOptions = { + type: 'stdio', + command: 'node', + args: ['tools.js'], +} as unknown as t.MCPOptions; + +function makeParsedConfig(overrides: Partial = {}): t.ParsedServerConfig { + return { + type: 'sse', + url: 'https://mcp.example.com/sse', + requiresOAuth: false, + tools: 'tool_a, tool_b', + capabilities: '{}', + initDuration: 42, + ...overrides, + } as unknown as t.ParsedServerConfig; +} + +describe('MCPServersRegistry — ensureConfigServers', () => { + let registry: MCPServersRegistry; + let inspectSpy: jest.SpyInstance; + + beforeAll(() => { + jest.useFakeTimers(); + jest.setSystemTime(new Date(FIXED_TIME)); + }); + + afterAll(() => { + jest.useRealTimers(); + }); + + beforeEach(async () => { + (MCPServersRegistry as unknown as { instance: undefined }).instance = undefined; + MCPServersRegistry.createInstance(mockMongoose); + registry = MCPServersRegistry.getInstance(); + + inspectSpy = jest + .spyOn(MCPServerInspector, 'inspect') + .mockImplementation(async (_serverName: string, rawConfig: t.MCPOptions) => + makeParsedConfig(rawConfig as unknown as Partial), + ); + + await registry.reset(); + }); + + afterEach(() => { + inspectSpy.mockClear(); + }); + + it('should return empty for empty input', async () => { + expect(await registry.ensureConfigServers({})).toEqual({}); + }); + + it('should return empty for null/undefined input', async () => { + expect( + await registry.ensureConfigServers(null as unknown as Record), + ).toEqual({}); + expect( + await registry.ensureConfigServers(undefined as unknown as Record), + ).toEqual({}); + }); + + it('should exclude YAML servers from config-source detection', async () => { + await registry.addServer('yaml_server', yamlConfig, 'CACHE'); + + const result = await registry.ensureConfigServers({ + yaml_server: yamlConfig, + config_server: sseConfig, + }); + + expect(result).toHaveProperty('config_server'); + expect(result).not.toHaveProperty('yaml_server'); + }); + + it('should return empty when all servers are YAML', async () => { + await registry.addServer('yaml_a', yamlConfig, 'CACHE'); + await registry.addServer('yaml_b', yamlConfig, 'CACHE'); + inspectSpy.mockClear(); + + const result = await registry.ensureConfigServers({ + yaml_a: yamlConfig, + yaml_b: yamlConfig, + }); + + expect(result).toEqual({}); + expect(inspectSpy).not.toHaveBeenCalled(); + }); + + it('should lazy-initialize a config-source server and tag source as config', async () => { + const result = await registry.ensureConfigServers({ my_server: sseConfig }); + + expect(result).toHaveProperty('my_server'); + expect(result.my_server.source).toBe('config'); + expect(inspectSpy).toHaveBeenCalledTimes(1); + expect(inspectSpy).toHaveBeenCalledWith('my_server', sseConfig, undefined, undefined); + }); + + it('should return cached result on second call without re-inspecting', async () => { + await registry.ensureConfigServers({ my_server: sseConfig }); + expect(inspectSpy).toHaveBeenCalledTimes(1); + + const result2 = await registry.ensureConfigServers({ my_server: sseConfig }); + expect(result2).toHaveProperty('my_server'); + expect(result2.my_server.source).toBe('config'); + expect(inspectSpy).toHaveBeenCalledTimes(1); + }); + + it('should store inspectionFailed stub on inspection failure', async () => { + inspectSpy.mockRejectedValueOnce(new Error('connection refused')); + + const result = await registry.ensureConfigServers({ bad_server: sseConfig }); + + expect(result).toHaveProperty('bad_server'); + expect(result.bad_server.inspectionFailed).toBe(true); + expect(result.bad_server.source).toBe('config'); + }); + + it('should return stub from cache on repeated failure without re-inspecting', async () => { + inspectSpy.mockRejectedValueOnce(new Error('connection refused')); + await registry.ensureConfigServers({ bad_server: sseConfig }); + expect(inspectSpy).toHaveBeenCalledTimes(1); + + const result2 = await registry.ensureConfigServers({ bad_server: sseConfig }); + expect(result2.bad_server.inspectionFailed).toBe(true); + expect(inspectSpy).toHaveBeenCalledTimes(1); + }); + + it('should retry stale failure stub after CONFIG_STUB_RETRY_MS', async () => { + inspectSpy.mockRejectedValueOnce(new Error('transient DNS failure')); + await registry.ensureConfigServers({ flaky_server: sseConfig }); + expect(inspectSpy).toHaveBeenCalledTimes(1); + + jest.setSystemTime(new Date(FIXED_TIME + 6 * 60 * 1000)); + + const result = await registry.ensureConfigServers({ flaky_server: sseConfig }); + expect(inspectSpy).toHaveBeenCalledTimes(2); + expect(result.flaky_server.inspectionFailed).toBeUndefined(); + expect(result.flaky_server.source).toBe('config'); + + jest.setSystemTime(new Date(FIXED_TIME)); + }); + + describe('cross-tenant isolation', () => { + it('should use different cache keys for same server name with different configs', async () => { + inspectSpy.mockClear(); + const resultA = await registry.ensureConfigServers({ shared_name: sseConfig }); + expect(resultA.shared_name.source).toBe('config'); + expect(inspectSpy).toHaveBeenCalledTimes(1); + + const resultB = await registry.ensureConfigServers({ shared_name: altSseConfig }); + expect(resultB.shared_name.source).toBe('config'); + expect(inspectSpy).toHaveBeenCalledTimes(2); + }); + + it('should return tenant-A config for tenant-A and tenant-B config for tenant-B', async () => { + const resultA = await registry.ensureConfigServers({ srv: sseConfig }); + const resultB = await registry.ensureConfigServers({ srv: altSseConfig }); + + expect((resultA.srv as unknown as { url: string }).url).toBe('https://mcp.example.com/sse'); + expect((resultB.srv as unknown as { url: string }).url).toBe( + 'https://mcp.other-tenant.com/sse', + ); + }); + }); + + describe('concurrent deduplication', () => { + it('should only inspect once for multiple parallel calls with the same config', async () => { + inspectSpy.mockClear(); + // Fire two calls simultaneously — both see cache miss, but only one should inspect + const [r1, r2] = await Promise.all([ + registry.ensureConfigServers({ dedup_srv: sseConfig }), + registry.ensureConfigServers({ dedup_srv: sseConfig }), + ]); + + expect(r1.dedup_srv).toBeDefined(); + expect(r2.dedup_srv).toBeDefined(); + expect(inspectSpy).toHaveBeenCalledTimes(1); + + // Subsequent call must NOT re-inspect (cached) + inspectSpy.mockClear(); + await registry.ensureConfigServers({ dedup_srv: sseConfig }); + expect(inspectSpy).toHaveBeenCalledTimes(0); + }); + }); + + describe('merge order', () => { + it('should merge YAML → config → user with correct precedence in getAllServerConfigs', async () => { + await registry.addServer('yaml_srv', yamlConfig, 'CACHE'); + + const configServers = await registry.ensureConfigServers({ config_srv: sseConfig }); + + const all = await registry.getAllServerConfigs(undefined, configServers); + expect(all).toHaveProperty('yaml_srv'); + expect(all).toHaveProperty('config_srv'); + expect(all.yaml_srv.source).toBe('yaml'); + expect(all.config_srv.source).toBe('config'); + }); + + it('should let config servers appear alongside user DB servers', async () => { + const mockDbConfigs = { + user_srv: makeParsedConfig({ source: 'user', dbId: 'abc123' }), + }; + jest.spyOn(registry['dbConfigsRepo'], 'getAll').mockResolvedValue(mockDbConfigs); + + const configServers = await registry.ensureConfigServers({ config_srv: sseConfig }); + const all = await registry.getAllServerConfigs('user-1', configServers); + + expect(all).toHaveProperty('config_srv'); + expect(all).toHaveProperty('user_srv'); + expect(all.config_srv.source).toBe('config'); + expect(all.user_srv.source).toBe('user'); + }); + }); + + describe('invalidateConfigCache', () => { + it('should clear config cache and force re-inspection on next call', async () => { + await registry.ensureConfigServers({ my_server: sseConfig }); + inspectSpy.mockClear(); + + await registry.invalidateConfigCache(); + + await registry.ensureConfigServers({ my_server: sseConfig }); + expect(inspectSpy).toHaveBeenCalledTimes(1); + }); + + it('should return evicted server names', async () => { + await registry.ensureConfigServers({ srv_a: sseConfig, srv_b: altSseConfig }); + const evicted = await registry.invalidateConfigCache(); + expect(evicted.length).toBeGreaterThan(0); + }); + + it('should return empty array when nothing is cached', async () => { + const evicted = await registry.invalidateConfigCache(); + expect(evicted).toEqual([]); + }); + }); + + describe('getServerConfig with configServers', () => { + it('should return config-source server when configServers is passed', async () => { + const configServers = await registry.ensureConfigServers({ config_srv: sseConfig }); + const config = await registry.getServerConfig('config_srv', undefined, configServers); + expect(config).toBeDefined(); + expect(config?.source).toBe('config'); + }); + + it('should return config-source server with userId when configServers is passed', async () => { + const configServers = await registry.ensureConfigServers({ config_srv: sseConfig }); + const config = await registry.getServerConfig('config_srv', 'user-123', configServers); + expect(config).toBeDefined(); + expect(config?.source).toBe('config'); + }); + + it('should return undefined for config-source server without configServers (tenant isolation)', async () => { + await registry.ensureConfigServers({ config_srv: sseConfig }); + const config = await registry.getServerConfig('config_srv'); + expect(config).toBeUndefined(); + }); + + it('should return correct config after invalidation and re-init', async () => { + const configServers1 = await registry.ensureConfigServers({ config_srv: sseConfig }); + expect(await registry.getServerConfig('config_srv', undefined, configServers1)).toBeDefined(); + + await registry.invalidateConfigCache(); + + const configServers2 = await registry.ensureConfigServers({ config_srv: sseConfig }); + const config = await registry.getServerConfig('config_srv', undefined, configServers2); + expect(config).toBeDefined(); + expect(config?.source).toBe('config'); + }); + + it('should not cross-contaminate between tenant configServers maps', async () => { + const tenantA = await registry.ensureConfigServers({ srv: sseConfig }); + const tenantB = await registry.ensureConfigServers({ srv: altSseConfig }); + + const configA = await registry.getServerConfig('srv', undefined, tenantA); + const configB = await registry.getServerConfig('srv', undefined, tenantB); + + expect((configA as unknown as { url: string }).url).toBe('https://mcp.example.com/sse'); + expect((configB as unknown as { url: string }).url).toBe('https://mcp.other-tenant.com/sse'); + }); + }); + + describe('source tagging', () => { + it('should tag CACHE-stored servers as yaml', async () => { + await registry.addServer('yaml_srv', yamlConfig, 'CACHE'); + const config = await registry.getServerConfig('yaml_srv'); + expect(config?.source).toBe('yaml'); + }); + + it('should tag stubs as yaml when stored in CACHE', async () => { + await registry.addServerStub('stub_srv', yamlConfig, 'CACHE'); + const config = await registry.getServerConfig('stub_srv'); + expect(config?.source).toBe('yaml'); + expect(config?.inspectionFailed).toBe(true); + }); + }); +}); diff --git a/packages/api/src/mcp/registry/cache/ServerConfigsCacheFactory.ts b/packages/api/src/mcp/registry/cache/ServerConfigsCacheFactory.ts index ba0cec90ea..ebe19b59e3 100644 --- a/packages/api/src/mcp/registry/cache/ServerConfigsCacheFactory.ts +++ b/packages/api/src/mcp/registry/cache/ServerConfigsCacheFactory.ts @@ -1,31 +1,57 @@ -import { cacheConfig } from '~/cache'; +import { ServerConfigsCacheRedisAggregateKey } from './ServerConfigsCacheRedisAggregateKey'; import { ServerConfigsCacheInMemory } from './ServerConfigsCacheInMemory'; import { ServerConfigsCacheRedis } from './ServerConfigsCacheRedis'; +import { cacheConfig } from '~/cache'; -export type ServerConfigsCache = ServerConfigsCacheInMemory | ServerConfigsCacheRedis; +export type ServerConfigsCache = + | ServerConfigsCacheInMemory + | ServerConfigsCacheRedis + | ServerConfigsCacheRedisAggregateKey; /** - * Factory for creating the appropriate ServerConfigsCache implementation based on deployment mode. - * Automatically selects between in-memory and Redis-backed storage depending on USE_REDIS config. - * In single-instance mode (USE_REDIS=false), returns lightweight in-memory cache. - * In cluster mode (USE_REDIS=true), returns Redis-backed cache with distributed coordination. - * Provides a unified interface regardless of the underlying storage mechanism. + * Namespace for YAML-loaded app-level MCP configs. When Redis is enabled, uses a single + * aggregate key instead of per-server keys to avoid the costly SCAN + batch-GET pattern + * in {@link ServerConfigsCacheRedis.getAll} that caused 60s+ stalls under concurrent + * load (see GitHub #11624, #12408). When Redis is disabled, uses in-memory storage. + */ +export const APP_CACHE_NAMESPACE = 'App' as const; + +/** Namespace for admin-defined config-override MCP server inspection results. */ +export const CONFIG_CACHE_NAMESPACE = 'Config' as const; + +/** Namespaces that use the aggregate-key optimization to avoid SCAN+N-GETs stalls. */ +const AGGREGATE_KEY_NAMESPACES = new Set([APP_CACHE_NAMESPACE, CONFIG_CACHE_NAMESPACE]); + +/** + * Factory for creating the appropriate ServerConfigsCache implementation based on + * deployment mode and namespace. + * + * Namespaces in {@link AGGREGATE_KEY_NAMESPACES} use {@link ServerConfigsCacheRedisAggregateKey} + * when Redis is enabled — storing all configs under a single key so `getAll()` is one GET + * instead of SCAN + N GETs. Cross-instance visibility is preserved: reinspection results + * propagate through Redis automatically. + * + * Other namespaces use the standard {@link ServerConfigsCacheRedis} (per-key storage with + * SCAN-based enumeration) when Redis is enabled. */ export class ServerConfigsCacheFactory { /** * Create a ServerConfigsCache instance. - * Returns Redis implementation if Redis is configured, otherwise in-memory implementation. * - * @param namespace - The namespace for the cache (e.g., 'App') - only used for Redis namespacing - * @param leaderOnly - Whether operations should only be performed by the leader (only applies to Redis) + * @param namespace - The namespace for the cache. Namespaces in {@link AGGREGATE_KEY_NAMESPACES} + * use aggregate-key Redis storage (or in-memory when Redis is disabled). + * @param leaderOnly - Whether write operations should only be performed by the leader. * @returns ServerConfigsCache instance */ static create(namespace: string, leaderOnly: boolean): ServerConfigsCache { - if (cacheConfig.USE_REDIS) { - return new ServerConfigsCacheRedis(namespace, leaderOnly); + if (!cacheConfig.USE_REDIS) { + return new ServerConfigsCacheInMemory(); } - // In-memory mode uses a simple Map - doesn't need namespace - return new ServerConfigsCacheInMemory(); + if (AGGREGATE_KEY_NAMESPACES.has(namespace)) { + return new ServerConfigsCacheRedisAggregateKey(namespace, leaderOnly); + } + + return new ServerConfigsCacheRedis(namespace, leaderOnly); } } diff --git a/packages/api/src/mcp/registry/cache/ServerConfigsCacheInMemory.ts b/packages/api/src/mcp/registry/cache/ServerConfigsCacheInMemory.ts index 384c477756..5a7fd35b9f 100644 --- a/packages/api/src/mcp/registry/cache/ServerConfigsCacheInMemory.ts +++ b/packages/api/src/mcp/registry/cache/ServerConfigsCacheInMemory.ts @@ -28,6 +28,10 @@ export class ServerConfigsCacheInMemory { this.cache.set(serverName, { ...config, updatedAt: Date.now() }); } + public async upsert(serverName: string, config: ParsedServerConfig): Promise { + this.cache.set(serverName, { ...config, updatedAt: Date.now() }); + } + public async remove(serverName: string): Promise { if (!this.cache.delete(serverName)) { throw new Error(`Failed to remove server "${serverName}" in cache.`); diff --git a/packages/api/src/mcp/registry/cache/ServerConfigsCacheRedis.ts b/packages/api/src/mcp/registry/cache/ServerConfigsCacheRedis.ts index d3154baf73..af1316056d 100644 --- a/packages/api/src/mcp/registry/cache/ServerConfigsCacheRedis.ts +++ b/packages/api/src/mcp/registry/cache/ServerConfigsCacheRedis.ts @@ -52,6 +52,12 @@ export class ServerConfigsCacheRedis this.successCheck(`update ${this.namespace} server "${serverName}"`, success); } + public async upsert(serverName: string, config: ParsedServerConfig): Promise { + if (this.leaderOnly) await this.leaderCheck(`upsert ${this.namespace} MCP servers`); + const success = await this.cache.set(serverName, { ...config, updatedAt: Date.now() }); + this.successCheck(`upsert ${this.namespace} server "${serverName}"`, success); + } + public async remove(serverName: string): Promise { if (this.leaderOnly) await this.leaderCheck(`remove ${this.namespace} MCP servers`); const success = await this.cache.delete(serverName); diff --git a/packages/api/src/mcp/registry/cache/ServerConfigsCacheRedisAggregateKey.ts b/packages/api/src/mcp/registry/cache/ServerConfigsCacheRedisAggregateKey.ts new file mode 100644 index 0000000000..5fc32bd7aa --- /dev/null +++ b/packages/api/src/mcp/registry/cache/ServerConfigsCacheRedisAggregateKey.ts @@ -0,0 +1,193 @@ +import type Keyv from 'keyv'; +import type { IServerConfigsRepositoryInterface } from '~/mcp/registry/ServerConfigsRepositoryInterface'; +import type { ParsedServerConfig, AddServerResult } from '~/mcp/types'; +import { BaseRegistryCache } from './BaseRegistryCache'; +import { cacheConfig, standardCache } from '~/cache'; + +/** + * Redis-backed MCP server configs cache that stores all entries under a single aggregate key. + * + * Unlike {@link ServerConfigsCacheRedis} which uses SCAN + batch-GET for `getAll()`, this + * implementation stores the entire config map as a single JSON value in Redis. This makes + * `getAll()` a single O(1) GET regardless of keyspace size, eliminating the 60s+ stalls + * caused by SCAN under concurrent load in large deployments (see GitHub #11624, #12408). + * + * Trade-offs: + * - `add/update/remove` use a serialized read-modify-write on the aggregate key via a + * promise-based mutex. This prevents concurrent writes from racing within a single + * process (e.g., during `Promise.allSettled` initialization of multiple servers). + * - The entire config map is serialized/deserialized on every operation. With typical MCP + * deployments (~5-50 servers), the JSON payload is small (10-50KB). + * - Cross-instance visibility is preserved: all instances read/write the same Redis key, + * so reinspection results propagate automatically after readThroughCache TTL expiry. + * + * IMPORTANT: The promise-based writeLock serializes writes within a single Node.js process + * only. Concurrent writes from separate instances race at the Redis level (last-write-wins). + * This is acceptable because writes are performed exclusively by the leader during + * initialization via {@link MCPServersInitializer}. `reinspectServer` is manual and rare. + * Callers must enforce this single-writer invariant externally. + */ +const AGGREGATE_KEY = '__all__'; + +export class ServerConfigsCacheRedisAggregateKey + extends BaseRegistryCache + implements IServerConfigsRepositoryInterface +{ + protected readonly cache: Keyv; + private writeLock: Promise = Promise.resolve(); + + /** + * In-memory snapshot of the aggregate key to avoid redundant Redis GETs. + * `getAll()` is called 20+ times per chat request (once per tool, per server + * config lookup, per connection check) but the data doesn't change within a + * request cycle. The snapshot collapses all reads within the TTL window into + * a single Redis GET. Invalidated on every write (`add`, `update`, `remove`, `reset`). + * + * NOTE: In multi-instance deployments, the effective max staleness for cross-instance + * writes is up to 2×MCP_REGISTRY_CACHE_TTL. This happens when readThroughCacheAll + * (MCPServersRegistry) is populated from a snapshot that is nearly expired. For the + * default 5000ms TTL, worst-case cross-instance propagation is ~10s. This is acceptable + * given the single-writer invariant (leader-only initialization, rare manual reinspection). + */ + private localSnapshot: Record | null = null; + /** Milliseconds since epoch. 0 = epoch = always expired on first check. */ + private localSnapshotExpiry = 0; + + private readonly namespace: string; + + constructor(namespace: string, leaderOnly: boolean) { + super(leaderOnly); + this.namespace = namespace; + this.cache = standardCache(`${this.PREFIX}::Servers::${namespace}`); + } + + private invalidateLocalSnapshot(): void { + this.localSnapshot = null; + this.localSnapshotExpiry = 0; + } + + /** + * Serializes write operations to prevent concurrent read-modify-write races. + * Reads (`get`, `getAll`) are not serialized — they can run concurrently. + * Always invalidates the local snapshot in `finally` to guarantee cleanup + * even when the write callback throws (e.g., Redis SET failure). + */ + private async withWriteLock(fn: () => Promise): Promise { + const previousLock = this.writeLock; + let resolve!: () => void; + this.writeLock = new Promise((r) => { + resolve = r; + }); + try { + await previousLock; + return await fn(); + } finally { + this.invalidateLocalSnapshot(); + resolve(); + } + } + + public async getAll(): Promise> { + const ttl = cacheConfig.MCP_REGISTRY_CACHE_TTL; + if (ttl > 0) { + const now = Date.now(); + if (this.localSnapshot !== null && now < this.localSnapshotExpiry) { + return this.localSnapshot; + } + } + + const result = + ((await this.cache.get(AGGREGATE_KEY)) as Record | undefined) ?? + {}; + + if (ttl > 0) { + this.localSnapshot = result; + this.localSnapshotExpiry = Date.now() + ttl; + } + return result; + } + + public async get(serverName: string): Promise { + const all = await this.getAll(); + return all[serverName]; + } + + public async add(serverName: string, config: ParsedServerConfig): Promise { + if (this.leaderOnly) await this.leaderCheck('add MCP servers'); + return this.withWriteLock(async () => { + // Force fresh Redis read so the read-modify-write uses current data, + // not a snapshot that may predate this write. Distinct from the finally-block + // invalidation which cleans up after the write completes or throws. + this.invalidateLocalSnapshot(); + const all = await this.getAll(); + if (all[serverName]) { + throw new Error( + `Server "${serverName}" already exists in cache. Use update() to modify existing configs.`, + ); + } + const storedConfig = { ...config, updatedAt: Date.now() }; + const newAll = { ...all, [serverName]: storedConfig }; + const success = await this.cache.set(AGGREGATE_KEY, newAll); + this.successCheck(`add ${this.namespace} server "${serverName}"`, success); + return { serverName, config: storedConfig }; + }); + } + + public async update(serverName: string, config: ParsedServerConfig): Promise { + if (this.leaderOnly) await this.leaderCheck('update MCP servers'); + return this.withWriteLock(async () => { + this.invalidateLocalSnapshot(); // Force fresh Redis read (see add() comment) + const all = await this.getAll(); + if (!all[serverName]) { + throw new Error( + `Server "${serverName}" does not exist in cache. Use add() to create new configs.`, + ); + } + const newAll = { ...all, [serverName]: { ...config, updatedAt: Date.now() } }; + const success = await this.cache.set(AGGREGATE_KEY, newAll); + this.successCheck(`update ${this.namespace} server "${serverName}"`, success); + }); + } + + public async upsert(serverName: string, config: ParsedServerConfig): Promise { + if (this.leaderOnly) await this.leaderCheck('upsert MCP servers'); + return this.withWriteLock(async () => { + this.invalidateLocalSnapshot(); + const all = await this.getAll(); + const newAll = { ...all, [serverName]: { ...config, updatedAt: Date.now() } }; + const success = await this.cache.set(AGGREGATE_KEY, newAll); + this.successCheck(`upsert ${this.namespace} server "${serverName}"`, success); + }); + } + + public async remove(serverName: string): Promise { + if (this.leaderOnly) await this.leaderCheck('remove MCP servers'); + return this.withWriteLock(async () => { + this.invalidateLocalSnapshot(); // Force fresh Redis read (see add() comment) + const all = await this.getAll(); + if (!all[serverName]) { + throw new Error(`Failed to remove server "${serverName}" in cache.`); + } + const { [serverName]: _, ...newAll } = all; + const success = await this.cache.set(AGGREGATE_KEY, newAll); + this.successCheck(`remove ${this.namespace} server "${serverName}"`, success); + }); + } + + /** + * Resets the aggregate key directly instead of using SCAN-based `cache.clear()`. + * Only one key (`__all__`) ever exists in this namespace, so a targeted delete is + * more efficient and consistent with the PR's goal of eliminating SCAN operations. + * + * NOTE: Intentionally not serialized via `withWriteLock`. `reset()` is only called + * during lifecycle transitions (test teardown, full reinitialization via + * `MCPServersInitializer`) where no concurrent writes are in flight. + */ + public override async reset(): Promise { + if (this.leaderOnly) { + await this.leaderCheck(`reset ${this.namespace} MCP servers cache`); + } + await this.cache.delete(AGGREGATE_KEY); + this.invalidateLocalSnapshot(); + } +} diff --git a/packages/api/src/mcp/registry/cache/__tests__/ServerConfigsCacheFactory.test.ts b/packages/api/src/mcp/registry/cache/__tests__/ServerConfigsCacheFactory.test.ts index 7499ae127e..577b878cc7 100644 --- a/packages/api/src/mcp/registry/cache/__tests__/ServerConfigsCacheFactory.test.ts +++ b/packages/api/src/mcp/registry/cache/__tests__/ServerConfigsCacheFactory.test.ts @@ -1,9 +1,11 @@ -import { ServerConfigsCacheFactory } from '../ServerConfigsCacheFactory'; +import { ServerConfigsCacheFactory, APP_CACHE_NAMESPACE } from '../ServerConfigsCacheFactory'; +import { ServerConfigsCacheRedisAggregateKey } from '../ServerConfigsCacheRedisAggregateKey'; import { ServerConfigsCacheInMemory } from '../ServerConfigsCacheInMemory'; import { ServerConfigsCacheRedis } from '../ServerConfigsCacheRedis'; import { cacheConfig } from '~/cache'; // Mock the cache implementations +jest.mock('../ServerConfigsCacheRedisAggregateKey'); jest.mock('../ServerConfigsCacheInMemory'); jest.mock('../ServerConfigsCacheRedis'); @@ -17,53 +19,48 @@ jest.mock('~/cache', () => ({ describe('ServerConfigsCacheFactory', () => { beforeEach(() => { jest.clearAllMocks(); + cacheConfig.USE_REDIS = false; }); describe('create()', () => { - it('should return ServerConfigsCacheRedis when USE_REDIS is true', () => { - // Arrange + it('should return ServerConfigsCacheRedisAggregateKey for App namespace when USE_REDIS is true', () => { cacheConfig.USE_REDIS = true; - // Act - const cache = ServerConfigsCacheFactory.create('App', true); + const cache = ServerConfigsCacheFactory.create(APP_CACHE_NAMESPACE, false); - // Assert - expect(cache).toBeInstanceOf(ServerConfigsCacheRedis); - expect(ServerConfigsCacheRedis).toHaveBeenCalledWith('App', true); + expect(cache).toBeInstanceOf(ServerConfigsCacheRedisAggregateKey); + expect(ServerConfigsCacheRedisAggregateKey).toHaveBeenCalledWith(APP_CACHE_NAMESPACE, false); + expect(ServerConfigsCacheRedis).not.toHaveBeenCalled(); + expect(ServerConfigsCacheInMemory).not.toHaveBeenCalled(); }); - it('should return ServerConfigsCacheInMemory when USE_REDIS is false', () => { - // Arrange + it('should return ServerConfigsCacheInMemory for App namespace when USE_REDIS is false', () => { cacheConfig.USE_REDIS = false; - // Act - const cache = ServerConfigsCacheFactory.create('App', false); + const cache = ServerConfigsCacheFactory.create(APP_CACHE_NAMESPACE, false); - // Assert expect(cache).toBeInstanceOf(ServerConfigsCacheInMemory); - expect(ServerConfigsCacheInMemory).toHaveBeenCalled(); + expect(ServerConfigsCacheInMemory).toHaveBeenCalledWith(); + expect(ServerConfigsCacheRedis).not.toHaveBeenCalled(); + expect(ServerConfigsCacheRedisAggregateKey).not.toHaveBeenCalled(); }); - it('should pass correct parameters to ServerConfigsCacheRedis', () => { - // Arrange + it('should return ServerConfigsCacheRedis for non-App namespaces when USE_REDIS is true', () => { cacheConfig.USE_REDIS = true; - // Act - ServerConfigsCacheFactory.create('CustomNamespace', true); + const cache = ServerConfigsCacheFactory.create('CustomNamespace', true); - // Assert + expect(cache).toBeInstanceOf(ServerConfigsCacheRedis); expect(ServerConfigsCacheRedis).toHaveBeenCalledWith('CustomNamespace', true); + expect(ServerConfigsCacheRedisAggregateKey).not.toHaveBeenCalled(); }); - it('should create ServerConfigsCacheInMemory without parameters when USE_REDIS is false', () => { - // Arrange + it('should return ServerConfigsCacheInMemory for non-App namespaces when USE_REDIS is false', () => { cacheConfig.USE_REDIS = false; - // Act - ServerConfigsCacheFactory.create('App', false); + const cache = ServerConfigsCacheFactory.create('CustomNamespace', false); - // Assert - // In-memory cache doesn't use namespace/leaderOnly parameters + expect(cache).toBeInstanceOf(ServerConfigsCacheInMemory); expect(ServerConfigsCacheInMemory).toHaveBeenCalledWith(); }); }); diff --git a/packages/api/src/mcp/registry/cache/__tests__/ServerConfigsCacheInMemory.test.ts b/packages/api/src/mcp/registry/cache/__tests__/ServerConfigsCacheInMemory.test.ts index c123325c1f..b8827a3fe9 100644 --- a/packages/api/src/mcp/registry/cache/__tests__/ServerConfigsCacheInMemory.test.ts +++ b/packages/api/src/mcp/registry/cache/__tests__/ServerConfigsCacheInMemory.test.ts @@ -12,6 +12,7 @@ describe('ServerConfigsCacheInMemory Integration Tests', () => { // Test data const mockConfig1: ParsedServerConfig = { + type: 'stdio', command: 'node', args: ['server1.js'], env: { TEST: 'value1' }, @@ -19,6 +20,7 @@ describe('ServerConfigsCacheInMemory Integration Tests', () => { }; const mockConfig2: ParsedServerConfig = { + type: 'stdio', command: 'python', args: ['server2.py'], env: { TEST: 'value2' }, @@ -26,6 +28,7 @@ describe('ServerConfigsCacheInMemory Integration Tests', () => { }; const mockConfig3: ParsedServerConfig = { + type: 'stdio', command: 'node', args: ['server3.js'], url: 'http://localhost:3000', diff --git a/packages/api/src/mcp/registry/cache/__tests__/ServerConfigsCacheRedis.perf_benchmark.manual.spec.ts b/packages/api/src/mcp/registry/cache/__tests__/ServerConfigsCacheRedis.perf_benchmark.manual.spec.ts new file mode 100644 index 0000000000..d9dc7bb978 --- /dev/null +++ b/packages/api/src/mcp/registry/cache/__tests__/ServerConfigsCacheRedis.perf_benchmark.manual.spec.ts @@ -0,0 +1,343 @@ +/** + * Performance benchmark for ServerConfigsCacheRedis.getAll() + * + * Requires a live Redis instance. Run manually (excluded from CI): + * npx jest --config packages/api/jest.config.mjs --testPathPatterns="perf_benchmark" --coverage=false + * + * Set env vars as needed: + * USE_REDIS=true REDIS_URI=redis://localhost:6379 npx jest ... + * + * This benchmark isolates the two phases of getAll() — SCAN (key discovery) and + * batched GET (value retrieval) — to identify the actual bottleneck under load. + * It also benchmarks alternative approaches (single aggregate key, MGET) against + * the current SCAN+GET implementation. + */ +import { expect } from '@playwright/test'; +import type { RedisClientType } from 'redis'; +import type { ParsedServerConfig } from '~/mcp/types'; + +describe('ServerConfigsCacheRedis Performance Benchmark', () => { + let ServerConfigsCacheRedis: typeof import('../ServerConfigsCacheRedis').ServerConfigsCacheRedis; + let keyvRedisClient: Awaited['keyvRedisClient']; + let standardCache: Awaited['standardCache']; + + const PREFIX = 'perf-bench'; + + const makeConfig = (i: number): ParsedServerConfig => + ({ + type: 'stdio', + command: `cmd-${i}`, + args: [`arg-${i}`, `--flag-${i}`], + env: { KEY: `value-${i}`, EXTRA: `extra-${i}` }, + requiresOAuth: false, + tools: `tool_a_${i}, tool_b_${i}`, + capabilities: `{"tools":{"listChanged":true}}`, + serverInstructions: `Instructions for server ${i}`, + }) as ParsedServerConfig; + + beforeAll(async () => { + process.env.USE_REDIS = process.env.USE_REDIS ?? 'true'; + process.env.USE_REDIS_CLUSTER = process.env.USE_REDIS_CLUSTER ?? 'true'; + process.env.REDIS_URI = + process.env.REDIS_URI ?? + 'redis://127.0.0.1:7001,redis://127.0.0.1:7002,redis://127.0.0.1:7003'; + process.env.REDIS_KEY_PREFIX = process.env.REDIS_KEY_PREFIX ?? 'perf-bench-test'; + + const cacheModule = await import('../ServerConfigsCacheRedis'); + const redisClients = await import('~/cache/redisClients'); + const cacheFactory = await import('~/cache'); + + ServerConfigsCacheRedis = cacheModule.ServerConfigsCacheRedis; + keyvRedisClient = redisClients.keyvRedisClient; + standardCache = cacheFactory.standardCache; + + if (!keyvRedisClient) throw new Error('Redis client is not initialized'); + await redisClients.keyvRedisClientReady; + }); + + afterAll(async () => { + if (keyvRedisClient?.isOpen) await keyvRedisClient.disconnect(); + }); + + /** Clean up all keys matching our test prefix */ + async function cleanupKeys(pattern: string): Promise { + if (!keyvRedisClient || !('scanIterator' in keyvRedisClient)) return; + const keys: string[] = []; + for await (const key of keyvRedisClient.scanIterator({ MATCH: pattern })) { + keys.push(key); + } + if (keys.length > 0) { + await Promise.all(keys.map((key) => keyvRedisClient!.del(key))); + } + } + + /** Populate a cache with N configs and return the cache instance */ + async function populateCache( + namespace: string, + count: number, + ): Promise> { + const cache = new ServerConfigsCacheRedis(namespace, false); + for (let i = 0; i < count; i++) { + await cache.add(`server-${i}`, makeConfig(i)); + } + return cache; + } + + /** + * Benchmark 1: Isolate SCAN vs GET phases in current getAll() + * + * Measures time spent in each phase separately to identify the bottleneck. + */ + describe('Phase isolation: SCAN vs batched GET', () => { + const CONFIG_COUNTS = [5, 20, 50]; + + for (const count of CONFIG_COUNTS) { + it(`should measure SCAN and GET phases separately for ${count} configs`, async () => { + const ns = `${PREFIX}-phase-${count}`; + const cache = await populateCache(ns, count); + + try { + // Get the Keyv cache instance namespace for pattern matching + const keyvCache = standardCache(`MCP::ServersRegistry::Servers::${ns}`); + const pattern = `*MCP::ServersRegistry::Servers::${ns}:*`; + + // Phase 1: SCAN only (key discovery) + const scanStart = Date.now(); + const keys: string[] = []; + for await (const key of (keyvRedisClient as RedisClientType).scanIterator({ + MATCH: pattern, + })) { + keys.push(key); + } + const scanMs = Date.now() - scanStart; + + // Phase 2: Batched GET only (value retrieval via Keyv) + const keyNames = keys.map((key) => key.substring(key.lastIndexOf(':') + 1)); + const BATCH_SIZE = 100; + const getStart = Date.now(); + for (let i = 0; i < keyNames.length; i += BATCH_SIZE) { + const batch = keyNames.slice(i, i + BATCH_SIZE); + await Promise.all(batch.map((k) => keyvCache.get(k))); + } + const getMs = Date.now() - getStart; + + // Phase 3: Full getAll() (both phases combined) + const fullStart = Date.now(); + const result = await cache.getAll(); + const fullMs = Date.now() - fullStart; + + console.log( + `[${count} configs] SCAN: ${scanMs}ms | GET: ${getMs}ms | Full getAll: ${fullMs}ms | Keys found: ${keys.length}`, + ); + + expect(Object.keys(result).length).toBe(count); + + // Clean up the Keyv instance + await keyvCache.clear(); + } finally { + await cleanupKeys(`*${ns}*`); + } + }); + } + }); + + /** + * Benchmark 2: SCAN cost scales with total Redis keyspace, not just matching keys + * + * Redis SCAN iterates the entire hash table and filters by pattern. With a large + * keyspace (many non-matching keys), SCAN takes longer even if few keys match. + * This test measures SCAN time with background noise keys. + */ + describe('SCAN cost vs keyspace size', () => { + it('should measure SCAN latency with background noise keys', async () => { + const ns = `${PREFIX}-noise`; + const targetCount = 10; + + // Add target configs + const cache = await populateCache(ns, targetCount); + + // Add noise keys in a different namespace to inflate the keyspace + const noiseCount = 500; + const noiseCache = standardCache(`noise-namespace-${Date.now()}`); + for (let i = 0; i < noiseCount; i++) { + await noiseCache.set(`noise-${i}`, { data: `value-${i}` }); + } + + try { + const pattern = `*MCP::ServersRegistry::Servers::${ns}:*`; + + // Measure SCAN with noise + const scanStart = Date.now(); + const keys: string[] = []; + for await (const key of (keyvRedisClient as RedisClientType).scanIterator({ + MATCH: pattern, + })) { + keys.push(key); + } + const scanMs = Date.now() - scanStart; + + // Measure full getAll + const fullStart = Date.now(); + const result = await cache.getAll(); + const fullMs = Date.now() - fullStart; + + console.log( + `[${targetCount} configs + ${noiseCount} noise keys] SCAN: ${scanMs}ms | Full getAll: ${fullMs}ms`, + ); + + expect(Object.keys(result).length).toBe(targetCount); + } finally { + await noiseCache.clear(); + await cleanupKeys(`*${ns}*`); + } + }); + }); + + /** + * Benchmark 3: Concurrent getAll() calls (simulates the actual production bottleneck) + * + * Multiple users hitting /api/mcp/* simultaneously, all triggering getAll() + * after the 5s TTL read-through cache expires. + */ + describe('Concurrent getAll() under load', () => { + const CONCURRENCY_LEVELS = [1, 10, 50, 100]; + const CONFIG_COUNT = 30; + + for (const concurrency of CONCURRENCY_LEVELS) { + it(`should measure ${concurrency} concurrent getAll() calls with ${CONFIG_COUNT} configs`, async () => { + const ns = `${PREFIX}-concurrent-${concurrency}`; + const cache = await populateCache(ns, CONFIG_COUNT); + + try { + const startTime = Date.now(); + const promises = Array.from({ length: concurrency }, () => cache.getAll()); + const results = await Promise.all(promises); + const elapsed = Date.now() - startTime; + + console.log( + `[${CONFIG_COUNT} configs x ${concurrency} concurrent] Total: ${elapsed}ms | Per-call avg: ${(elapsed / concurrency).toFixed(1)}ms`, + ); + + for (const result of results) { + expect(Object.keys(result).length).toBe(CONFIG_COUNT); + } + } finally { + await cleanupKeys(`*${ns}*`); + } + }); + } + }); + + /** + * Benchmark 4: Alternative — Single aggregate key + * + * Instead of SCAN+GET, store all configs under one Redis key. + * getAll() becomes a single GET + JSON parse. + */ + describe('Alternative: Single aggregate key', () => { + it('should compare aggregate key vs SCAN+GET for getAll()', async () => { + const ns = `${PREFIX}-aggregate`; + const configCount = 30; + const cache = await populateCache(ns, configCount); + + // Build the aggregate object + const aggregate: Record = {}; + for (let i = 0; i < configCount; i++) { + aggregate[`server-${i}`] = makeConfig(i); + } + + // Store as single key + const aggregateCache = standardCache(`aggregate-test-${Date.now()}`); + await aggregateCache.set('all', aggregate); + + try { + // Measure SCAN+GET approach + const scanStart = Date.now(); + const scanResult = await cache.getAll(); + const scanMs = Date.now() - scanStart; + + // Measure single-key approach + const aggStart = Date.now(); + const aggResult = (await aggregateCache.get('all')) as Record; + const aggMs = Date.now() - aggStart; + + console.log( + `[${configCount} configs] SCAN+GET: ${scanMs}ms | Single key: ${aggMs}ms | Speedup: ${(scanMs / Math.max(aggMs, 1)).toFixed(1)}x`, + ); + + expect(Object.keys(scanResult).length).toBe(configCount); + expect(Object.keys(aggResult).length).toBe(configCount); + + // Concurrent comparison + const concurrency = 100; + const scanConcStart = Date.now(); + await Promise.all(Array.from({ length: concurrency }, () => cache.getAll())); + const scanConcMs = Date.now() - scanConcStart; + + const aggConcStart = Date.now(); + await Promise.all(Array.from({ length: concurrency }, () => aggregateCache.get('all'))); + const aggConcMs = Date.now() - aggConcStart; + + console.log( + `[${configCount} configs x ${concurrency} concurrent] SCAN+GET: ${scanConcMs}ms | Single key: ${aggConcMs}ms | Speedup: ${(scanConcMs / Math.max(aggConcMs, 1)).toFixed(1)}x`, + ); + } finally { + await aggregateCache.clear(); + await cleanupKeys(`*${ns}*`); + } + }); + }); + + /** + * Benchmark 5: Alternative — Raw MGET (bypassing Keyv serialization overhead) + * + * Keyv wraps each value in { value, expires } JSON. Using raw MGET on the + * Redis client skips the Keyv layer entirely. + */ + describe('Alternative: Raw MGET vs Keyv batch GET', () => { + it('should compare raw MGET vs Keyv GET for value retrieval', async () => { + const ns = `${PREFIX}-mget`; + const configCount = 30; + const cache = await populateCache(ns, configCount); + + try { + // First, discover keys via SCAN (same for both approaches) + const pattern = `*MCP::ServersRegistry::Servers::${ns}:*`; + const keys: string[] = []; + for await (const key of (keyvRedisClient as RedisClientType).scanIterator({ + MATCH: pattern, + })) { + keys.push(key); + } + + // Approach 1: Keyv batch GET (current implementation) + const keyvCache = standardCache(`MCP::ServersRegistry::Servers::${ns}`); + const keyNames = keys.map((key) => key.substring(key.lastIndexOf(':') + 1)); + + const keyvStart = Date.now(); + await Promise.all(keyNames.map((k) => keyvCache.get(k))); + const keyvMs = Date.now() - keyvStart; + + // Approach 2: Raw MGET (no Keyv overhead) + const mgetStart = Date.now(); + if ('mGet' in keyvRedisClient!) { + const rawValues = await ( + keyvRedisClient as { mGet: (keys: string[]) => Promise<(string | null)[]> } + ).mGet(keys); + // Parse the Keyv-wrapped JSON values + rawValues.filter(Boolean).map((v) => JSON.parse(v!)); + } + const mgetMs = Date.now() - mgetStart; + + console.log( + `[${configCount} configs] Keyv batch GET: ${keyvMs}ms | Raw MGET: ${mgetMs}ms | Speedup: ${(keyvMs / Math.max(mgetMs, 1)).toFixed(1)}x`, + ); + + // Clean up + await keyvCache.clear(); + } finally { + await cleanupKeys(`*${ns}*`); + } + }); + }); +}); diff --git a/packages/api/src/mcp/registry/cache/__tests__/ServerConfigsCacheRedisAggregateKey.cache_integration.spec.ts b/packages/api/src/mcp/registry/cache/__tests__/ServerConfigsCacheRedisAggregateKey.cache_integration.spec.ts new file mode 100644 index 0000000000..4ec30187a2 --- /dev/null +++ b/packages/api/src/mcp/registry/cache/__tests__/ServerConfigsCacheRedisAggregateKey.cache_integration.spec.ts @@ -0,0 +1,338 @@ +import { expect } from '@playwright/test'; +import type { ParsedServerConfig } from '~/mcp/types'; + +describe('ServerConfigsCacheRedisAggregateKey Integration Tests', () => { + let ServerConfigsCacheRedisAggregateKey: typeof import('../ServerConfigsCacheRedisAggregateKey').ServerConfigsCacheRedisAggregateKey; + let keyvRedisClient: Awaited['keyvRedisClient']; + + let cache: InstanceType< + typeof import('../ServerConfigsCacheRedisAggregateKey').ServerConfigsCacheRedisAggregateKey + >; + + const mockConfig1 = { + type: 'stdio', + command: 'node', + args: ['server1.js'], + env: { TEST: 'value1' }, + } as ParsedServerConfig; + + const mockConfig2 = { + type: 'stdio', + command: 'python', + args: ['server2.py'], + env: { TEST: 'value2' }, + } as ParsedServerConfig; + + const mockConfig3 = { + type: 'sse', + url: 'http://localhost:3000', + requiresOAuth: true, + } as ParsedServerConfig; + + beforeAll(async () => { + process.env.USE_REDIS = process.env.USE_REDIS ?? 'true'; + process.env.USE_REDIS_CLUSTER = process.env.USE_REDIS_CLUSTER ?? 'true'; + process.env.REDIS_URI = + process.env.REDIS_URI ?? + 'redis://127.0.0.1:7001,redis://127.0.0.1:7002,redis://127.0.0.1:7003'; + process.env.REDIS_KEY_PREFIX = process.env.REDIS_KEY_PREFIX ?? 'AggregateKey-IntegrationTest'; + + const cacheModule = await import('../ServerConfigsCacheRedisAggregateKey'); + const redisClients = await import('~/cache/redisClients'); + + ServerConfigsCacheRedisAggregateKey = cacheModule.ServerConfigsCacheRedisAggregateKey; + keyvRedisClient = redisClients.keyvRedisClient; + + if (!keyvRedisClient) throw new Error('Redis client is not initialized'); + await redisClients.keyvRedisClientReady; + }); + + beforeEach(() => { + cache = new ServerConfigsCacheRedisAggregateKey('agg-test', false); + }); + + afterEach(async () => { + await cache.reset(); + }); + + afterAll(async () => { + if (keyvRedisClient?.isOpen) await keyvRedisClient.disconnect(); + }); + + describe('add and get operations', () => { + it('should add and retrieve a server config', async () => { + await cache.add('server1', mockConfig1); + const result = await cache.get('server1'); + expect(result).toMatchObject(mockConfig1); + }); + + it('should return undefined for non-existent server', async () => { + const result = await cache.get('non-existent'); + expect(result).toBeUndefined(); + }); + + it('should throw error when adding duplicate server', async () => { + await cache.add('server1', mockConfig1); + await expect(cache.add('server1', mockConfig2)).rejects.toThrow( + 'Server "server1" already exists in cache. Use update() to modify existing configs.', + ); + }); + + it('should handle multiple server configs', async () => { + await cache.add('server1', mockConfig1); + await cache.add('server2', mockConfig2); + await cache.add('server3', mockConfig3); + + expect(await cache.get('server1')).toMatchObject(mockConfig1); + expect(await cache.get('server2')).toMatchObject(mockConfig2); + expect(await cache.get('server3')).toMatchObject(mockConfig3); + }); + }); + + describe('getAll operation', () => { + it('should return empty object when no servers exist', async () => { + const result = await cache.getAll(); + expect(result).toMatchObject({}); + }); + + it('should return all server configs', async () => { + await cache.add('server1', mockConfig1); + await cache.add('server2', mockConfig2); + await cache.add('server3', mockConfig3); + + const result = await cache.getAll(); + expect(result).toMatchObject({ + server1: mockConfig1, + server2: mockConfig2, + server3: mockConfig3, + }); + }); + + it('should reflect additions in getAll', async () => { + await cache.add('server1', mockConfig1); + await cache.add('server2', mockConfig2); + + let result = await cache.getAll(); + expect(Object.keys(result).length).toBe(2); + + await cache.add('server3', mockConfig3); + result = await cache.getAll(); + expect(Object.keys(result).length).toBe(3); + expect(result.server3).toMatchObject(mockConfig3); + }); + }); + + describe('update operation', () => { + it('should update an existing server config', async () => { + await cache.add('server1', mockConfig1); + expect(await cache.get('server1')).toMatchObject(mockConfig1); + + await cache.update('server1', mockConfig2); + const result = await cache.get('server1'); + expect(result).toMatchObject(mockConfig2); + }); + + it('should throw error when updating non-existent server', async () => { + await expect(cache.update('non-existent', mockConfig1)).rejects.toThrow( + 'Server "non-existent" does not exist in cache. Use add() to create new configs.', + ); + }); + + it('should reflect updates in getAll', async () => { + await cache.add('server1', mockConfig1); + await cache.add('server2', mockConfig2); + + await cache.update('server1', mockConfig3); + const result = await cache.getAll(); + expect(result.server1).toMatchObject(mockConfig3); + expect(result.server2).toMatchObject(mockConfig2); + }); + }); + + describe('remove operation', () => { + it('should remove an existing server config', async () => { + await cache.add('server1', mockConfig1); + expect(await cache.get('server1')).toMatchObject(mockConfig1); + + await cache.remove('server1'); + expect(await cache.get('server1')).toBeUndefined(); + }); + + it('should throw error when removing non-existent server', async () => { + await expect(cache.remove('non-existent')).rejects.toThrow( + 'Failed to remove server "non-existent" in cache.', + ); + }); + + it('should remove server from getAll results', async () => { + await cache.add('server1', mockConfig1); + await cache.add('server2', mockConfig2); + + let result = await cache.getAll(); + expect(Object.keys(result).length).toBe(2); + + await cache.remove('server1'); + result = await cache.getAll(); + expect(Object.keys(result).length).toBe(1); + expect(result.server1).toBeUndefined(); + expect(result.server2).toMatchObject(mockConfig2); + }); + + it('should allow re-adding a removed server', async () => { + await cache.add('server1', mockConfig1); + await cache.remove('server1'); + await cache.add('server1', mockConfig3); + + const result = await cache.get('server1'); + expect(result).toMatchObject(mockConfig3); + }); + }); + + describe('concurrent write safety', () => { + it('should handle concurrent add calls without data loss', async () => { + const configCount = 20; + const promises = Array.from({ length: configCount }, (_, i) => + cache.add(`server-${i}`, { + type: 'stdio', + command: `cmd-${i}`, + args: [`arg-${i}`], + } as ParsedServerConfig), + ); + + const results = await Promise.allSettled(promises); + const failures = results.filter((r) => r.status === 'rejected'); + expect(failures).toHaveLength(0); + + const result = await cache.getAll(); + expect(Object.keys(result).length).toBe(configCount); + for (let i = 0; i < configCount; i++) { + expect(result[`server-${i}`]).toBeDefined(); + const config = result[`server-${i}`] as { command?: string }; + expect(config.command).toBe(`cmd-${i}`); + } + }); + + it('should handle concurrent getAll calls', async () => { + await cache.add('server1', mockConfig1); + await cache.add('server2', mockConfig2); + await cache.add('server3', mockConfig3); + + const concurrency = 50; + const promises = Array.from({ length: concurrency }, () => cache.getAll()); + const results = await Promise.all(promises); + + for (const result of results) { + expect(Object.keys(result).length).toBe(3); + expect(result.server1).toMatchObject(mockConfig1); + expect(result.server2).toMatchObject(mockConfig2); + expect(result.server3).toMatchObject(mockConfig3); + } + }); + }); + + describe('reset operation', () => { + it('should clear all configs', async () => { + await cache.add('server1', mockConfig1); + await cache.add('server2', mockConfig2); + + expect(Object.keys(await cache.getAll()).length).toBe(2); + + await cache.reset(); + + const result = await cache.getAll(); + expect(Object.keys(result).length).toBe(0); + }); + }); + + describe('local snapshot behavior', () => { + it('should collapse repeated getAll calls into a single Redis GET within TTL', async () => { + await cache.add('server1', mockConfig1); + await cache.add('server2', mockConfig2); + + // Prime the snapshot + await cache.getAll(); + + // Spy on the underlying Keyv cache to count Redis calls + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const cacheGetSpy = jest.spyOn((cache as any).cache, 'get'); + + await cache.getAll(); + await cache.getAll(); + await cache.getAll(); + + // Snapshot should be served; Redis should NOT have been called + expect(cacheGetSpy.mock.calls).toHaveLength(0); + cacheGetSpy.mockRestore(); + }); + + it('should invalidate snapshot after add', async () => { + await cache.add('server1', mockConfig1); + const before = await cache.getAll(); + expect(Object.keys(before).length).toBe(1); + + await cache.add('server2', mockConfig2); + const after = await cache.getAll(); + expect(Object.keys(after).length).toBe(2); + }); + + it('should invalidate snapshot after update and preserve other entries', async () => { + await cache.add('server1', mockConfig1); + await cache.add('server2', mockConfig2); + expect((await cache.getAll()).server1).toMatchObject(mockConfig1); + + await cache.update('server1', mockConfig3); + const after = await cache.getAll(); + expect(after.server1).toMatchObject(mockConfig3); + expect(after.server2).toMatchObject(mockConfig2); + }); + + it('should invalidate snapshot after remove', async () => { + await cache.add('server1', mockConfig1); + await cache.add('server2', mockConfig2); + expect(Object.keys(await cache.getAll()).length).toBe(2); + + await cache.remove('server1'); + const after = await cache.getAll(); + expect(Object.keys(after).length).toBe(1); + expect(after.server1).toBeUndefined(); + expect(after.server2).toMatchObject(mockConfig2); + }); + + it('should invalidate snapshot after reset', async () => { + await cache.add('server1', mockConfig1); + expect(Object.keys(await cache.getAll()).length).toBe(1); + + await cache.reset(); + expect(Object.keys(await cache.getAll()).length).toBe(0); + }); + + it('should not retroactively modify previously returned snapshot references', async () => { + await cache.add('server1', mockConfig1); + + // Prime the snapshot + const snapshot = await cache.getAll(); + expect(Object.keys(snapshot).length).toBe(1); + + // Add a second server — the original snapshot reference should be unmodified + await cache.add('server2', mockConfig2); + expect(Object.keys(snapshot).length).toBe(1); + expect(snapshot.server2).toBeUndefined(); + }); + + it('should hit Redis again after snapshot TTL expires', async () => { + await cache.add('server1', mockConfig1); + await cache.getAll(); // prime snapshot + + // Force-expire the snapshot without sleeping + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (cache as any).localSnapshotExpiry = Date.now() - 1; + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const cacheGetSpy = jest.spyOn((cache as any).cache, 'get'); + const result = await cache.getAll(); + expect(cacheGetSpy.mock.calls).toHaveLength(1); + expect(Object.keys(result).length).toBe(1); + cacheGetSpy.mockRestore(); + }); + }); +}); diff --git a/packages/api/src/mcp/registry/db/ServerConfigsDB.ts b/packages/api/src/mcp/registry/db/ServerConfigsDB.ts index 9981f6b00b..b1649c66ca 100644 --- a/packages/api/src/mcp/registry/db/ServerConfigsDB.ts +++ b/packages/api/src/mcp/registry/db/ServerConfigsDB.ts @@ -220,6 +220,25 @@ export class ServerConfigsDB implements IServerConfigsRepositoryInterface { await this._dbMethods.updateMCPServer(serverName, { config: configToSave }); } + /** + * Atomic add-or-update. For DB-backed servers this delegates to update since + * DB servers are always created via the explicit add() flow with ACL setup. + * Config-source servers should use configCacheRepo, not dbConfigsRepo. + */ + public async upsert( + serverName: string, + config: ParsedServerConfig, + userId?: string, + ): Promise { + if (!userId) { + throw new Error( + `[ServerConfigsDB.upsert] User ID is required for DB-backed MCP server upsert of "${serverName}". ` + + 'Config-source servers should use configCacheRepo, not dbConfigsRepo.', + ); + } + return this.update(serverName, config, userId); + } + /** * Deletes an MCP server and removes all associated ACL entries. * @param serverName - The serverName of the server to remove @@ -411,6 +430,7 @@ export class ServerConfigsDB implements IServerConfigsRepositoryInterface { const config: ParsedServerConfig = { ...serverDBDoc.config, dbId: (serverDBDoc._id as Types.ObjectId).toString(), + source: 'user', updatedAt: serverDBDoc.updatedAt?.getTime(), }; return await this.decryptConfig(config); diff --git a/packages/api/src/mcp/types/index.ts b/packages/api/src/mcp/types/index.ts index 6cb5e02f0b..32c2787165 100644 --- a/packages/api/src/mcp/types/index.ts +++ b/packages/api/src/mcp/types/index.ts @@ -144,6 +144,14 @@ export type ImageFormatter = (item: ImageContent) => FormattedContent; export type FormattedToolResponse = FormattedContentResult; +/** + * Origin of an MCP server definition. + * - `'yaml'` — operator-defined in librechat.yaml, full trust, boot-time init + * - `'config'` — admin-defined via Config override, full trust, lazy init + * - `'user'` — user-provided via UI, sandboxed (restricted placeholder resolution) + */ +export type MCPServerSource = 'yaml' | 'config' | 'user'; + export type ParsedServerConfig = MCPOptions & { url?: string; requiresOAuth?: boolean; @@ -154,6 +162,8 @@ export type ParsedServerConfig = MCPOptions & { initDuration?: number; updatedAt?: number; dbId?: string; + /** Origin of this server definition — determines trust level and placeholder resolution */ + source?: MCPServerSource; /** True if access is only via agent (not directly shared with user) */ consumeOnly?: boolean; /** True when inspection failed at startup; the server is known but not fully initialized */ @@ -202,6 +212,8 @@ export interface ToolDiscoveryOptions { customUserVars?: Record; requestBody?: RequestBody; connectionTimeout?: number; + /** Pre-resolved config-source servers for tenant-scoped lookup */ + configServers?: Record; } export interface ToolDiscoveryResult { diff --git a/packages/api/src/mcp/utils.ts b/packages/api/src/mcp/utils.ts index ff367725fc..653a96d5bd 100644 --- a/packages/api/src/mcp/utils.ts +++ b/packages/api/src/mcp/utils.ts @@ -8,6 +8,15 @@ export function hasCustomUserVars(config: Pick 0; } +/** + * Determines whether a server config is user-sourced (sandboxed placeholder resolution). + * When `source` is set, it is authoritative. When absent (pre-upgrade cached configs), + * falls back to the legacy `dbId` heuristic for backward compatibility. + */ +export function isUserSourced(config: Pick): boolean { + return config.source != null ? config.source === 'user' : !!config.dbId; +} + /** * Allowlist-based sanitization for API responses. Only explicitly listed fields are included; * new fields added to ParsedServerConfig are excluded by default until allowlisted here. @@ -31,6 +40,8 @@ export function redactServerSecrets(config: ParsedServerConfig): Partial` + * + * Guards against the caller passing a pre-wrapped name (one that already + * starts with the oauth prefix in its original, un-normalized form) to + * prevent double-wrapping. + */ +export function buildOAuthToolCallName(serverName: string): string { + const oauthPrefix = `oauth${Constants.mcp_delimiter}`; + if (serverName.startsWith(oauthPrefix)) { + return normalizeServerName(serverName); + } + return `${oauthPrefix}${normalizeServerName(serverName)}`; +} + /** * Sanitizes a URL by removing query parameters to prevent credential leakage in logs. * @param url - The URL to sanitize (string or URL object) diff --git a/packages/api/src/middleware/__tests__/tenant.spec.ts b/packages/api/src/middleware/__tests__/tenant.spec.ts new file mode 100644 index 0000000000..7451817941 --- /dev/null +++ b/packages/api/src/middleware/__tests__/tenant.spec.ts @@ -0,0 +1,101 @@ +import { getTenantId } from '@librechat/data-schemas'; +import type { Response, NextFunction } from 'express'; +import type { ServerRequest } from '~/types/http'; +// Import directly from source file — _resetTenantMiddlewareStrictCache is intentionally +// excluded from the public barrel export (index.ts). +import { tenantContextMiddleware, _resetTenantMiddlewareStrictCache } from '../tenant'; + +function mockReq(user?: Record): ServerRequest { + return { user } as unknown as ServerRequest; +} + +function mockRes(): Response { + const res = { + status: jest.fn().mockReturnThis(), + json: jest.fn().mockReturnThis(), + }; + return res as unknown as Response; +} + +/** Runs the middleware and returns a Promise that resolves when next() is called. */ +function runMiddleware(req: ServerRequest, res: Response): Promise { + return new Promise((resolve) => { + const next: NextFunction = () => { + resolve(getTenantId()); + }; + tenantContextMiddleware(req, res, next); + }); +} + +describe('tenantContextMiddleware', () => { + afterEach(() => { + _resetTenantMiddlewareStrictCache(); + delete process.env.TENANT_ISOLATION_STRICT; + }); + + it('sets ALS tenant context for authenticated requests with tenantId', async () => { + const req = mockReq({ tenantId: 'tenant-x', role: 'user' }); + const res = mockRes(); + + const tenantId = await runMiddleware(req, res); + expect(tenantId).toBe('tenant-x'); + }); + + it('is a no-op for unauthenticated requests (no user)', async () => { + const req = mockReq(); + const res = mockRes(); + + const tenantId = await runMiddleware(req, res); + expect(tenantId).toBeUndefined(); + }); + + it('passes through without ALS when user has no tenantId in non-strict mode', async () => { + const req = mockReq({ role: 'user' }); + const res = mockRes(); + + const tenantId = await runMiddleware(req, res); + expect(tenantId).toBeUndefined(); + }); + + it('returns 403 when user has no tenantId in strict mode', () => { + process.env.TENANT_ISOLATION_STRICT = 'true'; + _resetTenantMiddlewareStrictCache(); + + const req = mockReq({ role: 'user' }); + const res = mockRes(); + const next: NextFunction = jest.fn(); + + tenantContextMiddleware(req, res, next); + + expect(res.status).toHaveBeenCalledWith(403); + expect(res.json).toHaveBeenCalledWith( + expect.objectContaining({ error: expect.stringContaining('Tenant context required') }), + ); + expect(next).not.toHaveBeenCalled(); + }); + + it('allows authenticated requests with tenantId in strict mode', async () => { + process.env.TENANT_ISOLATION_STRICT = 'true'; + _resetTenantMiddlewareStrictCache(); + + const req = mockReq({ tenantId: 'tenant-y', role: 'admin' }); + const res = mockRes(); + + const tenantId = await runMiddleware(req, res); + expect(tenantId).toBe('tenant-y'); + }); + + it('different requests get independent tenant contexts', async () => { + const runRequest = (tid: string) => { + const req = mockReq({ tenantId: tid, role: 'user' }); + const res = mockRes(); + return runMiddleware(req, res); + }; + + const results = await Promise.all([runRequest('tenant-1'), runRequest('tenant-2')]); + + expect(results).toHaveLength(2); + expect(results).toContain('tenant-1'); + expect(results).toContain('tenant-2'); + }); +}); diff --git a/packages/api/src/middleware/balance.ts b/packages/api/src/middleware/balance.ts index 8c6b149cdd..19719680ec 100644 --- a/packages/api/src/middleware/balance.ts +++ b/packages/api/src/middleware/balance.ts @@ -12,7 +12,11 @@ import type { BalanceUpdateFields } from '~/types'; import { getBalanceConfig } from '~/app/config'; export interface BalanceMiddlewareOptions { - getAppConfig: (options?: { role?: string; refresh?: boolean }) => Promise; + getAppConfig: (options?: { + role?: string; + tenantId?: string; + refresh?: boolean; + }) => Promise; findBalanceByUser: (userId: string) => Promise; upsertBalanceFields: (userId: string, fields: IBalanceUpdate) => Promise; } @@ -92,7 +96,10 @@ export function createSetBalanceConfig({ return async (req: ServerRequest, res: ServerResponse, next: NextFunction): Promise => { try { const user = req.user as IUser & { _id: string | ObjectId }; - const appConfig = await getAppConfig({ role: user?.role }); + const appConfig = await getAppConfig({ + role: user?.role, + tenantId: user?.tenantId, + }); const balanceConfig = getBalanceConfig(appConfig); if (!balanceConfig?.enabled) { return next(); diff --git a/packages/api/src/middleware/capabilities.ts b/packages/api/src/middleware/capabilities.ts index c06a90ac8e..a3f1fe9038 100644 --- a/packages/api/src/middleware/capabilities.ts +++ b/packages/api/src/middleware/capabilities.ts @@ -9,7 +9,7 @@ import { import type { PrincipalType } from 'librechat-data-provider'; import type { SystemCapability, ConfigSection } from '@librechat/data-schemas'; import type { NextFunction, Response } from 'express'; -import type { Types } from 'mongoose'; +import type { Types, ClientSession } from 'mongoose'; import type { ServerRequest } from '~/types/http'; interface ResolvedPrincipal { @@ -18,7 +18,10 @@ interface ResolvedPrincipal { } interface CapabilityDeps { - getUserPrincipals: (params: { userId: string; role: string }) => Promise; + getUserPrincipals: ( + params: { userId: string | Types.ObjectId; role?: string | null }, + session?: ClientSession, + ) => Promise; hasCapabilityForPrincipals: (params: { principals: ResolvedPrincipal[]; capability: SystemCapability; @@ -26,7 +29,7 @@ interface CapabilityDeps { }) => Promise; } -interface CapabilityUser { +export interface CapabilityUser { id: string; role: string; tenantId?: string; @@ -48,7 +51,7 @@ export type RequireCapabilityFn = ( export type HasConfigCapabilityFn = ( user: CapabilityUser, - section: ConfigSection, + section: ConfigSection | null, verb?: 'manage' | 'read', ) => Promise; @@ -138,11 +141,14 @@ export function generateCapabilityCheck(deps: CapabilityDeps): { */ async function hasConfigCapability( user: CapabilityUser, - section: ConfigSection, + section: ConfigSection | null, verb: 'manage' | 'read' = 'manage', ): Promise { const broadCap = verb === 'manage' ? SystemCapabilities.MANAGE_CONFIGS : SystemCapabilities.READ_CONFIGS; + if (section == null) { + return hasCapability(user, broadCap); + } if (await hasCapability(user, broadCap)) { return true; } diff --git a/packages/api/src/middleware/index.ts b/packages/api/src/middleware/index.ts index a56b8e4a3e..b91fee2999 100644 --- a/packages/api/src/middleware/index.ts +++ b/packages/api/src/middleware/index.ts @@ -5,5 +5,7 @@ export * from './notFound'; export * from './balance'; export * from './json'; export * from './capabilities'; +export { tenantContextMiddleware } from './tenant'; +export { preAuthTenantMiddleware } from './preAuthTenant'; export * from './concurrency'; export * from './checkBalance'; diff --git a/packages/api/src/middleware/preAuthTenant.spec.ts b/packages/api/src/middleware/preAuthTenant.spec.ts new file mode 100644 index 0000000000..669a43c84f --- /dev/null +++ b/packages/api/src/middleware/preAuthTenant.spec.ts @@ -0,0 +1,129 @@ +import { getTenantId, logger } from '@librechat/data-schemas'; +import { preAuthTenantMiddleware } from './preAuthTenant'; +import type { Request, Response, NextFunction } from 'express'; + +jest.mock('@librechat/data-schemas', () => ({ + ...jest.requireActual('@librechat/data-schemas'), + logger: { + warn: jest.fn(), + error: jest.fn(), + info: jest.fn(), + debug: jest.fn(), + }, +})); + +describe('preAuthTenantMiddleware', () => { + let req: { headers: Record; ip?: string; path?: string }; + let res: Partial; + + beforeEach(() => { + jest.clearAllMocks(); + req = { headers: {} }; + res = {}; + }); + + it('calls next() without ALS context when no X-Tenant-Id header is present', () => { + let capturedTenantId: string | undefined = 'sentinel'; + const capturedNext: NextFunction = () => { + capturedTenantId = getTenantId(); + }; + + preAuthTenantMiddleware(req as Request, res as Response, capturedNext); + expect(capturedTenantId).toBeUndefined(); + }); + + it('calls next() without ALS context when X-Tenant-Id header is empty', () => { + req.headers = { 'x-tenant-id': '' }; + let capturedTenantId: string | undefined = 'sentinel'; + const capturedNext: NextFunction = () => { + capturedTenantId = getTenantId(); + }; + + preAuthTenantMiddleware(req as Request, res as Response, capturedNext); + expect(capturedTenantId).toBeUndefined(); + }); + + it('wraps downstream in ALS context when X-Tenant-Id header is present', () => { + req.headers = { 'x-tenant-id': 'acme-corp' }; + let capturedTenantId: string | undefined; + const capturedNext: NextFunction = () => { + capturedTenantId = getTenantId(); + }; + + preAuthTenantMiddleware(req as Request, res as Response, capturedNext); + expect(capturedTenantId).toBe('acme-corp'); + }); + + it('ignores __SYSTEM__ sentinel and logs warning', () => { + req.headers = { 'x-tenant-id': '__SYSTEM__' }; + req.ip = '10.0.0.1'; + req.path = '/api/config'; + let capturedTenantId: string | undefined = 'should-be-overwritten'; + const capturedNext: NextFunction = () => { + capturedTenantId = getTenantId(); + }; + + preAuthTenantMiddleware(req as Request, res as Response, capturedNext); + expect(capturedTenantId).toBeUndefined(); + expect(logger.warn).toHaveBeenCalledWith( + expect.stringContaining('__SYSTEM__'), + expect.objectContaining({ ip: '10.0.0.1', path: '/api/config' }), + ); + }); + + it('ignores array-valued headers (Express can produce these)', () => { + req.headers = { 'x-tenant-id': ['a', 'b'] as unknown as string }; + let capturedTenantId: string | undefined = 'sentinel'; + const capturedNext: NextFunction = () => { + capturedTenantId = getTenantId(); + }; + + preAuthTenantMiddleware(req as Request, res as Response, capturedNext); + expect(capturedTenantId).toBeUndefined(); + }); + + it('ignores tenant IDs containing invalid characters and logs warning', () => { + req.headers = { 'x-tenant-id': 'tenant:injected' }; + req.ip = '192.168.1.1'; + req.path = '/api/auth/login'; + let capturedTenantId: string | undefined = 'sentinel'; + const capturedNext: NextFunction = () => { + capturedTenantId = getTenantId(); + }; + + preAuthTenantMiddleware(req as Request, res as Response, capturedNext); + expect(capturedTenantId).toBeUndefined(); + expect(logger.warn).toHaveBeenCalledWith( + expect.stringContaining('malformed'), + expect.objectContaining({ ip: '192.168.1.1', path: '/api/auth/login' }), + ); + }); + + it('trims whitespace from tenant ID header', () => { + req.headers = { 'x-tenant-id': ' acme-corp ' }; + let capturedTenantId: string | undefined; + const capturedNext: NextFunction = () => { + capturedTenantId = getTenantId(); + }; + + preAuthTenantMiddleware(req as Request, res as Response, capturedNext); + expect(capturedTenantId).toBe('acme-corp'); + }); + + it('ignores tenant IDs exceeding max length and logs warning', () => { + req.headers = { 'x-tenant-id': 'a'.repeat(200) }; + req.ip = '192.168.1.1'; + req.path = '/api/share/abc'; + let capturedTenantId: string | undefined = 'sentinel'; + const capturedNext: NextFunction = () => { + capturedTenantId = getTenantId(); + }; + + preAuthTenantMiddleware(req as Request, res as Response, capturedNext); + expect(capturedTenantId).toBeUndefined(); + expect(logger.warn).toHaveBeenCalledWith( + expect.stringContaining('malformed'), + expect.objectContaining({ ip: '192.168.1.1', length: 200, path: '/api/share/abc' }), + ); + }); +}); diff --git a/packages/api/src/middleware/preAuthTenant.ts b/packages/api/src/middleware/preAuthTenant.ts new file mode 100644 index 0000000000..bab91f3a18 --- /dev/null +++ b/packages/api/src/middleware/preAuthTenant.ts @@ -0,0 +1,72 @@ +import { tenantStorage, logger, SYSTEM_TENANT_ID } from '@librechat/data-schemas'; +import type { Request, Response, NextFunction } from 'express'; + +/** + * Pre-authentication tenant context middleware for unauthenticated routes. + * + * Reads the tenant identifier from the `X-Tenant-Id` request header and wraps + * downstream handlers in `tenantStorage.run()` so that Mongoose queries and + * config resolution run within the correct tenant scope. + * + * **Where to use**: Mount on routes that must be tenant-aware before + * authentication has occurred: + * - `GET /api/config` — login page needs tenant-specific config (social logins, registration) + * - `/api/auth/*` — login, register, password reset + * - `/oauth/*` — OAuth callback flows + * - `GET /api/share/:shareId` — public shared conversation links + * + * **How the header gets set**: The deployment's reverse proxy, auth gateway, + * or OpenID strategy sets `X-Tenant-Id` based on subdomain, path, or OIDC claim. + * This middleware does NOT resolve tenants from subdomains or tokens — that is + * the responsibility of the deployment layer. + * + * **Design**: Intentionally minimal. No subdomain parsing, no OIDC claim + * extraction, no YAML-driven strategy. Multi-tenant deployments can: + * 1. Set the header in the reverse proxy / ingress (simplest), + * 2. Replace this middleware's resolver logic entirely, or + * 3. Layer additional resolution on top (e.g., OpenID `tenant` claim → header). + * + * If no header is present, downstream runs without tenant ALS context (same as + * single-tenant mode). This preserves backward compatibility. + */ +const MAX_TENANT_ID_LENGTH = 128; +const VALID_TENANT_ID = /^[-a-zA-Z0-9_.]+$/; + +export function preAuthTenantMiddleware(req: Request, res: Response, next: NextFunction): void { + const raw = req.headers['x-tenant-id']; + + if (!raw || typeof raw !== 'string') { + next(); + return; + } + + const tenantId = raw.trim(); + + if (!tenantId) { + next(); + return; + } + + if (tenantId === SYSTEM_TENANT_ID) { + logger.warn('[preAuthTenant] Rejected __SYSTEM__ sentinel in X-Tenant-Id header', { + ip: req.ip, + path: req.path, + }); + next(); + return; + } + + if (tenantId.length > MAX_TENANT_ID_LENGTH || !VALID_TENANT_ID.test(tenantId)) { + logger.warn('[preAuthTenant] Rejected malformed X-Tenant-Id header', { + ip: req.ip, + length: tenantId.length, + path: req.path, + }); + next(); + return; + } + + return void tenantStorage.run({ tenantId }, async () => { + next(); + }); +} diff --git a/packages/api/src/middleware/tenant.ts b/packages/api/src/middleware/tenant.ts new file mode 100644 index 0000000000..0b0e003991 --- /dev/null +++ b/packages/api/src/middleware/tenant.ts @@ -0,0 +1,70 @@ +import { isMainThread } from 'worker_threads'; +import { tenantStorage, logger } from '@librechat/data-schemas'; +import type { Response, NextFunction } from 'express'; +import type { ServerRequest } from '~/types/http'; + +let _checkedThread = false; + +let _strictMode: boolean | undefined; + +function isStrict(): boolean { + return (_strictMode ??= process.env.TENANT_ISOLATION_STRICT === 'true'); +} + +/** Resets the cached strict-mode flag. Exposed for test teardown only. */ +export function _resetTenantMiddlewareStrictCache(): void { + _strictMode = undefined; +} + +/** + * Express middleware that propagates the authenticated user's `tenantId` into + * the AsyncLocalStorage context used by the Mongoose tenant-isolation plugin. + * + * **Placement**: Chained automatically by `requireJwtAuth` after successful + * passport authentication (req.user is populated). Must NOT be registered at + * global `app.use()` scope — `req.user` is undefined at that stage. + * + * Behaviour: + * - Authenticated request with `tenantId` → wraps downstream in `tenantStorage.run({ tenantId })` + * - Authenticated request **without** `tenantId`: + * - Strict mode (`TENANT_ISOLATION_STRICT=true`) → responds 403 + * - Non-strict (default) → passes through without ALS context (backward compat) + * - Unauthenticated request → no-op (calls `next()` directly) + */ +export function tenantContextMiddleware( + req: ServerRequest, + res: Response, + next: NextFunction, +): void { + if (!_checkedThread) { + _checkedThread = true; + if (!isMainThread) { + logger.error( + '[tenantContextMiddleware] Running in a worker thread — ' + + 'ALS context will not propagate. This middleware must only run in the main Express process.', + ); + } + } + + const user = req.user as { tenantId?: string } | undefined; + + if (!user) { + next(); + return; + } + + const tenantId = user.tenantId; + + if (!tenantId) { + if (isStrict()) { + res.status(403).json({ error: 'Tenant context required in strict isolation mode' }); + return; + } + next(); + return; + } + + return void tenantStorage.run({ tenantId }, async () => { + next(); + }); +} diff --git a/packages/api/src/stream/GenerationJobManager.ts b/packages/api/src/stream/GenerationJobManager.ts index 3e04ab734b..5993c911ff 100644 --- a/packages/api/src/stream/GenerationJobManager.ts +++ b/packages/api/src/stream/GenerationJobManager.ts @@ -1,4 +1,4 @@ -import { logger } from '@librechat/data-schemas'; +import { logger, getTenantId, SYSTEM_TENANT_ID } from '@librechat/data-schemas'; import type { StandardGraph } from '@librechat/agents'; import { parseTextParts } from 'librechat-data-provider'; import type { Agents, TMessageContentParts } from 'librechat-data-provider'; @@ -197,7 +197,9 @@ class GenerationJobManagerClass { userId: string, conversationId?: string, ): Promise { - const jobData = await this.jobStore.createJob(streamId, userId, conversationId); + const tenantId = getTenantId(); + const safeTenantId = tenantId && tenantId !== SYSTEM_TENANT_ID ? tenantId : undefined; + const jobData = await this.jobStore.createJob(streamId, userId, conversationId, safeTenantId); /** * Create runtime state with readyPromise. @@ -355,6 +357,7 @@ class GenerationJobManagerClass { error: jobData.error, metadata: { userId: jobData.userId, + tenantId: jobData.tenantId, conversationId: jobData.conversationId, userMessage: jobData.userMessage, responseMessageId: jobData.responseMessageId, @@ -1255,8 +1258,8 @@ class GenerationJobManagerClass { * @param userId - The user ID to query * @returns Array of conversation IDs with active jobs */ - async getActiveJobIdsForUser(userId: string): Promise { - return this.jobStore.getActiveJobIdsByUser(userId); + async getActiveJobIdsForUser(userId: string, tenantId?: string): Promise { + return this.jobStore.getActiveJobIdsByUser(userId, tenantId); } /** diff --git a/packages/api/src/stream/implementations/InMemoryJobStore.ts b/packages/api/src/stream/implementations/InMemoryJobStore.ts index cc82a69963..7280c3ce80 100644 --- a/packages/api/src/stream/implementations/InMemoryJobStore.ts +++ b/packages/api/src/stream/implementations/InMemoryJobStore.ts @@ -70,6 +70,7 @@ export class InMemoryJobStore implements IJobStore { streamId: string, userId: string, conversationId?: string, + tenantId?: string, ): Promise { if (this.jobs.size >= this.maxJobs) { await this.evictOldest(); @@ -78,6 +79,7 @@ export class InMemoryJobStore implements IJobStore { const job: SerializableJobData = { streamId, userId, + ...(tenantId && { tenantId }), status: 'running', createdAt: Date.now(), conversationId, @@ -86,11 +88,12 @@ export class InMemoryJobStore implements IJobStore { this.jobs.set(streamId, job); - // Track job by userId for efficient user-scoped queries - let userJobs = this.userJobMap.get(userId); + // Track job by userId (tenant-qualified when available) for efficient user-scoped queries + const userKey = tenantId ? `${tenantId}:${userId}` : userId; + let userJobs = this.userJobMap.get(userKey); if (!userJobs) { userJobs = new Set(); - this.userJobMap.set(userId, userJobs); + this.userJobMap.set(userKey, userJobs); } userJobs.add(streamId); @@ -146,6 +149,17 @@ export class InMemoryJobStore implements IJobStore { } for (const id of toDelete) { + const job = this.jobs.get(id); + if (job) { + const userKey = job.tenantId ? `${job.tenantId}:${job.userId}` : job.userId; + const userJobs = this.userJobMap.get(userKey); + if (userJobs) { + userJobs.delete(id); + if (userJobs.size === 0) { + this.userJobMap.delete(userKey); + } + } + } await this.deleteJob(id); } @@ -169,6 +183,17 @@ export class InMemoryJobStore implements IJobStore { if (oldestId) { logger.warn(`[InMemoryJobStore] Evicting oldest job: ${oldestId}`); + const job = this.jobs.get(oldestId); + if (job) { + const userKey = job.tenantId ? `${job.tenantId}:${job.userId}` : job.userId; + const userJobs = this.userJobMap.get(userKey); + if (userJobs) { + userJobs.delete(oldestId); + if (userJobs.size === 0) { + this.userJobMap.delete(userKey); + } + } + } await this.deleteJob(oldestId); } } @@ -205,8 +230,9 @@ export class InMemoryJobStore implements IJobStore { * Returns conversation IDs of running jobs belonging to the user. * Also performs self-healing cleanup: removes stale entries for jobs that no longer exist. */ - async getActiveJobIdsByUser(userId: string): Promise { - const trackedIds = this.userJobMap.get(userId); + async getActiveJobIdsByUser(userId: string, tenantId?: string): Promise { + const userKey = tenantId ? `${tenantId}:${userId}` : userId; + const trackedIds = this.userJobMap.get(userKey); if (!trackedIds || trackedIds.size === 0) { return []; } @@ -226,7 +252,7 @@ export class InMemoryJobStore implements IJobStore { // Clean up empty set if (trackedIds.size === 0) { - this.userJobMap.delete(userId); + this.userJobMap.delete(userKey); } return activeIds; diff --git a/packages/api/src/stream/implementations/RedisJobStore.ts b/packages/api/src/stream/implementations/RedisJobStore.ts index 727fe066eb..a631bc2044 100644 --- a/packages/api/src/stream/implementations/RedisJobStore.ts +++ b/packages/api/src/stream/implementations/RedisJobStore.ts @@ -29,8 +29,9 @@ const KEYS = { runSteps: (streamId: string) => `stream:{${streamId}}:runsteps`, /** Running jobs set for cleanup (global set - single slot) */ runningJobs: 'stream:running', - /** User's active jobs set: stream:user:{userId}:jobs */ - userJobs: (userId: string) => `stream:user:{${userId}}:jobs`, + /** User's active jobs set, tenant-qualified when tenantId is available */ + userJobs: (userId: string, tenantId?: string) => + tenantId ? `stream:user:{${tenantId}:${userId}}:jobs` : `stream:user:{${userId}}:jobs`, }; /** @@ -140,10 +141,12 @@ export class RedisJobStore implements IJobStore { streamId: string, userId: string, conversationId?: string, + tenantId?: string, ): Promise { const job: SerializableJobData = { streamId, userId, + ...(tenantId && { tenantId }), status: 'running', createdAt: Date.now(), conversationId, @@ -151,7 +154,7 @@ export class RedisJobStore implements IJobStore { }; const key = KEYS.job(streamId); - const userJobsKey = KEYS.userJobs(userId); + const userJobsKey = KEYS.userJobs(userId, tenantId); // For cluster mode, we can't pipeline keys on different slots // The job key uses hash tag {streamId}, runningJobs and userJobs are on different slots @@ -377,8 +380,8 @@ export class RedisJobStore implements IJobStore { * @param userId - The user ID to query * @returns Array of conversation IDs with active jobs */ - async getActiveJobIdsByUser(userId: string): Promise { - const userJobsKey = KEYS.userJobs(userId); + async getActiveJobIdsByUser(userId: string, tenantId?: string): Promise { + const userJobsKey = KEYS.userJobs(userId, tenantId); const trackedIds = await this.redis.smembers(userJobsKey); if (trackedIds.length === 0) { @@ -868,6 +871,7 @@ export class RedisJobStore implements IJobStore { return { streamId: data.streamId, userId: data.userId, + tenantId: data.tenantId || undefined, status: data.status as JobStatus, createdAt: parseInt(data.createdAt, 10), completedAt: data.completedAt ? parseInt(data.completedAt, 10) : undefined, diff --git a/packages/api/src/stream/interfaces/IJobStore.ts b/packages/api/src/stream/interfaces/IJobStore.ts index fadddb840d..b59eed66f8 100644 --- a/packages/api/src/stream/interfaces/IJobStore.ts +++ b/packages/api/src/stream/interfaces/IJobStore.ts @@ -12,6 +12,7 @@ export type JobStatus = 'running' | 'complete' | 'error' | 'aborted'; export interface SerializableJobData { streamId: string; userId: string; + tenantId?: string; status: JobStatus; createdAt: number; completedAt?: number; @@ -149,6 +150,7 @@ export interface IJobStore { streamId: string, userId: string, conversationId?: string, + tenantId?: string, ): Promise; /** Get a job by streamId (streamId === conversationId) */ @@ -186,7 +188,7 @@ export interface IJobStore { * @param userId - The user ID to query * @returns Array of conversation IDs with active jobs */ - getActiveJobIdsByUser(userId: string): Promise; + getActiveJobIdsByUser(userId: string, tenantId?: string): Promise; // ===== Content State Methods ===== // These methods manage volatile content state tied to each job. diff --git a/packages/api/src/types/es2024-string.d.ts b/packages/api/src/types/es2024-string.d.ts new file mode 100644 index 0000000000..f25bc46bda --- /dev/null +++ b/packages/api/src/types/es2024-string.d.ts @@ -0,0 +1,4 @@ +/** String.prototype.isWellFormed — ES2024 API, available in Node 20+ but absent from TS 5.3 lib */ +interface String { + isWellFormed(): boolean; +} diff --git a/packages/api/src/types/stream.ts b/packages/api/src/types/stream.ts index 068d9c8db8..dd125a1aab 100644 --- a/packages/api/src/types/stream.ts +++ b/packages/api/src/types/stream.ts @@ -4,6 +4,7 @@ import type { ServerSentEvent } from '~/types'; export interface GenerationJobMetadata { userId: string; + tenantId?: string; conversationId?: string; /** User message data for rebuilding submission on reconnect */ userMessage?: Agents.UserMessageMeta; diff --git a/packages/api/src/utils/env.ts b/packages/api/src/utils/env.ts index adeeb24b34..f71a131c09 100644 --- a/packages/api/src/utils/env.ts +++ b/packages/api/src/utils/env.ts @@ -84,12 +84,12 @@ export function encodeHeaderValue(value: string): string { */ export function createSafeUser( user: IUser | null | undefined, -): Partial & { federatedTokens?: unknown } { +): Partial & { federatedTokens?: IUser['federatedTokens'] } { if (!user) { return {}; } - const safeUser: Partial & { federatedTokens?: unknown } = {}; + const safeUser: Partial & { federatedTokens?: IUser['federatedTokens'] } = {}; for (const field of ALLOWED_USER_FIELDS) { if (field in user) { safeUser[field] = user[field]; diff --git a/packages/api/src/utils/graph.spec.ts b/packages/api/src/utils/graph.spec.ts index 4f1fa14983..91f8a29eff 100644 --- a/packages/api/src/utils/graph.spec.ts +++ b/packages/api/src/utils/graph.spec.ts @@ -1,4 +1,4 @@ -import type { TUser } from 'librechat-data-provider'; +import type { IUser } from '@librechat/data-schemas'; import type { GraphTokenResolver, GraphTokenOptions } from './graph'; import { containsGraphTokenPlaceholder, @@ -94,9 +94,9 @@ describe('Graph Token Utilities', () => { }); it('should return false for non-object values', () => { - expect(recordContainsGraphTokenPlaceholder('string' as unknown as Record)).toBe( - false, - ); + expect( + recordContainsGraphTokenPlaceholder('string' as unknown as Record), + ).toBe(false); }); }); @@ -141,7 +141,7 @@ describe('Graph Token Utilities', () => { }); describe('resolveGraphTokenPlaceholder', () => { - const mockUser: Partial = { + const mockUser: Partial = { id: 'user-123', provider: 'openid', openidId: 'oidc-sub-456', @@ -157,7 +157,7 @@ describe('Graph Token Utilities', () => { it('should return original value when no placeholder is present', async () => { const value = 'Bearer static-token'; const result = await resolveGraphTokenPlaceholder(value, { - user: mockUser as TUser, + user: mockUser as Partial as IUser, graphTokenResolver: mockGraphTokenResolver, }); expect(result).toBe('Bearer static-token'); @@ -174,7 +174,7 @@ describe('Graph Token Utilities', () => { it('should return original value when graphTokenResolver is not provided', async () => { const value = 'Bearer {{LIBRECHAT_GRAPH_ACCESS_TOKEN}}'; const result = await resolveGraphTokenPlaceholder(value, { - user: mockUser as TUser, + user: mockUser as Partial as IUser, }); expect(result).toBe(value); }); @@ -184,7 +184,7 @@ describe('Graph Token Utilities', () => { const value = 'Bearer {{LIBRECHAT_GRAPH_ACCESS_TOKEN}}'; const result = await resolveGraphTokenPlaceholder(value, { - user: mockUser as TUser, + user: mockUser as Partial as IUser, graphTokenResolver: mockGraphTokenResolver, }); expect(result).toBe(value); @@ -196,7 +196,7 @@ describe('Graph Token Utilities', () => { const value = 'Bearer {{LIBRECHAT_GRAPH_ACCESS_TOKEN}}'; const result = await resolveGraphTokenPlaceholder(value, { - user: mockUser as TUser, + user: mockUser as Partial as IUser, graphTokenResolver: mockGraphTokenResolver, }); expect(result).toBe(value); @@ -208,7 +208,7 @@ describe('Graph Token Utilities', () => { const value = 'Bearer {{LIBRECHAT_GRAPH_ACCESS_TOKEN}}'; const result = await resolveGraphTokenPlaceholder(value, { - user: mockUser as TUser, + user: mockUser as Partial as IUser, graphTokenResolver: mockGraphTokenResolver, }); expect(result).toBe(value); @@ -220,7 +220,7 @@ describe('Graph Token Utilities', () => { const value = 'Bearer {{LIBRECHAT_GRAPH_ACCESS_TOKEN}}'; const result = await resolveGraphTokenPlaceholder(value, { - user: mockUser as TUser, + user: mockUser as Partial as IUser, graphTokenResolver: mockGraphTokenResolver, }); expect(result).toBe('Bearer resolved-graph-token'); @@ -233,7 +233,7 @@ describe('Graph Token Utilities', () => { const value = 'Primary: {{LIBRECHAT_GRAPH_ACCESS_TOKEN}}, Secondary: {{LIBRECHAT_GRAPH_ACCESS_TOKEN}}'; const result = await resolveGraphTokenPlaceholder(value, { - user: mockUser as TUser, + user: mockUser as Partial as IUser, graphTokenResolver: mockGraphTokenResolver, }); expect(result).toBe('Primary: resolved-graph-token, Secondary: resolved-graph-token'); @@ -242,11 +242,13 @@ describe('Graph Token Utilities', () => { it('should return original value when graph token exchange fails', async () => { mockExtractOpenIDTokenInfo.mockReturnValue({ accessToken: 'access-token' }); mockIsOpenIDTokenValid.mockReturnValue(true); - const failingResolver: GraphTokenResolver = jest.fn().mockRejectedValue(new Error('Exchange failed')); + const failingResolver: GraphTokenResolver = jest + .fn() + .mockRejectedValue(new Error('Exchange failed')); const value = 'Bearer {{LIBRECHAT_GRAPH_ACCESS_TOKEN}}'; const result = await resolveGraphTokenPlaceholder(value, { - user: mockUser as TUser, + user: mockUser as Partial as IUser, graphTokenResolver: failingResolver, }); expect(result).toBe(value); @@ -259,7 +261,7 @@ describe('Graph Token Utilities', () => { const value = 'Bearer {{LIBRECHAT_GRAPH_ACCESS_TOKEN}}'; const result = await resolveGraphTokenPlaceholder(value, { - user: mockUser as TUser, + user: mockUser as Partial as IUser, graphTokenResolver: emptyResolver, }); expect(result).toBe(value); @@ -271,7 +273,7 @@ describe('Graph Token Utilities', () => { const value = 'Bearer {{LIBRECHAT_GRAPH_ACCESS_TOKEN}}'; await resolveGraphTokenPlaceholder(value, { - user: mockUser as TUser, + user: mockUser as Partial as IUser, graphTokenResolver: mockGraphTokenResolver, scopes: 'custom-scope', }); @@ -286,7 +288,7 @@ describe('Graph Token Utilities', () => { }); describe('resolveGraphTokensInRecord', () => { - const mockUser: Partial = { + const mockUser: Partial = { id: 'user-123', provider: 'openid', }; @@ -299,7 +301,7 @@ describe('Graph Token Utilities', () => { }); const options: GraphTokenOptions = { - user: mockUser as TUser, + user: mockUser as Partial as IUser, graphTokenResolver: mockGraphTokenResolver, }; @@ -348,7 +350,7 @@ describe('Graph Token Utilities', () => { }); describe('preProcessGraphTokens', () => { - const mockUser: Partial = { + const mockUser: Partial = { id: 'user-123', provider: 'openid', }; @@ -361,7 +363,7 @@ describe('Graph Token Utilities', () => { }); const graphOptions: GraphTokenOptions = { - user: mockUser as TUser, + user: mockUser as Partial as IUser, graphTokenResolver: mockGraphTokenResolver, }; diff --git a/packages/api/src/utils/oidc.spec.ts b/packages/api/src/utils/oidc.spec.ts index 0d7216304b..e7088d9897 100644 --- a/packages/api/src/utils/oidc.spec.ts +++ b/packages/api/src/utils/oidc.spec.ts @@ -1,10 +1,10 @@ import { extractOpenIDTokenInfo, isOpenIDTokenValid, processOpenIDPlaceholders } from './oidc'; -import type { TUser } from 'librechat-data-provider'; +import type { IUser } from '@librechat/data-schemas'; describe('OpenID Token Utilities', () => { describe('extractOpenIDTokenInfo', () => { it('should extract token info from user with federatedTokens', () => { - const user: Partial = { + const user: Partial = { id: 'user-123', provider: 'openid', openidId: 'oidc-sub-456', @@ -36,7 +36,7 @@ describe('OpenID Token Utilities', () => { }); it('should return null when user is not OpenID provider', () => { - const user: Partial = { + const user: Partial = { id: 'user-123', provider: 'email', }; @@ -46,7 +46,7 @@ describe('OpenID Token Utilities', () => { }); it('should return token info when user has no federatedTokens but is OpenID provider', () => { - const user: Partial = { + const user: Partial = { id: 'user-123', provider: 'openid', openidId: 'oidc-sub-456', @@ -66,7 +66,7 @@ describe('OpenID Token Utilities', () => { }); it('should extract partial token info when some tokens are missing', () => { - const user: Partial = { + const user: Partial = { id: 'user-123', provider: 'openid', openidId: 'oidc-sub-456', @@ -89,7 +89,7 @@ describe('OpenID Token Utilities', () => { }); it('should prioritize openidId over regular id', () => { - const user: Partial = { + const user: Partial = { id: 'user-123', provider: 'openid', openidId: 'oidc-sub-456', @@ -104,7 +104,7 @@ describe('OpenID Token Utilities', () => { }); it('should fall back to regular id when openidId is not available', () => { - const user: Partial = { + const user: Partial = { id: 'user-123', provider: 'openid', federatedTokens: { @@ -397,7 +397,7 @@ describe('OpenID Token Utilities', () => { describe('Integration: Full OpenID Token Flow', () => { it('should extract, validate, and process tokens correctly', () => { - const user: Partial = { + const user: Partial = { id: 'user-123', provider: 'openid', openidId: 'oidc-sub-456', @@ -428,7 +428,7 @@ describe('OpenID Token Utilities', () => { }); it('should resolve LIBRECHAT_OPENID_ID_TOKEN and LIBRECHAT_OPENID_ACCESS_TOKEN to different values', () => { - const user: Partial = { + const user: Partial = { id: 'user-123', provider: 'openid', openidId: 'oidc-sub-456', @@ -457,7 +457,7 @@ describe('OpenID Token Utilities', () => { }); it('should handle expired tokens correctly', () => { - const user: Partial = { + const user: Partial = { id: 'user-123', provider: 'openid', openidId: 'oidc-sub-456', @@ -481,7 +481,7 @@ describe('OpenID Token Utilities', () => { }); it('should handle user with no federatedTokens but still has OpenID provider', () => { - const user: Partial = { + const user: Partial = { id: 'user-123', provider: 'openid', openidId: 'oidc-sub-456', @@ -499,7 +499,7 @@ describe('OpenID Token Utilities', () => { }); it('should handle non-OpenID users', () => { - const user: Partial = { + const user: Partial = { id: 'user-123', provider: 'email', }; diff --git a/packages/api/src/utils/oidc.ts b/packages/api/src/utils/oidc.ts index dbf41818c4..51056406c1 100644 --- a/packages/api/src/utils/oidc.ts +++ b/packages/api/src/utils/oidc.ts @@ -1,5 +1,5 @@ import { logger } from '@librechat/data-schemas'; -import type { IUser } from '@librechat/data-schemas'; +import type { IUser, OIDCTokens } from '@librechat/data-schemas'; export interface OpenIDTokenInfo { accessToken?: string; @@ -11,14 +11,7 @@ export interface OpenIDTokenInfo { claims?: Record; } -interface FederatedTokens { - access_token?: string; - id_token?: string; - refresh_token?: string; - expires_at?: number; -} - -function isFederatedTokens(obj: unknown): obj is FederatedTokens { +function isFederatedTokens(obj: unknown): obj is OIDCTokens { if (!obj || typeof obj !== 'object') { return false; } @@ -61,23 +54,24 @@ export function extractOpenIDTokenInfo( const tokenInfo: OpenIDTokenInfo = {}; - if ('federatedTokens' in user && isFederatedTokens(user.federatedTokens)) { - const tokens = user.federatedTokens; + const federated = user.federatedTokens; + const openid = user.openidTokens; + + if (federated && isFederatedTokens(federated)) { logger.debug('[extractOpenIDTokenInfo] Found federatedTokens:', { - has_access_token: !!tokens.access_token, - has_id_token: !!tokens.id_token, - has_refresh_token: !!tokens.refresh_token, - expires_at: tokens.expires_at, + has_access_token: !!federated.access_token, + has_id_token: !!federated.id_token, + has_refresh_token: !!federated.refresh_token, + expires_at: federated.expires_at, }); - tokenInfo.accessToken = tokens.access_token; - tokenInfo.idToken = tokens.id_token; - tokenInfo.expiresAt = tokens.expires_at; - } else if ('openidTokens' in user && isFederatedTokens(user.openidTokens)) { - const tokens = user.openidTokens; + tokenInfo.accessToken = federated.access_token; + tokenInfo.idToken = federated.id_token; + tokenInfo.expiresAt = federated.expires_at; + } else if (openid && isFederatedTokens(openid)) { logger.debug('[extractOpenIDTokenInfo] Found openidTokens'); - tokenInfo.accessToken = tokens.access_token; - tokenInfo.idToken = tokens.id_token; - tokenInfo.expiresAt = tokens.expires_at; + tokenInfo.accessToken = openid.access_token; + tokenInfo.idToken = openid.id_token; + tokenInfo.expiresAt = openid.expires_at; } tokenInfo.userId = user.openidId || user.id; diff --git a/packages/api/src/utils/tokens.ts b/packages/api/src/utils/tokens.ts index ae09da4f28..14215698a6 100644 --- a/packages/api/src/utils/tokens.ts +++ b/packages/api/src/utils/tokens.ts @@ -72,6 +72,7 @@ const mistralModels = { 'mistral-large-2402': 127500, 'mistral-large-2407': 127500, 'mistral-large': 131000, + 'mistral-large-3': 255000, 'mistral-saba': 32000, 'ministral-3b': 131000, 'ministral-8b': 131000, diff --git a/packages/api/types/index.d.ts b/packages/api/types/index.d.ts new file mode 100644 index 0000000000..f25bc46bda --- /dev/null +++ b/packages/api/types/index.d.ts @@ -0,0 +1,4 @@ +/** String.prototype.isWellFormed — ES2024 API, available in Node 20+ but absent from TS 5.3 lib */ +interface String { + isWellFormed(): boolean; +} diff --git a/packages/client/src/components/OGDialogTemplate.tsx b/packages/client/src/components/OGDialogTemplate.tsx index 300ae5b194..2414915a4b 100644 --- a/packages/client/src/components/OGDialogTemplate.tsx +++ b/packages/client/src/components/OGDialogTemplate.tsx @@ -80,9 +80,8 @@ const OGDialogTemplate = forwardRef((props: DialogTemplateProps, ref: Ref; + file?: Partial & { progress?: number }; fileType: { fill: string; paths: React.FC; diff --git a/packages/data-provider/package.json b/packages/data-provider/package.json index a66f4eec4e..0cbe9258f2 100644 --- a/packages/data-provider/package.json +++ b/packages/data-provider/package.json @@ -1,6 +1,6 @@ { "name": "librechat-data-provider", - "version": "0.8.401", + "version": "0.8.405", "description": "data services for librechat apps", "main": "dist/index.js", "module": "dist/index.es.js", @@ -18,6 +18,9 @@ "types": "./dist/types/react-query/index.d.ts" } }, + "files": [ + "dist" + ], "scripts": { "clean": "rimraf dist", "build": "npm run clean && rollup -c --silent --bundleConfigAsCjs", diff --git a/packages/data-provider/specs/config-schemas.spec.ts b/packages/data-provider/specs/config-schemas.spec.ts new file mode 100644 index 0000000000..fabd35cec9 --- /dev/null +++ b/packages/data-provider/specs/config-schemas.spec.ts @@ -0,0 +1,254 @@ +import { + endpointSchema, + paramDefinitionSchema, + agentsEndpointSchema, + azureEndpointSchema, +} from '../src/config'; +import { tModelSpecPresetSchema, EModelEndpoint } from '../src/schemas'; + +describe('paramDefinitionSchema', () => { + it('accepts a minimal definition with only key', () => { + const result = paramDefinitionSchema.safeParse({ key: 'temperature' }); + expect(result.success).toBe(true); + }); + + it('accepts a full definition with all fields', () => { + const result = paramDefinitionSchema.safeParse({ + key: 'temperature', + type: 'number', + component: 'slider', + default: 0.7, + label: 'Temperature', + range: { min: 0, max: 2, step: 0.01 }, + columns: 2, + columnSpan: 1, + includeInput: true, + descriptionSide: 'right', + }); + expect(result.success).toBe(true); + }); + + it('rejects columns > 4', () => { + const result = paramDefinitionSchema.safeParse({ + key: 'test', + columns: 5, + }); + expect(result.success).toBe(false); + }); + + it('rejects columns < 1', () => { + const result = paramDefinitionSchema.safeParse({ + key: 'test', + columns: 0, + }); + expect(result.success).toBe(false); + }); + + it('rejects non-integer columns', () => { + const result = paramDefinitionSchema.safeParse({ + key: 'test', + columns: 2.5, + }); + expect(result.success).toBe(false); + }); + + it('rejects non-integer columnSpan', () => { + const result = paramDefinitionSchema.safeParse({ + key: 'test', + columnSpan: 1.5, + }); + expect(result.success).toBe(false); + }); + + it('rejects negative minTags', () => { + const result = paramDefinitionSchema.safeParse({ + key: 'test', + minTags: -1, + }); + expect(result.success).toBe(false); + }); + + it('rejects invalid descriptionSide', () => { + const result = paramDefinitionSchema.safeParse({ + key: 'test', + descriptionSide: 'diagonal', + }); + expect(result.success).toBe(false); + }); + + it('rejects invalid type enum value', () => { + const result = paramDefinitionSchema.safeParse({ + key: 'test', + type: 'invalid', + }); + expect(result.success).toBe(false); + }); + + it('rejects invalid component enum value', () => { + const result = paramDefinitionSchema.safeParse({ + key: 'test', + component: 'wheel', + }); + expect(result.success).toBe(false); + }); + + it('allows type and component to be omitted (merged from defaults at runtime)', () => { + const result = paramDefinitionSchema.safeParse({ + key: 'temperature', + range: { min: 0, max: 2, step: 0.01 }, + }); + expect(result.success).toBe(true); + expect(result.data).not.toHaveProperty('type'); + expect(result.data).not.toHaveProperty('component'); + }); +}); + +describe('tModelSpecPresetSchema', () => { + it('strips system/DB fields from preset', () => { + const result = tModelSpecPresetSchema.safeParse({ + conversationId: 'conv-123', + presetId: 'preset-456', + title: 'My Preset', + defaultPreset: true, + order: 3, + isArchived: true, + user: 'user123', + messages: ['msg1'], + tags: ['tag1'], + file_ids: ['file1'], + expiredAt: '2026-12-31', + parentMessageId: 'parent1', + model: 'gpt-4o', + endpoint: EModelEndpoint.openAI, + }); + expect(result.success).toBe(true); + if (result.success) { + expect(result.data).not.toHaveProperty('conversationId'); + expect(result.data).not.toHaveProperty('presetId'); + expect(result.data).not.toHaveProperty('title'); + expect(result.data).not.toHaveProperty('defaultPreset'); + expect(result.data).not.toHaveProperty('order'); + expect(result.data).not.toHaveProperty('isArchived'); + expect(result.data).not.toHaveProperty('user'); + expect(result.data).not.toHaveProperty('messages'); + expect(result.data).not.toHaveProperty('tags'); + expect(result.data).not.toHaveProperty('file_ids'); + expect(result.data).not.toHaveProperty('expiredAt'); + expect(result.data).not.toHaveProperty('parentMessageId'); + expect(result.data).toHaveProperty('model', 'gpt-4o'); + } + }); + + it('strips deprecated fields', () => { + const result = tModelSpecPresetSchema.safeParse({ + resendImages: true, + chatGptLabel: 'old-label', + model: 'gpt-4o', + endpoint: EModelEndpoint.openAI, + }); + expect(result.success).toBe(true); + if (result.success) { + expect(result.data).not.toHaveProperty('resendImages'); + expect(result.data).not.toHaveProperty('chatGptLabel'); + } + }); + + it('strips frontend-only fields', () => { + const result = tModelSpecPresetSchema.safeParse({ + greeting: 'Hello!', + iconURL: 'https://example.com/icon.png', + spec: 'some-spec', + presetOverride: { model: 'other' }, + model: 'gpt-4o', + endpoint: EModelEndpoint.openAI, + }); + expect(result.success).toBe(true); + if (result.success) { + expect(result.data).not.toHaveProperty('greeting'); + expect(result.data).not.toHaveProperty('iconURL'); + expect(result.data).not.toHaveProperty('spec'); + expect(result.data).not.toHaveProperty('presetOverride'); + } + }); + + it('preserves valid preset fields', () => { + const result = tModelSpecPresetSchema.safeParse({ + model: 'gpt-4o', + endpoint: EModelEndpoint.openAI, + temperature: 0.7, + topP: 0.9, + maxOutputTokens: 4096, + }); + expect(result.success).toBe(true); + if (result.success) { + expect(result.data.model).toBe('gpt-4o'); + expect(result.data.temperature).toBe(0.7); + expect(result.data.topP).toBe(0.9); + expect(result.data.maxOutputTokens).toBe(4096); + } + }); +}); + +describe('endpointSchema deprecated fields', () => { + const validEndpoint = { + name: 'CustomEndpoint', + apiKey: 'test-key', + baseURL: 'https://api.example.com', + models: { default: ['model-1'] }, + }; + + it('silently strips deprecated summarize field', () => { + const result = endpointSchema.safeParse({ + ...validEndpoint, + summarize: true, + summaryModel: 'gpt-4o', + }); + expect(result.success).toBe(true); + if (result.success) { + expect(result.data).not.toHaveProperty('summarize'); + expect(result.data).not.toHaveProperty('summaryModel'); + } + }); + + it('silently strips deprecated customOrder field', () => { + const result = endpointSchema.safeParse({ + ...validEndpoint, + customOrder: 5, + }); + expect(result.success).toBe(true); + if (result.success) { + expect(result.data).not.toHaveProperty('customOrder'); + } + }); +}); + +describe('agentsEndpointSchema', () => { + it('does not accept baseURL', () => { + const result = agentsEndpointSchema.safeParse({ + baseURL: 'https://example.com', + }); + expect(result.success).toBe(true); + if (result.success) { + expect(result.data).not.toHaveProperty('baseURL'); + } + }); +}); + +describe('azureEndpointSchema', () => { + it('silently strips plugins field', () => { + const result = azureEndpointSchema.safeParse({ + groups: [ + { + group: 'test-group', + apiKey: 'test-key', + models: { 'gpt-4': true }, + }, + ], + plugins: true, + }); + expect(result.success).toBe(true); + if (result.success) { + expect(result.data).not.toHaveProperty('plugins'); + } + }); +}); diff --git a/packages/data-provider/specs/filetypes.spec.ts b/packages/data-provider/specs/filetypes.spec.ts index 39711dadd9..dba6cd4795 100644 --- a/packages/data-provider/specs/filetypes.spec.ts +++ b/packages/data-provider/specs/filetypes.spec.ts @@ -8,7 +8,6 @@ import { retrievalMimeTypes, excelFileTypes, excelMimeTypes, - fileConfigSchema, mergeFileConfig, mbToBytes, } from '../src/file-config'; @@ -126,8 +125,6 @@ describe('mergeFileConfig', () => { test('merges minimal update correctly', () => { const result = mergeFileConfig(dynamicConfigs.minimalUpdate); expect(result.serverFileSizeLimit).toEqual(mbToBytes(1024)); - const parsedResult = fileConfigSchema.safeParse(result); - expect(parsedResult.success).toBeTruthy(); }); test('overrides default endpoint with full new configuration', () => { @@ -136,8 +133,6 @@ describe('mergeFileConfig', () => { expect(result.endpoints.default.supportedMimeTypes).toEqual( expect.arrayContaining([new RegExp('^video/.*$')]), ); - const parsedResult = fileConfigSchema.safeParse(result); - expect(parsedResult.success).toBeTruthy(); }); test('adds new endpoint configuration correctly', () => { @@ -147,8 +142,6 @@ describe('mergeFileConfig', () => { expect(result.endpoints.newEndpoint.supportedMimeTypes).toEqual( expect.arrayContaining([new RegExp('^application/json$')]), ); - const parsedResult = fileConfigSchema.safeParse(result); - expect(parsedResult.success).toBeTruthy(); }); test('disables an endpoint and sets numeric fields to 0 and empties supportedMimeTypes', () => { diff --git a/packages/data-provider/src/config.ts b/packages/data-provider/src/config.ts index e641d7b63a..ae3f5b9560 100644 --- a/packages/data-provider/src/config.ts +++ b/packages/data-provider/src/config.ts @@ -2,6 +2,7 @@ import { z } from 'zod'; import type { ZodError } from 'zod'; import type { TEndpointsConfig, TModelsConfig, TConfig } from './types'; import { EModelEndpoint, eModelEndpointSchema, isAgentsEndpoint } from './schemas'; +import { ComponentTypes, SettingTypes, OptionTypes } from './generate'; import { specsConfigSchema, TSpecsConfig } from './models'; import { fileConfigSchema } from './file-config'; import { apiBaseUrl } from './api-endpoints'; @@ -120,11 +121,11 @@ export const azureBaseSchema = z.object({ instanceName: z.string().optional(), deploymentName: z.string().optional(), assistants: z.boolean().optional(), - addParams: z.record(z.any()).optional(), + addParams: z.record(z.union([z.string(), z.number(), z.boolean(), z.null()])).optional(), dropParams: z.array(z.string()).optional(), version: z.string().optional(), baseURL: z.string().optional(), - additionalHeaders: z.record(z.any()).optional(), + additionalHeaders: z.record(z.string()).optional(), }); export type TAzureBaseSchema = z.infer; @@ -257,7 +258,7 @@ export const assistantEndpointSchema = baseEndpointSchema.merge( userIdQuery: z.boolean().optional(), }) .optional(), - headers: z.record(z.any()).optional(), + headers: z.record(z.string()).optional(), }), ); @@ -279,6 +280,7 @@ export const defaultAgentCapabilities = [ ]; export const agentsEndpointSchema = baseEndpointSchema + .omit({ baseURL: true }) .merge( z.object({ /* agents specific */ @@ -305,6 +307,43 @@ export const agentsEndpointSchema = baseEndpointSchema export type TAgentsEndpoint = z.infer; +export const paramDefinitionSchema = z.object({ + key: z.string(), + description: z.string().optional(), + type: z.nativeEnum(SettingTypes).optional(), + default: z.union([z.number(), z.boolean(), z.string(), z.array(z.string())]).optional(), + showLabel: z.boolean().optional(), + showDefault: z.boolean().optional(), + options: z.array(z.string()).optional(), + range: z + .object({ + min: z.number(), + max: z.number(), + step: z.number().optional(), + }) + .optional(), + enumMappings: z.record(z.union([z.number(), z.boolean(), z.string()])).optional(), + component: z.nativeEnum(ComponentTypes).optional(), + optionType: z.nativeEnum(OptionTypes).optional(), + columnSpan: z.number().int().nonnegative().optional(), + columns: z.number().int().min(1).max(4).optional(), + label: z.string().optional(), + placeholder: z.string().optional(), + labelCode: z.boolean().optional(), + placeholderCode: z.boolean().optional(), + descriptionCode: z.boolean().optional(), + minText: z.number().optional(), + maxText: z.number().optional(), + minTags: z.number().min(0).optional(), + maxTags: z.number().min(0).optional(), + includeInput: z.boolean().optional(), + descriptionSide: z.enum(['top', 'right', 'bottom', 'left']).optional(), + searchPlaceholder: z.string().optional(), + selectPlaceholder: z.string().optional(), + searchPlaceholderCode: z.boolean().optional(), + selectPlaceholderCode: z.boolean().optional(), +}); + export const endpointSchema = baseEndpointSchema.merge( z.object({ name: z.string().refine((value) => !eModelEndpointSchema.safeParse(value).success, { @@ -319,23 +358,20 @@ export const endpointSchema = baseEndpointSchema.merge( fetch: z.boolean().optional(), userIdQuery: z.boolean().optional(), }), - summarize: z.boolean().optional(), - summaryModel: z.string().optional(), iconURL: z.string().optional(), modelDisplayLabel: z.string().optional(), - headers: z.record(z.any()).optional(), - addParams: z.record(z.any()).optional(), + headers: z.record(z.string()).optional(), + addParams: z.record(z.union([z.string(), z.number(), z.boolean(), z.null()])).optional(), dropParams: z.array(z.string()).optional(), customParams: z .object({ defaultParamsEndpoint: z.string().default('custom'), - paramDefinitions: z.array(z.record(z.any())).optional(), + paramDefinitions: z.array(paramDefinitionSchema).optional(), }) .strict() .optional(), - customOrder: z.number().optional(), directEndpoint: z.boolean().optional(), - titleMessageRole: z.string().optional(), + titleMessageRole: z.enum(['system', 'user', 'assistant']).optional(), }), ); @@ -344,7 +380,6 @@ export type TEndpoint = z.infer; export const azureEndpointSchema = z .object({ groups: azureGroupConfigsSchema, - plugins: z.boolean().optional(), assistants: z.boolean().optional(), }) .and( @@ -356,9 +391,6 @@ export const azureEndpointSchema = z titleModel: true, titlePrompt: true, titlePromptTemplate: true, - summarize: true, - summaryModel: true, - customOrder: true, }) .partial(), ); @@ -501,7 +533,8 @@ const speechTab = z .optional() .or( z.object({ - engineSTT: z.string().optional(), + /** Keep in sync with STTProviders enum (defined below — cannot reference due to eval order) */ + engineSTT: z.enum(['openai', 'azureOpenAI']).optional(), languageSTT: z.string().optional(), autoTranscribeAudio: z.boolean().optional(), decibelValue: z.number().optional(), @@ -514,11 +547,12 @@ const speechTab = z .optional() .or( z.object({ - engineTTS: z.string().optional(), + /** Keep in sync with TTSProviders enum (defined below — cannot reference due to eval order) */ + engineTTS: z.enum(['openai', 'azureOpenAI', 'elevenlabs', 'localai']).optional(), voice: z.string().optional(), languageTTS: z.string().optional(), automaticPlayback: z.boolean().optional(), - playbackRate: z.number().optional(), + playbackRate: z.number().min(0.25).max(4).optional(), cacheTTS: z.boolean().optional(), }), ) @@ -864,7 +898,7 @@ export const webSearchSchema = z.object({ searchProvider: z.nativeEnum(SearchProviders).optional(), scraperProvider: z.nativeEnum(ScraperProviders).optional(), rerankerType: z.nativeEnum(RerankerTypes).optional(), - scraperTimeout: z.number().optional(), + scraperTimeout: z.number().int().nonnegative().optional(), safeSearch: z.nativeEnum(SafeSearchTypes).default(SafeSearchTypes.MODERATE), firecrawlOptions: z .object({ @@ -873,7 +907,7 @@ export const webSearchSchema = z.object({ excludeTags: z.array(z.string()).optional(), headers: z.record(z.string()).optional(), waitFor: z.number().optional(), - timeout: z.number().optional(), + timeout: z.number().int().nonnegative().optional(), maxAge: z.number().optional(), mobile: z.boolean().optional(), skipTlsVerification: z.boolean().optional(), @@ -942,7 +976,7 @@ export const memorySchema = z.object({ provider: z.string(), model: z.string(), instructions: z.string().optional(), - model_parameters: z.record(z.any()).optional(), + model_parameters: z.record(z.union([z.string(), z.number(), z.boolean()])).optional(), }), ]) .optional(), @@ -1026,7 +1060,7 @@ export const configSchema = z.object({ modelSpecs: specsConfigSchema.optional(), endpoints: z .object({ - all: baseEndpointSchema.optional(), + all: baseEndpointSchema.omit({ baseURL: true }).optional(), [EModelEndpoint.openAI]: baseEndpointSchema.optional(), [EModelEndpoint.google]: baseEndpointSchema.optional(), [EModelEndpoint.anthropic]: anthropicEndpointSchema.optional(), @@ -1238,9 +1272,6 @@ export const defaultModels = { 'gemini-2.5-pro', 'gemini-2.5-flash', 'gemini-2.5-flash-lite', - // Gemini 2.0 Models - 'gemini-2.0-flash-001', - 'gemini-2.0-flash-lite', ], [EModelEndpoint.anthropic]: sharedAnthropicModels, [EModelEndpoint.openAI]: [ diff --git a/packages/data-provider/src/file-config.ts b/packages/data-provider/src/file-config.ts index e8781ba6b2..9c2b3bf7df 100644 --- a/packages/data-provider/src/file-config.ts +++ b/packages/data-provider/src/file-config.ts @@ -443,22 +443,7 @@ export const fileConfig = { }, }; -const supportedMimeTypesSchema = z - .array(z.any()) - .optional() - .refine( - (mimeTypes) => { - if (!mimeTypes) { - return true; - } - return mimeTypes.every( - (mimeType) => mimeType instanceof RegExp || typeof mimeType === 'string', - ); - }, - { - message: 'Each mimeType must be a string or a RegExp object.', - }, - ); +const supportedMimeTypesSchema = z.array(z.string()).optional(); export const endpointFileConfigSchema = z.object({ disabled: z.boolean().optional(), @@ -692,22 +677,24 @@ export function mergeFileConfig(dynamic: z.infer | unde } if (dynamic.ocr !== undefined) { + const { supportedMimeTypes: ocrMimeTypes, ...ocrRest } = dynamic.ocr; mergedConfig.ocr = { ...mergedConfig.ocr, - ...dynamic.ocr, + ...ocrRest, }; - if (dynamic.ocr.supportedMimeTypes) { - mergedConfig.ocr.supportedMimeTypes = convertStringsToRegex(dynamic.ocr.supportedMimeTypes); + if (ocrMimeTypes) { + mergedConfig.ocr.supportedMimeTypes = convertStringsToRegex(ocrMimeTypes); } } if (dynamic.text !== undefined) { + const { supportedMimeTypes: textMimeTypes, ...textRest } = dynamic.text; mergedConfig.text = { ...mergedConfig.text, - ...dynamic.text, + ...textRest, }; - if (dynamic.text.supportedMimeTypes) { - mergedConfig.text.supportedMimeTypes = convertStringsToRegex(dynamic.text.supportedMimeTypes); + if (textMimeTypes) { + mergedConfig.text.supportedMimeTypes = convertStringsToRegex(textMimeTypes); } } diff --git a/packages/data-provider/src/mcp.ts b/packages/data-provider/src/mcp.ts index 3ad296c4ec..b22a599b9b 100644 --- a/packages/data-provider/src/mcp.ts +++ b/packages/data-provider/src/mcp.ts @@ -18,10 +18,10 @@ const BaseOptionsSchema = z.object({ */ startup: z.boolean().optional(), iconPath: z.string().optional(), - timeout: z.number().optional(), + timeout: z.number().int().nonnegative().optional(), /** Timeout (ms) for the long-lived SSE GET stream body before undici aborts it. Default: 300_000 (5 min). */ - sseReadTimeout: z.number().positive().optional(), - initTimeout: z.number().optional(), + sseReadTimeout: z.number().int().positive().optional(), + initTimeout: z.number().int().nonnegative().optional(), /** Controls visibility in chat dropdown menu (MCPSelect) */ chatMenu: z.boolean().optional(), /** @@ -104,7 +104,7 @@ const BaseOptionsSchema = z.object({ }); export const StdioOptionsSchema = BaseOptionsSchema.extend({ - type: z.literal('stdio').optional(), + type: z.literal('stdio').default('stdio'), /** * The executable to run to start the server. */ @@ -134,17 +134,17 @@ export const StdioOptionsSchema = BaseOptionsSchema.extend({ return processedEnv; }), /** - * How to handle stderr of the child process. This matches the semantics of Node's `child_process.spawn`. - * - * @type {import('node:child_process').IOType | import('node:stream').Stream | number} - * - * The default is "inherit", meaning messages to stderr will be printed to the parent process's stderr. + * How to handle stderr of the child process. + * Accepts: 'pipe' | 'ignore' | 'inherit' | file descriptor number. + * Defaults to "inherit". */ - stderr: z.any().optional(), + stderr: z + .union([z.enum(['pipe', 'ignore', 'inherit']), z.number().int().nonnegative()]) + .optional(), }); export const WebSocketOptionsSchema = BaseOptionsSchema.extend({ - type: z.literal('websocket').optional(), + type: z.literal('websocket').default('websocket'), url: z .string() .transform((val: string) => extractEnvVariable(val)) @@ -161,7 +161,7 @@ export const WebSocketOptionsSchema = BaseOptionsSchema.extend({ }); export const SSEOptionsSchema = BaseOptionsSchema.extend({ - type: z.literal('sse').optional(), + type: z.literal('sse').default('sse'), headers: z.record(z.string(), z.string()).optional(), url: z .string() diff --git a/packages/data-provider/src/models.ts b/packages/data-provider/src/models.ts index c2dbe2cf77..82c2042d8a 100644 --- a/packages/data-provider/src/models.ts +++ b/packages/data-provider/src/models.ts @@ -1,8 +1,8 @@ import { z } from 'zod'; -import type { TPreset } from './schemas'; +import type { TModelSpecPreset } from './schemas'; import { EModelEndpoint, - tPresetSchema, + tModelSpecPresetSchema, eModelEndpointSchema, AuthType, authTypeSchema, @@ -11,7 +11,7 @@ import { export type TModelSpec = { name: string; label: string; - preset: TPreset; + preset: TModelSpecPreset; order?: number; default?: boolean; description?: string; @@ -42,7 +42,7 @@ export type TModelSpec = { export const tModelSpecSchema = z.object({ name: z.string(), label: z.string(), - preset: tPresetSchema, + preset: tModelSpecPresetSchema, order: z.number().optional(), default: z.boolean().optional(), description: z.string().optional(), diff --git a/packages/data-provider/src/schemas.ts b/packages/data-provider/src/schemas.ts index 19ba804556..084f74af86 100644 --- a/packages/data-provider/src/schemas.ts +++ b/packages/data-provider/src/schemas.ts @@ -635,11 +635,15 @@ export const tMessageSchema = z.object({ calibrationRatio: z .number() .optional() - .describe('EMA ratio of provider-reported vs local token estimates; seeds the pruner on subsequent runs'), + .describe( + 'EMA ratio of provider-reported vs local token estimates; seeds the pruner on subsequent runs', + ), encoding: z .string() .optional() - .describe('Tokenizer encoding used when this ratio was computed (e.g. "claude", "o200k_base")'), + .describe( + 'Tokenizer encoding used when this ratio was computed (e.g. "claude", "o200k_base")', + ), }) .optional(), }); @@ -919,6 +923,30 @@ export const tQueryParamsSchema = tConversationSchema }), ); +/** Narrowed preset schema for use in model specs — omits system/DB/deprecated fields */ +export const tModelSpecPresetSchema = tPresetSchema.omit({ + conversationId: true, + presetId: true, + title: true, + defaultPreset: true, + order: true, + isArchived: true, + user: true, + messages: true, + tags: true, + file_ids: true, + expiredAt: true, + parentMessageId: true, + resendImages: true, + chatGptLabel: true, + presetOverride: true, + greeting: true, + iconURL: true, + spec: true, +}); + +export type TModelSpecPreset = z.infer; + export type TPreset = z.infer; export type TSetOption = ( diff --git a/packages/data-schemas/package.json b/packages/data-schemas/package.json index 0376804ad4..145b8925d1 100644 --- a/packages/data-schemas/package.json +++ b/packages/data-schemas/package.json @@ -1,16 +1,22 @@ { "name": "@librechat/data-schemas", - "version": "0.0.40", + "version": "0.0.47", "description": "Mongoose schemas and models for LibreChat", "type": "module", "main": "dist/index.cjs", "module": "dist/index.es.js", "types": "./dist/types/index.d.ts", + "sideEffects": false, "exports": { ".": { "import": "./dist/index.es.js", "require": "./dist/index.cjs", "types": "./dist/types/index.d.ts" + }, + "./capabilities": { + "import": "./dist/admin/capabilities.es.js", + "require": "./dist/admin/capabilities.cjs", + "types": "./dist/types/admin/capabilities.d.ts" } }, "files": [ diff --git a/packages/data-schemas/rollup.config.js b/packages/data-schemas/rollup.config.js index d58331feee..703630e121 100644 --- a/packages/data-schemas/rollup.config.js +++ b/packages/data-schemas/rollup.config.js @@ -8,14 +8,20 @@ export default { input: 'src/index.ts', output: [ { - file: 'dist/index.es.js', + dir: 'dist', format: 'es', sourcemap: true, + preserveModules: true, + preserveModulesRoot: 'src', + entryFileNames: '[name].es.js', }, { - file: 'dist/index.cjs', + dir: 'dist', format: 'cjs', sourcemap: true, + preserveModules: true, + preserveModulesRoot: 'src', + entryFileNames: '[name].cjs', }, ], plugins: [ diff --git a/packages/data-schemas/src/admin/capabilities.ts b/packages/data-schemas/src/admin/capabilities.ts new file mode 100644 index 0000000000..447db235a2 --- /dev/null +++ b/packages/data-schemas/src/admin/capabilities.ts @@ -0,0 +1,199 @@ +import { ResourceType } from 'librechat-data-provider'; +import type { + BaseSystemCapability, + SystemCapability, + ConfigSection, + CapabilityCategory, +} from '~/types/admin'; + +// --------------------------------------------------------------------------- +// System Capabilities +// --------------------------------------------------------------------------- + +/** + * The canonical set of base system capabilities. + * + * These are used by the admin panel and LibreChat API to gate access to + * admin features. Config-section-derived capabilities (e.g. + * `manage:configs:endpoints`) are built on top of these where the + * configSchema is available. + */ +export const SystemCapabilities = { + ACCESS_ADMIN: 'access:admin', + READ_USERS: 'read:users', + MANAGE_USERS: 'manage:users', + READ_GROUPS: 'read:groups', + MANAGE_GROUPS: 'manage:groups', + READ_ROLES: 'read:roles', + MANAGE_ROLES: 'manage:roles', + READ_CONFIGS: 'read:configs', + MANAGE_CONFIGS: 'manage:configs', + ASSIGN_CONFIGS: 'assign:configs', + READ_USAGE: 'read:usage', + READ_AGENTS: 'read:agents', + MANAGE_AGENTS: 'manage:agents', + MANAGE_MCP_SERVERS: 'manage:mcpservers', + READ_PROMPTS: 'read:prompts', + MANAGE_PROMPTS: 'manage:prompts', + /** Reserved — not yet enforced by any middleware. */ + READ_ASSISTANTS: 'read:assistants', + MANAGE_ASSISTANTS: 'manage:assistants', +} as const; + +/** + * Capabilities that are implied by holding a broader capability. + * e.g. `MANAGE_USERS` implies `READ_USERS`. + */ +export const CapabilityImplications: Partial> = + { + [SystemCapabilities.MANAGE_USERS]: [SystemCapabilities.READ_USERS], + [SystemCapabilities.MANAGE_GROUPS]: [SystemCapabilities.READ_GROUPS], + [SystemCapabilities.MANAGE_ROLES]: [SystemCapabilities.READ_ROLES], + [SystemCapabilities.MANAGE_CONFIGS]: [SystemCapabilities.READ_CONFIGS], + [SystemCapabilities.MANAGE_AGENTS]: [SystemCapabilities.READ_AGENTS], + [SystemCapabilities.MANAGE_PROMPTS]: [SystemCapabilities.READ_PROMPTS], + [SystemCapabilities.MANAGE_ASSISTANTS]: [SystemCapabilities.READ_ASSISTANTS], + }; + +// --------------------------------------------------------------------------- +// Capability utility functions +// --------------------------------------------------------------------------- + +/** Reverse map: for a given read capability, which manage capabilities imply it? */ +const impliedByMap: Record = {}; +for (const [manage, reads] of Object.entries(CapabilityImplications)) { + for (const read of reads as string[]) { + if (!impliedByMap[read]) { + impliedByMap[read] = []; + } + impliedByMap[read].push(manage); + } +} + +/** + * Check whether a set of held capabilities satisfies a required capability, + * accounting for the manage→read implication hierarchy. + */ +export function hasImpliedCapability(held: string[], required: string): boolean { + if (held.includes(required)) { + return true; + } + const impliers = impliedByMap[required]; + if (impliers) { + for (const cap of impliers) { + if (held.includes(cap)) { + return true; + } + } + } + return false; +} + +/** + * Given a set of directly-held capabilities, compute the full set including + * all implied capabilities. + */ +export function expandImplications(directCaps: string[]): string[] { + const expanded = new Set(directCaps); + for (const cap of directCaps) { + const implied = CapabilityImplications[cap as BaseSystemCapability]; + if (implied) { + for (const imp of implied) { + expanded.add(imp); + } + } + } + return Array.from(expanded); +} + +// --------------------------------------------------------------------------- +// Resource & config capability mappings +// --------------------------------------------------------------------------- + +/** + * Maps each ACL ResourceType to the SystemCapability that grants + * unrestricted management access. Typed as `Record` + * so adding a new ResourceType variant causes a compile error until a + * capability is assigned here. + */ +export const ResourceCapabilityMap: Record = { + [ResourceType.AGENT]: SystemCapabilities.MANAGE_AGENTS, + [ResourceType.PROMPTGROUP]: SystemCapabilities.MANAGE_PROMPTS, + [ResourceType.MCPSERVER]: SystemCapabilities.MANAGE_MCP_SERVERS, + [ResourceType.REMOTE_AGENT]: SystemCapabilities.MANAGE_AGENTS, +}; + +/** + * Derives a section-level config management capability from a configSchema key. + * @example configCapability('endpoints') → 'manage:configs:endpoints' + * + * TODO: Section-level config capabilities are scaffolded but not yet active. + * To activate delegated config management: + * 1. Expose POST/DELETE /api/admin/grants endpoints (wiring grantCapability/revokeCapability) + * 2. Seed section-specific grants for delegated admin roles via those endpoints + * 3. Guard config write handlers with hasConfigCapability(user, section) + */ +export function configCapability(section: ConfigSection): `manage:configs:${ConfigSection}` { + return `manage:configs:${section}`; +} + +/** + * Derives a section-level config read capability from a configSchema key. + * @example readConfigCapability('endpoints') → 'read:configs:endpoints' + */ +export function readConfigCapability(section: ConfigSection): `read:configs:${ConfigSection}` { + return `read:configs:${section}`; +} + +// --------------------------------------------------------------------------- +// Reserved principal IDs +// --------------------------------------------------------------------------- + +/** Reserved principalId for the DB base config (overrides YAML defaults). */ +export const BASE_CONFIG_PRINCIPAL_ID = '__base__'; + +/** Pre-defined UI categories for grouping capabilities in the admin panel. */ +export const CAPABILITY_CATEGORIES: CapabilityCategory[] = [ + { + key: 'users', + labelKey: 'com_cap_cat_users', + capabilities: [SystemCapabilities.MANAGE_USERS, SystemCapabilities.READ_USERS], + }, + { + key: 'groups', + labelKey: 'com_cap_cat_groups', + capabilities: [SystemCapabilities.MANAGE_GROUPS, SystemCapabilities.READ_GROUPS], + }, + { + key: 'roles', + labelKey: 'com_cap_cat_roles', + capabilities: [SystemCapabilities.MANAGE_ROLES, SystemCapabilities.READ_ROLES], + }, + { + key: 'config', + labelKey: 'com_cap_cat_config', + capabilities: [ + SystemCapabilities.MANAGE_CONFIGS, + SystemCapabilities.READ_CONFIGS, + SystemCapabilities.ASSIGN_CONFIGS, + ], + }, + { + key: 'content', + labelKey: 'com_cap_cat_content', + capabilities: [ + SystemCapabilities.MANAGE_AGENTS, + SystemCapabilities.READ_AGENTS, + SystemCapabilities.MANAGE_PROMPTS, + SystemCapabilities.READ_PROMPTS, + SystemCapabilities.MANAGE_ASSISTANTS, + SystemCapabilities.READ_ASSISTANTS, + SystemCapabilities.MANAGE_MCP_SERVERS, + ], + }, + { + key: 'system', + labelKey: 'com_cap_cat_system', + capabilities: [SystemCapabilities.ACCESS_ADMIN, SystemCapabilities.READ_USAGE], + }, +]; diff --git a/packages/data-schemas/src/admin/index.ts b/packages/data-schemas/src/admin/index.ts new file mode 100644 index 0000000000..8d43daada6 --- /dev/null +++ b/packages/data-schemas/src/admin/index.ts @@ -0,0 +1 @@ +export * from './capabilities'; diff --git a/packages/data-schemas/src/app/index.ts b/packages/data-schemas/src/app/index.ts index 77cb799f8c..b07a36acd0 100644 --- a/packages/data-schemas/src/app/index.ts +++ b/packages/data-schemas/src/app/index.ts @@ -5,3 +5,4 @@ export * from './specs'; export * from './turnstile'; export * from './vertex'; export * from './web'; +export * from './resolution'; diff --git a/packages/data-schemas/src/app/resolution.spec.ts b/packages/data-schemas/src/app/resolution.spec.ts new file mode 100644 index 0000000000..12f8985a48 --- /dev/null +++ b/packages/data-schemas/src/app/resolution.spec.ts @@ -0,0 +1,108 @@ +import { mergeConfigOverrides } from './resolution'; +import type { AppConfig, IConfig } from '~/types'; + +function fakeConfig(overrides: Record, priority: number): IConfig { + return { + _id: 'fake', + principalType: 'role', + principalId: 'test', + principalModel: 'Role', + priority, + overrides, + isActive: true, + configVersion: 1, + } as unknown as IConfig; +} + +const baseConfig = { + interface: { endpointsMenu: true, sidePanel: true }, + registration: { enabled: true }, + endpoints: ['openAI'], +} as unknown as AppConfig; + +describe('mergeConfigOverrides', () => { + it('returns base config when configs array is empty', () => { + expect(mergeConfigOverrides(baseConfig, [])).toBe(baseConfig); + }); + + it('returns base config when configs is null/undefined', () => { + expect(mergeConfigOverrides(baseConfig, null as unknown as IConfig[])).toBe(baseConfig); + expect(mergeConfigOverrides(baseConfig, undefined as unknown as IConfig[])).toBe(baseConfig); + }); + + it('deep merges a single override into base', () => { + const configs = [fakeConfig({ interface: { endpointsMenu: false } }, 10)]; + const result = mergeConfigOverrides(baseConfig, configs) as unknown as Record; + const iface = result.interface as Record; + expect(iface.endpointsMenu).toBe(false); + expect(iface.sidePanel).toBe(true); + }); + + it('sorts by priority — higher priority wins', () => { + const configs = [ + fakeConfig({ registration: { enabled: false } }, 100), + fakeConfig({ registration: { enabled: true, custom: 'yes' } }, 10), + ]; + const result = mergeConfigOverrides(baseConfig, configs) as unknown as Record; + const reg = result.registration as Record; + expect(reg.enabled).toBe(false); + expect(reg.custom).toBe('yes'); + }); + + it('replaces arrays instead of concatenating', () => { + const configs = [fakeConfig({ endpoints: ['anthropic', 'google'] }, 10)]; + const result = mergeConfigOverrides(baseConfig, configs) as unknown as Record; + expect(result.endpoints).toEqual(['anthropic', 'google']); + }); + + it('does not mutate the base config', () => { + const original = JSON.parse(JSON.stringify(baseConfig)); + const configs = [fakeConfig({ interface: { endpointsMenu: false } }, 10)]; + mergeConfigOverrides(baseConfig, configs); + expect(baseConfig).toEqual(original); + }); + + it('handles null override values', () => { + const configs = [fakeConfig({ interface: { endpointsMenu: null } }, 10)]; + const result = mergeConfigOverrides(baseConfig, configs) as unknown as Record; + const iface = result.interface as Record; + expect(iface.endpointsMenu).toBeNull(); + }); + + it('skips configs with no overrides object', () => { + const configs = [fakeConfig(undefined as unknown as Record, 10)]; + const result = mergeConfigOverrides(baseConfig, configs); + expect(result).toEqual(baseConfig); + }); + + it('strips __proto__, constructor, and prototype keys from overrides', () => { + const configs = [ + fakeConfig( + { + __proto__: { polluted: true }, + constructor: { bad: true }, + prototype: { evil: true }, + safe: 'ok', + }, + 10, + ), + ]; + const result = mergeConfigOverrides(baseConfig, configs) as unknown as Record; + expect(result.safe).toBe('ok'); + expect(({} as Record).polluted).toBeUndefined(); + expect(Object.prototype.hasOwnProperty.call(result, 'constructor')).toBe(false); + expect(Object.prototype.hasOwnProperty.call(result, 'prototype')).toBe(false); + }); + + it('merges three priority levels in order', () => { + const configs = [ + fakeConfig({ interface: { endpointsMenu: false } }, 0), + fakeConfig({ interface: { endpointsMenu: true, sidePanel: false } }, 10), + fakeConfig({ interface: { sidePanel: true } }, 100), + ]; + const result = mergeConfigOverrides(baseConfig, configs) as unknown as Record; + const iface = result.interface as Record; + expect(iface.endpointsMenu).toBe(true); + expect(iface.sidePanel).toBe(true); + }); +}); diff --git a/packages/data-schemas/src/app/resolution.ts b/packages/data-schemas/src/app/resolution.ts new file mode 100644 index 0000000000..ad1c1fbff0 --- /dev/null +++ b/packages/data-schemas/src/app/resolution.ts @@ -0,0 +1,54 @@ +import type { AppConfig, IConfig } from '~/types'; + +type AnyObject = { [key: string]: unknown }; + +const MAX_MERGE_DEPTH = 10; +const UNSAFE_KEYS = new Set(['__proto__', 'constructor', 'prototype']); + +function deepMerge(target: T, source: AnyObject, depth = 0): T { + const result = { ...target } as AnyObject; + for (const key of Object.keys(source)) { + if (UNSAFE_KEYS.has(key)) { + continue; + } + const sourceVal = source[key]; + const targetVal = result[key]; + if ( + depth < MAX_MERGE_DEPTH && + sourceVal != null && + typeof sourceVal === 'object' && + !Array.isArray(sourceVal) && + targetVal != null && + typeof targetVal === 'object' && + !Array.isArray(targetVal) + ) { + result[key] = deepMerge(targetVal as AnyObject, sourceVal as AnyObject, depth + 1); + } else { + result[key] = sourceVal; + } + } + return result as T; +} + +/** + * Merge DB config overrides into a base AppConfig. + * + * Configs are sorted by priority ascending (lowest first, highest wins). + * Each config's `overrides` is deep-merged into the base config in order. + */ +export function mergeConfigOverrides(baseConfig: AppConfig, configs: IConfig[]): AppConfig { + if (!configs || configs.length === 0) { + return baseConfig; + } + + const sorted = [...configs].sort((a, b) => a.priority - b.priority); + + let merged = { ...baseConfig }; + for (const config of sorted) { + if (config.overrides && typeof config.overrides === 'object') { + merged = deepMerge(merged, config.overrides as AnyObject); + } + } + + return merged; +} diff --git a/packages/data-schemas/src/config/tenantContext.spec.ts b/packages/data-schemas/src/config/tenantContext.spec.ts new file mode 100644 index 0000000000..7e6cc0748d --- /dev/null +++ b/packages/data-schemas/src/config/tenantContext.spec.ts @@ -0,0 +1,26 @@ +import { tenantStorage, runAsSystem, scopedCacheKey } from './tenantContext'; + +describe('scopedCacheKey', () => { + it('returns base key when no ALS context is set', () => { + expect(scopedCacheKey('MODELS_CONFIG')).toBe('MODELS_CONFIG'); + }); + + it('returns base key in SYSTEM_TENANT_ID context', async () => { + await runAsSystem(async () => { + expect(scopedCacheKey('MODELS_CONFIG')).toBe('MODELS_CONFIG'); + }); + }); + + it('appends tenantId when tenant context is active', async () => { + await tenantStorage.run({ tenantId: 'acme' }, async () => { + expect(scopedCacheKey('MODELS_CONFIG')).toBe('MODELS_CONFIG:acme'); + }); + }); + + it('does not leak tenant context outside ALS scope', async () => { + await tenantStorage.run({ tenantId: 'acme' }, async () => { + expect(scopedCacheKey('KEY')).toBe('KEY:acme'); + }); + expect(scopedCacheKey('KEY')).toBe('KEY'); + }); +}); diff --git a/packages/data-schemas/src/config/tenantContext.ts b/packages/data-schemas/src/config/tenantContext.ts index e5e4376a90..eb77edb27d 100644 --- a/packages/data-schemas/src/config/tenantContext.ts +++ b/packages/data-schemas/src/config/tenantContext.ts @@ -26,3 +26,16 @@ export function getTenantId(): string | undefined { export function runAsSystem(fn: () => Promise): Promise { return tenantStorage.run({ tenantId: SYSTEM_TENANT_ID }, fn); } + +/** + * Appends `:${tenantId}` to a cache key when a non-system tenant context is active. + * Returns the base key unchanged when no ALS context is set or when running + * inside `runAsSystem()` (SYSTEM_TENANT_ID context). + */ +export function scopedCacheKey(baseKey: string): string { + const tenantId = getTenantId(); + if (!tenantId || tenantId === SYSTEM_TENANT_ID) { + return baseKey; + } + return `${baseKey}:${tenantId}`; +} diff --git a/packages/data-schemas/src/index.ts b/packages/data-schemas/src/index.ts index aa92b3b2e6..1139f83f17 100644 --- a/packages/data-schemas/src/index.ts +++ b/packages/data-schemas/src/index.ts @@ -1,5 +1,5 @@ export * from './app'; -export * from './systemCapabilities'; +export * from './admin'; export * from './common'; export * from './crypto'; export * from './schema'; @@ -7,6 +7,7 @@ export * from './utils'; export { createModels } from './models'; export { createMethods, + RoleConflictError, DEFAULT_REFRESH_TOKEN_EXPIRY, DEFAULT_SESSION_EXPIRY, tokenValues, @@ -18,6 +19,12 @@ export type * from './types'; export type * from './methods'; export { default as logger } from './config/winston'; export { default as meiliLogger } from './config/meiliLogger'; -export { tenantStorage, getTenantId, runAsSystem, SYSTEM_TENANT_ID } from './config/tenantContext'; +export { + tenantStorage, + getTenantId, + runAsSystem, + scopedCacheKey, + SYSTEM_TENANT_ID, +} from './config/tenantContext'; export type { TenantContext } from './config/tenantContext'; export { dropSupersededTenantIndexes, dropSupersededPromptGroupIndexes } from './migrations'; diff --git a/packages/data-schemas/src/methods/aclEntry.ts b/packages/data-schemas/src/methods/aclEntry.ts index 82e277254a..d93693641c 100644 --- a/packages/data-schemas/src/methods/aclEntry.ts +++ b/packages/data-schemas/src/methods/aclEntry.ts @@ -7,7 +7,8 @@ import type { DeleteResult, Model, } from 'mongoose'; -import type { IAclEntry } from '~/types'; +import type { AclEntry, IAclEntry } from '~/types'; +import { tenantSafeBulkWrite } from '~/utils/tenantBulkWrite'; export function createAclEntryMethods(mongoose: typeof import('mongoose')) { /** @@ -374,11 +375,11 @@ export function createAclEntryMethods(mongoose: typeof import('mongoose')) { * @param options - Optional query options (e.g., { session }) */ async function bulkWriteAclEntries( - ops: AnyBulkWriteOperation[], + ops: AnyBulkWriteOperation[], options?: { session?: ClientSession }, ) { const AclEntry = mongoose.models.AclEntry as Model; - return AclEntry.bulkWrite(ops, options || {}); + return tenantSafeBulkWrite(AclEntry, ops as AnyBulkWriteOperation[], options || {}); } /** @@ -448,7 +449,9 @@ export function createAclEntryMethods(mongoose: typeof import('mongoose')) { { $group: { _id: '$resourceId' } }, ]); - const multiOwnerIds = new Set(otherOwners.map((doc: { _id: Types.ObjectId }) => doc._id.toString())); + const multiOwnerIds = new Set( + otherOwners.map((doc: { _id: Types.ObjectId }) => doc._id.toString()), + ); return ownedIds.filter((id) => !multiOwnerIds.has(id.toString())); } diff --git a/packages/data-schemas/src/methods/agentCategory.ts b/packages/data-schemas/src/methods/agentCategory.ts index 2dd4678075..baf33207aa 100644 --- a/packages/data-schemas/src/methods/agentCategory.ts +++ b/packages/data-schemas/src/methods/agentCategory.ts @@ -1,5 +1,6 @@ import type { Model, Types } from 'mongoose'; import type { IAgentCategory } from '~/types'; +import { tenantSafeBulkWrite } from '~/utils/tenantBulkWrite'; export function createAgentCategoryMethods(mongoose: typeof import('mongoose')) { /** @@ -74,7 +75,7 @@ export function createAgentCategoryMethods(mongoose: typeof import('mongoose')) }, })); - return await AgentCategory.bulkWrite(operations); + return await tenantSafeBulkWrite(AgentCategory, operations); } /** @@ -241,7 +242,7 @@ export function createAgentCategoryMethods(mongoose: typeof import('mongoose')) }, })); - await AgentCategory.bulkWrite(bulkOps, { ordered: false }); + await tenantSafeBulkWrite(AgentCategory, bulkOps, { ordered: false }); } return updates.length > 0 || created > 0; diff --git a/packages/data-schemas/src/methods/config.spec.ts b/packages/data-schemas/src/methods/config.spec.ts new file mode 100644 index 0000000000..8bcf73a733 --- /dev/null +++ b/packages/data-schemas/src/methods/config.spec.ts @@ -0,0 +1,333 @@ +import mongoose, { Types } from 'mongoose'; +import { MongoMemoryServer } from 'mongodb-memory-server'; +import { PrincipalType, PrincipalModel } from 'librechat-data-provider'; +import { createConfigMethods } from './config'; +import configSchema from '~/schema/config'; +import type { IConfig } from '~/types'; + +let mongoServer: MongoMemoryServer; +let methods: ReturnType; + +beforeAll(async () => { + mongoServer = await MongoMemoryServer.create(); + await mongoose.connect(mongoServer.getUri()); + if (!mongoose.models.Config) { + mongoose.model('Config', configSchema); + } + methods = createConfigMethods(mongoose); +}); + +afterAll(async () => { + await mongoose.disconnect(); + await mongoServer.stop(); +}); + +beforeEach(async () => { + await mongoose.models.Config.deleteMany({}); +}); + +describe('upsertConfig', () => { + it('creates a new config document', async () => { + const result = await methods.upsertConfig( + PrincipalType.ROLE, + 'admin', + PrincipalModel.ROLE, + { interface: { endpointsMenu: false } }, + 10, + ); + + expect(result).toBeTruthy(); + expect(result!.principalType).toBe(PrincipalType.ROLE); + expect(result!.principalId).toBe('admin'); + expect(result!.priority).toBe(10); + expect(result!.isActive).toBe(true); + expect(result!.configVersion).toBe(1); + }); + + it('is idempotent — second upsert updates the same doc', async () => { + await methods.upsertConfig( + PrincipalType.ROLE, + 'admin', + PrincipalModel.ROLE, + { interface: { endpointsMenu: false } }, + 10, + ); + + await methods.upsertConfig( + PrincipalType.ROLE, + 'admin', + PrincipalModel.ROLE, + { interface: { endpointsMenu: true } }, + 10, + ); + + const count = await mongoose.models.Config.countDocuments({}); + expect(count).toBe(1); + }); + + it('increments configVersion on each upsert', async () => { + const first = await methods.upsertConfig( + PrincipalType.ROLE, + 'admin', + PrincipalModel.ROLE, + { interface: { endpointsMenu: true } }, + 10, + ); + + const second = await methods.upsertConfig( + PrincipalType.ROLE, + 'admin', + PrincipalModel.ROLE, + { interface: { endpointsMenu: false } }, + 10, + ); + + expect(first!.configVersion).toBe(1); + expect(second!.configVersion).toBe(2); + }); + + it('normalizes ObjectId principalId to string', async () => { + const oid = new Types.ObjectId(); + await methods.upsertConfig(PrincipalType.USER, oid, PrincipalModel.USER, { cache: true }, 100); + + const found = await methods.findConfigByPrincipal(PrincipalType.USER, oid.toString()); + expect(found).toBeTruthy(); + expect(found!.principalId).toBe(oid.toString()); + }); +}); + +describe('findConfigByPrincipal', () => { + it('finds an active config', async () => { + await methods.upsertConfig( + PrincipalType.ROLE, + 'admin', + PrincipalModel.ROLE, + { cache: true }, + 10, + ); + + const result = await methods.findConfigByPrincipal(PrincipalType.ROLE, 'admin'); + expect(result).toBeTruthy(); + expect(result!.principalType).toBe(PrincipalType.ROLE); + }); + + it('returns null when no config exists', async () => { + const result = await methods.findConfigByPrincipal(PrincipalType.ROLE, 'nonexistent'); + expect(result).toBeNull(); + }); + + it('does not find inactive configs', async () => { + await methods.upsertConfig( + PrincipalType.ROLE, + 'admin', + PrincipalModel.ROLE, + { cache: true }, + 10, + ); + await methods.toggleConfigActive(PrincipalType.ROLE, 'admin', false); + + const result = await methods.findConfigByPrincipal(PrincipalType.ROLE, 'admin'); + expect(result).toBeNull(); + }); +}); + +describe('listAllConfigs', () => { + it('returns all configs when no filter', async () => { + await methods.upsertConfig(PrincipalType.ROLE, 'a', PrincipalModel.ROLE, {}, 10); + await methods.upsertConfig(PrincipalType.ROLE, 'b', PrincipalModel.ROLE, {}, 20); + await methods.toggleConfigActive(PrincipalType.ROLE, 'b', false); + + const all = await methods.listAllConfigs(); + expect(all).toHaveLength(2); + }); + + it('filters by isActive when specified', async () => { + await methods.upsertConfig(PrincipalType.ROLE, 'a', PrincipalModel.ROLE, {}, 10); + await methods.upsertConfig(PrincipalType.ROLE, 'b', PrincipalModel.ROLE, {}, 20); + await methods.toggleConfigActive(PrincipalType.ROLE, 'b', false); + + const active = await methods.listAllConfigs({ isActive: true }); + expect(active).toHaveLength(1); + expect(active[0].principalId).toBe('a'); + + const inactive = await methods.listAllConfigs({ isActive: false }); + expect(inactive).toHaveLength(1); + expect(inactive[0].principalId).toBe('b'); + }); + + it('returns configs sorted by priority ascending', async () => { + await methods.upsertConfig(PrincipalType.ROLE, 'high', PrincipalModel.ROLE, {}, 100); + await methods.upsertConfig(PrincipalType.ROLE, 'low', PrincipalModel.ROLE, {}, 0); + await methods.upsertConfig(PrincipalType.ROLE, 'mid', PrincipalModel.ROLE, {}, 50); + + const configs = await methods.listAllConfigs(); + expect(configs.map((c) => c.principalId)).toEqual(['low', 'mid', 'high']); + }); +}); + +describe('getApplicableConfigs', () => { + it('always includes the __base__ config', async () => { + await methods.upsertConfig( + PrincipalType.ROLE, + '__base__', + PrincipalModel.ROLE, + { cache: true }, + 0, + ); + + const configs = await methods.getApplicableConfigs([]); + expect(configs).toHaveLength(1); + expect(configs[0].principalId).toBe('__base__'); + }); + + it('returns base + matching principals', async () => { + await methods.upsertConfig( + PrincipalType.ROLE, + '__base__', + PrincipalModel.ROLE, + { cache: true }, + 0, + ); + await methods.upsertConfig( + PrincipalType.ROLE, + 'admin', + PrincipalModel.ROLE, + { version: '2' }, + 10, + ); + await methods.upsertConfig( + PrincipalType.ROLE, + 'user', + PrincipalModel.ROLE, + { version: '3' }, + 10, + ); + + const configs = await methods.getApplicableConfigs([ + { principalType: PrincipalType.ROLE, principalId: 'admin' }, + ]); + + expect(configs).toHaveLength(2); + expect(configs.map((c) => c.principalId).sort()).toEqual(['__base__', 'admin']); + }); + + it('returns sorted by priority', async () => { + await methods.upsertConfig(PrincipalType.ROLE, '__base__', PrincipalModel.ROLE, {}, 0); + await methods.upsertConfig(PrincipalType.ROLE, 'admin', PrincipalModel.ROLE, {}, 10); + + const configs = await methods.getApplicableConfigs([ + { principalType: PrincipalType.ROLE, principalId: 'admin' }, + ]); + + expect(configs[0].principalId).toBe('__base__'); + expect(configs[1].principalId).toBe('admin'); + }); + + it('skips principals with undefined principalId', async () => { + await methods.upsertConfig(PrincipalType.ROLE, '__base__', PrincipalModel.ROLE, {}, 0); + + const configs = await methods.getApplicableConfigs([ + { principalType: PrincipalType.GROUP, principalId: undefined }, + ]); + + expect(configs).toHaveLength(1); + }); +}); + +describe('patchConfigFields', () => { + it('atomically sets specific fields via $set', async () => { + await methods.upsertConfig( + PrincipalType.ROLE, + 'admin', + PrincipalModel.ROLE, + { interface: { endpointsMenu: true, sidePanel: true } }, + 10, + ); + + const result = await methods.patchConfigFields( + PrincipalType.ROLE, + 'admin', + PrincipalModel.ROLE, + { 'interface.endpointsMenu': false }, + 10, + ); + + const overrides = result!.overrides as Record; + const iface = overrides.interface as Record; + expect(iface.endpointsMenu).toBe(false); + expect(iface.sidePanel).toBe(true); + }); + + it('creates a config if none exists (upsert)', async () => { + const result = await methods.patchConfigFields( + PrincipalType.ROLE, + 'newrole', + PrincipalModel.ROLE, + { 'interface.endpointsMenu': false }, + 10, + ); + + expect(result).toBeTruthy(); + expect(result!.principalId).toBe('newrole'); + }); +}); + +describe('unsetConfigField', () => { + it('removes a field from overrides via $unset', async () => { + await methods.upsertConfig( + PrincipalType.ROLE, + 'admin', + PrincipalModel.ROLE, + { interface: { endpointsMenu: false, sidePanel: false } }, + 10, + ); + + const result = await methods.unsetConfigField( + PrincipalType.ROLE, + 'admin', + 'interface.endpointsMenu', + ); + const overrides = result!.overrides as Record; + const iface = overrides.interface as Record; + expect(iface.endpointsMenu).toBeUndefined(); + expect(iface.sidePanel).toBe(false); + }); + + it('returns null for non-existent config', async () => { + const result = await methods.unsetConfigField(PrincipalType.ROLE, 'ghost', 'a.b'); + expect(result).toBeNull(); + }); +}); + +describe('deleteConfig', () => { + it('deletes and returns the config', async () => { + await methods.upsertConfig(PrincipalType.ROLE, 'admin', PrincipalModel.ROLE, {}, 10); + const deleted = await methods.deleteConfig(PrincipalType.ROLE, 'admin'); + expect(deleted).toBeTruthy(); + + const found = await methods.findConfigByPrincipal(PrincipalType.ROLE, 'admin'); + expect(found).toBeNull(); + }); + + it('returns null when deleting non-existent config', async () => { + const result = await methods.deleteConfig(PrincipalType.ROLE, 'ghost'); + expect(result).toBeNull(); + }); +}); + +describe('toggleConfigActive', () => { + it('deactivates an active config', async () => { + await methods.upsertConfig(PrincipalType.ROLE, 'admin', PrincipalModel.ROLE, {}, 10); + + const result = await methods.toggleConfigActive(PrincipalType.ROLE, 'admin', false); + expect(result!.isActive).toBe(false); + }); + + it('reactivates an inactive config', async () => { + await methods.upsertConfig(PrincipalType.ROLE, 'admin', PrincipalModel.ROLE, {}, 10); + await methods.toggleConfigActive(PrincipalType.ROLE, 'admin', false); + + const result = await methods.toggleConfigActive(PrincipalType.ROLE, 'admin', true); + expect(result!.isActive).toBe(true); + }); +}); diff --git a/packages/data-schemas/src/methods/config.ts b/packages/data-schemas/src/methods/config.ts new file mode 100644 index 0000000000..42047d216f --- /dev/null +++ b/packages/data-schemas/src/methods/config.ts @@ -0,0 +1,215 @@ +import { Types } from 'mongoose'; +import { PrincipalType, PrincipalModel } from 'librechat-data-provider'; +import { BASE_CONFIG_PRINCIPAL_ID } from '~/admin/capabilities'; +import type { TCustomConfig } from 'librechat-data-provider'; +import type { Model, ClientSession } from 'mongoose'; +import type { IConfig } from '~/types'; + +export function createConfigMethods(mongoose: typeof import('mongoose')) { + async function findConfigByPrincipal( + principalType: PrincipalType, + principalId: string | Types.ObjectId, + options?: { includeInactive?: boolean }, + session?: ClientSession, + ): Promise { + const Config = mongoose.models.Config as Model; + const filter: { principalType: PrincipalType; principalId: string; isActive?: boolean } = { + principalType, + principalId: principalId.toString(), + }; + if (!options?.includeInactive) { + filter.isActive = true; + } + return await Config.findOne(filter) + .session(session ?? null) + .lean(); + } + + async function listAllConfigs( + filter?: { isActive?: boolean }, + session?: ClientSession, + ): Promise { + const Config = mongoose.models.Config as Model; + const where: { isActive?: boolean } = {}; + if (filter?.isActive !== undefined) { + where.isActive = filter.isActive; + } + return await Config.find(where) + .sort({ priority: 1 }) + .session(session ?? null) + .lean(); + } + + async function getApplicableConfigs( + principals?: Array<{ principalType: string; principalId?: string | Types.ObjectId }>, + session?: ClientSession, + ): Promise { + const Config = mongoose.models.Config as Model; + + const basePrincipal = { + principalType: PrincipalType.ROLE as string, + principalId: BASE_CONFIG_PRINCIPAL_ID, + }; + + const principalsQuery = [basePrincipal]; + + if (principals && principals.length > 0) { + for (const p of principals) { + if (p.principalId !== undefined) { + principalsQuery.push({ + principalType: p.principalType, + principalId: p.principalId.toString(), + }); + } + } + } + + return await Config.find({ + $or: principalsQuery, + isActive: true, + }) + .sort({ priority: 1 }) + .session(session ?? null) + .lean(); + } + + async function upsertConfig( + principalType: PrincipalType, + principalId: string | Types.ObjectId, + principalModel: PrincipalModel, + overrides: Partial, + priority: number, + session?: ClientSession, + ): Promise { + const Config = mongoose.models.Config as Model; + + const query = { + principalType, + principalId: principalId.toString(), + }; + + const update = { + $set: { + principalModel, + overrides, + priority, + isActive: true, + }, + $inc: { configVersion: 1 }, + }; + + const options = { + upsert: true, + new: true, + setDefaultsOnInsert: true, + ...(session ? { session } : {}), + }; + + try { + return await Config.findOneAndUpdate(query, update, options); + } catch (err: unknown) { + if ((err as { code?: number }).code === 11000) { + return await Config.findOneAndUpdate( + query, + { $set: update.$set, $inc: update.$inc }, + { new: true, ...(session ? { session } : {}) }, + ); + } + throw err; + } + } + + async function patchConfigFields( + principalType: PrincipalType, + principalId: string | Types.ObjectId, + principalModel: PrincipalModel, + fields: Record, + priority: number, + session?: ClientSession, + ): Promise { + const Config = mongoose.models.Config as Model; + + const setPayload: { principalModel: PrincipalModel; priority: number; [key: string]: unknown } = + { + principalModel, + priority, + }; + + for (const [path, value] of Object.entries(fields)) { + setPayload[`overrides.${path}`] = value; + } + + const options = { + upsert: true, + new: true, + setDefaultsOnInsert: true, + ...(session ? { session } : {}), + }; + + return await Config.findOneAndUpdate( + { principalType, principalId: principalId.toString() }, + { $set: setPayload, $inc: { configVersion: 1 } }, + options, + ); + } + + async function unsetConfigField( + principalType: PrincipalType, + principalId: string | Types.ObjectId, + fieldPath: string, + session?: ClientSession, + ): Promise { + const Config = mongoose.models.Config as Model; + + const options = { + new: true, + ...(session ? { session } : {}), + }; + + return await Config.findOneAndUpdate( + { principalType, principalId: principalId.toString() }, + { $unset: { [`overrides.${fieldPath}`]: '' }, $inc: { configVersion: 1 } }, + options, + ); + } + + async function deleteConfig( + principalType: PrincipalType, + principalId: string | Types.ObjectId, + session?: ClientSession, + ): Promise { + const Config = mongoose.models.Config as Model; + + return await Config.findOneAndDelete({ + principalType, + principalId: principalId.toString(), + }).session(session ?? null); + } + + async function toggleConfigActive( + principalType: PrincipalType, + principalId: string | Types.ObjectId, + isActive: boolean, + session?: ClientSession, + ): Promise { + const Config = mongoose.models.Config as Model; + return await Config.findOneAndUpdate( + { principalType, principalId: principalId.toString() }, + { $set: { isActive } }, + { new: true, ...(session ? { session } : {}) }, + ); + } + + return { + listAllConfigs, + findConfigByPrincipal, + getApplicableConfigs, + upsertConfig, + patchConfigFields, + unsetConfigField, + deleteConfig, + toggleConfigActive, + }; +} + +export type ConfigMethods = ReturnType; diff --git a/packages/data-schemas/src/methods/conversation.ts b/packages/data-schemas/src/methods/conversation.ts index 7a62afef9e..abfe16bf2d 100644 --- a/packages/data-schemas/src/methods/conversation.ts +++ b/packages/data-schemas/src/methods/conversation.ts @@ -1,6 +1,7 @@ import type { FilterQuery, Model, SortOrder } from 'mongoose'; -import logger from '~/config/winston'; import { createTempChatExpirationDate } from '~/utils/tempChatRetention'; +import { tenantSafeBulkWrite } from '~/utils/tenantBulkWrite'; +import logger from '~/config/winston'; import type { AppConfig, IConversation } from '~/types'; import type { MessageMethods } from './message'; import type { DeleteResult } from 'mongoose'; @@ -228,7 +229,7 @@ export function createConversationMethods( }, })); - const result = await Conversation.bulkWrite(bulkOps); + const result = await tenantSafeBulkWrite(Conversation, bulkOps); return result; } catch (error) { logger.error('[bulkSaveConvos] Error saving conversations in bulk', error); diff --git a/packages/data-schemas/src/methods/conversationTag.ts b/packages/data-schemas/src/methods/conversationTag.ts index af1e43babb..085948bab5 100644 --- a/packages/data-schemas/src/methods/conversationTag.ts +++ b/packages/data-schemas/src/methods/conversationTag.ts @@ -1,4 +1,5 @@ import type { Model } from 'mongoose'; +import { tenantSafeBulkWrite } from '~/utils/tenantBulkWrite'; import logger from '~/config/winston'; interface IConversationTag { @@ -233,7 +234,7 @@ export function createConversationTagMethods(mongoose: typeof import('mongoose') } if (bulkOps.length > 0) { - await ConversationTag.bulkWrite(bulkOps); + await tenantSafeBulkWrite(ConversationTag, bulkOps); } const updatedConversation = ( @@ -273,7 +274,7 @@ export function createConversationTagMethods(mongoose: typeof import('mongoose') }, })); - const result = await ConversationTag.bulkWrite(bulkOps); + const result = await tenantSafeBulkWrite(ConversationTag, bulkOps); if (result && result.modifiedCount > 0) { logger.debug( `user: ${user} | Incremented tag counts - modified ${result.modifiedCount} tags`, diff --git a/packages/data-schemas/src/methods/file.ts b/packages/data-schemas/src/methods/file.ts index 3d7db88c3f..4c0969afb3 100644 --- a/packages/data-schemas/src/methods/file.ts +++ b/packages/data-schemas/src/methods/file.ts @@ -2,6 +2,7 @@ import logger from '../config/winston'; import { EToolResources, FileContext } from 'librechat-data-provider'; import type { FilterQuery, SortOrder, Model } from 'mongoose'; import type { IMongoFile } from '~/types/file'; +import { tenantSafeBulkWrite } from '~/utils/tenantBulkWrite'; /** Factory function that takes mongoose instance and returns the file methods */ export function createFileMethods(mongoose: typeof import('mongoose')) { @@ -322,7 +323,7 @@ export function createFileMethods(mongoose: typeof import('mongoose')) { }, })); - const result = await File.bulkWrite(bulkOperations); + const result = await tenantSafeBulkWrite(File, bulkOperations); logger.info(`Updated ${result.modifiedCount} files with new S3 URLs`); } diff --git a/packages/data-schemas/src/methods/index.ts b/packages/data-schemas/src/methods/index.ts index 11f00e7827..830d88ff4c 100644 --- a/packages/data-schemas/src/methods/index.ts +++ b/packages/data-schemas/src/methods/index.ts @@ -1,9 +1,8 @@ import { createSessionMethods, DEFAULT_REFRESH_TOKEN_EXPIRY, type SessionMethods } from './session'; import { createTokenMethods, type TokenMethods } from './token'; -import { createRoleMethods, type RoleMethods, type RoleDeps } from './role'; +import { createRoleMethods, RoleConflictError } from './role'; +import type { RoleMethods, RoleDeps } from './role'; import { createUserMethods, DEFAULT_SESSION_EXPIRY, type UserMethods } from './user'; - -export { DEFAULT_REFRESH_TOKEN_EXPIRY, DEFAULT_SESSION_EXPIRY }; import { createKeyMethods, type KeyMethods } from './key'; import { createFileMethods, type FileMethods } from './file'; /* Memories */ @@ -48,7 +47,10 @@ import { createSpendTokensMethods, type SpendTokensMethods } from './spendTokens import { createPromptMethods, type PromptMethods, type PromptDeps } from './prompt'; /* Tier 5 — Agent */ import { createAgentMethods, type AgentMethods, type AgentDeps } from './agent'; +/* Config */ +import { createConfigMethods, type ConfigMethods } from './config'; +export { RoleConflictError, DEFAULT_REFRESH_TOKEN_EXPIRY, DEFAULT_SESSION_EXPIRY }; export { tokenValues, cacheTokenValues, premiumTokenValues, defaultRate }; export type AllMethods = UserMethods & @@ -80,7 +82,8 @@ export type AllMethods = UserMethods & TransactionMethods & SpendTokensMethods & PromptMethods & - AgentMethods; + AgentMethods & + ConfigMethods; /** Dependencies injected from the api layer into createMethods */ export interface CreateMethodsDeps { @@ -201,6 +204,8 @@ export function createMethods( ...promptMethods, /* Tier 5 */ ...agentMethods, + /* Config */ + ...createConfigMethods(mongoose), }; } @@ -235,4 +240,5 @@ export type { SpendTokensMethods, PromptMethods, AgentMethods, + ConfigMethods, }; diff --git a/packages/data-schemas/src/methods/message.ts b/packages/data-schemas/src/methods/message.ts index ae5ca72b12..2e638b6bfb 100644 --- a/packages/data-schemas/src/methods/message.ts +++ b/packages/data-schemas/src/methods/message.ts @@ -1,6 +1,7 @@ import type { DeleteResult, FilterQuery, Model } from 'mongoose'; import logger from '~/config/winston'; import { createTempChatExpirationDate } from '~/utils/tempChatRetention'; +import { tenantSafeBulkWrite } from '~/utils/tenantBulkWrite'; import type { AppConfig, IMessage } from '~/types'; /** Simple UUID v4 regex to replace zod validation */ @@ -165,7 +166,7 @@ export function createMessageMethods(mongoose: typeof import('mongoose')): Messa upsert: true, }, })); - const result = await Message.bulkWrite(bulkOps); + const result = await tenantSafeBulkWrite(Message, bulkOps); return result; } catch (err) { logger.error('Error saving messages in bulk:', err); diff --git a/packages/data-schemas/src/methods/prompt.ts b/packages/data-schemas/src/methods/prompt.ts index a1b6bfde37..86d830fecd 100644 --- a/packages/data-schemas/src/methods/prompt.ts +++ b/packages/data-schemas/src/methods/prompt.ts @@ -1,8 +1,9 @@ import { ResourceType, SystemCategories } from 'librechat-data-provider'; import type { Model, Types } from 'mongoose'; import type { IAclEntry, IPrompt, IPromptGroup, IPromptGroupDocument } from '~/types'; -import { escapeRegExp } from '~/utils/string'; +import { getTenantId, SYSTEM_TENANT_ID } from '~/config/tenantContext'; import { isValidObjectIdString } from '~/utils/objectId'; +import { escapeRegExp } from '~/utils/string'; import logger from '~/config/winston'; export interface PromptDeps { @@ -508,16 +509,37 @@ export function createPromptMethods(mongoose: typeof import('mongoose'), deps: P if (typeof matchFilter._id === 'string') { matchFilter._id = new ObjectId(matchFilter._id); } + const tenantId = getTenantId(); + const useTenantFilter = tenantId && tenantId !== SYSTEM_TENANT_ID; + + const lookupStage = useTenantFilter + ? { + $lookup: { + from: 'prompts', + let: { prodId: '$productionId' }, + pipeline: [ + { + $match: { + $expr: { $eq: ['$_id', '$$prodId'] }, + tenantId, + }, + }, + ], + as: 'productionPrompt', + }, + } + : { + $lookup: { + from: 'prompts', + localField: 'productionId', + foreignField: '_id', + as: 'productionPrompt', + }, + }; + const result = await PromptGroup.aggregate([ { $match: matchFilter }, - { - $lookup: { - from: 'prompts', - localField: 'productionId', - foreignField: '_id', - as: 'productionPrompt', - }, - }, + lookupStage, { $unwind: { path: '$productionPrompt', preserveNullAndEmptyArrays: true } }, ]); const group = result[0] || null; diff --git a/packages/data-schemas/src/methods/role.methods.spec.ts b/packages/data-schemas/src/methods/role.methods.spec.ts index 78d7f98ea1..be75be7b6f 100644 --- a/packages/data-schemas/src/methods/role.methods.spec.ts +++ b/packages/data-schemas/src/methods/role.methods.spec.ts @@ -1,10 +1,17 @@ import mongoose from 'mongoose'; import { MongoMemoryServer } from 'mongodb-memory-server'; import { SystemRoles, Permissions, roleDefaults, PermissionTypes } from 'librechat-data-provider'; -import type { IRole, RolePermissions } from '..'; +import type { IRole, IUser, RolePermissions } from '..'; import { createRoleMethods } from './role'; import { createModels } from '../models'; +jest.mock('~/config/winston', () => ({ + error: jest.fn(), + warn: jest.fn(), + info: jest.fn(), + debug: jest.fn(), +})); + const mockCache = { get: jest.fn(), set: jest.fn(), @@ -14,9 +21,18 @@ const mockCache = { const mockGetCache = jest.fn().mockReturnValue(mockCache); let Role: mongoose.Model; +let User: mongoose.Model; let getRoleByName: ReturnType['getRoleByName']; let updateAccessPermissions: ReturnType['updateAccessPermissions']; let initializeRoles: ReturnType['initializeRoles']; +let createRoleByName: ReturnType['createRoleByName']; +let deleteRoleByName: ReturnType['deleteRoleByName']; +let updateUsersByRole: ReturnType['updateUsersByRole']; +let listUsersByRole: ReturnType['listUsersByRole']; +let countUsersByRole: ReturnType['countUsersByRole']; +let updateRoleByName: ReturnType['updateRoleByName']; +let listRoles: ReturnType['listRoles']; +let countRoles: ReturnType['countRoles']; let mongoServer: MongoMemoryServer; beforeAll(async () => { @@ -25,10 +41,19 @@ beforeAll(async () => { await mongoose.connect(mongoUri); createModels(mongoose); Role = mongoose.models.Role; + User = mongoose.models.User as mongoose.Model; const methods = createRoleMethods(mongoose, { getCache: mockGetCache }); getRoleByName = methods.getRoleByName; updateAccessPermissions = methods.updateAccessPermissions; initializeRoles = methods.initializeRoles; + createRoleByName = methods.createRoleByName; + deleteRoleByName = methods.deleteRoleByName; + updateRoleByName = methods.updateRoleByName; + updateUsersByRole = methods.updateUsersByRole; + listUsersByRole = methods.listUsersByRole; + countUsersByRole = methods.countUsersByRole; + listRoles = methods.listRoles; + countRoles = methods.countRoles; }); afterAll(async () => { @@ -38,6 +63,7 @@ afterAll(async () => { beforeEach(async () => { await Role.deleteMany({}); + await User.deleteMany({}); mockGetCache.mockClear(); mockCache.get.mockClear(); mockCache.set.mockClear(); @@ -259,12 +285,12 @@ describe('updateAccessPermissions', () => { const updatedRole = await getRoleByName(SystemRoles.USER); // SHARED_GLOBAL=true → SHARE=true (inherited) - expect(updatedRole.permissions[PermissionTypes.PROMPTS].SHARE).toBe(true); + expect(updatedRole.permissions[PermissionTypes.PROMPTS]!.SHARE).toBe(true); // SHARED_GLOBAL=false → SHARE=false (inherited) - expect(updatedRole.permissions[PermissionTypes.AGENTS].SHARE).toBe(false); + expect(updatedRole.permissions[PermissionTypes.AGENTS]!.SHARE).toBe(false); // SHARED_GLOBAL cleaned up - expect(updatedRole.permissions[PermissionTypes.PROMPTS].SHARED_GLOBAL).toBeUndefined(); - expect(updatedRole.permissions[PermissionTypes.AGENTS].SHARED_GLOBAL).toBeUndefined(); + expect(updatedRole.permissions[PermissionTypes.PROMPTS]).not.toHaveProperty('SHARED_GLOBAL'); + expect(updatedRole.permissions[PermissionTypes.AGENTS]).not.toHaveProperty('SHARED_GLOBAL'); }); it('should respect explicit SHARE in update payload and not override it with SHARED_GLOBAL', async () => { @@ -283,8 +309,8 @@ describe('updateAccessPermissions', () => { const updatedRole = await getRoleByName(SystemRoles.USER); - expect(updatedRole.permissions[PermissionTypes.PROMPTS].SHARE).toBe(false); - expect(updatedRole.permissions[PermissionTypes.PROMPTS].SHARED_GLOBAL).toBeUndefined(); + expect(updatedRole.permissions[PermissionTypes.PROMPTS]!.SHARE).toBe(false); + expect(updatedRole.permissions[PermissionTypes.PROMPTS]).not.toHaveProperty('SHARED_GLOBAL'); }); it('should migrate SHARED_GLOBAL to SHARE even when the permType is not in the update payload', async () => { @@ -310,13 +336,13 @@ describe('updateAccessPermissions', () => { const updatedRole = await getRoleByName(SystemRoles.USER); // SHARE should have been inherited from SHARED_GLOBAL, not silently dropped - expect(updatedRole.permissions[PermissionTypes.PROMPTS].SHARE).toBe(true); + expect(updatedRole.permissions[PermissionTypes.PROMPTS]!.SHARE).toBe(true); // SHARED_GLOBAL should be removed - expect(updatedRole.permissions[PermissionTypes.PROMPTS].SHARED_GLOBAL).toBeUndefined(); + expect(updatedRole.permissions[PermissionTypes.PROMPTS]).not.toHaveProperty('SHARED_GLOBAL'); // Original USE should be untouched - expect(updatedRole.permissions[PermissionTypes.PROMPTS].USE).toBe(true); + expect(updatedRole.permissions[PermissionTypes.PROMPTS]!.USE).toBe(true); // The actual update should have applied - expect(updatedRole.permissions[PermissionTypes.MULTI_CONVO].USE).toBe(true); + expect(updatedRole.permissions[PermissionTypes.MULTI_CONVO]!.USE).toBe(true); }); it('should remove orphaned SHARED_GLOBAL when SHARE already exists and permType is not in update', async () => { @@ -340,9 +366,9 @@ describe('updateAccessPermissions', () => { const updatedRole = await getRoleByName(SystemRoles.USER); - expect(updatedRole.permissions[PermissionTypes.PROMPTS].SHARED_GLOBAL).toBeUndefined(); - expect(updatedRole.permissions[PermissionTypes.PROMPTS].SHARE).toBe(true); - expect(updatedRole.permissions[PermissionTypes.MULTI_CONVO].USE).toBe(true); + expect(updatedRole.permissions[PermissionTypes.PROMPTS]).not.toHaveProperty('SHARED_GLOBAL'); + expect(updatedRole.permissions[PermissionTypes.PROMPTS]!.SHARE).toBe(true); + expect(updatedRole.permissions[PermissionTypes.MULTI_CONVO]!.USE).toBe(true); }); it('should not update MULTI_CONVO permissions when no changes are needed', async () => { @@ -515,3 +541,362 @@ describe('initializeRoles', () => { expect(userRole.permissions[PermissionTypes.MULTI_CONVO]?.USE).toBeDefined(); }); }); + +describe('createRoleByName', () => { + it('creates a custom role and caches it', async () => { + const role = await createRoleByName({ name: 'editor', description: 'Can edit' }); + + expect(role.name).toBe('editor'); + expect(role.description).toBe('Can edit'); + expect(mockCache.set).toHaveBeenCalledWith( + 'editor', + expect.objectContaining({ name: 'editor' }), + ); + + const persisted = await Role.findOne({ name: 'editor' }).lean(); + expect(persisted).toBeTruthy(); + }); + + it('trims whitespace from role name', async () => { + const role = await createRoleByName({ name: ' editor ' }); + + expect(role.name).toBe('editor'); + }); + + it('throws when name is empty', async () => { + await expect(createRoleByName({ name: '' })).rejects.toThrow('Role name is required'); + }); + + it('throws when name is whitespace-only', async () => { + await expect(createRoleByName({ name: ' ' })).rejects.toThrow('Role name is required'); + }); + + it('throws when name is undefined', async () => { + await expect(createRoleByName({})).rejects.toThrow('Role name is required'); + }); + + it('throws for reserved system role names', async () => { + await expect(createRoleByName({ name: SystemRoles.ADMIN })).rejects.toThrow( + /reserved system name/, + ); + await expect(createRoleByName({ name: SystemRoles.USER })).rejects.toThrow( + /reserved system name/, + ); + }); + + it('throws when role already exists', async () => { + await createRoleByName({ name: 'editor' }); + + await expect(createRoleByName({ name: 'editor' })).rejects.toThrow(/already exists/); + }); +}); + +describe('deleteRoleByName', () => { + it('deletes a custom role and reassigns users to USER', async () => { + await createRoleByName({ name: 'editor' }); + await User.create([ + { name: 'Alice', email: 'alice@test.com', role: 'editor', username: 'alice' }, + { name: 'Bob', email: 'bob@test.com', role: 'editor', username: 'bob' }, + { name: 'Carol', email: 'carol@test.com', role: SystemRoles.USER, username: 'carol' }, + ]); + + const deleted = await deleteRoleByName('editor'); + + expect(deleted).toBeTruthy(); + expect(deleted!.name).toBe('editor'); + + const alice = await User.findOne({ email: 'alice@test.com' }).lean(); + const bob = await User.findOne({ email: 'bob@test.com' }).lean(); + const carol = await User.findOne({ email: 'carol@test.com' }).lean(); + expect(alice!.role).toBe(SystemRoles.USER); + expect(bob!.role).toBe(SystemRoles.USER); + expect(carol!.role).toBe(SystemRoles.USER); + }); + + it('returns null when role does not exist', async () => { + const result = await deleteRoleByName('nonexistent'); + expect(result).toBeNull(); + }); + + it('throws for system roles', async () => { + await expect(deleteRoleByName(SystemRoles.ADMIN)).rejects.toThrow(/Cannot delete system role/); + await expect(deleteRoleByName(SystemRoles.USER)).rejects.toThrow(/Cannot delete system role/); + }); + + it('sets cache entry to null after deletion', async () => { + await createRoleByName({ name: 'editor' }); + mockCache.set.mockClear(); + + await deleteRoleByName('editor'); + + expect(mockCache.set).toHaveBeenCalledWith('editor', null); + }); + + it('returns null and invalidates cache when role does not exist', async () => { + mockCache.set.mockClear(); + + const result = await deleteRoleByName('nonexistent'); + + expect(result).toBeNull(); + expect(mockCache.set).toHaveBeenCalledWith('nonexistent', null); + }); +}); + +describe('updateRoleByName - cache on rename', () => { + it('invalidates old key and populates new key on rename', async () => { + await createRoleByName({ name: 'editor', description: 'Can edit' }); + mockCache.set.mockClear(); + + const updated = await updateRoleByName('editor', { name: 'senior-editor' }); + + expect(updated.name).toBe('senior-editor'); + expect(mockCache.set).toHaveBeenCalledWith('editor', null); + expect(mockCache.set).toHaveBeenCalledWith( + 'senior-editor', + expect.objectContaining({ name: 'senior-editor' }), + ); + }); + + it('writes same key when name unchanged', async () => { + await createRoleByName({ name: 'editor' }); + mockCache.set.mockClear(); + + await updateRoleByName('editor', { description: 'Updated desc' }); + + expect(mockCache.set).toHaveBeenCalledWith( + 'editor', + expect.objectContaining({ name: 'editor', description: 'Updated desc' }), + ); + expect(mockCache.set).toHaveBeenCalledTimes(1); + }); +}); + +describe('listUsersByRole', () => { + it('returns users matching the role', async () => { + await User.create([ + { name: 'Alice', email: 'alice@test.com', role: 'editor', username: 'alice' }, + { name: 'Bob', email: 'bob@test.com', role: 'editor', username: 'bob' }, + { name: 'Carol', email: 'carol@test.com', role: SystemRoles.USER, username: 'carol' }, + ]); + + const users = await listUsersByRole('editor'); + + expect(users).toHaveLength(2); + const names = users.map((u) => u.name).sort(); + expect(names).toEqual(['Alice', 'Bob']); + }); + + it('returns empty array when no users have the role', async () => { + const users = await listUsersByRole('nonexistent'); + expect(users).toEqual([]); + }); + + it('respects limit and offset for pagination', async () => { + await User.create([ + { name: 'Alice', email: 'a@test.com', role: 'editor', username: 'a' }, + { name: 'Bob', email: 'b@test.com', role: 'editor', username: 'b' }, + { name: 'Carol', email: 'c@test.com', role: 'editor', username: 'c' }, + { name: 'Dave', email: 'd@test.com', role: 'editor', username: 'd' }, + { name: 'Eve', email: 'e@test.com', role: 'editor', username: 'e' }, + ]); + + const page1 = await listUsersByRole('editor', { limit: 2, offset: 0 }); + const page2 = await listUsersByRole('editor', { limit: 2, offset: 2 }); + const page3 = await listUsersByRole('editor', { limit: 2, offset: 4 }); + + expect(page1).toHaveLength(2); + expect(page2).toHaveLength(2); + expect(page3).toHaveLength(1); + + const allIds = [...page1, ...page2, ...page3].map((u) => u._id!.toString()); + expect(new Set(allIds).size).toBe(5); + }); + + it('selects only expected fields', async () => { + await User.create({ + name: 'Alice', + email: 'alice@test.com', + role: 'editor', + username: 'alice', + password: 'secret123', + }); + + const users = await listUsersByRole('editor'); + + expect(users).toHaveLength(1); + expect(users[0].name).toBe('Alice'); + expect(users[0].email).toBe('alice@test.com'); + expect(users[0]._id).toBeDefined(); + expect('password' in users[0]).toBe(false); + expect('username' in users[0]).toBe(false); + }); +}); + +describe('updateUsersByRole', () => { + it('migrates all users from one role to another', async () => { + await User.create([ + { name: 'Alice', email: 'alice@test.com', role: 'editor', username: 'alice' }, + { name: 'Bob', email: 'bob@test.com', role: 'editor', username: 'bob' }, + { name: 'Carol', email: 'carol@test.com', role: SystemRoles.USER, username: 'carol' }, + ]); + + await updateUsersByRole('editor', 'senior-editor'); + + const alice = await User.findOne({ email: 'alice@test.com' }).lean(); + const bob = await User.findOne({ email: 'bob@test.com' }).lean(); + const carol = await User.findOne({ email: 'carol@test.com' }).lean(); + expect(alice!.role).toBe('senior-editor'); + expect(bob!.role).toBe('senior-editor'); + expect(carol!.role).toBe(SystemRoles.USER); + }); + + it('is a no-op when no users have the source role', async () => { + await User.create({ + name: 'Alice', + email: 'alice@test.com', + role: SystemRoles.USER, + username: 'alice', + }); + + await updateUsersByRole('nonexistent', 'new-role'); + + const alice = await User.findOne({ email: 'alice@test.com' }).lean(); + expect(alice!.role).toBe(SystemRoles.USER); + }); +}); + +describe('countUsersByRole', () => { + it('returns the count of users with the given role', async () => { + await User.create([ + { name: 'Alice', email: 'alice@test.com', role: 'editor', username: 'alice' }, + { name: 'Bob', email: 'bob@test.com', role: 'editor', username: 'bob' }, + { name: 'Carol', email: 'carol@test.com', role: SystemRoles.USER, username: 'carol' }, + ]); + + expect(await countUsersByRole('editor')).toBe(2); + expect(await countUsersByRole(SystemRoles.USER)).toBe(1); + }); + + it('returns 0 when no users have the role', async () => { + expect(await countUsersByRole('nonexistent')).toBe(0); + }); +}); + +describe('listRoles', () => { + beforeEach(async () => { + await Role.deleteMany({}); + }); + + it('returns roles sorted alphabetically by name', async () => { + await Role.create([ + { name: 'zebra', permissions: {} }, + { name: 'alpha', permissions: {} }, + { name: 'middle', permissions: {} }, + ]); + + const roles = await listRoles(); + + expect(roles.map((r) => r.name)).toEqual(['alpha', 'middle', 'zebra']); + }); + + it('respects limit and offset for pagination', async () => { + await Role.create([ + { name: 'a-role', permissions: {} }, + { name: 'b-role', permissions: {} }, + { name: 'c-role', permissions: {} }, + { name: 'd-role', permissions: {} }, + { name: 'e-role', permissions: {} }, + ]); + + const page1 = await listRoles({ limit: 2, offset: 0 }); + const page2 = await listRoles({ limit: 2, offset: 2 }); + const page3 = await listRoles({ limit: 2, offset: 4 }); + + expect(page1).toHaveLength(2); + expect(page1.map((r) => r.name)).toEqual(['a-role', 'b-role']); + expect(page2).toHaveLength(2); + expect(page2.map((r) => r.name)).toEqual(['c-role', 'd-role']); + expect(page3).toHaveLength(1); + expect(page3.map((r) => r.name)).toEqual(['e-role']); + }); + + it('defaults to limit 50 and offset 0', async () => { + await Role.create({ name: 'only-role', permissions: {} }); + + const roles = await listRoles(); + + expect(roles).toHaveLength(1); + expect(roles[0].name).toBe('only-role'); + }); + + it('returns only name and description fields', async () => { + await Role.create({ + name: 'editor', + description: 'Can edit', + permissions: { PROMPTS: { USE: true } }, + }); + + const roles = await listRoles(); + + expect(roles).toHaveLength(1); + expect(roles[0].name).toBe('editor'); + expect(roles[0].description).toBe('Can edit'); + expect(roles[0]._id).toBeDefined(); + expect('permissions' in roles[0]).toBe(false); + }); + + it('returns empty array when no roles exist', async () => { + const roles = await listRoles(); + expect(roles).toEqual([]); + }); + + it('returns undefined description for pre-existing roles without the field', async () => { + await Role.collection.insertOne({ name: 'legacy', permissions: {} }); + + const roles = await listRoles(); + + expect(roles).toHaveLength(1); + expect(roles[0].name).toBe('legacy'); + expect(roles[0].description).toBeUndefined(); + }); +}); + +describe('countRoles', () => { + beforeEach(async () => { + await Role.deleteMany({}); + }); + + it('returns the total number of roles', async () => { + await Role.create([ + { name: 'a', permissions: {} }, + { name: 'b', permissions: {} }, + { name: 'c', permissions: {} }, + ]); + + expect(await countRoles()).toBe(3); + }); + + it('returns 0 when no roles exist', async () => { + expect(await countRoles()).toBe(0); + }); +}); + +describe('createRoleByName - duplicate key race', () => { + beforeEach(async () => { + await Role.deleteMany({}); + }); + + it('throws RoleConflictError on concurrent insert (11000)', async () => { + await createRoleByName({ name: 'editor' }); + + const insertSpy = jest.spyOn(Role.prototype, 'save').mockImplementationOnce(() => { + const err = new Error('E11000 duplicate key error') as Error & { code: number }; + err.code = 11000; + throw err; + }); + + await expect(createRoleByName({ name: 'editor2' })).rejects.toThrow(/already exists/); + + insertSpy.mockRestore(); + }); +}); diff --git a/packages/data-schemas/src/methods/role.ts b/packages/data-schemas/src/methods/role.ts index 7b51e45330..e84b91420a 100644 --- a/packages/data-schemas/src/methods/role.ts +++ b/packages/data-schemas/src/methods/role.ts @@ -5,9 +5,24 @@ import { permissionsSchema, removeNullishValues, } from 'librechat-data-provider'; -import type { IRole } from '~/types'; +import type { Model } from 'mongoose'; +import type { IRole, IUser } from '~/types'; import logger from '~/config/winston'; +const systemRoleValues = new Set(Object.values(SystemRoles)); + +/** Case-insensitive check — the legacy roles route uppercases params. */ +function isSystemRoleName(name: string): boolean { + return systemRoleValues.has(name.toUpperCase()); +} + +export class RoleConflictError extends Error { + constructor(message: string) { + super(message); + this.name = 'RoleConflictError'; + } +} + export interface RoleDeps { /** Returns a cache store for the given key. Injected from getLogStores. */ getCache?: (key: string) => { @@ -30,8 +45,11 @@ export function createRoleMethods(mongoose: typeof import('mongoose'), deps: Rol const defaultPerms = roleDefaults[roleName].permissions; if (!role) { - role = new Role(roleDefaults[roleName]); + role = new Role({ ...roleDefaults[roleName], description: '' }); } else { + if (role.description == null) { + role.description = ''; + } const permissions = role.toObject()?.permissions ?? {}; role.permissions = role.permissions || {}; for (const permType of Object.keys(defaultPerms)) { @@ -45,11 +63,26 @@ export function createRoleMethods(mongoose: typeof import('mongoose'), deps: Rol } /** - * List all roles in the system. + * List all roles in the system. Returns only name and description (projected). */ - async function listRoles() { + async function listRoles(options?: { + limit?: number; + offset?: number; + }): Promise[]> { + const Role = mongoose.models.Role as Model; + const limit = options?.limit ?? 50; + const offset = options?.offset ?? 0; + return await Role.find({}) + .select('name description') + .sort({ name: 1 }) + .skip(offset) + .limit(limit) + .lean(); + } + + async function countRoles(): Promise { const Role = mongoose.models.Role; - return await Role.find({}).select('name permissions').lean(); + return await Role.countDocuments({}); } /** @@ -73,7 +106,7 @@ export function createRoleMethods(mongoose: typeof import('mongoose'), deps: Rol } const role = await query.lean().exec(); - if (!role && SystemRoles[roleName as keyof typeof SystemRoles]) { + if (!role && systemRoleValues.has(roleName)) { const newRole = await new Role(roleDefaults[roleName as keyof typeof roleDefaults]).save(); if (cache) { await cache.set(roleName, newRole); @@ -96,20 +129,24 @@ export function createRoleMethods(mongoose: typeof import('mongoose'), deps: Rol const cache = deps.getCache?.(CacheKeys.ROLES); try { const Role = mongoose.models.Role; - const role = await Role.findOneAndUpdate( - { name: roleName }, - { $set: updates }, - { new: true, lean: true }, - ) + const role = await Role.findOneAndUpdate({ name: roleName }, { $set: updates }, { new: true }) .select('-__v') .lean() .exec(); if (cache) { - await cache.set(roleName, role); + if (updates.name && updates.name !== roleName) { + await Promise.all([cache.set(roleName, null), cache.set(updates.name, role)]); + } else { + await cache.set(roleName, role); + } } return role as unknown as IRole; } catch (error) { - throw new Error(`Failed to update role: ${(error as Error).message}`); + if (error && typeof error === 'object' && 'code' in error && error.code === 11000) { + const targetName = updates.name ?? roleName; + throw new RoleConflictError(`Role "${targetName}" already exists`); + } + throw new Error(`Failed to update role: ${(error as Error).message}`, { cause: error }); } } @@ -342,13 +379,137 @@ export function createRoleMethods(mongoose: typeof import('mongoose'), deps: Rol } } + /** Rejects names that match system roles. */ + async function createRoleByName(roleData: Partial): Promise { + const { name } = roleData; + if (!name || typeof name !== 'string' || !name.trim()) { + throw new Error('Role name is required'); + } + const trimmed = name.trim(); + if (isSystemRoleName(trimmed)) { + throw new RoleConflictError(`Cannot create role with reserved system name: ${name}`); + } + const Role = mongoose.models.Role; + const existing = await Role.findOne({ name: trimmed }).lean(); + if (existing) { + throw new RoleConflictError(`Role "${trimmed}" already exists`); + } + let role; + try { + role = await new Role({ ...roleData, name: trimmed }).save(); + } catch (err) { + /** + * The compound unique index `{ name: 1, tenantId: 1 }` on the role schema + * (roleSchema.index in schema/role.ts) triggers error 11000 when a concurrent + * request races past the findOne check above. This catch converts it into + * the same user-facing message as the application-level duplicate check. + */ + if (err && typeof err === 'object' && 'code' in err && err.code === 11000) { + throw new RoleConflictError(`Role "${trimmed}" already exists`); + } + throw err; + } + try { + const cache = deps.getCache?.(CacheKeys.ROLES); + if (cache) { + await cache.set(role.name, role.toObject()); + } + } catch (cacheError) { + logger.error(`[createRoleByName] cache set failed for "${role.name}":`, cacheError); + } + return role.toObject() as IRole; + } + + /** + * Guards against deleting system roles. Reassigns affected users back to USER. + * + * No existence pre-check is performed: for a nonexistent role the `updateMany` + * is a harmless no-op and `findOneAndDelete` returns null. This makes the + * function idempotent — a retry after a partial failure will still clean up + * orphaned user references and cache entries. + * + * Without a MongoDB transaction the two writes are non-atomic — if the delete + * fails after the reassignment, users will already have been moved to USER + * while the role document still exists. Recovery requires the caller to retry + * the delete call, which will succeed since the `updateMany` is a no-op on + * the second pass. + */ + async function deleteRoleByName(roleName: string): Promise { + if (isSystemRoleName(roleName)) { + throw new Error(`Cannot delete system role: ${roleName}`); + } + const Role = mongoose.models.Role; + const User = mongoose.models.User as Model; + await User.updateMany({ role: roleName }, { $set: { role: SystemRoles.USER } }); + const deleted = await Role.findOneAndDelete({ name: roleName }).lean(); + try { + const cache = deps.getCache?.(CacheKeys.ROLES); + if (cache) { + // Setting null evicts the stale document. getRoleByName treats falsy cached + // values as a miss and falls through to the DB, so this does not provide + // negative caching — it only prevents serving the pre-deletion document. + await cache.set(roleName, null); + } + } catch (cacheError) { + logger.error(`[deleteRoleByName] cache invalidation failed for "${roleName}":`, cacheError); + } + return deleted as IRole | null; + } + + async function updateUsersByRole(oldRole: string, newRole: string): Promise { + const User = mongoose.models.User as Model; + await User.updateMany({ role: oldRole }, { $set: { role: newRole } }); + } + + async function findUserIdsByRole(roleName: string): Promise { + const User = mongoose.models.User as Model; + const users = await User.find({ role: roleName }).select('_id').lean(); + return users.map((u) => u._id.toString()); + } + + async function updateUsersRoleByIds(userIds: string[], newRole: string): Promise { + if (userIds.length === 0) { + return; + } + const User = mongoose.models.User as Model; + await User.updateMany({ _id: { $in: userIds } }, { $set: { role: newRole } }); + } + + async function listUsersByRole( + roleName: string, + options?: { limit?: number; offset?: number }, + ): Promise { + const User = mongoose.models.User as Model; + const limit = options?.limit ?? 50; + const offset = options?.offset ?? 0; + return await User.find({ role: roleName }) + .select('_id name email avatar') + .sort({ _id: 1 }) + .skip(offset) + .limit(limit) + .lean(); + } + + async function countUsersByRole(roleName: string): Promise { + const User = mongoose.models.User as Model; + return await User.countDocuments({ role: roleName }); + } + return { listRoles, + countRoles, initializeRoles, getRoleByName, updateRoleByName, updateAccessPermissions, migrateRoleSchema, + createRoleByName, + deleteRoleByName, + updateUsersByRole, + findUserIdsByRole, + updateUsersRoleByIds, + listUsersByRole, + countUsersByRole, }; } diff --git a/packages/data-schemas/src/methods/spendTokens.spec.ts b/packages/data-schemas/src/methods/spendTokens.spec.ts index 5730bc7bdd..d505663d57 100644 --- a/packages/data-schemas/src/methods/spendTokens.spec.ts +++ b/packages/data-schemas/src/methods/spendTokens.spec.ts @@ -864,8 +864,8 @@ describe('spendTokens', () => { const expectedCompletionCost = tokenUsage.completionTokens * premiumCompletionRate; expect(result).not.toBeNull(); - expect(result!.prompt.prompt).toBeCloseTo(-expectedPromptCost, 0); - expect(result!.completion.completion).toBeCloseTo(-expectedCompletionCost, 0); + expect(result!.prompt!.prompt).toBeCloseTo(-expectedPromptCost, 0); + expect(result!.completion!.completion).toBeCloseTo(-expectedCompletionCost, 0); }); it('should charge standard rates for structured tokens when below threshold', async () => { @@ -907,8 +907,8 @@ describe('spendTokens', () => { const expectedCompletionCost = tokenUsage.completionTokens * standardCompletionRate; expect(result).not.toBeNull(); - expect(result!.prompt.prompt).toBeCloseTo(-expectedPromptCost, 0); - expect(result!.completion.completion).toBeCloseTo(-expectedCompletionCost, 0); + expect(result!.prompt!.prompt).toBeCloseTo(-expectedPromptCost, 0); + expect(result!.completion!.completion).toBeCloseTo(-expectedCompletionCost, 0); }); it('should charge standard rates for gemini-3.1-pro-preview when prompt tokens are below threshold', async () => { @@ -937,7 +937,7 @@ describe('spendTokens', () => { completionTokens * tokenValues['gemini-3.1'].completion; const balance = await Balance.findOne({ user: userId }); - expect(balance.tokenCredits).toBeCloseTo(initialBalance - expectedCost, 0); + expect(balance!.tokenCredits).toBeCloseTo(initialBalance - expectedCost, 0); }); it('should charge premium rates for gemini-3.1-pro-preview when prompt tokens exceed threshold', async () => { @@ -966,7 +966,7 @@ describe('spendTokens', () => { completionTokens * premiumTokenValues['gemini-3.1'].completion; const balance = await Balance.findOne({ user: userId }); - expect(balance.tokenCredits).toBeCloseTo(initialBalance - expectedCost, 0); + expect(balance!.tokenCredits).toBeCloseTo(initialBalance - expectedCost, 0); }); it('should charge premium rates for gemini-3.1-pro-preview-customtools when prompt tokens exceed threshold', async () => { @@ -995,7 +995,7 @@ describe('spendTokens', () => { completionTokens * premiumTokenValues['gemini-3.1'].completion; const balance = await Balance.findOne({ user: userId }); - expect(balance.tokenCredits).toBeCloseTo(initialBalance - expectedCost, 0); + expect(balance!.tokenCredits).toBeCloseTo(initialBalance - expectedCost, 0); }); it('should charge premium rates for structured gemini-3.1 tokens when total input exceeds threshold', async () => { @@ -1032,13 +1032,13 @@ describe('spendTokens', () => { const expectedPromptCost = tokenUsage.promptTokens.input * premiumPromptRate + - tokenUsage.promptTokens.write * writeRate + - tokenUsage.promptTokens.read * readRate; + tokenUsage.promptTokens.write * writeRate! + + tokenUsage.promptTokens.read * readRate!; const expectedCompletionCost = tokenUsage.completionTokens * premiumCompletionRate; expect(result).not.toBeNull(); - expect(result!.prompt.prompt).toBeCloseTo(-expectedPromptCost, 0); - expect(result!.completion.completion).toBeCloseTo(-expectedCompletionCost, 0); + expect(result!.prompt!.prompt).toBeCloseTo(-expectedPromptCost, 0); + expect(result!.completion!.completion).toBeCloseTo(-expectedCompletionCost, 0); }); it('should not apply premium pricing to non-premium models regardless of prompt size', async () => { diff --git a/packages/data-schemas/src/methods/systemGrant.spec.ts b/packages/data-schemas/src/methods/systemGrant.spec.ts index 188d31b544..49b4f7269e 100644 --- a/packages/data-schemas/src/methods/systemGrant.spec.ts +++ b/packages/data-schemas/src/methods/systemGrant.spec.ts @@ -2,8 +2,8 @@ import mongoose, { Types } from 'mongoose'; import { PrincipalType, SystemRoles } from 'librechat-data-provider'; import { MongoMemoryServer } from 'mongodb-memory-server'; import type * as t from '~/types'; -import type { SystemCapability } from '~/systemCapabilities'; -import { SystemCapabilities, CapabilityImplications } from '~/systemCapabilities'; +import type { SystemCapability } from '~/types/admin'; +import { SystemCapabilities, CapabilityImplications } from '~/admin/capabilities'; import { createSystemGrantMethods } from './systemGrant'; import systemGrantSchema from '~/schema/systemGrant'; import logger from '~/config/winston'; @@ -702,6 +702,68 @@ describe('systemGrant methods', () => { }); }); + describe('deleteGrantsForPrincipal', () => { + it('deletes all grants for a principal', async () => { + const groupId = new Types.ObjectId(); + + await methods.grantCapability({ + principalType: PrincipalType.GROUP, + principalId: groupId, + capability: SystemCapabilities.READ_USERS, + }); + await methods.grantCapability({ + principalType: PrincipalType.GROUP, + principalId: groupId, + capability: SystemCapabilities.READ_CONFIGS, + }); + + await methods.deleteGrantsForPrincipal(PrincipalType.GROUP, groupId); + + const remaining = await SystemGrant.countDocuments({ + principalType: PrincipalType.GROUP, + principalId: groupId, + }); + expect(remaining).toBe(0); + }); + + it('is a no-op for principal with no grants', async () => { + const groupId = new Types.ObjectId(); + + await expect( + methods.deleteGrantsForPrincipal(PrincipalType.GROUP, groupId), + ).resolves.not.toThrow(); + }); + + it('does not affect other principals', async () => { + const groupA = new Types.ObjectId(); + const groupB = new Types.ObjectId(); + + await methods.grantCapability({ + principalType: PrincipalType.GROUP, + principalId: groupA, + capability: SystemCapabilities.READ_USERS, + }); + await methods.grantCapability({ + principalType: PrincipalType.GROUP, + principalId: groupB, + capability: SystemCapabilities.READ_USERS, + }); + + await methods.deleteGrantsForPrincipal(PrincipalType.GROUP, groupA); + + const remainingA = await SystemGrant.countDocuments({ + principalType: PrincipalType.GROUP, + principalId: groupA, + }); + const remainingB = await SystemGrant.countDocuments({ + principalType: PrincipalType.GROUP, + principalId: groupB, + }); + expect(remainingA).toBe(0); + expect(remainingB).toBe(1); + }); + }); + describe('schema validation', () => { it('rejects null tenantId at the schema level', async () => { await expect( diff --git a/packages/data-schemas/src/methods/systemGrant.ts b/packages/data-schemas/src/methods/systemGrant.ts index f0f389d762..4954f50c16 100644 --- a/packages/data-schemas/src/methods/systemGrant.ts +++ b/packages/data-schemas/src/methods/systemGrant.ts @@ -1,8 +1,8 @@ import { PrincipalType, SystemRoles } from 'librechat-data-provider'; import type { Types, Model, ClientSession } from 'mongoose'; -import type { SystemCapability } from '~/systemCapabilities'; +import type { SystemCapability } from '~/types/admin'; import type { ISystemGrant } from '~/types'; -import { SystemCapabilities, CapabilityImplications } from '~/systemCapabilities'; +import { SystemCapabilities, CapabilityImplications } from '~/admin/capabilities'; import { normalizePrincipalId } from '~/utils/principal'; import logger from '~/config/winston'; @@ -246,12 +246,28 @@ export function createSystemGrantMethods(mongoose: typeof import('mongoose')) { } } + /** + * Delete all system grants for a principal. + * Used for cascade cleanup when a principal (group, role) is deleted. + */ + async function deleteGrantsForPrincipal( + principalType: PrincipalType, + principalId: string | Types.ObjectId, + session?: ClientSession, + ): Promise { + const SystemGrant = mongoose.models.SystemGrant as Model; + const normalizedPrincipalId = normalizePrincipalId(principalId, principalType); + const options = session ? { session } : {}; + await SystemGrant.deleteMany({ principalType, principalId: normalizedPrincipalId }, options); + } + return { grantCapability, seedSystemGrants, revokeCapability, hasCapabilityForPrincipals, getCapabilitiesForPrincipal, + deleteGrantsForPrincipal, }; } diff --git a/packages/data-schemas/src/methods/tx.ts b/packages/data-schemas/src/methods/tx.ts index a1be4190ba..a048874457 100644 --- a/packages/data-schemas/src/methods/tx.ts +++ b/packages/data-schemas/src/methods/tx.ts @@ -387,7 +387,7 @@ export function createTxMethods(_mongoose: typeof import('mongoose'), txDeps: Tx function getPremiumRate( valueKey: string, tokenType: string, - inputTokenCount?: number, + inputTokenCount?: number | null, ): number | null { if (inputTokenCount == null) { return null; diff --git a/packages/data-schemas/src/methods/user.test.ts b/packages/data-schemas/src/methods/user.test.ts index 522e4fe158..5e557805e4 100644 --- a/packages/data-schemas/src/methods/user.test.ts +++ b/packages/data-schemas/src/methods/user.test.ts @@ -18,7 +18,7 @@ describe('User Methods', () => { describe('generateToken', () => { const mockUser = { - _id: 'user123', + _id: new mongoose.Types.ObjectId('aaaaaaaaaaaaaaaaaaaaaaaa'), username: 'testuser', provider: 'local', email: 'test@example.com', diff --git a/packages/data-schemas/src/methods/user.ts b/packages/data-schemas/src/methods/user.ts index 74cb4a1e1c..0b630e49b3 100644 --- a/packages/data-schemas/src/methods/user.ts +++ b/packages/data-schemas/src/methods/user.ts @@ -35,13 +35,26 @@ export function createUserMethods(mongoose: typeof import('mongoose')) { searchCriteria: FilterQuery, fieldsToSelect?: string | string[] | null, ): Promise { - const User = mongoose.models.User; + const User = mongoose.models.User as mongoose.Model; const normalizedCriteria = normalizeEmailInCriteria(searchCriteria); const query = User.findOne(normalizedCriteria); if (fieldsToSelect) { query.select(fieldsToSelect); } - return (await query.lean()) as IUser | null; + return await query.lean(); + } + + async function findUsers( + searchCriteria: FilterQuery, + fieldsToSelect?: string | string[] | null, + ): Promise { + const User = mongoose.models.User as mongoose.Model; + const normalizedCriteria = normalizeEmailInCriteria(searchCriteria); + const query = User.find(normalizedCriteria); + if (fieldsToSelect) { + query.select(fieldsToSelect); + } + return await query.lean(); } /** @@ -288,8 +301,6 @@ export function createUserMethods(mongoose: typeof import('mongoose')) { .sort((a, b) => b._searchScore - a._searchScore) .slice(0, limit) .map((user) => { - // Remove the search score from final results - // eslint-disable-next-line @typescript-eslint/no-unused-vars const { _searchScore, ...userWithoutScore } = user; return userWithoutScore; }); @@ -323,6 +334,7 @@ export function createUserMethods(mongoose: typeof import('mongoose')) { return { findUser, + findUsers, countUsers, createUser, updateUser, diff --git a/packages/data-schemas/src/methods/userGroup.methods.spec.ts b/packages/data-schemas/src/methods/userGroup.methods.spec.ts index 8a31544018..51848de091 100644 --- a/packages/data-schemas/src/methods/userGroup.methods.spec.ts +++ b/packages/data-schemas/src/methods/userGroup.methods.spec.ts @@ -600,6 +600,155 @@ describe('UserGroup Methods - Detailed Tests', () => { }); }); + describe('listGroups', () => { + beforeEach(async () => { + await Group.create([ + { name: 'Beta', source: 'local', memberIds: [], email: 'beta@test.com' }, + { name: 'Alpha', source: 'local', memberIds: [], description: 'first group' }, + { name: 'Gamma', source: 'entra', idOnTheSource: 'ext-g', memberIds: [] }, + ]); + }); + + test('returns groups sorted by name', async () => { + const groups = await methods.listGroups(); + + expect(groups).toHaveLength(3); + expect(groups[0].name).toBe('Alpha'); + expect(groups[1].name).toBe('Beta'); + expect(groups[2].name).toBe('Gamma'); + }); + + test('filters by source', async () => { + const groups = await methods.listGroups({ source: 'entra' }); + + expect(groups).toHaveLength(1); + expect(groups[0].name).toBe('Gamma'); + }); + + test('filters by search (name)', async () => { + const groups = await methods.listGroups({ search: 'alpha' }); + + expect(groups).toHaveLength(1); + expect(groups[0].name).toBe('Alpha'); + }); + + test('filters by search (email)', async () => { + const groups = await methods.listGroups({ search: 'beta@test' }); + + expect(groups).toHaveLength(1); + expect(groups[0].name).toBe('Beta'); + }); + + test('filters by search (description)', async () => { + const groups = await methods.listGroups({ search: 'first group' }); + + expect(groups).toHaveLength(1); + expect(groups[0].name).toBe('Alpha'); + }); + + test('respects limit and offset', async () => { + const groups = await methods.listGroups({ limit: 1, offset: 1 }); + + expect(groups).toHaveLength(1); + expect(groups[0].name).toBe('Beta'); + }); + + test('returns empty for no matches', async () => { + const groups = await methods.listGroups({ search: 'nonexistent' }); + + expect(groups).toHaveLength(0); + }); + }); + + describe('countGroups', () => { + beforeEach(async () => { + await Group.create([ + { name: 'A', source: 'local', memberIds: [] }, + { name: 'B', source: 'local', memberIds: [] }, + { name: 'C', source: 'entra', idOnTheSource: 'ext-c', memberIds: [] }, + ]); + }); + + test('returns total count', async () => { + const count = await methods.countGroups(); + + expect(count).toBe(3); + }); + + test('respects source filter', async () => { + const count = await methods.countGroups({ source: 'local' }); + + expect(count).toBe(2); + }); + + test('respects search filter', async () => { + const count = await methods.countGroups({ search: 'A' }); + + expect(count).toBe(1); + }); + }); + + describe('deleteGroup', () => { + test('returns deleted group', async () => { + const group = await Group.create({ name: 'ToDelete', source: 'local', memberIds: [] }); + + const deleted = await methods.deleteGroup(group._id as mongoose.Types.ObjectId); + + expect(deleted).toBeDefined(); + expect(deleted?.name).toBe('ToDelete'); + const remaining = await Group.findById(group._id); + expect(remaining).toBeNull(); + }); + + test('returns null for non-existent ID', async () => { + const fakeId = new mongoose.Types.ObjectId(); + const result = await methods.deleteGroup(fakeId); + + expect(result).toBeNull(); + }); + }); + + describe('removeMemberById', () => { + test('removes member from memberIds array', async () => { + const group = await Group.create({ + name: 'Test', + source: 'local', + memberIds: ['m1', 'm2', 'm3'], + }); + + const updated = await methods.removeMemberById( + group._id as mongoose.Types.ObjectId, + 'm2', + ); + + expect(updated).toBeDefined(); + expect(updated?.memberIds).toEqual(['m1', 'm3']); + }); + + test('is idempotent when memberId not present', async () => { + const group = await Group.create({ + name: 'Test', + source: 'local', + memberIds: ['m1'], + }); + + const updated = await methods.removeMemberById( + group._id as mongoose.Types.ObjectId, + 'nonexistent', + ); + + expect(updated).toBeDefined(); + expect(updated?.memberIds).toEqual(['m1']); + }); + + test('returns null for non-existent group', async () => { + const fakeId = new mongoose.Types.ObjectId(); + const result = await methods.removeMemberById(fakeId, 'any-id'); + + expect(result).toBeNull(); + }); + }); + describe('sortPrincipalsByRelevance', () => { test('should sort principals by relevance score', async () => { const principals = [ diff --git a/packages/data-schemas/src/methods/userGroup.spec.ts b/packages/data-schemas/src/methods/userGroup.spec.ts index 675fdb2592..ca83ced7d9 100644 --- a/packages/data-schemas/src/methods/userGroup.spec.ts +++ b/packages/data-schemas/src/methods/userGroup.spec.ts @@ -496,7 +496,7 @@ describe('userGroup methods', () => { it('returns the updated user document', async () => { const user = await createTestUser({ idOnTheSource: 'user-ext-1' }); const { user: updatedUser } = await methods.syncUserEntraGroups(user._id, []); - expect(updatedUser._id.toString()).toBe(user._id.toString()); + expect((updatedUser._id as Types.ObjectId).toString()).toBe(user._id.toString()); }); }); diff --git a/packages/data-schemas/src/methods/userGroup.ts b/packages/data-schemas/src/methods/userGroup.ts index 5e11c26135..0e6b57adb2 100644 --- a/packages/data-schemas/src/methods/userGroup.ts +++ b/packages/data-schemas/src/methods/userGroup.ts @@ -1,8 +1,9 @@ import { Types } from 'mongoose'; import { PrincipalType } from 'librechat-data-provider'; import type { TUser, TPrincipalSearchResult } from 'librechat-data-provider'; -import type { Model, ClientSession } from 'mongoose'; +import type { Model, ClientSession, FilterQuery } from 'mongoose'; import type { IGroup, IRole, IUser } from '~/types'; +import { escapeRegExp } from '~/utils/string'; export function createUserGroupMethods(mongoose: typeof import('mongoose')) { /** @@ -14,7 +15,7 @@ export function createUserGroupMethods(mongoose: typeof import('mongoose')) { */ async function findGroupById( groupId: string | Types.ObjectId, - projection: Record = {}, + projection: Record = {}, session?: ClientSession, ): Promise { const Group = mongoose.models.Group as Model; @@ -36,7 +37,7 @@ export function createUserGroupMethods(mongoose: typeof import('mongoose')) { async function findGroupByExternalId( idOnTheSource: string, source: 'entra' | 'local' = 'entra', - projection: Record = {}, + projection: Record = {}, session?: ClientSession, ): Promise { const Group = mongoose.models.Group as Model; @@ -236,34 +237,42 @@ export function createUserGroupMethods(mongoose: typeof import('mongoose')) { } /** - * Get a list of all principal identifiers for a user (user ID + group IDs + public) - * For use in permission checks + * Get a list of all principal identifiers for a user (user ID + group IDs + public). + * For use in permission checks. + * + * Tenant filtering for group memberships is handled automatically by the + * `applyTenantIsolation` Mongoose plugin on the Group schema. The + * `tenantContextMiddleware` (chained by `requireJwtAuth` after passport auth) + * sets the ALS context, so `getUserGroups()` → `findGroupsByMemberId()` queries + * are scoped to the requesting tenant. No explicit tenantId parameter is needed. + * + * IMPORTANT: This relies on the ALS tenant context being active. If this + * function is called outside a request context (e.g. startup, background jobs), + * group queries will be unscoped. In strict mode, the Mongoose plugin will + * reject such queries. + * + * Ref: #12091 (resolved by tenant context middleware in requireJwtAuth) + * * @param params - Parameters object * @param params.userId - The user ID * @param params.role - Optional user role (if not provided, will query from DB) * @param session - Optional MongoDB session for transactions * @returns Array of principal objects with type and id */ - /** - * TODO(#12091): This method has no tenantId parameter — it returns ALL group - * memberships for a user regardless of tenant. In multi-tenant mode, group - * principals from other tenants will be included in capability checks, which - * could grant cross-tenant capabilities. Add tenantId filtering here when - * tenant isolation is activated. - */ async function getUserPrincipals( params: { userId: string | Types.ObjectId; role?: string | null; }, session?: ClientSession, - ): Promise> { + ): Promise> { const { userId, role } = params; /** `userId` must be an `ObjectId` for USER principal since ACL entries store `ObjectId`s */ const userObjectId = typeof userId === 'string' ? new Types.ObjectId(userId) : userId; - const principals: Array<{ principalType: string; principalId?: string | Types.ObjectId }> = [ - { principalType: PrincipalType.USER, principalId: userObjectId }, - ]; + const principals: Array<{ + principalType: PrincipalType; + principalId?: string | Types.ObjectId; + }> = [{ principalType: PrincipalType.USER, principalId: userObjectId }]; // If role is not provided, query user to get it let userRole = role; @@ -651,6 +660,97 @@ export function createUserGroupMethods(mongoose: typeof import('mongoose')) { return Group.updateMany(filter, update, options || {}); } + function buildGroupQuery(filter: { + source?: 'local' | 'entra'; + search?: string; + }): FilterQuery { + const query: FilterQuery = {}; + if (filter.source) { + query.source = filter.source; + } + if (filter.search) { + const regex = new RegExp(escapeRegExp(filter.search), 'i'); + query.$or = [{ name: regex }, { email: regex }, { description: regex }]; + } + return query; + } + + /** + * List groups with optional source, search, and pagination filters. + * Results are sorted by name. + * @param filter - Optional filter with source, search, limit, and offset fields + * @param session - Optional MongoDB session for transactions + */ + async function listGroups( + filter: { + source?: 'local' | 'entra'; + search?: string; + limit?: number; + offset?: number; + } = {}, + session?: ClientSession, + ): Promise { + const Group = mongoose.models.Group as Model; + const query = buildGroupQuery(filter); + const limit = filter.limit ?? 50; + const offset = filter.offset ?? 0; + return await Group.find(query) + .sort({ name: 1 }) + .skip(offset) + .limit(limit) + .session(session ?? null) + .lean(); + } + + /** + * Count groups matching optional source and search filters. + * @param filter - Optional filter with source and search fields + * @param session - Optional MongoDB session for transactions + */ + async function countGroups( + filter: { source?: 'local' | 'entra'; search?: string } = {}, + session?: ClientSession, + ): Promise { + const Group = mongoose.models.Group as Model; + const query = buildGroupQuery(filter); + return await Group.countDocuments(query).session(session ?? null); + } + + /** + * Delete a group by its ID. + * @param groupId - The group's ObjectId + * @param session - Optional MongoDB session for transactions + */ + async function deleteGroup( + groupId: string | Types.ObjectId, + session?: ClientSession, + ): Promise { + const Group = mongoose.models.Group as Model; + const options = session ? { session } : {}; + return await Group.findByIdAndDelete(groupId, options).lean(); + } + + /** + * Remove a member from a group by raw memberId string ($pull from memberIds). + * Unlike removeUserFromGroup, this does not look up the user first. + * @param groupId - The group's ObjectId + * @param memberId - The raw memberId string to remove (ObjectId or idOnTheSource) + * @param session - Optional MongoDB session for transactions + */ + async function removeMemberById( + groupId: string | Types.ObjectId, + memberId: string, + session?: ClientSession, + ): Promise { + const Group = mongoose.models.Group as Model; + const options = { new: true, ...(session ? { session } : {}) }; + return await Group.findByIdAndUpdate( + groupId, + { $pull: { memberIds: memberId } }, + options, + ).lean(); + } + return { findGroupById, findGroupByExternalId, @@ -670,6 +770,10 @@ export function createUserGroupMethods(mongoose: typeof import('mongoose')) { searchPrincipals, calculateRelevanceScore, sortPrincipalsByRelevance, + listGroups, + countGroups, + deleteGroup, + removeMemberById, }; } diff --git a/packages/data-schemas/src/migrations/promptGroupIndexes.ts b/packages/data-schemas/src/migrations/promptGroupIndexes.ts index 4b6013c9e4..2d389f3f09 100644 --- a/packages/data-schemas/src/migrations/promptGroupIndexes.ts +++ b/packages/data-schemas/src/migrations/promptGroupIndexes.ts @@ -18,7 +18,7 @@ export async function dropSupersededPromptGroupIndexes( let collection; try { - collection = connection.db.collection(collectionName); + collection = connection.db!.collection(collectionName); } catch { result.skipped.push( ...SUPERSEDED_PROMPT_GROUP_INDEXES.map( diff --git a/packages/data-schemas/src/migrations/tenantIndexes.spec.ts b/packages/data-schemas/src/migrations/tenantIndexes.spec.ts index 4637e7d0ad..6a0987d757 100644 --- a/packages/data-schemas/src/migrations/tenantIndexes.spec.ts +++ b/packages/data-schemas/src/migrations/tenantIndexes.spec.ts @@ -1,4 +1,4 @@ -import mongoose, { Schema } from 'mongoose'; +import mongoose from 'mongoose'; import { MongoMemoryServer } from 'mongodb-memory-server'; import { dropSupersededTenantIndexes, SUPERSEDED_INDEXES } from './tenantIndexes'; @@ -24,7 +24,7 @@ afterAll(async () => { describe('dropSupersededTenantIndexes', () => { describe('with pre-existing single-field unique indexes (simulates upgrade)', () => { beforeAll(async () => { - const db = mongoose.connection.db; + const db = mongoose.connection.db!; await db.createCollection('users'); const users = db.collection('users'); @@ -47,7 +47,13 @@ describe('dropSupersededTenantIndexes', () => { await db.createCollection('roles'); await db.collection('roles').createIndex({ name: 1 }, { unique: true, name: 'name_1' }); + await db.createCollection('agents'); + await db.collection('agents').createIndex({ id: 1 }, { unique: true, name: 'id_1' }); + await db.createCollection('conversations'); + await db + .collection('conversations') + .createIndex({ conversationId: 1 }, { unique: true, name: 'conversationId_1' }); await db .collection('conversations') .createIndex( @@ -56,10 +62,18 @@ describe('dropSupersededTenantIndexes', () => { ); await db.createCollection('messages'); + await db + .collection('messages') + .createIndex({ messageId: 1 }, { unique: true, name: 'messageId_1' }); await db .collection('messages') .createIndex({ messageId: 1, user: 1 }, { unique: true, name: 'messageId_1_user_1' }); + await db.createCollection('presets'); + await db + .collection('presets') + .createIndex({ presetId: 1 }, { unique: true, name: 'presetId_1' }); + await db.createCollection('agentcategories'); await db .collection('agentcategories') @@ -119,7 +133,7 @@ describe('dropSupersededTenantIndexes', () => { }); it('old unique indexes are actually gone from users collection', async () => { - const indexes = await mongoose.connection.db.collection('users').indexes(); + const indexes = await mongoose.connection.db!.collection('users').indexes(); const indexNames = indexes.map((idx) => idx.name); expect(indexNames).not.toContain('email_1'); @@ -129,14 +143,14 @@ describe('dropSupersededTenantIndexes', () => { }); it('old unique indexes are actually gone from roles collection', async () => { - const indexes = await mongoose.connection.db.collection('roles').indexes(); + const indexes = await mongoose.connection.db!.collection('roles').indexes(); const indexNames = indexes.map((idx) => idx.name); expect(indexNames).not.toContain('name_1'); }); it('old compound unique indexes are gone from conversations collection', async () => { - const indexes = await mongoose.connection.db.collection('conversations').indexes(); + const indexes = await mongoose.connection.db!.collection('conversations').indexes(); const indexNames = indexes.map((idx) => idx.name); expect(indexNames).not.toContain('conversationId_1_user_1'); @@ -145,7 +159,7 @@ describe('dropSupersededTenantIndexes', () => { describe('multi-tenant writes after migration', () => { beforeAll(async () => { - const db = mongoose.connection.db; + const db = mongoose.connection.db!; const users = db.collection('users'); await users.createIndex( @@ -155,7 +169,7 @@ describe('dropSupersededTenantIndexes', () => { }); it('allows same email in different tenants after old index is dropped', async () => { - const users = mongoose.connection.db.collection('users'); + const users = mongoose.connection.db!.collection('users'); await users.insertOne({ email: 'shared@example.com', @@ -182,7 +196,7 @@ describe('dropSupersededTenantIndexes', () => { }); it('still rejects duplicate email within same tenant', async () => { - const users = mongoose.connection.db.collection('users'); + const users = mongoose.connection.db!.collection('users'); await users.insertOne({ email: 'unique-within@example.com', @@ -233,7 +247,7 @@ describe('dropSupersededTenantIndexes', () => { partialConnection = mongoose.createConnection(partialServer.getUri()); await partialConnection.asPromise(); - const db = partialConnection.db; + const db = partialConnection.db!; await db.createCollection('users'); await db.collection('users').createIndex({ email: 1 }, { unique: true, name: 'email_1' }); }); diff --git a/packages/data-schemas/src/migrations/tenantIndexes.ts b/packages/data-schemas/src/migrations/tenantIndexes.ts index c68df4db2b..6536423ad2 100644 --- a/packages/data-schemas/src/migrations/tenantIndexes.ts +++ b/packages/data-schemas/src/migrations/tenantIndexes.ts @@ -24,8 +24,10 @@ const SUPERSEDED_INDEXES: Record = { 'appleId_1', ], roles: ['name_1'], - conversations: ['conversationId_1_user_1'], - messages: ['messageId_1_user_1'], + agents: ['id_1'], + conversations: ['conversationId_1', 'conversationId_1_user_1'], + messages: ['messageId_1', 'messageId_1_user_1'], + presets: ['presetId_1'], agentcategories: ['value_1'], accessroles: ['accessRoleId_1'], conversationtags: ['tag_1_user_1'], @@ -53,7 +55,7 @@ export async function dropSupersededTenantIndexes( const result: MigrationResult = { dropped: [], skipped: [], errors: [] }; for (const [collectionName, indexNames] of Object.entries(SUPERSEDED_INDEXES)) { - const collection = connection.db.collection(collectionName); + const collection = connection.db!.collection(collectionName); let existingIndexes: Array<{ name?: string }>; try { diff --git a/packages/data-schemas/src/models/config.ts b/packages/data-schemas/src/models/config.ts new file mode 100644 index 0000000000..97c08ce1da --- /dev/null +++ b/packages/data-schemas/src/models/config.ts @@ -0,0 +1,8 @@ +import configSchema from '~/schema/config'; +import { applyTenantIsolation } from '~/models/plugins/tenantIsolation'; +import type * as t from '~/types'; + +export function createConfigModel(mongoose: typeof import('mongoose')) { + applyTenantIsolation(configSchema); + return mongoose.models.Config || mongoose.model('Config', configSchema); +} diff --git a/packages/data-schemas/src/models/index.ts b/packages/data-schemas/src/models/index.ts index 44d94c6ab4..5a8e8f1c2c 100644 --- a/packages/data-schemas/src/models/index.ts +++ b/packages/data-schemas/src/models/index.ts @@ -27,6 +27,7 @@ import { createAccessRoleModel } from './accessRole'; import { createAclEntryModel } from './aclEntry'; import { createSystemGrantModel } from './systemGrant'; import { createGroupModel } from './group'; +import { createConfigModel } from './config'; /** * Creates all database models for all collections @@ -62,5 +63,6 @@ export function createModels(mongoose: typeof import('mongoose')) { AclEntry: createAclEntryModel(mongoose), SystemGrant: createSystemGrantModel(mongoose), Group: createGroupModel(mongoose), + Config: createConfigModel(mongoose), }; } diff --git a/packages/data-schemas/src/schema/agent.ts b/packages/data-schemas/src/schema/agent.ts index 42a7ca5418..70734d0ceb 100644 --- a/packages/data-schemas/src/schema/agent.ts +++ b/packages/data-schemas/src/schema/agent.ts @@ -5,8 +5,6 @@ const agentSchema = new Schema( { id: { type: String, - index: true, - unique: true, required: true, }, name: { @@ -124,6 +122,7 @@ const agentSchema = new Schema( }, ); +agentSchema.index({ id: 1, tenantId: 1 }, { unique: true }); agentSchema.index({ updatedAt: -1, _id: 1 }); agentSchema.index({ 'edges.to': 1 }); diff --git a/packages/data-schemas/src/schema/config.ts b/packages/data-schemas/src/schema/config.ts new file mode 100644 index 0000000000..be3784d55e --- /dev/null +++ b/packages/data-schemas/src/schema/config.ts @@ -0,0 +1,55 @@ +import { Schema } from 'mongoose'; +import { PrincipalType, PrincipalModel } from 'librechat-data-provider'; +import type { IConfig } from '~/types'; + +const configSchema = new Schema( + { + principalType: { + type: String, + enum: Object.values(PrincipalType), + required: true, + index: true, + }, + principalId: { + type: String, + refPath: 'principalModel', + required: true, + index: true, + }, + principalModel: { + type: String, + enum: Object.values(PrincipalModel), + required: true, + }, + priority: { + type: Number, + required: true, + index: true, + }, + overrides: { + type: Schema.Types.Mixed, + default: {}, + }, + isActive: { + type: Boolean, + default: true, + index: true, + }, + configVersion: { + type: Number, + default: 0, + }, + tenantId: { + type: String, + index: true, + }, + }, + { timestamps: true }, +); + +// Enforce 1:1 principal-to-config (one config document per principal per tenant) +configSchema.index({ principalType: 1, principalId: 1, tenantId: 1 }, { unique: true }); +configSchema.index({ principalType: 1, principalId: 1, isActive: 1, tenantId: 1 }); +configSchema.index({ priority: 1, isActive: 1, tenantId: 1 }); + +export default configSchema; diff --git a/packages/data-schemas/src/schema/convo.ts b/packages/data-schemas/src/schema/convo.ts index 9ed8949e9c..c8f394935a 100644 --- a/packages/data-schemas/src/schema/convo.ts +++ b/packages/data-schemas/src/schema/convo.ts @@ -6,7 +6,6 @@ const convoSchema: Schema = new Schema( { conversationId: { type: String, - unique: true, required: true, index: true, meiliIndex: true, diff --git a/packages/data-schemas/src/schema/index.ts b/packages/data-schemas/src/schema/index.ts index 456eb03ac2..2a5eff658b 100644 --- a/packages/data-schemas/src/schema/index.ts +++ b/packages/data-schemas/src/schema/index.ts @@ -25,3 +25,4 @@ export { default as userSchema } from './user'; export { default as memorySchema } from './memory'; export { default as groupSchema } from './group'; export { default as systemGrantSchema } from './systemGrant'; +export { default as configSchema } from './config'; diff --git a/packages/data-schemas/src/schema/message.ts b/packages/data-schemas/src/schema/message.ts index ff3468918e..9879efae55 100644 --- a/packages/data-schemas/src/schema/message.ts +++ b/packages/data-schemas/src/schema/message.ts @@ -5,7 +5,6 @@ const messageSchema: Schema = new Schema( { messageId: { type: String, - unique: true, required: true, index: true, meiliIndex: true, diff --git a/packages/data-schemas/src/schema/preset.ts b/packages/data-schemas/src/schema/preset.ts index 33c217ea23..5af5163fd3 100644 --- a/packages/data-schemas/src/schema/preset.ts +++ b/packages/data-schemas/src/schema/preset.ts @@ -60,7 +60,6 @@ const presetSchema: Schema = new Schema( { presetId: { type: String, - unique: true, required: true, index: true, }, @@ -88,4 +87,6 @@ const presetSchema: Schema = new Schema( { timestamps: true }, ); +presetSchema.index({ presetId: 1, tenantId: 1 }, { unique: true }); + export default presetSchema; diff --git a/packages/data-schemas/src/schema/role.ts b/packages/data-schemas/src/schema/role.ts index 1c27478ef6..ac478c2a83 100644 --- a/packages/data-schemas/src/schema/role.ts +++ b/packages/data-schemas/src/schema/role.ts @@ -73,6 +73,7 @@ const rolePermissionsSchema = new Schema( const roleSchema: Schema = new Schema({ name: { type: String, required: true, index: true }, + description: { type: String, default: '' }, permissions: { type: rolePermissionsSchema, }, diff --git a/packages/data-schemas/src/schema/systemGrant.ts b/packages/data-schemas/src/schema/systemGrant.ts index 0366f6080d..a20a407bf1 100644 --- a/packages/data-schemas/src/schema/systemGrant.ts +++ b/packages/data-schemas/src/schema/systemGrant.ts @@ -1,7 +1,7 @@ import { Schema } from 'mongoose'; import { PrincipalType } from 'librechat-data-provider'; -import { SystemCapabilities } from '~/systemCapabilities'; -import type { SystemCapability } from '~/systemCapabilities'; +import { SystemCapabilities } from '~/admin/capabilities'; +import type { SystemCapability } from '~/types/admin'; import type { ISystemGrant } from '~/types'; const baseCapabilities = new Set(Object.values(SystemCapabilities)); diff --git a/packages/data-schemas/src/schema/user.ts b/packages/data-schemas/src/schema/user.ts index 92680415bd..f807ddd8d6 100644 --- a/packages/data-schemas/src/schema/user.ts +++ b/packages/data-schemas/src/schema/user.ts @@ -158,6 +158,7 @@ const userSchema = new Schema( ); userSchema.index({ email: 1, tenantId: 1 }, { unique: true }); +userSchema.index({ role: 1, tenantId: 1 }); const oAuthIdFields = [ 'googleId', diff --git a/packages/data-schemas/src/systemCapabilities.ts b/packages/data-schemas/src/systemCapabilities.ts deleted file mode 100644 index cf2acfbf88..0000000000 --- a/packages/data-schemas/src/systemCapabilities.ts +++ /dev/null @@ -1,106 +0,0 @@ -import type { z } from 'zod'; -import type { configSchema } from 'librechat-data-provider'; -import { ResourceType } from 'librechat-data-provider'; - -export const SystemCapabilities = { - ACCESS_ADMIN: 'access:admin', - READ_USERS: 'read:users', - MANAGE_USERS: 'manage:users', - READ_GROUPS: 'read:groups', - MANAGE_GROUPS: 'manage:groups', - READ_ROLES: 'read:roles', - MANAGE_ROLES: 'manage:roles', - READ_CONFIGS: 'read:configs', - MANAGE_CONFIGS: 'manage:configs', - ASSIGN_CONFIGS: 'assign:configs', - READ_USAGE: 'read:usage', - READ_AGENTS: 'read:agents', - MANAGE_AGENTS: 'manage:agents', - MANAGE_MCP_SERVERS: 'manage:mcpservers', - READ_PROMPTS: 'read:prompts', - MANAGE_PROMPTS: 'manage:prompts', - /** Reserved — not yet enforced by any middleware. Grant has no effect until assistant listing is gated. */ - READ_ASSISTANTS: 'read:assistants', - MANAGE_ASSISTANTS: 'manage:assistants', -} as const; - -/** Top-level keys of the configSchema from librechat.yaml. */ -export type ConfigSection = keyof z.infer; - -/** Principal types that can receive config overrides. */ -export type ConfigAssignTarget = 'user' | 'group' | 'role'; - -/** Base capabilities defined in the SystemCapabilities object. */ -type BaseSystemCapability = (typeof SystemCapabilities)[keyof typeof SystemCapabilities]; - -/** Section-level config capabilities derived from configSchema keys. */ -type ConfigSectionCapability = `manage:configs:${ConfigSection}` | `read:configs:${ConfigSection}`; - -/** Principal-scoped config assignment capabilities. */ -type ConfigAssignCapability = `assign:configs:${ConfigAssignTarget}`; - -/** - * Union of all valid capability strings: - * - Base capabilities from SystemCapabilities - * - Section-level config capabilities (manage:configs:
, read:configs:
) - * - Config assignment capabilities (assign:configs:) - */ -export type SystemCapability = - | BaseSystemCapability - | ConfigSectionCapability - | ConfigAssignCapability; - -/** - * Capabilities that are implied by holding a broader capability. - * When `hasCapability` checks for an implied capability, it first expands - * the principal's grant set — so granting `MANAGE_USERS` automatically - * satisfies a `READ_USERS` check without a separate grant. - * - * Implication is one-directional: `MANAGE_USERS` implies `READ_USERS`, - * but `READ_USERS` does NOT imply `MANAGE_USERS`. - */ -export const CapabilityImplications: Partial> = - { - [SystemCapabilities.MANAGE_USERS]: [SystemCapabilities.READ_USERS], - [SystemCapabilities.MANAGE_GROUPS]: [SystemCapabilities.READ_GROUPS], - [SystemCapabilities.MANAGE_ROLES]: [SystemCapabilities.READ_ROLES], - [SystemCapabilities.MANAGE_CONFIGS]: [SystemCapabilities.READ_CONFIGS], - [SystemCapabilities.MANAGE_AGENTS]: [SystemCapabilities.READ_AGENTS], - [SystemCapabilities.MANAGE_PROMPTS]: [SystemCapabilities.READ_PROMPTS], - [SystemCapabilities.MANAGE_ASSISTANTS]: [SystemCapabilities.READ_ASSISTANTS], - }; - -/** - * Maps each ACL ResourceType to the SystemCapability that grants - * unrestricted management access. Typed as `Record` - * so adding a new ResourceType variant causes a compile error until a - * capability is assigned here. - */ -export const ResourceCapabilityMap: Record = { - [ResourceType.AGENT]: SystemCapabilities.MANAGE_AGENTS, - [ResourceType.PROMPTGROUP]: SystemCapabilities.MANAGE_PROMPTS, - [ResourceType.MCPSERVER]: SystemCapabilities.MANAGE_MCP_SERVERS, - [ResourceType.REMOTE_AGENT]: SystemCapabilities.MANAGE_AGENTS, -}; - -/** - * Derives a section-level config management capability from a configSchema key. - * @example configCapability('endpoints') → 'manage:configs:endpoints' - * - * TODO: Section-level config capabilities are scaffolded but not yet active. - * To activate delegated config management: - * 1. Expose POST/DELETE /api/admin/grants endpoints (wiring grantCapability/revokeCapability) - * 2. Seed section-specific grants for delegated admin roles via those endpoints - * 3. Guard config write handlers with hasConfigCapability(user, section) - */ -export function configCapability(section: ConfigSection): `manage:configs:${ConfigSection}` { - return `manage:configs:${section}`; -} - -/** - * Derives a section-level config read capability from a configSchema key. - * @example readConfigCapability('endpoints') → 'read:configs:endpoints' - */ -export function readConfigCapability(section: ConfigSection): `read:configs:${ConfigSection}` { - return `read:configs:${section}`; -} diff --git a/packages/data-schemas/src/types/admin.ts b/packages/data-schemas/src/types/admin.ts new file mode 100644 index 0000000000..a16f68ae9c --- /dev/null +++ b/packages/data-schemas/src/types/admin.ts @@ -0,0 +1,120 @@ +import type { PrincipalType, PrincipalModel, TCustomConfig } from 'librechat-data-provider'; +import type { SystemCapabilities } from '~/admin/capabilities'; + +/* ── Capability types ───────────────────────────────────────────────── */ + +/** Base capabilities derived from the SystemCapabilities constant. */ +export type BaseSystemCapability = (typeof SystemCapabilities)[keyof typeof SystemCapabilities]; + +/** Principal types that can receive config overrides. */ +export type ConfigAssignTarget = 'user' | 'group' | 'role'; + +/** Top-level keys of the configSchema from librechat.yaml. */ +export type ConfigSection = string & keyof TCustomConfig; + +/** Section-level config capabilities derived from configSchema keys. */ +type ConfigSectionCapability = `manage:configs:${ConfigSection}` | `read:configs:${ConfigSection}`; + +/** Principal-scoped config assignment capabilities. */ +type ConfigAssignCapability = `assign:configs:${ConfigAssignTarget}`; + +/** + * Union of all valid capability strings: + * - Base capabilities from SystemCapabilities + * - Section-level config capabilities (manage:configs:
, read:configs:
) + * - Config assignment capabilities (assign:configs:) + */ +export type SystemCapability = + | BaseSystemCapability + | ConfigSectionCapability + | ConfigAssignCapability; + +/** UI grouping of capabilities for the admin panel's capability editor. */ +export type CapabilityCategory = { + key: string; + labelKey: string; + capabilities: BaseSystemCapability[]; +}; + +/* ── Admin API response types ───────────────────────────────────────── */ + +/** Config document as returned by the admin API (no Mongoose internals). */ +export type AdminConfig = { + _id: string; + principalType: PrincipalType; + principalId: string; + principalModel: PrincipalModel; + priority: number; + overrides: Partial; + isActive: boolean; + configVersion: number; + tenantId?: string; + createdAt?: string; + updatedAt?: string; +}; + +export type AdminConfigListResponse = { + configs: AdminConfig[]; +}; + +export type AdminConfigResponse = { + config: AdminConfig; +}; + +export type AdminConfigDeleteResponse = { + success: boolean; +}; + +/** Audit action types for grant changes. */ +export type AuditAction = 'grant_assigned' | 'grant_removed'; + +/** SystemGrant document as returned by the admin API. */ +export type AdminSystemGrant = { + id: string; + principalType: PrincipalType; + principalId: string; + capability: string; + grantedBy?: string; + grantedAt: string; + expiresAt?: string; +}; + +/** Audit log entry for grant changes as returned by the admin API. */ +export type AdminAuditLogEntry = { + id: string; + action: AuditAction; + actorId: string; + actorName: string; + targetPrincipalType: PrincipalType; + targetPrincipalId: string; + targetName: string; + capability: string; + timestamp: string; +}; + +/** Group as returned by the admin API. */ +export type AdminGroup = { + id: string; + name: string; + description: string; + memberCount: number; + topMembers: { name: string }[]; + isActive: boolean; +}; + +/** Member entry as returned by the admin API for group/role membership lists. */ +export type AdminMember = { + userId: string; + name: string; + email: string; + avatarUrl?: string; + joinedAt?: string; +}; + +/** Minimal user info returned by user search endpoints. */ +export type AdminUserSearchResult = { + userId: string; + name: string; + email: string; + avatarUrl?: string; +}; diff --git a/packages/data-schemas/src/types/config.ts b/packages/data-schemas/src/types/config.ts new file mode 100644 index 0000000000..04e0ca58ab --- /dev/null +++ b/packages/data-schemas/src/types/config.ts @@ -0,0 +1,36 @@ +import { PrincipalType, PrincipalModel } from 'librechat-data-provider'; +import type { TCustomConfig } from 'librechat-data-provider'; +import type { Document, Types } from 'mongoose'; + +/** + * Configuration override for a principal (user, group, or role). + * Stores partial overrides at the TCustomConfig (YAML) level, + * which are merged with the base config before processing through AppService. + */ +export type Config = { + /** The type of principal (user, group, role) */ + principalType: PrincipalType; + /** The ID of the principal (ObjectId for users/groups, string for roles) */ + principalId: Types.ObjectId | string; + /** The model name for the principal */ + principalModel: PrincipalModel; + /** Priority level for determining merge order (higher = more specific) */ + priority: number; + /** Configuration overrides matching librechat.yaml structure */ + overrides: Partial; + /** Whether this config override is currently active */ + isActive: boolean; + /** Version number for cache invalidation, auto-increments on overrides change */ + configVersion: number; + /** Tenant identifier for multi-tenancy isolation */ + tenantId?: string; + /** When this config was created */ + createdAt?: Date; + /** When this config was last updated */ + updatedAt?: Date; +}; + +export type IConfig = Config & + Document & { + _id: Types.ObjectId; + }; diff --git a/packages/data-schemas/src/types/index.ts b/packages/data-schemas/src/types/index.ts index 26238cbda1..748ea5d77d 100644 --- a/packages/data-schemas/src/types/index.ts +++ b/packages/data-schemas/src/types/index.ts @@ -28,6 +28,10 @@ export * from './accessRole'; export * from './aclEntry'; export * from './systemGrant'; export * from './group'; +/* Config */ +export * from './config'; +/* Admin */ +export * from './admin'; /* Web */ export * from './web'; /* MCP Servers */ diff --git a/packages/data-schemas/src/types/role.ts b/packages/data-schemas/src/types/role.ts index 60a579240c..bc85284c34 100644 --- a/packages/data-schemas/src/types/role.ts +++ b/packages/data-schemas/src/types/role.ts @@ -5,6 +5,7 @@ import { CursorPaginationParams } from '~/common'; export interface IRole extends Document { name: string; + description?: string; permissions: { [PermissionTypes.BOOKMARKS]?: { [Permissions.USE]?: boolean; @@ -74,11 +75,13 @@ export type RolePermissionsInput = DeepPartial; export interface CreateRoleRequest { name: string; + description?: string; permissions: RolePermissionsInput; } export interface UpdateRoleRequest { name?: string; + description?: string; permissions?: RolePermissionsInput; } diff --git a/packages/data-schemas/src/types/systemGrant.ts b/packages/data-schemas/src/types/systemGrant.ts index 9f0d576503..09cff1aec6 100644 --- a/packages/data-schemas/src/types/systemGrant.ts +++ b/packages/data-schemas/src/types/systemGrant.ts @@ -1,6 +1,6 @@ import type { Document, Types } from 'mongoose'; import type { PrincipalType } from 'librechat-data-provider'; -import type { SystemCapability } from '~/systemCapabilities'; +import type { SystemCapability } from '~/types/admin'; export type SystemGrant = { /** The type of principal — matches PrincipalType enum values */ diff --git a/packages/data-schemas/src/types/user.ts b/packages/data-schemas/src/types/user.ts index 0fac46ee63..2d8eb82f47 100644 --- a/packages/data-schemas/src/types/user.ts +++ b/packages/data-schemas/src/types/user.ts @@ -2,6 +2,7 @@ import type { Document, Types } from 'mongoose'; import { CursorPaginationParams } from '~/common'; export interface IUser extends Document { + _id: Types.ObjectId; name?: string; username?: string; email: string; @@ -50,6 +51,15 @@ export interface IUser extends Document { /** Field for external source identification (for consistency with TPrincipal schema) */ idOnTheSource?: string; tenantId?: string; + federatedTokens?: OIDCTokens; + openidTokens?: OIDCTokens; +} + +export interface OIDCTokens { + access_token?: string; + id_token?: string; + refresh_token?: string; + expires_at?: number; } export interface BalanceConfig { diff --git a/packages/data-schemas/src/types/winston-transports.d.ts b/packages/data-schemas/src/types/winston-transports.d.ts new file mode 100644 index 0000000000..704486e5ce --- /dev/null +++ b/packages/data-schemas/src/types/winston-transports.d.ts @@ -0,0 +1,34 @@ +import type TransportStream from 'winston-transport'; + +/** + * Module augmentation for winston's transports namespace. + * + * `winston-daily-rotate-file` ships its own augmentation targeting + * `'winston/lib/winston/transports'`, but it fails when winston and + * winston-daily-rotate-file resolve from different node_modules trees + * (which happens in this monorepo due to npm hoisting). This local + * declaration bridges the gap so `tsc --noEmit` passes. + */ +declare module 'winston/lib/winston/transports' { + interface Transports { + DailyRotateFile: new ( + opts?: { + level?: string; + filename?: string; + datePattern?: string; + zippedArchive?: boolean; + maxSize?: string | number; + maxFiles?: string | number; + dirname?: string; + stream?: NodeJS.WritableStream; + frequency?: string; + utc?: boolean; + extension?: string; + createSymlink?: boolean; + symlinkName?: string; + auditFile?: string; + format?: import('logform').Format; + } & TransportStream.TransportStreamOptions, + ) => TransportStream; + } +} diff --git a/packages/data-schemas/src/utils/index.ts b/packages/data-schemas/src/utils/index.ts index c071f4e827..17e43ac3ca 100644 --- a/packages/data-schemas/src/utils/index.ts +++ b/packages/data-schemas/src/utils/index.ts @@ -1,5 +1,6 @@ export * from './principal'; export * from './string'; export * from './tempChatRetention'; +export { tenantSafeBulkWrite } from './tenantBulkWrite'; export * from './transactions'; export * from './objectId'; diff --git a/packages/data-schemas/src/utils/tenantBulkWrite.spec.ts b/packages/data-schemas/src/utils/tenantBulkWrite.spec.ts new file mode 100644 index 0000000000..059868b8a1 --- /dev/null +++ b/packages/data-schemas/src/utils/tenantBulkWrite.spec.ts @@ -0,0 +1,376 @@ +import mongoose, { Schema } from 'mongoose'; +import { MongoMemoryServer } from 'mongodb-memory-server'; +import { tenantStorage, runAsSystem, SYSTEM_TENANT_ID } from '~/config/tenantContext'; +import { applyTenantIsolation, _resetStrictCache } from '~/models/plugins/tenantIsolation'; +import { tenantSafeBulkWrite, _resetBulkWriteStrictCache } from './tenantBulkWrite'; + +let mongoServer: InstanceType; + +interface ITestDoc { + name: string; + value?: number; + tenantId?: string; +} + +function createTestModel(suffix: string) { + const schema = new Schema({ + name: { type: String, required: true }, + value: { type: Number, default: 0 }, + tenantId: { type: String, index: true }, + }); + applyTenantIsolation(schema); + const modelName = `TestBulkWrite_${suffix}_${Date.now()}`; + return mongoose.model(modelName, schema); +} + +beforeAll(async () => { + mongoServer = await MongoMemoryServer.create(); + await mongoose.connect(mongoServer.getUri()); +}); + +afterAll(async () => { + await mongoose.disconnect(); + await mongoServer.stop(); +}); + +afterEach(() => { + delete process.env.TENANT_ISOLATION_STRICT; + _resetStrictCache(); + _resetBulkWriteStrictCache(); +}); + +describe('tenantSafeBulkWrite', () => { + describe('with tenant context', () => { + it('injects tenantId into updateOne filters', async () => { + const Model = createTestModel('updateOne'); + + // Seed data for two tenants + await runAsSystem(async () => { + await Model.create([ + { name: 'doc1', value: 1, tenantId: 'tenant-a' }, + { name: 'doc1', value: 1, tenantId: 'tenant-b' }, + ]); + }); + + // Update only tenant-a's doc + await tenantStorage.run({ tenantId: 'tenant-a' }, async () => { + await tenantSafeBulkWrite(Model, [ + { + updateOne: { + filter: { name: 'doc1' }, + update: { $set: { value: 99 } }, + }, + }, + ]); + }); + + // Verify tenant-a was updated, tenant-b was not + const docs = await runAsSystem(async () => Model.find({}).lean()); + const docA = docs.find((d) => d.tenantId === 'tenant-a'); + const docB = docs.find((d) => d.tenantId === 'tenant-b'); + expect(docA?.value).toBe(99); + expect(docB?.value).toBe(1); + }); + + it('injects tenantId into insertOne documents', async () => { + const Model = createTestModel('insertOne'); + + await tenantStorage.run({ tenantId: 'tenant-x' }, async () => { + await tenantSafeBulkWrite(Model, [ + { + insertOne: { + document: { name: 'new-doc', value: 42 } as ITestDoc, + }, + }, + ]); + }); + + const docs = await runAsSystem(async () => Model.find({}).lean()); + expect(docs).toHaveLength(1); + expect(docs[0].tenantId).toBe('tenant-x'); + expect(docs[0].name).toBe('new-doc'); + }); + + it('injects tenantId into deleteOne filters', async () => { + const Model = createTestModel('deleteOne'); + + await runAsSystem(async () => { + await Model.create([ + { name: 'to-delete', tenantId: 'tenant-a' }, + { name: 'to-delete', tenantId: 'tenant-b' }, + ]); + }); + + await tenantStorage.run({ tenantId: 'tenant-a' }, async () => { + await tenantSafeBulkWrite(Model, [ + { + deleteOne: { + filter: { name: 'to-delete' }, + }, + }, + ]); + }); + + const docs = await runAsSystem(async () => Model.find({}).lean()); + expect(docs).toHaveLength(1); + expect(docs[0].tenantId).toBe('tenant-b'); + }); + + it('injects tenantId into updateMany filters', async () => { + const Model = createTestModel('updateMany'); + + await runAsSystem(async () => { + await Model.create([ + { name: 'batch', value: 0, tenantId: 'tenant-a' }, + { name: 'batch', value: 0, tenantId: 'tenant-a' }, + { name: 'batch', value: 0, tenantId: 'tenant-b' }, + ]); + }); + + await tenantStorage.run({ tenantId: 'tenant-a' }, async () => { + await tenantSafeBulkWrite(Model, [ + { + updateMany: { + filter: { name: 'batch' }, + update: { $set: { value: 5 } }, + }, + }, + ]); + }); + + const docs = await runAsSystem(async () => Model.find({}).lean()); + const tenantADocs = docs.filter((d) => d.tenantId === 'tenant-a'); + const tenantBDocs = docs.filter((d) => d.tenantId === 'tenant-b'); + expect(tenantADocs.every((d) => d.value === 5)).toBe(true); + expect(tenantBDocs[0].value).toBe(0); + }); + }); + + describe('with SYSTEM_TENANT_ID', () => { + it('skips tenantId injection (cross-tenant operation)', async () => { + const Model = createTestModel('system'); + + await runAsSystem(async () => { + await Model.create([ + { name: 'sys-doc', value: 0, tenantId: 'tenant-a' }, + { name: 'sys-doc', value: 0, tenantId: 'tenant-b' }, + ]); + }); + + // System context should update ALL docs regardless of tenant + await runAsSystem(async () => { + await tenantSafeBulkWrite(Model, [ + { + updateMany: { + filter: { name: 'sys-doc' }, + update: { $set: { value: 100 } }, + }, + }, + ]); + }); + + const docs = await runAsSystem(async () => Model.find({}).lean()); + expect(docs.every((d) => d.value === 100)).toBe(true); + }); + }); + + describe('with SYSTEM_TENANT_ID in strict mode', () => { + it('does not throw when runAsSystem is used in strict mode', async () => { + process.env.TENANT_ISOLATION_STRICT = 'true'; + _resetBulkWriteStrictCache(); + + const Model = createTestModel('systemStrict'); + + await runAsSystem(async () => { + await Model.create({ name: 'strict-sys', value: 0 }); + }); + + await expect( + runAsSystem(async () => + tenantSafeBulkWrite(Model, [ + { + updateOne: { + filter: { name: 'strict-sys' }, + update: { $set: { value: 42 } }, + }, + }, + ]), + ), + ).resolves.toBeDefined(); + }); + }); + + describe('deleteMany and replaceOne', () => { + it('injects tenantId into deleteMany filters', async () => { + const Model = createTestModel('deleteMany'); + + await runAsSystem(async () => { + await Model.create([ + { name: 'batch-del', value: 0, tenantId: 'tenant-a' }, + { name: 'batch-del', value: 0, tenantId: 'tenant-a' }, + { name: 'batch-del', value: 0, tenantId: 'tenant-b' }, + ]); + }); + + await tenantStorage.run({ tenantId: 'tenant-a' }, async () => { + await tenantSafeBulkWrite(Model, [{ deleteMany: { filter: { name: 'batch-del' } } }]); + }); + + const docs = await runAsSystem(async () => Model.find({}).lean()); + expect(docs).toHaveLength(1); + expect(docs[0].tenantId).toBe('tenant-b'); + }); + + it('injects tenantId into replaceOne filter and replacement', async () => { + const Model = createTestModel('replaceOne'); + + await runAsSystem(async () => { + await Model.create([ + { name: 'to-replace', value: 1, tenantId: 'tenant-a' }, + { name: 'to-replace', value: 1, tenantId: 'tenant-b' }, + ]); + }); + + await tenantStorage.run({ tenantId: 'tenant-a' }, async () => { + await tenantSafeBulkWrite(Model, [ + { + replaceOne: { + filter: { name: 'to-replace' }, + replacement: { name: 'replaced', value: 99 }, + }, + }, + ]); + }); + + const docs = await runAsSystem(async () => Model.find({}).sort({ name: 1 }).lean()); + const replaced = docs.find((d) => d.name === 'replaced'); + const untouched = docs.find((d) => d.tenantId === 'tenant-b'); + expect(replaced?.value).toBe(99); + expect(replaced?.tenantId).toBe('tenant-a'); + expect(untouched?.value).toBe(1); + }); + + it('replaceOne overwrites a conflicting tenantId in the replacement document', async () => { + const Model = createTestModel('replaceOverwrite'); + + await runAsSystem(async () => { + await Model.create({ name: 'conflict', value: 1, tenantId: 'tenant-a' }); + }); + + await tenantStorage.run({ tenantId: 'tenant-a' }, async () => { + await tenantSafeBulkWrite(Model, [ + { + replaceOne: { + filter: { name: 'conflict' }, + replacement: { name: 'conflict', value: 2, tenantId: 'tenant-evil' } as ITestDoc, + }, + }, + ]); + }); + + const docs = await runAsSystem(async () => Model.find({}).lean()); + expect(docs).toHaveLength(1); + expect(docs[0].tenantId).toBe('tenant-a'); + expect(docs[0].value).toBe(2); + }); + }); + + describe('edge cases', () => { + it('handles empty ops array', async () => { + const Model = createTestModel('emptyOps'); + const result = await tenantStorage.run({ tenantId: 'tenant-x' }, async () => + tenantSafeBulkWrite(Model, []), + ); + expect(result.insertedCount).toBe(0); + expect(result.modifiedCount).toBe(0); + }); + }); + + describe('without tenant context', () => { + it('passes through in non-strict mode', async () => { + const Model = createTestModel('noCtx'); + + await runAsSystem(async () => { + await Model.create({ name: 'no-ctx', value: 0 }); + }); + + // No ALS context — non-strict should pass through + const result = await tenantSafeBulkWrite(Model, [ + { + updateOne: { + filter: { name: 'no-ctx' }, + update: { $set: { value: 10 } }, + }, + }, + ]); + + expect(result.modifiedCount).toBe(1); + }); + + it('throws in strict mode', async () => { + process.env.TENANT_ISOLATION_STRICT = 'true'; + _resetBulkWriteStrictCache(); + + const Model = createTestModel('strict'); + + await expect( + tenantSafeBulkWrite(Model, [ + { + updateOne: { + filter: { name: 'any' }, + update: { $set: { value: 1 } }, + }, + }, + ]), + ).rejects.toThrow('bulkWrite on TestBulkWrite_strict'); + }); + }); + + describe('mixed operations', () => { + it('handles a batch of mixed insert, update, delete operations', async () => { + const Model = createTestModel('mixed'); + + await runAsSystem(async () => { + await Model.create([ + { name: 'existing1', value: 1, tenantId: 'tenant-m' }, + { name: 'to-remove', value: 2, tenantId: 'tenant-m' }, + { name: 'existing1', value: 1, tenantId: 'tenant-other' }, + ]); + }); + + await tenantStorage.run({ tenantId: 'tenant-m' }, async () => { + await tenantSafeBulkWrite(Model, [ + { + insertOne: { + document: { name: 'new-item', value: 10 } as ITestDoc, + }, + }, + { + updateOne: { + filter: { name: 'existing1' }, + update: { $set: { value: 50 } }, + }, + }, + { + deleteOne: { + filter: { name: 'to-remove' }, + }, + }, + ]); + }); + + const docs = await runAsSystem(async () => Model.find({}).sort({ name: 1 }).lean()); + + // tenant-other's doc should be untouched + const otherDoc = docs.find((d) => d.tenantId === 'tenant-other' && d.name === 'existing1'); + expect(otherDoc?.value).toBe(1); + + // tenant-m: existing1 updated, to-remove deleted, new-item inserted + const tenantMDocs = docs.filter((d) => d.tenantId === 'tenant-m'); + expect(tenantMDocs).toHaveLength(2); + expect(tenantMDocs.find((d) => d.name === 'existing1')?.value).toBe(50); + expect(tenantMDocs.find((d) => d.name === 'new-item')?.value).toBe(10); + expect(tenantMDocs.find((d) => d.name === 'to-remove')).toBeUndefined(); + }); + }); +}); diff --git a/packages/data-schemas/src/utils/tenantBulkWrite.ts b/packages/data-schemas/src/utils/tenantBulkWrite.ts new file mode 100644 index 0000000000..16ef5fa057 --- /dev/null +++ b/packages/data-schemas/src/utils/tenantBulkWrite.ts @@ -0,0 +1,109 @@ +import type { AnyBulkWriteOperation, Model, MongooseBulkWriteOptions } from 'mongoose'; +import type { BulkWriteResult } from 'mongodb'; +import { getTenantId, SYSTEM_TENANT_ID } from '~/config/tenantContext'; +import logger from '~/config/winston'; + +let _strictMode: boolean | undefined; + +function isStrict(): boolean { + return (_strictMode ??= process.env.TENANT_ISOLATION_STRICT === 'true'); +} + +/** Resets the cached strict-mode flag. Exposed for test teardown only. */ +export function _resetBulkWriteStrictCache(): void { + _strictMode = undefined; +} + +/** + * Tenant-safe wrapper around Mongoose `Model.bulkWrite()`. + * + * Mongoose's `bulkWrite` does not trigger schema-level middleware hooks, so the + * `applyTenantIsolation` plugin cannot intercept it. This wrapper injects the + * current ALS tenant context into every operation's filter and/or document + * before delegating to the native `bulkWrite`. + * + * Behavior: + * - **tenantId present** (normal request): injects `{ tenantId }` into every + * operation filter (updateOne, deleteOne, replaceOne) and document (insertOne). + * - **SYSTEM_TENANT_ID**: skips injection (cross-tenant system operation). + * - **No tenantId + strict mode**: throws (fail-closed, same as the plugin). + * - **No tenantId + non-strict**: passes through without injection (backward compat). + */ +export async function tenantSafeBulkWrite( + model: Model, + ops: AnyBulkWriteOperation[], + options?: MongooseBulkWriteOptions, +): Promise { + const tenantId = getTenantId(); + + if (!tenantId) { + if (isStrict()) { + throw new Error( + `[TenantIsolation] bulkWrite on ${model.modelName} attempted without tenant context in strict mode`, + ); + } + return model.bulkWrite(ops, options); + } + + if (tenantId === SYSTEM_TENANT_ID) { + return model.bulkWrite(ops, options); + } + + const injected = ops.map((op) => injectTenantId(op, tenantId)); + return model.bulkWrite(injected, options); +} + +/** + * Injects `tenantId` into a single bulk-write operation. + * Returns a new operation object — does not mutate the original. + */ +function injectTenantId(op: AnyBulkWriteOperation, tenantId: string): AnyBulkWriteOperation { + if ('insertOne' in op) { + return { + insertOne: { + document: { ...op.insertOne.document, tenantId }, + }, + }; + } + + if ('updateOne' in op) { + const { filter, ...rest } = op.updateOne; + return { updateOne: { ...rest, filter: { ...filter, tenantId } } }; + } + + if ('updateMany' in op) { + const { filter, ...rest } = op.updateMany; + return { updateMany: { ...rest, filter: { ...filter, tenantId } } }; + } + + if ('deleteOne' in op) { + const { filter, ...rest } = op.deleteOne; + return { deleteOne: { ...rest, filter: { ...filter, tenantId } } }; + } + + if ('deleteMany' in op) { + const { filter, ...rest } = op.deleteMany; + return { deleteMany: { ...rest, filter: { ...filter, tenantId } } }; + } + + if ('replaceOne' in op) { + const { filter, replacement, ...rest } = op.replaceOne; + return { + replaceOne: { + ...rest, + filter: { ...filter, tenantId }, + replacement: { ...replacement, tenantId }, + }, + }; + } + + if (isStrict()) { + throw new Error( + '[TenantIsolation] Unknown bulkWrite operation type in strict mode — refusing to pass through without tenant injection', + ); + } + logger.warn( + '[tenantSafeBulkWrite] Unknown bulk op type, passing through without tenant injection', + ); + return op; +}