diff --git a/.env.example b/.env.example index 07c81066f..c3793945a 100644 --- a/.env.example +++ b/.env.example @@ -119,6 +119,14 @@ DEBUG_OPENAI=false # OPENAI_ORGANIZATION= +#====================# +# Assistants API # +#====================# + +# ASSISTANTS_API_KEY= +# ASSISTANTS_BASE_URL= +# ASSISTANTS_MODELS=gpt-3.5-turbo-0125,gpt-3.5-turbo-16k-0613,gpt-3.5-turbo-16k,gpt-3.5-turbo,gpt-4,gpt-4-0314,gpt-4-32k-0314,gpt-4-0613,gpt-3.5-turbo-0613,gpt-3.5-turbo-1106,gpt-4-0125-preview,gpt-4-turbo-preview,gpt-4-1106-preview + #============# # OpenRouter # #============# diff --git a/.eslintrc.js b/.eslintrc.js index a3d71acd6..6d8e08518 100644 --- a/.eslintrc.js +++ b/.eslintrc.js @@ -131,6 +131,12 @@ module.exports = { }, ], }, + { + files: ['./packages/data-provider/specs/**/*.ts'], + parserOptions: { + project: './packages/data-provider/tsconfig.spec.json', + }, + }, ], settings: { react: { diff --git a/.github/workflows/backend-review.yml b/.github/workflows/backend-review.yml index f5f8201ef..2d5cf387b 100644 --- a/.github/workflows/backend-review.yml +++ b/.github/workflows/backend-review.yml @@ -39,6 +39,9 @@ jobs: - name: Run unit tests run: cd api && npm run test:ci + - name: Run librechat-data-provider unit tests + run: cd packages/data-provider && npm run test:ci + - name: Run linters uses: wearerequired/lint-action@v2 with: diff --git a/.gitignore b/.gitignore index 765de5cb7..17c18d9a7 100644 --- a/.gitignore +++ b/.gitignore @@ -88,4 +88,7 @@ auth.json /packages/ux-shared/ /images -!client/src/components/Nav/SettingsTabs/Data/ \ No newline at end of file +!client/src/components/Nav/SettingsTabs/Data/ + +# User uploads +uploads/ \ No newline at end of file diff --git a/api/app/chatgpt-browser.js b/api/app/chatgpt-browser.js index 467e67785..818661555 100644 --- a/api/app/chatgpt-browser.js +++ b/api/app/chatgpt-browser.js @@ -1,5 +1,6 @@ require('dotenv').config(); const { KeyvFile } = require('keyv-file'); +const { Constants } = require('librechat-data-provider'); const { getUserKey, checkUserKeyExpiry } = require('../server/services/UserService'); const browserClient = async ({ @@ -48,7 +49,7 @@ const browserClient = async ({ options = { ...options, parentMessageId, conversationId }; } - if (parentMessageId === '00000000-0000-0000-0000-000000000000') { + if (parentMessageId === Constants.NO_PARENT) { delete options.conversationId; } diff --git a/api/app/clients/BaseClient.js b/api/app/clients/BaseClient.js index 3b919c92f..6009515f2 100644 --- a/api/app/clients/BaseClient.js +++ b/api/app/clients/BaseClient.js @@ -1,5 +1,5 @@ const crypto = require('crypto'); -const { supportsBalanceCheck } = require('librechat-data-provider'); +const { supportsBalanceCheck, Constants } = require('librechat-data-provider'); const { getConvo, getMessages, saveMessage, updateMessage, saveConvo } = require('~/models'); const { addSpaceIfNeeded, isEnabled } = require('~/server/utils'); const checkBalance = require('~/models/checkBalance'); @@ -77,7 +77,7 @@ class BaseClient { const saveOptions = this.getSaveOptions(); this.abortController = opts.abortController ?? new AbortController(); const conversationId = opts.conversationId ?? crypto.randomUUID(); - const parentMessageId = opts.parentMessageId ?? '00000000-0000-0000-0000-000000000000'; + const parentMessageId = opts.parentMessageId ?? Constants.NO_PARENT; const userMessageId = opts.overrideParentMessageId ?? crypto.randomUUID(); let responseMessageId = opts.responseMessageId ?? crypto.randomUUID(); let head = isEdited ? responseMessageId : parentMessageId; @@ -552,7 +552,7 @@ class BaseClient { * * Each message object should have an 'id' or 'messageId' property and may have a 'parentMessageId' property. * The 'parentMessageId' is the ID of the message that the current message is a reply to. - * If 'parentMessageId' is not present, null, or is '00000000-0000-0000-0000-000000000000', + * If 'parentMessageId' is not present, null, or is Constants.NO_PARENT, * the message is considered a root message. * * @param {Object} options - The options for the function. @@ -607,9 +607,7 @@ class BaseClient { } currentMessageId = - message.parentMessageId === '00000000-0000-0000-0000-000000000000' - ? null - : message.parentMessageId; + message.parentMessageId === Constants.NO_PARENT ? null : message.parentMessageId; } orderedMessages.reverse(); diff --git a/api/app/clients/GoogleClient.js b/api/app/clients/GoogleClient.js index 950cc8d11..dedda1fc8 100644 --- a/api/app/clients/GoogleClient.js +++ b/api/app/clients/GoogleClient.js @@ -4,12 +4,13 @@ const { GoogleVertexAI } = require('langchain/llms/googlevertexai'); const { ChatGoogleGenerativeAI } = require('@langchain/google-genai'); const { ChatGoogleVertexAI } = require('langchain/chat_models/googlevertexai'); const { AIMessage, HumanMessage, SystemMessage } = require('langchain/schema'); -const { encodeAndFormat, validateVisionModel } = require('~/server/services/Files/images'); +const { encodeAndFormat } = require('~/server/services/Files/images'); const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken'); const { + validateVisionModel, getResponseSender, - EModelEndpoint, endpointSettings, + EModelEndpoint, AuthKeys, } = require('librechat-data-provider'); const { getModelMaxTokens } = require('~/utils'); diff --git a/api/app/clients/OpenAIClient.js b/api/app/clients/OpenAIClient.js index f9c551097..0e61fdcba 100644 --- a/api/app/clients/OpenAIClient.js +++ b/api/app/clients/OpenAIClient.js @@ -1,14 +1,19 @@ const OpenAI = require('openai'); const { HttpsProxyAgent } = require('https-proxy-agent'); -const { getResponseSender, ImageDetailCost, ImageDetail } = require('librechat-data-provider'); +const { + getResponseSender, + validateVisionModel, + ImageDetailCost, + ImageDetail, +} = require('librechat-data-provider'); const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken'); const { - getModelMaxTokens, - genAzureChatCompletion, extractBaseURL, constructAzureURL, + getModelMaxTokens, + genAzureChatCompletion, } = require('~/utils'); -const { encodeAndFormat, validateVisionModel } = require('~/server/services/Files/images'); +const { encodeAndFormat } = require('~/server/services/Files/images/encode'); const { truncateText, formatMessage, CUT_OFF_PROMPT } = require('./prompts'); const { handleOpenAIErrors } = require('./tools/util'); const spendTokens = require('~/models/spendTokens'); @@ -630,6 +635,7 @@ class OpenAIClient extends BaseClient { context, tokenBuffer, initialMessageCount, + conversationId, }) { const modelOptions = { modelName: modelName ?? model, @@ -677,7 +683,7 @@ class OpenAIClient extends BaseClient { callbacks: runManager.createCallbacks({ context, tokenBuffer, - conversationId: this.conversationId, + conversationId: this.conversationId ?? conversationId, initialMessageCount, }), }); @@ -693,12 +699,13 @@ class OpenAIClient extends BaseClient { * * @param {Object} params - The parameters for the conversation title generation. * @param {string} params.text - The user's input. + * @param {string} [params.conversationId] - The current conversationId, if not already defined on client initialization. * @param {string} [params.responseText=''] - The AI's immediate response to the user. * * @returns {Promise} A promise that resolves to the generated conversation title. * In case of failure, it will return the default title, "New Chat". */ - async titleConvo({ text, responseText = '' }) { + async titleConvo({ text, conversationId, responseText = '' }) { let title = 'New Chat'; const convo = `||>User: "${truncateText(text)}" @@ -758,7 +765,12 @@ ${convo} try { this.abortController = new AbortController(); - const llm = this.initializeLLM({ ...modelOptions, context: 'title', tokenBuffer: 150 }); + const llm = this.initializeLLM({ + ...modelOptions, + conversationId, + context: 'title', + tokenBuffer: 150, + }); title = await runTitleChain({ llm, text, convo, signal: this.abortController.signal }); } catch (e) { if (e?.message?.toLowerCase()?.includes('abort')) { diff --git a/api/app/clients/PluginsClient.js b/api/app/clients/PluginsClient.js index ab04dd133..8a19ab4a2 100644 --- a/api/app/clients/PluginsClient.js +++ b/api/app/clients/PluginsClient.js @@ -3,6 +3,7 @@ const { CallbackManager } = require('langchain/callbacks'); const { BufferMemory, ChatMessageHistory } = require('langchain/memory'); const { initializeCustomAgent, initializeFunctionsAgent } = require('./agents'); const { addImages, buildErrorInput, buildPromptPrefix } = require('./output_parsers'); +const { processFileURL } = require('~/server/services/Files/process'); const { EModelEndpoint } = require('librechat-data-provider'); const { formatLangChainMessages } = require('./prompts'); const checkBalance = require('~/models/checkBalance'); @@ -113,6 +114,7 @@ class PluginsClient extends OpenAIClient { openAIApiKey: this.openAIApiKey, conversationId: this.conversationId, fileStrategy: this.options.req.app.locals.fileStrategy, + processFileURL, message, }, }); diff --git a/api/app/clients/prompts/formatMessages.spec.js b/api/app/clients/prompts/formatMessages.spec.js index 636cdb1c8..8d4956b38 100644 --- a/api/app/clients/prompts/formatMessages.spec.js +++ b/api/app/clients/prompts/formatMessages.spec.js @@ -1,5 +1,6 @@ -const { formatMessage, formatLangChainMessages, formatFromLangChain } = require('./formatMessages'); +const { Constants } = require('librechat-data-provider'); const { HumanMessage, AIMessage, SystemMessage } = require('langchain/schema'); +const { formatMessage, formatLangChainMessages, formatFromLangChain } = require('./formatMessages'); describe('formatMessage', () => { it('formats user message', () => { @@ -61,7 +62,7 @@ describe('formatMessage', () => { isCreatedByUser: true, isEdited: false, model: null, - parentMessageId: '00000000-0000-0000-0000-000000000000', + parentMessageId: Constants.NO_PARENT, sender: 'User', text: 'hi', tokenCount: 5, diff --git a/api/app/clients/specs/BaseClient.test.js b/api/app/clients/specs/BaseClient.test.js index 889499fbc..9ffa7e04f 100644 --- a/api/app/clients/specs/BaseClient.test.js +++ b/api/app/clients/specs/BaseClient.test.js @@ -1,3 +1,4 @@ +const { Constants } = require('librechat-data-provider'); const { initializeFakeClient } = require('./FakeClient'); jest.mock('../../../lib/db/connectDb'); @@ -307,7 +308,7 @@ describe('BaseClient', () => { const unorderedMessages = [ { id: '3', parentMessageId: '2', text: 'Message 3' }, { id: '2', parentMessageId: '1', text: 'Message 2' }, - { id: '1', parentMessageId: '00000000-0000-0000-0000-000000000000', text: 'Message 1' }, + { id: '1', parentMessageId: Constants.NO_PARENT, text: 'Message 1' }, ]; it('should return ordered messages based on parentMessageId', () => { @@ -316,7 +317,7 @@ describe('BaseClient', () => { parentMessageId: '3', }); expect(result).toEqual([ - { id: '1', parentMessageId: '00000000-0000-0000-0000-000000000000', text: 'Message 1' }, + { id: '1', parentMessageId: Constants.NO_PARENT, text: 'Message 1' }, { id: '2', parentMessageId: '1', text: 'Message 2' }, { id: '3', parentMessageId: '2', text: 'Message 3' }, ]); diff --git a/api/app/clients/specs/PluginsClient.test.js b/api/app/clients/specs/PluginsClient.test.js index 9a4663934..dfd57b23b 100644 --- a/api/app/clients/specs/PluginsClient.test.js +++ b/api/app/clients/specs/PluginsClient.test.js @@ -1,6 +1,7 @@ +const crypto = require('crypto'); +const { Constants } = require('librechat-data-provider'); const { HumanChatMessage, AIChatMessage } = require('langchain/schema'); const PluginsClient = require('../PluginsClient'); -const crypto = require('crypto'); jest.mock('~/lib/db/connectDb'); jest.mock('~/models/Conversation', () => { @@ -66,7 +67,7 @@ describe('PluginsClient', () => { TestAgent.setOptions(opts); } const conversationId = opts.conversationId || crypto.randomUUID(); - const parentMessageId = opts.parentMessageId || '00000000-0000-0000-0000-000000000000'; + const parentMessageId = opts.parentMessageId || Constants.NO_PARENT; const userMessageId = opts.overrideParentMessageId || crypto.randomUUID(); this.pastMessages = await TestAgent.loadHistory( conversationId, diff --git a/api/app/clients/tools/DALL-E.js b/api/app/clients/tools/DALL-E.js index d3cdaa713..4600bdb02 100644 --- a/api/app/clients/tools/DALL-E.js +++ b/api/app/clients/tools/DALL-E.js @@ -3,8 +3,8 @@ const OpenAI = require('openai'); const { v4: uuidv4 } = require('uuid'); const { Tool } = require('langchain/tools'); const { HttpsProxyAgent } = require('https-proxy-agent'); +const { FileContext } = require('librechat-data-provider'); const { getImageBasename } = require('~/server/services/Files/images'); -const { processFileURL } = require('~/server/services/Files/process'); const extractBaseURL = require('~/utils/extractBaseURL'); const { logger } = require('~/config'); @@ -14,6 +14,9 @@ class OpenAICreateImage extends Tool { this.userId = fields.userId; this.fileStrategy = fields.fileStrategy; + if (fields.processFileURL) { + this.processFileURL = fields.processFileURL.bind(this); + } let apiKey = fields.DALLE2_API_KEY ?? fields.DALLE_API_KEY ?? this.getApiKey(); const config = { apiKey }; @@ -80,13 +83,21 @@ Guidelines: } async _call(input) { - const resp = await this.openai.images.generate({ - prompt: this.replaceUnwantedChars(input), - // TODO: Future idea -- could we ask an LLM to extract these arguments from an input that might contain them? - n: 1, - // size: '1024x1024' - size: '512x512', - }); + let resp; + + try { + resp = await this.openai.images.generate({ + prompt: this.replaceUnwantedChars(input), + // TODO: Future idea -- could we ask an LLM to extract these arguments from an input that might contain them? + n: 1, + // size: '1024x1024' + size: '512x512', + }); + } catch (error) { + logger.error('[DALL-E] Problem generating the image:', error); + return `Something went wrong when trying to generate the image. The DALL-E API may be unavailable: +Error Message: ${error.message}`; + } const theImageUrl = resp.data[0].url; @@ -110,15 +121,16 @@ Guidelines: }); try { - const result = await processFileURL({ + const result = await this.processFileURL({ fileStrategy: this.fileStrategy, userId: this.userId, URL: theImageUrl, fileName: imageName, basePath: 'images', + context: FileContext.image_generation, }); - this.result = this.wrapInMarkdown(result); + this.result = this.wrapInMarkdown(result.filepath); } catch (error) { logger.error('Error while saving the image:', error); this.result = `Failed to save the image locally. ${error.message}`; diff --git a/api/app/clients/tools/index.js b/api/app/clients/tools/index.js index f5410e89e..60a741990 100644 --- a/api/app/clients/tools/index.js +++ b/api/app/clients/tools/index.js @@ -1,35 +1,42 @@ +const availableTools = require('./manifest.json'); +// Basic Tools +const CodeBrew = require('./CodeBrew'); const GoogleSearchAPI = require('./GoogleSearch'); -const OpenAICreateImage = require('./DALL-E'); -const DALLE3 = require('./structured/DALLE3'); -const StructuredSD = require('./structured/StableDiffusion'); -const StableDiffusionAPI = require('./StableDiffusion'); const WolframAlphaAPI = require('./Wolfram'); -const StructuredWolfram = require('./structured/Wolfram'); -const SelfReflectionTool = require('./SelfReflection'); const AzureAiSearch = require('./AzureAiSearch'); -const StructuredACS = require('./structured/AzureAISearch'); +const OpenAICreateImage = require('./DALL-E'); +const StableDiffusionAPI = require('./StableDiffusion'); +const SelfReflectionTool = require('./SelfReflection'); + +// Structured Tools +const DALLE3 = require('./structured/DALLE3'); const ChatTool = require('./structured/ChatTool'); const E2BTools = require('./structured/E2BTools'); const CodeSherpa = require('./structured/CodeSherpa'); +const StructuredSD = require('./structured/StableDiffusion'); +const StructuredACS = require('./structured/AzureAISearch'); const CodeSherpaTools = require('./structured/CodeSherpaTools'); -const availableTools = require('./manifest.json'); -const CodeBrew = require('./CodeBrew'); +const StructuredWolfram = require('./structured/Wolfram'); +const TavilySearchResults = require('./structured/TavilySearchResults'); module.exports = { availableTools, - GoogleSearchAPI, - OpenAICreateImage, - DALLE3, - StableDiffusionAPI, - StructuredSD, - WolframAlphaAPI, - StructuredWolfram, - SelfReflectionTool, - AzureAiSearch, - StructuredACS, - E2BTools, - ChatTool, - CodeSherpa, - CodeSherpaTools, + // Basic Tools CodeBrew, + AzureAiSearch, + GoogleSearchAPI, + WolframAlphaAPI, + OpenAICreateImage, + StableDiffusionAPI, + SelfReflectionTool, + // Structured Tools + DALLE3, + ChatTool, + E2BTools, + CodeSherpa, + StructuredSD, + StructuredACS, + CodeSherpaTools, + StructuredWolfram, + TavilySearchResults, }; diff --git a/api/app/clients/tools/manifest.json b/api/app/clients/tools/manifest.json index 1a03895a8..3a79ec02b 100644 --- a/api/app/clients/tools/manifest.json +++ b/api/app/clients/tools/manifest.json @@ -108,6 +108,19 @@ } ] }, + { + "name": "Tavily Search", + "pluginKey": "tavily_search_results_json", + "description": "Tavily Search is a robust search API tailored specifically for LLM Agents. It seamlessly integrates with diverse data sources to ensure a superior, relevant search experience.", + "icon": "https://tavily.com/favicon.ico", + "authConfig": [ + { + "authField": "TAVILY_API_KEY", + "label": "Tavily API Key", + "description": "Get your API key here: https://app.tavily.com/" + } + ] + }, { "name": "Calculator", "pluginKey": "calculator", diff --git a/api/app/clients/tools/structured/AzureAISearch.js b/api/app/clients/tools/structured/AzureAISearch.js index 9b50aa2c4..0ce7b43fb 100644 --- a/api/app/clients/tools/structured/AzureAISearch.js +++ b/api/app/clients/tools/structured/AzureAISearch.js @@ -19,6 +19,13 @@ class AzureAISearch extends StructuredTool { this.name = 'azure-ai-search'; this.description = 'Use the \'azure-ai-search\' tool to retrieve search results relevant to your input'; + /* Used to initialize the Tool without necessary variables. */ + this.override = fields.override ?? false; + + // Define schema + this.schema = z.object({ + query: z.string().describe('Search word or phrase to Azure AI Search'), + }); // Initialize properties using helper function this.serviceEndpoint = this._initializeField( @@ -51,12 +58,16 @@ class AzureAISearch extends StructuredTool { ); // Check for required fields - if (!this.serviceEndpoint || !this.indexName || !this.apiKey) { + if (!this.override && (!this.serviceEndpoint || !this.indexName || !this.apiKey)) { throw new Error( 'Missing AZURE_AI_SEARCH_SERVICE_ENDPOINT, AZURE_AI_SEARCH_INDEX_NAME, or AZURE_AI_SEARCH_API_KEY environment variable.', ); } + if (this.override) { + return; + } + // Create SearchClient this.client = new SearchClient( this.serviceEndpoint, @@ -64,11 +75,6 @@ class AzureAISearch extends StructuredTool { new AzureKeyCredential(this.apiKey), { apiVersion: this.apiVersion }, ); - - // Define schema - this.schema = z.object({ - query: z.string().describe('Search word or phrase to Azure AI Search'), - }); } // Improved error handling and logging diff --git a/api/app/clients/tools/structured/DALLE3.js b/api/app/clients/tools/structured/DALLE3.js index 92cf5b2a7..e3c0f7010 100644 --- a/api/app/clients/tools/structured/DALLE3.js +++ b/api/app/clients/tools/structured/DALLE3.js @@ -4,17 +4,25 @@ const OpenAI = require('openai'); const { v4: uuidv4 } = require('uuid'); const { Tool } = require('langchain/tools'); const { HttpsProxyAgent } = require('https-proxy-agent'); +const { FileContext } = require('librechat-data-provider'); const { getImageBasename } = require('~/server/services/Files/images'); -const { processFileURL } = require('~/server/services/Files/process'); const extractBaseURL = require('~/utils/extractBaseURL'); const { logger } = require('~/config'); class DALLE3 extends Tool { constructor(fields = {}) { super(); + /* Used to initialize the Tool without necessary variables. */ + this.override = fields.override ?? false; + /* Necessary for output to contain all image metadata. */ + this.returnMetadata = fields.returnMetadata ?? false; this.userId = fields.userId; this.fileStrategy = fields.fileStrategy; + if (fields.processFileURL) { + this.processFileURL = fields.processFileURL.bind(this); + } + let apiKey = fields.DALLE3_API_KEY ?? fields.DALLE_API_KEY ?? this.getApiKey(); const config = { apiKey }; if (process.env.DALLE_REVERSE_PROXY) { @@ -81,7 +89,7 @@ class DALLE3 extends Tool { getApiKey() { const apiKey = process.env.DALLE3_API_KEY ?? process.env.DALLE_API_KEY ?? ''; - if (!apiKey) { + if (!apiKey && !this.override) { throw new Error('Missing DALLE_API_KEY environment variable.'); } return apiKey; @@ -115,6 +123,7 @@ class DALLE3 extends Tool { n: 1, }); } catch (error) { + logger.error('[DALL-E-3] Problem generating the image:', error); return `Something went wrong when trying to generate the image. The DALL-E API may be unavailable: Error Message: ${error.message}`; } @@ -145,15 +154,26 @@ Error Message: ${error.message}`; }); try { - const result = await processFileURL({ + const result = await this.processFileURL({ fileStrategy: this.fileStrategy, userId: this.userId, URL: theImageUrl, fileName: imageName, basePath: 'images', + context: FileContext.image_generation, }); - this.result = this.wrapInMarkdown(result); + if (this.returnMetadata) { + this.result = { + file_id: result.file_id, + filename: result.filename, + filepath: result.filepath, + height: result.height, + width: result.width, + }; + } else { + this.result = this.wrapInMarkdown(result.filepath); + } } catch (error) { logger.error('Error while saving the image:', error); this.result = `Failed to save the image locally. ${error.message}`; diff --git a/api/app/clients/tools/structured/StableDiffusion.js b/api/app/clients/tools/structured/StableDiffusion.js index 1fc509673..dc479037b 100644 --- a/api/app/clients/tools/structured/StableDiffusion.js +++ b/api/app/clients/tools/structured/StableDiffusion.js @@ -10,6 +10,9 @@ const { logger } = require('~/config'); class StableDiffusionAPI extends StructuredTool { constructor(fields) { super(); + /* Used to initialize the Tool without necessary variables. */ + this.override = fields.override ?? false; + this.name = 'stable-diffusion'; this.url = fields.SD_WEBUI_URL || this.getServerURL(); this.description_for_model = `// Generate images and visuals using text. @@ -52,7 +55,7 @@ class StableDiffusionAPI extends StructuredTool { getServerURL() { const url = process.env.SD_WEBUI_URL || ''; - if (!url) { + if (!url && !this.override) { throw new Error('Missing SD_WEBUI_URL environment variable.'); } return url; diff --git a/api/app/clients/tools/structured/TavilySearchResults.js b/api/app/clients/tools/structured/TavilySearchResults.js new file mode 100644 index 000000000..3945ac1d0 --- /dev/null +++ b/api/app/clients/tools/structured/TavilySearchResults.js @@ -0,0 +1,92 @@ +const { z } = require('zod'); +const { Tool } = require('@langchain/core/tools'); +const { getEnvironmentVariable } = require('@langchain/core/utils/env'); + +class TavilySearchResults extends Tool { + static lc_name() { + return 'TavilySearchResults'; + } + + constructor(fields = {}) { + super(fields); + this.envVar = 'TAVILY_API_KEY'; + /* Used to initialize the Tool without necessary variables. */ + this.override = fields.override ?? false; + this.apiKey = fields.apiKey ?? this.getApiKey(); + + this.kwargs = fields?.kwargs ?? {}; + this.name = 'tavily_search_results_json'; + this.description = + 'A search engine optimized for comprehensive, accurate, and trusted results. Useful for when you need to answer questions about current events.'; + + this.schema = z.object({ + query: z.string().min(1).describe('The search query string.'), + max_results: z + .number() + .min(1) + .max(10) + .optional() + .describe('The maximum number of search results to return. Defaults to 5.'), + search_depth: z + .enum(['basic', 'advanced']) + .optional() + .describe( + 'The depth of the search, affecting result quality and response time (`basic` or `advanced`). Default is basic for quick results and advanced for indepth high quality results but longer response time. Advanced calls equals 2 requests.', + ), + include_images: z + .boolean() + .optional() + .describe( + 'Whether to include a list of query-related images in the response. Default is False.', + ), + include_answer: z + .boolean() + .optional() + .describe('Whether to include answers in the search results. Default is False.'), + // include_raw_content: z.boolean().optional().describe('Whether to include raw content in the search results. Default is False.'), + // include_domains: z.array(z.string()).optional().describe('A list of domains to specifically include in the search results.'), + // exclude_domains: z.array(z.string()).optional().describe('A list of domains to specifically exclude from the search results.'), + }); + } + + getApiKey() { + const apiKey = getEnvironmentVariable(this.envVar); + if (!apiKey && !this.override) { + throw new Error(`Missing ${this.envVar} environment variable.`); + } + return apiKey; + } + + async _call(input) { + const validationResult = this.schema.safeParse(input); + if (!validationResult.success) { + throw new Error(`Validation failed: ${JSON.stringify(validationResult.error.issues)}`); + } + + const { query, ...rest } = validationResult.data; + + const requestBody = { + api_key: this.apiKey, + query, + ...rest, + ...this.kwargs, + }; + + const response = await fetch('https://api.tavily.com/search', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify(requestBody), + }); + + const json = await response.json(); + if (!response.ok) { + throw new Error(`Request failed with status ${response.status}: ${json.error}`); + } + + return JSON.stringify(json); + } +} + +module.exports = TavilySearchResults; diff --git a/api/app/clients/tools/structured/Wolfram.js b/api/app/clients/tools/structured/Wolfram.js index 2c5c6e023..fc857b35c 100644 --- a/api/app/clients/tools/structured/Wolfram.js +++ b/api/app/clients/tools/structured/Wolfram.js @@ -7,6 +7,9 @@ const { logger } = require('~/config'); class WolframAlphaAPI extends StructuredTool { constructor(fields) { super(); + /* Used to initialize the Tool without necessary variables. */ + this.override = fields.override ?? false; + this.name = 'wolfram'; this.apiKey = fields.WOLFRAM_APP_ID || this.getAppId(); this.description_for_model = `// Access dynamic computation and curated data from WolframAlpha and Wolfram Cloud. @@ -55,7 +58,7 @@ class WolframAlphaAPI extends StructuredTool { getAppId() { const appId = process.env.WOLFRAM_APP_ID || ''; - if (!appId) { + if (!appId && !this.override) { throw new Error('Missing WOLFRAM_APP_ID environment variable.'); } return appId; diff --git a/api/app/clients/tools/structured/specs/DALLE3.spec.js b/api/app/clients/tools/structured/specs/DALLE3.spec.js index 58771b145..1b28de2fa 100644 --- a/api/app/clients/tools/structured/specs/DALLE3.spec.js +++ b/api/app/clients/tools/structured/specs/DALLE3.spec.js @@ -1,14 +1,11 @@ const OpenAI = require('openai'); const DALLE3 = require('../DALLE3'); -const { processFileURL } = require('~/server/services/Files/process'); const { logger } = require('~/config'); jest.mock('openai'); -jest.mock('~/server/services/Files/process', () => ({ - processFileURL: jest.fn(), -})); +const processFileURL = jest.fn(); jest.mock('~/server/services/Files/images', () => ({ getImageBasename: jest.fn().mockImplementation((url) => { @@ -69,7 +66,7 @@ describe('DALLE3', () => { jest.resetModules(); process.env = { ...originalEnv, DALLE_API_KEY: mockApiKey }; // Instantiate DALLE3 for tests that do not depend on DALLE3_SYSTEM_PROMPT - dalle = new DALLE3(); + dalle = new DALLE3({ processFileURL }); }); afterEach(() => { @@ -78,7 +75,8 @@ describe('DALLE3', () => { process.env = originalEnv; }); - it('should throw an error if DALLE_API_KEY is missing', () => { + it('should throw an error if all potential API keys are missing', () => { + delete process.env.DALLE3_API_KEY; delete process.env.DALLE_API_KEY; expect(() => new DALLE3()).toThrow('Missing DALLE_API_KEY environment variable.'); }); @@ -112,7 +110,9 @@ describe('DALLE3', () => { }; generate.mockResolvedValue(mockResponse); - processFileURL.mockResolvedValue('http://example.com/img-test.png'); + processFileURL.mockResolvedValue({ + filepath: 'http://example.com/img-test.png', + }); const result = await dalle._call(mockData); diff --git a/api/app/clients/tools/util/handleTools.js b/api/app/clients/tools/util/handleTools.js index 4a1e4e09b..2733791e3 100644 --- a/api/app/clients/tools/util/handleTools.js +++ b/api/app/clients/tools/util/handleTools.js @@ -6,19 +6,22 @@ const { OpenAIEmbeddings } = require('langchain/embeddings/openai'); const { getUserPluginAuthValue } = require('~/server/services/PluginService'); const { availableTools, + // Basic Tools + CodeBrew, + AzureAISearch, GoogleSearchAPI, WolframAlphaAPI, - StructuredWolfram, OpenAICreateImage, StableDiffusionAPI, + // Structured Tools DALLE3, - StructuredSD, - AzureAISearch, - StructuredACS, E2BTools, CodeSherpa, + StructuredSD, + StructuredACS, CodeSherpaTools, - CodeBrew, + StructuredWolfram, + TavilySearchResults, } = require('../'); const { loadToolSuite } = require('./loadToolSuite'); const { loadSpecs } = require('./loadSpecs'); @@ -151,8 +154,10 @@ const loadTools = async ({ returnMap = false, tools = [], options = {}, + skipSpecs = false, }) => { const toolConstructors = { + tavily_search_results_json: TavilySearchResults, calculator: Calculator, google: GoogleSearchAPI, wolfram: functions ? StructuredWolfram : WolframAlphaAPI, @@ -229,10 +234,17 @@ const loadTools = async ({ toolConstructors.codesherpa = CodeSherpa; } + const imageGenOptions = { + fileStrategy: options.fileStrategy, + processFileURL: options.processFileURL, + returnMetadata: options.returnMetadata, + }; + const toolOptions = { serpapi: { location: 'Austin,Texas,United States', hl: 'en', gl: 'us' }, - dalle: { fileStrategy: options.fileStrategy }, - 'dall-e': { fileStrategy: options.fileStrategy }, + dalle: imageGenOptions, + 'dall-e': imageGenOptions, + 'stable-diffusion': imageGenOptions, }; const toolAuthFields = {}; @@ -271,7 +283,7 @@ const loadTools = async ({ } let specs = null; - if (functions && remainingTools.length > 0) { + if (functions && remainingTools.length > 0 && skipSpecs !== true) { specs = await loadSpecs({ llm: model, user, @@ -298,6 +310,9 @@ const loadTools = async ({ let result = []; for (const tool of tools) { const validTool = requestedTools[tool]; + if (!validTool) { + continue; + } const plugin = await validTool(); if (Array.isArray(plugin)) { diff --git a/api/cache/getLogStores.js b/api/cache/getLogStores.js index 894edd216..1e614cde5 100644 --- a/api/cache/getLogStores.js +++ b/api/cache/getLogStores.js @@ -33,7 +33,11 @@ const genTitle = isEnabled(USE_REDIS) // ttl: 2 minutes const modelQueries = isEnabled(process.env.USE_REDIS) ? new Keyv({ store: keyvRedis }) - : new Keyv({ namespace: 'models' }); + : new Keyv({ namespace: CacheKeys.MODEL_QUERIES }); + +const abortKeys = isEnabled(USE_REDIS) + ? new Keyv({ store: keyvRedis }) + : new Keyv({ namespace: CacheKeys.ABORT_KEYS }); const namespaces = { [CacheKeys.CONFIG_STORE]: config, @@ -45,7 +49,9 @@ const namespaces = { message_limit: createViolationInstance('message_limit'), token_balance: createViolationInstance('token_balance'), registrations: createViolationInstance('registrations'), + [CacheKeys.FILE_UPLOAD_LIMIT]: createViolationInstance(CacheKeys.FILE_UPLOAD_LIMIT), logins: createViolationInstance('logins'), + [CacheKeys.ABORT_KEYS]: abortKeys, [CacheKeys.TOKEN_CONFIG]: tokenConfig, [CacheKeys.GEN_TITLE]: genTitle, [CacheKeys.MODEL_QUERIES]: modelQueries, diff --git a/api/config/parsers.js b/api/config/parsers.js index 59685eab0..16c85cba4 100644 --- a/api/config/parsers.js +++ b/api/config/parsers.js @@ -33,6 +33,10 @@ function getMatchingSensitivePatterns(valueStr) { * @returns {string} - The redacted console message. */ function redactMessage(str) { + if (!str) { + return ''; + } + const patterns = getMatchingSensitivePatterns(str); if (patterns.length === 0) { diff --git a/api/config/paths.js b/api/config/paths.js index 41e3ac505..92921218e 100644 --- a/api/config/paths.js +++ b/api/config/paths.js @@ -1,7 +1,10 @@ const path = require('path'); module.exports = { + uploads: path.resolve(__dirname, '..', '..', 'uploads'), dist: path.resolve(__dirname, '..', '..', 'client', 'dist'), publicPath: path.resolve(__dirname, '..', '..', 'client', 'public'), imageOutput: path.resolve(__dirname, '..', '..', 'client', 'public', 'images'), + structuredTools: path.resolve(__dirname, '..', 'app', 'clients', 'tools', 'structured'), + pluginManifest: path.resolve(__dirname, '..', 'app', 'clients', 'tools', 'manifest.json'), }; diff --git a/api/models/Action.js b/api/models/Action.js new file mode 100644 index 000000000..5141569c1 --- /dev/null +++ b/api/models/Action.js @@ -0,0 +1,68 @@ +const mongoose = require('mongoose'); +const actionSchema = require('./schema/action'); + +const Action = mongoose.model('action', actionSchema); + +/** + * Update an action with new data without overwriting existing properties, + * or create a new action if it doesn't exist. + * + * @param {Object} searchParams - The search parameters to find the action to update. + * @param {string} searchParams.action_id - The ID of the action to update. + * @param {string} searchParams.user - The user ID of the action's author. + * @param {Object} updateData - An object containing the properties to update. + * @returns {Promise} The updated or newly created action document as a plain object. + */ +const updateAction = async (searchParams, updateData) => { + return await Action.findOneAndUpdate(searchParams, updateData, { + new: true, + upsert: true, + }).lean(); +}; + +/** + * Retrieves all actions that match the given search parameters. + * + * @param {Object} searchParams - The search parameters to find matching actions. + * @param {boolean} includeSensitive - Flag to include sensitive data in the metadata. + * @returns {Promise>} A promise that resolves to an array of action documents as plain objects. + */ +const getActions = async (searchParams, includeSensitive = false) => { + const actions = await Action.find(searchParams).lean(); + + if (!includeSensitive) { + for (let i = 0; i < actions.length; i++) { + const metadata = actions[i].metadata; + if (!metadata) { + continue; + } + + const sensitiveFields = ['api_key', 'oauth_client_id', 'oauth_client_secret']; + for (let field of sensitiveFields) { + if (metadata[field]) { + delete metadata[field]; + } + } + } + } + + return actions; +}; + +/** + * Deletes an action by its ID. + * + * @param {Object} searchParams - The search parameters to find the action to update. + * @param {string} searchParams.action_id - The ID of the action to update. + * @param {string} searchParams.user - The user ID of the action's author. + * @returns {Promise} A promise that resolves to the deleted action document as a plain object, or null if no document was found. + */ +const deleteAction = async (searchParams) => { + return await Action.findOneAndDelete(searchParams).lean(); +}; + +module.exports = { + updateAction, + getActions, + deleteAction, +}; diff --git a/api/models/Assistant.js b/api/models/Assistant.js new file mode 100644 index 000000000..fa6192eee --- /dev/null +++ b/api/models/Assistant.js @@ -0,0 +1,47 @@ +const mongoose = require('mongoose'); +const assistantSchema = require('./schema/assistant'); + +const Assistant = mongoose.model('assistant', assistantSchema); + +/** + * Update an assistant with new data without overwriting existing properties, + * or create a new assistant if it doesn't exist. + * + * @param {Object} searchParams - The search parameters to find the assistant to update. + * @param {string} searchParams.assistant_id - The ID of the assistant to update. + * @param {string} searchParams.user - The user ID of the assistant's author. + * @param {Object} updateData - An object containing the properties to update. + * @returns {Promise} The updated or newly created assistant document as a plain object. + */ +const updateAssistant = async (searchParams, updateData) => { + return await Assistant.findOneAndUpdate(searchParams, updateData, { + new: true, + upsert: true, + }).lean(); +}; + +/** + * Retrieves an assistant document based on the provided ID. + * + * @param {Object} searchParams - The search parameters to find the assistant to update. + * @param {string} searchParams.assistant_id - The ID of the assistant to update. + * @param {string} searchParams.user - The user ID of the assistant's author. + * @returns {Promise} The assistant document as a plain object, or null if not found. + */ +const getAssistant = async (searchParams) => await Assistant.findOne(searchParams).lean(); + +/** + * Retrieves all assistants that match the given search parameters. + * + * @param {Object} searchParams - The search parameters to find matching assistants. + * @returns {Promise>} A promise that resolves to an array of action documents as plain objects. + */ +const getAssistants = async (searchParams) => { + return await Assistant.find(searchParams).lean(); +}; + +module.exports = { + updateAssistant, + getAssistants, + getAssistant, +}; diff --git a/api/models/File.js b/api/models/File.js index 4c353fd70..fa14af3b2 100644 --- a/api/models/File.js +++ b/api/models/File.js @@ -14,24 +14,32 @@ const findFileById = async (file_id, options = {}) => { }; /** - * Retrieves files matching a given filter. + * Retrieves files matching a given filter, sorted by the most recently updated. * @param {Object} filter - The filter criteria to apply. + * @param {Object} [_sortOptions] - Optional sort parameters. * @returns {Promise>} A promise that resolves to an array of file documents. */ -const getFiles = async (filter) => { - return await File.find(filter).lean(); +const getFiles = async (filter, _sortOptions) => { + const sortOptions = { updatedAt: -1, ..._sortOptions }; + return await File.find(filter).sort(sortOptions).lean(); }; /** * Creates a new file with a TTL of 1 hour. * @param {MongoFile} data - The file data to be created, must contain file_id. + * @param {boolean} disableTTL - Whether to disable the TTL. * @returns {Promise} A promise that resolves to the created file document. */ -const createFile = async (data) => { +const createFile = async (data, disableTTL) => { const fileData = { ...data, expiresAt: new Date(Date.now() + 3600 * 1000), }; + + if (disableTTL) { + delete fileData.expiresAt; + } + return await File.findOneAndUpdate({ file_id: data.file_id }, fileData, { new: true, upsert: true, @@ -75,6 +83,15 @@ const deleteFile = async (file_id) => { return await File.findOneAndDelete({ file_id }).lean(); }; +/** + * Deletes a file identified by a filter. + * @param {object} filter - The filter criteria to apply. + * @returns {Promise} A promise that resolves to the deleted file document or null. + */ +const deleteFileByFilter = async (filter) => { + return await File.findOneAndDelete(filter).lean(); +}; + /** * Deletes multiple files identified by an array of file_ids. * @param {Array} file_ids - The unique identifiers of the files to delete. @@ -93,4 +110,5 @@ module.exports = { updateFileUsage, deleteFile, deleteFiles, + deleteFileByFilter, }; diff --git a/api/models/Message.js b/api/models/Message.js index fe615f328..a8e1acdf1 100644 --- a/api/models/Message.js +++ b/api/models/Message.js @@ -72,11 +72,49 @@ module.exports = { throw new Error('Failed to save message.'); } }, + /** + * Records a message in the database. + * + * @async + * @function recordMessage + * @param {Object} params - The message data object. + * @param {string} params.user - The identifier of the user. + * @param {string} params.endpoint - The endpoint where the message originated. + * @param {string} params.messageId - The unique identifier for the message. + * @param {string} params.conversationId - The identifier of the conversation. + * @param {string} [params.parentMessageId] - The identifier of the parent message, if any. + * @param {Partial} rest - Any additional properties from the TMessage typedef not explicitly listed. + * @returns {Promise} The updated or newly inserted message document. + * @throws {Error} If there is an error in saving the message. + */ + async recordMessage({ user, endpoint, messageId, conversationId, parentMessageId, ...rest }) { + try { + // No parsing of convoId as may use threadId + const message = { + user, + endpoint, + messageId, + conversationId, + parentMessageId, + ...rest, + }; + + return await Message.findOneAndUpdate({ user, messageId }, message, { + upsert: true, + new: true, + }); + } catch (err) { + logger.error('Error saving message:', err); + throw new Error('Failed to save message.'); + } + }, async updateMessage(message) { try { const { messageId, ...update } = message; update.isEdited = true; - const updatedMessage = await Message.findOneAndUpdate({ messageId }, update, { new: true }); + const updatedMessage = await Message.findOneAndUpdate({ messageId }, update, { + new: true, + }); if (!updatedMessage) { throw new Error('Message not found.'); diff --git a/api/models/index.js b/api/models/index.js index 1fa751354..f1b51d5ef 100644 --- a/api/models/index.js +++ b/api/models/index.js @@ -1,12 +1,14 @@ const { getMessages, saveMessage, + recordMessage, updateMessage, deleteMessagesSince, deleteMessages, } = require('./Message'); const { getConvoTitle, getConvo, saveConvo, deleteConvos } = require('./Conversation'); const { getPreset, getPresets, savePreset, deletePresets } = require('./Preset'); +const { hashPassword, getUser, updateUser } = require('./userMethods'); const { findFileById, createFile, @@ -29,8 +31,13 @@ module.exports = { Balance, Transaction, + hashPassword, + updateUser, + getUser, + getMessages, saveMessage, + recordMessage, updateMessage, deleteMessagesSince, deleteMessages, diff --git a/api/models/plugins/mongoMeili.js b/api/models/plugins/mongoMeili.js index abba84861..79dd30b11 100644 --- a/api/models/plugins/mongoMeili.js +++ b/api/models/plugins/mongoMeili.js @@ -183,6 +183,15 @@ const createMeiliMongooseModel = function ({ index, attributesToIndex }) { if (object.conversationId && object.conversationId.includes('|')) { object.conversationId = object.conversationId.replace(/\|/g, '--'); } + + if (object.content && Array.isArray(object.content)) { + object.text = object.content + .filter((item) => item.type === 'text' && item.text && item.text.value) + .map((item) => item.text.value) + .join(' '); + delete object.content; + } + return object; } diff --git a/api/models/schema/action.js b/api/models/schema/action.js new file mode 100644 index 000000000..fdafd2ec2 --- /dev/null +++ b/api/models/schema/action.js @@ -0,0 +1,60 @@ +const mongoose = require('mongoose'); + +const { Schema } = mongoose; + +const AuthSchema = new Schema( + { + authorization_type: String, + custom_auth_header: String, + type: { + type: String, + enum: ['service_http', 'oauth', 'none'], + }, + authorization_content_type: String, + authorization_url: String, + client_url: String, + scope: String, + token_exchange_method: { + type: String, + enum: ['default_post', 'basic_auth_header', null], + }, + }, + { _id: false }, +); + +const actionSchema = new Schema({ + user: { + type: mongoose.Schema.Types.ObjectId, + ref: 'User', + index: true, + required: true, + }, + action_id: { + type: String, + index: true, + required: true, + }, + type: { + type: String, + default: 'action_prototype', + }, + settings: Schema.Types.Mixed, + assistant_id: String, + metadata: { + api_key: String, // private, encrypted + auth: AuthSchema, + domain: { + type: String, + unique: true, + required: true, + }, + // json_schema: Schema.Types.Mixed, + privacy_policy_url: String, + raw_spec: String, + oauth_client_id: String, // private, encrypted + oauth_client_secret: String, // private, encrypted + }, +}); +// }, { minimize: false }); // Prevent removal of empty objects + +module.exports = actionSchema; diff --git a/api/models/schema/assistant.js b/api/models/schema/assistant.js new file mode 100644 index 000000000..a4ec36e19 --- /dev/null +++ b/api/models/schema/assistant.js @@ -0,0 +1,34 @@ +const mongoose = require('mongoose'); + +const assistantSchema = mongoose.Schema( + { + user: { + type: mongoose.Schema.Types.ObjectId, + ref: 'User', + required: true, + }, + assistant_id: { + type: String, + unique: true, + index: true, + required: true, + }, + avatar: { + type: { + filepath: String, + source: String, + }, + default: undefined, + }, + access_level: { + type: Number, + }, + file_ids: { type: [String], default: undefined }, + actions: { type: [String], default: undefined }, + }, + { + timestamps: true, + }, +); + +module.exports = assistantSchema; diff --git a/api/models/schema/defaults.js b/api/models/schema/defaults.js index 39a6430f4..fc0add4e0 100644 --- a/api/models/schema/defaults.js +++ b/api/models/schema/defaults.js @@ -11,152 +11,133 @@ const conversationPreset = { // for azureOpenAI, openAI, chatGPTBrowser only model: { type: String, - // default: null, required: false, }, // for azureOpenAI, openAI only chatGptLabel: { type: String, - // default: null, required: false, }, // for google only modelLabel: { type: String, - // default: null, required: false, }, promptPrefix: { type: String, - // default: null, required: false, }, temperature: { type: Number, - // default: 1, required: false, }, top_p: { type: Number, - // default: 1, required: false, }, // for google only topP: { type: Number, - // default: 0.95, required: false, }, topK: { type: Number, - // default: 40, required: false, }, maxOutputTokens: { type: Number, - // default: 1024, required: false, }, presence_penalty: { type: Number, - // default: 0, required: false, }, frequency_penalty: { type: Number, - // default: 0, required: false, }, // for bingai only jailbreak: { type: Boolean, - // default: false, }, context: { type: String, - // default: null, }, systemMessage: { type: String, - // default: null, }, toneStyle: { type: String, - // default: null, }, + file_ids: { type: [{ type: String }], default: undefined }, + // vision resendImages: { type: Boolean, }, imageDetail: { type: String, }, + /* assistants */ + assistant_id: { + type: String, + }, + instructions: { + type: String, + }, }; const agentOptions = { model: { type: String, - // default: null, required: false, }, // for azureOpenAI, openAI only chatGptLabel: { type: String, - // default: null, required: false, }, modelLabel: { type: String, - // default: null, required: false, }, promptPrefix: { type: String, - // default: null, required: false, }, temperature: { type: Number, - // default: 1, required: false, }, top_p: { type: Number, - // default: 1, required: false, }, // for google only topP: { type: Number, - // default: 0.95, required: false, }, topK: { type: Number, - // default: 40, required: false, }, maxOutputTokens: { type: Number, - // default: 1024, required: false, }, presence_penalty: { type: Number, - // default: 0, required: false, }, frequency_penalty: { type: Number, - // default: 0, required: false, }, context: { type: String, - // default: null, }, systemMessage: { type: String, - // default: null, }, }; diff --git a/api/models/schema/fileSchema.js b/api/models/schema/fileSchema.js index 471b7bfd7..e470a8d7e 100644 --- a/api/models/schema/fileSchema.js +++ b/api/models/schema/fileSchema.js @@ -3,6 +3,8 @@ const mongoose = require('mongoose'); /** * @typedef {Object} MongoFile + * @property {mongoose.Schema.Types.ObjectId} [_id] - MongoDB Document ID + * @property {number} [__v] - MongoDB Version Key * @property {mongoose.Schema.Types.ObjectId} user - User ID * @property {string} [conversationId] - Optional conversation ID * @property {string} file_id - File identifier @@ -17,6 +19,8 @@ const mongoose = require('mongoose'); * @property {number} [width] - Optional width of the file * @property {number} [height] - Optional height of the file * @property {Date} [expiresAt] - Optional height of the file + * @property {Date} [createdAt] - Date when the file was created + * @property {Date} [updatedAt] - Date when the file was updated */ const fileSchema = mongoose.Schema( { @@ -61,6 +65,10 @@ const fileSchema = mongoose.Schema( type: String, required: true, }, + context: { + type: String, + // required: true, + }, usage: { type: Number, required: true, diff --git a/api/models/schema/messageSchema.js b/api/models/schema/messageSchema.js index 06da19e47..fc745499f 100644 --- a/api/models/schema/messageSchema.js +++ b/api/models/schema/messageSchema.js @@ -17,6 +17,7 @@ const messageSchema = mongoose.Schema( user: { type: String, index: true, + required: true, default: null, }, model: { @@ -46,12 +47,10 @@ const messageSchema = mongoose.Schema( }, sender: { type: String, - required: true, meiliIndex: true, }, text: { type: String, - required: true, meiliIndex: true, }, summary: { @@ -103,6 +102,14 @@ const messageSchema = mongoose.Schema( default: undefined, }, plugins: { type: [{ type: mongoose.Schema.Types.Mixed }], default: undefined }, + content: { + type: [{ type: mongoose.Schema.Types.Mixed }], + default: undefined, + meiliIndex: true, + }, + thread_id: { + type: String, + }, }, { timestamps: true }, ); diff --git a/api/models/spendTokens.js b/api/models/spendTokens.js index 3687d5512..f52075b13 100644 --- a/api/models/spendTokens.js +++ b/api/models/spendTokens.js @@ -21,6 +21,10 @@ const { logger } = require('~/config'); */ const spendTokens = async (txData, tokenUsage) => { const { promptTokens, completionTokens } = tokenUsage; + logger.debug(`[spendTokens] conversationId: ${txData.conversationId} | Token usage: `, { + promptTokens, + completionTokens, + }); let prompt, completion; try { if (promptTokens >= 0) { @@ -42,7 +46,12 @@ const spendTokens = async (txData, tokenUsage) => { rawAmount: -completionTokens, }); - logger.debug('[spendTokens] post-transaction', { prompt, completion }); + prompt && + completion && + logger.debug('[spendTokens] Transaction data record against balance:', { + prompt, + completion, + }); } catch (err) { logger.error('[spendTokens]', err); } diff --git a/api/models/userMethods.js b/api/models/userMethods.js new file mode 100644 index 000000000..c1ccce5b5 --- /dev/null +++ b/api/models/userMethods.js @@ -0,0 +1,46 @@ +const bcrypt = require('bcryptjs'); +const User = require('./User'); + +const hashPassword = async (password) => { + const hashedPassword = await new Promise((resolve, reject) => { + bcrypt.hash(password, 10, function (err, hash) { + if (err) { + reject(err); + } else { + resolve(hash); + } + }); + }); + + return hashedPassword; +}; + +/** + * Retrieve a user by ID and convert the found user document to a plain object. + * + * @param {string} userId - The ID of the user to find and return as a plain object. + * @returns {Promise} A plain object representing the user document, or `null` if no user is found. + */ +const getUser = async function (userId) { + return await User.findById(userId).lean(); +}; + +/** + * Update a user with new data without overwriting existing properties. + * + * @param {string} userId - The ID of the user to update. + * @param {Object} updateData - An object containing the properties to update. + * @returns {Promise} The updated user document as a plain object, or `null` if no user is found. + */ +const updateUser = async function (userId, updateData) { + return await User.findByIdAndUpdate(userId, updateData, { + new: true, + runValidators: true, + }).lean(); +}; + +module.exports = { + hashPassword, + updateUser, + getUser, +}; diff --git a/api/package.json b/api/package.json index e35c01e64..1119bff05 100644 --- a/api/package.json +++ b/api/package.json @@ -31,6 +31,7 @@ "@azure/search-documents": "^12.0.0", "@keyv/mongo": "^2.1.8", "@keyv/redis": "^2.8.1", + "@langchain/community": "^0.0.17", "@langchain/google-genai": "^0.0.8", "axios": "^1.3.4", "bcryptjs": "^2.4.3", @@ -44,6 +45,7 @@ "express-mongo-sanitize": "^2.2.0", "express-rate-limit": "^6.9.0", "express-session": "^1.17.3", + "file-type": "^18.7.0", "firebase": "^10.6.0", "googleapis": "^126.0.1", "handlebars": "^4.7.7", @@ -58,6 +60,7 @@ "librechat-data-provider": "*", "lodash": "^4.17.21", "meilisearch": "^0.33.0", + "mime": "^3.0.0", "module-alias": "^2.2.3", "mongoose": "^7.1.1", "multer": "^1.4.5-lts.1", diff --git a/api/server/controllers/AskController.js b/api/server/controllers/AskController.js index 67d7c67e9..52ed0d30f 100644 --- a/api/server/controllers/AskController.js +++ b/api/server/controllers/AskController.js @@ -1,4 +1,4 @@ -const { getResponseSender } = require('librechat-data-provider'); +const { getResponseSender, Constants } = require('librechat-data-provider'); const { sendMessage, createOnProgress } = require('~/server/utils'); const { saveMessage, getConvoTitle, getConvo } = require('~/models'); const { createAbortController, handleAbortError } = require('~/server/middleware'); @@ -140,7 +140,7 @@ const AskController = async (req, res, next, initializeClient, addTitle) => { await saveMessage(userMessage); - if (addTitle && parentMessageId === '00000000-0000-0000-0000-000000000000' && newConvo) { + if (addTitle && parentMessageId === Constants.NO_PARENT && newConvo) { addTitle(req, { text, response, diff --git a/api/server/controllers/EndpointController.js b/api/server/controllers/EndpointController.js index 5069bb33e..3a0db0222 100644 --- a/api/server/controllers/EndpointController.js +++ b/api/server/controllers/EndpointController.js @@ -1,4 +1,4 @@ -const { CacheKeys } = require('librechat-data-provider'); +const { CacheKeys, EModelEndpoint } = require('librechat-data-provider'); const { loadDefaultEndpointsConfig, loadConfigEndpoints } = require('~/server/services/Config'); const { getLogStores } = require('~/cache'); @@ -14,6 +14,10 @@ async function endpointController(req, res) { const customConfigEndpoints = await loadConfigEndpoints(); const endpointsConfig = { ...defaultEndpointsConfig, ...customConfigEndpoints }; + if (endpointsConfig[EModelEndpoint.assistants] && req.app.locals?.[EModelEndpoint.assistants]) { + endpointsConfig[EModelEndpoint.assistants].disableBuilder = + req.app.locals[EModelEndpoint.assistants].disableBuilder; + } await cache.set(CacheKeys.ENDPOINT_CONFIG, endpointsConfig); res.send(JSON.stringify(endpointsConfig)); diff --git a/api/server/controllers/PluginController.js b/api/server/controllers/PluginController.js index b3c4d31ae..803d89923 100644 --- a/api/server/controllers/PluginController.js +++ b/api/server/controllers/PluginController.js @@ -1,4 +1,3 @@ -const path = require('path'); const { promises: fs } = require('fs'); const { CacheKeys } = require('librechat-data-provider'); const { addOpenAPISpecs } = require('~/app/clients/tools/util/addOpenAPISpecs'); @@ -56,12 +55,10 @@ const getAvailablePluginsController = async (req, res) => { return; } - const manifestFile = await fs.readFile( - path.join(__dirname, '..', '..', 'app', 'clients', 'tools', 'manifest.json'), - 'utf8', - ); + const pluginManifest = await fs.readFile(req.app.locals.paths.pluginManifest, 'utf8'); - const jsonData = JSON.parse(manifestFile); + const jsonData = JSON.parse(pluginManifest); + /** @type {TPlugin[]} */ const uniquePlugins = filterUniquePlugins(jsonData); const authenticatedPlugins = uniquePlugins.map((plugin) => { if (isPluginAuthenticated(plugin)) { @@ -78,6 +75,53 @@ const getAvailablePluginsController = async (req, res) => { } }; +/** + * Retrieves and returns a list of available tools, either from a cache or by reading a plugin manifest file. + * + * This function first attempts to retrieve the list of tools from a cache. If the tools are not found in the cache, + * it reads a plugin manifest file, filters for unique plugins, and determines if each plugin is authenticated. + * Only plugins that are marked as available in the application's local state are included in the final list. + * The resulting list of tools is then cached and sent to the client. + * + * @param {object} req - The request object, containing information about the HTTP request. + * @param {object} res - The response object, used to send back the desired HTTP response. + * @returns {Promise} A promise that resolves when the function has completed. + */ +const getAvailableTools = async (req, res) => { + try { + const cache = getLogStores(CacheKeys.CONFIG_STORE); + const cachedTools = await cache.get(CacheKeys.TOOLS); + if (cachedTools) { + res.status(200).json(cachedTools); + return; + } + + const pluginManifest = await fs.readFile(req.app.locals.paths.pluginManifest, 'utf8'); + + const jsonData = JSON.parse(pluginManifest); + /** @type {TPlugin[]} */ + const uniquePlugins = filterUniquePlugins(jsonData); + + const authenticatedPlugins = uniquePlugins.map((plugin) => { + if (isPluginAuthenticated(plugin)) { + return { ...plugin, authenticated: true }; + } else { + return plugin; + } + }); + + const tools = authenticatedPlugins.filter( + (plugin) => req.app.locals.availableTools[plugin.pluginKey] !== undefined, + ); + + await cache.set(CacheKeys.TOOLS, tools); + res.status(200).json(tools); + } catch (error) { + res.status(500).json({ message: error.message }); + } +}; + module.exports = { + getAvailableTools, getAvailablePluginsController, }; diff --git a/api/server/controllers/UserController.js b/api/server/controllers/UserController.js index fa08cd545..ac20ca627 100644 --- a/api/server/controllers/UserController.js +++ b/api/server/controllers/UserController.js @@ -8,16 +8,19 @@ const getUserController = async (req, res) => { const updateUserPluginsController = async (req, res) => { const { user } = req; - const { pluginKey, action, auth } = req.body; + const { pluginKey, action, auth, isAssistantTool } = req.body; let authService; try { - const userPluginsService = await updateUserPluginsService(user, pluginKey, action); + if (!isAssistantTool) { + const userPluginsService = await updateUserPluginsService(user, pluginKey, action); - if (userPluginsService instanceof Error) { - logger.error('[userPluginsService]', userPluginsService); - const { status, message } = userPluginsService; - res.status(status).send({ message }); + if (userPluginsService instanceof Error) { + logger.error('[userPluginsService]', userPluginsService); + const { status, message } = userPluginsService; + res.status(status).send({ message }); + } } + if (auth) { const keys = Object.keys(auth); const values = Object.values(auth); diff --git a/api/server/index.js b/api/server/index.js index c08415e8f..5d35434bd 100644 --- a/api/server/index.js +++ b/api/server/index.js @@ -76,7 +76,7 @@ const startServer = async () => { app.use('/api/plugins', routes.plugins); app.use('/api/config', routes.config); app.use('/api/assistants', routes.assistants); - app.use('/api/files', routes.files); + app.use('/api/files', await routes.files.initialize()); app.use((req, res) => { res.status(404).sendFile(path.join(app.locals.paths.dist, 'index.html')); diff --git a/api/server/middleware/abortMiddleware.js b/api/server/middleware/abortMiddleware.js index cc9b9fc05..08f614286 100644 --- a/api/server/middleware/abortMiddleware.js +++ b/api/server/middleware/abortMiddleware.js @@ -1,18 +1,24 @@ +const { EModelEndpoint } = require('librechat-data-provider'); const { sendMessage, sendError, countTokens, isEnabled } = require('~/server/utils'); const { saveMessage, getConvo, getConvoTitle } = require('~/models'); const clearPendingReq = require('~/cache/clearPendingReq'); const abortControllers = require('./abortControllers'); const { redactMessage } = require('~/config/parsers'); const spendTokens = require('~/models/spendTokens'); +const { abortRun } = require('./abortRun'); const { logger } = require('~/config'); async function abortMessage(req, res) { - let { abortKey, conversationId } = req.body; + let { abortKey, conversationId, endpoint } = req.body; if (!abortKey && conversationId) { abortKey = conversationId; } + if (endpoint === EModelEndpoint.assistants) { + return await abortRun(req, res); + } + if (!abortControllers.has(abortKey) && !res.headersSent) { return res.status(204).send({ message: 'Request not found' }); } diff --git a/api/server/middleware/abortRun.js b/api/server/middleware/abortRun.js new file mode 100644 index 000000000..b93eb8c21 --- /dev/null +++ b/api/server/middleware/abortRun.js @@ -0,0 +1,87 @@ +const { CacheKeys, RunStatus, isUUID } = require('librechat-data-provider'); +const { initializeClient } = require('~/server/services/Endpoints/assistant'); +const { checkMessageGaps, recordUsage } = require('~/server/services/Threads'); +const { getConvo } = require('~/models/Conversation'); +const getLogStores = require('~/cache/getLogStores'); +const { sendMessage } = require('~/server/utils'); +// const spendTokens = require('~/models/spendTokens'); +const { logger } = require('~/config'); + +async function abortRun(req, res) { + res.setHeader('Content-Type', 'application/json'); + const { abortKey } = req.body; + const [conversationId, latestMessageId] = abortKey.split(':'); + + if (!isUUID.safeParse(conversationId).success) { + logger.error('[abortRun] Invalid conversationId', { conversationId }); + return res.status(400).send({ message: 'Invalid conversationId' }); + } + + const cacheKey = `${req.user.id}:${conversationId}`; + const cache = getLogStores(CacheKeys.ABORT_KEYS); + const runValues = await cache.get(cacheKey); + const [thread_id, run_id] = runValues.split(':'); + + if (!run_id) { + logger.warn('[abortRun] Couldn\'t find run for cancel request', { thread_id }); + return res.status(204).send({ message: 'Run not found' }); + } else if (run_id === 'cancelled') { + logger.warn('[abortRun] Run already cancelled', { thread_id }); + return res.status(204).send({ message: 'Run already cancelled' }); + } + + let runMessages = []; + /** @type {{ openai: OpenAI }} */ + const { openai } = await initializeClient({ req, res }); + + try { + await cache.set(cacheKey, 'cancelled'); + const cancelledRun = await openai.beta.threads.runs.cancel(thread_id, run_id); + logger.debug('Cancelled run:', cancelledRun); + } catch (error) { + logger.error('[abortRun] Error cancelling run', error); + if ( + error?.message?.includes(RunStatus.CANCELLED) || + error?.message?.includes(RunStatus.CANCELLING) + ) { + return res.end(); + } + } + + try { + const run = await openai.beta.threads.runs.retrieve(thread_id, run_id); + await recordUsage({ + ...run.usage, + model: run.model, + user: req.user.id, + conversationId, + }); + } catch (error) { + logger.error('[abortRun] Error fetching or processing run', error); + } + + runMessages = await checkMessageGaps({ + openai, + latestMessageId, + thread_id, + run_id, + conversationId, + }); + + const finalEvent = { + title: 'New Chat', + final: true, + conversation: await getConvo(req.user.id, conversationId), + runMessages, + }; + + if (res.headersSent && finalEvent) { + return sendMessage(res, finalEvent); + } + + res.json(finalEvent); +} + +module.exports = { + abortRun, +}; diff --git a/api/server/middleware/buildEndpointOption.js b/api/server/middleware/buildEndpointOption.js index 91d0cacea..40ad9dadb 100644 --- a/api/server/middleware/buildEndpointOption.js +++ b/api/server/middleware/buildEndpointOption.js @@ -5,6 +5,7 @@ const anthropic = require('~/server/services/Endpoints/anthropic'); const openAI = require('~/server/services/Endpoints/openAI'); const custom = require('~/server/services/Endpoints/custom'); const google = require('~/server/services/Endpoints/google'); +const assistant = require('~/server/services/Endpoints/assistant'); const buildFunction = { [EModelEndpoint.openAI]: openAI.buildOptions, @@ -13,6 +14,7 @@ const buildFunction = { [EModelEndpoint.azureOpenAI]: openAI.buildOptions, [EModelEndpoint.anthropic]: anthropic.buildOptions, [EModelEndpoint.gptPlugins]: gptPlugins.buildOptions, + [EModelEndpoint.assistants]: assistant.buildOptions, }; function buildEndpointOption(req, res, next) { diff --git a/api/server/middleware/denyRequest.js b/api/server/middleware/denyRequest.js index 8000aa2b1..37952176b 100644 --- a/api/server/middleware/denyRequest.js +++ b/api/server/middleware/denyRequest.js @@ -1,7 +1,7 @@ const crypto = require('crypto'); -const { saveMessage } = require('~/models'); +const { getResponseSender, Constants } = require('librechat-data-provider'); const { sendMessage, sendError } = require('~/server/utils'); -const { getResponseSender } = require('librechat-data-provider'); +const { saveMessage } = require('~/models'); /** * Denies a request by sending an error message and optionally saves the user's message. @@ -38,8 +38,7 @@ const denyRequest = async (req, res, errorMessage) => { }; sendMessage(res, { message: userMessage, created: true }); - const shouldSaveMessage = - _convoId && parentMessageId && parentMessageId !== '00000000-0000-0000-0000-000000000000'; + const shouldSaveMessage = _convoId && parentMessageId && parentMessageId !== Constants.NO_PARENT; if (shouldSaveMessage) { await saveMessage({ ...userMessage, user: req.user.id }); diff --git a/api/server/middleware/index.js b/api/server/middleware/index.js index 77afd9716..5b257c9a4 100644 --- a/api/server/middleware/index.js +++ b/api/server/middleware/index.js @@ -4,6 +4,7 @@ const uaParser = require('./uaParser'); const setHeaders = require('./setHeaders'); const loginLimiter = require('./loginLimiter'); const requireJwtAuth = require('./requireJwtAuth'); +const uploadLimiters = require('./uploadLimiters'); const registerLimiter = require('./registerLimiter'); const messageLimiters = require('./messageLimiters'); const requireLocalAuth = require('./requireLocalAuth'); @@ -16,6 +17,7 @@ const moderateText = require('./moderateText'); const noIndex = require('./noIndex'); module.exports = { + ...uploadLimiters, ...abortMiddleware, ...messageLimiters, checkBan, diff --git a/api/server/middleware/uploadLimiters.js b/api/server/middleware/uploadLimiters.js new file mode 100644 index 000000000..80544a580 --- /dev/null +++ b/api/server/middleware/uploadLimiters.js @@ -0,0 +1,75 @@ +const rateLimit = require('express-rate-limit'); +const { CacheKeys } = require('librechat-data-provider'); +const logViolation = require('~/cache/logViolation'); + +const getEnvironmentVariables = () => { + const FILE_UPLOAD_IP_MAX = parseInt(process.env.FILE_UPLOAD_IP_MAX) || 100; + const FILE_UPLOAD_IP_WINDOW = parseInt(process.env.FILE_UPLOAD_IP_WINDOW) || 15; + const FILE_UPLOAD_USER_MAX = parseInt(process.env.FILE_UPLOAD_USER_MAX) || 50; + const FILE_UPLOAD_USER_WINDOW = parseInt(process.env.FILE_UPLOAD_USER_WINDOW) || 15; + + const fileUploadIpWindowMs = FILE_UPLOAD_IP_WINDOW * 60 * 1000; + const fileUploadIpMax = FILE_UPLOAD_IP_MAX; + const fileUploadIpWindowInMinutes = fileUploadIpWindowMs / 60000; + + const fileUploadUserWindowMs = FILE_UPLOAD_USER_WINDOW * 60 * 1000; + const fileUploadUserMax = FILE_UPLOAD_USER_MAX; + const fileUploadUserWindowInMinutes = fileUploadUserWindowMs / 60000; + + return { + fileUploadIpWindowMs, + fileUploadIpMax, + fileUploadIpWindowInMinutes, + fileUploadUserWindowMs, + fileUploadUserMax, + fileUploadUserWindowInMinutes, + }; +}; + +const createFileUploadHandler = (ip = true) => { + const { + fileUploadIpMax, + fileUploadIpWindowInMinutes, + fileUploadUserMax, + fileUploadUserWindowInMinutes, + } = getEnvironmentVariables(); + + return async (req, res) => { + const type = CacheKeys.FILE_UPLOAD_LIMIT; + const errorMessage = { + type, + max: ip ? fileUploadIpMax : fileUploadUserMax, + limiter: ip ? 'ip' : 'user', + windowInMinutes: ip ? fileUploadIpWindowInMinutes : fileUploadUserWindowInMinutes, + }; + + await logViolation(req, res, type, errorMessage); + res.status(429).json({ message: 'Too many file upload requests. Try again later' }); + }; +}; + +const createFileLimiters = () => { + const { fileUploadIpWindowMs, fileUploadIpMax, fileUploadUserWindowMs, fileUploadUserMax } = + getEnvironmentVariables(); + + const fileUploadIpLimiter = rateLimit({ + windowMs: fileUploadIpWindowMs, + max: fileUploadIpMax, + handler: createFileUploadHandler(), + }); + + const fileUploadUserLimiter = rateLimit({ + windowMs: fileUploadUserWindowMs, + max: fileUploadUserMax, + handler: createFileUploadHandler(false), + keyGenerator: function (req) { + return req.user?.id; // Use the user ID or NULL if not available + }, + }); + + return { fileUploadIpLimiter, fileUploadUserLimiter }; +}; + +module.exports = { + createFileLimiters, +}; diff --git a/api/server/routes/ask/askChatGPTBrowser.js b/api/server/routes/ask/askChatGPTBrowser.js index 34f1096a8..4ce1770b8 100644 --- a/api/server/routes/ask/askChatGPTBrowser.js +++ b/api/server/routes/ask/askChatGPTBrowser.js @@ -1,5 +1,6 @@ const crypto = require('crypto'); const express = require('express'); +const { Constants } = require('librechat-data-provider'); const { saveMessage, getConvoTitle, saveConvo, getConvo } = require('~/models'); const { handleError, sendMessage, createOnProgress, handleText } = require('~/server/utils'); const { setHeaders } = require('~/server/middleware'); @@ -27,7 +28,7 @@ router.post('/', setHeaders, async (req, res) => { const conversationId = oldConversationId || crypto.randomUUID(); const isNewConversation = !oldConversationId; const userMessageId = crypto.randomUUID(); - const userParentMessageId = parentMessageId || '00000000-0000-0000-0000-000000000000'; + const userParentMessageId = parentMessageId || Constants.NO_PARENT; const userMessage = { messageId: userMessageId, sender: 'User', @@ -209,7 +210,7 @@ const ask = async ({ }); res.end(); - if (userParentMessageId == '00000000-0000-0000-0000-000000000000') { + if (userParentMessageId == Constants.NO_PARENT) { // const title = await titleConvo({ endpoint: endpointOption?.endpoint, text, response: responseMessage }); const title = await response.details.title; await saveConvo(user, { diff --git a/api/server/routes/ask/bingAI.js b/api/server/routes/ask/bingAI.js index 1281b56ae..916cda4b1 100644 --- a/api/server/routes/ask/bingAI.js +++ b/api/server/routes/ask/bingAI.js @@ -1,5 +1,6 @@ -const express = require('express'); const crypto = require('crypto'); +const express = require('express'); +const { Constants } = require('librechat-data-provider'); const { handleError, sendMessage, createOnProgress, handleText } = require('~/server/utils'); const { saveMessage, getConvoTitle, saveConvo, getConvo } = require('~/models'); const { setHeaders } = require('~/server/middleware'); @@ -28,7 +29,7 @@ router.post('/', setHeaders, async (req, res) => { const conversationId = oldConversationId || crypto.randomUUID(); const isNewConversation = !oldConversationId; const userMessageId = messageId; - const userParentMessageId = parentMessageId || '00000000-0000-0000-0000-000000000000'; + const userParentMessageId = parentMessageId || Constants.NO_PARENT; let userMessage = { messageId: userMessageId, sender: 'User', @@ -238,7 +239,7 @@ const ask = async ({ }); res.end(); - if (userParentMessageId == '00000000-0000-0000-0000-000000000000') { + if (userParentMessageId == Constants.NO_PARENT) { const title = await titleConvoBing({ text, response: responseMessage, diff --git a/api/server/routes/ask/gptPlugins.js b/api/server/routes/ask/gptPlugins.js index 85616cd1b..80817e5a4 100644 --- a/api/server/routes/ask/gptPlugins.js +++ b/api/server/routes/ask/gptPlugins.js @@ -1,6 +1,6 @@ const express = require('express'); const router = express.Router(); -const { getResponseSender } = require('librechat-data-provider'); +const { getResponseSender, Constants } = require('librechat-data-provider'); const { validateTools } = require('~/app'); const { addTitle } = require('~/server/services/Endpoints/openAI'); const { initializeClient } = require('~/server/services/Endpoints/gptPlugins'); @@ -204,7 +204,7 @@ router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, }); res.end(); - if (parentMessageId === '00000000-0000-0000-0000-000000000000' && newConvo) { + if (parentMessageId === Constants.NO_PARENT && newConvo) { addTitle(req, { text, response, diff --git a/api/server/routes/assistants/actions.js b/api/server/routes/assistants/actions.js new file mode 100644 index 000000000..9a10be9f0 --- /dev/null +++ b/api/server/routes/assistants/actions.js @@ -0,0 +1,201 @@ +const { v4 } = require('uuid'); +const express = require('express'); +const { actionDelimiter } = require('librechat-data-provider'); +const { initializeClient } = require('~/server/services/Endpoints/assistant'); +const { updateAction, getActions, deleteAction } = require('~/models/Action'); +const { updateAssistant, getAssistant } = require('~/models/Assistant'); +const { encryptMetadata } = require('~/server/services/ActionService'); +const { logger } = require('~/config'); + +const router = express.Router(); + +/** + * Retrieves all user's actions + * @route GET /actions/ + * @param {string} req.params.id - Assistant identifier. + * @returns {Action[]} 200 - success response - application/json + */ +router.get('/', async (req, res) => { + try { + res.json(await getActions({ user: req.user.id })); + } catch (error) { + res.status(500).json({ error: error.message }); + } +}); + +/** + * Adds or updates actions for a specific assistant. + * @route POST /actions/:assistant_id + * @param {string} req.params.assistant_id - The ID of the assistant. + * @param {FunctionTool[]} req.body.functions - The functions to be added or updated. + * @param {string} [req.body.action_id] - Optional ID for the action. + * @param {ActionMetadata} req.body.metadata - Metadata for the action. + * @returns {Object} 200 - success response - application/json + */ +router.post('/:assistant_id', async (req, res) => { + try { + const { assistant_id } = req.params; + + /** @type {{ functions: FunctionTool[], action_id: string, metadata: ActionMetadata }} */ + const { functions, action_id: _action_id, metadata: _metadata } = req.body; + if (!functions.length) { + return res.status(400).json({ message: 'No functions provided' }); + } + + let metadata = encryptMetadata(_metadata); + + const { domain } = metadata; + if (!domain) { + return res.status(400).json({ message: 'No domain provided' }); + } + + const action_id = _action_id ?? v4(); + const initialPromises = []; + + /** @type {{ openai: OpenAI }} */ + const { openai } = await initializeClient({ req, res }); + + initialPromises.push(getAssistant({ assistant_id, user: req.user.id })); + initialPromises.push(openai.beta.assistants.retrieve(assistant_id)); + !!_action_id && initialPromises.push(getActions({ user: req.user.id, action_id }, true)); + + /** @type {[AssistantDocument, Assistant, [Action|undefined]]} */ + const [assistant_data, assistant, actions_result] = await Promise.all(initialPromises); + + if (actions_result && actions_result.length) { + const action = actions_result[0]; + metadata = { ...action.metadata, ...metadata }; + } + + if (!assistant) { + return res.status(404).json({ message: 'Assistant not found' }); + } + + const { actions: _actions = [] } = assistant_data ?? {}; + const actions = []; + for (const action of _actions) { + const [action_domain, current_action_id] = action.split(actionDelimiter); + if (action_domain === domain && !_action_id) { + // TODO: dupe check on the frontend + return res.status(400).json({ + message: `Action sets cannot have duplicate domains - ${domain} already exists on another action`, + }); + } + + if (current_action_id === action_id) { + continue; + } + + actions.push(action); + } + + actions.push(`${domain}${actionDelimiter}${action_id}`); + + /** @type {{ tools: FunctionTool[] | { type: 'code_interpreter'|'retrieval'}[]}} */ + const { tools: _tools = [] } = assistant; + + const tools = _tools + .filter( + (tool) => + !( + tool.function && + (tool.function.name.includes(domain) || tool.function.name.includes(action_id)) + ), + ) + .concat( + functions.map((tool) => ({ + ...tool, + function: { + ...tool.function, + name: `${tool.function.name}${actionDelimiter}${domain}`, + }, + })), + ); + + const promises = []; + promises.push( + updateAssistant( + { assistant_id, user: req.user.id }, + { + actions, + }, + ), + ); + promises.push(openai.beta.assistants.update(assistant_id, { tools })); + promises.push(updateAction({ action_id, user: req.user.id }, { metadata, assistant_id })); + + /** @type {[AssistantDocument, Assistant, Action]} */ + const resolved = await Promise.all(promises); + const sensitiveFields = ['api_key', 'oauth_client_id', 'oauth_client_secret']; + for (let field of sensitiveFields) { + if (resolved[2].metadata[field]) { + delete resolved[2].metadata[field]; + } + } + res.json(resolved); + } catch (error) { + const message = 'Trouble updating the Assistant Action'; + logger.error(message, error); + res.status(500).json({ message }); + } +}); + +/** + * Deletes an action for a specific assistant. + * @route DELETE /actions/:assistant_id/:action_id + * @param {string} req.params.assistant_id - The ID of the assistant. + * @param {string} req.params.action_id - The ID of the action to delete. + * @returns {Object} 200 - success response - application/json + */ +router.delete('/:assistant_id/:action_id', async (req, res) => { + try { + const { assistant_id, action_id } = req.params; + + /** @type {{ openai: OpenAI }} */ + const { openai } = await initializeClient({ req, res }); + + const initialPromises = []; + initialPromises.push(getAssistant({ assistant_id, user: req.user.id })); + initialPromises.push(openai.beta.assistants.retrieve(assistant_id)); + + /** @type {[AssistantDocument, Assistant]} */ + const [assistant_data, assistant] = await Promise.all(initialPromises); + + const { actions } = assistant_data ?? {}; + const { tools = [] } = assistant ?? {}; + + let domain = ''; + const updatedActions = actions.filter((action) => { + if (action.includes(action_id)) { + [domain] = action.split(actionDelimiter); + return false; + } + return true; + }); + + const updatedTools = tools.filter( + (tool) => !(tool.function && tool.function.name.includes(domain)), + ); + + const promises = []; + promises.push( + updateAssistant( + { assistant_id, user: req.user.id }, + { + actions: updatedActions, + }, + ), + ); + promises.push(openai.beta.assistants.update(assistant_id, { tools: updatedTools })); + promises.push(deleteAction({ action_id, user: req.user.id })); + + await Promise.all(promises); + res.status(200).json({ message: 'Action deleted successfully' }); + } catch (error) { + const message = 'Trouble deleting the Assistant Action'; + logger.error(message, error); + res.status(500).json({ message }); + } +}); + +module.exports = router; diff --git a/api/server/routes/assistants/assistants.js b/api/server/routes/assistants/assistants.js index b911c685a..0f12e2ec7 100644 --- a/api/server/routes/assistants/assistants.js +++ b/api/server/routes/assistants/assistants.js @@ -1,9 +1,31 @@ -const OpenAI = require('openai'); +const multer = require('multer'); const express = require('express'); +const { FileContext, EModelEndpoint } = require('librechat-data-provider'); +const { updateAssistant, getAssistants } = require('~/models/Assistant'); +const { initializeClient } = require('~/server/services/Endpoints/assistant'); +const { getStrategyFunctions } = require('~/server/services/Files/strategies'); +const { uploadImageBuffer } = require('~/server/services/Files/process'); +const { deleteFileByFilter } = require('~/models/File'); const { logger } = require('~/config'); +const actions = require('./actions'); +const tools = require('./tools'); +const upload = multer(); const router = express.Router(); +/** + * Assistant actions route. + * @route GET|POST /assistants/actions + */ +router.use('/actions', actions); + +/** + * Create an assistant. + * @route GET /assistants/tools + * @returns {TPlugin[]} 200 - application/json + */ +router.use('/tools', tools); + /** * Create an assistant. * @route POST /assistants @@ -12,12 +34,25 @@ const router = express.Router(); */ router.post('/', async (req, res) => { try { - const openai = new OpenAI(process.env.OPENAI_API_KEY); - const assistantData = req.body; + /** @type {{ openai: OpenAI }} */ + const { openai } = await initializeClient({ req, res }); + + const { tools = [], ...assistantData } = req.body; + assistantData.tools = tools + .map((tool) => { + if (typeof tool !== 'string') { + return tool; + } + + return req.app.locals.availableTools[tool]; + }) + .filter((tool) => tool); + const assistant = await openai.beta.assistants.create(assistantData); logger.debug('/assistants/', assistant); res.status(201).json(assistant); } catch (error) { + logger.error('[/assistants] Error creating assistant', error); res.status(500).json({ error: error.message }); } }); @@ -30,11 +65,14 @@ router.post('/', async (req, res) => { */ router.get('/:id', async (req, res) => { try { - const openai = new OpenAI(process.env.OPENAI_API_KEY); + /** @type {{ openai: OpenAI }} */ + const { openai } = await initializeClient({ req, res }); + const assistant_id = req.params.id; const assistant = await openai.beta.assistants.retrieve(assistant_id); res.json(assistant); } catch (error) { + logger.error('[/assistants/:id] Error retrieving assistant', error); res.status(500).json({ error: error.message }); } }); @@ -48,12 +86,25 @@ router.get('/:id', async (req, res) => { */ router.patch('/:id', async (req, res) => { try { - const openai = new OpenAI(process.env.OPENAI_API_KEY); + /** @type {{ openai: OpenAI }} */ + const { openai } = await initializeClient({ req, res }); + const assistant_id = req.params.id; const updateData = req.body; + updateData.tools = (updateData.tools ?? []) + .map((tool) => { + if (typeof tool !== 'string') { + return tool; + } + + return req.app.locals.availableTools[tool]; + }) + .filter((tool) => tool); + const updatedAssistant = await openai.beta.assistants.update(assistant_id, updateData); res.json(updatedAssistant); } catch (error) { + logger.error('[/assistants/:id] Error updating assistant', error); res.status(500).json({ error: error.message }); } }); @@ -66,12 +117,15 @@ router.patch('/:id', async (req, res) => { */ router.delete('/:id', async (req, res) => { try { - const openai = new OpenAI(process.env.OPENAI_API_KEY); + /** @type {{ openai: OpenAI }} */ + const { openai } = await initializeClient({ req, res }); + const assistant_id = req.params.id; const deletionStatus = await openai.beta.assistants.del(assistant_id); res.json(deletionStatus); } catch (error) { - res.status(500).json({ error: error.message }); + logger.error('[/assistants/:id] Error deleting assistant', error); + res.status(500).json({ error: 'Error deleting assistant' }); } }); @@ -79,22 +133,121 @@ router.delete('/:id', async (req, res) => { * Returns a list of assistants. * @route GET /assistants * @param {AssistantListParams} req.query - The assistant list parameters for pagination and sorting. - * @returns {Array} 200 - success response - application/json + * @returns {AssistantListResponse} 200 - success response - application/json */ router.get('/', async (req, res) => { try { - const openai = new OpenAI(process.env.OPENAI_API_KEY); + /** @type {{ openai: OpenAI }} */ + const { openai } = await initializeClient({ req, res }); + const { limit, order, after, before } = req.query; - const assistants = await openai.beta.assistants.list({ + const response = await openai.beta.assistants.list({ limit, order, after, before, }); - res.json(assistants); + + /** @type {AssistantListResponse} */ + let body = response.body; + + if (req.app.locals?.[EModelEndpoint.assistants]) { + /** @type {Partial} */ + const assistantsConfig = req.app.locals[EModelEndpoint.assistants]; + const { supportedIds, excludedIds } = assistantsConfig; + if (supportedIds?.length) { + body.data = body.data.filter((assistant) => supportedIds.includes(assistant.id)); + } else if (excludedIds?.length) { + body.data = body.data.filter((assistant) => !excludedIds.includes(assistant.id)); + } + } + + res.json(body); } catch (error) { + logger.error('[/assistants] Error listing assistants', error); res.status(500).json({ error: error.message }); } }); +/** + * Returns a list of the user's assistant documents (metadata saved to database). + * @route GET /assistants/documents + * @returns {AssistantDocument[]} 200 - success response - application/json + */ +router.get('/documents', async (req, res) => { + try { + res.json(await getAssistants({ user: req.user.id })); + } catch (error) { + logger.error('[/assistants/documents] Error listing assistant documents', error); + res.status(500).json({ error: error.message }); + } +}); + +/** + * Uploads and updates an avatar for a specific assistant. + * @route POST /avatar/:assistant_id + * @param {string} req.params.assistant_id - The ID of the assistant. + * @param {Express.Multer.File} req.file - The avatar image file. + * @param {string} [req.body.metadata] - Optional metadata for the assistant's avatar. + * @returns {Object} 200 - success response - application/json + */ +router.post('/avatar/:assistant_id', upload.single('file'), async (req, res) => { + try { + const { assistant_id } = req.params; + if (!assistant_id) { + return res.status(400).json({ message: 'Assistant ID is required' }); + } + + let { metadata: _metadata = '{}' } = req.body; + /** @type {{ openai: OpenAI }} */ + const { openai } = await initializeClient({ req, res }); + + const image = await uploadImageBuffer({ req, context: FileContext.avatar }); + + try { + _metadata = JSON.parse(_metadata); + } catch (error) { + logger.error('[/avatar/:assistant_id] Error parsing metadata', error); + _metadata = {}; + } + + if (_metadata.avatar && _metadata.avatar_source) { + const { deleteFile } = getStrategyFunctions(_metadata.avatar_source); + try { + await deleteFile(req, { filepath: _metadata.avatar }); + await deleteFileByFilter({ filepath: _metadata.avatar }); + } catch (error) { + logger.error('[/avatar/:assistant_id] Error deleting old avatar', error); + } + } + + const metadata = { + ..._metadata, + avatar: image.filepath, + avatar_source: req.app.locals.fileStrategy, + }; + + const promises = []; + promises.push( + updateAssistant( + { assistant_id, user: req.user.id }, + { + avatar: { + filepath: image.filepath, + source: req.app.locals.fileStrategy, + }, + }, + ), + ); + promises.push(openai.beta.assistants.update(assistant_id, { metadata })); + + const resolved = await Promise.all(promises); + res.status(201).json(resolved[1]); + } catch (error) { + const message = 'An error occurred while updating the Assistant Avatar'; + logger.error(message, error); + res.status(500).json({ message }); + } +}); + module.exports = router; diff --git a/api/server/routes/assistants/chat.js b/api/server/routes/assistants/chat.js index e45bad191..57a3d4d28 100644 --- a/api/server/routes/assistants/chat.js +++ b/api/server/routes/assistants/chat.js @@ -1,64 +1,217 @@ -const crypto = require('crypto'); -const OpenAI = require('openai'); -const { logger } = require('~/config'); -const { sendMessage } = require('../../utils'); -const { initThread, createRun, handleRun } = require('../../services/AssistantService'); +const { v4 } = require('uuid'); const express = require('express'); +const { EModelEndpoint, Constants, RunStatus, CacheKeys } = require('librechat-data-provider'); +const { + initThread, + recordUsage, + saveUserMessage, + checkMessageGaps, + addThreadMetadata, + saveAssistantMessage, +} = require('~/server/services/Threads'); +const { runAssistant, createOnTextProgress } = require('~/server/services/AssistantService'); +const { addTitle, initializeClient } = require('~/server/services/Endpoints/assistant'); +const { createRun, sleep } = require('~/server/services/Runs'); +const { getConvo } = require('~/models/Conversation'); +const getLogStores = require('~/cache/getLogStores'); +const { sendMessage } = require('~/server/utils'); +const { logger } = require('~/config'); + const router = express.Router(); const { setHeaders, - // handleAbort, - // handleAbortError, + handleAbort, + handleAbortError, // validateEndpoint, - // buildEndpointOption, - // createAbortController, -} = require('../../middleware'); + buildEndpointOption, +} = require('~/server/middleware'); -// const thread = { -// id: 'thread_LexzJUVugYFqfslS7c7iL3Zo', -// "thread_nZoiCbPauU60LqY1Q0ME1elg" -// }; +router.post('/abort', handleAbort()); /** - * Chat with an assistant. + * @route POST / + * @desc Chat with an assistant + * @access Public + * @param {express.Request} req - The request object, containing the request data. + * @param {express.Response} res - The response object, used to send back a response. + * @returns {void} */ -router.post('/', setHeaders, async (req, res) => { - try { - logger.debug('[/assistants/chat/] req.body', req.body); - // test message: - // How many polls of 500 ms intervals are there in 18 seconds? +router.post('/', buildEndpointOption, setHeaders, async (req, res) => { + logger.debug('[/assistants/chat/] req.body', req.body); + const { + text, + model, + files = [], + promptPrefix, + assistant_id, + instructions, + thread_id: _thread_id, + messageId: _messageId, + conversationId: convoId, + parentMessageId: _parentId = Constants.NO_PARENT, + } = req.body; - const { assistant_id, messages, text: userMessage, messageId } = req.body; - const conversationId = req.body.conversationId || crypto.randomUUID(); - // let thread_id = req.body.thread_id ?? 'thread_nZoiCbPauU60LqY1Q0ME1elg'; // for testing - let thread_id = req.body.thread_id; + /** @type {Partial} */ + const assistantsConfig = req.app.locals?.[EModelEndpoint.assistants]; + + if (assistantsConfig) { + const { supportedIds, excludedIds } = assistantsConfig; + const error = { message: 'Assistant not supported' }; + if (supportedIds?.length && !supportedIds.includes(assistant_id)) { + return await handleAbortError(res, req, error, { + sender: 'System', + conversationId: convoId, + messageId: v4(), + parentMessageId: _messageId, + error, + }); + } else if (excludedIds?.length && excludedIds.includes(assistant_id)) { + return await handleAbortError(res, req, error, { + sender: 'System', + conversationId: convoId, + messageId: v4(), + parentMessageId: _messageId, + }); + } + } + + /** @type {OpenAIClient} */ + let openai; + /** @type {string|undefined} - the current thread id */ + let thread_id = _thread_id; + /** @type {string|undefined} - the current run id */ + let run_id; + /** @type {string|undefined} - the parent messageId */ + let parentMessageId = _parentId; + /** @type {TMessage[]} */ + let previousMessages = []; + + const userMessageId = v4(); + const responseMessageId = v4(); + + /** @type {string} - The conversation UUID - created if undefined */ + const conversationId = convoId ?? v4(); + + const cache = getLogStores(CacheKeys.ABORT_KEYS); + const cacheKey = `${req.user.id}:${conversationId}`; + + try { + if (convoId && !_thread_id) { + throw new Error('Missing thread_id for existing conversation'); + } if (!assistant_id) { throw new Error('Missing assistant_id'); } - const openai = new OpenAI(process.env.OPENAI_API_KEY); - console.log(messages); + /** @type {{ openai: OpenAIClient }} */ + const { openai: _openai, client } = await initializeClient({ + req, + res, + endpointOption: req.body.endpointOption, + initAppClient: true, + }); - const initThreadBody = { - messages: [ - { - role: 'user', - content: userMessage, - metadata: { - messageId, - }, - }, - ], + openai = _openai; + + // if (thread_id) { + // previousMessages = await checkMessageGaps({ openai, thread_id, conversationId }); + // } + + if (previousMessages.length) { + parentMessageId = previousMessages[previousMessages.length - 1].messageId; + } + + const userMessage = { + role: 'user', + content: text, metadata: { + messageId: userMessageId, + }, + }; + + let thread_file_ids = []; + if (convoId) { + const convo = await getConvo(req.user.id, convoId); + if (convo && convo.file_ids) { + thread_file_ids = convo.file_ids; + } + } + + const file_ids = files.map(({ file_id }) => file_id); + if (file_ids.length || thread_file_ids.length) { + userMessage.file_ids = file_ids; + openai.attachedFileIds = new Set([...file_ids, ...thread_file_ids]); + } + + // TODO: may allow multiple messages to be created beforehand in a future update + const initThreadBody = { + messages: [userMessage], + metadata: { + user: req.user.id, conversationId, }, }; const result = await initThread({ openai, body: initThreadBody, thread_id }); - // const { messages: _messages } = result; thread_id = result.thread_id; + createOnTextProgress({ + openai, + conversationId, + userMessageId, + messageId: responseMessageId, + thread_id, + }); + + const requestMessage = { + user: req.user.id, + text, + messageId: userMessageId, + parentMessageId, + // TODO: make sure client sends correct format for `files`, use zod + files, + file_ids, + conversationId, + isCreatedByUser: true, + assistant_id, + thread_id, + model: assistant_id, + }; + + previousMessages.push(requestMessage); + + await saveUserMessage({ ...requestMessage, model }); + + const conversation = { + conversationId, + // TODO: title feature + title: 'New Chat', + endpoint: EModelEndpoint.assistants, + promptPrefix: promptPrefix, + instructions: instructions, + assistant_id, + // model, + }; + + if (file_ids.length) { + conversation.file_ids = file_ids; + } + + /** @type {CreateRunBody} */ + const body = { + assistant_id, + model, + }; + + if (promptPrefix) { + body.additional_instructions = promptPrefix; + } + + if (instructions) { + body.instructions = instructions; + } + /* NOTE: * By default, a Run will use the model and tools configuration specified in Assistant object, * but you can override most of these when creating the Run for added flexibility: @@ -66,43 +219,160 @@ router.post('/', setHeaders, async (req, res) => { const run = await createRun({ openai, thread_id, - body: { assistant_id, model: 'gpt-3.5-turbo-1106' }, + body, }); - const response = await handleRun({ openai, thread_id, run_id: run.id }); + run_id = run.id; + await cache.set(cacheKey, `${thread_id}:${run_id}`); + + sendMessage(res, { + sync: true, + conversationId, + // messages: previousMessages, + requestMessage, + responseMessage: { + user: req.user.id, + messageId: openai.responseMessage.messageId, + parentMessageId: userMessageId, + conversationId, + assistant_id, + thread_id, + model: assistant_id, + }, + }); + + // todo: retry logic + let response = await runAssistant({ openai, thread_id, run_id }); + logger.debug('[/assistants/chat/] response', response); + + if (response.run.status === RunStatus.IN_PROGRESS) { + response = await runAssistant({ + openai, + thread_id, + run_id, + in_progress: openai.in_progress, + }); + } + + /** @type {ResponseMessage} */ + const responseMessage = { + ...openai.responseMessage, + parentMessageId: userMessageId, + conversationId, + user: req.user.id, + assistant_id, + thread_id, + model: assistant_id, + }; + + // TODO: token count from usage returned in run // TODO: parse responses, save to db, send to user sendMessage(res, { title: 'New Chat', final: true, - conversation: { - conversationId: 'fake-convo-id', - title: 'New Chat', - }, + conversation, requestMessage: { - messageId: 'fake-user-message-id', - parentMessageId: '00000000-0000-0000-0000-000000000000', - conversationId: 'fake-convo-id', - sender: 'User', - text: req.body.text, - isCreatedByUser: true, - }, - responseMessage: { - messageId: 'fake-response-id', - conversationId: 'fake-convo-id', - parentMessageId: 'fake-user-message-id', - isCreatedByUser: false, - isEdited: false, - model: 'gpt-3.5-turbo-1106', - sender: 'Assistant', - text: response.choices[0].text, + parentMessageId, + thread_id, }, }); res.end(); + + await saveAssistantMessage({ ...responseMessage, model }); + + if (parentMessageId === Constants.NO_PARENT && !_thread_id) { + addTitle(req, { + text, + responseText: openai.responseText, + conversationId, + client, + }); + } + + await addThreadMetadata({ + openai, + thread_id, + messageId: responseMessage.messageId, + messages: response.messages, + }); + + if (!response.run.usage) { + await sleep(3000); + const completedRun = await openai.beta.threads.runs.retrieve(thread_id, run.id); + if (completedRun.usage) { + await recordUsage({ + ...completedRun.usage, + user: req.user.id, + model: completedRun.model ?? model, + conversationId, + }); + } + } else { + await recordUsage({ + ...response.run.usage, + user: req.user.id, + model: response.run.model ?? model, + conversationId, + }); + } } catch (error) { - // res.status(500).json({ error: error.message }); + if (error.message === 'Run cancelled') { + return res.end(); + } + logger.error('[/assistants/chat/]', error); - res.end(); + + if (!openai || !thread_id || !run_id) { + return res.status(500).json({ error: 'The Assistant run failed to initialize' }); + } + + try { + await cache.delete(cacheKey); + const cancelledRun = await openai.beta.threads.runs.cancel(thread_id, run_id); + logger.debug('Cancelled run:', cancelledRun); + } catch (error) { + logger.error('[abortRun] Error cancelling run', error); + } + + await sleep(2000); + try { + const run = await openai.beta.threads.runs.retrieve(thread_id, run_id); + await recordUsage({ + ...run.usage, + model: run.model, + user: req.user.id, + conversationId, + }); + } catch (error) { + logger.error('[/assistants/chat/] Error fetching or processing run', error); + } + + try { + const runMessages = await checkMessageGaps({ + openai, + run_id, + thread_id, + conversationId, + latestMessageId: responseMessageId, + }); + + const finalEvent = { + title: 'New Chat', + final: true, + conversation: await getConvo(req.user.id, conversationId), + runMessages, + }; + + if (res.headersSent && finalEvent) { + return sendMessage(res, finalEvent); + } + + res.json(finalEvent); + } catch (error) { + logger.error('[/assistants/chat/] Error finalizing error process', error); + return res.status(500).json({ error: 'The Assistant run failed' }); + } } }); diff --git a/api/server/routes/assistants/tools.js b/api/server/routes/assistants/tools.js new file mode 100644 index 000000000..324b62095 --- /dev/null +++ b/api/server/routes/assistants/tools.js @@ -0,0 +1,8 @@ +const express = require('express'); +const { getAvailableTools } = require('~/server/controllers/PluginController'); + +const router = express.Router(); + +router.get('/', getAvailableTools); + +module.exports = router; diff --git a/api/server/routes/convos.js b/api/server/routes/convos.js index 6213cfd2c..2af2e1054 100644 --- a/api/server/routes/convos.js +++ b/api/server/routes/convos.js @@ -1,10 +1,10 @@ const express = require('express'); const { CacheKeys } = require('librechat-data-provider'); -const { getConvosByPage, deleteConvos } = require('~/models/Conversation'); +const { initializeClient } = require('~/server/services/Endpoints/assistant'); +const { getConvosByPage, deleteConvos, getConvo, saveConvo } = require('~/models/Conversation'); const requireJwtAuth = require('~/server/middleware/requireJwtAuth'); -const { sleep } = require('~/server/services/AssistantService'); +const { sleep } = require('~/server/services/Runs/handle'); const getLogStores = require('~/cache/getLogStores'); -const { getConvo, saveConvo } = require('~/models'); const { logger } = require('~/config'); const router = express.Router(); @@ -47,28 +47,37 @@ router.post('/gen_title', async (req, res) => { await titleCache.delete(key); res.status(200).json({ title }); } else { - res - .status(404) - .json({ - message: 'Title not found or method not implemented for the conversation\'s endpoint', - }); + res.status(404).json({ + message: 'Title not found or method not implemented for the conversation\'s endpoint', + }); } }); router.post('/clear', async (req, res) => { let filter = {}; - const { conversationId, source } = req.body.arg; + const { conversationId, source, thread_id } = req.body.arg; if (conversationId) { filter = { conversationId }; } - // for debugging deletion source - // logger.debug('source:', source); - if (source === 'button' && !conversationId) { return res.status(200).send('No conversationId provided'); } + if (thread_id) { + /** @type {{ openai: OpenAI}} */ + const { openai } = await initializeClient({ req, res }); + try { + const response = await openai.beta.threads.del(thread_id); + logger.debug('Deleted OpenAI thread:', response); + } catch (error) { + logger.error('Error deleting OpenAI thread:', error); + } + } + + // for debugging deletion source + // logger.debug('source:', source); + try { const dbResponse = await deleteConvos(req.user.id, filter); res.status(201).json(dbResponse); diff --git a/api/server/routes/files/avatar.js b/api/server/routes/files/avatar.js index 5abba85f9..71ade965c 100644 --- a/api/server/routes/files/avatar.js +++ b/api/server/routes/files/avatar.js @@ -1,38 +1,36 @@ -const express = require('express'); const multer = require('multer'); - -const uploadAvatar = require('~/server/services/Files/images/avatar'); -const { requireJwtAuth } = require('~/server/middleware/'); -const User = require('~/models/User'); +const express = require('express'); +const { getStrategyFunctions } = require('~/server/services/Files/strategies'); +const { resizeAvatar } = require('~/server/services/Files/images/avatar'); +const { logger } = require('~/config'); const upload = multer(); const router = express.Router(); -router.post('/', requireJwtAuth, upload.single('input'), async (req, res) => { +router.post('/', upload.single('input'), async (req, res) => { try { const userId = req.user.id; const { manual } = req.body; const input = req.file.buffer; + if (!userId) { throw new Error('User ID is undefined'); } - // TODO: do not use Model directly, instead use a service method that uses the model - const user = await User.findById(userId).lean(); - - if (!user) { - throw new Error('User not found'); - } - const url = await uploadAvatar({ - input, + const fileStrategy = req.app.locals.fileStrategy; + const webPBuffer = await resizeAvatar({ userId, - manual, - fileStrategy: req.app.locals.fileStrategy, + input, }); + const { processAvatar } = getStrategyFunctions(fileStrategy); + const url = await processAvatar({ buffer: webPBuffer, userId, manual }); + res.json({ url }); } catch (error) { - res.status(500).json({ message: 'An error occurred while uploading the profile picture' }); + const message = 'An error occurred while uploading the profile picture'; + logger.error(message, error); + res.status(500).json({ message }); } }); diff --git a/api/server/routes/files/files.js b/api/server/routes/files/files.js index 3fea2e5d0..d44df747a 100644 --- a/api/server/routes/files/files.js +++ b/api/server/routes/files/files.js @@ -1,14 +1,17 @@ -const { z } = require('zod'); +const axios = require('axios'); +const fs = require('fs').promises; const express = require('express'); -const { FileSources } = require('librechat-data-provider'); -const { getStrategyFunctions } = require('~/server/services/Files/strategies'); -const { deleteFiles, getFiles } = require('~/models'); +const { isUUID } = require('librechat-data-provider'); +const { + filterFile, + processFileUpload, + processDeleteRequest, +} = require('~/server/services/Files/process'); +const { getFiles } = require('~/models/File'); const { logger } = require('~/config'); const router = express.Router(); -const isUUID = z.string().uuid(); - router.get('/', async (req, res) => { try { const files = await getFiles({ user: req.user.id }); @@ -19,6 +22,15 @@ router.get('/', async (req, res) => { } }); +router.get('/config', async (req, res) => { + try { + res.status(200).json(req.app.locals.fileConfig); + } catch (error) { + logger.error('[/files] Error getting fileConfig', error); + res.status(400).json({ message: 'Error in request', error: error.message }); + } +}); + router.delete('/', async (req, res) => { try { const { files: _files } = req.body; @@ -31,6 +43,11 @@ router.delete('/', async (req, res) => { if (!file.filepath) { return false; } + + if (/^file-/.test(file.file_id)) { + return true; + } + return isUUID.safeParse(file.file_id).success; }); @@ -39,29 +56,8 @@ router.delete('/', async (req, res) => { return; } - const file_ids = files.map((file) => file.file_id); - const deletionMethods = {}; - const promises = []; - promises.push(await deleteFiles(file_ids)); + await processDeleteRequest({ req, files }); - for (const file of files) { - const source = file.source ?? FileSources.local; - - if (deletionMethods[source]) { - promises.push(deletionMethods[source](req, file)); - continue; - } - - const { deleteFile } = getStrategyFunctions(source); - if (!deleteFile) { - throw new Error(`Delete function not implemented for ${source}`); - } - - deletionMethods[source] = deleteFile; - promises.push(deleteFile(req, file)); - } - - await Promise.all(promises); res.status(200).json({ message: 'Files deleted successfully' }); } catch (error) { logger.error('[/files] Error deleting files:', error); @@ -69,4 +65,69 @@ router.delete('/', async (req, res) => { } }); +router.get('/download/:fileId', async (req, res) => { + try { + const { fileId } = req.params; + + const options = { + headers: { + // TODO: Client initialization for OpenAI API Authentication + Authorization: `Bearer ${process.env.OPENAI_API_KEY}`, + }, + responseType: 'stream', + }; + + const fileResponse = await axios.get(`https://api.openai.com/v1/files/${fileId}`, { + headers: options.headers, + }); + const { filename } = fileResponse.data; + + const response = await axios.get(`https://api.openai.com/v1/files/${fileId}/content`, options); + res.setHeader('Content-Disposition', `attachment; filename="${filename}"`); + response.data.pipe(res); + } catch (error) { + console.error('Error downloading file:', error); + res.status(500).send('Error downloading file'); + } +}); + +router.post('/', async (req, res) => { + const file = req.file; + const metadata = req.body; + let cleanup = true; + + try { + filterFile({ req, file }); + + metadata.temp_file_id = metadata.file_id; + metadata.file_id = req.file_id; + + await processFileUpload({ req, res, file, metadata }); + } catch (error) { + let message = 'Error processing file'; + logger.error('[/files] Error processing file:', error); + cleanup = false; + + if (error.message?.includes('file_ids')) { + message += ': ' + error.message; + } + + // TODO: delete remote file if it exists + try { + await fs.unlink(file.path); + } catch (error) { + logger.error('[/files] Error deleting file:', error); + } + res.status(500).json({ message }); + } + + if (cleanup) { + try { + await fs.unlink(file.path); + } catch (error) { + logger.error('[/files/images] Error deleting file after file processing:', error); + } + } +}); + module.exports = router; diff --git a/api/server/routes/files/images.js b/api/server/routes/files/images.js index d1016f98f..374711c4a 100644 --- a/api/server/routes/files/images.js +++ b/api/server/routes/files/images.js @@ -1,49 +1,29 @@ -const { z } = require('zod'); const path = require('path'); const fs = require('fs').promises; const express = require('express'); -const upload = require('./multer'); -const { processImageUpload } = require('~/server/services/Files/process'); +const { filterFile, processImageFile } = require('~/server/services/Files/process'); const { logger } = require('~/config'); const router = express.Router(); -router.post('/', upload.single('file'), async (req, res) => { - const file = req.file; +router.post('/', async (req, res) => { const metadata = req.body; - // TODO: add file size/type validation - - const uuidSchema = z.string().uuid(); try { - if (!file) { - throw new Error('No file provided'); - } + filterFile({ req, file: req.file, image: true }); - if (!metadata.file_id) { - throw new Error('No file_id provided'); - } - - if (!metadata.width) { - throw new Error('No width provided'); - } - - if (!metadata.height) { - throw new Error('No height provided'); - } - /* parse to validate api call */ - uuidSchema.parse(metadata.file_id); metadata.temp_file_id = metadata.file_id; metadata.file_id = req.file_id; - await processImageUpload({ req, res, file, metadata }); + await processImageFile({ req, res, file: req.file, metadata }); } catch (error) { + // TODO: delete remote file if it exists logger.error('[/files/images] Error processing file:', error); try { const filepath = path.join( req.app.locals.paths.imageOutput, req.user.id, - path.basename(file.filename), + path.basename(req.file.filename), ); await fs.unlink(filepath); } catch (error) { @@ -51,16 +31,6 @@ router.post('/', upload.single('file'), async (req, res) => { } res.status(500).json({ message: 'Error processing file' }); } - - // do this if strategy is not local - // finally { - // try { - // // await fs.unlink(file.path); - // } catch (error) { - // logger.error('[/files/images] Error deleting file:', error); - - // } - // } }); module.exports = router; diff --git a/api/server/routes/files/index.js b/api/server/routes/files/index.js index 9afb900bb..c9f5ce167 100644 --- a/api/server/routes/files/index.js +++ b/api/server/routes/files/index.js @@ -1,24 +1,27 @@ const express = require('express'); -const router = express.Router(); -const { - uaParser, - checkBan, - requireJwtAuth, - // concurrentLimiter, - // messageIpLimiter, - // messageUserLimiter, -} = require('../../middleware'); +const createMulterInstance = require('./multer'); +const { uaParser, checkBan, requireJwtAuth, createFileLimiters } = require('~/server/middleware'); const files = require('./files'); const images = require('./images'); const avatar = require('./avatar'); -router.use(requireJwtAuth); -router.use(checkBan); -router.use(uaParser); +const initialize = async () => { + const router = express.Router(); + router.use(requireJwtAuth); + router.use(checkBan); + router.use(uaParser); -router.use('/', files); -router.use('/images', images); -router.use('/images/avatar', avatar); + const upload = await createMulterInstance(); + const { fileUploadIpLimiter, fileUploadUserLimiter } = createFileLimiters(); + router.post('*', fileUploadIpLimiter, fileUploadUserLimiter); + router.post('/', upload.single('file')); + router.post('/images', upload.single('file')); -module.exports = router; + router.use('/', files); + router.use('/images', images); + router.use('/images/avatar', avatar); + return router; +}; + +module.exports = { initialize }; diff --git a/api/server/routes/files/multer.js b/api/server/routes/files/multer.js index d5aea05a3..71a820ba5 100644 --- a/api/server/routes/files/multer.js +++ b/api/server/routes/files/multer.js @@ -2,13 +2,12 @@ const fs = require('fs'); const path = require('path'); const crypto = require('crypto'); const multer = require('multer'); - -const supportedTypes = ['image/jpeg', 'image/jpg', 'image/png', 'image/webp']; -const sizeLimit = 20 * 1024 * 1024; // 20 MB +const { fileConfig: defaultFileConfig, mergeFileConfig } = require('librechat-data-provider'); +const getCustomConfig = require('~/server/services/Config/getCustomConfig'); const storage = multer.diskStorage({ destination: function (req, file, cb) { - const outputPath = path.join(req.app.locals.paths.imageOutput, 'temp'); + const outputPath = path.join(req.app.locals.paths.uploads, 'temp', req.user.id); if (!fs.existsSync(outputPath)) { fs.mkdirSync(outputPath, { recursive: true }); } @@ -16,22 +15,30 @@ const storage = multer.diskStorage({ }, filename: function (req, file, cb) { req.file_id = crypto.randomUUID(); - const fileExt = path.extname(file.originalname); - cb(null, `img-${req.file_id}${fileExt}`); + cb(null, `${file.originalname}`); }, }); const fileFilter = (req, file, cb) => { - if (!supportedTypes.includes(file.mimetype)) { - return cb( - new Error('Unsupported file type. Only JPEG, JPG, PNG, and WEBP files are allowed.'), - false, - ); + if (!file) { + return cb(new Error('No file provided'), false); + } + + if (!defaultFileConfig.checkType(file.mimetype)) { + return cb(new Error('Unsupported file type: ' + file.mimetype), false); } cb(null, true); }; -const upload = multer({ storage, fileFilter, limits: { fileSize: sizeLimit } }); +const createMulterInstance = async () => { + const customConfig = await getCustomConfig(); + const fileConfig = mergeFileConfig(customConfig?.fileConfig); + return multer({ + storage, + fileFilter, + limits: { fileSize: fileConfig.serverFileSizeLimit }, + }); +}; -module.exports = upload; +module.exports = createMulterInstance; diff --git a/api/server/services/ActionService.js b/api/server/services/ActionService.js new file mode 100644 index 000000000..76e846af8 --- /dev/null +++ b/api/server/services/ActionService.js @@ -0,0 +1,118 @@ +const { AuthTypeEnum } = require('librechat-data-provider'); +const { encryptV2, decryptV2 } = require('~/server/utils/crypto'); +const { getActions } = require('~/models/Action'); +const { logger } = require('~/config'); + +/** + * Loads action sets based on the user and assistant ID. + * + * @param {Object} params - The parameters for loading action sets. + * @param {string} params.user - The user identifier. + * @param {string} params.assistant_id - The assistant identifier. + * @returns {Promise} A promise that resolves to an array of actions or `null` if no match. + */ +async function loadActionSets({ user, assistant_id }) { + return await getActions({ user, assistant_id }, true); +} + +/** + * Creates a general tool for an entire action set. + * + * @param {Object} params - The parameters for loading action sets. + * @param {Action} params.action - The action set. Necessary for decrypting authentication values. + * @param {ActionRequest} params.requestBuilder - The ActionRequest builder class to execute the API call. + * @returns { { _call: (toolInput: Object) => unknown} } An object with `_call` method to execute the tool input. + */ +function createActionTool({ action, requestBuilder }) { + action.metadata = decryptMetadata(action.metadata); + const _call = async (toolInput) => { + try { + requestBuilder.setParams(toolInput); + if (action.metadata.auth && action.metadata.auth.type !== AuthTypeEnum.None) { + await requestBuilder.setAuth(action.metadata); + } + const res = await requestBuilder.execute(); + if (typeof res.data === 'object') { + return JSON.stringify(res.data); + } + return res.data; + } catch (error) { + logger.error(`API call to ${action.metadata.domain} failed`, error); + if (error.response) { + const { status, data } = error.response; + return `API call to ${action.metadata.domain} failed with status ${status}: ${data}`; + } + + return `API call to ${action.metadata.domain} failed.`; + } + }; + + return { + _call, + }; +} + +/** + * Encrypts sensitive metadata values for an action. + * + * @param {ActionMetadata} metadata - The action metadata to encrypt. + * @returns {ActionMetadata} The updated action metadata with encrypted values. + */ +function encryptMetadata(metadata) { + const encryptedMetadata = { ...metadata }; + + // ServiceHttp + if (metadata.auth && metadata.auth.type === AuthTypeEnum.ServiceHttp) { + if (metadata.api_key) { + encryptedMetadata.api_key = encryptV2(metadata.api_key); + } + } + + // OAuth + else if (metadata.auth && metadata.auth.type === AuthTypeEnum.OAuth) { + if (metadata.oauth_client_id) { + encryptedMetadata.oauth_client_id = encryptV2(metadata.oauth_client_id); + } + if (metadata.oauth_client_secret) { + encryptedMetadata.oauth_client_secret = encryptV2(metadata.oauth_client_secret); + } + } + + return encryptedMetadata; +} + +/** + * Decrypts sensitive metadata values for an action. + * + * @param {ActionMetadata} metadata - The action metadata to decrypt. + * @returns {ActionMetadata} The updated action metadata with decrypted values. + */ +function decryptMetadata(metadata) { + const decryptedMetadata = { ...metadata }; + + // ServiceHttp + if (metadata.auth && metadata.auth.type === AuthTypeEnum.ServiceHttp) { + if (metadata.api_key) { + decryptedMetadata.api_key = decryptV2(metadata.api_key); + } + } + + // OAuth + else if (metadata.auth && metadata.auth.type === AuthTypeEnum.OAuth) { + if (metadata.oauth_client_id) { + decryptedMetadata.oauth_client_id = decryptV2(metadata.oauth_client_id); + } + if (metadata.oauth_client_secret) { + decryptedMetadata.oauth_client_secret = decryptV2(metadata.oauth_client_secret); + } + } + + return decryptedMetadata; +} + +module.exports = { + loadActionSets, + createActionTool, + encryptMetadata, + decryptMetadata, +}; diff --git a/api/server/services/AppService.js b/api/server/services/AppService.js index b62d274d5..5089392ed 100644 --- a/api/server/services/AppService.js +++ b/api/server/services/AppService.js @@ -1,7 +1,10 @@ -const { FileSources } = require('librechat-data-provider'); +const { FileSources, EModelEndpoint, Constants } = require('librechat-data-provider'); const { initializeFirebase } = require('./Files/Firebase/initialize'); const loadCustomConfig = require('./Config/loadCustomConfig'); +const handleRateLimits = require('./Config/handleRateLimits'); +const { loadAndFormatTools } = require('./ToolService'); const paths = require('~/config/paths'); +const { logger } = require('~/config'); /** * @@ -12,13 +15,7 @@ const paths = require('~/config/paths'); const AppService = async (app) => { /** @type {TCustomConfig}*/ const config = (await loadCustomConfig()) ?? {}; - const socialLogins = config?.registration?.socialLogins ?? [ - 'google', - 'facebook', - 'openid', - 'github', - 'discord', - ]; + const fileStrategy = config.fileStrategy ?? FileSources.local; process.env.CDN_PROVIDER = fileStrategy; @@ -26,10 +23,71 @@ const AppService = async (app) => { initializeFirebase(); } + /** @type {Record} */ + endpointLocals[EModelEndpoint.assistants] = { + disableBuilder, + pollIntervalMs, + timeoutMs, + supportedIds, + excludedIds, + }; + } + app.locals = { socialLogins, + availableTools, fileStrategy, + fileConfig: config?.fileConfig, paths, + ...endpointLocals, }; }; diff --git a/api/server/services/AppService.spec.js b/api/server/services/AppService.spec.js index 1f3f2245b..2bd33cbb9 100644 --- a/api/server/services/AppService.spec.js +++ b/api/server/services/AppService.spec.js @@ -13,6 +13,24 @@ jest.mock('./Config/loadCustomConfig', () => { jest.mock('./Files/Firebase/initialize', () => ({ initializeFirebase: jest.fn(), })); +jest.mock('./ToolService', () => ({ + loadAndFormatTools: jest.fn().mockReturnValue({ + ExampleTool: { + type: 'function', + function: { + description: 'Example tool function', + name: 'exampleFunction', + parameters: { + type: 'object', + properties: { + param1: { type: 'string', description: 'An example parameter' }, + }, + required: ['param1'], + }, + }, + }, + }), +})); describe('AppService', () => { let app; @@ -30,10 +48,39 @@ describe('AppService', () => { expect(app.locals).toEqual({ socialLogins: ['testLogin'], fileStrategy: 'testStrategy', + availableTools: { + ExampleTool: { + type: 'function', + function: expect.objectContaining({ + description: 'Example tool function', + name: 'exampleFunction', + parameters: expect.objectContaining({ + type: 'object', + properties: expect.any(Object), + required: expect.arrayContaining(['param1']), + }), + }), + }, + }, paths: expect.anything(), }); }); + it('should log a warning if the config version is outdated', async () => { + require('./Config/loadCustomConfig').mockImplementationOnce(() => + Promise.resolve({ + version: '0.9.0', // An outdated version for this test + registration: { socialLogins: ['testLogin'] }, + fileStrategy: 'testStrategy', + }), + ); + + await AppService(app); + + const { logger } = require('~/config'); + expect(logger.info).toHaveBeenCalledWith(expect.stringContaining('Outdated Config version')); + }); + it('should initialize Firebase when fileStrategy is firebase', async () => { require('./Config/loadCustomConfig').mockImplementationOnce(() => Promise.resolve({ @@ -48,4 +95,217 @@ describe('AppService', () => { expect(process.env.CDN_PROVIDER).toEqual(FileSources.firebase); }); + + it('should load and format tools accurately with defined structure', async () => { + const { loadAndFormatTools } = require('./ToolService'); + await AppService(app); + + expect(loadAndFormatTools).toHaveBeenCalledWith({ + directory: expect.anything(), + filter: expect.anything(), + }); + + expect(app.locals.availableTools.ExampleTool).toBeDefined(); + expect(app.locals.availableTools.ExampleTool).toEqual({ + type: 'function', + function: { + description: 'Example tool function', + name: 'exampleFunction', + parameters: { + type: 'object', + properties: { + param1: { type: 'string', description: 'An example parameter' }, + }, + required: ['param1'], + }, + }, + }); + }); + + it('should correctly configure endpoints based on custom config', async () => { + require('./Config/loadCustomConfig').mockImplementationOnce(() => + Promise.resolve({ + endpoints: { + assistants: { + disableBuilder: true, + pollIntervalMs: 5000, + timeoutMs: 30000, + supportedIds: ['id1', 'id2'], + }, + }, + }), + ); + + await AppService(app); + + expect(app.locals).toHaveProperty('assistants'); + expect(app.locals.assistants).toEqual( + expect.objectContaining({ + disableBuilder: true, + pollIntervalMs: 5000, + timeoutMs: 30000, + supportedIds: expect.arrayContaining(['id1', 'id2']), + }), + ); + }); + + it('should not modify FILE_UPLOAD environment variables without rate limits', async () => { + // Setup initial environment variables + process.env.FILE_UPLOAD_IP_MAX = '10'; + process.env.FILE_UPLOAD_IP_WINDOW = '15'; + process.env.FILE_UPLOAD_USER_MAX = '5'; + process.env.FILE_UPLOAD_USER_WINDOW = '20'; + + const initialEnv = { ...process.env }; + + await AppService(app); + + // Expect environment variables to remain unchanged + expect(process.env.FILE_UPLOAD_IP_MAX).toEqual(initialEnv.FILE_UPLOAD_IP_MAX); + expect(process.env.FILE_UPLOAD_IP_WINDOW).toEqual(initialEnv.FILE_UPLOAD_IP_WINDOW); + expect(process.env.FILE_UPLOAD_USER_MAX).toEqual(initialEnv.FILE_UPLOAD_USER_MAX); + expect(process.env.FILE_UPLOAD_USER_WINDOW).toEqual(initialEnv.FILE_UPLOAD_USER_WINDOW); + }); + + it('should correctly set FILE_UPLOAD environment variables based on rate limits', async () => { + // Define and mock a custom configuration with rate limits + const rateLimitsConfig = { + rateLimits: { + fileUploads: { + ipMax: '100', + ipWindowInMinutes: '60', + userMax: '50', + userWindowInMinutes: '30', + }, + }, + }; + + require('./Config/loadCustomConfig').mockImplementationOnce(() => + Promise.resolve(rateLimitsConfig), + ); + + await AppService(app); + + // Verify that process.env has been updated according to the rate limits config + expect(process.env.FILE_UPLOAD_IP_MAX).toEqual('100'); + expect(process.env.FILE_UPLOAD_IP_WINDOW).toEqual('60'); + expect(process.env.FILE_UPLOAD_USER_MAX).toEqual('50'); + expect(process.env.FILE_UPLOAD_USER_WINDOW).toEqual('30'); + }); + + it('should fallback to default FILE_UPLOAD environment variables when rate limits are unspecified', async () => { + // Setup initial environment variables to non-default values + process.env.FILE_UPLOAD_IP_MAX = 'initialMax'; + process.env.FILE_UPLOAD_IP_WINDOW = 'initialWindow'; + process.env.FILE_UPLOAD_USER_MAX = 'initialUserMax'; + process.env.FILE_UPLOAD_USER_WINDOW = 'initialUserWindow'; + + // Mock a custom configuration without specific rate limits + require('./Config/loadCustomConfig').mockImplementationOnce(() => Promise.resolve({})); + + await AppService(app); + + // Verify that process.env falls back to the initial values + expect(process.env.FILE_UPLOAD_IP_MAX).toEqual('initialMax'); + expect(process.env.FILE_UPLOAD_IP_WINDOW).toEqual('initialWindow'); + expect(process.env.FILE_UPLOAD_USER_MAX).toEqual('initialUserMax'); + expect(process.env.FILE_UPLOAD_USER_WINDOW).toEqual('initialUserWindow'); + }); +}); + +describe('AppService updating app.locals', () => { + let app; + let initialEnv; + + beforeEach(() => { + // Store initial environment variables to restore them after each test + initialEnv = { ...process.env }; + + app = { locals: {} }; + process.env.CDN_PROVIDER = undefined; + }); + + afterEach(() => { + // Restore initial environment variables + process.env = { ...initialEnv }; + }); + + it('should update app.locals with default values if loadCustomConfig returns undefined', async () => { + // Mock loadCustomConfig to return undefined + require('./Config/loadCustomConfig').mockImplementationOnce(() => Promise.resolve(undefined)); + + await AppService(app); + + expect(app.locals).toBeDefined(); + expect(app.locals.paths).toBeDefined(); + expect(app.locals.availableTools).toBeDefined(); + expect(app.locals.fileStrategy).toEqual(FileSources.local); + }); + + it('should update app.locals with values from loadCustomConfig', async () => { + // Mock loadCustomConfig to return a specific config object + const customConfig = { + fileStrategy: 'firebase', + registration: { socialLogins: ['testLogin'] }, + }; + require('./Config/loadCustomConfig').mockImplementationOnce(() => + Promise.resolve(customConfig), + ); + + await AppService(app); + + expect(app.locals).toBeDefined(); + expect(app.locals.paths).toBeDefined(); + expect(app.locals.availableTools).toBeDefined(); + expect(app.locals.fileStrategy).toEqual(customConfig.fileStrategy); + expect(app.locals.socialLogins).toEqual(customConfig.registration.socialLogins); + }); + + it('should apply the assistants endpoint configuration correctly to app.locals', async () => { + const mockConfig = { + endpoints: { + assistants: { + disableBuilder: true, + pollIntervalMs: 5000, + timeoutMs: 30000, + supportedIds: ['id1', 'id2'], + }, + }, + }; + require('./Config/loadCustomConfig').mockImplementationOnce(() => Promise.resolve(mockConfig)); + + const app = { locals: {} }; + await AppService(app); + + expect(app.locals).toHaveProperty('assistants'); + const { assistants } = app.locals; + expect(assistants.disableBuilder).toBe(true); + expect(assistants.pollIntervalMs).toBe(5000); + expect(assistants.timeoutMs).toBe(30000); + expect(assistants.supportedIds).toEqual(['id1', 'id2']); + expect(assistants.excludedIds).toBeUndefined(); + }); + + it('should log a warning when both supportedIds and excludedIds are provided', async () => { + const mockConfig = { + endpoints: { + assistants: { + disableBuilder: false, + pollIntervalMs: 3000, + timeoutMs: 20000, + supportedIds: ['id1', 'id2'], + excludedIds: ['id3'], + }, + }, + }; + require('./Config/loadCustomConfig').mockImplementationOnce(() => Promise.resolve(mockConfig)); + + const app = { locals: {} }; + await require('./AppService')(app); + + const { logger } = require('~/config'); + expect(logger.warn).toHaveBeenCalledWith( + expect.stringContaining('Both `supportedIds` and `excludedIds` are defined'), + ); + }); }); diff --git a/api/server/services/AssistantService.js b/api/server/services/AssistantService.js index b6eabb753..5ddd9a4c6 100644 --- a/api/server/services/AssistantService.js +++ b/api/server/services/AssistantService.js @@ -1,256 +1,93 @@ -const RunManager = require('./Runs/RunMananger'); +const path = require('path'); +const { klona } = require('klona'); +const { + StepTypes, + RunStatus, + StepStatus, + FilePurpose, + ContentTypes, + ToolCallTypes, + imageExtRegex, + imageGenTools, + EModelEndpoint, + defaultOrderQuery, +} = require('librechat-data-provider'); +const { retrieveAndProcessFile } = require('~/server/services/Files/process'); +const { RunManager, waitForRun, sleep } = require('~/server/services/Runs'); +const { processRequiredActions } = require('~/server/services/ToolService'); +const { createOnProgress, sendMessage } = require('~/server/utils'); +const { TextStream } = require('~/app/clients'); +const { logger } = require('~/config'); /** - * @typedef {Object} Message - * @property {string} id - The identifier of the message. - * @property {string} object - The object type, always 'thread.message'. - * @property {number} created_at - The Unix timestamp (in seconds) for when the message was created. - * @property {string} thread_id - The thread ID that this message belongs to. - * @property {string} role - The entity that produced the message. One of 'user' or 'assistant'. - * @property {Object[]} content - The content of the message in an array of text and/or images. - * @property {string} content[].type - The type of content, either 'text' or 'image_file'. - * @property {Object} [content[].text] - The text content, present if type is 'text'. - * @property {string} content[].text.value - The data that makes up the text. - * @property {Object[]} [content[].text.annotations] - Annotations for the text content. - * @property {Object} [content[].image_file] - The image file content, present if type is 'image_file'. - * @property {string} content[].image_file.file_id - The File ID of the image in the message content. - * @property {string[]} [file_ids] - Optional list of File IDs for the message. - * @property {string|null} [assistant_id] - If applicable, the ID of the assistant that authored this message. - * @property {string|null} [run_id] - If applicable, the ID of the run associated with the authoring of this message. - * @property {Object} [metadata] - Optional metadata for the message, a map of key-value pairs. - */ - -/** - * @typedef {Object} FunctionTool - * @property {string} type - The type of tool, 'function'. - * @property {Object} function - The function definition. - * @property {string} function.description - A description of what the function does. - * @property {string} function.name - The name of the function to be called. - * @property {Object} function.parameters - The parameters the function accepts, described as a JSON Schema object. - */ - -/** - * @typedef {Object} Tool - * @property {string} type - The type of tool, can be 'code_interpreter', 'retrieval', or 'function'. - * @property {FunctionTool} [function] - The function tool, present if type is 'function'. - */ - -/** - * @typedef {Object} Run - * @property {string} id - The identifier of the run. - * @property {string} object - The object type, always 'thread.run'. - * @property {number} created_at - The Unix timestamp (in seconds) for when the run was created. - * @property {string} thread_id - The ID of the thread that was executed on as a part of this run. - * @property {string} assistant_id - The ID of the assistant used for execution of this run. - * @property {string} status - The status of the run (e.g., 'queued', 'completed'). - * @property {Object} [required_action] - Details on the action required to continue the run. - * @property {string} required_action.type - The type of required action, always 'submit_tool_outputs'. - * @property {Object} required_action.submit_tool_outputs - Details on the tool outputs needed for the run to continue. - * @property {Object[]} required_action.submit_tool_outputs.tool_calls - A list of the relevant tool calls. - * @property {string} required_action.submit_tool_outputs.tool_calls[].id - The ID of the tool call. - * @property {string} required_action.submit_tool_outputs.tool_calls[].type - The type of tool call the output is required for, always 'function'. - * @property {Object} required_action.submit_tool_outputs.tool_calls[].function - The function definition. - * @property {string} required_action.submit_tool_outputs.tool_calls[].function.name - The name of the function. - * @property {string} required_action.submit_tool_outputs.tool_calls[].function.arguments - The arguments that the model expects you to pass to the function. - * @property {Object} [last_error] - The last error associated with this run. - * @property {string} last_error.code - One of 'server_error' or 'rate_limit_exceeded'. - * @property {string} last_error.message - A human-readable description of the error. - * @property {number} [expires_at] - The Unix timestamp (in seconds) for when the run will expire. - * @property {number} [started_at] - The Unix timestamp (in seconds) for when the run was started. - * @property {number} [cancelled_at] - The Unix timestamp (in seconds) for when the run was cancelled. - * @property {number} [failed_at] - The Unix timestamp (in seconds) for when the run failed. - * @property {number} [completed_at] - The Unix timestamp (in seconds) for when the run was completed. - * @property {string} [model] - The model that the assistant used for this run. - * @property {string} [instructions] - The instructions that the assistant used for this run. - * @property {Tool[]} [tools] - The list of tools used for this run. - * @property {string[]} [file_ids] - The list of File IDs used for this run. - * @property {Object} [metadata] - Metadata associated with this run. - */ - -/** - * @typedef {Object} RunStep - * @property {string} id - The identifier of the run step. - * @property {string} object - The object type, always 'thread.run.step'. - * @property {number} created_at - The Unix timestamp (in seconds) for when the run step was created. - * @property {string} assistant_id - The ID of the assistant associated with the run step. - * @property {string} thread_id - The ID of the thread that was run. - * @property {string} run_id - The ID of the run that this run step is a part of. - * @property {string} type - The type of run step, either 'message_creation' or 'tool_calls'. - * @property {string} status - The status of the run step, can be 'in_progress', 'cancelled', 'failed', 'completed', or 'expired'. - * @property {Object} step_details - The details of the run step. - * @property {Object} [last_error] - The last error associated with this run step. - * @property {string} last_error.code - One of 'server_error' or 'rate_limit_exceeded'. - * @property {string} last_error.message - A human-readable description of the error. - * @property {number} [expired_at] - The Unix timestamp (in seconds) for when the run step expired. - * @property {number} [cancelled_at] - The Unix timestamp (in seconds) for when the run step was cancelled. - * @property {number} [failed_at] - The Unix timestamp (in seconds) for when the run step failed. - * @property {number} [completed_at] - The Unix timestamp (in seconds) for when the run step completed. - * @property {Object} [metadata] - Metadata associated with this run step, a map of up to 16 key-value pairs. - */ - -/** - * @typedef {Object} StepMessage - * @property {Message} message - The complete message object created by the step. - * @property {string} id - The identifier of the run step. - * @property {string} object - The object type, always 'thread.run.step'. - * @property {number} created_at - The Unix timestamp (in seconds) for when the run step was created. - * @property {string} assistant_id - The ID of the assistant associated with the run step. - * @property {string} thread_id - The ID of the thread that was run. - * @property {string} run_id - The ID of the run that this run step is a part of. - * @property {string} type - The type of run step, either 'message_creation' or 'tool_calls'. - * @property {string} status - The status of the run step, can be 'in_progress', 'cancelled', 'failed', 'completed', or 'expired'. - * @property {Object} step_details - The details of the run step. - * @property {Object} [last_error] - The last error associated with this run step. - * @property {string} last_error.code - One of 'server_error' or 'rate_limit_exceeded'. - * @property {string} last_error.message - A human-readable description of the error. - * @property {number} [expired_at] - The Unix timestamp (in seconds) for when the run step expired. - * @property {number} [cancelled_at] - The Unix timestamp (in seconds) for when the run step was cancelled. - * @property {number} [failed_at] - The Unix timestamp (in seconds) for when the run step failed. - * @property {number} [completed_at] - The Unix timestamp (in seconds) for when the run step completed. - * @property {Object} [metadata] - Metadata associated with this run step, a map of up to 16 key-value pairs. - */ - -/** - * Initializes a new thread or adds messages to an existing thread. + * Sorts, processes, and flattens messages to a single string. * - * @param {Object} params - The parameters for initializing a thread. - * @param {OpenAI} params.openai - The OpenAI client instance. - * @param {Object} params.body - The body of the request. - * @param {Message[]} params.body.messages - A list of messages to start the thread with. - * @param {Object} [params.body.metadata] - Optional metadata for the thread. - * @param {string} [params.thread_id] - Optional existing thread ID. If provided, a message will be added to this thread. - * @return {Promise} A promise that resolves to the newly created thread object or the updated thread object. + * @param {Object} params - Params for creating the onTextProgress function. + * @param {OpenAIClient} params.openai - The OpenAI client instance. + * @param {string} params.conversationId - The current conversation ID. + * @param {string} params.userMessageId - The user message ID; response's `parentMessageId`. + * @param {string} params.messageId - The response message ID. + * @param {string} params.thread_id - The current thread ID. + * @returns {void} */ -async function initThread({ openai, body, thread_id: _thread_id }) { - let thread = {}; - const messages = []; - if (_thread_id) { - const message = await openai.beta.threads.messages.create(_thread_id, body.messages[0]); - messages.push(message); - } else { - thread = await openai.beta.threads.create(body); - } +async function createOnTextProgress({ + openai, + conversationId, + userMessageId, + messageId, + thread_id, +}) { + openai.responseMessage = { + conversationId, + parentMessageId: userMessageId, + role: 'assistant', + messageId, + content: [], + }; - const thread_id = _thread_id ?? thread.id; - return { messages, thread_id, ...thread }; -} + openai.responseText = ''; -/** - * Creates a run on a thread using the OpenAI API. - * - * @param {Object} params - The parameters for creating a run. - * @param {OpenAI} params.openai - The OpenAI client instance. - * @param {string} params.thread_id - The ID of the thread to run. - * @param {Object} params.body - The body of the request to create a run. - * @param {string} params.body.assistant_id - The ID of the assistant to use for this run. - * @param {string} [params.body.model] - Optional. The ID of the model to be used for this run. - * @param {string} [params.body.instructions] - Optional. Override the default system message of the assistant. - * @param {Object[]} [params.body.tools] - Optional. Override the tools the assistant can use for this run. - * @param {string[]} [params.body.file_ids] - Optional. List of File IDs the assistant can use for this run. - * @param {Object} [params.body.metadata] - Optional. Metadata for the run. - * @return {Promise} A promise that resolves to the created run object. - */ -async function createRun({ openai, thread_id, body }) { - const run = await openai.beta.threads.runs.create(thread_id, body); - return run; -} + openai.addContentData = (data) => { + const { type, index } = data; + openai.responseMessage.content[index] = { type, [type]: data[type] }; -// /** -// * Retrieves all steps of a run. -// * -// * @param {Object} params - The parameters for the retrieveRunSteps function. -// * @param {OpenAI} params.openai - The OpenAI client instance. -// * @param {string} params.thread_id - The ID of the thread associated with the run. -// * @param {string} params.run_id - The ID of the run to retrieve steps for. -// * @return {Promise} A promise that resolves to an array of RunStep objects. -// */ -// async function retrieveRunSteps({ openai, thread_id, run_id }) { -// const runSteps = await openai.beta.threads.runs.steps.list(thread_id, run_id); -// return runSteps; -// } - -/** - * Delays the execution for a specified number of milliseconds. - * - * @param {number} ms - The number of milliseconds to delay. - * @return {Promise} A promise that resolves after the specified delay. - */ -function sleep(ms) { - return new Promise((resolve) => setTimeout(resolve, ms)); -} - -/** - * Waits for a run to complete by repeatedly checking its status. It uses a RunManager instance to fetch and manage run steps based on the run status. - * - * @param {Object} params - The parameters for the waitForRun function. - * @param {OpenAI} params.openai - The OpenAI client instance. - * @param {string} params.run_id - The ID of the run to wait for. - * @param {string} params.thread_id - The ID of the thread associated with the run. - * @param {RunManager} params.runManager - The RunManager instance to manage run steps. - * @param {number} params.pollIntervalMs - The interval for polling the run status, default is 500 milliseconds. - * @return {Promise} A promise that resolves to the last fetched run object. - */ -async function waitForRun({ openai, run_id, thread_id, runManager, pollIntervalMs = 500 }) { - const timeout = 18000; // 18 seconds - let timeElapsed = 0; - let run; - - // this runManager will be passed in from the caller - // const runManager = new RunManager({ - // 'in_progress': (step) => { /* ... */ }, - // 'queued': (step) => { /* ... */ }, - // }); - - while (timeElapsed < timeout) { - run = await openai.beta.threads.runs.retrieve(thread_id, run_id); - console.log(`Run status: ${run.status}`); - - if (!['in_progress', 'queued'].includes(run.status)) { - await runManager.fetchRunSteps({ - openai, - thread_id: thread_id, - run_id: run_id, - runStatus: run.status, - final: true, - }); - break; + if (type === ContentTypes.TEXT) { + openai.responseText += data[type].value; + return; } - // may use in future - // await runManager.fetchRunSteps({ - // openai, - // thread_id: thread_id, - // run_id: run_id, - // runStatus: run.status, - // }); + const contentData = { + index, + type, + [type]: data[type], + messageId, + thread_id, + conversationId, + }; - await sleep(pollIntervalMs); - timeElapsed += pollIntervalMs; - } - - return run; + logger.debug('Content data:', contentData); + sendMessage(openai.res, contentData); + }; } /** * Retrieves the response from an OpenAI run. * * @param {Object} params - The parameters for getting the response. - * @param {OpenAI} params.openai - The OpenAI client instance. + * @param {OpenAIClient} params.openai - The OpenAI client instance. * @param {string} params.run_id - The ID of the run to get the response for. * @param {string} params.thread_id - The ID of the thread associated with the run. - * @return {Promise} + * @return {Promise} */ async function getResponse({ openai, run_id, thread_id }) { const run = await waitForRun({ openai, run_id, thread_id, pollIntervalMs: 500 }); - if (run.status === 'completed') { - const messages = await openai.beta.threads.messages.list(thread_id, { - order: 'asc', - }); + if (run.status === RunStatus.COMPLETED) { + const messages = await openai.beta.threads.messages.list(thread_id, defaultOrderQuery); const newMessages = messages.data.filter((msg) => msg.run_id === run_id); return newMessages; - } else if (run.status === 'requires_action') { + } else if (run.status === RunStatus.REQUIRES_ACTION) { const actions = []; run.required_action?.submit_tool_outputs.tool_calls.forEach((item) => { const functionCall = item.function; @@ -259,7 +96,6 @@ async function getResponse({ openai, run_id, thread_id }) { tool: functionCall.name, toolInput: args, toolCallId: item.id, - log: '', run_id, thread_id, }); @@ -273,90 +109,432 @@ async function getResponse({ openai, run_id, thread_id }) { } /** - * Initializes a RunManager with handlers, then invokes waitForRun to monitor and manage an OpenAI run. - * - * @param {Object} params - The parameters for managing and monitoring the run. - * @param {OpenAI} params.openai - The OpenAI client instance. - * @param {string} params.run_id - The ID of the run to manage and monitor. - * @param {string} params.thread_id - The ID of the thread associated with the run. - * @return {Promise} A promise that resolves to an object containing the run and managed steps. + * Filters the steps to keep only the most recent instance of each unique step. + * @param {RunStep[]} steps - The array of RunSteps to filter. + * @return {RunStep[]} The filtered array of RunSteps. */ -async function handleRun({ openai, run_id, thread_id }) { - let steps; - let messages; - const runManager = new RunManager({ - // 'in_progress': async ({ step, final, isLast }) => { - // // Define logic for handling steps with 'in_progress' status - // }, - // 'queued': async ({ step, final, isLast }) => { - // // Define logic for handling steps with 'queued' status - // }, - final: async ({ step, runStatus, stepsByStatus }) => { - console.log(`Final step for ${run_id} with status ${runStatus}`); - console.dir(step, { depth: null }); +function filterSteps(steps = []) { + if (steps.length <= 1) { + return steps; + } + const stepMap = new Map(); - const promises = []; - promises.push( - openai.beta.threads.messages.list(thread_id, { - order: 'asc', - }), - ); + steps.forEach((step) => { + if (!step) { + return; + } - const finalSteps = stepsByStatus[runStatus]; + const effectiveTimestamp = Math.max( + step.created_at, + step.expired_at || 0, + step.cancelled_at || 0, + step.failed_at || 0, + step.completed_at || 0, + ); - // loop across all statuses, may use in the future - // for (const [_status, stepsPromises] of Object.entries(stepsByStatus)) { - // promises.push(...stepsPromises); - // } - for (const stepPromise of finalSteps) { - promises.push(stepPromise); + if (!stepMap.has(step.id) || effectiveTimestamp > stepMap.get(step.id).effectiveTimestamp) { + const latestStep = { ...step, effectiveTimestamp }; + if (latestStep.last_error) { + // testing to see if we ever step into this } - - const resolved = await Promise.all(promises); - const res = resolved.shift(); - messages = res.data.filter((msg) => msg.run_id === run_id); - resolved.push(step); - steps = resolved; - }, + stepMap.set(step.id, latestStep); + } }); - const run = await waitForRun({ openai, run_id, thread_id, runManager, pollIntervalMs: 500 }); - - return { run, steps, messages }; -} - -/** - * Maps messages to their corresponding steps. Steps with message creation will be paired with their messages, - * while steps without message creation will be returned as is. - * - * @param {RunStep[]} steps - An array of steps from the run. - * @param {Message[]} messages - An array of message objects. - * @returns {(StepMessage | RunStep)[]} An array where each element is either a step with its corresponding message (StepMessage) or a step without a message (RunStep). - */ -function mapMessagesToSteps(steps, messages) { - // Create a map of messages indexed by their IDs for efficient lookup - const messageMap = messages.reduce((acc, msg) => { - acc[msg.id] = msg; - return acc; - }, {}); - - // Map each step to its corresponding message, or return the step as is if no message ID is present - return steps.map((step) => { - const messageId = step.step_details?.message_creation?.message_id; - - if (messageId && messageMap[messageId]) { - return { step, message: messageMap[messageId] }; - } + return Array.from(stepMap.values()).map((step) => { + delete step.effectiveTimestamp; return step; }); } +/** + * @callback InProgressFunction + * @param {Object} params - The parameters for the in progress step. + * @param {RunStep} params.step - The step object with details about the message creation. + * @returns {Promise} - A promise that resolves when the step is processed. + */ + +function hasToolCallChanged(previousCall, currentCall) { + return JSON.stringify(previousCall) !== JSON.stringify(currentCall); +} + +/** + * Creates a handler function for steps in progress, specifically for + * processing messages and managing seen completed messages. + * + * @param {OpenAIClient} openai - The OpenAI client instance. + * @param {string} thread_id - The ID of the thread the run is in. + * @param {ThreadMessage[]} messages - The accumulated messages for the run. + * @return {InProgressFunction} a function to handle steps in progress. + */ +function createInProgressHandler(openai, thread_id, messages) { + openai.index = 0; + openai.mappedOrder = new Map(); + openai.seenToolCalls = new Map(); + openai.processedFileIds = new Set(); + openai.completeToolCallSteps = new Set(); + openai.seenCompletedMessages = new Set(); + + /** + * The in_progress function for handling message creation steps. + * + * @type {InProgressFunction} + */ + async function in_progress({ step }) { + if (step.type === StepTypes.TOOL_CALLS) { + const { tool_calls } = step.step_details; + + for (const _toolCall of tool_calls) { + /** @type {StepToolCall} */ + const toolCall = _toolCall; + const previousCall = openai.seenToolCalls.get(toolCall.id); + + // If the tool call isn't new and hasn't changed + if (previousCall && !hasToolCallChanged(previousCall, toolCall)) { + continue; + } + + let toolCallIndex = openai.mappedOrder.get(toolCall.id); + if (toolCallIndex === undefined) { + // New tool call + toolCallIndex = openai.index; + openai.mappedOrder.set(toolCall.id, openai.index); + openai.index++; + } + + if (step.status === StepStatus.IN_PROGRESS) { + toolCall.progress = + previousCall && previousCall.progress + ? Math.min(previousCall.progress + 0.2, 0.95) + : 0.01; + } else { + toolCall.progress = 1; + openai.completeToolCallSteps.add(step.id); + } + + if ( + toolCall.type === ToolCallTypes.CODE_INTERPRETER && + step.status === StepStatus.COMPLETED + ) { + const { outputs } = toolCall[toolCall.type]; + + for (const output of outputs) { + if (output.type !== 'image') { + continue; + } + + if (openai.processedFileIds.has(output.image?.file_id)) { + continue; + } + + const { file_id } = output.image; + const file = await retrieveAndProcessFile({ + openai, + file_id, + basename: `${file_id}.png`, + }); + // toolCall.asset_pointer = file.filepath; + const prelimImage = { + file_id, + filename: path.basename(file.filepath), + filepath: file.filepath, + height: file.height, + width: file.width, + }; + // check if every key has a value before adding to content + const prelimImageKeys = Object.keys(prelimImage); + const validImageFile = prelimImageKeys.every((key) => prelimImage[key]); + + if (!validImageFile) { + continue; + } + + const image_file = { + [ContentTypes.IMAGE_FILE]: prelimImage, + type: ContentTypes.IMAGE_FILE, + index: openai.index, + }; + openai.addContentData(image_file); + openai.processedFileIds.add(file_id); + openai.index++; + } + } else if ( + toolCall.type === ToolCallTypes.FUNCTION && + step.status === StepStatus.COMPLETED && + imageGenTools.has(toolCall[toolCall.type].name) + ) { + /* If a change is detected, skip image generation tools as already processed */ + openai.seenToolCalls.set(toolCall.id, toolCall); + continue; + } + + openai.addContentData({ + [ContentTypes.TOOL_CALL]: toolCall, + index: toolCallIndex, + type: ContentTypes.TOOL_CALL, + }); + + // Update the stored tool call + openai.seenToolCalls.set(toolCall.id, toolCall); + } + } else if (step.type === StepTypes.MESSAGE_CREATION && step.status === StepStatus.COMPLETED) { + const { message_id } = step.step_details.message_creation; + if (openai.seenCompletedMessages.has(message_id)) { + return; + } + + openai.seenCompletedMessages.add(message_id); + + const message = await openai.beta.threads.messages.retrieve(thread_id, message_id); + messages.push(message); + + let messageIndex = openai.mappedOrder.get(step.id); + if (messageIndex === undefined) { + // New message + messageIndex = openai.index; + openai.mappedOrder.set(step.id, openai.index); + openai.index++; + } + + const result = await processMessages(openai, [message]); + openai.addContentData({ + [ContentTypes.TEXT]: { value: result.text }, + type: ContentTypes.TEXT, + index: messageIndex, + }); + + // Create the Factory Function to stream the message + const { onProgress: progressCallback } = createOnProgress({ + // todo: add option to save partialText to db + // onProgress: () => {}, + }); + + // This creates a function that attaches all of the parameters + // specified here to each SSE message generated by the TextStream + const onProgress = progressCallback({ + res: openai.res, + index: messageIndex, + messageId: openai.responseMessage.messageId, + type: ContentTypes.TEXT, + stream: true, + thread_id, + }); + + // Create a small buffer before streaming begins + await sleep(500); + + const stream = new TextStream(result.text, { delay: 9 }); + await stream.processTextStream(onProgress); + } + } + + return in_progress; +} + +/** + * Initializes a RunManager with handlers, then invokes waitForRun to monitor and manage an OpenAI run. + * + * @param {Object} params - The parameters for managing and monitoring the run. + * @param {OpenAIClient} params.openai - The OpenAI client instance. + * @param {string} params.run_id - The ID of the run to manage and monitor. + * @param {string} params.thread_id - The ID of the thread associated with the run. + * @param {RunStep[]} params.accumulatedSteps - The accumulated steps for the run. + * @param {ThreadMessage[]} params.accumulatedMessages - The accumulated messages for the run. + * @param {InProgressFunction} [params.in_progress] - The `in_progress` function from a previous run. + * @return {Promise} A promise that resolves to an object containing the run and managed steps. + */ +async function runAssistant({ + openai, + run_id, + thread_id, + accumulatedSteps = [], + accumulatedMessages = [], + in_progress: inProgress, +}) { + let steps = accumulatedSteps; + let messages = accumulatedMessages; + const in_progress = inProgress ?? createInProgressHandler(openai, thread_id, messages); + openai.in_progress = in_progress; + + const runManager = new RunManager({ + in_progress, + final: async ({ step, runStatus, stepsByStatus }) => { + logger.debug(`[runAssistant] Final step for ${run_id} with status ${runStatus}`, step); + + const promises = []; + // promises.push( + // openai.beta.threads.messages.list(thread_id, defaultOrderQuery), + // ); + + // const finalSteps = stepsByStatus[runStatus]; + // for (const stepPromise of finalSteps) { + // promises.push(stepPromise); + // } + + // loop across all statuses + for (const [_status, stepsPromises] of Object.entries(stepsByStatus)) { + promises.push(...stepsPromises); + } + + const resolved = await Promise.all(promises); + const finalSteps = filterSteps(steps.concat(resolved)); + + if (step.type === StepTypes.MESSAGE_CREATION) { + const incompleteToolCallSteps = finalSteps.filter( + (s) => s && s.type === StepTypes.TOOL_CALLS && !openai.completeToolCallSteps.has(s.id), + ); + for (const incompleteToolCallStep of incompleteToolCallSteps) { + await in_progress({ step: incompleteToolCallStep }); + } + } + await in_progress({ step }); + // const res = resolved.shift(); + // messages = messages.concat(res.data.filter((msg) => msg && msg.run_id === run_id)); + resolved.push(step); + /* Note: no issues without deep cloning, but it's safer to do so */ + steps = klona(finalSteps); + }, + }); + + /** @type {TCustomConfig.endpoints.assistants} */ + const assistantsEndpointConfig = openai.req.app.locals?.[EModelEndpoint.assistants] ?? {}; + const { pollIntervalMs, timeoutMs } = assistantsEndpointConfig; + + const run = await waitForRun({ + openai, + run_id, + thread_id, + runManager, + pollIntervalMs, + timeout: timeoutMs, + }); + + if (!run.required_action) { + // const { messages: sortedMessages, text } = await processMessages(openai, messages); + // return { run, steps, messages: sortedMessages, text }; + const sortedMessages = messages.sort((a, b) => a.created_at - b.created_at); + return { run, steps, messages: sortedMessages }; + } + + const { submit_tool_outputs } = run.required_action; + const actions = submit_tool_outputs.tool_calls.map((item) => { + const functionCall = item.function; + const args = JSON.parse(functionCall.arguments); + return { + tool: functionCall.name, + toolInput: args, + toolCallId: item.id, + run_id, + thread_id, + }; + }); + + const outputs = await processRequiredActions(openai, actions); + + const toolRun = await openai.beta.threads.runs.submitToolOutputs(run.thread_id, run.id, outputs); + + // Recursive call with accumulated steps and messages + return await runAssistant({ + openai, + run_id: toolRun.id, + thread_id, + accumulatedSteps: steps, + accumulatedMessages: messages, + in_progress, + }); +} + +/** + * Sorts, processes, and flattens messages to a single string. + * + * @param {OpenAIClient} openai - The OpenAI client instance. + * @param {ThreadMessage[]} messages - An array of messages. + * @returns {Promise<{messages: ThreadMessage[], text: string}>} The sorted messages and the flattened text. + */ +async function processMessages(openai, messages = []) { + const sorted = messages.sort((a, b) => a.created_at - b.created_at); + + let text = ''; + for (const message of sorted) { + message.files = []; + for (const content of message.content) { + const processImageFile = + content.type === 'image_file' && !openai.processedFileIds.has(content.image_file?.file_id); + if (processImageFile) { + const { file_id } = content.image_file; + + const file = await retrieveAndProcessFile({ openai, file_id, basename: `${file_id}.png` }); + openai.processedFileIds.add(file_id); + message.files.push(file); + continue; + } + + text += (content.text?.value ?? '') + ' '; + + // Process annotations if they exist + if (!content.text?.annotations) { + continue; + } + + logger.debug('Processing annotations:', content.text.annotations); + for (const annotation of content.text.annotations) { + logger.debug('Current annotation:', annotation); + let file; + const processFilePath = + annotation.file_path && !openai.processedFileIds.has(annotation.file_path?.file_id); + + if (processFilePath) { + const basename = imageExtRegex.test(annotation.text) + ? path.basename(annotation.text) + : null; + file = await retrieveAndProcessFile({ + openai, + file_id: annotation.file_path.file_id, + basename, + }); + openai.processedFileIds.add(annotation.file_path.file_id); + } + + const processFileCitation = + annotation.file_citation && + !openai.processedFileIds.has(annotation.file_citation?.file_id); + + if (processFileCitation) { + file = await retrieveAndProcessFile({ + openai, + file_id: annotation.file_citation.file_id, + unknownType: true, + }); + openai.processedFileIds.add(annotation.file_citation.file_id); + } + + if (!file && (annotation.file_path || annotation.file_citation)) { + const { file_id } = annotation.file_citation || annotation.file_path || {}; + file = await retrieveAndProcessFile({ openai, file_id, unknownType: true }); + openai.processedFileIds.add(file_id); + } + + if (!file) { + continue; + } + + if (file.purpose && file.purpose === FilePurpose.Assistants) { + text = text.replace(annotation.text, file.filename); + } else if (file.filepath) { + text = text.replace(annotation.text, file.filepath); + } + + message.files.push(file); + } + } + } + + return { messages: sorted, text }; +} + module.exports = { - initThread, - createRun, - waitForRun, getResponse, - handleRun, - sleep, - mapMessagesToSteps, + runAssistant, + processMessages, + createOnTextProgress, }; diff --git a/api/server/services/Config/EndpointService.js b/api/server/services/Config/EndpointService.js index 998e7a83d..eeab6a4c7 100644 --- a/api/server/services/Config/EndpointService.js +++ b/api/server/services/Config/EndpointService.js @@ -2,6 +2,7 @@ const { EModelEndpoint } = require('librechat-data-provider'); const { OPENAI_API_KEY: openAIApiKey, + ASSISTANTS_API_KEY: assistantsApiKey, AZURE_API_KEY: azureOpenAIApiKey, ANTHROPIC_API_KEY: anthropicApiKey, CHATGPT_TOKEN: chatGPTToken, @@ -28,7 +29,7 @@ module.exports = { userProvidedOpenAI, googleKey, [EModelEndpoint.openAI]: isUserProvided(openAIApiKey), - [EModelEndpoint.assistant]: isUserProvided(openAIApiKey), + [EModelEndpoint.assistants]: isUserProvided(assistantsApiKey), [EModelEndpoint.azureOpenAI]: isUserProvided(azureOpenAIApiKey), [EModelEndpoint.chatGPTBrowser]: isUserProvided(chatGPTToken), [EModelEndpoint.anthropic]: isUserProvided(anthropicApiKey), diff --git a/api/server/services/Config/handleRateLimits.js b/api/server/services/Config/handleRateLimits.js new file mode 100644 index 000000000..d40ccfb4f --- /dev/null +++ b/api/server/services/Config/handleRateLimits.js @@ -0,0 +1,22 @@ +/** + * + * @param {TCustomConfig['rateLimits'] | undefined} rateLimits + */ +const handleRateLimits = (rateLimits) => { + if (!rateLimits) { + return; + } + const { fileUploads } = rateLimits; + if (!fileUploads) { + return; + } + + process.env.FILE_UPLOAD_IP_MAX = fileUploads.ipMax ?? process.env.FILE_UPLOAD_IP_MAX; + process.env.FILE_UPLOAD_IP_WINDOW = + fileUploads.ipWindowInMinutes ?? process.env.FILE_UPLOAD_IP_WINDOW; + process.env.FILE_UPLOAD_USER_MAX = fileUploads.userMax ?? process.env.FILE_UPLOAD_USER_MAX; + process.env.FILE_UPLOAD_USER_WINDOW = + fileUploads.userWindowInMinutes ?? process.env.FILE_UPLOAD_USER_WINDOW; +}; + +module.exports = handleRateLimits; diff --git a/api/server/services/Config/loadCustomConfig.js b/api/server/services/Config/loadCustomConfig.js index a90083b2f..eecf8eff8 100644 --- a/api/server/services/Config/loadCustomConfig.js +++ b/api/server/services/Config/loadCustomConfig.js @@ -7,17 +7,22 @@ const { logger } = require('~/config'); const projectRoot = path.resolve(__dirname, '..', '..', '..', '..'); const configPath = path.resolve(projectRoot, 'librechat.yaml'); +let i = 0; + /** * Load custom configuration files and caches the object if the `cache` field at root is true. * Validation via parsing the config file with the config schema. * @function loadCustomConfig * @returns {Promise} A promise that resolves to null or the custom config object. * */ - async function loadCustomConfig() { const customConfig = loadYaml(configPath); if (!customConfig) { - logger.info('Custom config file missing or YAML format invalid.'); + i === 0 && + logger.info( + 'Custom config file missing or YAML format invalid.\n\nCheck out the latest config file guide for configurable options and features.\nhttps://docs.librechat.ai/install/configuration/custom_config.html\n\n', + ); + i === 0 && i++; return null; } @@ -28,6 +33,7 @@ async function loadCustomConfig() { } else { logger.info('Custom config file loaded:'); logger.info(JSON.stringify(customConfig, null, 2)); + logger.debug('Custom config:', customConfig); } if (customConfig.cache) { diff --git a/api/server/services/Config/loadDefaultEConfig.js b/api/server/services/Config/loadDefaultEConfig.js index 34ab05d8a..0f1c7dcbb 100644 --- a/api/server/services/Config/loadDefaultEConfig.js +++ b/api/server/services/Config/loadDefaultEConfig.js @@ -9,10 +9,11 @@ const { config } = require('./EndpointService'); */ async function loadDefaultEndpointsConfig() { const { google, gptPlugins } = await loadAsyncEndpoints(); - const { openAI, bingAI, anthropic, azureOpenAI, chatGPTBrowser } = config; + const { openAI, assistants, bingAI, anthropic, azureOpenAI, chatGPTBrowser } = config; let enabledEndpoints = [ EModelEndpoint.openAI, + EModelEndpoint.assistants, EModelEndpoint.azureOpenAI, EModelEndpoint.google, EModelEndpoint.bingAI, @@ -31,6 +32,7 @@ async function loadDefaultEndpointsConfig() { const endpointConfig = { [EModelEndpoint.openAI]: openAI, + [EModelEndpoint.assistants]: assistants, [EModelEndpoint.azureOpenAI]: azureOpenAI, [EModelEndpoint.google]: google, [EModelEndpoint.bingAI]: bingAI, diff --git a/api/server/services/Config/loadDefaultModels.js b/api/server/services/Config/loadDefaultModels.js index e22a9b276..29be47822 100644 --- a/api/server/services/Config/loadDefaultModels.js +++ b/api/server/services/Config/loadDefaultModels.js @@ -7,10 +7,6 @@ const { getChatGPTBrowserModels, } = require('~/server/services/ModelService'); -const fitlerAssistantModels = (str) => { - return /gpt-4|gpt-3\\.5/i.test(str) && !/vision|instruct/i.test(str); -}; - /** * Loads the default models for the application. * @async @@ -28,6 +24,7 @@ async function loadDefaultModels(req) { azure: useAzurePlugins, plugins: true, }); + const assistant = await getOpenAIModels({ assistants: true }); return { [EModelEndpoint.openAI]: openAI, @@ -37,7 +34,7 @@ async function loadDefaultModels(req) { [EModelEndpoint.azureOpenAI]: azureOpenAI, [EModelEndpoint.bingAI]: ['BingAI', 'Sydney'], [EModelEndpoint.chatGPTBrowser]: chatGPTBrowser, - [EModelEndpoint.assistant]: openAI.filter(fitlerAssistantModels), + [EModelEndpoint.assistants]: assistant, }; } diff --git a/api/server/services/Endpoints/assistant/addTitle.js b/api/server/services/Endpoints/assistant/addTitle.js new file mode 100644 index 000000000..691153915 --- /dev/null +++ b/api/server/services/Endpoints/assistant/addTitle.js @@ -0,0 +1,28 @@ +const { CacheKeys } = require('librechat-data-provider'); +const { saveConvo } = require('~/models/Conversation'); +const getLogStores = require('~/cache/getLogStores'); +const { isEnabled } = require('~/server/utils'); + +const addTitle = async (req, { text, responseText, conversationId, client }) => { + const { TITLE_CONVO = 'true' } = process.env ?? {}; + if (!isEnabled(TITLE_CONVO)) { + return; + } + + if (client.options.titleConvo === false) { + return; + } + + const titleCache = getLogStores(CacheKeys.GEN_TITLE); + const key = `${req.user.id}-${conversationId}`; + + const title = await client.titleConvo({ text, conversationId, responseText }); + await titleCache.set(key, title); + + await saveConvo(req.user.id, { + conversationId, + title, + }); +}; + +module.exports = addTitle; diff --git a/api/server/services/Endpoints/assistant/buildOptions.js b/api/server/services/Endpoints/assistant/buildOptions.js new file mode 100644 index 000000000..4197d976b --- /dev/null +++ b/api/server/services/Endpoints/assistant/buildOptions.js @@ -0,0 +1,15 @@ +const buildOptions = (endpoint, parsedBody) => { + // eslint-disable-next-line no-unused-vars + const { promptPrefix, chatGptLabel, resendImages, imageDetail, ...rest } = parsedBody; + const endpointOption = { + endpoint, + promptPrefix, + modelOptions: { + ...rest, + }, + }; + + return endpointOption; +}; + +module.exports = buildOptions; diff --git a/api/server/services/Endpoints/assistant/index.js b/api/server/services/Endpoints/assistant/index.js new file mode 100644 index 000000000..772b1efb1 --- /dev/null +++ b/api/server/services/Endpoints/assistant/index.js @@ -0,0 +1,9 @@ +const addTitle = require('./addTitle'); +const buildOptions = require('./buildOptions'); +const initializeClient = require('./initializeClient'); + +module.exports = { + addTitle, + buildOptions, + initializeClient, +}; diff --git a/api/server/services/Endpoints/assistant/initializeClient.js b/api/server/services/Endpoints/assistant/initializeClient.js new file mode 100644 index 000000000..886a037ad --- /dev/null +++ b/api/server/services/Endpoints/assistant/initializeClient.js @@ -0,0 +1,80 @@ +const OpenAI = require('openai'); +const { HttpsProxyAgent } = require('https-proxy-agent'); +const { EModelEndpoint } = require('librechat-data-provider'); +const { + getUserKey, + getUserKeyExpiry, + checkUserKeyExpiry, +} = require('~/server/services/UserService'); +const OpenAIClient = require('~/app/clients/OpenAIClient'); + +const initializeClient = async ({ req, res, endpointOption, initAppClient = false }) => { + const { PROXY, OPENAI_ORGANIZATION, ASSISTANTS_API_KEY, ASSISTANTS_BASE_URL } = process.env; + + const opts = {}; + const baseURL = ASSISTANTS_BASE_URL ?? null; + + if (baseURL) { + opts.baseURL = baseURL; + } + + if (PROXY) { + opts.httpAgent = new HttpsProxyAgent(PROXY); + } + + if (OPENAI_ORGANIZATION) { + opts.organization = OPENAI_ORGANIZATION; + } + + const credentials = ASSISTANTS_API_KEY; + + const isUserProvided = credentials === 'user_provided'; + + let userKey = null; + if (isUserProvided) { + const expiresAt = getUserKeyExpiry({ userId: req.user.id, name: EModelEndpoint.assistants }); + checkUserKeyExpiry( + expiresAt, + 'Your Assistants API key has expired. Please provide your API key again.', + ); + userKey = await getUserKey({ userId: req.user.id, name: EModelEndpoint.assistants }); + } + + let apiKey = isUserProvided ? userKey : credentials; + + if (!apiKey) { + throw new Error('API key not provided.'); + } + + /** @type {OpenAIClient} */ + const openai = new OpenAI({ + apiKey, + ...opts, + }); + openai.req = req; + openai.res = res; + + if (endpointOption && initAppClient) { + const clientOptions = { + reverseProxyUrl: baseURL, + proxy: PROXY ?? null, + req, + res, + ...endpointOption, + }; + + const client = new OpenAIClient(apiKey, clientOptions); + return { + client, + openai, + openAIApiKey: apiKey, + }; + } + + return { + openai, + openAIApiKey: apiKey, + }; +}; + +module.exports = initializeClient; diff --git a/api/server/services/Files/Firebase/crud.js b/api/server/services/Files/Firebase/crud.js index 68f534bcb..0567d2afa 100644 --- a/api/server/services/Files/Firebase/crud.js +++ b/api/server/services/Files/Firebase/crud.js @@ -1,5 +1,6 @@ const fetch = require('node-fetch'); const { ref, uploadBytes, getDownloadURL, deleteObject } = require('firebase/storage'); +const { getBufferMetadata } = require('~/server/utils'); const { getFirebaseStorage } = require('./initialize'); /** @@ -41,9 +42,8 @@ async function deleteFile(basePath, fileName) { * @param {string} [params.basePath='images'] - Optional. The base basePath in Firebase Storage where the file will * be stored. Defaults to 'images' if not specified. * - * @returns {Promise} - * A promise that resolves to the file name if the file is successfully uploaded, or null if there - * is an error in initialization or upload. + * @returns {Promise<{ bytes: number, type: string, dimensions: Record} | null>} + * A promise that resolves to the file metadata if the file is successfully saved, or null if there is an error. */ async function saveURLToFirebase({ userId, URL, fileName, basePath = 'images' }) { const storage = getFirebaseStorage(); @@ -53,10 +53,12 @@ async function saveURLToFirebase({ userId, URL, fileName, basePath = 'images' }) } const storageRef = ref(storage, `${basePath}/${userId.toString()}/${fileName}`); + const response = await fetch(URL); + const buffer = await response.buffer(); try { - await uploadBytes(storageRef, await fetch(URL).then((response) => response.buffer())); - return fileName; + await uploadBytes(storageRef, buffer); + return await getBufferMetadata(buffer); } catch (error) { console.error('Error uploading file to Firebase Storage:', error.message); return null; diff --git a/api/server/services/Files/Firebase/images.js b/api/server/services/Files/Firebase/images.js index 95b600962..7d45f22d7 100644 --- a/api/server/services/Files/Firebase/images.js +++ b/api/server/services/Files/Firebase/images.js @@ -1,7 +1,8 @@ const fs = require('fs'); const path = require('path'); const sharp = require('sharp'); -const { resizeImage } = require('../images/resize'); +const { resizeImageBuffer } = require('../images/resize'); +const { updateUser } = require('~/models/userMethods'); const { saveBufferToFirebase } = require('./crud'); const { updateFile } = require('~/models/File'); const { logger } = require('~/config'); @@ -11,7 +12,7 @@ const { logger } = require('~/config'); * resolution. * * - * @param {Object} req - The request object from Express. It should have a `user` property with an `id` + * @param {Express.Request} req - The request object from Express. It should have a `user` property with an `id` * representing the user, and an `app.locals.paths` object with an `imageOutput` path. * @param {Express.Multer.File} file - The file object, which is part of the request. The file object should * have a `path` property that points to the location of the uploaded file. @@ -26,7 +27,8 @@ const { logger } = require('~/config'); */ async function uploadImageToFirebase(req, file, resolution = 'high') { const inputFilePath = file.path; - const { buffer: resizedBuffer, width, height } = await resizeImage(inputFilePath, resolution); + const inputBuffer = await fs.promises.readFile(inputFilePath); + const { buffer: resizedBuffer, width, height } = await resizeImageBuffer(inputBuffer, resolution); const extension = path.extname(inputFilePath); const userId = req.user.id; @@ -73,15 +75,15 @@ async function prepareImageURL(req, file) { * * @param {object} params - The parameters object. * @param {Buffer} params.buffer - The Buffer containing the avatar image in WebP format. - * @param {object} params.User - The User document (mongoose); TODO: remove direct use of Model, `User` + * @param {string} params.userId - The user ID. * @param {string} params.manual - A string flag indicating whether the update is manual ('true' or 'false'). * @returns {Promise} - A promise that resolves with the URL of the uploaded avatar. * @throws {Error} - Throws an error if Firebase is not initialized or if there is an error in uploading. */ -async function processFirebaseAvatar({ buffer, User, manual }) { +async function processFirebaseAvatar({ buffer, userId, manual }) { try { const downloadURL = await saveBufferToFirebase({ - userId: User._id.toString(), + userId, buffer, fileName: 'avatar.png', }); @@ -91,8 +93,7 @@ async function processFirebaseAvatar({ buffer, User, manual }) { const url = `${downloadURL}?manual=${isManual}`; if (isManual) { - User.avatar = url; - await User.save(); + await updateUser(userId, { avatar: url }); } return url; diff --git a/api/server/services/Files/Local/crud.js b/api/server/services/Files/Local/crud.js index cd8bdbc5d..a60038e8e 100644 --- a/api/server/services/Files/Local/crud.js +++ b/api/server/services/Files/Local/crud.js @@ -1,8 +1,9 @@ const fs = require('fs'); const path = require('path'); const axios = require('axios'); -const { logger } = require('~/config'); +const { getBufferMetadata } = require('~/server/utils'); const paths = require('~/config/paths'); +const { logger } = require('~/config'); /** * Saves a file to a specified output path with a new filename. @@ -13,7 +14,7 @@ const paths = require('~/config/paths'); * @returns {Promise} The full path of the saved file. * @throws Will throw an error if the file saving process fails. */ -async function saveFile(file, outputPath, outputFilename) { +async function saveLocalFile(file, outputPath, outputFilename) { try { if (!fs.existsSync(outputPath)) { fs.mkdirSync(outputPath, { recursive: true }); @@ -44,9 +45,41 @@ async function saveFile(file, outputPath, outputFilename) { const saveLocalImage = async (req, file, filename) => { const imagePath = req.app.locals.paths.imageOutput; const outputPath = path.join(imagePath, req.user.id ?? ''); - await saveFile(file, outputPath, filename); + await saveLocalFile(file, outputPath, filename); }; +/** + * Saves a buffer to a specified directory on the local file system. + * + * @param {Object} params - The parameters object. + * @param {string} params.userId - The user's unique identifier. This is used to create a user-specific directory. + * @param {Buffer} params.buffer - The buffer to be saved. + * @param {string} params.fileName - The name of the file to be saved. + * @param {string} [params.basePath='images'] - Optional. The base path where the file will be stored. + * Defaults to 'images' if not specified. + * @returns {Promise} - A promise that resolves to the path of the saved file. + */ +async function saveLocalBuffer({ userId, buffer, fileName, basePath = 'images' }) { + try { + const { publicPath, uploads } = paths; + + const directoryPath = path.join(basePath === 'images' ? publicPath : uploads, basePath, userId); + + if (!fs.existsSync(directoryPath)) { + fs.mkdirSync(directoryPath, { recursive: true }); + } + + fs.writeFileSync(path.join(directoryPath, fileName), buffer); + + const filePath = path.posix.join('/', basePath, userId, fileName); + + return filePath; + } catch (error) { + logger.error('[saveLocalBuffer] Error while saving the buffer:', error); + throw error; + } +} + /** * Saves a file from a given URL to a local directory. The function fetches the file using the provided URL, * determines the content type, and saves it to a specified local directory with the correct file extension. @@ -62,20 +95,18 @@ const saveLocalImage = async (req, file, filename) => { * @param {string} [params.basePath='images'] - Optional. The base directory where the file will be saved. * Defaults to 'images' if not specified. * - * @returns {Promise} - * A promise that resolves to the file name if the file is successfully saved, or null if there is an error. + * @returns {Promise<{ bytes: number, type: string, dimensions: Record} | null>} + * A promise that resolves to the file metadata if the file is successfully saved, or null if there is an error. */ async function saveFileFromURL({ userId, URL, fileName, basePath = 'images' }) { try { - // Fetch the file from the URL const response = await axios({ url: URL, - responseType: 'stream', + responseType: 'arraybuffer', }); - // Get the content type from the response headers - const contentType = response.headers['content-type']; - let extension = contentType.split('/').pop(); + const buffer = Buffer.from(response.data, 'binary'); + const { bytes, type, dimensions, extension } = await getBufferMetadata(buffer); // Construct the outputPath based on the basePath and userId const outputPath = path.join(paths.publicPath, basePath, userId.toString()); @@ -92,17 +123,15 @@ async function saveFileFromURL({ userId, URL, fileName, basePath = 'images' }) { fileName += `.${extension}`; } - // Create a writable stream for the output path - const outputFilePath = path.join(outputPath, path.basename(fileName)); - const writer = fs.createWriteStream(outputFilePath); + // Save the file to the output path + const outputFilePath = path.join(outputPath, fileName); + fs.writeFileSync(outputFilePath, buffer); - // Pipe the response data to the output file - response.data.pipe(writer); - - return new Promise((resolve, reject) => { - writer.on('finish', () => resolve(fileName)); - writer.on('error', reject); - }); + return { + bytes, + type, + dimensions, + }; } catch (error) { logger.error('[saveFileFromURL] Error while saving the file:', error); return null; @@ -171,4 +200,11 @@ const deleteLocalFile = async (req, file) => { await fs.promises.unlink(filepath); }; -module.exports = { saveFile, saveLocalImage, saveFileFromURL, getLocalFileURL, deleteLocalFile }; +module.exports = { + saveLocalFile, + saveLocalImage, + saveLocalBuffer, + saveFileFromURL, + getLocalFileURL, + deleteLocalFile, +}; diff --git a/api/server/services/Files/Local/images.js b/api/server/services/Files/Local/images.js index 63ed5b2f6..f30594e92 100644 --- a/api/server/services/Files/Local/images.js +++ b/api/server/services/Files/Local/images.js @@ -1,7 +1,8 @@ const fs = require('fs'); const path = require('path'); const sharp = require('sharp'); -const { resizeImage } = require('../images/resize'); +const { resizeImageBuffer } = require('../images/resize'); +const { updateUser } = require('~/models/userMethods'); const { updateFile } = require('~/models/File'); /** @@ -28,7 +29,8 @@ const { updateFile } = require('~/models/File'); */ async function uploadLocalImage(req, file, resolution = 'high') { const inputFilePath = file.path; - const { buffer: resizedBuffer, width, height } = await resizeImage(inputFilePath, resolution); + const inputBuffer = await fs.promises.readFile(inputFilePath); + const { buffer: resizedBuffer, width, height } = await resizeImageBuffer(inputBuffer, resolution); const extension = path.extname(inputFilePath); const { imageOutput } = req.app.locals.paths; @@ -96,17 +98,17 @@ async function prepareImagesLocal(req, file) { } /** - * Uploads a user's avatar to Firebase Storage and returns the URL. + * Uploads a user's avatar to local server storage and returns the URL. * If the 'manual' flag is set to 'true', it also updates the user's avatar URL in the database. * * @param {object} params - The parameters object. * @param {Buffer} params.buffer - The Buffer containing the avatar image in WebP format. - * @param {object} params.User - The User document (mongoose); TODO: remove direct use of Model, `User` + * @param {string} params.userId - The user ID. * @param {string} params.manual - A string flag indicating whether the update is manual ('true' or 'false'). * @returns {Promise} - A promise that resolves with the URL of the uploaded avatar. * @throws {Error} - Throws an error if Firebase is not initialized or if there is an error in uploading. */ -async function processLocalAvatar({ buffer, User, manual }) { +async function processLocalAvatar({ buffer, userId, manual }) { const userDir = path.resolve( __dirname, '..', @@ -117,10 +119,11 @@ async function processLocalAvatar({ buffer, User, manual }) { 'client', 'public', 'images', - User._id.toString(), + userId, ); + const fileName = `avatar-${new Date().getTime()}.png`; - const urlRoute = `/images/${User._id.toString()}/${fileName}`; + const urlRoute = `/images/${userId}/${fileName}`; const avatarPath = path.join(userDir, fileName); await fs.promises.mkdir(userDir, { recursive: true }); @@ -130,8 +133,7 @@ async function processLocalAvatar({ buffer, User, manual }) { let url = `${urlRoute}?manual=${isManual}`; if (isManual) { - User.avatar = url; - await User.save(); + await updateUser(userId, { avatar: url }); } return url; diff --git a/api/server/services/Files/OpenAI/crud.js b/api/server/services/Files/OpenAI/crud.js new file mode 100644 index 000000000..6cad1603a --- /dev/null +++ b/api/server/services/Files/OpenAI/crud.js @@ -0,0 +1,49 @@ +const fs = require('fs'); + +/** + * Uploads a file that can be used across various OpenAI services. + * + * @param {Express.Request} req - The request object from Express. It should have a `user` property with an `id` + * representing the user, and an `app.locals.paths` object with an `imageOutput` path. + * @param {Express.Multer.File} file - The file uploaded to the server via multer. + * @param {OpenAI} openai - The initialized OpenAI client. + * @returns {Promise} + */ +async function uploadOpenAIFile(req, file, openai) { + try { + const uploadedFile = await openai.files.create({ + file: fs.createReadStream(file.path), + purpose: 'assistants', + }); + + console.log('File uploaded successfully to OpenAI'); + + return uploadedFile; + } catch (error) { + console.error('Error uploading file to OpenAI:', error.message); + throw error; + } +} + +/** + * Deletes a file previously uploaded to OpenAI. + * + * @param {Express.Request} req - The request object from Express. + * @param {MongoFile} file - The database representation of the uploaded file. + * @param {OpenAI} openai - The initialized OpenAI client. + * @returns {Promise} + */ +async function deleteOpenAIFile(req, file, openai) { + try { + const res = await openai.files.del(file.file_id); + if (!res.deleted) { + throw new Error('OpenAI returned `false` for deleted status'); + } + console.log('File deleted successfully from OpenAI'); + } catch (error) { + console.error('Error deleting file from OpenAI:', error.message); + throw error; + } +} + +module.exports = { uploadOpenAIFile, deleteOpenAIFile }; diff --git a/api/server/services/Files/OpenAI/index.js b/api/server/services/Files/OpenAI/index.js new file mode 100644 index 000000000..a6223d1ee --- /dev/null +++ b/api/server/services/Files/OpenAI/index.js @@ -0,0 +1,5 @@ +const crud = require('./crud'); + +module.exports = { + ...crud, +}; diff --git a/api/server/services/Files/images/avatar.js b/api/server/services/Files/images/avatar.js index 490fc8617..8f4f65b8e 100644 --- a/api/server/services/Files/images/avatar.js +++ b/api/server/services/Files/images/avatar.js @@ -1,42 +1,29 @@ const sharp = require('sharp'); const fs = require('fs').promises; const fetch = require('node-fetch'); -const User = require('~/models/User'); -const { getStrategyFunctions } = require('~/server/services/Files/strategies'); +const { resizeAndConvert } = require('./resize'); const { logger } = require('~/config'); -async function convertToWebP(inputBuffer) { - return sharp(inputBuffer).resize({ width: 150 }).toFormat('webp').toBuffer(); -} - /** * Uploads an avatar image for a user. This function can handle various types of input (URL, Buffer, or File object), - * processes the image to a square format, converts it to WebP format, and then uses a specified file strategy for - * further processing. It performs validation on the user ID and the input type. The function can throw errors for - * invalid input types, fetching issues, or other processing errors. + * processes the image to a square format, converts it to WebP format, and returns the resized buffer. * * @param {Object} params - The parameters object. * @param {string} params.userId - The unique identifier of the user for whom the avatar is being uploaded. - * @param {FileSources} params.fileStrategy - The file handling strategy to use, determining how the avatar is processed. * @param {(string|Buffer|File)} params.input - The input representing the avatar image. Can be a URL (string), * a Buffer, or a File object. - * @param {string} params.manual - A string flag indicating whether the upload process is manual. * * @returns {Promise} - * A promise that resolves to the result of the `processAvatar` function, specific to the chosen file - * strategy. Throws an error if any step in the process fails. + * A promise that resolves to a resized buffer. * * @throws {Error} Throws an error if the user ID is undefined, the input type is invalid, the image fetching fails, * or any other error occurs during the processing. */ -async function uploadAvatar({ userId, fileStrategy, input, manual }) { +async function resizeAvatar({ userId, input }) { try { if (userId === undefined) { throw new Error('User ID is undefined'); } - const _id = userId; - // TODO: remove direct use of Model, `User` - const oldUser = await User.findOne({ _id }); let imageBuffer; if (typeof input === 'string') { @@ -66,13 +53,12 @@ async function uploadAvatar({ userId, fileStrategy, input, manual }) { }) .toBuffer(); - const webPBuffer = await convertToWebP(squaredBuffer); - const { processAvatar } = getStrategyFunctions(fileStrategy); - return await processAvatar({ buffer: webPBuffer, User: oldUser, manual }); + const { buffer } = await resizeAndConvert(squaredBuffer); + return buffer; } catch (error) { logger.error('Error uploading the avatar:', error); throw error; } } -module.exports = uploadAvatar; +module.exports = { resizeAvatar }; diff --git a/api/server/services/Files/images/convert.js b/api/server/services/Files/images/convert.js new file mode 100644 index 000000000..2c5a6ab30 --- /dev/null +++ b/api/server/services/Files/images/convert.js @@ -0,0 +1,69 @@ +const fs = require('fs'); +const path = require('path'); +const sharp = require('sharp'); +const { resizeImageBuffer } = require('./resize'); +const { getStrategyFunctions } = require('../strategies'); + +/** + * Converts an image file or buffer to WebP format with specified resolution. + * + * @param {Express.Request} req - The request object, containing user and app configuration data. + * @param {Buffer | Express.Multer.File} file - The file object, containing either a path or a buffer. + * @param {'low' | 'high'} [resolution='high'] - The desired resolution for the output image. + * @param {string} [basename=''] - The basename of the input file, if it is a buffer. + * @returns {Promise<{filepath: string, bytes: number, width: number, height: number}>} An object containing the path, size, and dimensions of the converted image. + * @throws Throws an error if there is an issue during the conversion process. + */ +async function convertToWebP(req, file, resolution = 'high', basename = '') { + try { + let inputBuffer; + let outputBuffer; + let extension = path.extname(file.path ?? basename).toLowerCase(); + + // Check if the input is a buffer or a file path + if (Buffer.isBuffer(file)) { + inputBuffer = file; + } else if (file && file.path) { + const inputFilePath = file.path; + inputBuffer = await fs.promises.readFile(inputFilePath); + } else { + throw new Error('Invalid input: file must be a buffer or contain a valid path.'); + } + + // Resize the image buffer + const { + buffer: resizedBuffer, + width, + height, + } = await resizeImageBuffer(inputBuffer, resolution); + + // Check if the file is already in WebP format + // If it isn't, convert it: + if (extension === '.webp') { + outputBuffer = resizedBuffer; + } else { + outputBuffer = await sharp(resizedBuffer).toFormat('webp').toBuffer(); + extension = '.webp'; + } + + // Generate a new filename for the output file + const newFileName = + path.basename(file.path ?? basename, path.extname(file.path ?? basename)) + extension; + + const { saveBuffer } = getStrategyFunctions(req.app.locals.fileStrategy); + + const savedFilePath = await saveBuffer({ + userId: req.user.id, + buffer: outputBuffer, + fileName: newFileName, + }); + + const bytes = Buffer.byteLength(outputBuffer); + return { filepath: savedFilePath, bytes, width, height }; + } catch (err) { + console.error(err); + throw err; + } +} + +module.exports = { convertToWebP }; diff --git a/api/server/services/Files/images/index.js b/api/server/services/Files/images/index.js index 1438887e6..889b19f20 100644 --- a/api/server/services/Files/images/index.js +++ b/api/server/services/Files/images/index.js @@ -1,13 +1,13 @@ const avatar = require('./avatar'); +const convert = require('./convert'); const encode = require('./encode'); const parse = require('./parse'); const resize = require('./resize'); -const validate = require('./validate'); module.exports = { + ...convert, ...encode, ...parse, ...resize, - ...validate, avatar, }; diff --git a/api/server/services/Files/images/resize.js b/api/server/services/Files/images/resize.js index dd6f24cee..f27173c9d 100644 --- a/api/server/services/Files/images/resize.js +++ b/api/server/services/Files/images/resize.js @@ -1,6 +1,16 @@ const sharp = require('sharp'); -async function resizeImage(inputFilePath, resolution) { +/** + * Resizes an image from a given buffer based on the specified resolution. + * + * @param {Buffer} inputBuffer - The buffer of the image to be resized. + * @param {'low' | 'high'} resolution - The resolution to resize the image to. + * 'low' for a maximum of 512x512 resolution, + * 'high' for a maximum of 768x2000 resolution. + * @returns {Promise<{buffer: Buffer, width: number, height: number}>} An object containing the resized image buffer and its dimensions. + * @throws Will throw an error if the resolution parameter is invalid. + */ +async function resizeImageBuffer(inputBuffer, resolution) { const maxLowRes = 512; const maxShortSideHighRes = 768; const maxLongSideHighRes = 2000; @@ -12,7 +22,7 @@ async function resizeImage(inputFilePath, resolution) { resizeOptions.width = maxLowRes; resizeOptions.height = maxLowRes; } else if (resolution === 'high') { - const metadata = await sharp(inputFilePath).metadata(); + const metadata = await sharp(inputBuffer).metadata(); const isWidthShorter = metadata.width < metadata.height; if (isWidthShorter) { @@ -43,10 +53,28 @@ async function resizeImage(inputFilePath, resolution) { throw new Error('Invalid resolution parameter'); } - const resizedBuffer = await sharp(inputFilePath).rotate().resize(resizeOptions).toBuffer(); + const resizedBuffer = await sharp(inputBuffer).rotate().resize(resizeOptions).toBuffer(); const resizedMetadata = await sharp(resizedBuffer).metadata(); return { buffer: resizedBuffer, width: resizedMetadata.width, height: resizedMetadata.height }; } -module.exports = { resizeImage }; +/** + * Resizes an image buffer to webp format as well as reduces 150 px width. + * + * @param {Buffer} inputBuffer - The buffer of the image to be resized. + * @returns {Promise<{ buffer: Buffer, width: number, height: number, bytes: number }>} An object containing the resized image buffer, its size and dimensions. + * @throws Will throw an error if the resolution parameter is invalid. + */ +async function resizeAndConvert(inputBuffer) { + const resizedBuffer = await sharp(inputBuffer).resize({ width: 150 }).toFormat('webp').toBuffer(); + const resizedMetadata = await sharp(resizedBuffer).metadata(); + return { + buffer: resizedBuffer, + width: resizedMetadata.width, + height: resizedMetadata.height, + bytes: Buffer.byteLength(resizedBuffer), + }; +} + +module.exports = { resizeImageBuffer, resizeAndConvert }; diff --git a/api/server/services/Files/images/validate.js b/api/server/services/Files/images/validate.js deleted file mode 100644 index 97ae73cf9..000000000 --- a/api/server/services/Files/images/validate.js +++ /dev/null @@ -1,13 +0,0 @@ -const { visionModels } = require('librechat-data-provider'); - -function validateVisionModel(model) { - if (!model) { - return false; - } - - return visionModels.some((visionModel) => model.includes(visionModel)); -} - -module.exports = { - validateVisionModel, -}; diff --git a/api/server/services/Files/process.js b/api/server/services/Files/process.js index 4ee9510b4..5f0d83997 100644 --- a/api/server/services/Files/process.js +++ b/api/server/services/Files/process.js @@ -1,7 +1,25 @@ -const { updateFileUsage, createFile } = require('~/models'); +const path = require('path'); +const { v4 } = require('uuid'); +const mime = require('mime/lite'); +const { + isUUID, + megabyte, + FileContext, + FileSources, + imageExtRegex, + EModelEndpoint, + mergeFileConfig, +} = require('librechat-data-provider'); +const { convertToWebP, resizeAndConvert } = require('~/server/services/Files/images'); +const { initializeClient } = require('~/server/services/Endpoints/assistant'); +const { createFile, updateFileUsage, deleteFiles } = require('~/models/File'); +const { isEnabled, determineFileType } = require('~/server/utils'); +const { LB_QueueAsyncCall } = require('~/server/utils/queue'); const { getStrategyFunctions } = require('./strategies'); const { logger } = require('~/config'); +const { GPTS_DOWNLOAD_IMAGES = 'true' } = process.env; + const processFiles = async (files) => { const promises = []; for (let file of files) { @@ -13,6 +31,99 @@ const processFiles = async (files) => { return await Promise.all(promises); }; +/** + * Enqueues the delete operation to the leaky bucket queue if necessary, or adds it directly to promises. + * + * @param {Express.Request} req - The express request object. + * @param {MongoFile} file - The file object to delete. + * @param {Function} deleteFile - The delete file function. + * @param {Promise[]} promises - The array of promises to await. + * @param {OpenAI | undefined} [openai] - If an OpenAI file, the initialized OpenAI client. + */ +function enqueueDeleteOperation(req, file, deleteFile, promises, openai) { + if (file.source === FileSources.openai) { + // Enqueue to leaky bucket + promises.push( + new Promise((resolve, reject) => { + LB_QueueAsyncCall( + () => deleteFile(req, file, openai), + [], + (err, result) => { + if (err) { + logger.error('Error deleting file from OpenAI source', err); + reject(err); + } else { + resolve(result); + } + }, + ); + }), + ); + } else { + // Add directly to promises + promises.push( + deleteFile(req, file).catch((err) => { + logger.error('Error deleting file', err); + return Promise.reject(err); + }), + ); + } +} + +// TODO: refactor as currently only image files can be deleted this way +// as other filetypes will not reside in public path +/** + * Deletes a list of files from the server filesystem and the database. + * + * @param {Object} params - The params object. + * @param {MongoFile[]} params.files - The file objects to delete. + * @param {Express.Request} params.req - The express request object. + * @param {DeleteFilesBody} params.req.body - The request body. + * @param {string} [params.req.body.assistant_id] - The assistant ID if file uploaded is associated to an assistant. + * + * @returns {Promise} + */ +const processDeleteRequest = async ({ req, files }) => { + const file_ids = files.map((file) => file.file_id); + + const deletionMethods = {}; + const promises = []; + promises.push(deleteFiles(file_ids)); + + /** @type {OpenAI | undefined} */ + let openai; + if (req.body.assistant_id) { + ({ openai } = await initializeClient({ req })); + } + + for (const file of files) { + const source = file.source ?? FileSources.local; + + if (source === FileSources.openai && !openai) { + ({ openai } = await initializeClient({ req })); + } + + if (req.body.assistant_id) { + promises.push(openai.beta.assistants.files.del(req.body.assistant_id, file.file_id)); + } + + if (deletionMethods[source]) { + enqueueDeleteOperation(req, file, deletionMethods[source], promises, openai); + continue; + } + + const { deleteFile } = getStrategyFunctions(source); + if (!deleteFile) { + throw new Error(`Delete function not implemented for ${source}`); + } + + deletionMethods[source] = deleteFile; + enqueueDeleteOperation(req, file, deleteFile, promises, openai); + } + + await Promise.allSettled(promises); +}; + /** * Processes a file URL using a specified file handling strategy. This function accepts a strategy name, * fetches the corresponding file processing functions (for saving and retrieving file URLs), and then @@ -21,25 +132,38 @@ const processFiles = async (files) => { * exception with an appropriate message. * * @param {Object} params - The parameters object. - * @param {FileSources} params.fileStrategy - The file handling strategy to use. Must be a value from the - * `FileSources` enum, which defines different file handling - * strategies (like saving to Firebase, local storage, etc.). + * @param {FileSources} params.fileStrategy - The file handling strategy to use. + * Must be a value from the `FileSources` enum, which defines different file + * handling strategies (like saving to Firebase, local storage, etc.). * @param {string} params.userId - The user's unique identifier. Used for creating user-specific paths or - * references in the file handling process. + * references in the file handling process. * @param {string} params.URL - The URL of the file to be processed. - * @param {string} params.fileName - The name that will be used to save the file. This should include the - * file extension. + * @param {string} params.fileName - The name that will be used to save the file (including extension) * @param {string} params.basePath - The base path or directory where the file will be saved or retrieved from. - * - * @returns {Promise} - * A promise that resolves to the URL of the processed file. It throws an error if the file processing - * fails at any stage. + * @param {FileContext} params.context - The context of the file (e.g., 'avatar', 'image_generation', etc.) + * @returns {Promise} A promise that resolves to the DB representation (MongoFile) + * of the processed file. It throws an error if the file processing fails at any stage. */ -const processFileURL = async ({ fileStrategy, userId, URL, fileName, basePath }) => { +const processFileURL = async ({ fileStrategy, userId, URL, fileName, basePath, context }) => { const { saveURL, getFileURL } = getStrategyFunctions(fileStrategy); try { - await saveURL({ userId, URL, fileName, basePath }); - return await getFileURL({ fileName: `${userId}/${fileName}`, basePath }); + const { bytes, type, dimensions } = await saveURL({ userId, URL, fileName, basePath }); + const filepath = await getFileURL({ fileName: `${userId}/${fileName}`, basePath }); + return await createFile( + { + user: userId, + file_id: v4(), + bytes, + filepath, + filename: fileName, + source: fileStrategy, + type, + context, + width: dimensions.width, + height: dimensions.height, + }, + true, + ); } catch (error) { logger.error(`Error while processing the image with ${fileStrategy}:`, error); throw new Error(`Failed to process the image with ${fileStrategy}. ${error.message}`); @@ -49,7 +173,6 @@ const processFileURL = async ({ fileStrategy, userId, URL, fileName, basePath }) /** * Applies the current strategy for image uploads. * Saves file metadata to the database with an expiry TTL. - * Files must be deleted from the server filesystem manually. * * @param {Object} params - The parameters object. * @param {Express.Request} params.req - The Express request object. @@ -58,7 +181,7 @@ const processFileURL = async ({ fileStrategy, userId, URL, fileName, basePath }) * @param {ImageMetadata} params.metadata - Additional metadata for the file. * @returns {Promise} */ -const processImageUpload = async ({ req, res, file, metadata }) => { +const processImageFile = async ({ req, res, file, metadata }) => { const source = req.app.locals.fileStrategy; const { handleImageUpload } = getStrategyFunctions(source); const { file_id, temp_file_id } = metadata; @@ -71,6 +194,7 @@ const processImageUpload = async ({ req, res, file, metadata }) => { bytes, filepath, filename: file.originalname, + context: FileContext.message_attachment, source, type: 'image/webp', width, @@ -81,8 +205,271 @@ const processImageUpload = async ({ req, res, file, metadata }) => { res.status(200).json({ message: 'File uploaded and processed successfully', ...result }); }; +/** + * Applies the current strategy for image uploads and + * returns minimal file metadata, without saving to the database. + * + * @param {Object} params - The parameters object. + * @param {Express.Request} params.req - The Express request object. + * @param {FileContext} params.context - The context of the file (e.g., 'avatar', 'image_generation', etc.) + * @returns {Promise<{ filepath: string, filename: string, source: string, type: 'image/webp'}>} + */ +const uploadImageBuffer = async ({ req, context }) => { + const source = req.app.locals.fileStrategy; + const { saveBuffer } = getStrategyFunctions(source); + const { buffer, width, height, bytes } = await resizeAndConvert(req.file.buffer); + const file_id = v4(); + const fileName = `img-${file_id}.webp`; + + const filepath = await saveBuffer({ userId: req.user.id, fileName, buffer }); + return await createFile( + { + user: req.user.id, + file_id, + bytes, + filepath, + filename: req.file.originalname, + context, + source, + type: 'image/webp', + width, + height, + }, + true, + ); +}; + +/** + * Applies the current strategy for file uploads. + * Saves file metadata to the database with an expiry TTL. + * Files must be deleted from the server filesystem manually. + * + * @param {Object} params - The parameters object. + * @param {Express.Request} params.req - The Express request object. + * @param {Express.Response} params.res - The Express response object. + * @param {Express.Multer.File} params.file - The uploaded file. + * @param {FileMetadata} params.metadata - Additional metadata for the file. + * @returns {Promise} + */ +const processFileUpload = async ({ req, res, file, metadata }) => { + const isAssistantUpload = metadata.endpoint === EModelEndpoint.assistants; + const source = isAssistantUpload ? FileSources.openai : req.app.locals.fileStrategy; + const { handleFileUpload } = getStrategyFunctions(source); + const { file_id, temp_file_id } = metadata; + + /** @type {OpenAI | undefined} */ + let openai; + if (source === FileSources.openai) { + ({ openai } = await initializeClient({ req })); + } + + const { id, bytes, filename, filepath } = await handleFileUpload(req, file, openai); + + if (isAssistantUpload && !metadata.message_file) { + await openai.beta.assistants.files.create(metadata.assistant_id, { + file_id: id, + }); + } + + const result = await createFile( + { + user: req.user.id, + file_id: id ?? file_id, + temp_file_id, + bytes, + filepath: isAssistantUpload ? `https://api.openai.com/v1/files/${id}` : filepath, + filename: filename ?? file.originalname, + context: isAssistantUpload ? FileContext.assistants : FileContext.message_attachment, + source, + type: file.mimetype, + }, + true, + ); + res.status(200).json({ message: 'File uploaded and processed successfully', ...result }); +}; + +/** + * Retrieves and processes an OpenAI file based on its type. + * + * @param {Object} params - The params passed to the function. + * @param {OpenAIClient} params.openai - The params passed to the function. + * @param {string} params.file_id - The ID of the file to retrieve. + * @param {string} params.basename - The basename of the file (if image); e.g., 'image.jpg'. + * @param {boolean} [params.unknownType] - Whether the file type is unknown. + * @returns {Promise<{file_id: string, filepath: string, source: string, bytes?: number, width?: number, height?: number} | null>} + * - Returns null if `file_id` is not defined; else, the file metadata if successfully retrieved and processed. + */ +async function retrieveAndProcessFile({ openai, file_id, basename: _basename, unknownType }) { + if (!file_id) { + return null; + } + + if (openai.attachedFileIds?.has(file_id)) { + return { + file_id, + // filepath: TODO: local source filepath?, + source: FileSources.openai, + }; + } + + let basename = _basename; + const downloadImages = isEnabled(GPTS_DOWNLOAD_IMAGES); + + /** + * @param {string} file_id - The ID of the file to retrieve. + * @param {boolean} [save] - Whether to save the file metadata to the database. + */ + const retrieveFile = async (file_id, save = false) => { + const _file = await openai.files.retrieve(file_id); + const filepath = `/api/files/download/${file_id}`; + const file = { + ..._file, + type: mime.getType(_file.filename), + filepath, + usage: 1, + file_id, + context: _file.purpose ?? FileContext.message_attachment, + source: FileSources.openai, + }; + + if (save) { + await createFile(file, true); + } else { + try { + await updateFileUsage({ file_id }); + } catch (error) { + logger.error('Error updating file usage', error); + } + } + + return file; + }; + + // If image downloads are not enabled or no basename provided, return only the file metadata + if (!downloadImages || (!basename && !downloadImages)) { + return await retrieveFile(file_id, true); + } + + let data; + try { + const response = await openai.files.content(file_id); + data = await response.arrayBuffer(); + } catch (error) { + logger.error('Error downloading file from OpenAI:', error); + return await retrieveFile(file_id); + } + + if (!data) { + return await retrieveFile(file_id); + } + const dataBuffer = Buffer.from(data); + + /** + * @param {Buffer} dataBuffer + * @param {string} fileExt + */ + const processAsImage = async (dataBuffer, fileExt) => { + // Logic to process image files, convert to webp, etc. + const _file = await convertToWebP(openai.req, dataBuffer, 'high', `${file_id}${fileExt}`); + const file = { + ..._file, + type: 'image/webp', + usage: 1, + file_id, + source: FileSources.openai, + }; + createFile(file, true); + return file; + }; + + /** @param {Buffer} dataBuffer */ + const processOtherFileTypes = async (dataBuffer) => { + // Logic to handle other file types + logger.debug('[retrieveAndProcessFile] Non-image file type detected'); + return { filepath: `/api/files/download/${file_id}`, bytes: dataBuffer.length }; + }; + + // If the filetype is unknown, inspect the file + if (unknownType || !path.extname(basename)) { + const detectedExt = await determineFileType(dataBuffer); + if (detectedExt && imageExtRegex.test('.' + detectedExt)) { + return await processAsImage(dataBuffer, detectedExt); + } else { + return await processOtherFileTypes(dataBuffer); + } + } + + // Existing logic for processing known image types + if (downloadImages && basename && path.extname(basename) && imageExtRegex.test(basename)) { + return await processAsImage(dataBuffer, path.extname(basename)); + } else { + logger.debug('[retrieveAndProcessFile] Not an image or invalid extension: ', basename); + return await processOtherFileTypes(dataBuffer); + } +} + +/** + * Filters a file based on its size and the endpoint origin. + * + * @param {Object} params - The parameters for the function. + * @param {Express.Request} params.req - The request object from Express. + * @param {Express.Multer.File} params.file - The file uploaded to the server via multer. + * @param {boolean} [params.image] - Whether the file expected is an image. + * @returns {void} + * + * @throws {Error} If a file exception is caught (invalid file size or type, lack of metadata). + */ +function filterFile({ req, file, image }) { + const { endpoint, file_id, width, height } = req.body; + + if (!file_id) { + throw new Error('No file_id provided'); + } + + /* parse to validate api call, throws error on fail */ + isUUID.parse(file_id); + + if (!endpoint) { + throw new Error('No endpoint provided'); + } + + const fileConfig = mergeFileConfig(req.app.locals.fileConfig); + + const { fileSizeLimit, supportedMimeTypes } = + fileConfig.endpoints[endpoint] ?? fileConfig.endpoints.default; + + if (file.size > fileSizeLimit) { + throw new Error( + `File size limit of ${fileSizeLimit / megabyte} MB exceeded for ${endpoint} endpoint`, + ); + } + + const isSupportedMimeType = fileConfig.checkType(file.mimetype, supportedMimeTypes); + + if (!isSupportedMimeType) { + throw new Error('Unsupported file type'); + } + + if (!image) { + return; + } + + if (!width) { + throw new Error('No width provided'); + } + + if (!height) { + throw new Error('No height provided'); + } +} + module.exports = { - processImageUpload, + filterFile, processFiles, processFileURL, + processImageFile, + uploadImageBuffer, + processFileUpload, + processDeleteRequest, + retrieveAndProcessFile, }; diff --git a/api/server/services/Files/strategies.js b/api/server/services/Files/strategies.js index 4e2018604..e69251a2c 100644 --- a/api/server/services/Files/strategies.js +++ b/api/server/services/Files/strategies.js @@ -4,46 +4,82 @@ const { prepareImageURL, saveURLToFirebase, deleteFirebaseFile, + saveBufferToFirebase, uploadImageToFirebase, processFirebaseAvatar, } = require('./Firebase'); const { + // saveLocalFile, getLocalFileURL, saveFileFromURL, + saveLocalBuffer, deleteLocalFile, uploadLocalImage, prepareImagesLocal, processLocalAvatar, } = require('./Local'); +const { uploadOpenAIFile, deleteOpenAIFile } = require('./OpenAI'); -// Firebase Strategy Functions +/** + * Firebase Storage Strategy Functions + * + * */ const firebaseStrategy = () => ({ // saveFile: saveURL: saveURLToFirebase, getFileURL: getFirebaseURL, deleteFile: deleteFirebaseFile, + saveBuffer: saveBufferToFirebase, prepareImagePayload: prepareImageURL, processAvatar: processFirebaseAvatar, handleImageUpload: uploadImageToFirebase, }); -// Local Strategy Functions +/** + * Local Server Storage Strategy Functions + * + * */ const localStrategy = () => ({ - // saveFile: , + // saveFile: saveLocalFile, saveURL: saveFileFromURL, getFileURL: getLocalFileURL, + saveBuffer: saveLocalBuffer, deleteFile: deleteLocalFile, processAvatar: processLocalAvatar, handleImageUpload: uploadLocalImage, prepareImagePayload: prepareImagesLocal, }); +/** + * OpenAI Strategy Functions + * + * Note: null values mean that the strategy is not supported. + * */ +const openAIStrategy = () => ({ + /** @type {typeof saveFileFromURL | null} */ + saveURL: null, + /** @type {typeof getLocalFileURL | null} */ + getFileURL: null, + /** @type {typeof saveLocalBuffer | null} */ + saveBuffer: null, + /** @type {typeof processLocalAvatar | null} */ + processAvatar: null, + /** @type {typeof uploadLocalImage | null} */ + handleImageUpload: null, + /** @type {typeof prepareImagesLocal | null} */ + prepareImagePayload: null, + deleteFile: deleteOpenAIFile, + handleFileUpload: uploadOpenAIFile, +}); + // Strategy Selector const getStrategyFunctions = (fileSource) => { if (fileSource === FileSources.firebase) { return firebaseStrategy(); } else if (fileSource === FileSources.local) { return localStrategy(); + } else if (fileSource === FileSources.openai) { + return openAIStrategy(); } else { throw new Error('Invalid file source'); } diff --git a/api/server/services/ModelService.js b/api/server/services/ModelService.js index 45ae11aa1..107411e4c 100644 --- a/api/server/services/ModelService.js +++ b/api/server/services/ModelService.js @@ -69,8 +69,33 @@ const fetchModels = async ({ await cache.set(name, endpointTokenConfig); } models = input.data.map((item) => item.id); - } catch (err) { - logger.error(`Failed to fetch models from ${azure ? 'Azure ' : ''}${name} API`, err); + } catch (error) { + const logMessage = `Failed to fetch models from ${azure ? 'Azure ' : ''}${name} API`; + if (error.response) { + logger.error( + `${logMessage} The request was made and the server responded with a status code that falls out of the range of 2xx: ${ + error.message ? error.message : '' + }`, + { + headers: error.response.headers, + status: error.response.status, + data: error.response.data, + }, + ); + } else if (error.request) { + logger.error( + `${logMessage} The request was made but no response was received: ${ + error.message ? error.message : '' + }`, + { + request: error.request, + }, + ); + } else { + logger.error(`${logMessage} Something happened in setting up the request`, { + message: error.message ? error.message : '', + }); + } } return models; @@ -131,6 +156,9 @@ const fetchOpenAIModels = async (opts, _models = []) => { if (baseURL === openaiBaseURL) { const regex = /(text-davinci-003|gpt-)/; models = models.filter((model) => regex.test(model)); + const instructModels = models.filter((model) => model.includes('instruct')); + const otherModels = models.filter((model) => !model.includes('instruct')); + models = otherModels.concat(instructModels); } await modelsCache.set(baseURL, models); @@ -147,7 +175,11 @@ const fetchOpenAIModels = async (opts, _models = []) => { * @param {boolean} [opts.plugins=false] - Whether to fetch models from the plugins. */ const getOpenAIModels = async (opts) => { - let models = defaultModels.openAI; + let models = defaultModels[EModelEndpoint.openAI]; + + if (opts.assistants) { + models = defaultModels[EModelEndpoint.assistants]; + } if (opts.plugins) { models = models.filter( @@ -161,7 +193,9 @@ const getOpenAIModels = async (opts) => { } let key; - if (opts.azure) { + if (opts.assistants) { + key = 'ASSISTANTS_MODELS'; + } else if (opts.azure) { key = 'AZURE_OPENAI_MODELS'; } else if (opts.plugins) { key = 'PLUGIN_MODELS'; @@ -178,6 +212,10 @@ const getOpenAIModels = async (opts) => { return models; } + if (opts.assistants) { + return models; + } + return await fetchOpenAIModels(opts, models); }; diff --git a/api/server/services/ModelService.spec.js b/api/server/services/ModelService.spec.js index 19fa407ce..7c1d326fa 100644 --- a/api/server/services/ModelService.spec.js +++ b/api/server/services/ModelService.spec.js @@ -210,3 +210,49 @@ describe('getOpenAIModels with mocked config', () => { expect(models).toContain('some-default-model'); }); }); + +describe('getOpenAIModels sorting behavior', () => { + beforeEach(() => { + axios.get.mockResolvedValue({ + data: { + data: [ + { id: 'gpt-3.5-turbo-instruct-0914' }, + { id: 'gpt-3.5-turbo-instruct' }, + { id: 'gpt-3.5-turbo' }, + { id: 'gpt-4-0314' }, + { id: 'gpt-4-turbo-preview' }, + ], + }, + }); + }); + + it('ensures instruct models are listed last', async () => { + const models = await getOpenAIModels({ user: 'user456' }); + + // Check if the last model is an "instruct" model + expect(models[models.length - 1]).toMatch(/instruct/); + + // Check if the "instruct" models are placed at the end + const instructIndexes = models + .map((model, index) => (model.includes('instruct') ? index : -1)) + .filter((index) => index !== -1); + const nonInstructIndexes = models + .map((model, index) => (!model.includes('instruct') ? index : -1)) + .filter((index) => index !== -1); + + expect(Math.max(...nonInstructIndexes)).toBeLessThan(Math.min(...instructIndexes)); + + const expectedOrder = [ + 'gpt-3.5-turbo', + 'gpt-4-0314', + 'gpt-4-turbo-preview', + 'gpt-3.5-turbo-instruct-0914', + 'gpt-3.5-turbo-instruct', + ]; + expect(models).toEqual(expectedOrder); + }); + + afterEach(() => { + jest.clearAllMocks(); + }); +}); diff --git a/api/server/services/PluginService.js b/api/server/services/PluginService.js index 615823829..efe0bb03f 100644 --- a/api/server/services/PluginService.js +++ b/api/server/services/PluginService.js @@ -90,8 +90,7 @@ const updateUserPluginAuth = async (userId, authField, pluginKey, value) => { const deleteUserPluginAuth = async (userId, authField) => { try { - const response = await PluginAuth.deleteOne({ userId, authField }); - return response; + return await PluginAuth.deleteOne({ userId, authField }); } catch (err) { logger.error('[deleteUserPluginAuth]', err); return err; diff --git a/api/server/services/Runs/RunMananger.js b/api/server/services/Runs/RunManager.js similarity index 88% rename from api/server/services/Runs/RunMananger.js rename to api/server/services/Runs/RunManager.js index 67a3624c1..b7f7400c3 100644 --- a/api/server/services/Runs/RunMananger.js +++ b/api/server/services/Runs/RunManager.js @@ -44,15 +44,23 @@ class RunManager { */ async fetchRunSteps({ openai, thread_id, run_id, runStatus, final = false }) { // const { data: steps, first_id, last_id, has_more } = await openai.beta.threads.runs.steps.list(thread_id, run_id); - const { data: _steps } = await openai.beta.threads.runs.steps.list(thread_id, run_id); + const { data: _steps } = await openai.beta.threads.runs.steps.list( + thread_id, + run_id, + {}, + { + timeout: 3000, + maxRetries: 5, + }, + ); const steps = _steps.sort((a, b) => a.created_at - b.created_at); for (const [i, step] of steps.entries()) { - if (this.seenSteps.has(step.id)) { + if (!final && this.seenSteps.has(`${step.id}-${step.status}`)) { continue; } const isLast = i === steps.length - 1; - this.seenSteps.add(step.id); + this.seenSteps.add(`${step.id}-${step.status}`); this.stepsByStatus[runStatus] = this.stepsByStatus[runStatus] || []; const currentStepPromise = (async () => { @@ -64,6 +72,13 @@ class RunManager { return await currentStepPromise; } + if (step.type === 'tool_calls') { + await currentStepPromise; + } + if (step.type === 'message_creation' && step.status === 'completed') { + await currentStepPromise; + } + this.lastStepPromiseByStatus[runStatus] = currentStepPromise; this.stepsByStatus[runStatus].push(currentStepPromise); } @@ -79,7 +94,7 @@ class RunManager { */ async handleStep({ step, runStatus, final, isLast }) { if (this.handlers[runStatus]) { - return this.handlers[runStatus]({ step, final, isLast }); + return await this.handlers[runStatus]({ step, final, isLast }); } if (final && isLast && this.handlers['final']) { diff --git a/api/server/services/Runs/handle.js b/api/server/services/Runs/handle.js new file mode 100644 index 000000000..231891d71 --- /dev/null +++ b/api/server/services/Runs/handle.js @@ -0,0 +1,270 @@ +const { RunStatus, defaultOrderQuery, CacheKeys } = require('librechat-data-provider'); +const getLogStores = require('~/cache/getLogStores'); +const { retrieveRun } = require('./methods'); +const RunManager = require('./RunManager'); +const { logger } = require('~/config'); + +async function withTimeout(promise, timeoutMs, timeoutMessage) { + let timeoutHandle; + + const timeoutPromise = new Promise((_, reject) => { + timeoutHandle = setTimeout(() => { + logger.debug(timeoutMessage); + reject(new Error('Operation timed out')); + }, timeoutMs); + }); + + try { + return await Promise.race([promise, timeoutPromise]); + } finally { + clearTimeout(timeoutHandle); + } +} + +/** + * Creates a run on a thread using the OpenAI API. + * + * @param {Object} params - The parameters for creating a run. + * @param {OpenAIClient} params.openai - The OpenAI client instance. + * @param {string} params.thread_id - The ID of the thread to run. + * @param {Object} params.body - The body of the request to create a run. + * @param {string} params.body.assistant_id - The ID of the assistant to use for this run. + * @param {string} [params.body.model] - Optional. The ID of the model to be used for this run. + * @param {string} [params.body.instructions] - Optional. Override the default system message of the assistant. + * @param {string} [params.body.additional_instructions] - Optional. Appends additional instructions + * at theend of the instructions for the run. This is useful for modifying + * the behavior on a per-run basis without overriding other instructions. + * @param {Object[]} [params.body.tools] - Optional. Override the tools the assistant can use for this run. + * @param {string[]} [params.body.file_ids] - Optional. + * List of File IDs the assistant can use for this run. + * + * **Note:** The API seems to prefer files added to messages, not runs. + * @param {Object} [params.body.metadata] - Optional. Metadata for the run. + * @return {Promise} A promise that resolves to the created run object. + */ +async function createRun({ openai, thread_id, body }) { + return await openai.beta.threads.runs.create(thread_id, body); +} + +/** + * Delays the execution for a specified number of milliseconds. + * + * @param {number} ms - The number of milliseconds to delay. + * @return {Promise} A promise that resolves after the specified delay. + */ +function sleep(ms) { + return new Promise((resolve) => setTimeout(resolve, ms)); +} + +/** + * Waits for a run to complete by repeatedly checking its status. It uses a RunManager instance to fetch and manage run steps based on the run status. + * + * @param {Object} params - The parameters for the waitForRun function. + * @param {OpenAIClient} params.openai - The OpenAI client instance. + * @param {string} params.run_id - The ID of the run to wait for. + * @param {string} params.thread_id - The ID of the thread associated with the run. + * @param {RunManager} params.runManager - The RunManager instance to manage run steps. + * @param {number} [params.pollIntervalMs=750] - The interval for polling the run status; default is 750 milliseconds. + * @param {number} [params.timeout=180000] - The period to wait until timing out polling; default is 3 minutes (in ms). + * @return {Promise} A promise that resolves to the last fetched run object. + */ +async function waitForRun({ + openai, + run_id, + thread_id, + runManager, + pollIntervalMs = 750, + timeout = 60000 * 3, +}) { + let timeElapsed = 0; + let run; + + const cache = getLogStores(CacheKeys.ABORT_KEYS); + const cacheKey = `${openai.req.user.id}:${openai.responseMessage.conversationId}`; + + let i = 0; + let lastSeenStatus = null; + const runIdLog = `run_id: ${run_id}`; + const runInfo = `user: ${openai.req.user.id} | thread_id: ${thread_id} | ${runIdLog}`; + const raceTimeoutMs = 3000; + let maxRetries = 5; + let attempt = 0; + while (timeElapsed < timeout) { + i++; + logger.debug(`[heartbeat ${i}] ${runIdLog} | Retrieving run status...`); + let updatedRun; + + const startTime = Date.now(); + while (!updatedRun && attempt < maxRetries) { + try { + updatedRun = await withTimeout( + retrieveRun({ thread_id, run_id, timeout: raceTimeoutMs, openai }), + raceTimeoutMs, + `[heartbeat ${i}] ${runIdLog} | Run retrieval timed out at ${timeElapsed} ms. Trying again (attempt ${ + attempt + 1 + } of ${maxRetries})...`, + ); + attempt++; + } catch (error) { + logger.warn(`${runIdLog} | Error retrieving run status: ${error}`); + } + } + const endTime = Date.now(); + logger.debug( + `[heartbeat ${i}] ${runIdLog} | Elapsed run retrieval time: ${endTime - startTime}`, + ); + if (!updatedRun) { + const errorMessage = `[waitForRun] ${runIdLog} | Run retrieval failed after ${maxRetries} attempts`; + throw new Error(errorMessage); + } + run = updatedRun; + attempt = 0; + const runStatus = `${runInfo} | status: ${run.status}`; + + if (run.status !== lastSeenStatus) { + logger.debug(`[${run.status}] ${runInfo}`); + lastSeenStatus = run.status; + } + + logger.debug(`[heartbeat ${i}] ${runStatus}`); + + let cancelStatus; + try { + const timeoutMessage = `[heartbeat ${i}] ${runIdLog} | Cancel Status check operation timed out.`; + cancelStatus = await withTimeout(cache.get(cacheKey), raceTimeoutMs, timeoutMessage); + } catch (error) { + logger.warn(`Error retrieving cancel status: ${error}`); + } + + if (cancelStatus === 'cancelled') { + logger.warn(`[waitForRun] ${runStatus} | RUN CANCELLED`); + throw new Error('Run cancelled'); + } + + if (![RunStatus.IN_PROGRESS, RunStatus.QUEUED].includes(run.status)) { + logger.debug(`[FINAL] ${runInfo} | status: ${run.status}`); + await runManager.fetchRunSteps({ + openai, + thread_id: thread_id, + run_id: run_id, + runStatus: run.status, + final: true, + }); + break; + } + + // may use in future; for now, just fetch from the final status + await runManager.fetchRunSteps({ + openai, + thread_id: thread_id, + run_id: run_id, + runStatus: run.status, + }); + + await sleep(pollIntervalMs); + timeElapsed += pollIntervalMs; + } + + if (timeElapsed >= timeout) { + const timeoutMessage = `[waitForRun] ${runInfo} | status: ${run.status} | timed out after ${timeout} ms`; + logger.warn(timeoutMessage); + throw new Error(timeoutMessage); + } + + return run; +} + +/** + * Retrieves all steps of a run. + * + * @deprecated: Steps are handled with runAssistant now. + * @param {Object} params - The parameters for the retrieveRunSteps function. + * @param {OpenAIClient} params.openai - The OpenAI client instance. + * @param {string} params.thread_id - The ID of the thread associated with the run. + * @param {string} params.run_id - The ID of the run to retrieve steps for. + * @return {Promise} A promise that resolves to an array of RunStep objects. + */ +async function _retrieveRunSteps({ openai, thread_id, run_id }) { + const runSteps = await openai.beta.threads.runs.steps.list(thread_id, run_id); + return runSteps; +} + +/** + * Initializes a RunManager with handlers, then invokes waitForRun to monitor and manage an OpenAI run. + * + * @deprecated Use runAssistant instead. + * @param {Object} params - The parameters for managing and monitoring the run. + * @param {OpenAIClient} params.openai - The OpenAI client instance. + * @param {string} params.run_id - The ID of the run to manage and monitor. + * @param {string} params.thread_id - The ID of the thread associated with the run. + * @return {Promise} A promise that resolves to an object containing the run and managed steps. + */ +async function _handleRun({ openai, run_id, thread_id }) { + let steps = []; + let messages = []; + const runManager = new RunManager({ + // 'in_progress': async ({ step, final, isLast }) => { + // // Define logic for handling steps with 'in_progress' status + // }, + // 'queued': async ({ step, final, isLast }) => { + // // Define logic for handling steps with 'queued' status + // }, + final: async ({ step, runStatus, stepsByStatus }) => { + console.log(`Final step for ${run_id} with status ${runStatus}`); + console.dir(step, { depth: null }); + + const promises = []; + promises.push(openai.beta.threads.messages.list(thread_id, defaultOrderQuery)); + + // const finalSteps = stepsByStatus[runStatus]; + // for (const stepPromise of finalSteps) { + // promises.push(stepPromise); + // } + + // loop across all statuses + for (const [_status, stepsPromises] of Object.entries(stepsByStatus)) { + promises.push(...stepsPromises); + } + + const resolved = await Promise.all(promises); + const res = resolved.shift(); + messages = res.data.filter((msg) => msg.run_id === run_id); + resolved.push(step); + steps = resolved; + }, + }); + + const run = await waitForRun({ + openai, + run_id, + thread_id, + runManager, + pollIntervalMs: 750, + timeout: 60000, + }); + const actions = []; + if (run.required_action) { + const { submit_tool_outputs } = run.required_action; + submit_tool_outputs.tool_calls.forEach((item) => { + const functionCall = item.function; + const args = JSON.parse(functionCall.arguments); + actions.push({ + tool: functionCall.name, + toolInput: args, + toolCallId: item.id, + run_id, + thread_id, + }); + }); + } + + return { run, steps, messages, actions }; +} + +module.exports = { + sleep, + createRun, + waitForRun, + // _handleRun, + // retrieveRunSteps, +}; diff --git a/api/server/services/Runs/index.js b/api/server/services/Runs/index.js new file mode 100644 index 000000000..2cb06d467 --- /dev/null +++ b/api/server/services/Runs/index.js @@ -0,0 +1,9 @@ +const handle = require('./handle'); +const methods = require('./methods'); +const RunManager = require('./RunManager'); + +module.exports = { + ...handle, + ...methods, + RunManager, +}; diff --git a/api/server/services/Runs/methods.js b/api/server/services/Runs/methods.js new file mode 100644 index 000000000..ee74a9bd1 --- /dev/null +++ b/api/server/services/Runs/methods.js @@ -0,0 +1,76 @@ +const axios = require('axios'); +const { logger } = require('~/config'); + +/** + * @typedef {Object} RetrieveOptions + * @property {string} thread_id - The ID of the thread to retrieve. + * @property {string} run_id - The ID of the run to retrieve. + * @property {number} [timeout] - Optional timeout for the API call. + * @property {number} [maxRetries] - TODO: not yet implemented; Optional maximum number of retries for the API call. + * @property {OpenAIClient} openai - Configuration and credentials for OpenAI API access. + */ + +/** + * Asynchronously retrieves data from an API endpoint based on provided thread and run IDs. + * + * @param {RetrieveOptions} options - The options for the retrieve operation. + * @returns {Promise} The data retrieved from the API. + */ +async function retrieveRun({ thread_id, run_id, timeout, openai }) { + const { apiKey, baseURL, httpAgent, organization } = openai; + const url = `${baseURL}/threads/${thread_id}/runs/${run_id}`; + + const headers = { + Authorization: `Bearer ${apiKey}`, + 'OpenAI-Beta': 'assistants=v1', + }; + + if (organization) { + headers['OpenAI-Organization'] = organization; + } + + try { + const axiosConfig = { + headers: headers, + timeout: timeout, + }; + + if (httpAgent) { + axiosConfig.httpAgent = httpAgent; + axiosConfig.httpsAgent = httpAgent; + } + + const response = await axios.get(url, axiosConfig); + return response.data; + } catch (error) { + const logMessage = '[retrieveRun] Failed to retrieve run data:'; + if (error.response) { + logger.error( + `${logMessage} The request was made and the server responded with a status code that falls out of the range of 2xx: ${ + error.message ? error.message : '' + }`, + { + headers: error.response.headers, + status: error.response.status, + data: error.response.data, + }, + ); + } else if (error.request) { + logger.error( + `${logMessage} The request was made but no response was received: ${ + error.message ? error.message : '' + }`, + { + request: error.request, + }, + ); + } else { + logger.error(`${logMessage} Something happened in setting up the request`, { + message: error.message ? error.message : '', + }); + } + throw error; + } +} + +module.exports = { retrieveRun }; diff --git a/api/server/services/Threads/index.js b/api/server/services/Threads/index.js new file mode 100644 index 000000000..850cddc4e --- /dev/null +++ b/api/server/services/Threads/index.js @@ -0,0 +1,5 @@ +const manage = require('./manage'); + +module.exports = { + ...manage, +}; diff --git a/api/server/services/Threads/manage.js b/api/server/services/Threads/manage.js new file mode 100644 index 000000000..125277860 --- /dev/null +++ b/api/server/services/Threads/manage.js @@ -0,0 +1,495 @@ +const { v4 } = require('uuid'); +const { + EModelEndpoint, + Constants, + defaultOrderQuery, + ContentTypes, +} = require('librechat-data-provider'); +const { recordMessage, getMessages } = require('~/models/Message'); +const { saveConvo } = require('~/models/Conversation'); +const spendTokens = require('~/models/spendTokens'); +const { countTokens } = require('~/server/utils'); + +/** + * Initializes a new thread or adds messages to an existing thread. + * + * @param {Object} params - The parameters for initializing a thread. + * @param {OpenAIClient} params.openai - The OpenAI client instance. + * @param {Object} params.body - The body of the request. + * @param {ThreadMessage[]} params.body.messages - A list of messages to start the thread with. + * @param {Object} [params.body.metadata] - Optional metadata for the thread. + * @param {string} [params.thread_id] - Optional existing thread ID. If provided, a message will be added to this thread. + * @return {Promise} A promise that resolves to the newly created thread object or the updated thread object. + */ +async function initThread({ openai, body, thread_id: _thread_id }) { + let thread = {}; + const messages = []; + if (_thread_id) { + const message = await openai.beta.threads.messages.create(_thread_id, body.messages[0]); + messages.push(message); + } else { + thread = await openai.beta.threads.create(body); + } + + const thread_id = _thread_id ?? thread.id; + return { messages, thread_id, ...thread }; +} + +/** + * Saves a user message to the DB in the Assistants endpoint format. + * + * @param {Object} params - The parameters of the user message + * @param {string} params.user - The user's ID. + * @param {string} params.text - The user's prompt. + * @param {string} params.messageId - The user message Id. + * @param {string} params.model - The model used by the assistant. + * @param {string} params.assistant_id - The current assistant Id. + * @param {string} params.thread_id - The thread Id. + * @param {string} params.conversationId - The message's conversationId + * @param {string} [params.parentMessageId] - Optional if initial message. + * Defaults to Constants.NO_PARENT. + * @param {string} [params.instructions] - Optional: from preset for `instructions` field. + * Overrides the instructions of the assistant. + * @param {string} [params.promptPrefix] - Optional: from preset for `additional_instructions` field. + * @param {import('librechat-data-provider').TFile[]} [params.files] - Optional. List of Attached File Objects. + * @param {string[]} [params.file_ids] - Optional. List of File IDs attached to the userMessage. + * @return {Promise} A promise that resolves to the created run object. + */ +async function saveUserMessage(params) { + const tokenCount = await countTokens(params.text); + + // todo: do this on the frontend + // const { file_ids = [] } = params; + // let content; + // if (file_ids.length) { + // content = [ + // { + // value: params.text, + // }, + // ...( + // file_ids + // .filter(f => f) + // .map((file_id) => ({ + // file_id, + // })) + // ), + // ]; + // } + + const userMessage = { + user: params.user, + endpoint: EModelEndpoint.assistants, + messageId: params.messageId, + conversationId: params.conversationId, + parentMessageId: params.parentMessageId ?? Constants.NO_PARENT, + /* For messages, use the assistant_id instead of model */ + model: params.assistant_id, + thread_id: params.thread_id, + sender: 'User', + text: params.text, + isCreatedByUser: true, + tokenCount, + }; + + const convo = { + endpoint: EModelEndpoint.assistants, + conversationId: params.conversationId, + promptPrefix: params.promptPrefix, + instructions: params.instructions, + assistant_id: params.assistant_id, + model: params.model, + }; + + if (params.files?.length) { + userMessage.files = params.files.map(({ file_id }) => ({ file_id })); + convo.file_ids = params.file_ids; + } + + const message = await recordMessage(userMessage); + await saveConvo(params.user, convo); + + return message; +} + +/** + * Saves an Assistant message to the DB in the Assistants endpoint format. + * + * @param {Object} params - The parameters of the Assistant message + * @param {string} params.user - The user's ID. + * @param {string} params.messageId - The message Id. + * @param {string} params.assistant_id - The assistant Id. + * @param {string} params.thread_id - The thread Id. + * @param {string} params.model - The model used by the assistant. + * @param {ContentPart[]} params.content - The message content parts. + * @param {string} params.conversationId - The message's conversationId + * @param {string} params.parentMessageId - The latest user message that triggered this response. + * @param {string} [params.instructions] - Optional: from preset for `instructions` field. + * Overrides the instructions of the assistant. + * @param {string} [params.promptPrefix] - Optional: from preset for `additional_instructions` field. + * @return {Promise} A promise that resolves to the created run object. + */ +async function saveAssistantMessage(params) { + const text = params.content.reduce((acc, part) => { + if (!part.value) { + return acc; + } + + return acc + ' ' + part.value; + }, ''); + + // const tokenCount = // TODO: need to count each content part + + const message = await recordMessage({ + user: params.user, + endpoint: EModelEndpoint.assistants, + messageId: params.messageId, + conversationId: params.conversationId, + parentMessageId: params.parentMessageId, + thread_id: params.thread_id, + /* For messages, use the assistant_id instead of model */ + model: params.assistant_id, + content: params.content, + sender: 'Assistant', + isCreatedByUser: false, + text: text.trim(), + // tokenCount, + }); + + await saveConvo(params.user, { + endpoint: EModelEndpoint.assistants, + conversationId: params.conversationId, + promptPrefix: params.promptPrefix, + instructions: params.instructions, + assistant_id: params.assistant_id, + model: params.model, + }); + + return message; +} + +/** + * Records LibreChat messageId to all response messages' metadata + * + * @param {Object} params - The parameters for initializing a thread. + * @param {OpenAIClient} params.openai - The OpenAI client instance. + * @param {string} params.thread_id - Response thread ID. + * @param {string} params.messageId - The response `messageId` generated by LibreChat. + * @param {StepMessage[] | Message[]} params.messages - A list of messages to start the thread with. + * @return {Promise} A promise that resolves to the updated messages + */ +async function addThreadMetadata({ openai, thread_id, messageId, messages }) { + const promises = []; + for (const message of messages) { + promises.push( + openai.beta.threads.messages.update(thread_id, message.id, { + metadata: { + messageId, + }, + }), + ); + } + + return await Promise.all(promises); +} + +/** + * Synchronizes LibreChat messages to Thread Messages. + * Updates the LibreChat DB with any missing Thread Messages and + * updates the missing Thread Messages' metadata with their corresponding db messageId's. + * + * Also updates the existing conversation's file_ids with any new file_ids. + * + * @param {Object} params - The parameters for synchronizing messages. + * @param {OpenAIClient} params.openai - The OpenAI client instance. + * @param {TMessage[]} params.dbMessages - The LibreChat DB messages. + * @param {ThreadMessage[]} params.apiMessages - The thread messages from the API. + * @param {string} params.conversationId - The current conversation ID. + * @param {string} params.thread_id - The current thread ID. + * @param {string} [params.assistant_id] - The current assistant ID. + * @return {Promise} A promise that resolves to the updated messages + */ +async function syncMessages({ + openai, + apiMessages, + dbMessages, + conversationId, + thread_id, + assistant_id, +}) { + let result = []; + let dbMessageMap = new Map(dbMessages.map((msg) => [msg.messageId, msg])); + + const modifyPromises = []; + const recordPromises = []; + + /** + * + * Modify API message and save newMessage to DB + * + * @param {Object} params - The parameters object + * @param {TMessage} params.dbMessage + * @param {dbMessage} params.apiMessage + */ + const processNewMessage = async ({ dbMessage, apiMessage }) => { + recordPromises.push(recordMessage({ ...dbMessage, user: openai.req.user.id })); + + if (!apiMessage.id.includes('msg_')) { + return; + } + + if (dbMessage.aggregateMessages?.length > 1) { + modifyPromises.push( + addThreadMetadata({ + openai, + thread_id, + messageId: dbMessage.messageId, + messages: dbMessage.aggregateMessages, + }), + ); + return; + } + + modifyPromises.push( + openai.beta.threads.messages.update(thread_id, apiMessage.id, { + metadata: { + messageId: dbMessage.messageId, + }, + }), + ); + }; + + let lastMessage = null; + + for (let i = 0; i < apiMessages.length; i++) { + const apiMessage = apiMessages[i]; + + // Check if the message exists in the database based on metadata + const dbMessageId = apiMessage.metadata && apiMessage.metadata.messageId; + let dbMessage = dbMessageMap.get(dbMessageId); + + if (dbMessage) { + // If message exists in DB, use its messageId and update parentMessageId + dbMessage.parentMessageId = lastMessage ? lastMessage.messageId : Constants.NO_PARENT; + lastMessage = dbMessage; + result.push(dbMessage); + continue; + } + + if (apiMessage.role === 'assistant' && lastMessage && lastMessage.role === 'assistant') { + // Aggregate assistant messages + lastMessage.content = [...lastMessage.content, ...apiMessage.content]; + lastMessage.files = [...(lastMessage.files ?? []), ...(apiMessage.files ?? [])]; + lastMessage.aggregateMessages.push({ id: apiMessage.id }); + } else { + // Handle new or missing message + const newMessage = { + thread_id, + conversationId, + messageId: v4(), + endpoint: EModelEndpoint.assistants, + parentMessageId: lastMessage ? lastMessage.messageId : Constants.NO_PARENT, + role: apiMessage.role, + isCreatedByUser: apiMessage.role === 'user', + // TODO: process generated files in content parts + content: apiMessage.content, + aggregateMessages: [{ id: apiMessage.id }], + model: apiMessage.role === 'user' ? null : apiMessage.assistant_id, + user: openai.req.user.id, + }; + + if (apiMessage.file_ids?.length) { + // TODO: retrieve file objects from API + newMessage.files = apiMessage.file_ids.map((file_id) => ({ file_id })); + } + + /* Assign assistant_id if defined */ + if (assistant_id && apiMessage.role === 'assistant' && !newMessage.model) { + apiMessage.model = assistant_id; + newMessage.model = assistant_id; + } + + result.push(newMessage); + lastMessage = newMessage; + + if (apiMessage.role === 'user') { + processNewMessage({ dbMessage: newMessage, apiMessage }); + continue; + } + } + + const nextMessage = apiMessages[i + 1]; + const processAssistant = !nextMessage || nextMessage.role === 'user'; + + if (apiMessage.role === 'assistant' && processAssistant) { + processNewMessage({ dbMessage: lastMessage, apiMessage }); + } + } + + const attached_file_ids = apiMessages.reduce((acc, msg) => { + if (msg.role === 'user' && msg.file_ids?.length) { + return [...acc, ...msg.file_ids]; + } + + return acc; + }, []); + + await Promise.all(modifyPromises); + await Promise.all(recordPromises); + + await saveConvo(openai.req.user.id, { + conversationId, + file_ids: attached_file_ids, + }); + + return result; +} + +/** + * Maps messages to their corresponding steps. Steps with message creation will be paired with their messages, + * while steps without message creation will be returned as is. + * + * @param {RunStep[]} steps - An array of steps from the run. + * @param {Message[]} messages - An array of message objects. + * @returns {(StepMessage | RunStep)[]} An array where each element is either a step with its corresponding message (StepMessage) or a step without a message (RunStep). + */ +function mapMessagesToSteps(steps, messages) { + // Create a map of messages indexed by their IDs for efficient lookup + const messageMap = messages.reduce((acc, msg) => { + acc[msg.id] = msg; + return acc; + }, {}); + + // Map each step to its corresponding message, or return the step as is if no message ID is present + return steps + .sort((a, b) => a.created_at - b.created_at) + .map((step) => { + const messageId = step.step_details?.message_creation?.message_id; + + if (messageId && messageMap[messageId]) { + return { step, message: messageMap[messageId] }; + } + return step; + }); +} + +/** + * Checks for any missing messages; if missing, + * synchronizes LibreChat messages to Thread Messages + * + * @param {Object} params - The parameters for initializing a thread. + * @param {OpenAIClient} params.openai - The OpenAI client instance. + * @param {string} [params.latestMessageId] - Optional: The latest message ID from LibreChat. + * @param {string} params.thread_id - Response thread ID. + * @param {string} params.run_id - Response Run ID. + * @param {string} params.conversationId - LibreChat conversation ID. + * @return {Promise} A promise that resolves to the updated messages + */ +async function checkMessageGaps({ openai, latestMessageId, thread_id, run_id, conversationId }) { + const promises = []; + promises.push(openai.beta.threads.messages.list(thread_id, defaultOrderQuery)); + promises.push(openai.beta.threads.runs.steps.list(thread_id, run_id)); + /** @type {[{ data: ThreadMessage[] }, { data: RunStep[] }]} */ + const [response, stepsResponse] = await Promise.all(promises); + + const steps = mapMessagesToSteps(stepsResponse.data, response.data); + /** @type {ThreadMessage} */ + const currentMessage = { + id: v4(), + content: [], + assistant_id: null, + created_at: Math.floor(new Date().getTime() / 1000), + object: 'thread.message', + role: 'assistant', + run_id, + thread_id, + metadata: { + messageId: latestMessageId, + }, + }; + + for (const step of steps) { + if (!currentMessage.assistant_id && step.assistant_id) { + currentMessage.assistant_id = step.assistant_id; + } + if (step.message) { + currentMessage.id = step.message.id; + currentMessage.created_at = step.message.created_at; + currentMessage.content = currentMessage.content.concat(step.message.content); + } else if (step.step_details?.type === 'tool_calls' && step.step_details?.tool_calls?.length) { + currentMessage.content = currentMessage.content.concat( + step.step_details?.tool_calls.map((toolCall) => ({ + [ContentTypes.TOOL_CALL]: { + ...toolCall, + progress: 2, + }, + type: ContentTypes.TOOL_CALL, + })), + ); + } + } + + let addedCurrentMessage = false; + const apiMessages = response.data.map((msg) => { + if (msg.id === currentMessage.id) { + addedCurrentMessage = true; + return currentMessage; + } + return msg; + }); + + if (!addedCurrentMessage) { + apiMessages.push(currentMessage); + } + + const dbMessages = await getMessages({ conversationId }); + const assistant_id = dbMessages?.[0]?.model; + + const syncedMessages = await syncMessages({ + openai, + dbMessages, + apiMessages, + thread_id, + conversationId, + assistant_id, + }); + + return Object.values( + [...dbMessages, ...syncedMessages].reduce( + (acc, message) => ({ ...acc, [message.messageId]: message }), + {}, + ), + ); +} + +/** + * Records token usage for a given completion request. + * + * @param {Object} params - The parameters for initializing a thread. + * @param {number} params.prompt_tokens - The number of prompt tokens used. + * @param {number} params.completion_tokens - The number of completion tokens used. + * @param {string} params.model - The model used by the assistant run. + * @param {string} params.user - The user's ID. + * @param {string} params.conversationId - LibreChat conversation ID. + * @return {Promise} A promise that resolves to the updated messages + */ +const recordUsage = async ({ prompt_tokens, completion_tokens, model, user, conversationId }) => { + await spendTokens( + { + user, + model, + context: 'message', + conversationId, + }, + { promptTokens: prompt_tokens, completionTokens: completion_tokens }, + ); +}; + +module.exports = { + initThread, + recordUsage, + saveUserMessage, + checkMessageGaps, + addThreadMetadata, + mapMessagesToSteps, + saveAssistantMessage, +}; diff --git a/api/server/services/ToolService.js b/api/server/services/ToolService.js new file mode 100644 index 000000000..c24d31fcd --- /dev/null +++ b/api/server/services/ToolService.js @@ -0,0 +1,317 @@ +const fs = require('fs'); +const path = require('path'); +const { StructuredTool } = require('langchain/tools'); +const { zodToJsonSchema } = require('zod-to-json-schema'); +const { Calculator } = require('langchain/tools/calculator'); +const { + ContentTypes, + imageGenTools, + openapiToFunction, + validateAndParseOpenAPISpec, + actionDelimiter, +} = require('librechat-data-provider'); +const { loadActionSets, createActionTool } = require('./ActionService'); +const { processFileURL } = require('~/server/services/Files/process'); +const { loadTools } = require('~/app/clients/tools/util'); +const { redactMessage } = require('~/config/parsers'); +const { sleep } = require('./Runs/handle'); +const { logger } = require('~/config'); + +/** + * Loads and formats tools from the specified tool directory. + * + * The directory is scanned for JavaScript files, excluding any files in the filter set. + * For each file, it attempts to load the file as a module and instantiate a class, if it's a subclass of `StructuredTool`. + * Each tool instance is then formatted to be compatible with the OpenAI Assistant. + * Additionally, instances of LangChain Tools are included in the result. + * + * @param {object} params - The parameters for the function. + * @param {string} params.directory - The directory path where the tools are located. + * @param {Set} [params.filter=new Set()] - A set of filenames to exclude from loading. + * @returns {Record} An object mapping each tool's plugin key to its instance. + */ +function loadAndFormatTools({ directory, filter = new Set() }) { + const tools = []; + /* Structured Tools Directory */ + const files = fs.readdirSync(directory); + + for (const file of files) { + if (file.endsWith('.js') && !filter.has(file)) { + const filePath = path.join(directory, file); + let ToolClass = null; + try { + ToolClass = require(filePath); + } catch (error) { + logger.error(`[loadAndFormatTools] Error loading tool from ${filePath}:`, error); + continue; + } + + if (!ToolClass) { + continue; + } + + if (ToolClass.prototype instanceof StructuredTool) { + /** @type {StructuredTool | null} */ + let toolInstance = null; + try { + toolInstance = new ToolClass({ override: true }); + } catch (error) { + logger.error( + `[loadAndFormatTools] Error initializing \`${file}\` tool; if it requires authentication, is the \`override\` field configured?`, + error, + ); + continue; + } + + if (!toolInstance) { + continue; + } + + const formattedTool = formatToOpenAIAssistantTool(toolInstance); + tools.push(formattedTool); + } + } + } + + /** + * Basic Tools; schema: { input: string } + */ + const basicToolInstances = [new Calculator()]; + + for (const toolInstance of basicToolInstances) { + const formattedTool = formatToOpenAIAssistantTool(toolInstance); + tools.push(formattedTool); + } + + return tools.reduce((map, tool) => { + map[tool.function.name] = tool; + return map; + }, {}); +} + +/** + * Formats a `StructuredTool` instance into a format that is compatible + * with OpenAI's ChatCompletionFunctions. It uses the `zodToJsonSchema` + * function to convert the schema of the `StructuredTool` into a JSON + * schema, which is then used as the parameters for the OpenAI function. + * + * @param {StructuredTool} tool - The StructuredTool to format. + * @returns {FunctionTool} The OpenAI Assistant Tool. + */ +function formatToOpenAIAssistantTool(tool) { + return { + type: 'function', + function: { + name: tool.name, + description: tool.description, + parameters: zodToJsonSchema(tool.schema), + }, + }; +} + +/** + * Processes return required actions from run. + * + * @param {OpenAIClient} openai - OpenAI Client. + * @param {RequiredAction[]} requiredActions - The required actions to submit outputs for. + * @returns {Promise} The outputs of the tools. + * + */ +async function processRequiredActions(openai, requiredActions) { + logger.debug( + `[required actions] user: ${openai.req.user.id} | thread_id: ${requiredActions[0].thread_id} | run_id: ${requiredActions[0].run_id}`, + requiredActions, + ); + const tools = requiredActions.map((action) => action.tool); + const loadedTools = await loadTools({ + user: openai.req.user.id, + model: openai.req.body.model ?? 'gpt-3.5-turbo-1106', + tools, + functions: true, + options: { + processFileURL, + openAIApiKey: openai.apiKey, + fileStrategy: openai.req.app.locals.fileStrategy, + returnMetadata: true, + }, + skipSpecs: true, + }); + + const ToolMap = loadedTools.reduce((map, tool) => { + map[tool.name] = tool; + return map; + }, {}); + + const promises = []; + + /** @type {Action[]} */ + let actionSets = []; + let isActionTool = false; + const ActionToolMap = {}; + const ActionBuildersMap = {}; + + for (let i = 0; i < requiredActions.length; i++) { + const currentAction = requiredActions[i]; + let tool = ToolMap[currentAction.tool] ?? ActionToolMap[currentAction.tool]; + + const handleToolOutput = async (output) => { + requiredActions[i].output = output; + + /** @type {FunctionToolCall & PartMetadata} */ + const toolCall = { + function: { + name: currentAction.tool, + arguments: JSON.stringify(currentAction.toolInput), + output, + }, + id: currentAction.toolCallId, + type: 'function', + progress: 1, + action: isActionTool, + }; + + const toolCallIndex = openai.mappedOrder.get(toolCall.id); + + if (imageGenTools.has(currentAction.tool)) { + const imageOutput = output; + toolCall.function.output = `${currentAction.tool} displayed an image. All generated images are already plainly visible, so don't repeat the descriptions in detail. Do not list download links as they are available in the UI already. The user may download the images by clicking on them, but do not mention anything about downloading to the user.`; + + // Streams the "Finished" state of the tool call in the UI + openai.addContentData({ + [ContentTypes.TOOL_CALL]: toolCall, + index: toolCallIndex, + type: ContentTypes.TOOL_CALL, + }); + + await sleep(500); + + /** @type {ImageFile} */ + const imageDetails = { + ...imageOutput, + ...currentAction.toolInput, + }; + + const image_file = { + [ContentTypes.IMAGE_FILE]: imageDetails, + type: ContentTypes.IMAGE_FILE, + // Replace the tool call output with Image file + index: toolCallIndex, + }; + + openai.addContentData(image_file); + + // Update the stored tool call + openai.seenToolCalls.set(toolCall.id, toolCall); + + return { + tool_call_id: currentAction.toolCallId, + output: toolCall.function.output, + }; + } + + openai.seenToolCalls.set(toolCall.id, toolCall); + openai.addContentData({ + [ContentTypes.TOOL_CALL]: toolCall, + index: toolCallIndex, + type: ContentTypes.TOOL_CALL, + // TODO: to append tool properties to stream, pass metadata rest to addContentData + // result: tool.result, + }); + + return { + tool_call_id: currentAction.toolCallId, + output, + }; + }; + + if (!tool) { + // throw new Error(`Tool ${currentAction.tool} not found.`); + + if (!actionSets.length) { + actionSets = + (await loadActionSets({ + user: openai.req.user.id, + assistant_id: openai.req.body.assistant_id, + })) ?? []; + } + + const actionSet = actionSets.find((action) => + currentAction.tool.includes(action.metadata.domain), + ); + + if (!actionSet) { + // TODO: try `function` if no action set is found + // throw new Error(`Tool ${currentAction.tool} not found.`); + continue; + } + + let builders = ActionBuildersMap[actionSet.metadata.domain]; + + if (!builders) { + const validationResult = validateAndParseOpenAPISpec(actionSet.metadata.raw_spec); + if (!validationResult.spec) { + throw new Error( + `Invalid spec: user: ${openai.req.user.id} | thread_id: ${requiredActions[0].thread_id} | run_id: ${requiredActions[0].run_id}`, + ); + } + const { requestBuilders } = openapiToFunction(validationResult.spec); + ActionToolMap[actionSet.metadata.domain] = requestBuilders; + builders = requestBuilders; + } + + const functionName = currentAction.tool.replace( + `${actionDelimiter}${actionSet.metadata.domain}`, + '', + ); + const requestBuilder = builders[functionName]; + + if (!requestBuilder) { + // throw new Error(`Tool ${currentAction.tool} not found.`); + continue; + } + + tool = createActionTool({ action: actionSet, requestBuilder }); + isActionTool = !!tool; + ActionToolMap[currentAction.tool] = tool; + } + + if (currentAction.tool === 'calculator') { + currentAction.toolInput = currentAction.toolInput.input; + } + + try { + const promise = tool + ._call(currentAction.toolInput) + .then(handleToolOutput) + .catch((error) => { + logger.error(`Error processing tool ${currentAction.tool}`, error); + return { + tool_call_id: currentAction.toolCallId, + output: `Error processing tool ${currentAction.tool}: ${redactMessage(error.message)}`, + }; + }); + promises.push(promise); + } catch (error) { + logger.error( + `tool_call_id: ${currentAction.toolCallId} | Error processing tool ${currentAction.tool}`, + error, + ); + promises.push( + Promise.resolve({ + tool_call_id: currentAction.toolCallId, + error: error.message, + }), + ); + } + } + + return { + tool_outputs: await Promise.all(promises), + }; +} + +module.exports = { + formatToOpenAIAssistantTool, + loadAndFormatTools, + processRequiredActions, +}; diff --git a/api/server/utils/countTokens.js b/api/server/utils/countTokens.js index 34c070aa8..641e38610 100644 --- a/api/server/utils/countTokens.js +++ b/api/server/utils/countTokens.js @@ -3,6 +3,20 @@ const p50k_base = require('tiktoken/encoders/p50k_base.json'); const cl100k_base = require('tiktoken/encoders/cl100k_base.json'); const logger = require('~/config/winston'); +/** + * Counts the number of tokens in a given text using a specified encoding model. + * + * This function utilizes the 'Tiktoken' library to encode text based on the selected model. + * It supports two models, 'text-davinci-003' and 'gpt-3.5-turbo', each with its own encoding strategy. + * For 'text-davinci-003', the 'p50k_base' encoder is used, whereas for other models, the 'cl100k_base' encoder is applied. + * In case of an error during encoding, the error is logged, and the function returns 0. + * + * @async + * @param {string} text - The text to be tokenized. Defaults to an empty string if not provided. + * @param {string} modelName - The name of the model used for tokenizing. Defaults to 'gpt-3.5-turbo'. + * @returns {Promise} The number of tokens in the provided text. Returns 0 if an error occurs. + * @throws Logs the error to a logger and rethrows if any error occurs during tokenization. + */ const countTokens = async (text = '', modelName = 'gpt-3.5-turbo') => { let encoder = null; try { diff --git a/api/server/utils/crypto.js b/api/server/utils/crypto.js index 9b5fed67c..911819634 100644 --- a/api/server/utils/crypto.js +++ b/api/server/utils/crypto.js @@ -19,4 +19,23 @@ function decrypt(encryptedValue) { return decrypted; } -module.exports = { encrypt, decrypt }; +// Programatically generate iv +function encryptV2(value) { + const gen_iv = crypto.randomBytes(16); + const cipher = crypto.createCipheriv(algorithm, key, gen_iv); + let encrypted = cipher.update(value, 'utf8', 'hex'); + encrypted += cipher.final('hex'); + return gen_iv.toString('hex') + ':' + encrypted; +} + +function decryptV2(encryptedValue) { + const parts = encryptedValue.split(':'); + const gen_iv = Buffer.from(parts.shift(), 'hex'); + const encrypted = parts.join(':'); + const decipher = crypto.createDecipheriv(algorithm, key, gen_iv); + let decrypted = decipher.update(encrypted, 'hex', 'utf8'); + decrypted += decipher.final('utf8'); + return decrypted; +} + +module.exports = { encrypt, decrypt, encryptV2, decryptV2 }; diff --git a/api/server/utils/files.js b/api/server/utils/files.js new file mode 100644 index 000000000..63cf95d3a --- /dev/null +++ b/api/server/utils/files.js @@ -0,0 +1,47 @@ +const sharp = require('sharp'); + +/** + * Determines the file type of a buffer + * @param {Buffer} dataBuffer + * @param {boolean} [returnFileType=false] - Optional. If true, returns the file type instead of the file extension. + * @returns {Promise} - Returns the file extension if found, else null + * */ +const determineFileType = async (dataBuffer, returnFileType) => { + const fileType = await import('file-type'); + const type = await fileType.fileTypeFromBuffer(dataBuffer); + if (returnFileType) { + return type; + } + return type ? type.ext : null; // Returns extension if found, else null +}; + +/** + * Get buffer metadata + * @param {Buffer} buffer + * @returns {Promise<{ bytes: number, type: string, dimensions: Record, extension: string}>} + */ +const getBufferMetadata = async (buffer) => { + const fileType = await determineFileType(buffer, true); + const bytes = buffer.length; + let extension = fileType ? fileType.ext : 'unknown'; + + /** @type {Record} */ + let dimensions = {}; + + if (fileType && fileType.mime.startsWith('image/') && extension !== 'unknown') { + const imageMetadata = await sharp(buffer).metadata(); + dimensions = { + width: imageMetadata.width, + height: imageMetadata.height, + }; + } + + return { + bytes, + type: fileType?.mime ?? 'unknown', + dimensions, + extension, + }; +}; + +module.exports = { determineFileType, getBufferMetadata }; diff --git a/api/server/utils/index.js b/api/server/utils/index.js index d51cdd1d4..c1637a678 100644 --- a/api/server/utils/index.js +++ b/api/server/utils/index.js @@ -5,6 +5,7 @@ const handleText = require('./handleText'); const cryptoUtils = require('./crypto'); const citations = require('./citations'); const sendEmail = require('./sendEmail'); +const files = require('./files'); const math = require('./math'); module.exports = { @@ -15,5 +16,6 @@ module.exports = { countTokens, removePorts, sendEmail, + ...files, math, }; diff --git a/api/server/utils/queue.js b/api/server/utils/queue.js new file mode 100644 index 000000000..73d819205 --- /dev/null +++ b/api/server/utils/queue.js @@ -0,0 +1,58 @@ +/** + * A leaky bucket queue structure to manage API requests. + * @type {{queue: Array, interval: NodeJS.Timer | null}} + */ +const _LB = { + queue: [], + interval: null, +}; + +/** + * Interval in milliseconds to control the rate of API requests. + * Adjust the interval according to your rate limit needs. + */ +const _LB_INTERVAL_MS = Math.ceil(1000 / 60); // 60 req/s + +/** + * Executes the next function in the leaky bucket queue. + * This function is called at regular intervals defined by _LB_INTERVAL_MS. + */ +const _LB_EXEC_NEXT = async () => { + if (_LB.queue.length === 0) { + clearInterval(_LB.interval); + _LB.interval = null; + return; + } + + const next = _LB.queue.shift(); + if (!next) { + return; + } + + const { asyncFunc, args, callback } = next; + + try { + const data = await asyncFunc(...args); + callback(null, data); + } catch (e) { + callback(e); + } +}; + +/** + * Adds an async function call to the leaky bucket queue. + * @param {Function} asyncFunc - The async function to be executed. + * @param {Array} args - Arguments to pass to the async function. + * @param {Function} callback - Callback function for handling the result or error. + */ +function LB_QueueAsyncCall(asyncFunc, args, callback) { + _LB.queue.push({ asyncFunc, args, callback }); + + if (_LB.interval === null) { + _LB.interval = setInterval(_LB_EXEC_NEXT, _LB_INTERVAL_MS); + } +} + +module.exports = { + LB_QueueAsyncCall, +}; diff --git a/api/server/utils/streamResponse.js b/api/server/utils/streamResponse.js index 3511f144c..109a40746 100644 --- a/api/server/utils/streamResponse.js +++ b/api/server/utils/streamResponse.js @@ -16,12 +16,12 @@ const handleError = (res, message) => { /** * Sends message data in Server Sent Events format. - * @param {object} res - - The server response. - * @param {string} message - The message to be sent. + * @param {Express.Response} res - - The server response. + * @param {string | Object} message - The message to be sent. * @param {'message' | 'error' | 'cancel'} event - [Optional] The type of event. Default is 'message'. */ const sendMessage = (res, message, event = 'message') => { - if (message.length === 0) { + if (typeof message === 'string' && message.length === 0) { return; } res.write(`event: ${event}\ndata: ${JSON.stringify(message)}\n\n`); diff --git a/api/strategies/process.js b/api/strategies/process.js index 570637eec..9b7910231 100644 --- a/api/strategies/process.js +++ b/api/strategies/process.js @@ -1,5 +1,6 @@ const { FileSources } = require('librechat-data-provider'); -const uploadAvatar = require('~/server/services/Files/images/avatar'); +const { getStrategyFunctions } = require('~/server/services/Files/strategies'); +const { resizeAvatar } = require('~/server/services/Files/images/avatar'); const User = require('~/models/User'); /** @@ -24,8 +25,12 @@ const handleExistingUser = async (oldUser, avatarUrl) => { await oldUser.save(); } else if (!isLocal && (oldUser.avatar === null || !oldUser.avatar.includes('?manual=true'))) { const userId = oldUser._id; - const newavatarUrl = await uploadAvatar({ userId, input: avatarUrl, fileStrategy }); - oldUser.avatar = newavatarUrl; + const webPBuffer = await resizeAvatar({ + userId, + input: avatarUrl, + }); + const { processAvatar } = getStrategyFunctions(fileStrategy); + oldUser.avatar = await processAvatar({ buffer: webPBuffer, userId }); await oldUser.save(); } }; @@ -78,8 +83,12 @@ const createNewUser = async ({ if (!isLocal) { const userId = newUser._id; - const newavatarUrl = await uploadAvatar({ userId, input: avatarUrl, fileStrategy }); - newUser.avatar = newavatarUrl; + const webPBuffer = await resizeAvatar({ + userId, + input: avatarUrl, + }); + const { processAvatar } = getStrategyFunctions(fileStrategy); + newUser.avatar = await processAvatar({ buffer: webPBuffer, userId }); await newUser.save(); } diff --git a/api/typedefs.js b/api/typedefs.js index bb1f68cc8..6bc792cbe 100644 --- a/api/typedefs.js +++ b/api/typedefs.js @@ -14,6 +14,12 @@ * @memberof typedefs */ +/** + * @exports AssistantDocument + * @typedef {import('librechat-data-provider').AssistantDocument} AssistantDocument + * @memberof typedefs + */ + /** * @exports OpenAIFile * @typedef {import('librechat-data-provider').File} OpenAIFile @@ -44,6 +50,12 @@ * @memberof typedefs */ +/** + * @exports TPlugin + * @typedef {import('librechat-data-provider').TPlugin} TPlugin + * @memberof typedefs + */ + /** * @exports FileSources * @typedef {import('librechat-data-provider').FileSources} FileSources @@ -51,12 +63,63 @@ */ /** - * @exports ImageMetadata - * @typedef {Object} ImageMetadata + * @exports TMessage + * @typedef {import('librechat-data-provider').TMessage} TMessage + * @memberof typedefs + */ + +/** + * @exports ImageFile + * @typedef {import('librechat-data-provider').ImageFile} ImageFile + * @memberof typedefs + */ + +/** + * @exports ActionRequest + * @typedef {import('librechat-data-provider').ActionRequest} ActionRequest + * @memberof typedefs + */ + +/** + * @exports Action + * @typedef {import('librechat-data-provider').Action} Action + * @memberof typedefs + */ + +/** + * @exports ActionMetadata + * @typedef {import('librechat-data-provider').ActionMetadata} ActionMetadata + * @memberof typedefs + */ + +/** + * @exports ActionAuth + * @typedef {import('librechat-data-provider').ActionAuth} ActionAuth + * @memberof typedefs + */ + +/** + * @exports DeleteFilesBody + * @typedef {import('librechat-data-provider').DeleteFilesBody} DeleteFilesBody + * @memberof typedefs + */ + +/** + * @exports FileMetadata + * @typedef {Object} FileMetadata * @property {string} file_id - The identifier of the file. * @property {string} [temp_file_id] - The temporary identifier of the file. + * @property {string} endpoint - The conversation endpoint origin for the file upload. + * @property {string} [assistant_id] - The assistant ID if file upload is in the `knowledge` context. + * @memberof typedefs + */ + +/** + * @typedef {Object} ImageOnlyMetadata * @property {number} width - The width of the image. * @property {number} height - The height of the image. + * + * @typedef {FileMetadata & ImageOnlyMetadata} ImageMetadata * @memberof typedefs */ @@ -90,6 +153,36 @@ * @memberof typedefs */ +/** + * @exports ContentPart + * @typedef {import('librechat-data-provider').ContentPart} ContentPart + * @memberof typedefs + */ + +/** + * @exports StepTypes + * @typedef {import('librechat-data-provider').StepTypes} StepTypes + * @memberof typedefs + */ + +/** + * @exports TContentData + * @typedef {import('librechat-data-provider').TContentData} TContentData + * @memberof typedefs + */ + +/** + * @exports ContentPart + * @typedef {import('librechat-data-provider').ContentPart} ContentPart + * @memberof typedefs + */ + +/** + * @exports PartMetadata + * @typedef {import('librechat-data-provider').PartMetadata} PartMetadata + * @memberof typedefs + */ + /** * @exports ThreadMessage * @typedef {import('openai').OpenAI.Beta.Threads.ThreadMessage} ThreadMessage @@ -97,14 +190,111 @@ */ /** + * @exports TAssistantEndpoint + * @typedef {import('librechat-data-provider').TAssistantEndpoint} TAssistantEndpoint + * @memberof typedefs + */ + +/** + * Represents details of the message creation by the run step, including the ID of the created message. + * + * @exports MessageCreationStepDetails + * @typedef {Object} MessageCreationStepDetails + * @property {Object} message_creation - Details of the message creation. + * @property {string} message_creation.message_id - The ID of the message that was created by this run step. + * @property {'message_creation'} type - Always 'message_creation'. + * @memberof typedefs + */ + +/** + * Represents a text log output from the Code Interpreter tool call. + * @typedef {Object} CodeLogOutput + * @property {'logs'} type - Always 'logs'. + * @property {string} logs - The text output from the Code Interpreter tool call. + */ + +/** + * Represents an image output from the Code Interpreter tool call. + * @typedef {Object} CodeImageOutput + * @property {'image'} type - Always 'image'. + * @property {Object} image - The image object. + * @property {string} image.file_id - The file ID of the image. + */ + +/** + * Details of the Code Interpreter tool call the run step was involved in. + * Includes the tool call ID, the code interpreter definition, and the type of tool call. + * + * @typedef {Object} CodeToolCall + * @property {string} id - The ID of the tool call. + * @property {Object} code_interpreter - The Code Interpreter tool call definition. + * @property {string} code_interpreter.input - The input to the Code Interpreter tool call. + * @property {Array<(CodeLogOutput | CodeImageOutput)>} code_interpreter.outputs - The outputs from the Code Interpreter tool call. + * @property {'code_interpreter'} type - The type of tool call, always 'code_interpreter'. + * @memberof typedefs + */ + +/** + * Details of a Function tool call the run step was involved in. + * Includes the tool call ID, the function definition, and the type of tool call. + * + * @typedef {Object} FunctionToolCall + * @property {string} id - The ID of the tool call object. + * @property {Object} function - The definition of the function that was called. + * @property {string} function.arguments - The arguments passed to the function. + * @property {string} function.name - The name of the function. + * @property {string|null} function.output - The output of the function, null if not submitted. + * @property {'function'} type - The type of tool call, always 'function'. + * @memberof typedefs + */ + +/** + * Details of a Retrieval tool call the run step was involved in. + * Includes the tool call ID and the type of tool call. + * + * @typedef {Object} RetrievalToolCall + * @property {string} id - The ID of the tool call object. + * @property {unknown} retrieval - An empty object for now. + * @property {'retrieval'} type - The type of tool call, always 'retrieval'. + * @memberof typedefs + */ + +/** + * Details of the tool calls involved in a run step. + * Can be associated with one of three types of tools: `code_interpreter`, `retrieval`, or `function`. + * + * @typedef {Object} ToolCallsStepDetails + * @property {Array} tool_calls - An array of tool calls the run step was involved in. + * @property {'tool_calls'} type - Always 'tool_calls'. + * @memberof typedefs + */ + +/** + * Details of the tool calls involved in a run step. + * Can be associated with one of three types of tools: `code_interpreter`, `retrieval`, or `function`. + * + * @exports StepToolCall + * @typedef {(CodeToolCall | RetrievalToolCall | FunctionToolCall) & PartMetadata} StepToolCall + * @memberof typedefs + */ + +/** + * Represents a tool call object required for certain actions in the OpenAI API, + * including the function definition and type of the tool call. + * * @exports RequiredActionFunctionToolCall - * @typedef {import('openai').OpenAI.Beta.Threads.RequiredActionFunctionToolCall} RequiredActionFunctionToolCall + * @typedef {Object} RequiredActionFunctionToolCall + * @property {string} id - The ID of the tool call, referenced when submitting tool outputs. + * @property {Object} function - The function definition associated with the tool call. + * @property {string} function.arguments - The arguments that the model expects to be passed to the function. + * @property {string} function.name - The name of the function. + * @property {'function'} type - The type of tool call the output is required for, currently always 'function'. * @memberof typedefs */ /** * @exports RunManager - * @typedef {import('./server/services/Runs/RunMananger.js').RunManager} RunManager + * @typedef {import('./server/services/Runs/RunManager.js').RunManager} RunManager * @memberof typedefs */ @@ -112,7 +302,7 @@ * @exports Thread * @typedef {Object} Thread * @property {string} id - The identifier of the thread. - * @property {string} object - The object type, always 'thread'. + * @property {'thread'} object - The object type, always 'thread'. * @property {number} created_at - The Unix timestamp (in seconds) for when the thread was created. * @property {Object} [metadata] - Optional metadata associated with the thread. * @property {Message[]} [messages] - An array of messages associated with the thread. @@ -123,12 +313,12 @@ * @exports Message * @typedef {Object} Message * @property {string} id - The identifier of the message. - * @property {string} object - The object type, always 'thread.message'. + * @property {'thread.message'} object - The object type, always 'thread.message'. * @property {number} created_at - The Unix timestamp (in seconds) for when the message was created. * @property {string} thread_id - The thread ID that this message belongs to. - * @property {string} role - The entity that produced the message. One of 'user' or 'assistant'. + * @property {'user'|'assistant'} role - The entity that produced the message. One of 'user' or 'assistant'. * @property {Object[]} content - The content of the message in an array of text and/or images. - * @property {string} content[].type - The type of content, either 'text' or 'image_file'. + * @property {'text'|'image_file'} content[].type - The type of content, either 'text' or 'image_file'. * @property {Object} [content[].text] - The text content, present if type is 'text'. * @property {string} content[].text.value - The data that makes up the text. * @property {Object[]} [content[].text.annotations] - Annotations for the text content. @@ -170,7 +360,7 @@ /** * @exports FunctionTool * @typedef {Object} FunctionTool - * @property {string} type - The type of tool, 'function'. + * @property {'function'} type - The type of tool, 'function'. * @property {Object} function - The function definition. * @property {string} function.description - A description of what the function does. * @property {string} function.name - The name of the function to be called. @@ -181,7 +371,7 @@ /** * @exports Tool * @typedef {Object} Tool - * @property {string} type - The type of tool, can be 'code_interpreter', 'retrieval', or 'function'. + * @property {'code_interpreter'|'retrieval'|'function'} type - The type of tool, can be 'code_interpreter', 'retrieval', or 'function'. * @property {FunctionTool} [function] - The function tool, present if type is 'function'. * @memberof typedefs */ @@ -194,7 +384,7 @@ * @property {number} created_at - The Unix timestamp (in seconds) for when the run was created. * @property {string} thread_id - The ID of the thread that was executed on as a part of this run. * @property {string} assistant_id - The ID of the assistant used for execution of this run. - * @property {string} status - The status of the run (e.g., 'queued', 'completed'). + * @property {'queued'|'in_progress'|'requires_action'|'cancelling'|'cancelled'|'failed'|'completed'|'expired'} status - The status of the run: queued, in_progress, requires_action, cancelling, cancelled, failed, completed, or expired. * @property {Object} [required_action] - Details on the action required to continue the run. * @property {string} required_action.type - The type of required action, always 'submit_tool_outputs'. * @property {Object} required_action.submit_tool_outputs - Details on the tool outputs needed for the run to continue. @@ -214,9 +404,15 @@ * @property {number} [completed_at] - The Unix timestamp (in seconds) for when the run was completed. * @property {string} [model] - The model that the assistant used for this run. * @property {string} [instructions] - The instructions that the assistant used for this run. + * @property {string} [additional_instructions] - Optional. Appends additional instructions + * at theend of the instructions for the run. This is useful for modifying * @property {Tool[]} [tools] - The list of tools used for this run. * @property {string[]} [file_ids] - The list of File IDs used for this run. * @property {Object} [metadata] - Metadata associated with this run. + * @property {Object} [usage] - Usage statistics related to the run. This value will be `null` if the run is not in a terminal state (i.e. `in_progress`, `queued`, etc.). + * @property {number} [usage.completion_tokens] - Number of completion tokens used over the course of the run. + * @property {number} [usage.prompt_tokens] - Number of prompt tokens used over the course of the run. + * @property {number} [usage.total_tokens] - Total number of tokens used (prompt + completion). * @memberof typedefs */ @@ -229,11 +425,11 @@ * @property {string} assistant_id - The ID of the assistant associated with the run step. * @property {string} thread_id - The ID of the thread that was run. * @property {string} run_id - The ID of the run that this run step is a part of. - * @property {string} type - The type of run step, either 'message_creation' or 'tool_calls'. - * @property {string} status - The status of the run step, can be 'in_progress', 'cancelled', 'failed', 'completed', or 'expired'. - * @property {Object} step_details - The details of the run step. + * @property {'message_creation' | 'tool_calls'} type - The type of run step. + * @property {'in_progress' | 'cancelled' | 'failed' | 'completed' | 'expired'} status - The status of the run step. + * @property {MessageCreationStepDetails | ToolCallsStepDetails} step_details - The details of the run step. * @property {Object} [last_error] - The last error associated with this run step. - * @property {string} last_error.code - One of 'server_error' or 'rate_limit_exceeded'. + * @property {'server_error' | 'rate_limit_exceeded'} last_error.code - One of 'server_error' or 'rate_limit_exceeded'. * @property {string} last_error.message - A human-readable description of the error. * @property {number} [expired_at] - The Unix timestamp (in seconds) for when the run step expired. * @property {number} [cancelled_at] - The Unix timestamp (in seconds) for when the run step was cancelled. @@ -253,8 +449,8 @@ * @property {string} assistant_id - The ID of the assistant associated with the run step. * @property {string} thread_id - The ID of the thread that was run. * @property {string} run_id - The ID of the run that this run step is a part of. - * @property {string} type - The type of run step, either 'message_creation' or 'tool_calls'. - * @property {string} status - The status of the run step, can be 'in_progress', 'cancelled', 'failed', 'completed', or 'expired'. + * @property {'message_creation'|'tool_calls'} type - The type of run step, either 'message_creation' or 'tool_calls'. + * @property {'in_progress'|'cancelled'|'failed'|'completed'|'expired'} status - The status of the run step, can be 'in_progress', 'cancelled', 'failed', 'completed', or 'expired'. * @property {Object} step_details - The details of the run step. * @property {Object} [last_error] - The last error associated with this run step. * @property {string} last_error.code - One of 'server_error' or 'rate_limit_exceeded'. @@ -362,6 +558,41 @@ * @memberof typedefs */ +/** + * @exports RequiredAction + * @typedef {Object} RequiredAction + * @property {string} tool - The name of the function. + * @property {Object} toolInput - The args to invoke the function with. + * @property {string} toolCallId - The ID of the tool call. + * @property {Run['id']} run_id - Run identifier. + * @property {Thread['id']} thread_id - Thread identifier. + * @memberof typedefs + */ + +/** + * @exports StructuredTool + * @typedef {Object} StructuredTool + * @property {string} name - The name of the function. + * @property {string} description - The description of the function. + * @property {import('zod').ZodTypeAny} schema - The structured zod schema. + * @memberof typedefs + */ + +/** + * @exports ToolOutput + * @typedef {Object} ToolOutput + * @property {string} tool_call_id - The ID of the tool call. + * @property {Object} output - The output of the tool, which can vary in structure. + * @memberof typedefs + */ + +/** + * @exports ToolOutputs + * @typedef {Object} ToolOutputs + * @property {ToolOutput[]} tool_outputs - Array of tool outputs. + * @memberof typedefs + */ + /** * @typedef {Object} ModelOptions * @property {string} modelName - The name of the model. @@ -412,3 +643,96 @@ * An endpoint's config object mapping model keys to their respective prompt, completion rates, and context limit. * @memberof typedefs */ + +/** + * @typedef {Object} ResponseMessage + * @property {string} conversationId - The ID of the conversation. + * @property {string} thread_id - The ID of the thread. + * @property {string} messageId - The ID of the message (from LibreChat). + * @property {string} parentMessageId - The ID of the parent message. + * @property {string} user - The ID of the user. + * @property {string} assistant_id - The ID of the assistant. + * @property {string} role - The role of the response. + * @property {string} model - The model used in the response. + * @property {ContentPart[]} content - The content parts accumulated from the run. + * @memberof typedefs + */ + +/** + * @typedef {Object} RunResponse + * @property {Run} run - The detailed information about the run. + * @property {RunStep[]} steps - An array of steps taken during the run. + * @property {StepMessage[]} messages - An array of messages related to the run. + * @memberof typedefs + */ + +/** + * @callback InProgressFunction + * @param {Object} params - The parameters for the in progress step. + * @param {RunStep} params.step - The step object with details about the message creation. + * @returns {Promise} - A promise that resolves when the step is processed. + * @memberof typedefs + */ + +// /** +// * @typedef {OpenAI & { +// * req: Express.Request, +// * res: Express.Response +// * getPartialText: () => string, +// * processedFileIds: Set, +// * mappedOrder: Map, +// * completeToolCallSteps: Set, +// * seenCompletedMessages: Set, +// * seenToolCalls: Map, +// * progressCallback: (options: Object) => void, +// * addContentData: (data: TContentData) => void, +// * responseMessage: ResponseMessage, +// * }} OpenAIClient - for reference only +// */ + +/** + * @typedef {Object} OpenAIClientType + * + * @property {Express.Request} req - The Express request object. + * @property {Express.Response} res - The Express response object. + * @property {?import('https-proxy-agent').HttpsProxyAgent} httpAgent - An optional HTTP proxy agent for the request. + + * @property {() => string} getPartialText - Retrieves the current tokens accumulated by `progressCallback`. + * + * Note: not used until real streaming is implemented by OpenAI. + * + * @property {string} responseText -The accumulated text values for the current run. + * @property {Set} processedFileIds - A set of IDs for processed files. + * @property {Map} mappedOrder - A map to maintain the order of individual `tool_calls` and `steps`. + * @property {Set} [attachedFileIds] - A set of user attached file ids; necessary to track which files are downloadable. + * @property {Set} completeToolCallSteps - A set of completed tool call steps. + * @property {Set} seenCompletedMessages - A set of completed messages that have been seen/processed. + * @property {Map} seenToolCalls - A map of tool calls that have been seen/processed. + * @property {(data: TContentData) => void} addContentData - Updates the response message's relevant + * @property {InProgressFunction} in_progress - Updates the response message's relevant + * content array with the part by index & sends intermediate SSE message with content data. + * + * Note: does not send intermediate SSE message for messages, which are streamed + * (may soon be streamed) directly from OpenAI API. + * + * @property {ResponseMessage} responseMessage - A message object for responses. + * + * @typedef {OpenAI & OpenAIClientType} OpenAIClient + */ + +/** + * The body of the request to create a run, specifying the assistant, model, + * instructions, and any additional parameters needed for the run. + * + * @typedef {Object} CreateRunBody + * @property {string} assistant_id - The ID of the assistant to use for this run. + * @property {string} [model] - Optional. The ID of the model to be used for this run. + * @property {string} [instructions] - Optional. Override the default system message of the assistant. + * @property {string} [additional_instructions] - Optional. Appends additional instructions + * at the end of the instructions for the run. Useful for modifying behavior on a per-run basis without overriding other instructions. + * @property {Object[]} [tools] - Optional. Override the tools the assistant can use for this run. Should include tool call ID and the type of tool call. + * @property {string[]} [file_ids] - Optional. List of File IDs the assistant can use for this run. + * **Note:** The API seems to prefer files added to messages, not runs. + * @property {Object} [metadata] - Optional. Metadata for the run. + * @memberof typedefs + */ diff --git a/client/package.json b/client/package.json index cc54a3f11..84b37c7db 100644 --- a/client/package.json +++ b/client/package.json @@ -28,14 +28,18 @@ "homepage": "https://github.com/danny-avila/LibreChat#readme", "dependencies": { "@headlessui/react": "^1.7.13", + "@radix-ui/react-accordion": "^1.1.2", "@radix-ui/react-alert-dialog": "^1.0.2", "@radix-ui/react-checkbox": "^1.0.3", + "@radix-ui/react-collapsible": "^1.0.3", "@radix-ui/react-dialog": "^1.0.2", "@radix-ui/react-dropdown-menu": "^2.0.2", "@radix-ui/react-hover-card": "^1.0.5", "@radix-ui/react-icons": "^1.3.0", "@radix-ui/react-label": "^2.0.0", "@radix-ui/react-popover": "^1.0.7", + "@radix-ui/react-radio-group": "^1.1.3", + "@radix-ui/react-select": "^2.0.0", "@radix-ui/react-separator": "^1.0.3", "@radix-ui/react-slider": "^1.1.1", "@radix-ui/react-switch": "^1.0.3", @@ -43,6 +47,7 @@ "@radix-ui/react-toast": "^1.1.5", "@radix-ui/react-tooltip": "^1.0.6", "@tanstack/react-query": "^4.28.0", + "@tanstack/react-table": "^8.11.7", "@zattoo/use-double-click": "1.2.0", "axios": "^1.3.4", "class-variance-authority": "^0.6.0", @@ -67,6 +72,7 @@ "react-hook-form": "^7.43.9", "react-lazy-load-image-component": "^1.6.0", "react-markdown": "^8.0.6", + "react-resizable-panels": "^1.0.9", "react-router-dom": "^6.11.2", "react-textarea-autosize": "^8.4.0", "react-transition-group": "^4.4.5", @@ -116,7 +122,7 @@ "tailwindcss": "^3.4.1", "ts-jest": "^29.1.0", "typescript": "^5.0.4", - "vite": "^5.0.7", + "vite": "^5.1.1", "vite-plugin-html": "^3.2.0", "vite-plugin-node-polyfills": "^0.17.0" } diff --git a/client/src/App.jsx b/client/src/App.jsx index 10c9ab9b5..ce2ec3b6d 100644 --- a/client/src/App.jsx +++ b/client/src/App.jsx @@ -6,7 +6,7 @@ import { HTML5Backend } from 'react-dnd-html5-backend'; import { ReactQueryDevtools } from '@tanstack/react-query-devtools'; import { QueryClient, QueryClientProvider, QueryCache } from '@tanstack/react-query'; import { ScreenshotProvider, ThemeProvider, useApiErrorBoundary } from './hooks'; -import { ToastProvider, AssistantsProvider } from './Providers'; +import { ToastProvider } from './Providers'; import Toast from './components/ui/Toast'; import { router } from './routes'; @@ -29,14 +29,12 @@ const App = () => { - - - - - - - - + + + + + + diff --git a/client/src/Providers/AssistantsContext.tsx b/client/src/Providers/AssistantsContext.tsx index 515618879..10079083a 100644 --- a/client/src/Providers/AssistantsContext.tsx +++ b/client/src/Providers/AssistantsContext.tsx @@ -1,14 +1,10 @@ +import { useForm, FormProvider } from 'react-hook-form'; import { createContext, useContext } from 'react'; +import { defaultAssistantFormValues } from 'librechat-data-provider'; import type { UseFormReturn } from 'react-hook-form'; -import type { CreationForm } from '~/common'; -import useCreationForm from './useCreationForm'; +import type { AssistantForm } from '~/common'; -// type AssistantsContextType = { -// // open: boolean; -// // setOpen: Dispatch>; -// form: UseFormReturn; -// }; -type AssistantsContextType = UseFormReturn; +type AssistantsContextType = UseFormReturn; export const AssistantsContext = createContext({} as AssistantsContextType); @@ -23,7 +19,9 @@ export function useAssistantsContext() { } export default function AssistantsProvider({ children }) { - const hookValues = useCreationForm(); + const methods = useForm({ + defaultValues: defaultAssistantFormValues, + }); - return {children}; + return {children}; } diff --git a/client/src/Providers/AssistantsMapContext.tsx b/client/src/Providers/AssistantsMapContext.tsx new file mode 100644 index 000000000..850e7d312 --- /dev/null +++ b/client/src/Providers/AssistantsMapContext.tsx @@ -0,0 +1,8 @@ +import { createContext, useContext } from 'react'; +import { useAssistantsMap } from '~/hooks/Assistants'; +type AssistantsMapContextType = ReturnType; + +export const AssistantsMapContext = createContext( + {} as AssistantsMapContextType, +); +export const useAssistantsMapContext = () => useContext(AssistantsMapContext); diff --git a/client/src/Providers/FileMapContext.tsx b/client/src/Providers/FileMapContext.tsx new file mode 100644 index 000000000..2e189cacb --- /dev/null +++ b/client/src/Providers/FileMapContext.tsx @@ -0,0 +1,6 @@ +import { createContext, useContext } from 'react'; +import { useFileMap } from '~/hooks/Files'; +type FileMapContextType = ReturnType; + +export const FileMapContext = createContext({} as FileMapContextType); +export const useFileMapContext = () => useContext(FileMapContext); diff --git a/client/src/Providers/index.ts b/client/src/Providers/index.ts index ab8b65d78..32e5c25dc 100644 --- a/client/src/Providers/index.ts +++ b/client/src/Providers/index.ts @@ -2,4 +2,6 @@ export { default as ToastProvider } from './ToastContext'; export { default as AssistantsProvider } from './AssistantsContext'; export * from './ChatContext'; export * from './ToastContext'; +export * from './FileMapContext'; export * from './AssistantsContext'; +export * from './AssistantsMapContext'; diff --git a/client/src/Providers/useCreationForm.ts b/client/src/Providers/useCreationForm.ts deleted file mode 100644 index 6fadf4c94..000000000 --- a/client/src/Providers/useCreationForm.ts +++ /dev/null @@ -1,19 +0,0 @@ -// import { useState } from 'react'; -import { useForm } from 'react-hook-form'; -import type { CreationForm } from '~/common'; - -export default function useViewPromptForm() { - return useForm({ - defaultValues: { - assistant: '', - id: '', - name: '', - description: '', - instructions: '', - model: 'gpt-3.5-turbo-1106', - function: false, - code_interpreter: false, - retrieval: false, - }, - }); -} diff --git a/client/src/common/assistants-types.ts b/client/src/common/assistants-types.ts index 7dc6906e7..c748de0de 100644 --- a/client/src/common/assistants-types.ts +++ b/client/src/common/assistants-types.ts @@ -1,19 +1,21 @@ -import type { Option } from './types'; import type { Assistant } from 'librechat-data-provider'; +import type { Option, ExtendedFile } from './types'; -export type TAssistantOption = string | (Option & Assistant); +export type TAssistantOption = + | string + | (Option & Assistant & { files?: Array<[string, ExtendedFile]> }); export type Actions = { - function: boolean; code_interpreter: boolean; retrieval: boolean; }; -export type CreationForm = { +export type AssistantForm = { assistant: TAssistantOption; id: string; name: string | null; description: string | null; instructions: string | null; model: string; + functions: string[]; } & Actions; diff --git a/client/src/common/types.ts b/client/src/common/types.ts index 1ca169a0c..babbe5579 100644 --- a/client/src/common/types.ts +++ b/client/src/common/types.ts @@ -1,4 +1,6 @@ import { FileSources } from 'librechat-data-provider'; +import type { ColumnDef } from '@tanstack/react-table'; +import type { SetterOrUpdater } from 'recoil'; import type { TConversation, TMessage, @@ -6,10 +8,80 @@ import type { TLoginUser, TUser, EModelEndpoint, + Action, + AuthTypeEnum, + AuthorizationTypeEnum, + TokenExchangeMethodEnum, } from 'librechat-data-provider'; import type { UseMutationResult } from '@tanstack/react-query'; +import type { LucideIcon } from 'lucide-react'; -export type TSetOption = (param: number | string) => (newValue: number | string | boolean) => void; +export type GenericSetter = (value: T | ((currentValue: T) => T)) => void; + +export type NavLink = { + title: string; + label?: string; + icon: LucideIcon; + Component?: React.ComponentType; + variant?: 'default' | 'ghost'; + id: string; +}; + +export interface NavProps { + isCollapsed: boolean; + links: NavLink[]; + resize?: (size: number) => void; + defaultActive?: string; +} + +interface ColumnMeta { + meta: { + size: number | string; + }; +} + +export enum Panel { + builder = 'builder', + actions = 'actions', +} + +export type FileSetter = + | SetterOrUpdater> + | React.Dispatch>>; + +export type ActionAuthForm = { + /* General */ + type: AuthTypeEnum; + saved_auth_fields: boolean; + /* API key */ + api_key: string; // not nested + authorization_type: AuthorizationTypeEnum; + custom_auth_header: string; + /* OAuth */ + oauth_client_id: string; // not nested + oauth_client_secret: string; // not nested + authorization_url: string; + client_url: string; + scope: string; + token_exchange_method: TokenExchangeMethodEnum; +}; + +export type AssistantPanelProps = { + index?: number; + action?: Action; + actions?: Action[]; + assistant_id?: string; + activePanel?: string; + setAction: React.Dispatch>; + setCurrentAssistantId: React.Dispatch>; + setActivePanel: React.Dispatch>; +}; + +export type AugmentedColumnDef = ColumnDef & ColumnMeta; + +export type TSetOption = ( + param: number | string, +) => (newValue: number | string | boolean | Partial) => void; export type TSetExample = ( i: number, type: string, @@ -72,7 +144,7 @@ export type TSetOptionsPayload = { setAgentOption: TSetOption; // getConversation: () => TConversation | TPreset | null; checkPluginSelection: (value: string) => boolean; - setTools: (newValue: string) => void; + setTools: (newValue: string, remove?: boolean) => void; }; export type TPresetItemProps = { @@ -136,7 +208,7 @@ export type TAdditionalProps = { setSiblingIdx: (value: number) => void; }; -export type TMessageContent = TInitialProps & TAdditionalProps; +export type TMessageContentProps = TInitialProps & TAdditionalProps; export type TText = Pick; export type TEditProps = Pick & @@ -172,6 +244,11 @@ export type TDialogProps = { onOpenChange: (open: boolean) => void; }; +export type TPluginStoreDialogProps = { + isOpen: boolean; + setIsOpen: (open: boolean) => void; +}; + export type TResError = { response: { data: { message: string } }; message: string; @@ -198,7 +275,7 @@ export type TAuthConfig = { test?: boolean; }; -export type IconProps = Pick & +export type IconProps = Pick & Pick & { size?: number; button?: boolean; @@ -207,6 +284,8 @@ export type IconProps = Pick & className?: string; endpoint?: EModelEndpoint | string | null; endpointType?: EModelEndpoint | null; + assistantName?: string; + error?: boolean; }; export type Option = Record & { @@ -220,7 +299,7 @@ export type TOptionSettings = { }; export interface ExtendedFile { - file: File; + file?: File; file_id: string; temp_file_id?: string; type?: string; @@ -229,9 +308,10 @@ export interface ExtendedFile { width?: number; height?: number; size: number; - preview: string; + preview?: string; progress: number; source?: FileSources; + attached?: boolean; } export type ContextType = { navVisible: boolean; setNavVisible: (visible: boolean) => void }; diff --git a/client/src/components/Chat/ChatView.tsx b/client/src/components/Chat/ChatView.tsx index 30a7edc18..604c8f1e7 100644 --- a/client/src/components/Chat/ChatView.tsx +++ b/client/src/components/Chat/ChatView.tsx @@ -2,16 +2,13 @@ import { memo } from 'react'; import { useRecoilValue } from 'recoil'; import { useParams } from 'react-router-dom'; import { useGetMessagesByConvoId } from 'librechat-data-provider/react-query'; -import { useChatHelpers, useSSE } from '~/hooks'; -// import GenerationButtons from './Input/GenerationButtons'; +import { ChatContext, useFileMapContext } from '~/Providers'; import MessagesView from './Messages/MessagesView'; -// import OptionsBar from './Input/OptionsBar'; -import { useGetFiles } from '~/data-provider'; -import { buildTree, mapFiles } from '~/utils'; +import { useChatHelpers, useSSE } from '~/hooks'; import { Spinner } from '~/components/svg'; -import { ChatContext } from '~/Providers'; import Presentation from './Presentation'; import ChatForm from './Input/ChatForm'; +import { buildTree } from '~/utils'; import Landing from './Landing'; import Header from './Header'; import Footer from './Footer'; @@ -22,9 +19,7 @@ function ChatView({ index = 0 }: { index?: number }) { const submissionAtIndex = useRecoilValue(store.submissionByIndex(0)); useSSE(submissionAtIndex); - const { data: fileMap } = useGetFiles({ - select: mapFiles, - }); + const fileMap = useFileMapContext(); const { data: messagesTree = null, isLoading } = useGetMessagesByConvoId(conversationId ?? '', { select: (data) => { @@ -38,7 +33,7 @@ function ChatView({ index = 0 }: { index?: number }) { return ( - + {isLoading && conversationId !== 'new' ? (
@@ -48,8 +43,6 @@ function ChatView({ index = 0 }: { index?: number }) { ) : ( } /> )} - {/* */} - {/* */}