Merge branch 'main' into feat/Multitenant-login-OIDC

This commit is contained in:
Ruben Talstra 2025-05-14 21:23:24 +02:00 committed by GitHub
commit a85e853e12
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
535 changed files with 28767 additions and 13591 deletions

387
api/server/cleanup.js Normal file
View file

@ -0,0 +1,387 @@
const { logger } = require('~/config');
// WeakMap to hold temporary data associated with requests
const requestDataMap = new WeakMap();
const FinalizationRegistry = global.FinalizationRegistry || null;
/**
* FinalizationRegistry to clean up client objects when they are garbage collected.
* This is used to prevent memory leaks and ensure that client objects are
* properly disposed of when they are no longer needed.
* The registry holds a weak reference to the client object and a cleanup
* callback that is called when the client object is garbage collected.
* The callback can be used to perform any necessary cleanup operations,
* such as removing event listeners or freeing up resources.
*/
const clientRegistry = FinalizationRegistry
? new FinalizationRegistry((heldValue) => {
try {
// This will run when the client is garbage collected
if (heldValue && heldValue.userId) {
logger.debug(`[FinalizationRegistry] Cleaning up client for user ${heldValue.userId}`);
} else {
logger.debug('[FinalizationRegistry] Cleaning up client');
}
} catch (e) {
// Ignore errors
}
})
: null;
/**
* Cleans up the client object by removing references to its properties.
* This is useful for preventing memory leaks and ensuring that the client
* and its properties can be garbage collected when it is no longer needed.
*/
function disposeClient(client) {
if (!client) {
return;
}
try {
if (client.user) {
client.user = null;
}
if (client.apiKey) {
client.apiKey = null;
}
if (client.azure) {
client.azure = null;
}
if (client.conversationId) {
client.conversationId = null;
}
if (client.responseMessageId) {
client.responseMessageId = null;
}
if (client.message_file_map) {
client.message_file_map = null;
}
if (client.clientName) {
client.clientName = null;
}
if (client.sender) {
client.sender = null;
}
if (client.model) {
client.model = null;
}
if (client.maxContextTokens) {
client.maxContextTokens = null;
}
if (client.contextStrategy) {
client.contextStrategy = null;
}
if (client.currentDateString) {
client.currentDateString = null;
}
if (client.inputTokensKey) {
client.inputTokensKey = null;
}
if (client.outputTokensKey) {
client.outputTokensKey = null;
}
if (client.skipSaveUserMessage !== undefined) {
client.skipSaveUserMessage = null;
}
if (client.visionMode) {
client.visionMode = null;
}
if (client.continued !== undefined) {
client.continued = null;
}
if (client.fetchedConvo !== undefined) {
client.fetchedConvo = null;
}
if (client.previous_summary) {
client.previous_summary = null;
}
if (client.metadata) {
client.metadata = null;
}
if (client.isVisionModel) {
client.isVisionModel = null;
}
if (client.isChatCompletion !== undefined) {
client.isChatCompletion = null;
}
if (client.contextHandlers) {
client.contextHandlers = null;
}
if (client.augmentedPrompt) {
client.augmentedPrompt = null;
}
if (client.systemMessage) {
client.systemMessage = null;
}
if (client.azureEndpoint) {
client.azureEndpoint = null;
}
if (client.langchainProxy) {
client.langchainProxy = null;
}
if (client.isOmni !== undefined) {
client.isOmni = null;
}
if (client.runManager) {
client.runManager = null;
}
// Properties specific to AnthropicClient
if (client.message_start) {
client.message_start = null;
}
if (client.message_delta) {
client.message_delta = null;
}
if (client.isClaude3 !== undefined) {
client.isClaude3 = null;
}
if (client.useMessages !== undefined) {
client.useMessages = null;
}
if (client.isLegacyOutput !== undefined) {
client.isLegacyOutput = null;
}
if (client.supportsCacheControl !== undefined) {
client.supportsCacheControl = null;
}
// Properties specific to GoogleClient
if (client.serviceKey) {
client.serviceKey = null;
}
if (client.project_id) {
client.project_id = null;
}
if (client.client_email) {
client.client_email = null;
}
if (client.private_key) {
client.private_key = null;
}
if (client.access_token) {
client.access_token = null;
}
if (client.reverseProxyUrl) {
client.reverseProxyUrl = null;
}
if (client.authHeader) {
client.authHeader = null;
}
if (client.isGenerativeModel !== undefined) {
client.isGenerativeModel = null;
}
// Properties specific to OpenAIClient
if (client.ChatGPTClient) {
client.ChatGPTClient = null;
}
if (client.completionsUrl) {
client.completionsUrl = null;
}
if (client.shouldSummarize !== undefined) {
client.shouldSummarize = null;
}
if (client.isOllama !== undefined) {
client.isOllama = null;
}
if (client.FORCE_PROMPT !== undefined) {
client.FORCE_PROMPT = null;
}
if (client.isChatGptModel !== undefined) {
client.isChatGptModel = null;
}
if (client.isUnofficialChatGptModel !== undefined) {
client.isUnofficialChatGptModel = null;
}
if (client.useOpenRouter !== undefined) {
client.useOpenRouter = null;
}
if (client.startToken) {
client.startToken = null;
}
if (client.endToken) {
client.endToken = null;
}
if (client.userLabel) {
client.userLabel = null;
}
if (client.chatGptLabel) {
client.chatGptLabel = null;
}
if (client.modelLabel) {
client.modelLabel = null;
}
if (client.modelOptions) {
client.modelOptions = null;
}
if (client.defaultVisionModel) {
client.defaultVisionModel = null;
}
if (client.maxPromptTokens) {
client.maxPromptTokens = null;
}
if (client.maxResponseTokens) {
client.maxResponseTokens = null;
}
if (client.run) {
// Break circular references in run
if (client.run.Graph) {
client.run.Graph.resetValues();
client.run.Graph.handlerRegistry = null;
client.run.Graph.runId = null;
client.run.Graph.tools = null;
client.run.Graph.signal = null;
client.run.Graph.config = null;
client.run.Graph.toolEnd = null;
client.run.Graph.toolMap = null;
client.run.Graph.provider = null;
client.run.Graph.streamBuffer = null;
client.run.Graph.clientOptions = null;
client.run.Graph.graphState = null;
if (client.run.Graph.boundModel?.client) {
client.run.Graph.boundModel.client = null;
}
client.run.Graph.boundModel = null;
client.run.Graph.systemMessage = null;
client.run.Graph.reasoningKey = null;
client.run.Graph.messages = null;
client.run.Graph.contentData = null;
client.run.Graph.stepKeyIds = null;
client.run.Graph.contentIndexMap = null;
client.run.Graph.toolCallStepIds = null;
client.run.Graph.messageIdsByStepKey = null;
client.run.Graph.messageStepHasToolCalls = null;
client.run.Graph.prelimMessageIdsByStepKey = null;
client.run.Graph.currentTokenType = null;
client.run.Graph.lastToken = null;
client.run.Graph.tokenTypeSwitch = null;
client.run.Graph.indexTokenCountMap = null;
client.run.Graph.currentUsage = null;
client.run.Graph.tokenCounter = null;
client.run.Graph.maxContextTokens = null;
client.run.Graph.pruneMessages = null;
client.run.Graph.lastStreamCall = null;
client.run.Graph.startIndex = null;
client.run.Graph = null;
}
if (client.run.handlerRegistry) {
client.run.handlerRegistry = null;
}
if (client.run.graphRunnable) {
if (client.run.graphRunnable.channels) {
client.run.graphRunnable.channels = null;
}
if (client.run.graphRunnable.nodes) {
client.run.graphRunnable.nodes = null;
}
if (client.run.graphRunnable.lc_kwargs) {
client.run.graphRunnable.lc_kwargs = null;
}
if (client.run.graphRunnable.builder?.nodes) {
client.run.graphRunnable.builder.nodes = null;
client.run.graphRunnable.builder = null;
}
client.run.graphRunnable = null;
}
client.run = null;
}
if (client.sendMessage) {
client.sendMessage = null;
}
if (client.savedMessageIds) {
client.savedMessageIds.clear();
client.savedMessageIds = null;
}
if (client.currentMessages) {
client.currentMessages = null;
}
if (client.streamHandler) {
client.streamHandler = null;
}
if (client.contentParts) {
client.contentParts = null;
}
if (client.abortController) {
client.abortController = null;
}
if (client.collectedUsage) {
client.collectedUsage = null;
}
if (client.indexTokenCountMap) {
client.indexTokenCountMap = null;
}
if (client.agentConfigs) {
client.agentConfigs = null;
}
if (client.artifactPromises) {
client.artifactPromises = null;
}
if (client.usage) {
client.usage = null;
}
if (typeof client.dispose === 'function') {
client.dispose();
}
if (client.options) {
if (client.options.req) {
client.options.req = null;
}
if (client.options.res) {
client.options.res = null;
}
if (client.options.attachments) {
client.options.attachments = null;
}
if (client.options.agent) {
client.options.agent = null;
}
}
client.options = null;
} catch (e) {
// Ignore errors during disposal
}
}
function processReqData(data = {}, context) {
let {
abortKey,
userMessage,
userMessagePromise,
responseMessageId,
promptTokens,
conversationId,
userMessageId,
} = context;
for (const key in data) {
if (key === 'userMessage') {
userMessage = data[key];
userMessageId = data[key].messageId;
} else if (key === 'userMessagePromise') {
userMessagePromise = data[key];
} else if (key === 'responseMessageId') {
responseMessageId = data[key];
} else if (key === 'promptTokens') {
promptTokens = data[key];
} else if (key === 'abortKey') {
abortKey = data[key];
} else if (!conversationId && key === 'conversationId') {
conversationId = data[key];
}
}
return {
abortKey,
userMessage,
userMessagePromise,
responseMessageId,
promptTokens,
conversationId,
userMessageId,
};
}
module.exports = {
disposeClient,
requestDataMap,
clientRegistry,
processReqData,
};

View file

