From 4b6d68b3b5d7a963511a4b78a7b3e79a710ec1c6 Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Wed, 25 Mar 2026 19:39:29 -0400 Subject: [PATCH 01/18] =?UTF-8?q?=F0=9F=8E=9B=EF=B8=8F=20feat:=20DB-Backed?= =?UTF-8?q?=20Per-Principal=20Config=20System=20(#12354)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * โœจ feat: Add Config schema, model, and methods for role-based DB config overrides Add the database foundation for principal-based configuration overrides (user, group, role) in data-schemas. Includes schema with tenantId and tenant isolation, CRUD methods, and barrel exports. * ๐Ÿ”ง fix: Add shebang and enforce LF line endings for git hooks The pre-commit hook was missing #!/bin/sh, and core.autocrlf=true was converting it to CRLF, both causing "Exec format error" on Windows. Add .gitattributes to force LF for .husky/* and *.sh files. * โœจ feat: Add admin config API routes with section-level capability checks Add /api/admin/config endpoints for managing per-principal config overrides (user, group, role). Handlers in @librechat/api use DI pattern with section-level hasConfigCapability checks for granular access control. Supports full overrides replacement, per-field PATCH via dot-paths, field deletion, toggle active, and listing. * ๐Ÿ› fix: Move deleteConfigField fieldPath from URL param to request body The path-to-regexp wildcard syntax (:fieldPath(*)) is not supported by the version used in Express. Send fieldPath in the DELETE request body instead, which also avoids URL-encoding issues with dotted paths. * โœจ feat: Wire config resolution into getAppConfig with override caching Add mergeConfigOverrides utility in data-schemas for deep-merging DB config overrides into base AppConfig by priority order. Update getAppConfig to query DB for applicable configs when role/userId is provided, with short-TTL caching and a hasAnyConfigs feature flag for zero-cost when no DB configs exist. Also: add unique compound index on Config schema, pass userId from config middleware, and signal config changes from admin API handlers. * ๐Ÿ”„ refactor: Extract getAppConfig logic into packages/api as TS service Move override resolution, caching strategy, and signalConfigChange from api/server/services/Config/app.js into packages/api/src/app/appConfigService.ts using the DI factory pattern (createAppConfigService). The JS file becomes a thin wiring layer injecting loadBaseConfig, cache, and DB dependencies. * ๐Ÿงน chore: Rename configResolution.ts to resolution.ts * โœจ feat: Move admin types & capabilities to librechat-data-provider Move SystemCapabilities, CapabilityImplications, and utility functions (hasImpliedCapability, expandImplications) from data-schemas to data-provider so they are available to external consumers like the admin panel without a data-schemas dependency. Add API-friendly admin types: TAdminConfig, TAdminSystemGrant, TAdminAuditLogEntry, TAdminGroup, TAdminMember, TAdminUserSearchResult, TCapabilityCategory, and CAPABILITY_CATEGORIES. data-schemas re-exports these from data-provider and extends with config-schema-derived types (ConfigSection, SystemCapability union). Bump version to 0.8.500. * feat: Add JSON-serializable admin config API response types to data-schemas Add AdminConfig, AdminConfigListResponse, AdminConfigResponse, and AdminConfigDeleteResponse types so both LibreChat API handlers and the admin panel can share the same response contract. Bump version to 0.0.41. * refactor: Move admin capabilities & types from data-provider to data-schemas SystemCapabilities, CapabilityImplications, utility functions, CAPABILITY_CATEGORIES, and admin API response types should not be in data-provider as it gets compiled into the frontend bundle, exposing the capability surface. Moved everything to data-schemas (server-only). All consumers already import from @librechat/data-schemas, so no import changes needed elsewhere. Consolidated duplicate AdminConfig type (was in both config.ts and admin.ts). * chore: Bump @librechat/data-schemas to 0.0.42 * refactor: Reorganize admin capabilities into admin/ and types/admin.ts Split systemCapabilities.ts following data-schemas conventions: - Types (BaseSystemCapability, SystemCapability, AdminConfig, etc.) โ†’ src/types/admin.ts - Runtime code (SystemCapabilities, CapabilityImplications, utilities) โ†’ src/admin/capabilities.ts Revert data-provider version to 0.8.401 (no longer modified). * chore: Fix import ordering, rename appConfigService to service - Rename app/appConfigService.ts โ†’ app/service.ts (directory provides context) - Fix import order in admin/config.ts, types/admin.ts, types/config.ts - Add naming convention to AGENTS.md * feat: Add DB base config support (role/__base__) - Add BASE_CONFIG_PRINCIPAL_ID constant for reserved base config doc - getApplicableConfigs always includes __base__ in queries - getAppConfig queries DB even without role/userId when DB configs exist - Bump @librechat/data-schemas to 0.0.43 * fix: Address PR review issues for admin config - Add listAllConfigs method; listConfigs endpoint returns all active configs instead of only __base__ - Normalize principalId to string in all config methods to prevent ObjectId vs string mismatch on user/group lookups - Block __proto__ and all dunder-prefixed segments in field path validation to prevent prototype pollution - Fix configVersion off-by-one: default to 0, guard pre('save') with !isNew, use $inc on findOneAndUpdate - Remove unused getApplicableConfigs from admin handler deps * fix: Enable tree-shaking for data-schemas, bump packages - Switch data-schemas Rollup output to preserveModules so each source file becomes its own chunk; consumers (admin panel) can now import just the modules they need without pulling in winston/mongoose/etc. - Add sideEffects: false to data-schemas package.json - Bump data-schemas to 0.0.44, data-provider to 0.8.402 * feat: add capabilities subpath export to data-schemas Adds `@librechat/data-schemas/capabilities` subpath export so browser consumers can import BASE_CONFIG_PRINCIPAL_ID and capability constants without pulling in Node.js-only modules (winston, async_hooks, etc.). Bump version to 0.0.45. * fix: include dist/ in data-provider npm package Add explicit files field so npm includes dist/types/ in the published package. Without this, the root .gitignore exclusion of dist/ causes npm to omit type declarations, breaking TypeScript consumers. * chore: bump librechat-data-provider to 0.8.403 * feat: add GET /api/admin/config/base for raw AppConfig Returns the full AppConfig (YAML + DB base merged) so the admin panel can display actual config field values and structure. The startup config endpoint (/api/config) returns TStartupConfig which is a different shape meant for the frontend app. * chore: imports order * fix: address code review findings for admin config Critical: - Fix clearAppConfigCache: was deleting from wrong cache store (CONFIG_STORE instead of APP_CONFIG), now clears BASE and HAS_DB_CONFIGS keys - Eliminate race condition: patchConfigField and deleteConfigField now use atomic MongoDB $set/$unset with dot-path notation instead of read-modify-write cycles, removing the lost-update bug entirely - Add patchConfigFields and unsetConfigField atomic DB methods Major: - Reorder cache check before principal resolution in getAppConfig so getUserPrincipals DB query only fires on cache miss - Replace '' as ConfigSection with typed BROAD_CONFIG_ACCESS constant - Parallelize capability checks with Promise.all instead of sequential awaits in for loops - Use loose equality (== null) for cache miss check to handle both null and undefined returns from cache implementations - Set HAS_DB_CONFIGS_KEY to true on successful config fetch Minor: - Remove dead pre('save') hook from config schema (all writes use findOneAndUpdate which bypasses document hooks) - Consolidate duplicate type imports in resolution.ts - Remove dead deepGet/deepSet/deepUnset functions (replaced by atomic ops) - Add .sort({ priority: 1 }) to getApplicableConfigs query - Rename _impliedBy to impliedByMap * fix: self-referencing BROAD_CONFIG_ACCESS constant * fix: replace type-cast sentinel with proper null parameter Update hasConfigCapability to accept ConfigSection | null where null means broad access check (MANAGE_CONFIGS or READ_CONFIGS only). Removes the '' as ConfigSection type lie from admin config handlers. * fix: remaining review findings + add tests - listAllConfigs accepts optional { isActive } filter so admin listing can show inactive configs (#9) - Standardize session application to .session(session ?? null) across all config DB methods (#15) - Export isValidFieldPath and getTopLevelSection for testability - Add 38 tests across 3 spec files: - config.spec.ts (api): path validation, prototype pollution rejection - resolution.spec.ts: deep merge, priority ordering, array replacement - config.spec.ts (data-schemas): full CRUD, ObjectId normalization, atomic $set/$unset, configVersion increment, toggle, __base__ query * fix: address second code review findings - Fix cross-user cache contamination: overrideCacheKey now handles userId-without-role case with its own cache key (#1) - Add broad capability check before DB lookup in getConfig to prevent config existence enumeration (#2/#3) - Move deleteConfigField fieldPath from request body to query parameter for proxy/load balancer compatibility (#5) - Derive BaseSystemCapability from SystemCapabilities const instead of manual string union (#6) - Return 201 on upsert creation, 200 on update (#11) - Remove inline narration comments per AGENTS.md (#12) - Type overrides as Partial in DB methods and handler deps (#13) - Replace double as-unknown-as casts in resolution.ts with generic deepMerge (#14) - Make override cache TTL injectable via AppConfigServiceDeps (#16) - Add exhaustive never check in principalModel switch (#17) * fix: remaining review findings โ€” tests, rename, semantics - Rename signalConfigChange โ†’ markConfigsDirty with JSDoc documenting the stale-window tradeoff and overrideCacheTtl knob - Fix DEFAULT_OVERRIDE_CACHE_TTL naming convention - Add createAppConfigService tests (14 cases): cache behavior, feature flag, cross-user key isolation, fallback on error, markConfigsDirty - Add admin handler integration tests (13 cases): auth ordering, 201/200 on create/update, fieldPath from query param, markConfigsDirty calls, capability checks * fix: global flag corruption + empty overrides auth bypass - Remove HAS_DB_CONFIGS_KEY=false optimization: a scoped query returning no configs does not mean no configs exist globally. Setting the flag false from a per-principal query short-circuited all subsequent users. - Add broad manage capability check before section checks in upsertConfigOverrides: empty overrides {} no longer bypasses auth. * test: add regression and invariant tests for config system Regression tests: - Bug 1: User A's empty result does not short-circuit User B's overrides - Bug 2: Empty overrides {} returns 403 without MANAGE_CONFIGS Invariant tests (applied across ALL handlers): - All 5 mutation handlers call markConfigsDirty on success - All 5 mutation handlers return 401 without auth - All 5 mutation handlers return 403 without capability - All 3 read handlers return 403 without capability * fix: third review pass โ€” all findings addressed Service (service.ts): - Restore HAS_DB_CONFIGS=false for base-only queries (no role/userId) so deployments with zero DB configs skip DB queries (#1) - Resolve cache once at factory init instead of per-invocation (#8) - Use BASE_CONFIG_PRINCIPAL_ID constant in overrideCacheKey (#10) - Add JSDoc to clearAppConfigCache documenting stale-window (#4) - Fix log message to not say "from YAML" (#14) Admin handlers (config.ts): - Use configVersion===1 for 201 vs 200, eliminating TOCTOU race (#2) - Add Array.isArray guard on overrides body (#5) - Import CapabilityUser from capabilities.ts, remove duplicate (#6) - Replace as-unknown-as cast with targeted type assertion (#7) - Add MAX_PATCH_ENTRIES=100 cap on entries array (#15) - Reorder deleteConfigField to validate principalType first (#12) - Export CapabilityUser from middleware/capabilities.ts DB methods (config.ts): - Remove isActive:true from patchConfigFields to prevent silent reactivation of disabled configs (#3) Schema (config.ts): - Change principalId from Schema.Types.Mixed to String (#11) Tests: - Add patchConfigField unsafe fieldPath rejection test (#9) - Add base-only HAS_DB_CONFIGS=false test (#1) - Update 201/200 tests to use configVersion instead of findConfig (#2) * fix: add read handler 401 invariant tests + document flag behavior - Add invariant: all 3 read handlers return 401 without auth - Document on markConfigsDirty that HAS_DB_CONFIGS stays true after all configs are deleted until clearAppConfigCache or restart * fix: remove HAS_DB_CONFIGS false optimization entirely getApplicableConfigs([]) only queries for __base__, not all configs. A deployment with role/group configs but no __base__ doc gets the flag poisoned to false by a base-only query, silently ignoring all scoped overrides. The optimization is not safe without a comprehensive Config.exists() check, which adds its own DB cost. Removed entirely. The flag is now write-once-true (set when configs are found or by markConfigsDirty) and only cleared by clearAppConfigCache/restart. * chore: reorder import statements in app.js for clarity * refactor: remove HAS_DB_CONFIGS_KEY machinery entirely The three-state flag (false/null/true) was the source of multiple bugs across review rounds. Every attempt to safely set it to false was defeated by getApplicableConfigs querying only a subset of principals. Removed: HAS_DB_CONFIGS_KEY constant, all reads/writes of the flag, markConfigsDirty (now a no-op concept), notifyChange wrapper, and all tests that seeded false manually. The per-user/role TTL cache (overrideCacheTtl, default 60s) is the sole caching mechanism. On cache miss, getApplicableConfigs queries the DB. This is one indexed query per user per TTL window โ€” acceptable for the config override use case. * docs: rewrite admin panel remaining work with current state * perf: cache empty override results to avoid repeated DB queries When getApplicableConfigs returns no configs for a principal, cache baseConfig under their override key with TTL. Without this, every user with no per-principal overrides hits MongoDB on every request after the 60s cache window expires. * fix: add tenantId to cache keys + reject PUBLIC principal type - Include tenantId in override cache keys to prevent cross-tenant config contamination. Single-tenant deployments (tenantId undefined) use '_' as placeholder โ€” no behavior change for them. - Reject PrincipalType.PUBLIC in admin config validation โ€” PUBLIC has no PrincipalModel and is never resolved by getApplicableConfigs, so config docs for it would be dead data. - Config middleware passes req.user.tenantId to getAppConfig. * fix: fourth review pass findings DB methods (config.ts): - findConfigByPrincipal accepts { includeInactive } option so admin GET can retrieve inactive configs (#5) - upsertConfig catches E11000 duplicate key on concurrent upserts and retries without upsert flag (#2) - unsetConfigField no longer filters isActive:true, consistent with patchConfigFields (#11) - Typed filter objects replace Record (#12) Admin handlers (config.ts): - patchConfigField: serial broad capability check before Promise.all to pre-warm ALS principal cache, preventing N parallel DB calls (#3) - isValidFieldPath rejects leading/trailing dots and consecutive dots (#7) - Duplicate fieldPaths in patch entries return 400 (#8) - DEFAULT_PRIORITY named constant replaces hardcoded 10 (#14) - Admin getConfig and patchConfigField pass includeInactive to findConfigByPrincipal (#5) - Route import uses barrel instead of direct file path (#13) Resolution (resolution.ts): - deepMerge has MAX_MERGE_DEPTH=10 guard to prevent stack overflow from crafted deeply nested configs (#4) * fix: final review cleanup - Remove ADMIN_PANEL_REMAINING.md (local dev notes with Windows paths) - Add empty-result caching regression test - Add tenantId to AdminConfigDeps.getAppConfig type - Restore exhaustive never check in principalModel switch - Standardize toggleConfigActive session handling to options pattern * fix: validate priority in patchConfigField handler Add the same non-negative number validation for priority that upsertConfigOverrides already has. Without this, invalid priority values could be stored via PATCH and corrupt merge ordering. * chore: remove planning doc from PR * fix: correct stale cache key strings in service tests * fix: clean up service tests and harden tenant sentinel - Remove no-op cache delete lines from regression tests - Change no-tenant sentinel from '_' to '__default__' to avoid collision with a real tenant ID when multi-tenancy is enabled - Remove unused CONFIG_STORE from AppConfigServiceDeps * chore: bump @librechat/data-schemas to 0.0.46 * fix: block prototype-poisoning keys in deepMerge Skip __proto__, constructor, and prototype keys during config merge to prevent prototype pollution via PUT /api/admin/config overrides. --- .gitattributes | 3 + AGENTS.md | 6 + api/server/index.js | 1 + api/server/middleware/config/app.js | 4 +- api/server/routes/admin/config.js | 39 ++ api/server/routes/index.js | 2 + api/server/services/Config/app.js | 71 +-- packages/api/src/admin/config.handler.spec.ts | 414 ++++++++++++++ packages/api/src/admin/config.spec.ts | 57 ++ packages/api/src/admin/config.ts | 509 ++++++++++++++++++ packages/api/src/admin/index.ts | 2 + packages/api/src/app/index.ts | 1 + packages/api/src/app/service.spec.ts | 244 +++++++++ packages/api/src/app/service.ts | 155 ++++++ packages/api/src/index.ts | 2 + packages/api/src/middleware/capabilities.ts | 9 +- packages/data-provider/package.json | 5 +- packages/data-schemas/package.json | 8 +- packages/data-schemas/rollup.config.js | 10 +- .../data-schemas/src/admin/capabilities.ts | 199 +++++++ packages/data-schemas/src/admin/index.ts | 1 + packages/data-schemas/src/app/index.ts | 1 + .../data-schemas/src/app/resolution.spec.ts | 108 ++++ packages/data-schemas/src/app/resolution.ts | 54 ++ packages/data-schemas/src/index.ts | 2 +- .../data-schemas/src/methods/config.spec.ts | 297 ++++++++++ packages/data-schemas/src/methods/config.ts | 215 ++++++++ packages/data-schemas/src/methods/index.ts | 8 +- .../src/methods/systemGrant.spec.ts | 4 +- .../data-schemas/src/methods/systemGrant.ts | 4 +- packages/data-schemas/src/models/config.ts | 8 + packages/data-schemas/src/models/index.ts | 2 + packages/data-schemas/src/schema/config.ts | 55 ++ packages/data-schemas/src/schema/index.ts | 1 + .../data-schemas/src/schema/systemGrant.ts | 4 +- .../data-schemas/src/systemCapabilities.ts | 106 ---- packages/data-schemas/src/types/admin.ts | 126 +++++ packages/data-schemas/src/types/config.ts | 36 ++ packages/data-schemas/src/types/index.ts | 4 + .../data-schemas/src/types/systemGrant.ts | 2 +- 40 files changed, 2596 insertions(+), 183 deletions(-) create mode 100644 .gitattributes create mode 100644 api/server/routes/admin/config.js create mode 100644 packages/api/src/admin/config.handler.spec.ts create mode 100644 packages/api/src/admin/config.spec.ts create mode 100644 packages/api/src/admin/config.ts create mode 100644 packages/api/src/admin/index.ts create mode 100644 packages/api/src/app/service.spec.ts create mode 100644 packages/api/src/app/service.ts create mode 100644 packages/data-schemas/src/admin/capabilities.ts create mode 100644 packages/data-schemas/src/admin/index.ts create mode 100644 packages/data-schemas/src/app/resolution.spec.ts create mode 100644 packages/data-schemas/src/app/resolution.ts create mode 100644 packages/data-schemas/src/methods/config.spec.ts create mode 100644 packages/data-schemas/src/methods/config.ts create mode 100644 packages/data-schemas/src/models/config.ts create mode 100644 packages/data-schemas/src/schema/config.ts delete mode 100644 packages/data-schemas/src/systemCapabilities.ts create mode 100644 packages/data-schemas/src/types/admin.ts create mode 100644 packages/data-schemas/src/types/config.ts diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000000..725ac8b6bd --- /dev/null +++ b/.gitattributes @@ -0,0 +1,3 @@ +# Force LF line endings for shell scripts and git hooks (required for cross-platform compatibility) +.husky/* text eol=lf +*.sh text eol=lf diff --git a/AGENTS.md b/AGENTS.md index ec44607aa7..81362cfc57 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -29,6 +29,12 @@ The source code for `@librechat/agents` (major backend dependency, same team) is ## Code Style +### Naming and File Organization + +- **Single-word file names** whenever possible (e.g., `permissions.ts`, `capabilities.ts`, `service.ts`). +- When multiple words are needed, prefer grouping related modules under a **single-word directory** rather than using multi-word file names (e.g., `admin/capabilities.ts` not `adminCapabilities.ts`). +- The directory already provides context โ€” `app/service.ts` not `app/appConfigService.ts`. + ### Structure and Clarity - **Never-nesting**: early returns, flat code, minimal indentation. Break complex operations into well-named helpers. diff --git a/api/server/index.js b/api/server/index.js index ba376ab335..0a8a29f3b7 100644 --- a/api/server/index.js +++ b/api/server/index.js @@ -141,6 +141,7 @@ const startServer = async () => { /* API Endpoints */ app.use('/api/auth', routes.auth); app.use('/api/admin', routes.adminAuth); + app.use('/api/admin/config', routes.adminConfig); app.use('/api/actions', routes.actions); app.use('/api/keys', routes.keys); app.use('/api/api-keys', routes.apiKeys); diff --git a/api/server/middleware/config/app.js b/api/server/middleware/config/app.js index bca3c8f71d..fb5f89b229 100644 --- a/api/server/middleware/config/app.js +++ b/api/server/middleware/config/app.js @@ -4,7 +4,9 @@ const { getAppConfig } = require('~/server/services/Config'); const configMiddleware = async (req, res, next) => { try { const userRole = req.user?.role; - req.config = await getAppConfig({ role: userRole }); + const userId = req.user?.id; + const tenantId = req.user?.tenantId; + req.config = await getAppConfig({ role: userRole, userId, tenantId }); next(); } catch (error) { diff --git a/api/server/routes/admin/config.js b/api/server/routes/admin/config.js new file mode 100644 index 0000000000..b9407c6b09 --- /dev/null +++ b/api/server/routes/admin/config.js @@ -0,0 +1,39 @@ +const express = require('express'); +const { createAdminConfigHandlers } = require('@librechat/api'); +const { SystemCapabilities } = require('@librechat/data-schemas'); +const { + hasConfigCapability, + requireCapability, +} = require('~/server/middleware/roles/capabilities'); +const { getAppConfig } = require('~/server/services/Config'); +const { requireJwtAuth } = require('~/server/middleware'); +const db = require('~/models'); + +const router = express.Router(); + +const requireAdminAccess = requireCapability(SystemCapabilities.ACCESS_ADMIN); + +const handlers = createAdminConfigHandlers({ + listAllConfigs: db.listAllConfigs, + findConfigByPrincipal: db.findConfigByPrincipal, + upsertConfig: db.upsertConfig, + patchConfigFields: db.patchConfigFields, + unsetConfigField: db.unsetConfigField, + deleteConfig: db.deleteConfig, + toggleConfigActive: db.toggleConfigActive, + hasConfigCapability, + getAppConfig, +}); + +router.use(requireJwtAuth, requireAdminAccess); + +router.get('/', handlers.listConfigs); +router.get('/base', handlers.getBaseConfig); +router.get('/:principalType/:principalId', handlers.getConfig); +router.put('/:principalType/:principalId', handlers.upsertConfigOverrides); +router.patch('/:principalType/:principalId/fields', handlers.patchConfigField); +router.delete('/:principalType/:principalId/fields', handlers.deleteConfigField); +router.delete('/:principalType/:principalId', handlers.deleteConfigOverrides); +router.patch('/:principalType/:principalId/active', handlers.toggleConfig); + +module.exports = router; diff --git a/api/server/routes/index.js b/api/server/routes/index.js index 6a48919db3..b1f16d5e3c 100644 --- a/api/server/routes/index.js +++ b/api/server/routes/index.js @@ -2,6 +2,7 @@ const accessPermissions = require('./accessPermissions'); const assistants = require('./assistants'); const categories = require('./categories'); const adminAuth = require('./admin/auth'); +const adminConfig = require('./admin/config'); const endpoints = require('./endpoints'); const staticRoute = require('./static'); const messages = require('./messages'); @@ -31,6 +32,7 @@ module.exports = { mcp, auth, adminAuth, + adminConfig, keys, apiKeys, user, diff --git a/api/server/services/Config/app.js b/api/server/services/Config/app.js index 75a5cbe56d..a63bef2124 100644 --- a/api/server/services/Config/app.js +++ b/api/server/services/Config/app.js @@ -1,12 +1,12 @@ const { CacheKeys } = require('librechat-data-provider'); -const { logger, AppService } = require('@librechat/data-schemas'); +const { AppService } = require('@librechat/data-schemas'); +const { createAppConfigService } = require('@librechat/api'); const { loadAndFormatTools } = require('~/server/services/start/tools'); const loadCustomConfig = require('./loadCustomConfig'); const { setCachedTools } = require('./getCachedTools'); const getLogStores = require('~/cache/getLogStores'); const paths = require('~/config/paths'); - -const BASE_CONFIG_KEY = '_BASE_'; +const db = require('~/models'); const loadBaseConfig = async () => { /** @type {TCustomConfig} */ @@ -20,63 +20,14 @@ const loadBaseConfig = async () => { return AppService({ config, paths, systemTools }); }; -/** - * Get the app configuration based on user context - * @param {Object} [options] - * @param {string} [options.role] - User role for role-based config - * @param {boolean} [options.refresh] - Force refresh the cache - * @returns {Promise} - */ -async function getAppConfig(options = {}) { - const { role, refresh } = options; - - const cache = getLogStores(CacheKeys.APP_CONFIG); - const cacheKey = role ? role : BASE_CONFIG_KEY; - - if (!refresh) { - const cached = await cache.get(cacheKey); - if (cached) { - return cached; - } - } - - let baseConfig = await cache.get(BASE_CONFIG_KEY); - if (!baseConfig) { - logger.info('[getAppConfig] App configuration not initialized. Initializing AppService...'); - baseConfig = await loadBaseConfig(); - - if (!baseConfig) { - throw new Error('Failed to initialize app configuration through AppService.'); - } - - if (baseConfig.availableTools) { - await setCachedTools(baseConfig.availableTools); - } - - await cache.set(BASE_CONFIG_KEY, baseConfig); - } - - // For now, return the base config - // In the future, this is where we'll apply role-based modifications - if (role) { - // TODO: Apply role-based config modifications - // const roleConfig = await applyRoleBasedConfig(baseConfig, role); - // await cache.set(cacheKey, roleConfig); - // return roleConfig; - } - - return baseConfig; -} - -/** - * Clear the app configuration cache - * @returns {Promise} - */ -async function clearAppConfigCache() { - const cache = getLogStores(CacheKeys.CONFIG_STORE); - const cacheKey = CacheKeys.APP_CONFIG; - return await cache.delete(cacheKey); -} +const { getAppConfig, clearAppConfigCache } = createAppConfigService({ + loadBaseConfig, + setCachedTools, + getCache: getLogStores, + cacheKeys: CacheKeys, + getApplicableConfigs: db.getApplicableConfigs, + getUserPrincipals: db.getUserPrincipals, +}); module.exports = { getAppConfig, diff --git a/packages/api/src/admin/config.handler.spec.ts b/packages/api/src/admin/config.handler.spec.ts new file mode 100644 index 0000000000..705c54babc --- /dev/null +++ b/packages/api/src/admin/config.handler.spec.ts @@ -0,0 +1,414 @@ +import { createAdminConfigHandlers } from './config'; + +function mockReq(overrides = {}) { + return { + user: { id: 'u1', role: 'ADMIN', _id: { toString: () => 'u1' } }, + params: {}, + body: {}, + query: {}, + ...overrides, + }; +} + +function mockRes() { + const res = { + statusCode: 200, + body: undefined, + status: jest.fn((code) => { + res.statusCode = code; + return res; + }), + json: jest.fn((data) => { + res.body = data; + return res; + }), + }; + return res; +} + +function createHandlers(overrides = {}) { + const deps = { + listAllConfigs: jest.fn().mockResolvedValue([]), + findConfigByPrincipal: jest.fn().mockResolvedValue(null), + upsertConfig: jest.fn().mockResolvedValue({ + _id: 'c1', + principalType: 'role', + principalId: 'admin', + overrides: {}, + configVersion: 1, + }), + patchConfigFields: jest + .fn() + .mockResolvedValue({ _id: 'c1', overrides: { interface: { endpointsMenu: false } } }), + unsetConfigField: jest.fn().mockResolvedValue({ _id: 'c1', overrides: {} }), + deleteConfig: jest.fn().mockResolvedValue({ _id: 'c1' }), + toggleConfigActive: jest.fn().mockResolvedValue({ _id: 'c1', isActive: false }), + hasConfigCapability: jest.fn().mockResolvedValue(true), + + getAppConfig: jest.fn().mockResolvedValue({ interface: { endpointsMenu: true } }), + ...overrides, + }; + const handlers = createAdminConfigHandlers(deps); + return { handlers, deps }; +} + +describe('createAdminConfigHandlers', () => { + describe('getConfig', () => { + it('returns 403 before DB lookup when user lacks READ_CONFIGS', async () => { + const { handlers, deps } = createHandlers({ + hasConfigCapability: jest.fn().mockResolvedValue(false), + }); + const req = mockReq({ params: { principalType: 'role', principalId: 'admin' } }); + const res = mockRes(); + + await handlers.getConfig(req, res); + + expect(res.statusCode).toBe(403); + expect(deps.findConfigByPrincipal).not.toHaveBeenCalled(); + }); + + it('returns 404 when config does not exist', async () => { + const { handlers } = createHandlers(); + const req = mockReq({ params: { principalType: 'role', principalId: 'nonexistent' } }); + const res = mockRes(); + + await handlers.getConfig(req, res); + + expect(res.statusCode).toBe(404); + }); + + it('returns config when authorized and exists', async () => { + const config = { + _id: 'c1', + principalType: 'role', + principalId: 'admin', + overrides: { x: 1 }, + }; + const { handlers } = createHandlers({ + findConfigByPrincipal: jest.fn().mockResolvedValue(config), + }); + const req = mockReq({ params: { principalType: 'role', principalId: 'admin' } }); + const res = mockRes(); + + await handlers.getConfig(req, res); + + expect(res.statusCode).toBe(200); + expect(res.body.config).toEqual(config); + }); + + it('returns 400 for invalid principalType', async () => { + const { handlers } = createHandlers(); + const req = mockReq({ params: { principalType: 'invalid', principalId: 'x' } }); + const res = mockRes(); + + await handlers.getConfig(req, res); + + expect(res.statusCode).toBe(400); + }); + + it('rejects public principalType โ€” not usable for config overrides', async () => { + const { handlers } = createHandlers(); + const req = mockReq({ params: { principalType: 'public', principalId: 'x' } }); + const res = mockRes(); + + await handlers.getConfig(req, res); + + expect(res.statusCode).toBe(400); + }); + }); + + describe('upsertConfigOverrides', () => { + it('returns 201 when creating a new config (configVersion === 1)', async () => { + const { handlers } = createHandlers({ + upsertConfig: jest.fn().mockResolvedValue({ _id: 'c1', configVersion: 1 }), + }); + const req = mockReq({ + params: { principalType: 'role', principalId: 'admin' }, + body: { overrides: { interface: { endpointsMenu: false } } }, + }); + const res = mockRes(); + + await handlers.upsertConfigOverrides(req, res); + + expect(res.statusCode).toBe(201); + }); + + it('returns 200 when updating an existing config (configVersion > 1)', async () => { + const { handlers } = createHandlers({ + upsertConfig: jest.fn().mockResolvedValue({ _id: 'c1', configVersion: 5 }), + }); + const req = mockReq({ + params: { principalType: 'role', principalId: 'admin' }, + body: { overrides: { interface: { endpointsMenu: false } } }, + }); + const res = mockRes(); + + await handlers.upsertConfigOverrides(req, res); + + expect(res.statusCode).toBe(200); + }); + + it('returns 400 when overrides is missing', async () => { + const { handlers } = createHandlers(); + const req = mockReq({ + params: { principalType: 'role', principalId: 'admin' }, + body: {}, + }); + const res = mockRes(); + + await handlers.upsertConfigOverrides(req, res); + + expect(res.statusCode).toBe(400); + }); + }); + + describe('deleteConfigField', () => { + it('reads fieldPath from query parameter', async () => { + const { handlers, deps } = createHandlers(); + const req = mockReq({ + params: { principalType: 'role', principalId: 'admin' }, + query: { fieldPath: 'interface.endpointsMenu' }, + }); + const res = mockRes(); + + await handlers.deleteConfigField(req, res); + + expect(deps.unsetConfigField).toHaveBeenCalledWith( + 'role', + 'admin', + 'interface.endpointsMenu', + ); + }); + + it('returns 400 when fieldPath query param is missing', async () => { + const { handlers } = createHandlers(); + const req = mockReq({ + params: { principalType: 'role', principalId: 'admin' }, + query: {}, + }); + const res = mockRes(); + + await handlers.deleteConfigField(req, res); + + expect(res.statusCode).toBe(400); + expect(res.body.error).toContain('query parameter'); + }); + + it('rejects unsafe field paths', async () => { + const { handlers } = createHandlers(); + const req = mockReq({ + params: { principalType: 'role', principalId: 'admin' }, + query: { fieldPath: '__proto__.polluted' }, + }); + const res = mockRes(); + + await handlers.deleteConfigField(req, res); + + expect(res.statusCode).toBe(400); + }); + }); + + describe('patchConfigField', () => { + it('returns 403 when user lacks capability for section', async () => { + const { handlers } = createHandlers({ + hasConfigCapability: jest.fn().mockResolvedValue(false), + }); + const req = mockReq({ + params: { principalType: 'role', principalId: 'admin' }, + body: { entries: [{ fieldPath: 'interface.endpointsMenu', value: false }] }, + }); + const res = mockRes(); + + await handlers.patchConfigField(req, res); + + expect(res.statusCode).toBe(403); + }); + + it('rejects entries with unsafe field paths (prototype pollution)', async () => { + const { handlers } = createHandlers(); + const req = mockReq({ + params: { principalType: 'role', principalId: 'admin' }, + body: { entries: [{ fieldPath: '__proto__.polluted', value: true }] }, + }); + const res = mockRes(); + + await handlers.patchConfigField(req, res); + + expect(res.statusCode).toBe(400); + }); + }); + + describe('upsertConfigOverrides โ€” Bug 2 regression', () => { + it('returns 403 for empty overrides when user lacks MANAGE_CONFIGS', async () => { + const { handlers } = createHandlers({ + hasConfigCapability: jest.fn().mockResolvedValue(false), + }); + const req = mockReq({ + params: { principalType: 'role', principalId: 'admin' }, + body: { overrides: {} }, + }); + const res = mockRes(); + + await handlers.upsertConfigOverrides(req, res); + + expect(res.statusCode).toBe(403); + }); + }); + + // โ”€โ”€ Invariant tests: rules that must hold across ALL handlers โ”€โ”€โ”€โ”€โ”€โ”€ + + const MUTATION_HANDLERS: Array<{ + name: string; + reqOverrides: Record; + }> = [ + { + name: 'upsertConfigOverrides', + reqOverrides: { + params: { principalType: 'role', principalId: 'admin' }, + body: { overrides: { interface: { endpointsMenu: false } } }, + }, + }, + { + name: 'patchConfigField', + reqOverrides: { + params: { principalType: 'role', principalId: 'admin' }, + body: { entries: [{ fieldPath: 'interface.endpointsMenu', value: false }] }, + }, + }, + { + name: 'deleteConfigField', + reqOverrides: { + params: { principalType: 'role', principalId: 'admin' }, + query: { fieldPath: 'interface.endpointsMenu' }, + }, + }, + { + name: 'deleteConfigOverrides', + reqOverrides: { + params: { principalType: 'role', principalId: 'admin' }, + }, + }, + { + name: 'toggleConfig', + reqOverrides: { + params: { principalType: 'role', principalId: 'admin' }, + body: { isActive: false }, + }, + }, + ]; + + describe('invariant: all mutation handlers return 401 without auth', () => { + for (const { name, reqOverrides } of MUTATION_HANDLERS) { + it(`${name} returns 401 when user is missing`, async () => { + const { handlers } = createHandlers(); + const req = mockReq({ ...reqOverrides, user: undefined }); + const res = mockRes(); + + await (handlers as Record Promise>)[name]( + req, + res, + ); + + expect(res.statusCode).toBe(401); + }); + } + }); + + describe('invariant: all mutation handlers return 403 without capability', () => { + for (const { name, reqOverrides } of MUTATION_HANDLERS) { + it(`${name} returns 403 when user lacks capability`, async () => { + const { handlers } = createHandlers({ + hasConfigCapability: jest.fn().mockResolvedValue(false), + }); + const req = mockReq(reqOverrides); + const res = mockRes(); + + await (handlers as Record Promise>)[name]( + req, + res, + ); + + expect(res.statusCode).toBe(403); + }); + } + }); + + describe('invariant: all read handlers return 403 without capability', () => { + const READ_HANDLERS: Array<{ name: string; reqOverrides: Record }> = [ + { name: 'listConfigs', reqOverrides: {} }, + { name: 'getBaseConfig', reqOverrides: {} }, + { + name: 'getConfig', + reqOverrides: { params: { principalType: 'role', principalId: 'admin' } }, + }, + ]; + + for (const { name, reqOverrides } of READ_HANDLERS) { + it(`${name} returns 403 when user lacks capability`, async () => { + const { handlers } = createHandlers({ + hasConfigCapability: jest.fn().mockResolvedValue(false), + }); + const req = mockReq(reqOverrides); + const res = mockRes(); + + await (handlers as Record Promise>)[name]( + req, + res, + ); + + expect(res.statusCode).toBe(403); + }); + } + }); + + describe('invariant: all read handlers return 401 without auth', () => { + const READ_HANDLERS: Array<{ name: string; reqOverrides: Record }> = [ + { name: 'listConfigs', reqOverrides: {} }, + { name: 'getBaseConfig', reqOverrides: {} }, + { + name: 'getConfig', + reqOverrides: { params: { principalType: 'role', principalId: 'admin' } }, + }, + ]; + + for (const { name, reqOverrides } of READ_HANDLERS) { + it(`${name} returns 401 when user is missing`, async () => { + const { handlers } = createHandlers(); + const req = mockReq({ ...reqOverrides, user: undefined }); + const res = mockRes(); + + await (handlers as Record Promise>)[name]( + req, + res, + ); + + expect(res.statusCode).toBe(401); + }); + } + }); + + describe('getBaseConfig', () => { + it('returns 403 when user lacks READ_CONFIGS', async () => { + const { handlers } = createHandlers({ + hasConfigCapability: jest.fn().mockResolvedValue(false), + }); + const req = mockReq(); + const res = mockRes(); + + await handlers.getBaseConfig(req, res); + + expect(res.statusCode).toBe(403); + }); + + it('returns the full AppConfig', async () => { + const { handlers } = createHandlers(); + const req = mockReq(); + const res = mockRes(); + + await handlers.getBaseConfig(req, res); + + expect(res.statusCode).toBe(200); + expect(res.body.config).toEqual({ interface: { endpointsMenu: true } }); + }); + }); +}); diff --git a/packages/api/src/admin/config.spec.ts b/packages/api/src/admin/config.spec.ts new file mode 100644 index 0000000000..499cfaa35b --- /dev/null +++ b/packages/api/src/admin/config.spec.ts @@ -0,0 +1,57 @@ +import { isValidFieldPath, getTopLevelSection } from './config'; + +describe('isValidFieldPath', () => { + it('accepts simple dot paths', () => { + expect(isValidFieldPath('interface.endpointsMenu')).toBe(true); + expect(isValidFieldPath('registration.socialLogins')).toBe(true); + expect(isValidFieldPath('a')).toBe(true); + expect(isValidFieldPath('a.b.c.d')).toBe(true); + }); + + it('rejects empty and non-string', () => { + expect(isValidFieldPath('')).toBe(false); + // @ts-expect-error testing invalid input + expect(isValidFieldPath(undefined)).toBe(false); + // @ts-expect-error testing invalid input + expect(isValidFieldPath(null)).toBe(false); + // @ts-expect-error testing invalid input + expect(isValidFieldPath(42)).toBe(false); + }); + + it('rejects __proto__ and dunder-prefixed segments', () => { + expect(isValidFieldPath('__proto__')).toBe(false); + expect(isValidFieldPath('a.__proto__')).toBe(false); + expect(isValidFieldPath('__proto__.polluted')).toBe(false); + expect(isValidFieldPath('a.__proto__.b')).toBe(false); + expect(isValidFieldPath('__defineGetter__')).toBe(false); + expect(isValidFieldPath('a.__lookupSetter__')).toBe(false); + expect(isValidFieldPath('__')).toBe(false); + expect(isValidFieldPath('a.__.b')).toBe(false); + }); + + it('rejects constructor and prototype segments', () => { + expect(isValidFieldPath('constructor')).toBe(false); + expect(isValidFieldPath('a.constructor')).toBe(false); + expect(isValidFieldPath('constructor.a')).toBe(false); + expect(isValidFieldPath('prototype')).toBe(false); + expect(isValidFieldPath('a.prototype')).toBe(false); + expect(isValidFieldPath('prototype.a')).toBe(false); + }); + + it('allows segments containing but not matching reserved words', () => { + expect(isValidFieldPath('constructorName')).toBe(true); + expect(isValidFieldPath('prototypeChain')).toBe(true); + expect(isValidFieldPath('a.myConstructor')).toBe(true); + }); +}); + +describe('getTopLevelSection', () => { + it('returns first segment of a dot path', () => { + expect(getTopLevelSection('interface.endpointsMenu')).toBe('interface'); + expect(getTopLevelSection('registration.socialLogins.github')).toBe('registration'); + }); + + it('returns the whole string when no dots', () => { + expect(getTopLevelSection('interface')).toBe('interface'); + }); +}); diff --git a/packages/api/src/admin/config.ts b/packages/api/src/admin/config.ts new file mode 100644 index 0000000000..0a1afd5388 --- /dev/null +++ b/packages/api/src/admin/config.ts @@ -0,0 +1,509 @@ +import { logger } from '@librechat/data-schemas'; +import { PrincipalType, PrincipalModel } from 'librechat-data-provider'; +import type { TCustomConfig } from 'librechat-data-provider'; +import type { AppConfig, ConfigSection, IConfig } from '@librechat/data-schemas'; +import type { Types, ClientSession } from 'mongoose'; +import type { Response } from 'express'; +import type { CapabilityUser } from '~/middleware/capabilities'; +import type { ServerRequest } from '~/types/http'; + +const UNSAFE_SEGMENTS = /(?:^|\.)(__[\w]*|constructor|prototype)(?:\.|$)/; +const MAX_PATCH_ENTRIES = 100; +const DEFAULT_PRIORITY = 10; + +export function isValidFieldPath(path: string): boolean { + return ( + typeof path === 'string' && + path.length > 0 && + !path.startsWith('.') && + !path.endsWith('.') && + !path.includes('..') && + !UNSAFE_SEGMENTS.test(path) + ); +} + +export function getTopLevelSection(fieldPath: string): string { + return fieldPath.split('.')[0]; +} + +export interface AdminConfigDeps { + listAllConfigs: (filter?: { isActive?: boolean }, session?: ClientSession) => Promise; + findConfigByPrincipal: ( + principalType: PrincipalType, + principalId: string | Types.ObjectId, + options?: { includeInactive?: boolean }, + session?: ClientSession, + ) => Promise; + upsertConfig: ( + principalType: PrincipalType, + principalId: string | Types.ObjectId, + principalModel: PrincipalModel, + overrides: Partial, + priority: number, + session?: ClientSession, + ) => Promise; + patchConfigFields: ( + principalType: PrincipalType, + principalId: string | Types.ObjectId, + principalModel: PrincipalModel, + fields: Record, + priority: number, + session?: ClientSession, + ) => Promise; + unsetConfigField: ( + principalType: PrincipalType, + principalId: string | Types.ObjectId, + fieldPath: string, + session?: ClientSession, + ) => Promise; + deleteConfig: ( + principalType: PrincipalType, + principalId: string | Types.ObjectId, + session?: ClientSession, + ) => Promise; + toggleConfigActive: ( + principalType: PrincipalType, + principalId: string | Types.ObjectId, + isActive: boolean, + session?: ClientSession, + ) => Promise; + hasConfigCapability: ( + user: CapabilityUser, + section: ConfigSection | null, + verb?: 'manage' | 'read', + ) => Promise; + getAppConfig?: (options?: { + role?: string; + userId?: string; + tenantId?: string; + }) => Promise; +} + +// โ”€โ”€ Validation helpers โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +const CONFIG_PRINCIPAL_TYPES = new Set([ + PrincipalType.USER, + PrincipalType.GROUP, + PrincipalType.ROLE, +]); + +function validatePrincipalType(value: string): value is PrincipalType { + return CONFIG_PRINCIPAL_TYPES.has(value as PrincipalType); +} + +function principalModel(type: PrincipalType): PrincipalModel { + switch (type) { + case PrincipalType.USER: + return PrincipalModel.USER; + case PrincipalType.GROUP: + return PrincipalModel.GROUP; + case PrincipalType.ROLE: + return PrincipalModel.ROLE; + case PrincipalType.PUBLIC: + return PrincipalModel.ROLE; + default: { + const _exhaustive: never = type; + logger.warn(`[adminConfig] Unmapped PrincipalType: ${String(_exhaustive)}`); + return PrincipalModel.ROLE; + } + } +} + +function getCapabilityUser(req: ServerRequest): CapabilityUser | null { + if (!req.user) { + return null; + } + return { + id: req.user.id ?? req.user._id?.toString() ?? '', + role: req.user.role ?? '', + tenantId: (req.user as { tenantId?: string }).tenantId, + }; +} + +// โ”€โ”€ Handler factory โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +export function createAdminConfigHandlers(deps: AdminConfigDeps) { + const { + listAllConfigs, + findConfigByPrincipal, + upsertConfig, + patchConfigFields, + unsetConfigField, + deleteConfig, + toggleConfigActive, + hasConfigCapability, + getAppConfig, + } = deps; + + /** + * GET / โ€” List all active config overrides. + */ + async function listConfigs(req: ServerRequest, res: Response) { + try { + const user = getCapabilityUser(req); + if (!user) { + return res.status(401).json({ error: 'Authentication required' }); + } + + if (!(await hasConfigCapability(user, null, 'read'))) { + return res.status(403).json({ error: 'Insufficient permissions' }); + } + + const configs = await listAllConfigs(); + return res.status(200).json({ configs }); + } catch (error) { + logger.error('[adminConfig] listConfigs error:', error); + return res.status(500).json({ error: 'Failed to list configs' }); + } + } + + /** + * GET /base โ€” Return the raw AppConfig (YAML + DB base merged). + * This is the full config structure admins can edit, NOT the startup payload. + */ + async function getBaseConfig(req: ServerRequest, res: Response) { + try { + const user = getCapabilityUser(req); + if (!user) { + return res.status(401).json({ error: 'Authentication required' }); + } + + if (!(await hasConfigCapability(user, null, 'read'))) { + return res.status(403).json({ error: 'Insufficient permissions' }); + } + + if (!getAppConfig) { + return res.status(501).json({ error: 'Base config endpoint not configured' }); + } + + const appConfig = await getAppConfig(); + return res.status(200).json({ config: appConfig }); + } catch (error) { + logger.error('[adminConfig] getBaseConfig error:', error); + return res.status(500).json({ error: 'Failed to get base config' }); + } + } + + /** + * GET /:principalType/:principalId โ€” Get config for a specific principal. + */ + async function getConfig(req: ServerRequest, res: Response) { + try { + const { principalType, principalId } = req.params as { + principalType: string; + principalId: string; + }; + + if (!validatePrincipalType(principalType)) { + return res.status(400).json({ error: `Invalid principalType: ${principalType}` }); + } + + const user = getCapabilityUser(req); + if (!user) { + return res.status(401).json({ error: 'Authentication required' }); + } + + if (!(await hasConfigCapability(user, null, 'read'))) { + return res.status(403).json({ error: 'Insufficient permissions' }); + } + + const config = await findConfigByPrincipal(principalType, principalId, { + includeInactive: true, + }); + if (!config) { + return res.status(404).json({ error: 'Config not found' }); + } + + return res.status(200).json({ config }); + } catch (error) { + logger.error('[adminConfig] getConfig error:', error); + return res.status(500).json({ error: 'Failed to get config' }); + } + } + + /** + * PUT /:principalType/:principalId โ€” Replace entire overrides for a principal. + */ + async function upsertConfigOverrides(req: ServerRequest, res: Response) { + try { + const { principalType, principalId } = req.params as { + principalType: string; + principalId: string; + }; + + if (!validatePrincipalType(principalType)) { + return res.status(400).json({ error: `Invalid principalType: ${principalType}` }); + } + + const { overrides, priority } = req.body as { + overrides?: Partial; + priority?: number; + }; + + if (!overrides || typeof overrides !== 'object' || Array.isArray(overrides)) { + return res.status(400).json({ error: 'overrides must be a plain object' }); + } + + if (priority != null && (typeof priority !== 'number' || priority < 0)) { + return res.status(400).json({ error: 'priority must be a non-negative number' }); + } + + const user = getCapabilityUser(req); + if (!user) { + return res.status(401).json({ error: 'Authentication required' }); + } + + if (!(await hasConfigCapability(user, null, 'manage'))) { + return res.status(403).json({ error: 'Insufficient permissions' }); + } + + const overrideSections = Object.keys(overrides); + if (overrideSections.length > 0) { + const allowed = await Promise.all( + overrideSections.map((s) => hasConfigCapability(user, s as ConfigSection, 'manage')), + ); + const denied = overrideSections.find((_, i) => !allowed[i]); + if (denied) { + return res.status(403).json({ + error: `Insufficient permissions for config section: ${denied}`, + }); + } + } + + const config = await upsertConfig( + principalType, + principalId, + principalModel(principalType), + overrides, + priority ?? DEFAULT_PRIORITY, + ); + + return res.status(config?.configVersion === 1 ? 201 : 200).json({ config }); + } catch (error) { + logger.error('[adminConfig] upsertConfigOverrides error:', error); + return res.status(500).json({ error: 'Failed to upsert config' }); + } + } + + /** + * PATCH /:principalType/:principalId/fields โ€” Set individual fields via dot-paths. + */ + async function patchConfigField(req: ServerRequest, res: Response) { + try { + const { principalType, principalId } = req.params as { + principalType: string; + principalId: string; + }; + + if (!validatePrincipalType(principalType)) { + return res.status(400).json({ error: `Invalid principalType: ${principalType}` }); + } + + const { entries, priority } = req.body as { + entries?: Array<{ fieldPath: string; value: unknown }>; + priority?: number; + }; + + if (priority != null && (typeof priority !== 'number' || priority < 0)) { + return res.status(400).json({ error: 'priority must be a non-negative number' }); + } + + if (!Array.isArray(entries) || entries.length === 0) { + return res.status(400).json({ error: 'entries array is required and must not be empty' }); + } + + if (entries.length > MAX_PATCH_ENTRIES) { + return res + .status(400) + .json({ error: `entries array exceeds maximum of ${MAX_PATCH_ENTRIES}` }); + } + + for (const entry of entries) { + if (!isValidFieldPath(entry.fieldPath)) { + return res + .status(400) + .json({ error: `Invalid or unsafe field path: ${entry.fieldPath}` }); + } + } + + const user = getCapabilityUser(req); + if (!user) { + return res.status(401).json({ error: 'Authentication required' }); + } + + if (!(await hasConfigCapability(user, null, 'manage'))) { + const sections = [...new Set(entries.map((e) => getTopLevelSection(e.fieldPath)))]; + const allowed = await Promise.all( + sections.map((s) => hasConfigCapability(user, s as ConfigSection, 'manage')), + ); + const denied = sections.find((_, i) => !allowed[i]); + if (denied) { + return res.status(403).json({ + error: `Insufficient permissions for config section: ${denied}`, + }); + } + } + + const seen = new Set(); + const fields: Record = {}; + for (const entry of entries) { + if (seen.has(entry.fieldPath)) { + return res.status(400).json({ error: `Duplicate fieldPath: ${entry.fieldPath}` }); + } + seen.add(entry.fieldPath); + fields[entry.fieldPath] = entry.value; + } + + const existing = + priority == null + ? await findConfigByPrincipal(principalType, principalId, { includeInactive: true }) + : null; + + const config = await patchConfigFields( + principalType, + principalId, + principalModel(principalType), + fields, + priority ?? existing?.priority ?? DEFAULT_PRIORITY, + ); + + return res.status(200).json({ config }); + } catch (error) { + logger.error('[adminConfig] patchConfigField error:', error); + return res.status(500).json({ error: 'Failed to patch config fields' }); + } + } + + /** + * DELETE /:principalType/:principalId/fields?fieldPath=dotted.path + */ + async function deleteConfigField(req: ServerRequest, res: Response) { + try { + const { principalType, principalId } = req.params as { + principalType: string; + principalId: string; + }; + if (!validatePrincipalType(principalType)) { + return res.status(400).json({ error: `Invalid principalType: ${principalType}` }); + } + + const fieldPath = req.query.fieldPath as string | undefined; + + if (!fieldPath || typeof fieldPath !== 'string') { + return res.status(400).json({ error: 'fieldPath query parameter is required' }); + } + + if (!isValidFieldPath(fieldPath)) { + return res.status(400).json({ error: `Invalid or unsafe field path: ${fieldPath}` }); + } + + const user = getCapabilityUser(req); + if (!user) { + return res.status(401).json({ error: 'Authentication required' }); + } + + const section = getTopLevelSection(fieldPath); + if (!(await hasConfigCapability(user, section as ConfigSection, 'manage'))) { + return res.status(403).json({ + error: `Insufficient permissions for config section: ${section}`, + }); + } + + const config = await unsetConfigField(principalType, principalId, fieldPath); + if (!config) { + return res.status(404).json({ error: 'Config not found' }); + } + + return res.status(200).json({ config }); + } catch (error) { + logger.error('[adminConfig] deleteConfigField error:', error); + return res.status(500).json({ error: 'Failed to delete config field' }); + } + } + + /** + * DELETE /:principalType/:principalId โ€” Delete an entire config override. + */ + async function deleteConfigOverrides(req: ServerRequest, res: Response) { + try { + const { principalType, principalId } = req.params as { + principalType: string; + principalId: string; + }; + + if (!validatePrincipalType(principalType)) { + return res.status(400).json({ error: `Invalid principalType: ${principalType}` }); + } + + const user = getCapabilityUser(req); + if (!user) { + return res.status(401).json({ error: 'Authentication required' }); + } + + if (!(await hasConfigCapability(user, null, 'manage'))) { + return res.status(403).json({ error: 'Insufficient permissions' }); + } + + const config = await deleteConfig(principalType, principalId); + if (!config) { + return res.status(404).json({ error: 'Config not found' }); + } + + return res.status(200).json({ success: true }); + } catch (error) { + logger.error('[adminConfig] deleteConfigOverrides error:', error); + return res.status(500).json({ error: 'Failed to delete config' }); + } + } + + /** + * PATCH /:principalType/:principalId/active โ€” Toggle isActive. + */ + async function toggleConfig(req: ServerRequest, res: Response) { + try { + const { principalType, principalId } = req.params as { + principalType: string; + principalId: string; + }; + + if (!validatePrincipalType(principalType)) { + return res.status(400).json({ error: `Invalid principalType: ${principalType}` }); + } + + const { isActive } = req.body as { isActive?: boolean }; + if (typeof isActive !== 'boolean') { + return res.status(400).json({ error: 'isActive boolean is required' }); + } + + const user = getCapabilityUser(req); + if (!user) { + return res.status(401).json({ error: 'Authentication required' }); + } + + if (!(await hasConfigCapability(user, null, 'manage'))) { + return res.status(403).json({ error: 'Insufficient permissions' }); + } + + const config = await toggleConfigActive(principalType, principalId, isActive); + if (!config) { + return res.status(404).json({ error: 'Config not found' }); + } + + return res.status(200).json({ config }); + } catch (error) { + logger.error('[adminConfig] toggleConfig error:', error); + return res.status(500).json({ error: 'Failed to toggle config' }); + } + } + + return { + listConfigs, + getBaseConfig, + getConfig, + upsertConfigOverrides, + patchConfigField, + deleteConfigField, + deleteConfigOverrides, + toggleConfig, + }; +} diff --git a/packages/api/src/admin/index.ts b/packages/api/src/admin/index.ts new file mode 100644 index 0000000000..bf48ce7345 --- /dev/null +++ b/packages/api/src/admin/index.ts @@ -0,0 +1,2 @@ +export { createAdminConfigHandlers } from './config'; +export type { AdminConfigDeps } from './config'; diff --git a/packages/api/src/app/index.ts b/packages/api/src/app/index.ts index b95193e943..7acb75e09d 100644 --- a/packages/api/src/app/index.ts +++ b/packages/api/src/app/index.ts @@ -1,3 +1,4 @@ +export * from './service'; export * from './config'; export * from './permissions'; export * from './cdn'; diff --git a/packages/api/src/app/service.spec.ts b/packages/api/src/app/service.spec.ts new file mode 100644 index 0000000000..2dfba09e25 --- /dev/null +++ b/packages/api/src/app/service.spec.ts @@ -0,0 +1,244 @@ +import { createAppConfigService } from './service'; + +function createMockCache() { + const store = new Map(); + return { + get: jest.fn((key) => Promise.resolve(store.get(key))), + set: jest.fn((key, value) => { + store.set(key, value); + return Promise.resolve(undefined); + }), + delete: jest.fn((key) => { + store.delete(key); + return Promise.resolve(true); + }), + _store: store, + }; +} + +function createDeps(overrides = {}) { + const cache = createMockCache(); + const baseConfig = { interface: { endpointsMenu: true }, endpoints: ['openAI'] }; + + return { + loadBaseConfig: jest.fn().mockResolvedValue(baseConfig), + setCachedTools: jest.fn().mockResolvedValue(undefined), + getCache: jest.fn().mockReturnValue(cache), + cacheKeys: { APP_CONFIG: 'app_config' }, + getApplicableConfigs: jest.fn().mockResolvedValue([]), + getUserPrincipals: jest.fn().mockResolvedValue([ + { principalType: 'role', principalId: 'USER' }, + { principalType: 'user', principalId: 'uid1' }, + ]), + _cache: cache, + _baseConfig: baseConfig, + ...overrides, + }; +} + +describe('createAppConfigService', () => { + describe('getAppConfig', () => { + it('loads base config on first call', async () => { + const deps = createDeps(); + const { getAppConfig } = createAppConfigService(deps); + + const config = await getAppConfig(); + + expect(deps.loadBaseConfig).toHaveBeenCalledTimes(1); + expect(config).toEqual(deps._baseConfig); + }); + + it('caches base config โ€” does not reload on second call', async () => { + const deps = createDeps(); + const { getAppConfig } = createAppConfigService(deps); + + await getAppConfig(); + await getAppConfig(); + + expect(deps.loadBaseConfig).toHaveBeenCalledTimes(1); + }); + + it('reloads base config when refresh is true', async () => { + const deps = createDeps(); + const { getAppConfig } = createAppConfigService(deps); + + await getAppConfig(); + await getAppConfig({ refresh: true }); + + expect(deps.loadBaseConfig).toHaveBeenCalledTimes(2); + }); + + it('queries DB for applicable configs', async () => { + const deps = createDeps(); + const { getAppConfig } = createAppConfigService(deps); + + await getAppConfig({ role: 'ADMIN' }); + + expect(deps.getApplicableConfigs).toHaveBeenCalled(); + }); + + it('caches empty result โ€” does not re-query DB on second call', async () => { + const deps = createDeps({ getApplicableConfigs: jest.fn().mockResolvedValue([]) }); + const { getAppConfig } = createAppConfigService(deps); + + await getAppConfig({ role: 'USER' }); + await getAppConfig({ role: 'USER' }); + + expect(deps.getApplicableConfigs).toHaveBeenCalledTimes(1); + }); + + it('merges DB configs when found', async () => { + const deps = createDeps({ + getApplicableConfigs: jest + .fn() + .mockResolvedValue([ + { priority: 10, overrides: { interface: { endpointsMenu: false } }, isActive: true }, + ]), + }); + const { getAppConfig } = createAppConfigService(deps); + + const config = await getAppConfig({ role: 'ADMIN' }); + + expect(config.interface.endpointsMenu).toBe(false); + expect(config.endpoints).toEqual(['openAI']); + }); + + it('caches merged result with TTL', async () => { + const deps = createDeps({ + getApplicableConfigs: jest + .fn() + .mockResolvedValue([{ priority: 10, overrides: { x: 1 }, isActive: true }]), + }); + const { getAppConfig } = createAppConfigService(deps); + + await getAppConfig({ role: 'ADMIN' }); + await getAppConfig({ role: 'ADMIN' }); + + expect(deps.getApplicableConfigs).toHaveBeenCalledTimes(1); + }); + + it('uses separate cache keys per userId (no cross-user contamination)', async () => { + const deps = createDeps({ + getApplicableConfigs: jest + .fn() + .mockResolvedValue([ + { priority: 100, overrides: { x: 'user-specific' }, isActive: true }, + ]), + }); + const { getAppConfig } = createAppConfigService(deps); + + await getAppConfig({ userId: 'uid1' }); + await getAppConfig({ userId: 'uid2' }); + + expect(deps.getApplicableConfigs).toHaveBeenCalledTimes(2); + }); + + it('userId without role gets its own cache key', async () => { + const deps = createDeps({ + getApplicableConfigs: jest + .fn() + .mockResolvedValue([{ priority: 100, overrides: { y: 1 }, isActive: true }]), + }); + const { getAppConfig } = createAppConfigService(deps); + + await getAppConfig({ userId: 'uid1' }); + + const cachedKeys = [...deps._cache._store.keys()]; + const overrideKey = cachedKeys.find((k) => k.startsWith('_OVERRIDE_:')); + expect(overrideKey).toBe('_OVERRIDE_:__default__:uid1'); + }); + + it('tenantId is included in cache key to prevent cross-tenant contamination', async () => { + const deps = createDeps({ + getApplicableConfigs: jest + .fn() + .mockResolvedValue([{ priority: 10, overrides: { x: 1 }, isActive: true }]), + }); + const { getAppConfig } = createAppConfigService(deps); + + await getAppConfig({ role: 'ADMIN', tenantId: 'tenant-a' }); + await getAppConfig({ role: 'ADMIN', tenantId: 'tenant-b' }); + + expect(deps.getApplicableConfigs).toHaveBeenCalledTimes(2); + }); + + it('base-only empty result does not block subsequent scoped queries with results', async () => { + const mockGetConfigs = jest.fn().mockResolvedValue([]); + const deps = createDeps({ getApplicableConfigs: mockGetConfigs }); + const { getAppConfig } = createAppConfigService(deps); + + await getAppConfig(); + + mockGetConfigs.mockResolvedValueOnce([ + { priority: 10, overrides: { restricted: true }, isActive: true }, + ]); + const config = await getAppConfig({ role: 'ADMIN' }); + + expect(mockGetConfigs).toHaveBeenCalledTimes(2); + expect((config as Record).restricted).toBe(true); + }); + + it('does not short-circuit other users when one user has no overrides', async () => { + const mockGetConfigs = jest.fn().mockResolvedValue([]); + const deps = createDeps({ getApplicableConfigs: mockGetConfigs }); + const { getAppConfig } = createAppConfigService(deps); + + await getAppConfig({ role: 'USER' }); + expect(mockGetConfigs).toHaveBeenCalledTimes(1); + + mockGetConfigs.mockResolvedValueOnce([ + { priority: 10, overrides: { x: 'admin-only' }, isActive: true }, + ]); + const config = await getAppConfig({ role: 'ADMIN' }); + + expect(mockGetConfigs).toHaveBeenCalledTimes(2); + expect((config as Record).x).toBe('admin-only'); + }); + + it('falls back to base config on getApplicableConfigs error', async () => { + const deps = createDeps({ + getApplicableConfigs: jest.fn().mockRejectedValue(new Error('DB down')), + }); + const { getAppConfig } = createAppConfigService(deps); + + const config = await getAppConfig({ role: 'ADMIN' }); + + expect(config).toEqual(deps._baseConfig); + }); + + it('calls getUserPrincipals when userId is provided', async () => { + const deps = createDeps(); + const { getAppConfig } = createAppConfigService(deps); + + await getAppConfig({ role: 'USER', userId: 'uid1' }); + + expect(deps.getUserPrincipals).toHaveBeenCalledWith({ + userId: 'uid1', + role: 'USER', + }); + }); + + it('does not call getUserPrincipals when only role is provided', async () => { + const deps = createDeps(); + const { getAppConfig } = createAppConfigService(deps); + + await getAppConfig({ role: 'ADMIN' }); + + expect(deps.getUserPrincipals).not.toHaveBeenCalled(); + }); + }); + + describe('clearAppConfigCache', () => { + it('clears base config so it reloads on next call', async () => { + const deps = createDeps(); + const { getAppConfig, clearAppConfigCache } = createAppConfigService(deps); + + await getAppConfig(); + expect(deps.loadBaseConfig).toHaveBeenCalledTimes(1); + + await clearAppConfigCache(); + await getAppConfig(); + expect(deps.loadBaseConfig).toHaveBeenCalledTimes(2); + }); + }); +}); diff --git a/packages/api/src/app/service.ts b/packages/api/src/app/service.ts new file mode 100644 index 0000000000..b7826e40ee --- /dev/null +++ b/packages/api/src/app/service.ts @@ -0,0 +1,155 @@ +import { PrincipalType } from 'librechat-data-provider'; +import { logger, mergeConfigOverrides, BASE_CONFIG_PRINCIPAL_ID } from '@librechat/data-schemas'; +import type { Types } from 'mongoose'; +import type { AppConfig, IConfig } from '@librechat/data-schemas'; + +const BASE_CONFIG_KEY = '_BASE_'; + +const DEFAULT_OVERRIDE_CACHE_TTL = 60_000; + +// โ”€โ”€ Types โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +interface CacheStore { + get: (key: string) => Promise; + set: (key: string, value: unknown, ttl?: number) => Promise; + delete: (key: string) => Promise; +} + +export interface AppConfigServiceDeps { + /** Load the base AppConfig from YAML + AppService processing. */ + loadBaseConfig: () => Promise; + /** Cache tools after base config is loaded. */ + setCachedTools: (tools: Record) => Promise; + /** Get a cache store by key. */ + getCache: (key: string) => CacheStore; + /** The CacheKeys constants from librechat-data-provider. */ + cacheKeys: { APP_CONFIG: string }; + /** Fetch applicable DB config overrides for a set of principals. */ + getApplicableConfigs: ( + principals?: Array<{ principalType: string; principalId?: string | Types.ObjectId }>, + ) => Promise; + /** Resolve full principal list (user + role + groups) from userId/role. */ + getUserPrincipals: (params: { + userId: string | Types.ObjectId; + role?: string | null; + }) => Promise>; + /** TTL in ms for per-user/role merged config caches. Defaults to 60 000. */ + overrideCacheTtl?: number; +} + +// โ”€โ”€ Helpers โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +function overrideCacheKey(role?: string, userId?: string, tenantId?: string): string { + const tenant = tenantId || '__default__'; + if (userId && role) { + return `_OVERRIDE_:${tenant}:${role}:${userId}`; + } + if (userId) { + return `_OVERRIDE_:${tenant}:${userId}`; + } + if (role) { + return `_OVERRIDE_:${tenant}:${role}`; + } + return `_OVERRIDE_:${tenant}:${BASE_CONFIG_PRINCIPAL_ID}`; +} + +// โ”€โ”€ Service factory โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +export function createAppConfigService(deps: AppConfigServiceDeps) { + const { + loadBaseConfig, + setCachedTools, + getCache, + cacheKeys, + getApplicableConfigs, + getUserPrincipals, + overrideCacheTtl = DEFAULT_OVERRIDE_CACHE_TTL, + } = deps; + + const cache = getCache(cacheKeys.APP_CONFIG); + + async function buildPrincipals( + role?: string, + userId?: string, + ): Promise> { + if (userId) { + return getUserPrincipals({ userId, role }); + } + const principals: Array<{ principalType: string; principalId?: string | Types.ObjectId }> = []; + if (role) { + principals.push({ principalType: PrincipalType.ROLE, principalId: role }); + } + return principals; + } + + /** + * Get the app configuration, optionally merged with DB overrides for the given principal. + * + * The base config (from YAML + AppService) is cached indefinitely. Per-principal merged + * configs are cached with a short TTL (`overrideCacheTtl`, default 60s). On cache miss, + * `getApplicableConfigs` queries the DB for matching overrides and merges them by priority. + */ + async function getAppConfig( + options: { role?: string; userId?: string; tenantId?: string; refresh?: boolean } = {}, + ): Promise { + const { role, userId, tenantId, refresh } = options; + + let baseConfig = (await cache.get(BASE_CONFIG_KEY)) as AppConfig | undefined; + if (!baseConfig || refresh) { + logger.info('[getAppConfig] Loading base configuration...'); + baseConfig = await loadBaseConfig(); + + if (!baseConfig) { + throw new Error('Failed to initialize app configuration through AppService.'); + } + + if (baseConfig.availableTools) { + await setCachedTools(baseConfig.availableTools); + } + + await cache.set(BASE_CONFIG_KEY, baseConfig); + } + + const cacheKey = overrideCacheKey(role, userId, tenantId); + if (!refresh) { + const cachedMerged = (await cache.get(cacheKey)) as AppConfig | undefined; + if (cachedMerged) { + return cachedMerged; + } + } + + try { + const principals = await buildPrincipals(role, userId); + const configs = await getApplicableConfigs(principals); + + if (configs.length === 0) { + await cache.set(cacheKey, baseConfig, overrideCacheTtl); + return baseConfig; + } + + const merged = mergeConfigOverrides(baseConfig, configs); + await cache.set(cacheKey, merged, overrideCacheTtl); + return merged; + } catch (error) { + logger.error('[getAppConfig] Error resolving config overrides, falling back to base:', error); + return baseConfig; + } + } + + /** + * Clear the base config cache. Per-user/role override caches (`_OVERRIDE_:*`) + * are NOT flushed โ€” they expire naturally via `overrideCacheTtl`. After calling this, + * the base config will be reloaded from YAML on the next `getAppConfig` call, but + * users with cached overrides may see stale merged configs for up to `overrideCacheTtl` ms. + */ + async function clearAppConfigCache(): Promise { + await cache.delete(BASE_CONFIG_KEY); + } + + return { + getAppConfig, + clearAppConfigCache, + }; +} + +export type AppConfigService = ReturnType; diff --git a/packages/api/src/index.ts b/packages/api/src/index.ts index ef32e7b6b0..5ccf6b0124 100644 --- a/packages/api/src/index.ts +++ b/packages/api/src/index.ts @@ -1,4 +1,6 @@ export * from './app'; +/* Admin */ +export * from './admin'; export * from './cdn'; /* Auth */ export * from './auth'; diff --git a/packages/api/src/middleware/capabilities.ts b/packages/api/src/middleware/capabilities.ts index c06a90ac8e..28d3a0f76e 100644 --- a/packages/api/src/middleware/capabilities.ts +++ b/packages/api/src/middleware/capabilities.ts @@ -26,7 +26,7 @@ interface CapabilityDeps { }) => Promise; } -interface CapabilityUser { +export interface CapabilityUser { id: string; role: string; tenantId?: string; @@ -48,7 +48,7 @@ export type RequireCapabilityFn = ( export type HasConfigCapabilityFn = ( user: CapabilityUser, - section: ConfigSection, + section: ConfigSection | null, verb?: 'manage' | 'read', ) => Promise; @@ -138,11 +138,14 @@ export function generateCapabilityCheck(deps: CapabilityDeps): { */ async function hasConfigCapability( user: CapabilityUser, - section: ConfigSection, + section: ConfigSection | null, verb: 'manage' | 'read' = 'manage', ): Promise { const broadCap = verb === 'manage' ? SystemCapabilities.MANAGE_CONFIGS : SystemCapabilities.READ_CONFIGS; + if (section == null) { + return hasCapability(user, broadCap); + } if (await hasCapability(user, broadCap)) { return true; } diff --git a/packages/data-provider/package.json b/packages/data-provider/package.json index a66f4eec4e..1e0c76f37f 100644 --- a/packages/data-provider/package.json +++ b/packages/data-provider/package.json @@ -1,6 +1,6 @@ { "name": "librechat-data-provider", - "version": "0.8.401", + "version": "0.8.403", "description": "data services for librechat apps", "main": "dist/index.js", "module": "dist/index.es.js", @@ -18,6 +18,9 @@ "types": "./dist/types/react-query/index.d.ts" } }, + "files": [ + "dist" + ], "scripts": { "clean": "rimraf dist", "build": "npm run clean && rollup -c --silent --bundleConfigAsCjs", diff --git a/packages/data-schemas/package.json b/packages/data-schemas/package.json index 0376804ad4..0124552002 100644 --- a/packages/data-schemas/package.json +++ b/packages/data-schemas/package.json @@ -1,16 +1,22 @@ { "name": "@librechat/data-schemas", - "version": "0.0.40", + "version": "0.0.46", "description": "Mongoose schemas and models for LibreChat", "type": "module", "main": "dist/index.cjs", "module": "dist/index.es.js", "types": "./dist/types/index.d.ts", + "sideEffects": false, "exports": { ".": { "import": "./dist/index.es.js", "require": "./dist/index.cjs", "types": "./dist/types/index.d.ts" + }, + "./capabilities": { + "import": "./dist/admin/capabilities.es.js", + "require": "./dist/admin/capabilities.cjs", + "types": "./dist/types/admin/capabilities.d.ts" } }, "files": [ diff --git a/packages/data-schemas/rollup.config.js b/packages/data-schemas/rollup.config.js index d58331feee..703630e121 100644 --- a/packages/data-schemas/rollup.config.js +++ b/packages/data-schemas/rollup.config.js @@ -8,14 +8,20 @@ export default { input: 'src/index.ts', output: [ { - file: 'dist/index.es.js', + dir: 'dist', format: 'es', sourcemap: true, + preserveModules: true, + preserveModulesRoot: 'src', + entryFileNames: '[name].es.js', }, { - file: 'dist/index.cjs', + dir: 'dist', format: 'cjs', sourcemap: true, + preserveModules: true, + preserveModulesRoot: 'src', + entryFileNames: '[name].cjs', }, ], plugins: [ diff --git a/packages/data-schemas/src/admin/capabilities.ts b/packages/data-schemas/src/admin/capabilities.ts new file mode 100644 index 0000000000..447db235a2 --- /dev/null +++ b/packages/data-schemas/src/admin/capabilities.ts @@ -0,0 +1,199 @@ +import { ResourceType } from 'librechat-data-provider'; +import type { + BaseSystemCapability, + SystemCapability, + ConfigSection, + CapabilityCategory, +} from '~/types/admin'; + +// --------------------------------------------------------------------------- +// System Capabilities +// --------------------------------------------------------------------------- + +/** + * The canonical set of base system capabilities. + * + * These are used by the admin panel and LibreChat API to gate access to + * admin features. Config-section-derived capabilities (e.g. + * `manage:configs:endpoints`) are built on top of these where the + * configSchema is available. + */ +export const SystemCapabilities = { + ACCESS_ADMIN: 'access:admin', + READ_USERS: 'read:users', + MANAGE_USERS: 'manage:users', + READ_GROUPS: 'read:groups', + MANAGE_GROUPS: 'manage:groups', + READ_ROLES: 'read:roles', + MANAGE_ROLES: 'manage:roles', + READ_CONFIGS: 'read:configs', + MANAGE_CONFIGS: 'manage:configs', + ASSIGN_CONFIGS: 'assign:configs', + READ_USAGE: 'read:usage', + READ_AGENTS: 'read:agents', + MANAGE_AGENTS: 'manage:agents', + MANAGE_MCP_SERVERS: 'manage:mcpservers', + READ_PROMPTS: 'read:prompts', + MANAGE_PROMPTS: 'manage:prompts', + /** Reserved โ€” not yet enforced by any middleware. */ + READ_ASSISTANTS: 'read:assistants', + MANAGE_ASSISTANTS: 'manage:assistants', +} as const; + +/** + * Capabilities that are implied by holding a broader capability. + * e.g. `MANAGE_USERS` implies `READ_USERS`. + */ +export const CapabilityImplications: Partial> = + { + [SystemCapabilities.MANAGE_USERS]: [SystemCapabilities.READ_USERS], + [SystemCapabilities.MANAGE_GROUPS]: [SystemCapabilities.READ_GROUPS], + [SystemCapabilities.MANAGE_ROLES]: [SystemCapabilities.READ_ROLES], + [SystemCapabilities.MANAGE_CONFIGS]: [SystemCapabilities.READ_CONFIGS], + [SystemCapabilities.MANAGE_AGENTS]: [SystemCapabilities.READ_AGENTS], + [SystemCapabilities.MANAGE_PROMPTS]: [SystemCapabilities.READ_PROMPTS], + [SystemCapabilities.MANAGE_ASSISTANTS]: [SystemCapabilities.READ_ASSISTANTS], + }; + +// --------------------------------------------------------------------------- +// Capability utility functions +// --------------------------------------------------------------------------- + +/** Reverse map: for a given read capability, which manage capabilities imply it? */ +const impliedByMap: Record = {}; +for (const [manage, reads] of Object.entries(CapabilityImplications)) { + for (const read of reads as string[]) { + if (!impliedByMap[read]) { + impliedByMap[read] = []; + } + impliedByMap[read].push(manage); + } +} + +/** + * Check whether a set of held capabilities satisfies a required capability, + * accounting for the manageโ†’read implication hierarchy. + */ +export function hasImpliedCapability(held: string[], required: string): boolean { + if (held.includes(required)) { + return true; + } + const impliers = impliedByMap[required]; + if (impliers) { + for (const cap of impliers) { + if (held.includes(cap)) { + return true; + } + } + } + return false; +} + +/** + * Given a set of directly-held capabilities, compute the full set including + * all implied capabilities. + */ +export function expandImplications(directCaps: string[]): string[] { + const expanded = new Set(directCaps); + for (const cap of directCaps) { + const implied = CapabilityImplications[cap as BaseSystemCapability]; + if (implied) { + for (const imp of implied) { + expanded.add(imp); + } + } + } + return Array.from(expanded); +} + +// --------------------------------------------------------------------------- +// Resource & config capability mappings +// --------------------------------------------------------------------------- + +/** + * Maps each ACL ResourceType to the SystemCapability that grants + * unrestricted management access. Typed as `Record` + * so adding a new ResourceType variant causes a compile error until a + * capability is assigned here. + */ +export const ResourceCapabilityMap: Record = { + [ResourceType.AGENT]: SystemCapabilities.MANAGE_AGENTS, + [ResourceType.PROMPTGROUP]: SystemCapabilities.MANAGE_PROMPTS, + [ResourceType.MCPSERVER]: SystemCapabilities.MANAGE_MCP_SERVERS, + [ResourceType.REMOTE_AGENT]: SystemCapabilities.MANAGE_AGENTS, +}; + +/** + * Derives a section-level config management capability from a configSchema key. + * @example configCapability('endpoints') โ†’ 'manage:configs:endpoints' + * + * TODO: Section-level config capabilities are scaffolded but not yet active. + * To activate delegated config management: + * 1. Expose POST/DELETE /api/admin/grants endpoints (wiring grantCapability/revokeCapability) + * 2. Seed section-specific grants for delegated admin roles via those endpoints + * 3. Guard config write handlers with hasConfigCapability(user, section) + */ +export function configCapability(section: ConfigSection): `manage:configs:${ConfigSection}` { + return `manage:configs:${section}`; +} + +/** + * Derives a section-level config read capability from a configSchema key. + * @example readConfigCapability('endpoints') โ†’ 'read:configs:endpoints' + */ +export function readConfigCapability(section: ConfigSection): `read:configs:${ConfigSection}` { + return `read:configs:${section}`; +} + +// --------------------------------------------------------------------------- +// Reserved principal IDs +// --------------------------------------------------------------------------- + +/** Reserved principalId for the DB base config (overrides YAML defaults). */ +export const BASE_CONFIG_PRINCIPAL_ID = '__base__'; + +/** Pre-defined UI categories for grouping capabilities in the admin panel. */ +export const CAPABILITY_CATEGORIES: CapabilityCategory[] = [ + { + key: 'users', + labelKey: 'com_cap_cat_users', + capabilities: [SystemCapabilities.MANAGE_USERS, SystemCapabilities.READ_USERS], + }, + { + key: 'groups', + labelKey: 'com_cap_cat_groups', + capabilities: [SystemCapabilities.MANAGE_GROUPS, SystemCapabilities.READ_GROUPS], + }, + { + key: 'roles', + labelKey: 'com_cap_cat_roles', + capabilities: [SystemCapabilities.MANAGE_ROLES, SystemCapabilities.READ_ROLES], + }, + { + key: 'config', + labelKey: 'com_cap_cat_config', + capabilities: [ + SystemCapabilities.MANAGE_CONFIGS, + SystemCapabilities.READ_CONFIGS, + SystemCapabilities.ASSIGN_CONFIGS, + ], + }, + { + key: 'content', + labelKey: 'com_cap_cat_content', + capabilities: [ + SystemCapabilities.MANAGE_AGENTS, + SystemCapabilities.READ_AGENTS, + SystemCapabilities.MANAGE_PROMPTS, + SystemCapabilities.READ_PROMPTS, + SystemCapabilities.MANAGE_ASSISTANTS, + SystemCapabilities.READ_ASSISTANTS, + SystemCapabilities.MANAGE_MCP_SERVERS, + ], + }, + { + key: 'system', + labelKey: 'com_cap_cat_system', + capabilities: [SystemCapabilities.ACCESS_ADMIN, SystemCapabilities.READ_USAGE], + }, +]; diff --git a/packages/data-schemas/src/admin/index.ts b/packages/data-schemas/src/admin/index.ts new file mode 100644 index 0000000000..8d43daada6 --- /dev/null +++ b/packages/data-schemas/src/admin/index.ts @@ -0,0 +1 @@ +export * from './capabilities'; diff --git a/packages/data-schemas/src/app/index.ts b/packages/data-schemas/src/app/index.ts index 77cb799f8c..b07a36acd0 100644 --- a/packages/data-schemas/src/app/index.ts +++ b/packages/data-schemas/src/app/index.ts @@ -5,3 +5,4 @@ export * from './specs'; export * from './turnstile'; export * from './vertex'; export * from './web'; +export * from './resolution'; diff --git a/packages/data-schemas/src/app/resolution.spec.ts b/packages/data-schemas/src/app/resolution.spec.ts new file mode 100644 index 0000000000..12f8985a48 --- /dev/null +++ b/packages/data-schemas/src/app/resolution.spec.ts @@ -0,0 +1,108 @@ +import { mergeConfigOverrides } from './resolution'; +import type { AppConfig, IConfig } from '~/types'; + +function fakeConfig(overrides: Record, priority: number): IConfig { + return { + _id: 'fake', + principalType: 'role', + principalId: 'test', + principalModel: 'Role', + priority, + overrides, + isActive: true, + configVersion: 1, + } as unknown as IConfig; +} + +const baseConfig = { + interface: { endpointsMenu: true, sidePanel: true }, + registration: { enabled: true }, + endpoints: ['openAI'], +} as unknown as AppConfig; + +describe('mergeConfigOverrides', () => { + it('returns base config when configs array is empty', () => { + expect(mergeConfigOverrides(baseConfig, [])).toBe(baseConfig); + }); + + it('returns base config when configs is null/undefined', () => { + expect(mergeConfigOverrides(baseConfig, null as unknown as IConfig[])).toBe(baseConfig); + expect(mergeConfigOverrides(baseConfig, undefined as unknown as IConfig[])).toBe(baseConfig); + }); + + it('deep merges a single override into base', () => { + const configs = [fakeConfig({ interface: { endpointsMenu: false } }, 10)]; + const result = mergeConfigOverrides(baseConfig, configs) as unknown as Record; + const iface = result.interface as Record; + expect(iface.endpointsMenu).toBe(false); + expect(iface.sidePanel).toBe(true); + }); + + it('sorts by priority โ€” higher priority wins', () => { + const configs = [ + fakeConfig({ registration: { enabled: false } }, 100), + fakeConfig({ registration: { enabled: true, custom: 'yes' } }, 10), + ]; + const result = mergeConfigOverrides(baseConfig, configs) as unknown as Record; + const reg = result.registration as Record; + expect(reg.enabled).toBe(false); + expect(reg.custom).toBe('yes'); + }); + + it('replaces arrays instead of concatenating', () => { + const configs = [fakeConfig({ endpoints: ['anthropic', 'google'] }, 10)]; + const result = mergeConfigOverrides(baseConfig, configs) as unknown as Record; + expect(result.endpoints).toEqual(['anthropic', 'google']); + }); + + it('does not mutate the base config', () => { + const original = JSON.parse(JSON.stringify(baseConfig)); + const configs = [fakeConfig({ interface: { endpointsMenu: false } }, 10)]; + mergeConfigOverrides(baseConfig, configs); + expect(baseConfig).toEqual(original); + }); + + it('handles null override values', () => { + const configs = [fakeConfig({ interface: { endpointsMenu: null } }, 10)]; + const result = mergeConfigOverrides(baseConfig, configs) as unknown as Record; + const iface = result.interface as Record; + expect(iface.endpointsMenu).toBeNull(); + }); + + it('skips configs with no overrides object', () => { + const configs = [fakeConfig(undefined as unknown as Record, 10)]; + const result = mergeConfigOverrides(baseConfig, configs); + expect(result).toEqual(baseConfig); + }); + + it('strips __proto__, constructor, and prototype keys from overrides', () => { + const configs = [ + fakeConfig( + { + __proto__: { polluted: true }, + constructor: { bad: true }, + prototype: { evil: true }, + safe: 'ok', + }, + 10, + ), + ]; + const result = mergeConfigOverrides(baseConfig, configs) as unknown as Record; + expect(result.safe).toBe('ok'); + expect(({} as Record).polluted).toBeUndefined(); + expect(Object.prototype.hasOwnProperty.call(result, 'constructor')).toBe(false); + expect(Object.prototype.hasOwnProperty.call(result, 'prototype')).toBe(false); + }); + + it('merges three priority levels in order', () => { + const configs = [ + fakeConfig({ interface: { endpointsMenu: false } }, 0), + fakeConfig({ interface: { endpointsMenu: true, sidePanel: false } }, 10), + fakeConfig({ interface: { sidePanel: true } }, 100), + ]; + const result = mergeConfigOverrides(baseConfig, configs) as unknown as Record; + const iface = result.interface as Record; + expect(iface.endpointsMenu).toBe(true); + expect(iface.sidePanel).toBe(true); + }); +}); diff --git a/packages/data-schemas/src/app/resolution.ts b/packages/data-schemas/src/app/resolution.ts new file mode 100644 index 0000000000..ad1c1fbff0 --- /dev/null +++ b/packages/data-schemas/src/app/resolution.ts @@ -0,0 +1,54 @@ +import type { AppConfig, IConfig } from '~/types'; + +type AnyObject = { [key: string]: unknown }; + +const MAX_MERGE_DEPTH = 10; +const UNSAFE_KEYS = new Set(['__proto__', 'constructor', 'prototype']); + +function deepMerge(target: T, source: AnyObject, depth = 0): T { + const result = { ...target } as AnyObject; + for (const key of Object.keys(source)) { + if (UNSAFE_KEYS.has(key)) { + continue; + } + const sourceVal = source[key]; + const targetVal = result[key]; + if ( + depth < MAX_MERGE_DEPTH && + sourceVal != null && + typeof sourceVal === 'object' && + !Array.isArray(sourceVal) && + targetVal != null && + typeof targetVal === 'object' && + !Array.isArray(targetVal) + ) { + result[key] = deepMerge(targetVal as AnyObject, sourceVal as AnyObject, depth + 1); + } else { + result[key] = sourceVal; + } + } + return result as T; +} + +/** + * Merge DB config overrides into a base AppConfig. + * + * Configs are sorted by priority ascending (lowest first, highest wins). + * Each config's `overrides` is deep-merged into the base config in order. + */ +export function mergeConfigOverrides(baseConfig: AppConfig, configs: IConfig[]): AppConfig { + if (!configs || configs.length === 0) { + return baseConfig; + } + + const sorted = [...configs].sort((a, b) => a.priority - b.priority); + + let merged = { ...baseConfig }; + for (const config of sorted) { + if (config.overrides && typeof config.overrides === 'object') { + merged = deepMerge(merged, config.overrides as AnyObject); + } + } + + return merged; +} diff --git a/packages/data-schemas/src/index.ts b/packages/data-schemas/src/index.ts index aa92b3b2e6..cd683c937c 100644 --- a/packages/data-schemas/src/index.ts +++ b/packages/data-schemas/src/index.ts @@ -1,5 +1,5 @@ export * from './app'; -export * from './systemCapabilities'; +export * from './admin'; export * from './common'; export * from './crypto'; export * from './schema'; diff --git a/packages/data-schemas/src/methods/config.spec.ts b/packages/data-schemas/src/methods/config.spec.ts new file mode 100644 index 0000000000..82f43c2b37 --- /dev/null +++ b/packages/data-schemas/src/methods/config.spec.ts @@ -0,0 +1,297 @@ +import mongoose, { Types } from 'mongoose'; +import { MongoMemoryServer } from 'mongodb-memory-server'; +import { PrincipalType, PrincipalModel } from 'librechat-data-provider'; +import { createConfigMethods } from './config'; +import configSchema from '~/schema/config'; +import type { IConfig } from '~/types'; + +let mongoServer: MongoMemoryServer; +let methods: ReturnType; + +beforeAll(async () => { + mongoServer = await MongoMemoryServer.create(); + await mongoose.connect(mongoServer.getUri()); + if (!mongoose.models.Config) { + mongoose.model('Config', configSchema); + } + methods = createConfigMethods(mongoose); +}); + +afterAll(async () => { + await mongoose.disconnect(); + await mongoServer.stop(); +}); + +beforeEach(async () => { + await mongoose.models.Config.deleteMany({}); +}); + +describe('upsertConfig', () => { + it('creates a new config document', async () => { + const result = await methods.upsertConfig( + PrincipalType.ROLE, + 'admin', + PrincipalModel.ROLE, + { interface: { endpointsMenu: false } }, + 10, + ); + + expect(result).toBeTruthy(); + expect(result!.principalType).toBe(PrincipalType.ROLE); + expect(result!.principalId).toBe('admin'); + expect(result!.priority).toBe(10); + expect(result!.isActive).toBe(true); + expect(result!.configVersion).toBe(1); + }); + + it('is idempotent โ€” second upsert updates the same doc', async () => { + await methods.upsertConfig( + PrincipalType.ROLE, + 'admin', + PrincipalModel.ROLE, + { interface: { endpointsMenu: false } }, + 10, + ); + + await methods.upsertConfig( + PrincipalType.ROLE, + 'admin', + PrincipalModel.ROLE, + { interface: { endpointsMenu: true } }, + 10, + ); + + const count = await mongoose.models.Config.countDocuments({}); + expect(count).toBe(1); + }); + + it('increments configVersion on each upsert', async () => { + const first = await methods.upsertConfig( + PrincipalType.ROLE, + 'admin', + PrincipalModel.ROLE, + { a: 1 }, + 10, + ); + + const second = await methods.upsertConfig( + PrincipalType.ROLE, + 'admin', + PrincipalModel.ROLE, + { a: 2 }, + 10, + ); + + expect(first!.configVersion).toBe(1); + expect(second!.configVersion).toBe(2); + }); + + it('normalizes ObjectId principalId to string', async () => { + const oid = new Types.ObjectId(); + await methods.upsertConfig(PrincipalType.USER, oid, PrincipalModel.USER, { test: true }, 100); + + const found = await methods.findConfigByPrincipal(PrincipalType.USER, oid.toString()); + expect(found).toBeTruthy(); + expect(found!.principalId).toBe(oid.toString()); + }); +}); + +describe('findConfigByPrincipal', () => { + it('finds an active config', async () => { + await methods.upsertConfig(PrincipalType.ROLE, 'admin', PrincipalModel.ROLE, { x: 1 }, 10); + + const result = await methods.findConfigByPrincipal(PrincipalType.ROLE, 'admin'); + expect(result).toBeTruthy(); + expect(result!.principalType).toBe(PrincipalType.ROLE); + }); + + it('returns null when no config exists', async () => { + const result = await methods.findConfigByPrincipal(PrincipalType.ROLE, 'nonexistent'); + expect(result).toBeNull(); + }); + + it('does not find inactive configs', async () => { + await methods.upsertConfig(PrincipalType.ROLE, 'admin', PrincipalModel.ROLE, { x: 1 }, 10); + await methods.toggleConfigActive(PrincipalType.ROLE, 'admin', false); + + const result = await methods.findConfigByPrincipal(PrincipalType.ROLE, 'admin'); + expect(result).toBeNull(); + }); +}); + +describe('listAllConfigs', () => { + it('returns all configs when no filter', async () => { + await methods.upsertConfig(PrincipalType.ROLE, 'a', PrincipalModel.ROLE, {}, 10); + await methods.upsertConfig(PrincipalType.ROLE, 'b', PrincipalModel.ROLE, {}, 20); + await methods.toggleConfigActive(PrincipalType.ROLE, 'b', false); + + const all = await methods.listAllConfigs(); + expect(all).toHaveLength(2); + }); + + it('filters by isActive when specified', async () => { + await methods.upsertConfig(PrincipalType.ROLE, 'a', PrincipalModel.ROLE, {}, 10); + await methods.upsertConfig(PrincipalType.ROLE, 'b', PrincipalModel.ROLE, {}, 20); + await methods.toggleConfigActive(PrincipalType.ROLE, 'b', false); + + const active = await methods.listAllConfigs({ isActive: true }); + expect(active).toHaveLength(1); + expect(active[0].principalId).toBe('a'); + + const inactive = await methods.listAllConfigs({ isActive: false }); + expect(inactive).toHaveLength(1); + expect(inactive[0].principalId).toBe('b'); + }); + + it('returns configs sorted by priority ascending', async () => { + await methods.upsertConfig(PrincipalType.ROLE, 'high', PrincipalModel.ROLE, {}, 100); + await methods.upsertConfig(PrincipalType.ROLE, 'low', PrincipalModel.ROLE, {}, 0); + await methods.upsertConfig(PrincipalType.ROLE, 'mid', PrincipalModel.ROLE, {}, 50); + + const configs = await methods.listAllConfigs(); + expect(configs.map((c) => c.principalId)).toEqual(['low', 'mid', 'high']); + }); +}); + +describe('getApplicableConfigs', () => { + it('always includes the __base__ config', async () => { + await methods.upsertConfig(PrincipalType.ROLE, '__base__', PrincipalModel.ROLE, { a: 1 }, 0); + + const configs = await methods.getApplicableConfigs([]); + expect(configs).toHaveLength(1); + expect(configs[0].principalId).toBe('__base__'); + }); + + it('returns base + matching principals', async () => { + await methods.upsertConfig(PrincipalType.ROLE, '__base__', PrincipalModel.ROLE, { a: 1 }, 0); + await methods.upsertConfig(PrincipalType.ROLE, 'admin', PrincipalModel.ROLE, { b: 2 }, 10); + await methods.upsertConfig(PrincipalType.ROLE, 'user', PrincipalModel.ROLE, { c: 3 }, 10); + + const configs = await methods.getApplicableConfigs([ + { principalType: PrincipalType.ROLE, principalId: 'admin' }, + ]); + + expect(configs).toHaveLength(2); + expect(configs.map((c) => c.principalId).sort()).toEqual(['__base__', 'admin']); + }); + + it('returns sorted by priority', async () => { + await methods.upsertConfig(PrincipalType.ROLE, '__base__', PrincipalModel.ROLE, {}, 0); + await methods.upsertConfig(PrincipalType.ROLE, 'admin', PrincipalModel.ROLE, {}, 10); + + const configs = await methods.getApplicableConfigs([ + { principalType: PrincipalType.ROLE, principalId: 'admin' }, + ]); + + expect(configs[0].principalId).toBe('__base__'); + expect(configs[1].principalId).toBe('admin'); + }); + + it('skips principals with undefined principalId', async () => { + await methods.upsertConfig(PrincipalType.ROLE, '__base__', PrincipalModel.ROLE, {}, 0); + + const configs = await methods.getApplicableConfigs([ + { principalType: PrincipalType.GROUP, principalId: undefined }, + ]); + + expect(configs).toHaveLength(1); + }); +}); + +describe('patchConfigFields', () => { + it('atomically sets specific fields via $set', async () => { + await methods.upsertConfig( + PrincipalType.ROLE, + 'admin', + PrincipalModel.ROLE, + { interface: { endpointsMenu: true, sidePanel: true } }, + 10, + ); + + const result = await methods.patchConfigFields( + PrincipalType.ROLE, + 'admin', + PrincipalModel.ROLE, + { 'interface.endpointsMenu': false }, + 10, + ); + + const overrides = result!.overrides as Record; + const iface = overrides.interface as Record; + expect(iface.endpointsMenu).toBe(false); + expect(iface.sidePanel).toBe(true); + }); + + it('creates a config if none exists (upsert)', async () => { + const result = await methods.patchConfigFields( + PrincipalType.ROLE, + 'newrole', + PrincipalModel.ROLE, + { 'interface.endpointsMenu': false }, + 10, + ); + + expect(result).toBeTruthy(); + expect(result!.principalId).toBe('newrole'); + }); +}); + +describe('unsetConfigField', () => { + it('removes a field from overrides via $unset', async () => { + await methods.upsertConfig( + PrincipalType.ROLE, + 'admin', + PrincipalModel.ROLE, + { interface: { endpointsMenu: false, sidePanel: false } }, + 10, + ); + + const result = await methods.unsetConfigField( + PrincipalType.ROLE, + 'admin', + 'interface.endpointsMenu', + ); + const overrides = result!.overrides as Record; + const iface = overrides.interface as Record; + expect(iface.endpointsMenu).toBeUndefined(); + expect(iface.sidePanel).toBe(false); + }); + + it('returns null for non-existent config', async () => { + const result = await methods.unsetConfigField(PrincipalType.ROLE, 'ghost', 'a.b'); + expect(result).toBeNull(); + }); +}); + +describe('deleteConfig', () => { + it('deletes and returns the config', async () => { + await methods.upsertConfig(PrincipalType.ROLE, 'admin', PrincipalModel.ROLE, {}, 10); + const deleted = await methods.deleteConfig(PrincipalType.ROLE, 'admin'); + expect(deleted).toBeTruthy(); + + const found = await methods.findConfigByPrincipal(PrincipalType.ROLE, 'admin'); + expect(found).toBeNull(); + }); + + it('returns null when deleting non-existent config', async () => { + const result = await methods.deleteConfig(PrincipalType.ROLE, 'ghost'); + expect(result).toBeNull(); + }); +}); + +describe('toggleConfigActive', () => { + it('deactivates an active config', async () => { + await methods.upsertConfig(PrincipalType.ROLE, 'admin', PrincipalModel.ROLE, {}, 10); + + const result = await methods.toggleConfigActive(PrincipalType.ROLE, 'admin', false); + expect(result!.isActive).toBe(false); + }); + + it('reactivates an inactive config', async () => { + await methods.upsertConfig(PrincipalType.ROLE, 'admin', PrincipalModel.ROLE, {}, 10); + await methods.toggleConfigActive(PrincipalType.ROLE, 'admin', false); + + const result = await methods.toggleConfigActive(PrincipalType.ROLE, 'admin', true); + expect(result!.isActive).toBe(true); + }); +}); diff --git a/packages/data-schemas/src/methods/config.ts b/packages/data-schemas/src/methods/config.ts new file mode 100644 index 0000000000..42047d216f --- /dev/null +++ b/packages/data-schemas/src/methods/config.ts @@ -0,0 +1,215 @@ +import { Types } from 'mongoose'; +import { PrincipalType, PrincipalModel } from 'librechat-data-provider'; +import { BASE_CONFIG_PRINCIPAL_ID } from '~/admin/capabilities'; +import type { TCustomConfig } from 'librechat-data-provider'; +import type { Model, ClientSession } from 'mongoose'; +import type { IConfig } from '~/types'; + +export function createConfigMethods(mongoose: typeof import('mongoose')) { + async function findConfigByPrincipal( + principalType: PrincipalType, + principalId: string | Types.ObjectId, + options?: { includeInactive?: boolean }, + session?: ClientSession, + ): Promise { + const Config = mongoose.models.Config as Model; + const filter: { principalType: PrincipalType; principalId: string; isActive?: boolean } = { + principalType, + principalId: principalId.toString(), + }; + if (!options?.includeInactive) { + filter.isActive = true; + } + return await Config.findOne(filter) + .session(session ?? null) + .lean(); + } + + async function listAllConfigs( + filter?: { isActive?: boolean }, + session?: ClientSession, + ): Promise { + const Config = mongoose.models.Config as Model; + const where: { isActive?: boolean } = {}; + if (filter?.isActive !== undefined) { + where.isActive = filter.isActive; + } + return await Config.find(where) + .sort({ priority: 1 }) + .session(session ?? null) + .lean(); + } + + async function getApplicableConfigs( + principals?: Array<{ principalType: string; principalId?: string | Types.ObjectId }>, + session?: ClientSession, + ): Promise { + const Config = mongoose.models.Config as Model; + + const basePrincipal = { + principalType: PrincipalType.ROLE as string, + principalId: BASE_CONFIG_PRINCIPAL_ID, + }; + + const principalsQuery = [basePrincipal]; + + if (principals && principals.length > 0) { + for (const p of principals) { + if (p.principalId !== undefined) { + principalsQuery.push({ + principalType: p.principalType, + principalId: p.principalId.toString(), + }); + } + } + } + + return await Config.find({ + $or: principalsQuery, + isActive: true, + }) + .sort({ priority: 1 }) + .session(session ?? null) + .lean(); + } + + async function upsertConfig( + principalType: PrincipalType, + principalId: string | Types.ObjectId, + principalModel: PrincipalModel, + overrides: Partial, + priority: number, + session?: ClientSession, + ): Promise { + const Config = mongoose.models.Config as Model; + + const query = { + principalType, + principalId: principalId.toString(), + }; + + const update = { + $set: { + principalModel, + overrides, + priority, + isActive: true, + }, + $inc: { configVersion: 1 }, + }; + + const options = { + upsert: true, + new: true, + setDefaultsOnInsert: true, + ...(session ? { session } : {}), + }; + + try { + return await Config.findOneAndUpdate(query, update, options); + } catch (err: unknown) { + if ((err as { code?: number }).code === 11000) { + return await Config.findOneAndUpdate( + query, + { $set: update.$set, $inc: update.$inc }, + { new: true, ...(session ? { session } : {}) }, + ); + } + throw err; + } + } + + async function patchConfigFields( + principalType: PrincipalType, + principalId: string | Types.ObjectId, + principalModel: PrincipalModel, + fields: Record, + priority: number, + session?: ClientSession, + ): Promise { + const Config = mongoose.models.Config as Model; + + const setPayload: { principalModel: PrincipalModel; priority: number; [key: string]: unknown } = + { + principalModel, + priority, + }; + + for (const [path, value] of Object.entries(fields)) { + setPayload[`overrides.${path}`] = value; + } + + const options = { + upsert: true, + new: true, + setDefaultsOnInsert: true, + ...(session ? { session } : {}), + }; + + return await Config.findOneAndUpdate( + { principalType, principalId: principalId.toString() }, + { $set: setPayload, $inc: { configVersion: 1 } }, + options, + ); + } + + async function unsetConfigField( + principalType: PrincipalType, + principalId: string | Types.ObjectId, + fieldPath: string, + session?: ClientSession, + ): Promise { + const Config = mongoose.models.Config as Model; + + const options = { + new: true, + ...(session ? { session } : {}), + }; + + return await Config.findOneAndUpdate( + { principalType, principalId: principalId.toString() }, + { $unset: { [`overrides.${fieldPath}`]: '' }, $inc: { configVersion: 1 } }, + options, + ); + } + + async function deleteConfig( + principalType: PrincipalType, + principalId: string | Types.ObjectId, + session?: ClientSession, + ): Promise { + const Config = mongoose.models.Config as Model; + + return await Config.findOneAndDelete({ + principalType, + principalId: principalId.toString(), + }).session(session ?? null); + } + + async function toggleConfigActive( + principalType: PrincipalType, + principalId: string | Types.ObjectId, + isActive: boolean, + session?: ClientSession, + ): Promise { + const Config = mongoose.models.Config as Model; + return await Config.findOneAndUpdate( + { principalType, principalId: principalId.toString() }, + { $set: { isActive } }, + { new: true, ...(session ? { session } : {}) }, + ); + } + + return { + listAllConfigs, + findConfigByPrincipal, + getApplicableConfigs, + upsertConfig, + patchConfigFields, + unsetConfigField, + deleteConfig, + toggleConfigActive, + }; +} + +export type ConfigMethods = ReturnType; diff --git a/packages/data-schemas/src/methods/index.ts b/packages/data-schemas/src/methods/index.ts index 11f00e7827..4202cac0eb 100644 --- a/packages/data-schemas/src/methods/index.ts +++ b/packages/data-schemas/src/methods/index.ts @@ -48,6 +48,8 @@ import { createSpendTokensMethods, type SpendTokensMethods } from './spendTokens import { createPromptMethods, type PromptMethods, type PromptDeps } from './prompt'; /* Tier 5 โ€” Agent */ import { createAgentMethods, type AgentMethods, type AgentDeps } from './agent'; +/* Config */ +import { createConfigMethods, type ConfigMethods } from './config'; export { tokenValues, cacheTokenValues, premiumTokenValues, defaultRate }; @@ -80,7 +82,8 @@ export type AllMethods = UserMethods & TransactionMethods & SpendTokensMethods & PromptMethods & - AgentMethods; + AgentMethods & + ConfigMethods; /** Dependencies injected from the api layer into createMethods */ export interface CreateMethodsDeps { @@ -201,6 +204,8 @@ export function createMethods( ...promptMethods, /* Tier 5 */ ...agentMethods, + /* Config */ + ...createConfigMethods(mongoose), }; } @@ -235,4 +240,5 @@ export type { SpendTokensMethods, PromptMethods, AgentMethods, + ConfigMethods, }; diff --git a/packages/data-schemas/src/methods/systemGrant.spec.ts b/packages/data-schemas/src/methods/systemGrant.spec.ts index 188d31b544..b17285c761 100644 --- a/packages/data-schemas/src/methods/systemGrant.spec.ts +++ b/packages/data-schemas/src/methods/systemGrant.spec.ts @@ -2,8 +2,8 @@ import mongoose, { Types } from 'mongoose'; import { PrincipalType, SystemRoles } from 'librechat-data-provider'; import { MongoMemoryServer } from 'mongodb-memory-server'; import type * as t from '~/types'; -import type { SystemCapability } from '~/systemCapabilities'; -import { SystemCapabilities, CapabilityImplications } from '~/systemCapabilities'; +import type { SystemCapability } from '~/types/admin'; +import { SystemCapabilities, CapabilityImplications } from '~/admin/capabilities'; import { createSystemGrantMethods } from './systemGrant'; import systemGrantSchema from '~/schema/systemGrant'; import logger from '~/config/winston'; diff --git a/packages/data-schemas/src/methods/systemGrant.ts b/packages/data-schemas/src/methods/systemGrant.ts index f0f389d762..6071dd38c5 100644 --- a/packages/data-schemas/src/methods/systemGrant.ts +++ b/packages/data-schemas/src/methods/systemGrant.ts @@ -1,8 +1,8 @@ import { PrincipalType, SystemRoles } from 'librechat-data-provider'; import type { Types, Model, ClientSession } from 'mongoose'; -import type { SystemCapability } from '~/systemCapabilities'; +import type { SystemCapability } from '~/types/admin'; import type { ISystemGrant } from '~/types'; -import { SystemCapabilities, CapabilityImplications } from '~/systemCapabilities'; +import { SystemCapabilities, CapabilityImplications } from '~/admin/capabilities'; import { normalizePrincipalId } from '~/utils/principal'; import logger from '~/config/winston'; diff --git a/packages/data-schemas/src/models/config.ts b/packages/data-schemas/src/models/config.ts new file mode 100644 index 0000000000..97c08ce1da --- /dev/null +++ b/packages/data-schemas/src/models/config.ts @@ -0,0 +1,8 @@ +import configSchema from '~/schema/config'; +import { applyTenantIsolation } from '~/models/plugins/tenantIsolation'; +import type * as t from '~/types'; + +export function createConfigModel(mongoose: typeof import('mongoose')) { + applyTenantIsolation(configSchema); + return mongoose.models.Config || mongoose.model('Config', configSchema); +} diff --git a/packages/data-schemas/src/models/index.ts b/packages/data-schemas/src/models/index.ts index 44d94c6ab4..5a8e8f1c2c 100644 --- a/packages/data-schemas/src/models/index.ts +++ b/packages/data-schemas/src/models/index.ts @@ -27,6 +27,7 @@ import { createAccessRoleModel } from './accessRole'; import { createAclEntryModel } from './aclEntry'; import { createSystemGrantModel } from './systemGrant'; import { createGroupModel } from './group'; +import { createConfigModel } from './config'; /** * Creates all database models for all collections @@ -62,5 +63,6 @@ export function createModels(mongoose: typeof import('mongoose')) { AclEntry: createAclEntryModel(mongoose), SystemGrant: createSystemGrantModel(mongoose), Group: createGroupModel(mongoose), + Config: createConfigModel(mongoose), }; } diff --git a/packages/data-schemas/src/schema/config.ts b/packages/data-schemas/src/schema/config.ts new file mode 100644 index 0000000000..be3784d55e --- /dev/null +++ b/packages/data-schemas/src/schema/config.ts @@ -0,0 +1,55 @@ +import { Schema } from 'mongoose'; +import { PrincipalType, PrincipalModel } from 'librechat-data-provider'; +import type { IConfig } from '~/types'; + +const configSchema = new Schema( + { + principalType: { + type: String, + enum: Object.values(PrincipalType), + required: true, + index: true, + }, + principalId: { + type: String, + refPath: 'principalModel', + required: true, + index: true, + }, + principalModel: { + type: String, + enum: Object.values(PrincipalModel), + required: true, + }, + priority: { + type: Number, + required: true, + index: true, + }, + overrides: { + type: Schema.Types.Mixed, + default: {}, + }, + isActive: { + type: Boolean, + default: true, + index: true, + }, + configVersion: { + type: Number, + default: 0, + }, + tenantId: { + type: String, + index: true, + }, + }, + { timestamps: true }, +); + +// Enforce 1:1 principal-to-config (one config document per principal per tenant) +configSchema.index({ principalType: 1, principalId: 1, tenantId: 1 }, { unique: true }); +configSchema.index({ principalType: 1, principalId: 1, isActive: 1, tenantId: 1 }); +configSchema.index({ priority: 1, isActive: 1, tenantId: 1 }); + +export default configSchema; diff --git a/packages/data-schemas/src/schema/index.ts b/packages/data-schemas/src/schema/index.ts index 456eb03ac2..2a5eff658b 100644 --- a/packages/data-schemas/src/schema/index.ts +++ b/packages/data-schemas/src/schema/index.ts @@ -25,3 +25,4 @@ export { default as userSchema } from './user'; export { default as memorySchema } from './memory'; export { default as groupSchema } from './group'; export { default as systemGrantSchema } from './systemGrant'; +export { default as configSchema } from './config'; diff --git a/packages/data-schemas/src/schema/systemGrant.ts b/packages/data-schemas/src/schema/systemGrant.ts index 0366f6080d..a20a407bf1 100644 --- a/packages/data-schemas/src/schema/systemGrant.ts +++ b/packages/data-schemas/src/schema/systemGrant.ts @@ -1,7 +1,7 @@ import { Schema } from 'mongoose'; import { PrincipalType } from 'librechat-data-provider'; -import { SystemCapabilities } from '~/systemCapabilities'; -import type { SystemCapability } from '~/systemCapabilities'; +import { SystemCapabilities } from '~/admin/capabilities'; +import type { SystemCapability } from '~/types/admin'; import type { ISystemGrant } from '~/types'; const baseCapabilities = new Set(Object.values(SystemCapabilities)); diff --git a/packages/data-schemas/src/systemCapabilities.ts b/packages/data-schemas/src/systemCapabilities.ts deleted file mode 100644 index cf2acfbf88..0000000000 --- a/packages/data-schemas/src/systemCapabilities.ts +++ /dev/null @@ -1,106 +0,0 @@ -import type { z } from 'zod'; -import type { configSchema } from 'librechat-data-provider'; -import { ResourceType } from 'librechat-data-provider'; - -export const SystemCapabilities = { - ACCESS_ADMIN: 'access:admin', - READ_USERS: 'read:users', - MANAGE_USERS: 'manage:users', - READ_GROUPS: 'read:groups', - MANAGE_GROUPS: 'manage:groups', - READ_ROLES: 'read:roles', - MANAGE_ROLES: 'manage:roles', - READ_CONFIGS: 'read:configs', - MANAGE_CONFIGS: 'manage:configs', - ASSIGN_CONFIGS: 'assign:configs', - READ_USAGE: 'read:usage', - READ_AGENTS: 'read:agents', - MANAGE_AGENTS: 'manage:agents', - MANAGE_MCP_SERVERS: 'manage:mcpservers', - READ_PROMPTS: 'read:prompts', - MANAGE_PROMPTS: 'manage:prompts', - /** Reserved โ€” not yet enforced by any middleware. Grant has no effect until assistant listing is gated. */ - READ_ASSISTANTS: 'read:assistants', - MANAGE_ASSISTANTS: 'manage:assistants', -} as const; - -/** Top-level keys of the configSchema from librechat.yaml. */ -export type ConfigSection = keyof z.infer; - -/** Principal types that can receive config overrides. */ -export type ConfigAssignTarget = 'user' | 'group' | 'role'; - -/** Base capabilities defined in the SystemCapabilities object. */ -type BaseSystemCapability = (typeof SystemCapabilities)[keyof typeof SystemCapabilities]; - -/** Section-level config capabilities derived from configSchema keys. */ -type ConfigSectionCapability = `manage:configs:${ConfigSection}` | `read:configs:${ConfigSection}`; - -/** Principal-scoped config assignment capabilities. */ -type ConfigAssignCapability = `assign:configs:${ConfigAssignTarget}`; - -/** - * Union of all valid capability strings: - * - Base capabilities from SystemCapabilities - * - Section-level config capabilities (manage:configs:
, read:configs:
) - * - Config assignment capabilities (assign:configs:) - */ -export type SystemCapability = - | BaseSystemCapability - | ConfigSectionCapability - | ConfigAssignCapability; - -/** - * Capabilities that are implied by holding a broader capability. - * When `hasCapability` checks for an implied capability, it first expands - * the principal's grant set โ€” so granting `MANAGE_USERS` automatically - * satisfies a `READ_USERS` check without a separate grant. - * - * Implication is one-directional: `MANAGE_USERS` implies `READ_USERS`, - * but `READ_USERS` does NOT imply `MANAGE_USERS`. - */ -export const CapabilityImplications: Partial> = - { - [SystemCapabilities.MANAGE_USERS]: [SystemCapabilities.READ_USERS], - [SystemCapabilities.MANAGE_GROUPS]: [SystemCapabilities.READ_GROUPS], - [SystemCapabilities.MANAGE_ROLES]: [SystemCapabilities.READ_ROLES], - [SystemCapabilities.MANAGE_CONFIGS]: [SystemCapabilities.READ_CONFIGS], - [SystemCapabilities.MANAGE_AGENTS]: [SystemCapabilities.READ_AGENTS], - [SystemCapabilities.MANAGE_PROMPTS]: [SystemCapabilities.READ_PROMPTS], - [SystemCapabilities.MANAGE_ASSISTANTS]: [SystemCapabilities.READ_ASSISTANTS], - }; - -/** - * Maps each ACL ResourceType to the SystemCapability that grants - * unrestricted management access. Typed as `Record` - * so adding a new ResourceType variant causes a compile error until a - * capability is assigned here. - */ -export const ResourceCapabilityMap: Record = { - [ResourceType.AGENT]: SystemCapabilities.MANAGE_AGENTS, - [ResourceType.PROMPTGROUP]: SystemCapabilities.MANAGE_PROMPTS, - [ResourceType.MCPSERVER]: SystemCapabilities.MANAGE_MCP_SERVERS, - [ResourceType.REMOTE_AGENT]: SystemCapabilities.MANAGE_AGENTS, -}; - -/** - * Derives a section-level config management capability from a configSchema key. - * @example configCapability('endpoints') โ†’ 'manage:configs:endpoints' - * - * TODO: Section-level config capabilities are scaffolded but not yet active. - * To activate delegated config management: - * 1. Expose POST/DELETE /api/admin/grants endpoints (wiring grantCapability/revokeCapability) - * 2. Seed section-specific grants for delegated admin roles via those endpoints - * 3. Guard config write handlers with hasConfigCapability(user, section) - */ -export function configCapability(section: ConfigSection): `manage:configs:${ConfigSection}` { - return `manage:configs:${section}`; -} - -/** - * Derives a section-level config read capability from a configSchema key. - * @example readConfigCapability('endpoints') โ†’ 'read:configs:endpoints' - */ -export function readConfigCapability(section: ConfigSection): `read:configs:${ConfigSection}` { - return `read:configs:${section}`; -} diff --git a/packages/data-schemas/src/types/admin.ts b/packages/data-schemas/src/types/admin.ts new file mode 100644 index 0000000000..99915f659d --- /dev/null +++ b/packages/data-schemas/src/types/admin.ts @@ -0,0 +1,126 @@ +import type { + PrincipalType, + PrincipalModel, + TCustomConfig, + z, + configSchema, +} from 'librechat-data-provider'; +import type { SystemCapabilities } from '~/admin/capabilities'; + +/* โ”€โ”€ Capability types โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ */ + +/** Base capabilities derived from the SystemCapabilities constant. */ +export type BaseSystemCapability = (typeof SystemCapabilities)[keyof typeof SystemCapabilities]; + +/** Principal types that can receive config overrides. */ +export type ConfigAssignTarget = 'user' | 'group' | 'role'; + +/** Top-level keys of the configSchema from librechat.yaml. */ +export type ConfigSection = keyof z.infer; + +/** Section-level config capabilities derived from configSchema keys. */ +type ConfigSectionCapability = `manage:configs:${ConfigSection}` | `read:configs:${ConfigSection}`; + +/** Principal-scoped config assignment capabilities. */ +type ConfigAssignCapability = `assign:configs:${ConfigAssignTarget}`; + +/** + * Union of all valid capability strings: + * - Base capabilities from SystemCapabilities + * - Section-level config capabilities (manage:configs:
, read:configs:
) + * - Config assignment capabilities (assign:configs:) + */ +export type SystemCapability = + | BaseSystemCapability + | ConfigSectionCapability + | ConfigAssignCapability; + +/** UI grouping of capabilities for the admin panel's capability editor. */ +export type CapabilityCategory = { + key: string; + labelKey: string; + capabilities: BaseSystemCapability[]; +}; + +/* โ”€โ”€ Admin API response types โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ */ + +/** Config document as returned by the admin API (no Mongoose internals). */ +export type AdminConfig = { + _id: string; + principalType: PrincipalType; + principalId: string; + principalModel: PrincipalModel; + priority: number; + overrides: Partial; + isActive: boolean; + configVersion: number; + tenantId?: string; + createdAt?: string; + updatedAt?: string; +}; + +export type AdminConfigListResponse = { + configs: AdminConfig[]; +}; + +export type AdminConfigResponse = { + config: AdminConfig; +}; + +export type AdminConfigDeleteResponse = { + success: boolean; +}; + +/** Audit action types for grant changes. */ +export type AuditAction = 'grant_assigned' | 'grant_removed'; + +/** SystemGrant document as returned by the admin API. */ +export type AdminSystemGrant = { + id: string; + principalType: PrincipalType; + principalId: string; + capability: string; + grantedBy?: string; + grantedAt: string; + expiresAt?: string; +}; + +/** Audit log entry for grant changes as returned by the admin API. */ +export type AdminAuditLogEntry = { + id: string; + action: AuditAction; + actorId: string; + actorName: string; + targetPrincipalType: PrincipalType; + targetPrincipalId: string; + targetName: string; + capability: string; + timestamp: string; +}; + +/** Group as returned by the admin API. */ +export type AdminGroup = { + id: string; + name: string; + description: string; + memberCount: number; + topMembers: { name: string }[]; + isActive: boolean; +}; + +/** Member entry as returned by the admin API for group/role membership lists. */ +export type AdminMember = { + userId: string; + name: string; + email: string; + avatarUrl?: string; + joinedAt: string; +}; + +/** Minimal user info returned by user search endpoints. */ +export type AdminUserSearchResult = { + userId: string; + name: string; + email: string; + avatarUrl?: string; +}; diff --git a/packages/data-schemas/src/types/config.ts b/packages/data-schemas/src/types/config.ts new file mode 100644 index 0000000000..04e0ca58ab --- /dev/null +++ b/packages/data-schemas/src/types/config.ts @@ -0,0 +1,36 @@ +import { PrincipalType, PrincipalModel } from 'librechat-data-provider'; +import type { TCustomConfig } from 'librechat-data-provider'; +import type { Document, Types } from 'mongoose'; + +/** + * Configuration override for a principal (user, group, or role). + * Stores partial overrides at the TCustomConfig (YAML) level, + * which are merged with the base config before processing through AppService. + */ +export type Config = { + /** The type of principal (user, group, role) */ + principalType: PrincipalType; + /** The ID of the principal (ObjectId for users/groups, string for roles) */ + principalId: Types.ObjectId | string; + /** The model name for the principal */ + principalModel: PrincipalModel; + /** Priority level for determining merge order (higher = more specific) */ + priority: number; + /** Configuration overrides matching librechat.yaml structure */ + overrides: Partial; + /** Whether this config override is currently active */ + isActive: boolean; + /** Version number for cache invalidation, auto-increments on overrides change */ + configVersion: number; + /** Tenant identifier for multi-tenancy isolation */ + tenantId?: string; + /** When this config was created */ + createdAt?: Date; + /** When this config was last updated */ + updatedAt?: Date; +}; + +export type IConfig = Config & + Document & { + _id: Types.ObjectId; + }; diff --git a/packages/data-schemas/src/types/index.ts b/packages/data-schemas/src/types/index.ts index 26238cbda1..748ea5d77d 100644 --- a/packages/data-schemas/src/types/index.ts +++ b/packages/data-schemas/src/types/index.ts @@ -28,6 +28,10 @@ export * from './accessRole'; export * from './aclEntry'; export * from './systemGrant'; export * from './group'; +/* Config */ +export * from './config'; +/* Admin */ +export * from './admin'; /* Web */ export * from './web'; /* MCP Servers */ diff --git a/packages/data-schemas/src/types/systemGrant.ts b/packages/data-schemas/src/types/systemGrant.ts index 9f0d576503..09cff1aec6 100644 --- a/packages/data-schemas/src/types/systemGrant.ts +++ b/packages/data-schemas/src/types/systemGrant.ts @@ -1,6 +1,6 @@ import type { Document, Types } from 'mongoose'; import type { PrincipalType } from 'librechat-data-provider'; -import type { SystemCapability } from '~/systemCapabilities'; +import type { SystemCapability } from '~/types/admin'; export type SystemGrant = { /** The type of principal โ€” matches PrincipalType enum values */ From df82f2e9b221e21ad8a02de52019927c445c780b Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Thu, 26 Mar 2026 12:27:31 -0400 Subject: [PATCH 02/18] =?UTF-8?q?=F0=9F=8F=81=20fix:=20Invalidate=20Messag?= =?UTF-8?q?e=20Cache=20on=20Stream=20404=20Instead=20of=20Showing=20Error?= =?UTF-8?q?=20(#12411)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: Invalidate message cache on STREAM_EXPIRED instead of showing error When a 404 (stream expired) is received during SSE resume, the generation has already completed and messages are persisted in the database. Instead of injecting an error message into the cache, invalidate the messages query so react-query refetches from the DB. Also clear stale stream status cache and step maps to prevent retries and memory leaks. * fix: Mark conversation as processed when no active job found Prevents useResumeOnLoad from repeatedly re-checking the same conversation when the stream status returns inactive. The ref still resets on conversation change, so navigating away and back will correctly re-check. Also wait for background refetches to settle (isFetching) before acting on inactive status, preventing stale cached active:false from suppressing a valid resume. * test: Update useResumableSSE spec for cache invalidation on 404 Verify message cache invalidation, stream status removal, clearStepMaps, and setIsSubmitting(false) on the 404 path. * fix: Resolve lint warnings from CI Remove unused ErrorTypes import in test, add queryClient to useCallback dependency array in useResumableSSE. * Reorder import statements in useResumableSSE.ts --- .../SSE/__tests__/useResumableSSE.spec.ts | 37 ++++++++++++------- client/src/hooks/SSE/useResumableSSE.ts | 28 +++++++++----- client/src/hooks/SSE/useResumeOnLoad.ts | 18 +++++---- 3 files changed, 53 insertions(+), 30 deletions(-) diff --git a/client/src/hooks/SSE/__tests__/useResumableSSE.spec.ts b/client/src/hooks/SSE/__tests__/useResumableSSE.spec.ts index 9100f39858..1717d27c22 100644 --- a/client/src/hooks/SSE/__tests__/useResumableSSE.spec.ts +++ b/client/src/hooks/SSE/__tests__/useResumableSSE.spec.ts @@ -1,5 +1,5 @@ import { renderHook, act } from '@testing-library/react'; -import { Constants, ErrorTypes, LocalStorageKeys } from 'librechat-data-provider'; +import { Constants, LocalStorageKeys } from 'librechat-data-provider'; import type { TSubmission } from 'librechat-data-provider'; type SSEEventListener = (e: Partial & { responseCode?: number }) => void; @@ -34,7 +34,13 @@ jest.mock('sse.js', () => ({ })); const mockSetQueryData = jest.fn(); -const mockQueryClient = { setQueryData: mockSetQueryData }; +const mockInvalidateQueries = jest.fn(); +const mockRemoveQueries = jest.fn(); +const mockQueryClient = { + setQueryData: mockSetQueryData, + invalidateQueries: mockInvalidateQueries, + removeQueries: mockRemoveQueries, +}; jest.mock('@tanstack/react-query', () => ({ ...jest.requireActual('@tanstack/react-query'), @@ -63,6 +69,7 @@ jest.mock('~/data-provider', () => ({ useGetStartupConfig: () => ({ data: { balance: { enabled: false } } }), useGetUserBalance: () => ({ refetch: jest.fn() }), queueTitleGeneration: jest.fn(), + streamStatusQueryKey: (conversationId: string) => ['streamStatus', conversationId], })); const mockErrorHandler = jest.fn(); @@ -162,6 +169,11 @@ describe('useResumableSSE - 404 error path', () => { beforeEach(() => { mockSSEInstances.length = 0; localStorage.clear(); + mockErrorHandler.mockClear(); + mockClearStepMaps.mockClear(); + mockSetIsSubmitting.mockClear(); + mockInvalidateQueries.mockClear(); + mockRemoveQueries.mockClear(); }); const seedDraft = (conversationId: string) => { @@ -200,19 +212,18 @@ describe('useResumableSSE - 404 error path', () => { unmount(); }); - it('calls errorHandler with STREAM_EXPIRED error type on 404', async () => { + it('invalidates message cache and clears stream status on 404 instead of showing error', async () => { const { unmount } = await render404Scenario(CONV_ID); - expect(mockErrorHandler).toHaveBeenCalledTimes(1); - const call = mockErrorHandler.mock.calls[0][0]; - expect(call.data).toBeDefined(); - const parsed = JSON.parse(call.data.text); - expect(parsed.type).toBe(ErrorTypes.STREAM_EXPIRED); - expect(call.submission).toEqual( - expect.objectContaining({ - conversation: expect.objectContaining({ conversationId: CONV_ID }), - }), - ); + expect(mockErrorHandler).not.toHaveBeenCalled(); + expect(mockInvalidateQueries).toHaveBeenCalledWith({ + queryKey: ['messages', CONV_ID], + }); + expect(mockRemoveQueries).toHaveBeenCalledWith({ + queryKey: ['streamStatus', CONV_ID], + }); + expect(mockClearStepMaps).toHaveBeenCalled(); + expect(mockSetIsSubmitting).toHaveBeenCalledWith(false); unmount(); }); diff --git a/client/src/hooks/SSE/useResumableSSE.ts b/client/src/hooks/SSE/useResumableSSE.ts index 32820f8392..39dc610dae 100644 --- a/client/src/hooks/SSE/useResumableSSE.ts +++ b/client/src/hooks/SSE/useResumableSSE.ts @@ -16,7 +16,12 @@ import { } from 'librechat-data-provider'; import type { TMessage, TPayload, TSubmission, EventSubmission } from 'librechat-data-provider'; import type { EventHandlerParams } from './useEventHandlers'; -import { useGetStartupConfig, useGetUserBalance, queueTitleGeneration } from '~/data-provider'; +import { + useGetUserBalance, + useGetStartupConfig, + queueTitleGeneration, + streamStatusQueryKey, +} from '~/data-provider'; import type { ActiveJobsResponse } from '~/data-provider'; import { useAuthContext } from '~/hooks/AuthContext'; import useEventHandlers from './useEventHandlers'; @@ -343,18 +348,20 @@ export default function useResumableSSE( /* @ts-ignore - sse.js types don't expose responseCode */ const responseCode = e.responseCode; - // 404 means job doesn't exist (completed/deleted) - don't retry + // 404 โ†’ job completed & was cleaned up; messages are persisted in DB. + // Invalidate cache once so react-query refetches instead of showing an error. if (responseCode === 404) { - console.log('[ResumableSSE] Stream not found (404) - job completed or expired'); + const convoId = currentSubmission.conversation?.conversationId; + console.log('[ResumableSSE] Stream 404, invalidating messages for:', convoId); sse.close(); removeActiveJob(currentStreamId); - clearAllDrafts(currentSubmission.conversation?.conversationId); - errorHandler({ - data: { - text: JSON.stringify({ type: ErrorTypes.STREAM_EXPIRED }), - } as unknown as Parameters[0]['data'], - submission: currentSubmission as EventSubmission, - }); + clearAllDrafts(convoId); + clearStepMaps(); + if (convoId) { + queryClient.invalidateQueries({ queryKey: [QueryKeys.messages, convoId] }); + queryClient.removeQueries({ queryKey: streamStatusQueryKey(convoId) }); + } + setIsSubmitting(false); setShowStopButton(false); setStreamId(null); reconnectAttemptRef.current = 0; @@ -544,6 +551,7 @@ export default function useResumableSSE( startupConfig?.balance?.enabled, balanceQuery, removeActiveJob, + queryClient, ], ); diff --git a/client/src/hooks/SSE/useResumeOnLoad.ts b/client/src/hooks/SSE/useResumeOnLoad.ts index f09751db0e..5f0f691787 100644 --- a/client/src/hooks/SSE/useResumeOnLoad.ts +++ b/client/src/hooks/SSE/useResumeOnLoad.ts @@ -125,7 +125,11 @@ export default function useResumeOnLoad( conversationId !== Constants.NEW_CONVO && processedConvoRef.current !== conversationId; // Don't re-check processed convos - const { data: streamStatus, isSuccess } = useStreamStatus(conversationId, shouldCheck); + const { + data: streamStatus, + isSuccess, + isFetching, + } = useStreamStatus(conversationId, shouldCheck); useEffect(() => { console.log('[ResumeOnLoad] Effect check', { @@ -135,6 +139,7 @@ export default function useResumeOnLoad( hasCurrentSubmission: !!currentSubmission, currentSubmissionConvoId: currentSubmission?.conversation?.conversationId, isSuccess, + isFetching, streamStatusActive: streamStatus?.active, streamStatusStreamId: streamStatus?.streamId, processedConvoRef: processedConvoRef.current, @@ -171,8 +176,9 @@ export default function useResumeOnLoad( ); } - // Wait for stream status query to complete - if (!isSuccess || !streamStatus) { + // Wait for stream status query to complete (including background refetches + // that may replace a stale cached result with fresh data) + if (!isSuccess || !streamStatus || isFetching) { console.log('[ResumeOnLoad] Waiting for stream status query'); return; } @@ -183,15 +189,12 @@ export default function useResumeOnLoad( return; } - // Check if there's an active job to resume - // DON'T mark as processed here - only mark when we actually create a submission - // This prevents stale cache data from blocking subsequent resume attempts if (!streamStatus.active || !streamStatus.streamId) { console.log('[ResumeOnLoad] No active job to resume for:', conversationId); + processedConvoRef.current = conversationId; return; } - // Mark as processed NOW - we verified there's an active job and will create submission processedConvoRef.current = conversationId; console.log('[ResumeOnLoad] Found active job, creating submission...', { @@ -241,6 +244,7 @@ export default function useResumeOnLoad( submissionConvoId, currentSubmission, isSuccess, + isFetching, streamStatus, getMessages, setSubmission, From 1123f96e6a0b9dbdee35488b9449d584c3a371e8 Mon Sep 17 00:00:00 2001 From: Marco Beretta <81851188+berry-13@users.noreply.github.com> Date: Thu, 26 Mar 2026 18:43:33 +0100 Subject: [PATCH 03/18] =?UTF-8?q?=F0=9F=93=9D=20docs:=20add=20UTM=20tracki?= =?UTF-8?q?ng=20parameters=20to=20Railway=20deployment=20links=20(#12228)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 7da34974e3..3e05dc686b 100644 --- a/README.md +++ b/README.md @@ -32,7 +32,7 @@

- + Deploy on Railway From 359cc63b41383c6e8b3eb74dc072065abbbe7ef3 Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Thu, 26 Mar 2026 14:44:31 -0400 Subject: [PATCH 04/18] =?UTF-8?q?=E2=9A=A1=20refactor:=20Use=20in-memory?= =?UTF-8?q?=20cache=20for=20App=20MCP=20configs=20to=20avoid=20Redis=20SCA?= =?UTF-8?q?N=20(#12410)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * โšก perf: Use in-memory cache for App MCP configs to avoid Redis SCAN The 'App' namespace holds static YAML-loaded configs identical on every instance. Storing them in Redis and retrieving via SCAN + batch-GET caused 60s+ stalls under concurrent load (#11624). Since these configs are already loaded into memory at startup, bypass Redis entirely by always returning ServerConfigsCacheInMemory for the 'App' namespace. * โ™ป๏ธ refactor: Extract APP_CACHE_NAMESPACE constant and harden tests - Extract magic string 'App' to a shared `APP_CACHE_NAMESPACE` constant used by both ServerConfigsCacheFactory and MCPServersRegistry - Document that `leaderOnly` is ignored for the App namespace - Reset `cacheConfig.USE_REDIS` in test `beforeEach` to prevent ordering-dependent flakiness - Fix import order in test file (longest to shortest) * ๐Ÿ› fix: Populate App cache on follower instances in cluster mode In cluster deployments, only the leader runs MCPServersInitializer to inspect and cache MCP server configs. Followers previously read these from Redis, but with the App namespace now using in-memory storage, followers would have an empty cache. Add populateLocalCache() so follower processes independently initialize their own in-memory App cache from the same YAML configs after the leader signals completion. The method is idempotent โ€” if the cache is already populated (leader case), it's a no-op. * ๐Ÿ› fix: Use static flag for populateLocalCache idempotency Replace getAllServerConfigs() idempotency check with a static localCachePopulated flag. The previous check merged App + DB caches, causing false early returns in deployments with publicly shared DB configs, and poisoned the TTL read-through cache with stale results. The static flag is zero-cost (no async/Redis/DB calls), immune to DB config interference, and is reset alongside hasInitializedThisProcess in resetProcessFlag() for test teardown. Also set localCachePopulated=true after leader initialization completes, so subsequent calls on the leader don't redundantly re-run populateLocalCache. * ๐Ÿ“ docs: Document process-local reset() semantics for App cache With the App namespace using in-memory storage, reset() only clears the calling process's cache. Add JSDoc noting this behavioral change so callers in cluster deployments know each instance must reset independently. * โœ… test: Add follower cache population tests for MCPServersInitializer Cover the populateLocalCache code path: - Follower populates its own App cache after leader signals completion - localCachePopulated flag prevents redundant re-initialization - Fresh follower process independently initializes all servers * ๐Ÿงน style: Fix import order to longest-to-shortest convention * ๐Ÿ”ฌ test: Add Redis perf benchmark to isolate getAll() bottleneck Benchmarks that run against a live Redis instance to measure: 1. SCAN vs batched GET phases independently 2. SCAN cost scaling with total keyspace size (noise keys) 3. Concurrent getAll() at various concurrency levels (1/10/50/100) 4. Alternative: single aggregate key vs SCAN+GET 5. Alternative: raw MGET vs Keyv batch GET (serialization overhead) Run with: npx jest --config packages/api/jest.config.mjs \ --testPathPatterns="perf_benchmark" --coverage=false * โšก feat: Add aggregate-key Redis cache for MCP App configs ServerConfigsCacheRedisAggregateKey stores all configs under a single Redis key, making getAll() a single GET instead of SCAN + N GETs. This eliminates the O(keyspace_size) SCAN that caused 60s+ stalls in large deployments while preserving cross-instance visibility โ€” all instances read/write the same Redis key, so reinspection results propagate automatically after readThroughCache TTL expiry. * โ™ป๏ธ refactor: Use aggregate-key cache for App namespace in factory Update ServerConfigsCacheFactory to return ServerConfigsCacheRedisAggregateKey for the App namespace when Redis is enabled, instead of ServerConfigsCacheInMemory. This preserves cross-instance visibility (reinspection results propagate through Redis) while eliminating SCAN. Non-App namespaces still use the standard per-key ServerConfigsCacheRedis. * ๐Ÿ—‘๏ธ revert: Remove populateLocalCache โ€” no longer needed with aggregate key With App configs stored under a single Redis key (aggregate approach), followers read from Redis like before. The populateLocalCache mechanism and its localCachePopulated flag are no longer necessary. Also reverts the process-local reset() JSDoc since reset() is now cluster-wide again via Redis. * ๐Ÿ› fix: Add write mutex to aggregate cache and exclude perf benchmark from CI - Add promise-based write lock to ServerConfigsCacheRedisAggregateKey to prevent concurrent read-modify-write races during parallel initialization (Promise.allSettled runs multiple addServer calls concurrently, causing last-write-wins data loss on the aggregate key) - Rename perf benchmark to cache_integration pattern so CI skips it (requires live Redis) * ๐Ÿ”ง fix: Rename perf benchmark to *.manual.spec.ts to exclude from all CI The cache_integration pattern is picked up by test:cache-integration:mcp in CI. Rename to *.manual.spec.ts which isn't matched by any CI runner. * โœ… test: Add cache integration tests for ServerConfigsCacheRedisAggregateKey Tests against a live Redis instance covering: - CRUD operations (add, get, update, remove) - getAll with empty/populated cache - Duplicate add rejection, missing update/remove errors - Concurrent write safety (20 parallel adds without data loss) - Concurrent read safety (50 parallel getAll calls) - Reset clears all configs * ๐Ÿ”ง fix: Rename perf benchmark to *.manual.spec.ts to exclude from all CI The perf benchmark file was renamed to *.manual.spec.ts but no testPathIgnorePatterns existed for that convention. Add .*manual\.spec\. to both test and test:ci scripts, plus jest.config.mjs, so manual-only tests never run in CI unit test jobs. * fix: Address review findings for aggregate key cache - Add successCheck() to all write paths (add/update/remove) so Redis SET failures throw instead of being silently swallowed - Override reset() to use targeted cache.delete(AGGREGATE_KEY) instead of inherited SCAN-based cache.clear() โ€” consistent with eliminating SCAN operations - Document cross-instance write race invariant in class JSDoc: the promise-based writeLock is process-local only; callers must enforce single-writer semantics externally (leader-only init) - Use definite-assignment assertion (let resolve!:) instead of non-null assertion at call site - Fix import type convention in integration test - Verify Promise.allSettled rejections explicitly in concurrent write test - Fix broken run command in benchmark file header * style: Fix import ordering per AGENTS.md convention Local/project imports sorted longest to shortest. * chore: Update import ordering and clean up unused imports in MCPServersRegistry.ts * chore: import order * chore: import order --- packages/api/jest.config.mjs | 1 + packages/api/package.json | 4 +- .../src/mcp/registry/MCPServersRegistry.ts | 4 +- .../cache/ServerConfigsCacheFactory.ts | 48 ++- .../ServerConfigsCacheRedisAggregateKey.ts | 136 +++++++ .../ServerConfigsCacheFactory.test.ts | 47 ++- ...gsCacheRedis.perf_benchmark.manual.spec.ts | 336 ++++++++++++++++++ ...edisAggregateKey.cache_integration.spec.ts | 246 +++++++++++++ 8 files changed, 779 insertions(+), 43 deletions(-) create mode 100644 packages/api/src/mcp/registry/cache/ServerConfigsCacheRedisAggregateKey.ts create mode 100644 packages/api/src/mcp/registry/cache/__tests__/ServerConfigsCacheRedis.perf_benchmark.manual.spec.ts create mode 100644 packages/api/src/mcp/registry/cache/__tests__/ServerConfigsCacheRedisAggregateKey.cache_integration.spec.ts diff --git a/packages/api/jest.config.mjs b/packages/api/jest.config.mjs index df9cf6bcc2..976b794122 100644 --- a/packages/api/jest.config.mjs +++ b/packages/api/jest.config.mjs @@ -8,6 +8,7 @@ export default { '\\.helper\\.ts$', '\\.helper\\.d\\.ts$', '/__tests__/helpers/', + '\\.manual\\.spec\\.[jt]sx?$', ], coverageReporters: ['text', 'cobertura'], testResultsProcessor: 'jest-junit', diff --git a/packages/api/package.json b/packages/api/package.json index a4e74a7a3c..f09d946ec5 100644 --- a/packages/api/package.json +++ b/packages/api/package.json @@ -18,8 +18,8 @@ "build:dev": "npm run clean && NODE_ENV=development rollup -c --bundleConfigAsCjs", "build:watch": "NODE_ENV=development rollup -c -w --bundleConfigAsCjs", "build:watch:prod": "rollup -c -w --bundleConfigAsCjs", - "test": "jest --coverage --watch --testPathIgnorePatterns=\"\\.*integration\\.|\\.*helper\\.|__tests__/helpers/\"", - "test:ci": "jest --coverage --ci --testPathIgnorePatterns=\"\\.*integration\\.|\\.*helper\\.|__tests__/helpers/\"", + "test": "jest --coverage --watch --testPathIgnorePatterns=\"\\.*integration\\.|\\.*helper\\.|__tests__/helpers/|\\.*manual\\.spec\\.\"", + "test:ci": "jest --coverage --ci --testPathIgnorePatterns=\"\\.*integration\\.|\\.*helper\\.|__tests__/helpers/|\\.*manual\\.spec\\.\"", "test:cache-integration:core": "jest --testPathPatterns=\"src/cache/.*\\.cache_integration\\.spec\\.ts$\" --coverage=false", "test:cache-integration:cluster": "jest --testPathPatterns=\"src/cluster/.*\\.cache_integration\\.spec\\.ts$\" --coverage=false --runInBand", "test:cache-integration:mcp": "jest --testPathPatterns=\"src/mcp/.*\\.cache_integration\\.spec\\.ts$\" --coverage=false", diff --git a/packages/api/src/mcp/registry/MCPServersRegistry.ts b/packages/api/src/mcp/registry/MCPServersRegistry.ts index 506f5b1baa..b9c1eb66f5 100644 --- a/packages/api/src/mcp/registry/MCPServersRegistry.ts +++ b/packages/api/src/mcp/registry/MCPServersRegistry.ts @@ -2,8 +2,8 @@ import { Keyv } from 'keyv'; import { logger } from '@librechat/data-schemas'; import type { IServerConfigsRepositoryInterface } from './ServerConfigsRepositoryInterface'; import type * as t from '~/mcp/types'; +import { ServerConfigsCacheFactory, APP_CACHE_NAMESPACE } from './cache/ServerConfigsCacheFactory'; import { MCPInspectionFailedError, isMCPDomainNotAllowedError } from '~/mcp/errors'; -import { ServerConfigsCacheFactory } from './cache/ServerConfigsCacheFactory'; import { MCPServerInspector } from './MCPServerInspector'; import { ServerConfigsDB } from './db/ServerConfigsDB'; import { cacheConfig } from '~/cache/cacheConfig'; @@ -33,7 +33,7 @@ export class MCPServersRegistry { constructor(mongoose: typeof import('mongoose'), allowedDomains?: string[] | null) { this.dbConfigsRepo = new ServerConfigsDB(mongoose); - this.cacheConfigsRepo = ServerConfigsCacheFactory.create('App', false); + this.cacheConfigsRepo = ServerConfigsCacheFactory.create(APP_CACHE_NAMESPACE, false); this.allowedDomains = allowedDomains; const ttl = cacheConfig.MCP_REGISTRY_CACHE_TTL; diff --git a/packages/api/src/mcp/registry/cache/ServerConfigsCacheFactory.ts b/packages/api/src/mcp/registry/cache/ServerConfigsCacheFactory.ts index ba0cec90ea..b9549629d6 100644 --- a/packages/api/src/mcp/registry/cache/ServerConfigsCacheFactory.ts +++ b/packages/api/src/mcp/registry/cache/ServerConfigsCacheFactory.ts @@ -1,31 +1,51 @@ -import { cacheConfig } from '~/cache'; +import { ServerConfigsCacheRedisAggregateKey } from './ServerConfigsCacheRedisAggregateKey'; import { ServerConfigsCacheInMemory } from './ServerConfigsCacheInMemory'; import { ServerConfigsCacheRedis } from './ServerConfigsCacheRedis'; +import { cacheConfig } from '~/cache'; -export type ServerConfigsCache = ServerConfigsCacheInMemory | ServerConfigsCacheRedis; +export type ServerConfigsCache = + | ServerConfigsCacheInMemory + | ServerConfigsCacheRedis + | ServerConfigsCacheRedisAggregateKey; /** - * Factory for creating the appropriate ServerConfigsCache implementation based on deployment mode. - * Automatically selects between in-memory and Redis-backed storage depending on USE_REDIS config. - * In single-instance mode (USE_REDIS=false), returns lightweight in-memory cache. - * In cluster mode (USE_REDIS=true), returns Redis-backed cache with distributed coordination. - * Provides a unified interface regardless of the underlying storage mechanism. + * Namespace for YAML-loaded app-level MCP configs. When Redis is enabled, uses a single + * aggregate key instead of per-server keys to avoid the costly SCAN + batch-GET pattern + * in {@link ServerConfigsCacheRedis.getAll} that caused 60s+ stalls under concurrent + * load (see GitHub #11624, #12408). When Redis is disabled, uses in-memory storage. + */ +export const APP_CACHE_NAMESPACE = 'App' as const; + +/** + * Factory for creating the appropriate ServerConfigsCache implementation based on + * deployment mode and namespace. + * + * The {@link APP_CACHE_NAMESPACE} namespace uses {@link ServerConfigsCacheRedisAggregateKey} + * when Redis is enabled โ€” storing all configs under a single key so `getAll()` is one GET + * instead of SCAN + N GETs. Cross-instance visibility is preserved: reinspection results + * propagate through Redis automatically. + * + * Other namespaces use the standard {@link ServerConfigsCacheRedis} (per-key storage with + * SCAN-based enumeration) when Redis is enabled. */ export class ServerConfigsCacheFactory { /** * Create a ServerConfigsCache instance. - * Returns Redis implementation if Redis is configured, otherwise in-memory implementation. * - * @param namespace - The namespace for the cache (e.g., 'App') - only used for Redis namespacing - * @param leaderOnly - Whether operations should only be performed by the leader (only applies to Redis) + * @param namespace - The namespace for the cache. {@link APP_CACHE_NAMESPACE} uses + * aggregate-key Redis storage (or in-memory when Redis is disabled). + * @param leaderOnly - Whether write operations should only be performed by the leader. * @returns ServerConfigsCache instance */ static create(namespace: string, leaderOnly: boolean): ServerConfigsCache { - if (cacheConfig.USE_REDIS) { - return new ServerConfigsCacheRedis(namespace, leaderOnly); + if (!cacheConfig.USE_REDIS) { + return new ServerConfigsCacheInMemory(); } - // In-memory mode uses a simple Map - doesn't need namespace - return new ServerConfigsCacheInMemory(); + if (namespace === APP_CACHE_NAMESPACE) { + return new ServerConfigsCacheRedisAggregateKey(namespace, leaderOnly); + } + + return new ServerConfigsCacheRedis(namespace, leaderOnly); } } diff --git a/packages/api/src/mcp/registry/cache/ServerConfigsCacheRedisAggregateKey.ts b/packages/api/src/mcp/registry/cache/ServerConfigsCacheRedisAggregateKey.ts new file mode 100644 index 0000000000..12f423a1fb --- /dev/null +++ b/packages/api/src/mcp/registry/cache/ServerConfigsCacheRedisAggregateKey.ts @@ -0,0 +1,136 @@ +import { logger } from '@librechat/data-schemas'; +import type Keyv from 'keyv'; +import type { IServerConfigsRepositoryInterface } from '~/mcp/registry/ServerConfigsRepositoryInterface'; +import type { ParsedServerConfig, AddServerResult } from '~/mcp/types'; +import { BaseRegistryCache } from './BaseRegistryCache'; +import { standardCache } from '~/cache'; + +/** + * Redis-backed MCP server configs cache that stores all entries under a single aggregate key. + * + * Unlike {@link ServerConfigsCacheRedis} which uses SCAN + batch-GET for `getAll()`, this + * implementation stores the entire config map as a single JSON value in Redis. This makes + * `getAll()` a single O(1) GET regardless of keyspace size, eliminating the 60s+ stalls + * caused by SCAN under concurrent load in large deployments (see GitHub #11624, #12408). + * + * Trade-offs: + * - `add/update/remove` use a serialized read-modify-write on the aggregate key via a + * promise-based mutex. This prevents concurrent writes from racing within a single + * process (e.g., during `Promise.allSettled` initialization of multiple servers). + * - The entire config map is serialized/deserialized on every operation. With typical MCP + * deployments (~5-50 servers), the JSON payload is small (10-50KB). + * - Cross-instance visibility is preserved: all instances read/write the same Redis key, + * so reinspection results propagate automatically after readThroughCache TTL expiry. + * + * IMPORTANT: The promise-based writeLock serializes writes within a single Node.js process + * only. Concurrent writes from separate instances race at the Redis level (last-write-wins). + * This is acceptable because writes are performed exclusively by the leader during + * initialization via {@link MCPServersInitializer}. `reinspectServer` is manual and rare. + * Callers must enforce this single-writer invariant externally. + */ +const AGGREGATE_KEY = '__all__'; + +export class ServerConfigsCacheRedisAggregateKey + extends BaseRegistryCache + implements IServerConfigsRepositoryInterface +{ + protected readonly cache: Keyv; + private writeLock: Promise = Promise.resolve(); + + constructor(namespace: string, leaderOnly: boolean) { + super(leaderOnly); + this.cache = standardCache(`${this.PREFIX}::Servers::${namespace}`); + } + + /** + * Serializes write operations to prevent concurrent read-modify-write races. + * Reads (`get`, `getAll`) are not serialized โ€” they can run concurrently. + */ + private async withWriteLock(fn: () => Promise): Promise { + const previousLock = this.writeLock; + let resolve!: () => void; + this.writeLock = new Promise((r) => { + resolve = r; + }); + try { + await previousLock; + return await fn(); + } finally { + resolve(); + } + } + + public async getAll(): Promise> { + const startTime = Date.now(); + const result = (await this.cache.get(AGGREGATE_KEY)) as + | Record + | undefined; + const elapsed = Date.now() - startTime; + logger.debug( + `[ServerConfigsCacheRedisAggregateKey] getAll: fetched ${result ? Object.keys(result).length : 0} configs in ${elapsed}ms`, + ); + return result ?? {}; + } + + public async get(serverName: string): Promise { + const all = await this.getAll(); + return all[serverName]; + } + + public async add(serverName: string, config: ParsedServerConfig): Promise { + if (this.leaderOnly) await this.leaderCheck('add MCP servers'); + return this.withWriteLock(async () => { + const all = await this.getAll(); + if (all[serverName]) { + throw new Error( + `Server "${serverName}" already exists in cache. Use update() to modify existing configs.`, + ); + } + const storedConfig = { ...config, updatedAt: Date.now() }; + all[serverName] = storedConfig; + const success = await this.cache.set(AGGREGATE_KEY, all); + this.successCheck(`add App server "${serverName}"`, success); + return { serverName, config: storedConfig }; + }); + } + + public async update(serverName: string, config: ParsedServerConfig): Promise { + if (this.leaderOnly) await this.leaderCheck('update MCP servers'); + return this.withWriteLock(async () => { + const all = await this.getAll(); + if (!all[serverName]) { + throw new Error( + `Server "${serverName}" does not exist in cache. Use add() to create new configs.`, + ); + } + all[serverName] = { ...config, updatedAt: Date.now() }; + const success = await this.cache.set(AGGREGATE_KEY, all); + this.successCheck(`update App server "${serverName}"`, success); + }); + } + + public async remove(serverName: string): Promise { + if (this.leaderOnly) await this.leaderCheck('remove MCP servers'); + return this.withWriteLock(async () => { + const all = await this.getAll(); + if (!all[serverName]) { + throw new Error(`Failed to remove server "${serverName}" in cache.`); + } + delete all[serverName]; + const success = await this.cache.set(AGGREGATE_KEY, all); + this.successCheck(`remove App server "${serverName}"`, success); + }); + } + + /** + * Resets the aggregate key directly instead of using SCAN-based `cache.clear()`. + * Only one key (`__all__`) ever exists in this namespace, so a targeted delete is + * more efficient and consistent with the PR's goal of eliminating SCAN operations. + */ + public override async reset(): Promise { + if (this.leaderOnly) { + await this.leaderCheck('reset App MCP servers cache'); + } + await this.cache.delete(AGGREGATE_KEY); + } +} diff --git a/packages/api/src/mcp/registry/cache/__tests__/ServerConfigsCacheFactory.test.ts b/packages/api/src/mcp/registry/cache/__tests__/ServerConfigsCacheFactory.test.ts index 7499ae127e..577b878cc7 100644 --- a/packages/api/src/mcp/registry/cache/__tests__/ServerConfigsCacheFactory.test.ts +++ b/packages/api/src/mcp/registry/cache/__tests__/ServerConfigsCacheFactory.test.ts @@ -1,9 +1,11 @@ -import { ServerConfigsCacheFactory } from '../ServerConfigsCacheFactory'; +import { ServerConfigsCacheFactory, APP_CACHE_NAMESPACE } from '../ServerConfigsCacheFactory'; +import { ServerConfigsCacheRedisAggregateKey } from '../ServerConfigsCacheRedisAggregateKey'; import { ServerConfigsCacheInMemory } from '../ServerConfigsCacheInMemory'; import { ServerConfigsCacheRedis } from '../ServerConfigsCacheRedis'; import { cacheConfig } from '~/cache'; // Mock the cache implementations +jest.mock('../ServerConfigsCacheRedisAggregateKey'); jest.mock('../ServerConfigsCacheInMemory'); jest.mock('../ServerConfigsCacheRedis'); @@ -17,53 +19,48 @@ jest.mock('~/cache', () => ({ describe('ServerConfigsCacheFactory', () => { beforeEach(() => { jest.clearAllMocks(); + cacheConfig.USE_REDIS = false; }); describe('create()', () => { - it('should return ServerConfigsCacheRedis when USE_REDIS is true', () => { - // Arrange + it('should return ServerConfigsCacheRedisAggregateKey for App namespace when USE_REDIS is true', () => { cacheConfig.USE_REDIS = true; - // Act - const cache = ServerConfigsCacheFactory.create('App', true); + const cache = ServerConfigsCacheFactory.create(APP_CACHE_NAMESPACE, false); - // Assert - expect(cache).toBeInstanceOf(ServerConfigsCacheRedis); - expect(ServerConfigsCacheRedis).toHaveBeenCalledWith('App', true); + expect(cache).toBeInstanceOf(ServerConfigsCacheRedisAggregateKey); + expect(ServerConfigsCacheRedisAggregateKey).toHaveBeenCalledWith(APP_CACHE_NAMESPACE, false); + expect(ServerConfigsCacheRedis).not.toHaveBeenCalled(); + expect(ServerConfigsCacheInMemory).not.toHaveBeenCalled(); }); - it('should return ServerConfigsCacheInMemory when USE_REDIS is false', () => { - // Arrange + it('should return ServerConfigsCacheInMemory for App namespace when USE_REDIS is false', () => { cacheConfig.USE_REDIS = false; - // Act - const cache = ServerConfigsCacheFactory.create('App', false); + const cache = ServerConfigsCacheFactory.create(APP_CACHE_NAMESPACE, false); - // Assert expect(cache).toBeInstanceOf(ServerConfigsCacheInMemory); - expect(ServerConfigsCacheInMemory).toHaveBeenCalled(); + expect(ServerConfigsCacheInMemory).toHaveBeenCalledWith(); + expect(ServerConfigsCacheRedis).not.toHaveBeenCalled(); + expect(ServerConfigsCacheRedisAggregateKey).not.toHaveBeenCalled(); }); - it('should pass correct parameters to ServerConfigsCacheRedis', () => { - // Arrange + it('should return ServerConfigsCacheRedis for non-App namespaces when USE_REDIS is true', () => { cacheConfig.USE_REDIS = true; - // Act - ServerConfigsCacheFactory.create('CustomNamespace', true); + const cache = ServerConfigsCacheFactory.create('CustomNamespace', true); - // Assert + expect(cache).toBeInstanceOf(ServerConfigsCacheRedis); expect(ServerConfigsCacheRedis).toHaveBeenCalledWith('CustomNamespace', true); + expect(ServerConfigsCacheRedisAggregateKey).not.toHaveBeenCalled(); }); - it('should create ServerConfigsCacheInMemory without parameters when USE_REDIS is false', () => { - // Arrange + it('should return ServerConfigsCacheInMemory for non-App namespaces when USE_REDIS is false', () => { cacheConfig.USE_REDIS = false; - // Act - ServerConfigsCacheFactory.create('App', false); + const cache = ServerConfigsCacheFactory.create('CustomNamespace', false); - // Assert - // In-memory cache doesn't use namespace/leaderOnly parameters + expect(cache).toBeInstanceOf(ServerConfigsCacheInMemory); expect(ServerConfigsCacheInMemory).toHaveBeenCalledWith(); }); }); diff --git a/packages/api/src/mcp/registry/cache/__tests__/ServerConfigsCacheRedis.perf_benchmark.manual.spec.ts b/packages/api/src/mcp/registry/cache/__tests__/ServerConfigsCacheRedis.perf_benchmark.manual.spec.ts new file mode 100644 index 0000000000..1815d49fe0 --- /dev/null +++ b/packages/api/src/mcp/registry/cache/__tests__/ServerConfigsCacheRedis.perf_benchmark.manual.spec.ts @@ -0,0 +1,336 @@ +/** + * Performance benchmark for ServerConfigsCacheRedis.getAll() + * + * Requires a live Redis instance. Run manually (excluded from CI): + * npx jest --config packages/api/jest.config.mjs --testPathPatterns="perf_benchmark" --coverage=false + * + * Set env vars as needed: + * USE_REDIS=true REDIS_URI=redis://localhost:6379 npx jest ... + * + * This benchmark isolates the two phases of getAll() โ€” SCAN (key discovery) and + * batched GET (value retrieval) โ€” to identify the actual bottleneck under load. + * It also benchmarks alternative approaches (single aggregate key, MGET) against + * the current SCAN+GET implementation. + */ +import { expect } from '@playwright/test'; +import type { ParsedServerConfig } from '~/mcp/types'; + +describe('ServerConfigsCacheRedis Performance Benchmark', () => { + let ServerConfigsCacheRedis: typeof import('../ServerConfigsCacheRedis').ServerConfigsCacheRedis; + let keyvRedisClient: Awaited['keyvRedisClient']; + let standardCache: Awaited['standardCache']; + + const PREFIX = 'perf-bench'; + + const makeConfig = (i: number): ParsedServerConfig => + ({ + type: 'stdio', + command: `cmd-${i}`, + args: [`arg-${i}`, `--flag-${i}`], + env: { KEY: `value-${i}`, EXTRA: `extra-${i}` }, + requiresOAuth: false, + tools: `tool_a_${i}, tool_b_${i}`, + capabilities: `{"tools":{"listChanged":true}}`, + serverInstructions: `Instructions for server ${i}`, + }) as ParsedServerConfig; + + beforeAll(async () => { + process.env.USE_REDIS = process.env.USE_REDIS ?? 'true'; + process.env.USE_REDIS_CLUSTER = process.env.USE_REDIS_CLUSTER ?? 'true'; + process.env.REDIS_URI = + process.env.REDIS_URI ?? + 'redis://127.0.0.1:7001,redis://127.0.0.1:7002,redis://127.0.0.1:7003'; + process.env.REDIS_KEY_PREFIX = process.env.REDIS_KEY_PREFIX ?? 'perf-bench-test'; + + const cacheModule = await import('../ServerConfigsCacheRedis'); + const redisClients = await import('~/cache/redisClients'); + const cacheFactory = await import('~/cache'); + + ServerConfigsCacheRedis = cacheModule.ServerConfigsCacheRedis; + keyvRedisClient = redisClients.keyvRedisClient; + standardCache = cacheFactory.standardCache; + + if (!keyvRedisClient) throw new Error('Redis client is not initialized'); + await redisClients.keyvRedisClientReady; + }); + + afterAll(async () => { + if (keyvRedisClient?.isOpen) await keyvRedisClient.disconnect(); + }); + + /** Clean up all keys matching our test prefix */ + async function cleanupKeys(pattern: string): Promise { + if (!keyvRedisClient || !('scanIterator' in keyvRedisClient)) return; + const keys: string[] = []; + for await (const key of keyvRedisClient.scanIterator({ MATCH: pattern })) { + keys.push(key); + } + if (keys.length > 0) { + await Promise.all(keys.map((key) => keyvRedisClient!.del(key))); + } + } + + /** Populate a cache with N configs and return the cache instance */ + async function populateCache( + namespace: string, + count: number, + ): Promise> { + const cache = new ServerConfigsCacheRedis(namespace, false); + for (let i = 0; i < count; i++) { + await cache.add(`server-${i}`, makeConfig(i)); + } + return cache; + } + + /** + * Benchmark 1: Isolate SCAN vs GET phases in current getAll() + * + * Measures time spent in each phase separately to identify the bottleneck. + */ + describe('Phase isolation: SCAN vs batched GET', () => { + const CONFIG_COUNTS = [5, 20, 50]; + + for (const count of CONFIG_COUNTS) { + it(`should measure SCAN and GET phases separately for ${count} configs`, async () => { + const ns = `${PREFIX}-phase-${count}`; + const cache = await populateCache(ns, count); + + try { + // Get the Keyv cache instance namespace for pattern matching + const keyvCache = standardCache(`MCP::ServersRegistry::Servers::${ns}`); + const pattern = `*MCP::ServersRegistry::Servers::${ns}:*`; + + // Phase 1: SCAN only (key discovery) + const scanStart = Date.now(); + const keys: string[] = []; + for await (const key of keyvRedisClient!.scanIterator({ MATCH: pattern })) { + keys.push(key); + } + const scanMs = Date.now() - scanStart; + + // Phase 2: Batched GET only (value retrieval via Keyv) + const keyNames = keys.map((key) => key.substring(key.lastIndexOf(':') + 1)); + const BATCH_SIZE = 100; + const getStart = Date.now(); + for (let i = 0; i < keyNames.length; i += BATCH_SIZE) { + const batch = keyNames.slice(i, i + BATCH_SIZE); + await Promise.all(batch.map((k) => keyvCache.get(k))); + } + const getMs = Date.now() - getStart; + + // Phase 3: Full getAll() (both phases combined) + const fullStart = Date.now(); + const result = await cache.getAll(); + const fullMs = Date.now() - fullStart; + + console.log( + `[${count} configs] SCAN: ${scanMs}ms | GET: ${getMs}ms | Full getAll: ${fullMs}ms | Keys found: ${keys.length}`, + ); + + expect(Object.keys(result).length).toBe(count); + + // Clean up the Keyv instance + await keyvCache.clear(); + } finally { + await cleanupKeys(`*${ns}*`); + } + }); + } + }); + + /** + * Benchmark 2: SCAN cost scales with total Redis keyspace, not just matching keys + * + * Redis SCAN iterates the entire hash table and filters by pattern. With a large + * keyspace (many non-matching keys), SCAN takes longer even if few keys match. + * This test measures SCAN time with background noise keys. + */ + describe('SCAN cost vs keyspace size', () => { + it('should measure SCAN latency with background noise keys', async () => { + const ns = `${PREFIX}-noise`; + const targetCount = 10; + + // Add target configs + const cache = await populateCache(ns, targetCount); + + // Add noise keys in a different namespace to inflate the keyspace + const noiseCount = 500; + const noiseCache = standardCache(`noise-namespace-${Date.now()}`); + for (let i = 0; i < noiseCount; i++) { + await noiseCache.set(`noise-${i}`, { data: `value-${i}` }); + } + + try { + const pattern = `*MCP::ServersRegistry::Servers::${ns}:*`; + + // Measure SCAN with noise + const scanStart = Date.now(); + const keys: string[] = []; + for await (const key of keyvRedisClient!.scanIterator({ MATCH: pattern })) { + keys.push(key); + } + const scanMs = Date.now() - scanStart; + + // Measure full getAll + const fullStart = Date.now(); + const result = await cache.getAll(); + const fullMs = Date.now() - fullStart; + + console.log( + `[${targetCount} configs + ${noiseCount} noise keys] SCAN: ${scanMs}ms | Full getAll: ${fullMs}ms`, + ); + + expect(Object.keys(result).length).toBe(targetCount); + } finally { + await noiseCache.clear(); + await cleanupKeys(`*${ns}*`); + } + }); + }); + + /** + * Benchmark 3: Concurrent getAll() calls (simulates the actual production bottleneck) + * + * Multiple users hitting /api/mcp/* simultaneously, all triggering getAll() + * after the 5s TTL read-through cache expires. + */ + describe('Concurrent getAll() under load', () => { + const CONCURRENCY_LEVELS = [1, 10, 50, 100]; + const CONFIG_COUNT = 30; + + for (const concurrency of CONCURRENCY_LEVELS) { + it(`should measure ${concurrency} concurrent getAll() calls with ${CONFIG_COUNT} configs`, async () => { + const ns = `${PREFIX}-concurrent-${concurrency}`; + const cache = await populateCache(ns, CONFIG_COUNT); + + try { + const startTime = Date.now(); + const promises = Array.from({ length: concurrency }, () => cache.getAll()); + const results = await Promise.all(promises); + const elapsed = Date.now() - startTime; + + console.log( + `[${CONFIG_COUNT} configs x ${concurrency} concurrent] Total: ${elapsed}ms | Per-call avg: ${(elapsed / concurrency).toFixed(1)}ms`, + ); + + for (const result of results) { + expect(Object.keys(result).length).toBe(CONFIG_COUNT); + } + } finally { + await cleanupKeys(`*${ns}*`); + } + }); + } + }); + + /** + * Benchmark 4: Alternative โ€” Single aggregate key + * + * Instead of SCAN+GET, store all configs under one Redis key. + * getAll() becomes a single GET + JSON parse. + */ + describe('Alternative: Single aggregate key', () => { + it('should compare aggregate key vs SCAN+GET for getAll()', async () => { + const ns = `${PREFIX}-aggregate`; + const configCount = 30; + const cache = await populateCache(ns, configCount); + + // Build the aggregate object + const aggregate: Record = {}; + for (let i = 0; i < configCount; i++) { + aggregate[`server-${i}`] = makeConfig(i); + } + + // Store as single key + const aggregateCache = standardCache(`aggregate-test-${Date.now()}`); + await aggregateCache.set('all', aggregate); + + try { + // Measure SCAN+GET approach + const scanStart = Date.now(); + const scanResult = await cache.getAll(); + const scanMs = Date.now() - scanStart; + + // Measure single-key approach + const aggStart = Date.now(); + const aggResult = (await aggregateCache.get('all')) as Record; + const aggMs = Date.now() - aggStart; + + console.log( + `[${configCount} configs] SCAN+GET: ${scanMs}ms | Single key: ${aggMs}ms | Speedup: ${(scanMs / Math.max(aggMs, 1)).toFixed(1)}x`, + ); + + expect(Object.keys(scanResult).length).toBe(configCount); + expect(Object.keys(aggResult).length).toBe(configCount); + + // Concurrent comparison + const concurrency = 100; + const scanConcStart = Date.now(); + await Promise.all(Array.from({ length: concurrency }, () => cache.getAll())); + const scanConcMs = Date.now() - scanConcStart; + + const aggConcStart = Date.now(); + await Promise.all(Array.from({ length: concurrency }, () => aggregateCache.get('all'))); + const aggConcMs = Date.now() - aggConcStart; + + console.log( + `[${configCount} configs x ${concurrency} concurrent] SCAN+GET: ${scanConcMs}ms | Single key: ${aggConcMs}ms | Speedup: ${(scanConcMs / Math.max(aggConcMs, 1)).toFixed(1)}x`, + ); + } finally { + await aggregateCache.clear(); + await cleanupKeys(`*${ns}*`); + } + }); + }); + + /** + * Benchmark 5: Alternative โ€” Raw MGET (bypassing Keyv serialization overhead) + * + * Keyv wraps each value in { value, expires } JSON. Using raw MGET on the + * Redis client skips the Keyv layer entirely. + */ + describe('Alternative: Raw MGET vs Keyv batch GET', () => { + it('should compare raw MGET vs Keyv GET for value retrieval', async () => { + const ns = `${PREFIX}-mget`; + const configCount = 30; + const cache = await populateCache(ns, configCount); + + try { + // First, discover keys via SCAN (same for both approaches) + const pattern = `*MCP::ServersRegistry::Servers::${ns}:*`; + const keys: string[] = []; + for await (const key of keyvRedisClient!.scanIterator({ MATCH: pattern })) { + keys.push(key); + } + + // Approach 1: Keyv batch GET (current implementation) + const keyvCache = standardCache(`MCP::ServersRegistry::Servers::${ns}`); + const keyNames = keys.map((key) => key.substring(key.lastIndexOf(':') + 1)); + + const keyvStart = Date.now(); + await Promise.all(keyNames.map((k) => keyvCache.get(k))); + const keyvMs = Date.now() - keyvStart; + + // Approach 2: Raw MGET (no Keyv overhead) + const mgetStart = Date.now(); + if ('mGet' in keyvRedisClient!) { + const rawValues = await ( + keyvRedisClient as { mGet: (keys: string[]) => Promise<(string | null)[]> } + ).mGet(keys); + // Parse the Keyv-wrapped JSON values + rawValues.filter(Boolean).map((v) => JSON.parse(v!)); + } + const mgetMs = Date.now() - mgetStart; + + console.log( + `[${configCount} configs] Keyv batch GET: ${keyvMs}ms | Raw MGET: ${mgetMs}ms | Speedup: ${(keyvMs / Math.max(mgetMs, 1)).toFixed(1)}x`, + ); + + // Clean up + await keyvCache.clear(); + } finally { + await cleanupKeys(`*${ns}*`); + } + }); + }); +}); diff --git a/packages/api/src/mcp/registry/cache/__tests__/ServerConfigsCacheRedisAggregateKey.cache_integration.spec.ts b/packages/api/src/mcp/registry/cache/__tests__/ServerConfigsCacheRedisAggregateKey.cache_integration.spec.ts new file mode 100644 index 0000000000..cbb75609d1 --- /dev/null +++ b/packages/api/src/mcp/registry/cache/__tests__/ServerConfigsCacheRedisAggregateKey.cache_integration.spec.ts @@ -0,0 +1,246 @@ +import { expect } from '@playwright/test'; +import type { ParsedServerConfig } from '~/mcp/types'; + +describe('ServerConfigsCacheRedisAggregateKey Integration Tests', () => { + let ServerConfigsCacheRedisAggregateKey: typeof import('../ServerConfigsCacheRedisAggregateKey').ServerConfigsCacheRedisAggregateKey; + let keyvRedisClient: Awaited['keyvRedisClient']; + + let cache: InstanceType< + typeof import('../ServerConfigsCacheRedisAggregateKey').ServerConfigsCacheRedisAggregateKey + >; + + const mockConfig1 = { + type: 'stdio', + command: 'node', + args: ['server1.js'], + env: { TEST: 'value1' }, + } as ParsedServerConfig; + + const mockConfig2 = { + type: 'stdio', + command: 'python', + args: ['server2.py'], + env: { TEST: 'value2' }, + } as ParsedServerConfig; + + const mockConfig3 = { + type: 'sse', + url: 'http://localhost:3000', + requiresOAuth: true, + } as ParsedServerConfig; + + beforeAll(async () => { + process.env.USE_REDIS = process.env.USE_REDIS ?? 'true'; + process.env.USE_REDIS_CLUSTER = process.env.USE_REDIS_CLUSTER ?? 'true'; + process.env.REDIS_URI = + process.env.REDIS_URI ?? + 'redis://127.0.0.1:7001,redis://127.0.0.1:7002,redis://127.0.0.1:7003'; + process.env.REDIS_KEY_PREFIX = process.env.REDIS_KEY_PREFIX ?? 'AggregateKey-IntegrationTest'; + + const cacheModule = await import('../ServerConfigsCacheRedisAggregateKey'); + const redisClients = await import('~/cache/redisClients'); + + ServerConfigsCacheRedisAggregateKey = cacheModule.ServerConfigsCacheRedisAggregateKey; + keyvRedisClient = redisClients.keyvRedisClient; + + if (!keyvRedisClient) throw new Error('Redis client is not initialized'); + await redisClients.keyvRedisClientReady; + }); + + beforeEach(() => { + cache = new ServerConfigsCacheRedisAggregateKey('agg-test', false); + }); + + afterEach(async () => { + await cache.reset(); + }); + + afterAll(async () => { + if (keyvRedisClient?.isOpen) await keyvRedisClient.disconnect(); + }); + + describe('add and get operations', () => { + it('should add and retrieve a server config', async () => { + await cache.add('server1', mockConfig1); + const result = await cache.get('server1'); + expect(result).toMatchObject(mockConfig1); + }); + + it('should return undefined for non-existent server', async () => { + const result = await cache.get('non-existent'); + expect(result).toBeUndefined(); + }); + + it('should throw error when adding duplicate server', async () => { + await cache.add('server1', mockConfig1); + await expect(cache.add('server1', mockConfig2)).rejects.toThrow( + 'Server "server1" already exists in cache. Use update() to modify existing configs.', + ); + }); + + it('should handle multiple server configs', async () => { + await cache.add('server1', mockConfig1); + await cache.add('server2', mockConfig2); + await cache.add('server3', mockConfig3); + + expect(await cache.get('server1')).toMatchObject(mockConfig1); + expect(await cache.get('server2')).toMatchObject(mockConfig2); + expect(await cache.get('server3')).toMatchObject(mockConfig3); + }); + }); + + describe('getAll operation', () => { + it('should return empty object when no servers exist', async () => { + const result = await cache.getAll(); + expect(result).toMatchObject({}); + }); + + it('should return all server configs', async () => { + await cache.add('server1', mockConfig1); + await cache.add('server2', mockConfig2); + await cache.add('server3', mockConfig3); + + const result = await cache.getAll(); + expect(result).toMatchObject({ + server1: mockConfig1, + server2: mockConfig2, + server3: mockConfig3, + }); + }); + + it('should reflect additions in getAll', async () => { + await cache.add('server1', mockConfig1); + await cache.add('server2', mockConfig2); + + let result = await cache.getAll(); + expect(Object.keys(result).length).toBe(2); + + await cache.add('server3', mockConfig3); + result = await cache.getAll(); + expect(Object.keys(result).length).toBe(3); + expect(result.server3).toMatchObject(mockConfig3); + }); + }); + + describe('update operation', () => { + it('should update an existing server config', async () => { + await cache.add('server1', mockConfig1); + expect(await cache.get('server1')).toMatchObject(mockConfig1); + + await cache.update('server1', mockConfig2); + const result = await cache.get('server1'); + expect(result).toMatchObject(mockConfig2); + }); + + it('should throw error when updating non-existent server', async () => { + await expect(cache.update('non-existent', mockConfig1)).rejects.toThrow( + 'Server "non-existent" does not exist in cache. Use add() to create new configs.', + ); + }); + + it('should reflect updates in getAll', async () => { + await cache.add('server1', mockConfig1); + await cache.add('server2', mockConfig2); + + await cache.update('server1', mockConfig3); + const result = await cache.getAll(); + expect(result.server1).toMatchObject(mockConfig3); + expect(result.server2).toMatchObject(mockConfig2); + }); + }); + + describe('remove operation', () => { + it('should remove an existing server config', async () => { + await cache.add('server1', mockConfig1); + expect(await cache.get('server1')).toMatchObject(mockConfig1); + + await cache.remove('server1'); + expect(await cache.get('server1')).toBeUndefined(); + }); + + it('should throw error when removing non-existent server', async () => { + await expect(cache.remove('non-existent')).rejects.toThrow( + 'Failed to remove server "non-existent" in cache.', + ); + }); + + it('should remove server from getAll results', async () => { + await cache.add('server1', mockConfig1); + await cache.add('server2', mockConfig2); + + let result = await cache.getAll(); + expect(Object.keys(result).length).toBe(2); + + await cache.remove('server1'); + result = await cache.getAll(); + expect(Object.keys(result).length).toBe(1); + expect(result.server1).toBeUndefined(); + expect(result.server2).toMatchObject(mockConfig2); + }); + + it('should allow re-adding a removed server', async () => { + await cache.add('server1', mockConfig1); + await cache.remove('server1'); + await cache.add('server1', mockConfig3); + + const result = await cache.get('server1'); + expect(result).toMatchObject(mockConfig3); + }); + }); + + describe('concurrent write safety', () => { + it('should handle concurrent add calls without data loss', async () => { + const configCount = 20; + const promises = Array.from({ length: configCount }, (_, i) => + cache.add(`server-${i}`, { + type: 'stdio', + command: `cmd-${i}`, + args: [`arg-${i}`], + } as ParsedServerConfig), + ); + + const results = await Promise.allSettled(promises); + const failures = results.filter((r) => r.status === 'rejected'); + expect(failures).toHaveLength(0); + + const result = await cache.getAll(); + expect(Object.keys(result).length).toBe(configCount); + for (let i = 0; i < configCount; i++) { + expect(result[`server-${i}`]).toBeDefined(); + const config = result[`server-${i}`] as { command?: string }; + expect(config.command).toBe(`cmd-${i}`); + } + }); + + it('should handle concurrent getAll calls', async () => { + await cache.add('server1', mockConfig1); + await cache.add('server2', mockConfig2); + await cache.add('server3', mockConfig3); + + const concurrency = 50; + const promises = Array.from({ length: concurrency }, () => cache.getAll()); + const results = await Promise.all(promises); + + for (const result of results) { + expect(Object.keys(result).length).toBe(3); + expect(result.server1).toMatchObject(mockConfig1); + expect(result.server2).toMatchObject(mockConfig2); + expect(result.server3).toMatchObject(mockConfig3); + } + }); + }); + + describe('reset operation', () => { + it('should clear all configs', async () => { + await cache.add('server1', mockConfig1); + await cache.add('server2', mockConfig2); + + expect(Object.keys(await cache.getAll()).length).toBe(2); + + await cache.reset(); + + const result = await cache.getAll(); + expect(Object.keys(result).length).toBe(0); + }); + }); +}); From 8e2721011e9f13ec168fcfb8e6ccbc158c0fbae5 Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Thu, 26 Mar 2026 14:45:13 -0400 Subject: [PATCH 05/18] =?UTF-8?q?=F0=9F=94=91=20fix:=20Robust=20MCP=20OAut?= =?UTF-8?q?h=20Detection=20in=20Tool-Call=20Flow=20(#12418)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix(api): add buildOAuthToolCallName utility for MCP OAuth flows Extract a shared utility that builds the synthetic tool-call name used during MCP OAuth flows (oauth_mcp_{normalizedServerName}). Uses startsWith on the raw serverName (not the normalized form) to guard against double-wrapping, so names that merely normalize to start with oauth_mcp_ (e.g., oauth@mcp@server) are correctly prefixed while genuinely pre-wrapped names are left as-is. Add 8 unit tests covering normal names, pre-wrapped names, _mcp_ substrings, special characters, non-ASCII, and empty string inputs. * fix(backend): use buildOAuthToolCallName in MCP OAuth flows Replace inline tool-call name construction in both reconnectServer (MCP.js) and createOAuthEmitter (ToolService.js) with the shared buildOAuthToolCallName utility. Remove unused normalizeServerName import from ToolService.js. Fix import ordering in both files. This ensures the oauth_mcp_ prefix is consistently applied so the client correctly identifies MCP OAuth flows and binds the CSRF cookie to the right server. * fix(client): robust MCP OAuth detection and split handling in ToolCall - Fix split() destructuring to preserve tail segments for server names containing _mcp_ (e.g., foo_mcp_bar no longer truncated to foo). - Add auth URL redirect_uri fallback: when the tool-call name lacks the _mcp_ delimiter, parse redirect_uri for the MCP callback path. Set function_name to the extracted server name so progress text shows the server, not the raw tool-call ID. - Display server name instead of literal "oauth" as function_name, gated on auth presence to avoid misidentifying real tools named "oauth". - Consolidate three independent new URL(auth) parses into a single parsedAuthUrl useMemo shared across detection, actionId, and authDomain hooks. - Replace any type on ProgressText test mock with structural type. - Add 8 tests covering delimiter detection, multi-segment names, function_name display, redirect_uri fallback, normalized _mcp_ server names, and non-MCP action auth exclusion. * chore: fix import order in utils.test.ts * fix(client): drop auth gate on OAuth displayName so completed flows show server name The createOAuthEnd handler re-emits the toolCall delta without auth, so auth is cleared on the client after OAuth completes. Gating displayName on `func === 'oauth' && auth` caused completed OAuth steps to render "Completed oauth" instead of "Completed my-server". Remove the `&& auth` gate โ€” within the MCP delimiter branch the func="oauth" check alone is sufficient. Also remove `auth` from the useMemo dep array since only `parsedAuthUrl` is referenced. Update the test to assert correct post-completion display. --- api/server/services/MCP.js | 3 +- api/server/services/ToolService.js | 3 +- .../Chat/Messages/Content/ToolCall.tsx | 66 ++++---- .../Content/__tests__/ToolCall.test.tsx | 150 +++++++++++++++++- packages/api/src/mcp/__tests__/utils.test.ts | 50 +++++- packages/api/src/mcp/utils.ts | 16 ++ 6 files changed, 255 insertions(+), 33 deletions(-) diff --git a/api/server/services/MCP.js b/api/server/services/MCP.js index 5d97891c55..03563a0cfc 100644 --- a/api/server/services/MCP.js +++ b/api/server/services/MCP.js @@ -14,6 +14,7 @@ const { normalizeJsonSchema, GenerationJobManager, resolveJsonSchemaRefs, + buildOAuthToolCallName, } = require('@librechat/api'); const { Time, CacheKeys, Constants, isAssistantsEndpoint } = require('librechat-data-provider'); const { @@ -271,7 +272,7 @@ async function reconnectServer({ const stepId = 'step_oauth_login_' + serverName; const toolCall = { id: flowId, - name: serverName, + name: buildOAuthToolCallName(serverName), type: 'tool_call_chunk', }; diff --git a/api/server/services/ToolService.js b/api/server/services/ToolService.js index ca75e7eb4f..838de906fe 100644 --- a/api/server/services/ToolService.js +++ b/api/server/services/ToolService.js @@ -19,6 +19,7 @@ const { buildWebSearchContext, buildImageToolContext, buildToolClassification, + buildOAuthToolCallName, } = require('@librechat/api'); const { Time, @@ -521,7 +522,7 @@ async function loadToolDefinitionsWrapper({ req, res, agent, streamId = null, to const stepId = 'step_oauth_login_' + serverName; const toolCall = { id: flowId, - name: serverName, + name: buildOAuthToolCallName(serverName), type: 'tool_call_chunk', }; diff --git a/client/src/components/Chat/Messages/Content/ToolCall.tsx b/client/src/components/Chat/Messages/Content/ToolCall.tsx index 5abdd45f98..c7dd974577 100644 --- a/client/src/components/Chat/Messages/Content/ToolCall.tsx +++ b/client/src/components/Chat/Messages/Content/ToolCall.tsx @@ -49,19 +49,47 @@ export default function ToolCall({ } }, [autoExpand, hasOutput]); + const parsedAuthUrl = useMemo(() => { + if (!auth) { + return null; + } + try { + return new URL(auth); + } catch { + return null; + } + }, [auth]); + const { function_name, domain, isMCPToolCall, mcpServerName } = useMemo(() => { if (typeof name !== 'string') { return { function_name: '', domain: null, isMCPToolCall: false, mcpServerName: '' }; } if (name.includes(Constants.mcp_delimiter)) { - const [func, server] = name.split(Constants.mcp_delimiter); + const parts = name.split(Constants.mcp_delimiter); + const func = parts[0]; + const server = parts.slice(1).join(Constants.mcp_delimiter); + const displayName = func === 'oauth' ? server : func; return { - function_name: func || '', + function_name: displayName || '', domain: server && (server.replaceAll(actionDomainSeparator, '.') || null), isMCPToolCall: true, mcpServerName: server || '', }; } + + if (parsedAuthUrl) { + const redirectUri = parsedAuthUrl.searchParams.get('redirect_uri') || ''; + const mcpMatch = redirectUri.match(/\/api\/mcp\/([^/]+)\/oauth\/callback/); + if (mcpMatch?.[1]) { + return { + function_name: mcpMatch[1], + domain: null, + isMCPToolCall: true, + mcpServerName: mcpMatch[1], + }; + } + } + const [func, _domain] = name.includes(actionDelimiter) ? name.split(actionDelimiter) : [name, '']; @@ -71,25 +99,20 @@ export default function ToolCall({ isMCPToolCall: false, mcpServerName: '', }; - }, [name]); + }, [name, parsedAuthUrl]); const toolIconType = useMemo(() => getToolIconType(name), [name]); const mcpIconMap = useMCPIconMap(); const mcpIconUrl = isMCPToolCall ? mcpIconMap.get(mcpServerName) : undefined; const actionId = useMemo(() => { - if (isMCPToolCall || !auth) { + if (isMCPToolCall || !parsedAuthUrl) { return ''; } - try { - const url = new URL(auth); - const redirectUri = url.searchParams.get('redirect_uri') || ''; - const match = redirectUri.match(/\/api\/actions\/([^/]+)\/oauth\/callback/); - return match?.[1] || ''; - } catch { - return ''; - } - }, [auth, isMCPToolCall]); + const redirectUri = parsedAuthUrl.searchParams.get('redirect_uri') || ''; + const match = redirectUri.match(/\/api\/actions\/([^/]+)\/oauth\/callback/); + return match?.[1] || ''; + }, [parsedAuthUrl, isMCPToolCall]); const handleOAuthClick = useCallback(async () => { if (!auth) { @@ -132,21 +155,8 @@ export default function ToolCall({ ); const authDomain = useMemo(() => { - const authURL = auth ?? ''; - if (!authURL) { - return ''; - } - try { - const url = new URL(authURL); - return url.hostname; - } catch (e) { - logger.error( - 'client/src/components/Chat/Messages/Content/ToolCall.tsx - Failed to parse auth URL', - e, - ); - return ''; - } - }, [auth]); + return parsedAuthUrl?.hostname ?? ''; + }, [parsedAuthUrl]); const progress = useProgress(initialProgress); const showCancelled = cancelled || (errorState && !output); diff --git a/client/src/components/Chat/Messages/Content/__tests__/ToolCall.test.tsx b/client/src/components/Chat/Messages/Content/__tests__/ToolCall.test.tsx index 41356412f6..14b4b7e07a 100644 --- a/client/src/components/Chat/Messages/Content/__tests__/ToolCall.test.tsx +++ b/client/src/components/Chat/Messages/Content/__tests__/ToolCall.test.tsx @@ -1,6 +1,6 @@ import React from 'react'; import { RecoilRoot } from 'recoil'; -import { Tools } from 'librechat-data-provider'; +import { Tools, Constants } from 'librechat-data-provider'; import { render, screen, fireEvent } from '@testing-library/react'; import ToolCall from '../ToolCall'; @@ -53,9 +53,20 @@ jest.mock('../ToolCallInfo', () => ({ jest.mock('../ProgressText', () => ({ __esModule: true, - default: ({ onClick, inProgressText, finishedText, _error, _hasInput, _isExpanded }: any) => ( + default: ({ + onClick, + inProgressText, + finishedText, + subtitle, + }: { + onClick?: () => void; + inProgressText?: string; + finishedText?: string; + subtitle?: string; + }) => (

), })); @@ -346,6 +357,141 @@ describe('ToolCall', () => { }); }); + describe('MCP OAuth detection', () => { + const d = Constants.mcp_delimiter; + + it('should detect MCP OAuth from delimiter in tool-call name', () => { + renderWithRecoil( + , + ); + const subtitle = screen.getByTestId('subtitle'); + expect(subtitle.textContent).toBe('via my-server'); + }); + + it('should preserve full server name when it contains the delimiter substring', () => { + renderWithRecoil( + , + ); + const subtitle = screen.getByTestId('subtitle'); + expect(subtitle.textContent).toBe(`via foo${d}bar`); + }); + + it('should display server name (not "oauth") as function_name for OAuth tool calls', () => { + renderWithRecoil( + , + ); + const progressText = screen.getByTestId('progress-text'); + expect(progressText.textContent).toContain('Completed my-server'); + expect(progressText.textContent).not.toContain('Completed oauth'); + }); + + it('should display server name even when auth is cleared (post-completion)', () => { + // After OAuth completes, createOAuthEnd re-emits the toolCall without auth. + // The display should still show the server name, not literal "oauth". + renderWithRecoil( + , + ); + const progressText = screen.getByTestId('progress-text'); + expect(progressText.textContent).toContain('Completed my-server'); + expect(progressText.textContent).not.toContain('Completed oauth'); + }); + + it('should fallback to auth URL redirect_uri when name lacks delimiter', () => { + const authUrl = + 'https://oauth.example.com/authorize?redirect_uri=' + + encodeURIComponent('https://app.example.com/api/mcp/my-server/oauth/callback'); + renderWithRecoil( + , + ); + const subtitle = screen.getByTestId('subtitle'); + expect(subtitle.textContent).toBe('via my-server'); + }); + + it('should display server name (not raw tool-call ID) in fallback path finished text', () => { + const authUrl = + 'https://oauth.example.com/authorize?redirect_uri=' + + encodeURIComponent('https://app.example.com/api/mcp/my-server/oauth/callback'); + renderWithRecoil( + , + ); + const progressText = screen.getByTestId('progress-text'); + expect(progressText.textContent).toContain('Completed my-server'); + expect(progressText.textContent).not.toContain('bare_name'); + }); + + it('should show normalized server name when it contains _mcp_ after prefixing', () => { + // Server named oauth@mcp@server normalizes to oauth_mcp_server, + // gets prefixed to oauth_mcp_oauth_mcp_server. Client parses: + // func="oauth", server="oauth_mcp_server". Visually awkward but + // semantically correct โ€” the normalized name IS oauth_mcp_server. + renderWithRecoil( + , + ); + const subtitle = screen.getByTestId('subtitle'); + expect(subtitle.textContent).toBe(`via oauth${d}server`); + }); + + it('should not misidentify non-MCP action auth as MCP via fallback', () => { + const authUrl = + 'https://oauth.example.com/authorize?redirect_uri=' + + encodeURIComponent('https://app.example.com/api/actions/xyz/oauth/callback'); + renderWithRecoil( + , + ); + expect(screen.queryByTestId('subtitle')).not.toBeInTheDocument(); + }); + }); + describe('A11Y-04: screen reader status announcements', () => { it('includes sr-only aria-live region for status announcements', () => { renderWithRecoil( diff --git a/packages/api/src/mcp/__tests__/utils.test.ts b/packages/api/src/mcp/__tests__/utils.test.ts index e4fb31bdad..c244205b99 100644 --- a/packages/api/src/mcp/__tests__/utils.test.ts +++ b/packages/api/src/mcp/__tests__/utils.test.ts @@ -1,4 +1,9 @@ -import { normalizeServerName, redactServerSecrets, redactAllServerSecrets } from '~/mcp/utils'; +import { + buildOAuthToolCallName, + normalizeServerName, + redactAllServerSecrets, + redactServerSecrets, +} from '~/mcp/utils'; import type { ParsedServerConfig } from '~/mcp/types'; describe('normalizeServerName', () => { @@ -28,6 +33,49 @@ describe('normalizeServerName', () => { }); }); +describe('buildOAuthToolCallName', () => { + it('should prefix a simple server name with oauth_mcp_', () => { + expect(buildOAuthToolCallName('my-server')).toBe('oauth_mcp_my-server'); + }); + + it('should not double-wrap a name that already starts with oauth_mcp_', () => { + expect(buildOAuthToolCallName('oauth_mcp_my-server')).toBe('oauth_mcp_my-server'); + }); + + it('should correctly handle server names containing _mcp_ substring', () => { + const result = buildOAuthToolCallName('my_mcp_server'); + expect(result).toBe('oauth_mcp_my_mcp_server'); + }); + + it('should normalize non-ASCII server names before prefixing', () => { + const result = buildOAuthToolCallName('ๆˆ‘็š„ๆœๅŠก'); + expect(result).toMatch(/^oauth_mcp_server_\d+$/); + }); + + it('should normalize special characters before prefixing', () => { + expect(buildOAuthToolCallName('server@name!')).toBe('oauth_mcp_server_name'); + }); + + it('should handle empty string server name gracefully', () => { + const result = buildOAuthToolCallName(''); + expect(result).toMatch(/^oauth_mcp_server_\d+$/); + }); + + it('should treat a name already starting with oauth_mcp_ as pre-wrapped', () => { + // At the function level, a name starting with the oauth prefix is + // indistinguishable from a pre-wrapped name โ€” guard prevents double-wrapping. + // Server names with this prefix should be blocked at registration time. + expect(buildOAuthToolCallName('oauth_mcp_github')).toBe('oauth_mcp_github'); + }); + + it('should not treat special chars that normalize to oauth_mcp_* as pre-wrapped', () => { + // oauth@mcp@server does NOT start with 'oauth_mcp_' before normalization, + // so the guard correctly does not fire and the prefix is added. + const result = buildOAuthToolCallName('oauth@mcp@server'); + expect(result).toBe('oauth_mcp_oauth_mcp_server'); + }); +}); + describe('redactServerSecrets', () => { it('should strip apiKey.key from admin-sourced keys', () => { const config: ParsedServerConfig = { diff --git a/packages/api/src/mcp/utils.ts b/packages/api/src/mcp/utils.ts index ff367725fc..db89cffada 100644 --- a/packages/api/src/mcp/utils.ts +++ b/packages/api/src/mcp/utils.ts @@ -97,6 +97,22 @@ export function normalizeServerName(serverName: string): string { return normalized; } +/** + * Builds the synthetic tool-call name used during MCP OAuth flows. + * Format: `oauth` + * + * Guards against the caller passing a pre-wrapped name (one that already + * starts with the oauth prefix in its original, un-normalized form) to + * prevent double-wrapping. + */ +export function buildOAuthToolCallName(serverName: string): string { + const oauthPrefix = `oauth${Constants.mcp_delimiter}`; + if (serverName.startsWith(oauthPrefix)) { + return normalizeServerName(serverName); + } + return `${oauthPrefix}${normalizeServerName(serverName)}`; +} + /** * Sanitizes a URL by removing query parameters to prevent credential leakage in logs. * @param url - The URL to sanitize (string or URL object) From 5e3b7bcde3c01427357f15498b5d7094879799dd Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Thu, 26 Mar 2026 16:39:09 -0400 Subject: [PATCH 06/18] =?UTF-8?q?=F0=9F=8C=8A=20refactor:=20Local=20Snapsh?= =?UTF-8?q?ot=20for=20Aggregate=20Key=20Cache=20to=20Avoid=20Redundant=20R?= =?UTF-8?q?edis=20GETs=20(#12422)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * perf: Add local snapshot to aggregate key cache to avoid redundant Redis GETs getAll() was being called 20+ times per chat request (once per tool, per server config lookup, per connection check). Each call hit Redis even though the data doesn't change within a request cycle. Add an in-memory snapshot with 5s TTL that collapses all reads within the window into a single Redis GET. Writes (add/update/remove/reset) invalidate the snapshot immediately so mutations are never stale. Also removes the debug logger that was producing noisy per-call logs. * fix: Prevent snapshot mutation and guarantee cleanup on write failure - Never mutate the snapshot object in-place during writes. Build a new object (spread) so concurrent readers never observe uncommitted state. - Move invalidateLocalSnapshot() into withWriteLock's finally block so cleanup is guaranteed even when successCheck throws on Redis failure. - After successful writes, populate the snapshot with the committed state to avoid an unnecessary Redis GET on the next read. - Use Date.now() after the await in getAll() so the TTL window isn't shortened by Redis latency. - Strengthen tests: spy on underlying Keyv cache to verify N getAll() calls collapse into 1 Redis GET, verify snapshot reference immutability. * fix: Remove dead populateLocalSnapshot calls from write callbacks populateLocalSnapshot was called inside withWriteLock callbacks, but the finally block in withWriteLock always calls invalidateLocalSnapshot immediately after โ€” undoing the populate on every execution path. Remove the dead method and its three call sites. The snapshot is correctly cleared by finally on both success and failure paths. The next getAll() after a write hits Redis once to fetch the committed state, which is acceptable since writes only occur during init and rare manual reinspection. * fix: Derive local snapshot TTL from MCP_REGISTRY_CACHE_TTL config Use cacheConfig.MCP_REGISTRY_CACHE_TTL (default 5000ms) instead of a hardcoded 5s constant. When TTL is 0 (operator explicitly wants no caching), the snapshot is disabled entirely โ€” every getAll() hits Redis. * fix: Add TTL expiry test, document 2ร—TTL staleness, clarify comments - Add missing test for snapshot TTL expiry path (force-expire via localSnapshotExpiry mutation, verify Redis is hit again) - Document 2ร—TTL max cross-instance staleness in localSnapshot JSDoc - Document reset() intentionally bypasses withWriteLock - Add inline comments explaining why early invalidateLocalSnapshot() in write callbacks is distinct from the finally-block cleanup - Update cacheConfig.MCP_REGISTRY_CACHE_TTL JSDoc to reflect both use sites and the staleness implication - Rename misleading test name for snapshot reference immutability - Add epoch sentinel comment on localSnapshotExpiry initialization --- packages/api/src/cache/cacheConfig.ts | 9 +- .../ServerConfigsCacheRedisAggregateKey.ts | 77 ++++++++++++---- ...edisAggregateKey.cache_integration.spec.ts | 92 +++++++++++++++++++ 3 files changed, 159 insertions(+), 19 deletions(-) diff --git a/packages/api/src/cache/cacheConfig.ts b/packages/api/src/cache/cacheConfig.ts index 0d4304f5c3..7b4a899e98 100644 --- a/packages/api/src/cache/cacheConfig.ts +++ b/packages/api/src/cache/cacheConfig.ts @@ -128,8 +128,13 @@ const cacheConfig = { REDIS_SCAN_COUNT: math(process.env.REDIS_SCAN_COUNT, 1000), /** - * TTL in milliseconds for MCP registry read-through cache. - * This cache reduces redundant lookups within a single request flow. + * TTL in milliseconds for MCP registry caches. Used by both: + * - `MCPServersRegistry` read-through caches (`readThroughCache`/`readThroughCacheAll`) + * - `ServerConfigsCacheRedisAggregateKey` local snapshot (avoids redundant Redis GETs) + * + * Both layers use this value, so the effective max cross-instance staleness is up + * to 2ร— this value in multi-instance deployments. Set to 0 to disable the local + * snapshot entirely (every `getAll()` hits Redis directly). * @default 5000 (5 seconds) */ MCP_REGISTRY_CACHE_TTL: math(process.env.MCP_REGISTRY_CACHE_TTL, 5000), diff --git a/packages/api/src/mcp/registry/cache/ServerConfigsCacheRedisAggregateKey.ts b/packages/api/src/mcp/registry/cache/ServerConfigsCacheRedisAggregateKey.ts index 12f423a1fb..e67c1a4a84 100644 --- a/packages/api/src/mcp/registry/cache/ServerConfigsCacheRedisAggregateKey.ts +++ b/packages/api/src/mcp/registry/cache/ServerConfigsCacheRedisAggregateKey.ts @@ -1,9 +1,8 @@ -import { logger } from '@librechat/data-schemas'; import type Keyv from 'keyv'; import type { IServerConfigsRepositoryInterface } from '~/mcp/registry/ServerConfigsRepositoryInterface'; import type { ParsedServerConfig, AddServerResult } from '~/mcp/types'; import { BaseRegistryCache } from './BaseRegistryCache'; -import { standardCache } from '~/cache'; +import { cacheConfig, standardCache } from '~/cache'; /** * Redis-backed MCP server configs cache that stores all entries under a single aggregate key. @@ -37,14 +36,38 @@ export class ServerConfigsCacheRedisAggregateKey protected readonly cache: Keyv; private writeLock: Promise = Promise.resolve(); + /** + * In-memory snapshot of the aggregate key to avoid redundant Redis GETs. + * `getAll()` is called 20+ times per chat request (once per tool, per server + * config lookup, per connection check) but the data doesn't change within a + * request cycle. The snapshot collapses all reads within the TTL window into + * a single Redis GET. Invalidated on every write (`add`, `update`, `remove`, `reset`). + * + * NOTE: In multi-instance deployments, the effective max staleness for cross-instance + * writes is up to 2ร—MCP_REGISTRY_CACHE_TTL. This happens when readThroughCacheAll + * (MCPServersRegistry) is populated from a snapshot that is nearly expired. For the + * default 5000ms TTL, worst-case cross-instance propagation is ~10s. This is acceptable + * given the single-writer invariant (leader-only initialization, rare manual reinspection). + */ + private localSnapshot: Record | null = null; + /** Milliseconds since epoch. 0 = epoch = always expired on first check. */ + private localSnapshotExpiry = 0; + constructor(namespace: string, leaderOnly: boolean) { super(leaderOnly); this.cache = standardCache(`${this.PREFIX}::Servers::${namespace}`); } + private invalidateLocalSnapshot(): void { + this.localSnapshot = null; + this.localSnapshotExpiry = 0; + } + /** * Serializes write operations to prevent concurrent read-modify-write races. * Reads (`get`, `getAll`) are not serialized โ€” they can run concurrently. + * Always invalidates the local snapshot in `finally` to guarantee cleanup + * even when the write callback throws (e.g., Redis SET failure). */ private async withWriteLock(fn: () => Promise): Promise { const previousLock = this.writeLock; @@ -56,20 +79,29 @@ export class ServerConfigsCacheRedisAggregateKey await previousLock; return await fn(); } finally { + this.invalidateLocalSnapshot(); resolve(); } } public async getAll(): Promise> { - const startTime = Date.now(); - const result = (await this.cache.get(AGGREGATE_KEY)) as - | Record - | undefined; - const elapsed = Date.now() - startTime; - logger.debug( - `[ServerConfigsCacheRedisAggregateKey] getAll: fetched ${result ? Object.keys(result).length : 0} configs in ${elapsed}ms`, - ); - return result ?? {}; + const ttl = cacheConfig.MCP_REGISTRY_CACHE_TTL; + if (ttl > 0) { + const now = Date.now(); + if (this.localSnapshot !== null && now < this.localSnapshotExpiry) { + return this.localSnapshot; + } + } + + const result = + ((await this.cache.get(AGGREGATE_KEY)) as Record | undefined) ?? + {}; + + if (ttl > 0) { + this.localSnapshot = result; + this.localSnapshotExpiry = Date.now() + ttl; + } + return result; } public async get(serverName: string): Promise { @@ -80,6 +112,10 @@ export class ServerConfigsCacheRedisAggregateKey public async add(serverName: string, config: ParsedServerConfig): Promise { if (this.leaderOnly) await this.leaderCheck('add MCP servers'); return this.withWriteLock(async () => { + // Force fresh Redis read so the read-modify-write uses current data, + // not a snapshot that may predate this write. Distinct from the finally-block + // invalidation which cleans up after the write completes or throws. + this.invalidateLocalSnapshot(); const all = await this.getAll(); if (all[serverName]) { throw new Error( @@ -87,8 +123,8 @@ export class ServerConfigsCacheRedisAggregateKey ); } const storedConfig = { ...config, updatedAt: Date.now() }; - all[serverName] = storedConfig; - const success = await this.cache.set(AGGREGATE_KEY, all); + const newAll = { ...all, [serverName]: storedConfig }; + const success = await this.cache.set(AGGREGATE_KEY, newAll); this.successCheck(`add App server "${serverName}"`, success); return { serverName, config: storedConfig }; }); @@ -97,14 +133,15 @@ export class ServerConfigsCacheRedisAggregateKey public async update(serverName: string, config: ParsedServerConfig): Promise { if (this.leaderOnly) await this.leaderCheck('update MCP servers'); return this.withWriteLock(async () => { + this.invalidateLocalSnapshot(); // Force fresh Redis read (see add() comment) const all = await this.getAll(); if (!all[serverName]) { throw new Error( `Server "${serverName}" does not exist in cache. Use add() to create new configs.`, ); } - all[serverName] = { ...config, updatedAt: Date.now() }; - const success = await this.cache.set(AGGREGATE_KEY, all); + const newAll = { ...all, [serverName]: { ...config, updatedAt: Date.now() } }; + const success = await this.cache.set(AGGREGATE_KEY, newAll); this.successCheck(`update App server "${serverName}"`, success); }); } @@ -112,12 +149,13 @@ export class ServerConfigsCacheRedisAggregateKey public async remove(serverName: string): Promise { if (this.leaderOnly) await this.leaderCheck('remove MCP servers'); return this.withWriteLock(async () => { + this.invalidateLocalSnapshot(); // Force fresh Redis read (see add() comment) const all = await this.getAll(); if (!all[serverName]) { throw new Error(`Failed to remove server "${serverName}" in cache.`); } - delete all[serverName]; - const success = await this.cache.set(AGGREGATE_KEY, all); + const { [serverName]: _, ...newAll } = all; + const success = await this.cache.set(AGGREGATE_KEY, newAll); this.successCheck(`remove App server "${serverName}"`, success); }); } @@ -126,11 +164,16 @@ export class ServerConfigsCacheRedisAggregateKey * Resets the aggregate key directly instead of using SCAN-based `cache.clear()`. * Only one key (`__all__`) ever exists in this namespace, so a targeted delete is * more efficient and consistent with the PR's goal of eliminating SCAN operations. + * + * NOTE: Intentionally not serialized via `withWriteLock`. `reset()` is only called + * during lifecycle transitions (test teardown, full reinitialization via + * `MCPServersInitializer`) where no concurrent writes are in flight. */ public override async reset(): Promise { if (this.leaderOnly) { await this.leaderCheck('reset App MCP servers cache'); } await this.cache.delete(AGGREGATE_KEY); + this.invalidateLocalSnapshot(); } } diff --git a/packages/api/src/mcp/registry/cache/__tests__/ServerConfigsCacheRedisAggregateKey.cache_integration.spec.ts b/packages/api/src/mcp/registry/cache/__tests__/ServerConfigsCacheRedisAggregateKey.cache_integration.spec.ts index cbb75609d1..5aeb49b206 100644 --- a/packages/api/src/mcp/registry/cache/__tests__/ServerConfigsCacheRedisAggregateKey.cache_integration.spec.ts +++ b/packages/api/src/mcp/registry/cache/__tests__/ServerConfigsCacheRedisAggregateKey.cache_integration.spec.ts @@ -243,4 +243,96 @@ describe('ServerConfigsCacheRedisAggregateKey Integration Tests', () => { expect(Object.keys(result).length).toBe(0); }); }); + + describe('local snapshot behavior', () => { + it('should collapse repeated getAll calls into a single Redis GET within TTL', async () => { + await cache.add('server1', mockConfig1); + await cache.add('server2', mockConfig2); + + // Prime the snapshot + await cache.getAll(); + + // Spy on the underlying Keyv cache to count Redis calls + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const cacheGetSpy = jest.spyOn((cache as any).cache, 'get'); + + await cache.getAll(); + await cache.getAll(); + await cache.getAll(); + + // Snapshot should be served; Redis should NOT have been called + expect(cacheGetSpy).not.toHaveBeenCalled(); + cacheGetSpy.mockRestore(); + }); + + it('should invalidate snapshot after add', async () => { + await cache.add('server1', mockConfig1); + const before = await cache.getAll(); + expect(Object.keys(before).length).toBe(1); + + await cache.add('server2', mockConfig2); + const after = await cache.getAll(); + expect(Object.keys(after).length).toBe(2); + }); + + it('should invalidate snapshot after update and preserve other entries', async () => { + await cache.add('server1', mockConfig1); + await cache.add('server2', mockConfig2); + expect((await cache.getAll()).server1).toMatchObject(mockConfig1); + + await cache.update('server1', mockConfig3); + const after = await cache.getAll(); + expect(after.server1).toMatchObject(mockConfig3); + expect(after.server2).toMatchObject(mockConfig2); + }); + + it('should invalidate snapshot after remove', async () => { + await cache.add('server1', mockConfig1); + await cache.add('server2', mockConfig2); + expect(Object.keys(await cache.getAll()).length).toBe(2); + + await cache.remove('server1'); + const after = await cache.getAll(); + expect(Object.keys(after).length).toBe(1); + expect(after.server1).toBeUndefined(); + expect(after.server2).toMatchObject(mockConfig2); + }); + + it('should invalidate snapshot after reset', async () => { + await cache.add('server1', mockConfig1); + expect(Object.keys(await cache.getAll()).length).toBe(1); + + await cache.reset(); + expect(Object.keys(await cache.getAll()).length).toBe(0); + }); + + it('should not retroactively modify previously returned snapshot references', async () => { + await cache.add('server1', mockConfig1); + + // Prime the snapshot + const snapshot = await cache.getAll(); + expect(Object.keys(snapshot).length).toBe(1); + + // Add a second server โ€” the original snapshot reference should be unmodified + await cache.add('server2', mockConfig2); + expect(Object.keys(snapshot).length).toBe(1); + expect(snapshot.server2).toBeUndefined(); + }); + + it('should hit Redis again after snapshot TTL expires', async () => { + await cache.add('server1', mockConfig1); + await cache.getAll(); // prime snapshot + + // Force-expire the snapshot without sleeping + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (cache as any).localSnapshotExpiry = Date.now() - 1; + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const cacheGetSpy = jest.spyOn((cache as any).cache, 'get'); + const result = await cache.getAll(); + expect(cacheGetSpy).toHaveBeenCalledTimes(1); + expect(Object.keys(result).length).toBe(1); + cacheGetSpy.mockRestore(); + }); + }); }); From 083042e56cdd4b99fbda254b2e253be559a109be Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Thu, 26 Mar 2026 16:40:37 -0400 Subject: [PATCH 07/18] =?UTF-8?q?=F0=9F=AA=9D=20fix:=20Safe=20Hook=20Fallb?= =?UTF-8?q?acks=20for=20Tool-Call=20Components=20in=20Search=20Route=20(#1?= =?UTF-8?q?2423)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: add useOptionalMessagesOperations hook for context-safe message operations Add a variant of useMessagesOperations that returns no-op functions when MessagesViewProvider is absent instead of throwing, enabling shared components to render safely outside the chat route. * fix: use optional message operations in ToolCallInfo and UIResourceCarousel Switch ToolCallInfo and UIResourceCarousel from useMessagesOperations to useOptionalMessagesOperations so they no longer crash when rendered in the /search route, which lacks MessagesViewProvider. * fix: update test mocks to use useOptionalMessagesOperations * fix: consolidate noops and narrow useMemo dependency in useOptionalMessagesOperations - Replace three noop variants (noopAsync, noopReturn, noop) with a single `const noop = () => undefined` that correctly returns void/undefined - Destructure individual fields from context before the useMemo so the dependency array tracks stable operation references, not the full context object (avoiding unnecessary re-renders on unrelated state changes) - Add useOptionalMessagesConversation for components that need conversation data outside MessagesViewProvider * fix: use optional hooks in MCPUIResource components to prevent search crash MCPUIResource and MCPUIResourceCarousel render inside Markdown prose and can appear in the /search route. Switch them from the strict useMessagesOperations/useMessagesConversation hooks to the optional variants that return safe defaults when MessagesViewProvider is absent. * test: update test mocks for optional hook renames * fix: update ToolCallInfo and UIResourceCarousel test mocks to useOptionalMessagesOperations * fix: use optional message operations in useConversationUIResources useConversationUIResources internally called the strict useMessagesOperations(), which still threw when MCPUIResource rendered outside MessagesViewProvider. Switch to useOptionalMessagesOperations so the entire MCPUIResource render chain is safe in the /search route. * style: fix import order per project conventions * fix: replace as-unknown-as casts with typed NOOP_OPS stubs - Define OptionalMessagesOps type and NOOP_OPS constant with properly typed no-op functions, eliminating all `as unknown as T` casts - Use conversationId directly from useOptionalMessagesConversation instead of re-deriving it from conversation object - Update JSDoc to reflect search route support * test: add no-provider regression tests for optional message hooks Verify useOptionalMessagesOperations and useOptionalMessagesConversation return safe defaults when rendered outside MessagesViewProvider, covering the core crash path this PR fixes. --- client/src/Providers/MessagesViewContext.tsx | 49 +++++++++++++++++ .../__tests__/MessagesViewContext.spec.tsx | 53 +++++++++++++++++++ .../Chat/Messages/Content/ToolCallInfo.tsx | 6 +-- .../Messages/Content/UIResourceCarousel.tsx | 4 +- .../Content/__tests__/Markdown.mcpui.test.tsx | 18 ++++--- .../Content/__tests__/ToolCallInfo.test.tsx | 2 +- .../__tests__/UIResourceCarousel.test.tsx | 4 +- .../MCPUIResource/MCPUIResource.tsx | 17 +++--- .../MCPUIResource/MCPUIResourceCarousel.tsx | 17 +++--- .../__tests__/MCPUIResource.test.tsx | 14 +++-- .../__tests__/MCPUIResourceCarousel.test.tsx | 14 +++-- .../Messages/useConversationUIResources.ts | 4 +- 12 files changed, 153 insertions(+), 49 deletions(-) create mode 100644 client/src/Providers/__tests__/MessagesViewContext.spec.tsx diff --git a/client/src/Providers/MessagesViewContext.tsx b/client/src/Providers/MessagesViewContext.tsx index f1cae204a4..c44972918c 100644 --- a/client/src/Providers/MessagesViewContext.tsx +++ b/client/src/Providers/MessagesViewContext.tsx @@ -140,6 +140,55 @@ export function useMessagesOperations() { ); } +type OptionalMessagesOps = Pick< + MessagesViewContextValue, + 'ask' | 'regenerate' | 'handleContinue' | 'getMessages' | 'setMessages' +>; + +const NOOP_OPS: OptionalMessagesOps = { + ask: () => {}, + regenerate: () => {}, + handleContinue: () => {}, + getMessages: () => undefined, + setMessages: () => {}, +}; + +/** + * Hook for components that need message operations but may render outside MessagesViewProvider + * (e.g. the /search route). Returns no-op stubs when the provider is absent โ€” UI actions will + * be silently discarded rather than crashing. Callers must use optional chaining on + * `getMessages()` results, as it returns `undefined` outside the provider. + */ +export function useOptionalMessagesOperations(): OptionalMessagesOps { + const context = useContext(MessagesViewContext); + const ask = context?.ask; + const regenerate = context?.regenerate; + const handleContinue = context?.handleContinue; + const getMessages = context?.getMessages; + const setMessages = context?.setMessages; + return useMemo( + () => ({ + ask: ask ?? NOOP_OPS.ask, + regenerate: regenerate ?? NOOP_OPS.regenerate, + handleContinue: handleContinue ?? NOOP_OPS.handleContinue, + getMessages: getMessages ?? NOOP_OPS.getMessages, + setMessages: setMessages ?? NOOP_OPS.setMessages, + }), + [ask, regenerate, handleContinue, getMessages, setMessages], + ); +} + +/** + * Hook for components that need conversation data but may render outside MessagesViewProvider + * (e.g. the /search route). Returns `undefined` for both fields when the provider is absent. + */ +export function useOptionalMessagesConversation() { + const context = useContext(MessagesViewContext); + const conversation = context?.conversation; + const conversationId = context?.conversationId; + return useMemo(() => ({ conversation, conversationId }), [conversation, conversationId]); +} + /** Hook for components that only need message state */ export function useMessagesState() { const { index, latestMessageId, latestMessageDepth, setLatestMessage } = useMessagesViewContext(); diff --git a/client/src/Providers/__tests__/MessagesViewContext.spec.tsx b/client/src/Providers/__tests__/MessagesViewContext.spec.tsx new file mode 100644 index 0000000000..88cd6f702d --- /dev/null +++ b/client/src/Providers/__tests__/MessagesViewContext.spec.tsx @@ -0,0 +1,53 @@ +import { renderHook } from '@testing-library/react'; +import { + useOptionalMessagesOperations, + useOptionalMessagesConversation, +} from '../MessagesViewContext'; + +describe('useOptionalMessagesOperations', () => { + it('returns noop stubs when rendered outside MessagesViewProvider', () => { + const { result } = renderHook(() => useOptionalMessagesOperations()); + + expect(result.current.ask).toBeInstanceOf(Function); + expect(result.current.regenerate).toBeInstanceOf(Function); + expect(result.current.handleContinue).toBeInstanceOf(Function); + expect(result.current.getMessages).toBeInstanceOf(Function); + expect(result.current.setMessages).toBeInstanceOf(Function); + }); + + it('noop stubs do not throw when called', () => { + const { result } = renderHook(() => useOptionalMessagesOperations()); + + expect(() => result.current.ask({} as never)).not.toThrow(); + expect(() => result.current.regenerate({} as never)).not.toThrow(); + expect(() => result.current.handleContinue({} as never)).not.toThrow(); + expect(() => result.current.setMessages([])).not.toThrow(); + }); + + it('getMessages returns undefined outside the provider', () => { + const { result } = renderHook(() => useOptionalMessagesOperations()); + expect(result.current.getMessages()).toBeUndefined(); + }); + + it('returns stable references across re-renders', () => { + const { result, rerender } = renderHook(() => useOptionalMessagesOperations()); + const first = result.current; + rerender(); + expect(result.current).toBe(first); + }); +}); + +describe('useOptionalMessagesConversation', () => { + it('returns undefined fields when rendered outside MessagesViewProvider', () => { + const { result } = renderHook(() => useOptionalMessagesConversation()); + expect(result.current.conversation).toBeUndefined(); + expect(result.current.conversationId).toBeUndefined(); + }); + + it('returns stable references across re-renders', () => { + const { result, rerender } = renderHook(() => useOptionalMessagesConversation()); + const first = result.current; + rerender(); + expect(result.current).toBe(first); + }); +}); diff --git a/client/src/components/Chat/Messages/Content/ToolCallInfo.tsx b/client/src/components/Chat/Messages/Content/ToolCallInfo.tsx index 59a564be4d..79ac78dbb2 100644 --- a/client/src/components/Chat/Messages/Content/ToolCallInfo.tsx +++ b/client/src/components/Chat/Messages/Content/ToolCallInfo.tsx @@ -3,11 +3,11 @@ import { ChevronDown } from 'lucide-react'; import { Tools } from 'librechat-data-provider'; import { UIResourceRenderer } from '@mcp-ui/client'; import type { TAttachment, UIResource } from 'librechat-data-provider'; +import { useOptionalMessagesOperations } from '~/Providers'; import { useLocalize, useExpandCollapse } from '~/hooks'; import UIResourceCarousel from './UIResourceCarousel'; -import { useMessagesOperations } from '~/Providers'; -import { OutputRenderer } from './ToolOutput'; import { handleUIAction, cn } from '~/utils'; +import { OutputRenderer } from './ToolOutput'; function isSimpleObject(obj: unknown): obj is Record { if (typeof obj !== 'object' || obj === null || Array.isArray(obj)) { @@ -102,7 +102,7 @@ export default function ToolCallInfo({ attachments?: TAttachment[]; }) { const localize = useLocalize(); - const { ask } = useMessagesOperations(); + const { ask } = useOptionalMessagesOperations(); const [showParams, setShowParams] = useState(false); const { style: paramsExpandStyle, ref: paramsExpandRef } = useExpandCollapse(showParams); diff --git a/client/src/components/Chat/Messages/Content/UIResourceCarousel.tsx b/client/src/components/Chat/Messages/Content/UIResourceCarousel.tsx index c0829e5ad9..4cafa643c6 100644 --- a/client/src/components/Chat/Messages/Content/UIResourceCarousel.tsx +++ b/client/src/components/Chat/Messages/Content/UIResourceCarousel.tsx @@ -1,7 +1,7 @@ import React, { useState } from 'react'; import { UIResourceRenderer } from '@mcp-ui/client'; import type { UIResource } from 'librechat-data-provider'; -import { useMessagesOperations } from '~/Providers'; +import { useOptionalMessagesOperations } from '~/Providers'; import { handleUIAction } from '~/utils'; interface UIResourceCarouselProps { @@ -13,7 +13,7 @@ const UIResourceCarousel: React.FC = React.memo(({ uiRe const [showRightArrow, setShowRightArrow] = useState(true); const [isContainerHovered, setIsContainerHovered] = useState(false); const scrollContainerRef = React.useRef(null); - const { ask } = useMessagesOperations(); + const { ask } = useOptionalMessagesOperations(); const handleScroll = React.useCallback(() => { if (!scrollContainerRef.current) return; diff --git a/client/src/components/Chat/Messages/Content/__tests__/Markdown.mcpui.test.tsx b/client/src/components/Chat/Messages/Content/__tests__/Markdown.mcpui.test.tsx index 6df66c9e15..6ca06056fa 100644 --- a/client/src/components/Chat/Messages/Content/__tests__/Markdown.mcpui.test.tsx +++ b/client/src/components/Chat/Messages/Content/__tests__/Markdown.mcpui.test.tsx @@ -3,7 +3,11 @@ import { render, screen } from '@testing-library/react'; import Markdown from '../Markdown'; import { RecoilRoot } from 'recoil'; import { UI_RESOURCE_MARKER } from '~/components/MCPUIResource/plugin'; -import { useMessageContext, useMessagesConversation, useMessagesOperations } from '~/Providers'; +import { + useMessageContext, + useOptionalMessagesConversation, + useOptionalMessagesOperations, +} from '~/Providers'; import { useGetMessagesByConvoId } from '~/data-provider'; import { useLocalize } from '~/hooks'; @@ -12,8 +16,8 @@ import { useLocalize } from '~/hooks'; jest.mock('~/Providers', () => ({ ...jest.requireActual('~/Providers'), useMessageContext: jest.fn(), - useMessagesConversation: jest.fn(), - useMessagesOperations: jest.fn(), + useOptionalMessagesConversation: jest.fn(), + useOptionalMessagesOperations: jest.fn(), })); jest.mock('~/data-provider'); jest.mock('~/hooks'); @@ -26,11 +30,11 @@ jest.mock('@mcp-ui/client', () => ({ })); const mockUseMessageContext = useMessageContext as jest.MockedFunction; -const mockUseMessagesConversation = useMessagesConversation as jest.MockedFunction< - typeof useMessagesConversation +const mockUseMessagesConversation = useOptionalMessagesConversation as jest.MockedFunction< + typeof useOptionalMessagesConversation >; -const mockUseMessagesOperations = useMessagesOperations as jest.MockedFunction< - typeof useMessagesOperations +const mockUseMessagesOperations = useOptionalMessagesOperations as jest.MockedFunction< + typeof useOptionalMessagesOperations >; const mockUseGetMessagesByConvoId = useGetMessagesByConvoId as jest.MockedFunction< typeof useGetMessagesByConvoId diff --git a/client/src/components/Chat/Messages/Content/__tests__/ToolCallInfo.test.tsx b/client/src/components/Chat/Messages/Content/__tests__/ToolCallInfo.test.tsx index 4a4d80ae8d..38b792ccae 100644 --- a/client/src/components/Chat/Messages/Content/__tests__/ToolCallInfo.test.tsx +++ b/client/src/components/Chat/Messages/Content/__tests__/ToolCallInfo.test.tsx @@ -25,7 +25,7 @@ jest.mock('~/hooks', () => ({ })); jest.mock('~/Providers', () => ({ - useMessagesOperations: () => ({ + useOptionalMessagesOperations: () => ({ ask: jest.fn(), }), })); diff --git a/client/src/components/Chat/Messages/Content/__tests__/UIResourceCarousel.test.tsx b/client/src/components/Chat/Messages/Content/__tests__/UIResourceCarousel.test.tsx index 6d208c2cf2..6e472e3f49 100644 --- a/client/src/components/Chat/Messages/Content/__tests__/UIResourceCarousel.test.tsx +++ b/client/src/components/Chat/Messages/Content/__tests__/UIResourceCarousel.test.tsx @@ -13,10 +13,10 @@ jest.mock('@mcp-ui/client', () => ({ ), })); -// Mock useMessagesOperations hook +// Mock useOptionalMessagesOperations hook const mockAsk = jest.fn(); jest.mock('~/Providers', () => ({ - useMessagesOperations: () => ({ + useOptionalMessagesOperations: () => ({ ask: mockAsk, }), })); diff --git a/client/src/components/MCPUIResource/MCPUIResource.tsx b/client/src/components/MCPUIResource/MCPUIResource.tsx index ddf65c4388..692db889c9 100644 --- a/client/src/components/MCPUIResource/MCPUIResource.tsx +++ b/client/src/components/MCPUIResource/MCPUIResource.tsx @@ -1,8 +1,8 @@ import React from 'react'; import { UIResourceRenderer } from '@mcp-ui/client'; -import { handleUIAction } from '~/utils'; +import { useOptionalMessagesConversation, useOptionalMessagesOperations } from '~/Providers'; import { useConversationUIResources } from '~/hooks/Messages/useConversationUIResources'; -import { useMessagesConversation, useMessagesOperations } from '~/Providers'; +import { handleUIAction } from '~/utils'; import { useLocalize } from '~/hooks'; interface MCPUIResourceProps { @@ -13,19 +13,14 @@ interface MCPUIResourceProps { }; } -/** - * Component that renders an MCP UI resource based on its resource ID. - * Works in both main app and share view. - */ +/** Renders an MCP UI resource based on its resource ID. Works in chat, share, and search views. */ export function MCPUIResource(props: MCPUIResourceProps) { const { resourceId } = props.node.properties; const localize = useLocalize(); - const { ask } = useMessagesOperations(); - const { conversation } = useMessagesConversation(); + const { ask } = useOptionalMessagesOperations(); + const { conversationId } = useOptionalMessagesConversation(); - const conversationResourceMap = useConversationUIResources( - conversation?.conversationId ?? undefined, - ); + const conversationResourceMap = useConversationUIResources(conversationId ?? undefined); const uiResource = conversationResourceMap.get(resourceId ?? ''); diff --git a/client/src/components/MCPUIResource/MCPUIResourceCarousel.tsx b/client/src/components/MCPUIResource/MCPUIResourceCarousel.tsx index cf32318491..ba81a2f153 100644 --- a/client/src/components/MCPUIResource/MCPUIResourceCarousel.tsx +++ b/client/src/components/MCPUIResource/MCPUIResourceCarousel.tsx @@ -1,8 +1,8 @@ import React, { useMemo } from 'react'; -import { useConversationUIResources } from '~/hooks/Messages/useConversationUIResources'; -import { useMessagesConversation } from '~/Providers'; -import UIResourceCarousel from '../Chat/Messages/Content/UIResourceCarousel'; import type { UIResource } from 'librechat-data-provider'; +import { useConversationUIResources } from '~/hooks/Messages/useConversationUIResources'; +import UIResourceCarousel from '../Chat/Messages/Content/UIResourceCarousel'; +import { useOptionalMessagesConversation } from '~/Providers'; interface MCPUIResourceCarouselProps { node: { @@ -12,16 +12,11 @@ interface MCPUIResourceCarouselProps { }; } -/** - * Component that renders multiple MCP UI resources in a carousel. - * Works in both main app and share view. - */ +/** Renders multiple MCP UI resources in a carousel. Works in chat, share, and search views. */ export function MCPUIResourceCarousel(props: MCPUIResourceCarouselProps) { - const { conversation } = useMessagesConversation(); + const { conversationId } = useOptionalMessagesConversation(); - const conversationResourceMap = useConversationUIResources( - conversation?.conversationId ?? undefined, - ); + const conversationResourceMap = useConversationUIResources(conversationId ?? undefined); const uiResources = useMemo(() => { const { resourceIds = [] } = props.node.properties; diff --git a/client/src/components/MCPUIResource/__tests__/MCPUIResource.test.tsx b/client/src/components/MCPUIResource/__tests__/MCPUIResource.test.tsx index 53896bb6fe..c37b6d5d51 100644 --- a/client/src/components/MCPUIResource/__tests__/MCPUIResource.test.tsx +++ b/client/src/components/MCPUIResource/__tests__/MCPUIResource.test.tsx @@ -2,7 +2,11 @@ import React from 'react'; import { render, screen } from '@testing-library/react'; import { RecoilRoot } from 'recoil'; import { MCPUIResource } from '../MCPUIResource'; -import { useMessageContext, useMessagesConversation, useMessagesOperations } from '~/Providers'; +import { + useMessageContext, + useOptionalMessagesConversation, + useOptionalMessagesOperations, +} from '~/Providers'; import { useLocalize } from '~/hooks'; import { handleUIAction } from '~/utils'; @@ -22,11 +26,11 @@ jest.mock('@mcp-ui/client', () => ({ })); const mockUseMessageContext = useMessageContext as jest.MockedFunction; -const mockUseMessagesConversation = useMessagesConversation as jest.MockedFunction< - typeof useMessagesConversation +const mockUseMessagesConversation = useOptionalMessagesConversation as jest.MockedFunction< + typeof useOptionalMessagesConversation >; -const mockUseMessagesOperations = useMessagesOperations as jest.MockedFunction< - typeof useMessagesOperations +const mockUseMessagesOperations = useOptionalMessagesOperations as jest.MockedFunction< + typeof useOptionalMessagesOperations >; const mockUseLocalize = useLocalize as jest.MockedFunction; const mockHandleUIAction = handleUIAction as jest.MockedFunction; diff --git a/client/src/components/MCPUIResource/__tests__/MCPUIResourceCarousel.test.tsx b/client/src/components/MCPUIResource/__tests__/MCPUIResourceCarousel.test.tsx index a9f7962ab0..9a5ca934a0 100644 --- a/client/src/components/MCPUIResource/__tests__/MCPUIResourceCarousel.test.tsx +++ b/client/src/components/MCPUIResource/__tests__/MCPUIResourceCarousel.test.tsx @@ -2,7 +2,11 @@ import React from 'react'; import { render, screen } from '@testing-library/react'; import { RecoilRoot } from 'recoil'; import { MCPUIResourceCarousel } from '../MCPUIResourceCarousel'; -import { useMessageContext, useMessagesConversation, useMessagesOperations } from '~/Providers'; +import { + useMessageContext, + useOptionalMessagesConversation, + useOptionalMessagesOperations, +} from '~/Providers'; // Mock dependencies jest.mock('~/Providers'); @@ -19,11 +23,11 @@ jest.mock('../../Chat/Messages/Content/UIResourceCarousel', () => ({ })); const mockUseMessageContext = useMessageContext as jest.MockedFunction; -const mockUseMessagesConversation = useMessagesConversation as jest.MockedFunction< - typeof useMessagesConversation +const mockUseMessagesConversation = useOptionalMessagesConversation as jest.MockedFunction< + typeof useOptionalMessagesConversation >; -const mockUseMessagesOperations = useMessagesOperations as jest.MockedFunction< - typeof useMessagesOperations +const mockUseMessagesOperations = useOptionalMessagesOperations as jest.MockedFunction< + typeof useOptionalMessagesOperations >; describe('MCPUIResourceCarousel', () => { diff --git a/client/src/hooks/Messages/useConversationUIResources.ts b/client/src/hooks/Messages/useConversationUIResources.ts index 2333f64e5f..28e9aa035a 100644 --- a/client/src/hooks/Messages/useConversationUIResources.ts +++ b/client/src/hooks/Messages/useConversationUIResources.ts @@ -2,7 +2,7 @@ import { useMemo } from 'react'; import { useRecoilValue } from 'recoil'; import { Tools } from 'librechat-data-provider'; import type { TAttachment, UIResource } from 'librechat-data-provider'; -import { useMessagesOperations } from '~/Providers'; +import { useOptionalMessagesOperations } from '~/Providers'; import store from '~/store'; /** @@ -16,7 +16,7 @@ import store from '~/store'; export function useConversationUIResources( conversationId: string | undefined, ): Map { - const { getMessages } = useMessagesOperations(); + const { getMessages } = useOptionalMessagesOperations(); const conversationAttachmentsMap = useRecoilValue( store.conversationAttachmentsSelector(conversationId), From 9f6d8c6e9384a18c49e198351f828d3ab72b4063 Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Thu, 26 Mar 2026 17:35:00 -0400 Subject: [PATCH 08/18] =?UTF-8?q?=F0=9F=A7=B5=20feat:=20ALS=20Context=20Mi?= =?UTF-8?q?ddleware,=20Tenant=20Threading,=20and=20Config=20Cache=20Invali?= =?UTF-8?q?dation=20(#12407)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: add tenant context middleware for ALS-based isolation Introduces tenantContextMiddleware that propagates req.user.tenantId into AsyncLocalStorage, activating the Mongoose applyTenantIsolation plugin for all downstream DB queries within a request. - Strict mode (TENANT_ISOLATION_STRICT=true) returns 403 if no tenantId - Non-strict mode passes through for backward compatibility - No-op for unauthenticated requests - Includes 6 unit tests covering all paths * feat: register tenant middleware and wrap startup/auth in runAsSystem() - Register tenantContextMiddleware in Express app after capability middleware - Wrap server startup initialization in runAsSystem() for strict mode compat - Wrap auth strategy getAppConfig() calls in runAsSystem() since they run before user context is established (LDAP, SAML, OpenID, social login, AuthService) * feat: thread tenantId through all getAppConfig callers Pass tenantId from req.user to getAppConfig() across all callers that have request context, ensuring correct per-tenant cache key resolution. Also fixes getBaseConfig admin endpoint to scope to requesting admin's tenant instead of returning the unscoped base config. Files updated: - Controllers: UserController, PluginController - Middleware: checkDomainAllowed, balance - Routes: config - Services: loadConfigModels, loadDefaultModels, getEndpointsConfig, MCP - Audio services: TTSService, STTService, getVoices, getCustomConfigSpeech - Admin: getBaseConfig endpoint * feat: add config cache invalidation on admin mutations - Add clearOverrideCache(tenantId?) to flush per-principal override caches by enumerating Keyv store keys matching _OVERRIDE_: prefix - Add invalidateConfigCaches() helper that clears base config, override caches, tool caches, and endpoint config cache in one call - Wire invalidation into all 5 admin config mutation handlers (upsert, patch, delete field, delete overrides, toggle active) - Add strict mode warning when __default__ tenant fallback is used - Add 3 new tests for clearOverrideCache (all/scoped/base-preserving) * chore: update getUserPrincipals comment to reflect ALS-based tenant filtering The TODO(#12091) about missing tenantId filtering is resolved by the tenant context middleware + applyTenantIsolation Mongoose plugin. Group queries are now automatically scoped by tenantId via ALS. * fix: replace runAsSystem with baseOnly for pre-tenant code paths App configs are tenant-owned โ€” runAsSystem() would bypass tenant isolation and return cross-tenant DB overrides. Instead, add baseOnly option to getAppConfig() that returns YAML-derived config only, with zero DB queries. All startup code, auth strategies, and MCP initialization now use getAppConfig({ baseOnly: true }) to get the YAML config without touching the Config collection. * fix: address PR review findings โ€” middleware ordering, types, cache safety - Chain tenantContextMiddleware inside requireJwtAuth after passport auth instead of global app.use() where req.user is always undefined (Finding 1) - Remove global tenantContextMiddleware registration from index.js - Update BalanceMiddlewareOptions to include tenantId, remove redundant cast (Finding 4) - Add warning log when clearOverrideCache cannot enumerate keys on Redis (Finding 3) - Use startsWith instead of includes for cache key filtering (Finding 12) - Use generator loop instead of Array.from for key enumeration (Finding 3) - Selective barrel export โ€” exclude _resetTenantMiddlewareStrictCache (Finding 5) - Move isMainThread check to module level, remove per-request check (Finding 9) - Move mid-file require to top of app.js (Finding 8) - Parallelize invalidateConfigCaches with Promise.all (Finding 10) - Remove clearOverrideCache from public app.js exports (internal only) - Strengthen getUserPrincipals comment re: ALS dependency (Finding 2) * fix: restore runAsSystem for startup DB ops, consolidate require, clarify baseOnly - Restore runAsSystem() around performStartupChecks, updateInterfacePermissions, initializeMCPs, and initializeOAuthReconnectManager โ€” these make Mongoose queries that need system context in strict tenant mode (NEW-3) - Consolidate duplicate require('@librechat/api') in requireJwtAuth.js (NEW-1) - Document that baseOnly ignores role/userId/tenantId in JSDoc (NEW-2) * test: add requireJwtAuth tenant chaining + invalidateConfigCaches tests - requireJwtAuth: 5 tests verifying ALS tenant context is set after passport auth, isolated between concurrent requests, and not set when user has no tenantId (Finding 6) - invalidateConfigCaches: 4 tests verifying all four caches are cleared, tenantId is threaded through, partial failure is handled gracefully, and operations run in parallel via Promise.all (Finding 11) * fix: address Copilot review โ€” passport errors, namespaced cache keys, /base scoping - Forward passport errors in requireJwtAuth before entering tenant middleware โ€” prevents silent auth failures from reaching handlers (P1) - Account for Keyv namespace prefix in clearOverrideCache โ€” stored keys are namespaced as "APP_CONFIG:_OVERRIDE_:..." not "_OVERRIDE_:...", so override caches were never actually matched/cleared (P2) - Remove role from getBaseConfig โ€” /base should return tenant-scoped base config, not role-merged config that drifts per admin role (P2) - Return tenantStorage.run() for cleaner async semantics - Update mock cache in service.spec.ts to simulate Keyv namespacing * fix: address second review โ€” cache safety, code quality, test reliability - Decouple cache invalidation from mutation response: fire-and-forget with logging so DB mutation success is not masked by cache failures - Extract clearEndpointConfigCache helper from inline IIFE - Move isMainThread check to lazy once-per-process guard (no import side effect) - Memoize process.env read in overrideCacheKey to avoid per-request env lookups and log flooding in strict mode - Remove flaky timer-based parallelism assertion, use structural check - Merge orphaned double JSDoc block on getUserPrincipals - Fix stale [getAppConfig] log prefix โ†’ [ensureBaseConfig] - Fix import order in tenant.spec.ts (package types before local values) - Replace "Finding 1" reference with self-contained description - Use real tenantStorage primitives in requireJwtAuth spec mock * fix: move JSDoc to correct function after clearEndpointConfigCache extraction * refactor: remove Redis SCAN from clearOverrideCache, rely on TTL expiry Redis SCAN causes 60s+ stalls under concurrent load (see #12410). APP_CONFIG defaults to FORCED_IN_MEMORY_CACHE_NAMESPACES, so the in-memory store.keys() path handles the standard case. When APP_CONFIG is Redis-backed, overrides expire naturally via overrideCacheTtl (60s default) โ€” an acceptable window for admin config mutations. * fix: remove return from tenantStorage.run to satisfy void middleware signature * fix: address second review โ€” cache safety, code quality, test reliability - Switch invalidateConfigCaches from Promise.all to Promise.allSettled so partial failures are logged individually instead of producing one undifferentiated error (Finding 3) - Gate overrideCacheKey strict-mode warning behind a once-per-process flag to prevent log flooding under load (Finding 4) - Add test for passport error forwarding in requireJwtAuth โ€” the if (err) { return next(err) } branch now has coverage (Finding 5) - Add test for real partial failure in invalidateConfigCaches where clearAppConfigCache rejects (not just the swallowed endpoint error) * chore: reorder imports in index.js and app.js for consistency - Moved logger and runAsSystem imports to maintain a consistent import order across files. - Improved code readability by ensuring related imports are grouped together. --- api/server/controllers/PluginController.js | 5 +- api/server/controllers/UserController.js | 4 +- api/server/index.js | 16 +- .../__tests__/requireJwtAuth.spec.js | 116 +++++++++++++++ api/server/middleware/checkDomainAllowed.js | 1 + api/server/middleware/requireJwtAuth.js | 23 ++- api/server/routes/admin/config.js | 3 +- api/server/routes/config.js | 2 +- api/server/services/AuthService.js | 4 +- .../__tests__/invalidateConfigCaches.spec.js | 137 ++++++++++++++++++ api/server/services/Config/app.js | 43 +++++- .../services/Config/getEndpointsConfig.js | 3 +- .../services/Config/loadConfigModels.js | 2 +- .../services/Config/loadDefaultModels.js | 3 +- api/server/services/Files/Audio/STTService.js | 1 + api/server/services/Files/Audio/TTSService.js | 2 + .../Files/Audio/getCustomConfigSpeech.js | 1 + api/server/services/Files/Audio/getVoices.js | 1 + api/server/services/MCP.js | 4 +- api/server/services/initializeMCPs.js | 2 +- api/strategies/ldapStrategy.js | 2 +- api/strategies/openidStrategy.js | 2 +- api/strategies/samlStrategy.js | 2 +- api/strategies/socialLogin.js | 2 +- packages/api/src/admin/config.ts | 22 ++- packages/api/src/app/service.spec.ts | 102 ++++++++++++- packages/api/src/app/service.ts | 118 +++++++++++++-- .../src/middleware/__tests__/tenant.spec.ts | 101 +++++++++++++ packages/api/src/middleware/balance.ts | 11 +- packages/api/src/middleware/index.ts | 1 + packages/api/src/middleware/tenant.ts | 70 +++++++++ .../data-schemas/src/methods/userGroup.ts | 25 ++-- 32 files changed, 768 insertions(+), 63 deletions(-) create mode 100644 api/server/middleware/__tests__/requireJwtAuth.spec.js create mode 100644 api/server/services/Config/__tests__/invalidateConfigCaches.spec.js create mode 100644 packages/api/src/middleware/__tests__/tenant.spec.ts create mode 100644 packages/api/src/middleware/tenant.ts diff --git a/api/server/controllers/PluginController.js b/api/server/controllers/PluginController.js index 279ffb15fd..14dd284c30 100644 --- a/api/server/controllers/PluginController.js +++ b/api/server/controllers/PluginController.js @@ -15,7 +15,7 @@ const getAvailablePluginsController = async (req, res) => { return; } - const appConfig = await getAppConfig({ role: req.user?.role }); + const appConfig = await getAppConfig({ role: req.user?.role, tenantId: req.user?.tenantId }); /** @type {{ filteredTools: string[], includedTools: string[] }} */ const { filteredTools = [], includedTools = [] } = appConfig; /** @type {import('@librechat/api').LCManifestTool[]} */ @@ -66,7 +66,8 @@ const getAvailableTools = async (req, res) => { const cache = getLogStores(CacheKeys.TOOL_CACHE); const cachedToolsArray = await cache.get(CacheKeys.TOOLS); - const appConfig = req.config ?? (await getAppConfig({ role: req.user?.role })); + const appConfig = + req.config ?? (await getAppConfig({ role: req.user?.role, tenantId: req.user?.tenantId })); // Return early if we have cached tools if (cachedToolsArray != null) { diff --git a/api/server/controllers/UserController.js b/api/server/controllers/UserController.js index 301c6d2f76..16b68968d9 100644 --- a/api/server/controllers/UserController.js +++ b/api/server/controllers/UserController.js @@ -26,7 +26,7 @@ const { getLogStores } = require('~/cache'); const db = require('~/models'); const getUserController = async (req, res) => { - const appConfig = await getAppConfig({ role: req.user?.role }); + const appConfig = await getAppConfig({ role: req.user?.role, tenantId: req.user?.tenantId }); /** @type {IUser} */ const userData = req.user.toObject != null ? req.user.toObject() : { ...req.user }; /** @@ -165,7 +165,7 @@ const deleteUserMcpServers = async (userId) => { }; const updateUserPluginsController = async (req, res) => { - const appConfig = await getAppConfig({ role: req.user?.role }); + const appConfig = await getAppConfig({ role: req.user?.role, tenantId: req.user?.tenantId }); const { user } = req; const { pluginKey, action, auth, isEntityTool } = req.body; try { diff --git a/api/server/index.js b/api/server/index.js index 0a8a29f3b7..de99f06701 100644 --- a/api/server/index.js +++ b/api/server/index.js @@ -8,8 +8,8 @@ const express = require('express'); const passport = require('passport'); const compression = require('compression'); const cookieParser = require('cookie-parser'); -const { logger } = require('@librechat/data-schemas'); const mongoSanitize = require('express-mongo-sanitize'); +const { logger, runAsSystem } = require('@librechat/data-schemas'); const { isEnabled, apiNotFound, @@ -60,10 +60,12 @@ const startServer = async () => { app.set('trust proxy', trusted_proxy); await seedDatabase(); - const appConfig = await getAppConfig(); + const appConfig = await getAppConfig({ baseOnly: true }); initializeFileStorage(appConfig); - await performStartupChecks(appConfig); - await updateInterfacePermissions({ appConfig, getRoleByName, updateAccessPermissions }); + await runAsSystem(async () => { + await performStartupChecks(appConfig); + await updateInterfacePermissions({ appConfig, getRoleByName, updateAccessPermissions }); + }); const indexPath = path.join(appConfig.paths.dist, 'index.html'); let indexHTML = fs.readFileSync(indexPath, 'utf8'); @@ -205,8 +207,10 @@ const startServer = async () => { logger.info(`Server listening at http://${host == '0.0.0.0' ? 'localhost' : host}:${port}`); } - await initializeMCPs(); - await initializeOAuthReconnectManager(); + await runAsSystem(async () => { + await initializeMCPs(); + await initializeOAuthReconnectManager(); + }); await checkMigrations(); // Configure stream services (auto-detects Redis from USE_REDIS env var) diff --git a/api/server/middleware/__tests__/requireJwtAuth.spec.js b/api/server/middleware/__tests__/requireJwtAuth.spec.js new file mode 100644 index 0000000000..bc288e5dab --- /dev/null +++ b/api/server/middleware/__tests__/requireJwtAuth.spec.js @@ -0,0 +1,116 @@ +/** + * Integration test: verifies that requireJwtAuth chains tenantContextMiddleware + * after successful passport authentication, so ALS tenant context is set for + * all downstream middleware and route handlers. + * + * requireJwtAuth must chain tenantContextMiddleware after passport populates + * req.user (not at global app.use() scope where req.user is undefined). + * If the chaining is removed, these tests fail. + */ + +const { getTenantId } = require('@librechat/data-schemas'); + +// โ”€โ”€ Mocks โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +let mockPassportError = null; + +jest.mock('passport', () => ({ + authenticate: jest.fn(() => { + return (req, _res, done) => { + if (mockPassportError) { + return done(mockPassportError); + } + if (req._mockUser) { + req.user = req._mockUser; + } + done(); + }; + }), +})); + +// Mock @librechat/api โ€” the real tenantContextMiddleware is TS and cannot be +// required directly from CJS tests. This thin wrapper mirrors the real logic +// (read req.user.tenantId, call tenantStorage.run) using the same data-schemas +// primitives. The real implementation is covered by packages/api tenant.spec.ts. +jest.mock('@librechat/api', () => { + const { tenantStorage } = require('@librechat/data-schemas'); + return { + isEnabled: jest.fn(() => false), + tenantContextMiddleware: (req, res, next) => { + const tenantId = req.user?.tenantId; + if (!tenantId) { + return next(); + } + return tenantStorage.run({ tenantId }, async () => next()); + }, + }; +}); + +// โ”€โ”€ Helpers โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +const requireJwtAuth = require('../requireJwtAuth'); + +function mockReq(user) { + return { headers: {}, _mockUser: user }; +} + +function mockRes() { + return { status: jest.fn().mockReturnThis(), json: jest.fn().mockReturnThis() }; +} + +/** Runs requireJwtAuth and returns the tenantId observed inside next(). */ +function runAuth(user) { + return new Promise((resolve) => { + const req = mockReq(user); + const res = mockRes(); + requireJwtAuth(req, res, () => { + resolve(getTenantId()); + }); + }); +} + +// โ”€โ”€ Tests โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +describe('requireJwtAuth tenant context chaining', () => { + afterEach(() => { + mockPassportError = null; + }); + + it('forwards passport errors to next() without entering tenant middleware', async () => { + mockPassportError = new Error('JWT signature invalid'); + const req = mockReq(undefined); + const res = mockRes(); + const err = await new Promise((resolve) => { + requireJwtAuth(req, res, (e) => resolve(e)); + }); + expect(err).toBeInstanceOf(Error); + expect(err.message).toBe('JWT signature invalid'); + expect(getTenantId()).toBeUndefined(); + }); + + it('sets ALS tenant context after passport auth succeeds', async () => { + const tenantId = await runAuth({ tenantId: 'tenant-abc', role: 'user' }); + expect(tenantId).toBe('tenant-abc'); + }); + + it('ALS tenant context is NOT set when user has no tenantId', async () => { + const tenantId = await runAuth({ role: 'user' }); + expect(tenantId).toBeUndefined(); + }); + + it('ALS tenant context is NOT set when user is undefined', async () => { + const tenantId = await runAuth(undefined); + expect(tenantId).toBeUndefined(); + }); + + it('concurrent requests get isolated tenant contexts', async () => { + const results = await Promise.all( + ['tenant-1', 'tenant-2', 'tenant-3'].map((tid) => runAuth({ tenantId: tid, role: 'user' })), + ); + expect(results).toEqual(['tenant-1', 'tenant-2', 'tenant-3']); + }); + + it('ALS context is not set at top-level scope (outside any request)', () => { + expect(getTenantId()).toBeUndefined(); + }); +}); diff --git a/api/server/middleware/checkDomainAllowed.js b/api/server/middleware/checkDomainAllowed.js index 754eb9c127..f7a3f00e68 100644 --- a/api/server/middleware/checkDomainAllowed.js +++ b/api/server/middleware/checkDomainAllowed.js @@ -18,6 +18,7 @@ const checkDomainAllowed = async (req, res, next) => { const email = req?.user?.email; const appConfig = await getAppConfig({ role: req?.user?.role, + tenantId: req?.user?.tenantId, }); if (email && !isEmailDomainAllowed(email, appConfig?.registration?.allowedDomains)) { diff --git a/api/server/middleware/requireJwtAuth.js b/api/server/middleware/requireJwtAuth.js index 16b107aefc..b13e991b23 100644 --- a/api/server/middleware/requireJwtAuth.js +++ b/api/server/middleware/requireJwtAuth.js @@ -1,20 +1,29 @@ const cookies = require('cookie'); const passport = require('passport'); -const { isEnabled } = require('@librechat/api'); +const { isEnabled, tenantContextMiddleware } = require('@librechat/api'); /** - * Custom Middleware to handle JWT authentication, with support for OpenID token reuse - * Switches between JWT and OpenID authentication based on cookies and environment settings + * Custom Middleware to handle JWT authentication, with support for OpenID token reuse. + * Switches between JWT and OpenID authentication based on cookies and environment settings. + * + * After successful authentication (req.user populated), automatically chains into + * `tenantContextMiddleware` to propagate `req.user.tenantId` into AsyncLocalStorage + * for downstream Mongoose tenant isolation. */ const requireJwtAuth = (req, res, next) => { const cookieHeader = req.headers.cookie; const tokenProvider = cookieHeader ? cookies.parse(cookieHeader).token_provider : null; - if (tokenProvider === 'openid' && isEnabled(process.env.OPENID_REUSE_TOKENS)) { - return passport.authenticate('openidJwt', { session: false })(req, res, next); - } + const strategy = + tokenProvider === 'openid' && isEnabled(process.env.OPENID_REUSE_TOKENS) ? 'openidJwt' : 'jwt'; - return passport.authenticate('jwt', { session: false })(req, res, next); + passport.authenticate(strategy, { session: false })(req, res, (err) => { + if (err) { + return next(err); + } + // req.user is now populated by passport โ€” set up tenant ALS context + tenantContextMiddleware(req, res, next); + }); }; module.exports = requireJwtAuth; diff --git a/api/server/routes/admin/config.js b/api/server/routes/admin/config.js index b9407c6b09..0632077ea9 100644 --- a/api/server/routes/admin/config.js +++ b/api/server/routes/admin/config.js @@ -5,7 +5,7 @@ const { hasConfigCapability, requireCapability, } = require('~/server/middleware/roles/capabilities'); -const { getAppConfig } = require('~/server/services/Config'); +const { getAppConfig, invalidateConfigCaches } = require('~/server/services/Config'); const { requireJwtAuth } = require('~/server/middleware'); const db = require('~/models'); @@ -23,6 +23,7 @@ const handlers = createAdminConfigHandlers({ toggleConfigActive: db.toggleConfigActive, hasConfigCapability, getAppConfig, + invalidateConfigCaches, }); router.use(requireJwtAuth, requireAdminAccess); diff --git a/api/server/routes/config.js b/api/server/routes/config.js index bf60f57e08..0a68ccba4f 100644 --- a/api/server/routes/config.js +++ b/api/server/routes/config.js @@ -37,7 +37,7 @@ router.get('/', async function (req, res) { const ldap = getLdapConfig(); try { - const appConfig = await getAppConfig({ role: req.user?.role }); + const appConfig = await getAppConfig({ role: req.user?.role, tenantId: req.user?.tenantId }); const isOpenIdEnabled = !!process.env.OPENID_CLIENT_ID && diff --git a/api/server/services/AuthService.js b/api/server/services/AuthService.js index ef50a365b9..f17c5051a9 100644 --- a/api/server/services/AuthService.js +++ b/api/server/services/AuthService.js @@ -189,7 +189,7 @@ const registerUser = async (user, additionalData = {}) => { let newUserId; try { - const appConfig = await getAppConfig(); + const appConfig = await getAppConfig({ baseOnly: true }); if (!isEmailDomainAllowed(email, appConfig?.registration?.allowedDomains)) { const errorMessage = 'The email address provided cannot be used. Please use a different email address.'; @@ -260,7 +260,7 @@ const registerUser = async (user, additionalData = {}) => { */ const requestPasswordReset = async (req) => { const { email } = req.body; - const appConfig = await getAppConfig(); + const appConfig = await getAppConfig({ baseOnly: true }); if (!isEmailDomainAllowed(email, appConfig?.registration?.allowedDomains)) { const error = new Error(ErrorTypes.AUTH_FAILED); error.code = ErrorTypes.AUTH_FAILED; diff --git a/api/server/services/Config/__tests__/invalidateConfigCaches.spec.js b/api/server/services/Config/__tests__/invalidateConfigCaches.spec.js new file mode 100644 index 0000000000..df21786f05 --- /dev/null +++ b/api/server/services/Config/__tests__/invalidateConfigCaches.spec.js @@ -0,0 +1,137 @@ +// โ”€โ”€ Mocks โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +const mockConfigStoreDelete = jest.fn().mockResolvedValue(true); +const mockClearAppConfigCache = jest.fn().mockResolvedValue(undefined); +const mockClearOverrideCache = jest.fn().mockResolvedValue(undefined); + +jest.mock('~/cache/getLogStores', () => { + return jest.fn(() => ({ + delete: mockConfigStoreDelete, + })); +}); + +jest.mock('~/server/services/start/tools', () => ({ + loadAndFormatTools: jest.fn(() => ({})), +})); + +jest.mock('../loadCustomConfig', () => jest.fn().mockResolvedValue({})); + +jest.mock('@librechat/data-schemas', () => { + const actual = jest.requireActual('@librechat/data-schemas'); + return { ...actual, AppService: jest.fn(() => ({ availableTools: {} })) }; +}); + +jest.mock('~/models', () => ({ + getApplicableConfigs: jest.fn().mockResolvedValue([]), + getUserPrincipals: jest.fn().mockResolvedValue([]), +})); + +const mockInvalidateCachedTools = jest.fn().mockResolvedValue(undefined); +jest.mock('../getCachedTools', () => ({ + setCachedTools: jest.fn().mockResolvedValue(undefined), + invalidateCachedTools: mockInvalidateCachedTools, +})); + +jest.mock('@librechat/api', () => ({ + createAppConfigService: jest.fn(() => ({ + getAppConfig: jest.fn().mockResolvedValue({ availableTools: {} }), + clearAppConfigCache: mockClearAppConfigCache, + clearOverrideCache: mockClearOverrideCache, + })), +})); + +// โ”€โ”€ Tests โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +const { CacheKeys } = require('librechat-data-provider'); +const { invalidateConfigCaches } = require('../app'); + +describe('invalidateConfigCaches', () => { + beforeEach(() => { + jest.clearAllMocks(); + }); + + it('clears all four caches', async () => { + await invalidateConfigCaches(); + + expect(mockClearAppConfigCache).toHaveBeenCalledTimes(1); + expect(mockClearOverrideCache).toHaveBeenCalledTimes(1); + expect(mockInvalidateCachedTools).toHaveBeenCalledWith({ invalidateGlobal: true }); + expect(mockConfigStoreDelete).toHaveBeenCalledWith(CacheKeys.ENDPOINT_CONFIG); + }); + + it('passes tenantId through to clearOverrideCache', async () => { + await invalidateConfigCaches('tenant-a'); + + expect(mockClearOverrideCache).toHaveBeenCalledWith('tenant-a'); + expect(mockClearAppConfigCache).toHaveBeenCalledTimes(1); + expect(mockInvalidateCachedTools).toHaveBeenCalledWith({ invalidateGlobal: true }); + }); + + it('does not throw when CONFIG_STORE.delete fails', async () => { + mockConfigStoreDelete.mockRejectedValueOnce(new Error('store not found')); + + await expect(invalidateConfigCaches()).resolves.not.toThrow(); + + // Other caches should still have been invalidated + expect(mockClearAppConfigCache).toHaveBeenCalledTimes(1); + expect(mockClearOverrideCache).toHaveBeenCalledTimes(1); + expect(mockInvalidateCachedTools).toHaveBeenCalledWith({ invalidateGlobal: true }); + }); + + it('all operations run in parallel (not sequentially)', async () => { + const order = []; + + mockClearAppConfigCache.mockImplementation( + () => + new Promise((r) => + setTimeout(() => { + order.push('base'); + r(); + }, 10), + ), + ); + mockClearOverrideCache.mockImplementation( + () => + new Promise((r) => + setTimeout(() => { + order.push('override'); + r(); + }, 10), + ), + ); + mockInvalidateCachedTools.mockImplementation( + () => + new Promise((r) => + setTimeout(() => { + order.push('tools'); + r(); + }, 10), + ), + ); + mockConfigStoreDelete.mockImplementation( + () => + new Promise((r) => + setTimeout(() => { + order.push('endpoint'); + r(); + }, 10), + ), + ); + + await invalidateConfigCaches(); + + // All four should have been called (parallel execution via Promise.allSettled) + expect(order).toHaveLength(4); + expect(new Set(order)).toEqual(new Set(['base', 'override', 'tools', 'endpoint'])); + }); + + it('resolves even when clearAppConfigCache throws (partial failure)', async () => { + mockClearAppConfigCache.mockRejectedValueOnce(new Error('cache connection lost')); + + await expect(invalidateConfigCaches()).resolves.not.toThrow(); + + // Other caches should still have been invalidated despite the failure + expect(mockClearOverrideCache).toHaveBeenCalledTimes(1); + expect(mockInvalidateCachedTools).toHaveBeenCalledWith({ invalidateGlobal: true }); + }); +}); diff --git a/api/server/services/Config/app.js b/api/server/services/Config/app.js index a63bef2124..c0180fdb12 100644 --- a/api/server/services/Config/app.js +++ b/api/server/services/Config/app.js @@ -1,9 +1,9 @@ const { CacheKeys } = require('librechat-data-provider'); -const { AppService } = require('@librechat/data-schemas'); const { createAppConfigService } = require('@librechat/api'); +const { AppService, logger } = require('@librechat/data-schemas'); +const { setCachedTools, invalidateCachedTools } = require('./getCachedTools'); const { loadAndFormatTools } = require('~/server/services/start/tools'); const loadCustomConfig = require('./loadCustomConfig'); -const { setCachedTools } = require('./getCachedTools'); const getLogStores = require('~/cache/getLogStores'); const paths = require('~/config/paths'); const db = require('~/models'); @@ -20,7 +20,7 @@ const loadBaseConfig = async () => { return AppService({ config, paths, systemTools }); }; -const { getAppConfig, clearAppConfigCache } = createAppConfigService({ +const { getAppConfig, clearAppConfigCache, clearOverrideCache } = createAppConfigService({ loadBaseConfig, setCachedTools, getCache: getLogStores, @@ -29,7 +29,44 @@ const { getAppConfig, clearAppConfigCache } = createAppConfigService({ getUserPrincipals: db.getUserPrincipals, }); +/** Deletes the ENDPOINT_CONFIG entry from CONFIG_STORE. Failures are non-critical and swallowed. */ +async function clearEndpointConfigCache() { + try { + const configStore = getLogStores(CacheKeys.CONFIG_STORE); + await configStore.delete(CacheKeys.ENDPOINT_CONFIG); + } catch { + // CONFIG_STORE or ENDPOINT_CONFIG may not exist โ€” not critical + } +} + +/** + * Invalidate all config-related caches after an admin config mutation. + * Clears the base config, per-principal override caches, tool caches, + * and the endpoints config cache. + * @param {string} [tenantId] - Optional tenant ID to scope override cache clearing. + */ +async function invalidateConfigCaches(tenantId) { + const results = await Promise.allSettled([ + clearAppConfigCache(), + clearOverrideCache(tenantId), + invalidateCachedTools({ invalidateGlobal: true }), + clearEndpointConfigCache(), + ]); + const labels = [ + 'clearAppConfigCache', + 'clearOverrideCache', + 'invalidateCachedTools', + 'clearEndpointConfigCache', + ]; + for (let i = 0; i < results.length; i++) { + if (results[i].status === 'rejected') { + logger.error(`[invalidateConfigCaches] ${labels[i]} failed:`, results[i].reason); + } + } +} + module.exports = { getAppConfig, clearAppConfigCache, + invalidateConfigCaches, }; diff --git a/api/server/services/Config/getEndpointsConfig.js b/api/server/services/Config/getEndpointsConfig.js index bb22584851..476d3d7c80 100644 --- a/api/server/services/Config/getEndpointsConfig.js +++ b/api/server/services/Config/getEndpointsConfig.js @@ -26,7 +26,8 @@ async function getEndpointsConfig(req) { } } - const appConfig = req.config ?? (await getAppConfig({ role: req.user?.role })); + const appConfig = + req.config ?? (await getAppConfig({ role: req.user?.role, tenantId: req.user?.tenantId })); const defaultEndpointsConfig = await loadDefaultEndpointsConfig(appConfig); const customEndpointsConfig = loadCustomEndpointsConfig(appConfig?.endpoints?.custom); diff --git a/api/server/services/Config/loadConfigModels.js b/api/server/services/Config/loadConfigModels.js index 2bc83ecc3a..b94a719909 100644 --- a/api/server/services/Config/loadConfigModels.js +++ b/api/server/services/Config/loadConfigModels.js @@ -12,7 +12,7 @@ const { getAppConfig } = require('./app'); * @param {ServerRequest} req - The Express request object. */ async function loadConfigModels(req) { - const appConfig = await getAppConfig({ role: req.user?.role }); + const appConfig = await getAppConfig({ role: req.user?.role, tenantId: req.user?.tenantId }); if (!appConfig) { return {}; } diff --git a/api/server/services/Config/loadDefaultModels.js b/api/server/services/Config/loadDefaultModels.js index 31aa831a70..85f2c42a33 100644 --- a/api/server/services/Config/loadDefaultModels.js +++ b/api/server/services/Config/loadDefaultModels.js @@ -16,7 +16,8 @@ const { getAppConfig } = require('./app'); */ async function loadDefaultModels(req) { try { - const appConfig = req.config ?? (await getAppConfig({ role: req.user?.role })); + const appConfig = + req.config ?? (await getAppConfig({ role: req.user?.role, tenantId: req.user?.tenantId })); const vertexConfig = appConfig?.endpoints?.[EModelEndpoint.anthropic]?.vertexConfig; const [openAI, anthropic, azureOpenAI, assistants, azureAssistants, google, bedrock] = diff --git a/api/server/services/Files/Audio/STTService.js b/api/server/services/Files/Audio/STTService.js index 4ba62a7eeb..c9a35c35ea 100644 --- a/api/server/services/Files/Audio/STTService.js +++ b/api/server/services/Files/Audio/STTService.js @@ -142,6 +142,7 @@ class STTService { req.config ?? (await getAppConfig({ role: req?.user?.role, + tenantId: req?.user?.tenantId, })); const sttSchema = appConfig?.speech?.stt; if (!sttSchema) { diff --git a/api/server/services/Files/Audio/TTSService.js b/api/server/services/Files/Audio/TTSService.js index 2c932968c6..1125dd74ed 100644 --- a/api/server/services/Files/Audio/TTSService.js +++ b/api/server/services/Files/Audio/TTSService.js @@ -297,6 +297,7 @@ class TTSService { req.config ?? (await getAppConfig({ role: req.user?.role, + tenantId: req.user?.tenantId, })); try { res.setHeader('Content-Type', 'audio/mpeg'); @@ -365,6 +366,7 @@ class TTSService { req.config ?? (await getAppConfig({ role: req.user?.role, + tenantId: req.user?.tenantId, })); const provider = this.getProvider(appConfig); const ttsSchema = appConfig?.speech?.tts?.[provider]; diff --git a/api/server/services/Files/Audio/getCustomConfigSpeech.js b/api/server/services/Files/Audio/getCustomConfigSpeech.js index d0d0b51ac2..b438771ec1 100644 --- a/api/server/services/Files/Audio/getCustomConfigSpeech.js +++ b/api/server/services/Files/Audio/getCustomConfigSpeech.js @@ -17,6 +17,7 @@ async function getCustomConfigSpeech(req, res) { try { const appConfig = await getAppConfig({ role: req.user?.role, + tenantId: req.user?.tenantId, }); if (!appConfig) { diff --git a/api/server/services/Files/Audio/getVoices.js b/api/server/services/Files/Audio/getVoices.js index f2f8e100c3..22bd7cea6e 100644 --- a/api/server/services/Files/Audio/getVoices.js +++ b/api/server/services/Files/Audio/getVoices.js @@ -18,6 +18,7 @@ async function getVoices(req, res) { req.config ?? (await getAppConfig({ role: req.user?.role, + tenantId: req.user?.tenantId, })); const ttsSchema = appConfig?.speech?.tts; diff --git a/api/server/services/MCP.js b/api/server/services/MCP.js index 03563a0cfc..d765d335aa 100644 --- a/api/server/services/MCP.js +++ b/api/server/services/MCP.js @@ -367,7 +367,7 @@ async function createMCPTools({ const serverConfig = config ?? (await getMCPServersRegistry().getServerConfig(serverName, user?.id)); if (serverConfig?.url) { - const appConfig = await getAppConfig({ role: user?.role }); + const appConfig = await getAppConfig({ role: user?.role, tenantId: user?.tenantId }); const allowedDomains = appConfig?.mcpSettings?.allowedDomains; const isDomainAllowed = await isMCPDomainAllowed(serverConfig, allowedDomains); if (!isDomainAllowed) { @@ -449,7 +449,7 @@ async function createMCPTool({ const serverConfig = config ?? (await getMCPServersRegistry().getServerConfig(serverName, user?.id)); if (serverConfig?.url) { - const appConfig = await getAppConfig({ role: user?.role }); + const appConfig = await getAppConfig({ role: user?.role, tenantId: user?.tenantId }); const allowedDomains = appConfig?.mcpSettings?.allowedDomains; const isDomainAllowed = await isMCPDomainAllowed(serverConfig, allowedDomains); if (!isDomainAllowed) { diff --git a/api/server/services/initializeMCPs.js b/api/server/services/initializeMCPs.js index c7f27acd0e..5728730131 100644 --- a/api/server/services/initializeMCPs.js +++ b/api/server/services/initializeMCPs.js @@ -7,7 +7,7 @@ const { createMCPServersRegistry, createMCPManager } = require('~/config'); * Initialize MCP servers */ async function initializeMCPs() { - const appConfig = await getAppConfig(); + const appConfig = await getAppConfig({ baseOnly: true }); const mcpServers = appConfig.mcpConfig; try { diff --git a/api/strategies/ldapStrategy.js b/api/strategies/ldapStrategy.js index dcadc26a45..0c99c7b670 100644 --- a/api/strategies/ldapStrategy.js +++ b/api/strategies/ldapStrategy.js @@ -122,7 +122,7 @@ const ldapLogin = new LdapStrategy(ldapOptions, async (userinfo, done) => { ); } - const appConfig = await getAppConfig(); + const appConfig = await getAppConfig({ baseOnly: true }); if (!isEmailDomainAllowed(mail, appConfig?.registration?.allowedDomains)) { logger.error( `[LDAP Strategy] Authentication blocked - email domain not allowed [Email: ${mail}]`, diff --git a/api/strategies/openidStrategy.js b/api/strategies/openidStrategy.js index 7c43358297..ab7eb60261 100644 --- a/api/strategies/openidStrategy.js +++ b/api/strategies/openidStrategy.js @@ -468,7 +468,7 @@ async function processOpenIDAuth(tokenset, existingUsersOnly = false) { Object.assign(userinfo, providerUserinfo); } - const appConfig = await getAppConfig(); + const appConfig = await getAppConfig({ baseOnly: true }); const email = getOpenIdEmail(userinfo); if (!isEmailDomainAllowed(email, appConfig?.registration?.allowedDomains)) { logger.error( diff --git a/api/strategies/samlStrategy.js b/api/strategies/samlStrategy.js index 843baf8a64..abcb3de099 100644 --- a/api/strategies/samlStrategy.js +++ b/api/strategies/samlStrategy.js @@ -193,7 +193,7 @@ async function setupSaml() { logger.debug('[samlStrategy] SAML profile:', profile); const userEmail = getEmail(profile) || ''; - const appConfig = await getAppConfig(); + const appConfig = await getAppConfig({ baseOnly: true }); if (!isEmailDomainAllowed(userEmail, appConfig?.registration?.allowedDomains)) { logger.error( diff --git a/api/strategies/socialLogin.js b/api/strategies/socialLogin.js index 88fb347042..7585e8e2fe 100644 --- a/api/strategies/socialLogin.js +++ b/api/strategies/socialLogin.js @@ -13,7 +13,7 @@ const socialLogin = profile, }); - const appConfig = await getAppConfig(); + const appConfig = await getAppConfig({ baseOnly: true }); if (!isEmailDomainAllowed(email, appConfig?.registration?.allowedDomains)) { logger.error( diff --git a/packages/api/src/admin/config.ts b/packages/api/src/admin/config.ts index 0a1afd5388..b2afd9c69b 100644 --- a/packages/api/src/admin/config.ts +++ b/packages/api/src/admin/config.ts @@ -77,6 +77,8 @@ export interface AdminConfigDeps { userId?: string; tenantId?: string; }) => Promise; + /** Invalidate all config-related caches after a mutation. */ + invalidateConfigCaches?: (tenantId?: string) => Promise; } // โ”€โ”€ Validation helpers โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ @@ -133,6 +135,7 @@ export function createAdminConfigHandlers(deps: AdminConfigDeps) { toggleConfigActive, hasConfigCapability, getAppConfig, + invalidateConfigCaches, } = deps; /** @@ -176,7 +179,9 @@ export function createAdminConfigHandlers(deps: AdminConfigDeps) { return res.status(501).json({ error: 'Base config endpoint not configured' }); } - const appConfig = await getAppConfig(); + const appConfig = await getAppConfig({ + tenantId: user.tenantId, + }); return res.status(200).json({ config: appConfig }); } catch (error) { logger.error('[adminConfig] getBaseConfig error:', error); @@ -278,6 +283,9 @@ export function createAdminConfigHandlers(deps: AdminConfigDeps) { priority ?? DEFAULT_PRIORITY, ); + invalidateConfigCaches?.(user.tenantId)?.catch((err) => + logger.error('[adminConfig] Cache invalidation failed after upsert:', err), + ); return res.status(config?.configVersion === 1 ? 201 : 200).json({ config }); } catch (error) { logger.error('[adminConfig] upsertConfigOverrides error:', error); @@ -367,6 +375,9 @@ export function createAdminConfigHandlers(deps: AdminConfigDeps) { priority ?? existing?.priority ?? DEFAULT_PRIORITY, ); + invalidateConfigCaches?.(user.tenantId)?.catch((err) => + logger.error('[adminConfig] Cache invalidation failed after patch:', err), + ); return res.status(200).json({ config }); } catch (error) { logger.error('[adminConfig] patchConfigField error:', error); @@ -414,6 +425,9 @@ export function createAdminConfigHandlers(deps: AdminConfigDeps) { return res.status(404).json({ error: 'Config not found' }); } + invalidateConfigCaches?.(user.tenantId)?.catch((err) => + logger.error('[adminConfig] Cache invalidation failed after field delete:', err), + ); return res.status(200).json({ config }); } catch (error) { logger.error('[adminConfig] deleteConfigField error:', error); @@ -449,6 +463,9 @@ export function createAdminConfigHandlers(deps: AdminConfigDeps) { return res.status(404).json({ error: 'Config not found' }); } + invalidateConfigCaches?.(user.tenantId)?.catch((err) => + logger.error('[adminConfig] Cache invalidation failed after config delete:', err), + ); return res.status(200).json({ success: true }); } catch (error) { logger.error('[adminConfig] deleteConfigOverrides error:', error); @@ -489,6 +506,9 @@ export function createAdminConfigHandlers(deps: AdminConfigDeps) { return res.status(404).json({ error: 'Config not found' }); } + invalidateConfigCaches?.(user.tenantId)?.catch((err) => + logger.error('[adminConfig] Cache invalidation failed after toggle:', err), + ); return res.status(200).json({ config }); } catch (error) { logger.error('[adminConfig] toggleConfig error:', error); diff --git a/packages/api/src/app/service.spec.ts b/packages/api/src/app/service.spec.ts index 2dfba09e25..4232a36dc3 100644 --- a/packages/api/src/app/service.spec.ts +++ b/packages/api/src/app/service.spec.ts @@ -1,17 +1,24 @@ import { createAppConfigService } from './service'; -function createMockCache() { +/** + * Creates a mock cache that simulates Keyv's namespace behavior. + * Keyv stores keys internally as `namespace:key` but its API (get/set/delete) + * accepts un-namespaced keys and auto-prepends the namespace. + */ +function createMockCache(namespace = 'app_config') { const store = new Map(); return { - get: jest.fn((key) => Promise.resolve(store.get(key))), + get: jest.fn((key) => Promise.resolve(store.get(`${namespace}:${key}`))), set: jest.fn((key, value) => { - store.set(key, value); + store.set(`${namespace}:${key}`, value); return Promise.resolve(undefined); }), delete: jest.fn((key) => { - store.delete(key); + store.delete(`${namespace}:${key}`); return Promise.resolve(true); }), + /** Mimic Keyv's opts.store structure for key enumeration in clearOverrideCache */ + opts: { store: { keys: () => store.keys() } }, _store: store, }; } @@ -58,6 +65,23 @@ describe('createAppConfigService', () => { expect(deps.loadBaseConfig).toHaveBeenCalledTimes(1); }); + it('baseOnly returns YAML config without DB queries', async () => { + const deps = createDeps({ + getApplicableConfigs: jest + .fn() + .mockResolvedValue([ + { priority: 10, overrides: { interface: { endpointsMenu: false } }, isActive: true }, + ]), + }); + const { getAppConfig } = createAppConfigService(deps); + + const config = await getAppConfig({ baseOnly: true }); + + expect(deps.loadBaseConfig).toHaveBeenCalledTimes(1); + expect(deps.getApplicableConfigs).not.toHaveBeenCalled(); + expect(config).toEqual(deps._baseConfig); + }); + it('reloads base config when refresh is true', async () => { const deps = createDeps(); const { getAppConfig } = createAppConfigService(deps); @@ -144,8 +168,8 @@ describe('createAppConfigService', () => { await getAppConfig({ userId: 'uid1' }); const cachedKeys = [...deps._cache._store.keys()]; - const overrideKey = cachedKeys.find((k) => k.startsWith('_OVERRIDE_:')); - expect(overrideKey).toBe('_OVERRIDE_:__default__:uid1'); + const overrideKey = cachedKeys.find((k) => k.includes('_OVERRIDE_:')); + expect(overrideKey).toBe('app_config:_OVERRIDE_:__default__:uid1'); }); it('tenantId is included in cache key to prevent cross-tenant contamination', async () => { @@ -241,4 +265,70 @@ describe('createAppConfigService', () => { expect(deps.loadBaseConfig).toHaveBeenCalledTimes(2); }); }); + + describe('clearOverrideCache', () => { + it('clears all override caches when no tenantId is provided', async () => { + const deps = createDeps({ + getApplicableConfigs: jest + .fn() + .mockResolvedValue([{ priority: 10, overrides: { x: 1 }, isActive: true }]), + }); + const { getAppConfig, clearOverrideCache } = createAppConfigService(deps); + + await getAppConfig({ role: 'ADMIN', tenantId: 'tenant-a' }); + await getAppConfig({ role: 'ADMIN', tenantId: 'tenant-b' }); + expect(deps.getApplicableConfigs).toHaveBeenCalledTimes(2); + + await clearOverrideCache(); + + // After clearing, both tenants should re-query DB + await getAppConfig({ role: 'ADMIN', tenantId: 'tenant-a' }); + await getAppConfig({ role: 'ADMIN', tenantId: 'tenant-b' }); + expect(deps.getApplicableConfigs).toHaveBeenCalledTimes(4); + }); + + it('clears only specified tenant override caches', async () => { + const deps = createDeps({ + getApplicableConfigs: jest + .fn() + .mockResolvedValue([{ priority: 10, overrides: { x: 1 }, isActive: true }]), + }); + const { getAppConfig, clearOverrideCache } = createAppConfigService(deps); + + await getAppConfig({ role: 'ADMIN', tenantId: 'tenant-a' }); + await getAppConfig({ role: 'ADMIN', tenantId: 'tenant-b' }); + expect(deps.getApplicableConfigs).toHaveBeenCalledTimes(2); + + await clearOverrideCache('tenant-a'); + + // tenant-a should re-query, tenant-b should be cached + await getAppConfig({ role: 'ADMIN', tenantId: 'tenant-a' }); + await getAppConfig({ role: 'ADMIN', tenantId: 'tenant-b' }); + expect(deps.getApplicableConfigs).toHaveBeenCalledTimes(3); + }); + + it('does not clear base config', async () => { + const deps = createDeps(); + const { getAppConfig, clearOverrideCache } = createAppConfigService(deps); + + await getAppConfig(); + expect(deps.loadBaseConfig).toHaveBeenCalledTimes(1); + + await clearOverrideCache(); + + await getAppConfig(); + // Base config should still be cached + expect(deps.loadBaseConfig).toHaveBeenCalledTimes(1); + }); + + it('does not throw when store.keys is unavailable (Redis fallback to TTL expiry)', async () => { + const deps = createDeps(); + // Remove store.keys to simulate Redis-backed cache + deps._cache.opts = {}; + const { clearOverrideCache } = createAppConfigService(deps); + + // Should not throw โ€” logs warning and relies on TTL expiry + await expect(clearOverrideCache()).resolves.toBeUndefined(); + }); + }); }); diff --git a/packages/api/src/app/service.ts b/packages/api/src/app/service.ts index b7826e40ee..6c5d307709 100644 --- a/packages/api/src/app/service.ts +++ b/packages/api/src/app/service.ts @@ -13,6 +13,12 @@ interface CacheStore { get: (key: string) => Promise; set: (key: string, value: unknown, ttl?: number) => Promise; delete: (key: string) => Promise; + /** Keyv options โ€” used for key enumeration when clearing override caches. */ + opts?: { + store?: { + keys?: () => IterableIterator; + }; + }; } export interface AppConfigServiceDeps { @@ -39,8 +45,28 @@ export interface AppConfigServiceDeps { // โ”€โ”€ Helpers โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ +let _strictOverride: boolean | undefined; +function isStrictOverrideMode(): boolean { + return (_strictOverride ??= process.env.TENANT_ISOLATION_STRICT === 'true'); +} + +/** @internal Resets the cached strict-override flag. Exposed for test teardown only. */ +let _warnedNoTenantInStrictMode = false; + +export function _resetOverrideStrictCache(): void { + _strictOverride = undefined; + _warnedNoTenantInStrictMode = false; +} + function overrideCacheKey(role?: string, userId?: string, tenantId?: string): string { const tenant = tenantId || '__default__'; + if (!tenantId && isStrictOverrideMode() && !_warnedNoTenantInStrictMode) { + _warnedNoTenantInStrictMode = true; + logger.warn( + '[overrideCacheKey] No tenantId in strict mode โ€” falling back to __default__. ' + + 'This likely indicates a code path that bypasses the tenant context middleware.', + ); + } if (userId && role) { return `_OVERRIDE_:${tenant}:${role}:${userId}`; } @@ -83,20 +109,13 @@ export function createAppConfigService(deps: AppConfigServiceDeps) { } /** - * Get the app configuration, optionally merged with DB overrides for the given principal. - * - * The base config (from YAML + AppService) is cached indefinitely. Per-principal merged - * configs are cached with a short TTL (`overrideCacheTtl`, default 60s). On cache miss, - * `getApplicableConfigs` queries the DB for matching overrides and merges them by priority. + * Ensure the YAML-derived base config is loaded and cached. + * Returns the `_BASE_` config (YAML + AppService). No DB queries. */ - async function getAppConfig( - options: { role?: string; userId?: string; tenantId?: string; refresh?: boolean } = {}, - ): Promise { - const { role, userId, tenantId, refresh } = options; - + async function ensureBaseConfig(refresh?: boolean): Promise { let baseConfig = (await cache.get(BASE_CONFIG_KEY)) as AppConfig | undefined; if (!baseConfig || refresh) { - logger.info('[getAppConfig] Loading base configuration...'); + logger.info('[ensureBaseConfig] Loading base configuration...'); baseConfig = await loadBaseConfig(); if (!baseConfig) { @@ -109,6 +128,37 @@ export function createAppConfigService(deps: AppConfigServiceDeps) { await cache.set(BASE_CONFIG_KEY, baseConfig); } + return baseConfig; + } + + /** + * Get the app configuration, optionally merged with DB overrides for the given principal. + * + * The base config (from YAML + AppService) is cached indefinitely. Per-principal merged + * configs are cached with a short TTL (`overrideCacheTtl`, default 60s). On cache miss, + * `getApplicableConfigs` queries the DB for matching overrides and merges them by priority. + * + * When `baseOnly` is true, returns the YAML-derived config without any DB queries. + * `role`, `userId`, and `tenantId` are ignored in this mode. + * Use this for startup, auth strategies, and other pre-tenant code paths. + */ + async function getAppConfig( + options: { + role?: string; + userId?: string; + tenantId?: string; + refresh?: boolean; + /** When true, return only the YAML-derived base config โ€” no DB override queries. */ + baseOnly?: boolean; + } = {}, + ): Promise { + const { role, userId, tenantId, refresh, baseOnly } = options; + + const baseConfig = await ensureBaseConfig(refresh); + + if (baseOnly) { + return baseConfig; + } const cacheKey = overrideCacheKey(role, userId, tenantId); if (!refresh) { @@ -146,9 +196,55 @@ export function createAppConfigService(deps: AppConfigServiceDeps) { await cache.delete(BASE_CONFIG_KEY); } + /** + * Clear per-principal override caches. When `tenantId` is provided, only caches + * matching `_OVERRIDE_:${tenantId}:*` are deleted. When omitted, ALL override + * caches are cleared. + */ + async function clearOverrideCache(tenantId?: string): Promise { + const namespace = cacheKeys.APP_CONFIG; + const overrideSegment = tenantId ? `_OVERRIDE_:${tenantId}:` : '_OVERRIDE_:'; + + // In-memory store โ€” enumerate keys directly. + // APP_CONFIG defaults to FORCED_IN_MEMORY_CACHE_NAMESPACES, so this is the + // standard path. Redis SCAN is intentionally avoided here โ€” it can cause 60s+ + // stalls under concurrent load (see #12410). When APP_CONFIG is Redis-backed + // and store.keys() is unavailable, overrides expire naturally via TTL. + const store = (cache as CacheStore).opts?.store; + if (store && typeof store.keys === 'function') { + // Keyv stores keys with a namespace prefix (e.g. "APP_CONFIG:_OVERRIDE_:..."). + // We match on the namespaced key but delete using the un-namespaced key + // because Keyv.delete() auto-prepends the namespace. + const namespacedPrefix = `${namespace}:${overrideSegment}`; + const toDelete: string[] = []; + for (const key of store.keys()) { + if (key.startsWith(namespacedPrefix)) { + toDelete.push(key.slice(namespace.length + 1)); + } + } + if (toDelete.length > 0) { + await Promise.all(toDelete.map((key) => cache.delete(key))); + logger.info( + `[clearOverrideCache] Cleared ${toDelete.length} override cache entries` + + (tenantId ? ` for tenant ${tenantId}` : ''), + ); + } + return; + } + + logger.warn( + '[clearOverrideCache] Cache store does not support key enumeration. ' + + 'Override caches will expire naturally via TTL (%dms). ' + + 'This is expected when APP_CONFIG is Redis-backed โ€” Redis SCAN is avoided ' + + 'for performance reasons (see #12410).', + overrideCacheTtl, + ); + } + return { getAppConfig, clearAppConfigCache, + clearOverrideCache, }; } diff --git a/packages/api/src/middleware/__tests__/tenant.spec.ts b/packages/api/src/middleware/__tests__/tenant.spec.ts new file mode 100644 index 0000000000..7451817941 --- /dev/null +++ b/packages/api/src/middleware/__tests__/tenant.spec.ts @@ -0,0 +1,101 @@ +import { getTenantId } from '@librechat/data-schemas'; +import type { Response, NextFunction } from 'express'; +import type { ServerRequest } from '~/types/http'; +// Import directly from source file โ€” _resetTenantMiddlewareStrictCache is intentionally +// excluded from the public barrel export (index.ts). +import { tenantContextMiddleware, _resetTenantMiddlewareStrictCache } from '../tenant'; + +function mockReq(user?: Record): ServerRequest { + return { user } as unknown as ServerRequest; +} + +function mockRes(): Response { + const res = { + status: jest.fn().mockReturnThis(), + json: jest.fn().mockReturnThis(), + }; + return res as unknown as Response; +} + +/** Runs the middleware and returns a Promise that resolves when next() is called. */ +function runMiddleware(req: ServerRequest, res: Response): Promise { + return new Promise((resolve) => { + const next: NextFunction = () => { + resolve(getTenantId()); + }; + tenantContextMiddleware(req, res, next); + }); +} + +describe('tenantContextMiddleware', () => { + afterEach(() => { + _resetTenantMiddlewareStrictCache(); + delete process.env.TENANT_ISOLATION_STRICT; + }); + + it('sets ALS tenant context for authenticated requests with tenantId', async () => { + const req = mockReq({ tenantId: 'tenant-x', role: 'user' }); + const res = mockRes(); + + const tenantId = await runMiddleware(req, res); + expect(tenantId).toBe('tenant-x'); + }); + + it('is a no-op for unauthenticated requests (no user)', async () => { + const req = mockReq(); + const res = mockRes(); + + const tenantId = await runMiddleware(req, res); + expect(tenantId).toBeUndefined(); + }); + + it('passes through without ALS when user has no tenantId in non-strict mode', async () => { + const req = mockReq({ role: 'user' }); + const res = mockRes(); + + const tenantId = await runMiddleware(req, res); + expect(tenantId).toBeUndefined(); + }); + + it('returns 403 when user has no tenantId in strict mode', () => { + process.env.TENANT_ISOLATION_STRICT = 'true'; + _resetTenantMiddlewareStrictCache(); + + const req = mockReq({ role: 'user' }); + const res = mockRes(); + const next: NextFunction = jest.fn(); + + tenantContextMiddleware(req, res, next); + + expect(res.status).toHaveBeenCalledWith(403); + expect(res.json).toHaveBeenCalledWith( + expect.objectContaining({ error: expect.stringContaining('Tenant context required') }), + ); + expect(next).not.toHaveBeenCalled(); + }); + + it('allows authenticated requests with tenantId in strict mode', async () => { + process.env.TENANT_ISOLATION_STRICT = 'true'; + _resetTenantMiddlewareStrictCache(); + + const req = mockReq({ tenantId: 'tenant-y', role: 'admin' }); + const res = mockRes(); + + const tenantId = await runMiddleware(req, res); + expect(tenantId).toBe('tenant-y'); + }); + + it('different requests get independent tenant contexts', async () => { + const runRequest = (tid: string) => { + const req = mockReq({ tenantId: tid, role: 'user' }); + const res = mockRes(); + return runMiddleware(req, res); + }; + + const results = await Promise.all([runRequest('tenant-1'), runRequest('tenant-2')]); + + expect(results).toHaveLength(2); + expect(results).toContain('tenant-1'); + expect(results).toContain('tenant-2'); + }); +}); diff --git a/packages/api/src/middleware/balance.ts b/packages/api/src/middleware/balance.ts index 8c6b149cdd..19719680ec 100644 --- a/packages/api/src/middleware/balance.ts +++ b/packages/api/src/middleware/balance.ts @@ -12,7 +12,11 @@ import type { BalanceUpdateFields } from '~/types'; import { getBalanceConfig } from '~/app/config'; export interface BalanceMiddlewareOptions { - getAppConfig: (options?: { role?: string; refresh?: boolean }) => Promise; + getAppConfig: (options?: { + role?: string; + tenantId?: string; + refresh?: boolean; + }) => Promise; findBalanceByUser: (userId: string) => Promise; upsertBalanceFields: (userId: string, fields: IBalanceUpdate) => Promise; } @@ -92,7 +96,10 @@ export function createSetBalanceConfig({ return async (req: ServerRequest, res: ServerResponse, next: NextFunction): Promise => { try { const user = req.user as IUser & { _id: string | ObjectId }; - const appConfig = await getAppConfig({ role: user?.role }); + const appConfig = await getAppConfig({ + role: user?.role, + tenantId: user?.tenantId, + }); const balanceConfig = getBalanceConfig(appConfig); if (!balanceConfig?.enabled) { return next(); diff --git a/packages/api/src/middleware/index.ts b/packages/api/src/middleware/index.ts index a56b8e4a3e..7d9dee2f8a 100644 --- a/packages/api/src/middleware/index.ts +++ b/packages/api/src/middleware/index.ts @@ -5,5 +5,6 @@ export * from './notFound'; export * from './balance'; export * from './json'; export * from './capabilities'; +export { tenantContextMiddleware } from './tenant'; export * from './concurrency'; export * from './checkBalance'; diff --git a/packages/api/src/middleware/tenant.ts b/packages/api/src/middleware/tenant.ts new file mode 100644 index 0000000000..0b0e003991 --- /dev/null +++ b/packages/api/src/middleware/tenant.ts @@ -0,0 +1,70 @@ +import { isMainThread } from 'worker_threads'; +import { tenantStorage, logger } from '@librechat/data-schemas'; +import type { Response, NextFunction } from 'express'; +import type { ServerRequest } from '~/types/http'; + +let _checkedThread = false; + +let _strictMode: boolean | undefined; + +function isStrict(): boolean { + return (_strictMode ??= process.env.TENANT_ISOLATION_STRICT === 'true'); +} + +/** Resets the cached strict-mode flag. Exposed for test teardown only. */ +export function _resetTenantMiddlewareStrictCache(): void { + _strictMode = undefined; +} + +/** + * Express middleware that propagates the authenticated user's `tenantId` into + * the AsyncLocalStorage context used by the Mongoose tenant-isolation plugin. + * + * **Placement**: Chained automatically by `requireJwtAuth` after successful + * passport authentication (req.user is populated). Must NOT be registered at + * global `app.use()` scope โ€” `req.user` is undefined at that stage. + * + * Behaviour: + * - Authenticated request with `tenantId` โ†’ wraps downstream in `tenantStorage.run({ tenantId })` + * - Authenticated request **without** `tenantId`: + * - Strict mode (`TENANT_ISOLATION_STRICT=true`) โ†’ responds 403 + * - Non-strict (default) โ†’ passes through without ALS context (backward compat) + * - Unauthenticated request โ†’ no-op (calls `next()` directly) + */ +export function tenantContextMiddleware( + req: ServerRequest, + res: Response, + next: NextFunction, +): void { + if (!_checkedThread) { + _checkedThread = true; + if (!isMainThread) { + logger.error( + '[tenantContextMiddleware] Running in a worker thread โ€” ' + + 'ALS context will not propagate. This middleware must only run in the main Express process.', + ); + } + } + + const user = req.user as { tenantId?: string } | undefined; + + if (!user) { + next(); + return; + } + + const tenantId = user.tenantId; + + if (!tenantId) { + if (isStrict()) { + res.status(403).json({ error: 'Tenant context required in strict isolation mode' }); + return; + } + next(); + return; + } + + return void tenantStorage.run({ tenantId }, async () => { + next(); + }); +} diff --git a/packages/data-schemas/src/methods/userGroup.ts b/packages/data-schemas/src/methods/userGroup.ts index 5e11c26135..a41358337c 100644 --- a/packages/data-schemas/src/methods/userGroup.ts +++ b/packages/data-schemas/src/methods/userGroup.ts @@ -236,21 +236,28 @@ export function createUserGroupMethods(mongoose: typeof import('mongoose')) { } /** - * Get a list of all principal identifiers for a user (user ID + group IDs + public) - * For use in permission checks + * Get a list of all principal identifiers for a user (user ID + group IDs + public). + * For use in permission checks. + * + * Tenant filtering for group memberships is handled automatically by the + * `applyTenantIsolation` Mongoose plugin on the Group schema. The + * `tenantContextMiddleware` (chained by `requireJwtAuth` after passport auth) + * sets the ALS context, so `getUserGroups()` โ†’ `findGroupsByMemberId()` queries + * are scoped to the requesting tenant. No explicit tenantId parameter is needed. + * + * IMPORTANT: This relies on the ALS tenant context being active. If this + * function is called outside a request context (e.g. startup, background jobs), + * group queries will be unscoped. In strict mode, the Mongoose plugin will + * reject such queries. + * + * Ref: #12091 (resolved by tenant context middleware in requireJwtAuth) + * * @param params - Parameters object * @param params.userId - The user ID * @param params.role - Optional user role (if not provided, will query from DB) * @param session - Optional MongoDB session for transactions * @returns Array of principal objects with type and id */ - /** - * TODO(#12091): This method has no tenantId parameter โ€” it returns ALL group - * memberships for a user regardless of tenant. In multi-tenant mode, group - * principals from other tenants will be included in capability checks, which - * could grant cross-tenant capabilities. Add tenantId filtering here when - * tenant isolation is activated. - */ async function getUserPrincipals( params: { userId: string | Types.ObjectId; From 2e3d66cfe288be3ca75f26eee2fad3b482dd4803 Mon Sep 17 00:00:00 2001 From: Dustin Healy <54083382+dustinhealy@users.noreply.github.com> Date: Thu, 26 Mar 2026 14:36:18 -0700 Subject: [PATCH 09/18] =?UTF-8?q?=F0=9F=91=A5=20feat:=20Admin=20Groups=20A?= =?UTF-8?q?PI=20Endpoints=20(#12387)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: add listGroups and deleteGroup methods to userGroup * feat: add admin groups handler factory and Express routes * fix: address convention violations in admin groups handlers * fix: address Copilot review findings in admin groups handlers - Escape regex in listGroups to prevent injection/ReDoS - Validate ObjectId format in all handlers accepting id/userId params - Replace N+1 findUser loop with batched findUsers query - Remove unused findGroupsByMemberId from dep interface - Map Mongoose ValidationError to 400 in create/update handlers - Validate name in updateGroupHandler (reject empty/whitespace) - Handle null updateGroupById result (race condition) - Tighten error message matching in add/remove member handlers * test: add unit tests for admin groups handlers * fix: address code review findings for admin groups Atomic delete/update handlers (single DB trip), pass through idOnTheSource, add removeMemberById for non-ObjectId members, deduplicate member results, fix error message exposure, add hard cap/sort to listGroups, replace GroupListFilter with Pick of GroupFilterOptions, validate memberIds as array, trim name in update, fix import order, and improve test hygiene with fresh IDs per test. * fix: cascade cleanup, pagination, and test coverage for admin groups Add deleteGrantsForPrincipal to systemGrant data layer and wire cascade cleanup (Config, AclEntry, SystemGrant) into deleteGroupHandler. Add limit/offset pagination to getGroupMembers. Guard empty PATCH bodies with 400. Remove dead type guard and unnecessary type cast. Add 11 new tests covering cascade delete, idempotent member removal, empty update, search filter, 500 error paths, and pagination. * fix: harden admin groups with cascade resilience, type safety, and fallback removal Wrap cascade cleanup in inner try/catch so partial failure logs but still returns 200 (group is already deleted). Replace Record on deleteAclEntries with proper typed filter. Log warning for unmapped user ObjectIds in createGroup memberIds. Add removeMemberById fallback when removeUserFromGroup throws User not found for ObjectId-format userId. Extract VALID_GROUP_SOURCES constant. Add 3 new tests (60 total). * refactor: add countGroups, pagination, and projection type to data layer Extract buildGroupQuery helper, add countGroups method, support limit/offset/skip in listGroups, standardize session handling to .session(session ?? null), and tighten projection parameter from Record to Record. * fix: cascade resilience, pagination, validation, and error clarity for admin groups - Use Promise.allSettled for cascade cleanup so all steps run even if one fails; log individual rejections - Echo deleted group id in delete response - Add countGroups dep and wire limit/offset pagination for listGroups - Deduplicate memberIds before computing total in getGroupMembers - Use { memberIds: 1 } projection in getGroupMembers - Cap memberIds at 500 entries in createGroup - Reject search queries exceeding 200 characters - Clarify addGroupMember error for non-ObjectId userId - Document deleted-user fallback limitation in removeGroupMember * test: extend handler and DB-layer test coverage for admin groups Handler tests: projection assertion, dedup total, memberIds cap, search max length, non-ObjectId memberIds passthrough, cascade partial failure resilience, dedup scenarios, echo id in delete response. DB-layer tests: listGroups sort/filter/pagination, countGroups, deleteGroup, removeMemberById, deleteGrantsForPrincipal. * fix: cast group principalId to ObjectId for ACL entry cleanup deleteAclEntries is a thin deleteMany wrapper with no type casting, but grantPermission stores group principalId as ObjectId. Passing the raw string from req.params would leave orphaned ACL entries on group deletion. * refactor: remove redundant pagination clamping from DB listGroups Handler already clamps limit/offset at the API boundary. The DB method is a general-purpose building block and should not re-validate. * fix: add source and name validation, import order, and test coverage for admin groups - Validate source against VALID_GROUP_SOURCES in createGroupHandler - Cap name at 500 characters in both create and update handlers - Document total as upper bound in getGroupMembers response - Document ObjectId requirement for deleteAclEntries in cascade - Fix import ordering in test file (local value after type imports) - Add tests for updateGroup with description, email, avatar fields - Add tests for invalid source and name max-length in both handlers * fix: add field length caps, flatten nested try/catch, and fix logger level in admin groups Add max-length validation for description, email, avatar, and idOnTheSource in create/update handlers. Extract removeObjectIdMember helper to flatten nested try/catch per never-nesting convention. Downgrade unmapped-memberIds log from error to warn. Fix type import ordering and add missing await in removeMemberById for consistency. --- api/server/index.js | 1 + api/server/routes/admin/groups.js | 41 + api/server/routes/index.js | 2 + packages/api/src/admin/groups.spec.ts | 1348 +++++++++++++++++ packages/api/src/admin/groups.ts | 482 ++++++ packages/api/src/admin/index.ts | 2 + .../src/methods/systemGrant.spec.ts | 62 + .../data-schemas/src/methods/systemGrant.ts | 16 + packages/data-schemas/src/methods/user.ts | 14 + .../src/methods/userGroup.methods.spec.ts | 149 ++ .../data-schemas/src/methods/userGroup.ts | 102 +- 11 files changed, 2216 insertions(+), 3 deletions(-) create mode 100644 api/server/routes/admin/groups.js create mode 100644 packages/api/src/admin/groups.spec.ts create mode 100644 packages/api/src/admin/groups.ts diff --git a/api/server/index.js b/api/server/index.js index de99f06701..4ecc966476 100644 --- a/api/server/index.js +++ b/api/server/index.js @@ -144,6 +144,7 @@ const startServer = async () => { app.use('/api/auth', routes.auth); app.use('/api/admin', routes.adminAuth); app.use('/api/admin/config', routes.adminConfig); + app.use('/api/admin/groups', routes.adminGroups); app.use('/api/actions', routes.actions); app.use('/api/keys', routes.keys); app.use('/api/api-keys', routes.apiKeys); diff --git a/api/server/routes/admin/groups.js b/api/server/routes/admin/groups.js new file mode 100644 index 0000000000..7ca93acaa2 --- /dev/null +++ b/api/server/routes/admin/groups.js @@ -0,0 +1,41 @@ +const express = require('express'); +const { createAdminGroupsHandlers } = require('@librechat/api'); +const { SystemCapabilities } = require('@librechat/data-schemas'); +const { requireCapability } = require('~/server/middleware/roles/capabilities'); +const { requireJwtAuth } = require('~/server/middleware'); +const db = require('~/models'); + +const router = express.Router(); + +const requireAdminAccess = requireCapability(SystemCapabilities.ACCESS_ADMIN); +const requireReadGroups = requireCapability(SystemCapabilities.READ_GROUPS); +const requireManageGroups = requireCapability(SystemCapabilities.MANAGE_GROUPS); + +const handlers = createAdminGroupsHandlers({ + listGroups: db.listGroups, + countGroups: db.countGroups, + findGroupById: db.findGroupById, + createGroup: db.createGroup, + updateGroupById: db.updateGroupById, + deleteGroup: db.deleteGroup, + addUserToGroup: db.addUserToGroup, + removeUserFromGroup: db.removeUserFromGroup, + removeMemberById: db.removeMemberById, + findUsers: db.findUsers, + deleteConfig: db.deleteConfig, + deleteAclEntries: db.deleteAclEntries, + deleteGrantsForPrincipal: db.deleteGrantsForPrincipal, +}); + +router.use(requireJwtAuth, requireAdminAccess); + +router.get('/', requireReadGroups, handlers.listGroups); +router.post('/', requireManageGroups, handlers.createGroup); +router.get('/:id', requireReadGroups, handlers.getGroup); +router.patch('/:id', requireManageGroups, handlers.updateGroup); +router.delete('/:id', requireManageGroups, handlers.deleteGroup); +router.get('/:id/members', requireReadGroups, handlers.getGroupMembers); +router.post('/:id/members', requireManageGroups, handlers.addGroupMember); +router.delete('/:id/members/:userId', requireManageGroups, handlers.removeGroupMember); + +module.exports = router; diff --git a/api/server/routes/index.js b/api/server/routes/index.js index b1f16d5e3c..f9a088649c 100644 --- a/api/server/routes/index.js +++ b/api/server/routes/index.js @@ -3,6 +3,7 @@ const assistants = require('./assistants'); const categories = require('./categories'); const adminAuth = require('./admin/auth'); const adminConfig = require('./admin/config'); +const adminGroups = require('./admin/groups'); const endpoints = require('./endpoints'); const staticRoute = require('./static'); const messages = require('./messages'); @@ -33,6 +34,7 @@ module.exports = { auth, adminAuth, adminConfig, + adminGroups, keys, apiKeys, user, diff --git a/packages/api/src/admin/groups.spec.ts b/packages/api/src/admin/groups.spec.ts new file mode 100644 index 0000000000..42e32152d9 --- /dev/null +++ b/packages/api/src/admin/groups.spec.ts @@ -0,0 +1,1348 @@ +import { Types } from 'mongoose'; +import { PrincipalType } from 'librechat-data-provider'; +import type { IGroup, IUser } from '@librechat/data-schemas'; +import type { Response } from 'express'; +import type { ServerRequest } from '~/types/http'; +import type { AdminGroupsDeps } from './groups'; +import { createAdminGroupsHandlers } from './groups'; + +jest.mock('@librechat/data-schemas', () => ({ + ...jest.requireActual('@librechat/data-schemas'), + logger: { error: jest.fn(), warn: jest.fn() }, +})); + +describe('createAdminGroupsHandlers', () => { + let validId: string; + let validUserId: string; + + beforeEach(() => { + validId = new Types.ObjectId().toString(); + validUserId = new Types.ObjectId().toString(); + }); + + function mockGroup(overrides: Partial = {}): IGroup { + return { + _id: new Types.ObjectId(validId), + name: 'Test Group', + source: 'local', + memberIds: [], + createdAt: new Date(), + updatedAt: new Date(), + ...overrides, + } as IGroup; + } + + function mockUser(overrides: Partial = {}): IUser { + return { + _id: new Types.ObjectId(validUserId), + name: 'Test User', + email: 'test@example.com', + avatar: 'https://example.com/avatar.png', + ...overrides, + } as IUser; + } + + function createReqRes( + overrides: { + params?: Record; + query?: Record; + body?: Record; + } = {}, + ) { + const req = { + params: overrides.params ?? {}, + query: overrides.query ?? {}, + body: overrides.body ?? {}, + } as unknown as ServerRequest; + + const json = jest.fn(); + const status = jest.fn().mockReturnValue({ json }); + const res = { status, json } as unknown as Response; + + return { req, res, status, json }; + } + + function createDeps(overrides: Partial = {}): AdminGroupsDeps { + return { + listGroups: jest.fn().mockResolvedValue([]), + countGroups: jest.fn().mockResolvedValue(0), + findGroupById: jest.fn().mockResolvedValue(null), + createGroup: jest.fn().mockResolvedValue(mockGroup()), + updateGroupById: jest.fn().mockResolvedValue(mockGroup()), + deleteGroup: jest.fn().mockResolvedValue(mockGroup()), + addUserToGroup: jest.fn().mockResolvedValue({ user: mockUser(), group: mockGroup() }), + removeUserFromGroup: jest.fn().mockResolvedValue({ user: mockUser(), group: mockGroup() }), + removeMemberById: jest.fn().mockResolvedValue(mockGroup()), + findUsers: jest.fn().mockResolvedValue([]), + deleteConfig: jest.fn().mockResolvedValue(null), + deleteAclEntries: jest.fn().mockResolvedValue({ deletedCount: 0 }), + deleteGrantsForPrincipal: jest.fn().mockResolvedValue(undefined), + ...overrides, + }; + } + + describe('listGroups', () => { + it('returns groups with total, limit, offset', async () => { + const groups = [mockGroup()]; + const deps = createDeps({ + listGroups: jest.fn().mockResolvedValue(groups), + countGroups: jest.fn().mockResolvedValue(1), + }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ query: {} }); + + await handlers.listGroups(req, res); + + expect(status).toHaveBeenCalledWith(200); + expect(json).toHaveBeenCalledWith({ groups, total: 1, limit: 50, offset: 0 }); + }); + + it('passes source and search filters with pagination', async () => { + const deps = createDeps(); + const handlers = createAdminGroupsHandlers(deps); + const { req, res } = createReqRes({ + query: { source: 'entra', search: 'engineering', limit: '20', offset: '10' }, + }); + + await handlers.listGroups(req, res); + + expect(deps.listGroups).toHaveBeenCalledWith({ + source: 'entra', + search: 'engineering', + limit: 20, + offset: 10, + }); + expect(deps.countGroups).toHaveBeenCalledWith({ + source: 'entra', + search: 'engineering', + }); + }); + + it('passes search filter alone', async () => { + const deps = createDeps(); + const handlers = createAdminGroupsHandlers(deps); + const { req, res } = createReqRes({ query: { search: 'eng' } }); + + await handlers.listGroups(req, res); + + expect(deps.listGroups).toHaveBeenCalledWith({ search: 'eng', limit: 50, offset: 0 }); + expect(deps.countGroups).toHaveBeenCalledWith({ search: 'eng' }); + }); + + it('ignores invalid source values', async () => { + const deps = createDeps(); + const handlers = createAdminGroupsHandlers(deps); + const { req, res } = createReqRes({ query: { source: 'invalid' } }); + + await handlers.listGroups(req, res); + + expect(deps.listGroups).toHaveBeenCalledWith({ limit: 50, offset: 0 }); + expect(deps.countGroups).toHaveBeenCalledWith({}); + }); + + it('clamps limit and offset', async () => { + const deps = createDeps(); + const handlers = createAdminGroupsHandlers(deps); + const { req, res } = createReqRes({ query: { limit: '999', offset: '-5' } }); + + await handlers.listGroups(req, res); + + expect(deps.listGroups).toHaveBeenCalledWith({ limit: 200, offset: 0 }); + }); + + it('returns 400 when search exceeds max length', async () => { + const deps = createDeps(); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ + query: { search: 'a'.repeat(201) }, + }); + + await handlers.listGroups(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'search must not exceed 200 characters' }); + expect(deps.listGroups).not.toHaveBeenCalled(); + }); + + it('returns 500 when countGroups fails', async () => { + const deps = createDeps({ + countGroups: jest.fn().mockRejectedValue(new Error('count failed')), + }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes(); + + await handlers.listGroups(req, res); + + expect(status).toHaveBeenCalledWith(500); + expect(json).toHaveBeenCalledWith({ error: 'Failed to list groups' }); + }); + + it('returns 500 on error', async () => { + const deps = createDeps({ listGroups: jest.fn().mockRejectedValue(new Error('db down')) }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes(); + + await handlers.listGroups(req, res); + + expect(status).toHaveBeenCalledWith(500); + expect(json).toHaveBeenCalledWith({ error: 'Failed to list groups' }); + }); + }); + + describe('getGroup', () => { + it('returns group with 200', async () => { + const group = mockGroup(); + const deps = createDeps({ findGroupById: jest.fn().mockResolvedValue(group) }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ params: { id: validId } }); + + await handlers.getGroup(req, res); + + expect(status).toHaveBeenCalledWith(200); + expect(json).toHaveBeenCalledWith({ group }); + }); + + it('returns 400 for invalid ID', async () => { + const deps = createDeps(); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ params: { id: 'not-an-id' } }); + + await handlers.getGroup(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'Invalid group ID format' }); + expect(deps.findGroupById).not.toHaveBeenCalled(); + }); + + it('returns 404 when group not found', async () => { + const deps = createDeps({ findGroupById: jest.fn().mockResolvedValue(null) }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ params: { id: validId } }); + + await handlers.getGroup(req, res); + + expect(status).toHaveBeenCalledWith(404); + expect(json).toHaveBeenCalledWith({ error: 'Group not found' }); + }); + + it('returns 500 on error', async () => { + const deps = createDeps({ + findGroupById: jest.fn().mockRejectedValue(new Error('db down')), + }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ params: { id: validId } }); + + await handlers.getGroup(req, res); + + expect(status).toHaveBeenCalledWith(500); + expect(json).toHaveBeenCalledWith({ error: 'Failed to get group' }); + }); + }); + + describe('createGroup', () => { + it('creates group and returns 201', async () => { + const group = mockGroup(); + const deps = createDeps({ createGroup: jest.fn().mockResolvedValue(group) }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ + body: { name: 'New Group', description: 'A group' }, + }); + + await handlers.createGroup(req, res); + + expect(status).toHaveBeenCalledWith(201); + expect(json).toHaveBeenCalledWith({ group }); + expect(deps.createGroup).toHaveBeenCalledWith( + expect.objectContaining({ + name: 'New Group', + description: 'A group', + source: 'local', + memberIds: [], + }), + ); + }); + + it('normalizes memberIds to idOnTheSource values', async () => { + const userId = new Types.ObjectId().toString(); + const user = { _id: new Types.ObjectId(userId), idOnTheSource: 'ext-norm-1' } as IUser; + const group = mockGroup(); + const deps = createDeps({ + createGroup: jest.fn().mockResolvedValue(group), + findUsers: jest.fn().mockResolvedValue([user]), + }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status } = createReqRes({ + body: { name: 'With Members', memberIds: [userId] }, + }); + + await handlers.createGroup(req, res); + + expect(status).toHaveBeenCalledWith(201); + expect(deps.findUsers).toHaveBeenCalledWith({ _id: { $in: [userId] } }, 'idOnTheSource'); + expect(deps.createGroup).toHaveBeenCalledWith( + expect.objectContaining({ memberIds: ['ext-norm-1'] }), + ); + }); + + it('logs warning when memberIds contain non-existent user ObjectIds', async () => { + const { logger } = jest.requireMock('@librechat/data-schemas'); + const unknownId = new Types.ObjectId().toString(); + const group = mockGroup(); + const deps = createDeps({ + createGroup: jest.fn().mockResolvedValue(group), + findUsers: jest.fn().mockResolvedValue([]), + }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status } = createReqRes({ + body: { name: 'With Unknown', memberIds: [unknownId] }, + }); + + await handlers.createGroup(req, res); + + expect(status).toHaveBeenCalledWith(201); + expect(logger.warn).toHaveBeenCalledWith( + '[adminGroups] createGroup: memberIds contain unknown user ObjectIds:', + [unknownId], + ); + }); + + it('passes idOnTheSource when provided', async () => { + const group = mockGroup(); + const deps = createDeps({ createGroup: jest.fn().mockResolvedValue(group) }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status } = createReqRes({ + body: { name: 'Entra Group', source: 'entra', idOnTheSource: 'ent-abc-123' }, + }); + + await handlers.createGroup(req, res); + + expect(status).toHaveBeenCalledWith(201); + expect(deps.createGroup).toHaveBeenCalledWith( + expect.objectContaining({ idOnTheSource: 'ent-abc-123', source: 'entra' }), + ); + }); + + it('returns 400 for invalid source value', async () => { + const deps = createDeps(); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ + body: { name: 'Bad Source', source: 'azure' }, + }); + + await handlers.createGroup(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'Invalid source value' }); + expect(deps.createGroup).not.toHaveBeenCalled(); + }); + + it('returns 400 when name exceeds max length', async () => { + const deps = createDeps(); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ + body: { name: 'a'.repeat(501) }, + }); + + await handlers.createGroup(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'name must not exceed 500 characters' }); + expect(deps.createGroup).not.toHaveBeenCalled(); + }); + + it('returns 400 when description exceeds max length', async () => { + const deps = createDeps(); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ + body: { name: 'Valid', description: 'x'.repeat(2001) }, + }); + + await handlers.createGroup(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ + error: 'description must not exceed 2000 characters', + }); + expect(deps.createGroup).not.toHaveBeenCalled(); + }); + + it('returns 400 when email exceeds max length', async () => { + const deps = createDeps(); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ + body: { name: 'Valid', email: 'x'.repeat(501) }, + }); + + await handlers.createGroup(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'email must not exceed 500 characters' }); + expect(deps.createGroup).not.toHaveBeenCalled(); + }); + + it('returns 400 when avatar exceeds max length', async () => { + const deps = createDeps(); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ + body: { name: 'Valid', avatar: 'x'.repeat(2001) }, + }); + + await handlers.createGroup(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'avatar must not exceed 2000 characters' }); + expect(deps.createGroup).not.toHaveBeenCalled(); + }); + + it('returns 400 when idOnTheSource exceeds max length', async () => { + const deps = createDeps(); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ + body: { name: 'Valid', idOnTheSource: 'x'.repeat(501) }, + }); + + await handlers.createGroup(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ + error: 'idOnTheSource must not exceed 500 characters', + }); + expect(deps.createGroup).not.toHaveBeenCalled(); + }); + + it('returns 400 when memberIds exceeds cap', async () => { + const deps = createDeps(); + const handlers = createAdminGroupsHandlers(deps); + const memberIds = Array.from({ length: 501 }, (_, i) => `ext-${i}`); + const { req, res, status, json } = createReqRes({ + body: { name: 'Too Many Members', memberIds }, + }); + + await handlers.createGroup(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'memberIds must not exceed 500 entries' }); + expect(deps.createGroup).not.toHaveBeenCalled(); + }); + + it('passes non-ObjectId memberIds through unchanged', async () => { + const deps = createDeps(); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status } = createReqRes({ + body: { name: 'Ext Group', memberIds: ['ext-1', 'ext-2'] }, + }); + + await handlers.createGroup(req, res); + + expect(deps.findUsers).not.toHaveBeenCalled(); + expect(deps.createGroup).toHaveBeenCalledWith( + expect.objectContaining({ memberIds: ['ext-1', 'ext-2'] }), + ); + expect(status).toHaveBeenCalledWith(201); + }); + + it('returns 400 when name is missing', async () => { + const deps = createDeps(); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ body: {} }); + + await handlers.createGroup(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'name is required' }); + expect(deps.createGroup).not.toHaveBeenCalled(); + }); + + it('returns 400 when name is whitespace-only', async () => { + const deps = createDeps(); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ body: { name: ' ' } }); + + await handlers.createGroup(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'name is required' }); + }); + + it('returns 400 on ValidationError', async () => { + const validationError = new Error('source must be local or entra'); + validationError.name = 'ValidationError'; + const deps = createDeps({ createGroup: jest.fn().mockRejectedValue(validationError) }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ body: { name: 'Test' } }); + + await handlers.createGroup(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'source must be local or entra' }); + }); + + it('returns 500 on unexpected error', async () => { + const deps = createDeps({ createGroup: jest.fn().mockRejectedValue(new Error('db crash')) }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ body: { name: 'Test' } }); + + await handlers.createGroup(req, res); + + expect(status).toHaveBeenCalledWith(500); + expect(json).toHaveBeenCalledWith({ error: 'Failed to create group' }); + }); + }); + + describe('updateGroup', () => { + it('updates group and returns 200', async () => { + const group = mockGroup({ name: 'Updated' }); + const deps = createDeps({ + updateGroupById: jest.fn().mockResolvedValue(group), + }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { id: validId }, + body: { name: 'Updated' }, + }); + + await handlers.updateGroup(req, res); + + expect(status).toHaveBeenCalledWith(200); + expect(json).toHaveBeenCalledWith({ group }); + }); + + it('updates description only', async () => { + const group = mockGroup({ description: 'New desc' }); + const deps = createDeps({ updateGroupById: jest.fn().mockResolvedValue(group) }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status } = createReqRes({ + params: { id: validId }, + body: { description: 'New desc' }, + }); + + await handlers.updateGroup(req, res); + + expect(status).toHaveBeenCalledWith(200); + expect(deps.updateGroupById).toHaveBeenCalledWith(validId, { description: 'New desc' }); + }); + + it('updates email only', async () => { + const group = mockGroup({ email: 'team@co.com' }); + const deps = createDeps({ updateGroupById: jest.fn().mockResolvedValue(group) }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status } = createReqRes({ + params: { id: validId }, + body: { email: 'team@co.com' }, + }); + + await handlers.updateGroup(req, res); + + expect(status).toHaveBeenCalledWith(200); + expect(deps.updateGroupById).toHaveBeenCalledWith(validId, { email: 'team@co.com' }); + }); + + it('updates avatar only', async () => { + const group = mockGroup({ avatar: 'https://img.co/a.png' }); + const deps = createDeps({ updateGroupById: jest.fn().mockResolvedValue(group) }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status } = createReqRes({ + params: { id: validId }, + body: { avatar: 'https://img.co/a.png' }, + }); + + await handlers.updateGroup(req, res); + + expect(status).toHaveBeenCalledWith(200); + expect(deps.updateGroupById).toHaveBeenCalledWith(validId, { + avatar: 'https://img.co/a.png', + }); + }); + + it('updates multiple fields at once', async () => { + const group = mockGroup({ name: 'New', description: 'Desc', email: 'a@b.com' }); + const deps = createDeps({ updateGroupById: jest.fn().mockResolvedValue(group) }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status } = createReqRes({ + params: { id: validId }, + body: { name: ' New ', description: 'Desc', email: 'a@b.com' }, + }); + + await handlers.updateGroup(req, res); + + expect(status).toHaveBeenCalledWith(200); + expect(deps.updateGroupById).toHaveBeenCalledWith(validId, { + name: 'New', + description: 'Desc', + email: 'a@b.com', + }); + }); + + it('returns 400 for invalid ID', async () => { + const deps = createDeps(); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { id: 'bad' }, + body: { name: 'Updated' }, + }); + + await handlers.updateGroup(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'Invalid group ID format' }); + }); + + it('returns 400 when name is empty string', async () => { + const deps = createDeps(); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { id: validId }, + body: { name: '' }, + }); + + await handlers.updateGroup(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'name must be a non-empty string' }); + }); + + it('returns 400 when name is whitespace-only', async () => { + const deps = createDeps(); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { id: validId }, + body: { name: ' ' }, + }); + + await handlers.updateGroup(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'name must be a non-empty string' }); + }); + + it('returns 400 when name exceeds max length', async () => { + const deps = createDeps(); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { id: validId }, + body: { name: 'a'.repeat(501) }, + }); + + await handlers.updateGroup(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'name must not exceed 500 characters' }); + expect(deps.updateGroupById).not.toHaveBeenCalled(); + }); + + it('returns 400 when description exceeds max length', async () => { + const deps = createDeps(); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { id: validId }, + body: { description: 'x'.repeat(2001) }, + }); + + await handlers.updateGroup(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ + error: 'description must not exceed 2000 characters', + }); + expect(deps.updateGroupById).not.toHaveBeenCalled(); + }); + + it('returns 400 when email exceeds max length', async () => { + const deps = createDeps(); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { id: validId }, + body: { email: 'x'.repeat(501) }, + }); + + await handlers.updateGroup(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'email must not exceed 500 characters' }); + expect(deps.updateGroupById).not.toHaveBeenCalled(); + }); + + it('returns 400 when avatar exceeds max length', async () => { + const deps = createDeps(); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { id: validId }, + body: { avatar: 'x'.repeat(2001) }, + }); + + await handlers.updateGroup(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'avatar must not exceed 2000 characters' }); + expect(deps.updateGroupById).not.toHaveBeenCalled(); + }); + + it('returns 400 when no valid fields provided', async () => { + const deps = createDeps(); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { id: validId }, + body: {}, + }); + + await handlers.updateGroup(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'No valid fields to update' }); + expect(deps.updateGroupById).not.toHaveBeenCalled(); + }); + + it('returns 404 when updateGroupById returns null', async () => { + const deps = createDeps({ + updateGroupById: jest.fn().mockResolvedValue(null), + }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { id: validId }, + body: { name: 'Updated' }, + }); + + await handlers.updateGroup(req, res); + + expect(status).toHaveBeenCalledWith(404); + expect(json).toHaveBeenCalledWith({ error: 'Group not found' }); + }); + + it('returns 400 on ValidationError', async () => { + const validationError = new Error('invalid field'); + validationError.name = 'ValidationError'; + const deps = createDeps({ + updateGroupById: jest.fn().mockRejectedValue(validationError), + }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { id: validId }, + body: { name: 'Updated' }, + }); + + await handlers.updateGroup(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'invalid field' }); + }); + + it('returns 500 on error', async () => { + const deps = createDeps({ + updateGroupById: jest.fn().mockRejectedValue(new Error('db down')), + }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { id: validId }, + body: { name: 'Updated' }, + }); + + await handlers.updateGroup(req, res); + + expect(status).toHaveBeenCalledWith(500); + expect(json).toHaveBeenCalledWith({ error: 'Failed to update group' }); + }); + }); + + describe('deleteGroup', () => { + it('deletes group and returns 200 with id', async () => { + const deps = createDeps({ deleteGroup: jest.fn().mockResolvedValue(mockGroup()) }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ params: { id: validId } }); + + await handlers.deleteGroup(req, res); + + expect(deps.deleteGroup).toHaveBeenCalledWith(validId); + expect(status).toHaveBeenCalledWith(200); + expect(json).toHaveBeenCalledWith({ success: true, id: validId }); + }); + + it('returns 400 for invalid ID', async () => { + const deps = createDeps(); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ params: { id: 'bad-id' } }); + + await handlers.deleteGroup(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'Invalid group ID format' }); + }); + + it('returns 404 when deleteGroup returns null', async () => { + const deps = createDeps({ deleteGroup: jest.fn().mockResolvedValue(null) }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ params: { id: validId } }); + + await handlers.deleteGroup(req, res); + + expect(status).toHaveBeenCalledWith(404); + expect(json).toHaveBeenCalledWith({ error: 'Group not found' }); + expect(deps.deleteConfig).not.toHaveBeenCalled(); + }); + + it('returns 500 on error', async () => { + const deps = createDeps({ + deleteGroup: jest.fn().mockRejectedValue(new Error('db down')), + }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ params: { id: validId } }); + + await handlers.deleteGroup(req, res); + + expect(status).toHaveBeenCalledWith(500); + expect(json).toHaveBeenCalledWith({ error: 'Failed to delete group' }); + }); + + it('returns 200 even when cascade cleanup partially fails', async () => { + const deps = createDeps({ + deleteGroup: jest.fn().mockResolvedValue(mockGroup()), + deleteAclEntries: jest.fn().mockRejectedValue(new Error('cleanup failed')), + }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ params: { id: validId } }); + + await handlers.deleteGroup(req, res); + + expect(status).toHaveBeenCalledWith(200); + expect(json).toHaveBeenCalledWith({ success: true, id: validId }); + expect(deps.deleteConfig).toHaveBeenCalledWith(PrincipalType.GROUP, validId); + expect(deps.deleteAclEntries).toHaveBeenCalledWith({ + principalType: PrincipalType.GROUP, + principalId: new Types.ObjectId(validId), + }); + expect(deps.deleteGrantsForPrincipal).toHaveBeenCalledWith(PrincipalType.GROUP, validId); + }); + + it('cleans up Config, AclEntry, and SystemGrant on group delete', async () => { + const deps = createDeps({ deleteGroup: jest.fn().mockResolvedValue(mockGroup()) }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status } = createReqRes({ params: { id: validId } }); + + await handlers.deleteGroup(req, res); + + expect(status).toHaveBeenCalledWith(200); + expect(deps.deleteConfig).toHaveBeenCalledWith(PrincipalType.GROUP, validId); + expect(deps.deleteAclEntries).toHaveBeenCalledWith({ + principalType: PrincipalType.GROUP, + principalId: new Types.ObjectId(validId), + }); + expect(deps.deleteGrantsForPrincipal).toHaveBeenCalledWith(PrincipalType.GROUP, validId); + }); + }); + + describe('getGroupMembers', () => { + it('fetches group with memberIds projection only', async () => { + const group = mockGroup({ memberIds: [] }); + const deps = createDeps({ findGroupById: jest.fn().mockResolvedValue(group) }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res } = createReqRes({ params: { id: validId } }); + + await handlers.getGroupMembers(req, res); + + expect(deps.findGroupById).toHaveBeenCalledWith(validId, { memberIds: 1 }); + }); + + it('returns empty members for group with no memberIds', async () => { + const group = mockGroup({ memberIds: [] }); + const deps = createDeps({ findGroupById: jest.fn().mockResolvedValue(group) }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ params: { id: validId } }); + + await handlers.getGroupMembers(req, res); + + expect(status).toHaveBeenCalledWith(200); + expect(json).toHaveBeenCalledWith({ members: [], total: 0, limit: 50, offset: 0 }); + expect(deps.findUsers).not.toHaveBeenCalled(); + }); + + it('batches member lookup with $or query', async () => { + const user = mockUser({ idOnTheSource: 'ext-123' }); + const group = mockGroup({ memberIds: [validUserId, 'ext-123'] }); + const deps = createDeps({ + findGroupById: jest.fn().mockResolvedValue(group), + findUsers: jest.fn().mockResolvedValue([user]), + }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ params: { id: validId } }); + + await handlers.getGroupMembers(req, res); + + expect(deps.findUsers).toHaveBeenCalledWith( + { + $or: [ + { idOnTheSource: { $in: [validUserId, 'ext-123'] } }, + { _id: { $in: [validUserId] } }, + ], + }, + 'name email avatar idOnTheSource', + ); + expect(status).toHaveBeenCalledWith(200); + const members = json.mock.calls[0][0].members; + expect(members).toHaveLength(1); + }); + + it('skips _id condition when no valid ObjectIds in memberIds', async () => { + const group = mockGroup({ memberIds: ['ext-1', 'ext-2'] }); + const deps = createDeps({ + findGroupById: jest.fn().mockResolvedValue(group), + findUsers: jest.fn().mockResolvedValue([]), + }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res } = createReqRes({ params: { id: validId } }); + + await handlers.getGroupMembers(req, res); + + expect(deps.findUsers).toHaveBeenCalledWith( + { $or: [{ idOnTheSource: { $in: ['ext-1', 'ext-2'] } }] }, + 'name email avatar idOnTheSource', + ); + }); + + it('falls back to memberId when user not found', async () => { + const group = mockGroup({ memberIds: ['unknown-member'] }); + const deps = createDeps({ + findGroupById: jest.fn().mockResolvedValue(group), + findUsers: jest.fn().mockResolvedValue([]), + }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, json } = createReqRes({ params: { id: validId } }); + + await handlers.getGroupMembers(req, res); + + expect(json.mock.calls[0][0].members).toEqual([ + { userId: 'unknown-member', name: 'unknown-member', email: '', avatarUrl: undefined }, + ]); + }); + + it('deduplicates when identical memberId appears twice', async () => { + const user = mockUser({ idOnTheSource: validUserId }); + const group = mockGroup({ memberIds: [validUserId, validUserId] }); + const deps = createDeps({ + findGroupById: jest.fn().mockResolvedValue(group), + findUsers: jest.fn().mockResolvedValue([user]), + }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, json } = createReqRes({ params: { id: validId } }); + + await handlers.getGroupMembers(req, res); + + const result = json.mock.calls[0][0]; + expect(result.members).toHaveLength(1); + expect(result.total).toBe(1); + }); + + it('deduplicates when objectId and idOnTheSource both present for same user', async () => { + const extId = 'ext-dedup-123'; + const user = mockUser({ idOnTheSource: extId }); + const group = mockGroup({ memberIds: [validUserId, extId] }); + const deps = createDeps({ + findGroupById: jest.fn().mockResolvedValue(group), + findUsers: jest.fn().mockResolvedValue([user]), + }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, json } = createReqRes({ params: { id: validId } }); + + await handlers.getGroupMembers(req, res); + + expect(json.mock.calls[0][0].members).toHaveLength(1); + }); + + it('reports deduplicated total for duplicate memberIds', async () => { + const group = mockGroup({ memberIds: ['m1', 'm2', 'm1', 'm3', 'm2'] }); + const deps = createDeps({ + findGroupById: jest.fn().mockResolvedValue(group), + findUsers: jest.fn().mockResolvedValue([]), + }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, json } = createReqRes({ params: { id: validId } }); + + await handlers.getGroupMembers(req, res); + + const result = json.mock.calls[0][0]; + expect(result.total).toBe(3); + expect(result.members).toHaveLength(3); + }); + + it('paginates members with limit and offset', async () => { + const ids = ['m1', 'm2', 'm3', 'm4', 'm5']; + const group = mockGroup({ memberIds: ids }); + const deps = createDeps({ + findGroupById: jest.fn().mockResolvedValue(group), + findUsers: jest.fn().mockResolvedValue([]), + }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, json } = createReqRes({ + params: { id: validId }, + query: { limit: '2', offset: '1' }, + }); + + await handlers.getGroupMembers(req, res); + + const result = json.mock.calls[0][0]; + expect(result.total).toBe(5); + expect(result.limit).toBe(2); + expect(result.offset).toBe(1); + expect(result.members).toHaveLength(2); + expect(result.members[0].userId).toBe('m2'); + expect(result.members[1].userId).toBe('m3'); + }); + + it('caps limit at 200', async () => { + const ids = Array.from({ length: 5 }, (_, i) => `m${i}`); + const group = mockGroup({ memberIds: ids }); + const deps = createDeps({ + findGroupById: jest.fn().mockResolvedValue(group), + findUsers: jest.fn().mockResolvedValue([]), + }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, json } = createReqRes({ + params: { id: validId }, + query: { limit: '999' }, + }); + + await handlers.getGroupMembers(req, res); + + const result = json.mock.calls[0][0]; + expect(result.limit).toBe(200); + }); + + it('returns empty when offset exceeds total', async () => { + const group = mockGroup({ memberIds: ['m1', 'm2'] }); + const deps = createDeps({ + findGroupById: jest.fn().mockResolvedValue(group), + }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, json } = createReqRes({ + params: { id: validId }, + query: { offset: '10' }, + }); + + await handlers.getGroupMembers(req, res); + + const result = json.mock.calls[0][0]; + expect(result.members).toHaveLength(0); + expect(result.total).toBe(2); + expect(deps.findUsers).not.toHaveBeenCalled(); + }); + + it('returns 400 for invalid group ID', async () => { + const deps = createDeps(); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ params: { id: 'nope' } }); + + await handlers.getGroupMembers(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'Invalid group ID format' }); + }); + + it('returns 404 when group not found', async () => { + const deps = createDeps({ findGroupById: jest.fn().mockResolvedValue(null) }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ params: { id: validId } }); + + await handlers.getGroupMembers(req, res); + + expect(status).toHaveBeenCalledWith(404); + expect(json).toHaveBeenCalledWith({ error: 'Group not found' }); + }); + + it('returns 500 on error', async () => { + const deps = createDeps({ + findGroupById: jest.fn().mockRejectedValue(new Error('db down')), + }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ params: { id: validId } }); + + await handlers.getGroupMembers(req, res); + + expect(status).toHaveBeenCalledWith(500); + expect(json).toHaveBeenCalledWith({ error: 'Failed to get group members' }); + }); + }); + + describe('addGroupMember', () => { + it('adds member and returns 200', async () => { + const group = mockGroup(); + const deps = createDeps({ + addUserToGroup: jest.fn().mockResolvedValue({ user: mockUser(), group }), + }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { id: validId }, + body: { userId: validUserId }, + }); + + await handlers.addGroupMember(req, res); + + expect(deps.addUserToGroup).toHaveBeenCalledWith(validUserId, validId); + expect(status).toHaveBeenCalledWith(200); + expect(json).toHaveBeenCalledWith({ group }); + }); + + it('returns 400 for invalid group ID', async () => { + const deps = createDeps(); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { id: 'bad' }, + body: { userId: validUserId }, + }); + + await handlers.addGroupMember(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'Invalid group ID format' }); + }); + + it('returns 400 when userId is missing', async () => { + const deps = createDeps(); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { id: validId }, + body: {}, + }); + + await handlers.addGroupMember(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'userId is required' }); + }); + + it('returns 400 for non-ObjectId userId', async () => { + const deps = createDeps(); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { id: validId }, + body: { userId: 'not-valid' }, + }); + + await handlers.addGroupMember(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ + error: 'Only native user ObjectIds can be added via this endpoint', + }); + }); + + it('returns 404 when addUserToGroup returns null group', async () => { + const deps = createDeps({ + addUserToGroup: jest.fn().mockResolvedValue({ user: mockUser(), group: null }), + }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { id: validId }, + body: { userId: validUserId }, + }); + + await handlers.addGroupMember(req, res); + + expect(status).toHaveBeenCalledWith(404); + expect(json).toHaveBeenCalledWith({ error: 'Group not found' }); + }); + + it('returns 404 for "User not found" error', async () => { + const deps = createDeps({ + addUserToGroup: jest.fn().mockRejectedValue(new Error('User not found')), + }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { id: validId }, + body: { userId: validUserId }, + }); + + await handlers.addGroupMember(req, res); + + expect(status).toHaveBeenCalledWith(404); + expect(json).toHaveBeenCalledWith({ error: 'User not found' }); + }); + + it('returns 500 for unrelated errors', async () => { + const deps = createDeps({ + addUserToGroup: jest.fn().mockRejectedValue(new Error('connection lost')), + }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { id: validId }, + body: { userId: validUserId }, + }); + + await handlers.addGroupMember(req, res); + + expect(status).toHaveBeenCalledWith(500); + expect(json).toHaveBeenCalledWith({ error: 'Failed to add member' }); + }); + + it('does not misclassify errors containing "not found" substring', async () => { + const deps = createDeps({ + addUserToGroup: jest.fn().mockRejectedValue(new Error('Permission not found in config')), + }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status } = createReqRes({ + params: { id: validId }, + body: { userId: validUserId }, + }); + + await handlers.addGroupMember(req, res); + + expect(status).toHaveBeenCalledWith(500); + }); + }); + + describe('removeGroupMember', () => { + it('removes member and returns 200', async () => { + const deps = createDeps({ + removeUserFromGroup: jest.fn().mockResolvedValue({ user: mockUser(), group: mockGroup() }), + }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { id: validId, userId: validUserId }, + }); + + await handlers.removeGroupMember(req, res); + + expect(deps.removeUserFromGroup).toHaveBeenCalledWith(validUserId, validId); + expect(status).toHaveBeenCalledWith(200); + expect(json).toHaveBeenCalledWith({ success: true }); + }); + + it('returns 400 for invalid group ID', async () => { + const deps = createDeps(); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { id: 'bad', userId: validUserId }, + }); + + await handlers.removeGroupMember(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'Invalid group ID format' }); + }); + + it('removes non-ObjectId member via removeMemberById', async () => { + const deps = createDeps({ + removeMemberById: jest.fn().mockResolvedValue(mockGroup()), + }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { id: validId, userId: 'ent-abc-123' }, + }); + + await handlers.removeGroupMember(req, res); + + expect(deps.removeMemberById).toHaveBeenCalledWith(validId, 'ent-abc-123'); + expect(deps.removeUserFromGroup).not.toHaveBeenCalled(); + expect(status).toHaveBeenCalledWith(200); + expect(json).toHaveBeenCalledWith({ success: true }); + }); + + it('returns 404 when removeMemberById returns null', async () => { + const deps = createDeps({ + removeMemberById: jest.fn().mockResolvedValue(null), + }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { id: validId, userId: 'ent-abc-123' }, + }); + + await handlers.removeGroupMember(req, res); + + expect(status).toHaveBeenCalledWith(404); + expect(json).toHaveBeenCalledWith({ error: 'Group not found' }); + }); + + it('falls back to removeMemberById when ObjectId userId not found as user', async () => { + const deps = createDeps({ + removeUserFromGroup: jest.fn().mockRejectedValue(new Error('User not found')), + removeMemberById: jest.fn().mockResolvedValue(mockGroup()), + }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { id: validId, userId: validUserId }, + }); + + await handlers.removeGroupMember(req, res); + + expect(deps.removeUserFromGroup).toHaveBeenCalledWith(validUserId, validId); + expect(deps.removeMemberById).toHaveBeenCalledWith(validId, validUserId); + expect(status).toHaveBeenCalledWith(200); + expect(json).toHaveBeenCalledWith({ success: true }); + }); + + it('returns 404 when removeUserFromGroup returns null group', async () => { + const deps = createDeps({ + removeUserFromGroup: jest.fn().mockResolvedValue({ user: mockUser(), group: null }), + }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { id: validId, userId: validUserId }, + }); + + await handlers.removeGroupMember(req, res); + + expect(status).toHaveBeenCalledWith(404); + expect(json).toHaveBeenCalledWith({ error: 'Group not found' }); + }); + + it('returns 404 when fallback removeMemberById also returns null', async () => { + const deps = createDeps({ + removeUserFromGroup: jest.fn().mockRejectedValue(new Error('User not found')), + removeMemberById: jest.fn().mockResolvedValue(null), + }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { id: validId, userId: validUserId }, + }); + + await handlers.removeGroupMember(req, res); + + expect(status).toHaveBeenCalledWith(404); + expect(json).toHaveBeenCalledWith({ error: 'Group not found' }); + }); + + it('returns 500 for unrelated errors', async () => { + const deps = createDeps({ + removeUserFromGroup: jest.fn().mockRejectedValue(new Error('timeout')), + }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { id: validId, userId: validUserId }, + }); + + await handlers.removeGroupMember(req, res); + + expect(status).toHaveBeenCalledWith(500); + expect(json).toHaveBeenCalledWith({ error: 'Failed to remove member' }); + }); + + it('returns 200 when removing ObjectId member not in group (idempotent delete)', async () => { + const group = mockGroup({ memberIds: [] }); + const deps = createDeps({ + removeUserFromGroup: jest.fn().mockResolvedValue({ user: mockUser(), group }), + }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { id: validId, userId: validUserId }, + }); + + await handlers.removeGroupMember(req, res); + + expect(status).toHaveBeenCalledWith(200); + expect(json).toHaveBeenCalledWith({ success: true }); + }); + + it('returns 200 when removing non-ObjectId member not in group (idempotent delete)', async () => { + const group = mockGroup({ memberIds: [] }); + const deps = createDeps({ + removeMemberById: jest.fn().mockResolvedValue(group), + }); + const handlers = createAdminGroupsHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { id: validId, userId: 'ext-not-in-group' }, + }); + + await handlers.removeGroupMember(req, res); + + expect(status).toHaveBeenCalledWith(200); + expect(json).toHaveBeenCalledWith({ success: true }); + }); + }); +}); diff --git a/packages/api/src/admin/groups.ts b/packages/api/src/admin/groups.ts new file mode 100644 index 0000000000..58ff4d9782 --- /dev/null +++ b/packages/api/src/admin/groups.ts @@ -0,0 +1,482 @@ +import { Types } from 'mongoose'; +import { PrincipalType } from 'librechat-data-provider'; +import { logger, isValidObjectIdString } from '@librechat/data-schemas'; +import type { + IGroup, + IUser, + IConfig, + CreateGroupRequest, + UpdateGroupRequest, + GroupFilterOptions, +} from '@librechat/data-schemas'; +import type { FilterQuery, ClientSession, DeleteResult } from 'mongoose'; +import type { Response } from 'express'; +import type { ValidationError } from '~/types/error'; +import type { ServerRequest } from '~/types/http'; + +type GroupListFilter = Pick; + +const VALID_GROUP_SOURCES: ReadonlySet = new Set(['local', 'entra']); +const MAX_CREATE_MEMBER_IDS = 500; +const MAX_SEARCH_LENGTH = 200; +const MAX_NAME_LENGTH = 500; +const MAX_DESCRIPTION_LENGTH = 2000; +const MAX_EMAIL_LENGTH = 500; +const MAX_AVATAR_LENGTH = 2000; +const MAX_EXTERNAL_ID_LENGTH = 500; + +interface GroupIdParams { + id: string; +} + +interface GroupMemberParams extends GroupIdParams { + userId: string; +} + +export interface AdminGroupsDeps { + listGroups: ( + filter?: GroupListFilter & { limit?: number; offset?: number }, + session?: ClientSession, + ) => Promise; + countGroups: (filter?: GroupListFilter, session?: ClientSession) => Promise; + findGroupById: ( + groupId: string | Types.ObjectId, + projection?: Record, + session?: ClientSession, + ) => Promise; + createGroup: (groupData: Partial, session?: ClientSession) => Promise; + updateGroupById: ( + groupId: string | Types.ObjectId, + data: Partial>, + session?: ClientSession, + ) => Promise; + deleteGroup: ( + groupId: string | Types.ObjectId, + session?: ClientSession, + ) => Promise; + addUserToGroup: ( + userId: string | Types.ObjectId, + groupId: string | Types.ObjectId, + session?: ClientSession, + ) => Promise<{ user: IUser; group: IGroup | null }>; + removeUserFromGroup: ( + userId: string | Types.ObjectId, + groupId: string | Types.ObjectId, + session?: ClientSession, + ) => Promise<{ user: IUser; group: IGroup | null }>; + removeMemberById: ( + groupId: string | Types.ObjectId, + memberId: string, + session?: ClientSession, + ) => Promise; + findUsers: ( + searchCriteria: FilterQuery, + fieldsToSelect?: string | string[] | null, + ) => Promise; + deleteConfig: ( + principalType: PrincipalType, + principalId: string | Types.ObjectId, + ) => Promise; + deleteAclEntries: (filter: { + principalType: PrincipalType; + principalId: string | Types.ObjectId; + }) => Promise; + deleteGrantsForPrincipal: ( + principalType: PrincipalType, + principalId: string | Types.ObjectId, + ) => Promise; +} + +export function createAdminGroupsHandlers(deps: AdminGroupsDeps) { + const { + listGroups, + countGroups, + findGroupById, + createGroup, + updateGroupById, + deleteGroup, + addUserToGroup, + removeUserFromGroup, + removeMemberById, + findUsers, + deleteConfig, + deleteAclEntries, + deleteGrantsForPrincipal, + } = deps; + + async function listGroupsHandler(req: ServerRequest, res: Response) { + try { + const { search, source } = req.query as { search?: string; source?: string }; + const filter: GroupListFilter = {}; + if (source && VALID_GROUP_SOURCES.has(source)) { + filter.source = source as IGroup['source']; + } + if (search && search.length > MAX_SEARCH_LENGTH) { + return res + .status(400) + .json({ error: `search must not exceed ${MAX_SEARCH_LENGTH} characters` }); + } + if (search) { + filter.search = search; + } + const limit = Math.min(Math.max(Number(req.query.limit) || 50, 1), 200); + const offset = Math.max(Number(req.query.offset) || 0, 0); + const [groups, total] = await Promise.all([ + listGroups({ ...filter, limit, offset }), + countGroups(filter), + ]); + return res.status(200).json({ groups, total, limit, offset }); + } catch (error) { + logger.error('[adminGroups] listGroups error:', error); + return res.status(500).json({ error: 'Failed to list groups' }); + } + } + + async function getGroupHandler(req: ServerRequest, res: Response) { + try { + const { id } = req.params as GroupIdParams; + if (!isValidObjectIdString(id)) { + return res.status(400).json({ error: 'Invalid group ID format' }); + } + const group = await findGroupById(id); + if (!group) { + return res.status(404).json({ error: 'Group not found' }); + } + return res.status(200).json({ group }); + } catch (error) { + logger.error('[adminGroups] getGroup error:', error); + return res.status(500).json({ error: 'Failed to get group' }); + } + } + + async function createGroupHandler(req: ServerRequest, res: Response) { + try { + const body = req.body as CreateGroupRequest; + if (!body.name || typeof body.name !== 'string' || !body.name.trim()) { + return res.status(400).json({ error: 'name is required' }); + } + if (body.name.trim().length > MAX_NAME_LENGTH) { + return res + .status(400) + .json({ error: `name must not exceed ${MAX_NAME_LENGTH} characters` }); + } + if (body.source && !VALID_GROUP_SOURCES.has(body.source)) { + return res.status(400).json({ error: 'Invalid source value' }); + } + if (body.description && body.description.length > MAX_DESCRIPTION_LENGTH) { + return res + .status(400) + .json({ error: `description must not exceed ${MAX_DESCRIPTION_LENGTH} characters` }); + } + if (body.email && body.email.length > MAX_EMAIL_LENGTH) { + return res + .status(400) + .json({ error: `email must not exceed ${MAX_EMAIL_LENGTH} characters` }); + } + if (body.avatar && body.avatar.length > MAX_AVATAR_LENGTH) { + return res + .status(400) + .json({ error: `avatar must not exceed ${MAX_AVATAR_LENGTH} characters` }); + } + if (body.idOnTheSource && body.idOnTheSource.length > MAX_EXTERNAL_ID_LENGTH) { + return res + .status(400) + .json({ error: `idOnTheSource must not exceed ${MAX_EXTERNAL_ID_LENGTH} characters` }); + } + + const rawIds = Array.isArray(body.memberIds) ? body.memberIds : []; + if (rawIds.length > MAX_CREATE_MEMBER_IDS) { + return res + .status(400) + .json({ error: `memberIds must not exceed ${MAX_CREATE_MEMBER_IDS} entries` }); + } + let memberIds = rawIds; + const objectIds = rawIds.filter(isValidObjectIdString); + if (objectIds.length > 0) { + const users = await findUsers({ _id: { $in: objectIds } }, 'idOnTheSource'); + const idMap = new Map(); + for (const user of users) { + const uid = user._id?.toString() ?? ''; + idMap.set(uid, user.idOnTheSource || uid); + } + const unmapped = objectIds.filter((oid) => !idMap.has(oid)); + if (unmapped.length > 0) { + logger.warn( + '[adminGroups] createGroup: memberIds contain unknown user ObjectIds:', + unmapped, + ); + } + memberIds = rawIds.map((id) => idMap.get(id) || id); + } + + const group = await createGroup({ + name: body.name.trim(), + description: body.description, + email: body.email, + avatar: body.avatar, + source: body.source || 'local', + memberIds, + ...(body.idOnTheSource ? { idOnTheSource: body.idOnTheSource } : {}), + }); + return res.status(201).json({ group }); + } catch (error) { + if ((error as ValidationError).name === 'ValidationError') { + return res.status(400).json({ error: (error as ValidationError).message }); + } + logger.error('[adminGroups] createGroup error:', error); + return res.status(500).json({ error: 'Failed to create group' }); + } + } + + async function updateGroupHandler(req: ServerRequest, res: Response) { + try { + const { id } = req.params as GroupIdParams; + if (!isValidObjectIdString(id)) { + return res.status(400).json({ error: 'Invalid group ID format' }); + } + const body = req.body as UpdateGroupRequest; + + if ( + body.name !== undefined && + (!body.name || typeof body.name !== 'string' || !body.name.trim()) + ) { + return res.status(400).json({ error: 'name must be a non-empty string' }); + } + if (body.name !== undefined && body.name.trim().length > MAX_NAME_LENGTH) { + return res + .status(400) + .json({ error: `name must not exceed ${MAX_NAME_LENGTH} characters` }); + } + if (body.description !== undefined && body.description.length > MAX_DESCRIPTION_LENGTH) { + return res + .status(400) + .json({ error: `description must not exceed ${MAX_DESCRIPTION_LENGTH} characters` }); + } + if (body.email !== undefined && body.email.length > MAX_EMAIL_LENGTH) { + return res + .status(400) + .json({ error: `email must not exceed ${MAX_EMAIL_LENGTH} characters` }); + } + if (body.avatar !== undefined && body.avatar.length > MAX_AVATAR_LENGTH) { + return res + .status(400) + .json({ error: `avatar must not exceed ${MAX_AVATAR_LENGTH} characters` }); + } + + const updateData: Partial> = {}; + if (body.name !== undefined) { + updateData.name = body.name.trim(); + } + if (body.description !== undefined) { + updateData.description = body.description; + } + if (body.email !== undefined) { + updateData.email = body.email; + } + if (body.avatar !== undefined) { + updateData.avatar = body.avatar; + } + + if (Object.keys(updateData).length === 0) { + return res.status(400).json({ error: 'No valid fields to update' }); + } + + const group = await updateGroupById(id, updateData); + if (!group) { + return res.status(404).json({ error: 'Group not found' }); + } + return res.status(200).json({ group }); + } catch (error) { + if ((error as ValidationError).name === 'ValidationError') { + return res.status(400).json({ error: (error as ValidationError).message }); + } + logger.error('[adminGroups] updateGroup error:', error); + return res.status(500).json({ error: 'Failed to update group' }); + } + } + + async function deleteGroupHandler(req: ServerRequest, res: Response) { + try { + const { id } = req.params as GroupIdParams; + if (!isValidObjectIdString(id)) { + return res.status(400).json({ error: 'Invalid group ID format' }); + } + const deleted = await deleteGroup(id); + if (!deleted) { + return res.status(404).json({ error: 'Group not found' }); + } + /** + * deleteAclEntries is a raw deleteMany wrapper with no type casting. + * grantPermission stores group principalId as ObjectId, so we must + * cast here. deleteConfig and deleteGrantsForPrincipal normalize internally. + */ + const cleanupResults = await Promise.allSettled([ + deleteConfig(PrincipalType.GROUP, id), + deleteAclEntries({ + principalType: PrincipalType.GROUP, + principalId: new Types.ObjectId(id), + }), + deleteGrantsForPrincipal(PrincipalType.GROUP, id), + ]); + for (const result of cleanupResults) { + if (result.status === 'rejected') { + logger.error('[adminGroups] cascade cleanup step failed for group:', id, result.reason); + } + } + return res.status(200).json({ success: true, id }); + } catch (error) { + logger.error('[adminGroups] deleteGroup error:', error); + return res.status(500).json({ error: 'Failed to delete group' }); + } + } + + async function getGroupMembersHandler(req: ServerRequest, res: Response) { + try { + const { id } = req.params as GroupIdParams; + if (!isValidObjectIdString(id)) { + return res.status(400).json({ error: 'Invalid group ID format' }); + } + const group = await findGroupById(id, { memberIds: 1 }); + if (!group) { + return res.status(404).json({ error: 'Group not found' }); + } + + /** + * `total` counts unique raw memberId strings. After user resolution, two + * distinct strings may map to the same user, so `members.length` can be + * less than the page size. Write paths prevent this for well-formed data. + */ + const allMemberIds = [...new Set(group.memberIds || [])]; + const total = allMemberIds.length; + const limit = Math.min(Math.max(Number(req.query.limit) || 50, 1), 200); + const offset = Math.max(Number(req.query.offset) || 0, 0); + + if (total === 0 || offset >= total) { + return res.status(200).json({ members: [], total, limit, offset }); + } + + const memberIds = allMemberIds.slice(offset, offset + limit); + + const validObjectIds = memberIds.filter(isValidObjectIdString); + const conditions: FilterQuery[] = [{ idOnTheSource: { $in: memberIds } }]; + if (validObjectIds.length > 0) { + conditions.push({ _id: { $in: validObjectIds } }); + } + const users = await findUsers({ $or: conditions }, 'name email avatar idOnTheSource'); + + const userMap = new Map(); + for (const user of users) { + if (user.idOnTheSource) { + userMap.set(user.idOnTheSource, user); + } + if (user._id) { + userMap.set(user._id.toString(), user); + } + } + + const seen = new Set(); + const members: { userId: string; name: string; email: string; avatarUrl?: string }[] = []; + for (const memberId of memberIds) { + const user = userMap.get(memberId); + const userId = user?._id?.toString() ?? memberId; + if (seen.has(userId)) { + continue; + } + seen.add(userId); + members.push({ + userId, + name: user?.name ?? memberId, + email: user?.email ?? '', + avatarUrl: user?.avatar, + }); + } + + return res.status(200).json({ members, total, limit, offset }); + } catch (error) { + logger.error('[adminGroups] getGroupMembers error:', error); + return res.status(500).json({ error: 'Failed to get group members' }); + } + } + + async function addGroupMemberHandler(req: ServerRequest, res: Response) { + try { + const { id } = req.params as GroupIdParams; + if (!isValidObjectIdString(id)) { + return res.status(400).json({ error: 'Invalid group ID format' }); + } + const { userId } = req.body as { userId: string }; + if (!userId || typeof userId !== 'string') { + return res.status(400).json({ error: 'userId is required' }); + } + if (!isValidObjectIdString(userId)) { + return res + .status(400) + .json({ error: 'Only native user ObjectIds can be added via this endpoint' }); + } + + const { group } = await addUserToGroup(userId, id); + if (!group) { + return res.status(404).json({ error: 'Group not found' }); + } + return res.status(200).json({ group }); + } catch (error) { + const message = error instanceof Error ? error.message : ''; + const isNotFound = message === 'User not found' || message.startsWith('User not found:'); + if (isNotFound) { + return res.status(404).json({ error: 'User not found' }); + } + logger.error('[adminGroups] addGroupMember error:', error); + return res.status(500).json({ error: 'Failed to add member' }); + } + } + + /** + * Attempt removal of an ObjectId-format member: first via removeUserFromGroup + * (which resolves the user), falling back to a raw $pull if the user record + * no longer exists. Returns null only when the group itself is not found. + */ + async function removeObjectIdMember(groupId: string, userId: string): Promise { + try { + const { group } = await removeUserFromGroup(userId, groupId); + return group; + } catch (err) { + const msg = err instanceof Error ? err.message : ''; + if (msg === 'User not found' || msg.startsWith('User not found:')) { + return removeMemberById(groupId, userId); + } + throw err; + } + } + + async function removeGroupMemberHandler(req: ServerRequest, res: Response) { + try { + const { id, userId } = req.params as GroupMemberParams; + if (!isValidObjectIdString(id)) { + return res.status(400).json({ error: 'Invalid group ID format' }); + } + + const group = isValidObjectIdString(userId) + ? await removeObjectIdMember(id, userId) + : await removeMemberById(id, userId); + + if (!group) { + return res.status(404).json({ error: 'Group not found' }); + } + return res.status(200).json({ success: true }); + } catch (error) { + logger.error('[adminGroups] removeGroupMember error:', error); + return res.status(500).json({ error: 'Failed to remove member' }); + } + } + + return { + listGroups: listGroupsHandler, + getGroup: getGroupHandler, + createGroup: createGroupHandler, + updateGroup: updateGroupHandler, + deleteGroup: deleteGroupHandler, + getGroupMembers: getGroupMembersHandler, + addGroupMember: addGroupMemberHandler, + removeGroupMember: removeGroupMemberHandler, + }; +} diff --git a/packages/api/src/admin/index.ts b/packages/api/src/admin/index.ts index bf48ce7345..d833c7e2b0 100644 --- a/packages/api/src/admin/index.ts +++ b/packages/api/src/admin/index.ts @@ -1,2 +1,4 @@ export { createAdminConfigHandlers } from './config'; +export { createAdminGroupsHandlers } from './groups'; export type { AdminConfigDeps } from './config'; +export type { AdminGroupsDeps } from './groups'; diff --git a/packages/data-schemas/src/methods/systemGrant.spec.ts b/packages/data-schemas/src/methods/systemGrant.spec.ts index b17285c761..49b4f7269e 100644 --- a/packages/data-schemas/src/methods/systemGrant.spec.ts +++ b/packages/data-schemas/src/methods/systemGrant.spec.ts @@ -702,6 +702,68 @@ describe('systemGrant methods', () => { }); }); + describe('deleteGrantsForPrincipal', () => { + it('deletes all grants for a principal', async () => { + const groupId = new Types.ObjectId(); + + await methods.grantCapability({ + principalType: PrincipalType.GROUP, + principalId: groupId, + capability: SystemCapabilities.READ_USERS, + }); + await methods.grantCapability({ + principalType: PrincipalType.GROUP, + principalId: groupId, + capability: SystemCapabilities.READ_CONFIGS, + }); + + await methods.deleteGrantsForPrincipal(PrincipalType.GROUP, groupId); + + const remaining = await SystemGrant.countDocuments({ + principalType: PrincipalType.GROUP, + principalId: groupId, + }); + expect(remaining).toBe(0); + }); + + it('is a no-op for principal with no grants', async () => { + const groupId = new Types.ObjectId(); + + await expect( + methods.deleteGrantsForPrincipal(PrincipalType.GROUP, groupId), + ).resolves.not.toThrow(); + }); + + it('does not affect other principals', async () => { + const groupA = new Types.ObjectId(); + const groupB = new Types.ObjectId(); + + await methods.grantCapability({ + principalType: PrincipalType.GROUP, + principalId: groupA, + capability: SystemCapabilities.READ_USERS, + }); + await methods.grantCapability({ + principalType: PrincipalType.GROUP, + principalId: groupB, + capability: SystemCapabilities.READ_USERS, + }); + + await methods.deleteGrantsForPrincipal(PrincipalType.GROUP, groupA); + + const remainingA = await SystemGrant.countDocuments({ + principalType: PrincipalType.GROUP, + principalId: groupA, + }); + const remainingB = await SystemGrant.countDocuments({ + principalType: PrincipalType.GROUP, + principalId: groupB, + }); + expect(remainingA).toBe(0); + expect(remainingB).toBe(1); + }); + }); + describe('schema validation', () => { it('rejects null tenantId at the schema level', async () => { await expect( diff --git a/packages/data-schemas/src/methods/systemGrant.ts b/packages/data-schemas/src/methods/systemGrant.ts index 6071dd38c5..4954f50c16 100644 --- a/packages/data-schemas/src/methods/systemGrant.ts +++ b/packages/data-schemas/src/methods/systemGrant.ts @@ -246,12 +246,28 @@ export function createSystemGrantMethods(mongoose: typeof import('mongoose')) { } } + /** + * Delete all system grants for a principal. + * Used for cascade cleanup when a principal (group, role) is deleted. + */ + async function deleteGrantsForPrincipal( + principalType: PrincipalType, + principalId: string | Types.ObjectId, + session?: ClientSession, + ): Promise { + const SystemGrant = mongoose.models.SystemGrant as Model; + const normalizedPrincipalId = normalizePrincipalId(principalId, principalType); + const options = session ? { session } : {}; + await SystemGrant.deleteMany({ principalType, principalId: normalizedPrincipalId }, options); + } + return { grantCapability, seedSystemGrants, revokeCapability, hasCapabilityForPrincipals, getCapabilitiesForPrincipal, + deleteGrantsForPrincipal, }; } diff --git a/packages/data-schemas/src/methods/user.ts b/packages/data-schemas/src/methods/user.ts index 74cb4a1e1c..137c01d0cd 100644 --- a/packages/data-schemas/src/methods/user.ts +++ b/packages/data-schemas/src/methods/user.ts @@ -44,6 +44,19 @@ export function createUserMethods(mongoose: typeof import('mongoose')) { return (await query.lean()) as IUser | null; } + async function findUsers( + searchCriteria: FilterQuery, + fieldsToSelect?: string | string[] | null, + ): Promise { + const User = mongoose.models.User; + const normalizedCriteria = normalizeEmailInCriteria(searchCriteria); + const query = User.find(normalizedCriteria); + if (fieldsToSelect) { + query.select(fieldsToSelect); + } + return (await query.lean()) as IUser[]; + } + /** * Count the number of user documents in the collection based on the provided filter. */ @@ -323,6 +336,7 @@ export function createUserMethods(mongoose: typeof import('mongoose')) { return { findUser, + findUsers, countUsers, createUser, updateUser, diff --git a/packages/data-schemas/src/methods/userGroup.methods.spec.ts b/packages/data-schemas/src/methods/userGroup.methods.spec.ts index 8a31544018..51848de091 100644 --- a/packages/data-schemas/src/methods/userGroup.methods.spec.ts +++ b/packages/data-schemas/src/methods/userGroup.methods.spec.ts @@ -600,6 +600,155 @@ describe('UserGroup Methods - Detailed Tests', () => { }); }); + describe('listGroups', () => { + beforeEach(async () => { + await Group.create([ + { name: 'Beta', source: 'local', memberIds: [], email: 'beta@test.com' }, + { name: 'Alpha', source: 'local', memberIds: [], description: 'first group' }, + { name: 'Gamma', source: 'entra', idOnTheSource: 'ext-g', memberIds: [] }, + ]); + }); + + test('returns groups sorted by name', async () => { + const groups = await methods.listGroups(); + + expect(groups).toHaveLength(3); + expect(groups[0].name).toBe('Alpha'); + expect(groups[1].name).toBe('Beta'); + expect(groups[2].name).toBe('Gamma'); + }); + + test('filters by source', async () => { + const groups = await methods.listGroups({ source: 'entra' }); + + expect(groups).toHaveLength(1); + expect(groups[0].name).toBe('Gamma'); + }); + + test('filters by search (name)', async () => { + const groups = await methods.listGroups({ search: 'alpha' }); + + expect(groups).toHaveLength(1); + expect(groups[0].name).toBe('Alpha'); + }); + + test('filters by search (email)', async () => { + const groups = await methods.listGroups({ search: 'beta@test' }); + + expect(groups).toHaveLength(1); + expect(groups[0].name).toBe('Beta'); + }); + + test('filters by search (description)', async () => { + const groups = await methods.listGroups({ search: 'first group' }); + + expect(groups).toHaveLength(1); + expect(groups[0].name).toBe('Alpha'); + }); + + test('respects limit and offset', async () => { + const groups = await methods.listGroups({ limit: 1, offset: 1 }); + + expect(groups).toHaveLength(1); + expect(groups[0].name).toBe('Beta'); + }); + + test('returns empty for no matches', async () => { + const groups = await methods.listGroups({ search: 'nonexistent' }); + + expect(groups).toHaveLength(0); + }); + }); + + describe('countGroups', () => { + beforeEach(async () => { + await Group.create([ + { name: 'A', source: 'local', memberIds: [] }, + { name: 'B', source: 'local', memberIds: [] }, + { name: 'C', source: 'entra', idOnTheSource: 'ext-c', memberIds: [] }, + ]); + }); + + test('returns total count', async () => { + const count = await methods.countGroups(); + + expect(count).toBe(3); + }); + + test('respects source filter', async () => { + const count = await methods.countGroups({ source: 'local' }); + + expect(count).toBe(2); + }); + + test('respects search filter', async () => { + const count = await methods.countGroups({ search: 'A' }); + + expect(count).toBe(1); + }); + }); + + describe('deleteGroup', () => { + test('returns deleted group', async () => { + const group = await Group.create({ name: 'ToDelete', source: 'local', memberIds: [] }); + + const deleted = await methods.deleteGroup(group._id as mongoose.Types.ObjectId); + + expect(deleted).toBeDefined(); + expect(deleted?.name).toBe('ToDelete'); + const remaining = await Group.findById(group._id); + expect(remaining).toBeNull(); + }); + + test('returns null for non-existent ID', async () => { + const fakeId = new mongoose.Types.ObjectId(); + const result = await methods.deleteGroup(fakeId); + + expect(result).toBeNull(); + }); + }); + + describe('removeMemberById', () => { + test('removes member from memberIds array', async () => { + const group = await Group.create({ + name: 'Test', + source: 'local', + memberIds: ['m1', 'm2', 'm3'], + }); + + const updated = await methods.removeMemberById( + group._id as mongoose.Types.ObjectId, + 'm2', + ); + + expect(updated).toBeDefined(); + expect(updated?.memberIds).toEqual(['m1', 'm3']); + }); + + test('is idempotent when memberId not present', async () => { + const group = await Group.create({ + name: 'Test', + source: 'local', + memberIds: ['m1'], + }); + + const updated = await methods.removeMemberById( + group._id as mongoose.Types.ObjectId, + 'nonexistent', + ); + + expect(updated).toBeDefined(); + expect(updated?.memberIds).toEqual(['m1']); + }); + + test('returns null for non-existent group', async () => { + const fakeId = new mongoose.Types.ObjectId(); + const result = await methods.removeMemberById(fakeId, 'any-id'); + + expect(result).toBeNull(); + }); + }); + describe('sortPrincipalsByRelevance', () => { test('should sort principals by relevance score', async () => { const principals = [ diff --git a/packages/data-schemas/src/methods/userGroup.ts b/packages/data-schemas/src/methods/userGroup.ts index a41358337c..948542e6de 100644 --- a/packages/data-schemas/src/methods/userGroup.ts +++ b/packages/data-schemas/src/methods/userGroup.ts @@ -1,8 +1,9 @@ import { Types } from 'mongoose'; import { PrincipalType } from 'librechat-data-provider'; import type { TUser, TPrincipalSearchResult } from 'librechat-data-provider'; -import type { Model, ClientSession } from 'mongoose'; +import type { Model, ClientSession, FilterQuery } from 'mongoose'; import type { IGroup, IRole, IUser } from '~/types'; +import { escapeRegExp } from '~/utils/string'; export function createUserGroupMethods(mongoose: typeof import('mongoose')) { /** @@ -14,7 +15,7 @@ export function createUserGroupMethods(mongoose: typeof import('mongoose')) { */ async function findGroupById( groupId: string | Types.ObjectId, - projection: Record = {}, + projection: Record = {}, session?: ClientSession, ): Promise { const Group = mongoose.models.Group as Model; @@ -36,7 +37,7 @@ export function createUserGroupMethods(mongoose: typeof import('mongoose')) { async function findGroupByExternalId( idOnTheSource: string, source: 'entra' | 'local' = 'entra', - projection: Record = {}, + projection: Record = {}, session?: ClientSession, ): Promise { const Group = mongoose.models.Group as Model; @@ -658,6 +659,97 @@ export function createUserGroupMethods(mongoose: typeof import('mongoose')) { return Group.updateMany(filter, update, options || {}); } + function buildGroupQuery(filter: { + source?: 'local' | 'entra'; + search?: string; + }): FilterQuery { + const query: FilterQuery = {}; + if (filter.source) { + query.source = filter.source; + } + if (filter.search) { + const regex = new RegExp(escapeRegExp(filter.search), 'i'); + query.$or = [{ name: regex }, { email: regex }, { description: regex }]; + } + return query; + } + + /** + * List groups with optional source, search, and pagination filters. + * Results are sorted by name. + * @param filter - Optional filter with source, search, limit, and offset fields + * @param session - Optional MongoDB session for transactions + */ + async function listGroups( + filter: { + source?: 'local' | 'entra'; + search?: string; + limit?: number; + offset?: number; + } = {}, + session?: ClientSession, + ): Promise { + const Group = mongoose.models.Group as Model; + const query = buildGroupQuery(filter); + const limit = filter.limit ?? 50; + const offset = filter.offset ?? 0; + return await Group.find(query) + .sort({ name: 1 }) + .skip(offset) + .limit(limit) + .session(session ?? null) + .lean(); + } + + /** + * Count groups matching optional source and search filters. + * @param filter - Optional filter with source and search fields + * @param session - Optional MongoDB session for transactions + */ + async function countGroups( + filter: { source?: 'local' | 'entra'; search?: string } = {}, + session?: ClientSession, + ): Promise { + const Group = mongoose.models.Group as Model; + const query = buildGroupQuery(filter); + return await Group.countDocuments(query).session(session ?? null); + } + + /** + * Delete a group by its ID. + * @param groupId - The group's ObjectId + * @param session - Optional MongoDB session for transactions + */ + async function deleteGroup( + groupId: string | Types.ObjectId, + session?: ClientSession, + ): Promise { + const Group = mongoose.models.Group as Model; + const options = session ? { session } : {}; + return await Group.findByIdAndDelete(groupId, options).lean(); + } + + /** + * Remove a member from a group by raw memberId string ($pull from memberIds). + * Unlike removeUserFromGroup, this does not look up the user first. + * @param groupId - The group's ObjectId + * @param memberId - The raw memberId string to remove (ObjectId or idOnTheSource) + * @param session - Optional MongoDB session for transactions + */ + async function removeMemberById( + groupId: string | Types.ObjectId, + memberId: string, + session?: ClientSession, + ): Promise { + const Group = mongoose.models.Group as Model; + const options = { new: true, ...(session ? { session } : {}) }; + return await Group.findByIdAndUpdate( + groupId, + { $pull: { memberIds: memberId } }, + options, + ).lean(); + } + return { findGroupById, findGroupByExternalId, @@ -677,6 +769,10 @@ export function createUserGroupMethods(mongoose: typeof import('mongoose')) { searchPrincipals, calculateRelevanceScore, sortPrincipalsByRelevance, + listGroups, + countGroups, + deleteGroup, + removeMemberById, }; } From 5972a21479e6cda4b8aff93adb8e7fc17f31772f Mon Sep 17 00:00:00 2001 From: Dustin Healy <54083382+dustinhealy@users.noreply.github.com> Date: Fri, 27 Mar 2026 12:44:47 -0700 Subject: [PATCH 10/18] =?UTF-8?q?=F0=9F=AA=AA=20feat:=20Admin=20Roles=20AP?= =?UTF-8?q?I=20Endpoints=20(#12400)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: add createRole and deleteRole methods to role * feat: add admin roles handler factory and Express routes * fix: address convention violations in admin roles handlers * fix: rename createRole/deleteRole to avoid AccessRole name collision The existing accessRole.ts already exports createRole/deleteRole for the AccessRole model. In createMethods index.ts, these are spread after roleMethods, overwriting them. Renamed our Role methods to createRoleByName/deleteRoleByName to match the existing pattern (getRoleByName, updateRoleByName) and avoid the collision. * feat: add description field to Role model - Add description to IRole, CreateRoleRequest, UpdateRoleRequest types - Add description field to Mongoose roleSchema (default: '') - Wire description through createRoleHandler and updateRoleHandler - Include description in listRoles select clause so it appears in list * fix: address Copilot review findings in admin roles handlers * test: add unit tests for admin roles and groups handlers * test: add data-layer tests for createRoleByName, deleteRoleByName, listUsersByRole * fix: allow system role updates when name is unchanged The updateRoleHandler guard rejected any request where body.name matched a system role, even when the name was not being changed. This blocked editing a system role's description. Compare against the URL param to only reject actual renames to reserved names. * fix: address external review findings for admin roles - Block renaming system roles (ADMIN/USER) and add user migration on rename - Add input validation: name max-length, trim on update, duplicate name check - Replace fragile String.includes error matching with prefix-based classification - Catch MongoDB 11000 duplicate key in createRoleByName - Add pagination (limit/offset/total) to getRoleMembersHandler - Reverse delete order in deleteRoleByName โ€” reassign users before deletion - Add role existence check in removeRoleMember; drop unused createdAt select - Add Array.isArray guard for permissions input; use consistent ?? coalescing - Fix import ordering per AGENTS.md conventions - Type-cast mongoose.models.User as Model for proper TS inference - Add comprehensive tests: rename guards, pagination, validation, 500 paths * fix: address re-review findings for admin roles - Gate deleteRoleByName on existence check โ€” skip user reassignment and cache invalidation when role doesn't exist (fixes test mismatch) - Reverse rename order: migrate users before renaming role so a migration failure leaves the system in a consistent state - Add .sort({ _id: 1 }) to listUsersByRole for deterministic pagination - Import shared AdminMember type from data-schemas instead of local copy; make joinedAt optional since neither groups nor roles populate it - Change IRole.description from optional to required to match schema default - Add data-layer tests for updateUsersByRole and countUsersByRole - Add handler test verifying users-first rename ordering and migration failure safety * fix: add rollback on rename failure and update PR description - Roll back user migration if updateRoleByName returns null during a rename (race: role deleted between existence check and update) - Add test verifying rollback calls updateUsersByRole in reverse - Update PR #12400 description to reflect current test counts (56 handler tests, 40 data-layer tests) and safety features * fix: rollback on rename throw, description validation, delete/DRY cleanup - Hoist isRename/trimmedName above try block so catch can roll back user migration when updateRoleByName throws (not just returns null) - Add description type + max-length (2000) validation in create and update, consistent with groups handler - Remove redundant getRoleByName existence check in deleteRoleHandler โ€” use deleteRoleByName return value directly - Skip no-op name write when body.name equals current name (use isRename) - Extract getUserModel() accessor to DRY repeated Model casts - Use name.trim() consistently in createRoleByName error messages - Add tests: rename-throw rollback, description validation (create+update), update delete test mocks to match simplified handler * fix: guard spurious rollback, harden createRole error path, validate before DB calls - Add migrationRan flag to prevent rollback of user migration that never ran - Return generic message on 500 in createRoleHandler, specific only for 409 - Move description validation before DB queries in updateRoleHandler - Return existing role early when update body has no changes - Wrap cache.set in createRoleByName with try/catch to prevent masking DB success - Add JSDoc on 11000 catch explaining compound unique index - Add tests: spurious rollback guard, empty update body, description validation ordering, listUsersByRole pagination * fix: validate permissions in create, RoleConflictError, rollback safety, cache consistency - Add permissions type/array validation in createRoleHandler - Introduce RoleConflictError class replacing fragile string-prefix matching - Wrap rollback in !role null path with try/catch for correct 404 response - Wrap deleteRoleByName cache.set in try/catch matching createRoleByName - Narrow updateRoleHandler body type to { name?, description? } - Add tests: non-string description in create, rollback failure logging, permissions array rejection, description max-length assertion fix * feat: prevent removing the last admin user Add guard in removeRoleMember that checks countUsersByRole before demoting an ADMIN user, returning 400 if they are the last one. * fix: move interleaved export below imports, add await to countUsersByRole * fix: paginate listRoles, null-guard permissions handler, fix export ordering - Add limit/offset/total pagination to listRoles matching the groups pattern - Add countRoles data-layer method - Omit permissions from listRoles select (getRole returns full document) - Null-guard re-fetched role in updateRolePermissionsHandler - Move interleaved export below all imports in methods/index.ts * fix: address review findings โ€” race safety, validation DRY, type accuracy, test coverage - Add post-write admin count verification in removeRoleMember to prevent zero-admin race condition (TOCTOU โ†’ rollback if count hits 0) - Make IRole.description optional; backfill in initializeRoles for pre-existing roles that lack the field (.lean() bypasses defaults) - Extract parsePagination, validateNameParam, validateRoleName, and validateDescription helpers to eliminate duplicated validation - Add validateNameParam guard to all 7 handlers reading req.params.name - Catch 11000 in updateRoleByName and surface as 409 via RoleConflictError - Add idempotent skip in addRoleMember when user already has target role - Verify updateRolePermissions test asserts response body - Add data-layer tests: listRoles sort/pagination/projection, countRoles, and createRoleByName 11000 duplicate key race * fix: defensive rollback in removeRoleMember, type/style cleanup, test coverage - Wrap removeRoleMember post-write admin rollback in try/catch so a transient DB failure cannot leave the system with zero administrators - Replace double `as unknown[] as IRole[]` cast with `.lean()` - Type parsePagination param explicitly; extract DEFAULT/MAX page constants - Preserve original error cause in updateRoleByName re-throw - Add test for rollback failure path in removeRoleMember (returns 400) - Add test for pre-existing roles missing description field (.lean()) * chore: bump @librechat/data-schemas to 0.0.47 * fix: stale cache on rename, extract renameRole helper, shared pagination, cleanup - Fix updateRoleByName cache bug: invalidate old key and populate new key when updates.name differs from roleName (prevents stale cache after rename) - Extract renameRole helper to eliminate mutable outer-scope state flags (isRename, trimmedName, migrationRan) in updateRoleHandler - Unify system-role protection to 403 for both rename-from and rename-to - Extract parsePagination to shared admin/pagination.ts; use in both roles.ts and groups.ts - Extract name.trim() to local const in createRoleByName (was called 5ร—) - Remove redundant findOne pre-check in deleteRoleByName - Replace getUserModel closure with local const declarations - Remove redundant description ?? '' in createRoleHandler (schema default) - Add doc comment on updateRolePermissionsHandler noting cache dependency - Add data-layer tests for cache rename behavior (old key null, new key set) * fix: harden role guards, add User.role index, validate names, improve tests - Add index on User.role field for efficient member queries at scale - Replace fragile SystemRoles key lookup with value-based Set check (6 sites) - Elevate rename rollback failure logging to CRITICAL (matches removeRoleMember) - Guard removeRoleMember against non-ADMIN system roles (403 for USER) - Fix parsePagination limit=0 gotcha: use parseInt + NaN check instead of || - Add control character and reserved path segment validation to role names - Simplify validateRoleName: remove redundant casts and dead conditions - Add JSDoc to deleteRoleByName documenting non-atomic window - Split mixed value+type import in methods/index.ts per AGENTS.md - Add 9 new tests: permissions assertion, combined rename+desc, createRole with permissions, pagination edge cases, control char/reserved name rejection, system role removeRoleMember guard * fix: exact-case reserved name check, consistent validation, cleaner createRole - Remove .toLowerCase() from reserved name check so only exact matches (members, permissions) are rejected, not legitimate names like "Members" - Extract trimmed const in validateRoleName for consistent validation - Add control char check to validateNameParam for parity with body validation - Build createRole roleData conditionally to avoid passing description: undefined - Expand deleteRoleByName JSDoc documenting self-healing design and no-op trade-off * fix: scope rename rollback to only migrated users, prevent cross-role corruption Capture user IDs before forward migration so the rollback path only reverts users this request actually moved. Previously the rollback called updateUsersByRole(newName, currentName) which would sweep all users with the new role โ€” including any independently assigned by a concurrent admin request โ€” causing silent cross-role data corruption. Adds findUserIdsByRole and updateUsersRoleByIds to the data layer. Extracts rollbackMigratedUsers helper to deduplicate rollback sites. * fix: guard last admin in addRoleMember to prevent zero-admin lockout Since each user has exactly one role, addRoleMember implicitly removes the user from their current role. Without a guard, reassigning the sole admin to a non-admin role leaves zero admins and locks out admin management. Adds the same countUsersByRole check used in removeRoleMember. * fix: wire findUserIdsByRole and updateUsersRoleByIds into roles route The scoped rollback deps added in c89b5db were missing from the route DI wiring, causing renameRole to call undefined and return a 500. * fix: post-write admin guard in addRoleMember, compound role index, review cleanup - Add post-write admin count check + rollback to addRoleMember to match removeRoleMember's two-phase TOCTOU protection (prevents zero-admin via concurrent requests) - Replace single-field User.role index with compound { role: 1, tenantId: 1 } to align with existing multi-tenant index pattern (email, OAuth IDs) - Narrow listRoles dep return type to RoleListItem (projected fields only) - Refactor validateDescription to early-return style per AGENTS.md - Remove redundant double .lean() in updateRoleByName - Document rename snapshot race window in renameRole JSDoc - Document cache null-set behavior in deleteRoleByName - Add routing-coupling comment on RESERVED_ROLE_NAMES - Add test for addRoleMember post-write rollback * fix: review cleanup โ€” system-role guard, type safety, JSDoc accuracy, tests - Add system-role guard to addRoleMember: block direct assignment to non-ADMIN system roles (403), symmetric with removeRoleMember - Fix RESERVED_ROLE_NAMES comment: explain semantic URL ambiguity, not a routing conflict (Express resolves single vs multi-segment correctly) - Replace _id: unknown with Types.ObjectId | string per AGENTS.md - Narrow listRoles data-layer return type to Pick to match the actual .select() projection - Move updateRoleHandler param check inside try/catch for consistency - Include user IDs in all CRITICAL rollback failure logs for operator recovery - Clarify deleteRoleByName JSDoc: replace "self-healing" with "idempotent", document that recovery requires caller retry - Add tests: system-role guard, promote non-admin to ADMIN, findUserIdsByRole throw prevents migration * fix: include _id in listRoles return type to match RoleListItem Pick omits _id, making it incompatible with the handler dep's RoleListItem which requires _id. * fix: case-insensitive system role guard, reject null permissions, check updateUser result - System role name checks now use case-insensitive comparison via toUpperCase() โ€” prevents creating 'admin' or 'user' which would collide with the legacy roles route that uppercases params - Reject permissions: null in createRole (typeof null === 'object' was bypassing the validation) - Check updateUser return in addRoleMember โ€” return 404 if the user was deleted between the findUser and updateUser calls * fix: check updateUser return in removeRoleMember for concurrent delete safety --------- Co-authored-by: Danny Avila --- api/server/index.js | 1 + api/server/routes/admin/roles.js | 43 + api/server/routes/index.js | 2 + packages/api/src/admin/groups.ts | 7 +- packages/api/src/admin/index.ts | 2 + packages/api/src/admin/pagination.ts | 17 + packages/api/src/admin/roles.spec.ts | 1484 +++++++++++++++++ packages/api/src/admin/roles.ts | 550 ++++++ packages/data-schemas/package.json | 2 +- packages/data-schemas/src/index.ts | 1 + packages/data-schemas/src/methods/index.ts | 6 +- .../src/methods/role.methods.spec.ts | 387 ++++- packages/data-schemas/src/methods/role.ts | 187 ++- packages/data-schemas/src/schema/role.ts | 1 + packages/data-schemas/src/schema/user.ts | 1 + packages/data-schemas/src/types/admin.ts | 2 +- packages/data-schemas/src/types/role.ts | 3 + 17 files changed, 2673 insertions(+), 23 deletions(-) create mode 100644 api/server/routes/admin/roles.js create mode 100644 packages/api/src/admin/pagination.ts create mode 100644 packages/api/src/admin/roles.spec.ts create mode 100644 packages/api/src/admin/roles.ts diff --git a/api/server/index.js b/api/server/index.js index 4ecc966476..813b453468 100644 --- a/api/server/index.js +++ b/api/server/index.js @@ -145,6 +145,7 @@ const startServer = async () => { app.use('/api/admin', routes.adminAuth); app.use('/api/admin/config', routes.adminConfig); app.use('/api/admin/groups', routes.adminGroups); + app.use('/api/admin/roles', routes.adminRoles); app.use('/api/actions', routes.actions); app.use('/api/keys', routes.keys); app.use('/api/api-keys', routes.apiKeys); diff --git a/api/server/routes/admin/roles.js b/api/server/routes/admin/roles.js new file mode 100644 index 0000000000..2d0f1b1128 --- /dev/null +++ b/api/server/routes/admin/roles.js @@ -0,0 +1,43 @@ +const express = require('express'); +const { createAdminRolesHandlers } = require('@librechat/api'); +const { SystemCapabilities } = require('@librechat/data-schemas'); +const { requireCapability } = require('~/server/middleware/roles/capabilities'); +const { requireJwtAuth } = require('~/server/middleware'); +const db = require('~/models'); + +const router = express.Router(); + +const requireAdminAccess = requireCapability(SystemCapabilities.ACCESS_ADMIN); +const requireReadRoles = requireCapability(SystemCapabilities.READ_ROLES); +const requireManageRoles = requireCapability(SystemCapabilities.MANAGE_ROLES); + +const handlers = createAdminRolesHandlers({ + listRoles: db.listRoles, + countRoles: db.countRoles, + getRoleByName: db.getRoleByName, + createRoleByName: db.createRoleByName, + updateRoleByName: db.updateRoleByName, + updateAccessPermissions: db.updateAccessPermissions, + deleteRoleByName: db.deleteRoleByName, + findUser: db.findUser, + updateUser: db.updateUser, + updateUsersByRole: db.updateUsersByRole, + findUserIdsByRole: db.findUserIdsByRole, + updateUsersRoleByIds: db.updateUsersRoleByIds, + listUsersByRole: db.listUsersByRole, + countUsersByRole: db.countUsersByRole, +}); + +router.use(requireJwtAuth, requireAdminAccess); + +router.get('/', requireReadRoles, handlers.listRoles); +router.post('/', requireManageRoles, handlers.createRole); +router.get('/:name', requireReadRoles, handlers.getRole); +router.patch('/:name', requireManageRoles, handlers.updateRole); +router.delete('/:name', requireManageRoles, handlers.deleteRole); +router.patch('/:name/permissions', requireManageRoles, handlers.updateRolePermissions); +router.get('/:name/members', requireReadRoles, handlers.getRoleMembers); +router.post('/:name/members', requireManageRoles, handlers.addRoleMember); +router.delete('/:name/members/:userId', requireManageRoles, handlers.removeRoleMember); + +module.exports = router; diff --git a/api/server/routes/index.js b/api/server/routes/index.js index f9a088649c..71ae041fc2 100644 --- a/api/server/routes/index.js +++ b/api/server/routes/index.js @@ -4,6 +4,7 @@ const categories = require('./categories'); const adminAuth = require('./admin/auth'); const adminConfig = require('./admin/config'); const adminGroups = require('./admin/groups'); +const adminRoles = require('./admin/roles'); const endpoints = require('./endpoints'); const staticRoute = require('./static'); const messages = require('./messages'); @@ -35,6 +36,7 @@ module.exports = { adminAuth, adminConfig, adminGroups, + adminRoles, keys, apiKeys, user, diff --git a/packages/api/src/admin/groups.ts b/packages/api/src/admin/groups.ts index 58ff4d9782..ab4490e05f 100644 --- a/packages/api/src/admin/groups.ts +++ b/packages/api/src/admin/groups.ts @@ -13,6 +13,7 @@ import type { FilterQuery, ClientSession, DeleteResult } from 'mongoose'; import type { Response } from 'express'; import type { ValidationError } from '~/types/error'; import type { ServerRequest } from '~/types/http'; +import { parsePagination } from './pagination'; type GroupListFilter = Pick; @@ -119,8 +120,7 @@ export function createAdminGroupsHandlers(deps: AdminGroupsDeps) { if (search) { filter.search = search; } - const limit = Math.min(Math.max(Number(req.query.limit) || 50, 1), 200); - const offset = Math.max(Number(req.query.offset) || 0, 0); + const { limit, offset } = parsePagination(req.query); const [groups, total] = await Promise.all([ listGroups({ ...filter, limit, offset }), countGroups(filter), @@ -348,8 +348,7 @@ export function createAdminGroupsHandlers(deps: AdminGroupsDeps) { */ const allMemberIds = [...new Set(group.memberIds || [])]; const total = allMemberIds.length; - const limit = Math.min(Math.max(Number(req.query.limit) || 50, 1), 200); - const offset = Math.max(Number(req.query.offset) || 0, 0); + const { limit, offset } = parsePagination(req.query); if (total === 0 || offset >= total) { return res.status(200).json({ members: [], total, limit, offset }); diff --git a/packages/api/src/admin/index.ts b/packages/api/src/admin/index.ts index d833c7e2b0..fe60f1d993 100644 --- a/packages/api/src/admin/index.ts +++ b/packages/api/src/admin/index.ts @@ -1,4 +1,6 @@ export { createAdminConfigHandlers } from './config'; export { createAdminGroupsHandlers } from './groups'; +export { createAdminRolesHandlers } from './roles'; export type { AdminConfigDeps } from './config'; export type { AdminGroupsDeps } from './groups'; +export type { AdminRolesDeps } from './roles'; diff --git a/packages/api/src/admin/pagination.ts b/packages/api/src/admin/pagination.ts new file mode 100644 index 0000000000..69003f0418 --- /dev/null +++ b/packages/api/src/admin/pagination.ts @@ -0,0 +1,17 @@ +export const DEFAULT_PAGE_LIMIT = 50; +export const MAX_PAGE_LIMIT = 200; + +export function parsePagination(query: { limit?: string; offset?: string }): { + limit: number; + offset: number; +} { + const rawLimit = parseInt(query.limit ?? '', 10); + const rawOffset = parseInt(query.offset ?? '', 10); + return { + limit: Math.min( + Math.max(Number.isNaN(rawLimit) ? DEFAULT_PAGE_LIMIT : rawLimit, 1), + MAX_PAGE_LIMIT, + ), + offset: Math.max(Number.isNaN(rawOffset) ? 0 : rawOffset, 0), + }; +} diff --git a/packages/api/src/admin/roles.spec.ts b/packages/api/src/admin/roles.spec.ts new file mode 100644 index 0000000000..3f43079bfb --- /dev/null +++ b/packages/api/src/admin/roles.spec.ts @@ -0,0 +1,1484 @@ +import { Types } from 'mongoose'; +import { SystemRoles } from 'librechat-data-provider'; +import type { IRole, IUser } from '@librechat/data-schemas'; +import type { Response } from 'express'; +import type { ServerRequest } from '~/types/http'; +import type { AdminRolesDeps } from './roles'; +import { createAdminRolesHandlers } from './roles'; + +const { RoleConflictError } = jest.requireActual('@librechat/data-schemas'); + +jest.mock('@librechat/data-schemas', () => ({ + ...jest.requireActual('@librechat/data-schemas'), + logger: { error: jest.fn() }, +})); + +const validUserId = new Types.ObjectId().toString(); + +function mockRole(overrides: Partial = {}): IRole { + return { + name: 'editor', + description: 'Can edit content', + permissions: {}, + ...overrides, + } as IRole; +} + +function mockUser(overrides: Partial = {}): IUser { + return { + _id: new Types.ObjectId(validUserId), + name: 'Test User', + email: 'test@example.com', + avatar: 'https://example.com/avatar.png', + role: 'editor', + ...overrides, + } as IUser; +} + +function createReqRes( + overrides: { + params?: Record; + query?: Record; + body?: Record; + } = {}, +) { + const req = { + params: overrides.params ?? {}, + query: overrides.query ?? {}, + body: overrides.body ?? {}, + } as unknown as ServerRequest; + + const json = jest.fn(); + const status = jest.fn().mockReturnValue({ json }); + const res = { status, json } as unknown as Response; + + return { req, res, status, json }; +} + +function createDeps(overrides: Partial = {}): AdminRolesDeps { + return { + listRoles: jest.fn().mockResolvedValue([]), + countRoles: jest.fn().mockResolvedValue(0), + getRoleByName: jest.fn().mockResolvedValue(null), + createRoleByName: jest.fn().mockResolvedValue(mockRole()), + updateRoleByName: jest.fn().mockResolvedValue(mockRole()), + updateAccessPermissions: jest.fn().mockResolvedValue(undefined), + deleteRoleByName: jest.fn().mockResolvedValue(mockRole()), + findUser: jest.fn().mockResolvedValue(null), + updateUser: jest.fn().mockResolvedValue(mockUser()), + updateUsersByRole: jest.fn().mockResolvedValue(undefined), + findUserIdsByRole: jest.fn().mockResolvedValue(['uid-1', 'uid-2']), + updateUsersRoleByIds: jest.fn().mockResolvedValue(undefined), + listUsersByRole: jest.fn().mockResolvedValue([]), + countUsersByRole: jest.fn().mockResolvedValue(0), + ...overrides, + }; +} + +describe('createAdminRolesHandlers', () => { + describe('listRoles', () => { + it('returns paginated roles with 200', async () => { + const roles = [mockRole()]; + const deps = createDeps({ + listRoles: jest.fn().mockResolvedValue(roles), + countRoles: jest.fn().mockResolvedValue(1), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes(); + + await handlers.listRoles(req, res); + + expect(status).toHaveBeenCalledWith(200); + expect(json).toHaveBeenCalledWith({ roles, total: 1, limit: 50, offset: 0 }); + expect(deps.listRoles).toHaveBeenCalledWith({ limit: 50, offset: 0 }); + }); + + it('passes custom limit and offset from query', async () => { + const deps = createDeps({ + countRoles: jest.fn().mockResolvedValue(100), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + query: { limit: '25', offset: '50' }, + }); + + await handlers.listRoles(req, res); + + expect(status).toHaveBeenCalledWith(200); + expect(json).toHaveBeenCalledWith({ roles: [], total: 100, limit: 25, offset: 50 }); + expect(deps.listRoles).toHaveBeenCalledWith({ limit: 25, offset: 50 }); + }); + + it('clamps limit to 200', async () => { + const deps = createDeps(); + const handlers = createAdminRolesHandlers(deps); + const { req, res } = createReqRes({ query: { limit: '999' } }); + + await handlers.listRoles(req, res); + + expect(deps.listRoles).toHaveBeenCalledWith({ limit: 200, offset: 0 }); + }); + + it('clamps negative offset to 0', async () => { + const deps = createDeps(); + const handlers = createAdminRolesHandlers(deps); + const { req, res } = createReqRes({ query: { offset: '-5' } }); + + await handlers.listRoles(req, res); + + expect(deps.listRoles).toHaveBeenCalledWith({ limit: 50, offset: 0 }); + }); + + it('treats non-numeric limit as default', async () => { + const deps = createDeps(); + const handlers = createAdminRolesHandlers(deps); + const { req, res } = createReqRes({ query: { limit: 'abc' } }); + + await handlers.listRoles(req, res); + + expect(deps.listRoles).toHaveBeenCalledWith({ limit: 50, offset: 0 }); + }); + + it('clamps limit=0 to 1', async () => { + const deps = createDeps(); + const handlers = createAdminRolesHandlers(deps); + const { req, res } = createReqRes({ query: { limit: '0' } }); + + await handlers.listRoles(req, res); + + expect(deps.listRoles).toHaveBeenCalledWith({ limit: 1, offset: 0 }); + }); + + it('truncates float offset to integer', async () => { + const deps = createDeps(); + const handlers = createAdminRolesHandlers(deps); + const { req, res } = createReqRes({ query: { offset: '1.7' } }); + + await handlers.listRoles(req, res); + + expect(deps.listRoles).toHaveBeenCalledWith({ limit: 50, offset: 1 }); + }); + + it('returns 500 on error', async () => { + const deps = createDeps({ listRoles: jest.fn().mockRejectedValue(new Error('db down')) }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes(); + + await handlers.listRoles(req, res); + + expect(status).toHaveBeenCalledWith(500); + expect(json).toHaveBeenCalledWith({ error: 'Failed to list roles' }); + }); + }); + + describe('getRole', () => { + it('returns role with 200', async () => { + const role = mockRole(); + const deps = createDeps({ getRoleByName: jest.fn().mockResolvedValue(role) }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ params: { name: 'editor' } }); + + await handlers.getRole(req, res); + + expect(status).toHaveBeenCalledWith(200); + expect(json).toHaveBeenCalledWith({ role }); + }); + + it('returns 404 when role not found', async () => { + const deps = createDeps({ getRoleByName: jest.fn().mockResolvedValue(null) }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ params: { name: 'nonexistent' } }); + + await handlers.getRole(req, res); + + expect(status).toHaveBeenCalledWith(404); + expect(json).toHaveBeenCalledWith({ error: 'Role not found' }); + }); + + it('returns 500 on error', async () => { + const deps = createDeps({ + getRoleByName: jest.fn().mockRejectedValue(new Error('db down')), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ params: { name: 'editor' } }); + + await handlers.getRole(req, res); + + expect(status).toHaveBeenCalledWith(500); + expect(json).toHaveBeenCalledWith({ error: 'Failed to get role' }); + }); + }); + + describe('createRole', () => { + it('creates role and returns 201', async () => { + const role = mockRole(); + const deps = createDeps({ createRoleByName: jest.fn().mockResolvedValue(role) }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + body: { name: 'editor', description: 'Can edit' }, + }); + + await handlers.createRole(req, res); + + expect(status).toHaveBeenCalledWith(201); + expect(json).toHaveBeenCalledWith({ role }); + expect(deps.createRoleByName).toHaveBeenCalledWith({ + name: 'editor', + description: 'Can edit', + permissions: {}, + }); + }); + + it('passes provided permissions to createRoleByName', async () => { + const perms = { chat: { read: true, write: false } } as unknown as IRole['permissions']; + const role = mockRole({ permissions: perms }); + const deps = createDeps({ createRoleByName: jest.fn().mockResolvedValue(role) }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + body: { name: 'editor', permissions: perms }, + }); + + await handlers.createRole(req, res); + + expect(status).toHaveBeenCalledWith(201); + expect(json).toHaveBeenCalledWith({ role }); + expect(deps.createRoleByName).toHaveBeenCalledWith({ + name: 'editor', + permissions: perms, + }); + }); + + it('returns 400 when name is missing', async () => { + const deps = createDeps(); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ body: {} }); + + await handlers.createRole(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'name is required' }); + expect(deps.createRoleByName).not.toHaveBeenCalled(); + }); + + it('returns 400 when name is whitespace-only', async () => { + const deps = createDeps(); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ body: { name: ' ' } }); + + await handlers.createRole(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'name is required' }); + }); + + it('returns 400 when name contains control characters', async () => { + const deps = createDeps(); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ body: { name: 'bad\x00name' } }); + + await handlers.createRole(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'name contains invalid characters' }); + expect(deps.createRoleByName).not.toHaveBeenCalled(); + }); + + it('returns 400 when name is a reserved path segment', async () => { + const deps = createDeps(); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ body: { name: 'members' } }); + + await handlers.createRole(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'name is a reserved path segment' }); + expect(deps.createRoleByName).not.toHaveBeenCalled(); + }); + + it('returns 400 when name exceeds max length', async () => { + const deps = createDeps(); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + body: { name: 'a'.repeat(501) }, + }); + + await handlers.createRole(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'name must not exceed 500 characters' }); + expect(deps.createRoleByName).not.toHaveBeenCalled(); + }); + + it('returns 400 when description exceeds max length', async () => { + const deps = createDeps(); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + body: { name: 'editor', description: 'a'.repeat(2001) }, + }); + + await handlers.createRole(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ + error: 'description must not exceed 2000 characters', + }); + expect(deps.createRoleByName).not.toHaveBeenCalled(); + }); + + it('returns 409 when role already exists', async () => { + const deps = createDeps({ + createRoleByName: jest + .fn() + .mockRejectedValue(new RoleConflictError('Role "editor" already exists')), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ body: { name: 'editor' } }); + + await handlers.createRole(req, res); + + expect(status).toHaveBeenCalledWith(409); + expect(json).toHaveBeenCalledWith({ error: 'Role "editor" already exists' }); + }); + + it('returns 409 when name is reserved system role', async () => { + const deps = createDeps({ + createRoleByName: jest + .fn() + .mockRejectedValue( + new RoleConflictError('Cannot create role with reserved system name: ADMIN'), + ), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ body: { name: 'ADMIN' } }); + + await handlers.createRole(req, res); + + expect(status).toHaveBeenCalledWith(409); + expect(json).toHaveBeenCalledWith({ + error: 'Cannot create role with reserved system name: ADMIN', + }); + }); + + it('returns 500 on unexpected error', async () => { + const deps = createDeps({ + createRoleByName: jest.fn().mockRejectedValue(new Error('db crash')), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ body: { name: 'editor' } }); + + await handlers.createRole(req, res); + + expect(status).toHaveBeenCalledWith(500); + expect(json).toHaveBeenCalledWith({ error: 'Failed to create role' }); + }); + + it('does not classify unrelated errors as 409', async () => { + const deps = createDeps({ + createRoleByName: jest + .fn() + .mockRejectedValue(new Error('Disk space reserved for system use')), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status } = createReqRes({ body: { name: 'test' } }); + + await handlers.createRole(req, res); + + expect(status).toHaveBeenCalledWith(500); + }); + + it('returns 400 when description is not a string', async () => { + const deps = createDeps(); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + body: { name: 'editor', description: 123 }, + }); + + await handlers.createRole(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'description must be a string' }); + expect(deps.createRoleByName).not.toHaveBeenCalled(); + }); + + it('returns 400 when permissions is an array', async () => { + const deps = createDeps(); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + body: { name: 'editor', permissions: [1, 2, 3] }, + }); + + await handlers.createRole(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'permissions must be an object' }); + expect(deps.createRoleByName).not.toHaveBeenCalled(); + }); + }); + + describe('updateRole', () => { + it('updates role and returns 200', async () => { + const role = mockRole({ name: 'senior-editor' }); + const deps = createDeps({ + getRoleByName: jest.fn().mockResolvedValueOnce(mockRole()).mockResolvedValueOnce(null), + updateRoleByName: jest.fn().mockResolvedValue(role), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: 'editor' }, + body: { name: 'senior-editor' }, + }); + + await handlers.updateRole(req, res); + + expect(status).toHaveBeenCalledWith(200); + expect(json).toHaveBeenCalledWith({ role }); + expect(deps.updateRoleByName).toHaveBeenCalledWith('editor', { name: 'senior-editor' }); + }); + + it('trims name before storage', async () => { + const role = mockRole({ name: 'trimmed' }); + const deps = createDeps({ + getRoleByName: jest.fn().mockResolvedValueOnce(mockRole()).mockResolvedValueOnce(null), + updateRoleByName: jest.fn().mockResolvedValue(role), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res } = createReqRes({ + params: { name: 'editor' }, + body: { name: ' trimmed ' }, + }); + + await handlers.updateRole(req, res); + + expect(deps.updateRoleByName).toHaveBeenCalledWith('editor', { name: 'trimmed' }); + }); + + it('migrates users before renaming role', async () => { + const role = mockRole({ name: 'new-name' }); + const callOrder: string[] = []; + const deps = createDeps({ + getRoleByName: jest.fn().mockResolvedValueOnce(mockRole()).mockResolvedValueOnce(null), + findUserIdsByRole: jest.fn().mockImplementation(() => { + callOrder.push('findUserIdsByRole'); + return Promise.resolve(['uid-1']); + }), + updateUsersByRole: jest.fn().mockImplementation(() => { + callOrder.push('updateUsersByRole'); + return Promise.resolve(); + }), + updateRoleByName: jest.fn().mockImplementation(() => { + callOrder.push('updateRoleByName'); + return Promise.resolve(role); + }), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status } = createReqRes({ + params: { name: 'editor' }, + body: { name: 'new-name' }, + }); + + await handlers.updateRole(req, res); + + expect(status).toHaveBeenCalledWith(200); + expect(deps.findUserIdsByRole).toHaveBeenCalledWith('editor'); + expect(deps.updateUsersByRole).toHaveBeenCalledWith('editor', 'new-name'); + expect(callOrder).toEqual(['findUserIdsByRole', 'updateUsersByRole', 'updateRoleByName']); + }); + + it('does not rename role when user migration fails', async () => { + const deps = createDeps({ + getRoleByName: jest.fn().mockResolvedValueOnce(mockRole()).mockResolvedValueOnce(null), + updateUsersByRole: jest.fn().mockRejectedValue(new Error('migration failed')), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status } = createReqRes({ + params: { name: 'editor' }, + body: { name: 'new-name' }, + }); + + await handlers.updateRole(req, res); + + expect(status).toHaveBeenCalledWith(500); + expect(deps.updateRoleByName).not.toHaveBeenCalled(); + }); + + it('does not migrate users when name unchanged', async () => { + const role = mockRole({ description: 'updated' }); + const deps = createDeps({ + getRoleByName: jest.fn().mockResolvedValue(mockRole()), + updateRoleByName: jest.fn().mockResolvedValue(role), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res } = createReqRes({ + params: { name: 'editor' }, + body: { description: 'updated' }, + }); + + await handlers.updateRole(req, res); + + expect(deps.updateUsersByRole).not.toHaveBeenCalled(); + }); + + it('renames and updates description in a single request', async () => { + const role = mockRole({ name: 'senior-editor', description: 'Updated desc' }); + const deps = createDeps({ + getRoleByName: jest.fn().mockResolvedValueOnce(mockRole()).mockResolvedValueOnce(null), + updateRoleByName: jest.fn().mockResolvedValue(role), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: 'editor' }, + body: { name: 'senior-editor', description: 'Updated desc' }, + }); + + await handlers.updateRole(req, res); + + expect(status).toHaveBeenCalledWith(200); + expect(json).toHaveBeenCalledWith({ role }); + expect(deps.updateUsersByRole).toHaveBeenCalledWith('editor', 'senior-editor'); + expect(deps.updateRoleByName).toHaveBeenCalledWith('editor', { + name: 'senior-editor', + description: 'Updated desc', + }); + }); + + it('returns 403 when renaming a system role', async () => { + const deps = createDeps(); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: SystemRoles.ADMIN }, + body: { name: 'custom-admin' }, + }); + + await handlers.updateRole(req, res); + + expect(status).toHaveBeenCalledWith(403); + expect(json).toHaveBeenCalledWith({ error: 'Cannot rename system role' }); + expect(deps.getRoleByName).not.toHaveBeenCalled(); + }); + + it('returns 403 when renaming to a system role name', async () => { + const deps = createDeps(); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: 'editor' }, + body: { name: SystemRoles.ADMIN }, + }); + + await handlers.updateRole(req, res); + + expect(status).toHaveBeenCalledWith(403); + expect(json).toHaveBeenCalledWith({ error: 'Cannot use a reserved system role name' }); + }); + + it('returns 409 when target name already exists', async () => { + const deps = createDeps({ + getRoleByName: jest.fn().mockResolvedValue(mockRole()), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: 'editor' }, + body: { name: 'viewer' }, + }); + + await handlers.updateRole(req, res); + + expect(status).toHaveBeenCalledWith(409); + expect(json).toHaveBeenCalledWith({ error: 'Role "viewer" already exists' }); + }); + + it('returns 400 when name is empty string', async () => { + const deps = createDeps(); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: 'editor' }, + body: { name: '' }, + }); + + await handlers.updateRole(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'name must be a non-empty string' }); + expect(deps.getRoleByName).not.toHaveBeenCalled(); + }); + + it('returns 400 when name is whitespace-only', async () => { + const deps = createDeps(); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: 'editor' }, + body: { name: ' ' }, + }); + + await handlers.updateRole(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'name must be a non-empty string' }); + }); + + it('returns 400 when name exceeds max length', async () => { + const deps = createDeps(); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: 'editor' }, + body: { name: 'a'.repeat(501) }, + }); + + await handlers.updateRole(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'name must not exceed 500 characters' }); + expect(deps.getRoleByName).not.toHaveBeenCalled(); + }); + + it('returns 404 when role not found', async () => { + const deps = createDeps({ getRoleByName: jest.fn().mockResolvedValue(null) }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: 'nonexistent' }, + body: { description: 'updated' }, + }); + + await handlers.updateRole(req, res); + + expect(status).toHaveBeenCalledWith(404); + expect(json).toHaveBeenCalledWith({ error: 'Role not found' }); + }); + + it('returns 404 when updateRoleByName returns null', async () => { + const deps = createDeps({ + getRoleByName: jest.fn().mockResolvedValue(mockRole()), + updateRoleByName: jest.fn().mockResolvedValue(null), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: 'editor' }, + body: { description: 'updated' }, + }); + + await handlers.updateRole(req, res); + + expect(status).toHaveBeenCalledWith(404); + expect(json).toHaveBeenCalledWith({ error: 'Role not found' }); + }); + + it('rolls back user migration when rename fails', async () => { + const ids = ['uid-1', 'uid-2']; + const deps = createDeps({ + getRoleByName: jest.fn().mockResolvedValueOnce(mockRole()).mockResolvedValueOnce(null), + findUserIdsByRole: jest.fn().mockResolvedValue(ids), + updateRoleByName: jest.fn().mockResolvedValue(null), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: 'editor' }, + body: { name: 'new-name' }, + }); + + await handlers.updateRole(req, res); + + expect(status).toHaveBeenCalledWith(404); + expect(json).toHaveBeenCalledWith({ error: 'Role not found' }); + expect(deps.updateUsersByRole).toHaveBeenCalledTimes(1); + expect(deps.updateUsersByRole).toHaveBeenCalledWith('editor', 'new-name'); + expect(deps.updateUsersRoleByIds).toHaveBeenCalledWith(ids, 'editor'); + }); + + it('rolls back user migration when rename throws', async () => { + const ids = ['uid-1', 'uid-2']; + const deps = createDeps({ + getRoleByName: jest.fn().mockResolvedValueOnce(mockRole()).mockResolvedValueOnce(null), + findUserIdsByRole: jest.fn().mockResolvedValue(ids), + updateRoleByName: jest.fn().mockRejectedValue(new Error('db crash')), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status } = createReqRes({ + params: { name: 'editor' }, + body: { name: 'new-name' }, + }); + + await handlers.updateRole(req, res); + + expect(status).toHaveBeenCalledWith(500); + expect(deps.updateUsersByRole).toHaveBeenCalledTimes(1); + expect(deps.updateUsersByRole).toHaveBeenCalledWith('editor', 'new-name'); + expect(deps.updateUsersRoleByIds).toHaveBeenCalledWith(ids, 'editor'); + }); + + it('logs rollback failure and still returns 500', async () => { + const deps = createDeps({ + getRoleByName: jest.fn().mockResolvedValueOnce(mockRole()).mockResolvedValueOnce(null), + findUserIdsByRole: jest.fn().mockResolvedValue(['uid-1']), + updateUsersRoleByIds: jest.fn().mockRejectedValue(new Error('rollback failed')), + updateRoleByName: jest.fn().mockRejectedValue(new Error('rename failed')), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status } = createReqRes({ + params: { name: 'editor' }, + body: { name: 'new-name' }, + }); + + await handlers.updateRole(req, res); + + expect(status).toHaveBeenCalledWith(500); + expect(deps.updateUsersByRole).toHaveBeenCalledTimes(1); + expect(deps.updateUsersRoleByIds).toHaveBeenCalledTimes(1); + }); + + it('returns 400 when description exceeds max length', async () => { + const deps = createDeps(); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: 'editor' }, + body: { description: 'a'.repeat(2001) }, + }); + + await handlers.updateRole(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ + error: 'description must not exceed 2000 characters', + }); + expect(deps.getRoleByName).not.toHaveBeenCalled(); + }); + + it('returns 500 on unexpected error', async () => { + const deps = createDeps({ + getRoleByName: jest.fn().mockResolvedValue(mockRole()), + updateRoleByName: jest.fn().mockRejectedValue(new Error('db error')), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: 'editor' }, + body: { description: 'updated' }, + }); + + await handlers.updateRole(req, res); + + expect(status).toHaveBeenCalledWith(500); + expect(json).toHaveBeenCalledWith({ error: 'Failed to update role' }); + }); + + it('does not roll back when error occurs before user migration', async () => { + const deps = createDeps({ + getRoleByName: jest + .fn() + .mockResolvedValueOnce(mockRole()) + .mockRejectedValueOnce(new Error('db crash')), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status } = createReqRes({ + params: { name: 'editor' }, + body: { name: 'new-name' }, + }); + + await handlers.updateRole(req, res); + + expect(status).toHaveBeenCalledWith(500); + expect(deps.updateUsersByRole).not.toHaveBeenCalled(); + }); + + it('does not migrate users when findUserIdsByRole throws', async () => { + const deps = createDeps({ + getRoleByName: jest.fn().mockResolvedValueOnce(mockRole()).mockResolvedValueOnce(null), + findUserIdsByRole: jest.fn().mockRejectedValue(new Error('db crash')), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status } = createReqRes({ + params: { name: 'editor' }, + body: { name: 'new-name' }, + }); + + await handlers.updateRole(req, res); + + expect(status).toHaveBeenCalledWith(500); + expect(deps.updateUsersByRole).not.toHaveBeenCalled(); + expect(deps.updateUsersRoleByIds).not.toHaveBeenCalled(); + }); + + it('returns existing role early when update body has no changes', async () => { + const role = mockRole(); + const deps = createDeps({ + getRoleByName: jest.fn().mockResolvedValue(role), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: 'editor' }, + body: {}, + }); + + await handlers.updateRole(req, res); + + expect(status).toHaveBeenCalledWith(200); + expect(json).toHaveBeenCalledWith({ role }); + expect(deps.updateRoleByName).not.toHaveBeenCalled(); + }); + + it('rejects invalid description before making DB calls', async () => { + const deps = createDeps(); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: 'editor' }, + body: { description: 123 }, + }); + + await handlers.updateRole(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'description must be a string' }); + expect(deps.getRoleByName).not.toHaveBeenCalled(); + }); + }); + + describe('updateRolePermissions', () => { + it('updates permissions and returns 200 with updated role', async () => { + const role = mockRole(); + const updatedRole = mockRole({ + permissions: { chat: { read: true, write: true } } as IRole['permissions'], + }); + const deps = createDeps({ + getRoleByName: jest.fn().mockResolvedValueOnce(role).mockResolvedValueOnce(updatedRole), + }); + const handlers = createAdminRolesHandlers(deps); + const perms = { chat: { read: true, write: true } }; + const { req, res, status, json } = createReqRes({ + params: { name: 'editor' }, + body: { permissions: perms }, + }); + + await handlers.updateRolePermissions(req, res); + + expect(deps.updateAccessPermissions).toHaveBeenCalledWith('editor', perms, role); + expect(status).toHaveBeenCalledWith(200); + expect(json).toHaveBeenCalledWith({ role: updatedRole }); + }); + + it('returns 400 when permissions is missing', async () => { + const deps = createDeps(); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: 'editor' }, + body: {}, + }); + + await handlers.updateRolePermissions(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'permissions object is required' }); + }); + + it('returns 400 when permissions is an array', async () => { + const deps = createDeps(); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: 'editor' }, + body: { permissions: [1, 2, 3] }, + }); + + await handlers.updateRolePermissions(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'permissions object is required' }); + }); + + it('returns 404 when role not found', async () => { + const deps = createDeps({ getRoleByName: jest.fn().mockResolvedValue(null) }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: 'nonexistent' }, + body: { permissions: { chat: { read: true } } }, + }); + + await handlers.updateRolePermissions(req, res); + + expect(status).toHaveBeenCalledWith(404); + expect(json).toHaveBeenCalledWith({ error: 'Role not found' }); + }); + + it('returns 500 on error', async () => { + const deps = createDeps({ + getRoleByName: jest.fn().mockResolvedValue(mockRole()), + updateAccessPermissions: jest.fn().mockRejectedValue(new Error('db down')), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: 'editor' }, + body: { permissions: { chat: { read: true } } }, + }); + + await handlers.updateRolePermissions(req, res); + + expect(status).toHaveBeenCalledWith(500); + expect(json).toHaveBeenCalledWith({ error: 'Failed to update role permissions' }); + }); + }); + + describe('deleteRole', () => { + it('deletes role and returns 200', async () => { + const deps = createDeps(); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ params: { name: 'editor' } }); + + await handlers.deleteRole(req, res); + + expect(deps.deleteRoleByName).toHaveBeenCalledWith('editor'); + expect(status).toHaveBeenCalledWith(200); + expect(json).toHaveBeenCalledWith({ success: true }); + }); + + it('returns 403 for system role', async () => { + const deps = createDeps(); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ params: { name: SystemRoles.ADMIN } }); + + await handlers.deleteRole(req, res); + + expect(status).toHaveBeenCalledWith(403); + expect(json).toHaveBeenCalledWith({ error: 'Cannot delete system role' }); + expect(deps.deleteRoleByName).not.toHaveBeenCalled(); + }); + + it('returns 404 when role not found', async () => { + const deps = createDeps({ deleteRoleByName: jest.fn().mockResolvedValue(null) }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ params: { name: 'nonexistent' } }); + + await handlers.deleteRole(req, res); + + expect(status).toHaveBeenCalledWith(404); + expect(json).toHaveBeenCalledWith({ error: 'Role not found' }); + }); + + it('returns 500 on error', async () => { + const deps = createDeps({ + deleteRoleByName: jest.fn().mockRejectedValue(new Error('db down')), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ params: { name: 'editor' } }); + + await handlers.deleteRole(req, res); + + expect(status).toHaveBeenCalledWith(500); + expect(json).toHaveBeenCalledWith({ error: 'Failed to delete role' }); + }); + }); + + describe('getRoleMembers', () => { + it('returns paginated members with 200', async () => { + const user = mockUser(); + const deps = createDeps({ + getRoleByName: jest.fn().mockResolvedValue(mockRole()), + listUsersByRole: jest.fn().mockResolvedValue([user]), + countUsersByRole: jest.fn().mockResolvedValue(1), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ params: { name: 'editor' } }); + + await handlers.getRoleMembers(req, res); + + expect(deps.listUsersByRole).toHaveBeenCalledWith('editor', { limit: 50, offset: 0 }); + expect(deps.countUsersByRole).toHaveBeenCalledWith('editor'); + expect(status).toHaveBeenCalledWith(200); + const response = json.mock.calls[0][0]; + expect(response.members).toHaveLength(1); + expect(response.members[0]).toEqual({ + userId: validUserId, + name: 'Test User', + email: 'test@example.com', + avatarUrl: 'https://example.com/avatar.png', + }); + expect(response.total).toBe(1); + expect(response.limit).toBe(50); + expect(response.offset).toBe(0); + }); + + it('passes pagination parameters from query', async () => { + const deps = createDeps({ + getRoleByName: jest.fn().mockResolvedValue(mockRole()), + countUsersByRole: jest.fn().mockResolvedValue(0), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res } = createReqRes({ + params: { name: 'editor' }, + query: { limit: '10', offset: '20' }, + }); + + await handlers.getRoleMembers(req, res); + + expect(deps.listUsersByRole).toHaveBeenCalledWith('editor', { limit: 10, offset: 20 }); + }); + + it('clamps limit to 200', async () => { + const deps = createDeps({ + getRoleByName: jest.fn().mockResolvedValue(mockRole()), + countUsersByRole: jest.fn().mockResolvedValue(0), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res } = createReqRes({ + params: { name: 'editor' }, + query: { limit: '999' }, + }); + + await handlers.getRoleMembers(req, res); + + expect(deps.listUsersByRole).toHaveBeenCalledWith('editor', { limit: 200, offset: 0 }); + }); + + it('does not include joinedAt in response', async () => { + const user = mockUser(); + const deps = createDeps({ + getRoleByName: jest.fn().mockResolvedValue(mockRole()), + listUsersByRole: jest.fn().mockResolvedValue([user]), + countUsersByRole: jest.fn().mockResolvedValue(1), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, json } = createReqRes({ params: { name: 'editor' } }); + + await handlers.getRoleMembers(req, res); + + const member = json.mock.calls[0][0].members[0]; + expect(member).not.toHaveProperty('joinedAt'); + }); + + it('returns empty array when no members', async () => { + const deps = createDeps({ + getRoleByName: jest.fn().mockResolvedValue(mockRole()), + countUsersByRole: jest.fn().mockResolvedValue(0), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ params: { name: 'editor' } }); + + await handlers.getRoleMembers(req, res); + + expect(status).toHaveBeenCalledWith(200); + expect(json).toHaveBeenCalledWith({ members: [], total: 0, limit: 50, offset: 0 }); + }); + + it('returns 404 when role not found', async () => { + const deps = createDeps({ getRoleByName: jest.fn().mockResolvedValue(null) }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ params: { name: 'nonexistent' } }); + + await handlers.getRoleMembers(req, res); + + expect(status).toHaveBeenCalledWith(404); + expect(json).toHaveBeenCalledWith({ error: 'Role not found' }); + }); + + it('returns 500 on error', async () => { + const deps = createDeps({ + getRoleByName: jest.fn().mockResolvedValue(mockRole()), + listUsersByRole: jest.fn().mockRejectedValue(new Error('db down')), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ params: { name: 'editor' } }); + + await handlers.getRoleMembers(req, res); + + expect(status).toHaveBeenCalledWith(500); + expect(json).toHaveBeenCalledWith({ error: 'Failed to get role members' }); + }); + }); + + describe('addRoleMember', () => { + it('adds member and returns 200', async () => { + const deps = createDeps({ + getRoleByName: jest.fn().mockResolvedValue(mockRole()), + findUser: jest.fn().mockResolvedValue(mockUser({ role: 'viewer' })), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: 'editor' }, + body: { userId: validUserId }, + }); + + await handlers.addRoleMember(req, res); + + expect(deps.updateUser).toHaveBeenCalledWith(validUserId, { role: 'editor' }); + expect(status).toHaveBeenCalledWith(200); + expect(json).toHaveBeenCalledWith({ success: true }); + }); + + it('skips DB write when user already has the target role', async () => { + const deps = createDeps({ + getRoleByName: jest.fn().mockResolvedValue(mockRole()), + findUser: jest.fn().mockResolvedValue(mockUser({ role: 'editor' })), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: 'editor' }, + body: { userId: validUserId }, + }); + + await handlers.addRoleMember(req, res); + + expect(status).toHaveBeenCalledWith(200); + expect(json).toHaveBeenCalledWith({ success: true }); + expect(deps.updateUser).not.toHaveBeenCalled(); + }); + + it('returns 400 when userId is missing', async () => { + const deps = createDeps(); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: 'editor' }, + body: {}, + }); + + await handlers.addRoleMember(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'userId is required' }); + }); + + it('returns 400 for invalid ObjectId', async () => { + const deps = createDeps(); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: 'editor' }, + body: { userId: 'not-valid' }, + }); + + await handlers.addRoleMember(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'Invalid user ID format' }); + }); + + it('returns 404 when role not found', async () => { + const deps = createDeps({ getRoleByName: jest.fn().mockResolvedValue(null) }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: 'nonexistent' }, + body: { userId: validUserId }, + }); + + await handlers.addRoleMember(req, res); + + expect(status).toHaveBeenCalledWith(404); + expect(json).toHaveBeenCalledWith({ error: 'Role not found' }); + }); + + it('returns 404 when user not found', async () => { + const deps = createDeps({ + getRoleByName: jest.fn().mockResolvedValue(mockRole()), + findUser: jest.fn().mockResolvedValue(null), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: 'editor' }, + body: { userId: validUserId }, + }); + + await handlers.addRoleMember(req, res); + + expect(status).toHaveBeenCalledWith(404); + expect(json).toHaveBeenCalledWith({ error: 'User not found' }); + }); + + it('returns 400 when reassigning the last admin to another role', async () => { + const deps = createDeps({ + getRoleByName: jest.fn().mockResolvedValue(mockRole({ name: 'editor' })), + findUser: jest.fn().mockResolvedValue(mockUser({ role: SystemRoles.ADMIN })), + countUsersByRole: jest.fn().mockResolvedValue(1), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: 'editor' }, + body: { userId: validUserId }, + }); + + await handlers.addRoleMember(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'Cannot remove the last admin user' }); + expect(deps.updateUser).not.toHaveBeenCalled(); + }); + + it('allows reassigning an admin when multiple admins exist', async () => { + const deps = createDeps({ + getRoleByName: jest.fn().mockResolvedValue(mockRole({ name: 'editor' })), + findUser: jest.fn().mockResolvedValue(mockUser({ role: SystemRoles.ADMIN })), + countUsersByRole: jest.fn().mockResolvedValue(3), + updateUser: jest.fn().mockResolvedValue(mockUser({ role: 'editor' })), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status } = createReqRes({ + params: { name: 'editor' }, + body: { userId: validUserId }, + }); + + await handlers.addRoleMember(req, res); + + expect(status).toHaveBeenCalledWith(200); + expect(deps.updateUser).toHaveBeenCalledWith(validUserId, { role: 'editor' }); + }); + + it('rolls back assignment when post-write admin count is zero', async () => { + const deps = createDeps({ + getRoleByName: jest.fn().mockResolvedValue(mockRole({ name: 'editor' })), + findUser: jest.fn().mockResolvedValue(mockUser({ role: SystemRoles.ADMIN })), + countUsersByRole: jest.fn().mockResolvedValueOnce(2).mockResolvedValueOnce(0), + updateUser: jest.fn().mockResolvedValue(mockUser()), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: 'editor' }, + body: { userId: validUserId }, + }); + + await handlers.addRoleMember(req, res); + + expect(deps.updateUser).toHaveBeenCalledTimes(2); + expect(deps.updateUser).toHaveBeenLastCalledWith(validUserId, { role: SystemRoles.ADMIN }); + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'Cannot remove the last admin user' }); + }); + + it('returns 403 when adding to a non-ADMIN system role', async () => { + const deps = createDeps(); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: SystemRoles.USER }, + body: { userId: validUserId }, + }); + + await handlers.addRoleMember(req, res); + + expect(status).toHaveBeenCalledWith(403); + expect(json).toHaveBeenCalledWith({ + error: 'Cannot directly assign members to a system role', + }); + expect(deps.updateUser).not.toHaveBeenCalled(); + }); + + it('allows promoting a non-admin user to the ADMIN role', async () => { + const deps = createDeps({ + getRoleByName: jest.fn().mockResolvedValue(mockRole({ name: SystemRoles.ADMIN })), + findUser: jest.fn().mockResolvedValue(mockUser({ role: 'editor' })), + updateUser: jest.fn().mockResolvedValue(mockUser({ role: SystemRoles.ADMIN })), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: SystemRoles.ADMIN }, + body: { userId: validUserId }, + }); + + await handlers.addRoleMember(req, res); + + expect(deps.updateUser).toHaveBeenCalledWith(validUserId, { role: SystemRoles.ADMIN }); + expect(status).toHaveBeenCalledWith(200); + expect(json).toHaveBeenCalledWith({ success: true }); + }); + + it('returns 500 on unexpected error', async () => { + const deps = createDeps({ + getRoleByName: jest.fn().mockResolvedValue(mockRole()), + findUser: jest.fn().mockResolvedValue(mockUser({ role: 'viewer' })), + updateUser: jest.fn().mockRejectedValue(new Error('timeout')), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: 'editor' }, + body: { userId: validUserId }, + }); + + await handlers.addRoleMember(req, res); + + expect(status).toHaveBeenCalledWith(500); + expect(json).toHaveBeenCalledWith({ error: 'Failed to add role member' }); + }); + }); + + describe('removeRoleMember', () => { + it('removes member and returns 200', async () => { + const deps = createDeps({ + getRoleByName: jest.fn().mockResolvedValue(mockRole()), + findUser: jest.fn().mockResolvedValue(mockUser({ role: 'editor' })), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: 'editor', userId: validUserId }, + }); + + await handlers.removeRoleMember(req, res); + + expect(deps.updateUser).toHaveBeenCalledWith(validUserId, { role: SystemRoles.USER }); + expect(status).toHaveBeenCalledWith(200); + expect(json).toHaveBeenCalledWith({ success: true }); + }); + + it('returns 403 when removing from a non-ADMIN system role', async () => { + const deps = createDeps(); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: SystemRoles.USER, userId: validUserId }, + }); + + await handlers.removeRoleMember(req, res); + + expect(status).toHaveBeenCalledWith(403); + expect(json).toHaveBeenCalledWith({ error: 'Cannot remove members from a system role' }); + expect(deps.getRoleByName).not.toHaveBeenCalled(); + }); + + it('returns 400 for invalid ObjectId', async () => { + const deps = createDeps(); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: 'editor', userId: 'bad' }, + }); + + await handlers.removeRoleMember(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'Invalid user ID format' }); + expect(deps.findUser).not.toHaveBeenCalled(); + }); + + it('returns 404 when role not found', async () => { + const deps = createDeps({ getRoleByName: jest.fn().mockResolvedValue(null) }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: 'nonexistent', userId: validUserId }, + }); + + await handlers.removeRoleMember(req, res); + + expect(status).toHaveBeenCalledWith(404); + expect(json).toHaveBeenCalledWith({ error: 'Role not found' }); + expect(deps.findUser).not.toHaveBeenCalled(); + }); + + it('returns 404 when user not found', async () => { + const deps = createDeps({ + getRoleByName: jest.fn().mockResolvedValue(mockRole()), + findUser: jest.fn().mockResolvedValue(null), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: 'editor', userId: validUserId }, + }); + + await handlers.removeRoleMember(req, res); + + expect(status).toHaveBeenCalledWith(404); + expect(json).toHaveBeenCalledWith({ error: 'User not found' }); + }); + + it('returns 400 when user is not a member of the role', async () => { + const deps = createDeps({ + getRoleByName: jest.fn().mockResolvedValue(mockRole()), + findUser: jest.fn().mockResolvedValue(mockUser({ role: 'other-role' })), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: 'editor', userId: validUserId }, + }); + + await handlers.removeRoleMember(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'User is not a member of this role' }); + expect(deps.updateUser).not.toHaveBeenCalled(); + }); + + it('returns 400 when removing the last admin user', async () => { + const deps = createDeps({ + getRoleByName: jest.fn().mockResolvedValue(mockRole({ name: SystemRoles.ADMIN })), + findUser: jest.fn().mockResolvedValue(mockUser({ role: SystemRoles.ADMIN })), + countUsersByRole: jest.fn().mockResolvedValue(1), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: SystemRoles.ADMIN, userId: validUserId }, + }); + + await handlers.removeRoleMember(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'Cannot remove the last admin user' }); + expect(deps.updateUser).not.toHaveBeenCalled(); + }); + + it('allows removing an admin when multiple admins exist', async () => { + const deps = createDeps({ + getRoleByName: jest.fn().mockResolvedValue(mockRole({ name: SystemRoles.ADMIN })), + findUser: jest.fn().mockResolvedValue(mockUser({ role: SystemRoles.ADMIN })), + countUsersByRole: jest.fn().mockResolvedValue(3), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: SystemRoles.ADMIN, userId: validUserId }, + }); + + await handlers.removeRoleMember(req, res); + + expect(status).toHaveBeenCalledWith(200); + expect(json).toHaveBeenCalledWith({ success: true }); + expect(deps.updateUser).toHaveBeenCalledWith(validUserId, { role: SystemRoles.USER }); + }); + + it('rolls back removal when post-write check finds zero admins', async () => { + const deps = createDeps({ + getRoleByName: jest.fn().mockResolvedValue(mockRole({ name: SystemRoles.ADMIN })), + findUser: jest.fn().mockResolvedValue(mockUser({ role: SystemRoles.ADMIN })), + countUsersByRole: jest.fn().mockResolvedValueOnce(2).mockResolvedValueOnce(0), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: SystemRoles.ADMIN, userId: validUserId }, + }); + + await handlers.removeRoleMember(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'Cannot remove the last admin user' }); + expect(deps.updateUser).toHaveBeenCalledTimes(2); + expect(deps.updateUser).toHaveBeenNthCalledWith(1, validUserId, { + role: SystemRoles.USER, + }); + expect(deps.updateUser).toHaveBeenNthCalledWith(2, validUserId, { + role: SystemRoles.ADMIN, + }); + }); + + it('returns 400 even when rollback updateUser throws', async () => { + const deps = createDeps({ + getRoleByName: jest.fn().mockResolvedValue(mockRole({ name: SystemRoles.ADMIN })), + findUser: jest.fn().mockResolvedValue(mockUser({ role: SystemRoles.ADMIN })), + countUsersByRole: jest.fn().mockResolvedValueOnce(2).mockResolvedValueOnce(0), + updateUser: jest + .fn() + .mockResolvedValueOnce(mockUser({ role: SystemRoles.USER })) + .mockRejectedValueOnce(new Error('rollback failed')), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: SystemRoles.ADMIN, userId: validUserId }, + }); + + await handlers.removeRoleMember(req, res); + + expect(status).toHaveBeenCalledWith(400); + expect(json).toHaveBeenCalledWith({ error: 'Cannot remove the last admin user' }); + expect(deps.updateUser).toHaveBeenCalledTimes(2); + }); + + it('returns 500 on unexpected error', async () => { + const deps = createDeps({ + getRoleByName: jest.fn().mockResolvedValue(mockRole()), + findUser: jest.fn().mockResolvedValue(mockUser({ role: 'editor' })), + updateUser: jest.fn().mockRejectedValue(new Error('timeout')), + }); + const handlers = createAdminRolesHandlers(deps); + const { req, res, status, json } = createReqRes({ + params: { name: 'editor', userId: validUserId }, + }); + + await handlers.removeRoleMember(req, res); + + expect(status).toHaveBeenCalledWith(500); + expect(json).toHaveBeenCalledWith({ error: 'Failed to remove role member' }); + }); + }); +}); diff --git a/packages/api/src/admin/roles.ts b/packages/api/src/admin/roles.ts new file mode 100644 index 0000000000..b8c87c23ea --- /dev/null +++ b/packages/api/src/admin/roles.ts @@ -0,0 +1,550 @@ +import { SystemRoles } from 'librechat-data-provider'; +import { logger, isValidObjectIdString, RoleConflictError } from '@librechat/data-schemas'; +import type { IRole, IUser, AdminMember } from '@librechat/data-schemas'; +import type { FilterQuery, Types } from 'mongoose'; +import type { Response } from 'express'; +import type { ServerRequest } from '~/types/http'; +import { parsePagination } from './pagination'; + +const systemRoleValues = new Set(Object.values(SystemRoles)); + +/** Case-insensitive check โ€” the legacy roles route uppercases params. */ +function isSystemRoleName(name: string): boolean { + return systemRoleValues.has(name.toUpperCase()); +} + +const MAX_NAME_LENGTH = 500; +const MAX_DESCRIPTION_LENGTH = 2000; +const CONTROL_CHAR_RE = /\p{Cc}/u; +/** + * Role names that would create semantically ambiguous URLs. + * e.g. GET /api/admin/roles/members โ€” is that "list roles" or "get role named members"? + * Express routing resolves this correctly (single vs multi-segment), but the URLs + * are confusing for API consumers. Keep in sync with sub-path routes in routes/admin/roles.js. + */ +const RESERVED_ROLE_NAMES = new Set(['members', 'permissions']); + +function validateNameParam(name: string): string | null { + if (!name || typeof name !== 'string') { + return 'name parameter is required'; + } + if (name.length > MAX_NAME_LENGTH) { + return `name must not exceed ${MAX_NAME_LENGTH} characters`; + } + if (CONTROL_CHAR_RE.test(name)) { + return 'name contains invalid characters'; + } + return null; +} + +function validateRoleName(name: unknown, required: boolean): string | null { + if (name === undefined) { + return required ? 'name is required' : null; + } + if (typeof name !== 'string' || !name.trim()) { + return required ? 'name is required' : 'name must be a non-empty string'; + } + const trimmed = name.trim(); + if (trimmed.length > MAX_NAME_LENGTH) { + return `name must not exceed ${MAX_NAME_LENGTH} characters`; + } + if (CONTROL_CHAR_RE.test(trimmed)) { + return 'name contains invalid characters'; + } + if (RESERVED_ROLE_NAMES.has(trimmed)) { + return 'name is a reserved path segment'; + } + return null; +} + +function validateDescription(description: unknown): string | null { + if (description === undefined) { + return null; + } + if (typeof description !== 'string') { + return 'description must be a string'; + } + if (description.length > MAX_DESCRIPTION_LENGTH) { + return `description must not exceed ${MAX_DESCRIPTION_LENGTH} characters`; + } + return null; +} + +interface RoleNameParams { + name: string; +} + +interface RoleMemberParams extends RoleNameParams { + userId: string; +} + +export type RoleListItem = { _id: Types.ObjectId | string; name: string; description?: string }; + +export interface AdminRolesDeps { + listRoles: (options?: { limit?: number; offset?: number }) => Promise; + countRoles: () => Promise; + getRoleByName: (name: string, fields?: string | string[] | null) => Promise; + createRoleByName: (roleData: Partial) => Promise; + updateRoleByName: (name: string, updates: Partial) => Promise; + updateAccessPermissions: ( + name: string, + perms: Record>, + roleData?: IRole, + ) => Promise; + deleteRoleByName: (name: string) => Promise; + findUser: ( + criteria: FilterQuery, + fields?: string | string[] | null, + ) => Promise; + updateUser: (userId: string, data: Partial) => Promise; + updateUsersByRole: (oldRole: string, newRole: string) => Promise; + findUserIdsByRole: (roleName: string) => Promise; + updateUsersRoleByIds: (userIds: string[], newRole: string) => Promise; + listUsersByRole: ( + roleName: string, + options?: { limit?: number; offset?: number }, + ) => Promise; + countUsersByRole: (roleName: string) => Promise; +} + +export function createAdminRolesHandlers(deps: AdminRolesDeps) { + const { + listRoles, + countRoles, + getRoleByName, + createRoleByName, + updateRoleByName, + updateAccessPermissions, + deleteRoleByName, + findUser, + updateUser, + updateUsersByRole, + findUserIdsByRole, + updateUsersRoleByIds, + listUsersByRole, + countUsersByRole, + } = deps; + + async function listRolesHandler(req: ServerRequest, res: Response) { + try { + const { limit, offset } = parsePagination(req.query); + const [roles, total] = await Promise.all([listRoles({ limit, offset }), countRoles()]); + return res.status(200).json({ roles, total, limit, offset }); + } catch (error) { + logger.error('[adminRoles] listRoles error:', error); + return res.status(500).json({ error: 'Failed to list roles' }); + } + } + + async function getRoleHandler(req: ServerRequest, res: Response) { + try { + const { name } = req.params as RoleNameParams; + const paramError = validateNameParam(name); + if (paramError) { + return res.status(400).json({ error: paramError }); + } + const role = await getRoleByName(name); + if (!role) { + return res.status(404).json({ error: 'Role not found' }); + } + return res.status(200).json({ role }); + } catch (error) { + logger.error('[adminRoles] getRole error:', error); + return res.status(500).json({ error: 'Failed to get role' }); + } + } + + async function createRoleHandler(req: ServerRequest, res: Response) { + try { + const { name, description, permissions } = req.body as { + name?: string; + description?: string; + permissions?: IRole['permissions']; + }; + const nameError = validateRoleName(name, true); + if (nameError) { + return res.status(400).json({ error: nameError }); + } + const descError = validateDescription(description); + if (descError) { + return res.status(400).json({ error: descError }); + } + if ( + permissions !== undefined && + (permissions === null || typeof permissions !== 'object' || Array.isArray(permissions)) + ) { + return res.status(400).json({ error: 'permissions must be an object' }); + } + const roleData: Partial = { + name: (name as string).trim(), + permissions: permissions ?? {}, + }; + if (description !== undefined) { + roleData.description = description; + } + const role = await createRoleByName(roleData); + return res.status(201).json({ role }); + } catch (error) { + logger.error('[adminRoles] createRole error:', error); + if (error instanceof RoleConflictError) { + return res.status(409).json({ error: error.message }); + } + return res.status(500).json({ error: 'Failed to create role' }); + } + } + + async function rollbackMigratedUsers( + migratedIds: string[], + currentName: string, + newName: string, + ): Promise { + if (migratedIds.length === 0) { + return; + } + try { + await updateUsersRoleByIds(migratedIds, currentName); + } catch (rollbackError) { + logger.error( + `[adminRoles] CRITICAL: rename rollback failed โ€” ${migratedIds.length} users have dangling role "${newName}": [${migratedIds.join(', ')}]`, + rollbackError, + ); + } + } + + /** + * Renames a role by migrating users to the new name and updating the role document. + * + * The ID snapshot from `findUserIdsByRole` is a point-in-time read. Users assigned + * to `currentName` between the snapshot and the bulk `updateUsersByRole` write will + * be moved to `newName` but will NOT be reverted on rollback. This window is narrow + * and only relevant under concurrent admin operations during a rename. + */ + async function renameRole( + currentName: string, + newName: string, + extraUpdates?: Partial, + ): Promise { + const migratedIds = await findUserIdsByRole(currentName); + await updateUsersByRole(currentName, newName); + try { + const updates: Partial = { name: newName, ...extraUpdates }; + const role = await updateRoleByName(currentName, updates); + if (!role) { + await rollbackMigratedUsers(migratedIds, currentName, newName); + } + return role; + } catch (error) { + await rollbackMigratedUsers(migratedIds, currentName, newName); + throw error; + } + } + + async function updateRoleHandler(req: ServerRequest, res: Response) { + try { + const { name } = req.params as RoleNameParams; + const paramError = validateNameParam(name); + if (paramError) { + return res.status(400).json({ error: paramError }); + } + const body = req.body as { name?: string; description?: string }; + const nameError = validateRoleName(body.name, false); + if (nameError) { + return res.status(400).json({ error: nameError }); + } + const descError = validateDescription(body.description); + if (descError) { + return res.status(400).json({ error: descError }); + } + + const trimmedName = body.name?.trim() ?? ''; + const isRename = trimmedName !== '' && trimmedName !== name; + + if (isRename && isSystemRoleName(name)) { + return res.status(403).json({ error: 'Cannot rename system role' }); + } + if (isRename && isSystemRoleName(trimmedName)) { + return res.status(403).json({ error: 'Cannot use a reserved system role name' }); + } + + const existing = await getRoleByName(name); + if (!existing) { + return res.status(404).json({ error: 'Role not found' }); + } + + if (isRename) { + const duplicate = await getRoleByName(trimmedName); + if (duplicate) { + return res.status(409).json({ error: `Role "${trimmedName}" already exists` }); + } + } + + const updates: Partial = {}; + if (isRename) { + updates.name = trimmedName; + } + if (body.description !== undefined) { + updates.description = body.description; + } + + if (Object.keys(updates).length === 0) { + return res.status(200).json({ role: existing }); + } + + if (isRename) { + const descUpdate = + body.description !== undefined ? { description: body.description } : undefined; + const role = await renameRole(name, trimmedName, descUpdate); + if (!role) { + return res.status(404).json({ error: 'Role not found' }); + } + return res.status(200).json({ role }); + } + + const role = await updateRoleByName(name, updates); + if (!role) { + return res.status(404).json({ error: 'Role not found' }); + } + return res.status(200).json({ role }); + } catch (error) { + if (error instanceof RoleConflictError) { + return res.status(409).json({ error: error.message }); + } + logger.error('[adminRoles] updateRole error:', error); + return res.status(500).json({ error: 'Failed to update role' }); + } + } + + /** + * The re-fetch via `getRoleByName` after `updateAccessPermissions` depends on the + * callee having written the updated document to the role cache. If the cache layer + * is refactored to stop writing from within `updateAccessPermissions`, this handler + * must be updated to perform an explicit uncached DB read. + */ + async function updateRolePermissionsHandler(req: ServerRequest, res: Response) { + try { + const { name } = req.params as RoleNameParams; + const paramError = validateNameParam(name); + if (paramError) { + return res.status(400).json({ error: paramError }); + } + const { permissions } = req.body as { + permissions: Record>; + }; + + if (!permissions || typeof permissions !== 'object' || Array.isArray(permissions)) { + return res.status(400).json({ error: 'permissions object is required' }); + } + + const existing = await getRoleByName(name); + if (!existing) { + return res.status(404).json({ error: 'Role not found' }); + } + + await updateAccessPermissions(name, permissions, existing); + const updated = await getRoleByName(name); + if (!updated) { + return res.status(404).json({ error: 'Role not found' }); + } + return res.status(200).json({ role: updated }); + } catch (error) { + logger.error('[adminRoles] updateRolePermissions error:', error); + return res.status(500).json({ error: 'Failed to update role permissions' }); + } + } + + async function deleteRoleHandler(req: ServerRequest, res: Response) { + try { + const { name } = req.params as RoleNameParams; + const paramError = validateNameParam(name); + if (paramError) { + return res.status(400).json({ error: paramError }); + } + if (isSystemRoleName(name)) { + return res.status(403).json({ error: 'Cannot delete system role' }); + } + + const deleted = await deleteRoleByName(name); + if (!deleted) { + return res.status(404).json({ error: 'Role not found' }); + } + return res.status(200).json({ success: true }); + } catch (error) { + logger.error('[adminRoles] deleteRole error:', error); + return res.status(500).json({ error: 'Failed to delete role' }); + } + } + + async function getRoleMembersHandler(req: ServerRequest, res: Response) { + try { + const { name } = req.params as RoleNameParams; + const paramError = validateNameParam(name); + if (paramError) { + return res.status(400).json({ error: paramError }); + } + const existing = await getRoleByName(name); + if (!existing) { + return res.status(404).json({ error: 'Role not found' }); + } + + const { limit, offset } = parsePagination(req.query); + + const [users, total] = await Promise.all([ + listUsersByRole(name, { limit, offset }), + countUsersByRole(name), + ]); + const members: AdminMember[] = users.map((u) => ({ + userId: u._id?.toString() ?? '', + name: u.name ?? u._id?.toString() ?? '', + email: u.email ?? '', + avatarUrl: u.avatar, + })); + return res.status(200).json({ members, total, limit, offset }); + } catch (error) { + logger.error('[adminRoles] getRoleMembers error:', error); + return res.status(500).json({ error: 'Failed to get role members' }); + } + } + + async function addRoleMemberHandler(req: ServerRequest, res: Response) { + try { + const { name } = req.params as RoleNameParams; + const paramError = validateNameParam(name); + if (paramError) { + return res.status(400).json({ error: paramError }); + } + const { userId } = req.body as { userId: string }; + + if (!userId || typeof userId !== 'string') { + return res.status(400).json({ error: 'userId is required' }); + } + if (!isValidObjectIdString(userId)) { + return res.status(400).json({ error: 'Invalid user ID format' }); + } + + if (isSystemRoleName(name) && name !== SystemRoles.ADMIN) { + return res.status(403).json({ error: 'Cannot directly assign members to a system role' }); + } + + const existing = await getRoleByName(name); + if (!existing) { + return res.status(404).json({ error: 'Role not found' }); + } + + const user = await findUser({ _id: userId }); + if (!user) { + return res.status(404).json({ error: 'User not found' }); + } + + if (user.role === name) { + return res.status(200).json({ success: true }); + } + + if (user.role === SystemRoles.ADMIN && name !== SystemRoles.ADMIN) { + const adminCount = await countUsersByRole(SystemRoles.ADMIN); + if (adminCount <= 1) { + return res.status(400).json({ error: 'Cannot remove the last admin user' }); + } + } + + const updated = await updateUser(userId, { role: name }); + if (!updated) { + return res.status(404).json({ error: 'User not found' }); + } + + if (user.role === SystemRoles.ADMIN && name !== SystemRoles.ADMIN) { + const postCount = await countUsersByRole(SystemRoles.ADMIN); + if (postCount === 0) { + try { + await updateUser(userId, { role: SystemRoles.ADMIN }); + } catch (rollbackError) { + logger.error( + `[adminRoles] CRITICAL: admin rollback failed in addRoleMember for user ${userId}:`, + rollbackError, + ); + } + return res.status(400).json({ error: 'Cannot remove the last admin user' }); + } + } + + return res.status(200).json({ success: true }); + } catch (error) { + logger.error('[adminRoles] addRoleMember error:', error); + return res.status(500).json({ error: 'Failed to add role member' }); + } + } + + async function removeRoleMemberHandler(req: ServerRequest, res: Response) { + try { + const { name, userId } = req.params as RoleMemberParams; + const paramError = validateNameParam(name); + if (paramError) { + return res.status(400).json({ error: paramError }); + } + if (!isValidObjectIdString(userId)) { + return res.status(400).json({ error: 'Invalid user ID format' }); + } + + if (isSystemRoleName(name) && name !== SystemRoles.ADMIN) { + return res.status(403).json({ error: 'Cannot remove members from a system role' }); + } + + const existing = await getRoleByName(name); + if (!existing) { + return res.status(404).json({ error: 'Role not found' }); + } + + const user = await findUser({ _id: userId }); + if (!user) { + return res.status(404).json({ error: 'User not found' }); + } + + if (user.role !== name) { + return res.status(400).json({ error: 'User is not a member of this role' }); + } + + if (name === SystemRoles.ADMIN) { + const adminCount = await countUsersByRole(SystemRoles.ADMIN); + if (adminCount <= 1) { + return res.status(400).json({ error: 'Cannot remove the last admin user' }); + } + } + + const removed = await updateUser(userId, { role: SystemRoles.USER }); + if (!removed) { + return res.status(404).json({ error: 'User not found' }); + } + + if (name === SystemRoles.ADMIN) { + const postCount = await countUsersByRole(SystemRoles.ADMIN); + if (postCount === 0) { + try { + await updateUser(userId, { role: SystemRoles.ADMIN }); + } catch (rollbackError) { + logger.error( + `[adminRoles] CRITICAL: admin rollback failed for user ${userId}:`, + rollbackError, + ); + } + return res.status(400).json({ error: 'Cannot remove the last admin user' }); + } + } + + return res.status(200).json({ success: true }); + } catch (error) { + logger.error('[adminRoles] removeRoleMember error:', error); + return res.status(500).json({ error: 'Failed to remove role member' }); + } + } + + return { + listRoles: listRolesHandler, + getRole: getRoleHandler, + createRole: createRoleHandler, + updateRole: updateRoleHandler, + updateRolePermissions: updateRolePermissionsHandler, + deleteRole: deleteRoleHandler, + getRoleMembers: getRoleMembersHandler, + addRoleMember: addRoleMemberHandler, + removeRoleMember: removeRoleMemberHandler, + }; +} diff --git a/packages/data-schemas/package.json b/packages/data-schemas/package.json index 0124552002..145b8925d1 100644 --- a/packages/data-schemas/package.json +++ b/packages/data-schemas/package.json @@ -1,6 +1,6 @@ { "name": "@librechat/data-schemas", - "version": "0.0.46", + "version": "0.0.47", "description": "Mongoose schemas and models for LibreChat", "type": "module", "main": "dist/index.cjs", diff --git a/packages/data-schemas/src/index.ts b/packages/data-schemas/src/index.ts index cd683c937c..d673db1f5c 100644 --- a/packages/data-schemas/src/index.ts +++ b/packages/data-schemas/src/index.ts @@ -7,6 +7,7 @@ export * from './utils'; export { createModels } from './models'; export { createMethods, + RoleConflictError, DEFAULT_REFRESH_TOKEN_EXPIRY, DEFAULT_SESSION_EXPIRY, tokenValues, diff --git a/packages/data-schemas/src/methods/index.ts b/packages/data-schemas/src/methods/index.ts index 4202cac0eb..830d88ff4c 100644 --- a/packages/data-schemas/src/methods/index.ts +++ b/packages/data-schemas/src/methods/index.ts @@ -1,9 +1,8 @@ import { createSessionMethods, DEFAULT_REFRESH_TOKEN_EXPIRY, type SessionMethods } from './session'; import { createTokenMethods, type TokenMethods } from './token'; -import { createRoleMethods, type RoleMethods, type RoleDeps } from './role'; +import { createRoleMethods, RoleConflictError } from './role'; +import type { RoleMethods, RoleDeps } from './role'; import { createUserMethods, DEFAULT_SESSION_EXPIRY, type UserMethods } from './user'; - -export { DEFAULT_REFRESH_TOKEN_EXPIRY, DEFAULT_SESSION_EXPIRY }; import { createKeyMethods, type KeyMethods } from './key'; import { createFileMethods, type FileMethods } from './file'; /* Memories */ @@ -51,6 +50,7 @@ import { createAgentMethods, type AgentMethods, type AgentDeps } from './agent'; /* Config */ import { createConfigMethods, type ConfigMethods } from './config'; +export { RoleConflictError, DEFAULT_REFRESH_TOKEN_EXPIRY, DEFAULT_SESSION_EXPIRY }; export { tokenValues, cacheTokenValues, premiumTokenValues, defaultRate }; export type AllMethods = UserMethods & diff --git a/packages/data-schemas/src/methods/role.methods.spec.ts b/packages/data-schemas/src/methods/role.methods.spec.ts index 78d7f98ea1..f8a66bef5d 100644 --- a/packages/data-schemas/src/methods/role.methods.spec.ts +++ b/packages/data-schemas/src/methods/role.methods.spec.ts @@ -1,10 +1,17 @@ import mongoose from 'mongoose'; import { MongoMemoryServer } from 'mongodb-memory-server'; import { SystemRoles, Permissions, roleDefaults, PermissionTypes } from 'librechat-data-provider'; -import type { IRole, RolePermissions } from '..'; +import type { IRole, IUser, RolePermissions } from '..'; import { createRoleMethods } from './role'; import { createModels } from '../models'; +jest.mock('~/config/winston', () => ({ + error: jest.fn(), + warn: jest.fn(), + info: jest.fn(), + debug: jest.fn(), +})); + const mockCache = { get: jest.fn(), set: jest.fn(), @@ -14,9 +21,18 @@ const mockCache = { const mockGetCache = jest.fn().mockReturnValue(mockCache); let Role: mongoose.Model; +let User: mongoose.Model; let getRoleByName: ReturnType['getRoleByName']; let updateAccessPermissions: ReturnType['updateAccessPermissions']; let initializeRoles: ReturnType['initializeRoles']; +let createRoleByName: ReturnType['createRoleByName']; +let deleteRoleByName: ReturnType['deleteRoleByName']; +let updateUsersByRole: ReturnType['updateUsersByRole']; +let listUsersByRole: ReturnType['listUsersByRole']; +let countUsersByRole: ReturnType['countUsersByRole']; +let updateRoleByName: ReturnType['updateRoleByName']; +let listRoles: ReturnType['listRoles']; +let countRoles: ReturnType['countRoles']; let mongoServer: MongoMemoryServer; beforeAll(async () => { @@ -25,10 +41,19 @@ beforeAll(async () => { await mongoose.connect(mongoUri); createModels(mongoose); Role = mongoose.models.Role; + User = mongoose.models.User as mongoose.Model; const methods = createRoleMethods(mongoose, { getCache: mockGetCache }); getRoleByName = methods.getRoleByName; updateAccessPermissions = methods.updateAccessPermissions; initializeRoles = methods.initializeRoles; + createRoleByName = methods.createRoleByName; + deleteRoleByName = methods.deleteRoleByName; + updateRoleByName = methods.updateRoleByName; + updateUsersByRole = methods.updateUsersByRole; + listUsersByRole = methods.listUsersByRole; + countUsersByRole = methods.countUsersByRole; + listRoles = methods.listRoles; + countRoles = methods.countRoles; }); afterAll(async () => { @@ -38,6 +63,7 @@ afterAll(async () => { beforeEach(async () => { await Role.deleteMany({}); + await User.deleteMany({}); mockGetCache.mockClear(); mockCache.get.mockClear(); mockCache.set.mockClear(); @@ -515,3 +541,362 @@ describe('initializeRoles', () => { expect(userRole.permissions[PermissionTypes.MULTI_CONVO]?.USE).toBeDefined(); }); }); + +describe('createRoleByName', () => { + it('creates a custom role and caches it', async () => { + const role = await createRoleByName({ name: 'editor', description: 'Can edit' }); + + expect(role.name).toBe('editor'); + expect(role.description).toBe('Can edit'); + expect(mockCache.set).toHaveBeenCalledWith( + 'editor', + expect.objectContaining({ name: 'editor' }), + ); + + const persisted = await Role.findOne({ name: 'editor' }).lean(); + expect(persisted).toBeTruthy(); + }); + + it('trims whitespace from role name', async () => { + const role = await createRoleByName({ name: ' editor ' }); + + expect(role.name).toBe('editor'); + }); + + it('throws when name is empty', async () => { + await expect(createRoleByName({ name: '' })).rejects.toThrow('Role name is required'); + }); + + it('throws when name is whitespace-only', async () => { + await expect(createRoleByName({ name: ' ' })).rejects.toThrow('Role name is required'); + }); + + it('throws when name is undefined', async () => { + await expect(createRoleByName({})).rejects.toThrow('Role name is required'); + }); + + it('throws for reserved system role names', async () => { + await expect(createRoleByName({ name: SystemRoles.ADMIN })).rejects.toThrow( + /reserved system name/, + ); + await expect(createRoleByName({ name: SystemRoles.USER })).rejects.toThrow( + /reserved system name/, + ); + }); + + it('throws when role already exists', async () => { + await createRoleByName({ name: 'editor' }); + + await expect(createRoleByName({ name: 'editor' })).rejects.toThrow(/already exists/); + }); +}); + +describe('deleteRoleByName', () => { + it('deletes a custom role and reassigns users to USER', async () => { + await createRoleByName({ name: 'editor' }); + await User.create([ + { name: 'Alice', email: 'alice@test.com', role: 'editor', username: 'alice' }, + { name: 'Bob', email: 'bob@test.com', role: 'editor', username: 'bob' }, + { name: 'Carol', email: 'carol@test.com', role: SystemRoles.USER, username: 'carol' }, + ]); + + const deleted = await deleteRoleByName('editor'); + + expect(deleted).toBeTruthy(); + expect(deleted!.name).toBe('editor'); + + const alice = await User.findOne({ email: 'alice@test.com' }).lean(); + const bob = await User.findOne({ email: 'bob@test.com' }).lean(); + const carol = await User.findOne({ email: 'carol@test.com' }).lean(); + expect(alice!.role).toBe(SystemRoles.USER); + expect(bob!.role).toBe(SystemRoles.USER); + expect(carol!.role).toBe(SystemRoles.USER); + }); + + it('returns null when role does not exist', async () => { + const result = await deleteRoleByName('nonexistent'); + expect(result).toBeNull(); + }); + + it('throws for system roles', async () => { + await expect(deleteRoleByName(SystemRoles.ADMIN)).rejects.toThrow(/Cannot delete system role/); + await expect(deleteRoleByName(SystemRoles.USER)).rejects.toThrow(/Cannot delete system role/); + }); + + it('sets cache entry to null after deletion', async () => { + await createRoleByName({ name: 'editor' }); + mockCache.set.mockClear(); + + await deleteRoleByName('editor'); + + expect(mockCache.set).toHaveBeenCalledWith('editor', null); + }); + + it('returns null and invalidates cache when role does not exist', async () => { + mockCache.set.mockClear(); + + const result = await deleteRoleByName('nonexistent'); + + expect(result).toBeNull(); + expect(mockCache.set).toHaveBeenCalledWith('nonexistent', null); + }); +}); + +describe('updateRoleByName - cache on rename', () => { + it('invalidates old key and populates new key on rename', async () => { + await createRoleByName({ name: 'editor', description: 'Can edit' }); + mockCache.set.mockClear(); + + const updated = await updateRoleByName('editor', { name: 'senior-editor' }); + + expect(updated.name).toBe('senior-editor'); + expect(mockCache.set).toHaveBeenCalledWith('editor', null); + expect(mockCache.set).toHaveBeenCalledWith( + 'senior-editor', + expect.objectContaining({ name: 'senior-editor' }), + ); + }); + + it('writes same key when name unchanged', async () => { + await createRoleByName({ name: 'editor' }); + mockCache.set.mockClear(); + + await updateRoleByName('editor', { description: 'Updated desc' }); + + expect(mockCache.set).toHaveBeenCalledWith( + 'editor', + expect.objectContaining({ name: 'editor', description: 'Updated desc' }), + ); + expect(mockCache.set).toHaveBeenCalledTimes(1); + }); +}); + +describe('listUsersByRole', () => { + it('returns users matching the role', async () => { + await User.create([ + { name: 'Alice', email: 'alice@test.com', role: 'editor', username: 'alice' }, + { name: 'Bob', email: 'bob@test.com', role: 'editor', username: 'bob' }, + { name: 'Carol', email: 'carol@test.com', role: SystemRoles.USER, username: 'carol' }, + ]); + + const users = await listUsersByRole('editor'); + + expect(users).toHaveLength(2); + const names = users.map((u) => u.name).sort(); + expect(names).toEqual(['Alice', 'Bob']); + }); + + it('returns empty array when no users have the role', async () => { + const users = await listUsersByRole('nonexistent'); + expect(users).toEqual([]); + }); + + it('respects limit and offset for pagination', async () => { + await User.create([ + { name: 'Alice', email: 'a@test.com', role: 'editor', username: 'a' }, + { name: 'Bob', email: 'b@test.com', role: 'editor', username: 'b' }, + { name: 'Carol', email: 'c@test.com', role: 'editor', username: 'c' }, + { name: 'Dave', email: 'd@test.com', role: 'editor', username: 'd' }, + { name: 'Eve', email: 'e@test.com', role: 'editor', username: 'e' }, + ]); + + const page1 = await listUsersByRole('editor', { limit: 2, offset: 0 }); + const page2 = await listUsersByRole('editor', { limit: 2, offset: 2 }); + const page3 = await listUsersByRole('editor', { limit: 2, offset: 4 }); + + expect(page1).toHaveLength(2); + expect(page2).toHaveLength(2); + expect(page3).toHaveLength(1); + + const allIds = [...page1, ...page2, ...page3].map((u) => u._id!.toString()); + expect(new Set(allIds).size).toBe(5); + }); + + it('selects only expected fields', async () => { + await User.create({ + name: 'Alice', + email: 'alice@test.com', + role: 'editor', + username: 'alice', + password: 'secret123', + }); + + const users = await listUsersByRole('editor'); + + expect(users).toHaveLength(1); + expect(users[0].name).toBe('Alice'); + expect(users[0].email).toBe('alice@test.com'); + expect(users[0]._id).toBeDefined(); + expect('password' in users[0]).toBe(false); + expect('username' in users[0]).toBe(false); + }); +}); + +describe('updateUsersByRole', () => { + it('migrates all users from one role to another', async () => { + await User.create([ + { name: 'Alice', email: 'alice@test.com', role: 'editor', username: 'alice' }, + { name: 'Bob', email: 'bob@test.com', role: 'editor', username: 'bob' }, + { name: 'Carol', email: 'carol@test.com', role: SystemRoles.USER, username: 'carol' }, + ]); + + await updateUsersByRole('editor', 'senior-editor'); + + const alice = await User.findOne({ email: 'alice@test.com' }).lean(); + const bob = await User.findOne({ email: 'bob@test.com' }).lean(); + const carol = await User.findOne({ email: 'carol@test.com' }).lean(); + expect(alice!.role).toBe('senior-editor'); + expect(bob!.role).toBe('senior-editor'); + expect(carol!.role).toBe(SystemRoles.USER); + }); + + it('is a no-op when no users have the source role', async () => { + await User.create({ + name: 'Alice', + email: 'alice@test.com', + role: SystemRoles.USER, + username: 'alice', + }); + + await updateUsersByRole('nonexistent', 'new-role'); + + const alice = await User.findOne({ email: 'alice@test.com' }).lean(); + expect(alice!.role).toBe(SystemRoles.USER); + }); +}); + +describe('countUsersByRole', () => { + it('returns the count of users with the given role', async () => { + await User.create([ + { name: 'Alice', email: 'alice@test.com', role: 'editor', username: 'alice' }, + { name: 'Bob', email: 'bob@test.com', role: 'editor', username: 'bob' }, + { name: 'Carol', email: 'carol@test.com', role: SystemRoles.USER, username: 'carol' }, + ]); + + expect(await countUsersByRole('editor')).toBe(2); + expect(await countUsersByRole(SystemRoles.USER)).toBe(1); + }); + + it('returns 0 when no users have the role', async () => { + expect(await countUsersByRole('nonexistent')).toBe(0); + }); +}); + +describe('listRoles', () => { + beforeEach(async () => { + await Role.deleteMany({}); + }); + + it('returns roles sorted alphabetically by name', async () => { + await Role.create([ + { name: 'zebra', permissions: {} }, + { name: 'alpha', permissions: {} }, + { name: 'middle', permissions: {} }, + ]); + + const roles = await listRoles(); + + expect(roles.map((r) => r.name)).toEqual(['alpha', 'middle', 'zebra']); + }); + + it('respects limit and offset for pagination', async () => { + await Role.create([ + { name: 'a-role', permissions: {} }, + { name: 'b-role', permissions: {} }, + { name: 'c-role', permissions: {} }, + { name: 'd-role', permissions: {} }, + { name: 'e-role', permissions: {} }, + ]); + + const page1 = await listRoles({ limit: 2, offset: 0 }); + const page2 = await listRoles({ limit: 2, offset: 2 }); + const page3 = await listRoles({ limit: 2, offset: 4 }); + + expect(page1).toHaveLength(2); + expect(page1.map((r) => r.name)).toEqual(['a-role', 'b-role']); + expect(page2).toHaveLength(2); + expect(page2.map((r) => r.name)).toEqual(['c-role', 'd-role']); + expect(page3).toHaveLength(1); + expect(page3.map((r) => r.name)).toEqual(['e-role']); + }); + + it('defaults to limit 50 and offset 0', async () => { + await Role.create({ name: 'only-role', permissions: {} }); + + const roles = await listRoles(); + + expect(roles).toHaveLength(1); + expect(roles[0].name).toBe('only-role'); + }); + + it('returns only name and description fields', async () => { + await Role.create({ + name: 'editor', + description: 'Can edit', + permissions: { PROMPTS: { USE: true } }, + }); + + const roles = await listRoles(); + + expect(roles).toHaveLength(1); + expect(roles[0].name).toBe('editor'); + expect(roles[0].description).toBe('Can edit'); + expect(roles[0]._id).toBeDefined(); + expect('permissions' in roles[0]).toBe(false); + }); + + it('returns empty array when no roles exist', async () => { + const roles = await listRoles(); + expect(roles).toEqual([]); + }); + + it('returns undefined description for pre-existing roles without the field', async () => { + await Role.collection.insertOne({ name: 'legacy', permissions: {} }); + + const roles = await listRoles(); + + expect(roles).toHaveLength(1); + expect(roles[0].name).toBe('legacy'); + expect(roles[0].description).toBeUndefined(); + }); +}); + +describe('countRoles', () => { + beforeEach(async () => { + await Role.deleteMany({}); + }); + + it('returns the total number of roles', async () => { + await Role.create([ + { name: 'a', permissions: {} }, + { name: 'b', permissions: {} }, + { name: 'c', permissions: {} }, + ]); + + expect(await countRoles()).toBe(3); + }); + + it('returns 0 when no roles exist', async () => { + expect(await countRoles()).toBe(0); + }); +}); + +describe('createRoleByName - duplicate key race', () => { + beforeEach(async () => { + await Role.deleteMany({}); + }); + + it('throws RoleConflictError on concurrent insert (11000)', async () => { + await createRoleByName({ name: 'editor' }); + + const insertSpy = jest.spyOn(Role.prototype, 'save').mockImplementationOnce(() => { + const err = new Error('E11000 duplicate key error') as Error & { code: number }; + err.code = 11000; + throw err; + }); + + await expect(createRoleByName({ name: 'editor2' })).rejects.toThrow(/already exists/); + + insertSpy.mockRestore(); + }); +}); diff --git a/packages/data-schemas/src/methods/role.ts b/packages/data-schemas/src/methods/role.ts index 7b51e45330..442041dcde 100644 --- a/packages/data-schemas/src/methods/role.ts +++ b/packages/data-schemas/src/methods/role.ts @@ -5,9 +5,24 @@ import { permissionsSchema, removeNullishValues, } from 'librechat-data-provider'; -import type { IRole } from '~/types'; +import type { Model } from 'mongoose'; +import type { IRole, IUser } from '~/types'; import logger from '~/config/winston'; +const systemRoleValues = new Set(Object.values(SystemRoles)); + +/** Case-insensitive check โ€” the legacy roles route uppercases params. */ +function isSystemRoleName(name: string): boolean { + return systemRoleValues.has(name.toUpperCase()); +} + +export class RoleConflictError extends Error { + constructor(message: string) { + super(message); + this.name = 'RoleConflictError'; + } +} + export interface RoleDeps { /** Returns a cache store for the given key. Injected from getLogStores. */ getCache?: (key: string) => { @@ -30,8 +45,11 @@ export function createRoleMethods(mongoose: typeof import('mongoose'), deps: Rol const defaultPerms = roleDefaults[roleName].permissions; if (!role) { - role = new Role(roleDefaults[roleName]); + role = new Role({ ...roleDefaults[roleName], description: '' }); } else { + if (role.description == null) { + role.description = ''; + } const permissions = role.toObject()?.permissions ?? {}; role.permissions = role.permissions || {}; for (const permType of Object.keys(defaultPerms)) { @@ -45,11 +63,26 @@ export function createRoleMethods(mongoose: typeof import('mongoose'), deps: Rol } /** - * List all roles in the system. + * List all roles in the system. Returns only name and description (projected). */ - async function listRoles() { + async function listRoles(options?: { + limit?: number; + offset?: number; + }): Promise[]> { const Role = mongoose.models.Role; - return await Role.find({}).select('name permissions').lean(); + const limit = options?.limit ?? 50; + const offset = options?.offset ?? 0; + return await Role.find({}) + .select('name description') + .sort({ name: 1 }) + .skip(offset) + .limit(limit) + .lean(); + } + + async function countRoles(): Promise { + const Role = mongoose.models.Role; + return await Role.countDocuments({}); } /** @@ -73,7 +106,7 @@ export function createRoleMethods(mongoose: typeof import('mongoose'), deps: Rol } const role = await query.lean().exec(); - if (!role && SystemRoles[roleName as keyof typeof SystemRoles]) { + if (!role && systemRoleValues.has(roleName)) { const newRole = await new Role(roleDefaults[roleName as keyof typeof roleDefaults]).save(); if (cache) { await cache.set(roleName, newRole); @@ -96,20 +129,24 @@ export function createRoleMethods(mongoose: typeof import('mongoose'), deps: Rol const cache = deps.getCache?.(CacheKeys.ROLES); try { const Role = mongoose.models.Role; - const role = await Role.findOneAndUpdate( - { name: roleName }, - { $set: updates }, - { new: true, lean: true }, - ) + const role = await Role.findOneAndUpdate({ name: roleName }, { $set: updates }, { new: true }) .select('-__v') .lean() .exec(); if (cache) { - await cache.set(roleName, role); + if (updates.name && updates.name !== roleName) { + await Promise.all([cache.set(roleName, null), cache.set(updates.name, role)]); + } else { + await cache.set(roleName, role); + } } return role as unknown as IRole; } catch (error) { - throw new Error(`Failed to update role: ${(error as Error).message}`); + if (error && typeof error === 'object' && 'code' in error && error.code === 11000) { + const targetName = updates.name ?? roleName; + throw new RoleConflictError(`Role "${targetName}" already exists`); + } + throw new Error(`Failed to update role: ${(error as Error).message}`, { cause: error }); } } @@ -342,13 +379,137 @@ export function createRoleMethods(mongoose: typeof import('mongoose'), deps: Rol } } + /** Rejects names that match system roles. */ + async function createRoleByName(roleData: Partial): Promise { + const { name } = roleData; + if (!name || typeof name !== 'string' || !name.trim()) { + throw new Error('Role name is required'); + } + const trimmed = name.trim(); + if (isSystemRoleName(trimmed)) { + throw new RoleConflictError(`Cannot create role with reserved system name: ${name}`); + } + const Role = mongoose.models.Role; + const existing = await Role.findOne({ name: trimmed }).lean(); + if (existing) { + throw new RoleConflictError(`Role "${trimmed}" already exists`); + } + let role; + try { + role = await new Role({ ...roleData, name: trimmed }).save(); + } catch (err) { + /** + * The compound unique index `{ name: 1, tenantId: 1 }` on the role schema + * (roleSchema.index in schema/role.ts) triggers error 11000 when a concurrent + * request races past the findOne check above. This catch converts it into + * the same user-facing message as the application-level duplicate check. + */ + if (err && typeof err === 'object' && 'code' in err && err.code === 11000) { + throw new RoleConflictError(`Role "${trimmed}" already exists`); + } + throw err; + } + try { + const cache = deps.getCache?.(CacheKeys.ROLES); + if (cache) { + await cache.set(role.name, role.toObject()); + } + } catch (cacheError) { + logger.error(`[createRoleByName] cache set failed for "${role.name}":`, cacheError); + } + return role.toObject() as IRole; + } + + /** + * Guards against deleting system roles. Reassigns affected users back to USER. + * + * No existence pre-check is performed: for a nonexistent role the `updateMany` + * is a harmless no-op and `findOneAndDelete` returns null. This makes the + * function idempotent โ€” a retry after a partial failure will still clean up + * orphaned user references and cache entries. + * + * Without a MongoDB transaction the two writes are non-atomic โ€” if the delete + * fails after the reassignment, users will already have been moved to USER + * while the role document still exists. Recovery requires the caller to retry + * the delete call, which will succeed since the `updateMany` is a no-op on + * the second pass. + */ + async function deleteRoleByName(roleName: string): Promise { + if (isSystemRoleName(roleName)) { + throw new Error(`Cannot delete system role: ${roleName}`); + } + const Role = mongoose.models.Role; + const User = mongoose.models.User as Model; + await User.updateMany({ role: roleName }, { $set: { role: SystemRoles.USER } }); + const deleted = await Role.findOneAndDelete({ name: roleName }).lean(); + try { + const cache = deps.getCache?.(CacheKeys.ROLES); + if (cache) { + // Setting null evicts the stale document. getRoleByName treats falsy cached + // values as a miss and falls through to the DB, so this does not provide + // negative caching โ€” it only prevents serving the pre-deletion document. + await cache.set(roleName, null); + } + } catch (cacheError) { + logger.error(`[deleteRoleByName] cache invalidation failed for "${roleName}":`, cacheError); + } + return deleted as IRole | null; + } + + async function updateUsersByRole(oldRole: string, newRole: string): Promise { + const User = mongoose.models.User as Model; + await User.updateMany({ role: oldRole }, { $set: { role: newRole } }); + } + + async function findUserIdsByRole(roleName: string): Promise { + const User = mongoose.models.User as Model; + const users = await User.find({ role: roleName }).select('_id').lean(); + return users.map((u) => u._id.toString()); + } + + async function updateUsersRoleByIds(userIds: string[], newRole: string): Promise { + if (userIds.length === 0) { + return; + } + const User = mongoose.models.User as Model; + await User.updateMany({ _id: { $in: userIds } }, { $set: { role: newRole } }); + } + + async function listUsersByRole( + roleName: string, + options?: { limit?: number; offset?: number }, + ): Promise { + const User = mongoose.models.User as Model; + const limit = options?.limit ?? 50; + const offset = options?.offset ?? 0; + return await User.find({ role: roleName }) + .select('_id name email avatar') + .sort({ _id: 1 }) + .skip(offset) + .limit(limit) + .lean(); + } + + async function countUsersByRole(roleName: string): Promise { + const User = mongoose.models.User as Model; + return await User.countDocuments({ role: roleName }); + } + return { listRoles, + countRoles, initializeRoles, getRoleByName, updateRoleByName, updateAccessPermissions, migrateRoleSchema, + createRoleByName, + deleteRoleByName, + updateUsersByRole, + findUserIdsByRole, + updateUsersRoleByIds, + listUsersByRole, + countUsersByRole, }; } diff --git a/packages/data-schemas/src/schema/role.ts b/packages/data-schemas/src/schema/role.ts index 1c27478ef6..ac478c2a83 100644 --- a/packages/data-schemas/src/schema/role.ts +++ b/packages/data-schemas/src/schema/role.ts @@ -73,6 +73,7 @@ const rolePermissionsSchema = new Schema( const roleSchema: Schema = new Schema({ name: { type: String, required: true, index: true }, + description: { type: String, default: '' }, permissions: { type: rolePermissionsSchema, }, diff --git a/packages/data-schemas/src/schema/user.ts b/packages/data-schemas/src/schema/user.ts index 92680415bd..f807ddd8d6 100644 --- a/packages/data-schemas/src/schema/user.ts +++ b/packages/data-schemas/src/schema/user.ts @@ -158,6 +158,7 @@ const userSchema = new Schema( ); userSchema.index({ email: 1, tenantId: 1 }, { unique: true }); +userSchema.index({ role: 1, tenantId: 1 }); const oAuthIdFields = [ 'googleId', diff --git a/packages/data-schemas/src/types/admin.ts b/packages/data-schemas/src/types/admin.ts index 99915f659d..9b30cdb98a 100644 --- a/packages/data-schemas/src/types/admin.ts +++ b/packages/data-schemas/src/types/admin.ts @@ -114,7 +114,7 @@ export type AdminMember = { name: string; email: string; avatarUrl?: string; - joinedAt: string; + joinedAt?: string; }; /** Minimal user info returned by user search endpoints. */ diff --git a/packages/data-schemas/src/types/role.ts b/packages/data-schemas/src/types/role.ts index 60a579240c..bc85284c34 100644 --- a/packages/data-schemas/src/types/role.ts +++ b/packages/data-schemas/src/types/role.ts @@ -5,6 +5,7 @@ import { CursorPaginationParams } from '~/common'; export interface IRole extends Document { name: string; + description?: string; permissions: { [PermissionTypes.BOOKMARKS]?: { [Permissions.USE]?: boolean; @@ -74,11 +75,13 @@ export type RolePermissionsInput = DeepPartial; export interface CreateRoleRequest { name: string; + description?: string; permissions: RolePermissionsInput; } export interface UpdateRoleRequest { name?: string; + description?: string; permissions?: RolePermissionsInput; } From 77712c825f5c6d9303007880ad445adfafd96955 Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Fri, 27 Mar 2026 16:08:43 -0400 Subject: [PATCH 11/18] =?UTF-8?q?=F0=9F=8F=A2=20feat:=20Tenant-Scoped=20Ap?= =?UTF-8?q?p=20Config=20in=20Auth=20Login=20Flows=20(#12434)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: add resolveAppConfigForUser utility for tenant-scoped auth config TypeScript utility in packages/api that wraps getAppConfig in tenantStorage.run() when the user has a tenantId, falling back to baseOnly for new users or non-tenant deployments. Uses DI pattern (getAppConfig passed as parameter) for testability. Auth flows apply role-level overrides only (userId not passed) because user/group principal resolution is deferred to post-auth. * feat: tenant-scoped app config in auth login flows All auth strategies (LDAP, SAML, OpenID, social login) now use a two-phase domain check consistent with requestPasswordReset: 1. Fast-fail with base config (memory-cached, zero DB queries) 2. DB user lookup 3. Tenant-scoped re-check via resolveAppConfigForUser (only when user has a tenantId; otherwise reuse base config) This preserves the original fast-fail protection against globally blocked domains while enabling tenant-specific config overrides. OpenID error ordering preserved: AUTH_FAILED checked before domain re-check so users with wrong providers get the correct error type. registerUser unchanged (baseOnly, no user identity yet). * test: add tenant-scoped config tests for auth strategies Add resolveAppConfig.spec.ts in packages/api with 8 tests: - baseOnly fallback for null/undefined/no-tenant users - tenant-scoped config with role and tenantId - ALS context propagation verified inside getAppConfig callback - undefined role with tenantId edge case Update strategy and AuthService tests to mock resolveAppConfigForUser via @librechat/api. Tests verify two-phase domain check behavior: fast-fail before DB, tenant re-check after. Non-tenant users reuse base config without calling resolveAppConfigForUser. * refactor: skip redundant domain re-check for non-tenant users Guard the second isEmailDomainAllowed call with appConfig !== baseConfig in SAML, OpenID, and social strategies. For non-tenant users the tenant config is the same base config object, so the second check is a no-op. Narrow eslint-disable in resolveAppConfig.spec.ts to the specific require line instead of blanket file-level suppression. * fix: address review findings โ€” consistency, tests, and ordering - Consolidate duplicate require('@librechat/api') in AuthService.js - Add two-phase domain check to LDAP (base fast-fail before findUser), making all strategies consistent with PR description - Add appConfig !== baseConfig guard to requestPasswordReset second domain check, consistent with SAML/OpenID/social strategies - Move SAML provider check before tenant config resolution to avoid unnecessary resolveAppConfigForUser call for wrong-provider users - Add tenant domain rejection tests to SAML, OpenID, and social specs verifying that tenant config restrictions actually block login - Add error propagation tests to resolveAppConfig.spec.ts - Remove redundant mockTenantStorage alias in resolveAppConfig.spec.ts - Narrow eslint-disable to specific require line * test: add tenant domain rejection test for LDAP strategy Covers the appConfig !== baseConfig && !isEmailDomainAllowed path, consistent with SAML, OpenID, and social strategy specs. * refactor: rename resolveAppConfig to app/resolve per AGENTS.md Rename resolveAppConfig.ts โ†’ resolve.ts and resolveAppConfig.spec.ts โ†’ resolve.spec.ts to align with the project's concise naming convention. * fix: remove fragile reference-equality guard, add logging and docs Remove appConfig !== baseConfig guard from all strategies and requestPasswordReset. The guard relied on implicit cache-backend identity semantics (in-memory Keyv returns same object reference) that would silently break with Redis or cloned configs. The second isEmailDomainAllowed call is a cheap synchronous check โ€” always running it is clearer and eliminates the coupling. Add audit logging to requestPasswordReset domain blocks (base and tenant), consistent with all auth strategies. Extract duplicated error construction into makeDomainDeniedError(). Wrap resolveAppConfigForUser in requestPasswordReset with try/catch to prevent DB errors from leaking to the client via the controller's generic catch handler. Document the dual tenantId propagation (ALS for DB isolation, explicit param for cache key) in resolveAppConfigForUser JSDoc. Add comment documenting the LDAP error-type ordering change (cross-provider users from blocked domains now get 'domain not allowed' instead of AUTH_FAILED). Assert resolveAppConfigForUser is not called on LDAP provider mismatch path. * fix: return generic response for tenant domain block in password reset Tenant-scoped domain rejection in requestPasswordReset now returns the same generic "If an account with that email exists..." response instead of an Error. This prevents user-enumeration: an attacker cannot distinguish between "email not found" and "tenant blocks this domain" by comparing HTTP responses. The base-config fast-fail (pre-user-lookup) still returns an Error since it fires before any user existence is revealed. * docs: document phase 1 vs phase 2 domain check behavior in JSDoc Phase 1 (base config, pre-findUser) intentionally returns Error/400 to reveal globally blocked domains without confirming user existence. Phase 2 (tenant config, post-findUser) returns generic 200 to prevent user-enumeration. This distinction is now explicit in the JSDoc. --- api/server/services/AuthService.js | 42 +- api/server/services/AuthService.spec.js | 76 +- api/strategies/ldapStrategy.js | 43 +- api/strategies/ldapStrategy.spec.js | 71 +- api/strategies/openidStrategy.js | 15 +- api/strategies/openidStrategy.spec.js | 3695 ++++++++++++----------- api/strategies/samlStrategy.js | 21 +- api/strategies/samlStrategy.spec.js | 50 + api/strategies/socialLogin.js | 21 +- api/strategies/socialLogin.test.js | 139 +- packages/api/src/app/index.ts | 1 + packages/api/src/app/resolve.spec.ts | 95 + packages/api/src/app/resolve.ts | 39 + 13 files changed, 2428 insertions(+), 1880 deletions(-) create mode 100644 packages/api/src/app/resolve.spec.ts create mode 100644 packages/api/src/app/resolve.ts diff --git a/api/server/services/AuthService.js b/api/server/services/AuthService.js index f17c5051a9..816a0eac5b 100644 --- a/api/server/services/AuthService.js +++ b/api/server/services/AuthService.js @@ -13,6 +13,7 @@ const { checkEmailConfig, isEmailDomainAllowed, shouldUseSecureCookie, + resolveAppConfigForUser, } = require('@librechat/api'); const { findUser, @@ -255,19 +256,52 @@ const registerUser = async (user, additionalData = {}) => { }; /** - * Request password reset + * Request password reset. + * + * Uses a two-phase domain check: fast-fail with the memory-cached base config + * (zero DB queries) to block globally denied domains before user lookup, then + * re-check with tenant-scoped config after user lookup so tenant-specific + * restrictions are enforced. + * + * Phase 1 (base check) returns an Error (HTTP 400) โ€” this intentionally reveals + * that the domain is globally blocked, but fires before any DB lookup so it + * cannot confirm user existence. Phase 2 (tenant check) returns the generic + * success message (HTTP 200) to prevent user-enumeration via status codes. + * * @param {ServerRequest} req */ const requestPasswordReset = async (req) => { const { email } = req.body; - const appConfig = await getAppConfig({ baseOnly: true }); - if (!isEmailDomainAllowed(email, appConfig?.registration?.allowedDomains)) { + + const baseConfig = await getAppConfig({ baseOnly: true }); + if (!isEmailDomainAllowed(email, baseConfig?.registration?.allowedDomains)) { + logger.warn( + `[requestPasswordReset] Blocked - email domain not allowed [Email: ${email}] [IP: ${req.ip}]`, + ); const error = new Error(ErrorTypes.AUTH_FAILED); error.code = ErrorTypes.AUTH_FAILED; error.message = 'Email domain not allowed'; return error; } - const user = await findUser({ email }, 'email _id'); + + const user = await findUser({ email }, 'email _id role tenantId'); + let appConfig = baseConfig; + if (user?.tenantId) { + try { + appConfig = await resolveAppConfigForUser(getAppConfig, user); + } catch (err) { + logger.error('[requestPasswordReset] Failed to resolve tenant config, using base:', err); + } + } + + if (!isEmailDomainAllowed(email, appConfig?.registration?.allowedDomains)) { + logger.warn( + `[requestPasswordReset] Tenant config blocked domain [Email: ${email}] [IP: ${req.ip}]`, + ); + return { + message: 'If an account with that email exists, a password reset link has been sent to it.', + }; + } const emailEnabled = checkEmailConfig(); logger.warn(`[requestPasswordReset] [Password reset request initiated] [Email: ${email}]`); diff --git a/api/server/services/AuthService.spec.js b/api/server/services/AuthService.spec.js index da78f8d775..c8abafdbe5 100644 --- a/api/server/services/AuthService.spec.js +++ b/api/server/services/AuthService.spec.js @@ -14,6 +14,7 @@ jest.mock('@librechat/api', () => ({ isEmailDomainAllowed: jest.fn(), math: jest.fn((val, fallback) => (val ? Number(val) : fallback)), shouldUseSecureCookie: jest.fn(() => false), + resolveAppConfigForUser: jest.fn(async (_getAppConfig, _user) => ({})), })); jest.mock('~/models', () => ({ findUser: jest.fn(), @@ -35,8 +36,14 @@ jest.mock('~/strategies/validators', () => ({ registerSchema: { parse: jest.fn() jest.mock('~/server/services/Config', () => ({ getAppConfig: jest.fn() })); jest.mock('~/server/utils', () => ({ sendEmail: jest.fn() })); -const { shouldUseSecureCookie } = require('@librechat/api'); -const { setOpenIDAuthTokens } = require('./AuthService'); +const { + shouldUseSecureCookie, + isEmailDomainAllowed, + resolveAppConfigForUser, +} = require('@librechat/api'); +const { findUser } = require('~/models'); +const { getAppConfig } = require('~/server/services/Config'); +const { setOpenIDAuthTokens, requestPasswordReset } = require('./AuthService'); /** Helper to build a mock Express response */ function mockResponse() { @@ -267,3 +274,68 @@ describe('setOpenIDAuthTokens', () => { }); }); }); + +describe('requestPasswordReset', () => { + beforeEach(() => { + jest.clearAllMocks(); + isEmailDomainAllowed.mockReturnValue(true); + getAppConfig.mockResolvedValue({ + registration: { allowedDomains: ['example.com'] }, + }); + resolveAppConfigForUser.mockResolvedValue({ + registration: { allowedDomains: ['example.com'] }, + }); + }); + + it('should fast-fail with base config before DB lookup for blocked domains', async () => { + isEmailDomainAllowed.mockReturnValue(false); + + const req = { body: { email: 'blocked@evil.com' }, ip: '127.0.0.1' }; + const result = await requestPasswordReset(req); + + expect(getAppConfig).toHaveBeenCalledWith({ baseOnly: true }); + expect(findUser).not.toHaveBeenCalled(); + expect(result).toBeInstanceOf(Error); + }); + + it('should call resolveAppConfigForUser for tenant user', async () => { + const user = { + _id: 'user-tenant', + email: 'user@example.com', + tenantId: 'tenant-x', + role: 'USER', + }; + findUser.mockResolvedValue(user); + + const req = { body: { email: 'user@example.com' }, ip: '127.0.0.1' }; + await requestPasswordReset(req); + + expect(resolveAppConfigForUser).toHaveBeenCalledWith(getAppConfig, user); + }); + + it('should reuse baseConfig for non-tenant user without calling resolveAppConfigForUser', async () => { + findUser.mockResolvedValue({ _id: 'user-no-tenant', email: 'user@example.com' }); + + const req = { body: { email: 'user@example.com' }, ip: '127.0.0.1' }; + await requestPasswordReset(req); + + expect(resolveAppConfigForUser).not.toHaveBeenCalled(); + }); + + it('should return generic response when tenant config blocks the domain (non-enumerable)', async () => { + const user = { + _id: 'user-tenant', + email: 'user@example.com', + tenantId: 'tenant-x', + role: 'USER', + }; + findUser.mockResolvedValue(user); + isEmailDomainAllowed.mockReturnValueOnce(true).mockReturnValueOnce(false); + + const req = { body: { email: 'user@example.com' }, ip: '127.0.0.1' }; + const result = await requestPasswordReset(req); + + expect(result).not.toBeInstanceOf(Error); + expect(result.message).toContain('If an account with that email exists'); + }); +}); diff --git a/api/strategies/ldapStrategy.js b/api/strategies/ldapStrategy.js index 0c99c7b670..9253f54196 100644 --- a/api/strategies/ldapStrategy.js +++ b/api/strategies/ldapStrategy.js @@ -2,7 +2,12 @@ const fs = require('fs'); const LdapStrategy = require('passport-ldapauth'); const { logger } = require('@librechat/data-schemas'); const { SystemRoles, ErrorTypes } = require('librechat-data-provider'); -const { isEnabled, getBalanceConfig, isEmailDomainAllowed } = require('@librechat/api'); +const { + isEnabled, + getBalanceConfig, + isEmailDomainAllowed, + resolveAppConfigForUser, +} = require('@librechat/api'); const { createUser, findUser, updateUser, countUsers } = require('~/models'); const { getAppConfig } = require('~/server/services/Config'); @@ -89,16 +94,6 @@ const ldapLogin = new LdapStrategy(ldapOptions, async (userinfo, done) => { const ldapId = (LDAP_ID && userinfo[LDAP_ID]) || userinfo.uid || userinfo.sAMAccountName || userinfo.mail; - let user = await findUser({ ldapId }); - if (user && user.provider !== 'ldap') { - logger.info( - `[ldapStrategy] User ${user.email} already exists with provider ${user.provider}`, - ); - return done(null, false, { - message: ErrorTypes.AUTH_FAILED, - }); - } - const fullNameAttributes = LDAP_FULL_NAME && LDAP_FULL_NAME.split(','); const fullName = fullNameAttributes && fullNameAttributes.length > 0 @@ -122,7 +117,31 @@ const ldapLogin = new LdapStrategy(ldapOptions, async (userinfo, done) => { ); } - const appConfig = await getAppConfig({ baseOnly: true }); + // Domain check before findUser for two-phase fast-fail (consistent with SAML/OpenID/social). + // This means cross-provider users from blocked domains get 'Email domain not allowed' + // instead of AUTH_FAILED โ€” both deny access. + const baseConfig = await getAppConfig({ baseOnly: true }); + if (!isEmailDomainAllowed(mail, baseConfig?.registration?.allowedDomains)) { + logger.error( + `[LDAP Strategy] Authentication blocked - email domain not allowed [Email: ${mail}]`, + ); + return done(null, false, { message: 'Email domain not allowed' }); + } + + let user = await findUser({ ldapId }); + if (user && user.provider !== 'ldap') { + logger.info( + `[ldapStrategy] User ${user.email} already exists with provider ${user.provider}`, + ); + return done(null, false, { + message: ErrorTypes.AUTH_FAILED, + }); + } + + const appConfig = user?.tenantId + ? await resolveAppConfigForUser(getAppConfig, user) + : baseConfig; + if (!isEmailDomainAllowed(mail, appConfig?.registration?.allowedDomains)) { logger.error( `[LDAP Strategy] Authentication blocked - email domain not allowed [Email: ${mail}]`, diff --git a/api/strategies/ldapStrategy.spec.js b/api/strategies/ldapStrategy.spec.js index a00e9b14b7..876d70f845 100644 --- a/api/strategies/ldapStrategy.spec.js +++ b/api/strategies/ldapStrategy.spec.js @@ -9,10 +9,10 @@ jest.mock('@librechat/data-schemas', () => ({ })); jest.mock('@librechat/api', () => ({ - // isEnabled used for TLS flags isEnabled: jest.fn(() => false), isEmailDomainAllowed: jest.fn(() => true), getBalanceConfig: jest.fn(() => ({ enabled: false })), + resolveAppConfigForUser: jest.fn(async (_getAppConfig, _user) => ({})), })); jest.mock('~/models', () => ({ @@ -30,14 +30,15 @@ jest.mock('~/server/services/Config', () => ({ let verifyCallback; jest.mock('passport-ldapauth', () => { return jest.fn().mockImplementation((options, verify) => { - verifyCallback = verify; // capture the strategy verify function + verifyCallback = verify; return { name: 'ldap', options, verify }; }); }); const { ErrorTypes } = require('librechat-data-provider'); -const { isEmailDomainAllowed } = require('@librechat/api'); +const { isEmailDomainAllowed, resolveAppConfigForUser } = require('@librechat/api'); const { findUser, createUser, updateUser, countUsers } = require('~/models'); +const { getAppConfig } = require('~/server/services/Config'); // Helper to call the verify callback and wrap in a Promise for convenience const callVerify = (userinfo) => @@ -117,6 +118,7 @@ describe('ldapStrategy', () => { expect(user).toBe(false); expect(info).toEqual({ message: ErrorTypes.AUTH_FAILED }); expect(createUser).not.toHaveBeenCalled(); + expect(resolveAppConfigForUser).not.toHaveBeenCalled(); }); it('updates an existing ldap user with current LDAP info', async () => { @@ -158,7 +160,6 @@ describe('ldapStrategy', () => { uid: 'uid999', givenName: 'John', cn: 'John Doe', - // no mail and no custom LDAP_EMAIL }; const { user } = await callVerify(userinfo); @@ -180,4 +181,66 @@ describe('ldapStrategy', () => { expect(user).toBe(false); expect(info).toEqual({ message: 'Email domain not allowed' }); }); + + it('passes getAppConfig and found user to resolveAppConfigForUser', async () => { + const existing = { + _id: 'u3', + provider: 'ldap', + email: 'tenant@example.com', + ldapId: 'uid-tenant', + username: 'tenantuser', + name: 'Tenant User', + tenantId: 'tenant-a', + role: 'USER', + }; + findUser.mockResolvedValue(existing); + + const userinfo = { + uid: 'uid-tenant', + mail: 'tenant@example.com', + givenName: 'Tenant', + cn: 'Tenant User', + }; + + await callVerify(userinfo); + + expect(resolveAppConfigForUser).toHaveBeenCalledWith(getAppConfig, existing); + }); + + it('uses baseConfig for new user without calling resolveAppConfigForUser', async () => { + findUser.mockResolvedValue(null); + + const userinfo = { + uid: 'uid-new', + mail: 'newuser@example.com', + givenName: 'New', + cn: 'New User', + }; + + await callVerify(userinfo); + + expect(resolveAppConfigForUser).not.toHaveBeenCalled(); + expect(getAppConfig).toHaveBeenCalledWith({ baseOnly: true }); + }); + + it('should block login when tenant config restricts the domain', async () => { + const existing = { + _id: 'u-blocked', + provider: 'ldap', + ldapId: 'uid-tenant', + tenantId: 'tenant-strict', + role: 'USER', + }; + findUser.mockResolvedValue(existing); + resolveAppConfigForUser.mockResolvedValue({ + registration: { allowedDomains: ['other.com'] }, + }); + isEmailDomainAllowed.mockReturnValueOnce(true).mockReturnValueOnce(false); + + const userinfo = { uid: 'uid-tenant', mail: 'user@example.com', givenName: 'Test', cn: 'Test' }; + const { user, info } = await callVerify(userinfo); + + expect(user).toBe(false); + expect(info).toEqual({ message: 'Email domain not allowed' }); + }); }); diff --git a/api/strategies/openidStrategy.js b/api/strategies/openidStrategy.js index ab7eb60261..7314a84e15 100644 --- a/api/strategies/openidStrategy.js +++ b/api/strategies/openidStrategy.js @@ -15,6 +15,7 @@ const { findOpenIDUser, getBalanceConfig, isEmailDomainAllowed, + resolveAppConfigForUser, } = require('@librechat/api'); const { getStrategyFunctions } = require('~/server/services/Files/strategies'); const { findUser, createUser, updateUser } = require('~/models'); @@ -468,9 +469,10 @@ async function processOpenIDAuth(tokenset, existingUsersOnly = false) { Object.assign(userinfo, providerUserinfo); } - const appConfig = await getAppConfig({ baseOnly: true }); const email = getOpenIdEmail(userinfo); - if (!isEmailDomainAllowed(email, appConfig?.registration?.allowedDomains)) { + + const baseConfig = await getAppConfig({ baseOnly: true }); + if (!isEmailDomainAllowed(email, baseConfig?.registration?.allowedDomains)) { logger.error( `[OpenID Strategy] Authentication blocked - email domain not allowed [Identifier: ${email}]`, ); @@ -491,6 +493,15 @@ async function processOpenIDAuth(tokenset, existingUsersOnly = false) { throw new Error(ErrorTypes.AUTH_FAILED); } + const appConfig = user?.tenantId ? await resolveAppConfigForUser(getAppConfig, user) : baseConfig; + + if (!isEmailDomainAllowed(email, appConfig?.registration?.allowedDomains)) { + logger.error( + `[OpenID Strategy] Authentication blocked - email domain not allowed [Identifier: ${email}]`, + ); + throw new Error('Email domain not allowed'); + } + const fullName = getFullName(userinfo); const requiredRole = process.env.OPENID_REQUIRED_ROLE; diff --git a/api/strategies/openidStrategy.spec.js b/api/strategies/openidStrategy.spec.js index 4436fab672..6d824176f7 100644 --- a/api/strategies/openidStrategy.spec.js +++ b/api/strategies/openidStrategy.spec.js @@ -1,1822 +1,1873 @@ -const undici = require('undici'); -const fetch = require('node-fetch'); -const jwtDecode = require('jsonwebtoken/decode'); -const { ErrorTypes } = require('librechat-data-provider'); -const { findUser, createUser, updateUser } = require('~/models'); -const { setupOpenId } = require('./openidStrategy'); - -// --- Mocks --- -jest.mock('node-fetch'); -jest.mock('jsonwebtoken/decode'); -jest.mock('undici', () => ({ - fetch: jest.fn(), - ProxyAgent: jest.fn(), -})); -jest.mock('~/server/services/Files/strategies', () => ({ - getStrategyFunctions: jest.fn(() => ({ - saveBuffer: jest.fn().mockResolvedValue('/fake/path/to/avatar.png'), - })), -})); -jest.mock('~/server/services/Config', () => ({ - getAppConfig: jest.fn().mockResolvedValue({}), -})); -jest.mock('@librechat/api', () => ({ - ...jest.requireActual('@librechat/api'), - isEnabled: jest.fn(() => false), - isEmailDomainAllowed: jest.fn(() => true), - findOpenIDUser: jest.requireActual('@librechat/api').findOpenIDUser, - getBalanceConfig: jest.fn(() => ({ - enabled: false, - })), -})); -jest.mock('~/models', () => ({ - findUser: jest.fn(), - createUser: jest.fn(), - updateUser: jest.fn(), -})); -jest.mock('@librechat/data-schemas', () => ({ - ...jest.requireActual('@librechat/api'), - logger: { - info: jest.fn(), - warn: jest.fn(), - debug: jest.fn(), - error: jest.fn(), - }, - hashToken: jest.fn().mockResolvedValue('hashed-token'), -})); -jest.mock('~/cache/getLogStores', () => - jest.fn(() => ({ - get: jest.fn(), - set: jest.fn(), - })), -); - -// Mock the openid-client module and all its dependencies -jest.mock('openid-client', () => { - return { - discovery: jest.fn().mockResolvedValue({ - clientId: 'fake_client_id', - clientSecret: 'fake_client_secret', - issuer: 'https://fake-issuer.com', - // Add any other properties needed by the implementation - }), - fetchUserInfo: jest.fn().mockImplementation(() => { - // Only return additional properties, but don't override any claims - return Promise.resolve({}); - }), - genericGrantRequest: jest.fn().mockResolvedValue({ - access_token: 'exchanged_graph_token', - expires_in: 3600, - }), - customFetch: Symbol('customFetch'), - }; -}); - -jest.mock('openid-client/passport', () => { - /** Store callbacks by strategy name - 'openid' and 'openidAdmin' */ - const verifyCallbacks = {}; - let lastVerifyCallback; - - const mockStrategy = jest.fn((options, verify) => { - lastVerifyCallback = verify; - return { name: 'openid', options, verify }; - }); - - return { - Strategy: mockStrategy, - /** Get the last registered callback (for backward compatibility) */ - __getVerifyCallback: () => lastVerifyCallback, - /** Store callback by name when passport.use is called */ - __setVerifyCallback: (name, callback) => { - verifyCallbacks[name] = callback; - }, - /** Get callback by strategy name */ - __getVerifyCallbackByName: (name) => verifyCallbacks[name], - }; -}); - -// Mock passport - capture strategy name and callback -jest.mock('passport', () => ({ - use: jest.fn((name, strategy) => { - const passportMock = require('openid-client/passport'); - if (strategy && strategy.verify) { - passportMock.__setVerifyCallback(name, strategy.verify); - } - }), -})); - -describe('setupOpenId', () => { - // Store a reference to the verify callback once it's set up - let verifyCallback; - - // Helper to wrap the verify callback in a promise - const validate = (tokenset) => - new Promise((resolve, reject) => { - verifyCallback(tokenset, (err, user, details) => { - if (err) { - reject(err); - } else { - resolve({ user, details }); - } - }); - }); - - const tokenset = { - id_token: 'fake_id_token', - access_token: 'fake_access_token', - claims: () => ({ - sub: '1234', - email: 'test@example.com', - email_verified: true, - given_name: 'First', - family_name: 'Last', - name: 'My Full', - preferred_username: 'testusername', - username: 'flast', - picture: 'https://example.com/avatar.png', - }), - }; - - beforeEach(async () => { - // Clear previous mock calls and reset implementations - jest.clearAllMocks(); - - // Reset environment variables needed by the strategy - process.env.OPENID_ISSUER = 'https://fake-issuer.com'; - process.env.OPENID_CLIENT_ID = 'fake_client_id'; - process.env.OPENID_CLIENT_SECRET = 'fake_client_secret'; - process.env.DOMAIN_SERVER = 'https://example.com'; - process.env.OPENID_CALLBACK_URL = '/callback'; - process.env.OPENID_SCOPE = 'openid profile email'; - process.env.OPENID_REQUIRED_ROLE = 'requiredRole'; - process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'roles'; - process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; - process.env.OPENID_ADMIN_ROLE = 'admin'; - process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'permissions'; - process.env.OPENID_ADMIN_ROLE_TOKEN_KIND = 'id'; - delete process.env.OPENID_USERNAME_CLAIM; - delete process.env.OPENID_NAME_CLAIM; - delete process.env.OPENID_EMAIL_CLAIM; - delete process.env.PROXY; - delete process.env.OPENID_USE_PKCE; - - // Default jwtDecode mock returns a token that includes the required role. - jwtDecode.mockReturnValue({ - roles: ['requiredRole'], - permissions: ['admin'], - }); - - // By default, assume that no user is found, so createUser will be called - findUser.mockResolvedValue(null); - createUser.mockImplementation(async (userData) => { - // simulate created user with an _id property - return { _id: 'newUserId', ...userData }; - }); - updateUser.mockImplementation(async (id, userData) => { - return { _id: id, ...userData }; - }); - - // For image download, simulate a successful response - const fakeBuffer = Buffer.from('fake image'); - const fakeResponse = { - ok: true, - buffer: jest.fn().mockResolvedValue(fakeBuffer), - }; - fetch.mockResolvedValue(fakeResponse); - - // Call the setup function and capture the verify callback for the regular 'openid' strategy - // (not 'openidAdmin' which requires existing users) - await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - }); - - it('should create a new user with correct username when preferred_username claim exists', async () => { - // Arrange โ€“ our userinfo already has preferred_username 'testusername' - const userinfo = tokenset.claims(); - - // Act - const { user } = await validate(tokenset); - - // Assert - expect(user.username).toBe(userinfo.preferred_username); - expect(createUser).toHaveBeenCalledWith( - expect.objectContaining({ - provider: 'openid', - openidId: userinfo.sub, - username: userinfo.preferred_username, - email: userinfo.email, - name: `${userinfo.given_name} ${userinfo.family_name}`, - }), - { enabled: false }, - true, - true, - ); - }); - - it('should use username as username when preferred_username claim is missing', async () => { - // Arrange โ€“ remove preferred_username from userinfo - const userinfo = { ...tokenset.claims() }; - delete userinfo.preferred_username; - // Expect the username to be the "username" - const expectUsername = userinfo.username; - - // Act - const { user } = await validate({ ...tokenset, claims: () => userinfo }); - - // Assert - expect(user.username).toBe(expectUsername); - expect(createUser).toHaveBeenCalledWith( - expect.objectContaining({ username: expectUsername }), - { enabled: false }, - true, - true, - ); - }); - - it('should use email as username when username and preferred_username are missing', async () => { - // Arrange โ€“ remove username and preferred_username - const userinfo = { ...tokenset.claims() }; - delete userinfo.username; - delete userinfo.preferred_username; - const expectUsername = userinfo.email; - - // Act - const { user } = await validate({ ...tokenset, claims: () => userinfo }); - - // Assert - expect(user.username).toBe(expectUsername); - expect(createUser).toHaveBeenCalledWith( - expect.objectContaining({ username: expectUsername }), - { enabled: false }, - true, - true, - ); - }); - - it('should override username with OPENID_USERNAME_CLAIM when set', async () => { - // Arrange โ€“ set OPENID_USERNAME_CLAIM so that the sub claim is used - process.env.OPENID_USERNAME_CLAIM = 'sub'; - const userinfo = tokenset.claims(); - - // Act - const { user } = await validate(tokenset); - - // Assert โ€“ username should equal the sub (converted as-is) - expect(user.username).toBe(userinfo.sub); - expect(createUser).toHaveBeenCalledWith( - expect.objectContaining({ username: userinfo.sub }), - { enabled: false }, - true, - true, - ); - }); - - it('should set the full name correctly when given_name and family_name exist', async () => { - // Arrange - const userinfo = tokenset.claims(); - const expectedFullName = `${userinfo.given_name} ${userinfo.family_name}`; - - // Act - const { user } = await validate(tokenset); - - // Assert - expect(user.name).toBe(expectedFullName); - }); - - it('should override full name with OPENID_NAME_CLAIM when set', async () => { - // Arrange โ€“ use the name claim as the full name - process.env.OPENID_NAME_CLAIM = 'name'; - const userinfo = { ...tokenset.claims(), name: 'Custom Name' }; - - // Act - const { user } = await validate({ ...tokenset, claims: () => userinfo }); - - // Assert - expect(user.name).toBe('Custom Name'); - }); - - it('should update an existing user on login', async () => { - // Arrange โ€“ simulate that a user already exists with openid provider - const existingUser = { - _id: 'existingUserId', - provider: 'openid', - email: tokenset.claims().email, - openidId: '', - username: '', - name: '', - }; - findUser.mockImplementation(async (query) => { - if (query.openidId === tokenset.claims().sub || query.email === tokenset.claims().email) { - return existingUser; - } - return null; - }); - - const userinfo = tokenset.claims(); - - // Act - await validate(tokenset); - - // Assert โ€“ updateUser should be called and the user object updated - expect(updateUser).toHaveBeenCalledWith( - existingUser._id, - expect.objectContaining({ - provider: 'openid', - openidId: userinfo.sub, - username: userinfo.preferred_username, - name: `${userinfo.given_name} ${userinfo.family_name}`, - }), - ); - }); - - it('should block login when email exists with different provider', async () => { - // Arrange โ€“ simulate that a user exists with same email but different provider - const existingUser = { - _id: 'existingUserId', - provider: 'google', - email: tokenset.claims().email, - googleId: 'some-google-id', - username: 'existinguser', - name: 'Existing User', - }; - findUser.mockImplementation(async (query) => { - if (query.email === tokenset.claims().email && !query.provider) { - return existingUser; - } - return null; - }); - - // Act - const result = await validate(tokenset); - - // Assert โ€“ verify that the strategy rejects login - expect(result.user).toBe(false); - expect(result.details.message).toBe(ErrorTypes.AUTH_FAILED); - expect(createUser).not.toHaveBeenCalled(); - expect(updateUser).not.toHaveBeenCalled(); - }); - - it('should block login when email fallback finds user with mismatched openidId', async () => { - const existingUser = { - _id: 'existingUserId', - provider: 'openid', - openidId: 'different-sub-claim', - email: tokenset.claims().email, - username: 'existinguser', - name: 'Existing User', - }; - findUser.mockImplementation(async (query) => { - if (query.$or) { - return null; - } - if (query.email === tokenset.claims().email) { - return existingUser; - } - return null; - }); - - const result = await validate(tokenset); - - expect(result.user).toBe(false); - expect(result.details.message).toBe(ErrorTypes.AUTH_FAILED); - expect(createUser).not.toHaveBeenCalled(); - expect(updateUser).not.toHaveBeenCalled(); - }); - - it('should enforce the required role and reject login if missing', async () => { - // Arrange โ€“ simulate a token without the required role. - jwtDecode.mockReturnValue({ - roles: ['SomeOtherRole'], - }); - - // Act - const { user, details } = await validate(tokenset); - - // Assert โ€“ verify that the strategy rejects login - expect(user).toBe(false); - expect(details.message).toBe('You must have "requiredRole" role to log in.'); - }); - - it('should not treat substring matches in string roles as satisfying required role', async () => { - // Arrange โ€“ override required role to "read" then re-setup - process.env.OPENID_REQUIRED_ROLE = 'read'; - await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - - // Token contains "bread" which *contains* "read" as a substring - jwtDecode.mockReturnValue({ - roles: 'bread', - }); - - // Act - const { user, details } = await validate(tokenset); - - // Assert โ€“ verify that substring match does not grant access - expect(user).toBe(false); - expect(details.message).toBe('You must have "read" role to log in.'); - }); - - it('should allow login when roles claim is a space-separated string containing the required role', async () => { - // Arrange โ€“ IdP returns roles as a space-delimited string - jwtDecode.mockReturnValue({ - roles: 'role1 role2 requiredRole', - }); - - // Act - const { user } = await validate(tokenset); - - // Assert โ€“ login succeeds when required role is present after splitting - expect(user).toBeTruthy(); - expect(createUser).toHaveBeenCalled(); - }); - - it('should allow login when roles claim is a comma-separated string containing the required role', async () => { - // Arrange โ€“ IdP returns roles as a comma-delimited string - jwtDecode.mockReturnValue({ - roles: 'role1,role2,requiredRole', - }); - - // Act - const { user } = await validate(tokenset); - - // Assert โ€“ login succeeds when required role is present after splitting - expect(user).toBeTruthy(); - expect(createUser).toHaveBeenCalled(); - }); - - it('should allow login when roles claim is a mixed comma-and-space-separated string containing the required role', async () => { - // Arrange โ€“ IdP returns roles with comma-and-space delimiters - jwtDecode.mockReturnValue({ - roles: 'role1, role2, requiredRole', - }); - - // Act - const { user } = await validate(tokenset); - - // Assert โ€“ login succeeds when required role is present after splitting - expect(user).toBeTruthy(); - expect(createUser).toHaveBeenCalled(); - }); - - it('should reject login when roles claim is a space-separated string that does not contain the required role', async () => { - // Arrange โ€“ IdP returns a delimited string but required role is absent - jwtDecode.mockReturnValue({ - roles: 'role1 role2 otherRole', - }); - - // Act - const { user, details } = await validate(tokenset); - - // Assert โ€“ login is rejected with the correct error message - expect(user).toBe(false); - expect(details.message).toBe('You must have "requiredRole" role to log in.'); - }); - - it('should allow login when single required role is present (backward compatibility)', async () => { - // Arrange โ€“ ensure single role configuration (as set in beforeEach) - // OPENID_REQUIRED_ROLE = 'requiredRole' - // Default jwtDecode mock in beforeEach already returns this role - jwtDecode.mockReturnValue({ - roles: ['requiredRole', 'anotherRole'], - }); - - // Act - const { user } = await validate(tokenset); - - // Assert โ€“ verify that login succeeds with single role configuration - expect(user).toBeTruthy(); - expect(user.email).toBe(tokenset.claims().email); - expect(user.username).toBe(tokenset.claims().preferred_username); - expect(createUser).toHaveBeenCalled(); - }); - - describe('group overage and groups handling', () => { - it.each([ - ['groups array contains required group', ['group-required', 'other-group'], true, undefined], - [ - 'groups array missing required group', - ['other-group'], - false, - 'You must have "group-required" role to log in.', - ], - ['groups string equals required group', 'group-required', true, undefined], - [ - 'groups string is other group', - 'other-group', - false, - 'You must have "group-required" role to log in.', - ], - ])( - 'uses groups claim directly when %s (no overage)', - async (_label, groupsClaim, expectedAllowed, expectedMessage) => { - process.env.OPENID_REQUIRED_ROLE = 'group-required'; - process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups'; - process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; - - jwtDecode.mockReturnValue({ - groups: groupsClaim, - permissions: ['admin'], - }); - - await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - - const { user, details } = await validate(tokenset); - - expect(undici.fetch).not.toHaveBeenCalled(); - expect(Boolean(user)).toBe(expectedAllowed); - expect(details?.message).toBe(expectedMessage); - }, - ); - - it.each([ - ['token kind is not id', { kind: 'access', path: 'groups', decoded: { hasgroups: true } }], - ['parameter path is not groups', { kind: 'id', path: 'roles', decoded: { hasgroups: true } }], - ['decoded token is falsy', { kind: 'id', path: 'groups', decoded: null }], - [ - 'no overage indicators in decoded token', - { - kind: 'id', - path: 'groups', - decoded: { - permissions: ['admin'], - }, - }, - ], - [ - 'only _claim_names present (no _claim_sources)', - { - kind: 'id', - path: 'groups', - decoded: { - _claim_names: { groups: 'src1' }, - permissions: ['admin'], - }, - }, - ], - [ - 'only _claim_sources present (no _claim_names)', - { - kind: 'id', - path: 'groups', - decoded: { - _claim_sources: { src1: { endpoint: 'https://graph.windows.net/ignored' } }, - permissions: ['admin'], - }, - }, - ], - ])('does not attempt overage resolution when %s', async (_label, cfg) => { - process.env.OPENID_REQUIRED_ROLE = 'group-required'; - process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = cfg.path; - process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = cfg.kind; - - jwtDecode.mockReturnValue(cfg.decoded); - - await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - - const { user, details } = await validate(tokenset); - - expect(undici.fetch).not.toHaveBeenCalled(); - expect(user).toBe(false); - expect(details.message).toBe('You must have "group-required" role to log in.'); - const { logger } = require('@librechat/data-schemas'); - const expectedTokenKind = cfg.kind === 'access' ? 'access token' : 'id token'; - expect(logger.error).toHaveBeenCalledWith( - expect.stringContaining(`Key '${cfg.path}' not found in ${expectedTokenKind}!`), - ); - }); - }); - - describe('resolving groups via Microsoft Graph', () => { - it('denies login and does not call Graph when access token is missing', async () => { - process.env.OPENID_REQUIRED_ROLE = 'group-required'; - process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups'; - process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; - - const { logger } = require('@librechat/data-schemas'); - - jwtDecode.mockReturnValue({ - hasgroups: true, - permissions: ['admin'], - }); - - await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - - const tokensetWithoutAccess = { - ...tokenset, - access_token: undefined, - }; - - const { user, details } = await validate(tokensetWithoutAccess); - - expect(user).toBe(false); - expect(details.message).toBe('You must have "group-required" role to log in.'); - - expect(undici.fetch).not.toHaveBeenCalled(); - expect(logger.error).toHaveBeenCalledWith( - expect.stringContaining('Access token missing; cannot resolve group overage'), - ); - }); - - it.each([ - [ - 'Graph returns HTTP error', - async () => ({ - ok: false, - status: 403, - statusText: 'Forbidden', - json: async () => ({}), - }), - [ - '[openidStrategy] Failed to resolve groups via Microsoft Graph getMemberObjects: HTTP 403 Forbidden', - ], - ], - [ - 'Graph network error', - async () => { - throw new Error('network error'); - }, - [ - '[openidStrategy] Error resolving groups via Microsoft Graph getMemberObjects:', - expect.any(Error), - ], - ], - [ - 'Graph returns unexpected shape (no value)', - async () => ({ - ok: true, - status: 200, - statusText: 'OK', - json: async () => ({}), - }), - [ - '[openidStrategy] Unexpected response format when resolving groups via Microsoft Graph getMemberObjects', - ], - ], - [ - 'Graph returns invalid value type', - async () => ({ - ok: true, - status: 200, - statusText: 'OK', - json: async () => ({ value: 'not-an-array' }), - }), - [ - '[openidStrategy] Unexpected response format when resolving groups via Microsoft Graph getMemberObjects', - ], - ], - ])( - 'denies login when overage resolution fails because %s', - async (_label, setupFetch, expectedErrorArgs) => { - process.env.OPENID_REQUIRED_ROLE = 'group-required'; - process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups'; - process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; - - const { logger } = require('@librechat/data-schemas'); - - jwtDecode.mockReturnValue({ - hasgroups: true, - permissions: ['admin'], - }); - - await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - - undici.fetch.mockImplementation(setupFetch); - - const { user, details } = await validate(tokenset); - - expect(undici.fetch).toHaveBeenCalled(); - expect(user).toBe(false); - expect(details.message).toBe('You must have "group-required" role to log in.'); - - expect(logger.error).toHaveBeenCalledWith(...expectedErrorArgs); - }, - ); - - it.each([ - [ - 'hasgroups overage and Graph contains required group', - { - hasgroups: true, - }, - ['group-required', 'some-other-group'], - true, - ], - [ - '_claim_* overage and Graph contains required group', - { - _claim_names: { groups: 'src1' }, - _claim_sources: { src1: { endpoint: 'https://graph.windows.net/ignored' } }, - }, - ['group-required', 'some-other-group'], - true, - ], - [ - 'hasgroups overage and Graph does NOT contain required group', - { - hasgroups: true, - }, - ['some-other-group'], - false, - ], - [ - '_claim_* overage and Graph does NOT contain required group', - { - _claim_names: { groups: 'src1' }, - _claim_sources: { src1: { endpoint: 'https://graph.windows.net/ignored' } }, - }, - ['some-other-group'], - false, - ], - ])( - 'resolves groups via Microsoft Graph when %s', - async (_label, decodedTokenValue, graphGroups, expectedAllowed) => { - process.env.OPENID_REQUIRED_ROLE = 'group-required'; - process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups'; - process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; - - const { logger } = require('@librechat/data-schemas'); - - jwtDecode.mockReturnValue(decodedTokenValue); - - await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - - undici.fetch.mockResolvedValue({ - ok: true, - status: 200, - statusText: 'OK', - json: async () => ({ - value: graphGroups, - }), - }); - - const { user } = await validate(tokenset); - - expect(undici.fetch).toHaveBeenCalledWith( - 'https://graph.microsoft.com/v1.0/me/getMemberObjects', - expect.objectContaining({ - method: 'POST', - headers: expect.objectContaining({ - Authorization: 'Bearer exchanged_graph_token', - }), - }), - ); - expect(Boolean(user)).toBe(expectedAllowed); - - expect(logger.debug).toHaveBeenCalledWith( - expect.stringContaining( - `Successfully resolved ${graphGroups.length} groups via Microsoft Graph getMemberObjects`, - ), - ); - }, - ); - }); - - describe('OBO token exchange for overage', () => { - it('exchanges access token via OBO before calling Graph API', async () => { - const openidClient = require('openid-client'); - process.env.OPENID_REQUIRED_ROLE = 'group-required'; - process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups'; - process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; - - jwtDecode.mockReturnValue({ hasgroups: true }); - - await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - - undici.fetch.mockResolvedValue({ - ok: true, - status: 200, - statusText: 'OK', - json: async () => ({ value: ['group-required'] }), - }); - - await validate(tokenset); - - expect(openidClient.genericGrantRequest).toHaveBeenCalledWith( - expect.anything(), - 'urn:ietf:params:oauth:grant-type:jwt-bearer', - expect.objectContaining({ - scope: 'https://graph.microsoft.com/User.Read', - assertion: tokenset.access_token, - requested_token_use: 'on_behalf_of', - }), - ); - - expect(undici.fetch).toHaveBeenCalledWith( - 'https://graph.microsoft.com/v1.0/me/getMemberObjects', - expect.objectContaining({ - headers: expect.objectContaining({ - Authorization: 'Bearer exchanged_graph_token', - }), - }), - ); - }); - - it('caches the exchanged token and reuses it on subsequent calls', async () => { - const openidClient = require('openid-client'); - const getLogStores = require('~/cache/getLogStores'); - const mockSet = jest.fn(); - const mockGet = jest - .fn() - .mockResolvedValueOnce(undefined) - .mockResolvedValueOnce({ access_token: 'exchanged_graph_token' }); - getLogStores.mockReturnValue({ get: mockGet, set: mockSet }); - - process.env.OPENID_REQUIRED_ROLE = 'group-required'; - process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups'; - process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; - - jwtDecode.mockReturnValue({ hasgroups: true }); - - await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - - undici.fetch.mockResolvedValue({ - ok: true, - status: 200, - statusText: 'OK', - json: async () => ({ value: ['group-required'] }), - }); - - // First call: cache miss โ†’ OBO exchange โ†’ cache set - await validate(tokenset); - expect(mockSet).toHaveBeenCalledWith( - '1234:overage', - { access_token: 'exchanged_graph_token' }, - 3600000, - ); - expect(openidClient.genericGrantRequest).toHaveBeenCalledTimes(1); - - // Second call: cache hit โ†’ no new OBO exchange - openidClient.genericGrantRequest.mockClear(); - await validate(tokenset); - expect(openidClient.genericGrantRequest).not.toHaveBeenCalled(); - }); - }); - - describe('admin role group overage', () => { - it('resolves admin groups via Graph when overage is detected for admin role', async () => { - process.env.OPENID_REQUIRED_ROLE = 'group-required'; - process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups'; - process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; - process.env.OPENID_ADMIN_ROLE = 'admin-group-id'; - process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'groups'; - process.env.OPENID_ADMIN_ROLE_TOKEN_KIND = 'id'; - - jwtDecode.mockReturnValue({ hasgroups: true }); - - await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - - undici.fetch.mockResolvedValue({ - ok: true, - status: 200, - statusText: 'OK', - json: async () => ({ value: ['group-required', 'admin-group-id'] }), - }); - - const { user } = await validate(tokenset); - - expect(user.role).toBe('ADMIN'); - }); - - it('does not grant admin when overage groups do not contain admin role', async () => { - process.env.OPENID_REQUIRED_ROLE = 'group-required'; - process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups'; - process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; - process.env.OPENID_ADMIN_ROLE = 'admin-group-id'; - process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'groups'; - process.env.OPENID_ADMIN_ROLE_TOKEN_KIND = 'id'; - - jwtDecode.mockReturnValue({ hasgroups: true }); - - await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - - undici.fetch.mockResolvedValue({ - ok: true, - status: 200, - statusText: 'OK', - json: async () => ({ value: ['group-required', 'other-group'] }), - }); - - const { user } = await validate(tokenset); - - expect(user).toBeTruthy(); - expect(user.role).toBeUndefined(); - }); - - it('reuses already-resolved overage groups for admin role check (no duplicate Graph call)', async () => { - process.env.OPENID_REQUIRED_ROLE = 'group-required'; - process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups'; - process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; - process.env.OPENID_ADMIN_ROLE = 'admin-group-id'; - process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'groups'; - process.env.OPENID_ADMIN_ROLE_TOKEN_KIND = 'id'; - - jwtDecode.mockReturnValue({ hasgroups: true }); - - await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - - undici.fetch.mockResolvedValue({ - ok: true, - status: 200, - statusText: 'OK', - json: async () => ({ value: ['group-required', 'admin-group-id'] }), - }); - - await validate(tokenset); - - // Graph API should be called only once (for required role), admin role reuses the result - expect(undici.fetch).toHaveBeenCalledTimes(1); - }); - - it('demotes existing admin when overage groups no longer contain admin role', async () => { - process.env.OPENID_REQUIRED_ROLE = 'group-required'; - process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups'; - process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; - process.env.OPENID_ADMIN_ROLE = 'admin-group-id'; - process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'groups'; - process.env.OPENID_ADMIN_ROLE_TOKEN_KIND = 'id'; - - const existingAdminUser = { - _id: 'existingAdminId', - provider: 'openid', - email: tokenset.claims().email, - openidId: tokenset.claims().sub, - username: 'adminuser', - name: 'Admin User', - role: 'ADMIN', - }; - - findUser.mockImplementation(async (query) => { - if (query.openidId === tokenset.claims().sub || query.email === tokenset.claims().email) { - return existingAdminUser; - } - return null; - }); - - jwtDecode.mockReturnValue({ hasgroups: true }); - - await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - - undici.fetch.mockResolvedValue({ - ok: true, - status: 200, - statusText: 'OK', - json: async () => ({ value: ['group-required'] }), - }); - - const { user } = await validate(tokenset); - - expect(user.role).toBe('USER'); - }); - - it('does not attempt overage for admin role when token kind is not id', async () => { - process.env.OPENID_REQUIRED_ROLE = 'requiredRole'; - process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'roles'; - process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; - process.env.OPENID_ADMIN_ROLE = 'admin'; - process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'groups'; - process.env.OPENID_ADMIN_ROLE_TOKEN_KIND = 'access'; - - jwtDecode.mockReturnValue({ - roles: ['requiredRole'], - hasgroups: true, - }); - - await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - - const { user } = await validate(tokenset); - - // No Graph call since admin uses access token (not id) - expect(undici.fetch).not.toHaveBeenCalled(); - expect(user.role).toBeUndefined(); - }); - - it('resolves admin via Graph independently when OPENID_REQUIRED_ROLE is not configured', async () => { - delete process.env.OPENID_REQUIRED_ROLE; - process.env.OPENID_ADMIN_ROLE = 'admin-group-id'; - process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'groups'; - process.env.OPENID_ADMIN_ROLE_TOKEN_KIND = 'id'; - - jwtDecode.mockReturnValue({ hasgroups: true }); - await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - - undici.fetch.mockResolvedValue({ - ok: true, - status: 200, - statusText: 'OK', - json: async () => ({ value: ['admin-group-id'] }), - }); - - const { user } = await validate(tokenset); - expect(user.role).toBe('ADMIN'); - expect(undici.fetch).toHaveBeenCalledTimes(1); - }); - - it('denies admin when OPENID_REQUIRED_ROLE is absent and Graph does not contain admin group', async () => { - delete process.env.OPENID_REQUIRED_ROLE; - process.env.OPENID_ADMIN_ROLE = 'admin-group-id'; - process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'groups'; - process.env.OPENID_ADMIN_ROLE_TOKEN_KIND = 'id'; - - jwtDecode.mockReturnValue({ hasgroups: true }); - await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - - undici.fetch.mockResolvedValue({ - ok: true, - status: 200, - statusText: 'OK', - json: async () => ({ value: ['other-group'] }), - }); - - const { user } = await validate(tokenset); - expect(user).toBeTruthy(); - expect(user.role).toBeUndefined(); - }); - - it('denies login and logs error when OBO exchange throws', async () => { - const openidClient = require('openid-client'); - process.env.OPENID_REQUIRED_ROLE = 'group-required'; - process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups'; - process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; - - jwtDecode.mockReturnValue({ hasgroups: true }); - openidClient.genericGrantRequest.mockRejectedValueOnce(new Error('OBO exchange rejected')); - - await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - - const { user, details } = await validate(tokenset); - expect(user).toBe(false); - expect(details.message).toBe('You must have "group-required" role to log in.'); - expect(undici.fetch).not.toHaveBeenCalled(); - }); - - it('denies login when OBO exchange returns no access_token', async () => { - const openidClient = require('openid-client'); - process.env.OPENID_REQUIRED_ROLE = 'group-required'; - process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups'; - process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; - - jwtDecode.mockReturnValue({ hasgroups: true }); - openidClient.genericGrantRequest.mockResolvedValueOnce({ expires_in: 3600 }); - - await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - - const { user, details } = await validate(tokenset); - expect(user).toBe(false); - expect(details.message).toBe('You must have "group-required" role to log in.'); - expect(undici.fetch).not.toHaveBeenCalled(); - }); - }); - - it('should attempt to download and save the avatar if picture is provided', async () => { - // Act - const { user } = await validate(tokenset); - - // Assert โ€“ verify that download was attempted and the avatar field was set via updateUser - expect(fetch).toHaveBeenCalled(); - // Our mock getStrategyFunctions.saveBuffer returns '/fake/path/to/avatar.png' - expect(user.avatar).toBe('/fake/path/to/avatar.png'); - }); - - it('should not attempt to download avatar if picture is not provided', async () => { - // Arrange โ€“ remove picture - const userinfo = { ...tokenset.claims() }; - delete userinfo.picture; - - // Act - await validate({ ...tokenset, claims: () => userinfo }); - - // Assert โ€“ fetch should not be called and avatar should remain undefined or empty - expect(fetch).not.toHaveBeenCalled(); - // Depending on your implementation, user.avatar may be undefined or an empty string. - }); - - it('should support comma-separated multiple roles', async () => { - // Arrange - process.env.OPENID_REQUIRED_ROLE = 'someRole,anotherRole,admin'; - await setupOpenId(); // Re-initialize the strategy - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - jwtDecode.mockReturnValue({ - roles: ['anotherRole', 'aThirdRole'], - }); - - // Act - const { user } = await validate(tokenset); - - // Assert - expect(user).toBeTruthy(); - expect(user.email).toBe(tokenset.claims().email); - }); - - it('should reject login when user has none of the required multiple roles', async () => { - // Arrange - process.env.OPENID_REQUIRED_ROLE = 'someRole,anotherRole,admin'; - await setupOpenId(); // Re-initialize the strategy - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - jwtDecode.mockReturnValue({ - roles: ['aThirdRole', 'aFourthRole'], - }); - - // Act - const { user, details } = await validate(tokenset); - - // Assert - expect(user).toBe(false); - expect(details.message).toBe( - 'You must have one of: "someRole", "anotherRole", "admin" role to log in.', - ); - }); - - it('should handle spaces in comma-separated roles', async () => { - // Arrange - process.env.OPENID_REQUIRED_ROLE = ' someRole , anotherRole , admin '; - await setupOpenId(); // Re-initialize the strategy - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - jwtDecode.mockReturnValue({ - roles: ['someRole'], - }); - - // Act - const { user } = await validate(tokenset); - - // Assert - expect(user).toBeTruthy(); - }); - - it('should default to usePKCE false when OPENID_USE_PKCE is not defined', async () => { - const OpenIDStrategy = require('openid-client/passport').Strategy; - - delete process.env.OPENID_USE_PKCE; - await setupOpenId(); - - const callOptions = OpenIDStrategy.mock.calls[OpenIDStrategy.mock.calls.length - 1][0]; - expect(callOptions.usePKCE).toBe(false); - expect(callOptions.params?.code_challenge_method).toBeUndefined(); - }); - - it('should attach federatedTokens to user object for token propagation', async () => { - // Arrange - setup tokenset with access token, id token, refresh token, and expiration - const tokensetWithTokens = { - ...tokenset, - access_token: 'mock_access_token_abc123', - id_token: 'mock_id_token_def456', - refresh_token: 'mock_refresh_token_xyz789', - expires_at: 1234567890, - }; - - // Act - validate with the tokenset containing tokens - const { user } = await validate(tokensetWithTokens); - - // Assert - verify federatedTokens object is attached with correct values - expect(user.federatedTokens).toBeDefined(); - expect(user.federatedTokens).toEqual({ - access_token: 'mock_access_token_abc123', - id_token: 'mock_id_token_def456', - refresh_token: 'mock_refresh_token_xyz789', - expires_at: 1234567890, - }); - }); - - it('should include id_token in federatedTokens distinct from access_token', async () => { - // Arrange - use different values for access_token and id_token - const tokensetWithTokens = { - ...tokenset, - access_token: 'the_access_token', - id_token: 'the_id_token', - refresh_token: 'the_refresh_token', - expires_at: 9999999999, - }; - - // Act - const { user } = await validate(tokensetWithTokens); - - // Assert - id_token and access_token must be different values - expect(user.federatedTokens.access_token).toBe('the_access_token'); - expect(user.federatedTokens.id_token).toBe('the_id_token'); - expect(user.federatedTokens.id_token).not.toBe(user.federatedTokens.access_token); - }); - - it('should include tokenset along with federatedTokens', async () => { - // Arrange - const tokensetWithTokens = { - ...tokenset, - access_token: 'test_access_token', - id_token: 'test_id_token', - refresh_token: 'test_refresh_token', - expires_at: 9999999999, - }; - - // Act - const { user } = await validate(tokensetWithTokens); - - // Assert - both tokenset and federatedTokens should be present - expect(user.tokenset).toBeDefined(); - expect(user.federatedTokens).toBeDefined(); - expect(user.tokenset.access_token).toBe('test_access_token'); - expect(user.tokenset.id_token).toBe('test_id_token'); - expect(user.federatedTokens.access_token).toBe('test_access_token'); - expect(user.federatedTokens.id_token).toBe('test_id_token'); - }); - - it('should set role to "ADMIN" if OPENID_ADMIN_ROLE is set and user has that role', async () => { - // Act - const { user } = await validate(tokenset); - - // Assert โ€“ verify that the user role is set to "ADMIN" - expect(user.role).toBe('ADMIN'); - }); - - it('should not set user role if OPENID_ADMIN_ROLE is set but the user does not have that role', async () => { - // Arrange โ€“ simulate a token without the admin permission - jwtDecode.mockReturnValue({ - roles: ['requiredRole'], - permissions: ['not-admin'], - }); - - // Act - const { user } = await validate(tokenset); - - // Assert โ€“ verify that the user role is not defined - expect(user.role).toBeUndefined(); - }); - - it('should demote existing admin user when admin role is removed from token', async () => { - // Arrange โ€“ simulate an existing user who is currently an admin - const existingAdminUser = { - _id: 'existingAdminId', - provider: 'openid', - email: tokenset.claims().email, - openidId: tokenset.claims().sub, - username: 'adminuser', - name: 'Admin User', - role: 'ADMIN', - }; - - findUser.mockImplementation(async (query) => { - if (query.openidId === tokenset.claims().sub || query.email === tokenset.claims().email) { - return existingAdminUser; - } - return null; - }); - - // Token without admin permission - jwtDecode.mockReturnValue({ - roles: ['requiredRole'], - permissions: ['not-admin'], - }); - - const { logger } = require('@librechat/data-schemas'); - - // Act - const { user } = await validate(tokenset); - - // Assert โ€“ verify that the user was demoted - expect(user.role).toBe('USER'); - expect(updateUser).toHaveBeenCalledWith( - existingAdminUser._id, - expect.objectContaining({ - role: 'USER', - }), - ); - expect(logger.info).toHaveBeenCalledWith( - expect.stringContaining('demoted from admin - role no longer present in token'), - ); - }); - - it('should NOT demote admin user when admin role env vars are not configured', async () => { - // Arrange โ€“ remove admin role env vars - delete process.env.OPENID_ADMIN_ROLE; - delete process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH; - delete process.env.OPENID_ADMIN_ROLE_TOKEN_KIND; - - await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - - // Simulate an existing admin user - const existingAdminUser = { - _id: 'existingAdminId', - provider: 'openid', - email: tokenset.claims().email, - openidId: tokenset.claims().sub, - username: 'adminuser', - name: 'Admin User', - role: 'ADMIN', - }; - - findUser.mockImplementation(async (query) => { - if (query.openidId === tokenset.claims().sub || query.email === tokenset.claims().email) { - return existingAdminUser; - } - return null; - }); - - jwtDecode.mockReturnValue({ - roles: ['requiredRole'], - }); - - // Act - const { user } = await validate(tokenset); - - // Assert โ€“ verify that the admin user was NOT demoted - expect(user.role).toBe('ADMIN'); - expect(updateUser).toHaveBeenCalledWith( - existingAdminUser._id, - expect.objectContaining({ - role: 'ADMIN', - }), - ); - }); - - describe('lodash get - nested path extraction', () => { - it('should extract roles from deeply nested token path', async () => { - process.env.OPENID_REQUIRED_ROLE = 'app-user'; - process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'resource_access.my-client.roles'; - - jwtDecode.mockReturnValue({ - resource_access: { - 'my-client': { - roles: ['app-user', 'viewer'], - }, - }, - }); - - await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - - const { user } = await validate(tokenset); - - expect(user).toBeTruthy(); - expect(user.email).toBe(tokenset.claims().email); - }); - - it('should extract roles from three-level nested path', async () => { - process.env.OPENID_REQUIRED_ROLE = 'editor'; - process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'data.access.permissions.roles'; - - jwtDecode.mockReturnValue({ - data: { - access: { - permissions: { - roles: ['editor', 'reader'], - }, - }, - }, - }); - - await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - - const { user } = await validate(tokenset); - - expect(user).toBeTruthy(); - }); - - it('should log error and reject login when required role path does not exist in token', async () => { - const { logger } = require('@librechat/data-schemas'); - process.env.OPENID_REQUIRED_ROLE = 'app-user'; - process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'resource_access.nonexistent.roles'; - - jwtDecode.mockReturnValue({ - resource_access: { - 'my-client': { - roles: ['app-user'], - }, - }, - }); - - await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - - const { user, details } = await validate(tokenset); - - expect(logger.error).toHaveBeenCalledWith( - expect.stringContaining("Key 'resource_access.nonexistent.roles' not found in id token!"), - ); - expect(user).toBe(false); - expect(details.message).toContain('role to log in'); - }); - - it('should handle missing intermediate nested path gracefully', async () => { - const { logger } = require('@librechat/data-schemas'); - process.env.OPENID_REQUIRED_ROLE = 'user'; - process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'org.team.roles'; - - jwtDecode.mockReturnValue({ - org: { - other: 'value', - }, - }); - - await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - - const { user } = await validate(tokenset); - - expect(logger.error).toHaveBeenCalledWith( - expect.stringContaining("Key 'org.team.roles' not found in id token!"), - ); - expect(user).toBe(false); - }); - - it('should extract admin role from nested path in access token', async () => { - process.env.OPENID_ADMIN_ROLE = 'admin'; - process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'realm_access.roles'; - process.env.OPENID_ADMIN_ROLE_TOKEN_KIND = 'access'; - - jwtDecode.mockImplementation((token) => { - if (token === 'fake_access_token') { - return { - realm_access: { - roles: ['admin', 'user'], - }, - }; - } - return { - roles: ['requiredRole'], - }; - }); - - await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - - const { user } = await validate(tokenset); - - expect(user.role).toBe('ADMIN'); - }); - - it('should extract admin role from nested path in userinfo', async () => { - process.env.OPENID_ADMIN_ROLE = 'admin'; - process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'organization.permissions'; - process.env.OPENID_ADMIN_ROLE_TOKEN_KIND = 'userinfo'; - - const userinfoWithNestedGroups = { - ...tokenset.claims(), - organization: { - permissions: ['admin', 'write'], - }, - }; - - require('openid-client').fetchUserInfo.mockResolvedValue({ - organization: { - permissions: ['admin', 'write'], - }, - }); - - jwtDecode.mockReturnValue({ - roles: ['requiredRole'], - }); - - await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - - const { user } = await validate({ - ...tokenset, - claims: () => userinfoWithNestedGroups, - }); - - expect(user.role).toBe('ADMIN'); - }); - - it('should handle boolean admin role value', async () => { - process.env.OPENID_ADMIN_ROLE = 'admin'; - process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'is_admin'; - - jwtDecode.mockReturnValue({ - roles: ['requiredRole'], - is_admin: true, - }); - - await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - - const { user } = await validate(tokenset); - - expect(user.role).toBe('ADMIN'); - }); - - it('should handle string admin role value matching exactly', async () => { - process.env.OPENID_ADMIN_ROLE = 'super-admin'; - process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'role'; - - jwtDecode.mockReturnValue({ - roles: ['requiredRole'], - role: 'super-admin', - }); - - await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - - const { user } = await validate(tokenset); - - expect(user.role).toBe('ADMIN'); - }); - - it('should not set admin role when string value does not match', async () => { - process.env.OPENID_ADMIN_ROLE = 'super-admin'; - process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'role'; - - jwtDecode.mockReturnValue({ - roles: ['requiredRole'], - role: 'regular-user', - }); - - await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - - const { user } = await validate(tokenset); - - expect(user.role).toBeUndefined(); - }); - - it('should handle array admin role value', async () => { - process.env.OPENID_ADMIN_ROLE = 'site-admin'; - process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'app_roles'; - - jwtDecode.mockReturnValue({ - roles: ['requiredRole'], - app_roles: ['user', 'site-admin', 'moderator'], - }); - - await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - - const { user } = await validate(tokenset); - - expect(user.role).toBe('ADMIN'); - }); - - it('should not set admin when role is not in array', async () => { - process.env.OPENID_ADMIN_ROLE = 'site-admin'; - process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'app_roles'; - - jwtDecode.mockReturnValue({ - roles: ['requiredRole'], - app_roles: ['user', 'moderator'], - }); - - await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - - const { user } = await validate(tokenset); - - expect(user.role).toBeUndefined(); - }); - - it('should grant admin when admin role claim is a space-separated string containing the admin role', async () => { - // Arrange โ€“ IdP returns admin roles as a space-delimited string - process.env.OPENID_ADMIN_ROLE = 'site-admin'; - process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'app_roles'; - - jwtDecode.mockReturnValue({ - roles: ['requiredRole'], - app_roles: 'user site-admin moderator', - }); - - await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - - // Act - const { user } = await validate(tokenset); - - // Assert โ€“ admin role is granted after splitting the delimited string - expect(user.role).toBe('ADMIN'); - }); - - it('should not grant admin when admin role claim is a space-separated string that does not contain the admin role', async () => { - // Arrange โ€“ delimited string present but admin role is absent - process.env.OPENID_ADMIN_ROLE = 'site-admin'; - process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'app_roles'; - - jwtDecode.mockReturnValue({ - roles: ['requiredRole'], - app_roles: 'user moderator', - }); - - await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - - // Act - const { user } = await validate(tokenset); - - // Assert โ€“ admin role is not granted - expect(user.role).toBeUndefined(); - }); - - it('should handle nested path with special characters in keys', async () => { - process.env.OPENID_REQUIRED_ROLE = 'app-user'; - process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'resource_access.my-app-123.roles'; - - jwtDecode.mockReturnValue({ - resource_access: { - 'my-app-123': { - roles: ['app-user'], - }, - }, - }); - - await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - - const { user } = await validate(tokenset); - - expect(user).toBeTruthy(); - }); - - it('should handle empty object at nested path', async () => { - const { logger } = require('@librechat/data-schemas'); - process.env.OPENID_REQUIRED_ROLE = 'user'; - process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'access.roles'; - - jwtDecode.mockReturnValue({ - access: {}, - }); - - await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - - const { user } = await validate(tokenset); - - expect(logger.error).toHaveBeenCalledWith( - expect.stringContaining("Key 'access.roles' not found in id token!"), - ); - expect(user).toBe(false); - }); - - it('should handle null value at intermediate path', async () => { - const { logger } = require('@librechat/data-schemas'); - process.env.OPENID_REQUIRED_ROLE = 'user'; - process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'data.roles'; - - jwtDecode.mockReturnValue({ - data: null, - }); - - await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - - const { user } = await validate(tokenset); - - expect(logger.error).toHaveBeenCalledWith( - expect.stringContaining("Key 'data.roles' not found in id token!"), - ); - expect(user).toBe(false); - }); - - it('should reject login with invalid admin role token kind', async () => { - process.env.OPENID_ADMIN_ROLE = 'admin'; - process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'roles'; - process.env.OPENID_ADMIN_ROLE_TOKEN_KIND = 'invalid'; - - const { logger } = require('@librechat/data-schemas'); - - jwtDecode.mockReturnValue({ - roles: ['requiredRole', 'admin'], - }); - - await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - - await expect(validate(tokenset)).rejects.toThrow('Invalid admin role token kind'); - - expect(logger.error).toHaveBeenCalledWith( - expect.stringContaining( - "Invalid admin role token kind: invalid. Must be one of 'access', 'id', or 'userinfo'", - ), - ); - }); - - it('should reject login when roles path returns invalid type (object)', async () => { - const { logger } = require('@librechat/data-schemas'); - process.env.OPENID_REQUIRED_ROLE = 'app-user'; - process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'roles'; - - jwtDecode.mockReturnValue({ - roles: { admin: true, user: false }, - }); - - await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - - const { user, details } = await validate(tokenset); - - expect(logger.error).toHaveBeenCalledWith( - expect.stringContaining("Key 'roles' not found in id token!"), - ); - expect(user).toBe(false); - expect(details.message).toContain('role to log in'); - }); - - it('should reject login when roles path returns invalid type (number)', async () => { - const { logger } = require('@librechat/data-schemas'); - process.env.OPENID_REQUIRED_ROLE = 'user'; - process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'roleCount'; - - jwtDecode.mockReturnValue({ - roleCount: 5, - }); - - await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); - - const { user } = await validate(tokenset); - - expect(logger.error).toHaveBeenCalledWith( - expect.stringContaining("Key 'roleCount' not found in id token!"), - ); - expect(user).toBe(false); - }); - }); - - describe('OPENID_EMAIL_CLAIM', () => { - it('should use the default email when OPENID_EMAIL_CLAIM is not set', async () => { - const { user } = await validate(tokenset); - expect(user.email).toBe('test@example.com'); - }); - - it('should use the configured claim when OPENID_EMAIL_CLAIM is set', async () => { - process.env.OPENID_EMAIL_CLAIM = 'upn'; - const userinfo = { ...tokenset.claims(), upn: 'user@corp.example.com' }; - - const { user } = await validate({ ...tokenset, claims: () => userinfo }); - - expect(user.email).toBe('user@corp.example.com'); - expect(createUser).toHaveBeenCalledWith( - expect.objectContaining({ email: 'user@corp.example.com' }), - expect.anything(), - true, - true, - ); - }); - - it('should fall back to preferred_username when email is missing and OPENID_EMAIL_CLAIM is not set', async () => { - const userinfo = { ...tokenset.claims() }; - delete userinfo.email; - - const { user } = await validate({ ...tokenset, claims: () => userinfo }); - - expect(user.email).toBe('testusername'); - }); - - it('should fall back to upn when email and preferred_username are missing and OPENID_EMAIL_CLAIM is not set', async () => { - const userinfo = { ...tokenset.claims(), upn: 'user@corp.example.com' }; - delete userinfo.email; - delete userinfo.preferred_username; - - const { user } = await validate({ ...tokenset, claims: () => userinfo }); - - expect(user.email).toBe('user@corp.example.com'); - }); - - it('should ignore empty string OPENID_EMAIL_CLAIM and use default fallback', async () => { - process.env.OPENID_EMAIL_CLAIM = ''; - - const { user } = await validate(tokenset); - - expect(user.email).toBe('test@example.com'); - }); - - it('should trim whitespace from OPENID_EMAIL_CLAIM and resolve correctly', async () => { - process.env.OPENID_EMAIL_CLAIM = ' upn '; - const userinfo = { ...tokenset.claims(), upn: 'user@corp.example.com' }; - - const { user } = await validate({ ...tokenset, claims: () => userinfo }); - - expect(user.email).toBe('user@corp.example.com'); - }); - - it('should ignore whitespace-only OPENID_EMAIL_CLAIM and use default fallback', async () => { - process.env.OPENID_EMAIL_CLAIM = ' '; - - const { user } = await validate(tokenset); - - expect(user.email).toBe('test@example.com'); - }); - - it('should fall back to default chain with warning when configured claim is missing from userinfo', async () => { - const { logger } = require('@librechat/data-schemas'); - process.env.OPENID_EMAIL_CLAIM = 'nonexistent_claim'; - - const { user } = await validate(tokenset); - - expect(user.email).toBe('test@example.com'); - expect(logger.warn).toHaveBeenCalledWith( - expect.stringContaining('OPENID_EMAIL_CLAIM="nonexistent_claim" not present in userinfo'), - ); - }); - }); -}); +const undici = require('undici'); +const fetch = require('node-fetch'); +const jwtDecode = require('jsonwebtoken/decode'); +const { ErrorTypes } = require('librechat-data-provider'); +const { findUser, createUser, updateUser } = require('~/models'); +const { resolveAppConfigForUser } = require('@librechat/api'); +const { getAppConfig } = require('~/server/services/Config'); +const { setupOpenId } = require('./openidStrategy'); + +// --- Mocks --- +jest.mock('node-fetch'); +jest.mock('jsonwebtoken/decode'); +jest.mock('undici', () => ({ + fetch: jest.fn(), + ProxyAgent: jest.fn(), +})); +jest.mock('~/server/services/Files/strategies', () => ({ + getStrategyFunctions: jest.fn(() => ({ + saveBuffer: jest.fn().mockResolvedValue('/fake/path/to/avatar.png'), + })), +})); +jest.mock('~/server/services/Config', () => ({ + getAppConfig: jest.fn().mockResolvedValue({}), +})); +jest.mock('@librechat/api', () => ({ + ...jest.requireActual('@librechat/api'), + isEnabled: jest.fn(() => false), + isEmailDomainAllowed: jest.fn(() => true), + findOpenIDUser: jest.requireActual('@librechat/api').findOpenIDUser, + getBalanceConfig: jest.fn(() => ({ + enabled: false, + })), + resolveAppConfigForUser: jest.fn(async (_getAppConfig, _user) => ({})), +})); +jest.mock('~/models', () => ({ + findUser: jest.fn(), + createUser: jest.fn(), + updateUser: jest.fn(), +})); +jest.mock('@librechat/data-schemas', () => ({ + ...jest.requireActual('@librechat/api'), + logger: { + info: jest.fn(), + warn: jest.fn(), + debug: jest.fn(), + error: jest.fn(), + }, + hashToken: jest.fn().mockResolvedValue('hashed-token'), +})); +jest.mock('~/cache/getLogStores', () => + jest.fn(() => ({ + get: jest.fn(), + set: jest.fn(), + })), +); + +// Mock the openid-client module and all its dependencies +jest.mock('openid-client', () => { + return { + discovery: jest.fn().mockResolvedValue({ + clientId: 'fake_client_id', + clientSecret: 'fake_client_secret', + issuer: 'https://fake-issuer.com', + // Add any other properties needed by the implementation + }), + fetchUserInfo: jest.fn().mockImplementation(() => { + // Only return additional properties, but don't override any claims + return Promise.resolve({}); + }), + genericGrantRequest: jest.fn().mockResolvedValue({ + access_token: 'exchanged_graph_token', + expires_in: 3600, + }), + customFetch: Symbol('customFetch'), + }; +}); + +jest.mock('openid-client/passport', () => { + /** Store callbacks by strategy name - 'openid' and 'openidAdmin' */ + const verifyCallbacks = {}; + let lastVerifyCallback; + + const mockStrategy = jest.fn((options, verify) => { + lastVerifyCallback = verify; + return { name: 'openid', options, verify }; + }); + + return { + Strategy: mockStrategy, + /** Get the last registered callback (for backward compatibility) */ + __getVerifyCallback: () => lastVerifyCallback, + /** Store callback by name when passport.use is called */ + __setVerifyCallback: (name, callback) => { + verifyCallbacks[name] = callback; + }, + /** Get callback by strategy name */ + __getVerifyCallbackByName: (name) => verifyCallbacks[name], + }; +}); + +// Mock passport - capture strategy name and callback +jest.mock('passport', () => ({ + use: jest.fn((name, strategy) => { + const passportMock = require('openid-client/passport'); + if (strategy && strategy.verify) { + passportMock.__setVerifyCallback(name, strategy.verify); + } + }), +})); + +describe('setupOpenId', () => { + // Store a reference to the verify callback once it's set up + let verifyCallback; + + // Helper to wrap the verify callback in a promise + const validate = (tokenset) => + new Promise((resolve, reject) => { + verifyCallback(tokenset, (err, user, details) => { + if (err) { + reject(err); + } else { + resolve({ user, details }); + } + }); + }); + + const tokenset = { + id_token: 'fake_id_token', + access_token: 'fake_access_token', + claims: () => ({ + sub: '1234', + email: 'test@example.com', + email_verified: true, + given_name: 'First', + family_name: 'Last', + name: 'My Full', + preferred_username: 'testusername', + username: 'flast', + picture: 'https://example.com/avatar.png', + }), + }; + + beforeEach(async () => { + // Clear previous mock calls and reset implementations + jest.clearAllMocks(); + + // Reset environment variables needed by the strategy + process.env.OPENID_ISSUER = 'https://fake-issuer.com'; + process.env.OPENID_CLIENT_ID = 'fake_client_id'; + process.env.OPENID_CLIENT_SECRET = 'fake_client_secret'; + process.env.DOMAIN_SERVER = 'https://example.com'; + process.env.OPENID_CALLBACK_URL = '/callback'; + process.env.OPENID_SCOPE = 'openid profile email'; + process.env.OPENID_REQUIRED_ROLE = 'requiredRole'; + process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'roles'; + process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; + process.env.OPENID_ADMIN_ROLE = 'admin'; + process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'permissions'; + process.env.OPENID_ADMIN_ROLE_TOKEN_KIND = 'id'; + delete process.env.OPENID_USERNAME_CLAIM; + delete process.env.OPENID_NAME_CLAIM; + delete process.env.OPENID_EMAIL_CLAIM; + delete process.env.PROXY; + delete process.env.OPENID_USE_PKCE; + + // Default jwtDecode mock returns a token that includes the required role. + jwtDecode.mockReturnValue({ + roles: ['requiredRole'], + permissions: ['admin'], + }); + + // By default, assume that no user is found, so createUser will be called + findUser.mockResolvedValue(null); + createUser.mockImplementation(async (userData) => { + // simulate created user with an _id property + return { _id: 'newUserId', ...userData }; + }); + updateUser.mockImplementation(async (id, userData) => { + return { _id: id, ...userData }; + }); + + // For image download, simulate a successful response + const fakeBuffer = Buffer.from('fake image'); + const fakeResponse = { + ok: true, + buffer: jest.fn().mockResolvedValue(fakeBuffer), + }; + fetch.mockResolvedValue(fakeResponse); + + // Call the setup function and capture the verify callback for the regular 'openid' strategy + // (not 'openidAdmin' which requires existing users) + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + }); + + it('should create a new user with correct username when preferred_username claim exists', async () => { + // Arrange โ€“ our userinfo already has preferred_username 'testusername' + const userinfo = tokenset.claims(); + + // Act + const { user } = await validate(tokenset); + + // Assert + expect(user.username).toBe(userinfo.preferred_username); + expect(createUser).toHaveBeenCalledWith( + expect.objectContaining({ + provider: 'openid', + openidId: userinfo.sub, + username: userinfo.preferred_username, + email: userinfo.email, + name: `${userinfo.given_name} ${userinfo.family_name}`, + }), + { enabled: false }, + true, + true, + ); + }); + + it('should use username as username when preferred_username claim is missing', async () => { + // Arrange โ€“ remove preferred_username from userinfo + const userinfo = { ...tokenset.claims() }; + delete userinfo.preferred_username; + // Expect the username to be the "username" + const expectUsername = userinfo.username; + + // Act + const { user } = await validate({ ...tokenset, claims: () => userinfo }); + + // Assert + expect(user.username).toBe(expectUsername); + expect(createUser).toHaveBeenCalledWith( + expect.objectContaining({ username: expectUsername }), + { enabled: false }, + true, + true, + ); + }); + + it('should use email as username when username and preferred_username are missing', async () => { + // Arrange โ€“ remove username and preferred_username + const userinfo = { ...tokenset.claims() }; + delete userinfo.username; + delete userinfo.preferred_username; + const expectUsername = userinfo.email; + + // Act + const { user } = await validate({ ...tokenset, claims: () => userinfo }); + + // Assert + expect(user.username).toBe(expectUsername); + expect(createUser).toHaveBeenCalledWith( + expect.objectContaining({ username: expectUsername }), + { enabled: false }, + true, + true, + ); + }); + + it('should override username with OPENID_USERNAME_CLAIM when set', async () => { + // Arrange โ€“ set OPENID_USERNAME_CLAIM so that the sub claim is used + process.env.OPENID_USERNAME_CLAIM = 'sub'; + const userinfo = tokenset.claims(); + + // Act + const { user } = await validate(tokenset); + + // Assert โ€“ username should equal the sub (converted as-is) + expect(user.username).toBe(userinfo.sub); + expect(createUser).toHaveBeenCalledWith( + expect.objectContaining({ username: userinfo.sub }), + { enabled: false }, + true, + true, + ); + }); + + it('should set the full name correctly when given_name and family_name exist', async () => { + // Arrange + const userinfo = tokenset.claims(); + const expectedFullName = `${userinfo.given_name} ${userinfo.family_name}`; + + // Act + const { user } = await validate(tokenset); + + // Assert + expect(user.name).toBe(expectedFullName); + }); + + it('should override full name with OPENID_NAME_CLAIM when set', async () => { + // Arrange โ€“ use the name claim as the full name + process.env.OPENID_NAME_CLAIM = 'name'; + const userinfo = { ...tokenset.claims(), name: 'Custom Name' }; + + // Act + const { user } = await validate({ ...tokenset, claims: () => userinfo }); + + // Assert + expect(user.name).toBe('Custom Name'); + }); + + it('should update an existing user on login', async () => { + // Arrange โ€“ simulate that a user already exists with openid provider + const existingUser = { + _id: 'existingUserId', + provider: 'openid', + email: tokenset.claims().email, + openidId: '', + username: '', + name: '', + }; + findUser.mockImplementation(async (query) => { + if (query.openidId === tokenset.claims().sub || query.email === tokenset.claims().email) { + return existingUser; + } + return null; + }); + + const userinfo = tokenset.claims(); + + // Act + await validate(tokenset); + + // Assert โ€“ updateUser should be called and the user object updated + expect(updateUser).toHaveBeenCalledWith( + existingUser._id, + expect.objectContaining({ + provider: 'openid', + openidId: userinfo.sub, + username: userinfo.preferred_username, + name: `${userinfo.given_name} ${userinfo.family_name}`, + }), + ); + }); + + it('should block login when email exists with different provider', async () => { + // Arrange โ€“ simulate that a user exists with same email but different provider + const existingUser = { + _id: 'existingUserId', + provider: 'google', + email: tokenset.claims().email, + googleId: 'some-google-id', + username: 'existinguser', + name: 'Existing User', + }; + findUser.mockImplementation(async (query) => { + if (query.email === tokenset.claims().email && !query.provider) { + return existingUser; + } + return null; + }); + + // Act + const result = await validate(tokenset); + + // Assert โ€“ verify that the strategy rejects login + expect(result.user).toBe(false); + expect(result.details.message).toBe(ErrorTypes.AUTH_FAILED); + expect(createUser).not.toHaveBeenCalled(); + expect(updateUser).not.toHaveBeenCalled(); + }); + + it('should block login when email fallback finds user with mismatched openidId', async () => { + const existingUser = { + _id: 'existingUserId', + provider: 'openid', + openidId: 'different-sub-claim', + email: tokenset.claims().email, + username: 'existinguser', + name: 'Existing User', + }; + findUser.mockImplementation(async (query) => { + if (query.$or) { + return null; + } + if (query.email === tokenset.claims().email) { + return existingUser; + } + return null; + }); + + const result = await validate(tokenset); + + expect(result.user).toBe(false); + expect(result.details.message).toBe(ErrorTypes.AUTH_FAILED); + expect(createUser).not.toHaveBeenCalled(); + expect(updateUser).not.toHaveBeenCalled(); + }); + + it('should enforce the required role and reject login if missing', async () => { + // Arrange โ€“ simulate a token without the required role. + jwtDecode.mockReturnValue({ + roles: ['SomeOtherRole'], + }); + + // Act + const { user, details } = await validate(tokenset); + + // Assert โ€“ verify that the strategy rejects login + expect(user).toBe(false); + expect(details.message).toBe('You must have "requiredRole" role to log in.'); + }); + + it('should not treat substring matches in string roles as satisfying required role', async () => { + // Arrange โ€“ override required role to "read" then re-setup + process.env.OPENID_REQUIRED_ROLE = 'read'; + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + // Token contains "bread" which *contains* "read" as a substring + jwtDecode.mockReturnValue({ + roles: 'bread', + }); + + // Act + const { user, details } = await validate(tokenset); + + // Assert โ€“ verify that substring match does not grant access + expect(user).toBe(false); + expect(details.message).toBe('You must have "read" role to log in.'); + }); + + it('should allow login when roles claim is a space-separated string containing the required role', async () => { + // Arrange โ€“ IdP returns roles as a space-delimited string + jwtDecode.mockReturnValue({ + roles: 'role1 role2 requiredRole', + }); + + // Act + const { user } = await validate(tokenset); + + // Assert โ€“ login succeeds when required role is present after splitting + expect(user).toBeTruthy(); + expect(createUser).toHaveBeenCalled(); + }); + + it('should allow login when roles claim is a comma-separated string containing the required role', async () => { + // Arrange โ€“ IdP returns roles as a comma-delimited string + jwtDecode.mockReturnValue({ + roles: 'role1,role2,requiredRole', + }); + + // Act + const { user } = await validate(tokenset); + + // Assert โ€“ login succeeds when required role is present after splitting + expect(user).toBeTruthy(); + expect(createUser).toHaveBeenCalled(); + }); + + it('should allow login when roles claim is a mixed comma-and-space-separated string containing the required role', async () => { + // Arrange โ€“ IdP returns roles with comma-and-space delimiters + jwtDecode.mockReturnValue({ + roles: 'role1, role2, requiredRole', + }); + + // Act + const { user } = await validate(tokenset); + + // Assert โ€“ login succeeds when required role is present after splitting + expect(user).toBeTruthy(); + expect(createUser).toHaveBeenCalled(); + }); + + it('should reject login when roles claim is a space-separated string that does not contain the required role', async () => { + // Arrange โ€“ IdP returns a delimited string but required role is absent + jwtDecode.mockReturnValue({ + roles: 'role1 role2 otherRole', + }); + + // Act + const { user, details } = await validate(tokenset); + + // Assert โ€“ login is rejected with the correct error message + expect(user).toBe(false); + expect(details.message).toBe('You must have "requiredRole" role to log in.'); + }); + + it('should allow login when single required role is present (backward compatibility)', async () => { + // Arrange โ€“ ensure single role configuration (as set in beforeEach) + // OPENID_REQUIRED_ROLE = 'requiredRole' + // Default jwtDecode mock in beforeEach already returns this role + jwtDecode.mockReturnValue({ + roles: ['requiredRole', 'anotherRole'], + }); + + // Act + const { user } = await validate(tokenset); + + // Assert โ€“ verify that login succeeds with single role configuration + expect(user).toBeTruthy(); + expect(user.email).toBe(tokenset.claims().email); + expect(user.username).toBe(tokenset.claims().preferred_username); + expect(createUser).toHaveBeenCalled(); + }); + + describe('group overage and groups handling', () => { + it.each([ + ['groups array contains required group', ['group-required', 'other-group'], true, undefined], + [ + 'groups array missing required group', + ['other-group'], + false, + 'You must have "group-required" role to log in.', + ], + ['groups string equals required group', 'group-required', true, undefined], + [ + 'groups string is other group', + 'other-group', + false, + 'You must have "group-required" role to log in.', + ], + ])( + 'uses groups claim directly when %s (no overage)', + async (_label, groupsClaim, expectedAllowed, expectedMessage) => { + process.env.OPENID_REQUIRED_ROLE = 'group-required'; + process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups'; + process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; + + jwtDecode.mockReturnValue({ + groups: groupsClaim, + permissions: ['admin'], + }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + const { user, details } = await validate(tokenset); + + expect(undici.fetch).not.toHaveBeenCalled(); + expect(Boolean(user)).toBe(expectedAllowed); + expect(details?.message).toBe(expectedMessage); + }, + ); + + it.each([ + ['token kind is not id', { kind: 'access', path: 'groups', decoded: { hasgroups: true } }], + ['parameter path is not groups', { kind: 'id', path: 'roles', decoded: { hasgroups: true } }], + ['decoded token is falsy', { kind: 'id', path: 'groups', decoded: null }], + [ + 'no overage indicators in decoded token', + { + kind: 'id', + path: 'groups', + decoded: { + permissions: ['admin'], + }, + }, + ], + [ + 'only _claim_names present (no _claim_sources)', + { + kind: 'id', + path: 'groups', + decoded: { + _claim_names: { groups: 'src1' }, + permissions: ['admin'], + }, + }, + ], + [ + 'only _claim_sources present (no _claim_names)', + { + kind: 'id', + path: 'groups', + decoded: { + _claim_sources: { src1: { endpoint: 'https://graph.windows.net/ignored' } }, + permissions: ['admin'], + }, + }, + ], + ])('does not attempt overage resolution when %s', async (_label, cfg) => { + process.env.OPENID_REQUIRED_ROLE = 'group-required'; + process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = cfg.path; + process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = cfg.kind; + + jwtDecode.mockReturnValue(cfg.decoded); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + const { user, details } = await validate(tokenset); + + expect(undici.fetch).not.toHaveBeenCalled(); + expect(user).toBe(false); + expect(details.message).toBe('You must have "group-required" role to log in.'); + const { logger } = require('@librechat/data-schemas'); + const expectedTokenKind = cfg.kind === 'access' ? 'access token' : 'id token'; + expect(logger.error).toHaveBeenCalledWith( + expect.stringContaining(`Key '${cfg.path}' not found in ${expectedTokenKind}!`), + ); + }); + }); + + describe('resolving groups via Microsoft Graph', () => { + it('denies login and does not call Graph when access token is missing', async () => { + process.env.OPENID_REQUIRED_ROLE = 'group-required'; + process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups'; + process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; + + const { logger } = require('@librechat/data-schemas'); + + jwtDecode.mockReturnValue({ + hasgroups: true, + permissions: ['admin'], + }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + const tokensetWithoutAccess = { + ...tokenset, + access_token: undefined, + }; + + const { user, details } = await validate(tokensetWithoutAccess); + + expect(user).toBe(false); + expect(details.message).toBe('You must have "group-required" role to log in.'); + + expect(undici.fetch).not.toHaveBeenCalled(); + expect(logger.error).toHaveBeenCalledWith( + expect.stringContaining('Access token missing; cannot resolve group overage'), + ); + }); + + it.each([ + [ + 'Graph returns HTTP error', + async () => ({ + ok: false, + status: 403, + statusText: 'Forbidden', + json: async () => ({}), + }), + [ + '[openidStrategy] Failed to resolve groups via Microsoft Graph getMemberObjects: HTTP 403 Forbidden', + ], + ], + [ + 'Graph network error', + async () => { + throw new Error('network error'); + }, + [ + '[openidStrategy] Error resolving groups via Microsoft Graph getMemberObjects:', + expect.any(Error), + ], + ], + [ + 'Graph returns unexpected shape (no value)', + async () => ({ + ok: true, + status: 200, + statusText: 'OK', + json: async () => ({}), + }), + [ + '[openidStrategy] Unexpected response format when resolving groups via Microsoft Graph getMemberObjects', + ], + ], + [ + 'Graph returns invalid value type', + async () => ({ + ok: true, + status: 200, + statusText: 'OK', + json: async () => ({ value: 'not-an-array' }), + }), + [ + '[openidStrategy] Unexpected response format when resolving groups via Microsoft Graph getMemberObjects', + ], + ], + ])( + 'denies login when overage resolution fails because %s', + async (_label, setupFetch, expectedErrorArgs) => { + process.env.OPENID_REQUIRED_ROLE = 'group-required'; + process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups'; + process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; + + const { logger } = require('@librechat/data-schemas'); + + jwtDecode.mockReturnValue({ + hasgroups: true, + permissions: ['admin'], + }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + undici.fetch.mockImplementation(setupFetch); + + const { user, details } = await validate(tokenset); + + expect(undici.fetch).toHaveBeenCalled(); + expect(user).toBe(false); + expect(details.message).toBe('You must have "group-required" role to log in.'); + + expect(logger.error).toHaveBeenCalledWith(...expectedErrorArgs); + }, + ); + + it.each([ + [ + 'hasgroups overage and Graph contains required group', + { + hasgroups: true, + }, + ['group-required', 'some-other-group'], + true, + ], + [ + '_claim_* overage and Graph contains required group', + { + _claim_names: { groups: 'src1' }, + _claim_sources: { src1: { endpoint: 'https://graph.windows.net/ignored' } }, + }, + ['group-required', 'some-other-group'], + true, + ], + [ + 'hasgroups overage and Graph does NOT contain required group', + { + hasgroups: true, + }, + ['some-other-group'], + false, + ], + [ + '_claim_* overage and Graph does NOT contain required group', + { + _claim_names: { groups: 'src1' }, + _claim_sources: { src1: { endpoint: 'https://graph.windows.net/ignored' } }, + }, + ['some-other-group'], + false, + ], + ])( + 'resolves groups via Microsoft Graph when %s', + async (_label, decodedTokenValue, graphGroups, expectedAllowed) => { + process.env.OPENID_REQUIRED_ROLE = 'group-required'; + process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups'; + process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; + + const { logger } = require('@librechat/data-schemas'); + + jwtDecode.mockReturnValue(decodedTokenValue); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + undici.fetch.mockResolvedValue({ + ok: true, + status: 200, + statusText: 'OK', + json: async () => ({ + value: graphGroups, + }), + }); + + const { user } = await validate(tokenset); + + expect(undici.fetch).toHaveBeenCalledWith( + 'https://graph.microsoft.com/v1.0/me/getMemberObjects', + expect.objectContaining({ + method: 'POST', + headers: expect.objectContaining({ + Authorization: 'Bearer exchanged_graph_token', + }), + }), + ); + expect(Boolean(user)).toBe(expectedAllowed); + + expect(logger.debug).toHaveBeenCalledWith( + expect.stringContaining( + `Successfully resolved ${graphGroups.length} groups via Microsoft Graph getMemberObjects`, + ), + ); + }, + ); + }); + + describe('OBO token exchange for overage', () => { + it('exchanges access token via OBO before calling Graph API', async () => { + const openidClient = require('openid-client'); + process.env.OPENID_REQUIRED_ROLE = 'group-required'; + process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups'; + process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; + + jwtDecode.mockReturnValue({ hasgroups: true }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + undici.fetch.mockResolvedValue({ + ok: true, + status: 200, + statusText: 'OK', + json: async () => ({ value: ['group-required'] }), + }); + + await validate(tokenset); + + expect(openidClient.genericGrantRequest).toHaveBeenCalledWith( + expect.anything(), + 'urn:ietf:params:oauth:grant-type:jwt-bearer', + expect.objectContaining({ + scope: 'https://graph.microsoft.com/User.Read', + assertion: tokenset.access_token, + requested_token_use: 'on_behalf_of', + }), + ); + + expect(undici.fetch).toHaveBeenCalledWith( + 'https://graph.microsoft.com/v1.0/me/getMemberObjects', + expect.objectContaining({ + headers: expect.objectContaining({ + Authorization: 'Bearer exchanged_graph_token', + }), + }), + ); + }); + + it('caches the exchanged token and reuses it on subsequent calls', async () => { + const openidClient = require('openid-client'); + const getLogStores = require('~/cache/getLogStores'); + const mockSet = jest.fn(); + const mockGet = jest + .fn() + .mockResolvedValueOnce(undefined) + .mockResolvedValueOnce({ access_token: 'exchanged_graph_token' }); + getLogStores.mockReturnValue({ get: mockGet, set: mockSet }); + + process.env.OPENID_REQUIRED_ROLE = 'group-required'; + process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups'; + process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; + + jwtDecode.mockReturnValue({ hasgroups: true }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + undici.fetch.mockResolvedValue({ + ok: true, + status: 200, + statusText: 'OK', + json: async () => ({ value: ['group-required'] }), + }); + + // First call: cache miss โ†’ OBO exchange โ†’ cache set + await validate(tokenset); + expect(mockSet).toHaveBeenCalledWith( + '1234:overage', + { access_token: 'exchanged_graph_token' }, + 3600000, + ); + expect(openidClient.genericGrantRequest).toHaveBeenCalledTimes(1); + + // Second call: cache hit โ†’ no new OBO exchange + openidClient.genericGrantRequest.mockClear(); + await validate(tokenset); + expect(openidClient.genericGrantRequest).not.toHaveBeenCalled(); + }); + }); + + describe('admin role group overage', () => { + it('resolves admin groups via Graph when overage is detected for admin role', async () => { + process.env.OPENID_REQUIRED_ROLE = 'group-required'; + process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups'; + process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; + process.env.OPENID_ADMIN_ROLE = 'admin-group-id'; + process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'groups'; + process.env.OPENID_ADMIN_ROLE_TOKEN_KIND = 'id'; + + jwtDecode.mockReturnValue({ hasgroups: true }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + undici.fetch.mockResolvedValue({ + ok: true, + status: 200, + statusText: 'OK', + json: async () => ({ value: ['group-required', 'admin-group-id'] }), + }); + + const { user } = await validate(tokenset); + + expect(user.role).toBe('ADMIN'); + }); + + it('does not grant admin when overage groups do not contain admin role', async () => { + process.env.OPENID_REQUIRED_ROLE = 'group-required'; + process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups'; + process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; + process.env.OPENID_ADMIN_ROLE = 'admin-group-id'; + process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'groups'; + process.env.OPENID_ADMIN_ROLE_TOKEN_KIND = 'id'; + + jwtDecode.mockReturnValue({ hasgroups: true }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + undici.fetch.mockResolvedValue({ + ok: true, + status: 200, + statusText: 'OK', + json: async () => ({ value: ['group-required', 'other-group'] }), + }); + + const { user } = await validate(tokenset); + + expect(user).toBeTruthy(); + expect(user.role).toBeUndefined(); + }); + + it('reuses already-resolved overage groups for admin role check (no duplicate Graph call)', async () => { + process.env.OPENID_REQUIRED_ROLE = 'group-required'; + process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups'; + process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; + process.env.OPENID_ADMIN_ROLE = 'admin-group-id'; + process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'groups'; + process.env.OPENID_ADMIN_ROLE_TOKEN_KIND = 'id'; + + jwtDecode.mockReturnValue({ hasgroups: true }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + undici.fetch.mockResolvedValue({ + ok: true, + status: 200, + statusText: 'OK', + json: async () => ({ value: ['group-required', 'admin-group-id'] }), + }); + + await validate(tokenset); + + // Graph API should be called only once (for required role), admin role reuses the result + expect(undici.fetch).toHaveBeenCalledTimes(1); + }); + + it('demotes existing admin when overage groups no longer contain admin role', async () => { + process.env.OPENID_REQUIRED_ROLE = 'group-required'; + process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups'; + process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; + process.env.OPENID_ADMIN_ROLE = 'admin-group-id'; + process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'groups'; + process.env.OPENID_ADMIN_ROLE_TOKEN_KIND = 'id'; + + const existingAdminUser = { + _id: 'existingAdminId', + provider: 'openid', + email: tokenset.claims().email, + openidId: tokenset.claims().sub, + username: 'adminuser', + name: 'Admin User', + role: 'ADMIN', + }; + + findUser.mockImplementation(async (query) => { + if (query.openidId === tokenset.claims().sub || query.email === tokenset.claims().email) { + return existingAdminUser; + } + return null; + }); + + jwtDecode.mockReturnValue({ hasgroups: true }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + undici.fetch.mockResolvedValue({ + ok: true, + status: 200, + statusText: 'OK', + json: async () => ({ value: ['group-required'] }), + }); + + const { user } = await validate(tokenset); + + expect(user.role).toBe('USER'); + }); + + it('does not attempt overage for admin role when token kind is not id', async () => { + process.env.OPENID_REQUIRED_ROLE = 'requiredRole'; + process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'roles'; + process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; + process.env.OPENID_ADMIN_ROLE = 'admin'; + process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'groups'; + process.env.OPENID_ADMIN_ROLE_TOKEN_KIND = 'access'; + + jwtDecode.mockReturnValue({ + roles: ['requiredRole'], + hasgroups: true, + }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + const { user } = await validate(tokenset); + + // No Graph call since admin uses access token (not id) + expect(undici.fetch).not.toHaveBeenCalled(); + expect(user.role).toBeUndefined(); + }); + + it('resolves admin via Graph independently when OPENID_REQUIRED_ROLE is not configured', async () => { + delete process.env.OPENID_REQUIRED_ROLE; + process.env.OPENID_ADMIN_ROLE = 'admin-group-id'; + process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'groups'; + process.env.OPENID_ADMIN_ROLE_TOKEN_KIND = 'id'; + + jwtDecode.mockReturnValue({ hasgroups: true }); + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + undici.fetch.mockResolvedValue({ + ok: true, + status: 200, + statusText: 'OK', + json: async () => ({ value: ['admin-group-id'] }), + }); + + const { user } = await validate(tokenset); + expect(user.role).toBe('ADMIN'); + expect(undici.fetch).toHaveBeenCalledTimes(1); + }); + + it('denies admin when OPENID_REQUIRED_ROLE is absent and Graph does not contain admin group', async () => { + delete process.env.OPENID_REQUIRED_ROLE; + process.env.OPENID_ADMIN_ROLE = 'admin-group-id'; + process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'groups'; + process.env.OPENID_ADMIN_ROLE_TOKEN_KIND = 'id'; + + jwtDecode.mockReturnValue({ hasgroups: true }); + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + undici.fetch.mockResolvedValue({ + ok: true, + status: 200, + statusText: 'OK', + json: async () => ({ value: ['other-group'] }), + }); + + const { user } = await validate(tokenset); + expect(user).toBeTruthy(); + expect(user.role).toBeUndefined(); + }); + + it('denies login and logs error when OBO exchange throws', async () => { + const openidClient = require('openid-client'); + process.env.OPENID_REQUIRED_ROLE = 'group-required'; + process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups'; + process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; + + jwtDecode.mockReturnValue({ hasgroups: true }); + openidClient.genericGrantRequest.mockRejectedValueOnce(new Error('OBO exchange rejected')); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + const { user, details } = await validate(tokenset); + expect(user).toBe(false); + expect(details.message).toBe('You must have "group-required" role to log in.'); + expect(undici.fetch).not.toHaveBeenCalled(); + }); + + it('denies login when OBO exchange returns no access_token', async () => { + const openidClient = require('openid-client'); + process.env.OPENID_REQUIRED_ROLE = 'group-required'; + process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups'; + process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; + + jwtDecode.mockReturnValue({ hasgroups: true }); + openidClient.genericGrantRequest.mockResolvedValueOnce({ expires_in: 3600 }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + const { user, details } = await validate(tokenset); + expect(user).toBe(false); + expect(details.message).toBe('You must have "group-required" role to log in.'); + expect(undici.fetch).not.toHaveBeenCalled(); + }); + }); + + it('should attempt to download and save the avatar if picture is provided', async () => { + // Act + const { user } = await validate(tokenset); + + // Assert โ€“ verify that download was attempted and the avatar field was set via updateUser + expect(fetch).toHaveBeenCalled(); + // Our mock getStrategyFunctions.saveBuffer returns '/fake/path/to/avatar.png' + expect(user.avatar).toBe('/fake/path/to/avatar.png'); + }); + + it('should not attempt to download avatar if picture is not provided', async () => { + // Arrange โ€“ remove picture + const userinfo = { ...tokenset.claims() }; + delete userinfo.picture; + + // Act + await validate({ ...tokenset, claims: () => userinfo }); + + // Assert โ€“ fetch should not be called and avatar should remain undefined or empty + expect(fetch).not.toHaveBeenCalled(); + // Depending on your implementation, user.avatar may be undefined or an empty string. + }); + + it('should support comma-separated multiple roles', async () => { + // Arrange + process.env.OPENID_REQUIRED_ROLE = 'someRole,anotherRole,admin'; + await setupOpenId(); // Re-initialize the strategy + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + jwtDecode.mockReturnValue({ + roles: ['anotherRole', 'aThirdRole'], + }); + + // Act + const { user } = await validate(tokenset); + + // Assert + expect(user).toBeTruthy(); + expect(user.email).toBe(tokenset.claims().email); + }); + + it('should reject login when user has none of the required multiple roles', async () => { + // Arrange + process.env.OPENID_REQUIRED_ROLE = 'someRole,anotherRole,admin'; + await setupOpenId(); // Re-initialize the strategy + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + jwtDecode.mockReturnValue({ + roles: ['aThirdRole', 'aFourthRole'], + }); + + // Act + const { user, details } = await validate(tokenset); + + // Assert + expect(user).toBe(false); + expect(details.message).toBe( + 'You must have one of: "someRole", "anotherRole", "admin" role to log in.', + ); + }); + + it('should handle spaces in comma-separated roles', async () => { + // Arrange + process.env.OPENID_REQUIRED_ROLE = ' someRole , anotherRole , admin '; + await setupOpenId(); // Re-initialize the strategy + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + jwtDecode.mockReturnValue({ + roles: ['someRole'], + }); + + // Act + const { user } = await validate(tokenset); + + // Assert + expect(user).toBeTruthy(); + }); + + it('should default to usePKCE false when OPENID_USE_PKCE is not defined', async () => { + const OpenIDStrategy = require('openid-client/passport').Strategy; + + delete process.env.OPENID_USE_PKCE; + await setupOpenId(); + + const callOptions = OpenIDStrategy.mock.calls[OpenIDStrategy.mock.calls.length - 1][0]; + expect(callOptions.usePKCE).toBe(false); + expect(callOptions.params?.code_challenge_method).toBeUndefined(); + }); + + it('should attach federatedTokens to user object for token propagation', async () => { + // Arrange - setup tokenset with access token, id token, refresh token, and expiration + const tokensetWithTokens = { + ...tokenset, + access_token: 'mock_access_token_abc123', + id_token: 'mock_id_token_def456', + refresh_token: 'mock_refresh_token_xyz789', + expires_at: 1234567890, + }; + + // Act - validate with the tokenset containing tokens + const { user } = await validate(tokensetWithTokens); + + // Assert - verify federatedTokens object is attached with correct values + expect(user.federatedTokens).toBeDefined(); + expect(user.federatedTokens).toEqual({ + access_token: 'mock_access_token_abc123', + id_token: 'mock_id_token_def456', + refresh_token: 'mock_refresh_token_xyz789', + expires_at: 1234567890, + }); + }); + + it('should include id_token in federatedTokens distinct from access_token', async () => { + // Arrange - use different values for access_token and id_token + const tokensetWithTokens = { + ...tokenset, + access_token: 'the_access_token', + id_token: 'the_id_token', + refresh_token: 'the_refresh_token', + expires_at: 9999999999, + }; + + // Act + const { user } = await validate(tokensetWithTokens); + + // Assert - id_token and access_token must be different values + expect(user.federatedTokens.access_token).toBe('the_access_token'); + expect(user.federatedTokens.id_token).toBe('the_id_token'); + expect(user.federatedTokens.id_token).not.toBe(user.federatedTokens.access_token); + }); + + it('should include tokenset along with federatedTokens', async () => { + // Arrange + const tokensetWithTokens = { + ...tokenset, + access_token: 'test_access_token', + id_token: 'test_id_token', + refresh_token: 'test_refresh_token', + expires_at: 9999999999, + }; + + // Act + const { user } = await validate(tokensetWithTokens); + + // Assert - both tokenset and federatedTokens should be present + expect(user.tokenset).toBeDefined(); + expect(user.federatedTokens).toBeDefined(); + expect(user.tokenset.access_token).toBe('test_access_token'); + expect(user.tokenset.id_token).toBe('test_id_token'); + expect(user.federatedTokens.access_token).toBe('test_access_token'); + expect(user.federatedTokens.id_token).toBe('test_id_token'); + }); + + it('should set role to "ADMIN" if OPENID_ADMIN_ROLE is set and user has that role', async () => { + // Act + const { user } = await validate(tokenset); + + // Assert โ€“ verify that the user role is set to "ADMIN" + expect(user.role).toBe('ADMIN'); + }); + + it('should not set user role if OPENID_ADMIN_ROLE is set but the user does not have that role', async () => { + // Arrange โ€“ simulate a token without the admin permission + jwtDecode.mockReturnValue({ + roles: ['requiredRole'], + permissions: ['not-admin'], + }); + + // Act + const { user } = await validate(tokenset); + + // Assert โ€“ verify that the user role is not defined + expect(user.role).toBeUndefined(); + }); + + it('should demote existing admin user when admin role is removed from token', async () => { + // Arrange โ€“ simulate an existing user who is currently an admin + const existingAdminUser = { + _id: 'existingAdminId', + provider: 'openid', + email: tokenset.claims().email, + openidId: tokenset.claims().sub, + username: 'adminuser', + name: 'Admin User', + role: 'ADMIN', + }; + + findUser.mockImplementation(async (query) => { + if (query.openidId === tokenset.claims().sub || query.email === tokenset.claims().email) { + return existingAdminUser; + } + return null; + }); + + // Token without admin permission + jwtDecode.mockReturnValue({ + roles: ['requiredRole'], + permissions: ['not-admin'], + }); + + const { logger } = require('@librechat/data-schemas'); + + // Act + const { user } = await validate(tokenset); + + // Assert โ€“ verify that the user was demoted + expect(user.role).toBe('USER'); + expect(updateUser).toHaveBeenCalledWith( + existingAdminUser._id, + expect.objectContaining({ + role: 'USER', + }), + ); + expect(logger.info).toHaveBeenCalledWith( + expect.stringContaining('demoted from admin - role no longer present in token'), + ); + }); + + it('should NOT demote admin user when admin role env vars are not configured', async () => { + // Arrange โ€“ remove admin role env vars + delete process.env.OPENID_ADMIN_ROLE; + delete process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH; + delete process.env.OPENID_ADMIN_ROLE_TOKEN_KIND; + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + // Simulate an existing admin user + const existingAdminUser = { + _id: 'existingAdminId', + provider: 'openid', + email: tokenset.claims().email, + openidId: tokenset.claims().sub, + username: 'adminuser', + name: 'Admin User', + role: 'ADMIN', + }; + + findUser.mockImplementation(async (query) => { + if (query.openidId === tokenset.claims().sub || query.email === tokenset.claims().email) { + return existingAdminUser; + } + return null; + }); + + jwtDecode.mockReturnValue({ + roles: ['requiredRole'], + }); + + // Act + const { user } = await validate(tokenset); + + // Assert โ€“ verify that the admin user was NOT demoted + expect(user.role).toBe('ADMIN'); + expect(updateUser).toHaveBeenCalledWith( + existingAdminUser._id, + expect.objectContaining({ + role: 'ADMIN', + }), + ); + }); + + describe('lodash get - nested path extraction', () => { + it('should extract roles from deeply nested token path', async () => { + process.env.OPENID_REQUIRED_ROLE = 'app-user'; + process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'resource_access.my-client.roles'; + + jwtDecode.mockReturnValue({ + resource_access: { + 'my-client': { + roles: ['app-user', 'viewer'], + }, + }, + }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + const { user } = await validate(tokenset); + + expect(user).toBeTruthy(); + expect(user.email).toBe(tokenset.claims().email); + }); + + it('should extract roles from three-level nested path', async () => { + process.env.OPENID_REQUIRED_ROLE = 'editor'; + process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'data.access.permissions.roles'; + + jwtDecode.mockReturnValue({ + data: { + access: { + permissions: { + roles: ['editor', 'reader'], + }, + }, + }, + }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + const { user } = await validate(tokenset); + + expect(user).toBeTruthy(); + }); + + it('should log error and reject login when required role path does not exist in token', async () => { + const { logger } = require('@librechat/data-schemas'); + process.env.OPENID_REQUIRED_ROLE = 'app-user'; + process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'resource_access.nonexistent.roles'; + + jwtDecode.mockReturnValue({ + resource_access: { + 'my-client': { + roles: ['app-user'], + }, + }, + }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + const { user, details } = await validate(tokenset); + + expect(logger.error).toHaveBeenCalledWith( + expect.stringContaining("Key 'resource_access.nonexistent.roles' not found in id token!"), + ); + expect(user).toBe(false); + expect(details.message).toContain('role to log in'); + }); + + it('should handle missing intermediate nested path gracefully', async () => { + const { logger } = require('@librechat/data-schemas'); + process.env.OPENID_REQUIRED_ROLE = 'user'; + process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'org.team.roles'; + + jwtDecode.mockReturnValue({ + org: { + other: 'value', + }, + }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + const { user } = await validate(tokenset); + + expect(logger.error).toHaveBeenCalledWith( + expect.stringContaining("Key 'org.team.roles' not found in id token!"), + ); + expect(user).toBe(false); + }); + + it('should extract admin role from nested path in access token', async () => { + process.env.OPENID_ADMIN_ROLE = 'admin'; + process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'realm_access.roles'; + process.env.OPENID_ADMIN_ROLE_TOKEN_KIND = 'access'; + + jwtDecode.mockImplementation((token) => { + if (token === 'fake_access_token') { + return { + realm_access: { + roles: ['admin', 'user'], + }, + }; + } + return { + roles: ['requiredRole'], + }; + }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + const { user } = await validate(tokenset); + + expect(user.role).toBe('ADMIN'); + }); + + it('should extract admin role from nested path in userinfo', async () => { + process.env.OPENID_ADMIN_ROLE = 'admin'; + process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'organization.permissions'; + process.env.OPENID_ADMIN_ROLE_TOKEN_KIND = 'userinfo'; + + const userinfoWithNestedGroups = { + ...tokenset.claims(), + organization: { + permissions: ['admin', 'write'], + }, + }; + + require('openid-client').fetchUserInfo.mockResolvedValue({ + organization: { + permissions: ['admin', 'write'], + }, + }); + + jwtDecode.mockReturnValue({ + roles: ['requiredRole'], + }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + const { user } = await validate({ + ...tokenset, + claims: () => userinfoWithNestedGroups, + }); + + expect(user.role).toBe('ADMIN'); + }); + + it('should handle boolean admin role value', async () => { + process.env.OPENID_ADMIN_ROLE = 'admin'; + process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'is_admin'; + + jwtDecode.mockReturnValue({ + roles: ['requiredRole'], + is_admin: true, + }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + const { user } = await validate(tokenset); + + expect(user.role).toBe('ADMIN'); + }); + + it('should handle string admin role value matching exactly', async () => { + process.env.OPENID_ADMIN_ROLE = 'super-admin'; + process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'role'; + + jwtDecode.mockReturnValue({ + roles: ['requiredRole'], + role: 'super-admin', + }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + const { user } = await validate(tokenset); + + expect(user.role).toBe('ADMIN'); + }); + + it('should not set admin role when string value does not match', async () => { + process.env.OPENID_ADMIN_ROLE = 'super-admin'; + process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'role'; + + jwtDecode.mockReturnValue({ + roles: ['requiredRole'], + role: 'regular-user', + }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + const { user } = await validate(tokenset); + + expect(user.role).toBeUndefined(); + }); + + it('should handle array admin role value', async () => { + process.env.OPENID_ADMIN_ROLE = 'site-admin'; + process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'app_roles'; + + jwtDecode.mockReturnValue({ + roles: ['requiredRole'], + app_roles: ['user', 'site-admin', 'moderator'], + }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + const { user } = await validate(tokenset); + + expect(user.role).toBe('ADMIN'); + }); + + it('should not set admin when role is not in array', async () => { + process.env.OPENID_ADMIN_ROLE = 'site-admin'; + process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'app_roles'; + + jwtDecode.mockReturnValue({ + roles: ['requiredRole'], + app_roles: ['user', 'moderator'], + }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + const { user } = await validate(tokenset); + + expect(user.role).toBeUndefined(); + }); + + it('should grant admin when admin role claim is a space-separated string containing the admin role', async () => { + // Arrange โ€“ IdP returns admin roles as a space-delimited string + process.env.OPENID_ADMIN_ROLE = 'site-admin'; + process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'app_roles'; + + jwtDecode.mockReturnValue({ + roles: ['requiredRole'], + app_roles: 'user site-admin moderator', + }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + // Act + const { user } = await validate(tokenset); + + // Assert โ€“ admin role is granted after splitting the delimited string + expect(user.role).toBe('ADMIN'); + }); + + it('should not grant admin when admin role claim is a space-separated string that does not contain the admin role', async () => { + // Arrange โ€“ delimited string present but admin role is absent + process.env.OPENID_ADMIN_ROLE = 'site-admin'; + process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'app_roles'; + + jwtDecode.mockReturnValue({ + roles: ['requiredRole'], + app_roles: 'user moderator', + }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + // Act + const { user } = await validate(tokenset); + + // Assert โ€“ admin role is not granted + expect(user.role).toBeUndefined(); + }); + + it('should handle nested path with special characters in keys', async () => { + process.env.OPENID_REQUIRED_ROLE = 'app-user'; + process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'resource_access.my-app-123.roles'; + + jwtDecode.mockReturnValue({ + resource_access: { + 'my-app-123': { + roles: ['app-user'], + }, + }, + }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + const { user } = await validate(tokenset); + + expect(user).toBeTruthy(); + }); + + it('should handle empty object at nested path', async () => { + const { logger } = require('@librechat/data-schemas'); + process.env.OPENID_REQUIRED_ROLE = 'user'; + process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'access.roles'; + + jwtDecode.mockReturnValue({ + access: {}, + }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + const { user } = await validate(tokenset); + + expect(logger.error).toHaveBeenCalledWith( + expect.stringContaining("Key 'access.roles' not found in id token!"), + ); + expect(user).toBe(false); + }); + + it('should handle null value at intermediate path', async () => { + const { logger } = require('@librechat/data-schemas'); + process.env.OPENID_REQUIRED_ROLE = 'user'; + process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'data.roles'; + + jwtDecode.mockReturnValue({ + data: null, + }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + const { user } = await validate(tokenset); + + expect(logger.error).toHaveBeenCalledWith( + expect.stringContaining("Key 'data.roles' not found in id token!"), + ); + expect(user).toBe(false); + }); + + it('should reject login with invalid admin role token kind', async () => { + process.env.OPENID_ADMIN_ROLE = 'admin'; + process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'roles'; + process.env.OPENID_ADMIN_ROLE_TOKEN_KIND = 'invalid'; + + const { logger } = require('@librechat/data-schemas'); + + jwtDecode.mockReturnValue({ + roles: ['requiredRole', 'admin'], + }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + await expect(validate(tokenset)).rejects.toThrow('Invalid admin role token kind'); + + expect(logger.error).toHaveBeenCalledWith( + expect.stringContaining( + "Invalid admin role token kind: invalid. Must be one of 'access', 'id', or 'userinfo'", + ), + ); + }); + + it('should reject login when roles path returns invalid type (object)', async () => { + const { logger } = require('@librechat/data-schemas'); + process.env.OPENID_REQUIRED_ROLE = 'app-user'; + process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'roles'; + + jwtDecode.mockReturnValue({ + roles: { admin: true, user: false }, + }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + const { user, details } = await validate(tokenset); + + expect(logger.error).toHaveBeenCalledWith( + expect.stringContaining("Key 'roles' not found in id token!"), + ); + expect(user).toBe(false); + expect(details.message).toContain('role to log in'); + }); + + it('should reject login when roles path returns invalid type (number)', async () => { + const { logger } = require('@librechat/data-schemas'); + process.env.OPENID_REQUIRED_ROLE = 'user'; + process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'roleCount'; + + jwtDecode.mockReturnValue({ + roleCount: 5, + }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + const { user } = await validate(tokenset); + + expect(logger.error).toHaveBeenCalledWith( + expect.stringContaining("Key 'roleCount' not found in id token!"), + ); + expect(user).toBe(false); + }); + }); + + describe('OPENID_EMAIL_CLAIM', () => { + it('should use the default email when OPENID_EMAIL_CLAIM is not set', async () => { + const { user } = await validate(tokenset); + expect(user.email).toBe('test@example.com'); + }); + + it('should use the configured claim when OPENID_EMAIL_CLAIM is set', async () => { + process.env.OPENID_EMAIL_CLAIM = 'upn'; + const userinfo = { ...tokenset.claims(), upn: 'user@corp.example.com' }; + + const { user } = await validate({ ...tokenset, claims: () => userinfo }); + + expect(user.email).toBe('user@corp.example.com'); + expect(createUser).toHaveBeenCalledWith( + expect.objectContaining({ email: 'user@corp.example.com' }), + expect.anything(), + true, + true, + ); + }); + + it('should fall back to preferred_username when email is missing and OPENID_EMAIL_CLAIM is not set', async () => { + const userinfo = { ...tokenset.claims() }; + delete userinfo.email; + + const { user } = await validate({ ...tokenset, claims: () => userinfo }); + + expect(user.email).toBe('testusername'); + }); + + it('should fall back to upn when email and preferred_username are missing and OPENID_EMAIL_CLAIM is not set', async () => { + const userinfo = { ...tokenset.claims(), upn: 'user@corp.example.com' }; + delete userinfo.email; + delete userinfo.preferred_username; + + const { user } = await validate({ ...tokenset, claims: () => userinfo }); + + expect(user.email).toBe('user@corp.example.com'); + }); + + it('should ignore empty string OPENID_EMAIL_CLAIM and use default fallback', async () => { + process.env.OPENID_EMAIL_CLAIM = ''; + + const { user } = await validate(tokenset); + + expect(user.email).toBe('test@example.com'); + }); + + it('should trim whitespace from OPENID_EMAIL_CLAIM and resolve correctly', async () => { + process.env.OPENID_EMAIL_CLAIM = ' upn '; + const userinfo = { ...tokenset.claims(), upn: 'user@corp.example.com' }; + + const { user } = await validate({ ...tokenset, claims: () => userinfo }); + + expect(user.email).toBe('user@corp.example.com'); + }); + + it('should ignore whitespace-only OPENID_EMAIL_CLAIM and use default fallback', async () => { + process.env.OPENID_EMAIL_CLAIM = ' '; + + const { user } = await validate(tokenset); + + expect(user.email).toBe('test@example.com'); + }); + + it('should fall back to default chain with warning when configured claim is missing from userinfo', async () => { + const { logger } = require('@librechat/data-schemas'); + process.env.OPENID_EMAIL_CLAIM = 'nonexistent_claim'; + + const { user } = await validate(tokenset); + + expect(user.email).toBe('test@example.com'); + expect(logger.warn).toHaveBeenCalledWith( + expect.stringContaining('OPENID_EMAIL_CLAIM="nonexistent_claim" not present in userinfo'), + ); + }); + }); + + describe('Tenant-scoped config', () => { + it('should call resolveAppConfigForUser for tenant user', async () => { + const existingUser = { + _id: 'openid-tenant-user', + provider: 'openid', + openidId: '1234', + email: 'test@example.com', + tenantId: 'tenant-d', + role: 'USER', + }; + findUser.mockResolvedValue(existingUser); + + await validate(tokenset); + + expect(resolveAppConfigForUser).toHaveBeenCalledWith(getAppConfig, existingUser); + }); + + it('should use baseConfig for new user without calling resolveAppConfigForUser', async () => { + findUser.mockResolvedValue(null); + + await validate(tokenset); + + expect(resolveAppConfigForUser).not.toHaveBeenCalled(); + expect(getAppConfig).toHaveBeenCalledWith({ baseOnly: true }); + }); + + it('should block login when tenant config restricts the domain', async () => { + const { isEmailDomainAllowed } = require('@librechat/api'); + const existingUser = { + _id: 'openid-tenant-blocked', + provider: 'openid', + openidId: '1234', + email: 'test@example.com', + tenantId: 'tenant-restrict', + role: 'USER', + }; + findUser.mockResolvedValue(existingUser); + resolveAppConfigForUser.mockResolvedValue({ + registration: { allowedDomains: ['other.com'] }, + }); + isEmailDomainAllowed.mockReturnValueOnce(true).mockReturnValueOnce(false); + + const { user, details } = await validate(tokenset); + expect(user).toBe(false); + expect(details).toEqual({ message: 'Email domain not allowed' }); + }); + }); +}); diff --git a/api/strategies/samlStrategy.js b/api/strategies/samlStrategy.js index abcb3de099..21e7bdd001 100644 --- a/api/strategies/samlStrategy.js +++ b/api/strategies/samlStrategy.js @@ -5,7 +5,11 @@ const passport = require('passport'); const { ErrorTypes } = require('librechat-data-provider'); const { hashToken, logger } = require('@librechat/data-schemas'); const { Strategy: SamlStrategy } = require('@node-saml/passport-saml'); -const { getBalanceConfig, isEmailDomainAllowed } = require('@librechat/api'); +const { + getBalanceConfig, + isEmailDomainAllowed, + resolveAppConfigForUser, +} = require('@librechat/api'); const { getStrategyFunctions } = require('~/server/services/Files/strategies'); const { findUser, createUser, updateUser } = require('~/models'); const { getAppConfig } = require('~/server/services/Config'); @@ -193,9 +197,9 @@ async function setupSaml() { logger.debug('[samlStrategy] SAML profile:', profile); const userEmail = getEmail(profile) || ''; - const appConfig = await getAppConfig({ baseOnly: true }); - if (!isEmailDomainAllowed(userEmail, appConfig?.registration?.allowedDomains)) { + const baseConfig = await getAppConfig({ baseOnly: true }); + if (!isEmailDomainAllowed(userEmail, baseConfig?.registration?.allowedDomains)) { logger.error( `[SAML Strategy] Authentication blocked - email domain not allowed [Email: ${userEmail}]`, ); @@ -223,6 +227,17 @@ async function setupSaml() { }); } + const appConfig = user?.tenantId + ? await resolveAppConfigForUser(getAppConfig, user) + : baseConfig; + + if (!isEmailDomainAllowed(userEmail, appConfig?.registration?.allowedDomains)) { + logger.error( + `[SAML Strategy] Authentication blocked - email domain not allowed [Email: ${userEmail}]`, + ); + return done(null, false, { message: 'Email domain not allowed' }); + } + const fullName = getFullName(profile); const username = convertToUsername( diff --git a/api/strategies/samlStrategy.spec.js b/api/strategies/samlStrategy.spec.js index 1d16719b87..2022d34b33 100644 --- a/api/strategies/samlStrategy.spec.js +++ b/api/strategies/samlStrategy.spec.js @@ -30,6 +30,7 @@ jest.mock('@librechat/api', () => ({ tokenCredits: 1000, startBalance: 1000, })), + resolveAppConfigForUser: jest.fn(async (_getAppConfig, _user) => ({})), })); jest.mock('~/server/services/Config/EndpointService', () => ({ config: {}, @@ -47,6 +48,9 @@ const fs = require('fs'); const path = require('path'); const fetch = require('node-fetch'); const { Strategy: SamlStrategy } = require('@node-saml/passport-saml'); +const { findUser } = require('~/models'); +const { resolveAppConfigForUser } = require('@librechat/api'); +const { getAppConfig } = require('~/server/services/Config'); const { setupSaml, getCertificateContent } = require('./samlStrategy'); // Configure fs mock @@ -440,4 +444,50 @@ u7wlOSk+oFzDIO/UILIA expect(fetch).not.toHaveBeenCalled(); }); + + it('should pass the found user to resolveAppConfigForUser', async () => { + const existingUser = { + _id: 'tenant-user-id', + provider: 'saml', + samlId: 'saml-1234', + email: 'test@example.com', + tenantId: 'tenant-c', + role: 'USER', + }; + findUser.mockResolvedValue(existingUser); + + const profile = { ...baseProfile }; + await validate(profile); + + expect(resolveAppConfigForUser).toHaveBeenCalledWith(getAppConfig, existingUser); + }); + + it('should use baseConfig for new SAML user without calling resolveAppConfigForUser', async () => { + const profile = { ...baseProfile }; + await validate(profile); + + expect(resolveAppConfigForUser).not.toHaveBeenCalled(); + expect(getAppConfig).toHaveBeenCalledWith({ baseOnly: true }); + }); + + it('should block login when tenant config restricts the domain', async () => { + const { isEmailDomainAllowed } = require('@librechat/api'); + const existingUser = { + _id: 'tenant-blocked', + provider: 'saml', + samlId: 'saml-1234', + email: 'test@example.com', + tenantId: 'tenant-restrict', + role: 'USER', + }; + findUser.mockResolvedValue(existingUser); + resolveAppConfigForUser.mockResolvedValue({ + registration: { allowedDomains: ['other.com'] }, + }); + isEmailDomainAllowed.mockReturnValueOnce(true).mockReturnValueOnce(false); + + const profile = { ...baseProfile }; + const { user } = await validate(profile); + expect(user).toBe(false); + }); }); diff --git a/api/strategies/socialLogin.js b/api/strategies/socialLogin.js index 7585e8e2fe..a5fe78e17d 100644 --- a/api/strategies/socialLogin.js +++ b/api/strategies/socialLogin.js @@ -1,6 +1,6 @@ const { logger } = require('@librechat/data-schemas'); const { ErrorTypes } = require('librechat-data-provider'); -const { isEnabled, isEmailDomainAllowed } = require('@librechat/api'); +const { isEnabled, isEmailDomainAllowed, resolveAppConfigForUser } = require('@librechat/api'); const { createSocialUser, handleExistingUser } = require('./process'); const { getAppConfig } = require('~/server/services/Config'); const { findUser } = require('~/models'); @@ -13,9 +13,8 @@ const socialLogin = profile, }); - const appConfig = await getAppConfig({ baseOnly: true }); - - if (!isEmailDomainAllowed(email, appConfig?.registration?.allowedDomains)) { + const baseConfig = await getAppConfig({ baseOnly: true }); + if (!isEmailDomainAllowed(email, baseConfig?.registration?.allowedDomains)) { logger.error( `[${provider}Login] Authentication blocked - email domain not allowed [Email: ${email}]`, ); @@ -41,6 +40,20 @@ const socialLogin = } } + const appConfig = existingUser?.tenantId + ? await resolveAppConfigForUser(getAppConfig, existingUser) + : baseConfig; + + if (!isEmailDomainAllowed(email, appConfig?.registration?.allowedDomains)) { + logger.error( + `[${provider}Login] Authentication blocked - email domain not allowed [Email: ${email}]`, + ); + const error = new Error(ErrorTypes.AUTH_FAILED); + error.code = ErrorTypes.AUTH_FAILED; + error.message = 'Email domain not allowed'; + return cb(error); + } + if (existingUser?.provider === provider) { await handleExistingUser(existingUser, avatarUrl, appConfig, email); return cb(null, existingUser); diff --git a/api/strategies/socialLogin.test.js b/api/strategies/socialLogin.test.js index ba4778c8b1..4fde397d55 100644 --- a/api/strategies/socialLogin.test.js +++ b/api/strategies/socialLogin.test.js @@ -3,6 +3,8 @@ const { ErrorTypes } = require('librechat-data-provider'); const { createSocialUser, handleExistingUser } = require('./process'); const socialLogin = require('./socialLogin'); const { findUser } = require('~/models'); +const { resolveAppConfigForUser } = require('@librechat/api'); +const { getAppConfig } = require('~/server/services/Config'); jest.mock('@librechat/data-schemas', () => { const actualModule = jest.requireActual('@librechat/data-schemas'); @@ -25,6 +27,10 @@ jest.mock('@librechat/api', () => ({ ...jest.requireActual('@librechat/api'), isEnabled: jest.fn().mockReturnValue(true), isEmailDomainAllowed: jest.fn().mockReturnValue(true), + resolveAppConfigForUser: jest.fn().mockResolvedValue({ + fileStrategy: 'local', + balance: { enabled: false }, + }), })); jest.mock('~/models', () => ({ @@ -66,10 +72,7 @@ describe('socialLogin', () => { googleId: googleId, }; - /** Mock findUser to return user on first call (by googleId), null on second call */ - findUser - .mockResolvedValueOnce(existingUser) // First call: finds by googleId - .mockResolvedValueOnce(null); // Second call would be by email, but won't be reached + findUser.mockResolvedValueOnce(existingUser).mockResolvedValueOnce(null); const mockProfile = { id: googleId, @@ -83,13 +86,9 @@ describe('socialLogin', () => { await loginFn(null, null, null, mockProfile, callback); - /** Verify it searched by googleId first */ expect(findUser).toHaveBeenNthCalledWith(1, { googleId: googleId }); - - /** Verify it did NOT search by email (because it found user by googleId) */ expect(findUser).toHaveBeenCalledTimes(1); - /** Verify handleExistingUser was called with the new email */ expect(handleExistingUser).toHaveBeenCalledWith( existingUser, 'https://example.com/avatar.png', @@ -97,7 +96,6 @@ describe('socialLogin', () => { newEmail, ); - /** Verify callback was called with success */ expect(callback).toHaveBeenCalledWith(null, existingUser); }); @@ -113,7 +111,7 @@ describe('socialLogin', () => { facebookId: facebookId, }; - findUser.mockResolvedValue(existingUser); // Always returns user + findUser.mockResolvedValue(existingUser); const mockProfile = { id: facebookId, @@ -127,7 +125,6 @@ describe('socialLogin', () => { await loginFn(null, null, null, mockProfile, callback); - /** Verify it searched by facebookId first */ expect(findUser).toHaveBeenCalledWith({ facebookId: facebookId }); expect(findUser.mock.calls[0]).toEqual([{ facebookId: facebookId }]); @@ -150,13 +147,10 @@ describe('socialLogin', () => { _id: 'user789', email: email, provider: 'google', - googleId: 'old-google-id', // Different googleId (edge case) + googleId: 'old-google-id', }; - /** First call (by googleId) returns null, second call (by email) returns user */ - findUser - .mockResolvedValueOnce(null) // By googleId - .mockResolvedValueOnce(existingUser); // By email + findUser.mockResolvedValueOnce(null).mockResolvedValueOnce(existingUser); const mockProfile = { id: googleId, @@ -170,13 +164,10 @@ describe('socialLogin', () => { await loginFn(null, null, null, mockProfile, callback); - /** Verify both searches happened */ expect(findUser).toHaveBeenNthCalledWith(1, { googleId: googleId }); - /** Email passed as-is; findUser implementation handles case normalization */ expect(findUser).toHaveBeenNthCalledWith(2, { email: email }); expect(findUser).toHaveBeenCalledTimes(2); - /** Verify warning log */ expect(logger.warn).toHaveBeenCalledWith( `[${provider}Login] User found by email: ${email} but not by ${provider}Id`, ); @@ -197,7 +188,6 @@ describe('socialLogin', () => { googleId: googleId, }; - /** Both searches return null */ findUser.mockResolvedValue(null); createSocialUser.mockResolvedValue(newUser); @@ -213,10 +203,8 @@ describe('socialLogin', () => { await loginFn(null, null, null, mockProfile, callback); - /** Verify both searches happened */ expect(findUser).toHaveBeenCalledTimes(2); - /** Verify createSocialUser was called */ expect(createSocialUser).toHaveBeenCalledWith({ email: email, avatarUrl: 'https://example.com/avatar.png', @@ -242,12 +230,10 @@ describe('socialLogin', () => { const existingUser = { _id: 'user123', email: email, - provider: 'local', // Different provider + provider: 'local', }; - findUser - .mockResolvedValueOnce(null) // By googleId - .mockResolvedValueOnce(existingUser); // By email + findUser.mockResolvedValueOnce(null).mockResolvedValueOnce(existingUser); const mockProfile = { id: googleId, @@ -261,7 +247,6 @@ describe('socialLogin', () => { await loginFn(null, null, null, mockProfile, callback); - /** Verify error callback */ expect(callback).toHaveBeenCalledWith( expect.objectContaining({ code: ErrorTypes.AUTH_FAILED, @@ -274,4 +259,104 @@ describe('socialLogin', () => { ); }); }); + + describe('Tenant-scoped config', () => { + it('should call resolveAppConfigForUser for tenant user', async () => { + const provider = 'google'; + const googleId = 'google-tenant-user'; + const email = 'tenant@example.com'; + + const existingUser = { + _id: 'userTenant', + email, + provider: 'google', + googleId, + tenantId: 'tenant-b', + role: 'USER', + }; + + findUser.mockResolvedValue(existingUser); + + const mockProfile = { + id: googleId, + emails: [{ value: email, verified: true }], + photos: [{ value: 'https://example.com/avatar.png' }], + name: { givenName: 'Tenant', familyName: 'User' }, + }; + + const loginFn = socialLogin(provider, mockGetProfileDetails); + const callback = jest.fn(); + + await loginFn(null, null, null, mockProfile, callback); + + expect(resolveAppConfigForUser).toHaveBeenCalledWith(getAppConfig, existingUser); + }); + + it('should use baseConfig for non-tenant user without calling resolveAppConfigForUser', async () => { + const provider = 'google'; + const googleId = 'google-new-tenant'; + const email = 'new@example.com'; + + findUser.mockResolvedValue(null); + createSocialUser.mockResolvedValue({ + _id: 'newUser', + email, + provider: 'google', + googleId, + }); + + const mockProfile = { + id: googleId, + emails: [{ value: email, verified: true }], + photos: [{ value: 'https://example.com/avatar.png' }], + name: { givenName: 'New', familyName: 'User' }, + }; + + const loginFn = socialLogin(provider, mockGetProfileDetails); + const callback = jest.fn(); + + await loginFn(null, null, null, mockProfile, callback); + + expect(resolveAppConfigForUser).not.toHaveBeenCalled(); + expect(getAppConfig).toHaveBeenCalledWith({ baseOnly: true }); + }); + + it('should block login when tenant config restricts the domain', async () => { + const { isEmailDomainAllowed } = require('@librechat/api'); + const provider = 'google'; + const googleId = 'google-tenant-blocked'; + const email = 'blocked@example.com'; + + const existingUser = { + _id: 'userBlocked', + email, + provider: 'google', + googleId, + tenantId: 'tenant-restrict', + role: 'USER', + }; + + findUser.mockResolvedValue(existingUser); + resolveAppConfigForUser.mockResolvedValue({ + registration: { allowedDomains: ['other.com'] }, + }); + isEmailDomainAllowed.mockReturnValueOnce(true).mockReturnValueOnce(false); + + const mockProfile = { + id: googleId, + emails: [{ value: email, verified: true }], + photos: [{ value: 'https://example.com/avatar.png' }], + name: { givenName: 'Blocked', familyName: 'User' }, + }; + + const loginFn = socialLogin(provider, mockGetProfileDetails); + const callback = jest.fn(); + + await loginFn(null, null, null, mockProfile, callback); + + expect(callback).toHaveBeenCalledWith( + expect.objectContaining({ message: 'Email domain not allowed' }), + ); + }); + }); }); diff --git a/packages/api/src/app/index.ts b/packages/api/src/app/index.ts index 7acb75e09d..8d8802f016 100644 --- a/packages/api/src/app/index.ts +++ b/packages/api/src/app/index.ts @@ -3,3 +3,4 @@ export * from './config'; export * from './permissions'; export * from './cdn'; export * from './checks'; +export * from './resolve'; diff --git a/packages/api/src/app/resolve.spec.ts b/packages/api/src/app/resolve.spec.ts new file mode 100644 index 0000000000..d7585198a0 --- /dev/null +++ b/packages/api/src/app/resolve.spec.ts @@ -0,0 +1,95 @@ +import type { AsyncLocalStorage } from 'async_hooks'; + +jest.mock('@librechat/data-schemas', () => { + // eslint-disable-next-line @typescript-eslint/no-require-imports + const { AsyncLocalStorage: ALS } = require('async_hooks'); + return { tenantStorage: new ALS() }; +}); + +import { resolveAppConfigForUser } from './resolve'; + +const { tenantStorage } = jest.requireMock('@librechat/data-schemas') as { + tenantStorage: AsyncLocalStorage<{ tenantId?: string }>; +}; + +describe('resolveAppConfigForUser', () => { + const mockGetAppConfig = jest.fn(); + + beforeEach(() => { + jest.clearAllMocks(); + mockGetAppConfig.mockResolvedValue({ registration: {} }); + }); + + it('calls getAppConfig with baseOnly when user is null', async () => { + await resolveAppConfigForUser(mockGetAppConfig, null); + expect(mockGetAppConfig).toHaveBeenCalledWith({ baseOnly: true }); + }); + + it('calls getAppConfig with baseOnly when user is undefined', async () => { + await resolveAppConfigForUser(mockGetAppConfig, undefined); + expect(mockGetAppConfig).toHaveBeenCalledWith({ baseOnly: true }); + }); + + it('calls getAppConfig with baseOnly when user has no tenantId', async () => { + await resolveAppConfigForUser(mockGetAppConfig, { role: 'USER' }); + expect(mockGetAppConfig).toHaveBeenCalledWith({ baseOnly: true }); + }); + + it('calls getAppConfig with role and tenantId when user has tenantId', async () => { + await resolveAppConfigForUser(mockGetAppConfig, { tenantId: 'tenant-a', role: 'USER' }); + expect(mockGetAppConfig).toHaveBeenCalledWith({ role: 'USER', tenantId: 'tenant-a' }); + }); + + it('calls tenantStorage.run for tenant users but not for non-tenant users', async () => { + const runSpy = jest.spyOn(tenantStorage, 'run'); + + await resolveAppConfigForUser(mockGetAppConfig, { role: 'USER' }); + expect(runSpy).not.toHaveBeenCalled(); + + await resolveAppConfigForUser(mockGetAppConfig, { tenantId: 'tenant-b', role: 'ADMIN' }); + expect(runSpy).toHaveBeenCalledWith({ tenantId: 'tenant-b' }, expect.any(Function)); + + runSpy.mockRestore(); + }); + + it('makes tenantId available via ALS inside getAppConfig', async () => { + let capturedContext: { tenantId?: string } | undefined; + mockGetAppConfig.mockImplementation(async () => { + capturedContext = tenantStorage.getStore(); + return { registration: {} }; + }); + + await resolveAppConfigForUser(mockGetAppConfig, { tenantId: 'tenant-c', role: 'USER' }); + + expect(capturedContext).toEqual({ tenantId: 'tenant-c' }); + }); + + it('returns the config from getAppConfig', async () => { + const tenantConfig = { registration: { allowedDomains: ['example.com'] } }; + mockGetAppConfig.mockResolvedValue(tenantConfig); + + const result = await resolveAppConfigForUser(mockGetAppConfig, { + tenantId: 'tenant-d', + role: 'USER', + }); + + expect(result).toBe(tenantConfig); + }); + + it('calls getAppConfig with role undefined when user has tenantId but no role', async () => { + await resolveAppConfigForUser(mockGetAppConfig, { tenantId: 'tenant-e' }); + expect(mockGetAppConfig).toHaveBeenCalledWith({ role: undefined, tenantId: 'tenant-e' }); + }); + + it('propagates rejection from getAppConfig for tenant users', async () => { + mockGetAppConfig.mockRejectedValue(new Error('config unavailable')); + await expect( + resolveAppConfigForUser(mockGetAppConfig, { tenantId: 'tenant-f', role: 'USER' }), + ).rejects.toThrow('config unavailable'); + }); + + it('propagates rejection from getAppConfig for baseOnly path', async () => { + mockGetAppConfig.mockRejectedValue(new Error('cache failure')); + await expect(resolveAppConfigForUser(mockGetAppConfig, null)).rejects.toThrow('cache failure'); + }); +}); diff --git a/packages/api/src/app/resolve.ts b/packages/api/src/app/resolve.ts new file mode 100644 index 0000000000..0810400222 --- /dev/null +++ b/packages/api/src/app/resolve.ts @@ -0,0 +1,39 @@ +import { tenantStorage } from '@librechat/data-schemas'; +import type { AppConfig } from '@librechat/data-schemas'; + +interface UserForConfigResolution { + tenantId?: string; + role?: string; +} + +type GetAppConfig = (opts: { + role?: string; + tenantId?: string; + baseOnly?: boolean; +}) => Promise; + +/** + * Resolves AppConfig scoped to the given user's tenant when available, + * falling back to YAML-only base config for new users or non-tenant deployments. + * + * Auth flows only apply role-level overrides (userId is not passed) because + * user/group principal resolution requires heavier DB work that is deferred + * to post-authentication config calls. + * + * `tenantId` is propagated through two channels that serve different purposes: + * - `tenantStorage.run()` sets the ALS context so Mongoose's `applyTenantIsolation` + * plugin scopes any DB queries (e.g., `getApplicableConfigs`) to the tenant. + * - The explicit `tenantId` parameter to `getAppConfig` is used for cache-key + * computation in `overrideCacheKey()`. Both channels are required. + */ +export async function resolveAppConfigForUser( + getAppConfig: GetAppConfig, + user: UserForConfigResolution | null | undefined, +): Promise { + if (user?.tenantId) { + return tenantStorage.run({ tenantId: user.tenantId }, async () => + getAppConfig({ role: user.role, tenantId: user.tenantId }), + ); + } + return getAppConfig({ baseOnly: true }); +} From 935288f84127084c5675367b0ae5ad38f438be87 Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Sat, 28 Mar 2026 10:36:43 -0400 Subject: [PATCH 12/18] =?UTF-8?q?=F0=9F=8F=97=EF=B8=8F=20feat:=203-Tier=20?= =?UTF-8?q?MCP=20Server=20Architecture=20with=20Config-Source=20Lazy=20Ini?= =?UTF-8?q?t=20(#12435)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: add MCPServerSource type, tenantMcpPolicy schema, and source-based dbSourced wiring - Add `tenantMcpPolicy` to `mcpSettings` in YAML config schema with `enabled`, `maxServersPerTenant`, `allowedTransports`, and `allowedDomains` - Add `MCPServerSource` type ('yaml' | 'config' | 'user') and `source` field to `ParsedServerConfig` - Change `dbSourced` determination from `!!config.dbId` to `config.source === 'user'` across MCPManager, ConnectionsRepository, UserConnectionManager, and MCPServerInspector - Set `source: 'user'` on all DB-sourced servers in ServerConfigsDB * feat: three-layer MCPServersRegistry with config cache and lazy init - Add `configCacheRepo` as third repository layer between YAML cache and DB for admin-defined config-source MCP servers - Implement `ensureConfigServers()` that identifies config-override servers from resolved `getAppConfig()` mcpConfig, lazily inspects them, and caches parsed configs with `source: 'config'` - Add `lazyInitConfigServer()` with timeout, stub-on-failure, and concurrent-init deduplication via `pendingConfigInits` map - Extend `getAllServerConfigs()` with optional `configServers` param for three-way merge: YAML โ†’ Config โ†’ User - Add `getServerConfig()` lookup through config cache layer - Add `invalidateConfigCache()` for clearing config-source inspection results on admin config mutations - Tag `source: 'yaml'` on CACHE-stored servers and `source: 'user'` on DB-stored servers in `addServer()` and `addServerStub()` * feat: wire tenant context into MCP controllers, services, and cache invalidation - Resolve config-source servers via `getAppConfig({ role, tenantId })` in `getMCPTools()` and `getMCPServersList()` controllers - Pass `ensureConfigServers()` results through `getAllServerConfigs()` for three-way merge of YAML + Config + User servers - Add tenant/role context to `getMCPSetupData()` and connection status routes via `getTenantId()` from ALS - Add `clearMcpConfigCache()` to `invalidateConfigCaches()` so admin config mutations trigger re-inspection of config-source MCP servers * feat: enforce tenantMcpPolicy on admin config mcpServers mutations - Add `validateMcpServerPolicy()` helper that checks mcpServers against operator-defined `tenantMcpPolicy` (enabled, maxServersPerTenant, allowedTransports, allowedDomains) - Wire validation into `upsertConfigOverrides` and `patchConfigField` handlers โ€” rejects with 403 when policy is violated - Infer transport type from config shape (command โ†’ stdio, url protocol โ†’ websocket/sse, type field โ†’ streamable-http) - Validate server domains against policy allowlist when configured * revert: remove tenantMcpPolicy schema and enforcement The existing admin config CRUD routes already provide the mechanism for granular MCP server prepopulation (groups, roles, users). The tenantMcpPolicy gating adds unnecessary complexity that can be revisited if needed in the future. - Remove tenantMcpPolicy from mcpSettings Zod schema - Remove validateMcpServerPolicy helper and TenantMcpPolicy interface - Remove policy enforcement from upsertConfigOverrides and patchConfigField handlers * test: update test assertions for source field and config-server wiring - Use objectContaining in MCPServersRegistry reset test to account for new source: 'yaml' field on CACHE-stored configs - Add getTenantId and ensureConfigServers mocks to MCP route tests - Add getAppConfig mock to route test Config service mock - Update getMCPSetupData assertion to expect second options argument - Update getAllServerConfigs assertions for new configServers parameter * fix: disconnect active connections when config-source servers are evicted When admin config overrides change and config-source MCP servers are removed, the invalidation now proactively disconnects active connections for evicted servers instead of leaving them lingering until timeout. - Return evicted server names from invalidateConfigCache() - Disconnect app-level connections for evicted servers in clearMcpConfigCache() via MCPManager.appConnections.disconnect() * fix: address code review findings (CRITICAL, MAJOR, MINOR) CRITICAL fixes: - Scope configCacheRepo keys by config content hash to prevent cross-tenant cache poisoning when two tenants define the same server name with different configurations - Change dbSourced checks from `source === 'user'` to `source !== 'yaml' && source !== 'config'` so undefined source (pre-upgrade cached configs) fails closed to restricted mode MAJOR fixes: - Derive OAuth servers from already-computed mcpConfig instead of calling getOAuthServers() separately โ€” config-source OAuth servers are now properly detected - Add parseInt radix (10) and NaN guard with fallback to 30_000 for CONFIG_SERVER_INIT_TIMEOUT_MS - Add CONFIG_CACHE_NAMESPACE to aggregate-key branch in ServerConfigsCacheFactory to avoid SCAN-based Redis stalls - Remove `if (role || tenantId)` guard in getMCPSetupData โ€” config servers now always resolve regardless of tenant context MINOR fixes: - Extract resolveAllMcpConfigs() helper in mcp controller to eliminate 3x copy-pasted config resolution boilerplate - Distinguish "not initialized" from real errors in clearMcpConfigCache โ€” log actual failures instead of swallowing - Remove narrative inline comments per style guide - Remove dead try/catch inside Promise.allSettled in ensureConfigServers (inner method never throws) - Memoize YAML server names to avoid repeated cacheConfigsRepo.getAll() calls per request Test updates: - Add ensureConfigServers mock to registry test fixtures - Update getMCPSetupData assertions for inline OAuth derivation * fix: address code review findings (CRITICAL, MAJOR, MINOR) CRITICAL fixes: - Break circular dependency: move CONFIG_CACHE_NAMESPACE from MCPServersRegistry to ServerConfigsCacheFactory - Fix dbSourced fail-closed: use source field when present, fall back to legacy dbId check when absent (backward-compatible with pre-upgrade cached configs that lack source field) MAJOR fixes: - Add CONFIG_CACHE_NAMESPACE to aggregate-key set in ServerConfigsCacheFactory to avoid SCAN-based Redis stalls - Add comprehensive test suite (ensureConfigServers.test.ts, 18 tests) covering lazy init, stub-on-failure, cross-tenant isolation via config hash keys, concurrent deduplication, merge order, and cache invalidation MINOR fixes: - Update MCPServerInspector test assertion for dbSourced change * fix: restore getServerConfig lookup for config-source servers (NEW-1) Add configNameToKey map that indexes server name โ†’ hash-based cache key for O(1) lookup by name in getServerConfig. This restores the config cache layer that was dropped when hash-based keys were introduced. Without this fix, config-source servers appeared in tool listings (via getAllServerConfigs) but getServerConfig returned undefined, breaking all connection and tool call paths. - Populate configNameToKey in ensureSingleConfigServer - Clear configNameToKey in invalidateConfigCache and reset - Clear stale read-through cache entries after lazy init - Remove dead code in invalidateConfigCache (config.title, key parsing) - Add getServerConfig tests for config-source server lookup * fix: eliminate configNameToKey race via caller-provided configServers param Replace the process-global configNameToKey map (last-writer-wins under concurrent multi-tenant load) with a configServers parameter on getServerConfig. Callers pass the pre-resolved config servers map directly โ€” no shared mutable state, no cross-tenant race. - Add optional configServers param to getServerConfig; when provided, returns matching config directly without any global lookup - Remove configNameToKey map entirely (was the source of the race) - Extract server names from cache keys via lastIndexOf in invalidateConfigCache (safe for names containing colons) - Use mcpConfig[serverName] directly in getMCPTools instead of a redundant getServerConfig call - Add cross-tenant isolation test for getServerConfig * fix: populate read-through cache after config server lazy init After lazyInitConfigServer succeeds, write the parsed config to readThroughCache keyed by serverName so that getServerConfig calls from ConnectionsRepository, UserConnectionManager, and MCPManager.callTool find the config without needing configServers. Without this, config-source servers appeared in tool listings but every connection attempt and tool call returned undefined. * fix: user-scoped getServerConfig fallback to server-only cache key When getServerConfig is called with a userId (e.g., from callTool or UserConnectionManager), the cache key is serverName::userId. Config-source servers are cached under the server-only key (no userId). Add a fallback so user-scoped lookups find config-source servers in the read-through cache. * fix: configCacheRepo fallback, isUserSourced DRY, cross-process race CRITICAL: Add findInConfigCache fallback in getServerConfig so config-source servers remain reachable after readThroughCache TTL expires (5s). Without this, every tool call after 5s returned undefined for config-source servers. MAJOR: Extract isUserSourced() helper to mcp/utils.ts and replace all 5 inline dbSourced ternary expressions (MCPManager x2, ConnectionsRepository, UserConnectionManager, MCPServerInspector). MAJOR: Fix cross-process Redis race in lazyInitConfigServer โ€” when configCacheRepo.add throws (key exists from another process), fall back to reading the existing entry instead of returning undefined. MINOR: Parallelize invalidateConfigCache awaits with Promise.all. Remove redundant .catch(() => {}) inside Promise.allSettled. Tighten dedup test assertion to toBe(1). Add TTL-expiry tests for getServerConfig (with and without userId). * feat: thread configServers through getAppToolFunctions and formatInstructionsForContext Add optional configServers parameter to getAppToolFunctions, getInstructions, and formatInstructionsForContext so config-source server tools and instructions are visible to agent initialization and context injection paths. Existing callers (boot-time init, tests) pass no argument and continue to work unchanged. Agent runtime paths can now thread resolved config servers from request context. * fix: stale failure stubs retry after 5 min, upsert for cross-process races - Add CONFIG_STUB_RETRY_MS (5 min) โ€” stale failure stubs are retried instead of permanently disabling config-source servers after transient errors (DNS outage, cold-start race) - Extract upsertConfigCache() helper that tries add then falls back to update, preventing cross-process Redis races where a second instance's successful inspection result was discarded - Add test for stale-stub retry after CONFIG_STUB_RETRY_MS * fix: stamp updatedAt on failure stubs, null-guard callTool config, test cleanup - Add updatedAt: Date.now() to failure stubs in lazyInitConfigServer so CONFIG_STUB_RETRY_MS (5 min) window works correctly โ€” without it, stubs were always considered stale (updatedAt ?? 0 โ†’ epoch โ†’ always expired) - Add null guard for rawConfig in MCPManager.callTool before passing to preProcessGraphTokens โ€” prevents unsafe `as` cast on undefined - Log double-failure in upsertConfigCache instead of silently swallowing - Replace module-scope Date.now monkey-patch with jest.useFakeTimers / jest.setSystemTime / jest.useRealTimers in ensureConfigServers tests * fix: server-only readThrough fallback only returns truthy values Prevents a cached undefined from a prior no-userId lookup from short-circuiting the DB query on a subsequent userId-scoped lookup. * fix: remove findInConfigCache to eliminate cross-tenant config leakage The findInConfigCache prefix scan (serverName:*) could return any tenant's config after readThrough TTL expires, violating tenant isolation. Config-source servers are now ONLY resolvable through: 1. The configServers param (callers with tenant context from ALS) 2. The readThrough cache (populated by ensureSingleConfigServer, 5s TTL, repopulated on every HTTP request via resolveAllMcpConfigs) Connection/tool-call paths without tenant context rely exclusively on the readThrough cache. If it expires before the next HTTP request repopulates it, the server is not found โ€” which is correct because there is no tenant context to determine which config to return. - Remove findInConfigCache method and its call in getServerConfig - Update server-only readThrough fallback to only return truthy values (prevents cached undefined from short-circuiting user-scoped DB lookup) - Update tests to document tenant isolation behavior after cache expiry * style: fix import order per AGENTS.md conventions Sort package imports shortest-to-longest, local imports longest-to-shortest across MCPServersRegistry, ConnectionsRepository, MCPManager, UserConnectionManager, and MCPServerInspector. * fix: eliminate cross-tenant readThrough contamination and TTL-expiry tool failures Thread pre-resolved serverConfig from tool creation context into callTool, removing dependency on the readThrough cache for config-source servers. This fixes two issues: - Cross-tenant contamination: the readThrough cache key was unscoped (just serverName), so concurrent multi-tenant requests for same-named servers would overwrite each other's entries - TTL expiry: tool calls happening >5s after config resolution would fail with "Configuration not found" because the readThrough entry had expired Changes: - Add optional serverConfig param to MCPManager.callTool โ€” uses provided config directly, falling back to getServerConfig lookup for YAML/user servers - Thread serverConfig from createMCPTool through createToolInstance closure to callTool - Remove readThrough write from ensureSingleConfigServer โ€” config-source servers are only accessible via configServers param (tenant-scoped) - Remove server-only readThrough fallback from getServerConfig - Increase config cache hash from 8 to 16 hex chars (64-bit) - Add isUserSourced boundary tests for all source/dbId combinations - Fix double Object.keys call in getMCPTools controller - Update test assertions for new getServerConfig behavior * fix: cache base configs for config-server users; narrow upsertConfigCache error handling - Refactor getAllServerConfigs to separate base config fetch (YAML + DB) from config-server layering. Base configs are cached via readThroughCacheAll regardless of whether configServers is provided, eliminating uncached MongoDB queries per request for config-server users - Narrow upsertConfigCache catch to duplicate-key errors only; infrastructure errors (Redis timeouts, network failures) now propagate instead of being silently swallowed, preventing inspection storms during outages * fix: restore correct merge order and document upsert error matching - Restore YAML โ†’ Config โ†’ User DB precedence in getAllServerConfigs (user DB servers have highest precedence, matching the JSDoc contract) - Add source comment on upsertConfigCache duplicate-key detection linking to the two cache implementations that define the error message * feat: complete config-source server support across all execution paths Wire configServers through the entire agent execution pipeline so config-source MCP servers are fully functional โ€” not just visible in listings but executable in agent sessions. - Thread configServers into handleTools.js agent tool pipeline: resolve config servers from tenant context before MCP tool iteration, pass to getServerConfig, createMCPTools, and createMCPTool - Thread configServers into agent instructions pipeline: applyContextToAgent โ†’ getMCPInstructionsForServers โ†’ formatInstructionsForContext, resolved in client.js before agent context application - Add configServers param to createMCPTool and createMCPTools for reconnect path fallback - Add source field to redactServerSecrets allowlist for client UI differentiation of server tiers - Narrow invalidateConfigCache to only clear readThroughCacheAll (merged results), preserving YAML individual-server readThrough entries - Update context.spec.ts assertions for new configServers parameter * fix: add missing mocks for config-source server dependencies in client.test.js Mock getMCPServersRegistry, getAppConfig, and getTenantId that were added to client.js but not reflected in the test file's jest.mock declarations. * fix: update formatInstructionsForContext assertions for configServers param The test assertions expected formatInstructionsForContext to be called with only the server names array, but it now receives configServers as a second argument after the config-source server feature wiring. * fix: move configServers resolution before MCP tool loop to avoid TDZ configServers was declared with `let` after the first tool loop but referenced inside it via getServerConfig(), causing a ReferenceError temporal dead zone. Move declaration and resolution before the loop, using tools.some(mcpToolPattern) to gate the async resolution. * fix: address review findings โ€” cache bypass, discoverServerTools gap, DRY - #2: getAllServerConfigs now always uses getBaseServerConfigs (cached via readThroughCacheAll) instead of bypassing it when configServers is present. Extracts user-DB entries from cached base by diffing against YAML keys to maintain YAML โ†’ Config โ†’ User DB merge order without extra MongoDB calls. - #3: Add configServers param to ToolDiscoveryOptions and thread it through discoverServerTools โ†’ getServerConfig so config-source servers are discoverable during OAuth reconnection flows. - #6: Replace inline import() type annotations in context.ts with proper import type { ParsedServerConfig } per AGENTS.md conventions. - #7: Extract resolveConfigServers(req) helper in MCP.js and use it from handleTools.js and client.js, eliminating the duplicated 6-line config resolution pattern. - #10: Restore removed "why" comment explaining getLoaded() vs getAll() choice in getMCPSetupData โ€” documents non-obvious correctness constraint. - #11: Fix incomplete JSDoc param type on resolveAllMcpConfigs. * fix: consolidate imports, reorder constants, fix YAML-DB merge edge case - Merge duplicate @librechat/data-schemas requires in MCP.js into one - Move resolveConfigServers after module-level constants - Fix getAllServerConfigs edge case where user-DB entry overriding a YAML entry with the same name was excluded from userDbConfigs; now uses reference equality check to detect DB-overwritten YAML keys * fix: replace fragile string-match error detection with proper upsert method Add upsert() to IServerConfigsRepositoryInterface and all implementations (InMemory, Redis, RedisAggregateKey, DB). This eliminates the brittle error message string match ('already exists in cache') in upsertConfigCache that was the only thing preventing cross-process init races from silently discarding inspection results. Each implementation handles add-or-update atomically: - InMemory: direct Map.set() - Redis: direct cache.set() - RedisAggregateKey: read-modify-write under write lock - DB: delegates to update() (DB servers use explicit add() with ACL setup) * fix: wire configServers through remaining HTTP endpoints - getMCPServerById: use resolveAllMcpConfigs instead of bare getServerConfig - reinitialize route: resolve configServers before getServerConfig - auth-values route: resolve configServers before getServerConfig - getOAuthHeaders: accept configServers param, thread from callers - Update mcp.spec.js tests to mock getAllServerConfigs for GET by name * fix: thread serverConfig through getConnection for config-source servers Config-source servers exist only in configCacheRepo, not in YAML cache or DB. When callTool โ†’ getConnection โ†’ getUserConnection โ†’ getServerConfig runs without configServers, it returns undefined and throws. Fix by threading the pre-resolved serverConfig (providedConfig) from callTool through getConnection โ†’ getUserConnection โ†’ createUserConnectionInternal, using it as a fallback before the registry lookup. * fix: thread configServers through reinit, reconnect, and tool definition paths Wire configServers through every remaining call chain that creates or reconnects MCP server connections: - reinitMCPServer: accepts serverConfig and configServers, uses them for getServerConfig fallback, getConnection, and discoverServerTools - reconnectServer: accepts and passes configServers to reinitMCPServer - createMCPTools/createMCPTool: pass configServers to reconnectServer - ToolService.loadToolDefinitionsWrapper: resolves configServers from req, passes to both reinitMCPServer call sites - reinitialize route: passes serverConfig and configServers to reinitMCPServer * fix: address review findings โ€” simplify merge, harden error paths, fix log labels - Simplify getAllServerConfigs merge: replace fragile reference-equality loop with direct spread { ...yamlConfigs, ...configServers, ...base } - Guard upsertConfigCache in lazyInitConfigServer catch block so cache failures don't mask the original inspection error - Deduplicate getYamlServerNames cold-start with promise dedup pattern - Remove dead `if (!mcpConfig)` guard in getMCPSetupData - Fix hardcoded "App server" in ServerConfigsCacheRedisAggregateKey error messages โ€” now uses this.namespace for correct Config/App labeling - Remove misleading OAuth callback comment about readThrough cache - Move resolveConfigServers after module-level constants in MCP.js * fix: clear rejected yamlServerNames promise, fix config-source reinspect, fix reset log label - Clear yamlServerNamesPromise on rejection so transient cache errors don't permanently prevent ensureConfigServers from working - Skip reinspectServer for config-source servers (source: 'config') in reinitMCPServer โ€” they lack a CACHE/DB storage location; retry is handled by CONFIG_STUB_RETRY_MS in ensureConfigServers - Use source field instead of dbId for storageLocation derivation - Fix remaining hardcoded "App" in reset() leaderCheck message * fix: persist oauthHeaders in flow state for config-source OAuth servers The OAuth callback route has no JWT auth context and cannot resolve config-source server configs. Previously, getOAuthHeaders would silently return {} for config-source servers, dropping custom token exchange headers. Now oauthHeaders are persisted in MCPOAuthFlowMetadata during flow initiation (which has auth context), and the callback reads them from the stored flow state with a fallback to the registry lookup for YAML/user-DB servers. * fix: update tests for getMCPSetupData null guard removal and ToolService mock - MCP.spec.js: update test to expect graceful handling of null mcpConfig instead of a throw (getAllServerConfigs always returns an object) - MCP.js: add defensive || {} for Object.entries(mcpConfig) in case of null from test mocks - ToolService.spec.js: add missing mock for ~/server/services/MCP (resolveConfigServers) * fix: address review findings โ€” DRY, naming, logging, dead code, defensive guards - #1: Simplify getAllServerConfigs to single getBaseServerConfigs call, eliminating redundant double-fetch of cacheConfigsRepo.getAll() - #2: Add warning log when oauthHeaders absent from OAuth callback flow state - #3: Extract resolveAllMcpConfigs to MCP.js service layer; controller imports shared helper instead of reimplementing - #4: Rename _serverConfig/_provider to capturedServerConfig/capturedProvider in createToolInstance โ€” these are actively used, not unused - #5: Log rejected results from ensureConfigServers Promise.allSettled so cache errors are visible instead of silently dropped - #6: Remove dead 'MCP config not found' error handlers from routes - #7: Document circular-dependency reason for dynamic require in clearMcpConfigCache - #8: Remove logger.error from withTimeout to prevent double-logging timeouts - #10: Add explicit userId guard in ServerConfigsDB.upsert with clear error message - #12: Use spread instead of mutation in addServer for immutability consistency - Add upsert mock to ensureConfigServers.test.ts DB mock - Update route tests for resolveAllMcpConfigs import change * fix: restore correct merge priority, use immutable spread, fix test mock - getAllServerConfigs: { ...configServers, ...base } so userDB wins over configServers, matching documented "User DB (highest)" priority - lazyInitConfigServer: use immutable spread instead of direct mutation for parsedConfig.source, consistent with addServer fix - Fix test to mock getAllServerConfigs as {} instead of null, remove unnecessary || {} defensive guard in getMCPSetupData * fix: error handling, stable hashing, flatten nesting, remove dead param - Wrap resolveConfigServers/resolveAllMcpConfigs in try/catch with graceful {} fallback so transient DB/cache errors don't crash tool pipeline - Sort keys in configCacheKey JSON.stringify for deterministic hashing regardless of object property insertion order - Flatten clearMcpConfigCache from 3 nested try-catch to early returns; document that user connections are cleaned up lazily (accepted tradeoff) - Remove dead configServers param from getAppToolFunctions (never passed) - Add security rationale comment for source field in redactServerSecrets * fix: use recursive key-sorting replacer in configCacheKey to prevent cross-tenant cache collision The array replacer in JSON.stringify acts as a property allowlist at every nesting depth, silently dropping nested keys like headers['X-API-Key'], oauth.client_secret, etc. Two configs with different nested values but identical top-level structure produced the same hash, causing cross-tenant cache hits and potential credential contamination. Switch to a function replacer that recursively sorts keys at all depths without dropping any properties. Also document the known gap in getOAuthServers: config-source OAuth servers are not covered by auto-reconnection or uninstall cleanup because callers lack request context. * fix: move clearMcpConfigCache to packages/api to eliminate circular dependency The function only depends on MCPServersRegistry and MCPManager, both of which live in packages/api. Import it directly from @librechat/api in the CJS layer instead of using dynamic require('~/config'). * chore: imports/fields ordering * fix: address review findings โ€” error handling, targeted lookup, test gaps - Narrow resolveAllMcpConfigs catch to only wrap ensureConfigServers so getAppConfig/getAllServerConfigs failures propagate instead of masking infrastructure errors as empty server lists. - Use targeted getServerConfig in getMCPServerById instead of fetching all server configs for a single-server lookup. - Forward configServers to inner createMCPTool calls so reconnect path works for config-source servers. - Update getAllServerConfigs JSDoc to document disjoint-key design. - Add OAuth callback oauthHeaders fallback tests (flow state present vs registry fallback). - Add resolveConfigServers/resolveAllMcpConfigs unit tests covering happy path and error propagation. * fix: add getOAuthReconnectionManager mock to OAuth callback tests * chore: imports ordering --- api/app/clients/tools/util/handleTools.js | 16 +- api/server/controllers/agents/client.js | 5 + api/server/controllers/agents/client.test.js | 14 +- api/server/controllers/mcp.js | 23 +- api/server/routes/__tests__/mcp.spec.js | 146 ++++++-- api/server/routes/mcp.js | 52 ++- .../__tests__/invalidateConfigCaches.spec.js | 2 + api/server/services/Config/app.js | 6 +- api/server/services/MCP.js | 101 ++++-- api/server/services/MCP.spec.js | 34 +- api/server/services/ToolService.js | 4 + api/server/services/Tools/mcp.js | 23 +- api/server/services/__tests__/MCP.spec.js | 131 +++++++ .../services/__tests__/ToolService.spec.js | 3 + packages/api/src/agents/context.spec.ts | 20 +- packages/api/src/agents/context.ts | 18 +- packages/api/src/index.ts | 1 + packages/api/src/mcp/ConnectionsRepository.ts | 4 +- packages/api/src/mcp/MCPManager.ts | 40 ++- packages/api/src/mcp/UserConnectionManager.ts | 11 +- packages/api/src/mcp/__tests__/utils.test.ts | 27 ++ packages/api/src/mcp/cache.ts | 43 +++ packages/api/src/mcp/oauth/handler.ts | 2 + packages/api/src/mcp/oauth/types.ts | 2 + .../src/mcp/registry/MCPServerInspector.ts | 4 +- .../src/mcp/registry/MCPServersRegistry.ts | 323 +++++++++++++++-- .../ServerConfigsRepositoryInterface.ts | 3 + .../__tests__/MCPServerInspector.test.ts | 12 +- .../__tests__/MCPServersRegistry.test.ts | 12 +- .../__tests__/ensureConfigServers.test.ts | 328 ++++++++++++++++++ .../cache/ServerConfigsCacheFactory.ts | 14 +- .../cache/ServerConfigsCacheInMemory.ts | 4 + .../registry/cache/ServerConfigsCacheRedis.ts | 6 + .../ServerConfigsCacheRedisAggregateKey.ts | 22 +- .../src/mcp/registry/db/ServerConfigsDB.ts | 20 ++ packages/api/src/mcp/types/index.ts | 12 + packages/api/src/mcp/utils.ts | 11 + 37 files changed, 1337 insertions(+), 162 deletions(-) create mode 100644 api/server/services/__tests__/MCP.spec.js create mode 100644 packages/api/src/mcp/cache.ts create mode 100644 packages/api/src/mcp/registry/__tests__/ensureConfigServers.test.ts diff --git a/api/app/clients/tools/util/handleTools.js b/api/app/clients/tools/util/handleTools.js index 4b86101425..8adb43f945 100644 --- a/api/app/clients/tools/util/handleTools.js +++ b/api/app/clients/tools/util/handleTools.js @@ -14,7 +14,6 @@ const { buildImageToolContext, buildWebSearchContext, } = require('@librechat/api'); -const { getMCPServersRegistry } = require('~/config'); const { Tools, Constants, @@ -39,12 +38,13 @@ const { createGeminiImageTool, createOpenAIImageTools, } = require('../'); -const { primeFiles: primeCodeFiles } = require('~/server/services/Files/Code/process'); +const { createMCPTool, createMCPTools, resolveConfigServers } = require('~/server/services/MCP'); const { createFileSearchTool, primeFiles: primeSearchFiles } = require('./fileSearch'); +const { primeFiles: primeCodeFiles } = require('~/server/services/Files/Code/process'); const { getUserPluginAuthValue } = require('~/server/services/PluginService'); -const { createMCPTool, createMCPTools } = require('~/server/services/MCP'); const { loadAuthValues } = require('~/server/services/Tools/credentials'); const { getMCPServerTools } = require('~/server/services/Config'); +const { getMCPServersRegistry } = require('~/config'); const { getRoleByName } = require('~/models'); /** @@ -256,6 +256,12 @@ const loadTools = async ({ const toolContextMap = {}; const requestedMCPTools = {}; + /** Resolve config-source servers for the current user/tenant context */ + let configServers; + if (tools.some((tool) => tool && mcpToolPattern.test(tool))) { + configServers = await resolveConfigServers(options.req); + } + for (const tool of tools) { if (tool === Tools.execute_code) { requestedTools[tool] = async () => { @@ -341,7 +347,7 @@ const loadTools = async ({ continue; } const serverConfig = serverName - ? await getMCPServersRegistry().getServerConfig(serverName, user) + ? await getMCPServersRegistry().getServerConfig(serverName, user, configServers) : null; if (!serverConfig) { logger.warn( @@ -419,6 +425,7 @@ const loadTools = async ({ let index = -1; const failedMCPServers = new Set(); const safeUser = createSafeUser(options.req?.user); + for (const [serverName, toolConfigs] of Object.entries(requestedMCPTools)) { index++; /** @type {LCAvailableTools} */ @@ -433,6 +440,7 @@ const loadTools = async ({ signal, user: safeUser, userMCPAuthMap, + configServers, res: options.res, streamId: options.req?._resumableStreamId || null, model: agent?.model ?? model, diff --git a/api/server/controllers/agents/client.js b/api/server/controllers/agents/client.js index 47a10165e3..d6795a4be9 100644 --- a/api/server/controllers/agents/client.js +++ b/api/server/controllers/agents/client.js @@ -50,6 +50,7 @@ const { const { filterFilesByAgentAccess } = require('~/server/services/Files/permissions'); const { encodeAndFormat } = require('~/server/services/Files/images/encode'); const { createContextHandlers } = require('~/app/clients/prompts'); +const { resolveConfigServers } = require('~/server/services/MCP'); const { getMCPServerTools } = require('~/server/services/Config'); const BaseClient = require('~/app/clients/BaseClient'); const { getMCPManager } = require('~/config'); @@ -377,6 +378,9 @@ class AgentClient extends BaseClient { */ const ephemeralAgent = this.options.req.body.ephemeralAgent; const mcpManager = getMCPManager(); + + const configServers = await resolveConfigServers(this.options.req); + await Promise.all( allAgents.map(({ agent, agentId }) => applyContextToAgent({ @@ -384,6 +388,7 @@ class AgentClient extends BaseClient { agentId, logger, mcpManager, + configServers, sharedRunContext, ephemeralAgent: agentId === this.options.agent.id ? ephemeralAgent : undefined, }), diff --git a/api/server/controllers/agents/client.test.js b/api/server/controllers/agents/client.test.js index 41a806f66d..1595f652f7 100644 --- a/api/server/controllers/agents/client.test.js +++ b/api/server/controllers/agents/client.test.js @@ -22,6 +22,10 @@ jest.mock('~/server/services/Config', () => ({ getMCPServerTools: jest.fn(), })); +jest.mock('~/server/services/MCP', () => ({ + resolveConfigServers: jest.fn().mockResolvedValue({}), +})); + jest.mock('~/models', () => ({ getAgent: jest.fn(), getRoleByName: jest.fn(), @@ -1315,7 +1319,7 @@ describe('AgentClient - titleConvo', () => { }); // Verify formatInstructionsForContext was called with correct server names - expect(mockFormatInstructions).toHaveBeenCalledWith(['server1', 'server2']); + expect(mockFormatInstructions).toHaveBeenCalledWith(['server1', 'server2'], {}); // Verify the instructions do NOT contain [object Promise] expect(client.options.agent.instructions).not.toContain('[object Promise]'); @@ -1355,10 +1359,10 @@ describe('AgentClient - titleConvo', () => { }); // Verify formatInstructionsForContext was called with ephemeral server names - expect(mockFormatInstructions).toHaveBeenCalledWith([ - 'ephemeral-server1', - 'ephemeral-server2', - ]); + expect(mockFormatInstructions).toHaveBeenCalledWith( + ['ephemeral-server1', 'ephemeral-server2'], + {}, + ); // Verify no [object Promise] in instructions expect(client.options.agent.instructions).not.toContain('[object Promise]'); diff --git a/api/server/controllers/mcp.js b/api/server/controllers/mcp.js index 729f01da9d..e31bb93bc6 100644 --- a/api/server/controllers/mcp.js +++ b/api/server/controllers/mcp.js @@ -14,6 +14,7 @@ const { isMCPInspectionFailedError, } = require('@librechat/api'); const { Constants, MCPServerUserInputSchema } = require('librechat-data-provider'); +const { resolveConfigServers, resolveAllMcpConfigs } = require('~/server/services/MCP'); const { cacheMCPServerTools, getMCPServerTools } = require('~/server/services/Config'); const { getMCPManager, getMCPServersRegistry } = require('~/config'); @@ -57,7 +58,7 @@ function handleMCPError(error, res) { } /** - * Get all MCP tools available to the user + * Get all MCP tools available to the user. */ const getMCPTools = async (req, res) => { try { @@ -67,10 +68,10 @@ const getMCPTools = async (req, res) => { return res.status(401).json({ message: 'Unauthorized' }); } - const mcpConfig = await getMCPServersRegistry().getAllServerConfigs(userId); - const configuredServers = mcpConfig ? Object.keys(mcpConfig) : []; + const mcpConfig = await resolveAllMcpConfigs(userId, req.user); + const configuredServers = Object.keys(mcpConfig); - if (!mcpConfig || Object.keys(mcpConfig).length == 0) { + if (!configuredServers.length) { return res.status(200).json({ servers: {} }); } @@ -115,14 +116,11 @@ const getMCPTools = async (req, res) => { try { const serverTools = serverToolsMap.get(serverName); - // Get server config once const serverConfig = mcpConfig[serverName]; - const rawServerConfig = await getMCPServersRegistry().getServerConfig(serverName, userId); - // Initialize server object with all server-level data const server = { name: serverName, - icon: rawServerConfig?.iconPath || '', + icon: serverConfig?.iconPath || '', authenticated: true, authConfig: [], tools: [], @@ -183,7 +181,7 @@ const getMCPServersList = async (req, res) => { return res.status(401).json({ message: 'Unauthorized' }); } - const serverConfigs = await getMCPServersRegistry().getAllServerConfigs(userId); + const serverConfigs = await resolveAllMcpConfigs(userId, req.user); return res.json(redactAllServerSecrets(serverConfigs)); } catch (error) { logger.error('[getMCPServersList]', error); @@ -237,7 +235,12 @@ const getMCPServerById = async (req, res) => { if (!serverName) { return res.status(400).json({ message: 'Server name is required' }); } - const parsedConfig = await getMCPServersRegistry().getServerConfig(serverName, userId); + const configServers = await resolveConfigServers(req); + const parsedConfig = await getMCPServersRegistry().getServerConfig( + serverName, + userId, + configServers, + ); if (!parsedConfig) { return res.status(404).json({ message: 'MCP server not found' }); diff --git a/api/server/routes/__tests__/mcp.spec.js b/api/server/routes/__tests__/mcp.spec.js index 1ad8cac087..f194f361d3 100644 --- a/api/server/routes/__tests__/mcp.spec.js +++ b/api/server/routes/__tests__/mcp.spec.js @@ -18,6 +18,7 @@ const mockRegistryInstance = { getServerConfig: jest.fn(), getOAuthServers: jest.fn(), getAllServerConfigs: jest.fn(), + ensureConfigServers: jest.fn().mockResolvedValue({}), addServer: jest.fn(), updateServer: jest.fn(), removeServer: jest.fn(), @@ -58,6 +59,7 @@ jest.mock('@librechat/api', () => { }); jest.mock('@librechat/data-schemas', () => ({ + getTenantId: jest.fn(), logger: { debug: jest.fn(), info: jest.fn(), @@ -93,14 +95,18 @@ jest.mock('~/server/services/Config', () => ({ getCachedTools: jest.fn(), getMCPServerTools: jest.fn(), loadCustomConfig: jest.fn(), + getAppConfig: jest.fn().mockResolvedValue({ mcpConfig: {} }), })); jest.mock('~/server/services/Config/mcp', () => ({ updateMCPServerTools: jest.fn(), })); +const mockResolveAllMcpConfigs = jest.fn().mockResolvedValue({}); jest.mock('~/server/services/MCP', () => ({ getMCPSetupData: jest.fn(), + resolveConfigServers: jest.fn().mockResolvedValue({}), + resolveAllMcpConfigs: (...args) => mockResolveAllMcpConfigs(...args), getServerConnectionStatus: jest.fn(), })); @@ -579,6 +585,112 @@ describe('MCP Routes', () => { ); }); + it('should use oauthHeaders from flow state when present', async () => { + const mockFlowManager = { + getFlowState: jest.fn().mockResolvedValue({ status: 'PENDING' }), + completeFlow: jest.fn().mockResolvedValue(), + deleteFlow: jest.fn().mockResolvedValue(true), + }; + const mockFlowState = { + serverName: 'test-server', + userId: 'test-user-id', + metadata: { toolFlowId: 'tool-flow-123' }, + clientInfo: {}, + codeVerifier: 'test-verifier', + oauthHeaders: { 'X-Custom-Auth': 'header-value' }, + }; + const mockTokens = { access_token: 'tok', refresh_token: 'ref' }; + + MCPOAuthHandler.getFlowState.mockResolvedValue(mockFlowState); + MCPOAuthHandler.completeOAuthFlow.mockResolvedValue(mockTokens); + MCPTokenStorage.storeTokens.mockResolvedValue(); + getLogStores.mockReturnValue({}); + require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager); + require('~/config').getOAuthReconnectionManager.mockReturnValue({ + clearReconnection: jest.fn(), + }); + require('~/config').getMCPManager.mockReturnValue({ + getUserConnection: jest.fn().mockResolvedValue({ + fetchTools: jest.fn().mockResolvedValue([]), + }), + }); + const { getCachedTools, setCachedTools } = require('~/server/services/Config'); + getCachedTools.mockResolvedValue({}); + setCachedTools.mockResolvedValue(); + + const flowId = 'test-user-id:test-server'; + const csrfToken = generateTestCsrfToken(flowId); + + await request(app) + .get('/api/mcp/test-server/oauth/callback') + .set('Cookie', [`oauth_csrf=${csrfToken}`]) + .query({ code: 'auth-code', state: flowId }); + + expect(MCPOAuthHandler.completeOAuthFlow).toHaveBeenCalledWith( + flowId, + 'auth-code', + mockFlowManager, + { 'X-Custom-Auth': 'header-value' }, + ); + expect(mockRegistryInstance.getServerConfig).not.toHaveBeenCalled(); + }); + + it('should fall back to registry oauth_headers when flow state lacks them', async () => { + const mockFlowManager = { + getFlowState: jest.fn().mockResolvedValue({ status: 'PENDING' }), + completeFlow: jest.fn().mockResolvedValue(), + deleteFlow: jest.fn().mockResolvedValue(true), + }; + const mockFlowState = { + serverName: 'test-server', + userId: 'test-user-id', + metadata: { toolFlowId: 'tool-flow-123' }, + clientInfo: {}, + codeVerifier: 'test-verifier', + }; + const mockTokens = { access_token: 'tok', refresh_token: 'ref' }; + + MCPOAuthHandler.getFlowState.mockResolvedValue(mockFlowState); + MCPOAuthHandler.completeOAuthFlow.mockResolvedValue(mockTokens); + MCPTokenStorage.storeTokens.mockResolvedValue(); + mockRegistryInstance.getServerConfig.mockResolvedValue({ + oauth_headers: { 'X-Registry-Header': 'from-registry' }, + }); + getLogStores.mockReturnValue({}); + require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager); + require('~/config').getOAuthReconnectionManager.mockReturnValue({ + clearReconnection: jest.fn(), + }); + require('~/config').getMCPManager.mockReturnValue({ + getUserConnection: jest.fn().mockResolvedValue({ + fetchTools: jest.fn().mockResolvedValue([]), + }), + }); + const { getCachedTools, setCachedTools } = require('~/server/services/Config'); + getCachedTools.mockResolvedValue({}); + setCachedTools.mockResolvedValue(); + + const flowId = 'test-user-id:test-server'; + const csrfToken = generateTestCsrfToken(flowId); + + await request(app) + .get('/api/mcp/test-server/oauth/callback') + .set('Cookie', [`oauth_csrf=${csrfToken}`]) + .query({ code: 'auth-code', state: flowId }); + + expect(MCPOAuthHandler.completeOAuthFlow).toHaveBeenCalledWith( + flowId, + 'auth-code', + mockFlowManager, + { 'X-Registry-Header': 'from-registry' }, + ); + expect(mockRegistryInstance.getServerConfig).toHaveBeenCalledWith( + 'test-server', + 'test-user-id', + undefined, + ); + }); + it('should redirect to error page when callback processing fails', async () => { MCPOAuthHandler.getFlowState.mockRejectedValue(new Error('Callback error')); const flowId = 'test-user-id:test-server'; @@ -1350,19 +1462,10 @@ describe('MCP Routes', () => { }, }); - expect(getMCPSetupData).toHaveBeenCalledWith('test-user-id'); + expect(getMCPSetupData).toHaveBeenCalledWith('test-user-id', expect.any(Object)); expect(getServerConnectionStatus).toHaveBeenCalledTimes(2); }); - it('should return 404 when MCP config is not found', async () => { - getMCPSetupData.mockRejectedValue(new Error('MCP config not found')); - - const response = await request(app).get('/api/mcp/connection/status'); - - expect(response.status).toBe(404); - expect(response.body).toEqual({ error: 'MCP config not found' }); - }); - it('should return 500 when connection status check fails', async () => { getMCPSetupData.mockRejectedValue(new Error('Database error')); @@ -1437,15 +1540,6 @@ describe('MCP Routes', () => { }); }); - it('should return 404 when MCP config is not found', async () => { - getMCPSetupData.mockRejectedValue(new Error('MCP config not found')); - - const response = await request(app).get('/api/mcp/connection/status/test-server'); - - expect(response.status).toBe(404); - expect(response.body).toEqual({ error: 'MCP config not found' }); - }); - it('should return 500 when connection status check fails', async () => { getMCPSetupData.mockRejectedValue(new Error('Database connection failed')); @@ -1704,7 +1798,7 @@ describe('MCP Routes', () => { }, }; - mockRegistryInstance.getAllServerConfigs.mockResolvedValue(mockServerConfigs); + mockResolveAllMcpConfigs.mockResolvedValue(mockServerConfigs); const response = await request(app).get('/api/mcp/servers'); @@ -1721,11 +1815,14 @@ describe('MCP Routes', () => { }); expect(response.body['server-1'].headers).toBeUndefined(); expect(response.body['server-2'].headers).toBeUndefined(); - expect(mockRegistryInstance.getAllServerConfigs).toHaveBeenCalledWith('test-user-id'); + expect(mockResolveAllMcpConfigs).toHaveBeenCalledWith( + 'test-user-id', + expect.objectContaining({ id: 'test-user-id' }), + ); }); it('should return empty object when no servers are configured', async () => { - mockRegistryInstance.getAllServerConfigs.mockResolvedValue({}); + mockResolveAllMcpConfigs.mockResolvedValue({}); const response = await request(app).get('/api/mcp/servers'); @@ -1749,7 +1846,7 @@ describe('MCP Routes', () => { }); it('should return 500 when server config retrieval fails', async () => { - mockRegistryInstance.getAllServerConfigs.mockRejectedValue(new Error('Database error')); + mockResolveAllMcpConfigs.mockRejectedValue(new Error('Database error')); const response = await request(app).get('/api/mcp/servers'); @@ -1939,11 +2036,12 @@ describe('MCP Routes', () => { expect(mockRegistryInstance.getServerConfig).toHaveBeenCalledWith( 'test-server', 'test-user-id', + {}, ); }); it('should return 404 when server not found', async () => { - mockRegistryInstance.getServerConfig.mockResolvedValue(null); + mockRegistryInstance.getServerConfig.mockResolvedValue(undefined); const response = await request(app).get('/api/mcp/servers/non-existent-server'); diff --git a/api/server/routes/mcp.js b/api/server/routes/mcp.js index d6d7ed5ea0..c6496ad4b4 100644 --- a/api/server/routes/mcp.js +++ b/api/server/routes/mcp.js @@ -1,5 +1,5 @@ const { Router } = require('express'); -const { logger } = require('@librechat/data-schemas'); +const { logger, getTenantId } = require('@librechat/data-schemas'); const { CacheKeys, Constants, @@ -36,7 +36,11 @@ const { getFlowStateManager, getMCPManager, } = require('~/config'); -const { getMCPSetupData, getServerConnectionStatus } = require('~/server/services/MCP'); +const { + getServerConnectionStatus, + resolveConfigServers, + getMCPSetupData, +} = require('~/server/services/MCP'); const { requireJwtAuth, canAccessMCPServerResource } = require('~/server/middleware'); const { getUserPluginAuthValue } = require('~/server/services/PluginService'); const { updateMCPServerTools } = require('~/server/services/Config/mcp'); @@ -101,7 +105,8 @@ router.get('/:serverName/oauth/initiate', requireJwtAuth, setOAuthSession, async return res.status(400).json({ error: 'Invalid flow state' }); } - const oauthHeaders = await getOAuthHeaders(serverName, userId); + const configServers = await resolveConfigServers(req); + const oauthHeaders = await getOAuthHeaders(serverName, userId, configServers); const { authorizationUrl, flowId: oauthFlowId, @@ -233,7 +238,14 @@ router.get('/:serverName/oauth/callback', async (req, res) => { } logger.debug('[MCP OAuth] Completing OAuth flow'); - const oauthHeaders = await getOAuthHeaders(serverName, flowState.userId); + if (!flowState.oauthHeaders) { + logger.warn( + '[MCP OAuth] oauthHeaders absent from flow state โ€” config-source server oauth_headers will be empty', + { serverName, flowId }, + ); + } + const oauthHeaders = + flowState.oauthHeaders ?? (await getOAuthHeaders(serverName, flowState.userId)); const tokens = await MCPOAuthHandler.completeOAuthFlow(flowId, code, flowManager, oauthHeaders); logger.info('[MCP OAuth] OAuth flow completed, tokens received in callback route'); @@ -497,7 +509,12 @@ router.post( logger.info(`[MCP Reinitialize] Reinitializing server: ${serverName}`); const mcpManager = getMCPManager(); - const serverConfig = await getMCPServersRegistry().getServerConfig(serverName, user.id); + const configServers = await resolveConfigServers(req); + const serverConfig = await getMCPServersRegistry().getServerConfig( + serverName, + user.id, + configServers, + ); if (!serverConfig) { return res.status(404).json({ error: `MCP server '${serverName}' not found in configuration`, @@ -522,6 +539,8 @@ router.post( const result = await reinitMCPServer({ user, serverName, + serverConfig, + configServers, userMCPAuthMap, }); @@ -564,6 +583,7 @@ router.get('/connection/status', requireJwtAuth, async (req, res) => { const { mcpConfig, appConnections, userConnections, oauthServers } = await getMCPSetupData( user.id, + { role: user.role, tenantId: getTenantId() }, ); const connectionStatus = {}; @@ -593,9 +613,6 @@ router.get('/connection/status', requireJwtAuth, async (req, res) => { connectionStatus, }); } catch (error) { - if (error.message === 'MCP config not found') { - return res.status(404).json({ error: error.message }); - } logger.error('[MCP Connection Status] Failed to get connection status', error); res.status(500).json({ error: 'Failed to get connection status' }); } @@ -616,6 +633,7 @@ router.get('/connection/status/:serverName', requireJwtAuth, async (req, res) => const { mcpConfig, appConnections, userConnections, oauthServers } = await getMCPSetupData( user.id, + { role: user.role, tenantId: getTenantId() }, ); if (!mcpConfig[serverName]) { @@ -640,9 +658,6 @@ router.get('/connection/status/:serverName', requireJwtAuth, async (req, res) => requiresOAuth: serverStatus.requiresOAuth, }); } catch (error) { - if (error.message === 'MCP config not found') { - return res.status(404).json({ error: error.message }); - } logger.error( `[MCP Per-Server Status] Failed to get connection status for ${req.params.serverName}`, error, @@ -664,7 +679,12 @@ router.get('/:serverName/auth-values', requireJwtAuth, checkMCPUsePermissions, a return res.status(401).json({ error: 'User not authenticated' }); } - const serverConfig = await getMCPServersRegistry().getServerConfig(serverName, user.id); + const configServers = await resolveConfigServers(req); + const serverConfig = await getMCPServersRegistry().getServerConfig( + serverName, + user.id, + configServers, + ); if (!serverConfig) { return res.status(404).json({ error: `MCP server '${serverName}' not found in configuration`, @@ -703,8 +723,12 @@ router.get('/:serverName/auth-values', requireJwtAuth, checkMCPUsePermissions, a } }); -async function getOAuthHeaders(serverName, userId) { - const serverConfig = await getMCPServersRegistry().getServerConfig(serverName, userId); +async function getOAuthHeaders(serverName, userId, configServers) { + const serverConfig = await getMCPServersRegistry().getServerConfig( + serverName, + userId, + configServers, + ); return serverConfig?.oauth_headers ?? {}; } diff --git a/api/server/services/Config/__tests__/invalidateConfigCaches.spec.js b/api/server/services/Config/__tests__/invalidateConfigCaches.spec.js index df21786f05..49e94bc081 100644 --- a/api/server/services/Config/__tests__/invalidateConfigCaches.spec.js +++ b/api/server/services/Config/__tests__/invalidateConfigCaches.spec.js @@ -32,12 +32,14 @@ jest.mock('../getCachedTools', () => ({ invalidateCachedTools: mockInvalidateCachedTools, })); +const mockClearMcpConfigCache = jest.fn().mockResolvedValue(undefined); jest.mock('@librechat/api', () => ({ createAppConfigService: jest.fn(() => ({ getAppConfig: jest.fn().mockResolvedValue({ availableTools: {} }), clearAppConfigCache: mockClearAppConfigCache, clearOverrideCache: mockClearOverrideCache, })), + clearMcpConfigCache: mockClearMcpConfigCache, })); // โ”€โ”€ Tests โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ diff --git a/api/server/services/Config/app.js b/api/server/services/Config/app.js index c0180fdb12..7530ca1031 100644 --- a/api/server/services/Config/app.js +++ b/api/server/services/Config/app.js @@ -1,6 +1,6 @@ const { CacheKeys } = require('librechat-data-provider'); -const { createAppConfigService } = require('@librechat/api'); const { AppService, logger } = require('@librechat/data-schemas'); +const { createAppConfigService, clearMcpConfigCache } = require('@librechat/api'); const { setCachedTools, invalidateCachedTools } = require('./getCachedTools'); const { loadAndFormatTools } = require('~/server/services/start/tools'); const loadCustomConfig = require('./loadCustomConfig'); @@ -42,7 +42,7 @@ async function clearEndpointConfigCache() { /** * Invalidate all config-related caches after an admin config mutation. * Clears the base config, per-principal override caches, tool caches, - * and the endpoints config cache. + * the endpoints config cache, and the MCP config-source server cache. * @param {string} [tenantId] - Optional tenant ID to scope override cache clearing. */ async function invalidateConfigCaches(tenantId) { @@ -51,12 +51,14 @@ async function invalidateConfigCaches(tenantId) { clearOverrideCache(tenantId), invalidateCachedTools({ invalidateGlobal: true }), clearEndpointConfigCache(), + clearMcpConfigCache(), ]); const labels = [ 'clearAppConfigCache', 'clearOverrideCache', 'invalidateCachedTools', 'clearEndpointConfigCache', + 'clearMcpConfigCache', ]; for (let i = 0; i < results.length; i++) { if (results[i].status === 'rejected') { diff --git a/api/server/services/MCP.js b/api/server/services/MCP.js index d765d335aa..dbb44740a9 100644 --- a/api/server/services/MCP.js +++ b/api/server/services/MCP.js @@ -1,5 +1,5 @@ const { tool } = require('@langchain/core/tools'); -const { logger } = require('@librechat/data-schemas'); +const { logger, getTenantId } = require('@librechat/data-schemas'); const { Providers, StepTypes, @@ -54,6 +54,53 @@ function evictStale(map, ttl) { const unavailableMsg = "This tool's MCP server is temporarily unavailable. Please try again shortly."; +/** + * Resolves config-source MCP servers from admin Config overrides for the current + * request context. Returns the parsed configs keyed by server name. + * @param {import('express').Request} req - Express request with user context + * @returns {Promise>} + */ +async function resolveConfigServers(req) { + try { + const registry = getMCPServersRegistry(); + const user = req?.user; + const appConfig = await getAppConfig({ + role: user?.role, + tenantId: getTenantId(), + userId: user?.id, + }); + return await registry.ensureConfigServers(appConfig?.mcpConfig || {}); + } catch (error) { + logger.warn( + '[resolveConfigServers] Failed to resolve config servers, degrading to empty:', + error, + ); + return {}; + } +} + +/** + * Resolves config-source servers and merges all server configs (YAML + config + user DB) + * for the given user context. Shared helper for controllers needing the full merged config. + * @param {string} userId + * @param {{ id?: string, role?: string }} [user] + * @returns {Promise>} + */ +async function resolveAllMcpConfigs(userId, user) { + const registry = getMCPServersRegistry(); + const appConfig = await getAppConfig({ role: user?.role, tenantId: getTenantId(), userId }); + let configServers = {}; + try { + configServers = await registry.ensureConfigServers(appConfig?.mcpConfig || {}); + } catch (error) { + logger.warn( + '[resolveAllMcpConfigs] Config server resolution failed, continuing without:', + error, + ); + } + return await registry.getAllServerConfigs(userId, configServers); +} + /** * @param {string} toolName * @param {string} serverName @@ -249,6 +296,7 @@ async function reconnectServer({ index, signal, serverName, + configServers, userMCPAuthMap, streamId = null, }) { @@ -317,6 +365,7 @@ async function reconnectServer({ user, signal, serverName, + configServers, oauthStart, flowManager, userMCPAuthMap, @@ -359,13 +408,12 @@ async function createMCPTools({ config, provider, serverName, + configServers, userMCPAuthMap, streamId = null, }) { - // Early domain validation before reconnecting server (avoid wasted work on disallowed domains) - // Use getAppConfig() to support per-user/role domain restrictions const serverConfig = - config ?? (await getMCPServersRegistry().getServerConfig(serverName, user?.id)); + config ?? (await getMCPServersRegistry().getServerConfig(serverName, user?.id, configServers)); if (serverConfig?.url) { const appConfig = await getAppConfig({ role: user?.role, tenantId: user?.tenantId }); const allowedDomains = appConfig?.mcpSettings?.allowedDomains; @@ -382,6 +430,7 @@ async function createMCPTools({ index, signal, serverName, + configServers, userMCPAuthMap, streamId, }); @@ -401,6 +450,7 @@ async function createMCPTools({ user, provider, userMCPAuthMap, + configServers, streamId, availableTools: result.availableTools, toolKey: `${tool.name}${Constants.mcp_delimiter}${serverName}`, @@ -440,14 +490,13 @@ async function createMCPTool({ userMCPAuthMap, availableTools, config, + configServers, streamId = null, }) { const [toolName, serverName] = toolKey.split(Constants.mcp_delimiter); - // Runtime domain validation: check if the server's domain is still allowed - // Use getAppConfig() to support per-user/role domain restrictions const serverConfig = - config ?? (await getMCPServersRegistry().getServerConfig(serverName, user?.id)); + config ?? (await getMCPServersRegistry().getServerConfig(serverName, user?.id, configServers)); if (serverConfig?.url) { const appConfig = await getAppConfig({ role: user?.role, tenantId: user?.tenantId }); const allowedDomains = appConfig?.mcpSettings?.allowedDomains; @@ -478,6 +527,7 @@ async function createMCPTool({ index, signal, serverName, + configServers, userMCPAuthMap, streamId, }); @@ -501,6 +551,7 @@ async function createMCPTool({ provider, toolName, serverName, + serverConfig, toolDefinition, streamId, }); @@ -510,13 +561,14 @@ function createToolInstance({ res, toolName, serverName, + serverConfig: capturedServerConfig, toolDefinition, - provider: _provider, + provider: capturedProvider, streamId = null, }) { /** @type {LCTool} */ const { description, parameters } = toolDefinition; - const isGoogle = _provider === Providers.VERTEXAI || _provider === Providers.GOOGLE; + const isGoogle = capturedProvider === Providers.VERTEXAI || capturedProvider === Providers.GOOGLE; let schema = parameters ? normalizeJsonSchema(resolveJsonSchemaRefs(parameters)) : null; @@ -545,7 +597,7 @@ function createToolInstance({ const flowManager = getFlowStateManager(flowsCache); derivedSignal = config?.signal ? AbortSignal.any([config.signal]) : undefined; const mcpManager = getMCPManager(userId); - const provider = (config?.metadata?.provider || _provider)?.toLowerCase(); + const provider = (config?.metadata?.provider || capturedProvider)?.toLowerCase(); const { args: _args, stepId, ...toolCall } = config.toolCall ?? {}; const flowId = `${serverName}:oauth_login:${config.metadata.thread_id}:${config.metadata.run_id}`; @@ -577,6 +629,7 @@ function createToolInstance({ const result = await mcpManager.callTool({ serverName, + serverConfig: capturedServerConfig, toolName, provider, toolArguments, @@ -644,30 +697,36 @@ function createToolInstance({ } /** - * Get MCP setup data including config, connections, and OAuth servers + * Get MCP setup data including config, connections, and OAuth servers. + * Resolves config-source servers from admin Config overrides when tenant context is available. * @param {string} userId - The user ID + * @param {{ role?: string, tenantId?: string }} [options] - Optional role/tenant context * @returns {Object} Object containing mcpConfig, appConnections, userConnections, and oauthServers */ -async function getMCPSetupData(userId) { - const mcpConfig = await getMCPServersRegistry().getAllServerConfigs(userId); - - if (!mcpConfig) { - throw new Error('MCP config not found'); - } +async function getMCPSetupData(userId, options = {}) { + const registry = getMCPServersRegistry(); + const { role, tenantId } = options; + const appConfig = await getAppConfig({ role, tenantId, userId }); + const configServers = await registry.ensureConfigServers(appConfig?.mcpConfig || {}); + const mcpConfig = await registry.getAllServerConfigs(userId, configServers); const mcpManager = getMCPManager(userId); /** @type {Map} */ let appConnections = new Map(); try { - // Use getLoaded() instead of getAll() to avoid forcing connection creation + // Use getLoaded() instead of getAll() to avoid forcing connection creation. // getAll() creates connections for all servers, which is problematic for servers - // that require user context (e.g., those with {{LIBRECHAT_USER_ID}} placeholders) + // that require user context (e.g., those with {{LIBRECHAT_USER_ID}} placeholders). appConnections = (await mcpManager.appConnections?.getLoaded()) || new Map(); } catch (error) { logger.error(`[MCP][User: ${userId}] Error getting app connections:`, error); } const userConnections = mcpManager.getUserConnections(userId) || new Map(); - const oauthServers = await getMCPServersRegistry().getOAuthServers(userId); + const oauthServers = new Set( + Object.entries(mcpConfig) + .filter(([, config]) => config.requiresOAuth) + .map(([name]) => name), + ); return { mcpConfig, @@ -789,6 +848,8 @@ module.exports = { createMCPTool, createMCPTools, getMCPSetupData, + resolveConfigServers, + resolveAllMcpConfigs, checkOAuthFlowStatus, getServerConnectionStatus, createUnavailableToolStub, diff --git a/api/server/services/MCP.spec.js b/api/server/services/MCP.spec.js index 14a9ef90ed..c9925827f8 100644 --- a/api/server/services/MCP.spec.js +++ b/api/server/services/MCP.spec.js @@ -14,6 +14,7 @@ const mockRegistryInstance = { getOAuthServers: jest.fn(() => Promise.resolve(new Set())), getAllServerConfigs: jest.fn(() => Promise.resolve({})), getServerConfig: jest.fn(() => Promise.resolve(null)), + ensureConfigServers: jest.fn(() => Promise.resolve({})), }; // Create isMCPDomainAllowed mock that can be configured per-test @@ -113,38 +114,43 @@ describe('tests for the new helper functions used by the MCP connection status e }); it('should successfully return MCP setup data', async () => { - mockRegistryInstance.getAllServerConfigs.mockResolvedValue(mockConfig); + const mockConfigWithOAuth = { + server1: { type: 'stdio' }, + server2: { type: 'http', requiresOAuth: true }, + }; + mockRegistryInstance.getAllServerConfigs.mockResolvedValue(mockConfigWithOAuth); const mockAppConnections = new Map([['server1', { status: 'connected' }]]); const mockUserConnections = new Map([['server2', { status: 'disconnected' }]]); - const mockOAuthServers = new Set(['server2']); const mockMCPManager = { appConnections: { getLoaded: jest.fn(() => Promise.resolve(mockAppConnections)) }, getUserConnections: jest.fn(() => mockUserConnections), }; mockGetMCPManager.mockReturnValue(mockMCPManager); - mockRegistryInstance.getOAuthServers.mockResolvedValue(mockOAuthServers); const result = await getMCPSetupData(mockUserId); - expect(mockRegistryInstance.getAllServerConfigs).toHaveBeenCalledWith(mockUserId); + expect(mockRegistryInstance.ensureConfigServers).toHaveBeenCalled(); + expect(mockRegistryInstance.getAllServerConfigs).toHaveBeenCalledWith( + mockUserId, + expect.any(Object), + ); expect(mockGetMCPManager).toHaveBeenCalledWith(mockUserId); expect(mockMCPManager.appConnections.getLoaded).toHaveBeenCalled(); expect(mockMCPManager.getUserConnections).toHaveBeenCalledWith(mockUserId); - expect(mockRegistryInstance.getOAuthServers).toHaveBeenCalledWith(mockUserId); - expect(result).toEqual({ - mcpConfig: mockConfig, - appConnections: mockAppConnections, - userConnections: mockUserConnections, - oauthServers: mockOAuthServers, - }); + expect(result.mcpConfig).toEqual(mockConfigWithOAuth); + expect(result.appConnections).toEqual(mockAppConnections); + expect(result.userConnections).toEqual(mockUserConnections); + expect(result.oauthServers).toEqual(new Set(['server2'])); }); - it('should throw error when MCP config not found', async () => { - mockRegistryInstance.getAllServerConfigs.mockResolvedValue(null); - await expect(getMCPSetupData(mockUserId)).rejects.toThrow('MCP config not found'); + it('should return empty data when no servers are configured', async () => { + mockRegistryInstance.getAllServerConfigs.mockResolvedValue({}); + const result = await getMCPSetupData(mockUserId); + expect(result.mcpConfig).toEqual({}); + expect(result.oauthServers).toEqual(new Set()); }); it('should handle null values from MCP manager gracefully', async () => { diff --git a/api/server/services/ToolService.js b/api/server/services/ToolService.js index 838de906fe..c11843cb69 100644 --- a/api/server/services/ToolService.js +++ b/api/server/services/ToolService.js @@ -60,6 +60,7 @@ const { manifestToolMap, toolkits } = require('~/app/clients/tools/manifest'); const { createOnSearchResults } = require('~/server/services/Tools/search'); const { loadAuthValues } = require('~/server/services/Tools/credentials'); const { reinitMCPServer } = require('~/server/services/Tools/mcp'); +const { resolveConfigServers } = require('~/server/services/MCP'); const { recordUsage } = require('~/server/services/Threads'); const { loadTools } = require('~/app/clients/tools/util'); const { redactMessage } = require('~/config/parsers'); @@ -514,6 +515,7 @@ async function loadToolDefinitionsWrapper({ req, res, agent, streamId = null, to const flowsCache = getLogStores(CacheKeys.FLOWS); const flowManager = getFlowStateManager(flowsCache); + const configServers = await resolveConfigServers(req); const pendingOAuthServers = new Set(); const createOAuthEmitter = (serverName) => { @@ -579,6 +581,7 @@ async function loadToolDefinitionsWrapper({ req, res, agent, streamId = null, to oauthStart, flowManager, serverName, + configServers, userMCPAuthMap, }); @@ -666,6 +669,7 @@ async function loadToolDefinitionsWrapper({ req, res, agent, streamId = null, to const result = await reinitMCPServer({ user: req.user, serverName, + configServers, userMCPAuthMap, flowManager, returnOnOAuth: false, diff --git a/api/server/services/Tools/mcp.js b/api/server/services/Tools/mcp.js index 7589043e10..f1ebcf9796 100644 --- a/api/server/services/Tools/mcp.js +++ b/api/server/services/Tools/mcp.js @@ -25,11 +25,13 @@ async function reinitMCPServer({ signal, forceNew, serverName, + configServers, userMCPAuthMap, connectionTimeout, returnOnOAuth = true, oauthStart: _oauthStart, flowManager: _flowManager, + serverConfig: providedConfig, }) { /** @type {MCPConnection | null} */ let connection = null; @@ -42,13 +44,28 @@ async function reinitMCPServer({ try { const registry = getMCPServersRegistry(); - const serverConfig = await registry.getServerConfig(serverName, user?.id); + const serverConfig = + providedConfig ?? (await registry.getServerConfig(serverName, user?.id, configServers)); if (serverConfig?.inspectionFailed) { + if (serverConfig.source === 'config') { + logger.info( + `[MCP Reinitialize] Config-source server ${serverName} has inspectionFailed โ€” retry handled by config cache`, + ); + return { + availableTools: null, + success: false, + message: `MCP server '${serverName}' is still unreachable`, + oauthRequired: false, + serverName, + oauthUrl: null, + tools: null, + }; + } logger.info( `[MCP Reinitialize] Server ${serverName} had failed inspection, attempting reinspection`, ); try { - const storageLocation = serverConfig.dbId ? 'DB' : 'CACHE'; + const storageLocation = serverConfig.source === 'user' ? 'DB' : 'CACHE'; await registry.reinspectServer(serverName, storageLocation, user?.id); logger.info(`[MCP Reinitialize] Reinspection succeeded for server: ${serverName}`); } catch (reinspectError) { @@ -93,6 +110,7 @@ async function reinitMCPServer({ returnOnOAuth, customUserVars, connectionTimeout, + serverConfig, }); logger.info(`[MCP Reinitialize] Successfully established connection for ${serverName}`); @@ -125,6 +143,7 @@ async function reinitMCPServer({ oauthStart, customUserVars, connectionTimeout, + configServers, }); if (discoveryResult.tools && discoveryResult.tools.length > 0) { diff --git a/api/server/services/__tests__/MCP.spec.js b/api/server/services/__tests__/MCP.spec.js new file mode 100644 index 0000000000..39e99d54ac --- /dev/null +++ b/api/server/services/__tests__/MCP.spec.js @@ -0,0 +1,131 @@ +const mockRegistry = { + ensureConfigServers: jest.fn(), + getAllServerConfigs: jest.fn(), +}; + +jest.mock('~/config', () => ({ + getMCPServersRegistry: jest.fn(() => mockRegistry), + getMCPManager: jest.fn(), + getFlowStateManager: jest.fn(), + getOAuthReconnectionManager: jest.fn(), +})); + +jest.mock('@librechat/data-schemas', () => ({ + getTenantId: jest.fn(() => 'tenant-1'), + logger: { debug: jest.fn(), info: jest.fn(), warn: jest.fn(), error: jest.fn() }, +})); + +jest.mock('~/server/services/Config', () => ({ + getAppConfig: jest.fn(), + setCachedTools: jest.fn(), + getCachedTools: jest.fn(), + getMCPServerTools: jest.fn(), + loadCustomConfig: jest.fn(), +})); + +jest.mock('~/cache', () => ({ getLogStores: jest.fn() })); +jest.mock('~/models', () => ({ + findToken: jest.fn(), + createToken: jest.fn(), + updateToken: jest.fn(), +})); +jest.mock('~/server/services/GraphTokenService', () => ({ + getGraphApiToken: jest.fn(), +})); +jest.mock('~/server/services/Tools/mcp', () => ({ + reinitMCPServer: jest.fn(), +})); + +const { getAppConfig } = require('~/server/services/Config'); +const { resolveConfigServers, resolveAllMcpConfigs } = require('../MCP'); + +describe('resolveConfigServers', () => { + beforeEach(() => jest.clearAllMocks()); + + it('resolves config servers for the current request context', async () => { + getAppConfig.mockResolvedValue({ mcpConfig: { srv: { url: 'http://a' } } }); + mockRegistry.ensureConfigServers.mockResolvedValue({ srv: { name: 'srv' } }); + + const result = await resolveConfigServers({ user: { id: 'u1', role: 'admin' } }); + + expect(result).toEqual({ srv: { name: 'srv' } }); + expect(getAppConfig).toHaveBeenCalledWith( + expect.objectContaining({ role: 'admin', userId: 'u1' }), + ); + expect(mockRegistry.ensureConfigServers).toHaveBeenCalledWith({ srv: { url: 'http://a' } }); + }); + + it('returns {} when ensureConfigServers throws', async () => { + getAppConfig.mockResolvedValue({ mcpConfig: { srv: {} } }); + mockRegistry.ensureConfigServers.mockRejectedValue(new Error('inspect failed')); + + const result = await resolveConfigServers({ user: { id: 'u1' } }); + + expect(result).toEqual({}); + }); + + it('returns {} when getAppConfig throws', async () => { + getAppConfig.mockRejectedValue(new Error('db timeout')); + + const result = await resolveConfigServers({ user: { id: 'u1' } }); + + expect(result).toEqual({}); + }); + + it('passes empty mcpConfig when appConfig has none', async () => { + getAppConfig.mockResolvedValue({}); + mockRegistry.ensureConfigServers.mockResolvedValue({}); + + await resolveConfigServers({ user: { id: 'u1' } }); + + expect(mockRegistry.ensureConfigServers).toHaveBeenCalledWith({}); + }); +}); + +describe('resolveAllMcpConfigs', () => { + beforeEach(() => jest.clearAllMocks()); + + it('merges config servers with base servers', async () => { + getAppConfig.mockResolvedValue({ mcpConfig: { cfg_srv: {} } }); + mockRegistry.ensureConfigServers.mockResolvedValue({ cfg_srv: { name: 'cfg_srv' } }); + mockRegistry.getAllServerConfigs.mockResolvedValue({ + cfg_srv: { name: 'cfg_srv' }, + yaml_srv: { name: 'yaml_srv' }, + }); + + const result = await resolveAllMcpConfigs('u1', { id: 'u1', role: 'user' }); + + expect(result).toEqual({ + cfg_srv: { name: 'cfg_srv' }, + yaml_srv: { name: 'yaml_srv' }, + }); + expect(mockRegistry.getAllServerConfigs).toHaveBeenCalledWith('u1', { + cfg_srv: { name: 'cfg_srv' }, + }); + }); + + it('continues with empty configServers when ensureConfigServers fails', async () => { + getAppConfig.mockResolvedValue({ mcpConfig: { srv: {} } }); + mockRegistry.ensureConfigServers.mockRejectedValue(new Error('inspect failed')); + mockRegistry.getAllServerConfigs.mockResolvedValue({ yaml_srv: { name: 'yaml_srv' } }); + + const result = await resolveAllMcpConfigs('u1', { id: 'u1' }); + + expect(result).toEqual({ yaml_srv: { name: 'yaml_srv' } }); + expect(mockRegistry.getAllServerConfigs).toHaveBeenCalledWith('u1', {}); + }); + + it('propagates getAllServerConfigs failures', async () => { + getAppConfig.mockResolvedValue({ mcpConfig: {} }); + mockRegistry.ensureConfigServers.mockResolvedValue({}); + mockRegistry.getAllServerConfigs.mockRejectedValue(new Error('redis down')); + + await expect(resolveAllMcpConfigs('u1', { id: 'u1' })).rejects.toThrow('redis down'); + }); + + it('propagates getAppConfig failures', async () => { + getAppConfig.mockRejectedValue(new Error('mongo down')); + + await expect(resolveAllMcpConfigs('u1', { id: 'u1' })).rejects.toThrow('mongo down'); + }); +}); diff --git a/api/server/services/__tests__/ToolService.spec.js b/api/server/services/__tests__/ToolService.spec.js index a468a88eb3..6e06804280 100644 --- a/api/server/services/__tests__/ToolService.spec.js +++ b/api/server/services/__tests__/ToolService.spec.js @@ -64,6 +64,9 @@ jest.mock('~/models', () => ({ jest.mock('~/config', () => ({ getFlowStateManager: jest.fn(() => ({})), })); +jest.mock('~/server/services/MCP', () => ({ + resolveConfigServers: jest.fn().mockResolvedValue({}), +})); jest.mock('~/cache', () => ({ getLogStores: jest.fn(() => ({})), })); diff --git a/packages/api/src/agents/context.spec.ts b/packages/api/src/agents/context.spec.ts index c5358209c7..1d995a52bb 100644 --- a/packages/api/src/agents/context.spec.ts +++ b/packages/api/src/agents/context.spec.ts @@ -154,10 +154,10 @@ describe('Agent Context Utilities', () => { ); expect(result).toBe(instructions); - expect(mockMCPManager.formatInstructionsForContext).toHaveBeenCalledWith([ - 'server1', - 'server2', - ]); + expect(mockMCPManager.formatInstructionsForContext).toHaveBeenCalledWith( + ['server1', 'server2'], + undefined, + ); expect(mockLogger.debug).toHaveBeenCalledWith( '[AgentContext] Fetched MCP instructions for servers:', ['server1', 'server2'], @@ -345,9 +345,10 @@ describe('Agent Context Utilities', () => { logger: mockLogger, }); - expect(mockMCPManager.formatInstructionsForContext).toHaveBeenCalledWith([ - 'ephemeral-server', - ]); + expect(mockMCPManager.formatInstructionsForContext).toHaveBeenCalledWith( + ['ephemeral-server'], + undefined, + ); expect(agent.instructions).toContain('Ephemeral MCP'); }); @@ -375,7 +376,10 @@ describe('Agent Context Utilities', () => { logger: mockLogger, }); - expect(mockMCPManager.formatInstructionsForContext).toHaveBeenCalledWith(['agent-server']); + expect(mockMCPManager.formatInstructionsForContext).toHaveBeenCalledWith( + ['agent-server'], + undefined, + ); }); it('should work without agentId', async () => { diff --git a/packages/api/src/agents/context.ts b/packages/api/src/agents/context.ts index ebae2e0f9f..c526fd13fe 100644 --- a/packages/api/src/agents/context.ts +++ b/packages/api/src/agents/context.ts @@ -1,8 +1,9 @@ -import { DynamicStructuredTool } from '@langchain/core/tools'; import { Constants } from 'librechat-data-provider'; +import { DynamicStructuredTool } from '@langchain/core/tools'; import type { Agent, TEphemeralAgent } from 'librechat-data-provider'; import type { LCTool } from '@librechat/agents'; import type { Logger } from 'winston'; +import type { ParsedServerConfig } from '~/mcp/types'; import type { MCPManager } from '~/mcp/MCPManager'; /** @@ -63,12 +64,16 @@ export async function getMCPInstructionsForServers( mcpServers: string[], mcpManager: MCPManager, logger?: Logger, + configServers?: Record, ): Promise { if (!mcpServers.length) { return ''; } try { - const mcpInstructions = await mcpManager.formatInstructionsForContext(mcpServers); + const mcpInstructions = await mcpManager.formatInstructionsForContext( + mcpServers, + configServers, + ); if (mcpInstructions && logger) { logger.debug('[AgentContext] Fetched MCP instructions for servers:', mcpServers); } @@ -125,6 +130,7 @@ export async function applyContextToAgent({ ephemeralAgent, agentId, logger, + configServers, }: { agent: AgentWithTools; sharedRunContext: string; @@ -132,12 +138,18 @@ export async function applyContextToAgent({ ephemeralAgent?: TEphemeralAgent; agentId?: string; logger?: Logger; + configServers?: Record; }): Promise { const baseInstructions = agent.instructions || ''; try { const mcpServers = ephemeralAgent?.mcp?.length ? ephemeralAgent.mcp : extractMCPServers(agent); - const mcpInstructions = await getMCPInstructionsForServers(mcpServers, mcpManager, logger); + const mcpInstructions = await getMCPInstructionsForServers( + mcpServers, + mcpManager, + logger, + configServers, + ); agent.instructions = buildAgentInstructions({ sharedRunContext, diff --git a/packages/api/src/index.ts b/packages/api/src/index.ts index 5ccf6b0124..7a04b8e74a 100644 --- a/packages/api/src/index.ts +++ b/packages/api/src/index.ts @@ -14,6 +14,7 @@ export * from './mcp/oauth'; export * from './mcp/auth'; export * from './mcp/zod'; export * from './mcp/errors'; +export * from './mcp/cache'; /* Utilities */ export * from './mcp/utils'; export * from './utils'; diff --git a/packages/api/src/mcp/ConnectionsRepository.ts b/packages/api/src/mcp/ConnectionsRepository.ts index 6313faa8d4..79976b1199 100644 --- a/packages/api/src/mcp/ConnectionsRepository.ts +++ b/packages/api/src/mcp/ConnectionsRepository.ts @@ -2,7 +2,7 @@ import { logger } from '@librechat/data-schemas'; import type * as t from './types'; import { MCPServersRegistry } from '~/mcp/registry/MCPServersRegistry'; import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory'; -import { hasCustomUserVars } from './utils'; +import { hasCustomUserVars, isUserSourced } from './utils'; import { MCPConnection } from './connection'; const CONNECT_CONCURRENCY = 3; @@ -82,7 +82,7 @@ export class ConnectionsRepository { { serverName, serverConfig, - dbSourced: !!(serverConfig as t.ParsedServerConfig).dbId, + dbSourced: isUserSourced(serverConfig as t.ParsedServerConfig), useSSRFProtection: registry.shouldEnableSSRFProtection(), allowedDomains: registry.getAllowedDomains(), }, diff --git a/packages/api/src/mcp/MCPManager.ts b/packages/api/src/mcp/MCPManager.ts index 935307fa49..12227de39f 100644 --- a/packages/api/src/mcp/MCPManager.ts +++ b/packages/api/src/mcp/MCPManager.ts @@ -18,6 +18,7 @@ import { preProcessGraphTokens } from '~/utils/graph'; import { formatToolContent } from './parsers'; import { MCPConnection } from './connection'; import { processMCPEnv } from '~/utils/env'; +import { isUserSourced } from './utils'; /** * Centralized manager for MCP server connections and tool execution. @@ -53,6 +54,8 @@ export class MCPManager extends UserConnectionManager { user?: IUser; forceNew?: boolean; flowManager?: FlowStateManager; + /** Pre-resolved config for config-source servers not in YAML/DB */ + serverConfig?: t.ParsedServerConfig; } & Omit, ): Promise { //the get method checks if the config is still valid as app level @@ -91,6 +94,7 @@ export class MCPManager extends UserConnectionManager { const serverConfig = await MCPServersRegistry.getInstance().getServerConfig( serverName, user?.id, + args.configServers, ); if (!serverConfig) { @@ -103,7 +107,7 @@ export class MCPManager extends UserConnectionManager { const registry = MCPServersRegistry.getInstance(); const useSSRFProtection = registry.shouldEnableSSRFProtection(); const allowedDomains = registry.getAllowedDomains(); - const dbSourced = !!serverConfig.dbId; + const dbSourced = isUserSourced(serverConfig); const basic: t.BasicConnectionOptions = { dbSourced, serverName, @@ -193,9 +197,15 @@ export class MCPManager extends UserConnectionManager { * @param serverNames Optional array of server names. If not provided or empty, returns all servers. * @returns Object mapping server names to their instructions */ - private async getInstructions(serverNames?: string[]): Promise> { + private async getInstructions( + serverNames?: string[], + configServers?: Record, + ): Promise> { const instructions: Record = {}; - const configs = await MCPServersRegistry.getInstance().getAllServerConfigs(); + const configs = await MCPServersRegistry.getInstance().getAllServerConfigs( + undefined, + configServers, + ); for (const [serverName, config] of Object.entries(configs)) { if (config.serverInstructions != null) { instructions[serverName] = config.serverInstructions as string; @@ -210,9 +220,11 @@ export class MCPManager extends UserConnectionManager { * @param serverNames Optional array of server names to include. If not provided, includes all servers. * @returns Formatted instructions string ready for context injection */ - public async formatInstructionsForContext(serverNames?: string[]): Promise { - /** Instructions for specified servers or all stored instructions */ - const instructionsToInclude = await this.getInstructions(serverNames); + public async formatInstructionsForContext( + serverNames?: string[], + configServers?: Record, + ): Promise { + const instructionsToInclude = await this.getInstructions(serverNames, configServers); if (Object.keys(instructionsToInclude).length === 0) { return ''; @@ -248,6 +260,7 @@ Please follow these instructions when using tools from the respective MCP server async callTool({ user, serverName, + serverConfig: providedConfig, toolName, provider, toolArguments, @@ -262,6 +275,8 @@ Please follow these instructions when using tools from the respective MCP server }: { user?: IUser; serverName: string; + /** Pre-resolved config from tool creation context โ€” avoids readThrough TTL and cross-tenant issues */ + serverConfig?: t.ParsedServerConfig; toolName: string; provider: t.Provider; toolArguments?: Record; @@ -292,6 +307,7 @@ Please follow these instructions when using tools from the respective MCP server signal: options?.signal, customUserVars, requestBody, + serverConfig: providedConfig, }); if (!(await connection.isConnected())) { @@ -302,8 +318,16 @@ Please follow these instructions when using tools from the respective MCP server ); } - const rawConfig = await MCPServersRegistry.getInstance().getServerConfig(serverName, userId); - const isDbSourced = !!rawConfig?.dbId; + const rawConfig = + providedConfig ?? + (await MCPServersRegistry.getInstance().getServerConfig(serverName, userId)); + if (!rawConfig) { + throw new McpError( + ErrorCode.InvalidRequest, + `${logPrefix} Configuration for server "${serverName}" not found.`, + ); + } + const isDbSourced = isUserSourced(rawConfig); /** Pre-process Graph token placeholders (async) before the synchronous processMCPEnv pass */ const graphProcessedConfig = isDbSourced diff --git a/packages/api/src/mcp/UserConnectionManager.ts b/packages/api/src/mcp/UserConnectionManager.ts index 2e9d5be467..760f84c75e 100644 --- a/packages/api/src/mcp/UserConnectionManager.ts +++ b/packages/api/src/mcp/UserConnectionManager.ts @@ -4,6 +4,7 @@ import type * as t from './types'; import { MCPServersRegistry } from '~/mcp/registry/MCPServersRegistry'; import { ConnectionsRepository } from '~/mcp/ConnectionsRepository'; import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory'; +import { isUserSourced } from './utils'; import { MCPConnection } from './connection'; import { mcpConfig } from './mcpConfig'; @@ -38,6 +39,8 @@ export abstract class UserConnectionManager { opts: { serverName: string; forceNew?: boolean; + /** Pre-resolved config for config-source servers not in YAML/DB */ + serverConfig?: t.ParsedServerConfig; } & Omit, ): Promise { const { serverName, forceNew, user } = opts; @@ -85,9 +88,11 @@ export abstract class UserConnectionManager { signal, returnOnOAuth = false, connectionTimeout, + serverConfig: providedConfig, }: { serverName: string; forceNew?: boolean; + serverConfig?: t.ParsedServerConfig; } & Omit, userId: string, ): Promise { @@ -98,7 +103,9 @@ export abstract class UserConnectionManager { ); } - const config = await MCPServersRegistry.getInstance().getServerConfig(serverName, userId); + const config = + providedConfig ?? + (await MCPServersRegistry.getInstance().getServerConfig(serverName, userId)); const userServerMap = this.userConnections.get(userId); let connection = forceNew ? undefined : userServerMap?.get(serverName); @@ -158,7 +165,7 @@ export abstract class UserConnectionManager { { serverConfig: config, serverName: serverName, - dbSourced: !!config.dbId, + dbSourced: isUserSourced(config), useSSRFProtection: registry.shouldEnableSSRFProtection(), allowedDomains: registry.getAllowedDomains(), }, diff --git a/packages/api/src/mcp/__tests__/utils.test.ts b/packages/api/src/mcp/__tests__/utils.test.ts index c244205b99..b9c2a31fa5 100644 --- a/packages/api/src/mcp/__tests__/utils.test.ts +++ b/packages/api/src/mcp/__tests__/utils.test.ts @@ -3,6 +3,7 @@ import { normalizeServerName, redactAllServerSecrets, redactServerSecrets, + isUserSourced, } from '~/mcp/utils'; import type { ParsedServerConfig } from '~/mcp/types'; @@ -273,3 +274,29 @@ describe('redactAllServerSecrets', () => { expect((redacted['server-c'] as Record).command).toBeUndefined(); }); }); + +describe('isUserSourced', () => { + it('returns false when source is yaml', () => { + expect(isUserSourced({ source: 'yaml' })).toBe(false); + }); + + it('returns false when source is config', () => { + expect(isUserSourced({ source: 'config' })).toBe(false); + }); + + it('returns true when source is user', () => { + expect(isUserSourced({ source: 'user' })).toBe(true); + }); + + it('falls back to dbId when source is undefined โ€” dbId present means user-sourced', () => { + expect(isUserSourced({ source: undefined, dbId: 'abc123' })).toBe(true); + }); + + it('falls back to dbId when source is undefined โ€” no dbId means trusted', () => { + expect(isUserSourced({ source: undefined, dbId: undefined })).toBe(false); + }); + + it('returns false when both source and dbId are absent (pre-upgrade YAML server)', () => { + expect(isUserSourced({})).toBe(false); + }); +}); diff --git a/packages/api/src/mcp/cache.ts b/packages/api/src/mcp/cache.ts new file mode 100644 index 0000000000..e68ef42b3c --- /dev/null +++ b/packages/api/src/mcp/cache.ts @@ -0,0 +1,43 @@ +import { logger } from '@librechat/data-schemas'; +import { MCPServersRegistry } from './registry/MCPServersRegistry'; +import { MCPManager } from './MCPManager'; + +/** + * Clears config-source MCP server inspection cache so servers are re-inspected on next access. + * Best-effort disconnection of app-level connections for evicted servers. + * + * User-level connections (used by config-source servers) are cleaned up lazily via + * the stale-check mechanism on the next tool call โ€” this is an accepted design tradeoff + * since iterating all active user sessions is expensive and config mutations are rare. + */ +export async function clearMcpConfigCache(): Promise { + let registry: MCPServersRegistry; + try { + registry = MCPServersRegistry.getInstance(); + } catch { + return; + } + + let evictedServers: string[]; + try { + evictedServers = await registry.invalidateConfigCache(); + } catch (error) { + logger.error('[clearMcpConfigCache] Failed to invalidate config cache:', error); + return; + } + + if (!evictedServers.length) { + return; + } + + try { + const mcpManager = MCPManager.getInstance(); + if (mcpManager?.appConnections) { + await Promise.allSettled( + evictedServers.map((serverName) => mcpManager.appConnections!.disconnect(serverName)), + ); + } + } catch { + // MCPManager not yet initialized โ€” connections cleaned up lazily + } +} diff --git a/packages/api/src/mcp/oauth/handler.ts b/packages/api/src/mcp/oauth/handler.ts index 873af5c66d..e128dec308 100644 --- a/packages/api/src/mcp/oauth/handler.ts +++ b/packages/api/src/mcp/oauth/handler.ts @@ -467,6 +467,7 @@ export class MCPOAuthHandler { codeVerifier, clientInfo, metadata, + ...(Object.keys(oauthHeaders).length > 0 && { oauthHeaders }), }; logger.debug( @@ -573,6 +574,7 @@ export class MCPOAuthHandler { clientInfo, metadata, resourceMetadata, + ...(Object.keys(oauthHeaders).length > 0 && { oauthHeaders }), }; logger.debug( diff --git a/packages/api/src/mcp/oauth/types.ts b/packages/api/src/mcp/oauth/types.ts index 2138b4a782..bc5f53f60c 100644 --- a/packages/api/src/mcp/oauth/types.ts +++ b/packages/api/src/mcp/oauth/types.ts @@ -89,6 +89,8 @@ export interface MCPOAuthFlowMetadata extends FlowMetadata { metadata?: OAuthMetadata; resourceMetadata?: OAuthProtectedResourceMetadata; authorizationUrl?: string; + /** Custom headers for OAuth token exchange, persisted at flow initiation for the callback. */ + oauthHeaders?: Record; } export interface MCPOAuthTokens extends OAuthTokens { diff --git a/packages/api/src/mcp/registry/MCPServerInspector.ts b/packages/api/src/mcp/registry/MCPServerInspector.ts index 7f31211680..f064fbb7e5 100644 --- a/packages/api/src/mcp/registry/MCPServerInspector.ts +++ b/packages/api/src/mcp/registry/MCPServerInspector.ts @@ -4,9 +4,9 @@ import type { MCPConnection } from '~/mcp/connection'; import type * as t from '~/mcp/types'; import { isMCPDomainAllowed, extractMCPServerDomain } from '~/auth/domain'; import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory'; +import { hasCustomUserVars, isUserSourced } from '~/mcp/utils'; import { MCPDomainNotAllowedError } from '~/mcp/errors'; import { detectOAuthRequirement } from '~/mcp/oauth'; -import { hasCustomUserVars } from '~/mcp/utils'; import { isEnabled } from '~/utils'; /** @@ -73,7 +73,7 @@ export class MCPServerInspector { this.connection = await MCPConnectionFactory.create({ serverConfig: this.config, serverName: this.serverName, - dbSourced: !!this.config.dbId, + dbSourced: isUserSourced(this.config), useSSRFProtection: this.useSSRFProtection, allowedDomains: this.allowedDomains, }); diff --git a/packages/api/src/mcp/registry/MCPServersRegistry.ts b/packages/api/src/mcp/registry/MCPServersRegistry.ts index b9c1eb66f5..6c98a6b8dd 100644 --- a/packages/api/src/mcp/registry/MCPServersRegistry.ts +++ b/packages/api/src/mcp/registry/MCPServersRegistry.ts @@ -1,28 +1,48 @@ import { Keyv } from 'keyv'; +import { createHash } from 'crypto'; import { logger } from '@librechat/data-schemas'; import type { IServerConfigsRepositoryInterface } from './ServerConfigsRepositoryInterface'; import type * as t from '~/mcp/types'; -import { ServerConfigsCacheFactory, APP_CACHE_NAMESPACE } from './cache/ServerConfigsCacheFactory'; +import { + ServerConfigsCacheFactory, + APP_CACHE_NAMESPACE, + CONFIG_CACHE_NAMESPACE, +} from './cache/ServerConfigsCacheFactory'; import { MCPInspectionFailedError, isMCPDomainNotAllowedError } from '~/mcp/errors'; import { MCPServerInspector } from './MCPServerInspector'; import { ServerConfigsDB } from './db/ServerConfigsDB'; import { cacheConfig } from '~/cache/cacheConfig'; +import { withTimeout } from '~/utils'; + +/** How long a failure stub is considered fresh before re-attempting inspection (5 minutes). */ +const CONFIG_STUB_RETRY_MS = 5 * 60 * 1000; + +const CONFIG_SERVER_INIT_TIMEOUT_MS = (() => { + const raw = process.env.MCP_INIT_TIMEOUT_MS; + if (raw == null) { + return 30_000; + } + const parsed = parseInt(raw, 10); + return Number.isFinite(parsed) && parsed > 0 ? parsed : 30_000; +})(); /** * Central registry for managing MCP server configurations. * Authoritative source of truth for all MCP servers provided by LibreChat. * - * Uses a two-repository architecture: - * - Cache Repository: Stores YAML-defined configs loaded at startup (in-memory or Redis-backed) - * - DB Repository: Stores dynamic configs created at runtime (not yet implemented) + * Uses a three-layer architecture: + * - YAML Cache (cacheConfigsRepo): Operator-defined configs loaded at startup (in-memory or Redis) + * - Config Cache (configCacheRepo): Admin-defined configs from Config overrides, lazily initialized + * - DB Repository (dbConfigsRepo): User-provided configs created at runtime (MongoDB + ACL) * - * Query priority: Cache configs are checked first, then DB configs. + * Query priority: YAML cache โ†’ Config cache โ†’ DB. */ export class MCPServersRegistry { private static instance: MCPServersRegistry; private readonly dbConfigsRepo: IServerConfigsRepositoryInterface; private readonly cacheConfigsRepo: IServerConfigsRepositoryInterface; + private readonly configCacheRepo: IServerConfigsRepositoryInterface; private readonly allowedDomains?: string[] | null; private readonly readThroughCache: Keyv; private readonly readThroughCacheAll: Keyv>; @@ -31,9 +51,20 @@ export class MCPServersRegistry { Promise> >(); + /** Tracks in-flight config server initializations to prevent duplicate work. */ + private readonly pendingConfigInits = new Map< + string, + Promise + >(); + + /** Memoized YAML server names โ€” set once after boot-time init, never changes. */ + private yamlServerNames: Set | null = null; + private yamlServerNamesPromise: Promise> | null = null; + constructor(mongoose: typeof import('mongoose'), allowedDomains?: string[] | null) { this.dbConfigsRepo = new ServerConfigsDB(mongoose); this.cacheConfigsRepo = ServerConfigsCacheFactory.create(APP_CACHE_NAMESPACE, false); + this.configCacheRepo = ServerConfigsCacheFactory.create(CONFIG_CACHE_NAMESPACE, false); this.allowedDomains = allowedDomains; const ttl = cacheConfig.MCP_REGISTRY_CACHE_TTL; @@ -86,22 +117,29 @@ export class MCPServersRegistry { return !Array.isArray(this.allowedDomains) || this.allowedDomains.length === 0; } + /** + * Returns the config for a single server. When `configServers` is provided, config-source + * servers are resolved from it directly (no global state, no cross-tenant race). + */ public async getServerConfig( serverName: string, userId?: string, + configServers?: Record, ): Promise { + if (configServers?.[serverName]) { + return configServers[serverName]; + } + const cacheKey = this.getReadThroughCacheKey(serverName, userId); if (await this.readThroughCache.has(cacheKey)) { return await this.readThroughCache.get(cacheKey); } - // First we check if any config exist with the cache - // Yaml config are pre loaded to the cache - const configFromCache = await this.cacheConfigsRepo.get(serverName); - if (configFromCache) { - await this.readThroughCache.set(cacheKey, configFromCache); - return configFromCache; + const configFromYaml = await this.cacheConfigsRepo.get(serverName); + if (configFromYaml) { + await this.readThroughCache.set(cacheKey, configFromYaml); + return configFromYaml; } const configFromDB = await this.dbConfigsRepo.get(serverName, userId); @@ -109,7 +147,30 @@ export class MCPServersRegistry { return configFromDB; } - public async getAllServerConfigs(userId?: string): Promise> { + /** + * Returns all server configs visible to the given user. + * YAML and Config tiers are mutually exclusive by design (`ensureConfigServers` filters + * YAML names), so the spread order only matters for User DB (highest priority) overriding both. + */ + public async getAllServerConfigs( + userId?: string, + configServers?: Record, + ): Promise> { + if (configServers == null || !Object.keys(configServers).length) { + return this.getBaseServerConfigs(userId); + } + const base = await this.getBaseServerConfigs(userId); + return { ...configServers, ...base }; + } + + /** + * Returns YAML + user-DB server configs, cached via `readThroughCacheAll`. + * Always called by `getAllServerConfigs` so the DB query is amortized across + * requests within the TTL window regardless of whether `configServers` is present. + */ + private async getBaseServerConfigs( + userId?: string, + ): Promise> { const cacheKey = userId ?? '__no_user__'; if (await this.readThroughCacheAll.has(cacheKey)) { @@ -121,7 +182,7 @@ export class MCPServersRegistry { return pending; } - const fetchPromise = this.fetchAllServerConfigs(cacheKey, userId); + const fetchPromise = this.fetchBaseServerConfigs(cacheKey, userId); this.pendingGetAllPromises.set(cacheKey, fetchPromise); try { @@ -131,7 +192,7 @@ export class MCPServersRegistry { } } - private async fetchAllServerConfigs( + private async fetchBaseServerConfigs( cacheKey: string, userId?: string, ): Promise> { @@ -155,7 +216,8 @@ export class MCPServersRegistry { userId?: string, ): Promise { const configRepo = this.getConfigRepository(storageLocation); - const stubConfig: t.ParsedServerConfig = { ...config, inspectionFailed: true }; + const source: t.MCPServerSource = storageLocation === 'CACHE' ? 'yaml' : 'user'; + const stubConfig: t.ParsedServerConfig = { ...config, inspectionFailed: true, source }; const result = await configRepo.add(serverName, stubConfig, userId); await this.readThroughCache.delete(this.getReadThroughCacheKey(serverName, userId)); await this.readThroughCache.delete(this.getReadThroughCacheKey(serverName)); @@ -179,13 +241,16 @@ export class MCPServersRegistry { ); } catch (error) { logger.error(`[MCPServersRegistry] Failed to inspect server "${serverName}":`, error); - // Preserve domain-specific error for better error handling if (isMCPDomainNotAllowedError(error)) { throw error; } throw new MCPInspectionFailedError(serverName, error as Error); } - return await configRepo.add(serverName, parsedConfig, userId); + const tagged = { + ...parsedConfig, + source: (storageLocation === 'CACHE' ? 'yaml' : 'user') as t.MCPServerSource, + }; + return await configRepo.add(serverName, tagged, userId); } /** @@ -267,7 +332,6 @@ export class MCPServersRegistry { ); } catch (error) { logger.error(`[MCPServersRegistry] Failed to inspect server "${serverName}":`, error); - // Preserve domain-specific error for better error handling if (isMCPDomainNotAllowedError(error)) { throw error; } @@ -277,8 +341,180 @@ export class MCPServersRegistry { return parsedConfig; } - // TODO: This is currently used to determine if a server requires OAuth. However, this info can - // can be determined through config.requiresOAuth. Refactor usages and remove this method. + /** + * Ensures that config-source MCP servers (from admin Config overrides) are initialized. + * Identifies servers in `resolvedMcpConfig` that are not from YAML, lazily initializes + * any not yet in the config cache, and returns their parsed configs. + * + * Config cache keys are scoped by a hash of the raw config to prevent cross-tenant + * cache poisoning when two tenants define a server with the same name but different configs. + */ + public async ensureConfigServers( + resolvedMcpConfig: Record, + ): Promise> { + if (!resolvedMcpConfig || Object.keys(resolvedMcpConfig).length === 0) { + return {}; + } + + const yamlNames = await this.getYamlServerNames(); + const configServerEntries = Object.entries(resolvedMcpConfig).filter( + ([name]) => !yamlNames.has(name), + ); + + if (configServerEntries.length === 0) { + return {}; + } + + const result: Record = {}; + + const settled = await Promise.allSettled( + configServerEntries.map(async ([serverName, rawConfig]) => { + const parsed = await this.ensureSingleConfigServer(serverName, rawConfig); + if (parsed) { + result[serverName] = parsed; + } + }), + ); + for (const outcome of settled) { + if (outcome.status === 'rejected') { + logger.error('[MCPServersRegistry][ensureConfigServers] Unexpected error:', outcome.reason); + } + } + + return result; + } + + /** + * Ensures a single config-source server is initialized. + * Cache key is scoped by config hash to prevent cross-tenant poisoning. + * Deduplicates concurrent init requests for the same server+config. + * Stale failure stubs are retried after `CONFIG_STUB_RETRY_MS` to recover from transient errors. + */ + private async ensureSingleConfigServer( + serverName: string, + rawConfig: t.MCPOptions, + ): Promise { + const cacheKey = this.configCacheKey(serverName, rawConfig); + + const cached = await this.configCacheRepo.get(cacheKey); + if (cached) { + const isStaleStub = + cached.inspectionFailed && Date.now() - (cached.updatedAt ?? 0) > CONFIG_STUB_RETRY_MS; + if (!isStaleStub) { + return cached; + } + logger.info(`[MCP][config][${serverName}] Retrying stale failure stub`); + } + + const pending = this.pendingConfigInits.get(cacheKey); + if (pending) { + return pending; + } + + const initPromise = this.lazyInitConfigServer(cacheKey, serverName, rawConfig); + this.pendingConfigInits.set(cacheKey, initPromise); + + try { + return await initPromise; + } finally { + this.pendingConfigInits.delete(cacheKey); + } + } + + /** + * Lazily initializes a config-source MCP server: inspects capabilities/tools, then + * stores the parsed config in the config cache with `source: 'config'`. + */ + private async lazyInitConfigServer( + cacheKey: string, + serverName: string, + rawConfig: t.MCPOptions, + ): Promise { + const prefix = `[MCP][config][${serverName}]`; + logger.info(`${prefix} Lazy-initializing config-source server`); + + try { + const inspected = await withTimeout( + MCPServerInspector.inspect(serverName, rawConfig, undefined, this.allowedDomains), + CONFIG_SERVER_INIT_TIMEOUT_MS, + `${prefix} Server initialization timed out`, + ); + + const parsedConfig: t.ParsedServerConfig = { ...inspected, source: 'config' }; + await this.upsertConfigCache(cacheKey, parsedConfig); + + logger.info( + `${prefix} Initialized: tools=${parsedConfig.tools ?? 'N/A'}, ` + + `duration=${parsedConfig.initDuration ?? 'N/A'}ms`, + ); + return parsedConfig; + } catch (error) { + logger.error(`${prefix} Failed to initialize:`, error); + + const stubConfig: t.ParsedServerConfig = { + ...rawConfig, + inspectionFailed: true, + source: 'config', + updatedAt: Date.now(), + }; + try { + await this.upsertConfigCache(cacheKey, stubConfig); + logger.info(`${prefix} Stored stub config for recovery`); + } catch (cacheError) { + logger.error( + `${prefix} Failed to store stub config (will retry on next request):`, + cacheError, + ); + } + return stubConfig; + } + } + + /** + * Writes a config to `configCacheRepo` using the atomic upsert operation. + * Safe for cross-process races โ€” the underlying cache handles add-or-update internally. + */ + private async upsertConfigCache(cacheKey: string, config: t.ParsedServerConfig): Promise { + await this.configCacheRepo.upsert(cacheKey, config); + } + + /** + * Clears the config-source server cache, forcing re-inspection on next access. + * Called when admin config overrides change (e.g., mcpServers mutation). + * + * @returns Names of servers that were evicted from the config cache. + * Callers should disconnect active connections for these servers. + */ + public async invalidateConfigCache(): Promise { + const allCached = await this.configCacheRepo.getAll(); + const evictedNames = [ + ...new Set( + Object.keys(allCached).map((key) => { + const lastColon = key.lastIndexOf(':'); + return lastColon > 0 ? key.slice(0, lastColon) : key; + }), + ), + ]; + + await Promise.all([ + this.configCacheRepo.reset(), + // Only clear readThroughCacheAll (merged results that may include stale config servers). + // readThroughCache (individual YAML/user lookups) is unaffected by config mutations. + this.readThroughCacheAll.clear(), + ]); + + if (evictedNames.length > 0) { + logger.info( + `[MCPServersRegistry] Config server cache invalidated, evicted: ${evictedNames.join(', ')}`, + ); + } + return evictedNames; + } + + // TODO: Refactor callers to use config.requiresOAuth directly instead of this method. + // Known gap: config-source OAuth servers are not included here because callers + // (OAuthReconnectionManager, UserController) lack request context to resolve configServers. + // Config-source OAuth auto-reconnection and uninstall cleanup require a separate mechanism. public async getOAuthServers(userId?: string): Promise> { const allServers = await this.getAllServerConfigs(userId); const oauthServers = Object.entries(allServers).filter(([, config]) => config.requiresOAuth); @@ -287,8 +523,11 @@ export class MCPServersRegistry { public async reset(): Promise { await this.cacheConfigsRepo.reset(); + await this.configCacheRepo.reset(); await this.readThroughCache.clear(); await this.readThroughCacheAll.clear(); + this.yamlServerNames = null; + this.yamlServerNamesPromise = null; } public async removeServer( @@ -316,4 +555,48 @@ export class MCPServersRegistry { private getReadThroughCacheKey(serverName: string, userId?: string): string { return userId ? `${serverName}::${userId}` : serverName; } + + /** + * Returns memoized YAML server names. Populated lazily on first call after boot/reset. + * YAML servers don't change after boot, so this avoids repeated `getAll()` calls. + * Uses promise deduplication to prevent concurrent cold-start double-fetch. + */ + private getYamlServerNames(): Promise> { + if (this.yamlServerNames) { + return Promise.resolve(this.yamlServerNames); + } + if (this.yamlServerNamesPromise) { + return this.yamlServerNamesPromise; + } + this.yamlServerNamesPromise = this.cacheConfigsRepo + .getAll() + .then((configs) => { + this.yamlServerNames = new Set(Object.keys(configs)); + this.yamlServerNamesPromise = null; + return this.yamlServerNames; + }) + .catch((err) => { + this.yamlServerNamesPromise = null; + throw err; + }); + return this.yamlServerNamesPromise; + } + + /** + * Produces a config-cache key scoped by server name AND a hash of the raw config. + * Prevents cross-tenant cache poisoning when two tenants define the same server name + * with different configurations. + */ + private configCacheKey(serverName: string, rawConfig: t.MCPOptions): string { + const sorted = JSON.stringify(rawConfig, (_key, value: unknown) => { + if (value !== null && typeof value === 'object' && !Array.isArray(value)) { + return Object.fromEntries( + Object.entries(value as Record).sort(([a], [b]) => a.localeCompare(b)), + ); + } + return value; + }); + const hash = createHash('sha256').update(sorted).digest('hex').slice(0, 16); + return `${serverName}:${hash}`; + } } diff --git a/packages/api/src/mcp/registry/ServerConfigsRepositoryInterface.ts b/packages/api/src/mcp/registry/ServerConfigsRepositoryInterface.ts index 1c913dd1a3..4bf0fdd615 100644 --- a/packages/api/src/mcp/registry/ServerConfigsRepositoryInterface.ts +++ b/packages/api/src/mcp/registry/ServerConfigsRepositoryInterface.ts @@ -9,6 +9,9 @@ export interface IServerConfigsRepositoryInterface { //ACL Entry check if update is possible update(serverName: string, config: ParsedServerConfig, userId?: string): Promise; + /** Atomic add-or-update without requiring callers to inspect error messages. */ + upsert(serverName: string, config: ParsedServerConfig, userId?: string): Promise; + //ACL Entry check if remove is possible remove(serverName: string, userId?: string): Promise; diff --git a/packages/api/src/mcp/registry/__tests__/MCPServerInspector.test.ts b/packages/api/src/mcp/registry/__tests__/MCPServerInspector.test.ts index f0ab75c9b4..2012f82e31 100644 --- a/packages/api/src/mcp/registry/__tests__/MCPServerInspector.test.ts +++ b/packages/api/src/mcp/registry/__tests__/MCPServerInspector.test.ts @@ -321,12 +321,12 @@ describe('MCPServerInspector', () => { const result = await MCPServerInspector.inspect('test_server', rawConfig); // Verify factory was called to create connection - expect(MCPConnectionFactory.create).toHaveBeenCalledWith({ - serverName: 'test_server', - serverConfig: expect.objectContaining({ type: 'stdio', command: 'node' }), - useSSRFProtection: true, - dbSourced: false, - }); + expect(MCPConnectionFactory.create).toHaveBeenCalledWith( + expect.objectContaining({ + serverName: 'test_server', + serverConfig: expect.objectContaining({ type: 'stdio', command: 'node' }), + }), + ); // Verify temporary connection was disconnected expect(tempMockConnection.disconnect).toHaveBeenCalled(); diff --git a/packages/api/src/mcp/registry/__tests__/MCPServersRegistry.test.ts b/packages/api/src/mcp/registry/__tests__/MCPServersRegistry.test.ts index 8891120717..a20c09705f 100644 --- a/packages/api/src/mcp/registry/__tests__/MCPServersRegistry.test.ts +++ b/packages/api/src/mcp/registry/__tests__/MCPServersRegistry.test.ts @@ -112,8 +112,8 @@ describe('MCPServersRegistry', () => { const userConfigBefore = await registry.getServerConfig('user_server'); const allConfigsBefore = await registry.getAllServerConfigs(); - expect(appConfigBefore).toEqual(testParsedConfig); - expect(userConfigBefore).toEqual(testParsedConfig); + expect(appConfigBefore).toEqual(expect.objectContaining(testParsedConfig)); + expect(userConfigBefore).toEqual(expect.objectContaining(testParsedConfig)); expect(Object.keys(allConfigsBefore)).toHaveLength(2); // Reset everything @@ -250,22 +250,18 @@ describe('MCPServersRegistry', () => { }); it('should use different cache keys for different userIds', async () => { - // Spy on the cache repository get method + await registry['cacheConfigsRepo'].add('test_server', testParsedConfig); const cacheRepoGetSpy = jest.spyOn(registry['cacheConfigsRepo'], 'get'); - // First call without userId await registry.getServerConfig('test_server'); expect(cacheRepoGetSpy).toHaveBeenCalledTimes(1); - // Call with userId - should be a different cache key, so hits repository again await registry.getServerConfig('test_server', 'user123'); expect(cacheRepoGetSpy).toHaveBeenCalledTimes(2); - // Repeat call with same userId - should hit read-through cache await registry.getServerConfig('test_server', 'user123'); - expect(cacheRepoGetSpy).toHaveBeenCalledTimes(2); // Still 2 + expect(cacheRepoGetSpy).toHaveBeenCalledTimes(2); - // Call with different userId - should hit repository await registry.getServerConfig('test_server', 'user456'); expect(cacheRepoGetSpy).toHaveBeenCalledTimes(3); }); diff --git a/packages/api/src/mcp/registry/__tests__/ensureConfigServers.test.ts b/packages/api/src/mcp/registry/__tests__/ensureConfigServers.test.ts new file mode 100644 index 0000000000..70eb2f75c4 --- /dev/null +++ b/packages/api/src/mcp/registry/__tests__/ensureConfigServers.test.ts @@ -0,0 +1,328 @@ +import type * as t from '~/mcp/types'; +import { MCPServersRegistry } from '~/mcp/registry/MCPServersRegistry'; +import { MCPServerInspector } from '~/mcp/registry/MCPServerInspector'; + +jest.mock('~/mcp/registry/MCPServerInspector'); +jest.mock('~/mcp/registry/db/ServerConfigsDB', () => ({ + ServerConfigsDB: jest.fn().mockImplementation(() => ({ + get: jest.fn().mockResolvedValue(undefined), + getAll: jest.fn().mockResolvedValue({}), + add: jest.fn().mockResolvedValue(undefined), + update: jest.fn().mockResolvedValue(undefined), + upsert: jest.fn().mockResolvedValue(undefined), + remove: jest.fn().mockResolvedValue(undefined), + reset: jest.fn().mockResolvedValue(undefined), + })), +})); + +const FIXED_TIME = 1699564800000; + +const mockMongoose = {} as typeof import('mongoose'); + +const sseConfig: t.MCPOptions = { + type: 'sse', + url: 'https://mcp.example.com/sse', +} as unknown as t.MCPOptions; + +const altSseConfig: t.MCPOptions = { + type: 'sse', + url: 'https://mcp.other-tenant.com/sse', +} as unknown as t.MCPOptions; + +const yamlConfig: t.MCPOptions = { + type: 'stdio', + command: 'node', + args: ['tools.js'], +} as unknown as t.MCPOptions; + +function makeParsedConfig(overrides: Partial = {}): t.ParsedServerConfig { + return { + type: 'sse', + url: 'https://mcp.example.com/sse', + requiresOAuth: false, + tools: 'tool_a, tool_b', + capabilities: '{}', + initDuration: 42, + ...overrides, + } as unknown as t.ParsedServerConfig; +} + +describe('MCPServersRegistry โ€” ensureConfigServers', () => { + let registry: MCPServersRegistry; + let inspectSpy: jest.SpyInstance; + + beforeAll(() => { + jest.useFakeTimers(); + jest.setSystemTime(new Date(FIXED_TIME)); + }); + + afterAll(() => { + jest.useRealTimers(); + }); + + beforeEach(async () => { + (MCPServersRegistry as unknown as { instance: undefined }).instance = undefined; + MCPServersRegistry.createInstance(mockMongoose); + registry = MCPServersRegistry.getInstance(); + + inspectSpy = jest + .spyOn(MCPServerInspector, 'inspect') + .mockImplementation(async (_serverName: string, rawConfig: t.MCPOptions) => + makeParsedConfig(rawConfig as unknown as Partial), + ); + + await registry.reset(); + }); + + afterEach(() => { + inspectSpy.mockClear(); + }); + + it('should return empty for empty input', async () => { + expect(await registry.ensureConfigServers({})).toEqual({}); + }); + + it('should return empty for null/undefined input', async () => { + expect( + await registry.ensureConfigServers(null as unknown as Record), + ).toEqual({}); + expect( + await registry.ensureConfigServers(undefined as unknown as Record), + ).toEqual({}); + }); + + it('should exclude YAML servers from config-source detection', async () => { + await registry.addServer('yaml_server', yamlConfig, 'CACHE'); + + const result = await registry.ensureConfigServers({ + yaml_server: yamlConfig, + config_server: sseConfig, + }); + + expect(result).toHaveProperty('config_server'); + expect(result).not.toHaveProperty('yaml_server'); + }); + + it('should return empty when all servers are YAML', async () => { + await registry.addServer('yaml_a', yamlConfig, 'CACHE'); + await registry.addServer('yaml_b', yamlConfig, 'CACHE'); + inspectSpy.mockClear(); + + const result = await registry.ensureConfigServers({ + yaml_a: yamlConfig, + yaml_b: yamlConfig, + }); + + expect(result).toEqual({}); + expect(inspectSpy).not.toHaveBeenCalled(); + }); + + it('should lazy-initialize a config-source server and tag source as config', async () => { + const result = await registry.ensureConfigServers({ my_server: sseConfig }); + + expect(result).toHaveProperty('my_server'); + expect(result.my_server.source).toBe('config'); + expect(inspectSpy).toHaveBeenCalledTimes(1); + expect(inspectSpy).toHaveBeenCalledWith('my_server', sseConfig, undefined, undefined); + }); + + it('should return cached result on second call without re-inspecting', async () => { + await registry.ensureConfigServers({ my_server: sseConfig }); + expect(inspectSpy).toHaveBeenCalledTimes(1); + + const result2 = await registry.ensureConfigServers({ my_server: sseConfig }); + expect(result2).toHaveProperty('my_server'); + expect(result2.my_server.source).toBe('config'); + expect(inspectSpy).toHaveBeenCalledTimes(1); + }); + + it('should store inspectionFailed stub on inspection failure', async () => { + inspectSpy.mockRejectedValueOnce(new Error('connection refused')); + + const result = await registry.ensureConfigServers({ bad_server: sseConfig }); + + expect(result).toHaveProperty('bad_server'); + expect(result.bad_server.inspectionFailed).toBe(true); + expect(result.bad_server.source).toBe('config'); + }); + + it('should return stub from cache on repeated failure without re-inspecting', async () => { + inspectSpy.mockRejectedValueOnce(new Error('connection refused')); + await registry.ensureConfigServers({ bad_server: sseConfig }); + expect(inspectSpy).toHaveBeenCalledTimes(1); + + const result2 = await registry.ensureConfigServers({ bad_server: sseConfig }); + expect(result2.bad_server.inspectionFailed).toBe(true); + expect(inspectSpy).toHaveBeenCalledTimes(1); + }); + + it('should retry stale failure stub after CONFIG_STUB_RETRY_MS', async () => { + inspectSpy.mockRejectedValueOnce(new Error('transient DNS failure')); + await registry.ensureConfigServers({ flaky_server: sseConfig }); + expect(inspectSpy).toHaveBeenCalledTimes(1); + + jest.setSystemTime(new Date(FIXED_TIME + 6 * 60 * 1000)); + + const result = await registry.ensureConfigServers({ flaky_server: sseConfig }); + expect(inspectSpy).toHaveBeenCalledTimes(2); + expect(result.flaky_server.inspectionFailed).toBeUndefined(); + expect(result.flaky_server.source).toBe('config'); + + jest.setSystemTime(new Date(FIXED_TIME)); + }); + + describe('cross-tenant isolation', () => { + it('should use different cache keys for same server name with different configs', async () => { + inspectSpy.mockClear(); + const resultA = await registry.ensureConfigServers({ shared_name: sseConfig }); + expect(resultA.shared_name.source).toBe('config'); + expect(inspectSpy).toHaveBeenCalledTimes(1); + + const resultB = await registry.ensureConfigServers({ shared_name: altSseConfig }); + expect(resultB.shared_name.source).toBe('config'); + expect(inspectSpy).toHaveBeenCalledTimes(2); + }); + + it('should return tenant-A config for tenant-A and tenant-B config for tenant-B', async () => { + const resultA = await registry.ensureConfigServers({ srv: sseConfig }); + const resultB = await registry.ensureConfigServers({ srv: altSseConfig }); + + expect((resultA.srv as unknown as { url: string }).url).toBe('https://mcp.example.com/sse'); + expect((resultB.srv as unknown as { url: string }).url).toBe( + 'https://mcp.other-tenant.com/sse', + ); + }); + }); + + describe('concurrent deduplication', () => { + it('should only inspect once for multiple parallel calls with the same config', async () => { + inspectSpy.mockClear(); + // Fire two calls simultaneously โ€” both see cache miss, but only one should inspect + const [r1, r2] = await Promise.all([ + registry.ensureConfigServers({ dedup_srv: sseConfig }), + registry.ensureConfigServers({ dedup_srv: sseConfig }), + ]); + + expect(r1.dedup_srv).toBeDefined(); + expect(r2.dedup_srv).toBeDefined(); + expect(inspectSpy).toHaveBeenCalledTimes(1); + + // Subsequent call must NOT re-inspect (cached) + inspectSpy.mockClear(); + await registry.ensureConfigServers({ dedup_srv: sseConfig }); + expect(inspectSpy).toHaveBeenCalledTimes(0); + }); + }); + + describe('merge order', () => { + it('should merge YAML โ†’ config โ†’ user with correct precedence in getAllServerConfigs', async () => { + await registry.addServer('yaml_srv', yamlConfig, 'CACHE'); + + const configServers = await registry.ensureConfigServers({ config_srv: sseConfig }); + + const all = await registry.getAllServerConfigs(undefined, configServers); + expect(all).toHaveProperty('yaml_srv'); + expect(all).toHaveProperty('config_srv'); + expect(all.yaml_srv.source).toBe('yaml'); + expect(all.config_srv.source).toBe('config'); + }); + + it('should let config servers appear alongside user DB servers', async () => { + const mockDbConfigs = { + user_srv: makeParsedConfig({ source: 'user', dbId: 'abc123' }), + }; + jest.spyOn(registry['dbConfigsRepo'], 'getAll').mockResolvedValue(mockDbConfigs); + + const configServers = await registry.ensureConfigServers({ config_srv: sseConfig }); + const all = await registry.getAllServerConfigs('user-1', configServers); + + expect(all).toHaveProperty('config_srv'); + expect(all).toHaveProperty('user_srv'); + expect(all.config_srv.source).toBe('config'); + expect(all.user_srv.source).toBe('user'); + }); + }); + + describe('invalidateConfigCache', () => { + it('should clear config cache and force re-inspection on next call', async () => { + await registry.ensureConfigServers({ my_server: sseConfig }); + inspectSpy.mockClear(); + + await registry.invalidateConfigCache(); + + await registry.ensureConfigServers({ my_server: sseConfig }); + expect(inspectSpy).toHaveBeenCalledTimes(1); + }); + + it('should return evicted server names', async () => { + await registry.ensureConfigServers({ srv_a: sseConfig, srv_b: altSseConfig }); + const evicted = await registry.invalidateConfigCache(); + expect(evicted.length).toBeGreaterThan(0); + }); + + it('should return empty array when nothing is cached', async () => { + const evicted = await registry.invalidateConfigCache(); + expect(evicted).toEqual([]); + }); + }); + + describe('getServerConfig with configServers', () => { + it('should return config-source server when configServers is passed', async () => { + const configServers = await registry.ensureConfigServers({ config_srv: sseConfig }); + const config = await registry.getServerConfig('config_srv', undefined, configServers); + expect(config).toBeDefined(); + expect(config?.source).toBe('config'); + }); + + it('should return config-source server with userId when configServers is passed', async () => { + const configServers = await registry.ensureConfigServers({ config_srv: sseConfig }); + const config = await registry.getServerConfig('config_srv', 'user-123', configServers); + expect(config).toBeDefined(); + expect(config?.source).toBe('config'); + }); + + it('should return undefined for config-source server without configServers (tenant isolation)', async () => { + await registry.ensureConfigServers({ config_srv: sseConfig }); + const config = await registry.getServerConfig('config_srv'); + expect(config).toBeUndefined(); + }); + + it('should return correct config after invalidation and re-init', async () => { + const configServers1 = await registry.ensureConfigServers({ config_srv: sseConfig }); + expect(await registry.getServerConfig('config_srv', undefined, configServers1)).toBeDefined(); + + await registry.invalidateConfigCache(); + + const configServers2 = await registry.ensureConfigServers({ config_srv: sseConfig }); + const config = await registry.getServerConfig('config_srv', undefined, configServers2); + expect(config).toBeDefined(); + expect(config?.source).toBe('config'); + }); + + it('should not cross-contaminate between tenant configServers maps', async () => { + const tenantA = await registry.ensureConfigServers({ srv: sseConfig }); + const tenantB = await registry.ensureConfigServers({ srv: altSseConfig }); + + const configA = await registry.getServerConfig('srv', undefined, tenantA); + const configB = await registry.getServerConfig('srv', undefined, tenantB); + + expect((configA as unknown as { url: string }).url).toBe('https://mcp.example.com/sse'); + expect((configB as unknown as { url: string }).url).toBe('https://mcp.other-tenant.com/sse'); + }); + }); + + describe('source tagging', () => { + it('should tag CACHE-stored servers as yaml', async () => { + await registry.addServer('yaml_srv', yamlConfig, 'CACHE'); + const config = await registry.getServerConfig('yaml_srv'); + expect(config?.source).toBe('yaml'); + }); + + it('should tag stubs as yaml when stored in CACHE', async () => { + await registry.addServerStub('stub_srv', yamlConfig, 'CACHE'); + const config = await registry.getServerConfig('stub_srv'); + expect(config?.source).toBe('yaml'); + expect(config?.inspectionFailed).toBe(true); + }); + }); +}); diff --git a/packages/api/src/mcp/registry/cache/ServerConfigsCacheFactory.ts b/packages/api/src/mcp/registry/cache/ServerConfigsCacheFactory.ts index b9549629d6..ebe19b59e3 100644 --- a/packages/api/src/mcp/registry/cache/ServerConfigsCacheFactory.ts +++ b/packages/api/src/mcp/registry/cache/ServerConfigsCacheFactory.ts @@ -16,11 +16,17 @@ export type ServerConfigsCache = */ export const APP_CACHE_NAMESPACE = 'App' as const; +/** Namespace for admin-defined config-override MCP server inspection results. */ +export const CONFIG_CACHE_NAMESPACE = 'Config' as const; + +/** Namespaces that use the aggregate-key optimization to avoid SCAN+N-GETs stalls. */ +const AGGREGATE_KEY_NAMESPACES = new Set([APP_CACHE_NAMESPACE, CONFIG_CACHE_NAMESPACE]); + /** * Factory for creating the appropriate ServerConfigsCache implementation based on * deployment mode and namespace. * - * The {@link APP_CACHE_NAMESPACE} namespace uses {@link ServerConfigsCacheRedisAggregateKey} + * Namespaces in {@link AGGREGATE_KEY_NAMESPACES} use {@link ServerConfigsCacheRedisAggregateKey} * when Redis is enabled โ€” storing all configs under a single key so `getAll()` is one GET * instead of SCAN + N GETs. Cross-instance visibility is preserved: reinspection results * propagate through Redis automatically. @@ -32,8 +38,8 @@ export class ServerConfigsCacheFactory { /** * Create a ServerConfigsCache instance. * - * @param namespace - The namespace for the cache. {@link APP_CACHE_NAMESPACE} uses - * aggregate-key Redis storage (or in-memory when Redis is disabled). + * @param namespace - The namespace for the cache. Namespaces in {@link AGGREGATE_KEY_NAMESPACES} + * use aggregate-key Redis storage (or in-memory when Redis is disabled). * @param leaderOnly - Whether write operations should only be performed by the leader. * @returns ServerConfigsCache instance */ @@ -42,7 +48,7 @@ export class ServerConfigsCacheFactory { return new ServerConfigsCacheInMemory(); } - if (namespace === APP_CACHE_NAMESPACE) { + if (AGGREGATE_KEY_NAMESPACES.has(namespace)) { return new ServerConfigsCacheRedisAggregateKey(namespace, leaderOnly); } diff --git a/packages/api/src/mcp/registry/cache/ServerConfigsCacheInMemory.ts b/packages/api/src/mcp/registry/cache/ServerConfigsCacheInMemory.ts index 384c477756..5a7fd35b9f 100644 --- a/packages/api/src/mcp/registry/cache/ServerConfigsCacheInMemory.ts +++ b/packages/api/src/mcp/registry/cache/ServerConfigsCacheInMemory.ts @@ -28,6 +28,10 @@ export class ServerConfigsCacheInMemory { this.cache.set(serverName, { ...config, updatedAt: Date.now() }); } + public async upsert(serverName: string, config: ParsedServerConfig): Promise { + this.cache.set(serverName, { ...config, updatedAt: Date.now() }); + } + public async remove(serverName: string): Promise { if (!this.cache.delete(serverName)) { throw new Error(`Failed to remove server "${serverName}" in cache.`); diff --git a/packages/api/src/mcp/registry/cache/ServerConfigsCacheRedis.ts b/packages/api/src/mcp/registry/cache/ServerConfigsCacheRedis.ts index d3154baf73..af1316056d 100644 --- a/packages/api/src/mcp/registry/cache/ServerConfigsCacheRedis.ts +++ b/packages/api/src/mcp/registry/cache/ServerConfigsCacheRedis.ts @@ -52,6 +52,12 @@ export class ServerConfigsCacheRedis this.successCheck(`update ${this.namespace} server "${serverName}"`, success); } + public async upsert(serverName: string, config: ParsedServerConfig): Promise { + if (this.leaderOnly) await this.leaderCheck(`upsert ${this.namespace} MCP servers`); + const success = await this.cache.set(serverName, { ...config, updatedAt: Date.now() }); + this.successCheck(`upsert ${this.namespace} server "${serverName}"`, success); + } + public async remove(serverName: string): Promise { if (this.leaderOnly) await this.leaderCheck(`remove ${this.namespace} MCP servers`); const success = await this.cache.delete(serverName); diff --git a/packages/api/src/mcp/registry/cache/ServerConfigsCacheRedisAggregateKey.ts b/packages/api/src/mcp/registry/cache/ServerConfigsCacheRedisAggregateKey.ts index e67c1a4a84..5fc32bd7aa 100644 --- a/packages/api/src/mcp/registry/cache/ServerConfigsCacheRedisAggregateKey.ts +++ b/packages/api/src/mcp/registry/cache/ServerConfigsCacheRedisAggregateKey.ts @@ -53,8 +53,11 @@ export class ServerConfigsCacheRedisAggregateKey /** Milliseconds since epoch. 0 = epoch = always expired on first check. */ private localSnapshotExpiry = 0; + private readonly namespace: string; + constructor(namespace: string, leaderOnly: boolean) { super(leaderOnly); + this.namespace = namespace; this.cache = standardCache(`${this.PREFIX}::Servers::${namespace}`); } @@ -125,7 +128,7 @@ export class ServerConfigsCacheRedisAggregateKey const storedConfig = { ...config, updatedAt: Date.now() }; const newAll = { ...all, [serverName]: storedConfig }; const success = await this.cache.set(AGGREGATE_KEY, newAll); - this.successCheck(`add App server "${serverName}"`, success); + this.successCheck(`add ${this.namespace} server "${serverName}"`, success); return { serverName, config: storedConfig }; }); } @@ -142,7 +145,18 @@ export class ServerConfigsCacheRedisAggregateKey } const newAll = { ...all, [serverName]: { ...config, updatedAt: Date.now() } }; const success = await this.cache.set(AGGREGATE_KEY, newAll); - this.successCheck(`update App server "${serverName}"`, success); + this.successCheck(`update ${this.namespace} server "${serverName}"`, success); + }); + } + + public async upsert(serverName: string, config: ParsedServerConfig): Promise { + if (this.leaderOnly) await this.leaderCheck('upsert MCP servers'); + return this.withWriteLock(async () => { + this.invalidateLocalSnapshot(); + const all = await this.getAll(); + const newAll = { ...all, [serverName]: { ...config, updatedAt: Date.now() } }; + const success = await this.cache.set(AGGREGATE_KEY, newAll); + this.successCheck(`upsert ${this.namespace} server "${serverName}"`, success); }); } @@ -156,7 +170,7 @@ export class ServerConfigsCacheRedisAggregateKey } const { [serverName]: _, ...newAll } = all; const success = await this.cache.set(AGGREGATE_KEY, newAll); - this.successCheck(`remove App server "${serverName}"`, success); + this.successCheck(`remove ${this.namespace} server "${serverName}"`, success); }); } @@ -171,7 +185,7 @@ export class ServerConfigsCacheRedisAggregateKey */ public override async reset(): Promise { if (this.leaderOnly) { - await this.leaderCheck('reset App MCP servers cache'); + await this.leaderCheck(`reset ${this.namespace} MCP servers cache`); } await this.cache.delete(AGGREGATE_KEY); this.invalidateLocalSnapshot(); diff --git a/packages/api/src/mcp/registry/db/ServerConfigsDB.ts b/packages/api/src/mcp/registry/db/ServerConfigsDB.ts index 9981f6b00b..b1649c66ca 100644 --- a/packages/api/src/mcp/registry/db/ServerConfigsDB.ts +++ b/packages/api/src/mcp/registry/db/ServerConfigsDB.ts @@ -220,6 +220,25 @@ export class ServerConfigsDB implements IServerConfigsRepositoryInterface { await this._dbMethods.updateMCPServer(serverName, { config: configToSave }); } + /** + * Atomic add-or-update. For DB-backed servers this delegates to update since + * DB servers are always created via the explicit add() flow with ACL setup. + * Config-source servers should use configCacheRepo, not dbConfigsRepo. + */ + public async upsert( + serverName: string, + config: ParsedServerConfig, + userId?: string, + ): Promise { + if (!userId) { + throw new Error( + `[ServerConfigsDB.upsert] User ID is required for DB-backed MCP server upsert of "${serverName}". ` + + 'Config-source servers should use configCacheRepo, not dbConfigsRepo.', + ); + } + return this.update(serverName, config, userId); + } + /** * Deletes an MCP server and removes all associated ACL entries. * @param serverName - The serverName of the server to remove @@ -411,6 +430,7 @@ export class ServerConfigsDB implements IServerConfigsRepositoryInterface { const config: ParsedServerConfig = { ...serverDBDoc.config, dbId: (serverDBDoc._id as Types.ObjectId).toString(), + source: 'user', updatedAt: serverDBDoc.updatedAt?.getTime(), }; return await this.decryptConfig(config); diff --git a/packages/api/src/mcp/types/index.ts b/packages/api/src/mcp/types/index.ts index 6cb5e02f0b..32c2787165 100644 --- a/packages/api/src/mcp/types/index.ts +++ b/packages/api/src/mcp/types/index.ts @@ -144,6 +144,14 @@ export type ImageFormatter = (item: ImageContent) => FormattedContent; export type FormattedToolResponse = FormattedContentResult; +/** + * Origin of an MCP server definition. + * - `'yaml'` โ€” operator-defined in librechat.yaml, full trust, boot-time init + * - `'config'` โ€” admin-defined via Config override, full trust, lazy init + * - `'user'` โ€” user-provided via UI, sandboxed (restricted placeholder resolution) + */ +export type MCPServerSource = 'yaml' | 'config' | 'user'; + export type ParsedServerConfig = MCPOptions & { url?: string; requiresOAuth?: boolean; @@ -154,6 +162,8 @@ export type ParsedServerConfig = MCPOptions & { initDuration?: number; updatedAt?: number; dbId?: string; + /** Origin of this server definition โ€” determines trust level and placeholder resolution */ + source?: MCPServerSource; /** True if access is only via agent (not directly shared with user) */ consumeOnly?: boolean; /** True when inspection failed at startup; the server is known but not fully initialized */ @@ -202,6 +212,8 @@ export interface ToolDiscoveryOptions { customUserVars?: Record; requestBody?: RequestBody; connectionTimeout?: number; + /** Pre-resolved config-source servers for tenant-scoped lookup */ + configServers?: Record; } export interface ToolDiscoveryResult { diff --git a/packages/api/src/mcp/utils.ts b/packages/api/src/mcp/utils.ts index db89cffada..653a96d5bd 100644 --- a/packages/api/src/mcp/utils.ts +++ b/packages/api/src/mcp/utils.ts @@ -8,6 +8,15 @@ export function hasCustomUserVars(config: Pick 0; } +/** + * Determines whether a server config is user-sourced (sandboxed placeholder resolution). + * When `source` is set, it is authoritative. When absent (pre-upgrade cached configs), + * falls back to the legacy `dbId` heuristic for backward compatibility. + */ +export function isUserSourced(config: Pick): boolean { + return config.source != null ? config.source === 'user' : !!config.dbId; +} + /** * Allowlist-based sanitization for API responses. Only explicitly listed fields are included; * new fields added to ParsedServerConfig are excluded by default until allowlisted here. @@ -31,6 +40,8 @@ export function redactServerSecrets(config: ParsedServerConfig): Partial Date: Sat, 28 Mar 2026 16:43:50 -0400 Subject: [PATCH 13/18] =?UTF-8?q?=F0=9F=8F=97=EF=B8=8F=20feat:=20bulkWrite?= =?UTF-8?q?=20isolation,=20pre-auth=20context,=20strict-mode=20fixes=20(#1?= =?UTF-8?q?2445)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: wrap seedDatabase() in runAsSystem() for strict tenant mode seedDatabase() was called without tenant context at startup, causing every Mongoose operation inside it to throw when TENANT_ISOLATION_STRICT=true. Wrapping in runAsSystem() gives it the SYSTEM_TENANT_ID sentinel so the isolation plugin skips filtering, matching the pattern already used for performStartupChecks and updateInterfacePermissions. * fix: chain tenantContextMiddleware in optionalJwtAuth optionalJwtAuth populated req.user but never established ALS tenant context, unlike requireJwtAuth which chains tenantContextMiddleware after successful auth. Authenticated users hitting routes with optionalJwtAuth (e.g. /api/banner) had no tenant isolation. * feat: tenant-safe bulkWrite wrapper and call-site migration Mongoose's bulkWrite() does not trigger schema-level middleware hooks, so the applyTenantIsolation plugin cannot intercept it. This adds a tenantSafeBulkWrite() utility that injects the current ALS tenant context into every operation's filter/document before delegating to native bulkWrite. Migrates all 8 runtime bulkWrite call sites: - agentCategory (seedCategories, ensureDefaultCategories) - conversation (bulkSaveConvos) - message (bulkSaveMessages) - file (batchUpdateFiles) - conversationTag (updateTagsForConversation, bulkIncrementTagCounts) - aclEntry (bulkWriteAclEntries) systemGrant.seedSystemGrants is intentionally not migrated โ€” it uses explicit tenantId: { $exists: false } filters and is exempt from the isolation plugin. * feat: pre-auth tenant middleware and tenant-scoped config cache Adds preAuthTenantMiddleware that reads X-Tenant-Id from the request header and wraps downstream in tenantStorage ALS context. Wired onto /oauth, /api/auth, /api/config, and /api/share โ€” unauthenticated routes that need tenant scoping before JWT auth runs. The /api/config cache key is now tenant-scoped (STARTUP_CONFIG:${tenantId}) so multi-tenant deployments serve the correct login page config per tenant. The middleware is intentionally minimal โ€” no subdomain parsing, no OIDC claim extraction. The private fork's reverse proxy or auth gateway sets the header. * feat: accept optional tenantId in updateInterfacePermissions When tenantId is provided, the function re-enters inside tenantStorage.run({ tenantId }) so all downstream Mongoose queries target that tenant's roles instead of the system context. This lets the private fork's tenant provisioning flow call updateInterfacePermissions per-tenant after creating tenant-scoped ADMIN/USER roles. * fix: tenant-filter $lookup in getPromptGroup aggregation The $lookup stage in getPromptGroup() queried the prompts collection without tenant filtering. While the outer PromptGroup aggregate is protected by the tenantIsolation plugin's pre('aggregate') hook, $lookup runs as an internal MongoDB operation that bypasses Mongoose hooks entirely. Converts from simple field-based $lookup to pipeline-based $lookup with an explicit tenantId match when tenant context is active. * fix: replace field-level unique indexes with tenant-scoped compounds Field-level unique:true creates a globally-unique single-field index in MongoDB, which would cause insert failures across tenants sharing the same ID values. - agent.id: removed field-level unique, added { id, tenantId } compound - convo.conversationId: removed field-level unique (compound at line 50 already exists: { conversationId, user, tenantId }) - message.messageId: removed field-level unique (compound at line 165 already exists: { messageId, user, tenantId }) - preset.presetId: removed field-level unique, added { presetId, tenantId } compound * fix: scope MODELS_CONFIG, ENDPOINT_CONFIG, PLUGINS, TOOLS caches by tenant These caches store per-tenant configuration (available models, endpoint settings, plugin availability, tool definitions) but were using global cache keys. In multi-tenant mode, one tenant's cached config would be served to all tenants. Appends :${tenantId} to cache keys when tenant context is active. Falls back to the unscoped key when no tenant context exists (backward compatible for single-tenant OSS deployments). Covers all read, write, and delete sites: - ModelController.js: get/set MODELS_CONFIG - PluginController.js: get/set PLUGINS, get/set TOOLS - getEndpointsConfig.js: get/set/delete ENDPOINT_CONFIG - app.js: delete ENDPOINT_CONFIG (clearEndpointConfigCache) - mcp.js: delete TOOLS (updateMCPTools, mergeAppTools) - importers.js: get ENDPOINT_CONFIG * fix: add getTenantId to PluginController spec mock The data-schemas mock was missing getTenantId, causing all PluginController tests to throw when the controller calls getTenantId() for tenant-scoped cache keys. * fix: address review findings โ€” migration, strict-mode, DRY, types Addresses all CRITICAL, MAJOR, and MINOR review findings: F1 (CRITICAL): Add agents, conversations, messages, presets to SUPERSEDED_INDEXES in tenantIndexes.ts so dropSupersededTenantIndexes() drops the old single-field unique indexes that block multi-tenant inserts. F2 (CRITICAL): Unknown bulkWrite op types now throw in strict mode instead of silently passing through without tenant injection. F3 (MAJOR): Replace wildcard export with named export for tenantSafeBulkWrite, hiding _resetBulkWriteStrictCache from the public package API. F5 (MAJOR): Restore AnyBulkWriteOperation[] typing on bulkWriteAclEntries โ€” the unparameterized wrapper accepts parameterized ops as a subtype. F7 (MAJOR): Fix config.js tenant precedence โ€” JWT-derived req.user.tenantId now takes priority over the X-Tenant-Id header for authenticated requests. F8 (MINOR): Extract scopedCacheKey() helper into tenantContext.ts and replace all 11 inline occurrences across 7 files. F9 (MINOR): Use simple localField/foreignField $lookup for the non-tenant getPromptGroup path (more efficient index seeks). F12 (NIT): Remove redundant BulkOp type alias. F13 (NIT): Remove debug log that leaked raw tenantId. * fix: add new superseded indexes to tenantIndexes test fixture The test creates old indexes to verify the migration drops them. Missing fixture entries for agents.id_1, conversations.conversationId_1, messages.messageId_1, and presets.presetId_1 caused the count assertion to fail (expected 22, got 18). * fix: restore logger.warn for unknown bulk op types in non-strict mode * fix: block SYSTEM_TENANT_ID sentinel from external header input CRITICAL: preAuthTenantMiddleware accepted any string as X-Tenant-Id, including '__SYSTEM__'. The tenantIsolation plugin treats SYSTEM_TENANT_ID as an explicit bypass โ€” skipping ALL query filters. A client sending X-Tenant-Id: __SYSTEM__ to pre-auth routes (/api/share, /api/config, /api/auth, /oauth) would execute Mongoose operations without tenant isolation. Fixes: - preAuthTenantMiddleware rejects SYSTEM_TENANT_ID in header - scopedCacheKey returns the base key (not key:__SYSTEM__) in system context, preventing stale cache entries during runAsSystem() - updateInterfacePermissions guards tenantId against SYSTEM_TENANT_ID - $lookup pipeline separates $expr join from constant tenantId match for better index utilization - Regression test for sentinel rejection in preAuthTenant.spec.ts - Remove redundant getTenantId() call in config.js * test: add missing deleteMany/replaceOne coverage, fix vacuous ALS assertions bulkWrite spec: - deleteMany: verifies tenant-scoped deletion leaves other tenants untouched - replaceOne: verifies tenantId injected into both filter and replacement - replaceOne overwrite: verifies a conflicting tenantId in the replacement document is overwritten by the ALS tenant (defense-in-depth) - empty ops array: verifies graceful handling preAuthTenant spec: - All negative-case tests now use the capturedNext pattern to verify getTenantId() inside the middleware's execution context, not the test runner's outer frame (which was always undefined regardless) * feat: tenant-isolate MESSAGES cache, FLOWS cache, and GenerationJobManager MESSAGES cache (streamAudio.js): - Cache key now uses scopedCacheKey(messageId) to prefix with tenantId, preventing cross-tenant message content reads during TTS streaming. FLOWS cache (FlowStateManager): - getFlowKey() now generates ${type}:${tenantId}:${flowId} when tenant context is active, isolating OAuth flow state per tenant. GenerationJobManager: - tenantId added to SerializableJobData and GenerationJobMetadata - createJob() captures the current ALS tenant context (excluding SYSTEM_TENANT_ID) and stores it in job metadata - SSE subscription endpoint validates job.metadata.tenantId matches req.user.tenantId, blocking cross-tenant stream access - Both InMemoryJobStore and RedisJobStore updated to accept tenantId * fix: add getTenantId and SYSTEM_TENANT_ID to MCP OAuth test mocks FlowStateManager.getFlowKey() now calls getTenantId() for tenant-scoped flow keys. The 4 MCP OAuth test files mock @librechat/data-schemas without these exports, causing TypeError at runtime. * fix: correct import ordering per AGENTS.md conventions Package imports sorted shortest to longest line length, local imports sorted longest to shortest โ€” fixes ordering violations introduced by our new imports across 8 files. * fix: deserialize tenantId in RedisJobStore โ€” cross-tenant SSE guard was no-op in Redis mode serializeJob() writes tenantId to the Redis hash via Object.entries, but deserializeJob() manually enumerates fields and omitted tenantId. Every getJob() from Redis returned tenantId: undefined, causing the SSE route's cross-tenant guard to short-circuit (undefined && ... โ†’ false). * test: SSE tenant guard, FlowStateManager key consistency, ALS scope docs SSE stream tenant tests (streamTenant.spec.js): - Cross-tenant user accessing another tenant's stream โ†’ 403 - Same-tenant user accessing own stream โ†’ allowed - OSS mode (no tenantId on job) โ†’ tenant check skipped FlowStateManager tenant tests (manager.tenant.spec.ts): - completeFlow finds flow created under same tenant context - completeFlow does NOT find flow under different tenant context - Unscoped flows are separate from tenant-scoped flows Documentation: - JSDoc on getFlowKey documenting ALS context consistency requirement - Comment on streamAudio.js scopedCacheKey capture site * fix: SSE stream tests hang on success path, remove internal fork references The success-path tests entered the SSE streaming code which never closes, causing timeout. Mock subscribe() to end the response immediately. Restructured assertions to verify non-403/non-404. Removed "private fork" and "OSS" references from code and test descriptions โ€” replaced with "deployment layer", "multi-tenant deployments", and "single-tenant mode". * fix: address review findings โ€” test rigor, tenant ID validation, docs F1: SSE stream tests now mock subscribe() with correct signature (streamId, writeEvent, onDone, onError) and assert 200 status, verifying the tenant guard actually allows through same-tenant users. F2: completeFlow logs the attempted key and ALS tenantId when flow is not found, so reverse proxy misconfiguration (missing X-Tenant-Id on OAuth callback) produces an actionable warning. F3/F10: preAuthTenantMiddleware validates tenant ID format โ€” rejects colons, special characters, and values exceeding 128 chars. Trims whitespace. Prevents cache key collisions via crafted headers. F4: Documented cache invalidation scope limitation in clearEndpointConfigCache โ€” only the calling tenant's key is cleared; other tenants expire via TTL. F7: getFlowKey JSDoc now lists all 8 methods requiring consistent ALS context. F8: Added dedicated scopedCacheKey unit tests โ€” base key without context, base key in system context, scoped key with tenant, no ALS leakage across scope boundaries. * fix: revert flow key tenant scoping, fix SSE test timing FlowStateManager: Reverts tenant-scoped flow keys. OAuth callbacks arrive without tenant ALS context (provider redirects don't carry X-Tenant-Id), so completeFlow/failFlow would never find flows created under tenant context. Flow IDs are random UUIDs with no collision risk, and flow data is ephemeral (TTL-bounded). SSE tests: Use process.nextTick for onDone callback so Express response headers are flushed before res.write/res.end are called. * fix: restore getTenantId import for completeFlow diagnostic log * fix: correct completeFlow warning message, add missing flow test The warning referenced X-Tenant-Id header consistency which was only relevant when flow keys were tenant-scoped (since reverted). Updated to list actual causes: TTL expiry, missing flow, or routing to a different instance without shared Keyv storage. Removed the getTenantId() call and import โ€” no longer needed since flow keys are unscoped. Added test for the !flowState branch in completeFlow โ€” verifies return false and logger.warn on nonexistent flow ID. * fix: add explicit return type to recursive updateInterfacePermissions The recursive call (tenantId branch calls itself without tenantId) causes TypeScript to infer circular return type 'any'. Adding explicit Promise satisfies the rollup typescript plugin. * fix: update MCPOAuthRaceCondition test to match new completeFlow warning * fix: clearEndpointConfigCache deletes both scoped and unscoped keys Unauthenticated /api/endpoints requests populate the unscoped ENDPOINT_CONFIG key. Admin config mutations clear only the tenant-scoped key, leaving the unscoped entry stale indefinitely. Now deletes both when in tenant context. * fix: tenant guard on abort/status endpoints, warn logs, test coverage F1: Add tenant guard to /chat/status/:conversationId and /chat/abort matching the existing guard on /chat/stream/:streamId. The status endpoint exposes aggregatedContent (AI response text) which requires tenant-level access control. F2: preAuthTenantMiddleware now logs warn for rejected __SYSTEM__ sentinel and malformed tenant IDs, providing observability for bypass probing attempts. F3: Abort fallback path (getActiveJobIdsForUser) now has tenant check after resolving the job. F4: Test for strict mode + SYSTEM_TENANT_ID โ€” verifies runAsSystem bypasses tenantSafeBulkWrite without throwing in strict mode. F5: Test for job with tenantId + user without tenantId โ†’ 403. F10: Regex uses idiomatic hyphen-at-start form. F11: Test descriptions changed from "rejects" to "ignores" since middleware calls next() (not 4xx). Also fixes MCPOAuthRaceCondition test assertion to match updated completeFlow warning message. * fix: test coverage for logger.warn, status/abort guards, consistency A: preAuthTenant spec now mocks logger and asserts warn calls for __SYSTEM__ sentinel, malformed characters, and oversized headers. B: streamTenant spec expanded with status and abort endpoint tests โ€” cross-tenant status returns 403, same-tenant returns 200 with body, cross-tenant abort returns 403. C: Abort endpoint uses req.user.tenantId (not req.user?.tenantId) matching stream/status pattern โ€” requireJwtAuth guarantees req.user. D: Malformed header warning now includes ip in log metadata, matching the sentinel warning for consistent SOC correlation. * fix: assert ip field in malformed header warn tests * fix: parallelize cache deletes, document tenant guard, fix import order - clearEndpointConfigCache uses Promise.all for independent cache deletes instead of sequential awaits - SSE stream tenant guard has inline comment explaining backward-compat behavior for untenanted legacy jobs - conversation.ts local imports reordered longest-to-shortest per AGENTS.md * fix: tenant-qualify userJobs keys, document tenant guard backward-compat Job store userJobs keys now include tenantId when available: - Redis: stream:user:{tenantId:userId}:jobs (falls back to stream:user:{userId}:jobs when no tenant) - InMemory: composite key tenantId:userId in userJobMap getActiveJobIdsByUser/getActiveJobIdsForUser accept optional tenantId parameter, threaded through from req.user.tenantId at all call sites (/chat/active and /chat/abort fallback). Added inline comments on all three SSE tenant guards explaining the backward-compat design: untenanted legacy jobs remain accessible when the userId check passes. * fix: parallelize cache deletes, document tenant guard, fix import order Fix InMemoryJobStore.getActiveJobIdsByUser empty-set cleanup to use the tenant-qualified userKey instead of bare userId โ€” prevents orphaned empty Sets accumulating in userJobMap for multi-tenant users. Document cross-tenant staleness in clearEndpointConfigCache JSDoc โ€” other tenants' scoped keys expire via TTL, not active invalidation. * fix: cleanup userJobMap leak, startup warning, DRY tenant guard, docs F1: InMemoryJobStore.cleanup() now removes entries from userJobMap before calling deleteJob, preventing orphaned empty Sets from accumulating with tenant-qualified composite keys. F2: Startup warning when TENANT_ISOLATION_STRICT is active โ€” reminds operators to configure reverse proxy to control X-Tenant-Id header. F3: mergeAppTools JSDoc documents that tenant-scoped TOOLS keys are not actively invalidated (matching clearEndpointConfigCache pattern). F5: Abort handler getActiveJobIdsForUser call uses req.user.tenantId (not req.user?.tenantId) โ€” consistent with stream/status handlers. F6: updateInterfacePermissions JSDoc clarifies SYSTEM_TENANT_ID behavior โ€” falls through to caller's ALS context. F7: Extracted hasTenantMismatch() helper, replacing three identical inline tenant guard blocks across stream/status/abort endpoints. F9: scopedCacheKey JSDoc documents both passthrough cases (no context and SYSTEM_TENANT_ID context). * fix: clean userJobMap in evictOldest โ€” same leak as cleanup() --- api/server/controllers/ModelController.js | 10 +- api/server/controllers/PluginController.js | 12 +- .../controllers/PluginController.spec.js | 1 + api/server/index.js | 20 +- api/server/middleware/optionalJwtAuth.js | 6 +- .../agents/__tests__/streamTenant.spec.js | 186 +++++++++ api/server/routes/agents/index.js | 27 +- api/server/routes/config.js | 12 +- api/server/services/Config/app.js | 18 +- .../services/Config/getEndpointsConfig.js | 8 +- api/server/services/Config/mcp.js | 11 +- .../services/Files/Audio/streamAudio.js | 7 +- api/server/utils/import/importers.js | 4 +- packages/api/src/app/permissions.ts | 17 +- packages/api/src/flow/manager.tenant.spec.ts | 49 +++ packages/api/src/flow/manager.ts | 10 +- .../__tests__/MCPOAuthCSRFFallback.test.ts | 2 + .../src/mcp/__tests__/MCPOAuthFlow.test.ts | 2 + .../__tests__/MCPOAuthRaceCondition.test.ts | 4 +- .../mcp/__tests__/MCPOAuthTokenExpiry.test.ts | 2 + packages/api/src/middleware/index.ts | 1 + .../api/src/middleware/preAuthTenant.spec.ts | 129 ++++++ packages/api/src/middleware/preAuthTenant.ts | 72 ++++ .../api/src/stream/GenerationJobManager.ts | 11 +- .../implementations/InMemoryJobStore.ts | 38 +- .../stream/implementations/RedisJobStore.ts | 14 +- .../api/src/stream/interfaces/IJobStore.ts | 4 +- packages/api/src/types/stream.ts | 1 + .../src/config/tenantContext.spec.ts | 26 ++ .../data-schemas/src/config/tenantContext.ts | 13 + packages/data-schemas/src/index.ts | 8 +- packages/data-schemas/src/methods/aclEntry.ts | 7 +- .../data-schemas/src/methods/agentCategory.ts | 5 +- .../data-schemas/src/methods/conversation.ts | 5 +- .../src/methods/conversationTag.ts | 5 +- packages/data-schemas/src/methods/file.ts | 3 +- packages/data-schemas/src/methods/message.ts | 3 +- packages/data-schemas/src/methods/prompt.ts | 40 +- .../src/migrations/tenantIndexes.spec.ts | 14 + .../src/migrations/tenantIndexes.ts | 6 +- packages/data-schemas/src/schema/agent.ts | 3 +- packages/data-schemas/src/schema/convo.ts | 1 - packages/data-schemas/src/schema/message.ts | 1 - packages/data-schemas/src/schema/preset.ts | 3 +- packages/data-schemas/src/utils/index.ts | 1 + .../src/utils/tenantBulkWrite.spec.ts | 376 ++++++++++++++++++ .../data-schemas/src/utils/tenantBulkWrite.ts | 109 +++++ 47 files changed, 1224 insertions(+), 83 deletions(-) create mode 100644 api/server/routes/agents/__tests__/streamTenant.spec.js create mode 100644 packages/api/src/flow/manager.tenant.spec.ts create mode 100644 packages/api/src/middleware/preAuthTenant.spec.ts create mode 100644 packages/api/src/middleware/preAuthTenant.ts create mode 100644 packages/data-schemas/src/config/tenantContext.spec.ts create mode 100644 packages/data-schemas/src/utils/tenantBulkWrite.spec.ts create mode 100644 packages/data-schemas/src/utils/tenantBulkWrite.ts diff --git a/api/server/controllers/ModelController.js b/api/server/controllers/ModelController.js index 805d9eef27..1741d3f6b1 100644 --- a/api/server/controllers/ModelController.js +++ b/api/server/controllers/ModelController.js @@ -1,5 +1,5 @@ -const { logger } = require('@librechat/data-schemas'); const { CacheKeys } = require('librechat-data-provider'); +const { logger, scopedCacheKey } = require('@librechat/data-schemas'); const { loadDefaultModels, loadConfigModels } = require('~/server/services/Config'); const { getLogStores } = require('~/cache'); @@ -9,7 +9,8 @@ const { getLogStores } = require('~/cache'); */ const getModelsConfig = async (req) => { const cache = getLogStores(CacheKeys.CONFIG_STORE); - let modelsConfig = await cache.get(CacheKeys.MODELS_CONFIG); + const cacheKey = scopedCacheKey(CacheKeys.MODELS_CONFIG); + let modelsConfig = await cache.get(cacheKey); if (!modelsConfig) { modelsConfig = await loadModels(req); } @@ -24,7 +25,8 @@ const getModelsConfig = async (req) => { */ async function loadModels(req) { const cache = getLogStores(CacheKeys.CONFIG_STORE); - const cachedModelsConfig = await cache.get(CacheKeys.MODELS_CONFIG); + const cacheKey = scopedCacheKey(CacheKeys.MODELS_CONFIG); + const cachedModelsConfig = await cache.get(cacheKey); if (cachedModelsConfig) { return cachedModelsConfig; } @@ -33,7 +35,7 @@ async function loadModels(req) { const modelConfig = { ...defaultModelsConfig, ...customModelsConfig }; - await cache.set(CacheKeys.MODELS_CONFIG, modelConfig); + await cache.set(cacheKey, modelConfig); return modelConfig; } diff --git a/api/server/controllers/PluginController.js b/api/server/controllers/PluginController.js index 14dd284c30..7c47fe4d57 100644 --- a/api/server/controllers/PluginController.js +++ b/api/server/controllers/PluginController.js @@ -1,5 +1,5 @@ -const { logger } = require('@librechat/data-schemas'); const { CacheKeys } = require('librechat-data-provider'); +const { logger, scopedCacheKey } = require('@librechat/data-schemas'); const { getToolkitKey, checkPluginAuth, filterUniquePlugins } = require('@librechat/api'); const { getCachedTools, setCachedTools } = require('~/server/services/Config'); const { availableTools, toolkits } = require('~/app/clients/tools'); @@ -9,7 +9,8 @@ const { getLogStores } = require('~/cache'); const getAvailablePluginsController = async (req, res) => { try { const cache = getLogStores(CacheKeys.TOOL_CACHE); - const cachedPlugins = await cache.get(CacheKeys.PLUGINS); + const pluginsCacheKey = scopedCacheKey(CacheKeys.PLUGINS); + const cachedPlugins = await cache.get(pluginsCacheKey); if (cachedPlugins) { res.status(200).json(cachedPlugins); return; @@ -37,7 +38,7 @@ const getAvailablePluginsController = async (req, res) => { plugins = plugins.filter((plugin) => !filteredTools.includes(plugin.pluginKey)); } - await cache.set(CacheKeys.PLUGINS, plugins); + await cache.set(pluginsCacheKey, plugins); res.status(200).json(plugins); } catch (error) { res.status(500).json({ message: error.message }); @@ -64,7 +65,8 @@ const getAvailableTools = async (req, res) => { return res.status(401).json({ message: 'Unauthorized' }); } const cache = getLogStores(CacheKeys.TOOL_CACHE); - const cachedToolsArray = await cache.get(CacheKeys.TOOLS); + const toolsCacheKey = scopedCacheKey(CacheKeys.TOOLS); + const cachedToolsArray = await cache.get(toolsCacheKey); const appConfig = req.config ?? (await getAppConfig({ role: req.user?.role, tenantId: req.user?.tenantId })); @@ -115,7 +117,7 @@ const getAvailableTools = async (req, res) => { } const finalTools = filterUniquePlugins(toolsOutput); - await cache.set(CacheKeys.TOOLS, finalTools); + await cache.set(toolsCacheKey, finalTools); res.status(200).json(finalTools); } catch (error) { diff --git a/api/server/controllers/PluginController.spec.js b/api/server/controllers/PluginController.spec.js index 06a51a3bd6..fdbc2401ce 100644 --- a/api/server/controllers/PluginController.spec.js +++ b/api/server/controllers/PluginController.spec.js @@ -8,6 +8,7 @@ jest.mock('@librechat/data-schemas', () => ({ error: jest.fn(), warn: jest.fn(), }, + scopedCacheKey: jest.fn((key) => key), })); jest.mock('~/server/services/Config', () => ({ diff --git a/api/server/index.js b/api/server/index.js index 813b453468..4b919b1ceb 100644 --- a/api/server/index.js +++ b/api/server/index.js @@ -21,6 +21,7 @@ const { createStreamServices, initializeFileStorage, updateInterfacePermissions, + preAuthTenantMiddleware, } = require('@librechat/api'); const { connectDb, indexSync } = require('~/db'); const initializeOAuthReconnectManager = require('./services/initializeOAuthReconnectManager'); @@ -59,7 +60,14 @@ const startServer = async () => { app.disable('x-powered-by'); app.set('trust proxy', trusted_proxy); - await seedDatabase(); + if (isEnabled(process.env.TENANT_ISOLATION_STRICT)) { + logger.warn( + '[Security] TENANT_ISOLATION_STRICT is active. Ensure your reverse proxy strips or sets ' + + 'the X-Tenant-Id header โ€” untrusted clients must not be able to set it directly.', + ); + } + + await runAsSystem(seedDatabase); const appConfig = await getAppConfig({ baseOnly: true }); initializeFileStorage(appConfig); await runAsSystem(async () => { @@ -139,9 +147,11 @@ const startServer = async () => { /* Per-request capability cache โ€” must be registered before any route that calls hasCapability */ app.use(capabilityContextMiddleware); - app.use('/oauth', routes.oauth); + /* Pre-auth tenant context for unauthenticated routes that need tenant scoping. + * The reverse proxy / auth gateway sets `X-Tenant-Id` header for multi-tenant deployments. */ + app.use('/oauth', preAuthTenantMiddleware, routes.oauth); /* API Endpoints */ - app.use('/api/auth', routes.auth); + app.use('/api/auth', preAuthTenantMiddleware, routes.auth); app.use('/api/admin', routes.adminAuth); app.use('/api/admin/config', routes.adminConfig); app.use('/api/admin/groups', routes.adminGroups); @@ -159,11 +169,11 @@ const startServer = async () => { app.use('/api/endpoints', routes.endpoints); app.use('/api/balance', routes.balance); app.use('/api/models', routes.models); - app.use('/api/config', routes.config); + app.use('/api/config', preAuthTenantMiddleware, routes.config); app.use('/api/assistants', routes.assistants); app.use('/api/files', await routes.files.initialize()); app.use('/images/', createValidateImageRequest(appConfig.secureImageLinks), routes.staticRoute); - app.use('/api/share', routes.share); + app.use('/api/share', preAuthTenantMiddleware, routes.share); app.use('/api/roles', routes.roles); app.use('/api/agents', routes.agents); app.use('/api/banner', routes.banner); diff --git a/api/server/middleware/optionalJwtAuth.js b/api/server/middleware/optionalJwtAuth.js index 2f59fdda4a..d46478d36e 100644 --- a/api/server/middleware/optionalJwtAuth.js +++ b/api/server/middleware/optionalJwtAuth.js @@ -1,9 +1,10 @@ const cookies = require('cookie'); const passport = require('passport'); -const { isEnabled } = require('@librechat/api'); +const { isEnabled, tenantContextMiddleware } = require('@librechat/api'); // This middleware does not require authentication, -// but if the user is authenticated, it will set the user object. +// but if the user is authenticated, it will set the user object +// and establish tenant ALS context. const optionalJwtAuth = (req, res, next) => { const cookieHeader = req.headers.cookie; const tokenProvider = cookieHeader ? cookies.parse(cookieHeader).token_provider : null; @@ -13,6 +14,7 @@ const optionalJwtAuth = (req, res, next) => { } if (user) { req.user = user; + return tenantContextMiddleware(req, res, next); } next(); }; diff --git a/api/server/routes/agents/__tests__/streamTenant.spec.js b/api/server/routes/agents/__tests__/streamTenant.spec.js new file mode 100644 index 0000000000..1f89953186 --- /dev/null +++ b/api/server/routes/agents/__tests__/streamTenant.spec.js @@ -0,0 +1,186 @@ +const express = require('express'); +const request = require('supertest'); + +const mockGenerationJobManager = { + getJob: jest.fn(), + subscribe: jest.fn(), + getResumeState: jest.fn(), + abortJob: jest.fn(), + getActiveJobIdsForUser: jest.fn().mockResolvedValue([]), +}; + +jest.mock('@librechat/data-schemas', () => ({ + ...jest.requireActual('@librechat/data-schemas'), + logger: { + debug: jest.fn(), + warn: jest.fn(), + error: jest.fn(), + info: jest.fn(), + }, +})); + +jest.mock('@librechat/api', () => ({ + ...jest.requireActual('@librechat/api'), + isEnabled: jest.fn().mockReturnValue(false), + GenerationJobManager: mockGenerationJobManager, +})); + +jest.mock('~/models', () => ({ + saveMessage: jest.fn(), +})); + +let mockUserId = 'user-123'; +let mockTenantId; + +jest.mock('~/server/middleware', () => ({ + uaParser: (req, res, next) => next(), + checkBan: (req, res, next) => next(), + requireJwtAuth: (req, res, next) => { + req.user = { id: mockUserId, tenantId: mockTenantId }; + next(); + }, + messageIpLimiter: (req, res, next) => next(), + configMiddleware: (req, res, next) => next(), + messageUserLimiter: (req, res, next) => next(), +})); + +jest.mock('~/server/routes/agents/chat', () => require('express').Router()); +jest.mock('~/server/routes/agents/v1', () => ({ + v1: require('express').Router(), +})); +jest.mock('~/server/routes/agents/openai', () => require('express').Router()); +jest.mock('~/server/routes/agents/responses', () => require('express').Router()); + +const agentsRouter = require('../index'); +const app = express(); +app.use(express.json()); +app.use('/agents', agentsRouter); + +function mockSubscribeSuccess() { + mockGenerationJobManager.subscribe.mockImplementation((_streamId, _writeEvent, onDone) => { + process.nextTick(() => onDone({ done: true })); + return { unsubscribe: jest.fn() }; + }); +} + +describe('SSE stream tenant isolation', () => { + beforeEach(() => { + jest.clearAllMocks(); + mockUserId = 'user-123'; + mockTenantId = undefined; + }); + + describe('GET /chat/stream/:streamId', () => { + it('returns 403 when a user from a different tenant accesses a stream', async () => { + mockUserId = 'user-456'; + mockTenantId = 'tenant-b'; + + mockGenerationJobManager.getJob.mockResolvedValue({ + metadata: { userId: 'user-456', tenantId: 'tenant-a' }, + status: 'running', + }); + + const res = await request(app).get('/agents/chat/stream/stream-123'); + expect(res.status).toBe(403); + expect(res.body.error).toBe('Unauthorized'); + }); + + it('returns 404 when stream does not exist', async () => { + mockGenerationJobManager.getJob.mockResolvedValue(null); + + const res = await request(app).get('/agents/chat/stream/nonexistent'); + expect(res.status).toBe(404); + }); + + it('proceeds past tenant guard when tenant matches', async () => { + mockUserId = 'user-123'; + mockTenantId = 'tenant-a'; + mockSubscribeSuccess(); + + mockGenerationJobManager.getJob.mockResolvedValue({ + metadata: { userId: 'user-123', tenantId: 'tenant-a' }, + status: 'running', + }); + + const res = await request(app).get('/agents/chat/stream/stream-123'); + expect(res.status).toBe(200); + expect(mockGenerationJobManager.subscribe).toHaveBeenCalledTimes(1); + }); + + it('proceeds past tenant guard when job has no tenantId (single-tenant mode)', async () => { + mockUserId = 'user-123'; + mockTenantId = undefined; + mockSubscribeSuccess(); + + mockGenerationJobManager.getJob.mockResolvedValue({ + metadata: { userId: 'user-123' }, + status: 'running', + }); + + const res = await request(app).get('/agents/chat/stream/stream-123'); + expect(res.status).toBe(200); + expect(mockGenerationJobManager.subscribe).toHaveBeenCalledTimes(1); + }); + + it('returns 403 when job has tenantId but user has no tenantId', async () => { + mockUserId = 'user-123'; + mockTenantId = undefined; + + mockGenerationJobManager.getJob.mockResolvedValue({ + metadata: { userId: 'user-123', tenantId: 'some-tenant' }, + status: 'running', + }); + + const res = await request(app).get('/agents/chat/stream/stream-123'); + expect(res.status).toBe(403); + }); + }); + + describe('GET /chat/status/:conversationId', () => { + it('returns 403 when tenant does not match', async () => { + mockUserId = 'user-123'; + mockTenantId = 'tenant-b'; + + mockGenerationJobManager.getJob.mockResolvedValue({ + metadata: { userId: 'user-123', tenantId: 'tenant-a' }, + status: 'running', + }); + + const res = await request(app).get('/agents/chat/status/conv-123'); + expect(res.status).toBe(403); + expect(res.body.error).toBe('Unauthorized'); + }); + + it('returns status when tenant matches', async () => { + mockUserId = 'user-123'; + mockTenantId = 'tenant-a'; + + mockGenerationJobManager.getJob.mockResolvedValue({ + metadata: { userId: 'user-123', tenantId: 'tenant-a' }, + status: 'running', + createdAt: Date.now(), + }); + mockGenerationJobManager.getResumeState.mockResolvedValue(null); + + const res = await request(app).get('/agents/chat/status/conv-123'); + expect(res.status).toBe(200); + expect(res.body.active).toBe(true); + }); + }); + + describe('POST /chat/abort', () => { + it('returns 403 when tenant does not match', async () => { + mockUserId = 'user-123'; + mockTenantId = 'tenant-b'; + + mockGenerationJobManager.getJob.mockResolvedValue({ + metadata: { userId: 'user-123', tenantId: 'tenant-a' }, + status: 'running', + }); + + const res = await request(app).post('/agents/chat/abort').send({ streamId: 'stream-123' }); + expect(res.status).toBe(403); + expect(res.body.error).toBe('Unauthorized'); + }); + }); +}); diff --git a/api/server/routes/agents/index.js b/api/server/routes/agents/index.js index 86966a3f3e..eb42046bed 100644 --- a/api/server/routes/agents/index.js +++ b/api/server/routes/agents/index.js @@ -17,6 +17,11 @@ const chat = require('./chat'); const { LIMIT_MESSAGE_IP, LIMIT_MESSAGE_USER } = process.env ?? {}; +/** Untenanted jobs (pre-multi-tenancy) remain accessible if the userId check passes. */ +function hasTenantMismatch(job, user) { + return job.metadata?.tenantId != null && job.metadata.tenantId !== user.tenantId; +} + const router = express.Router(); /** @@ -67,6 +72,10 @@ router.get('/chat/stream/:streamId', async (req, res) => { return res.status(403).json({ error: 'Unauthorized' }); } + if (hasTenantMismatch(job, req.user)) { + return res.status(403).json({ error: 'Unauthorized' }); + } + res.setHeader('Content-Encoding', 'identity'); res.setHeader('Content-Type', 'text/event-stream'); res.setHeader('Cache-Control', 'no-cache, no-transform'); @@ -150,7 +159,10 @@ router.get('/chat/stream/:streamId', async (req, res) => { * @returns { activeJobIds: string[] } */ router.get('/chat/active', async (req, res) => { - const activeJobIds = await GenerationJobManager.getActiveJobIdsForUser(req.user.id); + const activeJobIds = await GenerationJobManager.getActiveJobIdsForUser( + req.user.id, + req.user.tenantId, + ); res.json({ activeJobIds }); }); @@ -174,6 +186,10 @@ router.get('/chat/status/:conversationId', async (req, res) => { return res.status(403).json({ error: 'Unauthorized' }); } + if (hasTenantMismatch(job, req.user)) { + return res.status(403).json({ error: 'Unauthorized' }); + } + // Get resume state which contains aggregatedContent // Avoid calling both getStreamInfo and getResumeState (both fetch content) const resumeState = await GenerationJobManager.getResumeState(conversationId); @@ -213,7 +229,10 @@ router.post('/chat/abort', async (req, res) => { // This handles the case where frontend sends "new" but job was created with a UUID if (!job && userId) { logger.debug(`[AgentStream] Job not found by ID, checking active jobs for user: ${userId}`); - const activeJobIds = await GenerationJobManager.getActiveJobIdsForUser(userId); + const activeJobIds = await GenerationJobManager.getActiveJobIdsForUser( + userId, + req.user.tenantId, + ); if (activeJobIds.length > 0) { // Abort the most recent active job for this user jobStreamId = activeJobIds[0]; @@ -230,6 +249,10 @@ router.post('/chat/abort', async (req, res) => { return res.status(403).json({ error: 'Unauthorized' }); } + if (hasTenantMismatch(job, req.user)) { + return res.status(403).json({ error: 'Unauthorized' }); + } + logger.debug(`[AgentStream] Job found, aborting: ${jobStreamId}`); const abortResult = await GenerationJobManager.abortJob(jobStreamId); logger.debug(`[AgentStream] Job aborted successfully: ${jobStreamId}`, { diff --git a/api/server/routes/config.js b/api/server/routes/config.js index 0a68ccba4f..8caa180854 100644 --- a/api/server/routes/config.js +++ b/api/server/routes/config.js @@ -1,7 +1,7 @@ const express = require('express'); -const { logger } = require('@librechat/data-schemas'); const { isEnabled, getBalanceConfig } = require('@librechat/api'); const { CacheKeys, defaultSocialLogins } = require('librechat-data-provider'); +const { logger, getTenantId, scopedCacheKey } = require('@librechat/data-schemas'); const { getLdapConfig } = require('~/server/services/Config/ldap'); const { getAppConfig } = require('~/server/services/Config/app'); const { getLogStores } = require('~/cache'); @@ -23,7 +23,8 @@ const openidReuseTokens = isEnabled(process.env.OPENID_REUSE_TOKENS); router.get('/', async function (req, res) { const cache = getLogStores(CacheKeys.CONFIG_STORE); - const cachedStartupConfig = await cache.get(CacheKeys.STARTUP_CONFIG); + const cacheKey = scopedCacheKey(CacheKeys.STARTUP_CONFIG); + const cachedStartupConfig = await cache.get(cacheKey); if (cachedStartupConfig) { res.send(cachedStartupConfig); return; @@ -37,7 +38,10 @@ router.get('/', async function (req, res) { const ldap = getLdapConfig(); try { - const appConfig = await getAppConfig({ role: req.user?.role, tenantId: req.user?.tenantId }); + const appConfig = await getAppConfig({ + role: req.user?.role, + tenantId: req.user?.tenantId || getTenantId(), + }); const isOpenIdEnabled = !!process.env.OPENID_CLIENT_ID && @@ -141,7 +145,7 @@ router.get('/', async function (req, res) { payload.customFooter = process.env.CUSTOM_FOOTER; } - await cache.set(CacheKeys.STARTUP_CONFIG, payload); + await cache.set(cacheKey, payload); return res.status(200).send(payload); } catch (err) { logger.error('Error in startup config', err); diff --git a/api/server/services/Config/app.js b/api/server/services/Config/app.js index 7530ca1031..3256732ec2 100644 --- a/api/server/services/Config/app.js +++ b/api/server/services/Config/app.js @@ -1,5 +1,5 @@ const { CacheKeys } = require('librechat-data-provider'); -const { AppService, logger } = require('@librechat/data-schemas'); +const { AppService, logger, scopedCacheKey } = require('@librechat/data-schemas'); const { createAppConfigService, clearMcpConfigCache } = require('@librechat/api'); const { setCachedTools, invalidateCachedTools } = require('./getCachedTools'); const { loadAndFormatTools } = require('~/server/services/start/tools'); @@ -29,11 +29,23 @@ const { getAppConfig, clearAppConfigCache, clearOverrideCache } = createAppConfi getUserPrincipals: db.getUserPrincipals, }); -/** Deletes the ENDPOINT_CONFIG entry from CONFIG_STORE. Failures are non-critical and swallowed. */ +/** + * Deletes ENDPOINT_CONFIG entries from CONFIG_STORE. + * Clears both the tenant-scoped key (if in tenant context) and the + * unscoped base key (populated by unauthenticated /api/endpoints calls). + * Other tenants' scoped keys are NOT actively cleared โ€” they expire + * via TTL. Config mutations in one tenant do not propagate immediately + * to other tenants' endpoint config caches. + */ async function clearEndpointConfigCache() { try { const configStore = getLogStores(CacheKeys.CONFIG_STORE); - await configStore.delete(CacheKeys.ENDPOINT_CONFIG); + const scoped = scopedCacheKey(CacheKeys.ENDPOINT_CONFIG); + const keys = [scoped]; + if (scoped !== CacheKeys.ENDPOINT_CONFIG) { + keys.push(CacheKeys.ENDPOINT_CONFIG); + } + await Promise.all(keys.map((k) => configStore.delete(k))); } catch { // CONFIG_STORE or ENDPOINT_CONFIG may not exist โ€” not critical } diff --git a/api/server/services/Config/getEndpointsConfig.js b/api/server/services/Config/getEndpointsConfig.js index 476d3d7c80..cd0230ad4a 100644 --- a/api/server/services/Config/getEndpointsConfig.js +++ b/api/server/services/Config/getEndpointsConfig.js @@ -1,3 +1,4 @@ +const { scopedCacheKey } = require('@librechat/data-schemas'); const { loadCustomEndpointsConfig } = require('@librechat/api'); const { CacheKeys, @@ -17,10 +18,11 @@ const { getAppConfig } = require('./app'); */ async function getEndpointsConfig(req) { const cache = getLogStores(CacheKeys.CONFIG_STORE); - const cachedEndpointsConfig = await cache.get(CacheKeys.ENDPOINT_CONFIG); + const cacheKey = scopedCacheKey(CacheKeys.ENDPOINT_CONFIG); + const cachedEndpointsConfig = await cache.get(cacheKey); if (cachedEndpointsConfig) { if (cachedEndpointsConfig.gptPlugins) { - await cache.delete(CacheKeys.ENDPOINT_CONFIG); + await cache.delete(cacheKey); } else { return cachedEndpointsConfig; } @@ -112,7 +114,7 @@ async function getEndpointsConfig(req) { const endpointsConfig = orderEndpointsConfig(mergedConfig); - await cache.set(CacheKeys.ENDPOINT_CONFIG, endpointsConfig); + await cache.set(cacheKey, endpointsConfig); return endpointsConfig; } diff --git a/api/server/services/Config/mcp.js b/api/server/services/Config/mcp.js index cc4e98b59e..869c9e66da 100644 --- a/api/server/services/Config/mcp.js +++ b/api/server/services/Config/mcp.js @@ -1,5 +1,5 @@ -const { logger } = require('@librechat/data-schemas'); const { CacheKeys, Constants } = require('librechat-data-provider'); +const { logger, scopedCacheKey } = require('@librechat/data-schemas'); const { getCachedTools, setCachedTools } = require('./getCachedTools'); const { getLogStores } = require('~/cache'); @@ -36,7 +36,7 @@ async function updateMCPServerTools({ userId, serverName, tools }) { await setCachedTools(serverTools, { userId, serverName }); const cache = getLogStores(CacheKeys.TOOL_CACHE); - await cache.delete(CacheKeys.TOOLS); + await cache.delete(scopedCacheKey(CacheKeys.TOOLS)); logger.debug( `[MCP Cache] Updated ${tools.length} tools for server ${serverName} (user: ${userId})`, ); @@ -48,7 +48,10 @@ async function updateMCPServerTools({ userId, serverName, tools }) { } /** - * Merges app-level tools with global tools + * Merges app-level tools with global tools. + * Only the current ALS-scoped key (base key in system/startup context) is cleared. + * Tenant-scoped TOOLS:tenantId keys are NOT actively invalidated โ€” they expire + * via TTL on the next tenant request. This matches clearEndpointConfigCache behavior. * @param {import('@librechat/api').LCAvailableTools} appTools * @returns {Promise} */ @@ -62,7 +65,7 @@ async function mergeAppTools(appTools) { const mergedTools = { ...cachedTools, ...appTools }; await setCachedTools(mergedTools); const cache = getLogStores(CacheKeys.TOOL_CACHE); - await cache.delete(CacheKeys.TOOLS); + await cache.delete(scopedCacheKey(CacheKeys.TOOLS)); logger.debug(`Merged ${count} app-level tools`); } catch (error) { logger.error('Failed to merge app-level tools:', error); diff --git a/api/server/services/Files/Audio/streamAudio.js b/api/server/services/Files/Audio/streamAudio.js index c28a96edff..7120399b5e 100644 --- a/api/server/services/Files/Audio/streamAudio.js +++ b/api/server/services/Files/Audio/streamAudio.js @@ -1,3 +1,4 @@ +const { scopedCacheKey } = require('@librechat/data-schemas'); const { Time, CacheKeys, @@ -67,6 +68,8 @@ function createChunkProcessor(user, messageId) { } const messageCache = getLogStores(CacheKeys.MESSAGES); + // Captured at creation time โ€” must be called within an active request ALS scope + const cacheKey = scopedCacheKey(messageId); /** * @returns {Promise<{ text: string, isFinished: boolean }[] | string>} @@ -81,7 +84,7 @@ function createChunkProcessor(user, messageId) { } /** @type { string | { text: string; complete: boolean } } */ - let message = await messageCache.get(messageId); + let message = await messageCache.get(cacheKey); if (!message) { message = await getMessage({ user, messageId }); } @@ -92,7 +95,7 @@ function createChunkProcessor(user, messageId) { } else { const text = message.content?.length > 0 ? parseTextParts(message.content) : message.text; messageCache.set( - messageId, + cacheKey, { text, complete: true, diff --git a/api/server/utils/import/importers.js b/api/server/utils/import/importers.js index 39734c181c..f8b3be4dab 100644 --- a/api/server/utils/import/importers.js +++ b/api/server/utils/import/importers.js @@ -1,5 +1,5 @@ const { v4: uuidv4 } = require('uuid'); -const { logger } = require('@librechat/data-schemas'); +const { logger, scopedCacheKey } = require('@librechat/data-schemas'); const { EModelEndpoint, Constants, openAISettings, CacheKeys } = require('librechat-data-provider'); const { createImportBatchBuilder } = require('./importBatchBuilder'); const { cloneMessagesWithTimestamps } = require('./fork'); @@ -203,7 +203,7 @@ async function importLibreChatConvo( /* Endpoint configuration */ let endpoint = jsonData.endpoint ?? options.endpoint ?? EModelEndpoint.openAI; const cache = getLogStores(CacheKeys.CONFIG_STORE); - const endpointsConfig = await cache.get(CacheKeys.ENDPOINT_CONFIG); + const endpointsConfig = await cache.get(scopedCacheKey(CacheKeys.ENDPOINT_CONFIG)); const endpointConfig = endpointsConfig?.[endpoint]; if (!endpointConfig && endpointsConfig) { endpoint = Object.keys(endpointsConfig)[0]; diff --git a/packages/api/src/app/permissions.ts b/packages/api/src/app/permissions.ts index 5a557adfcf..92da1342ce 100644 --- a/packages/api/src/app/permissions.ts +++ b/packages/api/src/app/permissions.ts @@ -1,4 +1,4 @@ -import { logger } from '@librechat/data-schemas'; +import { logger, tenantStorage, SYSTEM_TENANT_ID } from '@librechat/data-schemas'; import { SystemRoles, Permissions, @@ -54,6 +54,7 @@ export async function updateInterfacePermissions({ appConfig, getRoleByName, updateAccessPermissions, + tenantId, }: { appConfig: AppConfig; getRoleByName: (roleName: string, fieldsToSelect?: string | string[]) => Promise; @@ -63,7 +64,19 @@ export async function updateInterfacePermissions({ roleData?: IRole | null, ) => Promise; -}) { + /** + * Optional tenant ID for scoping role updates to a specific tenant. + * When provided (and not SYSTEM_TENANT_ID), runs inside `tenantStorage.run({ tenantId })`. + * When omitted or SYSTEM_TENANT_ID, uses the caller's existing ALS context. + */ + tenantId?: string; +}): Promise { + if (tenantId && tenantId !== SYSTEM_TENANT_ID) { + return tenantStorage.run({ tenantId }, async () => + updateInterfacePermissions({ appConfig, getRoleByName, updateAccessPermissions }), + ); + } + const loadedInterface = appConfig?.interfaceConfig; if (!loadedInterface) { return; diff --git a/packages/api/src/flow/manager.tenant.spec.ts b/packages/api/src/flow/manager.tenant.spec.ts new file mode 100644 index 0000000000..14b780c34b --- /dev/null +++ b/packages/api/src/flow/manager.tenant.spec.ts @@ -0,0 +1,49 @@ +import { Keyv } from 'keyv'; +import { logger, tenantStorage } from '@librechat/data-schemas'; +import { FlowStateManager } from './manager'; + +jest.mock('@librechat/data-schemas', () => ({ + ...jest.requireActual('@librechat/data-schemas'), + logger: { + info: jest.fn(), + warn: jest.fn(), + error: jest.fn(), + debug: jest.fn(), + }, +})); + +describe('FlowStateManager flow keys are not tenant-scoped', () => { + let manager: FlowStateManager; + + beforeEach(() => { + jest.clearAllMocks(); + const store = new Keyv({ store: new Map() }); + manager = new FlowStateManager(store, { ci: true, ttl: 60_000 }); + }); + + it('completeFlow finds a flow regardless of tenant context (OAuth callback compatibility)', async () => { + await tenantStorage.run({ tenantId: 'tenant-a' }, async () => { + await manager.initFlow('flow-1', 'oauth', {}); + }); + + const found = await manager.completeFlow('flow-1', 'oauth', { token: 'abc' }); + expect(found).toBe(true); + }); + + it('completeFlow works when both creation and completion have the same tenant', async () => { + await tenantStorage.run({ tenantId: 'tenant-a' }, async () => { + await manager.initFlow('flow-2', 'oauth', {}); + const found = await manager.completeFlow('flow-2', 'oauth', { token: 'abc' }); + expect(found).toBe(true); + }); + }); + + it('completeFlow returns false and logs when flow does not exist', async () => { + const found = await manager.completeFlow('ghost-flow', 'oauth', { token: 'x' }); + expect(found).toBe(false); + expect(logger.warn).toHaveBeenCalledWith( + expect.stringContaining('ghost-flow'), + expect.objectContaining({ flowId: 'ghost-flow', type: 'oauth' }), + ); + }); +}); diff --git a/packages/api/src/flow/manager.ts b/packages/api/src/flow/manager.ts index b68b9edb7a..544cba9560 100644 --- a/packages/api/src/flow/manager.ts +++ b/packages/api/src/flow/manager.ts @@ -53,6 +53,12 @@ export class FlowStateManager { process.on('SIGHUP', cleanup); } + /** + * Flow keys are intentionally NOT tenant-scoped. OAuth callbacks arrive + * without tenant ALS context (the provider redirect doesn't carry + * X-Tenant-Id). Flow IDs are random UUIDs with no collision risk, and + * flow data is ephemeral (TTL-bounded, no sensitive user content). + */ private getFlowKey(flowId: string, type: string): string { return `${type}:${flowId}`; } @@ -253,7 +259,9 @@ export class FlowStateManager { if (!flowState) { logger.warn( - '[FlowStateManager] Flow state not found during completion โ€” cannot recover metadata, skipping', + `[FlowStateManager] completeFlow: flow not found โ€” key=${flowKey}. ` + + 'Possible causes: flow TTL expired before callback arrived, flow was never created, or ' + + 'the callback is routing to a different instance without shared Keyv storage.', { flowId, type }, ); return false; diff --git a/packages/api/src/mcp/__tests__/MCPOAuthCSRFFallback.test.ts b/packages/api/src/mcp/__tests__/MCPOAuthCSRFFallback.test.ts index cdba06cf8d..c0a861817c 100644 --- a/packages/api/src/mcp/__tests__/MCPOAuthCSRFFallback.test.ts +++ b/packages/api/src/mcp/__tests__/MCPOAuthCSRFFallback.test.ts @@ -34,6 +34,8 @@ jest.mock('@librechat/data-schemas', () => ({ error: jest.fn(), debug: jest.fn(), }, + getTenantId: jest.fn(), + SYSTEM_TENANT_ID: '__SYSTEM__', encryptV2: jest.fn(async (val: string) => `enc:${val}`), decryptV2: jest.fn(async (val: string) => val.replace(/^enc:/, '')), })); diff --git a/packages/api/src/mcp/__tests__/MCPOAuthFlow.test.ts b/packages/api/src/mcp/__tests__/MCPOAuthFlow.test.ts index f73a5ed3e8..7e26165cad 100644 --- a/packages/api/src/mcp/__tests__/MCPOAuthFlow.test.ts +++ b/packages/api/src/mcp/__tests__/MCPOAuthFlow.test.ts @@ -20,6 +20,8 @@ jest.mock('@librechat/data-schemas', () => ({ error: jest.fn(), debug: jest.fn(), }, + getTenantId: jest.fn(), + SYSTEM_TENANT_ID: '__SYSTEM__', encryptV2: jest.fn(async (val: string) => `enc:${val}`), decryptV2: jest.fn(async (val: string) => val.replace(/^enc:/, '')), })); diff --git a/packages/api/src/mcp/__tests__/MCPOAuthRaceCondition.test.ts b/packages/api/src/mcp/__tests__/MCPOAuthRaceCondition.test.ts index cb6187ab45..d5fb1d67f7 100644 --- a/packages/api/src/mcp/__tests__/MCPOAuthRaceCondition.test.ts +++ b/packages/api/src/mcp/__tests__/MCPOAuthRaceCondition.test.ts @@ -23,6 +23,8 @@ jest.mock('@librechat/data-schemas', () => ({ error: jest.fn(), debug: jest.fn(), }, + getTenantId: jest.fn(), + SYSTEM_TENANT_ID: '__SYSTEM__', encryptV2: jest.fn(async (val: string) => `enc:${val}`), decryptV2: jest.fn(async (val: string) => val.replace(/^enc:/, '')), })); @@ -258,7 +260,7 @@ describe('MCP OAuth Race Condition Fixes', () => { expect(stateAfterComplete).toBeUndefined(); expect(mockLogger.warn).toHaveBeenCalledWith( - expect.stringContaining('cannot recover metadata'), + expect.stringContaining('flow not found'), expect.any(Object), ); }); diff --git a/packages/api/src/mcp/__tests__/MCPOAuthTokenExpiry.test.ts b/packages/api/src/mcp/__tests__/MCPOAuthTokenExpiry.test.ts index 986ac4c8b4..b5cbc869a8 100644 --- a/packages/api/src/mcp/__tests__/MCPOAuthTokenExpiry.test.ts +++ b/packages/api/src/mcp/__tests__/MCPOAuthTokenExpiry.test.ts @@ -26,6 +26,8 @@ jest.mock('@librechat/data-schemas', () => ({ error: jest.fn(), debug: jest.fn(), }, + getTenantId: jest.fn(), + SYSTEM_TENANT_ID: '__SYSTEM__', encryptV2: jest.fn(async (val: string) => `enc:${val}`), decryptV2: jest.fn(async (val: string) => val.replace(/^enc:/, '')), })); diff --git a/packages/api/src/middleware/index.ts b/packages/api/src/middleware/index.ts index 7d9dee2f8a..b91fee2999 100644 --- a/packages/api/src/middleware/index.ts +++ b/packages/api/src/middleware/index.ts @@ -6,5 +6,6 @@ export * from './balance'; export * from './json'; export * from './capabilities'; export { tenantContextMiddleware } from './tenant'; +export { preAuthTenantMiddleware } from './preAuthTenant'; export * from './concurrency'; export * from './checkBalance'; diff --git a/packages/api/src/middleware/preAuthTenant.spec.ts b/packages/api/src/middleware/preAuthTenant.spec.ts new file mode 100644 index 0000000000..ed35da2324 --- /dev/null +++ b/packages/api/src/middleware/preAuthTenant.spec.ts @@ -0,0 +1,129 @@ +import { getTenantId, logger } from '@librechat/data-schemas'; +import { preAuthTenantMiddleware } from './preAuthTenant'; +import type { Request, Response, NextFunction } from 'express'; + +jest.mock('@librechat/data-schemas', () => ({ + ...jest.requireActual('@librechat/data-schemas'), + logger: { + warn: jest.fn(), + error: jest.fn(), + info: jest.fn(), + debug: jest.fn(), + }, +})); + +describe('preAuthTenantMiddleware', () => { + let req: Partial; + let res: Partial; + + beforeEach(() => { + jest.clearAllMocks(); + req = { headers: {} }; + res = {}; + }); + + it('calls next() without ALS context when no X-Tenant-Id header is present', () => { + let capturedTenantId: string | undefined = 'sentinel'; + const capturedNext: NextFunction = () => { + capturedTenantId = getTenantId(); + }; + + preAuthTenantMiddleware(req as Request, res as Response, capturedNext); + expect(capturedTenantId).toBeUndefined(); + }); + + it('calls next() without ALS context when X-Tenant-Id header is empty', () => { + req.headers = { 'x-tenant-id': '' }; + let capturedTenantId: string | undefined = 'sentinel'; + const capturedNext: NextFunction = () => { + capturedTenantId = getTenantId(); + }; + + preAuthTenantMiddleware(req as Request, res as Response, capturedNext); + expect(capturedTenantId).toBeUndefined(); + }); + + it('wraps downstream in ALS context when X-Tenant-Id header is present', () => { + req.headers = { 'x-tenant-id': 'acme-corp' }; + let capturedTenantId: string | undefined; + const capturedNext: NextFunction = () => { + capturedTenantId = getTenantId(); + }; + + preAuthTenantMiddleware(req as Request, res as Response, capturedNext); + expect(capturedTenantId).toBe('acme-corp'); + }); + + it('ignores __SYSTEM__ sentinel and logs warning', () => { + req.headers = { 'x-tenant-id': '__SYSTEM__' }; + req.ip = '10.0.0.1'; + req.path = '/api/config'; + let capturedTenantId: string | undefined = 'should-be-overwritten'; + const capturedNext: NextFunction = () => { + capturedTenantId = getTenantId(); + }; + + preAuthTenantMiddleware(req as Request, res as Response, capturedNext); + expect(capturedTenantId).toBeUndefined(); + expect(logger.warn).toHaveBeenCalledWith( + expect.stringContaining('__SYSTEM__'), + expect.objectContaining({ ip: '10.0.0.1', path: '/api/config' }), + ); + }); + + it('ignores array-valued headers (Express can produce these)', () => { + req.headers = { 'x-tenant-id': ['a', 'b'] as unknown as string }; + let capturedTenantId: string | undefined = 'sentinel'; + const capturedNext: NextFunction = () => { + capturedTenantId = getTenantId(); + }; + + preAuthTenantMiddleware(req as Request, res as Response, capturedNext); + expect(capturedTenantId).toBeUndefined(); + }); + + it('ignores tenant IDs containing invalid characters and logs warning', () => { + req.headers = { 'x-tenant-id': 'tenant:injected' }; + req.ip = '192.168.1.1'; + req.path = '/api/auth/login'; + let capturedTenantId: string | undefined = 'sentinel'; + const capturedNext: NextFunction = () => { + capturedTenantId = getTenantId(); + }; + + preAuthTenantMiddleware(req as Request, res as Response, capturedNext); + expect(capturedTenantId).toBeUndefined(); + expect(logger.warn).toHaveBeenCalledWith( + expect.stringContaining('malformed'), + expect.objectContaining({ ip: '192.168.1.1', path: '/api/auth/login' }), + ); + }); + + it('trims whitespace from tenant ID header', () => { + req.headers = { 'x-tenant-id': ' acme-corp ' }; + let capturedTenantId: string | undefined; + const capturedNext: NextFunction = () => { + capturedTenantId = getTenantId(); + }; + + preAuthTenantMiddleware(req as Request, res as Response, capturedNext); + expect(capturedTenantId).toBe('acme-corp'); + }); + + it('ignores tenant IDs exceeding max length and logs warning', () => { + req.headers = { 'x-tenant-id': 'a'.repeat(200) }; + req.ip = '192.168.1.1'; + req.path = '/api/share/abc'; + let capturedTenantId: string | undefined = 'sentinel'; + const capturedNext: NextFunction = () => { + capturedTenantId = getTenantId(); + }; + + preAuthTenantMiddleware(req as Request, res as Response, capturedNext); + expect(capturedTenantId).toBeUndefined(); + expect(logger.warn).toHaveBeenCalledWith( + expect.stringContaining('malformed'), + expect.objectContaining({ ip: '192.168.1.1', length: 200, path: '/api/share/abc' }), + ); + }); +}); diff --git a/packages/api/src/middleware/preAuthTenant.ts b/packages/api/src/middleware/preAuthTenant.ts new file mode 100644 index 0000000000..bab91f3a18 --- /dev/null +++ b/packages/api/src/middleware/preAuthTenant.ts @@ -0,0 +1,72 @@ +import { tenantStorage, logger, SYSTEM_TENANT_ID } from '@librechat/data-schemas'; +import type { Request, Response, NextFunction } from 'express'; + +/** + * Pre-authentication tenant context middleware for unauthenticated routes. + * + * Reads the tenant identifier from the `X-Tenant-Id` request header and wraps + * downstream handlers in `tenantStorage.run()` so that Mongoose queries and + * config resolution run within the correct tenant scope. + * + * **Where to use**: Mount on routes that must be tenant-aware before + * authentication has occurred: + * - `GET /api/config` โ€” login page needs tenant-specific config (social logins, registration) + * - `/api/auth/*` โ€” login, register, password reset + * - `/oauth/*` โ€” OAuth callback flows + * - `GET /api/share/:shareId` โ€” public shared conversation links + * + * **How the header gets set**: The deployment's reverse proxy, auth gateway, + * or OpenID strategy sets `X-Tenant-Id` based on subdomain, path, or OIDC claim. + * This middleware does NOT resolve tenants from subdomains or tokens โ€” that is + * the responsibility of the deployment layer. + * + * **Design**: Intentionally minimal. No subdomain parsing, no OIDC claim + * extraction, no YAML-driven strategy. Multi-tenant deployments can: + * 1. Set the header in the reverse proxy / ingress (simplest), + * 2. Replace this middleware's resolver logic entirely, or + * 3. Layer additional resolution on top (e.g., OpenID `tenant` claim โ†’ header). + * + * If no header is present, downstream runs without tenant ALS context (same as + * single-tenant mode). This preserves backward compatibility. + */ +const MAX_TENANT_ID_LENGTH = 128; +const VALID_TENANT_ID = /^[-a-zA-Z0-9_.]+$/; + +export function preAuthTenantMiddleware(req: Request, res: Response, next: NextFunction): void { + const raw = req.headers['x-tenant-id']; + + if (!raw || typeof raw !== 'string') { + next(); + return; + } + + const tenantId = raw.trim(); + + if (!tenantId) { + next(); + return; + } + + if (tenantId === SYSTEM_TENANT_ID) { + logger.warn('[preAuthTenant] Rejected __SYSTEM__ sentinel in X-Tenant-Id header', { + ip: req.ip, + path: req.path, + }); + next(); + return; + } + + if (tenantId.length > MAX_TENANT_ID_LENGTH || !VALID_TENANT_ID.test(tenantId)) { + logger.warn('[preAuthTenant] Rejected malformed X-Tenant-Id header', { + ip: req.ip, + length: tenantId.length, + path: req.path, + }); + next(); + return; + } + + return void tenantStorage.run({ tenantId }, async () => { + next(); + }); +} diff --git a/packages/api/src/stream/GenerationJobManager.ts b/packages/api/src/stream/GenerationJobManager.ts index 3e04ab734b..5993c911ff 100644 --- a/packages/api/src/stream/GenerationJobManager.ts +++ b/packages/api/src/stream/GenerationJobManager.ts @@ -1,4 +1,4 @@ -import { logger } from '@librechat/data-schemas'; +import { logger, getTenantId, SYSTEM_TENANT_ID } from '@librechat/data-schemas'; import type { StandardGraph } from '@librechat/agents'; import { parseTextParts } from 'librechat-data-provider'; import type { Agents, TMessageContentParts } from 'librechat-data-provider'; @@ -197,7 +197,9 @@ class GenerationJobManagerClass { userId: string, conversationId?: string, ): Promise { - const jobData = await this.jobStore.createJob(streamId, userId, conversationId); + const tenantId = getTenantId(); + const safeTenantId = tenantId && tenantId !== SYSTEM_TENANT_ID ? tenantId : undefined; + const jobData = await this.jobStore.createJob(streamId, userId, conversationId, safeTenantId); /** * Create runtime state with readyPromise. @@ -355,6 +357,7 @@ class GenerationJobManagerClass { error: jobData.error, metadata: { userId: jobData.userId, + tenantId: jobData.tenantId, conversationId: jobData.conversationId, userMessage: jobData.userMessage, responseMessageId: jobData.responseMessageId, @@ -1255,8 +1258,8 @@ class GenerationJobManagerClass { * @param userId - The user ID to query * @returns Array of conversation IDs with active jobs */ - async getActiveJobIdsForUser(userId: string): Promise { - return this.jobStore.getActiveJobIdsByUser(userId); + async getActiveJobIdsForUser(userId: string, tenantId?: string): Promise { + return this.jobStore.getActiveJobIdsByUser(userId, tenantId); } /** diff --git a/packages/api/src/stream/implementations/InMemoryJobStore.ts b/packages/api/src/stream/implementations/InMemoryJobStore.ts index cc82a69963..7280c3ce80 100644 --- a/packages/api/src/stream/implementations/InMemoryJobStore.ts +++ b/packages/api/src/stream/implementations/InMemoryJobStore.ts @@ -70,6 +70,7 @@ export class InMemoryJobStore implements IJobStore { streamId: string, userId: string, conversationId?: string, + tenantId?: string, ): Promise { if (this.jobs.size >= this.maxJobs) { await this.evictOldest(); @@ -78,6 +79,7 @@ export class InMemoryJobStore implements IJobStore { const job: SerializableJobData = { streamId, userId, + ...(tenantId && { tenantId }), status: 'running', createdAt: Date.now(), conversationId, @@ -86,11 +88,12 @@ export class InMemoryJobStore implements IJobStore { this.jobs.set(streamId, job); - // Track job by userId for efficient user-scoped queries - let userJobs = this.userJobMap.get(userId); + // Track job by userId (tenant-qualified when available) for efficient user-scoped queries + const userKey = tenantId ? `${tenantId}:${userId}` : userId; + let userJobs = this.userJobMap.get(userKey); if (!userJobs) { userJobs = new Set(); - this.userJobMap.set(userId, userJobs); + this.userJobMap.set(userKey, userJobs); } userJobs.add(streamId); @@ -146,6 +149,17 @@ export class InMemoryJobStore implements IJobStore { } for (const id of toDelete) { + const job = this.jobs.get(id); + if (job) { + const userKey = job.tenantId ? `${job.tenantId}:${job.userId}` : job.userId; + const userJobs = this.userJobMap.get(userKey); + if (userJobs) { + userJobs.delete(id); + if (userJobs.size === 0) { + this.userJobMap.delete(userKey); + } + } + } await this.deleteJob(id); } @@ -169,6 +183,17 @@ export class InMemoryJobStore implements IJobStore { if (oldestId) { logger.warn(`[InMemoryJobStore] Evicting oldest job: ${oldestId}`); + const job = this.jobs.get(oldestId); + if (job) { + const userKey = job.tenantId ? `${job.tenantId}:${job.userId}` : job.userId; + const userJobs = this.userJobMap.get(userKey); + if (userJobs) { + userJobs.delete(oldestId); + if (userJobs.size === 0) { + this.userJobMap.delete(userKey); + } + } + } await this.deleteJob(oldestId); } } @@ -205,8 +230,9 @@ export class InMemoryJobStore implements IJobStore { * Returns conversation IDs of running jobs belonging to the user. * Also performs self-healing cleanup: removes stale entries for jobs that no longer exist. */ - async getActiveJobIdsByUser(userId: string): Promise { - const trackedIds = this.userJobMap.get(userId); + async getActiveJobIdsByUser(userId: string, tenantId?: string): Promise { + const userKey = tenantId ? `${tenantId}:${userId}` : userId; + const trackedIds = this.userJobMap.get(userKey); if (!trackedIds || trackedIds.size === 0) { return []; } @@ -226,7 +252,7 @@ export class InMemoryJobStore implements IJobStore { // Clean up empty set if (trackedIds.size === 0) { - this.userJobMap.delete(userId); + this.userJobMap.delete(userKey); } return activeIds; diff --git a/packages/api/src/stream/implementations/RedisJobStore.ts b/packages/api/src/stream/implementations/RedisJobStore.ts index 727fe066eb..a631bc2044 100644 --- a/packages/api/src/stream/implementations/RedisJobStore.ts +++ b/packages/api/src/stream/implementations/RedisJobStore.ts @@ -29,8 +29,9 @@ const KEYS = { runSteps: (streamId: string) => `stream:{${streamId}}:runsteps`, /** Running jobs set for cleanup (global set - single slot) */ runningJobs: 'stream:running', - /** User's active jobs set: stream:user:{userId}:jobs */ - userJobs: (userId: string) => `stream:user:{${userId}}:jobs`, + /** User's active jobs set, tenant-qualified when tenantId is available */ + userJobs: (userId: string, tenantId?: string) => + tenantId ? `stream:user:{${tenantId}:${userId}}:jobs` : `stream:user:{${userId}}:jobs`, }; /** @@ -140,10 +141,12 @@ export class RedisJobStore implements IJobStore { streamId: string, userId: string, conversationId?: string, + tenantId?: string, ): Promise { const job: SerializableJobData = { streamId, userId, + ...(tenantId && { tenantId }), status: 'running', createdAt: Date.now(), conversationId, @@ -151,7 +154,7 @@ export class RedisJobStore implements IJobStore { }; const key = KEYS.job(streamId); - const userJobsKey = KEYS.userJobs(userId); + const userJobsKey = KEYS.userJobs(userId, tenantId); // For cluster mode, we can't pipeline keys on different slots // The job key uses hash tag {streamId}, runningJobs and userJobs are on different slots @@ -377,8 +380,8 @@ export class RedisJobStore implements IJobStore { * @param userId - The user ID to query * @returns Array of conversation IDs with active jobs */ - async getActiveJobIdsByUser(userId: string): Promise { - const userJobsKey = KEYS.userJobs(userId); + async getActiveJobIdsByUser(userId: string, tenantId?: string): Promise { + const userJobsKey = KEYS.userJobs(userId, tenantId); const trackedIds = await this.redis.smembers(userJobsKey); if (trackedIds.length === 0) { @@ -868,6 +871,7 @@ export class RedisJobStore implements IJobStore { return { streamId: data.streamId, userId: data.userId, + tenantId: data.tenantId || undefined, status: data.status as JobStatus, createdAt: parseInt(data.createdAt, 10), completedAt: data.completedAt ? parseInt(data.completedAt, 10) : undefined, diff --git a/packages/api/src/stream/interfaces/IJobStore.ts b/packages/api/src/stream/interfaces/IJobStore.ts index fadddb840d..b59eed66f8 100644 --- a/packages/api/src/stream/interfaces/IJobStore.ts +++ b/packages/api/src/stream/interfaces/IJobStore.ts @@ -12,6 +12,7 @@ export type JobStatus = 'running' | 'complete' | 'error' | 'aborted'; export interface SerializableJobData { streamId: string; userId: string; + tenantId?: string; status: JobStatus; createdAt: number; completedAt?: number; @@ -149,6 +150,7 @@ export interface IJobStore { streamId: string, userId: string, conversationId?: string, + tenantId?: string, ): Promise; /** Get a job by streamId (streamId === conversationId) */ @@ -186,7 +188,7 @@ export interface IJobStore { * @param userId - The user ID to query * @returns Array of conversation IDs with active jobs */ - getActiveJobIdsByUser(userId: string): Promise; + getActiveJobIdsByUser(userId: string, tenantId?: string): Promise; // ===== Content State Methods ===== // These methods manage volatile content state tied to each job. diff --git a/packages/api/src/types/stream.ts b/packages/api/src/types/stream.ts index 068d9c8db8..dd125a1aab 100644 --- a/packages/api/src/types/stream.ts +++ b/packages/api/src/types/stream.ts @@ -4,6 +4,7 @@ import type { ServerSentEvent } from '~/types'; export interface GenerationJobMetadata { userId: string; + tenantId?: string; conversationId?: string; /** User message data for rebuilding submission on reconnect */ userMessage?: Agents.UserMessageMeta; diff --git a/packages/data-schemas/src/config/tenantContext.spec.ts b/packages/data-schemas/src/config/tenantContext.spec.ts new file mode 100644 index 0000000000..7e6cc0748d --- /dev/null +++ b/packages/data-schemas/src/config/tenantContext.spec.ts @@ -0,0 +1,26 @@ +import { tenantStorage, runAsSystem, scopedCacheKey } from './tenantContext'; + +describe('scopedCacheKey', () => { + it('returns base key when no ALS context is set', () => { + expect(scopedCacheKey('MODELS_CONFIG')).toBe('MODELS_CONFIG'); + }); + + it('returns base key in SYSTEM_TENANT_ID context', async () => { + await runAsSystem(async () => { + expect(scopedCacheKey('MODELS_CONFIG')).toBe('MODELS_CONFIG'); + }); + }); + + it('appends tenantId when tenant context is active', async () => { + await tenantStorage.run({ tenantId: 'acme' }, async () => { + expect(scopedCacheKey('MODELS_CONFIG')).toBe('MODELS_CONFIG:acme'); + }); + }); + + it('does not leak tenant context outside ALS scope', async () => { + await tenantStorage.run({ tenantId: 'acme' }, async () => { + expect(scopedCacheKey('KEY')).toBe('KEY:acme'); + }); + expect(scopedCacheKey('KEY')).toBe('KEY'); + }); +}); diff --git a/packages/data-schemas/src/config/tenantContext.ts b/packages/data-schemas/src/config/tenantContext.ts index e5e4376a90..eb77edb27d 100644 --- a/packages/data-schemas/src/config/tenantContext.ts +++ b/packages/data-schemas/src/config/tenantContext.ts @@ -26,3 +26,16 @@ export function getTenantId(): string | undefined { export function runAsSystem(fn: () => Promise): Promise { return tenantStorage.run({ tenantId: SYSTEM_TENANT_ID }, fn); } + +/** + * Appends `:${tenantId}` to a cache key when a non-system tenant context is active. + * Returns the base key unchanged when no ALS context is set or when running + * inside `runAsSystem()` (SYSTEM_TENANT_ID context). + */ +export function scopedCacheKey(baseKey: string): string { + const tenantId = getTenantId(); + if (!tenantId || tenantId === SYSTEM_TENANT_ID) { + return baseKey; + } + return `${baseKey}:${tenantId}`; +} diff --git a/packages/data-schemas/src/index.ts b/packages/data-schemas/src/index.ts index d673db1f5c..1139f83f17 100644 --- a/packages/data-schemas/src/index.ts +++ b/packages/data-schemas/src/index.ts @@ -19,6 +19,12 @@ export type * from './types'; export type * from './methods'; export { default as logger } from './config/winston'; export { default as meiliLogger } from './config/meiliLogger'; -export { tenantStorage, getTenantId, runAsSystem, SYSTEM_TENANT_ID } from './config/tenantContext'; +export { + tenantStorage, + getTenantId, + runAsSystem, + scopedCacheKey, + SYSTEM_TENANT_ID, +} from './config/tenantContext'; export type { TenantContext } from './config/tenantContext'; export { dropSupersededTenantIndexes, dropSupersededPromptGroupIndexes } from './migrations'; diff --git a/packages/data-schemas/src/methods/aclEntry.ts b/packages/data-schemas/src/methods/aclEntry.ts index 82e277254a..2f61861029 100644 --- a/packages/data-schemas/src/methods/aclEntry.ts +++ b/packages/data-schemas/src/methods/aclEntry.ts @@ -8,6 +8,7 @@ import type { Model, } from 'mongoose'; import type { IAclEntry } from '~/types'; +import { tenantSafeBulkWrite } from '~/utils/tenantBulkWrite'; export function createAclEntryMethods(mongoose: typeof import('mongoose')) { /** @@ -378,7 +379,7 @@ export function createAclEntryMethods(mongoose: typeof import('mongoose')) { options?: { session?: ClientSession }, ) { const AclEntry = mongoose.models.AclEntry as Model; - return AclEntry.bulkWrite(ops, options || {}); + return tenantSafeBulkWrite(AclEntry, ops, options || {}); } /** @@ -448,7 +449,9 @@ export function createAclEntryMethods(mongoose: typeof import('mongoose')) { { $group: { _id: '$resourceId' } }, ]); - const multiOwnerIds = new Set(otherOwners.map((doc: { _id: Types.ObjectId }) => doc._id.toString())); + const multiOwnerIds = new Set( + otherOwners.map((doc: { _id: Types.ObjectId }) => doc._id.toString()), + ); return ownedIds.filter((id) => !multiOwnerIds.has(id.toString())); } diff --git a/packages/data-schemas/src/methods/agentCategory.ts b/packages/data-schemas/src/methods/agentCategory.ts index 2dd4678075..baf33207aa 100644 --- a/packages/data-schemas/src/methods/agentCategory.ts +++ b/packages/data-schemas/src/methods/agentCategory.ts @@ -1,5 +1,6 @@ import type { Model, Types } from 'mongoose'; import type { IAgentCategory } from '~/types'; +import { tenantSafeBulkWrite } from '~/utils/tenantBulkWrite'; export function createAgentCategoryMethods(mongoose: typeof import('mongoose')) { /** @@ -74,7 +75,7 @@ export function createAgentCategoryMethods(mongoose: typeof import('mongoose')) }, })); - return await AgentCategory.bulkWrite(operations); + return await tenantSafeBulkWrite(AgentCategory, operations); } /** @@ -241,7 +242,7 @@ export function createAgentCategoryMethods(mongoose: typeof import('mongoose')) }, })); - await AgentCategory.bulkWrite(bulkOps, { ordered: false }); + await tenantSafeBulkWrite(AgentCategory, bulkOps, { ordered: false }); } return updates.length > 0 || created > 0; diff --git a/packages/data-schemas/src/methods/conversation.ts b/packages/data-schemas/src/methods/conversation.ts index 7a62afef9e..abfe16bf2d 100644 --- a/packages/data-schemas/src/methods/conversation.ts +++ b/packages/data-schemas/src/methods/conversation.ts @@ -1,6 +1,7 @@ import type { FilterQuery, Model, SortOrder } from 'mongoose'; -import logger from '~/config/winston'; import { createTempChatExpirationDate } from '~/utils/tempChatRetention'; +import { tenantSafeBulkWrite } from '~/utils/tenantBulkWrite'; +import logger from '~/config/winston'; import type { AppConfig, IConversation } from '~/types'; import type { MessageMethods } from './message'; import type { DeleteResult } from 'mongoose'; @@ -228,7 +229,7 @@ export function createConversationMethods( }, })); - const result = await Conversation.bulkWrite(bulkOps); + const result = await tenantSafeBulkWrite(Conversation, bulkOps); return result; } catch (error) { logger.error('[bulkSaveConvos] Error saving conversations in bulk', error); diff --git a/packages/data-schemas/src/methods/conversationTag.ts b/packages/data-schemas/src/methods/conversationTag.ts index af1e43babb..085948bab5 100644 --- a/packages/data-schemas/src/methods/conversationTag.ts +++ b/packages/data-schemas/src/methods/conversationTag.ts @@ -1,4 +1,5 @@ import type { Model } from 'mongoose'; +import { tenantSafeBulkWrite } from '~/utils/tenantBulkWrite'; import logger from '~/config/winston'; interface IConversationTag { @@ -233,7 +234,7 @@ export function createConversationTagMethods(mongoose: typeof import('mongoose') } if (bulkOps.length > 0) { - await ConversationTag.bulkWrite(bulkOps); + await tenantSafeBulkWrite(ConversationTag, bulkOps); } const updatedConversation = ( @@ -273,7 +274,7 @@ export function createConversationTagMethods(mongoose: typeof import('mongoose') }, })); - const result = await ConversationTag.bulkWrite(bulkOps); + const result = await tenantSafeBulkWrite(ConversationTag, bulkOps); if (result && result.modifiedCount > 0) { logger.debug( `user: ${user} | Incremented tag counts - modified ${result.modifiedCount} tags`, diff --git a/packages/data-schemas/src/methods/file.ts b/packages/data-schemas/src/methods/file.ts index 3d7db88c3f..4c0969afb3 100644 --- a/packages/data-schemas/src/methods/file.ts +++ b/packages/data-schemas/src/methods/file.ts @@ -2,6 +2,7 @@ import logger from '../config/winston'; import { EToolResources, FileContext } from 'librechat-data-provider'; import type { FilterQuery, SortOrder, Model } from 'mongoose'; import type { IMongoFile } from '~/types/file'; +import { tenantSafeBulkWrite } from '~/utils/tenantBulkWrite'; /** Factory function that takes mongoose instance and returns the file methods */ export function createFileMethods(mongoose: typeof import('mongoose')) { @@ -322,7 +323,7 @@ export function createFileMethods(mongoose: typeof import('mongoose')) { }, })); - const result = await File.bulkWrite(bulkOperations); + const result = await tenantSafeBulkWrite(File, bulkOperations); logger.info(`Updated ${result.modifiedCount} files with new S3 URLs`); } diff --git a/packages/data-schemas/src/methods/message.ts b/packages/data-schemas/src/methods/message.ts index ae5ca72b12..2e638b6bfb 100644 --- a/packages/data-schemas/src/methods/message.ts +++ b/packages/data-schemas/src/methods/message.ts @@ -1,6 +1,7 @@ import type { DeleteResult, FilterQuery, Model } from 'mongoose'; import logger from '~/config/winston'; import { createTempChatExpirationDate } from '~/utils/tempChatRetention'; +import { tenantSafeBulkWrite } from '~/utils/tenantBulkWrite'; import type { AppConfig, IMessage } from '~/types'; /** Simple UUID v4 regex to replace zod validation */ @@ -165,7 +166,7 @@ export function createMessageMethods(mongoose: typeof import('mongoose')): Messa upsert: true, }, })); - const result = await Message.bulkWrite(bulkOps); + const result = await tenantSafeBulkWrite(Message, bulkOps); return result; } catch (err) { logger.error('Error saving messages in bulk:', err); diff --git a/packages/data-schemas/src/methods/prompt.ts b/packages/data-schemas/src/methods/prompt.ts index a1b6bfde37..86d830fecd 100644 --- a/packages/data-schemas/src/methods/prompt.ts +++ b/packages/data-schemas/src/methods/prompt.ts @@ -1,8 +1,9 @@ import { ResourceType, SystemCategories } from 'librechat-data-provider'; import type { Model, Types } from 'mongoose'; import type { IAclEntry, IPrompt, IPromptGroup, IPromptGroupDocument } from '~/types'; -import { escapeRegExp } from '~/utils/string'; +import { getTenantId, SYSTEM_TENANT_ID } from '~/config/tenantContext'; import { isValidObjectIdString } from '~/utils/objectId'; +import { escapeRegExp } from '~/utils/string'; import logger from '~/config/winston'; export interface PromptDeps { @@ -508,16 +509,37 @@ export function createPromptMethods(mongoose: typeof import('mongoose'), deps: P if (typeof matchFilter._id === 'string') { matchFilter._id = new ObjectId(matchFilter._id); } + const tenantId = getTenantId(); + const useTenantFilter = tenantId && tenantId !== SYSTEM_TENANT_ID; + + const lookupStage = useTenantFilter + ? { + $lookup: { + from: 'prompts', + let: { prodId: '$productionId' }, + pipeline: [ + { + $match: { + $expr: { $eq: ['$_id', '$$prodId'] }, + tenantId, + }, + }, + ], + as: 'productionPrompt', + }, + } + : { + $lookup: { + from: 'prompts', + localField: 'productionId', + foreignField: '_id', + as: 'productionPrompt', + }, + }; + const result = await PromptGroup.aggregate([ { $match: matchFilter }, - { - $lookup: { - from: 'prompts', - localField: 'productionId', - foreignField: '_id', - as: 'productionPrompt', - }, - }, + lookupStage, { $unwind: { path: '$productionPrompt', preserveNullAndEmptyArrays: true } }, ]); const group = result[0] || null; diff --git a/packages/data-schemas/src/migrations/tenantIndexes.spec.ts b/packages/data-schemas/src/migrations/tenantIndexes.spec.ts index 4637e7d0ad..e62b587a6e 100644 --- a/packages/data-schemas/src/migrations/tenantIndexes.spec.ts +++ b/packages/data-schemas/src/migrations/tenantIndexes.spec.ts @@ -47,7 +47,13 @@ describe('dropSupersededTenantIndexes', () => { await db.createCollection('roles'); await db.collection('roles').createIndex({ name: 1 }, { unique: true, name: 'name_1' }); + await db.createCollection('agents'); + await db.collection('agents').createIndex({ id: 1 }, { unique: true, name: 'id_1' }); + await db.createCollection('conversations'); + await db + .collection('conversations') + .createIndex({ conversationId: 1 }, { unique: true, name: 'conversationId_1' }); await db .collection('conversations') .createIndex( @@ -56,10 +62,18 @@ describe('dropSupersededTenantIndexes', () => { ); await db.createCollection('messages'); + await db + .collection('messages') + .createIndex({ messageId: 1 }, { unique: true, name: 'messageId_1' }); await db .collection('messages') .createIndex({ messageId: 1, user: 1 }, { unique: true, name: 'messageId_1_user_1' }); + await db.createCollection('presets'); + await db + .collection('presets') + .createIndex({ presetId: 1 }, { unique: true, name: 'presetId_1' }); + await db.createCollection('agentcategories'); await db .collection('agentcategories') diff --git a/packages/data-schemas/src/migrations/tenantIndexes.ts b/packages/data-schemas/src/migrations/tenantIndexes.ts index c68df4db2b..a8b4e51768 100644 --- a/packages/data-schemas/src/migrations/tenantIndexes.ts +++ b/packages/data-schemas/src/migrations/tenantIndexes.ts @@ -24,8 +24,10 @@ const SUPERSEDED_INDEXES: Record = { 'appleId_1', ], roles: ['name_1'], - conversations: ['conversationId_1_user_1'], - messages: ['messageId_1_user_1'], + agents: ['id_1'], + conversations: ['conversationId_1', 'conversationId_1_user_1'], + messages: ['messageId_1', 'messageId_1_user_1'], + presets: ['presetId_1'], agentcategories: ['value_1'], accessroles: ['accessRoleId_1'], conversationtags: ['tag_1_user_1'], diff --git a/packages/data-schemas/src/schema/agent.ts b/packages/data-schemas/src/schema/agent.ts index 42a7ca5418..70734d0ceb 100644 --- a/packages/data-schemas/src/schema/agent.ts +++ b/packages/data-schemas/src/schema/agent.ts @@ -5,8 +5,6 @@ const agentSchema = new Schema( { id: { type: String, - index: true, - unique: true, required: true, }, name: { @@ -124,6 +122,7 @@ const agentSchema = new Schema( }, ); +agentSchema.index({ id: 1, tenantId: 1 }, { unique: true }); agentSchema.index({ updatedAt: -1, _id: 1 }); agentSchema.index({ 'edges.to': 1 }); diff --git a/packages/data-schemas/src/schema/convo.ts b/packages/data-schemas/src/schema/convo.ts index 9ed8949e9c..c8f394935a 100644 --- a/packages/data-schemas/src/schema/convo.ts +++ b/packages/data-schemas/src/schema/convo.ts @@ -6,7 +6,6 @@ const convoSchema: Schema = new Schema( { conversationId: { type: String, - unique: true, required: true, index: true, meiliIndex: true, diff --git a/packages/data-schemas/src/schema/message.ts b/packages/data-schemas/src/schema/message.ts index ff3468918e..9879efae55 100644 --- a/packages/data-schemas/src/schema/message.ts +++ b/packages/data-schemas/src/schema/message.ts @@ -5,7 +5,6 @@ const messageSchema: Schema = new Schema( { messageId: { type: String, - unique: true, required: true, index: true, meiliIndex: true, diff --git a/packages/data-schemas/src/schema/preset.ts b/packages/data-schemas/src/schema/preset.ts index 33c217ea23..5af5163fd3 100644 --- a/packages/data-schemas/src/schema/preset.ts +++ b/packages/data-schemas/src/schema/preset.ts @@ -60,7 +60,6 @@ const presetSchema: Schema = new Schema( { presetId: { type: String, - unique: true, required: true, index: true, }, @@ -88,4 +87,6 @@ const presetSchema: Schema = new Schema( { timestamps: true }, ); +presetSchema.index({ presetId: 1, tenantId: 1 }, { unique: true }); + export default presetSchema; diff --git a/packages/data-schemas/src/utils/index.ts b/packages/data-schemas/src/utils/index.ts index c071f4e827..17e43ac3ca 100644 --- a/packages/data-schemas/src/utils/index.ts +++ b/packages/data-schemas/src/utils/index.ts @@ -1,5 +1,6 @@ export * from './principal'; export * from './string'; export * from './tempChatRetention'; +export { tenantSafeBulkWrite } from './tenantBulkWrite'; export * from './transactions'; export * from './objectId'; diff --git a/packages/data-schemas/src/utils/tenantBulkWrite.spec.ts b/packages/data-schemas/src/utils/tenantBulkWrite.spec.ts new file mode 100644 index 0000000000..059868b8a1 --- /dev/null +++ b/packages/data-schemas/src/utils/tenantBulkWrite.spec.ts @@ -0,0 +1,376 @@ +import mongoose, { Schema } from 'mongoose'; +import { MongoMemoryServer } from 'mongodb-memory-server'; +import { tenantStorage, runAsSystem, SYSTEM_TENANT_ID } from '~/config/tenantContext'; +import { applyTenantIsolation, _resetStrictCache } from '~/models/plugins/tenantIsolation'; +import { tenantSafeBulkWrite, _resetBulkWriteStrictCache } from './tenantBulkWrite'; + +let mongoServer: InstanceType; + +interface ITestDoc { + name: string; + value?: number; + tenantId?: string; +} + +function createTestModel(suffix: string) { + const schema = new Schema({ + name: { type: String, required: true }, + value: { type: Number, default: 0 }, + tenantId: { type: String, index: true }, + }); + applyTenantIsolation(schema); + const modelName = `TestBulkWrite_${suffix}_${Date.now()}`; + return mongoose.model(modelName, schema); +} + +beforeAll(async () => { + mongoServer = await MongoMemoryServer.create(); + await mongoose.connect(mongoServer.getUri()); +}); + +afterAll(async () => { + await mongoose.disconnect(); + await mongoServer.stop(); +}); + +afterEach(() => { + delete process.env.TENANT_ISOLATION_STRICT; + _resetStrictCache(); + _resetBulkWriteStrictCache(); +}); + +describe('tenantSafeBulkWrite', () => { + describe('with tenant context', () => { + it('injects tenantId into updateOne filters', async () => { + const Model = createTestModel('updateOne'); + + // Seed data for two tenants + await runAsSystem(async () => { + await Model.create([ + { name: 'doc1', value: 1, tenantId: 'tenant-a' }, + { name: 'doc1', value: 1, tenantId: 'tenant-b' }, + ]); + }); + + // Update only tenant-a's doc + await tenantStorage.run({ tenantId: 'tenant-a' }, async () => { + await tenantSafeBulkWrite(Model, [ + { + updateOne: { + filter: { name: 'doc1' }, + update: { $set: { value: 99 } }, + }, + }, + ]); + }); + + // Verify tenant-a was updated, tenant-b was not + const docs = await runAsSystem(async () => Model.find({}).lean()); + const docA = docs.find((d) => d.tenantId === 'tenant-a'); + const docB = docs.find((d) => d.tenantId === 'tenant-b'); + expect(docA?.value).toBe(99); + expect(docB?.value).toBe(1); + }); + + it('injects tenantId into insertOne documents', async () => { + const Model = createTestModel('insertOne'); + + await tenantStorage.run({ tenantId: 'tenant-x' }, async () => { + await tenantSafeBulkWrite(Model, [ + { + insertOne: { + document: { name: 'new-doc', value: 42 } as ITestDoc, + }, + }, + ]); + }); + + const docs = await runAsSystem(async () => Model.find({}).lean()); + expect(docs).toHaveLength(1); + expect(docs[0].tenantId).toBe('tenant-x'); + expect(docs[0].name).toBe('new-doc'); + }); + + it('injects tenantId into deleteOne filters', async () => { + const Model = createTestModel('deleteOne'); + + await runAsSystem(async () => { + await Model.create([ + { name: 'to-delete', tenantId: 'tenant-a' }, + { name: 'to-delete', tenantId: 'tenant-b' }, + ]); + }); + + await tenantStorage.run({ tenantId: 'tenant-a' }, async () => { + await tenantSafeBulkWrite(Model, [ + { + deleteOne: { + filter: { name: 'to-delete' }, + }, + }, + ]); + }); + + const docs = await runAsSystem(async () => Model.find({}).lean()); + expect(docs).toHaveLength(1); + expect(docs[0].tenantId).toBe('tenant-b'); + }); + + it('injects tenantId into updateMany filters', async () => { + const Model = createTestModel('updateMany'); + + await runAsSystem(async () => { + await Model.create([ + { name: 'batch', value: 0, tenantId: 'tenant-a' }, + { name: 'batch', value: 0, tenantId: 'tenant-a' }, + { name: 'batch', value: 0, tenantId: 'tenant-b' }, + ]); + }); + + await tenantStorage.run({ tenantId: 'tenant-a' }, async () => { + await tenantSafeBulkWrite(Model, [ + { + updateMany: { + filter: { name: 'batch' }, + update: { $set: { value: 5 } }, + }, + }, + ]); + }); + + const docs = await runAsSystem(async () => Model.find({}).lean()); + const tenantADocs = docs.filter((d) => d.tenantId === 'tenant-a'); + const tenantBDocs = docs.filter((d) => d.tenantId === 'tenant-b'); + expect(tenantADocs.every((d) => d.value === 5)).toBe(true); + expect(tenantBDocs[0].value).toBe(0); + }); + }); + + describe('with SYSTEM_TENANT_ID', () => { + it('skips tenantId injection (cross-tenant operation)', async () => { + const Model = createTestModel('system'); + + await runAsSystem(async () => { + await Model.create([ + { name: 'sys-doc', value: 0, tenantId: 'tenant-a' }, + { name: 'sys-doc', value: 0, tenantId: 'tenant-b' }, + ]); + }); + + // System context should update ALL docs regardless of tenant + await runAsSystem(async () => { + await tenantSafeBulkWrite(Model, [ + { + updateMany: { + filter: { name: 'sys-doc' }, + update: { $set: { value: 100 } }, + }, + }, + ]); + }); + + const docs = await runAsSystem(async () => Model.find({}).lean()); + expect(docs.every((d) => d.value === 100)).toBe(true); + }); + }); + + describe('with SYSTEM_TENANT_ID in strict mode', () => { + it('does not throw when runAsSystem is used in strict mode', async () => { + process.env.TENANT_ISOLATION_STRICT = 'true'; + _resetBulkWriteStrictCache(); + + const Model = createTestModel('systemStrict'); + + await runAsSystem(async () => { + await Model.create({ name: 'strict-sys', value: 0 }); + }); + + await expect( + runAsSystem(async () => + tenantSafeBulkWrite(Model, [ + { + updateOne: { + filter: { name: 'strict-sys' }, + update: { $set: { value: 42 } }, + }, + }, + ]), + ), + ).resolves.toBeDefined(); + }); + }); + + describe('deleteMany and replaceOne', () => { + it('injects tenantId into deleteMany filters', async () => { + const Model = createTestModel('deleteMany'); + + await runAsSystem(async () => { + await Model.create([ + { name: 'batch-del', value: 0, tenantId: 'tenant-a' }, + { name: 'batch-del', value: 0, tenantId: 'tenant-a' }, + { name: 'batch-del', value: 0, tenantId: 'tenant-b' }, + ]); + }); + + await tenantStorage.run({ tenantId: 'tenant-a' }, async () => { + await tenantSafeBulkWrite(Model, [{ deleteMany: { filter: { name: 'batch-del' } } }]); + }); + + const docs = await runAsSystem(async () => Model.find({}).lean()); + expect(docs).toHaveLength(1); + expect(docs[0].tenantId).toBe('tenant-b'); + }); + + it('injects tenantId into replaceOne filter and replacement', async () => { + const Model = createTestModel('replaceOne'); + + await runAsSystem(async () => { + await Model.create([ + { name: 'to-replace', value: 1, tenantId: 'tenant-a' }, + { name: 'to-replace', value: 1, tenantId: 'tenant-b' }, + ]); + }); + + await tenantStorage.run({ tenantId: 'tenant-a' }, async () => { + await tenantSafeBulkWrite(Model, [ + { + replaceOne: { + filter: { name: 'to-replace' }, + replacement: { name: 'replaced', value: 99 }, + }, + }, + ]); + }); + + const docs = await runAsSystem(async () => Model.find({}).sort({ name: 1 }).lean()); + const replaced = docs.find((d) => d.name === 'replaced'); + const untouched = docs.find((d) => d.tenantId === 'tenant-b'); + expect(replaced?.value).toBe(99); + expect(replaced?.tenantId).toBe('tenant-a'); + expect(untouched?.value).toBe(1); + }); + + it('replaceOne overwrites a conflicting tenantId in the replacement document', async () => { + const Model = createTestModel('replaceOverwrite'); + + await runAsSystem(async () => { + await Model.create({ name: 'conflict', value: 1, tenantId: 'tenant-a' }); + }); + + await tenantStorage.run({ tenantId: 'tenant-a' }, async () => { + await tenantSafeBulkWrite(Model, [ + { + replaceOne: { + filter: { name: 'conflict' }, + replacement: { name: 'conflict', value: 2, tenantId: 'tenant-evil' } as ITestDoc, + }, + }, + ]); + }); + + const docs = await runAsSystem(async () => Model.find({}).lean()); + expect(docs).toHaveLength(1); + expect(docs[0].tenantId).toBe('tenant-a'); + expect(docs[0].value).toBe(2); + }); + }); + + describe('edge cases', () => { + it('handles empty ops array', async () => { + const Model = createTestModel('emptyOps'); + const result = await tenantStorage.run({ tenantId: 'tenant-x' }, async () => + tenantSafeBulkWrite(Model, []), + ); + expect(result.insertedCount).toBe(0); + expect(result.modifiedCount).toBe(0); + }); + }); + + describe('without tenant context', () => { + it('passes through in non-strict mode', async () => { + const Model = createTestModel('noCtx'); + + await runAsSystem(async () => { + await Model.create({ name: 'no-ctx', value: 0 }); + }); + + // No ALS context โ€” non-strict should pass through + const result = await tenantSafeBulkWrite(Model, [ + { + updateOne: { + filter: { name: 'no-ctx' }, + update: { $set: { value: 10 } }, + }, + }, + ]); + + expect(result.modifiedCount).toBe(1); + }); + + it('throws in strict mode', async () => { + process.env.TENANT_ISOLATION_STRICT = 'true'; + _resetBulkWriteStrictCache(); + + const Model = createTestModel('strict'); + + await expect( + tenantSafeBulkWrite(Model, [ + { + updateOne: { + filter: { name: 'any' }, + update: { $set: { value: 1 } }, + }, + }, + ]), + ).rejects.toThrow('bulkWrite on TestBulkWrite_strict'); + }); + }); + + describe('mixed operations', () => { + it('handles a batch of mixed insert, update, delete operations', async () => { + const Model = createTestModel('mixed'); + + await runAsSystem(async () => { + await Model.create([ + { name: 'existing1', value: 1, tenantId: 'tenant-m' }, + { name: 'to-remove', value: 2, tenantId: 'tenant-m' }, + { name: 'existing1', value: 1, tenantId: 'tenant-other' }, + ]); + }); + + await tenantStorage.run({ tenantId: 'tenant-m' }, async () => { + await tenantSafeBulkWrite(Model, [ + { + insertOne: { + document: { name: 'new-item', value: 10 } as ITestDoc, + }, + }, + { + updateOne: { + filter: { name: 'existing1' }, + update: { $set: { value: 50 } }, + }, + }, + { + deleteOne: { + filter: { name: 'to-remove' }, + }, + }, + ]); + }); + + const docs = await runAsSystem(async () => Model.find({}).sort({ name: 1 }).lean()); + + // tenant-other's doc should be untouched + const otherDoc = docs.find((d) => d.tenantId === 'tenant-other' && d.name === 'existing1'); + expect(otherDoc?.value).toBe(1); + + // tenant-m: existing1 updated, to-remove deleted, new-item inserted + const tenantMDocs = docs.filter((d) => d.tenantId === 'tenant-m'); + expect(tenantMDocs).toHaveLength(2); + expect(tenantMDocs.find((d) => d.name === 'existing1')?.value).toBe(50); + expect(tenantMDocs.find((d) => d.name === 'new-item')?.value).toBe(10); + expect(tenantMDocs.find((d) => d.name === 'to-remove')).toBeUndefined(); + }); + }); +}); diff --git a/packages/data-schemas/src/utils/tenantBulkWrite.ts b/packages/data-schemas/src/utils/tenantBulkWrite.ts new file mode 100644 index 0000000000..16ef5fa057 --- /dev/null +++ b/packages/data-schemas/src/utils/tenantBulkWrite.ts @@ -0,0 +1,109 @@ +import type { AnyBulkWriteOperation, Model, MongooseBulkWriteOptions } from 'mongoose'; +import type { BulkWriteResult } from 'mongodb'; +import { getTenantId, SYSTEM_TENANT_ID } from '~/config/tenantContext'; +import logger from '~/config/winston'; + +let _strictMode: boolean | undefined; + +function isStrict(): boolean { + return (_strictMode ??= process.env.TENANT_ISOLATION_STRICT === 'true'); +} + +/** Resets the cached strict-mode flag. Exposed for test teardown only. */ +export function _resetBulkWriteStrictCache(): void { + _strictMode = undefined; +} + +/** + * Tenant-safe wrapper around Mongoose `Model.bulkWrite()`. + * + * Mongoose's `bulkWrite` does not trigger schema-level middleware hooks, so the + * `applyTenantIsolation` plugin cannot intercept it. This wrapper injects the + * current ALS tenant context into every operation's filter and/or document + * before delegating to the native `bulkWrite`. + * + * Behavior: + * - **tenantId present** (normal request): injects `{ tenantId }` into every + * operation filter (updateOne, deleteOne, replaceOne) and document (insertOne). + * - **SYSTEM_TENANT_ID**: skips injection (cross-tenant system operation). + * - **No tenantId + strict mode**: throws (fail-closed, same as the plugin). + * - **No tenantId + non-strict**: passes through without injection (backward compat). + */ +export async function tenantSafeBulkWrite( + model: Model, + ops: AnyBulkWriteOperation[], + options?: MongooseBulkWriteOptions, +): Promise { + const tenantId = getTenantId(); + + if (!tenantId) { + if (isStrict()) { + throw new Error( + `[TenantIsolation] bulkWrite on ${model.modelName} attempted without tenant context in strict mode`, + ); + } + return model.bulkWrite(ops, options); + } + + if (tenantId === SYSTEM_TENANT_ID) { + return model.bulkWrite(ops, options); + } + + const injected = ops.map((op) => injectTenantId(op, tenantId)); + return model.bulkWrite(injected, options); +} + +/** + * Injects `tenantId` into a single bulk-write operation. + * Returns a new operation object โ€” does not mutate the original. + */ +function injectTenantId(op: AnyBulkWriteOperation, tenantId: string): AnyBulkWriteOperation { + if ('insertOne' in op) { + return { + insertOne: { + document: { ...op.insertOne.document, tenantId }, + }, + }; + } + + if ('updateOne' in op) { + const { filter, ...rest } = op.updateOne; + return { updateOne: { ...rest, filter: { ...filter, tenantId } } }; + } + + if ('updateMany' in op) { + const { filter, ...rest } = op.updateMany; + return { updateMany: { ...rest, filter: { ...filter, tenantId } } }; + } + + if ('deleteOne' in op) { + const { filter, ...rest } = op.deleteOne; + return { deleteOne: { ...rest, filter: { ...filter, tenantId } } }; + } + + if ('deleteMany' in op) { + const { filter, ...rest } = op.deleteMany; + return { deleteMany: { ...rest, filter: { ...filter, tenantId } } }; + } + + if ('replaceOne' in op) { + const { filter, replacement, ...rest } = op.replaceOne; + return { + replaceOne: { + ...rest, + filter: { ...filter, tenantId }, + replacement: { ...replacement, tenantId }, + }, + }; + } + + if (isStrict()) { + throw new Error( + '[TenantIsolation] Unknown bulkWrite operation type in strict mode โ€” refusing to pass through without tenant injection', + ); + } + logger.warn( + '[tenantSafeBulkWrite] Unknown bulk op type, passing through without tenant injection', + ); + return op; +} From d5c7d9f525b7214bcd814911f85489d19c7dddb3 Mon Sep 17 00:00:00 2001 From: Marco Beretta <81851188+berry-13@users.noreply.github.com> Date: Sun, 29 Mar 2026 01:10:36 +0100 Subject: [PATCH 14/18] =?UTF-8?q?=F0=9F=93=9D=20docs:=20update=20deploymen?= =?UTF-8?q?t=20link=20for=20Railway=20in=20README=20and=20README.zh.md=20(?= =?UTF-8?q?#12449)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * docs: update deployment link for Railway in README * docs: mirror Railway deploy link update in README.zh.md Agent-Logs-Url: https://github.com/danny-avila/LibreChat/sessions/ea1b6e56-f93d-47a7-9d62-6157d824acff Co-authored-by: berry-13 <81851188+berry-13@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> --- README.md | 2 +- README.zh.md | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 3e05dc686b..a7f68d9a92 100644 --- a/README.md +++ b/README.md @@ -32,7 +32,7 @@

- + Deploy on Railway diff --git a/README.zh.md b/README.zh.md index cc9cb5a205..7f74057413 100644 --- a/README.zh.md +++ b/README.zh.md @@ -1,4 +1,4 @@ - +

@@ -34,7 +34,7 @@

- + Deploy on Railway From fda1bfc3ccf2a7e0d0abecdd80c28a72a84be230 Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Sat, 28 Mar 2026 21:06:39 -0400 Subject: [PATCH 15/18] =?UTF-8?q?=F0=9F=94=AC=20ci:=20Add=20TypeScript=20T?= =?UTF-8?q?ype=20Checks=20to=20Backend=20Workflow=20and=20Fix=20All=20Type?= =?UTF-8?q?=20Errors=20(#12451)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix(data-schemas): resolve TypeScript strict type check errors in source files - Constrain ConfigSection to string keys via `string & keyof TCustomConfig` - Replace broken `z` import from data-provider with TCustomConfig derivation - Add `_id: Types.ObjectId` to IUser matching other Document interfaces - Add `federatedTokens` and `openidTokens` optional fields to IUser - Type mongoose model accessors as `Model` and `Model` - Widen `getPremiumRate` param to accept `number | null` - Widen `bulkWriteAclEntries` ops to untyped `AnyBulkWriteOperation[]` - Fix `getUserPrincipals` return type to use `PrincipalType` enum - Add non-null assertions for `connection.db` in migration files - Import DailyRotateFile constructor directly instead of relying on broken module augmentation across mismatched node_modules trees - Add winston-daily-rotate-file as devDependency for type resolution * fix(data-schemas): resolve TypeScript type errors in test files - Replace arbitrary test keys with valid TCustomConfig properties in config.spec - Use non-null assertions for permission objects in role.methods.spec - Replace `.SHARED_GLOBAL` access with `.not.toHaveProperty()` for legacy field - Add non-null assertions for balance, writeRate, readRate in spendTokens.spec - Update mock user _id to use ObjectId in user.test - Remove unused Schema import in tenantIndexes.spec * fix(api): resolve TypeScript strict type check errors across source and test files - Widen getUserPrincipals dep type in capabilities middleware - Fix federatedTokens type in createSafeUser return - Use proper mock req type for read-only properties in preAuthTenant.spec - Replace `as IUser` casts with ObjectId-typed mocks in openid/oidc specs - Use TokenExchangeMethodEnum values instead of string literals in MCP specs - Fix SessionStore type compatibility in sessionCache specs - Replace `catch (error: any)` with `(error as Error)` in redis specs - Remove invalid properties from test data in initialize and MCP specs - Add String.prototype.isWellFormed declaration for sanitizeTitle spec * fix(client): resolve TypeScript type errors in shared client components - Add default values for destructured bindings in OGDialogTemplate - Replace broken ExtendedFile import with inline type in FileIcon * ci: add TypeScript type-check job to backend review workflow Add a `typecheck` job that runs `tsc --noEmit` on all four TypeScript workspaces (data-provider, data-schemas, @librechat/api, @librechat/client) after the build step. Catches type errors that rollup builds may miss. * fix(data-schemas): add local type declaration for DailyRotateFile transport The `winston-daily-rotate-file` package ships a module augmentation for `winston/lib/winston/transports`, but it fails when winston and winston-daily-rotate-file resolve from different node_modules trees (which happens in this monorepo due to npm hoisting). Add a local `.d.ts` declaration that augments the same module path from within data-schemas' compilation unit, so `tsc --noEmit` passes while keeping the original runtime pattern (`new winston.transports.DailyRotateFile`). * fix: address code review findings from PR #12451 - Restore typed `AnyBulkWriteOperation[]` on bulkWriteAclEntries, cast to untyped only at the tenantSafeBulkWrite call site (Finding 1) - Type `findUser` model accessor consistently with `findUsers` (Finding 2) - Replace inline `import('mongoose').ClientSession` with top-level import type - Use `toHaveLength` for spy assertions in playwright-expect spec file - Replace numbered Record casts with `.not.toHaveProperty()` in role.methods.spec for SHARED_GLOBAL assertions - Use per-test ObjectIds instead of shared testUserId in openid.spec - Replace inline `import()` type annotations with top-level SessionData import in sessionCache spec - Remove extraneous blank line in user.ts searchUsers * refactor: address remaining review findings (4โ€“7) - Extract OIDCTokens interface in user.ts; deduplicate across IUser fields and oidc.ts FederatedTokens (Finding 4) - Move String.isWellFormed declaration from spec file to project-level src/types/es2024-string.d.ts (Finding 5) - Replace verbose `= undefined` defaults in OGDialogTemplate with null coalescing pattern (Finding 6) - Replace `Record` TestConfig with named interface containing explicit test fields (Finding 7) --- .github/workflows/backend-review.yml | 59 +++++++++++++++++++ packages/api/src/admin/config.handler.spec.ts | 25 +++++--- packages/api/src/app/service.spec.ts | 22 +++++-- packages/api/src/auth/openid.spec.ts | 53 ++++++++--------- .../sessionCache.cache_integration.spec.ts | 49 +++++++-------- .../redisClients.cache_integration.spec.ts | 24 +++++--- .../src/endpoints/custom/initialize.spec.ts | 4 +- .../src/mcp/__tests__/MCPOAuthFlow.test.ts | 9 +-- .../mcp/__tests__/MCPOAuthSecurity.test.ts | 27 +++++---- .../__tests__/MCPOAuthTokenStorage.test.ts | 4 +- ...gsCacheRedis.perf_benchmark.manual.spec.ts | 13 +++- ...edisAggregateKey.cache_integration.spec.ts | 4 +- packages/api/src/middleware/capabilities.ts | 7 ++- .../api/src/middleware/preAuthTenant.spec.ts | 2 +- packages/api/src/types/es2024-string.d.ts | 4 ++ packages/api/src/utils/env.ts | 4 +- packages/api/src/utils/graph.spec.ts | 42 ++++++------- packages/api/src/utils/oidc.spec.ts | 24 ++++---- packages/api/src/utils/oidc.ts | 40 ++++++------- packages/api/types/index.d.ts | 4 ++ .../src/components/OGDialogTemplate.tsx | 5 +- packages/client/src/svgs/FileIcon.tsx | 3 +- packages/data-schemas/src/methods/aclEntry.ts | 6 +- .../data-schemas/src/methods/config.spec.ts | 54 ++++++++++++++--- .../src/methods/role.methods.spec.ts | 26 ++++---- packages/data-schemas/src/methods/role.ts | 2 +- .../src/methods/spendTokens.spec.ts | 22 +++---- packages/data-schemas/src/methods/tx.ts | 2 +- .../data-schemas/src/methods/user.test.ts | 2 +- packages/data-schemas/src/methods/user.ts | 10 ++-- .../src/methods/userGroup.spec.ts | 2 +- .../data-schemas/src/methods/userGroup.ts | 9 +-- .../src/migrations/promptGroupIndexes.ts | 2 +- .../src/migrations/tenantIndexes.spec.ts | 18 +++--- .../src/migrations/tenantIndexes.ts | 2 +- packages/data-schemas/src/types/admin.ts | 10 +--- packages/data-schemas/src/types/user.ts | 10 ++++ .../src/types/winston-transports.d.ts | 34 +++++++++++ 38 files changed, 406 insertions(+), 233 deletions(-) create mode 100644 packages/api/src/types/es2024-string.d.ts create mode 100644 packages/api/types/index.d.ts create mode 100644 packages/data-schemas/src/types/winston-transports.d.ts diff --git a/.github/workflows/backend-review.yml b/.github/workflows/backend-review.yml index 038c90627e..9dd3905c0e 100644 --- a/.github/workflows/backend-review.yml +++ b/.github/workflows/backend-review.yml @@ -97,6 +97,65 @@ jobs: path: packages/api/dist retention-days: 2 + typecheck: + name: TypeScript type checks + needs: build + runs-on: ubuntu-latest + timeout-minutes: 10 + steps: + - uses: actions/checkout@v4 + + - name: Use Node.js 20.19 + uses: actions/setup-node@v4 + with: + node-version: '20.19' + + - name: Restore node_modules cache + id: cache-node-modules + uses: actions/cache@v4 + with: + path: | + node_modules + api/node_modules + packages/api/node_modules + packages/data-provider/node_modules + packages/data-schemas/node_modules + key: node-modules-backend-${{ runner.os }}-20.19-${{ hashFiles('package-lock.json') }} + + - name: Install dependencies + if: steps.cache-node-modules.outputs.cache-hit != 'true' + run: npm ci + + - name: Download data-provider build + uses: actions/download-artifact@v4 + with: + name: build-data-provider + path: packages/data-provider/dist + + - name: Download data-schemas build + uses: actions/download-artifact@v4 + with: + name: build-data-schemas + path: packages/data-schemas/dist + + - name: Download api build + uses: actions/download-artifact@v4 + with: + name: build-api + path: packages/api/dist + + - name: Type check data-provider + run: npx tsc --noEmit -p packages/data-provider/tsconfig.json + + - name: Type check data-schemas + run: npx tsc --noEmit -p packages/data-schemas/tsconfig.json + + - name: Type check @librechat/api + run: npx tsc --noEmit -p packages/api/tsconfig.json + + - name: Type check @librechat/client + run: npx tsc --noEmit -p packages/client/tsconfig.json + circular-deps: name: Circular dependency checks needs: build diff --git a/packages/api/src/admin/config.handler.spec.ts b/packages/api/src/admin/config.handler.spec.ts index 705c54babc..708d114e72 100644 --- a/packages/api/src/admin/config.handler.spec.ts +++ b/packages/api/src/admin/config.handler.spec.ts @@ -1,3 +1,5 @@ +import type { Response } from 'express'; +import type { ServerRequest } from '~/types/http'; import { createAdminConfigHandlers } from './config'; function mockReq(overrides = {}) { @@ -7,23 +9,30 @@ function mockReq(overrides = {}) { body: {}, query: {}, ...overrides, - }; + } as Partial as ServerRequest; +} + +interface MockRes { + statusCode: number; + body: undefined | { config?: unknown; error?: string; [key: string]: unknown }; + status: jest.Mock; + json: jest.Mock; } function mockRes() { - const res = { + const res: MockRes = { statusCode: 200, body: undefined, - status: jest.fn((code) => { + status: jest.fn((code: number) => { res.statusCode = code; return res; }), - json: jest.fn((data) => { + json: jest.fn((data: MockRes['body']) => { res.body = data; return res; }), }; - return res; + return res as Partial as Response & MockRes; } function createHandlers(overrides = {}) { @@ -93,7 +102,7 @@ describe('createAdminConfigHandlers', () => { await handlers.getConfig(req, res); expect(res.statusCode).toBe(200); - expect(res.body.config).toEqual(config); + expect(res.body!.config).toEqual(config); }); it('returns 400 for invalid principalType', async () => { @@ -191,7 +200,7 @@ describe('createAdminConfigHandlers', () => { await handlers.deleteConfigField(req, res); expect(res.statusCode).toBe(400); - expect(res.body.error).toContain('query parameter'); + expect(res.body!.error).toContain('query parameter'); }); it('rejects unsafe field paths', async () => { @@ -408,7 +417,7 @@ describe('createAdminConfigHandlers', () => { await handlers.getBaseConfig(req, res); expect(res.statusCode).toBe(200); - expect(res.body.config).toEqual({ interface: { endpointsMenu: true } }); + expect(res.body!.config).toEqual({ interface: { endpointsMenu: true } }); }); }); }); diff --git a/packages/api/src/app/service.spec.ts b/packages/api/src/app/service.spec.ts index 4232a36dc3..c410783793 100644 --- a/packages/api/src/app/service.spec.ts +++ b/packages/api/src/app/service.spec.ts @@ -1,5 +1,13 @@ +import type { AppConfig } from '@librechat/data-schemas'; import { createAppConfigService } from './service'; +/** Extends AppConfig with mock fields used by merge behavior tests. */ +interface TestConfig extends AppConfig { + restricted?: boolean; + x?: string; + interface?: { endpointsMenu?: boolean; [key: string]: boolean | undefined }; +} + /** * Creates a mock cache that simulates Keyv's namespace behavior. * Keyv stores keys internally as `namespace:key` but its API (get/set/delete) @@ -18,7 +26,9 @@ function createMockCache(namespace = 'app_config') { return Promise.resolve(true); }), /** Mimic Keyv's opts.store structure for key enumeration in clearOverrideCache */ - opts: { store: { keys: () => store.keys() } }, + opts: { store: { keys: () => store.keys() } } as { + store?: { keys: () => IterableIterator }; + }, _store: store, }; } @@ -123,8 +133,10 @@ describe('createAppConfigService', () => { const config = await getAppConfig({ role: 'ADMIN' }); - expect(config.interface.endpointsMenu).toBe(false); - expect(config.endpoints).toEqual(['openAI']); + // Test data uses mock fields that don't exist on AppConfig to verify merge behavior + const merged = config as TestConfig; + expect(merged.interface?.endpointsMenu).toBe(false); + expect(merged.endpoints).toEqual(['openAI']); }); it('caches merged result with TTL', async () => { @@ -199,7 +211,7 @@ describe('createAppConfigService', () => { const config = await getAppConfig({ role: 'ADMIN' }); expect(mockGetConfigs).toHaveBeenCalledTimes(2); - expect((config as Record).restricted).toBe(true); + expect((config as TestConfig).restricted).toBe(true); }); it('does not short-circuit other users when one user has no overrides', async () => { @@ -216,7 +228,7 @@ describe('createAppConfigService', () => { const config = await getAppConfig({ role: 'ADMIN' }); expect(mockGetConfigs).toHaveBeenCalledTimes(2); - expect((config as Record).x).toBe('admin-only'); + expect((config as TestConfig).x).toBe('admin-only'); }); it('falls back to base config on getApplicableConfigs error', async () => { diff --git a/packages/api/src/auth/openid.spec.ts b/packages/api/src/auth/openid.spec.ts index 0761a24e85..2cf3992cdf 100644 --- a/packages/api/src/auth/openid.spec.ts +++ b/packages/api/src/auth/openid.spec.ts @@ -1,8 +1,13 @@ +import { Types } from 'mongoose'; import { ErrorTypes } from 'librechat-data-provider'; import { logger } from '@librechat/data-schemas'; import type { IUser, UserMethods } from '@librechat/data-schemas'; import { findOpenIDUser } from './openid'; +function newId() { + return new Types.ObjectId(); +} + jest.mock('@librechat/data-schemas', () => ({ ...jest.requireActual('@librechat/data-schemas'), logger: { @@ -24,7 +29,7 @@ describe('findOpenIDUser', () => { describe('Primary condition searches', () => { it('should find user by openidId', async () => { const mockUser: IUser = { - _id: 'user123', + _id: newId(), provider: 'openid', openidId: 'openid_123', email: 'user@example.com', @@ -51,7 +56,7 @@ describe('findOpenIDUser', () => { it('should find user by idOnTheSource', async () => { const mockUser: IUser = { - _id: 'user123', + _id: newId(), provider: 'openid', idOnTheSource: 'source_123', email: 'user@example.com', @@ -78,7 +83,7 @@ describe('findOpenIDUser', () => { it('should find user by both openidId and idOnTheSource', async () => { const mockUser: IUser = { - _id: 'user123', + _id: newId(), provider: 'openid', openidId: 'openid_123', idOnTheSource: 'source_123', @@ -109,16 +114,14 @@ describe('findOpenIDUser', () => { describe('Email-based searches', () => { it('should find user by email when primary conditions fail and openidId matches', async () => { const mockUser: IUser = { - _id: 'user123', + _id: newId(), provider: 'openid', openidId: 'openid_123', email: 'user@example.com', username: 'testuser', } as IUser; - mockFindUser - .mockResolvedValueOnce(null) - .mockResolvedValueOnce(mockUser); + mockFindUser.mockResolvedValueOnce(null).mockResolvedValueOnce(mockUser); const result = await findOpenIDUser({ openidId: 'openid_123', @@ -179,7 +182,7 @@ describe('findOpenIDUser', () => { describe('Provider conflict handling', () => { it('should return error when user has different provider', async () => { const mockUser: IUser = { - _id: 'user123', + _id: newId(), provider: 'google', email: 'user@example.com', username: 'testuser', @@ -204,16 +207,14 @@ describe('findOpenIDUser', () => { it('should reject email fallback when existing openidId does not match token sub', async () => { const mockUser: IUser = { - _id: 'user123', + _id: newId(), provider: 'openid', openidId: 'openid_456', email: 'user@example.com', username: 'testuser', } as IUser; - mockFindUser - .mockResolvedValueOnce(null) - .mockResolvedValueOnce(mockUser); + mockFindUser.mockResolvedValueOnce(null).mockResolvedValueOnce(mockUser); const result = await findOpenIDUser({ openidId: 'openid_123', @@ -230,16 +231,14 @@ describe('findOpenIDUser', () => { it('should allow email fallback when existing openidId matches token sub', async () => { const mockUser: IUser = { - _id: 'user123', + _id: newId(), provider: 'openid', openidId: 'openid_123', email: 'user@example.com', username: 'testuser', } as IUser; - mockFindUser - .mockResolvedValueOnce(null) - .mockResolvedValueOnce(mockUser); + mockFindUser.mockResolvedValueOnce(null).mockResolvedValueOnce(mockUser); const result = await findOpenIDUser({ openidId: 'openid_123', @@ -258,7 +257,7 @@ describe('findOpenIDUser', () => { describe('User migration scenarios', () => { it('should prepare user for migration when email exists without openidId', async () => { const mockUser: IUser = { - _id: 'user123', + _id: newId(), email: 'user@example.com', username: 'testuser', // No provider and no openidId - needs migration @@ -287,16 +286,14 @@ describe('findOpenIDUser', () => { it('should reject when user already has a different openidId', async () => { const mockUser: IUser = { - _id: 'user123', + _id: newId(), provider: 'openid', openidId: 'existing_openid', email: 'user@example.com', username: 'testuser', } as IUser; - mockFindUser - .mockResolvedValueOnce(null) - .mockResolvedValueOnce(mockUser); + mockFindUser.mockResolvedValueOnce(null).mockResolvedValueOnce(mockUser); const result = await findOpenIDUser({ openidId: 'openid_123', @@ -313,16 +310,14 @@ describe('findOpenIDUser', () => { it('should reject when user has no provider but a different openidId', async () => { const mockUser: IUser = { - _id: 'user123', + _id: newId(), openidId: 'existing_openid', email: 'user@example.com', username: 'testuser', // No provider field โ€” tests a different branch than openid-provider mismatch } as IUser; - mockFindUser - .mockResolvedValueOnce(null) - .mockResolvedValueOnce(mockUser); + mockFindUser.mockResolvedValueOnce(null).mockResolvedValueOnce(mockUser); const result = await findOpenIDUser({ openidId: 'openid_123', @@ -422,16 +417,14 @@ describe('findOpenIDUser', () => { it('should pass email to findUser for case-insensitive lookup (findUser handles normalization)', async () => { const mockUser: IUser = { - _id: 'user123', + _id: newId(), provider: 'openid', openidId: 'openid_123', email: 'user@example.com', username: 'testuser', } as IUser; - mockFindUser - .mockResolvedValueOnce(null) - .mockResolvedValueOnce(mockUser); + mockFindUser.mockResolvedValueOnce(null).mockResolvedValueOnce(mockUser); const result = await findOpenIDUser({ openidId: 'openid_123', @@ -460,7 +453,7 @@ describe('findOpenIDUser', () => { it('should reject email fallback when openidId is empty and user has a stored openidId', async () => { const mockUser: IUser = { - _id: 'user123', + _id: newId(), provider: 'openid', openidId: 'existing-real-id', email: 'user@example.com', diff --git a/packages/api/src/cache/__tests__/cacheFactory/sessionCache.cache_integration.spec.ts b/packages/api/src/cache/__tests__/cacheFactory/sessionCache.cache_integration.spec.ts index f4ded8bc74..99c0d69b37 100644 --- a/packages/api/src/cache/__tests__/cacheFactory/sessionCache.cache_integration.spec.ts +++ b/packages/api/src/cache/__tests__/cacheFactory/sessionCache.cache_integration.spec.ts @@ -1,33 +1,36 @@ -interface SessionData { +import type { MemoryStore, SessionData } from 'express-session'; +import type { RedisStore as ConnectRedis } from 'connect-redis'; + +interface TestSessionData { [key: string]: unknown; cookie?: { maxAge: number }; user?: { id: string; name: string }; userId?: string; } -interface SessionStore { - prefix?: string; - set: (id: string, data: SessionData, callback?: (err?: Error) => void) => void; - get: (id: string, callback: (err: Error | null, data?: SessionData | null) => void) => void; - destroy: (id: string, callback?: (err?: Error) => void) => void; - touch: (id: string, data: SessionData, callback?: (err?: Error) => void) => void; - on?: (event: string, handler: (...args: unknown[]) => void) => void; -} +type CacheSessionStore = MemoryStore | ConnectRedis; describe('sessionCache', () => { let originalEnv: NodeJS.ProcessEnv; - // Helper to make session stores async - const asyncStore = (store: SessionStore) => ({ - set: (id: string, data: SessionData) => - new Promise((resolve) => store.set(id, data, () => resolve())), + // Helper to make session stores async โ€” uses generic store type to bridge + // between MemoryStore/ConnectRedis and the test's relaxed SessionData shape. + // The store methods accept express-session's SessionData but test data is + // intentionally simpler; the cast bridges the gap for integration tests. + const asyncStore = (store: CacheSessionStore) => ({ + set: (id: string, data: TestSessionData) => + new Promise((resolve) => + store.set(id, data as Partial as SessionData, () => resolve()), + ), get: (id: string) => - new Promise((resolve) => - store.get(id, (_, data) => resolve(data)), + new Promise((resolve) => + store.get(id, (_, data) => resolve(data as TestSessionData | null | undefined)), ), destroy: (id: string) => new Promise((resolve) => store.destroy(id, () => resolve())), - touch: (id: string, data: SessionData) => - new Promise((resolve) => store.touch(id, data, () => resolve())), + touch: (id: string, data: TestSessionData) => + new Promise((resolve) => + store.touch(id, data as Partial as SessionData, () => resolve()), + ), }); beforeEach(() => { @@ -66,11 +69,11 @@ describe('sessionCache', () => { // Verify it returns a ConnectRedis instance expect(store).toBeDefined(); expect(store.constructor.name).toBe('RedisStore'); - expect(store.prefix).toBe('test-sessions:'); + expect((store as CacheSessionStore & { prefix: string }).prefix).toBe('test-sessions:'); // Test session operations const sessionId = 'sess:123456'; - const sessionData: SessionData = { + const sessionData: TestSessionData = { user: { id: 'user123', name: 'Test User' }, cookie: { maxAge: 3600000 }, }; @@ -107,7 +110,7 @@ describe('sessionCache', () => { // Test session operations const sessionId = 'mem:789012'; - const sessionData: SessionData = { + const sessionData: TestSessionData = { user: { id: 'user456', name: 'Memory User' }, cookie: { maxAge: 3600000 }, }; @@ -135,8 +138,8 @@ describe('sessionCache', () => { const store1 = cacheFactory.sessionCache('namespace1'); const store2 = cacheFactory.sessionCache('namespace2:'); - expect(store1.prefix).toBe('namespace1:'); - expect(store2.prefix).toBe('namespace2:'); + expect((store1 as CacheSessionStore & { prefix: string }).prefix).toBe('namespace1:'); + expect((store2 as CacheSessionStore & { prefix: string }).prefix).toBe('namespace2:'); }); test('should register error handler for Redis connection', async () => { @@ -171,7 +174,7 @@ describe('sessionCache', () => { } const sessionId = 'ttl:12345'; - const sessionData: SessionData = { userId: 'ttl-user' }; + const sessionData: TestSessionData = { userId: 'ttl-user' }; const async = asyncStore(store); // Set session with short TTL diff --git a/packages/api/src/cache/__tests__/redisClients.cache_integration.spec.ts b/packages/api/src/cache/__tests__/redisClients.cache_integration.spec.ts index dc9a325746..77e8c01436 100644 --- a/packages/api/src/cache/__tests__/redisClients.cache_integration.spec.ts +++ b/packages/api/src/cache/__tests__/redisClients.cache_integration.spec.ts @@ -59,8 +59,8 @@ describe('redisClients Integration Tests', () => { if (keys.length > 0) { await ioredisClient.del(...keys); } - } catch (error: any) { - console.warn('Error cleaning up test keys:', error.message); + } catch (error) { + console.warn('Error cleaning up test keys:', (error as Error).message); } } @@ -70,8 +70,8 @@ describe('redisClients Integration Tests', () => { if (ioredisClient.status === 'ready') { ioredisClient.disconnect(); } - } catch (error: any) { - console.warn('Error disconnecting ioredis client:', error.message); + } catch (error) { + console.warn('Error disconnecting ioredis client:', (error as Error).message); } ioredisClient = null; } @@ -80,8 +80,8 @@ describe('redisClients Integration Tests', () => { try { // Try to disconnect - keyv/redis client doesn't have an isReady property await keyvRedisClient.disconnect(); - } catch (error: any) { - console.warn('Error disconnecting keyv redis client:', error.message); + } catch (error) { + console.warn('Error disconnecting keyv redis client:', (error as Error).message); } keyvRedisClient = null; } @@ -138,7 +138,11 @@ describe('redisClients Integration Tests', () => { test('should connect and perform set/get/delete operations', async () => { const clients = await import('../redisClients'); keyvRedisClient = clients.keyvRedisClient; - await testRedisOperations(keyvRedisClient!, 'keyv-single', clients.keyvRedisClientReady!); + await testRedisOperations( + keyvRedisClient!, + 'keyv-single', + clients.keyvRedisClientReady!.then(() => undefined), + ); }); }); @@ -150,7 +154,11 @@ describe('redisClients Integration Tests', () => { const clients = await import('../redisClients'); keyvRedisClient = clients.keyvRedisClient; - await testRedisOperations(keyvRedisClient!, 'keyv-cluster', clients.keyvRedisClientReady!); + await testRedisOperations( + keyvRedisClient!, + 'keyv-cluster', + clients.keyvRedisClientReady!.then(() => undefined), + ); }); }); }); diff --git a/packages/api/src/endpoints/custom/initialize.spec.ts b/packages/api/src/endpoints/custom/initialize.spec.ts index 3705f98977..eddd7cb515 100644 --- a/packages/api/src/endpoints/custom/initialize.spec.ts +++ b/packages/api/src/endpoints/custom/initialize.spec.ts @@ -81,7 +81,7 @@ describe('initializeCustom โ€“ Agents API user key resolution', () => { userApiKey: 'sk-user-key', }); // Simulate Agents API request body (no `key` field) - params.req.body = { model: 'agent_123', messages: [] }; + params.req.body = { model: 'agent_123' }; await initializeCustom(params); @@ -104,7 +104,7 @@ describe('initializeCustom โ€“ Agents API user key resolution', () => { baseURL: AuthType.USER_PROVIDED, userBaseURL: 'https://user-api.example.com/v1', }); - params.req.body = { model: 'agent_123', messages: [] }; + params.req.body = { model: 'agent_123' }; await initializeCustom(params); diff --git a/packages/api/src/mcp/__tests__/MCPOAuthFlow.test.ts b/packages/api/src/mcp/__tests__/MCPOAuthFlow.test.ts index 7e26165cad..cbd29d3571 100644 --- a/packages/api/src/mcp/__tests__/MCPOAuthFlow.test.ts +++ b/packages/api/src/mcp/__tests__/MCPOAuthFlow.test.ts @@ -7,6 +7,7 @@ import { createHash } from 'crypto'; import { Keyv } from 'keyv'; +import { TokenExchangeMethodEnum } from 'librechat-data-provider'; import { MCPTokenStorage, MCPOAuthHandler } from '~/mcp/oauth'; import { FlowStateManager } from '~/flow/manager'; import { createOAuthMCPServer, MockKeyv, InMemoryTokenStore } from './helpers/oauthTestServer'; @@ -94,7 +95,7 @@ describe('MCP OAuth Flow โ€” Real HTTP Server', () => { token_url: `${server.url}token`, client_id: clientInfo.client_id, client_secret: clientInfo.client_secret, - token_exchange_method: 'DefaultPost', + token_exchange_method: TokenExchangeMethodEnum.DefaultPost, }, ); @@ -133,7 +134,7 @@ describe('MCP OAuth Flow โ€” Real HTTP Server', () => { { token_url: `${rotatingServer.url}token`, client_id: 'anon', - token_exchange_method: 'DefaultPost', + token_exchange_method: TokenExchangeMethodEnum.DefaultPost, }, ); @@ -157,7 +158,7 @@ describe('MCP OAuth Flow โ€” Real HTTP Server', () => { { token_url: `${server.url}token`, client_id: 'anon', - token_exchange_method: 'DefaultPost', + token_exchange_method: TokenExchangeMethodEnum.DefaultPost, }, ), ).rejects.toThrow(); @@ -414,7 +415,7 @@ describe('MCP OAuth Flow โ€” Real HTTP Server', () => { const state = await flowManager.getFlowState(flowId, 'mcp_oauth'); expect(state?.status).toBe('COMPLETED'); - expect(state?.result?.access_token).toBe(tokens.access_token); + expect((state?.result as MCPOAuthTokens | undefined)?.access_token).toBe(tokens.access_token); }); it('should fail flow when authorization code is invalid', async () => { diff --git a/packages/api/src/mcp/__tests__/MCPOAuthSecurity.test.ts b/packages/api/src/mcp/__tests__/MCPOAuthSecurity.test.ts index a2d0440d42..d50e29eab7 100644 --- a/packages/api/src/mcp/__tests__/MCPOAuthSecurity.test.ts +++ b/packages/api/src/mcp/__tests__/MCPOAuthSecurity.test.ts @@ -304,10 +304,10 @@ describe('MCP OAuth allowedDomains SSRF exemption for admin-trusted hosts', () = }); it('should allow private revocationEndpoint when hostname is in allowedDomains', async () => { - const mockFetch = jest.fn().mockResolvedValue({ - ok: true, - status: 200, - } as Response); + const mockFetch = Object.assign( + jest.fn().mockResolvedValue({ ok: true, status: 200 } as Response), + { preconnect: jest.fn() }, + ); const originalFetch = global.fetch; global.fetch = mockFetch; @@ -333,14 +333,17 @@ describe('MCP OAuth allowedDomains SSRF exemption for admin-trusted hosts', () = }); it('should allow localhost token_url in refreshOAuthTokens when localhost is in allowedDomains', async () => { - const mockFetch = jest.fn().mockResolvedValue({ - ok: true, - json: async () => ({ - access_token: 'new-access-token', - token_type: 'Bearer', - expires_in: 3600, - }), - } as Response); + const mockFetch = Object.assign( + jest.fn().mockResolvedValue({ + ok: true, + json: async () => ({ + access_token: 'new-access-token', + token_type: 'Bearer', + expires_in: 3600, + }), + } as Response), + { preconnect: jest.fn() }, + ); const originalFetch = global.fetch; global.fetch = mockFetch; diff --git a/packages/api/src/mcp/__tests__/MCPOAuthTokenStorage.test.ts b/packages/api/src/mcp/__tests__/MCPOAuthTokenStorage.test.ts index 3805586453..2d3905d2fb 100644 --- a/packages/api/src/mcp/__tests__/MCPOAuthTokenStorage.test.ts +++ b/packages/api/src/mcp/__tests__/MCPOAuthTokenStorage.test.ts @@ -160,7 +160,7 @@ describe('MCPTokenStorage', () => { serverName: 'srv1', tokens: { access_token: 'at1', token_type: 'Bearer', expires_in: 3600 }, createToken: store.createToken, - clientInfo: { client_id: 'cid', client_secret: 'csec', redirect_uris: [] }, + clientInfo: { client_id: 'cid', client_secret: 'csec' }, }); const clientSaved = await store.findToken({ @@ -525,7 +525,7 @@ describe('MCPTokenStorage', () => { refresh_token: 'my-refresh-token', }, createToken: store.createToken, - clientInfo: { client_id: 'cid', client_secret: 'sec', redirect_uris: [] }, + clientInfo: { client_id: 'cid', client_secret: 'sec' }, }); const result = await MCPTokenStorage.getTokens({ diff --git a/packages/api/src/mcp/registry/cache/__tests__/ServerConfigsCacheRedis.perf_benchmark.manual.spec.ts b/packages/api/src/mcp/registry/cache/__tests__/ServerConfigsCacheRedis.perf_benchmark.manual.spec.ts index 1815d49fe0..d9dc7bb978 100644 --- a/packages/api/src/mcp/registry/cache/__tests__/ServerConfigsCacheRedis.perf_benchmark.manual.spec.ts +++ b/packages/api/src/mcp/registry/cache/__tests__/ServerConfigsCacheRedis.perf_benchmark.manual.spec.ts @@ -13,6 +13,7 @@ * the current SCAN+GET implementation. */ import { expect } from '@playwright/test'; +import type { RedisClientType } from 'redis'; import type { ParsedServerConfig } from '~/mcp/types'; describe('ServerConfigsCacheRedis Performance Benchmark', () => { @@ -103,7 +104,9 @@ describe('ServerConfigsCacheRedis Performance Benchmark', () => { // Phase 1: SCAN only (key discovery) const scanStart = Date.now(); const keys: string[] = []; - for await (const key of keyvRedisClient!.scanIterator({ MATCH: pattern })) { + for await (const key of (keyvRedisClient as RedisClientType).scanIterator({ + MATCH: pattern, + })) { keys.push(key); } const scanMs = Date.now() - scanStart; @@ -166,7 +169,9 @@ describe('ServerConfigsCacheRedis Performance Benchmark', () => { // Measure SCAN with noise const scanStart = Date.now(); const keys: string[] = []; - for await (const key of keyvRedisClient!.scanIterator({ MATCH: pattern })) { + for await (const key of (keyvRedisClient as RedisClientType).scanIterator({ + MATCH: pattern, + })) { keys.push(key); } const scanMs = Date.now() - scanStart; @@ -299,7 +304,9 @@ describe('ServerConfigsCacheRedis Performance Benchmark', () => { // First, discover keys via SCAN (same for both approaches) const pattern = `*MCP::ServersRegistry::Servers::${ns}:*`; const keys: string[] = []; - for await (const key of keyvRedisClient!.scanIterator({ MATCH: pattern })) { + for await (const key of (keyvRedisClient as RedisClientType).scanIterator({ + MATCH: pattern, + })) { keys.push(key); } diff --git a/packages/api/src/mcp/registry/cache/__tests__/ServerConfigsCacheRedisAggregateKey.cache_integration.spec.ts b/packages/api/src/mcp/registry/cache/__tests__/ServerConfigsCacheRedisAggregateKey.cache_integration.spec.ts index 5aeb49b206..4ec30187a2 100644 --- a/packages/api/src/mcp/registry/cache/__tests__/ServerConfigsCacheRedisAggregateKey.cache_integration.spec.ts +++ b/packages/api/src/mcp/registry/cache/__tests__/ServerConfigsCacheRedisAggregateKey.cache_integration.spec.ts @@ -261,7 +261,7 @@ describe('ServerConfigsCacheRedisAggregateKey Integration Tests', () => { await cache.getAll(); // Snapshot should be served; Redis should NOT have been called - expect(cacheGetSpy).not.toHaveBeenCalled(); + expect(cacheGetSpy.mock.calls).toHaveLength(0); cacheGetSpy.mockRestore(); }); @@ -330,7 +330,7 @@ describe('ServerConfigsCacheRedisAggregateKey Integration Tests', () => { // eslint-disable-next-line @typescript-eslint/no-explicit-any const cacheGetSpy = jest.spyOn((cache as any).cache, 'get'); const result = await cache.getAll(); - expect(cacheGetSpy).toHaveBeenCalledTimes(1); + expect(cacheGetSpy.mock.calls).toHaveLength(1); expect(Object.keys(result).length).toBe(1); cacheGetSpy.mockRestore(); }); diff --git a/packages/api/src/middleware/capabilities.ts b/packages/api/src/middleware/capabilities.ts index 28d3a0f76e..a3f1fe9038 100644 --- a/packages/api/src/middleware/capabilities.ts +++ b/packages/api/src/middleware/capabilities.ts @@ -9,7 +9,7 @@ import { import type { PrincipalType } from 'librechat-data-provider'; import type { SystemCapability, ConfigSection } from '@librechat/data-schemas'; import type { NextFunction, Response } from 'express'; -import type { Types } from 'mongoose'; +import type { Types, ClientSession } from 'mongoose'; import type { ServerRequest } from '~/types/http'; interface ResolvedPrincipal { @@ -18,7 +18,10 @@ interface ResolvedPrincipal { } interface CapabilityDeps { - getUserPrincipals: (params: { userId: string; role: string }) => Promise; + getUserPrincipals: ( + params: { userId: string | Types.ObjectId; role?: string | null }, + session?: ClientSession, + ) => Promise; hasCapabilityForPrincipals: (params: { principals: ResolvedPrincipal[]; capability: SystemCapability; diff --git a/packages/api/src/middleware/preAuthTenant.spec.ts b/packages/api/src/middleware/preAuthTenant.spec.ts index ed35da2324..669a43c84f 100644 --- a/packages/api/src/middleware/preAuthTenant.spec.ts +++ b/packages/api/src/middleware/preAuthTenant.spec.ts @@ -13,7 +13,7 @@ jest.mock('@librechat/data-schemas', () => ({ })); describe('preAuthTenantMiddleware', () => { - let req: Partial; + let req: { headers: Record; ip?: string; path?: string }; let res: Partial; beforeEach(() => { diff --git a/packages/api/src/types/es2024-string.d.ts b/packages/api/src/types/es2024-string.d.ts new file mode 100644 index 0000000000..f25bc46bda --- /dev/null +++ b/packages/api/src/types/es2024-string.d.ts @@ -0,0 +1,4 @@ +/** String.prototype.isWellFormed โ€” ES2024 API, available in Node 20+ but absent from TS 5.3 lib */ +interface String { + isWellFormed(): boolean; +} diff --git a/packages/api/src/utils/env.ts b/packages/api/src/utils/env.ts index adeeb24b34..f71a131c09 100644 --- a/packages/api/src/utils/env.ts +++ b/packages/api/src/utils/env.ts @@ -84,12 +84,12 @@ export function encodeHeaderValue(value: string): string { */ export function createSafeUser( user: IUser | null | undefined, -): Partial & { federatedTokens?: unknown } { +): Partial & { federatedTokens?: IUser['federatedTokens'] } { if (!user) { return {}; } - const safeUser: Partial & { federatedTokens?: unknown } = {}; + const safeUser: Partial & { federatedTokens?: IUser['federatedTokens'] } = {}; for (const field of ALLOWED_USER_FIELDS) { if (field in user) { safeUser[field] = user[field]; diff --git a/packages/api/src/utils/graph.spec.ts b/packages/api/src/utils/graph.spec.ts index 4f1fa14983..91f8a29eff 100644 --- a/packages/api/src/utils/graph.spec.ts +++ b/packages/api/src/utils/graph.spec.ts @@ -1,4 +1,4 @@ -import type { TUser } from 'librechat-data-provider'; +import type { IUser } from '@librechat/data-schemas'; import type { GraphTokenResolver, GraphTokenOptions } from './graph'; import { containsGraphTokenPlaceholder, @@ -94,9 +94,9 @@ describe('Graph Token Utilities', () => { }); it('should return false for non-object values', () => { - expect(recordContainsGraphTokenPlaceholder('string' as unknown as Record)).toBe( - false, - ); + expect( + recordContainsGraphTokenPlaceholder('string' as unknown as Record), + ).toBe(false); }); }); @@ -141,7 +141,7 @@ describe('Graph Token Utilities', () => { }); describe('resolveGraphTokenPlaceholder', () => { - const mockUser: Partial = { + const mockUser: Partial = { id: 'user-123', provider: 'openid', openidId: 'oidc-sub-456', @@ -157,7 +157,7 @@ describe('Graph Token Utilities', () => { it('should return original value when no placeholder is present', async () => { const value = 'Bearer static-token'; const result = await resolveGraphTokenPlaceholder(value, { - user: mockUser as TUser, + user: mockUser as Partial as IUser, graphTokenResolver: mockGraphTokenResolver, }); expect(result).toBe('Bearer static-token'); @@ -174,7 +174,7 @@ describe('Graph Token Utilities', () => { it('should return original value when graphTokenResolver is not provided', async () => { const value = 'Bearer {{LIBRECHAT_GRAPH_ACCESS_TOKEN}}'; const result = await resolveGraphTokenPlaceholder(value, { - user: mockUser as TUser, + user: mockUser as Partial as IUser, }); expect(result).toBe(value); }); @@ -184,7 +184,7 @@ describe('Graph Token Utilities', () => { const value = 'Bearer {{LIBRECHAT_GRAPH_ACCESS_TOKEN}}'; const result = await resolveGraphTokenPlaceholder(value, { - user: mockUser as TUser, + user: mockUser as Partial as IUser, graphTokenResolver: mockGraphTokenResolver, }); expect(result).toBe(value); @@ -196,7 +196,7 @@ describe('Graph Token Utilities', () => { const value = 'Bearer {{LIBRECHAT_GRAPH_ACCESS_TOKEN}}'; const result = await resolveGraphTokenPlaceholder(value, { - user: mockUser as TUser, + user: mockUser as Partial as IUser, graphTokenResolver: mockGraphTokenResolver, }); expect(result).toBe(value); @@ -208,7 +208,7 @@ describe('Graph Token Utilities', () => { const value = 'Bearer {{LIBRECHAT_GRAPH_ACCESS_TOKEN}}'; const result = await resolveGraphTokenPlaceholder(value, { - user: mockUser as TUser, + user: mockUser as Partial as IUser, graphTokenResolver: mockGraphTokenResolver, }); expect(result).toBe(value); @@ -220,7 +220,7 @@ describe('Graph Token Utilities', () => { const value = 'Bearer {{LIBRECHAT_GRAPH_ACCESS_TOKEN}}'; const result = await resolveGraphTokenPlaceholder(value, { - user: mockUser as TUser, + user: mockUser as Partial as IUser, graphTokenResolver: mockGraphTokenResolver, }); expect(result).toBe('Bearer resolved-graph-token'); @@ -233,7 +233,7 @@ describe('Graph Token Utilities', () => { const value = 'Primary: {{LIBRECHAT_GRAPH_ACCESS_TOKEN}}, Secondary: {{LIBRECHAT_GRAPH_ACCESS_TOKEN}}'; const result = await resolveGraphTokenPlaceholder(value, { - user: mockUser as TUser, + user: mockUser as Partial as IUser, graphTokenResolver: mockGraphTokenResolver, }); expect(result).toBe('Primary: resolved-graph-token, Secondary: resolved-graph-token'); @@ -242,11 +242,13 @@ describe('Graph Token Utilities', () => { it('should return original value when graph token exchange fails', async () => { mockExtractOpenIDTokenInfo.mockReturnValue({ accessToken: 'access-token' }); mockIsOpenIDTokenValid.mockReturnValue(true); - const failingResolver: GraphTokenResolver = jest.fn().mockRejectedValue(new Error('Exchange failed')); + const failingResolver: GraphTokenResolver = jest + .fn() + .mockRejectedValue(new Error('Exchange failed')); const value = 'Bearer {{LIBRECHAT_GRAPH_ACCESS_TOKEN}}'; const result = await resolveGraphTokenPlaceholder(value, { - user: mockUser as TUser, + user: mockUser as Partial as IUser, graphTokenResolver: failingResolver, }); expect(result).toBe(value); @@ -259,7 +261,7 @@ describe('Graph Token Utilities', () => { const value = 'Bearer {{LIBRECHAT_GRAPH_ACCESS_TOKEN}}'; const result = await resolveGraphTokenPlaceholder(value, { - user: mockUser as TUser, + user: mockUser as Partial as IUser, graphTokenResolver: emptyResolver, }); expect(result).toBe(value); @@ -271,7 +273,7 @@ describe('Graph Token Utilities', () => { const value = 'Bearer {{LIBRECHAT_GRAPH_ACCESS_TOKEN}}'; await resolveGraphTokenPlaceholder(value, { - user: mockUser as TUser, + user: mockUser as Partial as IUser, graphTokenResolver: mockGraphTokenResolver, scopes: 'custom-scope', }); @@ -286,7 +288,7 @@ describe('Graph Token Utilities', () => { }); describe('resolveGraphTokensInRecord', () => { - const mockUser: Partial = { + const mockUser: Partial = { id: 'user-123', provider: 'openid', }; @@ -299,7 +301,7 @@ describe('Graph Token Utilities', () => { }); const options: GraphTokenOptions = { - user: mockUser as TUser, + user: mockUser as Partial as IUser, graphTokenResolver: mockGraphTokenResolver, }; @@ -348,7 +350,7 @@ describe('Graph Token Utilities', () => { }); describe('preProcessGraphTokens', () => { - const mockUser: Partial = { + const mockUser: Partial = { id: 'user-123', provider: 'openid', }; @@ -361,7 +363,7 @@ describe('Graph Token Utilities', () => { }); const graphOptions: GraphTokenOptions = { - user: mockUser as TUser, + user: mockUser as Partial as IUser, graphTokenResolver: mockGraphTokenResolver, }; diff --git a/packages/api/src/utils/oidc.spec.ts b/packages/api/src/utils/oidc.spec.ts index 0d7216304b..e7088d9897 100644 --- a/packages/api/src/utils/oidc.spec.ts +++ b/packages/api/src/utils/oidc.spec.ts @@ -1,10 +1,10 @@ import { extractOpenIDTokenInfo, isOpenIDTokenValid, processOpenIDPlaceholders } from './oidc'; -import type { TUser } from 'librechat-data-provider'; +import type { IUser } from '@librechat/data-schemas'; describe('OpenID Token Utilities', () => { describe('extractOpenIDTokenInfo', () => { it('should extract token info from user with federatedTokens', () => { - const user: Partial = { + const user: Partial = { id: 'user-123', provider: 'openid', openidId: 'oidc-sub-456', @@ -36,7 +36,7 @@ describe('OpenID Token Utilities', () => { }); it('should return null when user is not OpenID provider', () => { - const user: Partial = { + const user: Partial = { id: 'user-123', provider: 'email', }; @@ -46,7 +46,7 @@ describe('OpenID Token Utilities', () => { }); it('should return token info when user has no federatedTokens but is OpenID provider', () => { - const user: Partial = { + const user: Partial = { id: 'user-123', provider: 'openid', openidId: 'oidc-sub-456', @@ -66,7 +66,7 @@ describe('OpenID Token Utilities', () => { }); it('should extract partial token info when some tokens are missing', () => { - const user: Partial = { + const user: Partial = { id: 'user-123', provider: 'openid', openidId: 'oidc-sub-456', @@ -89,7 +89,7 @@ describe('OpenID Token Utilities', () => { }); it('should prioritize openidId over regular id', () => { - const user: Partial = { + const user: Partial = { id: 'user-123', provider: 'openid', openidId: 'oidc-sub-456', @@ -104,7 +104,7 @@ describe('OpenID Token Utilities', () => { }); it('should fall back to regular id when openidId is not available', () => { - const user: Partial = { + const user: Partial = { id: 'user-123', provider: 'openid', federatedTokens: { @@ -397,7 +397,7 @@ describe('OpenID Token Utilities', () => { describe('Integration: Full OpenID Token Flow', () => { it('should extract, validate, and process tokens correctly', () => { - const user: Partial = { + const user: Partial = { id: 'user-123', provider: 'openid', openidId: 'oidc-sub-456', @@ -428,7 +428,7 @@ describe('OpenID Token Utilities', () => { }); it('should resolve LIBRECHAT_OPENID_ID_TOKEN and LIBRECHAT_OPENID_ACCESS_TOKEN to different values', () => { - const user: Partial = { + const user: Partial = { id: 'user-123', provider: 'openid', openidId: 'oidc-sub-456', @@ -457,7 +457,7 @@ describe('OpenID Token Utilities', () => { }); it('should handle expired tokens correctly', () => { - const user: Partial = { + const user: Partial = { id: 'user-123', provider: 'openid', openidId: 'oidc-sub-456', @@ -481,7 +481,7 @@ describe('OpenID Token Utilities', () => { }); it('should handle user with no federatedTokens but still has OpenID provider', () => { - const user: Partial = { + const user: Partial = { id: 'user-123', provider: 'openid', openidId: 'oidc-sub-456', @@ -499,7 +499,7 @@ describe('OpenID Token Utilities', () => { }); it('should handle non-OpenID users', () => { - const user: Partial = { + const user: Partial = { id: 'user-123', provider: 'email', }; diff --git a/packages/api/src/utils/oidc.ts b/packages/api/src/utils/oidc.ts index dbf41818c4..51056406c1 100644 --- a/packages/api/src/utils/oidc.ts +++ b/packages/api/src/utils/oidc.ts @@ -1,5 +1,5 @@ import { logger } from '@librechat/data-schemas'; -import type { IUser } from '@librechat/data-schemas'; +import type { IUser, OIDCTokens } from '@librechat/data-schemas'; export interface OpenIDTokenInfo { accessToken?: string; @@ -11,14 +11,7 @@ export interface OpenIDTokenInfo { claims?: Record; } -interface FederatedTokens { - access_token?: string; - id_token?: string; - refresh_token?: string; - expires_at?: number; -} - -function isFederatedTokens(obj: unknown): obj is FederatedTokens { +function isFederatedTokens(obj: unknown): obj is OIDCTokens { if (!obj || typeof obj !== 'object') { return false; } @@ -61,23 +54,24 @@ export function extractOpenIDTokenInfo( const tokenInfo: OpenIDTokenInfo = {}; - if ('federatedTokens' in user && isFederatedTokens(user.federatedTokens)) { - const tokens = user.federatedTokens; + const federated = user.federatedTokens; + const openid = user.openidTokens; + + if (federated && isFederatedTokens(federated)) { logger.debug('[extractOpenIDTokenInfo] Found federatedTokens:', { - has_access_token: !!tokens.access_token, - has_id_token: !!tokens.id_token, - has_refresh_token: !!tokens.refresh_token, - expires_at: tokens.expires_at, + has_access_token: !!federated.access_token, + has_id_token: !!federated.id_token, + has_refresh_token: !!federated.refresh_token, + expires_at: federated.expires_at, }); - tokenInfo.accessToken = tokens.access_token; - tokenInfo.idToken = tokens.id_token; - tokenInfo.expiresAt = tokens.expires_at; - } else if ('openidTokens' in user && isFederatedTokens(user.openidTokens)) { - const tokens = user.openidTokens; + tokenInfo.accessToken = federated.access_token; + tokenInfo.idToken = federated.id_token; + tokenInfo.expiresAt = federated.expires_at; + } else if (openid && isFederatedTokens(openid)) { logger.debug('[extractOpenIDTokenInfo] Found openidTokens'); - tokenInfo.accessToken = tokens.access_token; - tokenInfo.idToken = tokens.id_token; - tokenInfo.expiresAt = tokens.expires_at; + tokenInfo.accessToken = openid.access_token; + tokenInfo.idToken = openid.id_token; + tokenInfo.expiresAt = openid.expires_at; } tokenInfo.userId = user.openidId || user.id; diff --git a/packages/api/types/index.d.ts b/packages/api/types/index.d.ts new file mode 100644 index 0000000000..f25bc46bda --- /dev/null +++ b/packages/api/types/index.d.ts @@ -0,0 +1,4 @@ +/** String.prototype.isWellFormed โ€” ES2024 API, available in Node 20+ but absent from TS 5.3 lib */ +interface String { + isWellFormed(): boolean; +} diff --git a/packages/client/src/components/OGDialogTemplate.tsx b/packages/client/src/components/OGDialogTemplate.tsx index 300ae5b194..2414915a4b 100644 --- a/packages/client/src/components/OGDialogTemplate.tsx +++ b/packages/client/src/components/OGDialogTemplate.tsx @@ -80,9 +80,8 @@ const OGDialogTemplate = forwardRef((props: DialogTemplateProps, ref: Ref; + file?: Partial & { progress?: number }; fileType: { fill: string; paths: React.FC; diff --git a/packages/data-schemas/src/methods/aclEntry.ts b/packages/data-schemas/src/methods/aclEntry.ts index 2f61861029..d93693641c 100644 --- a/packages/data-schemas/src/methods/aclEntry.ts +++ b/packages/data-schemas/src/methods/aclEntry.ts @@ -7,7 +7,7 @@ import type { DeleteResult, Model, } from 'mongoose'; -import type { IAclEntry } from '~/types'; +import type { AclEntry, IAclEntry } from '~/types'; import { tenantSafeBulkWrite } from '~/utils/tenantBulkWrite'; export function createAclEntryMethods(mongoose: typeof import('mongoose')) { @@ -375,11 +375,11 @@ export function createAclEntryMethods(mongoose: typeof import('mongoose')) { * @param options - Optional query options (e.g., { session }) */ async function bulkWriteAclEntries( - ops: AnyBulkWriteOperation[], + ops: AnyBulkWriteOperation[], options?: { session?: ClientSession }, ) { const AclEntry = mongoose.models.AclEntry as Model; - return tenantSafeBulkWrite(AclEntry, ops, options || {}); + return tenantSafeBulkWrite(AclEntry, ops as AnyBulkWriteOperation[], options || {}); } /** diff --git a/packages/data-schemas/src/methods/config.spec.ts b/packages/data-schemas/src/methods/config.spec.ts index 82f43c2b37..8bcf73a733 100644 --- a/packages/data-schemas/src/methods/config.spec.ts +++ b/packages/data-schemas/src/methods/config.spec.ts @@ -70,7 +70,7 @@ describe('upsertConfig', () => { PrincipalType.ROLE, 'admin', PrincipalModel.ROLE, - { a: 1 }, + { interface: { endpointsMenu: true } }, 10, ); @@ -78,7 +78,7 @@ describe('upsertConfig', () => { PrincipalType.ROLE, 'admin', PrincipalModel.ROLE, - { a: 2 }, + { interface: { endpointsMenu: false } }, 10, ); @@ -88,7 +88,7 @@ describe('upsertConfig', () => { it('normalizes ObjectId principalId to string', async () => { const oid = new Types.ObjectId(); - await methods.upsertConfig(PrincipalType.USER, oid, PrincipalModel.USER, { test: true }, 100); + await methods.upsertConfig(PrincipalType.USER, oid, PrincipalModel.USER, { cache: true }, 100); const found = await methods.findConfigByPrincipal(PrincipalType.USER, oid.toString()); expect(found).toBeTruthy(); @@ -98,7 +98,13 @@ describe('upsertConfig', () => { describe('findConfigByPrincipal', () => { it('finds an active config', async () => { - await methods.upsertConfig(PrincipalType.ROLE, 'admin', PrincipalModel.ROLE, { x: 1 }, 10); + await methods.upsertConfig( + PrincipalType.ROLE, + 'admin', + PrincipalModel.ROLE, + { cache: true }, + 10, + ); const result = await methods.findConfigByPrincipal(PrincipalType.ROLE, 'admin'); expect(result).toBeTruthy(); @@ -111,7 +117,13 @@ describe('findConfigByPrincipal', () => { }); it('does not find inactive configs', async () => { - await methods.upsertConfig(PrincipalType.ROLE, 'admin', PrincipalModel.ROLE, { x: 1 }, 10); + await methods.upsertConfig( + PrincipalType.ROLE, + 'admin', + PrincipalModel.ROLE, + { cache: true }, + 10, + ); await methods.toggleConfigActive(PrincipalType.ROLE, 'admin', false); const result = await methods.findConfigByPrincipal(PrincipalType.ROLE, 'admin'); @@ -155,7 +167,13 @@ describe('listAllConfigs', () => { describe('getApplicableConfigs', () => { it('always includes the __base__ config', async () => { - await methods.upsertConfig(PrincipalType.ROLE, '__base__', PrincipalModel.ROLE, { a: 1 }, 0); + await methods.upsertConfig( + PrincipalType.ROLE, + '__base__', + PrincipalModel.ROLE, + { cache: true }, + 0, + ); const configs = await methods.getApplicableConfigs([]); expect(configs).toHaveLength(1); @@ -163,9 +181,27 @@ describe('getApplicableConfigs', () => { }); it('returns base + matching principals', async () => { - await methods.upsertConfig(PrincipalType.ROLE, '__base__', PrincipalModel.ROLE, { a: 1 }, 0); - await methods.upsertConfig(PrincipalType.ROLE, 'admin', PrincipalModel.ROLE, { b: 2 }, 10); - await methods.upsertConfig(PrincipalType.ROLE, 'user', PrincipalModel.ROLE, { c: 3 }, 10); + await methods.upsertConfig( + PrincipalType.ROLE, + '__base__', + PrincipalModel.ROLE, + { cache: true }, + 0, + ); + await methods.upsertConfig( + PrincipalType.ROLE, + 'admin', + PrincipalModel.ROLE, + { version: '2' }, + 10, + ); + await methods.upsertConfig( + PrincipalType.ROLE, + 'user', + PrincipalModel.ROLE, + { version: '3' }, + 10, + ); const configs = await methods.getApplicableConfigs([ { principalType: PrincipalType.ROLE, principalId: 'admin' }, diff --git a/packages/data-schemas/src/methods/role.methods.spec.ts b/packages/data-schemas/src/methods/role.methods.spec.ts index f8a66bef5d..be75be7b6f 100644 --- a/packages/data-schemas/src/methods/role.methods.spec.ts +++ b/packages/data-schemas/src/methods/role.methods.spec.ts @@ -285,12 +285,12 @@ describe('updateAccessPermissions', () => { const updatedRole = await getRoleByName(SystemRoles.USER); // SHARED_GLOBAL=true โ†’ SHARE=true (inherited) - expect(updatedRole.permissions[PermissionTypes.PROMPTS].SHARE).toBe(true); + expect(updatedRole.permissions[PermissionTypes.PROMPTS]!.SHARE).toBe(true); // SHARED_GLOBAL=false โ†’ SHARE=false (inherited) - expect(updatedRole.permissions[PermissionTypes.AGENTS].SHARE).toBe(false); + expect(updatedRole.permissions[PermissionTypes.AGENTS]!.SHARE).toBe(false); // SHARED_GLOBAL cleaned up - expect(updatedRole.permissions[PermissionTypes.PROMPTS].SHARED_GLOBAL).toBeUndefined(); - expect(updatedRole.permissions[PermissionTypes.AGENTS].SHARED_GLOBAL).toBeUndefined(); + expect(updatedRole.permissions[PermissionTypes.PROMPTS]).not.toHaveProperty('SHARED_GLOBAL'); + expect(updatedRole.permissions[PermissionTypes.AGENTS]).not.toHaveProperty('SHARED_GLOBAL'); }); it('should respect explicit SHARE in update payload and not override it with SHARED_GLOBAL', async () => { @@ -309,8 +309,8 @@ describe('updateAccessPermissions', () => { const updatedRole = await getRoleByName(SystemRoles.USER); - expect(updatedRole.permissions[PermissionTypes.PROMPTS].SHARE).toBe(false); - expect(updatedRole.permissions[PermissionTypes.PROMPTS].SHARED_GLOBAL).toBeUndefined(); + expect(updatedRole.permissions[PermissionTypes.PROMPTS]!.SHARE).toBe(false); + expect(updatedRole.permissions[PermissionTypes.PROMPTS]).not.toHaveProperty('SHARED_GLOBAL'); }); it('should migrate SHARED_GLOBAL to SHARE even when the permType is not in the update payload', async () => { @@ -336,13 +336,13 @@ describe('updateAccessPermissions', () => { const updatedRole = await getRoleByName(SystemRoles.USER); // SHARE should have been inherited from SHARED_GLOBAL, not silently dropped - expect(updatedRole.permissions[PermissionTypes.PROMPTS].SHARE).toBe(true); + expect(updatedRole.permissions[PermissionTypes.PROMPTS]!.SHARE).toBe(true); // SHARED_GLOBAL should be removed - expect(updatedRole.permissions[PermissionTypes.PROMPTS].SHARED_GLOBAL).toBeUndefined(); + expect(updatedRole.permissions[PermissionTypes.PROMPTS]).not.toHaveProperty('SHARED_GLOBAL'); // Original USE should be untouched - expect(updatedRole.permissions[PermissionTypes.PROMPTS].USE).toBe(true); + expect(updatedRole.permissions[PermissionTypes.PROMPTS]!.USE).toBe(true); // The actual update should have applied - expect(updatedRole.permissions[PermissionTypes.MULTI_CONVO].USE).toBe(true); + expect(updatedRole.permissions[PermissionTypes.MULTI_CONVO]!.USE).toBe(true); }); it('should remove orphaned SHARED_GLOBAL when SHARE already exists and permType is not in update', async () => { @@ -366,9 +366,9 @@ describe('updateAccessPermissions', () => { const updatedRole = await getRoleByName(SystemRoles.USER); - expect(updatedRole.permissions[PermissionTypes.PROMPTS].SHARED_GLOBAL).toBeUndefined(); - expect(updatedRole.permissions[PermissionTypes.PROMPTS].SHARE).toBe(true); - expect(updatedRole.permissions[PermissionTypes.MULTI_CONVO].USE).toBe(true); + expect(updatedRole.permissions[PermissionTypes.PROMPTS]).not.toHaveProperty('SHARED_GLOBAL'); + expect(updatedRole.permissions[PermissionTypes.PROMPTS]!.SHARE).toBe(true); + expect(updatedRole.permissions[PermissionTypes.MULTI_CONVO]!.USE).toBe(true); }); it('should not update MULTI_CONVO permissions when no changes are needed', async () => { diff --git a/packages/data-schemas/src/methods/role.ts b/packages/data-schemas/src/methods/role.ts index 442041dcde..e84b91420a 100644 --- a/packages/data-schemas/src/methods/role.ts +++ b/packages/data-schemas/src/methods/role.ts @@ -69,7 +69,7 @@ export function createRoleMethods(mongoose: typeof import('mongoose'), deps: Rol limit?: number; offset?: number; }): Promise[]> { - const Role = mongoose.models.Role; + const Role = mongoose.models.Role as Model; const limit = options?.limit ?? 50; const offset = options?.offset ?? 0; return await Role.find({}) diff --git a/packages/data-schemas/src/methods/spendTokens.spec.ts b/packages/data-schemas/src/methods/spendTokens.spec.ts index 5730bc7bdd..d505663d57 100644 --- a/packages/data-schemas/src/methods/spendTokens.spec.ts +++ b/packages/data-schemas/src/methods/spendTokens.spec.ts @@ -864,8 +864,8 @@ describe('spendTokens', () => { const expectedCompletionCost = tokenUsage.completionTokens * premiumCompletionRate; expect(result).not.toBeNull(); - expect(result!.prompt.prompt).toBeCloseTo(-expectedPromptCost, 0); - expect(result!.completion.completion).toBeCloseTo(-expectedCompletionCost, 0); + expect(result!.prompt!.prompt).toBeCloseTo(-expectedPromptCost, 0); + expect(result!.completion!.completion).toBeCloseTo(-expectedCompletionCost, 0); }); it('should charge standard rates for structured tokens when below threshold', async () => { @@ -907,8 +907,8 @@ describe('spendTokens', () => { const expectedCompletionCost = tokenUsage.completionTokens * standardCompletionRate; expect(result).not.toBeNull(); - expect(result!.prompt.prompt).toBeCloseTo(-expectedPromptCost, 0); - expect(result!.completion.completion).toBeCloseTo(-expectedCompletionCost, 0); + expect(result!.prompt!.prompt).toBeCloseTo(-expectedPromptCost, 0); + expect(result!.completion!.completion).toBeCloseTo(-expectedCompletionCost, 0); }); it('should charge standard rates for gemini-3.1-pro-preview when prompt tokens are below threshold', async () => { @@ -937,7 +937,7 @@ describe('spendTokens', () => { completionTokens * tokenValues['gemini-3.1'].completion; const balance = await Balance.findOne({ user: userId }); - expect(balance.tokenCredits).toBeCloseTo(initialBalance - expectedCost, 0); + expect(balance!.tokenCredits).toBeCloseTo(initialBalance - expectedCost, 0); }); it('should charge premium rates for gemini-3.1-pro-preview when prompt tokens exceed threshold', async () => { @@ -966,7 +966,7 @@ describe('spendTokens', () => { completionTokens * premiumTokenValues['gemini-3.1'].completion; const balance = await Balance.findOne({ user: userId }); - expect(balance.tokenCredits).toBeCloseTo(initialBalance - expectedCost, 0); + expect(balance!.tokenCredits).toBeCloseTo(initialBalance - expectedCost, 0); }); it('should charge premium rates for gemini-3.1-pro-preview-customtools when prompt tokens exceed threshold', async () => { @@ -995,7 +995,7 @@ describe('spendTokens', () => { completionTokens * premiumTokenValues['gemini-3.1'].completion; const balance = await Balance.findOne({ user: userId }); - expect(balance.tokenCredits).toBeCloseTo(initialBalance - expectedCost, 0); + expect(balance!.tokenCredits).toBeCloseTo(initialBalance - expectedCost, 0); }); it('should charge premium rates for structured gemini-3.1 tokens when total input exceeds threshold', async () => { @@ -1032,13 +1032,13 @@ describe('spendTokens', () => { const expectedPromptCost = tokenUsage.promptTokens.input * premiumPromptRate + - tokenUsage.promptTokens.write * writeRate + - tokenUsage.promptTokens.read * readRate; + tokenUsage.promptTokens.write * writeRate! + + tokenUsage.promptTokens.read * readRate!; const expectedCompletionCost = tokenUsage.completionTokens * premiumCompletionRate; expect(result).not.toBeNull(); - expect(result!.prompt.prompt).toBeCloseTo(-expectedPromptCost, 0); - expect(result!.completion.completion).toBeCloseTo(-expectedCompletionCost, 0); + expect(result!.prompt!.prompt).toBeCloseTo(-expectedPromptCost, 0); + expect(result!.completion!.completion).toBeCloseTo(-expectedCompletionCost, 0); }); it('should not apply premium pricing to non-premium models regardless of prompt size', async () => { diff --git a/packages/data-schemas/src/methods/tx.ts b/packages/data-schemas/src/methods/tx.ts index a1be4190ba..a048874457 100644 --- a/packages/data-schemas/src/methods/tx.ts +++ b/packages/data-schemas/src/methods/tx.ts @@ -387,7 +387,7 @@ export function createTxMethods(_mongoose: typeof import('mongoose'), txDeps: Tx function getPremiumRate( valueKey: string, tokenType: string, - inputTokenCount?: number, + inputTokenCount?: number | null, ): number | null { if (inputTokenCount == null) { return null; diff --git a/packages/data-schemas/src/methods/user.test.ts b/packages/data-schemas/src/methods/user.test.ts index 522e4fe158..5e557805e4 100644 --- a/packages/data-schemas/src/methods/user.test.ts +++ b/packages/data-schemas/src/methods/user.test.ts @@ -18,7 +18,7 @@ describe('User Methods', () => { describe('generateToken', () => { const mockUser = { - _id: 'user123', + _id: new mongoose.Types.ObjectId('aaaaaaaaaaaaaaaaaaaaaaaa'), username: 'testuser', provider: 'local', email: 'test@example.com', diff --git a/packages/data-schemas/src/methods/user.ts b/packages/data-schemas/src/methods/user.ts index 137c01d0cd..0b630e49b3 100644 --- a/packages/data-schemas/src/methods/user.ts +++ b/packages/data-schemas/src/methods/user.ts @@ -35,26 +35,26 @@ export function createUserMethods(mongoose: typeof import('mongoose')) { searchCriteria: FilterQuery, fieldsToSelect?: string | string[] | null, ): Promise { - const User = mongoose.models.User; + const User = mongoose.models.User as mongoose.Model; const normalizedCriteria = normalizeEmailInCriteria(searchCriteria); const query = User.findOne(normalizedCriteria); if (fieldsToSelect) { query.select(fieldsToSelect); } - return (await query.lean()) as IUser | null; + return await query.lean(); } async function findUsers( searchCriteria: FilterQuery, fieldsToSelect?: string | string[] | null, ): Promise { - const User = mongoose.models.User; + const User = mongoose.models.User as mongoose.Model; const normalizedCriteria = normalizeEmailInCriteria(searchCriteria); const query = User.find(normalizedCriteria); if (fieldsToSelect) { query.select(fieldsToSelect); } - return (await query.lean()) as IUser[]; + return await query.lean(); } /** @@ -301,8 +301,6 @@ export function createUserMethods(mongoose: typeof import('mongoose')) { .sort((a, b) => b._searchScore - a._searchScore) .slice(0, limit) .map((user) => { - // Remove the search score from final results - // eslint-disable-next-line @typescript-eslint/no-unused-vars const { _searchScore, ...userWithoutScore } = user; return userWithoutScore; }); diff --git a/packages/data-schemas/src/methods/userGroup.spec.ts b/packages/data-schemas/src/methods/userGroup.spec.ts index 675fdb2592..ca83ced7d9 100644 --- a/packages/data-schemas/src/methods/userGroup.spec.ts +++ b/packages/data-schemas/src/methods/userGroup.spec.ts @@ -496,7 +496,7 @@ describe('userGroup methods', () => { it('returns the updated user document', async () => { const user = await createTestUser({ idOnTheSource: 'user-ext-1' }); const { user: updatedUser } = await methods.syncUserEntraGroups(user._id, []); - expect(updatedUser._id.toString()).toBe(user._id.toString()); + expect((updatedUser._id as Types.ObjectId).toString()).toBe(user._id.toString()); }); }); diff --git a/packages/data-schemas/src/methods/userGroup.ts b/packages/data-schemas/src/methods/userGroup.ts index 948542e6de..0e6b57adb2 100644 --- a/packages/data-schemas/src/methods/userGroup.ts +++ b/packages/data-schemas/src/methods/userGroup.ts @@ -265,13 +265,14 @@ export function createUserGroupMethods(mongoose: typeof import('mongoose')) { role?: string | null; }, session?: ClientSession, - ): Promise> { + ): Promise> { const { userId, role } = params; /** `userId` must be an `ObjectId` for USER principal since ACL entries store `ObjectId`s */ const userObjectId = typeof userId === 'string' ? new Types.ObjectId(userId) : userId; - const principals: Array<{ principalType: string; principalId?: string | Types.ObjectId }> = [ - { principalType: PrincipalType.USER, principalId: userObjectId }, - ]; + const principals: Array<{ + principalType: PrincipalType; + principalId?: string | Types.ObjectId; + }> = [{ principalType: PrincipalType.USER, principalId: userObjectId }]; // If role is not provided, query user to get it let userRole = role; diff --git a/packages/data-schemas/src/migrations/promptGroupIndexes.ts b/packages/data-schemas/src/migrations/promptGroupIndexes.ts index 4b6013c9e4..2d389f3f09 100644 --- a/packages/data-schemas/src/migrations/promptGroupIndexes.ts +++ b/packages/data-schemas/src/migrations/promptGroupIndexes.ts @@ -18,7 +18,7 @@ export async function dropSupersededPromptGroupIndexes( let collection; try { - collection = connection.db.collection(collectionName); + collection = connection.db!.collection(collectionName); } catch { result.skipped.push( ...SUPERSEDED_PROMPT_GROUP_INDEXES.map( diff --git a/packages/data-schemas/src/migrations/tenantIndexes.spec.ts b/packages/data-schemas/src/migrations/tenantIndexes.spec.ts index e62b587a6e..6a0987d757 100644 --- a/packages/data-schemas/src/migrations/tenantIndexes.spec.ts +++ b/packages/data-schemas/src/migrations/tenantIndexes.spec.ts @@ -1,4 +1,4 @@ -import mongoose, { Schema } from 'mongoose'; +import mongoose from 'mongoose'; import { MongoMemoryServer } from 'mongodb-memory-server'; import { dropSupersededTenantIndexes, SUPERSEDED_INDEXES } from './tenantIndexes'; @@ -24,7 +24,7 @@ afterAll(async () => { describe('dropSupersededTenantIndexes', () => { describe('with pre-existing single-field unique indexes (simulates upgrade)', () => { beforeAll(async () => { - const db = mongoose.connection.db; + const db = mongoose.connection.db!; await db.createCollection('users'); const users = db.collection('users'); @@ -133,7 +133,7 @@ describe('dropSupersededTenantIndexes', () => { }); it('old unique indexes are actually gone from users collection', async () => { - const indexes = await mongoose.connection.db.collection('users').indexes(); + const indexes = await mongoose.connection.db!.collection('users').indexes(); const indexNames = indexes.map((idx) => idx.name); expect(indexNames).not.toContain('email_1'); @@ -143,14 +143,14 @@ describe('dropSupersededTenantIndexes', () => { }); it('old unique indexes are actually gone from roles collection', async () => { - const indexes = await mongoose.connection.db.collection('roles').indexes(); + const indexes = await mongoose.connection.db!.collection('roles').indexes(); const indexNames = indexes.map((idx) => idx.name); expect(indexNames).not.toContain('name_1'); }); it('old compound unique indexes are gone from conversations collection', async () => { - const indexes = await mongoose.connection.db.collection('conversations').indexes(); + const indexes = await mongoose.connection.db!.collection('conversations').indexes(); const indexNames = indexes.map((idx) => idx.name); expect(indexNames).not.toContain('conversationId_1_user_1'); @@ -159,7 +159,7 @@ describe('dropSupersededTenantIndexes', () => { describe('multi-tenant writes after migration', () => { beforeAll(async () => { - const db = mongoose.connection.db; + const db = mongoose.connection.db!; const users = db.collection('users'); await users.createIndex( @@ -169,7 +169,7 @@ describe('dropSupersededTenantIndexes', () => { }); it('allows same email in different tenants after old index is dropped', async () => { - const users = mongoose.connection.db.collection('users'); + const users = mongoose.connection.db!.collection('users'); await users.insertOne({ email: 'shared@example.com', @@ -196,7 +196,7 @@ describe('dropSupersededTenantIndexes', () => { }); it('still rejects duplicate email within same tenant', async () => { - const users = mongoose.connection.db.collection('users'); + const users = mongoose.connection.db!.collection('users'); await users.insertOne({ email: 'unique-within@example.com', @@ -247,7 +247,7 @@ describe('dropSupersededTenantIndexes', () => { partialConnection = mongoose.createConnection(partialServer.getUri()); await partialConnection.asPromise(); - const db = partialConnection.db; + const db = partialConnection.db!; await db.createCollection('users'); await db.collection('users').createIndex({ email: 1 }, { unique: true, name: 'email_1' }); }); diff --git a/packages/data-schemas/src/migrations/tenantIndexes.ts b/packages/data-schemas/src/migrations/tenantIndexes.ts index a8b4e51768..6536423ad2 100644 --- a/packages/data-schemas/src/migrations/tenantIndexes.ts +++ b/packages/data-schemas/src/migrations/tenantIndexes.ts @@ -55,7 +55,7 @@ export async function dropSupersededTenantIndexes( const result: MigrationResult = { dropped: [], skipped: [], errors: [] }; for (const [collectionName, indexNames] of Object.entries(SUPERSEDED_INDEXES)) { - const collection = connection.db.collection(collectionName); + const collection = connection.db!.collection(collectionName); let existingIndexes: Array<{ name?: string }>; try { diff --git a/packages/data-schemas/src/types/admin.ts b/packages/data-schemas/src/types/admin.ts index 9b30cdb98a..a16f68ae9c 100644 --- a/packages/data-schemas/src/types/admin.ts +++ b/packages/data-schemas/src/types/admin.ts @@ -1,10 +1,4 @@ -import type { - PrincipalType, - PrincipalModel, - TCustomConfig, - z, - configSchema, -} from 'librechat-data-provider'; +import type { PrincipalType, PrincipalModel, TCustomConfig } from 'librechat-data-provider'; import type { SystemCapabilities } from '~/admin/capabilities'; /* โ”€โ”€ Capability types โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ */ @@ -16,7 +10,7 @@ export type BaseSystemCapability = (typeof SystemCapabilities)[keyof typeof Syst export type ConfigAssignTarget = 'user' | 'group' | 'role'; /** Top-level keys of the configSchema from librechat.yaml. */ -export type ConfigSection = keyof z.infer; +export type ConfigSection = string & keyof TCustomConfig; /** Section-level config capabilities derived from configSchema keys. */ type ConfigSectionCapability = `manage:configs:${ConfigSection}` | `read:configs:${ConfigSection}`; diff --git a/packages/data-schemas/src/types/user.ts b/packages/data-schemas/src/types/user.ts index 0fac46ee63..2d8eb82f47 100644 --- a/packages/data-schemas/src/types/user.ts +++ b/packages/data-schemas/src/types/user.ts @@ -2,6 +2,7 @@ import type { Document, Types } from 'mongoose'; import { CursorPaginationParams } from '~/common'; export interface IUser extends Document { + _id: Types.ObjectId; name?: string; username?: string; email: string; @@ -50,6 +51,15 @@ export interface IUser extends Document { /** Field for external source identification (for consistency with TPrincipal schema) */ idOnTheSource?: string; tenantId?: string; + federatedTokens?: OIDCTokens; + openidTokens?: OIDCTokens; +} + +export interface OIDCTokens { + access_token?: string; + id_token?: string; + refresh_token?: string; + expires_at?: number; } export interface BalanceConfig { diff --git a/packages/data-schemas/src/types/winston-transports.d.ts b/packages/data-schemas/src/types/winston-transports.d.ts new file mode 100644 index 0000000000..704486e5ce --- /dev/null +++ b/packages/data-schemas/src/types/winston-transports.d.ts @@ -0,0 +1,34 @@ +import type TransportStream from 'winston-transport'; + +/** + * Module augmentation for winston's transports namespace. + * + * `winston-daily-rotate-file` ships its own augmentation targeting + * `'winston/lib/winston/transports'`, but it fails when winston and + * winston-daily-rotate-file resolve from different node_modules trees + * (which happens in this monorepo due to npm hoisting). This local + * declaration bridges the gap so `tsc --noEmit` passes. + */ +declare module 'winston/lib/winston/transports' { + interface Transports { + DailyRotateFile: new ( + opts?: { + level?: string; + filename?: string; + datePattern?: string; + zippedArchive?: boolean; + maxSize?: string | number; + maxFiles?: string | number; + dirname?: string; + stream?: NodeJS.WritableStream; + frequency?: string; + utc?: boolean; + extension?: string; + createSymlink?: boolean; + symlinkName?: string; + auditFile?: string; + format?: import('logform').Format; + } & TransportStream.TransportStreamOptions, + ) => TransportStream; + } +} From f82d4300a4717f5ed035ef27c47f75559d0a8439 Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Sat, 28 Mar 2026 23:44:58 -0400 Subject: [PATCH 16/18] =?UTF-8?q?=F0=9F=A7=B9=20chore:=20Remove=20Deprecat?= =?UTF-8?q?ed=20Gemini=202.0=20Models=20&=20Fix=20Mistral-Large-3=20Contex?= =?UTF-8?q?t=20Window=20(#12453)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * chore: remove deprecated Gemini 2.0 models from default models list Remove gemini-2.0-flash-001 and gemini-2.0-flash-lite from the Google default models array, as they have been deprecated by Google. Closes #12444 * fix: add mistral-large-3 max context tokens (256k) Add mistral-large-3 with 255000 max context tokens to the mistralModels map. Without this entry, the model falls back to the generic mistral-large key (131k), causing context window errors when using tools with Azure AI Foundry deployments. Closes #12429 * test: add mistral-large-3 token resolution tests and fix key ordering Add test coverage for mistral-large-3 context token resolution, verifying exact match, suffixed variants, and longest-match precedence over the generic mistral-large key. Reorder the mistral-large-3 entry after mistral-large to follow the file's documented convention of listing newer models last for reverse-scan performance. --- api/utils/tokens.spec.js | 54 ++++++++++++++++++++++++++++ packages/api/src/utils/tokens.ts | 1 + packages/data-provider/src/config.ts | 3 -- 3 files changed, 55 insertions(+), 3 deletions(-) diff --git a/api/utils/tokens.spec.js b/api/utils/tokens.spec.js index 6cecdb95c8..dfa6762ee5 100644 --- a/api/utils/tokens.spec.js +++ b/api/utils/tokens.spec.js @@ -1813,3 +1813,57 @@ describe('GLM Model Tests (Zhipu AI)', () => { }); }); }); + +describe('Mistral Model Tests', () => { + describe('getModelMaxTokens', () => { + test('should return correct tokens for mistral-large-3 (256k context)', () => { + expect(getModelMaxTokens('mistral-large-3', EModelEndpoint.custom)).toBe( + maxTokensMap[EModelEndpoint.custom]['mistral-large-3'], + ); + }); + + test('should match mistral-large-3 for suffixed variants', () => { + expect(getModelMaxTokens('mistral-large-3-instruct', EModelEndpoint.custom)).toBe( + maxTokensMap[EModelEndpoint.custom]['mistral-large-3'], + ); + }); + + test('should not match mistral-large-3 for generic mistral-large', () => { + expect(getModelMaxTokens('mistral-large', EModelEndpoint.custom)).toBe( + maxTokensMap[EModelEndpoint.custom]['mistral-large'], + ); + expect(getModelMaxTokens('mistral-large-latest', EModelEndpoint.custom)).toBe( + maxTokensMap[EModelEndpoint.custom]['mistral-large'], + ); + }); + }); + + describe('matchModelName', () => { + test('should match mistral-large-3 exactly', () => { + expect(matchModelName('mistral-large-3', EModelEndpoint.custom)).toBe('mistral-large-3'); + }); + + test('should match mistral-large-3 for prefixed/suffixed variants', () => { + expect(matchModelName('mistral/mistral-large-3', EModelEndpoint.custom)).toBe( + 'mistral-large-3', + ); + expect(matchModelName('mistral-large-3-instruct', EModelEndpoint.custom)).toBe( + 'mistral-large-3', + ); + }); + + test('should match generic mistral-large for non-3 variants', () => { + expect(matchModelName('mistral-large-latest', EModelEndpoint.custom)).toBe('mistral-large'); + }); + }); + + describe('findMatchingPattern', () => { + test('should prefer mistral-large-3 over mistral-large for mistral-large-3 variants', () => { + const result = findMatchingPattern( + 'mistral-large-3-instruct', + maxTokensMap[EModelEndpoint.custom], + ); + expect(result).toBe('mistral-large-3'); + }); + }); +}); diff --git a/packages/api/src/utils/tokens.ts b/packages/api/src/utils/tokens.ts index ae09da4f28..14215698a6 100644 --- a/packages/api/src/utils/tokens.ts +++ b/packages/api/src/utils/tokens.ts @@ -72,6 +72,7 @@ const mistralModels = { 'mistral-large-2402': 127500, 'mistral-large-2407': 127500, 'mistral-large': 131000, + 'mistral-large-3': 255000, 'mistral-saba': 32000, 'ministral-3b': 131000, 'ministral-8b': 131000, diff --git a/packages/data-provider/src/config.ts b/packages/data-provider/src/config.ts index e641d7b63a..bb89c56f82 100644 --- a/packages/data-provider/src/config.ts +++ b/packages/data-provider/src/config.ts @@ -1238,9 +1238,6 @@ export const defaultModels = { 'gemini-2.5-pro', 'gemini-2.5-flash', 'gemini-2.5-flash-lite', - // Gemini 2.0 Models - 'gemini-2.0-flash-001', - 'gemini-2.0-flash-lite', ], [EModelEndpoint.anthropic]: sharedAnthropicModels, [EModelEndpoint.openAI]: [ From 0d94881c2db0010c96340fafb806a710dd05a250 Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Sun, 29 Mar 2026 01:10:57 -0400 Subject: [PATCH 17/18] =?UTF-8?q?=F0=9F=A7=B9=20refactor:=20Tighten=20Conf?= =?UTF-8?q?ig=20Schema=20Typing=20and=20Remove=20Deprecated=20Fields=20(#1?= =?UTF-8?q?2452)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * refactor: Remove deprecated and unused fields from endpoint schemas - Remove summarize, summaryModel from endpointSchema and azureEndpointSchema - Remove plugins from azureEndpointSchema - Remove customOrder from endpointSchema and azureEndpointSchema - Remove baseURL from all and agents endpoint schemas - Type paramDefinitions with full SettingDefinition-based schema - Clean up summarize/summaryModel references in initialize.ts and config.spec.ts * refactor: Improve MCP transport schema typing - Add defaults to transport type discriminators (stdio, websocket, sse) - Type stderr field as IOType union instead of z.any() * refactor: Add narrowed preset schema for model specs - Create tModelSpecPresetSchema omitting system/DB/deprecated fields - Update tModelSpecSchema to use the narrowed preset schema * test: Add explicit type field to MCP test fixtures Add transport type discriminator to test objects that construct MCPOptions/ParsedServerConfig directly, required after type field changed from optional to default in schema definitions. * chore: Bump librechat-data-provider to 0.8.404 * refactor: Tighten z.record(z.any()) fields to precise value types - Type headers fields as z.record(z.string()) in endpoint, assistant, and azure schemas - Type addParams as z.record(z.union([z.string(), z.number(), z.boolean(), z.null()])) - Type azure additionalHeaders as z.record(z.string()) - Type memory model_parameters as z.record(z.union([z.string(), z.number(), z.boolean()])) - Type firecrawl changeTrackingOptions.schema as z.record(z.string()) * refactor: Type supportedMimeTypes schema as z.array(z.string()) Replace z.array(z.any()).refine() with z.array(z.string()) since config input is always strings that get converted to RegExp via convertStringsToRegex() after parsing. Destructure supportedMimeTypes from spreads to avoid string[]/RegExp[] type mismatch. * refactor: Tighten enum, role, and numeric constraint schemas - Type engineSTT as enum ['openai', 'azureOpenAI'] - Type engineTTS as enum ['openai', 'azureOpenAI', 'elevenlabs', 'localai'] - Constrain playbackRate to 0.25โ€“4 range - Type titleMessageRole as enum ['system', 'user', 'assistant'] - Add int().nonnegative() to MCP timeout and firecrawl timeout * chore: Bump librechat-data-provider to 0.8.405 * fix: Accept both string and RegExp in supportedMimeTypes schema The schema must accept both string[] (config input) and RegExp[] (post-merge runtime) since tests validate merged output against the schema. Use z.union([z.string(), z.instanceof(RegExp)]) to handle both. * refactor: Address review findings for schema tightening PR - Revert changeTrackingOptions.schema to z.record(z.unknown()) (JSON Schema is nested, not flat strings) - Remove dead contextStrategy code from BaseClient.js and cleanup.js - Extract paramDefinitionSchema to named exported constant - Add .int() constraint to columnSpan and columns - Apply consistent .int().nonnegative() to initTimeout, sseReadTimeout, scraperTimeout - Update stale stderr JSDoc to match actual accepted types - Add comprehensive tests for paramDefinitionSchema, tModelSpecPresetSchema, endpointSchema deprecated field stripping, and azureEndpointSchema * fix: Address second review pass findings - Revert supportedMimeTypesSchema to z.array(z.string()) and remove as string[] casts โ€” fix tests to not validate merged RegExp[] output against the config input schema - Remove unused tModelSpecSchema import from test file - Consolidate duplicate '../src/schemas' imports - Add expiredAt coverage to tModelSpecPresetSchema test - Assert plugins is absent in azureEndpointSchema test - Add sync comments for engineSTT/engineTTS enum literals * refactor: Omit preset-management fields from tModelSpecPresetSchema Omit conversationId, presetId, title, defaultPreset, and order from the model spec preset schema โ€” these are preset-management fields that don't belong in model spec configuration. --- api/app/clients/BaseClient.js | 1 - api/server/cleanup.js | 3 - .../api/src/endpoints/custom/initialize.ts | 2 - .../api/src/endpoints/openai/config.spec.ts | 4 - .../__tests__/ConnectionsRepository.test.ts | 4 +- .../MCPConnectionAgentLifecycle.test.ts | 12 +- packages/api/src/mcp/__tests__/mcp.spec.ts | 8 + .../ServerConfigsCacheInMemory.test.ts | 3 + packages/data-provider/package.json | 2 +- .../specs/config-schemas.spec.ts | 254 ++++++++++++++++++ .../data-provider/specs/filetypes.spec.ts | 7 - packages/data-provider/src/config.ts | 76 ++++-- packages/data-provider/src/file-config.ts | 31 +-- packages/data-provider/src/mcp.ts | 24 +- packages/data-provider/src/models.ts | 8 +- packages/data-provider/src/schemas.ts | 32 ++- 16 files changed, 384 insertions(+), 87 deletions(-) create mode 100644 packages/data-provider/specs/config-schemas.spec.ts diff --git a/api/app/clients/BaseClient.js b/api/app/clients/BaseClient.js index ec5ccfb5f4..08cb1f6ada 100644 --- a/api/app/clients/BaseClient.js +++ b/api/app/clients/BaseClient.js @@ -32,7 +32,6 @@ class BaseClient { constructor(apiKey, options = {}) { this.apiKey = apiKey; this.sender = options.sender ?? 'AI'; - this.contextStrategy = null; this.currentDateString = new Date().toLocaleDateString('en-us', { year: 'numeric', month: 'long', diff --git a/api/server/cleanup.js b/api/server/cleanup.js index 364c02cd8a..c27814292d 100644 --- a/api/server/cleanup.js +++ b/api/server/cleanup.js @@ -123,9 +123,6 @@ function disposeClient(client) { if (client.maxContextTokens) { client.maxContextTokens = null; } - if (client.contextStrategy) { - client.contextStrategy = null; - } if (client.currentDateString) { client.currentDateString = null; } diff --git a/packages/api/src/endpoints/custom/initialize.ts b/packages/api/src/endpoints/custom/initialize.ts index 1250721500..ea0d2dbf5d 100644 --- a/packages/api/src/endpoints/custom/initialize.ts +++ b/packages/api/src/endpoints/custom/initialize.ts @@ -32,10 +32,8 @@ function buildCustomOptions( customParams: endpointConfig.customParams, titleConvo: endpointConfig.titleConvo, titleModel: endpointConfig.titleModel, - summaryModel: endpointConfig.summaryModel, modelDisplayLabel: endpointConfig.modelDisplayLabel, titleMethod: endpointConfig.titleMethod ?? 'completion', - contextStrategy: endpointConfig.summarize ? 'summarize' : null, directEndpoint: endpointConfig.directEndpoint, titleMessageRole: endpointConfig.titleMessageRole, streamRate: endpointConfig.streamRate, diff --git a/packages/api/src/endpoints/openai/config.spec.ts b/packages/api/src/endpoints/openai/config.spec.ts index cdf9d6f14c..46ad6a6295 100644 --- a/packages/api/src/endpoints/openai/config.spec.ts +++ b/packages/api/src/endpoints/openai/config.spec.ts @@ -1399,10 +1399,8 @@ describe('getOpenAIConfig', () => { dropParams: ['presence_penalty'], titleConvo: true, titleModel: 'gpt-3.5-turbo', - summaryModel: 'gpt-3.5-turbo', modelDisplayLabel: 'Custom GPT-4', titleMethod: 'completion', - contextStrategy: 'summarize', directEndpoint: true, titleMessageRole: 'user', streamRate: 25, @@ -1417,10 +1415,8 @@ describe('getOpenAIConfig', () => { customParams: {}, titleConvo: endpointConfig.titleConvo, titleModel: endpointConfig.titleModel, - summaryModel: endpointConfig.summaryModel, modelDisplayLabel: endpointConfig.modelDisplayLabel, titleMethod: endpointConfig.titleMethod, - contextStrategy: endpointConfig.contextStrategy, directEndpoint: endpointConfig.directEndpoint, titleMessageRole: endpointConfig.titleMessageRole, streamRate: endpointConfig.streamRate, diff --git a/packages/api/src/mcp/__tests__/ConnectionsRepository.test.ts b/packages/api/src/mcp/__tests__/ConnectionsRepository.test.ts index 7a93960765..dfb57a1faf 100644 --- a/packages/api/src/mcp/__tests__/ConnectionsRepository.test.ts +++ b/packages/api/src/mcp/__tests__/ConnectionsRepository.test.ts @@ -46,8 +46,8 @@ describe('ConnectionsRepository', () => { beforeEach(() => { mockServerConfigs = { - server1: { url: 'http://localhost:3001' }, - server2: { command: 'test-command', args: ['--test'] }, + server1: { url: 'http://localhost:3001', type: 'sse' }, + server2: { command: 'test-command', args: ['--test'], type: 'stdio' }, server3: { url: 'ws://localhost:8080', type: 'websocket' }, }; diff --git a/packages/api/src/mcp/__tests__/MCPConnectionAgentLifecycle.test.ts b/packages/api/src/mcp/__tests__/MCPConnectionAgentLifecycle.test.ts index 281bd590db..c7b6b273ba 100644 --- a/packages/api/src/mcp/__tests__/MCPConnectionAgentLifecycle.test.ts +++ b/packages/api/src/mcp/__tests__/MCPConnectionAgentLifecycle.test.ts @@ -377,7 +377,7 @@ describe('MCPConnection Agent lifecycle โ€“ SSE', () => { it('reuses the same Agents across multiple requests instead of creating one per request', async () => { conn = new MCPConnection({ serverName: 'test-sse', - serverConfig: { url: server.url }, + serverConfig: { url: server.url, type: 'sse' }, useSSRFProtection: false, }); @@ -402,7 +402,7 @@ describe('MCPConnection Agent lifecycle โ€“ SSE', () => { it('calls Agent.close() on every registered Agent when disconnect() is called', async () => { conn = new MCPConnection({ serverName: 'test-sse', - serverConfig: { url: server.url }, + serverConfig: { url: server.url, type: 'sse' }, useSSRFProtection: false, }); @@ -417,7 +417,7 @@ describe('MCPConnection Agent lifecycle โ€“ SSE', () => { it('closes at least two Agents for SSE transport (eventSourceInit + fetch)', async () => { conn = new MCPConnection({ serverName: 'test-sse', - serverConfig: { url: server.url }, + serverConfig: { url: server.url, type: 'sse' }, useSSRFProtection: false, }); @@ -431,7 +431,7 @@ describe('MCPConnection Agent lifecycle โ€“ SSE', () => { it('does not double-close Agents when disconnect() is called twice', async () => { conn = new MCPConnection({ serverName: 'test-sse', - serverConfig: { url: server.url }, + serverConfig: { url: server.url, type: 'sse' }, useSSRFProtection: false, }); @@ -533,7 +533,7 @@ describe('MCPConnection SSE 404 handling โ€“ session-aware', () => { function makeConn() { return new MCPConnection({ serverName: 'test-404', - serverConfig: { url: 'http://127.0.0.1:1/sse' }, + serverConfig: { url: 'http://127.0.0.1:1/sse', type: 'sse' }, useSSRFProtection: false, }); } @@ -599,7 +599,7 @@ describe('MCPConnection SSE stream disconnect handling', () => { function makeConn() { return new MCPConnection({ serverName: 'test-sse-disconnect', - serverConfig: { url: 'http://127.0.0.1:1/sse' }, + serverConfig: { url: 'http://127.0.0.1:1/sse', type: 'sse' }, useSSRFProtection: false, }); } diff --git a/packages/api/src/mcp/__tests__/mcp.spec.ts b/packages/api/src/mcp/__tests__/mcp.spec.ts index d64f9f3afa..d5cc44569f 100644 --- a/packages/api/src/mcp/__tests__/mcp.spec.ts +++ b/packages/api/src/mcp/__tests__/mcp.spec.ts @@ -179,6 +179,7 @@ describe('Environment Variable Extraction (MCP)', () => { describe('processMCPEnv', () => { it('should create a deep clone of the input object', () => { const originalObj: MCPOptions = { + type: 'stdio', command: 'node', args: ['server.js'], env: { @@ -202,6 +203,7 @@ describe('Environment Variable Extraction (MCP)', () => { it('should process environment variables in env field', () => { const options: MCPOptions = { + type: 'stdio', command: 'node', args: ['server.js'], env: { @@ -252,6 +254,7 @@ describe('Environment Variable Extraction (MCP)', () => { it('should not modify objects without env or headers', () => { const options: MCPOptions = { + type: 'stdio', command: 'node', args: ['server.js'], timeout: 5000, @@ -433,6 +436,7 @@ describe('Environment Variable Extraction (MCP)', () => { ldapId: 'ldap-user-123', }); const options: MCPOptions = { + type: 'stdio', command: 'node', args: ['server.js'], env: { @@ -599,6 +603,7 @@ describe('Environment Variable Extraction (MCP)', () => { CUSTOM_VAR_2: 'custom-value-2', }; const options: MCPOptions = { + type: 'stdio', command: 'node', args: ['server.js'], env: { @@ -674,6 +679,7 @@ describe('Environment Variable Extraction (MCP)', () => { PROFILE_NAME: 'production-profile', }; const options: MCPOptions = { + type: 'stdio', command: 'npx', args: [ '-y', @@ -734,6 +740,7 @@ describe('Environment Variable Extraction (MCP)', () => { UNUSED_VAR: 'unused-value', }; const options: MCPOptions = { + type: 'stdio', command: 'node', args: ['server.js'], env: { @@ -959,6 +966,7 @@ describe('Environment Variable Extraction (MCP)', () => { }) as unknown as IUser; const options: MCPOptions = { + type: 'stdio', command: 'node', args: ['mcp-server.js', '--user', '{{LIBRECHAT_USER_USERNAME}}'], env: { diff --git a/packages/api/src/mcp/registry/cache/__tests__/ServerConfigsCacheInMemory.test.ts b/packages/api/src/mcp/registry/cache/__tests__/ServerConfigsCacheInMemory.test.ts index c123325c1f..b8827a3fe9 100644 --- a/packages/api/src/mcp/registry/cache/__tests__/ServerConfigsCacheInMemory.test.ts +++ b/packages/api/src/mcp/registry/cache/__tests__/ServerConfigsCacheInMemory.test.ts @@ -12,6 +12,7 @@ describe('ServerConfigsCacheInMemory Integration Tests', () => { // Test data const mockConfig1: ParsedServerConfig = { + type: 'stdio', command: 'node', args: ['server1.js'], env: { TEST: 'value1' }, @@ -19,6 +20,7 @@ describe('ServerConfigsCacheInMemory Integration Tests', () => { }; const mockConfig2: ParsedServerConfig = { + type: 'stdio', command: 'python', args: ['server2.py'], env: { TEST: 'value2' }, @@ -26,6 +28,7 @@ describe('ServerConfigsCacheInMemory Integration Tests', () => { }; const mockConfig3: ParsedServerConfig = { + type: 'stdio', command: 'node', args: ['server3.js'], url: 'http://localhost:3000', diff --git a/packages/data-provider/package.json b/packages/data-provider/package.json index 1e0c76f37f..0cbe9258f2 100644 --- a/packages/data-provider/package.json +++ b/packages/data-provider/package.json @@ -1,6 +1,6 @@ { "name": "librechat-data-provider", - "version": "0.8.403", + "version": "0.8.405", "description": "data services for librechat apps", "main": "dist/index.js", "module": "dist/index.es.js", diff --git a/packages/data-provider/specs/config-schemas.spec.ts b/packages/data-provider/specs/config-schemas.spec.ts new file mode 100644 index 0000000000..fabd35cec9 --- /dev/null +++ b/packages/data-provider/specs/config-schemas.spec.ts @@ -0,0 +1,254 @@ +import { + endpointSchema, + paramDefinitionSchema, + agentsEndpointSchema, + azureEndpointSchema, +} from '../src/config'; +import { tModelSpecPresetSchema, EModelEndpoint } from '../src/schemas'; + +describe('paramDefinitionSchema', () => { + it('accepts a minimal definition with only key', () => { + const result = paramDefinitionSchema.safeParse({ key: 'temperature' }); + expect(result.success).toBe(true); + }); + + it('accepts a full definition with all fields', () => { + const result = paramDefinitionSchema.safeParse({ + key: 'temperature', + type: 'number', + component: 'slider', + default: 0.7, + label: 'Temperature', + range: { min: 0, max: 2, step: 0.01 }, + columns: 2, + columnSpan: 1, + includeInput: true, + descriptionSide: 'right', + }); + expect(result.success).toBe(true); + }); + + it('rejects columns > 4', () => { + const result = paramDefinitionSchema.safeParse({ + key: 'test', + columns: 5, + }); + expect(result.success).toBe(false); + }); + + it('rejects columns < 1', () => { + const result = paramDefinitionSchema.safeParse({ + key: 'test', + columns: 0, + }); + expect(result.success).toBe(false); + }); + + it('rejects non-integer columns', () => { + const result = paramDefinitionSchema.safeParse({ + key: 'test', + columns: 2.5, + }); + expect(result.success).toBe(false); + }); + + it('rejects non-integer columnSpan', () => { + const result = paramDefinitionSchema.safeParse({ + key: 'test', + columnSpan: 1.5, + }); + expect(result.success).toBe(false); + }); + + it('rejects negative minTags', () => { + const result = paramDefinitionSchema.safeParse({ + key: 'test', + minTags: -1, + }); + expect(result.success).toBe(false); + }); + + it('rejects invalid descriptionSide', () => { + const result = paramDefinitionSchema.safeParse({ + key: 'test', + descriptionSide: 'diagonal', + }); + expect(result.success).toBe(false); + }); + + it('rejects invalid type enum value', () => { + const result = paramDefinitionSchema.safeParse({ + key: 'test', + type: 'invalid', + }); + expect(result.success).toBe(false); + }); + + it('rejects invalid component enum value', () => { + const result = paramDefinitionSchema.safeParse({ + key: 'test', + component: 'wheel', + }); + expect(result.success).toBe(false); + }); + + it('allows type and component to be omitted (merged from defaults at runtime)', () => { + const result = paramDefinitionSchema.safeParse({ + key: 'temperature', + range: { min: 0, max: 2, step: 0.01 }, + }); + expect(result.success).toBe(true); + expect(result.data).not.toHaveProperty('type'); + expect(result.data).not.toHaveProperty('component'); + }); +}); + +describe('tModelSpecPresetSchema', () => { + it('strips system/DB fields from preset', () => { + const result = tModelSpecPresetSchema.safeParse({ + conversationId: 'conv-123', + presetId: 'preset-456', + title: 'My Preset', + defaultPreset: true, + order: 3, + isArchived: true, + user: 'user123', + messages: ['msg1'], + tags: ['tag1'], + file_ids: ['file1'], + expiredAt: '2026-12-31', + parentMessageId: 'parent1', + model: 'gpt-4o', + endpoint: EModelEndpoint.openAI, + }); + expect(result.success).toBe(true); + if (result.success) { + expect(result.data).not.toHaveProperty('conversationId'); + expect(result.data).not.toHaveProperty('presetId'); + expect(result.data).not.toHaveProperty('title'); + expect(result.data).not.toHaveProperty('defaultPreset'); + expect(result.data).not.toHaveProperty('order'); + expect(result.data).not.toHaveProperty('isArchived'); + expect(result.data).not.toHaveProperty('user'); + expect(result.data).not.toHaveProperty('messages'); + expect(result.data).not.toHaveProperty('tags'); + expect(result.data).not.toHaveProperty('file_ids'); + expect(result.data).not.toHaveProperty('expiredAt'); + expect(result.data).not.toHaveProperty('parentMessageId'); + expect(result.data).toHaveProperty('model', 'gpt-4o'); + } + }); + + it('strips deprecated fields', () => { + const result = tModelSpecPresetSchema.safeParse({ + resendImages: true, + chatGptLabel: 'old-label', + model: 'gpt-4o', + endpoint: EModelEndpoint.openAI, + }); + expect(result.success).toBe(true); + if (result.success) { + expect(result.data).not.toHaveProperty('resendImages'); + expect(result.data).not.toHaveProperty('chatGptLabel'); + } + }); + + it('strips frontend-only fields', () => { + const result = tModelSpecPresetSchema.safeParse({ + greeting: 'Hello!', + iconURL: 'https://example.com/icon.png', + spec: 'some-spec', + presetOverride: { model: 'other' }, + model: 'gpt-4o', + endpoint: EModelEndpoint.openAI, + }); + expect(result.success).toBe(true); + if (result.success) { + expect(result.data).not.toHaveProperty('greeting'); + expect(result.data).not.toHaveProperty('iconURL'); + expect(result.data).not.toHaveProperty('spec'); + expect(result.data).not.toHaveProperty('presetOverride'); + } + }); + + it('preserves valid preset fields', () => { + const result = tModelSpecPresetSchema.safeParse({ + model: 'gpt-4o', + endpoint: EModelEndpoint.openAI, + temperature: 0.7, + topP: 0.9, + maxOutputTokens: 4096, + }); + expect(result.success).toBe(true); + if (result.success) { + expect(result.data.model).toBe('gpt-4o'); + expect(result.data.temperature).toBe(0.7); + expect(result.data.topP).toBe(0.9); + expect(result.data.maxOutputTokens).toBe(4096); + } + }); +}); + +describe('endpointSchema deprecated fields', () => { + const validEndpoint = { + name: 'CustomEndpoint', + apiKey: 'test-key', + baseURL: 'https://api.example.com', + models: { default: ['model-1'] }, + }; + + it('silently strips deprecated summarize field', () => { + const result = endpointSchema.safeParse({ + ...validEndpoint, + summarize: true, + summaryModel: 'gpt-4o', + }); + expect(result.success).toBe(true); + if (result.success) { + expect(result.data).not.toHaveProperty('summarize'); + expect(result.data).not.toHaveProperty('summaryModel'); + } + }); + + it('silently strips deprecated customOrder field', () => { + const result = endpointSchema.safeParse({ + ...validEndpoint, + customOrder: 5, + }); + expect(result.success).toBe(true); + if (result.success) { + expect(result.data).not.toHaveProperty('customOrder'); + } + }); +}); + +describe('agentsEndpointSchema', () => { + it('does not accept baseURL', () => { + const result = agentsEndpointSchema.safeParse({ + baseURL: 'https://example.com', + }); + expect(result.success).toBe(true); + if (result.success) { + expect(result.data).not.toHaveProperty('baseURL'); + } + }); +}); + +describe('azureEndpointSchema', () => { + it('silently strips plugins field', () => { + const result = azureEndpointSchema.safeParse({ + groups: [ + { + group: 'test-group', + apiKey: 'test-key', + models: { 'gpt-4': true }, + }, + ], + plugins: true, + }); + expect(result.success).toBe(true); + if (result.success) { + expect(result.data).not.toHaveProperty('plugins'); + } + }); +}); diff --git a/packages/data-provider/specs/filetypes.spec.ts b/packages/data-provider/specs/filetypes.spec.ts index 39711dadd9..dba6cd4795 100644 --- a/packages/data-provider/specs/filetypes.spec.ts +++ b/packages/data-provider/specs/filetypes.spec.ts @@ -8,7 +8,6 @@ import { retrievalMimeTypes, excelFileTypes, excelMimeTypes, - fileConfigSchema, mergeFileConfig, mbToBytes, } from '../src/file-config'; @@ -126,8 +125,6 @@ describe('mergeFileConfig', () => { test('merges minimal update correctly', () => { const result = mergeFileConfig(dynamicConfigs.minimalUpdate); expect(result.serverFileSizeLimit).toEqual(mbToBytes(1024)); - const parsedResult = fileConfigSchema.safeParse(result); - expect(parsedResult.success).toBeTruthy(); }); test('overrides default endpoint with full new configuration', () => { @@ -136,8 +133,6 @@ describe('mergeFileConfig', () => { expect(result.endpoints.default.supportedMimeTypes).toEqual( expect.arrayContaining([new RegExp('^video/.*$')]), ); - const parsedResult = fileConfigSchema.safeParse(result); - expect(parsedResult.success).toBeTruthy(); }); test('adds new endpoint configuration correctly', () => { @@ -147,8 +142,6 @@ describe('mergeFileConfig', () => { expect(result.endpoints.newEndpoint.supportedMimeTypes).toEqual( expect.arrayContaining([new RegExp('^application/json$')]), ); - const parsedResult = fileConfigSchema.safeParse(result); - expect(parsedResult.success).toBeTruthy(); }); test('disables an endpoint and sets numeric fields to 0 and empties supportedMimeTypes', () => { diff --git a/packages/data-provider/src/config.ts b/packages/data-provider/src/config.ts index bb89c56f82..ae3f5b9560 100644 --- a/packages/data-provider/src/config.ts +++ b/packages/data-provider/src/config.ts @@ -2,6 +2,7 @@ import { z } from 'zod'; import type { ZodError } from 'zod'; import type { TEndpointsConfig, TModelsConfig, TConfig } from './types'; import { EModelEndpoint, eModelEndpointSchema, isAgentsEndpoint } from './schemas'; +import { ComponentTypes, SettingTypes, OptionTypes } from './generate'; import { specsConfigSchema, TSpecsConfig } from './models'; import { fileConfigSchema } from './file-config'; import { apiBaseUrl } from './api-endpoints'; @@ -120,11 +121,11 @@ export const azureBaseSchema = z.object({ instanceName: z.string().optional(), deploymentName: z.string().optional(), assistants: z.boolean().optional(), - addParams: z.record(z.any()).optional(), + addParams: z.record(z.union([z.string(), z.number(), z.boolean(), z.null()])).optional(), dropParams: z.array(z.string()).optional(), version: z.string().optional(), baseURL: z.string().optional(), - additionalHeaders: z.record(z.any()).optional(), + additionalHeaders: z.record(z.string()).optional(), }); export type TAzureBaseSchema = z.infer; @@ -257,7 +258,7 @@ export const assistantEndpointSchema = baseEndpointSchema.merge( userIdQuery: z.boolean().optional(), }) .optional(), - headers: z.record(z.any()).optional(), + headers: z.record(z.string()).optional(), }), ); @@ -279,6 +280,7 @@ export const defaultAgentCapabilities = [ ]; export const agentsEndpointSchema = baseEndpointSchema + .omit({ baseURL: true }) .merge( z.object({ /* agents specific */ @@ -305,6 +307,43 @@ export const agentsEndpointSchema = baseEndpointSchema export type TAgentsEndpoint = z.infer; +export const paramDefinitionSchema = z.object({ + key: z.string(), + description: z.string().optional(), + type: z.nativeEnum(SettingTypes).optional(), + default: z.union([z.number(), z.boolean(), z.string(), z.array(z.string())]).optional(), + showLabel: z.boolean().optional(), + showDefault: z.boolean().optional(), + options: z.array(z.string()).optional(), + range: z + .object({ + min: z.number(), + max: z.number(), + step: z.number().optional(), + }) + .optional(), + enumMappings: z.record(z.union([z.number(), z.boolean(), z.string()])).optional(), + component: z.nativeEnum(ComponentTypes).optional(), + optionType: z.nativeEnum(OptionTypes).optional(), + columnSpan: z.number().int().nonnegative().optional(), + columns: z.number().int().min(1).max(4).optional(), + label: z.string().optional(), + placeholder: z.string().optional(), + labelCode: z.boolean().optional(), + placeholderCode: z.boolean().optional(), + descriptionCode: z.boolean().optional(), + minText: z.number().optional(), + maxText: z.number().optional(), + minTags: z.number().min(0).optional(), + maxTags: z.number().min(0).optional(), + includeInput: z.boolean().optional(), + descriptionSide: z.enum(['top', 'right', 'bottom', 'left']).optional(), + searchPlaceholder: z.string().optional(), + selectPlaceholder: z.string().optional(), + searchPlaceholderCode: z.boolean().optional(), + selectPlaceholderCode: z.boolean().optional(), +}); + export const endpointSchema = baseEndpointSchema.merge( z.object({ name: z.string().refine((value) => !eModelEndpointSchema.safeParse(value).success, { @@ -319,23 +358,20 @@ export const endpointSchema = baseEndpointSchema.merge( fetch: z.boolean().optional(), userIdQuery: z.boolean().optional(), }), - summarize: z.boolean().optional(), - summaryModel: z.string().optional(), iconURL: z.string().optional(), modelDisplayLabel: z.string().optional(), - headers: z.record(z.any()).optional(), - addParams: z.record(z.any()).optional(), + headers: z.record(z.string()).optional(), + addParams: z.record(z.union([z.string(), z.number(), z.boolean(), z.null()])).optional(), dropParams: z.array(z.string()).optional(), customParams: z .object({ defaultParamsEndpoint: z.string().default('custom'), - paramDefinitions: z.array(z.record(z.any())).optional(), + paramDefinitions: z.array(paramDefinitionSchema).optional(), }) .strict() .optional(), - customOrder: z.number().optional(), directEndpoint: z.boolean().optional(), - titleMessageRole: z.string().optional(), + titleMessageRole: z.enum(['system', 'user', 'assistant']).optional(), }), ); @@ -344,7 +380,6 @@ export type TEndpoint = z.infer; export const azureEndpointSchema = z .object({ groups: azureGroupConfigsSchema, - plugins: z.boolean().optional(), assistants: z.boolean().optional(), }) .and( @@ -356,9 +391,6 @@ export const azureEndpointSchema = z titleModel: true, titlePrompt: true, titlePromptTemplate: true, - summarize: true, - summaryModel: true, - customOrder: true, }) .partial(), ); @@ -501,7 +533,8 @@ const speechTab = z .optional() .or( z.object({ - engineSTT: z.string().optional(), + /** Keep in sync with STTProviders enum (defined below โ€” cannot reference due to eval order) */ + engineSTT: z.enum(['openai', 'azureOpenAI']).optional(), languageSTT: z.string().optional(), autoTranscribeAudio: z.boolean().optional(), decibelValue: z.number().optional(), @@ -514,11 +547,12 @@ const speechTab = z .optional() .or( z.object({ - engineTTS: z.string().optional(), + /** Keep in sync with TTSProviders enum (defined below โ€” cannot reference due to eval order) */ + engineTTS: z.enum(['openai', 'azureOpenAI', 'elevenlabs', 'localai']).optional(), voice: z.string().optional(), languageTTS: z.string().optional(), automaticPlayback: z.boolean().optional(), - playbackRate: z.number().optional(), + playbackRate: z.number().min(0.25).max(4).optional(), cacheTTS: z.boolean().optional(), }), ) @@ -864,7 +898,7 @@ export const webSearchSchema = z.object({ searchProvider: z.nativeEnum(SearchProviders).optional(), scraperProvider: z.nativeEnum(ScraperProviders).optional(), rerankerType: z.nativeEnum(RerankerTypes).optional(), - scraperTimeout: z.number().optional(), + scraperTimeout: z.number().int().nonnegative().optional(), safeSearch: z.nativeEnum(SafeSearchTypes).default(SafeSearchTypes.MODERATE), firecrawlOptions: z .object({ @@ -873,7 +907,7 @@ export const webSearchSchema = z.object({ excludeTags: z.array(z.string()).optional(), headers: z.record(z.string()).optional(), waitFor: z.number().optional(), - timeout: z.number().optional(), + timeout: z.number().int().nonnegative().optional(), maxAge: z.number().optional(), mobile: z.boolean().optional(), skipTlsVerification: z.boolean().optional(), @@ -942,7 +976,7 @@ export const memorySchema = z.object({ provider: z.string(), model: z.string(), instructions: z.string().optional(), - model_parameters: z.record(z.any()).optional(), + model_parameters: z.record(z.union([z.string(), z.number(), z.boolean()])).optional(), }), ]) .optional(), @@ -1026,7 +1060,7 @@ export const configSchema = z.object({ modelSpecs: specsConfigSchema.optional(), endpoints: z .object({ - all: baseEndpointSchema.optional(), + all: baseEndpointSchema.omit({ baseURL: true }).optional(), [EModelEndpoint.openAI]: baseEndpointSchema.optional(), [EModelEndpoint.google]: baseEndpointSchema.optional(), [EModelEndpoint.anthropic]: anthropicEndpointSchema.optional(), diff --git a/packages/data-provider/src/file-config.ts b/packages/data-provider/src/file-config.ts index 32a1a28cc9..7ec184755d 100644 --- a/packages/data-provider/src/file-config.ts +++ b/packages/data-provider/src/file-config.ts @@ -442,22 +442,7 @@ export const fileConfig = { }, }; -const supportedMimeTypesSchema = z - .array(z.any()) - .optional() - .refine( - (mimeTypes) => { - if (!mimeTypes) { - return true; - } - return mimeTypes.every( - (mimeType) => mimeType instanceof RegExp || typeof mimeType === 'string', - ); - }, - { - message: 'Each mimeType must be a string or a RegExp object.', - }, - ); +const supportedMimeTypesSchema = z.array(z.string()).optional(); export const endpointFileConfigSchema = z.object({ disabled: z.boolean().optional(), @@ -690,22 +675,24 @@ export function mergeFileConfig(dynamic: z.infer | unde } if (dynamic.ocr !== undefined) { + const { supportedMimeTypes: ocrMimeTypes, ...ocrRest } = dynamic.ocr; mergedConfig.ocr = { ...mergedConfig.ocr, - ...dynamic.ocr, + ...ocrRest, }; - if (dynamic.ocr.supportedMimeTypes) { - mergedConfig.ocr.supportedMimeTypes = convertStringsToRegex(dynamic.ocr.supportedMimeTypes); + if (ocrMimeTypes) { + mergedConfig.ocr.supportedMimeTypes = convertStringsToRegex(ocrMimeTypes); } } if (dynamic.text !== undefined) { + const { supportedMimeTypes: textMimeTypes, ...textRest } = dynamic.text; mergedConfig.text = { ...mergedConfig.text, - ...dynamic.text, + ...textRest, }; - if (dynamic.text.supportedMimeTypes) { - mergedConfig.text.supportedMimeTypes = convertStringsToRegex(dynamic.text.supportedMimeTypes); + if (textMimeTypes) { + mergedConfig.text.supportedMimeTypes = convertStringsToRegex(textMimeTypes); } } diff --git a/packages/data-provider/src/mcp.ts b/packages/data-provider/src/mcp.ts index 3ad296c4ec..b22a599b9b 100644 --- a/packages/data-provider/src/mcp.ts +++ b/packages/data-provider/src/mcp.ts @@ -18,10 +18,10 @@ const BaseOptionsSchema = z.object({ */ startup: z.boolean().optional(), iconPath: z.string().optional(), - timeout: z.number().optional(), + timeout: z.number().int().nonnegative().optional(), /** Timeout (ms) for the long-lived SSE GET stream body before undici aborts it. Default: 300_000 (5 min). */ - sseReadTimeout: z.number().positive().optional(), - initTimeout: z.number().optional(), + sseReadTimeout: z.number().int().positive().optional(), + initTimeout: z.number().int().nonnegative().optional(), /** Controls visibility in chat dropdown menu (MCPSelect) */ chatMenu: z.boolean().optional(), /** @@ -104,7 +104,7 @@ const BaseOptionsSchema = z.object({ }); export const StdioOptionsSchema = BaseOptionsSchema.extend({ - type: z.literal('stdio').optional(), + type: z.literal('stdio').default('stdio'), /** * The executable to run to start the server. */ @@ -134,17 +134,17 @@ export const StdioOptionsSchema = BaseOptionsSchema.extend({ return processedEnv; }), /** - * How to handle stderr of the child process. This matches the semantics of Node's `child_process.spawn`. - * - * @type {import('node:child_process').IOType | import('node:stream').Stream | number} - * - * The default is "inherit", meaning messages to stderr will be printed to the parent process's stderr. + * How to handle stderr of the child process. + * Accepts: 'pipe' | 'ignore' | 'inherit' | file descriptor number. + * Defaults to "inherit". */ - stderr: z.any().optional(), + stderr: z + .union([z.enum(['pipe', 'ignore', 'inherit']), z.number().int().nonnegative()]) + .optional(), }); export const WebSocketOptionsSchema = BaseOptionsSchema.extend({ - type: z.literal('websocket').optional(), + type: z.literal('websocket').default('websocket'), url: z .string() .transform((val: string) => extractEnvVariable(val)) @@ -161,7 +161,7 @@ export const WebSocketOptionsSchema = BaseOptionsSchema.extend({ }); export const SSEOptionsSchema = BaseOptionsSchema.extend({ - type: z.literal('sse').optional(), + type: z.literal('sse').default('sse'), headers: z.record(z.string(), z.string()).optional(), url: z .string() diff --git a/packages/data-provider/src/models.ts b/packages/data-provider/src/models.ts index c2dbe2cf77..82c2042d8a 100644 --- a/packages/data-provider/src/models.ts +++ b/packages/data-provider/src/models.ts @@ -1,8 +1,8 @@ import { z } from 'zod'; -import type { TPreset } from './schemas'; +import type { TModelSpecPreset } from './schemas'; import { EModelEndpoint, - tPresetSchema, + tModelSpecPresetSchema, eModelEndpointSchema, AuthType, authTypeSchema, @@ -11,7 +11,7 @@ import { export type TModelSpec = { name: string; label: string; - preset: TPreset; + preset: TModelSpecPreset; order?: number; default?: boolean; description?: string; @@ -42,7 +42,7 @@ export type TModelSpec = { export const tModelSpecSchema = z.object({ name: z.string(), label: z.string(), - preset: tPresetSchema, + preset: tModelSpecPresetSchema, order: z.number().optional(), default: z.boolean().optional(), description: z.string().optional(), diff --git a/packages/data-provider/src/schemas.ts b/packages/data-provider/src/schemas.ts index 19ba804556..084f74af86 100644 --- a/packages/data-provider/src/schemas.ts +++ b/packages/data-provider/src/schemas.ts @@ -635,11 +635,15 @@ export const tMessageSchema = z.object({ calibrationRatio: z .number() .optional() - .describe('EMA ratio of provider-reported vs local token estimates; seeds the pruner on subsequent runs'), + .describe( + 'EMA ratio of provider-reported vs local token estimates; seeds the pruner on subsequent runs', + ), encoding: z .string() .optional() - .describe('Tokenizer encoding used when this ratio was computed (e.g. "claude", "o200k_base")'), + .describe( + 'Tokenizer encoding used when this ratio was computed (e.g. "claude", "o200k_base")', + ), }) .optional(), }); @@ -919,6 +923,30 @@ export const tQueryParamsSchema = tConversationSchema }), ); +/** Narrowed preset schema for use in model specs โ€” omits system/DB/deprecated fields */ +export const tModelSpecPresetSchema = tPresetSchema.omit({ + conversationId: true, + presetId: true, + title: true, + defaultPreset: true, + order: true, + isArchived: true, + user: true, + messages: true, + tags: true, + file_ids: true, + expiredAt: true, + parentMessageId: true, + resendImages: true, + chatGptLabel: true, + presetOverride: true, + greeting: true, + iconURL: true, + spec: true, +}); + +export type TModelSpecPreset = z.infer; + export type TPreset = z.infer; export type TSetOption = ( From 7e2b51697edf9a5961db94b6abf48ac8449dfb9b Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Sun, 29 Mar 2026 17:05:12 -0400 Subject: [PATCH 18/18] =?UTF-8?q?=F0=9F=AA=A2=20refactor:=20Eliminate=20Un?= =?UTF-8?q?necessary=20Re-renders=20During=20Message=20Streaming=20(#12454?= =?UTF-8?q?)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * refactor: add TMessageChatContext type for stable context passing Defines a type for a stable context object that wrapper components pass to memo'd message components, avoiding direct ChatContext subscriptions that bypass React.memo during streaming. * perf: remove ChatContext subscription from useMessageActions useMessageActions previously called useChatContext() inside memo'd components (MessageRender, ContentRender), bypassing React.memo when isSubmitting changed during streaming. Now accepts a stable chatContext param instead, using a ref for the isSubmitting guard in regenerateMessage. Also stabilizes handleScroll in useMessageProcess by using a ref for isSubmitting instead of including it in useCallback deps. * perf: pass stable chatContext to memo'd message components Wrapper components (Message, MessageContent) now create a stable chatContext object via useMemo with a getter-backed isSubmitting, and compute effectiveIsSubmitting (false for non-latest messages). This ensures MessageRender and ContentRender (both React.memo'd) only re-render for the latest message during streaming, preventing unnecessary re-renders of all prior messages and their SubRow, HoverButtons, and SiblingSwitch children. * perf: add custom memo comparators to prevent message reference re-renders buildTree creates new message objects on every streaming update for ALL messages, not just the changed one. This defeats React.memo's default shallow comparison since the message prop has a new reference even when the content hasn't changed. Custom areEqual comparators now compare message by key fields (messageId, text, error, depth, children length, etc.) instead of reference equality, preventing unnecessary re-renders of SubRow, Files, HoverButtons and other children for non-latest messages. * perf: memoize ChatForm children to prevent streaming re-renders - Wrap StopButton in React.memo - Wrap AudioRecorder in React.memo, use ref for isSubmitting in onTranscriptionComplete callback to stabilize it - Remove useChatContext() from FileFormChat (bypassed its memo during streaming), accept files/setFiles/setFilesLoading as props from ChatForm instead * perf: stabilize ChatForm child props to prevent cascading re-renders ChatForm re-renders frequently during streaming (ChatContext changes). This caused StopButton and AttachFileChat/AttachFileMenu to re-render despite being memo'd, because their props were new references each time. - Wrap handleStopGenerating in a ref-based stable callback so StopButton always receives the same function reference - Create stableConversation via useMemo keyed on rendering-relevant fields only (conversationId, endpoint, agent_id, etc.), so AttachFileChat and FileFormChat don't re-render from unrelated conversation metadata updates (e.g., title generation) * perf: remove ChatContext subscription from AttachFileMenu and FileFormChat Both components used useFileHandling() which internally calls useChatContext(), bypassing their React.memo wrappers and causing re-renders on every streaming chunk. Switch to useFileHandlingNoChatContext() which accepts file state as parameters. The state (files, setFiles, setFilesLoading, conversation) is passed down from ChatForm โ†’ AttachFileChat โ†’ AttachFileMenu as props, keeping the memo chain intact. * fix: update imports and test mocks for useFileHandlingNoChatContext - Re-export useFileHandlingNoChatContext from hooks barrel - Import from ~/hooks instead of direct path for test compatibility - Add useToastContext mock to @librechat/client in AttachFileMenu tests since useFileHandlingNoChatContext runs the core hook which needs it - Add useFileHandlingNoChatContext to ~/hooks test mock * perf: fix remaining ChatForm streaming re-renders - Switch AttachFileMenu from useSharePointFileHandling (subscribes to ChatContext) to useSharePointFileHandlingNoChatContext with explicit file state props - Memoize ChatForm textarea onFocus/onBlur handlers with useCallback to prevent TextareaAutosize re-renders (inline arrow functions and .bind() created new references on every ChatForm render) - Update AttachFileMenu test mocks for new hook variants * refactor: add displayName to ChatForm for React DevTools * perf: prevent ChatForm re-renders during streaming via wrapper pattern ChatForm was re-rendering on every streaming chunk because it subscribed to useChatContext() internally, and the ChatContext value changed frequently during streaming. Extract context subscription into a ChatFormWrapper that: - Subscribes to useChatContext() (re-renders on every chunk, cheap) - Stabilizes conversation via selective useMemo - Stabilizes handleStopGenerating via ref-based callback - Passes individual stable values as props to ChatForm ChatForm (memo'd) now receives context values as props instead of subscribing directly. Since individual values (files, setFiles, isSubmitting, etc.) are stable references during streaming, ChatForm's memo prevents re-renders entirely โ€” it only re-renders when isSubmitting actually toggles (2x per stream: start/end). * perf: stabilize newConversation prop and memoize CollapseChat - Wrap newConversation in ref-based stable callback in ChatFormWrapper (was the remaining unstable prop causing ChatForm to re-render) - Wrap CollapseChat in React.memo to prevent re-renders from parent * perf: memoize useAddedResponse return value useAddedResponse returned a new object literal on every render, causing AddedChatContext.Provider to trigger re-renders of all consumers (including ChatForm) on every streaming chunk. Wrap in useMemo so the context value stays referentially stable. * perf: memoize TextareaHeader to prevent re-renders from ChatForm * perf: address review findings for streaming render optimization Finding 1: Switch AttachFile.tsx from useFileHandling to useFileHandlingNoChatContext, closing the optimization hole for standard (non-agent) chat endpoints. Finding 2: Replace content reference equality with length comparison in both memo comparators โ€” safer against buildTree array reconstruction. Finding 3: Add conversation?.model to stableConversation deps in ChatFormWrapper so file uploads use the correct model after switches. Finding 4/14: Fix stableNewConversation to explicitly return the underlying call's result instead of discarding it via `as` cast. Finding 5/6: Extract useMemoizedChatContext hook shared by Message.tsx and MessageContent.tsx โ€” eliminates ~70 lines of duplication and stabilizes chatContext.conversation via selective useMemo to prevent post-stream metadata updates from re-rendering all messages. Finding 8: Use TMessage type for regenerate param instead of Record. Finding 9: Use FileSetter alias in FileFormChat instead of inline type. Finding 11: Fix pre-existing broken throttle in useMessageProcess โ€” was creating a new throttle instance per call, providing zero deduplication. Now retains the instance via useMemo. Finding 12: Initialize isSubmittingRef with chatContext.isSubmitting instead of false for consistency. Finding 13: Add ChatFormWrapper displayName. * fix: revert content comparison to reference equality in memo comparators The length-based comparison (content?.length) missed updates within existing content parts during streaming โ€” text chunks update a part's content without changing the array length, so the comparator returned true and skipped re-renders for the latest message. Reference equality (===) is correct here: buildTree preserves content array references for unchanged messages via shallow spread, while React Query gives the latest message a new reference when its content updates during streaming. * fix: cancel throttled handleScroll on unmount and remove unused import * fix: use chatContext getter directly in regenerateMessage callback The local isSubmittingRef was stale for non-latest messages (which don't re-render during streaming by design). chatContext.isSubmitting is a getter backed by the wrapper's ref, so reading it at call-time always returns the current value regardless of whether the component has re-rendered. * fix: remove unused useCallback import from useMemoizedChatContext * fix: pass global isSubmitting to HoverButtons for action gating HoverButtons uses isSubmitting via useGenerationsByLatest to disable regenerate and hide edit buttons during streaming. Passing the effective value (false for non-latest messages) re-enabled those actions mid-stream, risking overlapping edits/regenerations. Use chatContext.isSubmitting (getter, always returns current value) for HoverButtons while keeping the effective value for rendering-only UI (cursor, placeholder, streaming indicator). * fix: address second review โ€” stale HoverButtons, messages dep, cleanup - Add isSubmitting to chatContext useMemo deps in useMemoizedChatContext so HoverButtons correctly updates when streaming starts/ends (2 extra re-renders per session, belt-and-suspenders for post-stream state) - Change conversation?.messages?.length dep to boolean in ChatFormWrapper stableConversation โ€” only need 0โ†”1+ transition for landing page check, not exact count on every message addition - Add defensive comment at chatContext destructuring point in useMessageActions explaining why isSubmitting must not be destructured - Remove dead mockUseFileHandling.mockReturnValue from AttachFileMenu tests * chore: remove dead useFileHandling mock artifacts from AttachFileMenu tests * fix: resolve eslint warnings for useMemo dependencies - Extract complex expression (conversation?.messages?.length ?? 0) > 0 to hasMessages variable for static analysis in ChatFormWrapper - Add eslint-disable for intentional isSubmitting dep in useMemoizedChatContext (forces new chatContext reference on streaming start/end so HoverButtons re-renders) --- client/src/common/types.ts | 22 +++ .../components/Chat/Input/AudioRecorder.tsx | 12 +- client/src/components/Chat/Input/ChatForm.tsx | 141 +++++++++++++++--- .../components/Chat/Input/CollapseChat.tsx | 2 +- .../Chat/Input/Files/AttachFile.tsx | 25 +++- .../Chat/Input/Files/AttachFileChat.tsx | 21 ++- .../Chat/Input/Files/AttachFileMenu.tsx | 28 +++- .../Chat/Input/Files/FileFormChat.tsx | 24 ++- .../Files/__tests__/AttachFileChat.spec.tsx | 8 +- .../Files/__tests__/AttachFileMenu.spec.tsx | 27 +++- .../src/components/Chat/Input/StopButton.tsx | 11 +- .../components/Chat/Input/TextareaHeader.tsx | 5 +- .../src/components/Chat/Messages/Message.tsx | 11 +- .../Chat/Messages/ui/MessageRender.tsx | 79 ++++++++-- .../src/components/Messages/ContentRender.tsx | 74 ++++++++- .../components/Messages/MessageContent.tsx | 9 +- client/src/hooks/Chat/useAddedResponse.ts | 15 +- client/src/hooks/Files/index.ts | 2 +- client/src/hooks/Messages/index.ts | 1 + .../hooks/Messages/useMemoizedChatContext.ts | 80 ++++++++++ .../src/hooks/Messages/useMessageActions.tsx | 27 +++- .../src/hooks/Messages/useMessageProcess.tsx | 27 ++-- 22 files changed, 554 insertions(+), 97 deletions(-) create mode 100644 client/src/hooks/Messages/useMemoizedChatContext.ts diff --git a/client/src/common/types.ts b/client/src/common/types.ts index 85044bb2bc..6ca408685f 100644 --- a/client/src/common/types.ts +++ b/client/src/common/types.ts @@ -355,6 +355,28 @@ export type TOptions = { export type TAskFunction = (props: TAskProps, options?: TOptions) => void; +/** + * Stable context object passed from non-memo'd wrapper components (Message, MessageContent) + * to memo'd inner components (MessageRender, ContentRender) via props. + * + * This avoids subscribing to ChatContext inside memo'd components, which would bypass React.memo + * and cause unnecessary re-renders when `isSubmitting` changes during streaming. + * + * The `isSubmitting` property should use a getter backed by a ref so it returns the current + * value at call-time (for callback guards) without being a reactive dependency. + */ +export type TMessageChatContext = { + ask: (...args: Parameters) => void; + index: number; + regenerate: (message: t.TMessage, options?: { addedConvo?: t.TConversation | null }) => void; + conversation: t.TConversation | null; + latestMessageId: string | undefined; + latestMessageDepth: number | undefined; + handleContinue: (e: React.MouseEvent) => void; + /** Should be a getter backed by a ref โ€” reads current value without triggering re-renders */ + readonly isSubmitting: boolean; +}; + export type TMessageProps = { conversation?: t.TConversation | null; messageId?: string | null; diff --git a/client/src/components/Chat/Input/AudioRecorder.tsx b/client/src/components/Chat/Input/AudioRecorder.tsx index dbf2c29d83..e9e19d0904 100644 --- a/client/src/components/Chat/Input/AudioRecorder.tsx +++ b/client/src/components/Chat/Input/AudioRecorder.tsx @@ -1,4 +1,4 @@ -import { useCallback, useRef } from 'react'; +import { memo, useCallback, useRef } from 'react'; import { MicOff } from 'lucide-react'; import { useToastContext, TooltipAnchor, ListeningIcon, Spinner } from '@librechat/client'; import { useLocalize, useSpeechToText, useGetAudioSettings } from '~/hooks'; @@ -7,7 +7,7 @@ import { globalAudioId } from '~/common'; import { cn } from '~/utils'; const isExternalSTT = (speechToTextEndpoint: string) => speechToTextEndpoint === 'external'; -export default function AudioRecorder({ +export default memo(function AudioRecorder({ disabled, ask, methods, @@ -26,10 +26,12 @@ export default function AudioRecorder({ const { speechToTextEndpoint } = useGetAudioSettings(); const existingTextRef = useRef(''); + const isSubmittingRef = useRef(isSubmitting); + isSubmittingRef.current = isSubmitting; const onTranscriptionComplete = useCallback( (text: string) => { - if (isSubmitting) { + if (isSubmittingRef.current) { showToast({ message: localize('com_ui_speech_while_submitting'), status: 'error', @@ -52,7 +54,7 @@ export default function AudioRecorder({ existingTextRef.current = ''; } }, - [ask, reset, showToast, localize, isSubmitting, speechToTextEndpoint], + [ask, reset, showToast, localize, speechToTextEndpoint], ); const setText = useCallback( @@ -125,4 +127,4 @@ export default function AudioRecorder({ } /> ); -} +}); diff --git a/client/src/components/Chat/Input/ChatForm.tsx b/client/src/components/Chat/Input/ChatForm.tsx index fed355dcb3..9e0ad7f382 100644 --- a/client/src/components/Chat/Input/ChatForm.tsx +++ b/client/src/components/Chat/Input/ChatForm.tsx @@ -3,6 +3,8 @@ import { useWatch } from 'react-hook-form'; import { TextareaAutosize } from '@librechat/client'; import { useRecoilState, useRecoilValue } from 'recoil'; import { Constants, isAssistantsEndpoint, isAgentsEndpoint } from 'librechat-data-provider'; +import type { TConversation } from 'librechat-data-provider'; +import type { ExtendedFile, FileSetter, ConvoGenerator } from '~/common'; import { useChatContext, useChatFormContext, @@ -35,7 +37,30 @@ import BadgeRow from './BadgeRow'; import Mention from './Mention'; import store from '~/store'; -const ChatForm = memo(({ index = 0 }: { index?: number }) => { +interface ChatFormProps { + index: number; + /** From ChatContext โ€” individual values so memo can compare them */ + files: Map; + setFiles: FileSetter; + conversation: TConversation | null; + isSubmitting: boolean; + filesLoading: boolean; + setFilesLoading: React.Dispatch>; + newConversation: ConvoGenerator; + handleStopGenerating: (e: React.MouseEvent) => void; +} + +const ChatForm = memo(function ChatForm({ + index, + files, + setFiles, + conversation, + isSubmitting, + filesLoading, + setFilesLoading, + newConversation, + handleStopGenerating, +}: ChatFormProps) { const submitButtonRef = useRef(null); const textAreaRef = useRef(null); useFocusChatEffect(textAreaRef); @@ -65,15 +90,6 @@ const ChatForm = memo(({ index = 0 }: { index?: number }) => { const { requiresKey } = useRequiresKey(); const methods = useChatFormContext(); - const { - files, - setFiles, - conversation, - isSubmitting, - filesLoading, - newConversation, - handleStopGenerating, - } = useChatContext(); const { generateConversation, conversation: addedConvo, @@ -120,6 +136,15 @@ const ChatForm = memo(({ index = 0 }: { index?: number }) => { } }, [isCollapsed]); + const handleTextareaFocus = useCallback(() => { + handleFocusOrClick(); + setIsTextAreaFocused(true); + }, [handleFocusOrClick]); + + const handleTextareaBlur = useCallback(() => { + setIsTextAreaFocused(false); + }, []); + useAutoSave({ files, setFiles, @@ -253,7 +278,12 @@ const ChatForm = memo(({ index = 0 }: { index?: number }) => { handleSaveBadges={handleSaveBadges} setBadges={setBadges} /> - + {endpoint && (

{ tabIndex={0} data-testid="text-input" rows={1} - onFocus={() => { - handleFocusOrClick(); - setIsTextAreaFocused(true); - }} - onBlur={setIsTextAreaFocused.bind(null, false)} + onFocus={handleTextareaFocus} + onBlur={handleTextareaBlur} aria-label={localize('com_ui_message_input')} onClick={handleFocusOrClick} style={{ height: 44, overflowY: 'auto' }} @@ -315,7 +342,13 @@ const ChatForm = memo(({ index = 0 }: { index?: number }) => { )} >
- +
{ ); }); +ChatForm.displayName = 'ChatForm'; -export default ChatForm; +/** + * Wrapper that subscribes to ChatContext and passes stable individual values + * to the memo'd ChatForm. This prevents ChatForm from re-rendering on every + * streaming chunk โ€” it only re-renders when the specific values it uses change. + */ +function ChatFormWrapper({ index = 0 }: { index?: number }) { + const { + files, + setFiles, + conversation, + isSubmitting, + filesLoading, + setFilesLoading, + newConversation, + handleStopGenerating, + } = useChatContext(); + + /** + * Stabilize conversation reference: only update when rendering-relevant fields change, + * not on every metadata update (e.g., title generation during streaming). + */ + const hasMessages = (conversation?.messages?.length ?? 0) > 0; + const stableConversation = useMemo( + () => conversation, + // eslint-disable-next-line react-hooks/exhaustive-deps + [ + conversation?.conversationId, + conversation?.endpoint, + conversation?.endpointType, + conversation?.agent_id, + conversation?.assistant_id, + conversation?.spec, + conversation?.useResponsesApi, + conversation?.model, + hasMessages, + ], + ); + + /** Stabilize function refs so they never trigger ChatForm re-renders */ + const handleStopRef = useRef(handleStopGenerating); + handleStopRef.current = handleStopGenerating; + const stableHandleStop = useCallback( + (e: React.MouseEvent) => handleStopRef.current(e), + [], + ); + + const newConvoRef = useRef(newConversation); + newConvoRef.current = newConversation; + const stableNewConversation: ConvoGenerator = useCallback( + (...args: Parameters): ReturnType => + newConvoRef.current(...args), + [], + ); + + return ( + + ); +} + +ChatFormWrapper.displayName = 'ChatFormWrapper'; + +export default ChatFormWrapper; diff --git a/client/src/components/Chat/Input/CollapseChat.tsx b/client/src/components/Chat/Input/CollapseChat.tsx index ea099ed69b..7efe52dc8d 100644 --- a/client/src/components/Chat/Input/CollapseChat.tsx +++ b/client/src/components/Chat/Input/CollapseChat.tsx @@ -52,4 +52,4 @@ const CollapseChat = ({ ); }; -export default CollapseChat; +export default React.memo(CollapseChat); diff --git a/client/src/components/Chat/Input/Files/AttachFile.tsx b/client/src/components/Chat/Input/Files/AttachFile.tsx index 38a3fa8c6f..098fa2c4c3 100644 --- a/client/src/components/Chat/Input/Files/AttachFile.tsx +++ b/client/src/components/Chat/Input/Files/AttachFile.tsx @@ -1,14 +1,33 @@ import React, { useRef } from 'react'; import { FileUpload, TooltipAnchor, AttachmentIcon } from '@librechat/client'; -import { useLocalize, useFileHandling } from '~/hooks'; +import type { TConversation } from 'librechat-data-provider'; +import type { ExtendedFile, FileSetter } from '~/common'; +import { useFileHandlingNoChatContext, useLocalize } from '~/hooks'; import { cn } from '~/utils'; -const AttachFile = ({ disabled }: { disabled?: boolean | null }) => { +const AttachFile = ({ + disabled, + files, + setFiles, + setFilesLoading, + conversation, +}: { + disabled?: boolean | null; + files: Map; + setFiles: FileSetter; + setFilesLoading: React.Dispatch>; + conversation: TConversation | null; +}) => { const localize = useLocalize(); const inputRef = useRef(null); const isUploadDisabled = disabled ?? false; - const { handleFileChange } = useFileHandling(); + const { handleFileChange } = useFileHandlingNoChatContext(undefined, { + files, + setFiles, + setFilesLoading, + conversation, + }); return ( diff --git a/client/src/components/Chat/Input/Files/AttachFileChat.tsx b/client/src/components/Chat/Input/Files/AttachFileChat.tsx index 2f954d01d5..7eb9b0c474 100644 --- a/client/src/components/Chat/Input/Files/AttachFileChat.tsx +++ b/client/src/components/Chat/Input/Files/AttachFileChat.tsx @@ -9,6 +9,7 @@ import { getEndpointFileConfig, } from 'librechat-data-provider'; import type { TConversation } from 'librechat-data-provider'; +import type { ExtendedFile, FileSetter } from '~/common'; import { useGetFileConfig, useGetEndpointsQuery, useGetAgentByIdQuery } from '~/data-provider'; import { useAgentsMapContext } from '~/Providers'; import AttachFileMenu from './AttachFileMenu'; @@ -17,9 +18,15 @@ import AttachFile from './AttachFile'; function AttachFileChat({ disableInputs, conversation, + files, + setFiles, + setFilesLoading, }: { disableInputs: boolean; conversation: TConversation | null; + files: Map; + setFiles: FileSetter; + setFilesLoading: React.Dispatch>; }) { const conversationId = conversation?.conversationId ?? Constants.NEW_CONVO; const { endpoint } = conversation ?? { endpoint: null }; @@ -90,7 +97,15 @@ function AttachFileChat({ ); if (isAssistants && endpointSupportsFiles && !isUploadDisabled) { - return ; + return ( + + ); } else if ((isAgents || endpointSupportsFiles) && !isUploadDisabled) { return ( ); } diff --git a/client/src/components/Chat/Input/Files/AttachFileMenu.tsx b/client/src/components/Chat/Input/Files/AttachFileMenu.tsx index 62072e49e5..181d219c08 100644 --- a/client/src/components/Chat/Input/Files/AttachFileMenu.tsx +++ b/client/src/components/Chat/Input/Files/AttachFileMenu.tsx @@ -23,15 +23,16 @@ import { bedrockDocumentExtensions, isDocumentSupportedProvider, } from 'librechat-data-provider'; -import type { EndpointFileConfig } from 'librechat-data-provider'; +import type { EndpointFileConfig, TConversation } from 'librechat-data-provider'; +import type { ExtendedFile, FileSetter } from '~/common'; import { useAgentToolPermissions, useAgentCapabilities, useGetAgentsConfig, - useFileHandling, + useFileHandlingNoChatContext, useLocalize, } from '~/hooks'; -import useSharePointFileHandling from '~/hooks/Files/useSharePointFileHandling'; +import { useSharePointFileHandlingNoChatContext } from '~/hooks/Files/useSharePointFileHandling'; import { SharePointPickerDialog } from '~/components/SharePoint'; import { useGetStartupConfig } from '~/data-provider'; import { ephemeralAgentByConvoId } from '~/store'; @@ -53,6 +54,10 @@ interface AttachFileMenuProps { endpointType?: EModelEndpoint | string; endpointFileConfig?: EndpointFileConfig; useResponsesApi?: boolean; + files: Map; + setFiles: FileSetter; + setFilesLoading: React.Dispatch>; + conversation: TConversation | null; } const AttachFileMenu = ({ @@ -63,6 +68,10 @@ const AttachFileMenu = ({ conversationId, endpointFileConfig, useResponsesApi, + files, + setFiles, + setFilesLoading, + conversation, }: AttachFileMenuProps) => { const localize = useLocalize(); const isUploadDisabled = disabled ?? false; @@ -72,10 +81,17 @@ const AttachFileMenu = ({ ephemeralAgentByConvoId(conversationId), ); const [toolResource, setToolResource] = useState(); - const { handleFileChange } = useFileHandling(); - const { handleSharePointFiles, isProcessing, downloadProgress } = useSharePointFileHandling({ - toolResource, + const { handleFileChange } = useFileHandlingNoChatContext(undefined, { + files, + setFiles, + setFilesLoading, + conversation, }); + const { handleSharePointFiles, isProcessing, downloadProgress } = + useSharePointFileHandlingNoChatContext( + { toolResource }, + { files, setFiles, setFilesLoading, conversation }, + ); const { agentsConfig } = useGetAgentsConfig(); const { data: startupConfig } = useGetStartupConfig(); diff --git a/client/src/components/Chat/Input/Files/FileFormChat.tsx b/client/src/components/Chat/Input/Files/FileFormChat.tsx index 3c01f2d642..4d37938c5d 100644 --- a/client/src/components/Chat/Input/Files/FileFormChat.tsx +++ b/client/src/components/Chat/Input/Files/FileFormChat.tsx @@ -1,16 +1,30 @@ import { memo } from 'react'; import { useRecoilValue } from 'recoil'; import type { TConversation } from 'librechat-data-provider'; -import { useChatContext } from '~/Providers'; -import { useFileHandling } from '~/hooks'; +import type { ExtendedFile, FileSetter } from '~/common'; +import { useFileHandlingNoChatContext } from '~/hooks'; import FileRow from './FileRow'; import store from '~/store'; -function FileFormChat({ conversation }: { conversation: TConversation | null }) { - const { files, setFiles, setFilesLoading } = useChatContext(); +function FileFormChat({ + conversation, + files, + setFiles, + setFilesLoading, +}: { + conversation: TConversation | null; + files: Map; + setFiles: FileSetter; + setFilesLoading: React.Dispatch>; +}) { const chatDirection = useRecoilValue(store.chatDirection).toLowerCase(); const { endpoint: _endpoint } = conversation ?? { endpoint: null }; - const { abortUpload } = useFileHandling(); + const { abortUpload } = useFileHandlingNoChatContext(undefined, { + files, + setFiles, + setFilesLoading, + conversation, + }); const isRTL = chatDirection === 'rtl'; diff --git a/client/src/components/Chat/Input/Files/__tests__/AttachFileChat.spec.tsx b/client/src/components/Chat/Input/Files/__tests__/AttachFileChat.spec.tsx index cea55f5ce8..80f06a1b89 100644 --- a/client/src/components/Chat/Input/Files/__tests__/AttachFileChat.spec.tsx +++ b/client/src/components/Chat/Input/Files/__tests__/AttachFileChat.spec.tsx @@ -59,7 +59,13 @@ function renderComponent(conversation: Record | null, disableIn return render( - + {}} + setFilesLoading={() => {}} + /> , ); diff --git a/client/src/components/Chat/Input/Files/__tests__/AttachFileMenu.spec.tsx b/client/src/components/Chat/Input/Files/__tests__/AttachFileMenu.spec.tsx index cf08721207..c2710d4ef8 100644 --- a/client/src/components/Chat/Input/Files/__tests__/AttachFileMenu.spec.tsx +++ b/client/src/components/Chat/Input/Files/__tests__/AttachFileMenu.spec.tsx @@ -9,13 +9,14 @@ jest.mock('~/hooks', () => ({ useAgentToolPermissions: jest.fn(), useAgentCapabilities: jest.fn(), useGetAgentsConfig: jest.fn(), - useFileHandling: jest.fn(), + useFileHandlingNoChatContext: jest.fn(), useLocalize: jest.fn(), })); jest.mock('~/hooks/Files/useSharePointFileHandling', () => ({ __esModule: true, default: jest.fn(), + useSharePointFileHandlingNoChatContext: jest.fn(), })); jest.mock('~/data-provider', () => ({ @@ -52,6 +53,7 @@ jest.mock('@librechat/client', () => { ), AttachmentIcon: () => R.createElement('span', { 'data-testid': 'attachment-icon' }), SharePointIcon: () => R.createElement('span', { 'data-testid': 'sharepoint-icon' }), + useToastContext: () => ({ showToast: jest.fn() }), }; }); @@ -66,11 +68,14 @@ jest.mock('@ariakit/react', () => { const mockUseAgentToolPermissions = jest.requireMock('~/hooks').useAgentToolPermissions; const mockUseAgentCapabilities = jest.requireMock('~/hooks').useAgentCapabilities; const mockUseGetAgentsConfig = jest.requireMock('~/hooks').useGetAgentsConfig; -const mockUseFileHandling = jest.requireMock('~/hooks').useFileHandling; +const mockUseFileHandlingNoChatContext = jest.requireMock('~/hooks').useFileHandlingNoChatContext; const mockUseLocalize = jest.requireMock('~/hooks').useLocalize; const mockUseSharePointFileHandling = jest.requireMock( '~/hooks/Files/useSharePointFileHandling', ).default; +const mockUseSharePointFileHandlingNoChatContext = jest.requireMock( + '~/hooks/Files/useSharePointFileHandling', +).useSharePointFileHandlingNoChatContext; const mockUseGetStartupConfig = jest.requireMock('~/data-provider').useGetStartupConfig; const queryClient = new QueryClient({ defaultOptions: { queries: { retry: false } } }); @@ -92,12 +97,15 @@ function setupMocks(overrides: { provider?: string } = {}) { codeEnabled: false, }); mockUseGetAgentsConfig.mockReturnValue({ agentsConfig: {} }); - mockUseFileHandling.mockReturnValue({ handleFileChange: jest.fn() }); - mockUseSharePointFileHandling.mockReturnValue({ + mockUseFileHandlingNoChatContext.mockReturnValue({ handleFileChange: jest.fn() }); + const sharePointReturnValue = { handleSharePointFiles: jest.fn(), isProcessing: false, downloadProgress: 0, - }); + error: null, + }; + mockUseSharePointFileHandling.mockReturnValue(sharePointReturnValue); + mockUseSharePointFileHandlingNoChatContext.mockReturnValue(sharePointReturnValue); mockUseGetStartupConfig.mockReturnValue({ data: { sharePointFilePickerEnabled: false } }); mockUseAgentToolPermissions.mockReturnValue({ fileSearchAllowedByAgent: false, @@ -110,7 +118,14 @@ function renderMenu(props: Record = {}) { return render( - + {}} + setFilesLoading={() => {}} + conversation={null} + {...props} + /> , ); diff --git a/client/src/components/Chat/Input/StopButton.tsx b/client/src/components/Chat/Input/StopButton.tsx index 4a058777f1..fd94ba806c 100644 --- a/client/src/components/Chat/Input/StopButton.tsx +++ b/client/src/components/Chat/Input/StopButton.tsx @@ -1,8 +1,15 @@ +import { memo } from 'react'; import { TooltipAnchor } from '@librechat/client'; import { useLocalize } from '~/hooks'; import { cn } from '~/utils'; -export default function StopButton({ stop, setShowStopButton }) { +export default memo(function StopButton({ + stop, + setShowStopButton, +}: { + stop: (e: React.MouseEvent) => void; + setShowStopButton: (value: boolean) => void; +}) { const localize = useLocalize(); return ( @@ -34,4 +41,4 @@ export default function StopButton({ stop, setShowStopButton }) { } > ); -} +}); diff --git a/client/src/components/Chat/Input/TextareaHeader.tsx b/client/src/components/Chat/Input/TextareaHeader.tsx index 9e67252efe..06c1802585 100644 --- a/client/src/components/Chat/Input/TextareaHeader.tsx +++ b/client/src/components/Chat/Input/TextareaHeader.tsx @@ -1,8 +1,9 @@ +import { memo } from 'react'; import AddedConvo from './AddedConvo'; import type { TConversation } from 'librechat-data-provider'; import type { SetterOrUpdater } from 'recoil'; -export default function TextareaHeader({ +export default memo(function TextareaHeader({ addedConvo, setAddedConvo, }: { @@ -17,4 +18,4 @@ export default function TextareaHeader({
); -} +}); diff --git a/client/src/components/Chat/Messages/Message.tsx b/client/src/components/Chat/Messages/Message.tsx index f9db38fdab..53aef812fc 100644 --- a/client/src/components/Chat/Messages/Message.tsx +++ b/client/src/components/Chat/Messages/Message.tsx @@ -1,5 +1,5 @@ import React from 'react'; -import { useMessageProcess } from '~/hooks'; +import { useMessageProcess, useMemoizedChatContext } from '~/hooks'; import type { TMessageProps } from '~/common'; import MessageRender from './ui/MessageRender'; import MultiMessage from './MultiMessage'; @@ -23,10 +23,11 @@ const MessageContainer = React.memo(function MessageContainer({ }); export default function Message(props: TMessageProps) { - const { conversation, handleScroll } = useMessageProcess({ + const { conversation, handleScroll, isSubmitting } = useMessageProcess({ message: props.message, }); const { message, currentEditId, setCurrentEditId } = props; + const { chatContext, effectiveIsSubmitting } = useMemoizedChatContext(message, isSubmitting); if (!message || typeof message !== 'object') { return null; @@ -38,7 +39,11 @@ export default function Message(props: TMessageProps) { <>
- +
; +/** + * Custom comparator for React.memo: compares `message` by key fields instead of reference + * because `buildTree` creates new message objects on every streaming update for ALL messages, + * even when only the latest message's text changed. + */ +function areMessageRenderPropsEqual(prev: MessageRenderProps, next: MessageRenderProps): boolean { + if (prev.isSubmitting !== next.isSubmitting) { + return false; + } + if (prev.chatContext !== next.chatContext) { + return false; + } + if (prev.siblingIdx !== next.siblingIdx) { + return false; + } + if (prev.siblingCount !== next.siblingCount) { + return false; + } + if (prev.currentEditId !== next.currentEditId) { + return false; + } + if (prev.setSiblingIdx !== next.setSiblingIdx) { + return false; + } + if (prev.setCurrentEditId !== next.setCurrentEditId) { + return false; + } + + const prevMsg = prev.message; + const nextMsg = next.message; + if (prevMsg === nextMsg) { + return true; + } + if (!prevMsg || !nextMsg) { + return prevMsg === nextMsg; + } + + return ( + prevMsg.messageId === nextMsg.messageId && + prevMsg.text === nextMsg.text && + prevMsg.error === nextMsg.error && + prevMsg.unfinished === nextMsg.unfinished && + prevMsg.depth === nextMsg.depth && + prevMsg.isCreatedByUser === nextMsg.isCreatedByUser && + (prevMsg.children?.length ?? 0) === (nextMsg.children?.length ?? 0) && + prevMsg.content === nextMsg.content && + prevMsg.model === nextMsg.model && + prevMsg.endpoint === nextMsg.endpoint && + prevMsg.iconURL === nextMsg.iconURL && + prevMsg.feedback?.rating === nextMsg.feedback?.rating && + (prevMsg.files?.length ?? 0) === (nextMsg.files?.length ?? 0) + ); +} + const MessageRender = memo(function MessageRender({ message: msg, siblingIdx, @@ -31,6 +92,7 @@ const MessageRender = memo(function MessageRender({ currentEditId, setCurrentEditId, isSubmitting = false, + chatContext, }: MessageRenderProps) { const localize = useLocalize(); const { @@ -52,6 +114,7 @@ const MessageRender = memo(function MessageRender({ message: msg, currentEditId, setCurrentEditId, + chatContext, }); const fontSize = useAtomValue(fontSizeAtom); const maximizeChatSpace = useRecoilValue(store.maximizeChatSpace); @@ -63,8 +126,6 @@ const MessageRender = memo(function MessageRender({ [hasNoChildren, msg?.depth, latestMessageDepth], ); const isLatestMessage = msg?.messageId === latestMessageId; - /** Only pass isSubmitting to the latest message to prevent unnecessary re-renders */ - const effectiveIsSubmitting = isLatestMessage ? isSubmitting : false; const iconData: TMessageIcon = useMemo( () => ({ @@ -92,10 +153,10 @@ const MessageRender = memo(function MessageRender({ messageId, isLatestMessage, isExpanded: false as const, - isSubmitting: effectiveIsSubmitting, + isSubmitting, conversationId: conversation?.conversationId, }), - [messageId, conversation?.conversationId, effectiveIsSubmitting, isLatestMessage], + [messageId, conversation?.conversationId, isSubmitting, isLatestMessage], ); if (!msg) { @@ -165,7 +226,7 @@ const MessageRender = memo(function MessageRender({ message={msg} enterEdit={enterEdit} error={!!(msg.error ?? false)} - isSubmitting={effectiveIsSubmitting} + isSubmitting={isSubmitting} unfinished={msg.unfinished ?? false} isCreatedByUser={msg.isCreatedByUser ?? true} siblingIdx={siblingIdx ?? 0} @@ -173,7 +234,7 @@ const MessageRender = memo(function MessageRender({ />
- {hasNoChildren && effectiveIsSubmitting ? ( + {hasNoChildren && isSubmitting ? ( ) : ( @@ -187,7 +248,7 @@ const MessageRender = memo(function MessageRender({ isEditing={edit} message={msg} enterEdit={enterEdit} - isSubmitting={isSubmitting} + isSubmitting={chatContext.isSubmitting} conversation={conversation ?? null} regenerate={handleRegenerateMessage} copyToClipboard={copyToClipboard} @@ -202,7 +263,7 @@ const MessageRender = memo(function MessageRender({ ); -}); +}, areMessageRenderPropsEqual); MessageRender.displayName = 'MessageRender'; export default MessageRender; diff --git a/client/src/components/Messages/ContentRender.tsx b/client/src/components/Messages/ContentRender.tsx index 6b3f05ce5d..4ba8db36f8 100644 --- a/client/src/components/Messages/ContentRender.tsx +++ b/client/src/components/Messages/ContentRender.tsx @@ -2,7 +2,7 @@ import { useCallback, useMemo, memo } from 'react'; import { useAtomValue } from 'jotai'; import { useRecoilValue } from 'recoil'; import type { TMessage, TMessageContentParts } from 'librechat-data-provider'; -import type { TMessageProps, TMessageIcon } from '~/common'; +import type { TMessageProps, TMessageIcon, TMessageChatContext } from '~/common'; import { useAttachments, useLocalize, useMessageActions, useContentMetadata } from '~/hooks'; import { cn, getHeaderPrefixForScreenReader, getMessageAriaLabel } from '~/utils'; import ContentParts from '~/components/Chat/Messages/Content/ContentParts'; @@ -16,12 +16,72 @@ import store from '~/store'; type ContentRenderProps = { message?: TMessage; + /** + * Effective isSubmitting: false for non-latest messages, real value for latest. + * Computed by the wrapper (MessageContent.tsx) so this memo'd component only re-renders + * when the value actually matters. + */ isSubmitting?: boolean; + /** Stable context object from wrapper โ€” avoids ChatContext subscription inside memo */ + chatContext: TMessageChatContext; } & Pick< TMessageProps, 'currentEditId' | 'setCurrentEditId' | 'siblingIdx' | 'setSiblingIdx' | 'siblingCount' >; +/** + * Custom comparator for React.memo: compares `message` by key fields instead of reference + * because `buildTree` creates new message objects on every streaming update for ALL messages. + */ +function areContentRenderPropsEqual(prev: ContentRenderProps, next: ContentRenderProps): boolean { + if (prev.isSubmitting !== next.isSubmitting) { + return false; + } + if (prev.chatContext !== next.chatContext) { + return false; + } + if (prev.siblingIdx !== next.siblingIdx) { + return false; + } + if (prev.siblingCount !== next.siblingCount) { + return false; + } + if (prev.currentEditId !== next.currentEditId) { + return false; + } + if (prev.setSiblingIdx !== next.setSiblingIdx) { + return false; + } + if (prev.setCurrentEditId !== next.setCurrentEditId) { + return false; + } + + const prevMsg = prev.message; + const nextMsg = next.message; + if (prevMsg === nextMsg) { + return true; + } + if (!prevMsg || !nextMsg) { + return prevMsg === nextMsg; + } + + return ( + prevMsg.messageId === nextMsg.messageId && + prevMsg.text === nextMsg.text && + prevMsg.error === nextMsg.error && + prevMsg.unfinished === nextMsg.unfinished && + prevMsg.depth === nextMsg.depth && + prevMsg.isCreatedByUser === nextMsg.isCreatedByUser && + (prevMsg.children?.length ?? 0) === (nextMsg.children?.length ?? 0) && + prevMsg.content === nextMsg.content && + prevMsg.model === nextMsg.model && + prevMsg.endpoint === nextMsg.endpoint && + prevMsg.iconURL === nextMsg.iconURL && + prevMsg.feedback?.rating === nextMsg.feedback?.rating && + (prevMsg.attachments?.length ?? 0) === (nextMsg.attachments?.length ?? 0) + ); +} + const ContentRender = memo(function ContentRender({ message: msg, siblingIdx, @@ -30,6 +90,7 @@ const ContentRender = memo(function ContentRender({ currentEditId, setCurrentEditId, isSubmitting = false, + chatContext, }: ContentRenderProps) { const localize = useLocalize(); const { attachments, searchResults } = useAttachments({ @@ -55,6 +116,7 @@ const ContentRender = memo(function ContentRender({ searchResults, currentEditId, setCurrentEditId, + chatContext, }); const fontSize = useAtomValue(fontSizeAtom); const maximizeChatSpace = useRecoilValue(store.maximizeChatSpace); @@ -66,8 +128,6 @@ const ContentRender = memo(function ContentRender({ ); const hasNoChildren = !(msg?.children?.length ?? 0); const isLatestMessage = msg?.messageId === latestMessageId; - /** Only pass isSubmitting to the latest message to prevent unnecessary re-renders */ - const effectiveIsSubmitting = isLatestMessage ? isSubmitting : false; const iconData: TMessageIcon = useMemo( () => ({ @@ -158,13 +218,13 @@ const ContentRender = memo(function ContentRender({ searchResults={searchResults} setSiblingIdx={setSiblingIdx} isLatestMessage={isLatestMessage} - isSubmitting={effectiveIsSubmitting} + isSubmitting={isSubmitting} isCreatedByUser={msg.isCreatedByUser} conversationId={conversation?.conversationId} content={msg.content as Array} /> - {hasNoChildren && effectiveIsSubmitting ? ( + {hasNoChildren && isSubmitting ? ( ) : ( @@ -178,7 +238,7 @@ const ContentRender = memo(function ContentRender({ message={msg} isEditing={edit} enterEdit={enterEdit} - isSubmitting={isSubmitting} + isSubmitting={chatContext.isSubmitting} conversation={conversation ?? null} regenerate={handleRegenerateMessage} copyToClipboard={copyToClipboard} @@ -193,7 +253,7 @@ const ContentRender = memo(function ContentRender({ ); -}); +}, areContentRenderPropsEqual); ContentRender.displayName = 'ContentRender'; export default ContentRender; diff --git a/client/src/components/Messages/MessageContent.tsx b/client/src/components/Messages/MessageContent.tsx index 0e53b1c840..977e397022 100644 --- a/client/src/components/Messages/MessageContent.tsx +++ b/client/src/components/Messages/MessageContent.tsx @@ -1,5 +1,5 @@ import React from 'react'; -import { useMessageProcess } from '~/hooks'; +import { useMessageProcess, useMemoizedChatContext } from '~/hooks'; import type { TMessageProps } from '~/common'; import MultiMessage from '~/components/Chat/Messages/MultiMessage'; @@ -28,6 +28,7 @@ export default function MessageContent(props: TMessageProps) { message: props.message, }); const { message, currentEditId, setCurrentEditId } = props; + const { chatContext, effectiveIsSubmitting } = useMemoizedChatContext(message, isSubmitting); if (!message || typeof message !== 'object') { return null; @@ -39,7 +40,11 @@ export default function MessageContent(props: TMessageProps) { <>
- +
({ + conversation, + setConversation, + generateConversation, + }), + [conversation, setConversation, generateConversation], + ); } diff --git a/client/src/hooks/Files/index.ts b/client/src/hooks/Files/index.ts index df86c02a96..499572f0e0 100644 --- a/client/src/hooks/Files/index.ts +++ b/client/src/hooks/Files/index.ts @@ -1,6 +1,6 @@ export { default as useDeleteFilesFromTable } from './useDeleteFilesFromTable'; export { default as useSetFilesToDelete } from './useSetFilesToDelete'; -export { default as useFileHandling } from './useFileHandling'; +export { default as useFileHandling, useFileHandlingNoChatContext } from './useFileHandling'; export { default as useFileDeletion } from './useFileDeletion'; export { default as useUpdateFiles } from './useUpdateFiles'; export { default as useDragHelpers } from './useDragHelpers'; diff --git a/client/src/hooks/Messages/index.ts b/client/src/hooks/Messages/index.ts index a78a1ef553..439b7e152e 100644 --- a/client/src/hooks/Messages/index.ts +++ b/client/src/hooks/Messages/index.ts @@ -5,6 +5,7 @@ export { default as useSubmitMessage } from './useSubmitMessage'; export type { ContentMetadataResult } from './useContentMetadata'; export { default as useExpandCollapse } from './useExpandCollapse'; export { default as useMessageActions } from './useMessageActions'; +export { default as useMemoizedChatContext } from './useMemoizedChatContext'; export { default as useMessageProcess } from './useMessageProcess'; export { default as useMessageHelpers } from './useMessageHelpers'; export { default as useCopyToClipboard } from './useCopyToClipboard'; diff --git a/client/src/hooks/Messages/useMemoizedChatContext.ts b/client/src/hooks/Messages/useMemoizedChatContext.ts new file mode 100644 index 0000000000..aa35372a8e --- /dev/null +++ b/client/src/hooks/Messages/useMemoizedChatContext.ts @@ -0,0 +1,80 @@ +import { useRef, useMemo } from 'react'; +import type { TMessage } from 'librechat-data-provider'; +import type { TMessageChatContext } from '~/common/types'; +import { useChatContext } from '~/Providers'; + +/** + * Creates a stable `TMessageChatContext` object for memo'd message components. + * + * Subscribes to `useChatContext()` internally (intended to be called from non-memo'd + * wrapper components like `Message` and `MessageContent`), then produces: + * - A `chatContext` object that stays referentially stable during streaming + * (uses a getter for `isSubmitting` backed by a ref) + * - A stable `conversation` reference that only updates when rendering-relevant fields change + * - An `effectiveIsSubmitting` value (false for non-latest messages) + */ +export default function useMemoizedChatContext( + message: TMessage | null | undefined, + isSubmitting: boolean, +) { + const chatCtx = useChatContext(); + + const isSubmittingRef = useRef(isSubmitting); + isSubmittingRef.current = isSubmitting; + + /** + * Stabilize conversation: only update when rendering-relevant fields change, + * not on every metadata update (e.g., title generation). + */ + const stableConversation = useMemo( + () => chatCtx.conversation, + // eslint-disable-next-line react-hooks/exhaustive-deps + [ + chatCtx.conversation?.conversationId, + chatCtx.conversation?.endpoint, + chatCtx.conversation?.endpointType, + chatCtx.conversation?.model, + chatCtx.conversation?.agent_id, + chatCtx.conversation?.assistant_id, + ], + ); + + /** + * `isSubmitting` is included in deps so that chatContext gets a new reference + * when streaming starts/ends (2x per session). This ensures HoverButtons + * re-renders to update regenerate/edit button visibility via useGenerationsByLatest. + * The getter pattern is still valuable: callbacks reading chatContext.isSubmitting + * at call-time always get the current value even between these re-renders. + */ + const chatContext: TMessageChatContext = useMemo( + () => ({ + ask: chatCtx.ask, + index: chatCtx.index, + regenerate: chatCtx.regenerate, + conversation: stableConversation, + latestMessageId: chatCtx.latestMessageId, + latestMessageDepth: chatCtx.latestMessageDepth, + handleContinue: chatCtx.handleContinue, + get isSubmitting() { + return isSubmittingRef.current; + }, + }), + // eslint-disable-next-line react-hooks/exhaustive-deps + [ + chatCtx.ask, + chatCtx.index, + chatCtx.regenerate, + stableConversation, + chatCtx.latestMessageId, + chatCtx.latestMessageDepth, + chatCtx.handleContinue, + isSubmitting, // intentional: forces new reference on streaming start/end so HoverButtons re-renders + ], + ); + + const messageId = message?.messageId ?? null; + const isLatestMessage = messageId === chatCtx.latestMessageId; + const effectiveIsSubmitting = isLatestMessage ? isSubmitting : false; + + return { chatContext, effectiveIsSubmitting }; +} diff --git a/client/src/hooks/Messages/useMessageActions.tsx b/client/src/hooks/Messages/useMessageActions.tsx index e8946b895b..590ba6a40e 100644 --- a/client/src/hooks/Messages/useMessageActions.tsx +++ b/client/src/hooks/Messages/useMessageActions.tsx @@ -11,7 +11,8 @@ import { TUpdateFeedbackRequest, } from 'librechat-data-provider'; import type { TMessageProps } from '~/common'; -import { useChatContext, useAssistantsMapContext, useAgentsMapContext } from '~/Providers'; +import type { TMessageChatContext } from '~/common/types'; +import { useAssistantsMapContext, useAgentsMapContext } from '~/Providers'; import useCopyToClipboard from './useCopyToClipboard'; import { useAuthContext } from '~/hooks/AuthContext'; import { useGetAddedConvo } from '~/hooks/Chat'; @@ -23,24 +24,33 @@ export type TMessageActions = Pick< 'message' | 'currentEditId' | 'setCurrentEditId' > & { searchResults?: { [key: string]: SearchResultData }; + /** + * Stable context object passed from wrapper components to avoid subscribing + * to ChatContext inside memo'd components (which would bypass React.memo). + * The `isSubmitting` property uses a getter backed by a ref, so it always + * returns the current value at call-time without triggering re-renders. + */ + chatContext: TMessageChatContext; }; export default function useMessageActions(props: TMessageActions) { const localize = useLocalize(); const { user } = useAuthContext(); const UsernameDisplay = useRecoilValue(store.UsernameDisplay); - const { message, currentEditId, setCurrentEditId, searchResults } = props; + const { message, currentEditId, setCurrentEditId, searchResults, chatContext } = props; const { ask, index, regenerate, - isSubmitting, conversation, latestMessageId, latestMessageDepth, handleContinue, - } = useChatContext(); + // NOTE: isSubmitting is intentionally NOT destructured here. + // chatContext.isSubmitting is a getter backed by a ref โ€” destructuring + // would capture a one-time snapshot. Always access via chatContext.isSubmitting. + } = chatContext; const getAddedConvo = useGetAddedConvo(); @@ -98,13 +108,18 @@ export default function useMessageActions(props: TMessageActions) { } }, [agentsMap, conversation?.agent_id, conversation?.endpoint, message?.model]); + /** + * chatContext.isSubmitting is a getter backed by the wrapper's ref, + * so it always returns the current value at call-time โ€” even for + * non-latest messages that don't re-render during streaming. + */ const regenerateMessage = useCallback(() => { - if ((isSubmitting && isCreatedByUser === true) || !message) { + if ((chatContext.isSubmitting && isCreatedByUser === true) || !message) { return; } regenerate(message, { addedConvo: getAddedConvo() }); - }, [isSubmitting, isCreatedByUser, message, regenerate, getAddedConvo]); + }, [chatContext, isCreatedByUser, message, regenerate, getAddedConvo]); const copyToClipboard = useCopyToClipboard({ text, content, searchResults }); diff --git a/client/src/hooks/Messages/useMessageProcess.tsx b/client/src/hooks/Messages/useMessageProcess.tsx index 37738b50a9..bb49670a2f 100644 --- a/client/src/hooks/Messages/useMessageProcess.tsx +++ b/client/src/hooks/Messages/useMessageProcess.tsx @@ -1,6 +1,6 @@ import throttle from 'lodash/throttle'; import { Constants } from 'librechat-data-provider'; -import { useEffect, useRef, useCallback, useMemo } from 'react'; +import { useEffect, useRef, useMemo } from 'react'; import type { TMessage } from 'librechat-data-provider'; import { getTextKey, TEXT_KEY_DIVIDER, logger } from '~/utils'; import { useMessagesViewContext } from '~/Providers'; @@ -56,24 +56,25 @@ export default function useMessageProcess({ message }: { message?: TMessage | nu } }, [hasNoChildren, message, setLatestMessage, conversation?.conversationId]); - const handleScroll = useCallback( - (event: unknown | TouchEvent | WheelEvent) => { - throttle(() => { + /** Use ref for isSubmitting to stabilize handleScroll across isSubmitting changes */ + const isSubmittingRef = useRef(isSubmitting); + isSubmittingRef.current = isSubmitting; + + const handleScroll = useMemo( + () => + throttle((event: unknown) => { logger.log( 'message_scrolling', - `useMessageProcess: setting abort scroll to ${isSubmitting}, handleScroll event`, + `useMessageProcess: setting abort scroll to ${isSubmittingRef.current}, handleScroll event`, event, ); - if (isSubmitting) { - setAbortScroll(true); - } else { - setAbortScroll(false); - } - }, 500)(); - }, - [isSubmitting, setAbortScroll], + setAbortScroll(isSubmittingRef.current); + }, 500), + [setAbortScroll], ); + useEffect(() => () => handleScroll.cancel(), [handleScroll]); + return { handleScroll, isSubmitting,