🛠️ fix: Minor Fixes in Message, Ask/EditController, OpenAIClient, and countTokens (#1463)

* fix(Message): avoid overwriting unprovided properties

* fix(OpenAIClient): return intermediateReply on user abort

* fix(AskController): do not send/save final message if abort was triggered

* fix(countTokens): avoid fetching remote registry and exclusively use cl100k_base or p50k_base weights for token counting

* refactor(Message/messageSchema): rely on messageSchema for default values when saving messages

* fix(EditController): do not send/save final message if abort was triggered

* fix(config/helpers): fix module resolution error
This commit is contained in:
Danny Avila 2023-12-30 14:34:32 -05:00 committed by GitHub
parent e4c555f95a
commit 431fc6284f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 37 additions and 31 deletions

View file

@ -847,7 +847,7 @@ ${convo}
err?.message?.includes('abort') || err?.message?.includes('abort') ||
(err instanceof OpenAI.APIError && err?.message?.includes('abort')) (err instanceof OpenAI.APIError && err?.message?.includes('abort'))
) { ) {
return ''; return intermediateReply;
} }
if ( if (
err?.message?.includes( err?.message?.includes(

View file

@ -15,16 +15,16 @@ module.exports = {
parentMessageId, parentMessageId,
sender, sender,
text, text,
isCreatedByUser = false, isCreatedByUser,
error, error,
unfinished, unfinished,
files, files,
isEdited = false, isEdited,
finish_reason = null, finish_reason,
tokenCount = null, tokenCount,
plugin = null, plugin,
plugins = null, plugins,
model = null, model,
}) { }) {
try { try {
const validConvoId = idSchema.safeParse(conversationId); const validConvoId = idSchema.safeParse(conversationId);

View file

@ -21,6 +21,7 @@ const messageSchema = mongoose.Schema(
}, },
model: { model: {
type: String, type: String,
default: null,
}, },
conversationSignature: { conversationSignature: {
type: String, type: String,

View file

@ -118,16 +118,19 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
delete userMessage.image_urls; delete userMessage.image_urls;
} }
sendMessage(res, { if (!abortController.signal.aborted) {
title: await getConvoTitle(user, conversationId), sendMessage(res, {
final: true, title: await getConvoTitle(user, conversationId),
conversation: await getConvo(user, conversationId), final: true,
requestMessage: userMessage, conversation: await getConvo(user, conversationId),
responseMessage: response, requestMessage: userMessage,
}); responseMessage: response,
res.end(); });
res.end();
await saveMessage({ ...response, user });
}
await saveMessage({ ...response, user });
await saveMessage(userMessage); await saveMessage(userMessage);
if (addTitle && parentMessageId === '00000000-0000-0000-0000-000000000000' && newConvo) { if (addTitle && parentMessageId === '00000000-0000-0000-0000-000000000000' && newConvo) {

View file

@ -112,16 +112,18 @@ const EditController = async (req, res, next, initializeClient) => {
response = { ...response, ...metadata }; response = { ...response, ...metadata };
} }
await saveMessage({ ...response, user }); if (!abortController.signal.aborted) {
sendMessage(res, {
title: await getConvoTitle(user, conversationId),
final: true,
conversation: await getConvo(user, conversationId),
requestMessage: userMessage,
responseMessage: response,
});
res.end();
sendMessage(res, { await saveMessage({ ...response, user });
title: await getConvoTitle(user, conversationId), }
final: true,
conversation: await getConvo(user, conversationId),
requestMessage: userMessage,
responseMessage: response,
});
res.end();
} catch (error) { } catch (error) {
const partialText = getPartialText(); const partialText = getPartialText();
handleAbortError(res, req, error, { handleAbortError(res, req, error, {

View file

@ -1,13 +1,12 @@
const { load } = require('tiktoken/load');
const { Tiktoken } = require('tiktoken/lite'); const { Tiktoken } = require('tiktoken/lite');
const registry = require('tiktoken/registry.json'); const p50k_base = require('tiktoken/encoders/p50k_base.json');
const models = require('tiktoken/model_to_encoding.json'); const cl100k_base = require('tiktoken/encoders/cl100k_base.json');
const logger = require('~/config/winston'); const logger = require('~/config/winston');
const countTokens = async (text = '', modelName = 'gpt-3.5-turbo') => { const countTokens = async (text = '', modelName = 'gpt-3.5-turbo') => {
let encoder = null; let encoder = null;
try { try {
const model = await load(registry[models[modelName]]); const model = modelName.includes('text-davinci-003') ? p50k_base : cl100k_base;
encoder = new Tiktoken(model.bpe_ranks, model.special_tokens, model.pat_str); encoder = new Tiktoken(model.bpe_ranks, model.special_tokens, model.pat_str);
const tokens = encoder.encode(text); const tokens = encoder.encode(text);
encoder.free(); encoder.free();

View file

@ -6,7 +6,8 @@ const fs = require('fs');
const path = require('path'); const path = require('path');
const readline = require('readline'); const readline = require('readline');
const { execSync } = require('child_process'); const { execSync } = require('child_process');
const { connectDb } = require('@librechat/backend/lib/db'); require('module-alias')({ base: path.resolve(__dirname, '..', 'api') });
const connectDb = require('~/lib/db/connectDb');
const askQuestion = (query) => { const askQuestion = (query) => {
const rl = readline.createInterface({ const rl = readline.createInterface({