diff --git a/api/server/services/Files/strategies.js b/api/server/services/Files/strategies.js index 41dcd5518a..4f8067142b 100644 --- a/api/server/services/Files/strategies.js +++ b/api/server/services/Files/strategies.js @@ -1,5 +1,9 @@ const { FileSources } = require('librechat-data-provider'); -const { uploadMistralOCR, uploadAzureMistralOCR } = require('@librechat/api'); +const { + uploadMistralOCR, + uploadAzureMistralOCR, + uploadGoogleVertexMistralOCR, +} = require('@librechat/api'); const { getFirebaseURL, prepareImageURL, @@ -222,6 +226,26 @@ const azureMistralOCRStrategy = () => ({ handleFileUpload: uploadAzureMistralOCR, }); +const vertexMistralOCRStrategy = () => ({ + /** @type {typeof saveFileFromURL | null} */ + saveURL: null, + /** @type {typeof getLocalFileURL | null} */ + getFileURL: null, + /** @type {typeof saveLocalBuffer | null} */ + saveBuffer: null, + /** @type {typeof processLocalAvatar | null} */ + processAvatar: null, + /** @type {typeof uploadLocalImage | null} */ + handleImageUpload: null, + /** @type {typeof prepareImagesLocal | null} */ + prepareImagePayload: null, + /** @type {typeof deleteLocalFile | null} */ + deleteFile: null, + /** @type {typeof getLocalFileStream | null} */ + getDownloadStream: null, + handleFileUpload: uploadGoogleVertexMistralOCR, +}); + // Strategy Selector const getStrategyFunctions = (fileSource) => { if (fileSource === FileSources.firebase) { @@ -244,6 +268,8 @@ const getStrategyFunctions = (fileSource) => { return mistralOCRStrategy(); } else if (fileSource === FileSources.azure_mistral_ocr) { return azureMistralOCRStrategy(); + } else if (fileSource === FileSources.vertexai_mistral_ocr) { + return vertexMistralOCRStrategy(); } else { throw new Error('Invalid file source'); } diff --git a/packages/api/src/files/mistral/crud.ts b/packages/api/src/files/mistral/crud.ts index d89be8f14d..f3ad74b731 100644 --- a/packages/api/src/files/mistral/crud.ts +++ b/packages/api/src/files/mistral/crud.ts @@ -32,6 +32,13 @@ interface AuthConfig { baseURL: string; } +/** Helper type for Google service account */ +interface GoogleServiceAccount { + client_email?: string; + private_key?: string; + project_id?: string; +} + /** Helper type for OCR request context */ interface OCRContext { req: Pick & { @@ -424,3 +431,214 @@ export const uploadAzureMistralOCR = async ( throw createOCRError(error, 'Error uploading document to Azure Mistral OCR API:'); } }; + +/** + * Loads Google service account configuration + */ +async function loadGoogleAuthConfig(): Promise<{ + serviceAccount: GoogleServiceAccount; + accessToken: string; +}> { + /** Path from current file to project root auth.json */ + const authJsonPath = path.join(__dirname, '..', '..', '..', 'api', 'data', 'auth.json'); + + let serviceKey: GoogleServiceAccount; + try { + const authJsonContent = fs.readFileSync(authJsonPath, 'utf8'); + serviceKey = JSON.parse(authJsonContent) as GoogleServiceAccount; + } catch { + throw new Error(`Google service account not found at ${authJsonPath}`); + } + + if (!serviceKey.client_email || !serviceKey.private_key || !serviceKey.project_id) { + throw new Error('Invalid Google service account configuration'); + } + + const jwt = await createJWT(serviceKey); + const accessToken = await exchangeJWTForAccessToken(jwt); + + return { + serviceAccount: serviceKey, + accessToken, + }; +} + +/** + * Creates a JWT token manually + */ +async function createJWT(serviceKey: GoogleServiceAccount): Promise { + const crypto = await import('crypto'); + + const header = { + alg: 'RS256', + typ: 'JWT', + }; + + const now = Math.floor(Date.now() / 1000); + const payload = { + iss: serviceKey.client_email, + scope: 'https://www.googleapis.com/auth/cloud-platform', + aud: 'https://oauth2.googleapis.com/token', + exp: now + 3600, + iat: now, + }; + + const encodedHeader = Buffer.from(JSON.stringify(header)).toString('base64url'); + const encodedPayload = Buffer.from(JSON.stringify(payload)).toString('base64url'); + + const signatureInput = `${encodedHeader}.${encodedPayload}`; + + const sign = crypto.createSign('RSA-SHA256'); + sign.update(signatureInput); + sign.end(); + + const signature = sign.sign(serviceKey.private_key!, 'base64url'); + + return `${signatureInput}.${signature}`; +} + +/** + * Exchanges JWT for access token + */ +async function exchangeJWTForAccessToken(jwt: string): Promise { + const response = await axios.post( + 'https://oauth2.googleapis.com/token', + new URLSearchParams({ + grant_type: 'urn:ietf:params:oauth:grant-type:jwt-bearer', + assertion: jwt, + }), + { + headers: { + 'Content-Type': 'application/x-www-form-urlencoded', + }, + }, + ); + + if (!response.data?.access_token) { + throw new Error('No access token in response'); + } + + return response.data.access_token; +} + +/** + * Performs OCR using Google Vertex AI + */ +async function performGoogleVertexOCR({ + url, + accessToken, + projectId, + model, + documentType = 'document_url', +}: { + url: string; + accessToken: string; + projectId: string; + model: string; + documentType?: 'document_url' | 'image_url'; +}): Promise { + const location = process.env.GOOGLE_LOC || 'us-central1'; + const modelId = model || 'mistral-ocr-2505'; + + let baseURL: string; + if (location === 'global') { + baseURL = `https://aiplatform.googleapis.com/v1/projects/${projectId}/locations/global/publishers/mistralai/models/${modelId}:rawPredict`; + } else { + baseURL = `https://${location}-aiplatform.googleapis.com/v1/projects/${projectId}/locations/${location}/publishers/mistralai/models/${modelId}:rawPredict`; + } + + const documentKey = documentType === 'image_url' ? 'image_url' : 'document_url'; + + const requestBody = { + model: modelId, + document: { + type: documentType, + [documentKey]: url, + }, + include_image_base64: true, + }; + + logger.debug('Sending request to Google Vertex AI:', { + url: baseURL, + body: { + ...requestBody, + document: { ...requestBody.document, [documentKey]: 'base64_data_hidden' }, + }, + }); + + return axios + .post(baseURL, requestBody, { + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${accessToken}`, + Accept: 'application/json', + }, + }) + .then((res) => { + logger.debug('Google Vertex AI response received'); + return res.data; + }) + .catch((error) => { + if (error.response?.data) { + logger.error('Vertex AI error response: ' + JSON.stringify(error.response.data, null, 2)); + } + throw new Error( + logAxiosError({ + error: error as AxiosError, + message: 'Error calling Google Vertex AI Mistral OCR', + }), + ); + }); +} + +/** + * Use Google Vertex AI Mistral OCR API to process the OCR result. + * + * @param params - The params object. + * @param params.req - The request object from Express. It should have a `user` property with an `id` + * representing the user + * @param params.file - The file object, which is part of the request. The file object should + * have a `mimetype` property that tells us the file type + * @param params.loadAuthValues - Function to load authentication values + * @returns - The result object containing the processed `text` and `images` (not currently used), + * along with the `filename` and `bytes` properties. + */ +export const uploadGoogleVertexMistralOCR = async ( + context: OCRContext, +): Promise => { + try { + const { serviceAccount, accessToken } = await loadGoogleAuthConfig(); + const model = getModelConfig(context.req.app.locals?.ocr); + + const buffer = fs.readFileSync(context.file.path); + const base64 = buffer.toString('base64'); + const base64Prefix = `data:${context.file.mimetype || 'application/pdf'};base64,`; + + const documentType = getDocumentType(context.file); + const ocrResult = await performGoogleVertexOCR({ + url: `${base64Prefix}${base64}`, + accessToken, + projectId: serviceAccount.project_id!, + model, + documentType, + }); + + if (!ocrResult || !ocrResult.pages || ocrResult.pages.length === 0) { + throw new Error( + 'No OCR result returned from service, may be down or the file is not supported.', + ); + } + + const { text, images } = processOCRResult(ocrResult); + + return { + filename: context.file.originalname, + bytes: text.length * 4, + filepath: FileSources.vertexai_mistral_ocr as string, + text, + images, + }; + } catch (error) { + throw createOCRError(error, 'Error uploading document to Google Vertex AI Mistral OCR:'); + } +}; diff --git a/packages/data-provider/src/config.ts b/packages/data-provider/src/config.ts index 4d154f4958..cf69603bf1 100644 --- a/packages/data-provider/src/config.ts +++ b/packages/data-provider/src/config.ts @@ -615,6 +615,7 @@ export enum OCRStrategy { MISTRAL_OCR = 'mistral_ocr', CUSTOM_OCR = 'custom_ocr', AZURE_MISTRAL_OCR = 'azure_mistral_ocr', + VERTEXAI_MISTRAL_OCR = 'vertexai_mistral_ocr', } export enum SearchCategories { diff --git a/packages/data-provider/src/types/files.ts b/packages/data-provider/src/types/files.ts index 95b74a4216..fd60278053 100644 --- a/packages/data-provider/src/types/files.ts +++ b/packages/data-provider/src/types/files.ts @@ -11,6 +11,7 @@ export enum FileSources { execute_code = 'execute_code', mistral_ocr = 'mistral_ocr', azure_mistral_ocr = 'azure_mistral_ocr', + vertexai_mistral_ocr = 'vertexai_mistral_ocr', text = 'text', }