mirror of
https://github.com/danny-avila/LibreChat.git
synced 2025-09-22 08:12:00 +02:00
⚙️ feat: Adjust Rate of Stream Progress (#3244)
* chore: bump data-provider and add MESSAGES CacheKey * refactor: avoid saving messages while streaming, save partial text to cache instead * fix(ci): processChunks * chore: logging aborted request to debug * feat: set stream rate for token processing * chore: specify default stream rate * fix(ci): Update AppService.js to use optional chaining for endpointLocals assignment * refactor: abstract the error handler * feat: streamRate for assistants; refactor: update default rate for token * refactor: update error handling in assistants/errors.js * refactor: update error handling in assistants/errors.js
This commit is contained in:
parent
1c282d1517
commit
5d40d0a37a
29 changed files with 661 additions and 309 deletions
|
@ -2,8 +2,9 @@ const Anthropic = require('@anthropic-ai/sdk');
|
||||||
const { HttpsProxyAgent } = require('https-proxy-agent');
|
const { HttpsProxyAgent } = require('https-proxy-agent');
|
||||||
const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken');
|
const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken');
|
||||||
const {
|
const {
|
||||||
getResponseSender,
|
Constants,
|
||||||
EModelEndpoint,
|
EModelEndpoint,
|
||||||
|
getResponseSender,
|
||||||
validateVisionModel,
|
validateVisionModel,
|
||||||
} = require('librechat-data-provider');
|
} = require('librechat-data-provider');
|
||||||
const { encodeAndFormat } = require('~/server/services/Files/images/encode');
|
const { encodeAndFormat } = require('~/server/services/Files/images/encode');
|
||||||
|
@ -16,6 +17,7 @@ const {
|
||||||
} = require('./prompts');
|
} = require('./prompts');
|
||||||
const spendTokens = require('~/models/spendTokens');
|
const spendTokens = require('~/models/spendTokens');
|
||||||
const { getModelMaxTokens } = require('~/utils');
|
const { getModelMaxTokens } = require('~/utils');
|
||||||
|
const { sleep } = require('~/server/utils');
|
||||||
const BaseClient = require('./BaseClient');
|
const BaseClient = require('./BaseClient');
|
||||||
const { logger } = require('~/config');
|
const { logger } = require('~/config');
|
||||||
|
|
||||||
|
@ -605,6 +607,7 @@ class AnthropicClient extends BaseClient {
|
||||||
};
|
};
|
||||||
|
|
||||||
const maxRetries = 3;
|
const maxRetries = 3;
|
||||||
|
const streamRate = this.options.streamRate ?? Constants.DEFAULT_STREAM_RATE;
|
||||||
async function processResponse() {
|
async function processResponse() {
|
||||||
let attempts = 0;
|
let attempts = 0;
|
||||||
|
|
||||||
|
@ -627,6 +630,8 @@ class AnthropicClient extends BaseClient {
|
||||||
} else if (completion.completion) {
|
} else if (completion.completion) {
|
||||||
handleChunk(completion.completion);
|
handleChunk(completion.completion);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
await sleep(streamRate);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Successful processing, exit loop
|
// Successful processing, exit loop
|
||||||
|
|
|
@ -1,10 +1,11 @@
|
||||||
const crypto = require('crypto');
|
const crypto = require('crypto');
|
||||||
const fetch = require('node-fetch');
|
const fetch = require('node-fetch');
|
||||||
const { supportsBalanceCheck, Constants } = require('librechat-data-provider');
|
const { supportsBalanceCheck, Constants, CacheKeys, Time } = require('librechat-data-provider');
|
||||||
const { getMessages, saveMessage, updateMessage, saveConvo } = require('~/models');
|
const { getMessages, saveMessage, updateMessage, saveConvo } = require('~/models');
|
||||||
const { addSpaceIfNeeded, isEnabled } = require('~/server/utils');
|
const { addSpaceIfNeeded, isEnabled } = require('~/server/utils');
|
||||||
const checkBalance = require('~/models/checkBalance');
|
const checkBalance = require('~/models/checkBalance');
|
||||||
const { getFiles } = require('~/models/File');
|
const { getFiles } = require('~/models/File');
|
||||||
|
const { getLogStores } = require('~/cache');
|
||||||
const TextStream = require('./TextStream');
|
const TextStream = require('./TextStream');
|
||||||
const { logger } = require('~/config');
|
const { logger } = require('~/config');
|
||||||
|
|
||||||
|
@ -540,6 +541,15 @@ class BaseClient {
|
||||||
await this.recordTokenUsage({ promptTokens, completionTokens });
|
await this.recordTokenUsage({ promptTokens, completionTokens });
|
||||||
}
|
}
|
||||||
this.responsePromise = this.saveMessageToDatabase(responseMessage, saveOptions, user);
|
this.responsePromise = this.saveMessageToDatabase(responseMessage, saveOptions, user);
|
||||||
|
const messageCache = getLogStores(CacheKeys.MESSAGES);
|
||||||
|
messageCache.set(
|
||||||
|
responseMessageId,
|
||||||
|
{
|
||||||
|
text: responseMessage.text,
|
||||||
|
complete: true,
|
||||||
|
},
|
||||||
|
Time.FIVE_MINUTES,
|
||||||
|
);
|
||||||
delete responseMessage.tokenCount;
|
delete responseMessage.tokenCount;
|
||||||
return responseMessage;
|
return responseMessage;
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,10 +13,12 @@ const {
|
||||||
endpointSettings,
|
endpointSettings,
|
||||||
EModelEndpoint,
|
EModelEndpoint,
|
||||||
VisionModes,
|
VisionModes,
|
||||||
|
Constants,
|
||||||
AuthKeys,
|
AuthKeys,
|
||||||
} = require('librechat-data-provider');
|
} = require('librechat-data-provider');
|
||||||
const { encodeAndFormat } = require('~/server/services/Files/images');
|
const { encodeAndFormat } = require('~/server/services/Files/images');
|
||||||
const { getModelMaxTokens } = require('~/utils');
|
const { getModelMaxTokens } = require('~/utils');
|
||||||
|
const { sleep } = require('~/server/utils');
|
||||||
const { logger } = require('~/config');
|
const { logger } = require('~/config');
|
||||||
const {
|
const {
|
||||||
formatMessage,
|
formatMessage,
|
||||||
|
@ -620,8 +622,9 @@ class GoogleClient extends BaseClient {
|
||||||
}
|
}
|
||||||
|
|
||||||
async getCompletion(_payload, options = {}) {
|
async getCompletion(_payload, options = {}) {
|
||||||
const { onProgress, abortController } = options;
|
|
||||||
const { parameters, instances } = _payload;
|
const { parameters, instances } = _payload;
|
||||||
|
const { onProgress, abortController } = options;
|
||||||
|
const streamRate = this.options.streamRate ?? Constants.DEFAULT_STREAM_RATE;
|
||||||
const { messages: _messages, context, examples: _examples } = instances?.[0] ?? {};
|
const { messages: _messages, context, examples: _examples } = instances?.[0] ?? {};
|
||||||
|
|
||||||
let examples;
|
let examples;
|
||||||
|
@ -701,6 +704,7 @@ class GoogleClient extends BaseClient {
|
||||||
delay,
|
delay,
|
||||||
});
|
});
|
||||||
reply += chunkText;
|
reply += chunkText;
|
||||||
|
await sleep(streamRate);
|
||||||
}
|
}
|
||||||
return reply;
|
return reply;
|
||||||
}
|
}
|
||||||
|
@ -712,10 +716,17 @@ class GoogleClient extends BaseClient {
|
||||||
safetySettings: safetySettings,
|
safetySettings: safetySettings,
|
||||||
});
|
});
|
||||||
|
|
||||||
let delay = this.isGenerativeModel ? 12 : 8;
|
let delay = this.options.streamRate || 8;
|
||||||
|
|
||||||
|
if (!this.options.streamRate) {
|
||||||
|
if (this.isGenerativeModel) {
|
||||||
|
delay = 12;
|
||||||
|
}
|
||||||
if (modelName.includes('flash')) {
|
if (modelName.includes('flash')) {
|
||||||
delay = 5;
|
delay = 5;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for await (const chunk of stream) {
|
for await (const chunk of stream) {
|
||||||
const chunkText = chunk?.content ?? chunk;
|
const chunkText = chunk?.content ?? chunk;
|
||||||
await this.generateTextStream(chunkText, onProgress, {
|
await this.generateTextStream(chunkText, onProgress, {
|
||||||
|
|
|
@ -1,7 +1,9 @@
|
||||||
const { z } = require('zod');
|
const { z } = require('zod');
|
||||||
const axios = require('axios');
|
const axios = require('axios');
|
||||||
const { Ollama } = require('ollama');
|
const { Ollama } = require('ollama');
|
||||||
|
const { Constants } = require('librechat-data-provider');
|
||||||
const { deriveBaseURL } = require('~/utils');
|
const { deriveBaseURL } = require('~/utils');
|
||||||
|
const { sleep } = require('~/server/utils');
|
||||||
const { logger } = require('~/config');
|
const { logger } = require('~/config');
|
||||||
|
|
||||||
const ollamaPayloadSchema = z.object({
|
const ollamaPayloadSchema = z.object({
|
||||||
|
@ -40,6 +42,7 @@ const getValidBase64 = (imageUrl) => {
|
||||||
class OllamaClient {
|
class OllamaClient {
|
||||||
constructor(options = {}) {
|
constructor(options = {}) {
|
||||||
const host = deriveBaseURL(options.baseURL ?? 'http://localhost:11434');
|
const host = deriveBaseURL(options.baseURL ?? 'http://localhost:11434');
|
||||||
|
this.streamRate = options.streamRate ?? Constants.DEFAULT_STREAM_RATE;
|
||||||
/** @type {Ollama} */
|
/** @type {Ollama} */
|
||||||
this.client = new Ollama({ host });
|
this.client = new Ollama({ host });
|
||||||
}
|
}
|
||||||
|
@ -136,6 +139,8 @@ class OllamaClient {
|
||||||
stream.controller.abort();
|
stream.controller.abort();
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
await sleep(this.streamRate);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// TODO: regular completion
|
// TODO: regular completion
|
||||||
|
|
|
@ -1182,8 +1182,10 @@ ${convo}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const streamRate = this.options.streamRate ?? Constants.DEFAULT_STREAM_RATE;
|
||||||
|
|
||||||
if (this.message_file_map && this.isOllama) {
|
if (this.message_file_map && this.isOllama) {
|
||||||
const ollamaClient = new OllamaClient({ baseURL });
|
const ollamaClient = new OllamaClient({ baseURL, streamRate });
|
||||||
return await ollamaClient.chatCompletion({
|
return await ollamaClient.chatCompletion({
|
||||||
payload: modelOptions,
|
payload: modelOptions,
|
||||||
onProgress,
|
onProgress,
|
||||||
|
@ -1221,8 +1223,6 @@ ${convo}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
const azureDelay = this.modelOptions.model?.includes('gpt-4') ? 30 : 17;
|
|
||||||
|
|
||||||
for await (const chunk of stream) {
|
for await (const chunk of stream) {
|
||||||
const token = chunk.choices[0]?.delta?.content || '';
|
const token = chunk.choices[0]?.delta?.content || '';
|
||||||
intermediateReply += token;
|
intermediateReply += token;
|
||||||
|
@ -1232,9 +1232,7 @@ ${convo}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (this.azure) {
|
await sleep(streamRate);
|
||||||
await sleep(azureDelay);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!UnexpectedRoleError) {
|
if (!UnexpectedRoleError) {
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
const OpenAIClient = require('./OpenAIClient');
|
const OpenAIClient = require('./OpenAIClient');
|
||||||
const { CallbackManager } = require('langchain/callbacks');
|
const { CallbackManager } = require('langchain/callbacks');
|
||||||
|
const { CacheKeys, Time } = require('librechat-data-provider');
|
||||||
const { BufferMemory, ChatMessageHistory } = require('langchain/memory');
|
const { BufferMemory, ChatMessageHistory } = require('langchain/memory');
|
||||||
const { initializeCustomAgent, initializeFunctionsAgent } = require('./agents');
|
const { initializeCustomAgent, initializeFunctionsAgent } = require('./agents');
|
||||||
const { addImages, buildErrorInput, buildPromptPrefix } = require('./output_parsers');
|
const { addImages, buildErrorInput, buildPromptPrefix } = require('./output_parsers');
|
||||||
|
@ -11,6 +12,7 @@ const { SelfReflectionTool } = require('./tools');
|
||||||
const { isEnabled } = require('~/server/utils');
|
const { isEnabled } = require('~/server/utils');
|
||||||
const { extractBaseURL } = require('~/utils');
|
const { extractBaseURL } = require('~/utils');
|
||||||
const { loadTools } = require('./tools/util');
|
const { loadTools } = require('./tools/util');
|
||||||
|
const { getLogStores } = require('~/cache');
|
||||||
const { logger } = require('~/config');
|
const { logger } = require('~/config');
|
||||||
|
|
||||||
class PluginsClient extends OpenAIClient {
|
class PluginsClient extends OpenAIClient {
|
||||||
|
@ -220,6 +222,13 @@ class PluginsClient extends OpenAIClient {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @param {TMessage} responseMessage
|
||||||
|
* @param {Partial<TMessage>} saveOptions
|
||||||
|
* @param {string} user
|
||||||
|
* @returns
|
||||||
|
*/
|
||||||
async handleResponseMessage(responseMessage, saveOptions, user) {
|
async handleResponseMessage(responseMessage, saveOptions, user) {
|
||||||
const { output, errorMessage, ...result } = this.result;
|
const { output, errorMessage, ...result } = this.result;
|
||||||
logger.debug('[PluginsClient][handleResponseMessage] Output:', {
|
logger.debug('[PluginsClient][handleResponseMessage] Output:', {
|
||||||
|
@ -239,6 +248,15 @@ class PluginsClient extends OpenAIClient {
|
||||||
}
|
}
|
||||||
|
|
||||||
this.responsePromise = this.saveMessageToDatabase(responseMessage, saveOptions, user);
|
this.responsePromise = this.saveMessageToDatabase(responseMessage, saveOptions, user);
|
||||||
|
const messageCache = getLogStores(CacheKeys.MESSAGES);
|
||||||
|
messageCache.set(
|
||||||
|
responseMessage.messageId,
|
||||||
|
{
|
||||||
|
text: responseMessage.text,
|
||||||
|
complete: true,
|
||||||
|
},
|
||||||
|
Time.FIVE_MINUTES,
|
||||||
|
);
|
||||||
delete responseMessage.tokenCount;
|
delete responseMessage.tokenCount;
|
||||||
return { ...responseMessage, ...result };
|
return { ...responseMessage, ...result };
|
||||||
}
|
}
|
||||||
|
|
25
api/cache/getLogStores.js
vendored
25
api/cache/getLogStores.js
vendored
|
@ -1,13 +1,11 @@
|
||||||
const Keyv = require('keyv');
|
const Keyv = require('keyv');
|
||||||
const { CacheKeys, ViolationTypes } = require('librechat-data-provider');
|
const { CacheKeys, ViolationTypes, Time } = require('librechat-data-provider');
|
||||||
const { logFile, violationFile } = require('./keyvFiles');
|
const { logFile, violationFile } = require('./keyvFiles');
|
||||||
const { math, isEnabled } = require('~/server/utils');
|
const { math, isEnabled } = require('~/server/utils');
|
||||||
const keyvRedis = require('./keyvRedis');
|
const keyvRedis = require('./keyvRedis');
|
||||||
const keyvMongo = require('./keyvMongo');
|
const keyvMongo = require('./keyvMongo');
|
||||||
|
|
||||||
const { BAN_DURATION, USE_REDIS } = process.env ?? {};
|
const { BAN_DURATION, USE_REDIS } = process.env ?? {};
|
||||||
const THIRTY_MINUTES = 1800000;
|
|
||||||
const TEN_MINUTES = 600000;
|
|
||||||
|
|
||||||
const duration = math(BAN_DURATION, 7200000);
|
const duration = math(BAN_DURATION, 7200000);
|
||||||
|
|
||||||
|
@ -29,17 +27,21 @@ const roles = isEnabled(USE_REDIS)
|
||||||
? new Keyv({ store: keyvRedis })
|
? new Keyv({ store: keyvRedis })
|
||||||
: new Keyv({ namespace: CacheKeys.ROLES });
|
: new Keyv({ namespace: CacheKeys.ROLES });
|
||||||
|
|
||||||
const audioRuns = isEnabled(USE_REDIS) // ttl: 30 minutes
|
const audioRuns = isEnabled(USE_REDIS)
|
||||||
? new Keyv({ store: keyvRedis, ttl: TEN_MINUTES })
|
? new Keyv({ store: keyvRedis, ttl: Time.TEN_MINUTES })
|
||||||
: new Keyv({ namespace: CacheKeys.AUDIO_RUNS, ttl: TEN_MINUTES });
|
: new Keyv({ namespace: CacheKeys.AUDIO_RUNS, ttl: Time.TEN_MINUTES });
|
||||||
|
|
||||||
|
const messages = isEnabled(USE_REDIS)
|
||||||
|
? new Keyv({ store: keyvRedis, ttl: Time.FIVE_MINUTES })
|
||||||
|
: new Keyv({ namespace: CacheKeys.MESSAGES, ttl: Time.FIVE_MINUTES });
|
||||||
|
|
||||||
const tokenConfig = isEnabled(USE_REDIS) // ttl: 30 minutes
|
const tokenConfig = isEnabled(USE_REDIS) // ttl: 30 minutes
|
||||||
? new Keyv({ store: keyvRedis, ttl: THIRTY_MINUTES })
|
? new Keyv({ store: keyvRedis, ttl: Time.THIRTY_MINUTES })
|
||||||
: new Keyv({ namespace: CacheKeys.TOKEN_CONFIG, ttl: THIRTY_MINUTES });
|
: new Keyv({ namespace: CacheKeys.TOKEN_CONFIG, ttl: Time.THIRTY_MINUTES });
|
||||||
|
|
||||||
const genTitle = isEnabled(USE_REDIS) // ttl: 2 minutes
|
const genTitle = isEnabled(USE_REDIS) // ttl: 2 minutes
|
||||||
? new Keyv({ store: keyvRedis, ttl: 120000 })
|
? new Keyv({ store: keyvRedis, ttl: Time.TWO_MINUTES })
|
||||||
: new Keyv({ namespace: CacheKeys.GEN_TITLE, ttl: 120000 });
|
: new Keyv({ namespace: CacheKeys.GEN_TITLE, ttl: Time.TWO_MINUTES });
|
||||||
|
|
||||||
const modelQueries = isEnabled(process.env.USE_REDIS)
|
const modelQueries = isEnabled(process.env.USE_REDIS)
|
||||||
? new Keyv({ store: keyvRedis })
|
? new Keyv({ store: keyvRedis })
|
||||||
|
@ -47,7 +49,7 @@ const modelQueries = isEnabled(process.env.USE_REDIS)
|
||||||
|
|
||||||
const abortKeys = isEnabled(USE_REDIS)
|
const abortKeys = isEnabled(USE_REDIS)
|
||||||
? new Keyv({ store: keyvRedis })
|
? new Keyv({ store: keyvRedis })
|
||||||
: new Keyv({ namespace: CacheKeys.ABORT_KEYS, ttl: 600000 });
|
: new Keyv({ namespace: CacheKeys.ABORT_KEYS, ttl: Time.TEN_MINUTES });
|
||||||
|
|
||||||
const namespaces = {
|
const namespaces = {
|
||||||
[CacheKeys.ROLES]: roles,
|
[CacheKeys.ROLES]: roles,
|
||||||
|
@ -81,6 +83,7 @@ const namespaces = {
|
||||||
[CacheKeys.GEN_TITLE]: genTitle,
|
[CacheKeys.GEN_TITLE]: genTitle,
|
||||||
[CacheKeys.MODEL_QUERIES]: modelQueries,
|
[CacheKeys.MODEL_QUERIES]: modelQueries,
|
||||||
[CacheKeys.AUDIO_RUNS]: audioRuns,
|
[CacheKeys.AUDIO_RUNS]: audioRuns,
|
||||||
|
[CacheKeys.MESSAGES]: messages,
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -1,7 +1,8 @@
|
||||||
const throttle = require('lodash/throttle');
|
const throttle = require('lodash/throttle');
|
||||||
const { getResponseSender, Constants, EModelEndpoint } = require('librechat-data-provider');
|
const { getResponseSender, Constants, CacheKeys, Time } = require('librechat-data-provider');
|
||||||
const { createAbortController, handleAbortError } = require('~/server/middleware');
|
const { createAbortController, handleAbortError } = require('~/server/middleware');
|
||||||
const { sendMessage, createOnProgress } = require('~/server/utils');
|
const { sendMessage, createOnProgress } = require('~/server/utils');
|
||||||
|
const { getLogStores } = require('~/cache');
|
||||||
const { saveMessage } = require('~/models');
|
const { saveMessage } = require('~/models');
|
||||||
const { logger } = require('~/config');
|
const { logger } = require('~/config');
|
||||||
|
|
||||||
|
@ -51,11 +52,13 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const { client } = await initializeClient({ req, res, endpointOption });
|
const { client } = await initializeClient({ req, res, endpointOption });
|
||||||
const unfinished = endpointOption.endpoint === EModelEndpoint.google ? false : true;
|
const messageCache = getLogStores(CacheKeys.MESSAGES);
|
||||||
const { onProgress: progressCallback, getPartialText } = createOnProgress({
|
const { onProgress: progressCallback, getPartialText } = createOnProgress({
|
||||||
onProgress: throttle(
|
onProgress: throttle(
|
||||||
({ text: partialText }) => {
|
({ text: partialText }) => {
|
||||||
saveMessage(req, {
|
/*
|
||||||
|
const unfinished = endpointOption.endpoint === EModelEndpoint.google ? false : true;
|
||||||
|
messageCache.set(responseMessageId, {
|
||||||
messageId: responseMessageId,
|
messageId: responseMessageId,
|
||||||
sender,
|
sender,
|
||||||
conversationId,
|
conversationId,
|
||||||
|
@ -65,7 +68,10 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
|
||||||
unfinished,
|
unfinished,
|
||||||
error: false,
|
error: false,
|
||||||
user,
|
user,
|
||||||
});
|
}, Time.FIVE_MINUTES);
|
||||||
|
*/
|
||||||
|
|
||||||
|
messageCache.set(responseMessageId, partialText, Time.FIVE_MINUTES);
|
||||||
},
|
},
|
||||||
3000,
|
3000,
|
||||||
{ trailing: false },
|
{ trailing: false },
|
||||||
|
|
|
@ -1,7 +1,8 @@
|
||||||
const throttle = require('lodash/throttle');
|
const throttle = require('lodash/throttle');
|
||||||
const { getResponseSender, EModelEndpoint } = require('librechat-data-provider');
|
const { getResponseSender, CacheKeys, Time } = require('librechat-data-provider');
|
||||||
const { createAbortController, handleAbortError } = require('~/server/middleware');
|
const { createAbortController, handleAbortError } = require('~/server/middleware');
|
||||||
const { sendMessage, createOnProgress } = require('~/server/utils');
|
const { sendMessage, createOnProgress } = require('~/server/utils');
|
||||||
|
const { getLogStores } = require('~/cache');
|
||||||
const { saveMessage } = require('~/models');
|
const { saveMessage } = require('~/models');
|
||||||
const { logger } = require('~/config');
|
const { logger } = require('~/config');
|
||||||
|
|
||||||
|
@ -51,12 +52,14 @@ const EditController = async (req, res, next, initializeClient) => {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
const unfinished = endpointOption.endpoint === EModelEndpoint.google ? false : true;
|
const messageCache = getLogStores(CacheKeys.MESSAGES);
|
||||||
const { onProgress: progressCallback, getPartialText } = createOnProgress({
|
const { onProgress: progressCallback, getPartialText } = createOnProgress({
|
||||||
generation,
|
generation,
|
||||||
onProgress: throttle(
|
onProgress: throttle(
|
||||||
({ text: partialText }) => {
|
({ text: partialText }) => {
|
||||||
saveMessage(req, {
|
/*
|
||||||
|
const unfinished = endpointOption.endpoint === EModelEndpoint.google ? false : true;
|
||||||
|
{
|
||||||
messageId: responseMessageId,
|
messageId: responseMessageId,
|
||||||
sender,
|
sender,
|
||||||
conversationId,
|
conversationId,
|
||||||
|
@ -67,7 +70,8 @@ const EditController = async (req, res, next, initializeClient) => {
|
||||||
isEdited: true,
|
isEdited: true,
|
||||||
error: false,
|
error: false,
|
||||||
user,
|
user,
|
||||||
});
|
} */
|
||||||
|
messageCache.set(responseMessageId, partialText, Time.FIVE_MINUTES);
|
||||||
},
|
},
|
||||||
3000,
|
3000,
|
||||||
{ trailing: false },
|
{ trailing: false },
|
||||||
|
|
|
@ -1,12 +1,12 @@
|
||||||
const { v4 } = require('uuid');
|
const { v4 } = require('uuid');
|
||||||
const {
|
const {
|
||||||
|
Time,
|
||||||
Constants,
|
Constants,
|
||||||
RunStatus,
|
RunStatus,
|
||||||
CacheKeys,
|
CacheKeys,
|
||||||
ContentTypes,
|
ContentTypes,
|
||||||
ToolCallTypes,
|
ToolCallTypes,
|
||||||
EModelEndpoint,
|
EModelEndpoint,
|
||||||
ViolationTypes,
|
|
||||||
retrievalMimeTypes,
|
retrievalMimeTypes,
|
||||||
AssistantStreamEvents,
|
AssistantStreamEvents,
|
||||||
} = require('librechat-data-provider');
|
} = require('librechat-data-provider');
|
||||||
|
@ -14,12 +14,12 @@ const {
|
||||||
initThread,
|
initThread,
|
||||||
recordUsage,
|
recordUsage,
|
||||||
saveUserMessage,
|
saveUserMessage,
|
||||||
checkMessageGaps,
|
|
||||||
addThreadMetadata,
|
addThreadMetadata,
|
||||||
saveAssistantMessage,
|
saveAssistantMessage,
|
||||||
} = require('~/server/services/Threads');
|
} = require('~/server/services/Threads');
|
||||||
const { sendResponse, sendMessage, sleep, isEnabled, countTokens } = require('~/server/utils');
|
|
||||||
const { runAssistant, createOnTextProgress } = require('~/server/services/AssistantService');
|
const { runAssistant, createOnTextProgress } = require('~/server/services/AssistantService');
|
||||||
|
const { sendMessage, sleep, isEnabled, countTokens } = require('~/server/utils');
|
||||||
|
const { createErrorHandler } = require('~/server/controllers/assistants/errors');
|
||||||
const validateAuthor = require('~/server/middleware/assistants/validateAuthor');
|
const validateAuthor = require('~/server/middleware/assistants/validateAuthor');
|
||||||
const { createRun, StreamRunManager } = require('~/server/services/Runs');
|
const { createRun, StreamRunManager } = require('~/server/services/Runs');
|
||||||
const { addTitle } = require('~/server/services/Endpoints/assistants');
|
const { addTitle } = require('~/server/services/Endpoints/assistants');
|
||||||
|
@ -90,140 +90,20 @@ const chatV2 = async (req, res) => {
|
||||||
/** @type {Run | undefined} - The completed run, undefined if incomplete */
|
/** @type {Run | undefined} - The completed run, undefined if incomplete */
|
||||||
let completedRun;
|
let completedRun;
|
||||||
|
|
||||||
const handleError = async (error) => {
|
const getContext = () => ({
|
||||||
const defaultErrorMessage =
|
|
||||||
'The Assistant run failed to initialize. Try sending a message in a new conversation.';
|
|
||||||
const messageData = {
|
|
||||||
thread_id,
|
|
||||||
assistant_id,
|
|
||||||
conversationId,
|
|
||||||
parentMessageId,
|
|
||||||
sender: 'System',
|
|
||||||
user: req.user.id,
|
|
||||||
shouldSaveMessage: false,
|
|
||||||
messageId: responseMessageId,
|
|
||||||
endpoint,
|
|
||||||
};
|
|
||||||
|
|
||||||
if (error.message === 'Run cancelled') {
|
|
||||||
return res.end();
|
|
||||||
} else if (error.message === 'Request closed' && completedRun) {
|
|
||||||
return;
|
|
||||||
} else if (error.message === 'Request closed') {
|
|
||||||
logger.debug('[/assistants/chat/] Request aborted on close');
|
|
||||||
} else if (/Files.*are invalid/.test(error.message)) {
|
|
||||||
const errorMessage = `Files are invalid, or may not have uploaded yet.${
|
|
||||||
endpoint === EModelEndpoint.azureAssistants
|
|
||||||
? ' If using Azure OpenAI, files are only available in the region of the assistant\'s model at the time of upload.'
|
|
||||||
: ''
|
|
||||||
}`;
|
|
||||||
return sendResponse(req, res, messageData, errorMessage);
|
|
||||||
} else if (error?.message?.includes('string too long')) {
|
|
||||||
return sendResponse(
|
|
||||||
req,
|
|
||||||
res,
|
|
||||||
messageData,
|
|
||||||
'Message too long. The Assistants API has a limit of 32,768 characters per message. Please shorten it and try again.',
|
|
||||||
);
|
|
||||||
} else if (error?.message?.includes(ViolationTypes.TOKEN_BALANCE)) {
|
|
||||||
return sendResponse(req, res, messageData, error.message);
|
|
||||||
} else {
|
|
||||||
logger.error('[/assistants/chat/]', error);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!openai || !thread_id || !run_id) {
|
|
||||||
return sendResponse(req, res, messageData, defaultErrorMessage);
|
|
||||||
}
|
|
||||||
|
|
||||||
await sleep(2000);
|
|
||||||
|
|
||||||
try {
|
|
||||||
const status = await cache.get(cacheKey);
|
|
||||||
if (status === 'cancelled') {
|
|
||||||
logger.debug('[/assistants/chat/] Run already cancelled');
|
|
||||||
return res.end();
|
|
||||||
}
|
|
||||||
await cache.delete(cacheKey);
|
|
||||||
const cancelledRun = await openai.beta.threads.runs.cancel(thread_id, run_id);
|
|
||||||
logger.debug('[/assistants/chat/] Cancelled run:', cancelledRun);
|
|
||||||
} catch (error) {
|
|
||||||
logger.error('[/assistants/chat/] Error cancelling run', error);
|
|
||||||
}
|
|
||||||
|
|
||||||
await sleep(2000);
|
|
||||||
|
|
||||||
let run;
|
|
||||||
try {
|
|
||||||
run = await openai.beta.threads.runs.retrieve(thread_id, run_id);
|
|
||||||
await recordUsage({
|
|
||||||
...run.usage,
|
|
||||||
model: run.model,
|
|
||||||
user: req.user.id,
|
|
||||||
conversationId,
|
|
||||||
});
|
|
||||||
} catch (error) {
|
|
||||||
logger.error('[/assistants/chat/] Error fetching or processing run', error);
|
|
||||||
}
|
|
||||||
|
|
||||||
let finalEvent;
|
|
||||||
try {
|
|
||||||
const runMessages = await checkMessageGaps({
|
|
||||||
openai,
|
openai,
|
||||||
run_id,
|
run_id,
|
||||||
endpoint,
|
endpoint,
|
||||||
|
cacheKey,
|
||||||
thread_id,
|
thread_id,
|
||||||
|
completedRun,
|
||||||
|
assistant_id,
|
||||||
conversationId,
|
conversationId,
|
||||||
latestMessageId: responseMessageId,
|
parentMessageId,
|
||||||
|
responseMessageId,
|
||||||
});
|
});
|
||||||
|
|
||||||
const errorContentPart = {
|
const handleError = createErrorHandler({ req, res, getContext });
|
||||||
text: {
|
|
||||||
value:
|
|
||||||
error?.message ?? 'There was an error processing your request. Please try again later.',
|
|
||||||
},
|
|
||||||
type: ContentTypes.ERROR,
|
|
||||||
};
|
|
||||||
|
|
||||||
if (!Array.isArray(runMessages[runMessages.length - 1]?.content)) {
|
|
||||||
runMessages[runMessages.length - 1].content = [errorContentPart];
|
|
||||||
} else {
|
|
||||||
const contentParts = runMessages[runMessages.length - 1].content;
|
|
||||||
for (let i = 0; i < contentParts.length; i++) {
|
|
||||||
const currentPart = contentParts[i];
|
|
||||||
/** @type {CodeToolCall | RetrievalToolCall | FunctionToolCall | undefined} */
|
|
||||||
const toolCall = currentPart?.[ContentTypes.TOOL_CALL];
|
|
||||||
if (
|
|
||||||
toolCall &&
|
|
||||||
toolCall?.function &&
|
|
||||||
!(toolCall?.function?.output || toolCall?.function?.output?.length)
|
|
||||||
) {
|
|
||||||
contentParts[i] = {
|
|
||||||
...currentPart,
|
|
||||||
[ContentTypes.TOOL_CALL]: {
|
|
||||||
...toolCall,
|
|
||||||
function: {
|
|
||||||
...toolCall.function,
|
|
||||||
output: 'error processing tool',
|
|
||||||
},
|
|
||||||
},
|
|
||||||
};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
runMessages[runMessages.length - 1].content.push(errorContentPart);
|
|
||||||
}
|
|
||||||
|
|
||||||
finalEvent = {
|
|
||||||
final: true,
|
|
||||||
conversation: await getConvo(req.user.id, conversationId),
|
|
||||||
runMessages,
|
|
||||||
};
|
|
||||||
} catch (error) {
|
|
||||||
logger.error('[/assistants/chat/] Error finalizing error process', error);
|
|
||||||
return sendResponse(req, res, messageData, 'The Assistant run failed');
|
|
||||||
}
|
|
||||||
|
|
||||||
return sendResponse(req, res, finalEvent);
|
|
||||||
};
|
|
||||||
|
|
||||||
try {
|
try {
|
||||||
res.on('close', async () => {
|
res.on('close', async () => {
|
||||||
|
@ -490,6 +370,11 @@ const chatV2 = async (req, res) => {
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/** @type {undefined | TAssistantEndpoint} */
|
||||||
|
const config = req.app.locals[endpoint] ?? {};
|
||||||
|
/** @type {undefined | TBaseEndpoint} */
|
||||||
|
const allConfig = req.app.locals.all;
|
||||||
|
|
||||||
const streamRunManager = new StreamRunManager({
|
const streamRunManager = new StreamRunManager({
|
||||||
req,
|
req,
|
||||||
res,
|
res,
|
||||||
|
@ -499,6 +384,7 @@ const chatV2 = async (req, res) => {
|
||||||
attachedFileIds,
|
attachedFileIds,
|
||||||
parentMessageId: userMessageId,
|
parentMessageId: userMessageId,
|
||||||
responseMessage: openai.responseMessage,
|
responseMessage: openai.responseMessage,
|
||||||
|
streamRate: allConfig?.streamRate ?? config.streamRate,
|
||||||
// streamOptions: {
|
// streamOptions: {
|
||||||
|
|
||||||
// },
|
// },
|
||||||
|
@ -511,6 +397,16 @@ const chatV2 = async (req, res) => {
|
||||||
|
|
||||||
response = streamRunManager;
|
response = streamRunManager;
|
||||||
response.text = streamRunManager.intermediateText;
|
response.text = streamRunManager.intermediateText;
|
||||||
|
|
||||||
|
const messageCache = getLogStores(CacheKeys.MESSAGES);
|
||||||
|
messageCache.set(
|
||||||
|
responseMessageId,
|
||||||
|
{
|
||||||
|
complete: true,
|
||||||
|
text: response.text,
|
||||||
|
},
|
||||||
|
Time.FIVE_MINUTES,
|
||||||
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
await processRun();
|
await processRun();
|
||||||
|
|
193
api/server/controllers/assistants/errors.js
Normal file
193
api/server/controllers/assistants/errors.js
Normal file
|
@ -0,0 +1,193 @@
|
||||||
|
// errorHandler.js
|
||||||
|
const { sendResponse } = require('~/server/utils');
|
||||||
|
const { logger } = require('~/config');
|
||||||
|
const getLogStores = require('~/cache/getLogStores');
|
||||||
|
const { CacheKeys, ViolationTypes, ContentTypes } = require('librechat-data-provider');
|
||||||
|
const { getConvo } = require('~/models/Conversation');
|
||||||
|
const { recordUsage, checkMessageGaps } = require('~/server/services/Threads');
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @typedef {Object} ErrorHandlerContext
|
||||||
|
* @property {OpenAIClient} openai - The OpenAI client
|
||||||
|
* @property {string} thread_id - The thread ID
|
||||||
|
* @property {string} run_id - The run ID
|
||||||
|
* @property {boolean} completedRun - Whether the run has completed
|
||||||
|
* @property {string} assistant_id - The assistant ID
|
||||||
|
* @property {string} conversationId - The conversation ID
|
||||||
|
* @property {string} parentMessageId - The parent message ID
|
||||||
|
* @property {string} responseMessageId - The response message ID
|
||||||
|
* @property {string} endpoint - The endpoint being used
|
||||||
|
* @property {string} cacheKey - The cache key for the current request
|
||||||
|
*/
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @typedef {Object} ErrorHandlerDependencies
|
||||||
|
* @property {Express.Request} req - The Express request object
|
||||||
|
* @property {Express.Response} res - The Express response object
|
||||||
|
* @property {() => ErrorHandlerContext} getContext - Function to get the current context
|
||||||
|
* @property {string} [originPath] - The origin path for the error handler
|
||||||
|
*/
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates an error handler function with the given dependencies
|
||||||
|
* @param {ErrorHandlerDependencies} dependencies - The dependencies for the error handler
|
||||||
|
* @returns {(error: Error) => Promise<void>} The error handler function
|
||||||
|
*/
|
||||||
|
const createErrorHandler = ({ req, res, getContext, originPath = '/assistants/chat/' }) => {
|
||||||
|
const cache = getLogStores(CacheKeys.ABORT_KEYS);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Handles errors that occur during the chat process
|
||||||
|
* @param {Error} error - The error that occurred
|
||||||
|
* @returns {Promise<void>}
|
||||||
|
*/
|
||||||
|
return async (error) => {
|
||||||
|
const {
|
||||||
|
openai,
|
||||||
|
run_id,
|
||||||
|
endpoint,
|
||||||
|
cacheKey,
|
||||||
|
thread_id,
|
||||||
|
completedRun,
|
||||||
|
assistant_id,
|
||||||
|
conversationId,
|
||||||
|
parentMessageId,
|
||||||
|
responseMessageId,
|
||||||
|
} = getContext();
|
||||||
|
|
||||||
|
const defaultErrorMessage =
|
||||||
|
'The Assistant run failed to initialize. Try sending a message in a new conversation.';
|
||||||
|
const messageData = {
|
||||||
|
thread_id,
|
||||||
|
assistant_id,
|
||||||
|
conversationId,
|
||||||
|
parentMessageId,
|
||||||
|
sender: 'System',
|
||||||
|
user: req.user.id,
|
||||||
|
shouldSaveMessage: false,
|
||||||
|
messageId: responseMessageId,
|
||||||
|
endpoint,
|
||||||
|
};
|
||||||
|
|
||||||
|
if (error.message === 'Run cancelled') {
|
||||||
|
return res.end();
|
||||||
|
} else if (error.message === 'Request closed' && completedRun) {
|
||||||
|
return;
|
||||||
|
} else if (error.message === 'Request closed') {
|
||||||
|
logger.debug(`[${originPath}] Request aborted on close`);
|
||||||
|
} else if (/Files.*are invalid/.test(error.message)) {
|
||||||
|
const errorMessage = `Files are invalid, or may not have uploaded yet.${
|
||||||
|
endpoint === 'azureAssistants'
|
||||||
|
? ' If using Azure OpenAI, files are only available in the region of the assistant\'s model at the time of upload.'
|
||||||
|
: ''
|
||||||
|
}`;
|
||||||
|
return sendResponse(req, res, messageData, errorMessage);
|
||||||
|
} else if (error?.message?.includes('string too long')) {
|
||||||
|
return sendResponse(
|
||||||
|
req,
|
||||||
|
res,
|
||||||
|
messageData,
|
||||||
|
'Message too long. The Assistants API has a limit of 32,768 characters per message. Please shorten it and try again.',
|
||||||
|
);
|
||||||
|
} else if (error?.message?.includes(ViolationTypes.TOKEN_BALANCE)) {
|
||||||
|
return sendResponse(req, res, messageData, error.message);
|
||||||
|
} else {
|
||||||
|
logger.error(`[${originPath}]`, error);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!openai || !thread_id || !run_id) {
|
||||||
|
return sendResponse(req, res, messageData, defaultErrorMessage);
|
||||||
|
}
|
||||||
|
|
||||||
|
await new Promise((resolve) => setTimeout(resolve, 2000));
|
||||||
|
|
||||||
|
try {
|
||||||
|
const status = await cache.get(cacheKey);
|
||||||
|
if (status === 'cancelled') {
|
||||||
|
logger.debug(`[${originPath}] Run already cancelled`);
|
||||||
|
return res.end();
|
||||||
|
}
|
||||||
|
await cache.delete(cacheKey);
|
||||||
|
const cancelledRun = await openai.beta.threads.runs.cancel(thread_id, run_id);
|
||||||
|
logger.debug(`[${originPath}] Cancelled run:`, cancelledRun);
|
||||||
|
} catch (error) {
|
||||||
|
logger.error(`[${originPath}] Error cancelling run`, error);
|
||||||
|
}
|
||||||
|
|
||||||
|
await new Promise((resolve) => setTimeout(resolve, 2000));
|
||||||
|
|
||||||
|
let run;
|
||||||
|
try {
|
||||||
|
run = await openai.beta.threads.runs.retrieve(thread_id, run_id);
|
||||||
|
await recordUsage({
|
||||||
|
...run.usage,
|
||||||
|
model: run.model,
|
||||||
|
user: req.user.id,
|
||||||
|
conversationId,
|
||||||
|
});
|
||||||
|
} catch (error) {
|
||||||
|
logger.error(`[${originPath}] Error fetching or processing run`, error);
|
||||||
|
}
|
||||||
|
|
||||||
|
let finalEvent;
|
||||||
|
try {
|
||||||
|
const runMessages = await checkMessageGaps({
|
||||||
|
openai,
|
||||||
|
run_id,
|
||||||
|
endpoint,
|
||||||
|
thread_id,
|
||||||
|
conversationId,
|
||||||
|
latestMessageId: responseMessageId,
|
||||||
|
});
|
||||||
|
|
||||||
|
const errorContentPart = {
|
||||||
|
text: {
|
||||||
|
value:
|
||||||
|
error?.message ?? 'There was an error processing your request. Please try again later.',
|
||||||
|
},
|
||||||
|
type: ContentTypes.ERROR,
|
||||||
|
};
|
||||||
|
|
||||||
|
if (!Array.isArray(runMessages[runMessages.length - 1]?.content)) {
|
||||||
|
runMessages[runMessages.length - 1].content = [errorContentPart];
|
||||||
|
} else {
|
||||||
|
const contentParts = runMessages[runMessages.length - 1].content;
|
||||||
|
for (let i = 0; i < contentParts.length; i++) {
|
||||||
|
const currentPart = contentParts[i];
|
||||||
|
/** @type {CodeToolCall | RetrievalToolCall | FunctionToolCall | undefined} */
|
||||||
|
const toolCall = currentPart?.[ContentTypes.TOOL_CALL];
|
||||||
|
if (
|
||||||
|
toolCall &&
|
||||||
|
toolCall?.function &&
|
||||||
|
!(toolCall?.function?.output || toolCall?.function?.output?.length)
|
||||||
|
) {
|
||||||
|
contentParts[i] = {
|
||||||
|
...currentPart,
|
||||||
|
[ContentTypes.TOOL_CALL]: {
|
||||||
|
...toolCall,
|
||||||
|
function: {
|
||||||
|
...toolCall.function,
|
||||||
|
output: 'error processing tool',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
runMessages[runMessages.length - 1].content.push(errorContentPart);
|
||||||
|
}
|
||||||
|
|
||||||
|
finalEvent = {
|
||||||
|
final: true,
|
||||||
|
conversation: await getConvo(req.user.id, conversationId),
|
||||||
|
runMessages,
|
||||||
|
};
|
||||||
|
} catch (error) {
|
||||||
|
logger.error(`[${originPath}] Error finalizing error process`, error);
|
||||||
|
return sendResponse(req, res, messageData, 'The Assistant run failed');
|
||||||
|
}
|
||||||
|
|
||||||
|
return sendResponse(req, res, finalEvent);
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
module.exports = { createErrorHandler };
|
|
@ -30,7 +30,10 @@ async function abortMessage(req, res) {
|
||||||
return res.status(204).send({ message: 'Request not found' });
|
return res.status(204).send({ message: 'Request not found' });
|
||||||
}
|
}
|
||||||
const finalEvent = await abortController.abortCompletion();
|
const finalEvent = await abortController.abortCompletion();
|
||||||
logger.info('[abortMessage] Aborted request', { abortKey });
|
logger.debug(
|
||||||
|
`[abortMessage] ID: ${req.user.id} | ${req.user.email} | Aborted request: ` +
|
||||||
|
JSON.stringify({ abortKey }),
|
||||||
|
);
|
||||||
abortControllers.delete(abortKey);
|
abortControllers.delete(abortKey);
|
||||||
|
|
||||||
if (res.headersSent && finalEvent) {
|
if (res.headersSent && finalEvent) {
|
||||||
|
|
|
@ -1,10 +1,11 @@
|
||||||
const express = require('express');
|
const express = require('express');
|
||||||
const throttle = require('lodash/throttle');
|
const throttle = require('lodash/throttle');
|
||||||
const { getResponseSender, Constants } = require('librechat-data-provider');
|
const { getResponseSender, Constants, CacheKeys, Time } = require('librechat-data-provider');
|
||||||
const { initializeClient } = require('~/server/services/Endpoints/gptPlugins');
|
const { initializeClient } = require('~/server/services/Endpoints/gptPlugins');
|
||||||
const { sendMessage, createOnProgress } = require('~/server/utils');
|
const { sendMessage, createOnProgress } = require('~/server/utils');
|
||||||
const { addTitle } = require('~/server/services/Endpoints/openAI');
|
const { addTitle } = require('~/server/services/Endpoints/openAI');
|
||||||
const { saveMessage } = require('~/models');
|
const { saveMessage } = require('~/models');
|
||||||
|
const { getLogStores } = require('~/cache');
|
||||||
const {
|
const {
|
||||||
handleAbort,
|
handleAbort,
|
||||||
createAbortController,
|
createAbortController,
|
||||||
|
@ -71,7 +72,8 @@ router.post(
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
const throttledSaveMessage = throttle(saveMessage, 3000, { trailing: false });
|
const messageCache = getLogStores(CacheKeys.MESSAGES);
|
||||||
|
const throttledSetMessage = throttle(messageCache.set, 3000, { trailing: false });
|
||||||
let streaming = null;
|
let streaming = null;
|
||||||
let timer = null;
|
let timer = null;
|
||||||
|
|
||||||
|
@ -85,7 +87,8 @@ router.post(
|
||||||
clearTimeout(timer);
|
clearTimeout(timer);
|
||||||
}
|
}
|
||||||
|
|
||||||
throttledSaveMessage(req, {
|
/*
|
||||||
|
{
|
||||||
messageId: responseMessageId,
|
messageId: responseMessageId,
|
||||||
sender,
|
sender,
|
||||||
conversationId,
|
conversationId,
|
||||||
|
@ -96,7 +99,9 @@ router.post(
|
||||||
error: false,
|
error: false,
|
||||||
plugins,
|
plugins,
|
||||||
user,
|
user,
|
||||||
});
|
}
|
||||||
|
*/
|
||||||
|
throttledSetMessage(responseMessageId, partialText, Time.FIVE_MINUTES);
|
||||||
|
|
||||||
streaming = new Promise((resolve) => {
|
streaming = new Promise((resolve) => {
|
||||||
timer = setTimeout(() => {
|
timer = setTimeout(() => {
|
||||||
|
|
|
@ -1,19 +1,20 @@
|
||||||
const express = require('express');
|
const express = require('express');
|
||||||
const throttle = require('lodash/throttle');
|
const throttle = require('lodash/throttle');
|
||||||
const { getResponseSender } = require('librechat-data-provider');
|
const { getResponseSender, CacheKeys, Time } = require('librechat-data-provider');
|
||||||
const {
|
const {
|
||||||
handleAbort,
|
|
||||||
createAbortController,
|
|
||||||
handleAbortError,
|
|
||||||
setHeaders,
|
setHeaders,
|
||||||
|
handleAbort,
|
||||||
|
moderateText,
|
||||||
validateModel,
|
validateModel,
|
||||||
|
handleAbortError,
|
||||||
validateEndpoint,
|
validateEndpoint,
|
||||||
buildEndpointOption,
|
buildEndpointOption,
|
||||||
moderateText,
|
createAbortController,
|
||||||
} = require('~/server/middleware');
|
} = require('~/server/middleware');
|
||||||
const { sendMessage, createOnProgress, formatSteps, formatAction } = require('~/server/utils');
|
const { sendMessage, createOnProgress, formatSteps, formatAction } = require('~/server/utils');
|
||||||
const { initializeClient } = require('~/server/services/Endpoints/gptPlugins');
|
const { initializeClient } = require('~/server/services/Endpoints/gptPlugins');
|
||||||
const { saveMessage } = require('~/models');
|
const { saveMessage } = require('~/models');
|
||||||
|
const { getLogStores } = require('~/cache');
|
||||||
const { validateTools } = require('~/app');
|
const { validateTools } = require('~/app');
|
||||||
const { logger } = require('~/config');
|
const { logger } = require('~/config');
|
||||||
|
|
||||||
|
@ -79,7 +80,8 @@ router.post(
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
const throttledSaveMessage = throttle(saveMessage, 3000, { trailing: false });
|
const messageCache = getLogStores(CacheKeys.MESSAGES);
|
||||||
|
const throttledSetMessage = throttle(messageCache.set, 3000, { trailing: false });
|
||||||
const {
|
const {
|
||||||
onProgress: progressCallback,
|
onProgress: progressCallback,
|
||||||
sendIntermediateMessage,
|
sendIntermediateMessage,
|
||||||
|
@ -91,7 +93,8 @@ router.post(
|
||||||
plugin.loading = false;
|
plugin.loading = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
throttledSaveMessage(req, {
|
/*
|
||||||
|
{
|
||||||
messageId: responseMessageId,
|
messageId: responseMessageId,
|
||||||
sender,
|
sender,
|
||||||
conversationId,
|
conversationId,
|
||||||
|
@ -102,7 +105,9 @@ router.post(
|
||||||
isEdited: true,
|
isEdited: true,
|
||||||
error: false,
|
error: false,
|
||||||
user,
|
user,
|
||||||
});
|
}
|
||||||
|
*/
|
||||||
|
throttledSetMessage(responseMessageId, partialText, Time.FIVE_MINUTES);
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
|
@ -67,17 +67,18 @@ const AppService = async (app) => {
|
||||||
handleRateLimits(config?.rateLimits);
|
handleRateLimits(config?.rateLimits);
|
||||||
|
|
||||||
const endpointLocals = {};
|
const endpointLocals = {};
|
||||||
|
const endpoints = config?.endpoints;
|
||||||
|
|
||||||
if (config?.endpoints?.[EModelEndpoint.azureOpenAI]) {
|
if (endpoints?.[EModelEndpoint.azureOpenAI]) {
|
||||||
endpointLocals[EModelEndpoint.azureOpenAI] = azureConfigSetup(config);
|
endpointLocals[EModelEndpoint.azureOpenAI] = azureConfigSetup(config);
|
||||||
checkAzureVariables();
|
checkAzureVariables();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (config?.endpoints?.[EModelEndpoint.azureOpenAI]?.assistants) {
|
if (endpoints?.[EModelEndpoint.azureOpenAI]?.assistants) {
|
||||||
endpointLocals[EModelEndpoint.azureAssistants] = azureAssistantsDefaults();
|
endpointLocals[EModelEndpoint.azureAssistants] = azureAssistantsDefaults();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (config?.endpoints?.[EModelEndpoint.azureAssistants]) {
|
if (endpoints?.[EModelEndpoint.azureAssistants]) {
|
||||||
endpointLocals[EModelEndpoint.azureAssistants] = assistantsConfigSetup(
|
endpointLocals[EModelEndpoint.azureAssistants] = assistantsConfigSetup(
|
||||||
config,
|
config,
|
||||||
EModelEndpoint.azureAssistants,
|
EModelEndpoint.azureAssistants,
|
||||||
|
@ -85,7 +86,7 @@ const AppService = async (app) => {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (config?.endpoints?.[EModelEndpoint.assistants]) {
|
if (endpoints?.[EModelEndpoint.assistants]) {
|
||||||
endpointLocals[EModelEndpoint.assistants] = assistantsConfigSetup(
|
endpointLocals[EModelEndpoint.assistants] = assistantsConfigSetup(
|
||||||
config,
|
config,
|
||||||
EModelEndpoint.assistants,
|
EModelEndpoint.assistants,
|
||||||
|
@ -93,6 +94,19 @@ const AppService = async (app) => {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (endpoints?.[EModelEndpoint.openAI]) {
|
||||||
|
endpointLocals[EModelEndpoint.openAI] = endpoints[EModelEndpoint.openAI];
|
||||||
|
}
|
||||||
|
if (endpoints?.[EModelEndpoint.google]) {
|
||||||
|
endpointLocals[EModelEndpoint.google] = endpoints[EModelEndpoint.google];
|
||||||
|
}
|
||||||
|
if (endpoints?.[EModelEndpoint.anthropic]) {
|
||||||
|
endpointLocals[EModelEndpoint.anthropic] = endpoints[EModelEndpoint.anthropic];
|
||||||
|
}
|
||||||
|
if (endpoints?.[EModelEndpoint.gptPlugins]) {
|
||||||
|
endpointLocals[EModelEndpoint.gptPlugins] = endpoints[EModelEndpoint.gptPlugins];
|
||||||
|
}
|
||||||
|
|
||||||
app.locals = {
|
app.locals = {
|
||||||
...defaultLocals,
|
...defaultLocals,
|
||||||
modelSpecs: config.modelSpecs,
|
modelSpecs: config.modelSpecs,
|
||||||
|
|
|
@ -19,11 +19,27 @@ const initializeClient = async ({ req, res, endpointOption }) => {
|
||||||
checkUserKeyExpiry(expiresAt, EModelEndpoint.anthropic);
|
checkUserKeyExpiry(expiresAt, EModelEndpoint.anthropic);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const clientOptions = {};
|
||||||
|
|
||||||
|
/** @type {undefined | TBaseEndpoint} */
|
||||||
|
const anthropicConfig = req.app.locals[EModelEndpoint.anthropic];
|
||||||
|
|
||||||
|
if (anthropicConfig) {
|
||||||
|
clientOptions.streamRate = anthropicConfig.streamRate;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @type {undefined | TBaseEndpoint} */
|
||||||
|
const allConfig = req.app.locals.all;
|
||||||
|
if (allConfig) {
|
||||||
|
clientOptions.streamRate = allConfig.streamRate;
|
||||||
|
}
|
||||||
|
|
||||||
const client = new AnthropicClient(anthropicApiKey, {
|
const client = new AnthropicClient(anthropicApiKey, {
|
||||||
req,
|
req,
|
||||||
res,
|
res,
|
||||||
reverseProxyUrl: ANTHROPIC_REVERSE_PROXY ?? null,
|
reverseProxyUrl: ANTHROPIC_REVERSE_PROXY ?? null,
|
||||||
proxy: PROXY ?? null,
|
proxy: PROXY ?? null,
|
||||||
|
...clientOptions,
|
||||||
...endpointOption,
|
...endpointOption,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
|
@ -114,9 +114,16 @@ const initializeClient = async ({ req, res, endpointOption }) => {
|
||||||
contextStrategy: endpointConfig.summarize ? 'summarize' : null,
|
contextStrategy: endpointConfig.summarize ? 'summarize' : null,
|
||||||
directEndpoint: endpointConfig.directEndpoint,
|
directEndpoint: endpointConfig.directEndpoint,
|
||||||
titleMessageRole: endpointConfig.titleMessageRole,
|
titleMessageRole: endpointConfig.titleMessageRole,
|
||||||
|
streamRate: endpointConfig.streamRate,
|
||||||
endpointTokenConfig,
|
endpointTokenConfig,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/** @type {undefined | TBaseEndpoint} */
|
||||||
|
const allConfig = req.app.locals.all;
|
||||||
|
if (allConfig) {
|
||||||
|
customOptions.streamRate = allConfig.streamRate;
|
||||||
|
}
|
||||||
|
|
||||||
const clientOptions = {
|
const clientOptions = {
|
||||||
reverseProxyUrl: baseURL ?? null,
|
reverseProxyUrl: baseURL ?? null,
|
||||||
proxy: PROXY ?? null,
|
proxy: PROXY ?? null,
|
||||||
|
|
|
@ -27,11 +27,27 @@ const initializeClient = async ({ req, res, endpointOption }) => {
|
||||||
[AuthKeys.GOOGLE_API_KEY]: GOOGLE_KEY,
|
[AuthKeys.GOOGLE_API_KEY]: GOOGLE_KEY,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const clientOptions = {};
|
||||||
|
|
||||||
|
/** @type {undefined | TBaseEndpoint} */
|
||||||
|
const allConfig = req.app.locals.all;
|
||||||
|
/** @type {undefined | TBaseEndpoint} */
|
||||||
|
const googleConfig = req.app.locals[EModelEndpoint.google];
|
||||||
|
|
||||||
|
if (googleConfig) {
|
||||||
|
clientOptions.streamRate = googleConfig.streamRate;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (allConfig) {
|
||||||
|
clientOptions.streamRate = allConfig.streamRate;
|
||||||
|
}
|
||||||
|
|
||||||
const client = new GoogleClient(credentials, {
|
const client = new GoogleClient(credentials, {
|
||||||
req,
|
req,
|
||||||
res,
|
res,
|
||||||
reverseProxyUrl: GOOGLE_REVERSE_PROXY ?? null,
|
reverseProxyUrl: GOOGLE_REVERSE_PROXY ?? null,
|
||||||
proxy: PROXY ?? null,
|
proxy: PROXY ?? null,
|
||||||
|
...clientOptions,
|
||||||
...endpointOption,
|
...endpointOption,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
|
@ -8,6 +8,8 @@ jest.mock('~/server/services/UserService', () => ({
|
||||||
getUserKey: jest.fn().mockImplementation(() => ({})),
|
getUserKey: jest.fn().mockImplementation(() => ({})),
|
||||||
}));
|
}));
|
||||||
|
|
||||||
|
const app = { locals: {} };
|
||||||
|
|
||||||
describe('google/initializeClient', () => {
|
describe('google/initializeClient', () => {
|
||||||
afterEach(() => {
|
afterEach(() => {
|
||||||
jest.clearAllMocks();
|
jest.clearAllMocks();
|
||||||
|
@ -23,6 +25,7 @@ describe('google/initializeClient', () => {
|
||||||
const req = {
|
const req = {
|
||||||
body: { key: expiresAt },
|
body: { key: expiresAt },
|
||||||
user: { id: '123' },
|
user: { id: '123' },
|
||||||
|
app,
|
||||||
};
|
};
|
||||||
const res = {};
|
const res = {};
|
||||||
const endpointOption = { modelOptions: { model: 'default-model' } };
|
const endpointOption = { modelOptions: { model: 'default-model' } };
|
||||||
|
@ -44,6 +47,7 @@ describe('google/initializeClient', () => {
|
||||||
const req = {
|
const req = {
|
||||||
body: { key: null },
|
body: { key: null },
|
||||||
user: { id: '123' },
|
user: { id: '123' },
|
||||||
|
app,
|
||||||
};
|
};
|
||||||
const res = {};
|
const res = {};
|
||||||
const endpointOption = { modelOptions: { model: 'default-model' } };
|
const endpointOption = { modelOptions: { model: 'default-model' } };
|
||||||
|
@ -66,6 +70,7 @@ describe('google/initializeClient', () => {
|
||||||
const req = {
|
const req = {
|
||||||
body: { key: expiresAt },
|
body: { key: expiresAt },
|
||||||
user: { id: '123' },
|
user: { id: '123' },
|
||||||
|
app,
|
||||||
};
|
};
|
||||||
const res = {};
|
const res = {};
|
||||||
const endpointOption = { modelOptions: { model: 'default-model' } };
|
const endpointOption = { modelOptions: { model: 'default-model' } };
|
||||||
|
|
|
@ -86,6 +86,9 @@ const initializeClient = async ({ req, res, endpointOption }) => {
|
||||||
clientOptions.titleModel = azureConfig.titleModel;
|
clientOptions.titleModel = azureConfig.titleModel;
|
||||||
clientOptions.titleMethod = azureConfig.titleMethod ?? 'completion';
|
clientOptions.titleMethod = azureConfig.titleMethod ?? 'completion';
|
||||||
|
|
||||||
|
const azureRate = modelName.includes('gpt-4') ? 30 : 17;
|
||||||
|
clientOptions.streamRate = azureConfig.streamRate ?? azureRate;
|
||||||
|
|
||||||
const groupName = modelGroupMap[modelName].group;
|
const groupName = modelGroupMap[modelName].group;
|
||||||
clientOptions.addParams = azureConfig.groupMap[groupName].addParams;
|
clientOptions.addParams = azureConfig.groupMap[groupName].addParams;
|
||||||
clientOptions.dropParams = azureConfig.groupMap[groupName].dropParams;
|
clientOptions.dropParams = azureConfig.groupMap[groupName].dropParams;
|
||||||
|
@ -98,6 +101,19 @@ const initializeClient = async ({ req, res, endpointOption }) => {
|
||||||
apiKey = clientOptions.azure.azureOpenAIApiKey;
|
apiKey = clientOptions.azure.azureOpenAIApiKey;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** @type {undefined | TBaseEndpoint} */
|
||||||
|
const pluginsConfig = req.app.locals[EModelEndpoint.gptPlugins];
|
||||||
|
|
||||||
|
if (!useAzure && pluginsConfig) {
|
||||||
|
clientOptions.streamRate = pluginsConfig.streamRate;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @type {undefined | TBaseEndpoint} */
|
||||||
|
const allConfig = req.app.locals.all;
|
||||||
|
if (allConfig) {
|
||||||
|
clientOptions.streamRate = allConfig.streamRate;
|
||||||
|
}
|
||||||
|
|
||||||
if (!apiKey) {
|
if (!apiKey) {
|
||||||
throw new Error(`${endpoint} API key not provided. Please provide it again.`);
|
throw new Error(`${endpoint} API key not provided. Please provide it again.`);
|
||||||
}
|
}
|
||||||
|
|
|
@ -76,6 +76,10 @@ const initializeClient = async ({ req, res, endpointOption }) => {
|
||||||
|
|
||||||
clientOptions.titleConvo = azureConfig.titleConvo;
|
clientOptions.titleConvo = azureConfig.titleConvo;
|
||||||
clientOptions.titleModel = azureConfig.titleModel;
|
clientOptions.titleModel = azureConfig.titleModel;
|
||||||
|
|
||||||
|
const azureRate = modelName.includes('gpt-4') ? 30 : 17;
|
||||||
|
clientOptions.streamRate = azureConfig.streamRate ?? azureRate;
|
||||||
|
|
||||||
clientOptions.titleMethod = azureConfig.titleMethod ?? 'completion';
|
clientOptions.titleMethod = azureConfig.titleMethod ?? 'completion';
|
||||||
|
|
||||||
const groupName = modelGroupMap[modelName].group;
|
const groupName = modelGroupMap[modelName].group;
|
||||||
|
@ -90,6 +94,19 @@ const initializeClient = async ({ req, res, endpointOption }) => {
|
||||||
apiKey = clientOptions.azure.azureOpenAIApiKey;
|
apiKey = clientOptions.azure.azureOpenAIApiKey;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** @type {undefined | TBaseEndpoint} */
|
||||||
|
const openAIConfig = req.app.locals[EModelEndpoint.openAI];
|
||||||
|
|
||||||
|
if (!isAzureOpenAI && openAIConfig) {
|
||||||
|
clientOptions.streamRate = openAIConfig.streamRate;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @type {undefined | TBaseEndpoint} */
|
||||||
|
const allConfig = req.app.locals.all;
|
||||||
|
if (allConfig) {
|
||||||
|
clientOptions.streamRate = allConfig.streamRate;
|
||||||
|
}
|
||||||
|
|
||||||
if (userProvidesKey & !apiKey) {
|
if (userProvidesKey & !apiKey) {
|
||||||
throw new Error(
|
throw new Error(
|
||||||
JSON.stringify({
|
JSON.stringify({
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
const WebSocket = require('ws');
|
const WebSocket = require('ws');
|
||||||
const { Message } = require('~/models/Message');
|
const { CacheKeys } = require('librechat-data-provider');
|
||||||
|
const { getLogStores } = require('~/cache');
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @param {string[]} voiceIds - Array of voice IDs
|
* @param {string[]} voiceIds - Array of voice IDs
|
||||||
|
@ -104,6 +105,8 @@ function createChunkProcessor(messageId) {
|
||||||
throw new Error('Message ID is required');
|
throw new Error('Message ID is required');
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const messageCache = getLogStores(CacheKeys.MESSAGES);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @returns {Promise<{ text: string, isFinished: boolean }[] | string>}
|
* @returns {Promise<{ text: string, isFinished: boolean }[] | string>}
|
||||||
*/
|
*/
|
||||||
|
@ -116,14 +119,17 @@ function createChunkProcessor(messageId) {
|
||||||
return `No change in message after ${MAX_NO_CHANGE_COUNT} attempts`;
|
return `No change in message after ${MAX_NO_CHANGE_COUNT} attempts`;
|
||||||
}
|
}
|
||||||
|
|
||||||
const message = await Message.findOne({ messageId }, 'text unfinished').lean();
|
/** @type { string | { text: string; complete: boolean } } */
|
||||||
|
const message = await messageCache.get(messageId);
|
||||||
|
|
||||||
if (!message || !message.text) {
|
if (!message) {
|
||||||
notFoundCount++;
|
notFoundCount++;
|
||||||
return [];
|
return [];
|
||||||
}
|
}
|
||||||
|
|
||||||
const { text, unfinished } = message;
|
const text = typeof message === 'string' ? message : message.text;
|
||||||
|
const complete = typeof message === 'string' ? false : message.complete;
|
||||||
|
|
||||||
if (text === processedText) {
|
if (text === processedText) {
|
||||||
noChangeCount++;
|
noChangeCount++;
|
||||||
}
|
}
|
||||||
|
@ -131,7 +137,7 @@ function createChunkProcessor(messageId) {
|
||||||
const remainingText = text.slice(processedText.length);
|
const remainingText = text.slice(processedText.length);
|
||||||
const chunks = [];
|
const chunks = [];
|
||||||
|
|
||||||
if (unfinished && remainingText.length >= 20) {
|
if (!complete && remainingText.length >= 20) {
|
||||||
const separatorIndex = findLastSeparatorIndex(remainingText);
|
const separatorIndex = findLastSeparatorIndex(remainingText);
|
||||||
if (separatorIndex !== -1) {
|
if (separatorIndex !== -1) {
|
||||||
const chunkText = remainingText.slice(0, separatorIndex + 1);
|
const chunkText = remainingText.slice(0, separatorIndex + 1);
|
||||||
|
@ -141,7 +147,7 @@ function createChunkProcessor(messageId) {
|
||||||
chunks.push({ text: remainingText, isFinished: false });
|
chunks.push({ text: remainingText, isFinished: false });
|
||||||
processedText = text;
|
processedText = text;
|
||||||
}
|
}
|
||||||
} else if (!unfinished && remainingText.trim().length > 0) {
|
} else if (complete && remainingText.trim().length > 0) {
|
||||||
chunks.push({ text: remainingText.trim(), isFinished: true });
|
chunks.push({ text: remainingText.trim(), isFinished: true });
|
||||||
processedText = text;
|
processedText = text;
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,89 +1,145 @@
|
||||||
const { createChunkProcessor, splitTextIntoChunks } = require('./streamAudio');
|
const { createChunkProcessor, splitTextIntoChunks } = require('./streamAudio');
|
||||||
const { Message } = require('~/models/Message');
|
|
||||||
|
|
||||||
jest.mock('~/models/Message', () => ({
|
jest.mock('keyv');
|
||||||
Message: {
|
|
||||||
findOne: jest.fn().mockReturnValue({
|
const globalCache = {};
|
||||||
lean: jest.fn(),
|
jest.mock('~/cache/getLogStores', () => {
|
||||||
}),
|
return jest.fn().mockImplementation(() => {
|
||||||
},
|
const EventEmitter = require('events');
|
||||||
}));
|
const { CacheKeys } = require('librechat-data-provider');
|
||||||
|
|
||||||
|
class KeyvMongo extends EventEmitter {
|
||||||
|
constructor(url = 'mongodb://127.0.0.1:27017', options) {
|
||||||
|
super();
|
||||||
|
this.ttlSupport = false;
|
||||||
|
url = url ?? {};
|
||||||
|
if (typeof url === 'string') {
|
||||||
|
url = { url };
|
||||||
|
}
|
||||||
|
if (url.uri) {
|
||||||
|
url = { url: url.uri, ...url };
|
||||||
|
}
|
||||||
|
this.opts = {
|
||||||
|
url,
|
||||||
|
collection: 'keyv',
|
||||||
|
...url,
|
||||||
|
...options,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
get = async (key) => {
|
||||||
|
return new Promise((resolve) => {
|
||||||
|
resolve(globalCache[key] || null);
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
set = async (key, value) => {
|
||||||
|
return new Promise((resolve) => {
|
||||||
|
globalCache[key] = value;
|
||||||
|
resolve(true);
|
||||||
|
});
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
return new KeyvMongo('', {
|
||||||
|
namespace: CacheKeys.MESSAGES,
|
||||||
|
ttl: 0,
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
describe('processChunks', () => {
|
describe('processChunks', () => {
|
||||||
let processChunks;
|
let processChunks;
|
||||||
|
let mockMessageCache;
|
||||||
|
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
|
jest.resetAllMocks();
|
||||||
|
mockMessageCache = {
|
||||||
|
get: jest.fn(),
|
||||||
|
};
|
||||||
|
require('~/cache/getLogStores').mockReturnValue(mockMessageCache);
|
||||||
processChunks = createChunkProcessor('message-id');
|
processChunks = createChunkProcessor('message-id');
|
||||||
Message.findOne.mockClear();
|
|
||||||
Message.findOne().lean.mockClear();
|
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should return an empty array when the message is not found', async () => {
|
it('should return an empty array when the message is not found', async () => {
|
||||||
Message.findOne().lean.mockResolvedValueOnce(null);
|
mockMessageCache.get.mockResolvedValueOnce(null);
|
||||||
|
|
||||||
const result = await processChunks();
|
const result = await processChunks();
|
||||||
|
|
||||||
expect(result).toEqual([]);
|
expect(result).toEqual([]);
|
||||||
expect(Message.findOne).toHaveBeenCalledWith({ messageId: 'message-id' }, 'text unfinished');
|
expect(mockMessageCache.get).toHaveBeenCalledWith('message-id');
|
||||||
expect(Message.findOne().lean).toHaveBeenCalled();
|
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should return an empty array when the message does not have a text property', async () => {
|
it('should return an error message after MAX_NOT_FOUND_COUNT attempts', async () => {
|
||||||
Message.findOne().lean.mockResolvedValueOnce({ unfinished: true });
|
mockMessageCache.get.mockResolvedValue(null);
|
||||||
|
|
||||||
|
for (let i = 0; i < 6; i++) {
|
||||||
|
await processChunks();
|
||||||
|
}
|
||||||
const result = await processChunks();
|
const result = await processChunks();
|
||||||
|
|
||||||
expect(result).toEqual([]);
|
expect(result).toBe('Message not found after 6 attempts');
|
||||||
expect(Message.findOne).toHaveBeenCalledWith({ messageId: 'message-id' }, 'text unfinished');
|
|
||||||
expect(Message.findOne().lean).toHaveBeenCalled();
|
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should return chunks for an unfinished message with separators', async () => {
|
it('should return chunks for an incomplete message with separators', async () => {
|
||||||
const messageText = 'This is a long message. It should be split into chunks. Lol hi mom';
|
const messageText = 'This is a long message. It should be split into chunks. Lol hi mom';
|
||||||
Message.findOne().lean.mockResolvedValueOnce({ text: messageText, unfinished: true });
|
mockMessageCache.get.mockResolvedValueOnce({ text: messageText, complete: false });
|
||||||
|
|
||||||
const result = await processChunks();
|
const result = await processChunks();
|
||||||
|
|
||||||
expect(result).toEqual([
|
expect(result).toEqual([
|
||||||
{ text: 'This is a long message. It should be split into chunks.', isFinished: false },
|
{ text: 'This is a long message. It should be split into chunks.', isFinished: false },
|
||||||
]);
|
]);
|
||||||
expect(Message.findOne).toHaveBeenCalledWith({ messageId: 'message-id' }, 'text unfinished');
|
|
||||||
expect(Message.findOne().lean).toHaveBeenCalled();
|
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should return chunks for an unfinished message without separators', async () => {
|
it('should return chunks for an incomplete message without separators', async () => {
|
||||||
const messageText = 'This is a long message without separators hello there my friend';
|
const messageText = 'This is a long message without separators hello there my friend';
|
||||||
Message.findOne().lean.mockResolvedValueOnce({ text: messageText, unfinished: true });
|
mockMessageCache.get.mockResolvedValueOnce({ text: messageText, complete: false });
|
||||||
|
|
||||||
const result = await processChunks();
|
const result = await processChunks();
|
||||||
|
|
||||||
expect(result).toEqual([{ text: messageText, isFinished: false }]);
|
expect(result).toEqual([{ text: messageText, isFinished: false }]);
|
||||||
expect(Message.findOne).toHaveBeenCalledWith({ messageId: 'message-id' }, 'text unfinished');
|
|
||||||
expect(Message.findOne().lean).toHaveBeenCalled();
|
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should return the remaining text as a chunk for a finished message', async () => {
|
it('should return the remaining text as a chunk for a complete message', async () => {
|
||||||
const messageText = 'This is a finished message.';
|
const messageText = 'This is a finished message.';
|
||||||
Message.findOne().lean.mockResolvedValueOnce({ text: messageText, unfinished: false });
|
mockMessageCache.get.mockResolvedValueOnce({ text: messageText, complete: true });
|
||||||
|
|
||||||
const result = await processChunks();
|
const result = await processChunks();
|
||||||
|
|
||||||
expect(result).toEqual([{ text: messageText, isFinished: true }]);
|
expect(result).toEqual([{ text: messageText, isFinished: true }]);
|
||||||
expect(Message.findOne).toHaveBeenCalledWith({ messageId: 'message-id' }, 'text unfinished');
|
|
||||||
expect(Message.findOne().lean).toHaveBeenCalled();
|
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should return an empty array for a finished message with no remaining text', async () => {
|
it('should return an empty array for a complete message with no remaining text', async () => {
|
||||||
const messageText = 'This is a finished message.';
|
const messageText = 'This is a finished message.';
|
||||||
Message.findOne().lean.mockResolvedValueOnce({ text: messageText, unfinished: false });
|
mockMessageCache.get.mockResolvedValueOnce({ text: messageText, complete: true });
|
||||||
|
|
||||||
await processChunks();
|
await processChunks();
|
||||||
Message.findOne().lean.mockResolvedValueOnce({ text: messageText, unfinished: false });
|
mockMessageCache.get.mockResolvedValueOnce({ text: messageText, complete: true });
|
||||||
const result = await processChunks();
|
const result = await processChunks();
|
||||||
|
|
||||||
expect(result).toEqual([]);
|
expect(result).toEqual([]);
|
||||||
expect(Message.findOne).toHaveBeenCalledWith({ messageId: 'message-id' }, 'text unfinished');
|
});
|
||||||
expect(Message.findOne().lean).toHaveBeenCalledTimes(2);
|
|
||||||
|
it('should return an error message after MAX_NO_CHANGE_COUNT attempts with no change', async () => {
|
||||||
|
const messageText = 'This is a message that does not change.';
|
||||||
|
mockMessageCache.get.mockResolvedValue({ text: messageText, complete: false });
|
||||||
|
|
||||||
|
for (let i = 0; i < 11; i++) {
|
||||||
|
await processChunks();
|
||||||
|
}
|
||||||
|
const result = await processChunks();
|
||||||
|
|
||||||
|
expect(result).toBe('No change in message after 10 attempts');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle string messages as incomplete', async () => {
|
||||||
|
const messageText = 'This is a message as a string.';
|
||||||
|
mockMessageCache.get.mockResolvedValueOnce(messageText);
|
||||||
|
|
||||||
|
const result = await processChunks();
|
||||||
|
|
||||||
|
expect(result).toEqual([{ text: messageText, isFinished: false }]);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
|
@ -1,17 +1,19 @@
|
||||||
const throttle = require('lodash/throttle');
|
const throttle = require('lodash/throttle');
|
||||||
const {
|
const {
|
||||||
|
Time,
|
||||||
|
CacheKeys,
|
||||||
StepTypes,
|
StepTypes,
|
||||||
ContentTypes,
|
ContentTypes,
|
||||||
ToolCallTypes,
|
ToolCallTypes,
|
||||||
// StepStatus,
|
|
||||||
MessageContentTypes,
|
MessageContentTypes,
|
||||||
AssistantStreamEvents,
|
AssistantStreamEvents,
|
||||||
|
Constants,
|
||||||
} = require('librechat-data-provider');
|
} = require('librechat-data-provider');
|
||||||
const { retrieveAndProcessFile } = require('~/server/services/Files/process');
|
const { retrieveAndProcessFile } = require('~/server/services/Files/process');
|
||||||
const { processRequiredActions } = require('~/server/services/ToolService');
|
const { processRequiredActions } = require('~/server/services/ToolService');
|
||||||
const { saveMessage, updateMessageText } = require('~/models/Message');
|
const { createOnProgress, sendMessage, sleep } = require('~/server/utils');
|
||||||
const { createOnProgress, sendMessage } = require('~/server/utils');
|
|
||||||
const { processMessages } = require('~/server/services/Threads');
|
const { processMessages } = require('~/server/services/Threads');
|
||||||
|
const { getLogStores } = require('~/cache');
|
||||||
const { logger } = require('~/config');
|
const { logger } = require('~/config');
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -68,8 +70,8 @@ class StreamRunManager {
|
||||||
this.attachedFileIds = fields.attachedFileIds;
|
this.attachedFileIds = fields.attachedFileIds;
|
||||||
/** @type {undefined | Promise<ChatCompletion>} */
|
/** @type {undefined | Promise<ChatCompletion>} */
|
||||||
this.visionPromise = fields.visionPromise;
|
this.visionPromise = fields.visionPromise;
|
||||||
/** @type {boolean} */
|
/** @type {number} */
|
||||||
this.savedInitialMessage = false;
|
this.streamRate = fields.streamRate ?? Constants.DEFAULT_STREAM_RATE;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @type {Object.<AssistantStreamEvents, (event: AssistantStreamEvent) => Promise<void>>}
|
* @type {Object.<AssistantStreamEvents, (event: AssistantStreamEvent) => Promise<void>>}
|
||||||
|
@ -139,11 +141,11 @@ class StreamRunManager {
|
||||||
return this.intermediateText;
|
return this.intermediateText;
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Saves the initial intermediate message
|
/** Returns the current, intermediate message
|
||||||
* @returns {Promise<void>}
|
* @returns {TMessage}
|
||||||
*/
|
*/
|
||||||
async saveInitialMessage() {
|
getIntermediateMessage() {
|
||||||
return saveMessage(this.req, {
|
return {
|
||||||
conversationId: this.finalMessage.conversationId,
|
conversationId: this.finalMessage.conversationId,
|
||||||
messageId: this.finalMessage.messageId,
|
messageId: this.finalMessage.messageId,
|
||||||
parentMessageId: this.parentMessageId,
|
parentMessageId: this.parentMessageId,
|
||||||
|
@ -155,7 +157,7 @@ class StreamRunManager {
|
||||||
sender: 'Assistant',
|
sender: 'Assistant',
|
||||||
unfinished: true,
|
unfinished: true,
|
||||||
error: false,
|
error: false,
|
||||||
});
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
/* <------------------ Main Event Handlers ------------------> */
|
/* <------------------ Main Event Handlers ------------------> */
|
||||||
|
@ -347,6 +349,8 @@ class StreamRunManager {
|
||||||
type: ContentTypes.TOOL_CALL,
|
type: ContentTypes.TOOL_CALL,
|
||||||
index,
|
index,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
await sleep(this.streamRate);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -444,6 +448,7 @@ class StreamRunManager {
|
||||||
if (content && content.type === MessageContentTypes.TEXT) {
|
if (content && content.type === MessageContentTypes.TEXT) {
|
||||||
this.intermediateText += content.text.value;
|
this.intermediateText += content.text.value;
|
||||||
onProgress(content.text.value);
|
onProgress(content.text.value);
|
||||||
|
await sleep(this.streamRate);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -589,21 +594,14 @@ class StreamRunManager {
|
||||||
const index = this.getStepIndex(stepKey);
|
const index = this.getStepIndex(stepKey);
|
||||||
this.orderedRunSteps.set(index, message_creation);
|
this.orderedRunSteps.set(index, message_creation);
|
||||||
|
|
||||||
|
const messageCache = getLogStores(CacheKeys.MESSAGES);
|
||||||
// Create the Factory Function to stream the message
|
// Create the Factory Function to stream the message
|
||||||
const { onProgress: progressCallback } = createOnProgress({
|
const { onProgress: progressCallback } = createOnProgress({
|
||||||
onProgress: throttle(
|
onProgress: throttle(
|
||||||
() => {
|
() => {
|
||||||
if (!this.savedInitialMessage) {
|
messageCache.set(this.finalMessage.messageId, this.getText(), Time.FIVE_MINUTES);
|
||||||
this.saveInitialMessage();
|
|
||||||
this.savedInitialMessage = true;
|
|
||||||
} else {
|
|
||||||
updateMessageText({
|
|
||||||
messageId: this.finalMessage.messageId,
|
|
||||||
text: this.getText(),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
2000,
|
3000,
|
||||||
{ trailing: false },
|
{ trailing: false },
|
||||||
),
|
),
|
||||||
});
|
});
|
||||||
|
|
|
@ -51,6 +51,7 @@ function assistantsConfigSetup(config, assistantsEndpoint, prevConfig = {}) {
|
||||||
excludedIds: parsedConfig.excludedIds,
|
excludedIds: parsedConfig.excludedIds,
|
||||||
privateAssistants: parsedConfig.privateAssistants,
|
privateAssistants: parsedConfig.privateAssistants,
|
||||||
timeoutMs: parsedConfig.timeoutMs,
|
timeoutMs: parsedConfig.timeoutMs,
|
||||||
|
streamRate: parsedConfig.streamRate,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -465,6 +465,12 @@
|
||||||
* @memberof typedefs
|
* @memberof typedefs
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @exports TBaseEndpoint
|
||||||
|
* @typedef {import('librechat-data-provider').TBaseEndpoint} TBaseEndpoint
|
||||||
|
* @memberof typedefs
|
||||||
|
*/
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @exports TEndpoint
|
* @exports TEndpoint
|
||||||
* @typedef {import('librechat-data-provider').TEndpoint} TEndpoint
|
* @typedef {import('librechat-data-provider').TEndpoint} TEndpoint
|
||||||
|
|
2
package-lock.json
generated
2
package-lock.json
generated
|
@ -29437,7 +29437,7 @@
|
||||||
},
|
},
|
||||||
"packages/data-provider": {
|
"packages/data-provider": {
|
||||||
"name": "librechat-data-provider",
|
"name": "librechat-data-provider",
|
||||||
"version": "0.7.1",
|
"version": "0.7.2",
|
||||||
"license": "ISC",
|
"license": "ISC",
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"@types/js-yaml": "^4.0.9",
|
"@types/js-yaml": "^4.0.9",
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
{
|
{
|
||||||
"name": "librechat-data-provider",
|
"name": "librechat-data-provider",
|
||||||
"version": "0.7.1",
|
"version": "0.7.2",
|
||||||
"description": "data services for librechat apps",
|
"description": "data services for librechat apps",
|
||||||
"main": "dist/index.js",
|
"main": "dist/index.js",
|
||||||
"module": "dist/index.es.js",
|
"module": "dist/index.es.js",
|
||||||
|
|
|
@ -136,7 +136,14 @@ export const defaultAssistantsVersion = {
|
||||||
[EModelEndpoint.azureAssistants]: 1,
|
[EModelEndpoint.azureAssistants]: 1,
|
||||||
};
|
};
|
||||||
|
|
||||||
export const assistantEndpointSchema = z.object({
|
export const baseEndpointSchema = z.object({
|
||||||
|
streamRate: z.number().optional(),
|
||||||
|
});
|
||||||
|
|
||||||
|
export type TBaseEndpoint = z.infer<typeof baseEndpointSchema>;
|
||||||
|
|
||||||
|
export const assistantEndpointSchema = baseEndpointSchema.merge(
|
||||||
|
z.object({
|
||||||
/* assistants specific */
|
/* assistants specific */
|
||||||
disableBuilder: z.boolean().optional(),
|
disableBuilder: z.boolean().optional(),
|
||||||
pollIntervalMs: z.number().optional(),
|
pollIntervalMs: z.number().optional(),
|
||||||
|
@ -170,11 +177,13 @@ export const assistantEndpointSchema = z.object({
|
||||||
titleMethod: z.union([z.literal('completion'), z.literal('functions')]).optional(),
|
titleMethod: z.union([z.literal('completion'), z.literal('functions')]).optional(),
|
||||||
titleModel: z.string().optional(),
|
titleModel: z.string().optional(),
|
||||||
headers: z.record(z.any()).optional(),
|
headers: z.record(z.any()).optional(),
|
||||||
});
|
}),
|
||||||
|
);
|
||||||
|
|
||||||
export type TAssistantEndpoint = z.infer<typeof assistantEndpointSchema>;
|
export type TAssistantEndpoint = z.infer<typeof assistantEndpointSchema>;
|
||||||
|
|
||||||
export const endpointSchema = z.object({
|
export const endpointSchema = baseEndpointSchema.merge(
|
||||||
|
z.object({
|
||||||
name: z.string().refine((value) => !eModelEndpointSchema.safeParse(value).success, {
|
name: z.string().refine((value) => !eModelEndpointSchema.safeParse(value).success, {
|
||||||
message: `Value cannot be one of the default endpoint (EModelEndpoint) values: ${Object.values(
|
message: `Value cannot be one of the default endpoint (EModelEndpoint) values: ${Object.values(
|
||||||
EModelEndpoint,
|
EModelEndpoint,
|
||||||
|
@ -200,7 +209,8 @@ export const endpointSchema = z.object({
|
||||||
customOrder: z.number().optional(),
|
customOrder: z.number().optional(),
|
||||||
directEndpoint: z.boolean().optional(),
|
directEndpoint: z.boolean().optional(),
|
||||||
titleMessageRole: z.string().optional(),
|
titleMessageRole: z.string().optional(),
|
||||||
});
|
}),
|
||||||
|
);
|
||||||
|
|
||||||
export type TEndpoint = z.infer<typeof endpointSchema>;
|
export type TEndpoint = z.infer<typeof endpointSchema>;
|
||||||
|
|
||||||
|
@ -213,6 +223,7 @@ export const azureEndpointSchema = z
|
||||||
.and(
|
.and(
|
||||||
endpointSchema
|
endpointSchema
|
||||||
.pick({
|
.pick({
|
||||||
|
streamRate: true,
|
||||||
titleConvo: true,
|
titleConvo: true,
|
||||||
titleMethod: true,
|
titleMethod: true,
|
||||||
titleModel: true,
|
titleModel: true,
|
||||||
|
@ -426,10 +437,15 @@ export const configSchema = z.object({
|
||||||
modelSpecs: specsConfigSchema.optional(),
|
modelSpecs: specsConfigSchema.optional(),
|
||||||
endpoints: z
|
endpoints: z
|
||||||
.object({
|
.object({
|
||||||
|
all: baseEndpointSchema.optional(),
|
||||||
|
[EModelEndpoint.openAI]: baseEndpointSchema.optional(),
|
||||||
|
[EModelEndpoint.google]: baseEndpointSchema.optional(),
|
||||||
|
[EModelEndpoint.anthropic]: baseEndpointSchema.optional(),
|
||||||
|
[EModelEndpoint.gptPlugins]: baseEndpointSchema.optional(),
|
||||||
[EModelEndpoint.azureOpenAI]: azureEndpointSchema.optional(),
|
[EModelEndpoint.azureOpenAI]: azureEndpointSchema.optional(),
|
||||||
[EModelEndpoint.azureAssistants]: assistantEndpointSchema.optional(),
|
[EModelEndpoint.azureAssistants]: assistantEndpointSchema.optional(),
|
||||||
[EModelEndpoint.assistants]: assistantEndpointSchema.optional(),
|
[EModelEndpoint.assistants]: assistantEndpointSchema.optional(),
|
||||||
custom: z.array(endpointSchema.partial()).optional(),
|
[EModelEndpoint.custom]: z.array(endpointSchema.partial()).optional(),
|
||||||
})
|
})
|
||||||
.strict()
|
.strict()
|
||||||
.refine((data) => Object.keys(data).length > 0, {
|
.refine((data) => Object.keys(data).length > 0, {
|
||||||
|
@ -657,6 +673,16 @@ export enum InfiniteCollections {
|
||||||
SHARED_LINKS = 'sharedLinks',
|
SHARED_LINKS = 'sharedLinks',
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Enum for time intervals
|
||||||
|
*/
|
||||||
|
export enum Time {
|
||||||
|
THIRTY_MINUTES = 1800000,
|
||||||
|
TEN_MINUTES = 600000,
|
||||||
|
FIVE_MINUTES = 300000,
|
||||||
|
TWO_MINUTES = 120000,
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Enum for cache keys.
|
* Enum for cache keys.
|
||||||
*/
|
*/
|
||||||
|
@ -727,6 +753,10 @@ export enum CacheKeys {
|
||||||
* Key for the cached audio run Ids.
|
* Key for the cached audio run Ids.
|
||||||
*/
|
*/
|
||||||
AUDIO_RUNS = 'audioRuns',
|
AUDIO_RUNS = 'audioRuns',
|
||||||
|
/**
|
||||||
|
* Key for in-progress messages.
|
||||||
|
*/
|
||||||
|
MESSAGES = 'messages',
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -911,6 +941,8 @@ export enum Constants {
|
||||||
COMMON_DIVIDER = '__',
|
COMMON_DIVIDER = '__',
|
||||||
/** Max length for commands */
|
/** Max length for commands */
|
||||||
COMMANDS_MAX_LENGTH = 56,
|
COMMANDS_MAX_LENGTH = 56,
|
||||||
|
/** Default Stream Rate (ms) */
|
||||||
|
DEFAULT_STREAM_RATE = 1,
|
||||||
}
|
}
|
||||||
|
|
||||||
export enum LocalStorageKeys {
|
export enum LocalStorageKeys {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue