diff --git a/.env.example b/.env.example index a42ed5d68..7800677e2 100644 --- a/.env.example +++ b/.env.example @@ -164,6 +164,16 @@ ASSISTANTS_API_KEY=user_provided # ASSISTANTS_BASE_URL= # ASSISTANTS_MODELS=gpt-4o,gpt-3.5-turbo-0125,gpt-3.5-turbo-16k-0613,gpt-3.5-turbo-16k,gpt-3.5-turbo,gpt-4,gpt-4-0314,gpt-4-32k-0314,gpt-4-0613,gpt-3.5-turbo-0613,gpt-3.5-turbo-1106,gpt-4-0125-preview,gpt-4-turbo-preview,gpt-4-1106-preview +#==========================# +# Azure Assistants API # +#==========================# + +# Note: You should map your credentials with custom variables according to your Azure OpenAI Configuration +# The models for Azure Assistants are also determined by your Azure OpenAI configuration. + +# More info, including how to enable use of Assistants with Azure here: +# https://www.librechat.ai/docs/configuration/librechat_yaml/ai_endpoints/azure#using-assistants-with-azure + #============# # OpenRouter # #============# diff --git a/api/app/clients/OpenAIClient.js b/api/app/clients/OpenAIClient.js index b4a50bc05..f22908757 100644 --- a/api/app/clients/OpenAIClient.js +++ b/api/app/clients/OpenAIClient.js @@ -756,6 +756,8 @@ class OpenAIClient extends BaseClient { * In case of failure, it will return the default title, "New Chat". */ async titleConvo({ text, conversationId, responseText = '' }) { + this.conversationId = conversationId; + if (this.options.attachments) { delete this.options.attachments; } @@ -838,13 +840,17 @@ ${convo} try { let useChatCompletion = true; + if (this.options.reverseProxyUrl === CohereConstants.API_URL) { useChatCompletion = false; } + title = ( await this.sendPayload(instructionsPayload, { modelOptions, useChatCompletion }) ).replaceAll('"', ''); + const completionTokens = this.getTokenCount(title); + this.recordTokenUsage({ promptTokens, completionTokens, context: 'title' }); } catch (e) { logger.error( @@ -868,6 +874,7 @@ ${convo} context: 'title', tokenBuffer: 150, }); + title = await runTitleChain({ llm, text, convo, signal: this.abortController.signal }); } catch (e) { if (e?.message?.toLowerCase()?.includes('abort')) { @@ -1005,9 +1012,9 @@ ${convo} await spendTokens( { context, - user: this.user, model: this.modelOptions.model, conversationId: this.conversationId, + user: this.user ?? this.options.req.user?.id, endpointTokenConfig: this.options.endpointTokenConfig, }, { promptTokens, completionTokens }, diff --git a/api/app/clients/specs/OpenAIClient.test.js b/api/app/clients/specs/OpenAIClient.test.js index 7ef4fdcae..459039841 100644 --- a/api/app/clients/specs/OpenAIClient.test.js +++ b/api/app/clients/specs/OpenAIClient.test.js @@ -144,6 +144,7 @@ describe('OpenAIClient', () => { const defaultOptions = { // debug: true, + req: {}, openaiApiKey: 'new-api-key', modelOptions: { model, diff --git a/api/models/Action.js b/api/models/Action.js index 9acac078b..86bd5d859 100644 --- a/api/models/Action.js +++ b/api/models/Action.js @@ -62,8 +62,24 @@ const deleteAction = async (searchParams, session = null) => { return await Action.findOneAndDelete(searchParams, options).lean(); }; -module.exports = { - updateAction, - getActions, - deleteAction, +/** + * Deletes actions by params, within a transaction session if provided. + * + * @param {Object} searchParams - The search parameters to find the actions to delete. + * @param {string} searchParams.action_id - The ID of the action(s) to delete. + * @param {string} searchParams.user - The user ID of the action's author. + * @param {mongoose.ClientSession} [session] - The transaction session to use (optional). + * @returns {Promise} A promise that resolves to the number of deleted action documents. + */ +const deleteActions = async (searchParams, session = null) => { + const options = session ? { session } : {}; + const result = await Action.deleteMany(searchParams, options); + return result.deletedCount; +}; + +module.exports = { + getActions, + updateAction, + deleteAction, + deleteActions, }; diff --git a/api/models/Assistant.js b/api/models/Assistant.js index 17e407722..bf9382d0e 100644 --- a/api/models/Assistant.js +++ b/api/models/Assistant.js @@ -39,8 +39,21 @@ const getAssistants = async (searchParams) => { return await Assistant.find(searchParams).lean(); }; +/** + * Deletes an assistant based on the provided ID. + * + * @param {Object} searchParams - The search parameters to find the assistant to delete. + * @param {string} searchParams.assistant_id - The ID of the assistant to delete. + * @param {string} searchParams.user - The user ID of the assistant's author. + * @returns {Promise} Resolves when the assistant has been successfully deleted. + */ +const deleteAssistant = async (searchParams) => { + return await Assistant.findOneAndDelete(searchParams); +}; + module.exports = { updateAssistant, + deleteAssistant, getAssistants, getAssistant, }; diff --git a/api/models/plugins/mongoMeili.js b/api/models/plugins/mongoMeili.js index 261b5c50c..df9633830 100644 --- a/api/models/plugins/mongoMeili.js +++ b/api/models/plugins/mongoMeili.js @@ -155,7 +155,7 @@ const createMeiliMongooseModel = function ({ index, attributesToIndex }) { function (results, value, key) { return { ...results, [key]: 1 }; }, - { _id: 1 }, + { _id: 1, __v: 1 }, ), ).lean(); diff --git a/api/models/spendTokens.js b/api/models/spendTokens.js index 830cda207..917d0c93d 100644 --- a/api/models/spendTokens.js +++ b/api/models/spendTokens.js @@ -40,7 +40,7 @@ const spendTokens = async (txData, tokenUsage) => { }); } - if (!completionTokens) { + if (!completionTokens && isNaN(completionTokens)) { logger.debug('[spendTokens] !completionTokens', { prompt, completion }); return; } diff --git a/api/package.json b/api/package.json index d91b6031e..d4e0132dd 100644 --- a/api/package.json +++ b/api/package.json @@ -76,7 +76,7 @@ "nodejs-gpt": "^1.37.4", "nodemailer": "^6.9.4", "ollama": "^0.5.0", - "openai": "4.36.0", + "openai": "^4.47.1", "openai-chat-tokens": "^0.2.8", "openid-client": "^5.4.2", "passport": "^0.6.0", diff --git a/api/server/controllers/EndpointController.js b/api/server/controllers/EndpointController.js index b99dd5eda..d80ea6b14 100644 --- a/api/server/controllers/EndpointController.js +++ b/api/server/controllers/EndpointController.js @@ -16,10 +16,28 @@ async function endpointController(req, res) { /** @type {TEndpointsConfig} */ const mergedConfig = { ...defaultEndpointsConfig, ...customConfigEndpoints }; if (mergedConfig[EModelEndpoint.assistants] && req.app.locals?.[EModelEndpoint.assistants]) { - const { disableBuilder, retrievalModels, capabilities, ..._rest } = + const { disableBuilder, retrievalModels, capabilities, version, ..._rest } = req.app.locals[EModelEndpoint.assistants]; + mergedConfig[EModelEndpoint.assistants] = { ...mergedConfig[EModelEndpoint.assistants], + version, + retrievalModels, + disableBuilder, + capabilities, + }; + } + + if ( + mergedConfig[EModelEndpoint.azureAssistants] && + req.app.locals?.[EModelEndpoint.azureAssistants] + ) { + const { disableBuilder, retrievalModels, capabilities, version, ..._rest } = + req.app.locals[EModelEndpoint.azureAssistants]; + + mergedConfig[EModelEndpoint.azureAssistants] = { + ...mergedConfig[EModelEndpoint.azureAssistants], + version, retrievalModels, disableBuilder, capabilities, diff --git a/api/server/routes/assistants/chat.js b/api/server/controllers/assistants/chatV1.js similarity index 93% rename from api/server/routes/assistants/chat.js rename to api/server/controllers/assistants/chatV1.js index 96a09d02d..34f9e9203 100644 --- a/api/server/routes/assistants/chat.js +++ b/api/server/controllers/assistants/chatV1.js @@ -1,14 +1,13 @@ const { v4 } = require('uuid'); -const express = require('express'); const { Constants, RunStatus, CacheKeys, - FileSources, ContentTypes, EModelEndpoint, ViolationTypes, ImageVisionTool, + checkOpenAIStorage, AssistantStreamEvents, } = require('librechat-data-provider'); const { @@ -21,27 +20,18 @@ const { } = require('~/server/services/Threads'); const { sendResponse, sendMessage, sleep, isEnabled, countTokens } = require('~/server/utils'); const { runAssistant, createOnTextProgress } = require('~/server/services/AssistantService'); -const { addTitle, initializeClient } = require('~/server/services/Endpoints/assistants'); const { formatMessage, createVisionPrompt } = require('~/app/clients/prompts'); const { createRun, StreamRunManager } = require('~/server/services/Runs'); +const { addTitle } = require('~/server/services/Endpoints/assistants'); const { getTransactions } = require('~/models/Transaction'); const checkBalance = require('~/models/checkBalance'); const { getConvo } = require('~/models/Conversation'); const getLogStores = require('~/cache/getLogStores'); const { getModelMaxTokens } = require('~/utils'); +const { getOpenAIClient } = require('./helpers'); const { logger } = require('~/config'); -const router = express.Router(); -const { - setHeaders, - handleAbort, - validateModel, - handleAbortError, - // validateEndpoint, - buildEndpointOption, -} = require('~/server/middleware'); - -router.post('/abort', handleAbort()); +const { handleAbortError } = require('~/server/middleware'); const ten_minutes = 1000 * 60 * 10; @@ -49,16 +39,17 @@ const ten_minutes = 1000 * 60 * 10; * @route POST / * @desc Chat with an assistant * @access Public - * @param {express.Request} req - The request object, containing the request data. - * @param {express.Response} res - The response object, used to send back a response. + * @param {Express.Request} req - The request object, containing the request data. + * @param {Express.Response} res - The response object, used to send back a response. * @returns {void} */ -router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res) => { +const chatV1 = async (req, res) => { logger.debug('[/assistants/chat/] req.body', req.body); const { text, model, + endpoint, files = [], promptPrefix, assistant_id, @@ -70,7 +61,7 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res } = req.body; /** @type {Partial} */ - const assistantsConfig = req.app.locals?.[EModelEndpoint.assistants]; + const assistantsConfig = req.app.locals?.[endpoint]; if (assistantsConfig) { const { supportedIds, excludedIds } = assistantsConfig; @@ -138,7 +129,7 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res user: req.user.id, shouldSaveMessage: false, messageId: responseMessageId, - endpoint: EModelEndpoint.assistants, + endpoint, }; if (error.message === 'Run cancelled') { @@ -149,7 +140,7 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res logger.debug('[/assistants/chat/] Request aborted on close'); } else if (/Files.*are invalid/.test(error.message)) { const errorMessage = `Files are invalid, or may not have uploaded yet.${ - req.app.locals?.[EModelEndpoint.azureOpenAI].assistants + endpoint === EModelEndpoint.azureAssistants ? ' If using Azure OpenAI, files are only available in the region of the assistant\'s model at the time of upload.' : '' }`; @@ -205,6 +196,7 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res const runMessages = await checkMessageGaps({ openai, run_id, + endpoint, thread_id, conversationId, latestMessageId: responseMessageId, @@ -311,8 +303,7 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res }); }; - /** @type {{ openai: OpenAIClient }} */ - const { openai: _openai, client } = await initializeClient({ + const { openai: _openai, client } = await getOpenAIClient({ req, res, endpointOption: req.body.endpointOption, @@ -370,10 +361,7 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res /** @type {MongoFile[]} */ const attachments = await req.body.endpointOption.attachments; - if ( - attachments && - attachments.every((attachment) => attachment.source === FileSources.openai) - ) { + if (attachments && attachments.every((attachment) => checkOpenAIStorage(attachment.source))) { return; } @@ -431,7 +419,7 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res if (processedFiles) { for (const file of processedFiles) { - if (file.source !== FileSources.openai) { + if (!checkOpenAIStorage(file.source)) { attachedFileIds.delete(file.file_id); const index = file_ids.indexOf(file.file_id); if (index > -1) { @@ -467,6 +455,7 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res assistant_id, thread_id, model: assistant_id, + endpoint, }; previousMessages.push(requestMessage); @@ -476,7 +465,7 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res conversation = { conversationId, - endpoint: EModelEndpoint.assistants, + endpoint, promptPrefix: promptPrefix, instructions: instructions, assistant_id, @@ -513,7 +502,7 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res let response; const processRun = async (retry = false) => { - if (req.app.locals[EModelEndpoint.azureOpenAI]?.assistants) { + if (endpoint === EModelEndpoint.azureAssistants) { body.model = openai._options.model; openai.attachedFileIds = attachedFileIds; openai.visionPromise = visionPromise; @@ -603,6 +592,7 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res assistant_id, thread_id, model: assistant_id, + endpoint, }; sendMessage(res, { @@ -655,6 +645,6 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res } catch (error) { await handleError(error); } -}); +}; -module.exports = router; +module.exports = chatV1; diff --git a/api/server/controllers/assistants/chatV2.js b/api/server/controllers/assistants/chatV2.js new file mode 100644 index 000000000..c72d5fc9b --- /dev/null +++ b/api/server/controllers/assistants/chatV2.js @@ -0,0 +1,618 @@ +const { v4 } = require('uuid'); +const { + Constants, + RunStatus, + CacheKeys, + ContentTypes, + ToolCallTypes, + EModelEndpoint, + ViolationTypes, + retrievalMimeTypes, + AssistantStreamEvents, +} = require('librechat-data-provider'); +const { + initThread, + recordUsage, + saveUserMessage, + checkMessageGaps, + addThreadMetadata, + saveAssistantMessage, +} = require('~/server/services/Threads'); +const { sendResponse, sendMessage, sleep, isEnabled, countTokens } = require('~/server/utils'); +const { runAssistant, createOnTextProgress } = require('~/server/services/AssistantService'); +const { createRun, StreamRunManager } = require('~/server/services/Runs'); +const { addTitle } = require('~/server/services/Endpoints/assistants'); +const { getTransactions } = require('~/models/Transaction'); +const checkBalance = require('~/models/checkBalance'); +const { getConvo } = require('~/models/Conversation'); +const getLogStores = require('~/cache/getLogStores'); +const { getModelMaxTokens } = require('~/utils'); +const { getOpenAIClient } = require('./helpers'); +const { logger } = require('~/config'); + +const { handleAbortError } = require('~/server/middleware'); + +const ten_minutes = 1000 * 60 * 10; + +/** + * @route POST / + * @desc Chat with an assistant + * @access Public + * @param {Express.Request} req - The request object, containing the request data. + * @param {Express.Response} res - The response object, used to send back a response. + * @returns {void} + */ +const chatV2 = async (req, res) => { + logger.debug('[/assistants/chat/] req.body', req.body); + + /** @type {{ files: MongoFile[]}} */ + const { + text, + model, + endpoint, + files = [], + promptPrefix, + assistant_id, + instructions, + thread_id: _thread_id, + messageId: _messageId, + conversationId: convoId, + parentMessageId: _parentId = Constants.NO_PARENT, + } = req.body; + + /** @type {Partial} */ + const assistantsConfig = req.app.locals?.[endpoint]; + + if (assistantsConfig) { + const { supportedIds, excludedIds } = assistantsConfig; + const error = { message: 'Assistant not supported' }; + if (supportedIds?.length && !supportedIds.includes(assistant_id)) { + return await handleAbortError(res, req, error, { + sender: 'System', + conversationId: convoId, + messageId: v4(), + parentMessageId: _messageId, + error, + }); + } else if (excludedIds?.length && excludedIds.includes(assistant_id)) { + return await handleAbortError(res, req, error, { + sender: 'System', + conversationId: convoId, + messageId: v4(), + parentMessageId: _messageId, + }); + } + } + + /** @type {OpenAIClient} */ + let openai; + /** @type {string|undefined} - the current thread id */ + let thread_id = _thread_id; + /** @type {string|undefined} - the current run id */ + let run_id; + /** @type {string|undefined} - the parent messageId */ + let parentMessageId = _parentId; + /** @type {TMessage[]} */ + let previousMessages = []; + /** @type {import('librechat-data-provider').TConversation | null} */ + let conversation = null; + /** @type {string[]} */ + let file_ids = []; + /** @type {Set} */ + let attachedFileIds = new Set(); + /** @type {TMessage | null} */ + let requestMessage = null; + + const userMessageId = v4(); + const responseMessageId = v4(); + + /** @type {string} - The conversation UUID - created if undefined */ + const conversationId = convoId ?? v4(); + + const cache = getLogStores(CacheKeys.ABORT_KEYS); + const cacheKey = `${req.user.id}:${conversationId}`; + + /** @type {Run | undefined} - The completed run, undefined if incomplete */ + let completedRun; + + const handleError = async (error) => { + const defaultErrorMessage = + 'The Assistant run failed to initialize. Try sending a message in a new conversation.'; + const messageData = { + thread_id, + assistant_id, + conversationId, + parentMessageId, + sender: 'System', + user: req.user.id, + shouldSaveMessage: false, + messageId: responseMessageId, + endpoint, + }; + + if (error.message === 'Run cancelled') { + return res.end(); + } else if (error.message === 'Request closed' && completedRun) { + return; + } else if (error.message === 'Request closed') { + logger.debug('[/assistants/chat/] Request aborted on close'); + } else if (/Files.*are invalid/.test(error.message)) { + const errorMessage = `Files are invalid, or may not have uploaded yet.${ + endpoint === EModelEndpoint.azureAssistants + ? ' If using Azure OpenAI, files are only available in the region of the assistant\'s model at the time of upload.' + : '' + }`; + return sendResponse(res, messageData, errorMessage); + } else if (error?.message?.includes('string too long')) { + return sendResponse( + res, + messageData, + 'Message too long. The Assistants API has a limit of 32,768 characters per message. Please shorten it and try again.', + ); + } else if (error?.message?.includes(ViolationTypes.TOKEN_BALANCE)) { + return sendResponse(res, messageData, error.message); + } else { + logger.error('[/assistants/chat/]', error); + } + + if (!openai || !thread_id || !run_id) { + return sendResponse(res, messageData, defaultErrorMessage); + } + + await sleep(2000); + + try { + const status = await cache.get(cacheKey); + if (status === 'cancelled') { + logger.debug('[/assistants/chat/] Run already cancelled'); + return res.end(); + } + await cache.delete(cacheKey); + const cancelledRun = await openai.beta.threads.runs.cancel(thread_id, run_id); + logger.debug('[/assistants/chat/] Cancelled run:', cancelledRun); + } catch (error) { + logger.error('[/assistants/chat/] Error cancelling run', error); + } + + await sleep(2000); + + let run; + try { + run = await openai.beta.threads.runs.retrieve(thread_id, run_id); + await recordUsage({ + ...run.usage, + model: run.model, + user: req.user.id, + conversationId, + }); + } catch (error) { + logger.error('[/assistants/chat/] Error fetching or processing run', error); + } + + let finalEvent; + try { + const runMessages = await checkMessageGaps({ + openai, + run_id, + endpoint, + thread_id, + conversationId, + latestMessageId: responseMessageId, + }); + + const errorContentPart = { + text: { + value: + error?.message ?? 'There was an error processing your request. Please try again later.', + }, + type: ContentTypes.ERROR, + }; + + if (!Array.isArray(runMessages[runMessages.length - 1]?.content)) { + runMessages[runMessages.length - 1].content = [errorContentPart]; + } else { + const contentParts = runMessages[runMessages.length - 1].content; + for (let i = 0; i < contentParts.length; i++) { + const currentPart = contentParts[i]; + /** @type {CodeToolCall | RetrievalToolCall | FunctionToolCall | undefined} */ + const toolCall = currentPart?.[ContentTypes.TOOL_CALL]; + if ( + toolCall && + toolCall?.function && + !(toolCall?.function?.output || toolCall?.function?.output?.length) + ) { + contentParts[i] = { + ...currentPart, + [ContentTypes.TOOL_CALL]: { + ...toolCall, + function: { + ...toolCall.function, + output: 'error processing tool', + }, + }, + }; + } + } + runMessages[runMessages.length - 1].content.push(errorContentPart); + } + + finalEvent = { + final: true, + conversation: await getConvo(req.user.id, conversationId), + runMessages, + }; + } catch (error) { + logger.error('[/assistants/chat/] Error finalizing error process', error); + return sendResponse(res, messageData, 'The Assistant run failed'); + } + + return sendResponse(res, finalEvent); + }; + + try { + res.on('close', async () => { + if (!completedRun) { + await handleError(new Error('Request closed')); + } + }); + + if (convoId && !_thread_id) { + completedRun = true; + throw new Error('Missing thread_id for existing conversation'); + } + + if (!assistant_id) { + completedRun = true; + throw new Error('Missing assistant_id'); + } + + const checkBalanceBeforeRun = async () => { + if (!isEnabled(process.env.CHECK_BALANCE)) { + return; + } + const transactions = + (await getTransactions({ + user: req.user.id, + context: 'message', + conversationId, + })) ?? []; + + const totalPreviousTokens = Math.abs( + transactions.reduce((acc, curr) => acc + curr.rawAmount, 0), + ); + + // TODO: make promptBuffer a config option; buffer for titles, needs buffer for system instructions + const promptBuffer = parentMessageId === Constants.NO_PARENT && !_thread_id ? 200 : 0; + // 5 is added for labels + let promptTokens = (await countTokens(text + (promptPrefix ?? ''))) + 5; + promptTokens += totalPreviousTokens + promptBuffer; + // Count tokens up to the current context window + promptTokens = Math.min(promptTokens, getModelMaxTokens(model)); + + await checkBalance({ + req, + res, + txData: { + model, + user: req.user.id, + tokenType: 'prompt', + amount: promptTokens, + }, + }); + }; + + const { openai: _openai, client } = await getOpenAIClient({ + req, + res, + endpointOption: req.body.endpointOption, + initAppClient: true, + }); + + openai = _openai; + + if (previousMessages.length) { + parentMessageId = previousMessages[previousMessages.length - 1].messageId; + } + + let userMessage = { + role: 'user', + content: [ + { + type: ContentTypes.TEXT, + text, + }, + ], + metadata: { + messageId: userMessageId, + }, + }; + + /** @type {CreateRunBody | undefined} */ + const body = { + assistant_id, + model, + }; + + if (promptPrefix) { + body.additional_instructions = promptPrefix; + } + + if (instructions) { + body.instructions = instructions; + } + + const getRequestFileIds = async () => { + let thread_file_ids = []; + if (convoId) { + const convo = await getConvo(req.user.id, convoId); + if (convo && convo.file_ids) { + thread_file_ids = convo.file_ids; + } + } + + if (files.length || thread_file_ids.length) { + attachedFileIds = new Set([...file_ids, ...thread_file_ids]); + + let attachmentIndex = 0; + for (const file of files) { + file_ids.push(file.file_id); + if (file.type.startsWith('image')) { + userMessage.content.push({ + type: ContentTypes.IMAGE_FILE, + [ContentTypes.IMAGE_FILE]: { file_id: file.file_id }, + }); + } + + if (!userMessage.attachments) { + userMessage.attachments = []; + } + + userMessage.attachments.push({ + file_id: file.file_id, + tools: [{ type: ToolCallTypes.CODE_INTERPRETER }], + }); + + if (file.type.startsWith('image')) { + continue; + } + + const mimeType = file.type; + const isSupportedByRetrieval = retrievalMimeTypes.some((regex) => regex.test(mimeType)); + if (isSupportedByRetrieval) { + userMessage.attachments[attachmentIndex].tools.push({ + type: ToolCallTypes.FILE_SEARCH, + }); + } + + attachmentIndex++; + } + } + }; + + const initializeThread = async () => { + await getRequestFileIds(); + + // TODO: may allow multiple messages to be created beforehand in a future update + const initThreadBody = { + messages: [userMessage], + metadata: { + user: req.user.id, + conversationId, + }, + }; + + const result = await initThread({ openai, body: initThreadBody, thread_id }); + thread_id = result.thread_id; + + createOnTextProgress({ + openai, + conversationId, + userMessageId, + messageId: responseMessageId, + thread_id, + }); + + requestMessage = { + user: req.user.id, + text, + messageId: userMessageId, + parentMessageId, + // TODO: make sure client sends correct format for `files`, use zod + files, + file_ids, + conversationId, + isCreatedByUser: true, + assistant_id, + thread_id, + model: assistant_id, + endpoint, + }; + + previousMessages.push(requestMessage); + + /* asynchronous */ + saveUserMessage({ ...requestMessage, model }); + + conversation = { + conversationId, + endpoint, + promptPrefix: promptPrefix, + instructions: instructions, + assistant_id, + // model, + }; + + if (file_ids.length) { + conversation.file_ids = file_ids; + } + }; + + const promises = [initializeThread(), checkBalanceBeforeRun()]; + await Promise.all(promises); + + const sendInitialResponse = () => { + sendMessage(res, { + sync: true, + conversationId, + // messages: previousMessages, + requestMessage, + responseMessage: { + user: req.user.id, + messageId: openai.responseMessage.messageId, + parentMessageId: userMessageId, + conversationId, + assistant_id, + thread_id, + model: assistant_id, + }, + }); + }; + + /** @type {RunResponse | typeof StreamRunManager | undefined} */ + let response; + + const processRun = async (retry = false) => { + if (endpoint === EModelEndpoint.azureAssistants) { + body.model = openai._options.model; + openai.attachedFileIds = attachedFileIds; + if (retry) { + response = await runAssistant({ + openai, + thread_id, + run_id, + in_progress: openai.in_progress, + }); + return; + } + + /* NOTE: + * By default, a Run will use the model and tools configuration specified in Assistant object, + * but you can override most of these when creating the Run for added flexibility: + */ + const run = await createRun({ + openai, + thread_id, + body, + }); + + run_id = run.id; + await cache.set(cacheKey, `${thread_id}:${run_id}`, ten_minutes); + sendInitialResponse(); + + // todo: retry logic + response = await runAssistant({ openai, thread_id, run_id }); + return; + } + + /** @type {{[AssistantStreamEvents.ThreadRunCreated]: (event: ThreadRunCreated) => Promise}} */ + const handlers = { + [AssistantStreamEvents.ThreadRunCreated]: async (event) => { + await cache.set(cacheKey, `${thread_id}:${event.data.id}`, ten_minutes); + run_id = event.data.id; + sendInitialResponse(); + }, + }; + + const streamRunManager = new StreamRunManager({ + req, + res, + openai, + handlers, + thread_id, + attachedFileIds, + responseMessage: openai.responseMessage, + // streamOptions: { + + // }, + }); + + await streamRunManager.runAssistant({ + thread_id, + body, + }); + + response = streamRunManager; + }; + + await processRun(); + logger.debug('[/assistants/chat/] response', { + run: response.run, + steps: response.steps, + }); + + if (response.run.status === RunStatus.CANCELLED) { + logger.debug('[/assistants/chat/] Run cancelled, handled by `abortRun`'); + return res.end(); + } + + if (response.run.status === RunStatus.IN_PROGRESS) { + processRun(true); + } + + completedRun = response.run; + + /** @type {ResponseMessage} */ + const responseMessage = { + ...(response.responseMessage ?? response.finalMessage), + parentMessageId: userMessageId, + conversationId, + user: req.user.id, + assistant_id, + thread_id, + model: assistant_id, + endpoint, + }; + + sendMessage(res, { + final: true, + conversation, + requestMessage: { + parentMessageId, + thread_id, + }, + }); + res.end(); + + await saveAssistantMessage({ ...responseMessage, model }); + + if (parentMessageId === Constants.NO_PARENT && !_thread_id) { + addTitle(req, { + text, + responseText: response.text, + conversationId, + client, + }); + } + + await addThreadMetadata({ + openai, + thread_id, + messageId: responseMessage.messageId, + messages: response.messages, + }); + + if (!response.run.usage) { + await sleep(3000); + completedRun = await openai.beta.threads.runs.retrieve(thread_id, response.run.id); + if (completedRun.usage) { + await recordUsage({ + ...completedRun.usage, + user: req.user.id, + model: completedRun.model ?? model, + conversationId, + }); + } + } else { + await recordUsage({ + ...response.run.usage, + user: req.user.id, + model: response.run.model ?? model, + conversationId, + }); + } + } catch (error) { + await handleError(error); + } +}; + +module.exports = chatV2; diff --git a/api/server/controllers/assistants/helpers.js b/api/server/controllers/assistants/helpers.js new file mode 100644 index 000000000..f8c9efde4 --- /dev/null +++ b/api/server/controllers/assistants/helpers.js @@ -0,0 +1,158 @@ +const { EModelEndpoint, CacheKeys, defaultAssistantsVersion } = require('librechat-data-provider'); +const { + initializeClient: initAzureClient, +} = require('~/server/services/Endpoints/azureAssistants'); +const { initializeClient } = require('~/server/services/Endpoints/assistants'); +const { getLogStores } = require('~/cache'); + +/** + * @param {Express.Request} req + * @param {string} [endpoint] + * @returns {Promise} + */ +const getCurrentVersion = async (req, endpoint) => { + const index = req.baseUrl.lastIndexOf('/v'); + let version = index !== -1 ? req.baseUrl.substring(index + 1, index + 3) : null; + if (!version && req.body.version) { + version = `v${req.body.version}`; + } + if (!version && endpoint) { + const cache = getLogStores(CacheKeys.CONFIG_STORE); + const cachedEndpointsConfig = await cache.get(CacheKeys.ENDPOINT_CONFIG); + version = `v${ + cachedEndpointsConfig?.[endpoint]?.version ?? defaultAssistantsVersion[endpoint] + }`; + } + if (!version?.startsWith('v') && version.length !== 2) { + throw new Error(`[${req.baseUrl}] Invalid version: ${version}`); + } + return version; +}; + +/** + * Asynchronously lists assistants based on provided query parameters. + * + * Initializes the client with the current request and response objects and lists assistants + * according to the query parameters. This function abstracts the logic for non-Azure paths. + * + * @async + * @param {object} params - The parameters object. + * @param {object} params.req - The request object, used for initializing the client. + * @param {object} params.res - The response object, used for initializing the client. + * @param {string} params.version - The API version to use. + * @param {object} params.query - The query parameters to list assistants (e.g., limit, order). + * @returns {Promise} A promise that resolves to the response from the `openai.beta.assistants.list` method call. + */ +const listAssistants = async ({ req, res, version, query }) => { + const { openai } = await getOpenAIClient({ req, res, version }); + return openai.beta.assistants.list(query); +}; + +/** + * Asynchronously lists assistants for Azure configured groups. + * + * Iterates through Azure configured assistant groups, initializes the client with the current request and response objects, + * lists assistants based on the provided query parameters, and merges their data alongside the model information into a single array. + * + * @async + * @param {object} params - The parameters object. + * @param {object} params.req - The request object, used for initializing the client and manipulating the request body. + * @param {object} params.res - The response object, used for initializing the client. + * @param {string} params.version - The API version to use. + * @param {TAzureConfig} params.azureConfig - The Azure configuration object containing assistantGroups and groupMap. + * @param {object} params.query - The query parameters to list assistants (e.g., limit, order). + * @returns {Promise} A promise that resolves to an array of assistant data merged with their respective model information. + */ +const listAssistantsForAzure = async ({ req, res, version, azureConfig = {}, query }) => { + /** @type {Array<[string, TAzureModelConfig]>} */ + const groupModelTuples = []; + const promises = []; + /** @type {Array} */ + const groups = []; + + const { groupMap, assistantGroups } = azureConfig; + + for (const groupName of assistantGroups) { + const group = groupMap[groupName]; + groups.push(group); + + const currentModelTuples = Object.entries(group?.models); + groupModelTuples.push(currentModelTuples); + + /* The specified model is only necessary to + fetch assistants for the shared instance */ + req.body.model = currentModelTuples[0][0]; + promises.push(listAssistants({ req, res, version, query })); + } + + const resolvedQueries = await Promise.all(promises); + const data = resolvedQueries.flatMap((res, i) => + res.data.map((assistant) => { + const deploymentName = assistant.model; + const currentGroup = groups[i]; + const currentModelTuples = groupModelTuples[i]; + const firstModel = currentModelTuples[0][0]; + + if (currentGroup.deploymentName === deploymentName) { + return { ...assistant, model: firstModel }; + } + + for (const [model, modelConfig] of currentModelTuples) { + if (modelConfig.deploymentName === deploymentName) { + return { ...assistant, model }; + } + } + + return { ...assistant, model: firstModel }; + }), + ); + + return { + first_id: data[0]?.id, + last_id: data[data.length - 1]?.id, + object: 'list', + has_more: false, + data, + }; +}; + +async function getOpenAIClient({ req, res, endpointOption, initAppClient, overrideEndpoint }) { + let endpoint = overrideEndpoint ?? req.body.endpoint ?? req.query.endpoint; + const version = await getCurrentVersion(req, endpoint); + if (!endpoint) { + throw new Error(`[${req.baseUrl}] Endpoint is required`); + } + + let result; + if (endpoint === EModelEndpoint.assistants) { + result = await initializeClient({ req, res, version, endpointOption, initAppClient }); + } else if (endpoint === EModelEndpoint.azureAssistants) { + result = await initAzureClient({ req, res, version, endpointOption, initAppClient }); + } + + return result; +} + +const fetchAssistants = async (req, res) => { + const { limit = 100, order = 'desc', after, before, endpoint } = req.query; + const version = await getCurrentVersion(req, endpoint); + const query = { limit, order, after, before }; + + /** @type {AssistantListResponse} */ + let body; + + if (endpoint === EModelEndpoint.assistants) { + ({ body } = await listAssistants({ req, res, version, query })); + } else if (endpoint === EModelEndpoint.azureAssistants) { + const azureConfig = req.app.locals[EModelEndpoint.azureOpenAI]; + body = await listAssistantsForAzure({ req, res, version, azureConfig, query }); + } + + return body; +}; + +module.exports = { + getOpenAIClient, + fetchAssistants, + getCurrentVersion, +}; diff --git a/api/server/routes/assistants/assistants.js b/api/server/controllers/assistants/v1.js similarity index 75% rename from api/server/routes/assistants/assistants.js rename to api/server/controllers/assistants/v1.js index 67f200f6b..3bbd6b63d 100644 --- a/api/server/routes/assistants/assistants.js +++ b/api/server/controllers/assistants/v1.js @@ -1,34 +1,11 @@ -const multer = require('multer'); -const express = require('express'); -const { FileContext, EModelEndpoint } = require('librechat-data-provider'); -const { - initializeClient, - listAssistantsForAzure, - listAssistants, -} = require('~/server/services/Endpoints/assistants'); +const { FileContext } = require('librechat-data-provider'); const { getStrategyFunctions } = require('~/server/services/Files/strategies'); +const { deleteAssistantActions } = require('~/server/services/ActionService'); const { uploadImageBuffer } = require('~/server/services/Files/process'); const { updateAssistant, getAssistants } = require('~/models/Assistant'); +const { getOpenAIClient, fetchAssistants } = require('./helpers'); const { deleteFileByFilter } = require('~/models/File'); const { logger } = require('~/config'); -const actions = require('./actions'); -const tools = require('./tools'); - -const upload = multer(); -const router = express.Router(); - -/** - * Assistant actions route. - * @route GET|POST /assistants/actions - */ -router.use('/actions', actions); - -/** - * Create an assistant. - * @route GET /assistants/tools - * @returns {TPlugin[]} 200 - application/json - */ -router.use('/tools', tools); /** * Create an assistant. @@ -36,12 +13,11 @@ router.use('/tools', tools); * @param {AssistantCreateParams} req.body - The assistant creation parameters. * @returns {Assistant} 201 - success response - application/json */ -router.post('/', async (req, res) => { +const createAssistant = async (req, res) => { try { - /** @type {{ openai: OpenAI }} */ - const { openai } = await initializeClient({ req, res }); + const { openai } = await getOpenAIClient({ req, res }); - const { tools = [], ...assistantData } = req.body; + const { tools = [], endpoint, ...assistantData } = req.body; assistantData.tools = tools .map((tool) => { if (typeof tool !== 'string') { @@ -52,18 +28,28 @@ router.post('/', async (req, res) => { }) .filter((tool) => tool); + let azureModelIdentifier = null; if (openai.locals?.azureOptions) { + azureModelIdentifier = assistantData.model; assistantData.model = openai.locals.azureOptions.azureOpenAIApiDeploymentName; } + assistantData.metadata = { + author: req.user.id, + endpoint, + }; + const assistant = await openai.beta.assistants.create(assistantData); + if (azureModelIdentifier) { + assistant.model = azureModelIdentifier; + } logger.debug('/assistants/', assistant); res.status(201).json(assistant); } catch (error) { logger.error('[/assistants] Error creating assistant', error); res.status(500).json({ error: error.message }); } -}); +}; /** * Retrieves an assistant. @@ -71,10 +57,10 @@ router.post('/', async (req, res) => { * @param {string} req.params.id - Assistant identifier. * @returns {Assistant} 200 - success response - application/json */ -router.get('/:id', async (req, res) => { +const retrieveAssistant = async (req, res) => { try { - /** @type {{ openai: OpenAI }} */ - const { openai } = await initializeClient({ req, res }); + /* NOTE: not actually being used right now */ + const { openai } = await getOpenAIClient({ req, res }); const assistant_id = req.params.id; const assistant = await openai.beta.assistants.retrieve(assistant_id); @@ -83,22 +69,23 @@ router.get('/:id', async (req, res) => { logger.error('[/assistants/:id] Error retrieving assistant', error); res.status(500).json({ error: error.message }); } -}); +}; /** * Modifies an assistant. * @route PATCH /assistants/:id + * @param {object} req - Express Request + * @param {object} req.params - Request params * @param {string} req.params.id - Assistant identifier. * @param {AssistantUpdateParams} req.body - The assistant update parameters. * @returns {Assistant} 200 - success response - application/json */ -router.patch('/:id', async (req, res) => { +const patchAssistant = async (req, res) => { try { - /** @type {{ openai: OpenAI }} */ - const { openai } = await initializeClient({ req, res }); + const { openai } = await getOpenAIClient({ req, res }); const assistant_id = req.params.id; - const updateData = req.body; + const { endpoint: _e, ...updateData } = req.body; updateData.tools = (updateData.tools ?? []) .map((tool) => { if (typeof tool !== 'string') { @@ -119,52 +106,46 @@ router.patch('/:id', async (req, res) => { logger.error('[/assistants/:id] Error updating assistant', error); res.status(500).json({ error: error.message }); } -}); +}; /** * Deletes an assistant. * @route DELETE /assistants/:id + * @param {object} req - Express Request + * @param {object} req.params - Request params * @param {string} req.params.id - Assistant identifier. * @returns {Assistant} 200 - success response - application/json */ -router.delete('/:id', async (req, res) => { +const deleteAssistant = async (req, res) => { try { - /** @type {{ openai: OpenAI }} */ - const { openai } = await initializeClient({ req, res }); + const { openai } = await getOpenAIClient({ req, res }); const assistant_id = req.params.id; const deletionStatus = await openai.beta.assistants.del(assistant_id); + if (deletionStatus?.deleted) { + await deleteAssistantActions({ req, assistant_id }); + } res.json(deletionStatus); } catch (error) { logger.error('[/assistants/:id] Error deleting assistant', error); res.status(500).json({ error: 'Error deleting assistant' }); } -}); +}; /** * Returns a list of assistants. * @route GET /assistants + * @param {object} req - Express Request * @param {AssistantListParams} req.query - The assistant list parameters for pagination and sorting. * @returns {AssistantListResponse} 200 - success response - application/json */ -router.get('/', async (req, res) => { +const listAssistants = async (req, res) => { try { - const { limit = 100, order = 'desc', after, before } = req.query; - const query = { limit, order, after, before }; + const body = await fetchAssistants(req, res); - const azureConfig = req.app.locals[EModelEndpoint.azureOpenAI]; - /** @type {AssistantListResponse} */ - let body; - - if (azureConfig?.assistants) { - body = await listAssistantsForAzure({ req, res, azureConfig, query }); - } else { - ({ body } = await listAssistants({ req, res, query })); - } - - if (req.app.locals?.[EModelEndpoint.assistants]) { + if (req.app.locals?.[req.query.endpoint]) { /** @type {Partial} */ - const assistantsConfig = req.app.locals[EModelEndpoint.assistants]; + const assistantsConfig = req.app.locals[req.query.endpoint]; const { supportedIds, excludedIds } = assistantsConfig; if (supportedIds?.length) { body.data = body.data.filter((assistant) => supportedIds.includes(assistant.id)); @@ -178,31 +159,34 @@ router.get('/', async (req, res) => { logger.error('[/assistants] Error listing assistants', error); res.status(500).json({ message: 'Error listing assistants' }); } -}); +}; /** * Returns a list of the user's assistant documents (metadata saved to database). * @route GET /assistants/documents * @returns {AssistantDocument[]} 200 - success response - application/json */ -router.get('/documents', async (req, res) => { +const getAssistantDocuments = async (req, res) => { try { res.json(await getAssistants({ user: req.user.id })); } catch (error) { logger.error('[/assistants/documents] Error listing assistant documents', error); res.status(500).json({ error: error.message }); } -}); +}; /** * Uploads and updates an avatar for a specific assistant. * @route POST /avatar/:assistant_id + * @param {object} req - Express Request + * @param {object} req.params - Request params * @param {string} req.params.assistant_id - The ID of the assistant. * @param {Express.Multer.File} req.file - The avatar image file. + * @param {object} req.body - Request body * @param {string} [req.body.metadata] - Optional metadata for the assistant's avatar. * @returns {Object} 200 - success response - application/json */ -router.post('/avatar/:assistant_id', upload.single('file'), async (req, res) => { +const uploadAssistantAvatar = async (req, res) => { try { const { assistant_id } = req.params; if (!assistant_id) { @@ -210,8 +194,7 @@ router.post('/avatar/:assistant_id', upload.single('file'), async (req, res) => } let { metadata: _metadata = '{}' } = req.body; - /** @type {{ openai: OpenAI }} */ - const { openai } = await initializeClient({ req, res }); + const { openai } = await getOpenAIClient({ req, res }); const image = await uploadImageBuffer({ req, @@ -266,6 +249,14 @@ router.post('/avatar/:assistant_id', upload.single('file'), async (req, res) => logger.error(message, error); res.status(500).json({ message }); } -}); +}; -module.exports = router; +module.exports = { + createAssistant, + retrieveAssistant, + patchAssistant, + deleteAssistant, + listAssistants, + getAssistantDocuments, + uploadAssistantAvatar, +}; diff --git a/api/server/controllers/assistants/v2.js b/api/server/controllers/assistants/v2.js new file mode 100644 index 000000000..81f55607a --- /dev/null +++ b/api/server/controllers/assistants/v2.js @@ -0,0 +1,208 @@ +const { ToolCallTypes } = require('librechat-data-provider'); +const { validateAndUpdateTool } = require('~/server/services/ActionService'); +const { getOpenAIClient } = require('./helpers'); +const { logger } = require('~/config'); + +/** + * Create an assistant. + * @route POST /assistants + * @param {AssistantCreateParams} req.body - The assistant creation parameters. + * @returns {Assistant} 201 - success response - application/json + */ +const createAssistant = async (req, res) => { + try { + /** @type {{ openai: OpenAIClient }} */ + const { openai } = await getOpenAIClient({ req, res }); + + const { tools = [], endpoint, ...assistantData } = req.body; + assistantData.tools = tools + .map((tool) => { + if (typeof tool !== 'string') { + return tool; + } + + return req.app.locals.availableTools[tool]; + }) + .filter((tool) => tool); + + let azureModelIdentifier = null; + if (openai.locals?.azureOptions) { + azureModelIdentifier = assistantData.model; + assistantData.model = openai.locals.azureOptions.azureOpenAIApiDeploymentName; + } + + assistantData.metadata = { + author: req.user.id, + endpoint, + }; + + const assistant = await openai.beta.assistants.create(assistantData); + if (azureModelIdentifier) { + assistant.model = azureModelIdentifier; + } + logger.debug('/assistants/', assistant); + res.status(201).json(assistant); + } catch (error) { + logger.error('[/assistants] Error creating assistant', error); + res.status(500).json({ error: error.message }); + } +}; + +/** + * Modifies an assistant. + * @param {object} params + * @param {Express.Request} params.req + * @param {OpenAIClient} params.openai + * @param {string} params.assistant_id + * @param {AssistantUpdateParams} params.updateData + * @returns {Promise} The updated assistant. + */ +const updateAssistant = async ({ req, openai, assistant_id, updateData }) => { + const tools = []; + + let hasFileSearch = false; + for (const tool of updateData.tools ?? []) { + let actualTool = typeof tool === 'string' ? req.app.locals.availableTools[tool] : tool; + + if (!actualTool) { + continue; + } + + if (actualTool.type === ToolCallTypes.FILE_SEARCH) { + hasFileSearch = true; + } + + if (!actualTool.function) { + tools.push(actualTool); + continue; + } + + const updatedTool = await validateAndUpdateTool({ req, tool: actualTool, assistant_id }); + if (updatedTool) { + tools.push(updatedTool); + } + } + + if (hasFileSearch && !updateData.tool_resources) { + const assistant = await openai.beta.assistants.retrieve(assistant_id); + updateData.tool_resources = assistant.tool_resources ?? null; + } + + if (hasFileSearch && !updateData.tool_resources?.file_search) { + updateData.tool_resources = { + ...(updateData.tool_resources ?? {}), + file_search: { + vector_store_ids: [], + }, + }; + } + + updateData.tools = tools; + + if (openai.locals?.azureOptions && updateData.model) { + updateData.model = openai.locals.azureOptions.azureOpenAIApiDeploymentName; + } + + return await openai.beta.assistants.update(assistant_id, updateData); +}; + +/** + * Modifies an assistant with the resource file id. + * @param {object} params + * @param {Express.Request} params.req + * @param {OpenAIClient} params.openai + * @param {string} params.assistant_id + * @param {string} params.tool_resource + * @param {string} params.file_id + * @param {AssistantUpdateParams} params.updateData + * @returns {Promise} The updated assistant. + */ +const addResourceFileId = async ({ req, openai, assistant_id, tool_resource, file_id }) => { + const assistant = await openai.beta.assistants.retrieve(assistant_id); + const { tool_resources = {} } = assistant; + if (tool_resources[tool_resource]) { + tool_resources[tool_resource].file_ids.push(file_id); + } else { + tool_resources[tool_resource] = { file_ids: [file_id] }; + } + + delete assistant.id; + return await updateAssistant({ + req, + openai, + assistant_id, + updateData: { tools: assistant.tools, tool_resources }, + }); +}; + +/** + * Deletes a file ID from an assistant's resource. + * @param {object} params + * @param {Express.Request} params.req + * @param {OpenAIClient} params.openai + * @param {string} params.assistant_id + * @param {string} [params.tool_resource] + * @param {string} params.file_id + * @param {AssistantUpdateParams} params.updateData + * @returns {Promise} The updated assistant. + */ +const deleteResourceFileId = async ({ req, openai, assistant_id, tool_resource, file_id }) => { + const assistant = await openai.beta.assistants.retrieve(assistant_id); + const { tool_resources = {} } = assistant; + + if (tool_resource && tool_resources[tool_resource]) { + const resource = tool_resources[tool_resource]; + const index = resource.file_ids.indexOf(file_id); + if (index !== -1) { + resource.file_ids.splice(index, 1); + } + } else { + for (const resourceKey in tool_resources) { + const resource = tool_resources[resourceKey]; + const index = resource.file_ids.indexOf(file_id); + if (index !== -1) { + resource.file_ids.splice(index, 1); + break; + } + } + } + + delete assistant.id; + return await updateAssistant({ + req, + openai, + assistant_id, + updateData: { tools: assistant.tools, tool_resources }, + }); +}; + +/** + * Modifies an assistant. + * @route PATCH /assistants/:id + * @param {object} req - Express Request + * @param {object} req.params - Request params + * @param {string} req.params.id - Assistant identifier. + * @param {AssistantUpdateParams} req.body - The assistant update parameters. + * @returns {Assistant} 200 - success response - application/json + */ +const patchAssistant = async (req, res) => { + try { + const { openai } = await getOpenAIClient({ req, res }); + const assistant_id = req.params.id; + const { endpoint: _e, ...updateData } = req.body; + updateData.tools = updateData.tools ?? []; + const updatedAssistant = await updateAssistant({ req, openai, assistant_id, updateData }); + res.json(updatedAssistant); + } catch (error) { + logger.error('[/assistants/:id] Error updating assistant', error); + res.status(500).json({ error: error.message }); + } +}; + +module.exports = { + patchAssistant, + createAssistant, + updateAssistant, + addResourceFileId, + deleteResourceFileId, +}; diff --git a/api/server/middleware/abortMiddleware.js b/api/server/middleware/abortMiddleware.js index a868b107b..69df9619c 100644 --- a/api/server/middleware/abortMiddleware.js +++ b/api/server/middleware/abortMiddleware.js @@ -1,4 +1,4 @@ -const { EModelEndpoint } = require('librechat-data-provider'); +const { isAssistantsEndpoint } = require('librechat-data-provider'); const { sendMessage, sendError, countTokens, isEnabled } = require('~/server/utils'); const { truncateText, smartTruncateText } = require('~/app/clients/prompts'); const { saveMessage, getConvo, getConvoTitle } = require('~/models'); @@ -15,7 +15,7 @@ async function abortMessage(req, res) { abortKey = conversationId; } - if (endpoint === EModelEndpoint.assistants) { + if (isAssistantsEndpoint(endpoint)) { return await abortRun(req, res); } diff --git a/api/server/middleware/abortRun.js b/api/server/middleware/abortRun.js index 6db6329d4..6522d6746 100644 --- a/api/server/middleware/abortRun.js +++ b/api/server/middleware/abortRun.js @@ -10,7 +10,7 @@ const three_minutes = 1000 * 60 * 3; async function abortRun(req, res) { res.setHeader('Content-Type', 'application/json'); - const { abortKey } = req.body; + const { abortKey, endpoint } = req.body; const [conversationId, latestMessageId] = abortKey.split(':'); const conversation = await getConvo(req.user.id, conversationId); @@ -68,9 +68,10 @@ async function abortRun(req, res) { runMessages = await checkMessageGaps({ openai, - latestMessageId, + endpoint, thread_id, run_id, + latestMessageId, conversationId, }); diff --git a/api/server/middleware/buildEndpointOption.js b/api/server/middleware/buildEndpointOption.js index 3de13ed2e..ddaaa35a3 100644 --- a/api/server/middleware/buildEndpointOption.js +++ b/api/server/middleware/buildEndpointOption.js @@ -1,5 +1,6 @@ const { parseConvo, EModelEndpoint } = require('librechat-data-provider'); const { getModelsConfig } = require('~/server/controllers/ModelController'); +const azureAssistants = require('~/server/services/Endpoints/azureAssistants'); const assistants = require('~/server/services/Endpoints/assistants'); const gptPlugins = require('~/server/services/Endpoints/gptPlugins'); const { processFiles } = require('~/server/services/Files/process'); @@ -18,6 +19,7 @@ const buildFunction = { [EModelEndpoint.anthropic]: anthropic.buildOptions, [EModelEndpoint.gptPlugins]: gptPlugins.buildOptions, [EModelEndpoint.assistants]: assistants.buildOptions, + [EModelEndpoint.azureAssistants]: azureAssistants.buildOptions, }; async function buildEndpointOption(req, res, next) { diff --git a/api/server/routes/assistants/actions.js b/api/server/routes/assistants/actions.js index 9cf47c869..515035761 100644 --- a/api/server/routes/assistants/actions.js +++ b/api/server/routes/assistants/actions.js @@ -2,7 +2,7 @@ const { v4 } = require('uuid'); const express = require('express'); const { encryptMetadata, domainParser } = require('~/server/services/ActionService'); const { actionDelimiter, EModelEndpoint } = require('librechat-data-provider'); -const { initializeClient } = require('~/server/services/Endpoints/assistants'); +const { getOpenAIClient } = require('~/server/controllers/assistants/helpers'); const { updateAction, getActions, deleteAction } = require('~/models/Action'); const { updateAssistant, getAssistant } = require('~/models/Assistant'); const { logger } = require('~/config'); @@ -45,7 +45,6 @@ router.post('/:assistant_id', async (req, res) => { let metadata = encryptMetadata(_metadata); let { domain } = metadata; - /* Azure doesn't support periods in function names */ domain = await domainParser(req, domain, true); if (!domain) { @@ -55,8 +54,7 @@ router.post('/:assistant_id', async (req, res) => { const action_id = _action_id ?? v4(); const initialPromises = []; - /** @type {{ openai: OpenAI }} */ - const { openai } = await initializeClient({ req, res }); + const { openai } = await getOpenAIClient({ req, res }); initialPromises.push(getAssistant({ assistant_id })); initialPromises.push(openai.beta.assistants.retrieve(assistant_id)); @@ -157,9 +155,7 @@ router.delete('/:assistant_id/:action_id/:model', async (req, res) => { try { const { assistant_id, action_id, model } = req.params; req.body.model = model; - - /** @type {{ openai: OpenAI }} */ - const { openai } = await initializeClient({ req, res }); + const { openai } = await getOpenAIClient({ req, res }); const initialPromises = []; initialPromises.push(getAssistant({ assistant_id })); diff --git a/api/server/routes/assistants/chatV1.js b/api/server/routes/assistants/chatV1.js new file mode 100644 index 000000000..99de23c20 --- /dev/null +++ b/api/server/routes/assistants/chatV1.js @@ -0,0 +1,25 @@ +const express = require('express'); + +const router = express.Router(); +const { + setHeaders, + handleAbort, + validateModel, + // validateEndpoint, + buildEndpointOption, +} = require('~/server/middleware'); +const chatController = require('~/server/controllers/assistants/chatV1'); + +router.post('/abort', handleAbort()); + +/** + * @route POST / + * @desc Chat with an assistant + * @access Public + * @param {express.Request} req - The request object, containing the request data. + * @param {express.Response} res - The response object, used to send back a response. + * @returns {void} + */ +router.post('/', validateModel, buildEndpointOption, setHeaders, chatController); + +module.exports = router; diff --git a/api/server/routes/assistants/chatV2.js b/api/server/routes/assistants/chatV2.js new file mode 100644 index 000000000..e0ef2e0b2 --- /dev/null +++ b/api/server/routes/assistants/chatV2.js @@ -0,0 +1,25 @@ +const express = require('express'); + +const router = express.Router(); +const { + setHeaders, + handleAbort, + validateModel, + // validateEndpoint, + buildEndpointOption, +} = require('~/server/middleware'); +const chatController = require('~/server/controllers/assistants/chatV2'); + +router.post('/abort', handleAbort()); + +/** + * @route POST / + * @desc Chat with an assistant + * @access Public + * @param {express.Request} req - The request object, containing the request data. + * @param {express.Response} res - The response object, used to send back a response. + * @returns {void} + */ +router.post('/', validateModel, buildEndpointOption, setHeaders, chatController); + +module.exports = router; diff --git a/api/server/routes/assistants/index.js b/api/server/routes/assistants/index.js index a47a768f9..6613177e7 100644 --- a/api/server/routes/assistants/index.js +++ b/api/server/routes/assistants/index.js @@ -7,16 +7,19 @@ const { // concurrentLimiter, // messageIpLimiter, // messageUserLimiter, -} = require('../../middleware'); +} = require('~/server/middleware'); -const assistants = require('./assistants'); -const chat = require('./chat'); +const v1 = require('./v1'); +const chatV1 = require('./chatV1'); +const v2 = require('./v2'); +const chatV2 = require('./chatV2'); router.use(requireJwtAuth); router.use(checkBan); router.use(uaParser); - -router.use('/', assistants); -router.use('/chat', chat); +router.use('/v1/', v1); +router.use('/v1/chat', chatV1); +router.use('/v2/', v2); +router.use('/v2/chat', chatV2); module.exports = router; diff --git a/api/server/routes/assistants/v1.js b/api/server/routes/assistants/v1.js new file mode 100644 index 000000000..184450887 --- /dev/null +++ b/api/server/routes/assistants/v1.js @@ -0,0 +1,81 @@ +const multer = require('multer'); +const express = require('express'); +const controllers = require('~/server/controllers/assistants/v1'); +const actions = require('./actions'); +const tools = require('./tools'); + +const upload = multer(); +const router = express.Router(); + +/** + * Assistant actions route. + * @route GET|POST /assistants/actions + */ +router.use('/actions', actions); + +/** + * Create an assistant. + * @route GET /assistants/tools + * @returns {TPlugin[]} 200 - application/json + */ +router.use('/tools', tools); + +/** + * Create an assistant. + * @route POST /assistants + * @param {AssistantCreateParams} req.body - The assistant creation parameters. + * @returns {Assistant} 201 - success response - application/json + */ +router.post('/', controllers.createAssistant); + +/** + * Retrieves an assistant. + * @route GET /assistants/:id + * @param {string} req.params.id - Assistant identifier. + * @returns {Assistant} 200 - success response - application/json + */ +router.get('/:id', controllers.retrieveAssistant); + +/** + * Modifies an assistant. + * @route PATCH /assistants/:id + * @param {string} req.params.id - Assistant identifier. + * @param {AssistantUpdateParams} req.body - The assistant update parameters. + * @returns {Assistant} 200 - success response - application/json + */ +router.patch('/:id', controllers.patchAssistant); + +/** + * Deletes an assistant. + * @route DELETE /assistants/:id + * @param {string} req.params.id - Assistant identifier. + * @returns {Assistant} 200 - success response - application/json + */ +router.delete('/:id', controllers.deleteAssistant); + +/** + * Returns a list of assistants. + * @route GET /assistants + * @param {AssistantListParams} req.query - The assistant list parameters for pagination and sorting. + * @returns {AssistantListResponse} 200 - success response - application/json + */ +router.get('/', controllers.listAssistants); + +/** + * Returns a list of the user's assistant documents (metadata saved to database). + * @route GET /assistants/documents + * @returns {AssistantDocument[]} 200 - success response - application/json + */ +router.get('/documents', controllers.getAssistantDocuments); + +/** + * Uploads and updates an avatar for a specific assistant. + * @route POST /avatar/:assistant_id + * @param {string} req.params.assistant_id - The ID of the assistant. + * @param {Express.Multer.File} req.file - The avatar image file. + * @param {string} [req.body.metadata] - Optional metadata for the assistant's avatar. + * @returns {Object} 200 - success response - application/json + */ +router.post('/avatar/:assistant_id', upload.single('file'), controllers.uploadAssistantAvatar); + +module.exports = router; diff --git a/api/server/routes/assistants/v2.js b/api/server/routes/assistants/v2.js new file mode 100644 index 000000000..3c70c623a --- /dev/null +++ b/api/server/routes/assistants/v2.js @@ -0,0 +1,82 @@ +const multer = require('multer'); +const express = require('express'); +const v1 = require('~/server/controllers/assistants/v1'); +const v2 = require('~/server/controllers/assistants/v2'); +const actions = require('./actions'); +const tools = require('./tools'); + +const upload = multer(); +const router = express.Router(); + +/** + * Assistant actions route. + * @route GET|POST /assistants/actions + */ +router.use('/actions', actions); + +/** + * Create an assistant. + * @route GET /assistants/tools + * @returns {TPlugin[]} 200 - application/json + */ +router.use('/tools', tools); + +/** + * Create an assistant. + * @route POST /assistants + * @param {AssistantCreateParams} req.body - The assistant creation parameters. + * @returns {Assistant} 201 - success response - application/json + */ +router.post('/', v2.createAssistant); + +/** + * Retrieves an assistant. + * @route GET /assistants/:id + * @param {string} req.params.id - Assistant identifier. + * @returns {Assistant} 200 - success response - application/json + */ +router.get('/:id', v1.retrieveAssistant); + +/** + * Modifies an assistant. + * @route PATCH /assistants/:id + * @param {string} req.params.id - Assistant identifier. + * @param {AssistantUpdateParams} req.body - The assistant update parameters. + * @returns {Assistant} 200 - success response - application/json + */ +router.patch('/:id', v2.patchAssistant); + +/** + * Deletes an assistant. + * @route DELETE /assistants/:id + * @param {string} req.params.id - Assistant identifier. + * @returns {Assistant} 200 - success response - application/json + */ +router.delete('/:id', v1.deleteAssistant); + +/** + * Returns a list of assistants. + * @route GET /assistants + * @param {AssistantListParams} req.query - The assistant list parameters for pagination and sorting. + * @returns {AssistantListResponse} 200 - success response - application/json + */ +router.get('/', v1.listAssistants); + +/** + * Returns a list of the user's assistant documents (metadata saved to database). + * @route GET /assistants/documents + * @returns {AssistantDocument[]} 200 - success response - application/json + */ +router.get('/documents', v1.getAssistantDocuments); + +/** + * Uploads and updates an avatar for a specific assistant. + * @route POST /avatar/:assistant_id + * @param {string} req.params.assistant_id - The ID of the assistant. + * @param {Express.Multer.File} req.file - The avatar image file. + * @param {string} [req.body.metadata] - Optional metadata for the assistant's avatar. + * @returns {Object} 200 - success response - application/json + */ +router.post('/avatar/:assistant_id', upload.single('file'), v1.uploadAssistantAvatar); + +module.exports = router; diff --git a/api/server/routes/files/files.js b/api/server/routes/files/files.js index 812d4bd33..565893af3 100644 --- a/api/server/routes/files/files.js +++ b/api/server/routes/files/files.js @@ -1,6 +1,6 @@ const fs = require('fs').promises; const express = require('express'); -const { isUUID, FileSources } = require('librechat-data-provider'); +const { isUUID, checkOpenAIStorage } = require('librechat-data-provider'); const { filterFile, processFileUpload, @@ -89,7 +89,7 @@ router.get('/download/:userId/:file_id', async (req, res) => { return res.status(403).send('Forbidden'); } - if (file.source === FileSources.openai && !file.model) { + if (checkOpenAIStorage(file.source) && !file.model) { logger.warn(`${errorPrefix} has no associated model: ${file_id}`); return res.status(400).send('The model used when creating this file is not available'); } @@ -110,7 +110,8 @@ router.get('/download/:userId/:file_id', async (req, res) => { let passThrough; /** @type {ReadableStream | undefined} */ let fileStream; - if (file.source === FileSources.openai) { + + if (checkOpenAIStorage(file.source)) { req.body = { model: file.model }; const { openai } = await initializeClient({ req, res }); logger.debug(`Downloading file ${file_id} from OpenAI`); diff --git a/api/server/routes/search.js b/api/server/routes/search.js index 2197b38ce..68cff7532 100644 --- a/api/server/routes/search.js +++ b/api/server/routes/search.js @@ -41,29 +41,10 @@ router.get('/', async function (req, res) { return; } - const messages = ( - await Message.meiliSearch( - q, - { - attributesToHighlight: ['text'], - highlightPreTag: '**', - highlightPostTag: '**', - }, - true, - ) - ).hits.map((message) => { - const { _formatted, ...rest } = message; - return { - ...rest, - searchResult: true, - text: _formatted.text, - }; - }); + const messages = (await Message.meiliSearch(q, undefined, true)).hits; const titles = (await Conversation.meiliSearch(q)).hits; + const sortedHits = reduceHits(messages, titles); - // debugging: - // logger.debug('user:', user, 'message hits:', messages.length, 'convo hits:', titles.length); - // logger.debug('sorted hits:', sortedHits.length); const result = await getConvosQueried(user, sortedHits, pageNumber); const activeMessages = []; @@ -86,8 +67,7 @@ router.get('/', async function (req, res) { delete result.cache; } delete result.convoMap; - // for debugging - // logger.debug(result, messages.length); + res.status(200).send(result); } catch (error) { logger.error('[/search] Error while searching messages & conversations', error); diff --git a/api/server/services/ActionService.js b/api/server/services/ActionService.js index 344a6570b..6f832bce1 100644 --- a/api/server/services/ActionService.js +++ b/api/server/services/ActionService.js @@ -1,20 +1,59 @@ const { - AuthTypeEnum, - EModelEndpoint, - actionDomainSeparator, CacheKeys, Constants, + AuthTypeEnum, + actionDelimiter, + isImageVisionTool, + actionDomainSeparator, } = require('librechat-data-provider'); const { encryptV2, decryptV2 } = require('~/server/utils/crypto'); -const { getActions } = require('~/models/Action'); +const { getActions, deleteActions } = require('~/models/Action'); +const { deleteAssistant } = require('~/models/Assistant'); const { getLogStores } = require('~/cache'); const { logger } = require('~/config'); +const toolNameRegex = /^[a-zA-Z0-9_-]+$/; + +/** + * Validates tool name against regex pattern and updates if necessary. + * @param {object} params - The parameters for the function. + * @param {object} params.req - Express Request. + * @param {FunctionTool} params.tool - The tool object. + * @param {string} params.assistant_id - The assistant ID + * @returns {object|null} - Updated tool object or null if invalid and not an action. + */ +const validateAndUpdateTool = async ({ req, tool, assistant_id }) => { + let actions; + if (isImageVisionTool(tool)) { + return null; + } + if (!toolNameRegex.test(tool.function.name)) { + const [functionName, domain] = tool.function.name.split(actionDelimiter); + actions = await getActions({ assistant_id, user: req.user.id }, true); + const matchingActions = actions.filter((action) => { + const metadata = action.metadata; + return metadata && metadata.domain === domain; + }); + const action = matchingActions[0]; + if (!action) { + return null; + } + + const parsedDomain = await domainParser(req, domain, true); + + if (!parsedDomain) { + return null; + } + + tool.function.name = `${functionName}${actionDelimiter}${parsedDomain}`; + } + return tool; +}; + /** * Encodes or decodes a domain name to/from base64, or replacing periods with a custom separator. * - * Necessary because Azure OpenAI Assistants API doesn't support periods in function - * names due to `[a-zA-Z0-9_-]*` Regex Validation, limited to a 64-character maximum. + * Necessary due to `[a-zA-Z0-9_-]*` Regex Validation, limited to a 64-character maximum. * * @param {Express.Request} req - The Express Request object. * @param {string} domain - The domain name to encode/decode. @@ -26,10 +65,6 @@ async function domainParser(req, domain, inverse = false) { return; } - if (!req.app.locals[EModelEndpoint.azureOpenAI]?.assistants) { - return domain; - } - const domainsCache = getLogStores(CacheKeys.ENCODED_DOMAINS); const cachedDomain = await domainsCache.get(domain); if (inverse && cachedDomain) { @@ -170,10 +205,29 @@ function decryptMetadata(metadata) { return decryptedMetadata; } +/** + * Deletes an action and its corresponding assistant. + * @param {Object} params - The parameters for the function. + * @param {OpenAIClient} params.req - The Express Request object. + * @param {string} params.assistant_id - The ID of the assistant. + */ +const deleteAssistantActions = async ({ req, assistant_id }) => { + try { + await deleteActions({ assistant_id, user: req.user.id }); + await deleteAssistant({ assistant_id, user: req.user.id }); + } catch (error) { + const message = 'Trouble deleting Assistant Actions for Assistant ID: ' + assistant_id; + logger.error(message, error); + throw new Error(message); + } +}; + module.exports = { - loadActionSets, + deleteAssistantActions, + validateAndUpdateTool, createActionTool, encryptMetadata, decryptMetadata, + loadActionSets, domainParser, }; diff --git a/api/server/services/ActionService.spec.js b/api/server/services/ActionService.spec.js index 57f998896..a9650d603 100644 --- a/api/server/services/ActionService.spec.js +++ b/api/server/services/ActionService.spec.js @@ -73,12 +73,12 @@ describe('domainParser', () => { const TLD = '.com'; // Non-azure request - it('returns domain as is if not azure', async () => { + it('does not return domain as is if not azure', async () => { const domain = `example.com${actionDomainSeparator}test${actionDomainSeparator}`; const result1 = await domainParser(reqNoAzure, domain, false); const result2 = await domainParser(reqNoAzure, domain, true); - expect(result1).toEqual(domain); - expect(result2).toEqual(domain); + expect(result1).not.toEqual(domain); + expect(result2).not.toEqual(domain); }); // Test for Empty or Null Inputs diff --git a/api/server/services/AppService.js b/api/server/services/AppService.js index 4163a3df8..b4d35f136 100644 --- a/api/server/services/AppService.js +++ b/api/server/services/AppService.js @@ -72,7 +72,14 @@ const AppService = async (app) => { } if (config?.endpoints?.[EModelEndpoint.azureOpenAI]?.assistants) { - endpointLocals[EModelEndpoint.assistants] = azureAssistantsDefaults(); + endpointLocals[EModelEndpoint.azureAssistants] = azureAssistantsDefaults(); + } + + if (config?.endpoints?.[EModelEndpoint.azureAssistants]) { + endpointLocals[EModelEndpoint.azureAssistants] = assistantsConfigSetup( + config, + endpointLocals[EModelEndpoint.azureAssistants], + ); } if (config?.endpoints?.[EModelEndpoint.assistants]) { diff --git a/api/server/services/AppService.spec.js b/api/server/services/AppService.spec.js index e55bff994..602ef43f8 100644 --- a/api/server/services/AppService.spec.js +++ b/api/server/services/AppService.spec.js @@ -253,8 +253,8 @@ describe('AppService', () => { process.env.EASTUS_API_KEY = 'eastus-key'; await AppService(app); - expect(app.locals).toHaveProperty(EModelEndpoint.assistants); - expect(app.locals[EModelEndpoint.assistants].capabilities.length).toEqual(3); + expect(app.locals).toHaveProperty(EModelEndpoint.azureAssistants); + expect(app.locals[EModelEndpoint.azureAssistants].capabilities.length).toEqual(3); }); it('should correctly configure Azure OpenAI endpoint based on custom config', async () => { diff --git a/api/server/services/AssistantService.js b/api/server/services/AssistantService.js index 41e88dc8b..2db0a56b6 100644 --- a/api/server/services/AssistantService.js +++ b/api/server/services/AssistantService.js @@ -78,7 +78,7 @@ async function createOnTextProgress({ * @return {Promise} */ async function getResponse({ openai, run_id, thread_id }) { - const run = await waitForRun({ openai, run_id, thread_id, pollIntervalMs: 500 }); + const run = await waitForRun({ openai, run_id, thread_id, pollIntervalMs: 2000 }); if (run.status === RunStatus.COMPLETED) { const messages = await openai.beta.threads.messages.list(thread_id, defaultOrderQuery); @@ -393,8 +393,9 @@ async function runAssistant({ }, }); + const { endpoint = EModelEndpoint.azureAssistants } = openai.req.body; /** @type {TCustomConfig.endpoints.assistants} */ - const assistantsEndpointConfig = openai.req.app.locals?.[EModelEndpoint.assistants] ?? {}; + const assistantsEndpointConfig = openai.req.app.locals?.[endpoint] ?? {}; const { pollIntervalMs, timeoutMs } = assistantsEndpointConfig; const run = await waitForRun({ diff --git a/api/server/services/Config/EndpointService.js b/api/server/services/Config/EndpointService.js index 987fbb885..438cb81e8 100644 --- a/api/server/services/Config/EndpointService.js +++ b/api/server/services/Config/EndpointService.js @@ -3,6 +3,7 @@ const { isUserProvided, generateConfig } = require('~/server/utils'); const { OPENAI_API_KEY: openAIApiKey, + AZURE_ASSISTANTS_API_KEY: azureAssistantsApiKey, ASSISTANTS_API_KEY: assistantsApiKey, AZURE_API_KEY: azureOpenAIApiKey, ANTHROPIC_API_KEY: anthropicApiKey, @@ -13,6 +14,7 @@ const { OPENAI_REVERSE_PROXY, AZURE_OPENAI_BASEURL, ASSISTANTS_BASE_URL, + AZURE_ASSISTANTS_BASE_URL, } = process.env ?? {}; const useAzurePlugins = !!PLUGINS_USE_AZURE; @@ -28,11 +30,20 @@ module.exports = { useAzurePlugins, userProvidedOpenAI, googleKey, - [EModelEndpoint.openAI]: generateConfig(openAIApiKey, OPENAI_REVERSE_PROXY), - [EModelEndpoint.assistants]: generateConfig(assistantsApiKey, ASSISTANTS_BASE_URL, true), - [EModelEndpoint.azureOpenAI]: generateConfig(azureOpenAIApiKey, AZURE_OPENAI_BASEURL), - [EModelEndpoint.chatGPTBrowser]: generateConfig(chatGPTToken), - [EModelEndpoint.anthropic]: generateConfig(anthropicApiKey), [EModelEndpoint.bingAI]: generateConfig(bingToken), + [EModelEndpoint.anthropic]: generateConfig(anthropicApiKey), + [EModelEndpoint.chatGPTBrowser]: generateConfig(chatGPTToken), + [EModelEndpoint.openAI]: generateConfig(openAIApiKey, OPENAI_REVERSE_PROXY), + [EModelEndpoint.azureOpenAI]: generateConfig(azureOpenAIApiKey, AZURE_OPENAI_BASEURL), + [EModelEndpoint.assistants]: generateConfig( + assistantsApiKey, + ASSISTANTS_BASE_URL, + EModelEndpoint.assistants, + ), + [EModelEndpoint.azureAssistants]: generateConfig( + azureAssistantsApiKey, + AZURE_ASSISTANTS_BASE_URL, + EModelEndpoint.azureAssistants, + ), }, }; diff --git a/api/server/services/Config/loadConfigEndpoints.js b/api/server/services/Config/loadConfigEndpoints.js index cd05cb9ac..203a461b0 100644 --- a/api/server/services/Config/loadConfigEndpoints.js +++ b/api/server/services/Config/loadConfigEndpoints.js @@ -53,7 +53,7 @@ async function loadConfigEndpoints(req) { if (req.app.locals[EModelEndpoint.azureOpenAI]?.assistants) { /** @type {Omit} */ - endpointsConfig[EModelEndpoint.assistants] = { + endpointsConfig[EModelEndpoint.azureAssistants] = { userProvide: false, }; } diff --git a/api/server/services/Config/loadConfigModels.js b/api/server/services/Config/loadConfigModels.js index b3997a2ad..cb0b800d7 100644 --- a/api/server/services/Config/loadConfigModels.js +++ b/api/server/services/Config/loadConfigModels.js @@ -30,7 +30,7 @@ async function loadConfigModels(req) { } if (azureEndpoint?.assistants && azureConfig.assistantModels) { - modelsConfig[EModelEndpoint.assistants] = azureConfig.assistantModels; + modelsConfig[EModelEndpoint.azureAssistants] = azureConfig.assistantModels; } if (!Array.isArray(endpoints[EModelEndpoint.custom])) { diff --git a/api/server/services/Config/loadDefaultEConfig.js b/api/server/services/Config/loadDefaultEConfig.js index 960dfb4c7..379bd4250 100644 --- a/api/server/services/Config/loadDefaultEConfig.js +++ b/api/server/services/Config/loadDefaultEConfig.js @@ -9,13 +9,15 @@ const { config } = require('./EndpointService'); */ async function loadDefaultEndpointsConfig(req) { const { google, gptPlugins } = await loadAsyncEndpoints(req); - const { openAI, assistants, bingAI, anthropic, azureOpenAI, chatGPTBrowser } = config; + const { openAI, assistants, azureAssistants, bingAI, anthropic, azureOpenAI, chatGPTBrowser } = + config; const enabledEndpoints = getEnabledEndpoints(); const endpointConfig = { [EModelEndpoint.openAI]: openAI, [EModelEndpoint.assistants]: assistants, + [EModelEndpoint.azureAssistants]: azureAssistants, [EModelEndpoint.azureOpenAI]: azureOpenAI, [EModelEndpoint.google]: google, [EModelEndpoint.bingAI]: bingAI, diff --git a/api/server/services/Config/loadDefaultModels.js b/api/server/services/Config/loadDefaultModels.js index e0b2ca0e4..c550fbebb 100644 --- a/api/server/services/Config/loadDefaultModels.js +++ b/api/server/services/Config/loadDefaultModels.js @@ -25,6 +25,7 @@ async function loadDefaultModels(req) { plugins: true, }); const assistants = await getOpenAIModels({ assistants: true }); + const azureAssistants = await getOpenAIModels({ azureAssistants: true }); return { [EModelEndpoint.openAI]: openAI, @@ -35,6 +36,7 @@ async function loadDefaultModels(req) { [EModelEndpoint.bingAI]: ['BingAI', 'Sydney'], [EModelEndpoint.chatGPTBrowser]: chatGPTBrowser, [EModelEndpoint.assistants]: assistants, + [EModelEndpoint.azureAssistants]: azureAssistants, }; } diff --git a/api/server/services/Endpoints/assistants/index.js b/api/server/services/Endpoints/assistants/index.js index 10e94f2cd..772b1efb1 100644 --- a/api/server/services/Endpoints/assistants/index.js +++ b/api/server/services/Endpoints/assistants/index.js @@ -2,95 +2,8 @@ const addTitle = require('./addTitle'); const buildOptions = require('./buildOptions'); const initializeClient = require('./initializeClient'); -/** - * Asynchronously lists assistants based on provided query parameters. - * - * Initializes the client with the current request and response objects and lists assistants - * according to the query parameters. This function abstracts the logic for non-Azure paths. - * - * @async - * @param {object} params - The parameters object. - * @param {object} params.req - The request object, used for initializing the client. - * @param {object} params.res - The response object, used for initializing the client. - * @param {object} params.query - The query parameters to list assistants (e.g., limit, order). - * @returns {Promise} A promise that resolves to the response from the `openai.beta.assistants.list` method call. - */ -const listAssistants = async ({ req, res, query }) => { - const { openai } = await initializeClient({ req, res }); - return openai.beta.assistants.list(query); -}; - -/** - * Asynchronously lists assistants for Azure configured groups. - * - * Iterates through Azure configured assistant groups, initializes the client with the current request and response objects, - * lists assistants based on the provided query parameters, and merges their data alongside the model information into a single array. - * - * @async - * @param {object} params - The parameters object. - * @param {object} params.req - The request object, used for initializing the client and manipulating the request body. - * @param {object} params.res - The response object, used for initializing the client. - * @param {TAzureConfig} params.azureConfig - The Azure configuration object containing assistantGroups and groupMap. - * @param {object} params.query - The query parameters to list assistants (e.g., limit, order). - * @returns {Promise} A promise that resolves to an array of assistant data merged with their respective model information. - */ -const listAssistantsForAzure = async ({ req, res, azureConfig = {}, query }) => { - /** @type {Array<[string, TAzureModelConfig]>} */ - const groupModelTuples = []; - const promises = []; - /** @type {Array} */ - const groups = []; - - const { groupMap, assistantGroups } = azureConfig; - - for (const groupName of assistantGroups) { - const group = groupMap[groupName]; - groups.push(group); - - const currentModelTuples = Object.entries(group?.models); - groupModelTuples.push(currentModelTuples); - - /* The specified model is only necessary to - fetch assistants for the shared instance */ - req.body.model = currentModelTuples[0][0]; - promises.push(listAssistants({ req, res, query })); - } - - const resolvedQueries = await Promise.all(promises); - const data = resolvedQueries.flatMap((res, i) => - res.data.map((assistant) => { - const deploymentName = assistant.model; - const currentGroup = groups[i]; - const currentModelTuples = groupModelTuples[i]; - const firstModel = currentModelTuples[0][0]; - - if (currentGroup.deploymentName === deploymentName) { - return { ...assistant, model: firstModel }; - } - - for (const [model, modelConfig] of currentModelTuples) { - if (modelConfig.deploymentName === deploymentName) { - return { ...assistant, model }; - } - } - - return { ...assistant, model: firstModel }; - }), - ); - - return { - first_id: data[0]?.id, - last_id: data[data.length - 1]?.id, - object: 'list', - has_more: false, - data, - }; -}; - module.exports = { addTitle, buildOptions, initializeClient, - listAssistants, - listAssistantsForAzure, }; diff --git a/api/server/services/Endpoints/assistants/initializeClient.js b/api/server/services/Endpoints/assistants/initializeClient.js index c44bc66f3..5dadd54d1 100644 --- a/api/server/services/Endpoints/assistants/initializeClient.js +++ b/api/server/services/Endpoints/assistants/initializeClient.js @@ -1,11 +1,6 @@ const OpenAI = require('openai'); const { HttpsProxyAgent } = require('https-proxy-agent'); -const { - ErrorTypes, - EModelEndpoint, - resolveHeaders, - mapModelToAzureConfig, -} = require('librechat-data-provider'); +const { ErrorTypes, EModelEndpoint } = require('librechat-data-provider'); const { getUserKeyValues, getUserKeyExpiry, @@ -13,9 +8,8 @@ const { } = require('~/server/services/UserService'); const OpenAIClient = require('~/app/clients/OpenAIClient'); const { isUserProvided } = require('~/server/utils'); -const { constructAzureURL } = require('~/utils'); -const initializeClient = async ({ req, res, endpointOption, initAppClient = false }) => { +const initializeClient = async ({ req, res, endpointOption, version, initAppClient = false }) => { const { PROXY, OPENAI_ORGANIZATION, ASSISTANTS_API_KEY, ASSISTANTS_BASE_URL } = process.env; const userProvidesKey = isUserProvided(ASSISTANTS_API_KEY); @@ -34,7 +28,11 @@ const initializeClient = async ({ req, res, endpointOption, initAppClient = fals let apiKey = userProvidesKey ? userValues.apiKey : ASSISTANTS_API_KEY; let baseURL = userProvidesURL ? userValues.baseURL : ASSISTANTS_BASE_URL; - const opts = {}; + const opts = { + defaultHeaders: { + 'OpenAI-Beta': `assistants=${version}`, + }, + }; const clientOptions = { reverseProxyUrl: baseURL ?? null, @@ -44,54 +42,6 @@ const initializeClient = async ({ req, res, endpointOption, initAppClient = fals ...endpointOption, }; - /** @type {TAzureConfig | undefined} */ - const azureConfig = req.app.locals[EModelEndpoint.azureOpenAI]; - - /** @type {AzureOptions | undefined} */ - let azureOptions; - - if (azureConfig && azureConfig.assistants) { - const { modelGroupMap, groupMap, assistantModels } = azureConfig; - const modelName = req.body.model ?? req.query.model ?? assistantModels[0]; - const { - azureOptions: currentOptions, - baseURL: azureBaseURL, - headers = {}, - serverless, - } = mapModelToAzureConfig({ - modelName, - modelGroupMap, - groupMap, - }); - - azureOptions = currentOptions; - - baseURL = constructAzureURL({ - baseURL: azureBaseURL ?? 'https://${INSTANCE_NAME}.openai.azure.com/openai', - azureOptions, - }); - - apiKey = azureOptions.azureOpenAIApiKey; - opts.defaultQuery = { 'api-version': azureOptions.azureOpenAIApiVersion }; - opts.defaultHeaders = resolveHeaders({ ...headers, 'api-key': apiKey }); - opts.model = azureOptions.azureOpenAIApiDeploymentName; - - if (initAppClient) { - clientOptions.titleConvo = azureConfig.titleConvo; - clientOptions.titleModel = azureConfig.titleModel; - clientOptions.titleMethod = azureConfig.titleMethod ?? 'completion'; - - const groupName = modelGroupMap[modelName].group; - clientOptions.addParams = azureConfig.groupMap[groupName].addParams; - clientOptions.dropParams = azureConfig.groupMap[groupName].dropParams; - clientOptions.forcePrompt = azureConfig.groupMap[groupName].forcePrompt; - - clientOptions.reverseProxyUrl = baseURL ?? clientOptions.reverseProxyUrl; - clientOptions.headers = opts.defaultHeaders; - clientOptions.azure = !serverless && azureOptions; - } - } - if (userProvidesKey & !apiKey) { throw new Error( JSON.stringify({ @@ -125,10 +75,6 @@ const initializeClient = async ({ req, res, endpointOption, initAppClient = fals openai.req = req; openai.res = res; - if (azureOptions) { - openai.locals = { ...(openai.locals ?? {}), azureOptions }; - } - if (endpointOption && initAppClient) { const client = new OpenAIClient(apiKey, clientOptions); return { diff --git a/api/server/services/Endpoints/azureAssistants/buildOptions.js b/api/server/services/Endpoints/azureAssistants/buildOptions.js new file mode 100644 index 000000000..047663c4e --- /dev/null +++ b/api/server/services/Endpoints/azureAssistants/buildOptions.js @@ -0,0 +1,19 @@ +const buildOptions = (endpoint, parsedBody) => { + // eslint-disable-next-line no-unused-vars + const { promptPrefix, assistant_id, iconURL, greeting, spec, ...rest } = parsedBody; + const endpointOption = { + endpoint, + promptPrefix, + assistant_id, + iconURL, + greeting, + spec, + modelOptions: { + ...rest, + }, + }; + + return endpointOption; +}; + +module.exports = buildOptions; diff --git a/api/server/services/Endpoints/azureAssistants/index.js b/api/server/services/Endpoints/azureAssistants/index.js new file mode 100644 index 000000000..399446830 --- /dev/null +++ b/api/server/services/Endpoints/azureAssistants/index.js @@ -0,0 +1,7 @@ +const buildOptions = require('./buildOptions'); +const initializeClient = require('./initializeClient'); + +module.exports = { + buildOptions, + initializeClient, +}; diff --git a/api/server/services/Endpoints/azureAssistants/initializeClient.js b/api/server/services/Endpoints/azureAssistants/initializeClient.js new file mode 100644 index 000000000..69a55c74b --- /dev/null +++ b/api/server/services/Endpoints/azureAssistants/initializeClient.js @@ -0,0 +1,195 @@ +const OpenAI = require('openai'); +const { HttpsProxyAgent } = require('https-proxy-agent'); +const { + ErrorTypes, + EModelEndpoint, + resolveHeaders, + mapModelToAzureConfig, +} = require('librechat-data-provider'); +const { + getUserKeyValues, + getUserKeyExpiry, + checkUserKeyExpiry, +} = require('~/server/services/UserService'); +const OpenAIClient = require('~/app/clients/OpenAIClient'); +const { isUserProvided } = require('~/server/utils'); +const { constructAzureURL } = require('~/utils'); + +class Files { + constructor(client) { + this._client = client; + } + /** + * Create an assistant file by attaching a + * [File](https://platform.openai.com/docs/api-reference/files) to an + * [assistant](https://platform.openai.com/docs/api-reference/assistants). + */ + create(assistantId, body, options) { + return this._client.post(`/assistants/${assistantId}/files`, { + body, + ...options, + headers: { 'OpenAI-Beta': 'assistants=v1', ...options?.headers }, + }); + } + + /** + * Retrieves an AssistantFile. + */ + retrieve(assistantId, fileId, options) { + return this._client.get(`/assistants/${assistantId}/files/${fileId}`, { + ...options, + headers: { 'OpenAI-Beta': 'assistants=v1', ...options?.headers }, + }); + } + + /** + * Delete an assistant file. + */ + del(assistantId, fileId, options) { + return this._client.delete(`/assistants/${assistantId}/files/${fileId}`, { + ...options, + headers: { 'OpenAI-Beta': 'assistants=v1', ...options?.headers }, + }); + } +} + +const initializeClient = async ({ req, res, version, endpointOption, initAppClient = false }) => { + const { PROXY, OPENAI_ORGANIZATION, AZURE_ASSISTANTS_API_KEY, AZURE_ASSISTANTS_BASE_URL } = + process.env; + + const userProvidesKey = isUserProvided(AZURE_ASSISTANTS_API_KEY); + const userProvidesURL = isUserProvided(AZURE_ASSISTANTS_BASE_URL); + + let userValues = null; + if (userProvidesKey || userProvidesURL) { + const expiresAt = await getUserKeyExpiry({ + userId: req.user.id, + name: EModelEndpoint.azureAssistants, + }); + checkUserKeyExpiry(expiresAt, EModelEndpoint.azureAssistants); + userValues = await getUserKeyValues({ + userId: req.user.id, + name: EModelEndpoint.azureAssistants, + }); + } + + let apiKey = userProvidesKey ? userValues.apiKey : AZURE_ASSISTANTS_API_KEY; + let baseURL = userProvidesURL ? userValues.baseURL : AZURE_ASSISTANTS_BASE_URL; + + const opts = {}; + + const clientOptions = { + reverseProxyUrl: baseURL ?? null, + proxy: PROXY ?? null, + req, + res, + ...endpointOption, + }; + + /** @type {TAzureConfig | undefined} */ + const azureConfig = req.app.locals[EModelEndpoint.azureOpenAI]; + + /** @type {AzureOptions | undefined} */ + let azureOptions; + + if (azureConfig && azureConfig.assistants) { + const { modelGroupMap, groupMap, assistantModels } = azureConfig; + const modelName = req.body.model ?? req.query.model ?? assistantModels[0]; + const { + azureOptions: currentOptions, + baseURL: azureBaseURL, + headers = {}, + serverless, + } = mapModelToAzureConfig({ + modelName, + modelGroupMap, + groupMap, + }); + + azureOptions = currentOptions; + + baseURL = constructAzureURL({ + baseURL: azureBaseURL ?? 'https://${INSTANCE_NAME}.openai.azure.com/openai', + azureOptions, + }); + + apiKey = azureOptions.azureOpenAIApiKey; + opts.defaultQuery = { 'api-version': azureOptions.azureOpenAIApiVersion }; + opts.defaultHeaders = resolveHeaders({ + ...headers, + 'api-key': apiKey, + 'OpenAI-Beta': `assistants=${version}`, + }); + opts.model = azureOptions.azureOpenAIApiDeploymentName; + + if (initAppClient) { + clientOptions.titleConvo = azureConfig.titleConvo; + clientOptions.titleModel = azureConfig.titleModel; + clientOptions.titleMethod = azureConfig.titleMethod ?? 'completion'; + + const groupName = modelGroupMap[modelName].group; + clientOptions.addParams = azureConfig.groupMap[groupName].addParams; + clientOptions.dropParams = azureConfig.groupMap[groupName].dropParams; + clientOptions.forcePrompt = azureConfig.groupMap[groupName].forcePrompt; + + clientOptions.reverseProxyUrl = baseURL ?? clientOptions.reverseProxyUrl; + clientOptions.headers = opts.defaultHeaders; + clientOptions.azure = !serverless && azureOptions; + } + } + + if (userProvidesKey & !apiKey) { + throw new Error( + JSON.stringify({ + type: ErrorTypes.NO_USER_KEY, + }), + ); + } + + if (!apiKey) { + throw new Error('Assistants API key not provided. Please provide it again.'); + } + + if (baseURL) { + opts.baseURL = baseURL; + } + + if (PROXY) { + opts.httpAgent = new HttpsProxyAgent(PROXY); + } + + if (OPENAI_ORGANIZATION) { + opts.organization = OPENAI_ORGANIZATION; + } + + /** @type {OpenAIClient} */ + const openai = new OpenAI({ + apiKey, + ...opts, + }); + + openai.beta.assistants.files = new Files(openai); + + openai.req = req; + openai.res = res; + + if (azureOptions) { + openai.locals = { ...(openai.locals ?? {}), azureOptions }; + } + + if (endpointOption && initAppClient) { + const client = new OpenAIClient(apiKey, clientOptions); + return { + client, + openai, + openAIApiKey: apiKey, + }; + } + + return { + openai, + openAIApiKey: apiKey, + }; +}; + +module.exports = initializeClient; diff --git a/api/server/services/Endpoints/azureAssistants/initializeClient.spec.js b/api/server/services/Endpoints/azureAssistants/initializeClient.spec.js new file mode 100644 index 000000000..6dc4a6d47 --- /dev/null +++ b/api/server/services/Endpoints/azureAssistants/initializeClient.spec.js @@ -0,0 +1,112 @@ +// const OpenAI = require('openai'); +const { HttpsProxyAgent } = require('https-proxy-agent'); +const { ErrorTypes } = require('librechat-data-provider'); +const { getUserKey, getUserKeyExpiry, getUserKeyValues } = require('~/server/services/UserService'); +const initializeClient = require('./initializeClient'); +// const { OpenAIClient } = require('~/app'); + +jest.mock('~/server/services/UserService', () => ({ + getUserKey: jest.fn(), + getUserKeyExpiry: jest.fn(), + getUserKeyValues: jest.fn(), + checkUserKeyExpiry: jest.requireActual('~/server/services/UserService').checkUserKeyExpiry, +})); + +const today = new Date(); +const tenDaysFromToday = new Date(today.setDate(today.getDate() + 10)); +const isoString = tenDaysFromToday.toISOString(); + +describe('initializeClient', () => { + // Set up environment variables + const originalEnvironment = process.env; + const app = { + locals: {}, + }; + + beforeEach(() => { + jest.resetModules(); // Clears the cache + process.env = { ...originalEnvironment }; // Make a copy + }); + + afterAll(() => { + process.env = originalEnvironment; // Restore original env vars + }); + + test('initializes OpenAI client with default API key and URL', async () => { + process.env.AZURE_ASSISTANTS_API_KEY = 'default-api-key'; + process.env.AZURE_ASSISTANTS_BASE_URL = 'https://default.api.url'; + + // Assuming 'isUserProvided' to return false for this test case + jest.mock('~/server/utils', () => ({ + isUserProvided: jest.fn().mockReturnValueOnce(false), + })); + + const req = { user: { id: 'user123' }, app }; + const res = {}; + + const { openai, openAIApiKey } = await initializeClient({ req, res }); + expect(openai.apiKey).toBe('default-api-key'); + expect(openAIApiKey).toBe('default-api-key'); + expect(openai.baseURL).toBe('https://default.api.url'); + }); + + test('initializes OpenAI client with user-provided API key and URL', async () => { + process.env.AZURE_ASSISTANTS_API_KEY = 'user_provided'; + process.env.AZURE_ASSISTANTS_BASE_URL = 'user_provided'; + + getUserKeyValues.mockResolvedValue({ apiKey: 'user-api-key', baseURL: 'https://user.api.url' }); + getUserKeyExpiry.mockResolvedValue(isoString); + + const req = { user: { id: 'user123' }, app }; + const res = {}; + + const { openai, openAIApiKey } = await initializeClient({ req, res }); + expect(openAIApiKey).toBe('user-api-key'); + expect(openai.apiKey).toBe('user-api-key'); + expect(openai.baseURL).toBe('https://user.api.url'); + }); + + test('throws error for invalid JSON in user-provided values', async () => { + process.env.AZURE_ASSISTANTS_API_KEY = 'user_provided'; + getUserKey.mockResolvedValue('invalid-json'); + getUserKeyExpiry.mockResolvedValue(isoString); + getUserKeyValues.mockImplementation(() => { + let userValues = getUserKey(); + try { + userValues = JSON.parse(userValues); + } catch (e) { + throw new Error( + JSON.stringify({ + type: ErrorTypes.INVALID_USER_KEY, + }), + ); + } + return userValues; + }); + + const req = { user: { id: 'user123' } }; + const res = {}; + + await expect(initializeClient({ req, res })).rejects.toThrow(/invalid_user_key/); + }); + + test('throws error if API key is not provided', async () => { + delete process.env.AZURE_ASSISTANTS_API_KEY; // Simulate missing API key + + const req = { user: { id: 'user123' }, app }; + const res = {}; + + await expect(initializeClient({ req, res })).rejects.toThrow(/Assistants API key not/); + }); + + test('initializes OpenAI client with proxy configuration', async () => { + process.env.AZURE_ASSISTANTS_API_KEY = 'test-key'; + process.env.PROXY = 'http://proxy.server'; + + const req = { user: { id: 'user123' }, app }; + const res = {}; + + const { openai } = await initializeClient({ req, res }); + expect(openai.httpAgent).toBeInstanceOf(HttpsProxyAgent); + }); +}); diff --git a/api/server/services/Files/Firebase/crud.js b/api/server/services/Files/Firebase/crud.js index 43b5ec9b2..c4d1d05bf 100644 --- a/api/server/services/Files/Firebase/crud.js +++ b/api/server/services/Files/Firebase/crud.js @@ -180,7 +180,15 @@ const deleteFirebaseFile = async (req, file) => { if (!fileName.includes(req.user.id)) { throw new Error('Invalid file path'); } - await deleteFile('', fileName); + try { + await deleteFile('', fileName); + } catch (error) { + logger.error('Error deleting file from Firebase:', error); + if (error.code === 'storage/object-not-found') { + return; + } + throw error; + } }; /** diff --git a/api/server/services/Files/OpenAI/crud.js b/api/server/services/Files/OpenAI/crud.js index 346259e82..881b2063b 100644 --- a/api/server/services/Files/OpenAI/crud.js +++ b/api/server/services/Files/OpenAI/crud.js @@ -14,9 +14,11 @@ const { logger } = require('~/config'); * @returns {Promise} */ async function uploadOpenAIFile({ req, file, openai }) { + const { height, width } = req.body; + const isImage = height && width; const uploadedFile = await openai.files.create({ file: fs.createReadStream(file.path), - purpose: FilePurpose.Assistants, + purpose: isImage ? FilePurpose.Vision : FilePurpose.Assistants, }); logger.debug( @@ -34,7 +36,7 @@ async function uploadOpenAIFile({ req, file, openai }) { await sleep(sleepTime); } - return uploadedFile; + return isImage ? { ...uploadedFile, height, width } : uploadedFile; } /** diff --git a/api/server/services/Files/process.js b/api/server/services/Files/process.js index 7f91d481a..197fd160c 100644 --- a/api/server/services/Files/process.js +++ b/api/server/services/Files/process.js @@ -10,10 +10,13 @@ const { EModelEndpoint, mergeFileConfig, hostImageIdSuffix, + checkOpenAIStorage, hostImageNamePrefix, + isAssistantsEndpoint, } = require('librechat-data-provider'); +const { addResourceFileId, deleteResourceFileId } = require('~/server/controllers/assistants/v2'); const { convertImage, resizeAndConvert } = require('~/server/services/Files/images'); -const { initializeClient } = require('~/server/services/Endpoints/assistants'); +const { getOpenAIClient } = require('~/server/controllers/assistants/helpers'); const { createFile, updateFileUsage, deleteFiles } = require('~/models/File'); const { LB_QueueAsyncCall } = require('~/server/utils/queue'); const { getStrategyFunctions } = require('./strategies'); @@ -34,14 +37,16 @@ const processFiles = async (files) => { /** * Enqueues the delete operation to the leaky bucket queue if necessary, or adds it directly to promises. * - * @param {Express.Request} req - The express request object. - * @param {MongoFile} file - The file object to delete. - * @param {Function} deleteFile - The delete file function. - * @param {Promise[]} promises - The array of promises to await. - * @param {OpenAI | undefined} [openai] - If an OpenAI file, the initialized OpenAI client. + * @param {object} params - The passed parameters. + * @param {Express.Request} params.req - The express request object. + * @param {MongoFile} params.file - The file object to delete. + * @param {Function} params.deleteFile - The delete file function. + * @param {Promise[]} params.promises - The array of promises to await. + * @param {string[]} params.resolvedFileIds - The array of promises to await. + * @param {OpenAI | undefined} [params.openai] - If an OpenAI file, the initialized OpenAI client. */ -function enqueueDeleteOperation(req, file, deleteFile, promises, openai) { - if (file.source === FileSources.openai) { +function enqueueDeleteOperation({ req, file, deleteFile, promises, resolvedFileIds, openai }) { + if (checkOpenAIStorage(file.source)) { // Enqueue to leaky bucket promises.push( new Promise((resolve, reject) => { @@ -53,6 +58,7 @@ function enqueueDeleteOperation(req, file, deleteFile, promises, openai) { logger.error('Error deleting file from OpenAI source', err); reject(err); } else { + resolvedFileIds.push(file.file_id); resolve(result); } }, @@ -62,10 +68,12 @@ function enqueueDeleteOperation(req, file, deleteFile, promises, openai) { } else { // Add directly to promises promises.push( - deleteFile(req, file).catch((err) => { - logger.error('Error deleting file', err); - return Promise.reject(err); - }), + deleteFile(req, file) + .then(() => resolvedFileIds.push(file.file_id)) + .catch((err) => { + logger.error('Error deleting file', err); + return Promise.reject(err); + }), ); } } @@ -80,35 +88,71 @@ function enqueueDeleteOperation(req, file, deleteFile, promises, openai) { * @param {Express.Request} params.req - The express request object. * @param {DeleteFilesBody} params.req.body - The request body. * @param {string} [params.req.body.assistant_id] - The assistant ID if file uploaded is associated to an assistant. + * @param {string} [params.req.body.tool_resource] - The tool resource if assistant file uploaded is associated to a tool resource. * * @returns {Promise} */ const processDeleteRequest = async ({ req, files }) => { - const file_ids = files.map((file) => file.file_id); - + const resolvedFileIds = []; const deletionMethods = {}; const promises = []; - promises.push(deleteFiles(file_ids)); - /** @type {OpenAI | undefined} */ - let openai; - if (req.body.assistant_id) { - ({ openai } = await initializeClient({ req })); + /** @type {Record} */ + const client = { [FileSources.openai]: undefined, [FileSources.azure]: undefined }; + const initializeClients = async () => { + const openAIClient = await getOpenAIClient({ + req, + overrideEndpoint: EModelEndpoint.assistants, + }); + client[FileSources.openai] = openAIClient.openai; + + if (!req.app.locals[EModelEndpoint.azureOpenAI]?.assistants) { + return; + } + + const azureClient = await getOpenAIClient({ + req, + overrideEndpoint: EModelEndpoint.azureAssistants, + }); + client[FileSources.azure] = azureClient.openai; + }; + + if (req.body.assistant_id !== undefined) { + await initializeClients(); } for (const file of files) { const source = file.source ?? FileSources.local; - if (source === FileSources.openai && !openai) { - ({ openai } = await initializeClient({ req })); + if (checkOpenAIStorage(source) && !client[source]) { + await initializeClients(); } - if (req.body.assistant_id) { + const openai = client[source]; + + if (req.body.assistant_id && req.body.tool_resource) { + promises.push( + deleteResourceFileId({ + req, + openai, + file_id: file.file_id, + assistant_id: req.body.assistant_id, + tool_resource: req.body.tool_resource, + }), + ); + } else if (req.body.assistant_id) { promises.push(openai.beta.assistants.files.del(req.body.assistant_id, file.file_id)); } if (deletionMethods[source]) { - enqueueDeleteOperation(req, file, deletionMethods[source], promises, openai); + enqueueDeleteOperation({ + req, + file, + deleteFile: deletionMethods[source], + promises, + resolvedFileIds, + openai, + }); continue; } @@ -118,10 +162,11 @@ const processDeleteRequest = async ({ req, files }) => { } deletionMethods[source] = deleteFile; - enqueueDeleteOperation(req, file, deleteFile, promises, openai); + enqueueDeleteOperation({ req, file, deleteFile, promises, resolvedFileIds, openai }); } await Promise.allSettled(promises); + await deleteFiles(resolvedFileIds); }; /** @@ -180,12 +225,13 @@ const processFileURL = async ({ fileStrategy, userId, URL, fileName, basePath, c * * @param {Object} params - The parameters object. * @param {Express.Request} params.req - The Express request object. - * @param {Express.Response} params.res - The Express response object. + * @param {Express.Response} [params.res] - The Express response object. * @param {Express.Multer.File} params.file - The uploaded file. * @param {ImageMetadata} params.metadata - Additional metadata for the file. + * @param {boolean} params.returnFile - Whether to return the file metadata or return response as normal. * @returns {Promise} */ -const processImageFile = async ({ req, res, file, metadata }) => { +const processImageFile = async ({ req, res, file, metadata, returnFile = false }) => { const source = req.app.locals.fileStrategy; const { handleImageUpload } = getStrategyFunctions(source); const { file_id, temp_file_id, endpoint } = metadata; @@ -213,6 +259,10 @@ const processImageFile = async ({ req, res, file, metadata }) => { }, true, ); + + if (returnFile) { + return result; + } res.status(200).json({ message: 'File uploaded and processed successfully', ...result }); }; @@ -274,28 +324,57 @@ const uploadImageBuffer = async ({ req, context, metadata = {}, resize = true }) * @returns {Promise} */ const processFileUpload = async ({ req, res, file, metadata }) => { - const isAssistantUpload = metadata.endpoint === EModelEndpoint.assistants; - const source = isAssistantUpload ? FileSources.openai : FileSources.vectordb; + const isAssistantUpload = isAssistantsEndpoint(metadata.endpoint); + const assistantSource = + metadata.endpoint === EModelEndpoint.azureAssistants ? FileSources.azure : FileSources.openai; + const source = isAssistantUpload ? assistantSource : FileSources.vectordb; const { handleFileUpload } = getStrategyFunctions(source); const { file_id, temp_file_id } = metadata; /** @type {OpenAI | undefined} */ let openai; - if (source === FileSources.openai) { - ({ openai } = await initializeClient({ req })); + if (checkOpenAIStorage(source)) { + ({ openai } = await getOpenAIClient({ req })); } - const { id, bytes, filename, filepath, embedded } = await handleFileUpload({ + const { + id, + bytes, + filename, + filepath: _filepath, + embedded, + height, + width, + } = await handleFileUpload({ req, file, file_id, openai, }); - if (isAssistantUpload && !metadata.message_file) { + if (isAssistantUpload && !metadata.message_file && !metadata.tool_resource) { await openai.beta.assistants.files.create(metadata.assistant_id, { file_id: id, }); + } else if (isAssistantUpload && !metadata.message_file) { + await addResourceFileId({ + req, + openai, + file_id: id, + assistant_id: metadata.assistant_id, + tool_resource: metadata.tool_resource, + }); + } + + let filepath = isAssistantUpload ? `${openai.baseURL}/files/${id}` : _filepath; + if (isAssistantUpload && file.mimetype.startsWith('image')) { + const result = await processImageFile({ + req, + file, + metadata: { file_id: v4() }, + returnFile: true, + }); + filepath = result.filepath; } const result = await createFile( @@ -304,13 +383,15 @@ const processFileUpload = async ({ req, res, file, metadata }) => { file_id: id ?? file_id, temp_file_id, bytes, + filepath, filename: filename ?? file.originalname, - filepath: isAssistantUpload ? `${openai.baseURL}/files/${id}` : filepath, context: isAssistantUpload ? FileContext.assistants : FileContext.message_attachment, model: isAssistantUpload ? req.body.model : undefined, type: file.mimetype, embedded, source, + height, + width, }, true, ); @@ -340,7 +421,10 @@ const processOpenAIFile = async ({ originalName ? `/${originalName}` : '' }`; const type = mime.getType(originalName ?? file_id); - + const source = + openai.req.body.endpoint === EModelEndpoint.azureAssistants + ? FileSources.azure + : FileSources.openai; const file = { ..._file, type, @@ -349,7 +433,7 @@ const processOpenAIFile = async ({ usage: 1, user: userId, context: _file.purpose, - source: FileSources.openai, + source, model: openai.req.body.model, filename: originalName ?? file_id, }; @@ -394,12 +478,14 @@ const processOpenAIImageOutput = async ({ req, buffer, file_id, filename, fileEx filename: `${hostImageNamePrefix}${filename}`, }; createFile(file, true); + const source = + req.body.endpoint === EModelEndpoint.azureAssistants ? FileSources.azure : FileSources.openai; createFile( { ...file, file_id, filename, - source: FileSources.openai, + source, type: mime.getType(fileExt), }, true, @@ -500,7 +586,12 @@ async function retrieveAndProcessFile({ * Filters a file based on its size and the endpoint origin. * * @param {Object} params - The parameters for the function. - * @param {Express.Request} params.req - The request object from Express. + * @param {object} params.req - The request object from Express. + * @param {string} [params.req.endpoint] + * @param {string} [params.req.file_id] + * @param {number} [params.req.width] + * @param {number} [params.req.height] + * @param {number} [params.req.version] * @param {Express.Multer.File} params.file - The file uploaded to the server via multer. * @param {boolean} [params.image] - Whether the file expected is an image. * @returns {void} diff --git a/api/server/services/Files/strategies.js b/api/server/services/Files/strategies.js index 96733e403..fa4e456fc 100644 --- a/api/server/services/Files/strategies.js +++ b/api/server/services/Files/strategies.js @@ -111,6 +111,8 @@ const getStrategyFunctions = (fileSource) => { return localStrategy(); } else if (fileSource === FileSources.openai) { return openAIStrategy(); + } else if (fileSource === FileSources.azure) { + return openAIStrategy(); } else if (fileSource === FileSources.vectordb) { return vectorStrategy(); } else { diff --git a/api/server/services/ModelService.js b/api/server/services/ModelService.js index 3c560b297..b6ca6e4f4 100644 --- a/api/server/services/ModelService.js +++ b/api/server/services/ModelService.js @@ -167,6 +167,8 @@ const getOpenAIModels = async (opts) => { if (opts.assistants) { models = defaultModels[EModelEndpoint.assistants]; + } else if (opts.azure) { + models = defaultModels[EModelEndpoint.azureAssistants]; } if (opts.plugins) { diff --git a/api/server/services/Runs/handle.js b/api/server/services/Runs/handle.js index 8b73b099e..dd048219b 100644 --- a/api/server/services/Runs/handle.js +++ b/api/server/services/Runs/handle.js @@ -55,7 +55,7 @@ async function createRun({ openai, thread_id, body }) { * @param {string} params.run_id - The ID of the run to wait for. * @param {string} params.thread_id - The ID of the thread associated with the run. * @param {RunManager} params.runManager - The RunManager instance to manage run steps. - * @param {number} [params.pollIntervalMs=750] - The interval for polling the run status; default is 750 milliseconds. + * @param {number} [params.pollIntervalMs=2000] - The interval for polling the run status; default is 2000 milliseconds. * @param {number} [params.timeout=180000] - The period to wait until timing out polling; default is 3 minutes (in ms). * @return {Promise} A promise that resolves to the last fetched run object. */ @@ -64,7 +64,7 @@ async function waitForRun({ run_id, thread_id, runManager, - pollIntervalMs = 750, + pollIntervalMs = 2000, timeout = 60000 * 3, }) { let timeElapsed = 0; @@ -233,7 +233,7 @@ async function _handleRun({ openai, run_id, thread_id }) { run_id, thread_id, runManager, - pollIntervalMs: 750, + pollIntervalMs: 2000, timeout: 60000, }); const actions = []; diff --git a/api/server/services/Threads/manage.js b/api/server/services/Threads/manage.js index f875b1084..fb151cee9 100644 --- a/api/server/services/Threads/manage.js +++ b/api/server/services/Threads/manage.js @@ -3,7 +3,6 @@ const { v4 } = require('uuid'); const { Constants, ContentTypes, - EModelEndpoint, AnnotationTypes, defaultOrderQuery, } = require('librechat-data-provider'); @@ -50,6 +49,7 @@ async function initThread({ openai, body, thread_id: _thread_id }) { * @param {string} params.assistant_id - The current assistant Id. * @param {string} params.thread_id - The thread Id. * @param {string} params.conversationId - The message's conversationId + * @param {string} params.endpoint - The conversation endpoint * @param {string} [params.parentMessageId] - Optional if initial message. * Defaults to Constants.NO_PARENT. * @param {string} [params.instructions] - Optional: from preset for `instructions` field. @@ -82,7 +82,7 @@ async function saveUserMessage(params) { const userMessage = { user: params.user, - endpoint: EModelEndpoint.assistants, + endpoint: params.endpoint, messageId: params.messageId, conversationId: params.conversationId, parentMessageId: params.parentMessageId ?? Constants.NO_PARENT, @@ -96,7 +96,7 @@ async function saveUserMessage(params) { }; const convo = { - endpoint: EModelEndpoint.assistants, + endpoint: params.endpoint, conversationId: params.conversationId, promptPrefix: params.promptPrefix, instructions: params.instructions, @@ -126,6 +126,7 @@ async function saveUserMessage(params) { * @param {string} params.model - The model used by the assistant. * @param {ContentPart[]} params.content - The message content parts. * @param {string} params.conversationId - The message's conversationId + * @param {string} params.endpoint - The conversation endpoint * @param {string} params.parentMessageId - The latest user message that triggered this response. * @param {string} [params.instructions] - Optional: from preset for `instructions` field. * Overrides the instructions of the assistant. @@ -145,7 +146,7 @@ async function saveAssistantMessage(params) { const message = await recordMessage({ user: params.user, - endpoint: EModelEndpoint.assistants, + endpoint: params.endpoint, messageId: params.messageId, conversationId: params.conversationId, parentMessageId: params.parentMessageId, @@ -160,7 +161,7 @@ async function saveAssistantMessage(params) { }); await saveConvo(params.user, { - endpoint: EModelEndpoint.assistants, + endpoint: params.endpoint, conversationId: params.conversationId, promptPrefix: params.promptPrefix, instructions: params.instructions, @@ -205,20 +206,22 @@ async function addThreadMetadata({ openai, thread_id, messageId, messages }) { * * @param {Object} params - The parameters for synchronizing messages. * @param {OpenAIClient} params.openai - The OpenAI client instance. + * @param {string} params.endpoint - The current endpoint. + * @param {string} params.thread_id - The current thread ID. * @param {TMessage[]} params.dbMessages - The LibreChat DB messages. * @param {ThreadMessage[]} params.apiMessages - The thread messages from the API. - * @param {string} params.conversationId - The current conversation ID. - * @param {string} params.thread_id - The current thread ID. * @param {string} [params.assistant_id] - The current assistant ID. + * @param {string} params.conversationId - The current conversation ID. * @return {Promise} A promise that resolves to the updated messages */ async function syncMessages({ openai, - apiMessages, - dbMessages, - conversationId, + endpoint, thread_id, + dbMessages, + apiMessages, assistant_id, + conversationId, }) { let result = []; let dbMessageMap = new Map(dbMessages.map((msg) => [msg.messageId, msg])); @@ -290,7 +293,7 @@ async function syncMessages({ thread_id, conversationId, messageId: v4(), - endpoint: EModelEndpoint.assistants, + endpoint, parentMessageId: lastMessage ? lastMessage.messageId : Constants.NO_PARENT, role: apiMessage.role, isCreatedByUser: apiMessage.role === 'user', @@ -382,13 +385,21 @@ function mapMessagesToSteps(steps, messages) { * * @param {Object} params - The parameters for initializing a thread. * @param {OpenAIClient} params.openai - The OpenAI client instance. + * @param {string} params.endpoint - The current endpoint. * @param {string} [params.latestMessageId] - Optional: The latest message ID from LibreChat. * @param {string} params.thread_id - Response thread ID. * @param {string} params.run_id - Response Run ID. * @param {string} params.conversationId - LibreChat conversation ID. * @return {Promise} A promise that resolves to the updated messages */ -async function checkMessageGaps({ openai, latestMessageId, thread_id, run_id, conversationId }) { +async function checkMessageGaps({ + openai, + endpoint, + latestMessageId, + thread_id, + run_id, + conversationId, +}) { const promises = []; promises.push(openai.beta.threads.messages.list(thread_id, defaultOrderQuery)); promises.push(openai.beta.threads.runs.steps.list(thread_id, run_id)); @@ -406,6 +417,7 @@ async function checkMessageGaps({ openai, latestMessageId, thread_id, run_id, co role: 'assistant', run_id, thread_id, + endpoint, metadata: { messageId: latestMessageId, }, @@ -452,11 +464,12 @@ async function checkMessageGaps({ openai, latestMessageId, thread_id, run_id, co const syncedMessages = await syncMessages({ openai, + endpoint, + thread_id, dbMessages, apiMessages, - thread_id, - conversationId, assistant_id, + conversationId, }); return Object.values( @@ -498,41 +511,62 @@ const recordUsage = async ({ }; /** - * Safely replaces the annotated text within the specified range denoted by start_index and end_index, - * after verifying that the text within that range matches the given annotation text. - * Proceeds with the replacement even if a mismatch is found, but logs a warning. + * Creates a replaceAnnotation function with internal state for tracking the index offset. * - * @param {string} originalText The original text content. - * @param {number} start_index The starting index where replacement should begin. - * @param {number} end_index The ending index where replacement should end. - * @param {string} expectedText The text expected to be found in the specified range. - * @param {string} replacementText The text to insert in place of the existing content. - * @returns {string} The text with the replacement applied, regardless of text match. + * @returns {function} The replaceAnnotation function with closure for index offset. */ -function replaceAnnotation(originalText, start_index, end_index, expectedText, replacementText) { - if (start_index < 0 || end_index > originalText.length || start_index > end_index) { - logger.warn(`Invalid range specified for annotation replacement. - Attempting replacement with \`replace\` method instead... - length: ${originalText.length} - start_index: ${start_index} - end_index: ${end_index}`); - return originalText.replace(originalText, replacementText); +function createReplaceAnnotation() { + let indexOffset = 0; + + /** + * Safely replaces the annotated text within the specified range denoted by start_index and end_index, + * after verifying that the text within that range matches the given annotation text. + * Proceeds with the replacement even if a mismatch is found, but logs a warning. + * + * @param {object} params The original text content. + * @param {string} params.currentText The current text content, with/without replacements. + * @param {number} params.start_index The starting index where replacement should begin. + * @param {number} params.end_index The ending index where replacement should end. + * @param {string} params.expectedText The text expected to be found in the specified range. + * @param {string} params.replacementText The text to insert in place of the existing content. + * @returns {string} The text with the replacement applied, regardless of text match. + */ + function replaceAnnotation({ + currentText, + start_index, + end_index, + expectedText, + replacementText, + }) { + const adjustedStartIndex = start_index + indexOffset; + const adjustedEndIndex = end_index + indexOffset; + + if ( + adjustedStartIndex < 0 || + adjustedEndIndex > currentText.length || + adjustedStartIndex > adjustedEndIndex + ) { + logger.warn(`Invalid range specified for annotation replacement. + Attempting replacement with \`replace\` method instead... + length: ${currentText.length} + start_index: ${adjustedStartIndex} + end_index: ${adjustedEndIndex}`); + return currentText.replace(expectedText, replacementText); + } + + if (currentText.substring(adjustedStartIndex, adjustedEndIndex) !== expectedText) { + return currentText.replace(expectedText, replacementText); + } + + indexOffset += replacementText.length - (adjustedEndIndex - adjustedStartIndex); + return ( + currentText.slice(0, adjustedStartIndex) + + replacementText + + currentText.slice(adjustedEndIndex) + ); } - const actualTextInRange = originalText.substring(start_index, end_index); - - if (actualTextInRange !== expectedText) { - logger.warn(`The text within the specified range does not match the expected annotation text. - Attempting replacement with \`replace\` method instead... - Expected: ${expectedText} - Actual: ${actualTextInRange}`); - - return originalText.replace(originalText, replacementText); - } - - const beforeText = originalText.substring(0, start_index); - const afterText = originalText.substring(end_index); - return beforeText + replacementText + afterText; + return replaceAnnotation; } /** @@ -581,6 +615,11 @@ async function processMessages({ openai, client, messages = [] }) { continue; } + const originalText = currentText; + text += originalText; + + const replaceAnnotation = createReplaceAnnotation(); + logger.debug('[processMessages] Processing annotations:', annotations); for (const annotation of annotations) { let file; @@ -589,14 +628,16 @@ async function processMessages({ openai, client, messages = [] }) { const file_id = annotationType?.file_id; const alreadyProcessed = client.processedFileIds.has(file_id); - const replaceCurrentAnnotation = (replacement = '') => { - currentText = replaceAnnotation( + const replaceCurrentAnnotation = (replacementText = '') => { + const { start_index, end_index, text: expectedText } = annotation; + currentText = replaceAnnotation({ + originalText, currentText, - annotation.start_index, - annotation.end_index, - annotation.text, - replacement, - ); + start_index, + end_index, + expectedText, + replacementText, + }); edited = true; }; @@ -623,7 +664,7 @@ async function processMessages({ openai, client, messages = [] }) { replaceCurrentAnnotation(`^${sources.length}^`); } - text += currentText + ' '; + text = currentText; if (!file) { continue; diff --git a/api/server/services/start/assistants.js b/api/server/services/start/assistants.js index dfef99e59..394d7d1a3 100644 --- a/api/server/services/start/assistants.js +++ b/api/server/services/start/assistants.js @@ -2,6 +2,7 @@ const { Capabilities, EModelEndpoint, assistantEndpointSchema, + defaultAssistantsVersion, } = require('librechat-data-provider'); const { logger } = require('~/config'); @@ -12,6 +13,7 @@ const { logger } = require('~/config'); function azureAssistantsDefaults() { return { capabilities: [Capabilities.tools, Capabilities.actions, Capabilities.code_interpreter], + version: defaultAssistantsVersion.azureAssistants, }; } diff --git a/api/server/services/start/azureOpenAI.js b/api/server/services/start/azureOpenAI.js index 3b5c44620..565c8f691 100644 --- a/api/server/services/start/azureOpenAI.js +++ b/api/server/services/start/azureOpenAI.js @@ -41,6 +41,17 @@ function azureConfigSetup(config) { ); } + if ( + azureConfiguration.assistants && + process.env.ENDPOINTS && + !process.env.ENDPOINTS.includes(EModelEndpoint.azureAssistants) + ) { + logger.warn( + `Azure Assistants are configured, but the endpoint will not be accessible as it's not included in the ENDPOINTS environment variable. + Please add the value "${EModelEndpoint.azureAssistants}" to the ENDPOINTS list if expected.`, + ); + } + return { modelNames, modelGroupMap, diff --git a/api/server/utils/handleText.js b/api/server/utils/handleText.js index bfa37e279..70dc16b93 100644 --- a/api/server/utils/handleText.js +++ b/api/server/utils/handleText.js @@ -1,4 +1,10 @@ -const { Capabilities, defaultRetrievalModels } = require('librechat-data-provider'); +const { + Capabilities, + EModelEndpoint, + isAssistantsEndpoint, + defaultRetrievalModels, + defaultAssistantsVersion, +} = require('librechat-data-provider'); const { getCitations, citeText } = require('./citations'); const partialRight = require('lodash/partialRight'); const { sendMessage } = require('./streamResponse'); @@ -154,9 +160,10 @@ const isUserProvided = (value) => value === 'user_provided'; * Generate the configuration for a given key and base URL. * @param {string} key * @param {string} baseURL + * @param {string} endpoint * @returns {boolean | { userProvide: boolean, userProvideURL?: boolean }} */ -function generateConfig(key, baseURL, assistants = false) { +function generateConfig(key, baseURL, endpoint) { if (!key) { return false; } @@ -168,6 +175,8 @@ function generateConfig(key, baseURL, assistants = false) { config.userProvideURL = isUserProvided(baseURL); } + const assistants = isAssistantsEndpoint(endpoint); + if (assistants) { config.retrievalModels = defaultRetrievalModels; config.capabilities = [ @@ -179,6 +188,12 @@ function generateConfig(key, baseURL, assistants = false) { ]; } + if (assistants && endpoint === EModelEndpoint.azureAssistants) { + config.version = defaultAssistantsVersion.azureAssistants; + } else if (assistants) { + config.version = defaultAssistantsVersion.assistants; + } + return config; } diff --git a/api/typedefs.js b/api/typedefs.js index f7970be4f..5c83cab15 100644 --- a/api/typedefs.js +++ b/api/typedefs.js @@ -657,6 +657,12 @@ * @memberof typedefs */ +/** + * @exports OpenAISpecClient + * @typedef {import('./app/clients/OpenAIClient')} OpenAISpecClient + * @memberof typedefs + */ + /** * @exports ImportBatchBuilder * @typedef {import('./server/utils/import/importBatchBuilder.js').ImportBatchBuilder} ImportBatchBuilder diff --git a/client/src/common/assistants-types.ts b/client/src/common/assistants-types.ts index 3b9ad27da..e4edf025e 100644 --- a/client/src/common/assistants-types.ts +++ b/client/src/common/assistants-types.ts @@ -4,7 +4,11 @@ import type { Option, ExtendedFile } from './types'; export type TAssistantOption = | string - | (Option & Assistant & { files?: Array<[string, ExtendedFile]> }); + | (Option & + Assistant & { + files?: Array<[string, ExtendedFile]>; + code_files?: Array<[string, ExtendedFile]>; + }); export type Actions = { [Capabilities.code_interpreter]: boolean; diff --git a/client/src/common/types.ts b/client/src/common/types.ts index e574e90d8..62aae7f14 100644 --- a/client/src/common/types.ts +++ b/client/src/common/types.ts @@ -8,10 +8,12 @@ import type { TPreset, TPlugin, TMessage, + Assistant, TLoginUser, AuthTypeEnum, TConversation, EModelEndpoint, + AssistantsEndpoint, AuthorizationTypeEnum, TSetOption as SetOption, TokenExchangeMethodEnum, @@ -19,6 +21,13 @@ import type { import type { UseMutationResult } from '@tanstack/react-query'; import type { LucideIcon } from 'lucide-react'; +export type AssistantListItem = { + id: string; + name: string; + metadata: Assistant['metadata']; + model: string; +}; + export type TPluginMap = Record; export type GenericSetter = (value: T | ((currentValue: T) => T)) => void; @@ -101,6 +110,8 @@ export type AssistantPanelProps = { actions?: Action[]; assistant_id?: string; activePanel?: string; + endpoint: AssistantsEndpoint; + version: number | string; setAction: React.Dispatch>; setCurrentAssistantId: React.Dispatch>; setActivePanel: React.Dispatch>; @@ -315,6 +326,7 @@ export type IconProps = Pick & iconURL?: string; message?: boolean; className?: string; + iconClassName?: string; endpoint?: EModelEndpoint | string | null; endpointType?: EModelEndpoint | null; assistantName?: string; @@ -327,7 +339,11 @@ export type Option = Record & { }; export type OptionWithIcon = Option & { icon?: React.ReactNode }; -export type MentionOption = OptionWithIcon & { type: string; value: string; description?: string }; +export type MentionOption = OptionWithIcon & { + type: string; + value: string; + description?: string; +}; export type TOptionSettings = { showExamples?: boolean; diff --git a/client/src/components/Chat/Input/ChatForm.tsx b/client/src/components/Chat/Input/ChatForm.tsx index f05fd7279..f12284cc7 100644 --- a/client/src/components/Chat/Input/ChatForm.tsx +++ b/client/src/components/Chat/Input/ChatForm.tsx @@ -3,8 +3,8 @@ import { useForm } from 'react-hook-form'; import { memo, useCallback, useRef, useMemo } from 'react'; import { supportsFiles, - EModelEndpoint, mergeFileConfig, + isAssistantsEndpoint, fileConfig as defaultFileConfig, } from 'librechat-data-provider'; import { useChatContext, useAssistantsMapContext } from '~/Providers'; @@ -74,8 +74,9 @@ const ChatForm = ({ index = 0 }) => { const endpointFileConfig = fileConfig.endpoints[endpoint ?? '']; const invalidAssistant = useMemo( () => - conversation?.endpoint === EModelEndpoint.assistants && - (!conversation?.assistant_id || !assistantMap?.[conversation?.assistant_id ?? '']), + isAssistantsEndpoint(conversation?.endpoint) && + (!conversation?.assistant_id || + !assistantMap?.[conversation?.endpoint ?? '']?.[conversation?.assistant_id ?? '']), [conversation?.assistant_id, conversation?.endpoint, assistantMap], ); const disableInputs = useMemo( diff --git a/client/src/components/Chat/Input/Files/FilePreview.tsx b/client/src/components/Chat/Input/Files/FilePreview.tsx index 55c66b9a8..e1060e897 100644 --- a/client/src/components/Chat/Input/Files/FilePreview.tsx +++ b/client/src/components/Chat/Input/Files/FilePreview.tsx @@ -2,6 +2,7 @@ import type { TFile } from 'librechat-data-provider'; import type { ExtendedFile } from '~/common'; import FileIcon from '~/components/svg/Files/FileIcon'; import ProgressCircle from './ProgressCircle'; +import SourceIcon from './SourceIcon'; import { useProgress } from '~/hooks'; import { cn } from '~/utils'; @@ -20,8 +21,7 @@ const FilePreview = ({ }) => { const radius = 55; // Radius of the SVG circle const circumference = 2 * Math.PI * radius; - const progress = useProgress(file?.['progress'] ?? 1, 0.001, file?.size ?? 1); - console.log(progress); + const progress = useProgress(file?.['progress'] ?? 1, 0.001, (file as ExtendedFile)?.size ?? 1); // Calculate the offset based on the loading progress const offset = circumference - progress * circumference; @@ -32,6 +32,7 @@ const FilePreview = ({ return (
+ {progress < 1 && ( >; fileFilter?: (file: ExtendedFile) => boolean; assistant_id?: string; + tool_resource?: EToolResources; Wrapper?: React.FC<{ children: React.ReactNode }>; }) { const files = Array.from(_files.values()).filter((file) => @@ -25,7 +28,8 @@ export default function FileRow({ ); const { mutateAsync } = useDeleteFilesMutation({ - onMutate: async () => console.log('Deleting files: assistant_id', assistant_id), + onMutate: async () => + console.log('Deleting files: assistant_id, tool_resource', assistant_id, tool_resource), onSuccess: () => { console.log('Files deleted'); }, @@ -34,7 +38,7 @@ export default function FileRow({ }, }); - const { deleteFile } = useFileDeletion({ mutateAsync, assistant_id }); + const { deleteFile } = useFileDeletion({ mutateAsync, assistant_id, tool_resource }); useEffect(() => { if (!files) { @@ -82,6 +86,7 @@ export default function FileRow({ url={file.preview} onDelete={handleDelete} progress={file.progress} + source={file.source} /> ); } diff --git a/client/src/components/Chat/Input/Files/FilesView.tsx b/client/src/components/Chat/Input/Files/FilesView.tsx index efd9ec2a8..8791e6c91 100644 --- a/client/src/components/Chat/Input/Files/FilesView.tsx +++ b/client/src/components/Chat/Input/Files/FilesView.tsx @@ -12,16 +12,9 @@ export default function Files({ open, onOpenChange }) { const { data: files = [] } = useGetFiles({ select: (files) => files.map((file) => { - if (file.source === FileSources.local || file.source === FileSources.openai) { - file.context = file.context ?? FileContext.unknown; - return file; - } else { - return { - ...file, - context: file.context ?? FileContext.unknown, - source: FileSources.local, - }; - } + file.context = file.context ?? FileContext.unknown; + file.filterSource = file.source === FileSources.firebase ? FileSources.local : file.source; + return file; }), }); diff --git a/client/src/components/Chat/Input/Files/Image.tsx b/client/src/components/Chat/Input/Files/Image.tsx index 1cd13c833..22c03b537 100644 --- a/client/src/components/Chat/Input/Files/Image.tsx +++ b/client/src/components/Chat/Input/Files/Image.tsx @@ -1,3 +1,4 @@ +import { FileSources } from 'librechat-data-provider'; import ImagePreview from './ImagePreview'; import RemoveFile from './RemoveFile'; @@ -6,16 +7,18 @@ const Image = ({ url, onDelete, progress = 1, + source = FileSources.local, }: { imageBase64?: string; url?: string; onDelete: () => void; progress: number; // between 0 and 1 + source?: FileSources; }) => { return (
- +
diff --git a/client/src/components/Chat/Input/Files/ImagePreview.tsx b/client/src/components/Chat/Input/Files/ImagePreview.tsx index 479481235..2876c2aef 100644 --- a/client/src/components/Chat/Input/Files/ImagePreview.tsx +++ b/client/src/components/Chat/Input/Files/ImagePreview.tsx @@ -1,4 +1,6 @@ +import { FileSources } from 'librechat-data-provider'; import ProgressCircle from './ProgressCircle'; +import SourceIcon from './SourceIcon'; import { cn } from '~/utils'; type styleProps = { @@ -13,11 +15,13 @@ const ImagePreview = ({ url, progress = 1, className = '', + source, }: { imageBase64?: string; url?: string; progress?: number; // between 0 and 1 className?: string; + source?: FileSources; }) => { let style: styleProps = { backgroundSize: 'cover', @@ -65,6 +69,7 @@ const ImagePreview = ({ circleCSSProperties={circleCSSProperties} /> )} +
); }; diff --git a/client/src/components/Chat/Input/Files/SourceIcon.tsx b/client/src/components/Chat/Input/Files/SourceIcon.tsx new file mode 100644 index 000000000..23cc4d816 --- /dev/null +++ b/client/src/components/Chat/Input/Files/SourceIcon.tsx @@ -0,0 +1,45 @@ +import { EModelEndpoint, FileSources } from 'librechat-data-provider'; +import { MinimalIcon } from '~/components/Endpoints'; +import { cn } from '~/utils'; + +const sourceToEndpoint = { + [FileSources.openai]: EModelEndpoint.openAI, + [FileSources.azure]: EModelEndpoint.azureOpenAI, +}; +const sourceToClassname = { + [FileSources.openai]: 'bg-black/65', + [FileSources.azure]: 'azure-bg-color opacity-85', +}; + +const defaultClassName = + 'absolute right-0 bottom-0 rounded-full p-[0.15rem] text-gray-600 transition-colors'; + +export default function SourceIcon({ + source, + className = defaultClassName, +}: { + source?: FileSources; + className?: string; +}) { + if (source === FileSources.local || source === FileSources.firebase) { + return null; + } + + const endpoint = sourceToEndpoint[source ?? '']; + + if (!endpoint) { + return null; + } + return ( + + ); +} diff --git a/client/src/components/Chat/Input/Files/Table/Columns.tsx b/client/src/components/Chat/Input/Files/Table/Columns.tsx index 5b53a06f4..7284f2931 100644 --- a/client/src/components/Chat/Input/Files/Table/Columns.tsx +++ b/client/src/components/Chat/Input/Files/Table/Columns.tsx @@ -7,6 +7,7 @@ import ImagePreview from '~/components/Chat/Input/Files/ImagePreview'; import FilePreview from '~/components/Chat/Input/Files/FilePreview'; import { SortFilterHeader } from './SortFilterHeader'; import { OpenAIMinimalIcon } from '~/components/svg'; +import { AzureMinimalIcon } from '~/components/svg'; import { Button, Checkbox } from '~/components/ui'; import { formatDate, getFileType } from '~/utils'; import useLocalize from '~/hooks/useLocalize'; @@ -71,10 +72,11 @@ export const columns: ColumnDef[] = [ const file = row.original; if (file.type?.startsWith('image')) { return ( -
+
{file.filename}
@@ -84,7 +86,7 @@ export const columns: ColumnDef[] = [ const fileType = getFileType(file.type); return (
- {fileType && } + {fileType && } {file.filename}
); @@ -108,7 +110,7 @@ export const columns: ColumnDef[] = [ cell: ({ row }) => formatDate(row.original.updatedAt), }, { - accessorKey: 'source', + accessorKey: 'filterSource', header: ({ column }) => { const localize = useLocalize(); return ( @@ -117,10 +119,14 @@ export const columns: ColumnDef[] = [ title={localize('com_ui_storage')} filters={{ Storage: Object.values(FileSources).filter( - (value) => value === FileSources.local || value === FileSources.openai, + (value) => + value === FileSources.local || + value === FileSources.openai || + value === FileSources.azure, ), }} valueMap={{ + [FileSources.azure]: 'Azure', [FileSources.openai]: 'OpenAI', [FileSources.local]: 'com_ui_host', }} @@ -137,6 +143,13 @@ export const columns: ColumnDef[] = [ {'OpenAI'}
); + } else if (source === FileSources.azure) { + return ( +
+ + {'Azure'} +
+ ); } return (
diff --git a/client/src/components/Chat/Input/Files/Table/DataTable.tsx b/client/src/components/Chat/Input/Files/Table/DataTable.tsx index 347006b48..1886ffc87 100644 --- a/client/src/components/Chat/Input/Files/Table/DataTable.tsx +++ b/client/src/components/Chat/Input/Files/Table/DataTable.tsx @@ -48,7 +48,12 @@ const contextMap = { [FileContext.bytes]: 'com_ui_size', }; -type Style = { width?: number | string; maxWidth?: number | string; minWidth?: number | string }; +type Style = { + width?: number | string; + maxWidth?: number | string; + minWidth?: number | string; + zIndex?: number; +}; export default function DataTable({ columns, data }: DataTableProps) { const localize = useLocalize(); @@ -142,7 +147,7 @@ export default function DataTable({ columns, data }: DataTablePro {table.getHeaderGroups().map((headerGroup) => ( {headerGroup.headers.map((header, index) => { - const style: Style = { maxWidth: '32px', minWidth: '125px' }; + const style: Style = { maxWidth: '32px', minWidth: '125px', zIndex: 50 }; if (header.id === 'filename') { style.maxWidth = '50%'; style.width = '50%'; diff --git a/client/src/components/Chat/Input/Mention.tsx b/client/src/components/Chat/Input/Mention.tsx index 229dd5a54..93fec7430 100644 --- a/client/src/components/Chat/Input/Mention.tsx +++ b/client/src/components/Chat/Input/Mention.tsx @@ -17,7 +17,9 @@ export default function Mention({ }) { const localize = useLocalize(); const assistantMap = useAssistantsMapContext(); - const { options, modelsConfig, assistants, onSelectMention } = useMentions({ assistantMap }); + const { options, modelsConfig, assistantListMap, onSelectMention } = useMentions({ + assistantMap, + }); const [activeIndex, setActiveIndex] = useState(0); const timeoutRef = useRef(null); @@ -47,7 +49,12 @@ export default function Mention({ if (mention.type === 'endpoint' && mention.value === EModelEndpoint.assistants) { setSearchValue(''); - setInputOptions(assistants); + setInputOptions(assistantListMap[EModelEndpoint.assistants]); + setActiveIndex(0); + inputRef.current?.focus(); + } else if (mention.type === 'endpoint' && mention.value === EModelEndpoint.azureAssistants) { + setSearchValue(''); + setInputOptions(assistantListMap[EModelEndpoint.azureAssistants]); setActiveIndex(0); inputRef.current?.focus(); } else if (mention.type === 'endpoint') { diff --git a/client/src/components/Chat/Landing.tsx b/client/src/components/Chat/Landing.tsx index 5e2392bfc..2202a28ad 100644 --- a/client/src/components/Chat/Landing.tsx +++ b/client/src/components/Chat/Landing.tsx @@ -1,4 +1,4 @@ -import { EModelEndpoint } from 'librechat-data-provider'; +import { EModelEndpoint, isAssistantsEndpoint } from 'librechat-data-provider'; import { useGetEndpointsQuery, useGetStartupConfig } from 'librechat-data-provider/react-query'; import type { ReactNode } from 'react'; import { TooltipProvider, Tooltip, TooltipTrigger, TooltipContent } from '~/components/ui'; @@ -30,7 +30,8 @@ export default function Landing({ Header }: { Header?: ReactNode }) { const iconURL = conversation?.iconURL; endpoint = getIconEndpoint({ endpointsConfig, iconURL, endpoint }); - const assistant = endpoint === EModelEndpoint.assistants && assistantMap?.[assistant_id ?? '']; + const isAssistant = isAssistantsEndpoint(endpoint); + const assistant = isAssistant && assistantMap?.[endpoint]?.[assistant_id ?? '']; const assistantName = (assistant && assistant?.name) || ''; const assistantDesc = (assistant && assistant?.description) || ''; const avatar = (assistant && (assistant?.metadata?.avatar as string)) || ''; @@ -77,7 +78,7 @@ export default function Landing({ Header }: { Header?: ReactNode }) {
) : (
- {endpoint === EModelEndpoint.assistants + {isAssistant ? conversation?.greeting ?? localize('com_nav_welcome_assistant') : conversation?.greeting ?? localize('com_nav_welcome_message')}
diff --git a/client/src/components/Chat/Menus/Endpoints/Icons.tsx b/client/src/components/Chat/Menus/Endpoints/Icons.tsx index 4a700bd97..4e88cceef 100644 --- a/client/src/components/Chat/Menus/Endpoints/Icons.tsx +++ b/client/src/components/Chat/Menus/Endpoints/Icons.tsx @@ -15,6 +15,24 @@ import { import UnknownIcon from './UnknownIcon'; import { cn } from '~/utils'; +const AssistantAvatar = ({ className = '', assistantName, avatar, size }: IconMapProps) => { + if (assistantName && avatar) { + return ( + {assistantName} + ); + } else if (assistantName) { + return ; + } + + return ; +}; + export const icons = { [EModelEndpoint.azureOpenAI]: AzureMinimalIcon, [EModelEndpoint.openAI]: GPTIcon, @@ -24,22 +42,7 @@ export const icons = { [EModelEndpoint.google]: GoogleMinimalIcon, [EModelEndpoint.bingAI]: BingAIMinimalIcon, [EModelEndpoint.custom]: CustomMinimalIcon, - [EModelEndpoint.assistants]: ({ className = '', assistantName, avatar, size }: IconMapProps) => { - if (assistantName && avatar) { - return ( - {assistantName} - ); - } else if (assistantName) { - return ; - } - - return ; - }, + [EModelEndpoint.assistants]: AssistantAvatar, + [EModelEndpoint.azureAssistants]: AssistantAvatar, unknown: UnknownIcon, }; diff --git a/client/src/components/Chat/Menus/EndpointsMenu.tsx b/client/src/components/Chat/Menus/EndpointsMenu.tsx index 6d73c80e7..ab5eb4633 100644 --- a/client/src/components/Chat/Menus/EndpointsMenu.tsx +++ b/client/src/components/Chat/Menus/EndpointsMenu.tsx @@ -1,5 +1,5 @@ import { Content, Portal, Root } from '@radix-ui/react-popover'; -import { alternateName, EModelEndpoint } from 'librechat-data-provider'; +import { alternateName, isAssistantsEndpoint } from 'librechat-data-provider'; import { useGetEndpointsQuery } from 'librechat-data-provider/react-query'; import type { FC } from 'react'; import { useChatContext, useAssistantsMapContext } from '~/Providers'; @@ -16,7 +16,8 @@ const EndpointsMenu: FC = () => { const { endpoint = '', assistant_id = null } = conversation ?? {}; const assistantMap = useAssistantsMapContext(); - const assistant = endpoint === EModelEndpoint.assistants && assistantMap?.[assistant_id ?? '']; + const assistant = + isAssistantsEndpoint(endpoint) && assistantMap?.[endpoint ?? '']?.[assistant_id ?? '']; const assistantName = (assistant && assistant?.name) || 'Assistant'; if (!endpoint) { diff --git a/client/src/components/Chat/Messages/Content/CodeAnalyze.tsx b/client/src/components/Chat/Messages/Content/CodeAnalyze.tsx index c0ab640dd..492d86d62 100644 --- a/client/src/components/Chat/Messages/Content/CodeAnalyze.tsx +++ b/client/src/components/Chat/Messages/Content/CodeAnalyze.tsx @@ -1,6 +1,7 @@ import { useState } from 'react'; import { useRecoilValue } from 'recoil'; import ProgressCircle from './ProgressCircle'; +import CancelledIcon from './CancelledIcon'; import ProgressText from './ProgressText'; import FinishedIcon from './FinishedIcon'; import MarkdownLite from './MarkdownLite'; @@ -11,10 +12,12 @@ export default function CodeAnalyze({ initialProgress = 0.1, code, outputs = [], + isSubmitting, }: { initialProgress: number; code: string; outputs: Record[]; + isSubmitting: boolean; }) { const showCodeDefault = useRecoilValue(store.showCode); const [showCode, setShowCode] = useState(showCodeDefault); @@ -35,7 +38,13 @@ export default function CodeAnalyze({
{progress < 1 ? ( - + ) : ( )} @@ -74,18 +83,25 @@ const CodeInProgress = ({ offset, circumference, radius, + isSubmitting, + progress, }: { + progress: number; offset: number; circumference: number; radius: number; + isSubmitting: boolean; }) => { + if (progress < 1 && !isSubmitting) { + return ; + } return (
-
+
); } else if ( part.type === ContentTypes.TOOL_CALL && - part[ContentTypes.TOOL_CALL].type === ToolCallTypes.RETRIEVAL + (part[ContentTypes.TOOL_CALL].type === ToolCallTypes.RETRIEVAL || + part[ContentTypes.TOOL_CALL].type === ToolCallTypes.FILE_SEARCH) ) { const toolCall = part[ContentTypes.TOOL_CALL]; return ; diff --git a/client/src/components/Chat/Messages/HoverButtons.tsx b/client/src/components/Chat/Messages/HoverButtons.tsx index 7a593202b..35fa10df7 100644 --- a/client/src/components/Chat/Messages/HoverButtons.tsx +++ b/client/src/components/Chat/Messages/HoverButtons.tsx @@ -1,5 +1,4 @@ import { useState } from 'react'; -import { EModelEndpoint } from 'librechat-data-provider'; import type { TConversation, TMessage } from 'librechat-data-provider'; import { Clipboard, CheckMark, EditIcon, RegenerateIcon, ContinueIcon } from '~/components/svg'; import { useGenerationsByLatest, useLocalize } from '~/hooks'; @@ -35,14 +34,19 @@ export default function HoverButtons({ const { endpoint: _endpoint, endpointType } = conversation ?? {}; const endpoint = endpointType ?? _endpoint; const [isCopied, setIsCopied] = useState(false); - const { hideEditButton, regenerateEnabled, continueSupported, forkingSupported } = - useGenerationsByLatest({ - isEditing, - isSubmitting, - message, - endpoint: endpoint ?? '', - latestMessage, - }); + const { + hideEditButton, + regenerateEnabled, + continueSupported, + forkingSupported, + isEditableEndpoint, + } = useGenerationsByLatest({ + isEditing, + isSubmitting, + message, + endpoint: endpoint ?? '', + latestMessage, + }); if (!conversation) { return null; } @@ -58,7 +62,7 @@ export default function HoverButtons({ return (
- {endpoint !== EModelEndpoint.assistants && ( + {isEditableEndpoint && (
diff --git a/client/src/components/SidePanel/Builder/AssistantAvatar.tsx b/client/src/components/SidePanel/Builder/AssistantAvatar.tsx index 863e9bfc2..e5a736b34 100644 --- a/client/src/components/SidePanel/Builder/AssistantAvatar.tsx +++ b/client/src/components/SidePanel/Builder/AssistantAvatar.tsx @@ -10,9 +10,10 @@ import { import type { UseMutationResult } from '@tanstack/react-query'; import type { Metadata, - AssistantListResponse, Assistant, + AssistantsEndpoint, AssistantCreateParams, + AssistantListResponse, } from 'librechat-data-provider'; import { useUploadAssistantAvatarMutation, useGetFileConfig } from '~/data-provider'; import { AssistantAvatar, NoImage, AvatarMenu } from './Images'; @@ -22,10 +23,14 @@ import { useLocalize } from '~/hooks'; // import { cn } from '~/utils/'; function Avatar({ + endpoint, + version, assistant_id, metadata, createMutation, }: { + endpoint: AssistantsEndpoint; + version: number | string; assistant_id: string | null; metadata: null | Metadata; createMutation: UseMutationResult; @@ -46,8 +51,8 @@ function Avatar({ const { showToast } = useToastContext(); const activeModel = useMemo(() => { - return assistantsMap[assistant_id ?? '']?.model ?? ''; - }, [assistant_id, assistantsMap]); + return assistantsMap[endpoint][assistant_id ?? '']?.model ?? ''; + }, [assistantsMap, endpoint, assistant_id]); const { mutate: uploadAvatar } = useUploadAssistantAvatarMutation({ onMutate: () => { @@ -65,6 +70,7 @@ function Avatar({ const res = queryClient.getQueryData([ QueryKeys.assistants, + endpoint, defaultOrderQuery, ]); @@ -83,10 +89,13 @@ function Avatar({ return assistant; }) ?? []; - queryClient.setQueryData([QueryKeys.assistants, defaultOrderQuery], { - ...res, - data: assistants, - }); + queryClient.setQueryData( + [QueryKeys.assistants, endpoint, defaultOrderQuery], + { + ...res, + data: assistants, + }, + ); setProgress(1); }, @@ -149,9 +158,20 @@ function Avatar({ model: activeModel, postCreation: true, formData, + endpoint, + version, }); } - }, [createMutation.data, createMutation.isSuccess, input, previewUrl, uploadAvatar, activeModel]); + }, [ + createMutation.data, + createMutation.isSuccess, + input, + previewUrl, + uploadAvatar, + activeModel, + endpoint, + version, + ]); const handleFileChange = (event: React.ChangeEvent): void => { const file = event.target.files?.[0]; @@ -183,6 +203,8 @@ function Avatar({ assistant_id, model: activeModel, formData, + endpoint, + version, }); } else { showToast({ diff --git a/client/src/components/SidePanel/Builder/AssistantPanel.tsx b/client/src/components/SidePanel/Builder/AssistantPanel.tsx index 6399c5d34..aca34fc8e 100644 --- a/client/src/components/SidePanel/Builder/AssistantPanel.tsx +++ b/client/src/components/SidePanel/Builder/AssistantPanel.tsx @@ -1,23 +1,23 @@ -import { useState, useMemo, useEffect } from 'react'; +import { useState, useMemo } from 'react'; import { useQueryClient } from '@tanstack/react-query'; import { useForm, FormProvider, Controller, useWatch } from 'react-hook-form'; -import { useGetModelsQuery, useGetEndpointsQuery } from 'librechat-data-provider/react-query'; +import { useGetModelsQuery } from 'librechat-data-provider/react-query'; import { Tools, QueryKeys, Capabilities, - EModelEndpoint, actionDelimiter, ImageVisionTool, defaultAssistantFormValues, } from 'librechat-data-provider'; +import type { FunctionTool, TConfig, TPlugin } from 'librechat-data-provider'; import type { AssistantForm, AssistantPanelProps } from '~/common'; -import type { FunctionTool, TPlugin, TEndpointsConfig } from 'librechat-data-provider'; import { useCreateAssistantMutation, useUpdateAssistantMutation } from '~/data-provider'; -import { SelectDropDown, Checkbox, QuestionMark } from '~/components/ui'; import { useAssistantsMapContext, useToastContext } from '~/Providers'; import { useSelectAssistant, useLocalize } from '~/hooks'; import { ToolSelectDialog } from '~/components/Tools'; +import CapabilitiesForm from './CapabilitiesForm'; +import { SelectDropDown } from '~/components/ui'; import AssistantAvatar from './AssistantAvatar'; import AssistantSelect from './AssistantSelect'; import AssistantAction from './AssistantAction'; @@ -35,17 +35,20 @@ const inputClass = export default function AssistantPanel({ // index = 0, setAction, + endpoint, actions = [], setActivePanel, assistant_id: current_assistant_id, setCurrentAssistantId, -}: AssistantPanelProps) { + assistantsConfig, + version, +}: AssistantPanelProps & { assistantsConfig?: TConfig | null }) { const queryClient = useQueryClient(); const modelsQuery = useGetModelsQuery(); const assistantMap = useAssistantsMapContext(); - const { data: endpointsConfig = {} as TEndpointsConfig } = useGetEndpointsQuery(); + const allTools = queryClient.getQueryData([QueryKeys.tools]) ?? []; - const { onSelect: onSelectAssistant } = useSelectAssistant(); + const { onSelect: onSelectAssistant } = useSelectAssistant(endpoint); const { showToast } = useToastContext(); const localize = useLocalize(); @@ -55,44 +58,31 @@ export default function AssistantPanel({ const [showToolDialog, setShowToolDialog] = useState(false); - const { control, handleSubmit, reset, setValue, getValues } = methods; + const { control, handleSubmit, reset } = methods; const assistant = useWatch({ control, name: 'assistant' }); const functions = useWatch({ control, name: 'functions' }); const assistant_id = useWatch({ control, name: 'id' }); - const model = useWatch({ control, name: 'model' }); const activeModel = useMemo(() => { - return assistantMap?.[assistant_id]?.model; - }, [assistantMap, assistant_id]); + return assistantMap?.[endpoint]?.[assistant_id]?.model; + }, [assistantMap, endpoint, assistant_id]); - const assistants = useMemo(() => endpointsConfig?.[EModelEndpoint.assistants], [endpointsConfig]); - const retrievalModels = useMemo(() => new Set(assistants?.retrievalModels ?? []), [assistants]); const toolsEnabled = useMemo( - () => assistants?.capabilities?.includes(Capabilities.tools), - [assistants], + () => assistantsConfig?.capabilities?.includes(Capabilities.tools), + [assistantsConfig], ); const actionsEnabled = useMemo( - () => assistants?.capabilities?.includes(Capabilities.actions), - [assistants], + () => assistantsConfig?.capabilities?.includes(Capabilities.actions), + [assistantsConfig], ); const retrievalEnabled = useMemo( - () => assistants?.capabilities?.includes(Capabilities.retrieval), - [assistants], + () => assistantsConfig?.capabilities?.includes(Capabilities.retrieval), + [assistantsConfig], ); const codeEnabled = useMemo( - () => assistants?.capabilities?.includes(Capabilities.code_interpreter), - [assistants], + () => assistantsConfig?.capabilities?.includes(Capabilities.code_interpreter), + [assistantsConfig], ); - const imageVisionEnabled = useMemo( - () => assistants?.capabilities?.includes(Capabilities.image_vision), - [assistants], - ); - - useEffect(() => { - if (model && !retrievalModels.has(model)) { - setValue(Capabilities.retrieval, false); - } - }, [model, setValue, retrievalModels]); /* Mutations */ const update = useUpdateAssistantMutation({ @@ -145,7 +135,7 @@ export default function AssistantPanel({ if (!functionName.includes(actionDelimiter)) { return functionName; } else { - const assistant = assistantMap?.[assistant_id]; + const assistant = assistantMap?.[endpoint]?.[assistant_id]; const tool = assistant?.tools?.find((tool) => tool.function?.name === functionName); if (assistant && tool) { return tool; @@ -160,7 +150,7 @@ export default function AssistantPanel({ tools.push({ type: Tools.code_interpreter }); } if (data.retrieval) { - tools.push({ type: Tools.retrieval }); + tools.push({ type: version == 2 ? Tools.file_search : Tools.retrieval }); } if (data.image_vision) { tools.push(ImageVisionTool); @@ -183,6 +173,7 @@ export default function AssistantPanel({ instructions, model, tools, + endpoint, }, }); return; @@ -194,6 +185,8 @@ export default function AssistantPanel({ instructions, model, tools, + endpoint, + version, }); }; @@ -211,6 +204,7 @@ export default function AssistantPanel({
{/* Knowledge */} - {(codeEnabled || retrievalEnabled) && ( - + {(codeEnabled || retrievalEnabled) && version == 1 && ( + )} {/* Capabilities */} -
-
- - - -
-
- {codeEnabled && ( -
- ( - - )} - /> - -
- )} - {imageVisionEnabled && ( -
- ( - - )} - /> - -
- )} - {retrievalEnabled && ( -
- ( - - )} - /> - -
- )} -
-
+ {/* Tools */}