diff --git a/.env.example b/.env.example index e235b6cbb9..b86092d56d 100644 --- a/.env.example +++ b/.env.example @@ -473,6 +473,15 @@ FIREBASE_STORAGE_BUCKET= FIREBASE_MESSAGING_SENDER_ID= FIREBASE_APP_ID= +#========================# +# S3 AWS Bucket # +#========================# + +AWS_ACCESS_KEY_ID= +AWS_SECRET_ACCESS_KEY= +AWS_REGION= +AWS_BUCKET_NAME= + #========================# # Shared Links # #========================# diff --git a/api/app/clients/BaseClient.js b/api/app/clients/BaseClient.js index 61b39a8f6d..d3077e68f5 100644 --- a/api/app/clients/BaseClient.js +++ b/api/app/clients/BaseClient.js @@ -5,6 +5,7 @@ const { isAgentsEndpoint, isParamEndpoint, EModelEndpoint, + ContentTypes, excludedKeys, ErrorTypes, Constants, @@ -365,17 +366,14 @@ class BaseClient { * context: TMessage[], * remainingContextTokens: number, * messagesToRefine: TMessage[], - * summaryIndex: number, - * }>} An object with four properties: `context`, `summaryIndex`, `remainingContextTokens`, and `messagesToRefine`. + * }>} An object with three properties: `context`, `remainingContextTokens`, and `messagesToRefine`. * `context` is an array of messages that fit within the token limit. - * `summaryIndex` is the index of the first message in the `messagesToRefine` array. * `remainingContextTokens` is the number of tokens remaining within the limit after adding the messages to the context. * `messagesToRefine` is an array of messages that were not added to the context because they would have exceeded the token limit. */ async getMessagesWithinTokenLimit({ messages: _messages, maxContextTokens, instructions }) { // Every reply is primed with <|start|>assistant<|message|>, so we // start with 3 tokens for the label after all messages have been counted. - let summaryIndex = -1; let currentTokenCount = 3; const instructionsTokenCount = instructions?.tokenCount ?? 0; let remainingContextTokens = @@ -408,14 +406,12 @@ class BaseClient { } const prunedMemory = messages; - summaryIndex = prunedMemory.length - 1; remainingContextTokens -= currentTokenCount; return { context: context.reverse(), remainingContextTokens, messagesToRefine: prunedMemory, - summaryIndex, }; } @@ -458,7 +454,7 @@ class BaseClient { let orderedWithInstructions = this.addInstructions(orderedMessages, instructions); - let { context, remainingContextTokens, messagesToRefine, summaryIndex } = + let { context, remainingContextTokens, messagesToRefine } = await this.getMessagesWithinTokenLimit({ messages: orderedWithInstructions, instructions, @@ -528,7 +524,7 @@ class BaseClient { } // Make sure to only continue summarization logic if the summary message was generated - shouldSummarize = summaryMessage && shouldSummarize; + shouldSummarize = summaryMessage != null && shouldSummarize === true; logger.debug('[BaseClient] Context Count (2/2)', { remainingContextTokens, @@ -538,17 +534,18 @@ class BaseClient { /** @type {Record | undefined} */ let tokenCountMap; if (buildTokenMap) { - tokenCountMap = orderedWithInstructions.reduce((map, message, index) => { + const currentPayload = shouldSummarize ? orderedWithInstructions : context; + tokenCountMap = currentPayload.reduce((map, message, index) => { const { messageId } = message; if (!messageId) { return map; } - if (shouldSummarize && index === summaryIndex && !usePrevSummary) { + if (shouldSummarize && index === messagesToRefine.length - 1 && !usePrevSummary) { map.summaryMessage = { ...summaryMessage, messageId, tokenCount: summaryTokenCount }; } - map[messageId] = orderedWithInstructions[index].tokenCount; + map[messageId] = currentPayload[index].tokenCount; return map; }, {}); } @@ -1021,11 +1018,17 @@ class BaseClient { const processValue = (value) => { if (Array.isArray(value)) { for (let item of value) { - if (!item || !item.type || item.type === 'image_url') { + if ( + !item || + !item.type || + item.type === ContentTypes.THINK || + item.type === ContentTypes.ERROR || + item.type === ContentTypes.IMAGE_URL + ) { continue; } - if (item.type === 'tool_call' && item.tool_call != null) { + if (item.type === ContentTypes.TOOL_CALL && item.tool_call != null) { const toolName = item.tool_call?.name || ''; if (toolName != null && toolName && typeof toolName === 'string') { numTokens += this.getTokenCount(toolName); @@ -1121,9 +1124,13 @@ class BaseClient { return message; } - const files = await getFiles({ - file_id: { $in: fileIds }, - }); + const files = await getFiles( + { + file_id: { $in: fileIds }, + }, + {}, + {}, + ); await this.addImageURLs(message, files, this.visionMode); diff --git a/api/app/clients/OpenAIClient.js b/api/app/clients/OpenAIClient.js index 9a89e34879..a1ab496b5d 100644 --- a/api/app/clients/OpenAIClient.js +++ b/api/app/clients/OpenAIClient.js @@ -1272,6 +1272,29 @@ ${convo} }); } + /** Note: OpenAI Web Search models do not support any known parameters besdies `max_tokens` */ + if (modelOptions.model && /gpt-4o.*search/.test(modelOptions.model)) { + const searchExcludeParams = [ + 'frequency_penalty', + 'presence_penalty', + 'temperature', + 'top_p', + 'top_k', + 'stop', + 'logit_bias', + 'seed', + 'response_format', + 'n', + 'logprobs', + 'user', + ]; + + this.options.dropParams = this.options.dropParams || []; + this.options.dropParams = [ + ...new Set([...this.options.dropParams, ...searchExcludeParams]), + ]; + } + if (this.options.dropParams && Array.isArray(this.options.dropParams)) { this.options.dropParams.forEach((param) => { delete modelOptions[param]; diff --git a/api/app/clients/prompts/formatMessages.js b/api/app/clients/prompts/formatMessages.js index 4e8d3bd5a5..9fa0d40497 100644 --- a/api/app/clients/prompts/formatMessages.js +++ b/api/app/clients/prompts/formatMessages.js @@ -211,7 +211,7 @@ const formatAgentMessages = (payload) => { } else if (part.type === ContentTypes.THINK) { hasReasoning = true; continue; - } else if (part.type === ContentTypes.ERROR) { + } else if (part.type === ContentTypes.ERROR || part.type === ContentTypes.AGENT_UPDATE) { continue; } else { currentContent.push(part); diff --git a/api/app/clients/specs/BaseClient.test.js b/api/app/clients/specs/BaseClient.test.js index 0dae5b14d3..c9be50d3de 100644 --- a/api/app/clients/specs/BaseClient.test.js +++ b/api/app/clients/specs/BaseClient.test.js @@ -164,7 +164,7 @@ describe('BaseClient', () => { const result = await TestClient.getMessagesWithinTokenLimit({ messages }); expect(result.context).toEqual(expectedContext); - expect(result.summaryIndex).toEqual(expectedIndex); + expect(result.messagesToRefine.length - 1).toEqual(expectedIndex); expect(result.remainingContextTokens).toBe(expectedRemainingContextTokens); expect(result.messagesToRefine).toEqual(expectedMessagesToRefine); }); @@ -200,7 +200,7 @@ describe('BaseClient', () => { const result = await TestClient.getMessagesWithinTokenLimit({ messages }); expect(result.context).toEqual(expectedContext); - expect(result.summaryIndex).toEqual(expectedIndex); + expect(result.messagesToRefine.length - 1).toEqual(expectedIndex); expect(result.remainingContextTokens).toBe(expectedRemainingContextTokens); expect(result.messagesToRefine).toEqual(expectedMessagesToRefine); }); diff --git a/api/app/clients/tools/structured/DALLE3.js b/api/app/clients/tools/structured/DALLE3.js index 81200e3a61..fc0f1851f6 100644 --- a/api/app/clients/tools/structured/DALLE3.js +++ b/api/app/clients/tools/structured/DALLE3.js @@ -172,7 +172,7 @@ Error Message: ${error.message}`); { type: ContentTypes.IMAGE_URL, image_url: { - url: `data:image/jpeg;base64,${base64}`, + url: `data:image/png;base64,${base64}`, }, }, ]; diff --git a/api/app/clients/tools/util/handleTools.js b/api/app/clients/tools/util/handleTools.js index ae19a158ee..063d6e0327 100644 --- a/api/app/clients/tools/util/handleTools.js +++ b/api/app/clients/tools/util/handleTools.js @@ -21,6 +21,7 @@ const { } = require('../'); const { primeFiles: primeCodeFiles } = require('~/server/services/Files/Code/process'); const { createFileSearchTool, primeFiles: primeSearchFiles } = require('./fileSearch'); +const { loadAuthValues } = require('~/server/services/Tools/credentials'); const { createMCPTool } = require('~/server/services/MCP'); const { loadSpecs } = require('./loadSpecs'); const { logger } = require('~/config'); @@ -90,45 +91,6 @@ const validateTools = async (user, tools = []) => { } }; -const loadAuthValues = async ({ userId, authFields, throwError = true }) => { - let authValues = {}; - - /** - * Finds the first non-empty value for the given authentication field, supporting alternate fields. - * @param {string[]} fields Array of strings representing the authentication fields. Supports alternate fields delimited by "||". - * @returns {Promise<{ authField: string, authValue: string} | null>} An object containing the authentication field and value, or null if not found. - */ - const findAuthValue = async (fields) => { - for (const field of fields) { - let value = process.env[field]; - if (value) { - return { authField: field, authValue: value }; - } - try { - value = await getUserPluginAuthValue(userId, field, throwError); - } catch (err) { - if (field === fields[fields.length - 1] && !value) { - throw err; - } - } - if (value) { - return { authField: field, authValue: value }; - } - } - return null; - }; - - for (let authField of authFields) { - const fields = authField.split('||'); - const result = await findAuthValue(fields); - if (result) { - authValues[result.authField] = result.authValue; - } - } - - return authValues; -}; - /** @typedef {typeof import('@langchain/core/tools').Tool} ToolConstructor */ /** @typedef {import('@langchain/core/tools').Tool} Tool */ @@ -348,7 +310,6 @@ const loadTools = async ({ module.exports = { loadToolWithAuth, - loadAuthValues, validateTools, loadTools, }; diff --git a/api/app/clients/tools/util/index.js b/api/app/clients/tools/util/index.js index 73d10270b6..ea67bb4ced 100644 --- a/api/app/clients/tools/util/index.js +++ b/api/app/clients/tools/util/index.js @@ -1,9 +1,8 @@ -const { validateTools, loadTools, loadAuthValues } = require('./handleTools'); +const { validateTools, loadTools } = require('./handleTools'); const handleOpenAIErrors = require('./handleOpenAIErrors'); module.exports = { handleOpenAIErrors, - loadAuthValues, validateTools, loadTools, }; diff --git a/api/config/index.js b/api/config/index.js index aaf8bb2764..8f23e404c8 100644 --- a/api/config/index.js +++ b/api/config/index.js @@ -1,3 +1,4 @@ +const axios = require('axios'); const { EventSource } = require('eventsource'); const { Time, CacheKeys } = require('librechat-data-provider'); const logger = require('./winston'); @@ -47,9 +48,46 @@ const sendEvent = (res, event) => { res.write(`event: message\ndata: ${JSON.stringify(event)}\n\n`); }; +/** + * Creates and configures an Axios instance with optional proxy settings. + * + * @typedef {import('axios').AxiosInstance} AxiosInstance + * @typedef {import('axios').AxiosProxyConfig} AxiosProxyConfig + * + * @returns {AxiosInstance} A configured Axios instance + * @throws {Error} If there's an issue creating the Axios instance or parsing the proxy URL + */ +function createAxiosInstance() { + const instance = axios.create(); + + if (process.env.proxy) { + try { + const url = new URL(process.env.proxy); + + /** @type {AxiosProxyConfig} */ + const proxyConfig = { + host: url.hostname.replace(/^\[|\]$/g, ''), + protocol: url.protocol.replace(':', ''), + }; + + if (url.port) { + proxyConfig.port = parseInt(url.port, 10); + } + + instance.defaults.proxy = proxyConfig; + } catch (error) { + console.error('Error parsing proxy URL:', error); + throw new Error(`Invalid proxy URL: ${process.env.proxy}`); + } + } + + return instance; +} + module.exports = { logger, sendEvent, getMCPManager, + createAxiosInstance, getFlowStateManager, }; diff --git a/api/config/index.spec.js b/api/config/index.spec.js new file mode 100644 index 0000000000..36ed8302f3 --- /dev/null +++ b/api/config/index.spec.js @@ -0,0 +1,126 @@ +const axios = require('axios'); +const { createAxiosInstance } = require('./index'); + +// Mock axios +jest.mock('axios', () => ({ + interceptors: { + request: { use: jest.fn(), eject: jest.fn() }, + response: { use: jest.fn(), eject: jest.fn() }, + }, + create: jest.fn().mockReturnValue({ + defaults: { + proxy: null, + }, + get: jest.fn().mockResolvedValue({ data: {} }), + post: jest.fn().mockResolvedValue({ data: {} }), + put: jest.fn().mockResolvedValue({ data: {} }), + delete: jest.fn().mockResolvedValue({ data: {} }), + }), + get: jest.fn().mockResolvedValue({ data: {} }), + post: jest.fn().mockResolvedValue({ data: {} }), + put: jest.fn().mockResolvedValue({ data: {} }), + delete: jest.fn().mockResolvedValue({ data: {} }), + reset: jest.fn().mockImplementation(function () { + this.get.mockClear(); + this.post.mockClear(); + this.put.mockClear(); + this.delete.mockClear(); + this.create.mockClear(); + }), +})); + +describe('createAxiosInstance', () => { + const originalEnv = process.env; + + beforeEach(() => { + // Reset mocks + jest.clearAllMocks(); + // Create a clean copy of process.env + process.env = { ...originalEnv }; + // Default: no proxy + delete process.env.proxy; + }); + + afterAll(() => { + // Restore original process.env + process.env = originalEnv; + }); + + test('creates an axios instance without proxy when no proxy env is set', () => { + const instance = createAxiosInstance(); + + expect(axios.create).toHaveBeenCalledTimes(1); + expect(instance.defaults.proxy).toBeNull(); + }); + + test('configures proxy correctly with hostname and protocol', () => { + process.env.proxy = 'http://example.com'; + + const instance = createAxiosInstance(); + + expect(axios.create).toHaveBeenCalledTimes(1); + expect(instance.defaults.proxy).toEqual({ + host: 'example.com', + protocol: 'http', + }); + }); + + test('configures proxy correctly with hostname, protocol and port', () => { + process.env.proxy = 'https://proxy.example.com:8080'; + + const instance = createAxiosInstance(); + + expect(axios.create).toHaveBeenCalledTimes(1); + expect(instance.defaults.proxy).toEqual({ + host: 'proxy.example.com', + protocol: 'https', + port: 8080, + }); + }); + + test('handles proxy URLs with authentication', () => { + process.env.proxy = 'http://user:pass@proxy.example.com:3128'; + + const instance = createAxiosInstance(); + + expect(axios.create).toHaveBeenCalledTimes(1); + expect(instance.defaults.proxy).toEqual({ + host: 'proxy.example.com', + protocol: 'http', + port: 3128, + // Note: The current implementation doesn't handle auth - if needed, add this functionality + }); + }); + + test('throws error when proxy URL is invalid', () => { + process.env.proxy = 'invalid-url'; + + expect(() => createAxiosInstance()).toThrow('Invalid proxy URL'); + expect(axios.create).toHaveBeenCalledTimes(1); + }); + + // If you want to test the actual URL parsing more thoroughly + test('handles edge case proxy URLs correctly', () => { + // IPv6 address + process.env.proxy = 'http://[::1]:8080'; + + let instance = createAxiosInstance(); + + expect(instance.defaults.proxy).toEqual({ + host: '::1', + protocol: 'http', + port: 8080, + }); + + // URL with path (which should be ignored for proxy config) + process.env.proxy = 'http://proxy.example.com:8080/some/path'; + + instance = createAxiosInstance(); + + expect(instance.defaults.proxy).toEqual({ + host: 'proxy.example.com', + protocol: 'http', + port: 8080, + }); + }); +}); diff --git a/api/models/Banner.js b/api/models/Banner.js index 0f20faeba8..399a8e72ee 100644 --- a/api/models/Banner.js +++ b/api/models/Banner.js @@ -28,4 +28,4 @@ const getBanner = async (user) => { } }; -module.exports = { getBanner }; +module.exports = { Banner, getBanner }; diff --git a/api/models/Conversation.js b/api/models/Conversation.js index 9e51926ebc..dd6ef9bde1 100644 --- a/api/models/Conversation.js +++ b/api/models/Conversation.js @@ -15,19 +15,6 @@ const searchConversation = async (conversationId) => { throw new Error('Error searching conversation'); } }; -/** - * Searches for a conversation by conversationId and returns associated file ids. - * @param {string} conversationId - The conversation's ID. - * @returns {Promise} - */ -const getConvoFiles = async (conversationId) => { - try { - return (await Conversation.findOne({ conversationId }, 'files').lean())?.files ?? []; - } catch (error) { - logger.error('[getConvoFiles] Error getting conversation files', error); - throw new Error('Error getting conversation files'); - } -}; /** * Retrieves a single conversation for a given user and conversation ID. @@ -73,6 +60,20 @@ const deleteNullOrEmptyConversations = async () => { } }; +/** + * Searches for a conversation by conversationId and returns associated file ids. + * @param {string} conversationId - The conversation's ID. + * @returns {Promise} + */ +const getConvoFiles = async (conversationId) => { + try { + return (await Conversation.findOne({ conversationId }, 'files').lean())?.files ?? []; + } catch (error) { + logger.error('[getConvoFiles] Error getting conversation files', error); + throw new Error('Error getting conversation files'); + } +}; + module.exports = { Conversation, getConvoFiles, diff --git a/api/models/File.js b/api/models/File.js index 870a18a7c8..0bde258a54 100644 --- a/api/models/File.js +++ b/api/models/File.js @@ -1,5 +1,6 @@ const mongoose = require('mongoose'); const { fileSchema } = require('@librechat/data-schemas'); +const { logger } = require('~/config'); const File = mongoose.model('File', fileSchema); @@ -17,11 +18,39 @@ const findFileById = async (file_id, options = {}) => { * Retrieves files matching a given filter, sorted by the most recently updated. * @param {Object} filter - The filter criteria to apply. * @param {Object} [_sortOptions] - Optional sort parameters. + * @param {Object|String} [selectFields={ text: 0 }] - Fields to include/exclude in the query results. + * Default excludes the 'text' field. * @returns {Promise>} A promise that resolves to an array of file documents. */ -const getFiles = async (filter, _sortOptions) => { +const getFiles = async (filter, _sortOptions, selectFields = { text: 0 }) => { const sortOptions = { updatedAt: -1, ..._sortOptions }; - return await File.find(filter).sort(sortOptions).lean(); + return await File.find(filter).select(selectFields).sort(sortOptions).lean(); +}; + +/** + * Retrieves tool files (files that are embedded or have a fileIdentifier) from an array of file IDs + * @param {string[]} fileIds - Array of file_id strings to search for + * @returns {Promise>} Files that match the criteria + */ +const getToolFilesByIds = async (fileIds) => { + if (!fileIds || !fileIds.length) { + return []; + } + + try { + const filter = { + file_id: { $in: fileIds }, + $or: [{ embedded: true }, { 'metadata.fileIdentifier': { $exists: true } }], + }; + + const selectFields = { text: 0 }; + const sortOptions = { updatedAt: -1 }; + + return await getFiles(filter, sortOptions, selectFields); + } catch (error) { + logger.error('[getToolFilesByIds] Error retrieving tool files:', error); + throw new Error('Error retrieving tool files'); + } }; /** @@ -109,6 +138,7 @@ module.exports = { File, findFileById, getFiles, + getToolFilesByIds, createFile, updateFile, updateFileUsage, diff --git a/api/models/Message.js b/api/models/Message.js index e651b20ad0..58068813ef 100644 --- a/api/models/Message.js +++ b/api/models/Message.js @@ -71,7 +71,42 @@ async function saveMessage(req, params, metadata) { } catch (err) { logger.error('Error saving message:', err); logger.info(`---\`saveMessage\` context: ${metadata?.context}`); - throw err; + + // Check if this is a duplicate key error (MongoDB error code 11000) + if (err.code === 11000 && err.message.includes('duplicate key error')) { + // Log the duplicate key error but don't crash the application + logger.warn(`Duplicate messageId detected: ${params.messageId}. Continuing execution.`); + + try { + // Try to find the existing message with this ID + const existingMessage = await Message.findOne({ + messageId: params.messageId, + user: req.user.id, + }); + + // If we found it, return it + if (existingMessage) { + return existingMessage.toObject(); + } + + // If we can't find it (unlikely but possible in race conditions) + return { + ...params, + messageId: params.messageId, + user: req.user.id, + }; + } catch (findError) { + // If the findOne also fails, log it but don't crash + logger.warn(`Could not retrieve existing message with ID ${params.messageId}: ${findError.message}`); + return { + ...params, + messageId: params.messageId, + user: req.user.id, + }; + } + } + + throw err; // Re-throw other errors } } diff --git a/api/models/tx.js b/api/models/tx.js index b534e7edc9..67301d0c49 100644 --- a/api/models/tx.js +++ b/api/models/tx.js @@ -61,6 +61,7 @@ const bedrockValues = { 'amazon.nova-micro-v1:0': { prompt: 0.035, completion: 0.14 }, 'amazon.nova-lite-v1:0': { prompt: 0.06, completion: 0.24 }, 'amazon.nova-pro-v1:0': { prompt: 0.8, completion: 3.2 }, + 'deepseek.r1': { prompt: 1.35, completion: 5.4 }, }; /** diff --git a/api/models/tx.spec.js b/api/models/tx.spec.js index b04eacc9f3..f612e222bb 100644 --- a/api/models/tx.spec.js +++ b/api/models/tx.spec.js @@ -288,7 +288,7 @@ describe('AWS Bedrock Model Tests', () => { }); describe('Deepseek Model Tests', () => { - const deepseekModels = ['deepseek-chat', 'deepseek-coder', 'deepseek-reasoner']; + const deepseekModels = ['deepseek-chat', 'deepseek-coder', 'deepseek-reasoner', 'deepseek.r1']; it('should return the correct prompt multipliers for all models', () => { const results = deepseekModels.map((model) => { diff --git a/api/package.json b/api/package.json index cfc9977aaf..36edce6baa 100644 --- a/api/package.json +++ b/api/package.json @@ -35,6 +35,8 @@ "homepage": "https://librechat.ai", "dependencies": { "@anthropic-ai/sdk": "^0.37.0", + "@aws-sdk/client-s3": "^3.758.0", + "@aws-sdk/s3-request-presigner": "^3.758.0", "@azure/search-documents": "^12.0.0", "@google/generative-ai": "^0.23.0", "@googleapis/youtube": "^20.0.0", @@ -42,10 +44,10 @@ "@keyv/redis": "^2.8.1", "@langchain/community": "^0.3.34", "@langchain/core": "^0.3.40", - "@langchain/google-genai": "^0.1.9", - "@langchain/google-vertexai": "^0.2.0", + "@langchain/google-genai": "^0.1.11", + "@langchain/google-vertexai": "^0.2.2", "@langchain/textsplitters": "^0.1.0", - "@librechat/agents": "^2.2.0", + "@librechat/agents": "^2.2.8", "@librechat/data-schemas": "*", "@waylaidwanderer/fetch-event-source": "^3.0.1", "axios": "^1.8.2", @@ -82,7 +84,7 @@ "memorystore": "^1.6.7", "mime": "^3.0.0", "module-alias": "^2.2.3", - "mongoose": "^8.9.5", + "mongoose": "^8.12.1", "multer": "^1.4.5-lts.1", "nanoid": "^3.3.7", "nodemailer": "^6.9.15", diff --git a/api/server/controllers/agents/callbacks.js b/api/server/controllers/agents/callbacks.js index 45beefe7e6..6622ec3815 100644 --- a/api/server/controllers/agents/callbacks.js +++ b/api/server/controllers/agents/callbacks.js @@ -10,8 +10,8 @@ const { ChatModelStreamHandler, } = require('@librechat/agents'); const { processCodeOutput } = require('~/server/services/Files/Code/process'); +const { loadAuthValues } = require('~/server/services/Tools/credentials'); const { saveBase64Image } = require('~/server/services/Files/process'); -const { loadAuthValues } = require('~/app/clients/tools/util'); const { logger, sendEvent } = require('~/config'); /** @typedef {import('@librechat/agents').Graph} Graph */ diff --git a/api/server/controllers/agents/client.js b/api/server/controllers/agents/client.js index 628b62e5ea..4b995bb06a 100644 --- a/api/server/controllers/agents/client.js +++ b/api/server/controllers/agents/client.js @@ -7,7 +7,16 @@ // validateVisionModel, // mapModelToAzureConfig, // } = require('librechat-data-provider'); -const { Callback, createMetadataAggregator } = require('@librechat/agents'); +require('events').EventEmitter.defaultMaxListeners = 100; +const { + Callback, + GraphEvents, + formatMessage, + formatAgentMessages, + formatContentStrings, + getTokenCountForMessage, + createMetadataAggregator, +} = require('@librechat/agents'); const { Constants, VisionModes, @@ -17,24 +26,19 @@ const { KnownEndpoints, anthropicSchema, isAgentsEndpoint, + AgentCapabilities, bedrockInputSchema, removeNullishValues, } = require('librechat-data-provider'); -const { - formatMessage, - addCacheControl, - formatAgentMessages, - formatContentStrings, - createContextHandlers, -} = require('~/app/clients/prompts'); +const { getCustomEndpointConfig, checkCapability } = require('~/server/services/Config'); +const { addCacheControl, createContextHandlers } = require('~/app/clients/prompts'); const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens'); const { getBufferString, HumanMessage } = require('@langchain/core/messages'); const { encodeAndFormat } = require('~/server/services/Files/images/encode'); -const { getCustomEndpointConfig } = require('~/server/services/Config'); const Tokenizer = require('~/server/services/Tokenizer'); const BaseClient = require('~/app/clients/BaseClient'); +const { logger, sendEvent } = require('~/config'); const { createRun } = require('./run'); -const { logger } = require('~/config'); /** @typedef {import('@librechat/agents').MessageContentComplex} MessageContentComplex */ /** @typedef {import('@langchain/core/runnables').RunnableConfig} RunnableConfig */ @@ -99,6 +103,8 @@ class AgentClient extends BaseClient { this.outputTokensKey = 'output_tokens'; /** @type {UsageMetadata} */ this.usage; + /** @type {Record} */ + this.indexTokenCountMap = {}; } /** @@ -223,14 +229,23 @@ class AgentClient extends BaseClient { }; } + /** + * + * @param {TMessage} message + * @param {Array} attachments + * @returns {Promise>>} + */ async addImageURLs(message, attachments) { - const { files, image_urls } = await encodeAndFormat( + const { files, text, image_urls } = await encodeAndFormat( this.options.req, attachments, this.options.agent.provider, VisionModes.agents, ); message.image_urls = image_urls.length ? image_urls : undefined; + if (text && text.length) { + message.ocr = text; + } return files; } @@ -308,7 +323,21 @@ class AgentClient extends BaseClient { assistantName: this.options?.modelLabel, }); - const needsTokenCount = this.contextStrategy && !orderedMessages[i].tokenCount; + if (message.ocr && i !== orderedMessages.length - 1) { + if (typeof formattedMessage.content === 'string') { + formattedMessage.content = message.ocr + '\n' + formattedMessage.content; + } else { + const textPart = formattedMessage.content.find((part) => part.type === 'text'); + textPart + ? (textPart.text = message.ocr + '\n' + textPart.text) + : formattedMessage.content.unshift({ type: 'text', text: message.ocr }); + } + } else if (message.ocr && i === orderedMessages.length - 1) { + systemContent = [systemContent, message.ocr].join('\n'); + } + + const needsTokenCount = + (this.contextStrategy && !orderedMessages[i].tokenCount) || message.ocr; /* If tokens were never counted, or, is a Vision request and the message has files, count again */ if (needsTokenCount || (this.isVisionModel && (message.image_urls || message.files))) { @@ -354,6 +383,10 @@ class AgentClient extends BaseClient { })); } + for (let i = 0; i < messages.length; i++) { + this.indexTokenCountMap[i] = messages[i].tokenCount; + } + const result = { tokenCountMap, prompt: payload, @@ -599,6 +632,9 @@ class AgentClient extends BaseClient { // }); // } + /** @type {TCustomConfig['endpoints']['agents']} */ + const agentsEConfig = this.options.req.app.locals[EModelEndpoint.agents]; + /** @type {Partial & { version: 'v1' | 'v2'; run_id?: string; streamMode: string }} */ const config = { configurable: { @@ -606,19 +642,30 @@ class AgentClient extends BaseClient { last_agent_index: this.agentConfigs?.size ?? 0, hide_sequential_outputs: this.options.agent.hide_sequential_outputs, }, - recursionLimit: this.options.req.app.locals[EModelEndpoint.agents]?.recursionLimit, + recursionLimit: agentsEConfig?.recursionLimit, signal: abortController.signal, streamMode: 'values', version: 'v2', }; - const initialMessages = formatAgentMessages(payload); + const toolSet = new Set((this.options.agent.tools ?? []).map((tool) => tool && tool.name)); + let { messages: initialMessages, indexTokenCountMap } = formatAgentMessages( + payload, + this.indexTokenCountMap, + toolSet, + ); if (legacyContentEndpoints.has(this.options.agent.endpoint)) { - formatContentStrings(initialMessages); + initialMessages = formatContentStrings(initialMessages); } /** @type {ReturnType} */ let run; + const countTokens = ((text) => this.getTokenCount(text)).bind(this); + + /** @type {(message: BaseMessage) => number} */ + const tokenCounter = (message) => { + return getTokenCountForMessage(message, countTokens); + }; /** * @@ -626,12 +673,23 @@ class AgentClient extends BaseClient { * @param {BaseMessage[]} messages * @param {number} [i] * @param {TMessageContentParts[]} [contentData] + * @param {Record} [currentIndexCountMap] */ - const runAgent = async (agent, _messages, i = 0, contentData = []) => { + const runAgent = async (agent, _messages, i = 0, contentData = [], _currentIndexCountMap) => { config.configurable.model = agent.model_parameters.model; + const currentIndexCountMap = _currentIndexCountMap ?? indexTokenCountMap; if (i > 0) { this.model = agent.model_parameters.model; } + if (agent.recursion_limit && typeof agent.recursion_limit === 'number') { + config.recursionLimit = agent.recursion_limit; + } + if ( + agentsEConfig?.maxRecursionLimit && + config.recursionLimit > agentsEConfig?.maxRecursionLimit + ) { + config.recursionLimit = agentsEConfig?.maxRecursionLimit; + } config.configurable.agent_id = agent.id; config.configurable.name = agent.name; config.configurable.agent_index = i; @@ -694,11 +752,29 @@ class AgentClient extends BaseClient { } if (contentData.length) { + const agentUpdate = { + type: ContentTypes.AGENT_UPDATE, + [ContentTypes.AGENT_UPDATE]: { + index: contentData.length, + runId: this.responseMessageId, + agentId: agent.id, + }, + }; + const streamData = { + event: GraphEvents.ON_AGENT_UPDATE, + data: agentUpdate, + }; + this.options.aggregateContent(streamData); + sendEvent(this.options.res, streamData); + contentData.push(agentUpdate); run.Graph.contentData = contentData; } await run.processStream({ messages }, config, { keepContent: i !== 0, + tokenCounter, + indexTokenCountMap: currentIndexCountMap, + maxContextTokens: agent.maxContextTokens, callbacks: { [Callback.TOOL_ERROR]: (graph, error, toolId) => { logger.error( @@ -712,9 +788,13 @@ class AgentClient extends BaseClient { }; await runAgent(this.options.agent, initialMessages); - let finalContentStart = 0; - if (this.agentConfigs && this.agentConfigs.size > 0) { + if ( + this.agentConfigs && + this.agentConfigs.size > 0 && + (await checkCapability(this.options.req, AgentCapabilities.chain)) + ) { + const windowSize = 5; let latestMessage = initialMessages.pop().content; if (typeof latestMessage !== 'string') { latestMessage = latestMessage[0].text; @@ -722,7 +802,16 @@ class AgentClient extends BaseClient { let i = 1; let runMessages = []; - const lastFiveMessages = initialMessages.slice(-5); + const windowIndexCountMap = {}; + const windowMessages = initialMessages.slice(-windowSize); + let currentIndex = 4; + for (let i = initialMessages.length - 1; i >= 0; i--) { + windowIndexCountMap[currentIndex] = indexTokenCountMap[i]; + currentIndex--; + if (currentIndex < 0) { + break; + } + } for (const [agentId, agent] of this.agentConfigs) { if (abortController.signal.aborted === true) { break; @@ -757,7 +846,9 @@ class AgentClient extends BaseClient { } try { const contextMessages = []; - for (const message of lastFiveMessages) { + const runIndexCountMap = {}; + for (let i = 0; i < windowMessages.length; i++) { + const message = windowMessages[i]; const messageType = message._getType(); if ( (!agent.tools || agent.tools.length === 0) && @@ -765,11 +856,13 @@ class AgentClient extends BaseClient { ) { continue; } - + runIndexCountMap[contextMessages.length] = windowIndexCountMap[i]; contextMessages.push(message); } - const currentMessages = [...contextMessages, new HumanMessage(bufferString)]; - await runAgent(agent, currentMessages, i, contentData); + const bufferMessage = new HumanMessage(bufferString); + runIndexCountMap[contextMessages.length] = tokenCounter(bufferMessage); + const currentMessages = [...contextMessages, bufferMessage]; + await runAgent(agent, currentMessages, i, contentData, runIndexCountMap); } catch (err) { logger.error( `[api/server/controllers/agents/client.js #chatCompletion] Error running agent ${agentId} (${i})`, @@ -780,6 +873,7 @@ class AgentClient extends BaseClient { } } + /** Note: not implemented */ if (config.configurable.hide_sequential_outputs !== true) { finalContentStart = 0; } diff --git a/api/server/controllers/agents/v1.js b/api/server/controllers/agents/v1.js index 08327ec61c..731dee69a2 100644 --- a/api/server/controllers/agents/v1.js +++ b/api/server/controllers/agents/v1.js @@ -1,10 +1,11 @@ const fs = require('fs').promises; const { nanoid } = require('nanoid'); const { - FileContext, - Constants, Tools, + Constants, + FileContext, SystemRoles, + EToolResources, actionDelimiter, } = require('librechat-data-provider'); const { @@ -203,14 +204,21 @@ const duplicateAgentHandler = async (req, res) => { } const { - _id: __id, id: _id, + _id: __id, author: _author, createdAt: _createdAt, updatedAt: _updatedAt, + tool_resources: _tool_resources = {}, ...cloneData } = agent; + if (_tool_resources?.[EToolResources.ocr]) { + cloneData.tool_resources = { + [EToolResources.ocr]: _tool_resources[EToolResources.ocr], + }; + } + const newAgentId = `agent_${nanoid()}`; const newAgentData = Object.assign(cloneData, { id: newAgentId, diff --git a/api/server/controllers/tools.js b/api/server/controllers/tools.js index 1c5330af35..b37b6fcb8c 100644 --- a/api/server/controllers/tools.js +++ b/api/server/controllers/tools.js @@ -10,7 +10,8 @@ const { const { processFileURL, uploadImageBuffer } = require('~/server/services/Files/process'); const { processCodeOutput } = require('~/server/services/Files/Code/process'); const { createToolCall, getToolCallsByConvo } = require('~/models/ToolCall'); -const { loadAuthValues, loadTools } = require('~/app/clients/tools/util'); +const { loadAuthValues } = require('~/server/services/Tools/credentials'); +const { loadTools } = require('~/app/clients/tools/util'); const { checkAccess } = require('~/server/middleware'); const { getMessage } = require('~/models/Message'); const { logger } = require('~/config'); diff --git a/api/server/middleware/buildEndpointOption.js b/api/server/middleware/buildEndpointOption.js index a0ce754a1c..041864b025 100644 --- a/api/server/middleware/buildEndpointOption.js +++ b/api/server/middleware/buildEndpointOption.js @@ -10,7 +10,6 @@ const openAI = require('~/server/services/Endpoints/openAI'); const agents = require('~/server/services/Endpoints/agents'); const custom = require('~/server/services/Endpoints/custom'); const google = require('~/server/services/Endpoints/google'); -const { getConvoFiles } = require('~/models/Conversation'); const { handleError } = require('~/server/utils'); const buildFunction = { @@ -87,16 +86,8 @@ async function buildEndpointOption(req, res, next) { // TODO: use `getModelsConfig` only when necessary const modelsConfig = await getModelsConfig(req); - const { resendFiles = true } = req.body.endpointOption; req.body.endpointOption.modelsConfig = modelsConfig; - if (isAgents && resendFiles && req.body.conversationId) { - const fileIds = await getConvoFiles(req.body.conversationId); - const requestFiles = req.body.files ?? []; - if (requestFiles.length || fileIds.length) { - req.body.endpointOption.attachments = processFiles(requestFiles, fileIds); - } - } else if (req.body.files) { - // hold the promise + if (req.body.files && !isAgents) { req.body.endpointOption.attachments = processFiles(req.body.files); } next(); diff --git a/api/server/routes/files/files.js b/api/server/routes/files/files.js index c320f7705b..c371b8e28e 100644 --- a/api/server/routes/files/files.js +++ b/api/server/routes/files/files.js @@ -16,7 +16,7 @@ const { } = require('~/server/services/Files/process'); const { getStrategyFunctions } = require('~/server/services/Files/strategies'); const { getOpenAIClient } = require('~/server/controllers/assistants/helpers'); -const { loadAuthValues } = require('~/app/clients/tools/util'); +const { loadAuthValues } = require('~/server/services/Tools/credentials'); const { getAgent } = require('~/models/Agent'); const { getFiles } = require('~/models/File'); const { logger } = require('~/config'); diff --git a/api/server/services/ActionService.js b/api/server/services/ActionService.js index 660e7aeb0d..c332cdfcf1 100644 --- a/api/server/services/ActionService.js +++ b/api/server/services/ActionService.js @@ -161,9 +161,9 @@ async function createActionTool({ if (metadata.auth && metadata.auth.type !== AuthTypeEnum.None) { try { - const action_id = action.action_id; - const identifier = `${req.user.id}:${action.action_id}`; if (metadata.auth.type === AuthTypeEnum.OAuth && metadata.auth.authorization_url) { + const action_id = action.action_id; + const identifier = `${req.user.id}:${action.action_id}`; const requestLogin = async () => { const { args: _args, stepId, ...toolCall } = config.toolCall ?? {}; if (!stepId) { diff --git a/api/server/services/AppService.js b/api/server/services/AppService.js index d194d31a6b..925ffe93de 100644 --- a/api/server/services/AppService.js +++ b/api/server/services/AppService.js @@ -1,7 +1,14 @@ -const { FileSources, EModelEndpoint, getConfigDefaults } = require('librechat-data-provider'); +const { + FileSources, + EModelEndpoint, + loadOCRConfig, + processMCPEnv, + getConfigDefaults, +} = require('librechat-data-provider'); const { checkVariables, checkHealth, checkConfig, checkAzureVariables } = require('./start/checks'); const { azureAssistantsDefaults, assistantsConfigSetup } = require('./start/assistants'); const { initializeFirebase } = require('./Files/Firebase/initialize'); +const { initializeS3 } = require('./Files/S3/initialize'); const loadCustomConfig = require('./Config/loadCustomConfig'); const handleRateLimits = require('./Config/handleRateLimits'); const { loadDefaultInterface } = require('./start/interface'); @@ -25,6 +32,7 @@ const AppService = async (app) => { const config = (await loadCustomConfig()) ?? {}; const configDefaults = getConfigDefaults(); + const ocr = loadOCRConfig(config.ocr); const filteredTools = config.filteredTools; const includedTools = config.includedTools; const fileStrategy = config.fileStrategy ?? configDefaults.fileStrategy; @@ -37,6 +45,8 @@ const AppService = async (app) => { if (fileStrategy === FileSources.firebase) { initializeFirebase(); + } else if (fileStrategy === FileSources.s3) { + initializeS3(); } /** @type {Record { if (config.mcpServers != null) { const mcpManager = await getMCPManager(); - await mcpManager.initializeMCP(config.mcpServers); + await mcpManager.initializeMCP(config.mcpServers, processMCPEnv); await mcpManager.mapAvailableTools(availableTools); } @@ -57,6 +67,7 @@ const AppService = async (app) => { const interfaceConfig = await loadDefaultInterface(config, configDefaults); const defaultLocals = { + ocr, paths, fileStrategy, socialLogins, diff --git a/api/server/services/AppService.spec.js b/api/server/services/AppService.spec.js index 61ac80fc6c..e47bfe7d5d 100644 --- a/api/server/services/AppService.spec.js +++ b/api/server/services/AppService.spec.js @@ -120,6 +120,7 @@ describe('AppService', () => { }, }, paths: expect.anything(), + ocr: expect.anything(), imageOutputType: expect.any(String), fileConfig: undefined, secureImageLinks: undefined, @@ -588,4 +589,33 @@ describe('AppService updating app.locals and issuing warnings', () => { ); }); }); + + it('should not parse environment variable references in OCR config', async () => { + // Mock custom configuration with env variable references in OCR config + const mockConfig = { + ocr: { + apiKey: '${OCR_API_KEY_CUSTOM_VAR_NAME}', + baseURL: '${OCR_BASEURL_CUSTOM_VAR_NAME}', + strategy: 'mistral_ocr', + mistralModel: 'mistral-medium', + }, + }; + + require('./Config/loadCustomConfig').mockImplementationOnce(() => Promise.resolve(mockConfig)); + + // Set actual environment variables with different values + process.env.OCR_API_KEY_CUSTOM_VAR_NAME = 'actual-api-key'; + process.env.OCR_BASEURL_CUSTOM_VAR_NAME = 'https://actual-ocr-url.com'; + + // Initialize app + const app = { locals: {} }; + await AppService(app); + + // Verify that the raw string references were preserved and not interpolated + expect(app.locals.ocr).toBeDefined(); + expect(app.locals.ocr.apiKey).toEqual('${OCR_API_KEY_CUSTOM_VAR_NAME}'); + expect(app.locals.ocr.baseURL).toEqual('${OCR_BASEURL_CUSTOM_VAR_NAME}'); + expect(app.locals.ocr.strategy).toEqual('mistral_ocr'); + expect(app.locals.ocr.mistralModel).toEqual('mistral-medium'); + }); }); diff --git a/api/server/services/Config/getEndpointsConfig.js b/api/server/services/Config/getEndpointsConfig.js index 4f8bde68ad..016f5f7445 100644 --- a/api/server/services/Config/getEndpointsConfig.js +++ b/api/server/services/Config/getEndpointsConfig.js @@ -72,4 +72,15 @@ async function getEndpointsConfig(req) { return endpointsConfig; } -module.exports = { getEndpointsConfig }; +/** + * @param {ServerRequest} req + * @param {import('librechat-data-provider').AgentCapabilities} capability + * @returns {Promise} + */ +const checkCapability = async (req, capability) => { + const endpointsConfig = await getEndpointsConfig(req); + const capabilities = endpointsConfig?.[EModelEndpoint.agents]?.capabilities ?? []; + return capabilities.includes(capability); +}; + +module.exports = { getEndpointsConfig, checkCapability }; diff --git a/api/server/services/Endpoints/agents/build.js b/api/server/services/Endpoints/agents/build.js index 027937e7fd..999cdc16be 100644 --- a/api/server/services/Endpoints/agents/build.js +++ b/api/server/services/Endpoints/agents/build.js @@ -2,15 +2,8 @@ const { loadAgent } = require('~/models/Agent'); const { logger } = require('~/config'); const buildOptions = (req, endpoint, parsedBody) => { - const { - spec, - iconURL, - agent_id, - instructions, - maxContextTokens, - resendFiles = true, - ...model_parameters - } = parsedBody; + const { spec, iconURL, agent_id, instructions, maxContextTokens, ...model_parameters } = + parsedBody; const agentPromise = loadAgent({ req, agent_id, @@ -24,7 +17,6 @@ const buildOptions = (req, endpoint, parsedBody) => { iconURL, endpoint, agent_id, - resendFiles, instructions, maxContextTokens, model_parameters, diff --git a/api/server/services/Endpoints/agents/initialize.js b/api/server/services/Endpoints/agents/initialize.js index 1cf8ad7a67..737165e316 100644 --- a/api/server/services/Endpoints/agents/initialize.js +++ b/api/server/services/Endpoints/agents/initialize.js @@ -2,6 +2,7 @@ const { createContentAggregator, Providers } = require('@librechat/agents'); const { EModelEndpoint, getResponseSender, + AgentCapabilities, providerEndpointMap, } = require('librechat-data-provider'); const { @@ -15,10 +16,14 @@ const initCustom = require('~/server/services/Endpoints/custom/initialize'); const initGoogle = require('~/server/services/Endpoints/google/initialize'); const generateArtifactsPrompt = require('~/app/clients/prompts/artifacts'); const { getCustomEndpointConfig } = require('~/server/services/Config'); +const { processFiles } = require('~/server/services/Files/process'); const { loadAgentTools } = require('~/server/services/ToolService'); const AgentClient = require('~/server/controllers/agents/client'); +const { getConvoFiles } = require('~/models/Conversation'); +const { getToolFilesByIds } = require('~/models/File'); const { getModelMaxTokens } = require('~/utils'); const { getAgent } = require('~/models/Agent'); +const { getFiles } = require('~/models/File'); const { logger } = require('~/config'); const providerConfigMap = { @@ -34,20 +39,38 @@ const providerConfigMap = { }; /** - * + * @param {ServerRequest} req * @param {Promise> | undefined} _attachments * @param {AgentToolResources | undefined} _tool_resources * @returns {Promise<{ attachments: Array | undefined, tool_resources: AgentToolResources | undefined }>} */ -const primeResources = async (_attachments, _tool_resources) => { +const primeResources = async (req, _attachments, _tool_resources) => { try { + /** @type {Array | undefined} */ + let attachments; + const tool_resources = _tool_resources ?? {}; + const isOCREnabled = (req.app.locals?.[EModelEndpoint.agents]?.capabilities ?? []).includes( + AgentCapabilities.ocr, + ); + if (tool_resources.ocr?.file_ids && isOCREnabled) { + const context = await getFiles( + { + file_id: { $in: tool_resources.ocr.file_ids }, + }, + {}, + {}, + ); + attachments = (attachments ?? []).concat(context); + } if (!_attachments) { - return { attachments: undefined, tool_resources: _tool_resources }; + return { attachments, tool_resources }; } /** @type {Array | undefined} */ const files = await _attachments; - const attachments = []; - const tool_resources = _tool_resources ?? {}; + if (!attachments) { + /** @type {Array} */ + attachments = []; + } for (const file of files) { if (!file) { @@ -82,7 +105,6 @@ const primeResources = async (_attachments, _tool_resources) => { * @param {ServerResponse} params.res * @param {Agent} params.agent * @param {object} [params.endpointOption] - * @param {AgentToolResources} [params.tool_resources] * @param {boolean} [params.isInitialAgent] * @returns {Promise} */ @@ -91,9 +113,30 @@ const initializeAgentOptions = async ({ res, agent, endpointOption, - tool_resources, isInitialAgent = false, }) => { + let currentFiles; + /** @type {Array} */ + const requestFiles = req.body.files ?? []; + if ( + isInitialAgent && + req.body.conversationId != null && + (agent.model_parameters?.resendFiles ?? true) === true + ) { + const fileIds = (await getConvoFiles(req.body.conversationId)) ?? []; + const toolFiles = await getToolFilesByIds(fileIds); + if (requestFiles.length || toolFiles.length) { + currentFiles = await processFiles(requestFiles.concat(toolFiles)); + } + } else if (isInitialAgent && requestFiles.length) { + currentFiles = await processFiles(requestFiles); + } + + const { attachments, tool_resources } = await primeResources( + req, + currentFiles, + agent.tool_resources, + ); const { tools, toolContextMap } = await loadAgentTools({ req, res, @@ -138,6 +181,7 @@ const initializeAgentOptions = async ({ agent.provider = options.provider; } + /** @type {import('@librechat/agents').ClientOptions} */ agent.model_parameters = Object.assign(model_parameters, options.llmConfig); if (options.configOptions) { agent.model_parameters.configuration = options.configOptions; @@ -156,15 +200,16 @@ const initializeAgentOptions = async ({ const tokensModel = agent.provider === EModelEndpoint.azureOpenAI ? agent.model : agent.model_parameters.model; + const maxTokens = agent.model_parameters.maxOutputTokens ?? agent.model_parameters.maxTokens ?? 0; return { ...agent, tools, + attachments, toolContextMap, maxContextTokens: agent.max_context_tokens ?? - getModelMaxTokens(tokensModel, providerEndpointMap[provider]) ?? - 4000, + ((getModelMaxTokens(tokensModel, providerEndpointMap[provider]) ?? 4000) - maxTokens) * 0.9, }; }; @@ -197,11 +242,6 @@ const initializeClient = async ({ req, res, endpointOption }) => { throw new Error('Agent not found'); } - const { attachments, tool_resources } = await primeResources( - endpointOption.attachments, - primaryAgent.tool_resources, - ); - const agentConfigs = new Map(); // Handle primary agent @@ -210,7 +250,6 @@ const initializeClient = async ({ req, res, endpointOption }) => { res, agent: primaryAgent, endpointOption, - tool_resources, isInitialAgent: true, }); @@ -240,18 +279,21 @@ const initializeClient = async ({ req, res, endpointOption }) => { const client = new AgentClient({ req, - agent: primaryConfig, + res, sender, - attachments, contentParts, + agentConfigs, eventHandlers, collectedUsage, + aggregateContent, artifactPromises, + agent: primaryConfig, spec: endpointOption.spec, iconURL: endpointOption.iconURL, - agentConfigs, endpoint: EModelEndpoint.agents, + attachments: primaryConfig.attachments, maxContextTokens: primaryConfig.maxContextTokens, + resendFiles: primaryConfig.model_parameters?.resendFiles ?? true, }); return { client }; diff --git a/api/server/services/Endpoints/bedrock/initialize.js b/api/server/services/Endpoints/bedrock/initialize.js index 3ffa03393d..4d9ba361cf 100644 --- a/api/server/services/Endpoints/bedrock/initialize.js +++ b/api/server/services/Endpoints/bedrock/initialize.js @@ -23,8 +23,9 @@ const initializeClient = async ({ req, res, endpointOption }) => { const agent = { id: EModelEndpoint.bedrock, name: endpointOption.name, - instructions: endpointOption.promptPrefix, provider: EModelEndpoint.bedrock, + endpoint: EModelEndpoint.bedrock, + instructions: endpointOption.promptPrefix, model: endpointOption.model_parameters.model, model_parameters: endpointOption.model_parameters, }; @@ -54,6 +55,7 @@ const initializeClient = async ({ req, res, endpointOption }) => { const client = new AgentClient({ req, + res, agent, sender, // tools, diff --git a/api/server/services/Endpoints/openAI/initialize.js b/api/server/services/Endpoints/openAI/initialize.js index 5614804b68..4d358cef1a 100644 --- a/api/server/services/Endpoints/openAI/initialize.js +++ b/api/server/services/Endpoints/openAI/initialize.js @@ -135,12 +135,9 @@ const initializeClient = async ({ } if (optionsOnly) { - clientOptions = Object.assign( - { - modelOptions: endpointOption.model_parameters, - }, - clientOptions, - ); + const modelOptions = endpointOption.model_parameters; + modelOptions.model = modelName; + clientOptions = Object.assign({ modelOptions }, clientOptions); clientOptions.modelOptions.user = req.user.id; const options = getLLMConfig(apiKey, clientOptions); if (!clientOptions.streamRate) { diff --git a/api/server/services/Endpoints/openAI/llm.js b/api/server/services/Endpoints/openAI/llm.js index af19ece486..a8aeeb5b9d 100644 --- a/api/server/services/Endpoints/openAI/llm.js +++ b/api/server/services/Endpoints/openAI/llm.js @@ -28,7 +28,7 @@ const { isEnabled } = require('~/server/utils'); * @returns {Object} Configuration options for creating an LLM instance. */ function getLLMConfig(apiKey, options = {}, endpoint = null) { - const { + let { modelOptions = {}, reverseProxyUrl, defaultQuery, @@ -50,10 +50,32 @@ function getLLMConfig(apiKey, options = {}, endpoint = null) { if (addParams && typeof addParams === 'object') { Object.assign(llmConfig, addParams); } + /** Note: OpenAI Web Search models do not support any known parameters besdies `max_tokens` */ + if (modelOptions.model && /gpt-4o.*search/.test(modelOptions.model)) { + const searchExcludeParams = [ + 'frequency_penalty', + 'presence_penalty', + 'temperature', + 'top_p', + 'top_k', + 'stop', + 'logit_bias', + 'seed', + 'response_format', + 'n', + 'logprobs', + 'user', + ]; + + dropParams = dropParams || []; + dropParams = [...new Set([...dropParams, ...searchExcludeParams])]; + } if (dropParams && Array.isArray(dropParams)) { dropParams.forEach((param) => { - delete llmConfig[param]; + if (llmConfig[param]) { + llmConfig[param] = undefined; + } }); } diff --git a/api/server/services/Files/Code/crud.js b/api/server/services/Files/Code/crud.js index a467f6a29a..1360cccadb 100644 --- a/api/server/services/Files/Code/crud.js +++ b/api/server/services/Files/Code/crud.js @@ -1,8 +1,10 @@ -const axios = require('axios'); const FormData = require('form-data'); const { getCodeBaseURL } = require('@librechat/agents'); +const { createAxiosInstance } = require('~/config'); const { logAxiosError } = require('~/utils'); +const axios = createAxiosInstance(); + const MAX_FILE_SIZE = 150 * 1024 * 1024; /** @@ -27,13 +29,6 @@ async function getCodeOutputDownloadStream(fileIdentifier, apiKey) { timeout: 15000, }; - if (process.env.PROXY) { - options.proxy = { - host: process.env.PROXY, - protocol: process.env.PROXY.startsWith('https') ? 'https' : 'http', - }; - } - const response = await axios(options); return response; } catch (error) { @@ -79,13 +74,6 @@ async function uploadCodeEnvFile({ req, stream, filename, apiKey, entity_id = '' maxBodyLength: MAX_FILE_SIZE, }; - if (process.env.PROXY) { - options.proxy = { - host: process.env.PROXY, - protocol: process.env.PROXY.startsWith('https') ? 'https' : 'http', - }; - } - const response = await axios.post(`${baseURL}/upload`, form, options); /** @type {{ message: string; session_id: string; files: Array<{ fileId: string; filename: string }> }} */ diff --git a/api/server/services/Files/MistralOCR/crud.js b/api/server/services/Files/MistralOCR/crud.js new file mode 100644 index 0000000000..cef8297519 --- /dev/null +++ b/api/server/services/Files/MistralOCR/crud.js @@ -0,0 +1,207 @@ +// ~/server/services/Files/MistralOCR/crud.js +const fs = require('fs'); +const path = require('path'); +const FormData = require('form-data'); +const { FileSources, envVarRegex, extractEnvVariable } = require('librechat-data-provider'); +const { loadAuthValues } = require('~/server/services/Tools/credentials'); +const { logger, createAxiosInstance } = require('~/config'); +const { logAxiosError } = require('~/utils'); + +const axios = createAxiosInstance(); + +/** + * Uploads a document to Mistral API using file streaming to avoid loading the entire file into memory + * + * @param {Object} params Upload parameters + * @param {string} params.filePath The path to the file on disk + * @param {string} [params.fileName] Optional filename to use (defaults to the name from filePath) + * @param {string} params.apiKey Mistral API key + * @param {string} [params.baseURL=https://api.mistral.ai/v1] Mistral API base URL + * @returns {Promise} The response from Mistral API + */ +async function uploadDocumentToMistral({ + filePath, + fileName = '', + apiKey, + baseURL = 'https://api.mistral.ai/v1', +}) { + const form = new FormData(); + form.append('purpose', 'ocr'); + const actualFileName = fileName || path.basename(filePath); + const fileStream = fs.createReadStream(filePath); + form.append('file', fileStream, { filename: actualFileName }); + + return axios + .post(`${baseURL}/files`, form, { + headers: { + Authorization: `Bearer ${apiKey}`, + ...form.getHeaders(), + }, + maxBodyLength: Infinity, + maxContentLength: Infinity, + }) + .then((res) => res.data) + .catch((error) => { + logger.error('Error uploading document to Mistral:', error.message); + throw error; + }); +} + +async function getSignedUrl({ + apiKey, + fileId, + expiry = 24, + baseURL = 'https://api.mistral.ai/v1', +}) { + return axios + .get(`${baseURL}/files/${fileId}/url?expiry=${expiry}`, { + headers: { + Authorization: `Bearer ${apiKey}`, + }, + }) + .then((res) => res.data) + .catch((error) => { + logger.error('Error fetching signed URL:', error.message); + throw error; + }); +} + +/** + * @param {Object} params + * @param {string} params.apiKey + * @param {string} params.documentUrl + * @param {string} [params.baseURL] + * @returns {Promise} + */ +async function performOCR({ + apiKey, + documentUrl, + model = 'mistral-ocr-latest', + baseURL = 'https://api.mistral.ai/v1', +}) { + return axios + .post( + `${baseURL}/ocr`, + { + model, + include_image_base64: false, + document: { + type: 'document_url', + document_url: documentUrl, + }, + }, + { + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${apiKey}`, + }, + }, + ) + .then((res) => res.data) + .catch((error) => { + logger.error('Error performing OCR:', error.message); + throw error; + }); +} + +function extractVariableName(str) { + const match = str.match(envVarRegex); + return match ? match[1] : null; +} + +const uploadMistralOCR = async ({ req, file, file_id, entity_id }) => { + try { + /** @type {TCustomConfig['ocr']} */ + const ocrConfig = req.app.locals?.ocr; + + const apiKeyConfig = ocrConfig.apiKey || ''; + const baseURLConfig = ocrConfig.baseURL || ''; + + const isApiKeyEnvVar = envVarRegex.test(apiKeyConfig); + const isBaseURLEnvVar = envVarRegex.test(baseURLConfig); + + const isApiKeyEmpty = !apiKeyConfig.trim(); + const isBaseURLEmpty = !baseURLConfig.trim(); + + let apiKey, baseURL; + + if (isApiKeyEnvVar || isBaseURLEnvVar || isApiKeyEmpty || isBaseURLEmpty) { + const apiKeyVarName = isApiKeyEnvVar ? extractVariableName(apiKeyConfig) : 'OCR_API_KEY'; + const baseURLVarName = isBaseURLEnvVar ? extractVariableName(baseURLConfig) : 'OCR_BASEURL'; + + const authValues = await loadAuthValues({ + userId: req.user.id, + authFields: [baseURLVarName, apiKeyVarName], + optional: new Set([baseURLVarName]), + }); + + apiKey = authValues[apiKeyVarName]; + baseURL = authValues[baseURLVarName]; + } else { + apiKey = apiKeyConfig; + baseURL = baseURLConfig; + } + + const mistralFile = await uploadDocumentToMistral({ + filePath: file.path, + fileName: file.originalname, + apiKey, + baseURL, + }); + + const modelConfig = ocrConfig.mistralModel || ''; + const model = envVarRegex.test(modelConfig) + ? extractEnvVariable(modelConfig) + : modelConfig.trim() || 'mistral-ocr-latest'; + + const signedUrlResponse = await getSignedUrl({ + apiKey, + baseURL, + fileId: mistralFile.id, + }); + + const ocrResult = await performOCR({ + apiKey, + baseURL, + model, + documentUrl: signedUrlResponse.url, + }); + + let aggregatedText = ''; + const images = []; + ocrResult.pages.forEach((page, index) => { + if (ocrResult.pages.length > 1) { + aggregatedText += `# PAGE ${index + 1}\n`; + } + + aggregatedText += page.markdown + '\n\n'; + + if (page.images && page.images.length > 0) { + page.images.forEach((image) => { + if (image.image_base64) { + images.push(image.image_base64); + } + }); + } + }); + + return { + filename: file.originalname, + bytes: aggregatedText.length * 4, + filepath: FileSources.mistral_ocr, + text: aggregatedText, + images, + }; + } catch (error) { + const message = 'Error uploading document to Mistral OCR API'; + logAxiosError({ error, message }); + throw new Error(message); + } +}; + +module.exports = { + uploadDocumentToMistral, + uploadMistralOCR, + getSignedUrl, + performOCR, +}; diff --git a/api/server/services/Files/MistralOCR/crud.spec.js b/api/server/services/Files/MistralOCR/crud.spec.js new file mode 100644 index 0000000000..80ac6f73a4 --- /dev/null +++ b/api/server/services/Files/MistralOCR/crud.spec.js @@ -0,0 +1,737 @@ +const fs = require('fs'); + +const mockAxios = { + interceptors: { + request: { use: jest.fn(), eject: jest.fn() }, + response: { use: jest.fn(), eject: jest.fn() }, + }, + create: jest.fn().mockReturnValue({ + defaults: { + proxy: null, + }, + get: jest.fn().mockResolvedValue({ data: {} }), + post: jest.fn().mockResolvedValue({ data: {} }), + put: jest.fn().mockResolvedValue({ data: {} }), + delete: jest.fn().mockResolvedValue({ data: {} }), + }), + get: jest.fn().mockResolvedValue({ data: {} }), + post: jest.fn().mockResolvedValue({ data: {} }), + put: jest.fn().mockResolvedValue({ data: {} }), + delete: jest.fn().mockResolvedValue({ data: {} }), + reset: jest.fn().mockImplementation(function () { + this.get.mockClear(); + this.post.mockClear(); + this.put.mockClear(); + this.delete.mockClear(); + this.create.mockClear(); + }), +}; + +jest.mock('axios', () => mockAxios); +jest.mock('fs'); +jest.mock('~/utils', () => ({ + logAxiosError: jest.fn(), +})); +jest.mock('~/config', () => ({ + logger: { + error: jest.fn(), + }, + createAxiosInstance: () => mockAxios, +})); +jest.mock('~/server/services/Tools/credentials', () => ({ + loadAuthValues: jest.fn(), +})); + +const { uploadDocumentToMistral, uploadMistralOCR, getSignedUrl, performOCR } = require('./crud'); + +describe('MistralOCR Service', () => { + afterEach(() => { + mockAxios.reset(); + jest.clearAllMocks(); + }); + + describe('uploadDocumentToMistral', () => { + beforeEach(() => { + // Create a more complete mock for file streams that FormData can work with + const mockReadStream = { + on: jest.fn().mockImplementation(function (event, handler) { + // Simulate immediate 'end' event to make FormData complete processing + if (event === 'end') { + handler(); + } + return this; + }), + pipe: jest.fn().mockImplementation(function () { + return this; + }), + pause: jest.fn(), + resume: jest.fn(), + emit: jest.fn(), + once: jest.fn(), + destroy: jest.fn(), + }; + + fs.createReadStream = jest.fn().mockReturnValue(mockReadStream); + + // Mock FormData's append to avoid actual stream processing + jest.mock('form-data', () => { + const mockFormData = function () { + return { + append: jest.fn(), + getHeaders: jest + .fn() + .mockReturnValue({ 'content-type': 'multipart/form-data; boundary=---boundary' }), + getBuffer: jest.fn().mockReturnValue(Buffer.from('mock-form-data')), + getLength: jest.fn().mockReturnValue(100), + }; + }; + return mockFormData; + }); + }); + + it('should upload a document to Mistral API using file streaming', async () => { + const mockResponse = { data: { id: 'file-123', purpose: 'ocr' } }; + mockAxios.post.mockResolvedValueOnce(mockResponse); + + const result = await uploadDocumentToMistral({ + filePath: '/path/to/test.pdf', + fileName: 'test.pdf', + apiKey: 'test-api-key', + }); + + // Check that createReadStream was called with the correct file path + expect(fs.createReadStream).toHaveBeenCalledWith('/path/to/test.pdf'); + + // Since we're mocking FormData, we'll just check that axios was called correctly + expect(mockAxios.post).toHaveBeenCalledWith( + 'https://api.mistral.ai/v1/files', + expect.anything(), + expect.objectContaining({ + headers: expect.objectContaining({ + Authorization: 'Bearer test-api-key', + }), + maxBodyLength: Infinity, + maxContentLength: Infinity, + }), + ); + expect(result).toEqual(mockResponse.data); + }); + + it('should handle errors during document upload', async () => { + const errorMessage = 'API error'; + mockAxios.post.mockRejectedValueOnce(new Error(errorMessage)); + + await expect( + uploadDocumentToMistral({ + filePath: '/path/to/test.pdf', + fileName: 'test.pdf', + apiKey: 'test-api-key', + }), + ).rejects.toThrow(); + + const { logger } = require('~/config'); + expect(logger.error).toHaveBeenCalledWith( + expect.stringContaining('Error uploading document to Mistral:'), + expect.any(String), + ); + }); + }); + + describe('getSignedUrl', () => { + it('should fetch signed URL from Mistral API', async () => { + const mockResponse = { data: { url: 'https://document-url.com' } }; + mockAxios.get.mockResolvedValueOnce(mockResponse); + + const result = await getSignedUrl({ + fileId: 'file-123', + apiKey: 'test-api-key', + }); + + expect(mockAxios.get).toHaveBeenCalledWith( + 'https://api.mistral.ai/v1/files/file-123/url?expiry=24', + { + headers: { + Authorization: 'Bearer test-api-key', + }, + }, + ); + expect(result).toEqual(mockResponse.data); + }); + + it('should handle errors when fetching signed URL', async () => { + const errorMessage = 'API error'; + mockAxios.get.mockRejectedValueOnce(new Error(errorMessage)); + + await expect( + getSignedUrl({ + fileId: 'file-123', + apiKey: 'test-api-key', + }), + ).rejects.toThrow(); + + const { logger } = require('~/config'); + expect(logger.error).toHaveBeenCalledWith('Error fetching signed URL:', errorMessage); + }); + }); + + describe('performOCR', () => { + it('should perform OCR using Mistral API', async () => { + const mockResponse = { + data: { + pages: [{ markdown: 'Page 1 content' }, { markdown: 'Page 2 content' }], + }, + }; + mockAxios.post.mockResolvedValueOnce(mockResponse); + + const result = await performOCR({ + apiKey: 'test-api-key', + documentUrl: 'https://document-url.com', + model: 'mistral-ocr-latest', + }); + + expect(mockAxios.post).toHaveBeenCalledWith( + 'https://api.mistral.ai/v1/ocr', + { + model: 'mistral-ocr-latest', + include_image_base64: false, + document: { + type: 'document_url', + document_url: 'https://document-url.com', + }, + }, + { + headers: { + 'Content-Type': 'application/json', + Authorization: 'Bearer test-api-key', + }, + }, + ); + expect(result).toEqual(mockResponse.data); + }); + + it('should handle errors during OCR processing', async () => { + const errorMessage = 'OCR processing error'; + mockAxios.post.mockRejectedValueOnce(new Error(errorMessage)); + + await expect( + performOCR({ + apiKey: 'test-api-key', + documentUrl: 'https://document-url.com', + }), + ).rejects.toThrow(); + + const { logger } = require('~/config'); + expect(logger.error).toHaveBeenCalledWith('Error performing OCR:', errorMessage); + }); + }); + + describe('uploadMistralOCR', () => { + beforeEach(() => { + const mockReadStream = { + on: jest.fn().mockImplementation(function (event, handler) { + if (event === 'end') { + handler(); + } + return this; + }), + pipe: jest.fn().mockImplementation(function () { + return this; + }), + pause: jest.fn(), + resume: jest.fn(), + emit: jest.fn(), + once: jest.fn(), + destroy: jest.fn(), + }; + + fs.createReadStream = jest.fn().mockReturnValue(mockReadStream); + }); + + it('should process OCR for a file with standard configuration', async () => { + // Setup mocks + const { loadAuthValues } = require('~/server/services/Tools/credentials'); + loadAuthValues.mockResolvedValue({ + OCR_API_KEY: 'test-api-key', + OCR_BASEURL: 'https://api.mistral.ai/v1', + }); + + // Mock file upload response + mockAxios.post.mockResolvedValueOnce({ + data: { id: 'file-123', purpose: 'ocr' }, + }); + + // Mock signed URL response + mockAxios.get.mockResolvedValueOnce({ + data: { url: 'https://signed-url.com' }, + }); + + // Mock OCR response with text and images + mockAxios.post.mockResolvedValueOnce({ + data: { + pages: [ + { + markdown: 'Page 1 content', + images: [{ image_base64: 'base64image1' }], + }, + { + markdown: 'Page 2 content', + images: [{ image_base64: 'base64image2' }], + }, + ], + }, + }); + + const req = { + user: { id: 'user123' }, + app: { + locals: { + ocr: { + // Use environment variable syntax to ensure loadAuthValues is called + apiKey: '${OCR_API_KEY}', + baseURL: '${OCR_BASEURL}', + mistralModel: 'mistral-medium', + }, + }, + }, + }; + + const file = { + path: '/tmp/upload/file.pdf', + originalname: 'document.pdf', + }; + + const result = await uploadMistralOCR({ + req, + file, + file_id: 'file123', + entity_id: 'entity123', + }); + + expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf'); + + expect(loadAuthValues).toHaveBeenCalledWith({ + userId: 'user123', + authFields: ['OCR_BASEURL', 'OCR_API_KEY'], + optional: expect.any(Set), + }); + + // Verify OCR result + expect(result).toEqual({ + filename: 'document.pdf', + bytes: expect.any(Number), + filepath: 'mistral_ocr', + text: expect.stringContaining('# PAGE 1'), + images: ['base64image1', 'base64image2'], + }); + }); + + it('should process variable references in configuration', async () => { + // Setup mocks with environment variables + const { loadAuthValues } = require('~/server/services/Tools/credentials'); + loadAuthValues.mockResolvedValue({ + CUSTOM_API_KEY: 'custom-api-key', + CUSTOM_BASEURL: 'https://custom-api.mistral.ai/v1', + }); + + // Mock API responses + mockAxios.post.mockResolvedValueOnce({ + data: { id: 'file-123', purpose: 'ocr' }, + }); + mockAxios.get.mockResolvedValueOnce({ + data: { url: 'https://signed-url.com' }, + }); + mockAxios.post.mockResolvedValueOnce({ + data: { + pages: [{ markdown: 'Content from custom API' }], + }, + }); + + const req = { + user: { id: 'user123' }, + app: { + locals: { + ocr: { + apiKey: '${CUSTOM_API_KEY}', + baseURL: '${CUSTOM_BASEURL}', + mistralModel: '${CUSTOM_MODEL}', + }, + }, + }, + }; + + // Set environment variable for model + process.env.CUSTOM_MODEL = 'mistral-large'; + + const file = { + path: '/tmp/upload/file.pdf', + originalname: 'document.pdf', + }; + + const result = await uploadMistralOCR({ + req, + file, + file_id: 'file123', + entity_id: 'entity123', + }); + + expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf'); + + // Verify that custom environment variables were extracted and used + expect(loadAuthValues).toHaveBeenCalledWith({ + userId: 'user123', + authFields: ['CUSTOM_BASEURL', 'CUSTOM_API_KEY'], + optional: expect.any(Set), + }); + + // Check that mistral-large was used in the OCR API call + expect(mockAxios.post).toHaveBeenCalledWith( + expect.anything(), + expect.objectContaining({ + model: 'mistral-large', + }), + expect.anything(), + ); + + expect(result.text).toEqual('Content from custom API\n\n'); + }); + + it('should fall back to default values when variables are not properly formatted', async () => { + const { loadAuthValues } = require('~/server/services/Tools/credentials'); + loadAuthValues.mockResolvedValue({ + OCR_API_KEY: 'default-api-key', + OCR_BASEURL: undefined, // Testing optional parameter + }); + + mockAxios.post.mockResolvedValueOnce({ + data: { id: 'file-123', purpose: 'ocr' }, + }); + mockAxios.get.mockResolvedValueOnce({ + data: { url: 'https://signed-url.com' }, + }); + mockAxios.post.mockResolvedValueOnce({ + data: { + pages: [{ markdown: 'Default API result' }], + }, + }); + + const req = { + user: { id: 'user123' }, + app: { + locals: { + ocr: { + // Use environment variable syntax to ensure loadAuthValues is called + apiKey: '${INVALID_FORMAT}', // Using valid env var format but with an invalid name + baseURL: '${OCR_BASEURL}', // Using valid env var format + mistralModel: 'mistral-ocr-latest', // Plain string value + }, + }, + }, + }; + + const file = { + path: '/tmp/upload/file.pdf', + originalname: 'document.pdf', + }; + + await uploadMistralOCR({ + req, + file, + file_id: 'file123', + entity_id: 'entity123', + }); + + expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf'); + + // Should use the default values + expect(loadAuthValues).toHaveBeenCalledWith({ + userId: 'user123', + authFields: ['OCR_BASEURL', 'INVALID_FORMAT'], + optional: expect.any(Set), + }); + + // Should use the default model when not using environment variable format + expect(mockAxios.post).toHaveBeenCalledWith( + expect.anything(), + expect.objectContaining({ + model: 'mistral-ocr-latest', + }), + expect.anything(), + ); + }); + + it('should handle API errors during OCR process', async () => { + const { loadAuthValues } = require('~/server/services/Tools/credentials'); + loadAuthValues.mockResolvedValue({ + OCR_API_KEY: 'test-api-key', + }); + + // Mock file upload to fail + mockAxios.post.mockRejectedValueOnce(new Error('Upload failed')); + + const req = { + user: { id: 'user123' }, + app: { + locals: { + ocr: { + apiKey: 'OCR_API_KEY', + baseURL: 'OCR_BASEURL', + }, + }, + }, + }; + + const file = { + path: '/tmp/upload/file.pdf', + originalname: 'document.pdf', + }; + + await expect( + uploadMistralOCR({ + req, + file, + file_id: 'file123', + entity_id: 'entity123', + }), + ).rejects.toThrow('Error uploading document to Mistral OCR API'); + expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf'); + + const { logAxiosError } = require('~/utils'); + expect(logAxiosError).toHaveBeenCalled(); + }); + + it('should handle single page documents without page numbering', async () => { + const { loadAuthValues } = require('~/server/services/Tools/credentials'); + loadAuthValues.mockResolvedValue({ + OCR_API_KEY: 'test-api-key', + OCR_BASEURL: 'https://api.mistral.ai/v1', // Make sure this is included + }); + + // Clear all previous mocks + mockAxios.post.mockClear(); + mockAxios.get.mockClear(); + + // 1. First mock: File upload response + mockAxios.post.mockImplementationOnce(() => + Promise.resolve({ data: { id: 'file-123', purpose: 'ocr' } }), + ); + + // 2. Second mock: Signed URL response + mockAxios.get.mockImplementationOnce(() => + Promise.resolve({ data: { url: 'https://signed-url.com' } }), + ); + + // 3. Third mock: OCR response + mockAxios.post.mockImplementationOnce(() => + Promise.resolve({ + data: { + pages: [{ markdown: 'Single page content' }], + }, + }), + ); + + const req = { + user: { id: 'user123' }, + app: { + locals: { + ocr: { + apiKey: 'OCR_API_KEY', + baseURL: 'OCR_BASEURL', + mistralModel: 'mistral-ocr-latest', + }, + }, + }, + }; + + const file = { + path: '/tmp/upload/file.pdf', + originalname: 'single-page.pdf', + }; + + const result = await uploadMistralOCR({ + req, + file, + file_id: 'file123', + entity_id: 'entity123', + }); + + expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf'); + + // Verify that single page documents don't include page numbering + expect(result.text).not.toContain('# PAGE'); + expect(result.text).toEqual('Single page content\n\n'); + }); + + it('should use literal values in configuration when provided directly', async () => { + const { loadAuthValues } = require('~/server/services/Tools/credentials'); + // We'll still mock this but it should not be used for literal values + loadAuthValues.mockResolvedValue({}); + + // Clear all previous mocks + mockAxios.post.mockClear(); + mockAxios.get.mockClear(); + + // 1. First mock: File upload response + mockAxios.post.mockImplementationOnce(() => + Promise.resolve({ data: { id: 'file-123', purpose: 'ocr' } }), + ); + + // 2. Second mock: Signed URL response + mockAxios.get.mockImplementationOnce(() => + Promise.resolve({ data: { url: 'https://signed-url.com' } }), + ); + + // 3. Third mock: OCR response + mockAxios.post.mockImplementationOnce(() => + Promise.resolve({ + data: { + pages: [{ markdown: 'Processed with literal config values' }], + }, + }), + ); + + const req = { + user: { id: 'user123' }, + app: { + locals: { + ocr: { + // Direct values that should be used as-is, without variable substitution + apiKey: 'actual-api-key-value', + baseURL: 'https://direct-api-url.mistral.ai/v1', + mistralModel: 'mistral-direct-model', + }, + }, + }, + }; + + const file = { + path: '/tmp/upload/file.pdf', + originalname: 'direct-values.pdf', + }; + + const result = await uploadMistralOCR({ + req, + file, + file_id: 'file123', + entity_id: 'entity123', + }); + + expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf'); + + // Verify the correct URL was used with the direct baseURL value + expect(mockAxios.post).toHaveBeenCalledWith( + 'https://direct-api-url.mistral.ai/v1/files', + expect.any(Object), + expect.objectContaining({ + headers: expect.objectContaining({ + Authorization: 'Bearer actual-api-key-value', + }), + }), + ); + + // Check the OCR call was made with the direct model value + expect(mockAxios.post).toHaveBeenCalledWith( + 'https://direct-api-url.mistral.ai/v1/ocr', + expect.objectContaining({ + model: 'mistral-direct-model', + }), + expect.any(Object), + ); + + // Verify the result + expect(result.text).toEqual('Processed with literal config values\n\n'); + + // Verify loadAuthValues was never called since we used direct values + expect(loadAuthValues).not.toHaveBeenCalled(); + }); + + it('should handle empty configuration values and use defaults', async () => { + const { loadAuthValues } = require('~/server/services/Tools/credentials'); + // Set up the mock values to be returned by loadAuthValues + loadAuthValues.mockResolvedValue({ + OCR_API_KEY: 'default-from-env-key', + OCR_BASEURL: 'https://default-from-env.mistral.ai/v1', + }); + + // Clear all previous mocks + mockAxios.post.mockClear(); + mockAxios.get.mockClear(); + + // 1. First mock: File upload response + mockAxios.post.mockImplementationOnce(() => + Promise.resolve({ data: { id: 'file-123', purpose: 'ocr' } }), + ); + + // 2. Second mock: Signed URL response + mockAxios.get.mockImplementationOnce(() => + Promise.resolve({ data: { url: 'https://signed-url.com' } }), + ); + + // 3. Third mock: OCR response + mockAxios.post.mockImplementationOnce(() => + Promise.resolve({ + data: { + pages: [{ markdown: 'Content from default configuration' }], + }, + }), + ); + + const req = { + user: { id: 'user123' }, + app: { + locals: { + ocr: { + // Empty string values - should fall back to defaults + apiKey: '', + baseURL: '', + mistralModel: '', + }, + }, + }, + }; + + const file = { + path: '/tmp/upload/file.pdf', + originalname: 'empty-config.pdf', + }; + + const result = await uploadMistralOCR({ + req, + file, + file_id: 'file123', + entity_id: 'entity123', + }); + + expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf'); + + // Verify loadAuthValues was called with the default variable names + expect(loadAuthValues).toHaveBeenCalledWith({ + userId: 'user123', + authFields: ['OCR_BASEURL', 'OCR_API_KEY'], + optional: expect.any(Set), + }); + + // Verify the API calls used the default values from loadAuthValues + expect(mockAxios.post).toHaveBeenCalledWith( + 'https://default-from-env.mistral.ai/v1/files', + expect.any(Object), + expect.objectContaining({ + headers: expect.objectContaining({ + Authorization: 'Bearer default-from-env-key', + }), + }), + ); + + // Verify the OCR model defaulted to mistral-ocr-latest + expect(mockAxios.post).toHaveBeenCalledWith( + 'https://default-from-env.mistral.ai/v1/ocr', + expect.objectContaining({ + model: 'mistral-ocr-latest', + }), + expect.any(Object), + ); + + // Check result + expect(result.text).toEqual('Content from default configuration\n\n'); + }); + }); +}); diff --git a/api/server/services/Files/MistralOCR/index.js b/api/server/services/Files/MistralOCR/index.js new file mode 100644 index 0000000000..a6223d1ee5 --- /dev/null +++ b/api/server/services/Files/MistralOCR/index.js @@ -0,0 +1,5 @@ +const crud = require('./crud'); + +module.exports = { + ...crud, +}; diff --git a/api/server/services/Files/S3/crud.js b/api/server/services/Files/S3/crud.js new file mode 100644 index 0000000000..701c2327da --- /dev/null +++ b/api/server/services/Files/S3/crud.js @@ -0,0 +1,162 @@ +const fs = require('fs'); +const path = require('path'); +const axios = require('axios'); +const fetch = require('node-fetch'); +const { getBufferMetadata } = require('~/server/utils'); +const { initializeS3 } = require('./initialize'); +const { logger } = require('~/config'); +const { PutObjectCommand, GetObjectCommand, DeleteObjectCommand } = require('@aws-sdk/client-s3'); +const { getSignedUrl } = require('@aws-sdk/s3-request-presigner'); + +const bucketName = process.env.AWS_BUCKET_NAME; +const s3 = initializeS3(); +const defaultBasePath = 'images'; + +/** + * Constructs the S3 key based on the base path, user ID, and file name. + */ +const getS3Key = (basePath, userId, fileName) => `${basePath}/${userId}/${fileName}`; + +/** + * Uploads a buffer to S3 and returns a signed URL. + * + * @param {Object} params + * @param {string} params.userId - The user's unique identifier. + * @param {Buffer} params.buffer - The buffer containing file data. + * @param {string} params.fileName - The file name to use in S3. + * @param {string} [params.basePath='images'] - The base path in the bucket. + * @returns {Promise} Signed URL of the uploaded file. + */ +async function saveBufferToS3({ userId, buffer, fileName, basePath = defaultBasePath }) { + const key = getS3Key(basePath, userId, fileName); + const params = { Bucket: bucketName, Key: key, Body: buffer }; + + try { + await s3.send(new PutObjectCommand(params)); + return await getS3URL({ userId, fileName, basePath }); + } catch (error) { + logger.error('[saveBufferToS3] Error uploading buffer to S3:', error.message); + throw error; + } +} + +/** + * Retrieves a signed URL for a file stored in S3. + * + * @param {Object} params + * @param {string} params.userId - The user's unique identifier. + * @param {string} params.fileName - The file name in S3. + * @param {string} [params.basePath='images'] - The base path in the bucket. + * @returns {Promise} A signed URL valid for 24 hours. + */ +async function getS3URL({ userId, fileName, basePath = defaultBasePath }) { + const key = getS3Key(basePath, userId, fileName); + const params = { Bucket: bucketName, Key: key }; + + try { + return await getSignedUrl(s3, new GetObjectCommand(params), { expiresIn: 86400 }); + } catch (error) { + logger.error('[getS3URL] Error getting signed URL from S3:', error.message); + throw error; + } +} + +/** + * Saves a file from a given URL to S3. + * + * @param {Object} params + * @param {string} params.userId - The user's unique identifier. + * @param {string} params.URL - The source URL of the file. + * @param {string} params.fileName - The file name to use in S3. + * @param {string} [params.basePath='images'] - The base path in the bucket. + * @returns {Promise} Signed URL of the uploaded file. + */ +async function saveURLToS3({ userId, URL, fileName, basePath = defaultBasePath }) { + try { + const response = await fetch(URL); + const buffer = await response.buffer(); + // Optionally you can call getBufferMetadata(buffer) if needed. + return await saveBufferToS3({ userId, buffer, fileName, basePath }); + } catch (error) { + logger.error('[saveURLToS3] Error uploading file from URL to S3:', error.message); + throw error; + } +} + +/** + * Deletes a file from S3. + * + * @param {Object} params + * @param {string} params.userId - The user's unique identifier. + * @param {string} params.fileName - The file name in S3. + * @param {string} [params.basePath='images'] - The base path in the bucket. + * @returns {Promise} + */ +async function deleteFileFromS3({ userId, fileName, basePath = defaultBasePath }) { + const key = getS3Key(basePath, userId, fileName); + const params = { Bucket: bucketName, Key: key }; + + try { + await s3.send(new DeleteObjectCommand(params)); + logger.debug('[deleteFileFromS3] File deleted successfully from S3'); + } catch (error) { + logger.error('[deleteFileFromS3] Error deleting file from S3:', error.message); + // If the file is not found, we can safely return. + if (error.code === 'NoSuchKey') { + return; + } + throw error; + } +} + +/** + * Uploads a local file to S3. + * + * @param {Object} params + * @param {import('express').Request} params.req - The Express request (must include user). + * @param {Express.Multer.File} params.file - The file object from Multer. + * @param {string} params.file_id - Unique file identifier. + * @param {string} [params.basePath='images'] - The base path in the bucket. + * @returns {Promise<{ filepath: string, bytes: number }>} + */ +async function uploadFileToS3({ req, file, file_id, basePath = defaultBasePath }) { + try { + const inputFilePath = file.path; + const inputBuffer = await fs.promises.readFile(inputFilePath); + const bytes = Buffer.byteLength(inputBuffer); + const userId = req.user.id; + const fileName = `${file_id}__${path.basename(inputFilePath)}`; + const fileURL = await saveBufferToS3({ userId, buffer: inputBuffer, fileName, basePath }); + await fs.promises.unlink(inputFilePath); + return { filepath: fileURL, bytes }; + } catch (error) { + logger.error('[uploadFileToS3] Error uploading file to S3:', error.message); + throw error; + } +} + +/** + * Retrieves a readable stream for a file stored in S3. + * + * @param {string} filePath - The S3 key of the file. + * @returns {Promise} + */ +async function getS3FileStream(filePath) { + const params = { Bucket: bucketName, Key: filePath }; + try { + const data = await s3.send(new GetObjectCommand(params)); + return data.Body; // Returns a Node.js ReadableStream. + } catch (error) { + logger.error('[getS3FileStream] Error retrieving S3 file stream:', error.message); + throw error; + } +} + +module.exports = { + saveBufferToS3, + saveURLToS3, + getS3URL, + deleteFileFromS3, + uploadFileToS3, + getS3FileStream, +}; diff --git a/api/server/services/Files/S3/images.js b/api/server/services/Files/S3/images.js new file mode 100644 index 0000000000..378212cb5e --- /dev/null +++ b/api/server/services/Files/S3/images.js @@ -0,0 +1,118 @@ +const fs = require('fs'); +const path = require('path'); +const sharp = require('sharp'); +const { resizeImageBuffer } = require('../images/resize'); +const { updateUser } = require('~/models/userMethods'); +const { saveBufferToS3 } = require('./crud'); +const { updateFile } = require('~/models/File'); +const { logger } = require('~/config'); + +const defaultBasePath = 'images'; + +/** + * Resizes, converts, and uploads an image file to S3. + * + * @param {Object} params + * @param {import('express').Request} params.req - Express request (expects user and app.locals.imageOutputType). + * @param {Express.Multer.File} params.file - File object from Multer. + * @param {string} params.file_id - Unique file identifier. + * @param {any} params.endpoint - Endpoint identifier used in image processing. + * @param {string} [params.resolution='high'] - Desired image resolution. + * @param {string} [params.basePath='images'] - Base path in the bucket. + * @returns {Promise<{ filepath: string, bytes: number, width: number, height: number }>} + */ +async function uploadImageToS3({ + req, + file, + file_id, + endpoint, + resolution = 'high', + basePath = defaultBasePath, +}) { + try { + const inputFilePath = file.path; + const inputBuffer = await fs.promises.readFile(inputFilePath); + const { + buffer: resizedBuffer, + width, + height, + } = await resizeImageBuffer(inputBuffer, resolution, endpoint); + const extension = path.extname(inputFilePath); + const userId = req.user.id; + + let processedBuffer; + let fileName = `${file_id}__${path.basename(inputFilePath)}`; + const targetExtension = `.${req.app.locals.imageOutputType}`; + + if (extension.toLowerCase() === targetExtension) { + processedBuffer = resizedBuffer; + } else { + processedBuffer = await sharp(resizedBuffer) + .toFormat(req.app.locals.imageOutputType) + .toBuffer(); + fileName = fileName.replace(new RegExp(path.extname(fileName) + '$'), targetExtension); + if (!path.extname(fileName)) { + fileName += targetExtension; + } + } + + const downloadURL = await saveBufferToS3({ + userId, + buffer: processedBuffer, + fileName, + basePath, + }); + await fs.promises.unlink(inputFilePath); + const bytes = Buffer.byteLength(processedBuffer); + return { filepath: downloadURL, bytes, width, height }; + } catch (error) { + logger.error('[uploadImageToS3] Error uploading image to S3:', error.message); + throw error; + } +} + +/** + * Updates a file record and returns its signed URL. + * + * @param {import('express').Request} req - Express request. + * @param {Object} file - File metadata. + * @returns {Promise<[Promise, string]>} + */ +async function prepareImageURLS3(req, file) { + try { + const updatePromise = updateFile({ file_id: file.file_id }); + return Promise.all([updatePromise, file.filepath]); + } catch (error) { + logger.error('[prepareImageURLS3] Error preparing image URL:', error.message); + throw error; + } +} + +/** + * Processes a user's avatar image by uploading it to S3 and updating the user's avatar URL if required. + * + * @param {Object} params + * @param {Buffer} params.buffer - Avatar image buffer. + * @param {string} params.userId - User's unique identifier. + * @param {string} params.manual - 'true' or 'false' flag for manual update. + * @param {string} [params.basePath='images'] - Base path in the bucket. + * @returns {Promise} Signed URL of the uploaded avatar. + */ +async function processS3Avatar({ buffer, userId, manual, basePath = defaultBasePath }) { + try { + const downloadURL = await saveBufferToS3({ userId, buffer, fileName: 'avatar.png', basePath }); + if (manual === 'true') { + await updateUser(userId, { avatar: downloadURL }); + } + return downloadURL; + } catch (error) { + logger.error('[processS3Avatar] Error processing S3 avatar:', error.message); + throw error; + } +} + +module.exports = { + uploadImageToS3, + prepareImageURLS3, + processS3Avatar, +}; diff --git a/api/server/services/Files/S3/index.js b/api/server/services/Files/S3/index.js new file mode 100644 index 0000000000..27ad97a852 --- /dev/null +++ b/api/server/services/Files/S3/index.js @@ -0,0 +1,9 @@ +const crud = require('./crud'); +const images = require('./images'); +const initialize = require('./initialize'); + +module.exports = { + ...crud, + ...images, + ...initialize, +}; diff --git a/api/server/services/Files/S3/initialize.js b/api/server/services/Files/S3/initialize.js new file mode 100644 index 0000000000..d85945f708 --- /dev/null +++ b/api/server/services/Files/S3/initialize.js @@ -0,0 +1,43 @@ +const { S3Client } = require('@aws-sdk/client-s3'); +const { logger } = require('~/config'); + +let s3 = null; + +/** + * Initializes and returns an instance of the AWS S3 client. + * + * If AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY are provided, they will be used. + * Otherwise, the AWS SDK's default credentials chain (including IRSA) is used. + * + * @returns {S3Client|null} An instance of S3Client if the region is provided; otherwise, null. + */ +const initializeS3 = () => { + if (s3) { + return s3; + } + + const region = process.env.AWS_REGION; + if (!region) { + logger.error('[initializeS3] AWS_REGION is not set. Cannot initialize S3.'); + return null; + } + + const accessKeyId = process.env.AWS_ACCESS_KEY_ID; + const secretAccessKey = process.env.AWS_SECRET_ACCESS_KEY; + + if (accessKeyId && secretAccessKey) { + s3 = new S3Client({ + region, + credentials: { accessKeyId, secretAccessKey }, + }); + logger.info('[initializeS3] S3 initialized with provided credentials.'); + } else { + // When using IRSA, credentials are automatically provided via the IAM Role attached to the ServiceAccount. + s3 = new S3Client({ region }); + logger.info('[initializeS3] S3 initialized using default credentials (IRSA).'); + } + + return s3; +}; + +module.exports = { initializeS3 }; diff --git a/api/server/services/Files/images/encode.js b/api/server/services/Files/images/encode.js index 94153ffc64..707632fb6a 100644 --- a/api/server/services/Files/images/encode.js +++ b/api/server/services/Files/images/encode.js @@ -49,6 +49,7 @@ async function encodeAndFormat(req, files, endpoint, mode) { const promises = []; const encodingMethods = {}; const result = { + text: '', files: [], image_urls: [], }; @@ -59,6 +60,9 @@ async function encodeAndFormat(req, files, endpoint, mode) { for (let file of files) { const source = file.source ?? FileSources.local; + if (source === FileSources.text && file.text) { + result.text += `${!result.text ? 'Attached document(s):\n```md' : '\n\n---\n\n'}# "${file.filename}"\n${file.text}\n`; + } if (!file.height) { promises.push([file, null]); @@ -85,6 +89,10 @@ async function encodeAndFormat(req, files, endpoint, mode) { promises.push(preparePayload(req, file)); } + if (result.text) { + result.text += '\n```'; + } + const detail = req.body.imageDetail ?? ImageDetail.auto; /** @type {Array<[MongoFile, string]>} */ diff --git a/api/server/services/Files/process.js b/api/server/services/Files/process.js index 8744eb409b..78a4976e2f 100644 --- a/api/server/services/Files/process.js +++ b/api/server/services/Files/process.js @@ -28,8 +28,8 @@ const { addResourceFileId, deleteResourceFileId } = require('~/server/controller const { addAgentResourceFile, removeAgentResourceFiles } = require('~/models/Agent'); const { getOpenAIClient } = require('~/server/controllers/assistants/helpers'); const { createFile, updateFileUsage, deleteFiles } = require('~/models/File'); -const { getEndpointsConfig } = require('~/server/services/Config'); -const { loadAuthValues } = require('~/app/clients/tools/util'); +const { loadAuthValues } = require('~/server/services/Tools/credentials'); +const { checkCapability } = require('~/server/services/Config'); const { LB_QueueAsyncCall } = require('~/server/utils/queue'); const { getStrategyFunctions } = require('./strategies'); const { determineFileType } = require('~/server/utils'); @@ -162,7 +162,6 @@ const processDeleteRequest = async ({ req, files }) => { for (const file of files) { const source = file.source ?? FileSources.local; - if (req.body.agent_id && req.body.tool_resource) { agentFiles.push({ tool_resource: req.body.tool_resource, @@ -170,6 +169,11 @@ const processDeleteRequest = async ({ req, files }) => { }); } + if (source === FileSources.text) { + resolvedFileIds.push(file.file_id); + continue; + } + if (checkOpenAIStorage(source) && !client[source]) { await initializeClients(); } @@ -453,17 +457,6 @@ const processFileUpload = async ({ req, res, metadata }) => { res.status(200).json({ message: 'File uploaded and processed successfully', ...result }); }; -/** - * @param {ServerRequest} req - * @param {AgentCapabilities} capability - * @returns {Promise} - */ -const checkCapability = async (req, capability) => { - const endpointsConfig = await getEndpointsConfig(req); - const capabilities = endpointsConfig?.[EModelEndpoint.agents]?.capabilities ?? []; - return capabilities.includes(capability); -}; - /** * Applies the current strategy for file uploads. * Saves file metadata to the database with an expiry TTL. @@ -521,6 +514,52 @@ const processAgentFileUpload = async ({ req, res, metadata }) => { if (!isFileSearchEnabled) { throw new Error('File search is not enabled for Agents'); } + } else if (tool_resource === EToolResources.ocr) { + const isOCREnabled = await checkCapability(req, AgentCapabilities.ocr); + if (!isOCREnabled) { + throw new Error('OCR capability is not enabled for Agents'); + } + + const { handleFileUpload } = getStrategyFunctions( + req.app.locals?.ocr?.strategy ?? FileSources.mistral_ocr, + ); + const { file_id, temp_file_id } = metadata; + + const { + text, + bytes, + // TODO: OCR images support? + images, + filename, + filepath: ocrFileURL, + } = await handleFileUpload({ req, file, file_id, entity_id: agent_id }); + + const fileInfo = removeNullishValues({ + text, + bytes, + file_id, + temp_file_id, + user: req.user.id, + type: file.mimetype, + filepath: ocrFileURL, + source: FileSources.text, + filename: filename ?? file.originalname, + model: messageAttachment ? undefined : req.body.model, + context: messageAttachment ? FileContext.message_attachment : FileContext.agents, + }); + + if (!messageAttachment && tool_resource) { + await addAgentResourceFile({ + req, + file_id, + agent_id, + tool_resource, + }); + } + const result = await createFile(fileInfo, true); + return res + .status(200) + .json({ message: 'Agent file uploaded and processed successfully', ...result }); } const source = diff --git a/api/server/services/Files/strategies.js b/api/server/services/Files/strategies.js index ddfdd57469..7fcf10af03 100644 --- a/api/server/services/Files/strategies.js +++ b/api/server/services/Files/strategies.js @@ -21,9 +21,21 @@ const { processLocalAvatar, getLocalFileStream, } = require('./Local'); +const { + getS3URL, + saveURLToS3, + saveBufferToS3, + getS3FileStream, + uploadImageToS3, + prepareImageURLS3, + deleteFileFromS3, + processS3Avatar, + uploadFileToS3, +} = require('./S3'); const { uploadOpenAIFile, deleteOpenAIFile, getOpenAIFileStream } = require('./OpenAI'); const { getCodeOutputDownloadStream, uploadCodeEnvFile } = require('./Code'); const { uploadVectors, deleteVectors } = require('./VectorDB'); +const { uploadMistralOCR } = require('./MistralOCR'); /** * Firebase Storage Strategy Functions @@ -57,6 +69,22 @@ const localStrategy = () => ({ getDownloadStream: getLocalFileStream, }); +/** + * S3 Storage Strategy Functions + * + * */ +const s3Strategy = () => ({ + handleFileUpload: uploadFileToS3, + saveURL: saveURLToS3, + getFileURL: getS3URL, + deleteFile: deleteFileFromS3, + saveBuffer: saveBufferToS3, + prepareImagePayload: prepareImageURLS3, + processAvatar: processS3Avatar, + handleImageUpload: uploadImageToS3, + getDownloadStream: getS3FileStream, +}); + /** * VectorDB Storage Strategy Functions * @@ -127,6 +155,26 @@ const codeOutputStrategy = () => ({ getDownloadStream: getCodeOutputDownloadStream, }); +const mistralOCRStrategy = () => ({ + /** @type {typeof saveFileFromURL | null} */ + saveURL: null, + /** @type {typeof getLocalFileURL | null} */ + getFileURL: null, + /** @type {typeof saveLocalBuffer | null} */ + saveBuffer: null, + /** @type {typeof processLocalAvatar | null} */ + processAvatar: null, + /** @type {typeof uploadLocalImage | null} */ + handleImageUpload: null, + /** @type {typeof prepareImagesLocal | null} */ + prepareImagePayload: null, + /** @type {typeof deleteLocalFile | null} */ + deleteFile: null, + /** @type {typeof getLocalFileStream | null} */ + getDownloadStream: null, + handleFileUpload: uploadMistralOCR, +}); + // Strategy Selector const getStrategyFunctions = (fileSource) => { if (fileSource === FileSources.firebase) { @@ -139,8 +187,12 @@ const getStrategyFunctions = (fileSource) => { return openAIStrategy(); } else if (fileSource === FileSources.vectordb) { return vectorStrategy(); + } else if (fileSource === FileSources.s3) { + return s3Strategy(); } else if (fileSource === FileSources.execute_code) { return codeOutputStrategy(); + } else if (fileSource === FileSources.mistral_ocr) { + return mistralOCRStrategy(); } else { throw new Error('Invalid file source'); } diff --git a/api/server/services/ToolService.js b/api/server/services/ToolService.js index f3e4efb6e3..969ca8d8ff 100644 --- a/api/server/services/ToolService.js +++ b/api/server/services/ToolService.js @@ -362,7 +362,12 @@ async function processRequiredActions(client, requiredActions) { continue; } - tool = await createActionTool({ action: actionSet, requestBuilder }); + tool = await createActionTool({ + req: client.req, + res: client.res, + action: actionSet, + requestBuilder, + }); if (!tool) { logger.warn( `Invalid action: user: ${client.req.user.id} | thread_id: ${requiredActions[0].thread_id} | run_id: ${requiredActions[0].run_id} | toolName: ${currentAction.tool}`, diff --git a/api/server/services/Tools/credentials.js b/api/server/services/Tools/credentials.js new file mode 100644 index 0000000000..b50a2460d4 --- /dev/null +++ b/api/server/services/Tools/credentials.js @@ -0,0 +1,56 @@ +const { getUserPluginAuthValue } = require('~/server/services/PluginService'); + +/** + * + * @param {Object} params + * @param {string} params.userId + * @param {string[]} params.authFields + * @param {Set} [params.optional] + * @param {boolean} [params.throwError] + * @returns + */ +const loadAuthValues = async ({ userId, authFields, optional, throwError = true }) => { + let authValues = {}; + + /** + * Finds the first non-empty value for the given authentication field, supporting alternate fields. + * @param {string[]} fields Array of strings representing the authentication fields. Supports alternate fields delimited by "||". + * @returns {Promise<{ authField: string, authValue: string} | null>} An object containing the authentication field and value, or null if not found. + */ + const findAuthValue = async (fields) => { + for (const field of fields) { + let value = process.env[field]; + if (value) { + return { authField: field, authValue: value }; + } + try { + value = await getUserPluginAuthValue(userId, field, throwError); + } catch (err) { + if (optional && optional.has(field)) { + return { authField: field, authValue: undefined }; + } + if (field === fields[fields.length - 1] && !value) { + throw err; + } + } + if (value) { + return { authField: field, authValue: value }; + } + } + return null; + }; + + for (let authField of authFields) { + const fields = authField.split('||'); + const result = await findAuthValue(fields); + if (result) { + authValues[result.authField] = result.authValue; + } + } + + return authValues; +}; + +module.exports = { + loadAuthValues, +}; diff --git a/api/server/utils/handleText.js b/api/server/utils/handleText.js index 8c681d8f4e..f593d6c866 100644 --- a/api/server/utils/handleText.js +++ b/api/server/utils/handleText.js @@ -203,6 +203,8 @@ function generateConfig(key, baseURL, endpoint) { AgentCapabilities.artifacts, AgentCapabilities.actions, AgentCapabilities.tools, + AgentCapabilities.ocr, + AgentCapabilities.chain, ]; } diff --git a/api/test/__mocks__/logger.js b/api/test/__mocks__/logger.js index caeb004e39..549c57d5a4 100644 --- a/api/test/__mocks__/logger.js +++ b/api/test/__mocks__/logger.js @@ -39,7 +39,10 @@ jest.mock('winston-daily-rotate-file', () => { }); jest.mock('~/config', () => { + const actualModule = jest.requireActual('~/config'); return { + sendEvent: actualModule.sendEvent, + createAxiosInstance: actualModule.createAxiosInstance, logger: { info: jest.fn(), warn: jest.fn(), diff --git a/api/typedefs.js b/api/typedefs.js index 3045d9543b..21c4f1fecc 100644 --- a/api/typedefs.js +++ b/api/typedefs.js @@ -1787,3 +1787,51 @@ * @typedef {Promise<{ message: TMessage, conversation: TConversation }> | undefined} ClientDatabaseSavePromise * @memberof typedefs */ + +/** + * @exports OCRImage + * @typedef {Object} OCRImage + * @property {string} id - The identifier of the image. + * @property {number} top_left_x - X-coordinate of the top left corner of the image. + * @property {number} top_left_y - Y-coordinate of the top left corner of the image. + * @property {number} bottom_right_x - X-coordinate of the bottom right corner of the image. + * @property {number} bottom_right_y - Y-coordinate of the bottom right corner of the image. + * @property {string} image_base64 - Base64-encoded image data. + * @memberof typedefs + */ + +/** + * @exports PageDimensions + * @typedef {Object} PageDimensions + * @property {number} dpi - The dots per inch resolution of the page. + * @property {number} height - The height of the page in pixels. + * @property {number} width - The width of the page in pixels. + * @memberof typedefs + */ + +/** + * @exports OCRPage + * @typedef {Object} OCRPage + * @property {number} index - The index of the page in the document. + * @property {string} markdown - The extracted text content of the page in markdown format. + * @property {OCRImage[]} images - Array of images found on the page. + * @property {PageDimensions} dimensions - The dimensions of the page. + * @memberof typedefs + */ + +/** + * @exports OCRUsageInfo + * @typedef {Object} OCRUsageInfo + * @property {number} pages_processed - Number of pages processed in the document. + * @property {number} doc_size_bytes - Size of the document in bytes. + * @memberof typedefs + */ + +/** + * @exports OCRResult + * @typedef {Object} OCRResult + * @property {OCRPage[]} pages - Array of pages extracted from the document. + * @property {string} model - The model used for OCR processing. + * @property {OCRUsageInfo} usage_info - Usage information for the OCR operation. + * @memberof typedefs + */ diff --git a/api/utils/tokens.js b/api/utils/tokens.js index 8edfb0a31c..58aaf7051b 100644 --- a/api/utils/tokens.js +++ b/api/utils/tokens.js @@ -92,6 +92,7 @@ const anthropicModels = { const deepseekModels = { 'deepseek-reasoner': 63000, // -1000 from max (API) deepseek: 63000, // -1000 from max (API) + 'deepseek.r1': 127500, }; const metaModels = { diff --git a/api/utils/tokens.spec.js b/api/utils/tokens.spec.js index d4dbb30498..e5ae21b646 100644 --- a/api/utils/tokens.spec.js +++ b/api/utils/tokens.spec.js @@ -423,6 +423,9 @@ describe('Meta Models Tests', () => { expect(getModelMaxTokens('deepseek-reasoner')).toBe( maxTokensMap[EModelEndpoint.openAI]['deepseek-reasoner'], ); + expect(getModelMaxTokens('deepseek.r1')).toBe( + maxTokensMap[EModelEndpoint.openAI]['deepseek.r1'], + ); }); }); diff --git a/client/src/common/agents-types.ts b/client/src/common/agents-types.ts index a9c24106bc..982cbfdb17 100644 --- a/client/src/common/agents-types.ts +++ b/client/src/common/agents-types.ts @@ -5,6 +5,7 @@ import type { OptionWithIcon, ExtendedFile } from './types'; export type TAgentOption = OptionWithIcon & Agent & { knowledge_files?: Array<[string, ExtendedFile]>; + context_files?: Array<[string, ExtendedFile]>; code_files?: Array<[string, ExtendedFile]>; }; @@ -27,4 +28,5 @@ export type AgentForm = { provider?: AgentProvider | OptionWithIcon; agent_ids?: string[]; [AgentCapabilities.artifacts]?: ArtifactModes | string; + recursion_limit?: number; } & TAgentCapabilities; diff --git a/client/src/common/types.ts b/client/src/common/types.ts index 380ec573b8..975f468930 100644 --- a/client/src/common/types.ts +++ b/client/src/common/types.ts @@ -131,6 +131,7 @@ export interface DataColumnMeta { } export enum Panel { + advanced = 'advanced', builder = 'builder', actions = 'actions', model = 'model', @@ -181,6 +182,7 @@ export type AgentPanelProps = { activePanel?: string; action?: t.Action; actions?: t.Action[]; + createMutation: UseMutationResult; setActivePanel: React.Dispatch>; setAction: React.Dispatch>; endpointsConfig?: t.TEndpointsConfig; @@ -483,6 +485,7 @@ export interface ExtendedFile { attached?: boolean; embedded?: boolean; tool_resource?: string; + metadata?: t.TFile['metadata']; } export type ContextType = { navVisible: boolean; setNavVisible: (visible: boolean) => void }; diff --git a/client/src/components/Chat/Input/Files/AttachFileMenu.tsx b/client/src/components/Chat/Input/Files/AttachFileMenu.tsx index 54a8a595c4..8841a0ae51 100644 --- a/client/src/components/Chat/Input/Files/AttachFileMenu.tsx +++ b/client/src/components/Chat/Input/Files/AttachFileMenu.tsx @@ -1,7 +1,7 @@ import * as Ariakit from '@ariakit/react'; import React, { useRef, useState, useMemo } from 'react'; -import { FileSearch, ImageUpIcon, TerminalSquareIcon } from 'lucide-react'; import { EToolResources, EModelEndpoint } from 'librechat-data-provider'; +import { FileSearch, ImageUpIcon, TerminalSquareIcon, FileType2Icon } from 'lucide-react'; import { FileUpload, TooltipAnchor, DropdownPopup } from '~/components/ui'; import { useGetEndpointsQuery } from '~/data-provider'; import { AttachmentIcon } from '~/components/svg'; @@ -49,6 +49,17 @@ const AttachFile = ({ isRTL, disabled, handleFileChange }: AttachFileProps) => { }, ]; + if (capabilities.includes(EToolResources.ocr)) { + items.push({ + label: localize('com_ui_upload_ocr_text'), + onClick: () => { + setToolResource(EToolResources.ocr); + handleUploadClick(); + }, + icon: , + }); + } + if (capabilities.includes(EToolResources.file_search)) { items.push({ label: localize('com_ui_upload_file_search'), diff --git a/client/src/components/Chat/Input/Files/DragDropModal.tsx b/client/src/components/Chat/Input/Files/DragDropModal.tsx index b252ae1a93..2abc15a45b 100644 --- a/client/src/components/Chat/Input/Files/DragDropModal.tsx +++ b/client/src/components/Chat/Input/Files/DragDropModal.tsx @@ -1,6 +1,6 @@ import React, { useMemo } from 'react'; import { EModelEndpoint, EToolResources } from 'librechat-data-provider'; -import { FileSearch, ImageUpIcon, TerminalSquareIcon } from 'lucide-react'; +import { FileSearch, ImageUpIcon, FileType2Icon, TerminalSquareIcon } from 'lucide-react'; import OGDialogTemplate from '~/components/ui/OGDialogTemplate'; import { useGetEndpointsQuery } from '~/data-provider'; import useLocalize from '~/hooks/useLocalize'; @@ -50,6 +50,12 @@ const DragDropModal = ({ onOptionSelect, setShowModal, files, isVisible }: DragD value: EToolResources.execute_code, icon: , }); + } else if (capability === EToolResources.ocr) { + _options.push({ + label: localize('com_ui_upload_ocr_text'), + value: EToolResources.ocr, + icon: , + }); } } diff --git a/client/src/components/Chat/Input/Files/FilePreview.tsx b/client/src/components/Chat/Input/Files/FilePreview.tsx index 80933b8503..02851119af 100644 --- a/client/src/components/Chat/Input/Files/FilePreview.tsx +++ b/client/src/components/Chat/Input/Files/FilePreview.tsx @@ -19,7 +19,7 @@ const FilePreview = ({ }; className?: string; }) => { - const radius = 55; // Radius of the SVG circle + const radius = 55; const circumference = 2 * Math.PI * radius; const progress = useProgress( file?.['progress'] ?? 1, @@ -27,16 +27,15 @@ const FilePreview = ({ (file as ExtendedFile | undefined)?.size ?? 1, ); - // Calculate the offset based on the loading progress const offset = circumference - progress * circumference; const circleCSSProperties = { transition: 'stroke-dashoffset 0.5s linear', }; return ( -
+
- + {progress < 1 && ( + + + +
+ ); + } + + if (source === FileSources.text) { + return ( +
+ + + +
+ ); + } + + if (source === FileSources.vectordb) { + return ( +
+ + + +
+ ); } const endpoint = sourceToEndpoint[source ?? '']; @@ -31,7 +64,7 @@ export default function SourceIcon({ return null; } return ( - +
); } diff --git a/client/src/components/Chat/Menus/Models/ModelSpec.tsx b/client/src/components/Chat/Menus/Models/ModelSpec.tsx index 32415bc557..b57f97820f 100644 --- a/client/src/components/Chat/Menus/Models/ModelSpec.tsx +++ b/client/src/components/Chat/Menus/Models/ModelSpec.tsx @@ -75,7 +75,7 @@ const MenuItem: FC = ({ {showIconInMenu && }
{title} -
{description}
+
{description}
{spec.badges && spec.badges.length > 0 && (
{spec.badges.map((badge, index) => ( diff --git a/client/src/components/Chat/Messages/Content/ContentParts.tsx b/client/src/components/Chat/Messages/Content/ContentParts.tsx index ddf08976eb..3805e0bb41 100644 --- a/client/src/components/Chat/Messages/Content/ContentParts.tsx +++ b/client/src/components/Chat/Messages/Content/ContentParts.tsx @@ -139,6 +139,7 @@ const ContentParts = memo( isSubmitting={isSubmitting} key={`part-${messageId}-${idx}`} isCreatedByUser={isCreatedByUser} + isLast={idx === content.length - 1} showCursor={idx === content.length - 1 && isLast} /> diff --git a/client/src/components/Chat/Messages/Content/Markdown.tsx b/client/src/components/Chat/Messages/Content/Markdown.tsx index e01de091c7..ee134b0e53 100644 --- a/client/src/components/Chat/Messages/Content/Markdown.tsx +++ b/client/src/components/Chat/Messages/Content/Markdown.tsx @@ -166,15 +166,12 @@ export const p: React.ElementType = memo(({ children }: TParagraphProps) => { return

{children}

; }); -const cursor = ' '; - type TContentProps = { content: string; - showCursor?: boolean; isLatestMessage: boolean; }; -const Markdown = memo(({ content = '', showCursor, isLatestMessage }: TContentProps) => { +const Markdown = memo(({ content = '', isLatestMessage }: TContentProps) => { const LaTeXParsing = useRecoilValue(store.LaTeXParsing); const isInitializing = content === ''; @@ -240,7 +237,7 @@ const Markdown = memo(({ content = '', showCursor, isLatestMessage }: TContentPr } } > - {isLatestMessage && (showCursor ?? false) ? currentContent + cursor : currentContent} + {currentContent} diff --git a/client/src/components/Chat/Messages/Content/MessageContent.tsx b/client/src/components/Chat/Messages/Content/MessageContent.tsx index 1547a01d80..f70a15b779 100644 --- a/client/src/components/Chat/Messages/Content/MessageContent.tsx +++ b/client/src/components/Chat/Messages/Content/MessageContent.tsx @@ -83,9 +83,7 @@ const DisplayMessage = ({ text, isCreatedByUser, message, showCursor }: TDisplay let content: React.ReactElement; if (!isCreatedByUser) { - content = ( - - ); + content = ; } else if (enableUserMsgMarkdown) { content = ; } else { diff --git a/client/src/components/Chat/Messages/Content/Part.tsx b/client/src/components/Chat/Messages/Content/Part.tsx index 2430bee6f9..1351efd59c 100644 --- a/client/src/components/Chat/Messages/Content/Part.tsx +++ b/client/src/components/Chat/Messages/Content/Part.tsx @@ -8,9 +8,11 @@ import { import { memo } from 'react'; import type { TMessageContentParts, TAttachment } from 'librechat-data-provider'; import { ErrorMessage } from './MessageContent'; +import AgentUpdate from './Parts/AgentUpdate'; import ExecuteCode from './Parts/ExecuteCode'; import RetrievalCall from './RetrievalCall'; import Reasoning from './Parts/Reasoning'; +import EmptyText from './Parts/EmptyText'; import CodeAnalyze from './CodeAnalyze'; import Container from './Container'; import ToolCall from './ToolCall'; @@ -20,145 +22,159 @@ import Image from './Image'; type PartProps = { part?: TMessageContentParts; + isLast?: boolean; isSubmitting: boolean; showCursor: boolean; isCreatedByUser: boolean; attachments?: TAttachment[]; }; -const Part = memo(({ part, isSubmitting, attachments, showCursor, isCreatedByUser }: PartProps) => { - if (!part) { - return null; - } - - if (part.type === ContentTypes.ERROR) { - return ( - - ); - } else if (part.type === ContentTypes.TEXT) { - const text = typeof part.text === 'string' ? part.text : part.text.value; - - if (typeof text !== 'string') { - return null; - } - if (part.tool_call_ids != null && !text) { - return null; - } - return ( - - - - ); - } else if (part.type === ContentTypes.THINK) { - const reasoning = typeof part.think === 'string' ? part.think : part.think.value; - if (typeof reasoning !== 'string') { - return null; - } - return ; - } else if (part.type === ContentTypes.TOOL_CALL) { - const toolCall = part[ContentTypes.TOOL_CALL]; - - if (!toolCall) { +const Part = memo( + ({ part, isSubmitting, attachments, isLast, showCursor, isCreatedByUser }: PartProps) => { + if (!part) { return null; } - const isToolCall = - 'args' in toolCall && (!toolCall.type || toolCall.type === ToolCallTypes.TOOL_CALL); - if (isToolCall && toolCall.name === Tools.execute_code) { + if (part.type === ContentTypes.ERROR) { return ( - ); - } else if (isToolCall) { + } else if (part.type === ContentTypes.AGENT_UPDATE) { return ( - - ); - } else if (toolCall.type === ToolCallTypes.CODE_INTERPRETER) { - const code_interpreter = toolCall[ToolCallTypes.CODE_INTERPRETER]; - return ( - - ); - } else if ( - toolCall.type === ToolCallTypes.RETRIEVAL || - toolCall.type === ToolCallTypes.FILE_SEARCH - ) { - return ( - - ); - } else if ( - toolCall.type === ToolCallTypes.FUNCTION && - ToolCallTypes.FUNCTION in toolCall && - imageGenTools.has(toolCall.function.name) - ) { - return ( - - ); - } else if (toolCall.type === ToolCallTypes.FUNCTION && ToolCallTypes.FUNCTION in toolCall) { - if (isImageVisionTool(toolCall)) { - if (isSubmitting && showCursor) { - return ( + <> + + {isLast && showCursor && ( - + - ); - } + )} + + ); + } else if (part.type === ContentTypes.TEXT) { + const text = typeof part.text === 'string' ? part.text : part.text.value; + + if (typeof text !== 'string') { + return null; + } + if (part.tool_call_ids != null && !text) { + return null; + } + return ( + + + + ); + } else if (part.type === ContentTypes.THINK) { + const reasoning = typeof part.think === 'string' ? part.think : part.think.value; + if (typeof reasoning !== 'string') { + return null; + } + return ; + } else if (part.type === ContentTypes.TOOL_CALL) { + const toolCall = part[ContentTypes.TOOL_CALL]; + + if (!toolCall) { return null; } + const isToolCall = + 'args' in toolCall && (!toolCall.type || toolCall.type === ToolCallTypes.TOOL_CALL); + if (isToolCall && toolCall.name === Tools.execute_code) { + return ( + + ); + } else if (isToolCall) { + return ( + + ); + } else if (toolCall.type === ToolCallTypes.CODE_INTERPRETER) { + const code_interpreter = toolCall[ToolCallTypes.CODE_INTERPRETER]; + return ( + + ); + } else if ( + toolCall.type === ToolCallTypes.RETRIEVAL || + toolCall.type === ToolCallTypes.FILE_SEARCH + ) { + return ( + + ); + } else if ( + toolCall.type === ToolCallTypes.FUNCTION && + ToolCallTypes.FUNCTION in toolCall && + imageGenTools.has(toolCall.function.name) + ) { + return ( + + ); + } else if (toolCall.type === ToolCallTypes.FUNCTION && ToolCallTypes.FUNCTION in toolCall) { + if (isImageVisionTool(toolCall)) { + if (isSubmitting && showCursor) { + return ( + + + + ); + } + return null; + } + + return ( + + ); + } + } else if (part.type === ContentTypes.IMAGE_FILE) { + const imageFile = part[ContentTypes.IMAGE_FILE]; + const height = imageFile.height ?? 1920; + const width = imageFile.width ?? 1080; return ( - ); } - } else if (part.type === ContentTypes.IMAGE_FILE) { - const imageFile = part[ContentTypes.IMAGE_FILE]; - const height = imageFile.height ?? 1920; - const width = imageFile.width ?? 1080; - return ( - - ); - } - return null; -}); + return null; + }, +); export default Part; diff --git a/client/src/components/Chat/Messages/Content/Parts/AgentUpdate.tsx b/client/src/components/Chat/Messages/Content/Parts/AgentUpdate.tsx new file mode 100644 index 0000000000..4dca00107e --- /dev/null +++ b/client/src/components/Chat/Messages/Content/Parts/AgentUpdate.tsx @@ -0,0 +1,39 @@ +import React, { useMemo } from 'react'; +import { EModelEndpoint } from 'librechat-data-provider'; +import { useAgentsMapContext } from '~/Providers'; +import Icon from '~/components/Endpoints/Icon'; + +interface AgentUpdateProps { + currentAgentId: string; +} + +const AgentUpdate: React.FC = ({ currentAgentId }) => { + const agentsMap = useAgentsMapContext() || {}; + const currentAgent = useMemo(() => agentsMap[currentAgentId], [agentsMap, currentAgentId]); + if (!currentAgentId) { + return null; + } + return ( +
+
+
+
+
+
+
+
+
+ +
+
{currentAgent?.name}
+
+
+ ); +}; + +export default AgentUpdate; diff --git a/client/src/components/Chat/Messages/Content/Parts/EmptyText.tsx b/client/src/components/Chat/Messages/Content/Parts/EmptyText.tsx new file mode 100644 index 0000000000..1b514164df --- /dev/null +++ b/client/src/components/Chat/Messages/Content/Parts/EmptyText.tsx @@ -0,0 +1,17 @@ +import { memo } from 'react'; + +const EmptyTextPart = memo(() => { + return ( +
+
+
+

+ +

+
+
+
+ ); +}); + +export default EmptyTextPart; diff --git a/client/src/components/Chat/Messages/Content/Parts/Text.tsx b/client/src/components/Chat/Messages/Content/Parts/Text.tsx index 7c207f1512..d4a605aea5 100644 --- a/client/src/components/Chat/Messages/Content/Parts/Text.tsx +++ b/client/src/components/Chat/Messages/Content/Parts/Text.tsx @@ -29,9 +29,7 @@ const TextPart = memo(({ text, isCreatedByUser, showCursor }: TextPartProps) => const content: ContentType = useMemo(() => { if (!isCreatedByUser) { - return ( - - ); + return ; } else if (enableUserMsgMarkdown) { return ; } else { diff --git a/client/src/components/SidePanel/Agents/AdminSettings.tsx b/client/src/components/SidePanel/Agents/AdminSettings.tsx index 6ca21d1317..5fb13fd045 100644 --- a/client/src/components/SidePanel/Agents/AdminSettings.tsx +++ b/client/src/components/SidePanel/Agents/AdminSettings.tsx @@ -142,7 +142,7 @@ const AdminSettings = () => { + ); +}; + +export default AdvancedButton; diff --git a/client/src/components/SidePanel/Agents/Advanced/AdvancedPanel.tsx b/client/src/components/SidePanel/Agents/Advanced/AdvancedPanel.tsx new file mode 100644 index 0000000000..0ead79cd32 --- /dev/null +++ b/client/src/components/SidePanel/Agents/Advanced/AdvancedPanel.tsx @@ -0,0 +1,55 @@ +import { useMemo } from 'react'; +import { ChevronLeft } from 'lucide-react'; +import { AgentCapabilities } from 'librechat-data-provider'; +import { useFormContext, Controller } from 'react-hook-form'; +import type { AgentForm, AgentPanelProps } from '~/common'; +import MaxAgentSteps from './MaxAgentSteps'; +import AgentChain from './AgentChain'; +import { useLocalize } from '~/hooks'; +import { Panel } from '~/common'; + +export default function AdvancedPanel({ + agentsConfig, + setActivePanel, +}: Pick) { + const localize = useLocalize(); + const methods = useFormContext(); + const { control, watch } = methods; + const currentAgentId = watch('id'); + const chainEnabled = useMemo( + () => agentsConfig?.capabilities.includes(AgentCapabilities.chain) ?? false, + [agentsConfig], + ); + + return ( +
+
+
+ +
+
{localize('com_ui_advanced_settings')}
+
+
+ + {chainEnabled && ( + } + /> + )} +
+
+ ); +} diff --git a/client/src/components/SidePanel/Agents/Advanced/AgentChain.tsx b/client/src/components/SidePanel/Agents/Advanced/AgentChain.tsx new file mode 100644 index 0000000000..1380927115 --- /dev/null +++ b/client/src/components/SidePanel/Agents/Advanced/AgentChain.tsx @@ -0,0 +1,179 @@ +import { X, Link2, PlusCircle } from 'lucide-react'; +import { EModelEndpoint } from 'librechat-data-provider'; +import React, { useState, useMemo, useCallback, useEffect } from 'react'; +import type { ControllerRenderProps } from 'react-hook-form'; +import type { AgentForm, OptionWithIcon } from '~/common'; +import ControlCombobox from '~/components/ui/ControlCombobox'; +import { HoverCard, HoverCardPortal, HoverCardContent, HoverCardTrigger } from '~/components/ui'; +import { CircleHelpIcon } from '~/components/svg'; +import { useAgentsMapContext } from '~/Providers'; +import Icon from '~/components/Endpoints/Icon'; +import { useLocalize } from '~/hooks'; +import { ESide } from '~/common'; + +interface AgentChainProps { + field: ControllerRenderProps; + currentAgentId: string; +} + +/** TODO: make configurable */ +const MAX_AGENTS = 10; + +const AgentChain: React.FC = ({ field, currentAgentId }) => { + const localize = useLocalize(); + const [newAgentId, setNewAgentId] = useState(''); + const agentsMap = useAgentsMapContext() || {}; + const agentIds = field.value || []; + + const agents = useMemo(() => Object.values(agentsMap), [agentsMap]); + + const selectableAgents = useMemo( + () => + agents + .filter((agent) => agent?.id !== currentAgentId) + .map( + (agent) => + ({ + label: agent?.name || '', + value: agent?.id, + icon: ( + + ), + }) as OptionWithIcon, + ), + [agents, currentAgentId], + ); + + const getAgentDetails = useCallback((id: string) => agentsMap[id], [agentsMap]); + + useEffect(() => { + if (newAgentId && agentIds.length < MAX_AGENTS) { + field.onChange([...agentIds, newAgentId]); + setNewAgentId(''); + } + }, [newAgentId, agentIds, field]); + + const removeAgentAt = (index: number) => { + field.onChange(agentIds.filter((_, i) => i !== index)); + }; + + const updateAgentAt = (index: number, id: string) => { + const updated = [...agentIds]; + updated[index] = id; + field.onChange(updated); + }; + + return ( + +
+
+ + + + +
+
+ {agentIds.length} / {MAX_AGENTS} +
+
+
+ {/* Current fixed agent */} +
+
+
+ +
+
+ {getAgentDetails(currentAgentId)?.name} +
+
+
+ {} + {agentIds.map((agentId, idx) => ( + +
+ updateAgentAt(idx, id)} + selectPlaceholder={localize('com_ui_agent_var', { 0: localize('com_ui_select') })} + searchPlaceholder={localize('com_ui_agent_var', { 0: localize('com_ui_search') })} + items={selectableAgents} + displayValue={getAgentDetails(agentId)?.name ?? ''} + SelectIcon={ + + } + className="flex-1 border-border-heavy" + containerClassName="px-0" + /> + {/* Future Settings button? */} + {/* */} + +
+ {idx < agentIds.length - 1 && ( + + )} +
+ ))} + + {agentIds.length < MAX_AGENTS && ( + <> + {agentIds.length > 0 && } + } + /> + + )} + + {agentIds.length >= MAX_AGENTS && ( +

+ {localize('com_ui_agent_chain_max', { 0: MAX_AGENTS })} +

+ )} +
+ + +
+

{localize('com_ui_agent_chain_info')}

+
+
+
+
+ ); +}; + +export default AgentChain; diff --git a/client/src/components/SidePanel/Agents/Advanced/MaxAgentSteps.tsx b/client/src/components/SidePanel/Agents/Advanced/MaxAgentSteps.tsx new file mode 100644 index 0000000000..5e334282f9 --- /dev/null +++ b/client/src/components/SidePanel/Agents/Advanced/MaxAgentSteps.tsx @@ -0,0 +1,52 @@ +import { useFormContext, Controller } from 'react-hook-form'; +import type { AgentForm } from '~/common'; +import { + HoverCard, + FormInput, + HoverCardPortal, + HoverCardContent, + HoverCardTrigger, +} from '~/components/ui'; +import { CircleHelpIcon } from '~/components/svg'; +import { useLocalize } from '~/hooks'; +import { ESide } from '~/common'; + +export default function AdvancedPanel() { + const localize = useLocalize(); + const methods = useFormContext(); + const { control } = methods; + + return ( + + ( + + + + } + /> + )} + /> + + +
+

+ {localize('com_ui_agent_recursion_limit_info')} +

+
+
+
+
+ ); +} diff --git a/client/src/components/SidePanel/Agents/AgentConfig.tsx b/client/src/components/SidePanel/Agents/AgentConfig.tsx index 9fc7674158..864ecd8173 100644 --- a/client/src/components/SidePanel/Agents/AgentConfig.tsx +++ b/client/src/components/SidePanel/Agents/AgentConfig.tsx @@ -1,31 +1,19 @@ import React, { useState, useMemo, useCallback } from 'react'; import { useQueryClient } from '@tanstack/react-query'; import { Controller, useWatch, useFormContext } from 'react-hook-form'; -import { - QueryKeys, - SystemRoles, - Permissions, - EModelEndpoint, - PermissionTypes, - AgentCapabilities, -} from 'librechat-data-provider'; +import { QueryKeys, EModelEndpoint, AgentCapabilities } from 'librechat-data-provider'; import type { TPlugin } from 'librechat-data-provider'; import type { AgentForm, AgentPanelProps, IconComponentTypes } from '~/common'; import { cn, defaultTextProps, removeFocusOutlines, getEndpointField, getIconKey } from '~/utils'; -import { useCreateAgentMutation, useUpdateAgentMutation } from '~/data-provider'; -import { useLocalize, useAuthContext, useHasAccess } from '~/hooks'; import { useToastContext, useFileMapContext } from '~/Providers'; import { icons } from '~/components/Chat/Menus/Endpoints/Icons'; import Action from '~/components/SidePanel/Builder/Action'; import { ToolSelectDialog } from '~/components/Tools'; -import DuplicateAgent from './DuplicateAgent'; import { processAgentOption } from '~/utils'; -import AdminSettings from './AdminSettings'; -import DeleteButton from './DeleteButton'; import AgentAvatar from './AgentAvatar'; -import { Spinner } from '~/components'; +import FileContext from './FileContext'; +import { useLocalize } from '~/hooks'; import FileSearch from './FileSearch'; -import ShareAgent from './ShareAgent'; import Artifacts from './Artifacts'; import AgentTool from './AgentTool'; import CodeForm from './Code/Form'; @@ -42,11 +30,10 @@ export default function AgentConfig({ setAction, actions = [], agentsConfig, - endpointsConfig, + createMutation, setActivePanel, - setCurrentAgentId, + endpointsConfig, }: AgentPanelProps) { - const { user } = useAuthContext(); const fileMap = useFileMapContext(); const queryClient = useQueryClient(); @@ -65,11 +52,6 @@ export default function AgentConfig({ const tools = useWatch({ control, name: 'tools' }); const agent_id = useWatch({ control, name: 'id' }); - const hasAccessToShareAgents = useHasAccess({ - permissionType: PermissionTypes.AGENTS, - permission: Permissions.SHARED_GLOBAL, - }); - const toolsEnabled = useMemo( () => agentsConfig?.capabilities.includes(AgentCapabilities.tools), [agentsConfig], @@ -82,6 +64,10 @@ export default function AgentConfig({ () => agentsConfig?.capabilities.includes(AgentCapabilities.artifacts) ?? false, [agentsConfig], ); + const ocrEnabled = useMemo( + () => agentsConfig?.capabilities.includes(AgentCapabilities.ocr) ?? false, + [agentsConfig], + ); const fileSearchEnabled = useMemo( () => agentsConfig?.capabilities.includes(AgentCapabilities.file_search) ?? false, [agentsConfig], @@ -91,6 +77,26 @@ export default function AgentConfig({ [agentsConfig], ); + const context_files = useMemo(() => { + if (typeof agent === 'string') { + return []; + } + + if (agent?.id !== agent_id) { + return []; + } + + if (agent.context_files) { + return agent.context_files; + } + + const _agent = processAgentOption({ + agent, + fileMap, + }); + return _agent.context_files ?? []; + }, [agent, agent_id, fileMap]); + const knowledge_files = useMemo(() => { if (typeof agent === 'string') { return []; @@ -131,46 +137,6 @@ export default function AgentConfig({ return _agent.code_files ?? []; }, [agent, agent_id, fileMap]); - /* Mutations */ - const update = useUpdateAgentMutation({ - onSuccess: (data) => { - showToast({ - message: `${localize('com_assistants_update_success')} ${ - data.name ?? localize('com_ui_agent') - }`, - }); - }, - onError: (err) => { - const error = err as Error; - showToast({ - message: `${localize('com_agents_update_error')}${ - error.message ? ` ${localize('com_ui_error')}: ${error.message}` : '' - }`, - status: 'error', - }); - }, - }); - - const create = useCreateAgentMutation({ - onSuccess: (data) => { - setCurrentAgentId(data.id); - showToast({ - message: `${localize('com_assistants_create_success')} ${ - data.name ?? localize('com_ui_agent') - }`, - }); - }, - onError: (err) => { - const error = err as Error; - showToast({ - message: `${localize('com_agents_create_error')}${ - error.message ? ` ${localize('com_ui_error')}: ${error.message}` : '' - }`, - status: 'error', - }); - }, - }); - const handleAddActions = useCallback(() => { if (!agent_id) { showToast({ @@ -200,26 +166,14 @@ export default function AgentConfig({ Icon = icons[iconKey]; } - const renderSaveButton = () => { - if (create.isLoading || update.isLoading) { - return
- {user?.role === SystemRoles.ADMIN && } - {/* Context Button */} -
- - {(agent?.author === user?.id || user?.role === SystemRoles.ADMIN) && - hasAccessToShareAgents && ( - - )} - {agent && agent.author === user?.id && } - {/* Submit Button */} - -
& { + updateMutation: ReturnType; +}) { + const localize = useLocalize(); + const { user } = useAuthContext(); + + const methods = useFormContext(); + + const { control } = methods; + const agent = useWatch({ control, name: 'agent' }); + const agent_id = useWatch({ control, name: 'id' }); + + const hasAccessToShareAgents = useHasAccess({ + permissionType: PermissionTypes.AGENTS, + permission: Permissions.SHARED_GLOBAL, + }); + + const renderSaveButton = () => { + if (createMutation.isLoading || updateMutation.isLoading) { + return