⚙️ feat: Adjust Rate of Stream Progress (#3244)

* chore: bump data-provider and add MESSAGES CacheKey

* refactor: avoid saving messages while streaming, save partial text to cache instead

* fix(ci): processChunks

* chore: logging aborted request to debug

* feat: set stream rate for token processing

* chore: specify default stream rate

* fix(ci): Update AppService.js to use optional chaining for endpointLocals assignment

* refactor: abstract the error handler

* feat: streamRate for assistants; refactor: update default rate for token

* refactor: update error handling in assistants/errors.js

* refactor: update error handling in assistants/errors.js
This commit is contained in:
Danny Avila 2024-07-17 10:47:17 -04:00 committed by GitHub
parent 1c282d1517
commit 5d40d0a37a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
29 changed files with 661 additions and 309 deletions

View file

@ -2,8 +2,9 @@ const Anthropic = require('@anthropic-ai/sdk');
const { HttpsProxyAgent } = require('https-proxy-agent'); const { HttpsProxyAgent } = require('https-proxy-agent');
const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken'); const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken');
const { const {
getResponseSender, Constants,
EModelEndpoint, EModelEndpoint,
getResponseSender,
validateVisionModel, validateVisionModel,
} = require('librechat-data-provider'); } = require('librechat-data-provider');
const { encodeAndFormat } = require('~/server/services/Files/images/encode'); const { encodeAndFormat } = require('~/server/services/Files/images/encode');
@ -16,6 +17,7 @@ const {
} = require('./prompts'); } = require('./prompts');
const spendTokens = require('~/models/spendTokens'); const spendTokens = require('~/models/spendTokens');
const { getModelMaxTokens } = require('~/utils'); const { getModelMaxTokens } = require('~/utils');
const { sleep } = require('~/server/utils');
const BaseClient = require('./BaseClient'); const BaseClient = require('./BaseClient');
const { logger } = require('~/config'); const { logger } = require('~/config');
@ -605,6 +607,7 @@ class AnthropicClient extends BaseClient {
}; };
const maxRetries = 3; const maxRetries = 3;
const streamRate = this.options.streamRate ?? Constants.DEFAULT_STREAM_RATE;
async function processResponse() { async function processResponse() {
let attempts = 0; let attempts = 0;
@ -627,6 +630,8 @@ class AnthropicClient extends BaseClient {
} else if (completion.completion) { } else if (completion.completion) {
handleChunk(completion.completion); handleChunk(completion.completion);
} }
await sleep(streamRate);
} }
// Successful processing, exit loop // Successful processing, exit loop

View file

@ -1,10 +1,11 @@
const crypto = require('crypto'); const crypto = require('crypto');
const fetch = require('node-fetch'); const fetch = require('node-fetch');
const { supportsBalanceCheck, Constants } = require('librechat-data-provider'); const { supportsBalanceCheck, Constants, CacheKeys, Time } = require('librechat-data-provider');
const { getMessages, saveMessage, updateMessage, saveConvo } = require('~/models'); const { getMessages, saveMessage, updateMessage, saveConvo } = require('~/models');
const { addSpaceIfNeeded, isEnabled } = require('~/server/utils'); const { addSpaceIfNeeded, isEnabled } = require('~/server/utils');
const checkBalance = require('~/models/checkBalance'); const checkBalance = require('~/models/checkBalance');
const { getFiles } = require('~/models/File'); const { getFiles } = require('~/models/File');
const { getLogStores } = require('~/cache');
const TextStream = require('./TextStream'); const TextStream = require('./TextStream');
const { logger } = require('~/config'); const { logger } = require('~/config');
@ -540,6 +541,15 @@ class BaseClient {
await this.recordTokenUsage({ promptTokens, completionTokens }); await this.recordTokenUsage({ promptTokens, completionTokens });
} }
this.responsePromise = this.saveMessageToDatabase(responseMessage, saveOptions, user); this.responsePromise = this.saveMessageToDatabase(responseMessage, saveOptions, user);
const messageCache = getLogStores(CacheKeys.MESSAGES);
messageCache.set(
responseMessageId,
{
text: responseMessage.text,
complete: true,
},
Time.FIVE_MINUTES,
);
delete responseMessage.tokenCount; delete responseMessage.tokenCount;
return responseMessage; return responseMessage;
} }

View file

