⚙️ feat: Adjust Rate of Stream Progress (#3244)

* chore: bump data-provider and add MESSAGES CacheKey

* refactor: avoid saving messages while streaming, save partial text to cache instead

* fix(ci): processChunks

* chore: logging aborted request to debug

* feat: set stream rate for token processing

* chore: specify default stream rate

* fix(ci): Update AppService.js to use optional chaining for endpointLocals assignment

* refactor: abstract the error handler

* feat: streamRate for assistants; refactor: update default rate for token

* refactor: update error handling in assistants/errors.js

* refactor: update error handling in assistants/errors.js
This commit is contained in:
Danny Avila 2024-07-17 10:47:17 -04:00 committed by GitHub
parent 1c282d1517
commit 5d40d0a37a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
29 changed files with 661 additions and 309 deletions

View file

@ -1,7 +1,8 @@
const throttle = require('lodash/throttle');
const { getResponseSender, Constants, EModelEndpoint } = require('librechat-data-provider');
const { getResponseSender, Constants, CacheKeys, Time } = require('librechat-data-provider');
const { createAbortController, handleAbortError } = require('~/server/middleware');
const { sendMessage, createOnProgress } = require('~/server/utils');
const { getLogStores } = require('~/cache');
const { saveMessage } = require('~/models');
const { logger } = require('~/config');
@ -51,11 +52,13 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
try {
const { client } = await initializeClient({ req, res, endpointOption });
const unfinished = endpointOption.endpoint === EModelEndpoint.google ? false : true;
const messageCache = getLogStores(CacheKeys.MESSAGES);
const { onProgress: progressCallback, getPartialText } = createOnProgress({
onProgress: throttle(
({ text: partialText }) => {
saveMessage(req, {
/*
const unfinished = endpointOption.endpoint === EModelEndpoint.google ? false : true;
messageCache.set(responseMessageId, {
messageId: responseMessageId,
sender,
conversationId,
@ -65,7 +68,10 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
unfinished,
error: false,
user,
});
}, Time.FIVE_MINUTES);
*/
messageCache.set(responseMessageId, partialText, Time.FIVE_MINUTES);
},
3000,
{ trailing: false },

View file

@ -1,7 +1,8 @@
const throttle = require('lodash/throttle');
const { getResponseSender, EModelEndpoint } = require('librechat-data-provider');
const { getResponseSender, CacheKeys, Time } = require('librechat-data-provider');
const { createAbortController, handleAbortError } = require('~/server/middleware');
const { sendMessage, createOnProgress } = require('~/server/utils');
const { getLogStores } = require('~/cache');
const { saveMessage } = require('~/models');
const { logger } = require('~/config');
@ -51,12 +52,14 @@ const EditController = async (req, res, next, initializeClient) => {
}
};
const unfinished = endpointOption.endpoint === EModelEndpoint.google ? false : true;
const messageCache = getLogStores(CacheKeys.MESSAGES);
const { onProgress: progressCallback, getPartialText } = createOnProgress({
generation,
onProgress: throttle(
({ text: partialText }) => {
saveMessage(req, {
/*
const unfinished = endpointOption.endpoint === EModelEndpoint.google ? false : true;
{
messageId: responseMessageId,
sender,
conversationId,
@ -67,7 +70,8 @@ const EditController = async (req, res, next, initializeClient) => {
isEdited: true,
error: false,
user,
});
} */
messageCache.set(responseMessageId, partialText, Time.FIVE_MINUTES);
},
3000,
{ trailing: false },

View file

@ -1,12 +1,12 @@
const { v4 } = require('uuid');
const {
Time,
Constants,
RunStatus,
CacheKeys,
ContentTypes,
ToolCallTypes,
EModelEndpoint,
ViolationTypes,
retrievalMimeTypes,
AssistantStreamEvents,
} = require('librechat-data-provider');
@ -14,12 +14,12 @@ const {
initThread,
recordUsage,
saveUserMessage,
checkMessageGaps,
addThreadMetadata,
saveAssistantMessage,
} = require('~/server/services/Threads');
const { sendResponse, sendMessage, sleep, isEnabled, countTokens } = require('~/server/utils');
const { runAssistant, createOnTextProgress } = require('~/server/services/AssistantService');
const { sendMessage, sleep, isEnabled, countTokens } = require('~/server/utils');
const { createErrorHandler } = require('~/server/controllers/assistants/errors');
const validateAuthor = require('~/server/middleware/assistants/validateAuthor');
const { createRun, StreamRunManager } = require('~/server/services/Runs');
const { addTitle } = require('~/server/services/Endpoints/assistants');
@ -44,7 +44,7 @@ const ten_minutes = 1000 * 60 * 10;
const chatV2 = async (req, res) => {
logger.debug('[/assistants/chat/] req.body', req.body);
/** @type {{ files: MongoFile[]}} */
/** @type {{files: MongoFile[]}} */
const {
text,
model,
@ -90,140 +90,20 @@ const chatV2 = async (req, res) => {
/** @type {Run | undefined} - The completed run, undefined if incomplete */
let completedRun;
const handleError = async (error) => {
const defaultErrorMessage =
'The Assistant run failed to initialize. Try sending a message in a new conversation.';
const messageData = {
thread_id,
assistant_id,
conversationId,
parentMessageId,
sender: 'System',
user: req.user.id,
shouldSaveMessage: false,
messageId: responseMessageId,
endpoint,
};
const getContext = () => ({
openai,
run_id,
endpoint,
cacheKey,
thread_id,
completedRun,
assistant_id,
conversationId,
parentMessageId,
responseMessageId,
});
if (error.message === 'Run cancelled') {
return res.end();
} else if (error.message === 'Request closed' && completedRun) {
return;
} else if (error.message === 'Request closed') {
logger.debug('[/assistants/chat/] Request aborted on close');
} else if (/Files.*are invalid/.test(error.message)) {
const errorMessage = `Files are invalid, or may not have uploaded yet.${
endpoint === EModelEndpoint.azureAssistants
? ' If using Azure OpenAI, files are only available in the region of the assistant\'s model at the time of upload.'
: ''
}`;
return sendResponse(req, res, messageData, errorMessage);
} else if (error?.message?.includes('string too long')) {
return sendResponse(
req,
res,
messageData,
'Message too long. The Assistants API has a limit of 32,768 characters per message. Please shorten it and try again.',
);
} else if (error?.message?.includes(ViolationTypes.TOKEN_BALANCE)) {
return sendResponse(req, res, messageData, error.message);
} else {
logger.error('[/assistants/chat/]', error);
}
if (!openai || !thread_id || !run_id) {
return sendResponse(req, res, messageData, defaultErrorMessage);
}
await sleep(2000);
try {
const status = await cache.get(cacheKey);
if (status === 'cancelled') {
logger.debug('[/assistants/chat/] Run already cancelled');
return res.end();
}
await cache.delete(cacheKey);
const cancelledRun = await openai.beta.threads.runs.cancel(thread_id, run_id);
logger.debug('[/assistants/chat/] Cancelled run:', cancelledRun);
} catch (error) {
logger.error('[/assistants/chat/] Error cancelling run', error);
}
await sleep(2000);
let run;
try {
run = await openai.beta.threads.runs.retrieve(thread_id, run_id);
await recordUsage({
...run.usage,
model: run.model,
user: req.user.id,
conversationId,
});
} catch (error) {
logger.error('[/assistants/chat/] Error fetching or processing run', error);
}
let finalEvent;
try {
const runMessages = await checkMessageGaps({
openai,
run_id,
endpoint,
thread_id,
conversationId,
latestMessageId: responseMessageId,
});
const errorContentPart = {
text: {
value:
error?.message ?? 'There was an error processing your request. Please try again later.',
},
type: ContentTypes.ERROR,
};
if (!Array.isArray(runMessages[runMessages.length - 1]?.content)) {
runMessages[runMessages.length - 1].content = [errorContentPart];
} else {
const contentParts = runMessages[runMessages.length - 1].content;
for (let i = 0; i < contentParts.length; i++) {
const currentPart = contentParts[i];
/** @type {CodeToolCall | RetrievalToolCall | FunctionToolCall | undefined} */
const toolCall = currentPart?.[ContentTypes.TOOL_CALL];
if (
toolCall &&
toolCall?.function &&
!(toolCall?.function?.output || toolCall?.function?.output?.length)
) {
contentParts[i] = {
...currentPart,
[ContentTypes.TOOL_CALL]: {
...toolCall,
function: {
...toolCall.function,
output: 'error processing tool',
},
},
};
}
}
runMessages[runMessages.length - 1].content.push(errorContentPart);
}
finalEvent = {
final: true,
conversation: await getConvo(req.user.id, conversationId),
runMessages,
};
} catch (error) {
logger.error('[/assistants/chat/] Error finalizing error process', error);
return sendResponse(req, res, messageData, 'The Assistant run failed');
}
return sendResponse(req, res, finalEvent);
};
const handleError = createErrorHandler({ req, res, getContext });
try {
res.on('close', async () => {
@ -490,6 +370,11 @@ const chatV2 = async (req, res) => {
},
};
/** @type {undefined | TAssistantEndpoint} */
const config = req.app.locals[endpoint] ?? {};
/** @type {undefined | TBaseEndpoint} */
const allConfig = req.app.locals.all;
const streamRunManager = new StreamRunManager({
req,
res,
@ -499,6 +384,7 @@ const chatV2 = async (req, res) => {
attachedFileIds,
parentMessageId: userMessageId,
responseMessage: openai.responseMessage,
streamRate: allConfig?.streamRate ?? config.streamRate,
// streamOptions: {
// },
@ -511,6 +397,16 @@ const chatV2 = async (req, res) => {
response = streamRunManager;
response.text = streamRunManager.intermediateText;
const messageCache = getLogStores(CacheKeys.MESSAGES);
messageCache.set(
responseMessageId,
{
complete: true,
text: response.text,
},
Time.FIVE_MINUTES,
);
};
await processRun();

View file

@ -0,0 +1,193 @@
// errorHandler.js
const { sendResponse } = require('~/server/utils');
const { logger } = require('~/config');
const getLogStores = require('~/cache/getLogStores');
const { CacheKeys, ViolationTypes, ContentTypes } = require('librechat-data-provider');
const { getConvo } = require('~/models/Conversation');
const { recordUsage, checkMessageGaps } = require('~/server/services/Threads');
/**
* @typedef {Object} ErrorHandlerContext
* @property {OpenAIClient} openai - The OpenAI client
* @property {string} thread_id - The thread ID
* @property {string} run_id - The run ID
* @property {boolean} completedRun - Whether the run has completed
* @property {string} assistant_id - The assistant ID
* @property {string} conversationId - The conversation ID
* @property {string} parentMessageId - The parent message ID
* @property {string} responseMessageId - The response message ID
* @property {string} endpoint - The endpoint being used
* @property {string} cacheKey - The cache key for the current request
*/
/**
* @typedef {Object} ErrorHandlerDependencies
* @property {Express.Request} req - The Express request object
* @property {Express.Response} res - The Express response object
* @property {() => ErrorHandlerContext} getContext - Function to get the current context
* @property {string} [originPath] - The origin path for the error handler
*/
/**
* Creates an error handler function with the given dependencies
* @param {ErrorHandlerDependencies} dependencies - The dependencies for the error handler
* @returns {(error: Error) => Promise<void>} The error handler function
*/
const createErrorHandler = ({ req, res, getContext, originPath = '/assistants/chat/' }) => {
const cache = getLogStores(CacheKeys.ABORT_KEYS);
/**
* Handles errors that occur during the chat process
* @param {Error} error - The error that occurred
* @returns {Promise<void>}
*/
return async (error) => {
const {
openai,
run_id,
endpoint,
cacheKey,
thread_id,
completedRun,
assistant_id,
conversationId,
parentMessageId,
responseMessageId,
} = getContext();
const defaultErrorMessage =
'The Assistant run failed to initialize. Try sending a message in a new conversation.';
const messageData = {
thread_id,
assistant_id,
conversationId,
parentMessageId,
sender: 'System',
user: req.user.id,
shouldSaveMessage: false,
messageId: responseMessageId,
endpoint,
};
if (error.message === 'Run cancelled') {
return res.end();
} else if (error.message === 'Request closed' && completedRun) {
return;
} else if (error.message === 'Request closed') {
logger.debug(`[${originPath}] Request aborted on close`);
} else if (/Files.*are invalid/.test(error.message)) {
const errorMessage = `Files are invalid, or may not have uploaded yet.${
endpoint === 'azureAssistants'
? ' If using Azure OpenAI, files are only available in the region of the assistant\'s model at the time of upload.'
: ''
}`;
return sendResponse(req, res, messageData, errorMessage);
} else if (error?.message?.includes('string too long')) {
return sendResponse(
req,
res,
messageData,
'Message too long. The Assistants API has a limit of 32,768 characters per message. Please shorten it and try again.',
);
} else if (error?.message?.includes(ViolationTypes.TOKEN_BALANCE)) {
return sendResponse(req, res, messageData, error.message);
} else {
logger.error(`[${originPath}]`, error);
}
if (!openai || !thread_id || !run_id) {
return sendResponse(req, res, messageData, defaultErrorMessage);
}
await new Promise((resolve) => setTimeout(resolve, 2000));
try {
const status = await cache.get(cacheKey);
if (status === 'cancelled') {
logger.debug(`[${originPath}] Run already cancelled`);
return res.end();
}
await cache.delete(cacheKey);
const cancelledRun = await openai.beta.threads.runs.cancel(thread_id, run_id);
logger.debug(`[${originPath}] Cancelled run:`, cancelledRun);
} catch (error) {
logger.error(`[${originPath}] Error cancelling run`, error);
}
await new Promise((resolve) => setTimeout(resolve, 2000));
let run;
try {
run = await openai.beta.threads.runs.retrieve(thread_id, run_id);
await recordUsage({
...run.usage,
model: run.model,
user: req.user.id,
conversationId,
});
} catch (error) {
logger.error(`[${originPath}] Error fetching or processing run`, error);
}
let finalEvent;
try {
const runMessages = await checkMessageGaps({
openai,
run_id,
endpoint,
thread_id,
conversationId,
latestMessageId: responseMessageId,
});
const errorContentPart = {
text: {
value:
error?.message ?? 'There was an error processing your request. Please try again later.',
},
type: ContentTypes.ERROR,
};
if (!Array.isArray(runMessages[runMessages.length - 1]?.content)) {
runMessages[runMessages.length - 1].content = [errorContentPart];
} else {
const contentParts = runMessages[runMessages.length - 1].content;
for (let i = 0; i < contentParts.length; i++) {
const currentPart = contentParts[i];
/** @type {CodeToolCall | RetrievalToolCall | FunctionToolCall | undefined} */
const toolCall = currentPart?.[ContentTypes.TOOL_CALL];
if (
toolCall &&
toolCall?.function &&
!(toolCall?.function?.output || toolCall?.function?.output?.length)
) {
contentParts[i] = {
...currentPart,
[ContentTypes.TOOL_CALL]: {
...toolCall,
function: {
...toolCall.function,
output: 'error processing tool',
},
},
};
}
}
runMessages[runMessages.length - 1].content.push(errorContentPart);
}
finalEvent = {
final: true,
conversation: await getConvo(req.user.id, conversationId),
runMessages,
};
} catch (error) {
logger.error(`[${originPath}] Error finalizing error process`, error);
return sendResponse(req, res, messageData, 'The Assistant run failed');
}
return sendResponse(req, res, finalEvent);
};
};
module.exports = { createErrorHandler };

View file

@ -30,7 +30,10 @@ async function abortMessage(req, res) {
return res.status(204).send({ message: 'Request not found' });
}
const finalEvent = await abortController.abortCompletion();
logger.info('[abortMessage] Aborted request', { abortKey });
logger.debug(
`[abortMessage] ID: ${req.user.id} | ${req.user.email} | Aborted request: ` +
JSON.stringify({ abortKey }),
);
abortControllers.delete(abortKey);
if (res.headersSent && finalEvent) {

View file

@ -1,10 +1,11 @@
const express = require('express');
const throttle = require('lodash/throttle');
const { getResponseSender, Constants } = require('librechat-data-provider');
const { getResponseSender, Constants, CacheKeys, Time } = require('librechat-data-provider');
const { initializeClient } = require('~/server/services/Endpoints/gptPlugins');
const { sendMessage, createOnProgress } = require('~/server/utils');
const { addTitle } = require('~/server/services/Endpoints/openAI');
const { saveMessage } = require('~/models');
const { getLogStores } = require('~/cache');
const {
handleAbort,
createAbortController,
@ -71,7 +72,8 @@ router.post(
}
};
const throttledSaveMessage = throttle(saveMessage, 3000, { trailing: false });
const messageCache = getLogStores(CacheKeys.MESSAGES);
const throttledSetMessage = throttle(messageCache.set, 3000, { trailing: false });
let streaming = null;
let timer = null;
@ -85,7 +87,8 @@ router.post(
clearTimeout(timer);
}
throttledSaveMessage(req, {
/*
{
messageId: responseMessageId,
sender,
conversationId,
@ -96,7 +99,9 @@ router.post(
error: false,
plugins,
user,
});
}
*/
throttledSetMessage(responseMessageId, partialText, Time.FIVE_MINUTES);
streaming = new Promise((resolve) => {
timer = setTimeout(() => {

View file

@ -1,19 +1,20 @@
const express = require('express');
const throttle = require('lodash/throttle');
const { getResponseSender } = require('librechat-data-provider');
const { getResponseSender, CacheKeys, Time } = require('librechat-data-provider');
const {
handleAbort,
createAbortController,
handleAbortError,
setHeaders,
handleAbort,
moderateText,
validateModel,
handleAbortError,
validateEndpoint,
buildEndpointOption,
moderateText,
createAbortController,
} = require('~/server/middleware');
const { sendMessage, createOnProgress, formatSteps, formatAction } = require('~/server/utils');
const { initializeClient } = require('~/server/services/Endpoints/gptPlugins');
const { saveMessage } = require('~/models');
const { getLogStores } = require('~/cache');
const { validateTools } = require('~/app');
const { logger } = require('~/config');
@ -79,7 +80,8 @@ router.post(
}
};
const throttledSaveMessage = throttle(saveMessage, 3000, { trailing: false });
const messageCache = getLogStores(CacheKeys.MESSAGES);
const throttledSetMessage = throttle(messageCache.set, 3000, { trailing: false });
const {
onProgress: progressCallback,
sendIntermediateMessage,
@ -91,7 +93,8 @@ router.post(
plugin.loading = false;
}
throttledSaveMessage(req, {
/*
{
messageId: responseMessageId,
sender,
conversationId,
@ -102,7 +105,9 @@ router.post(
isEdited: true,
error: false,
user,
});
}
*/
throttledSetMessage(responseMessageId, partialText, Time.FIVE_MINUTES);
},
});

View file

@ -67,17 +67,18 @@ const AppService = async (app) => {
handleRateLimits(config?.rateLimits);
const endpointLocals = {};
const endpoints = config?.endpoints;
if (config?.endpoints?.[EModelEndpoint.azureOpenAI]) {
if (endpoints?.[EModelEndpoint.azureOpenAI]) {
endpointLocals[EModelEndpoint.azureOpenAI] = azureConfigSetup(config);
checkAzureVariables();
}
if (config?.endpoints?.[EModelEndpoint.azureOpenAI]?.assistants) {
if (endpoints?.[EModelEndpoint.azureOpenAI]?.assistants) {
endpointLocals[EModelEndpoint.azureAssistants] = azureAssistantsDefaults();
}
if (config?.endpoints?.[EModelEndpoint.azureAssistants]) {
if (endpoints?.[EModelEndpoint.azureAssistants]) {
endpointLocals[EModelEndpoint.azureAssistants] = assistantsConfigSetup(
config,
EModelEndpoint.azureAssistants,
@ -85,7 +86,7 @@ const AppService = async (app) => {
);
}
if (config?.endpoints?.[EModelEndpoint.assistants]) {
if (endpoints?.[EModelEndpoint.assistants]) {
endpointLocals[EModelEndpoint.assistants] = assistantsConfigSetup(
config,
EModelEndpoint.assistants,
@ -93,6 +94,19 @@ const AppService = async (app) => {
);
}
if (endpoints?.[EModelEndpoint.openAI]) {
endpointLocals[EModelEndpoint.openAI] = endpoints[EModelEndpoint.openAI];
}
if (endpoints?.[EModelEndpoint.google]) {
endpointLocals[EModelEndpoint.google] = endpoints[EModelEndpoint.google];
}
if (endpoints?.[EModelEndpoint.anthropic]) {
endpointLocals[EModelEndpoint.anthropic] = endpoints[EModelEndpoint.anthropic];
}
if (endpoints?.[EModelEndpoint.gptPlugins]) {
endpointLocals[EModelEndpoint.gptPlugins] = endpoints[EModelEndpoint.gptPlugins];
}
app.locals = {
...defaultLocals,
modelSpecs: config.modelSpecs,

View file

@ -19,11 +19,27 @@ const initializeClient = async ({ req, res, endpointOption }) => {
checkUserKeyExpiry(expiresAt, EModelEndpoint.anthropic);
}
const clientOptions = {};
/** @type {undefined | TBaseEndpoint} */
const anthropicConfig = req.app.locals[EModelEndpoint.anthropic];
if (anthropicConfig) {
clientOptions.streamRate = anthropicConfig.streamRate;
}
/** @type {undefined | TBaseEndpoint} */
const allConfig = req.app.locals.all;
if (allConfig) {
clientOptions.streamRate = allConfig.streamRate;
}
const client = new AnthropicClient(anthropicApiKey, {
req,
res,
reverseProxyUrl: ANTHROPIC_REVERSE_PROXY ?? null,
proxy: PROXY ?? null,
...clientOptions,
...endpointOption,
});

View file

@ -114,9 +114,16 @@ const initializeClient = async ({ req, res, endpointOption }) => {
contextStrategy: endpointConfig.summarize ? 'summarize' : null,
directEndpoint: endpointConfig.directEndpoint,
titleMessageRole: endpointConfig.titleMessageRole,
streamRate: endpointConfig.streamRate,
endpointTokenConfig,
};
/** @type {undefined | TBaseEndpoint} */
const allConfig = req.app.locals.all;
if (allConfig) {
customOptions.streamRate = allConfig.streamRate;
}
const clientOptions = {
reverseProxyUrl: baseURL ?? null,
proxy: PROXY ?? null,

View file

@ -27,11 +27,27 @@ const initializeClient = async ({ req, res, endpointOption }) => {
[AuthKeys.GOOGLE_API_KEY]: GOOGLE_KEY,
};
const clientOptions = {};
/** @type {undefined | TBaseEndpoint} */
const allConfig = req.app.locals.all;
/** @type {undefined | TBaseEndpoint} */
const googleConfig = req.app.locals[EModelEndpoint.google];
if (googleConfig) {
clientOptions.streamRate = googleConfig.streamRate;
}
if (allConfig) {
clientOptions.streamRate = allConfig.streamRate;
}
const client = new GoogleClient(credentials, {
req,
res,
reverseProxyUrl: GOOGLE_REVERSE_PROXY ?? null,
proxy: PROXY ?? null,
...clientOptions,
...endpointOption,
});

View file

@ -8,6 +8,8 @@ jest.mock('~/server/services/UserService', () => ({
getUserKey: jest.fn().mockImplementation(() => ({})),
}));
const app = { locals: {} };
describe('google/initializeClient', () => {
afterEach(() => {
jest.clearAllMocks();
@ -23,6 +25,7 @@ describe('google/initializeClient', () => {
const req = {
body: { key: expiresAt },
user: { id: '123' },
app,
};
const res = {};
const endpointOption = { modelOptions: { model: 'default-model' } };
@ -44,6 +47,7 @@ describe('google/initializeClient', () => {
const req = {
body: { key: null },
user: { id: '123' },
app,
};
const res = {};
const endpointOption = { modelOptions: { model: 'default-model' } };
@ -66,6 +70,7 @@ describe('google/initializeClient', () => {
const req = {
body: { key: expiresAt },
user: { id: '123' },
app,
};
const res = {};
const endpointOption = { modelOptions: { model: 'default-model' } };

View file

@ -86,6 +86,9 @@ const initializeClient = async ({ req, res, endpointOption }) => {
clientOptions.titleModel = azureConfig.titleModel;
clientOptions.titleMethod = azureConfig.titleMethod ?? 'completion';
const azureRate = modelName.includes('gpt-4') ? 30 : 17;
clientOptions.streamRate = azureConfig.streamRate ?? azureRate;
const groupName = modelGroupMap[modelName].group;
clientOptions.addParams = azureConfig.groupMap[groupName].addParams;
clientOptions.dropParams = azureConfig.groupMap[groupName].dropParams;
@ -98,6 +101,19 @@ const initializeClient = async ({ req, res, endpointOption }) => {
apiKey = clientOptions.azure.azureOpenAIApiKey;
}
/** @type {undefined | TBaseEndpoint} */
const pluginsConfig = req.app.locals[EModelEndpoint.gptPlugins];
if (!useAzure && pluginsConfig) {
clientOptions.streamRate = pluginsConfig.streamRate;
}
/** @type {undefined | TBaseEndpoint} */
const allConfig = req.app.locals.all;
if (allConfig) {
clientOptions.streamRate = allConfig.streamRate;
}
if (!apiKey) {
throw new Error(`${endpoint} API key not provided. Please provide it again.`);
}

View file

@ -76,6 +76,10 @@ const initializeClient = async ({ req, res, endpointOption }) => {
clientOptions.titleConvo = azureConfig.titleConvo;
clientOptions.titleModel = azureConfig.titleModel;
const azureRate = modelName.includes('gpt-4') ? 30 : 17;
clientOptions.streamRate = azureConfig.streamRate ?? azureRate;
clientOptions.titleMethod = azureConfig.titleMethod ?? 'completion';
const groupName = modelGroupMap[modelName].group;
@ -90,6 +94,19 @@ const initializeClient = async ({ req, res, endpointOption }) => {
apiKey = clientOptions.azure.azureOpenAIApiKey;
}
/** @type {undefined | TBaseEndpoint} */
const openAIConfig = req.app.locals[EModelEndpoint.openAI];
if (!isAzureOpenAI && openAIConfig) {
clientOptions.streamRate = openAIConfig.streamRate;
}
/** @type {undefined | TBaseEndpoint} */
const allConfig = req.app.locals.all;
if (allConfig) {
clientOptions.streamRate = allConfig.streamRate;
}
if (userProvidesKey & !apiKey) {
throw new Error(
JSON.stringify({

View file

@ -1,5 +1,6 @@
const WebSocket = require('ws');
const { Message } = require('~/models/Message');
const { CacheKeys } = require('librechat-data-provider');
const { getLogStores } = require('~/cache');
/**
* @param {string[]} voiceIds - Array of voice IDs
@ -104,6 +105,8 @@ function createChunkProcessor(messageId) {
throw new Error('Message ID is required');
}
const messageCache = getLogStores(CacheKeys.MESSAGES);
/**
* @returns {Promise<{ text: string, isFinished: boolean }[] | string>}
*/
@ -116,14 +119,17 @@ function createChunkProcessor(messageId) {
return `No change in message after ${MAX_NO_CHANGE_COUNT} attempts`;
}
const message = await Message.findOne({ messageId }, 'text unfinished').lean();
/** @type { string | { text: string; complete: boolean } } */
const message = await messageCache.get(messageId);
if (!message || !message.text) {
if (!message) {
notFoundCount++;
return [];
}
const { text, unfinished } = message;
const text = typeof message === 'string' ? message : message.text;
const complete = typeof message === 'string' ? false : message.complete;
if (text === processedText) {
noChangeCount++;
}
@ -131,7 +137,7 @@ function createChunkProcessor(messageId) {
const remainingText = text.slice(processedText.length);
const chunks = [];
if (unfinished && remainingText.length >= 20) {
if (!complete && remainingText.length >= 20) {
const separatorIndex = findLastSeparatorIndex(remainingText);
if (separatorIndex !== -1) {
const chunkText = remainingText.slice(0, separatorIndex + 1);
@ -141,7 +147,7 @@ function createChunkProcessor(messageId) {
chunks.push({ text: remainingText, isFinished: false });
processedText = text;
}
} else if (!unfinished && remainingText.trim().length > 0) {
} else if (complete && remainingText.trim().length > 0) {
chunks.push({ text: remainingText.trim(), isFinished: true });
processedText = text;
}

View file

@ -1,89 +1,145 @@
const { createChunkProcessor, splitTextIntoChunks } = require('./streamAudio');
const { Message } = require('~/models/Message');
jest.mock('~/models/Message', () => ({
Message: {
findOne: jest.fn().mockReturnValue({
lean: jest.fn(),
}),
},
}));
jest.mock('keyv');
const globalCache = {};
jest.mock('~/cache/getLogStores', () => {
return jest.fn().mockImplementation(() => {
const EventEmitter = require('events');
const { CacheKeys } = require('librechat-data-provider');
class KeyvMongo extends EventEmitter {
constructor(url = 'mongodb://127.0.0.1:27017', options) {
super();
this.ttlSupport = false;
url = url ?? {};
if (typeof url === 'string') {
url = { url };
}
if (url.uri) {
url = { url: url.uri, ...url };
}
this.opts = {
url,
collection: 'keyv',
...url,
...options,
};
}
get = async (key) => {
return new Promise((resolve) => {
resolve(globalCache[key] || null);
});
};
set = async (key, value) => {
return new Promise((resolve) => {
globalCache[key] = value;
resolve(true);
});
};
}
return new KeyvMongo('', {
namespace: CacheKeys.MESSAGES,
ttl: 0,
});
});
});
describe('processChunks', () => {
let processChunks;
let mockMessageCache;
beforeEach(() => {
jest.resetAllMocks();
mockMessageCache = {
get: jest.fn(),
};
require('~/cache/getLogStores').mockReturnValue(mockMessageCache);
processChunks = createChunkProcessor('message-id');
Message.findOne.mockClear();
Message.findOne().lean.mockClear();
});
it('should return an empty array when the message is not found', async () => {
Message.findOne().lean.mockResolvedValueOnce(null);
mockMessageCache.get.mockResolvedValueOnce(null);
const result = await processChunks();
expect(result).toEqual([]);
expect(Message.findOne).toHaveBeenCalledWith({ messageId: 'message-id' }, 'text unfinished');
expect(Message.findOne().lean).toHaveBeenCalled();
expect(mockMessageCache.get).toHaveBeenCalledWith('message-id');
});
it('should return an empty array when the message does not have a text property', async () => {
Message.findOne().lean.mockResolvedValueOnce({ unfinished: true });
it('should return an error message after MAX_NOT_FOUND_COUNT attempts', async () => {
mockMessageCache.get.mockResolvedValue(null);
for (let i = 0; i < 6; i++) {
await processChunks();
}
const result = await processChunks();
expect(result).toEqual([]);
expect(Message.findOne).toHaveBeenCalledWith({ messageId: 'message-id' }, 'text unfinished');
expect(Message.findOne().lean).toHaveBeenCalled();
expect(result).toBe('Message not found after 6 attempts');
});
it('should return chunks for an unfinished message with separators', async () => {
it('should return chunks for an incomplete message with separators', async () => {
const messageText = 'This is a long message. It should be split into chunks. Lol hi mom';
Message.findOne().lean.mockResolvedValueOnce({ text: messageText, unfinished: true });
mockMessageCache.get.mockResolvedValueOnce({ text: messageText, complete: false });
const result = await processChunks();
expect(result).toEqual([
{ text: 'This is a long message. It should be split into chunks.', isFinished: false },
]);
expect(Message.findOne).toHaveBeenCalledWith({ messageId: 'message-id' }, 'text unfinished');
expect(Message.findOne().lean).toHaveBeenCalled();
});
it('should return chunks for an unfinished message without separators', async () => {
it('should return chunks for an incomplete message without separators', async () => {
const messageText = 'This is a long message without separators hello there my friend';
Message.findOne().lean.mockResolvedValueOnce({ text: messageText, unfinished: true });
mockMessageCache.get.mockResolvedValueOnce({ text: messageText, complete: false });
const result = await processChunks();
expect(result).toEqual([{ text: messageText, isFinished: false }]);
expect(Message.findOne).toHaveBeenCalledWith({ messageId: 'message-id' }, 'text unfinished');
expect(Message.findOne().lean).toHaveBeenCalled();
});
it('should return the remaining text as a chunk for a finished message', async () => {
it('should return the remaining text as a chunk for a complete message', async () => {
const messageText = 'This is a finished message.';
Message.findOne().lean.mockResolvedValueOnce({ text: messageText, unfinished: false });
mockMessageCache.get.mockResolvedValueOnce({ text: messageText, complete: true });
const result = await processChunks();
expect(result).toEqual([{ text: messageText, isFinished: true }]);
expect(Message.findOne).toHaveBeenCalledWith({ messageId: 'message-id' }, 'text unfinished');
expect(Message.findOne().lean).toHaveBeenCalled();
});
it('should return an empty array for a finished message with no remaining text', async () => {
it('should return an empty array for a complete message with no remaining text', async () => {
const messageText = 'This is a finished message.';
Message.findOne().lean.mockResolvedValueOnce({ text: messageText, unfinished: false });
mockMessageCache.get.mockResolvedValueOnce({ text: messageText, complete: true });
await processChunks();
Message.findOne().lean.mockResolvedValueOnce({ text: messageText, unfinished: false });
mockMessageCache.get.mockResolvedValueOnce({ text: messageText, complete: true });
const result = await processChunks();
expect(result).toEqual([]);
expect(Message.findOne).toHaveBeenCalledWith({ messageId: 'message-id' }, 'text unfinished');
expect(Message.findOne().lean).toHaveBeenCalledTimes(2);
});
it('should return an error message after MAX_NO_CHANGE_COUNT attempts with no change', async () => {
const messageText = 'This is a message that does not change.';
mockMessageCache.get.mockResolvedValue({ text: messageText, complete: false });
for (let i = 0; i < 11; i++) {
await processChunks();
}
const result = await processChunks();
expect(result).toBe('No change in message after 10 attempts');
});
it('should handle string messages as incomplete', async () => {
const messageText = 'This is a message as a string.';
mockMessageCache.get.mockResolvedValueOnce(messageText);
const result = await processChunks();
expect(result).toEqual([{ text: messageText, isFinished: false }]);
});
});

View file

@ -1,17 +1,19 @@
const throttle = require('lodash/throttle');
const {
Time,
CacheKeys,
StepTypes,
ContentTypes,
ToolCallTypes,
// StepStatus,
MessageContentTypes,
AssistantStreamEvents,
Constants,
} = require('librechat-data-provider');
const { retrieveAndProcessFile } = require('~/server/services/Files/process');
const { processRequiredActions } = require('~/server/services/ToolService');
const { saveMessage, updateMessageText } = require('~/models/Message');
const { createOnProgress, sendMessage } = require('~/server/utils');
const { createOnProgress, sendMessage, sleep } = require('~/server/utils');
const { processMessages } = require('~/server/services/Threads');
const { getLogStores } = require('~/cache');
const { logger } = require('~/config');
/**
@ -68,8 +70,8 @@ class StreamRunManager {
this.attachedFileIds = fields.attachedFileIds;
/** @type {undefined | Promise<ChatCompletion>} */
this.visionPromise = fields.visionPromise;
/** @type {boolean} */
this.savedInitialMessage = false;
/** @type {number} */
this.streamRate = fields.streamRate ?? Constants.DEFAULT_STREAM_RATE;
/**
* @type {Object.<AssistantStreamEvents, (event: AssistantStreamEvent) => Promise<void>>}
@ -139,11 +141,11 @@ class StreamRunManager {
return this.intermediateText;
}
/** Saves the initial intermediate message
* @returns {Promise<void>}
/** Returns the current, intermediate message
* @returns {TMessage}
*/
async saveInitialMessage() {
return saveMessage(this.req, {
getIntermediateMessage() {
return {
conversationId: this.finalMessage.conversationId,
messageId: this.finalMessage.messageId,
parentMessageId: this.parentMessageId,
@ -155,7 +157,7 @@ class StreamRunManager {
sender: 'Assistant',
unfinished: true,
error: false,
});
};
}
/* <------------------ Main Event Handlers ------------------> */
@ -347,6 +349,8 @@ class StreamRunManager {
type: ContentTypes.TOOL_CALL,
index,
});
await sleep(this.streamRate);
}
};
@ -444,6 +448,7 @@ class StreamRunManager {
if (content && content.type === MessageContentTypes.TEXT) {
this.intermediateText += content.text.value;
onProgress(content.text.value);
await sleep(this.streamRate);
}
}
@ -589,21 +594,14 @@ class StreamRunManager {
const index = this.getStepIndex(stepKey);
this.orderedRunSteps.set(index, message_creation);
const messageCache = getLogStores(CacheKeys.MESSAGES);
// Create the Factory Function to stream the message
const { onProgress: progressCallback } = createOnProgress({
onProgress: throttle(
() => {
if (!this.savedInitialMessage) {
this.saveInitialMessage();
this.savedInitialMessage = true;
} else {
updateMessageText({
messageId: this.finalMessage.messageId,
text: this.getText(),
});
}
messageCache.set(this.finalMessage.messageId, this.getText(), Time.FIVE_MINUTES);
},
2000,
3000,
{ trailing: false },
),
});

View file

@ -51,6 +51,7 @@ function assistantsConfigSetup(config, assistantsEndpoint, prevConfig = {}) {
excludedIds: parsedConfig.excludedIds,
privateAssistants: parsedConfig.privateAssistants,
timeoutMs: parsedConfig.timeoutMs,
streamRate: parsedConfig.streamRate,
};
}