diff --git a/api/server/controllers/ModelController.js b/api/server/controllers/ModelController.js index 805d9eef27..1741d3f6b1 100644 --- a/api/server/controllers/ModelController.js +++ b/api/server/controllers/ModelController.js @@ -1,5 +1,5 @@ -const { logger } = require('@librechat/data-schemas'); const { CacheKeys } = require('librechat-data-provider'); +const { logger, scopedCacheKey } = require('@librechat/data-schemas'); const { loadDefaultModels, loadConfigModels } = require('~/server/services/Config'); const { getLogStores } = require('~/cache'); @@ -9,7 +9,8 @@ const { getLogStores } = require('~/cache'); */ const getModelsConfig = async (req) => { const cache = getLogStores(CacheKeys.CONFIG_STORE); - let modelsConfig = await cache.get(CacheKeys.MODELS_CONFIG); + const cacheKey = scopedCacheKey(CacheKeys.MODELS_CONFIG); + let modelsConfig = await cache.get(cacheKey); if (!modelsConfig) { modelsConfig = await loadModels(req); } @@ -24,7 +25,8 @@ const getModelsConfig = async (req) => { */ async function loadModels(req) { const cache = getLogStores(CacheKeys.CONFIG_STORE); - const cachedModelsConfig = await cache.get(CacheKeys.MODELS_CONFIG); + const cacheKey = scopedCacheKey(CacheKeys.MODELS_CONFIG); + const cachedModelsConfig = await cache.get(cacheKey); if (cachedModelsConfig) { return cachedModelsConfig; } @@ -33,7 +35,7 @@ async function loadModels(req) { const modelConfig = { ...defaultModelsConfig, ...customModelsConfig }; - await cache.set(CacheKeys.MODELS_CONFIG, modelConfig); + await cache.set(cacheKey, modelConfig); return modelConfig; } diff --git a/api/server/controllers/PluginController.js b/api/server/controllers/PluginController.js index 14dd284c30..7c47fe4d57 100644 --- a/api/server/controllers/PluginController.js +++ b/api/server/controllers/PluginController.js @@ -1,5 +1,5 @@ -const { logger } = require('@librechat/data-schemas'); const { CacheKeys } = require('librechat-data-provider'); +const { logger, scopedCacheKey } = require('@librechat/data-schemas'); const { getToolkitKey, checkPluginAuth, filterUniquePlugins } = require('@librechat/api'); const { getCachedTools, setCachedTools } = require('~/server/services/Config'); const { availableTools, toolkits } = require('~/app/clients/tools'); @@ -9,7 +9,8 @@ const { getLogStores } = require('~/cache'); const getAvailablePluginsController = async (req, res) => { try { const cache = getLogStores(CacheKeys.TOOL_CACHE); - const cachedPlugins = await cache.get(CacheKeys.PLUGINS); + const pluginsCacheKey = scopedCacheKey(CacheKeys.PLUGINS); + const cachedPlugins = await cache.get(pluginsCacheKey); if (cachedPlugins) { res.status(200).json(cachedPlugins); return; @@ -37,7 +38,7 @@ const getAvailablePluginsController = async (req, res) => { plugins = plugins.filter((plugin) => !filteredTools.includes(plugin.pluginKey)); } - await cache.set(CacheKeys.PLUGINS, plugins); + await cache.set(pluginsCacheKey, plugins); res.status(200).json(plugins); } catch (error) { res.status(500).json({ message: error.message }); @@ -64,7 +65,8 @@ const getAvailableTools = async (req, res) => { return res.status(401).json({ message: 'Unauthorized' }); } const cache = getLogStores(CacheKeys.TOOL_CACHE); - const cachedToolsArray = await cache.get(CacheKeys.TOOLS); + const toolsCacheKey = scopedCacheKey(CacheKeys.TOOLS); + const cachedToolsArray = await cache.get(toolsCacheKey); const appConfig = req.config ?? (await getAppConfig({ role: req.user?.role, tenantId: req.user?.tenantId })); @@ -115,7 +117,7 @@ const getAvailableTools = async (req, res) => { } const finalTools = filterUniquePlugins(toolsOutput); - await cache.set(CacheKeys.TOOLS, finalTools); + await cache.set(toolsCacheKey, finalTools); res.status(200).json(finalTools); } catch (error) { diff --git a/api/server/controllers/PluginController.spec.js b/api/server/controllers/PluginController.spec.js index 06a51a3bd6..fdbc2401ce 100644 --- a/api/server/controllers/PluginController.spec.js +++ b/api/server/controllers/PluginController.spec.js @@ -8,6 +8,7 @@ jest.mock('@librechat/data-schemas', () => ({ error: jest.fn(), warn: jest.fn(), }, + scopedCacheKey: jest.fn((key) => key), })); jest.mock('~/server/services/Config', () => ({ diff --git a/api/server/index.js b/api/server/index.js index 813b453468..4b919b1ceb 100644 --- a/api/server/index.js +++ b/api/server/index.js @@ -21,6 +21,7 @@ const { createStreamServices, initializeFileStorage, updateInterfacePermissions, + preAuthTenantMiddleware, } = require('@librechat/api'); const { connectDb, indexSync } = require('~/db'); const initializeOAuthReconnectManager = require('./services/initializeOAuthReconnectManager'); @@ -59,7 +60,14 @@ const startServer = async () => { app.disable('x-powered-by'); app.set('trust proxy', trusted_proxy); - await seedDatabase(); + if (isEnabled(process.env.TENANT_ISOLATION_STRICT)) { + logger.warn( + '[Security] TENANT_ISOLATION_STRICT is active. Ensure your reverse proxy strips or sets ' + + 'the X-Tenant-Id header — untrusted clients must not be able to set it directly.', + ); + } + + await runAsSystem(seedDatabase); const appConfig = await getAppConfig({ baseOnly: true }); initializeFileStorage(appConfig); await runAsSystem(async () => { @@ -139,9 +147,11 @@ const startServer = async () => { /* Per-request capability cache — must be registered before any route that calls hasCapability */ app.use(capabilityContextMiddleware); - app.use('/oauth', routes.oauth); + /* Pre-auth tenant context for unauthenticated routes that need tenant scoping. + * The reverse proxy / auth gateway sets `X-Tenant-Id` header for multi-tenant deployments. */ + app.use('/oauth', preAuthTenantMiddleware, routes.oauth); /* API Endpoints */ - app.use('/api/auth', routes.auth); + app.use('/api/auth', preAuthTenantMiddleware, routes.auth); app.use('/api/admin', routes.adminAuth); app.use('/api/admin/config', routes.adminConfig); app.use('/api/admin/groups', routes.adminGroups); @@ -159,11 +169,11 @@ const startServer = async () => { app.use('/api/endpoints', routes.endpoints); app.use('/api/balance', routes.balance); app.use('/api/models', routes.models); - app.use('/api/config', routes.config); + app.use('/api/config', preAuthTenantMiddleware, routes.config); app.use('/api/assistants', routes.assistants); app.use('/api/files', await routes.files.initialize()); app.use('/images/', createValidateImageRequest(appConfig.secureImageLinks), routes.staticRoute); - app.use('/api/share', routes.share); + app.use('/api/share', preAuthTenantMiddleware, routes.share); app.use('/api/roles', routes.roles); app.use('/api/agents', routes.agents); app.use('/api/banner', routes.banner); diff --git a/api/server/middleware/optionalJwtAuth.js b/api/server/middleware/optionalJwtAuth.js index 2f59fdda4a..d46478d36e 100644 --- a/api/server/middleware/optionalJwtAuth.js +++ b/api/server/middleware/optionalJwtAuth.js @@ -1,9 +1,10 @@ const cookies = require('cookie'); const passport = require('passport'); -const { isEnabled } = require('@librechat/api'); +const { isEnabled, tenantContextMiddleware } = require('@librechat/api'); // This middleware does not require authentication, -// but if the user is authenticated, it will set the user object. +// but if the user is authenticated, it will set the user object +// and establish tenant ALS context. const optionalJwtAuth = (req, res, next) => { const cookieHeader = req.headers.cookie; const tokenProvider = cookieHeader ? cookies.parse(cookieHeader).token_provider : null; @@ -13,6 +14,7 @@ const optionalJwtAuth = (req, res, next) => { } if (user) { req.user = user; + return tenantContextMiddleware(req, res, next); } next(); }; diff --git a/api/server/routes/agents/__tests__/streamTenant.spec.js b/api/server/routes/agents/__tests__/streamTenant.spec.js new file mode 100644 index 0000000000..1f89953186 --- /dev/null +++ b/api/server/routes/agents/__tests__/streamTenant.spec.js @@ -0,0 +1,186 @@ +const express = require('express'); +const request = require('supertest'); + +const mockGenerationJobManager = { + getJob: jest.fn(), + subscribe: jest.fn(), + getResumeState: jest.fn(), + abortJob: jest.fn(), + getActiveJobIdsForUser: jest.fn().mockResolvedValue([]), +}; + +jest.mock('@librechat/data-schemas', () => ({ + ...jest.requireActual('@librechat/data-schemas'), + logger: { + debug: jest.fn(), + warn: jest.fn(), + error: jest.fn(), + info: jest.fn(), + }, +})); + +jest.mock('@librechat/api', () => ({ + ...jest.requireActual('@librechat/api'), + isEnabled: jest.fn().mockReturnValue(false), + GenerationJobManager: mockGenerationJobManager, +})); + +jest.mock('~/models', () => ({ + saveMessage: jest.fn(), +})); + +let mockUserId = 'user-123'; +let mockTenantId; + +jest.mock('~/server/middleware', () => ({ + uaParser: (req, res, next) => next(), + checkBan: (req, res, next) => next(), + requireJwtAuth: (req, res, next) => { + req.user = { id: mockUserId, tenantId: mockTenantId }; + next(); + }, + messageIpLimiter: (req, res, next) => next(), + configMiddleware: (req, res, next) => next(), + messageUserLimiter: (req, res, next) => next(), +})); + +jest.mock('~/server/routes/agents/chat', () => require('express').Router()); +jest.mock('~/server/routes/agents/v1', () => ({ + v1: require('express').Router(), +})); +jest.mock('~/server/routes/agents/openai', () => require('express').Router()); +jest.mock('~/server/routes/agents/responses', () => require('express').Router()); + +const agentsRouter = require('../index'); +const app = express(); +app.use(express.json()); +app.use('/agents', agentsRouter); + +function mockSubscribeSuccess() { + mockGenerationJobManager.subscribe.mockImplementation((_streamId, _writeEvent, onDone) => { + process.nextTick(() => onDone({ done: true })); + return { unsubscribe: jest.fn() }; + }); +} + +describe('SSE stream tenant isolation', () => { + beforeEach(() => { + jest.clearAllMocks(); + mockUserId = 'user-123'; + mockTenantId = undefined; + }); + + describe('GET /chat/stream/:streamId', () => { + it('returns 403 when a user from a different tenant accesses a stream', async () => { + mockUserId = 'user-456'; + mockTenantId = 'tenant-b'; + + mockGenerationJobManager.getJob.mockResolvedValue({ + metadata: { userId: 'user-456', tenantId: 'tenant-a' }, + status: 'running', + }); + + const res = await request(app).get('/agents/chat/stream/stream-123'); + expect(res.status).toBe(403); + expect(res.body.error).toBe('Unauthorized'); + }); + + it('returns 404 when stream does not exist', async () => { + mockGenerationJobManager.getJob.mockResolvedValue(null); + + const res = await request(app).get('/agents/chat/stream/nonexistent'); + expect(res.status).toBe(404); + }); + + it('proceeds past tenant guard when tenant matches', async () => { + mockUserId = 'user-123'; + mockTenantId = 'tenant-a'; + mockSubscribeSuccess(); + + mockGenerationJobManager.getJob.mockResolvedValue({ + metadata: { userId: 'user-123', tenantId: 'tenant-a' }, + status: 'running', + }); + + const res = await request(app).get('/agents/chat/stream/stream-123'); + expect(res.status).toBe(200); + expect(mockGenerationJobManager.subscribe).toHaveBeenCalledTimes(1); + }); + + it('proceeds past tenant guard when job has no tenantId (single-tenant mode)', async () => { + mockUserId = 'user-123'; + mockTenantId = undefined; + mockSubscribeSuccess(); + + mockGenerationJobManager.getJob.mockResolvedValue({ + metadata: { userId: 'user-123' }, + status: 'running', + }); + + const res = await request(app).get('/agents/chat/stream/stream-123'); + expect(res.status).toBe(200); + expect(mockGenerationJobManager.subscribe).toHaveBeenCalledTimes(1); + }); + + it('returns 403 when job has tenantId but user has no tenantId', async () => { + mockUserId = 'user-123'; + mockTenantId = undefined; + + mockGenerationJobManager.getJob.mockResolvedValue({ + metadata: { userId: 'user-123', tenantId: 'some-tenant' }, + status: 'running', + }); + + const res = await request(app).get('/agents/chat/stream/stream-123'); + expect(res.status).toBe(403); + }); + }); + + describe('GET /chat/status/:conversationId', () => { + it('returns 403 when tenant does not match', async () => { + mockUserId = 'user-123'; + mockTenantId = 'tenant-b'; + + mockGenerationJobManager.getJob.mockResolvedValue({ + metadata: { userId: 'user-123', tenantId: 'tenant-a' }, + status: 'running', + }); + + const res = await request(app).get('/agents/chat/status/conv-123'); + expect(res.status).toBe(403); + expect(res.body.error).toBe('Unauthorized'); + }); + + it('returns status when tenant matches', async () => { + mockUserId = 'user-123'; + mockTenantId = 'tenant-a'; + + mockGenerationJobManager.getJob.mockResolvedValue({ + metadata: { userId: 'user-123', tenantId: 'tenant-a' }, + status: 'running', + createdAt: Date.now(), + }); + mockGenerationJobManager.getResumeState.mockResolvedValue(null); + + const res = await request(app).get('/agents/chat/status/conv-123'); + expect(res.status).toBe(200); + expect(res.body.active).toBe(true); + }); + }); + + describe('POST /chat/abort', () => { + it('returns 403 when tenant does not match', async () => { + mockUserId = 'user-123'; + mockTenantId = 'tenant-b'; + + mockGenerationJobManager.getJob.mockResolvedValue({ + metadata: { userId: 'user-123', tenantId: 'tenant-a' }, + status: 'running', + }); + + const res = await request(app).post('/agents/chat/abort').send({ streamId: 'stream-123' }); + expect(res.status).toBe(403); + expect(res.body.error).toBe('Unauthorized'); + }); + }); +}); diff --git a/api/server/routes/agents/index.js b/api/server/routes/agents/index.js index 86966a3f3e..eb42046bed 100644 --- a/api/server/routes/agents/index.js +++ b/api/server/routes/agents/index.js @@ -17,6 +17,11 @@ const chat = require('./chat'); const { LIMIT_MESSAGE_IP, LIMIT_MESSAGE_USER } = process.env ?? {}; +/** Untenanted jobs (pre-multi-tenancy) remain accessible if the userId check passes. */ +function hasTenantMismatch(job, user) { + return job.metadata?.tenantId != null && job.metadata.tenantId !== user.tenantId; +} + const router = express.Router(); /** @@ -67,6 +72,10 @@ router.get('/chat/stream/:streamId', async (req, res) => { return res.status(403).json({ error: 'Unauthorized' }); } + if (hasTenantMismatch(job, req.user)) { + return res.status(403).json({ error: 'Unauthorized' }); + } + res.setHeader('Content-Encoding', 'identity'); res.setHeader('Content-Type', 'text/event-stream'); res.setHeader('Cache-Control', 'no-cache, no-transform'); @@ -150,7 +159,10 @@ router.get('/chat/stream/:streamId', async (req, res) => { * @returns { activeJobIds: string[] } */ router.get('/chat/active', async (req, res) => { - const activeJobIds = await GenerationJobManager.getActiveJobIdsForUser(req.user.id); + const activeJobIds = await GenerationJobManager.getActiveJobIdsForUser( + req.user.id, + req.user.tenantId, + ); res.json({ activeJobIds }); }); @@ -174,6 +186,10 @@ router.get('/chat/status/:conversationId', async (req, res) => { return res.status(403).json({ error: 'Unauthorized' }); } + if (hasTenantMismatch(job, req.user)) { + return res.status(403).json({ error: 'Unauthorized' }); + } + // Get resume state which contains aggregatedContent // Avoid calling both getStreamInfo and getResumeState (both fetch content) const resumeState = await GenerationJobManager.getResumeState(conversationId); @@ -213,7 +229,10 @@ router.post('/chat/abort', async (req, res) => { // This handles the case where frontend sends "new" but job was created with a UUID if (!job && userId) { logger.debug(`[AgentStream] Job not found by ID, checking active jobs for user: ${userId}`); - const activeJobIds = await GenerationJobManager.getActiveJobIdsForUser(userId); + const activeJobIds = await GenerationJobManager.getActiveJobIdsForUser( + userId, + req.user.tenantId, + ); if (activeJobIds.length > 0) { // Abort the most recent active job for this user jobStreamId = activeJobIds[0]; @@ -230,6 +249,10 @@ router.post('/chat/abort', async (req, res) => { return res.status(403).json({ error: 'Unauthorized' }); } + if (hasTenantMismatch(job, req.user)) { + return res.status(403).json({ error: 'Unauthorized' }); + } + logger.debug(`[AgentStream] Job found, aborting: ${jobStreamId}`); const abortResult = await GenerationJobManager.abortJob(jobStreamId); logger.debug(`[AgentStream] Job aborted successfully: ${jobStreamId}`, { diff --git a/api/server/routes/config.js b/api/server/routes/config.js index 0a68ccba4f..8caa180854 100644 --- a/api/server/routes/config.js +++ b/api/server/routes/config.js @@ -1,7 +1,7 @@ const express = require('express'); -const { logger } = require('@librechat/data-schemas'); const { isEnabled, getBalanceConfig } = require('@librechat/api'); const { CacheKeys, defaultSocialLogins } = require('librechat-data-provider'); +const { logger, getTenantId, scopedCacheKey } = require('@librechat/data-schemas'); const { getLdapConfig } = require('~/server/services/Config/ldap'); const { getAppConfig } = require('~/server/services/Config/app'); const { getLogStores } = require('~/cache'); @@ -23,7 +23,8 @@ const openidReuseTokens = isEnabled(process.env.OPENID_REUSE_TOKENS); router.get('/', async function (req, res) { const cache = getLogStores(CacheKeys.CONFIG_STORE); - const cachedStartupConfig = await cache.get(CacheKeys.STARTUP_CONFIG); + const cacheKey = scopedCacheKey(CacheKeys.STARTUP_CONFIG); + const cachedStartupConfig = await cache.get(cacheKey); if (cachedStartupConfig) { res.send(cachedStartupConfig); return; @@ -37,7 +38,10 @@ router.get('/', async function (req, res) { const ldap = getLdapConfig(); try { - const appConfig = await getAppConfig({ role: req.user?.role, tenantId: req.user?.tenantId }); + const appConfig = await getAppConfig({ + role: req.user?.role, + tenantId: req.user?.tenantId || getTenantId(), + }); const isOpenIdEnabled = !!process.env.OPENID_CLIENT_ID && @@ -141,7 +145,7 @@ router.get('/', async function (req, res) { payload.customFooter = process.env.CUSTOM_FOOTER; } - await cache.set(CacheKeys.STARTUP_CONFIG, payload); + await cache.set(cacheKey, payload); return res.status(200).send(payload); } catch (err) { logger.error('Error in startup config', err); diff --git a/api/server/services/Config/app.js b/api/server/services/Config/app.js index 7530ca1031..3256732ec2 100644 --- a/api/server/services/Config/app.js +++ b/api/server/services/Config/app.js @@ -1,5 +1,5 @@ const { CacheKeys } = require('librechat-data-provider'); -const { AppService, logger } = require('@librechat/data-schemas'); +const { AppService, logger, scopedCacheKey } = require('@librechat/data-schemas'); const { createAppConfigService, clearMcpConfigCache } = require('@librechat/api'); const { setCachedTools, invalidateCachedTools } = require('./getCachedTools'); const { loadAndFormatTools } = require('~/server/services/start/tools'); @@ -29,11 +29,23 @@ const { getAppConfig, clearAppConfigCache, clearOverrideCache } = createAppConfi getUserPrincipals: db.getUserPrincipals, }); -/** Deletes the ENDPOINT_CONFIG entry from CONFIG_STORE. Failures are non-critical and swallowed. */ +/** + * Deletes ENDPOINT_CONFIG entries from CONFIG_STORE. + * Clears both the tenant-scoped key (if in tenant context) and the + * unscoped base key (populated by unauthenticated /api/endpoints calls). + * Other tenants' scoped keys are NOT actively cleared — they expire + * via TTL. Config mutations in one tenant do not propagate immediately + * to other tenants' endpoint config caches. + */ async function clearEndpointConfigCache() { try { const configStore = getLogStores(CacheKeys.CONFIG_STORE); - await configStore.delete(CacheKeys.ENDPOINT_CONFIG); + const scoped = scopedCacheKey(CacheKeys.ENDPOINT_CONFIG); + const keys = [scoped]; + if (scoped !== CacheKeys.ENDPOINT_CONFIG) { + keys.push(CacheKeys.ENDPOINT_CONFIG); + } + await Promise.all(keys.map((k) => configStore.delete(k))); } catch { // CONFIG_STORE or ENDPOINT_CONFIG may not exist — not critical } diff --git a/api/server/services/Config/getEndpointsConfig.js b/api/server/services/Config/getEndpointsConfig.js index 476d3d7c80..cd0230ad4a 100644 --- a/api/server/services/Config/getEndpointsConfig.js +++ b/api/server/services/Config/getEndpointsConfig.js @@ -1,3 +1,4 @@ +const { scopedCacheKey } = require('@librechat/data-schemas'); const { loadCustomEndpointsConfig } = require('@librechat/api'); const { CacheKeys, @@ -17,10 +18,11 @@ const { getAppConfig } = require('./app'); */ async function getEndpointsConfig(req) { const cache = getLogStores(CacheKeys.CONFIG_STORE); - const cachedEndpointsConfig = await cache.get(CacheKeys.ENDPOINT_CONFIG); + const cacheKey = scopedCacheKey(CacheKeys.ENDPOINT_CONFIG); + const cachedEndpointsConfig = await cache.get(cacheKey); if (cachedEndpointsConfig) { if (cachedEndpointsConfig.gptPlugins) { - await cache.delete(CacheKeys.ENDPOINT_CONFIG); + await cache.delete(cacheKey); } else { return cachedEndpointsConfig; } @@ -112,7 +114,7 @@ async function getEndpointsConfig(req) { const endpointsConfig = orderEndpointsConfig(mergedConfig); - await cache.set(CacheKeys.ENDPOINT_CONFIG, endpointsConfig); + await cache.set(cacheKey, endpointsConfig); return endpointsConfig; } diff --git a/api/server/services/Config/mcp.js b/api/server/services/Config/mcp.js index cc4e98b59e..869c9e66da 100644 --- a/api/server/services/Config/mcp.js +++ b/api/server/services/Config/mcp.js @@ -1,5 +1,5 @@ -const { logger } = require('@librechat/data-schemas'); const { CacheKeys, Constants } = require('librechat-data-provider'); +const { logger, scopedCacheKey } = require('@librechat/data-schemas'); const { getCachedTools, setCachedTools } = require('./getCachedTools'); const { getLogStores } = require('~/cache'); @@ -36,7 +36,7 @@ async function updateMCPServerTools({ userId, serverName, tools }) { await setCachedTools(serverTools, { userId, serverName }); const cache = getLogStores(CacheKeys.TOOL_CACHE); - await cache.delete(CacheKeys.TOOLS); + await cache.delete(scopedCacheKey(CacheKeys.TOOLS)); logger.debug( `[MCP Cache] Updated ${tools.length} tools for server ${serverName} (user: ${userId})`, ); @@ -48,7 +48,10 @@ async function updateMCPServerTools({ userId, serverName, tools }) { } /** - * Merges app-level tools with global tools + * Merges app-level tools with global tools. + * Only the current ALS-scoped key (base key in system/startup context) is cleared. + * Tenant-scoped TOOLS:tenantId keys are NOT actively invalidated — they expire + * via TTL on the next tenant request. This matches clearEndpointConfigCache behavior. * @param {import('@librechat/api').LCAvailableTools} appTools * @returns {Promise} */ @@ -62,7 +65,7 @@ async function mergeAppTools(appTools) { const mergedTools = { ...cachedTools, ...appTools }; await setCachedTools(mergedTools); const cache = getLogStores(CacheKeys.TOOL_CACHE); - await cache.delete(CacheKeys.TOOLS); + await cache.delete(scopedCacheKey(CacheKeys.TOOLS)); logger.debug(`Merged ${count} app-level tools`); } catch (error) { logger.error('Failed to merge app-level tools:', error); diff --git a/api/server/services/Files/Audio/streamAudio.js b/api/server/services/Files/Audio/streamAudio.js index c28a96edff..7120399b5e 100644 --- a/api/server/services/Files/Audio/streamAudio.js +++ b/api/server/services/Files/Audio/streamAudio.js @@ -1,3 +1,4 @@ +const { scopedCacheKey } = require('@librechat/data-schemas'); const { Time, CacheKeys, @@ -67,6 +68,8 @@ function createChunkProcessor(user, messageId) { } const messageCache = getLogStores(CacheKeys.MESSAGES); + // Captured at creation time — must be called within an active request ALS scope + const cacheKey = scopedCacheKey(messageId); /** * @returns {Promise<{ text: string, isFinished: boolean }[] | string>} @@ -81,7 +84,7 @@ function createChunkProcessor(user, messageId) { } /** @type { string | { text: string; complete: boolean } } */ - let message = await messageCache.get(messageId); + let message = await messageCache.get(cacheKey); if (!message) { message = await getMessage({ user, messageId }); } @@ -92,7 +95,7 @@ function createChunkProcessor(user, messageId) { } else { const text = message.content?.length > 0 ? parseTextParts(message.content) : message.text; messageCache.set( - messageId, + cacheKey, { text, complete: true, diff --git a/api/server/utils/import/importers.js b/api/server/utils/import/importers.js index 39734c181c..f8b3be4dab 100644 --- a/api/server/utils/import/importers.js +++ b/api/server/utils/import/importers.js @@ -1,5 +1,5 @@ const { v4: uuidv4 } = require('uuid'); -const { logger } = require('@librechat/data-schemas'); +const { logger, scopedCacheKey } = require('@librechat/data-schemas'); const { EModelEndpoint, Constants, openAISettings, CacheKeys } = require('librechat-data-provider'); const { createImportBatchBuilder } = require('./importBatchBuilder'); const { cloneMessagesWithTimestamps } = require('./fork'); @@ -203,7 +203,7 @@ async function importLibreChatConvo( /* Endpoint configuration */ let endpoint = jsonData.endpoint ?? options.endpoint ?? EModelEndpoint.openAI; const cache = getLogStores(CacheKeys.CONFIG_STORE); - const endpointsConfig = await cache.get(CacheKeys.ENDPOINT_CONFIG); + const endpointsConfig = await cache.get(scopedCacheKey(CacheKeys.ENDPOINT_CONFIG)); const endpointConfig = endpointsConfig?.[endpoint]; if (!endpointConfig && endpointsConfig) { endpoint = Object.keys(endpointsConfig)[0]; diff --git a/packages/api/src/app/permissions.ts b/packages/api/src/app/permissions.ts index 5a557adfcf..92da1342ce 100644 --- a/packages/api/src/app/permissions.ts +++ b/packages/api/src/app/permissions.ts @@ -1,4 +1,4 @@ -import { logger } from '@librechat/data-schemas'; +import { logger, tenantStorage, SYSTEM_TENANT_ID } from '@librechat/data-schemas'; import { SystemRoles, Permissions, @@ -54,6 +54,7 @@ export async function updateInterfacePermissions({ appConfig, getRoleByName, updateAccessPermissions, + tenantId, }: { appConfig: AppConfig; getRoleByName: (roleName: string, fieldsToSelect?: string | string[]) => Promise; @@ -63,7 +64,19 @@ export async function updateInterfacePermissions({ roleData?: IRole | null, ) => Promise; -}) { + /** + * Optional tenant ID for scoping role updates to a specific tenant. + * When provided (and not SYSTEM_TENANT_ID), runs inside `tenantStorage.run({ tenantId })`. + * When omitted or SYSTEM_TENANT_ID, uses the caller's existing ALS context. + */ + tenantId?: string; +}): Promise { + if (tenantId && tenantId !== SYSTEM_TENANT_ID) { + return tenantStorage.run({ tenantId }, async () => + updateInterfacePermissions({ appConfig, getRoleByName, updateAccessPermissions }), + ); + } + const loadedInterface = appConfig?.interfaceConfig; if (!loadedInterface) { return; diff --git a/packages/api/src/flow/manager.tenant.spec.ts b/packages/api/src/flow/manager.tenant.spec.ts new file mode 100644 index 0000000000..14b780c34b --- /dev/null +++ b/packages/api/src/flow/manager.tenant.spec.ts @@ -0,0 +1,49 @@ +import { Keyv } from 'keyv'; +import { logger, tenantStorage } from '@librechat/data-schemas'; +import { FlowStateManager } from './manager'; + +jest.mock('@librechat/data-schemas', () => ({ + ...jest.requireActual('@librechat/data-schemas'), + logger: { + info: jest.fn(), + warn: jest.fn(), + error: jest.fn(), + debug: jest.fn(), + }, +})); + +describe('FlowStateManager flow keys are not tenant-scoped', () => { + let manager: FlowStateManager; + + beforeEach(() => { + jest.clearAllMocks(); + const store = new Keyv({ store: new Map() }); + manager = new FlowStateManager(store, { ci: true, ttl: 60_000 }); + }); + + it('completeFlow finds a flow regardless of tenant context (OAuth callback compatibility)', async () => { + await tenantStorage.run({ tenantId: 'tenant-a' }, async () => { + await manager.initFlow('flow-1', 'oauth', {}); + }); + + const found = await manager.completeFlow('flow-1', 'oauth', { token: 'abc' }); + expect(found).toBe(true); + }); + + it('completeFlow works when both creation and completion have the same tenant', async () => { + await tenantStorage.run({ tenantId: 'tenant-a' }, async () => { + await manager.initFlow('flow-2', 'oauth', {}); + const found = await manager.completeFlow('flow-2', 'oauth', { token: 'abc' }); + expect(found).toBe(true); + }); + }); + + it('completeFlow returns false and logs when flow does not exist', async () => { + const found = await manager.completeFlow('ghost-flow', 'oauth', { token: 'x' }); + expect(found).toBe(false); + expect(logger.warn).toHaveBeenCalledWith( + expect.stringContaining('ghost-flow'), + expect.objectContaining({ flowId: 'ghost-flow', type: 'oauth' }), + ); + }); +}); diff --git a/packages/api/src/flow/manager.ts b/packages/api/src/flow/manager.ts index b68b9edb7a..544cba9560 100644 --- a/packages/api/src/flow/manager.ts +++ b/packages/api/src/flow/manager.ts @@ -53,6 +53,12 @@ export class FlowStateManager { process.on('SIGHUP', cleanup); } + /** + * Flow keys are intentionally NOT tenant-scoped. OAuth callbacks arrive + * without tenant ALS context (the provider redirect doesn't carry + * X-Tenant-Id). Flow IDs are random UUIDs with no collision risk, and + * flow data is ephemeral (TTL-bounded, no sensitive user content). + */ private getFlowKey(flowId: string, type: string): string { return `${type}:${flowId}`; } @@ -253,7 +259,9 @@ export class FlowStateManager { if (!flowState) { logger.warn( - '[FlowStateManager] Flow state not found during completion — cannot recover metadata, skipping', + `[FlowStateManager] completeFlow: flow not found — key=${flowKey}. ` + + 'Possible causes: flow TTL expired before callback arrived, flow was never created, or ' + + 'the callback is routing to a different instance without shared Keyv storage.', { flowId, type }, ); return false; diff --git a/packages/api/src/mcp/__tests__/MCPOAuthCSRFFallback.test.ts b/packages/api/src/mcp/__tests__/MCPOAuthCSRFFallback.test.ts index cdba06cf8d..c0a861817c 100644 --- a/packages/api/src/mcp/__tests__/MCPOAuthCSRFFallback.test.ts +++ b/packages/api/src/mcp/__tests__/MCPOAuthCSRFFallback.test.ts @@ -34,6 +34,8 @@ jest.mock('@librechat/data-schemas', () => ({ error: jest.fn(), debug: jest.fn(), }, + getTenantId: jest.fn(), + SYSTEM_TENANT_ID: '__SYSTEM__', encryptV2: jest.fn(async (val: string) => `enc:${val}`), decryptV2: jest.fn(async (val: string) => val.replace(/^enc:/, '')), })); diff --git a/packages/api/src/mcp/__tests__/MCPOAuthFlow.test.ts b/packages/api/src/mcp/__tests__/MCPOAuthFlow.test.ts index f73a5ed3e8..7e26165cad 100644 --- a/packages/api/src/mcp/__tests__/MCPOAuthFlow.test.ts +++ b/packages/api/src/mcp/__tests__/MCPOAuthFlow.test.ts @@ -20,6 +20,8 @@ jest.mock('@librechat/data-schemas', () => ({ error: jest.fn(), debug: jest.fn(), }, + getTenantId: jest.fn(), + SYSTEM_TENANT_ID: '__SYSTEM__', encryptV2: jest.fn(async (val: string) => `enc:${val}`), decryptV2: jest.fn(async (val: string) => val.replace(/^enc:/, '')), })); diff --git a/packages/api/src/mcp/__tests__/MCPOAuthRaceCondition.test.ts b/packages/api/src/mcp/__tests__/MCPOAuthRaceCondition.test.ts index cb6187ab45..d5fb1d67f7 100644 --- a/packages/api/src/mcp/__tests__/MCPOAuthRaceCondition.test.ts +++ b/packages/api/src/mcp/__tests__/MCPOAuthRaceCondition.test.ts @@ -23,6 +23,8 @@ jest.mock('@librechat/data-schemas', () => ({ error: jest.fn(), debug: jest.fn(), }, + getTenantId: jest.fn(), + SYSTEM_TENANT_ID: '__SYSTEM__', encryptV2: jest.fn(async (val: string) => `enc:${val}`), decryptV2: jest.fn(async (val: string) => val.replace(/^enc:/, '')), })); @@ -258,7 +260,7 @@ describe('MCP OAuth Race Condition Fixes', () => { expect(stateAfterComplete).toBeUndefined(); expect(mockLogger.warn).toHaveBeenCalledWith( - expect.stringContaining('cannot recover metadata'), + expect.stringContaining('flow not found'), expect.any(Object), ); }); diff --git a/packages/api/src/mcp/__tests__/MCPOAuthTokenExpiry.test.ts b/packages/api/src/mcp/__tests__/MCPOAuthTokenExpiry.test.ts index 986ac4c8b4..b5cbc869a8 100644 --- a/packages/api/src/mcp/__tests__/MCPOAuthTokenExpiry.test.ts +++ b/packages/api/src/mcp/__tests__/MCPOAuthTokenExpiry.test.ts @@ -26,6 +26,8 @@ jest.mock('@librechat/data-schemas', () => ({ error: jest.fn(), debug: jest.fn(), }, + getTenantId: jest.fn(), + SYSTEM_TENANT_ID: '__SYSTEM__', encryptV2: jest.fn(async (val: string) => `enc:${val}`), decryptV2: jest.fn(async (val: string) => val.replace(/^enc:/, '')), })); diff --git a/packages/api/src/middleware/index.ts b/packages/api/src/middleware/index.ts index 7d9dee2f8a..b91fee2999 100644 --- a/packages/api/src/middleware/index.ts +++ b/packages/api/src/middleware/index.ts @@ -6,5 +6,6 @@ export * from './balance'; export * from './json'; export * from './capabilities'; export { tenantContextMiddleware } from './tenant'; +export { preAuthTenantMiddleware } from './preAuthTenant'; export * from './concurrency'; export * from './checkBalance'; diff --git a/packages/api/src/middleware/preAuthTenant.spec.ts b/packages/api/src/middleware/preAuthTenant.spec.ts new file mode 100644 index 0000000000..ed35da2324 --- /dev/null +++ b/packages/api/src/middleware/preAuthTenant.spec.ts @@ -0,0 +1,129 @@ +import { getTenantId, logger } from '@librechat/data-schemas'; +import { preAuthTenantMiddleware } from './preAuthTenant'; +import type { Request, Response, NextFunction } from 'express'; + +jest.mock('@librechat/data-schemas', () => ({ + ...jest.requireActual('@librechat/data-schemas'), + logger: { + warn: jest.fn(), + error: jest.fn(), + info: jest.fn(), + debug: jest.fn(), + }, +})); + +describe('preAuthTenantMiddleware', () => { + let req: Partial; + let res: Partial; + + beforeEach(() => { + jest.clearAllMocks(); + req = { headers: {} }; + res = {}; + }); + + it('calls next() without ALS context when no X-Tenant-Id header is present', () => { + let capturedTenantId: string | undefined = 'sentinel'; + const capturedNext: NextFunction = () => { + capturedTenantId = getTenantId(); + }; + + preAuthTenantMiddleware(req as Request, res as Response, capturedNext); + expect(capturedTenantId).toBeUndefined(); + }); + + it('calls next() without ALS context when X-Tenant-Id header is empty', () => { + req.headers = { 'x-tenant-id': '' }; + let capturedTenantId: string | undefined = 'sentinel'; + const capturedNext: NextFunction = () => { + capturedTenantId = getTenantId(); + }; + + preAuthTenantMiddleware(req as Request, res as Response, capturedNext); + expect(capturedTenantId).toBeUndefined(); + }); + + it('wraps downstream in ALS context when X-Tenant-Id header is present', () => { + req.headers = { 'x-tenant-id': 'acme-corp' }; + let capturedTenantId: string | undefined; + const capturedNext: NextFunction = () => { + capturedTenantId = getTenantId(); + }; + + preAuthTenantMiddleware(req as Request, res as Response, capturedNext); + expect(capturedTenantId).toBe('acme-corp'); + }); + + it('ignores __SYSTEM__ sentinel and logs warning', () => { + req.headers = { 'x-tenant-id': '__SYSTEM__' }; + req.ip = '10.0.0.1'; + req.path = '/api/config'; + let capturedTenantId: string | undefined = 'should-be-overwritten'; + const capturedNext: NextFunction = () => { + capturedTenantId = getTenantId(); + }; + + preAuthTenantMiddleware(req as Request, res as Response, capturedNext); + expect(capturedTenantId).toBeUndefined(); + expect(logger.warn).toHaveBeenCalledWith( + expect.stringContaining('__SYSTEM__'), + expect.objectContaining({ ip: '10.0.0.1', path: '/api/config' }), + ); + }); + + it('ignores array-valued headers (Express can produce these)', () => { + req.headers = { 'x-tenant-id': ['a', 'b'] as unknown as string }; + let capturedTenantId: string | undefined = 'sentinel'; + const capturedNext: NextFunction = () => { + capturedTenantId = getTenantId(); + }; + + preAuthTenantMiddleware(req as Request, res as Response, capturedNext); + expect(capturedTenantId).toBeUndefined(); + }); + + it('ignores tenant IDs containing invalid characters and logs warning', () => { + req.headers = { 'x-tenant-id': 'tenant:injected' }; + req.ip = '192.168.1.1'; + req.path = '/api/auth/login'; + let capturedTenantId: string | undefined = 'sentinel'; + const capturedNext: NextFunction = () => { + capturedTenantId = getTenantId(); + }; + + preAuthTenantMiddleware(req as Request, res as Response, capturedNext); + expect(capturedTenantId).toBeUndefined(); + expect(logger.warn).toHaveBeenCalledWith( + expect.stringContaining('malformed'), + expect.objectContaining({ ip: '192.168.1.1', path: '/api/auth/login' }), + ); + }); + + it('trims whitespace from tenant ID header', () => { + req.headers = { 'x-tenant-id': ' acme-corp ' }; + let capturedTenantId: string | undefined; + const capturedNext: NextFunction = () => { + capturedTenantId = getTenantId(); + }; + + preAuthTenantMiddleware(req as Request, res as Response, capturedNext); + expect(capturedTenantId).toBe('acme-corp'); + }); + + it('ignores tenant IDs exceeding max length and logs warning', () => { + req.headers = { 'x-tenant-id': 'a'.repeat(200) }; + req.ip = '192.168.1.1'; + req.path = '/api/share/abc'; + let capturedTenantId: string | undefined = 'sentinel'; + const capturedNext: NextFunction = () => { + capturedTenantId = getTenantId(); + }; + + preAuthTenantMiddleware(req as Request, res as Response, capturedNext); + expect(capturedTenantId).toBeUndefined(); + expect(logger.warn).toHaveBeenCalledWith( + expect.stringContaining('malformed'), + expect.objectContaining({ ip: '192.168.1.1', length: 200, path: '/api/share/abc' }), + ); + }); +}); diff --git a/packages/api/src/middleware/preAuthTenant.ts b/packages/api/src/middleware/preAuthTenant.ts new file mode 100644 index 0000000000..bab91f3a18 --- /dev/null +++ b/packages/api/src/middleware/preAuthTenant.ts @@ -0,0 +1,72 @@ +import { tenantStorage, logger, SYSTEM_TENANT_ID } from '@librechat/data-schemas'; +import type { Request, Response, NextFunction } from 'express'; + +/** + * Pre-authentication tenant context middleware for unauthenticated routes. + * + * Reads the tenant identifier from the `X-Tenant-Id` request header and wraps + * downstream handlers in `tenantStorage.run()` so that Mongoose queries and + * config resolution run within the correct tenant scope. + * + * **Where to use**: Mount on routes that must be tenant-aware before + * authentication has occurred: + * - `GET /api/config` — login page needs tenant-specific config (social logins, registration) + * - `/api/auth/*` — login, register, password reset + * - `/oauth/*` — OAuth callback flows + * - `GET /api/share/:shareId` — public shared conversation links + * + * **How the header gets set**: The deployment's reverse proxy, auth gateway, + * or OpenID strategy sets `X-Tenant-Id` based on subdomain, path, or OIDC claim. + * This middleware does NOT resolve tenants from subdomains or tokens — that is + * the responsibility of the deployment layer. + * + * **Design**: Intentionally minimal. No subdomain parsing, no OIDC claim + * extraction, no YAML-driven strategy. Multi-tenant deployments can: + * 1. Set the header in the reverse proxy / ingress (simplest), + * 2. Replace this middleware's resolver logic entirely, or + * 3. Layer additional resolution on top (e.g., OpenID `tenant` claim → header). + * + * If no header is present, downstream runs without tenant ALS context (same as + * single-tenant mode). This preserves backward compatibility. + */ +const MAX_TENANT_ID_LENGTH = 128; +const VALID_TENANT_ID = /^[-a-zA-Z0-9_.]+$/; + +export function preAuthTenantMiddleware(req: Request, res: Response, next: NextFunction): void { + const raw = req.headers['x-tenant-id']; + + if (!raw || typeof raw !== 'string') { + next(); + return; + } + + const tenantId = raw.trim(); + + if (!tenantId) { + next(); + return; + } + + if (tenantId === SYSTEM_TENANT_ID) { + logger.warn('[preAuthTenant] Rejected __SYSTEM__ sentinel in X-Tenant-Id header', { + ip: req.ip, + path: req.path, + }); + next(); + return; + } + + if (tenantId.length > MAX_TENANT_ID_LENGTH || !VALID_TENANT_ID.test(tenantId)) { + logger.warn('[preAuthTenant] Rejected malformed X-Tenant-Id header', { + ip: req.ip, + length: tenantId.length, + path: req.path, + }); + next(); + return; + } + + return void tenantStorage.run({ tenantId }, async () => { + next(); + }); +} diff --git a/packages/api/src/stream/GenerationJobManager.ts b/packages/api/src/stream/GenerationJobManager.ts index 3e04ab734b..5993c911ff 100644 --- a/packages/api/src/stream/GenerationJobManager.ts +++ b/packages/api/src/stream/GenerationJobManager.ts @@ -1,4 +1,4 @@ -import { logger } from '@librechat/data-schemas'; +import { logger, getTenantId, SYSTEM_TENANT_ID } from '@librechat/data-schemas'; import type { StandardGraph } from '@librechat/agents'; import { parseTextParts } from 'librechat-data-provider'; import type { Agents, TMessageContentParts } from 'librechat-data-provider'; @@ -197,7 +197,9 @@ class GenerationJobManagerClass { userId: string, conversationId?: string, ): Promise { - const jobData = await this.jobStore.createJob(streamId, userId, conversationId); + const tenantId = getTenantId(); + const safeTenantId = tenantId && tenantId !== SYSTEM_TENANT_ID ? tenantId : undefined; + const jobData = await this.jobStore.createJob(streamId, userId, conversationId, safeTenantId); /** * Create runtime state with readyPromise. @@ -355,6 +357,7 @@ class GenerationJobManagerClass { error: jobData.error, metadata: { userId: jobData.userId, + tenantId: jobData.tenantId, conversationId: jobData.conversationId, userMessage: jobData.userMessage, responseMessageId: jobData.responseMessageId, @@ -1255,8 +1258,8 @@ class GenerationJobManagerClass { * @param userId - The user ID to query * @returns Array of conversation IDs with active jobs */ - async getActiveJobIdsForUser(userId: string): Promise { - return this.jobStore.getActiveJobIdsByUser(userId); + async getActiveJobIdsForUser(userId: string, tenantId?: string): Promise { + return this.jobStore.getActiveJobIdsByUser(userId, tenantId); } /** diff --git a/packages/api/src/stream/implementations/InMemoryJobStore.ts b/packages/api/src/stream/implementations/InMemoryJobStore.ts index cc82a69963..7280c3ce80 100644 --- a/packages/api/src/stream/implementations/InMemoryJobStore.ts +++ b/packages/api/src/stream/implementations/InMemoryJobStore.ts @@ -70,6 +70,7 @@ export class InMemoryJobStore implements IJobStore { streamId: string, userId: string, conversationId?: string, + tenantId?: string, ): Promise { if (this.jobs.size >= this.maxJobs) { await this.evictOldest(); @@ -78,6 +79,7 @@ export class InMemoryJobStore implements IJobStore { const job: SerializableJobData = { streamId, userId, + ...(tenantId && { tenantId }), status: 'running', createdAt: Date.now(), conversationId, @@ -86,11 +88,12 @@ export class InMemoryJobStore implements IJobStore { this.jobs.set(streamId, job); - // Track job by userId for efficient user-scoped queries - let userJobs = this.userJobMap.get(userId); + // Track job by userId (tenant-qualified when available) for efficient user-scoped queries + const userKey = tenantId ? `${tenantId}:${userId}` : userId; + let userJobs = this.userJobMap.get(userKey); if (!userJobs) { userJobs = new Set(); - this.userJobMap.set(userId, userJobs); + this.userJobMap.set(userKey, userJobs); } userJobs.add(streamId); @@ -146,6 +149,17 @@ export class InMemoryJobStore implements IJobStore { } for (const id of toDelete) { + const job = this.jobs.get(id); + if (job) { + const userKey = job.tenantId ? `${job.tenantId}:${job.userId}` : job.userId; + const userJobs = this.userJobMap.get(userKey); + if (userJobs) { + userJobs.delete(id); + if (userJobs.size === 0) { + this.userJobMap.delete(userKey); + } + } + } await this.deleteJob(id); } @@ -169,6 +183,17 @@ export class InMemoryJobStore implements IJobStore { if (oldestId) { logger.warn(`[InMemoryJobStore] Evicting oldest job: ${oldestId}`); + const job = this.jobs.get(oldestId); + if (job) { + const userKey = job.tenantId ? `${job.tenantId}:${job.userId}` : job.userId; + const userJobs = this.userJobMap.get(userKey); + if (userJobs) { + userJobs.delete(oldestId); + if (userJobs.size === 0) { + this.userJobMap.delete(userKey); + } + } + } await this.deleteJob(oldestId); } } @@ -205,8 +230,9 @@ export class InMemoryJobStore implements IJobStore { * Returns conversation IDs of running jobs belonging to the user. * Also performs self-healing cleanup: removes stale entries for jobs that no longer exist. */ - async getActiveJobIdsByUser(userId: string): Promise { - const trackedIds = this.userJobMap.get(userId); + async getActiveJobIdsByUser(userId: string, tenantId?: string): Promise { + const userKey = tenantId ? `${tenantId}:${userId}` : userId; + const trackedIds = this.userJobMap.get(userKey); if (!trackedIds || trackedIds.size === 0) { return []; } @@ -226,7 +252,7 @@ export class InMemoryJobStore implements IJobStore { // Clean up empty set if (trackedIds.size === 0) { - this.userJobMap.delete(userId); + this.userJobMap.delete(userKey); } return activeIds; diff --git a/packages/api/src/stream/implementations/RedisJobStore.ts b/packages/api/src/stream/implementations/RedisJobStore.ts index 727fe066eb..a631bc2044 100644 --- a/packages/api/src/stream/implementations/RedisJobStore.ts +++ b/packages/api/src/stream/implementations/RedisJobStore.ts @@ -29,8 +29,9 @@ const KEYS = { runSteps: (streamId: string) => `stream:{${streamId}}:runsteps`, /** Running jobs set for cleanup (global set - single slot) */ runningJobs: 'stream:running', - /** User's active jobs set: stream:user:{userId}:jobs */ - userJobs: (userId: string) => `stream:user:{${userId}}:jobs`, + /** User's active jobs set, tenant-qualified when tenantId is available */ + userJobs: (userId: string, tenantId?: string) => + tenantId ? `stream:user:{${tenantId}:${userId}}:jobs` : `stream:user:{${userId}}:jobs`, }; /** @@ -140,10 +141,12 @@ export class RedisJobStore implements IJobStore { streamId: string, userId: string, conversationId?: string, + tenantId?: string, ): Promise { const job: SerializableJobData = { streamId, userId, + ...(tenantId && { tenantId }), status: 'running', createdAt: Date.now(), conversationId, @@ -151,7 +154,7 @@ export class RedisJobStore implements IJobStore { }; const key = KEYS.job(streamId); - const userJobsKey = KEYS.userJobs(userId); + const userJobsKey = KEYS.userJobs(userId, tenantId); // For cluster mode, we can't pipeline keys on different slots // The job key uses hash tag {streamId}, runningJobs and userJobs are on different slots @@ -377,8 +380,8 @@ export class RedisJobStore implements IJobStore { * @param userId - The user ID to query * @returns Array of conversation IDs with active jobs */ - async getActiveJobIdsByUser(userId: string): Promise { - const userJobsKey = KEYS.userJobs(userId); + async getActiveJobIdsByUser(userId: string, tenantId?: string): Promise { + const userJobsKey = KEYS.userJobs(userId, tenantId); const trackedIds = await this.redis.smembers(userJobsKey); if (trackedIds.length === 0) { @@ -868,6 +871,7 @@ export class RedisJobStore implements IJobStore { return { streamId: data.streamId, userId: data.userId, + tenantId: data.tenantId || undefined, status: data.status as JobStatus, createdAt: parseInt(data.createdAt, 10), completedAt: data.completedAt ? parseInt(data.completedAt, 10) : undefined, diff --git a/packages/api/src/stream/interfaces/IJobStore.ts b/packages/api/src/stream/interfaces/IJobStore.ts index fadddb840d..b59eed66f8 100644 --- a/packages/api/src/stream/interfaces/IJobStore.ts +++ b/packages/api/src/stream/interfaces/IJobStore.ts @@ -12,6 +12,7 @@ export type JobStatus = 'running' | 'complete' | 'error' | 'aborted'; export interface SerializableJobData { streamId: string; userId: string; + tenantId?: string; status: JobStatus; createdAt: number; completedAt?: number; @@ -149,6 +150,7 @@ export interface IJobStore { streamId: string, userId: string, conversationId?: string, + tenantId?: string, ): Promise; /** Get a job by streamId (streamId === conversationId) */ @@ -186,7 +188,7 @@ export interface IJobStore { * @param userId - The user ID to query * @returns Array of conversation IDs with active jobs */ - getActiveJobIdsByUser(userId: string): Promise; + getActiveJobIdsByUser(userId: string, tenantId?: string): Promise; // ===== Content State Methods ===== // These methods manage volatile content state tied to each job. diff --git a/packages/api/src/types/stream.ts b/packages/api/src/types/stream.ts index 068d9c8db8..dd125a1aab 100644 --- a/packages/api/src/types/stream.ts +++ b/packages/api/src/types/stream.ts @@ -4,6 +4,7 @@ import type { ServerSentEvent } from '~/types'; export interface GenerationJobMetadata { userId: string; + tenantId?: string; conversationId?: string; /** User message data for rebuilding submission on reconnect */ userMessage?: Agents.UserMessageMeta; diff --git a/packages/data-schemas/src/config/tenantContext.spec.ts b/packages/data-schemas/src/config/tenantContext.spec.ts new file mode 100644 index 0000000000..7e6cc0748d --- /dev/null +++ b/packages/data-schemas/src/config/tenantContext.spec.ts @@ -0,0 +1,26 @@ +import { tenantStorage, runAsSystem, scopedCacheKey } from './tenantContext'; + +describe('scopedCacheKey', () => { + it('returns base key when no ALS context is set', () => { + expect(scopedCacheKey('MODELS_CONFIG')).toBe('MODELS_CONFIG'); + }); + + it('returns base key in SYSTEM_TENANT_ID context', async () => { + await runAsSystem(async () => { + expect(scopedCacheKey('MODELS_CONFIG')).toBe('MODELS_CONFIG'); + }); + }); + + it('appends tenantId when tenant context is active', async () => { + await tenantStorage.run({ tenantId: 'acme' }, async () => { + expect(scopedCacheKey('MODELS_CONFIG')).toBe('MODELS_CONFIG:acme'); + }); + }); + + it('does not leak tenant context outside ALS scope', async () => { + await tenantStorage.run({ tenantId: 'acme' }, async () => { + expect(scopedCacheKey('KEY')).toBe('KEY:acme'); + }); + expect(scopedCacheKey('KEY')).toBe('KEY'); + }); +}); diff --git a/packages/data-schemas/src/config/tenantContext.ts b/packages/data-schemas/src/config/tenantContext.ts index e5e4376a90..eb77edb27d 100644 --- a/packages/data-schemas/src/config/tenantContext.ts +++ b/packages/data-schemas/src/config/tenantContext.ts @@ -26,3 +26,16 @@ export function getTenantId(): string | undefined { export function runAsSystem(fn: () => Promise): Promise { return tenantStorage.run({ tenantId: SYSTEM_TENANT_ID }, fn); } + +/** + * Appends `:${tenantId}` to a cache key when a non-system tenant context is active. + * Returns the base key unchanged when no ALS context is set or when running + * inside `runAsSystem()` (SYSTEM_TENANT_ID context). + */ +export function scopedCacheKey(baseKey: string): string { + const tenantId = getTenantId(); + if (!tenantId || tenantId === SYSTEM_TENANT_ID) { + return baseKey; + } + return `${baseKey}:${tenantId}`; +} diff --git a/packages/data-schemas/src/index.ts b/packages/data-schemas/src/index.ts index d673db1f5c..1139f83f17 100644 --- a/packages/data-schemas/src/index.ts +++ b/packages/data-schemas/src/index.ts @@ -19,6 +19,12 @@ export type * from './types'; export type * from './methods'; export { default as logger } from './config/winston'; export { default as meiliLogger } from './config/meiliLogger'; -export { tenantStorage, getTenantId, runAsSystem, SYSTEM_TENANT_ID } from './config/tenantContext'; +export { + tenantStorage, + getTenantId, + runAsSystem, + scopedCacheKey, + SYSTEM_TENANT_ID, +} from './config/tenantContext'; export type { TenantContext } from './config/tenantContext'; export { dropSupersededTenantIndexes, dropSupersededPromptGroupIndexes } from './migrations'; diff --git a/packages/data-schemas/src/methods/aclEntry.ts b/packages/data-schemas/src/methods/aclEntry.ts index 82e277254a..2f61861029 100644 --- a/packages/data-schemas/src/methods/aclEntry.ts +++ b/packages/data-schemas/src/methods/aclEntry.ts @@ -8,6 +8,7 @@ import type { Model, } from 'mongoose'; import type { IAclEntry } from '~/types'; +import { tenantSafeBulkWrite } from '~/utils/tenantBulkWrite'; export function createAclEntryMethods(mongoose: typeof import('mongoose')) { /** @@ -378,7 +379,7 @@ export function createAclEntryMethods(mongoose: typeof import('mongoose')) { options?: { session?: ClientSession }, ) { const AclEntry = mongoose.models.AclEntry as Model; - return AclEntry.bulkWrite(ops, options || {}); + return tenantSafeBulkWrite(AclEntry, ops, options || {}); } /** @@ -448,7 +449,9 @@ export function createAclEntryMethods(mongoose: typeof import('mongoose')) { { $group: { _id: '$resourceId' } }, ]); - const multiOwnerIds = new Set(otherOwners.map((doc: { _id: Types.ObjectId }) => doc._id.toString())); + const multiOwnerIds = new Set( + otherOwners.map((doc: { _id: Types.ObjectId }) => doc._id.toString()), + ); return ownedIds.filter((id) => !multiOwnerIds.has(id.toString())); } diff --git a/packages/data-schemas/src/methods/agentCategory.ts b/packages/data-schemas/src/methods/agentCategory.ts index 2dd4678075..baf33207aa 100644 --- a/packages/data-schemas/src/methods/agentCategory.ts +++ b/packages/data-schemas/src/methods/agentCategory.ts @@ -1,5 +1,6 @@ import type { Model, Types } from 'mongoose'; import type { IAgentCategory } from '~/types'; +import { tenantSafeBulkWrite } from '~/utils/tenantBulkWrite'; export function createAgentCategoryMethods(mongoose: typeof import('mongoose')) { /** @@ -74,7 +75,7 @@ export function createAgentCategoryMethods(mongoose: typeof import('mongoose')) }, })); - return await AgentCategory.bulkWrite(operations); + return await tenantSafeBulkWrite(AgentCategory, operations); } /** @@ -241,7 +242,7 @@ export function createAgentCategoryMethods(mongoose: typeof import('mongoose')) }, })); - await AgentCategory.bulkWrite(bulkOps, { ordered: false }); + await tenantSafeBulkWrite(AgentCategory, bulkOps, { ordered: false }); } return updates.length > 0 || created > 0; diff --git a/packages/data-schemas/src/methods/conversation.ts b/packages/data-schemas/src/methods/conversation.ts index 7a62afef9e..abfe16bf2d 100644 --- a/packages/data-schemas/src/methods/conversation.ts +++ b/packages/data-schemas/src/methods/conversation.ts @@ -1,6 +1,7 @@ import type { FilterQuery, Model, SortOrder } from 'mongoose'; -import logger from '~/config/winston'; import { createTempChatExpirationDate } from '~/utils/tempChatRetention'; +import { tenantSafeBulkWrite } from '~/utils/tenantBulkWrite'; +import logger from '~/config/winston'; import type { AppConfig, IConversation } from '~/types'; import type { MessageMethods } from './message'; import type { DeleteResult } from 'mongoose'; @@ -228,7 +229,7 @@ export function createConversationMethods( }, })); - const result = await Conversation.bulkWrite(bulkOps); + const result = await tenantSafeBulkWrite(Conversation, bulkOps); return result; } catch (error) { logger.error('[bulkSaveConvos] Error saving conversations in bulk', error); diff --git a/packages/data-schemas/src/methods/conversationTag.ts b/packages/data-schemas/src/methods/conversationTag.ts index af1e43babb..085948bab5 100644 --- a/packages/data-schemas/src/methods/conversationTag.ts +++ b/packages/data-schemas/src/methods/conversationTag.ts @@ -1,4 +1,5 @@ import type { Model } from 'mongoose'; +import { tenantSafeBulkWrite } from '~/utils/tenantBulkWrite'; import logger from '~/config/winston'; interface IConversationTag { @@ -233,7 +234,7 @@ export function createConversationTagMethods(mongoose: typeof import('mongoose') } if (bulkOps.length > 0) { - await ConversationTag.bulkWrite(bulkOps); + await tenantSafeBulkWrite(ConversationTag, bulkOps); } const updatedConversation = ( @@ -273,7 +274,7 @@ export function createConversationTagMethods(mongoose: typeof import('mongoose') }, })); - const result = await ConversationTag.bulkWrite(bulkOps); + const result = await tenantSafeBulkWrite(ConversationTag, bulkOps); if (result && result.modifiedCount > 0) { logger.debug( `user: ${user} | Incremented tag counts - modified ${result.modifiedCount} tags`, diff --git a/packages/data-schemas/src/methods/file.ts b/packages/data-schemas/src/methods/file.ts index 3d7db88c3f..4c0969afb3 100644 --- a/packages/data-schemas/src/methods/file.ts +++ b/packages/data-schemas/src/methods/file.ts @@ -2,6 +2,7 @@ import logger from '../config/winston'; import { EToolResources, FileContext } from 'librechat-data-provider'; import type { FilterQuery, SortOrder, Model } from 'mongoose'; import type { IMongoFile } from '~/types/file'; +import { tenantSafeBulkWrite } from '~/utils/tenantBulkWrite'; /** Factory function that takes mongoose instance and returns the file methods */ export function createFileMethods(mongoose: typeof import('mongoose')) { @@ -322,7 +323,7 @@ export function createFileMethods(mongoose: typeof import('mongoose')) { }, })); - const result = await File.bulkWrite(bulkOperations); + const result = await tenantSafeBulkWrite(File, bulkOperations); logger.info(`Updated ${result.modifiedCount} files with new S3 URLs`); } diff --git a/packages/data-schemas/src/methods/message.ts b/packages/data-schemas/src/methods/message.ts index ae5ca72b12..2e638b6bfb 100644 --- a/packages/data-schemas/src/methods/message.ts +++ b/packages/data-schemas/src/methods/message.ts @@ -1,6 +1,7 @@ import type { DeleteResult, FilterQuery, Model } from 'mongoose'; import logger from '~/config/winston'; import { createTempChatExpirationDate } from '~/utils/tempChatRetention'; +import { tenantSafeBulkWrite } from '~/utils/tenantBulkWrite'; import type { AppConfig, IMessage } from '~/types'; /** Simple UUID v4 regex to replace zod validation */ @@ -165,7 +166,7 @@ export function createMessageMethods(mongoose: typeof import('mongoose')): Messa upsert: true, }, })); - const result = await Message.bulkWrite(bulkOps); + const result = await tenantSafeBulkWrite(Message, bulkOps); return result; } catch (err) { logger.error('Error saving messages in bulk:', err); diff --git a/packages/data-schemas/src/methods/prompt.ts b/packages/data-schemas/src/methods/prompt.ts index a1b6bfde37..86d830fecd 100644 --- a/packages/data-schemas/src/methods/prompt.ts +++ b/packages/data-schemas/src/methods/prompt.ts @@ -1,8 +1,9 @@ import { ResourceType, SystemCategories } from 'librechat-data-provider'; import type { Model, Types } from 'mongoose'; import type { IAclEntry, IPrompt, IPromptGroup, IPromptGroupDocument } from '~/types'; -import { escapeRegExp } from '~/utils/string'; +import { getTenantId, SYSTEM_TENANT_ID } from '~/config/tenantContext'; import { isValidObjectIdString } from '~/utils/objectId'; +import { escapeRegExp } from '~/utils/string'; import logger from '~/config/winston'; export interface PromptDeps { @@ -508,16 +509,37 @@ export function createPromptMethods(mongoose: typeof import('mongoose'), deps: P if (typeof matchFilter._id === 'string') { matchFilter._id = new ObjectId(matchFilter._id); } + const tenantId = getTenantId(); + const useTenantFilter = tenantId && tenantId !== SYSTEM_TENANT_ID; + + const lookupStage = useTenantFilter + ? { + $lookup: { + from: 'prompts', + let: { prodId: '$productionId' }, + pipeline: [ + { + $match: { + $expr: { $eq: ['$_id', '$$prodId'] }, + tenantId, + }, + }, + ], + as: 'productionPrompt', + }, + } + : { + $lookup: { + from: 'prompts', + localField: 'productionId', + foreignField: '_id', + as: 'productionPrompt', + }, + }; + const result = await PromptGroup.aggregate([ { $match: matchFilter }, - { - $lookup: { - from: 'prompts', - localField: 'productionId', - foreignField: '_id', - as: 'productionPrompt', - }, - }, + lookupStage, { $unwind: { path: '$productionPrompt', preserveNullAndEmptyArrays: true } }, ]); const group = result[0] || null; diff --git a/packages/data-schemas/src/migrations/tenantIndexes.spec.ts b/packages/data-schemas/src/migrations/tenantIndexes.spec.ts index 4637e7d0ad..e62b587a6e 100644 --- a/packages/data-schemas/src/migrations/tenantIndexes.spec.ts +++ b/packages/data-schemas/src/migrations/tenantIndexes.spec.ts @@ -47,7 +47,13 @@ describe('dropSupersededTenantIndexes', () => { await db.createCollection('roles'); await db.collection('roles').createIndex({ name: 1 }, { unique: true, name: 'name_1' }); + await db.createCollection('agents'); + await db.collection('agents').createIndex({ id: 1 }, { unique: true, name: 'id_1' }); + await db.createCollection('conversations'); + await db + .collection('conversations') + .createIndex({ conversationId: 1 }, { unique: true, name: 'conversationId_1' }); await db .collection('conversations') .createIndex( @@ -56,10 +62,18 @@ describe('dropSupersededTenantIndexes', () => { ); await db.createCollection('messages'); + await db + .collection('messages') + .createIndex({ messageId: 1 }, { unique: true, name: 'messageId_1' }); await db .collection('messages') .createIndex({ messageId: 1, user: 1 }, { unique: true, name: 'messageId_1_user_1' }); + await db.createCollection('presets'); + await db + .collection('presets') + .createIndex({ presetId: 1 }, { unique: true, name: 'presetId_1' }); + await db.createCollection('agentcategories'); await db .collection('agentcategories') diff --git a/packages/data-schemas/src/migrations/tenantIndexes.ts b/packages/data-schemas/src/migrations/tenantIndexes.ts index c68df4db2b..a8b4e51768 100644 --- a/packages/data-schemas/src/migrations/tenantIndexes.ts +++ b/packages/data-schemas/src/migrations/tenantIndexes.ts @@ -24,8 +24,10 @@ const SUPERSEDED_INDEXES: Record = { 'appleId_1', ], roles: ['name_1'], - conversations: ['conversationId_1_user_1'], - messages: ['messageId_1_user_1'], + agents: ['id_1'], + conversations: ['conversationId_1', 'conversationId_1_user_1'], + messages: ['messageId_1', 'messageId_1_user_1'], + presets: ['presetId_1'], agentcategories: ['value_1'], accessroles: ['accessRoleId_1'], conversationtags: ['tag_1_user_1'], diff --git a/packages/data-schemas/src/schema/agent.ts b/packages/data-schemas/src/schema/agent.ts index 42a7ca5418..70734d0ceb 100644 --- a/packages/data-schemas/src/schema/agent.ts +++ b/packages/data-schemas/src/schema/agent.ts @@ -5,8 +5,6 @@ const agentSchema = new Schema( { id: { type: String, - index: true, - unique: true, required: true, }, name: { @@ -124,6 +122,7 @@ const agentSchema = new Schema( }, ); +agentSchema.index({ id: 1, tenantId: 1 }, { unique: true }); agentSchema.index({ updatedAt: -1, _id: 1 }); agentSchema.index({ 'edges.to': 1 }); diff --git a/packages/data-schemas/src/schema/convo.ts b/packages/data-schemas/src/schema/convo.ts index 9ed8949e9c..c8f394935a 100644 --- a/packages/data-schemas/src/schema/convo.ts +++ b/packages/data-schemas/src/schema/convo.ts @@ -6,7 +6,6 @@ const convoSchema: Schema = new Schema( { conversationId: { type: String, - unique: true, required: true, index: true, meiliIndex: true, diff --git a/packages/data-schemas/src/schema/message.ts b/packages/data-schemas/src/schema/message.ts index ff3468918e..9879efae55 100644 --- a/packages/data-schemas/src/schema/message.ts +++ b/packages/data-schemas/src/schema/message.ts @@ -5,7 +5,6 @@ const messageSchema: Schema = new Schema( { messageId: { type: String, - unique: true, required: true, index: true, meiliIndex: true, diff --git a/packages/data-schemas/src/schema/preset.ts b/packages/data-schemas/src/schema/preset.ts index 33c217ea23..5af5163fd3 100644 --- a/packages/data-schemas/src/schema/preset.ts +++ b/packages/data-schemas/src/schema/preset.ts @@ -60,7 +60,6 @@ const presetSchema: Schema = new Schema( { presetId: { type: String, - unique: true, required: true, index: true, }, @@ -88,4 +87,6 @@ const presetSchema: Schema = new Schema( { timestamps: true }, ); +presetSchema.index({ presetId: 1, tenantId: 1 }, { unique: true }); + export default presetSchema; diff --git a/packages/data-schemas/src/utils/index.ts b/packages/data-schemas/src/utils/index.ts index c071f4e827..17e43ac3ca 100644 --- a/packages/data-schemas/src/utils/index.ts +++ b/packages/data-schemas/src/utils/index.ts @@ -1,5 +1,6 @@ export * from './principal'; export * from './string'; export * from './tempChatRetention'; +export { tenantSafeBulkWrite } from './tenantBulkWrite'; export * from './transactions'; export * from './objectId'; diff --git a/packages/data-schemas/src/utils/tenantBulkWrite.spec.ts b/packages/data-schemas/src/utils/tenantBulkWrite.spec.ts new file mode 100644 index 0000000000..059868b8a1 --- /dev/null +++ b/packages/data-schemas/src/utils/tenantBulkWrite.spec.ts @@ -0,0 +1,376 @@ +import mongoose, { Schema } from 'mongoose'; +import { MongoMemoryServer } from 'mongodb-memory-server'; +import { tenantStorage, runAsSystem, SYSTEM_TENANT_ID } from '~/config/tenantContext'; +import { applyTenantIsolation, _resetStrictCache } from '~/models/plugins/tenantIsolation'; +import { tenantSafeBulkWrite, _resetBulkWriteStrictCache } from './tenantBulkWrite'; + +let mongoServer: InstanceType; + +interface ITestDoc { + name: string; + value?: number; + tenantId?: string; +} + +function createTestModel(suffix: string) { + const schema = new Schema({ + name: { type: String, required: true }, + value: { type: Number, default: 0 }, + tenantId: { type: String, index: true }, + }); + applyTenantIsolation(schema); + const modelName = `TestBulkWrite_${suffix}_${Date.now()}`; + return mongoose.model(modelName, schema); +} + +beforeAll(async () => { + mongoServer = await MongoMemoryServer.create(); + await mongoose.connect(mongoServer.getUri()); +}); + +afterAll(async () => { + await mongoose.disconnect(); + await mongoServer.stop(); +}); + +afterEach(() => { + delete process.env.TENANT_ISOLATION_STRICT; + _resetStrictCache(); + _resetBulkWriteStrictCache(); +}); + +describe('tenantSafeBulkWrite', () => { + describe('with tenant context', () => { + it('injects tenantId into updateOne filters', async () => { + const Model = createTestModel('updateOne'); + + // Seed data for two tenants + await runAsSystem(async () => { + await Model.create([ + { name: 'doc1', value: 1, tenantId: 'tenant-a' }, + { name: 'doc1', value: 1, tenantId: 'tenant-b' }, + ]); + }); + + // Update only tenant-a's doc + await tenantStorage.run({ tenantId: 'tenant-a' }, async () => { + await tenantSafeBulkWrite(Model, [ + { + updateOne: { + filter: { name: 'doc1' }, + update: { $set: { value: 99 } }, + }, + }, + ]); + }); + + // Verify tenant-a was updated, tenant-b was not + const docs = await runAsSystem(async () => Model.find({}).lean()); + const docA = docs.find((d) => d.tenantId === 'tenant-a'); + const docB = docs.find((d) => d.tenantId === 'tenant-b'); + expect(docA?.value).toBe(99); + expect(docB?.value).toBe(1); + }); + + it('injects tenantId into insertOne documents', async () => { + const Model = createTestModel('insertOne'); + + await tenantStorage.run({ tenantId: 'tenant-x' }, async () => { + await tenantSafeBulkWrite(Model, [ + { + insertOne: { + document: { name: 'new-doc', value: 42 } as ITestDoc, + }, + }, + ]); + }); + + const docs = await runAsSystem(async () => Model.find({}).lean()); + expect(docs).toHaveLength(1); + expect(docs[0].tenantId).toBe('tenant-x'); + expect(docs[0].name).toBe('new-doc'); + }); + + it('injects tenantId into deleteOne filters', async () => { + const Model = createTestModel('deleteOne'); + + await runAsSystem(async () => { + await Model.create([ + { name: 'to-delete', tenantId: 'tenant-a' }, + { name: 'to-delete', tenantId: 'tenant-b' }, + ]); + }); + + await tenantStorage.run({ tenantId: 'tenant-a' }, async () => { + await tenantSafeBulkWrite(Model, [ + { + deleteOne: { + filter: { name: 'to-delete' }, + }, + }, + ]); + }); + + const docs = await runAsSystem(async () => Model.find({}).lean()); + expect(docs).toHaveLength(1); + expect(docs[0].tenantId).toBe('tenant-b'); + }); + + it('injects tenantId into updateMany filters', async () => { + const Model = createTestModel('updateMany'); + + await runAsSystem(async () => { + await Model.create([ + { name: 'batch', value: 0, tenantId: 'tenant-a' }, + { name: 'batch', value: 0, tenantId: 'tenant-a' }, + { name: 'batch', value: 0, tenantId: 'tenant-b' }, + ]); + }); + + await tenantStorage.run({ tenantId: 'tenant-a' }, async () => { + await tenantSafeBulkWrite(Model, [ + { + updateMany: { + filter: { name: 'batch' }, + update: { $set: { value: 5 } }, + }, + }, + ]); + }); + + const docs = await runAsSystem(async () => Model.find({}).lean()); + const tenantADocs = docs.filter((d) => d.tenantId === 'tenant-a'); + const tenantBDocs = docs.filter((d) => d.tenantId === 'tenant-b'); + expect(tenantADocs.every((d) => d.value === 5)).toBe(true); + expect(tenantBDocs[0].value).toBe(0); + }); + }); + + describe('with SYSTEM_TENANT_ID', () => { + it('skips tenantId injection (cross-tenant operation)', async () => { + const Model = createTestModel('system'); + + await runAsSystem(async () => { + await Model.create([ + { name: 'sys-doc', value: 0, tenantId: 'tenant-a' }, + { name: 'sys-doc', value: 0, tenantId: 'tenant-b' }, + ]); + }); + + // System context should update ALL docs regardless of tenant + await runAsSystem(async () => { + await tenantSafeBulkWrite(Model, [ + { + updateMany: { + filter: { name: 'sys-doc' }, + update: { $set: { value: 100 } }, + }, + }, + ]); + }); + + const docs = await runAsSystem(async () => Model.find({}).lean()); + expect(docs.every((d) => d.value === 100)).toBe(true); + }); + }); + + describe('with SYSTEM_TENANT_ID in strict mode', () => { + it('does not throw when runAsSystem is used in strict mode', async () => { + process.env.TENANT_ISOLATION_STRICT = 'true'; + _resetBulkWriteStrictCache(); + + const Model = createTestModel('systemStrict'); + + await runAsSystem(async () => { + await Model.create({ name: 'strict-sys', value: 0 }); + }); + + await expect( + runAsSystem(async () => + tenantSafeBulkWrite(Model, [ + { + updateOne: { + filter: { name: 'strict-sys' }, + update: { $set: { value: 42 } }, + }, + }, + ]), + ), + ).resolves.toBeDefined(); + }); + }); + + describe('deleteMany and replaceOne', () => { + it('injects tenantId into deleteMany filters', async () => { + const Model = createTestModel('deleteMany'); + + await runAsSystem(async () => { + await Model.create([ + { name: 'batch-del', value: 0, tenantId: 'tenant-a' }, + { name: 'batch-del', value: 0, tenantId: 'tenant-a' }, + { name: 'batch-del', value: 0, tenantId: 'tenant-b' }, + ]); + }); + + await tenantStorage.run({ tenantId: 'tenant-a' }, async () => { + await tenantSafeBulkWrite(Model, [{ deleteMany: { filter: { name: 'batch-del' } } }]); + }); + + const docs = await runAsSystem(async () => Model.find({}).lean()); + expect(docs).toHaveLength(1); + expect(docs[0].tenantId).toBe('tenant-b'); + }); + + it('injects tenantId into replaceOne filter and replacement', async () => { + const Model = createTestModel('replaceOne'); + + await runAsSystem(async () => { + await Model.create([ + { name: 'to-replace', value: 1, tenantId: 'tenant-a' }, + { name: 'to-replace', value: 1, tenantId: 'tenant-b' }, + ]); + }); + + await tenantStorage.run({ tenantId: 'tenant-a' }, async () => { + await tenantSafeBulkWrite(Model, [ + { + replaceOne: { + filter: { name: 'to-replace' }, + replacement: { name: 'replaced', value: 99 }, + }, + }, + ]); + }); + + const docs = await runAsSystem(async () => Model.find({}).sort({ name: 1 }).lean()); + const replaced = docs.find((d) => d.name === 'replaced'); + const untouched = docs.find((d) => d.tenantId === 'tenant-b'); + expect(replaced?.value).toBe(99); + expect(replaced?.tenantId).toBe('tenant-a'); + expect(untouched?.value).toBe(1); + }); + + it('replaceOne overwrites a conflicting tenantId in the replacement document', async () => { + const Model = createTestModel('replaceOverwrite'); + + await runAsSystem(async () => { + await Model.create({ name: 'conflict', value: 1, tenantId: 'tenant-a' }); + }); + + await tenantStorage.run({ tenantId: 'tenant-a' }, async () => { + await tenantSafeBulkWrite(Model, [ + { + replaceOne: { + filter: { name: 'conflict' }, + replacement: { name: 'conflict', value: 2, tenantId: 'tenant-evil' } as ITestDoc, + }, + }, + ]); + }); + + const docs = await runAsSystem(async () => Model.find({}).lean()); + expect(docs).toHaveLength(1); + expect(docs[0].tenantId).toBe('tenant-a'); + expect(docs[0].value).toBe(2); + }); + }); + + describe('edge cases', () => { + it('handles empty ops array', async () => { + const Model = createTestModel('emptyOps'); + const result = await tenantStorage.run({ tenantId: 'tenant-x' }, async () => + tenantSafeBulkWrite(Model, []), + ); + expect(result.insertedCount).toBe(0); + expect(result.modifiedCount).toBe(0); + }); + }); + + describe('without tenant context', () => { + it('passes through in non-strict mode', async () => { + const Model = createTestModel('noCtx'); + + await runAsSystem(async () => { + await Model.create({ name: 'no-ctx', value: 0 }); + }); + + // No ALS context — non-strict should pass through + const result = await tenantSafeBulkWrite(Model, [ + { + updateOne: { + filter: { name: 'no-ctx' }, + update: { $set: { value: 10 } }, + }, + }, + ]); + + expect(result.modifiedCount).toBe(1); + }); + + it('throws in strict mode', async () => { + process.env.TENANT_ISOLATION_STRICT = 'true'; + _resetBulkWriteStrictCache(); + + const Model = createTestModel('strict'); + + await expect( + tenantSafeBulkWrite(Model, [ + { + updateOne: { + filter: { name: 'any' }, + update: { $set: { value: 1 } }, + }, + }, + ]), + ).rejects.toThrow('bulkWrite on TestBulkWrite_strict'); + }); + }); + + describe('mixed operations', () => { + it('handles a batch of mixed insert, update, delete operations', async () => { + const Model = createTestModel('mixed'); + + await runAsSystem(async () => { + await Model.create([ + { name: 'existing1', value: 1, tenantId: 'tenant-m' }, + { name: 'to-remove', value: 2, tenantId: 'tenant-m' }, + { name: 'existing1', value: 1, tenantId: 'tenant-other' }, + ]); + }); + + await tenantStorage.run({ tenantId: 'tenant-m' }, async () => { + await tenantSafeBulkWrite(Model, [ + { + insertOne: { + document: { name: 'new-item', value: 10 } as ITestDoc, + }, + }, + { + updateOne: { + filter: { name: 'existing1' }, + update: { $set: { value: 50 } }, + }, + }, + { + deleteOne: { + filter: { name: 'to-remove' }, + }, + }, + ]); + }); + + const docs = await runAsSystem(async () => Model.find({}).sort({ name: 1 }).lean()); + + // tenant-other's doc should be untouched + const otherDoc = docs.find((d) => d.tenantId === 'tenant-other' && d.name === 'existing1'); + expect(otherDoc?.value).toBe(1); + + // tenant-m: existing1 updated, to-remove deleted, new-item inserted + const tenantMDocs = docs.filter((d) => d.tenantId === 'tenant-m'); + expect(tenantMDocs).toHaveLength(2); + expect(tenantMDocs.find((d) => d.name === 'existing1')?.value).toBe(50); + expect(tenantMDocs.find((d) => d.name === 'new-item')?.value).toBe(10); + expect(tenantMDocs.find((d) => d.name === 'to-remove')).toBeUndefined(); + }); + }); +}); diff --git a/packages/data-schemas/src/utils/tenantBulkWrite.ts b/packages/data-schemas/src/utils/tenantBulkWrite.ts new file mode 100644 index 0000000000..16ef5fa057 --- /dev/null +++ b/packages/data-schemas/src/utils/tenantBulkWrite.ts @@ -0,0 +1,109 @@ +import type { AnyBulkWriteOperation, Model, MongooseBulkWriteOptions } from 'mongoose'; +import type { BulkWriteResult } from 'mongodb'; +import { getTenantId, SYSTEM_TENANT_ID } from '~/config/tenantContext'; +import logger from '~/config/winston'; + +let _strictMode: boolean | undefined; + +function isStrict(): boolean { + return (_strictMode ??= process.env.TENANT_ISOLATION_STRICT === 'true'); +} + +/** Resets the cached strict-mode flag. Exposed for test teardown only. */ +export function _resetBulkWriteStrictCache(): void { + _strictMode = undefined; +} + +/** + * Tenant-safe wrapper around Mongoose `Model.bulkWrite()`. + * + * Mongoose's `bulkWrite` does not trigger schema-level middleware hooks, so the + * `applyTenantIsolation` plugin cannot intercept it. This wrapper injects the + * current ALS tenant context into every operation's filter and/or document + * before delegating to the native `bulkWrite`. + * + * Behavior: + * - **tenantId present** (normal request): injects `{ tenantId }` into every + * operation filter (updateOne, deleteOne, replaceOne) and document (insertOne). + * - **SYSTEM_TENANT_ID**: skips injection (cross-tenant system operation). + * - **No tenantId + strict mode**: throws (fail-closed, same as the plugin). + * - **No tenantId + non-strict**: passes through without injection (backward compat). + */ +export async function tenantSafeBulkWrite( + model: Model, + ops: AnyBulkWriteOperation[], + options?: MongooseBulkWriteOptions, +): Promise { + const tenantId = getTenantId(); + + if (!tenantId) { + if (isStrict()) { + throw new Error( + `[TenantIsolation] bulkWrite on ${model.modelName} attempted without tenant context in strict mode`, + ); + } + return model.bulkWrite(ops, options); + } + + if (tenantId === SYSTEM_TENANT_ID) { + return model.bulkWrite(ops, options); + } + + const injected = ops.map((op) => injectTenantId(op, tenantId)); + return model.bulkWrite(injected, options); +} + +/** + * Injects `tenantId` into a single bulk-write operation. + * Returns a new operation object — does not mutate the original. + */ +function injectTenantId(op: AnyBulkWriteOperation, tenantId: string): AnyBulkWriteOperation { + if ('insertOne' in op) { + return { + insertOne: { + document: { ...op.insertOne.document, tenantId }, + }, + }; + } + + if ('updateOne' in op) { + const { filter, ...rest } = op.updateOne; + return { updateOne: { ...rest, filter: { ...filter, tenantId } } }; + } + + if ('updateMany' in op) { + const { filter, ...rest } = op.updateMany; + return { updateMany: { ...rest, filter: { ...filter, tenantId } } }; + } + + if ('deleteOne' in op) { + const { filter, ...rest } = op.deleteOne; + return { deleteOne: { ...rest, filter: { ...filter, tenantId } } }; + } + + if ('deleteMany' in op) { + const { filter, ...rest } = op.deleteMany; + return { deleteMany: { ...rest, filter: { ...filter, tenantId } } }; + } + + if ('replaceOne' in op) { + const { filter, replacement, ...rest } = op.replaceOne; + return { + replaceOne: { + ...rest, + filter: { ...filter, tenantId }, + replacement: { ...replacement, tenantId }, + }, + }; + } + + if (isStrict()) { + throw new Error( + '[TenantIsolation] Unknown bulkWrite operation type in strict mode — refusing to pass through without tenant injection', + ); + } + logger.warn( + '[tenantSafeBulkWrite] Unknown bulk op type, passing through without tenant injection', + ); + return op; +}