👓 feat: Vision Support for Assistants (#2195)

* refactor(assistants/chat): use promises to speed up initialization, initialize shared variables, include `attachedFileIds` to streamRunManager

* chore: additional typedefs

* fix(OpenAIClient): handle edge case where attachments promise is resolved

* feat: createVisionPrompt

* feat: Vision Support for Assistants
This commit is contained in:
Danny Avila 2024-03-24 23:43:00 -04:00 committed by GitHub
parent 1f0fb497f8
commit 798e8763d0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 376 additions and 100 deletions

View file

@ -4,9 +4,11 @@ const {
Constants,
RunStatus,
CacheKeys,
FileSources,
ContentTypes,
EModelEndpoint,
ViolationTypes,
ImageVisionTool,
AssistantStreamEvents,
} = require('librechat-data-provider');
const {
@ -17,9 +19,10 @@ const {
addThreadMetadata,
saveAssistantMessage,
} = require('~/server/services/Threads');
const { sendResponse, sendMessage, sleep, isEnabled, countTokens } = require('~/server/utils');
const { runAssistant, createOnTextProgress } = require('~/server/services/AssistantService');
const { addTitle, initializeClient } = require('~/server/services/Endpoints/assistants');
const { sendResponse, sendMessage, sleep, isEnabled, countTokens } = require('~/server/utils');
const { formatMessage, createVisionPrompt } = require('~/app/clients/prompts');
const { createRun, StreamRunManager } = require('~/server/services/Runs');
const { getTransactions } = require('~/models/Transaction');
const checkBalance = require('~/models/checkBalance');
@ -100,6 +103,16 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res
let parentMessageId = _parentId;
/** @type {TMessage[]} */
let previousMessages = [];
/** @type {import('librechat-data-provider').TConversation | null} */
let conversation = null;
/** @type {string[]} */
let file_ids = [];
/** @type {Set<string>} */
let attachedFileIds = new Set();
/** @type {TMessage | null} */
let requestMessage = null;
/** @type {undefined | Promise<ChatCompletion>} */
let visionPromise;
const userMessageId = v4();
const responseMessageId = v4();
@ -258,7 +271,10 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res
throw new Error('Missing assistant_id');
}
if (isEnabled(process.env.CHECK_BALANCE)) {
const checkBalanceBeforeRun = async () => {
if (!isEnabled(process.env.CHECK_BALANCE)) {
return;
}
const transactions =
(await getTransactions({
user: req.user.id,
@ -288,7 +304,7 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res
amount: promptTokens,
},
});
}
};
/** @type {{ openai: OpenAIClient }} */
const { openai: _openai, client } = await initializeClient({
@ -300,15 +316,11 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res
openai = _openai;
// if (thread_id) {
// previousMessages = await checkMessageGaps({ openai, thread_id, conversationId });
// }
if (previousMessages.length) {
parentMessageId = previousMessages[previousMessages.length - 1].messageId;
}
const userMessage = {
let userMessage = {
role: 'user',
content: text,
metadata: {
@ -316,75 +328,7 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res
},
};
let thread_file_ids = [];
if (convoId) {
const convo = await getConvo(req.user.id, convoId);
if (convo && convo.file_ids) {
thread_file_ids = convo.file_ids;
}
}
const file_ids = files.map(({ file_id }) => file_id);
if (file_ids.length || thread_file_ids.length) {
userMessage.file_ids = file_ids;
openai.attachedFileIds = new Set([...file_ids, ...thread_file_ids]);
}
// TODO: may allow multiple messages to be created beforehand in a future update
const initThreadBody = {
messages: [userMessage],
metadata: {
user: req.user.id,
conversationId,
},
};
const result = await initThread({ openai, body: initThreadBody, thread_id });
thread_id = result.thread_id;
createOnTextProgress({
openai,
conversationId,
userMessageId,
messageId: responseMessageId,
thread_id,
});
const requestMessage = {
user: req.user.id,
text,
messageId: userMessageId,
parentMessageId,
// TODO: make sure client sends correct format for `files`, use zod
files,
file_ids,
conversationId,
isCreatedByUser: true,
assistant_id,
thread_id,
model: assistant_id,
};
previousMessages.push(requestMessage);
await saveUserMessage({ ...requestMessage, model });
const conversation = {
conversationId,
// TODO: title feature
title: 'New Chat',
endpoint: EModelEndpoint.assistants,
promptPrefix: promptPrefix,
instructions: instructions,
assistant_id,
// model,
};
if (file_ids.length) {
conversation.file_ids = file_ids;
}
/** @type {CreateRunBody} */
/** @type {CreateRunBody | undefined} */
const body = {
assistant_id,
model,
@ -398,6 +342,143 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res
body.instructions = instructions;
}
const getRequestFileIds = async () => {
let thread_file_ids = [];
if (convoId) {
const convo = await getConvo(req.user.id, convoId);
if (convo && convo.file_ids) {
thread_file_ids = convo.file_ids;
}
}
file_ids = files.map(({ file_id }) => file_id);
if (file_ids.length || thread_file_ids.length) {
userMessage.file_ids = file_ids;
attachedFileIds = new Set([...file_ids, ...thread_file_ids]);
}
};
const addVisionPrompt = async () => {
if (!req.body.endpointOption.attachments) {
return;
}
const assistant = await openai.beta.assistants.retrieve(assistant_id);
const visionToolIndex = assistant.tools.findIndex(
(tool) => tool.function.name === ImageVisionTool.function.name,
);
if (visionToolIndex === -1) {
return;
}
const attachments = await req.body.endpointOption.attachments;
let visionMessage = {
role: 'user',
content: '',
};
const files = await client.addImageURLs(visionMessage, attachments);
if (!visionMessage.image_urls?.length) {
return;
}
const imageCount = visionMessage.image_urls.length;
const plural = imageCount > 1;
visionMessage.content = createVisionPrompt(plural);
visionMessage = formatMessage({ message: visionMessage, endpoint: EModelEndpoint.openAI });
visionPromise = openai.chat.completions.create({
model: 'gpt-4-vision-preview',
messages: [visionMessage],
max_tokens: 4000,
});
const pluralized = plural ? 's' : '';
body.additional_instructions = `${
body.additional_instructions ? `${body.additional_instructions}\n` : ''
}The user has uploaded ${imageCount} image${pluralized}.
Use the \`${ImageVisionTool.function.name}\` tool to retrieve ${
plural ? '' : 'a '
}detailed text description${pluralized} for ${plural ? 'each' : 'the'} image${pluralized}.`;
return files;
};
const initializeThread = async () => {
/** @type {[ undefined | MongoFile[]]}*/
const [processedFiles] = await Promise.all([addVisionPrompt(), getRequestFileIds()]);
// TODO: may allow multiple messages to be created beforehand in a future update
const initThreadBody = {
messages: [userMessage],
metadata: {
user: req.user.id,
conversationId,
},
};
if (processedFiles) {
for (const file of processedFiles) {
if (file.source !== FileSources.openai) {
attachedFileIds.delete(file.file_id);
const index = file_ids.indexOf(file.file_id);
if (index > -1) {
file_ids.splice(index, 1);
}
}
}
userMessage.file_ids = file_ids;
}
const result = await initThread({ openai, body: initThreadBody, thread_id });
thread_id = result.thread_id;
createOnTextProgress({
openai,
conversationId,
userMessageId,
messageId: responseMessageId,
thread_id,
});
requestMessage = {
user: req.user.id,
text,
messageId: userMessageId,
parentMessageId,
// TODO: make sure client sends correct format for `files`, use zod
files,
file_ids,
conversationId,
isCreatedByUser: true,
assistant_id,
thread_id,
model: assistant_id,
};
previousMessages.push(requestMessage);
/* asynchronous */
saveUserMessage({ ...requestMessage, model });
conversation = {
conversationId,
title: 'New Chat',
endpoint: EModelEndpoint.assistants,
promptPrefix: promptPrefix,
instructions: instructions,
assistant_id,
// model,
};
if (file_ids.length) {
conversation.file_ids = file_ids;
}
};
const promises = [initializeThread(), checkBalanceBeforeRun()];
await Promise.all(promises);
const sendInitialResponse = () => {
sendMessage(res, {
sync: true,
@ -421,6 +502,8 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res
const processRun = async (retry = false) => {
if (req.app.locals[EModelEndpoint.azureOpenAI]?.assistants) {
openai.attachedFileIds = attachedFileIds;
openai.visionPromise = visionPromise;
if (retry) {
response = await runAssistant({
openai,
@ -463,9 +546,11 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res
req,
res,
openai,
thread_id,
responseMessage: openai.responseMessage,
handlers,
thread_id,
visionPromise,
attachedFileIds,
responseMessage: openai.responseMessage,
// streamOptions: {
// },