mirror of
https://github.com/danny-avila/LibreChat.git
synced 2026-02-27 04:44:10 +01:00
🚦 refactor: Concurrent Request Limiter for Resumable Streams (#11167)
* feat: Implement concurrent request handling in ResumableAgentController - Introduced a new concurrency management system by adding `checkAndIncrementPendingRequest` and `decrementPendingRequest` functions to manage user request limits. - Replaced the previous `concurrentLimiter` middleware with a more integrated approach directly within the `ResumableAgentController`. - Enhanced violation logging and request denial for users exceeding their concurrent request limits. - Removed the obsolete `concurrentLimiter` middleware file and updated related imports across the codebase. * refactor: Simplify error handling in ResumableAgentController and enhance SSE error management - Removed the `denyRequest` middleware and replaced it with a direct response for concurrent request violations in the ResumableAgentController. - Improved error handling in the `useResumableSSE` hook to differentiate between network errors and other error types, ensuring more informative error responses are sent to the error handler. * test: Enhance MCP server configuration tests with new mocks and improved logging - Added mocks for MCP server registry and manager in `index.spec.js` to facilitate testing of server configurations. - Updated debug logging in `initializeMCPs.spec.js` to simplify messages regarding server configurations, improving clarity in test outputs. * refactor: Enhance concurrency management in request handling - Updated `checkAndIncrementPendingRequest` and `decrementPendingRequest` functions to utilize Redis for atomic request counting, improving concurrency control. - Added error handling for Redis operations to ensure requests can proceed even during Redis failures. - Streamlined cache key generation for both Redis and in-memory fallback, enhancing clarity and performance in managing pending requests. - Improved comments and documentation for better understanding of the concurrency logic and its implications. * refactor: Improve atomicity in Redis operations for pending request management - Updated `checkAndIncrementPendingRequest` to utilize Redis pipelines for atomic INCR and EXPIRE operations, enhancing concurrency control and preventing edge cases. - Added error handling for pipeline execution failures to ensure robust request management. - Improved comments for clarity on the concurrency logic and its implications.
This commit is contained in:
parent
a2361aa891
commit
a7aa4dc91b
9 changed files with 272 additions and 91 deletions
|
|
@ -1,13 +1,17 @@
|
||||||
const { logger } = require('@librechat/data-schemas');
|
const { logger } = require('@librechat/data-schemas');
|
||||||
const { Constants } = require('librechat-data-provider');
|
const { Constants, ViolationTypes } = require('librechat-data-provider');
|
||||||
const {
|
const {
|
||||||
sendEvent,
|
sendEvent,
|
||||||
|
getViolationInfo,
|
||||||
GenerationJobManager,
|
GenerationJobManager,
|
||||||
|
decrementPendingRequest,
|
||||||
sanitizeFileForTransmit,
|
sanitizeFileForTransmit,
|
||||||
sanitizeMessageForTransmit,
|
sanitizeMessageForTransmit,
|
||||||
|
checkAndIncrementPendingRequest,
|
||||||
} = require('@librechat/api');
|
} = require('@librechat/api');
|
||||||
const { handleAbortError } = require('~/server/middleware');
|
|
||||||
const { disposeClient, clientRegistry, requestDataMap } = require('~/server/cleanup');
|
const { disposeClient, clientRegistry, requestDataMap } = require('~/server/cleanup');
|
||||||
|
const { handleAbortError } = require('~/server/middleware');
|
||||||
|
const { logViolation } = require('~/cache');
|
||||||
const { saveMessage } = require('~/models');
|
const { saveMessage } = require('~/models');
|
||||||
|
|
||||||
function createCloseHandler(abortController) {
|
function createCloseHandler(abortController) {
|
||||||
|
|
@ -47,6 +51,13 @@ const ResumableAgentController = async (req, res, next, initializeClient, addTit
|
||||||
|
|
||||||
const userId = req.user.id;
|
const userId = req.user.id;
|
||||||
|
|
||||||
|
const { allowed, pendingRequests, limit } = await checkAndIncrementPendingRequest(userId);
|
||||||
|
if (!allowed) {
|
||||||
|
const violationInfo = getViolationInfo(pendingRequests, limit);
|
||||||
|
await logViolation(req, res, ViolationTypes.CONCURRENT, violationInfo, violationInfo.score);
|
||||||
|
return res.status(429).json(violationInfo);
|
||||||
|
}
|
||||||
|
|
||||||
// Generate conversationId upfront if not provided - streamId === conversationId always
|
// Generate conversationId upfront if not provided - streamId === conversationId always
|
||||||
// Treat "new" as a placeholder that needs a real UUID (frontend may send "new" for new convos)
|
// Treat "new" as a placeholder that needs a real UUID (frontend may send "new" for new convos)
|
||||||
const conversationId =
|
const conversationId =
|
||||||
|
|
@ -137,6 +148,7 @@ const ResumableAgentController = async (req, res, next, initializeClient, addTit
|
||||||
|
|
||||||
if (job.abortController.signal.aborted) {
|
if (job.abortController.signal.aborted) {
|
||||||
GenerationJobManager.completeJob(streamId, 'Request aborted during initialization');
|
GenerationJobManager.completeJob(streamId, 'Request aborted during initialization');
|
||||||
|
await decrementPendingRequest(userId);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -263,6 +275,7 @@ const ResumableAgentController = async (req, res, next, initializeClient, addTit
|
||||||
|
|
||||||
GenerationJobManager.emitDone(streamId, finalEvent);
|
GenerationJobManager.emitDone(streamId, finalEvent);
|
||||||
GenerationJobManager.completeJob(streamId);
|
GenerationJobManager.completeJob(streamId);
|
||||||
|
await decrementPendingRequest(userId);
|
||||||
|
|
||||||
if (client.savedMessageIds && !client.savedMessageIds.has(messageId)) {
|
if (client.savedMessageIds && !client.savedMessageIds.has(messageId)) {
|
||||||
await saveMessage(
|
await saveMessage(
|
||||||
|
|
@ -282,6 +295,7 @@ const ResumableAgentController = async (req, res, next, initializeClient, addTit
|
||||||
};
|
};
|
||||||
GenerationJobManager.emitDone(streamId, finalEvent);
|
GenerationJobManager.emitDone(streamId, finalEvent);
|
||||||
GenerationJobManager.completeJob(streamId, 'Request aborted');
|
GenerationJobManager.completeJob(streamId, 'Request aborted');
|
||||||
|
await decrementPendingRequest(userId);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!client.skipSaveUserMessage && userMessage) {
|
if (!client.skipSaveUserMessage && userMessage) {
|
||||||
|
|
@ -322,6 +336,8 @@ const ResumableAgentController = async (req, res, next, initializeClient, addTit
|
||||||
GenerationJobManager.completeJob(streamId, error.message);
|
GenerationJobManager.completeJob(streamId, error.message);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
await decrementPendingRequest(userId);
|
||||||
|
|
||||||
if (client) {
|
if (client) {
|
||||||
disposeClient(client);
|
disposeClient(client);
|
||||||
}
|
}
|
||||||
|
|
@ -332,11 +348,12 @@ const ResumableAgentController = async (req, res, next, initializeClient, addTit
|
||||||
};
|
};
|
||||||
|
|
||||||
// Start generation and handle any unhandled errors
|
// Start generation and handle any unhandled errors
|
||||||
startGeneration().catch((err) => {
|
startGeneration().catch(async (err) => {
|
||||||
logger.error(
|
logger.error(
|
||||||
`[ResumableAgentController] Unhandled error in background generation: ${err.message}`,
|
`[ResumableAgentController] Unhandled error in background generation: ${err.message}`,
|
||||||
);
|
);
|
||||||
GenerationJobManager.completeJob(streamId, err.message);
|
GenerationJobManager.completeJob(streamId, err.message);
|
||||||
|
await decrementPendingRequest(userId);
|
||||||
});
|
});
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
logger.error('[ResumableAgentController] Initialization error:', error);
|
logger.error('[ResumableAgentController] Initialization error:', error);
|
||||||
|
|
@ -347,6 +364,7 @@ const ResumableAgentController = async (req, res, next, initializeClient, addTit
|
||||||
GenerationJobManager.emitError(streamId, error.message || 'Failed to start generation');
|
GenerationJobManager.emitError(streamId, error.message || 'Failed to start generation');
|
||||||
}
|
}
|
||||||
GenerationJobManager.completeJob(streamId, error.message);
|
GenerationJobManager.completeJob(streamId, error.message);
|
||||||
|
await decrementPendingRequest(userId);
|
||||||
if (client) {
|
if (client) {
|
||||||
disposeClient(client);
|
disposeClient(client);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -25,6 +25,13 @@ jest.mock('~/app/clients/tools', () => ({
|
||||||
toolkits: [],
|
toolkits: [],
|
||||||
}));
|
}));
|
||||||
|
|
||||||
|
jest.mock('~/config', () => ({
|
||||||
|
createMCPServersRegistry: jest.fn(),
|
||||||
|
createMCPManager: jest.fn().mockResolvedValue({
|
||||||
|
getAppToolFunctions: jest.fn().mockResolvedValue({}),
|
||||||
|
}),
|
||||||
|
}));
|
||||||
|
|
||||||
describe('Server Configuration', () => {
|
describe('Server Configuration', () => {
|
||||||
// Increase the default timeout to allow for Mongo cleanup
|
// Increase the default timeout to allow for Mongo cleanup
|
||||||
jest.setTimeout(30_000);
|
jest.setTimeout(30_000);
|
||||||
|
|
|
||||||
|
|
@ -1,76 +0,0 @@
|
||||||
const { isEnabled } = require('@librechat/api');
|
|
||||||
const { Time, CacheKeys, ViolationTypes } = require('librechat-data-provider');
|
|
||||||
const clearPendingReq = require('~/cache/clearPendingReq');
|
|
||||||
const { logViolation, getLogStores } = require('~/cache');
|
|
||||||
const denyRequest = require('./denyRequest');
|
|
||||||
|
|
||||||
const {
|
|
||||||
USE_REDIS,
|
|
||||||
CONCURRENT_MESSAGE_MAX = 1,
|
|
||||||
CONCURRENT_VIOLATION_SCORE: score,
|
|
||||||
} = process.env ?? {};
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Middleware to limit concurrent requests for a user.
|
|
||||||
*
|
|
||||||
* This middleware checks if a user has exceeded a specified concurrent request limit.
|
|
||||||
* If the user exceeds the limit, an error is returned. If the user is within the limit,
|
|
||||||
* their request count is incremented. After the request is processed, the count is decremented.
|
|
||||||
* If the `cache` store is not available, the middleware will skip its logic.
|
|
||||||
*
|
|
||||||
* @function
|
|
||||||
* @param {Object} req - Express request object containing user information.
|
|
||||||
* @param {Object} res - Express response object.
|
|
||||||
* @param {import('express').NextFunction} next - Next middleware function.
|
|
||||||
* @throws {Error} Throws an error if the user exceeds the concurrent request limit.
|
|
||||||
*/
|
|
||||||
const concurrentLimiter = async (req, res, next) => {
|
|
||||||
const namespace = CacheKeys.PENDING_REQ;
|
|
||||||
const cache = getLogStores(namespace);
|
|
||||||
if (!cache) {
|
|
||||||
return next();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (Object.keys(req?.body ?? {}).length === 1 && req?.body?.abortKey) {
|
|
||||||
return next();
|
|
||||||
}
|
|
||||||
|
|
||||||
const userId = req.user?.id ?? req.user?._id ?? '';
|
|
||||||
const limit = Math.max(CONCURRENT_MESSAGE_MAX, 1);
|
|
||||||
const type = ViolationTypes.CONCURRENT;
|
|
||||||
|
|
||||||
const key = `${isEnabled(USE_REDIS) ? namespace : ''}:${userId}`;
|
|
||||||
const pendingRequests = +((await cache.get(key)) ?? 0);
|
|
||||||
|
|
||||||
if (pendingRequests >= limit) {
|
|
||||||
const errorMessage = {
|
|
||||||
type,
|
|
||||||
limit,
|
|
||||||
pendingRequests,
|
|
||||||
};
|
|
||||||
|
|
||||||
await logViolation(req, res, type, errorMessage, score);
|
|
||||||
return await denyRequest(req, res, errorMessage);
|
|
||||||
} else {
|
|
||||||
await cache.set(key, pendingRequests + 1, Time.ONE_MINUTE);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ensure the requests are removed from the store once the request is done
|
|
||||||
let cleared = false;
|
|
||||||
const cleanUp = async () => {
|
|
||||||
if (cleared) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
cleared = true;
|
|
||||||
await clearPendingReq({ userId, cache });
|
|
||||||
};
|
|
||||||
|
|
||||||
if (pendingRequests < limit) {
|
|
||||||
res.on('finish', cleanUp);
|
|
||||||
res.on('close', cleanUp);
|
|
||||||
}
|
|
||||||
|
|
||||||
next();
|
|
||||||
};
|
|
||||||
|
|
||||||
module.exports = concurrentLimiter;
|
|
||||||
|
|
@ -3,7 +3,6 @@ const validateRegistration = require('./validateRegistration');
|
||||||
const buildEndpointOption = require('./buildEndpointOption');
|
const buildEndpointOption = require('./buildEndpointOption');
|
||||||
const validateMessageReq = require('./validateMessageReq');
|
const validateMessageReq = require('./validateMessageReq');
|
||||||
const checkDomainAllowed = require('./checkDomainAllowed');
|
const checkDomainAllowed = require('./checkDomainAllowed');
|
||||||
const concurrentLimiter = require('./concurrentLimiter');
|
|
||||||
const requireLocalAuth = require('./requireLocalAuth');
|
const requireLocalAuth = require('./requireLocalAuth');
|
||||||
const canDeleteAccount = require('./canDeleteAccount');
|
const canDeleteAccount = require('./canDeleteAccount');
|
||||||
const accessResources = require('./accessResources');
|
const accessResources = require('./accessResources');
|
||||||
|
|
@ -42,7 +41,6 @@ module.exports = {
|
||||||
requireLocalAuth,
|
requireLocalAuth,
|
||||||
canDeleteAccount,
|
canDeleteAccount,
|
||||||
configMiddleware,
|
configMiddleware,
|
||||||
concurrentLimiter,
|
|
||||||
checkDomainAllowed,
|
checkDomainAllowed,
|
||||||
validateMessageReq,
|
validateMessageReq,
|
||||||
buildEndpointOption,
|
buildEndpointOption,
|
||||||
|
|
|
||||||
|
|
@ -7,13 +7,12 @@ const {
|
||||||
requireJwtAuth,
|
requireJwtAuth,
|
||||||
messageIpLimiter,
|
messageIpLimiter,
|
||||||
configMiddleware,
|
configMiddleware,
|
||||||
concurrentLimiter,
|
|
||||||
messageUserLimiter,
|
messageUserLimiter,
|
||||||
} = require('~/server/middleware');
|
} = require('~/server/middleware');
|
||||||
const { v1 } = require('./v1');
|
const { v1 } = require('./v1');
|
||||||
const chat = require('./chat');
|
const chat = require('./chat');
|
||||||
|
|
||||||
const { LIMIT_CONCURRENT_MESSAGES, LIMIT_MESSAGE_IP, LIMIT_MESSAGE_USER } = process.env ?? {};
|
const { LIMIT_MESSAGE_IP, LIMIT_MESSAGE_USER } = process.env ?? {};
|
||||||
|
|
||||||
const router = express.Router();
|
const router = express.Router();
|
||||||
|
|
||||||
|
|
@ -208,10 +207,6 @@ router.post('/chat/abort', async (req, res) => {
|
||||||
const chatRouter = express.Router();
|
const chatRouter = express.Router();
|
||||||
chatRouter.use(configMiddleware);
|
chatRouter.use(configMiddleware);
|
||||||
|
|
||||||
if (isEnabled(LIMIT_CONCURRENT_MESSAGES)) {
|
|
||||||
chatRouter.use(concurrentLimiter);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (isEnabled(LIMIT_MESSAGE_IP)) {
|
if (isEnabled(LIMIT_MESSAGE_IP)) {
|
||||||
chatRouter.use(messageIpLimiter);
|
chatRouter.use(messageIpLimiter);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -171,7 +171,7 @@ describe('initializeMCPs', () => {
|
||||||
expect(mockMCPManagerInstance.getAppToolFunctions).not.toHaveBeenCalled();
|
expect(mockMCPManagerInstance.getAppToolFunctions).not.toHaveBeenCalled();
|
||||||
expect(mockMergeAppTools).not.toHaveBeenCalled();
|
expect(mockMergeAppTools).not.toHaveBeenCalled();
|
||||||
expect(logger.debug).toHaveBeenCalledWith(
|
expect(logger.debug).toHaveBeenCalledWith(
|
||||||
'[MCP] No configured servers configured. MCPManager ready for UI-based servers.',
|
'[MCP] No servers configured. MCPManager ready for UI-based servers.',
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
@ -185,7 +185,7 @@ describe('initializeMCPs', () => {
|
||||||
expect(mockMCPManagerInstance.getAppToolFunctions).not.toHaveBeenCalled();
|
expect(mockMCPManagerInstance.getAppToolFunctions).not.toHaveBeenCalled();
|
||||||
expect(mockMergeAppTools).not.toHaveBeenCalled();
|
expect(mockMergeAppTools).not.toHaveBeenCalled();
|
||||||
expect(logger.debug).toHaveBeenCalledWith(
|
expect(logger.debug).toHaveBeenCalledWith(
|
||||||
'[MCP] No configured servers configured. MCPManager ready for UI-based servers.',
|
'[MCP] No servers configured. MCPManager ready for UI-based servers.',
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -532,9 +532,20 @@ export default function useResumableSSE(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// All retries failed or non-network error
|
|
||||||
console.error('[ResumableSSE] Error starting generation:', lastError);
|
console.error('[ResumableSSE] Error starting generation:', lastError);
|
||||||
errorHandler({ data: undefined, submission: currentSubmission as EventSubmission });
|
|
||||||
|
const axiosError = lastError as { response?: { data?: Record<string, unknown> } };
|
||||||
|
const errorData = axiosError?.response?.data;
|
||||||
|
if (errorData) {
|
||||||
|
errorHandler({
|
||||||
|
data: { text: JSON.stringify(errorData) } as unknown as Parameters<
|
||||||
|
typeof errorHandler
|
||||||
|
>[0]['data'],
|
||||||
|
submission: currentSubmission as EventSubmission,
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
errorHandler({ data: undefined, submission: currentSubmission as EventSubmission });
|
||||||
|
}
|
||||||
setIsSubmitting(false);
|
setIsSubmitting(false);
|
||||||
return null;
|
return null;
|
||||||
},
|
},
|
||||||
|
|
|
||||||
227
packages/api/src/middleware/concurrency.ts
Normal file
227
packages/api/src/middleware/concurrency.ts
Normal file
|
|
@ -0,0 +1,227 @@
|
||||||
|
import { logger } from '@librechat/data-schemas';
|
||||||
|
import { CacheKeys, Time, ViolationTypes } from 'librechat-data-provider';
|
||||||
|
import { standardCache, cacheConfig, ioredisClient } from '~/cache';
|
||||||
|
import { isEnabled, math } from '~/utils';
|
||||||
|
|
||||||
|
const { USE_REDIS } = cacheConfig;
|
||||||
|
|
||||||
|
const LIMIT_CONCURRENT_MESSAGES = process.env.LIMIT_CONCURRENT_MESSAGES;
|
||||||
|
const CONCURRENT_MESSAGE_MAX = math(process.env.CONCURRENT_MESSAGE_MAX, 2);
|
||||||
|
const CONCURRENT_VIOLATION_SCORE = math(process.env.CONCURRENT_VIOLATION_SCORE, 1);
|
||||||
|
|
||||||
|
/** Lazily initialized cache for pending requests (used only for in-memory fallback) */
|
||||||
|
let pendingReqCache: ReturnType<typeof standardCache> | null = null;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get or create the pending requests cache for in-memory fallback.
|
||||||
|
* Uses lazy initialization to avoid creating cache before app is ready.
|
||||||
|
*/
|
||||||
|
function getPendingReqCache(): ReturnType<typeof standardCache> | null {
|
||||||
|
if (!isEnabled(LIMIT_CONCURRENT_MESSAGES)) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
if (!pendingReqCache) {
|
||||||
|
pendingReqCache = standardCache(CacheKeys.PENDING_REQ);
|
||||||
|
}
|
||||||
|
return pendingReqCache;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Build the cache key for a user's pending requests.
|
||||||
|
* Note: ioredisClient already has keyPrefix applied, so we only add namespace:userId
|
||||||
|
*/
|
||||||
|
function buildKey(userId: string): string {
|
||||||
|
const namespace = CacheKeys.PENDING_REQ;
|
||||||
|
return `${namespace}:${userId}`;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Build the cache key for in-memory fallback (Keyv).
|
||||||
|
*/
|
||||||
|
function buildMemoryKey(userId: string): string {
|
||||||
|
return `:${userId}`;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface PendingRequestResult {
|
||||||
|
allowed: boolean;
|
||||||
|
pendingRequests: number;
|
||||||
|
limit: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface ViolationInfo {
|
||||||
|
type: string;
|
||||||
|
limit: number;
|
||||||
|
pendingRequests: number;
|
||||||
|
score: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check if a user can make a new concurrent request and increment the counter if allowed.
|
||||||
|
* This is designed for resumable streams where the HTTP response lifecycle doesn't match
|
||||||
|
* the actual request processing lifecycle.
|
||||||
|
*
|
||||||
|
* When Redis is available, uses atomic INCR to prevent race conditions.
|
||||||
|
* Falls back to non-atomic get/set for in-memory cache.
|
||||||
|
*
|
||||||
|
* @param userId - The user's ID
|
||||||
|
* @returns Object with `allowed` (boolean), `pendingRequests` (current count), and `limit`
|
||||||
|
*/
|
||||||
|
export async function checkAndIncrementPendingRequest(
|
||||||
|
userId: string,
|
||||||
|
): Promise<PendingRequestResult> {
|
||||||
|
const limit = Math.max(CONCURRENT_MESSAGE_MAX, 1);
|
||||||
|
|
||||||
|
if (!isEnabled(LIMIT_CONCURRENT_MESSAGES)) {
|
||||||
|
return { allowed: true, pendingRequests: 0, limit };
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!userId) {
|
||||||
|
logger.warn('[concurrency] checkAndIncrementPendingRequest called without userId');
|
||||||
|
return { allowed: true, pendingRequests: 0, limit };
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use atomic Redis INCR when available to prevent race conditions
|
||||||
|
if (USE_REDIS && ioredisClient) {
|
||||||
|
const key = buildKey(userId);
|
||||||
|
try {
|
||||||
|
// Pipeline ensures INCR and EXPIRE execute atomically in one round-trip
|
||||||
|
// This prevents edge cases where crash between operations leaves key without TTL
|
||||||
|
const pipeline = ioredisClient.pipeline();
|
||||||
|
pipeline.incr(key);
|
||||||
|
pipeline.expire(key, 60);
|
||||||
|
const results = await pipeline.exec();
|
||||||
|
|
||||||
|
if (!results || results[0][0]) {
|
||||||
|
throw new Error('Pipeline execution failed');
|
||||||
|
}
|
||||||
|
|
||||||
|
const newCount = results[0][1] as number;
|
||||||
|
|
||||||
|
if (newCount > limit) {
|
||||||
|
// Over limit - decrement back and reject
|
||||||
|
await ioredisClient.decr(key);
|
||||||
|
logger.debug(
|
||||||
|
`[concurrency] User ${userId} exceeded concurrent limit: ${newCount}/${limit}`,
|
||||||
|
);
|
||||||
|
return { allowed: false, pendingRequests: newCount, limit };
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
`[concurrency] User ${userId} incremented pending requests: ${newCount}/${limit}`,
|
||||||
|
);
|
||||||
|
return { allowed: true, pendingRequests: newCount, limit };
|
||||||
|
} catch (error) {
|
||||||
|
logger.error('[concurrency] Redis atomic increment failed:', error);
|
||||||
|
// On Redis error, allow the request to proceed (fail-open)
|
||||||
|
return { allowed: true, pendingRequests: 0, limit };
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback: non-atomic in-memory cache (race condition possible but acceptable for in-memory)
|
||||||
|
const cache = getPendingReqCache();
|
||||||
|
if (!cache) {
|
||||||
|
return { allowed: true, pendingRequests: 0, limit };
|
||||||
|
}
|
||||||
|
|
||||||
|
const key = buildMemoryKey(userId);
|
||||||
|
const pendingRequests = +((await cache.get(key)) ?? 0);
|
||||||
|
|
||||||
|
if (pendingRequests >= limit) {
|
||||||
|
logger.debug(
|
||||||
|
`[concurrency] User ${userId} exceeded concurrent limit: ${pendingRequests}/${limit}`,
|
||||||
|
);
|
||||||
|
return { allowed: false, pendingRequests, limit };
|
||||||
|
}
|
||||||
|
|
||||||
|
await cache.set(key, pendingRequests + 1, Time.ONE_MINUTE);
|
||||||
|
logger.debug(
|
||||||
|
`[concurrency] User ${userId} incremented pending requests: ${pendingRequests + 1}/${limit}`,
|
||||||
|
);
|
||||||
|
|
||||||
|
return { allowed: true, pendingRequests: pendingRequests + 1, limit };
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Decrement the pending request counter for a user.
|
||||||
|
* Should be called when a generation job completes, errors, or is aborted.
|
||||||
|
*
|
||||||
|
* This function handles errors internally and will never throw - it's a cleanup
|
||||||
|
* operation that should not interrupt the main flow if cache operations fail.
|
||||||
|
*
|
||||||
|
* When Redis is available, uses atomic DECR to prevent race conditions.
|
||||||
|
* Falls back to non-atomic get/set for in-memory cache.
|
||||||
|
*
|
||||||
|
* @param userId - The user's ID
|
||||||
|
*/
|
||||||
|
export async function decrementPendingRequest(userId: string): Promise<void> {
|
||||||
|
try {
|
||||||
|
if (!isEnabled(LIMIT_CONCURRENT_MESSAGES)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!userId) {
|
||||||
|
logger.warn('[concurrency] decrementPendingRequest called without userId');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use atomic Redis DECR when available
|
||||||
|
if (USE_REDIS && ioredisClient) {
|
||||||
|
const key = buildKey(userId);
|
||||||
|
try {
|
||||||
|
const newCount = await ioredisClient.decr(key);
|
||||||
|
if (newCount < 0) {
|
||||||
|
// Counter went negative - reset to 0 and delete
|
||||||
|
await ioredisClient.del(key);
|
||||||
|
logger.debug(`[concurrency] User ${userId} pending requests cleared (was negative)`);
|
||||||
|
} else if (newCount === 0) {
|
||||||
|
// Clean up zero-value keys
|
||||||
|
await ioredisClient.del(key);
|
||||||
|
logger.debug(`[concurrency] User ${userId} pending requests cleared`);
|
||||||
|
} else {
|
||||||
|
logger.debug(`[concurrency] User ${userId} decremented pending requests: ${newCount}`);
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
logger.error('[concurrency] Redis atomic decrement failed:', error);
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback: non-atomic in-memory cache
|
||||||
|
const cache = getPendingReqCache();
|
||||||
|
if (!cache) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const key = buildMemoryKey(userId);
|
||||||
|
const currentReq = +((await cache.get(key)) ?? 0);
|
||||||
|
|
||||||
|
if (currentReq >= 1) {
|
||||||
|
await cache.set(key, currentReq - 1, Time.ONE_MINUTE);
|
||||||
|
logger.debug(`[concurrency] User ${userId} decremented pending requests: ${currentReq - 1}`);
|
||||||
|
} else {
|
||||||
|
await cache.delete(key);
|
||||||
|
logger.debug(`[concurrency] User ${userId} pending requests cleared (was ${currentReq})`);
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
logger.error('[concurrency] Error decrementing pending request:', error);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get violation info for logging purposes when a user exceeds the concurrent request limit.
|
||||||
|
*/
|
||||||
|
export function getViolationInfo(pendingRequests: number, limit: number): ViolationInfo {
|
||||||
|
return {
|
||||||
|
type: ViolationTypes.CONCURRENT,
|
||||||
|
limit,
|
||||||
|
pendingRequests,
|
||||||
|
score: CONCURRENT_VIOLATION_SCORE,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check if concurrent message limiting is enabled.
|
||||||
|
*/
|
||||||
|
export function isConcurrentLimitEnabled(): boolean {
|
||||||
|
return isEnabled(LIMIT_CONCURRENT_MESSAGES);
|
||||||
|
}
|
||||||
|
|
@ -2,3 +2,4 @@ export * from './access';
|
||||||
export * from './error';
|
export * from './error';
|
||||||
export * from './balance';
|
export * from './balance';
|
||||||
export * from './json';
|
export * from './json';
|
||||||
|
export * from './concurrency';
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue