🚀 feat: Assistants Streaming (#2159)

* chore: bump openai to 4.29.0 and npm audit fix

* chore: remove unnecessary stream field from ContentData

* feat: new enum and types for AssistantStreamEvent

* refactor(AssistantService): remove stream field and add conversationId to text ContentData
> - return `finalMessage` and `text` on run completion
> - move `processMessages` to services/Threads to avoid circular dependencies with new stream handling
> - refactor(processMessages/retrieveAndProcessFile): add new `client` field to differentiate new RunClient type

* WIP: new assistants stream handling

* chore: stores messages to StreamRunManager

* chore: add additional typedefs

* fix: pass req and openai to StreamRunManager

* fix(AssistantService): pass openai as client to `retrieveAndProcessFile`

* WIP: streaming tool i/o, handle in_progress and completed run steps

* feat(assistants): process required actions with streaming enabled

* chore: condense early return check for useSSE useEffect

* chore: remove unnecessary comments and only handle completed tool calls when not function

* feat: add TTL for assistants run abort cacheKey

* feat: abort stream runs

* fix(assistants): render streaming cursor

* fix(assistants): hide edit icon as functionality is not supported

* fix(textArea): handle pasting edge cases; first, when onChange events wouldn't fire; second, when textarea wouldn't resize

* chore: memoize Conversations

* chore(useTextarea): reverse args order

* fix: load default capabilities when an azure is configured to support assistants, but `assistants` endpoint is not configured

* fix(AssistantSelect): update form assistant model on assistant form select

* fix(actions): handle azure strict validation for function names to fix crud for actions

* chore: remove content data debug log as it fires in rapid succession

* feat: improve UX for assistant errors mid-request

* feat: add tool call localizations and replace any domain separators from azure action names

* refactor(chat): error out tool calls without outputs during handleError

* fix(ToolService): handle domain separators allowing Azure use of actions

* refactor(StreamRunManager): types and throw Error if tool submission fails
This commit is contained in:
Danny Avila 2024-03-21 22:42:25 -04:00 committed by GitHub
parent ed64c76053
commit f427ad792a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
39 changed files with 1503 additions and 330 deletions

View file

@ -4,9 +4,10 @@ const { checkMessageGaps, recordUsage } = require('~/server/services/Threads');
const { getConvo } = require('~/models/Conversation');
const getLogStores = require('~/cache/getLogStores');
const { sendMessage } = require('~/server/utils');
// const spendTokens = require('~/models/spendTokens');
const { logger } = require('~/config');
const three_minutes = 1000 * 60 * 3;
async function abortRun(req, res) {
res.setHeader('Content-Type', 'application/json');
const { abortKey } = req.body;
@ -40,7 +41,7 @@ async function abortRun(req, res) {
const { openai } = await initializeClient({ req, res });
try {
await cache.set(cacheKey, 'cancelled');
await cache.set(cacheKey, 'cancelled', three_minutes);
const cancelledRun = await openai.beta.threads.runs.cancel(thread_id, run_id);
logger.debug('[abortRun] Cancelled run:', cancelledRun);
} catch (error) {

View file

@ -2,9 +2,9 @@ const { v4 } = require('uuid');
const express = require('express');
const { actionDelimiter } = require('librechat-data-provider');
const { initializeClient } = require('~/server/services/Endpoints/assistants');
const { encryptMetadata, domainParser } = require('~/server/services/ActionService');
const { updateAction, getActions, deleteAction } = require('~/models/Action');
const { updateAssistant, getAssistant } = require('~/models/Assistant');
const { encryptMetadata } = require('~/server/services/ActionService');
const { logger } = require('~/config');
const router = express.Router();
@ -44,7 +44,10 @@ router.post('/:assistant_id', async (req, res) => {
let metadata = encryptMetadata(_metadata);
const { domain } = metadata;
let { domain } = metadata;
/* Azure doesn't support periods in function names */
domain = domainParser(req, domain, true);
if (!domain) {
return res.status(400).json({ message: 'No domain provided' });
}
@ -141,9 +144,10 @@ router.post('/:assistant_id', async (req, res) => {
* @param {string} req.params.action_id - The ID of the action to delete.
* @returns {Object} 200 - success response - application/json
*/
router.delete('/:assistant_id/:action_id', async (req, res) => {
router.delete('/:assistant_id/:action_id/:model', async (req, res) => {
try {
const { assistant_id, action_id } = req.params;
const { assistant_id, action_id, model } = req.params;
req.body.model = model;
/** @type {{ openai: OpenAI }} */
const { openai } = await initializeClient({ req, res });
@ -167,6 +171,8 @@ router.delete('/:assistant_id/:action_id', async (req, res) => {
return true;
});
domain = domainParser(req, domain, true);
const updatedTools = tools.filter(
(tool) => !(tool.function && tool.function.name.includes(domain)),
);

View file

@ -4,8 +4,10 @@ const {
Constants,
RunStatus,
CacheKeys,
ContentTypes,
EModelEndpoint,
ViolationTypes,
AssistantStreamEvents,
} = require('librechat-data-provider');
const {
initThread,
@ -18,8 +20,8 @@ const {
const { runAssistant, createOnTextProgress } = require('~/server/services/AssistantService');
const { addTitle, initializeClient } = require('~/server/services/Endpoints/assistants');
const { sendResponse, sendMessage, sleep, isEnabled, countTokens } = require('~/server/utils');
const { createRun, StreamRunManager } = require('~/server/services/Runs');
const { getTransactions } = require('~/models/Transaction');
const { createRun } = require('~/server/services/Runs');
const checkBalance = require('~/models/checkBalance');
const { getConvo } = require('~/models/Conversation');
const getLogStores = require('~/cache/getLogStores');
@ -38,6 +40,8 @@ const {
router.post('/abort', handleAbort());
const ten_minutes = 1000 * 60 * 10;
/**
* @route POST /
* @desc Chat with an assistant
@ -147,7 +151,7 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res
return sendResponse(res, messageData, defaultErrorMessage);
}
await sleep(3000);
await sleep(2000);
try {
const status = await cache.get(cacheKey);
@ -187,6 +191,42 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res
latestMessageId: responseMessageId,
});
const errorContentPart = {
text: {
value:
error?.message ?? 'There was an error processing your request. Please try again later.',
},
type: ContentTypes.ERROR,
};
if (!Array.isArray(runMessages[runMessages.length - 1]?.content)) {
runMessages[runMessages.length - 1].content = [errorContentPart];
} else {
const contentParts = runMessages[runMessages.length - 1].content;
for (let i = 0; i < contentParts.length; i++) {
const currentPart = contentParts[i];
/** @type {CodeToolCall | RetrievalToolCall | FunctionToolCall | undefined} */
const toolCall = currentPart?.[ContentTypes.TOOL_CALL];
if (
toolCall &&
toolCall?.function &&
!(toolCall?.function?.output || toolCall?.function?.output?.length)
) {
contentParts[i] = {
...currentPart,
[ContentTypes.TOOL_CALL]: {
...toolCall,
function: {
...toolCall.function,
output: 'error processing tool',
},
},
};
}
}
runMessages[runMessages.length - 1].content.push(errorContentPart);
}
finalEvent = {
title: 'New Chat',
final: true,
@ -358,53 +398,107 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res
body.instructions = instructions;
}
/* NOTE:
* By default, a Run will use the model and tools configuration specified in Assistant object,
* but you can override most of these when creating the Run for added flexibility:
*/
const run = await createRun({
openai,
thread_id,
body,
});
run_id = run.id;
await cache.set(cacheKey, `${thread_id}:${run_id}`);
sendMessage(res, {
sync: true,
conversationId,
// messages: previousMessages,
requestMessage,
responseMessage: {
user: req.user.id,
messageId: openai.responseMessage.messageId,
parentMessageId: userMessageId,
const sendInitialResponse = () => {
sendMessage(res, {
sync: true,
conversationId,
assistant_id,
thread_id,
model: assistant_id,
},
});
// messages: previousMessages,
requestMessage,
responseMessage: {
user: req.user.id,
messageId: openai.responseMessage.messageId,
parentMessageId: userMessageId,
conversationId,
assistant_id,
thread_id,
model: assistant_id,
},
});
};
// todo: retry logic
let response = await runAssistant({ openai, thread_id, run_id });
logger.debug('[/assistants/chat/] response', response);
/** @type {RunResponse | typeof StreamRunManager | undefined} */
let response;
if (response.run.status === RunStatus.IN_PROGRESS) {
response = await runAssistant({
const processRun = async (retry = false) => {
if (req.app.locals[EModelEndpoint.azureOpenAI]?.assistants) {
if (retry) {
response = await runAssistant({
openai,
thread_id,
run_id,
in_progress: openai.in_progress,
});
return;
}
/* NOTE:
* By default, a Run will use the model and tools configuration specified in Assistant object,
* but you can override most of these when creating the Run for added flexibility:
*/
const run = await createRun({
openai,
thread_id,
body,
});
run_id = run.id;
await cache.set(cacheKey, `${thread_id}:${run_id}`, ten_minutes);
sendInitialResponse();
// todo: retry logic
response = await runAssistant({ openai, thread_id, run_id });
return;
}
/** @type {{[AssistantStreamEvents.ThreadRunCreated]: (event: ThreadRunCreated) => Promise<void>}} */
const handlers = {
[AssistantStreamEvents.ThreadRunCreated]: async (event) => {
await cache.set(cacheKey, `${thread_id}:${event.data.id}`, ten_minutes);
run_id = event.data.id;
sendInitialResponse();
},
};
const streamRunManager = new StreamRunManager({
req,
res,
openai,
thread_id,
run_id,
in_progress: openai.in_progress,
responseMessage: openai.responseMessage,
handlers,
// streamOptions: {
// },
});
await streamRunManager.runAssistant({
thread_id,
body,
});
response = streamRunManager;
};
await processRun();
logger.debug('[/assistants/chat/] response', {
run: response.run,
steps: response.steps,
});
if (response.run.status === RunStatus.CANCELLED) {
logger.debug('[/assistants/chat/] Run cancelled, handled by `abortRun`');
return res.end();
}
if (response.run.status === RunStatus.IN_PROGRESS) {
processRun(true);
}
completedRun = response.run;
/** @type {ResponseMessage} */
const responseMessage = {
...openai.responseMessage,
...response.finalMessage,
parentMessageId: userMessageId,
conversationId,
user: req.user.id,
@ -413,9 +507,6 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res
model: assistant_id,
};
// TODO: token count from usage returned in run
// TODO: parse responses, save to db, send to user
sendMessage(res, {
title: 'New Chat',
final: true,
@ -432,7 +523,7 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res
if (parentMessageId === Constants.NO_PARENT && !_thread_id) {
addTitle(req, {
text,
responseText: openai.responseText,
responseText: response.text,
conversationId,
client,
});
@ -447,7 +538,7 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res
if (!response.run.usage) {
await sleep(3000);
completedRun = await openai.beta.threads.runs.retrieve(thread_id, run.id);
completedRun = await openai.beta.threads.runs.retrieve(thread_id, response.run.id);
if (completedRun.usage) {
await recordUsage({
...completedRun.usage,

View file

@ -1,8 +1,35 @@
const { AuthTypeEnum } = require('librechat-data-provider');
const { AuthTypeEnum, EModelEndpoint, actionDomainSeparator } = require('librechat-data-provider');
const { encryptV2, decryptV2 } = require('~/server/utils/crypto');
const { getActions } = require('~/models/Action');
const { logger } = require('~/config');
/**
* Parses the domain for an action.
*
* Azure OpenAI Assistants API doesn't support periods in function
* names due to `[a-zA-Z0-9_-]*` Regex Validation.
*
* @param {Express.Request} req - Express Request object
* @param {string} domain - The domain for the actoin
* @param {boolean} inverse - If true, replaces periods with `actionDomainSeparator`
* @returns {string} The parsed domain
*/
function domainParser(req, domain, inverse = false) {
if (!domain) {
return;
}
if (!req.app.locals[EModelEndpoint.azureOpenAI]?.assistants) {
return domain;
}
if (inverse) {
return domain.replace(/\./g, actionDomainSeparator);
}
return domain.replace(actionDomainSeparator, '.');
}
/**
* Loads action sets based on the user and assistant ID.
*
@ -117,4 +144,5 @@ module.exports = {
createActionTool,
encryptMetadata,
decryptMetadata,
domainParser,
};

View file

@ -1,6 +1,7 @@
const {
Constants,
FileSources,
Capabilities,
EModelEndpoint,
defaultSocialLogins,
validateAzureGroups,
@ -122,6 +123,13 @@ const AppService = async (app) => {
);
}
});
if (azureConfiguration.assistants) {
endpointLocals[EModelEndpoint.assistants] = {
// Note: may need to add retrieval models here in the future
capabilities: [Capabilities.tools, Capabilities.actions, Capabilities.code_interpreter],
};
}
}
if (config?.endpoints?.[EModelEndpoint.assistants]) {
@ -133,8 +141,11 @@ const AppService = async (app) => {
);
}
const prevConfig = endpointLocals[EModelEndpoint.assistants] ?? {};
/** @type {Partial<TAssistantEndpoint>} */
endpointLocals[EModelEndpoint.assistants] = {
...prevConfig,
retrievalModels: parsedConfig.retrievalModels,
disableBuilder: parsedConfig.disableBuilder,
pollIntervalMs: parsedConfig.pollIntervalMs,

View file

@ -4,18 +4,17 @@ const {
StepTypes,
RunStatus,
StepStatus,
FilePurpose,
ContentTypes,
ToolCallTypes,
imageExtRegex,
imageGenTools,
EModelEndpoint,
defaultOrderQuery,
} = require('librechat-data-provider');
const { retrieveAndProcessFile } = require('~/server/services/Files/process');
const { RunManager, waitForRun } = require('~/server/services/Runs');
const { processRequiredActions } = require('~/server/services/ToolService');
const { createOnProgress, sendMessage, sleep } = require('~/server/utils');
const { RunManager, waitForRun } = require('~/server/services/Runs');
const { processMessages } = require('~/server/services/Threads');
const { TextStream } = require('~/app/clients');
const { logger } = require('~/config');
@ -230,6 +229,7 @@ function createInProgressHandler(openai, thread_id, messages) {
const { file_id } = output.image;
const file = await retrieveAndProcessFile({
openai,
client: openai,
file_id,
basename: `${file_id}.png`,
});
@ -299,7 +299,7 @@ function createInProgressHandler(openai, thread_id, messages) {
openai.index++;
}
const result = await processMessages(openai, [message]);
const result = await processMessages({ openai, client: openai, messages: [message] });
openai.addContentData({
[ContentTypes.TEXT]: { value: result.text },
type: ContentTypes.TEXT,
@ -318,8 +318,8 @@ function createInProgressHandler(openai, thread_id, messages) {
res: openai.res,
index: messageIndex,
messageId: openai.responseMessage.messageId,
conversationId: openai.responseMessage.conversationId,
type: ContentTypes.TEXT,
stream: true,
thread_id,
});
@ -416,7 +416,13 @@ async function runAssistant({
// const { messages: sortedMessages, text } = await processMessages(openai, messages);
// return { run, steps, messages: sortedMessages, text };
const sortedMessages = messages.sort((a, b) => a.created_at - b.created_at);
return { run, steps, messages: sortedMessages };
return {
run,
steps,
messages: sortedMessages,
finalMessage: openai.responseMessage,
text: openai.responseText,
};
}
const { submit_tool_outputs } = run.required_action;
@ -447,98 +453,8 @@ async function runAssistant({
});
}
/**
* Sorts, processes, and flattens messages to a single string.
*
* @param {OpenAIClient} openai - The OpenAI client instance.
* @param {ThreadMessage[]} messages - An array of messages.
* @returns {Promise<{messages: ThreadMessage[], text: string}>} The sorted messages and the flattened text.
*/
async function processMessages(openai, messages = []) {
const sorted = messages.sort((a, b) => a.created_at - b.created_at);
let text = '';
for (const message of sorted) {
message.files = [];
for (const content of message.content) {
const processImageFile =
content.type === 'image_file' && !openai.processedFileIds.has(content.image_file?.file_id);
if (processImageFile) {
const { file_id } = content.image_file;
const file = await retrieveAndProcessFile({ openai, file_id, basename: `${file_id}.png` });
openai.processedFileIds.add(file_id);
message.files.push(file);
continue;
}
text += (content.text?.value ?? '') + ' ';
logger.debug('[processMessages] Processing message:', { value: text });
// Process annotations if they exist
if (!content.text?.annotations?.length) {
continue;
}
logger.debug('[processMessages] Processing annotations:', content.text.annotations);
for (const annotation of content.text.annotations) {
logger.debug('Current annotation:', annotation);
let file;
const processFilePath =
annotation.file_path && !openai.processedFileIds.has(annotation.file_path?.file_id);
if (processFilePath) {
const basename = imageExtRegex.test(annotation.text)
? path.basename(annotation.text)
: null;
file = await retrieveAndProcessFile({
openai,
file_id: annotation.file_path.file_id,
basename,
});
openai.processedFileIds.add(annotation.file_path.file_id);
}
const processFileCitation =
annotation.file_citation &&
!openai.processedFileIds.has(annotation.file_citation?.file_id);
if (processFileCitation) {
file = await retrieveAndProcessFile({
openai,
file_id: annotation.file_citation.file_id,
unknownType: true,
});
openai.processedFileIds.add(annotation.file_citation.file_id);
}
if (!file && (annotation.file_path || annotation.file_citation)) {
const { file_id } = annotation.file_citation || annotation.file_path || {};
file = await retrieveAndProcessFile({ openai, file_id, unknownType: true });
openai.processedFileIds.add(file_id);
}
if (!file) {
continue;
}
if (file.purpose && file.purpose === FilePurpose.Assistants) {
text = text.replace(annotation.text, file.filename);
} else if (file.filepath) {
text = text.replace(annotation.text, file.filepath);
}
message.files.push(file);
}
}
}
return { messages: sorted, text };
}
module.exports = {
getResponse,
runAssistant,
processMessages,
createOnTextProgress,
};

View file

@ -338,19 +338,26 @@ const processFileUpload = async ({ req, res, file, metadata }) => {
* Retrieves and processes an OpenAI file based on its type.
*
* @param {Object} params - The params passed to the function.
* @param {OpenAIClient} params.openai - The params passed to the function.
* @param {OpenAIClient} params.openai - The OpenAI client instance.
* @param {RunClient} params.client - The LibreChat client instance: either refers to `openai` or `streamRunManager`.
* @param {string} params.file_id - The ID of the file to retrieve.
* @param {string} params.basename - The basename of the file (if image); e.g., 'image.jpg'.
* @param {boolean} [params.unknownType] - Whether the file type is unknown.
* @returns {Promise<{file_id: string, filepath: string, source: string, bytes?: number, width?: number, height?: number} | null>}
* - Returns null if `file_id` is not defined; else, the file metadata if successfully retrieved and processed.
*/
async function retrieveAndProcessFile({ openai, file_id, basename: _basename, unknownType }) {
async function retrieveAndProcessFile({
openai,
client,
file_id,
basename: _basename,
unknownType,
}) {
if (!file_id) {
return null;
}
if (openai.attachedFileIds?.has(file_id)) {
if (client.attachedFileIds?.has(file_id)) {
return {
file_id,
// filepath: TODO: local source filepath?,
@ -416,7 +423,7 @@ async function retrieveAndProcessFile({ openai, file_id, basename: _basename, un
*/
const processAsImage = async (dataBuffer, fileExt) => {
// Logic to process image files, convert to webp, etc.
const _file = await convertToWebP(openai.req, dataBuffer, 'high', `${file_id}${fileExt}`);
const _file = await convertToWebP(client.req, dataBuffer, 'high', `${file_id}${fileExt}`);
const file = {
..._file,
type: 'image/webp',

View file

@ -0,0 +1,618 @@
const path = require('path');
const {
StepTypes,
ContentTypes,
ToolCallTypes,
// StepStatus,
MessageContentTypes,
AssistantStreamEvents,
} = require('librechat-data-provider');
const { retrieveAndProcessFile } = require('~/server/services/Files/process');
const { processRequiredActions } = require('~/server/services/ToolService');
const { createOnProgress, sendMessage } = require('~/server/utils');
const { processMessages } = require('~/server/services/Threads');
const { logger } = require('~/config');
/**
* Implements the StreamRunManager functionality for managing the streaming
* and processing of run steps, messages, and tool calls within a thread.
* @implements {StreamRunManager}
*/
class StreamRunManager {
constructor(fields) {
this.index = 0;
/** @type {Map<string, RunStep>} */
this.steps = new Map();
/** @type {Map<string, number} */
this.mappedOrder = new Map();
/** @type {Map<string, StepToolCall} */
this.orderedRunSteps = new Map();
/** @type {Set<string>} */
this.processedFileIds = new Set();
/** @type {Map<string, (delta: ToolCallDelta | string) => Promise<void>} */
this.progressCallbacks = new Map();
/** @type {Run | null} */
this.run = null;
/** @type {Express.Request} */
this.req = fields.req;
/** @type {Express.Response} */
this.res = fields.res;
/** @type {OpenAI} */
this.openai = fields.openai;
/** @type {string} */
this.apiKey = this.openai.apiKey;
/** @type {string} */
this.thread_id = fields.thread_id;
/** @type {RunCreateAndStreamParams} */
this.initialRunBody = fields.runBody;
/**
* @type {Object.<AssistantStreamEvents, (event: AssistantStreamEvent) => Promise<void>>}
*/
this.clientHandlers = fields.handlers ?? {};
/** @type {OpenAIRequestOptions} */
this.streamOptions = fields.streamOptions ?? {};
/** @type {Partial<TMessage>} */
this.finalMessage = fields.responseMessage ?? {};
/** @type {ThreadMessage[]} */
this.messages = [];
/** @type {string} */
this.text = '';
/**
* @type {Object.<AssistantStreamEvents, (event: AssistantStreamEvent) => Promise<void>>}
*/
this.handlers = {
[AssistantStreamEvents.ThreadCreated]: this.handleThreadCreated,
[AssistantStreamEvents.ThreadRunCreated]: this.handleRunEvent,
[AssistantStreamEvents.ThreadRunQueued]: this.handleRunEvent,
[AssistantStreamEvents.ThreadRunInProgress]: this.handleRunEvent,
[AssistantStreamEvents.ThreadRunRequiresAction]: this.handleRunEvent,
[AssistantStreamEvents.ThreadRunCompleted]: this.handleRunEvent,
[AssistantStreamEvents.ThreadRunFailed]: this.handleRunEvent,
[AssistantStreamEvents.ThreadRunCancelling]: this.handleRunEvent,
[AssistantStreamEvents.ThreadRunCancelled]: this.handleRunEvent,
[AssistantStreamEvents.ThreadRunExpired]: this.handleRunEvent,
[AssistantStreamEvents.ThreadRunStepCreated]: this.handleRunStepEvent,
[AssistantStreamEvents.ThreadRunStepInProgress]: this.handleRunStepEvent,
[AssistantStreamEvents.ThreadRunStepCompleted]: this.handleRunStepEvent,
[AssistantStreamEvents.ThreadRunStepFailed]: this.handleRunStepEvent,
[AssistantStreamEvents.ThreadRunStepCancelled]: this.handleRunStepEvent,
[AssistantStreamEvents.ThreadRunStepExpired]: this.handleRunStepEvent,
[AssistantStreamEvents.ThreadRunStepDelta]: this.handleRunStepDeltaEvent,
[AssistantStreamEvents.ThreadMessageCreated]: this.handleMessageEvent,
[AssistantStreamEvents.ThreadMessageInProgress]: this.handleMessageEvent,
[AssistantStreamEvents.ThreadMessageCompleted]: this.handleMessageEvent,
[AssistantStreamEvents.ThreadMessageIncomplete]: this.handleMessageEvent,
[AssistantStreamEvents.ThreadMessageDelta]: this.handleMessageDeltaEvent,
[AssistantStreamEvents.ErrorEvent]: this.handleErrorEvent,
};
}
/**
*
* Sends the content data to the client via SSE.
*
* @param {StreamContentData} data
* @returns {Promise<void>}
*/
async addContentData(data) {
const { type, index } = data;
this.finalMessage.content[index] = { type, [type]: data[type] };
if (type === ContentTypes.TEXT) {
this.text += data[type].value;
return;
}
const contentData = {
index,
type,
[type]: data[type],
thread_id: this.thread_id,
messageId: this.finalMessage.messageId,
conversationId: this.finalMessage.conversationId,
};
sendMessage(this.res, contentData);
}
/* <------------------ Main Event Handlers ------------------> */
/**
* Run the assistant and handle the events.
* @param {Object} params -
* The parameters for running the assistant.
* @param {string} params.thread_id - The thread id.
* @param {RunCreateAndStreamParams} params.body - The body of the run.
* @returns {Promise<void>}
*/
async runAssistant({ thread_id, body }) {
const streamRun = this.openai.beta.threads.runs.createAndStream(
thread_id,
body,
this.streamOptions,
);
for await (const event of streamRun) {
await this.handleEvent(event);
}
}
/**
* Handle the event.
* @param {AssistantStreamEvent} event - The stream event object.
* @returns {Promise<void>}
*/
async handleEvent(event) {
const handler = this.handlers[event.event];
const clientHandler = this.clientHandlers[event.event];
if (clientHandler) {
await clientHandler.call(this, event);
}
if (handler) {
await handler.call(this, event);
} else {
logger.warn(`Unhandled event type: ${event.event}`);
}
}
/**
* Handle thread.created event
* @param {ThreadCreated} event -
* The thread.created event object.
*/
async handleThreadCreated(event) {
logger.debug('Thread created:', event.data);
}
/**
* Handle Run Events
* @param {ThreadRunCreated | ThreadRunQueued | ThreadRunInProgress | ThreadRunRequiresAction | ThreadRunCompleted | ThreadRunFailed | ThreadRunCancelling | ThreadRunCancelled | ThreadRunExpired} event -
* The run event object.
*/
async handleRunEvent(event) {
this.run = event.data;
logger.debug('Run event:', this.run);
if (event.event === AssistantStreamEvents.ThreadRunRequiresAction) {
await this.onRunRequiresAction(event);
} else if (event.event === AssistantStreamEvents.ThreadRunCompleted) {
logger.debug('Run completed:', this.run);
}
}
/**
* Handle Run Step Events
* @param {ThreadRunStepCreated | ThreadRunStepInProgress | ThreadRunStepCompleted | ThreadRunStepFailed | ThreadRunStepCancelled | ThreadRunStepExpired} event -
* The run step event object.
*/
async handleRunStepEvent(event) {
logger.debug('Run step event:', event.data);
const step = event.data;
this.steps.set(step.id, step);
if (event.event === AssistantStreamEvents.ThreadRunStepCreated) {
this.onRunStepCreated(event);
} else if (event.event === AssistantStreamEvents.ThreadRunStepCompleted) {
this.onRunStepCompleted(event);
}
}
/* <------------------ Delta Events ------------------> */
/** @param {CodeImageOutput} */
async handleCodeImageOutput(output) {
if (this.processedFileIds.has(output.image?.file_id)) {
return;
}
const { file_id } = output.image;
const file = await retrieveAndProcessFile({
openai: this.openai,
client: this,
file_id,
basename: `${file_id}.png`,
});
// toolCall.asset_pointer = file.filepath;
const prelimImage = {
file_id,
filename: path.basename(file.filepath),
filepath: file.filepath,
height: file.height,
width: file.width,
};
// check if every key has a value before adding to content
const prelimImageKeys = Object.keys(prelimImage);
const validImageFile = prelimImageKeys.every((key) => prelimImage[key]);
if (!validImageFile) {
return;
}
const index = this.getStepIndex(file_id);
const image_file = {
[ContentTypes.IMAGE_FILE]: prelimImage,
type: ContentTypes.IMAGE_FILE,
index,
};
this.addContentData(image_file);
this.processedFileIds.add(file_id);
}
/**
* Create Tool Call Stream
* @param {number} index - The index of the tool call.
* @param {StepToolCall} toolCall -
* The current tool call object.
*/
createToolCallStream(index, toolCall) {
/** @type {StepToolCall} */
const state = toolCall;
const type = state.type;
const data = state[type];
/** @param {ToolCallDelta} */
const deltaHandler = async (delta) => {
for (const key in delta) {
if (!Object.prototype.hasOwnProperty.call(data, key)) {
logger.warn(`Unhandled tool call key "${key}", delta: `, delta);
continue;
}
if (Array.isArray(delta[key])) {
if (!Array.isArray(data[key])) {
data[key] = [];
}
for (const d of delta[key]) {
if (typeof d === 'object' && !Object.prototype.hasOwnProperty.call(d, 'index')) {
logger.warn('Expected an object with an \'index\' for array updates but got:', d);
continue;
}
const imageOutput = type === ToolCallTypes.CODE_INTERPRETER && d?.type === 'image';
if (imageOutput) {
await this.handleCodeImageOutput(d);
continue;
}
const { index, ...updateData } = d;
// Ensure the data at index is an object or undefined before assigning
if (typeof data[key][index] !== 'object' || data[key][index] === null) {
data[key][index] = {};
}
// Merge the updateData into data[key][index]
for (const updateKey in updateData) {
data[key][index][updateKey] = updateData[updateKey];
}
}
} else if (typeof delta[key] === 'string' && typeof data[key] === 'string') {
// Concatenate strings
data[key] += delta[key];
} else if (
typeof delta[key] === 'object' &&
delta[key] !== null &&
!Array.isArray(delta[key])
) {
// Merge objects
data[key] = { ...data[key], ...delta[key] };
} else {
// Directly set the value for other types
data[key] = delta[key];
}
state[type] = data;
this.addContentData({
[ContentTypes.TOOL_CALL]: toolCall,
type: ContentTypes.TOOL_CALL,
index,
});
}
};
return deltaHandler;
}
/**
* @param {string} stepId -
* @param {StepToolCall} toolCall -
*
*/
handleNewToolCall(stepId, toolCall) {
const stepKey = this.generateToolCallKey(stepId, toolCall);
const index = this.getStepIndex(stepKey);
this.getStepIndex(toolCall.id, index);
toolCall.progress = 0.01;
this.orderedRunSteps.set(index, toolCall);
const progressCallback = this.createToolCallStream(index, toolCall);
this.progressCallbacks.set(stepKey, progressCallback);
this.addContentData({
[ContentTypes.TOOL_CALL]: toolCall,
type: ContentTypes.TOOL_CALL,
index,
});
}
/**
* Handle Completed Tool Call
* @param {string} stepId - The id of the step the tool_call is part of.
* @param {StepToolCall} toolCall - The tool call object.
*
*/
handleCompletedToolCall(stepId, toolCall) {
if (toolCall.type === ToolCallTypes.FUNCTION) {
return;
}
const stepKey = this.generateToolCallKey(stepId, toolCall);
const index = this.getStepIndex(stepKey);
toolCall.progress = 1;
this.orderedRunSteps.set(index, toolCall);
this.addContentData({
[ContentTypes.TOOL_CALL]: toolCall,
type: ContentTypes.TOOL_CALL,
index,
});
}
/**
* Handle Run Step Delta Event
* @param {ThreadRunStepDelta} event -
* The run step delta event object.
*/
async handleRunStepDeltaEvent(event) {
const { delta, id: stepId } = event.data;
if (!delta.step_details) {
logger.warn('Undefined or unhandled run step delta:', delta);
return;
}
/** @type {{ tool_calls: Array<ToolCallDeltaObject> }} */
const { tool_calls } = delta.step_details;
if (!tool_calls) {
logger.warn('Unhandled run step details', delta.step_details);
return;
}
for (const toolCall of tool_calls) {
const stepKey = this.generateToolCallKey(stepId, toolCall);
if (!this.mappedOrder.has(stepKey)) {
this.handleNewToolCall(stepId, toolCall);
continue;
}
const toolCallDelta = toolCall[toolCall.type];
const progressCallback = this.progressCallbacks.get(stepKey);
await progressCallback(toolCallDelta);
}
}
/**
* Handle Message Delta Event
* @param {ThreadMessageDelta} event -
* The Message Delta event object.
*/
async handleMessageDeltaEvent(event) {
const message = event.data;
const onProgress = this.progressCallbacks.get(message.id);
const content = message.delta.content?.[0];
if (content && content.type === MessageContentTypes.TEXT) {
onProgress(content.text.value);
}
}
/**
* Handle Error Event
* @param {ErrorEvent} event -
* The Error event object.
*/
async handleErrorEvent(event) {
logger.error('Error event:', event.data);
}
/* <------------------ Misc. Helpers ------------------> */
/**
* Gets the step index for a given step key, creating a new index if it doesn't exist.
* @param {string} stepKey -
* The access key for the step. Either a message.id, tool_call key, or file_id.
* @param {number | undefined} [overrideIndex] - An override index to use an alternative stepKey.
* This is necessary due to the toolCall Id being unavailable in delta stream events.
* @returns {number | undefined} index - The index of the step; `undefined` if invalid key or using overrideIndex.
*/
getStepIndex(stepKey, overrideIndex) {
if (!stepKey) {
return;
}
if (!isNaN(overrideIndex)) {
this.mappedOrder.set(stepKey, overrideIndex);
return;
}
let index = this.mappedOrder.get(stepKey);
if (index === undefined) {
index = this.index;
this.mappedOrder.set(stepKey, this.index);
this.index++;
}
return index;
}
/**
* Generate Tool Call Key
* @param {string} stepId - The id of the step the tool_call is part of.
* @param {StepToolCall} toolCall - The tool call object.
* @returns {string} key - The generated key for the tool call.
*/
generateToolCallKey(stepId, toolCall) {
return `${stepId}_tool_call_${toolCall.index}_${toolCall.type}`;
}
/* <------------------ Run Event handlers ------------------> */
/**
* Handle Run Events Requiring Action
* @param {ThreadRunRequiresAction} event -
* The run event object requiring action.
*/
async onRunRequiresAction(event) {
const run = event.data;
const { submit_tool_outputs } = run.required_action;
const actions = submit_tool_outputs.tool_calls.map((item) => {
const functionCall = item.function;
const args = JSON.parse(functionCall.arguments);
return {
tool: functionCall.name,
toolInput: args,
toolCallId: item.id,
run_id: run.id,
thread_id: this.thread_id,
};
});
const { tool_outputs } = await processRequiredActions(this, actions);
/** @type {AssistantStream | undefined} */
let toolRun;
try {
toolRun = this.openai.beta.threads.runs.submitToolOutputsStream(
run.thread_id,
run.id,
{
tool_outputs,
stream: true,
},
this.streamOptions,
);
} catch (error) {
logger.error('Error submitting tool outputs:', error);
throw error;
}
for await (const event of toolRun) {
await this.handleEvent(event);
}
}
/* <------------------ RunStep Event handlers ------------------> */
/**
* Handle Run Step Created Events
* @param {ThreadRunStepCreated} event -
* The created run step event object.
*/
async onRunStepCreated(event) {
const step = event.data;
const isMessage = step.type === StepTypes.MESSAGE_CREATION;
if (isMessage) {
/** @type {MessageCreationStepDetails} */
const { message_creation } = step.step_details;
const stepKey = message_creation.message_id;
const index = this.getStepIndex(stepKey);
this.orderedRunSteps.set(index, message_creation);
// Create the Factory Function to stream the message
const { onProgress: progressCallback } = createOnProgress({
// todo: add option to save partialText to db
// onProgress: () => {},
});
// This creates a function that attaches all of the parameters
// specified here to each SSE message generated by the TextStream
const onProgress = progressCallback({
index,
res: this.res,
messageId: this.finalMessage.messageId,
conversationId: this.finalMessage.conversationId,
thread_id: this.thread_id,
type: ContentTypes.TEXT,
});
this.progressCallbacks.set(stepKey, onProgress);
this.orderedRunSteps.set(index, step);
return;
}
if (step.type !== StepTypes.TOOL_CALLS) {
logger.warn('Unhandled step creation type:', step.type);
return;
}
/** @type {{ tool_calls: StepToolCall[] }} */
const { tool_calls } = step.step_details;
for (const toolCall of tool_calls) {
this.handleNewToolCall(step.id, toolCall);
}
}
/**
* Handle Run Step Completed Events
* @param {ThreadRunStepCompleted} event -
* The completed run step event object.
*/
async onRunStepCompleted(event) {
const step = event.data;
const isMessage = step.type === StepTypes.MESSAGE_CREATION;
if (isMessage) {
logger.warn('RunStep Message completion: to be handled by Message Event.', step);
return;
}
/** @type {{ tool_calls: StepToolCall[] }} */
const { tool_calls } = step.step_details;
for (let i = 0; i < tool_calls.length; i++) {
const toolCall = tool_calls[i];
toolCall.index = i;
this.handleCompletedToolCall(step.id, toolCall);
}
}
/* <------------------ Message Event handlers ------------------> */
/**
* Handle Message Event
* @param {ThreadMessageCreated | ThreadMessageInProgress | ThreadMessageCompleted | ThreadMessageIncomplete} event -
* The Message event object.
*/
async handleMessageEvent(event) {
if (event.event === AssistantStreamEvents.ThreadMessageCompleted) {
this.messageCompleted(event);
}
}
/**
* Handle Message Completed Events
* @param {ThreadMessageCompleted} event -
* The Completed Message event object.
*/
async messageCompleted(event) {
const message = event.data;
const result = await processMessages({
openai: this.openai,
client: this,
messages: [message],
});
const index = this.mappedOrder.get(message.id);
this.addContentData({
[ContentTypes.TEXT]: { value: result.text },
type: ContentTypes.TEXT,
index,
});
this.messages.push(message);
}
}
module.exports = StreamRunManager;

View file

@ -1,9 +1,11 @@
const handle = require('./handle');
const methods = require('./methods');
const RunManager = require('./RunManager');
const StreamRunManager = require('./StreamRunManager');
module.exports = {
...handle,
...methods,
RunManager,
StreamRunManager,
};

View file

@ -1,14 +1,19 @@
const path = require('path');
const { v4 } = require('uuid');
const {
EModelEndpoint,
Constants,
defaultOrderQuery,
FilePurpose,
ContentTypes,
imageExtRegex,
EModelEndpoint,
defaultOrderQuery,
} = require('librechat-data-provider');
const { retrieveAndProcessFile } = require('~/server/services/Files/process');
const { recordMessage, getMessages } = require('~/models/Message');
const { saveConvo } = require('~/models/Conversation');
const spendTokens = require('~/models/spendTokens');
const { countTokens } = require('~/server/utils');
const { logger } = require('~/config');
/**
* Initializes a new thread or adds messages to an existing thread.
@ -484,9 +489,108 @@ const recordUsage = async ({ prompt_tokens, completion_tokens, model, user, conv
);
};
/**
* Sorts, processes, and flattens messages to a single string.
*
* @param {object} params - The OpenAI client instance.
* @param {OpenAIClient} params.openai - The OpenAI client instance.
* @param {RunClient} params.client - The LibreChat client that manages the run: either refers to `OpenAI` or `StreamRunManager`.
* @param {ThreadMessage[]} params.messages - An array of messages.
* @returns {Promise<{messages: ThreadMessage[], text: string}>} The sorted messages and the flattened text.
*/
async function processMessages({ openai, client, messages = [] }) {
const sorted = messages.sort((a, b) => a.created_at - b.created_at);
let text = '';
for (const message of sorted) {
message.files = [];
for (const content of message.content) {
const processImageFile =
content.type === 'image_file' && !client.processedFileIds.has(content.image_file?.file_id);
if (processImageFile) {
const { file_id } = content.image_file;
const file = await retrieveAndProcessFile({
openai,
client,
file_id,
basename: `${file_id}.png`,
});
client.processedFileIds.add(file_id);
message.files.push(file);
continue;
}
text += (content.text?.value ?? '') + ' ';
logger.debug('[processMessages] Processing message:', { value: text });
// Process annotations if they exist
if (!content.text?.annotations?.length) {
continue;
}
logger.debug('[processMessages] Processing annotations:', content.text.annotations);
for (const annotation of content.text.annotations) {
logger.debug('Current annotation:', annotation);
let file;
const processFilePath =
annotation.file_path && !client.processedFileIds.has(annotation.file_path?.file_id);
if (processFilePath) {
const basename = imageExtRegex.test(annotation.text)
? path.basename(annotation.text)
: null;
file = await retrieveAndProcessFile({
openai,
client,
file_id: annotation.file_path.file_id,
basename,
});
client.processedFileIds.add(annotation.file_path.file_id);
}
const processFileCitation =
annotation.file_citation &&
!client.processedFileIds.has(annotation.file_citation?.file_id);
if (processFileCitation) {
file = await retrieveAndProcessFile({
openai,
client,
file_id: annotation.file_citation.file_id,
unknownType: true,
});
client.processedFileIds.add(annotation.file_citation.file_id);
}
if (!file && (annotation.file_path || annotation.file_citation)) {
const { file_id } = annotation.file_citation || annotation.file_path || {};
file = await retrieveAndProcessFile({ openai, client, file_id, unknownType: true });
client.processedFileIds.add(file_id);
}
if (!file) {
continue;
}
if (file.purpose && file.purpose === FilePurpose.Assistants) {
text = text.replace(annotation.text, file.filename);
} else if (file.filepath) {
text = text.replace(annotation.text, file.filepath);
}
message.files.push(file);
}
}
}
return { messages: sorted, text };
}
module.exports = {
initThread,
recordUsage,
processMessages,
saveUserMessage,
checkMessageGaps,
addThreadMetadata,

View file

@ -10,7 +10,7 @@ const {
validateAndParseOpenAPISpec,
actionDelimiter,
} = require('librechat-data-provider');
const { loadActionSets, createActionTool } = require('./ActionService');
const { loadActionSets, createActionTool, domainParser } = require('./ActionService');
const { processFileURL } = require('~/server/services/Files/process');
const { loadTools } = require('~/app/clients/tools/util');
const { redactMessage } = require('~/config/parsers');
@ -112,26 +112,26 @@ function formatToOpenAIAssistantTool(tool) {
/**
* Processes return required actions from run.
*
* @param {OpenAIClient} openai - OpenAI Client.
* @param {OpenAIClient} client - OpenAI or StreamRunManager Client.
* @param {RequiredAction[]} requiredActions - The required actions to submit outputs for.
* @returns {Promise<ToolOutputs>} The outputs of the tools.
*
*/
async function processRequiredActions(openai, requiredActions) {
async function processRequiredActions(client, requiredActions) {
logger.debug(
`[required actions] user: ${openai.req.user.id} | thread_id: ${requiredActions[0].thread_id} | run_id: ${requiredActions[0].run_id}`,
`[required actions] user: ${client.req.user.id} | thread_id: ${requiredActions[0].thread_id} | run_id: ${requiredActions[0].run_id}`,
requiredActions,
);
const tools = requiredActions.map((action) => action.tool);
const loadedTools = await loadTools({
user: openai.req.user.id,
model: openai.req.body.model ?? 'gpt-3.5-turbo-1106',
user: client.req.user.id,
model: client.req.body.model ?? 'gpt-3.5-turbo-1106',
tools,
functions: true,
options: {
processFileURL,
openAIApiKey: openai.apiKey,
fileStrategy: openai.req.app.locals.fileStrategy,
openAIApiKey: client.apiKey,
fileStrategy: client.req.app.locals.fileStrategy,
returnMetadata: true,
},
skipSpecs: true,
@ -170,14 +170,14 @@ async function processRequiredActions(openai, requiredActions) {
action: isActionTool,
};
const toolCallIndex = openai.mappedOrder.get(toolCall.id);
const toolCallIndex = client.mappedOrder.get(toolCall.id);
if (imageGenTools.has(currentAction.tool)) {
const imageOutput = output;
toolCall.function.output = `${currentAction.tool} displayed an image. All generated images are already plainly visible, so don't repeat the descriptions in detail. Do not list download links as they are available in the UI already. The user may download the images by clicking on them, but do not mention anything about downloading to the user.`;
// Streams the "Finished" state of the tool call in the UI
openai.addContentData({
client.addContentData({
[ContentTypes.TOOL_CALL]: toolCall,
index: toolCallIndex,
type: ContentTypes.TOOL_CALL,
@ -198,10 +198,10 @@ async function processRequiredActions(openai, requiredActions) {
index: toolCallIndex,
};
openai.addContentData(image_file);
client.addContentData(image_file);
// Update the stored tool call
openai.seenToolCalls.set(toolCall.id, toolCall);
client.seenToolCalls && client.seenToolCalls.set(toolCall.id, toolCall);
return {
tool_call_id: currentAction.toolCallId,
@ -209,8 +209,8 @@ async function processRequiredActions(openai, requiredActions) {
};
}
openai.seenToolCalls.set(toolCall.id, toolCall);
openai.addContentData({
client.seenToolCalls && client.seenToolCalls.set(toolCall.id, toolCall);
client.addContentData({
[ContentTypes.TOOL_CALL]: toolCall,
index: toolCallIndex,
type: ContentTypes.TOOL_CALL,
@ -230,13 +230,13 @@ async function processRequiredActions(openai, requiredActions) {
if (!actionSets.length) {
actionSets =
(await loadActionSets({
user: openai.req.user.id,
assistant_id: openai.req.body.assistant_id,
user: client.req.user.id,
assistant_id: client.req.body.assistant_id,
})) ?? [];
}
const actionSet = actionSets.find((action) =>
currentAction.tool.includes(action.metadata.domain),
currentAction.tool.includes(domainParser(client.req, action.metadata.domain, true)),
);
if (!actionSet) {
@ -251,7 +251,7 @@ async function processRequiredActions(openai, requiredActions) {
const validationResult = validateAndParseOpenAPISpec(actionSet.metadata.raw_spec);
if (!validationResult.spec) {
throw new Error(
`Invalid spec: user: ${openai.req.user.id} | thread_id: ${requiredActions[0].thread_id} | run_id: ${requiredActions[0].run_id}`,
`Invalid spec: user: ${client.req.user.id} | thread_id: ${requiredActions[0].thread_id} | run_id: ${requiredActions[0].run_id}`,
);
}
const { requestBuilders } = openapiToFunction(validationResult.spec);
@ -260,7 +260,7 @@ async function processRequiredActions(openai, requiredActions) {
}
const functionName = currentAction.tool.replace(
`${actionDelimiter}${actionSet.metadata.domain}`,
`${actionDelimiter}${domainParser(client.req, actionSet.metadata.domain, true)}`,
'',
);
const requestBuilder = builders[functionName];