📝 fix: Mistral OCR Image Support and Azure Agent Titles (#6901)

* fix: azure title model

* refactor: typing for uploadMistralOCR

* fix: update conversation ID handling in useSSE for better state management, only use PENDING_CONVO for new conversations

* fix: streamline conversation ID handling in useSSE for simplicity, only needs state update to prevent draft from applying

* fix: update performOCR and tests to support document and image URLs with appropriate types
This commit is contained in:
Danny Avila 2025-04-15 18:03:56 -04:00 committed by GitHub
parent 650e9b4f6c
commit d32f34e5d7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 185 additions and 15 deletions

View file

@ -33,6 +33,7 @@ const { addCacheControl, createContextHandlers } = require('~/app/clients/prompt
const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens'); const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens');
const { getBufferString, HumanMessage } = require('@langchain/core/messages'); const { getBufferString, HumanMessage } = require('@langchain/core/messages');
const { encodeAndFormat } = require('~/server/services/Files/images/encode'); const { encodeAndFormat } = require('~/server/services/Files/images/encode');
const initOpenAI = require('~/server/services/Endpoints/openAI/initialize');
const Tokenizer = require('~/server/services/Tokenizer'); const Tokenizer = require('~/server/services/Tokenizer');
const BaseClient = require('~/app/clients/BaseClient'); const BaseClient = require('~/app/clients/BaseClient');
const { logger, sendEvent } = require('~/config'); const { logger, sendEvent } = require('~/config');
@ -931,14 +932,16 @@ class AgentClient extends BaseClient {
throw new Error('Run not initialized'); throw new Error('Run not initialized');
} }
const { handleLLMEnd, collected: collectedMetadata } = createMetadataAggregator(); const { handleLLMEnd, collected: collectedMetadata } = createMetadataAggregator();
const endpoint = this.options.agent.endpoint;
const { req, res } = this.options;
/** @type {import('@librechat/agents').ClientOptions} */ /** @type {import('@librechat/agents').ClientOptions} */
const clientOptions = { let clientOptions = {
maxTokens: 75, maxTokens: 75,
}; };
let endpointConfig = this.options.req.app.locals[this.options.agent.endpoint]; let endpointConfig = req.app.locals[endpoint];
if (!endpointConfig) { if (!endpointConfig) {
try { try {
endpointConfig = await getCustomEndpointConfig(this.options.agent.endpoint); endpointConfig = await getCustomEndpointConfig(endpoint);
} catch (err) { } catch (err) {
logger.error( logger.error(
'[api/server/controllers/agents/client.js #titleConvo] Error getting custom endpoint config', '[api/server/controllers/agents/client.js #titleConvo] Error getting custom endpoint config',
@ -953,6 +956,28 @@ class AgentClient extends BaseClient {
) { ) {
clientOptions.model = endpointConfig.titleModel; clientOptions.model = endpointConfig.titleModel;
} }
if (
endpoint === EModelEndpoint.azureOpenAI &&
clientOptions.model &&
this.options.agent.model_parameters.model !== clientOptions.model
) {
clientOptions =
(
await initOpenAI({
req,
res,
optionsOnly: true,
overrideModel: clientOptions.model,
overrideEndpoint: endpoint,
endpointOption: {
model_parameters: clientOptions,
},
})
)?.llmConfig ?? clientOptions;
}
if (/\b(o1|o3)\b/i.test(clientOptions.model) && clientOptions.maxTokens != null) {
delete clientOptions.maxTokens;
}
try { try {
const titleResult = await this.run.generateTitle({ const titleResult = await this.run.generateTitle({
inputText: text, inputText: text,

View file

@ -69,16 +69,20 @@ async function getSignedUrl({
/** /**
* @param {Object} params * @param {Object} params
* @param {string} params.apiKey * @param {string} params.apiKey
* @param {string} params.documentUrl * @param {string} params.url - The document or image URL
* @param {string} [params.documentType='document_url'] - 'document_url' or 'image_url'
* @param {string} [params.model]
* @param {string} [params.baseURL] * @param {string} [params.baseURL]
* @returns {Promise<OCRResult>} * @returns {Promise<OCRResult>}
*/ */
async function performOCR({ async function performOCR({
apiKey, apiKey,
documentUrl, url,
documentType = 'document_url',
model = 'mistral-ocr-latest', model = 'mistral-ocr-latest',
baseURL = 'https://api.mistral.ai/v1', baseURL = 'https://api.mistral.ai/v1',
}) { }) {
const documentKey = documentType === 'image_url' ? 'image_url' : 'document_url';
return axios return axios
.post( .post(
`${baseURL}/ocr`, `${baseURL}/ocr`,
@ -86,8 +90,8 @@ async function performOCR({
model, model,
include_image_base64: false, include_image_base64: false,
document: { document: {
type: 'document_url', type: documentType,
document_url: documentUrl, [documentKey]: url,
}, },
}, },
{ {
@ -109,6 +113,19 @@ function extractVariableName(str) {
return match ? match[1] : null; return match ? match[1] : null;
} }
/**
* Uploads a file to the Mistral OCR API and processes the OCR result.
*
* @param {Object} params - The params object.
* @param {ServerRequest} params.req - The request object from Express. It should have a `user` property with an `id`
* representing the user
* @param {Express.Multer.File} params.file - The file object, which is part of the request. The file object should
* have a `mimetype` property that tells us the file type
* @param {string} params.file_id - The file ID.
* @param {string} [params.entity_id] - The entity ID, not used here but passed for consistency.
* @returns {Promise<{ filepath: string, bytes: number }>} - The result object containing the processed `text` and `images` (not currently used),
* along with the `filename` and `bytes` properties.
*/
const uploadMistralOCR = async ({ req, file, file_id, entity_id }) => { const uploadMistralOCR = async ({ req, file, file_id, entity_id }) => {
try { try {
/** @type {TCustomConfig['ocr']} */ /** @type {TCustomConfig['ocr']} */
@ -160,11 +177,18 @@ const uploadMistralOCR = async ({ req, file, file_id, entity_id }) => {
fileId: mistralFile.id, fileId: mistralFile.id,
}); });
const mimetype = (file.mimetype || '').toLowerCase();
const originalname = file.originalname || '';
const isImage =
mimetype.startsWith('image') || /\.(png|jpe?g|gif|bmp|webp|tiff?)$/i.test(originalname);
const documentType = isImage ? 'image_url' : 'document_url';
const ocrResult = await performOCR({ const ocrResult = await performOCR({
apiKey, apiKey,
baseURL, baseURL,
model, model,
documentUrl: signedUrlResponse.url, url: signedUrlResponse.url,
documentType,
}); });
let aggregatedText = ''; let aggregatedText = '';

View file

@ -172,7 +172,7 @@ describe('MistralOCR Service', () => {
}); });
describe('performOCR', () => { describe('performOCR', () => {
it('should perform OCR using Mistral API', async () => { it('should perform OCR using Mistral API (document_url)', async () => {
const mockResponse = { const mockResponse = {
data: { data: {
pages: [{ markdown: 'Page 1 content' }, { markdown: 'Page 2 content' }], pages: [{ markdown: 'Page 1 content' }, { markdown: 'Page 2 content' }],
@ -182,8 +182,9 @@ describe('MistralOCR Service', () => {
const result = await performOCR({ const result = await performOCR({
apiKey: 'test-api-key', apiKey: 'test-api-key',
documentUrl: 'https://document-url.com', url: 'https://document-url.com',
model: 'mistral-ocr-latest', model: 'mistral-ocr-latest',
documentType: 'document_url',
}); });
expect(mockAxios.post).toHaveBeenCalledWith( expect(mockAxios.post).toHaveBeenCalledWith(
@ -206,6 +207,41 @@ describe('MistralOCR Service', () => {
expect(result).toEqual(mockResponse.data); expect(result).toEqual(mockResponse.data);
}); });
it('should perform OCR using Mistral API (image_url)', async () => {
const mockResponse = {
data: {
pages: [{ markdown: 'Image OCR content' }],
},
};
mockAxios.post.mockResolvedValueOnce(mockResponse);
const result = await performOCR({
apiKey: 'test-api-key',
url: 'https://image-url.com/image.png',
model: 'mistral-ocr-latest',
documentType: 'image_url',
});
expect(mockAxios.post).toHaveBeenCalledWith(
'https://api.mistral.ai/v1/ocr',
{
model: 'mistral-ocr-latest',
include_image_base64: false,
document: {
type: 'image_url',
image_url: 'https://image-url.com/image.png',
},
},
{
headers: {
'Content-Type': 'application/json',
Authorization: 'Bearer test-api-key',
},
},
);
expect(result).toEqual(mockResponse.data);
});
it('should handle errors during OCR processing', async () => { it('should handle errors during OCR processing', async () => {
const errorMessage = 'OCR processing error'; const errorMessage = 'OCR processing error';
mockAxios.post.mockRejectedValueOnce(new Error(errorMessage)); mockAxios.post.mockRejectedValueOnce(new Error(errorMessage));
@ -213,7 +249,7 @@ describe('MistralOCR Service', () => {
await expect( await expect(
performOCR({ performOCR({
apiKey: 'test-api-key', apiKey: 'test-api-key',
documentUrl: 'https://document-url.com', url: 'https://document-url.com',
}), }),
).rejects.toThrow(); ).rejects.toThrow();
@ -295,6 +331,7 @@ describe('MistralOCR Service', () => {
const file = { const file = {
path: '/tmp/upload/file.pdf', path: '/tmp/upload/file.pdf',
originalname: 'document.pdf', originalname: 'document.pdf',
mimetype: 'application/pdf',
}; };
const result = await uploadMistralOCR({ const result = await uploadMistralOCR({
@ -322,6 +359,90 @@ describe('MistralOCR Service', () => {
}); });
}); });
it('should process OCR for an image file and use image_url type', async () => {
const { loadAuthValues } = require('~/server/services/Tools/credentials');
loadAuthValues.mockResolvedValue({
OCR_API_KEY: 'test-api-key',
OCR_BASEURL: 'https://api.mistral.ai/v1',
});
// Mock file upload response
mockAxios.post.mockResolvedValueOnce({
data: { id: 'file-456', purpose: 'ocr' },
});
// Mock signed URL response
mockAxios.get.mockResolvedValueOnce({
data: { url: 'https://signed-url.com/image.png' },
});
// Mock OCR response for image
mockAxios.post.mockResolvedValueOnce({
data: {
pages: [
{
markdown: 'Image OCR result',
images: [{ image_base64: 'imgbase64' }],
},
],
},
});
const req = {
user: { id: 'user456' },
app: {
locals: {
ocr: {
apiKey: '${OCR_API_KEY}',
baseURL: '${OCR_BASEURL}',
mistralModel: 'mistral-medium',
},
},
},
};
const file = {
path: '/tmp/upload/image.png',
originalname: 'image.png',
mimetype: 'image/png',
};
const result = await uploadMistralOCR({
req,
file,
file_id: 'file456',
entity_id: 'entity456',
});
expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/image.png');
expect(loadAuthValues).toHaveBeenCalledWith({
userId: 'user456',
authFields: ['OCR_BASEURL', 'OCR_API_KEY'],
optional: expect.any(Set),
});
// Check that the OCR API was called with image_url type
expect(mockAxios.post).toHaveBeenCalledWith(
'https://api.mistral.ai/v1/ocr',
expect.objectContaining({
document: expect.objectContaining({
type: 'image_url',
image_url: 'https://signed-url.com/image.png',
}),
}),
expect.any(Object),
);
expect(result).toEqual({
filename: 'image.png',
bytes: expect.any(Number),
filepath: 'mistral_ocr',
text: expect.stringContaining('Image OCR result'),
images: ['imgbase64'],
});
});
it('should process variable references in configuration', async () => { it('should process variable references in configuration', async () => {
// Setup mocks with environment variables // Setup mocks with environment variables
const { loadAuthValues } = require('~/server/services/Tools/credentials'); const { loadAuthValues } = require('~/server/services/Tools/credentials');

View file

@ -520,7 +520,7 @@ const processAgentFileUpload = async ({ req, res, metadata }) => {
throw new Error('OCR capability is not enabled for Agents'); throw new Error('OCR capability is not enabled for Agents');
} }
const { handleFileUpload } = getStrategyFunctions( const { handleFileUpload: uploadMistralOCR } = getStrategyFunctions(
req.app.locals?.ocr?.strategy ?? FileSources.mistral_ocr, req.app.locals?.ocr?.strategy ?? FileSources.mistral_ocr,
); );
const { file_id, temp_file_id } = metadata; const { file_id, temp_file_id } = metadata;
@ -532,7 +532,7 @@ const processAgentFileUpload = async ({ req, res, metadata }) => {
images, images,
filename, filename,
filepath: ocrFileURL, filepath: ocrFileURL,
} = await handleFileUpload({ req, file, file_id, entity_id: agent_id, basePath }); } = await uploadMistralOCR({ req, file, file_id, entity_id: agent_id, basePath });
const fileInfo = removeNullishValues({ const fileInfo = removeNullishValues({
text, text,
@ -540,7 +540,7 @@ const processAgentFileUpload = async ({ req, res, metadata }) => {
file_id, file_id,
temp_file_id, temp_file_id,
user: req.user.id, user: req.user.id,
type: file.mimetype, type: 'text/plain',
filepath: ocrFileURL, filepath: ocrFileURL,
source: FileSources.text, source: FileSources.text,
filename: filename ?? file.originalname, filename: filename ?? file.originalname,

View file

@ -128,7 +128,7 @@ export default function useSSE(
return { return {
...prev, ...prev,
title, title,
conversationId: Constants.PENDING_CONVO as string, conversationId: prev?.conversationId,
}; };
}); });
let { payload } = payloadData; let { payload } = payloadData;