🤖 refactor: Improve Agents Memory Usage, Bump Keyv, Grok 3 (#6850)

* chore: remove unused redis file

* chore: bump keyv dependencies, and update related imports

* refactor: Implement IoRedis client for rate limiting across middleware, as node-redis via keyv not compatible

* fix: Set max listeners to expected amount

* WIP: memory improvements

* refactor: Simplify getAbortData assignment in createAbortController

* refactor: Update getAbortData to use WeakRef for content management

* WIP: memory improvements in agent chat requests

* refactor: Enhance memory management with finalization registry and cleanup functions

* refactor: Simplify domainParser calls by removing unnecessary request parameter

* refactor: Update parameter types for action tools and agent loading functions to use minimal configs

* refactor: Simplify domainParser tests by removing unnecessary request parameter

* refactor: Simplify domainParser call by removing unnecessary request parameter

* refactor: Enhance client disposal by nullifying additional properties to improve memory management

* refactor: Improve title generation by adding abort controller and timeout handling, consolidate request cleanup

* refactor: Update checkIdleConnections to skip current user when checking for idle connections if passed

* refactor: Update createMCPTool to derive userId from config and handle abort signals

* refactor: Introduce createTokenCounter function and update tokenCounter usage; enhance disposeClient to reset Graph values

* refactor: Update getMCPManager to accept userId parameter for improved idle connection handling

* refactor: Extract logToolError function for improved error handling in AgentClient

* refactor: Update disposeClient to clear handlerRegistry and graphRunnable references in client.run

* refactor: Extract createHandleNewToken function to streamline token handling in initializeClient

* chore: bump @librechat/agents

* refactor: Improve timeout handling in addTitle function for better error management

* refactor: Introduce createFetch instead of using class method

* refactor: Enhance client disposal and request data handling in AskController and EditController

* refactor: Update import statements for AnthropicClient and OpenAIClient to use specific paths

* refactor: Use WeakRef for response handling in SplitStreamHandler to prevent memory leaks

* refactor: Simplify client disposal and rename getReqData to processReqData in AskController and EditController

* refactor: Improve logging structure and parameter handling in OpenAIClient

* refactor: Remove unused GraphEvents and improve stream event handling in AnthropicClient and OpenAIClient

* refactor: Simplify client initialization in AskController and EditController

* refactor: Remove unused mock functions and implement in-memory store for KeyvMongo

* chore: Update dependencies in package-lock.json to latest versions

* refactor: Await token usage recording in OpenAIClient to ensure proper async handling

* refactor: Remove handleAbort route from multiple endpoints and enhance client disposal logic

* refactor: Enhance abort controller logic by managing abortKey more effectively

* refactor: Add newConversation handling in useEventHandlers for improved conversation management

* fix: dropparams

* refactor: Use optional chaining for safer access to request properties in BaseClient

* refactor: Move client disposal and request data processing logic to cleanup module for better organization

* refactor: Remove aborted request check from addTitle function for cleaner logic

* feat: Add Grok 3 model pricing and update tests for new models

* chore: Remove trace warnings and inspect flags from backend start script used for debugging

* refactor: Replace user identifier handling with userId for consistency across controllers, use UserId in clientRegistry

* refactor: Enhance client disposal logic to prevent memory leaks by clearing additional references

* chore: Update @librechat/agents to version 2.4.14 in package.json and package-lock.json
This commit is contained in:
Danny Avila 2025-04-12 18:46:36 -04:00 committed by GitHub
parent 1e6b1b9554
commit 37964975c1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
68 changed files with 1796 additions and 623 deletions

View file

@ -9,7 +9,7 @@ const {
getResponseSender, getResponseSender,
validateVisionModel, validateVisionModel,
} = require('librechat-data-provider'); } = require('librechat-data-provider');
const { SplitStreamHandler: _Handler, GraphEvents } = require('@librechat/agents'); const { SplitStreamHandler: _Handler } = require('@librechat/agents');
const { const {
truncateText, truncateText,
formatMessage, formatMessage,
@ -26,10 +26,11 @@ const {
const { getModelMaxTokens, getModelMaxOutputTokens, matchModelName } = require('~/utils'); const { getModelMaxTokens, getModelMaxOutputTokens, matchModelName } = require('~/utils');
const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens'); const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens');
const { encodeAndFormat } = require('~/server/services/Files/images/encode'); const { encodeAndFormat } = require('~/server/services/Files/images/encode');
const { createFetch, createStreamEventHandlers } = require('./generators');
const Tokenizer = require('~/server/services/Tokenizer'); const Tokenizer = require('~/server/services/Tokenizer');
const { logger, sendEvent } = require('~/config');
const { sleep } = require('~/server/utils'); const { sleep } = require('~/server/utils');
const BaseClient = require('./BaseClient'); const BaseClient = require('./BaseClient');
const { logger } = require('~/config');
const HUMAN_PROMPT = '\n\nHuman:'; const HUMAN_PROMPT = '\n\nHuman:';
const AI_PROMPT = '\n\nAssistant:'; const AI_PROMPT = '\n\nAssistant:';
@ -184,7 +185,10 @@ class AnthropicClient extends BaseClient {
getClient(requestOptions) { getClient(requestOptions) {
/** @type {Anthropic.ClientOptions} */ /** @type {Anthropic.ClientOptions} */
const options = { const options = {
fetch: this.fetch, fetch: createFetch({
directEndpoint: this.options.directEndpoint,
reverseProxyUrl: this.options.reverseProxyUrl,
}),
apiKey: this.apiKey, apiKey: this.apiKey,
}; };
@ -795,14 +799,11 @@ class AnthropicClient extends BaseClient {
} }
logger.debug('[AnthropicClient]', { ...requestOptions }); logger.debug('[AnthropicClient]', { ...requestOptions });
const handlers = createStreamEventHandlers(this.options.res);
this.streamHandler = new SplitStreamHandler({ this.streamHandler = new SplitStreamHandler({
accumulate: true, accumulate: true,
runId: this.responseMessageId, runId: this.responseMessageId,
handlers: { handlers,
[GraphEvents.ON_RUN_STEP]: (event) => sendEvent(this.options.res, event),
[GraphEvents.ON_MESSAGE_DELTA]: (event) => sendEvent(this.options.res, event),
[GraphEvents.ON_REASONING_DELTA]: (event) => sendEvent(this.options.res, event),
},
}); });
let intermediateReply = this.streamHandler.tokens; let intermediateReply = this.streamHandler.tokens;

View file

@ -28,15 +28,10 @@ class BaseClient {
month: 'long', month: 'long',
day: 'numeric', day: 'numeric',
}); });
this.fetch = this.fetch.bind(this);
/** @type {boolean} */ /** @type {boolean} */
this.skipSaveConvo = false; this.skipSaveConvo = false;
/** @type {boolean} */ /** @type {boolean} */
this.skipSaveUserMessage = false; this.skipSaveUserMessage = false;
/** @type {ClientDatabaseSavePromise} */
this.userMessagePromise;
/** @type {ClientDatabaseSavePromise} */
this.responsePromise;
/** @type {string} */ /** @type {string} */
this.user; this.user;
/** @type {string} */ /** @type {string} */
@ -564,6 +559,8 @@ class BaseClient {
} }
async sendMessage(message, opts = {}) { async sendMessage(message, opts = {}) {
/** @type {Promise<TMessage>} */
let userMessagePromise;
const { user, head, isEdited, conversationId, responseMessageId, saveOptions, userMessage } = const { user, head, isEdited, conversationId, responseMessageId, saveOptions, userMessage } =
await this.handleStartMethods(message, opts); await this.handleStartMethods(message, opts);
@ -625,11 +622,11 @@ class BaseClient {
} }
if (!isEdited && !this.skipSaveUserMessage) { if (!isEdited && !this.skipSaveUserMessage) {
this.userMessagePromise = this.saveMessageToDatabase(userMessage, saveOptions, user); userMessagePromise = this.saveMessageToDatabase(userMessage, saveOptions, user);
this.savedMessageIds.add(userMessage.messageId); this.savedMessageIds.add(userMessage.messageId);
if (typeof opts?.getReqData === 'function') { if (typeof opts?.getReqData === 'function') {
opts.getReqData({ opts.getReqData({
userMessagePromise: this.userMessagePromise, userMessagePromise,
}); });
} }
} }
@ -655,7 +652,9 @@ class BaseClient {
/** @type {string|string[]|undefined} */ /** @type {string|string[]|undefined} */
const completion = await this.sendCompletion(payload, opts); const completion = await this.sendCompletion(payload, opts);
if (this.abortController) {
this.abortController.requestCompleted = true; this.abortController.requestCompleted = true;
}
/** @type {TMessage} */ /** @type {TMessage} */
const responseMessage = { const responseMessage = {
@ -703,7 +702,13 @@ class BaseClient {
if (usage != null && Number(usage[this.outputTokensKey]) > 0) { if (usage != null && Number(usage[this.outputTokensKey]) > 0) {
responseMessage.tokenCount = usage[this.outputTokensKey]; responseMessage.tokenCount = usage[this.outputTokensKey];
completionTokens = responseMessage.tokenCount; completionTokens = responseMessage.tokenCount;
await this.updateUserMessageTokenCount({ usage, tokenCountMap, userMessage, opts }); await this.updateUserMessageTokenCount({
usage,
tokenCountMap,
userMessage,
userMessagePromise,
opts,
});
} else { } else {
responseMessage.tokenCount = this.getTokenCountForResponse(responseMessage); responseMessage.tokenCount = this.getTokenCountForResponse(responseMessage);
completionTokens = responseMessage.tokenCount; completionTokens = responseMessage.tokenCount;
@ -712,8 +717,8 @@ class BaseClient {
await this.recordTokenUsage({ promptTokens, completionTokens, usage }); await this.recordTokenUsage({ promptTokens, completionTokens, usage });
} }
if (this.userMessagePromise) { if (userMessagePromise) {
await this.userMessagePromise; await userMessagePromise;
} }
if (this.artifactPromises) { if (this.artifactPromises) {
@ -728,7 +733,11 @@ class BaseClient {
} }
} }
this.responsePromise = this.saveMessageToDatabase(responseMessage, saveOptions, user); responseMessage.databasePromise = this.saveMessageToDatabase(
responseMessage,
saveOptions,
user,
);
this.savedMessageIds.add(responseMessage.messageId); this.savedMessageIds.add(responseMessage.messageId);
delete responseMessage.tokenCount; delete responseMessage.tokenCount;
return responseMessage; return responseMessage;
@ -749,9 +758,16 @@ class BaseClient {
* @param {StreamUsage} params.usage * @param {StreamUsage} params.usage
* @param {Record<string, number>} params.tokenCountMap * @param {Record<string, number>} params.tokenCountMap
* @param {TMessage} params.userMessage * @param {TMessage} params.userMessage
* @param {Promise<TMessage>} params.userMessagePromise
* @param {object} params.opts * @param {object} params.opts
*/ */
async updateUserMessageTokenCount({ usage, tokenCountMap, userMessage, opts }) { async updateUserMessageTokenCount({
usage,
tokenCountMap,
userMessage,
userMessagePromise,
opts,
}) {
/** @type {boolean} */ /** @type {boolean} */
const shouldUpdateCount = const shouldUpdateCount =
this.calculateCurrentTokenCount != null && this.calculateCurrentTokenCount != null &&
@ -787,7 +803,7 @@ class BaseClient {
Note: we update the user message to be sure it gets the calculated token count; Note: we update the user message to be sure it gets the calculated token count;
though `AskController` saves the user message, EditController does not though `AskController` saves the user message, EditController does not
*/ */
await this.userMessagePromise; await userMessagePromise;
await this.updateMessageInDatabase({ await this.updateMessageInDatabase({
messageId: userMessage.messageId, messageId: userMessage.messageId,
tokenCount: userMessageTokenCount, tokenCount: userMessageTokenCount,
@ -853,7 +869,7 @@ class BaseClient {
} }
const savedMessage = await saveMessage( const savedMessage = await saveMessage(
this.options.req, this.options?.req,
{ {
...message, ...message,
endpoint: this.options.endpoint, endpoint: this.options.endpoint,
@ -877,7 +893,7 @@ class BaseClient {
const existingConvo = const existingConvo =
this.fetchedConvo === true this.fetchedConvo === true
? null ? null
: await getConvo(this.options.req?.user?.id, message.conversationId); : await getConvo(this.options?.req?.user?.id, message.conversationId);
const unsetFields = {}; const unsetFields = {};
const exceptions = new Set(['spec', 'iconURL']); const exceptions = new Set(['spec', 'iconURL']);
@ -897,7 +913,7 @@ class BaseClient {
} }
} }
const conversation = await saveConvo(this.options.req, fieldsToKeep, { const conversation = await saveConvo(this.options?.req, fieldsToKeep, {
context: 'api/app/clients/BaseClient.js - saveMessageToDatabase #saveConvo', context: 'api/app/clients/BaseClient.js - saveMessageToDatabase #saveConvo',
unsetFields, unsetFields,
}); });

View file

@ -1,4 +1,4 @@
const Keyv = require('keyv'); const { Keyv } = require('keyv');
const crypto = require('crypto'); const crypto = require('crypto');
const { CohereClient } = require('cohere-ai'); const { CohereClient } = require('cohere-ai');
const { fetchEventSource } = require('@waylaidwanderer/fetch-event-source'); const { fetchEventSource } = require('@waylaidwanderer/fetch-event-source');
@ -339,7 +339,7 @@ class ChatGPTClient extends BaseClient {
opts.body = JSON.stringify(modelOptions); opts.body = JSON.stringify(modelOptions);
if (modelOptions.stream) { if (modelOptions.stream) {
// eslint-disable-next-line no-async-promise-executor
return new Promise(async (resolve, reject) => { return new Promise(async (resolve, reject) => {
try { try {
let done = false; let done = false;

View file

@ -1,7 +1,7 @@
const OpenAI = require('openai'); const OpenAI = require('openai');
const { OllamaClient } = require('./OllamaClient'); const { OllamaClient } = require('./OllamaClient');
const { HttpsProxyAgent } = require('https-proxy-agent'); const { HttpsProxyAgent } = require('https-proxy-agent');
const { SplitStreamHandler, GraphEvents } = require('@librechat/agents'); const { SplitStreamHandler } = require('@librechat/agents');
const { const {
Constants, Constants,
ImageDetail, ImageDetail,
@ -32,17 +32,18 @@ const {
createContextHandlers, createContextHandlers,
} = require('./prompts'); } = require('./prompts');
const { encodeAndFormat } = require('~/server/services/Files/images/encode'); const { encodeAndFormat } = require('~/server/services/Files/images/encode');
const { createFetch, createStreamEventHandlers } = require('./generators');
const { addSpaceIfNeeded, isEnabled, sleep } = require('~/server/utils'); const { addSpaceIfNeeded, isEnabled, sleep } = require('~/server/utils');
const Tokenizer = require('~/server/services/Tokenizer'); const Tokenizer = require('~/server/services/Tokenizer');
const { spendTokens } = require('~/models/spendTokens'); const { spendTokens } = require('~/models/spendTokens');
const { handleOpenAIErrors } = require('./tools/util'); const { handleOpenAIErrors } = require('./tools/util');
const { createLLM, RunManager } = require('./llm'); const { createLLM, RunManager } = require('./llm');
const { logger, sendEvent } = require('~/config');
const ChatGPTClient = require('./ChatGPTClient'); const ChatGPTClient = require('./ChatGPTClient');
const { summaryBuffer } = require('./memory'); const { summaryBuffer } = require('./memory');
const { runTitleChain } = require('./chains'); const { runTitleChain } = require('./chains');
const { tokenSplit } = require('./document'); const { tokenSplit } = require('./document');
const BaseClient = require('./BaseClient'); const BaseClient = require('./BaseClient');
const { logger } = require('~/config');
class OpenAIClient extends BaseClient { class OpenAIClient extends BaseClient {
constructor(apiKey, options = {}) { constructor(apiKey, options = {}) {
@ -609,7 +610,7 @@ class OpenAIClient extends BaseClient {
return result.trim(); return result.trim();
} }
logger.debug('[OpenAIClient] sendCompletion: result', result); logger.debug('[OpenAIClient] sendCompletion: result', { ...result });
if (this.isChatCompletion) { if (this.isChatCompletion) {
reply = result.choices[0].message.content; reply = result.choices[0].message.content;
@ -818,7 +819,7 @@ ${convo}
const completionTokens = this.getTokenCount(title); const completionTokens = this.getTokenCount(title);
this.recordTokenUsage({ promptTokens, completionTokens, context: 'title' }); await this.recordTokenUsage({ promptTokens, completionTokens, context: 'title' });
} catch (e) { } catch (e) {
logger.error( logger.error(
'[OpenAIClient] There was an issue generating the title with the completion method', '[OpenAIClient] There was an issue generating the title with the completion method',
@ -1245,7 +1246,10 @@ ${convo}
let chatCompletion; let chatCompletion;
/** @type {OpenAI} */ /** @type {OpenAI} */
const openai = new OpenAI({ const openai = new OpenAI({
fetch: this.fetch, fetch: createFetch({
directEndpoint: this.options.directEndpoint,
reverseProxyUrl: this.options.reverseProxyUrl,
}),
apiKey: this.apiKey, apiKey: this.apiKey,
...opts, ...opts,
}); });
@ -1275,12 +1279,13 @@ ${convo}
} }
if (this.options.addParams && typeof this.options.addParams === 'object') { if (this.options.addParams && typeof this.options.addParams === 'object') {
const addParams = { ...this.options.addParams };
modelOptions = { modelOptions = {
...modelOptions, ...modelOptions,
...this.options.addParams, ...addParams,
}; };
logger.debug('[OpenAIClient] chatCompletion: added params', { logger.debug('[OpenAIClient] chatCompletion: added params', {
addParams: this.options.addParams, addParams: addParams,
modelOptions, modelOptions,
}); });
} }
@ -1309,11 +1314,12 @@ ${convo}
} }
if (this.options.dropParams && Array.isArray(this.options.dropParams)) { if (this.options.dropParams && Array.isArray(this.options.dropParams)) {
this.options.dropParams.forEach((param) => { const dropParams = [...this.options.dropParams];
dropParams.forEach((param) => {
delete modelOptions[param]; delete modelOptions[param];
}); });
logger.debug('[OpenAIClient] chatCompletion: dropped params', { logger.debug('[OpenAIClient] chatCompletion: dropped params', {
dropParams: this.options.dropParams, dropParams: dropParams,
modelOptions, modelOptions,
}); });
} }
@ -1355,15 +1361,12 @@ ${convo}
delete modelOptions.reasoning_effort; delete modelOptions.reasoning_effort;
} }
const handlers = createStreamEventHandlers(this.options.res);
this.streamHandler = new SplitStreamHandler({ this.streamHandler = new SplitStreamHandler({
reasoningKey, reasoningKey,
accumulate: true, accumulate: true,
runId: this.responseMessageId, runId: this.responseMessageId,
handlers: { handlers,
[GraphEvents.ON_RUN_STEP]: (event) => sendEvent(this.options.res, event),
[GraphEvents.ON_MESSAGE_DELTA]: (event) => sendEvent(this.options.res, event),
[GraphEvents.ON_REASONING_DELTA]: (event) => sendEvent(this.options.res, event),
},
}); });
intermediateReply = this.streamHandler.tokens; intermediateReply = this.streamHandler.tokens;

View file

@ -252,12 +252,14 @@ class PluginsClient extends OpenAIClient {
await this.recordTokenUsage(responseMessage); await this.recordTokenUsage(responseMessage);
} }
this.responsePromise = this.saveMessageToDatabase(responseMessage, saveOptions, user); const databasePromise = this.saveMessageToDatabase(responseMessage, saveOptions, user);
delete responseMessage.tokenCount; delete responseMessage.tokenCount;
return { ...responseMessage, ...result }; return { ...responseMessage, ...result, databasePromise };
} }
async sendMessage(message, opts = {}) { async sendMessage(message, opts = {}) {
/** @type {Promise<TMessage>} */
let userMessagePromise;
/** @type {{ filteredTools: string[], includedTools: string[] }} */ /** @type {{ filteredTools: string[], includedTools: string[] }} */
const { filteredTools = [], includedTools = [] } = this.options.req.app.locals; const { filteredTools = [], includedTools = [] } = this.options.req.app.locals;
@ -327,10 +329,10 @@ class PluginsClient extends OpenAIClient {
} }
if (!this.skipSaveUserMessage) { if (!this.skipSaveUserMessage) {
this.userMessagePromise = this.saveMessageToDatabase(userMessage, saveOptions, user); userMessagePromise = this.saveMessageToDatabase(userMessage, saveOptions, user);
if (typeof opts?.getReqData === 'function') { if (typeof opts?.getReqData === 'function') {
opts.getReqData({ opts.getReqData({
userMessagePromise: this.userMessagePromise, userMessagePromise,
}); });
} }
} }

View file

@ -0,0 +1,60 @@
const { GraphEvents } = require('@librechat/agents');
const { logger, sendEvent } = require('~/config');
/**
* Makes a function to make HTTP request and logs the process.
* @param {Object} params
* @param {boolean} [params.directEndpoint] - Whether to use a direct endpoint.
* @param {string} [params.reverseProxyUrl] - The reverse proxy URL to use for the request.
* @returns {Promise<Response>} - A promise that resolves to the response of the fetch request.
*/
function createFetch({ directEndpoint = false, reverseProxyUrl = '' }) {
/**
* Makes an HTTP request and logs the process.
* @param {RequestInfo} url - The URL to make the request to. Can be a string or a Request object.
* @param {RequestInit} [init] - Optional init options for the request.
* @returns {Promise<Response>} - A promise that resolves to the response of the fetch request.
*/
return async (_url, init) => {
let url = _url;
if (directEndpoint) {
url = reverseProxyUrl;
}
logger.debug(`Making request to ${url}`);
if (typeof Bun !== 'undefined') {
return await fetch(url, init);
}
return await fetch(url, init);
};
}
// Add this at the module level outside the class
/**
* Creates event handlers for stream events that don't capture client references
* @param {Object} res - The response object to send events to
* @returns {Object} Object containing handler functions
*/
function createStreamEventHandlers(res) {
return {
[GraphEvents.ON_RUN_STEP]: (event) => {
if (res) {
sendEvent(res, event);
}
},
[GraphEvents.ON_MESSAGE_DELTA]: (event) => {
if (res) {
sendEvent(res, event);
}
},
[GraphEvents.ON_REASONING_DELTA]: (event) => {
if (res) {
sendEvent(res, event);
}
},
};
}
module.exports = {
createFetch,
createStreamEventHandlers,
};

View file

@ -123,7 +123,7 @@ const getAuthFields = (toolKey) => {
* *
* @param {object} object * @param {object} object
* @param {string} object.user * @param {string} object.user
* @param {Agent} [object.agent] * @param {Pick<Agent, 'id' | 'provider' | 'model'>} [object.agent]
* @param {string} [object.model] * @param {string} [object.model]
* @param {EModelEndpoint} [object.endpoint] * @param {EModelEndpoint} [object.endpoint]
* @param {LoadToolOptions} [object.options] * @param {LoadToolOptions} [object.options]

View file

@ -1,4 +1,4 @@
const Keyv = require('keyv'); const { Keyv } = require('keyv');
const { CacheKeys, ViolationTypes, Time } = require('librechat-data-provider'); const { CacheKeys, ViolationTypes, Time } = require('librechat-data-provider');
const { logFile, violationFile } = require('./keyvFiles'); const { logFile, violationFile } = require('./keyvFiles');
const { math, isEnabled } = require('~/server/utils'); const { math, isEnabled } = require('~/server/utils');

92
api/cache/ioredisClient.js vendored Normal file
View file

@ -0,0 +1,92 @@
const fs = require('fs');
const Redis = require('ioredis');
const { isEnabled } = require('~/server/utils');
const logger = require('~/config/winston');
const { REDIS_URI, USE_REDIS, USE_REDIS_CLUSTER, REDIS_CA, REDIS_MAX_LISTENERS } = process.env;
/** @type {import('ioredis').Redis | import('ioredis').Cluster} */
let ioredisClient;
const redis_max_listeners = Number(REDIS_MAX_LISTENERS) || 40;
function mapURI(uri) {
const regex =
/^(?:(?<scheme>\w+):\/\/)?(?:(?<user>[^:@]+)(?::(?<password>[^@]+))?@)?(?<host>[\w.-]+)(?::(?<port>\d{1,5}))?$/;
const match = uri.match(regex);
if (match) {
const { scheme, user, password, host, port } = match.groups;
return {
scheme: scheme || 'none',
user: user || null,
password: password || null,
host: host || null,
port: port || null,
};
} else {
const parts = uri.split(':');
if (parts.length === 2) {
return {
scheme: 'none',
user: null,
password: null,
host: parts[0],
port: parts[1],
};
}
return {
scheme: 'none',
user: null,
password: null,
host: uri,
port: null,
};
}
}
if (REDIS_URI && isEnabled(USE_REDIS)) {
let redisOptions = null;
if (REDIS_CA) {
const ca = fs.readFileSync(REDIS_CA);
redisOptions = { tls: { ca } };
}
if (isEnabled(USE_REDIS_CLUSTER)) {
const hosts = REDIS_URI.split(',').map((item) => {
var value = mapURI(item);
return {
host: value.host,
port: value.port,
};
});
ioredisClient = new Redis.Cluster(hosts, { redisOptions });
} else {
ioredisClient = new Redis(REDIS_URI, redisOptions);
}
ioredisClient.on('ready', () => {
logger.info('IoRedis connection ready');
});
ioredisClient.on('reconnecting', () => {
logger.info('IoRedis connection reconnecting');
});
ioredisClient.on('end', () => {
logger.info('IoRedis connection ended');
});
ioredisClient.on('close', () => {
logger.info('IoRedis connection closed');
});
ioredisClient.on('error', (err) => logger.error('IoRedis connection error:', err));
ioredisClient.setMaxListeners(redis_max_listeners);
logger.info(
'[Optional] IoRedis initialized for rate limiters. If you have issues, disable Redis or restart the server.',
);
} else {
logger.info('[Optional] IoRedis not initialized for rate limiters.');
}
module.exports = ioredisClient;

View file

@ -1,11 +1,9 @@
const { KeyvFile } = require('keyv-file'); const { KeyvFile } = require('keyv-file');
const logFile = new KeyvFile({ filename: './data/logs.json' }); const logFile = new KeyvFile({ filename: './data/logs.json' }).setMaxListeners(20);
const pendingReqFile = new KeyvFile({ filename: './data/pendingReqCache.json' }); const violationFile = new KeyvFile({ filename: './data/violations.json' }).setMaxListeners(20);
const violationFile = new KeyvFile({ filename: './data/violations.json' });
module.exports = { module.exports = {
logFile, logFile,
pendingReqFile,
violationFile, violationFile,
}; };

View file

@ -1,4 +1,4 @@
const KeyvMongo = require('@keyv/mongo'); const { KeyvMongo } = require('@keyv/mongo');
const { logger } = require('~/config'); const { logger } = require('~/config');
const { MONGO_URI } = process.env ?? {}; const { MONGO_URI } = process.env ?? {};

View file

@ -1,6 +1,6 @@
const fs = require('fs'); const fs = require('fs');
const ioredis = require('ioredis'); const ioredis = require('ioredis');
const KeyvRedis = require('@keyv/redis'); const KeyvRedis = require('@keyv/redis').default;
const { isEnabled } = require('~/server/utils'); const { isEnabled } = require('~/server/utils');
const logger = require('~/config/winston'); const logger = require('~/config/winston');
@ -50,6 +50,7 @@ function mapURI(uri) {
if (REDIS_URI && isEnabled(USE_REDIS)) { if (REDIS_URI && isEnabled(USE_REDIS)) {
let redisOptions = null; let redisOptions = null;
/** @type {import('@keyv/redis').KeyvRedisOptions} */
let keyvOpts = { let keyvOpts = {
useRedisSets: false, useRedisSets: false,
keyPrefix: redis_prefix, keyPrefix: redis_prefix,
@ -74,6 +75,18 @@ if (REDIS_URI && isEnabled(USE_REDIS)) {
} else { } else {
keyvRedis = new KeyvRedis(REDIS_URI, keyvOpts); keyvRedis = new KeyvRedis(REDIS_URI, keyvOpts);
} }
keyvRedis.on('ready', () => {
logger.info('KeyvRedis connection ready');
});
keyvRedis.on('reconnecting', () => {
logger.info('KeyvRedis connection reconnecting');
});
keyvRedis.on('end', () => {
logger.info('KeyvRedis connection ended');
});
keyvRedis.on('close', () => {
logger.info('KeyvRedis connection closed');
});
keyvRedis.on('error', (err) => logger.error('KeyvRedis connection error:', err)); keyvRedis.on('error', (err) => logger.error('KeyvRedis connection error:', err));
keyvRedis.setMaxListeners(redis_max_listeners); keyvRedis.setMaxListeners(redis_max_listeners);
logger.info( logger.info(

4
api/cache/redis.js vendored
View file

@ -1,4 +0,0 @@
const Redis = require('ioredis');
const { REDIS_URI } = process.env ?? {};
const redis = new Redis.Cluster(REDIS_URI);
module.exports = redis;

View file

@ -6,15 +6,19 @@ const logger = require('./winston');
global.EventSource = EventSource; global.EventSource = EventSource;
/** @type {MCPManager} */
let mcpManager = null; let mcpManager = null;
let flowManager = null; let flowManager = null;
/** /**
* @param {string} [userId] - Optional user ID, to avoid disconnecting the current user.
* @returns {MCPManager} * @returns {MCPManager}
*/ */
function getMCPManager() { function getMCPManager(userId) {
if (!mcpManager) { if (!mcpManager) {
mcpManager = MCPManager.getInstance(logger); mcpManager = MCPManager.getInstance(logger);
} else {
mcpManager.checkIdleConnections(userId);
} }
return mcpManager; return mcpManager;
} }

View file

@ -123,6 +123,10 @@ const tokenValues = Object.assign(
'grok-2-1212': { prompt: 2.0, completion: 10.0 }, 'grok-2-1212': { prompt: 2.0, completion: 10.0 },
'grok-2-latest': { prompt: 2.0, completion: 10.0 }, 'grok-2-latest': { prompt: 2.0, completion: 10.0 },
'grok-2': { prompt: 2.0, completion: 10.0 }, 'grok-2': { prompt: 2.0, completion: 10.0 },
'grok-3-mini-fast': { prompt: 0.4, completion: 4 },
'grok-3-mini': { prompt: 0.3, completion: 0.5 },
'grok-3-fast': { prompt: 5.0, completion: 25.0 },
'grok-3': { prompt: 3.0, completion: 15.0 },
'grok-beta': { prompt: 5.0, completion: 15.0 }, 'grok-beta': { prompt: 5.0, completion: 15.0 },
'mistral-large': { prompt: 2.0, completion: 6.0 }, 'mistral-large': { prompt: 2.0, completion: 6.0 },
'pixtral-large': { prompt: 2.0, completion: 6.0 }, 'pixtral-large': { prompt: 2.0, completion: 6.0 },

View file

@ -507,5 +507,27 @@ describe('Grok Model Tests - Pricing', () => {
expect(getMultiplier({ model: 'grok-beta', tokenType: 'prompt' })).toBe(5.0); expect(getMultiplier({ model: 'grok-beta', tokenType: 'prompt' })).toBe(5.0);
expect(getMultiplier({ model: 'grok-beta', tokenType: 'completion' })).toBe(15.0); expect(getMultiplier({ model: 'grok-beta', tokenType: 'completion' })).toBe(15.0);
}); });
test('should return correct prompt and completion rates for Grok 3 models', () => {
expect(getMultiplier({ model: 'grok-3', tokenType: 'prompt' })).toBe(3.0);
expect(getMultiplier({ model: 'grok-3', tokenType: 'completion' })).toBe(15.0);
expect(getMultiplier({ model: 'grok-3-fast', tokenType: 'prompt' })).toBe(5.0);
expect(getMultiplier({ model: 'grok-3-fast', tokenType: 'completion' })).toBe(25.0);
expect(getMultiplier({ model: 'grok-3-mini', tokenType: 'prompt' })).toBe(0.3);
expect(getMultiplier({ model: 'grok-3-mini', tokenType: 'completion' })).toBe(0.5);
expect(getMultiplier({ model: 'grok-3-mini-fast', tokenType: 'prompt' })).toBe(0.4);
expect(getMultiplier({ model: 'grok-3-mini-fast', tokenType: 'completion' })).toBe(4.0);
});
test('should return correct prompt and completion rates for Grok 3 models with prefixes', () => {
expect(getMultiplier({ model: 'xai/grok-3', tokenType: 'prompt' })).toBe(3.0);
expect(getMultiplier({ model: 'xai/grok-3', tokenType: 'completion' })).toBe(15.0);
expect(getMultiplier({ model: 'xai/grok-3-fast', tokenType: 'prompt' })).toBe(5.0);
expect(getMultiplier({ model: 'xai/grok-3-fast', tokenType: 'completion' })).toBe(25.0);
expect(getMultiplier({ model: 'xai/grok-3-mini', tokenType: 'prompt' })).toBe(0.3);
expect(getMultiplier({ model: 'xai/grok-3-mini', tokenType: 'completion' })).toBe(0.5);
expect(getMultiplier({ model: 'xai/grok-3-mini-fast', tokenType: 'prompt' })).toBe(0.4);
expect(getMultiplier({ model: 'xai/grok-3-mini-fast', tokenType: 'completion' })).toBe(4.0);
});
}); });
}); });

View file

@ -42,14 +42,14 @@
"@azure/storage-blob": "^12.26.0", "@azure/storage-blob": "^12.26.0",
"@google/generative-ai": "^0.23.0", "@google/generative-ai": "^0.23.0",
"@googleapis/youtube": "^20.0.0", "@googleapis/youtube": "^20.0.0",
"@keyv/mongo": "^2.1.8", "@keyv/mongo": "^3.0.1",
"@keyv/redis": "^2.8.1", "@keyv/redis": "^4.3.3",
"@langchain/community": "^0.3.39", "@langchain/community": "^0.3.39",
"@langchain/core": "^0.3.43", "@langchain/core": "^0.3.43",
"@langchain/google-genai": "^0.2.2", "@langchain/google-genai": "^0.2.2",
"@langchain/google-vertexai": "^0.2.3", "@langchain/google-vertexai": "^0.2.3",
"@langchain/textsplitters": "^0.1.0", "@langchain/textsplitters": "^0.1.0",
"@librechat/agents": "^2.4.12", "@librechat/agents": "^2.4.14",
"@librechat/data-schemas": "*", "@librechat/data-schemas": "*",
"@waylaidwanderer/fetch-event-source": "^3.0.1", "@waylaidwanderer/fetch-event-source": "^3.0.1",
"axios": "^1.8.2", "axios": "^1.8.2",
@ -76,8 +76,8 @@
"ioredis": "^5.3.2", "ioredis": "^5.3.2",
"js-yaml": "^4.1.0", "js-yaml": "^4.1.0",
"jsonwebtoken": "^9.0.0", "jsonwebtoken": "^9.0.0",
"keyv": "^4.5.4", "keyv": "^5.3.2",
"keyv-file": "^0.2.0", "keyv-file": "^5.1.2",
"klona": "^2.0.6", "klona": "^2.0.6",
"librechat-data-provider": "*", "librechat-data-provider": "*",
"librechat-mcp": "*", "librechat-mcp": "*",

240
api/server/cleanup.js Normal file
View file

@ -0,0 +1,240 @@
const { logger } = require('~/config');
// WeakMap to hold temporary data associated with requests
const requestDataMap = new WeakMap();
const FinalizationRegistry = global.FinalizationRegistry || null;
/**
* FinalizationRegistry to clean up client objects when they are garbage collected.
* This is used to prevent memory leaks and ensure that client objects are
* properly disposed of when they are no longer needed.
* The registry holds a weak reference to the client object and a cleanup
* callback that is called when the client object is garbage collected.
* The callback can be used to perform any necessary cleanup operations,
* such as removing event listeners or freeing up resources.
*/
const clientRegistry = FinalizationRegistry
? new FinalizationRegistry((heldValue) => {
try {
// This will run when the client is garbage collected
if (heldValue && heldValue.userId) {
logger.debug(`[FinalizationRegistry] Cleaning up client for user ${heldValue.userId}`);
} else {
logger.debug('[FinalizationRegistry] Cleaning up client');
}
} catch (e) {
// Ignore errors
}
})
: null;
/**
* Cleans up the client object by removing references to its properties.
* This is useful for preventing memory leaks and ensuring that the client
* and its properties can be garbage collected when it is no longer needed.
*/
function disposeClient(client) {
if (!client) {
return;
}
try {
if (client.user) {
client.user = null;
}
if (client.apiKey) {
client.apiKey = null;
}
if (client.azure) {
client.azure = null;
}
if (client.conversationId) {
client.conversationId = null;
}
if (client.responseMessageId) {
client.responseMessageId = null;
}
if (client.clientName) {
client.clientName = null;
}
if (client.sender) {
client.sender = null;
}
if (client.model) {
client.model = null;
}
if (client.maxContextTokens) {
client.maxContextTokens = null;
}
if (client.contextStrategy) {
client.contextStrategy = null;
}
if (client.currentDateString) {
client.currentDateString = null;
}
if (client.inputTokensKey) {
client.inputTokensKey = null;
}
if (client.outputTokensKey) {
client.outputTokensKey = null;
}
if (client.run) {
// Break circular references in run
if (client.run.Graph) {
client.run.Graph.resetValues();
client.run.Graph.handlerRegistry = null;
client.run.Graph.runId = null;
client.run.Graph.tools = null;
client.run.Graph.signal = null;
client.run.Graph.config = null;
client.run.Graph.toolEnd = null;
client.run.Graph.toolMap = null;
client.run.Graph.provider = null;
client.run.Graph.streamBuffer = null;
client.run.Graph.clientOptions = null;
client.run.Graph.graphState = null;
client.run.Graph.boundModel = null;
client.run.Graph.systemMessage = null;
client.run.Graph.reasoningKey = null;
client.run.Graph.messages = null;
client.run.Graph.contentData = null;
client.run.Graph.stepKeyIds = null;
client.run.Graph.contentIndexMap = null;
client.run.Graph.toolCallStepIds = null;
client.run.Graph.messageIdsByStepKey = null;
client.run.Graph.messageStepHasToolCalls = null;
client.run.Graph.prelimMessageIdsByStepKey = null;
client.run.Graph.currentTokenType = null;
client.run.Graph.lastToken = null;
client.run.Graph.tokenTypeSwitch = null;
client.run.Graph.indexTokenCountMap = null;
client.run.Graph.currentUsage = null;
client.run.Graph.tokenCounter = null;
client.run.Graph.maxContextTokens = null;
client.run.Graph.pruneMessages = null;
client.run.Graph.lastStreamCall = null;
client.run.Graph.startIndex = null;
client.run.Graph = null;
}
if (client.run.handlerRegistry) {
client.run.handlerRegistry = null;
}
if (client.run.graphRunnable) {
if (client.run.graphRunnable.channels) {
client.run.graphRunnable.channels = null;
}
if (client.run.graphRunnable.nodes) {
client.run.graphRunnable.nodes = null;
}
if (client.run.graphRunnable.lc_kwargs) {
client.run.graphRunnable.lc_kwargs = null;
}
if (client.run.graphRunnable.builder?.nodes) {
client.run.graphRunnable.builder.nodes = null;
client.run.graphRunnable.builder = null;
}
client.run.graphRunnable = null;
}
client.run = null;
}
if (client.sendMessage) {
client.sendMessage = null;
}
if (client.savedMessageIds) {
client.savedMessageIds.clear();
client.savedMessageIds = null;
}
if (client.currentMessages) {
client.currentMessages = null;
}
if (client.streamHandler) {
client.streamHandler = null;
}
if (client.contentParts) {
client.contentParts = null;
}
if (client.abortController) {
client.abortController = null;
}
if (client.collectedUsage) {
client.collectedUsage = null;
}
if (client.indexTokenCountMap) {
client.indexTokenCountMap = null;
}
if (client.agentConfigs) {
client.agentConfigs = null;
}
if (client.artifactPromises) {
client.artifactPromises = null;
}
if (client.usage) {
client.usage = null;
}
if (typeof client.dispose === 'function') {
client.dispose();
}
if (client.options) {
if (client.options.req) {
client.options.req = null;
}
if (client.options.res) {
client.options.res = null;
}
if (client.options.attachments) {
client.options.attachments = null;
}
if (client.options.agent) {
client.options.agent = null;
}
}
client.options = null;
} catch (e) {
// Ignore errors during disposal
}
}
function processReqData(data = {}, context) {
let {
abortKey,
userMessage,
userMessagePromise,
responseMessageId,
promptTokens,
conversationId,
userMessageId,
} = context;
for (const key in data) {
if (key === 'userMessage') {
userMessage = data[key];
userMessageId = data[key].messageId;
} else if (key === 'userMessagePromise') {
userMessagePromise = data[key];
} else if (key === 'responseMessageId') {
responseMessageId = data[key];
} else if (key === 'promptTokens') {
promptTokens = data[key];
} else if (key === 'abortKey') {
abortKey = data[key];
} else if (!conversationId && key === 'conversationId') {
conversationId = data[key];
}
}
return {
abortKey,
userMessage,
userMessagePromise,
responseMessageId,
promptTokens,
conversationId,
userMessageId,
};
}
module.exports = {
disposeClient,
requestDataMap,
clientRegistry,
processReqData,
};

View file

@ -1,5 +1,15 @@
const { getResponseSender, Constants } = require('librechat-data-provider'); const { getResponseSender, Constants } = require('librechat-data-provider');
const { createAbortController, handleAbortError } = require('~/server/middleware'); const {
handleAbortError,
createAbortController,
cleanupAbortController,
} = require('~/server/middleware');
const {
disposeClient,
processReqData,
clientRegistry,
requestDataMap,
} = require('~/server/cleanup');
const { sendMessage, createOnProgress } = require('~/server/utils'); const { sendMessage, createOnProgress } = require('~/server/utils');
const { saveMessage } = require('~/models'); const { saveMessage } = require('~/models');
const { logger } = require('~/config'); const { logger } = require('~/config');
@ -14,90 +24,162 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
overrideParentMessageId = null, overrideParentMessageId = null,
} = req.body; } = req.body;
let client = null;
let abortKey = null;
let cleanupHandlers = [];
let clientRef = null;
logger.debug('[AskController]', { logger.debug('[AskController]', {
text, text,
conversationId, conversationId,
...endpointOption, ...endpointOption,
modelsConfig: endpointOption.modelsConfig ? 'exists' : '', modelsConfig: endpointOption?.modelsConfig ? 'exists' : '',
}); });
let userMessage; let userMessage = null;
let userMessagePromise; let userMessagePromise = null;
let promptTokens; let promptTokens = null;
let userMessageId; let userMessageId = null;
let responseMessageId; let responseMessageId = null;
let getAbortData = null;
const sender = getResponseSender({ const sender = getResponseSender({
...endpointOption, ...endpointOption,
model: endpointOption.modelOptions.model, model: endpointOption.modelOptions.model,
modelDisplayLabel, modelDisplayLabel,
}); });
const newConvo = !conversationId; const initialConversationId = conversationId;
const user = req.user.id; const newConvo = !initialConversationId;
const userId = req.user.id;
const getReqData = (data = {}) => { let reqDataContext = {
for (let key in data) { userMessage,
if (key === 'userMessage') { userMessagePromise,
userMessage = data[key]; responseMessageId,
userMessageId = data[key].messageId; promptTokens,
} else if (key === 'userMessagePromise') { conversationId,
userMessagePromise = data[key]; userMessageId,
} else if (key === 'responseMessageId') {
responseMessageId = data[key];
} else if (key === 'promptTokens') {
promptTokens = data[key];
} else if (!conversationId && key === 'conversationId') {
conversationId = data[key];
}
}
}; };
let getText; const updateReqData = (data = {}) => {
reqDataContext = processReqData(data, reqDataContext);
abortKey = reqDataContext.abortKey;
userMessage = reqDataContext.userMessage;
userMessagePromise = reqDataContext.userMessagePromise;
responseMessageId = reqDataContext.responseMessageId;
promptTokens = reqDataContext.promptTokens;
conversationId = reqDataContext.conversationId;
userMessageId = reqDataContext.userMessageId;
};
let { onProgress: progressCallback, getPartialText } = createOnProgress();
const performCleanup = () => {
logger.debug('[AskController] Performing cleanup');
if (Array.isArray(cleanupHandlers)) {
for (const handler of cleanupHandlers) {
try { try {
const { client } = await initializeClient({ req, res, endpointOption }); if (typeof handler === 'function') {
const { onProgress: progressCallback, getPartialText } = createOnProgress(); handler();
}
getText = client.getStreamText != null ? client.getStreamText.bind(client) : getPartialText; } catch (e) {
// Ignore
const getAbortData = () => ({ }
sender, }
conversationId,
userMessagePromise,
messageId: responseMessageId,
parentMessageId: overrideParentMessageId ?? userMessageId,
text: getText(),
userMessage,
promptTokens,
});
const { abortController, onStart } = createAbortController(req, res, getAbortData, getReqData);
res.on('close', () => {
logger.debug('[AskController] Request closed');
if (!abortController) {
return;
} else if (abortController.signal.aborted) {
return;
} else if (abortController.requestCompleted) {
return;
} }
if (abortKey) {
logger.debug('[AskController] Cleaning up abort controller');
cleanupAbortController(abortKey);
abortKey = null;
}
if (client) {
disposeClient(client);
client = null;
}
reqDataContext = null;
userMessage = null;
userMessagePromise = null;
promptTokens = null;
getAbortData = null;
progressCallback = null;
endpointOption = null;
cleanupHandlers = null;
addTitle = null;
if (requestDataMap.has(req)) {
requestDataMap.delete(req);
}
logger.debug('[AskController] Cleanup completed');
};
try {
({ client } = await initializeClient({ req, res, endpointOption }));
if (clientRegistry && client) {
clientRegistry.register(client, { userId }, client);
}
if (client) {
requestDataMap.set(req, { client });
}
clientRef = new WeakRef(client);
getAbortData = () => {
const currentClient = clientRef.deref();
const currentText =
currentClient?.getStreamText != null ? currentClient.getStreamText() : getPartialText();
return {
sender,
conversationId,
messageId: reqDataContext.responseMessageId,
parentMessageId: overrideParentMessageId ?? userMessageId,
text: currentText,
userMessage: userMessage,
userMessagePromise: userMessagePromise,
promptTokens: reqDataContext.promptTokens,
};
};
const { onStart, abortController } = createAbortController(
req,
res,
getAbortData,
updateReqData,
);
const closeHandler = () => {
logger.debug('[AskController] Request closed');
if (!abortController || abortController.signal.aborted || abortController.requestCompleted) {
return;
}
abortController.abort(); abortController.abort();
logger.debug('[AskController] Request aborted on close'); logger.debug('[AskController] Request aborted on close');
};
res.on('close', closeHandler);
cleanupHandlers.push(() => {
try {
res.removeListener('close', closeHandler);
} catch (e) {
// Ignore
}
}); });
const messageOptions = { const messageOptions = {
user, user: userId,
parentMessageId, parentMessageId,
conversationId, conversationId: reqDataContext.conversationId,
overrideParentMessageId, overrideParentMessageId,
getReqData, getReqData: updateReqData,
onStart, onStart,
abortController, abortController,
progressCallback, progressCallback,
progressOptions: { progressOptions: {
res, res,
// parentMessageId: overrideParentMessageId || userMessageId,
}, },
}; };
@ -105,58 +187,93 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
let response = await client.sendMessage(text, messageOptions); let response = await client.sendMessage(text, messageOptions);
response.endpoint = endpointOption.endpoint; response.endpoint = endpointOption.endpoint;
const { conversation = {} } = await client.responsePromise; const databasePromise = response.databasePromise;
delete response.databasePromise;
const { conversation: convoData = {} } = await databasePromise;
const conversation = { ...convoData };
conversation.title = conversation.title =
conversation && !conversation.title ? null : conversation?.title || 'New Chat'; conversation && !conversation.title ? null : conversation?.title || 'New Chat';
if (client.options.attachments) { const latestUserMessage = reqDataContext.userMessage;
userMessage.files = client.options.attachments;
if (client?.options?.attachments && latestUserMessage) {
latestUserMessage.files = client.options.attachments;
if (endpointOption?.modelOptions?.model) {
conversation.model = endpointOption.modelOptions.model; conversation.model = endpointOption.modelOptions.model;
delete userMessage.image_urls; }
delete latestUserMessage.image_urls;
} }
if (!abortController.signal.aborted) { if (!abortController.signal.aborted) {
const finalResponseMessage = { ...response };
sendMessage(res, { sendMessage(res, {
final: true, final: true,
conversation, conversation,
title: conversation.title, title: conversation.title,
requestMessage: userMessage, requestMessage: latestUserMessage,
responseMessage: response, responseMessage: finalResponseMessage,
}); });
res.end(); res.end();
if (!client.savedMessageIds.has(response.messageId)) { if (client?.savedMessageIds && !client.savedMessageIds.has(response.messageId)) {
await saveMessage( await saveMessage(
req, req,
{ ...response, user }, { ...finalResponseMessage, user: userId },
{ context: 'api/server/controllers/AskController.js - response end' }, { context: 'api/server/controllers/AskController.js - response end' },
); );
} }
} }
if (!client.skipSaveUserMessage) { if (!client?.skipSaveUserMessage && latestUserMessage) {
await saveMessage(req, userMessage, { await saveMessage(req, latestUserMessage, {
context: 'api/server/controllers/AskController.js - don\'t skip saving user message', context: 'api/server/controllers/AskController.js - don\'t skip saving user message',
}); });
} }
if (addTitle && parentMessageId === Constants.NO_PARENT && newConvo) { if (typeof addTitle === 'function' && parentMessageId === Constants.NO_PARENT && newConvo) {
addTitle(req, { addTitle(req, {
text, text,
response, response: { ...response },
client, client,
})
.then(() => {
logger.debug('[AskController] Title generation started');
})
.catch((err) => {
logger.error('[AskController] Error in title generation', err);
})
.finally(() => {
logger.debug('[AskController] Title generation completed');
performCleanup();
}); });
} else {
performCleanup();
} }
} catch (error) { } catch (error) {
const partialText = getText && getText(); logger.error('[AskController] Error handling request', error);
let partialText = '';
try {
const currentClient = clientRef.deref();
partialText =
currentClient?.getStreamText != null ? currentClient.getStreamText() : getPartialText();
} catch (getTextError) {
logger.error('[AskController] Error calling getText() during error handling', getTextError);
}
handleAbortError(res, req, error, { handleAbortError(res, req, error, {
sender, sender,
partialText, partialText,
conversationId, conversationId: reqDataContext.conversationId,
messageId: responseMessageId, messageId: reqDataContext.responseMessageId,
parentMessageId: overrideParentMessageId ?? userMessageId ?? parentMessageId, parentMessageId: overrideParentMessageId ?? reqDataContext.userMessageId ?? parentMessageId,
}).catch((err) => { })
logger.error('[AskController] Error in `handleAbortError`', err); .catch((err) => {
logger.error('[AskController] Error in `handleAbortError` during catch block', err);
})
.finally(() => {
performCleanup();
}); });
} }
}; };

View file

@ -1,5 +1,15 @@
const { getResponseSender } = require('librechat-data-provider'); const { getResponseSender } = require('librechat-data-provider');
const { createAbortController, handleAbortError } = require('~/server/middleware'); const {
handleAbortError,
createAbortController,
cleanupAbortController,
} = require('~/server/middleware');
const {
disposeClient,
processReqData,
clientRegistry,
requestDataMap,
} = require('~/server/cleanup');
const { sendMessage, createOnProgress } = require('~/server/utils'); const { sendMessage, createOnProgress } = require('~/server/utils');
const { saveMessage } = require('~/models'); const { saveMessage } = require('~/models');
const { logger } = require('~/config'); const { logger } = require('~/config');
@ -17,6 +27,11 @@ const EditController = async (req, res, next, initializeClient) => {
overrideParentMessageId = null, overrideParentMessageId = null,
} = req.body; } = req.body;
let client = null;
let abortKey = null;
let cleanupHandlers = [];
let clientRef = null; // Declare clientRef here
logger.debug('[EditController]', { logger.debug('[EditController]', {
text, text,
generation, generation,
@ -26,122 +41,203 @@ const EditController = async (req, res, next, initializeClient) => {
modelsConfig: endpointOption.modelsConfig ? 'exists' : '', modelsConfig: endpointOption.modelsConfig ? 'exists' : '',
}); });
let userMessage; let userMessage = null;
let userMessagePromise; let userMessagePromise = null;
let promptTokens; let promptTokens = null;
let getAbortData = null;
const sender = getResponseSender({ const sender = getResponseSender({
...endpointOption, ...endpointOption,
model: endpointOption.modelOptions.model, model: endpointOption.modelOptions.model,
modelDisplayLabel, modelDisplayLabel,
}); });
const userMessageId = parentMessageId; const userMessageId = parentMessageId;
const user = req.user.id; const userId = req.user.id;
const getReqData = (data = {}) => { let reqDataContext = { userMessage, userMessagePromise, responseMessageId, promptTokens };
for (let key in data) {
if (key === 'userMessage') { const updateReqData = (data = {}) => {
userMessage = data[key]; reqDataContext = processReqData(data, reqDataContext);
} else if (key === 'userMessagePromise') { abortKey = reqDataContext.abortKey;
userMessagePromise = data[key]; userMessage = reqDataContext.userMessage;
} else if (key === 'responseMessageId') { userMessagePromise = reqDataContext.userMessagePromise;
responseMessageId = data[key]; responseMessageId = reqDataContext.responseMessageId;
} else if (key === 'promptTokens') { promptTokens = reqDataContext.promptTokens;
promptTokens = data[key];
}
}
}; };
const { onProgress: progressCallback, getPartialText } = createOnProgress({ let { onProgress: progressCallback, getPartialText } = createOnProgress({
generation, generation,
}); });
let getText; const performCleanup = () => {
logger.debug('[EditController] Performing cleanup');
if (Array.isArray(cleanupHandlers)) {
for (const handler of cleanupHandlers) {
try { try {
const { client } = await initializeClient({ req, res, endpointOption }); if (typeof handler === 'function') {
handler();
getText = client.getStreamText != null ? client.getStreamText.bind(client) : getPartialText; }
} catch (e) {
const getAbortData = () => ({ // Ignore
conversationId, }
userMessagePromise, }
messageId: responseMessageId,
sender,
parentMessageId: overrideParentMessageId ?? userMessageId,
text: getText(),
userMessage,
promptTokens,
});
const { abortController, onStart } = createAbortController(req, res, getAbortData, getReqData);
res.on('close', () => {
logger.debug('[EditController] Request closed');
if (!abortController) {
return;
} else if (abortController.signal.aborted) {
return;
} else if (abortController.requestCompleted) {
return;
} }
if (abortKey) {
logger.debug('[AskController] Cleaning up abort controller');
cleanupAbortController(abortKey);
abortKey = null;
}
if (client) {
disposeClient(client);
client = null;
}
reqDataContext = null;
userMessage = null;
userMessagePromise = null;
promptTokens = null;
getAbortData = null;
progressCallback = null;
endpointOption = null;
cleanupHandlers = null;
if (requestDataMap.has(req)) {
requestDataMap.delete(req);
}
logger.debug('[EditController] Cleanup completed');
};
try {
({ client } = await initializeClient({ req, res, endpointOption }));
if (clientRegistry && client) {
clientRegistry.register(client, { userId }, client);
}
if (client) {
requestDataMap.set(req, { client });
}
clientRef = new WeakRef(client);
getAbortData = () => {
const currentClient = clientRef.deref();
const currentText =
currentClient?.getStreamText != null ? currentClient.getStreamText() : getPartialText();
return {
sender,
conversationId,
messageId: reqDataContext.responseMessageId,
parentMessageId: overrideParentMessageId ?? userMessageId,
text: currentText,
userMessage: userMessage,
userMessagePromise: userMessagePromise,
promptTokens: reqDataContext.promptTokens,
};
};
const { onStart, abortController } = createAbortController(
req,
res,
getAbortData,
updateReqData,
);
const closeHandler = () => {
logger.debug('[EditController] Request closed');
if (!abortController || abortController.signal.aborted || abortController.requestCompleted) {
return;
}
abortController.abort(); abortController.abort();
logger.debug('[EditController] Request aborted on close'); logger.debug('[EditController] Request aborted on close');
};
res.on('close', closeHandler);
cleanupHandlers.push(() => {
try {
res.removeListener('close', closeHandler);
} catch (e) {
// Ignore
}
}); });
let response = await client.sendMessage(text, { let response = await client.sendMessage(text, {
user, user: userId,
generation, generation,
isContinued, isContinued,
isEdited: true, isEdited: true,
conversationId, conversationId,
parentMessageId, parentMessageId,
responseMessageId, responseMessageId: reqDataContext.responseMessageId,
overrideParentMessageId, overrideParentMessageId,
getReqData, getReqData: updateReqData,
onStart, onStart,
abortController, abortController,
progressCallback, progressCallback,
progressOptions: { progressOptions: {
res, res,
// parentMessageId: overrideParentMessageId || userMessageId,
}, },
}); });
const { conversation = {} } = await client.responsePromise; const databasePromise = response.databasePromise;
delete response.databasePromise;
const { conversation: convoData = {} } = await databasePromise;
const conversation = { ...convoData };
conversation.title = conversation.title =
conversation && !conversation.title ? null : conversation?.title || 'New Chat'; conversation && !conversation.title ? null : conversation?.title || 'New Chat';
if (client.options.attachments) { if (client?.options?.attachments && endpointOption?.modelOptions?.model) {
conversation.model = endpointOption.modelOptions.model; conversation.model = endpointOption.modelOptions.model;
} }
if (!abortController.signal.aborted) { if (!abortController.signal.aborted) {
const finalUserMessage = reqDataContext.userMessage;
const finalResponseMessage = { ...response };
sendMessage(res, { sendMessage(res, {
final: true, final: true,
conversation, conversation,
title: conversation.title, title: conversation.title,
requestMessage: userMessage, requestMessage: finalUserMessage,
responseMessage: response, responseMessage: finalResponseMessage,
}); });
res.end(); res.end();
await saveMessage( await saveMessage(
req, req,
{ ...response, user }, { ...finalResponseMessage, user: userId },
{ context: 'api/server/controllers/EditController.js - response end' }, { context: 'api/server/controllers/EditController.js - response end' },
); );
} }
performCleanup();
} catch (error) { } catch (error) {
const partialText = getText(); logger.error('[EditController] Error handling request', error);
let partialText = '';
try {
const currentClient = clientRef.deref();
partialText =
currentClient?.getStreamText != null ? currentClient.getStreamText() : getPartialText();
} catch (getTextError) {
logger.error('[EditController] Error calling getText() during error handling', getTextError);
}
handleAbortError(res, req, error, { handleAbortError(res, req, error, {
sender, sender,
partialText, partialText,
conversationId, conversationId,
messageId: responseMessageId, messageId: reqDataContext.responseMessageId,
parentMessageId: overrideParentMessageId ?? userMessageId ?? parentMessageId, parentMessageId: overrideParentMessageId ?? userMessageId ?? parentMessageId,
}).catch((err) => { })
logger.error('[EditController] Error in `handleAbortError`', err); .catch((err) => {
logger.error('[EditController] Error in `handleAbortError` during catch block', err);
})
.finally(() => {
performCleanup();
}); });
} }
}; };

View file

@ -63,6 +63,21 @@ const noSystemModelRegex = [/\bo1\b/gi];
// const { getFormattedMemories } = require('~/models/Memory'); // const { getFormattedMemories } = require('~/models/Memory');
// const { getCurrentDateTime } = require('~/utils'); // const { getCurrentDateTime } = require('~/utils');
function createTokenCounter(encoding) {
return (message) => {
const countTokens = (text) => Tokenizer.getTokenCount(text, encoding);
return getTokenCountForMessage(message, countTokens);
};
}
function logToolError(graph, error, toolId) {
logger.error(
'[api/server/controllers/agents/client.js #chatCompletion] Tool Error',
error,
toolId,
);
}
class AgentClient extends BaseClient { class AgentClient extends BaseClient {
constructor(options = {}) { constructor(options = {}) {
super(null, options); super(null, options);
@ -535,6 +550,10 @@ class AgentClient extends BaseClient {
} }
async chatCompletion({ payload, abortController = null }) { async chatCompletion({ payload, abortController = null }) {
/** @type {Partial<RunnableConfig> & { version: 'v1' | 'v2'; run_id?: string; streamMode: string }} */
let config;
/** @type {ReturnType<createRun>} */
let run;
try { try {
if (!abortController) { if (!abortController) {
abortController = new AbortController(); abortController = new AbortController();
@ -632,11 +651,11 @@ class AgentClient extends BaseClient {
/** @type {TCustomConfig['endpoints']['agents']} */ /** @type {TCustomConfig['endpoints']['agents']} */
const agentsEConfig = this.options.req.app.locals[EModelEndpoint.agents]; const agentsEConfig = this.options.req.app.locals[EModelEndpoint.agents];
/** @type {Partial<RunnableConfig> & { version: 'v1' | 'v2'; run_id?: string; streamMode: string }} */ config = {
const config = {
configurable: { configurable: {
thread_id: this.conversationId, thread_id: this.conversationId,
last_agent_index: this.agentConfigs?.size ?? 0, last_agent_index: this.agentConfigs?.size ?? 0,
user_id: this.user ?? this.options.req.user?.id,
hide_sequential_outputs: this.options.agent.hide_sequential_outputs, hide_sequential_outputs: this.options.agent.hide_sequential_outputs,
}, },
recursionLimit: agentsEConfig?.recursionLimit, recursionLimit: agentsEConfig?.recursionLimit,
@ -655,15 +674,6 @@ class AgentClient extends BaseClient {
initialMessages = formatContentStrings(initialMessages); initialMessages = formatContentStrings(initialMessages);
} }
/** @type {ReturnType<createRun>} */
let run;
const countTokens = ((text) => this.getTokenCount(text)).bind(this);
/** @type {(message: BaseMessage) => number} */
const tokenCounter = (message) => {
return getTokenCountForMessage(message, countTokens);
};
/** /**
* *
* @param {Agent} agent * @param {Agent} agent
@ -767,19 +777,14 @@ class AgentClient extends BaseClient {
run.Graph.contentData = contentData; run.Graph.contentData = contentData;
} }
const encoding = this.getEncoding();
await run.processStream({ messages }, config, { await run.processStream({ messages }, config, {
keepContent: i !== 0, keepContent: i !== 0,
tokenCounter, tokenCounter: createTokenCounter(encoding),
indexTokenCountMap: currentIndexCountMap, indexTokenCountMap: currentIndexCountMap,
maxContextTokens: agent.maxContextTokens, maxContextTokens: agent.maxContextTokens,
callbacks: { callbacks: {
[Callback.TOOL_ERROR]: (graph, error, toolId) => { [Callback.TOOL_ERROR]: logToolError,
logger.error(
'[api/server/controllers/agents/client.js #chatCompletion] Tool Error',
error,
toolId,
);
},
}, },
}); });
}; };
@ -809,6 +814,8 @@ class AgentClient extends BaseClient {
break; break;
} }
} }
const encoding = this.getEncoding();
const tokenCounter = createTokenCounter(encoding);
for (const [agentId, agent] of this.agentConfigs) { for (const [agentId, agent] of this.agentConfigs) {
if (abortController.signal.aborted === true) { if (abortController.signal.aborted === true) {
break; break;
@ -917,7 +924,7 @@ class AgentClient extends BaseClient {
* @param {string} params.text * @param {string} params.text
* @param {string} params.conversationId * @param {string} params.conversationId
*/ */
async titleConvo({ text }) { async titleConvo({ text, abortController }) {
if (!this.run) { if (!this.run) {
throw new Error('Run not initialized'); throw new Error('Run not initialized');
} }
@ -950,6 +957,7 @@ class AgentClient extends BaseClient {
contentParts: this.contentParts, contentParts: this.contentParts,
clientOptions, clientOptions,
chainOptions: { chainOptions: {
signal: abortController.signal,
callbacks: [ callbacks: [
{ {
handleLLMEnd, handleLLMEnd,
@ -975,7 +983,7 @@ class AgentClient extends BaseClient {
}; };
}); });
this.recordCollectedUsage({ await this.recordCollectedUsage({
model: clientOptions.model, model: clientOptions.model,
context: 'title', context: 'title',
collectedUsage, collectedUsage,

View file

@ -1,5 +1,10 @@
const { Constants } = require('librechat-data-provider'); const { Constants } = require('librechat-data-provider');
const { createAbortController, handleAbortError } = require('~/server/middleware'); const {
handleAbortError,
createAbortController,
cleanupAbortController,
} = require('~/server/middleware');
const { disposeClient, clientRegistry, requestDataMap } = require('~/server/cleanup');
const { sendMessage } = require('~/server/utils'); const { sendMessage } = require('~/server/utils');
const { saveMessage } = require('~/models'); const { saveMessage } = require('~/models');
const { logger } = require('~/config'); const { logger } = require('~/config');
@ -14,16 +19,22 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
} = req.body; } = req.body;
let sender; let sender;
let abortKey;
let userMessage; let userMessage;
let promptTokens; let promptTokens;
let userMessageId; let userMessageId;
let responseMessageId; let responseMessageId;
let userMessagePromise; let userMessagePromise;
let getAbortData;
let client = null;
// Initialize as an array
let cleanupHandlers = [];
const newConvo = !conversationId; const newConvo = !conversationId;
const user = req.user.id; const userId = req.user.id;
const getReqData = (data = {}) => { // Create handler to avoid capturing the entire parent scope
let getReqData = (data = {}) => {
for (let key in data) { for (let key in data) {
if (key === 'userMessage') { if (key === 'userMessage') {
userMessage = data[key]; userMessage = data[key];
@ -36,30 +47,96 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
promptTokens = data[key]; promptTokens = data[key];
} else if (key === 'sender') { } else if (key === 'sender') {
sender = data[key]; sender = data[key];
} else if (key === 'abortKey') {
abortKey = data[key];
} else if (!conversationId && key === 'conversationId') { } else if (!conversationId && key === 'conversationId') {
conversationId = data[key]; conversationId = data[key];
} }
} }
}; };
// Create a function to handle final cleanup
const performCleanup = () => {
logger.debug('[AgentController] Performing cleanup');
// Make sure cleanupHandlers is an array before iterating
if (Array.isArray(cleanupHandlers)) {
// Execute all cleanup handlers
for (const handler of cleanupHandlers) {
try {
if (typeof handler === 'function') {
handler();
}
} catch (e) {
// Ignore cleanup errors
}
}
}
// Clean up abort controller
if (abortKey) {
logger.debug('[AgentController] Cleaning up abort controller');
cleanupAbortController(abortKey);
}
// Dispose client properly
if (client) {
disposeClient(client);
}
// Clear all references
client = null;
getReqData = null;
userMessage = null;
getAbortData = null;
endpointOption.agent = null;
endpointOption = null;
cleanupHandlers = null;
userMessagePromise = null;
// Clear request data map
if (requestDataMap.has(req)) {
requestDataMap.delete(req);
}
logger.debug('[AgentController] Cleanup completed');
};
try { try {
/** @type {{ client: TAgentClient }} */ /** @type {{ client: TAgentClient }} */
const { client } = await initializeClient({ req, res, endpointOption }); const result = await initializeClient({ req, res, endpointOption });
client = result.client;
const getAbortData = () => ({ // Register client with finalization registry if available
if (clientRegistry) {
clientRegistry.register(client, { userId }, client);
}
// Store request data in WeakMap keyed by req object
requestDataMap.set(req, { client });
// Use WeakRef to allow GC but still access content if it exists
const contentRef = new WeakRef(client.contentParts || []);
// Minimize closure scope - only capture small primitives and WeakRef
getAbortData = () => {
// Dereference WeakRef each time
const content = contentRef.deref();
return {
sender, sender,
content: content || [],
userMessage, userMessage,
promptTokens, promptTokens,
conversationId, conversationId,
userMessagePromise, userMessagePromise,
messageId: responseMessageId, messageId: responseMessageId,
content: client.getContentParts(),
parentMessageId: overrideParentMessageId ?? userMessageId, parentMessageId: overrideParentMessageId ?? userMessageId,
}); };
};
const { abortController, onStart } = createAbortController(req, res, getAbortData, getReqData); const { abortController, onStart } = createAbortController(req, res, getAbortData, getReqData);
res.on('close', () => { // Simple handler to avoid capturing scope
const closeHandler = () => {
logger.debug('[AgentController] Request closed'); logger.debug('[AgentController] Request closed');
if (!abortController) { if (!abortController) {
return; return;
@ -71,10 +148,19 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
abortController.abort(); abortController.abort();
logger.debug('[AgentController] Request aborted on close'); logger.debug('[AgentController] Request aborted on close');
};
res.on('close', closeHandler);
cleanupHandlers.push(() => {
try {
res.removeListener('close', closeHandler);
} catch (e) {
// Ignore
}
}); });
const messageOptions = { const messageOptions = {
user, user: userId,
onStart, onStart,
getReqData, getReqData,
conversationId, conversationId,
@ -83,68 +169,102 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
overrideParentMessageId, overrideParentMessageId,
progressOptions: { progressOptions: {
res, res,
// parentMessageId: overrideParentMessageId || userMessageId,
}, },
}; };
let response = await client.sendMessage(text, messageOptions); let response = await client.sendMessage(text, messageOptions);
response.endpoint = endpointOption.endpoint;
const { conversation = {} } = await client.responsePromise; // Extract what we need and immediately break reference
const messageId = response.messageId;
const endpoint = endpointOption.endpoint;
response.endpoint = endpoint;
// Store database promise locally
const databasePromise = response.databasePromise;
delete response.databasePromise;
// Resolve database-related data
const { conversation: convoData = {} } = await databasePromise;
const conversation = { ...convoData };
conversation.title = conversation.title =
conversation && !conversation.title ? null : conversation?.title || 'New Chat'; conversation && !conversation.title ? null : conversation?.title || 'New Chat';
if (req.body.files && client.options.attachments) { // Process files if needed
if (req.body.files && client.options?.attachments) {
userMessage.files = []; userMessage.files = [];
const messageFiles = new Set(req.body.files.map((file) => file.file_id)); const messageFiles = new Set(req.body.files.map((file) => file.file_id));
for (let attachment of client.options.attachments) { for (let attachment of client.options.attachments) {
if (messageFiles.has(attachment.file_id)) { if (messageFiles.has(attachment.file_id)) {
userMessage.files.push(attachment); userMessage.files.push({ ...attachment });
} }
} }
delete userMessage.image_urls; delete userMessage.image_urls;
} }
// Only send if not aborted
if (!abortController.signal.aborted) { if (!abortController.signal.aborted) {
// Create a new response object with minimal copies
const finalResponse = { ...response };
sendMessage(res, { sendMessage(res, {
final: true, final: true,
conversation, conversation,
title: conversation.title, title: conversation.title,
requestMessage: userMessage, requestMessage: userMessage,
responseMessage: response, responseMessage: finalResponse,
}); });
res.end(); res.end();
if (!client.savedMessageIds.has(response.messageId)) { // Save the message if needed
if (client.savedMessageIds && !client.savedMessageIds.has(messageId)) {
await saveMessage( await saveMessage(
req, req,
{ ...response, user }, { ...finalResponse, user: userId },
{ context: 'api/server/controllers/agents/request.js - response end' }, { context: 'api/server/controllers/agents/request.js - response end' },
); );
} }
} }
// Save user message if needed
if (!client.skipSaveUserMessage) { if (!client.skipSaveUserMessage) {
await saveMessage(req, userMessage, { await saveMessage(req, userMessage, {
context: 'api/server/controllers/agents/request.js - don\'t skip saving user message', context: 'api/server/controllers/agents/request.js - don\'t skip saving user message',
}); });
} }
// Add title if needed - extract minimal data
if (addTitle && parentMessageId === Constants.NO_PARENT && newConvo) { if (addTitle && parentMessageId === Constants.NO_PARENT && newConvo) {
addTitle(req, { addTitle(req, {
text, text,
response, response: { ...response },
client, client,
})
.then(() => {
logger.debug('[AgentController] Title generation started');
})
.catch((err) => {
logger.error('[AgentController] Error in title generation', err);
})
.finally(() => {
logger.debug('[AgentController] Title generation completed');
performCleanup();
}); });
} else {
performCleanup();
} }
} catch (error) { } catch (error) {
// Handle error without capturing much scope
handleAbortError(res, req, error, { handleAbortError(res, req, error, {
conversationId, conversationId,
sender, sender,
messageId: responseMessageId, messageId: responseMessageId,
parentMessageId: overrideParentMessageId ?? userMessageId ?? parentMessageId, parentMessageId: overrideParentMessageId ?? userMessageId ?? parentMessageId,
}).catch((err) => { })
.catch((err) => {
logger.error('[api/server/controllers/agents/request] Error in `handleAbortError`', err); logger.error('[api/server/controllers/agents/request] Error in `handleAbortError`', err);
})
.finally(() => {
performCleanup();
}); });
} }
}; };

View file

@ -1,3 +1,4 @@
// abortMiddleware.js
const { isAssistantsEndpoint, ErrorTypes } = require('librechat-data-provider'); const { isAssistantsEndpoint, ErrorTypes } = require('librechat-data-provider');
const { sendMessage, sendError, countTokens, isEnabled } = require('~/server/utils'); const { sendMessage, sendError, countTokens, isEnabled } = require('~/server/utils');
const { truncateText, smartTruncateText } = require('~/app/clients/prompts'); const { truncateText, smartTruncateText } = require('~/app/clients/prompts');
@ -8,6 +9,68 @@ const { saveMessage, getConvo } = require('~/models');
const { abortRun } = require('./abortRun'); const { abortRun } = require('./abortRun');
const { logger } = require('~/config'); const { logger } = require('~/config');
const abortDataMap = new WeakMap();
function cleanupAbortController(abortKey) {
if (!abortControllers.has(abortKey)) {
return false;
}
const { abortController } = abortControllers.get(abortKey);
if (!abortController) {
abortControllers.delete(abortKey);
return true;
}
// 1. Check if this controller has any composed signals and clean them up
try {
// This creates a temporary composed signal to use for cleanup
const composedSignal = AbortSignal.any([abortController.signal]);
// Get all event types - in practice, AbortSignal typically only uses 'abort'
const eventTypes = ['abort'];
// First, execute a dummy listener removal to handle potential composed signals
for (const eventType of eventTypes) {
const dummyHandler = () => {};
composedSignal.addEventListener(eventType, dummyHandler);
composedSignal.removeEventListener(eventType, dummyHandler);
const listeners = composedSignal.listeners?.(eventType) || [];
for (const listener of listeners) {
composedSignal.removeEventListener(eventType, listener);
}
}
} catch (e) {
logger.debug(`Error cleaning up composed signals: ${e}`);
}
// 2. Abort the controller if not already aborted
if (!abortController.signal.aborted) {
abortController.abort();
}
// 3. Remove from registry
abortControllers.delete(abortKey);
// 4. Clean up any data stored in the WeakMap
if (abortDataMap.has(abortController)) {
abortDataMap.delete(abortController);
}
// 5. Clean up function references on the controller
if (abortController.getAbortData) {
abortController.getAbortData = null;
}
if (abortController.abortCompletion) {
abortController.abortCompletion = null;
}
return true;
}
async function abortMessage(req, res) { async function abortMessage(req, res) {
let { abortKey, endpoint } = req.body; let { abortKey, endpoint } = req.body;
@ -29,19 +92,19 @@ async function abortMessage(req, res) {
if (!abortController) { if (!abortController) {
return res.status(204).send({ message: 'Request not found' }); return res.status(204).send({ message: 'Request not found' });
} }
const finalEvent = await abortController.abortCompletion();
const finalEvent = await abortController.abortCompletion?.();
logger.debug( logger.debug(
`[abortMessage] ID: ${req.user.id} | ${req.user.email} | Aborted request: ` + `[abortMessage] ID: ${req.user.id} | ${req.user.email} | Aborted request: ` +
JSON.stringify({ abortKey }), JSON.stringify({ abortKey }),
); );
abortControllers.delete(abortKey); cleanupAbortController(abortKey);
if (res.headersSent && finalEvent) { if (res.headersSent && finalEvent) {
return sendMessage(res, finalEvent); return sendMessage(res, finalEvent);
} }
res.setHeader('Content-Type', 'application/json'); res.setHeader('Content-Type', 'application/json');
res.send(JSON.stringify(finalEvent)); res.send(JSON.stringify(finalEvent));
} }
@ -62,8 +125,48 @@ const createAbortController = (req, res, getAbortData, getReqData) => {
const abortController = new AbortController(); const abortController = new AbortController();
const { endpointOption } = req.body; const { endpointOption } = req.body;
// Store minimal data in WeakMap to avoid circular references
abortDataMap.set(abortController, {
getAbortDataFn: getAbortData,
userId: req.user.id,
endpoint: endpointOption.endpoint,
iconURL: endpointOption.iconURL,
model: endpointOption.modelOptions?.model || endpointOption.model_parameters?.model,
});
// Replace the direct function reference with a wrapper that uses WeakMap
abortController.getAbortData = function () { abortController.getAbortData = function () {
return getAbortData(); const data = abortDataMap.get(this);
if (!data || typeof data.getAbortDataFn !== 'function') {
return {};
}
try {
const result = data.getAbortDataFn();
// Create a copy without circular references
const cleanResult = { ...result };
// If userMessagePromise exists, break its reference to client
if (
cleanResult.userMessagePromise &&
typeof cleanResult.userMessagePromise.then === 'function'
) {
// Create a new promise that fulfills with the same result but doesn't reference the original
const originalPromise = cleanResult.userMessagePromise;
cleanResult.userMessagePromise = new Promise((resolve, reject) => {
originalPromise.then(
(result) => resolve({ ...result }),
(error) => reject(error),
);
});
}
return cleanResult;
} catch (err) {
logger.error('[abortController.getAbortData] Error:', err);
return {};
}
}; };
/** /**
@ -74,6 +177,7 @@ const createAbortController = (req, res, getAbortData, getReqData) => {
sendMessage(res, { message: userMessage, created: true }); sendMessage(res, { message: userMessage, created: true });
const abortKey = userMessage?.conversationId ?? req.user.id; const abortKey = userMessage?.conversationId ?? req.user.id;
getReqData({ abortKey });
const prevRequest = abortControllers.get(abortKey); const prevRequest = abortControllers.get(abortKey);
const { overrideUserMessageId } = req?.body ?? {}; const { overrideUserMessageId } = req?.body ?? {};
@ -81,34 +185,74 @@ const createAbortController = (req, res, getAbortData, getReqData) => {
const data = prevRequest.abortController.getAbortData(); const data = prevRequest.abortController.getAbortData();
getReqData({ userMessage: data?.userMessage }); getReqData({ userMessage: data?.userMessage });
const addedAbortKey = `${abortKey}:${responseMessageId}`; const addedAbortKey = `${abortKey}:${responseMessageId}`;
abortControllers.set(addedAbortKey, { abortController, ...endpointOption });
res.on('finish', function () { // Store minimal options
abortControllers.delete(addedAbortKey); const minimalOptions = {
}); endpoint: endpointOption.endpoint,
iconURL: endpointOption.iconURL,
model: endpointOption.modelOptions?.model || endpointOption.model_parameters?.model,
};
abortControllers.set(addedAbortKey, { abortController, ...minimalOptions });
// Use a simple function for cleanup to avoid capturing context
const cleanupHandler = () => {
try {
cleanupAbortController(addedAbortKey);
} catch (e) {
// Ignore cleanup errors
}
};
res.on('finish', cleanupHandler);
return; return;
} }
abortControllers.set(abortKey, { abortController, ...endpointOption }); // Store minimal options
const minimalOptions = {
res.on('finish', function () { endpoint: endpointOption.endpoint,
abortControllers.delete(abortKey); iconURL: endpointOption.iconURL,
}); model: endpointOption.modelOptions?.model || endpointOption.model_parameters?.model,
}; };
abortControllers.set(abortKey, { abortController, ...minimalOptions });
// Use a simple function for cleanup to avoid capturing context
const cleanupHandler = () => {
try {
cleanupAbortController(abortKey);
} catch (e) {
// Ignore cleanup errors
}
};
res.on('finish', cleanupHandler);
};
// Define abortCompletion without capturing the entire parent scope
abortController.abortCompletion = async function () { abortController.abortCompletion = async function () {
abortController.abort(); this.abort();
// Get data from WeakMap
const ctrlData = abortDataMap.get(this);
if (!ctrlData || !ctrlData.getAbortDataFn) {
return { final: true, conversation: {}, title: 'New Chat' };
}
// Get abort data using stored function
const { conversationId, userMessage, userMessagePromise, promptTokens, ...responseData } = const { conversationId, userMessage, userMessagePromise, promptTokens, ...responseData } =
getAbortData(); ctrlData.getAbortDataFn();
const completionTokens = await countTokens(responseData?.text ?? ''); const completionTokens = await countTokens(responseData?.text ?? '');
const user = req.user.id; const user = ctrlData.userId;
const responseMessage = { const responseMessage = {
...responseData, ...responseData,
conversationId, conversationId,
finish_reason: 'incomplete', finish_reason: 'incomplete',
endpoint: endpointOption.endpoint, endpoint: ctrlData.endpoint,
iconURL: endpointOption.iconURL, iconURL: ctrlData.iconURL,
model: endpointOption.modelOptions?.model ?? endpointOption.model_parameters?.model, model: ctrlData.modelOptions?.model ?? ctrlData.model_parameters?.model,
unfinished: false, unfinished: false,
error: false, error: false,
isCreatedByUser: false, isCreatedByUser: false,
@ -130,10 +274,12 @@ const createAbortController = (req, res, getAbortData, getReqData) => {
if (userMessagePromise) { if (userMessagePromise) {
const resolved = await userMessagePromise; const resolved = await userMessagePromise;
conversation = resolved?.conversation; conversation = resolved?.conversation;
// Break reference to promise
resolved.conversation = null;
} }
if (!conversation) { if (!conversation) {
conversation = await getConvo(req.user.id, conversationId); conversation = await getConvo(user, conversationId);
} }
return { return {
@ -218,11 +364,12 @@ const handleAbortError = async (res, req, error, data) => {
}; };
} }
// Create a simple callback without capturing parent scope
const callback = async () => { const callback = async () => {
if (abortControllers.has(conversationId)) { try {
const { abortController } = abortControllers.get(conversationId); cleanupAbortController(conversationId);
abortController.abort(); } catch (e) {
abortControllers.delete(conversationId); // Ignore cleanup errors
} }
}; };
@ -243,6 +390,7 @@ const handleAbortError = async (res, req, error, data) => {
module.exports = { module.exports = {
handleAbort, handleAbort,
createAbortController,
handleAbortError, handleAbortError,
createAbortController,
cleanupAbortController,
}; };

View file

@ -1,4 +1,4 @@
const Keyv = require('keyv'); const { Keyv } = require('keyv');
const uap = require('ua-parser-js'); const uap = require('ua-parser-js');
const { ViolationTypes } = require('librechat-data-provider'); const { ViolationTypes } = require('librechat-data-provider');
const { isEnabled, removePorts } = require('~/server/utils'); const { isEnabled, removePorts } = require('~/server/utils');

View file

@ -1,10 +1,9 @@
const Keyv = require('keyv');
const rateLimit = require('express-rate-limit'); const rateLimit = require('express-rate-limit');
const { RedisStore } = require('rate-limit-redis'); const { RedisStore } = require('rate-limit-redis');
const { ViolationTypes } = require('librechat-data-provider'); const { ViolationTypes } = require('librechat-data-provider');
const ioredisClient = require('~/cache/ioredisClient');
const logViolation = require('~/cache/logViolation'); const logViolation = require('~/cache/logViolation');
const { isEnabled } = require('~/server/utils'); const { isEnabled } = require('~/server/utils');
const keyvRedis = require('~/cache/keyvRedis');
const { logger } = require('~/config'); const { logger } = require('~/config');
const getEnvironmentVariables = () => { const getEnvironmentVariables = () => {
@ -67,11 +66,9 @@ const createImportLimiters = () => {
}, },
}; };
if (isEnabled(process.env.USE_REDIS)) { if (isEnabled(process.env.USE_REDIS) && ioredisClient) {
logger.debug('Using Redis for import rate limiters.'); logger.debug('Using Redis for import rate limiters.');
const keyv = new Keyv({ store: keyvRedis }); const sendCommand = (...args) => ioredisClient.call(...args);
const client = keyv.opts.store.redis;
const sendCommand = (...args) => client.call(...args);
const ipStore = new RedisStore({ const ipStore = new RedisStore({
sendCommand, sendCommand,
prefix: 'import_ip_limiter:', prefix: 'import_ip_limiter:',

View file

@ -1,8 +1,7 @@
const Keyv = require('keyv');
const rateLimit = require('express-rate-limit'); const rateLimit = require('express-rate-limit');
const { RedisStore } = require('rate-limit-redis'); const { RedisStore } = require('rate-limit-redis');
const { removePorts, isEnabled } = require('~/server/utils'); const { removePorts, isEnabled } = require('~/server/utils');
const keyvRedis = require('~/cache/keyvRedis'); const ioredisClient = require('~/cache/ioredisClient');
const { logViolation } = require('~/cache'); const { logViolation } = require('~/cache');
const { logger } = require('~/config'); const { logger } = require('~/config');
@ -31,13 +30,10 @@ const limiterOptions = {
keyGenerator: removePorts, keyGenerator: removePorts,
}; };
if (isEnabled(process.env.USE_REDIS)) { if (isEnabled(process.env.USE_REDIS) && ioredisClient) {
logger.debug('Using Redis for login rate limiter.'); logger.debug('Using Redis for login rate limiter.');
const keyv = new Keyv({ store: keyvRedis });
const client = keyv.opts.store.redis;
const sendCommand = (...args) => client.call(...args);
const store = new RedisStore({ const store = new RedisStore({
sendCommand, sendCommand: (...args) => ioredisClient.call(...args),
prefix: 'login_limiter:', prefix: 'login_limiter:',
}); });
limiterOptions.store = store; limiterOptions.store = store;

View file

@ -1,9 +1,8 @@
const Keyv = require('keyv');
const rateLimit = require('express-rate-limit'); const rateLimit = require('express-rate-limit');
const { RedisStore } = require('rate-limit-redis'); const { RedisStore } = require('rate-limit-redis');
const denyRequest = require('~/server/middleware/denyRequest'); const denyRequest = require('~/server/middleware/denyRequest');
const ioredisClient = require('~/cache/ioredisClient');
const { isEnabled } = require('~/server/utils'); const { isEnabled } = require('~/server/utils');
const keyvRedis = require('~/cache/keyvRedis');
const { logViolation } = require('~/cache'); const { logViolation } = require('~/cache');
const { logger } = require('~/config'); const { logger } = require('~/config');
@ -63,11 +62,9 @@ const userLimiterOptions = {
}, },
}; };
if (isEnabled(process.env.USE_REDIS)) { if (isEnabled(process.env.USE_REDIS) && ioredisClient) {
logger.debug('Using Redis for message rate limiters.'); logger.debug('Using Redis for message rate limiters.');
const keyv = new Keyv({ store: keyvRedis }); const sendCommand = (...args) => ioredisClient.call(...args);
const client = keyv.opts.store.redis;
const sendCommand = (...args) => client.call(...args);
const ipStore = new RedisStore({ const ipStore = new RedisStore({
sendCommand, sendCommand,
prefix: 'message_ip_limiter:', prefix: 'message_ip_limiter:',

View file

@ -1,8 +1,7 @@
const Keyv = require('keyv');
const rateLimit = require('express-rate-limit'); const rateLimit = require('express-rate-limit');
const { RedisStore } = require('rate-limit-redis'); const { RedisStore } = require('rate-limit-redis');
const { removePorts, isEnabled } = require('~/server/utils'); const { removePorts, isEnabled } = require('~/server/utils');
const keyvRedis = require('~/cache/keyvRedis'); const ioredisClient = require('~/cache/ioredisClient');
const { logViolation } = require('~/cache'); const { logViolation } = require('~/cache');
const { logger } = require('~/config'); const { logger } = require('~/config');
@ -31,13 +30,10 @@ const limiterOptions = {
keyGenerator: removePorts, keyGenerator: removePorts,
}; };
if (isEnabled(process.env.USE_REDIS)) { if (isEnabled(process.env.USE_REDIS) && ioredisClient) {
logger.debug('Using Redis for register rate limiter.'); logger.debug('Using Redis for register rate limiter.');
const keyv = new Keyv({ store: keyvRedis });
const client = keyv.opts.store.redis;
const sendCommand = (...args) => client.call(...args);
const store = new RedisStore({ const store = new RedisStore({
sendCommand, sendCommand: (...args) => ioredisClient.call(...args),
prefix: 'register_limiter:', prefix: 'register_limiter:',
}); });
limiterOptions.store = store; limiterOptions.store = store;

View file

@ -1,9 +1,8 @@
const Keyv = require('keyv');
const rateLimit = require('express-rate-limit'); const rateLimit = require('express-rate-limit');
const { RedisStore } = require('rate-limit-redis'); const { RedisStore } = require('rate-limit-redis');
const { ViolationTypes } = require('librechat-data-provider'); const { ViolationTypes } = require('librechat-data-provider');
const { removePorts, isEnabled } = require('~/server/utils'); const { removePorts, isEnabled } = require('~/server/utils');
const keyvRedis = require('~/cache/keyvRedis'); const ioredisClient = require('~/cache/ioredisClient');
const { logViolation } = require('~/cache'); const { logViolation } = require('~/cache');
const { logger } = require('~/config'); const { logger } = require('~/config');
@ -36,13 +35,10 @@ const limiterOptions = {
keyGenerator: removePorts, keyGenerator: removePorts,
}; };
if (isEnabled(process.env.USE_REDIS)) { if (isEnabled(process.env.USE_REDIS) && ioredisClient) {
logger.debug('Using Redis for reset password rate limiter.'); logger.debug('Using Redis for reset password rate limiter.');
const keyv = new Keyv({ store: keyvRedis });
const client = keyv.opts.store.redis;
const sendCommand = (...args) => client.call(...args);
const store = new RedisStore({ const store = new RedisStore({
sendCommand, sendCommand: (...args) => ioredisClient.call(...args),
prefix: 'reset_password_limiter:', prefix: 'reset_password_limiter:',
}); });
limiterOptions.store = store; limiterOptions.store = store;

View file

@ -1,10 +1,9 @@
const Keyv = require('keyv');
const rateLimit = require('express-rate-limit'); const rateLimit = require('express-rate-limit');
const { RedisStore } = require('rate-limit-redis'); const { RedisStore } = require('rate-limit-redis');
const { ViolationTypes } = require('librechat-data-provider'); const { ViolationTypes } = require('librechat-data-provider');
const ioredisClient = require('~/cache/ioredisClient');
const logViolation = require('~/cache/logViolation'); const logViolation = require('~/cache/logViolation');
const { isEnabled } = require('~/server/utils'); const { isEnabled } = require('~/server/utils');
const keyvRedis = require('~/cache/keyvRedis');
const { logger } = require('~/config'); const { logger } = require('~/config');
const getEnvironmentVariables = () => { const getEnvironmentVariables = () => {
@ -67,11 +66,9 @@ const createSTTLimiters = () => {
}, },
}; };
if (isEnabled(process.env.USE_REDIS)) { if (isEnabled(process.env.USE_REDIS) && ioredisClient) {
logger.debug('Using Redis for STT rate limiters.'); logger.debug('Using Redis for STT rate limiters.');
const keyv = new Keyv({ store: keyvRedis }); const sendCommand = (...args) => ioredisClient.call(...args);
const client = keyv.opts.store.redis;
const sendCommand = (...args) => client.call(...args);
const ipStore = new RedisStore({ const ipStore = new RedisStore({
sendCommand, sendCommand,
prefix: 'stt_ip_limiter:', prefix: 'stt_ip_limiter:',

View file

@ -1,10 +1,9 @@
const Keyv = require('keyv');
const rateLimit = require('express-rate-limit'); const rateLimit = require('express-rate-limit');
const { RedisStore } = require('rate-limit-redis'); const { RedisStore } = require('rate-limit-redis');
const { ViolationTypes } = require('librechat-data-provider'); const { ViolationTypes } = require('librechat-data-provider');
const ioredisClient = require('~/cache/ioredisClient');
const logViolation = require('~/cache/logViolation'); const logViolation = require('~/cache/logViolation');
const { isEnabled } = require('~/server/utils'); const { isEnabled } = require('~/server/utils');
const keyvRedis = require('~/cache/keyvRedis');
const { logger } = require('~/config'); const { logger } = require('~/config');
const handler = async (req, res) => { const handler = async (req, res) => {
@ -29,13 +28,10 @@ const limiterOptions = {
}, },
}; };
if (isEnabled(process.env.USE_REDIS)) { if (isEnabled(process.env.USE_REDIS) && ioredisClient) {
logger.debug('Using Redis for tool call rate limiter.'); logger.debug('Using Redis for tool call rate limiter.');
const keyv = new Keyv({ store: keyvRedis });
const client = keyv.opts.store.redis;
const sendCommand = (...args) => client.call(...args);
const store = new RedisStore({ const store = new RedisStore({
sendCommand, sendCommand: (...args) => ioredisClient.call(...args),
prefix: 'tool_call_limiter:', prefix: 'tool_call_limiter:',
}); });
limiterOptions.store = store; limiterOptions.store = store;

View file

@ -1,10 +1,9 @@
const Keyv = require('keyv');
const rateLimit = require('express-rate-limit'); const rateLimit = require('express-rate-limit');
const { RedisStore } = require('rate-limit-redis'); const { RedisStore } = require('rate-limit-redis');
const { ViolationTypes } = require('librechat-data-provider'); const { ViolationTypes } = require('librechat-data-provider');
const ioredisClient = require('~/cache/ioredisClient');
const logViolation = require('~/cache/logViolation'); const logViolation = require('~/cache/logViolation');
const { isEnabled } = require('~/server/utils'); const { isEnabled } = require('~/server/utils');
const keyvRedis = require('~/cache/keyvRedis');
const { logger } = require('~/config'); const { logger } = require('~/config');
const getEnvironmentVariables = () => { const getEnvironmentVariables = () => {
@ -67,11 +66,9 @@ const createTTSLimiters = () => {
}, },
}; };
if (isEnabled(process.env.USE_REDIS)) { if (isEnabled(process.env.USE_REDIS) && ioredisClient) {
logger.debug('Using Redis for TTS rate limiters.'); logger.debug('Using Redis for TTS rate limiters.');
const keyv = new Keyv({ store: keyvRedis }); const sendCommand = (...args) => ioredisClient.call(...args);
const client = keyv.opts.store.redis;
const sendCommand = (...args) => client.call(...args);
const ipStore = new RedisStore({ const ipStore = new RedisStore({
sendCommand, sendCommand,
prefix: 'tts_ip_limiter:', prefix: 'tts_ip_limiter:',

View file

@ -1,10 +1,9 @@
const Keyv = require('keyv');
const rateLimit = require('express-rate-limit'); const rateLimit = require('express-rate-limit');
const { RedisStore } = require('rate-limit-redis'); const { RedisStore } = require('rate-limit-redis');
const { ViolationTypes } = require('librechat-data-provider'); const { ViolationTypes } = require('librechat-data-provider');
const ioredisClient = require('~/cache/ioredisClient');
const logViolation = require('~/cache/logViolation'); const logViolation = require('~/cache/logViolation');
const { isEnabled } = require('~/server/utils'); const { isEnabled } = require('~/server/utils');
const keyvRedis = require('~/cache/keyvRedis');
const { logger } = require('~/config'); const { logger } = require('~/config');
const getEnvironmentVariables = () => { const getEnvironmentVariables = () => {
@ -72,11 +71,9 @@ const createFileLimiters = () => {
}, },
}; };
if (isEnabled(process.env.USE_REDIS)) { if (isEnabled(process.env.USE_REDIS) && ioredisClient) {
logger.debug('Using Redis for file upload rate limiters.'); logger.debug('Using Redis for file upload rate limiters.');
const keyv = new Keyv({ store: keyvRedis }); const sendCommand = (...args) => ioredisClient.call(...args);
const client = keyv.opts.store.redis;
const sendCommand = (...args) => client.call(...args);
const ipStore = new RedisStore({ const ipStore = new RedisStore({
sendCommand, sendCommand,
prefix: 'file_upload_ip_limiter:', prefix: 'file_upload_ip_limiter:',

View file

@ -1,9 +1,8 @@
const Keyv = require('keyv');
const rateLimit = require('express-rate-limit'); const rateLimit = require('express-rate-limit');
const { RedisStore } = require('rate-limit-redis'); const { RedisStore } = require('rate-limit-redis');
const { ViolationTypes } = require('librechat-data-provider'); const { ViolationTypes } = require('librechat-data-provider');
const { removePorts, isEnabled } = require('~/server/utils'); const { removePorts, isEnabled } = require('~/server/utils');
const keyvRedis = require('~/cache/keyvRedis'); const ioredisClient = require('~/cache/ioredisClient');
const { logViolation } = require('~/cache'); const { logViolation } = require('~/cache');
const { logger } = require('~/config'); const { logger } = require('~/config');
@ -36,13 +35,10 @@ const limiterOptions = {
keyGenerator: removePorts, keyGenerator: removePorts,
}; };
if (isEnabled(process.env.USE_REDIS)) { if (isEnabled(process.env.USE_REDIS) && ioredisClient) {
logger.debug('Using Redis for verify email rate limiter.'); logger.debug('Using Redis for verify email rate limiter.');
const keyv = new Keyv({ store: keyvRedis });
const client = keyv.opts.store.redis;
const sendCommand = (...args) => client.call(...args);
const store = new RedisStore({ const store = new RedisStore({
sendCommand, sendCommand: (...args) => ioredisClient.call(...args),
prefix: 'verify_email_limiter:', prefix: 'verify_email_limiter:',
}); });
limiterOptions.store = store; limiterOptions.store = store;

View file

@ -58,7 +58,7 @@ router.post('/:agent_id', async (req, res) => {
} }
let { domain } = metadata; let { domain } = metadata;
domain = await domainParser(req, domain, true); domain = await domainParser(domain, true);
if (!domain) { if (!domain) {
return res.status(400).json({ message: 'No domain provided' }); return res.status(400).json({ message: 'No domain provided' });
@ -164,7 +164,7 @@ router.delete('/:agent_id/:action_id', async (req, res) => {
return true; return true;
}); });
domain = await domainParser(req, domain, true); domain = await domainParser(domain, true);
if (!domain) { if (!domain) {
return res.status(400).json({ message: 'No domain provided' }); return res.status(400).json({ message: 'No domain provided' });

View file

@ -2,7 +2,6 @@ const express = require('express');
const { PermissionTypes, Permissions } = require('librechat-data-provider'); const { PermissionTypes, Permissions } = require('librechat-data-provider');
const { const {
setHeaders, setHeaders,
handleAbort,
moderateText, moderateText,
// validateModel, // validateModel,
generateCheckAccess, generateCheckAccess,
@ -16,7 +15,6 @@ const addTitle = require('~/server/services/Endpoints/agents/title');
const router = express.Router(); const router = express.Router();
router.use(moderateText); router.use(moderateText);
router.post('/abort', handleAbort());
const checkAgentAccess = generateCheckAccess(PermissionTypes.AGENTS, [Permissions.USE]); const checkAgentAccess = generateCheckAccess(PermissionTypes.AGENTS, [Permissions.USE]);

View file

@ -1,4 +1,4 @@
const Keyv = require('keyv'); const { Keyv } = require('keyv');
const { KeyvFile } = require('keyv-file'); const { KeyvFile } = require('keyv-file');
const { logger } = require('~/config'); const { logger } = require('~/config');

View file

@ -11,8 +11,6 @@ const {
const router = express.Router(); const router = express.Router();
router.post('/abort', handleAbort());
router.post( router.post(
'/', '/',
validateEndpoint, validateEndpoint,

View file

@ -3,7 +3,6 @@ const AskController = require('~/server/controllers/AskController');
const { initializeClient } = require('~/server/services/Endpoints/custom'); const { initializeClient } = require('~/server/services/Endpoints/custom');
const { addTitle } = require('~/server/services/Endpoints/openAI'); const { addTitle } = require('~/server/services/Endpoints/openAI');
const { const {
handleAbort,
setHeaders, setHeaders,
validateModel, validateModel,
validateEndpoint, validateEndpoint,
@ -12,8 +11,6 @@ const {
const router = express.Router(); const router = express.Router();
router.post('/abort', handleAbort());
router.post( router.post(
'/', '/',
validateEndpoint, validateEndpoint,

View file

@ -3,7 +3,6 @@ const AskController = require('~/server/controllers/AskController');
const { initializeClient, addTitle } = require('~/server/services/Endpoints/google'); const { initializeClient, addTitle } = require('~/server/services/Endpoints/google');
const { const {
setHeaders, setHeaders,
handleAbort,
validateModel, validateModel,
validateEndpoint, validateEndpoint,
buildEndpointOption, buildEndpointOption,
@ -11,8 +10,6 @@ const {
const router = express.Router(); const router = express.Router();
router.post('/abort', handleAbort());
router.post( router.post(
'/', '/',
validateEndpoint, validateEndpoint,

View file

@ -20,7 +20,6 @@ const { logger } = require('~/config');
const router = express.Router(); const router = express.Router();
router.use(moderateText); router.use(moderateText);
router.post('/abort', handleAbort());
router.post( router.post(
'/', '/',
@ -196,7 +195,8 @@ router.post(
logger.debug('[/ask/gptPlugins]', response); logger.debug('[/ask/gptPlugins]', response);
const { conversation = {} } = await client.responsePromise; const { conversation = {} } = await response.databasePromise;
delete response.databasePromise;
conversation.title = conversation.title =
conversation && !conversation.title ? null : conversation?.title || 'New Chat'; conversation && !conversation.title ? null : conversation?.title || 'New Chat';

View file

@ -12,7 +12,6 @@ const {
const router = express.Router(); const router = express.Router();
router.use(moderateText); router.use(moderateText);
router.post('/abort', handleAbort());
router.post( router.post(
'/', '/',

View file

@ -36,7 +36,7 @@ router.post('/:assistant_id', async (req, res) => {
} }
let { domain } = metadata; let { domain } = metadata;
domain = await domainParser(req, domain, true); domain = await domainParser(domain, true);
if (!domain) { if (!domain) {
return res.status(400).json({ message: 'No domain provided' }); return res.status(400).json({ message: 'No domain provided' });
@ -172,7 +172,7 @@ router.delete('/:assistant_id/:action_id/:model', async (req, res) => {
return true; return true;
}); });
domain = await domainParser(req, domain, true); domain = await domainParser(domain, true);
if (!domain) { if (!domain) {
return res.status(400).json({ message: 'No domain provided' }); return res.status(400).json({ message: 'No domain provided' });

View file

@ -14,7 +14,6 @@ const AgentController = require('~/server/controllers/agents/request');
const addTitle = require('~/server/services/Endpoints/agents/title'); const addTitle = require('~/server/services/Endpoints/agents/title');
router.use(moderateText); router.use(moderateText);
router.post('/abort', handleAbort());
/** /**
* @route POST / * @route POST /

View file

@ -3,7 +3,6 @@ const EditController = require('~/server/controllers/EditController');
const { initializeClient } = require('~/server/services/Endpoints/anthropic'); const { initializeClient } = require('~/server/services/Endpoints/anthropic');
const { const {
setHeaders, setHeaders,
handleAbort,
validateModel, validateModel,
validateEndpoint, validateEndpoint,
buildEndpointOption, buildEndpointOption,
@ -11,8 +10,6 @@ const {
const router = express.Router(); const router = express.Router();
router.post('/abort', handleAbort());
router.post( router.post(
'/', '/',
validateEndpoint, validateEndpoint,

View file

@ -12,8 +12,6 @@ const {
const router = express.Router(); const router = express.Router();
router.post('/abort', handleAbort());
router.post( router.post(
'/', '/',
validateEndpoint, validateEndpoint,

View file

@ -3,7 +3,6 @@ const EditController = require('~/server/controllers/EditController');
const { initializeClient } = require('~/server/services/Endpoints/google'); const { initializeClient } = require('~/server/services/Endpoints/google');
const { const {
setHeaders, setHeaders,
handleAbort,
validateModel, validateModel,
validateEndpoint, validateEndpoint,
buildEndpointOption, buildEndpointOption,
@ -11,8 +10,6 @@ const {
const router = express.Router(); const router = express.Router();
router.post('/abort', handleAbort());
router.post( router.post(
'/', '/',
validateEndpoint, validateEndpoint,

View file

@ -2,7 +2,6 @@ const express = require('express');
const { getResponseSender } = require('librechat-data-provider'); const { getResponseSender } = require('librechat-data-provider');
const { const {
setHeaders, setHeaders,
handleAbort,
moderateText, moderateText,
validateModel, validateModel,
handleAbortError, handleAbortError,
@ -19,7 +18,6 @@ const { logger } = require('~/config');
const router = express.Router(); const router = express.Router();
router.use(moderateText); router.use(moderateText);
router.post('/abort', handleAbort());
router.post( router.post(
'/', '/',
@ -173,7 +171,8 @@ router.post(
logger.debug('[/edit/gptPlugins] CLIENT RESPONSE', response); logger.debug('[/edit/gptPlugins] CLIENT RESPONSE', response);
const { conversation = {} } = await client.responsePromise; const { conversation = {} } = await response.databasePromise;
delete response.databasePromise;
conversation.title = conversation.title =
conversation && !conversation.title ? null : conversation?.title || 'New Chat'; conversation && !conversation.title ? null : conversation?.title || 'New Chat';

View file

@ -2,7 +2,6 @@ const express = require('express');
const EditController = require('~/server/controllers/EditController'); const EditController = require('~/server/controllers/EditController');
const { initializeClient } = require('~/server/services/Endpoints/openAI'); const { initializeClient } = require('~/server/services/Endpoints/openAI');
const { const {
handleAbort,
setHeaders, setHeaders,
validateModel, validateModel,
validateEndpoint, validateEndpoint,
@ -12,7 +11,6 @@ const {
const router = express.Router(); const router = express.Router();
router.use(moderateText); router.use(moderateText);
router.post('/abort', handleAbort());
router.post( router.post(
'/', '/',

View file

@ -1,4 +1,4 @@
const Keyv = require('keyv'); const { Keyv } = require('keyv');
const express = require('express'); const express = require('express');
const { MeiliSearch } = require('meilisearch'); const { MeiliSearch } = require('meilisearch');
const { Conversation, getConvosQueried } = require('~/models/Conversation'); const { Conversation, getConvosQueried } = require('~/models/Conversation');

View file

@ -50,7 +50,7 @@ const validateAndUpdateTool = async ({ req, tool, assistant_id }) => {
return null; return null;
} }
const parsedDomain = await domainParser(req, domain, true); const parsedDomain = await domainParser(domain, true);
if (!parsedDomain) { if (!parsedDomain) {
return null; return null;
@ -66,12 +66,11 @@ const validateAndUpdateTool = async ({ req, tool, assistant_id }) => {
* *
* Necessary due to `[a-zA-Z0-9_-]*` Regex Validation, limited to a 64-character maximum. * Necessary due to `[a-zA-Z0-9_-]*` Regex Validation, limited to a 64-character maximum.
* *
* @param {Express.Request} req - The Express Request object.
* @param {string} domain - The domain name to encode/decode. * @param {string} domain - The domain name to encode/decode.
* @param {boolean} inverse - False to decode from base64, true to encode to base64. * @param {boolean} inverse - False to decode from base64, true to encode to base64.
* @returns {Promise<string>} Encoded or decoded domain string. * @returns {Promise<string>} Encoded or decoded domain string.
*/ */
async function domainParser(req, domain, inverse = false) { async function domainParser(domain, inverse = false) {
if (!domain) { if (!domain) {
return; return;
} }
@ -122,7 +121,7 @@ async function loadActionSets(searchParams) {
* Creates a general tool for an entire action set. * Creates a general tool for an entire action set.
* *
* @param {Object} params - The parameters for loading action sets. * @param {Object} params - The parameters for loading action sets.
* @param {ServerRequest} params.req * @param {string} params.userId
* @param {ServerResponse} params.res * @param {ServerResponse} params.res
* @param {Action} params.action - The action set. Necessary for decrypting authentication values. * @param {Action} params.action - The action set. Necessary for decrypting authentication values.
* @param {ActionRequest} params.requestBuilder - The ActionRequest builder class to execute the API call. * @param {ActionRequest} params.requestBuilder - The ActionRequest builder class to execute the API call.
@ -133,7 +132,7 @@ async function loadActionSets(searchParams) {
* @returns { Promise<typeof tool | { _call: (toolInput: Object | string) => unknown}> } An object with `_call` method to execute the tool input. * @returns { Promise<typeof tool | { _call: (toolInput: Object | string) => unknown}> } An object with `_call` method to execute the tool input.
*/ */
async function createActionTool({ async function createActionTool({
req, userId,
res, res,
action, action,
requestBuilder, requestBuilder,
@ -154,7 +153,7 @@ async function createActionTool({
try { try {
if (metadata.auth.type === AuthTypeEnum.OAuth && metadata.auth.authorization_url) { if (metadata.auth.type === AuthTypeEnum.OAuth && metadata.auth.authorization_url) {
const action_id = action.action_id; const action_id = action.action_id;
const identifier = `${req.user.id}:${action.action_id}`; const identifier = `${userId}:${action.action_id}`;
const requestLogin = async () => { const requestLogin = async () => {
const { args: _args, stepId, ...toolCall } = config.toolCall ?? {}; const { args: _args, stepId, ...toolCall } = config.toolCall ?? {};
if (!stepId) { if (!stepId) {
@ -162,7 +161,7 @@ async function createActionTool({
} }
const statePayload = { const statePayload = {
nonce: nanoid(), nonce: nanoid(),
user: req.user.id, user: userId,
action_id, action_id,
}; };
@ -206,7 +205,7 @@ async function createActionTool({
'oauth', 'oauth',
{ {
state: stateToken, state: stateToken,
userId: req.user.id, userId: userId,
client_url: metadata.auth.client_url, client_url: metadata.auth.client_url,
redirect_uri: `${process.env.DOMAIN_CLIENT}/api/actions/${action_id}/oauth/callback`, redirect_uri: `${process.env.DOMAIN_CLIENT}/api/actions/${action_id}/oauth/callback`,
/** Encrypted values */ /** Encrypted values */
@ -232,10 +231,10 @@ async function createActionTool({
}; };
const tokenPromises = []; const tokenPromises = [];
tokenPromises.push(findToken({ userId: req.user.id, type: 'oauth', identifier })); tokenPromises.push(findToken({ userId, type: 'oauth', identifier }));
tokenPromises.push( tokenPromises.push(
findToken({ findToken({
userId: req.user.id, userId,
type: 'oauth_refresh', type: 'oauth_refresh',
identifier: `${identifier}:refresh`, identifier: `${identifier}:refresh`,
}), }),
@ -258,9 +257,9 @@ async function createActionTool({
const refresh_token = await decryptV2(refreshTokenData.token); const refresh_token = await decryptV2(refreshTokenData.token);
const refreshTokens = async () => const refreshTokens = async () =>
await refreshAccessToken({ await refreshAccessToken({
userId,
identifier, identifier,
refresh_token, refresh_token,
userId: req.user.id,
client_url: metadata.auth.client_url, client_url: metadata.auth.client_url,
encrypted_oauth_client_id: encrypted.oauth_client_id, encrypted_oauth_client_id: encrypted.oauth_client_id,
encrypted_oauth_client_secret: encrypted.oauth_client_secret, encrypted_oauth_client_secret: encrypted.oauth_client_secret,

View file

@ -78,20 +78,20 @@ describe('domainParser', () => {
// Non-azure request // Non-azure request
it('does not return domain as is if not azure', async () => { it('does not return domain as is if not azure', async () => {
const domain = `example.com${actionDomainSeparator}test${actionDomainSeparator}`; const domain = `example.com${actionDomainSeparator}test${actionDomainSeparator}`;
const result1 = await domainParser(reqNoAzure, domain, false); const result1 = await domainParser(domain, false);
const result2 = await domainParser(reqNoAzure, domain, true); const result2 = await domainParser(domain, true);
expect(result1).not.toEqual(domain); expect(result1).not.toEqual(domain);
expect(result2).not.toEqual(domain); expect(result2).not.toEqual(domain);
}); });
// Test for Empty or Null Inputs // Test for Empty or Null Inputs
it('returns undefined for null domain input', async () => { it('returns undefined for null domain input', async () => {
const result = await domainParser(req, null, true); const result = await domainParser(null, true);
expect(result).toBeUndefined(); expect(result).toBeUndefined();
}); });
it('returns undefined for empty domain input', async () => { it('returns undefined for empty domain input', async () => {
const result = await domainParser(req, '', true); const result = await domainParser('', true);
expect(result).toBeUndefined(); expect(result).toBeUndefined();
}); });
@ -102,7 +102,7 @@ describe('domainParser', () => {
.toString('base64') .toString('base64')
.substring(0, Constants.ENCODED_DOMAIN_LENGTH); .substring(0, Constants.ENCODED_DOMAIN_LENGTH);
await domainParser(req, domain, true); await domainParser(domain, true);
const cachedValue = await globalCache[encodedDomain]; const cachedValue = await globalCache[encodedDomain];
expect(cachedValue).toEqual(Buffer.from(domain).toString('base64')); expect(cachedValue).toEqual(Buffer.from(domain).toString('base64'));
@ -112,14 +112,14 @@ describe('domainParser', () => {
it('encodes domain exactly at threshold without modification', async () => { it('encodes domain exactly at threshold without modification', async () => {
const domain = 'a'.repeat(Constants.ENCODED_DOMAIN_LENGTH - TLD.length) + TLD; const domain = 'a'.repeat(Constants.ENCODED_DOMAIN_LENGTH - TLD.length) + TLD;
const expected = domain.replace(/\./g, actionDomainSeparator); const expected = domain.replace(/\./g, actionDomainSeparator);
const result = await domainParser(req, domain, true); const result = await domainParser(domain, true);
expect(result).toEqual(expected); expect(result).toEqual(expected);
}); });
it('encodes domain just below threshold without modification', async () => { it('encodes domain just below threshold without modification', async () => {
const domain = 'a'.repeat(Constants.ENCODED_DOMAIN_LENGTH - 1 - TLD.length) + TLD; const domain = 'a'.repeat(Constants.ENCODED_DOMAIN_LENGTH - 1 - TLD.length) + TLD;
const expected = domain.replace(/\./g, actionDomainSeparator); const expected = domain.replace(/\./g, actionDomainSeparator);
const result = await domainParser(req, domain, true); const result = await domainParser(domain, true);
expect(result).toEqual(expected); expect(result).toEqual(expected);
}); });
@ -129,7 +129,7 @@ describe('domainParser', () => {
const encodedDomain = Buffer.from(unicodeDomain) const encodedDomain = Buffer.from(unicodeDomain)
.toString('base64') .toString('base64')
.substring(0, Constants.ENCODED_DOMAIN_LENGTH); .substring(0, Constants.ENCODED_DOMAIN_LENGTH);
const result = await domainParser(req, unicodeDomain, true); const result = await domainParser(unicodeDomain, true);
expect(result).toEqual(encodedDomain); expect(result).toEqual(encodedDomain);
}); });
@ -139,7 +139,6 @@ describe('domainParser', () => {
globalCache[encodedDomain.substring(0, Constants.ENCODED_DOMAIN_LENGTH)] = encodedDomain; // Simulate caching globalCache[encodedDomain.substring(0, Constants.ENCODED_DOMAIN_LENGTH)] = encodedDomain; // Simulate caching
const result = await domainParser( const result = await domainParser(
req,
encodedDomain.substring(0, Constants.ENCODED_DOMAIN_LENGTH), encodedDomain.substring(0, Constants.ENCODED_DOMAIN_LENGTH),
false, false,
); );
@ -150,27 +149,27 @@ describe('domainParser', () => {
it('returns domain with replaced separators if no cached domain exists', async () => { it('returns domain with replaced separators if no cached domain exists', async () => {
const domain = 'example.com'; const domain = 'example.com';
const withSeparator = domain.replace(/\./g, actionDomainSeparator); const withSeparator = domain.replace(/\./g, actionDomainSeparator);
const result = await domainParser(req, withSeparator, false); const result = await domainParser(withSeparator, false);
expect(result).toEqual(domain); expect(result).toEqual(domain);
}); });
it('returns domain with replaced separators when inverse is false and under encoding length', async () => { it('returns domain with replaced separators when inverse is false and under encoding length', async () => {
const domain = 'examp.com'; const domain = 'examp.com';
const withSeparator = domain.replace(/\./g, actionDomainSeparator); const withSeparator = domain.replace(/\./g, actionDomainSeparator);
const result = await domainParser(req, withSeparator, false); const result = await domainParser(withSeparator, false);
expect(result).toEqual(domain); expect(result).toEqual(domain);
}); });
it('replaces periods with actionDomainSeparator when inverse is true and under encoding length', async () => { it('replaces periods with actionDomainSeparator when inverse is true and under encoding length', async () => {
const domain = 'examp.com'; const domain = 'examp.com';
const expected = domain.replace(/\./g, actionDomainSeparator); const expected = domain.replace(/\./g, actionDomainSeparator);
const result = await domainParser(req, domain, true); const result = await domainParser(domain, true);
expect(result).toEqual(expected); expect(result).toEqual(expected);
}); });
it('encodes domain when length is above threshold and inverse is true', async () => { it('encodes domain when length is above threshold and inverse is true', async () => {
const domain = 'a'.repeat(Constants.ENCODED_DOMAIN_LENGTH + 1).concat('.com'); const domain = 'a'.repeat(Constants.ENCODED_DOMAIN_LENGTH + 1).concat('.com');
const result = await domainParser(req, domain, true); const result = await domainParser(domain, true);
expect(result).not.toEqual(domain); expect(result).not.toEqual(domain);
expect(result.length).toBeLessThanOrEqual(Constants.ENCODED_DOMAIN_LENGTH); expect(result.length).toBeLessThanOrEqual(Constants.ENCODED_DOMAIN_LENGTH);
}); });
@ -180,20 +179,20 @@ describe('domainParser', () => {
const encodedDomain = Buffer.from( const encodedDomain = Buffer.from(
originalDomain.replace(/\./g, actionDomainSeparator), originalDomain.replace(/\./g, actionDomainSeparator),
).toString('base64'); ).toString('base64');
const result = await domainParser(req, encodedDomain, false); const result = await domainParser(encodedDomain, false);
expect(result).toEqual(encodedDomain); expect(result).toEqual(encodedDomain);
}); });
it('decodes encoded value if cached and encoded value is provided, and inverse is false', async () => { it('decodes encoded value if cached and encoded value is provided, and inverse is false', async () => {
const originalDomain = 'example.com'; const originalDomain = 'example.com';
const encodedDomain = await domainParser(req, originalDomain, true); const encodedDomain = await domainParser(originalDomain, true);
const result = await domainParser(req, encodedDomain, false); const result = await domainParser(encodedDomain, false);
expect(result).toEqual(originalDomain); expect(result).toEqual(originalDomain);
}); });
it('handles invalid base64 encoded values gracefully', async () => { it('handles invalid base64 encoded values gracefully', async () => {
const invalidBase64Domain = 'not_base64_encoded'; const invalidBase64Domain = 'not_base64_encoded';
const result = await domainParser(req, invalidBase64Domain, false); const result = await domainParser(invalidBase64Domain, false);
expect(result).toEqual(invalidBase64Domain); expect(result).toEqual(invalidBase64Domain);
}); });
}); });

View file

@ -159,14 +159,20 @@ const initializeAgentOptions = async ({
currentFiles, currentFiles,
agent.tool_resources, agent.tool_resources,
); );
const provider = agent.provider;
const { tools, toolContextMap } = await loadAgentTools({ const { tools, toolContextMap } = await loadAgentTools({
req, req,
res, res,
agent, agent: {
id: agent.id,
tools: agent.tools,
provider,
model: agent.model,
},
tool_resources, tool_resources,
}); });
const provider = agent.provider;
agent.endpoint = provider; agent.endpoint = provider;
let getOptions = providerConfigMap[provider]; let getOptions = providerConfigMap[provider];
if (!getOptions && providerConfigMap[provider.toLowerCase()] != null) { if (!getOptions && providerConfigMap[provider.toLowerCase()] != null) {

View file

@ -2,7 +2,11 @@ const { CacheKeys } = require('librechat-data-provider');
const getLogStores = require('~/cache/getLogStores'); const getLogStores = require('~/cache/getLogStores');
const { isEnabled } = require('~/server/utils'); const { isEnabled } = require('~/server/utils');
const { saveConvo } = require('~/models'); const { saveConvo } = require('~/models');
const { logger } = require('~/config');
/**
* Add title to conversation in a way that avoids memory retention
*/
const addTitle = async (req, { text, response, client }) => { const addTitle = async (req, { text, response, client }) => {
const { TITLE_CONVO = true } = process.env ?? {}; const { TITLE_CONVO = true } = process.env ?? {};
if (!isEnabled(TITLE_CONVO)) { if (!isEnabled(TITLE_CONVO)) {
@ -13,28 +17,43 @@ const addTitle = async (req, { text, response, client }) => {
return; return;
} }
// If the request was aborted, don't generate the title. const titleCache = getLogStores(CacheKeys.GEN_TITLE);
if (client.abortController.signal.aborted) { const key = `${req.user.id}-${response.conversationId}`;
/** @type {NodeJS.Timeout} */
let timeoutId;
try {
const timeoutPromise = new Promise((_, reject) => {
timeoutId = setTimeout(() => reject(new Error('Title generation timeout')), 25000);
}).catch((error) => {
logger.error('Title error:', error);
});
let titlePromise;
let abortController = new AbortController();
if (client && typeof client.titleConvo === 'function') {
titlePromise = Promise.race([
client
.titleConvo({
text,
abortController,
})
.catch((error) => {
logger.error('Client title error:', error);
}),
timeoutPromise,
]);
} else {
return; return;
} }
const titleCache = getLogStores(CacheKeys.GEN_TITLE); const title = await titlePromise;
const key = `${req.user.id}-${response.conversationId}`; if (!abortController.signal.aborted) {
const responseText = abortController.abort();
response?.content && Array.isArray(response?.content) }
? response.content.reduce((acc, block) => { if (timeoutId) {
if (block?.type === 'text') { clearTimeout(timeoutId);
return acc + block.text;
} }
return acc;
}, '')
: (response?.content ?? response?.text ?? '');
const title = await client.titleConvo({
text,
responseText,
conversationId: response.conversationId,
});
await titleCache.set(key, title, 120000); await titleCache.set(key, title, 120000);
await saveConvo( await saveConvo(
req, req,
@ -44,6 +63,9 @@ const addTitle = async (req, { text, response, client }) => {
}, },
{ context: 'api/server/services/Endpoints/agents/title.js' }, { context: 'api/server/services/Endpoints/agents/title.js' },
); );
} catch (error) {
logger.error('Error generating title:', error);
}
}; };
module.exports = addTitle; module.exports = addTitle;

View file

@ -1,7 +1,7 @@
const { EModelEndpoint } = require('librechat-data-provider'); const { EModelEndpoint } = require('librechat-data-provider');
const { getUserKey, checkUserKeyExpiry } = require('~/server/services/UserService'); const { getUserKey, checkUserKeyExpiry } = require('~/server/services/UserService');
const { getLLMConfig } = require('~/server/services/Endpoints/anthropic/llm'); const { getLLMConfig } = require('~/server/services/Endpoints/anthropic/llm');
const { AnthropicClient } = require('~/app'); const AnthropicClient = require('~/app/clients/AnthropicClient');
const initializeClient = async ({ req, res, endpointOption, overrideModel, optionsOnly }) => { const initializeClient = async ({ req, res, endpointOption, overrideModel, optionsOnly }) => {
const { ANTHROPIC_API_KEY, ANTHROPIC_REVERSE_PROXY, PROXY } = process.env; const { ANTHROPIC_API_KEY, ANTHROPIC_REVERSE_PROXY, PROXY } = process.env;

View file

@ -13,11 +13,6 @@ const addTitle = async (req, { text, response, client }) => {
return; return;
} }
// If the request was aborted, don't generate the title.
if (client.abortController.signal.aborted) {
return;
}
const titleCache = getLogStores(CacheKeys.GEN_TITLE); const titleCache = getLogStores(CacheKeys.GEN_TITLE);
const key = `${req.user.id}-${response.conversationId}`; const key = `${req.user.id}-${response.conversationId}`;

View file

@ -11,8 +11,8 @@ const { getLLMConfig } = require('~/server/services/Endpoints/openAI/llm');
const { getCustomEndpointConfig } = require('~/server/services/Config'); const { getCustomEndpointConfig } = require('~/server/services/Config');
const { fetchModels } = require('~/server/services/ModelService'); const { fetchModels } = require('~/server/services/ModelService');
const { isUserProvided, sleep } = require('~/server/utils'); const { isUserProvided, sleep } = require('~/server/utils');
const OpenAIClient = require('~/app/clients/OpenAIClient');
const getLogStores = require('~/cache/getLogStores'); const getLogStores = require('~/cache/getLogStores');
const { OpenAIClient } = require('~/app');
const { PROXY } = process.env; const { PROXY } = process.env;

View file

@ -7,8 +7,14 @@ const {
const { getUserKeyValues, checkUserKeyExpiry } = require('~/server/services/UserService'); const { getUserKeyValues, checkUserKeyExpiry } = require('~/server/services/UserService');
const { getLLMConfig } = require('~/server/services/Endpoints/openAI/llm'); const { getLLMConfig } = require('~/server/services/Endpoints/openAI/llm');
const { isEnabled, isUserProvided, sleep } = require('~/server/utils'); const { isEnabled, isUserProvided, sleep } = require('~/server/utils');
const OpenAIClient = require('~/app/clients/OpenAIClient');
const { getAzureCredentials } = require('~/utils'); const { getAzureCredentials } = require('~/utils');
const { OpenAIClient } = require('~/app');
function createHandleNewToken(streamRate) {
async () => {
await sleep(streamRate);
};
}
const initializeClient = async ({ const initializeClient = async ({
req, req,
@ -140,14 +146,13 @@ const initializeClient = async ({
clientOptions = Object.assign({ modelOptions }, clientOptions); clientOptions = Object.assign({ modelOptions }, clientOptions);
clientOptions.modelOptions.user = req.user.id; clientOptions.modelOptions.user = req.user.id;
const options = getLLMConfig(apiKey, clientOptions); const options = getLLMConfig(apiKey, clientOptions);
if (!clientOptions.streamRate) { const streamRate = clientOptions.streamRate;
if (!streamRate) {
return options; return options;
} }
options.llmConfig.callbacks = [ options.llmConfig.callbacks = [
{ {
handleLLMNewToken: async () => { handleLLMNewToken: createHandleNewToken(streamRate),
await sleep(clientOptions.streamRate);
},
}, },
]; ];
return options; return options;

View file

@ -13,11 +13,6 @@ const addTitle = async (req, { text, response, client }) => {
return; return;
} }
// If the request was aborted and is not azure, don't generate the title.
if (!client.azure && client.abortController.signal.aborted) {
return;
}
const titleCache = getLogStores(CacheKeys.GEN_TITLE); const titleCache = getLogStores(CacheKeys.GEN_TITLE);
const key = `${req.user.id}-${response.conversationId}`; const key = `${req.user.id}-${response.conversationId}`;

View file

@ -37,9 +37,8 @@ async function createMCPTool({ req, toolKey, provider }) {
} }
const [toolName, serverName] = toolKey.split(Constants.mcp_delimiter); const [toolName, serverName] = toolKey.split(Constants.mcp_delimiter);
const userId = req.user?.id;
if (!userId) { if (!req.user?.id) {
logger.error( logger.error(
`[MCP][${serverName}][${toolName}] User ID not found on request. Cannot create tool.`, `[MCP][${serverName}][${toolName}] User ID not found on request. Cannot create tool.`,
); );
@ -49,15 +48,16 @@ async function createMCPTool({ req, toolKey, provider }) {
/** @type {(toolArguments: Object | string, config?: GraphRunnableConfig) => Promise<unknown>} */ /** @type {(toolArguments: Object | string, config?: GraphRunnableConfig) => Promise<unknown>} */
const _call = async (toolArguments, config) => { const _call = async (toolArguments, config) => {
try { try {
const mcpManager = getMCPManager(); const derivedSignal = config?.signal ? AbortSignal.any([config.signal]) : undefined;
const mcpManager = getMCPManager(config?.userId);
const result = await mcpManager.callTool({ const result = await mcpManager.callTool({
serverName, serverName,
toolName, toolName,
provider, provider,
toolArguments, toolArguments,
options: { options: {
userId, userId: config?.configurable?.user_id,
signal: config?.signal, signal: derivedSignal,
}, },
}); });
@ -70,7 +70,7 @@ async function createMCPTool({ req, toolKey, provider }) {
return result; return result;
} catch (error) { } catch (error) {
logger.error( logger.error(
`[MCP][User: ${userId}][${serverName}] Error calling "${toolName}" MCP tool:`, `[MCP][User: ${config?.userId}][${serverName}] Error calling "${toolName}" MCP tool:`,
error, error,
); );
throw new Error( throw new Error(

View file

@ -334,7 +334,7 @@ async function processRequiredActions(client, requiredActions) {
const domainMap = new Map(); const domainMap = new Map();
for (const action of actionSets) { for (const action of actionSets) {
const domain = await domainParser(client.req, action.metadata.domain, true); const domain = await domainParser(action.metadata.domain, true);
domainMap.set(domain, action); domainMap.set(domain, action);
// Check if domain is allowed // Check if domain is allowed
@ -404,7 +404,7 @@ async function processRequiredActions(client, requiredActions) {
// We've already decrypted the metadata, so we can pass it directly // We've already decrypted the metadata, so we can pass it directly
tool = await createActionTool({ tool = await createActionTool({
req: client.req, userId: client.req.user.id,
res: client.res, res: client.res,
action, action,
requestBuilder, requestBuilder,
@ -458,7 +458,7 @@ async function processRequiredActions(client, requiredActions) {
* @param {Object} params - Run params containing user and request information. * @param {Object} params - Run params containing user and request information.
* @param {ServerRequest} params.req - The request object. * @param {ServerRequest} params.req - The request object.
* @param {ServerResponse} params.res - The request object. * @param {ServerResponse} params.res - The request object.
* @param {Agent} params.agent - The agent to load tools for. * @param {Pick<Agent, 'id' | 'provider' | 'model' | 'tools'} params.agent - The agent to load tools for.
* @param {string | undefined} [params.openAIApiKey] - The OpenAI API key. * @param {string | undefined} [params.openAIApiKey] - The OpenAI API key.
* @returns {Promise<{ tools?: StructuredTool[] }>} The agent tools. * @returns {Promise<{ tools?: StructuredTool[] }>} The agent tools.
*/ */
@ -570,7 +570,7 @@ async function loadAgentTools({ req, res, agent, tool_resources, openAIApiKey })
const domainMap = new Map(); const domainMap = new Map();
for (const action of actionSets) { for (const action of actionSets) {
const domain = await domainParser(req, action.metadata.domain, true); const domain = await domainParser(action.metadata.domain, true);
domainMap.set(domain, action); domainMap.set(domain, action);
// Check if domain is allowed (do this once per action set) // Check if domain is allowed (do this once per action set)
@ -639,7 +639,7 @@ async function loadAgentTools({ req, res, agent, tool_resources, openAIApiKey })
if (requestBuilder) { if (requestBuilder) {
const tool = await createActionTool({ const tool = await createActionTool({
req, userId: req.user.id,
res, res,
action, action,
requestBuilder, requestBuilder,

View file

@ -1,4 +1,4 @@
const Keyv = require('keyv'); const { Keyv } = require('keyv');
const passport = require('passport'); const passport = require('passport');
const session = require('express-session'); const session = require('express-session');
const MemoryStore = require('memorystore')(session); const MemoryStore = require('memorystore')(session);
@ -53,7 +53,7 @@ const configureSocialLogins = (app) => {
if (isEnabled(process.env.USE_REDIS)) { if (isEnabled(process.env.USE_REDIS)) {
logger.debug('Using Redis for session storage in OpenID...'); logger.debug('Using Redis for session storage in OpenID...');
const keyv = new Keyv({ store: keyvRedis }); const keyv = new Keyv({ store: keyvRedis });
const client = keyv.opts.store.redis; const client = keyv.opts.store.client;
sessionOptions.store = new RedisStore({ client, prefix: 'openid_session' }); sessionOptions.store = new RedisStore({ client, prefix: 'openid_session' });
} else { } else {
sessionOptions.store = new MemoryStore({ sessionOptions.store = new MemoryStore({

View file

@ -1,6 +1,3 @@
const mockGet = jest.fn();
const mockSet = jest.fn();
jest.mock('@keyv/mongo', () => { jest.mock('@keyv/mongo', () => {
const EventEmitter = require('events'); const EventEmitter = require('events');
class KeyvMongo extends EventEmitter { class KeyvMongo extends EventEmitter {
@ -20,11 +17,32 @@ jest.mock('@keyv/mongo', () => {
...url, ...url,
...options, ...options,
}; };
// In-memory store for tests
this.store = new Map();
} }
get = mockGet; async get(key) {
set = mockSet; return this.store.get(key);
} }
return KeyvMongo; async set(key, value, ttl) {
this.store.set(key, value);
return true;
}
async delete(key) {
return this.store.delete(key);
}
async clear() {
this.store.clear();
return true;
}
}
// Create a store factory function for the test suite
const store = () => new KeyvMongo();
return { KeyvMongo };
}); });

View file

@ -196,6 +196,7 @@ const bedrockModels = {
}; };
const xAIModels = { const xAIModels = {
grok: 131072,
'grok-beta': 131072, 'grok-beta': 131072,
'grok-vision-beta': 8192, 'grok-vision-beta': 8192,
'grok-2': 131072, 'grok-2': 131072,
@ -204,6 +205,10 @@ const xAIModels = {
'grok-2-vision': 32768, 'grok-2-vision': 32768,
'grok-2-vision-latest': 32768, 'grok-2-vision-latest': 32768,
'grok-2-vision-1212': 32768, 'grok-2-vision-1212': 32768,
'grok-3': 131072,
'grok-3-fast': 131072,
'grok-3-mini': 131072,
'grok-3-mini-fast': 131072,
}; };
const aggregateModels = { ...openAIModels, ...googleModels, ...bedrockModels, ...xAIModels }; const aggregateModels = { ...openAIModels, ...googleModels, ...bedrockModels, ...xAIModels };

View file

@ -517,18 +517,30 @@ describe('Grok Model Tests - Tokens', () => {
expect(getModelMaxTokens('grok-2-latest')).toBe(131072); expect(getModelMaxTokens('grok-2-latest')).toBe(131072);
}); });
test('should return correct tokens for Grok 3 series models', () => {
expect(getModelMaxTokens('grok-3')).toBe(131072);
expect(getModelMaxTokens('grok-3-fast')).toBe(131072);
expect(getModelMaxTokens('grok-3-mini')).toBe(131072);
expect(getModelMaxTokens('grok-3-mini-fast')).toBe(131072);
});
test('should handle partial matches for Grok models with prefixes', () => { test('should handle partial matches for Grok models with prefixes', () => {
// Vision models should match before general models // Vision models should match before general models
expect(getModelMaxTokens('openai/grok-2-vision-1212')).toBe(32768); expect(getModelMaxTokens('xai/grok-2-vision-1212')).toBe(32768);
expect(getModelMaxTokens('openai/grok-2-vision')).toBe(32768); expect(getModelMaxTokens('xai/grok-2-vision')).toBe(32768);
expect(getModelMaxTokens('openai/grok-2-vision-latest')).toBe(32768); expect(getModelMaxTokens('xai/grok-2-vision-latest')).toBe(32768);
// Beta models // Beta models
expect(getModelMaxTokens('openai/grok-vision-beta')).toBe(8192); expect(getModelMaxTokens('xai/grok-vision-beta')).toBe(8192);
expect(getModelMaxTokens('openai/grok-beta')).toBe(131072); expect(getModelMaxTokens('xai/grok-beta')).toBe(131072);
// Text models // Text models
expect(getModelMaxTokens('openai/grok-2-1212')).toBe(131072); expect(getModelMaxTokens('xai/grok-2-1212')).toBe(131072);
expect(getModelMaxTokens('openai/grok-2')).toBe(131072); expect(getModelMaxTokens('xai/grok-2')).toBe(131072);
expect(getModelMaxTokens('openai/grok-2-latest')).toBe(131072); expect(getModelMaxTokens('xai/grok-2-latest')).toBe(131072);
// Grok 3 models
expect(getModelMaxTokens('xai/grok-3')).toBe(131072);
expect(getModelMaxTokens('xai/grok-3-fast')).toBe(131072);
expect(getModelMaxTokens('xai/grok-3-mini')).toBe(131072);
expect(getModelMaxTokens('xai/grok-3-mini-fast')).toBe(131072);
}); });
}); });
@ -545,20 +557,30 @@ describe('Grok Model Tests - Tokens', () => {
expect(matchModelName('grok-2-1212')).toBe('grok-2-1212'); expect(matchModelName('grok-2-1212')).toBe('grok-2-1212');
expect(matchModelName('grok-2')).toBe('grok-2'); expect(matchModelName('grok-2')).toBe('grok-2');
expect(matchModelName('grok-2-latest')).toBe('grok-2-latest'); expect(matchModelName('grok-2-latest')).toBe('grok-2-latest');
// Grok 3 models
expect(matchModelName('grok-3')).toBe('grok-3');
expect(matchModelName('grok-3-fast')).toBe('grok-3-fast');
expect(matchModelName('grok-3-mini')).toBe('grok-3-mini');
expect(matchModelName('grok-3-mini-fast')).toBe('grok-3-mini-fast');
}); });
test('should match Grok model variations with prefixes', () => { test('should match Grok model variations with prefixes', () => {
// Vision models should match before general models // Vision models should match before general models
expect(matchModelName('openai/grok-2-vision-1212')).toBe('grok-2-vision-1212'); expect(matchModelName('xai/grok-2-vision-1212')).toBe('grok-2-vision-1212');
expect(matchModelName('openai/grok-2-vision')).toBe('grok-2-vision'); expect(matchModelName('xai/grok-2-vision')).toBe('grok-2-vision');
expect(matchModelName('openai/grok-2-vision-latest')).toBe('grok-2-vision-latest'); expect(matchModelName('xai/grok-2-vision-latest')).toBe('grok-2-vision-latest');
// Beta models // Beta models
expect(matchModelName('openai/grok-vision-beta')).toBe('grok-vision-beta'); expect(matchModelName('xai/grok-vision-beta')).toBe('grok-vision-beta');
expect(matchModelName('openai/grok-beta')).toBe('grok-beta'); expect(matchModelName('xai/grok-beta')).toBe('grok-beta');
// Text models // Text models
expect(matchModelName('openai/grok-2-1212')).toBe('grok-2-1212'); expect(matchModelName('xai/grok-2-1212')).toBe('grok-2-1212');
expect(matchModelName('openai/grok-2')).toBe('grok-2'); expect(matchModelName('xai/grok-2')).toBe('grok-2');
expect(matchModelName('openai/grok-2-latest')).toBe('grok-2-latest'); expect(matchModelName('xai/grok-2-latest')).toBe('grok-2-latest');
// Grok 3 models
expect(matchModelName('xai/grok-3')).toBe('grok-3');
expect(matchModelName('xai/grok-3-fast')).toBe('grok-3-fast');
expect(matchModelName('xai/grok-3-mini')).toBe('grok-3-mini');
expect(matchModelName('xai/grok-3-mini-fast')).toBe('grok-3-mini-fast');
}); });
}); });
}); });

View file

@ -11,6 +11,7 @@ import {
tMessageSchema, tMessageSchema,
tConvoUpdateSchema, tConvoUpdateSchema,
ContentTypes, ContentTypes,
isAssistantsEndpoint,
} from 'librechat-data-provider'; } from 'librechat-data-provider';
import type { import type {
TMessage, TMessage,
@ -622,6 +623,17 @@ export default function useEventHandlers({
const { endpoint: _endpoint, endpointType } = const { endpoint: _endpoint, endpointType } =
(submission.conversation as TConversation | null) ?? {}; (submission.conversation as TConversation | null) ?? {};
const endpoint = endpointType ?? _endpoint; const endpoint = endpointType ?? _endpoint;
if (!isAssistantsEndpoint(endpoint)) {
if (newConversation) {
newConversation({
template: { conversationId: conversationId || v4() },
preset: tPresetSchema.parse(submission.conversation),
});
}
setIsSubmitting(false);
return;
}
try { try {
const response = await fetch(`${EndpointURLs[endpoint ?? '']}/abort`, { const response = await fetch(`${EndpointURLs[endpoint ?? '']}/abort`, {
method: 'POST', method: 'POST',

514
package-lock.json generated

File diff suppressed because it is too large Load diff

View file

@ -151,11 +151,14 @@ export class MCPManager {
} }
/** Check for and disconnect idle connections */ /** Check for and disconnect idle connections */
private checkIdleConnections(): void { private checkIdleConnections(currentUserId?: string): void {
const now = Date.now(); const now = Date.now();
// Iterate through all users to check for idle ones // Iterate through all users to check for idle ones
for (const [userId, lastActivity] of this.userLastActivity.entries()) { for (const [userId, lastActivity] of this.userLastActivity.entries()) {
if (currentUserId && currentUserId === userId) {
continue;
}
if (now - lastActivity > this.USER_CONNECTION_IDLE_TIMEOUT) { if (now - lastActivity > this.USER_CONNECTION_IDLE_TIMEOUT) {
this.logger.info( this.logger.info(
`[MCP][User: ${userId}] User idle for too long. Disconnecting all connections...`, `[MCP][User: ${userId}] User idle for too long. Disconnecting all connections...`,