🚀 refactor: Enhance Custom Endpoints, Message Logic, and Payload Handling (#2895)

* chore: use node-fetch for OpenAIClient fetch key for non-crashing usage of AbortController in Bun runtime

* chore: variable order

* fix(useSSE): prevent finalHandler call in abortConversation to update messages/conversation after user navigated away

* chore: params order

* refactor: organize intermediate message logic and ensure correct variables are passed

* fix: Add stt and tts routes before upload limiters, prevent bans

* fix(abortRun): temp fix to delete unfinished messages to avoid message thread parent relationship issues

* refactor: Update AnthropicClient to use node-fetch for fetch key and add proxy support

* fix(gptPlugins): ensure parentMessageId/messageId relationship is maintained

* feat(BaseClient): custom fetch function to analyze/edit payloads just before sending (also prevents abortController crash on Bun runtime)

* feat: `directEndpoint` and `titleMessageRole` custom endpoint options

* chore: Bump version to 0.6.6 in data-provider package.json
This commit is contained in:
Danny Avila 2024-05-28 14:52:12 -04:00 committed by GitHub
parent 0ee060d730
commit 40685f6eb4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 137 additions and 29 deletions

View file

@ -1,4 +1,5 @@
const Anthropic = require('@anthropic-ai/sdk'); const Anthropic = require('@anthropic-ai/sdk');
const { HttpsProxyAgent } = require('https-proxy-agent');
const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken'); const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken');
const { const {
getResponseSender, getResponseSender,
@ -123,9 +124,14 @@ class AnthropicClient extends BaseClient {
getClient() { getClient() {
/** @type {Anthropic.default.RequestOptions} */ /** @type {Anthropic.default.RequestOptions} */
const options = { const options = {
fetch: this.fetch,
apiKey: this.apiKey, apiKey: this.apiKey,
}; };
if (this.options.proxy) {
options.httpAgent = new HttpsProxyAgent(this.options.proxy);
}
if (this.options.reverseProxyUrl) { if (this.options.reverseProxyUrl) {
options.baseURL = this.options.reverseProxyUrl; options.baseURL = this.options.reverseProxyUrl;
} }

View file

@ -1,4 +1,5 @@
const crypto = require('crypto'); const crypto = require('crypto');
const fetch = require('node-fetch');
const { supportsBalanceCheck, Constants } = require('librechat-data-provider'); const { supportsBalanceCheck, Constants } = require('librechat-data-provider');
const { getConvo, getMessages, saveMessage, updateMessage, saveConvo } = require('~/models'); const { getConvo, getMessages, saveMessage, updateMessage, saveConvo } = require('~/models');
const { addSpaceIfNeeded, isEnabled } = require('~/server/utils'); const { addSpaceIfNeeded, isEnabled } = require('~/server/utils');
@ -17,6 +18,7 @@ class BaseClient {
month: 'long', month: 'long',
day: 'numeric', day: 'numeric',
}); });
this.fetch = this.fetch.bind(this);
} }
setOptions() { setOptions() {
@ -54,6 +56,22 @@ class BaseClient {
}); });
} }
/**
* Makes an HTTP request and logs the process.
*
* @param {RequestInfo} url - The URL to make the request to. Can be a string or a Request object.
* @param {RequestInit} [init] - Optional init options for the request.
* @returns {Promise<Response>} - A promise that resolves to the response of the fetch request.
*/
async fetch(_url, init) {
let url = _url;
if (this.options.directEndpoint) {
url = this.options.reverseProxyUrl;
}
logger.debug(`Making request to ${url}`);
return await fetch(url, init);
}
getBuildMessagesOptions() { getBuildMessagesOptions() {
throw new Error('Subclasses must implement getBuildMessagesOptions'); throw new Error('Subclasses must implement getBuildMessagesOptions');
} }

View file

