From 877c2efc851342e876250e1b6aebb8e471a09f7a Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Sat, 28 Mar 2026 16:43:50 -0400 Subject: [PATCH] =?UTF-8?q?=F0=9F=8F=97=EF=B8=8F=20feat:=20bulkWrite=20iso?= =?UTF-8?q?lation,=20pre-auth=20context,=20strict-mode=20fixes=20(#12445)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: wrap seedDatabase() in runAsSystem() for strict tenant mode seedDatabase() was called without tenant context at startup, causing every Mongoose operation inside it to throw when TENANT_ISOLATION_STRICT=true. Wrapping in runAsSystem() gives it the SYSTEM_TENANT_ID sentinel so the isolation plugin skips filtering, matching the pattern already used for performStartupChecks and updateInterfacePermissions. * fix: chain tenantContextMiddleware in optionalJwtAuth optionalJwtAuth populated req.user but never established ALS tenant context, unlike requireJwtAuth which chains tenantContextMiddleware after successful auth. Authenticated users hitting routes with optionalJwtAuth (e.g. /api/banner) had no tenant isolation. * feat: tenant-safe bulkWrite wrapper and call-site migration Mongoose's bulkWrite() does not trigger schema-level middleware hooks, so the applyTenantIsolation plugin cannot intercept it. This adds a tenantSafeBulkWrite() utility that injects the current ALS tenant context into every operation's filter/document before delegating to native bulkWrite. Migrates all 8 runtime bulkWrite call sites: - agentCategory (seedCategories, ensureDefaultCategories) - conversation (bulkSaveConvos) - message (bulkSaveMessages) - file (batchUpdateFiles) - conversationTag (updateTagsForConversation, bulkIncrementTagCounts) - aclEntry (bulkWriteAclEntries) systemGrant.seedSystemGrants is intentionally not migrated — it uses explicit tenantId: { $exists: false } filters and is exempt from the isolation plugin. * feat: pre-auth tenant middleware and tenant-scoped config cache Adds preAuthTenantMiddleware that reads X-Tenant-Id from the request header and wraps downstream in tenantStorage ALS context. Wired onto /oauth, /api/auth, /api/config, and /api/share — unauthenticated routes that need tenant scoping before JWT auth runs. The /api/config cache key is now tenant-scoped (STARTUP_CONFIG:${tenantId}) so multi-tenant deployments serve the correct login page config per tenant. The middleware is intentionally minimal — no subdomain parsing, no OIDC claim extraction. The private fork's reverse proxy or auth gateway sets the header. * feat: accept optional tenantId in updateInterfacePermissions When tenantId is provided, the function re-enters inside tenantStorage.run({ tenantId }) so all downstream Mongoose queries target that tenant's roles instead of the system context. This lets the private fork's tenant provisioning flow call updateInterfacePermissions per-tenant after creating tenant-scoped ADMIN/USER roles. * fix: tenant-filter $lookup in getPromptGroup aggregation The $lookup stage in getPromptGroup() queried the prompts collection without tenant filtering. While the outer PromptGroup aggregate is protected by the tenantIsolation plugin's pre('aggregate') hook, $lookup runs as an internal MongoDB operation that bypasses Mongoose hooks entirely. Converts from simple field-based $lookup to pipeline-based $lookup with an explicit tenantId match when tenant context is active. * fix: replace field-level unique indexes with tenant-scoped compounds Field-level unique:true creates a globally-unique single-field index in MongoDB, which would cause insert failures across tenants sharing the same ID values. - agent.id: removed field-level unique, added { id, tenantId } compound - convo.conversationId: removed field-level unique (compound at line 50 already exists: { conversationId, user, tenantId }) - message.messageId: removed field-level unique (compound at line 165 already exists: { messageId, user, tenantId }) - preset.presetId: removed field-level unique, added { presetId, tenantId } compound * fix: scope MODELS_CONFIG, ENDPOINT_CONFIG, PLUGINS, TOOLS caches by tenant These caches store per-tenant configuration (available models, endpoint settings, plugin availability, tool definitions) but were using global cache keys. In multi-tenant mode, one tenant's cached config would be served to all tenants. Appends :${tenantId} to cache keys when tenant context is active. Falls back to the unscoped key when no tenant context exists (backward compatible for single-tenant OSS deployments). Covers all read, write, and delete sites: - ModelController.js: get/set MODELS_CONFIG - PluginController.js: get/set PLUGINS, get/set TOOLS - getEndpointsConfig.js: get/set/delete ENDPOINT_CONFIG - app.js: delete ENDPOINT_CONFIG (clearEndpointConfigCache) - mcp.js: delete TOOLS (updateMCPTools, mergeAppTools) - importers.js: get ENDPOINT_CONFIG * fix: add getTenantId to PluginController spec mock The data-schemas mock was missing getTenantId, causing all PluginController tests to throw when the controller calls getTenantId() for tenant-scoped cache keys. * fix: address review findings — migration, strict-mode, DRY, types Addresses all CRITICAL, MAJOR, and MINOR review findings: F1 (CRITICAL): Add agents, conversations, messages, presets to SUPERSEDED_INDEXES in tenantIndexes.ts so dropSupersededTenantIndexes() drops the old single-field unique indexes that block multi-tenant inserts. F2 (CRITICAL): Unknown bulkWrite op types now throw in strict mode instead of silently passing through without tenant injection. F3 (MAJOR): Replace wildcard export with named export for tenantSafeBulkWrite, hiding _resetBulkWriteStrictCache from the public package API. F5 (MAJOR): Restore AnyBulkWriteOperation[] typing on bulkWriteAclEntries — the unparameterized wrapper accepts parameterized ops as a subtype. F7 (MAJOR): Fix config.js tenant precedence — JWT-derived req.user.tenantId now takes priority over the X-Tenant-Id header for authenticated requests. F8 (MINOR): Extract scopedCacheKey() helper into tenantContext.ts and replace all 11 inline occurrences across 7 files. F9 (MINOR): Use simple localField/foreignField $lookup for the non-tenant getPromptGroup path (more efficient index seeks). F12 (NIT): Remove redundant BulkOp type alias. F13 (NIT): Remove debug log that leaked raw tenantId. * fix: add new superseded indexes to tenantIndexes test fixture The test creates old indexes to verify the migration drops them. Missing fixture entries for agents.id_1, conversations.conversationId_1, messages.messageId_1, and presets.presetId_1 caused the count assertion to fail (expected 22, got 18). * fix: restore logger.warn for unknown bulk op types in non-strict mode * fix: block SYSTEM_TENANT_ID sentinel from external header input CRITICAL: preAuthTenantMiddleware accepted any string as X-Tenant-Id, including '__SYSTEM__'. The tenantIsolation plugin treats SYSTEM_TENANT_ID as an explicit bypass — skipping ALL query filters. A client sending X-Tenant-Id: __SYSTEM__ to pre-auth routes (/api/share, /api/config, /api/auth, /oauth) would execute Mongoose operations without tenant isolation. Fixes: - preAuthTenantMiddleware rejects SYSTEM_TENANT_ID in header - scopedCacheKey returns the base key (not key:__SYSTEM__) in system context, preventing stale cache entries during runAsSystem() - updateInterfacePermissions guards tenantId against SYSTEM_TENANT_ID - $lookup pipeline separates $expr join from constant tenantId match for better index utilization - Regression test for sentinel rejection in preAuthTenant.spec.ts - Remove redundant getTenantId() call in config.js * test: add missing deleteMany/replaceOne coverage, fix vacuous ALS assertions bulkWrite spec: - deleteMany: verifies tenant-scoped deletion leaves other tenants untouched - replaceOne: verifies tenantId injected into both filter and replacement - replaceOne overwrite: verifies a conflicting tenantId in the replacement document is overwritten by the ALS tenant (defense-in-depth) - empty ops array: verifies graceful handling preAuthTenant spec: - All negative-case tests now use the capturedNext pattern to verify getTenantId() inside the middleware's execution context, not the test runner's outer frame (which was always undefined regardless) * feat: tenant-isolate MESSAGES cache, FLOWS cache, and GenerationJobManager MESSAGES cache (streamAudio.js): - Cache key now uses scopedCacheKey(messageId) to prefix with tenantId, preventing cross-tenant message content reads during TTS streaming. FLOWS cache (FlowStateManager): - getFlowKey() now generates ${type}:${tenantId}:${flowId} when tenant context is active, isolating OAuth flow state per tenant. GenerationJobManager: - tenantId added to SerializableJobData and GenerationJobMetadata - createJob() captures the current ALS tenant context (excluding SYSTEM_TENANT_ID) and stores it in job metadata - SSE subscription endpoint validates job.metadata.tenantId matches req.user.tenantId, blocking cross-tenant stream access - Both InMemoryJobStore and RedisJobStore updated to accept tenantId * fix: add getTenantId and SYSTEM_TENANT_ID to MCP OAuth test mocks FlowStateManager.getFlowKey() now calls getTenantId() for tenant-scoped flow keys. The 4 MCP OAuth test files mock @librechat/data-schemas without these exports, causing TypeError at runtime. * fix: correct import ordering per AGENTS.md conventions Package imports sorted shortest to longest line length, local imports sorted longest to shortest — fixes ordering violations introduced by our new imports across 8 files. * fix: deserialize tenantId in RedisJobStore — cross-tenant SSE guard was no-op in Redis mode serializeJob() writes tenantId to the Redis hash via Object.entries, but deserializeJob() manually enumerates fields and omitted tenantId. Every getJob() from Redis returned tenantId: undefined, causing the SSE route's cross-tenant guard to short-circuit (undefined && ... → false). * test: SSE tenant guard, FlowStateManager key consistency, ALS scope docs SSE stream tenant tests (streamTenant.spec.js): - Cross-tenant user accessing another tenant's stream → 403 - Same-tenant user accessing own stream → allowed - OSS mode (no tenantId on job) → tenant check skipped FlowStateManager tenant tests (manager.tenant.spec.ts): - completeFlow finds flow created under same tenant context - completeFlow does NOT find flow under different tenant context - Unscoped flows are separate from tenant-scoped flows Documentation: - JSDoc on getFlowKey documenting ALS context consistency requirement - Comment on streamAudio.js scopedCacheKey capture site * fix: SSE stream tests hang on success path, remove internal fork references The success-path tests entered the SSE streaming code which never closes, causing timeout. Mock subscribe() to end the response immediately. Restructured assertions to verify non-403/non-404. Removed "private fork" and "OSS" references from code and test descriptions — replaced with "deployment layer", "multi-tenant deployments", and "single-tenant mode". * fix: address review findings — test rigor, tenant ID validation, docs F1: SSE stream tests now mock subscribe() with correct signature (streamId, writeEvent, onDone, onError) and assert 200 status, verifying the tenant guard actually allows through same-tenant users. F2: completeFlow logs the attempted key and ALS tenantId when flow is not found, so reverse proxy misconfiguration (missing X-Tenant-Id on OAuth callback) produces an actionable warning. F3/F10: preAuthTenantMiddleware validates tenant ID format — rejects colons, special characters, and values exceeding 128 chars. Trims whitespace. Prevents cache key collisions via crafted headers. F4: Documented cache invalidation scope limitation in clearEndpointConfigCache — only the calling tenant's key is cleared; other tenants expire via TTL. F7: getFlowKey JSDoc now lists all 8 methods requiring consistent ALS context. F8: Added dedicated scopedCacheKey unit tests — base key without context, base key in system context, scoped key with tenant, no ALS leakage across scope boundaries. * fix: revert flow key tenant scoping, fix SSE test timing FlowStateManager: Reverts tenant-scoped flow keys. OAuth callbacks arrive without tenant ALS context (provider redirects don't carry X-Tenant-Id), so completeFlow/failFlow would never find flows created under tenant context. Flow IDs are random UUIDs with no collision risk, and flow data is ephemeral (TTL-bounded). SSE tests: Use process.nextTick for onDone callback so Express response headers are flushed before res.write/res.end are called. * fix: restore getTenantId import for completeFlow diagnostic log * fix: correct completeFlow warning message, add missing flow test The warning referenced X-Tenant-Id header consistency which was only relevant when flow keys were tenant-scoped (since reverted). Updated to list actual causes: TTL expiry, missing flow, or routing to a different instance without shared Keyv storage. Removed the getTenantId() call and import — no longer needed since flow keys are unscoped. Added test for the !flowState branch in completeFlow — verifies return false and logger.warn on nonexistent flow ID. * fix: add explicit return type to recursive updateInterfacePermissions The recursive call (tenantId branch calls itself without tenantId) causes TypeScript to infer circular return type 'any'. Adding explicit Promise satisfies the rollup typescript plugin. * fix: update MCPOAuthRaceCondition test to match new completeFlow warning * fix: clearEndpointConfigCache deletes both scoped and unscoped keys Unauthenticated /api/endpoints requests populate the unscoped ENDPOINT_CONFIG key. Admin config mutations clear only the tenant-scoped key, leaving the unscoped entry stale indefinitely. Now deletes both when in tenant context. * fix: tenant guard on abort/status endpoints, warn logs, test coverage F1: Add tenant guard to /chat/status/:conversationId and /chat/abort matching the existing guard on /chat/stream/:streamId. The status endpoint exposes aggregatedContent (AI response text) which requires tenant-level access control. F2: preAuthTenantMiddleware now logs warn for rejected __SYSTEM__ sentinel and malformed tenant IDs, providing observability for bypass probing attempts. F3: Abort fallback path (getActiveJobIdsForUser) now has tenant check after resolving the job. F4: Test for strict mode + SYSTEM_TENANT_ID — verifies runAsSystem bypasses tenantSafeBulkWrite without throwing in strict mode. F5: Test for job with tenantId + user without tenantId → 403. F10: Regex uses idiomatic hyphen-at-start form. F11: Test descriptions changed from "rejects" to "ignores" since middleware calls next() (not 4xx). Also fixes MCPOAuthRaceCondition test assertion to match updated completeFlow warning message. * fix: test coverage for logger.warn, status/abort guards, consistency A: preAuthTenant spec now mocks logger and asserts warn calls for __SYSTEM__ sentinel, malformed characters, and oversized headers. B: streamTenant spec expanded with status and abort endpoint tests — cross-tenant status returns 403, same-tenant returns 200 with body, cross-tenant abort returns 403. C: Abort endpoint uses req.user.tenantId (not req.user?.tenantId) matching stream/status pattern — requireJwtAuth guarantees req.user. D: Malformed header warning now includes ip in log metadata, matching the sentinel warning for consistent SOC correlation. * fix: assert ip field in malformed header warn tests * fix: parallelize cache deletes, document tenant guard, fix import order - clearEndpointConfigCache uses Promise.all for independent cache deletes instead of sequential awaits - SSE stream tenant guard has inline comment explaining backward-compat behavior for untenanted legacy jobs - conversation.ts local imports reordered longest-to-shortest per AGENTS.md * fix: tenant-qualify userJobs keys, document tenant guard backward-compat Job store userJobs keys now include tenantId when available: - Redis: stream:user:{tenantId:userId}:jobs (falls back to stream:user:{userId}:jobs when no tenant) - InMemory: composite key tenantId:userId in userJobMap getActiveJobIdsByUser/getActiveJobIdsForUser accept optional tenantId parameter, threaded through from req.user.tenantId at all call sites (/chat/active and /chat/abort fallback). Added inline comments on all three SSE tenant guards explaining the backward-compat design: untenanted legacy jobs remain accessible when the userId check passes. * fix: parallelize cache deletes, document tenant guard, fix import order Fix InMemoryJobStore.getActiveJobIdsByUser empty-set cleanup to use the tenant-qualified userKey instead of bare userId — prevents orphaned empty Sets accumulating in userJobMap for multi-tenant users. Document cross-tenant staleness in clearEndpointConfigCache JSDoc — other tenants' scoped keys expire via TTL, not active invalidation. * fix: cleanup userJobMap leak, startup warning, DRY tenant guard, docs F1: InMemoryJobStore.cleanup() now removes entries from userJobMap before calling deleteJob, preventing orphaned empty Sets from accumulating with tenant-qualified composite keys. F2: Startup warning when TENANT_ISOLATION_STRICT is active — reminds operators to configure reverse proxy to control X-Tenant-Id header. F3: mergeAppTools JSDoc documents that tenant-scoped TOOLS keys are not actively invalidated (matching clearEndpointConfigCache pattern). F5: Abort handler getActiveJobIdsForUser call uses req.user.tenantId (not req.user?.tenantId) — consistent with stream/status handlers. F6: updateInterfacePermissions JSDoc clarifies SYSTEM_TENANT_ID behavior — falls through to caller's ALS context. F7: Extracted hasTenantMismatch() helper, replacing three identical inline tenant guard blocks across stream/status/abort endpoints. F9: scopedCacheKey JSDoc documents both passthrough cases (no context and SYSTEM_TENANT_ID context). * fix: clean userJobMap in evictOldest — same leak as cleanup() --- api/server/controllers/ModelController.js | 10 +- api/server/controllers/PluginController.js | 12 +- .../controllers/PluginController.spec.js | 1 + api/server/index.js | 20 +- api/server/middleware/optionalJwtAuth.js | 6 +- .../agents/__tests__/streamTenant.spec.js | 186 +++++++++ api/server/routes/agents/index.js | 27 +- api/server/routes/config.js | 12 +- api/server/services/Config/app.js | 18 +- .../services/Config/getEndpointsConfig.js | 8 +- api/server/services/Config/mcp.js | 11 +- .../services/Files/Audio/streamAudio.js | 7 +- api/server/utils/import/importers.js | 4 +- packages/api/src/app/permissions.ts | 17 +- packages/api/src/flow/manager.tenant.spec.ts | 49 +++ packages/api/src/flow/manager.ts | 10 +- .../__tests__/MCPOAuthCSRFFallback.test.ts | 2 + .../src/mcp/__tests__/MCPOAuthFlow.test.ts | 2 + .../__tests__/MCPOAuthRaceCondition.test.ts | 4 +- .../mcp/__tests__/MCPOAuthTokenExpiry.test.ts | 2 + packages/api/src/middleware/index.ts | 1 + .../api/src/middleware/preAuthTenant.spec.ts | 129 ++++++ packages/api/src/middleware/preAuthTenant.ts | 72 ++++ .../api/src/stream/GenerationJobManager.ts | 11 +- .../implementations/InMemoryJobStore.ts | 38 +- .../stream/implementations/RedisJobStore.ts | 14 +- .../api/src/stream/interfaces/IJobStore.ts | 4 +- packages/api/src/types/stream.ts | 1 + .../src/config/tenantContext.spec.ts | 26 ++ .../data-schemas/src/config/tenantContext.ts | 13 + packages/data-schemas/src/index.ts | 8 +- packages/data-schemas/src/methods/aclEntry.ts | 7 +- .../data-schemas/src/methods/agentCategory.ts | 5 +- .../data-schemas/src/methods/conversation.ts | 5 +- .../src/methods/conversationTag.ts | 5 +- packages/data-schemas/src/methods/file.ts | 3 +- packages/data-schemas/src/methods/message.ts | 3 +- packages/data-schemas/src/methods/prompt.ts | 40 +- .../src/migrations/tenantIndexes.spec.ts | 14 + .../src/migrations/tenantIndexes.ts | 6 +- packages/data-schemas/src/schema/agent.ts | 3 +- packages/data-schemas/src/schema/convo.ts | 1 - packages/data-schemas/src/schema/message.ts | 1 - packages/data-schemas/src/schema/preset.ts | 3 +- packages/data-schemas/src/utils/index.ts | 1 + .../src/utils/tenantBulkWrite.spec.ts | 376 ++++++++++++++++++ .../data-schemas/src/utils/tenantBulkWrite.ts | 109 +++++ 47 files changed, 1224 insertions(+), 83 deletions(-) create mode 100644 api/server/routes/agents/__tests__/streamTenant.spec.js create mode 100644 packages/api/src/flow/manager.tenant.spec.ts create mode 100644 packages/api/src/middleware/preAuthTenant.spec.ts create mode 100644 packages/api/src/middleware/preAuthTenant.ts create mode 100644 packages/data-schemas/src/config/tenantContext.spec.ts create mode 100644 packages/data-schemas/src/utils/tenantBulkWrite.spec.ts create mode 100644 packages/data-schemas/src/utils/tenantBulkWrite.ts diff --git a/api/server/controllers/ModelController.js b/api/server/controllers/ModelController.js index 805d9eef27..1741d3f6b1 100644 --- a/api/server/controllers/ModelController.js +++ b/api/server/controllers/ModelController.js @@ -1,5 +1,5 @@ -const { logger } = require('@librechat/data-schemas'); const { CacheKeys } = require('librechat-data-provider'); +const { logger, scopedCacheKey } = require('@librechat/data-schemas'); const { loadDefaultModels, loadConfigModels } = require('~/server/services/Config'); const { getLogStores } = require('~/cache'); @@ -9,7 +9,8 @@ const { getLogStores } = require('~/cache'); */ const getModelsConfig = async (req) => { const cache = getLogStores(CacheKeys.CONFIG_STORE); - let modelsConfig = await cache.get(CacheKeys.MODELS_CONFIG); + const cacheKey = scopedCacheKey(CacheKeys.MODELS_CONFIG); + let modelsConfig = await cache.get(cacheKey); if (!modelsConfig) { modelsConfig = await loadModels(req); } @@ -24,7 +25,8 @@ const getModelsConfig = async (req) => { */ async function loadModels(req) { const cache = getLogStores(CacheKeys.CONFIG_STORE); - const cachedModelsConfig = await cache.get(CacheKeys.MODELS_CONFIG); + const cacheKey = scopedCacheKey(CacheKeys.MODELS_CONFIG); + const cachedModelsConfig = await cache.get(cacheKey); if (cachedModelsConfig) { return cachedModelsConfig; } @@ -33,7 +35,7 @@ async function loadModels(req) { const modelConfig = { ...defaultModelsConfig, ...customModelsConfig }; - await cache.set(CacheKeys.MODELS_CONFIG, modelConfig); + await cache.set(cacheKey, modelConfig); return modelConfig; } diff --git a/api/server/controllers/PluginController.js b/api/server/controllers/PluginController.js index 14dd284c30..7c47fe4d57 100644 --- a/api/server/controllers/PluginController.js +++ b/api/server/controllers/PluginController.js @@ -1,5 +1,5 @@ -const { logger } = require('@librechat/data-schemas'); const { CacheKeys } = require('librechat-data-provider'); +const { logger, scopedCacheKey } = require('@librechat/data-schemas'); const { getToolkitKey, checkPluginAuth, filterUniquePlugins } = require('@librechat/api'); const { getCachedTools, setCachedTools } = require('~/server/services/Config'); const { availableTools, toolkits } = require('~/app/clients/tools'); @@ -9,7 +9,8 @@ const { getLogStores } = require('~/cache'); const getAvailablePluginsController = async (req, res) => { try { const cache = getLogStores(CacheKeys.TOOL_CACHE); - const cachedPlugins = await cache.get(CacheKeys.PLUGINS); + const pluginsCacheKey = scopedCacheKey(CacheKeys.PLUGINS); + const cachedPlugins = await cache.get(pluginsCacheKey); if (cachedPlugins) { res.status(200).json(cachedPlugins); return; @@ -37,7 +38,7 @@ const getAvailablePluginsController = async (req, res) => { plugins = plugins.filter((plugin) => !filteredTools.includes(plugin.pluginKey)); } - await cache.set(CacheKeys.PLUGINS, plugins); + await cache.set(pluginsCacheKey, plugins); res.status(200).json(plugins); } catch (error) { res.status(500).json({ message: error.message }); @@ -64,7 +65,8 @@ const getAvailableTools = async (req, res) => { return res.status(401).json({ message: 'Unauthorized' }); } const cache = getLogStores(CacheKeys.TOOL_CACHE); - const cachedToolsArray = await cache.get(CacheKeys.TOOLS); + const toolsCacheKey = scopedCacheKey(CacheKeys.TOOLS); + const cachedToolsArray = await cache.get(toolsCacheKey); const appConfig = req.config ?? (await getAppConfig({ role: req.user?.role, tenantId: req.user?.tenantId })); @@ -115,7 +117,7 @@ const getAvailableTools = async (req, res) => { } const finalTools = filterUniquePlugins(toolsOutput); - await cache.set(CacheKeys.TOOLS, finalTools); + await cache.set(toolsCacheKey, finalTools); res.status(200).json(finalTools); } catch (error) { diff --git a/api/server/controllers/PluginController.spec.js b/api/server/controllers/PluginController.spec.js index 06a51a3bd6..fdbc2401ce 100644 --- a/api/server/controllers/PluginController.spec.js +++ b/api/server/controllers/PluginController.spec.js @@ -8,6 +8,7 @@ jest.mock('@librechat/data-schemas', () => ({ error: jest.fn(), warn: jest.fn(), }, + scopedCacheKey: jest.fn((key) => key), })); jest.mock('~/server/services/Config', () => ({ diff --git a/api/server/index.js b/api/server/index.js index 813b453468..4b919b1ceb 100644 --- a/api/server/index.js +++ b/api/server/index.js @@ -21,6 +21,7 @@ const { createStreamServices, initializeFileStorage, updateInterfacePermissions, + preAuthTenantMiddleware, } = require('@librechat/api'); const { connectDb, indexSync } = require('~/db'); const initializeOAuthReconnectManager = require('./services/initializeOAuthReconnectManager'); @@ -59,7 +60,14 @@ const startServer = async () => { app.disable('x-powered-by'); app.set('trust proxy', trusted_proxy); - await seedDatabase(); + if (isEnabled(process.env.TENANT_ISOLATION_STRICT)) { + logger.warn( + '[Security] TENANT_ISOLATION_STRICT is active. Ensure your reverse proxy strips or sets ' + + 'the X-Tenant-Id header — untrusted clients must not be able to set it directly.', + ); + } + + await runAsSystem(seedDatabase); const appConfig = await getAppConfig({ baseOnly: true }); initializeFileStorage(appConfig); await runAsSystem(async () => { @@ -139,9 +147,11 @@ const startServer = async () => { /* Per-request capability cache — must be registered before any route that calls hasCapability */ app.use(capabilityContextMiddleware); - app.use('/oauth', routes.oauth); + /* Pre-auth tenant context for unauthenticated routes that need tenant scoping. + * The reverse proxy / auth gateway sets `X-Tenant-Id` header for multi-tenant deployments. */ + app.use('/oauth', preAuthTenantMiddleware, routes.oauth); /* API Endpoints */ - app.use('/api/auth', routes.auth); + app.use('/api/auth', preAuthTenantMiddleware, routes.auth); app.use('/api/admin', routes.adminAuth); app.use('/api/admin/config', routes.adminConfig); app.use('/api/admin/groups', routes.adminGroups); @@ -159,11 +169,11 @@ const startServer = async () => { app.use('/api/endpoints', routes.endpoints); app.use('/api/balance', routes.balance); app.use('/api/models', routes.models); - app.use('/api/config', routes.config); + app.use('/api/config', preAuthTenantMiddleware, routes.config); app.use('/api/assistants', routes.assistants); app.use('/api/files', await routes.files.initialize()); app.use('/images/', createValidateImageRequest(appConfig.secureImageLinks), routes.staticRoute); - app.use('/api/share', routes.share); + app.use('/api/share', preAuthTenantMiddleware, routes.share); app.use('/api/roles', routes.roles); app.use('/api/agents', routes.agents); app.use('/api/banner', routes.banner); diff --git a/api/server/middleware/optionalJwtAuth.js b/api/server/middleware/optionalJwtAuth.js index 2f59fdda4a..d46478d36e 100644 --- a/api/server/middleware/optionalJwtAuth.js +++ b/api/server/middleware/optionalJwtAuth.js @@ -1,9 +1,10 @@ const cookies = require('cookie'); const passport = require('passport'); -const { isEnabled } = require('@librechat/api'); +const { isEnabled, tenantContextMiddleware } = require('@librechat/api'); // This middleware does not require authentication, -// but if the user is authenticated, it will set the user object. +// but if the user is authenticated, it will set the user object +// and establish tenant ALS context. const optionalJwtAuth = (req, res, next) => { const cookieHeader = req.headers.cookie; const tokenProvider = cookieHeader ? cookies.parse(cookieHeader).token_provider : null; @@ -13,6 +14,7 @@ const optionalJwtAuth = (req, res, next) => { } if (user) { req.user = user; + return tenantContextMiddleware(req, res, next); } next(); }; diff --git a/api/server/routes/agents/__tests__/streamTenant.spec.js b/api/server/routes/agents/__tests__/streamTenant.spec.js new file mode 100644 index 0000000000..1f89953186 --- /dev/null +++ b/api/server/routes/agents/__tests__/streamTenant.spec.js @@ -0,0 +1,186 @@ +const express = require('express'); +const request = require('supertest'); + +const mockGenerationJobManager = { + getJob: jest.fn(), + subscribe: jest.fn(), + getResumeState: jest.fn(), + abortJob: jest.fn(), + getActiveJobIdsForUser: jest.fn().mockResolvedValue([]), +}; + +jest.mock('@librechat/data-schemas', () => ({ + ...jest.requireActual('@librechat/data-schemas'), + logger: { + debug: jest.fn(), + warn: jest.fn(), + error: jest.fn(), + info: jest.fn(), + }, +})); + +jest.mock('@librechat/api', () => ({ + ...jest.requireActual('@librechat/api'), + isEnabled: jest.fn().mockReturnValue(false), + GenerationJobManager: mockGenerationJobManager, +})); + +jest.mock('~/models', () => ({ + saveMessage: jest.fn(), +})); + +let mockUserId = 'user-123'; +let mockTenantId; + +jest.mock('~/server/middleware', () => ({ + uaParser: (req, res, next) => next(), + checkBan: (req, res, next) => next(), + requireJwtAuth: (req, res, next) => { + req.user = { id: mockUserId, tenantId: mockTenantId }; + next(); + }, + messageIpLimiter: (req, res, next) => next(), + configMiddleware: (req, res, next) => next(), + messageUserLimiter: (req, res, next) => next(), +})); + +jest.mock('~/server/routes/agents/chat', () => require('express').Router()); +jest.mock('~/server/routes/agents/v1', () => ({ + v1: require('express').Router(), +})); +jest.mock('~/server/routes/agents/openai', () => require('express').Router()); +jest.mock('~/server/routes/agents/responses', () => require('express').Router()); + +const agentsRouter = require('../index'); +const app = express(); +app.use(express.json()); +app.use('/agents', agentsRouter); + +function mockSubscribeSuccess() { + mockGenerationJobManager.subscribe.mockImplementation((_streamId, _writeEvent, onDone) => { + process.nextTick(() => onDone({ done: true })); + return { unsubscribe: jest.fn() }; + }); +} + +describe('SSE stream tenant isolation', () => { + beforeEach(() => { + jest.clearAllMocks(); + mockUserId = 'user-123'; + mockTenantId = undefined; + }); + + describe('GET /chat/stream/:streamId', () => { + it('returns 403 when a user from a different tenant accesses a stream', async () => { + mockUserId = 'user-456'; + mockTenantId = 'tenant-b'; + + mockGenerationJobManager.getJob.mockResolvedValue({ + metadata: { userId: 'user-456', tenantId: 'tenant-a' }, + status: 'running', + }); + + const res = await request(app).get('/agents/chat/stream/stream-123'); + expect(res.status).toBe(403); + expect(res.body.error).toBe('Unauthorized'); + }); + + it('returns 404 when stream does not exist', async () => { + mockGenerationJobManager.getJob.mockResolvedValue(null); + + const res = await request(app).get('/agents/chat/stream/nonexistent'); + expect(res.status).toBe(404); + }); + + it('proceeds past tenant guard when tenant matches', async () => { + mockUserId = 'user-123'; + mockTenantId = 'tenant-a'; + mockSubscribeSuccess(); + + mockGenerationJobManager.getJob.mockResolvedValue({ + metadata: { userId: 'user-123', tenantId: 'tenant-a' }, + status: 'running', + }); + + const res = await request(app).get('/agents/chat/stream/stream-123'); + expect(res.status).toBe(200); + expect(mockGenerationJobManager.subscribe).toHaveBeenCalledTimes(1); + }); + + it('proceeds past tenant guard when job has no tenantId (single-tenant mode)', async () => { + mockUserId = 'user-123'; + mockTenantId = undefined; + mockSubscribeSuccess(); + + mockGenerationJobManager.getJob.mockResolvedValue({ + metadata: { userId: 'user-123' }, + status: 'running', + }); + + const res = await request(app).get('/agents/chat/stream/stream-123'); + expect(res.status).toBe(200); + expect(mockGenerationJobManager.subscribe).toHaveBeenCalledTimes(1); + }); + + it('returns 403 when job has tenantId but user has no tenantId', async () => { + mockUserId = 'user-123'; + mockTenantId = undefined; + + mockGenerationJobManager.getJob.mockResolvedValue({ + metadata: { userId: 'user-123', tenantId: 'some-tenant' }, + status: 'running', + }); + + const res = await request(app).get('/agents/chat/stream/stream-123'); + expect(res.status).toBe(403); + }); + }); + + describe('GET /chat/status/:conversationId', () => { + it('returns 403 when tenant does not match', async () => { + mockUserId = 'user-123'; + mockTenantId = 'tenant-b'; + + mockGenerationJobManager.getJob.mockResolvedValue({ + metadata: { userId: 'user-123', tenantId: 'tenant-a' }, + status: 'running', + }); + + const res = await request(app).get('/agents/chat/status/conv-123'); + expect(res.status).toBe(403); + expect(res.body.error).toBe('Unauthorized'); + }); + + it('returns status when tenant matches', async () => { + mockUserId = 'user-123'; + mockTenantId = 'tenant-a'; + + mockGenerationJobManager.getJob.mockResolvedValue({ + metadata: { userId: 'user-123', tenantId: 'tenant-a' }, + status: 'running', + createdAt: Date.now(), + }); + mockGenerationJobManager.getResumeState.mockResolvedValue(null); + + const res = await request(app).get('/agents/chat/status/conv-123'); + expect(res.status).toBe(200); + expect(res.body.active).toBe(true); + }); + }); + + describe('POST /chat/abort', () => { + it('returns 403 when tenant does not match', async () => { + mockUserId = 'user-123'; + mockTenantId = 'tenant-b'; + + mockGenerationJobManager.getJob.mockResolvedValue({ + metadata: { userId: 'user-123', tenantId: 'tenant-a' }, + status: 'running', + }); + + const res = await request(app).post('/agents/chat/abort').send({ streamId: 'stream-123' }); + expect(res.status).toBe(403); + expect(res.body.error).toBe('Unauthorized'); + }); + }); +}); diff --git a/api/server/routes/agents/index.js b/api/server/routes/agents/index.js index 86966a3f3e..eb42046bed 100644 --- a/api/server/routes/agents/index.js +++ b/api/server/routes/agents/index.js @@ -17,6 +17,11 @@ const chat = require('./chat'); const { LIMIT_MESSAGE_IP, LIMIT_MESSAGE_USER } = process.env ?? {}; +/** Untenanted jobs (pre-multi-tenancy) remain accessible if the userId check passes. */ +function hasTenantMismatch(job, user) { + return job.metadata?.tenantId != null && job.metadata.tenantId !== user.tenantId; +} + const router = express.Router(); /** @@ -67,6 +72,10 @@ router.get('/chat/stream/:streamId', async (req, res) => { return res.status(403).json({ error: 'Unauthorized' }); } + if (hasTenantMismatch(job, req.user)) { + return res.status(403).json({ error: 'Unauthorized' }); + } + res.setHeader('Content-Encoding', 'identity'); res.setHeader('Content-Type', 'text/event-stream'); res.setHeader('Cache-Control', 'no-cache, no-transform'); @@ -150,7 +159,10 @@ router.get('/chat/stream/:streamId', async (req, res) => { * @returns { activeJobIds: string[] } */ router.get('/chat/active', async (req, res) => { - const activeJobIds = await GenerationJobManager.getActiveJobIdsForUser(req.user.id); + const activeJobIds = await GenerationJobManager.getActiveJobIdsForUser( + req.user.id, + req.user.tenantId, + ); res.json({ activeJobIds }); }); @@ -174,6 +186,10 @@ router.get('/chat/status/:conversationId', async (req, res) => { return res.status(403).json({ error: 'Unauthorized' }); } + if (hasTenantMismatch(job, req.user)) { + return res.status(403).json({ error: 'Unauthorized' }); + } + // Get resume state which contains aggregatedContent // Avoid calling both getStreamInfo and getResumeState (both fetch content) const resumeState = await GenerationJobManager.getResumeState(conversationId); @@ -213,7 +229,10 @@ router.post('/chat/abort', async (req, res) => { // This handles the case where frontend sends "new" but job was created with a UUID if (!job && userId) { logger.debug(`[AgentStream] Job not found by ID, checking active jobs for user: ${userId}`); - const activeJobIds = await GenerationJobManager.getActiveJobIdsForUser(userId); + const activeJobIds = await GenerationJobManager.getActiveJobIdsForUser( + userId, + req.user.tenantId, + ); if (activeJobIds.length > 0) { // Abort the most recent active job for this user jobStreamId = activeJobIds[0]; @@ -230,6 +249,10 @@ router.post('/chat/abort', async (req, res) => { return res.status(403).json({ error: 'Unauthorized' }); } + if (hasTenantMismatch(job, req.user)) { + return res.status(403).json({ error: 'Unauthorized' }); + } + logger.debug(`[AgentStream] Job found, aborting: ${jobStreamId}`); const abortResult = await GenerationJobManager.abortJob(jobStreamId); logger.debug(`[AgentStream] Job aborted successfully: ${jobStreamId}`, { diff --git a/api/server/routes/config.js b/api/server/routes/config.js index 0a68ccba4f..8caa180854 100644 --- a/api/server/routes/config.js +++ b/api/server/routes/config.js @@ -1,7 +1,7 @@ const express = require('express'); -const { logger } = require('@librechat/data-schemas'); const { isEnabled, getBalanceConfig } = require('@librechat/api'); const { CacheKeys, defaultSocialLogins } = require('librechat-data-provider'); +const { logger, getTenantId, scopedCacheKey } = require('@librechat/data-schemas'); const { getLdapConfig } = require('~/server/services/Config/ldap'); const { getAppConfig } = require('~/server/services/Config/app'); const { getLogStores } = require('~/cache'); @@ -23,7 +23,8 @@ const openidReuseTokens = isEnabled(process.env.OPENID_REUSE_TOKENS); router.get('/', async function (req, res) { const cache = getLogStores(CacheKeys.CONFIG_STORE); - const cachedStartupConfig = await cache.get(CacheKeys.STARTUP_CONFIG); + const cacheKey = scopedCacheKey(CacheKeys.STARTUP_CONFIG); + const cachedStartupConfig = await cache.get(cacheKey); if (cachedStartupConfig) { res.send(cachedStartupConfig); return; @@ -37,7 +38,10 @@ router.get('/', async function (req, res) { const ldap = getLdapConfig(); try { - const appConfig = await getAppConfig({ role: req.user?.role, tenantId: req.user?.tenantId }); + const appConfig = await getAppConfig({ + role: req.user?.role, + tenantId: req.user?.tenantId || getTenantId(), + }); const isOpenIdEnabled = !!process.env.OPENID_CLIENT_ID && @@ -141,7 +145,7 @@ router.get('/', async function (req, res) { payload.customFooter = process.env.CUSTOM_FOOTER; } - await cache.set(CacheKeys.STARTUP_CONFIG, payload); + await cache.set(cacheKey, payload); return res.status(200).send(payload); } catch (err) { logger.error('Error in startup config', err); diff --git a/api/server/services/Config/app.js b/api/server/services/Config/app.js index 7530ca1031..3256732ec2 100644 --- a/api/server/services/Config/app.js +++ b/api/server/services/Config/app.js @@ -1,5 +1,5 @@ const { CacheKeys } = require('librechat-data-provider'); -const { AppService, logger } = require('@librechat/data-schemas'); +const { AppService, logger, scopedCacheKey } = require('@librechat/data-schemas'); const { createAppConfigService, clearMcpConfigCache } = require('@librechat/api'); const { setCachedTools, invalidateCachedTools } = require('./getCachedTools'); const { loadAndFormatTools } = require('~/server/services/start/tools'); @@ -29,11 +29,23 @@ const { getAppConfig, clearAppConfigCache, clearOverrideCache } = createAppConfi getUserPrincipals: db.getUserPrincipals, }); -/** Deletes the ENDPOINT_CONFIG entry from CONFIG_STORE. Failures are non-critical and swallowed. */ +/** + * Deletes ENDPOINT_CONFIG entries from CONFIG_STORE. + * Clears both the tenant-scoped key (if in tenant context) and the + * unscoped base key (populated by unauthenticated /api/endpoints calls). + * Other tenants' scoped keys are NOT actively cleared — they expire + * via TTL. Config mutations in one tenant do not propagate immediately + * to other tenants' endpoint config caches. + */ async function clearEndpointConfigCache() { try { const configStore = getLogStores(CacheKeys.CONFIG_STORE); - await configStore.delete(CacheKeys.ENDPOINT_CONFIG); + const scoped = scopedCacheKey(CacheKeys.ENDPOINT_CONFIG); + const keys = [scoped]; + if (scoped !== CacheKeys.ENDPOINT_CONFIG) { + keys.push(CacheKeys.ENDPOINT_CONFIG); + } + await Promise.all(keys.map((k) => configStore.delete(k))); } catch { // CONFIG_STORE or ENDPOINT_CONFIG may not exist — not critical } diff --git a/api/server/services/Config/getEndpointsConfig.js b/api/server/services/Config/getEndpointsConfig.js index 476d3d7c80..cd0230ad4a 100644 --- a/api/server/services/Config/getEndpointsConfig.js +++ b/api/server/services/Config/getEndpointsConfig.js @@ -1,3 +1,4 @@ +const { scopedCacheKey } = require('@librechat/data-schemas'); const { loadCustomEndpointsConfig } = require('@librechat/api'); const { CacheKeys, @@ -17,10 +18,11 @@ const { getAppConfig } = require('./app'); */ async function getEndpointsConfig(req) { const cache = getLogStores(CacheKeys.CONFIG_STORE); - const cachedEndpointsConfig = await cache.get(CacheKeys.ENDPOINT_CONFIG); + const cacheKey = scopedCacheKey(CacheKeys.ENDPOINT_CONFIG); + const cachedEndpointsConfig = await cache.get(cacheKey); if (cachedEndpointsConfig) { if (cachedEndpointsConfig.gptPlugins) { - await cache.delete(CacheKeys.ENDPOINT_CONFIG); + await cache.delete(cacheKey); } else { return cachedEndpointsConfig; } @@ -112,7 +114,7 @@ async function getEndpointsConfig(req) { const endpointsConfig = orderEndpointsConfig(mergedConfig); - await cache.set(CacheKeys.ENDPOINT_CONFIG, endpointsConfig); + await cache.set(cacheKey, endpointsConfig); return endpointsConfig; } diff --git a/api/server/services/Config/mcp.js b/api/server/services/Config/mcp.js index cc4e98b59e..869c9e66da 100644 --- a/api/server/services/Config/mcp.js +++ b/api/server/services/Config/mcp.js @@ -1,5 +1,5 @@ -const { logger } = require('@librechat/data-schemas'); const { CacheKeys, Constants } = require('librechat-data-provider'); +const { logger, scopedCacheKey } = require('@librechat/data-schemas'); const { getCachedTools, setCachedTools } = require('./getCachedTools'); const { getLogStores } = require('~/cache'); @@ -36,7 +36,7 @@ async function updateMCPServerTools({ userId, serverName, tools }) { await setCachedTools(serverTools, { userId, serverName }); const cache = getLogStores(CacheKeys.TOOL_CACHE); - await cache.delete(CacheKeys.TOOLS); + await cache.delete(scopedCacheKey(CacheKeys.TOOLS)); logger.debug( `[MCP Cache] Updated ${tools.length} tools for server ${serverName} (user: ${userId})`, ); @@ -48,7 +48,10 @@ async function updateMCPServerTools({ userId, serverName, tools }) { } /** - * Merges app-level tools with global tools + * Merges app-level tools with global tools. + * Only the current ALS-scoped key (base key in system/startup context) is cleared. + * Tenant-scoped TOOLS:tenantId keys are NOT actively invalidated — they expire + * via TTL on the next tenant request. This matches clearEndpointConfigCache behavior. * @param {import('@librechat/api').LCAvailableTools} appTools * @returns {Promise} */ @@ -62,7 +65,7 @@ async function mergeAppTools(appTools) { const mergedTools = { ...cachedTools, ...appTools }; await setCachedTools(mergedTools); const cache = getLogStores(CacheKeys.TOOL_CACHE); - await cache.delete(CacheKeys.TOOLS); + await cache.delete(scopedCacheKey(CacheKeys.TOOLS)); logger.debug(`Merged ${count} app-level tools`); } catch (error) { logger.error('Failed to merge app-level tools:', error); diff --git a/api/server/services/Files/Audio/streamAudio.js b/api/server/services/Files/Audio/streamAudio.js index c28a96edff..7120399b5e 100644 --- a/api/server/services/Files/Audio/streamAudio.js +++ b/api/server/services/Files/Audio/streamAudio.js @@ -1,3 +1,4 @@ +const { scopedCacheKey } = require('@librechat/data-schemas'); const { Time, CacheKeys, @@ -67,6 +68,8 @@ function createChunkProcessor(user, messageId) { } const messageCache = getLogStores(CacheKeys.MESSAGES); + // Captured at creation time — must be called within an active request ALS scope + const cacheKey = scopedCacheKey(messageId); /** * @returns {Promise<{ text: string, isFinished: boolean }[] | string>} @@ -81,7 +84,7 @@ function createChunkProcessor(user, messageId) { } /** @type { string | { text: string; complete: boolean } } */ - let message = await messageCache.get(messageId); + let message = await messageCache.get(cacheKey); if (!message) { message = await getMessage({ user, messageId }); } @@ -92,7 +95,7 @@ function createChunkProcessor(user, messageId) { } else { const text = message.content?.length > 0 ? parseTextParts(message.content) : message.text; messageCache.set( - messageId, + cacheKey, { text, complete: true, diff --git a/api/server/utils/import/importers.js b/api/server/utils/import/importers.js index 39734c181c..f8b3be4dab 100644 --- a/api/server/utils/import/importers.js +++ b/api/server/utils/import/importers.js @@ -1,5 +1,5 @@ const { v4: uuidv4 } = require('uuid'); -const { logger } = require('@librechat/data-schemas'); +const { logger, scopedCacheKey } = require('@librechat/data-schemas'); const { EModelEndpoint, Constants, openAISettings, CacheKeys } = require('librechat-data-provider'); const { createImportBatchBuilder } = require('./importBatchBuilder'); const { cloneMessagesWithTimestamps } = require('./fork'); @@ -203,7 +203,7 @@ async function importLibreChatConvo( /* Endpoint configuration */ let endpoint = jsonData.endpoint ?? options.endpoint ?? EModelEndpoint.openAI; const cache = getLogStores(CacheKeys.CONFIG_STORE); - const endpointsConfig = await cache.get(CacheKeys.ENDPOINT_CONFIG); + const endpointsConfig = await cache.get(scopedCacheKey(CacheKeys.ENDPOINT_CONFIG)); const endpointConfig = endpointsConfig?.[endpoint]; if (!endpointConfig && endpointsConfig) { endpoint = Object.keys(endpointsConfig)[0]; diff --git a/packages/api/src/app/permissions.ts b/packages/api/src/app/permissions.ts index 5a557adfcf..92da1342ce 100644 --- a/packages/api/src/app/permissions.ts +++ b/packages/api/src/app/permissions.ts @@ -1,4 +1,4 @@ -import { logger } from '@librechat/data-schemas'; +import { logger, tenantStorage, SYSTEM_TENANT_ID } from '@librechat/data-schemas'; import { SystemRoles, Permissions, @@ -54,6 +54,7 @@ export async function updateInterfacePermissions({ appConfig, getRoleByName, updateAccessPermissions, + tenantId, }: { appConfig: AppConfig; getRoleByName: (roleName: string, fieldsToSelect?: string | string[]) => Promise; @@ -63,7 +64,19 @@ export async function updateInterfacePermissions({ roleData?: IRole | null, ) => Promise; -}) { + /** + * Optional tenant ID for scoping role updates to a specific tenant. + * When provided (and not SYSTEM_TENANT_ID), runs inside `tenantStorage.run({ tenantId })`. + * When omitted or SYSTEM_TENANT_ID, uses the caller's existing ALS context. + */ + tenantId?: string; +}): Promise { + if (tenantId && tenantId !== SYSTEM_TENANT_ID) { + return tenantStorage.run({ tenantId }, async () => + updateInterfacePermissions({ appConfig, getRoleByName, updateAccessPermissions }), + ); + } + const loadedInterface = appConfig?.interfaceConfig; if (!loadedInterface) { return; diff --git a/packages/api/src/flow/manager.tenant.spec.ts b/packages/api/src/flow/manager.tenant.spec.ts new file mode 100644 index 0000000000..14b780c34b --- /dev/null +++ b/packages/api/src/flow/manager.tenant.spec.ts @@ -0,0 +1,49 @@ +import { Keyv } from 'keyv'; +import { logger, tenantStorage } from '@librechat/data-schemas'; +import { FlowStateManager } from './manager'; + +jest.mock('@librechat/data-schemas', () => ({ + ...jest.requireActual('@librechat/data-schemas'), + logger: { + info: jest.fn(), + warn: jest.fn(), + error: jest.fn(), + debug: jest.fn(), + }, +})); + +describe('FlowStateManager flow keys are not tenant-scoped', () => { + let manager: FlowStateManager; + + beforeEach(() => { + jest.clearAllMocks(); + const store = new Keyv({ store: new Map() }); + manager = new FlowStateManager(store, { ci: true, ttl: 60_000 }); + }); + + it('completeFlow finds a flow regardless of tenant context (OAuth callback compatibility)', async () => { + await tenantStorage.run({ tenantId: 'tenant-a' }, async () => { + await manager.initFlow('flow-1', 'oauth', {}); + }); + + const found = await manager.completeFlow('flow-1', 'oauth', { token: 'abc' }); + expect(found).toBe(true); + }); + + it('completeFlow works when both creation and completion have the same tenant', async () => { + await tenantStorage.run({ tenantId: 'tenant-a' }, async () => { + await manager.initFlow('flow-2', 'oauth', {}); + const found = await manager.completeFlow('flow-2', 'oauth', { token: 'abc' }); + expect(found).toBe(true); + }); + }); + + it('completeFlow returns false and logs when flow does not exist', async () => { + const found = await manager.completeFlow('ghost-flow', 'oauth', { token: 'x' }); + expect(found).toBe(false); + expect(logger.warn).toHaveBeenCalledWith( + expect.stringContaining('ghost-flow'), + expect.objectContaining({ flowId: 'ghost-flow', type: 'oauth' }), + ); + }); +}); diff --git a/packages/api/src/flow/manager.ts b/packages/api/src/flow/manager.ts index b68b9edb7a..544cba9560 100644 --- a/packages/api/src/flow/manager.ts +++ b/packages/api/src/flow/manager.ts @@ -53,6 +53,12 @@ export class FlowStateManager { process.on('SIGHUP', cleanup); } + /** + * Flow keys are intentionally NOT tenant-scoped. OAuth callbacks arrive + * without tenant ALS context (the provider redirect doesn't carry + * X-Tenant-Id). Flow IDs are random UUIDs with no collision risk, and + * flow data is ephemeral (TTL-bounded, no sensitive user content). + */ private getFlowKey(flowId: string, type: string): string { return `${type}:${flowId}`; } @@ -253,7 +259,9 @@ export class FlowStateManager { if (!flowState) { logger.warn( - '[FlowStateManager] Flow state not found during completion — cannot recover metadata, skipping', + `[FlowStateManager] completeFlow: flow not found — key=${flowKey}. ` + + 'Possible causes: flow TTL expired before callback arrived, flow was never created, or ' + + 'the callback is routing to a different instance without shared Keyv storage.', { flowId, type }, ); return false; diff --git a/packages/api/src/mcp/__tests__/MCPOAuthCSRFFallback.test.ts b/packages/api/src/mcp/__tests__/MCPOAuthCSRFFallback.test.ts index cdba06cf8d..c0a861817c 100644 --- a/packages/api/src/mcp/__tests__/MCPOAuthCSRFFallback.test.ts +++ b/packages/api/src/mcp/__tests__/MCPOAuthCSRFFallback.test.ts @@ -34,6 +34,8 @@ jest.mock('@librechat/data-schemas', () => ({ error: jest.fn(), debug: jest.fn(), }, + getTenantId: jest.fn(), + SYSTEM_TENANT_ID: '__SYSTEM__', encryptV2: jest.fn(async (val: string) => `enc:${val}`), decryptV2: jest.fn(async (val: string) => val.replace(/^enc:/, '')), })); diff --git a/packages/api/src/mcp/__tests__/MCPOAuthFlow.test.ts b/packages/api/src/mcp/__tests__/MCPOAuthFlow.test.ts index f73a5ed3e8..7e26165cad 100644 --- a/packages/api/src/mcp/__tests__/MCPOAuthFlow.test.ts +++ b/packages/api/src/mcp/__tests__/MCPOAuthFlow.test.ts @@ -20,6 +20,8 @@ jest.mock('@librechat/data-schemas', () => ({ error: jest.fn(), debug: jest.fn(), }, + getTenantId: jest.fn(), + SYSTEM_TENANT_ID: '__SYSTEM__', encryptV2: jest.fn(async (val: string) => `enc:${val}`), decryptV2: jest.fn(async (val: string) => val.replace(/^enc:/, '')), })); diff --git a/packages/api/src/mcp/__tests__/MCPOAuthRaceCondition.test.ts b/packages/api/src/mcp/__tests__/MCPOAuthRaceCondition.test.ts index cb6187ab45..d5fb1d67f7 100644 --- a/packages/api/src/mcp/__tests__/MCPOAuthRaceCondition.test.ts +++ b/packages/api/src/mcp/__tests__/MCPOAuthRaceCondition.test.ts @@ -23,6 +23,8 @@ jest.mock('@librechat/data-schemas', () => ({ error: jest.fn(), debug: jest.fn(), }, + getTenantId: jest.fn(), + SYSTEM_TENANT_ID: '__SYSTEM__', encryptV2: jest.fn(async (val: string) => `enc:${val}`), decryptV2: jest.fn(async (val: string) => val.replace(/^enc:/, '')), })); @@ -258,7 +260,7 @@ describe('MCP OAuth Race Condition Fixes', () => { expect(stateAfterComplete).toBeUndefined(); expect(mockLogger.warn).toHaveBeenCalledWith( - expect.stringContaining('cannot recover metadata'), + expect.stringContaining('flow not found'), expect.any(Object), ); }); diff --git a/packages/api/src/mcp/__tests__/MCPOAuthTokenExpiry.test.ts b/packages/api/src/mcp/__tests__/MCPOAuthTokenExpiry.test.ts index 986ac4c8b4..b5cbc869a8 100644 --- a/packages/api/src/mcp/__tests__/MCPOAuthTokenExpiry.test.ts +++ b/packages/api/src/mcp/__tests__/MCPOAuthTokenExpiry.test.ts @@ -26,6 +26,8 @@ jest.mock('@librechat/data-schemas', () => ({ error: jest.fn(), debug: jest.fn(), }, + getTenantId: jest.fn(), + SYSTEM_TENANT_ID: '__SYSTEM__', encryptV2: jest.fn(async (val: string) => `enc:${val}`), decryptV2: jest.fn(async (val: string) => val.replace(/^enc:/, '')), })); diff --git a/packages/api/src/middleware/index.ts b/packages/api/src/middleware/index.ts index 7d9dee2f8a..b91fee2999 100644 --- a/packages/api/src/middleware/index.ts +++ b/packages/api/src/middleware/index.ts @@ -6,5 +6,6 @@ export * from './balance'; export * from './json'; export * from './capabilities'; export { tenantContextMiddleware } from './tenant'; +export { preAuthTenantMiddleware } from './preAuthTenant'; export * from './concurrency'; export * from './checkBalance'; diff --git a/packages/api/src/middleware/preAuthTenant.spec.ts b/packages/api/src/middleware/preAuthTenant.spec.ts new file mode 100644 index 0000000000..ed35da2324 --- /dev/null +++ b/packages/api/src/middleware/preAuthTenant.spec.ts @@ -0,0 +1,129 @@ +import { getTenantId, logger } from '@librechat/data-schemas'; +import { preAuthTenantMiddleware } from './preAuthTenant'; +import type { Request, Response, NextFunction } from 'express'; + +jest.mock('@librechat/data-schemas', () => ({ + ...jest.requireActual('@librechat/data-schemas'), + logger: { + warn: jest.fn(), + error: jest.fn(), + info: jest.fn(), + debug: jest.fn(), + }, +})); + +describe('preAuthTenantMiddleware', () => { + let req: Partial; + let res: Partial; + + beforeEach(() => { + jest.clearAllMocks(); + req = { headers: {} }; + res = {}; + }); + + it('calls next() without ALS context when no X-Tenant-Id header is present', () => { + let capturedTenantId: string | undefined = 'sentinel'; + const capturedNext: NextFunction = () => { + capturedTenantId = getTenantId(); + }; + + preAuthTenantMiddleware(req as Request, res as Response, capturedNext); + expect(capturedTenantId).toBeUndefined(); + }); + + it('calls next() without ALS context when X-Tenant-Id header is empty', () => { + req.headers = { 'x-tenant-id': '' }; + let capturedTenantId: string | undefined = 'sentinel'; + const capturedNext: NextFunction = () => { + capturedTenantId = getTenantId(); + }; + + preAuthTenantMiddleware(req as Request, res as Response, capturedNext); + expect(capturedTenantId).toBeUndefined(); + }); + + it('wraps downstream in ALS context when X-Tenant-Id header is present', () => { + req.headers = { 'x-tenant-id': 'acme-corp' }; + let capturedTenantId: string | undefined; + const capturedNext: NextFunction = () => { + capturedTenantId = getTenantId(); + }; + + preAuthTenantMiddleware(req as Request, res as Response, capturedNext); + expect(capturedTenantId).toBe('acme-corp'); + }); + + it('ignores __SYSTEM__ sentinel and logs warning', () => { + req.headers = { 'x-tenant-id': '__SYSTEM__' }; + req.ip = '10.0.0.1'; + req.path = '/api/config'; + let capturedTenantId: string | undefined = 'should-be-overwritten'; + const capturedNext: NextFunction = () => { + capturedTenantId = getTenantId(); + }; + + preAuthTenantMiddleware(req as Request, res as Response, capturedNext); + expect(capturedTenantId).toBeUndefined(); + expect(logger.warn).toHaveBeenCalledWith( + expect.stringContaining('__SYSTEM__'), + expect.objectContaining({ ip: '10.0.0.1', path: '/api/config' }), + ); + }); + + it('ignores array-valued headers (Express can produce these)', () => { + req.headers = { 'x-tenant-id': ['a', 'b'] as unknown as string }; + let capturedTenantId: string | undefined = 'sentinel'; + const capturedNext: NextFunction = () => { + capturedTenantId = getTenantId(); + }; + + preAuthTenantMiddleware(req as Request, res as Response, capturedNext); + expect(capturedTenantId).toBeUndefined(); + }); + + it('ignores tenant IDs containing invalid characters and logs warning', () => { + req.headers = { 'x-tenant-id': 'tenant:injected' }; + req.ip = '192.168.1.1'; + req.path = '/api/auth/login'; + let capturedTenantId: string | undefined = 'sentinel'; + const capturedNext: NextFunction = () => { + capturedTenantId = getTenantId(); + }; + + preAuthTenantMiddleware(req as Request, res as Response, capturedNext); + expect(capturedTenantId).toBeUndefined(); + expect(logger.warn).toHaveBeenCalledWith( + expect.stringContaining('malformed'), + expect.objectContaining({ ip: '192.168.1.1', path: '/api/auth/login' }), + ); + }); + + it('trims whitespace from tenant ID header', () => { + req.headers = { 'x-tenant-id': ' acme-corp ' }; + let capturedTenantId: string | undefined; + const capturedNext: NextFunction = () => { + capturedTenantId = getTenantId(); + }; + + preAuthTenantMiddleware(req as Request, res as Response, capturedNext); + expect(capturedTenantId).toBe('acme-corp'); + }); + + it('ignores tenant IDs exceeding max length and logs warning', () => { + req.headers = { 'x-tenant-id': 'a'.repeat(200) }; + req.ip = '192.168.1.1'; + req.path = '/api/share/abc'; + let capturedTenantId: string | undefined = 'sentinel'; + const capturedNext: NextFunction = () => { + capturedTenantId = getTenantId(); + }; + + preAuthTenantMiddleware(req as Request, res as Response, capturedNext); + expect(capturedTenantId).toBeUndefined(); + expect(logger.warn).toHaveBeenCalledWith( + expect.stringContaining('malformed'), + expect.objectContaining({ ip: '192.168.1.1', length: 200, path: '/api/share/abc' }), + ); + }); +}); diff --git a/packages/api/src/middleware/preAuthTenant.ts b/packages/api/src/middleware/preAuthTenant.ts new file mode 100644 index 0000000000..bab91f3a18 --- /dev/null +++ b/packages/api/src/middleware/preAuthTenant.ts @@ -0,0 +1,72 @@ +import { tenantStorage, logger, SYSTEM_TENANT_ID } from '@librechat/data-schemas'; +import type { Request, Response, NextFunction } from 'express'; + +/** + * Pre-authentication tenant context middleware for unauthenticated routes. + * + * Reads the tenant identifier from the `X-Tenant-Id` request header and wraps + * downstream handlers in `tenantStorage.run()` so that Mongoose queries and + * config resolution run within the correct tenant scope. + * + * **Where to use**: Mount on routes that must be tenant-aware before + * authentication has occurred: + * - `GET /api/config` — login page needs tenant-specific config (social logins, registration) + * - `/api/auth/*` — login, register, password reset + * - `/oauth/*` — OAuth callback flows + * - `GET /api/share/:shareId` — public shared conversation links + * + * **How the header gets set**: The deployment's reverse proxy, auth gateway, + * or OpenID strategy sets `X-Tenant-Id` based on subdomain, path, or OIDC claim. + * This middleware does NOT resolve tenants from subdomains or tokens — that is + * the responsibility of the deployment layer. + * + * **Design**: Intentionally minimal. No subdomain parsing, no OIDC claim + * extraction, no YAML-driven strategy. Multi-tenant deployments can: + * 1. Set the header in the reverse proxy / ingress (simplest), + * 2. Replace this middleware's resolver logic entirely, or + * 3. Layer additional resolution on top (e.g., OpenID `tenant` claim → header). + * + * If no header is present, downstream runs without tenant ALS context (same as + * single-tenant mode). This preserves backward compatibility. + */ +const MAX_TENANT_ID_LENGTH = 128; +const VALID_TENANT_ID = /^[-a-zA-Z0-9_.]+$/; + +export function preAuthTenantMiddleware(req: Request, res: Response, next: NextFunction): void { + const raw = req.headers['x-tenant-id']; + + if (!raw || typeof raw !== 'string') { + next(); + return; + } + + const tenantId = raw.trim(); + + if (!tenantId) { + next(); + return; + } + + if (tenantId === SYSTEM_TENANT_ID) { + logger.warn('[preAuthTenant] Rejected __SYSTEM__ sentinel in X-Tenant-Id header', { + ip: req.ip, + path: req.path, + }); + next(); + return; + } + + if (tenantId.length > MAX_TENANT_ID_LENGTH || !VALID_TENANT_ID.test(tenantId)) { + logger.warn('[preAuthTenant] Rejected malformed X-Tenant-Id header', { + ip: req.ip, + length: tenantId.length, + path: req.path, + }); + next(); + return; + } + + return void tenantStorage.run({ tenantId }, async () => { + next(); + }); +} diff --git a/packages/api/src/stream/GenerationJobManager.ts b/packages/api/src/stream/GenerationJobManager.ts index 3e04ab734b..5993c911ff 100644 --- a/packages/api/src/stream/GenerationJobManager.ts +++ b/packages/api/src/stream/GenerationJobManager.ts @@ -1,4 +1,4 @@ -import { logger } from '@librechat/data-schemas'; +import { logger, getTenantId, SYSTEM_TENANT_ID } from '@librechat/data-schemas'; import type { StandardGraph } from '@librechat/agents'; import { parseTextParts } from 'librechat-data-provider'; import type { Agents, TMessageContentParts } from 'librechat-data-provider'; @@ -197,7 +197,9 @@ class GenerationJobManagerClass { userId: string, conversationId?: string, ): Promise { - const jobData = await this.jobStore.createJob(streamId, userId, conversationId); + const tenantId = getTenantId(); + const safeTenantId = tenantId && tenantId !== SYSTEM_TENANT_ID ? tenantId : undefined; + const jobData = await this.jobStore.createJob(streamId, userId, conversationId, safeTenantId); /** * Create runtime state with readyPromise. @@ -355,6 +357,7 @@ class GenerationJobManagerClass { error: jobData.error, metadata: { userId: jobData.userId, + tenantId: jobData.tenantId, conversationId: jobData.conversationId, userMessage: jobData.userMessage, responseMessageId: jobData.responseMessageId, @@ -1255,8 +1258,8 @@ class GenerationJobManagerClass { * @param userId - The user ID to query * @returns Array of conversation IDs with active jobs */ - async getActiveJobIdsForUser(userId: string): Promise { - return this.jobStore.getActiveJobIdsByUser(userId); + async getActiveJobIdsForUser(userId: string, tenantId?: string): Promise { + return this.jobStore.getActiveJobIdsByUser(userId, tenantId); } /** diff --git a/packages/api/src/stream/implementations/InMemoryJobStore.ts b/packages/api/src/stream/implementations/InMemoryJobStore.ts index cc82a69963..7280c3ce80 100644 --- a/packages/api/src/stream/implementations/InMemoryJobStore.ts +++ b/packages/api/src/stream/implementations/InMemoryJobStore.ts @@ -70,6 +70,7 @@ export class InMemoryJobStore implements IJobStore { streamId: string, userId: string, conversationId?: string, + tenantId?: string, ): Promise { if (this.jobs.size >= this.maxJobs) { await this.evictOldest(); @@ -78,6 +79,7 @@ export class InMemoryJobStore implements IJobStore { const job: SerializableJobData = { streamId, userId, + ...(tenantId && { tenantId }), status: 'running', createdAt: Date.now(), conversationId, @@ -86,11 +88,12 @@ export class InMemoryJobStore implements IJobStore { this.jobs.set(streamId, job); - // Track job by userId for efficient user-scoped queries - let userJobs = this.userJobMap.get(userId); + // Track job by userId (tenant-qualified when available) for efficient user-scoped queries + const userKey = tenantId ? `${tenantId}:${userId}` : userId; + let userJobs = this.userJobMap.get(userKey); if (!userJobs) { userJobs = new Set(); - this.userJobMap.set(userId, userJobs); + this.userJobMap.set(userKey, userJobs); } userJobs.add(streamId); @@ -146,6 +149,17 @@ export class InMemoryJobStore implements IJobStore { } for (const id of toDelete) { + const job = this.jobs.get(id); + if (job) { + const userKey = job.tenantId ? `${job.tenantId}:${job.userId}` : job.userId; + const userJobs = this.userJobMap.get(userKey); + if (userJobs) { + userJobs.delete(id); + if (userJobs.size === 0) { + this.userJobMap.delete(userKey); + } + } + } await this.deleteJob(id); } @@ -169,6 +183,17 @@ export class InMemoryJobStore implements IJobStore { if (oldestId) { logger.warn(`[InMemoryJobStore] Evicting oldest job: ${oldestId}`); + const job = this.jobs.get(oldestId); + if (job) { + const userKey = job.tenantId ? `${job.tenantId}:${job.userId}` : job.userId; + const userJobs = this.userJobMap.get(userKey); + if (userJobs) { + userJobs.delete(oldestId); + if (userJobs.size === 0) { + this.userJobMap.delete(userKey); + } + } + } await this.deleteJob(oldestId); } } @@ -205,8 +230,9 @@ export class InMemoryJobStore implements IJobStore { * Returns conversation IDs of running jobs belonging to the user. * Also performs self-healing cleanup: removes stale entries for jobs that no longer exist. */ - async getActiveJobIdsByUser(userId: string): Promise { - const trackedIds = this.userJobMap.get(userId); + async getActiveJobIdsByUser(userId: string, tenantId?: string): Promise { + const userKey = tenantId ? `${tenantId}:${userId}` : userId; + const trackedIds = this.userJobMap.get(userKey); if (!trackedIds || trackedIds.size === 0) { return []; } @@ -226,7 +252,7 @@ export class InMemoryJobStore implements IJobStore { // Clean up empty set if (trackedIds.size === 0) { - this.userJobMap.delete(userId); + this.userJobMap.delete(userKey); } return activeIds; diff --git a/packages/api/src/stream/implementations/RedisJobStore.ts b/packages/api/src/stream/implementations/RedisJobStore.ts index 727fe066eb..a631bc2044 100644 --- a/packages/api/src/stream/implementations/RedisJobStore.ts +++ b/packages/api/src/stream/implementations/RedisJobStore.ts @@ -29,8 +29,9 @@ const KEYS = { runSteps: (streamId: string) => `stream:{${streamId}}:runsteps`, /** Running jobs set for cleanup (global set - single slot) */ runningJobs: 'stream:running', - /** User's active jobs set: stream:user:{userId}:jobs */ - userJobs: (userId: string) => `stream:user:{${userId}}:jobs`, + /** User's active jobs set, tenant-qualified when tenantId is available */ + userJobs: (userId: string, tenantId?: string) => + tenantId ? `stream:user:{${tenantId}:${userId}}:jobs` : `stream:user:{${userId}}:jobs`, }; /** @@ -140,10 +141,12 @@ export class RedisJobStore implements IJobStore { streamId: string, userId: string, conversationId?: string, + tenantId?: string, ): Promise { const job: SerializableJobData = { streamId, userId, + ...(tenantId && { tenantId }), status: 'running', createdAt: Date.now(), conversationId, @@ -151,7 +154,7 @@ export class RedisJobStore implements IJobStore { }; const key = KEYS.job(streamId); - const userJobsKey = KEYS.userJobs(userId); + const userJobsKey = KEYS.userJobs(userId, tenantId); // For cluster mode, we can't pipeline keys on different slots // The job key uses hash tag {streamId}, runningJobs and userJobs are on different slots @@ -377,8 +380,8 @@ export class RedisJobStore implements IJobStore { * @param userId - The user ID to query * @returns Array of conversation IDs with active jobs */ - async getActiveJobIdsByUser(userId: string): Promise { - const userJobsKey = KEYS.userJobs(userId); + async getActiveJobIdsByUser(userId: string, tenantId?: string): Promise { + const userJobsKey = KEYS.userJobs(userId, tenantId); const trackedIds = await this.redis.smembers(userJobsKey); if (trackedIds.length === 0) { @@ -868,6 +871,7 @@ export class RedisJobStore implements IJobStore { return { streamId: data.streamId, userId: data.userId, + tenantId: data.tenantId || undefined, status: data.status as JobStatus, createdAt: parseInt(data.createdAt, 10), completedAt: data.completedAt ? parseInt(data.completedAt, 10) : undefined, diff --git a/packages/api/src/stream/interfaces/IJobStore.ts b/packages/api/src/stream/interfaces/IJobStore.ts index fadddb840d..b59eed66f8 100644 --- a/packages/api/src/stream/interfaces/IJobStore.ts +++ b/packages/api/src/stream/interfaces/IJobStore.ts @@ -12,6 +12,7 @@ export type JobStatus = 'running' | 'complete' | 'error' | 'aborted'; export interface SerializableJobData { streamId: string; userId: string; + tenantId?: string; status: JobStatus; createdAt: number; completedAt?: number; @@ -149,6 +150,7 @@ export interface IJobStore { streamId: string, userId: string, conversationId?: string, + tenantId?: string, ): Promise; /** Get a job by streamId (streamId === conversationId) */ @@ -186,7 +188,7 @@ export interface IJobStore { * @param userId - The user ID to query * @returns Array of conversation IDs with active jobs */ - getActiveJobIdsByUser(userId: string): Promise; + getActiveJobIdsByUser(userId: string, tenantId?: string): Promise; // ===== Content State Methods ===== // These methods manage volatile content state tied to each job. diff --git a/packages/api/src/types/stream.ts b/packages/api/src/types/stream.ts index 068d9c8db8..dd125a1aab 100644 --- a/packages/api/src/types/stream.ts +++ b/packages/api/src/types/stream.ts @@ -4,6 +4,7 @@ import type { ServerSentEvent } from '~/types'; export interface GenerationJobMetadata { userId: string; + tenantId?: string; conversationId?: string; /** User message data for rebuilding submission on reconnect */ userMessage?: Agents.UserMessageMeta; diff --git a/packages/data-schemas/src/config/tenantContext.spec.ts b/packages/data-schemas/src/config/tenantContext.spec.ts new file mode 100644 index 0000000000..7e6cc0748d --- /dev/null +++ b/packages/data-schemas/src/config/tenantContext.spec.ts @@ -0,0 +1,26 @@ +import { tenantStorage, runAsSystem, scopedCacheKey } from './tenantContext'; + +describe('scopedCacheKey', () => { + it('returns base key when no ALS context is set', () => { + expect(scopedCacheKey('MODELS_CONFIG')).toBe('MODELS_CONFIG'); + }); + + it('returns base key in SYSTEM_TENANT_ID context', async () => { + await runAsSystem(async () => { + expect(scopedCacheKey('MODELS_CONFIG')).toBe('MODELS_CONFIG'); + }); + }); + + it('appends tenantId when tenant context is active', async () => { + await tenantStorage.run({ tenantId: 'acme' }, async () => { + expect(scopedCacheKey('MODELS_CONFIG')).toBe('MODELS_CONFIG:acme'); + }); + }); + + it('does not leak tenant context outside ALS scope', async () => { + await tenantStorage.run({ tenantId: 'acme' }, async () => { + expect(scopedCacheKey('KEY')).toBe('KEY:acme'); + }); + expect(scopedCacheKey('KEY')).toBe('KEY'); + }); +}); diff --git a/packages/data-schemas/src/config/tenantContext.ts b/packages/data-schemas/src/config/tenantContext.ts index e5e4376a90..eb77edb27d 100644 --- a/packages/data-schemas/src/config/tenantContext.ts +++ b/packages/data-schemas/src/config/tenantContext.ts @@ -26,3 +26,16 @@ export function getTenantId(): string | undefined { export function runAsSystem(fn: () => Promise): Promise { return tenantStorage.run({ tenantId: SYSTEM_TENANT_ID }, fn); } + +/** + * Appends `:${tenantId}` to a cache key when a non-system tenant context is active. + * Returns the base key unchanged when no ALS context is set or when running + * inside `runAsSystem()` (SYSTEM_TENANT_ID context). + */ +export function scopedCacheKey(baseKey: string): string { + const tenantId = getTenantId(); + if (!tenantId || tenantId === SYSTEM_TENANT_ID) { + return baseKey; + } + return `${baseKey}:${tenantId}`; +} diff --git a/packages/data-schemas/src/index.ts b/packages/data-schemas/src/index.ts index d673db1f5c..1139f83f17 100644 --- a/packages/data-schemas/src/index.ts +++ b/packages/data-schemas/src/index.ts @@ -19,6 +19,12 @@ export type * from './types'; export type * from './methods'; export { default as logger } from './config/winston'; export { default as meiliLogger } from './config/meiliLogger'; -export { tenantStorage, getTenantId, runAsSystem, SYSTEM_TENANT_ID } from './config/tenantContext'; +export { + tenantStorage, + getTenantId, + runAsSystem, + scopedCacheKey, + SYSTEM_TENANT_ID, +} from './config/tenantContext'; export type { TenantContext } from './config/tenantContext'; export { dropSupersededTenantIndexes, dropSupersededPromptGroupIndexes } from './migrations'; diff --git a/packages/data-schemas/src/methods/aclEntry.ts b/packages/data-schemas/src/methods/aclEntry.ts index 82e277254a..2f61861029 100644 --- a/packages/data-schemas/src/methods/aclEntry.ts +++ b/packages/data-schemas/src/methods/aclEntry.ts @@ -8,6 +8,7 @@ import type { Model, } from 'mongoose'; import type { IAclEntry } from '~/types'; +import { tenantSafeBulkWrite } from '~/utils/tenantBulkWrite'; export function createAclEntryMethods(mongoose: typeof import('mongoose')) { /** @@ -378,7 +379,7 @@ export function createAclEntryMethods(mongoose: typeof import('mongoose')) { options?: { session?: ClientSession }, ) { const AclEntry = mongoose.models.AclEntry as Model; - return AclEntry.bulkWrite(ops, options || {}); + return tenantSafeBulkWrite(AclEntry, ops, options || {}); } /** @@ -448,7 +449,9 @@ export function createAclEntryMethods(mongoose: typeof import('mongoose')) { { $group: { _id: '$resourceId' } }, ]); - const multiOwnerIds = new Set(otherOwners.map((doc: { _id: Types.ObjectId }) => doc._id.toString())); + const multiOwnerIds = new Set( + otherOwners.map((doc: { _id: Types.ObjectId }) => doc._id.toString()), + ); return ownedIds.filter((id) => !multiOwnerIds.has(id.toString())); } diff --git a/packages/data-schemas/src/methods/agentCategory.ts b/packages/data-schemas/src/methods/agentCategory.ts index 2dd4678075..baf33207aa 100644 --- a/packages/data-schemas/src/methods/agentCategory.ts +++ b/packages/data-schemas/src/methods/agentCategory.ts @@ -1,5 +1,6 @@ import type { Model, Types } from 'mongoose'; import type { IAgentCategory } from '~/types'; +import { tenantSafeBulkWrite } from '~/utils/tenantBulkWrite'; export function createAgentCategoryMethods(mongoose: typeof import('mongoose')) { /** @@ -74,7 +75,7 @@ export function createAgentCategoryMethods(mongoose: typeof import('mongoose')) }, })); - return await AgentCategory.bulkWrite(operations); + return await tenantSafeBulkWrite(AgentCategory, operations); } /** @@ -241,7 +242,7 @@ export function createAgentCategoryMethods(mongoose: typeof import('mongoose')) }, })); - await AgentCategory.bulkWrite(bulkOps, { ordered: false }); + await tenantSafeBulkWrite(AgentCategory, bulkOps, { ordered: false }); } return updates.length > 0 || created > 0; diff --git a/packages/data-schemas/src/methods/conversation.ts b/packages/data-schemas/src/methods/conversation.ts index 7a62afef9e..abfe16bf2d 100644 --- a/packages/data-schemas/src/methods/conversation.ts +++ b/packages/data-schemas/src/methods/conversation.ts @@ -1,6 +1,7 @@ import type { FilterQuery, Model, SortOrder } from 'mongoose'; -import logger from '~/config/winston'; import { createTempChatExpirationDate } from '~/utils/tempChatRetention'; +import { tenantSafeBulkWrite } from '~/utils/tenantBulkWrite'; +import logger from '~/config/winston'; import type { AppConfig, IConversation } from '~/types'; import type { MessageMethods } from './message'; import type { DeleteResult } from 'mongoose'; @@ -228,7 +229,7 @@ export function createConversationMethods( }, })); - const result = await Conversation.bulkWrite(bulkOps); + const result = await tenantSafeBulkWrite(Conversation, bulkOps); return result; } catch (error) { logger.error('[bulkSaveConvos] Error saving conversations in bulk', error); diff --git a/packages/data-schemas/src/methods/conversationTag.ts b/packages/data-schemas/src/methods/conversationTag.ts index af1e43babb..085948bab5 100644 --- a/packages/data-schemas/src/methods/conversationTag.ts +++ b/packages/data-schemas/src/methods/conversationTag.ts @@ -1,4 +1,5 @@ import type { Model } from 'mongoose'; +import { tenantSafeBulkWrite } from '~/utils/tenantBulkWrite'; import logger from '~/config/winston'; interface IConversationTag { @@ -233,7 +234,7 @@ export function createConversationTagMethods(mongoose: typeof import('mongoose') } if (bulkOps.length > 0) { - await ConversationTag.bulkWrite(bulkOps); + await tenantSafeBulkWrite(ConversationTag, bulkOps); } const updatedConversation = ( @@ -273,7 +274,7 @@ export function createConversationTagMethods(mongoose: typeof import('mongoose') }, })); - const result = await ConversationTag.bulkWrite(bulkOps); + const result = await tenantSafeBulkWrite(ConversationTag, bulkOps); if (result && result.modifiedCount > 0) { logger.debug( `user: ${user} | Incremented tag counts - modified ${result.modifiedCount} tags`, diff --git a/packages/data-schemas/src/methods/file.ts b/packages/data-schemas/src/methods/file.ts index 3d7db88c3f..4c0969afb3 100644 --- a/packages/data-schemas/src/methods/file.ts +++ b/packages/data-schemas/src/methods/file.ts @@ -2,6 +2,7 @@ import logger from '../config/winston'; import { EToolResources, FileContext } from 'librechat-data-provider'; import type { FilterQuery, SortOrder, Model } from 'mongoose'; import type { IMongoFile } from '~/types/file'; +import { tenantSafeBulkWrite } from '~/utils/tenantBulkWrite'; /** Factory function that takes mongoose instance and returns the file methods */ export function createFileMethods(mongoose: typeof import('mongoose')) { @@ -322,7 +323,7 @@ export function createFileMethods(mongoose: typeof import('mongoose')) { }, })); - const result = await File.bulkWrite(bulkOperations); + const result = await tenantSafeBulkWrite(File, bulkOperations); logger.info(`Updated ${result.modifiedCount} files with new S3 URLs`); } diff --git a/packages/data-schemas/src/methods/message.ts b/packages/data-schemas/src/methods/message.ts index ae5ca72b12..2e638b6bfb 100644 --- a/packages/data-schemas/src/methods/message.ts +++ b/packages/data-schemas/src/methods/message.ts @@ -1,6 +1,7 @@ import type { DeleteResult, FilterQuery, Model } from 'mongoose'; import logger from '~/config/winston'; import { createTempChatExpirationDate } from '~/utils/tempChatRetention'; +import { tenantSafeBulkWrite } from '~/utils/tenantBulkWrite'; import type { AppConfig, IMessage } from '~/types'; /** Simple UUID v4 regex to replace zod validation */ @@ -165,7 +166,7 @@ export function createMessageMethods(mongoose: typeof import('mongoose')): Messa upsert: true, }, })); - const result = await Message.bulkWrite(bulkOps); + const result = await tenantSafeBulkWrite(Message, bulkOps); return result; } catch (err) { logger.error('Error saving messages in bulk:', err); diff --git a/packages/data-schemas/src/methods/prompt.ts b/packages/data-schemas/src/methods/prompt.ts index a1b6bfde37..86d830fecd 100644 --- a/packages/data-schemas/src/methods/prompt.ts +++ b/packages/data-schemas/src/methods/prompt.ts @@ -1,8 +1,9 @@ import { ResourceType, SystemCategories } from 'librechat-data-provider'; import type { Model, Types } from 'mongoose'; import type { IAclEntry, IPrompt, IPromptGroup, IPromptGroupDocument } from '~/types'; -import { escapeRegExp } from '~/utils/string'; +import { getTenantId, SYSTEM_TENANT_ID } from '~/config/tenantContext'; import { isValidObjectIdString } from '~/utils/objectId'; +import { escapeRegExp } from '~/utils/string'; import logger from '~/config/winston'; export interface PromptDeps { @@ -508,16 +509,37 @@ export function createPromptMethods(mongoose: typeof import('mongoose'), deps: P if (typeof matchFilter._id === 'string') { matchFilter._id = new ObjectId(matchFilter._id); } + const tenantId = getTenantId(); + const useTenantFilter = tenantId && tenantId !== SYSTEM_TENANT_ID; + + const lookupStage = useTenantFilter + ? { + $lookup: { + from: 'prompts', + let: { prodId: '$productionId' }, + pipeline: [ + { + $match: { + $expr: { $eq: ['$_id', '$$prodId'] }, + tenantId, + }, + }, + ], + as: 'productionPrompt', + }, + } + : { + $lookup: { + from: 'prompts', + localField: 'productionId', + foreignField: '_id', + as: 'productionPrompt', + }, + }; + const result = await PromptGroup.aggregate([ { $match: matchFilter }, - { - $lookup: { - from: 'prompts', - localField: 'productionId', - foreignField: '_id', - as: 'productionPrompt', - }, - }, + lookupStage, { $unwind: { path: '$productionPrompt', preserveNullAndEmptyArrays: true } }, ]); const group = result[0] || null; diff --git a/packages/data-schemas/src/migrations/tenantIndexes.spec.ts b/packages/data-schemas/src/migrations/tenantIndexes.spec.ts index 4637e7d0ad..e62b587a6e 100644 --- a/packages/data-schemas/src/migrations/tenantIndexes.spec.ts +++ b/packages/data-schemas/src/migrations/tenantIndexes.spec.ts @@ -47,7 +47,13 @@ describe('dropSupersededTenantIndexes', () => { await db.createCollection('roles'); await db.collection('roles').createIndex({ name: 1 }, { unique: true, name: 'name_1' }); + await db.createCollection('agents'); + await db.collection('agents').createIndex({ id: 1 }, { unique: true, name: 'id_1' }); + await db.createCollection('conversations'); + await db + .collection('conversations') + .createIndex({ conversationId: 1 }, { unique: true, name: 'conversationId_1' }); await db .collection('conversations') .createIndex( @@ -56,10 +62,18 @@ describe('dropSupersededTenantIndexes', () => { ); await db.createCollection('messages'); + await db + .collection('messages') + .createIndex({ messageId: 1 }, { unique: true, name: 'messageId_1' }); await db .collection('messages') .createIndex({ messageId: 1, user: 1 }, { unique: true, name: 'messageId_1_user_1' }); + await db.createCollection('presets'); + await db + .collection('presets') + .createIndex({ presetId: 1 }, { unique: true, name: 'presetId_1' }); + await db.createCollection('agentcategories'); await db .collection('agentcategories') diff --git a/packages/data-schemas/src/migrations/tenantIndexes.ts b/packages/data-schemas/src/migrations/tenantIndexes.ts index c68df4db2b..a8b4e51768 100644 --- a/packages/data-schemas/src/migrations/tenantIndexes.ts +++ b/packages/data-schemas/src/migrations/tenantIndexes.ts @@ -24,8 +24,10 @@ const SUPERSEDED_INDEXES: Record = { 'appleId_1', ], roles: ['name_1'], - conversations: ['conversationId_1_user_1'], - messages: ['messageId_1_user_1'], + agents: ['id_1'], + conversations: ['conversationId_1', 'conversationId_1_user_1'], + messages: ['messageId_1', 'messageId_1_user_1'], + presets: ['presetId_1'], agentcategories: ['value_1'], accessroles: ['accessRoleId_1'], conversationtags: ['tag_1_user_1'], diff --git a/packages/data-schemas/src/schema/agent.ts b/packages/data-schemas/src/schema/agent.ts index 42a7ca5418..70734d0ceb 100644 --- a/packages/data-schemas/src/schema/agent.ts +++ b/packages/data-schemas/src/schema/agent.ts @@ -5,8 +5,6 @@ const agentSchema = new Schema( { id: { type: String, - index: true, - unique: true, required: true, }, name: { @@ -124,6 +122,7 @@ const agentSchema = new Schema( }, ); +agentSchema.index({ id: 1, tenantId: 1 }, { unique: true }); agentSchema.index({ updatedAt: -1, _id: 1 }); agentSchema.index({ 'edges.to': 1 }); diff --git a/packages/data-schemas/src/schema/convo.ts b/packages/data-schemas/src/schema/convo.ts index 9ed8949e9c..c8f394935a 100644 --- a/packages/data-schemas/src/schema/convo.ts +++ b/packages/data-schemas/src/schema/convo.ts @@ -6,7 +6,6 @@ const convoSchema: Schema = new Schema( { conversationId: { type: String, - unique: true, required: true, index: true, meiliIndex: true, diff --git a/packages/data-schemas/src/schema/message.ts b/packages/data-schemas/src/schema/message.ts index ff3468918e..9879efae55 100644 --- a/packages/data-schemas/src/schema/message.ts +++ b/packages/data-schemas/src/schema/message.ts @@ -5,7 +5,6 @@ const messageSchema: Schema = new Schema( { messageId: { type: String, - unique: true, required: true, index: true, meiliIndex: true, diff --git a/packages/data-schemas/src/schema/preset.ts b/packages/data-schemas/src/schema/preset.ts index 33c217ea23..5af5163fd3 100644 --- a/packages/data-schemas/src/schema/preset.ts +++ b/packages/data-schemas/src/schema/preset.ts @@ -60,7 +60,6 @@ const presetSchema: Schema = new Schema( { presetId: { type: String, - unique: true, required: true, index: true, }, @@ -88,4 +87,6 @@ const presetSchema: Schema = new Schema( { timestamps: true }, ); +presetSchema.index({ presetId: 1, tenantId: 1 }, { unique: true }); + export default presetSchema; diff --git a/packages/data-schemas/src/utils/index.ts b/packages/data-schemas/src/utils/index.ts index c071f4e827..17e43ac3ca 100644 --- a/packages/data-schemas/src/utils/index.ts +++ b/packages/data-schemas/src/utils/index.ts @@ -1,5 +1,6 @@ export * from './principal'; export * from './string'; export * from './tempChatRetention'; +export { tenantSafeBulkWrite } from './tenantBulkWrite'; export * from './transactions'; export * from './objectId'; diff --git a/packages/data-schemas/src/utils/tenantBulkWrite.spec.ts b/packages/data-schemas/src/utils/tenantBulkWrite.spec.ts new file mode 100644 index 0000000000..059868b8a1 --- /dev/null +++ b/packages/data-schemas/src/utils/tenantBulkWrite.spec.ts @@ -0,0 +1,376 @@ +import mongoose, { Schema } from 'mongoose'; +import { MongoMemoryServer } from 'mongodb-memory-server'; +import { tenantStorage, runAsSystem, SYSTEM_TENANT_ID } from '~/config/tenantContext'; +import { applyTenantIsolation, _resetStrictCache } from '~/models/plugins/tenantIsolation'; +import { tenantSafeBulkWrite, _resetBulkWriteStrictCache } from './tenantBulkWrite'; + +let mongoServer: InstanceType; + +interface ITestDoc { + name: string; + value?: number; + tenantId?: string; +} + +function createTestModel(suffix: string) { + const schema = new Schema({ + name: { type: String, required: true }, + value: { type: Number, default: 0 }, + tenantId: { type: String, index: true }, + }); + applyTenantIsolation(schema); + const modelName = `TestBulkWrite_${suffix}_${Date.now()}`; + return mongoose.model(modelName, schema); +} + +beforeAll(async () => { + mongoServer = await MongoMemoryServer.create(); + await mongoose.connect(mongoServer.getUri()); +}); + +afterAll(async () => { + await mongoose.disconnect(); + await mongoServer.stop(); +}); + +afterEach(() => { + delete process.env.TENANT_ISOLATION_STRICT; + _resetStrictCache(); + _resetBulkWriteStrictCache(); +}); + +describe('tenantSafeBulkWrite', () => { + describe('with tenant context', () => { + it('injects tenantId into updateOne filters', async () => { + const Model = createTestModel('updateOne'); + + // Seed data for two tenants + await runAsSystem(async () => { + await Model.create([ + { name: 'doc1', value: 1, tenantId: 'tenant-a' }, + { name: 'doc1', value: 1, tenantId: 'tenant-b' }, + ]); + }); + + // Update only tenant-a's doc + await tenantStorage.run({ tenantId: 'tenant-a' }, async () => { + await tenantSafeBulkWrite(Model, [ + { + updateOne: { + filter: { name: 'doc1' }, + update: { $set: { value: 99 } }, + }, + }, + ]); + }); + + // Verify tenant-a was updated, tenant-b was not + const docs = await runAsSystem(async () => Model.find({}).lean()); + const docA = docs.find((d) => d.tenantId === 'tenant-a'); + const docB = docs.find((d) => d.tenantId === 'tenant-b'); + expect(docA?.value).toBe(99); + expect(docB?.value).toBe(1); + }); + + it('injects tenantId into insertOne documents', async () => { + const Model = createTestModel('insertOne'); + + await tenantStorage.run({ tenantId: 'tenant-x' }, async () => { + await tenantSafeBulkWrite(Model, [ + { + insertOne: { + document: { name: 'new-doc', value: 42 } as ITestDoc, + }, + }, + ]); + }); + + const docs = await runAsSystem(async () => Model.find({}).lean()); + expect(docs).toHaveLength(1); + expect(docs[0].tenantId).toBe('tenant-x'); + expect(docs[0].name).toBe('new-doc'); + }); + + it('injects tenantId into deleteOne filters', async () => { + const Model = createTestModel('deleteOne'); + + await runAsSystem(async () => { + await Model.create([ + { name: 'to-delete', tenantId: 'tenant-a' }, + { name: 'to-delete', tenantId: 'tenant-b' }, + ]); + }); + + await tenantStorage.run({ tenantId: 'tenant-a' }, async () => { + await tenantSafeBulkWrite(Model, [ + { + deleteOne: { + filter: { name: 'to-delete' }, + }, + }, + ]); + }); + + const docs = await runAsSystem(async () => Model.find({}).lean()); + expect(docs).toHaveLength(1); + expect(docs[0].tenantId).toBe('tenant-b'); + }); + + it('injects tenantId into updateMany filters', async () => { + const Model = createTestModel('updateMany'); + + await runAsSystem(async () => { + await Model.create([ + { name: 'batch', value: 0, tenantId: 'tenant-a' }, + { name: 'batch', value: 0, tenantId: 'tenant-a' }, + { name: 'batch', value: 0, tenantId: 'tenant-b' }, + ]); + }); + + await tenantStorage.run({ tenantId: 'tenant-a' }, async () => { + await tenantSafeBulkWrite(Model, [ + { + updateMany: { + filter: { name: 'batch' }, + update: { $set: { value: 5 } }, + }, + }, + ]); + }); + + const docs = await runAsSystem(async () => Model.find({}).lean()); + const tenantADocs = docs.filter((d) => d.tenantId === 'tenant-a'); + const tenantBDocs = docs.filter((d) => d.tenantId === 'tenant-b'); + expect(tenantADocs.every((d) => d.value === 5)).toBe(true); + expect(tenantBDocs[0].value).toBe(0); + }); + }); + + describe('with SYSTEM_TENANT_ID', () => { + it('skips tenantId injection (cross-tenant operation)', async () => { + const Model = createTestModel('system'); + + await runAsSystem(async () => { + await Model.create([ + { name: 'sys-doc', value: 0, tenantId: 'tenant-a' }, + { name: 'sys-doc', value: 0, tenantId: 'tenant-b' }, + ]); + }); + + // System context should update ALL docs regardless of tenant + await runAsSystem(async () => { + await tenantSafeBulkWrite(Model, [ + { + updateMany: { + filter: { name: 'sys-doc' }, + update: { $set: { value: 100 } }, + }, + }, + ]); + }); + + const docs = await runAsSystem(async () => Model.find({}).lean()); + expect(docs.every((d) => d.value === 100)).toBe(true); + }); + }); + + describe('with SYSTEM_TENANT_ID in strict mode', () => { + it('does not throw when runAsSystem is used in strict mode', async () => { + process.env.TENANT_ISOLATION_STRICT = 'true'; + _resetBulkWriteStrictCache(); + + const Model = createTestModel('systemStrict'); + + await runAsSystem(async () => { + await Model.create({ name: 'strict-sys', value: 0 }); + }); + + await expect( + runAsSystem(async () => + tenantSafeBulkWrite(Model, [ + { + updateOne: { + filter: { name: 'strict-sys' }, + update: { $set: { value: 42 } }, + }, + }, + ]), + ), + ).resolves.toBeDefined(); + }); + }); + + describe('deleteMany and replaceOne', () => { + it('injects tenantId into deleteMany filters', async () => { + const Model = createTestModel('deleteMany'); + + await runAsSystem(async () => { + await Model.create([ + { name: 'batch-del', value: 0, tenantId: 'tenant-a' }, + { name: 'batch-del', value: 0, tenantId: 'tenant-a' }, + { name: 'batch-del', value: 0, tenantId: 'tenant-b' }, + ]); + }); + + await tenantStorage.run({ tenantId: 'tenant-a' }, async () => { + await tenantSafeBulkWrite(Model, [{ deleteMany: { filter: { name: 'batch-del' } } }]); + }); + + const docs = await runAsSystem(async () => Model.find({}).lean()); + expect(docs).toHaveLength(1); + expect(docs[0].tenantId).toBe('tenant-b'); + }); + + it('injects tenantId into replaceOne filter and replacement', async () => { + const Model = createTestModel('replaceOne'); + + await runAsSystem(async () => { + await Model.create([ + { name: 'to-replace', value: 1, tenantId: 'tenant-a' }, + { name: 'to-replace', value: 1, tenantId: 'tenant-b' }, + ]); + }); + + await tenantStorage.run({ tenantId: 'tenant-a' }, async () => { + await tenantSafeBulkWrite(Model, [ + { + replaceOne: { + filter: { name: 'to-replace' }, + replacement: { name: 'replaced', value: 99 }, + }, + }, + ]); + }); + + const docs = await runAsSystem(async () => Model.find({}).sort({ name: 1 }).lean()); + const replaced = docs.find((d) => d.name === 'replaced'); + const untouched = docs.find((d) => d.tenantId === 'tenant-b'); + expect(replaced?.value).toBe(99); + expect(replaced?.tenantId).toBe('tenant-a'); + expect(untouched?.value).toBe(1); + }); + + it('replaceOne overwrites a conflicting tenantId in the replacement document', async () => { + const Model = createTestModel('replaceOverwrite'); + + await runAsSystem(async () => { + await Model.create({ name: 'conflict', value: 1, tenantId: 'tenant-a' }); + }); + + await tenantStorage.run({ tenantId: 'tenant-a' }, async () => { + await tenantSafeBulkWrite(Model, [ + { + replaceOne: { + filter: { name: 'conflict' }, + replacement: { name: 'conflict', value: 2, tenantId: 'tenant-evil' } as ITestDoc, + }, + }, + ]); + }); + + const docs = await runAsSystem(async () => Model.find({}).lean()); + expect(docs).toHaveLength(1); + expect(docs[0].tenantId).toBe('tenant-a'); + expect(docs[0].value).toBe(2); + }); + }); + + describe('edge cases', () => { + it('handles empty ops array', async () => { + const Model = createTestModel('emptyOps'); + const result = await tenantStorage.run({ tenantId: 'tenant-x' }, async () => + tenantSafeBulkWrite(Model, []), + ); + expect(result.insertedCount).toBe(0); + expect(result.modifiedCount).toBe(0); + }); + }); + + describe('without tenant context', () => { + it('passes through in non-strict mode', async () => { + const Model = createTestModel('noCtx'); + + await runAsSystem(async () => { + await Model.create({ name: 'no-ctx', value: 0 }); + }); + + // No ALS context — non-strict should pass through + const result = await tenantSafeBulkWrite(Model, [ + { + updateOne: { + filter: { name: 'no-ctx' }, + update: { $set: { value: 10 } }, + }, + }, + ]); + + expect(result.modifiedCount).toBe(1); + }); + + it('throws in strict mode', async () => { + process.env.TENANT_ISOLATION_STRICT = 'true'; + _resetBulkWriteStrictCache(); + + const Model = createTestModel('strict'); + + await expect( + tenantSafeBulkWrite(Model, [ + { + updateOne: { + filter: { name: 'any' }, + update: { $set: { value: 1 } }, + }, + }, + ]), + ).rejects.toThrow('bulkWrite on TestBulkWrite_strict'); + }); + }); + + describe('mixed operations', () => { + it('handles a batch of mixed insert, update, delete operations', async () => { + const Model = createTestModel('mixed'); + + await runAsSystem(async () => { + await Model.create([ + { name: 'existing1', value: 1, tenantId: 'tenant-m' }, + { name: 'to-remove', value: 2, tenantId: 'tenant-m' }, + { name: 'existing1', value: 1, tenantId: 'tenant-other' }, + ]); + }); + + await tenantStorage.run({ tenantId: 'tenant-m' }, async () => { + await tenantSafeBulkWrite(Model, [ + { + insertOne: { + document: { name: 'new-item', value: 10 } as ITestDoc, + }, + }, + { + updateOne: { + filter: { name: 'existing1' }, + update: { $set: { value: 50 } }, + }, + }, + { + deleteOne: { + filter: { name: 'to-remove' }, + }, + }, + ]); + }); + + const docs = await runAsSystem(async () => Model.find({}).sort({ name: 1 }).lean()); + + // tenant-other's doc should be untouched + const otherDoc = docs.find((d) => d.tenantId === 'tenant-other' && d.name === 'existing1'); + expect(otherDoc?.value).toBe(1); + + // tenant-m: existing1 updated, to-remove deleted, new-item inserted + const tenantMDocs = docs.filter((d) => d.tenantId === 'tenant-m'); + expect(tenantMDocs).toHaveLength(2); + expect(tenantMDocs.find((d) => d.name === 'existing1')?.value).toBe(50); + expect(tenantMDocs.find((d) => d.name === 'new-item')?.value).toBe(10); + expect(tenantMDocs.find((d) => d.name === 'to-remove')).toBeUndefined(); + }); + }); +}); diff --git a/packages/data-schemas/src/utils/tenantBulkWrite.ts b/packages/data-schemas/src/utils/tenantBulkWrite.ts new file mode 100644 index 0000000000..16ef5fa057 --- /dev/null +++ b/packages/data-schemas/src/utils/tenantBulkWrite.ts @@ -0,0 +1,109 @@ +import type { AnyBulkWriteOperation, Model, MongooseBulkWriteOptions } from 'mongoose'; +import type { BulkWriteResult } from 'mongodb'; +import { getTenantId, SYSTEM_TENANT_ID } from '~/config/tenantContext'; +import logger from '~/config/winston'; + +let _strictMode: boolean | undefined; + +function isStrict(): boolean { + return (_strictMode ??= process.env.TENANT_ISOLATION_STRICT === 'true'); +} + +/** Resets the cached strict-mode flag. Exposed for test teardown only. */ +export function _resetBulkWriteStrictCache(): void { + _strictMode = undefined; +} + +/** + * Tenant-safe wrapper around Mongoose `Model.bulkWrite()`. + * + * Mongoose's `bulkWrite` does not trigger schema-level middleware hooks, so the + * `applyTenantIsolation` plugin cannot intercept it. This wrapper injects the + * current ALS tenant context into every operation's filter and/or document + * before delegating to the native `bulkWrite`. + * + * Behavior: + * - **tenantId present** (normal request): injects `{ tenantId }` into every + * operation filter (updateOne, deleteOne, replaceOne) and document (insertOne). + * - **SYSTEM_TENANT_ID**: skips injection (cross-tenant system operation). + * - **No tenantId + strict mode**: throws (fail-closed, same as the plugin). + * - **No tenantId + non-strict**: passes through without injection (backward compat). + */ +export async function tenantSafeBulkWrite( + model: Model, + ops: AnyBulkWriteOperation[], + options?: MongooseBulkWriteOptions, +): Promise { + const tenantId = getTenantId(); + + if (!tenantId) { + if (isStrict()) { + throw new Error( + `[TenantIsolation] bulkWrite on ${model.modelName} attempted without tenant context in strict mode`, + ); + } + return model.bulkWrite(ops, options); + } + + if (tenantId === SYSTEM_TENANT_ID) { + return model.bulkWrite(ops, options); + } + + const injected = ops.map((op) => injectTenantId(op, tenantId)); + return model.bulkWrite(injected, options); +} + +/** + * Injects `tenantId` into a single bulk-write operation. + * Returns a new operation object — does not mutate the original. + */ +function injectTenantId(op: AnyBulkWriteOperation, tenantId: string): AnyBulkWriteOperation { + if ('insertOne' in op) { + return { + insertOne: { + document: { ...op.insertOne.document, tenantId }, + }, + }; + } + + if ('updateOne' in op) { + const { filter, ...rest } = op.updateOne; + return { updateOne: { ...rest, filter: { ...filter, tenantId } } }; + } + + if ('updateMany' in op) { + const { filter, ...rest } = op.updateMany; + return { updateMany: { ...rest, filter: { ...filter, tenantId } } }; + } + + if ('deleteOne' in op) { + const { filter, ...rest } = op.deleteOne; + return { deleteOne: { ...rest, filter: { ...filter, tenantId } } }; + } + + if ('deleteMany' in op) { + const { filter, ...rest } = op.deleteMany; + return { deleteMany: { ...rest, filter: { ...filter, tenantId } } }; + } + + if ('replaceOne' in op) { + const { filter, replacement, ...rest } = op.replaceOne; + return { + replaceOne: { + ...rest, + filter: { ...filter, tenantId }, + replacement: { ...replacement, tenantId }, + }, + }; + } + + if (isStrict()) { + throw new Error( + '[TenantIsolation] Unknown bulkWrite operation type in strict mode — refusing to pass through without tenant injection', + ); + } + logger.warn( + '[tenantSafeBulkWrite] Unknown bulk op type, passing through without tenant injection', + ); + return op; +}