@ -1,5 +1,15 @@
const { getResponseSender, Constants } = require('librechat-data-provider');
const { createAbortController, handleAbortError } = require('~/server/middleware');
const {
handleAbortError,
createAbortController,
cleanupAbortController,
} = require('~/server/middleware');
const {
disposeClient,
processReqData,
clientRegistry,
requestDataMap,
} = require('~/server/cleanup');
const { sendMessage, createOnProgress } = require('~/server/utils');
const { saveMessage } = require('~/models');
const { logger } = require('~/config');
@ -14,90 +24,162 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
overrideParentMessageId = null,
} = req.body;
let client = null;
let abortKey = null;
let cleanupHandlers = [];
let clientRef = null;
logger.debug('[AskController]', {
text,
conversationId,
...endpointOption,
modelsConfig: endpointOption.modelsConfig ? 'exists' : '',
modelsConfig: endpointOption?.modelsConfig ? 'exists' : '',
});
let userMessage;
let userMessagePromise;
let promptTokens;
let userMessageId;
let responseMessageId;
let userMessage = null;
let userMessagePromise = null;
let promptTokens = null;
let userMessageId = null;
let responseMessageId = null;
let getAbortData = null;
const sender = getResponseSender({
...endpointOption,
model: endpointOption.modelOptions.model,
modelDisplayLabel,
});
const newConvo = !conversationId;
const user = req.user.id;
const initialConversationId = conversationId;
const newConvo = !initialConversationId;
const userId = req.user.id;
const getReqData = (data = {}) => {
for (let key in data) {
if (key === 'userMessage') {
userMessage = data[key];
userMessageId = data[key].messageId;
} else if (key === 'userMessagePromise') {
userMessagePromise = data[key];
} else if (key === 'responseMessageId') {
responseMessageId = data[key];
} else if (key === 'promptTokens') {
promptTokens = data[key];
} else if (!conversationId && key === 'conversationId') {
conversationId = data[key];
}
}
let reqDataContext = {
userMessage,
userMessagePromise,
responseMessageId,
promptTokens,
conversationId,
userMessageId,
};
let getText;
const updateReqData = (data = {}) => {
reqDataContext = processReqData(data, reqDataContext);
abortKey = reqDataContext.abortKey;
userMessage = reqDataContext.userMessage;
userMessagePromise = reqDataContext.userMessagePromise;
responseMessageId = reqDataContext.responseMessageId;
promptTokens = reqDataContext.promptTokens;
conversationId = reqDataContext.conversationId;
userMessageId = reqDataContext.userMessageId;
};
let { onProgress: progressCallback, getPartialText } = createOnProgress();
const performCleanup = () => {
logger.debug('[AskController] Performing cleanup');
if (Array.isArray(cleanupHandlers)) {
for (const handler of cleanupHandlers) {
try {
if (typeof handler === 'function') {
handler();
}
} catch (e) {
// Ignore
}
}
}
if (abortKey) {
logger.debug('[AskController] Cleaning up abort controller');
cleanupAbortController(abortKey);
abortKey = null;
}
if (client) {
disposeClient(client);
client = null;
}
reqDataContext = null;
userMessage = null;
userMessagePromise = null;
promptTokens = null;
getAbortData = null;
progressCallback = null;
endpointOption = null;
cleanupHandlers = null;
addTitle = null;
if (requestDataMap.has(req)) {
requestDataMap.delete(req);
}
logger.debug('[AskController] Cleanup completed');
};
try {
const { client } = await initializeClient({ req, res, endpointOption });
const { onProgress: progressCallback, getPartialText } = createOnProgress();
({ client } = await initializeClient({ req, res, endpointOption }));
if (clientRegistry && client) {
clientRegistry.register(client, { userId }, client);
}
getText = client.getStreamText != null ? client.getStreamText.bind(client) : getPartialText;
if (client) {
requestDataMap.set(req, { client });
}
const getAbortData = () => ({
sender,
conversationId,
userMessagePromise,
messageId: responseMessageId,
parentMessageId: overrideParentMessageId ?? userMessageId,
text: getText(),
userMessage,
promptTokens,
});
clientRef = new WeakRef(client);
const { abortController, onStart } = createAbortController(req, res, getAbortData, getReqData);
getAbortData = () => {
const currentClient = clientRef?.deref();
const currentText =
currentClient?.getStreamText != null ? currentClient.getStreamText() : getPartialText();
res.on('close', () => {
return {
sender,
conversationId,
messageId: reqDataContext.responseMessageId,
parentMessageId: overrideParentMessageId ?? userMessageId,
text: currentText,
userMessage: userMessage,
userMessagePromise: userMessagePromise,
promptTokens: reqDataContext.promptTokens,
};
};
const { onStart, abortController } = createAbortController(
req,
res,
getAbortData,
updateReqData,
);
const closeHandler = () => {
logger.debug('[AskController] Request closed');
if (!abortController) {
return;
} else if (abortController.signal.aborted) {
return;
} else if (abortController.requestCompleted) {
if (!abortController || abortController.signal.aborted || abortController.requestCompleted) {
return;
}
abortController.abort();
logger.debug('[AskController] Request aborted on close');
};
res.on('close', closeHandler);
cleanupHandlers.push(() => {
try {
res.removeListener('close', closeHandler);
} catch (e) {
// Ignore
}
});
const messageOptions = {
user,
user: userId,
parentMessageId,
conversationId,
conversationId: reqDataContext.conversationId,
overrideParentMessageId,
getReqData,
getReqData: updateReqData,
onStart,
abortController,
progressCallback,
progressOptions: {
res,
// parentMessageId: overrideParentMessageId || userMessageId,
},
};
@ -105,59 +187,95 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
let response = await client.sendMessage(text, messageOptions);
response.endpoint = endpointOption.endpoint;
const { conversation = {} } = await client.responsePromise;
const databasePromise = response.databasePromise;
delete response.databasePromise;
const { conversation: convoData = {} } = await databasePromise;
const conversation = { ...convoData };
conversation.title =
conversation && !conversation.title ? null : conversation?.title || 'New Chat';
if (client.options.attachments) {
userMessage.files = client.options.attachments;
conversation.model = endpointOption.modelOptions.model;
delete userMessage.image_urls;
const latestUserMessage = reqDataContext.userMessage;
if (client?.options?.attachments && latestUserMessage) {
latestUserMessage.files = client.options.attachments;
if (endpointOption?.modelOptions?.model) {
conversation.model = endpointOption.modelOptions.model;
}
delete latestUserMessage.image_urls;
}
if (!abortController.signal.aborted) {
const finalResponseMessage = { ...response };
sendMessage(res, {
final: true,
conversation,
title: conversation.title,
requestMessage: userMessage,
responseMessage: response,
requestMessage: latestUserMessage,
responseMessage: finalResponseMessage,
});
res.end();
if (!client.savedMessageIds.has(response.messageId)) {
if (client?.savedMessageIds && !client.savedMessageIds.has(response.messageId)) {
await saveMessage(
req,
{ ...response, user },
{ ...finalResponseMessage, user: userId },
{ context: 'api/server/controllers/AskController.js - response end' },
);
}
}
if (!client.skipSaveUserMessage) {
await saveMessage(req, userMessage, {
context: 'api/server/controllers/AskController.js - don\'t skip saving user message',
if (!client?.skipSaveUserMessage && latestUserMessage) {
await saveMessage(req, latestUserMessage, {
context: "api/server/controllers/AskController.js - don't skip saving user message",
});
}
if (addTitle && parentMessageId === Constants.NO_PARENT && newConvo) {
if (typeof addTitle === 'function' && parentMessageId === Constants.NO_PARENT && newConvo) {
addTitle(req, {
text,
response,
response: { ...response },
client,
});
})
.then(() => {
logger.debug('[AskController] Title generation started');
})
.catch((err) => {
logger.error('[AskController] Error in title generation', err);
})
.finally(() => {
logger.debug('[AskController] Title generation completed');
performCleanup();
});
} else {
performCleanup();
}
} catch (error) {
const partialText = getText && getText();
logger.error('[AskController] Error handling request', error);
let partialText = '';
try {
const currentClient = clientRef?.deref();
partialText =
currentClient?.getStreamText != null ? currentClient.getStreamText() : getPartialText();
} catch (getTextError) {
logger.error('[AskController] Error calling getText() during error handling', getTextError);
}
handleAbortError(res, req, error, {
sender,
partialText,
conversationId,
messageId: responseMessageId,
parentMessageId: overrideParentMessageId ?? userMessageId ?? parentMessageId,
}).catch((err) => {
logger.error('[AskController] Error in `handleAbortError`', err);
});
conversationId: reqDataContext.conversationId,
messageId: reqDataContext.responseMessageId,
parentMessageId: overrideParentMessageId ?? reqDataContext.userMessageId ?? parentMessageId,
userMessageId: reqDataContext.userMessageId,
})
.catch((err) => {
logger.error('[AskController] Error in `handleAbortError` during catch block', err);
})
.finally(() => {
performCleanup();
});
}
};

View file

@ -1,5 +1,15 @@
const { getResponseSender } = require('librechat-data-provider');
const { createAbortController, handleAbortError } = require('~/server/middleware');
const {
handleAbortError,
createAbortController,
cleanupAbortController,
} = require('~/server/middleware');
const {
disposeClient,
processReqData,
clientRegistry,
requestDataMap,
} = require('~/server/cleanup');
const { sendMessage, createOnProgress } = require('~/server/utils');
const { saveMessage } = require('~/models');
const { logger } = require('~/config');
@ -17,6 +27,11 @@ const EditController = async (req, res, next, initializeClient) => {
overrideParentMessageId = null,
} = req.body;
let client = null;
let abortKey = null;
let cleanupHandlers = [];
let clientRef = null; // Declare clientRef here
logger.debug('[EditController]', {
text,
generation,
@ -26,123 +41,205 @@ const EditController = async (req, res, next, initializeClient) => {
modelsConfig: endpointOption.modelsConfig ? 'exists' : '',
});
let userMessage;
let userMessagePromise;
let promptTokens;
let userMessage = null;
let userMessagePromise = null;
let promptTokens = null;
let getAbortData = null;
const sender = getResponseSender({
...endpointOption,
model: endpointOption.modelOptions.model,
modelDisplayLabel,
});
const userMessageId = parentMessageId;
const user = req.user.id;
const userId = req.user.id;
const getReqData = (data = {}) => {
for (let key in data) {
if (key === 'userMessage') {
userMessage = data[key];
} else if (key === 'userMessagePromise') {
userMessagePromise = data[key];
} else if (key === 'responseMessageId') {
responseMessageId = data[key];
} else if (key === 'promptTokens') {
promptTokens = data[key];
}
}
let reqDataContext = { userMessage, userMessagePromise, responseMessageId, promptTokens };
const updateReqData = (data = {}) => {
reqDataContext = processReqData(data, reqDataContext);
abortKey = reqDataContext.abortKey;
userMessage = reqDataContext.userMessage;
userMessagePromise = reqDataContext.userMessagePromise;
responseMessageId = reqDataContext.responseMessageId;
promptTokens = reqDataContext.promptTokens;
};
const { onProgress: progressCallback, getPartialText } = createOnProgress({
let { onProgress: progressCallback, getPartialText } = createOnProgress({
generation,
});
let getText;
const performCleanup = () => {
logger.debug('[EditController] Performing cleanup');
if (Array.isArray(cleanupHandlers)) {
for (const handler of cleanupHandlers) {
try {
if (typeof handler === 'function') {
handler();
}
} catch (e) {
// Ignore
}
}
}
if (abortKey) {
logger.debug('[AskController] Cleaning up abort controller');
cleanupAbortController(abortKey);
abortKey = null;
}
if (client) {
disposeClient(client);
client = null;
}
reqDataContext = null;
userMessage = null;
userMessagePromise = null;
promptTokens = null;
getAbortData = null;
progressCallback = null;
endpointOption = null;
cleanupHandlers = null;
if (requestDataMap.has(req)) {
requestDataMap.delete(req);
}
logger.debug('[EditController] Cleanup completed');
};
try {
const { client } = await initializeClient({ req, res, endpointOption });
({ client } = await initializeClient({ req, res, endpointOption }));
getText = client.getStreamText != null ? client.getStreamText.bind(client) : getPartialText;
if (clientRegistry && client) {
clientRegistry.register(client, { userId }, client);
}
const getAbortData = () => ({
conversationId,
userMessagePromise,
messageId: responseMessageId,
sender,
parentMessageId: overrideParentMessageId ?? userMessageId,
text: getText(),
userMessage,
promptTokens,
});
if (client) {
requestDataMap.set(req, { client });
}
const { abortController, onStart } = createAbortController(req, res, getAbortData, getReqData);
clientRef = new WeakRef(client);
res.on('close', () => {
getAbortData = () => {
const currentClient = clientRef?.deref();
const currentText =
currentClient?.getStreamText != null ? currentClient.getStreamText() : getPartialText();
return {
sender,
conversationId,
messageId: reqDataContext.responseMessageId,
parentMessageId: overrideParentMessageId ?? userMessageId,
text: currentText,
userMessage: userMessage,
userMessagePromise: userMessagePromise,
promptTokens: reqDataContext.promptTokens,
};
};
const { onStart, abortController } = createAbortController(
req,
res,
getAbortData,
updateReqData,
);
const closeHandler = () => {
logger.debug('[EditController] Request closed');
if (!abortController) {
return;
} else if (abortController.signal.aborted) {
return;
} else if (abortController.requestCompleted) {
if (!abortController || abortController.signal.aborted || abortController.requestCompleted) {
return;
}
abortController.abort();
logger.debug('[EditController] Request aborted on close');
};
res.on('close', closeHandler);
cleanupHandlers.push(() => {
try {
res.removeListener('close', closeHandler);
} catch (e) {
// Ignore
}
});
let response = await client.sendMessage(text, {
user,
user: userId,
generation,
isContinued,
isEdited: true,
conversationId,
parentMessageId,
responseMessageId,
responseMessageId: reqDataContext.responseMessageId,
overrideParentMessageId,
getReqData,
getReqData: updateReqData,
onStart,
abortController,
progressCallback,
progressOptions: {
res,
// parentMessageId: overrideParentMessageId || userMessageId,
},
});
const { conversation = {} } = await client.responsePromise;
const databasePromise = response.databasePromise;
delete response.databasePromise;
const { conversation: convoData = {} } = await databasePromise;
const conversation = { ...convoData };
conversation.title =
conversation && !conversation.title ? null : conversation?.title || 'New Chat';
if (client.options.attachments) {
if (client?.options?.attachments && endpointOption?.modelOptions?.model) {
conversation.model = endpointOption.modelOptions.model;
}
if (!abortController.signal.aborted) {
const finalUserMessage = reqDataContext.userMessage;
const finalResponseMessage = { ...response };
sendMessage(res, {
final: true,
conversation,
title: conversation.title,
requestMessage: userMessage,
responseMessage: response,
requestMessage: finalUserMessage,
responseMessage: finalResponseMessage,
});
res.end();
await saveMessage(
req,
{ ...response, user },
{ ...finalResponseMessage, user: userId },
{ context: 'api/server/controllers/EditController.js - response end' },
);
}
performCleanup();
} catch (error) {
const partialText = getText();
logger.error('[EditController] Error handling request', error);
let partialText = '';
try {
const currentClient = clientRef?.deref();
partialText =
currentClient?.getStreamText != null ? currentClient.getStreamText() : getPartialText();
} catch (getTextError) {
logger.error('[EditController] Error calling getText() during error handling', getTextError);
}
handleAbortError(res, req, error, {
sender,
partialText,
conversationId,
messageId: responseMessageId,
messageId: reqDataContext.responseMessageId,
parentMessageId: overrideParentMessageId ?? userMessageId ?? parentMessageId,
}).catch((err) => {
logger.error('[EditController] Error in `handleAbortError`', err);
});
userMessageId,
})
.catch((err) => {
logger.error('[EditController] Error in `handleAbortError` during catch block', err);
})
.finally(() => {
performCleanup();
});
}
};

View file

@ -1,5 +1,5 @@
const { CacheKeys, AuthType } = require('librechat-data-provider');
const { addOpenAPISpecs } = require('~/app/clients/tools/util/addOpenAPISpecs');
const { getToolkitKey } = require('~/server/services/ToolService');
const { getCustomConfig } = require('~/server/services/Config');
const { availableTools } = require('~/app/clients/tools');
const { getMCPManager } = require('~/config');
@ -69,7 +69,7 @@ const getAvailablePluginsController = async (req, res) => {
);
}
let plugins = await addOpenAPISpecs(authenticatedPlugins);
let plugins = authenticatedPlugins;
if (includedTools.length > 0) {
plugins = plugins.filter((plugin) => includedTools.includes(plugin.pluginKey));
@ -105,11 +105,11 @@ const getAvailableTools = async (req, res) => {
return;
}
const pluginManifest = availableTools;
let pluginManifest = availableTools;
const customConfig = await getCustomConfig();
if (customConfig?.mcpServers != null) {
const mcpManager = await getMCPManager();
await mcpManager.loadManifestTools(pluginManifest);
const mcpManager = getMCPManager();
pluginManifest = await mcpManager.loadManifestTools(pluginManifest);
}
/** @type {TPlugin[]} */
@ -128,7 +128,7 @@ const getAvailableTools = async (req, res) => {
(plugin) =>
toolDefinitions[plugin.pluginKey] !== undefined ||
(plugin.toolkit === true &&
Object.keys(toolDefinitions).some((key) => key.startsWith(`${plugin.pluginKey}_`))),
Object.keys(toolDefinitions).some((key) => getToolkitKey(key) === plugin.pluginKey)),
);
await cache.set(CacheKeys.TOOLS, tools);

View file

@ -1,6 +1,8 @@
const { FileSources } = require('librechat-data-provider');
const {
Balance,
getFiles,
updateUser,
deleteFiles,
deleteConvos,
deletePresets,
@ -12,6 +14,7 @@ const User = require('~/models/User');
const { updateUserPluginAuth, deleteUserPluginAuth } = require('~/server/services/PluginService');
const { updateUserPluginsService, deleteUserKey } = require('~/server/services/UserService');
const { verifyEmail, resendVerificationEmail } = require('~/server/services/AuthService');
const { needsRefresh, getNewS3URL } = require('~/server/services/Files/S3/crud');
const { processDeleteRequest } = require('~/server/services/Files/process');
const { deleteAllSharedLinks } = require('~/models/Share');
const { deleteToolCalls } = require('~/models/ToolCall');
@ -19,8 +22,23 @@ const { Transaction } = require('~/models/Transaction');
const { logger } = require('~/config');
const getUserController = async (req, res) => {
/** @type {MongoUser} */
const userData = req.user.toObject != null ? req.user.toObject() : { ...req.user };
delete userData.totpSecret;
if (req.app.locals.fileStrategy === FileSources.s3 && userData.avatar) {
const avatarNeedsRefresh = needsRefresh(userData.avatar, 3600);
if (!avatarNeedsRefresh) {
return res.status(200).send(userData);
}
const originalAvatar = userData.avatar;
try {
userData.avatar = await getNewS3URL(userData.avatar);
await updateUser(userData.id, { avatar: userData.avatar });
} catch (error) {
userData.avatar = originalAvatar;
logger.error('Error getting new S3 URL for avatar:', error);
}
}
res.status(200).send(userData);
};

View file

@ -14,15 +14,6 @@ const { loadAuthValues } = require('~/server/services/Tools/credentials');
const { saveBase64Image } = require('~/server/services/Files/process');
const { logger, sendEvent } = require('~/config');
/** @typedef {import('@librechat/agents').Graph} Graph */
/** @typedef {import('@librechat/agents').EventHandler} EventHandler */
/** @typedef {import('@librechat/agents').ModelEndData} ModelEndData */
/** @typedef {import('@librechat/agents').ToolEndData} ToolEndData */
/** @typedef {import('@librechat/agents').ToolEndCallback} ToolEndCallback */
/** @typedef {import('@librechat/agents').ChatModelStreamHandler} ChatModelStreamHandler */
/** @typedef {import('@librechat/agents').ContentAggregatorResult['aggregateContent']} ContentAggregator */
/** @typedef {import('@librechat/agents').GraphEvents} GraphEvents */
class ModelEndHandler {
/**
* @param {Array<UsageMetadata>} collectedUsage
@ -38,7 +29,7 @@ class ModelEndHandler {
* @param {string} event
* @param {ModelEndData | undefined} data
* @param {Record<string, unknown> | undefined} metadata
* @param {Graph} graph
* @param {StandardGraph} graph
* @returns
*/
handle(event, data, metadata, graph) {
@ -61,7 +52,10 @@ class ModelEndHandler {
}
this.collectedUsage.push(usage);
if (!graph.clientOptions?.disableStreaming) {
const streamingDisabled = !!(
graph.clientOptions?.disableStreaming || graph?.boundModel?.disableStreaming
);
if (!streamingDisabled) {
return;
}
if (!data.output.content) {
@ -246,7 +240,11 @@ function createToolEndCallback({ req, res, artifactPromises }) {
if (output.artifact.content) {
/** @type {FormattedContent[]} */
const content = output.artifact.content;
for (const part of content) {
for (let i = 0; i < content.length; i++) {
const part = content[i];
if (!part) {
continue;
}
if (part.type !== 'image_url') {
continue;
}
@ -254,8 +252,10 @@ function createToolEndCallback({ req, res, artifactPromises }) {
artifactPromises.push(
(async () => {
const filename = `${output.name}_${output.tool_call_id}_img_${nanoid()}`;
const file_id = output.artifact.file_ids?.[i];
const file = await saveBase64Image(url, {
req,
file_id,
filename,
endpoint: metadata.provider,
context: FileContext.image_generation,

View file

@ -20,11 +20,9 @@ const {
const {
Constants,
VisionModes,
openAISchema,
ContentTypes,
EModelEndpoint,
KnownEndpoints,
anthropicSchema,
isAgentsEndpoint,
AgentCapabilities,
bedrockInputSchema,
@ -35,6 +33,7 @@ const { addCacheControl, createContextHandlers } = require('~/app/clients/prompt
const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens');
const { getBufferString, HumanMessage } = require('@langchain/core/messages');
const { encodeAndFormat } = require('~/server/services/Files/images/encode');
const initOpenAI = require('~/server/services/Endpoints/openAI/initialize');
const Tokenizer = require('~/server/services/Tokenizer');
const BaseClient = require('~/app/clients/BaseClient');
const { logger, sendEvent } = require('~/config');
@ -43,21 +42,43 @@ const { createRun } = require('./run');
/** @typedef {import('@librechat/agents').MessageContentComplex} MessageContentComplex */
/** @typedef {import('@langchain/core/runnables').RunnableConfig} RunnableConfig */
const providerParsers = {
[EModelEndpoint.openAI]: openAISchema.parse,
[EModelEndpoint.azureOpenAI]: openAISchema.parse,
[EModelEndpoint.anthropic]: anthropicSchema.parse,
[EModelEndpoint.bedrock]: bedrockInputSchema.parse,
/**
* @param {ServerRequest} req
* @param {Agent} agent
* @param {string} endpoint
*/
const payloadParser = ({ req, agent, endpoint }) => {
if (isAgentsEndpoint(endpoint)) {
return { model: undefined };
} else if (endpoint === EModelEndpoint.bedrock) {
return bedrockInputSchema.parse(agent.model_parameters);
}
return req.body.endpointOption.model_parameters;
};
const legacyContentEndpoints = new Set([KnownEndpoints.groq, KnownEndpoints.deepseek]);
const noSystemModelRegex = [/\bo1\b/gi];
const noSystemModelRegex = [/\b(o1-preview|o1-mini|amazon\.titan-text)\b/gi];
// const { processMemory, memoryInstructions } = require('~/server/services/Endpoints/agents/memory');
// const { getFormattedMemories } = require('~/models/Memory');
// const { getCurrentDateTime } = require('~/utils');
function createTokenCounter(encoding) {
return (message) => {
const countTokens = (text) => Tokenizer.getTokenCount(text, encoding);
return getTokenCountForMessage(message, countTokens);
};
}
function logToolError(graph, error, toolId) {
logger.error(
'[api/server/controllers/agents/client.js #chatCompletion] Tool Error',
error,
toolId,
);
}
class AgentClient extends BaseClient {
constructor(options = {}) {
super(null, options);
@ -127,19 +148,13 @@ class AgentClient extends BaseClient {
* @param {MongoFile[]} attachments
*/
checkVisionRequest(attachments) {
logger.info(
'[api/server/controllers/agents/client.js #checkVisionRequest] not implemented',
attachments,
);
// if (!attachments) {
// return;
// }
// const availableModels = this.options.modelsConfig?.[this.options.endpoint];
// if (!availableModels) {
// return;
// }
// let visionRequestDetected = false;
// for (const file of attachments) {
// if (file?.type?.includes('image')) {
@ -150,13 +165,11 @@ class AgentClient extends BaseClient {
// if (!visionRequestDetected) {
// return;
// }
// this.isVisionModel = validateVisionModel({ model: this.modelOptions.model, availableModels });
// if (this.isVisionModel) {
// delete this.modelOptions.stop;
// return;
// }
// for (const model of availableModels) {
// if (!validateVisionModel({ model, availableModels })) {
// continue;
@ -166,42 +179,31 @@ class AgentClient extends BaseClient {
// delete this.modelOptions.stop;
// return;
// }
// if (!availableModels.includes(this.defaultVisionModel)) {
// return;
// }
// if (!validateVisionModel({ model: this.defaultVisionModel, availableModels })) {
// return;
// }
// this.modelOptions.model = this.defaultVisionModel;
// this.isVisionModel = true;
// delete this.modelOptions.stop;
}
getSaveOptions() {
const parseOptions = providerParsers[this.options.endpoint];
let runOptions =
this.options.endpoint === EModelEndpoint.agents
? {
model: undefined,
// TODO:
// would need to be override settings; otherwise, model needs to be undefined
// model: this.override.model,
// instructions: this.override.instructions,
// additional_instructions: this.override.additional_instructions,
}
: {};
if (parseOptions) {
try {
runOptions = parseOptions(this.options.agent.model_parameters);
} catch (error) {
logger.error(
'[api/server/controllers/agents/client.js #getSaveOptions] Error parsing options',
error,
);
}
// TODO:
// would need to be override settings; otherwise, model needs to be undefined
// model: this.override.model,
// instructions: this.override.instructions,
// additional_instructions: this.override.additional_instructions,
let runOptions = {};
try {
runOptions = payloadParser(this.options);
} catch (error) {
logger.error(
'[api/server/controllers/agents/client.js #getSaveOptions] Error parsing options',
error,
);
}
return removeNullishValues(
@ -352,7 +354,9 @@ class AgentClient extends BaseClient {
this.contextHandlers?.processFile(file);
continue;
}
if (file.metadata?.fileIdentifier) {
continue;
}
// orderedMessages[i].tokenCount += this.calculateImageTokenCost({
// width: file.width,
// height: file.height,
@ -471,6 +475,7 @@ class AgentClient extends BaseClient {
err,
);
});
continue;
}
spendTokens(txMetadata, {
promptTokens: usage.input_tokens,
@ -538,6 +543,10 @@ class AgentClient extends BaseClient {
}
async chatCompletion({ payload, abortController = null }) {
/** @type {Partial<RunnableConfig> & { version: 'v1' | 'v2'; run_id?: string; streamMode: string }} */
let config;
/** @type {ReturnType<createRun>} */
let run;
try {
if (!abortController) {
abortController = new AbortController();
@ -635,11 +644,11 @@ class AgentClient extends BaseClient {
/** @type {TCustomConfig['endpoints']['agents']} */
const agentsEConfig = this.options.req.app.locals[EModelEndpoint.agents];
/** @type {Partial<RunnableConfig> & { version: 'v1' | 'v2'; run_id?: string; streamMode: string }} */
const config = {
config = {
configurable: {
thread_id: this.conversationId,
last_agent_index: this.agentConfigs?.size ?? 0,
user_id: this.user ?? this.options.req.user?.id,
hide_sequential_outputs: this.options.agent.hide_sequential_outputs,
},
recursionLimit: agentsEConfig?.recursionLimit,
@ -654,19 +663,10 @@ class AgentClient extends BaseClient {
this.indexTokenCountMap,
toolSet,
);
if (legacyContentEndpoints.has(this.options.agent.endpoint)) {
if (legacyContentEndpoints.has(this.options.agent.endpoint?.toLowerCase())) {
initialMessages = formatContentStrings(initialMessages);
}
/** @type {ReturnType<createRun>} */
let run;
const countTokens = ((text) => this.getTokenCount(text)).bind(this);
/** @type {(message: BaseMessage) => number} */
const tokenCounter = (message) => {
return getTokenCountForMessage(message, countTokens);
};
/**
*
* @param {Agent} agent
@ -718,12 +718,14 @@ class AgentClient extends BaseClient {
}
if (noSystemMessages === true && systemContent?.length) {
let latestMessage = _messages.pop().content;
const latestMessageContent = _messages.pop().content;
if (typeof latestMessage !== 'string') {
latestMessage = latestMessage[0].text;
latestMessageContent[0].text = [systemContent, latestMessageContent[0].text].join('\n');
_messages.push(new HumanMessage({ content: latestMessageContent }));
} else {
const text = [systemContent, latestMessageContent].join('\n');
_messages.push(new HumanMessage(text));
}
latestMessage = [systemContent, latestMessage].join('\n');
_messages.push(new HumanMessage(latestMessage));
}
let messages = _messages;
@ -770,21 +772,18 @@ class AgentClient extends BaseClient {
run.Graph.contentData = contentData;
}
const encoding = this.getEncoding();
await run.processStream({ messages }, config, {
keepContent: i !== 0,
tokenCounter,
tokenCounter: createTokenCounter(encoding),
indexTokenCountMap: currentIndexCountMap,
maxContextTokens: agent.maxContextTokens,
callbacks: {
[Callback.TOOL_ERROR]: (graph, error, toolId) => {
logger.error(
'[api/server/controllers/agents/client.js #chatCompletion] Tool Error',
error,
toolId,
);
},
[Callback.TOOL_ERROR]: logToolError,
},
});
config.signal = null;
};
await runAgent(this.options.agent, initialMessages);
@ -812,6 +811,8 @@ class AgentClient extends BaseClient {
break;
}
}
const encoding = this.getEncoding();
const tokenCounter = createTokenCounter(encoding);
for (const [agentId, agent] of this.agentConfigs) {
if (abortController.signal.aborted === true) {
break;
@ -920,18 +921,27 @@ class AgentClient extends BaseClient {
* @param {string} params.text
* @param {string} params.conversationId
*/
async titleConvo({ text }) {
async titleConvo({ text, abortController }) {
if (!this.run) {
throw new Error('Run not initialized');
}
const { handleLLMEnd, collected: collectedMetadata } = createMetadataAggregator();
const endpoint = this.options.agent.endpoint;
const { req, res } = this.options;
/** @type {import('@librechat/agents').ClientOptions} */
const clientOptions = {
let clientOptions = {
maxTokens: 75,
};
let endpointConfig = this.options.req.app.locals[this.options.agent.endpoint];
let endpointConfig = req.app.locals[endpoint];
if (!endpointConfig) {
endpointConfig = await getCustomEndpointConfig(this.options.agent.endpoint);
try {
endpointConfig = await getCustomEndpointConfig(endpoint);
} catch (err) {
logger.error(
'[api/server/controllers/agents/client.js #titleConvo] Error getting custom endpoint config',
err,
);
}
}
if (
endpointConfig &&
@ -940,12 +950,35 @@ class AgentClient extends BaseClient {
) {
clientOptions.model = endpointConfig.titleModel;
}
if (
endpoint === EModelEndpoint.azureOpenAI &&
clientOptions.model &&
this.options.agent.model_parameters.model !== clientOptions.model
) {
clientOptions =
(
await initOpenAI({
req,
res,
optionsOnly: true,
overrideModel: clientOptions.model,
overrideEndpoint: endpoint,
endpointOption: {
model_parameters: clientOptions,
},
})
)?.llmConfig ?? clientOptions;
}
if (/\b(o\d)\b/i.test(clientOptions.model) && clientOptions.maxTokens != null) {
delete clientOptions.maxTokens;
}
try {
const titleResult = await this.run.generateTitle({
inputText: text,
contentParts: this.contentParts,
clientOptions,
chainOptions: {
signal: abortController.signal,
callbacks: [
{
handleLLMEnd,
@ -971,7 +1004,7 @@ class AgentClient extends BaseClient {
};
});
this.recordCollectedUsage({
await this.recordCollectedUsage({
model: clientOptions.model,
context: 'title',
collectedUsage,

View file

@ -1,5 +1,10 @@
const { Constants } = require('librechat-data-provider');
const { createAbortController, handleAbortError } = require('~/server/middleware');
const {
handleAbortError,
createAbortController,
cleanupAbortController,
} = require('~/server/middleware');
const { disposeClient, clientRegistry, requestDataMap } = require('~/server/cleanup');
const { sendMessage } = require('~/server/utils');
const { saveMessage } = require('~/models');
const { logger } = require('~/config');
@ -14,16 +19,22 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
} = req.body;
let sender;
let abortKey;
let userMessage;
let promptTokens;
let userMessageId;
let responseMessageId;
let userMessagePromise;
let getAbortData;
let client = null;
// Initialize as an array
let cleanupHandlers = [];
const newConvo = !conversationId;
const user = req.user.id;
const userId = req.user.id;
const getReqData = (data = {}) => {
// Create handler to avoid capturing the entire parent scope
let getReqData = (data = {}) => {
for (let key in data) {
if (key === 'userMessage') {
userMessage = data[key];
@ -36,30 +47,96 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
promptTokens = data[key];
} else if (key === 'sender') {
sender = data[key];
} else if (key === 'abortKey') {
abortKey = data[key];
} else if (!conversationId && key === 'conversationId') {
conversationId = data[key];
}
}
};
// Create a function to handle final cleanup
const performCleanup = () => {
logger.debug('[AgentController] Performing cleanup');
// Make sure cleanupHandlers is an array before iterating
if (Array.isArray(cleanupHandlers)) {
// Execute all cleanup handlers
for (const handler of cleanupHandlers) {
try {
if (typeof handler === 'function') {
handler();
}
} catch (e) {
// Ignore cleanup errors
}
}
}
// Clean up abort controller
if (abortKey) {
logger.debug('[AgentController] Cleaning up abort controller');
cleanupAbortController(abortKey);
}
// Dispose client properly
if (client) {
disposeClient(client);
}
// Clear all references
client = null;
getReqData = null;
userMessage = null;
getAbortData = null;
endpointOption.agent = null;
endpointOption = null;
cleanupHandlers = null;
userMessagePromise = null;
// Clear request data map
if (requestDataMap.has(req)) {
requestDataMap.delete(req);
}
logger.debug('[AgentController] Cleanup completed');
};
try {
/** @type {{ client: TAgentClient }} */
const { client } = await initializeClient({ req, res, endpointOption });
const result = await initializeClient({ req, res, endpointOption });
client = result.client;
const getAbortData = () => ({
sender,
userMessage,
promptTokens,
conversationId,
userMessagePromise,
messageId: responseMessageId,
content: client.getContentParts(),
parentMessageId: overrideParentMessageId ?? userMessageId,
});
// Register client with finalization registry if available
if (clientRegistry) {
clientRegistry.register(client, { userId }, client);
}
// Store request data in WeakMap keyed by req object
requestDataMap.set(req, { client });
// Use WeakRef to allow GC but still access content if it exists
const contentRef = new WeakRef(client.contentParts || []);
// Minimize closure scope - only capture small primitives and WeakRef
getAbortData = () => {
// Dereference WeakRef each time
const content = contentRef.deref();
return {
sender,
content: content || [],
userMessage,
promptTokens,
conversationId,
userMessagePromise,
messageId: responseMessageId,
parentMessageId: overrideParentMessageId ?? userMessageId,
};
};
const { abortController, onStart } = createAbortController(req, res, getAbortData, getReqData);
res.on('close', () => {
// Simple handler to avoid capturing scope
const closeHandler = () => {
logger.debug('[AgentController] Request closed');
if (!abortController) {
return;
@ -71,10 +148,19 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
abortController.abort();
logger.debug('[AgentController] Request aborted on close');
};
res.on('close', closeHandler);
cleanupHandlers.push(() => {
try {
res.removeListener('close', closeHandler);
} catch (e) {
// Ignore
}
});
const messageOptions = {
user,
user: userId,
onStart,
getReqData,
conversationId,
@ -83,69 +169,104 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
overrideParentMessageId,
progressOptions: {
res,
// parentMessageId: overrideParentMessageId || userMessageId,
},
};
let response = await client.sendMessage(text, messageOptions);
response.endpoint = endpointOption.endpoint;
const { conversation = {} } = await client.responsePromise;
// Extract what we need and immediately break reference
const messageId = response.messageId;
const endpoint = endpointOption.endpoint;
response.endpoint = endpoint;
// Store database promise locally
const databasePromise = response.databasePromise;
delete response.databasePromise;
// Resolve database-related data
const { conversation: convoData = {} } = await databasePromise;
const conversation = { ...convoData };
conversation.title =
conversation && !conversation.title ? null : conversation?.title || 'New Chat';
if (req.body.files && client.options.attachments) {
// Process files if needed
if (req.body.files && client.options?.attachments) {
userMessage.files = [];
const messageFiles = new Set(req.body.files.map((file) => file.file_id));
for (let attachment of client.options.attachments) {
if (messageFiles.has(attachment.file_id)) {
userMessage.files.push(attachment);
userMessage.files.push({ ...attachment });
}
}
delete userMessage.image_urls;
}
// Only send if not aborted
if (!abortController.signal.aborted) {
// Create a new response object with minimal copies
const finalResponse = { ...response };
sendMessage(res, {
final: true,
conversation,
title: conversation.title,
requestMessage: userMessage,
responseMessage: response,
responseMessage: finalResponse,
});
res.end();
if (!client.savedMessageIds.has(response.messageId)) {
// Save the message if needed
if (client.savedMessageIds && !client.savedMessageIds.has(messageId)) {
await saveMessage(
req,
{ ...response, user },
{ ...finalResponse, user: userId },
{ context: 'api/server/controllers/agents/request.js - response end' },
);
}
}
// Save user message if needed
if (!client.skipSaveUserMessage) {
await saveMessage(req, userMessage, {
context: 'api/server/controllers/agents/request.js - don\'t skip saving user message',
});
}
// Add title if needed - extract minimal data
if (addTitle && parentMessageId === Constants.NO_PARENT && newConvo) {
addTitle(req, {
text,
response,
response: { ...response },
client,
});
})
.then(() => {
logger.debug('[AgentController] Title generation started');
})
.catch((err) => {
logger.error('[AgentController] Error in title generation', err);
})
.finally(() => {
logger.debug('[AgentController] Title generation completed');
performCleanup();
});
} else {
performCleanup();
}
} catch (error) {
// Handle error without capturing much scope
handleAbortError(res, req, error, {
conversationId,
sender,
messageId: responseMessageId,
parentMessageId: overrideParentMessageId ?? userMessageId ?? parentMessageId,
}).catch((err) => {
logger.error('[api/server/controllers/agents/request] Error in `handleAbortError`', err);
});
userMessageId,
})
.catch((err) => {
logger.error('[api/server/controllers/agents/request] Error in `handleAbortError`', err);
})
.finally(() => {
performCleanup();
});
}
};

View file

@ -11,6 +11,13 @@ const { providerEndpointMap, KnownEndpoints } = require('librechat-data-provider
* @typedef {import('@librechat/agents').IState} IState
*/
const customProviders = new Set([
Providers.XAI,
Providers.OLLAMA,
Providers.DEEPSEEK,
Providers.OPENROUTER,
]);
/**
* Creates a new Run instance with custom handlers and configuration.
*
@ -43,6 +50,15 @@ async function createRun({
agent.model_parameters,
);
/** Resolves issues with new OpenAI usage field */
if (
customProviders.has(agent.provider) ||
(agent.provider === Providers.OPENAI && agent.endpoint !== agent.provider)
) {
llmConfig.streamUsage = false;
llmConfig.usage = true;
}
/** @type {'reasoning_content' | 'reasoning'} */
let reasoningKey;
if (
@ -51,10 +67,6 @@ async function createRun({
) {
reasoningKey = 'reasoning';
}
if (/o1(?!-(?:mini|preview)).*$/.test(llmConfig.model)) {
llmConfig.streaming = false;
llmConfig.disableStreaming = true;
}
/** @type {StandardGraphConfig} */
const graphConfig = {

View file

@ -4,6 +4,7 @@ const {
Tools,
Constants,
FileContext,
FileSources,
SystemRoles,
EToolResources,
actionDelimiter,
@ -17,9 +18,10 @@ const {
} = require('~/models/Agent');
const { uploadImageBuffer, filterFile } = require('~/server/services/Files/process');
const { getStrategyFunctions } = require('~/server/services/Files/strategies');
const { refreshS3Url } = require('~/server/services/Files/S3/crud');
const { updateAction, getActions } = require('~/models/Action');
const { getProjectByName } = require('~/models/Project');
const { updateAgentProjects } = require('~/models/Agent');
const { getProjectByName } = require('~/models/Project');
const { deleteFileByFilter } = require('~/models/File');
const { logger } = require('~/config');
@ -102,6 +104,14 @@ const getAgentHandler = async (req, res) => {
return res.status(404).json({ error: 'Agent not found' });
}
if (agent.avatar && agent.avatar?.source === FileSources.s3) {
const originalUrl = agent.avatar.filepath;
agent.avatar.filepath = await refreshS3Url(agent.avatar);
if (originalUrl !== agent.avatar.filepath) {
await updateAgent({ id }, { avatar: agent.avatar });
}
}
agent.author = agent.author.toString();
agent.isCollaborative = !!agent.isCollaborative;
@ -212,6 +222,11 @@ const duplicateAgentHandler = async (req, res) => {
tool_resources: _tool_resources = {},
...cloneData
} = agent;
cloneData.name = `${agent.name} (${new Date().toLocaleString('en-US', {
dateStyle: 'short',
timeStyle: 'short',
hour12: false,
})})`;
if (_tool_resources?.[EToolResources.ocr]) {
cloneData.tool_resources = {

View file

@ -19,7 +19,7 @@ const {
addThreadMetadata,
saveAssistantMessage,
} = require('~/server/services/Threads');
const { sendResponse, sendMessage, sleep, isEnabled, countTokens } = require('~/server/utils');
const { sendResponse, sendMessage, sleep, countTokens } = require('~/server/utils');
const { runAssistant, createOnTextProgress } = require('~/server/services/AssistantService');
const validateAuthor = require('~/server/middleware/assistants/validateAuthor');
const { formatMessage, createVisionPrompt } = require('~/app/clients/prompts');
@ -27,7 +27,7 @@ const { createRun, StreamRunManager } = require('~/server/services/Runs');
const { addTitle } = require('~/server/services/Endpoints/assistants');
const { createRunBody } = require('~/server/services/createRunBody');
const { getTransactions } = require('~/models/Transaction');
const checkBalance = require('~/models/checkBalance');
const { checkBalance } = require('~/models/balanceMethods');
const { getConvo } = require('~/models/Conversation');
const getLogStores = require('~/cache/getLogStores');
const { getModelMaxTokens } = require('~/utils');
@ -119,7 +119,7 @@ const chatV1 = async (req, res) => {
} else if (/Files.*are invalid/.test(error.message)) {
const errorMessage = `Files are invalid, or may not have uploaded yet.${
endpoint === EModelEndpoint.azureAssistants
? ' If using Azure OpenAI, files are only available in the region of the assistant\'s model at the time of upload.'
? " If using Azure OpenAI, files are only available in the region of the assistant's model at the time of upload."
: ''
}`;
return sendResponse(req, res, messageData, errorMessage);
@ -248,7 +248,8 @@ const chatV1 = async (req, res) => {
}
const checkBalanceBeforeRun = async () => {
if (!isEnabled(process.env.CHECK_BALANCE)) {
const balance = req.app?.locals?.balance;
if (!balance?.enabled) {
return;
}
const transactions =
@ -378,8 +379,8 @@ const chatV1 = async (req, res) => {
body.additional_instructions ? `${body.additional_instructions}\n` : ''
}The user has uploaded ${imageCount} image${pluralized}.
Use the \`${ImageVisionTool.function.name}\` tool to retrieve ${
plural ? '' : 'a '
}detailed text description${pluralized} for ${plural ? 'each' : 'the'} image${pluralized}.`;
plural ? '' : 'a '
}detailed text description${pluralized} for ${plural ? 'each' : 'the'} image${pluralized}.`;
return files;
};
@ -575,6 +576,8 @@ const chatV1 = async (req, res) => {
thread_id,
model: assistant_id,
endpoint,
spec: endpointOption.spec,
iconURL: endpointOption.iconURL,
};
sendMessage(res, {

View file

@ -18,14 +18,14 @@ const {
saveAssistantMessage,
} = require('~/server/services/Threads');
const { runAssistant, createOnTextProgress } = require('~/server/services/AssistantService');
const { sendMessage, sleep, isEnabled, countTokens } = require('~/server/utils');
const { createErrorHandler } = require('~/server/controllers/assistants/errors');
const validateAuthor = require('~/server/middleware/assistants/validateAuthor');
const { createRun, StreamRunManager } = require('~/server/services/Runs');
const { addTitle } = require('~/server/services/Endpoints/assistants');
const { sendMessage, sleep, countTokens } = require('~/server/utils');
const { createRunBody } = require('~/server/services/createRunBody');
const { getTransactions } = require('~/models/Transaction');
const checkBalance = require('~/models/checkBalance');
const { checkBalance } = require('~/models/balanceMethods');
const { getConvo } = require('~/models/Conversation');
const getLogStores = require('~/cache/getLogStores');
const { getModelMaxTokens } = require('~/utils');
@ -124,7 +124,8 @@ const chatV2 = async (req, res) => {
}
const checkBalanceBeforeRun = async () => {
if (!isEnabled(process.env.CHECK_BALANCE)) {
const balance = req.app?.locals?.balance;
if (!balance?.enabled) {
return;
}
const transactions =
@ -427,6 +428,8 @@ const chatV2 = async (req, res) => {
thread_id,
model: assistant_id,
endpoint,
spec: endpointOption.spec,
iconURL: endpointOption.iconURL,
};
sendMessage(res, {

View file

@ -88,8 +88,8 @@ const startServer = async () => {
app.use('/api/actions', routes.actions);
app.use('/api/keys', routes.keys);
app.use('/api/user', routes.user);
app.use('/api/search', routes.search);
app.use('/api/ask', routes.ask);
app.use('/api/search', routes.search);
app.use('/api/edit', routes.edit);
app.use('/api/messages', routes.messages);
app.use('/api/convos', routes.convos);

View file

@ -1,3 +1,4 @@
// abortMiddleware.js
const { isAssistantsEndpoint, ErrorTypes } = require('librechat-data-provider');
const { sendMessage, sendError, countTokens, isEnabled } = require('~/server/utils');
const { truncateText, smartTruncateText } = require('~/app/clients/prompts');
@ -8,6 +9,68 @@ const { saveMessage, getConvo } = require('~/models');
const { abortRun } = require('./abortRun');
const { logger } = require('~/config');
const abortDataMap = new WeakMap();
function cleanupAbortController(abortKey) {
if (!abortControllers.has(abortKey)) {
return false;
}
const { abortController } = abortControllers.get(abortKey);
if (!abortController) {
abortControllers.delete(abortKey);
return true;
}
// 1. Check if this controller has any composed signals and clean them up
try {
// This creates a temporary composed signal to use for cleanup
const composedSignal = AbortSignal.any([abortController.signal]);
// Get all event types - in practice, AbortSignal typically only uses 'abort'
const eventTypes = ['abort'];
// First, execute a dummy listener removal to handle potential composed signals
for (const eventType of eventTypes) {
const dummyHandler = () => {};
composedSignal.addEventListener(eventType, dummyHandler);
composedSignal.removeEventListener(eventType, dummyHandler);
const listeners = composedSignal.listeners?.(eventType) || [];
for (const listener of listeners) {
composedSignal.removeEventListener(eventType, listener);
}
}
} catch (e) {
logger.debug(`Error cleaning up composed signals: ${e}`);
}
// 2. Abort the controller if not already aborted
if (!abortController.signal.aborted) {
abortController.abort();
}
// 3. Remove from registry
abortControllers.delete(abortKey);
// 4. Clean up any data stored in the WeakMap
if (abortDataMap.has(abortController)) {
abortDataMap.delete(abortController);
}
// 5. Clean up function references on the controller
if (abortController.getAbortData) {
abortController.getAbortData = null;
}
if (abortController.abortCompletion) {
abortController.abortCompletion = null;
}
return true;
}
async function abortMessage(req, res) {
let { abortKey, endpoint } = req.body;
@ -29,24 +92,24 @@ async function abortMessage(req, res) {
if (!abortController) {
return res.status(204).send({ message: 'Request not found' });
}
const finalEvent = await abortController.abortCompletion();
const finalEvent = await abortController.abortCompletion?.();
logger.debug(
`[abortMessage] ID: ${req.user.id} | ${req.user.email} | Aborted request: ` +
JSON.stringify({ abortKey }),
);
abortControllers.delete(abortKey);
cleanupAbortController(abortKey);
if (res.headersSent && finalEvent) {
return sendMessage(res, finalEvent);
}
res.setHeader('Content-Type', 'application/json');
res.send(JSON.stringify(finalEvent));
}
const handleAbort = () => {
return async (req, res) => {
const handleAbort = function () {
return async function (req, res) {
try {
if (isEnabled(process.env.LIMIT_CONCURRENT_MESSAGES)) {
await clearPendingReq({ userId: req.user.id });
@ -62,8 +125,48 @@ const createAbortController = (req, res, getAbortData, getReqData) => {
const abortController = new AbortController();
const { endpointOption } = req.body;
// Store minimal data in WeakMap to avoid circular references
abortDataMap.set(abortController, {
getAbortDataFn: getAbortData,
userId: req.user.id,
endpoint: endpointOption.endpoint,
iconURL: endpointOption.iconURL,
model: endpointOption.modelOptions?.model || endpointOption.model_parameters?.model,
});
// Replace the direct function reference with a wrapper that uses WeakMap
abortController.getAbortData = function () {
return getAbortData();
const data = abortDataMap.get(this);
if (!data || typeof data.getAbortDataFn !== 'function') {
return {};
}
try {
const result = data.getAbortDataFn();
// Create a copy without circular references
const cleanResult = { ...result };
// If userMessagePromise exists, break its reference to client
if (
cleanResult.userMessagePromise &&
typeof cleanResult.userMessagePromise.then === 'function'
) {
// Create a new promise that fulfills with the same result but doesn't reference the original
const originalPromise = cleanResult.userMessagePromise;
cleanResult.userMessagePromise = new Promise((resolve, reject) => {
originalPromise.then(
(result) => resolve({ ...result }),
(error) => reject(error),
);
});
}
return cleanResult;
} catch (err) {
logger.error('[abortController.getAbortData] Error:', err);
return {};
}
};
/**
@ -74,6 +177,7 @@ const createAbortController = (req, res, getAbortData, getReqData) => {
sendMessage(res, { message: userMessage, created: true });
const abortKey = userMessage?.conversationId ?? req.user.id;
getReqData({ abortKey });
const prevRequest = abortControllers.get(abortKey);
const { overrideUserMessageId } = req?.body ?? {};
@ -81,34 +185,74 @@ const createAbortController = (req, res, getAbortData, getReqData) => {
const data = prevRequest.abortController.getAbortData();
getReqData({ userMessage: data?.userMessage });
const addedAbortKey = `${abortKey}:${responseMessageId}`;
abortControllers.set(addedAbortKey, { abortController, ...endpointOption });
res.on('finish', function () {
abortControllers.delete(addedAbortKey);
});
// Store minimal options
const minimalOptions = {
endpoint: endpointOption.endpoint,
iconURL: endpointOption.iconURL,
model: endpointOption.modelOptions?.model || endpointOption.model_parameters?.model,
};
abortControllers.set(addedAbortKey, { abortController, ...minimalOptions });
// Use a simple function for cleanup to avoid capturing context
const cleanupHandler = () => {
try {
cleanupAbortController(addedAbortKey);
} catch (e) {
// Ignore cleanup errors
}
};
res.on('finish', cleanupHandler);
return;
}
abortControllers.set(abortKey, { abortController, ...endpointOption });
// Store minimal options
const minimalOptions = {
endpoint: endpointOption.endpoint,
iconURL: endpointOption.iconURL,
model: endpointOption.modelOptions?.model || endpointOption.model_parameters?.model,
};
res.on('finish', function () {
abortControllers.delete(abortKey);
});
abortControllers.set(abortKey, { abortController, ...minimalOptions });
// Use a simple function for cleanup to avoid capturing context
const cleanupHandler = () => {
try {
cleanupAbortController(abortKey);
} catch (e) {
// Ignore cleanup errors
}
};
res.on('finish', cleanupHandler);
};
// Define abortCompletion without capturing the entire parent scope
abortController.abortCompletion = async function () {
abortController.abort();
this.abort();
// Get data from WeakMap
const ctrlData = abortDataMap.get(this);
if (!ctrlData || !ctrlData.getAbortDataFn) {
return { final: true, conversation: {}, title: 'New Chat' };
}
// Get abort data using stored function
const { conversationId, userMessage, userMessagePromise, promptTokens, ...responseData } =
getAbortData();
ctrlData.getAbortDataFn();
const completionTokens = await countTokens(responseData?.text ?? '');
const user = req.user.id;
const user = ctrlData.userId;
const responseMessage = {
...responseData,
conversationId,
finish_reason: 'incomplete',
endpoint: endpointOption.endpoint,
iconURL: endpointOption.iconURL,
model: endpointOption.modelOptions?.model ?? endpointOption.model_parameters?.model,
endpoint: ctrlData.endpoint,
iconURL: ctrlData.iconURL,
model: ctrlData.modelOptions?.model ?? ctrlData.model_parameters?.model,
unfinished: false,
error: false,
isCreatedByUser: false,
@ -130,10 +274,12 @@ const createAbortController = (req, res, getAbortData, getReqData) => {
if (userMessagePromise) {
const resolved = await userMessagePromise;
conversation = resolved?.conversation;
// Break reference to promise
resolved.conversation = null;
}
if (!conversation) {
conversation = await getConvo(req.user.id, conversationId);
conversation = await getConvo(user, conversationId);
}
return {
@ -148,6 +294,13 @@ const createAbortController = (req, res, getAbortData, getReqData) => {
return { abortController, onStart };
};
/**
* @param {ServerResponse} res
* @param {ServerRequest} req
* @param {Error | unknown} error
* @param {Partial<TMessage> & { partialText?: string }} data
* @returns { Promise<void> }
*/
const handleAbortError = async (res, req, error, data) => {
if (error?.message?.includes('base64')) {
logger.error('[handleAbortError] Error in base64 encoding', {
@ -158,7 +311,7 @@ const handleAbortError = async (res, req, error, data) => {
} else {
logger.error('[handleAbortError] AI response error; aborting request:', error);
}
const { sender, conversationId, messageId, parentMessageId, partialText } = data;
const { sender, conversationId, messageId, parentMessageId, userMessageId, partialText } = data;
if (error.stack && error.stack.includes('google')) {
logger.warn(
@ -178,17 +331,30 @@ const handleAbortError = async (res, req, error, data) => {
errorText = `{"type":"${ErrorTypes.NO_SYSTEM_MESSAGES}"}`;
}
/**
* @param {string} partialText
* @returns {Promise<void>}
*/
const respondWithError = async (partialText) => {
const endpointOption = req.body?.endpointOption;
let options = {
sender,
messageId,
conversationId,
parentMessageId,
text: errorText,
shouldSaveMessage: true,
user: req.user.id,
spec: endpointOption?.spec,
iconURL: endpointOption?.iconURL,
modelLabel: endpointOption?.modelLabel,
shouldSaveMessage: userMessageId != null,
model: endpointOption?.modelOptions?.model || req.body?.model,
};
if (req.body?.agent_id) {
options.agent_id = req.body.agent_id;
}
if (partialText) {
options = {
...options,
@ -198,11 +364,12 @@ const handleAbortError = async (res, req, error, data) => {
};
}
// Create a simple callback without capturing parent scope
const callback = async () => {
if (abortControllers.has(conversationId)) {
const { abortController } = abortControllers.get(conversationId);
abortController.abort();
abortControllers.delete(conversationId);
try {
cleanupAbortController(conversationId);
} catch (e) {
// Ignore cleanup errors
}
};
@ -223,6 +390,7 @@ const handleAbortError = async (res, req, error, data) => {
module.exports = {
handleAbort,
createAbortController,
handleAbortError,
createAbortController,
cleanupAbortController,
};

View file

@ -1,6 +1,11 @@
const { parseCompactConvo, EModelEndpoint, isAgentsEndpoint } = require('librechat-data-provider');
const { getModelsConfig } = require('~/server/controllers/ModelController');
const {
parseCompactConvo,
EModelEndpoint,
isAgentsEndpoint,
EndpointURLs,
} = require('librechat-data-provider');
const azureAssistants = require('~/server/services/Endpoints/azureAssistants');
const { getModelsConfig } = require('~/server/controllers/ModelController');
const assistants = require('~/server/services/Endpoints/assistants');
const gptPlugins = require('~/server/services/Endpoints/gptPlugins');
const { processFiles } = require('~/server/services/Files/process');
@ -77,8 +82,9 @@ async function buildEndpointOption(req, res, next) {
}
try {
const isAgents = isAgentsEndpoint(endpoint);
const endpointFn = buildFunction[endpointType ?? endpoint];
const isAgents =
isAgentsEndpoint(endpoint) || req.baseUrl.startsWith(EndpointURLs[EModelEndpoint.agents]);
const endpointFn = buildFunction[isAgents ? EModelEndpoint.agents : (endpointType ?? endpoint)];
const builder = isAgents ? (...args) => endpointFn(req, ...args) : endpointFn;
// TODO: use object params

View file

@ -1,4 +1,4 @@
const Keyv = require('keyv');
const { Keyv } = require('keyv');
const uap = require('ua-parser-js');
const { ViolationTypes } = require('librechat-data-provider');
const { isEnabled, removePorts } = require('~/server/utils');

View file

@ -1,4 +1,4 @@
const { Time } = require('librechat-data-provider');
const { Time, CacheKeys } = require('librechat-data-provider');
const clearPendingReq = require('~/cache/clearPendingReq');
const { logViolation, getLogStores } = require('~/cache');
const { isEnabled } = require('~/server/utils');
@ -25,7 +25,7 @@ const {
* @throws {Error} Throws an error if the user exceeds the concurrent request limit.
*/
const concurrentLimiter = async (req, res, next) => {
const namespace = 'pending_req';
const namespace = CacheKeys.PENDING_REQ;
const cache = getLogStores(namespace);
if (!cache) {
return next();

View file

@ -8,6 +8,7 @@ const concurrentLimiter = require('./concurrentLimiter');
const validateEndpoint = require('./validateEndpoint');
const requireLocalAuth = require('./requireLocalAuth');
const canDeleteAccount = require('./canDeleteAccount');
const setBalanceConfig = require('./setBalanceConfig');
const requireLdapAuth = require('./requireLdapAuth');
const abortMiddleware = require('./abortMiddleware');
const checkInviteUser = require('./checkInviteUser');
@ -41,6 +42,7 @@ module.exports = {
requireLocalAuth,
canDeleteAccount,
validateEndpoint,
setBalanceConfig,
concurrentLimiter,
checkDomainAllowed,
validateMessageReq,

View file

@ -1,10 +1,9 @@
const Keyv = require('keyv');
const rateLimit = require('express-rate-limit');
const { RedisStore } = require('rate-limit-redis');
const { ViolationTypes } = require('librechat-data-provider');
const ioredisClient = require('~/cache/ioredisClient');
const logViolation = require('~/cache/logViolation');
const { isEnabled } = require('~/server/utils');
const keyvRedis = require('~/cache/keyvRedis');
const { logger } = require('~/config');
const getEnvironmentVariables = () => {
@ -67,11 +66,9 @@ const createImportLimiters = () => {
},
};
if (isEnabled(process.env.USE_REDIS)) {
if (isEnabled(process.env.USE_REDIS) && ioredisClient) {
logger.debug('Using Redis for import rate limiters.');
const keyv = new Keyv({ store: keyvRedis });
const client = keyv.opts.store.redis;
const sendCommand = (...args) => client.call(...args);
const sendCommand = (...args) => ioredisClient.call(...args);
const ipStore = new RedisStore({
sendCommand,
prefix: 'import_ip_limiter:',

View file

@ -1,8 +1,7 @@
const Keyv = require('keyv');
const rateLimit = require('express-rate-limit');
const { RedisStore } = require('rate-limit-redis');
const { removePorts, isEnabled } = require('~/server/utils');
const keyvRedis = require('~/cache/keyvRedis');
const ioredisClient = require('~/cache/ioredisClient');
const { logViolation } = require('~/cache');
const { logger } = require('~/config');
@ -31,13 +30,10 @@ const limiterOptions = {
keyGenerator: removePorts,
};
if (isEnabled(process.env.USE_REDIS)) {
if (isEnabled(process.env.USE_REDIS) && ioredisClient) {
logger.debug('Using Redis for login rate limiter.');
const keyv = new Keyv({ store: keyvRedis });
const client = keyv.opts.store.redis;
const sendCommand = (...args) => client.call(...args);
const store = new RedisStore({
sendCommand,
sendCommand: (...args) => ioredisClient.call(...args),
prefix: 'login_limiter:',
});
limiterOptions.store = store;

View file

@ -1,9 +1,8 @@
const Keyv = require('keyv');
const rateLimit = require('express-rate-limit');
const { RedisStore } = require('rate-limit-redis');
const denyRequest = require('~/server/middleware/denyRequest');
const ioredisClient = require('~/cache/ioredisClient');
const { isEnabled } = require('~/server/utils');
const keyvRedis = require('~/cache/keyvRedis');
const { logViolation } = require('~/cache');
const { logger } = require('~/config');
@ -63,11 +62,9 @@ const userLimiterOptions = {
},
};
if (isEnabled(process.env.USE_REDIS)) {
if (isEnabled(process.env.USE_REDIS) && ioredisClient) {
logger.debug('Using Redis for message rate limiters.');
const keyv = new Keyv({ store: keyvRedis });
const client = keyv.opts.store.redis;
const sendCommand = (...args) => client.call(...args);
const sendCommand = (...args) => ioredisClient.call(...args);
const ipStore = new RedisStore({
sendCommand,
prefix: 'message_ip_limiter:',

View file

@ -1,8 +1,7 @@
const Keyv = require('keyv');
const rateLimit = require('express-rate-limit');
const { RedisStore } = require('rate-limit-redis');
const { removePorts, isEnabled } = require('~/server/utils');
const keyvRedis = require('~/cache/keyvRedis');
const ioredisClient = require('~/cache/ioredisClient');
const { logViolation } = require('~/cache');
const { logger } = require('~/config');
@ -31,13 +30,10 @@ const limiterOptions = {
keyGenerator: removePorts,
};
if (isEnabled(process.env.USE_REDIS)) {
if (isEnabled(process.env.USE_REDIS) && ioredisClient) {
logger.debug('Using Redis for register rate limiter.');
const keyv = new Keyv({ store: keyvRedis });
const client = keyv.opts.store.redis;
const sendCommand = (...args) => client.call(...args);
const store = new RedisStore({
sendCommand,
sendCommand: (...args) => ioredisClient.call(...args),
prefix: 'register_limiter:',
});
limiterOptions.store = store;

View file

@ -1,9 +1,8 @@
const Keyv = require('keyv');
const rateLimit = require('express-rate-limit');
const { RedisStore } = require('rate-limit-redis');
const { ViolationTypes } = require('librechat-data-provider');
const { removePorts, isEnabled } = require('~/server/utils');
const keyvRedis = require('~/cache/keyvRedis');
const ioredisClient = require('~/cache/ioredisClient');
const { logViolation } = require('~/cache');
const { logger } = require('~/config');
@ -36,13 +35,10 @@ const limiterOptions = {
keyGenerator: removePorts,
};
if (isEnabled(process.env.USE_REDIS)) {
if (isEnabled(process.env.USE_REDIS) && ioredisClient) {
logger.debug('Using Redis for reset password rate limiter.');
const keyv = new Keyv({ store: keyvRedis });
const client = keyv.opts.store.redis;
const sendCommand = (...args) => client.call(...args);
const store = new RedisStore({
sendCommand,
sendCommand: (...args) => ioredisClient.call(...args),
prefix: 'reset_password_limiter:',
});
limiterOptions.store = store;

View file

@ -1,10 +1,9 @@
const Keyv = require('keyv');
const rateLimit = require('express-rate-limit');
const { RedisStore } = require('rate-limit-redis');
const { ViolationTypes } = require('librechat-data-provider');
const ioredisClient = require('~/cache/ioredisClient');
const logViolation = require('~/cache/logViolation');
const { isEnabled } = require('~/server/utils');
const keyvRedis = require('~/cache/keyvRedis');
const { logger } = require('~/config');
const getEnvironmentVariables = () => {
@ -67,11 +66,9 @@ const createSTTLimiters = () => {
},
};
if (isEnabled(process.env.USE_REDIS)) {
if (isEnabled(process.env.USE_REDIS) && ioredisClient) {
logger.debug('Using Redis for STT rate limiters.');
const keyv = new Keyv({ store: keyvRedis });
const client = keyv.opts.store.redis;
const sendCommand = (...args) => client.call(...args);
const sendCommand = (...args) => ioredisClient.call(...args);
const ipStore = new RedisStore({
sendCommand,
prefix: 'stt_ip_limiter:',

View file

@ -1,10 +1,9 @@
const Keyv = require('keyv');
const rateLimit = require('express-rate-limit');
const { RedisStore } = require('rate-limit-redis');
const { ViolationTypes } = require('librechat-data-provider');
const ioredisClient = require('~/cache/ioredisClient');
const logViolation = require('~/cache/logViolation');
const { isEnabled } = require('~/server/utils');
const keyvRedis = require('~/cache/keyvRedis');
const { logger } = require('~/config');
const handler = async (req, res) => {
@ -29,13 +28,10 @@ const limiterOptions = {
},
};
if (isEnabled(process.env.USE_REDIS)) {
if (isEnabled(process.env.USE_REDIS) && ioredisClient) {
logger.debug('Using Redis for tool call rate limiter.');
const keyv = new Keyv({ store: keyvRedis });
const client = keyv.opts.store.redis;
const sendCommand = (...args) => client.call(...args);
const store = new RedisStore({
sendCommand,
sendCommand: (...args) => ioredisClient.call(...args),
prefix: 'tool_call_limiter:',
});
limiterOptions.store = store;

View file

@ -1,10 +1,9 @@
const Keyv = require('keyv');
const rateLimit = require('express-rate-limit');
const { RedisStore } = require('rate-limit-redis');
const { ViolationTypes } = require('librechat-data-provider');
const ioredisClient = require('~/cache/ioredisClient');
const logViolation = require('~/cache/logViolation');
const { isEnabled } = require('~/server/utils');
const keyvRedis = require('~/cache/keyvRedis');
const { logger } = require('~/config');
const getEnvironmentVariables = () => {
@ -67,11 +66,9 @@ const createTTSLimiters = () => {
},
};
if (isEnabled(process.env.USE_REDIS)) {
if (isEnabled(process.env.USE_REDIS) && ioredisClient) {
logger.debug('Using Redis for TTS rate limiters.');
const keyv = new Keyv({ store: keyvRedis });
const client = keyv.opts.store.redis;
const sendCommand = (...args) => client.call(...args);
const sendCommand = (...args) => ioredisClient.call(...args);
const ipStore = new RedisStore({
sendCommand,
prefix: 'tts_ip_limiter:',

View file

@ -1,10 +1,9 @@
const Keyv = require('keyv');
const rateLimit = require('express-rate-limit');
const { RedisStore } = require('rate-limit-redis');
const { ViolationTypes } = require('librechat-data-provider');
const ioredisClient = require('~/cache/ioredisClient');
const logViolation = require('~/cache/logViolation');
const { isEnabled } = require('~/server/utils');
const keyvRedis = require('~/cache/keyvRedis');
const { logger } = require('~/config');
const getEnvironmentVariables = () => {
@ -72,11 +71,9 @@ const createFileLimiters = () => {
},
};
if (isEnabled(process.env.USE_REDIS)) {
if (isEnabled(process.env.USE_REDIS) && ioredisClient) {
logger.debug('Using Redis for file upload rate limiters.');
const keyv = new Keyv({ store: keyvRedis });
const client = keyv.opts.store.redis;
const sendCommand = (...args) => client.call(...args);
const sendCommand = (...args) => ioredisClient.call(...args);
const ipStore = new RedisStore({
sendCommand,
prefix: 'file_upload_ip_limiter:',

View file

@ -1,9 +1,8 @@
const Keyv = require('keyv');
const rateLimit = require('express-rate-limit');
const { RedisStore } = require('rate-limit-redis');
const { ViolationTypes } = require('librechat-data-provider');
const { removePorts, isEnabled } = require('~/server/utils');
const keyvRedis = require('~/cache/keyvRedis');
const ioredisClient = require('~/cache/ioredisClient');
const { logViolation } = require('~/cache');
const { logger } = require('~/config');
@ -36,13 +35,10 @@ const limiterOptions = {
keyGenerator: removePorts,
};
if (isEnabled(process.env.USE_REDIS)) {
if (isEnabled(process.env.USE_REDIS) && ioredisClient) {
logger.debug('Using Redis for verify email rate limiter.');
const keyv = new Keyv({ store: keyvRedis });
const client = keyv.opts.store.redis;
const sendCommand = (...args) => client.call(...args);
const store = new RedisStore({
sendCommand,
sendCommand: (...args) => ioredisClient.call(...args),
prefix: 'verify_email_limiter:',
});
limiterOptions.store = store;

View file

@ -1,39 +1,41 @@
const axios = require('axios');
const { ErrorTypes } = require('librechat-data-provider');
const { isEnabled } = require('~/server/utils');
const denyRequest = require('./denyRequest');
const { logger } = require('~/config');
async function moderateText(req, res, next) {
if (process.env.OPENAI_MODERATION === 'true') {
try {
const { text } = req.body;
if (!isEnabled(process.env.OPENAI_MODERATION)) {
return next();
}
try {
const { text } = req.body;
const response = await axios.post(
process.env.OPENAI_MODERATION_REVERSE_PROXY || 'https://api.openai.com/v1/moderations',
{
input: text,
const response = await axios.post(
process.env.OPENAI_MODERATION_REVERSE_PROXY || 'https://api.openai.com/v1/moderations',
{
input: text,
},
{
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${process.env.OPENAI_MODERATION_API_KEY}`,
},
{
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${process.env.OPENAI_MODERATION_API_KEY}`,
},
},
);
},
);
const results = response.data.results;
const flagged = results.some((result) => result.flagged);
const results = response.data.results;
const flagged = results.some((result) => result.flagged);
if (flagged) {
const type = ErrorTypes.MODERATION;
const errorMessage = { type };
return await denyRequest(req, res, errorMessage);
}
} catch (error) {
logger.error('Error in moderateText:', error);
const errorMessage = 'error in moderation check';
if (flagged) {
const type = ErrorTypes.MODERATION;
const errorMessage = { type };
return await denyRequest(req, res, errorMessage);
}
} catch (error) {
logger.error('Error in moderateText:', error);
const errorMessage = 'error in moderation check';
return await denyRequest(req, res, errorMessage);
}
next();
}

View file

@ -17,9 +17,9 @@ const checkAccess = async (user, permissionType, permissions, bodyProps = {}, ch
}
const role = await getRoleByName(user.role);
if (role && role[permissionType]) {
if (role && role.permissions && role.permissions[permissionType]) {
const hasAnyPermission = permissions.some((permission) => {
if (role[permissionType][permission]) {
if (role.permissions[permissionType][permission]) {
return true;
}

View file

@ -0,0 +1,91 @@
const { getBalanceConfig } = require('~/server/services/Config');
const Balance = require('~/models/Balance');
const { logger } = require('~/config');
/**
* Middleware to synchronize user balance settings with current balance configuration.
* @function
* @param {Object} req - Express request object containing user information.
* @param {Object} res - Express response object.
* @param {import('express').NextFunction} next - Next middleware function.
*/
const setBalanceConfig = async (req, res, next) => {
try {
const balanceConfig = await getBalanceConfig();
if (!balanceConfig?.enabled) {
return next();
}
if (balanceConfig.startBalance == null) {
return next();
}
const userId = req.user._id;
const userBalanceRecord = await Balance.findOne({ user: userId }).lean();
const updateFields = buildUpdateFields(balanceConfig, userBalanceRecord);
if (Object.keys(updateFields).length === 0) {
return next();
}
await Balance.findOneAndUpdate(
{ user: userId },
{ $set: updateFields },
{ upsert: true, new: true },
);
next();
} catch (error) {
logger.error('Error setting user balance:', error);
next(error);
}
};
/**
* Build an object containing fields that need updating
* @param {Object} config - The balance configuration
* @param {Object|null} userRecord - The user's current balance record, if any
* @returns {Object} Fields that need updating
*/
function buildUpdateFields(config, userRecord) {
const updateFields = {};
// Ensure user record has the required fields
if (!userRecord) {
updateFields.user = userRecord?.user;
updateFields.tokenCredits = config.startBalance;
}
if (userRecord?.tokenCredits == null && config.startBalance != null) {
updateFields.tokenCredits = config.startBalance;
}
const isAutoRefillConfigValid =
config.autoRefillEnabled &&
config.refillIntervalValue != null &&
config.refillIntervalUnit != null &&
config.refillAmount != null;
if (!isAutoRefillConfigValid) {
return updateFields;
}
if (userRecord?.autoRefillEnabled !== config.autoRefillEnabled) {
updateFields.autoRefillEnabled = config.autoRefillEnabled;
}
if (userRecord?.refillIntervalValue !== config.refillIntervalValue) {
updateFields.refillIntervalValue = config.refillIntervalValue;
}
if (userRecord?.refillIntervalUnit !== config.refillIntervalUnit) {
updateFields.refillIntervalUnit = config.refillIntervalUnit;
}
if (userRecord?.refillAmount !== config.refillAmount) {
updateFields.refillAmount = config.refillAmount;
}
return updateFields;
}
module.exports = setBalanceConfig;

View file

@ -1,5 +1,6 @@
const express = require('express');
const jwt = require('jsonwebtoken');
const { CacheKeys } = require('librechat-data-provider');
const { getAccessToken } = require('~/server/services/TokenService');
const { logger, getFlowStateManager } = require('~/config');
const { getLogStores } = require('~/cache');
@ -19,8 +20,8 @@ const JWT_SECRET = process.env.JWT_SECRET;
router.get('/:action_id/oauth/callback', async (req, res) => {
const { action_id } = req.params;
const { code, state } = req.query;
const flowManager = await getFlowStateManager(getLogStores);
const flowsCache = getLogStores(CacheKeys.FLOWS);
const flowManager = getFlowStateManager(flowsCache);
let identifier = action_id;
try {
let decodedState;

View file

@ -58,7 +58,7 @@ router.post('/:agent_id', async (req, res) => {
}
let { domain } = metadata;
domain = await domainParser(req, domain, true);
domain = await domainParser(domain, true);
if (!domain) {
return res.status(400).json({ message: 'No domain provided' });
@ -164,7 +164,7 @@ router.delete('/:agent_id/:action_id', async (req, res) => {
return true;
});
domain = await domainParser(req, domain, true);
domain = await domainParser(domain, true);
if (!domain) {
return res.status(400).json({ message: 'No domain provided' });

View file

@ -2,7 +2,7 @@ const express = require('express');
const { PermissionTypes, Permissions } = require('librechat-data-provider');
const {
setHeaders,
handleAbort,
moderateText,
// validateModel,
generateCheckAccess,
validateConvoAccess,
@ -14,28 +14,37 @@ const addTitle = require('~/server/services/Endpoints/agents/title');
const router = express.Router();
router.post('/abort', handleAbort());
router.use(moderateText);
const checkAgentAccess = generateCheckAccess(PermissionTypes.AGENTS, [Permissions.USE]);
router.use(checkAgentAccess);
router.use(validateConvoAccess);
router.use(buildEndpointOption);
router.use(setHeaders);
const controller = async (req, res, next) => {
await AgentController(req, res, next, initializeClient, addTitle);
};
/**
* @route POST /
* @route POST / (regular endpoint)
* @desc Chat with an assistant
* @access Public
* @param {express.Request} req - The request object, containing the request data.
* @param {express.Response} res - The response object, used to send back a response.
* @returns {void}
*/
router.post(
'/',
// validateModel,
checkAgentAccess,
validateConvoAccess,
buildEndpointOption,
setHeaders,
async (req, res, next) => {
await AgentController(req, res, next, initializeClient, addTitle);
},
);
router.post('/', controller);
/**
* @route POST /:endpoint (ephemeral agents)
* @desc Chat with an assistant
* @access Public
* @param {express.Request} req - The request object, containing the request data.
* @param {express.Response} res - The response object, used to send back a response.
* @returns {void}
*/
router.post('/:endpoint', controller);
module.exports = router;

View file

@ -1,21 +1,40 @@
const express = require('express');
const router = express.Router();
const {
uaParser,
checkBan,
requireJwtAuth,
// concurrentLimiter,
// messageIpLimiter,
// messageUserLimiter,
messageIpLimiter,
concurrentLimiter,
messageUserLimiter,
} = require('~/server/middleware');
const { isEnabled } = require('~/server/utils');
const { v1 } = require('./v1');
const chat = require('./chat');
const { LIMIT_CONCURRENT_MESSAGES, LIMIT_MESSAGE_IP, LIMIT_MESSAGE_USER } = process.env ?? {};
const router = express.Router();
router.use(requireJwtAuth);
router.use(checkBan);
router.use(uaParser);
router.use('/', v1);
router.use('/chat', chat);
const chatRouter = express.Router();
if (isEnabled(LIMIT_CONCURRENT_MESSAGES)) {
chatRouter.use(concurrentLimiter);
}
if (isEnabled(LIMIT_MESSAGE_IP)) {
chatRouter.use(messageIpLimiter);
}
if (isEnabled(LIMIT_MESSAGE_USER)) {
chatRouter.use(messageUserLimiter);
}
chatRouter.use('/', chat);
router.use('/chat', chatRouter);
module.exports = router;

View file

@ -1,4 +1,4 @@
const Keyv = require('keyv');
const { Keyv } = require('keyv');
const { KeyvFile } = require('keyv-file');
const { logger } = require('~/config');

View file

@ -11,8 +11,6 @@ const {
const router = express.Router();
router.post('/abort', handleAbort());
router.post(
'/',
validateEndpoint,

View file

@ -3,7 +3,6 @@ const AskController = require('~/server/controllers/AskController');
const { initializeClient } = require('~/server/services/Endpoints/custom');
const { addTitle } = require('~/server/services/Endpoints/openAI');
const {
handleAbort,
setHeaders,
validateModel,
validateEndpoint,
@ -12,8 +11,6 @@ const {
const router = express.Router();
router.post('/abort', handleAbort());
router.post(
'/',
validateEndpoint,

View file

@ -3,7 +3,6 @@ const AskController = require('~/server/controllers/AskController');
const { initializeClient, addTitle } = require('~/server/services/Endpoints/google');
const {
setHeaders,
handleAbort,
validateModel,
validateEndpoint,
buildEndpointOption,
@ -11,8 +10,6 @@ const {
const router = express.Router();
router.post('/abort', handleAbort());
router.post(
'/',
validateEndpoint,

View file

@ -20,7 +20,6 @@ const { logger } = require('~/config');
const router = express.Router();
router.use(moderateText);
router.post('/abort', handleAbort());
router.post(
'/',
@ -196,7 +195,8 @@ router.post(
logger.debug('[/ask/gptPlugins]', response);
const { conversation = {} } = await client.responsePromise;
const { conversation = {} } = await response.databasePromise;
delete response.databasePromise;
conversation.title =
conversation && !conversation.title ? null : conversation?.title || 'New Chat';

View file

@ -1,10 +1,4 @@
const express = require('express');
const openAI = require('./openAI');
const custom = require('./custom');
const google = require('./google');
const anthropic = require('./anthropic');
const gptPlugins = require('./gptPlugins');
const { isEnabled } = require('~/server/utils');
const { EModelEndpoint } = require('librechat-data-provider');
const {
uaParser,
@ -15,6 +9,12 @@ const {
messageUserLimiter,
validateConvoAccess,
} = require('~/server/middleware');
const { isEnabled } = require('~/server/utils');
const gptPlugins = require('./gptPlugins');
const anthropic = require('./anthropic');
const custom = require('./custom');
const google = require('./google');
const openAI = require('./openAI');
const { LIMIT_CONCURRENT_MESSAGES, LIMIT_MESSAGE_IP, LIMIT_MESSAGE_USER } = process.env ?? {};

View file

@ -12,7 +12,6 @@ const {
const router = express.Router();
router.use(moderateText);
router.post('/abort', handleAbort());
router.post(
'/',

View file

@ -36,7 +36,7 @@ router.post('/:assistant_id', async (req, res) => {
}
let { domain } = metadata;
domain = await domainParser(req, domain, true);
domain = await domainParser(domain, true);
if (!domain) {
return res.status(400).json({ message: 'No domain provided' });
@ -172,7 +172,7 @@ router.delete('/:assistant_id/:action_id/:model', async (req, res) => {
return true;
});
domain = await domainParser(req, domain, true);
domain = await domainParser(domain, true);
if (!domain) {
return res.status(400).json({ message: 'No domain provided' });

View file

@ -23,6 +23,7 @@ const {
checkInviteUser,
registerLimiter,
requireLdapAuth,
setBalanceConfig,
requireLocalAuth,
resetPasswordLimiter,
validateRegistration,
@ -40,6 +41,7 @@ router.post(
loginLimiter,
checkBan,
ldapAuth ? requireLdapAuth : requireLocalAuth,
setBalanceConfig,
loginController,
);
router.post('/refresh', refreshController);

View file

@ -4,6 +4,7 @@ const router = express.Router();
const {
setHeaders,
handleAbort,
moderateText,
// validateModel,
// validateEndpoint,
buildEndpointOption,
@ -12,7 +13,7 @@ const { initializeClient } = require('~/server/services/Endpoints/bedrock');
const AgentController = require('~/server/controllers/agents/request');
const addTitle = require('~/server/services/Endpoints/agents/title');
router.post('/abort', handleAbort());
router.use(moderateText);
/**
* @route POST /

View file

@ -1,19 +1,35 @@
const express = require('express');
const router = express.Router();
const {
uaParser,
checkBan,
requireJwtAuth,
// concurrentLimiter,
// messageIpLimiter,
// messageUserLimiter,
messageIpLimiter,
concurrentLimiter,
messageUserLimiter,
} = require('~/server/middleware');
const { isEnabled } = require('~/server/utils');
const chat = require('./chat');
const { LIMIT_CONCURRENT_MESSAGES, LIMIT_MESSAGE_IP, LIMIT_MESSAGE_USER } = process.env ?? {};
const router = express.Router();
router.use(requireJwtAuth);
router.use(checkBan);
router.use(uaParser);
if (isEnabled(LIMIT_CONCURRENT_MESSAGES)) {
router.use(concurrentLimiter);
}
if (isEnabled(LIMIT_MESSAGE_IP)) {
router.use(messageIpLimiter);
}
if (isEnabled(LIMIT_MESSAGE_USER)) {
router.use(messageUserLimiter);
}
router.use('/chat', chat);
module.exports = router;

View file

@ -68,7 +68,6 @@ router.get('/', async function (req, res) {
!!process.env.EMAIL_PASSWORD &&
!!process.env.EMAIL_FROM,
passwordResetEnabled,
checkBalance: isEnabled(process.env.CHECK_BALANCE),
showBirthdayIcon:
isBirthday() ||
isEnabled(process.env.SHOW_BIRTHDAY_ICON) ||
@ -76,11 +75,13 @@ router.get('/', async function (req, res) {
helpAndFaqURL: process.env.HELP_AND_FAQ_URL || 'https://librechat.ai',
interface: req.app.locals.interfaceConfig,
modelSpecs: req.app.locals.modelSpecs,
balance: req.app.locals.balance,
sharedLinksEnabled,
publicSharedLinksEnabled,
analyticsGtmId: process.env.ANALYTICS_GTM_ID,
instanceProjectId: instanceProject._id.toString(),
bundlerURL: process.env.SANDPACK_BUNDLER_URL,
staticBundlerURL: process.env.SANDPACK_STATIC_BUNDLER_URL,
};
if (ldap) {

View file

@ -1,16 +1,17 @@
const multer = require('multer');
const express = require('express');
const { CacheKeys, EModelEndpoint } = require('librechat-data-provider');
const { getConvosByPage, deleteConvos, getConvo, saveConvo } = require('~/models/Conversation');
const { getConvosByCursor, deleteConvos, getConvo, saveConvo } = require('~/models/Conversation');
const { forkConversation, duplicateConversation } = require('~/server/utils/import/fork');
const { storage, importFileFilter } = require('~/server/routes/files/multer');
const requireJwtAuth = require('~/server/middleware/requireJwtAuth');
const { importConversations } = require('~/server/utils/import');
const { createImportLimiters } = require('~/server/middleware');
const { deleteToolCalls } = require('~/models/ToolCall');
const { isEnabled, sleep } = require('~/server/utils');
const getLogStores = require('~/cache/getLogStores');
const { sleep } = require('~/server/utils');
const { logger } = require('~/config');
const assistantClients = {
[EModelEndpoint.azureAssistants]: require('~/server/services/Endpoints/azureAssistants'),
[EModelEndpoint.assistants]: require('~/server/services/Endpoints/assistants'),
@ -20,28 +21,30 @@ const router = express.Router();
router.use(requireJwtAuth);
router.get('/', async (req, res) => {
let pageNumber = req.query.pageNumber || 1;
pageNumber = parseInt(pageNumber, 10);
const limit = parseInt(req.query.limit, 10) || 25;
const cursor = req.query.cursor;
const isArchived = isEnabled(req.query.isArchived);
const search = req.query.search ? decodeURIComponent(req.query.search) : undefined;
const order = req.query.order || 'desc';
if (isNaN(pageNumber) || pageNumber < 1) {
return res.status(400).json({ error: 'Invalid page number' });
}
let pageSize = req.query.pageSize || 25;
pageSize = parseInt(pageSize, 10);
if (isNaN(pageSize) || pageSize < 1) {
return res.status(400).json({ error: 'Invalid page size' });
}
const isArchived = req.query.isArchived === 'true';
let tags;
if (req.query.tags) {
tags = Array.isArray(req.query.tags) ? req.query.tags : [req.query.tags];
} else {
tags = undefined;
}
res.status(200).send(await getConvosByPage(req.user.id, pageNumber, pageSize, isArchived, tags));
try {
const result = await getConvosByCursor(req.user.id, {
cursor,
limit,
isArchived,
tags,
search,
order,
});
res.status(200).json(result);
} catch (error) {
res.status(500).json({ error: 'Error fetching conversations' });
}
});
router.get('/:conversationId', async (req, res) => {
@ -76,22 +79,28 @@ router.post('/gen_title', async (req, res) => {
}
});
router.post('/clear', async (req, res) => {
router.delete('/', async (req, res) => {
let filter = {};
const { conversationId, source, thread_id, endpoint } = req.body.arg;
if (conversationId) {
filter = { conversationId };
// Prevent deletion of all conversations
if (!conversationId && !source && !thread_id && !endpoint) {
return res.status(400).json({
error: 'no parameters provided',
});
}
if (source === 'button' && !conversationId) {
if (conversationId) {
filter = { conversationId };
} else if (source === 'button') {
return res.status(200).send('No conversationId provided');
}
if (
typeof endpoint != 'undefined' &&
typeof endpoint !== 'undefined' &&
Object.prototype.propertyIsEnumerable.call(assistantClients, endpoint)
) {
/** @type {{ openai: OpenAI}} */
/** @type {{ openai: OpenAI }} */
const { openai } = await assistantClients[endpoint].initializeClient({ req, res });
try {
const response = await openai.beta.threads.del(thread_id);
@ -101,9 +110,6 @@ router.post('/clear', async (req, res) => {
}
}
// for debugging deletion source
// logger.debug('source:', source);
try {
const dbResponse = await deleteConvos(req.user.id, filter);
await deleteToolCalls(req.user.id, filter.conversationId);
@ -114,6 +120,17 @@ router.post('/clear', async (req, res) => {
}
});
router.delete('/all', async (req, res) => {
try {
const dbResponse = await deleteConvos(req.user.id, {});
await deleteToolCalls(req.user.id);
res.status(201).json(dbResponse);
} catch (error) {
logger.error('Error clearing conversations', error);
res.status(500).send('Error clearing conversations');
}
});
router.post('/update', async (req, res) => {
const update = req.body.arg;

View file

@ -3,7 +3,6 @@ const EditController = require('~/server/controllers/EditController');
const { initializeClient } = require('~/server/services/Endpoints/anthropic');
const {
setHeaders,
handleAbort,
validateModel,
validateEndpoint,
buildEndpointOption,
@ -11,8 +10,6 @@ const {
const router = express.Router();
router.post('/abort', handleAbort());
router.post(
'/',
validateEndpoint,

View file

@ -12,8 +12,6 @@ const {
const router = express.Router();
router.post('/abort', handleAbort());
router.post(
'/',
validateEndpoint,

View file

@ -3,7 +3,6 @@ const EditController = require('~/server/controllers/EditController');
const { initializeClient } = require('~/server/services/Endpoints/google');
const {
setHeaders,
handleAbort,
validateModel,
validateEndpoint,
buildEndpointOption,
@ -11,8 +10,6 @@ const {
const router = express.Router();
router.post('/abort', handleAbort());
router.post(
'/',
validateEndpoint,

View file

@ -2,7 +2,6 @@ const express = require('express');
const { getResponseSender } = require('librechat-data-provider');
const {
setHeaders,
handleAbort,
moderateText,
validateModel,
handleAbortError,
@ -19,7 +18,6 @@ const { logger } = require('~/config');
const router = express.Router();
router.use(moderateText);
router.post('/abort', handleAbort());
router.post(
'/',
@ -173,7 +171,8 @@ router.post(
logger.debug('[/edit/gptPlugins] CLIENT RESPONSE', response);
const { conversation = {} } = await client.responsePromise;
const { conversation = {} } = await response.databasePromise;
delete response.databasePromise;
conversation.title =
conversation && !conversation.title ? null : conversation?.title || 'New Chat';

View file

@ -2,7 +2,6 @@ const express = require('express');
const EditController = require('~/server/controllers/EditController');
const { initializeClient } = require('~/server/services/Endpoints/openAI');
const {
handleAbort,
setHeaders,
validateModel,
validateEndpoint,
@ -12,7 +11,6 @@ const {
const router = express.Router();
router.use(moderateText);
router.post('/abort', handleAbort());
router.post(
'/',

View file

@ -2,7 +2,9 @@ const fs = require('fs').promises;
const express = require('express');
const { EnvVar } = require('@librechat/agents');
const {
Time,
isUUID,
CacheKeys,
FileSources,
EModelEndpoint,
isAgentsEndpoint,
@ -17,8 +19,11 @@ const {
const { getStrategyFunctions } = require('~/server/services/Files/strategies');
const { getOpenAIClient } = require('~/server/controllers/assistants/helpers');
const { loadAuthValues } = require('~/server/services/Tools/credentials');
const { refreshS3FileUrls } = require('~/server/services/Files/S3/crud');
const { getFiles, batchUpdateFiles } = require('~/models/File');
const { getAssistant } = require('~/models/Assistant');
const { getAgent } = require('~/models/Agent');
const { getFiles } = require('~/models/File');
const { getLogStores } = require('~/cache');
const { logger } = require('~/config');
const router = express.Router();
@ -26,6 +31,18 @@ const router = express.Router();
router.get('/', async (req, res) => {
try {
const files = await getFiles({ user: req.user.id });
if (req.app.locals.fileStrategy === FileSources.s3) {
try {
const cache = getLogStores(CacheKeys.S3_EXPIRY_INTERVAL);
const alreadyChecked = await cache.get(req.user.id);
if (!alreadyChecked) {
await refreshS3FileUrls(files, batchUpdateFiles);
await cache.set(req.user.id, true, Time.THIRTY_MINUTES);
}
} catch (error) {
logger.warn('[/files] Error refreshing S3 file URLs:', error);
}
}
res.status(200).send(files);
} catch (error) {
logger.error('[/files] Error getting files:', error);
@ -78,7 +95,7 @@ router.delete('/', async (req, res) => {
});
}
/* Handle entity unlinking even if no valid files to delete */
/* Handle agent unlinking even if no valid files to delete */
if (req.body.agent_id && req.body.tool_resource && dbFiles.length === 0) {
const agent = await getAgent({
id: req.body.agent_id,
@ -88,7 +105,21 @@ router.delete('/', async (req, res) => {
const agentFiles = files.filter((f) => toolResourceFiles.includes(f.file_id));
await processDeleteRequest({ req, files: agentFiles });
res.status(200).json({ message: 'File associations removed successfully' });
res.status(200).json({ message: 'File associations removed successfully from agent' });
return;
}
/* Handle assistant unlinking even if no valid files to delete */
if (req.body.assistant_id && req.body.tool_resource && dbFiles.length === 0) {
const assistant = await getAssistant({
id: req.body.assistant_id,
});
const toolResourceFiles = assistant.tool_resources?.[req.body.tool_resource]?.file_ids ?? [];
const assistantFiles = files.filter((f) => toolResourceFiles.includes(f.file_id));
await processDeleteRequest({ req, files: assistantFiles });
res.status(200).json({ message: 'File associations removed successfully from assistant' });
return;
}

View file

@ -10,6 +10,7 @@ const balance = require('./balance');
const plugins = require('./plugins');
const bedrock = require('./bedrock');
const actions = require('./actions');
const banner = require('./banner');
const search = require('./search');
const models = require('./models');
const convos = require('./convos');
@ -25,7 +26,6 @@ const edit = require('./edit');
const keys = require('./keys');
const user = require('./user');
const ask = require('./ask');
const banner = require('./banner');
module.exports = {
ask,
@ -38,13 +38,14 @@ module.exports = {
oauth,
files,
share,
banner,
agents,
bedrock,
convos,
search,
prompts,
config,
models,
bedrock,
prompts,
plugins,
actions,
presets,
@ -55,5 +56,4 @@ module.exports = {
assistants,
categories,
staticRoute,
banner,
};

View file

@ -10,12 +10,90 @@ const {
} = require('~/models');
const { findAllArtifacts, replaceArtifactContent } = require('~/server/services/Artifacts/update');
const { requireJwtAuth, validateMessageReq } = require('~/server/middleware');
const { cleanUpPrimaryKeyValue } = require('~/lib/utils/misc');
const { getConvosQueried } = require('~/models/Conversation');
const { countTokens } = require('~/server/utils');
const { Message } = require('~/models/Message');
const { logger } = require('~/config');
const router = express.Router();
router.use(requireJwtAuth);
router.get('/', async (req, res) => {
try {
const user = req.user.id ?? '';
const {
cursor = null,
sortBy = 'createdAt',
sortDirection = 'desc',
pageSize: pageSizeRaw,
conversationId,
messageId,
search,
} = req.query;
const pageSize = parseInt(pageSizeRaw, 10) || 25;
let response;
const sortField = ['endpoint', 'createdAt', 'updatedAt'].includes(sortBy)
? sortBy
: 'createdAt';
const sortOrder = sortDirection === 'asc' ? 1 : -1;
if (conversationId && messageId) {
const message = await Message.findOne({ conversationId, messageId, user: user }).lean();
response = { messages: message ? [message] : [], nextCursor: null };
} else if (conversationId) {
const filter = { conversationId, user: user };
if (cursor) {
filter[sortField] = sortOrder === 1 ? { $gt: cursor } : { $lt: cursor };
}
const messages = await Message.find(filter)
.sort({ [sortField]: sortOrder })
.limit(pageSize + 1)
.lean();
const nextCursor = messages.length > pageSize ? messages.pop()[sortField] : null;
response = { messages, nextCursor };
} else if (search) {
const searchResults = await Message.meiliSearch(search, undefined, true);
const messages = searchResults.hits || [];
const result = await getConvosQueried(req.user.id, messages, cursor);
const activeMessages = [];
for (let i = 0; i < messages.length; i++) {
let message = messages[i];
if (message.conversationId.includes('--')) {
message.conversationId = cleanUpPrimaryKeyValue(message.conversationId);
}
if (result.convoMap[message.conversationId]) {
const convo = result.convoMap[message.conversationId];
const dbMessage = await getMessage({ user, messageId: message.messageId });
activeMessages.push({
...message,
title: convo.title,
conversationId: message.conversationId,
model: convo.model,
isCreatedByUser: dbMessage?.isCreatedByUser,
endpoint: dbMessage?.endpoint,
iconURL: dbMessage?.iconURL,
});
}
}
response = { messages: activeMessages, nextCursor: null };
} else {
response = { messages: [], nextCursor: null };
}
res.status(200).json(response);
} catch (error) {
logger.error('Error fetching messages:', error);
res.status(500).json({ error: 'Internal server error' });
}
});
router.post('/artifact/:messageId', async (req, res) => {
try {
const { messageId } = req.params;

View file

@ -1,7 +1,13 @@
// file deepcode ignore NoRateLimitingForLogin: Rate limiting is handled by the `loginLimiter` middleware
const express = require('express');
const passport = require('passport');
const { loginLimiter, logHeaders, checkBan, checkDomainAllowed } = require('~/server/middleware');
const {
checkBan,
logHeaders,
loginLimiter,
setBalanceConfig,
checkDomainAllowed,
} = require('~/server/middleware');
const { setAuthTokens } = require('~/server/services/AuthService');
const { logger } = require('~/config');
const { chooseOpenIdStrategy } = require('~/server/utils/openidHelper');
@ -57,6 +63,7 @@ router.get(
session: false,
scope: ['openid', 'profile', 'email'],
}),
setBalanceConfig,
oauthHandler,
);
@ -81,6 +88,7 @@ router.get(
scope: ['public_profile'],
profileFields: ['id', 'email', 'name'],
}),
setBalanceConfig,
oauthHandler,
);
@ -113,6 +121,7 @@ router.get(
next(err);
}
},
setBalanceConfig,
oauthHandler,
);
@ -135,6 +144,7 @@ router.get(
session: false,
scope: ['user:email', 'read:user'],
}),
setBalanceConfig,
oauthHandler,
);
@ -157,6 +167,7 @@ router.get(
session: false,
scope: ['identify', 'email'],
}),
setBalanceConfig,
oauthHandler,
);
@ -177,6 +188,7 @@ router.post(
failureMessage: true,
session: false,
}),
setBalanceConfig,
oauthHandler,
);

View file

@ -48,7 +48,7 @@ router.put('/:roleName/prompts', checkAdmin, async (req, res) => {
const { roleName: _r } = req.params;
// TODO: TEMP, use a better parsing for roleName
const roleName = _r.toUpperCase();
/** @type {TRole['PROMPTS']} */
/** @type {TRole['permissions']['PROMPTS']} */
const updates = req.body;
try {
@ -59,10 +59,16 @@ router.put('/:roleName/prompts', checkAdmin, async (req, res) => {
return res.status(404).send({ message: 'Role not found' });
}
const currentPermissions =
role.permissions?.[PermissionTypes.PROMPTS] || role[PermissionTypes.PROMPTS] || {};
const mergedUpdates = {
[PermissionTypes.PROMPTS]: {
...role[PermissionTypes.PROMPTS],
...parsedUpdates,
permissions: {
...role.permissions,
[PermissionTypes.PROMPTS]: {
...currentPermissions,
...parsedUpdates,
},
},
};
@ -81,7 +87,7 @@ router.put('/:roleName/agents', checkAdmin, async (req, res) => {
const { roleName: _r } = req.params;
// TODO: TEMP, use a better parsing for roleName
const roleName = _r.toUpperCase();
/** @type {TRole['AGENTS']} */
/** @type {TRole['permissions']['AGENTS']} */
const updates = req.body;
try {
@ -92,17 +98,23 @@ router.put('/:roleName/agents', checkAdmin, async (req, res) => {
return res.status(404).send({ message: 'Role not found' });
}
const currentPermissions =
role.permissions?.[PermissionTypes.AGENTS] || role[PermissionTypes.AGENTS] || {};
const mergedUpdates = {
[PermissionTypes.AGENTS]: {
...role[PermissionTypes.AGENTS],
...parsedUpdates,
permissions: {
...role.permissions,
[PermissionTypes.AGENTS]: {
...currentPermissions,
...parsedUpdates,
},
},
};
const updatedRole = await updateRoleByName(roleName, mergedUpdates);
res.status(200).send(updatedRole);
} catch (error) {
return res.status(400).send({ message: 'Invalid prompt permissions.', error: error.errors });
return res.status(400).send({ message: 'Invalid agent permissions.', error: error.errors });
}
});

View file

@ -1,93 +1,17 @@
const Keyv = require('keyv');
const express = require('express');
const { MeiliSearch } = require('meilisearch');
const { Conversation, getConvosQueried } = require('~/models/Conversation');
const requireJwtAuth = require('~/server/middleware/requireJwtAuth');
const { cleanUpPrimaryKeyValue } = require('~/lib/utils/misc');
const { reduceHits } = require('~/lib/utils/reduceHits');
const { isEnabled } = require('~/server/utils');
const { Message } = require('~/models/Message');
const keyvRedis = require('~/cache/keyvRedis');
const { logger } = require('~/config');
const router = express.Router();
const expiration = 60 * 1000;
const cache = isEnabled(process.env.USE_REDIS)
? new Keyv({ store: keyvRedis })
: new Keyv({ namespace: 'search', ttl: expiration });
router.use(requireJwtAuth);
router.get('/sync', async function (req, res) {
await Message.syncWithMeili();
await Conversation.syncWithMeili();
res.send('synced');
});
router.get('/', async function (req, res) {
try {
let user = req.user.id ?? '';
const { q } = req.query;
const pageNumber = req.query.pageNumber || 1;
const key = `${user}:search:${q}`;
const cached = await cache.get(key);
if (cached) {
logger.debug('[/search] cache hit: ' + key);
const { pages, pageSize, messages } = cached;
res
.status(200)
.send({ conversations: cached[pageNumber], pages, pageNumber, pageSize, messages });
return;
}
const messages = (await Message.meiliSearch(q, undefined, true)).hits;
const titles = (await Conversation.meiliSearch(q)).hits;
const sortedHits = reduceHits(messages, titles);
const result = await getConvosQueried(user, sortedHits, pageNumber);
const activeMessages = [];
for (let i = 0; i < messages.length; i++) {
let message = messages[i];
if (message.conversationId.includes('--')) {
message.conversationId = cleanUpPrimaryKeyValue(message.conversationId);
}
if (result.convoMap[message.conversationId]) {
const convo = result.convoMap[message.conversationId];
const { title, chatGptLabel, model } = convo;
message = { ...message, ...{ title, chatGptLabel, model } };
activeMessages.push(message);
}
}
result.messages = activeMessages;
if (result.cache) {
result.cache.messages = activeMessages;
cache.set(key, result.cache, expiration);
delete result.cache;
}
delete result.convoMap;
res.status(200).send(result);
} catch (error) {
logger.error('[/search] Error while searching messages & conversations', error);
res.status(500).send({ message: 'Error searching' });
}
});
router.get('/test', async function (req, res) {
const { q } = req.query;
const messages = (
await Message.meiliSearch(q, { attributesToHighlight: ['text'] }, true)
).hits.map((message) => {
const { _formatted, ...rest } = message;
return { ...rest, searchResult: true, text: _formatted.text };
});
res.send(messages);
});
router.get('/enable', async function (req, res) {
let result = false;
if (!isEnabled(process.env.SEARCH)) {
return res.send(false);
}
try {
const client = new MeiliSearch({
host: process.env.MEILI_HOST,
@ -95,8 +19,7 @@ router.get('/enable', async function (req, res) {
});
const { status } = await client.health();
result = status === 'available' && !!process.env.SEARCH;
return res.send(result);
return res.send(status === 'available');
} catch (error) {
return res.send(false);
}

View file

@ -13,7 +13,6 @@ const {
actionDomainSeparator,
} = require('librechat-data-provider');
const { refreshAccessToken } = require('~/server/services/TokenService');
const { isActionDomainAllowed } = require('~/server/services/domains');
const { logger, getFlowStateManager, sendEvent } = require('~/config');
const { encryptV2, decryptV2 } = require('~/server/utils/crypto');
const { getActions, deleteActions } = require('~/models/Action');
@ -51,7 +50,7 @@ const validateAndUpdateTool = async ({ req, tool, assistant_id }) => {
return null;
}
const parsedDomain = await domainParser(req, domain, true);
const parsedDomain = await domainParser(domain, true);
if (!parsedDomain) {
return null;
@ -67,16 +66,14 @@ const validateAndUpdateTool = async ({ req, tool, assistant_id }) => {
*
* Necessary due to `[a-zA-Z0-9_-]*` Regex Validation, limited to a 64-character maximum.
*
* @param {Express.Request} req - The Express Request object.
* @param {string} domain - The domain name to encode/decode.
* @param {boolean} inverse - False to decode from base64, true to encode to base64.
* @returns {Promise<string>} Encoded or decoded domain string.
*/
async function domainParser(req, domain, inverse = false) {
async function domainParser(domain, inverse = false) {
if (!domain) {
return;
}
const domainsCache = getLogStores(CacheKeys.ENCODED_DOMAINS);
const cachedDomain = await domainsCache.get(domain);
if (inverse && cachedDomain) {
@ -123,47 +120,39 @@ async function loadActionSets(searchParams) {
* Creates a general tool for an entire action set.
*
* @param {Object} params - The parameters for loading action sets.
* @param {ServerRequest} params.req
* @param {string} params.userId
* @param {ServerResponse} params.res
* @param {Action} params.action - The action set. Necessary for decrypting authentication values.
* @param {ActionRequest} params.requestBuilder - The ActionRequest builder class to execute the API call.
* @param {string | undefined} [params.name] - The name of the tool.
* @param {string | undefined} [params.description] - The description for the tool.
* @param {import('zod').ZodTypeAny | undefined} [params.zodSchema] - The Zod schema for tool input validation/definition
* @param {{ oauth_client_id?: string; oauth_client_secret?: string; }} params.encrypted - The encrypted values for the action.
* @returns { Promise<typeof tool | { _call: (toolInput: Object | string) => unknown}> } An object with `_call` method to execute the tool input.
*/
async function createActionTool({
req,
userId,
res,
action,
requestBuilder,
zodSchema,
name,
description,
encrypted,
}) {
const isDomainAllowed = await isActionDomainAllowed(action.metadata.domain);
if (!isDomainAllowed) {
return null;
}
const encrypted = {
oauth_client_id: action.metadata.oauth_client_id,
oauth_client_secret: action.metadata.oauth_client_secret,
};
action.metadata = await decryptMetadata(action.metadata);
/** @type {(toolInput: Object | string, config: GraphRunnableConfig) => Promise<unknown>} */
const _call = async (toolInput, config) => {
try {
/** @type {import('librechat-data-provider').ActionMetadataRuntime} */
const metadata = action.metadata;
const executor = requestBuilder.createExecutor();
const preparedExecutor = executor.setParams(toolInput);
const preparedExecutor = executor.setParams(toolInput ?? {});
if (metadata.auth && metadata.auth.type !== AuthTypeEnum.None) {
try {
if (metadata.auth.type === AuthTypeEnum.OAuth && metadata.auth.authorization_url) {
const action_id = action.action_id;
const identifier = `${req.user.id}:${action.action_id}`;
const identifier = `${userId}:${action.action_id}`;
const requestLogin = async () => {
const { args: _args, stepId, ...toolCall } = config.toolCall ?? {};
if (!stepId) {
@ -171,7 +160,7 @@ async function createActionTool({
}
const statePayload = {
nonce: nanoid(),
user: req.user.id,
user: userId,
action_id,
};
@ -198,26 +187,33 @@ async function createActionTool({
expires_at: Date.now() + Time.TWO_MINUTES,
},
};
const flowManager = await getFlowStateManager(getLogStores);
const flowsCache = getLogStores(CacheKeys.FLOWS);
const flowManager = getFlowStateManager(flowsCache);
await flowManager.createFlowWithHandler(
`${identifier}:login`,
`${identifier}:oauth_login:${config.metadata.thread_id}:${config.metadata.run_id}`,
'oauth_login',
async () => {
sendEvent(res, { event: GraphEvents.ON_RUN_STEP_DELTA, data });
logger.debug('Sent OAuth login request to client', { action_id, identifier });
return true;
},
config?.signal,
);
logger.debug('Waiting for OAuth Authorization response', { action_id, identifier });
const result = await flowManager.createFlow(identifier, 'oauth', {
state: stateToken,
userId: req.user.id,
client_url: metadata.auth.client_url,
redirect_uri: `${process.env.DOMAIN_CLIENT}/api/actions/${action_id}/oauth/callback`,
/** Encrypted values */
encrypted_oauth_client_id: encrypted.oauth_client_id,
encrypted_oauth_client_secret: encrypted.oauth_client_secret,
});
const result = await flowManager.createFlow(
identifier,
'oauth',
{
state: stateToken,
userId: userId,
client_url: metadata.auth.client_url,
redirect_uri: `${process.env.DOMAIN_CLIENT}/api/actions/${action_id}/oauth/callback`,
/** Encrypted values */
encrypted_oauth_client_id: encrypted.oauth_client_id,
encrypted_oauth_client_secret: encrypted.oauth_client_secret,
},
config?.signal,
);
logger.debug('Received OAuth Authorization response', { action_id, identifier });
data.delta.auth = undefined;
data.delta.expires_at = undefined;
@ -235,10 +231,10 @@ async function createActionTool({
};
const tokenPromises = [];
tokenPromises.push(findToken({ userId: req.user.id, type: 'oauth', identifier }));
tokenPromises.push(findToken({ userId, type: 'oauth', identifier }));
tokenPromises.push(
findToken({
userId: req.user.id,
userId,
type: 'oauth_refresh',
identifier: `${identifier}:refresh`,
}),
@ -261,18 +257,20 @@ async function createActionTool({
const refresh_token = await decryptV2(refreshTokenData.token);
const refreshTokens = async () =>
await refreshAccessToken({
userId,
identifier,
refresh_token,
userId: req.user.id,
client_url: metadata.auth.client_url,
encrypted_oauth_client_id: encrypted.oauth_client_id,
encrypted_oauth_client_secret: encrypted.oauth_client_secret,
});
const flowManager = await getFlowStateManager(getLogStores);
const flowsCache = getLogStores(CacheKeys.FLOWS);
const flowManager = getFlowStateManager(flowsCache);
const refreshData = await flowManager.createFlowWithHandler(
`${identifier}:refresh`,
'oauth_refresh',
refreshTokens,
config?.signal,
);
metadata.oauth_access_token = refreshData.access_token;
if (refreshData.refresh_token) {
@ -308,9 +306,8 @@ async function createActionTool({
}
return response.data;
} catch (error) {
const logMessage = `API call to ${action.metadata.domain} failed`;
logAxiosError({ message: logMessage, error });
throw error;
const message = `API call to ${action.metadata.domain} failed:`;
return logAxiosError({ message, error });
}
};
@ -327,6 +324,27 @@ async function createActionTool({
};
}
/**
* Encrypts a sensitive value.
* @param {string} value
* @returns {Promise<string>}
*/
async function encryptSensitiveValue(value) {
// Encode API key to handle special characters like ":"
const encodedValue = encodeURIComponent(value);
return await encryptV2(encodedValue);
}
/**
* Decrypts a sensitive value.
* @param {string} value
* @returns {Promise<string>}
*/
async function decryptSensitiveValue(value) {
const decryptedValue = await decryptV2(value);
return decodeURIComponent(decryptedValue);
}
/**
* Encrypts sensitive metadata values for an action.
*
@ -339,17 +357,19 @@ async function encryptMetadata(metadata) {
// ServiceHttp
if (metadata.auth && metadata.auth.type === AuthTypeEnum.ServiceHttp) {
if (metadata.api_key) {
encryptedMetadata.api_key = await encryptV2(metadata.api_key);
encryptedMetadata.api_key = await encryptSensitiveValue(metadata.api_key);
}
}
// OAuth
else if (metadata.auth && metadata.auth.type === AuthTypeEnum.OAuth) {
if (metadata.oauth_client_id) {
encryptedMetadata.oauth_client_id = await encryptV2(metadata.oauth_client_id);
encryptedMetadata.oauth_client_id = await encryptSensitiveValue(metadata.oauth_client_id);
}
if (metadata.oauth_client_secret) {
encryptedMetadata.oauth_client_secret = await encryptV2(metadata.oauth_client_secret);
encryptedMetadata.oauth_client_secret = await encryptSensitiveValue(
metadata.oauth_client_secret,
);
}
}
@ -368,17 +388,19 @@ async function decryptMetadata(metadata) {
// ServiceHttp
if (metadata.auth && metadata.auth.type === AuthTypeEnum.ServiceHttp) {
if (metadata.api_key) {
decryptedMetadata.api_key = await decryptV2(metadata.api_key);
decryptedMetadata.api_key = await decryptSensitiveValue(metadata.api_key);
}
}
// OAuth
else if (metadata.auth && metadata.auth.type === AuthTypeEnum.OAuth) {
if (metadata.oauth_client_id) {
decryptedMetadata.oauth_client_id = await decryptV2(metadata.oauth_client_id);
decryptedMetadata.oauth_client_id = await decryptSensitiveValue(metadata.oauth_client_id);
}
if (metadata.oauth_client_secret) {
decryptedMetadata.oauth_client_secret = await decryptV2(metadata.oauth_client_secret);
decryptedMetadata.oauth_client_secret = await decryptSensitiveValue(
metadata.oauth_client_secret,
);
}
}

View file

@ -78,20 +78,20 @@ describe('domainParser', () => {
// Non-azure request
it('does not return domain as is if not azure', async () => {
const domain = `example.com${actionDomainSeparator}test${actionDomainSeparator}`;
const result1 = await domainParser(reqNoAzure, domain, false);
const result2 = await domainParser(reqNoAzure, domain, true);
const result1 = await domainParser(domain, false);
const result2 = await domainParser(domain, true);
expect(result1).not.toEqual(domain);
expect(result2).not.toEqual(domain);
});
// Test for Empty or Null Inputs
it('returns undefined for null domain input', async () => {
const result = await domainParser(req, null, true);
const result = await domainParser(null, true);
expect(result).toBeUndefined();
});
it('returns undefined for empty domain input', async () => {
const result = await domainParser(req, '', true);
const result = await domainParser('', true);
expect(result).toBeUndefined();
});
@ -102,7 +102,7 @@ describe('domainParser', () => {
.toString('base64')
.substring(0, Constants.ENCODED_DOMAIN_LENGTH);
await domainParser(req, domain, true);
await domainParser(domain, true);
const cachedValue = await globalCache[encodedDomain];
expect(cachedValue).toEqual(Buffer.from(domain).toString('base64'));
@ -112,14 +112,14 @@ describe('domainParser', () => {
it('encodes domain exactly at threshold without modification', async () => {
const domain = 'a'.repeat(Constants.ENCODED_DOMAIN_LENGTH - TLD.length) + TLD;
const expected = domain.replace(/\./g, actionDomainSeparator);
const result = await domainParser(req, domain, true);
const result = await domainParser(domain, true);
expect(result).toEqual(expected);
});
it('encodes domain just below threshold without modification', async () => {
const domain = 'a'.repeat(Constants.ENCODED_DOMAIN_LENGTH - 1 - TLD.length) + TLD;
const expected = domain.replace(/\./g, actionDomainSeparator);
const result = await domainParser(req, domain, true);
const result = await domainParser(domain, true);
expect(result).toEqual(expected);
});
@ -129,7 +129,7 @@ describe('domainParser', () => {
const encodedDomain = Buffer.from(unicodeDomain)
.toString('base64')
.substring(0, Constants.ENCODED_DOMAIN_LENGTH);
const result = await domainParser(req, unicodeDomain, true);
const result = await domainParser(unicodeDomain, true);
expect(result).toEqual(encodedDomain);
});
@ -139,7 +139,6 @@ describe('domainParser', () => {
globalCache[encodedDomain.substring(0, Constants.ENCODED_DOMAIN_LENGTH)] = encodedDomain; // Simulate caching
const result = await domainParser(
req,
encodedDomain.substring(0, Constants.ENCODED_DOMAIN_LENGTH),
false,
);
@ -150,27 +149,27 @@ describe('domainParser', () => {
it('returns domain with replaced separators if no cached domain exists', async () => {
const domain = 'example.com';
const withSeparator = domain.replace(/\./g, actionDomainSeparator);
const result = await domainParser(req, withSeparator, false);
const result = await domainParser(withSeparator, false);
expect(result).toEqual(domain);
});
it('returns domain with replaced separators when inverse is false and under encoding length', async () => {
const domain = 'examp.com';
const withSeparator = domain.replace(/\./g, actionDomainSeparator);
const result = await domainParser(req, withSeparator, false);
const result = await domainParser(withSeparator, false);
expect(result).toEqual(domain);
});
it('replaces periods with actionDomainSeparator when inverse is true and under encoding length', async () => {
const domain = 'examp.com';
const expected = domain.replace(/\./g, actionDomainSeparator);
const result = await domainParser(req, domain, true);
const result = await domainParser(domain, true);
expect(result).toEqual(expected);
});
it('encodes domain when length is above threshold and inverse is true', async () => {
const domain = 'a'.repeat(Constants.ENCODED_DOMAIN_LENGTH + 1).concat('.com');
const result = await domainParser(req, domain, true);
const result = await domainParser(domain, true);
expect(result).not.toEqual(domain);
expect(result.length).toBeLessThanOrEqual(Constants.ENCODED_DOMAIN_LENGTH);
});
@ -180,20 +179,20 @@ describe('domainParser', () => {
const encodedDomain = Buffer.from(
originalDomain.replace(/\./g, actionDomainSeparator),
).toString('base64');
const result = await domainParser(req, encodedDomain, false);
const result = await domainParser(encodedDomain, false);
expect(result).toEqual(encodedDomain);
});
it('decodes encoded value if cached and encoded value is provided, and inverse is false', async () => {
const originalDomain = 'example.com';
const encodedDomain = await domainParser(req, originalDomain, true);
const result = await domainParser(req, encodedDomain, false);
const encodedDomain = await domainParser(originalDomain, true);
const result = await domainParser(encodedDomain, false);
expect(result).toEqual(originalDomain);
});
it('handles invalid base64 encoded values gracefully', async () => {
const invalidBase64Domain = 'not_base64_encoded';
const result = await domainParser(req, invalidBase64Domain, false);
const result = await domainParser(invalidBase64Domain, false);
expect(result).toEqual(invalidBase64Domain);
});
});

View file

@ -9,15 +9,16 @@ const { checkVariables, checkHealth, checkConfig, checkAzureVariables } = requir
const { azureAssistantsDefaults, assistantsConfigSetup } = require('./start/assistants');
const { initializeAzureBlobService } = require('./Files/Azure/initialize');
const { initializeFirebase } = require('./Files/Firebase/initialize');
const { initializeS3 } = require('./Files/S3/initialize');
const loadCustomConfig = require('./Config/loadCustomConfig');
const handleRateLimits = require('./Config/handleRateLimits');
const { loadDefaultInterface } = require('./start/interface');
const { azureConfigSetup } = require('./start/azureOpenAI');
const { processModelSpecs } = require('./start/modelSpecs');
const { initializeS3 } = require('./Files/S3/initialize');
const { loadAndFormatTools } = require('./ToolService');
const { agentsConfigSetup } = require('./start/agents');
const { initializeRoles } = require('~/models/Role');
const { isEnabled } = require('~/server/utils');
const { getMCPManager } = require('~/config');
const paths = require('~/config/paths');
@ -29,7 +30,7 @@ const paths = require('~/config/paths');
*/
const AppService = async (app) => {
await initializeRoles();
/** @type {TCustomConfig}*/
/** @type {TCustomConfig} */
const config = (await loadCustomConfig()) ?? {};
const configDefaults = getConfigDefaults();
@ -37,6 +38,11 @@ const AppService = async (app) => {
const filteredTools = config.filteredTools;
const includedTools = config.includedTools;
const fileStrategy = config.fileStrategy ?? configDefaults.fileStrategy;
const startBalance = process.env.START_BALANCE;
const balance = config.balance ?? {
enabled: isEnabled(process.env.CHECK_BALANCE),
startBalance: startBalance ? parseInt(startBalance, 10) : undefined,
};
const imageOutputType = config?.imageOutputType ?? configDefaults.imageOutputType;
process.env.CDN_PROVIDER = fileStrategy;
@ -46,13 +52,13 @@ const AppService = async (app) => {
if (fileStrategy === FileSources.firebase) {
initializeFirebase();
} else if (fileStrategy === FileSources.azure) {
} else if (fileStrategy === FileSources.azure_blob) {
initializeAzureBlobService();
} else if (fileStrategy === FileSources.s3) {
initializeS3();
}
/** @type {Record<string, FunctionTool} */
/** @type {Record<string, FunctionTool>} */
const availableTools = loadAndFormatTools({
adminFilter: filteredTools,
adminIncluded: includedTools,
@ -60,7 +66,7 @@ const AppService = async (app) => {
});
if (config.mcpServers != null) {
const mcpManager = await getMCPManager();
const mcpManager = getMCPManager();
await mcpManager.initializeMCP(config.mcpServers, processMCPEnv);
await mcpManager.mapAvailableTools(availableTools);
}
@ -79,6 +85,7 @@ const AppService = async (app) => {
availableTools,
imageOutputType,
interfaceConfig,
balance,
};
if (!Object.keys(config).length) {
@ -139,7 +146,7 @@ const AppService = async (app) => {
...defaultLocals,
fileConfig: config?.fileConfig,
secureImageLinks: config?.secureImageLinks,
modelSpecs: processModelSpecs(endpoints, config.modelSpecs),
modelSpecs: processModelSpecs(endpoints, config.modelSpecs, interfaceConfig),
...endpointLocals,
};
};

View file

@ -15,6 +15,9 @@ jest.mock('./Config/loadCustomConfig', () => {
Promise.resolve({
registration: { socialLogins: ['testLogin'] },
fileStrategy: 'testStrategy',
balance: {
enabled: true,
},
}),
);
});
@ -124,6 +127,9 @@ describe('AppService', () => {
imageOutputType: expect.any(String),
fileConfig: undefined,
secureImageLinks: undefined,
balance: { enabled: true },
filteredTools: undefined,
includedTools: undefined,
});
});
@ -341,9 +347,6 @@ describe('AppService', () => {
process.env.FILE_UPLOAD_USER_MAX = 'initialUserMax';
process.env.FILE_UPLOAD_USER_WINDOW = 'initialUserWindow';
// Mock a custom configuration without specific rate limits
require('./Config/loadCustomConfig').mockImplementationOnce(() => Promise.resolve({}));
await AppService(app);
// Verify that process.env falls back to the initial values
@ -404,9 +407,6 @@ describe('AppService', () => {
process.env.IMPORT_USER_MAX = 'initialUserMax';
process.env.IMPORT_USER_WINDOW = 'initialUserWindow';
// Mock a custom configuration without specific rate limits
require('./Config/loadCustomConfig').mockImplementationOnce(() => Promise.resolve({}));
await AppService(app);
// Verify that process.env falls back to the initial values
@ -445,13 +445,27 @@ describe('AppService updating app.locals and issuing warnings', () => {
expect(app.locals.availableTools).toBeDefined();
expect(app.locals.fileStrategy).toEqual(FileSources.local);
expect(app.locals.socialLogins).toEqual(defaultSocialLogins);
expect(app.locals.balance).toEqual(
expect.objectContaining({
enabled: false,
startBalance: undefined,
}),
);
});
it('should update app.locals with values from loadCustomConfig', async () => {
// Mock loadCustomConfig to return a specific config object
// Mock loadCustomConfig to return a specific config object with a complete balance config
const customConfig = {
fileStrategy: 'firebase',
registration: { socialLogins: ['testLogin'] },
balance: {
enabled: false,
startBalance: 5000,
autoRefillEnabled: true,
refillIntervalValue: 15,
refillIntervalUnit: 'hours',
refillAmount: 5000,
},
};
require('./Config/loadCustomConfig').mockImplementationOnce(() =>
Promise.resolve(customConfig),
@ -464,6 +478,7 @@ describe('AppService updating app.locals and issuing warnings', () => {
expect(app.locals.availableTools).toBeDefined();
expect(app.locals.fileStrategy).toEqual(customConfig.fileStrategy);
expect(app.locals.socialLogins).toEqual(customConfig.registration.socialLogins);
expect(app.locals.balance).toEqual(customConfig.balance);
});
it('should apply the assistants endpoint configuration correctly to app.locals', async () => {

View file

@ -56,7 +56,7 @@ const logoutUser = async (req, refreshToken) => {
try {
req.session.destroy();
} catch (destroyErr) {
logger.error('[logoutUser] Failed to destroy session.', destroyErr);
logger.debug('[logoutUser] Failed to destroy session.', destroyErr);
}
return { status: 200, message: 'Logout successful' };
@ -91,7 +91,7 @@ const sendVerificationEmail = async (user) => {
subject: 'Verify your email',
payload: {
appName: process.env.APP_TITLE || 'LibreChat',
name: user.name,
name: user.name || user.username || user.email,
verificationLink: verificationLink,
year: new Date().getFullYear(),
},
@ -278,7 +278,7 @@ const requestPasswordReset = async (req) => {
subject: 'Password Reset Request',
payload: {
appName: process.env.APP_TITLE || 'LibreChat',
name: user.name,
name: user.name || user.username || user.email,
link: link,
year: new Date().getFullYear(),
},
@ -331,7 +331,7 @@ const resetPassword = async (userId, token, password) => {
subject: 'Password Reset Successfully',
payload: {
appName: process.env.APP_TITLE || 'LibreChat',
name: user.name,
name: user.name || user.username || user.email,
year: new Date().getFullYear(),
},
template: 'passwordReset.handlebars',
@ -414,7 +414,7 @@ const resendVerificationEmail = async (req) => {
subject: 'Verify your email',
payload: {
appName: process.env.APP_TITLE || 'LibreChat',
name: user.name,
name: user.name || user.username || user.email,
verificationLink: verificationLink,
year: new Date().getFullYear(),
},

View file

@ -1,5 +1,5 @@
const { CacheKeys, EModelEndpoint } = require('librechat-data-provider');
const { normalizeEndpointName } = require('~/server/utils');
const { normalizeEndpointName, isEnabled } = require('~/server/utils');
const loadCustomConfig = require('./loadCustomConfig');
const getLogStores = require('~/cache/getLogStores');
@ -23,6 +23,26 @@ async function getCustomConfig() {
return customConfig;
}
/**
* Retrieves the configuration object
* @function getBalanceConfig
* @returns {Promise<TCustomConfig['balance'] | null>}
* */
async function getBalanceConfig() {
const isLegacyEnabled = isEnabled(process.env.CHECK_BALANCE);
const startBalance = process.env.START_BALANCE;
/** @type {TCustomConfig['balance']} */
const config = {
enabled: isLegacyEnabled,
startBalance: startBalance != null && startBalance ? parseInt(startBalance, 10) : undefined,
};
const customConfig = await getCustomConfig();
if (!customConfig) {
return config;
}
return { ...config, ...(customConfig?.['balance'] ?? {}) };
}
/**
*
* @param {string | EModelEndpoint} endpoint
@ -40,4 +60,4 @@ const getCustomEndpointConfig = async (endpoint) => {
);
};
module.exports = { getCustomConfig, getCustomEndpointConfig };
module.exports = { getCustomConfig, getBalanceConfig, getCustomEndpointConfig };

View file

@ -33,10 +33,12 @@ async function getEndpointsConfig(req) {
};
}
if (mergedConfig[EModelEndpoint.agents] && req.app.locals?.[EModelEndpoint.agents]) {
const { disableBuilder, capabilities, ..._rest } = req.app.locals[EModelEndpoint.agents];
const { disableBuilder, capabilities, allowedProviders, ..._rest } =
req.app.locals[EModelEndpoint.agents];
mergedConfig[EModelEndpoint.agents] = {
...mergedConfig[EModelEndpoint.agents],
allowedProviders,
disableBuilder,
capabilities,
};

View file

@ -1,12 +1,15 @@
const { isAgentsEndpoint, Constants } = require('librechat-data-provider');
const { loadAgent } = require('~/models/Agent');
const { logger } = require('~/config');
const buildOptions = (req, endpoint, parsedBody) => {
const buildOptions = (req, endpoint, parsedBody, endpointType) => {
const { spec, iconURL, agent_id, instructions, maxContextTokens, ...model_parameters } =
parsedBody;
const agentPromise = loadAgent({
req,
agent_id,
agent_id: isAgentsEndpoint(endpoint) ? agent_id : Constants.EPHEMERAL_AGENT_ID,
endpoint,
model_parameters,
}).catch((error) => {
logger.error(`[/agents/:${agent_id}] Error retrieving agent during build options step`, error);
return undefined;
@ -17,6 +20,7 @@ const buildOptions = (req, endpoint, parsedBody) => {
iconURL,
endpoint,
agent_id,
endpointType,
instructions,
maxContextTokens,
model_parameters,

View file

@ -1,8 +1,12 @@
const { createContentAggregator, Providers } = require('@librechat/agents');
const {
Constants,
ErrorTypes,
EModelEndpoint,
EToolResources,
getResponseSender,
AgentCapabilities,
replaceSpecialVars,
providerEndpointMap,
} = require('librechat-data-provider');
const {
@ -39,12 +43,19 @@ const providerConfigMap = {
};
/**
* @param {ServerRequest} req
* @param {Promise<Array<MongoFile | null>> | undefined} _attachments
* @param {AgentToolResources | undefined} _tool_resources
* @param {Object} params
* @param {ServerRequest} params.req
* @param {Promise<Array<MongoFile | null>> | undefined} [params.attachments]
* @param {Set<string>} params.requestFileSet
* @param {AgentToolResources | undefined} [params.tool_resources]
* @returns {Promise<{ attachments: Array<MongoFile | undefined> | undefined, tool_resources: AgentToolResources | undefined }>}
*/
const primeResources = async (req, _attachments, _tool_resources) => {
const primeResources = async ({
req,
attachments: _attachments,
tool_resources: _tool_resources,
requestFileSet,
}) => {
try {
/** @type {Array<MongoFile | undefined> | undefined} */
let attachments;
@ -52,7 +63,7 @@ const primeResources = async (req, _attachments, _tool_resources) => {
const isOCREnabled = (req.app.locals?.[EModelEndpoint.agents]?.capabilities ?? []).includes(
AgentCapabilities.ocr,
);
if (tool_resources.ocr?.file_ids && isOCREnabled) {
if (tool_resources[EToolResources.ocr]?.file_ids && isOCREnabled) {
const context = await getFiles(
{
file_id: { $in: tool_resources.ocr.file_ids },
@ -77,17 +88,28 @@ const primeResources = async (req, _attachments, _tool_resources) => {
continue;
}
if (file.metadata?.fileIdentifier) {
const execute_code = tool_resources.execute_code ?? {};
const execute_code = tool_resources[EToolResources.execute_code] ?? {};
if (!execute_code.files) {
tool_resources.execute_code = { ...execute_code, files: [] };
tool_resources[EToolResources.execute_code] = { ...execute_code, files: [] };
}
tool_resources.execute_code.files.push(file);
tool_resources[EToolResources.execute_code].files.push(file);
} else if (file.embedded === true) {
const file_search = tool_resources.file_search ?? {};
const file_search = tool_resources[EToolResources.file_search] ?? {};
if (!file_search.files) {
tool_resources.file_search = { ...file_search, files: [] };
tool_resources[EToolResources.file_search] = { ...file_search, files: [] };
}
tool_resources.file_search.files.push(file);
tool_resources[EToolResources.file_search].files.push(file);
} else if (
requestFileSet.has(file.file_id) &&
file.type.startsWith('image') &&
file.height &&
file.width
) {
const image_edit = tool_resources[EToolResources.image_edit] ?? {};
if (!image_edit.files) {
tool_resources[EToolResources.image_edit] = { ...image_edit, files: [] };
}
tool_resources[EToolResources.image_edit].files.push(file);
}
attachments.push(file);
@ -99,11 +121,25 @@ const primeResources = async (req, _attachments, _tool_resources) => {
}
};
/**
* @param {...string | number} values
* @returns {string | number | undefined}
*/
function optionalChainWithEmptyCheck(...values) {
for (const value of values) {
if (value !== undefined && value !== null && value !== '') {
return value;
}
}
return values[values.length - 1];
}
/**
* @param {object} params
* @param {ServerRequest} params.req
* @param {ServerResponse} params.res
* @param {Agent} params.agent
* @param {Set<string>} [params.allowedProviders]
* @param {object} [params.endpointOption]
* @param {boolean} [params.isInitialAgent]
* @returns {Promise<Agent>}
@ -113,8 +149,14 @@ const initializeAgentOptions = async ({
res,
agent,
endpointOption,
allowedProviders,
isInitialAgent = false,
}) => {
if (allowedProviders.size > 0 && !allowedProviders.has(agent.provider)) {
throw new Error(
`{ "type": "${ErrorTypes.INVALID_AGENT_PROVIDER}", "info": "${agent.provider}" }`,
);
}
let currentFiles;
/** @type {Array<MongoFile>} */
const requestFiles = req.body.files ?? [];
@ -124,7 +166,14 @@ const initializeAgentOptions = async ({
(agent.model_parameters?.resendFiles ?? true) === true
) {
const fileIds = (await getConvoFiles(req.body.conversationId)) ?? [];
const toolFiles = await getToolFilesByIds(fileIds);
/** @type {Set<EToolResources>} */
const toolResourceSet = new Set();
for (const tool of agent.tools) {
if (EToolResources[tool]) {
toolResourceSet.add(EToolResources[tool]);
}
}
const toolFiles = await getToolFilesByIds(fileIds, toolResourceSet);
if (requestFiles.length || toolFiles.length) {
currentFiles = await processFiles(requestFiles.concat(toolFiles));
}
@ -132,19 +181,26 @@ const initializeAgentOptions = async ({
currentFiles = await processFiles(requestFiles);
}
const { attachments, tool_resources } = await primeResources(
const { attachments, tool_resources } = await primeResources({
req,
currentFiles,
agent.tool_resources,
);
const { tools, toolContextMap } = await loadAgentTools({
req,
res,
agent,
tool_resources,
attachments: currentFiles,
tool_resources: agent.tool_resources,
requestFileSet: new Set(requestFiles.map((file) => file.file_id)),
});
const provider = agent.provider;
const { tools, toolContextMap } = await loadAgentTools({
req,
res,
agent: {
id: agent.id,
tools: agent.tools,
provider,
model: agent.model,
},
tool_resources,
});
agent.endpoint = provider;
let getOptions = providerConfigMap[provider];
if (!getOptions && providerConfigMap[provider.toLowerCase()] != null) {
@ -177,6 +233,13 @@ const initializeAgentOptions = async ({
endpointOption: _endpointOption,
});
if (
agent.endpoint === EModelEndpoint.azureOpenAI &&
options.llmConfig?.azureOpenAIApiInstanceName == null
) {
agent.provider = Providers.OPENAI;
}
if (options.provider != null) {
agent.provider = options.provider;
}
@ -191,6 +254,13 @@ const initializeAgentOptions = async ({
agent.model_parameters.model = agent.model;
}
if (agent.instructions && agent.instructions !== '') {
agent.instructions = replaceSpecialVars({
text: agent.instructions,
user: req.user,
});
}
if (typeof agent.artifacts === 'string' && agent.artifacts !== '') {
agent.additional_instructions = generateArtifactsPrompt({
endpoint: agent.provider,
@ -200,12 +270,17 @@ const initializeAgentOptions = async ({
const tokensModel =
agent.provider === EModelEndpoint.azureOpenAI ? agent.model : agent.model_parameters.model;
const maxTokens = agent.model_parameters.maxOutputTokens ?? agent.model_parameters.maxTokens ?? 0;
const maxContextTokens =
agent.model_parameters.maxContextTokens ??
agent.max_context_tokens ??
getModelMaxTokens(tokensModel, providerEndpointMap[provider]) ??
4096;
const maxTokens = optionalChainWithEmptyCheck(
agent.model_parameters.maxOutputTokens,
agent.model_parameters.maxTokens,
0,
);
const maxContextTokens = optionalChainWithEmptyCheck(
agent.model_parameters.maxContextTokens,
agent.max_context_tokens,
getModelMaxTokens(tokensModel, providerEndpointMap[provider]),
4096,
);
return {
...agent,
tools,
@ -245,6 +320,8 @@ const initializeClient = async ({ req, res, endpointOption }) => {
}
const agentConfigs = new Map();
/** @type {Set<string>} */
const allowedProviders = new Set(req?.app?.locals?.[EModelEndpoint.agents]?.allowedProviders);
// Handle primary agent
const primaryConfig = await initializeAgentOptions({
@ -252,6 +329,7 @@ const initializeClient = async ({ req, res, endpointOption }) => {
res,
agent: primaryAgent,
endpointOption,
allowedProviders,
isInitialAgent: true,
});
@ -267,6 +345,7 @@ const initializeClient = async ({ req, res, endpointOption }) => {
res,
agent,
endpointOption,
allowedProviders,
});
agentConfigs.set(agentId, config);
}
@ -292,10 +371,14 @@ const initializeClient = async ({ req, res, endpointOption }) => {
agent: primaryConfig,
spec: endpointOption.spec,
iconURL: endpointOption.iconURL,
endpoint: EModelEndpoint.agents,
attachments: primaryConfig.attachments,
endpointType: endpointOption.endpointType,
maxContextTokens: primaryConfig.maxContextTokens,
resendFiles: primaryConfig.model_parameters?.resendFiles ?? true,
endpoint:
primaryConfig.id === Constants.EPHEMERAL_AGENT_ID
? primaryConfig.endpoint
: EModelEndpoint.agents,
});
return { client };

View file

@ -2,7 +2,11 @@ const { CacheKeys } = require('librechat-data-provider');
const getLogStores = require('~/cache/getLogStores');
const { isEnabled } = require('~/server/utils');
const { saveConvo } = require('~/models');
const { logger } = require('~/config');
/**
* Add title to conversation in a way that avoids memory retention
*/
const addTitle = async (req, { text, response, client }) => {
const { TITLE_CONVO = true } = process.env ?? {};
if (!isEnabled(TITLE_CONVO)) {
@ -13,37 +17,55 @@ const addTitle = async (req, { text, response, client }) => {
return;
}
// If the request was aborted, don't generate the title.
if (client.abortController.signal.aborted) {
return;
}
const titleCache = getLogStores(CacheKeys.GEN_TITLE);
const key = `${req.user.id}-${response.conversationId}`;
const responseText =
response?.content && Array.isArray(response?.content)
? response.content.reduce((acc, block) => {
if (block?.type === 'text') {
return acc + block.text;
}
return acc;
}, '')
: (response?.content ?? response?.text ?? '');
/** @type {NodeJS.Timeout} */
let timeoutId;
try {
const timeoutPromise = new Promise((_, reject) => {
timeoutId = setTimeout(() => reject(new Error('Title generation timeout')), 25000);
}).catch((error) => {
logger.error('Title error:', error);
});
const title = await client.titleConvo({
text,
responseText,
conversationId: response.conversationId,
});
await titleCache.set(key, title, 120000);
await saveConvo(
req,
{
conversationId: response.conversationId,
title,
},
{ context: 'api/server/services/Endpoints/agents/title.js' },
);
let titlePromise;
let abortController = new AbortController();
if (client && typeof client.titleConvo === 'function') {
titlePromise = Promise.race([
client
.titleConvo({
text,
abortController,
})
.catch((error) => {
logger.error('Client title error:', error);
}),
timeoutPromise,
]);
} else {
return;
}
const title = await titlePromise;
if (!abortController.signal.aborted) {
abortController.abort();
}
if (timeoutId) {
clearTimeout(timeoutId);
}
await titleCache.set(key, title, 120000);
await saveConvo(
req,
{
conversationId: response.conversationId,
title,
},
{ context: 'api/server/services/Endpoints/agents/title.js' },
);
} catch (error) {
logger.error('Error generating title:', error);
}
};
module.exports = addTitle;

View file

@ -1,7 +1,7 @@
const { EModelEndpoint } = require('librechat-data-provider');
const { getUserKey, checkUserKeyExpiry } = require('~/server/services/UserService');
const { getLLMConfig } = require('~/server/services/Endpoints/anthropic/llm');
const { AnthropicClient } = require('~/app');
const AnthropicClient = require('~/app/clients/AnthropicClient');
const initializeClient = async ({ req, res, endpointOption, overrideModel, optionsOnly }) => {
const { ANTHROPIC_API_KEY, ANTHROPIC_REVERSE_PROXY, PROXY } = process.env;

View file

@ -13,11 +13,6 @@ const addTitle = async (req, { text, response, client }) => {
return;
}
// If the request was aborted, don't generate the title.
if (client.abortController.signal.aborted) {
return;
}
const titleCache = getLogStores(CacheKeys.GEN_TITLE);
const key = `${req.user.id}-${response.conversationId}`;

View file

@ -3,7 +3,6 @@ const generateArtifactsPrompt = require('~/app/clients/prompts/artifacts');
const { getAssistant } = require('~/models/Assistant');
const buildOptions = async (endpoint, parsedBody) => {
const { promptPrefix, assistant_id, iconURL, greeting, spec, artifacts, ...modelOptions } =
parsedBody;
const endpointOption = removeNullishValues({

View file

@ -8,7 +8,7 @@ const {
removeNullishValues,
} = require('librechat-data-provider');
const { getUserKey, checkUserKeyExpiry } = require('~/server/services/UserService');
const { sleep } = require('~/server/utils');
const { createHandleLLMNewToken } = require('~/app/clients/generators');
const getOptions = async ({ req, overrideModel, endpointOption }) => {
const {
@ -90,12 +90,7 @@ const getOptions = async ({ req, overrideModel, endpointOption }) => {
llmConfig.callbacks = [
{
handleLLMNewToken: async () => {
if (!streamRate) {
return;
}
await sleep(streamRate);
},
handleLLMNewToken: createHandleLLMNewToken(streamRate),
},
];

View file

@ -9,10 +9,11 @@ const { Providers } = require('@librechat/agents');
const { getUserKeyValues, checkUserKeyExpiry } = require('~/server/services/UserService');
const { getLLMConfig } = require('~/server/services/Endpoints/openAI/llm');
const { getCustomEndpointConfig } = require('~/server/services/Config');
const { createHandleLLMNewToken } = require('~/app/clients/generators');
const { fetchModels } = require('~/server/services/ModelService');
const { isUserProvided, sleep } = require('~/server/utils');
const OpenAIClient = require('~/app/clients/OpenAIClient');
const { isUserProvided } = require('~/server/utils');
const getLogStores = require('~/cache/getLogStores');
const { OpenAIClient } = require('~/app');
const { PROXY } = process.env;
@ -148,9 +149,7 @@ const initializeClient = async ({ req, res, endpointOption, optionsOnly, overrid
}
options.llmConfig.callbacks = [
{
handleLLMNewToken: async () => {
await sleep(customOptions.streamRate);
},
handleLLMNewToken: createHandleLLMNewToken(clientOptions.streamRate),
},
];
return options;

View file

@ -6,9 +6,10 @@ const {
} = require('librechat-data-provider');
const { getUserKeyValues, checkUserKeyExpiry } = require('~/server/services/UserService');
const { getLLMConfig } = require('~/server/services/Endpoints/openAI/llm');
const { isEnabled, isUserProvided, sleep } = require('~/server/utils');
const { createHandleLLMNewToken } = require('~/app/clients/generators');
const { isEnabled, isUserProvided } = require('~/server/utils');
const OpenAIClient = require('~/app/clients/OpenAIClient');
const { getAzureCredentials } = require('~/utils');
const { OpenAIClient } = require('~/app');
const initializeClient = async ({
req,
@ -140,14 +141,13 @@ const initializeClient = async ({
clientOptions = Object.assign({ modelOptions }, clientOptions);
clientOptions.modelOptions.user = req.user.id;
const options = getLLMConfig(apiKey, clientOptions);
if (!clientOptions.streamRate) {
const streamRate = clientOptions.streamRate;
if (!streamRate) {
return options;
}
options.llmConfig.callbacks = [
{
handleLLMNewToken: async () => {
await sleep(clientOptions.streamRate);
},
handleLLMNewToken: createHandleLLMNewToken(streamRate),
},
];
return options;

View file

@ -136,7 +136,7 @@ function getLLMConfig(apiKey, options = {}, endpoint = null) {
Object.assign(llmConfig, azure);
llmConfig.model = llmConfig.azureOpenAIApiDeploymentName;
} else {
llmConfig.openAIApiKey = apiKey;
llmConfig.apiKey = apiKey;
// Object.assign(llmConfig, {
// configuration: { apiKey },
// });
@ -153,6 +153,12 @@ function getLLMConfig(apiKey, options = {}, endpoint = null) {
delete llmConfig.reasoning_effort;
}
if (llmConfig?.['max_tokens'] != null) {
/** @type {number} */
llmConfig.maxTokens = llmConfig['max_tokens'];
delete llmConfig['max_tokens'];
}
return {
/** @type {OpenAIClientOptions} */
llmConfig,

View file

@ -13,11 +13,6 @@ const addTitle = async (req, { text, response, client }) => {
return;
}
// If the request was aborted and is not azure, don't generate the title.
if (!client.azure && client.abortController.signal.aborted) {
return;
}
const titleCache = getLogStores(CacheKeys.GEN_TITLE);
const key = `${req.user.id}-${response.conversationId}`;

View file

@ -7,6 +7,78 @@ const { getCustomConfig } = require('~/server/services/Config');
const { genAzureEndpoint } = require('~/utils');
const { logger } = require('~/config');
/**
* Maps MIME types to their corresponding file extensions for audio files.
* @type {Object}
*/
const MIME_TO_EXTENSION_MAP = {
// MP4 container formats
'audio/mp4': 'm4a',
'audio/x-m4a': 'm4a',
// Ogg formats
'audio/ogg': 'ogg',
'audio/vorbis': 'ogg',
'application/ogg': 'ogg',
// Wave formats
'audio/wav': 'wav',
'audio/x-wav': 'wav',
'audio/wave': 'wav',
// MP3 formats
'audio/mp3': 'mp3',
'audio/mpeg': 'mp3',
'audio/mpeg3': 'mp3',
// WebM formats
'audio/webm': 'webm',
// Additional formats
'audio/flac': 'flac',
'audio/x-flac': 'flac',
};
/**
* Gets the file extension from the MIME type.
* @param {string} mimeType - The MIME type.
* @returns {string} The file extension.
*/
function getFileExtensionFromMime(mimeType) {
// Default fallback
if (!mimeType) {
return 'webm';
}
// Direct lookup (fastest)
const extension = MIME_TO_EXTENSION_MAP[mimeType];
if (extension) {
return extension;
}
// Try to extract subtype as fallback
const subtype = mimeType.split('/')[1]?.toLowerCase();
// If subtype matches a known extension
if (['mp3', 'mp4', 'ogg', 'wav', 'webm', 'm4a', 'flac'].includes(subtype)) {
return subtype === 'mp4' ? 'm4a' : subtype;
}
// Generic checks for partial matches
if (subtype?.includes('mp4') || subtype?.includes('m4a')) {
return 'm4a';
}
if (subtype?.includes('ogg')) {
return 'ogg';
}
if (subtype?.includes('wav')) {
return 'wav';
}
if (subtype?.includes('mp3') || subtype?.includes('mpeg')) {
return 'mp3';
}
if (subtype?.includes('webm')) {
return 'webm';
}
return 'webm'; // Default fallback
}
/**
* Service class for handling Speech-to-Text (STT) operations.
* @class
@ -170,8 +242,10 @@ class STTService {
throw new Error('Invalid provider');
}
const fileExtension = getFileExtensionFromMime(audioFile.mimetype);
const audioReadStream = Readable.from(audioBuffer);
audioReadStream.path = 'audio.wav';
audioReadStream.path = `audio.${fileExtension}`;
const [url, data, headers] = strategy.call(this, sttSchema, audioReadStream, audioFile);

View file

@ -1,4 +1,10 @@
const { CacheKeys, findLastSeparatorIndex, SEPARATORS, Time } = require('librechat-data-provider');
const {
Time,
CacheKeys,
SEPARATORS,
parseTextParts,
findLastSeparatorIndex,
} = require('librechat-data-provider');
const { getMessage } = require('~/models/Message');
const { getLogStores } = require('~/cache');
@ -84,10 +90,11 @@ function createChunkProcessor(user, messageId) {
notFoundCount++;
return [];
} else {
const text = message.content?.length > 0 ? parseTextParts(message.content) : message.text;
messageCache.set(
messageId,
{
text: message.text,
text,
complete: true,
},
Time.FIVE_MINUTES,
@ -95,7 +102,7 @@ function createChunkProcessor(user, messageId) {
}
const text = typeof message === 'string' ? message : message.text;
const complete = typeof message === 'string' ? false : message.complete ?? true;
const complete = typeof message === 'string' ? false : (message.complete ?? true);
if (text === processedText) {
noChangeCount++;

View file

@ -1,11 +1,13 @@
const fs = require('fs');
const path = require('path');
const mime = require('mime');
const axios = require('axios');
const fetch = require('node-fetch');
const { logger } = require('~/config');
const { getAzureContainerClient } = require('./initialize');
const defaultBasePath = 'images';
const { AZURE_STORAGE_PUBLIC_ACCESS = 'true', AZURE_CONTAINER_NAME = 'files' } = process.env;
/**
* Uploads a buffer to Azure Blob Storage.
@ -29,10 +31,9 @@ async function saveBufferToAzure({
}) {
try {
const containerClient = getAzureContainerClient(containerName);
const access = AZURE_STORAGE_PUBLIC_ACCESS?.toLowerCase() === 'true' ? 'blob' : undefined;
// Create the container if it doesn't exist. This is done per operation.
await containerClient.createIfNotExists({
access: process.env.AZURE_STORAGE_PUBLIC_ACCESS ? 'blob' : undefined,
});
await containerClient.createIfNotExists({ access });
const blobPath = `${basePath}/${userId}/${fileName}`;
const blockBlobClient = containerClient.getBlockBlobClient(blobPath);
await blockBlobClient.uploadData(buffer);
@ -97,25 +98,21 @@ async function getAzureURL({ fileName, basePath = defaultBasePath, userId, conta
* Deletes a blob from Azure Blob Storage.
*
* @param {Object} params
* @param {string} params.fileName - The name of the file.
* @param {string} [params.basePath='images'] - The base folder where the file is stored.
* @param {string} params.userId - The user's id.
* @param {string} [params.containerName] - The Azure Blob container name.
* @param {ServerRequest} params.req - The Express request object.
* @param {MongoFile} params.file - The file object.
*/
async function deleteFileFromAzure({
fileName,
basePath = defaultBasePath,
userId,
containerName,
}) {
async function deleteFileFromAzure(req, file) {
try {
const containerClient = getAzureContainerClient(containerName);
const blobPath = `${basePath}/${userId}/${fileName}`;
const containerClient = getAzureContainerClient(AZURE_CONTAINER_NAME);
const blobPath = file.filepath.split(`${AZURE_CONTAINER_NAME}/`)[1];
if (!blobPath.includes(req.user.id)) {
throw new Error('User ID not found in blob path');
}
const blockBlobClient = containerClient.getBlockBlobClient(blobPath);
await blockBlobClient.delete();
logger.debug('[deleteFileFromAzure] Blob deleted successfully from Azure Blob Storage');
} catch (error) {
logger.error('[deleteFileFromAzure] Error deleting blob:', error.message);
logger.error('[deleteFileFromAzure] Error deleting blob:', error);
if (error.statusCode === 404) {
return;
}
@ -123,6 +120,65 @@ async function deleteFileFromAzure({
}
}
/**
* Streams a file from disk directly to Azure Blob Storage without loading
* the entire file into memory.
*
* @param {Object} params
* @param {string} params.userId - The user's id.
* @param {string} params.filePath - The local file path to upload.
* @param {string} params.fileName - The name of the file in Azure.
* @param {string} [params.basePath='images'] - The base folder within the container.
* @param {string} [params.containerName] - The Azure Blob container name.
* @returns {Promise<string>} The URL of the uploaded blob.
*/
async function streamFileToAzure({
userId,
filePath,
fileName,
basePath = defaultBasePath,
containerName,
}) {
try {
const containerClient = getAzureContainerClient(containerName);
const access = AZURE_STORAGE_PUBLIC_ACCESS?.toLowerCase() === 'true' ? 'blob' : undefined;
// Create the container if it doesn't exist
await containerClient.createIfNotExists({ access });
const blobPath = `${basePath}/${userId}/${fileName}`;
const blockBlobClient = containerClient.getBlockBlobClient(blobPath);
// Get file size for proper content length
const stats = await fs.promises.stat(filePath);
// Create read stream from the file
const fileStream = fs.createReadStream(filePath);
const blobContentType = mime.getType(fileName);
await blockBlobClient.uploadStream(
fileStream,
undefined, // Use default concurrency (5)
undefined, // Use default buffer size (8MB)
{
blobHTTPHeaders: {
blobContentType,
},
onProgress: (progress) => {
logger.debug(
`[streamFileToAzure] Upload progress: ${progress.loadedBytes} bytes of ${stats.size}`,
);
},
},
);
return blockBlobClient.url;
} catch (error) {
logger.error('[streamFileToAzure] Error streaming file:', error);
throw error;
}
}
/**
* Uploads a file from the local file system to Azure Blob Storage.
*
@ -146,18 +202,19 @@ async function uploadFileToAzure({
}) {
try {
const inputFilePath = file.path;
const inputBuffer = await fs.promises.readFile(inputFilePath);
const bytes = Buffer.byteLength(inputBuffer);
const stats = await fs.promises.stat(inputFilePath);
const bytes = stats.size;
const userId = req.user.id;
const fileName = `${file_id}__${path.basename(inputFilePath)}`;
const fileURL = await saveBufferToAzure({
const fileURL = await streamFileToAzure({
userId,
buffer: inputBuffer,
filePath: inputFilePath,
fileName,
basePath,
containerName,
});
await fs.promises.unlink(inputFilePath);
return { filepath: fileURL, bytes };
} catch (error) {
logger.error('[uploadFileToAzure] Error uploading file:', error);

View file

@ -32,11 +32,12 @@ async function getCodeOutputDownloadStream(fileIdentifier, apiKey) {
const response = await axios(options);
return response;
} catch (error) {
logAxiosError({
message: `Error downloading code environment file stream: ${error.message}`,
error,
});
throw new Error(`Error downloading file: ${error.message}`);
throw new Error(
logAxiosError({
message: `Error downloading code environment file stream: ${error.message}`,
error,
}),
);
}
}
@ -89,11 +90,12 @@ async function uploadCodeEnvFile({ req, stream, filename, apiKey, entity_id = ''
return `${fileIdentifier}?entity_id=${entity_id}`;
} catch (error) {
logAxiosError({
message: `Error uploading code environment file: ${error.message}`,
error,
});
throw new Error(`Error uploading code environment file: ${error.message}`);
throw new Error(
logAxiosError({
message: `Error uploading code environment file: ${error.message}`,
error,
}),
);
}
}

View file

@ -309,6 +309,24 @@ function getLocalFileStream(req, filepath) {
throw new Error(`Invalid file path: ${filepath}`);
}
return fs.createReadStream(fullPath);
} else if (filepath.includes('/images/')) {
const basePath = filepath.split('/images/')[1];
if (!basePath) {
logger.warn(`Invalid base path: ${filepath}`);
throw new Error(`Invalid file path: ${filepath}`);
}
const fullPath = path.join(req.app.locals.paths.imageOutput, basePath);
const publicDir = req.app.locals.paths.imageOutput;
const rel = path.relative(publicDir, fullPath);
if (rel.startsWith('..') || path.isAbsolute(rel) || rel.includes(`..${path.sep}`)) {
logger.warn(`Invalid relative file path: ${filepath}`);
throw new Error(`Invalid file path: ${filepath}`);
}
return fs.createReadStream(fullPath);
}
return fs.createReadStream(filepath);

View file

@ -5,7 +5,7 @@ const FormData = require('form-data');
const { FileSources, envVarRegex, extractEnvVariable } = require('librechat-data-provider');
const { loadAuthValues } = require('~/server/services/Tools/credentials');
const { logger, createAxiosInstance } = require('~/config');
const { logAxiosError } = require('~/utils');
const { logAxiosError } = require('~/utils/axios');
const axios = createAxiosInstance();
@ -69,16 +69,20 @@ async function getSignedUrl({
/**
* @param {Object} params
* @param {string} params.apiKey
* @param {string} params.documentUrl
* @param {string} params.url - The document or image URL
* @param {string} [params.documentType='document_url'] - 'document_url' or 'image_url'
* @param {string} [params.model]
* @param {string} [params.baseURL]
* @returns {Promise<OCRResult>}
*/
async function performOCR({
apiKey,
documentUrl,
url,
documentType = 'document_url',
model = 'mistral-ocr-latest',
baseURL = 'https://api.mistral.ai/v1',
}) {
const documentKey = documentType === 'image_url' ? 'image_url' : 'document_url';
return axios
.post(
`${baseURL}/ocr`,
@ -86,8 +90,8 @@ async function performOCR({
model,
include_image_base64: false,
document: {
type: 'document_url',
document_url: documentUrl,
type: documentType,
[documentKey]: url,
},
},
{
@ -109,6 +113,19 @@ function extractVariableName(str) {
return match ? match[1] : null;
}
/**
* Uploads a file to the Mistral OCR API and processes the OCR result.
*
* @param {Object} params - The params object.
* @param {ServerRequest} params.req - The request object from Express. It should have a `user` property with an `id`
* representing the user
* @param {Express.Multer.File} params.file - The file object, which is part of the request. The file object should
* have a `mimetype` property that tells us the file type
* @param {string} params.file_id - The file ID.
* @param {string} [params.entity_id] - The entity ID, not used here but passed for consistency.
* @returns {Promise<{ filepath: string, bytes: number }>} - The result object containing the processed `text` and `images` (not currently used),
* along with the `filename` and `bytes` properties.
*/
const uploadMistralOCR = async ({ req, file, file_id, entity_id }) => {
try {
/** @type {TCustomConfig['ocr']} */
@ -160,11 +177,18 @@ const uploadMistralOCR = async ({ req, file, file_id, entity_id }) => {
fileId: mistralFile.id,
});
const mimetype = (file.mimetype || '').toLowerCase();
const originalname = file.originalname || '';
const isImage =
mimetype.startsWith('image') || /\.(png|jpe?g|gif|bmp|webp|tiff?)$/i.test(originalname);
const documentType = isImage ? 'image_url' : 'document_url';
const ocrResult = await performOCR({
apiKey,
baseURL,
model,
documentUrl: signedUrlResponse.url,
url: signedUrlResponse.url,
documentType,
});
let aggregatedText = '';
@ -194,8 +218,7 @@ const uploadMistralOCR = async ({ req, file, file_id, entity_id }) => {
};
} catch (error) {
const message = 'Error uploading document to Mistral OCR API';
logAxiosError({ error, message });
throw new Error(message);
throw new Error(logAxiosError({ error, message }));
}
};

View file

@ -29,9 +29,6 @@ const mockAxios = {
jest.mock('axios', () => mockAxios);
jest.mock('fs');
jest.mock('~/utils', () => ({
logAxiosError: jest.fn(),
}));
jest.mock('~/config', () => ({
logger: {
error: jest.fn(),
@ -175,7 +172,7 @@ describe('MistralOCR Service', () => {
});
describe('performOCR', () => {
it('should perform OCR using Mistral API', async () => {
it('should perform OCR using Mistral API (document_url)', async () => {
const mockResponse = {
data: {
pages: [{ markdown: 'Page 1 content' }, { markdown: 'Page 2 content' }],
@ -185,8 +182,9 @@ describe('MistralOCR Service', () => {
const result = await performOCR({
apiKey: 'test-api-key',
documentUrl: 'https://document-url.com',
url: 'https://document-url.com',
model: 'mistral-ocr-latest',
documentType: 'document_url',
});
expect(mockAxios.post).toHaveBeenCalledWith(
@ -209,6 +207,41 @@ describe('MistralOCR Service', () => {
expect(result).toEqual(mockResponse.data);
});
it('should perform OCR using Mistral API (image_url)', async () => {
const mockResponse = {
data: {
pages: [{ markdown: 'Image OCR content' }],
},
};
mockAxios.post.mockResolvedValueOnce(mockResponse);
const result = await performOCR({
apiKey: 'test-api-key',
url: 'https://image-url.com/image.png',
model: 'mistral-ocr-latest',
documentType: 'image_url',
});
expect(mockAxios.post).toHaveBeenCalledWith(
'https://api.mistral.ai/v1/ocr',
{
model: 'mistral-ocr-latest',
include_image_base64: false,
document: {
type: 'image_url',
image_url: 'https://image-url.com/image.png',
},
},
{
headers: {
'Content-Type': 'application/json',
Authorization: 'Bearer test-api-key',
},
},
);
expect(result).toEqual(mockResponse.data);
});
it('should handle errors during OCR processing', async () => {
const errorMessage = 'OCR processing error';
mockAxios.post.mockRejectedValueOnce(new Error(errorMessage));
@ -216,7 +249,7 @@ describe('MistralOCR Service', () => {
await expect(
performOCR({
apiKey: 'test-api-key',
documentUrl: 'https://document-url.com',
url: 'https://document-url.com',
}),
).rejects.toThrow();
@ -298,6 +331,7 @@ describe('MistralOCR Service', () => {
const file = {
path: '/tmp/upload/file.pdf',
originalname: 'document.pdf',
mimetype: 'application/pdf',
};
const result = await uploadMistralOCR({
@ -325,6 +359,90 @@ describe('MistralOCR Service', () => {
});
});
it('should process OCR for an image file and use image_url type', async () => {
const { loadAuthValues } = require('~/server/services/Tools/credentials');
loadAuthValues.mockResolvedValue({
OCR_API_KEY: 'test-api-key',
OCR_BASEURL: 'https://api.mistral.ai/v1',
});
// Mock file upload response
mockAxios.post.mockResolvedValueOnce({
data: { id: 'file-456', purpose: 'ocr' },
});
// Mock signed URL response
mockAxios.get.mockResolvedValueOnce({
data: { url: 'https://signed-url.com/image.png' },
});
// Mock OCR response for image
mockAxios.post.mockResolvedValueOnce({
data: {
pages: [
{
markdown: 'Image OCR result',
images: [{ image_base64: 'imgbase64' }],
},
],
},
});
const req = {
user: { id: 'user456' },
app: {
locals: {
ocr: {
apiKey: '${OCR_API_KEY}',
baseURL: '${OCR_BASEURL}',
mistralModel: 'mistral-medium',
},
},
},
};
const file = {
path: '/tmp/upload/image.png',
originalname: 'image.png',
mimetype: 'image/png',
};
const result = await uploadMistralOCR({
req,
file,
file_id: 'file456',
entity_id: 'entity456',
});
expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/image.png');
expect(loadAuthValues).toHaveBeenCalledWith({
userId: 'user456',
authFields: ['OCR_BASEURL', 'OCR_API_KEY'],
optional: expect.any(Set),
});
// Check that the OCR API was called with image_url type
expect(mockAxios.post).toHaveBeenCalledWith(
'https://api.mistral.ai/v1/ocr',
expect.objectContaining({
document: expect.objectContaining({
type: 'image_url',
image_url: 'https://signed-url.com/image.png',
}),
}),
expect.any(Object),
);
expect(result).toEqual({
filename: 'image.png',
bytes: expect.any(Number),
filepath: 'mistral_ocr',
text: expect.stringContaining('Image OCR result'),
images: ['imgbase64'],
});
});
it('should process variable references in configuration', async () => {
// Setup mocks with environment variables
const { loadAuthValues } = require('~/server/services/Tools/credentials');
@ -494,9 +612,6 @@ describe('MistralOCR Service', () => {
}),
).rejects.toThrow('Error uploading document to Mistral OCR API');
expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf');
const { logAxiosError } = require('~/utils');
expect(logAxiosError).toHaveBeenCalled();
});
it('should handle single page documents without page numbering', async () => {

View file

@ -1,7 +1,13 @@
const fs = require('fs');
const path = require('path');
const fetch = require('node-fetch');
const { PutObjectCommand, GetObjectCommand, DeleteObjectCommand } = require('@aws-sdk/client-s3');
const { FileSources } = require('librechat-data-provider');
const {
PutObjectCommand,
GetObjectCommand,
HeadObjectCommand,
DeleteObjectCommand,
} = require('@aws-sdk/client-s3');
const { getSignedUrl } = require('@aws-sdk/s3-request-presigner');
const { initializeS3 } = require('./initialize');
const { logger } = require('~/config');
@ -9,6 +15,34 @@ const { logger } = require('~/config');
const bucketName = process.env.AWS_BUCKET_NAME;
const defaultBasePath = 'images';
let s3UrlExpirySeconds = 7 * 24 * 60 * 60;
let s3RefreshExpiryMs = null;
if (process.env.S3_URL_EXPIRY_SECONDS !== undefined) {
const parsed = parseInt(process.env.S3_URL_EXPIRY_SECONDS, 10);
if (!isNaN(parsed) && parsed > 0) {
s3UrlExpirySeconds = Math.min(parsed, 7 * 24 * 60 * 60);
} else {
logger.warn(
`[S3] Invalid S3_URL_EXPIRY_SECONDS value: "${process.env.S3_URL_EXPIRY_SECONDS}". Using 7-day expiry.`,
);
}
}
if (process.env.S3_REFRESH_EXPIRY_MS !== null && process.env.S3_REFRESH_EXPIRY_MS) {
const parsed = parseInt(process.env.S3_REFRESH_EXPIRY_MS, 10);
if (!isNaN(parsed) && parsed > 0) {
s3RefreshExpiryMs = parsed;
logger.info(`[S3] Using custom refresh expiry time: ${s3RefreshExpiryMs}ms`);
} else {
logger.warn(
`[S3] Invalid S3_REFRESH_EXPIRY_MS value: "${process.env.S3_REFRESH_EXPIRY_MS}". Using default refresh logic.`,
);
}
}
/**
* Constructs the S3 key based on the base path, user ID, and file name.
*/
@ -39,13 +73,14 @@ async function saveBufferToS3({ userId, buffer, fileName, basePath = defaultBase
}
/**
* Retrieves a signed URL for a file stored in S3.
* Retrieves a URL for a file stored in S3.
* Returns a signed URL with expiration time or a proxy URL based on config
*
* @param {Object} params
* @param {string} params.userId - The user's unique identifier.
* @param {string} params.fileName - The file name in S3.
* @param {string} [params.basePath='images'] - The base path in the bucket.
* @returns {Promise<string>} A signed URL valid for 24 hours.
* @returns {Promise<string>} A URL to access the S3 object
*/
async function getS3URL({ userId, fileName, basePath = defaultBasePath }) {
const key = getS3Key(basePath, userId, fileName);
@ -53,7 +88,7 @@ async function getS3URL({ userId, fileName, basePath = defaultBasePath }) {
try {
const s3 = initializeS3();
return await getSignedUrl(s3, new GetObjectCommand(params), { expiresIn: 86400 });
return await getSignedUrl(s3, new GetObjectCommand(params), { expiresIn: s3UrlExpirySeconds });
} catch (error) {
logger.error('[getS3URL] Error getting signed URL from S3:', error.message);
throw error;
@ -86,21 +121,51 @@ async function saveURLToS3({ userId, URL, fileName, basePath = defaultBasePath }
* Deletes a file from S3.
*
* @param {Object} params
* @param {string} params.userId - The user's unique identifier.
* @param {string} params.fileName - The file name in S3.
* @param {string} [params.basePath='images'] - The base path in the bucket.
* @param {ServerRequest} params.req
* @param {MongoFile} params.file - The file object to delete.
* @returns {Promise<void>}
*/
async function deleteFileFromS3({ userId, fileName, basePath = defaultBasePath }) {
const key = getS3Key(basePath, userId, fileName);
async function deleteFileFromS3(req, file) {
const key = extractKeyFromS3Url(file.filepath);
const params = { Bucket: bucketName, Key: key };
if (!key.includes(req.user.id)) {
const message = `[deleteFileFromS3] User ID mismatch: ${req.user.id} vs ${key}`;
logger.error(message);
throw new Error(message);
}
try {
const s3 = initializeS3();
await s3.send(new DeleteObjectCommand(params));
logger.debug('[deleteFileFromS3] File deleted successfully from S3');
try {
const headCommand = new HeadObjectCommand(params);
await s3.send(headCommand);
logger.debug('[deleteFileFromS3] File exists, proceeding with deletion');
} catch (headErr) {
if (headErr.name === 'NotFound') {
logger.warn(`[deleteFileFromS3] File does not exist: ${key}`);
return;
}
}
const deleteResult = await s3.send(new DeleteObjectCommand(params));
logger.debug('[deleteFileFromS3] Delete command response:', JSON.stringify(deleteResult));
try {
await s3.send(new HeadObjectCommand(params));
logger.error('[deleteFileFromS3] File still exists after deletion!');
} catch (verifyErr) {
if (verifyErr.name === 'NotFound') {
logger.debug(`[deleteFileFromS3] Verified file is deleted: ${key}`);
} else {
logger.error('[deleteFileFromS3] Error verifying deletion:', verifyErr);
}
}
logger.debug('[deleteFileFromS3] S3 File deletion completed');
} catch (error) {
logger.error('[deleteFileFromS3] Error deleting file from S3:', error.message);
logger.error(`[deleteFileFromS3] Error deleting file from S3: ${error.message}`);
logger.error(error.stack);
// If the file is not found, we can safely return.
if (error.code === 'NoSuchKey') {
return;
@ -110,7 +175,7 @@ async function deleteFileFromS3({ userId, fileName, basePath = defaultBasePath }
}
/**
* Uploads a local file to S3.
* Uploads a local file to S3 by streaming it directly without loading into memory.
*
* @param {Object} params
* @param {import('express').Request} params.req - The Express request (must include user).
@ -122,37 +187,272 @@ async function deleteFileFromS3({ userId, fileName, basePath = defaultBasePath }
async function uploadFileToS3({ req, file, file_id, basePath = defaultBasePath }) {
try {
const inputFilePath = file.path;
const inputBuffer = await fs.promises.readFile(inputFilePath);
const bytes = Buffer.byteLength(inputBuffer);
const userId = req.user.id;
const fileName = `${file_id}__${path.basename(inputFilePath)}`;
const fileURL = await saveBufferToS3({ userId, buffer: inputBuffer, fileName, basePath });
await fs.promises.unlink(inputFilePath);
const key = getS3Key(basePath, userId, fileName);
const stats = await fs.promises.stat(inputFilePath);
const bytes = stats.size;
const fileStream = fs.createReadStream(inputFilePath);
const s3 = initializeS3();
const uploadParams = {
Bucket: bucketName,
Key: key,
Body: fileStream,
};
await s3.send(new PutObjectCommand(uploadParams));
const fileURL = await getS3URL({ userId, fileName, basePath });
return { filepath: fileURL, bytes };
} catch (error) {
logger.error('[uploadFileToS3] Error uploading file to S3:', error.message);
logger.error('[uploadFileToS3] Error streaming file to S3:', error);
try {
if (file && file.path) {
await fs.promises.unlink(file.path);
}
} catch (unlinkError) {
logger.error(
'[uploadFileToS3] Error deleting temporary file, likely already deleted:',
unlinkError.message,
);
}
throw error;
}
}
/**
* Extracts the S3 key from a URL or returns the key if already properly formatted
*
* @param {string} fileUrlOrKey - The file URL or key
* @returns {string} The S3 key
*/
function extractKeyFromS3Url(fileUrlOrKey) {
if (!fileUrlOrKey) {
throw new Error('Invalid input: URL or key is empty');
}
try {
const url = new URL(fileUrlOrKey);
return url.pathname.substring(1);
} catch (error) {
const parts = fileUrlOrKey.split('/');
if (parts.length >= 3 && !fileUrlOrKey.startsWith('http') && !fileUrlOrKey.startsWith('/')) {
return fileUrlOrKey;
}
return fileUrlOrKey.startsWith('/') ? fileUrlOrKey.substring(1) : fileUrlOrKey;
}
}
/**
* Retrieves a readable stream for a file stored in S3.
*
* @param {ServerRequest} req - Server request object.
* @param {string} filePath - The S3 key of the file.
* @returns {Promise<NodeJS.ReadableStream>}
*/
async function getS3FileStream(filePath) {
const params = { Bucket: bucketName, Key: filePath };
async function getS3FileStream(_req, filePath) {
try {
const Key = extractKeyFromS3Url(filePath);
const params = { Bucket: bucketName, Key };
const s3 = initializeS3();
const data = await s3.send(new GetObjectCommand(params));
return data.Body; // Returns a Node.js ReadableStream.
} catch (error) {
logger.error('[getS3FileStream] Error retrieving S3 file stream:', error.message);
logger.error('[getS3FileStream] Error retrieving S3 file stream:', error);
throw error;
}
}
/**
* Determines if a signed S3 URL is close to expiration
*
* @param {string} signedUrl - The signed S3 URL
* @param {number} bufferSeconds - Buffer time in seconds
* @returns {boolean} True if the URL needs refreshing
*/
function needsRefresh(signedUrl, bufferSeconds) {
try {
// Parse the URL
const url = new URL(signedUrl);
// Check if it has the signature parameters that indicate it's a signed URL
// X-Amz-Signature is the most reliable indicator for AWS signed URLs
if (!url.searchParams.has('X-Amz-Signature')) {
// Not a signed URL, so no expiration to check (or it's already a proxy URL)
return false;
}
// Extract the expiration time from the URL
const expiresParam = url.searchParams.get('X-Amz-Expires');
const dateParam = url.searchParams.get('X-Amz-Date');
if (!expiresParam || !dateParam) {
// Missing expiration information, assume it needs refresh to be safe
return true;
}
// Parse the AWS date format (YYYYMMDDTHHMMSSZ)
const year = dateParam.substring(0, 4);
const month = dateParam.substring(4, 6);
const day = dateParam.substring(6, 8);
const hour = dateParam.substring(9, 11);
const minute = dateParam.substring(11, 13);
const second = dateParam.substring(13, 15);
const dateObj = new Date(`${year}-${month}-${day}T${hour}:${minute}:${second}Z`);
const expiresAtDate = new Date(dateObj.getTime() + parseInt(expiresParam) * 1000);
// Check if it's close to expiration
const now = new Date();
// If S3_REFRESH_EXPIRY_MS is set, use it to determine if URL is expired
if (s3RefreshExpiryMs !== null) {
const urlCreationTime = dateObj.getTime();
const urlAge = now.getTime() - urlCreationTime;
return urlAge >= s3RefreshExpiryMs;
}
// Otherwise use the default buffer-based logic
const bufferTime = new Date(now.getTime() + bufferSeconds * 1000);
return expiresAtDate <= bufferTime;
} catch (error) {
logger.error('Error checking URL expiration:', error);
// If we can't determine, assume it needs refresh to be safe
return true;
}
}
/**
* Generates a new URL for an expired S3 URL
* @param {string} currentURL - The current file URL
* @returns {Promise<string | undefined>}
*/
async function getNewS3URL(currentURL) {
try {
const s3Key = extractKeyFromS3Url(currentURL);
if (!s3Key) {
return;
}
const keyParts = s3Key.split('/');
if (keyParts.length < 3) {
return;
}
const basePath = keyParts[0];
const userId = keyParts[1];
const fileName = keyParts.slice(2).join('/');
return await getS3URL({
userId,
fileName,
basePath,
});
} catch (error) {
logger.error('Error getting new S3 URL:', error);
}
}
/**
* Refreshes S3 URLs for an array of files if they're expired or close to expiring
*
* @param {MongoFile[]} files - Array of file documents
* @param {(files: MongoFile[]) => Promise<void>} batchUpdateFiles - Function to update files in the database
* @param {number} [bufferSeconds=3600] - Buffer time in seconds to check for expiration
* @returns {Promise<MongoFile[]>} The files with refreshed URLs if needed
*/
async function refreshS3FileUrls(files, batchUpdateFiles, bufferSeconds = 3600) {
if (!files || !Array.isArray(files) || files.length === 0) {
return files;
}
const filesToUpdate = [];
for (let i = 0; i < files.length; i++) {
const file = files[i];
if (!file?.file_id) {
continue;
}
if (file.source !== FileSources.s3) {
continue;
}
if (!file.filepath) {
continue;
}
if (!needsRefresh(file.filepath, bufferSeconds)) {
continue;
}
try {
const newURL = await getNewS3URL(file.filepath);
if (!newURL) {
continue;
}
filesToUpdate.push({
file_id: file.file_id,
filepath: newURL,
});
files[i].filepath = newURL;
} catch (error) {
logger.error(`Error refreshing S3 URL for file ${file.file_id}:`, error);
}
}
if (filesToUpdate.length > 0) {
await batchUpdateFiles(filesToUpdate);
}
return files;
}
/**
* Refreshes a single S3 URL if it's expired or close to expiring
*
* @param {{ filepath: string, source: string }} fileObj - Simple file object containing filepath and source
* @param {number} [bufferSeconds=3600] - Buffer time in seconds to check for expiration
* @returns {Promise<string>} The refreshed URL or the original URL if no refresh needed
*/
async function refreshS3Url(fileObj, bufferSeconds = 3600) {
if (!fileObj || fileObj.source !== FileSources.s3 || !fileObj.filepath) {
return fileObj?.filepath || '';
}
if (!needsRefresh(fileObj.filepath, bufferSeconds)) {
return fileObj.filepath;
}
try {
const s3Key = extractKeyFromS3Url(fileObj.filepath);
if (!s3Key) {
logger.warn(`Unable to extract S3 key from URL: ${fileObj.filepath}`);
return fileObj.filepath;
}
const keyParts = s3Key.split('/');
if (keyParts.length < 3) {
logger.warn(`Invalid S3 key format: ${s3Key}`);
return fileObj.filepath;
}
const basePath = keyParts[0];
const userId = keyParts[1];
const fileName = keyParts.slice(2).join('/');
const newUrl = await getS3URL({
userId,
fileName,
basePath,
});
logger.debug(`Refreshed S3 URL for key: ${s3Key}`);
return newUrl;
} catch (error) {
logger.error(`Error refreshing S3 URL: ${error.message}`);
return fileObj.filepath;
}
}
module.exports = {
saveBufferToS3,
saveURLToS3,
@ -160,4 +460,8 @@ module.exports = {
deleteFileFromS3,
uploadFileToS3,
getS3FileStream,
refreshS3FileUrls,
refreshS3Url,
needsRefresh,
getNewS3URL,
};

View file

@ -7,8 +7,47 @@ const {
EModelEndpoint,
} = require('librechat-data-provider');
const { getStrategyFunctions } = require('~/server/services/Files/strategies');
const { logAxiosError } = require('~/utils');
const { logger } = require('~/config');
/**
* Converts a readable stream to a base64 encoded string.
*
* @param {NodeJS.ReadableStream} stream - The readable stream to convert.
* @param {boolean} [destroyStream=true] - Whether to destroy the stream after processing.
* @returns {Promise<string>} - Promise resolving to the base64 encoded content.
*/
async function streamToBase64(stream, destroyStream = true) {
return new Promise((resolve, reject) => {
const chunks = [];
stream.on('data', (chunk) => {
chunks.push(chunk);
});
stream.on('end', () => {
try {
const buffer = Buffer.concat(chunks);
const base64Data = buffer.toString('base64');
chunks.length = 0; // Clear the array
resolve(base64Data);
} catch (err) {
reject(err);
}
});
stream.on('error', (error) => {
chunks.length = 0;
reject(error);
});
}).finally(() => {
// Clean up the stream if required
if (destroyStream && stream.destroy && typeof stream.destroy === 'function') {
stream.destroy();
}
});
}
/**
* Fetches an image from a URL and returns its base64 representation.
*
@ -22,10 +61,12 @@ async function fetchImageToBase64(url) {
const response = await axios.get(url, {
responseType: 'arraybuffer',
});
return Buffer.from(response.data).toString('base64');
const base64Data = Buffer.from(response.data).toString('base64');
response.data = null;
return base64Data;
} catch (error) {
logger.error('Error fetching image to convert to base64', error);
throw error;
const message = 'Error fetching image to convert to base64';
throw new Error(logAxiosError({ message, error }));
}
}
@ -37,17 +78,21 @@ const base64Only = new Set([
EModelEndpoint.bedrock,
]);
const blobStorageSources = new Set([FileSources.azure_blob, FileSources.s3]);
/**
* Encodes and formats the given files.
* @param {Express.Request} req - The request object.
* @param {Array<MongoFile>} files - The array of files to encode and format.
* @param {EModelEndpoint} [endpoint] - Optional: The endpoint for the image.
* @param {string} [mode] - Optional: The endpoint mode for the image.
* @returns {Promise<Object>} - A promise that resolves to the result object containing the encoded images and file details.
* @returns {Promise<{ text: string; files: MongoFile[]; image_urls: MessageContentImageUrl[] }>} - A promise that resolves to the result object containing the encoded images and file details.
*/
async function encodeAndFormat(req, files, endpoint, mode) {
const promises = [];
/** @type {Record<FileSources, Pick<ReturnType<typeof getStrategyFunctions>, 'prepareImagePayload' | 'getDownloadStream'>>} */
const encodingMethods = {};
/** @type {{ text: string; files: MongoFile[]; image_urls: MessageContentImageUrl[] }} */
const result = {
text: '',
files: [],
@ -59,6 +104,7 @@ async function encodeAndFormat(req, files, endpoint, mode) {
}
for (let file of files) {
/** @type {FileSources} */
const source = file.source ?? FileSources.local;
if (source === FileSources.text && file.text) {
result.text += `${!result.text ? 'Attached document(s):\n```md' : '\n\n---\n\n'}# "${file.filename}"\n${file.text}\n`;
@ -70,18 +116,29 @@ async function encodeAndFormat(req, files, endpoint, mode) {
}
if (!encodingMethods[source]) {
const { prepareImagePayload } = getStrategyFunctions(source);
const { prepareImagePayload, getDownloadStream } = getStrategyFunctions(source);
if (!prepareImagePayload) {
throw new Error(`Encoding function not implemented for ${source}`);
}
encodingMethods[source] = prepareImagePayload;
encodingMethods[source] = { prepareImagePayload, getDownloadStream };
}
const preparePayload = encodingMethods[source];
/* Google & Anthropic don't support passing URLs to payload */
if (source !== FileSources.local && base64Only.has(endpoint)) {
const preparePayload = encodingMethods[source].prepareImagePayload;
/* We need to fetch the image and convert it to base64 if we are using S3/Azure Blob storage. */
if (blobStorageSources.has(source)) {
try {
const downloadStream = encodingMethods[source].getDownloadStream;
let stream = await downloadStream(req, file.filepath);
let base64Data = await streamToBase64(stream);
stream = null;
promises.push([file, base64Data]);
base64Data = null;
continue;
} catch (error) {
// Error handling code
}
} else if (source !== FileSources.local && base64Only.has(endpoint)) {
const [_file, imageURL] = await preparePayload(req, file);
promises.push([_file, await fetchImageToBase64(imageURL)]);
continue;
@ -97,6 +154,7 @@ async function encodeAndFormat(req, files, endpoint, mode) {
/** @type {Array<[MongoFile, string]>} */
const formattedImages = await Promise.all(promises);
promises.length = 0;
for (const [file, imageContent] of formattedImages) {
const fileMetadata = {
@ -129,8 +187,8 @@ async function encodeAndFormat(req, files, endpoint, mode) {
};
if (mode === VisionModes.agents) {
result.image_urls.push(imagePart);
result.files.push(fileMetadata);
result.image_urls.push({ ...imagePart });
result.files.push({ ...fileMetadata });
continue;
}
@ -152,10 +210,11 @@ async function encodeAndFormat(req, files, endpoint, mode) {
delete imagePart.image_url;
}
result.image_urls.push(imagePart);
result.files.push(fileMetadata);
result.image_urls.push({ ...imagePart });
result.files.push({ ...fileMetadata });
}
return result;
formattedImages.length = 0;
return { ...result };
}
module.exports = {

View file

@ -492,7 +492,7 @@ const processAgentFileUpload = async ({ req, res, metadata }) => {
let fileInfoMetadata;
const entity_id = messageAttachment === true ? undefined : agent_id;
const basePath = mime.getType(file.originalname)?.startsWith('image') ? 'images' : 'uploads';
if (tool_resource === EToolResources.execute_code) {
const isCodeEnabled = await checkCapability(req, AgentCapabilities.execute_code);
if (!isCodeEnabled) {
@ -520,7 +520,7 @@ const processAgentFileUpload = async ({ req, res, metadata }) => {
throw new Error('OCR capability is not enabled for Agents');
}
const { handleFileUpload } = getStrategyFunctions(
const { handleFileUpload: uploadMistralOCR } = getStrategyFunctions(
req.app.locals?.ocr?.strategy ?? FileSources.mistral_ocr,
);
const { file_id, temp_file_id } = metadata;
@ -532,7 +532,7 @@ const processAgentFileUpload = async ({ req, res, metadata }) => {
images,
filename,
filepath: ocrFileURL,
} = await handleFileUpload({ req, file, file_id, entity_id: agent_id });
} = await uploadMistralOCR({ req, file, file_id, entity_id: agent_id, basePath });
const fileInfo = removeNullishValues({
text,
@ -540,7 +540,7 @@ const processAgentFileUpload = async ({ req, res, metadata }) => {
file_id,
temp_file_id,
user: req.user.id,
type: file.mimetype,
type: 'text/plain',
filepath: ocrFileURL,
source: FileSources.text,
filename: filename ?? file.originalname,
@ -582,6 +582,7 @@ const processAgentFileUpload = async ({ req, res, metadata }) => {
file,
file_id,
entity_id,
basePath,
});
let filepath = _filepath;

View file

@ -211,6 +211,8 @@ const getStrategyFunctions = (fileSource) => {
} else if (fileSource === FileSources.openai) {
return openAIStrategy();
} else if (fileSource === FileSources.azure) {
return openAIStrategy();
} else if (fileSource === FileSources.azure_blob) {
return azureStrategy();
} else if (fileSource === FileSources.vectordb) {
return vectorStrategy();

View file

@ -13,13 +13,13 @@ const { logger, getMCPManager } = require('~/config');
* Creates a general tool for an entire action set.
*
* @param {Object} params - The parameters for loading action sets.
* @param {ServerRequest} params.req - The name of the tool.
* @param {ServerRequest} params.req - The Express request object, containing user/request info.
* @param {string} params.toolKey - The toolKey for the tool.
* @param {import('@librechat/agents').Providers | EModelEndpoint} params.provider - The provider for the tool.
* @param {string} params.model - The model for the tool.
* @returns { Promise<typeof tool | { _call: (toolInput: Object | string) => unknown}> } An object with `_call` method to execute the tool input.
*/
async function createMCPTool({ req, toolKey, provider }) {
async function createMCPTool({ req, toolKey, provider: _provider }) {
const toolDefinition = req.app.locals.availableTools[toolKey]?.function;
if (!toolDefinition) {
logger.error(`Tool ${toolKey} not found in available tools`);
@ -27,9 +27,10 @@ async function createMCPTool({ req, toolKey, provider }) {
}
/** @type {LCTool} */
const { description, parameters } = toolDefinition;
const isGoogle = provider === Providers.VERTEXAI || provider === Providers.GOOGLE;
const isGoogle = _provider === Providers.VERTEXAI || _provider === Providers.GOOGLE;
let schema = convertJsonSchemaToZod(parameters, {
allowEmptyObject: !isGoogle,
transformOneOfAnyOf: true,
});
if (!schema) {
@ -37,19 +38,31 @@ async function createMCPTool({ req, toolKey, provider }) {
}
const [toolName, serverName] = toolKey.split(Constants.mcp_delimiter);
if (!req.user?.id) {
logger.error(
`[MCP][${serverName}][${toolName}] User ID not found on request. Cannot create tool.`,
);
throw new Error(`User ID not found on request. Cannot create tool for ${toolKey}.`);
}
/** @type {(toolArguments: Object | string, config?: GraphRunnableConfig) => Promise<unknown>} */
const _call = async (toolArguments, config) => {
try {
const mcpManager = await getMCPManager();
const derivedSignal = config?.signal ? AbortSignal.any([config.signal]) : undefined;
const mcpManager = getMCPManager(config?.configurable?.user_id);
const provider = (config?.metadata?.provider || _provider)?.toLowerCase();
const result = await mcpManager.callTool({
serverName,
toolName,
provider,
toolArguments,
options: {
signal: config?.signal,
userId: config?.configurable?.user_id,
signal: derivedSignal,
},
});
if (isAssistantsEndpoint(provider) && Array.isArray(result)) {
return result[0];
}
@ -58,8 +71,13 @@ async function createMCPTool({ req, toolKey, provider }) {
}
return result;
} catch (error) {
logger.error(`${toolName} MCP server tool call failed`, error);
return `${toolName} MCP server tool call failed.`;
logger.error(
`[MCP][User: ${config?.configurable?.user_id}][${serverName}] Error calling "${toolName}" MCP tool:`,
error,
);
throw new Error(
`"${toolKey}" tool call failed${error?.message ? `: ${error?.message}` : '.'}`,
);
}
};

View file

@ -55,8 +55,7 @@ async function retrieveRun({ thread_id, run_id, timeout, openai }) {
return response.data;
} catch (error) {
const message = '[retrieveRun] Failed to retrieve run data:';
logAxiosError({ message, error });
throw error;
throw new Error(logAxiosError({ message, error }));
}
}

View file

@ -132,6 +132,8 @@ async function saveUserMessage(req, params) {
* @param {string} params.endpoint - The conversation endpoint
* @param {string} params.parentMessageId - The latest user message that triggered this response.
* @param {string} [params.instructions] - Optional: from preset for `instructions` field.
* @param {string} [params.spec] - Optional: Model spec identifier.
* @param {string} [params.iconURL]
* Overrides the instructions of the assistant.
* @param {string} [params.promptPrefix] - Optional: from preset for `additional_instructions` field.
* @return {Promise<Run>} A promise that resolves to the created run object.
@ -154,6 +156,8 @@ async function saveAssistantMessage(req, params) {
text: params.text,
unfinished: false,
// tokenCount,
iconURL: params.iconURL,
spec: params.spec,
});
await saveConvo(
@ -165,6 +169,8 @@ async function saveAssistantMessage(req, params) {
instructions: params.instructions,
assistant_id: params.assistant_id,
model: params.model,
iconURL: params.iconURL,
spec: params.spec,
},
{ context: 'api/server/services/Threads/manage.js #saveAssistantMessage' },
);

View file

@ -93,11 +93,12 @@ const refreshAccessToken = async ({
return response.data;
} catch (error) {
const message = 'Error refreshing OAuth tokens';
logAxiosError({
message,
error,
});
throw new Error(message);
throw new Error(
logAxiosError({
message,
error,
}),
);
}
};
@ -156,11 +157,12 @@ const getAccessToken = async ({
return response.data;
} catch (error) {
const message = 'Error exchanging OAuth code';
logAxiosError({
message,
error,
});
throw new Error(message);
throw new Error(
logAxiosError({
message,
error,
}),
);
}
};

View file

@ -8,6 +8,7 @@ const {
ErrorTypes,
ContentTypes,
imageGenTools,
EToolResources,
EModelEndpoint,
actionDelimiter,
ImageVisionTool,
@ -15,9 +16,20 @@ const {
AgentCapabilities,
validateAndParseOpenAPISpec,
} = require('librechat-data-provider');
const {
createActionTool,
decryptMetadata,
loadActionSets,
domainParser,
} = require('./ActionService');
const {
createOpenAIImageTools,
createYouTubeTools,
manifestToolMap,
toolkits,
} = require('~/app/clients/tools');
const { processFileURL, uploadImageBuffer } = require('~/server/services/Files/process');
const { createYouTubeTools, manifestToolMap, toolkits } = require('~/app/clients/tools');
const { loadActionSets, createActionTool, domainParser } = require('./ActionService');
const { isActionDomainAllowed } = require('~/server/services/domains');
const { getEndpointsConfig } = require('~/server/services/Config');
const { recordUsage } = require('~/server/services/Threads');
const { loadTools } = require('~/app/clients/tools/util');
@ -25,6 +37,30 @@ const { redactMessage } = require('~/config/parsers');
const { sleep } = require('~/server/utils');
const { logger } = require('~/config');
/**
* @param {string} toolName
* @returns {string | undefined} toolKey
*/
function getToolkitKey(toolName) {
/** @type {string|undefined} */
let toolkitKey;
for (const toolkit of toolkits) {
if (toolName.startsWith(EToolResources.image_edit)) {
const splitMatches = toolkit.pluginKey.split('_');
const suffix = splitMatches[splitMatches.length - 1];
if (toolName.endsWith(suffix)) {
toolkitKey = toolkit.pluginKey;
break;
}
}
if (toolName.startsWith(toolkit.pluginKey)) {
toolkitKey = toolkit.pluginKey;
break;
}
}
return toolkitKey;
}
/**
* Loads and formats tools from the specified tool directory.
*
@ -97,14 +133,16 @@ function loadAndFormatTools({ directory, adminFilter = [], adminIncluded = [] })
tools.push(formattedTool);
}
/** Basic Tools; schema: { input: string } */
const basicToolInstances = [new Calculator(), ...createYouTubeTools({ override: true })];
/** Basic Tools & Toolkits; schema: { input: string } */
const basicToolInstances = [
new Calculator(),
...createOpenAIImageTools({ override: true }),
...createYouTubeTools({ override: true }),
];
for (const toolInstance of basicToolInstances) {
const formattedTool = formatToOpenAIAssistantTool(toolInstance);
let toolName = formattedTool[Tools.function].name;
toolName = toolkits.some((toolkit) => toolName.startsWith(toolkit.pluginKey))
? toolName.split('_')[0]
: toolName;
toolName = getToolkitKey(toolName) ?? toolName;
if (filter.has(toolName) && included.size === 0) {
continue;
}
@ -315,58 +353,95 @@ async function processRequiredActions(client, requiredActions) {
if (!tool) {
// throw new Error(`Tool ${currentAction.tool} not found.`);
// Load all action sets once if not already loaded
if (!actionSets.length) {
actionSets =
(await loadActionSets({
assistant_id: client.req.body.assistant_id,
})) ?? [];
// Process all action sets once
// Map domains to their processed action sets
const processedDomains = new Map();
const domainMap = new Map();
for (const action of actionSets) {
const domain = await domainParser(action.metadata.domain, true);
domainMap.set(domain, action);
// Check if domain is allowed
const isDomainAllowed = await isActionDomainAllowed(action.metadata.domain);
if (!isDomainAllowed) {
continue;
}
// Validate and parse OpenAPI spec
const validationResult = validateAndParseOpenAPISpec(action.metadata.raw_spec);
if (!validationResult.spec) {
throw new Error(
`Invalid spec: user: ${client.req.user.id} | thread_id: ${requiredActions[0].thread_id} | run_id: ${requiredActions[0].run_id}`,
);
}
// Process the OpenAPI spec
const { requestBuilders } = openapiToFunction(validationResult.spec);
// Store encrypted values for OAuth flow
const encrypted = {
oauth_client_id: action.metadata.oauth_client_id,
oauth_client_secret: action.metadata.oauth_client_secret,
};
// Decrypt metadata
const decryptedAction = { ...action };
decryptedAction.metadata = await decryptMetadata(action.metadata);
processedDomains.set(domain, {
action: decryptedAction,
requestBuilders,
encrypted,
});
// Store builders for reuse
ActionBuildersMap[action.metadata.domain] = requestBuilders;
}
// Update actionSets reference to use the domain map
actionSets = { domainMap, processedDomains };
}
let actionSet = null;
// Find the matching domain for this tool
let currentDomain = '';
for (let action of actionSets) {
const domain = await domainParser(client.req, action.metadata.domain, true);
for (const domain of actionSets.domainMap.keys()) {
if (currentAction.tool.includes(domain)) {
currentDomain = domain;
actionSet = action;
break;
}
}
if (!actionSet) {
if (!currentDomain || !actionSets.processedDomains.has(currentDomain)) {
// TODO: try `function` if no action set is found
// throw new Error(`Tool ${currentAction.tool} not found.`);
continue;
}
let builders = ActionBuildersMap[actionSet.metadata.domain];
if (!builders) {
const validationResult = validateAndParseOpenAPISpec(actionSet.metadata.raw_spec);
if (!validationResult.spec) {
throw new Error(
`Invalid spec: user: ${client.req.user.id} | thread_id: ${requiredActions[0].thread_id} | run_id: ${requiredActions[0].run_id}`,
);
}
const { requestBuilders } = openapiToFunction(validationResult.spec);
ActionToolMap[actionSet.metadata.domain] = requestBuilders;
builders = requestBuilders;
}
const { action, requestBuilders, encrypted } = actionSets.processedDomains.get(currentDomain);
const functionName = currentAction.tool.replace(`${actionDelimiter}${currentDomain}`, '');
const requestBuilder = builders[functionName];
const requestBuilder = requestBuilders[functionName];
if (!requestBuilder) {
// throw new Error(`Tool ${currentAction.tool} not found.`);
continue;
}
// We've already decrypted the metadata, so we can pass it directly
tool = await createActionTool({
req: client.req,
userId: client.req.user.id,
res: client.res,
action: actionSet,
action,
requestBuilder,
// Note: intentionally not passing zodSchema, name, and description for assistants API
encrypted, // Pass the encrypted values for OAuth flow
});
if (!tool) {
logger.warn(
@ -415,7 +490,7 @@ async function processRequiredActions(client, requiredActions) {
* @param {Object} params - Run params containing user and request information.
* @param {ServerRequest} params.req - The request object.
* @param {ServerResponse} params.res - The request object.
* @param {Agent} params.agent - The agent to load tools for.
* @param {Pick<Agent, 'id' | 'provider' | 'model' | 'tools'} params.agent - The agent to load tools for.
* @param {string | undefined} [params.openAIApiKey] - The OpenAI API key.
* @returns {Promise<{ tools?: StructuredTool[] }>} The agent tools.
*/
@ -425,21 +500,16 @@ async function loadAgentTools({ req, res, agent, tool_resources, openAIApiKey })
}
const endpointsConfig = await getEndpointsConfig(req);
const capabilities = endpointsConfig?.[EModelEndpoint.agents]?.capabilities ?? [];
const areToolsEnabled = capabilities.includes(AgentCapabilities.tools);
if (!areToolsEnabled) {
logger.debug('Tools are not enabled for this agent.');
return {};
}
const isFileSearchEnabled = capabilities.includes(AgentCapabilities.file_search);
const isCodeEnabled = capabilities.includes(AgentCapabilities.execute_code);
const areActionsEnabled = capabilities.includes(AgentCapabilities.actions);
const enabledCapabilities = new Set(endpointsConfig?.[EModelEndpoint.agents]?.capabilities ?? []);
const checkCapability = (capability) => enabledCapabilities.has(capability);
const areToolsEnabled = checkCapability(AgentCapabilities.tools);
const _agentTools = agent.tools?.filter((tool) => {
if (tool === Tools.file_search && !isFileSearchEnabled) {
return false;
} else if (tool === Tools.execute_code && !isCodeEnabled) {
if (tool === Tools.file_search) {
return checkCapability(AgentCapabilities.file_search);
} else if (tool === Tools.execute_code) {
return checkCapability(AgentCapabilities.execute_code);
} else if (!areToolsEnabled && !tool.includes(actionDelimiter)) {
return false;
}
return true;
@ -473,6 +543,10 @@ async function loadAgentTools({ req, res, agent, tool_resources, openAIApiKey })
continue;
}
if (!areToolsEnabled) {
continue;
}
if (tool.mcp === true) {
agentTools.push(tool);
continue;
@ -505,14 +579,69 @@ async function loadAgentTools({ req, res, agent, tool_resources, openAIApiKey })
return map;
}, {});
if (!areActionsEnabled) {
if (!checkCapability(AgentCapabilities.actions)) {
return {
tools: agentTools,
toolContextMap,
};
}
let actionSets = [];
const actionSets = (await loadActionSets({ agent_id: agent.id })) ?? [];
if (actionSets.length === 0) {
if (_agentTools.length > 0 && agentTools.length === 0) {
logger.warn(`No tools found for the specified tool calls: ${_agentTools.join(', ')}`);
}
return {
tools: agentTools,
toolContextMap,
};
}
// Process each action set once (validate spec, decrypt metadata)
const processedActionSets = new Map();
const domainMap = new Map();
for (const action of actionSets) {
const domain = await domainParser(action.metadata.domain, true);
domainMap.set(domain, action);
// Check if domain is allowed (do this once per action set)
const isDomainAllowed = await isActionDomainAllowed(action.metadata.domain);
if (!isDomainAllowed) {
continue;
}
// Validate and parse OpenAPI spec once per action set
const validationResult = validateAndParseOpenAPISpec(action.metadata.raw_spec);
if (!validationResult.spec) {
continue;
}
const encrypted = {
oauth_client_id: action.metadata.oauth_client_id,
oauth_client_secret: action.metadata.oauth_client_secret,
};
// Decrypt metadata once per action set
const decryptedAction = { ...action };
decryptedAction.metadata = await decryptMetadata(action.metadata);
// Process the OpenAPI spec once per action set
const { requestBuilders, functionSignatures, zodSchemas } = openapiToFunction(
validationResult.spec,
true,
);
processedActionSets.set(domain, {
action: decryptedAction,
requestBuilders,
functionSignatures,
zodSchemas,
encrypted,
});
}
// Now map tools to the processed action sets
const ActionToolMap = {};
for (const toolName of _agentTools) {
@ -520,55 +649,47 @@ async function loadAgentTools({ req, res, agent, tool_resources, openAIApiKey })
continue;
}
if (!actionSets.length) {
actionSets = (await loadActionSets({ agent_id: agent.id })) ?? [];
}
let actionSet = null;
// Find the matching domain for this tool
let currentDomain = '';
for (let action of actionSets) {
const domain = await domainParser(req, action.metadata.domain, true);
for (const domain of domainMap.keys()) {
if (toolName.includes(domain)) {
currentDomain = domain;
actionSet = action;
break;
}
}
if (!actionSet) {
if (!currentDomain || !processedActionSets.has(currentDomain)) {
continue;
}
const validationResult = validateAndParseOpenAPISpec(actionSet.metadata.raw_spec);
if (validationResult.spec) {
const { requestBuilders, functionSignatures, zodSchemas } = openapiToFunction(
validationResult.spec,
true,
);
const functionName = toolName.replace(`${actionDelimiter}${currentDomain}`, '');
const functionSig = functionSignatures.find((sig) => sig.name === functionName);
const requestBuilder = requestBuilders[functionName];
const zodSchema = zodSchemas[functionName];
const { action, encrypted, zodSchemas, requestBuilders, functionSignatures } =
processedActionSets.get(currentDomain);
const functionName = toolName.replace(`${actionDelimiter}${currentDomain}`, '');
const functionSig = functionSignatures.find((sig) => sig.name === functionName);
const requestBuilder = requestBuilders[functionName];
const zodSchema = zodSchemas[functionName];
if (requestBuilder) {
const tool = await createActionTool({
req,
res,
action: actionSet,
requestBuilder,
zodSchema,
name: toolName,
description: functionSig.description,
});
if (!tool) {
logger.warn(
`Invalid action: user: ${req.user.id} | agent_id: ${agent.id} | toolName: ${toolName}`,
);
throw new Error(`{"type":"${ErrorTypes.INVALID_ACTION}"}`);
}
agentTools.push(tool);
ActionToolMap[toolName] = tool;
if (requestBuilder) {
const tool = await createActionTool({
userId: req.user.id,
res,
action,
requestBuilder,
zodSchema,
encrypted,
name: toolName,
description: functionSig.description,
});
if (!tool) {
logger.warn(
`Invalid action: user: ${req.user.id} | agent_id: ${agent.id} | toolName: ${toolName}`,
);
throw new Error(`{"type":"${ErrorTypes.INVALID_ACTION}"}`);
}
agentTools.push(tool);
ActionToolMap[toolName] = tool;
}
}
@ -584,6 +705,7 @@ async function loadAgentTools({ req, res, agent, tool_resources, openAIApiKey })
}
module.exports = {
getToolkitKey,
loadAgentTools,
loadAndFormatTools,
processRequiredActions,

View file

@ -13,6 +13,24 @@ const secretDefaults = {
JWT_REFRESH_SECRET: 'eaa5191f2914e30b9387fd84e254e4ba6fc51b4654968a9b0803b456a54b8418',
};
const deprecatedVariables = [
{
key: 'CHECK_BALANCE',
description:
'Please use the `balance` field in the `librechat.yaml` config file instead.\nMore info: https://librechat.ai/docs/configuration/librechat_yaml/object_structure/balance#overview',
},
{
key: 'START_BALANCE',
description:
'Please use the `balance` field in the `librechat.yaml` config file instead.\nMore info: https://librechat.ai/docs/configuration/librechat_yaml/object_structure/balance#overview',
},
{
key: 'GOOGLE_API_KEY',
description:
'Please use the `GOOGLE_SEARCH_API_KEY` environment variable for the Google Search Tool instead.',
},
];
/**
* Checks environment variables for default secrets and deprecated variables.
* Logs warnings for any default secret values being used and for usage of deprecated `GOOGLE_API_KEY`.
@ -37,19 +55,11 @@ function checkVariables() {
\u200B`);
}
if (process.env.GOOGLE_API_KEY) {
logger.warn(
'The `GOOGLE_API_KEY` environment variable is deprecated.\nPlease use the `GOOGLE_SEARCH_API_KEY` environment variable instead.',
);
}
if (process.env.OPENROUTER_API_KEY) {
logger.warn(
`The \`OPENROUTER_API_KEY\` environment variable is deprecated and its functionality will be removed soon.
Use of this environment variable is highly discouraged as it can lead to unexpected errors when using custom endpoints.
Please use the config (\`librechat.yaml\`) file for setting up OpenRouter, and use \`OPENROUTER_KEY\` or another environment variable instead.`,
);
}
deprecatedVariables.forEach(({ key, description }) => {
if (process.env[key]) {
logger.warn(`The \`${key}\` environment variable is deprecated. ${description}`);
}
});
checkPasswordReset();
}

View file

@ -18,12 +18,15 @@ async function loadDefaultInterface(config, configDefaults, roleName = SystemRol
const { interface: interfaceConfig } = config ?? {};
const { interface: defaults } = configDefaults;
const hasModelSpecs = config?.modelSpecs?.list?.length > 0;
const includesAddedEndpoints = config?.modelSpecs?.addedEndpoints?.length > 0;
/** @type {TCustomConfig['interface']} */
const loadedInterface = removeNullishValues({
endpointsMenu:
interfaceConfig?.endpointsMenu ?? (hasModelSpecs ? false : defaults.endpointsMenu),
modelSelect: interfaceConfig?.modelSelect ?? (hasModelSpecs ? false : defaults.modelSelect),
modelSelect:
interfaceConfig?.modelSelect ??
(hasModelSpecs ? includesAddedEndpoints : defaults.modelSelect),
parameters: interfaceConfig?.parameters ?? (hasModelSpecs ? false : defaults.parameters),
presets: interfaceConfig?.presets ?? (hasModelSpecs ? false : defaults.presets),
sidePanel: interfaceConfig?.sidePanel ?? defaults.sidePanel,

View file

@ -6,9 +6,10 @@ const { logger } = require('~/config');
* Sets up Model Specs from the config (`librechat.yaml`) file.
* @param {TCustomConfig['endpoints']} [endpoints] - The loaded custom configuration for endpoints.
* @param {TCustomConfig['modelSpecs'] | undefined} [modelSpecs] - The loaded custom configuration for model specs.
* @param {TCustomConfig['interface'] | undefined} [interfaceConfig] - The loaded interface configuration.
* @returns {TCustomConfig['modelSpecs'] | undefined} The processed model specs, if any.
*/
function processModelSpecs(endpoints, _modelSpecs) {
function processModelSpecs(endpoints, _modelSpecs, interfaceConfig) {
if (!_modelSpecs) {
return undefined;
}
@ -20,6 +21,19 @@ function processModelSpecs(endpoints, _modelSpecs) {
const customEndpoints = endpoints?.[EModelEndpoint.custom] ?? [];
if (interfaceConfig.modelSelect !== true && (_modelSpecs.addedEndpoints?.length ?? 0) > 0) {
logger.warn(
`To utilize \`addedEndpoints\`, which allows provider/model selections alongside model specs, set \`modelSelect: true\` in the interface configuration.
Example:
\`\`\`yaml
interface:
modelSelect: true
\`\`\`
`,
);
}
for (const spec of list) {
if (EModelEndpoint[spec.preset.endpoint] && spec.preset.endpoint !== EModelEndpoint.custom) {
modelSpecs.push(spec);

View file

@ -1,4 +1,4 @@
const Keyv = require('keyv');
const { Keyv } = require('keyv');
const passport = require('passport');
const session = require('express-session');
const MemoryStore = require('memorystore')(session);
@ -49,7 +49,7 @@ const configureSocialLogins = (app) => {
if (isEnabled(process.env.USE_REDIS)) {
logger.debug('Using Redis for session storage in OpenID...');
const keyv = new Keyv({ store: keyvRedis });
const client = keyv.opts.store.redis;
const client = keyv.opts.store.client;
sessionOptions.store = new RedisStore({ client, prefix: 'openid_session' });
} else {
sessionOptions.store = new MemoryStore({

View file

@ -70,7 +70,13 @@ const sendError = async (req, res, options, callback) => {
}
if (shouldSaveMessage) {
await saveMessage(req, { ...errorMessage, user });
await saveMessage(
req,
{ ...errorMessage, user },
{
context: 'api/server/utils/streamResponse.js - sendError',
},
);
}
if (!errorMessage.error) {