From d32f34e5d7f41536a030f66db24a036b3a4b7cfb Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Tue, 15 Apr 2025 18:03:56 -0400 Subject: [PATCH] =?UTF-8?q?=F0=9F=93=9D=20fix:=20Mistral=20OCR=20Image=20S?= =?UTF-8?q?upport=20and=20Azure=20Agent=20Titles=20(#6901)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: azure title model * refactor: typing for uploadMistralOCR * fix: update conversation ID handling in useSSE for better state management, only use PENDING_CONVO for new conversations * fix: streamline conversation ID handling in useSSE for simplicity, only needs state update to prevent draft from applying * fix: update performOCR and tests to support document and image URLs with appropriate types --- api/server/controllers/agents/client.js | 31 ++++- api/server/services/Files/MistralOCR/crud.js | 34 ++++- .../services/Files/MistralOCR/crud.spec.js | 127 +++++++++++++++++- api/server/services/Files/process.js | 6 +- client/src/hooks/SSE/useSSE.ts | 2 +- 5 files changed, 185 insertions(+), 15 deletions(-) diff --git a/api/server/controllers/agents/client.js b/api/server/controllers/agents/client.js index 8a128bcdba..09290b59f9 100644 --- a/api/server/controllers/agents/client.js +++ b/api/server/controllers/agents/client.js @@ -33,6 +33,7 @@ const { addCacheControl, createContextHandlers } = require('~/app/clients/prompt const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens'); const { getBufferString, HumanMessage } = require('@langchain/core/messages'); const { encodeAndFormat } = require('~/server/services/Files/images/encode'); +const initOpenAI = require('~/server/services/Endpoints/openAI/initialize'); const Tokenizer = require('~/server/services/Tokenizer'); const BaseClient = require('~/app/clients/BaseClient'); const { logger, sendEvent } = require('~/config'); @@ -931,14 +932,16 @@ class AgentClient extends BaseClient { throw new Error('Run not initialized'); } const { handleLLMEnd, collected: collectedMetadata } = createMetadataAggregator(); + const endpoint = this.options.agent.endpoint; + const { req, res } = this.options; /** @type {import('@librechat/agents').ClientOptions} */ - const clientOptions = { + let clientOptions = { maxTokens: 75, }; - let endpointConfig = this.options.req.app.locals[this.options.agent.endpoint]; + let endpointConfig = req.app.locals[endpoint]; if (!endpointConfig) { try { - endpointConfig = await getCustomEndpointConfig(this.options.agent.endpoint); + endpointConfig = await getCustomEndpointConfig(endpoint); } catch (err) { logger.error( '[api/server/controllers/agents/client.js #titleConvo] Error getting custom endpoint config', @@ -953,6 +956,28 @@ class AgentClient extends BaseClient { ) { clientOptions.model = endpointConfig.titleModel; } + if ( + endpoint === EModelEndpoint.azureOpenAI && + clientOptions.model && + this.options.agent.model_parameters.model !== clientOptions.model + ) { + clientOptions = + ( + await initOpenAI({ + req, + res, + optionsOnly: true, + overrideModel: clientOptions.model, + overrideEndpoint: endpoint, + endpointOption: { + model_parameters: clientOptions, + }, + }) + )?.llmConfig ?? clientOptions; + } + if (/\b(o1|o3)\b/i.test(clientOptions.model) && clientOptions.maxTokens != null) { + delete clientOptions.maxTokens; + } try { const titleResult = await this.run.generateTitle({ inputText: text, diff --git a/api/server/services/Files/MistralOCR/crud.js b/api/server/services/Files/MistralOCR/crud.js index 689e4152ba..0c544b9eb4 100644 --- a/api/server/services/Files/MistralOCR/crud.js +++ b/api/server/services/Files/MistralOCR/crud.js @@ -69,16 +69,20 @@ async function getSignedUrl({ /** * @param {Object} params * @param {string} params.apiKey - * @param {string} params.documentUrl + * @param {string} params.url - The document or image URL + * @param {string} [params.documentType='document_url'] - 'document_url' or 'image_url' + * @param {string} [params.model] * @param {string} [params.baseURL] * @returns {Promise} */ async function performOCR({ apiKey, - documentUrl, + url, + documentType = 'document_url', model = 'mistral-ocr-latest', baseURL = 'https://api.mistral.ai/v1', }) { + const documentKey = documentType === 'image_url' ? 'image_url' : 'document_url'; return axios .post( `${baseURL}/ocr`, @@ -86,8 +90,8 @@ async function performOCR({ model, include_image_base64: false, document: { - type: 'document_url', - document_url: documentUrl, + type: documentType, + [documentKey]: url, }, }, { @@ -109,6 +113,19 @@ function extractVariableName(str) { return match ? match[1] : null; } +/** + * Uploads a file to the Mistral OCR API and processes the OCR result. + * + * @param {Object} params - The params object. + * @param {ServerRequest} params.req - The request object from Express. It should have a `user` property with an `id` + * representing the user + * @param {Express.Multer.File} params.file - The file object, which is part of the request. The file object should + * have a `mimetype` property that tells us the file type + * @param {string} params.file_id - The file ID. + * @param {string} [params.entity_id] - The entity ID, not used here but passed for consistency. + * @returns {Promise<{ filepath: string, bytes: number }>} - The result object containing the processed `text` and `images` (not currently used), + * along with the `filename` and `bytes` properties. + */ const uploadMistralOCR = async ({ req, file, file_id, entity_id }) => { try { /** @type {TCustomConfig['ocr']} */ @@ -160,11 +177,18 @@ const uploadMistralOCR = async ({ req, file, file_id, entity_id }) => { fileId: mistralFile.id, }); + const mimetype = (file.mimetype || '').toLowerCase(); + const originalname = file.originalname || ''; + const isImage = + mimetype.startsWith('image') || /\.(png|jpe?g|gif|bmp|webp|tiff?)$/i.test(originalname); + const documentType = isImage ? 'image_url' : 'document_url'; + const ocrResult = await performOCR({ apiKey, baseURL, model, - documentUrl: signedUrlResponse.url, + url: signedUrlResponse.url, + documentType, }); let aggregatedText = ''; diff --git a/api/server/services/Files/MistralOCR/crud.spec.js b/api/server/services/Files/MistralOCR/crud.spec.js index 6d0b321bbf..c3d2f46c40 100644 --- a/api/server/services/Files/MistralOCR/crud.spec.js +++ b/api/server/services/Files/MistralOCR/crud.spec.js @@ -172,7 +172,7 @@ describe('MistralOCR Service', () => { }); describe('performOCR', () => { - it('should perform OCR using Mistral API', async () => { + it('should perform OCR using Mistral API (document_url)', async () => { const mockResponse = { data: { pages: [{ markdown: 'Page 1 content' }, { markdown: 'Page 2 content' }], @@ -182,8 +182,9 @@ describe('MistralOCR Service', () => { const result = await performOCR({ apiKey: 'test-api-key', - documentUrl: 'https://document-url.com', + url: 'https://document-url.com', model: 'mistral-ocr-latest', + documentType: 'document_url', }); expect(mockAxios.post).toHaveBeenCalledWith( @@ -206,6 +207,41 @@ describe('MistralOCR Service', () => { expect(result).toEqual(mockResponse.data); }); + it('should perform OCR using Mistral API (image_url)', async () => { + const mockResponse = { + data: { + pages: [{ markdown: 'Image OCR content' }], + }, + }; + mockAxios.post.mockResolvedValueOnce(mockResponse); + + const result = await performOCR({ + apiKey: 'test-api-key', + url: 'https://image-url.com/image.png', + model: 'mistral-ocr-latest', + documentType: 'image_url', + }); + + expect(mockAxios.post).toHaveBeenCalledWith( + 'https://api.mistral.ai/v1/ocr', + { + model: 'mistral-ocr-latest', + include_image_base64: false, + document: { + type: 'image_url', + image_url: 'https://image-url.com/image.png', + }, + }, + { + headers: { + 'Content-Type': 'application/json', + Authorization: 'Bearer test-api-key', + }, + }, + ); + expect(result).toEqual(mockResponse.data); + }); + it('should handle errors during OCR processing', async () => { const errorMessage = 'OCR processing error'; mockAxios.post.mockRejectedValueOnce(new Error(errorMessage)); @@ -213,7 +249,7 @@ describe('MistralOCR Service', () => { await expect( performOCR({ apiKey: 'test-api-key', - documentUrl: 'https://document-url.com', + url: 'https://document-url.com', }), ).rejects.toThrow(); @@ -295,6 +331,7 @@ describe('MistralOCR Service', () => { const file = { path: '/tmp/upload/file.pdf', originalname: 'document.pdf', + mimetype: 'application/pdf', }; const result = await uploadMistralOCR({ @@ -322,6 +359,90 @@ describe('MistralOCR Service', () => { }); }); + it('should process OCR for an image file and use image_url type', async () => { + const { loadAuthValues } = require('~/server/services/Tools/credentials'); + loadAuthValues.mockResolvedValue({ + OCR_API_KEY: 'test-api-key', + OCR_BASEURL: 'https://api.mistral.ai/v1', + }); + + // Mock file upload response + mockAxios.post.mockResolvedValueOnce({ + data: { id: 'file-456', purpose: 'ocr' }, + }); + + // Mock signed URL response + mockAxios.get.mockResolvedValueOnce({ + data: { url: 'https://signed-url.com/image.png' }, + }); + + // Mock OCR response for image + mockAxios.post.mockResolvedValueOnce({ + data: { + pages: [ + { + markdown: 'Image OCR result', + images: [{ image_base64: 'imgbase64' }], + }, + ], + }, + }); + + const req = { + user: { id: 'user456' }, + app: { + locals: { + ocr: { + apiKey: '${OCR_API_KEY}', + baseURL: '${OCR_BASEURL}', + mistralModel: 'mistral-medium', + }, + }, + }, + }; + + const file = { + path: '/tmp/upload/image.png', + originalname: 'image.png', + mimetype: 'image/png', + }; + + const result = await uploadMistralOCR({ + req, + file, + file_id: 'file456', + entity_id: 'entity456', + }); + + expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/image.png'); + + expect(loadAuthValues).toHaveBeenCalledWith({ + userId: 'user456', + authFields: ['OCR_BASEURL', 'OCR_API_KEY'], + optional: expect.any(Set), + }); + + // Check that the OCR API was called with image_url type + expect(mockAxios.post).toHaveBeenCalledWith( + 'https://api.mistral.ai/v1/ocr', + expect.objectContaining({ + document: expect.objectContaining({ + type: 'image_url', + image_url: 'https://signed-url.com/image.png', + }), + }), + expect.any(Object), + ); + + expect(result).toEqual({ + filename: 'image.png', + bytes: expect.any(Number), + filepath: 'mistral_ocr', + text: expect.stringContaining('Image OCR result'), + images: ['imgbase64'], + }); + }); + it('should process variable references in configuration', async () => { // Setup mocks with environment variables const { loadAuthValues } = require('~/server/services/Tools/credentials'); diff --git a/api/server/services/Files/process.js b/api/server/services/Files/process.js index 384955dabf..81a4f52855 100644 --- a/api/server/services/Files/process.js +++ b/api/server/services/Files/process.js @@ -520,7 +520,7 @@ const processAgentFileUpload = async ({ req, res, metadata }) => { throw new Error('OCR capability is not enabled for Agents'); } - const { handleFileUpload } = getStrategyFunctions( + const { handleFileUpload: uploadMistralOCR } = getStrategyFunctions( req.app.locals?.ocr?.strategy ?? FileSources.mistral_ocr, ); const { file_id, temp_file_id } = metadata; @@ -532,7 +532,7 @@ const processAgentFileUpload = async ({ req, res, metadata }) => { images, filename, filepath: ocrFileURL, - } = await handleFileUpload({ req, file, file_id, entity_id: agent_id, basePath }); + } = await uploadMistralOCR({ req, file, file_id, entity_id: agent_id, basePath }); const fileInfo = removeNullishValues({ text, @@ -540,7 +540,7 @@ const processAgentFileUpload = async ({ req, res, metadata }) => { file_id, temp_file_id, user: req.user.id, - type: file.mimetype, + type: 'text/plain', filepath: ocrFileURL, source: FileSources.text, filename: filename ?? file.originalname, diff --git a/client/src/hooks/SSE/useSSE.ts b/client/src/hooks/SSE/useSSE.ts index 4042479415..2e66bfdefe 100644 --- a/client/src/hooks/SSE/useSSE.ts +++ b/client/src/hooks/SSE/useSSE.ts @@ -128,7 +128,7 @@ export default function useSSE( return { ...prev, title, - conversationId: Constants.PENDING_CONVO as string, + conversationId: prev?.conversationId, }; }); let { payload } = payloadData;