@ -589,7 +589,7 @@ class OpenAIClient extends BaseClient {
let streamResult = null; let streamResult = null;
this.modelOptions.user = this.user; this.modelOptions.user = this.user;
const invalidBaseUrl = this.completionsUrl && extractBaseURL(this.completionsUrl) === null; const invalidBaseUrl = this.completionsUrl && extractBaseURL(this.completionsUrl) === null;
const useOldMethod = !!(invalidBaseUrl || !this.isChatCompletion || typeof Bun !== 'undefined'); const useOldMethod = !!(invalidBaseUrl || !this.isChatCompletion);
if (typeof opts.onProgress === 'function' && useOldMethod) { if (typeof opts.onProgress === 'function' && useOldMethod) {
const completionResult = await this.getCompletion( const completionResult = await this.getCompletion(
payload, payload,
@ -829,7 +829,7 @@ class OpenAIClient extends BaseClient {
const instructionsPayload = [ const instructionsPayload = [
{ {
role: 'system', role: this.options.titleMessageRole ?? 'system',
content: `Please generate ${titleInstruction} content: `Please generate ${titleInstruction}
${convo} ${convo}
@ -1134,6 +1134,7 @@ ${convo}
let chatCompletion; let chatCompletion;
/** @type {OpenAI} */ /** @type {OpenAI} */
const openai = new OpenAI({ const openai = new OpenAI({
fetch: this.fetch,
apiKey: this.apiKey, apiKey: this.apiKey,
...opts, ...opts,
}); });

View file

@ -268,7 +268,7 @@ class PluginsClient extends OpenAIClient {
if (opts.progressCallback) { if (opts.progressCallback) {
opts.onProgress = opts.progressCallback.call(null, { opts.onProgress = opts.progressCallback.call(null, {
...(opts.progressOptions ?? {}), ...(opts.progressOptions ?? {}),
parentMessageId: opts.progressOptions?.parentMessageId ?? userMessage.messageId, parentMessageId: userMessage.messageId,
messageId: responseMessageId, messageId: responseMessageId,
}); });
} }

View file

@ -129,6 +129,14 @@ module.exports = {
throw new Error('Failed to save message.'); throw new Error('Failed to save message.');
} }
}, },
async updateMessageText({ messageId, text }) {
try {
await Message.updateOne({ messageId }, { text });
} catch (err) {
logger.error('Error updating message text:', err);
throw new Error('Failed to update message text.');
}
},
async updateMessage(message) { async updateMessage(message) {
try { try {
const { messageId, ...update } = message; const { messageId, ...update } = message;

View file

@ -496,7 +496,7 @@ const chatV2 = async (req, res) => {
handlers, handlers,
thread_id, thread_id,
attachedFileIds, attachedFileIds,
parentMessageId, parentMessageId: userMessageId,
responseMessage: openai.responseMessage, responseMessage: openai.responseMessage,
// streamOptions: { // streamOptions: {

View file

@ -1,6 +1,7 @@
const { CacheKeys, RunStatus, isUUID } = require('librechat-data-provider'); const { CacheKeys, RunStatus, isUUID } = require('librechat-data-provider');
const { initializeClient } = require('~/server/services/Endpoints/assistants'); const { initializeClient } = require('~/server/services/Endpoints/assistants');
const { checkMessageGaps, recordUsage } = require('~/server/services/Threads'); const { checkMessageGaps, recordUsage } = require('~/server/services/Threads');
const { deleteMessages } = require('~/models/Message');
const { getConvo } = require('~/models/Conversation'); const { getConvo } = require('~/models/Conversation');
const getLogStores = require('~/cache/getLogStores'); const getLogStores = require('~/cache/getLogStores');
const { sendMessage } = require('~/server/utils'); const { sendMessage } = require('~/server/utils');
@ -66,13 +67,19 @@ async function abortRun(req, res) {
logger.error('[abortRun] Error fetching or processing run', error); logger.error('[abortRun] Error fetching or processing run', error);
} }
/* TODO: a reconciling strategy between the existing intermediate message would be more optimal than deleting it */
await deleteMessages({
user: req.user.id,
unfinished: true,
conversationId,
});
runMessages = await checkMessageGaps({ runMessages = await checkMessageGaps({
openai, openai,
run_id,
endpoint, endpoint,
thread_id, thread_id,
run_id,
latestMessageId,
conversationId, conversationId,
latestMessageId,
}); });
const finalEvent = { const finalEvent = {

View file

@ -106,7 +106,11 @@ router.post(
const pluginMap = new Map(); const pluginMap = new Map();
const onAgentAction = async (action, runId) => { const onAgentAction = async (action, runId) => {
pluginMap.set(runId, action.tool); pluginMap.set(runId, action.tool);
sendIntermediateMessage(res, { plugins }); sendIntermediateMessage(res, {
plugins,
parentMessageId: userMessage.messageId,
messageId: responseMessageId,
});
}; };
const onToolStart = async (tool, input, runId, parentRunId) => { const onToolStart = async (tool, input, runId, parentRunId) => {
@ -124,7 +128,11 @@ router.post(
} }
const extraTokens = ':::plugin:::\n'; const extraTokens = ':::plugin:::\n';
plugins.push(latestPlugin); plugins.push(latestPlugin);
sendIntermediateMessage(res, { plugins }, extraTokens); sendIntermediateMessage(
res,
{ plugins, parentMessageId: userMessage.messageId, messageId: responseMessageId },
extraTokens,
);
}; };
const onToolEnd = async (output, runId) => { const onToolEnd = async (output, runId) => {
@ -142,7 +150,11 @@ router.post(
const onChainEnd = () => { const onChainEnd = () => {
saveMessage({ ...userMessage, user }); saveMessage({ ...userMessage, user });
sendIntermediateMessage(res, { plugins }); sendIntermediateMessage(res, {
plugins,
parentMessageId: userMessage.messageId,
messageId: responseMessageId,
});
}; };
const getAbortData = () => ({ const getAbortData = () => ({

View file

@ -110,7 +110,11 @@ router.post(
if (!start) { if (!start) {
saveMessage({ ...userMessage, user }); saveMessage({ ...userMessage, user });
} }
sendIntermediateMessage(res, { plugin }); sendIntermediateMessage(res, {
plugin,
parentMessageId: userMessage.messageId,
messageId: responseMessageId,
});
// logger.debug('PLUGIN ACTION', formattedAction); // logger.debug('PLUGIN ACTION', formattedAction);
}; };
@ -119,7 +123,11 @@ router.post(
plugin.outputs = steps && steps[0].action ? formatSteps(steps) : 'An error occurred.'; plugin.outputs = steps && steps[0].action ? formatSteps(steps) : 'An error occurred.';
plugin.loading = false; plugin.loading = false;
saveMessage({ ...userMessage, user }); saveMessage({ ...userMessage, user });
sendIntermediateMessage(res, { plugin }); sendIntermediateMessage(res, {
plugin,
parentMessageId: userMessage.messageId,
messageId: responseMessageId,
});
// logger.debug('CHAIN END', plugin.outputs); // logger.debug('CHAIN END', plugin.outputs);
}; };

View file

@ -14,6 +14,10 @@ const initialize = async () => {
router.use(checkBan); router.use(checkBan);
router.use(uaParser); router.use(uaParser);
/* Important: stt/tts routes must be added before the upload limiters */
router.use('/stt', stt);
router.use('/tts', tts);
const upload = await createMulterInstance(); const upload = await createMulterInstance();
const { fileUploadIpLimiter, fileUploadUserLimiter } = createFileLimiters(); const { fileUploadIpLimiter, fileUploadUserLimiter } = createFileLimiters();
router.post('*', fileUploadIpLimiter, fileUploadUserLimiter); router.post('*', fileUploadIpLimiter, fileUploadUserLimiter);

View file

@ -112,6 +112,8 @@ const initializeClient = async ({ req, res, endpointOption }) => {
modelDisplayLabel: endpointConfig.modelDisplayLabel, modelDisplayLabel: endpointConfig.modelDisplayLabel,
titleMethod: endpointConfig.titleMethod ?? 'completion', titleMethod: endpointConfig.titleMethod ?? 'completion',
contextStrategy: endpointConfig.summarize ? 'summarize' : null, contextStrategy: endpointConfig.summarize ? 'summarize' : null,
directEndpoint: endpointConfig.directEndpoint,
titleMessageRole: endpointConfig.titleMessageRole,
endpointTokenConfig, endpointTokenConfig,
}; };

View file

@ -9,9 +9,9 @@ const {
} = require('librechat-data-provider'); } = require('librechat-data-provider');
const { retrieveAndProcessFile } = require('~/server/services/Files/process'); const { retrieveAndProcessFile } = require('~/server/services/Files/process');
const { processRequiredActions } = require('~/server/services/ToolService'); const { processRequiredActions } = require('~/server/services/ToolService');
const { saveMessage, updateMessageText } = require('~/models/Message');
const { createOnProgress, sendMessage } = require('~/server/utils'); const { createOnProgress, sendMessage } = require('~/server/utils');
const { processMessages } = require('~/server/services/Threads'); const { processMessages } = require('~/server/services/Threads');
const { saveMessage } = require('~/models');
const { logger } = require('~/config'); const { logger } = require('~/config');
/** /**
@ -68,6 +68,8 @@ class StreamRunManager {
this.attachedFileIds = fields.attachedFileIds; this.attachedFileIds = fields.attachedFileIds;
/** @type {undefined | Promise<ChatCompletion>} */ /** @type {undefined | Promise<ChatCompletion>} */
this.visionPromise = fields.visionPromise; this.visionPromise = fields.visionPromise;
/** @type {boolean} */
this.savedInitialMessage = false;
/** /**
* @type {Object.<AssistantStreamEvents, (event: AssistantStreamEvent) => Promise<void>>} * @type {Object.<AssistantStreamEvents, (event: AssistantStreamEvent) => Promise<void>>}
@ -129,6 +131,33 @@ class StreamRunManager {
sendMessage(this.res, contentData); sendMessage(this.res, contentData);
} }
/* <------------------ Misc. Helpers ------------------> */
/** Returns the latest intermediate text
* @returns {string}
*/
getText() {
return this.intermediateText;
}
/** Saves the initial intermediate message
* @returns {Promise<void>}
*/
async saveInitialMessage() {
return saveMessage({
conversationId: this.finalMessage.conversationId,
messageId: this.finalMessage.messageId,
parentMessageId: this.parentMessageId,
model: this.req.body.assistant_id,
endpoint: this.req.body.endpoint,
isCreatedByUser: false,
user: this.req.user.id,
text: this.getText(),
sender: 'Assistant',
unfinished: true,
error: false,
});
}
/* <------------------ Main Event Handlers ------------------> */ /* <------------------ Main Event Handlers ------------------> */
/** /**
@ -530,23 +559,20 @@ class StreamRunManager {
const stepKey = message_creation.message_id; const stepKey = message_creation.message_id;
const index = this.getStepIndex(stepKey); const index = this.getStepIndex(stepKey);
this.orderedRunSteps.set(index, message_creation); this.orderedRunSteps.set(index, message_creation);
const getText = () => this.intermediateText;
// Create the Factory Function to stream the message // Create the Factory Function to stream the message
const { onProgress: progressCallback } = createOnProgress({ const { onProgress: progressCallback } = createOnProgress({
onProgress: throttle( onProgress: throttle(
() => { () => {
const text = getText(); if (!this.savedInitialMessage) {
saveMessage({ this.saveInitialMessage();
messageId: this.finalMessage.messageId, this.savedInitialMessage = true;
conversationId: this.finalMessage.conversationId, } else {
parentMessageId: this.parentMessageId, updateMessageText({
model: this.req.body.model, messageId: this.finalMessage.messageId,
user: this.req.user.id, text: this.getText(),
sender: 'Assistant', });
unfinished: true, }
error: false,
text,
});
}, },
2000, 2000,
{ trailing: false }, { trailing: false },

BIN
bun.lockb

Binary file not shown.

View file

@ -6,9 +6,9 @@ import store from '~/store';
function usePauseGlobalAudio(index = 0) { function usePauseGlobalAudio(index = 0) {
/* Global Audio Variables */ /* Global Audio Variables */
const setAudioRunId = useSetRecoilState(store.audioRunFamily(index)); const setAudioRunId = useSetRecoilState(store.audioRunFamily(index));
const setGlobalIsPlaying = useSetRecoilState(store.globalAudioPlayingFamily(index));
const setIsGlobalAudioFetching = useSetRecoilState(store.globalAudioFetchingFamily(index)); const setIsGlobalAudioFetching = useSetRecoilState(store.globalAudioFetchingFamily(index));
const [globalAudioURL, setGlobalAudioURL] = useRecoilState(store.globalAudioURLFamily(index)); const [globalAudioURL, setGlobalAudioURL] = useRecoilState(store.globalAudioURLFamily(index));
const setGlobalIsPlaying = useSetRecoilState(store.globalAudioPlayingFamily(index));
const pauseGlobalAudio = useCallback(() => { const pauseGlobalAudio = useCallback(() => {
if (globalAudioURL) { if (globalAudioURL) {

View file

@ -282,6 +282,12 @@ export default function useSSE(submission: TSubmission | null, index = 0) {
setShowStopButton(false); setShowStopButton(false);
setCompleted((prev) => new Set(prev.add(submission?.initialResponse?.messageId))); setCompleted((prev) => new Set(prev.add(submission?.initialResponse?.messageId)));
const currentMessages = getMessages();
// Early return if messages are empty; i.e., the user navigated away
if (!currentMessages?.length) {
return setIsSubmitting(false);
}
// update the messages; if assistants endpoint, client doesn't receive responseMessage // update the messages; if assistants endpoint, client doesn't receive responseMessage
if (runMessages) { if (runMessages) {
setMessages([...runMessages]); setMessages([...runMessages]);
@ -323,7 +329,15 @@ export default function useSSE(submission: TSubmission | null, index = 0) {
setIsSubmitting(false); setIsSubmitting(false);
}, },
[genTitle, queryClient, setMessages, setConversation, setIsSubmitting, setShowStopButton], [
genTitle,
queryClient,
getMessages,
setMessages,
setConversation,
setIsSubmitting,
setShowStopButton,
],
); );
const errorHandler = useCallback( const errorHandler = useCallback(

View file

@ -1,6 +1,6 @@
{ {
"name": "librechat-data-provider", "name": "librechat-data-provider",
"version": "0.6.5", "version": "0.6.6",
"description": "data services for librechat apps", "description": "data services for librechat apps",
"main": "dist/index.js", "main": "dist/index.js",
"module": "dist/index.es.js", "module": "dist/index.es.js",

View file

@ -198,6 +198,8 @@ export const endpointSchema = z.object({
addParams: z.record(z.any()).optional(), addParams: z.record(z.any()).optional(),
dropParams: z.array(z.string()).optional(), dropParams: z.array(z.string()).optional(),
customOrder: z.number().optional(), customOrder: z.number().optional(),
directEndpoint: z.boolean().optional(),
titleMessageRole: z.string().optional(),
}); });
export type TEndpoint = z.infer<typeof endpointSchema>; export type TEndpoint = z.infer<typeof endpointSchema>;
@ -747,7 +749,7 @@ export enum Constants {
/** Key for the app's version. */ /** Key for the app's version. */
VERSION = 'v0.7.2', VERSION = 'v0.7.2',
/** Key for the Custom Config's version (librechat.yaml). */ /** Key for the Custom Config's version (librechat.yaml). */
CONFIG_VERSION = '1.1.2', CONFIG_VERSION = '1.1.3',
/** Standard value for the first message's `parentMessageId` value, to indicate no parent exists. */ /** Standard value for the first message's `parentMessageId` value, to indicate no parent exists. */
NO_PARENT = '00000000-0000-0000-0000-000000000000', NO_PARENT = '00000000-0000-0000-0000-000000000000',
/** Fixed, encoded domain length for Azure OpenAI Assistants Function name parsing. */ /** Fixed, encoded domain length for Azure OpenAI Assistants Function name parsing. */