@ -13,10 +13,12 @@ const {
endpointSettings, endpointSettings,
EModelEndpoint, EModelEndpoint,
VisionModes, VisionModes,
Constants,
AuthKeys, AuthKeys,
} = require('librechat-data-provider'); } = require('librechat-data-provider');
const { encodeAndFormat } = require('~/server/services/Files/images'); const { encodeAndFormat } = require('~/server/services/Files/images');
const { getModelMaxTokens } = require('~/utils'); const { getModelMaxTokens } = require('~/utils');
const { sleep } = require('~/server/utils');
const { logger } = require('~/config'); const { logger } = require('~/config');
const { const {
formatMessage, formatMessage,
@ -620,8 +622,9 @@ class GoogleClient extends BaseClient {
} }
async getCompletion(_payload, options = {}) { async getCompletion(_payload, options = {}) {
const { onProgress, abortController } = options;
const { parameters, instances } = _payload; const { parameters, instances } = _payload;
const { onProgress, abortController } = options;
const streamRate = this.options.streamRate ?? Constants.DEFAULT_STREAM_RATE;
const { messages: _messages, context, examples: _examples } = instances?.[0] ?? {}; const { messages: _messages, context, examples: _examples } = instances?.[0] ?? {};
let examples; let examples;
@ -701,6 +704,7 @@ class GoogleClient extends BaseClient {
delay, delay,
}); });
reply += chunkText; reply += chunkText;
await sleep(streamRate);
} }
return reply; return reply;
} }
@ -712,10 +716,17 @@ class GoogleClient extends BaseClient {
safetySettings: safetySettings, safetySettings: safetySettings,
}); });
let delay = this.isGenerativeModel ? 12 : 8; let delay = this.options.streamRate || 8;
if (!this.options.streamRate) {
if (this.isGenerativeModel) {
delay = 12;
}
if (modelName.includes('flash')) { if (modelName.includes('flash')) {
delay = 5; delay = 5;
} }
}
for await (const chunk of stream) { for await (const chunk of stream) {
const chunkText = chunk?.content ?? chunk; const chunkText = chunk?.content ?? chunk;
await this.generateTextStream(chunkText, onProgress, { await this.generateTextStream(chunkText, onProgress, {

View file

@ -1,7 +1,9 @@
const { z } = require('zod'); const { z } = require('zod');
const axios = require('axios'); const axios = require('axios');
const { Ollama } = require('ollama'); const { Ollama } = require('ollama');
const { Constants } = require('librechat-data-provider');
const { deriveBaseURL } = require('~/utils'); const { deriveBaseURL } = require('~/utils');
const { sleep } = require('~/server/utils');
const { logger } = require('~/config'); const { logger } = require('~/config');
const ollamaPayloadSchema = z.object({ const ollamaPayloadSchema = z.object({
@ -40,6 +42,7 @@ const getValidBase64 = (imageUrl) => {
class OllamaClient { class OllamaClient {
constructor(options = {}) { constructor(options = {}) {
const host = deriveBaseURL(options.baseURL ?? 'http://localhost:11434'); const host = deriveBaseURL(options.baseURL ?? 'http://localhost:11434');
this.streamRate = options.streamRate ?? Constants.DEFAULT_STREAM_RATE;
/** @type {Ollama} */ /** @type {Ollama} */
this.client = new Ollama({ host }); this.client = new Ollama({ host });
} }
@ -136,6 +139,8 @@ class OllamaClient {
stream.controller.abort(); stream.controller.abort();
break; break;
} }
await sleep(this.streamRate);
} }
} }
// TODO: regular completion // TODO: regular completion

View file

@ -1182,8 +1182,10 @@ ${convo}
}); });
} }
const streamRate = this.options.streamRate ?? Constants.DEFAULT_STREAM_RATE;
if (this.message_file_map && this.isOllama) { if (this.message_file_map && this.isOllama) {
const ollamaClient = new OllamaClient({ baseURL }); const ollamaClient = new OllamaClient({ baseURL, streamRate });
return await ollamaClient.chatCompletion({ return await ollamaClient.chatCompletion({
payload: modelOptions, payload: modelOptions,
onProgress, onProgress,
@ -1221,8 +1223,6 @@ ${convo}
} }
}); });
const azureDelay = this.modelOptions.model?.includes('gpt-4') ? 30 : 17;
for await (const chunk of stream) { for await (const chunk of stream) {
const token = chunk.choices[0]?.delta?.content || ''; const token = chunk.choices[0]?.delta?.content || '';
intermediateReply += token; intermediateReply += token;
@ -1232,9 +1232,7 @@ ${convo}
break; break;
} }
if (this.azure) { await sleep(streamRate);
await sleep(azureDelay);
}
} }
if (!UnexpectedRoleError) { if (!UnexpectedRoleError) {

View file

@ -1,5 +1,6 @@
const OpenAIClient = require('./OpenAIClient'); const OpenAIClient = require('./OpenAIClient');
const { CallbackManager } = require('langchain/callbacks'); const { CallbackManager } = require('langchain/callbacks');
const { CacheKeys, Time } = require('librechat-data-provider');
const { BufferMemory, ChatMessageHistory } = require('langchain/memory'); const { BufferMemory, ChatMessageHistory } = require('langchain/memory');
const { initializeCustomAgent, initializeFunctionsAgent } = require('./agents'); const { initializeCustomAgent, initializeFunctionsAgent } = require('./agents');
const { addImages, buildErrorInput, buildPromptPrefix } = require('./output_parsers'); const { addImages, buildErrorInput, buildPromptPrefix } = require('./output_parsers');
@ -11,6 +12,7 @@ const { SelfReflectionTool } = require('./tools');
const { isEnabled } = require('~/server/utils'); const { isEnabled } = require('~/server/utils');
const { extractBaseURL } = require('~/utils'); const { extractBaseURL } = require('~/utils');
const { loadTools } = require('./tools/util'); const { loadTools } = require('./tools/util');
const { getLogStores } = require('~/cache');
const { logger } = require('~/config'); const { logger } = require('~/config');
class PluginsClient extends OpenAIClient { class PluginsClient extends OpenAIClient {
@ -220,6 +222,13 @@ class PluginsClient extends OpenAIClient {
} }
} }
/**
*
* @param {TMessage} responseMessage
* @param {Partial<TMessage>} saveOptions
* @param {string} user
* @returns
*/
async handleResponseMessage(responseMessage, saveOptions, user) { async handleResponseMessage(responseMessage, saveOptions, user) {
const { output, errorMessage, ...result } = this.result; const { output, errorMessage, ...result } = this.result;
logger.debug('[PluginsClient][handleResponseMessage] Output:', { logger.debug('[PluginsClient][handleResponseMessage] Output:', {
@ -239,6 +248,15 @@ class PluginsClient extends OpenAIClient {
} }
this.responsePromise = this.saveMessageToDatabase(responseMessage, saveOptions, user); this.responsePromise = this.saveMessageToDatabase(responseMessage, saveOptions, user);
const messageCache = getLogStores(CacheKeys.MESSAGES);
messageCache.set(
responseMessage.messageId,
{
text: responseMessage.text,
complete: true,
},
Time.FIVE_MINUTES,
);
delete responseMessage.tokenCount; delete responseMessage.tokenCount;
return { ...responseMessage, ...result }; return { ...responseMessage, ...result };
} }

View file

@ -1,13 +1,11 @@
const Keyv = require('keyv'); const Keyv = require('keyv');
const { CacheKeys, ViolationTypes } = require('librechat-data-provider'); const { CacheKeys, ViolationTypes, Time } = require('librechat-data-provider');
const { logFile, violationFile } = require('./keyvFiles'); const { logFile, violationFile } = require('./keyvFiles');
const { math, isEnabled } = require('~/server/utils'); const { math, isEnabled } = require('~/server/utils');
const keyvRedis = require('./keyvRedis'); const keyvRedis = require('./keyvRedis');
const keyvMongo = require('./keyvMongo'); const keyvMongo = require('./keyvMongo');
const { BAN_DURATION, USE_REDIS } = process.env ?? {}; const { BAN_DURATION, USE_REDIS } = process.env ?? {};
const THIRTY_MINUTES = 1800000;
const TEN_MINUTES = 600000;
const duration = math(BAN_DURATION, 7200000); const duration = math(BAN_DURATION, 7200000);
@ -29,17 +27,21 @@ const roles = isEnabled(USE_REDIS)
? new Keyv({ store: keyvRedis }) ? new Keyv({ store: keyvRedis })
: new Keyv({ namespace: CacheKeys.ROLES }); : new Keyv({ namespace: CacheKeys.ROLES });
const audioRuns = isEnabled(USE_REDIS) // ttl: 30 minutes const audioRuns = isEnabled(USE_REDIS)
? new Keyv({ store: keyvRedis, ttl: TEN_MINUTES }) ? new Keyv({ store: keyvRedis, ttl: Time.TEN_MINUTES })
: new Keyv({ namespace: CacheKeys.AUDIO_RUNS, ttl: TEN_MINUTES }); : new Keyv({ namespace: CacheKeys.AUDIO_RUNS, ttl: Time.TEN_MINUTES });
const messages = isEnabled(USE_REDIS)
? new Keyv({ store: keyvRedis, ttl: Time.FIVE_MINUTES })
: new Keyv({ namespace: CacheKeys.MESSAGES, ttl: Time.FIVE_MINUTES });
const tokenConfig = isEnabled(USE_REDIS) // ttl: 30 minutes const tokenConfig = isEnabled(USE_REDIS) // ttl: 30 minutes
? new Keyv({ store: keyvRedis, ttl: THIRTY_MINUTES }) ? new Keyv({ store: keyvRedis, ttl: Time.THIRTY_MINUTES })
: new Keyv({ namespace: CacheKeys.TOKEN_CONFIG, ttl: THIRTY_MINUTES }); : new Keyv({ namespace: CacheKeys.TOKEN_CONFIG, ttl: Time.THIRTY_MINUTES });
const genTitle = isEnabled(USE_REDIS) // ttl: 2 minutes const genTitle = isEnabled(USE_REDIS) // ttl: 2 minutes
? new Keyv({ store: keyvRedis, ttl: 120000 }) ? new Keyv({ store: keyvRedis, ttl: Time.TWO_MINUTES })
: new Keyv({ namespace: CacheKeys.GEN_TITLE, ttl: 120000 }); : new Keyv({ namespace: CacheKeys.GEN_TITLE, ttl: Time.TWO_MINUTES });
const modelQueries = isEnabled(process.env.USE_REDIS) const modelQueries = isEnabled(process.env.USE_REDIS)
? new Keyv({ store: keyvRedis }) ? new Keyv({ store: keyvRedis })
@ -47,7 +49,7 @@ const modelQueries = isEnabled(process.env.USE_REDIS)
const abortKeys = isEnabled(USE_REDIS) const abortKeys = isEnabled(USE_REDIS)
? new Keyv({ store: keyvRedis }) ? new Keyv({ store: keyvRedis })
: new Keyv({ namespace: CacheKeys.ABORT_KEYS, ttl: 600000 }); : new Keyv({ namespace: CacheKeys.ABORT_KEYS, ttl: Time.TEN_MINUTES });
const namespaces = { const namespaces = {
[CacheKeys.ROLES]: roles, [CacheKeys.ROLES]: roles,
@ -81,6 +83,7 @@ const namespaces = {
[CacheKeys.GEN_TITLE]: genTitle, [CacheKeys.GEN_TITLE]: genTitle,
[CacheKeys.MODEL_QUERIES]: modelQueries, [CacheKeys.MODEL_QUERIES]: modelQueries,
[CacheKeys.AUDIO_RUNS]: audioRuns, [CacheKeys.AUDIO_RUNS]: audioRuns,
[CacheKeys.MESSAGES]: messages,
}; };
/** /**

View file

@ -1,7 +1,8 @@
const throttle = require('lodash/throttle'); const throttle = require('lodash/throttle');
const { getResponseSender, Constants, EModelEndpoint } = require('librechat-data-provider'); const { getResponseSender, Constants, CacheKeys, Time } = require('librechat-data-provider');
const { createAbortController, handleAbortError } = require('~/server/middleware'); const { createAbortController, handleAbortError } = require('~/server/middleware');
const { sendMessage, createOnProgress } = require('~/server/utils'); const { sendMessage, createOnProgress } = require('~/server/utils');
const { getLogStores } = require('~/cache');
const { saveMessage } = require('~/models'); const { saveMessage } = require('~/models');
const { logger } = require('~/config'); const { logger } = require('~/config');
@ -51,11 +52,13 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
try { try {
const { client } = await initializeClient({ req, res, endpointOption }); const { client } = await initializeClient({ req, res, endpointOption });
const unfinished = endpointOption.endpoint === EModelEndpoint.google ? false : true; const messageCache = getLogStores(CacheKeys.MESSAGES);
const { onProgress: progressCallback, getPartialText } = createOnProgress({ const { onProgress: progressCallback, getPartialText } = createOnProgress({
onProgress: throttle( onProgress: throttle(
({ text: partialText }) => { ({ text: partialText }) => {
saveMessage(req, { /*
const unfinished = endpointOption.endpoint === EModelEndpoint.google ? false : true;
messageCache.set(responseMessageId, {
messageId: responseMessageId, messageId: responseMessageId,
sender, sender,
conversationId, conversationId,
@ -65,7 +68,10 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
unfinished, unfinished,
error: false, error: false,
user, user,
}); }, Time.FIVE_MINUTES);
*/
messageCache.set(responseMessageId, partialText, Time.FIVE_MINUTES);
}, },
3000, 3000,
{ trailing: false }, { trailing: false },

View file

@ -1,7 +1,8 @@
const throttle = require('lodash/throttle'); const throttle = require('lodash/throttle');
const { getResponseSender, EModelEndpoint } = require('librechat-data-provider'); const { getResponseSender, CacheKeys, Time } = require('librechat-data-provider');
const { createAbortController, handleAbortError } = require('~/server/middleware'); const { createAbortController, handleAbortError } = require('~/server/middleware');
const { sendMessage, createOnProgress } = require('~/server/utils'); const { sendMessage, createOnProgress } = require('~/server/utils');
const { getLogStores } = require('~/cache');
const { saveMessage } = require('~/models'); const { saveMessage } = require('~/models');
const { logger } = require('~/config'); const { logger } = require('~/config');
@ -51,12 +52,14 @@ const EditController = async (req, res, next, initializeClient) => {
} }
}; };
const unfinished = endpointOption.endpoint === EModelEndpoint.google ? false : true; const messageCache = getLogStores(CacheKeys.MESSAGES);
const { onProgress: progressCallback, getPartialText } = createOnProgress({ const { onProgress: progressCallback, getPartialText } = createOnProgress({
generation, generation,
onProgress: throttle( onProgress: throttle(
({ text: partialText }) => { ({ text: partialText }) => {
saveMessage(req, { /*
const unfinished = endpointOption.endpoint === EModelEndpoint.google ? false : true;
{
messageId: responseMessageId, messageId: responseMessageId,
sender, sender,
conversationId, conversationId,
@ -67,7 +70,8 @@ const EditController = async (req, res, next, initializeClient) => {
isEdited: true, isEdited: true,
error: false, error: false,
user, user,
}); } */
messageCache.set(responseMessageId, partialText, Time.FIVE_MINUTES);
}, },
3000, 3000,
{ trailing: false }, { trailing: false },

View file

@ -1,12 +1,12 @@
const { v4 } = require('uuid'); const { v4 } = require('uuid');
const { const {
Time,
Constants, Constants,
RunStatus, RunStatus,
CacheKeys, CacheKeys,
ContentTypes, ContentTypes,
ToolCallTypes, ToolCallTypes,
EModelEndpoint, EModelEndpoint,
ViolationTypes,
retrievalMimeTypes, retrievalMimeTypes,
AssistantStreamEvents, AssistantStreamEvents,
} = require('librechat-data-provider'); } = require('librechat-data-provider');
@ -14,12 +14,12 @@ const {
initThread, initThread,
recordUsage, recordUsage,
saveUserMessage, saveUserMessage,
checkMessageGaps,
addThreadMetadata, addThreadMetadata,
saveAssistantMessage, saveAssistantMessage,
} = require('~/server/services/Threads'); } = require('~/server/services/Threads');
const { sendResponse, sendMessage, sleep, isEnabled, countTokens } = require('~/server/utils');
const { runAssistant, createOnTextProgress } = require('~/server/services/AssistantService'); const { runAssistant, createOnTextProgress } = require('~/server/services/AssistantService');
const { sendMessage, sleep, isEnabled, countTokens } = require('~/server/utils');
const { createErrorHandler } = require('~/server/controllers/assistants/errors');
const validateAuthor = require('~/server/middleware/assistants/validateAuthor'); const validateAuthor = require('~/server/middleware/assistants/validateAuthor');
const { createRun, StreamRunManager } = require('~/server/services/Runs'); const { createRun, StreamRunManager } = require('~/server/services/Runs');
const { addTitle } = require('~/server/services/Endpoints/assistants'); const { addTitle } = require('~/server/services/Endpoints/assistants');
@ -44,7 +44,7 @@ const ten_minutes = 1000 * 60 * 10;
const chatV2 = async (req, res) => { const chatV2 = async (req, res) => {
logger.debug('[/assistants/chat/] req.body', req.body); logger.debug('[/assistants/chat/] req.body', req.body);
/** @type {{ files: MongoFile[]}} */ /** @type {{files: MongoFile[]}} */
const { const {
text, text,
model, model,
@ -90,140 +90,20 @@ const chatV2 = async (req, res) => {
/** @type {Run | undefined} - The completed run, undefined if incomplete */ /** @type {Run | undefined} - The completed run, undefined if incomplete */
let completedRun; let completedRun;
const handleError = async (error) => { const getContext = () => ({
const defaultErrorMessage =
'The Assistant run failed to initialize. Try sending a message in a new conversation.';
const messageData = {
thread_id,
assistant_id,
conversationId,
parentMessageId,
sender: 'System',
user: req.user.id,
shouldSaveMessage: false,
messageId: responseMessageId,
endpoint,
};
if (error.message === 'Run cancelled') {
return res.end();
} else if (error.message === 'Request closed' && completedRun) {
return;
} else if (error.message === 'Request closed') {
logger.debug('[/assistants/chat/] Request aborted on close');
} else if (/Files.*are invalid/.test(error.message)) {
const errorMessage = `Files are invalid, or may not have uploaded yet.${
endpoint === EModelEndpoint.azureAssistants
? ' If using Azure OpenAI, files are only available in the region of the assistant\'s model at the time of upload.'
: ''
}`;
return sendResponse(req, res, messageData, errorMessage);
} else if (error?.message?.includes('string too long')) {
return sendResponse(
req,
res,
messageData,
'Message too long. The Assistants API has a limit of 32,768 characters per message. Please shorten it and try again.',
);
} else if (error?.message?.includes(ViolationTypes.TOKEN_BALANCE)) {
return sendResponse(req, res, messageData, error.message);
} else {
logger.error('[/assistants/chat/]', error);
}
if (!openai || !thread_id || !run_id) {
return sendResponse(req, res, messageData, defaultErrorMessage);
}
await sleep(2000);
try {
const status = await cache.get(cacheKey);
if (status === 'cancelled') {
logger.debug('[/assistants/chat/] Run already cancelled');
return res.end();
}
await cache.delete(cacheKey);
const cancelledRun = await openai.beta.threads.runs.cancel(thread_id, run_id);
logger.debug('[/assistants/chat/] Cancelled run:', cancelledRun);
} catch (error) {
logger.error('[/assistants/chat/] Error cancelling run', error);
}
await sleep(2000);
let run;
try {
run = await openai.beta.threads.runs.retrieve(thread_id, run_id);
await recordUsage({
...run.usage,
model: run.model,
user: req.user.id,
conversationId,
});
} catch (error) {
logger.error('[/assistants/chat/] Error fetching or processing run', error);
}
let finalEvent;
try {
const runMessages = await checkMessageGaps({
openai, openai,
run_id, run_id,
endpoint, endpoint,
cacheKey,
thread_id, thread_id,
completedRun,
assistant_id,
conversationId, conversationId,
latestMessageId: responseMessageId, parentMessageId,
responseMessageId,
}); });
const errorContentPart = { const handleError = createErrorHandler({ req, res, getContext });
text: {
value:
error?.message ?? 'There was an error processing your request. Please try again later.',
},
type: ContentTypes.ERROR,
};
if (!Array.isArray(runMessages[runMessages.length - 1]?.content)) {
runMessages[runMessages.length - 1].content = [errorContentPart];
} else {
const contentParts = runMessages[runMessages.length - 1].content;
for (let i = 0; i < contentParts.length; i++) {
const currentPart = contentParts[i];
/** @type {CodeToolCall | RetrievalToolCall | FunctionToolCall | undefined} */
const toolCall = currentPart?.[ContentTypes.TOOL_CALL];
if (
toolCall &&
toolCall?.function &&
!(toolCall?.function?.output || toolCall?.function?.output?.length)
) {
contentParts[i] = {
...currentPart,
[ContentTypes.TOOL_CALL]: {
...toolCall,
function: {
...toolCall.function,
output: 'error processing tool',
},
},
};
}
}
runMessages[runMessages.length - 1].content.push(errorContentPart);
}
finalEvent = {
final: true,
conversation: await getConvo(req.user.id, conversationId),
runMessages,
};
} catch (error) {
logger.error('[/assistants/chat/] Error finalizing error process', error);
return sendResponse(req, res, messageData, 'The Assistant run failed');
}
return sendResponse(req, res, finalEvent);
};
try { try {
res.on('close', async () => { res.on('close', async () => {
@ -490,6 +370,11 @@ const chatV2 = async (req, res) => {
}, },
}; };
/** @type {undefined | TAssistantEndpoint} */
const config = req.app.locals[endpoint] ?? {};
/** @type {undefined | TBaseEndpoint} */
const allConfig = req.app.locals.all;
const streamRunManager = new StreamRunManager({ const streamRunManager = new StreamRunManager({
req, req,
res, res,
@ -499,6 +384,7 @@ const chatV2 = async (req, res) => {
attachedFileIds, attachedFileIds,
parentMessageId: userMessageId, parentMessageId: userMessageId,
responseMessage: openai.responseMessage, responseMessage: openai.responseMessage,
streamRate: allConfig?.streamRate ?? config.streamRate,
// streamOptions: { // streamOptions: {
// }, // },
@ -511,6 +397,16 @@ const chatV2 = async (req, res) => {
response = streamRunManager; response = streamRunManager;
response.text = streamRunManager.intermediateText; response.text = streamRunManager.intermediateText;
const messageCache = getLogStores(CacheKeys.MESSAGES);
messageCache.set(
responseMessageId,
{
complete: true,
text: response.text,
},
Time.FIVE_MINUTES,
);
}; };
await processRun(); await processRun();

View file

@ -0,0 +1,193 @@
// errorHandler.js
const { sendResponse } = require('~/server/utils');
const { logger } = require('~/config');
const getLogStores = require('~/cache/getLogStores');
const { CacheKeys, ViolationTypes, ContentTypes } = require('librechat-data-provider');
const { getConvo } = require('~/models/Conversation');
const { recordUsage, checkMessageGaps } = require('~/server/services/Threads');
/**
* @typedef {Object} ErrorHandlerContext
* @property {OpenAIClient} openai - The OpenAI client
* @property {string} thread_id - The thread ID
* @property {string} run_id - The run ID
* @property {boolean} completedRun - Whether the run has completed
* @property {string} assistant_id - The assistant ID
* @property {string} conversationId - The conversation ID
* @property {string} parentMessageId - The parent message ID
* @property {string} responseMessageId - The response message ID
* @property {string} endpoint - The endpoint being used
* @property {string} cacheKey - The cache key for the current request
*/
/**
* @typedef {Object} ErrorHandlerDependencies
* @property {Express.Request} req - The Express request object
* @property {Express.Response} res - The Express response object
* @property {() => ErrorHandlerContext} getContext - Function to get the current context
* @property {string} [originPath] - The origin path for the error handler
*/
/**
* Creates an error handler function with the given dependencies
* @param {ErrorHandlerDependencies} dependencies - The dependencies for the error handler
* @returns {(error: Error) => Promise<void>} The error handler function
*/
const createErrorHandler = ({ req, res, getContext, originPath = '/assistants/chat/' }) => {
const cache = getLogStores(CacheKeys.ABORT_KEYS);
/**
* Handles errors that occur during the chat process
* @param {Error} error - The error that occurred
* @returns {Promise<void>}
*/
return async (error) => {
const {
openai,
run_id,
endpoint,
cacheKey,
thread_id,
completedRun,
assistant_id,
conversationId,
parentMessageId,
responseMessageId,
} = getContext();
const defaultErrorMessage =
'The Assistant run failed to initialize. Try sending a message in a new conversation.';
const messageData = {
thread_id,
assistant_id,
conversationId,
parentMessageId,
sender: 'System',
user: req.user.id,
shouldSaveMessage: false,
messageId: responseMessageId,
endpoint,
};
if (error.message === 'Run cancelled') {
return res.end();
} else if (error.message === 'Request closed' && completedRun) {
return;
} else if (error.message === 'Request closed') {
logger.debug(`[${originPath}] Request aborted on close`);
} else if (/Files.*are invalid/.test(error.message)) {
const errorMessage = `Files are invalid, or may not have uploaded yet.${
endpoint === 'azureAssistants'
? ' If using Azure OpenAI, files are only available in the region of the assistant\'s model at the time of upload.'
: ''
}`;
return sendResponse(req, res, messageData, errorMessage);
} else if (error?.message?.includes('string too long')) {
return sendResponse(
req,
res,
messageData,
'Message too long. The Assistants API has a limit of 32,768 characters per message. Please shorten it and try again.',
);
} else if (error?.message?.includes(ViolationTypes.TOKEN_BALANCE)) {
return sendResponse(req, res, messageData, error.message);
} else {
logger.error(`[${originPath}]`, error);
}
if (!openai || !thread_id || !run_id) {
return sendResponse(req, res, messageData, defaultErrorMessage);
}
await new Promise((resolve) => setTimeout(resolve, 2000));
try {
const status = await cache.get(cacheKey);
if (status === 'cancelled') {
logger.debug(`[${originPath}] Run already cancelled`);
return res.end();
}
await cache.delete(cacheKey);
const cancelledRun = await openai.beta.threads.runs.cancel(thread_id, run_id);
logger.debug(`[${originPath}] Cancelled run:`, cancelledRun);
} catch (error) {
logger.error(`[${originPath}] Error cancelling run`, error);
}
await new Promise((resolve) => setTimeout(resolve, 2000));
let run;
try {
run = await openai.beta.threads.runs.retrieve(thread_id, run_id);
await recordUsage({
...run.usage,
model: run.model,
user: req.user.id,
conversationId,
});
} catch (error) {
logger.error(`[${originPath}] Error fetching or processing run`, error);
}
let finalEvent;
try {
const runMessages = await checkMessageGaps({
openai,
run_id,
endpoint,
thread_id,
conversationId,
latestMessageId: responseMessageId,
});
const errorContentPart = {
text: {
value:
error?.message ?? 'There was an error processing your request. Please try again later.',
},
type: ContentTypes.ERROR,
};
if (!Array.isArray(runMessages[runMessages.length - 1]?.content)) {
runMessages[runMessages.length - 1].content = [errorContentPart];
} else {
const contentParts = runMessages[runMessages.length - 1].content;
for (let i = 0; i < contentParts.length; i++) {
const currentPart = contentParts[i];
/** @type {CodeToolCall | RetrievalToolCall | FunctionToolCall | undefined} */
const toolCall = currentPart?.[ContentTypes.TOOL_CALL];
if (
toolCall &&
toolCall?.function &&
!(toolCall?.function?.output || toolCall?.function?.output?.length)
) {
contentParts[i] = {
...currentPart,
[ContentTypes.TOOL_CALL]: {
...toolCall,
function: {
...toolCall.function,
output: 'error processing tool',
},
},
};
}
}
runMessages[runMessages.length - 1].content.push(errorContentPart);
}
finalEvent = {
final: true,
conversation: await getConvo(req.user.id, conversationId),
runMessages,
};
} catch (error) {
logger.error(`[${originPath}] Error finalizing error process`, error);
return sendResponse(req, res, messageData, 'The Assistant run failed');
}
return sendResponse(req, res, finalEvent);
};
};
module.exports = { createErrorHandler };

View file

@ -30,7 +30,10 @@ async function abortMessage(req, res) {
return res.status(204).send({ message: 'Request not found' }); return res.status(204).send({ message: 'Request not found' });
} }
const finalEvent = await abortController.abortCompletion(); const finalEvent = await abortController.abortCompletion();
logger.info('[abortMessage] Aborted request', { abortKey }); logger.debug(
`[abortMessage] ID: ${req.user.id} | ${req.user.email} | Aborted request: ` +
JSON.stringify({ abortKey }),
);
abortControllers.delete(abortKey); abortControllers.delete(abortKey);
if (res.headersSent && finalEvent) { if (res.headersSent && finalEvent) {

View file

@ -1,10 +1,11 @@
const express = require('express'); const express = require('express');
const throttle = require('lodash/throttle'); const throttle = require('lodash/throttle');
const { getResponseSender, Constants } = require('librechat-data-provider'); const { getResponseSender, Constants, CacheKeys, Time } = require('librechat-data-provider');
const { initializeClient } = require('~/server/services/Endpoints/gptPlugins'); const { initializeClient } = require('~/server/services/Endpoints/gptPlugins');
const { sendMessage, createOnProgress } = require('~/server/utils'); const { sendMessage, createOnProgress } = require('~/server/utils');
const { addTitle } = require('~/server/services/Endpoints/openAI'); const { addTitle } = require('~/server/services/Endpoints/openAI');
const { saveMessage } = require('~/models'); const { saveMessage } = require('~/models');
const { getLogStores } = require('~/cache');
const { const {
handleAbort, handleAbort,
createAbortController, createAbortController,
@ -71,7 +72,8 @@ router.post(
} }
}; };
const throttledSaveMessage = throttle(saveMessage, 3000, { trailing: false }); const messageCache = getLogStores(CacheKeys.MESSAGES);
const throttledSetMessage = throttle(messageCache.set, 3000, { trailing: false });
let streaming = null; let streaming = null;
let timer = null; let timer = null;
@ -85,7 +87,8 @@ router.post(
clearTimeout(timer); clearTimeout(timer);
} }
throttledSaveMessage(req, { /*
{
messageId: responseMessageId, messageId: responseMessageId,
sender, sender,
conversationId, conversationId,
@ -96,7 +99,9 @@ router.post(
error: false, error: false,
plugins, plugins,
user, user,
}); }
*/
throttledSetMessage(responseMessageId, partialText, Time.FIVE_MINUTES);
streaming = new Promise((resolve) => { streaming = new Promise((resolve) => {
timer = setTimeout(() => { timer = setTimeout(() => {

View file

@ -1,19 +1,20 @@
const express = require('express'); const express = require('express');
const throttle = require('lodash/throttle'); const throttle = require('lodash/throttle');
const { getResponseSender } = require('librechat-data-provider'); const { getResponseSender, CacheKeys, Time } = require('librechat-data-provider');
const { const {
handleAbort,
createAbortController,
handleAbortError,
setHeaders, setHeaders,
handleAbort,
moderateText,
validateModel, validateModel,
handleAbortError,
validateEndpoint, validateEndpoint,
buildEndpointOption, buildEndpointOption,
moderateText, createAbortController,
} = require('~/server/middleware'); } = require('~/server/middleware');
const { sendMessage, createOnProgress, formatSteps, formatAction } = require('~/server/utils'); const { sendMessage, createOnProgress, formatSteps, formatAction } = require('~/server/utils');
const { initializeClient } = require('~/server/services/Endpoints/gptPlugins'); const { initializeClient } = require('~/server/services/Endpoints/gptPlugins');
const { saveMessage } = require('~/models'); const { saveMessage } = require('~/models');
const { getLogStores } = require('~/cache');
const { validateTools } = require('~/app'); const { validateTools } = require('~/app');
const { logger } = require('~/config'); const { logger } = require('~/config');
@ -79,7 +80,8 @@ router.post(
} }
}; };
const throttledSaveMessage = throttle(saveMessage, 3000, { trailing: false }); const messageCache = getLogStores(CacheKeys.MESSAGES);
const throttledSetMessage = throttle(messageCache.set, 3000, { trailing: false });
const { const {
onProgress: progressCallback, onProgress: progressCallback,
sendIntermediateMessage, sendIntermediateMessage,
@ -91,7 +93,8 @@ router.post(
plugin.loading = false; plugin.loading = false;
} }
throttledSaveMessage(req, { /*
{
messageId: responseMessageId, messageId: responseMessageId,
sender, sender,
conversationId, conversationId,
@ -102,7 +105,9 @@ router.post(
isEdited: true, isEdited: true,
error: false, error: false,
user, user,
}); }
*/
throttledSetMessage(responseMessageId, partialText, Time.FIVE_MINUTES);
}, },
}); });

View file

@ -67,17 +67,18 @@ const AppService = async (app) => {
handleRateLimits(config?.rateLimits); handleRateLimits(config?.rateLimits);
const endpointLocals = {}; const endpointLocals = {};
const endpoints = config?.endpoints;
if (config?.endpoints?.[EModelEndpoint.azureOpenAI]) { if (endpoints?.[EModelEndpoint.azureOpenAI]) {
endpointLocals[EModelEndpoint.azureOpenAI] = azureConfigSetup(config); endpointLocals[EModelEndpoint.azureOpenAI] = azureConfigSetup(config);
checkAzureVariables(); checkAzureVariables();
} }
if (config?.endpoints?.[EModelEndpoint.azureOpenAI]?.assistants) { if (endpoints?.[EModelEndpoint.azureOpenAI]?.assistants) {
endpointLocals[EModelEndpoint.azureAssistants] = azureAssistantsDefaults(); endpointLocals[EModelEndpoint.azureAssistants] = azureAssistantsDefaults();
} }
if (config?.endpoints?.[EModelEndpoint.azureAssistants]) { if (endpoints?.[EModelEndpoint.azureAssistants]) {
endpointLocals[EModelEndpoint.azureAssistants] = assistantsConfigSetup( endpointLocals[EModelEndpoint.azureAssistants] = assistantsConfigSetup(
config, config,
EModelEndpoint.azureAssistants, EModelEndpoint.azureAssistants,
@ -85,7 +86,7 @@ const AppService = async (app) => {
); );
} }
if (config?.endpoints?.[EModelEndpoint.assistants]) { if (endpoints?.[EModelEndpoint.assistants]) {
endpointLocals[EModelEndpoint.assistants] = assistantsConfigSetup( endpointLocals[EModelEndpoint.assistants] = assistantsConfigSetup(
config, config,
EModelEndpoint.assistants, EModelEndpoint.assistants,
@ -93,6 +94,19 @@ const AppService = async (app) => {
); );
} }
if (endpoints?.[EModelEndpoint.openAI]) {
endpointLocals[EModelEndpoint.openAI] = endpoints[EModelEndpoint.openAI];
}
if (endpoints?.[EModelEndpoint.google]) {
endpointLocals[EModelEndpoint.google] = endpoints[EModelEndpoint.google];
}
if (endpoints?.[EModelEndpoint.anthropic]) {
endpointLocals[EModelEndpoint.anthropic] = endpoints[EModelEndpoint.anthropic];
}
if (endpoints?.[EModelEndpoint.gptPlugins]) {
endpointLocals[EModelEndpoint.gptPlugins] = endpoints[EModelEndpoint.gptPlugins];
}
app.locals = { app.locals = {
...defaultLocals, ...defaultLocals,
modelSpecs: config.modelSpecs, modelSpecs: config.modelSpecs,

View file

@ -19,11 +19,27 @@ const initializeClient = async ({ req, res, endpointOption }) => {
checkUserKeyExpiry(expiresAt, EModelEndpoint.anthropic); checkUserKeyExpiry(expiresAt, EModelEndpoint.anthropic);
} }
const clientOptions = {};
/** @type {undefined | TBaseEndpoint} */
const anthropicConfig = req.app.locals[EModelEndpoint.anthropic];
if (anthropicConfig) {
clientOptions.streamRate = anthropicConfig.streamRate;
}
/** @type {undefined | TBaseEndpoint} */
const allConfig = req.app.locals.all;
if (allConfig) {
clientOptions.streamRate = allConfig.streamRate;
}
const client = new AnthropicClient(anthropicApiKey, { const client = new AnthropicClient(anthropicApiKey, {
req, req,
res, res,
reverseProxyUrl: ANTHROPIC_REVERSE_PROXY ?? null, reverseProxyUrl: ANTHROPIC_REVERSE_PROXY ?? null,
proxy: PROXY ?? null, proxy: PROXY ?? null,
...clientOptions,
...endpointOption, ...endpointOption,
}); });

View file

@ -114,9 +114,16 @@ const initializeClient = async ({ req, res, endpointOption }) => {
contextStrategy: endpointConfig.summarize ? 'summarize' : null, contextStrategy: endpointConfig.summarize ? 'summarize' : null,
directEndpoint: endpointConfig.directEndpoint, directEndpoint: endpointConfig.directEndpoint,
titleMessageRole: endpointConfig.titleMessageRole, titleMessageRole: endpointConfig.titleMessageRole,
streamRate: endpointConfig.streamRate,
endpointTokenConfig, endpointTokenConfig,
}; };
/** @type {undefined | TBaseEndpoint} */
const allConfig = req.app.locals.all;
if (allConfig) {
customOptions.streamRate = allConfig.streamRate;
}
const clientOptions = { const clientOptions = {
reverseProxyUrl: baseURL ?? null, reverseProxyUrl: baseURL ?? null,
proxy: PROXY ?? null, proxy: PROXY ?? null,

View file

@ -27,11 +27,27 @@ const initializeClient = async ({ req, res, endpointOption }) => {
[AuthKeys.GOOGLE_API_KEY]: GOOGLE_KEY, [AuthKeys.GOOGLE_API_KEY]: GOOGLE_KEY,
}; };
const clientOptions = {};
/** @type {undefined | TBaseEndpoint} */
const allConfig = req.app.locals.all;
/** @type {undefined | TBaseEndpoint} */
const googleConfig = req.app.locals[EModelEndpoint.google];
if (googleConfig) {
clientOptions.streamRate = googleConfig.streamRate;
}
if (allConfig) {
clientOptions.streamRate = allConfig.streamRate;
}
const client = new GoogleClient(credentials, { const client = new GoogleClient(credentials, {
req, req,
res, res,
reverseProxyUrl: GOOGLE_REVERSE_PROXY ?? null, reverseProxyUrl: GOOGLE_REVERSE_PROXY ?? null,
proxy: PROXY ?? null, proxy: PROXY ?? null,
...clientOptions,
...endpointOption, ...endpointOption,
}); });

View file

@ -8,6 +8,8 @@ jest.mock('~/server/services/UserService', () => ({
getUserKey: jest.fn().mockImplementation(() => ({})), getUserKey: jest.fn().mockImplementation(() => ({})),
})); }));
const app = { locals: {} };
describe('google/initializeClient', () => { describe('google/initializeClient', () => {
afterEach(() => { afterEach(() => {
jest.clearAllMocks(); jest.clearAllMocks();
@ -23,6 +25,7 @@ describe('google/initializeClient', () => {
const req = { const req = {
body: { key: expiresAt }, body: { key: expiresAt },
user: { id: '123' }, user: { id: '123' },
app,
}; };
const res = {}; const res = {};
const endpointOption = { modelOptions: { model: 'default-model' } }; const endpointOption = { modelOptions: { model: 'default-model' } };
@ -44,6 +47,7 @@ describe('google/initializeClient', () => {
const req = { const req = {
body: { key: null }, body: { key: null },
user: { id: '123' }, user: { id: '123' },
app,
}; };
const res = {}; const res = {};
const endpointOption = { modelOptions: { model: 'default-model' } }; const endpointOption = { modelOptions: { model: 'default-model' } };
@ -66,6 +70,7 @@ describe('google/initializeClient', () => {
const req = { const req = {
body: { key: expiresAt }, body: { key: expiresAt },
user: { id: '123' }, user: { id: '123' },
app,
}; };
const res = {}; const res = {};
const endpointOption = { modelOptions: { model: 'default-model' } }; const endpointOption = { modelOptions: { model: 'default-model' } };

View file

@ -86,6 +86,9 @@ const initializeClient = async ({ req, res, endpointOption }) => {
clientOptions.titleModel = azureConfig.titleModel; clientOptions.titleModel = azureConfig.titleModel;
clientOptions.titleMethod = azureConfig.titleMethod ?? 'completion'; clientOptions.titleMethod = azureConfig.titleMethod ?? 'completion';
const azureRate = modelName.includes('gpt-4') ? 30 : 17;
clientOptions.streamRate = azureConfig.streamRate ?? azureRate;
const groupName = modelGroupMap[modelName].group; const groupName = modelGroupMap[modelName].group;
clientOptions.addParams = azureConfig.groupMap[groupName].addParams; clientOptions.addParams = azureConfig.groupMap[groupName].addParams;
clientOptions.dropParams = azureConfig.groupMap[groupName].dropParams; clientOptions.dropParams = azureConfig.groupMap[groupName].dropParams;
@ -98,6 +101,19 @@ const initializeClient = async ({ req, res, endpointOption }) => {
apiKey = clientOptions.azure.azureOpenAIApiKey; apiKey = clientOptions.azure.azureOpenAIApiKey;
} }
/** @type {undefined | TBaseEndpoint} */
const pluginsConfig = req.app.locals[EModelEndpoint.gptPlugins];
if (!useAzure && pluginsConfig) {
clientOptions.streamRate = pluginsConfig.streamRate;
}
/** @type {undefined | TBaseEndpoint} */
const allConfig = req.app.locals.all;
if (allConfig) {
clientOptions.streamRate = allConfig.streamRate;
}
if (!apiKey) { if (!apiKey) {
throw new Error(`${endpoint} API key not provided. Please provide it again.`); throw new Error(`${endpoint} API key not provided. Please provide it again.`);
} }

View file

@ -76,6 +76,10 @@ const initializeClient = async ({ req, res, endpointOption }) => {
clientOptions.titleConvo = azureConfig.titleConvo; clientOptions.titleConvo = azureConfig.titleConvo;
clientOptions.titleModel = azureConfig.titleModel; clientOptions.titleModel = azureConfig.titleModel;
const azureRate = modelName.includes('gpt-4') ? 30 : 17;
clientOptions.streamRate = azureConfig.streamRate ?? azureRate;
clientOptions.titleMethod = azureConfig.titleMethod ?? 'completion'; clientOptions.titleMethod = azureConfig.titleMethod ?? 'completion';
const groupName = modelGroupMap[modelName].group; const groupName = modelGroupMap[modelName].group;
@ -90,6 +94,19 @@ const initializeClient = async ({ req, res, endpointOption }) => {
apiKey = clientOptions.azure.azureOpenAIApiKey; apiKey = clientOptions.azure.azureOpenAIApiKey;
} }
/** @type {undefined | TBaseEndpoint} */
const openAIConfig = req.app.locals[EModelEndpoint.openAI];
if (!isAzureOpenAI && openAIConfig) {
clientOptions.streamRate = openAIConfig.streamRate;
}
/** @type {undefined | TBaseEndpoint} */
const allConfig = req.app.locals.all;
if (allConfig) {
clientOptions.streamRate = allConfig.streamRate;
}
if (userProvidesKey & !apiKey) { if (userProvidesKey & !apiKey) {
throw new Error( throw new Error(
JSON.stringify({ JSON.stringify({

View file

@ -1,5 +1,6 @@
const WebSocket = require('ws'); const WebSocket = require('ws');
const { Message } = require('~/models/Message'); const { CacheKeys } = require('librechat-data-provider');
const { getLogStores } = require('~/cache');
/** /**
* @param {string[]} voiceIds - Array of voice IDs * @param {string[]} voiceIds - Array of voice IDs
@ -104,6 +105,8 @@ function createChunkProcessor(messageId) {
throw new Error('Message ID is required'); throw new Error('Message ID is required');
} }
const messageCache = getLogStores(CacheKeys.MESSAGES);
/** /**
* @returns {Promise<{ text: string, isFinished: boolean }[] | string>} * @returns {Promise<{ text: string, isFinished: boolean }[] | string>}
*/ */
@ -116,14 +119,17 @@ function createChunkProcessor(messageId) {
return `No change in message after ${MAX_NO_CHANGE_COUNT} attempts`; return `No change in message after ${MAX_NO_CHANGE_COUNT} attempts`;
} }
const message = await Message.findOne({ messageId }, 'text unfinished').lean(); /** @type { string | { text: string; complete: boolean } } */
const message = await messageCache.get(messageId);
if (!message || !message.text) { if (!message) {
notFoundCount++; notFoundCount++;
return []; return [];
} }
const { text, unfinished } = message; const text = typeof message === 'string' ? message : message.text;
const complete = typeof message === 'string' ? false : message.complete;
if (text === processedText) { if (text === processedText) {
noChangeCount++; noChangeCount++;
} }
@ -131,7 +137,7 @@ function createChunkProcessor(messageId) {
const remainingText = text.slice(processedText.length); const remainingText = text.slice(processedText.length);
const chunks = []; const chunks = [];
if (unfinished && remainingText.length >= 20) { if (!complete && remainingText.length >= 20) {
const separatorIndex = findLastSeparatorIndex(remainingText); const separatorIndex = findLastSeparatorIndex(remainingText);
if (separatorIndex !== -1) { if (separatorIndex !== -1) {
const chunkText = remainingText.slice(0, separatorIndex + 1); const chunkText = remainingText.slice(0, separatorIndex + 1);
@ -141,7 +147,7 @@ function createChunkProcessor(messageId) {
chunks.push({ text: remainingText, isFinished: false }); chunks.push({ text: remainingText, isFinished: false });
processedText = text; processedText = text;
} }
} else if (!unfinished && remainingText.trim().length > 0) { } else if (complete && remainingText.trim().length > 0) {
chunks.push({ text: remainingText.trim(), isFinished: true }); chunks.push({ text: remainingText.trim(), isFinished: true });
processedText = text; processedText = text;
} }

View file

@ -1,89 +1,145 @@
const { createChunkProcessor, splitTextIntoChunks } = require('./streamAudio'); const { createChunkProcessor, splitTextIntoChunks } = require('./streamAudio');
const { Message } = require('~/models/Message');
jest.mock('~/models/Message', () => ({ jest.mock('keyv');
Message: {
findOne: jest.fn().mockReturnValue({ const globalCache = {};
lean: jest.fn(), jest.mock('~/cache/getLogStores', () => {
}), return jest.fn().mockImplementation(() => {
}, const EventEmitter = require('events');
})); const { CacheKeys } = require('librechat-data-provider');
class KeyvMongo extends EventEmitter {
constructor(url = 'mongodb://127.0.0.1:27017', options) {
super();
this.ttlSupport = false;
url = url ?? {};
if (typeof url === 'string') {
url = { url };
}
if (url.uri) {
url = { url: url.uri, ...url };
}
this.opts = {
url,
collection: 'keyv',
...url,
...options,
};
}
get = async (key) => {
return new Promise((resolve) => {
resolve(globalCache[key] || null);
});
};
set = async (key, value) => {
return new Promise((resolve) => {
globalCache[key] = value;
resolve(true);
});
};
}
return new KeyvMongo('', {
namespace: CacheKeys.MESSAGES,
ttl: 0,
});
});
});
describe('processChunks', () => { describe('processChunks', () => {
let processChunks; let processChunks;
let mockMessageCache;
beforeEach(() => { beforeEach(() => {
jest.resetAllMocks();
mockMessageCache = {
get: jest.fn(),
};
require('~/cache/getLogStores').mockReturnValue(mockMessageCache);
processChunks = createChunkProcessor('message-id'); processChunks = createChunkProcessor('message-id');
Message.findOne.mockClear();
Message.findOne().lean.mockClear();
}); });
it('should return an empty array when the message is not found', async () => { it('should return an empty array when the message is not found', async () => {
Message.findOne().lean.mockResolvedValueOnce(null); mockMessageCache.get.mockResolvedValueOnce(null);
const result = await processChunks(); const result = await processChunks();
expect(result).toEqual([]); expect(result).toEqual([]);
expect(Message.findOne).toHaveBeenCalledWith({ messageId: 'message-id' }, 'text unfinished'); expect(mockMessageCache.get).toHaveBeenCalledWith('message-id');
expect(Message.findOne().lean).toHaveBeenCalled();
}); });
it('should return an empty array when the message does not have a text property', async () => { it('should return an error message after MAX_NOT_FOUND_COUNT attempts', async () => {
Message.findOne().lean.mockResolvedValueOnce({ unfinished: true }); mockMessageCache.get.mockResolvedValue(null);
for (let i = 0; i < 6; i++) {
await processChunks();
}
const result = await processChunks(); const result = await processChunks();
expect(result).toEqual([]); expect(result).toBe('Message not found after 6 attempts');
expect(Message.findOne).toHaveBeenCalledWith({ messageId: 'message-id' }, 'text unfinished');
expect(Message.findOne().lean).toHaveBeenCalled();
}); });
it('should return chunks for an unfinished message with separators', async () => { it('should return chunks for an incomplete message with separators', async () => {
const messageText = 'This is a long message. It should be split into chunks. Lol hi mom'; const messageText = 'This is a long message. It should be split into chunks. Lol hi mom';
Message.findOne().lean.mockResolvedValueOnce({ text: messageText, unfinished: true }); mockMessageCache.get.mockResolvedValueOnce({ text: messageText, complete: false });
const result = await processChunks(); const result = await processChunks();
expect(result).toEqual([ expect(result).toEqual([
{ text: 'This is a long message. It should be split into chunks.', isFinished: false }, { text: 'This is a long message. It should be split into chunks.', isFinished: false },
]); ]);
expect(Message.findOne).toHaveBeenCalledWith({ messageId: 'message-id' }, 'text unfinished');
expect(Message.findOne().lean).toHaveBeenCalled();
}); });
it('should return chunks for an unfinished message without separators', async () => { it('should return chunks for an incomplete message without separators', async () => {
const messageText = 'This is a long message without separators hello there my friend'; const messageText = 'This is a long message without separators hello there my friend';
Message.findOne().lean.mockResolvedValueOnce({ text: messageText, unfinished: true }); mockMessageCache.get.mockResolvedValueOnce({ text: messageText, complete: false });
const result = await processChunks(); const result = await processChunks();
expect(result).toEqual([{ text: messageText, isFinished: false }]); expect(result).toEqual([{ text: messageText, isFinished: false }]);
expect(Message.findOne).toHaveBeenCalledWith({ messageId: 'message-id' }, 'text unfinished');
expect(Message.findOne().lean).toHaveBeenCalled();
}); });
it('should return the remaining text as a chunk for a finished message', async () => { it('should return the remaining text as a chunk for a complete message', async () => {
const messageText = 'This is a finished message.'; const messageText = 'This is a finished message.';
Message.findOne().lean.mockResolvedValueOnce({ text: messageText, unfinished: false }); mockMessageCache.get.mockResolvedValueOnce({ text: messageText, complete: true });
const result = await processChunks(); const result = await processChunks();
expect(result).toEqual([{ text: messageText, isFinished: true }]); expect(result).toEqual([{ text: messageText, isFinished: true }]);
expect(Message.findOne).toHaveBeenCalledWith({ messageId: 'message-id' }, 'text unfinished');
expect(Message.findOne().lean).toHaveBeenCalled();
}); });
it('should return an empty array for a finished message with no remaining text', async () => { it('should return an empty array for a complete message with no remaining text', async () => {
const messageText = 'This is a finished message.'; const messageText = 'This is a finished message.';
Message.findOne().lean.mockResolvedValueOnce({ text: messageText, unfinished: false }); mockMessageCache.get.mockResolvedValueOnce({ text: messageText, complete: true });
await processChunks(); await processChunks();
Message.findOne().lean.mockResolvedValueOnce({ text: messageText, unfinished: false }); mockMessageCache.get.mockResolvedValueOnce({ text: messageText, complete: true });
const result = await processChunks(); const result = await processChunks();
expect(result).toEqual([]); expect(result).toEqual([]);
expect(Message.findOne).toHaveBeenCalledWith({ messageId: 'message-id' }, 'text unfinished'); });
expect(Message.findOne().lean).toHaveBeenCalledTimes(2);
it('should return an error message after MAX_NO_CHANGE_COUNT attempts with no change', async () => {
const messageText = 'This is a message that does not change.';
mockMessageCache.get.mockResolvedValue({ text: messageText, complete: false });
for (let i = 0; i < 11; i++) {
await processChunks();
}
const result = await processChunks();
expect(result).toBe('No change in message after 10 attempts');
});
it('should handle string messages as incomplete', async () => {
const messageText = 'This is a message as a string.';
mockMessageCache.get.mockResolvedValueOnce(messageText);
const result = await processChunks();
expect(result).toEqual([{ text: messageText, isFinished: false }]);
}); });
}); });

View file

@ -1,17 +1,19 @@
const throttle = require('lodash/throttle'); const throttle = require('lodash/throttle');
const { const {
Time,
CacheKeys,
StepTypes, StepTypes,
ContentTypes, ContentTypes,
ToolCallTypes, ToolCallTypes,
// StepStatus,
MessageContentTypes, MessageContentTypes,
AssistantStreamEvents, AssistantStreamEvents,
Constants,
} = require('librechat-data-provider'); } = require('librechat-data-provider');
const { retrieveAndProcessFile } = require('~/server/services/Files/process'); const { retrieveAndProcessFile } = require('~/server/services/Files/process');
const { processRequiredActions } = require('~/server/services/ToolService'); const { processRequiredActions } = require('~/server/services/ToolService');
const { saveMessage, updateMessageText } = require('~/models/Message'); const { createOnProgress, sendMessage, sleep } = require('~/server/utils');
const { createOnProgress, sendMessage } = require('~/server/utils');
const { processMessages } = require('~/server/services/Threads'); const { processMessages } = require('~/server/services/Threads');
const { getLogStores } = require('~/cache');
const { logger } = require('~/config'); const { logger } = require('~/config');
/** /**
@ -68,8 +70,8 @@ class StreamRunManager {
this.attachedFileIds = fields.attachedFileIds; this.attachedFileIds = fields.attachedFileIds;
/** @type {undefined | Promise<ChatCompletion>} */ /** @type {undefined | Promise<ChatCompletion>} */
this.visionPromise = fields.visionPromise; this.visionPromise = fields.visionPromise;
/** @type {boolean} */ /** @type {number} */
this.savedInitialMessage = false; this.streamRate = fields.streamRate ?? Constants.DEFAULT_STREAM_RATE;
/** /**
* @type {Object.<AssistantStreamEvents, (event: AssistantStreamEvent) => Promise<void>>} * @type {Object.<AssistantStreamEvents, (event: AssistantStreamEvent) => Promise<void>>}
@ -139,11 +141,11 @@ class StreamRunManager {
return this.intermediateText; return this.intermediateText;
} }
/** Saves the initial intermediate message /** Returns the current, intermediate message
* @returns {Promise<void>} * @returns {TMessage}
*/ */
async saveInitialMessage() { getIntermediateMessage() {
return saveMessage(this.req, { return {
conversationId: this.finalMessage.conversationId, conversationId: this.finalMessage.conversationId,
messageId: this.finalMessage.messageId, messageId: this.finalMessage.messageId,
parentMessageId: this.parentMessageId, parentMessageId: this.parentMessageId,
@ -155,7 +157,7 @@ class StreamRunManager {
sender: 'Assistant', sender: 'Assistant',
unfinished: true, unfinished: true,
error: false, error: false,
}); };
} }
/* <------------------ Main Event Handlers ------------------> */ /* <------------------ Main Event Handlers ------------------> */
@ -347,6 +349,8 @@ class StreamRunManager {
type: ContentTypes.TOOL_CALL, type: ContentTypes.TOOL_CALL,
index, index,
}); });
await sleep(this.streamRate);
} }
}; };
@ -444,6 +448,7 @@ class StreamRunManager {
if (content && content.type === MessageContentTypes.TEXT) { if (content && content.type === MessageContentTypes.TEXT) {
this.intermediateText += content.text.value; this.intermediateText += content.text.value;
onProgress(content.text.value); onProgress(content.text.value);
await sleep(this.streamRate);
} }
} }
@ -589,21 +594,14 @@ class StreamRunManager {
const index = this.getStepIndex(stepKey); const index = this.getStepIndex(stepKey);
this.orderedRunSteps.set(index, message_creation); this.orderedRunSteps.set(index, message_creation);
const messageCache = getLogStores(CacheKeys.MESSAGES);
// Create the Factory Function to stream the message // Create the Factory Function to stream the message
const { onProgress: progressCallback } = createOnProgress({ const { onProgress: progressCallback } = createOnProgress({
onProgress: throttle( onProgress: throttle(
() => { () => {
if (!this.savedInitialMessage) { messageCache.set(this.finalMessage.messageId, this.getText(), Time.FIVE_MINUTES);
this.saveInitialMessage();
this.savedInitialMessage = true;
} else {
updateMessageText({
messageId: this.finalMessage.messageId,
text: this.getText(),
});
}
}, },
2000, 3000,
{ trailing: false }, { trailing: false },
), ),
}); });

View file

@ -51,6 +51,7 @@ function assistantsConfigSetup(config, assistantsEndpoint, prevConfig = {}) {
excludedIds: parsedConfig.excludedIds, excludedIds: parsedConfig.excludedIds,
privateAssistants: parsedConfig.privateAssistants, privateAssistants: parsedConfig.privateAssistants,
timeoutMs: parsedConfig.timeoutMs, timeoutMs: parsedConfig.timeoutMs,
streamRate: parsedConfig.streamRate,
}; };
} }

View file

@ -465,6 +465,12 @@
* @memberof typedefs * @memberof typedefs
*/ */
/**
* @exports TBaseEndpoint
* @typedef {import('librechat-data-provider').TBaseEndpoint} TBaseEndpoint
* @memberof typedefs
*/
/** /**
* @exports TEndpoint * @exports TEndpoint
* @typedef {import('librechat-data-provider').TEndpoint} TEndpoint * @typedef {import('librechat-data-provider').TEndpoint} TEndpoint

2
package-lock.json generated
View file

@ -29437,7 +29437,7 @@
}, },
"packages/data-provider": { "packages/data-provider": {
"name": "librechat-data-provider", "name": "librechat-data-provider",
"version": "0.7.1", "version": "0.7.2",
"license": "ISC", "license": "ISC",
"dependencies": { "dependencies": {
"@types/js-yaml": "^4.0.9", "@types/js-yaml": "^4.0.9",

View file

@ -1,6 +1,6 @@
{ {
"name": "librechat-data-provider", "name": "librechat-data-provider",
"version": "0.7.1", "version": "0.7.2",
"description": "data services for librechat apps", "description": "data services for librechat apps",
"main": "dist/index.js", "main": "dist/index.js",
"module": "dist/index.es.js", "module": "dist/index.es.js",

View file

@ -136,7 +136,14 @@ export const defaultAssistantsVersion = {
[EModelEndpoint.azureAssistants]: 1, [EModelEndpoint.azureAssistants]: 1,
}; };
export const assistantEndpointSchema = z.object({ export const baseEndpointSchema = z.object({
streamRate: z.number().optional(),
});
export type TBaseEndpoint = z.infer<typeof baseEndpointSchema>;
export const assistantEndpointSchema = baseEndpointSchema.merge(
z.object({
/* assistants specific */ /* assistants specific */
disableBuilder: z.boolean().optional(), disableBuilder: z.boolean().optional(),
pollIntervalMs: z.number().optional(), pollIntervalMs: z.number().optional(),
@ -170,11 +177,13 @@ export const assistantEndpointSchema = z.object({
titleMethod: z.union([z.literal('completion'), z.literal('functions')]).optional(), titleMethod: z.union([z.literal('completion'), z.literal('functions')]).optional(),
titleModel: z.string().optional(), titleModel: z.string().optional(),
headers: z.record(z.any()).optional(), headers: z.record(z.any()).optional(),
}); }),
);
export type TAssistantEndpoint = z.infer<typeof assistantEndpointSchema>; export type TAssistantEndpoint = z.infer<typeof assistantEndpointSchema>;
export const endpointSchema = z.object({ export const endpointSchema = baseEndpointSchema.merge(
z.object({
name: z.string().refine((value) => !eModelEndpointSchema.safeParse(value).success, { name: z.string().refine((value) => !eModelEndpointSchema.safeParse(value).success, {
message: `Value cannot be one of the default endpoint (EModelEndpoint) values: ${Object.values( message: `Value cannot be one of the default endpoint (EModelEndpoint) values: ${Object.values(
EModelEndpoint, EModelEndpoint,
@ -200,7 +209,8 @@ export const endpointSchema = z.object({
customOrder: z.number().optional(), customOrder: z.number().optional(),
directEndpoint: z.boolean().optional(), directEndpoint: z.boolean().optional(),
titleMessageRole: z.string().optional(), titleMessageRole: z.string().optional(),
}); }),
);
export type TEndpoint = z.infer<typeof endpointSchema>; export type TEndpoint = z.infer<typeof endpointSchema>;
@ -213,6 +223,7 @@ export const azureEndpointSchema = z
.and( .and(
endpointSchema endpointSchema
.pick({ .pick({
streamRate: true,
titleConvo: true, titleConvo: true,
titleMethod: true, titleMethod: true,
titleModel: true, titleModel: true,
@ -426,10 +437,15 @@ export const configSchema = z.object({
modelSpecs: specsConfigSchema.optional(), modelSpecs: specsConfigSchema.optional(),
endpoints: z endpoints: z
.object({ .object({
all: baseEndpointSchema.optional(),
[EModelEndpoint.openAI]: baseEndpointSchema.optional(),
[EModelEndpoint.google]: baseEndpointSchema.optional(),
[EModelEndpoint.anthropic]: baseEndpointSchema.optional(),
[EModelEndpoint.gptPlugins]: baseEndpointSchema.optional(),
[EModelEndpoint.azureOpenAI]: azureEndpointSchema.optional(), [EModelEndpoint.azureOpenAI]: azureEndpointSchema.optional(),
[EModelEndpoint.azureAssistants]: assistantEndpointSchema.optional(), [EModelEndpoint.azureAssistants]: assistantEndpointSchema.optional(),
[EModelEndpoint.assistants]: assistantEndpointSchema.optional(), [EModelEndpoint.assistants]: assistantEndpointSchema.optional(),
custom: z.array(endpointSchema.partial()).optional(), [EModelEndpoint.custom]: z.array(endpointSchema.partial()).optional(),
}) })
.strict() .strict()
.refine((data) => Object.keys(data).length > 0, { .refine((data) => Object.keys(data).length > 0, {
@ -657,6 +673,16 @@ export enum InfiniteCollections {
SHARED_LINKS = 'sharedLinks', SHARED_LINKS = 'sharedLinks',
} }
/**
* Enum for time intervals
*/
export enum Time {
THIRTY_MINUTES = 1800000,
TEN_MINUTES = 600000,
FIVE_MINUTES = 300000,
TWO_MINUTES = 120000,
}
/** /**
* Enum for cache keys. * Enum for cache keys.
*/ */
@ -727,6 +753,10 @@ export enum CacheKeys {
* Key for the cached audio run Ids. * Key for the cached audio run Ids.
*/ */
AUDIO_RUNS = 'audioRuns', AUDIO_RUNS = 'audioRuns',
/**
* Key for in-progress messages.
*/
MESSAGES = 'messages',
} }
/** /**
@ -911,6 +941,8 @@ export enum Constants {
COMMON_DIVIDER = '__', COMMON_DIVIDER = '__',
/** Max length for commands */ /** Max length for commands */
COMMANDS_MAX_LENGTH = 56, COMMANDS_MAX_LENGTH = 56,
/** Default Stream Rate (ms) */
DEFAULT_STREAM_RATE = 1,
} }
export enum LocalStorageKeys { export enum LocalStorageKeys {