diff --git a/api/app/clients/AnthropicClient.js b/api/app/clients/AnthropicClient.js index 2373a321f5..16b21ea2e3 100644 --- a/api/app/clients/AnthropicClient.js +++ b/api/app/clients/AnthropicClient.js @@ -2,8 +2,9 @@ const Anthropic = require('@anthropic-ai/sdk'); const { HttpsProxyAgent } = require('https-proxy-agent'); const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken'); const { - getResponseSender, + Constants, EModelEndpoint, + getResponseSender, validateVisionModel, } = require('librechat-data-provider'); const { encodeAndFormat } = require('~/server/services/Files/images/encode'); @@ -16,6 +17,7 @@ const { } = require('./prompts'); const spendTokens = require('~/models/spendTokens'); const { getModelMaxTokens } = require('~/utils'); +const { sleep } = require('~/server/utils'); const BaseClient = require('./BaseClient'); const { logger } = require('~/config'); @@ -605,6 +607,7 @@ class AnthropicClient extends BaseClient { }; const maxRetries = 3; + const streamRate = this.options.streamRate ?? Constants.DEFAULT_STREAM_RATE; async function processResponse() { let attempts = 0; @@ -627,6 +630,8 @@ class AnthropicClient extends BaseClient { } else if (completion.completion) { handleChunk(completion.completion); } + + await sleep(streamRate); } // Successful processing, exit loop diff --git a/api/app/clients/BaseClient.js b/api/app/clients/BaseClient.js index c7b4f977c8..b09a6a5d95 100644 --- a/api/app/clients/BaseClient.js +++ b/api/app/clients/BaseClient.js @@ -1,10 +1,11 @@ const crypto = require('crypto'); const fetch = require('node-fetch'); -const { supportsBalanceCheck, Constants } = require('librechat-data-provider'); +const { supportsBalanceCheck, Constants, CacheKeys, Time } = require('librechat-data-provider'); const { getMessages, saveMessage, updateMessage, saveConvo } = require('~/models'); const { addSpaceIfNeeded, isEnabled } = require('~/server/utils'); const checkBalance = require('~/models/checkBalance'); const { getFiles } = require('~/models/File'); +const { getLogStores } = require('~/cache'); const TextStream = require('./TextStream'); const { logger } = require('~/config'); @@ -540,6 +541,15 @@ class BaseClient { await this.recordTokenUsage({ promptTokens, completionTokens }); } this.responsePromise = this.saveMessageToDatabase(responseMessage, saveOptions, user); + const messageCache = getLogStores(CacheKeys.MESSAGES); + messageCache.set( + responseMessageId, + { + text: responseMessage.text, + complete: true, + }, + Time.FIVE_MINUTES, + ); delete responseMessage.tokenCount; return responseMessage; } @@ -598,7 +608,11 @@ class BaseClient { * @param {string | null} user */ async saveMessageToDatabase(message, endpointOptions, user = null) { - const savedMessage = await saveMessage({ + if (this.user && user !== this.user) { + throw new Error('User mismatch.'); + } + + const savedMessage = await saveMessage(this.options.req, { ...message, endpoint: this.options.endpoint, unfinished: false, @@ -619,7 +633,7 @@ class BaseClient { } async updateMessageInDatabase(message) { - await updateMessage(message); + await updateMessage(this.options.req, message); } /** diff --git a/api/app/clients/GoogleClient.js b/api/app/clients/GoogleClient.js index a01df71841..e115ab1db8 100644 --- a/api/app/clients/GoogleClient.js +++ b/api/app/clients/GoogleClient.js @@ -13,10 +13,12 @@ const { endpointSettings, EModelEndpoint, VisionModes, + Constants, AuthKeys, } = require('librechat-data-provider'); const { encodeAndFormat } = require('~/server/services/Files/images'); const { getModelMaxTokens } = require('~/utils'); +const { sleep } = require('~/server/utils'); const { logger } = require('~/config'); const { formatMessage, @@ -620,8 +622,9 @@ class GoogleClient extends BaseClient { } async getCompletion(_payload, options = {}) { - const { onProgress, abortController } = options; const { parameters, instances } = _payload; + const { onProgress, abortController } = options; + const streamRate = this.options.streamRate ?? Constants.DEFAULT_STREAM_RATE; const { messages: _messages, context, examples: _examples } = instances?.[0] ?? {}; let examples; @@ -701,6 +704,7 @@ class GoogleClient extends BaseClient { delay, }); reply += chunkText; + await sleep(streamRate); } return reply; } @@ -712,10 +716,17 @@ class GoogleClient extends BaseClient { safetySettings: safetySettings, }); - let delay = this.isGenerativeModel ? 12 : 8; - if (modelName.includes('flash')) { - delay = 5; + let delay = this.options.streamRate || 8; + + if (!this.options.streamRate) { + if (this.isGenerativeModel) { + delay = 12; + } + if (modelName.includes('flash')) { + delay = 5; + } } + for await (const chunk of stream) { const chunkText = chunk?.content ?? chunk; await this.generateTextStream(chunkText, onProgress, { diff --git a/api/app/clients/OllamaClient.js b/api/app/clients/OllamaClient.js index 57bc8754fb..c88ef72d58 100644 --- a/api/app/clients/OllamaClient.js +++ b/api/app/clients/OllamaClient.js @@ -1,7 +1,9 @@ const { z } = require('zod'); const axios = require('axios'); const { Ollama } = require('ollama'); +const { Constants } = require('librechat-data-provider'); const { deriveBaseURL } = require('~/utils'); +const { sleep } = require('~/server/utils'); const { logger } = require('~/config'); const ollamaPayloadSchema = z.object({ @@ -40,6 +42,7 @@ const getValidBase64 = (imageUrl) => { class OllamaClient { constructor(options = {}) { const host = deriveBaseURL(options.baseURL ?? 'http://localhost:11434'); + this.streamRate = options.streamRate ?? Constants.DEFAULT_STREAM_RATE; /** @type {Ollama} */ this.client = new Ollama({ host }); } @@ -136,6 +139,8 @@ class OllamaClient { stream.controller.abort(); break; } + + await sleep(this.streamRate); } } // TODO: regular completion diff --git a/api/app/clients/OpenAIClient.js b/api/app/clients/OpenAIClient.js index 7520cbb897..ccc5165fc7 100644 --- a/api/app/clients/OpenAIClient.js +++ b/api/app/clients/OpenAIClient.js @@ -1182,8 +1182,10 @@ ${convo} }); } + const streamRate = this.options.streamRate ?? Constants.DEFAULT_STREAM_RATE; + if (this.message_file_map && this.isOllama) { - const ollamaClient = new OllamaClient({ baseURL }); + const ollamaClient = new OllamaClient({ baseURL, streamRate }); return await ollamaClient.chatCompletion({ payload: modelOptions, onProgress, @@ -1221,8 +1223,6 @@ ${convo} } }); - const azureDelay = this.modelOptions.model?.includes('gpt-4') ? 30 : 17; - for await (const chunk of stream) { const token = chunk.choices[0]?.delta?.content || ''; intermediateReply += token; @@ -1232,9 +1232,7 @@ ${convo} break; } - if (this.azure) { - await sleep(azureDelay); - } + await sleep(streamRate); } if (!UnexpectedRoleError) { diff --git a/api/app/clients/PluginsClient.js b/api/app/clients/PluginsClient.js index 2ce0ece4e7..a23fb019ba 100644 --- a/api/app/clients/PluginsClient.js +++ b/api/app/clients/PluginsClient.js @@ -1,5 +1,6 @@ const OpenAIClient = require('./OpenAIClient'); const { CallbackManager } = require('langchain/callbacks'); +const { CacheKeys, Time } = require('librechat-data-provider'); const { BufferMemory, ChatMessageHistory } = require('langchain/memory'); const { initializeCustomAgent, initializeFunctionsAgent } = require('./agents'); const { addImages, buildErrorInput, buildPromptPrefix } = require('./output_parsers'); @@ -11,6 +12,7 @@ const { SelfReflectionTool } = require('./tools'); const { isEnabled } = require('~/server/utils'); const { extractBaseURL } = require('~/utils'); const { loadTools } = require('./tools/util'); +const { getLogStores } = require('~/cache'); const { logger } = require('~/config'); class PluginsClient extends OpenAIClient { @@ -220,6 +222,13 @@ class PluginsClient extends OpenAIClient { } } + /** + * + * @param {TMessage} responseMessage + * @param {Partial} saveOptions + * @param {string} user + * @returns + */ async handleResponseMessage(responseMessage, saveOptions, user) { const { output, errorMessage, ...result } = this.result; logger.debug('[PluginsClient][handleResponseMessage] Output:', { @@ -239,6 +248,15 @@ class PluginsClient extends OpenAIClient { } this.responsePromise = this.saveMessageToDatabase(responseMessage, saveOptions, user); + const messageCache = getLogStores(CacheKeys.MESSAGES); + messageCache.set( + responseMessage.messageId, + { + text: responseMessage.text, + complete: true, + }, + Time.FIVE_MINUTES, + ); delete responseMessage.tokenCount; return { ...responseMessage, ...result }; } diff --git a/api/cache/getLogStores.js b/api/cache/getLogStores.js index 9a7282e25a..2b33751a04 100644 --- a/api/cache/getLogStores.js +++ b/api/cache/getLogStores.js @@ -1,13 +1,11 @@ const Keyv = require('keyv'); -const { CacheKeys, ViolationTypes } = require('librechat-data-provider'); +const { CacheKeys, ViolationTypes, Time } = require('librechat-data-provider'); const { logFile, violationFile } = require('./keyvFiles'); const { math, isEnabled } = require('~/server/utils'); const keyvRedis = require('./keyvRedis'); const keyvMongo = require('./keyvMongo'); const { BAN_DURATION, USE_REDIS } = process.env ?? {}; -const THIRTY_MINUTES = 1800000; -const TEN_MINUTES = 600000; const duration = math(BAN_DURATION, 7200000); @@ -29,17 +27,21 @@ const roles = isEnabled(USE_REDIS) ? new Keyv({ store: keyvRedis }) : new Keyv({ namespace: CacheKeys.ROLES }); -const audioRuns = isEnabled(USE_REDIS) // ttl: 30 minutes - ? new Keyv({ store: keyvRedis, ttl: TEN_MINUTES }) - : new Keyv({ namespace: CacheKeys.AUDIO_RUNS, ttl: TEN_MINUTES }); +const audioRuns = isEnabled(USE_REDIS) + ? new Keyv({ store: keyvRedis, ttl: Time.TEN_MINUTES }) + : new Keyv({ namespace: CacheKeys.AUDIO_RUNS, ttl: Time.TEN_MINUTES }); + +const messages = isEnabled(USE_REDIS) + ? new Keyv({ store: keyvRedis, ttl: Time.FIVE_MINUTES }) + : new Keyv({ namespace: CacheKeys.MESSAGES, ttl: Time.FIVE_MINUTES }); const tokenConfig = isEnabled(USE_REDIS) // ttl: 30 minutes - ? new Keyv({ store: keyvRedis, ttl: THIRTY_MINUTES }) - : new Keyv({ namespace: CacheKeys.TOKEN_CONFIG, ttl: THIRTY_MINUTES }); + ? new Keyv({ store: keyvRedis, ttl: Time.THIRTY_MINUTES }) + : new Keyv({ namespace: CacheKeys.TOKEN_CONFIG, ttl: Time.THIRTY_MINUTES }); const genTitle = isEnabled(USE_REDIS) // ttl: 2 minutes - ? new Keyv({ store: keyvRedis, ttl: 120000 }) - : new Keyv({ namespace: CacheKeys.GEN_TITLE, ttl: 120000 }); + ? new Keyv({ store: keyvRedis, ttl: Time.TWO_MINUTES }) + : new Keyv({ namespace: CacheKeys.GEN_TITLE, ttl: Time.TWO_MINUTES }); const modelQueries = isEnabled(process.env.USE_REDIS) ? new Keyv({ store: keyvRedis }) @@ -47,7 +49,7 @@ const modelQueries = isEnabled(process.env.USE_REDIS) const abortKeys = isEnabled(USE_REDIS) ? new Keyv({ store: keyvRedis }) - : new Keyv({ namespace: CacheKeys.ABORT_KEYS, ttl: 600000 }); + : new Keyv({ namespace: CacheKeys.ABORT_KEYS, ttl: Time.TEN_MINUTES }); const namespaces = { [CacheKeys.ROLES]: roles, @@ -81,6 +83,7 @@ const namespaces = { [CacheKeys.GEN_TITLE]: genTitle, [CacheKeys.MODEL_QUERIES]: modelQueries, [CacheKeys.AUDIO_RUNS]: audioRuns, + [CacheKeys.MESSAGES]: messages, }; /** diff --git a/api/models/Message.js b/api/models/Message.js index b9c82ca36b..460c693439 100644 --- a/api/models/Message.js +++ b/api/models/Message.js @@ -4,11 +4,37 @@ const logger = require('~/config/winston'); const idSchema = z.string().uuid(); -module.exports = { - Message, - - async saveMessage({ - user, +/** + * Saves a message in the database. + * + * @async + * @function saveMessage + * @param {Express.Request} req - The request object containing user information. + * @param {Object} params - The message data object. + * @param {string} params.endpoint - The endpoint where the message originated. + * @param {string} params.iconURL - The URL of the sender's icon. + * @param {string} params.messageId - The unique identifier for the message. + * @param {string} params.newMessageId - The new unique identifier for the message (if applicable). + * @param {string} params.conversationId - The identifier of the conversation. + * @param {string} [params.parentMessageId] - The identifier of the parent message, if any. + * @param {string} params.sender - The identifier of the sender. + * @param {string} params.text - The text content of the message. + * @param {boolean} params.isCreatedByUser - Indicates if the message was created by the user. + * @param {string} [params.error] - Any error associated with the message. + * @param {boolean} [params.unfinished] - Indicates if the message is unfinished. + * @param {Object[]} [params.files] - An array of files associated with the message. + * @param {boolean} [params.isEdited] - Indicates if the message was edited. + * @param {string} [params.finish_reason] - Reason for finishing the message. + * @param {number} [params.tokenCount] - The number of tokens in the message. + * @param {string} [params.plugin] - Plugin associated with the message. + * @param {Object[]} [params.plugins] - An array of plugins associated with the message. + * @param {string} [params.model] - The model used to generate the message. + * @returns {Promise} The updated or newly inserted message document. + * @throws {Error} If there is an error in saving the message. + */ +async function saveMessage( + req, + { endpoint, iconURL, messageId, @@ -27,178 +53,271 @@ module.exports = { plugin, plugins, model, - }) { - try { - const validConvoId = idSchema.safeParse(conversationId); - if (!validConvoId.success) { - return; - } + }, +) { + try { + if (!req || !req.user || !req.user.id) { + throw new Error('User not authenticated'); + } - const update = { - user, - iconURL, - endpoint, - messageId: newMessageId || messageId, - conversationId, - parentMessageId, - sender, - text, - isCreatedByUser, - isEdited, - finish_reason, - error, - unfinished, - tokenCount, - plugin, - plugins, - model, - }; + const validConvoId = idSchema.safeParse(conversationId); + if (!validConvoId.success) { + throw new Error('Invalid conversation ID'); + } - if (files) { - update.files = files; - } + const update = { + user: req.user.id, + iconURL, + endpoint, + messageId: newMessageId || messageId, + conversationId, + parentMessageId, + sender, + text, + isCreatedByUser, + isEdited, + finish_reason, + error, + unfinished, + tokenCount, + plugin, + plugins, + model, + }; - const message = await Message.findOneAndUpdate({ messageId }, update, { + if (files) { + update.files = files; + } + + const message = await Message.findOneAndUpdate({ messageId, user: req.user.id }, update, { + upsert: true, + new: true, + }); + + return message.toObject(); + } catch (err) { + logger.error('Error saving message:', err); + throw err; + } +} + +/** + * Saves multiple messages in the database in bulk. + * + * @async + * @function bulkSaveMessages + * @param {Object[]} messages - An array of message objects to save. + * @returns {Promise} The result of the bulk write operation. + * @throws {Error} If there is an error in saving messages in bulk. + */ +async function bulkSaveMessages(messages) { + try { + const bulkOps = messages.map((message) => ({ + updateOne: { + filter: { messageId: message.messageId }, + update: message, upsert: true, + }, + })); + + const result = await Message.bulkWrite(bulkOps); + return result; + } catch (err) { + logger.error('Error saving messages in bulk:', err); + throw err; + } +} + +/** + * Records a message in the database. + * + * @async + * @function recordMessage + * @param {Object} params - The message data object. + * @param {string} params.user - The identifier of the user. + * @param {string} params.endpoint - The endpoint where the message originated. + * @param {string} params.messageId - The unique identifier for the message. + * @param {string} params.conversationId - The identifier of the conversation. + * @param {string} [params.parentMessageId] - The identifier of the parent message, if any. + * @param {Partial} rest - Any additional properties from the TMessage typedef not explicitly listed. + * @returns {Promise} The updated or newly inserted message document. + * @throws {Error} If there is an error in saving the message. + */ +async function recordMessage({ + user, + endpoint, + messageId, + conversationId, + parentMessageId, + ...rest +}) { + try { + // No parsing of convoId as may use threadId + const message = { + user, + endpoint, + messageId, + conversationId, + parentMessageId, + ...rest, + }; + + return await Message.findOneAndUpdate({ user, messageId }, message, { + upsert: true, + new: true, + }); + } catch (err) { + logger.error('Error saving message:', err); + throw err; + } +} + +/** + * Updates the text of a message. + * + * @async + * @function updateMessageText + * @param {Object} params - The update data object. + * @param {Object} req - The request object. + * @param {string} params.messageId - The unique identifier for the message. + * @param {string} params.text - The new text content of the message. + * @returns {Promise} + * @throws {Error} If there is an error in updating the message text. + */ +async function updateMessageText(req, { messageId, text }) { + try { + await Message.updateOne({ messageId, user: req.user.id }, { text }); + } catch (err) { + logger.error('Error updating message text:', err); + throw err; + } +} + +/** + * Updates a message. + * + * @async + * @function updateMessage + * @param {Object} message - The message object containing update data. + * @param {Object} req - The request object. + * @param {string} message.messageId - The unique identifier for the message. + * @param {string} [message.text] - The new text content of the message. + * @param {Object[]} [message.files] - The files associated with the message. + * @param {boolean} [message.isCreatedByUser] - Indicates if the message was created by the user. + * @param {string} [message.sender] - The identifier of the sender. + * @param {number} [message.tokenCount] - The number of tokens in the message. + * @returns {Promise} The updated message document. + * @throws {Error} If there is an error in updating the message or if the message is not found. + */ +async function updateMessage(req, message) { + try { + const { messageId, ...update } = message; + update.isEdited = true; + const updatedMessage = await Message.findOneAndUpdate( + { messageId, user: req.user.id }, + update, + { new: true, + }, + ); + + if (!updatedMessage) { + throw new Error('Message not found or user not authorized.'); + } + + return { + messageId: updatedMessage.messageId, + conversationId: updatedMessage.conversationId, + parentMessageId: updatedMessage.parentMessageId, + sender: updatedMessage.sender, + text: updatedMessage.text, + isCreatedByUser: updatedMessage.isCreatedByUser, + tokenCount: updatedMessage.tokenCount, + isEdited: true, + }; + } catch (err) { + logger.error('Error updating message:', err); + throw err; + } +} + +/** + * Deletes messages in a conversation since a specific message. + * + * @async + * @function deleteMessagesSince + * @param {Object} params - The parameters object. + * @param {Object} req - The request object. + * @param {string} params.messageId - The unique identifier for the message. + * @param {string} params.conversationId - The identifier of the conversation. + * @returns {Promise} The number of deleted messages. + * @throws {Error} If there is an error in deleting messages. + */ +async function deleteMessagesSince(req, { messageId, conversationId }) { + try { + const message = await Message.findOne({ messageId, user: req.user.id }).lean(); + + if (message) { + const query = Message.find({ conversationId, user: req.user.id }); + return await query.deleteMany({ + createdAt: { $gt: message.createdAt }, }); - - return message.toObject(); - } catch (err) { - logger.error('Error saving message:', err); - throw new Error('Failed to save message.'); } - }, + return undefined; + } catch (err) { + logger.error('Error deleting messages:', err); + throw err; + } +} - async bulkSaveMessages(messages) { - try { - const bulkOps = messages.map((message) => ({ - updateOne: { - filter: { messageId: message.messageId }, - update: message, - upsert: true, - }, - })); - - const result = await Message.bulkWrite(bulkOps); - return result; - } catch (err) { - logger.error('Error saving messages in bulk:', err); - throw new Error('Failed to save messages in bulk.'); +/** + * Retrieves messages from the database. + * @async + * @function getMessages + * @param {Record} filter - The filter criteria. + * @param {string | undefined} [select] - The fields to select. + * @returns {Promise} The messages that match the filter criteria. + * @throws {Error} If there is an error in retrieving messages. + */ +async function getMessages(filter, select) { + try { + if (select) { + return await Message.find(filter).select(select).sort({ createdAt: 1 }).lean(); } - }, - /** - * Records a message in the database. - * - * @async - * @function recordMessage - * @param {Object} params - The message data object. - * @param {string} params.user - The identifier of the user. - * @param {string} params.endpoint - The endpoint where the message originated. - * @param {string} params.messageId - The unique identifier for the message. - * @param {string} params.conversationId - The identifier of the conversation. - * @param {string} [params.parentMessageId] - The identifier of the parent message, if any. - * @param {Partial} rest - Any additional properties from the TMessage typedef not explicitly listed. - * @returns {Promise} The updated or newly inserted message document. - * @throws {Error} If there is an error in saving the message. - */ - async recordMessage({ user, endpoint, messageId, conversationId, parentMessageId, ...rest }) { - try { - // No parsing of convoId as may use threadId - const message = { - user, - endpoint, - messageId, - conversationId, - parentMessageId, - ...rest, - }; + return await Message.find(filter).sort({ createdAt: 1 }).lean(); + } catch (err) { + logger.error('Error getting messages:', err); + throw err; + } +} - return await Message.findOneAndUpdate({ user, messageId }, message, { - upsert: true, - new: true, - }); - } catch (err) { - logger.error('Error saving message:', err); - throw new Error('Failed to save message.'); - } - }, - async updateMessageText({ messageId, text }) { - try { - await Message.updateOne({ messageId }, { text }); - } catch (err) { - logger.error('Error updating message text:', err); - throw new Error('Failed to update message text.'); - } - }, - async updateMessage(message) { - try { - const { messageId, ...update } = message; - update.isEdited = true; - const updatedMessage = await Message.findOneAndUpdate({ messageId }, update, { - new: true, - }); +/** + * Deletes messages from the database. + * + * @async + * @function deleteMessages + * @param {Object} filter - The filter criteria to find messages to delete. + * @returns {Promise} The number of deleted messages. + * @throws {Error} If there is an error in deleting messages. + */ +async function deleteMessages(filter) { + try { + return await Message.deleteMany(filter); + } catch (err) { + logger.error('Error deleting messages:', err); + throw err; + } +} - if (!updatedMessage) { - throw new Error('Message not found.'); - } - - return { - messageId: updatedMessage.messageId, - conversationId: updatedMessage.conversationId, - parentMessageId: updatedMessage.parentMessageId, - sender: updatedMessage.sender, - text: updatedMessage.text, - isCreatedByUser: updatedMessage.isCreatedByUser, - tokenCount: updatedMessage.tokenCount, - isEdited: true, - }; - } catch (err) { - logger.error('Error updating message:', err); - throw new Error('Failed to update message.'); - } - }, - async deleteMessagesSince({ messageId, conversationId }) { - try { - const message = await Message.findOne({ messageId }).lean(); - - if (message) { - return await Message.find({ conversationId }).deleteMany({ - createdAt: { $gt: message.createdAt }, - }); - } - } catch (err) { - logger.error('Error deleting messages:', err); - throw new Error('Failed to delete messages.'); - } - }, - - /** - * Retrieves messages from the database. - * @param {Record} filter - * @param {string | undefined} [select] - * @returns - */ - async getMessages(filter, select) { - try { - if (select) { - return await Message.find(filter).select(select).sort({ createdAt: 1 }).lean(); - } - - return await Message.find(filter).sort({ createdAt: 1 }).lean(); - } catch (err) { - logger.error('Error getting messages:', err); - throw new Error('Failed to get messages.'); - } - }, - - async deleteMessages(filter) { - try { - return await Message.deleteMany(filter); - } catch (err) { - logger.error('Error deleting messages:', err); - throw new Error('Failed to delete messages.'); - } - }, +module.exports = { + Message, + saveMessage, + bulkSaveMessages, + recordMessage, + updateMessageText, + updateMessage, + deleteMessagesSince, + getMessages, + deleteMessages, }; diff --git a/api/models/Message.spec.js b/api/models/Message.spec.js new file mode 100644 index 0000000000..d5fbba9ed8 --- /dev/null +++ b/api/models/Message.spec.js @@ -0,0 +1,239 @@ +const mongoose = require('mongoose'); +const { v4: uuidv4 } = require('uuid'); + +jest.mock('mongoose'); + +const mockFindQuery = { + select: jest.fn().mockReturnThis(), + sort: jest.fn().mockReturnThis(), + lean: jest.fn().mockReturnThis(), + deleteMany: jest.fn().mockResolvedValue({ deletedCount: 1 }), +}; + +const mockSchema = { + findOneAndUpdate: jest.fn(), + updateOne: jest.fn(), + findOne: jest.fn(() => ({ + lean: jest.fn(), + })), + find: jest.fn(() => mockFindQuery), + deleteMany: jest.fn(), +}; + +mongoose.model.mockReturnValue(mockSchema); + +jest.mock('~/models/schema/messageSchema', () => mockSchema); + +jest.mock('~/config/winston', () => ({ + error: jest.fn(), +})); + +const { + saveMessage, + getMessages, + updateMessage, + deleteMessages, + updateMessageText, + deleteMessagesSince, +} = require('~/models/Message'); + +describe('Message Operations', () => { + let mockReq; + let mockMessage; + + beforeEach(() => { + jest.clearAllMocks(); + + mockReq = { + user: { id: 'user123' }, + }; + + mockMessage = { + messageId: 'msg123', + conversationId: uuidv4(), + text: 'Hello, world!', + user: 'user123', + }; + + mockSchema.findOneAndUpdate.mockResolvedValue({ + toObject: () => mockMessage, + }); + }); + + describe('saveMessage', () => { + it('should save a message for an authenticated user', async () => { + const result = await saveMessage(mockReq, mockMessage); + expect(result).toEqual(mockMessage); + expect(mockSchema.findOneAndUpdate).toHaveBeenCalledWith( + { messageId: 'msg123', user: 'user123' }, + expect.objectContaining({ user: 'user123' }), + expect.any(Object), + ); + }); + + it('should throw an error for unauthenticated user', async () => { + mockReq.user = null; + await expect(saveMessage(mockReq, mockMessage)).rejects.toThrow('User not authenticated'); + }); + + it('should throw an error for invalid conversation ID', async () => { + mockMessage.conversationId = 'invalid-id'; + await expect(saveMessage(mockReq, mockMessage)).rejects.toThrow('Invalid conversation ID'); + }); + }); + + describe('updateMessageText', () => { + it('should update message text for the authenticated user', async () => { + await updateMessageText(mockReq, { messageId: 'msg123', text: 'Updated text' }); + expect(mockSchema.updateOne).toHaveBeenCalledWith( + { messageId: 'msg123', user: 'user123' }, + { text: 'Updated text' }, + ); + }); + }); + + describe('updateMessage', () => { + it('should update a message for the authenticated user', async () => { + mockSchema.findOneAndUpdate.mockResolvedValue(mockMessage); + const result = await updateMessage(mockReq, { messageId: 'msg123', text: 'Updated text' }); + expect(result).toEqual( + expect.objectContaining({ + messageId: 'msg123', + text: 'Hello, world!', + isEdited: true, + }), + ); + }); + + it('should throw an error if message is not found', async () => { + mockSchema.findOneAndUpdate.mockResolvedValue(null); + await expect( + updateMessage(mockReq, { messageId: 'nonexistent', text: 'Test' }), + ).rejects.toThrow('Message not found or user not authorized.'); + }); + }); + + describe('deleteMessagesSince', () => { + it('should delete messages only for the authenticated user', async () => { + mockSchema.findOne().lean.mockResolvedValueOnce({ createdAt: new Date() }); + mockFindQuery.deleteMany.mockResolvedValueOnce({ deletedCount: 1 }); + const result = await deleteMessagesSince(mockReq, { + messageId: 'msg123', + conversationId: 'convo123', + }); + expect(mockSchema.findOne).toHaveBeenCalledWith({ messageId: 'msg123', user: 'user123' }); + expect(mockSchema.find).not.toHaveBeenCalled(); + expect(result).toBeUndefined(); + }); + + it('should return undefined if no message is found', async () => { + mockSchema.findOne().lean.mockResolvedValueOnce(null); + const result = await deleteMessagesSince(mockReq, { + messageId: 'nonexistent', + conversationId: 'convo123', + }); + expect(result).toBeUndefined(); + }); + }); + + describe('getMessages', () => { + it('should retrieve messages with the correct filter', async () => { + const filter = { conversationId: 'convo123' }; + await getMessages(filter); + expect(mockSchema.find).toHaveBeenCalledWith(filter); + expect(mockFindQuery.sort).toHaveBeenCalledWith({ createdAt: 1 }); + expect(mockFindQuery.lean).toHaveBeenCalled(); + }); + }); + + describe('deleteMessages', () => { + it('should delete messages with the correct filter', async () => { + await deleteMessages({ user: 'user123' }); + expect(mockSchema.deleteMany).toHaveBeenCalledWith({ user: 'user123' }); + }); + }); + + describe('Conversation Hijacking Prevention', () => { + it('should not allow editing a message in another user\'s conversation', async () => { + const attackerReq = { user: { id: 'attacker123' } }; + const victimConversationId = 'victim-convo-123'; + const victimMessageId = 'victim-msg-123'; + + mockSchema.findOneAndUpdate.mockResolvedValue(null); + + await expect( + updateMessage(attackerReq, { + messageId: victimMessageId, + conversationId: victimConversationId, + text: 'Hacked message', + }), + ).rejects.toThrow('Message not found or user not authorized.'); + + expect(mockSchema.findOneAndUpdate).toHaveBeenCalledWith( + { messageId: victimMessageId, user: 'attacker123' }, + expect.anything(), + expect.anything(), + ); + }); + + it('should not allow deleting messages from another user\'s conversation', async () => { + const attackerReq = { user: { id: 'attacker123' } }; + const victimConversationId = 'victim-convo-123'; + const victimMessageId = 'victim-msg-123'; + + mockSchema.findOne().lean.mockResolvedValueOnce(null); // Simulating message not found for this user + const result = await deleteMessagesSince(attackerReq, { + messageId: victimMessageId, + conversationId: victimConversationId, + }); + + expect(result).toBeUndefined(); + expect(mockSchema.findOne).toHaveBeenCalledWith({ + messageId: victimMessageId, + user: 'attacker123', + }); + }); + + it('should not allow inserting a new message into another user\'s conversation', async () => { + const attackerReq = { user: { id: 'attacker123' } }; + const victimConversationId = uuidv4(); // Use a valid UUID + + await expect( + saveMessage(attackerReq, { + conversationId: victimConversationId, + text: 'Inserted malicious message', + messageId: 'new-msg-123', + }), + ).resolves.not.toThrow(); // It should not throw an error + + // Check that the message was saved with the attacker's user ID + expect(mockSchema.findOneAndUpdate).toHaveBeenCalledWith( + { messageId: 'new-msg-123', user: 'attacker123' }, + expect.objectContaining({ + user: 'attacker123', + conversationId: victimConversationId, + }), + expect.anything(), + ); + }); + + it('should allow retrieving messages from any conversation', async () => { + const victimConversationId = 'victim-convo-123'; + + await getMessages({ conversationId: victimConversationId }); + + expect(mockSchema.find).toHaveBeenCalledWith({ + conversationId: victimConversationId, + }); + + mockSchema.find.mockReturnValueOnce({ + select: jest.fn().mockReturnThis(), + sort: jest.fn().mockReturnThis(), + lean: jest.fn().mockResolvedValue([{ text: 'Test message' }]), + }); + + const result = await getMessages({ conversationId: victimConversationId }); + expect(result).toEqual([{ text: 'Test message' }]); + }); + }); +}); diff --git a/api/server/controllers/AskController.js b/api/server/controllers/AskController.js index 2c07398f77..674c22a834 100644 --- a/api/server/controllers/AskController.js +++ b/api/server/controllers/AskController.js @@ -1,7 +1,8 @@ const throttle = require('lodash/throttle'); -const { getResponseSender, Constants, EModelEndpoint } = require('librechat-data-provider'); +const { getResponseSender, Constants, CacheKeys, Time } = require('librechat-data-provider'); const { createAbortController, handleAbortError } = require('~/server/middleware'); const { sendMessage, createOnProgress } = require('~/server/utils'); +const { getLogStores } = require('~/cache'); const { saveMessage } = require('~/models'); const { logger } = require('~/config'); @@ -51,11 +52,13 @@ const AskController = async (req, res, next, initializeClient, addTitle) => { try { const { client } = await initializeClient({ req, res, endpointOption }); - const unfinished = endpointOption.endpoint === EModelEndpoint.google ? false : true; + const messageCache = getLogStores(CacheKeys.MESSAGES); const { onProgress: progressCallback, getPartialText } = createOnProgress({ onProgress: throttle( ({ text: partialText }) => { - saveMessage({ + /* + const unfinished = endpointOption.endpoint === EModelEndpoint.google ? false : true; + messageCache.set(responseMessageId, { messageId: responseMessageId, sender, conversationId, @@ -65,7 +68,10 @@ const AskController = async (req, res, next, initializeClient, addTitle) => { unfinished, error: false, user, - }); + }, Time.FIVE_MINUTES); + */ + + messageCache.set(responseMessageId, partialText, Time.FIVE_MINUTES); }, 3000, { trailing: false }, @@ -144,11 +150,11 @@ const AskController = async (req, res, next, initializeClient, addTitle) => { }); res.end(); - await saveMessage({ ...response, user }); + await saveMessage(req, { ...response, user }); } if (!client.skipSaveUserMessage) { - await saveMessage(userMessage); + await saveMessage(req, userMessage); } if (addTitle && parentMessageId === Constants.NO_PARENT && newConvo) { diff --git a/api/server/controllers/EditController.js b/api/server/controllers/EditController.js index 3315454d38..e8be7f3e7a 100644 --- a/api/server/controllers/EditController.js +++ b/api/server/controllers/EditController.js @@ -1,7 +1,8 @@ const throttle = require('lodash/throttle'); -const { getResponseSender, EModelEndpoint } = require('librechat-data-provider'); +const { getResponseSender, CacheKeys, Time } = require('librechat-data-provider'); const { createAbortController, handleAbortError } = require('~/server/middleware'); const { sendMessage, createOnProgress } = require('~/server/utils'); +const { getLogStores } = require('~/cache'); const { saveMessage } = require('~/models'); const { logger } = require('~/config'); @@ -51,12 +52,14 @@ const EditController = async (req, res, next, initializeClient) => { } }; - const unfinished = endpointOption.endpoint === EModelEndpoint.google ? false : true; + const messageCache = getLogStores(CacheKeys.MESSAGES); const { onProgress: progressCallback, getPartialText } = createOnProgress({ generation, onProgress: throttle( ({ text: partialText }) => { - saveMessage({ + /* + const unfinished = endpointOption.endpoint === EModelEndpoint.google ? false : true; + { messageId: responseMessageId, sender, conversationId, @@ -67,7 +70,8 @@ const EditController = async (req, res, next, initializeClient) => { isEdited: true, error: false, user, - }); + } */ + messageCache.set(responseMessageId, partialText, Time.FIVE_MINUTES); }, 3000, { trailing: false }, @@ -141,7 +145,7 @@ const EditController = async (req, res, next, initializeClient) => { }); res.end(); - await saveMessage({ ...response, user }); + await saveMessage(req, { ...response, user }); } } catch (error) { const partialText = getPartialText(); diff --git a/api/server/controllers/assistants/chatV1.js b/api/server/controllers/assistants/chatV1.js index 624f013af2..3107e78c79 100644 --- a/api/server/controllers/assistants/chatV1.js +++ b/api/server/controllers/assistants/chatV1.js @@ -120,21 +120,22 @@ const chatV1 = async (req, res) => { ? ' If using Azure OpenAI, files are only available in the region of the assistant\'s model at the time of upload.' : '' }`; - return sendResponse(res, messageData, errorMessage); + return sendResponse(req, res, messageData, errorMessage); } else if (error?.message?.includes('string too long')) { return sendResponse( + req, res, messageData, 'Message too long. The Assistants API has a limit of 32,768 characters per message. Please shorten it and try again.', ); } else if (error?.message?.includes(ViolationTypes.TOKEN_BALANCE)) { - return sendResponse(res, messageData, error.message); + return sendResponse(req, res, messageData, error.message); } else { logger.error('[/assistants/chat/]', error); } if (!openai || !thread_id || !run_id) { - return sendResponse(res, messageData, defaultErrorMessage); + return sendResponse(req, res, messageData, defaultErrorMessage); } await sleep(2000); @@ -221,10 +222,10 @@ const chatV1 = async (req, res) => { }; } catch (error) { logger.error('[/assistants/chat/] Error finalizing error process', error); - return sendResponse(res, messageData, 'The Assistant run failed'); + return sendResponse(req, res, messageData, 'The Assistant run failed'); } - return sendResponse(res, finalEvent); + return sendResponse(req, res, finalEvent); }; try { diff --git a/api/server/controllers/assistants/chatV2.js b/api/server/controllers/assistants/chatV2.js index 3b73d1520f..67e106ca0d 100644 --- a/api/server/controllers/assistants/chatV2.js +++ b/api/server/controllers/assistants/chatV2.js @@ -1,12 +1,12 @@ const { v4 } = require('uuid'); const { + Time, Constants, RunStatus, CacheKeys, ContentTypes, ToolCallTypes, EModelEndpoint, - ViolationTypes, retrievalMimeTypes, AssistantStreamEvents, } = require('librechat-data-provider'); @@ -14,12 +14,12 @@ const { initThread, recordUsage, saveUserMessage, - checkMessageGaps, addThreadMetadata, saveAssistantMessage, } = require('~/server/services/Threads'); -const { sendResponse, sendMessage, sleep, isEnabled, countTokens } = require('~/server/utils'); const { runAssistant, createOnTextProgress } = require('~/server/services/AssistantService'); +const { sendMessage, sleep, isEnabled, countTokens } = require('~/server/utils'); +const { createErrorHandler } = require('~/server/controllers/assistants/errors'); const validateAuthor = require('~/server/middleware/assistants/validateAuthor'); const { createRun, StreamRunManager } = require('~/server/services/Runs'); const { addTitle } = require('~/server/services/Endpoints/assistants'); @@ -44,7 +44,7 @@ const ten_minutes = 1000 * 60 * 10; const chatV2 = async (req, res) => { logger.debug('[/assistants/chat/] req.body', req.body); - /** @type {{ files: MongoFile[]}} */ + /** @type {{files: MongoFile[]}} */ const { text, model, @@ -90,139 +90,20 @@ const chatV2 = async (req, res) => { /** @type {Run | undefined} - The completed run, undefined if incomplete */ let completedRun; - const handleError = async (error) => { - const defaultErrorMessage = - 'The Assistant run failed to initialize. Try sending a message in a new conversation.'; - const messageData = { - thread_id, - assistant_id, - conversationId, - parentMessageId, - sender: 'System', - user: req.user.id, - shouldSaveMessage: false, - messageId: responseMessageId, - endpoint, - }; + const getContext = () => ({ + openai, + run_id, + endpoint, + cacheKey, + thread_id, + completedRun, + assistant_id, + conversationId, + parentMessageId, + responseMessageId, + }); - if (error.message === 'Run cancelled') { - return res.end(); - } else if (error.message === 'Request closed' && completedRun) { - return; - } else if (error.message === 'Request closed') { - logger.debug('[/assistants/chat/] Request aborted on close'); - } else if (/Files.*are invalid/.test(error.message)) { - const errorMessage = `Files are invalid, or may not have uploaded yet.${ - endpoint === EModelEndpoint.azureAssistants - ? ' If using Azure OpenAI, files are only available in the region of the assistant\'s model at the time of upload.' - : '' - }`; - return sendResponse(res, messageData, errorMessage); - } else if (error?.message?.includes('string too long')) { - return sendResponse( - res, - messageData, - 'Message too long. The Assistants API has a limit of 32,768 characters per message. Please shorten it and try again.', - ); - } else if (error?.message?.includes(ViolationTypes.TOKEN_BALANCE)) { - return sendResponse(res, messageData, error.message); - } else { - logger.error('[/assistants/chat/]', error); - } - - if (!openai || !thread_id || !run_id) { - return sendResponse(res, messageData, defaultErrorMessage); - } - - await sleep(2000); - - try { - const status = await cache.get(cacheKey); - if (status === 'cancelled') { - logger.debug('[/assistants/chat/] Run already cancelled'); - return res.end(); - } - await cache.delete(cacheKey); - const cancelledRun = await openai.beta.threads.runs.cancel(thread_id, run_id); - logger.debug('[/assistants/chat/] Cancelled run:', cancelledRun); - } catch (error) { - logger.error('[/assistants/chat/] Error cancelling run', error); - } - - await sleep(2000); - - let run; - try { - run = await openai.beta.threads.runs.retrieve(thread_id, run_id); - await recordUsage({ - ...run.usage, - model: run.model, - user: req.user.id, - conversationId, - }); - } catch (error) { - logger.error('[/assistants/chat/] Error fetching or processing run', error); - } - - let finalEvent; - try { - const runMessages = await checkMessageGaps({ - openai, - run_id, - endpoint, - thread_id, - conversationId, - latestMessageId: responseMessageId, - }); - - const errorContentPart = { - text: { - value: - error?.message ?? 'There was an error processing your request. Please try again later.', - }, - type: ContentTypes.ERROR, - }; - - if (!Array.isArray(runMessages[runMessages.length - 1]?.content)) { - runMessages[runMessages.length - 1].content = [errorContentPart]; - } else { - const contentParts = runMessages[runMessages.length - 1].content; - for (let i = 0; i < contentParts.length; i++) { - const currentPart = contentParts[i]; - /** @type {CodeToolCall | RetrievalToolCall | FunctionToolCall | undefined} */ - const toolCall = currentPart?.[ContentTypes.TOOL_CALL]; - if ( - toolCall && - toolCall?.function && - !(toolCall?.function?.output || toolCall?.function?.output?.length) - ) { - contentParts[i] = { - ...currentPart, - [ContentTypes.TOOL_CALL]: { - ...toolCall, - function: { - ...toolCall.function, - output: 'error processing tool', - }, - }, - }; - } - } - runMessages[runMessages.length - 1].content.push(errorContentPart); - } - - finalEvent = { - final: true, - conversation: await getConvo(req.user.id, conversationId), - runMessages, - }; - } catch (error) { - logger.error('[/assistants/chat/] Error finalizing error process', error); - return sendResponse(res, messageData, 'The Assistant run failed'); - } - - return sendResponse(res, finalEvent); - }; + const handleError = createErrorHandler({ req, res, getContext }); try { res.on('close', async () => { @@ -489,6 +370,11 @@ const chatV2 = async (req, res) => { }, }; + /** @type {undefined | TAssistantEndpoint} */ + const config = req.app.locals[endpoint] ?? {}; + /** @type {undefined | TBaseEndpoint} */ + const allConfig = req.app.locals.all; + const streamRunManager = new StreamRunManager({ req, res, @@ -498,6 +384,7 @@ const chatV2 = async (req, res) => { attachedFileIds, parentMessageId: userMessageId, responseMessage: openai.responseMessage, + streamRate: allConfig?.streamRate ?? config.streamRate, // streamOptions: { // }, @@ -510,6 +397,16 @@ const chatV2 = async (req, res) => { response = streamRunManager; response.text = streamRunManager.intermediateText; + + const messageCache = getLogStores(CacheKeys.MESSAGES); + messageCache.set( + responseMessageId, + { + complete: true, + text: response.text, + }, + Time.FIVE_MINUTES, + ); }; await processRun(); diff --git a/api/server/controllers/assistants/errors.js b/api/server/controllers/assistants/errors.js new file mode 100644 index 0000000000..a4b880bf04 --- /dev/null +++ b/api/server/controllers/assistants/errors.js @@ -0,0 +1,193 @@ +// errorHandler.js +const { sendResponse } = require('~/server/utils'); +const { logger } = require('~/config'); +const getLogStores = require('~/cache/getLogStores'); +const { CacheKeys, ViolationTypes, ContentTypes } = require('librechat-data-provider'); +const { getConvo } = require('~/models/Conversation'); +const { recordUsage, checkMessageGaps } = require('~/server/services/Threads'); + +/** + * @typedef {Object} ErrorHandlerContext + * @property {OpenAIClient} openai - The OpenAI client + * @property {string} thread_id - The thread ID + * @property {string} run_id - The run ID + * @property {boolean} completedRun - Whether the run has completed + * @property {string} assistant_id - The assistant ID + * @property {string} conversationId - The conversation ID + * @property {string} parentMessageId - The parent message ID + * @property {string} responseMessageId - The response message ID + * @property {string} endpoint - The endpoint being used + * @property {string} cacheKey - The cache key for the current request + */ + +/** + * @typedef {Object} ErrorHandlerDependencies + * @property {Express.Request} req - The Express request object + * @property {Express.Response} res - The Express response object + * @property {() => ErrorHandlerContext} getContext - Function to get the current context + * @property {string} [originPath] - The origin path for the error handler + */ + +/** + * Creates an error handler function with the given dependencies + * @param {ErrorHandlerDependencies} dependencies - The dependencies for the error handler + * @returns {(error: Error) => Promise} The error handler function + */ +const createErrorHandler = ({ req, res, getContext, originPath = '/assistants/chat/' }) => { + const cache = getLogStores(CacheKeys.ABORT_KEYS); + + /** + * Handles errors that occur during the chat process + * @param {Error} error - The error that occurred + * @returns {Promise} + */ + return async (error) => { + const { + openai, + run_id, + endpoint, + cacheKey, + thread_id, + completedRun, + assistant_id, + conversationId, + parentMessageId, + responseMessageId, + } = getContext(); + + const defaultErrorMessage = + 'The Assistant run failed to initialize. Try sending a message in a new conversation.'; + const messageData = { + thread_id, + assistant_id, + conversationId, + parentMessageId, + sender: 'System', + user: req.user.id, + shouldSaveMessage: false, + messageId: responseMessageId, + endpoint, + }; + + if (error.message === 'Run cancelled') { + return res.end(); + } else if (error.message === 'Request closed' && completedRun) { + return; + } else if (error.message === 'Request closed') { + logger.debug(`[${originPath}] Request aborted on close`); + } else if (/Files.*are invalid/.test(error.message)) { + const errorMessage = `Files are invalid, or may not have uploaded yet.${ + endpoint === 'azureAssistants' + ? ' If using Azure OpenAI, files are only available in the region of the assistant\'s model at the time of upload.' + : '' + }`; + return sendResponse(req, res, messageData, errorMessage); + } else if (error?.message?.includes('string too long')) { + return sendResponse( + req, + res, + messageData, + 'Message too long. The Assistants API has a limit of 32,768 characters per message. Please shorten it and try again.', + ); + } else if (error?.message?.includes(ViolationTypes.TOKEN_BALANCE)) { + return sendResponse(req, res, messageData, error.message); + } else { + logger.error(`[${originPath}]`, error); + } + + if (!openai || !thread_id || !run_id) { + return sendResponse(req, res, messageData, defaultErrorMessage); + } + + await new Promise((resolve) => setTimeout(resolve, 2000)); + + try { + const status = await cache.get(cacheKey); + if (status === 'cancelled') { + logger.debug(`[${originPath}] Run already cancelled`); + return res.end(); + } + await cache.delete(cacheKey); + const cancelledRun = await openai.beta.threads.runs.cancel(thread_id, run_id); + logger.debug(`[${originPath}] Cancelled run:`, cancelledRun); + } catch (error) { + logger.error(`[${originPath}] Error cancelling run`, error); + } + + await new Promise((resolve) => setTimeout(resolve, 2000)); + + let run; + try { + run = await openai.beta.threads.runs.retrieve(thread_id, run_id); + await recordUsage({ + ...run.usage, + model: run.model, + user: req.user.id, + conversationId, + }); + } catch (error) { + logger.error(`[${originPath}] Error fetching or processing run`, error); + } + + let finalEvent; + try { + const runMessages = await checkMessageGaps({ + openai, + run_id, + endpoint, + thread_id, + conversationId, + latestMessageId: responseMessageId, + }); + + const errorContentPart = { + text: { + value: + error?.message ?? 'There was an error processing your request. Please try again later.', + }, + type: ContentTypes.ERROR, + }; + + if (!Array.isArray(runMessages[runMessages.length - 1]?.content)) { + runMessages[runMessages.length - 1].content = [errorContentPart]; + } else { + const contentParts = runMessages[runMessages.length - 1].content; + for (let i = 0; i < contentParts.length; i++) { + const currentPart = contentParts[i]; + /** @type {CodeToolCall | RetrievalToolCall | FunctionToolCall | undefined} */ + const toolCall = currentPart?.[ContentTypes.TOOL_CALL]; + if ( + toolCall && + toolCall?.function && + !(toolCall?.function?.output || toolCall?.function?.output?.length) + ) { + contentParts[i] = { + ...currentPart, + [ContentTypes.TOOL_CALL]: { + ...toolCall, + function: { + ...toolCall.function, + output: 'error processing tool', + }, + }, + }; + } + } + runMessages[runMessages.length - 1].content.push(errorContentPart); + } + + finalEvent = { + final: true, + conversation: await getConvo(req.user.id, conversationId), + runMessages, + }; + } catch (error) { + logger.error(`[${originPath}] Error finalizing error process`, error); + return sendResponse(req, res, messageData, 'The Assistant run failed'); + } + + return sendResponse(req, res, finalEvent); + }; +}; + +module.exports = { createErrorHandler }; diff --git a/api/server/middleware/abortMiddleware.js b/api/server/middleware/abortMiddleware.js index 6b93bcac28..a8ef269c9f 100644 --- a/api/server/middleware/abortMiddleware.js +++ b/api/server/middleware/abortMiddleware.js @@ -30,7 +30,10 @@ async function abortMessage(req, res) { return res.status(204).send({ message: 'Request not found' }); } const finalEvent = await abortController.abortCompletion(); - logger.info('[abortMessage] Aborted request', { abortKey }); + logger.debug( + `[abortMessage] ID: ${req.user.id} | ${req.user.email} | Aborted request: ` + + JSON.stringify({ abortKey }), + ); abortControllers.delete(abortKey); if (res.headersSent && finalEvent) { @@ -116,7 +119,7 @@ const createAbortController = (req, res, getAbortData, getReqData) => { { promptTokens, completionTokens }, ); - saveMessage({ ...responseMessage, user }); + saveMessage(req, { ...responseMessage, user }); let conversation; if (userMessagePromise) { @@ -190,7 +193,7 @@ const handleAbortError = async (res, req, error, data) => { } }; - await sendError(res, options, callback); + await sendError(req, res, options, callback); }; if (partialText && partialText.length > 5) { diff --git a/api/server/middleware/denyRequest.js b/api/server/middleware/denyRequest.js index 37952176bf..8e89bccee0 100644 --- a/api/server/middleware/denyRequest.js +++ b/api/server/middleware/denyRequest.js @@ -41,10 +41,10 @@ const denyRequest = async (req, res, errorMessage) => { const shouldSaveMessage = _convoId && parentMessageId && parentMessageId !== Constants.NO_PARENT; if (shouldSaveMessage) { - await saveMessage({ ...userMessage, user: req.user.id }); + await saveMessage(req, { ...userMessage, user: req.user.id }); } - return await sendError(res, { + return await sendError(req, res, { sender: getResponseSender(req.body), messageId: crypto.randomUUID(), conversationId, diff --git a/api/server/middleware/validateImageRequest.js b/api/server/middleware/validateImageRequest.js index c0e8e5fe83..e07e48cc71 100644 --- a/api/server/middleware/validateImageRequest.js +++ b/api/server/middleware/validateImageRequest.js @@ -31,10 +31,14 @@ function validateImageRequest(req, res, next) { return res.status(403).send('Access Denied'); } - if (req.path.includes(payload.id)) { + const fullPath = decodeURIComponent(req.originalUrl); + const pathPattern = new RegExp(`^/images/${payload.id}/[^/]+$`); + + if (pathPattern.test(fullPath)) { logger.debug('[validateImageRequest] Image request validated'); next(); } else { + logger.warn('[validateImageRequest] Invalid image path'); res.status(403).send('Access Denied'); } } diff --git a/api/server/routes/ask/askChatGPTBrowser.js b/api/server/routes/ask/askChatGPTBrowser.js index 4ce1770b8e..8b4be397a0 100644 --- a/api/server/routes/ask/askChatGPTBrowser.js +++ b/api/server/routes/ask/askChatGPTBrowser.js @@ -51,7 +51,7 @@ router.post('/', setHeaders, async (req, res) => { }); if (!overrideParentMessageId) { - await saveMessage({ ...userMessage, user: req.user.id }); + await saveMessage(req, { ...userMessage, user: req.user.id }); await saveConvo(req.user.id, { ...userMessage, ...endpointOption, @@ -93,7 +93,7 @@ const ask = async ({ const currentTimestamp = Date.now(); if (currentTimestamp - lastSavedTimestamp > 500) { lastSavedTimestamp = currentTimestamp; - saveMessage({ + saveMessage(req, { messageId: responseMessageId, sender: endpointOption?.jailbreak ? 'Sydney' : 'BingAI', conversationId, @@ -159,7 +159,7 @@ const ask = async ({ isCreatedByUser: false, }; - await saveMessage({ ...responseMessage, user }); + await saveMessage(req, { ...responseMessage, user }); responseMessage.messageId = newResponseMessageId; // STEP2 update the conversation @@ -192,7 +192,7 @@ const ask = async ({ // If response has parentMessageId, the fake userMessage.messageId should be updated to the real one. if (!overrideParentMessageId) { - await saveMessage({ + await saveMessage(req, { ...userMessage, user, messageId: userMessageId, @@ -229,7 +229,7 @@ const ask = async ({ isCreatedByUser: false, text: `${getPartialMessage() ?? ''}\n\nError message: "${error.message}"`, }; - await saveMessage({ ...errorMessage, user }); + await saveMessage(req, { ...errorMessage, user }); handleError(res, errorMessage); } }; diff --git a/api/server/routes/ask/bingAI.js b/api/server/routes/ask/bingAI.js index 916cda4b10..b5763c3b3d 100644 --- a/api/server/routes/ask/bingAI.js +++ b/api/server/routes/ask/bingAI.js @@ -70,7 +70,7 @@ router.post('/', setHeaders, async (req, res) => { }); if (!overrideParentMessageId) { - await saveMessage({ ...userMessage, user: req.user.id }); + await saveMessage(req, { ...userMessage, user: req.user.id }); await saveConvo(req.user.id, { ...userMessage, ...endpointOption, @@ -118,7 +118,7 @@ const ask = async ({ const currentTimestamp = Date.now(); if (currentTimestamp - lastSavedTimestamp > 500) { lastSavedTimestamp = currentTimestamp; - saveMessage({ + saveMessage(req, { messageId: responseMessageId, sender: model, conversationId, @@ -197,7 +197,7 @@ const ask = async ({ isCreatedByUser: false, }; - await saveMessage({ ...responseMessage, user }); + await saveMessage(req, { ...responseMessage, user }); responseMessage.messageId = newResponseMessageId; let conversationUpdate = { @@ -221,7 +221,7 @@ const ask = async ({ // If response has parentMessageId, the fake userMessage.messageId should be updated to the real one. if (!overrideParentMessageId) { - await saveMessage({ + await saveMessage(req, { ...userMessage, user, messageId: userMessageId, @@ -266,7 +266,7 @@ const ask = async ({ isCreatedByUser: false, }; - saveMessage({ ...responseMessage, user }); + saveMessage(req, { ...responseMessage, user }); return { title: await getConvoTitle(user, conversationId), @@ -288,7 +288,7 @@ const ask = async ({ model, isCreatedByUser: false, }; - await saveMessage({ ...errorMessage, user }); + await saveMessage(req, { ...errorMessage, user }); handleError(res, errorMessage); } } diff --git a/api/server/routes/ask/gptPlugins.js b/api/server/routes/ask/gptPlugins.js index 1db3e333dc..602ff25086 100644 --- a/api/server/routes/ask/gptPlugins.js +++ b/api/server/routes/ask/gptPlugins.js @@ -1,10 +1,11 @@ const express = require('express'); const throttle = require('lodash/throttle'); -const { getResponseSender, Constants } = require('librechat-data-provider'); +const { getResponseSender, Constants, CacheKeys, Time } = require('librechat-data-provider'); const { initializeClient } = require('~/server/services/Endpoints/gptPlugins'); const { sendMessage, createOnProgress } = require('~/server/utils'); const { addTitle } = require('~/server/services/Endpoints/openAI'); const { saveMessage } = require('~/models'); +const { getLogStores } = require('~/cache'); const { handleAbort, createAbortController, @@ -71,7 +72,8 @@ router.post( } }; - const throttledSaveMessage = throttle(saveMessage, 3000, { trailing: false }); + const messageCache = getLogStores(CacheKeys.MESSAGES); + const throttledSetMessage = throttle(messageCache.set, 3000, { trailing: false }); let streaming = null; let timer = null; @@ -85,7 +87,8 @@ router.post( clearTimeout(timer); } - throttledSaveMessage({ + /* + { messageId: responseMessageId, sender, conversationId, @@ -96,7 +99,9 @@ router.post( error: false, plugins, user, - }); + } + */ + throttledSetMessage(responseMessageId, partialText, Time.FIVE_MINUTES); streaming = new Promise((resolve) => { timer = setTimeout(() => { @@ -170,7 +175,7 @@ router.post( const onChainEnd = () => { if (!client.skipSaveUserMessage) { - saveMessage({ ...userMessage, user }); + saveMessage(req, { ...userMessage, user }); } sendIntermediateMessage(res, { plugins, @@ -208,7 +213,7 @@ router.post( logger.debug('[/ask/gptPlugins]', response); response.plugins = plugins.map((p) => ({ ...p, loading: false })); - await saveMessage({ ...response, user }); + await saveMessage(req, { ...response, user }); const { conversation = {} } = await client.responsePromise; conversation.title = diff --git a/api/server/routes/edit/gptPlugins.js b/api/server/routes/edit/gptPlugins.js index 4db05bd493..926c8e4f5f 100644 --- a/api/server/routes/edit/gptPlugins.js +++ b/api/server/routes/edit/gptPlugins.js @@ -1,19 +1,20 @@ const express = require('express'); const throttle = require('lodash/throttle'); -const { getResponseSender } = require('librechat-data-provider'); +const { getResponseSender, CacheKeys, Time } = require('librechat-data-provider'); const { - handleAbort, - createAbortController, - handleAbortError, setHeaders, + handleAbort, + moderateText, validateModel, + handleAbortError, validateEndpoint, buildEndpointOption, - moderateText, + createAbortController, } = require('~/server/middleware'); const { sendMessage, createOnProgress, formatSteps, formatAction } = require('~/server/utils'); const { initializeClient } = require('~/server/services/Endpoints/gptPlugins'); const { saveMessage } = require('~/models'); +const { getLogStores } = require('~/cache'); const { validateTools } = require('~/app'); const { logger } = require('~/config'); @@ -79,7 +80,8 @@ router.post( } }; - const throttledSaveMessage = throttle(saveMessage, 3000, { trailing: false }); + const messageCache = getLogStores(CacheKeys.MESSAGES); + const throttledSetMessage = throttle(messageCache.set, 3000, { trailing: false }); const { onProgress: progressCallback, sendIntermediateMessage, @@ -91,7 +93,8 @@ router.post( plugin.loading = false; } - throttledSaveMessage({ + /* + { messageId: responseMessageId, sender, conversationId, @@ -102,7 +105,9 @@ router.post( isEdited: true, error: false, user, - }); + } + */ + throttledSetMessage(responseMessageId, partialText, Time.FIVE_MINUTES); }, }); @@ -110,7 +115,7 @@ router.post( let { intermediateSteps: steps } = data; plugin.outputs = steps && steps[0].action ? formatSteps(steps) : 'An error occurred.'; plugin.loading = false; - saveMessage({ ...userMessage, user }); + saveMessage(req, { ...userMessage, user }); sendIntermediateMessage(res, { plugin, parentMessageId: userMessage.messageId, @@ -141,7 +146,7 @@ router.post( plugin.inputs.push(formattedAction); plugin.latest = formattedAction.plugin; if (!start && !client.skipSaveUserMessage) { - saveMessage({ ...userMessage, user }); + saveMessage(req, { ...userMessage, user }); } sendIntermediateMessage(res, { plugin, @@ -180,7 +185,7 @@ router.post( logger.debug('[/edit/gptPlugins] CLIENT RESPONSE', response); response.plugin = { ...plugin, loading: false }; - await saveMessage({ ...response, user }); + await saveMessage(req, { ...response, user }); const { conversation = {} } = await client.responsePromise; conversation.title = diff --git a/api/server/routes/messages.js b/api/server/routes/messages.js index e0bdadfc50..f06a19b483 100644 --- a/api/server/routes/messages.js +++ b/api/server/routes/messages.js @@ -1,17 +1,12 @@ const express = require('express'); const router = express.Router(); -const { - getMessages, - updateMessage, - saveConvo, - saveMessage, - deleteMessages, -} = require('../../models'); -const { countTokens } = require('../utils'); -const { requireJwtAuth, validateMessageReq } = require('../middleware/'); +const { saveConvo, saveMessage, getMessages, updateMessage, deleteMessages } = require('~/models'); +const { requireJwtAuth, validateMessageReq } = require('~/server/middleware'); +const { countTokens } = require('~/server/utils'); router.use(requireJwtAuth); +/* Note: It's necessary to add `validateMessageReq` within route definition for correct params */ router.get('/:conversationId', validateMessageReq, async (req, res) => { const { conversationId } = req.params; res.status(200).send(await getMessages({ conversationId }, '-_id -__v -user')); @@ -20,7 +15,7 @@ router.get('/:conversationId', validateMessageReq, async (req, res) => { // CREATE router.post('/:conversationId', validateMessageReq, async (req, res) => { const message = req.body; - const savedMessage = await saveMessage({ ...message, user: req.user.id }); + const savedMessage = await saveMessage(req, { ...message, user: req.user.id }); await saveConvo(req.user.id, savedMessage); res.status(201).send(savedMessage); }); @@ -36,7 +31,8 @@ router.put('/:conversationId/:messageId', validateMessageReq, async (req, res) = const { messageId, model } = req.params; const { text } = req.body; const tokenCount = await countTokens(text, model); - res.status(201).json(await updateMessage({ messageId, text, tokenCount })); + const result = await updateMessage(req, { messageId, text, tokenCount }); + res.status(201).json(result); }); // DELETE diff --git a/api/server/services/AppService.js b/api/server/services/AppService.js index e416d5f6e7..d776aa63b7 100644 --- a/api/server/services/AppService.js +++ b/api/server/services/AppService.js @@ -67,17 +67,18 @@ const AppService = async (app) => { handleRateLimits(config?.rateLimits); const endpointLocals = {}; + const endpoints = config?.endpoints; - if (config?.endpoints?.[EModelEndpoint.azureOpenAI]) { + if (endpoints?.[EModelEndpoint.azureOpenAI]) { endpointLocals[EModelEndpoint.azureOpenAI] = azureConfigSetup(config); checkAzureVariables(); } - if (config?.endpoints?.[EModelEndpoint.azureOpenAI]?.assistants) { + if (endpoints?.[EModelEndpoint.azureOpenAI]?.assistants) { endpointLocals[EModelEndpoint.azureAssistants] = azureAssistantsDefaults(); } - if (config?.endpoints?.[EModelEndpoint.azureAssistants]) { + if (endpoints?.[EModelEndpoint.azureAssistants]) { endpointLocals[EModelEndpoint.azureAssistants] = assistantsConfigSetup( config, EModelEndpoint.azureAssistants, @@ -85,7 +86,7 @@ const AppService = async (app) => { ); } - if (config?.endpoints?.[EModelEndpoint.assistants]) { + if (endpoints?.[EModelEndpoint.assistants]) { endpointLocals[EModelEndpoint.assistants] = assistantsConfigSetup( config, EModelEndpoint.assistants, @@ -93,6 +94,19 @@ const AppService = async (app) => { ); } + if (endpoints?.[EModelEndpoint.openAI]) { + endpointLocals[EModelEndpoint.openAI] = endpoints[EModelEndpoint.openAI]; + } + if (endpoints?.[EModelEndpoint.google]) { + endpointLocals[EModelEndpoint.google] = endpoints[EModelEndpoint.google]; + } + if (endpoints?.[EModelEndpoint.anthropic]) { + endpointLocals[EModelEndpoint.anthropic] = endpoints[EModelEndpoint.anthropic]; + } + if (endpoints?.[EModelEndpoint.gptPlugins]) { + endpointLocals[EModelEndpoint.gptPlugins] = endpoints[EModelEndpoint.gptPlugins]; + } + app.locals = { ...defaultLocals, modelSpecs: config.modelSpecs, diff --git a/api/server/services/Endpoints/anthropic/initializeClient.js b/api/server/services/Endpoints/anthropic/initializeClient.js index c5d6696b3e..42b902b1fc 100644 --- a/api/server/services/Endpoints/anthropic/initializeClient.js +++ b/api/server/services/Endpoints/anthropic/initializeClient.js @@ -19,11 +19,27 @@ const initializeClient = async ({ req, res, endpointOption }) => { checkUserKeyExpiry(expiresAt, EModelEndpoint.anthropic); } + const clientOptions = {}; + + /** @type {undefined | TBaseEndpoint} */ + const anthropicConfig = req.app.locals[EModelEndpoint.anthropic]; + + if (anthropicConfig) { + clientOptions.streamRate = anthropicConfig.streamRate; + } + + /** @type {undefined | TBaseEndpoint} */ + const allConfig = req.app.locals.all; + if (allConfig) { + clientOptions.streamRate = allConfig.streamRate; + } + const client = new AnthropicClient(anthropicApiKey, { req, res, reverseProxyUrl: ANTHROPIC_REVERSE_PROXY ?? null, proxy: PROXY ?? null, + ...clientOptions, ...endpointOption, }); diff --git a/api/server/services/Endpoints/custom/initializeClient.js b/api/server/services/Endpoints/custom/initializeClient.js index 9fb6bfd1af..dbc7a769fb 100644 --- a/api/server/services/Endpoints/custom/initializeClient.js +++ b/api/server/services/Endpoints/custom/initializeClient.js @@ -114,9 +114,16 @@ const initializeClient = async ({ req, res, endpointOption }) => { contextStrategy: endpointConfig.summarize ? 'summarize' : null, directEndpoint: endpointConfig.directEndpoint, titleMessageRole: endpointConfig.titleMessageRole, + streamRate: endpointConfig.streamRate, endpointTokenConfig, }; + /** @type {undefined | TBaseEndpoint} */ + const allConfig = req.app.locals.all; + if (allConfig) { + customOptions.streamRate = allConfig.streamRate; + } + const clientOptions = { reverseProxyUrl: baseURL ?? null, proxy: PROXY ?? null, diff --git a/api/server/services/Endpoints/google/initializeClient.js b/api/server/services/Endpoints/google/initializeClient.js index d2099edcf5..788375e1e7 100644 --- a/api/server/services/Endpoints/google/initializeClient.js +++ b/api/server/services/Endpoints/google/initializeClient.js @@ -27,11 +27,27 @@ const initializeClient = async ({ req, res, endpointOption }) => { [AuthKeys.GOOGLE_API_KEY]: GOOGLE_KEY, }; + const clientOptions = {}; + + /** @type {undefined | TBaseEndpoint} */ + const allConfig = req.app.locals.all; + /** @type {undefined | TBaseEndpoint} */ + const googleConfig = req.app.locals[EModelEndpoint.google]; + + if (googleConfig) { + clientOptions.streamRate = googleConfig.streamRate; + } + + if (allConfig) { + clientOptions.streamRate = allConfig.streamRate; + } + const client = new GoogleClient(credentials, { req, res, reverseProxyUrl: GOOGLE_REVERSE_PROXY ?? null, proxy: PROXY ?? null, + ...clientOptions, ...endpointOption, }); diff --git a/api/server/services/Endpoints/google/initializeClient.spec.js b/api/server/services/Endpoints/google/initializeClient.spec.js index b46a535618..657dcbcaa8 100644 --- a/api/server/services/Endpoints/google/initializeClient.spec.js +++ b/api/server/services/Endpoints/google/initializeClient.spec.js @@ -8,6 +8,8 @@ jest.mock('~/server/services/UserService', () => ({ getUserKey: jest.fn().mockImplementation(() => ({})), })); +const app = { locals: {} }; + describe('google/initializeClient', () => { afterEach(() => { jest.clearAllMocks(); @@ -23,6 +25,7 @@ describe('google/initializeClient', () => { const req = { body: { key: expiresAt }, user: { id: '123' }, + app, }; const res = {}; const endpointOption = { modelOptions: { model: 'default-model' } }; @@ -44,6 +47,7 @@ describe('google/initializeClient', () => { const req = { body: { key: null }, user: { id: '123' }, + app, }; const res = {}; const endpointOption = { modelOptions: { model: 'default-model' } }; @@ -66,6 +70,7 @@ describe('google/initializeClient', () => { const req = { body: { key: expiresAt }, user: { id: '123' }, + app, }; const res = {}; const endpointOption = { modelOptions: { model: 'default-model' } }; diff --git a/api/server/services/Endpoints/gptPlugins/initializeClient.js b/api/server/services/Endpoints/gptPlugins/initializeClient.js index 312b23eb67..7e79d42564 100644 --- a/api/server/services/Endpoints/gptPlugins/initializeClient.js +++ b/api/server/services/Endpoints/gptPlugins/initializeClient.js @@ -86,6 +86,9 @@ const initializeClient = async ({ req, res, endpointOption }) => { clientOptions.titleModel = azureConfig.titleModel; clientOptions.titleMethod = azureConfig.titleMethod ?? 'completion'; + const azureRate = modelName.includes('gpt-4') ? 30 : 17; + clientOptions.streamRate = azureConfig.streamRate ?? azureRate; + const groupName = modelGroupMap[modelName].group; clientOptions.addParams = azureConfig.groupMap[groupName].addParams; clientOptions.dropParams = azureConfig.groupMap[groupName].dropParams; @@ -98,6 +101,19 @@ const initializeClient = async ({ req, res, endpointOption }) => { apiKey = clientOptions.azure.azureOpenAIApiKey; } + /** @type {undefined | TBaseEndpoint} */ + const pluginsConfig = req.app.locals[EModelEndpoint.gptPlugins]; + + if (!useAzure && pluginsConfig) { + clientOptions.streamRate = pluginsConfig.streamRate; + } + + /** @type {undefined | TBaseEndpoint} */ + const allConfig = req.app.locals.all; + if (allConfig) { + clientOptions.streamRate = allConfig.streamRate; + } + if (!apiKey) { throw new Error(`${endpoint} API key not provided. Please provide it again.`); } diff --git a/api/server/services/Endpoints/openAI/initializeClient.js b/api/server/services/Endpoints/openAI/initializeClient.js index 9a3a5c4189..1518cba028 100644 --- a/api/server/services/Endpoints/openAI/initializeClient.js +++ b/api/server/services/Endpoints/openAI/initializeClient.js @@ -76,6 +76,10 @@ const initializeClient = async ({ req, res, endpointOption }) => { clientOptions.titleConvo = azureConfig.titleConvo; clientOptions.titleModel = azureConfig.titleModel; + + const azureRate = modelName.includes('gpt-4') ? 30 : 17; + clientOptions.streamRate = azureConfig.streamRate ?? azureRate; + clientOptions.titleMethod = azureConfig.titleMethod ?? 'completion'; const groupName = modelGroupMap[modelName].group; @@ -90,6 +94,19 @@ const initializeClient = async ({ req, res, endpointOption }) => { apiKey = clientOptions.azure.azureOpenAIApiKey; } + /** @type {undefined | TBaseEndpoint} */ + const openAIConfig = req.app.locals[EModelEndpoint.openAI]; + + if (!isAzureOpenAI && openAIConfig) { + clientOptions.streamRate = openAIConfig.streamRate; + } + + /** @type {undefined | TBaseEndpoint} */ + const allConfig = req.app.locals.all; + if (allConfig) { + clientOptions.streamRate = allConfig.streamRate; + } + if (userProvidesKey & !apiKey) { throw new Error( JSON.stringify({ diff --git a/api/server/services/Files/Audio/getCustomConfigSpeech.js b/api/server/services/Files/Audio/getCustomConfigSpeech.js index eca49f711b..d22a143574 100644 --- a/api/server/services/Files/Audio/getCustomConfigSpeech.js +++ b/api/server/services/Files/Audio/getCustomConfigSpeech.js @@ -15,37 +15,43 @@ const getCustomConfig = require('~/server/services/Config/getCustomConfig'); async function getCustomConfigSpeech(req, res) { try { const customConfig = await getCustomConfig(); + const sttExternal = !!customConfig.speech?.stt; + const ttsExternal = !!customConfig.speech?.tts; + let settings = { + sttExternal, + ttsExternal, + }; if (!customConfig || !customConfig.speech?.speechTab) { - throw new Error('Configuration or speechTab schema is missing'); + return res.status(200).send(settings); } - const ttsSchema = customConfig.speech?.speechTab; - let settings = {}; + const speechTab = customConfig.speech.speechTab; - if (ttsSchema.advancedMode !== undefined) { - settings.advancedMode = ttsSchema.advancedMode; + if (speechTab.advancedMode !== undefined) { + settings.advancedMode = speechTab.advancedMode; } - if (ttsSchema.speechToText) { - for (const key in ttsSchema.speechToText) { - if (ttsSchema.speechToText[key] !== undefined) { - settings[key] = ttsSchema.speechToText[key]; + if (speechTab.speechToText) { + for (const key in speechTab.speechToText) { + if (speechTab.speechToText[key] !== undefined) { + settings[key] = speechTab.speechToText[key]; } } } - if (ttsSchema.textToSpeech) { - for (const key in ttsSchema.textToSpeech) { - if (ttsSchema.textToSpeech[key] !== undefined) { - settings[key] = ttsSchema.textToSpeech[key]; + if (speechTab.textToSpeech) { + for (const key in speechTab.textToSpeech) { + if (speechTab.textToSpeech[key] !== undefined) { + settings[key] = speechTab.textToSpeech[key]; } } } return res.status(200).send(settings); } catch (error) { - res.status(200).send(); + console.error('Failed to get custom config speech settings:', error); + res.status(500).send('Internal Server Error'); } } diff --git a/api/server/services/Files/Audio/streamAudio.js b/api/server/services/Files/Audio/streamAudio.js index 9f301e710b..eb8134e958 100644 --- a/api/server/services/Files/Audio/streamAudio.js +++ b/api/server/services/Files/Audio/streamAudio.js @@ -1,5 +1,6 @@ const WebSocket = require('ws'); -const { Message } = require('~/models/Message'); +const { CacheKeys } = require('librechat-data-provider'); +const { getLogStores } = require('~/cache'); /** * @param {string[]} voiceIds - Array of voice IDs @@ -104,6 +105,8 @@ function createChunkProcessor(messageId) { throw new Error('Message ID is required'); } + const messageCache = getLogStores(CacheKeys.MESSAGES); + /** * @returns {Promise<{ text: string, isFinished: boolean }[] | string>} */ @@ -116,14 +119,17 @@ function createChunkProcessor(messageId) { return `No change in message after ${MAX_NO_CHANGE_COUNT} attempts`; } - const message = await Message.findOne({ messageId }, 'text unfinished').lean(); + /** @type { string | { text: string; complete: boolean } } */ + const message = await messageCache.get(messageId); - if (!message || !message.text) { + if (!message) { notFoundCount++; return []; } - const { text, unfinished } = message; + const text = typeof message === 'string' ? message : message.text; + const complete = typeof message === 'string' ? false : message.complete; + if (text === processedText) { noChangeCount++; } @@ -131,7 +137,7 @@ function createChunkProcessor(messageId) { const remainingText = text.slice(processedText.length); const chunks = []; - if (unfinished && remainingText.length >= 20) { + if (!complete && remainingText.length >= 20) { const separatorIndex = findLastSeparatorIndex(remainingText); if (separatorIndex !== -1) { const chunkText = remainingText.slice(0, separatorIndex + 1); @@ -141,7 +147,7 @@ function createChunkProcessor(messageId) { chunks.push({ text: remainingText, isFinished: false }); processedText = text; } - } else if (!unfinished && remainingText.trim().length > 0) { + } else if (complete && remainingText.trim().length > 0) { chunks.push({ text: remainingText.trim(), isFinished: true }); processedText = text; } diff --git a/api/server/services/Files/Audio/streamAudio.spec.js b/api/server/services/Files/Audio/streamAudio.spec.js index 7aff8dbfa7..501e252c14 100644 --- a/api/server/services/Files/Audio/streamAudio.spec.js +++ b/api/server/services/Files/Audio/streamAudio.spec.js @@ -1,89 +1,145 @@ const { createChunkProcessor, splitTextIntoChunks } = require('./streamAudio'); -const { Message } = require('~/models/Message'); -jest.mock('~/models/Message', () => ({ - Message: { - findOne: jest.fn().mockReturnValue({ - lean: jest.fn(), - }), - }, -})); +jest.mock('keyv'); + +const globalCache = {}; +jest.mock('~/cache/getLogStores', () => { + return jest.fn().mockImplementation(() => { + const EventEmitter = require('events'); + const { CacheKeys } = require('librechat-data-provider'); + + class KeyvMongo extends EventEmitter { + constructor(url = 'mongodb://127.0.0.1:27017', options) { + super(); + this.ttlSupport = false; + url = url ?? {}; + if (typeof url === 'string') { + url = { url }; + } + if (url.uri) { + url = { url: url.uri, ...url }; + } + this.opts = { + url, + collection: 'keyv', + ...url, + ...options, + }; + } + + get = async (key) => { + return new Promise((resolve) => { + resolve(globalCache[key] || null); + }); + }; + + set = async (key, value) => { + return new Promise((resolve) => { + globalCache[key] = value; + resolve(true); + }); + }; + } + + return new KeyvMongo('', { + namespace: CacheKeys.MESSAGES, + ttl: 0, + }); + }); +}); describe('processChunks', () => { let processChunks; + let mockMessageCache; beforeEach(() => { + jest.resetAllMocks(); + mockMessageCache = { + get: jest.fn(), + }; + require('~/cache/getLogStores').mockReturnValue(mockMessageCache); processChunks = createChunkProcessor('message-id'); - Message.findOne.mockClear(); - Message.findOne().lean.mockClear(); }); it('should return an empty array when the message is not found', async () => { - Message.findOne().lean.mockResolvedValueOnce(null); + mockMessageCache.get.mockResolvedValueOnce(null); const result = await processChunks(); expect(result).toEqual([]); - expect(Message.findOne).toHaveBeenCalledWith({ messageId: 'message-id' }, 'text unfinished'); - expect(Message.findOne().lean).toHaveBeenCalled(); + expect(mockMessageCache.get).toHaveBeenCalledWith('message-id'); }); - it('should return an empty array when the message does not have a text property', async () => { - Message.findOne().lean.mockResolvedValueOnce({ unfinished: true }); + it('should return an error message after MAX_NOT_FOUND_COUNT attempts', async () => { + mockMessageCache.get.mockResolvedValue(null); + for (let i = 0; i < 6; i++) { + await processChunks(); + } const result = await processChunks(); - expect(result).toEqual([]); - expect(Message.findOne).toHaveBeenCalledWith({ messageId: 'message-id' }, 'text unfinished'); - expect(Message.findOne().lean).toHaveBeenCalled(); + expect(result).toBe('Message not found after 6 attempts'); }); - it('should return chunks for an unfinished message with separators', async () => { + it('should return chunks for an incomplete message with separators', async () => { const messageText = 'This is a long message. It should be split into chunks. Lol hi mom'; - Message.findOne().lean.mockResolvedValueOnce({ text: messageText, unfinished: true }); + mockMessageCache.get.mockResolvedValueOnce({ text: messageText, complete: false }); const result = await processChunks(); expect(result).toEqual([ { text: 'This is a long message. It should be split into chunks.', isFinished: false }, ]); - expect(Message.findOne).toHaveBeenCalledWith({ messageId: 'message-id' }, 'text unfinished'); - expect(Message.findOne().lean).toHaveBeenCalled(); }); - it('should return chunks for an unfinished message without separators', async () => { + it('should return chunks for an incomplete message without separators', async () => { const messageText = 'This is a long message without separators hello there my friend'; - Message.findOne().lean.mockResolvedValueOnce({ text: messageText, unfinished: true }); + mockMessageCache.get.mockResolvedValueOnce({ text: messageText, complete: false }); const result = await processChunks(); expect(result).toEqual([{ text: messageText, isFinished: false }]); - expect(Message.findOne).toHaveBeenCalledWith({ messageId: 'message-id' }, 'text unfinished'); - expect(Message.findOne().lean).toHaveBeenCalled(); }); - it('should return the remaining text as a chunk for a finished message', async () => { + it('should return the remaining text as a chunk for a complete message', async () => { const messageText = 'This is a finished message.'; - Message.findOne().lean.mockResolvedValueOnce({ text: messageText, unfinished: false }); + mockMessageCache.get.mockResolvedValueOnce({ text: messageText, complete: true }); const result = await processChunks(); expect(result).toEqual([{ text: messageText, isFinished: true }]); - expect(Message.findOne).toHaveBeenCalledWith({ messageId: 'message-id' }, 'text unfinished'); - expect(Message.findOne().lean).toHaveBeenCalled(); }); - it('should return an empty array for a finished message with no remaining text', async () => { + it('should return an empty array for a complete message with no remaining text', async () => { const messageText = 'This is a finished message.'; - Message.findOne().lean.mockResolvedValueOnce({ text: messageText, unfinished: false }); + mockMessageCache.get.mockResolvedValueOnce({ text: messageText, complete: true }); await processChunks(); - Message.findOne().lean.mockResolvedValueOnce({ text: messageText, unfinished: false }); + mockMessageCache.get.mockResolvedValueOnce({ text: messageText, complete: true }); const result = await processChunks(); expect(result).toEqual([]); - expect(Message.findOne).toHaveBeenCalledWith({ messageId: 'message-id' }, 'text unfinished'); - expect(Message.findOne().lean).toHaveBeenCalledTimes(2); + }); + + it('should return an error message after MAX_NO_CHANGE_COUNT attempts with no change', async () => { + const messageText = 'This is a message that does not change.'; + mockMessageCache.get.mockResolvedValue({ text: messageText, complete: false }); + + for (let i = 0; i < 11; i++) { + await processChunks(); + } + const result = await processChunks(); + + expect(result).toBe('No change in message after 10 attempts'); + }); + + it('should handle string messages as incomplete', async () => { + const messageText = 'This is a message as a string.'; + mockMessageCache.get.mockResolvedValueOnce(messageText); + + const result = await processChunks(); + + expect(result).toEqual([{ text: messageText, isFinished: false }]); }); }); diff --git a/api/server/services/Runs/StreamRunManager.js b/api/server/services/Runs/StreamRunManager.js index 01c97c0f79..951818bb6f 100644 --- a/api/server/services/Runs/StreamRunManager.js +++ b/api/server/services/Runs/StreamRunManager.js @@ -1,17 +1,19 @@ const throttle = require('lodash/throttle'); const { + Time, + CacheKeys, StepTypes, ContentTypes, ToolCallTypes, - // StepStatus, MessageContentTypes, AssistantStreamEvents, + Constants, } = require('librechat-data-provider'); const { retrieveAndProcessFile } = require('~/server/services/Files/process'); const { processRequiredActions } = require('~/server/services/ToolService'); -const { saveMessage, updateMessageText } = require('~/models/Message'); -const { createOnProgress, sendMessage } = require('~/server/utils'); +const { createOnProgress, sendMessage, sleep } = require('~/server/utils'); const { processMessages } = require('~/server/services/Threads'); +const { getLogStores } = require('~/cache'); const { logger } = require('~/config'); /** @@ -68,8 +70,8 @@ class StreamRunManager { this.attachedFileIds = fields.attachedFileIds; /** @type {undefined | Promise} */ this.visionPromise = fields.visionPromise; - /** @type {boolean} */ - this.savedInitialMessage = false; + /** @type {number} */ + this.streamRate = fields.streamRate ?? Constants.DEFAULT_STREAM_RATE; /** * @type {Object. Promise>} @@ -139,11 +141,11 @@ class StreamRunManager { return this.intermediateText; } - /** Saves the initial intermediate message - * @returns {Promise} + /** Returns the current, intermediate message + * @returns {TMessage} */ - async saveInitialMessage() { - return saveMessage({ + getIntermediateMessage() { + return { conversationId: this.finalMessage.conversationId, messageId: this.finalMessage.messageId, parentMessageId: this.parentMessageId, @@ -155,7 +157,7 @@ class StreamRunManager { sender: 'Assistant', unfinished: true, error: false, - }); + }; } /* <------------------ Main Event Handlers ------------------> */ @@ -347,6 +349,8 @@ class StreamRunManager { type: ContentTypes.TOOL_CALL, index, }); + + await sleep(this.streamRate); } }; @@ -444,6 +448,7 @@ class StreamRunManager { if (content && content.type === MessageContentTypes.TEXT) { this.intermediateText += content.text.value; onProgress(content.text.value); + await sleep(this.streamRate); } } @@ -589,21 +594,14 @@ class StreamRunManager { const index = this.getStepIndex(stepKey); this.orderedRunSteps.set(index, message_creation); + const messageCache = getLogStores(CacheKeys.MESSAGES); // Create the Factory Function to stream the message const { onProgress: progressCallback } = createOnProgress({ onProgress: throttle( () => { - if (!this.savedInitialMessage) { - this.saveInitialMessage(); - this.savedInitialMessage = true; - } else { - updateMessageText({ - messageId: this.finalMessage.messageId, - text: this.getText(), - }); - } + messageCache.set(this.finalMessage.messageId, this.getText(), Time.FIVE_MINUTES); }, - 2000, + 3000, { trailing: false }, ), }); diff --git a/api/server/services/start/assistants.js b/api/server/services/start/assistants.js index ab96db8701..b46edc676b 100644 --- a/api/server/services/start/assistants.js +++ b/api/server/services/start/assistants.js @@ -51,6 +51,7 @@ function assistantsConfigSetup(config, assistantsEndpoint, prevConfig = {}) { excludedIds: parsedConfig.excludedIds, privateAssistants: parsedConfig.privateAssistants, timeoutMs: parsedConfig.timeoutMs, + streamRate: parsedConfig.streamRate, }; } diff --git a/api/server/utils/streamResponse.js b/api/server/utils/streamResponse.js index b7a691d91a..0f042339a9 100644 --- a/api/server/utils/streamResponse.js +++ b/api/server/utils/streamResponse.js @@ -30,7 +30,8 @@ const sendMessage = (res, message, event = 'message') => { /** * Processes an error with provided options, saves the error message and sends a corresponding SSE response * @async - * @param {object} res - The server response. + * @param {object} req - The request. + * @param {object} res - The response. * @param {object} options - The options for handling the error containing message properties. * @param {object} options.user - The user ID. * @param {string} options.sender - The sender of the message. @@ -41,7 +42,7 @@ const sendMessage = (res, message, event = 'message') => { * @param {boolean} options.shouldSaveMessage - [Optional] Whether the message should be saved. Default is true. * @param {function} callback - [Optional] The callback function to be executed. */ -const sendError = async (res, options, callback) => { +const sendError = async (req, res, options, callback) => { const { user, sender, @@ -69,7 +70,7 @@ const sendError = async (res, options, callback) => { } if (shouldSaveMessage) { - await saveMessage({ ...errorMessage, user }); + await saveMessage(req, { ...errorMessage, user }); } if (!errorMessage.error) { @@ -97,11 +98,12 @@ const sendError = async (res, options, callback) => { /** * Sends the response based on whether headers have been sent or not. + * @param {Express.Request} req - The server response. * @param {Express.Response} res - The server response. * @param {Object} data - The data to be sent. * @param {string} [errorMessage] - The error message, if any. */ -const sendResponse = (res, data, errorMessage) => { +const sendResponse = (req, res, data, errorMessage) => { if (!res.headersSent) { if (errorMessage) { return res.status(500).json({ error: errorMessage }); @@ -110,7 +112,7 @@ const sendResponse = (res, data, errorMessage) => { } if (errorMessage) { - return sendError(res, { ...data, text: errorMessage }); + return sendError(req, res, { ...data, text: errorMessage }); } return sendMessage(res, data); }; diff --git a/api/typedefs.js b/api/typedefs.js index ecf78c1374..c8f46c6d9b 100644 --- a/api/typedefs.js +++ b/api/typedefs.js @@ -465,6 +465,12 @@ * @memberof typedefs */ +/** + * @exports TBaseEndpoint + * @typedef {import('librechat-data-provider').TBaseEndpoint} TBaseEndpoint + * @memberof typedefs + */ + /** * @exports TEndpoint * @typedef {import('librechat-data-provider').TEndpoint} TEndpoint diff --git a/client/src/components/Chat/Input/AudioRecorder.tsx b/client/src/components/Chat/Input/AudioRecorder.tsx index dd088ea3c8..1fbb3cc61b 100644 --- a/client/src/components/Chat/Input/AudioRecorder.tsx +++ b/client/src/components/Chat/Input/AudioRecorder.tsx @@ -4,16 +4,19 @@ import { ListeningIcon, Spinner } from '~/components/svg'; import { useLocalize, useSpeechToText } from '~/hooks'; import { useChatFormContext } from '~/Providers'; import { globalAudioId } from '~/common'; +import { cn } from '~/utils'; export default function AudioRecorder({ textAreaRef, methods, ask, + isRTL, disabled, }: { textAreaRef: React.RefObject; methods: ReturnType; ask: (data: { text: string }) => void; + isRTL: boolean; disabled: boolean; }) { const localize = useLocalize(); @@ -77,7 +80,12 @@ export default function AudioRecorder({ - - - {localize('com_nav_send_message')} - - - - ); - }), + forwardRef( + (props: { disabled: boolean; isRTL: boolean }, ref: React.ForwardedRef) => { + const localize = useLocalize(); + return ( + + + + + + + {localize('com_nav_send_message')} + + + + ); + }, + ), ); const SendButton = React.memo( forwardRef((props: SendButtonProps, ref: React.ForwardedRef) => { const data = useWatch({ control: props.control }); - return ; + return ; }), ); diff --git a/client/src/components/Chat/Input/StopButton.tsx b/client/src/components/Chat/Input/StopButton.tsx index 125ca1ea25..28ac9bbff5 100644 --- a/client/src/components/Chat/Input/StopButton.tsx +++ b/client/src/components/Chat/Input/StopButton.tsx @@ -1,6 +1,13 @@ -export default function StopButton({ stop, setShowStopButton }) { +import { cn } from '~/utils'; + +export default function StopButton({ stop, setShowStopButton, isRTL }) { return ( -
+