🚦 refactor: Concurrent Request Limiter for Resumable Streams (#11167)

* feat: Implement concurrent request handling in ResumableAgentController

- Introduced a new concurrency management system by adding `checkAndIncrementPendingRequest` and `decrementPendingRequest` functions to manage user request limits.
- Replaced the previous `concurrentLimiter` middleware with a more integrated approach directly within the `ResumableAgentController`.
- Enhanced violation logging and request denial for users exceeding their concurrent request limits.
- Removed the obsolete `concurrentLimiter` middleware file and updated related imports across the codebase.

* refactor: Simplify error handling in ResumableAgentController and enhance SSE error management

- Removed the `denyRequest` middleware and replaced it with a direct response for concurrent request violations in the ResumableAgentController.
- Improved error handling in the `useResumableSSE` hook to differentiate between network errors and other error types, ensuring more informative error responses are sent to the error handler.

* test: Enhance MCP server configuration tests with new mocks and improved logging

- Added mocks for MCP server registry and manager in `index.spec.js` to facilitate testing of server configurations.
- Updated debug logging in `initializeMCPs.spec.js` to simplify messages regarding server configurations, improving clarity in test outputs.

* refactor: Enhance concurrency management in request handling

- Updated `checkAndIncrementPendingRequest` and `decrementPendingRequest` functions to utilize Redis for atomic request counting, improving concurrency control.
- Added error handling for Redis operations to ensure requests can proceed even during Redis failures.
- Streamlined cache key generation for both Redis and in-memory fallback, enhancing clarity and performance in managing pending requests.
- Improved comments and documentation for better understanding of the concurrency logic and its implications.

* refactor: Improve atomicity in Redis operations for pending request management

- Updated `checkAndIncrementPendingRequest` to utilize Redis pipelines for atomic INCR and EXPIRE operations, enhancing concurrency control and preventing edge cases.
- Added error handling for pipeline execution failures to ensure robust request management.
- Improved comments for clarity on the concurrency logic and its implications.
This commit is contained in:
Danny Avila 2026-01-01 11:10:56 -05:00 committed by GitHub
parent a2361aa891
commit a7aa4dc91b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 272 additions and 91 deletions

View file

@ -1,13 +1,17 @@
const { logger } = require('@librechat/data-schemas');
const { Constants } = require('librechat-data-provider');
const { Constants, ViolationTypes } = require('librechat-data-provider');
const {
sendEvent,
getViolationInfo,
GenerationJobManager,
decrementPendingRequest,
sanitizeFileForTransmit,
sanitizeMessageForTransmit,
checkAndIncrementPendingRequest,
} = require('@librechat/api');
const { handleAbortError } = require('~/server/middleware');
const { disposeClient, clientRegistry, requestDataMap } = require('~/server/cleanup');
const { handleAbortError } = require('~/server/middleware');
const { logViolation } = require('~/cache');
const { saveMessage } = require('~/models');
function createCloseHandler(abortController) {
@ -47,6 +51,13 @@ const ResumableAgentController = async (req, res, next, initializeClient, addTit
const userId = req.user.id;
const { allowed, pendingRequests, limit } = await checkAndIncrementPendingRequest(userId);
if (!allowed) {
const violationInfo = getViolationInfo(pendingRequests, limit);
await logViolation(req, res, ViolationTypes.CONCURRENT, violationInfo, violationInfo.score);
return res.status(429).json(violationInfo);
}
// Generate conversationId upfront if not provided - streamId === conversationId always
// Treat "new" as a placeholder that needs a real UUID (frontend may send "new" for new convos)
const conversationId =
@ -137,6 +148,7 @@ const ResumableAgentController = async (req, res, next, initializeClient, addTit
if (job.abortController.signal.aborted) {
GenerationJobManager.completeJob(streamId, 'Request aborted during initialization');
await decrementPendingRequest(userId);
return;
}
@ -263,6 +275,7 @@ const ResumableAgentController = async (req, res, next, initializeClient, addTit
GenerationJobManager.emitDone(streamId, finalEvent);
GenerationJobManager.completeJob(streamId);
await decrementPendingRequest(userId);
if (client.savedMessageIds && !client.savedMessageIds.has(messageId)) {
await saveMessage(
@ -282,6 +295,7 @@ const ResumableAgentController = async (req, res, next, initializeClient, addTit
};
GenerationJobManager.emitDone(streamId, finalEvent);
GenerationJobManager.completeJob(streamId, 'Request aborted');
await decrementPendingRequest(userId);
}
if (!client.skipSaveUserMessage && userMessage) {
@ -322,6 +336,8 @@ const ResumableAgentController = async (req, res, next, initializeClient, addTit
GenerationJobManager.completeJob(streamId, error.message);
}
await decrementPendingRequest(userId);
if (client) {
disposeClient(client);
}
@ -332,11 +348,12 @@ const ResumableAgentController = async (req, res, next, initializeClient, addTit
};
// Start generation and handle any unhandled errors
startGeneration().catch((err) => {
startGeneration().catch(async (err) => {
logger.error(
`[ResumableAgentController] Unhandled error in background generation: ${err.message}`,
);
GenerationJobManager.completeJob(streamId, err.message);
await decrementPendingRequest(userId);
});
} catch (error) {
logger.error('[ResumableAgentController] Initialization error:', error);
@ -347,6 +364,7 @@ const ResumableAgentController = async (req, res, next, initializeClient, addTit
GenerationJobManager.emitError(streamId, error.message || 'Failed to start generation');
}
GenerationJobManager.completeJob(streamId, error.message);
await decrementPendingRequest(userId);
if (client) {
disposeClient(client);
}

View file

@ -25,6 +25,13 @@ jest.mock('~/app/clients/tools', () => ({
toolkits: [],
}));
jest.mock('~/config', () => ({
createMCPServersRegistry: jest.fn(),
createMCPManager: jest.fn().mockResolvedValue({
getAppToolFunctions: jest.fn().mockResolvedValue({}),
}),
}));
describe('Server Configuration', () => {
// Increase the default timeout to allow for Mongo cleanup
jest.setTimeout(30_000);

View file

@ -1,76 +0,0 @@
const { isEnabled } = require('@librechat/api');
const { Time, CacheKeys, ViolationTypes } = require('librechat-data-provider');
const clearPendingReq = require('~/cache/clearPendingReq');
const { logViolation, getLogStores } = require('~/cache');
const denyRequest = require('./denyRequest');
const {
USE_REDIS,
CONCURRENT_MESSAGE_MAX = 1,
CONCURRENT_VIOLATION_SCORE: score,
} = process.env ?? {};
/**
* Middleware to limit concurrent requests for a user.
*
* This middleware checks if a user has exceeded a specified concurrent request limit.
* If the user exceeds the limit, an error is returned. If the user is within the limit,
* their request count is incremented. After the request is processed, the count is decremented.
* If the `cache` store is not available, the middleware will skip its logic.
*
* @function
* @param {Object} req - Express request object containing user information.
* @param {Object} res - Express response object.
* @param {import('express').NextFunction} next - Next middleware function.
* @throws {Error} Throws an error if the user exceeds the concurrent request limit.
*/
const concurrentLimiter = async (req, res, next) => {
const namespace = CacheKeys.PENDING_REQ;
const cache = getLogStores(namespace);
if (!cache) {
return next();
}
if (Object.keys(req?.body ?? {}).length === 1 && req?.body?.abortKey) {
return next();
}
const userId = req.user?.id ?? req.user?._id ?? '';
const limit = Math.max(CONCURRENT_MESSAGE_MAX, 1);
const type = ViolationTypes.CONCURRENT;
const key = `${isEnabled(USE_REDIS) ? namespace : ''}:${userId}`;
const pendingRequests = +((await cache.get(key)) ?? 0);
if (pendingRequests >= limit) {
const errorMessage = {
type,
limit,
pendingRequests,
};
await logViolation(req, res, type, errorMessage, score);
return await denyRequest(req, res, errorMessage);
} else {
await cache.set(key, pendingRequests + 1, Time.ONE_MINUTE);
}
// Ensure the requests are removed from the store once the request is done
let cleared = false;
const cleanUp = async () => {
if (cleared) {
return;
}
cleared = true;
await clearPendingReq({ userId, cache });
};
if (pendingRequests < limit) {
res.on('finish', cleanUp);
res.on('close', cleanUp);
}
next();
};
module.exports = concurrentLimiter;

View file

@ -3,7 +3,6 @@ const validateRegistration = require('./validateRegistration');
const buildEndpointOption = require('./buildEndpointOption');
const validateMessageReq = require('./validateMessageReq');
const checkDomainAllowed = require('./checkDomainAllowed');
const concurrentLimiter = require('./concurrentLimiter');
const requireLocalAuth = require('./requireLocalAuth');
const canDeleteAccount = require('./canDeleteAccount');
const accessResources = require('./accessResources');
@ -42,7 +41,6 @@ module.exports = {
requireLocalAuth,
canDeleteAccount,
configMiddleware,
concurrentLimiter,
checkDomainAllowed,
validateMessageReq,
buildEndpointOption,

View file

@ -7,13 +7,12 @@ const {
requireJwtAuth,
messageIpLimiter,
configMiddleware,
concurrentLimiter,
messageUserLimiter,
} = require('~/server/middleware');
const { v1 } = require('./v1');
const chat = require('./chat');
const { LIMIT_CONCURRENT_MESSAGES, LIMIT_MESSAGE_IP, LIMIT_MESSAGE_USER } = process.env ?? {};
const { LIMIT_MESSAGE_IP, LIMIT_MESSAGE_USER } = process.env ?? {};
const router = express.Router();
@ -208,10 +207,6 @@ router.post('/chat/abort', async (req, res) => {
const chatRouter = express.Router();
chatRouter.use(configMiddleware);
if (isEnabled(LIMIT_CONCURRENT_MESSAGES)) {
chatRouter.use(concurrentLimiter);
}
if (isEnabled(LIMIT_MESSAGE_IP)) {
chatRouter.use(messageIpLimiter);
}

View file

@ -171,7 +171,7 @@ describe('initializeMCPs', () => {
expect(mockMCPManagerInstance.getAppToolFunctions).not.toHaveBeenCalled();
expect(mockMergeAppTools).not.toHaveBeenCalled();
expect(logger.debug).toHaveBeenCalledWith(
'[MCP] No configured servers configured. MCPManager ready for UI-based servers.',
'[MCP] No servers configured. MCPManager ready for UI-based servers.',
);
});
@ -185,7 +185,7 @@ describe('initializeMCPs', () => {
expect(mockMCPManagerInstance.getAppToolFunctions).not.toHaveBeenCalled();
expect(mockMergeAppTools).not.toHaveBeenCalled();
expect(logger.debug).toHaveBeenCalledWith(
'[MCP] No configured servers configured. MCPManager ready for UI-based servers.',
'[MCP] No servers configured. MCPManager ready for UI-based servers.',
);
});