mirror of
https://github.com/danny-avila/LibreChat.git
synced 2025-12-20 02:10:15 +01:00
Merge branch 'main' into feat/user-groups
This commit is contained in:
commit
c696b935b8
108 changed files with 7249 additions and 1732 deletions
|
|
@ -5,6 +5,7 @@ const {
|
|||
isAgentsEndpoint,
|
||||
isParamEndpoint,
|
||||
EModelEndpoint,
|
||||
ContentTypes,
|
||||
excludedKeys,
|
||||
ErrorTypes,
|
||||
Constants,
|
||||
|
|
@ -365,17 +366,14 @@ class BaseClient {
|
|||
* context: TMessage[],
|
||||
* remainingContextTokens: number,
|
||||
* messagesToRefine: TMessage[],
|
||||
* summaryIndex: number,
|
||||
* }>} An object with four properties: `context`, `summaryIndex`, `remainingContextTokens`, and `messagesToRefine`.
|
||||
* }>} An object with three properties: `context`, `remainingContextTokens`, and `messagesToRefine`.
|
||||
* `context` is an array of messages that fit within the token limit.
|
||||
* `summaryIndex` is the index of the first message in the `messagesToRefine` array.
|
||||
* `remainingContextTokens` is the number of tokens remaining within the limit after adding the messages to the context.
|
||||
* `messagesToRefine` is an array of messages that were not added to the context because they would have exceeded the token limit.
|
||||
*/
|
||||
async getMessagesWithinTokenLimit({ messages: _messages, maxContextTokens, instructions }) {
|
||||
// Every reply is primed with <|start|>assistant<|message|>, so we
|
||||
// start with 3 tokens for the label after all messages have been counted.
|
||||
let summaryIndex = -1;
|
||||
let currentTokenCount = 3;
|
||||
const instructionsTokenCount = instructions?.tokenCount ?? 0;
|
||||
let remainingContextTokens =
|
||||
|
|
@ -408,14 +406,12 @@ class BaseClient {
|
|||
}
|
||||
|
||||
const prunedMemory = messages;
|
||||
summaryIndex = prunedMemory.length - 1;
|
||||
remainingContextTokens -= currentTokenCount;
|
||||
|
||||
return {
|
||||
context: context.reverse(),
|
||||
remainingContextTokens,
|
||||
messagesToRefine: prunedMemory,
|
||||
summaryIndex,
|
||||
};
|
||||
}
|
||||
|
||||
|
|
@ -458,7 +454,7 @@ class BaseClient {
|
|||
|
||||
let orderedWithInstructions = this.addInstructions(orderedMessages, instructions);
|
||||
|
||||
let { context, remainingContextTokens, messagesToRefine, summaryIndex } =
|
||||
let { context, remainingContextTokens, messagesToRefine } =
|
||||
await this.getMessagesWithinTokenLimit({
|
||||
messages: orderedWithInstructions,
|
||||
instructions,
|
||||
|
|
@ -528,7 +524,7 @@ class BaseClient {
|
|||
}
|
||||
|
||||
// Make sure to only continue summarization logic if the summary message was generated
|
||||
shouldSummarize = summaryMessage && shouldSummarize;
|
||||
shouldSummarize = summaryMessage != null && shouldSummarize === true;
|
||||
|
||||
logger.debug('[BaseClient] Context Count (2/2)', {
|
||||
remainingContextTokens,
|
||||
|
|
@ -538,17 +534,18 @@ class BaseClient {
|
|||
/** @type {Record<string, number> | undefined} */
|
||||
let tokenCountMap;
|
||||
if (buildTokenMap) {
|
||||
tokenCountMap = orderedWithInstructions.reduce((map, message, index) => {
|
||||
const currentPayload = shouldSummarize ? orderedWithInstructions : context;
|
||||
tokenCountMap = currentPayload.reduce((map, message, index) => {
|
||||
const { messageId } = message;
|
||||
if (!messageId) {
|
||||
return map;
|
||||
}
|
||||
|
||||
if (shouldSummarize && index === summaryIndex && !usePrevSummary) {
|
||||
if (shouldSummarize && index === messagesToRefine.length - 1 && !usePrevSummary) {
|
||||
map.summaryMessage = { ...summaryMessage, messageId, tokenCount: summaryTokenCount };
|
||||
}
|
||||
|
||||
map[messageId] = orderedWithInstructions[index].tokenCount;
|
||||
map[messageId] = currentPayload[index].tokenCount;
|
||||
return map;
|
||||
}, {});
|
||||
}
|
||||
|
|
@ -1021,11 +1018,17 @@ class BaseClient {
|
|||
const processValue = (value) => {
|
||||
if (Array.isArray(value)) {
|
||||
for (let item of value) {
|
||||
if (!item || !item.type || item.type === 'image_url') {
|
||||
if (
|
||||
!item ||
|
||||
!item.type ||
|
||||
item.type === ContentTypes.THINK ||
|
||||
item.type === ContentTypes.ERROR ||
|
||||
item.type === ContentTypes.IMAGE_URL
|
||||
) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (item.type === 'tool_call' && item.tool_call != null) {
|
||||
if (item.type === ContentTypes.TOOL_CALL && item.tool_call != null) {
|
||||
const toolName = item.tool_call?.name || '';
|
||||
if (toolName != null && toolName && typeof toolName === 'string') {
|
||||
numTokens += this.getTokenCount(toolName);
|
||||
|
|
@ -1121,9 +1124,13 @@ class BaseClient {
|
|||
return message;
|
||||
}
|
||||
|
||||
const files = await getFiles({
|
||||
file_id: { $in: fileIds },
|
||||
});
|
||||
const files = await getFiles(
|
||||
{
|
||||
file_id: { $in: fileIds },
|
||||
},
|
||||
{},
|
||||
{},
|
||||
);
|
||||
|
||||
await this.addImageURLs(message, files, this.visionMode);
|
||||
|
||||
|
|
|
|||
|
|
@ -1272,6 +1272,29 @@ ${convo}
|
|||
});
|
||||
}
|
||||
|
||||
/** Note: OpenAI Web Search models do not support any known parameters besdies `max_tokens` */
|
||||
if (modelOptions.model && /gpt-4o.*search/.test(modelOptions.model)) {
|
||||
const searchExcludeParams = [
|
||||
'frequency_penalty',
|
||||
'presence_penalty',
|
||||
'temperature',
|
||||
'top_p',
|
||||
'top_k',
|
||||
'stop',
|
||||
'logit_bias',
|
||||
'seed',
|
||||
'response_format',
|
||||
'n',
|
||||
'logprobs',
|
||||
'user',
|
||||
];
|
||||
|
||||
this.options.dropParams = this.options.dropParams || [];
|
||||
this.options.dropParams = [
|
||||
...new Set([...this.options.dropParams, ...searchExcludeParams]),
|
||||
];
|
||||
}
|
||||
|
||||
if (this.options.dropParams && Array.isArray(this.options.dropParams)) {
|
||||
this.options.dropParams.forEach((param) => {
|
||||
delete modelOptions[param];
|
||||
|
|
|
|||
|
|
@ -211,7 +211,7 @@ const formatAgentMessages = (payload) => {
|
|||
} else if (part.type === ContentTypes.THINK) {
|
||||
hasReasoning = true;
|
||||
continue;
|
||||
} else if (part.type === ContentTypes.ERROR) {
|
||||
} else if (part.type === ContentTypes.ERROR || part.type === ContentTypes.AGENT_UPDATE) {
|
||||
continue;
|
||||
} else {
|
||||
currentContent.push(part);
|
||||
|
|
|
|||
|
|
@ -164,7 +164,7 @@ describe('BaseClient', () => {
|
|||
const result = await TestClient.getMessagesWithinTokenLimit({ messages });
|
||||
|
||||
expect(result.context).toEqual(expectedContext);
|
||||
expect(result.summaryIndex).toEqual(expectedIndex);
|
||||
expect(result.messagesToRefine.length - 1).toEqual(expectedIndex);
|
||||
expect(result.remainingContextTokens).toBe(expectedRemainingContextTokens);
|
||||
expect(result.messagesToRefine).toEqual(expectedMessagesToRefine);
|
||||
});
|
||||
|
|
@ -200,7 +200,7 @@ describe('BaseClient', () => {
|
|||
const result = await TestClient.getMessagesWithinTokenLimit({ messages });
|
||||
|
||||
expect(result.context).toEqual(expectedContext);
|
||||
expect(result.summaryIndex).toEqual(expectedIndex);
|
||||
expect(result.messagesToRefine.length - 1).toEqual(expectedIndex);
|
||||
expect(result.remainingContextTokens).toBe(expectedRemainingContextTokens);
|
||||
expect(result.messagesToRefine).toEqual(expectedMessagesToRefine);
|
||||
});
|
||||
|
|
|
|||
|
|
@ -172,7 +172,7 @@ Error Message: ${error.message}`);
|
|||
{
|
||||
type: ContentTypes.IMAGE_URL,
|
||||
image_url: {
|
||||
url: `data:image/jpeg;base64,${base64}`,
|
||||
url: `data:image/png;base64,${base64}`,
|
||||
},
|
||||
},
|
||||
];
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ const {
|
|||
} = require('../');
|
||||
const { primeFiles: primeCodeFiles } = require('~/server/services/Files/Code/process');
|
||||
const { createFileSearchTool, primeFiles: primeSearchFiles } = require('./fileSearch');
|
||||
const { loadAuthValues } = require('~/server/services/Tools/credentials');
|
||||
const { createMCPTool } = require('~/server/services/MCP');
|
||||
const { loadSpecs } = require('./loadSpecs');
|
||||
const { logger } = require('~/config');
|
||||
|
|
@ -90,45 +91,6 @@ const validateTools = async (user, tools = []) => {
|
|||
}
|
||||
};
|
||||
|
||||
const loadAuthValues = async ({ userId, authFields, throwError = true }) => {
|
||||
let authValues = {};
|
||||
|
||||
/**
|
||||
* Finds the first non-empty value for the given authentication field, supporting alternate fields.
|
||||
* @param {string[]} fields Array of strings representing the authentication fields. Supports alternate fields delimited by "||".
|
||||
* @returns {Promise<{ authField: string, authValue: string} | null>} An object containing the authentication field and value, or null if not found.
|
||||
*/
|
||||
const findAuthValue = async (fields) => {
|
||||
for (const field of fields) {
|
||||
let value = process.env[field];
|
||||
if (value) {
|
||||
return { authField: field, authValue: value };
|
||||
}
|
||||
try {
|
||||
value = await getUserPluginAuthValue(userId, field, throwError);
|
||||
} catch (err) {
|
||||
if (field === fields[fields.length - 1] && !value) {
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
if (value) {
|
||||
return { authField: field, authValue: value };
|
||||
}
|
||||
}
|
||||
return null;
|
||||
};
|
||||
|
||||
for (let authField of authFields) {
|
||||
const fields = authField.split('||');
|
||||
const result = await findAuthValue(fields);
|
||||
if (result) {
|
||||
authValues[result.authField] = result.authValue;
|
||||
}
|
||||
}
|
||||
|
||||
return authValues;
|
||||
};
|
||||
|
||||
/** @typedef {typeof import('@langchain/core/tools').Tool} ToolConstructor */
|
||||
/** @typedef {import('@langchain/core/tools').Tool} Tool */
|
||||
|
||||
|
|
@ -348,7 +310,6 @@ const loadTools = async ({
|
|||
|
||||
module.exports = {
|
||||
loadToolWithAuth,
|
||||
loadAuthValues,
|
||||
validateTools,
|
||||
loadTools,
|
||||
};
|
||||
|
|
|
|||
|
|
@ -1,9 +1,8 @@
|
|||
const { validateTools, loadTools, loadAuthValues } = require('./handleTools');
|
||||
const { validateTools, loadTools } = require('./handleTools');
|
||||
const handleOpenAIErrors = require('./handleOpenAIErrors');
|
||||
|
||||
module.exports = {
|
||||
handleOpenAIErrors,
|
||||
loadAuthValues,
|
||||
validateTools,
|
||||
loadTools,
|
||||
};
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
const axios = require('axios');
|
||||
const { EventSource } = require('eventsource');
|
||||
const { Time, CacheKeys } = require('librechat-data-provider');
|
||||
const logger = require('./winston');
|
||||
|
|
@ -47,9 +48,46 @@ const sendEvent = (res, event) => {
|
|||
res.write(`event: message\ndata: ${JSON.stringify(event)}\n\n`);
|
||||
};
|
||||
|
||||
/**
|
||||
* Creates and configures an Axios instance with optional proxy settings.
|
||||
*
|
||||
* @typedef {import('axios').AxiosInstance} AxiosInstance
|
||||
* @typedef {import('axios').AxiosProxyConfig} AxiosProxyConfig
|
||||
*
|
||||
* @returns {AxiosInstance} A configured Axios instance
|
||||
* @throws {Error} If there's an issue creating the Axios instance or parsing the proxy URL
|
||||
*/
|
||||
function createAxiosInstance() {
|
||||
const instance = axios.create();
|
||||
|
||||
if (process.env.proxy) {
|
||||
try {
|
||||
const url = new URL(process.env.proxy);
|
||||
|
||||
/** @type {AxiosProxyConfig} */
|
||||
const proxyConfig = {
|
||||
host: url.hostname.replace(/^\[|\]$/g, ''),
|
||||
protocol: url.protocol.replace(':', ''),
|
||||
};
|
||||
|
||||
if (url.port) {
|
||||
proxyConfig.port = parseInt(url.port, 10);
|
||||
}
|
||||
|
||||
instance.defaults.proxy = proxyConfig;
|
||||
} catch (error) {
|
||||
console.error('Error parsing proxy URL:', error);
|
||||
throw new Error(`Invalid proxy URL: ${process.env.proxy}`);
|
||||
}
|
||||
}
|
||||
|
||||
return instance;
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
logger,
|
||||
sendEvent,
|
||||
getMCPManager,
|
||||
createAxiosInstance,
|
||||
getFlowStateManager,
|
||||
};
|
||||
|
|
|
|||
126
api/config/index.spec.js
Normal file
126
api/config/index.spec.js
Normal file
|
|
@ -0,0 +1,126 @@
|
|||
const axios = require('axios');
|
||||
const { createAxiosInstance } = require('./index');
|
||||
|
||||
// Mock axios
|
||||
jest.mock('axios', () => ({
|
||||
interceptors: {
|
||||
request: { use: jest.fn(), eject: jest.fn() },
|
||||
response: { use: jest.fn(), eject: jest.fn() },
|
||||
},
|
||||
create: jest.fn().mockReturnValue({
|
||||
defaults: {
|
||||
proxy: null,
|
||||
},
|
||||
get: jest.fn().mockResolvedValue({ data: {} }),
|
||||
post: jest.fn().mockResolvedValue({ data: {} }),
|
||||
put: jest.fn().mockResolvedValue({ data: {} }),
|
||||
delete: jest.fn().mockResolvedValue({ data: {} }),
|
||||
}),
|
||||
get: jest.fn().mockResolvedValue({ data: {} }),
|
||||
post: jest.fn().mockResolvedValue({ data: {} }),
|
||||
put: jest.fn().mockResolvedValue({ data: {} }),
|
||||
delete: jest.fn().mockResolvedValue({ data: {} }),
|
||||
reset: jest.fn().mockImplementation(function () {
|
||||
this.get.mockClear();
|
||||
this.post.mockClear();
|
||||
this.put.mockClear();
|
||||
this.delete.mockClear();
|
||||
this.create.mockClear();
|
||||
}),
|
||||
}));
|
||||
|
||||
describe('createAxiosInstance', () => {
|
||||
const originalEnv = process.env;
|
||||
|
||||
beforeEach(() => {
|
||||
// Reset mocks
|
||||
jest.clearAllMocks();
|
||||
// Create a clean copy of process.env
|
||||
process.env = { ...originalEnv };
|
||||
// Default: no proxy
|
||||
delete process.env.proxy;
|
||||
});
|
||||
|
||||
afterAll(() => {
|
||||
// Restore original process.env
|
||||
process.env = originalEnv;
|
||||
});
|
||||
|
||||
test('creates an axios instance without proxy when no proxy env is set', () => {
|
||||
const instance = createAxiosInstance();
|
||||
|
||||
expect(axios.create).toHaveBeenCalledTimes(1);
|
||||
expect(instance.defaults.proxy).toBeNull();
|
||||
});
|
||||
|
||||
test('configures proxy correctly with hostname and protocol', () => {
|
||||
process.env.proxy = 'http://example.com';
|
||||
|
||||
const instance = createAxiosInstance();
|
||||
|
||||
expect(axios.create).toHaveBeenCalledTimes(1);
|
||||
expect(instance.defaults.proxy).toEqual({
|
||||
host: 'example.com',
|
||||
protocol: 'http',
|
||||
});
|
||||
});
|
||||
|
||||
test('configures proxy correctly with hostname, protocol and port', () => {
|
||||
process.env.proxy = 'https://proxy.example.com:8080';
|
||||
|
||||
const instance = createAxiosInstance();
|
||||
|
||||
expect(axios.create).toHaveBeenCalledTimes(1);
|
||||
expect(instance.defaults.proxy).toEqual({
|
||||
host: 'proxy.example.com',
|
||||
protocol: 'https',
|
||||
port: 8080,
|
||||
});
|
||||
});
|
||||
|
||||
test('handles proxy URLs with authentication', () => {
|
||||
process.env.proxy = 'http://user:pass@proxy.example.com:3128';
|
||||
|
||||
const instance = createAxiosInstance();
|
||||
|
||||
expect(axios.create).toHaveBeenCalledTimes(1);
|
||||
expect(instance.defaults.proxy).toEqual({
|
||||
host: 'proxy.example.com',
|
||||
protocol: 'http',
|
||||
port: 3128,
|
||||
// Note: The current implementation doesn't handle auth - if needed, add this functionality
|
||||
});
|
||||
});
|
||||
|
||||
test('throws error when proxy URL is invalid', () => {
|
||||
process.env.proxy = 'invalid-url';
|
||||
|
||||
expect(() => createAxiosInstance()).toThrow('Invalid proxy URL');
|
||||
expect(axios.create).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
// If you want to test the actual URL parsing more thoroughly
|
||||
test('handles edge case proxy URLs correctly', () => {
|
||||
// IPv6 address
|
||||
process.env.proxy = 'http://[::1]:8080';
|
||||
|
||||
let instance = createAxiosInstance();
|
||||
|
||||
expect(instance.defaults.proxy).toEqual({
|
||||
host: '::1',
|
||||
protocol: 'http',
|
||||
port: 8080,
|
||||
});
|
||||
|
||||
// URL with path (which should be ignored for proxy config)
|
||||
process.env.proxy = 'http://proxy.example.com:8080/some/path';
|
||||
|
||||
instance = createAxiosInstance();
|
||||
|
||||
expect(instance.defaults.proxy).toEqual({
|
||||
host: 'proxy.example.com',
|
||||
protocol: 'http',
|
||||
port: 8080,
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
@ -28,4 +28,4 @@ const getBanner = async (user) => {
|
|||
}
|
||||
};
|
||||
|
||||
module.exports = { getBanner };
|
||||
module.exports = { Banner, getBanner };
|
||||
|
|
|
|||
|
|
@ -15,19 +15,6 @@ const searchConversation = async (conversationId) => {
|
|||
throw new Error('Error searching conversation');
|
||||
}
|
||||
};
|
||||
/**
|
||||
* Searches for a conversation by conversationId and returns associated file ids.
|
||||
* @param {string} conversationId - The conversation's ID.
|
||||
* @returns {Promise<string[] | null>}
|
||||
*/
|
||||
const getConvoFiles = async (conversationId) => {
|
||||
try {
|
||||
return (await Conversation.findOne({ conversationId }, 'files').lean())?.files ?? [];
|
||||
} catch (error) {
|
||||
logger.error('[getConvoFiles] Error getting conversation files', error);
|
||||
throw new Error('Error getting conversation files');
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Retrieves a single conversation for a given user and conversation ID.
|
||||
|
|
@ -73,6 +60,20 @@ const deleteNullOrEmptyConversations = async () => {
|
|||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Searches for a conversation by conversationId and returns associated file ids.
|
||||
* @param {string} conversationId - The conversation's ID.
|
||||
* @returns {Promise<string[] | null>}
|
||||
*/
|
||||
const getConvoFiles = async (conversationId) => {
|
||||
try {
|
||||
return (await Conversation.findOne({ conversationId }, 'files').lean())?.files ?? [];
|
||||
} catch (error) {
|
||||
logger.error('[getConvoFiles] Error getting conversation files', error);
|
||||
throw new Error('Error getting conversation files');
|
||||
}
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
Conversation,
|
||||
getConvoFiles,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
const mongoose = require('mongoose');
|
||||
const { fileSchema } = require('@librechat/data-schemas');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const File = mongoose.model('File', fileSchema);
|
||||
|
||||
|
|
@ -17,11 +18,39 @@ const findFileById = async (file_id, options = {}) => {
|
|||
* Retrieves files matching a given filter, sorted by the most recently updated.
|
||||
* @param {Object} filter - The filter criteria to apply.
|
||||
* @param {Object} [_sortOptions] - Optional sort parameters.
|
||||
* @param {Object|String} [selectFields={ text: 0 }] - Fields to include/exclude in the query results.
|
||||
* Default excludes the 'text' field.
|
||||
* @returns {Promise<Array<IMongoFile>>} A promise that resolves to an array of file documents.
|
||||
*/
|
||||
const getFiles = async (filter, _sortOptions) => {
|
||||
const getFiles = async (filter, _sortOptions, selectFields = { text: 0 }) => {
|
||||
const sortOptions = { updatedAt: -1, ..._sortOptions };
|
||||
return await File.find(filter).sort(sortOptions).lean();
|
||||
return await File.find(filter).select(selectFields).sort(sortOptions).lean();
|
||||
};
|
||||
|
||||
/**
|
||||
* Retrieves tool files (files that are embedded or have a fileIdentifier) from an array of file IDs
|
||||
* @param {string[]} fileIds - Array of file_id strings to search for
|
||||
* @returns {Promise<Array<IMongoFile>>} Files that match the criteria
|
||||
*/
|
||||
const getToolFilesByIds = async (fileIds) => {
|
||||
if (!fileIds || !fileIds.length) {
|
||||
return [];
|
||||
}
|
||||
|
||||
try {
|
||||
const filter = {
|
||||
file_id: { $in: fileIds },
|
||||
$or: [{ embedded: true }, { 'metadata.fileIdentifier': { $exists: true } }],
|
||||
};
|
||||
|
||||
const selectFields = { text: 0 };
|
||||
const sortOptions = { updatedAt: -1 };
|
||||
|
||||
return await getFiles(filter, sortOptions, selectFields);
|
||||
} catch (error) {
|
||||
logger.error('[getToolFilesByIds] Error retrieving tool files:', error);
|
||||
throw new Error('Error retrieving tool files');
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
|
|
@ -109,6 +138,7 @@ module.exports = {
|
|||
File,
|
||||
findFileById,
|
||||
getFiles,
|
||||
getToolFilesByIds,
|
||||
createFile,
|
||||
updateFile,
|
||||
updateFileUsage,
|
||||
|
|
|
|||
|
|
@ -71,7 +71,42 @@ async function saveMessage(req, params, metadata) {
|
|||
} catch (err) {
|
||||
logger.error('Error saving message:', err);
|
||||
logger.info(`---\`saveMessage\` context: ${metadata?.context}`);
|
||||
throw err;
|
||||
|
||||
// Check if this is a duplicate key error (MongoDB error code 11000)
|
||||
if (err.code === 11000 && err.message.includes('duplicate key error')) {
|
||||
// Log the duplicate key error but don't crash the application
|
||||
logger.warn(`Duplicate messageId detected: ${params.messageId}. Continuing execution.`);
|
||||
|
||||
try {
|
||||
// Try to find the existing message with this ID
|
||||
const existingMessage = await Message.findOne({
|
||||
messageId: params.messageId,
|
||||
user: req.user.id,
|
||||
});
|
||||
|
||||
// If we found it, return it
|
||||
if (existingMessage) {
|
||||
return existingMessage.toObject();
|
||||
}
|
||||
|
||||
// If we can't find it (unlikely but possible in race conditions)
|
||||
return {
|
||||
...params,
|
||||
messageId: params.messageId,
|
||||
user: req.user.id,
|
||||
};
|
||||
} catch (findError) {
|
||||
// If the findOne also fails, log it but don't crash
|
||||
logger.warn(`Could not retrieve existing message with ID ${params.messageId}: ${findError.message}`);
|
||||
return {
|
||||
...params,
|
||||
messageId: params.messageId,
|
||||
user: req.user.id,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
throw err; // Re-throw other errors
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -61,6 +61,7 @@ const bedrockValues = {
|
|||
'amazon.nova-micro-v1:0': { prompt: 0.035, completion: 0.14 },
|
||||
'amazon.nova-lite-v1:0': { prompt: 0.06, completion: 0.24 },
|
||||
'amazon.nova-pro-v1:0': { prompt: 0.8, completion: 3.2 },
|
||||
'deepseek.r1': { prompt: 1.35, completion: 5.4 },
|
||||
};
|
||||
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -288,7 +288,7 @@ describe('AWS Bedrock Model Tests', () => {
|
|||
});
|
||||
|
||||
describe('Deepseek Model Tests', () => {
|
||||
const deepseekModels = ['deepseek-chat', 'deepseek-coder', 'deepseek-reasoner'];
|
||||
const deepseekModels = ['deepseek-chat', 'deepseek-coder', 'deepseek-reasoner', 'deepseek.r1'];
|
||||
|
||||
it('should return the correct prompt multipliers for all models', () => {
|
||||
const results = deepseekModels.map((model) => {
|
||||
|
|
|
|||
|
|
@ -35,6 +35,8 @@
|
|||
"homepage": "https://librechat.ai",
|
||||
"dependencies": {
|
||||
"@anthropic-ai/sdk": "^0.37.0",
|
||||
"@aws-sdk/client-s3": "^3.758.0",
|
||||
"@aws-sdk/s3-request-presigner": "^3.758.0",
|
||||
"@azure/search-documents": "^12.0.0",
|
||||
"@google/generative-ai": "^0.23.0",
|
||||
"@googleapis/youtube": "^20.0.0",
|
||||
|
|
@ -42,10 +44,10 @@
|
|||
"@keyv/redis": "^2.8.1",
|
||||
"@langchain/community": "^0.3.34",
|
||||
"@langchain/core": "^0.3.40",
|
||||
"@langchain/google-genai": "^0.1.9",
|
||||
"@langchain/google-vertexai": "^0.2.0",
|
||||
"@langchain/google-genai": "^0.1.11",
|
||||
"@langchain/google-vertexai": "^0.2.2",
|
||||
"@langchain/textsplitters": "^0.1.0",
|
||||
"@librechat/agents": "^2.2.0",
|
||||
"@librechat/agents": "^2.2.8",
|
||||
"@librechat/data-schemas": "*",
|
||||
"@waylaidwanderer/fetch-event-source": "^3.0.1",
|
||||
"axios": "^1.8.2",
|
||||
|
|
@ -82,7 +84,7 @@
|
|||
"memorystore": "^1.6.7",
|
||||
"mime": "^3.0.0",
|
||||
"module-alias": "^2.2.3",
|
||||
"mongoose": "^8.9.5",
|
||||
"mongoose": "^8.12.1",
|
||||
"multer": "^1.4.5-lts.1",
|
||||
"nanoid": "^3.3.7",
|
||||
"nodemailer": "^6.9.15",
|
||||
|
|
|
|||
|
|
@ -10,8 +10,8 @@ const {
|
|||
ChatModelStreamHandler,
|
||||
} = require('@librechat/agents');
|
||||
const { processCodeOutput } = require('~/server/services/Files/Code/process');
|
||||
const { loadAuthValues } = require('~/server/services/Tools/credentials');
|
||||
const { saveBase64Image } = require('~/server/services/Files/process');
|
||||
const { loadAuthValues } = require('~/app/clients/tools/util');
|
||||
const { logger, sendEvent } = require('~/config');
|
||||
|
||||
/** @typedef {import('@librechat/agents').Graph} Graph */
|
||||
|
|
|
|||
|
|
@ -7,7 +7,16 @@
|
|||
// validateVisionModel,
|
||||
// mapModelToAzureConfig,
|
||||
// } = require('librechat-data-provider');
|
||||
const { Callback, createMetadataAggregator } = require('@librechat/agents');
|
||||
require('events').EventEmitter.defaultMaxListeners = 100;
|
||||
const {
|
||||
Callback,
|
||||
GraphEvents,
|
||||
formatMessage,
|
||||
formatAgentMessages,
|
||||
formatContentStrings,
|
||||
getTokenCountForMessage,
|
||||
createMetadataAggregator,
|
||||
} = require('@librechat/agents');
|
||||
const {
|
||||
Constants,
|
||||
VisionModes,
|
||||
|
|
@ -17,24 +26,19 @@ const {
|
|||
KnownEndpoints,
|
||||
anthropicSchema,
|
||||
isAgentsEndpoint,
|
||||
AgentCapabilities,
|
||||
bedrockInputSchema,
|
||||
removeNullishValues,
|
||||
} = require('librechat-data-provider');
|
||||
const {
|
||||
formatMessage,
|
||||
addCacheControl,
|
||||
formatAgentMessages,
|
||||
formatContentStrings,
|
||||
createContextHandlers,
|
||||
} = require('~/app/clients/prompts');
|
||||
const { getCustomEndpointConfig, checkCapability } = require('~/server/services/Config');
|
||||
const { addCacheControl, createContextHandlers } = require('~/app/clients/prompts');
|
||||
const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens');
|
||||
const { getBufferString, HumanMessage } = require('@langchain/core/messages');
|
||||
const { encodeAndFormat } = require('~/server/services/Files/images/encode');
|
||||
const { getCustomEndpointConfig } = require('~/server/services/Config');
|
||||
const Tokenizer = require('~/server/services/Tokenizer');
|
||||
const BaseClient = require('~/app/clients/BaseClient');
|
||||
const { logger, sendEvent } = require('~/config');
|
||||
const { createRun } = require('./run');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
/** @typedef {import('@librechat/agents').MessageContentComplex} MessageContentComplex */
|
||||
/** @typedef {import('@langchain/core/runnables').RunnableConfig} RunnableConfig */
|
||||
|
|
@ -99,6 +103,8 @@ class AgentClient extends BaseClient {
|
|||
this.outputTokensKey = 'output_tokens';
|
||||
/** @type {UsageMetadata} */
|
||||
this.usage;
|
||||
/** @type {Record<string, number>} */
|
||||
this.indexTokenCountMap = {};
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -223,14 +229,23 @@ class AgentClient extends BaseClient {
|
|||
};
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param {TMessage} message
|
||||
* @param {Array<MongoFile>} attachments
|
||||
* @returns {Promise<Array<Partial<MongoFile>>>}
|
||||
*/
|
||||
async addImageURLs(message, attachments) {
|
||||
const { files, image_urls } = await encodeAndFormat(
|
||||
const { files, text, image_urls } = await encodeAndFormat(
|
||||
this.options.req,
|
||||
attachments,
|
||||
this.options.agent.provider,
|
||||
VisionModes.agents,
|
||||
);
|
||||
message.image_urls = image_urls.length ? image_urls : undefined;
|
||||
if (text && text.length) {
|
||||
message.ocr = text;
|
||||
}
|
||||
return files;
|
||||
}
|
||||
|
||||
|
|
@ -308,7 +323,21 @@ class AgentClient extends BaseClient {
|
|||
assistantName: this.options?.modelLabel,
|
||||
});
|
||||
|
||||
const needsTokenCount = this.contextStrategy && !orderedMessages[i].tokenCount;
|
||||
if (message.ocr && i !== orderedMessages.length - 1) {
|
||||
if (typeof formattedMessage.content === 'string') {
|
||||
formattedMessage.content = message.ocr + '\n' + formattedMessage.content;
|
||||
} else {
|
||||
const textPart = formattedMessage.content.find((part) => part.type === 'text');
|
||||
textPart
|
||||
? (textPart.text = message.ocr + '\n' + textPart.text)
|
||||
: formattedMessage.content.unshift({ type: 'text', text: message.ocr });
|
||||
}
|
||||
} else if (message.ocr && i === orderedMessages.length - 1) {
|
||||
systemContent = [systemContent, message.ocr].join('\n');
|
||||
}
|
||||
|
||||
const needsTokenCount =
|
||||
(this.contextStrategy && !orderedMessages[i].tokenCount) || message.ocr;
|
||||
|
||||
/* If tokens were never counted, or, is a Vision request and the message has files, count again */
|
||||
if (needsTokenCount || (this.isVisionModel && (message.image_urls || message.files))) {
|
||||
|
|
@ -354,6 +383,10 @@ class AgentClient extends BaseClient {
|
|||
}));
|
||||
}
|
||||
|
||||
for (let i = 0; i < messages.length; i++) {
|
||||
this.indexTokenCountMap[i] = messages[i].tokenCount;
|
||||
}
|
||||
|
||||
const result = {
|
||||
tokenCountMap,
|
||||
prompt: payload,
|
||||
|
|
@ -599,6 +632,9 @@ class AgentClient extends BaseClient {
|
|||
// });
|
||||
// }
|
||||
|
||||
/** @type {TCustomConfig['endpoints']['agents']} */
|
||||
const agentsEConfig = this.options.req.app.locals[EModelEndpoint.agents];
|
||||
|
||||
/** @type {Partial<RunnableConfig> & { version: 'v1' | 'v2'; run_id?: string; streamMode: string }} */
|
||||
const config = {
|
||||
configurable: {
|
||||
|
|
@ -606,19 +642,30 @@ class AgentClient extends BaseClient {
|
|||
last_agent_index: this.agentConfigs?.size ?? 0,
|
||||
hide_sequential_outputs: this.options.agent.hide_sequential_outputs,
|
||||
},
|
||||
recursionLimit: this.options.req.app.locals[EModelEndpoint.agents]?.recursionLimit,
|
||||
recursionLimit: agentsEConfig?.recursionLimit,
|
||||
signal: abortController.signal,
|
||||
streamMode: 'values',
|
||||
version: 'v2',
|
||||
};
|
||||
|
||||
const initialMessages = formatAgentMessages(payload);
|
||||
const toolSet = new Set((this.options.agent.tools ?? []).map((tool) => tool && tool.name));
|
||||
let { messages: initialMessages, indexTokenCountMap } = formatAgentMessages(
|
||||
payload,
|
||||
this.indexTokenCountMap,
|
||||
toolSet,
|
||||
);
|
||||
if (legacyContentEndpoints.has(this.options.agent.endpoint)) {
|
||||
formatContentStrings(initialMessages);
|
||||
initialMessages = formatContentStrings(initialMessages);
|
||||
}
|
||||
|
||||
/** @type {ReturnType<createRun>} */
|
||||
let run;
|
||||
const countTokens = ((text) => this.getTokenCount(text)).bind(this);
|
||||
|
||||
/** @type {(message: BaseMessage) => number} */
|
||||
const tokenCounter = (message) => {
|
||||
return getTokenCountForMessage(message, countTokens);
|
||||
};
|
||||
|
||||
/**
|
||||
*
|
||||
|
|
@ -626,12 +673,23 @@ class AgentClient extends BaseClient {
|
|||
* @param {BaseMessage[]} messages
|
||||
* @param {number} [i]
|
||||
* @param {TMessageContentParts[]} [contentData]
|
||||
* @param {Record<string, number>} [currentIndexCountMap]
|
||||
*/
|
||||
const runAgent = async (agent, _messages, i = 0, contentData = []) => {
|
||||
const runAgent = async (agent, _messages, i = 0, contentData = [], _currentIndexCountMap) => {
|
||||
config.configurable.model = agent.model_parameters.model;
|
||||
const currentIndexCountMap = _currentIndexCountMap ?? indexTokenCountMap;
|
||||
if (i > 0) {
|
||||
this.model = agent.model_parameters.model;
|
||||
}
|
||||
if (agent.recursion_limit && typeof agent.recursion_limit === 'number') {
|
||||
config.recursionLimit = agent.recursion_limit;
|
||||
}
|
||||
if (
|
||||
agentsEConfig?.maxRecursionLimit &&
|
||||
config.recursionLimit > agentsEConfig?.maxRecursionLimit
|
||||
) {
|
||||
config.recursionLimit = agentsEConfig?.maxRecursionLimit;
|
||||
}
|
||||
config.configurable.agent_id = agent.id;
|
||||
config.configurable.name = agent.name;
|
||||
config.configurable.agent_index = i;
|
||||
|
|
@ -694,11 +752,29 @@ class AgentClient extends BaseClient {
|
|||
}
|
||||
|
||||
if (contentData.length) {
|
||||
const agentUpdate = {
|
||||
type: ContentTypes.AGENT_UPDATE,
|
||||
[ContentTypes.AGENT_UPDATE]: {
|
||||
index: contentData.length,
|
||||
runId: this.responseMessageId,
|
||||
agentId: agent.id,
|
||||
},
|
||||
};
|
||||
const streamData = {
|
||||
event: GraphEvents.ON_AGENT_UPDATE,
|
||||
data: agentUpdate,
|
||||
};
|
||||
this.options.aggregateContent(streamData);
|
||||
sendEvent(this.options.res, streamData);
|
||||
contentData.push(agentUpdate);
|
||||
run.Graph.contentData = contentData;
|
||||
}
|
||||
|
||||
await run.processStream({ messages }, config, {
|
||||
keepContent: i !== 0,
|
||||
tokenCounter,
|
||||
indexTokenCountMap: currentIndexCountMap,
|
||||
maxContextTokens: agent.maxContextTokens,
|
||||
callbacks: {
|
||||
[Callback.TOOL_ERROR]: (graph, error, toolId) => {
|
||||
logger.error(
|
||||
|
|
@ -712,9 +788,13 @@ class AgentClient extends BaseClient {
|
|||
};
|
||||
|
||||
await runAgent(this.options.agent, initialMessages);
|
||||
|
||||
let finalContentStart = 0;
|
||||
if (this.agentConfigs && this.agentConfigs.size > 0) {
|
||||
if (
|
||||
this.agentConfigs &&
|
||||
this.agentConfigs.size > 0 &&
|
||||
(await checkCapability(this.options.req, AgentCapabilities.chain))
|
||||
) {
|
||||
const windowSize = 5;
|
||||
let latestMessage = initialMessages.pop().content;
|
||||
if (typeof latestMessage !== 'string') {
|
||||
latestMessage = latestMessage[0].text;
|
||||
|
|
@ -722,7 +802,16 @@ class AgentClient extends BaseClient {
|
|||
let i = 1;
|
||||
let runMessages = [];
|
||||
|
||||
const lastFiveMessages = initialMessages.slice(-5);
|
||||
const windowIndexCountMap = {};
|
||||
const windowMessages = initialMessages.slice(-windowSize);
|
||||
let currentIndex = 4;
|
||||
for (let i = initialMessages.length - 1; i >= 0; i--) {
|
||||
windowIndexCountMap[currentIndex] = indexTokenCountMap[i];
|
||||
currentIndex--;
|
||||
if (currentIndex < 0) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
for (const [agentId, agent] of this.agentConfigs) {
|
||||
if (abortController.signal.aborted === true) {
|
||||
break;
|
||||
|
|
@ -757,7 +846,9 @@ class AgentClient extends BaseClient {
|
|||
}
|
||||
try {
|
||||
const contextMessages = [];
|
||||
for (const message of lastFiveMessages) {
|
||||
const runIndexCountMap = {};
|
||||
for (let i = 0; i < windowMessages.length; i++) {
|
||||
const message = windowMessages[i];
|
||||
const messageType = message._getType();
|
||||
if (
|
||||
(!agent.tools || agent.tools.length === 0) &&
|
||||
|
|
@ -765,11 +856,13 @@ class AgentClient extends BaseClient {
|
|||
) {
|
||||
continue;
|
||||
}
|
||||
|
||||
runIndexCountMap[contextMessages.length] = windowIndexCountMap[i];
|
||||
contextMessages.push(message);
|
||||
}
|
||||
const currentMessages = [...contextMessages, new HumanMessage(bufferString)];
|
||||
await runAgent(agent, currentMessages, i, contentData);
|
||||
const bufferMessage = new HumanMessage(bufferString);
|
||||
runIndexCountMap[contextMessages.length] = tokenCounter(bufferMessage);
|
||||
const currentMessages = [...contextMessages, bufferMessage];
|
||||
await runAgent(agent, currentMessages, i, contentData, runIndexCountMap);
|
||||
} catch (err) {
|
||||
logger.error(
|
||||
`[api/server/controllers/agents/client.js #chatCompletion] Error running agent ${agentId} (${i})`,
|
||||
|
|
@ -780,6 +873,7 @@ class AgentClient extends BaseClient {
|
|||
}
|
||||
}
|
||||
|
||||
/** Note: not implemented */
|
||||
if (config.configurable.hide_sequential_outputs !== true) {
|
||||
finalContentStart = 0;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
const fs = require('fs').promises;
|
||||
const { nanoid } = require('nanoid');
|
||||
const {
|
||||
FileContext,
|
||||
Constants,
|
||||
Tools,
|
||||
Constants,
|
||||
FileContext,
|
||||
SystemRoles,
|
||||
EToolResources,
|
||||
actionDelimiter,
|
||||
} = require('librechat-data-provider');
|
||||
const {
|
||||
|
|
@ -203,14 +204,21 @@ const duplicateAgentHandler = async (req, res) => {
|
|||
}
|
||||
|
||||
const {
|
||||
_id: __id,
|
||||
id: _id,
|
||||
_id: __id,
|
||||
author: _author,
|
||||
createdAt: _createdAt,
|
||||
updatedAt: _updatedAt,
|
||||
tool_resources: _tool_resources = {},
|
||||
...cloneData
|
||||
} = agent;
|
||||
|
||||
if (_tool_resources?.[EToolResources.ocr]) {
|
||||
cloneData.tool_resources = {
|
||||
[EToolResources.ocr]: _tool_resources[EToolResources.ocr],
|
||||
};
|
||||
}
|
||||
|
||||
const newAgentId = `agent_${nanoid()}`;
|
||||
const newAgentData = Object.assign(cloneData, {
|
||||
id: newAgentId,
|
||||
|
|
|
|||
|
|
@ -10,7 +10,8 @@ const {
|
|||
const { processFileURL, uploadImageBuffer } = require('~/server/services/Files/process');
|
||||
const { processCodeOutput } = require('~/server/services/Files/Code/process');
|
||||
const { createToolCall, getToolCallsByConvo } = require('~/models/ToolCall');
|
||||
const { loadAuthValues, loadTools } = require('~/app/clients/tools/util');
|
||||
const { loadAuthValues } = require('~/server/services/Tools/credentials');
|
||||
const { loadTools } = require('~/app/clients/tools/util');
|
||||
const { checkAccess } = require('~/server/middleware');
|
||||
const { getMessage } = require('~/models/Message');
|
||||
const { logger } = require('~/config');
|
||||
|
|
|
|||
|
|
@ -10,7 +10,6 @@ const openAI = require('~/server/services/Endpoints/openAI');
|
|||
const agents = require('~/server/services/Endpoints/agents');
|
||||
const custom = require('~/server/services/Endpoints/custom');
|
||||
const google = require('~/server/services/Endpoints/google');
|
||||
const { getConvoFiles } = require('~/models/Conversation');
|
||||
const { handleError } = require('~/server/utils');
|
||||
|
||||
const buildFunction = {
|
||||
|
|
@ -87,16 +86,8 @@ async function buildEndpointOption(req, res, next) {
|
|||
|
||||
// TODO: use `getModelsConfig` only when necessary
|
||||
const modelsConfig = await getModelsConfig(req);
|
||||
const { resendFiles = true } = req.body.endpointOption;
|
||||
req.body.endpointOption.modelsConfig = modelsConfig;
|
||||
if (isAgents && resendFiles && req.body.conversationId) {
|
||||
const fileIds = await getConvoFiles(req.body.conversationId);
|
||||
const requestFiles = req.body.files ?? [];
|
||||
if (requestFiles.length || fileIds.length) {
|
||||
req.body.endpointOption.attachments = processFiles(requestFiles, fileIds);
|
||||
}
|
||||
} else if (req.body.files) {
|
||||
// hold the promise
|
||||
if (req.body.files && !isAgents) {
|
||||
req.body.endpointOption.attachments = processFiles(req.body.files);
|
||||
}
|
||||
next();
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ const {
|
|||
} = require('~/server/services/Files/process');
|
||||
const { getStrategyFunctions } = require('~/server/services/Files/strategies');
|
||||
const { getOpenAIClient } = require('~/server/controllers/assistants/helpers');
|
||||
const { loadAuthValues } = require('~/app/clients/tools/util');
|
||||
const { loadAuthValues } = require('~/server/services/Tools/credentials');
|
||||
const { getAgent } = require('~/models/Agent');
|
||||
const { getFiles } = require('~/models/File');
|
||||
const { logger } = require('~/config');
|
||||
|
|
|
|||
|
|
@ -161,9 +161,9 @@ async function createActionTool({
|
|||
|
||||
if (metadata.auth && metadata.auth.type !== AuthTypeEnum.None) {
|
||||
try {
|
||||
const action_id = action.action_id;
|
||||
const identifier = `${req.user.id}:${action.action_id}`;
|
||||
if (metadata.auth.type === AuthTypeEnum.OAuth && metadata.auth.authorization_url) {
|
||||
const action_id = action.action_id;
|
||||
const identifier = `${req.user.id}:${action.action_id}`;
|
||||
const requestLogin = async () => {
|
||||
const { args: _args, stepId, ...toolCall } = config.toolCall ?? {};
|
||||
if (!stepId) {
|
||||
|
|
|
|||
|
|
@ -1,7 +1,14 @@
|
|||
const { FileSources, EModelEndpoint, getConfigDefaults } = require('librechat-data-provider');
|
||||
const {
|
||||
FileSources,
|
||||
EModelEndpoint,
|
||||
loadOCRConfig,
|
||||
processMCPEnv,
|
||||
getConfigDefaults,
|
||||
} = require('librechat-data-provider');
|
||||
const { checkVariables, checkHealth, checkConfig, checkAzureVariables } = require('./start/checks');
|
||||
const { azureAssistantsDefaults, assistantsConfigSetup } = require('./start/assistants');
|
||||
const { initializeFirebase } = require('./Files/Firebase/initialize');
|
||||
const { initializeS3 } = require('./Files/S3/initialize');
|
||||
const loadCustomConfig = require('./Config/loadCustomConfig');
|
||||
const handleRateLimits = require('./Config/handleRateLimits');
|
||||
const { loadDefaultInterface } = require('./start/interface');
|
||||
|
|
@ -25,6 +32,7 @@ const AppService = async (app) => {
|
|||
const config = (await loadCustomConfig()) ?? {};
|
||||
const configDefaults = getConfigDefaults();
|
||||
|
||||
const ocr = loadOCRConfig(config.ocr);
|
||||
const filteredTools = config.filteredTools;
|
||||
const includedTools = config.includedTools;
|
||||
const fileStrategy = config.fileStrategy ?? configDefaults.fileStrategy;
|
||||
|
|
@ -37,6 +45,8 @@ const AppService = async (app) => {
|
|||
|
||||
if (fileStrategy === FileSources.firebase) {
|
||||
initializeFirebase();
|
||||
} else if (fileStrategy === FileSources.s3) {
|
||||
initializeS3();
|
||||
}
|
||||
|
||||
/** @type {Record<string, FunctionTool} */
|
||||
|
|
@ -48,7 +58,7 @@ const AppService = async (app) => {
|
|||
|
||||
if (config.mcpServers != null) {
|
||||
const mcpManager = await getMCPManager();
|
||||
await mcpManager.initializeMCP(config.mcpServers);
|
||||
await mcpManager.initializeMCP(config.mcpServers, processMCPEnv);
|
||||
await mcpManager.mapAvailableTools(availableTools);
|
||||
}
|
||||
|
||||
|
|
@ -57,6 +67,7 @@ const AppService = async (app) => {
|
|||
const interfaceConfig = await loadDefaultInterface(config, configDefaults);
|
||||
|
||||
const defaultLocals = {
|
||||
ocr,
|
||||
paths,
|
||||
fileStrategy,
|
||||
socialLogins,
|
||||
|
|
|
|||
|
|
@ -120,6 +120,7 @@ describe('AppService', () => {
|
|||
},
|
||||
},
|
||||
paths: expect.anything(),
|
||||
ocr: expect.anything(),
|
||||
imageOutputType: expect.any(String),
|
||||
fileConfig: undefined,
|
||||
secureImageLinks: undefined,
|
||||
|
|
@ -588,4 +589,33 @@ describe('AppService updating app.locals and issuing warnings', () => {
|
|||
);
|
||||
});
|
||||
});
|
||||
|
||||
it('should not parse environment variable references in OCR config', async () => {
|
||||
// Mock custom configuration with env variable references in OCR config
|
||||
const mockConfig = {
|
||||
ocr: {
|
||||
apiKey: '${OCR_API_KEY_CUSTOM_VAR_NAME}',
|
||||
baseURL: '${OCR_BASEURL_CUSTOM_VAR_NAME}',
|
||||
strategy: 'mistral_ocr',
|
||||
mistralModel: 'mistral-medium',
|
||||
},
|
||||
};
|
||||
|
||||
require('./Config/loadCustomConfig').mockImplementationOnce(() => Promise.resolve(mockConfig));
|
||||
|
||||
// Set actual environment variables with different values
|
||||
process.env.OCR_API_KEY_CUSTOM_VAR_NAME = 'actual-api-key';
|
||||
process.env.OCR_BASEURL_CUSTOM_VAR_NAME = 'https://actual-ocr-url.com';
|
||||
|
||||
// Initialize app
|
||||
const app = { locals: {} };
|
||||
await AppService(app);
|
||||
|
||||
// Verify that the raw string references were preserved and not interpolated
|
||||
expect(app.locals.ocr).toBeDefined();
|
||||
expect(app.locals.ocr.apiKey).toEqual('${OCR_API_KEY_CUSTOM_VAR_NAME}');
|
||||
expect(app.locals.ocr.baseURL).toEqual('${OCR_BASEURL_CUSTOM_VAR_NAME}');
|
||||
expect(app.locals.ocr.strategy).toEqual('mistral_ocr');
|
||||
expect(app.locals.ocr.mistralModel).toEqual('mistral-medium');
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -72,4 +72,15 @@ async function getEndpointsConfig(req) {
|
|||
return endpointsConfig;
|
||||
}
|
||||
|
||||
module.exports = { getEndpointsConfig };
|
||||
/**
|
||||
* @param {ServerRequest} req
|
||||
* @param {import('librechat-data-provider').AgentCapabilities} capability
|
||||
* @returns {Promise<boolean>}
|
||||
*/
|
||||
const checkCapability = async (req, capability) => {
|
||||
const endpointsConfig = await getEndpointsConfig(req);
|
||||
const capabilities = endpointsConfig?.[EModelEndpoint.agents]?.capabilities ?? [];
|
||||
return capabilities.includes(capability);
|
||||
};
|
||||
|
||||
module.exports = { getEndpointsConfig, checkCapability };
|
||||
|
|
|
|||
|
|
@ -2,15 +2,8 @@ const { loadAgent } = require('~/models/Agent');
|
|||
const { logger } = require('~/config');
|
||||
|
||||
const buildOptions = (req, endpoint, parsedBody) => {
|
||||
const {
|
||||
spec,
|
||||
iconURL,
|
||||
agent_id,
|
||||
instructions,
|
||||
maxContextTokens,
|
||||
resendFiles = true,
|
||||
...model_parameters
|
||||
} = parsedBody;
|
||||
const { spec, iconURL, agent_id, instructions, maxContextTokens, ...model_parameters } =
|
||||
parsedBody;
|
||||
const agentPromise = loadAgent({
|
||||
req,
|
||||
agent_id,
|
||||
|
|
@ -24,7 +17,6 @@ const buildOptions = (req, endpoint, parsedBody) => {
|
|||
iconURL,
|
||||
endpoint,
|
||||
agent_id,
|
||||
resendFiles,
|
||||
instructions,
|
||||
maxContextTokens,
|
||||
model_parameters,
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ const { createContentAggregator, Providers } = require('@librechat/agents');
|
|||
const {
|
||||
EModelEndpoint,
|
||||
getResponseSender,
|
||||
AgentCapabilities,
|
||||
providerEndpointMap,
|
||||
} = require('librechat-data-provider');
|
||||
const {
|
||||
|
|
@ -15,10 +16,14 @@ const initCustom = require('~/server/services/Endpoints/custom/initialize');
|
|||
const initGoogle = require('~/server/services/Endpoints/google/initialize');
|
||||
const generateArtifactsPrompt = require('~/app/clients/prompts/artifacts');
|
||||
const { getCustomEndpointConfig } = require('~/server/services/Config');
|
||||
const { processFiles } = require('~/server/services/Files/process');
|
||||
const { loadAgentTools } = require('~/server/services/ToolService');
|
||||
const AgentClient = require('~/server/controllers/agents/client');
|
||||
const { getConvoFiles } = require('~/models/Conversation');
|
||||
const { getToolFilesByIds } = require('~/models/File');
|
||||
const { getModelMaxTokens } = require('~/utils');
|
||||
const { getAgent } = require('~/models/Agent');
|
||||
const { getFiles } = require('~/models/File');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const providerConfigMap = {
|
||||
|
|
@ -34,20 +39,38 @@ const providerConfigMap = {
|
|||
};
|
||||
|
||||
/**
|
||||
*
|
||||
* @param {ServerRequest} req
|
||||
* @param {Promise<Array<MongoFile | null>> | undefined} _attachments
|
||||
* @param {AgentToolResources | undefined} _tool_resources
|
||||
* @returns {Promise<{ attachments: Array<MongoFile | undefined> | undefined, tool_resources: AgentToolResources | undefined }>}
|
||||
*/
|
||||
const primeResources = async (_attachments, _tool_resources) => {
|
||||
const primeResources = async (req, _attachments, _tool_resources) => {
|
||||
try {
|
||||
/** @type {Array<MongoFile | undefined> | undefined} */
|
||||
let attachments;
|
||||
const tool_resources = _tool_resources ?? {};
|
||||
const isOCREnabled = (req.app.locals?.[EModelEndpoint.agents]?.capabilities ?? []).includes(
|
||||
AgentCapabilities.ocr,
|
||||
);
|
||||
if (tool_resources.ocr?.file_ids && isOCREnabled) {
|
||||
const context = await getFiles(
|
||||
{
|
||||
file_id: { $in: tool_resources.ocr.file_ids },
|
||||
},
|
||||
{},
|
||||
{},
|
||||
);
|
||||
attachments = (attachments ?? []).concat(context);
|
||||
}
|
||||
if (!_attachments) {
|
||||
return { attachments: undefined, tool_resources: _tool_resources };
|
||||
return { attachments, tool_resources };
|
||||
}
|
||||
/** @type {Array<MongoFile | undefined> | undefined} */
|
||||
const files = await _attachments;
|
||||
const attachments = [];
|
||||
const tool_resources = _tool_resources ?? {};
|
||||
if (!attachments) {
|
||||
/** @type {Array<MongoFile | undefined>} */
|
||||
attachments = [];
|
||||
}
|
||||
|
||||
for (const file of files) {
|
||||
if (!file) {
|
||||
|
|
@ -82,7 +105,6 @@ const primeResources = async (_attachments, _tool_resources) => {
|
|||
* @param {ServerResponse} params.res
|
||||
* @param {Agent} params.agent
|
||||
* @param {object} [params.endpointOption]
|
||||
* @param {AgentToolResources} [params.tool_resources]
|
||||
* @param {boolean} [params.isInitialAgent]
|
||||
* @returns {Promise<Agent>}
|
||||
*/
|
||||
|
|
@ -91,9 +113,30 @@ const initializeAgentOptions = async ({
|
|||
res,
|
||||
agent,
|
||||
endpointOption,
|
||||
tool_resources,
|
||||
isInitialAgent = false,
|
||||
}) => {
|
||||
let currentFiles;
|
||||
/** @type {Array<MongoFile>} */
|
||||
const requestFiles = req.body.files ?? [];
|
||||
if (
|
||||
isInitialAgent &&
|
||||
req.body.conversationId != null &&
|
||||
(agent.model_parameters?.resendFiles ?? true) === true
|
||||
) {
|
||||
const fileIds = (await getConvoFiles(req.body.conversationId)) ?? [];
|
||||
const toolFiles = await getToolFilesByIds(fileIds);
|
||||
if (requestFiles.length || toolFiles.length) {
|
||||
currentFiles = await processFiles(requestFiles.concat(toolFiles));
|
||||
}
|
||||
} else if (isInitialAgent && requestFiles.length) {
|
||||
currentFiles = await processFiles(requestFiles);
|
||||
}
|
||||
|
||||
const { attachments, tool_resources } = await primeResources(
|
||||
req,
|
||||
currentFiles,
|
||||
agent.tool_resources,
|
||||
);
|
||||
const { tools, toolContextMap } = await loadAgentTools({
|
||||
req,
|
||||
res,
|
||||
|
|
@ -138,6 +181,7 @@ const initializeAgentOptions = async ({
|
|||
agent.provider = options.provider;
|
||||
}
|
||||
|
||||
/** @type {import('@librechat/agents').ClientOptions} */
|
||||
agent.model_parameters = Object.assign(model_parameters, options.llmConfig);
|
||||
if (options.configOptions) {
|
||||
agent.model_parameters.configuration = options.configOptions;
|
||||
|
|
@ -156,15 +200,16 @@ const initializeAgentOptions = async ({
|
|||
|
||||
const tokensModel =
|
||||
agent.provider === EModelEndpoint.azureOpenAI ? agent.model : agent.model_parameters.model;
|
||||
const maxTokens = agent.model_parameters.maxOutputTokens ?? agent.model_parameters.maxTokens ?? 0;
|
||||
|
||||
return {
|
||||
...agent,
|
||||
tools,
|
||||
attachments,
|
||||
toolContextMap,
|
||||
maxContextTokens:
|
||||
agent.max_context_tokens ??
|
||||
getModelMaxTokens(tokensModel, providerEndpointMap[provider]) ??
|
||||
4000,
|
||||
((getModelMaxTokens(tokensModel, providerEndpointMap[provider]) ?? 4000) - maxTokens) * 0.9,
|
||||
};
|
||||
};
|
||||
|
||||
|
|
@ -197,11 +242,6 @@ const initializeClient = async ({ req, res, endpointOption }) => {
|
|||
throw new Error('Agent not found');
|
||||
}
|
||||
|
||||
const { attachments, tool_resources } = await primeResources(
|
||||
endpointOption.attachments,
|
||||
primaryAgent.tool_resources,
|
||||
);
|
||||
|
||||
const agentConfigs = new Map();
|
||||
|
||||
// Handle primary agent
|
||||
|
|
@ -210,7 +250,6 @@ const initializeClient = async ({ req, res, endpointOption }) => {
|
|||
res,
|
||||
agent: primaryAgent,
|
||||
endpointOption,
|
||||
tool_resources,
|
||||
isInitialAgent: true,
|
||||
});
|
||||
|
||||
|
|
@ -240,18 +279,21 @@ const initializeClient = async ({ req, res, endpointOption }) => {
|
|||
|
||||
const client = new AgentClient({
|
||||
req,
|
||||
agent: primaryConfig,
|
||||
res,
|
||||
sender,
|
||||
attachments,
|
||||
contentParts,
|
||||
agentConfigs,
|
||||
eventHandlers,
|
||||
collectedUsage,
|
||||
aggregateContent,
|
||||
artifactPromises,
|
||||
agent: primaryConfig,
|
||||
spec: endpointOption.spec,
|
||||
iconURL: endpointOption.iconURL,
|
||||
agentConfigs,
|
||||
endpoint: EModelEndpoint.agents,
|
||||
attachments: primaryConfig.attachments,
|
||||
maxContextTokens: primaryConfig.maxContextTokens,
|
||||
resendFiles: primaryConfig.model_parameters?.resendFiles ?? true,
|
||||
});
|
||||
|
||||
return { client };
|
||||
|
|
|
|||
|
|
@ -23,8 +23,9 @@ const initializeClient = async ({ req, res, endpointOption }) => {
|
|||
const agent = {
|
||||
id: EModelEndpoint.bedrock,
|
||||
name: endpointOption.name,
|
||||
instructions: endpointOption.promptPrefix,
|
||||
provider: EModelEndpoint.bedrock,
|
||||
endpoint: EModelEndpoint.bedrock,
|
||||
instructions: endpointOption.promptPrefix,
|
||||
model: endpointOption.model_parameters.model,
|
||||
model_parameters: endpointOption.model_parameters,
|
||||
};
|
||||
|
|
@ -54,6 +55,7 @@ const initializeClient = async ({ req, res, endpointOption }) => {
|
|||
|
||||
const client = new AgentClient({
|
||||
req,
|
||||
res,
|
||||
agent,
|
||||
sender,
|
||||
// tools,
|
||||
|
|
|
|||
|
|
@ -135,12 +135,9 @@ const initializeClient = async ({
|
|||
}
|
||||
|
||||
if (optionsOnly) {
|
||||
clientOptions = Object.assign(
|
||||
{
|
||||
modelOptions: endpointOption.model_parameters,
|
||||
},
|
||||
clientOptions,
|
||||
);
|
||||
const modelOptions = endpointOption.model_parameters;
|
||||
modelOptions.model = modelName;
|
||||
clientOptions = Object.assign({ modelOptions }, clientOptions);
|
||||
clientOptions.modelOptions.user = req.user.id;
|
||||
const options = getLLMConfig(apiKey, clientOptions);
|
||||
if (!clientOptions.streamRate) {
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ const { isEnabled } = require('~/server/utils');
|
|||
* @returns {Object} Configuration options for creating an LLM instance.
|
||||
*/
|
||||
function getLLMConfig(apiKey, options = {}, endpoint = null) {
|
||||
const {
|
||||
let {
|
||||
modelOptions = {},
|
||||
reverseProxyUrl,
|
||||
defaultQuery,
|
||||
|
|
@ -50,10 +50,32 @@ function getLLMConfig(apiKey, options = {}, endpoint = null) {
|
|||
if (addParams && typeof addParams === 'object') {
|
||||
Object.assign(llmConfig, addParams);
|
||||
}
|
||||
/** Note: OpenAI Web Search models do not support any known parameters besdies `max_tokens` */
|
||||
if (modelOptions.model && /gpt-4o.*search/.test(modelOptions.model)) {
|
||||
const searchExcludeParams = [
|
||||
'frequency_penalty',
|
||||
'presence_penalty',
|
||||
'temperature',
|
||||
'top_p',
|
||||
'top_k',
|
||||
'stop',
|
||||
'logit_bias',
|
||||
'seed',
|
||||
'response_format',
|
||||
'n',
|
||||
'logprobs',
|
||||
'user',
|
||||
];
|
||||
|
||||
dropParams = dropParams || [];
|
||||
dropParams = [...new Set([...dropParams, ...searchExcludeParams])];
|
||||
}
|
||||
|
||||
if (dropParams && Array.isArray(dropParams)) {
|
||||
dropParams.forEach((param) => {
|
||||
delete llmConfig[param];
|
||||
if (llmConfig[param]) {
|
||||
llmConfig[param] = undefined;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
const axios = require('axios');
|
||||
const FormData = require('form-data');
|
||||
const { getCodeBaseURL } = require('@librechat/agents');
|
||||
const { createAxiosInstance } = require('~/config');
|
||||
const { logAxiosError } = require('~/utils');
|
||||
|
||||
const axios = createAxiosInstance();
|
||||
|
||||
const MAX_FILE_SIZE = 150 * 1024 * 1024;
|
||||
|
||||
/**
|
||||
|
|
@ -27,13 +29,6 @@ async function getCodeOutputDownloadStream(fileIdentifier, apiKey) {
|
|||
timeout: 15000,
|
||||
};
|
||||
|
||||
if (process.env.PROXY) {
|
||||
options.proxy = {
|
||||
host: process.env.PROXY,
|
||||
protocol: process.env.PROXY.startsWith('https') ? 'https' : 'http',
|
||||
};
|
||||
}
|
||||
|
||||
const response = await axios(options);
|
||||
return response;
|
||||
} catch (error) {
|
||||
|
|
@ -79,13 +74,6 @@ async function uploadCodeEnvFile({ req, stream, filename, apiKey, entity_id = ''
|
|||
maxBodyLength: MAX_FILE_SIZE,
|
||||
};
|
||||
|
||||
if (process.env.PROXY) {
|
||||
options.proxy = {
|
||||
host: process.env.PROXY,
|
||||
protocol: process.env.PROXY.startsWith('https') ? 'https' : 'http',
|
||||
};
|
||||
}
|
||||
|
||||
const response = await axios.post(`${baseURL}/upload`, form, options);
|
||||
|
||||
/** @type {{ message: string; session_id: string; files: Array<{ fileId: string; filename: string }> }} */
|
||||
|
|
|
|||
207
api/server/services/Files/MistralOCR/crud.js
Normal file
207
api/server/services/Files/MistralOCR/crud.js
Normal file
|
|
@ -0,0 +1,207 @@
|
|||
// ~/server/services/Files/MistralOCR/crud.js
|
||||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
const FormData = require('form-data');
|
||||
const { FileSources, envVarRegex, extractEnvVariable } = require('librechat-data-provider');
|
||||
const { loadAuthValues } = require('~/server/services/Tools/credentials');
|
||||
const { logger, createAxiosInstance } = require('~/config');
|
||||
const { logAxiosError } = require('~/utils');
|
||||
|
||||
const axios = createAxiosInstance();
|
||||
|
||||
/**
|
||||
* Uploads a document to Mistral API using file streaming to avoid loading the entire file into memory
|
||||
*
|
||||
* @param {Object} params Upload parameters
|
||||
* @param {string} params.filePath The path to the file on disk
|
||||
* @param {string} [params.fileName] Optional filename to use (defaults to the name from filePath)
|
||||
* @param {string} params.apiKey Mistral API key
|
||||
* @param {string} [params.baseURL=https://api.mistral.ai/v1] Mistral API base URL
|
||||
* @returns {Promise<Object>} The response from Mistral API
|
||||
*/
|
||||
async function uploadDocumentToMistral({
|
||||
filePath,
|
||||
fileName = '',
|
||||
apiKey,
|
||||
baseURL = 'https://api.mistral.ai/v1',
|
||||
}) {
|
||||
const form = new FormData();
|
||||
form.append('purpose', 'ocr');
|
||||
const actualFileName = fileName || path.basename(filePath);
|
||||
const fileStream = fs.createReadStream(filePath);
|
||||
form.append('file', fileStream, { filename: actualFileName });
|
||||
|
||||
return axios
|
||||
.post(`${baseURL}/files`, form, {
|
||||
headers: {
|
||||
Authorization: `Bearer ${apiKey}`,
|
||||
...form.getHeaders(),
|
||||
},
|
||||
maxBodyLength: Infinity,
|
||||
maxContentLength: Infinity,
|
||||
})
|
||||
.then((res) => res.data)
|
||||
.catch((error) => {
|
||||
logger.error('Error uploading document to Mistral:', error.message);
|
||||
throw error;
|
||||
});
|
||||
}
|
||||
|
||||
async function getSignedUrl({
|
||||
apiKey,
|
||||
fileId,
|
||||
expiry = 24,
|
||||
baseURL = 'https://api.mistral.ai/v1',
|
||||
}) {
|
||||
return axios
|
||||
.get(`${baseURL}/files/${fileId}/url?expiry=${expiry}`, {
|
||||
headers: {
|
||||
Authorization: `Bearer ${apiKey}`,
|
||||
},
|
||||
})
|
||||
.then((res) => res.data)
|
||||
.catch((error) => {
|
||||
logger.error('Error fetching signed URL:', error.message);
|
||||
throw error;
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* @param {Object} params
|
||||
* @param {string} params.apiKey
|
||||
* @param {string} params.documentUrl
|
||||
* @param {string} [params.baseURL]
|
||||
* @returns {Promise<OCRResult>}
|
||||
*/
|
||||
async function performOCR({
|
||||
apiKey,
|
||||
documentUrl,
|
||||
model = 'mistral-ocr-latest',
|
||||
baseURL = 'https://api.mistral.ai/v1',
|
||||
}) {
|
||||
return axios
|
||||
.post(
|
||||
`${baseURL}/ocr`,
|
||||
{
|
||||
model,
|
||||
include_image_base64: false,
|
||||
document: {
|
||||
type: 'document_url',
|
||||
document_url: documentUrl,
|
||||
},
|
||||
},
|
||||
{
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${apiKey}`,
|
||||
},
|
||||
},
|
||||
)
|
||||
.then((res) => res.data)
|
||||
.catch((error) => {
|
||||
logger.error('Error performing OCR:', error.message);
|
||||
throw error;
|
||||
});
|
||||
}
|
||||
|
||||
function extractVariableName(str) {
|
||||
const match = str.match(envVarRegex);
|
||||
return match ? match[1] : null;
|
||||
}
|
||||
|
||||
const uploadMistralOCR = async ({ req, file, file_id, entity_id }) => {
|
||||
try {
|
||||
/** @type {TCustomConfig['ocr']} */
|
||||
const ocrConfig = req.app.locals?.ocr;
|
||||
|
||||
const apiKeyConfig = ocrConfig.apiKey || '';
|
||||
const baseURLConfig = ocrConfig.baseURL || '';
|
||||
|
||||
const isApiKeyEnvVar = envVarRegex.test(apiKeyConfig);
|
||||
const isBaseURLEnvVar = envVarRegex.test(baseURLConfig);
|
||||
|
||||
const isApiKeyEmpty = !apiKeyConfig.trim();
|
||||
const isBaseURLEmpty = !baseURLConfig.trim();
|
||||
|
||||
let apiKey, baseURL;
|
||||
|
||||
if (isApiKeyEnvVar || isBaseURLEnvVar || isApiKeyEmpty || isBaseURLEmpty) {
|
||||
const apiKeyVarName = isApiKeyEnvVar ? extractVariableName(apiKeyConfig) : 'OCR_API_KEY';
|
||||
const baseURLVarName = isBaseURLEnvVar ? extractVariableName(baseURLConfig) : 'OCR_BASEURL';
|
||||
|
||||
const authValues = await loadAuthValues({
|
||||
userId: req.user.id,
|
||||
authFields: [baseURLVarName, apiKeyVarName],
|
||||
optional: new Set([baseURLVarName]),
|
||||
});
|
||||
|
||||
apiKey = authValues[apiKeyVarName];
|
||||
baseURL = authValues[baseURLVarName];
|
||||
} else {
|
||||
apiKey = apiKeyConfig;
|
||||
baseURL = baseURLConfig;
|
||||
}
|
||||
|
||||
const mistralFile = await uploadDocumentToMistral({
|
||||
filePath: file.path,
|
||||
fileName: file.originalname,
|
||||
apiKey,
|
||||
baseURL,
|
||||
});
|
||||
|
||||
const modelConfig = ocrConfig.mistralModel || '';
|
||||
const model = envVarRegex.test(modelConfig)
|
||||
? extractEnvVariable(modelConfig)
|
||||
: modelConfig.trim() || 'mistral-ocr-latest';
|
||||
|
||||
const signedUrlResponse = await getSignedUrl({
|
||||
apiKey,
|
||||
baseURL,
|
||||
fileId: mistralFile.id,
|
||||
});
|
||||
|
||||
const ocrResult = await performOCR({
|
||||
apiKey,
|
||||
baseURL,
|
||||
model,
|
||||
documentUrl: signedUrlResponse.url,
|
||||
});
|
||||
|
||||
let aggregatedText = '';
|
||||
const images = [];
|
||||
ocrResult.pages.forEach((page, index) => {
|
||||
if (ocrResult.pages.length > 1) {
|
||||
aggregatedText += `# PAGE ${index + 1}\n`;
|
||||
}
|
||||
|
||||
aggregatedText += page.markdown + '\n\n';
|
||||
|
||||
if (page.images && page.images.length > 0) {
|
||||
page.images.forEach((image) => {
|
||||
if (image.image_base64) {
|
||||
images.push(image.image_base64);
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
return {
|
||||
filename: file.originalname,
|
||||
bytes: aggregatedText.length * 4,
|
||||
filepath: FileSources.mistral_ocr,
|
||||
text: aggregatedText,
|
||||
images,
|
||||
};
|
||||
} catch (error) {
|
||||
const message = 'Error uploading document to Mistral OCR API';
|
||||
logAxiosError({ error, message });
|
||||
throw new Error(message);
|
||||
}
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
uploadDocumentToMistral,
|
||||
uploadMistralOCR,
|
||||
getSignedUrl,
|
||||
performOCR,
|
||||
};
|
||||
737
api/server/services/Files/MistralOCR/crud.spec.js
Normal file
737
api/server/services/Files/MistralOCR/crud.spec.js
Normal file
|
|
@ -0,0 +1,737 @@
|
|||
const fs = require('fs');
|
||||
|
||||
const mockAxios = {
|
||||
interceptors: {
|
||||
request: { use: jest.fn(), eject: jest.fn() },
|
||||
response: { use: jest.fn(), eject: jest.fn() },
|
||||
},
|
||||
create: jest.fn().mockReturnValue({
|
||||
defaults: {
|
||||
proxy: null,
|
||||
},
|
||||
get: jest.fn().mockResolvedValue({ data: {} }),
|
||||
post: jest.fn().mockResolvedValue({ data: {} }),
|
||||
put: jest.fn().mockResolvedValue({ data: {} }),
|
||||
delete: jest.fn().mockResolvedValue({ data: {} }),
|
||||
}),
|
||||
get: jest.fn().mockResolvedValue({ data: {} }),
|
||||
post: jest.fn().mockResolvedValue({ data: {} }),
|
||||
put: jest.fn().mockResolvedValue({ data: {} }),
|
||||
delete: jest.fn().mockResolvedValue({ data: {} }),
|
||||
reset: jest.fn().mockImplementation(function () {
|
||||
this.get.mockClear();
|
||||
this.post.mockClear();
|
||||
this.put.mockClear();
|
||||
this.delete.mockClear();
|
||||
this.create.mockClear();
|
||||
}),
|
||||
};
|
||||
|
||||
jest.mock('axios', () => mockAxios);
|
||||
jest.mock('fs');
|
||||
jest.mock('~/utils', () => ({
|
||||
logAxiosError: jest.fn(),
|
||||
}));
|
||||
jest.mock('~/config', () => ({
|
||||
logger: {
|
||||
error: jest.fn(),
|
||||
},
|
||||
createAxiosInstance: () => mockAxios,
|
||||
}));
|
||||
jest.mock('~/server/services/Tools/credentials', () => ({
|
||||
loadAuthValues: jest.fn(),
|
||||
}));
|
||||
|
||||
const { uploadDocumentToMistral, uploadMistralOCR, getSignedUrl, performOCR } = require('./crud');
|
||||
|
||||
describe('MistralOCR Service', () => {
|
||||
afterEach(() => {
|
||||
mockAxios.reset();
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
describe('uploadDocumentToMistral', () => {
|
||||
beforeEach(() => {
|
||||
// Create a more complete mock for file streams that FormData can work with
|
||||
const mockReadStream = {
|
||||
on: jest.fn().mockImplementation(function (event, handler) {
|
||||
// Simulate immediate 'end' event to make FormData complete processing
|
||||
if (event === 'end') {
|
||||
handler();
|
||||
}
|
||||
return this;
|
||||
}),
|
||||
pipe: jest.fn().mockImplementation(function () {
|
||||
return this;
|
||||
}),
|
||||
pause: jest.fn(),
|
||||
resume: jest.fn(),
|
||||
emit: jest.fn(),
|
||||
once: jest.fn(),
|
||||
destroy: jest.fn(),
|
||||
};
|
||||
|
||||
fs.createReadStream = jest.fn().mockReturnValue(mockReadStream);
|
||||
|
||||
// Mock FormData's append to avoid actual stream processing
|
||||
jest.mock('form-data', () => {
|
||||
const mockFormData = function () {
|
||||
return {
|
||||
append: jest.fn(),
|
||||
getHeaders: jest
|
||||
.fn()
|
||||
.mockReturnValue({ 'content-type': 'multipart/form-data; boundary=---boundary' }),
|
||||
getBuffer: jest.fn().mockReturnValue(Buffer.from('mock-form-data')),
|
||||
getLength: jest.fn().mockReturnValue(100),
|
||||
};
|
||||
};
|
||||
return mockFormData;
|
||||
});
|
||||
});
|
||||
|
||||
it('should upload a document to Mistral API using file streaming', async () => {
|
||||
const mockResponse = { data: { id: 'file-123', purpose: 'ocr' } };
|
||||
mockAxios.post.mockResolvedValueOnce(mockResponse);
|
||||
|
||||
const result = await uploadDocumentToMistral({
|
||||
filePath: '/path/to/test.pdf',
|
||||
fileName: 'test.pdf',
|
||||
apiKey: 'test-api-key',
|
||||
});
|
||||
|
||||
// Check that createReadStream was called with the correct file path
|
||||
expect(fs.createReadStream).toHaveBeenCalledWith('/path/to/test.pdf');
|
||||
|
||||
// Since we're mocking FormData, we'll just check that axios was called correctly
|
||||
expect(mockAxios.post).toHaveBeenCalledWith(
|
||||
'https://api.mistral.ai/v1/files',
|
||||
expect.anything(),
|
||||
expect.objectContaining({
|
||||
headers: expect.objectContaining({
|
||||
Authorization: 'Bearer test-api-key',
|
||||
}),
|
||||
maxBodyLength: Infinity,
|
||||
maxContentLength: Infinity,
|
||||
}),
|
||||
);
|
||||
expect(result).toEqual(mockResponse.data);
|
||||
});
|
||||
|
||||
it('should handle errors during document upload', async () => {
|
||||
const errorMessage = 'API error';
|
||||
mockAxios.post.mockRejectedValueOnce(new Error(errorMessage));
|
||||
|
||||
await expect(
|
||||
uploadDocumentToMistral({
|
||||
filePath: '/path/to/test.pdf',
|
||||
fileName: 'test.pdf',
|
||||
apiKey: 'test-api-key',
|
||||
}),
|
||||
).rejects.toThrow();
|
||||
|
||||
const { logger } = require('~/config');
|
||||
expect(logger.error).toHaveBeenCalledWith(
|
||||
expect.stringContaining('Error uploading document to Mistral:'),
|
||||
expect.any(String),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getSignedUrl', () => {
|
||||
it('should fetch signed URL from Mistral API', async () => {
|
||||
const mockResponse = { data: { url: 'https://document-url.com' } };
|
||||
mockAxios.get.mockResolvedValueOnce(mockResponse);
|
||||
|
||||
const result = await getSignedUrl({
|
||||
fileId: 'file-123',
|
||||
apiKey: 'test-api-key',
|
||||
});
|
||||
|
||||
expect(mockAxios.get).toHaveBeenCalledWith(
|
||||
'https://api.mistral.ai/v1/files/file-123/url?expiry=24',
|
||||
{
|
||||
headers: {
|
||||
Authorization: 'Bearer test-api-key',
|
||||
},
|
||||
},
|
||||
);
|
||||
expect(result).toEqual(mockResponse.data);
|
||||
});
|
||||
|
||||
it('should handle errors when fetching signed URL', async () => {
|
||||
const errorMessage = 'API error';
|
||||
mockAxios.get.mockRejectedValueOnce(new Error(errorMessage));
|
||||
|
||||
await expect(
|
||||
getSignedUrl({
|
||||
fileId: 'file-123',
|
||||
apiKey: 'test-api-key',
|
||||
}),
|
||||
).rejects.toThrow();
|
||||
|
||||
const { logger } = require('~/config');
|
||||
expect(logger.error).toHaveBeenCalledWith('Error fetching signed URL:', errorMessage);
|
||||
});
|
||||
});
|
||||
|
||||
describe('performOCR', () => {
|
||||
it('should perform OCR using Mistral API', async () => {
|
||||
const mockResponse = {
|
||||
data: {
|
||||
pages: [{ markdown: 'Page 1 content' }, { markdown: 'Page 2 content' }],
|
||||
},
|
||||
};
|
||||
mockAxios.post.mockResolvedValueOnce(mockResponse);
|
||||
|
||||
const result = await performOCR({
|
||||
apiKey: 'test-api-key',
|
||||
documentUrl: 'https://document-url.com',
|
||||
model: 'mistral-ocr-latest',
|
||||
});
|
||||
|
||||
expect(mockAxios.post).toHaveBeenCalledWith(
|
||||
'https://api.mistral.ai/v1/ocr',
|
||||
{
|
||||
model: 'mistral-ocr-latest',
|
||||
include_image_base64: false,
|
||||
document: {
|
||||
type: 'document_url',
|
||||
document_url: 'https://document-url.com',
|
||||
},
|
||||
},
|
||||
{
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: 'Bearer test-api-key',
|
||||
},
|
||||
},
|
||||
);
|
||||
expect(result).toEqual(mockResponse.data);
|
||||
});
|
||||
|
||||
it('should handle errors during OCR processing', async () => {
|
||||
const errorMessage = 'OCR processing error';
|
||||
mockAxios.post.mockRejectedValueOnce(new Error(errorMessage));
|
||||
|
||||
await expect(
|
||||
performOCR({
|
||||
apiKey: 'test-api-key',
|
||||
documentUrl: 'https://document-url.com',
|
||||
}),
|
||||
).rejects.toThrow();
|
||||
|
||||
const { logger } = require('~/config');
|
||||
expect(logger.error).toHaveBeenCalledWith('Error performing OCR:', errorMessage);
|
||||
});
|
||||
});
|
||||
|
||||
describe('uploadMistralOCR', () => {
|
||||
beforeEach(() => {
|
||||
const mockReadStream = {
|
||||
on: jest.fn().mockImplementation(function (event, handler) {
|
||||
if (event === 'end') {
|
||||
handler();
|
||||
}
|
||||
return this;
|
||||
}),
|
||||
pipe: jest.fn().mockImplementation(function () {
|
||||
return this;
|
||||
}),
|
||||
pause: jest.fn(),
|
||||
resume: jest.fn(),
|
||||
emit: jest.fn(),
|
||||
once: jest.fn(),
|
||||
destroy: jest.fn(),
|
||||
};
|
||||
|
||||
fs.createReadStream = jest.fn().mockReturnValue(mockReadStream);
|
||||
});
|
||||
|
||||
it('should process OCR for a file with standard configuration', async () => {
|
||||
// Setup mocks
|
||||
const { loadAuthValues } = require('~/server/services/Tools/credentials');
|
||||
loadAuthValues.mockResolvedValue({
|
||||
OCR_API_KEY: 'test-api-key',
|
||||
OCR_BASEURL: 'https://api.mistral.ai/v1',
|
||||
});
|
||||
|
||||
// Mock file upload response
|
||||
mockAxios.post.mockResolvedValueOnce({
|
||||
data: { id: 'file-123', purpose: 'ocr' },
|
||||
});
|
||||
|
||||
// Mock signed URL response
|
||||
mockAxios.get.mockResolvedValueOnce({
|
||||
data: { url: 'https://signed-url.com' },
|
||||
});
|
||||
|
||||
// Mock OCR response with text and images
|
||||
mockAxios.post.mockResolvedValueOnce({
|
||||
data: {
|
||||
pages: [
|
||||
{
|
||||
markdown: 'Page 1 content',
|
||||
images: [{ image_base64: 'base64image1' }],
|
||||
},
|
||||
{
|
||||
markdown: 'Page 2 content',
|
||||
images: [{ image_base64: 'base64image2' }],
|
||||
},
|
||||
],
|
||||
},
|
||||
});
|
||||
|
||||
const req = {
|
||||
user: { id: 'user123' },
|
||||
app: {
|
||||
locals: {
|
||||
ocr: {
|
||||
// Use environment variable syntax to ensure loadAuthValues is called
|
||||
apiKey: '${OCR_API_KEY}',
|
||||
baseURL: '${OCR_BASEURL}',
|
||||
mistralModel: 'mistral-medium',
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const file = {
|
||||
path: '/tmp/upload/file.pdf',
|
||||
originalname: 'document.pdf',
|
||||
};
|
||||
|
||||
const result = await uploadMistralOCR({
|
||||
req,
|
||||
file,
|
||||
file_id: 'file123',
|
||||
entity_id: 'entity123',
|
||||
});
|
||||
|
||||
expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf');
|
||||
|
||||
expect(loadAuthValues).toHaveBeenCalledWith({
|
||||
userId: 'user123',
|
||||
authFields: ['OCR_BASEURL', 'OCR_API_KEY'],
|
||||
optional: expect.any(Set),
|
||||
});
|
||||
|
||||
// Verify OCR result
|
||||
expect(result).toEqual({
|
||||
filename: 'document.pdf',
|
||||
bytes: expect.any(Number),
|
||||
filepath: 'mistral_ocr',
|
||||
text: expect.stringContaining('# PAGE 1'),
|
||||
images: ['base64image1', 'base64image2'],
|
||||
});
|
||||
});
|
||||
|
||||
it('should process variable references in configuration', async () => {
|
||||
// Setup mocks with environment variables
|
||||
const { loadAuthValues } = require('~/server/services/Tools/credentials');
|
||||
loadAuthValues.mockResolvedValue({
|
||||
CUSTOM_API_KEY: 'custom-api-key',
|
||||
CUSTOM_BASEURL: 'https://custom-api.mistral.ai/v1',
|
||||
});
|
||||
|
||||
// Mock API responses
|
||||
mockAxios.post.mockResolvedValueOnce({
|
||||
data: { id: 'file-123', purpose: 'ocr' },
|
||||
});
|
||||
mockAxios.get.mockResolvedValueOnce({
|
||||
data: { url: 'https://signed-url.com' },
|
||||
});
|
||||
mockAxios.post.mockResolvedValueOnce({
|
||||
data: {
|
||||
pages: [{ markdown: 'Content from custom API' }],
|
||||
},
|
||||
});
|
||||
|
||||
const req = {
|
||||
user: { id: 'user123' },
|
||||
app: {
|
||||
locals: {
|
||||
ocr: {
|
||||
apiKey: '${CUSTOM_API_KEY}',
|
||||
baseURL: '${CUSTOM_BASEURL}',
|
||||
mistralModel: '${CUSTOM_MODEL}',
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
// Set environment variable for model
|
||||
process.env.CUSTOM_MODEL = 'mistral-large';
|
||||
|
||||
const file = {
|
||||
path: '/tmp/upload/file.pdf',
|
||||
originalname: 'document.pdf',
|
||||
};
|
||||
|
||||
const result = await uploadMistralOCR({
|
||||
req,
|
||||
file,
|
||||
file_id: 'file123',
|
||||
entity_id: 'entity123',
|
||||
});
|
||||
|
||||
expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf');
|
||||
|
||||
// Verify that custom environment variables were extracted and used
|
||||
expect(loadAuthValues).toHaveBeenCalledWith({
|
||||
userId: 'user123',
|
||||
authFields: ['CUSTOM_BASEURL', 'CUSTOM_API_KEY'],
|
||||
optional: expect.any(Set),
|
||||
});
|
||||
|
||||
// Check that mistral-large was used in the OCR API call
|
||||
expect(mockAxios.post).toHaveBeenCalledWith(
|
||||
expect.anything(),
|
||||
expect.objectContaining({
|
||||
model: 'mistral-large',
|
||||
}),
|
||||
expect.anything(),
|
||||
);
|
||||
|
||||
expect(result.text).toEqual('Content from custom API\n\n');
|
||||
});
|
||||
|
||||
it('should fall back to default values when variables are not properly formatted', async () => {
|
||||
const { loadAuthValues } = require('~/server/services/Tools/credentials');
|
||||
loadAuthValues.mockResolvedValue({
|
||||
OCR_API_KEY: 'default-api-key',
|
||||
OCR_BASEURL: undefined, // Testing optional parameter
|
||||
});
|
||||
|
||||
mockAxios.post.mockResolvedValueOnce({
|
||||
data: { id: 'file-123', purpose: 'ocr' },
|
||||
});
|
||||
mockAxios.get.mockResolvedValueOnce({
|
||||
data: { url: 'https://signed-url.com' },
|
||||
});
|
||||
mockAxios.post.mockResolvedValueOnce({
|
||||
data: {
|
||||
pages: [{ markdown: 'Default API result' }],
|
||||
},
|
||||
});
|
||||
|
||||
const req = {
|
||||
user: { id: 'user123' },
|
||||
app: {
|
||||
locals: {
|
||||
ocr: {
|
||||
// Use environment variable syntax to ensure loadAuthValues is called
|
||||
apiKey: '${INVALID_FORMAT}', // Using valid env var format but with an invalid name
|
||||
baseURL: '${OCR_BASEURL}', // Using valid env var format
|
||||
mistralModel: 'mistral-ocr-latest', // Plain string value
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const file = {
|
||||
path: '/tmp/upload/file.pdf',
|
||||
originalname: 'document.pdf',
|
||||
};
|
||||
|
||||
await uploadMistralOCR({
|
||||
req,
|
||||
file,
|
||||
file_id: 'file123',
|
||||
entity_id: 'entity123',
|
||||
});
|
||||
|
||||
expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf');
|
||||
|
||||
// Should use the default values
|
||||
expect(loadAuthValues).toHaveBeenCalledWith({
|
||||
userId: 'user123',
|
||||
authFields: ['OCR_BASEURL', 'INVALID_FORMAT'],
|
||||
optional: expect.any(Set),
|
||||
});
|
||||
|
||||
// Should use the default model when not using environment variable format
|
||||
expect(mockAxios.post).toHaveBeenCalledWith(
|
||||
expect.anything(),
|
||||
expect.objectContaining({
|
||||
model: 'mistral-ocr-latest',
|
||||
}),
|
||||
expect.anything(),
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle API errors during OCR process', async () => {
|
||||
const { loadAuthValues } = require('~/server/services/Tools/credentials');
|
||||
loadAuthValues.mockResolvedValue({
|
||||
OCR_API_KEY: 'test-api-key',
|
||||
});
|
||||
|
||||
// Mock file upload to fail
|
||||
mockAxios.post.mockRejectedValueOnce(new Error('Upload failed'));
|
||||
|
||||
const req = {
|
||||
user: { id: 'user123' },
|
||||
app: {
|
||||
locals: {
|
||||
ocr: {
|
||||
apiKey: 'OCR_API_KEY',
|
||||
baseURL: 'OCR_BASEURL',
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const file = {
|
||||
path: '/tmp/upload/file.pdf',
|
||||
originalname: 'document.pdf',
|
||||
};
|
||||
|
||||
await expect(
|
||||
uploadMistralOCR({
|
||||
req,
|
||||
file,
|
||||
file_id: 'file123',
|
||||
entity_id: 'entity123',
|
||||
}),
|
||||
).rejects.toThrow('Error uploading document to Mistral OCR API');
|
||||
expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf');
|
||||
|
||||
const { logAxiosError } = require('~/utils');
|
||||
expect(logAxiosError).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle single page documents without page numbering', async () => {
|
||||
const { loadAuthValues } = require('~/server/services/Tools/credentials');
|
||||
loadAuthValues.mockResolvedValue({
|
||||
OCR_API_KEY: 'test-api-key',
|
||||
OCR_BASEURL: 'https://api.mistral.ai/v1', // Make sure this is included
|
||||
});
|
||||
|
||||
// Clear all previous mocks
|
||||
mockAxios.post.mockClear();
|
||||
mockAxios.get.mockClear();
|
||||
|
||||
// 1. First mock: File upload response
|
||||
mockAxios.post.mockImplementationOnce(() =>
|
||||
Promise.resolve({ data: { id: 'file-123', purpose: 'ocr' } }),
|
||||
);
|
||||
|
||||
// 2. Second mock: Signed URL response
|
||||
mockAxios.get.mockImplementationOnce(() =>
|
||||
Promise.resolve({ data: { url: 'https://signed-url.com' } }),
|
||||
);
|
||||
|
||||
// 3. Third mock: OCR response
|
||||
mockAxios.post.mockImplementationOnce(() =>
|
||||
Promise.resolve({
|
||||
data: {
|
||||
pages: [{ markdown: 'Single page content' }],
|
||||
},
|
||||
}),
|
||||
);
|
||||
|
||||
const req = {
|
||||
user: { id: 'user123' },
|
||||
app: {
|
||||
locals: {
|
||||
ocr: {
|
||||
apiKey: 'OCR_API_KEY',
|
||||
baseURL: 'OCR_BASEURL',
|
||||
mistralModel: 'mistral-ocr-latest',
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const file = {
|
||||
path: '/tmp/upload/file.pdf',
|
||||
originalname: 'single-page.pdf',
|
||||
};
|
||||
|
||||
const result = await uploadMistralOCR({
|
||||
req,
|
||||
file,
|
||||
file_id: 'file123',
|
||||
entity_id: 'entity123',
|
||||
});
|
||||
|
||||
expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf');
|
||||
|
||||
// Verify that single page documents don't include page numbering
|
||||
expect(result.text).not.toContain('# PAGE');
|
||||
expect(result.text).toEqual('Single page content\n\n');
|
||||
});
|
||||
|
||||
it('should use literal values in configuration when provided directly', async () => {
|
||||
const { loadAuthValues } = require('~/server/services/Tools/credentials');
|
||||
// We'll still mock this but it should not be used for literal values
|
||||
loadAuthValues.mockResolvedValue({});
|
||||
|
||||
// Clear all previous mocks
|
||||
mockAxios.post.mockClear();
|
||||
mockAxios.get.mockClear();
|
||||
|
||||
// 1. First mock: File upload response
|
||||
mockAxios.post.mockImplementationOnce(() =>
|
||||
Promise.resolve({ data: { id: 'file-123', purpose: 'ocr' } }),
|
||||
);
|
||||
|
||||
// 2. Second mock: Signed URL response
|
||||
mockAxios.get.mockImplementationOnce(() =>
|
||||
Promise.resolve({ data: { url: 'https://signed-url.com' } }),
|
||||
);
|
||||
|
||||
// 3. Third mock: OCR response
|
||||
mockAxios.post.mockImplementationOnce(() =>
|
||||
Promise.resolve({
|
||||
data: {
|
||||
pages: [{ markdown: 'Processed with literal config values' }],
|
||||
},
|
||||
}),
|
||||
);
|
||||
|
||||
const req = {
|
||||
user: { id: 'user123' },
|
||||
app: {
|
||||
locals: {
|
||||
ocr: {
|
||||
// Direct values that should be used as-is, without variable substitution
|
||||
apiKey: 'actual-api-key-value',
|
||||
baseURL: 'https://direct-api-url.mistral.ai/v1',
|
||||
mistralModel: 'mistral-direct-model',
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const file = {
|
||||
path: '/tmp/upload/file.pdf',
|
||||
originalname: 'direct-values.pdf',
|
||||
};
|
||||
|
||||
const result = await uploadMistralOCR({
|
||||
req,
|
||||
file,
|
||||
file_id: 'file123',
|
||||
entity_id: 'entity123',
|
||||
});
|
||||
|
||||
expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf');
|
||||
|
||||
// Verify the correct URL was used with the direct baseURL value
|
||||
expect(mockAxios.post).toHaveBeenCalledWith(
|
||||
'https://direct-api-url.mistral.ai/v1/files',
|
||||
expect.any(Object),
|
||||
expect.objectContaining({
|
||||
headers: expect.objectContaining({
|
||||
Authorization: 'Bearer actual-api-key-value',
|
||||
}),
|
||||
}),
|
||||
);
|
||||
|
||||
// Check the OCR call was made with the direct model value
|
||||
expect(mockAxios.post).toHaveBeenCalledWith(
|
||||
'https://direct-api-url.mistral.ai/v1/ocr',
|
||||
expect.objectContaining({
|
||||
model: 'mistral-direct-model',
|
||||
}),
|
||||
expect.any(Object),
|
||||
);
|
||||
|
||||
// Verify the result
|
||||
expect(result.text).toEqual('Processed with literal config values\n\n');
|
||||
|
||||
// Verify loadAuthValues was never called since we used direct values
|
||||
expect(loadAuthValues).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle empty configuration values and use defaults', async () => {
|
||||
const { loadAuthValues } = require('~/server/services/Tools/credentials');
|
||||
// Set up the mock values to be returned by loadAuthValues
|
||||
loadAuthValues.mockResolvedValue({
|
||||
OCR_API_KEY: 'default-from-env-key',
|
||||
OCR_BASEURL: 'https://default-from-env.mistral.ai/v1',
|
||||
});
|
||||
|
||||
// Clear all previous mocks
|
||||
mockAxios.post.mockClear();
|
||||
mockAxios.get.mockClear();
|
||||
|
||||
// 1. First mock: File upload response
|
||||
mockAxios.post.mockImplementationOnce(() =>
|
||||
Promise.resolve({ data: { id: 'file-123', purpose: 'ocr' } }),
|
||||
);
|
||||
|
||||
// 2. Second mock: Signed URL response
|
||||
mockAxios.get.mockImplementationOnce(() =>
|
||||
Promise.resolve({ data: { url: 'https://signed-url.com' } }),
|
||||
);
|
||||
|
||||
// 3. Third mock: OCR response
|
||||
mockAxios.post.mockImplementationOnce(() =>
|
||||
Promise.resolve({
|
||||
data: {
|
||||
pages: [{ markdown: 'Content from default configuration' }],
|
||||
},
|
||||
}),
|
||||
);
|
||||
|
||||
const req = {
|
||||
user: { id: 'user123' },
|
||||
app: {
|
||||
locals: {
|
||||
ocr: {
|
||||
// Empty string values - should fall back to defaults
|
||||
apiKey: '',
|
||||
baseURL: '',
|
||||
mistralModel: '',
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const file = {
|
||||
path: '/tmp/upload/file.pdf',
|
||||
originalname: 'empty-config.pdf',
|
||||
};
|
||||
|
||||
const result = await uploadMistralOCR({
|
||||
req,
|
||||
file,
|
||||
file_id: 'file123',
|
||||
entity_id: 'entity123',
|
||||
});
|
||||
|
||||
expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf');
|
||||
|
||||
// Verify loadAuthValues was called with the default variable names
|
||||
expect(loadAuthValues).toHaveBeenCalledWith({
|
||||
userId: 'user123',
|
||||
authFields: ['OCR_BASEURL', 'OCR_API_KEY'],
|
||||
optional: expect.any(Set),
|
||||
});
|
||||
|
||||
// Verify the API calls used the default values from loadAuthValues
|
||||
expect(mockAxios.post).toHaveBeenCalledWith(
|
||||
'https://default-from-env.mistral.ai/v1/files',
|
||||
expect.any(Object),
|
||||
expect.objectContaining({
|
||||
headers: expect.objectContaining({
|
||||
Authorization: 'Bearer default-from-env-key',
|
||||
}),
|
||||
}),
|
||||
);
|
||||
|
||||
// Verify the OCR model defaulted to mistral-ocr-latest
|
||||
expect(mockAxios.post).toHaveBeenCalledWith(
|
||||
'https://default-from-env.mistral.ai/v1/ocr',
|
||||
expect.objectContaining({
|
||||
model: 'mistral-ocr-latest',
|
||||
}),
|
||||
expect.any(Object),
|
||||
);
|
||||
|
||||
// Check result
|
||||
expect(result.text).toEqual('Content from default configuration\n\n');
|
||||
});
|
||||
});
|
||||
});
|
||||
5
api/server/services/Files/MistralOCR/index.js
Normal file
5
api/server/services/Files/MistralOCR/index.js
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
const crud = require('./crud');
|
||||
|
||||
module.exports = {
|
||||
...crud,
|
||||
};
|
||||
162
api/server/services/Files/S3/crud.js
Normal file
162
api/server/services/Files/S3/crud.js
Normal file
|
|
@ -0,0 +1,162 @@
|
|||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
const axios = require('axios');
|
||||
const fetch = require('node-fetch');
|
||||
const { getBufferMetadata } = require('~/server/utils');
|
||||
const { initializeS3 } = require('./initialize');
|
||||
const { logger } = require('~/config');
|
||||
const { PutObjectCommand, GetObjectCommand, DeleteObjectCommand } = require('@aws-sdk/client-s3');
|
||||
const { getSignedUrl } = require('@aws-sdk/s3-request-presigner');
|
||||
|
||||
const bucketName = process.env.AWS_BUCKET_NAME;
|
||||
const s3 = initializeS3();
|
||||
const defaultBasePath = 'images';
|
||||
|
||||
/**
|
||||
* Constructs the S3 key based on the base path, user ID, and file name.
|
||||
*/
|
||||
const getS3Key = (basePath, userId, fileName) => `${basePath}/${userId}/${fileName}`;
|
||||
|
||||
/**
|
||||
* Uploads a buffer to S3 and returns a signed URL.
|
||||
*
|
||||
* @param {Object} params
|
||||
* @param {string} params.userId - The user's unique identifier.
|
||||
* @param {Buffer} params.buffer - The buffer containing file data.
|
||||
* @param {string} params.fileName - The file name to use in S3.
|
||||
* @param {string} [params.basePath='images'] - The base path in the bucket.
|
||||
* @returns {Promise<string>} Signed URL of the uploaded file.
|
||||
*/
|
||||
async function saveBufferToS3({ userId, buffer, fileName, basePath = defaultBasePath }) {
|
||||
const key = getS3Key(basePath, userId, fileName);
|
||||
const params = { Bucket: bucketName, Key: key, Body: buffer };
|
||||
|
||||
try {
|
||||
await s3.send(new PutObjectCommand(params));
|
||||
return await getS3URL({ userId, fileName, basePath });
|
||||
} catch (error) {
|
||||
logger.error('[saveBufferToS3] Error uploading buffer to S3:', error.message);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves a signed URL for a file stored in S3.
|
||||
*
|
||||
* @param {Object} params
|
||||
* @param {string} params.userId - The user's unique identifier.
|
||||
* @param {string} params.fileName - The file name in S3.
|
||||
* @param {string} [params.basePath='images'] - The base path in the bucket.
|
||||
* @returns {Promise<string>} A signed URL valid for 24 hours.
|
||||
*/
|
||||
async function getS3URL({ userId, fileName, basePath = defaultBasePath }) {
|
||||
const key = getS3Key(basePath, userId, fileName);
|
||||
const params = { Bucket: bucketName, Key: key };
|
||||
|
||||
try {
|
||||
return await getSignedUrl(s3, new GetObjectCommand(params), { expiresIn: 86400 });
|
||||
} catch (error) {
|
||||
logger.error('[getS3URL] Error getting signed URL from S3:', error.message);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Saves a file from a given URL to S3.
|
||||
*
|
||||
* @param {Object} params
|
||||
* @param {string} params.userId - The user's unique identifier.
|
||||
* @param {string} params.URL - The source URL of the file.
|
||||
* @param {string} params.fileName - The file name to use in S3.
|
||||
* @param {string} [params.basePath='images'] - The base path in the bucket.
|
||||
* @returns {Promise<string>} Signed URL of the uploaded file.
|
||||
*/
|
||||
async function saveURLToS3({ userId, URL, fileName, basePath = defaultBasePath }) {
|
||||
try {
|
||||
const response = await fetch(URL);
|
||||
const buffer = await response.buffer();
|
||||
// Optionally you can call getBufferMetadata(buffer) if needed.
|
||||
return await saveBufferToS3({ userId, buffer, fileName, basePath });
|
||||
} catch (error) {
|
||||
logger.error('[saveURLToS3] Error uploading file from URL to S3:', error.message);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Deletes a file from S3.
|
||||
*
|
||||
* @param {Object} params
|
||||
* @param {string} params.userId - The user's unique identifier.
|
||||
* @param {string} params.fileName - The file name in S3.
|
||||
* @param {string} [params.basePath='images'] - The base path in the bucket.
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
async function deleteFileFromS3({ userId, fileName, basePath = defaultBasePath }) {
|
||||
const key = getS3Key(basePath, userId, fileName);
|
||||
const params = { Bucket: bucketName, Key: key };
|
||||
|
||||
try {
|
||||
await s3.send(new DeleteObjectCommand(params));
|
||||
logger.debug('[deleteFileFromS3] File deleted successfully from S3');
|
||||
} catch (error) {
|
||||
logger.error('[deleteFileFromS3] Error deleting file from S3:', error.message);
|
||||
// If the file is not found, we can safely return.
|
||||
if (error.code === 'NoSuchKey') {
|
||||
return;
|
||||
}
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Uploads a local file to S3.
|
||||
*
|
||||
* @param {Object} params
|
||||
* @param {import('express').Request} params.req - The Express request (must include user).
|
||||
* @param {Express.Multer.File} params.file - The file object from Multer.
|
||||
* @param {string} params.file_id - Unique file identifier.
|
||||
* @param {string} [params.basePath='images'] - The base path in the bucket.
|
||||
* @returns {Promise<{ filepath: string, bytes: number }>}
|
||||
*/
|
||||
async function uploadFileToS3({ req, file, file_id, basePath = defaultBasePath }) {
|
||||
try {
|
||||
const inputFilePath = file.path;
|
||||
const inputBuffer = await fs.promises.readFile(inputFilePath);
|
||||
const bytes = Buffer.byteLength(inputBuffer);
|
||||
const userId = req.user.id;
|
||||
const fileName = `${file_id}__${path.basename(inputFilePath)}`;
|
||||
const fileURL = await saveBufferToS3({ userId, buffer: inputBuffer, fileName, basePath });
|
||||
await fs.promises.unlink(inputFilePath);
|
||||
return { filepath: fileURL, bytes };
|
||||
} catch (error) {
|
||||
logger.error('[uploadFileToS3] Error uploading file to S3:', error.message);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves a readable stream for a file stored in S3.
|
||||
*
|
||||
* @param {string} filePath - The S3 key of the file.
|
||||
* @returns {Promise<NodeJS.ReadableStream>}
|
||||
*/
|
||||
async function getS3FileStream(filePath) {
|
||||
const params = { Bucket: bucketName, Key: filePath };
|
||||
try {
|
||||
const data = await s3.send(new GetObjectCommand(params));
|
||||
return data.Body; // Returns a Node.js ReadableStream.
|
||||
} catch (error) {
|
||||
logger.error('[getS3FileStream] Error retrieving S3 file stream:', error.message);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
saveBufferToS3,
|
||||
saveURLToS3,
|
||||
getS3URL,
|
||||
deleteFileFromS3,
|
||||
uploadFileToS3,
|
||||
getS3FileStream,
|
||||
};
|
||||
118
api/server/services/Files/S3/images.js
Normal file
118
api/server/services/Files/S3/images.js
Normal file
|
|
@ -0,0 +1,118 @@
|
|||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
const sharp = require('sharp');
|
||||
const { resizeImageBuffer } = require('../images/resize');
|
||||
const { updateUser } = require('~/models/userMethods');
|
||||
const { saveBufferToS3 } = require('./crud');
|
||||
const { updateFile } = require('~/models/File');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const defaultBasePath = 'images';
|
||||
|
||||
/**
|
||||
* Resizes, converts, and uploads an image file to S3.
|
||||
*
|
||||
* @param {Object} params
|
||||
* @param {import('express').Request} params.req - Express request (expects user and app.locals.imageOutputType).
|
||||
* @param {Express.Multer.File} params.file - File object from Multer.
|
||||
* @param {string} params.file_id - Unique file identifier.
|
||||
* @param {any} params.endpoint - Endpoint identifier used in image processing.
|
||||
* @param {string} [params.resolution='high'] - Desired image resolution.
|
||||
* @param {string} [params.basePath='images'] - Base path in the bucket.
|
||||
* @returns {Promise<{ filepath: string, bytes: number, width: number, height: number }>}
|
||||
*/
|
||||
async function uploadImageToS3({
|
||||
req,
|
||||
file,
|
||||
file_id,
|
||||
endpoint,
|
||||
resolution = 'high',
|
||||
basePath = defaultBasePath,
|
||||
}) {
|
||||
try {
|
||||
const inputFilePath = file.path;
|
||||
const inputBuffer = await fs.promises.readFile(inputFilePath);
|
||||
const {
|
||||
buffer: resizedBuffer,
|
||||
width,
|
||||
height,
|
||||
} = await resizeImageBuffer(inputBuffer, resolution, endpoint);
|
||||
const extension = path.extname(inputFilePath);
|
||||
const userId = req.user.id;
|
||||
|
||||
let processedBuffer;
|
||||
let fileName = `${file_id}__${path.basename(inputFilePath)}`;
|
||||
const targetExtension = `.${req.app.locals.imageOutputType}`;
|
||||
|
||||
if (extension.toLowerCase() === targetExtension) {
|
||||
processedBuffer = resizedBuffer;
|
||||
} else {
|
||||
processedBuffer = await sharp(resizedBuffer)
|
||||
.toFormat(req.app.locals.imageOutputType)
|
||||
.toBuffer();
|
||||
fileName = fileName.replace(new RegExp(path.extname(fileName) + '$'), targetExtension);
|
||||
if (!path.extname(fileName)) {
|
||||
fileName += targetExtension;
|
||||
}
|
||||
}
|
||||
|
||||
const downloadURL = await saveBufferToS3({
|
||||
userId,
|
||||
buffer: processedBuffer,
|
||||
fileName,
|
||||
basePath,
|
||||
});
|
||||
await fs.promises.unlink(inputFilePath);
|
||||
const bytes = Buffer.byteLength(processedBuffer);
|
||||
return { filepath: downloadURL, bytes, width, height };
|
||||
} catch (error) {
|
||||
logger.error('[uploadImageToS3] Error uploading image to S3:', error.message);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Updates a file record and returns its signed URL.
|
||||
*
|
||||
* @param {import('express').Request} req - Express request.
|
||||
* @param {Object} file - File metadata.
|
||||
* @returns {Promise<[Promise<any>, string]>}
|
||||
*/
|
||||
async function prepareImageURLS3(req, file) {
|
||||
try {
|
||||
const updatePromise = updateFile({ file_id: file.file_id });
|
||||
return Promise.all([updatePromise, file.filepath]);
|
||||
} catch (error) {
|
||||
logger.error('[prepareImageURLS3] Error preparing image URL:', error.message);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Processes a user's avatar image by uploading it to S3 and updating the user's avatar URL if required.
|
||||
*
|
||||
* @param {Object} params
|
||||
* @param {Buffer} params.buffer - Avatar image buffer.
|
||||
* @param {string} params.userId - User's unique identifier.
|
||||
* @param {string} params.manual - 'true' or 'false' flag for manual update.
|
||||
* @param {string} [params.basePath='images'] - Base path in the bucket.
|
||||
* @returns {Promise<string>} Signed URL of the uploaded avatar.
|
||||
*/
|
||||
async function processS3Avatar({ buffer, userId, manual, basePath = defaultBasePath }) {
|
||||
try {
|
||||
const downloadURL = await saveBufferToS3({ userId, buffer, fileName: 'avatar.png', basePath });
|
||||
if (manual === 'true') {
|
||||
await updateUser(userId, { avatar: downloadURL });
|
||||
}
|
||||
return downloadURL;
|
||||
} catch (error) {
|
||||
logger.error('[processS3Avatar] Error processing S3 avatar:', error.message);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
uploadImageToS3,
|
||||
prepareImageURLS3,
|
||||
processS3Avatar,
|
||||
};
|
||||
9
api/server/services/Files/S3/index.js
Normal file
9
api/server/services/Files/S3/index.js
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
const crud = require('./crud');
|
||||
const images = require('./images');
|
||||
const initialize = require('./initialize');
|
||||
|
||||
module.exports = {
|
||||
...crud,
|
||||
...images,
|
||||
...initialize,
|
||||
};
|
||||
43
api/server/services/Files/S3/initialize.js
Normal file
43
api/server/services/Files/S3/initialize.js
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
const { S3Client } = require('@aws-sdk/client-s3');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
let s3 = null;
|
||||
|
||||
/**
|
||||
* Initializes and returns an instance of the AWS S3 client.
|
||||
*
|
||||
* If AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY are provided, they will be used.
|
||||
* Otherwise, the AWS SDK's default credentials chain (including IRSA) is used.
|
||||
*
|
||||
* @returns {S3Client|null} An instance of S3Client if the region is provided; otherwise, null.
|
||||
*/
|
||||
const initializeS3 = () => {
|
||||
if (s3) {
|
||||
return s3;
|
||||
}
|
||||
|
||||
const region = process.env.AWS_REGION;
|
||||
if (!region) {
|
||||
logger.error('[initializeS3] AWS_REGION is not set. Cannot initialize S3.');
|
||||
return null;
|
||||
}
|
||||
|
||||
const accessKeyId = process.env.AWS_ACCESS_KEY_ID;
|
||||
const secretAccessKey = process.env.AWS_SECRET_ACCESS_KEY;
|
||||
|
||||
if (accessKeyId && secretAccessKey) {
|
||||
s3 = new S3Client({
|
||||
region,
|
||||
credentials: { accessKeyId, secretAccessKey },
|
||||
});
|
||||
logger.info('[initializeS3] S3 initialized with provided credentials.');
|
||||
} else {
|
||||
// When using IRSA, credentials are automatically provided via the IAM Role attached to the ServiceAccount.
|
||||
s3 = new S3Client({ region });
|
||||
logger.info('[initializeS3] S3 initialized using default credentials (IRSA).');
|
||||
}
|
||||
|
||||
return s3;
|
||||
};
|
||||
|
||||
module.exports = { initializeS3 };
|
||||
|
|
@ -49,6 +49,7 @@ async function encodeAndFormat(req, files, endpoint, mode) {
|
|||
const promises = [];
|
||||
const encodingMethods = {};
|
||||
const result = {
|
||||
text: '',
|
||||
files: [],
|
||||
image_urls: [],
|
||||
};
|
||||
|
|
@ -59,6 +60,9 @@ async function encodeAndFormat(req, files, endpoint, mode) {
|
|||
|
||||
for (let file of files) {
|
||||
const source = file.source ?? FileSources.local;
|
||||
if (source === FileSources.text && file.text) {
|
||||
result.text += `${!result.text ? 'Attached document(s):\n```md' : '\n\n---\n\n'}# "${file.filename}"\n${file.text}\n`;
|
||||
}
|
||||
|
||||
if (!file.height) {
|
||||
promises.push([file, null]);
|
||||
|
|
@ -85,6 +89,10 @@ async function encodeAndFormat(req, files, endpoint, mode) {
|
|||
promises.push(preparePayload(req, file));
|
||||
}
|
||||
|
||||
if (result.text) {
|
||||
result.text += '\n```';
|
||||
}
|
||||
|
||||
const detail = req.body.imageDetail ?? ImageDetail.auto;
|
||||
|
||||
/** @type {Array<[MongoFile, string]>} */
|
||||
|
|
|
|||
|
|
@ -28,8 +28,8 @@ const { addResourceFileId, deleteResourceFileId } = require('~/server/controller
|
|||
const { addAgentResourceFile, removeAgentResourceFiles } = require('~/models/Agent');
|
||||
const { getOpenAIClient } = require('~/server/controllers/assistants/helpers');
|
||||
const { createFile, updateFileUsage, deleteFiles } = require('~/models/File');
|
||||
const { getEndpointsConfig } = require('~/server/services/Config');
|
||||
const { loadAuthValues } = require('~/app/clients/tools/util');
|
||||
const { loadAuthValues } = require('~/server/services/Tools/credentials');
|
||||
const { checkCapability } = require('~/server/services/Config');
|
||||
const { LB_QueueAsyncCall } = require('~/server/utils/queue');
|
||||
const { getStrategyFunctions } = require('./strategies');
|
||||
const { determineFileType } = require('~/server/utils');
|
||||
|
|
@ -162,7 +162,6 @@ const processDeleteRequest = async ({ req, files }) => {
|
|||
|
||||
for (const file of files) {
|
||||
const source = file.source ?? FileSources.local;
|
||||
|
||||
if (req.body.agent_id && req.body.tool_resource) {
|
||||
agentFiles.push({
|
||||
tool_resource: req.body.tool_resource,
|
||||
|
|
@ -170,6 +169,11 @@ const processDeleteRequest = async ({ req, files }) => {
|
|||
});
|
||||
}
|
||||
|
||||
if (source === FileSources.text) {
|
||||
resolvedFileIds.push(file.file_id);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (checkOpenAIStorage(source) && !client[source]) {
|
||||
await initializeClients();
|
||||
}
|
||||
|
|
@ -453,17 +457,6 @@ const processFileUpload = async ({ req, res, metadata }) => {
|
|||
res.status(200).json({ message: 'File uploaded and processed successfully', ...result });
|
||||
};
|
||||
|
||||
/**
|
||||
* @param {ServerRequest} req
|
||||
* @param {AgentCapabilities} capability
|
||||
* @returns {Promise<boolean>}
|
||||
*/
|
||||
const checkCapability = async (req, capability) => {
|
||||
const endpointsConfig = await getEndpointsConfig(req);
|
||||
const capabilities = endpointsConfig?.[EModelEndpoint.agents]?.capabilities ?? [];
|
||||
return capabilities.includes(capability);
|
||||
};
|
||||
|
||||
/**
|
||||
* Applies the current strategy for file uploads.
|
||||
* Saves file metadata to the database with an expiry TTL.
|
||||
|
|
@ -521,6 +514,52 @@ const processAgentFileUpload = async ({ req, res, metadata }) => {
|
|||
if (!isFileSearchEnabled) {
|
||||
throw new Error('File search is not enabled for Agents');
|
||||
}
|
||||
} else if (tool_resource === EToolResources.ocr) {
|
||||
const isOCREnabled = await checkCapability(req, AgentCapabilities.ocr);
|
||||
if (!isOCREnabled) {
|
||||
throw new Error('OCR capability is not enabled for Agents');
|
||||
}
|
||||
|
||||
const { handleFileUpload } = getStrategyFunctions(
|
||||
req.app.locals?.ocr?.strategy ?? FileSources.mistral_ocr,
|
||||
);
|
||||
const { file_id, temp_file_id } = metadata;
|
||||
|
||||
const {
|
||||
text,
|
||||
bytes,
|
||||
// TODO: OCR images support?
|
||||
images,
|
||||
filename,
|
||||
filepath: ocrFileURL,
|
||||
} = await handleFileUpload({ req, file, file_id, entity_id: agent_id });
|
||||
|
||||
const fileInfo = removeNullishValues({
|
||||
text,
|
||||
bytes,
|
||||
file_id,
|
||||
temp_file_id,
|
||||
user: req.user.id,
|
||||
type: file.mimetype,
|
||||
filepath: ocrFileURL,
|
||||
source: FileSources.text,
|
||||
filename: filename ?? file.originalname,
|
||||
model: messageAttachment ? undefined : req.body.model,
|
||||
context: messageAttachment ? FileContext.message_attachment : FileContext.agents,
|
||||
});
|
||||
|
||||
if (!messageAttachment && tool_resource) {
|
||||
await addAgentResourceFile({
|
||||
req,
|
||||
file_id,
|
||||
agent_id,
|
||||
tool_resource,
|
||||
});
|
||||
}
|
||||
const result = await createFile(fileInfo, true);
|
||||
return res
|
||||
.status(200)
|
||||
.json({ message: 'Agent file uploaded and processed successfully', ...result });
|
||||
}
|
||||
|
||||
const source =
|
||||
|
|
|
|||
|
|
@ -21,9 +21,21 @@ const {
|
|||
processLocalAvatar,
|
||||
getLocalFileStream,
|
||||
} = require('./Local');
|
||||
const {
|
||||
getS3URL,
|
||||
saveURLToS3,
|
||||
saveBufferToS3,
|
||||
getS3FileStream,
|
||||
uploadImageToS3,
|
||||
prepareImageURLS3,
|
||||
deleteFileFromS3,
|
||||
processS3Avatar,
|
||||
uploadFileToS3,
|
||||
} = require('./S3');
|
||||
const { uploadOpenAIFile, deleteOpenAIFile, getOpenAIFileStream } = require('./OpenAI');
|
||||
const { getCodeOutputDownloadStream, uploadCodeEnvFile } = require('./Code');
|
||||
const { uploadVectors, deleteVectors } = require('./VectorDB');
|
||||
const { uploadMistralOCR } = require('./MistralOCR');
|
||||
|
||||
/**
|
||||
* Firebase Storage Strategy Functions
|
||||
|
|
@ -57,6 +69,22 @@ const localStrategy = () => ({
|
|||
getDownloadStream: getLocalFileStream,
|
||||
});
|
||||
|
||||
/**
|
||||
* S3 Storage Strategy Functions
|
||||
*
|
||||
* */
|
||||
const s3Strategy = () => ({
|
||||
handleFileUpload: uploadFileToS3,
|
||||
saveURL: saveURLToS3,
|
||||
getFileURL: getS3URL,
|
||||
deleteFile: deleteFileFromS3,
|
||||
saveBuffer: saveBufferToS3,
|
||||
prepareImagePayload: prepareImageURLS3,
|
||||
processAvatar: processS3Avatar,
|
||||
handleImageUpload: uploadImageToS3,
|
||||
getDownloadStream: getS3FileStream,
|
||||
});
|
||||
|
||||
/**
|
||||
* VectorDB Storage Strategy Functions
|
||||
*
|
||||
|
|
@ -127,6 +155,26 @@ const codeOutputStrategy = () => ({
|
|||
getDownloadStream: getCodeOutputDownloadStream,
|
||||
});
|
||||
|
||||
const mistralOCRStrategy = () => ({
|
||||
/** @type {typeof saveFileFromURL | null} */
|
||||
saveURL: null,
|
||||
/** @type {typeof getLocalFileURL | null} */
|
||||
getFileURL: null,
|
||||
/** @type {typeof saveLocalBuffer | null} */
|
||||
saveBuffer: null,
|
||||
/** @type {typeof processLocalAvatar | null} */
|
||||
processAvatar: null,
|
||||
/** @type {typeof uploadLocalImage | null} */
|
||||
handleImageUpload: null,
|
||||
/** @type {typeof prepareImagesLocal | null} */
|
||||
prepareImagePayload: null,
|
||||
/** @type {typeof deleteLocalFile | null} */
|
||||
deleteFile: null,
|
||||
/** @type {typeof getLocalFileStream | null} */
|
||||
getDownloadStream: null,
|
||||
handleFileUpload: uploadMistralOCR,
|
||||
});
|
||||
|
||||
// Strategy Selector
|
||||
const getStrategyFunctions = (fileSource) => {
|
||||
if (fileSource === FileSources.firebase) {
|
||||
|
|
@ -139,8 +187,12 @@ const getStrategyFunctions = (fileSource) => {
|
|||
return openAIStrategy();
|
||||
} else if (fileSource === FileSources.vectordb) {
|
||||
return vectorStrategy();
|
||||
} else if (fileSource === FileSources.s3) {
|
||||
return s3Strategy();
|
||||
} else if (fileSource === FileSources.execute_code) {
|
||||
return codeOutputStrategy();
|
||||
} else if (fileSource === FileSources.mistral_ocr) {
|
||||
return mistralOCRStrategy();
|
||||
} else {
|
||||
throw new Error('Invalid file source');
|
||||
}
|
||||
|
|
|
|||
|
|
@ -362,7 +362,12 @@ async function processRequiredActions(client, requiredActions) {
|
|||
continue;
|
||||
}
|
||||
|
||||
tool = await createActionTool({ action: actionSet, requestBuilder });
|
||||
tool = await createActionTool({
|
||||
req: client.req,
|
||||
res: client.res,
|
||||
action: actionSet,
|
||||
requestBuilder,
|
||||
});
|
||||
if (!tool) {
|
||||
logger.warn(
|
||||
`Invalid action: user: ${client.req.user.id} | thread_id: ${requiredActions[0].thread_id} | run_id: ${requiredActions[0].run_id} | toolName: ${currentAction.tool}`,
|
||||
|
|
|
|||
56
api/server/services/Tools/credentials.js
Normal file
56
api/server/services/Tools/credentials.js
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
const { getUserPluginAuthValue } = require('~/server/services/PluginService');
|
||||
|
||||
/**
|
||||
*
|
||||
* @param {Object} params
|
||||
* @param {string} params.userId
|
||||
* @param {string[]} params.authFields
|
||||
* @param {Set<string>} [params.optional]
|
||||
* @param {boolean} [params.throwError]
|
||||
* @returns
|
||||
*/
|
||||
const loadAuthValues = async ({ userId, authFields, optional, throwError = true }) => {
|
||||
let authValues = {};
|
||||
|
||||
/**
|
||||
* Finds the first non-empty value for the given authentication field, supporting alternate fields.
|
||||
* @param {string[]} fields Array of strings representing the authentication fields. Supports alternate fields delimited by "||".
|
||||
* @returns {Promise<{ authField: string, authValue: string} | null>} An object containing the authentication field and value, or null if not found.
|
||||
*/
|
||||
const findAuthValue = async (fields) => {
|
||||
for (const field of fields) {
|
||||
let value = process.env[field];
|
||||
if (value) {
|
||||
return { authField: field, authValue: value };
|
||||
}
|
||||
try {
|
||||
value = await getUserPluginAuthValue(userId, field, throwError);
|
||||
} catch (err) {
|
||||
if (optional && optional.has(field)) {
|
||||
return { authField: field, authValue: undefined };
|
||||
}
|
||||
if (field === fields[fields.length - 1] && !value) {
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
if (value) {
|
||||
return { authField: field, authValue: value };
|
||||
}
|
||||
}
|
||||
return null;
|
||||
};
|
||||
|
||||
for (let authField of authFields) {
|
||||
const fields = authField.split('||');
|
||||
const result = await findAuthValue(fields);
|
||||
if (result) {
|
||||
authValues[result.authField] = result.authValue;
|
||||
}
|
||||
}
|
||||
|
||||
return authValues;
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
loadAuthValues,
|
||||
};
|
||||
|
|
@ -203,6 +203,8 @@ function generateConfig(key, baseURL, endpoint) {
|
|||
AgentCapabilities.artifacts,
|
||||
AgentCapabilities.actions,
|
||||
AgentCapabilities.tools,
|
||||
AgentCapabilities.ocr,
|
||||
AgentCapabilities.chain,
|
||||
];
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -39,7 +39,10 @@ jest.mock('winston-daily-rotate-file', () => {
|
|||
});
|
||||
|
||||
jest.mock('~/config', () => {
|
||||
const actualModule = jest.requireActual('~/config');
|
||||
return {
|
||||
sendEvent: actualModule.sendEvent,
|
||||
createAxiosInstance: actualModule.createAxiosInstance,
|
||||
logger: {
|
||||
info: jest.fn(),
|
||||
warn: jest.fn(),
|
||||
|
|
|
|||
|
|
@ -1787,3 +1787,51 @@
|
|||
* @typedef {Promise<{ message: TMessage, conversation: TConversation }> | undefined} ClientDatabaseSavePromise
|
||||
* @memberof typedefs
|
||||
*/
|
||||
|
||||
/**
|
||||
* @exports OCRImage
|
||||
* @typedef {Object} OCRImage
|
||||
* @property {string} id - The identifier of the image.
|
||||
* @property {number} top_left_x - X-coordinate of the top left corner of the image.
|
||||
* @property {number} top_left_y - Y-coordinate of the top left corner of the image.
|
||||
* @property {number} bottom_right_x - X-coordinate of the bottom right corner of the image.
|
||||
* @property {number} bottom_right_y - Y-coordinate of the bottom right corner of the image.
|
||||
* @property {string} image_base64 - Base64-encoded image data.
|
||||
* @memberof typedefs
|
||||
*/
|
||||
|
||||
/**
|
||||
* @exports PageDimensions
|
||||
* @typedef {Object} PageDimensions
|
||||
* @property {number} dpi - The dots per inch resolution of the page.
|
||||
* @property {number} height - The height of the page in pixels.
|
||||
* @property {number} width - The width of the page in pixels.
|
||||
* @memberof typedefs
|
||||
*/
|
||||
|
||||
/**
|
||||
* @exports OCRPage
|
||||
* @typedef {Object} OCRPage
|
||||
* @property {number} index - The index of the page in the document.
|
||||
* @property {string} markdown - The extracted text content of the page in markdown format.
|
||||
* @property {OCRImage[]} images - Array of images found on the page.
|
||||
* @property {PageDimensions} dimensions - The dimensions of the page.
|
||||
* @memberof typedefs
|
||||
*/
|
||||
|
||||
/**
|
||||
* @exports OCRUsageInfo
|
||||
* @typedef {Object} OCRUsageInfo
|
||||
* @property {number} pages_processed - Number of pages processed in the document.
|
||||
* @property {number} doc_size_bytes - Size of the document in bytes.
|
||||
* @memberof typedefs
|
||||
*/
|
||||
|
||||
/**
|
||||
* @exports OCRResult
|
||||
* @typedef {Object} OCRResult
|
||||
* @property {OCRPage[]} pages - Array of pages extracted from the document.
|
||||
* @property {string} model - The model used for OCR processing.
|
||||
* @property {OCRUsageInfo} usage_info - Usage information for the OCR operation.
|
||||
* @memberof typedefs
|
||||
*/
|
||||
|
|
|
|||
|
|
@ -92,6 +92,7 @@ const anthropicModels = {
|
|||
const deepseekModels = {
|
||||
'deepseek-reasoner': 63000, // -1000 from max (API)
|
||||
deepseek: 63000, // -1000 from max (API)
|
||||
'deepseek.r1': 127500,
|
||||
};
|
||||
|
||||
const metaModels = {
|
||||
|
|
|
|||
|
|
@ -423,6 +423,9 @@ describe('Meta Models Tests', () => {
|
|||
expect(getModelMaxTokens('deepseek-reasoner')).toBe(
|
||||
maxTokensMap[EModelEndpoint.openAI]['deepseek-reasoner'],
|
||||
);
|
||||
expect(getModelMaxTokens('deepseek.r1')).toBe(
|
||||
maxTokensMap[EModelEndpoint.openAI]['deepseek.r1'],
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue