👓 feat: Vision Support for Assistants (#2195)

* refactor(assistants/chat): use promises to speed up initialization, initialize shared variables, include `attachedFileIds` to streamRunManager

* chore: additional typedefs

* fix(OpenAIClient): handle edge case where attachments promise is resolved

* feat: createVisionPrompt

* feat: Vision Support for Assistants
This commit is contained in:
Danny Avila 2024-03-24 23:43:00 -04:00 committed by GitHub
parent 1f0fb497f8
commit 798e8763d0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 376 additions and 100 deletions

View file

@ -92,7 +92,11 @@ class OpenAIClient extends BaseClient {
} }
this.defaultVisionModel = this.options.visionModel ?? 'gpt-4-vision-preview'; this.defaultVisionModel = this.options.visionModel ?? 'gpt-4-vision-preview';
this.options.attachments?.then((attachments) => this.checkVisionRequest(attachments)); if (typeof this.options.attachments?.then === 'function') {
this.options.attachments.then((attachments) => this.checkVisionRequest(attachments));
} else {
this.checkVisionRequest(this.options.attachments);
}
const { OPENROUTER_API_KEY, OPENAI_FORCE_PROMPT } = process.env ?? {}; const { OPENROUTER_API_KEY, OPENAI_FORCE_PROMPT } = process.env ?? {};
if (OPENROUTER_API_KEY && !this.azure) { if (OPENROUTER_API_KEY && !this.azure) {

View file

@ -0,0 +1,34 @@
/**
* Generates a prompt instructing the user to describe an image in detail, tailored to different types of visual content.
* @param {boolean} pluralized - Whether to pluralize the prompt for multiple images.
* @returns {string} - The generated vision prompt.
*/
const createVisionPrompt = (pluralized = false) => {
return `Please describe the image${
pluralized ? 's' : ''
} in detail, covering relevant aspects such as:
For photographs, illustrations, or artwork:
- The main subject(s) and their appearance, positioning, and actions
- The setting, background, and any notable objects or elements
- Colors, lighting, and overall mood or atmosphere
- Any interesting details, textures, or patterns
- The style, technique, or medium used (if discernible)
For screenshots or images containing text:
- The content and purpose of the text
- The layout, formatting, and organization of the information
- Any notable visual elements, such as logos, icons, or graphics
- The overall context or message conveyed by the screenshot
For graphs, charts, or data visualizations:
- The type of graph or chart (e.g., bar graph, line chart, pie chart)
- The variables being compared or analyzed
- Any trends, patterns, or outliers in the data
- The axis labels, scales, and units of measurement
- The title, legend, and any additional context provided
Be as specific and descriptive as possible while maintaining clarity and concision.`;
};
module.exports = createVisionPrompt;

View file

@ -4,6 +4,7 @@ const handleInputs = require('./handleInputs');
const instructions = require('./instructions'); const instructions = require('./instructions');
const titlePrompts = require('./titlePrompts'); const titlePrompts = require('./titlePrompts');
const truncateText = require('./truncateText'); const truncateText = require('./truncateText');
const createVisionPrompt = require('./createVisionPrompt');
const createContextHandlers = require('./createContextHandlers'); const createContextHandlers = require('./createContextHandlers');
module.exports = { module.exports = {
@ -13,5 +14,6 @@ module.exports = {
...instructions, ...instructions,
...titlePrompts, ...titlePrompts,
truncateText, truncateText,
createVisionPrompt,
createContextHandlers, createContextHandlers,
}; };

View file

@ -4,9 +4,11 @@ const {
Constants, Constants,
RunStatus, RunStatus,
CacheKeys, CacheKeys,
FileSources,
ContentTypes, ContentTypes,
EModelEndpoint, EModelEndpoint,
ViolationTypes, ViolationTypes,
ImageVisionTool,
AssistantStreamEvents, AssistantStreamEvents,
} = require('librechat-data-provider'); } = require('librechat-data-provider');
const { const {
@ -17,9 +19,10 @@ const {
addThreadMetadata, addThreadMetadata,
saveAssistantMessage, saveAssistantMessage,
} = require('~/server/services/Threads'); } = require('~/server/services/Threads');
const { sendResponse, sendMessage, sleep, isEnabled, countTokens } = require('~/server/utils');
const { runAssistant, createOnTextProgress } = require('~/server/services/AssistantService'); const { runAssistant, createOnTextProgress } = require('~/server/services/AssistantService');
const { addTitle, initializeClient } = require('~/server/services/Endpoints/assistants'); const { addTitle, initializeClient } = require('~/server/services/Endpoints/assistants');
const { sendResponse, sendMessage, sleep, isEnabled, countTokens } = require('~/server/utils'); const { formatMessage, createVisionPrompt } = require('~/app/clients/prompts');
const { createRun, StreamRunManager } = require('~/server/services/Runs'); const { createRun, StreamRunManager } = require('~/server/services/Runs');
const { getTransactions } = require('~/models/Transaction'); const { getTransactions } = require('~/models/Transaction');
const checkBalance = require('~/models/checkBalance'); const checkBalance = require('~/models/checkBalance');
@ -100,6 +103,16 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res
let parentMessageId = _parentId; let parentMessageId = _parentId;
/** @type {TMessage[]} */ /** @type {TMessage[]} */
let previousMessages = []; let previousMessages = [];
/** @type {import('librechat-data-provider').TConversation | null} */
let conversation = null;
/** @type {string[]} */
let file_ids = [];
/** @type {Set<string>} */
let attachedFileIds = new Set();
/** @type {TMessage | null} */
let requestMessage = null;
/** @type {undefined | Promise<ChatCompletion>} */
let visionPromise;
const userMessageId = v4(); const userMessageId = v4();
const responseMessageId = v4(); const responseMessageId = v4();
@ -258,7 +271,10 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res
throw new Error('Missing assistant_id'); throw new Error('Missing assistant_id');
} }
if (isEnabled(process.env.CHECK_BALANCE)) { const checkBalanceBeforeRun = async () => {
if (!isEnabled(process.env.CHECK_BALANCE)) {
return;
}
const transactions = const transactions =
(await getTransactions({ (await getTransactions({
user: req.user.id, user: req.user.id,
@ -288,7 +304,7 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res
amount: promptTokens, amount: promptTokens,
}, },
}); });
} };
/** @type {{ openai: OpenAIClient }} */ /** @type {{ openai: OpenAIClient }} */
const { openai: _openai, client } = await initializeClient({ const { openai: _openai, client } = await initializeClient({
@ -300,15 +316,11 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res
openai = _openai; openai = _openai;
// if (thread_id) {
// previousMessages = await checkMessageGaps({ openai, thread_id, conversationId });
// }
if (previousMessages.length) { if (previousMessages.length) {
parentMessageId = previousMessages[previousMessages.length - 1].messageId; parentMessageId = previousMessages[previousMessages.length - 1].messageId;
} }
const userMessage = { let userMessage = {
role: 'user', role: 'user',
content: text, content: text,
metadata: { metadata: {
@ -316,75 +328,7 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res
}, },
}; };
let thread_file_ids = []; /** @type {CreateRunBody | undefined} */
if (convoId) {
const convo = await getConvo(req.user.id, convoId);
if (convo && convo.file_ids) {
thread_file_ids = convo.file_ids;
}
}
const file_ids = files.map(({ file_id }) => file_id);
if (file_ids.length || thread_file_ids.length) {
userMessage.file_ids = file_ids;
openai.attachedFileIds = new Set([...file_ids, ...thread_file_ids]);
}
// TODO: may allow multiple messages to be created beforehand in a future update
const initThreadBody = {
messages: [userMessage],
metadata: {
user: req.user.id,
conversationId,
},
};
const result = await initThread({ openai, body: initThreadBody, thread_id });
thread_id = result.thread_id;
createOnTextProgress({
openai,
conversationId,
userMessageId,
messageId: responseMessageId,
thread_id,
});
const requestMessage = {
user: req.user.id,
text,
messageId: userMessageId,
parentMessageId,
// TODO: make sure client sends correct format for `files`, use zod
files,
file_ids,
conversationId,
isCreatedByUser: true,
assistant_id,
thread_id,
model: assistant_id,
};
previousMessages.push(requestMessage);
await saveUserMessage({ ...requestMessage, model });
const conversation = {
conversationId,
// TODO: title feature
title: 'New Chat',
endpoint: EModelEndpoint.assistants,
promptPrefix: promptPrefix,
instructions: instructions,
assistant_id,
// model,
};
if (file_ids.length) {
conversation.file_ids = file_ids;
}
/** @type {CreateRunBody} */
const body = { const body = {
assistant_id, assistant_id,
model, model,
@ -398,6 +342,143 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res
body.instructions = instructions; body.instructions = instructions;
} }
const getRequestFileIds = async () => {
let thread_file_ids = [];
if (convoId) {
const convo = await getConvo(req.user.id, convoId);
if (convo && convo.file_ids) {
thread_file_ids = convo.file_ids;
}
}
file_ids = files.map(({ file_id }) => file_id);
if (file_ids.length || thread_file_ids.length) {
userMessage.file_ids = file_ids;
attachedFileIds = new Set([...file_ids, ...thread_file_ids]);
}
};
const addVisionPrompt = async () => {
if (!req.body.endpointOption.attachments) {
return;
}
const assistant = await openai.beta.assistants.retrieve(assistant_id);
const visionToolIndex = assistant.tools.findIndex(
(tool) => tool.function.name === ImageVisionTool.function.name,
);
if (visionToolIndex === -1) {
return;
}
const attachments = await req.body.endpointOption.attachments;
let visionMessage = {
role: 'user',
content: '',
};
const files = await client.addImageURLs(visionMessage, attachments);
if (!visionMessage.image_urls?.length) {
return;
}
const imageCount = visionMessage.image_urls.length;
const plural = imageCount > 1;
visionMessage.content = createVisionPrompt(plural);
visionMessage = formatMessage({ message: visionMessage, endpoint: EModelEndpoint.openAI });
visionPromise = openai.chat.completions.create({
model: 'gpt-4-vision-preview',
messages: [visionMessage],
max_tokens: 4000,
});
const pluralized = plural ? 's' : '';
body.additional_instructions = `${
body.additional_instructions ? `${body.additional_instructions}\n` : ''
}The user has uploaded ${imageCount} image${pluralized}.
Use the \`${ImageVisionTool.function.name}\` tool to retrieve ${
plural ? '' : 'a '
}detailed text description${pluralized} for ${plural ? 'each' : 'the'} image${pluralized}.`;
return files;
};
const initializeThread = async () => {
/** @type {[ undefined | MongoFile[]]}*/
const [processedFiles] = await Promise.all([addVisionPrompt(), getRequestFileIds()]);
// TODO: may allow multiple messages to be created beforehand in a future update
const initThreadBody = {
messages: [userMessage],
metadata: {
user: req.user.id,
conversationId,
},
};
if (processedFiles) {
for (const file of processedFiles) {
if (file.source !== FileSources.openai) {
attachedFileIds.delete(file.file_id);
const index = file_ids.indexOf(file.file_id);
if (index > -1) {
file_ids.splice(index, 1);
}
}
}
userMessage.file_ids = file_ids;
}
const result = await initThread({ openai, body: initThreadBody, thread_id });
thread_id = result.thread_id;
createOnTextProgress({
openai,
conversationId,
userMessageId,
messageId: responseMessageId,
thread_id,
});
requestMessage = {
user: req.user.id,
text,
messageId: userMessageId,
parentMessageId,
// TODO: make sure client sends correct format for `files`, use zod
files,
file_ids,
conversationId,
isCreatedByUser: true,
assistant_id,
thread_id,
model: assistant_id,
};
previousMessages.push(requestMessage);
/* asynchronous */
saveUserMessage({ ...requestMessage, model });
conversation = {
conversationId,
title: 'New Chat',
endpoint: EModelEndpoint.assistants,
promptPrefix: promptPrefix,
instructions: instructions,
assistant_id,
// model,
};
if (file_ids.length) {
conversation.file_ids = file_ids;
}
};
const promises = [initializeThread(), checkBalanceBeforeRun()];
await Promise.all(promises);
const sendInitialResponse = () => { const sendInitialResponse = () => {
sendMessage(res, { sendMessage(res, {
sync: true, sync: true,
@ -421,6 +502,8 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res
const processRun = async (retry = false) => { const processRun = async (retry = false) => {
if (req.app.locals[EModelEndpoint.azureOpenAI]?.assistants) { if (req.app.locals[EModelEndpoint.azureOpenAI]?.assistants) {
openai.attachedFileIds = attachedFileIds;
openai.visionPromise = visionPromise;
if (retry) { if (retry) {
response = await runAssistant({ response = await runAssistant({
openai, openai,
@ -463,9 +546,11 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res
req, req,
res, res,
openai, openai,
thread_id,
responseMessage: openai.responseMessage,
handlers, handlers,
thread_id,
visionPromise,
attachedFileIds,
responseMessage: openai.responseMessage,
// streamOptions: { // streamOptions: {
// }, // },

View file

@ -59,6 +59,10 @@ class StreamRunManager {
this.messages = []; this.messages = [];
/** @type {string} */ /** @type {string} */
this.text = ''; this.text = '';
/** @type {Set<string>} */
this.attachedFileIds = fields.attachedFileIds;
/** @type {undefined | Promise<ChatCompletion>} */
this.visionPromise = fields.visionPromise;
/** /**
* @type {Object.<AssistantStreamEvents, (event: AssistantStreamEvent) => Promise<void>>} * @type {Object.<AssistantStreamEvents, (event: AssistantStreamEvent) => Promise<void>>}

View file

@ -468,21 +468,28 @@ async function checkMessageGaps({ openai, latestMessageId, thread_id, run_id, co
/** /**
* Records token usage for a given completion request. * Records token usage for a given completion request.
*
* @param {Object} params - The parameters for initializing a thread. * @param {Object} params - The parameters for initializing a thread.
* @param {number} params.prompt_tokens - The number of prompt tokens used. * @param {number} params.prompt_tokens - The number of prompt tokens used.
* @param {number} params.completion_tokens - The number of completion tokens used. * @param {number} params.completion_tokens - The number of completion tokens used.
* @param {string} params.model - The model used by the assistant run. * @param {string} params.model - The model used by the assistant run.
* @param {string} params.user - The user's ID. * @param {string} params.user - The user's ID.
* @param {string} params.conversationId - LibreChat conversation ID. * @param {string} params.conversationId - LibreChat conversation ID.
* @param {string} [params.context='message'] - The context of the usage. Defaults to 'message'.
* @return {Promise<TMessage[]>} A promise that resolves to the updated messages * @return {Promise<TMessage[]>} A promise that resolves to the updated messages
*/ */
const recordUsage = async ({ prompt_tokens, completion_tokens, model, user, conversationId }) => { const recordUsage = async ({
prompt_tokens,
completion_tokens,
model,
user,
conversationId,
context = 'message',
}) => {
await spendTokens( await spendTokens(
{ {
user, user,
model, model,
context: 'message', context,
conversationId, conversationId,
}, },
{ promptTokens: prompt_tokens, completionTokens: completion_tokens }, { promptTokens: prompt_tokens, completionTokens: completion_tokens },

View file

@ -4,14 +4,17 @@ const { StructuredTool } = require('langchain/tools');
const { zodToJsonSchema } = require('zod-to-json-schema'); const { zodToJsonSchema } = require('zod-to-json-schema');
const { Calculator } = require('langchain/tools/calculator'); const { Calculator } = require('langchain/tools/calculator');
const { const {
Tools,
ContentTypes, ContentTypes,
imageGenTools, imageGenTools,
actionDelimiter,
ImageVisionTool,
openapiToFunction, openapiToFunction,
validateAndParseOpenAPISpec, validateAndParseOpenAPISpec,
actionDelimiter,
} = require('librechat-data-provider'); } = require('librechat-data-provider');
const { loadActionSets, createActionTool, domainParser } = require('./ActionService'); const { loadActionSets, createActionTool, domainParser } = require('./ActionService');
const { processFileURL } = require('~/server/services/Files/process'); const { processFileURL } = require('~/server/services/Files/process');
const { recordUsage } = require('~/server/services/Threads');
const { loadTools } = require('~/app/clients/tools/util'); const { loadTools } = require('~/app/clients/tools/util');
const { redactMessage } = require('~/config/parsers'); const { redactMessage } = require('~/config/parsers');
const { sleep } = require('~/server/utils'); const { sleep } = require('~/server/utils');
@ -83,6 +86,8 @@ function loadAndFormatTools({ directory, filter = new Set() }) {
tools.push(formattedTool); tools.push(formattedTool);
} }
tools.push(ImageVisionTool);
return tools.reduce((map, tool) => { return tools.reduce((map, tool) => {
map[tool.function.name] = tool; map[tool.function.name] = tool;
return map; return map;
@ -100,8 +105,8 @@ function loadAndFormatTools({ directory, filter = new Set() }) {
*/ */
function formatToOpenAIAssistantTool(tool) { function formatToOpenAIAssistantTool(tool) {
return { return {
type: 'function', type: Tools.function,
function: { [Tools.function]: {
name: tool.name, name: tool.name,
description: tool.description, description: tool.description,
parameters: zodToJsonSchema(tool.schema), parameters: zodToJsonSchema(tool.schema),
@ -109,13 +114,42 @@ function formatToOpenAIAssistantTool(tool) {
}; };
} }
/**
* Processes the required actions by calling the appropriate tools and returning the outputs.
* @param {OpenAIClient} client - OpenAI or StreamRunManager Client.
* @param {RequiredAction} requiredActions - The current required action.
* @returns {Promise<ToolOutput>} The outputs of the tools.
*/
const processVisionRequest = async (client, currentAction) => {
if (!client.visionPromise) {
return {
tool_call_id: currentAction.toolCallId,
output: 'No image details found.',
};
}
/** @type {ChatCompletion | undefined} */
const completion = await client.visionPromise;
if (completion.usage) {
recordUsage({
user: client.req.user.id,
model: client.req.body.model,
conversationId: (client.responseMessage ?? client.finalMessage).conversationId,
...completion.usage,
});
}
const output = completion?.choices?.[0]?.message?.content ?? 'No image details found.';
return {
tool_call_id: currentAction.toolCallId,
output,
};
};
/** /**
* Processes return required actions from run. * Processes return required actions from run.
*
* @param {OpenAIClient} client - OpenAI or StreamRunManager Client. * @param {OpenAIClient} client - OpenAI or StreamRunManager Client.
* @param {RequiredAction[]} requiredActions - The required actions to submit outputs for. * @param {RequiredAction[]} requiredActions - The required actions to submit outputs for.
* @returns {Promise<ToolOutputs>} The outputs of the tools. * @returns {Promise<ToolOutputs>} The outputs of the tools.
*
*/ */
async function processRequiredActions(client, requiredActions) { async function processRequiredActions(client, requiredActions) {
logger.debug( logger.debug(
@ -152,6 +186,10 @@ async function processRequiredActions(client, requiredActions) {
for (let i = 0; i < requiredActions.length; i++) { for (let i = 0; i < requiredActions.length; i++) {
const currentAction = requiredActions[i]; const currentAction = requiredActions[i];
if (currentAction.tool === ImageVisionTool.function.name) {
promises.push(processVisionRequest(client, currentAction));
continue;
}
let tool = ToolMap[currentAction.tool] ?? ActionToolMap[currentAction.tool]; let tool = ToolMap[currentAction.tool] ?? ActionToolMap[currentAction.tool];
const handleToolOutput = async (output) => { const handleToolOutput = async (output) => {

View file

@ -172,6 +172,7 @@ function generateConfig(key, baseURL, assistants = false) {
config.retrievalModels = defaultRetrievalModels; config.retrievalModels = defaultRetrievalModels;
config.capabilities = [ config.capabilities = [
Capabilities.code_interpreter, Capabilities.code_interpreter,
Capabilities.image_vision,
Capabilities.retrieval, Capabilities.retrieval,
Capabilities.actions, Capabilities.actions,
Capabilities.tools, Capabilities.tools,

View file

@ -32,6 +32,18 @@
* @memberof typedefs * @memberof typedefs
*/ */
/**
* @exports ChatCompletionContentPartImage
* @typedef {import('openai').OpenAI.ChatCompletionContentPartImage} ChatCompletionContentPartImage
* @memberof typedefs
*/
/**
* @exports ChatCompletion
* @typedef {import('openai').OpenAI.ChatCompletion} ChatCompletion
* @memberof typedefs
*/
/** /**
* @exports OpenAIRequestOptions * @exports OpenAIRequestOptions
* @typedef {import('openai').OpenAI.RequestOptions} OpenAIRequestOptions * @typedef {import('openai').OpenAI.RequestOptions} OpenAIRequestOptions

View file

@ -1,3 +1,4 @@
import { Capabilities } from 'librechat-data-provider';
import type { Assistant } from 'librechat-data-provider'; import type { Assistant } from 'librechat-data-provider';
import type { Option, ExtendedFile } from './types'; import type { Option, ExtendedFile } from './types';
@ -6,8 +7,9 @@ export type TAssistantOption =
| (Option & Assistant & { files?: Array<[string, ExtendedFile]> }); | (Option & Assistant & { files?: Array<[string, ExtendedFile]> });
export type Actions = { export type Actions = {
code_interpreter: boolean; [Capabilities.code_interpreter]: boolean;
retrieval: boolean; [Capabilities.image_vision]: boolean;
[Capabilities.retrieval]: boolean;
}; };
export type AssistantForm = { export type AssistantForm = {

View file

@ -1,4 +1,9 @@
import { ToolCallTypes, ContentTypes, imageGenTools } from 'librechat-data-provider'; import {
ToolCallTypes,
ContentTypes,
imageGenTools,
isImageVisionTool,
} from 'librechat-data-provider';
import type { TMessageContentParts, TMessage } from 'librechat-data-provider'; import type { TMessageContentParts, TMessage } from 'librechat-data-provider';
import type { TDisplayProps } from '~/common'; import type { TDisplayProps } from '~/common';
import { ErrorMessage } from './MessageContent'; import { ErrorMessage } from './MessageContent';
@ -96,6 +101,25 @@ export default function Part({
part[ContentTypes.TOOL_CALL].type === ToolCallTypes.FUNCTION part[ContentTypes.TOOL_CALL].type === ToolCallTypes.FUNCTION
) { ) {
const toolCall = part[ContentTypes.TOOL_CALL]; const toolCall = part[ContentTypes.TOOL_CALL];
if (isImageVisionTool(toolCall)) {
if (isSubmitting && showCursor) {
return (
<Container>
<div className="markdown prose dark:prose-invert light dark:text-gray-70 my-1 w-full break-words">
<DisplayMessage
text={''}
isCreatedByUser={message.isCreatedByUser}
message={message}
showCursor={showCursor}
/>
</div>
</Container>
);
}
return null;
}
return ( return (
<ToolCall <ToolCall
initialProgress={toolCall.progress ?? 0.1} initialProgress={toolCall.progress ?? 0.1}

View file

@ -8,6 +8,7 @@ import {
Capabilities, Capabilities,
EModelEndpoint, EModelEndpoint,
actionDelimiter, actionDelimiter,
ImageVisionTool,
defaultAssistantFormValues, defaultAssistantFormValues,
} from 'librechat-data-provider'; } from 'librechat-data-provider';
import type { AssistantForm, AssistantPanelProps } from '~/common'; import type { AssistantForm, AssistantPanelProps } from '~/common';
@ -82,6 +83,10 @@ export default function AssistantPanel({
() => assistants?.capabilities?.includes(Capabilities.code_interpreter), () => assistants?.capabilities?.includes(Capabilities.code_interpreter),
[assistants], [assistants],
); );
const imageVisionEnabled = useMemo(
() => assistants?.capabilities?.includes(Capabilities.image_vision),
[assistants],
);
useEffect(() => { useEffect(() => {
if (model && !retrievalModels.has(model)) { if (model && !retrievalModels.has(model)) {
@ -157,6 +162,9 @@ export default function AssistantPanel({
if (data.retrieval) { if (data.retrieval) {
tools.push({ type: Tools.retrieval }); tools.push({ type: Tools.retrieval });
} }
if (data.image_vision) {
tools.push(ImageVisionTool);
}
const { const {
name, name,
@ -374,6 +382,37 @@ export default function AssistantPanel({
</label> </label>
</div> </div>
)} )}
{imageVisionEnabled && (
<div className="flex items-center">
<Controller
name={Capabilities.image_vision}
control={control}
render={({ field }) => (
<Checkbox
{...field}
checked={field.value}
onCheckedChange={field.onChange}
className="relative float-left mr-2 inline-flex h-4 w-4 cursor-pointer"
value={field?.value?.toString()}
/>
)}
/>
<label
className="form-check-label text-token-text-primary w-full cursor-pointer"
htmlFor={Capabilities.image_vision}
onClick={() =>
setValue(Capabilities.image_vision, !getValues(Capabilities.image_vision), {
shouldDirty: true,
})
}
>
<div className="flex items-center">
{localize('com_assistants_image_vision')}
<QuestionMark />
</div>
</label>
</div>
)}
{retrievalEnabled && ( {retrievalEnabled && (
<div className="flex items-center"> <div className="flex items-center">
<Controller <Controller
@ -417,9 +456,9 @@ export default function AssistantPanel({
${actionsEnabled ? localize('com_assistants_actions') : ''}`} ${actionsEnabled ? localize('com_assistants_actions') : ''}`}
</label> </label>
<div className="space-y-1"> <div className="space-y-1">
{functions.map((func) => ( {functions.map((func, i) => (
<AssistantTool <AssistantTool
key={func} key={`${func}-${i}-${assistant_id}`}
tool={func} tool={func}
allTools={allTools} allTools={allTools}
assistant_id={assistant_id} assistant_id={assistant_id}

View file

@ -3,6 +3,8 @@ import { useCallback, useEffect, useRef } from 'react';
import { import {
defaultAssistantFormValues, defaultAssistantFormValues,
defaultOrderQuery, defaultOrderQuery,
isImageVisionTool,
Capabilities,
FileSources, FileSources,
} from 'librechat-data-provider'; } from 'librechat-data-provider';
import type { UseFormReset } from 'react-hook-form'; import type { UseFormReset } from 'react-hook-form';
@ -13,7 +15,7 @@ import SelectDropDown from '~/components/ui/SelectDropDown';
import { useListAssistantsQuery } from '~/data-provider'; import { useListAssistantsQuery } from '~/data-provider';
import { useFileMapContext } from '~/Providers'; import { useFileMapContext } from '~/Providers';
import { useLocalize } from '~/hooks'; import { useLocalize } from '~/hooks';
import { cn } from '~/utils/'; import { cn } from '~/utils';
const keys = new Set(['name', 'id', 'description', 'instructions', 'model']); const keys = new Set(['name', 'id', 'description', 'instructions', 'model']);
@ -87,20 +89,21 @@ export default function AssistantSelect({
}; };
const actions: Actions = { const actions: Actions = {
code_interpreter: false, [Capabilities.code_interpreter]: false,
retrieval: false, [Capabilities.image_vision]: false,
[Capabilities.retrieval]: false,
}; };
assistant?.tools assistant?.tools
?.filter((tool) => tool.type !== 'function') ?.filter((tool) => tool.type !== 'function' || isImageVisionTool(tool))
?.map((tool) => tool.type) ?.map((tool) => tool?.function?.name || tool.type)
.forEach((tool) => { .forEach((tool) => {
actions[tool] = true; actions[tool] = true;
}); });
const functions = const functions =
assistant?.tools assistant?.tools
?.filter((tool) => tool.type === 'function') ?.filter((tool) => tool.type === 'function' && !isImageVisionTool(tool))
?.map((tool) => tool.function?.name ?? '') ?? []; ?.map((tool) => tool.function?.name ?? '') ?? [];
const formValues: Partial<AssistantForm & Actions> = { const formValues: Partial<AssistantForm & Actions> = {

View file

@ -16,6 +16,7 @@ export default {
'If you upload files under Knowledge, conversations with your Assistant may include file contents.', 'If you upload files under Knowledge, conversations with your Assistant may include file contents.',
com_assistants_knowledge_disabled: com_assistants_knowledge_disabled:
'Assistant must be created, and Code Interpreter or Retrieval must be enabled and saved before uploading files as Knowledge.', 'Assistant must be created, and Code Interpreter or Retrieval must be enabled and saved before uploading files as Knowledge.',
com_assistants_image_vision: 'Image Vision',
com_assistants_code_interpreter: 'Code Interpreter', com_assistants_code_interpreter: 'Code Interpreter',
com_assistants_code_interpreter_files: com_assistants_code_interpreter_files:
'The following files are only available for Code Interpreter:', 'The following files are only available for Code Interpreter:',

View file

@ -82,6 +82,7 @@ export type TValidatedAzureConfig = {
export enum Capabilities { export enum Capabilities {
code_interpreter = 'code_interpreter', code_interpreter = 'code_interpreter',
image_vision = 'image_vision',
retrieval = 'retrieval', retrieval = 'retrieval',
actions = 'actions', actions = 'actions',
tools = 'tools', tools = 'tools',
@ -100,6 +101,7 @@ export const assistantEndpointSchema = z.object({
.optional() .optional()
.default([ .default([
Capabilities.code_interpreter, Capabilities.code_interpreter,
Capabilities.image_vision,
Capabilities.retrieval, Capabilities.retrieval,
Capabilities.actions, Capabilities.actions,
Capabilities.tools, Capabilities.tools,

View file

@ -1,5 +1,6 @@
import { z } from 'zod'; import { z } from 'zod';
import type { TMessageContentParts } from './types/assistants'; import { Tools } from './types/assistants';
import type { TMessageContentParts, FunctionTool, FunctionToolCall } from './types/assistants';
import type { TFile } from './types/files'; import type { TFile } from './types/files';
export const isUUID = z.string().uuid(); export const isUUID = z.string().uuid();
@ -25,9 +26,26 @@ export const defaultAssistantFormValues = {
model: '', model: '',
functions: [], functions: [],
code_interpreter: false, code_interpreter: false,
image_vision: false,
retrieval: false, retrieval: false,
}; };
export const ImageVisionTool: FunctionTool = {
type: Tools.function,
[Tools.function]: {
name: 'image_vision',
description: 'Get detailed text descriptions for all current image attachments.',
parameters: {
type: 'object',
properties: {},
required: [],
},
},
};
export const isImageVisionTool = (tool: FunctionTool | FunctionToolCall) =>
tool.type === 'function' && tool.function?.name === ImageVisionTool?.function?.name;
export const endpointSettings = { export const endpointSettings = {
[EModelEndpoint.google]: { [EModelEndpoint.google]: {
model: { model: {