diff --git a/api/server/controllers/agents/request.js b/api/server/controllers/agents/request.js index aead06b325..97679a1327 100644 --- a/api/server/controllers/agents/request.js +++ b/api/server/controllers/agents/request.js @@ -1,13 +1,17 @@ const { logger } = require('@librechat/data-schemas'); -const { Constants } = require('librechat-data-provider'); +const { Constants, ViolationTypes } = require('librechat-data-provider'); const { sendEvent, + getViolationInfo, GenerationJobManager, + decrementPendingRequest, sanitizeFileForTransmit, sanitizeMessageForTransmit, + checkAndIncrementPendingRequest, } = require('@librechat/api'); -const { handleAbortError } = require('~/server/middleware'); const { disposeClient, clientRegistry, requestDataMap } = require('~/server/cleanup'); +const { handleAbortError } = require('~/server/middleware'); +const { logViolation } = require('~/cache'); const { saveMessage } = require('~/models'); function createCloseHandler(abortController) { @@ -47,6 +51,13 @@ const ResumableAgentController = async (req, res, next, initializeClient, addTit const userId = req.user.id; + const { allowed, pendingRequests, limit } = await checkAndIncrementPendingRequest(userId); + if (!allowed) { + const violationInfo = getViolationInfo(pendingRequests, limit); + await logViolation(req, res, ViolationTypes.CONCURRENT, violationInfo, violationInfo.score); + return res.status(429).json(violationInfo); + } + // Generate conversationId upfront if not provided - streamId === conversationId always // Treat "new" as a placeholder that needs a real UUID (frontend may send "new" for new convos) const conversationId = @@ -137,6 +148,7 @@ const ResumableAgentController = async (req, res, next, initializeClient, addTit if (job.abortController.signal.aborted) { GenerationJobManager.completeJob(streamId, 'Request aborted during initialization'); + await decrementPendingRequest(userId); return; } @@ -263,6 +275,7 @@ const ResumableAgentController = async (req, res, next, initializeClient, addTit GenerationJobManager.emitDone(streamId, finalEvent); GenerationJobManager.completeJob(streamId); + await decrementPendingRequest(userId); if (client.savedMessageIds && !client.savedMessageIds.has(messageId)) { await saveMessage( @@ -282,6 +295,7 @@ const ResumableAgentController = async (req, res, next, initializeClient, addTit }; GenerationJobManager.emitDone(streamId, finalEvent); GenerationJobManager.completeJob(streamId, 'Request aborted'); + await decrementPendingRequest(userId); } if (!client.skipSaveUserMessage && userMessage) { @@ -322,6 +336,8 @@ const ResumableAgentController = async (req, res, next, initializeClient, addTit GenerationJobManager.completeJob(streamId, error.message); } + await decrementPendingRequest(userId); + if (client) { disposeClient(client); } @@ -332,11 +348,12 @@ const ResumableAgentController = async (req, res, next, initializeClient, addTit }; // Start generation and handle any unhandled errors - startGeneration().catch((err) => { + startGeneration().catch(async (err) => { logger.error( `[ResumableAgentController] Unhandled error in background generation: ${err.message}`, ); GenerationJobManager.completeJob(streamId, err.message); + await decrementPendingRequest(userId); }); } catch (error) { logger.error('[ResumableAgentController] Initialization error:', error); @@ -347,6 +364,7 @@ const ResumableAgentController = async (req, res, next, initializeClient, addTit GenerationJobManager.emitError(streamId, error.message || 'Failed to start generation'); } GenerationJobManager.completeJob(streamId, error.message); + await decrementPendingRequest(userId); if (client) { disposeClient(client); } diff --git a/api/server/index.spec.js b/api/server/index.spec.js index 4dcd34687e..c73c605518 100644 --- a/api/server/index.spec.js +++ b/api/server/index.spec.js @@ -25,6 +25,13 @@ jest.mock('~/app/clients/tools', () => ({ toolkits: [], })); +jest.mock('~/config', () => ({ + createMCPServersRegistry: jest.fn(), + createMCPManager: jest.fn().mockResolvedValue({ + getAppToolFunctions: jest.fn().mockResolvedValue({}), + }), +})); + describe('Server Configuration', () => { // Increase the default timeout to allow for Mongo cleanup jest.setTimeout(30_000); diff --git a/api/server/middleware/concurrentLimiter.js b/api/server/middleware/concurrentLimiter.js deleted file mode 100644 index 96885e2fd4..0000000000 --- a/api/server/middleware/concurrentLimiter.js +++ /dev/null @@ -1,76 +0,0 @@ -const { isEnabled } = require('@librechat/api'); -const { Time, CacheKeys, ViolationTypes } = require('librechat-data-provider'); -const clearPendingReq = require('~/cache/clearPendingReq'); -const { logViolation, getLogStores } = require('~/cache'); -const denyRequest = require('./denyRequest'); - -const { - USE_REDIS, - CONCURRENT_MESSAGE_MAX = 1, - CONCURRENT_VIOLATION_SCORE: score, -} = process.env ?? {}; - -/** - * Middleware to limit concurrent requests for a user. - * - * This middleware checks if a user has exceeded a specified concurrent request limit. - * If the user exceeds the limit, an error is returned. If the user is within the limit, - * their request count is incremented. After the request is processed, the count is decremented. - * If the `cache` store is not available, the middleware will skip its logic. - * - * @function - * @param {Object} req - Express request object containing user information. - * @param {Object} res - Express response object. - * @param {import('express').NextFunction} next - Next middleware function. - * @throws {Error} Throws an error if the user exceeds the concurrent request limit. - */ -const concurrentLimiter = async (req, res, next) => { - const namespace = CacheKeys.PENDING_REQ; - const cache = getLogStores(namespace); - if (!cache) { - return next(); - } - - if (Object.keys(req?.body ?? {}).length === 1 && req?.body?.abortKey) { - return next(); - } - - const userId = req.user?.id ?? req.user?._id ?? ''; - const limit = Math.max(CONCURRENT_MESSAGE_MAX, 1); - const type = ViolationTypes.CONCURRENT; - - const key = `${isEnabled(USE_REDIS) ? namespace : ''}:${userId}`; - const pendingRequests = +((await cache.get(key)) ?? 0); - - if (pendingRequests >= limit) { - const errorMessage = { - type, - limit, - pendingRequests, - }; - - await logViolation(req, res, type, errorMessage, score); - return await denyRequest(req, res, errorMessage); - } else { - await cache.set(key, pendingRequests + 1, Time.ONE_MINUTE); - } - - // Ensure the requests are removed from the store once the request is done - let cleared = false; - const cleanUp = async () => { - if (cleared) { - return; - } - cleared = true; - await clearPendingReq({ userId, cache }); - }; - - if (pendingRequests < limit) { - res.on('finish', cleanUp); - res.on('close', cleanUp); - } - - next(); -}; - -module.exports = concurrentLimiter; diff --git a/api/server/middleware/index.js b/api/server/middleware/index.js index 2aad5a47e7..64b9fb1618 100644 --- a/api/server/middleware/index.js +++ b/api/server/middleware/index.js @@ -3,7 +3,6 @@ const validateRegistration = require('./validateRegistration'); const buildEndpointOption = require('./buildEndpointOption'); const validateMessageReq = require('./validateMessageReq'); const checkDomainAllowed = require('./checkDomainAllowed'); -const concurrentLimiter = require('./concurrentLimiter'); const requireLocalAuth = require('./requireLocalAuth'); const canDeleteAccount = require('./canDeleteAccount'); const accessResources = require('./accessResources'); @@ -42,7 +41,6 @@ module.exports = { requireLocalAuth, canDeleteAccount, configMiddleware, - concurrentLimiter, checkDomainAllowed, validateMessageReq, buildEndpointOption, diff --git a/api/server/routes/agents/index.js b/api/server/routes/agents/index.js index 21af27d0bc..6933a11534 100644 --- a/api/server/routes/agents/index.js +++ b/api/server/routes/agents/index.js @@ -7,13 +7,12 @@ const { requireJwtAuth, messageIpLimiter, configMiddleware, - concurrentLimiter, messageUserLimiter, } = require('~/server/middleware'); const { v1 } = require('./v1'); const chat = require('./chat'); -const { LIMIT_CONCURRENT_MESSAGES, LIMIT_MESSAGE_IP, LIMIT_MESSAGE_USER } = process.env ?? {}; +const { LIMIT_MESSAGE_IP, LIMIT_MESSAGE_USER } = process.env ?? {}; const router = express.Router(); @@ -208,10 +207,6 @@ router.post('/chat/abort', async (req, res) => { const chatRouter = express.Router(); chatRouter.use(configMiddleware); -if (isEnabled(LIMIT_CONCURRENT_MESSAGES)) { - chatRouter.use(concurrentLimiter); -} - if (isEnabled(LIMIT_MESSAGE_IP)) { chatRouter.use(messageIpLimiter); } diff --git a/api/server/services/initializeMCPs.spec.js b/api/server/services/initializeMCPs.spec.js index c45c451a1f..e37e12c356 100644 --- a/api/server/services/initializeMCPs.spec.js +++ b/api/server/services/initializeMCPs.spec.js @@ -171,7 +171,7 @@ describe('initializeMCPs', () => { expect(mockMCPManagerInstance.getAppToolFunctions).not.toHaveBeenCalled(); expect(mockMergeAppTools).not.toHaveBeenCalled(); expect(logger.debug).toHaveBeenCalledWith( - '[MCP] No configured servers configured. MCPManager ready for UI-based servers.', + '[MCP] No servers configured. MCPManager ready for UI-based servers.', ); }); @@ -185,7 +185,7 @@ describe('initializeMCPs', () => { expect(mockMCPManagerInstance.getAppToolFunctions).not.toHaveBeenCalled(); expect(mockMergeAppTools).not.toHaveBeenCalled(); expect(logger.debug).toHaveBeenCalledWith( - '[MCP] No configured servers configured. MCPManager ready for UI-based servers.', + '[MCP] No servers configured. MCPManager ready for UI-based servers.', ); }); diff --git a/client/src/hooks/SSE/useResumableSSE.ts b/client/src/hooks/SSE/useResumableSSE.ts index c475aeffcc..1bfb2706d5 100644 --- a/client/src/hooks/SSE/useResumableSSE.ts +++ b/client/src/hooks/SSE/useResumableSSE.ts @@ -532,9 +532,20 @@ export default function useResumableSSE( } } - // All retries failed or non-network error console.error('[ResumableSSE] Error starting generation:', lastError); - errorHandler({ data: undefined, submission: currentSubmission as EventSubmission }); + + const axiosError = lastError as { response?: { data?: Record } }; + const errorData = axiosError?.response?.data; + if (errorData) { + errorHandler({ + data: { text: JSON.stringify(errorData) } as unknown as Parameters< + typeof errorHandler + >[0]['data'], + submission: currentSubmission as EventSubmission, + }); + } else { + errorHandler({ data: undefined, submission: currentSubmission as EventSubmission }); + } setIsSubmitting(false); return null; }, diff --git a/packages/api/src/middleware/concurrency.ts b/packages/api/src/middleware/concurrency.ts new file mode 100644 index 0000000000..92ac8b7d46 --- /dev/null +++ b/packages/api/src/middleware/concurrency.ts @@ -0,0 +1,227 @@ +import { logger } from '@librechat/data-schemas'; +import { CacheKeys, Time, ViolationTypes } from 'librechat-data-provider'; +import { standardCache, cacheConfig, ioredisClient } from '~/cache'; +import { isEnabled, math } from '~/utils'; + +const { USE_REDIS } = cacheConfig; + +const LIMIT_CONCURRENT_MESSAGES = process.env.LIMIT_CONCURRENT_MESSAGES; +const CONCURRENT_MESSAGE_MAX = math(process.env.CONCURRENT_MESSAGE_MAX, 2); +const CONCURRENT_VIOLATION_SCORE = math(process.env.CONCURRENT_VIOLATION_SCORE, 1); + +/** Lazily initialized cache for pending requests (used only for in-memory fallback) */ +let pendingReqCache: ReturnType | null = null; + +/** + * Get or create the pending requests cache for in-memory fallback. + * Uses lazy initialization to avoid creating cache before app is ready. + */ +function getPendingReqCache(): ReturnType | null { + if (!isEnabled(LIMIT_CONCURRENT_MESSAGES)) { + return null; + } + if (!pendingReqCache) { + pendingReqCache = standardCache(CacheKeys.PENDING_REQ); + } + return pendingReqCache; +} + +/** + * Build the cache key for a user's pending requests. + * Note: ioredisClient already has keyPrefix applied, so we only add namespace:userId + */ +function buildKey(userId: string): string { + const namespace = CacheKeys.PENDING_REQ; + return `${namespace}:${userId}`; +} + +/** + * Build the cache key for in-memory fallback (Keyv). + */ +function buildMemoryKey(userId: string): string { + return `:${userId}`; +} + +export interface PendingRequestResult { + allowed: boolean; + pendingRequests: number; + limit: number; +} + +export interface ViolationInfo { + type: string; + limit: number; + pendingRequests: number; + score: number; +} + +/** + * Check if a user can make a new concurrent request and increment the counter if allowed. + * This is designed for resumable streams where the HTTP response lifecycle doesn't match + * the actual request processing lifecycle. + * + * When Redis is available, uses atomic INCR to prevent race conditions. + * Falls back to non-atomic get/set for in-memory cache. + * + * @param userId - The user's ID + * @returns Object with `allowed` (boolean), `pendingRequests` (current count), and `limit` + */ +export async function checkAndIncrementPendingRequest( + userId: string, +): Promise { + const limit = Math.max(CONCURRENT_MESSAGE_MAX, 1); + + if (!isEnabled(LIMIT_CONCURRENT_MESSAGES)) { + return { allowed: true, pendingRequests: 0, limit }; + } + + if (!userId) { + logger.warn('[concurrency] checkAndIncrementPendingRequest called without userId'); + return { allowed: true, pendingRequests: 0, limit }; + } + + // Use atomic Redis INCR when available to prevent race conditions + if (USE_REDIS && ioredisClient) { + const key = buildKey(userId); + try { + // Pipeline ensures INCR and EXPIRE execute atomically in one round-trip + // This prevents edge cases where crash between operations leaves key without TTL + const pipeline = ioredisClient.pipeline(); + pipeline.incr(key); + pipeline.expire(key, 60); + const results = await pipeline.exec(); + + if (!results || results[0][0]) { + throw new Error('Pipeline execution failed'); + } + + const newCount = results[0][1] as number; + + if (newCount > limit) { + // Over limit - decrement back and reject + await ioredisClient.decr(key); + logger.debug( + `[concurrency] User ${userId} exceeded concurrent limit: ${newCount}/${limit}`, + ); + return { allowed: false, pendingRequests: newCount, limit }; + } + + logger.debug( + `[concurrency] User ${userId} incremented pending requests: ${newCount}/${limit}`, + ); + return { allowed: true, pendingRequests: newCount, limit }; + } catch (error) { + logger.error('[concurrency] Redis atomic increment failed:', error); + // On Redis error, allow the request to proceed (fail-open) + return { allowed: true, pendingRequests: 0, limit }; + } + } + + // Fallback: non-atomic in-memory cache (race condition possible but acceptable for in-memory) + const cache = getPendingReqCache(); + if (!cache) { + return { allowed: true, pendingRequests: 0, limit }; + } + + const key = buildMemoryKey(userId); + const pendingRequests = +((await cache.get(key)) ?? 0); + + if (pendingRequests >= limit) { + logger.debug( + `[concurrency] User ${userId} exceeded concurrent limit: ${pendingRequests}/${limit}`, + ); + return { allowed: false, pendingRequests, limit }; + } + + await cache.set(key, pendingRequests + 1, Time.ONE_MINUTE); + logger.debug( + `[concurrency] User ${userId} incremented pending requests: ${pendingRequests + 1}/${limit}`, + ); + + return { allowed: true, pendingRequests: pendingRequests + 1, limit }; +} + +/** + * Decrement the pending request counter for a user. + * Should be called when a generation job completes, errors, or is aborted. + * + * This function handles errors internally and will never throw - it's a cleanup + * operation that should not interrupt the main flow if cache operations fail. + * + * When Redis is available, uses atomic DECR to prevent race conditions. + * Falls back to non-atomic get/set for in-memory cache. + * + * @param userId - The user's ID + */ +export async function decrementPendingRequest(userId: string): Promise { + try { + if (!isEnabled(LIMIT_CONCURRENT_MESSAGES)) { + return; + } + + if (!userId) { + logger.warn('[concurrency] decrementPendingRequest called without userId'); + return; + } + + // Use atomic Redis DECR when available + if (USE_REDIS && ioredisClient) { + const key = buildKey(userId); + try { + const newCount = await ioredisClient.decr(key); + if (newCount < 0) { + // Counter went negative - reset to 0 and delete + await ioredisClient.del(key); + logger.debug(`[concurrency] User ${userId} pending requests cleared (was negative)`); + } else if (newCount === 0) { + // Clean up zero-value keys + await ioredisClient.del(key); + logger.debug(`[concurrency] User ${userId} pending requests cleared`); + } else { + logger.debug(`[concurrency] User ${userId} decremented pending requests: ${newCount}`); + } + } catch (error) { + logger.error('[concurrency] Redis atomic decrement failed:', error); + } + return; + } + + // Fallback: non-atomic in-memory cache + const cache = getPendingReqCache(); + if (!cache) { + return; + } + + const key = buildMemoryKey(userId); + const currentReq = +((await cache.get(key)) ?? 0); + + if (currentReq >= 1) { + await cache.set(key, currentReq - 1, Time.ONE_MINUTE); + logger.debug(`[concurrency] User ${userId} decremented pending requests: ${currentReq - 1}`); + } else { + await cache.delete(key); + logger.debug(`[concurrency] User ${userId} pending requests cleared (was ${currentReq})`); + } + } catch (error) { + logger.error('[concurrency] Error decrementing pending request:', error); + } +} + +/** + * Get violation info for logging purposes when a user exceeds the concurrent request limit. + */ +export function getViolationInfo(pendingRequests: number, limit: number): ViolationInfo { + return { + type: ViolationTypes.CONCURRENT, + limit, + pendingRequests, + score: CONCURRENT_VIOLATION_SCORE, + }; +} + +/** + * Check if concurrent message limiting is enabled. + */ +export function isConcurrentLimitEnabled(): boolean { + return isEnabled(LIMIT_CONCURRENT_MESSAGES); +} diff --git a/packages/api/src/middleware/index.ts b/packages/api/src/middleware/index.ts index 0aa5cf4f86..4398b35e14 100644 --- a/packages/api/src/middleware/index.ts +++ b/packages/api/src/middleware/index.ts @@ -2,3 +2,4 @@ export * from './access'; export * from './error'; export * from './balance'; export * from './json'; +export * from './concurrency';