mirror of
https://github.com/danny-avila/LibreChat.git
synced 2026-01-02 00:28:51 +01:00
Merge branch 'dev' into feat/prompt-enhancement
This commit is contained in:
commit
3d261a969d
365 changed files with 23826 additions and 8790 deletions
|
|
@ -220,6 +220,9 @@ function disposeClient(client) {
|
|||
if (client.maxResponseTokens) {
|
||||
client.maxResponseTokens = null;
|
||||
}
|
||||
if (client.processMemory) {
|
||||
client.processMemory = null;
|
||||
}
|
||||
if (client.run) {
|
||||
// Break circular references in run
|
||||
if (client.run.Graph) {
|
||||
|
|
|
|||
|
|
@ -1,9 +1,11 @@
|
|||
const { logger } = require('@librechat/data-schemas');
|
||||
const { CacheKeys, AuthType } = require('librechat-data-provider');
|
||||
const { getCustomConfig, getCachedTools } = require('~/server/services/Config');
|
||||
const { getToolkitKey } = require('~/server/services/ToolService');
|
||||
const { getCustomConfig } = require('~/server/services/Config');
|
||||
const { getMCPManager, getFlowStateManager } = require('~/config');
|
||||
const { availableTools } = require('~/app/clients/tools');
|
||||
const { getMCPManager } = require('~/config');
|
||||
const { getLogStores } = require('~/cache');
|
||||
const { Constants } = require('librechat-data-provider');
|
||||
|
||||
/**
|
||||
* Filters out duplicate plugins from the list of plugins.
|
||||
|
|
@ -84,6 +86,45 @@ const getAvailablePluginsController = async (req, res) => {
|
|||
}
|
||||
};
|
||||
|
||||
function createServerToolsCallback() {
|
||||
/**
|
||||
* @param {string} serverName
|
||||
* @param {TPlugin[] | null} serverTools
|
||||
*/
|
||||
return async function (serverName, serverTools) {
|
||||
try {
|
||||
const mcpToolsCache = getLogStores(CacheKeys.MCP_TOOLS);
|
||||
if (!serverName || !mcpToolsCache) {
|
||||
return;
|
||||
}
|
||||
await mcpToolsCache.set(serverName, serverTools);
|
||||
logger.debug(`MCP tools for ${serverName} added to cache.`);
|
||||
} catch (error) {
|
||||
logger.error('Error retrieving MCP tools from cache:', error);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
function createGetServerTools() {
|
||||
/**
|
||||
* Retrieves cached server tools
|
||||
* @param {string} serverName
|
||||
* @returns {Promise<TPlugin[] | null>}
|
||||
*/
|
||||
return async function (serverName) {
|
||||
try {
|
||||
const mcpToolsCache = getLogStores(CacheKeys.MCP_TOOLS);
|
||||
if (!mcpToolsCache) {
|
||||
return null;
|
||||
}
|
||||
return await mcpToolsCache.get(serverName);
|
||||
} catch (error) {
|
||||
logger.error('Error retrieving MCP tools from cache:', error);
|
||||
return null;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves and returns a list of available tools, either from a cache or by reading a plugin manifest file.
|
||||
*
|
||||
|
|
@ -109,7 +150,16 @@ const getAvailableTools = async (req, res) => {
|
|||
const customConfig = await getCustomConfig();
|
||||
if (customConfig?.mcpServers != null) {
|
||||
const mcpManager = getMCPManager();
|
||||
pluginManifest = await mcpManager.loadManifestTools(pluginManifest);
|
||||
const flowsCache = getLogStores(CacheKeys.FLOWS);
|
||||
const flowManager = flowsCache ? getFlowStateManager(flowsCache) : null;
|
||||
const serverToolsCallback = createServerToolsCallback();
|
||||
const getServerTools = createGetServerTools();
|
||||
const mcpTools = await mcpManager.loadManifestTools({
|
||||
flowManager,
|
||||
serverToolsCallback,
|
||||
getServerTools,
|
||||
});
|
||||
pluginManifest = [...mcpTools, ...pluginManifest];
|
||||
}
|
||||
|
||||
/** @type {TPlugin[]} */
|
||||
|
|
@ -123,17 +173,57 @@ const getAvailableTools = async (req, res) => {
|
|||
}
|
||||
});
|
||||
|
||||
const toolDefinitions = req.app.locals.availableTools;
|
||||
const tools = authenticatedPlugins.filter(
|
||||
(plugin) =>
|
||||
toolDefinitions[plugin.pluginKey] !== undefined ||
|
||||
(plugin.toolkit === true &&
|
||||
Object.keys(toolDefinitions).some((key) => getToolkitKey(key) === plugin.pluginKey)),
|
||||
);
|
||||
const toolDefinitions = await getCachedTools({ includeGlobal: true });
|
||||
|
||||
await cache.set(CacheKeys.TOOLS, tools);
|
||||
res.status(200).json(tools);
|
||||
const toolsOutput = [];
|
||||
for (const plugin of authenticatedPlugins) {
|
||||
const isToolDefined = toolDefinitions[plugin.pluginKey] !== undefined;
|
||||
const isToolkit =
|
||||
plugin.toolkit === true &&
|
||||
Object.keys(toolDefinitions).some((key) => getToolkitKey(key) === plugin.pluginKey);
|
||||
|
||||
if (!isToolDefined && !isToolkit) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const toolToAdd = { ...plugin };
|
||||
|
||||
if (!plugin.pluginKey.includes(Constants.mcp_delimiter)) {
|
||||
toolsOutput.push(toolToAdd);
|
||||
continue;
|
||||
}
|
||||
|
||||
const parts = plugin.pluginKey.split(Constants.mcp_delimiter);
|
||||
const serverName = parts[parts.length - 1];
|
||||
const serverConfig = customConfig?.mcpServers?.[serverName];
|
||||
|
||||
if (!serverConfig?.customUserVars) {
|
||||
toolsOutput.push(toolToAdd);
|
||||
continue;
|
||||
}
|
||||
|
||||
const customVarKeys = Object.keys(serverConfig.customUserVars);
|
||||
|
||||
if (customVarKeys.length === 0) {
|
||||
toolToAdd.authConfig = [];
|
||||
toolToAdd.authenticated = true;
|
||||
} else {
|
||||
toolToAdd.authConfig = Object.entries(serverConfig.customUserVars).map(([key, value]) => ({
|
||||
authField: key,
|
||||
label: value.title || key,
|
||||
description: value.description || '',
|
||||
}));
|
||||
toolToAdd.authenticated = false;
|
||||
}
|
||||
|
||||
toolsOutput.push(toolToAdd);
|
||||
}
|
||||
|
||||
const finalTools = filterUniquePlugins(toolsOutput);
|
||||
await cache.set(CacheKeys.TOOLS, finalTools);
|
||||
res.status(200).json(finalTools);
|
||||
} catch (error) {
|
||||
logger.error('[getAvailableTools]', error);
|
||||
res.status(500).json({ message: error.message });
|
||||
}
|
||||
};
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
const { encryptV3 } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const {
|
||||
verifyTOTP,
|
||||
|
|
@ -7,7 +8,6 @@ const {
|
|||
generateBackupCodes,
|
||||
} = require('~/server/services/twoFactorService');
|
||||
const { getUserById, updateUser } = require('~/models');
|
||||
const { encryptV3 } = require('~/server/utils/crypto');
|
||||
|
||||
const safeAppTitle = (process.env.APP_TITLE || 'LibreChat').replace(/\s+/g, '');
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
const {
|
||||
Tools,
|
||||
Constants,
|
||||
FileSources,
|
||||
webSearchKeys,
|
||||
extractWebSearchEnvVars,
|
||||
|
|
@ -21,8 +22,9 @@ const { verifyEmail, resendVerificationEmail } = require('~/server/services/Auth
|
|||
const { needsRefresh, getNewS3URL } = require('~/server/services/Files/S3/crud');
|
||||
const { processDeleteRequest } = require('~/server/services/Files/process');
|
||||
const { Transaction, Balance, User } = require('~/db/models');
|
||||
const { deleteAllSharedLinks } = require('~/models/Share');
|
||||
const { deleteToolCalls } = require('~/models/ToolCall');
|
||||
const { deleteAllSharedLinks } = require('~/models');
|
||||
const { getMCPManager } = require('~/config');
|
||||
|
||||
const getUserController = async (req, res) => {
|
||||
/** @type {MongoUser} */
|
||||
|
|
@ -102,10 +104,22 @@ const updateUserPluginsController = async (req, res) => {
|
|||
}
|
||||
|
||||
let keys = Object.keys(auth);
|
||||
if (keys.length === 0 && pluginKey !== Tools.web_search) {
|
||||
const values = Object.values(auth); // Used in 'install' block
|
||||
|
||||
const isMCPTool = pluginKey.startsWith('mcp_') || pluginKey.includes(Constants.mcp_delimiter);
|
||||
|
||||
// Early exit condition:
|
||||
// If keys are empty (meaning auth: {} was likely sent for uninstall, or auth was empty for install)
|
||||
// AND it's not web_search (which has special key handling to populate `keys` for uninstall)
|
||||
// AND it's NOT (an uninstall action FOR an MCP tool - we need to proceed for this case to clear all its auth)
|
||||
// THEN return.
|
||||
if (
|
||||
keys.length === 0 &&
|
||||
pluginKey !== Tools.web_search &&
|
||||
!(action === 'uninstall' && isMCPTool)
|
||||
) {
|
||||
return res.status(200).send();
|
||||
}
|
||||
const values = Object.values(auth);
|
||||
|
||||
/** @type {number} */
|
||||
let status = 200;
|
||||
|
|
@ -132,16 +146,53 @@ const updateUserPluginsController = async (req, res) => {
|
|||
}
|
||||
}
|
||||
} else if (action === 'uninstall') {
|
||||
for (let i = 0; i < keys.length; i++) {
|
||||
authService = await deleteUserPluginAuth(user.id, keys[i]);
|
||||
// const isMCPTool was defined earlier
|
||||
if (isMCPTool && keys.length === 0) {
|
||||
// This handles the case where auth: {} is sent for an MCP tool uninstall.
|
||||
// It means "delete all credentials associated with this MCP pluginKey".
|
||||
authService = await deleteUserPluginAuth(user.id, null, true, pluginKey);
|
||||
if (authService instanceof Error) {
|
||||
logger.error('[authService]', authService);
|
||||
logger.error(
|
||||
`[authService] Error deleting all auth for MCP tool ${pluginKey}:`,
|
||||
authService,
|
||||
);
|
||||
({ status, message } = authService);
|
||||
}
|
||||
} else {
|
||||
// This handles:
|
||||
// 1. Web_search uninstall (keys will be populated with all webSearchKeys if auth was {}).
|
||||
// 2. Other tools uninstall (if keys were provided).
|
||||
// 3. MCP tool uninstall if specific keys were provided in `auth` (not current frontend behavior).
|
||||
// If keys is empty for non-MCP tools (and not web_search), this loop won't run, and nothing is deleted.
|
||||
for (let i = 0; i < keys.length; i++) {
|
||||
authService = await deleteUserPluginAuth(user.id, keys[i]); // Deletes by authField name
|
||||
if (authService instanceof Error) {
|
||||
logger.error('[authService] Error deleting specific auth key:', authService);
|
||||
({ status, message } = authService);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (status === 200) {
|
||||
// If auth was updated successfully, disconnect MCP sessions as they might use these credentials
|
||||
if (pluginKey.startsWith(Constants.mcp_prefix)) {
|
||||
try {
|
||||
const mcpManager = getMCPManager(user.id);
|
||||
if (mcpManager) {
|
||||
logger.info(
|
||||
`[updateUserPluginsController] Disconnecting MCP connections for user ${user.id} after plugin auth update for ${pluginKey}.`,
|
||||
);
|
||||
await mcpManager.disconnectUserConnections(user.id);
|
||||
}
|
||||
} catch (disconnectError) {
|
||||
logger.error(
|
||||
`[updateUserPluginsController] Error disconnecting MCP connections for user ${user.id} after plugin auth update:`,
|
||||
disconnectError,
|
||||
);
|
||||
// Do not fail the request for this, but log it.
|
||||
}
|
||||
}
|
||||
return res.status(status).send();
|
||||
}
|
||||
|
||||
|
|
@ -163,7 +214,11 @@ const deleteUserController = async (req, res) => {
|
|||
await Balance.deleteMany({ user: user._id }); // delete user balances
|
||||
await deletePresets(user.id); // delete user presets
|
||||
/* TODO: Delete Assistant Threads */
|
||||
await deleteConvos(user.id); // delete user convos
|
||||
try {
|
||||
await deleteConvos(user.id); // delete user convos
|
||||
} catch (error) {
|
||||
logger.error('[deleteUserController] Error deleting user convos, likely no convos', error);
|
||||
}
|
||||
await deleteUserPluginAuth(user.id, null, true); // delete user plugin auth
|
||||
await deleteUserById(user.id); // delete user
|
||||
await deleteAllSharedLinks(user.id); // delete user shared links
|
||||
|
|
|
|||
|
|
@ -1,4 +1,6 @@
|
|||
const { nanoid } = require('nanoid');
|
||||
const { sendEvent } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { Tools, StepTypes, FileContext } = require('librechat-data-provider');
|
||||
const {
|
||||
EnvVar,
|
||||
|
|
@ -12,7 +14,6 @@ const {
|
|||
const { processCodeOutput } = require('~/server/services/Files/Code/process');
|
||||
const { loadAuthValues } = require('~/server/services/Tools/credentials');
|
||||
const { saveBase64Image } = require('~/server/services/Files/process');
|
||||
const { logger, sendEvent } = require('~/config');
|
||||
|
||||
class ModelEndHandler {
|
||||
/**
|
||||
|
|
@ -240,9 +241,7 @@ function createToolEndCallback({ req, res, artifactPromises }) {
|
|||
if (output.artifact[Tools.web_search]) {
|
||||
artifactPromises.push(
|
||||
(async () => {
|
||||
const name = `${output.name}_${output.tool_call_id}_${nanoid()}`;
|
||||
const attachment = {
|
||||
name,
|
||||
type: Tools.web_search,
|
||||
messageId: metadata.run_id,
|
||||
toolCallId: output.tool_call_id,
|
||||
|
|
|
|||
|
|
@ -1,13 +1,12 @@
|
|||
// const { HttpsProxyAgent } = require('https-proxy-agent');
|
||||
// const {
|
||||
// Constants,
|
||||
// ImageDetail,
|
||||
// EModelEndpoint,
|
||||
// resolveHeaders,
|
||||
// validateVisionModel,
|
||||
// mapModelToAzureConfig,
|
||||
// } = require('librechat-data-provider');
|
||||
require('events').EventEmitter.defaultMaxListeners = 100;
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const {
|
||||
sendEvent,
|
||||
createRun,
|
||||
Tokenizer,
|
||||
memoryInstructions,
|
||||
createMemoryProcessor,
|
||||
} = require('@librechat/api');
|
||||
const {
|
||||
Callback,
|
||||
GraphEvents,
|
||||
|
|
@ -19,25 +18,34 @@ const {
|
|||
} = require('@librechat/agents');
|
||||
const {
|
||||
Constants,
|
||||
Permissions,
|
||||
VisionModes,
|
||||
ContentTypes,
|
||||
EModelEndpoint,
|
||||
KnownEndpoints,
|
||||
PermissionTypes,
|
||||
isAgentsEndpoint,
|
||||
AgentCapabilities,
|
||||
bedrockInputSchema,
|
||||
removeNullishValues,
|
||||
} = require('librechat-data-provider');
|
||||
const { getCustomEndpointConfig, checkCapability } = require('~/server/services/Config');
|
||||
const { addCacheControl, createContextHandlers } = require('~/app/clients/prompts');
|
||||
const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens');
|
||||
const { DynamicStructuredTool } = require('@langchain/core/tools');
|
||||
const { getBufferString, HumanMessage } = require('@langchain/core/messages');
|
||||
const {
|
||||
getCustomEndpointConfig,
|
||||
createGetMCPAuthMap,
|
||||
checkCapability,
|
||||
} = require('~/server/services/Config');
|
||||
const { addCacheControl, createContextHandlers } = require('~/app/clients/prompts');
|
||||
const { initializeAgent } = require('~/server/services/Endpoints/agents/agent');
|
||||
const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens');
|
||||
const { getFormattedMemories, deleteMemory, setMemory } = require('~/models');
|
||||
const { encodeAndFormat } = require('~/server/services/Files/images/encode');
|
||||
const initOpenAI = require('~/server/services/Endpoints/openAI/initialize');
|
||||
const Tokenizer = require('~/server/services/Tokenizer');
|
||||
const { checkAccess } = require('~/server/middleware/roles/access');
|
||||
const BaseClient = require('~/app/clients/BaseClient');
|
||||
const { logger, sendEvent } = require('~/config');
|
||||
const { createRun } = require('./run');
|
||||
const { loadAgent } = require('~/models/Agent');
|
||||
const { getMCPManager } = require('~/config');
|
||||
|
||||
/**
|
||||
* @param {ServerRequest} req
|
||||
|
|
@ -57,12 +65,8 @@ const legacyContentEndpoints = new Set([KnownEndpoints.groq, KnownEndpoints.deep
|
|||
|
||||
const noSystemModelRegex = [/\b(o1-preview|o1-mini|amazon\.titan-text)\b/gi];
|
||||
|
||||
// const { processMemory, memoryInstructions } = require('~/server/services/Endpoints/agents/memory');
|
||||
// const { getFormattedMemories } = require('~/models/Memory');
|
||||
// const { getCurrentDateTime } = require('~/utils');
|
||||
|
||||
function createTokenCounter(encoding) {
|
||||
return (message) => {
|
||||
return function (message) {
|
||||
const countTokens = (text) => Tokenizer.getTokenCount(text, encoding);
|
||||
return getTokenCountForMessage(message, countTokens);
|
||||
};
|
||||
|
|
@ -123,6 +127,8 @@ class AgentClient extends BaseClient {
|
|||
this.usage;
|
||||
/** @type {Record<string, number>} */
|
||||
this.indexTokenCountMap = {};
|
||||
/** @type {(messages: BaseMessage[]) => Promise<void>} */
|
||||
this.processMemory;
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -137,55 +143,10 @@ class AgentClient extends BaseClient {
|
|||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* Checks if the model is a vision model based on request attachments and sets the appropriate options:
|
||||
* - Sets `this.modelOptions.model` to `gpt-4-vision-preview` if the request is a vision request.
|
||||
* - Sets `this.isVisionModel` to `true` if vision request.
|
||||
* - Deletes `this.modelOptions.stop` if vision request.
|
||||
* `AgentClient` is not opinionated about vision requests, so we don't do anything here
|
||||
* @param {MongoFile[]} attachments
|
||||
*/
|
||||
checkVisionRequest(attachments) {
|
||||
// if (!attachments) {
|
||||
// return;
|
||||
// }
|
||||
// const availableModels = this.options.modelsConfig?.[this.options.endpoint];
|
||||
// if (!availableModels) {
|
||||
// return;
|
||||
// }
|
||||
// let visionRequestDetected = false;
|
||||
// for (const file of attachments) {
|
||||
// if (file?.type?.includes('image')) {
|
||||
// visionRequestDetected = true;
|
||||
// break;
|
||||
// }
|
||||
// }
|
||||
// if (!visionRequestDetected) {
|
||||
// return;
|
||||
// }
|
||||
// this.isVisionModel = validateVisionModel({ model: this.modelOptions.model, availableModels });
|
||||
// if (this.isVisionModel) {
|
||||
// delete this.modelOptions.stop;
|
||||
// return;
|
||||
// }
|
||||
// for (const model of availableModels) {
|
||||
// if (!validateVisionModel({ model, availableModels })) {
|
||||
// continue;
|
||||
// }
|
||||
// this.modelOptions.model = model;
|
||||
// this.isVisionModel = true;
|
||||
// delete this.modelOptions.stop;
|
||||
// return;
|
||||
// }
|
||||
// if (!availableModels.includes(this.defaultVisionModel)) {
|
||||
// return;
|
||||
// }
|
||||
// if (!validateVisionModel({ model: this.defaultVisionModel, availableModels })) {
|
||||
// return;
|
||||
// }
|
||||
// this.modelOptions.model = this.defaultVisionModel;
|
||||
// this.isVisionModel = true;
|
||||
// delete this.modelOptions.stop;
|
||||
}
|
||||
checkVisionRequest() {}
|
||||
|
||||
getSaveOptions() {
|
||||
// TODO:
|
||||
|
|
@ -269,24 +230,6 @@ class AgentClient extends BaseClient {
|
|||
.filter(Boolean)
|
||||
.join('\n')
|
||||
.trim();
|
||||
// this.systemMessage = getCurrentDateTime();
|
||||
// const { withKeys, withoutKeys } = await getFormattedMemories({
|
||||
// userId: this.options.req.user.id,
|
||||
// });
|
||||
// processMemory({
|
||||
// userId: this.options.req.user.id,
|
||||
// message: this.options.req.body.text,
|
||||
// parentMessageId,
|
||||
// memory: withKeys,
|
||||
// thread_id: this.conversationId,
|
||||
// }).catch((error) => {
|
||||
// logger.error('Memory Agent failed to process memory', error);
|
||||
// });
|
||||
|
||||
// this.systemMessage += '\n\n' + memoryInstructions;
|
||||
// if (withoutKeys) {
|
||||
// this.systemMessage += `\n\n# Existing memory about the user:\n${withoutKeys}`;
|
||||
// }
|
||||
|
||||
if (this.options.attachments) {
|
||||
const attachments = await this.options.attachments;
|
||||
|
|
@ -370,6 +313,37 @@ class AgentClient extends BaseClient {
|
|||
systemContent = this.augmentedPrompt + systemContent;
|
||||
}
|
||||
|
||||
// Inject MCP server instructions if available
|
||||
const ephemeralAgent = this.options.req.body.ephemeralAgent;
|
||||
let mcpServers = [];
|
||||
|
||||
// Check for ephemeral agent MCP servers
|
||||
if (ephemeralAgent && ephemeralAgent.mcp && ephemeralAgent.mcp.length > 0) {
|
||||
mcpServers = ephemeralAgent.mcp;
|
||||
}
|
||||
// Check for regular agent MCP tools
|
||||
else if (this.options.agent && this.options.agent.tools) {
|
||||
mcpServers = this.options.agent.tools
|
||||
.filter(
|
||||
(tool) =>
|
||||
tool instanceof DynamicStructuredTool && tool.name.includes(Constants.mcp_delimiter),
|
||||
)
|
||||
.map((tool) => tool.name.split(Constants.mcp_delimiter).pop())
|
||||
.filter(Boolean);
|
||||
}
|
||||
|
||||
if (mcpServers.length > 0) {
|
||||
try {
|
||||
const mcpInstructions = getMCPManager().formatInstructionsForContext(mcpServers);
|
||||
if (mcpInstructions) {
|
||||
systemContent = [systemContent, mcpInstructions].filter(Boolean).join('\n\n');
|
||||
logger.debug('[AgentClient] Injected MCP instructions for servers:', mcpServers);
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('[AgentClient] Failed to inject MCP instructions:', error);
|
||||
}
|
||||
}
|
||||
|
||||
if (systemContent) {
|
||||
this.options.agent.instructions = systemContent;
|
||||
}
|
||||
|
|
@ -399,9 +373,150 @@ class AgentClient extends BaseClient {
|
|||
opts.getReqData({ promptTokens });
|
||||
}
|
||||
|
||||
const withoutKeys = await this.useMemory();
|
||||
if (withoutKeys) {
|
||||
systemContent += `${memoryInstructions}\n\n# Existing memory about the user:\n${withoutKeys}`;
|
||||
}
|
||||
|
||||
if (systemContent) {
|
||||
this.options.agent.instructions = systemContent;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* @returns {Promise<string | undefined>}
|
||||
*/
|
||||
async useMemory() {
|
||||
const user = this.options.req.user;
|
||||
if (user.personalization?.memories === false) {
|
||||
return;
|
||||
}
|
||||
const hasAccess = await checkAccess(user, PermissionTypes.MEMORIES, [Permissions.USE]);
|
||||
|
||||
if (!hasAccess) {
|
||||
logger.debug(
|
||||
`[api/server/controllers/agents/client.js #useMemory] User ${user.id} does not have USE permission for memories`,
|
||||
);
|
||||
return;
|
||||
}
|
||||
/** @type {TCustomConfig['memory']} */
|
||||
const memoryConfig = this.options.req?.app?.locals?.memory;
|
||||
if (!memoryConfig || memoryConfig.disabled === true) {
|
||||
return;
|
||||
}
|
||||
|
||||
/** @type {Agent} */
|
||||
let prelimAgent;
|
||||
const allowedProviders = new Set(
|
||||
this.options.req?.app?.locals?.[EModelEndpoint.agents]?.allowedProviders,
|
||||
);
|
||||
try {
|
||||
if (memoryConfig.agent?.id != null && memoryConfig.agent.id !== this.options.agent.id) {
|
||||
prelimAgent = await loadAgent({
|
||||
req: this.options.req,
|
||||
agent_id: memoryConfig.agent.id,
|
||||
endpoint: EModelEndpoint.agents,
|
||||
});
|
||||
} else if (
|
||||
memoryConfig.agent?.id == null &&
|
||||
memoryConfig.agent?.model != null &&
|
||||
memoryConfig.agent?.provider != null
|
||||
) {
|
||||
prelimAgent = { id: Constants.EPHEMERAL_AGENT_ID, ...memoryConfig.agent };
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error(
|
||||
'[api/server/controllers/agents/client.js #useMemory] Error loading agent for memory',
|
||||
error,
|
||||
);
|
||||
}
|
||||
|
||||
const agent = await initializeAgent({
|
||||
req: this.options.req,
|
||||
res: this.options.res,
|
||||
agent: prelimAgent,
|
||||
allowedProviders,
|
||||
});
|
||||
|
||||
if (!agent) {
|
||||
logger.warn(
|
||||
'[api/server/controllers/agents/client.js #useMemory] No agent found for memory',
|
||||
memoryConfig,
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
const llmConfig = Object.assign(
|
||||
{
|
||||
provider: agent.provider,
|
||||
model: agent.model,
|
||||
},
|
||||
agent.model_parameters,
|
||||
);
|
||||
|
||||
/** @type {import('@librechat/api').MemoryConfig} */
|
||||
const config = {
|
||||
validKeys: memoryConfig.validKeys,
|
||||
instructions: agent.instructions,
|
||||
llmConfig,
|
||||
tokenLimit: memoryConfig.tokenLimit,
|
||||
};
|
||||
|
||||
const userId = this.options.req.user.id + '';
|
||||
const messageId = this.responseMessageId + '';
|
||||
const conversationId = this.conversationId + '';
|
||||
const [withoutKeys, processMemory] = await createMemoryProcessor({
|
||||
userId,
|
||||
config,
|
||||
messageId,
|
||||
conversationId,
|
||||
memoryMethods: {
|
||||
setMemory,
|
||||
deleteMemory,
|
||||
getFormattedMemories,
|
||||
},
|
||||
res: this.options.res,
|
||||
});
|
||||
|
||||
this.processMemory = processMemory;
|
||||
return withoutKeys;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param {BaseMessage[]} messages
|
||||
* @returns {Promise<void | (TAttachment | null)[]>}
|
||||
*/
|
||||
async runMemory(messages) {
|
||||
try {
|
||||
if (this.processMemory == null) {
|
||||
return;
|
||||
}
|
||||
/** @type {TCustomConfig['memory']} */
|
||||
const memoryConfig = this.options.req?.app?.locals?.memory;
|
||||
const messageWindowSize = memoryConfig?.messageWindowSize ?? 5;
|
||||
|
||||
let messagesToProcess = [...messages];
|
||||
if (messages.length > messageWindowSize) {
|
||||
for (let i = messages.length - messageWindowSize; i >= 0; i--) {
|
||||
const potentialWindow = messages.slice(i, i + messageWindowSize);
|
||||
if (potentialWindow[0]?.role === 'user') {
|
||||
messagesToProcess = [...potentialWindow];
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (messagesToProcess.length === messages.length) {
|
||||
messagesToProcess = [...messages.slice(-messageWindowSize)];
|
||||
}
|
||||
}
|
||||
return await this.processMemory(messagesToProcess);
|
||||
} catch (error) {
|
||||
logger.error('Memory Agent failed to process memory', error);
|
||||
}
|
||||
}
|
||||
|
||||
/** @type {sendCompletion} */
|
||||
async sendCompletion(payload, opts = {}) {
|
||||
await this.chatCompletion({
|
||||
|
|
@ -544,100 +659,13 @@ class AgentClient extends BaseClient {
|
|||
let config;
|
||||
/** @type {ReturnType<createRun>} */
|
||||
let run;
|
||||
/** @type {Promise<(TAttachment | null)[] | undefined>} */
|
||||
let memoryPromise;
|
||||
try {
|
||||
if (!abortController) {
|
||||
abortController = new AbortController();
|
||||
}
|
||||
|
||||
// if (this.options.headers) {
|
||||
// opts.defaultHeaders = { ...opts.defaultHeaders, ...this.options.headers };
|
||||
// }
|
||||
|
||||
// if (this.options.proxy) {
|
||||
// opts.httpAgent = new HttpsProxyAgent(this.options.proxy);
|
||||
// }
|
||||
|
||||
// if (this.isVisionModel) {
|
||||
// modelOptions.max_tokens = 4000;
|
||||
// }
|
||||
|
||||
// /** @type {TAzureConfig | undefined} */
|
||||
// const azureConfig = this.options?.req?.app?.locals?.[EModelEndpoint.azureOpenAI];
|
||||
|
||||
// if (
|
||||
// (this.azure && this.isVisionModel && azureConfig) ||
|
||||
// (azureConfig && this.isVisionModel && this.options.endpoint === EModelEndpoint.azureOpenAI)
|
||||
// ) {
|
||||
// const { modelGroupMap, groupMap } = azureConfig;
|
||||
// const {
|
||||
// azureOptions,
|
||||
// baseURL,
|
||||
// headers = {},
|
||||
// serverless,
|
||||
// } = mapModelToAzureConfig({
|
||||
// modelName: modelOptions.model,
|
||||
// modelGroupMap,
|
||||
// groupMap,
|
||||
// });
|
||||
// opts.defaultHeaders = resolveHeaders(headers);
|
||||
// this.langchainProxy = extractBaseURL(baseURL);
|
||||
// this.apiKey = azureOptions.azureOpenAIApiKey;
|
||||
|
||||
// const groupName = modelGroupMap[modelOptions.model].group;
|
||||
// this.options.addParams = azureConfig.groupMap[groupName].addParams;
|
||||
// this.options.dropParams = azureConfig.groupMap[groupName].dropParams;
|
||||
// // Note: `forcePrompt` not re-assigned as only chat models are vision models
|
||||
|
||||
// this.azure = !serverless && azureOptions;
|
||||
// this.azureEndpoint =
|
||||
// !serverless && genAzureChatCompletion(this.azure, modelOptions.model, this);
|
||||
// }
|
||||
|
||||
// if (this.azure || this.options.azure) {
|
||||
// /* Azure Bug, extremely short default `max_tokens` response */
|
||||
// if (!modelOptions.max_tokens && modelOptions.model === 'gpt-4-vision-preview') {
|
||||
// modelOptions.max_tokens = 4000;
|
||||
// }
|
||||
|
||||
// /* Azure does not accept `model` in the body, so we need to remove it. */
|
||||
// delete modelOptions.model;
|
||||
|
||||
// opts.baseURL = this.langchainProxy
|
||||
// ? constructAzureURL({
|
||||
// baseURL: this.langchainProxy,
|
||||
// azureOptions: this.azure,
|
||||
// })
|
||||
// : this.azureEndpoint.split(/(?<!\/)\/(chat|completion)\//)[0];
|
||||
|
||||
// opts.defaultQuery = { 'api-version': this.azure.azureOpenAIApiVersion };
|
||||
// opts.defaultHeaders = { ...opts.defaultHeaders, 'api-key': this.apiKey };
|
||||
// }
|
||||
|
||||
// if (process.env.OPENAI_ORGANIZATION) {
|
||||
// opts.organization = process.env.OPENAI_ORGANIZATION;
|
||||
// }
|
||||
|
||||
// if (this.options.addParams && typeof this.options.addParams === 'object') {
|
||||
// modelOptions = {
|
||||
// ...modelOptions,
|
||||
// ...this.options.addParams,
|
||||
// };
|
||||
// logger.debug('[api/server/controllers/agents/client.js #chatCompletion] added params', {
|
||||
// addParams: this.options.addParams,
|
||||
// modelOptions,
|
||||
// });
|
||||
// }
|
||||
|
||||
// if (this.options.dropParams && Array.isArray(this.options.dropParams)) {
|
||||
// this.options.dropParams.forEach((param) => {
|
||||
// delete modelOptions[param];
|
||||
// });
|
||||
// logger.debug('[api/server/controllers/agents/client.js #chatCompletion] dropped params', {
|
||||
// dropParams: this.options.dropParams,
|
||||
// modelOptions,
|
||||
// });
|
||||
// }
|
||||
|
||||
/** @type {TCustomConfig['endpoints']['agents']} */
|
||||
const agentsEConfig = this.options.req.app.locals[EModelEndpoint.agents];
|
||||
|
||||
|
|
@ -647,6 +675,7 @@ class AgentClient extends BaseClient {
|
|||
last_agent_index: this.agentConfigs?.size ?? 0,
|
||||
user_id: this.user ?? this.options.req.user?.id,
|
||||
hide_sequential_outputs: this.options.agent.hide_sequential_outputs,
|
||||
user: this.options.req.user,
|
||||
},
|
||||
recursionLimit: agentsEConfig?.recursionLimit,
|
||||
signal: abortController.signal,
|
||||
|
|
@ -654,6 +683,8 @@ class AgentClient extends BaseClient {
|
|||
version: 'v2',
|
||||
};
|
||||
|
||||
const getUserMCPAuthMap = await createGetMCPAuthMap();
|
||||
|
||||
const toolSet = new Set((this.options.agent.tools ?? []).map((tool) => tool && tool.name));
|
||||
let { messages: initialMessages, indexTokenCountMap } = formatAgentMessages(
|
||||
payload,
|
||||
|
|
@ -734,6 +765,10 @@ class AgentClient extends BaseClient {
|
|||
messages = addCacheControl(messages);
|
||||
}
|
||||
|
||||
if (i === 0) {
|
||||
memoryPromise = this.runMemory(messages);
|
||||
}
|
||||
|
||||
run = await createRun({
|
||||
agent,
|
||||
req: this.options.req,
|
||||
|
|
@ -769,10 +804,23 @@ class AgentClient extends BaseClient {
|
|||
run.Graph.contentData = contentData;
|
||||
}
|
||||
|
||||
const encoding = this.getEncoding();
|
||||
try {
|
||||
if (getUserMCPAuthMap) {
|
||||
config.configurable.userMCPAuthMap = await getUserMCPAuthMap({
|
||||
tools: agent.tools,
|
||||
userId: this.options.req.user.id,
|
||||
});
|
||||
}
|
||||
} catch (err) {
|
||||
logger.error(
|
||||
`[api/server/controllers/agents/client.js #chatCompletion] Error getting custom user vars for agent ${agent.id}`,
|
||||
err,
|
||||
);
|
||||
}
|
||||
|
||||
await run.processStream({ messages }, config, {
|
||||
keepContent: i !== 0,
|
||||
tokenCounter: createTokenCounter(encoding),
|
||||
tokenCounter: createTokenCounter(this.getEncoding()),
|
||||
indexTokenCountMap: currentIndexCountMap,
|
||||
maxContextTokens: agent.maxContextTokens,
|
||||
callbacks: {
|
||||
|
|
@ -887,6 +935,12 @@ class AgentClient extends BaseClient {
|
|||
});
|
||||
|
||||
try {
|
||||
if (memoryPromise) {
|
||||
const attachments = await memoryPromise;
|
||||
if (attachments && attachments.length > 0) {
|
||||
this.artifactPromises.push(...attachments);
|
||||
}
|
||||
}
|
||||
await this.recordCollectedUsage({ context: 'message' });
|
||||
} catch (err) {
|
||||
logger.error(
|
||||
|
|
@ -895,6 +949,12 @@ class AgentClient extends BaseClient {
|
|||
);
|
||||
}
|
||||
} catch (err) {
|
||||
if (memoryPromise) {
|
||||
const attachments = await memoryPromise;
|
||||
if (attachments && attachments.length > 0) {
|
||||
this.artifactPromises.push(...attachments);
|
||||
}
|
||||
}
|
||||
logger.error(
|
||||
'[api/server/controllers/agents/client.js #sendCompletion] Operation aborted',
|
||||
err,
|
||||
|
|
|
|||
|
|
@ -1,94 +0,0 @@
|
|||
const { Run, Providers } = require('@librechat/agents');
|
||||
const { providerEndpointMap, KnownEndpoints } = require('librechat-data-provider');
|
||||
|
||||
/**
|
||||
* @typedef {import('@librechat/agents').t} t
|
||||
* @typedef {import('@librechat/agents').StandardGraphConfig} StandardGraphConfig
|
||||
* @typedef {import('@librechat/agents').StreamEventData} StreamEventData
|
||||
* @typedef {import('@librechat/agents').EventHandler} EventHandler
|
||||
* @typedef {import('@librechat/agents').GraphEvents} GraphEvents
|
||||
* @typedef {import('@librechat/agents').LLMConfig} LLMConfig
|
||||
* @typedef {import('@librechat/agents').IState} IState
|
||||
*/
|
||||
|
||||
const customProviders = new Set([
|
||||
Providers.XAI,
|
||||
Providers.OLLAMA,
|
||||
Providers.DEEPSEEK,
|
||||
Providers.OPENROUTER,
|
||||
]);
|
||||
|
||||
/**
|
||||
* Creates a new Run instance with custom handlers and configuration.
|
||||
*
|
||||
* @param {Object} options - The options for creating the Run instance.
|
||||
* @param {ServerRequest} [options.req] - The server request.
|
||||
* @param {string | undefined} [options.runId] - Optional run ID; otherwise, a new run ID will be generated.
|
||||
* @param {Agent} options.agent - The agent for this run.
|
||||
* @param {AbortSignal} options.signal - The signal for this run.
|
||||
* @param {Record<GraphEvents, EventHandler> | undefined} [options.customHandlers] - Custom event handlers.
|
||||
* @param {boolean} [options.streaming=true] - Whether to use streaming.
|
||||
* @param {boolean} [options.streamUsage=true] - Whether to stream usage information.
|
||||
* @returns {Promise<Run<IState>>} A promise that resolves to a new Run instance.
|
||||
*/
|
||||
async function createRun({
|
||||
runId,
|
||||
agent,
|
||||
signal,
|
||||
customHandlers,
|
||||
streaming = true,
|
||||
streamUsage = true,
|
||||
}) {
|
||||
const provider = providerEndpointMap[agent.provider] ?? agent.provider;
|
||||
/** @type {LLMConfig} */
|
||||
const llmConfig = Object.assign(
|
||||
{
|
||||
provider,
|
||||
streaming,
|
||||
streamUsage,
|
||||
},
|
||||
agent.model_parameters,
|
||||
);
|
||||
|
||||
/** Resolves issues with new OpenAI usage field */
|
||||
if (
|
||||
customProviders.has(agent.provider) ||
|
||||
(agent.provider === Providers.OPENAI && agent.endpoint !== agent.provider)
|
||||
) {
|
||||
llmConfig.streamUsage = false;
|
||||
llmConfig.usage = true;
|
||||
}
|
||||
|
||||
/** @type {'reasoning_content' | 'reasoning'} */
|
||||
let reasoningKey;
|
||||
if (
|
||||
llmConfig.configuration?.baseURL?.includes(KnownEndpoints.openrouter) ||
|
||||
(agent.endpoint && agent.endpoint.toLowerCase().includes(KnownEndpoints.openrouter))
|
||||
) {
|
||||
reasoningKey = 'reasoning';
|
||||
}
|
||||
|
||||
/** @type {StandardGraphConfig} */
|
||||
const graphConfig = {
|
||||
signal,
|
||||
llmConfig,
|
||||
reasoningKey,
|
||||
tools: agent.tools,
|
||||
instructions: agent.instructions,
|
||||
additional_instructions: agent.additional_instructions,
|
||||
// toolEnd: agent.end_after_tools,
|
||||
};
|
||||
|
||||
// TEMPORARY FOR TESTING
|
||||
if (agent.provider === Providers.ANTHROPIC || agent.provider === Providers.BEDROCK) {
|
||||
graphConfig.streamBuffer = 2000;
|
||||
}
|
||||
|
||||
return Run.create({
|
||||
runId,
|
||||
graphConfig,
|
||||
customHandlers,
|
||||
});
|
||||
}
|
||||
|
||||
module.exports = { createRun };
|
||||
|
|
@ -1,9 +1,9 @@
|
|||
const fs = require('fs').promises;
|
||||
const { nanoid } = require('nanoid');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const {
|
||||
Tools,
|
||||
Constants,
|
||||
FileContext,
|
||||
FileSources,
|
||||
SystemRoles,
|
||||
EToolResources,
|
||||
|
|
@ -16,16 +16,16 @@ const {
|
|||
deleteAgent,
|
||||
getListAgents,
|
||||
} = require('~/models/Agent');
|
||||
const { uploadImageBuffer, filterFile } = require('~/server/services/Files/process');
|
||||
const { getStrategyFunctions } = require('~/server/services/Files/strategies');
|
||||
const { resizeAvatar } = require('~/server/services/Files/images/avatar');
|
||||
const { refreshS3Url } = require('~/server/services/Files/S3/crud');
|
||||
const { filterFile } = require('~/server/services/Files/process');
|
||||
const { updateAction, getActions } = require('~/models/Action');
|
||||
const { getCachedTools } = require('~/server/services/Config');
|
||||
const { updateAgentProjects } = require('~/models/Agent');
|
||||
const { getProjectByName } = require('~/models/Project');
|
||||
const { deleteFileByFilter } = require('~/models/File');
|
||||
const { revertAgentVersion } = require('~/models/Agent');
|
||||
const { logger } = require('~/config');
|
||||
const { deleteFileByFilter } = require('~/models/File');
|
||||
|
||||
const systemTools = {
|
||||
[Tools.execute_code]: true,
|
||||
|
|
@ -47,8 +47,9 @@ const createAgentHandler = async (req, res) => {
|
|||
|
||||
agentData.tools = [];
|
||||
|
||||
const availableTools = await getCachedTools({ includeGlobal: true });
|
||||
for (const tool of tools) {
|
||||
if (req.app.locals.availableTools[tool]) {
|
||||
if (availableTools[tool]) {
|
||||
agentData.tools.push(tool);
|
||||
}
|
||||
|
||||
|
|
@ -169,12 +170,18 @@ const updateAgentHandler = async (req, res) => {
|
|||
});
|
||||
}
|
||||
|
||||
/** @type {boolean} */
|
||||
const isProjectUpdate = (projectIds?.length ?? 0) > 0 || (removeProjectIds?.length ?? 0) > 0;
|
||||
|
||||
let updatedAgent =
|
||||
Object.keys(updateData).length > 0
|
||||
? await updateAgent({ id }, updateData, { updatingUserId: req.user.id })
|
||||
? await updateAgent({ id }, updateData, {
|
||||
updatingUserId: req.user.id,
|
||||
skipVersioning: isProjectUpdate,
|
||||
})
|
||||
: existingAgent;
|
||||
|
||||
if (projectIds || removeProjectIds) {
|
||||
if (isProjectUpdate) {
|
||||
updatedAgent = await updateAgentProjects({
|
||||
user: req.user,
|
||||
agentId: id,
|
||||
|
|
@ -387,6 +394,7 @@ const uploadAgentAvatarHandler = async (req, res) => {
|
|||
buffer: resizedBuffer,
|
||||
userId: req.user.id,
|
||||
manual: 'false',
|
||||
agentId: agent_id,
|
||||
});
|
||||
|
||||
const image = {
|
||||
|
|
@ -438,7 +446,7 @@ const uploadAgentAvatarHandler = async (req, res) => {
|
|||
try {
|
||||
await fs.unlink(req.file.path);
|
||||
logger.debug('[/:agent_id/avatar] Temp. image upload file deleted');
|
||||
} catch (error) {
|
||||
} catch {
|
||||
logger.debug('[/:agent_id/avatar] Temp. image upload file already deleted');
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
const fs = require('fs').promises;
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { FileContext } = require('librechat-data-provider');
|
||||
const { uploadImageBuffer, filterFile } = require('~/server/services/Files/process');
|
||||
const validateAuthor = require('~/server/middleware/assistants/validateAuthor');
|
||||
|
|
@ -6,9 +7,9 @@ const { getStrategyFunctions } = require('~/server/services/Files/strategies');
|
|||
const { deleteAssistantActions } = require('~/server/services/ActionService');
|
||||
const { updateAssistantDoc, getAssistants } = require('~/models/Assistant');
|
||||
const { getOpenAIClient, fetchAssistants } = require('./helpers');
|
||||
const { getCachedTools } = require('~/server/services/Config');
|
||||
const { manifestToolMap } = require('~/app/clients/tools');
|
||||
const { deleteFileByFilter } = require('~/models/File');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
/**
|
||||
* Create an assistant.
|
||||
|
|
@ -30,21 +31,20 @@ const createAssistant = async (req, res) => {
|
|||
delete assistantData.conversation_starters;
|
||||
delete assistantData.append_current_datetime;
|
||||
|
||||
const toolDefinitions = await getCachedTools({ includeGlobal: true });
|
||||
|
||||
assistantData.tools = tools
|
||||
.map((tool) => {
|
||||
if (typeof tool !== 'string') {
|
||||
return tool;
|
||||
}
|
||||
|
||||
const toolDefinitions = req.app.locals.availableTools;
|
||||
const toolDef = toolDefinitions[tool];
|
||||
if (!toolDef && manifestToolMap[tool] && manifestToolMap[tool].toolkit === true) {
|
||||
return (
|
||||
Object.entries(toolDefinitions)
|
||||
.filter(([key]) => key.startsWith(`${tool}_`))
|
||||
// eslint-disable-next-line no-unused-vars
|
||||
.map(([_, val]) => val)
|
||||
);
|
||||
return Object.entries(toolDefinitions)
|
||||
.filter(([key]) => key.startsWith(`${tool}_`))
|
||||
|
||||
.map(([_, val]) => val);
|
||||
}
|
||||
|
||||
return toolDef;
|
||||
|
|
@ -135,21 +135,21 @@ const patchAssistant = async (req, res) => {
|
|||
append_current_datetime,
|
||||
...updateData
|
||||
} = req.body;
|
||||
|
||||
const toolDefinitions = await getCachedTools({ includeGlobal: true });
|
||||
|
||||
updateData.tools = (updateData.tools ?? [])
|
||||
.map((tool) => {
|
||||
if (typeof tool !== 'string') {
|
||||
return tool;
|
||||
}
|
||||
|
||||
const toolDefinitions = req.app.locals.availableTools;
|
||||
const toolDef = toolDefinitions[tool];
|
||||
if (!toolDef && manifestToolMap[tool] && manifestToolMap[tool].toolkit === true) {
|
||||
return (
|
||||
Object.entries(toolDefinitions)
|
||||
.filter(([key]) => key.startsWith(`${tool}_`))
|
||||
// eslint-disable-next-line no-unused-vars
|
||||
.map(([_, val]) => val)
|
||||
);
|
||||
return Object.entries(toolDefinitions)
|
||||
.filter(([key]) => key.startsWith(`${tool}_`))
|
||||
|
||||
.map(([_, val]) => val);
|
||||
}
|
||||
|
||||
return toolDef;
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
const { logger } = require('@librechat/data-schemas');
|
||||
const { ToolCallTypes } = require('librechat-data-provider');
|
||||
const validateAuthor = require('~/server/middleware/assistants/validateAuthor');
|
||||
const { validateAndUpdateTool } = require('~/server/services/ActionService');
|
||||
const { getCachedTools } = require('~/server/services/Config');
|
||||
const { updateAssistantDoc } = require('~/models/Assistant');
|
||||
const { manifestToolMap } = require('~/app/clients/tools');
|
||||
const { getOpenAIClient } = require('./helpers');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
/**
|
||||
* Create an assistant.
|
||||
|
|
@ -27,21 +28,20 @@ const createAssistant = async (req, res) => {
|
|||
delete assistantData.conversation_starters;
|
||||
delete assistantData.append_current_datetime;
|
||||
|
||||
const toolDefinitions = await getCachedTools({ includeGlobal: true });
|
||||
|
||||
assistantData.tools = tools
|
||||
.map((tool) => {
|
||||
if (typeof tool !== 'string') {
|
||||
return tool;
|
||||
}
|
||||
|
||||
const toolDefinitions = req.app.locals.availableTools;
|
||||
const toolDef = toolDefinitions[tool];
|
||||
if (!toolDef && manifestToolMap[tool] && manifestToolMap[tool].toolkit === true) {
|
||||
return (
|
||||
Object.entries(toolDefinitions)
|
||||
.filter(([key]) => key.startsWith(`${tool}_`))
|
||||
// eslint-disable-next-line no-unused-vars
|
||||
.map(([_, val]) => val)
|
||||
);
|
||||
return Object.entries(toolDefinitions)
|
||||
.filter(([key]) => key.startsWith(`${tool}_`))
|
||||
|
||||
.map(([_, val]) => val);
|
||||
}
|
||||
|
||||
return toolDef;
|
||||
|
|
@ -125,13 +125,13 @@ const updateAssistant = async ({ req, openai, assistant_id, updateData }) => {
|
|||
|
||||
let hasFileSearch = false;
|
||||
for (const tool of updateData.tools ?? []) {
|
||||
const toolDefinitions = req.app.locals.availableTools;
|
||||
const toolDefinitions = await getCachedTools({ includeGlobal: true });
|
||||
let actualTool = typeof tool === 'string' ? toolDefinitions[tool] : tool;
|
||||
|
||||
if (!actualTool && manifestToolMap[tool] && manifestToolMap[tool].toolkit === true) {
|
||||
actualTool = Object.entries(toolDefinitions)
|
||||
.filter(([key]) => key.startsWith(`${tool}_`))
|
||||
// eslint-disable-next-line no-unused-vars
|
||||
|
||||
.map(([_, val]) => val);
|
||||
} else if (!actualTool) {
|
||||
continue;
|
||||
|
|
|
|||
|
|
@ -1,22 +1,22 @@
|
|||
require('dotenv').config();
|
||||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
require('module-alias')({ base: path.resolve(__dirname, '..') });
|
||||
const cors = require('cors');
|
||||
const axios = require('axios');
|
||||
const express = require('express');
|
||||
const compression = require('compression');
|
||||
const passport = require('passport');
|
||||
const mongoSanitize = require('express-mongo-sanitize');
|
||||
const fs = require('fs');
|
||||
const compression = require('compression');
|
||||
const cookieParser = require('cookie-parser');
|
||||
const { isEnabled } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const mongoSanitize = require('express-mongo-sanitize');
|
||||
const { connectDb, indexSync } = require('~/db');
|
||||
|
||||
const { jwtLogin, passportLogin } = require('~/strategies');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const { ldapLogin } = require('~/strategies');
|
||||
const { logger } = require('~/config');
|
||||
const validateImageRequest = require('./middleware/validateImageRequest');
|
||||
const { jwtLogin, ldapLogin, passportLogin } = require('~/strategies');
|
||||
const errorController = require('./controllers/ErrorController');
|
||||
const initializeMCP = require('./services/initializeMCP');
|
||||
const configureSocialLogins = require('./socialLogins');
|
||||
const AppService = require('./services/AppService');
|
||||
const staticCache = require('./utils/staticCache');
|
||||
|
|
@ -39,7 +39,9 @@ const startServer = async () => {
|
|||
await connectDb();
|
||||
|
||||
logger.info('Connected to MongoDB');
|
||||
await indexSync();
|
||||
indexSync().catch((err) => {
|
||||
logger.error('[indexSync] Background sync failed:', err);
|
||||
});
|
||||
|
||||
app.disable('x-powered-by');
|
||||
app.set('trust proxy', trusted_proxy);
|
||||
|
|
@ -117,8 +119,9 @@ const startServer = async () => {
|
|||
app.use('/api/agents', routes.agents);
|
||||
app.use('/api/banner', routes.banner);
|
||||
app.use('/api/bedrock', routes.bedrock);
|
||||
|
||||
app.use('/api/memories', routes.memories);
|
||||
app.use('/api/tags', routes.tags);
|
||||
app.use('/api/mcp', routes.mcp);
|
||||
|
||||
app.use((req, res) => {
|
||||
res.set({
|
||||
|
|
@ -142,6 +145,8 @@ const startServer = async () => {
|
|||
} else {
|
||||
logger.info(`Server listening at http://${host == '0.0.0.0' ? 'localhost' : host}:${port}`);
|
||||
}
|
||||
|
||||
initializeMCP(app);
|
||||
});
|
||||
};
|
||||
|
||||
|
|
@ -184,5 +189,5 @@ process.on('uncaughtException', (err) => {
|
|||
process.exit(1);
|
||||
});
|
||||
|
||||
// export app for easier testing purposes
|
||||
/** Export app for easier testing purposes */
|
||||
module.exports = app;
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
const checkAdmin = require('./checkAdmin');
|
||||
const { checkAccess, generateCheckAccess } = require('./generateCheckAccess');
|
||||
const checkAdmin = require('./admin');
|
||||
const { checkAccess, generateCheckAccess } = require('./access');
|
||||
|
||||
module.exports = {
|
||||
checkAdmin,
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
const { isEnabled } = require('@librechat/api');
|
||||
const { Constants, ViolationTypes, Time } = require('librechat-data-provider');
|
||||
const { searchConversation } = require('~/models/Conversation');
|
||||
const denyRequest = require('~/server/middleware/denyRequest');
|
||||
const { logViolation, getLogStores } = require('~/cache');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
|
||||
const { USE_REDIS, CONVO_ACCESS_VIOLATION_SCORE: score = 0 } = process.env ?? {};
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
const express = require('express');
|
||||
const jwt = require('jsonwebtoken');
|
||||
const { getAccessToken } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { CacheKeys } = require('librechat-data-provider');
|
||||
const { getAccessToken } = require('~/server/services/TokenService');
|
||||
const { logger, getFlowStateManager } = require('~/config');
|
||||
const { findToken, updateToken, createToken } = require('~/models');
|
||||
const { getFlowStateManager } = require('~/config');
|
||||
const { getLogStores } = require('~/cache');
|
||||
|
||||
const router = express.Router();
|
||||
|
|
@ -28,18 +30,19 @@ router.get('/:action_id/oauth/callback', async (req, res) => {
|
|||
try {
|
||||
decodedState = jwt.verify(state, JWT_SECRET);
|
||||
} catch (err) {
|
||||
logger.error('Error verifying state parameter:', err);
|
||||
await flowManager.failFlow(identifier, 'oauth', 'Invalid or expired state parameter');
|
||||
return res.status(400).send('Invalid or expired state parameter');
|
||||
return res.redirect('/oauth/error?error=invalid_state');
|
||||
}
|
||||
|
||||
if (decodedState.action_id !== action_id) {
|
||||
await flowManager.failFlow(identifier, 'oauth', 'Mismatched action ID in state parameter');
|
||||
return res.status(400).send('Mismatched action ID in state parameter');
|
||||
return res.redirect('/oauth/error?error=invalid_state');
|
||||
}
|
||||
|
||||
if (!decodedState.user) {
|
||||
await flowManager.failFlow(identifier, 'oauth', 'Invalid user ID in state parameter');
|
||||
return res.status(400).send('Invalid user ID in state parameter');
|
||||
return res.redirect('/oauth/error?error=invalid_state');
|
||||
}
|
||||
identifier = `${decodedState.user}:${action_id}`;
|
||||
const flowState = await flowManager.getFlowState(identifier, 'oauth');
|
||||
|
|
@ -47,90 +50,34 @@ router.get('/:action_id/oauth/callback', async (req, res) => {
|
|||
throw new Error('OAuth flow not found');
|
||||
}
|
||||
|
||||
const tokenData = await getAccessToken({
|
||||
code,
|
||||
userId: decodedState.user,
|
||||
identifier,
|
||||
client_url: flowState.metadata.client_url,
|
||||
redirect_uri: flowState.metadata.redirect_uri,
|
||||
/** Encrypted values */
|
||||
encrypted_oauth_client_id: flowState.metadata.encrypted_oauth_client_id,
|
||||
encrypted_oauth_client_secret: flowState.metadata.encrypted_oauth_client_secret,
|
||||
});
|
||||
const tokenData = await getAccessToken(
|
||||
{
|
||||
code,
|
||||
userId: decodedState.user,
|
||||
identifier,
|
||||
client_url: flowState.metadata.client_url,
|
||||
redirect_uri: flowState.metadata.redirect_uri,
|
||||
token_exchange_method: flowState.metadata.token_exchange_method,
|
||||
/** Encrypted values */
|
||||
encrypted_oauth_client_id: flowState.metadata.encrypted_oauth_client_id,
|
||||
encrypted_oauth_client_secret: flowState.metadata.encrypted_oauth_client_secret,
|
||||
},
|
||||
{
|
||||
findToken,
|
||||
updateToken,
|
||||
createToken,
|
||||
},
|
||||
);
|
||||
await flowManager.completeFlow(identifier, 'oauth', tokenData);
|
||||
res.send(`
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>Authentication Successful</title>
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<meta http-equiv="Content-Type" content="text/html; charset=UTF-8">
|
||||
<style>
|
||||
body {
|
||||
font-family: ui-sans-serif, system-ui, -apple-system, BlinkMacSystemFont;
|
||||
background-color: rgb(249, 250, 251);
|
||||
margin: 0;
|
||||
padding: 2rem;
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
min-height: 100vh;
|
||||
}
|
||||
.card {
|
||||
background-color: white;
|
||||
border-radius: 0.5rem;
|
||||
padding: 2rem;
|
||||
max-width: 28rem;
|
||||
width: 100%;
|
||||
box-shadow: 0 4px 6px -1px rgb(0 0 0 / 0.1), 0 2px 4px -2px rgb(0 0 0 / 0.1);
|
||||
text-align: center;
|
||||
}
|
||||
.heading {
|
||||
color: rgb(17, 24, 39);
|
||||
font-size: 1.875rem;
|
||||
font-weight: 700;
|
||||
margin: 0 0 1rem;
|
||||
}
|
||||
.description {
|
||||
color: rgb(75, 85, 99);
|
||||
font-size: 0.875rem;
|
||||
margin: 0.5rem 0;
|
||||
}
|
||||
.countdown {
|
||||
color: rgb(99, 102, 241);
|
||||
font-weight: 500;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="card">
|
||||
<h1 class="heading">Authentication Successful</h1>
|
||||
<p class="description">
|
||||
Your authentication was successful. This window will close in
|
||||
<span class="countdown" id="countdown">3</span> seconds.
|
||||
</p>
|
||||
</div>
|
||||
<script>
|
||||
let secondsLeft = 3;
|
||||
const countdownElement = document.getElementById('countdown');
|
||||
|
||||
const countdown = setInterval(() => {
|
||||
secondsLeft--;
|
||||
countdownElement.textContent = secondsLeft;
|
||||
|
||||
if (secondsLeft <= 0) {
|
||||
clearInterval(countdown);
|
||||
window.close();
|
||||
}
|
||||
}, 1000);
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
`);
|
||||
|
||||
/** Redirect to React success page */
|
||||
const serverName = flowState.metadata?.action_name || `Action ${action_id}`;
|
||||
const redirectUrl = `/oauth/success?serverName=${encodeURIComponent(serverName)}`;
|
||||
res.redirect(redirectUrl);
|
||||
} catch (error) {
|
||||
logger.error('Error in OAuth callback:', error);
|
||||
await flowManager.failFlow(identifier, 'oauth', error);
|
||||
res.status(500).send('Authentication failed. Please try again.');
|
||||
res.redirect('/oauth/error?error=callback_failed');
|
||||
}
|
||||
});
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
const express = require('express');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { CacheKeys, defaultSocialLogins, Constants } = require('librechat-data-provider');
|
||||
const { getCustomConfig } = require('~/server/services/Config/getCustomConfig');
|
||||
const { getLdapConfig } = require('~/server/services/Config/ldap');
|
||||
const { getProjectByName } = require('~/models/Project');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const { getLogStores } = require('~/cache');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const router = express.Router();
|
||||
const emailLoginEnabled =
|
||||
|
|
@ -21,6 +22,7 @@ const publicSharedLinksEnabled =
|
|||
|
||||
router.get('/', async function (req, res) {
|
||||
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
||||
|
||||
const cachedStartupConfig = await cache.get(CacheKeys.STARTUP_CONFIG);
|
||||
if (cachedStartupConfig) {
|
||||
res.send(cachedStartupConfig);
|
||||
|
|
@ -96,6 +98,18 @@ router.get('/', async function (req, res) {
|
|||
bundlerURL: process.env.SANDPACK_BUNDLER_URL,
|
||||
staticBundlerURL: process.env.SANDPACK_STATIC_BUNDLER_URL,
|
||||
};
|
||||
|
||||
payload.mcpServers = {};
|
||||
const config = await getCustomConfig();
|
||||
if (config?.mcpServers != null) {
|
||||
for (const serverName in config.mcpServers) {
|
||||
const serverConfig = config.mcpServers[serverName];
|
||||
payload.mcpServers[serverName] = {
|
||||
customUserVars: serverConfig?.customUserVars || {},
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
/** @type {TCustomConfig['webSearch']} */
|
||||
const webSearchConfig = req.app.locals.webSearch;
|
||||
if (
|
||||
|
|
|
|||
|
|
@ -65,8 +65,14 @@ router.post('/gen_title', async (req, res) => {
|
|||
let title = await titleCache.get(key);
|
||||
|
||||
if (!title) {
|
||||
await sleep(2500);
|
||||
title = await titleCache.get(key);
|
||||
// Retry every 1s for up to 20s
|
||||
for (let i = 0; i < 20; i++) {
|
||||
await sleep(1000);
|
||||
title = await titleCache.get(key);
|
||||
if (title) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (title) {
|
||||
|
|
|
|||
|
|
@ -2,8 +2,8 @@ const fs = require('fs');
|
|||
const path = require('path');
|
||||
const crypto = require('crypto');
|
||||
const multer = require('multer');
|
||||
const { sanitizeFilename } = require('@librechat/api');
|
||||
const { fileConfig: defaultFileConfig, mergeFileConfig } = require('librechat-data-provider');
|
||||
const { sanitizeFilename } = require('~/server/utils/handleText');
|
||||
const { getCustomConfig } = require('~/server/services/Config');
|
||||
|
||||
const storage = multer.diskStorage({
|
||||
|
|
|
|||
571
api/server/routes/files/multer.spec.js
Normal file
571
api/server/routes/files/multer.spec.js
Normal file
|
|
@ -0,0 +1,571 @@
|
|||
/* eslint-disable no-unused-vars */
|
||||
/* eslint-disable jest/no-done-callback */
|
||||
const fs = require('fs');
|
||||
const os = require('os');
|
||||
const path = require('path');
|
||||
const crypto = require('crypto');
|
||||
const { createMulterInstance, storage, importFileFilter } = require('./multer');
|
||||
|
||||
// Mock only the config service that requires external dependencies
|
||||
jest.mock('~/server/services/Config', () => ({
|
||||
getCustomConfig: jest.fn(() =>
|
||||
Promise.resolve({
|
||||
fileConfig: {
|
||||
endpoints: {
|
||||
openAI: {
|
||||
supportedMimeTypes: ['image/jpeg', 'image/png', 'application/pdf'],
|
||||
},
|
||||
default: {
|
||||
supportedMimeTypes: ['image/jpeg', 'image/png', 'text/plain'],
|
||||
},
|
||||
},
|
||||
serverFileSizeLimit: 10000000, // 10MB
|
||||
},
|
||||
}),
|
||||
),
|
||||
}));
|
||||
|
||||
describe('Multer Configuration', () => {
|
||||
let tempDir;
|
||||
let mockReq;
|
||||
let mockFile;
|
||||
|
||||
beforeEach(() => {
|
||||
// Create a temporary directory for each test
|
||||
tempDir = fs.mkdtempSync(path.join(os.tmpdir(), 'multer-test-'));
|
||||
|
||||
mockReq = {
|
||||
user: { id: 'test-user-123' },
|
||||
app: {
|
||||
locals: {
|
||||
paths: {
|
||||
uploads: tempDir,
|
||||
},
|
||||
},
|
||||
},
|
||||
body: {},
|
||||
originalUrl: '/api/files/upload',
|
||||
};
|
||||
|
||||
mockFile = {
|
||||
originalname: 'test-file.jpg',
|
||||
mimetype: 'image/jpeg',
|
||||
size: 1024,
|
||||
};
|
||||
|
||||
// Clear mocks
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
// Clean up temporary directory
|
||||
if (fs.existsSync(tempDir)) {
|
||||
fs.rmSync(tempDir, { recursive: true, force: true });
|
||||
}
|
||||
});
|
||||
|
||||
describe('Storage Configuration', () => {
|
||||
describe('destination function', () => {
|
||||
it('should create the correct destination path', (done) => {
|
||||
const cb = jest.fn((err, destination) => {
|
||||
expect(err).toBeNull();
|
||||
expect(destination).toBe(path.join(tempDir, 'temp', 'test-user-123'));
|
||||
expect(fs.existsSync(destination)).toBe(true);
|
||||
done();
|
||||
});
|
||||
|
||||
storage.getDestination(mockReq, mockFile, cb);
|
||||
});
|
||||
|
||||
it("should create directory recursively if it doesn't exist", (done) => {
|
||||
const deepPath = path.join(tempDir, 'deep', 'nested', 'path');
|
||||
mockReq.app.locals.paths.uploads = deepPath;
|
||||
|
||||
const cb = jest.fn((err, destination) => {
|
||||
expect(err).toBeNull();
|
||||
expect(destination).toBe(path.join(deepPath, 'temp', 'test-user-123'));
|
||||
expect(fs.existsSync(destination)).toBe(true);
|
||||
done();
|
||||
});
|
||||
|
||||
storage.getDestination(mockReq, mockFile, cb);
|
||||
});
|
||||
});
|
||||
|
||||
describe('filename function', () => {
|
||||
it('should generate a UUID for req.file_id', (done) => {
|
||||
const cb = jest.fn((err, filename) => {
|
||||
expect(err).toBeNull();
|
||||
expect(mockReq.file_id).toBeDefined();
|
||||
expect(mockReq.file_id).toMatch(
|
||||
/^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$/i,
|
||||
);
|
||||
done();
|
||||
});
|
||||
|
||||
storage.getFilename(mockReq, mockFile, cb);
|
||||
});
|
||||
|
||||
it('should decode URI components in filename', (done) => {
|
||||
const encodedFile = {
|
||||
...mockFile,
|
||||
originalname: encodeURIComponent('test file with spaces.jpg'),
|
||||
};
|
||||
|
||||
const cb = jest.fn((err, filename) => {
|
||||
expect(err).toBeNull();
|
||||
expect(encodedFile.originalname).toBe('test file with spaces.jpg');
|
||||
done();
|
||||
});
|
||||
|
||||
storage.getFilename(mockReq, encodedFile, cb);
|
||||
});
|
||||
|
||||
it('should call real sanitizeFilename with properly encoded filename', (done) => {
|
||||
// Test with a properly URI-encoded filename that needs sanitization
|
||||
const unsafeFile = {
|
||||
...mockFile,
|
||||
originalname: encodeURIComponent('test@#$%^&*()file with spaces!.jpg'),
|
||||
};
|
||||
|
||||
const cb = jest.fn((err, filename) => {
|
||||
expect(err).toBeNull();
|
||||
// The actual sanitizeFilename should have cleaned this up after decoding
|
||||
expect(filename).not.toContain('@');
|
||||
expect(filename).not.toContain('#');
|
||||
expect(filename).not.toContain('*');
|
||||
expect(filename).not.toContain('!');
|
||||
// Should still preserve dots and hyphens
|
||||
expect(filename).toContain('.jpg');
|
||||
done();
|
||||
});
|
||||
|
||||
storage.getFilename(mockReq, unsafeFile, cb);
|
||||
});
|
||||
|
||||
it('should handle very long filenames with actual crypto', (done) => {
|
||||
const longFile = {
|
||||
...mockFile,
|
||||
originalname: 'a'.repeat(300) + '.jpg',
|
||||
};
|
||||
|
||||
const cb = jest.fn((err, filename) => {
|
||||
expect(err).toBeNull();
|
||||
expect(filename.length).toBeLessThanOrEqual(255);
|
||||
expect(filename).toMatch(/\.jpg$/); // Should still end with .jpg
|
||||
// Should contain a hex suffix if truncated
|
||||
if (filename.length === 255) {
|
||||
expect(filename).toMatch(/-[a-f0-9]{6}\.jpg$/);
|
||||
}
|
||||
done();
|
||||
});
|
||||
|
||||
storage.getFilename(mockReq, longFile, cb);
|
||||
});
|
||||
|
||||
it('should generate unique file_id for each call', (done) => {
|
||||
let firstFileId;
|
||||
|
||||
const firstCb = jest.fn((err, filename) => {
|
||||
expect(err).toBeNull();
|
||||
firstFileId = mockReq.file_id;
|
||||
|
||||
// Reset req for second call
|
||||
delete mockReq.file_id;
|
||||
|
||||
const secondCb = jest.fn((err, filename) => {
|
||||
expect(err).toBeNull();
|
||||
expect(mockReq.file_id).toBeDefined();
|
||||
expect(mockReq.file_id).not.toBe(firstFileId);
|
||||
done();
|
||||
});
|
||||
|
||||
storage.getFilename(mockReq, mockFile, secondCb);
|
||||
});
|
||||
|
||||
storage.getFilename(mockReq, mockFile, firstCb);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('Import File Filter', () => {
|
||||
it('should accept JSON files by mimetype', (done) => {
|
||||
const jsonFile = {
|
||||
...mockFile,
|
||||
mimetype: 'application/json',
|
||||
originalname: 'data.json',
|
||||
};
|
||||
|
||||
const cb = jest.fn((err, result) => {
|
||||
expect(err).toBeNull();
|
||||
expect(result).toBe(true);
|
||||
done();
|
||||
});
|
||||
|
||||
importFileFilter(mockReq, jsonFile, cb);
|
||||
});
|
||||
|
||||
it('should accept files with .json extension', (done) => {
|
||||
const jsonFile = {
|
||||
...mockFile,
|
||||
mimetype: 'text/plain',
|
||||
originalname: 'data.json',
|
||||
};
|
||||
|
||||
const cb = jest.fn((err, result) => {
|
||||
expect(err).toBeNull();
|
||||
expect(result).toBe(true);
|
||||
done();
|
||||
});
|
||||
|
||||
importFileFilter(mockReq, jsonFile, cb);
|
||||
});
|
||||
|
||||
it('should reject non-JSON files', (done) => {
|
||||
const textFile = {
|
||||
...mockFile,
|
||||
mimetype: 'text/plain',
|
||||
originalname: 'document.txt',
|
||||
};
|
||||
|
||||
const cb = jest.fn((err, result) => {
|
||||
expect(err).toBeInstanceOf(Error);
|
||||
expect(err.message).toBe('Only JSON files are allowed');
|
||||
expect(result).toBe(false);
|
||||
done();
|
||||
});
|
||||
|
||||
importFileFilter(mockReq, textFile, cb);
|
||||
});
|
||||
|
||||
it('should handle files with uppercase .JSON extension', (done) => {
|
||||
const jsonFile = {
|
||||
...mockFile,
|
||||
mimetype: 'text/plain',
|
||||
originalname: 'DATA.JSON',
|
||||
};
|
||||
|
||||
const cb = jest.fn((err, result) => {
|
||||
expect(err).toBeNull();
|
||||
expect(result).toBe(true);
|
||||
done();
|
||||
});
|
||||
|
||||
importFileFilter(mockReq, jsonFile, cb);
|
||||
});
|
||||
});
|
||||
|
||||
describe('File Filter with Real defaultFileConfig', () => {
|
||||
it('should use real fileConfig.checkType for validation', async () => {
|
||||
// Test with actual librechat-data-provider functions
|
||||
const {
|
||||
fileConfig,
|
||||
imageMimeTypes,
|
||||
applicationMimeTypes,
|
||||
} = require('librechat-data-provider');
|
||||
|
||||
// Test that the real checkType function works with regex patterns
|
||||
expect(fileConfig.checkType('image/jpeg', [imageMimeTypes])).toBe(true);
|
||||
expect(fileConfig.checkType('video/mp4', [imageMimeTypes])).toBe(false);
|
||||
expect(fileConfig.checkType('application/pdf', [applicationMimeTypes])).toBe(true);
|
||||
expect(fileConfig.checkType('application/pdf', [])).toBe(false);
|
||||
});
|
||||
|
||||
it('should handle audio files for speech-to-text endpoint with real config', async () => {
|
||||
mockReq.originalUrl = '/api/speech/stt';
|
||||
|
||||
const multerInstance = await createMulterInstance();
|
||||
expect(multerInstance).toBeDefined();
|
||||
expect(typeof multerInstance.single).toBe('function');
|
||||
});
|
||||
|
||||
it('should reject unsupported file types using real config', async () => {
|
||||
// Mock defaultFileConfig for this specific test
|
||||
const originalCheckType = require('librechat-data-provider').fileConfig.checkType;
|
||||
const mockCheckType = jest.fn().mockReturnValue(false);
|
||||
require('librechat-data-provider').fileConfig.checkType = mockCheckType;
|
||||
|
||||
try {
|
||||
const multerInstance = await createMulterInstance();
|
||||
expect(multerInstance).toBeDefined();
|
||||
|
||||
// Test the actual file filter behavior would reject unsupported files
|
||||
expect(mockCheckType).toBeDefined();
|
||||
} finally {
|
||||
// Restore original function
|
||||
require('librechat-data-provider').fileConfig.checkType = originalCheckType;
|
||||
}
|
||||
});
|
||||
|
||||
it('should use real mergeFileConfig function', async () => {
|
||||
const { mergeFileConfig, mbToBytes } = require('librechat-data-provider');
|
||||
|
||||
// Test with actual merge function - note that it converts MB to bytes
|
||||
const testConfig = {
|
||||
serverFileSizeLimit: 5, // 5 MB
|
||||
endpoints: {
|
||||
custom: {
|
||||
supportedMimeTypes: ['text/plain'],
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const result = mergeFileConfig(testConfig);
|
||||
|
||||
// The function converts MB to bytes, so 5 MB becomes 5 * 1024 * 1024 bytes
|
||||
expect(result.serverFileSizeLimit).toBe(mbToBytes(5));
|
||||
expect(result.endpoints.custom.supportedMimeTypes).toBeDefined();
|
||||
// Should still have the default endpoints
|
||||
expect(result.endpoints.default).toBeDefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('createMulterInstance with Real Functions', () => {
|
||||
it('should create a multer instance with correct configuration', async () => {
|
||||
const multerInstance = await createMulterInstance();
|
||||
|
||||
expect(multerInstance).toBeDefined();
|
||||
expect(typeof multerInstance.single).toBe('function');
|
||||
expect(typeof multerInstance.array).toBe('function');
|
||||
expect(typeof multerInstance.fields).toBe('function');
|
||||
});
|
||||
|
||||
it('should use real config merging', async () => {
|
||||
const { getCustomConfig } = require('~/server/services/Config');
|
||||
|
||||
const multerInstance = await createMulterInstance();
|
||||
|
||||
expect(getCustomConfig).toHaveBeenCalled();
|
||||
expect(multerInstance).toBeDefined();
|
||||
});
|
||||
|
||||
it('should create multer instance with expected interface', async () => {
|
||||
const multerInstance = await createMulterInstance();
|
||||
|
||||
expect(multerInstance).toBeDefined();
|
||||
expect(typeof multerInstance.single).toBe('function');
|
||||
expect(typeof multerInstance.array).toBe('function');
|
||||
expect(typeof multerInstance.fields).toBe('function');
|
||||
});
|
||||
});
|
||||
|
||||
describe('Real Crypto Integration', () => {
|
||||
it('should use actual crypto.randomUUID()', (done) => {
|
||||
// Spy on crypto.randomUUID to ensure it's called
|
||||
const uuidSpy = jest.spyOn(crypto, 'randomUUID');
|
||||
|
||||
const cb = jest.fn((err, filename) => {
|
||||
expect(err).toBeNull();
|
||||
expect(uuidSpy).toHaveBeenCalled();
|
||||
expect(mockReq.file_id).toMatch(
|
||||
/^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$/i,
|
||||
);
|
||||
|
||||
uuidSpy.mockRestore();
|
||||
done();
|
||||
});
|
||||
|
||||
storage.getFilename(mockReq, mockFile, cb);
|
||||
});
|
||||
|
||||
it('should generate different UUIDs on subsequent calls', (done) => {
|
||||
const uuids = [];
|
||||
let callCount = 0;
|
||||
const totalCalls = 5;
|
||||
|
||||
const cb = jest.fn((err, filename) => {
|
||||
expect(err).toBeNull();
|
||||
uuids.push(mockReq.file_id);
|
||||
callCount++;
|
||||
|
||||
if (callCount === totalCalls) {
|
||||
// Check that all UUIDs are unique
|
||||
const uniqueUuids = new Set(uuids);
|
||||
expect(uniqueUuids.size).toBe(totalCalls);
|
||||
done();
|
||||
} else {
|
||||
// Reset for next call
|
||||
delete mockReq.file_id;
|
||||
storage.getFilename(mockReq, mockFile, cb);
|
||||
}
|
||||
});
|
||||
|
||||
// Start the chain
|
||||
storage.getFilename(mockReq, mockFile, cb);
|
||||
});
|
||||
|
||||
it('should generate cryptographically secure UUIDs', (done) => {
|
||||
const generatedUuids = new Set();
|
||||
let callCount = 0;
|
||||
const totalCalls = 10;
|
||||
|
||||
const cb = jest.fn((err, filename) => {
|
||||
expect(err).toBeNull();
|
||||
|
||||
// Verify UUID format and uniqueness
|
||||
expect(mockReq.file_id).toMatch(
|
||||
/^[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$/i,
|
||||
);
|
||||
|
||||
generatedUuids.add(mockReq.file_id);
|
||||
callCount++;
|
||||
|
||||
if (callCount === totalCalls) {
|
||||
// All UUIDs should be unique
|
||||
expect(generatedUuids.size).toBe(totalCalls);
|
||||
done();
|
||||
} else {
|
||||
// Reset for next call
|
||||
delete mockReq.file_id;
|
||||
storage.getFilename(mockReq, mockFile, cb);
|
||||
}
|
||||
});
|
||||
|
||||
// Start the chain
|
||||
storage.getFilename(mockReq, mockFile, cb);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Error Handling', () => {
|
||||
it('should handle CVE-2024-28870: empty field name DoS vulnerability', async () => {
|
||||
// Test for the CVE where empty field name could cause unhandled exception
|
||||
const multerInstance = await createMulterInstance();
|
||||
|
||||
// Create a mock request with empty field name (the vulnerability scenario)
|
||||
const mockReqWithEmptyField = {
|
||||
...mockReq,
|
||||
headers: {
|
||||
'content-type': 'multipart/form-data',
|
||||
},
|
||||
};
|
||||
|
||||
const mockRes = {
|
||||
status: jest.fn().mockReturnThis(),
|
||||
json: jest.fn(),
|
||||
end: jest.fn(),
|
||||
};
|
||||
|
||||
// This should not crash or throw unhandled exceptions
|
||||
const uploadMiddleware = multerInstance.single(''); // Empty field name
|
||||
|
||||
const mockNext = jest.fn((err) => {
|
||||
// If there's an error, it should be handled gracefully, not crash
|
||||
if (err) {
|
||||
expect(err).toBeInstanceOf(Error);
|
||||
// The error should be handled, not crash the process
|
||||
}
|
||||
});
|
||||
|
||||
// This should complete without crashing the process
|
||||
expect(() => {
|
||||
uploadMiddleware(mockReqWithEmptyField, mockRes, mockNext);
|
||||
}).not.toThrow();
|
||||
});
|
||||
|
||||
it('should handle file system errors when directory creation fails', (done) => {
|
||||
// Test with a non-existent parent directory to simulate fs issues
|
||||
const invalidPath = '/nonexistent/path/that/should/not/exist';
|
||||
mockReq.app.locals.paths.uploads = invalidPath;
|
||||
|
||||
try {
|
||||
// Call getDestination which should fail due to permission/path issues
|
||||
storage.getDestination(mockReq, mockFile, (err, destination) => {
|
||||
// If callback is reached, we didn't get the expected error
|
||||
done(new Error('Expected mkdirSync to throw an error but callback was called'));
|
||||
});
|
||||
// If we get here without throwing, something unexpected happened
|
||||
done(new Error('Expected mkdirSync to throw an error but no error was thrown'));
|
||||
} catch (error) {
|
||||
// This is the expected behavior - mkdirSync throws synchronously for invalid paths
|
||||
expect(error.code).toBe('EACCES');
|
||||
done();
|
||||
}
|
||||
});
|
||||
|
||||
it('should handle malformed filenames with real sanitization', (done) => {
|
||||
const malformedFile = {
|
||||
...mockFile,
|
||||
originalname: null, // This should be handled gracefully
|
||||
};
|
||||
|
||||
const cb = jest.fn((err, filename) => {
|
||||
// The function should handle this gracefully
|
||||
expect(typeof err === 'object' || err === null).toBe(true);
|
||||
done();
|
||||
});
|
||||
|
||||
try {
|
||||
storage.getFilename(mockReq, malformedFile, cb);
|
||||
} catch (error) {
|
||||
// If it throws, that's also acceptable behavior
|
||||
done();
|
||||
}
|
||||
});
|
||||
|
||||
it('should handle edge cases in filename sanitization', (done) => {
|
||||
const edgeCaseFiles = [
|
||||
{ originalname: '', expected: /_/ },
|
||||
{ originalname: '.hidden', expected: /^_\.hidden/ },
|
||||
{ originalname: '../../../etc/passwd', expected: /passwd/ },
|
||||
{ originalname: 'file\x00name.txt', expected: /file_name\.txt/ },
|
||||
];
|
||||
|
||||
let testCount = 0;
|
||||
|
||||
const testNextFile = (fileData) => {
|
||||
const fileToTest = { ...mockFile, originalname: fileData.originalname };
|
||||
|
||||
const cb = jest.fn((err, filename) => {
|
||||
expect(err).toBeNull();
|
||||
expect(filename).toMatch(fileData.expected);
|
||||
|
||||
testCount++;
|
||||
if (testCount === edgeCaseFiles.length) {
|
||||
done();
|
||||
} else {
|
||||
testNextFile(edgeCaseFiles[testCount]);
|
||||
}
|
||||
});
|
||||
|
||||
storage.getFilename(mockReq, fileToTest, cb);
|
||||
};
|
||||
|
||||
testNextFile(edgeCaseFiles[0]);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Real Configuration Testing', () => {
|
||||
it('should handle missing custom config gracefully with real mergeFileConfig', async () => {
|
||||
const { getCustomConfig } = require('~/server/services/Config');
|
||||
|
||||
// Mock getCustomConfig to return undefined
|
||||
getCustomConfig.mockResolvedValueOnce(undefined);
|
||||
|
||||
const multerInstance = await createMulterInstance();
|
||||
expect(multerInstance).toBeDefined();
|
||||
expect(typeof multerInstance.single).toBe('function');
|
||||
});
|
||||
|
||||
it('should properly integrate real fileConfig with custom endpoints', async () => {
|
||||
const { getCustomConfig } = require('~/server/services/Config');
|
||||
|
||||
// Mock a custom config with additional endpoints
|
||||
getCustomConfig.mockResolvedValueOnce({
|
||||
fileConfig: {
|
||||
endpoints: {
|
||||
anthropic: {
|
||||
supportedMimeTypes: ['text/plain', 'image/png'],
|
||||
},
|
||||
},
|
||||
serverFileSizeLimit: 20, // 20 MB
|
||||
},
|
||||
});
|
||||
|
||||
const multerInstance = await createMulterInstance();
|
||||
expect(multerInstance).toBeDefined();
|
||||
|
||||
// Verify that getCustomConfig was called (we can't spy on the actual merge function easily)
|
||||
expect(getCustomConfig).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
@ -4,6 +4,7 @@ const tokenizer = require('./tokenizer');
|
|||
const endpoints = require('./endpoints');
|
||||
const staticRoute = require('./static');
|
||||
const messages = require('./messages');
|
||||
const memories = require('./memories');
|
||||
const presets = require('./presets');
|
||||
const prompts = require('./prompts');
|
||||
const balance = require('./balance');
|
||||
|
|
@ -26,6 +27,7 @@ const edit = require('./edit');
|
|||
const keys = require('./keys');
|
||||
const user = require('./user');
|
||||
const ask = require('./ask');
|
||||
const mcp = require('./mcp');
|
||||
|
||||
module.exports = {
|
||||
ask,
|
||||
|
|
@ -51,9 +53,11 @@ module.exports = {
|
|||
presets,
|
||||
balance,
|
||||
messages,
|
||||
memories,
|
||||
endpoints,
|
||||
tokenizer,
|
||||
assistants,
|
||||
categories,
|
||||
staticRoute,
|
||||
mcp,
|
||||
};
|
||||
|
|
|
|||
205
api/server/routes/mcp.js
Normal file
205
api/server/routes/mcp.js
Normal file
|
|
@ -0,0 +1,205 @@
|
|||
const { Router } = require('express');
|
||||
const { MCPOAuthHandler } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { CacheKeys } = require('librechat-data-provider');
|
||||
const { requireJwtAuth } = require('~/server/middleware');
|
||||
const { getFlowStateManager } = require('~/config');
|
||||
const { getLogStores } = require('~/cache');
|
||||
|
||||
const router = Router();
|
||||
|
||||
/**
|
||||
* Initiate OAuth flow
|
||||
* This endpoint is called when the user clicks the auth link in the UI
|
||||
*/
|
||||
router.get('/:serverName/oauth/initiate', requireJwtAuth, async (req, res) => {
|
||||
try {
|
||||
const { serverName } = req.params;
|
||||
const { userId, flowId } = req.query;
|
||||
const user = req.user;
|
||||
|
||||
// Verify the userId matches the authenticated user
|
||||
if (userId !== user.id) {
|
||||
return res.status(403).json({ error: 'User mismatch' });
|
||||
}
|
||||
|
||||
logger.debug('[MCP OAuth] Initiate request', { serverName, userId, flowId });
|
||||
|
||||
const flowsCache = getLogStores(CacheKeys.FLOWS);
|
||||
const flowManager = getFlowStateManager(flowsCache);
|
||||
|
||||
/** Flow state to retrieve OAuth config */
|
||||
const flowState = await flowManager.getFlowState(flowId, 'mcp_oauth');
|
||||
if (!flowState) {
|
||||
logger.error('[MCP OAuth] Flow state not found', { flowId });
|
||||
return res.status(404).json({ error: 'Flow not found' });
|
||||
}
|
||||
|
||||
const { serverUrl, oauth: oauthConfig } = flowState.metadata || {};
|
||||
if (!serverUrl || !oauthConfig) {
|
||||
logger.error('[MCP OAuth] Missing server URL or OAuth config in flow state');
|
||||
return res.status(400).json({ error: 'Invalid flow state' });
|
||||
}
|
||||
|
||||
const { authorizationUrl, flowId: oauthFlowId } = await MCPOAuthHandler.initiateOAuthFlow(
|
||||
serverName,
|
||||
serverUrl,
|
||||
userId,
|
||||
oauthConfig,
|
||||
);
|
||||
|
||||
logger.debug('[MCP OAuth] OAuth flow initiated', { oauthFlowId, authorizationUrl });
|
||||
|
||||
// Redirect user to the authorization URL
|
||||
res.redirect(authorizationUrl);
|
||||
} catch (error) {
|
||||
logger.error('[MCP OAuth] Failed to initiate OAuth', error);
|
||||
res.status(500).json({ error: 'Failed to initiate OAuth' });
|
||||
}
|
||||
});
|
||||
|
||||
/**
|
||||
* OAuth callback handler
|
||||
* This handles the OAuth callback after the user has authorized the application
|
||||
*/
|
||||
router.get('/:serverName/oauth/callback', async (req, res) => {
|
||||
try {
|
||||
const { serverName } = req.params;
|
||||
const { code, state, error: oauthError } = req.query;
|
||||
|
||||
logger.debug('[MCP OAuth] Callback received', {
|
||||
serverName,
|
||||
code: code ? 'present' : 'missing',
|
||||
state,
|
||||
error: oauthError,
|
||||
});
|
||||
|
||||
if (oauthError) {
|
||||
logger.error('[MCP OAuth] OAuth error received', { error: oauthError });
|
||||
return res.redirect(`/oauth/error?error=${encodeURIComponent(String(oauthError))}`);
|
||||
}
|
||||
|
||||
if (!code || typeof code !== 'string') {
|
||||
logger.error('[MCP OAuth] Missing or invalid code');
|
||||
return res.redirect('/oauth/error?error=missing_code');
|
||||
}
|
||||
|
||||
if (!state || typeof state !== 'string') {
|
||||
logger.error('[MCP OAuth] Missing or invalid state');
|
||||
return res.redirect('/oauth/error?error=missing_state');
|
||||
}
|
||||
|
||||
// Extract flow ID from state
|
||||
const flowId = state;
|
||||
logger.debug('[MCP OAuth] Using flow ID from state', { flowId });
|
||||
|
||||
const flowsCache = getLogStores(CacheKeys.FLOWS);
|
||||
const flowManager = getFlowStateManager(flowsCache);
|
||||
|
||||
logger.debug('[MCP OAuth] Getting flow state for flowId: ' + flowId);
|
||||
const flowState = await MCPOAuthHandler.getFlowState(flowId, flowManager);
|
||||
|
||||
if (!flowState) {
|
||||
logger.error('[MCP OAuth] Flow state not found for flowId:', flowId);
|
||||
return res.redirect('/oauth/error?error=invalid_state');
|
||||
}
|
||||
|
||||
logger.debug('[MCP OAuth] Flow state details', {
|
||||
serverName: flowState.serverName,
|
||||
userId: flowState.userId,
|
||||
hasMetadata: !!flowState.metadata,
|
||||
hasClientInfo: !!flowState.clientInfo,
|
||||
hasCodeVerifier: !!flowState.codeVerifier,
|
||||
});
|
||||
|
||||
// Complete the OAuth flow
|
||||
logger.debug('[MCP OAuth] Completing OAuth flow');
|
||||
const tokens = await MCPOAuthHandler.completeOAuthFlow(flowId, code, flowManager);
|
||||
logger.info('[MCP OAuth] OAuth flow completed, tokens received in callback route');
|
||||
|
||||
// For system-level OAuth, we need to store the tokens and retry the connection
|
||||
if (flowState.userId === 'system') {
|
||||
logger.debug(`[MCP OAuth] System-level OAuth completed for ${serverName}`);
|
||||
}
|
||||
|
||||
/** ID of the flow that the tool/connection is waiting for */
|
||||
const toolFlowId = flowState.metadata?.toolFlowId;
|
||||
if (toolFlowId) {
|
||||
logger.debug('[MCP OAuth] Completing tool flow', { toolFlowId });
|
||||
await flowManager.completeFlow(toolFlowId, 'mcp_oauth', tokens);
|
||||
}
|
||||
|
||||
/** Redirect to success page with flowId and serverName */
|
||||
const redirectUrl = `/oauth/success?serverName=${encodeURIComponent(serverName)}`;
|
||||
res.redirect(redirectUrl);
|
||||
} catch (error) {
|
||||
logger.error('[MCP OAuth] OAuth callback error', error);
|
||||
res.redirect('/oauth/error?error=callback_failed');
|
||||
}
|
||||
});
|
||||
|
||||
/**
|
||||
* Get OAuth tokens for a completed flow
|
||||
* This is primarily for user-level OAuth flows
|
||||
*/
|
||||
router.get('/oauth/tokens/:flowId', requireJwtAuth, async (req, res) => {
|
||||
try {
|
||||
const { flowId } = req.params;
|
||||
const user = req.user;
|
||||
|
||||
if (!user?.id) {
|
||||
return res.status(401).json({ error: 'User not authenticated' });
|
||||
}
|
||||
|
||||
// Allow system flows or user-owned flows
|
||||
if (!flowId.startsWith(`${user.id}:`) && !flowId.startsWith('system:')) {
|
||||
return res.status(403).json({ error: 'Access denied' });
|
||||
}
|
||||
|
||||
const flowsCache = getLogStores(CacheKeys.FLOWS);
|
||||
const flowManager = getFlowStateManager(flowsCache);
|
||||
|
||||
const flowState = await flowManager.getFlowState(flowId, 'mcp_oauth');
|
||||
if (!flowState) {
|
||||
return res.status(404).json({ error: 'Flow not found' });
|
||||
}
|
||||
|
||||
if (flowState.status !== 'COMPLETED') {
|
||||
return res.status(400).json({ error: 'Flow not completed' });
|
||||
}
|
||||
|
||||
res.json({ tokens: flowState.result });
|
||||
} catch (error) {
|
||||
logger.error('[MCP OAuth] Failed to get tokens', error);
|
||||
res.status(500).json({ error: 'Failed to get tokens' });
|
||||
}
|
||||
});
|
||||
|
||||
/**
|
||||
* Check OAuth flow status
|
||||
* This endpoint can be used to poll the status of an OAuth flow
|
||||
*/
|
||||
router.get('/oauth/status/:flowId', async (req, res) => {
|
||||
try {
|
||||
const { flowId } = req.params;
|
||||
const flowsCache = getLogStores(CacheKeys.FLOWS);
|
||||
const flowManager = getFlowStateManager(flowsCache);
|
||||
|
||||
const flowState = await flowManager.getFlowState(flowId, 'mcp_oauth');
|
||||
if (!flowState) {
|
||||
return res.status(404).json({ error: 'Flow not found' });
|
||||
}
|
||||
|
||||
res.json({
|
||||
status: flowState.status,
|
||||
completed: flowState.status === 'COMPLETED',
|
||||
failed: flowState.status === 'FAILED',
|
||||
error: flowState.error,
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error('[MCP OAuth] Failed to get flow status', error);
|
||||
res.status(500).json({ error: 'Failed to get flow status' });
|
||||
}
|
||||
});
|
||||
|
||||
module.exports = router;
|
||||
231
api/server/routes/memories.js
Normal file
231
api/server/routes/memories.js
Normal file
|
|
@ -0,0 +1,231 @@
|
|||
const express = require('express');
|
||||
const { Tokenizer } = require('@librechat/api');
|
||||
const { PermissionTypes, Permissions } = require('librechat-data-provider');
|
||||
const {
|
||||
getAllUserMemories,
|
||||
toggleUserMemories,
|
||||
createMemory,
|
||||
setMemory,
|
||||
deleteMemory,
|
||||
} = require('~/models');
|
||||
const { requireJwtAuth, generateCheckAccess } = require('~/server/middleware');
|
||||
|
||||
const router = express.Router();
|
||||
|
||||
const checkMemoryRead = generateCheckAccess(PermissionTypes.MEMORIES, [
|
||||
Permissions.USE,
|
||||
Permissions.READ,
|
||||
]);
|
||||
const checkMemoryCreate = generateCheckAccess(PermissionTypes.MEMORIES, [
|
||||
Permissions.USE,
|
||||
Permissions.CREATE,
|
||||
]);
|
||||
const checkMemoryUpdate = generateCheckAccess(PermissionTypes.MEMORIES, [
|
||||
Permissions.USE,
|
||||
Permissions.UPDATE,
|
||||
]);
|
||||
const checkMemoryDelete = generateCheckAccess(PermissionTypes.MEMORIES, [
|
||||
Permissions.USE,
|
||||
Permissions.UPDATE,
|
||||
]);
|
||||
const checkMemoryOptOut = generateCheckAccess(PermissionTypes.MEMORIES, [
|
||||
Permissions.USE,
|
||||
Permissions.OPT_OUT,
|
||||
]);
|
||||
|
||||
router.use(requireJwtAuth);
|
||||
|
||||
/**
|
||||
* GET /memories
|
||||
* Returns all memories for the authenticated user, sorted by updated_at (newest first).
|
||||
* Also includes memory usage percentage based on token limit.
|
||||
*/
|
||||
router.get('/', checkMemoryRead, async (req, res) => {
|
||||
try {
|
||||
const memories = await getAllUserMemories(req.user.id);
|
||||
|
||||
const sortedMemories = memories.sort(
|
||||
(a, b) => new Date(b.updated_at).getTime() - new Date(a.updated_at).getTime(),
|
||||
);
|
||||
|
||||
const totalTokens = memories.reduce((sum, memory) => {
|
||||
return sum + (memory.tokenCount || 0);
|
||||
}, 0);
|
||||
|
||||
const memoryConfig = req.app.locals?.memory;
|
||||
const tokenLimit = memoryConfig?.tokenLimit;
|
||||
|
||||
let usagePercentage = null;
|
||||
if (tokenLimit && tokenLimit > 0) {
|
||||
usagePercentage = Math.min(100, Math.round((totalTokens / tokenLimit) * 100));
|
||||
}
|
||||
|
||||
res.json({
|
||||
memories: sortedMemories,
|
||||
totalTokens,
|
||||
tokenLimit: tokenLimit || null,
|
||||
usagePercentage,
|
||||
});
|
||||
} catch (error) {
|
||||
res.status(500).json({ error: error.message });
|
||||
}
|
||||
});
|
||||
|
||||
/**
|
||||
* POST /memories
|
||||
* Creates a new memory entry for the authenticated user.
|
||||
* Body: { key: string, value: string }
|
||||
* Returns 201 and { created: true, memory: <createdDoc> } when successful.
|
||||
*/
|
||||
router.post('/', checkMemoryCreate, async (req, res) => {
|
||||
const { key, value } = req.body;
|
||||
|
||||
if (typeof key !== 'string' || key.trim() === '') {
|
||||
return res.status(400).json({ error: 'Key is required and must be a non-empty string.' });
|
||||
}
|
||||
|
||||
if (typeof value !== 'string' || value.trim() === '') {
|
||||
return res.status(400).json({ error: 'Value is required and must be a non-empty string.' });
|
||||
}
|
||||
|
||||
try {
|
||||
const tokenCount = Tokenizer.getTokenCount(value, 'o200k_base');
|
||||
|
||||
const memories = await getAllUserMemories(req.user.id);
|
||||
|
||||
// Check token limit
|
||||
const memoryConfig = req.app.locals?.memory;
|
||||
const tokenLimit = memoryConfig?.tokenLimit;
|
||||
|
||||
if (tokenLimit) {
|
||||
const currentTotalTokens = memories.reduce(
|
||||
(sum, memory) => sum + (memory.tokenCount || 0),
|
||||
0,
|
||||
);
|
||||
if (currentTotalTokens + tokenCount > tokenLimit) {
|
||||
return res.status(400).json({
|
||||
error: `Adding this memory would exceed the token limit of ${tokenLimit}. Current usage: ${currentTotalTokens} tokens.`,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
const result = await createMemory({
|
||||
userId: req.user.id,
|
||||
key: key.trim(),
|
||||
value: value.trim(),
|
||||
tokenCount,
|
||||
});
|
||||
|
||||
if (!result.ok) {
|
||||
return res.status(500).json({ error: 'Failed to create memory.' });
|
||||
}
|
||||
|
||||
const updatedMemories = await getAllUserMemories(req.user.id);
|
||||
const newMemory = updatedMemories.find((m) => m.key === key.trim());
|
||||
|
||||
res.status(201).json({ created: true, memory: newMemory });
|
||||
} catch (error) {
|
||||
if (error.message && error.message.includes('already exists')) {
|
||||
return res.status(409).json({ error: 'Memory with this key already exists.' });
|
||||
}
|
||||
res.status(500).json({ error: error.message });
|
||||
}
|
||||
});
|
||||
|
||||
/**
|
||||
* PATCH /memories/preferences
|
||||
* Updates the user's memory preferences (e.g., enabling/disabling memories).
|
||||
* Body: { memories: boolean }
|
||||
* Returns 200 and { updated: true, preferences: { memories: boolean } } when successful.
|
||||
*/
|
||||
router.patch('/preferences', checkMemoryOptOut, async (req, res) => {
|
||||
const { memories } = req.body;
|
||||
|
||||
if (typeof memories !== 'boolean') {
|
||||
return res.status(400).json({ error: 'memories must be a boolean value.' });
|
||||
}
|
||||
|
||||
try {
|
||||
const updatedUser = await toggleUserMemories(req.user.id, memories);
|
||||
|
||||
if (!updatedUser) {
|
||||
return res.status(404).json({ error: 'User not found.' });
|
||||
}
|
||||
|
||||
res.json({
|
||||
updated: true,
|
||||
preferences: {
|
||||
memories: updatedUser.personalization?.memories ?? true,
|
||||
},
|
||||
});
|
||||
} catch (error) {
|
||||
res.status(500).json({ error: error.message });
|
||||
}
|
||||
});
|
||||
|
||||
/**
|
||||
* PATCH /memories/:key
|
||||
* Updates the value of an existing memory entry for the authenticated user.
|
||||
* Body: { value: string }
|
||||
* Returns 200 and { updated: true, memory: <updatedDoc> } when successful.
|
||||
*/
|
||||
router.patch('/:key', checkMemoryUpdate, async (req, res) => {
|
||||
const { key } = req.params;
|
||||
const { value } = req.body || {};
|
||||
|
||||
if (typeof value !== 'string' || value.trim() === '') {
|
||||
return res.status(400).json({ error: 'Value is required and must be a non-empty string.' });
|
||||
}
|
||||
|
||||
try {
|
||||
const tokenCount = Tokenizer.getTokenCount(value, 'o200k_base');
|
||||
|
||||
const memories = await getAllUserMemories(req.user.id);
|
||||
const existingMemory = memories.find((m) => m.key === key);
|
||||
|
||||
if (!existingMemory) {
|
||||
return res.status(404).json({ error: 'Memory not found.' });
|
||||
}
|
||||
|
||||
const result = await setMemory({
|
||||
userId: req.user.id,
|
||||
key,
|
||||
value,
|
||||
tokenCount,
|
||||
});
|
||||
|
||||
if (!result.ok) {
|
||||
return res.status(500).json({ error: 'Failed to update memory.' });
|
||||
}
|
||||
|
||||
const updatedMemories = await getAllUserMemories(req.user.id);
|
||||
const updatedMemory = updatedMemories.find((m) => m.key === key);
|
||||
|
||||
res.json({ updated: true, memory: updatedMemory });
|
||||
} catch (error) {
|
||||
res.status(500).json({ error: error.message });
|
||||
}
|
||||
});
|
||||
|
||||
/**
|
||||
* DELETE /memories/:key
|
||||
* Deletes a memory entry for the authenticated user.
|
||||
* Returns 200 and { deleted: true } when successful.
|
||||
*/
|
||||
router.delete('/:key', checkMemoryDelete, async (req, res) => {
|
||||
const { key } = req.params;
|
||||
|
||||
try {
|
||||
const result = await deleteMemory({ userId: req.user.id, key });
|
||||
|
||||
if (!result.ok) {
|
||||
return res.status(404).json({ error: 'Memory not found.' });
|
||||
}
|
||||
|
||||
res.json({ deleted: true });
|
||||
} catch (error) {
|
||||
res.status(500).json({ error: error.message });
|
||||
}
|
||||
});
|
||||
|
||||
module.exports = router;
|
||||
|
|
@ -47,7 +47,9 @@ const oauthHandler = async (req, res) => {
|
|||
|
||||
router.get('/error', (req, res) => {
|
||||
// A single error message is pushed by passport when authentication fails.
|
||||
logger.error('Error in OAuth authentication:', { message: req.session.messages.pop() });
|
||||
logger.error('Error in OAuth authentication:', {
|
||||
message: req.session?.messages?.pop() || 'Unknown error',
|
||||
});
|
||||
|
||||
// Redirect to login page with auth_failed parameter to prevent infinite redirect loops
|
||||
res.redirect(`${domains.client}/login?redirect=false`);
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
const express = require('express');
|
||||
const {
|
||||
promptPermissionsSchema,
|
||||
memoryPermissionsSchema,
|
||||
agentPermissionsSchema,
|
||||
PermissionTypes,
|
||||
roleDefaults,
|
||||
|
|
@ -118,4 +119,43 @@ router.put('/:roleName/agents', checkAdmin, async (req, res) => {
|
|||
}
|
||||
});
|
||||
|
||||
/**
|
||||
* PUT /api/roles/:roleName/memories
|
||||
* Update memory permissions for a specific role
|
||||
*/
|
||||
router.put('/:roleName/memories', checkAdmin, async (req, res) => {
|
||||
const { roleName: _r } = req.params;
|
||||
// TODO: TEMP, use a better parsing for roleName
|
||||
const roleName = _r.toUpperCase();
|
||||
/** @type {TRole['permissions']['MEMORIES']} */
|
||||
const updates = req.body;
|
||||
|
||||
try {
|
||||
const parsedUpdates = memoryPermissionsSchema.partial().parse(updates);
|
||||
|
||||
const role = await getRoleByName(roleName);
|
||||
if (!role) {
|
||||
return res.status(404).send({ message: 'Role not found' });
|
||||
}
|
||||
|
||||
const currentPermissions =
|
||||
role.permissions?.[PermissionTypes.MEMORIES] || role[PermissionTypes.MEMORIES] || {};
|
||||
|
||||
const mergedUpdates = {
|
||||
permissions: {
|
||||
...role.permissions,
|
||||
[PermissionTypes.MEMORIES]: {
|
||||
...currentPermissions,
|
||||
...parsedUpdates,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const updatedRole = await updateRoleByName(roleName, mergedUpdates);
|
||||
res.status(200).send(updatedRole);
|
||||
} catch (error) {
|
||||
return res.status(400).send({ message: 'Invalid memory permissions.', error: error.errors });
|
||||
}
|
||||
});
|
||||
|
||||
module.exports = router;
|
||||
|
|
|
|||
|
|
@ -1,15 +1,15 @@
|
|||
const express = require('express');
|
||||
|
||||
const { isEnabled } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const {
|
||||
getSharedLink,
|
||||
getSharedMessages,
|
||||
createSharedLink,
|
||||
updateSharedLink,
|
||||
getSharedLinks,
|
||||
deleteSharedLink,
|
||||
} = require('~/models/Share');
|
||||
getSharedLinks,
|
||||
getSharedLink,
|
||||
} = require('~/models');
|
||||
const requireJwtAuth = require('~/server/middleware/requireJwtAuth');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const router = express.Router();
|
||||
|
||||
/**
|
||||
|
|
@ -35,6 +35,7 @@ if (allowSharedLinks) {
|
|||
res.status(404).end();
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('Error getting shared messages:', error);
|
||||
res.status(500).json({ message: 'Error getting shared messages' });
|
||||
}
|
||||
},
|
||||
|
|
@ -54,9 +55,7 @@ router.get('/', requireJwtAuth, async (req, res) => {
|
|||
sortDirection: ['asc', 'desc'].includes(req.query.sortDirection)
|
||||
? req.query.sortDirection
|
||||
: 'desc',
|
||||
search: req.query.search
|
||||
? decodeURIComponent(req.query.search.trim())
|
||||
: undefined,
|
||||
search: req.query.search ? decodeURIComponent(req.query.search.trim()) : undefined,
|
||||
};
|
||||
|
||||
const result = await getSharedLinks(
|
||||
|
|
@ -75,7 +74,7 @@ router.get('/', requireJwtAuth, async (req, res) => {
|
|||
hasNextPage: result.hasNextPage,
|
||||
});
|
||||
} catch (error) {
|
||||
console.error('Error getting shared links:', error);
|
||||
logger.error('Error getting shared links:', error);
|
||||
res.status(500).json({
|
||||
message: 'Error getting shared links',
|
||||
error: error.message,
|
||||
|
|
@ -93,6 +92,7 @@ router.get('/link/:conversationId', requireJwtAuth, async (req, res) => {
|
|||
conversationId: req.params.conversationId,
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error('Error getting shared link:', error);
|
||||
res.status(500).json({ message: 'Error getting shared link' });
|
||||
}
|
||||
});
|
||||
|
|
@ -106,6 +106,7 @@ router.post('/:conversationId', requireJwtAuth, async (req, res) => {
|
|||
res.status(404).end();
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('Error creating shared link:', error);
|
||||
res.status(500).json({ message: 'Error creating shared link' });
|
||||
}
|
||||
});
|
||||
|
|
@ -119,6 +120,7 @@ router.patch('/:shareId', requireJwtAuth, async (req, res) => {
|
|||
res.status(404).end();
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('Error updating shared link:', error);
|
||||
res.status(500).json({ message: 'Error updating shared link' });
|
||||
}
|
||||
});
|
||||
|
|
@ -133,7 +135,8 @@ router.delete('/:shareId', requireJwtAuth, async (req, res) => {
|
|||
|
||||
return res.status(200).json(result);
|
||||
} catch (error) {
|
||||
return res.status(400).json({ message: error.message });
|
||||
logger.error('Error deleting shared link:', error);
|
||||
return res.status(400).json({ message: 'Error deleting shared link' });
|
||||
}
|
||||
});
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,15 @@
|
|||
const jwt = require('jsonwebtoken');
|
||||
const { nanoid } = require('nanoid');
|
||||
const { tool } = require('@langchain/core/tools');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { GraphEvents, sleep } = require('@librechat/agents');
|
||||
const {
|
||||
sendEvent,
|
||||
encryptV2,
|
||||
decryptV2,
|
||||
logAxiosError,
|
||||
refreshAccessToken,
|
||||
} = require('@librechat/api');
|
||||
const {
|
||||
Time,
|
||||
CacheKeys,
|
||||
|
|
@ -12,14 +20,11 @@ const {
|
|||
isImageVisionTool,
|
||||
actionDomainSeparator,
|
||||
} = require('librechat-data-provider');
|
||||
const { refreshAccessToken } = require('~/server/services/TokenService');
|
||||
const { logger, getFlowStateManager, sendEvent } = require('~/config');
|
||||
const { encryptV2, decryptV2 } = require('~/server/utils/crypto');
|
||||
const { findToken, updateToken, createToken } = require('~/models');
|
||||
const { getActions, deleteActions } = require('~/models/Action');
|
||||
const { deleteAssistant } = require('~/models/Assistant');
|
||||
const { logAxiosError } = require('~/utils');
|
||||
const { getFlowStateManager } = require('~/config');
|
||||
const { getLogStores } = require('~/cache');
|
||||
const { findToken } = require('~/models');
|
||||
|
||||
const JWT_SECRET = process.env.JWT_SECRET;
|
||||
const toolNameRegex = /^[a-zA-Z0-9_-]+$/;
|
||||
|
|
@ -208,6 +213,7 @@ async function createActionTool({
|
|||
userId: userId,
|
||||
client_url: metadata.auth.client_url,
|
||||
redirect_uri: `${process.env.DOMAIN_SERVER}/api/actions/${action_id}/oauth/callback`,
|
||||
token_exchange_method: metadata.auth.token_exchange_method,
|
||||
/** Encrypted values */
|
||||
encrypted_oauth_client_id: encrypted.oauth_client_id,
|
||||
encrypted_oauth_client_secret: encrypted.oauth_client_secret,
|
||||
|
|
@ -256,14 +262,22 @@ async function createActionTool({
|
|||
try {
|
||||
const refresh_token = await decryptV2(refreshTokenData.token);
|
||||
const refreshTokens = async () =>
|
||||
await refreshAccessToken({
|
||||
userId,
|
||||
identifier,
|
||||
refresh_token,
|
||||
client_url: metadata.auth.client_url,
|
||||
encrypted_oauth_client_id: encrypted.oauth_client_id,
|
||||
encrypted_oauth_client_secret: encrypted.oauth_client_secret,
|
||||
});
|
||||
await refreshAccessToken(
|
||||
{
|
||||
userId,
|
||||
identifier,
|
||||
refresh_token,
|
||||
client_url: metadata.auth.client_url,
|
||||
encrypted_oauth_client_id: encrypted.oauth_client_id,
|
||||
token_exchange_method: metadata.auth.token_exchange_method,
|
||||
encrypted_oauth_client_secret: encrypted.oauth_client_secret,
|
||||
},
|
||||
{
|
||||
findToken,
|
||||
updateToken,
|
||||
createToken,
|
||||
},
|
||||
);
|
||||
const flowsCache = getLogStores(CacheKeys.FLOWS);
|
||||
const flowManager = getFlowStateManager(flowsCache);
|
||||
const refreshData = await flowManager.createFlowWithHandler(
|
||||
|
|
|
|||
|
|
@ -1,11 +1,12 @@
|
|||
const {
|
||||
FileSources,
|
||||
loadOCRConfig,
|
||||
processMCPEnv,
|
||||
EModelEndpoint,
|
||||
loadMemoryConfig,
|
||||
getConfigDefaults,
|
||||
loadWebSearchConfig,
|
||||
} = require('librechat-data-provider');
|
||||
const { agentsConfigSetup } = require('@librechat/api');
|
||||
const {
|
||||
checkHealth,
|
||||
checkConfig,
|
||||
|
|
@ -24,10 +25,9 @@ const { azureConfigSetup } = require('./start/azureOpenAI');
|
|||
const { processModelSpecs } = require('./start/modelSpecs');
|
||||
const { initializeS3 } = require('./Files/S3/initialize');
|
||||
const { loadAndFormatTools } = require('./ToolService');
|
||||
const { agentsConfigSetup } = require('./start/agents');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const { initializeRoles } = require('~/models');
|
||||
const { getMCPManager } = require('~/config');
|
||||
const { setCachedTools } = require('./Config');
|
||||
const paths = require('~/config/paths');
|
||||
|
||||
/**
|
||||
|
|
@ -44,6 +44,7 @@ const AppService = async (app) => {
|
|||
const ocr = loadOCRConfig(config.ocr);
|
||||
const webSearch = loadWebSearchConfig(config.webSearch);
|
||||
checkWebSearchConfig(webSearch);
|
||||
const memory = loadMemoryConfig(config.memory);
|
||||
const filteredTools = config.filteredTools;
|
||||
const includedTools = config.includedTools;
|
||||
const fileStrategy = config.fileStrategy ?? configDefaults.fileStrategy;
|
||||
|
|
@ -74,11 +75,10 @@ const AppService = async (app) => {
|
|||
directory: paths.structuredTools,
|
||||
});
|
||||
|
||||
if (config.mcpServers != null) {
|
||||
const mcpManager = getMCPManager();
|
||||
await mcpManager.initializeMCP(config.mcpServers, processMCPEnv);
|
||||
await mcpManager.mapAvailableTools(availableTools);
|
||||
}
|
||||
await setCachedTools(availableTools, { isGlobal: true });
|
||||
|
||||
// Store MCP config for later initialization
|
||||
const mcpConfig = config.mcpServers || null;
|
||||
|
||||
const socialLogins =
|
||||
config?.registration?.socialLogins ?? configDefaults?.registration?.socialLogins;
|
||||
|
|
@ -88,20 +88,26 @@ const AppService = async (app) => {
|
|||
const defaultLocals = {
|
||||
ocr,
|
||||
paths,
|
||||
memory,
|
||||
webSearch,
|
||||
fileStrategy,
|
||||
socialLogins,
|
||||
filteredTools,
|
||||
includedTools,
|
||||
availableTools,
|
||||
imageOutputType,
|
||||
interfaceConfig,
|
||||
turnstileConfig,
|
||||
balance,
|
||||
mcpConfig,
|
||||
};
|
||||
|
||||
const agentsDefaults = agentsConfigSetup(config);
|
||||
|
||||
if (!Object.keys(config).length) {
|
||||
app.locals = defaultLocals;
|
||||
app.locals = {
|
||||
...defaultLocals,
|
||||
[EModelEndpoint.agents]: agentsDefaults,
|
||||
};
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -136,9 +142,7 @@ const AppService = async (app) => {
|
|||
);
|
||||
}
|
||||
|
||||
if (endpoints?.[EModelEndpoint.agents]) {
|
||||
endpointLocals[EModelEndpoint.agents] = agentsConfigSetup(config);
|
||||
}
|
||||
endpointLocals[EModelEndpoint.agents] = agentsConfigSetup(config, agentsDefaults);
|
||||
|
||||
const endpointKeys = [
|
||||
EModelEndpoint.openAI,
|
||||
|
|
|
|||
|
|
@ -2,8 +2,10 @@ const {
|
|||
FileSources,
|
||||
EModelEndpoint,
|
||||
EImageOutputType,
|
||||
AgentCapabilities,
|
||||
defaultSocialLogins,
|
||||
validateAzureGroups,
|
||||
defaultAgentCapabilities,
|
||||
deprecatedAzureVariables,
|
||||
conflictingAzureVariables,
|
||||
} = require('librechat-data-provider');
|
||||
|
|
@ -30,6 +32,25 @@ jest.mock('~/models', () => ({
|
|||
jest.mock('~/models/Role', () => ({
|
||||
updateAccessPermissions: jest.fn(),
|
||||
}));
|
||||
jest.mock('./Config', () => ({
|
||||
setCachedTools: jest.fn(),
|
||||
getCachedTools: jest.fn().mockResolvedValue({
|
||||
ExampleTool: {
|
||||
type: 'function',
|
||||
function: {
|
||||
description: 'Example tool function',
|
||||
name: 'exampleFunction',
|
||||
parameters: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
param1: { type: 'string', description: 'An example parameter' },
|
||||
},
|
||||
required: ['param1'],
|
||||
},
|
||||
},
|
||||
},
|
||||
}),
|
||||
}));
|
||||
jest.mock('./ToolService', () => ({
|
||||
loadAndFormatTools: jest.fn().mockReturnValue({
|
||||
ExampleTool: {
|
||||
|
|
@ -119,22 +140,9 @@ describe('AppService', () => {
|
|||
sidePanel: true,
|
||||
presets: true,
|
||||
}),
|
||||
mcpConfig: null,
|
||||
turnstileConfig: mockedTurnstileConfig,
|
||||
modelSpecs: undefined,
|
||||
availableTools: {
|
||||
ExampleTool: {
|
||||
type: 'function',
|
||||
function: expect.objectContaining({
|
||||
description: 'Example tool function',
|
||||
name: 'exampleFunction',
|
||||
parameters: expect.objectContaining({
|
||||
type: 'object',
|
||||
properties: expect.any(Object),
|
||||
required: expect.arrayContaining(['param1']),
|
||||
}),
|
||||
}),
|
||||
},
|
||||
},
|
||||
paths: expect.anything(),
|
||||
ocr: expect.anything(),
|
||||
imageOutputType: expect.any(String),
|
||||
|
|
@ -151,6 +159,11 @@ describe('AppService', () => {
|
|||
safeSearch: 1,
|
||||
serperApiKey: '${SERPER_API_KEY}',
|
||||
},
|
||||
memory: undefined,
|
||||
agents: {
|
||||
disableBuilder: false,
|
||||
capabilities: expect.arrayContaining([...defaultAgentCapabilities]),
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
|
|
@ -216,14 +229,41 @@ describe('AppService', () => {
|
|||
|
||||
it('should load and format tools accurately with defined structure', async () => {
|
||||
const { loadAndFormatTools } = require('./ToolService');
|
||||
const { setCachedTools, getCachedTools } = require('./Config');
|
||||
|
||||
await AppService(app);
|
||||
|
||||
expect(loadAndFormatTools).toHaveBeenCalledWith({
|
||||
adminFilter: undefined,
|
||||
adminIncluded: undefined,
|
||||
directory: expect.anything(),
|
||||
});
|
||||
|
||||
expect(app.locals.availableTools.ExampleTool).toBeDefined();
|
||||
expect(app.locals.availableTools.ExampleTool).toEqual({
|
||||
// Verify setCachedTools was called with the tools
|
||||
expect(setCachedTools).toHaveBeenCalledWith(
|
||||
{
|
||||
ExampleTool: {
|
||||
type: 'function',
|
||||
function: {
|
||||
description: 'Example tool function',
|
||||
name: 'exampleFunction',
|
||||
parameters: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
param1: { type: 'string', description: 'An example parameter' },
|
||||
},
|
||||
required: ['param1'],
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{ isGlobal: true },
|
||||
);
|
||||
|
||||
// Verify we can retrieve the tools from cache
|
||||
const cachedTools = await getCachedTools({ includeGlobal: true });
|
||||
expect(cachedTools.ExampleTool).toBeDefined();
|
||||
expect(cachedTools.ExampleTool).toEqual({
|
||||
type: 'function',
|
||||
function: {
|
||||
description: 'Example tool function',
|
||||
|
|
@ -268,6 +308,71 @@ describe('AppService', () => {
|
|||
);
|
||||
});
|
||||
|
||||
it('should correctly configure Agents endpoint based on custom config', async () => {
|
||||
require('./Config/loadCustomConfig').mockImplementationOnce(() =>
|
||||
Promise.resolve({
|
||||
endpoints: {
|
||||
[EModelEndpoint.agents]: {
|
||||
disableBuilder: true,
|
||||
recursionLimit: 10,
|
||||
maxRecursionLimit: 20,
|
||||
allowedProviders: ['openai', 'anthropic'],
|
||||
capabilities: [AgentCapabilities.tools, AgentCapabilities.actions],
|
||||
},
|
||||
},
|
||||
}),
|
||||
);
|
||||
|
||||
await AppService(app);
|
||||
|
||||
expect(app.locals).toHaveProperty(EModelEndpoint.agents);
|
||||
expect(app.locals[EModelEndpoint.agents]).toEqual(
|
||||
expect.objectContaining({
|
||||
disableBuilder: true,
|
||||
recursionLimit: 10,
|
||||
maxRecursionLimit: 20,
|
||||
allowedProviders: expect.arrayContaining(['openai', 'anthropic']),
|
||||
capabilities: expect.arrayContaining([AgentCapabilities.tools, AgentCapabilities.actions]),
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should configure Agents endpoint with defaults when no config is provided', async () => {
|
||||
require('./Config/loadCustomConfig').mockImplementationOnce(() => Promise.resolve({}));
|
||||
|
||||
await AppService(app);
|
||||
|
||||
expect(app.locals).toHaveProperty(EModelEndpoint.agents);
|
||||
expect(app.locals[EModelEndpoint.agents]).toEqual(
|
||||
expect.objectContaining({
|
||||
disableBuilder: false,
|
||||
capabilities: expect.arrayContaining([...defaultAgentCapabilities]),
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should configure Agents endpoint with defaults when endpoints exist but agents is not defined', async () => {
|
||||
require('./Config/loadCustomConfig').mockImplementationOnce(() =>
|
||||
Promise.resolve({
|
||||
endpoints: {
|
||||
[EModelEndpoint.openAI]: {
|
||||
titleConvo: true,
|
||||
},
|
||||
},
|
||||
}),
|
||||
);
|
||||
|
||||
await AppService(app);
|
||||
|
||||
expect(app.locals).toHaveProperty(EModelEndpoint.agents);
|
||||
expect(app.locals[EModelEndpoint.agents]).toEqual(
|
||||
expect.objectContaining({
|
||||
disableBuilder: false,
|
||||
capabilities: expect.arrayContaining([...defaultAgentCapabilities]),
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should correctly configure minimum Azure OpenAI Assistant values', async () => {
|
||||
const assistantGroups = [azureGroups[0], { ...azureGroups[1], assistants: true }];
|
||||
require('./Config/loadCustomConfig').mockImplementationOnce(() =>
|
||||
|
|
@ -463,7 +568,6 @@ describe('AppService updating app.locals and issuing warnings', () => {
|
|||
|
||||
expect(app.locals).toBeDefined();
|
||||
expect(app.locals.paths).toBeDefined();
|
||||
expect(app.locals.availableTools).toBeDefined();
|
||||
expect(app.locals.fileStrategy).toEqual(FileSources.local);
|
||||
expect(app.locals.socialLogins).toEqual(defaultSocialLogins);
|
||||
expect(app.locals.balance).toEqual(
|
||||
|
|
@ -496,7 +600,6 @@ describe('AppService updating app.locals and issuing warnings', () => {
|
|||
|
||||
expect(app.locals).toBeDefined();
|
||||
expect(app.locals.paths).toBeDefined();
|
||||
expect(app.locals.availableTools).toBeDefined();
|
||||
expect(app.locals.fileStrategy).toEqual(customConfig.fileStrategy);
|
||||
expect(app.locals.socialLogins).toEqual(customConfig.registration.socialLogins);
|
||||
expect(app.locals.balance).toEqual(customConfig.balance);
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
const bcrypt = require('bcryptjs');
|
||||
const { webcrypto } = require('node:crypto');
|
||||
const { isEnabled } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { SystemRoles, errorsToString } = require('librechat-data-provider');
|
||||
const {
|
||||
findUser,
|
||||
|
|
@ -17,11 +19,10 @@ const {
|
|||
deleteUserById,
|
||||
generateRefreshToken,
|
||||
} = require('~/models');
|
||||
const { isEnabled, checkEmailConfig, sendEmail } = require('~/server/utils');
|
||||
const { isEmailDomainAllowed } = require('~/server/services/domains');
|
||||
const { checkEmailConfig, sendEmail } = require('~/server/utils');
|
||||
const { getBalanceConfig } = require('~/server/services/Config');
|
||||
const { registerSchema } = require('~/strategies/validators');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const domains = {
|
||||
client: process.env.DOMAIN_CLIENT,
|
||||
|
|
@ -409,7 +410,9 @@ const setOpenIDAuthTokens = (tokenset, res) => {
|
|||
return;
|
||||
}
|
||||
const { REFRESH_TOKEN_EXPIRY } = process.env ?? {};
|
||||
const expiryInMilliseconds = eval(REFRESH_TOKEN_EXPIRY) ?? 1000 * 60 * 60 * 24 * 7; // 7 days default
|
||||
const expiryInMilliseconds = REFRESH_TOKEN_EXPIRY
|
||||
? eval(REFRESH_TOKEN_EXPIRY)
|
||||
: 1000 * 60 * 60 * 24 * 7; // 7 days default
|
||||
const expirationDate = new Date(Date.now() + expiryInMilliseconds);
|
||||
if (tokenset == null) {
|
||||
logger.error('[setOpenIDAuthTokens] No tokenset found in request');
|
||||
|
|
|
|||
258
api/server/services/Config/getCachedTools.js
Normal file
258
api/server/services/Config/getCachedTools.js
Normal file
|
|
@ -0,0 +1,258 @@
|
|||
const { CacheKeys } = require('librechat-data-provider');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
|
||||
/**
|
||||
* Cache key generators for different tool access patterns
|
||||
* These will support future permission-based caching
|
||||
*/
|
||||
const ToolCacheKeys = {
|
||||
/** Global tools available to all users */
|
||||
GLOBAL: 'tools:global',
|
||||
/** Tools available to a specific user */
|
||||
USER: (userId) => `tools:user:${userId}`,
|
||||
/** Tools available to a specific role */
|
||||
ROLE: (roleId) => `tools:role:${roleId}`,
|
||||
/** Tools available to a specific group */
|
||||
GROUP: (groupId) => `tools:group:${groupId}`,
|
||||
/** Combined effective tools for a user (computed from all sources) */
|
||||
EFFECTIVE: (userId) => `tools:effective:${userId}`,
|
||||
};
|
||||
|
||||
/**
|
||||
* Retrieves available tools from cache
|
||||
* @function getCachedTools
|
||||
* @param {Object} options - Options for retrieving tools
|
||||
* @param {string} [options.userId] - User ID for user-specific tools
|
||||
* @param {string[]} [options.roleIds] - Role IDs for role-based tools
|
||||
* @param {string[]} [options.groupIds] - Group IDs for group-based tools
|
||||
* @param {boolean} [options.includeGlobal=true] - Whether to include global tools
|
||||
* @returns {Promise<Object|null>} The available tools object or null if not cached
|
||||
*/
|
||||
async function getCachedTools(options = {}) {
|
||||
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
||||
const { userId, roleIds = [], groupIds = [], includeGlobal = true } = options;
|
||||
|
||||
// For now, return global tools (current behavior)
|
||||
// This will be expanded to merge tools from different sources
|
||||
if (!userId && includeGlobal) {
|
||||
return await cache.get(ToolCacheKeys.GLOBAL);
|
||||
}
|
||||
|
||||
// Future implementation will merge tools from multiple sources
|
||||
// based on user permissions, roles, and groups
|
||||
if (userId) {
|
||||
// Check if we have pre-computed effective tools for this user
|
||||
const effectiveTools = await cache.get(ToolCacheKeys.EFFECTIVE(userId));
|
||||
if (effectiveTools) {
|
||||
return effectiveTools;
|
||||
}
|
||||
|
||||
// Otherwise, compute from individual sources
|
||||
const toolSources = [];
|
||||
|
||||
if (includeGlobal) {
|
||||
const globalTools = await cache.get(ToolCacheKeys.GLOBAL);
|
||||
if (globalTools) {
|
||||
toolSources.push(globalTools);
|
||||
}
|
||||
}
|
||||
|
||||
// User-specific tools
|
||||
const userTools = await cache.get(ToolCacheKeys.USER(userId));
|
||||
if (userTools) {
|
||||
toolSources.push(userTools);
|
||||
}
|
||||
|
||||
// Role-based tools
|
||||
for (const roleId of roleIds) {
|
||||
const roleTools = await cache.get(ToolCacheKeys.ROLE(roleId));
|
||||
if (roleTools) {
|
||||
toolSources.push(roleTools);
|
||||
}
|
||||
}
|
||||
|
||||
// Group-based tools
|
||||
for (const groupId of groupIds) {
|
||||
const groupTools = await cache.get(ToolCacheKeys.GROUP(groupId));
|
||||
if (groupTools) {
|
||||
toolSources.push(groupTools);
|
||||
}
|
||||
}
|
||||
|
||||
// Merge all tool sources (for now, simple merge - future will handle conflicts)
|
||||
if (toolSources.length > 0) {
|
||||
return mergeToolSources(toolSources);
|
||||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets available tools in cache
|
||||
* @function setCachedTools
|
||||
* @param {Object} tools - The tools object to cache
|
||||
* @param {Object} options - Options for caching tools
|
||||
* @param {string} [options.userId] - User ID for user-specific tools
|
||||
* @param {string} [options.roleId] - Role ID for role-based tools
|
||||
* @param {string} [options.groupId] - Group ID for group-based tools
|
||||
* @param {boolean} [options.isGlobal=false] - Whether these are global tools
|
||||
* @param {number} [options.ttl] - Time to live in milliseconds
|
||||
* @returns {Promise<boolean>} Whether the operation was successful
|
||||
*/
|
||||
async function setCachedTools(tools, options = {}) {
|
||||
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
||||
const { userId, roleId, groupId, isGlobal = false, ttl } = options;
|
||||
|
||||
let cacheKey;
|
||||
if (isGlobal || (!userId && !roleId && !groupId)) {
|
||||
cacheKey = ToolCacheKeys.GLOBAL;
|
||||
} else if (userId) {
|
||||
cacheKey = ToolCacheKeys.USER(userId);
|
||||
} else if (roleId) {
|
||||
cacheKey = ToolCacheKeys.ROLE(roleId);
|
||||
} else if (groupId) {
|
||||
cacheKey = ToolCacheKeys.GROUP(groupId);
|
||||
}
|
||||
|
||||
if (!cacheKey) {
|
||||
throw new Error('Invalid cache key options provided');
|
||||
}
|
||||
|
||||
return await cache.set(cacheKey, tools, ttl);
|
||||
}
|
||||
|
||||
/**
|
||||
* Invalidates cached tools
|
||||
* @function invalidateCachedTools
|
||||
* @param {Object} options - Options for invalidating tools
|
||||
* @param {string} [options.userId] - User ID to invalidate
|
||||
* @param {string} [options.roleId] - Role ID to invalidate
|
||||
* @param {string} [options.groupId] - Group ID to invalidate
|
||||
* @param {boolean} [options.invalidateGlobal=false] - Whether to invalidate global tools
|
||||
* @param {boolean} [options.invalidateEffective=true] - Whether to invalidate effective tools
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
async function invalidateCachedTools(options = {}) {
|
||||
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
||||
const { userId, roleId, groupId, invalidateGlobal = false, invalidateEffective = true } = options;
|
||||
|
||||
const keysToDelete = [];
|
||||
|
||||
if (invalidateGlobal) {
|
||||
keysToDelete.push(ToolCacheKeys.GLOBAL);
|
||||
}
|
||||
|
||||
if (userId) {
|
||||
keysToDelete.push(ToolCacheKeys.USER(userId));
|
||||
if (invalidateEffective) {
|
||||
keysToDelete.push(ToolCacheKeys.EFFECTIVE(userId));
|
||||
}
|
||||
}
|
||||
|
||||
if (roleId) {
|
||||
keysToDelete.push(ToolCacheKeys.ROLE(roleId));
|
||||
// TODO: In future, invalidate all users with this role
|
||||
}
|
||||
|
||||
if (groupId) {
|
||||
keysToDelete.push(ToolCacheKeys.GROUP(groupId));
|
||||
// TODO: In future, invalidate all users in this group
|
||||
}
|
||||
|
||||
await Promise.all(keysToDelete.map((key) => cache.delete(key)));
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes and caches effective tools for a user
|
||||
* @function computeEffectiveTools
|
||||
* @param {string} userId - The user ID
|
||||
* @param {Object} context - Context containing user's roles and groups
|
||||
* @param {string[]} [context.roleIds=[]] - User's role IDs
|
||||
* @param {string[]} [context.groupIds=[]] - User's group IDs
|
||||
* @param {number} [ttl] - Time to live for the computed result
|
||||
* @returns {Promise<Object>} The computed effective tools
|
||||
*/
|
||||
async function computeEffectiveTools(userId, context = {}, ttl) {
|
||||
const { roleIds = [], groupIds = [] } = context;
|
||||
|
||||
// Get all tool sources
|
||||
const tools = await getCachedTools({
|
||||
userId,
|
||||
roleIds,
|
||||
groupIds,
|
||||
includeGlobal: true,
|
||||
});
|
||||
|
||||
if (tools) {
|
||||
// Cache the computed result
|
||||
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
||||
await cache.set(ToolCacheKeys.EFFECTIVE(userId), tools, ttl);
|
||||
}
|
||||
|
||||
return tools;
|
||||
}
|
||||
|
||||
/**
|
||||
* Merges multiple tool sources into a single tools object
|
||||
* @function mergeToolSources
|
||||
* @param {Object[]} sources - Array of tool objects to merge
|
||||
* @returns {Object} Merged tools object
|
||||
*/
|
||||
function mergeToolSources(sources) {
|
||||
// For now, simple merge that combines all tools
|
||||
// Future implementation will handle:
|
||||
// - Permission precedence (deny > allow)
|
||||
// - Tool property conflicts
|
||||
// - Metadata merging
|
||||
const merged = {};
|
||||
|
||||
for (const source of sources) {
|
||||
if (!source || typeof source !== 'object') {
|
||||
continue;
|
||||
}
|
||||
|
||||
for (const [toolId, toolConfig] of Object.entries(source)) {
|
||||
// Simple last-write-wins for now
|
||||
// Future: merge based on permission levels
|
||||
merged[toolId] = toolConfig;
|
||||
}
|
||||
}
|
||||
|
||||
return merged;
|
||||
}
|
||||
|
||||
/**
|
||||
* Middleware-friendly function to get tools for a request
|
||||
* @function getToolsForRequest
|
||||
* @param {Object} req - Express request object
|
||||
* @returns {Promise<Object|null>} Available tools for the request
|
||||
*/
|
||||
async function getToolsForRequest(req) {
|
||||
const userId = req.user?.id;
|
||||
|
||||
// For now, return global tools if no user
|
||||
if (!userId) {
|
||||
return getCachedTools({ includeGlobal: true });
|
||||
}
|
||||
|
||||
// Future: Extract roles and groups from req.user
|
||||
const roleIds = req.user?.roles || [];
|
||||
const groupIds = req.user?.groups || [];
|
||||
|
||||
return getCachedTools({
|
||||
userId,
|
||||
roleIds,
|
||||
groupIds,
|
||||
includeGlobal: true,
|
||||
});
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
ToolCacheKeys,
|
||||
getCachedTools,
|
||||
setCachedTools,
|
||||
getToolsForRequest,
|
||||
invalidateCachedTools,
|
||||
computeEffectiveTools,
|
||||
};
|
||||
|
|
@ -1,6 +1,10 @@
|
|||
const { logger } = require('@librechat/data-schemas');
|
||||
const { getUserMCPAuthMap } = require('@librechat/api');
|
||||
const { CacheKeys, EModelEndpoint } = require('librechat-data-provider');
|
||||
const { normalizeEndpointName, isEnabled } = require('~/server/utils');
|
||||
const loadCustomConfig = require('./loadCustomConfig');
|
||||
const { getCachedTools } = require('./getCachedTools');
|
||||
const { findPluginAuthsByKeys } = require('~/models');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
|
||||
/**
|
||||
|
|
@ -50,4 +54,46 @@ const getCustomEndpointConfig = async (endpoint) => {
|
|||
);
|
||||
};
|
||||
|
||||
module.exports = { getCustomConfig, getBalanceConfig, getCustomEndpointConfig };
|
||||
async function createGetMCPAuthMap() {
|
||||
const customConfig = await getCustomConfig();
|
||||
const mcpServers = customConfig?.mcpServers;
|
||||
const hasCustomUserVars = Object.values(mcpServers ?? {}).some((server) => server.customUserVars);
|
||||
if (!hasCustomUserVars) {
|
||||
return;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param {Object} params
|
||||
* @param {GenericTool[]} [params.tools]
|
||||
* @param {string} params.userId
|
||||
* @returns {Promise<Record<string, Record<string, string>> | undefined>}
|
||||
*/
|
||||
return async function ({ tools, userId }) {
|
||||
try {
|
||||
if (!tools || tools.length === 0) {
|
||||
return;
|
||||
}
|
||||
const appTools = await getCachedTools({
|
||||
userId,
|
||||
});
|
||||
return await getUserMCPAuthMap({
|
||||
tools,
|
||||
userId,
|
||||
appTools,
|
||||
findPluginAuthsByKeys,
|
||||
});
|
||||
} catch (err) {
|
||||
logger.error(
|
||||
`[api/server/controllers/agents/client.js #chatCompletion] Error getting custom user vars for agent`,
|
||||
err,
|
||||
);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
getCustomConfig,
|
||||
getBalanceConfig,
|
||||
createGetMCPAuthMap,
|
||||
getCustomEndpointConfig,
|
||||
};
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
const { config } = require('./EndpointService');
|
||||
const getCachedTools = require('./getCachedTools');
|
||||
const getCustomConfig = require('./getCustomConfig');
|
||||
const loadCustomConfig = require('./loadCustomConfig');
|
||||
const loadConfigModels = require('./loadConfigModels');
|
||||
|
|
@ -14,6 +15,7 @@ module.exports = {
|
|||
loadDefaultModels,
|
||||
loadOverrideConfig,
|
||||
loadAsyncEndpoints,
|
||||
...getCachedTools,
|
||||
...getCustomConfig,
|
||||
...getEndpointsConfig,
|
||||
};
|
||||
|
|
|
|||
196
api/server/services/Endpoints/agents/agent.js
Normal file
196
api/server/services/Endpoints/agents/agent.js
Normal file
|
|
@ -0,0 +1,196 @@
|
|||
const { Providers } = require('@librechat/agents');
|
||||
const { primeResources, optionalChainWithEmptyCheck } = require('@librechat/api');
|
||||
const {
|
||||
ErrorTypes,
|
||||
EModelEndpoint,
|
||||
EToolResources,
|
||||
replaceSpecialVars,
|
||||
providerEndpointMap,
|
||||
} = require('librechat-data-provider');
|
||||
const initAnthropic = require('~/server/services/Endpoints/anthropic/initialize');
|
||||
const getBedrockOptions = require('~/server/services/Endpoints/bedrock/options');
|
||||
const initOpenAI = require('~/server/services/Endpoints/openAI/initialize');
|
||||
const initCustom = require('~/server/services/Endpoints/custom/initialize');
|
||||
const initGoogle = require('~/server/services/Endpoints/google/initialize');
|
||||
const generateArtifactsPrompt = require('~/app/clients/prompts/artifacts');
|
||||
const { getCustomEndpointConfig } = require('~/server/services/Config');
|
||||
const { processFiles } = require('~/server/services/Files/process');
|
||||
const { getConvoFiles } = require('~/models/Conversation');
|
||||
const { getToolFilesByIds } = require('~/models/File');
|
||||
const { getModelMaxTokens } = require('~/utils');
|
||||
const { getFiles } = require('~/models/File');
|
||||
|
||||
const providerConfigMap = {
|
||||
[Providers.XAI]: initCustom,
|
||||
[Providers.OLLAMA]: initCustom,
|
||||
[Providers.DEEPSEEK]: initCustom,
|
||||
[Providers.OPENROUTER]: initCustom,
|
||||
[EModelEndpoint.openAI]: initOpenAI,
|
||||
[EModelEndpoint.google]: initGoogle,
|
||||
[EModelEndpoint.azureOpenAI]: initOpenAI,
|
||||
[EModelEndpoint.anthropic]: initAnthropic,
|
||||
[EModelEndpoint.bedrock]: getBedrockOptions,
|
||||
};
|
||||
|
||||
/**
|
||||
* @param {object} params
|
||||
* @param {ServerRequest} params.req
|
||||
* @param {ServerResponse} params.res
|
||||
* @param {Agent} params.agent
|
||||
* @param {string | null} [params.conversationId]
|
||||
* @param {Array<IMongoFile>} [params.requestFiles]
|
||||
* @param {typeof import('~/server/services/ToolService').loadAgentTools | undefined} [params.loadTools]
|
||||
* @param {TEndpointOption} [params.endpointOption]
|
||||
* @param {Set<string>} [params.allowedProviders]
|
||||
* @param {boolean} [params.isInitialAgent]
|
||||
* @returns {Promise<Agent & { tools: StructuredTool[], attachments: Array<MongoFile>, toolContextMap: Record<string, unknown>, maxContextTokens: number }>}
|
||||
*/
|
||||
const initializeAgent = async ({
|
||||
req,
|
||||
res,
|
||||
agent,
|
||||
loadTools,
|
||||
requestFiles,
|
||||
conversationId,
|
||||
endpointOption,
|
||||
allowedProviders,
|
||||
isInitialAgent = false,
|
||||
}) => {
|
||||
if (allowedProviders.size > 0 && !allowedProviders.has(agent.provider)) {
|
||||
throw new Error(
|
||||
`{ "type": "${ErrorTypes.INVALID_AGENT_PROVIDER}", "info": "${agent.provider}" }`,
|
||||
);
|
||||
}
|
||||
let currentFiles;
|
||||
|
||||
const _modelOptions = structuredClone(
|
||||
Object.assign(
|
||||
{ model: agent.model },
|
||||
agent.model_parameters ?? { model: agent.model },
|
||||
isInitialAgent === true ? endpointOption?.model_parameters : {},
|
||||
),
|
||||
);
|
||||
|
||||
const { resendFiles = true, ...modelOptions } = _modelOptions;
|
||||
|
||||
if (isInitialAgent && conversationId != null && resendFiles) {
|
||||
const fileIds = (await getConvoFiles(conversationId)) ?? [];
|
||||
/** @type {Set<EToolResources>} */
|
||||
const toolResourceSet = new Set();
|
||||
for (const tool of agent.tools) {
|
||||
if (EToolResources[tool]) {
|
||||
toolResourceSet.add(EToolResources[tool]);
|
||||
}
|
||||
}
|
||||
const toolFiles = await getToolFilesByIds(fileIds, toolResourceSet);
|
||||
if (requestFiles.length || toolFiles.length) {
|
||||
currentFiles = await processFiles(requestFiles.concat(toolFiles));
|
||||
}
|
||||
} else if (isInitialAgent && requestFiles.length) {
|
||||
currentFiles = await processFiles(requestFiles);
|
||||
}
|
||||
|
||||
const { attachments, tool_resources } = await primeResources({
|
||||
req,
|
||||
getFiles,
|
||||
attachments: currentFiles,
|
||||
tool_resources: agent.tool_resources,
|
||||
requestFileSet: new Set(requestFiles?.map((file) => file.file_id)),
|
||||
});
|
||||
|
||||
const provider = agent.provider;
|
||||
const { tools, toolContextMap } =
|
||||
(await loadTools?.({
|
||||
req,
|
||||
res,
|
||||
provider,
|
||||
agentId: agent.id,
|
||||
tools: agent.tools,
|
||||
model: agent.model,
|
||||
tool_resources,
|
||||
})) ?? {};
|
||||
|
||||
agent.endpoint = provider;
|
||||
let getOptions = providerConfigMap[provider];
|
||||
if (!getOptions && providerConfigMap[provider.toLowerCase()] != null) {
|
||||
agent.provider = provider.toLowerCase();
|
||||
getOptions = providerConfigMap[agent.provider];
|
||||
} else if (!getOptions) {
|
||||
const customEndpointConfig = await getCustomEndpointConfig(provider);
|
||||
if (!customEndpointConfig) {
|
||||
throw new Error(`Provider ${provider} not supported`);
|
||||
}
|
||||
getOptions = initCustom;
|
||||
agent.provider = Providers.OPENAI;
|
||||
}
|
||||
|
||||
const _endpointOption =
|
||||
isInitialAgent === true
|
||||
? Object.assign({}, endpointOption, { model_parameters: modelOptions })
|
||||
: { model_parameters: modelOptions };
|
||||
|
||||
const options = await getOptions({
|
||||
req,
|
||||
res,
|
||||
optionsOnly: true,
|
||||
overrideEndpoint: provider,
|
||||
overrideModel: agent.model,
|
||||
endpointOption: _endpointOption,
|
||||
});
|
||||
|
||||
const tokensModel =
|
||||
agent.provider === EModelEndpoint.azureOpenAI ? agent.model : modelOptions.model;
|
||||
const maxTokens = optionalChainWithEmptyCheck(
|
||||
modelOptions.maxOutputTokens,
|
||||
modelOptions.maxTokens,
|
||||
0,
|
||||
);
|
||||
const maxContextTokens = optionalChainWithEmptyCheck(
|
||||
modelOptions.maxContextTokens,
|
||||
modelOptions.max_context_tokens,
|
||||
getModelMaxTokens(tokensModel, providerEndpointMap[provider]),
|
||||
4096,
|
||||
);
|
||||
|
||||
if (
|
||||
agent.endpoint === EModelEndpoint.azureOpenAI &&
|
||||
options.llmConfig?.azureOpenAIApiInstanceName == null
|
||||
) {
|
||||
agent.provider = Providers.OPENAI;
|
||||
}
|
||||
|
||||
if (options.provider != null) {
|
||||
agent.provider = options.provider;
|
||||
}
|
||||
|
||||
/** @type {import('@librechat/agents').ClientOptions} */
|
||||
agent.model_parameters = { ...options.llmConfig };
|
||||
if (options.configOptions) {
|
||||
agent.model_parameters.configuration = options.configOptions;
|
||||
}
|
||||
|
||||
if (agent.instructions && agent.instructions !== '') {
|
||||
agent.instructions = replaceSpecialVars({
|
||||
text: agent.instructions,
|
||||
user: req.user,
|
||||
});
|
||||
}
|
||||
|
||||
if (typeof agent.artifacts === 'string' && agent.artifacts !== '') {
|
||||
agent.additional_instructions = generateArtifactsPrompt({
|
||||
endpoint: agent.provider,
|
||||
artifacts: agent.artifacts,
|
||||
});
|
||||
}
|
||||
|
||||
return {
|
||||
...agent,
|
||||
tools,
|
||||
attachments,
|
||||
resendFiles,
|
||||
toolContextMap,
|
||||
maxContextTokens: (maxContextTokens - maxTokens) * 0.9,
|
||||
};
|
||||
};
|
||||
|
||||
module.exports = { initializeAgent };
|
||||
|
|
@ -1,294 +1,41 @@
|
|||
const { createContentAggregator, Providers } = require('@librechat/agents');
|
||||
const {
|
||||
Constants,
|
||||
ErrorTypes,
|
||||
EModelEndpoint,
|
||||
EToolResources,
|
||||
getResponseSender,
|
||||
AgentCapabilities,
|
||||
replaceSpecialVars,
|
||||
providerEndpointMap,
|
||||
} = require('librechat-data-provider');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { createContentAggregator } = require('@librechat/agents');
|
||||
const { Constants, EModelEndpoint, getResponseSender } = require('librechat-data-provider');
|
||||
const {
|
||||
getDefaultHandlers,
|
||||
createToolEndCallback,
|
||||
} = require('~/server/controllers/agents/callbacks');
|
||||
const initAnthropic = require('~/server/services/Endpoints/anthropic/initialize');
|
||||
const getBedrockOptions = require('~/server/services/Endpoints/bedrock/options');
|
||||
const initOpenAI = require('~/server/services/Endpoints/openAI/initialize');
|
||||
const initCustom = require('~/server/services/Endpoints/custom/initialize');
|
||||
const initGoogle = require('~/server/services/Endpoints/google/initialize');
|
||||
const generateArtifactsPrompt = require('~/app/clients/prompts/artifacts');
|
||||
const { getCustomEndpointConfig } = require('~/server/services/Config');
|
||||
const { processFiles } = require('~/server/services/Files/process');
|
||||
const { initializeAgent } = require('~/server/services/Endpoints/agents/agent');
|
||||
const { loadAgentTools } = require('~/server/services/ToolService');
|
||||
const AgentClient = require('~/server/controllers/agents/client');
|
||||
const { getConvoFiles } = require('~/models/Conversation');
|
||||
const { getToolFilesByIds } = require('~/models/File');
|
||||
const { getModelMaxTokens } = require('~/utils');
|
||||
const { getAgent } = require('~/models/Agent');
|
||||
const { getFiles } = require('~/models/File');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const providerConfigMap = {
|
||||
[Providers.XAI]: initCustom,
|
||||
[Providers.OLLAMA]: initCustom,
|
||||
[Providers.DEEPSEEK]: initCustom,
|
||||
[Providers.OPENROUTER]: initCustom,
|
||||
[EModelEndpoint.openAI]: initOpenAI,
|
||||
[EModelEndpoint.google]: initGoogle,
|
||||
[EModelEndpoint.azureOpenAI]: initOpenAI,
|
||||
[EModelEndpoint.anthropic]: initAnthropic,
|
||||
[EModelEndpoint.bedrock]: getBedrockOptions,
|
||||
};
|
||||
|
||||
/**
|
||||
* @param {Object} params
|
||||
* @param {ServerRequest} params.req
|
||||
* @param {Promise<Array<MongoFile | null>> | undefined} [params.attachments]
|
||||
* @param {Set<string>} params.requestFileSet
|
||||
* @param {AgentToolResources | undefined} [params.tool_resources]
|
||||
* @returns {Promise<{ attachments: Array<MongoFile | undefined> | undefined, tool_resources: AgentToolResources | undefined }>}
|
||||
*/
|
||||
const primeResources = async ({
|
||||
req,
|
||||
attachments: _attachments,
|
||||
tool_resources: _tool_resources,
|
||||
requestFileSet,
|
||||
}) => {
|
||||
try {
|
||||
/** @type {Array<MongoFile | undefined> | undefined} */
|
||||
let attachments;
|
||||
const tool_resources = _tool_resources ?? {};
|
||||
const isOCREnabled = (req.app.locals?.[EModelEndpoint.agents]?.capabilities ?? []).includes(
|
||||
AgentCapabilities.ocr,
|
||||
);
|
||||
if (tool_resources[EToolResources.ocr]?.file_ids && isOCREnabled) {
|
||||
const context = await getFiles(
|
||||
{
|
||||
file_id: { $in: tool_resources.ocr.file_ids },
|
||||
},
|
||||
{},
|
||||
{},
|
||||
);
|
||||
attachments = (attachments ?? []).concat(context);
|
||||
function createToolLoader() {
|
||||
/**
|
||||
* @param {object} params
|
||||
* @param {ServerRequest} params.req
|
||||
* @param {ServerResponse} params.res
|
||||
* @param {string} params.agentId
|
||||
* @param {string[]} params.tools
|
||||
* @param {string} params.provider
|
||||
* @param {string} params.model
|
||||
* @param {AgentToolResources} params.tool_resources
|
||||
* @returns {Promise<{ tools: StructuredTool[], toolContextMap: Record<string, unknown> } | undefined>}
|
||||
*/
|
||||
return async function loadTools({ req, res, agentId, tools, provider, model, tool_resources }) {
|
||||
const agent = { id: agentId, tools, provider, model };
|
||||
try {
|
||||
return await loadAgentTools({
|
||||
req,
|
||||
res,
|
||||
agent,
|
||||
tool_resources,
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error('Error loading tools for agent ' + agentId, error);
|
||||
}
|
||||
if (!_attachments) {
|
||||
return { attachments, tool_resources };
|
||||
}
|
||||
/** @type {Array<MongoFile | undefined> | undefined} */
|
||||
const files = await _attachments;
|
||||
if (!attachments) {
|
||||
/** @type {Array<MongoFile | undefined>} */
|
||||
attachments = [];
|
||||
}
|
||||
|
||||
for (const file of files) {
|
||||
if (!file) {
|
||||
continue;
|
||||
}
|
||||
if (file.metadata?.fileIdentifier) {
|
||||
const execute_code = tool_resources[EToolResources.execute_code] ?? {};
|
||||
if (!execute_code.files) {
|
||||
tool_resources[EToolResources.execute_code] = { ...execute_code, files: [] };
|
||||
}
|
||||
tool_resources[EToolResources.execute_code].files.push(file);
|
||||
} else if (file.embedded === true) {
|
||||
const file_search = tool_resources[EToolResources.file_search] ?? {};
|
||||
if (!file_search.files) {
|
||||
tool_resources[EToolResources.file_search] = { ...file_search, files: [] };
|
||||
}
|
||||
tool_resources[EToolResources.file_search].files.push(file);
|
||||
} else if (
|
||||
requestFileSet.has(file.file_id) &&
|
||||
file.type.startsWith('image') &&
|
||||
file.height &&
|
||||
file.width
|
||||
) {
|
||||
const image_edit = tool_resources[EToolResources.image_edit] ?? {};
|
||||
if (!image_edit.files) {
|
||||
tool_resources[EToolResources.image_edit] = { ...image_edit, files: [] };
|
||||
}
|
||||
tool_resources[EToolResources.image_edit].files.push(file);
|
||||
}
|
||||
|
||||
attachments.push(file);
|
||||
}
|
||||
return { attachments, tool_resources };
|
||||
} catch (error) {
|
||||
logger.error('Error priming resources', error);
|
||||
return { attachments: _attachments, tool_resources: _tool_resources };
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @param {...string | number} values
|
||||
* @returns {string | number | undefined}
|
||||
*/
|
||||
function optionalChainWithEmptyCheck(...values) {
|
||||
for (const value of values) {
|
||||
if (value !== undefined && value !== null && value !== '') {
|
||||
return value;
|
||||
}
|
||||
}
|
||||
return values[values.length - 1];
|
||||
}
|
||||
|
||||
/**
|
||||
* @param {object} params
|
||||
* @param {ServerRequest} params.req
|
||||
* @param {ServerResponse} params.res
|
||||
* @param {Agent} params.agent
|
||||
* @param {Set<string>} [params.allowedProviders]
|
||||
* @param {object} [params.endpointOption]
|
||||
* @param {boolean} [params.isInitialAgent]
|
||||
* @returns {Promise<Agent>}
|
||||
*/
|
||||
const initializeAgentOptions = async ({
|
||||
req,
|
||||
res,
|
||||
agent,
|
||||
endpointOption,
|
||||
allowedProviders,
|
||||
isInitialAgent = false,
|
||||
}) => {
|
||||
if (allowedProviders.size > 0 && !allowedProviders.has(agent.provider)) {
|
||||
throw new Error(
|
||||
`{ "type": "${ErrorTypes.INVALID_AGENT_PROVIDER}", "info": "${agent.provider}" }`,
|
||||
);
|
||||
}
|
||||
let currentFiles;
|
||||
/** @type {Array<MongoFile>} */
|
||||
const requestFiles = req.body.files ?? [];
|
||||
if (
|
||||
isInitialAgent &&
|
||||
req.body.conversationId != null &&
|
||||
(agent.model_parameters?.resendFiles ?? true) === true
|
||||
) {
|
||||
const fileIds = (await getConvoFiles(req.body.conversationId)) ?? [];
|
||||
/** @type {Set<EToolResources>} */
|
||||
const toolResourceSet = new Set();
|
||||
for (const tool of agent.tools) {
|
||||
if (EToolResources[tool]) {
|
||||
toolResourceSet.add(EToolResources[tool]);
|
||||
}
|
||||
}
|
||||
const toolFiles = await getToolFilesByIds(fileIds, toolResourceSet);
|
||||
if (requestFiles.length || toolFiles.length) {
|
||||
currentFiles = await processFiles(requestFiles.concat(toolFiles));
|
||||
}
|
||||
} else if (isInitialAgent && requestFiles.length) {
|
||||
currentFiles = await processFiles(requestFiles);
|
||||
}
|
||||
|
||||
const { attachments, tool_resources } = await primeResources({
|
||||
req,
|
||||
attachments: currentFiles,
|
||||
tool_resources: agent.tool_resources,
|
||||
requestFileSet: new Set(requestFiles.map((file) => file.file_id)),
|
||||
});
|
||||
|
||||
const provider = agent.provider;
|
||||
const { tools, toolContextMap } = await loadAgentTools({
|
||||
req,
|
||||
res,
|
||||
agent: {
|
||||
id: agent.id,
|
||||
tools: agent.tools,
|
||||
provider,
|
||||
model: agent.model,
|
||||
},
|
||||
tool_resources,
|
||||
});
|
||||
|
||||
agent.endpoint = provider;
|
||||
let getOptions = providerConfigMap[provider];
|
||||
if (!getOptions && providerConfigMap[provider.toLowerCase()] != null) {
|
||||
agent.provider = provider.toLowerCase();
|
||||
getOptions = providerConfigMap[agent.provider];
|
||||
} else if (!getOptions) {
|
||||
const customEndpointConfig = await getCustomEndpointConfig(provider);
|
||||
if (!customEndpointConfig) {
|
||||
throw new Error(`Provider ${provider} not supported`);
|
||||
}
|
||||
getOptions = initCustom;
|
||||
agent.provider = Providers.OPENAI;
|
||||
}
|
||||
const model_parameters = Object.assign(
|
||||
{},
|
||||
agent.model_parameters ?? { model: agent.model },
|
||||
isInitialAgent === true ? endpointOption?.model_parameters : {},
|
||||
);
|
||||
const _endpointOption =
|
||||
isInitialAgent === true
|
||||
? Object.assign({}, endpointOption, { model_parameters })
|
||||
: { model_parameters };
|
||||
|
||||
const options = await getOptions({
|
||||
req,
|
||||
res,
|
||||
optionsOnly: true,
|
||||
overrideEndpoint: provider,
|
||||
overrideModel: agent.model,
|
||||
endpointOption: _endpointOption,
|
||||
});
|
||||
|
||||
if (
|
||||
agent.endpoint === EModelEndpoint.azureOpenAI &&
|
||||
options.llmConfig?.azureOpenAIApiInstanceName == null
|
||||
) {
|
||||
agent.provider = Providers.OPENAI;
|
||||
}
|
||||
|
||||
if (options.provider != null) {
|
||||
agent.provider = options.provider;
|
||||
}
|
||||
|
||||
/** @type {import('@librechat/agents').ClientOptions} */
|
||||
agent.model_parameters = Object.assign(model_parameters, options.llmConfig);
|
||||
if (options.configOptions) {
|
||||
agent.model_parameters.configuration = options.configOptions;
|
||||
}
|
||||
|
||||
if (!agent.model_parameters.model) {
|
||||
agent.model_parameters.model = agent.model;
|
||||
}
|
||||
|
||||
if (agent.instructions && agent.instructions !== '') {
|
||||
agent.instructions = replaceSpecialVars({
|
||||
text: agent.instructions,
|
||||
user: req.user,
|
||||
});
|
||||
}
|
||||
|
||||
if (typeof agent.artifacts === 'string' && agent.artifacts !== '') {
|
||||
agent.additional_instructions = generateArtifactsPrompt({
|
||||
endpoint: agent.provider,
|
||||
artifacts: agent.artifacts,
|
||||
});
|
||||
}
|
||||
|
||||
const tokensModel =
|
||||
agent.provider === EModelEndpoint.azureOpenAI ? agent.model : agent.model_parameters.model;
|
||||
const maxTokens = optionalChainWithEmptyCheck(
|
||||
agent.model_parameters.maxOutputTokens,
|
||||
agent.model_parameters.maxTokens,
|
||||
0,
|
||||
);
|
||||
const maxContextTokens = optionalChainWithEmptyCheck(
|
||||
agent.model_parameters.maxContextTokens,
|
||||
agent.max_context_tokens,
|
||||
getModelMaxTokens(tokensModel, providerEndpointMap[provider]),
|
||||
4096,
|
||||
);
|
||||
return {
|
||||
...agent,
|
||||
tools,
|
||||
attachments,
|
||||
toolContextMap,
|
||||
maxContextTokens: (maxContextTokens - maxTokens) * 0.9,
|
||||
};
|
||||
};
|
||||
}
|
||||
|
||||
const initializeClient = async ({ req, res, endpointOption }) => {
|
||||
if (!endpointOption) {
|
||||
|
|
@ -313,7 +60,6 @@ const initializeClient = async ({ req, res, endpointOption }) => {
|
|||
throw new Error('No agent promise provided');
|
||||
}
|
||||
|
||||
// Initialize primary agent
|
||||
const primaryAgent = await endpointOption.agent;
|
||||
if (!primaryAgent) {
|
||||
throw new Error('Agent not found');
|
||||
|
|
@ -323,10 +69,18 @@ const initializeClient = async ({ req, res, endpointOption }) => {
|
|||
/** @type {Set<string>} */
|
||||
const allowedProviders = new Set(req?.app?.locals?.[EModelEndpoint.agents]?.allowedProviders);
|
||||
|
||||
// Handle primary agent
|
||||
const primaryConfig = await initializeAgentOptions({
|
||||
const loadTools = createToolLoader();
|
||||
/** @type {Array<MongoFile>} */
|
||||
const requestFiles = req.body.files ?? [];
|
||||
/** @type {string} */
|
||||
const conversationId = req.body.conversationId;
|
||||
|
||||
const primaryConfig = await initializeAgent({
|
||||
req,
|
||||
res,
|
||||
loadTools,
|
||||
requestFiles,
|
||||
conversationId,
|
||||
agent: primaryAgent,
|
||||
endpointOption,
|
||||
allowedProviders,
|
||||
|
|
@ -340,10 +94,13 @@ const initializeClient = async ({ req, res, endpointOption }) => {
|
|||
if (!agent) {
|
||||
throw new Error(`Agent ${agentId} not found`);
|
||||
}
|
||||
const config = await initializeAgentOptions({
|
||||
const config = await initializeAgent({
|
||||
req,
|
||||
res,
|
||||
agent,
|
||||
loadTools,
|
||||
requestFiles,
|
||||
conversationId,
|
||||
endpointOption,
|
||||
allowedProviders,
|
||||
});
|
||||
|
|
@ -373,8 +130,8 @@ const initializeClient = async ({ req, res, endpointOption }) => {
|
|||
iconURL: endpointOption.iconURL,
|
||||
attachments: primaryConfig.attachments,
|
||||
endpointType: endpointOption.endpointType,
|
||||
resendFiles: primaryConfig.resendFiles ?? true,
|
||||
maxContextTokens: primaryConfig.maxContextTokens,
|
||||
resendFiles: primaryConfig.model_parameters?.resendFiles ?? true,
|
||||
endpoint:
|
||||
primaryConfig.id === Constants.EPHEMERAL_AGENT_ID
|
||||
? primaryConfig.endpoint
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
const { HttpsProxyAgent } = require('https-proxy-agent');
|
||||
const { ProxyAgent } = require('undici');
|
||||
const { anthropicSettings, removeNullishValues } = require('librechat-data-provider');
|
||||
const { checkPromptCacheSupport, getClaudeHeaders, configureReasoning } = require('./helpers');
|
||||
|
||||
|
|
@ -67,7 +67,10 @@ function getLLMConfig(apiKey, options = {}) {
|
|||
}
|
||||
|
||||
if (options.proxy) {
|
||||
requestOptions.clientOptions.httpAgent = new HttpsProxyAgent(options.proxy);
|
||||
const proxyAgent = new ProxyAgent(options.proxy);
|
||||
requestOptions.clientOptions.fetchOptions = {
|
||||
dispatcher: proxyAgent,
|
||||
};
|
||||
}
|
||||
|
||||
if (options.reverseProxyUrl) {
|
||||
|
|
|
|||
|
|
@ -21,8 +21,12 @@ describe('getLLMConfig', () => {
|
|||
proxy: 'http://proxy:8080',
|
||||
});
|
||||
|
||||
expect(result.llmConfig.clientOptions).toHaveProperty('httpAgent');
|
||||
expect(result.llmConfig.clientOptions.httpAgent).toHaveProperty('proxy', 'http://proxy:8080');
|
||||
expect(result.llmConfig.clientOptions).toHaveProperty('fetchOptions');
|
||||
expect(result.llmConfig.clientOptions.fetchOptions).toHaveProperty('dispatcher');
|
||||
expect(result.llmConfig.clientOptions.fetchOptions.dispatcher).toBeDefined();
|
||||
expect(result.llmConfig.clientOptions.fetchOptions.dispatcher.constructor.name).toBe(
|
||||
'ProxyAgent',
|
||||
);
|
||||
});
|
||||
|
||||
it('should include reverse proxy URL when provided', () => {
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
const OpenAI = require('openai');
|
||||
const { HttpsProxyAgent } = require('https-proxy-agent');
|
||||
const { constructAzureURL, isUserProvided } = require('@librechat/api');
|
||||
const {
|
||||
ErrorTypes,
|
||||
EModelEndpoint,
|
||||
|
|
@ -12,8 +13,6 @@ const {
|
|||
checkUserKeyExpiry,
|
||||
} = require('~/server/services/UserService');
|
||||
const OpenAIClient = require('~/app/clients/OpenAIClient');
|
||||
const { isUserProvided } = require('~/server/utils');
|
||||
const { constructAzureURL } = require('~/utils');
|
||||
|
||||
class Files {
|
||||
constructor(client) {
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
const { HttpsProxyAgent } = require('https-proxy-agent');
|
||||
const { createHandleLLMNewToken } = require('@librechat/api');
|
||||
const {
|
||||
AuthType,
|
||||
Constants,
|
||||
|
|
@ -8,7 +9,6 @@ const {
|
|||
removeNullishValues,
|
||||
} = require('librechat-data-provider');
|
||||
const { getUserKey, checkUserKeyExpiry } = require('~/server/services/UserService');
|
||||
const { createHandleLLMNewToken } = require('~/app/clients/generators');
|
||||
|
||||
const getOptions = async ({ req, overrideModel, endpointOption }) => {
|
||||
const {
|
||||
|
|
|
|||
|
|
@ -6,10 +6,9 @@ const {
|
|||
extractEnvVariable,
|
||||
} = require('librechat-data-provider');
|
||||
const { Providers } = require('@librechat/agents');
|
||||
const { getOpenAIConfig, createHandleLLMNewToken } = require('@librechat/api');
|
||||
const { getUserKeyValues, checkUserKeyExpiry } = require('~/server/services/UserService');
|
||||
const { getLLMConfig } = require('~/server/services/Endpoints/openAI/llm');
|
||||
const { getCustomEndpointConfig } = require('~/server/services/Config');
|
||||
const { createHandleLLMNewToken } = require('~/app/clients/generators');
|
||||
const { fetchModels } = require('~/server/services/ModelService');
|
||||
const OpenAIClient = require('~/app/clients/OpenAIClient');
|
||||
const { isUserProvided } = require('~/server/utils');
|
||||
|
|
@ -144,7 +143,7 @@ const initializeClient = async ({ req, res, endpointOption, optionsOnly, overrid
|
|||
clientOptions,
|
||||
);
|
||||
clientOptions.modelOptions.user = req.user.id;
|
||||
const options = getLLMConfig(apiKey, clientOptions, endpoint);
|
||||
const options = getOpenAIConfig(apiKey, clientOptions, endpoint);
|
||||
if (!customOptions.streamRate) {
|
||||
return options;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -25,9 +25,9 @@ const initializeClient = async ({ req, res, endpointOption, overrideModel, optio
|
|||
const credentials = isUserProvided
|
||||
? userKey
|
||||
: {
|
||||
[AuthKeys.GOOGLE_SERVICE_KEY]: serviceKey,
|
||||
[AuthKeys.GOOGLE_API_KEY]: GOOGLE_KEY,
|
||||
};
|
||||
[AuthKeys.GOOGLE_SERVICE_KEY]: serviceKey,
|
||||
[AuthKeys.GOOGLE_API_KEY]: GOOGLE_KEY,
|
||||
};
|
||||
|
||||
let clientOptions = {};
|
||||
|
||||
|
|
|
|||
|
|
@ -94,7 +94,7 @@ function getLLMConfig(credentials, options = {}) {
|
|||
// Extract from credentials
|
||||
const serviceKeyRaw = creds[AuthKeys.GOOGLE_SERVICE_KEY] ?? {};
|
||||
const serviceKey =
|
||||
typeof serviceKeyRaw === 'string' ? JSON.parse(serviceKeyRaw) : serviceKeyRaw ?? {};
|
||||
typeof serviceKeyRaw === 'string' ? JSON.parse(serviceKeyRaw) : (serviceKeyRaw ?? {});
|
||||
|
||||
const project_id = serviceKey?.project_id ?? null;
|
||||
const apiKey = creds[AuthKeys.GOOGLE_API_KEY] ?? null;
|
||||
|
|
@ -156,10 +156,6 @@ function getLLMConfig(credentials, options = {}) {
|
|||
}
|
||||
|
||||
if (authHeader) {
|
||||
/**
|
||||
* NOTE: NOT SUPPORTED BY LANGCHAIN GENAI CLIENT,
|
||||
* REQUIRES PR IN https://github.com/langchain-ai/langchainjs
|
||||
*/
|
||||
llmConfig.customHeaders = {
|
||||
Authorization: `Bearer ${apiKey}`,
|
||||
};
|
||||
|
|
|
|||
|
|
@ -1,11 +1,10 @@
|
|||
const {
|
||||
EModelEndpoint,
|
||||
mapModelToAzureConfig,
|
||||
resolveHeaders,
|
||||
mapModelToAzureConfig,
|
||||
} = require('librechat-data-provider');
|
||||
const { isEnabled, isUserProvided, getAzureCredentials } = require('@librechat/api');
|
||||
const { getUserKeyValues, checkUserKeyExpiry } = require('~/server/services/UserService');
|
||||
const { isEnabled, isUserProvided } = require('~/server/utils');
|
||||
const { getAzureCredentials } = require('~/utils');
|
||||
const { PluginsClient } = require('~/app');
|
||||
|
||||
const initializeClient = async ({ req, res, endpointOption }) => {
|
||||
|
|
|
|||
|
|
@ -114,11 +114,11 @@ describe('gptPlugins/initializeClient', () => {
|
|||
test('should initialize PluginsClient with Azure credentials when PLUGINS_USE_AZURE is true', async () => {
|
||||
process.env.AZURE_API_KEY = 'test-azure-api-key';
|
||||
(process.env.AZURE_OPENAI_API_INSTANCE_NAME = 'some-value'),
|
||||
(process.env.AZURE_OPENAI_API_DEPLOYMENT_NAME = 'some-value'),
|
||||
(process.env.AZURE_OPENAI_API_VERSION = 'some-value'),
|
||||
(process.env.AZURE_OPENAI_API_COMPLETIONS_DEPLOYMENT_NAME = 'some-value'),
|
||||
(process.env.AZURE_OPENAI_API_EMBEDDINGS_DEPLOYMENT_NAME = 'some-value'),
|
||||
(process.env.PLUGINS_USE_AZURE = 'true');
|
||||
(process.env.AZURE_OPENAI_API_DEPLOYMENT_NAME = 'some-value'),
|
||||
(process.env.AZURE_OPENAI_API_VERSION = 'some-value'),
|
||||
(process.env.AZURE_OPENAI_API_COMPLETIONS_DEPLOYMENT_NAME = 'some-value'),
|
||||
(process.env.AZURE_OPENAI_API_EMBEDDINGS_DEPLOYMENT_NAME = 'some-value'),
|
||||
(process.env.PLUGINS_USE_AZURE = 'true');
|
||||
process.env.DEBUG_PLUGINS = 'false';
|
||||
process.env.OPENAI_SUMMARIZE = 'false';
|
||||
|
||||
|
|
|
|||
|
|
@ -4,12 +4,15 @@ const {
|
|||
resolveHeaders,
|
||||
mapModelToAzureConfig,
|
||||
} = require('librechat-data-provider');
|
||||
const {
|
||||
isEnabled,
|
||||
isUserProvided,
|
||||
getOpenAIConfig,
|
||||
getAzureCredentials,
|
||||
createHandleLLMNewToken,
|
||||
} = require('@librechat/api');
|
||||
const { getUserKeyValues, checkUserKeyExpiry } = require('~/server/services/UserService');
|
||||
const { getLLMConfig } = require('~/server/services/Endpoints/openAI/llm');
|
||||
const { createHandleLLMNewToken } = require('~/app/clients/generators');
|
||||
const { isEnabled, isUserProvided } = require('~/server/utils');
|
||||
const OpenAIClient = require('~/app/clients/OpenAIClient');
|
||||
const { getAzureCredentials } = require('~/utils');
|
||||
|
||||
const initializeClient = async ({
|
||||
req,
|
||||
|
|
@ -140,7 +143,7 @@ const initializeClient = async ({
|
|||
modelOptions.model = modelName;
|
||||
clientOptions = Object.assign({ modelOptions }, clientOptions);
|
||||
clientOptions.modelOptions.user = req.user.id;
|
||||
const options = getLLMConfig(apiKey, clientOptions);
|
||||
const options = getOpenAIConfig(apiKey, clientOptions);
|
||||
const streamRate = clientOptions.streamRate;
|
||||
if (!streamRate) {
|
||||
return options;
|
||||
|
|
|
|||
|
|
@ -1,170 +0,0 @@
|
|||
const { HttpsProxyAgent } = require('https-proxy-agent');
|
||||
const { KnownEndpoints } = require('librechat-data-provider');
|
||||
const { sanitizeModelName, constructAzureURL } = require('~/utils');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
|
||||
/**
|
||||
* Generates configuration options for creating a language model (LLM) instance.
|
||||
* @param {string} apiKey - The API key for authentication.
|
||||
* @param {Object} options - Additional options for configuring the LLM.
|
||||
* @param {Object} [options.modelOptions] - Model-specific options.
|
||||
* @param {string} [options.modelOptions.model] - The name of the model to use.
|
||||
* @param {string} [options.modelOptions.user] - The user ID
|
||||
* @param {number} [options.modelOptions.temperature] - Controls randomness in output generation (0-2).
|
||||
* @param {number} [options.modelOptions.top_p] - Controls diversity via nucleus sampling (0-1).
|
||||
* @param {number} [options.modelOptions.frequency_penalty] - Reduces repetition of token sequences (-2 to 2).
|
||||
* @param {number} [options.modelOptions.presence_penalty] - Encourages discussing new topics (-2 to 2).
|
||||
* @param {number} [options.modelOptions.max_tokens] - The maximum number of tokens to generate.
|
||||
* @param {string[]} [options.modelOptions.stop] - Sequences where the API will stop generating further tokens.
|
||||
* @param {string} [options.reverseProxyUrl] - URL for a reverse proxy, if used.
|
||||
* @param {boolean} [options.useOpenRouter] - Flag to use OpenRouter API.
|
||||
* @param {Object} [options.headers] - Additional headers for API requests.
|
||||
* @param {string} [options.proxy] - Proxy server URL.
|
||||
* @param {Object} [options.azure] - Azure-specific configurations.
|
||||
* @param {boolean} [options.streaming] - Whether to use streaming mode.
|
||||
* @param {Object} [options.addParams] - Additional parameters to add to the model options.
|
||||
* @param {string[]} [options.dropParams] - Parameters to remove from the model options.
|
||||
* @param {string|null} [endpoint=null] - The endpoint name
|
||||
* @returns {Object} Configuration options for creating an LLM instance.
|
||||
*/
|
||||
function getLLMConfig(apiKey, options = {}, endpoint = null) {
|
||||
let {
|
||||
modelOptions = {},
|
||||
reverseProxyUrl,
|
||||
defaultQuery,
|
||||
headers,
|
||||
proxy,
|
||||
azure,
|
||||
streaming = true,
|
||||
addParams,
|
||||
dropParams,
|
||||
} = options;
|
||||
|
||||
/** @type {OpenAIClientOptions} */
|
||||
let llmConfig = {
|
||||
streaming,
|
||||
};
|
||||
|
||||
Object.assign(llmConfig, modelOptions);
|
||||
|
||||
if (addParams && typeof addParams === 'object') {
|
||||
Object.assign(llmConfig, addParams);
|
||||
}
|
||||
/** Note: OpenAI Web Search models do not support any known parameters besdies `max_tokens` */
|
||||
if (modelOptions.model && /gpt-4o.*search/.test(modelOptions.model)) {
|
||||
const searchExcludeParams = [
|
||||
'frequency_penalty',
|
||||
'presence_penalty',
|
||||
'temperature',
|
||||
'top_p',
|
||||
'top_k',
|
||||
'stop',
|
||||
'logit_bias',
|
||||
'seed',
|
||||
'response_format',
|
||||
'n',
|
||||
'logprobs',
|
||||
'user',
|
||||
];
|
||||
|
||||
dropParams = dropParams || [];
|
||||
dropParams = [...new Set([...dropParams, ...searchExcludeParams])];
|
||||
}
|
||||
|
||||
if (dropParams && Array.isArray(dropParams)) {
|
||||
dropParams.forEach((param) => {
|
||||
if (llmConfig[param]) {
|
||||
llmConfig[param] = undefined;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
let useOpenRouter;
|
||||
/** @type {OpenAIClientOptions['configuration']} */
|
||||
const configOptions = {};
|
||||
if (
|
||||
(reverseProxyUrl && reverseProxyUrl.includes(KnownEndpoints.openrouter)) ||
|
||||
(endpoint && endpoint.toLowerCase().includes(KnownEndpoints.openrouter))
|
||||
) {
|
||||
useOpenRouter = true;
|
||||
llmConfig.include_reasoning = true;
|
||||
configOptions.baseURL = reverseProxyUrl;
|
||||
configOptions.defaultHeaders = Object.assign(
|
||||
{
|
||||
'HTTP-Referer': 'https://librechat.ai',
|
||||
'X-Title': 'LibreChat',
|
||||
},
|
||||
headers,
|
||||
);
|
||||
} else if (reverseProxyUrl) {
|
||||
configOptions.baseURL = reverseProxyUrl;
|
||||
if (headers) {
|
||||
configOptions.defaultHeaders = headers;
|
||||
}
|
||||
}
|
||||
|
||||
if (defaultQuery) {
|
||||
configOptions.defaultQuery = defaultQuery;
|
||||
}
|
||||
|
||||
if (proxy) {
|
||||
const proxyAgent = new HttpsProxyAgent(proxy);
|
||||
Object.assign(configOptions, {
|
||||
httpAgent: proxyAgent,
|
||||
httpsAgent: proxyAgent,
|
||||
});
|
||||
}
|
||||
|
||||
if (azure) {
|
||||
const useModelName = isEnabled(process.env.AZURE_USE_MODEL_AS_DEPLOYMENT_NAME);
|
||||
azure.azureOpenAIApiDeploymentName = useModelName
|
||||
? sanitizeModelName(llmConfig.model)
|
||||
: azure.azureOpenAIApiDeploymentName;
|
||||
|
||||
if (process.env.AZURE_OPENAI_DEFAULT_MODEL) {
|
||||
llmConfig.model = process.env.AZURE_OPENAI_DEFAULT_MODEL;
|
||||
}
|
||||
|
||||
if (configOptions.baseURL) {
|
||||
const azureURL = constructAzureURL({
|
||||
baseURL: configOptions.baseURL,
|
||||
azureOptions: azure,
|
||||
});
|
||||
azure.azureOpenAIBasePath = azureURL.split(`/${azure.azureOpenAIApiDeploymentName}`)[0];
|
||||
}
|
||||
|
||||
Object.assign(llmConfig, azure);
|
||||
llmConfig.model = llmConfig.azureOpenAIApiDeploymentName;
|
||||
} else {
|
||||
llmConfig.apiKey = apiKey;
|
||||
// Object.assign(llmConfig, {
|
||||
// configuration: { apiKey },
|
||||
// });
|
||||
}
|
||||
|
||||
if (process.env.OPENAI_ORGANIZATION && this.azure) {
|
||||
llmConfig.organization = process.env.OPENAI_ORGANIZATION;
|
||||
}
|
||||
|
||||
if (useOpenRouter && llmConfig.reasoning_effort != null) {
|
||||
llmConfig.reasoning = {
|
||||
effort: llmConfig.reasoning_effort,
|
||||
};
|
||||
delete llmConfig.reasoning_effort;
|
||||
}
|
||||
|
||||
if (llmConfig?.['max_tokens'] != null) {
|
||||
/** @type {number} */
|
||||
llmConfig.maxTokens = llmConfig['max_tokens'];
|
||||
delete llmConfig['max_tokens'];
|
||||
}
|
||||
|
||||
return {
|
||||
/** @type {OpenAIClientOptions} */
|
||||
llmConfig,
|
||||
/** @type {OpenAIClientOptions['configuration']} */
|
||||
configOptions,
|
||||
};
|
||||
}
|
||||
|
||||
module.exports = { getLLMConfig };
|
||||
|
|
@ -2,9 +2,9 @@ const axios = require('axios');
|
|||
const fs = require('fs').promises;
|
||||
const FormData = require('form-data');
|
||||
const { Readable } = require('stream');
|
||||
const { genAzureEndpoint } = require('@librechat/api');
|
||||
const { extractEnvVariable, STTProviders } = require('librechat-data-provider');
|
||||
const { getCustomConfig } = require('~/server/services/Config');
|
||||
const { genAzureEndpoint } = require('~/utils');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
const axios = require('axios');
|
||||
const { genAzureEndpoint } = require('@librechat/api');
|
||||
const { extractEnvVariable, TTSProviders } = require('librechat-data-provider');
|
||||
const { getRandomVoiceId, createChunkProcessor, splitTextIntoChunks } = require('./streamAudio');
|
||||
const { getCustomConfig } = require('~/server/services/Config');
|
||||
const { genAzureEndpoint } = require('~/utils');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -91,15 +91,28 @@ async function prepareAzureImageURL(req, file) {
|
|||
* @param {Buffer} params.buffer - The avatar image buffer.
|
||||
* @param {string} params.userId - The user's id.
|
||||
* @param {string} params.manual - Flag to indicate manual update.
|
||||
* @param {string} [params.agentId] - Optional agent ID if this is an agent avatar.
|
||||
* @param {string} [params.basePath='images'] - The base folder within the container.
|
||||
* @param {string} [params.containerName] - The Azure Blob container name.
|
||||
* @returns {Promise<string>} The URL of the avatar.
|
||||
*/
|
||||
async function processAzureAvatar({ buffer, userId, manual, basePath = 'images', containerName }) {
|
||||
async function processAzureAvatar({
|
||||
buffer,
|
||||
userId,
|
||||
manual,
|
||||
agentId,
|
||||
basePath = 'images',
|
||||
containerName,
|
||||
}) {
|
||||
try {
|
||||
const metadata = await sharp(buffer).metadata();
|
||||
const extension = metadata.format === 'gif' ? 'gif' : 'png';
|
||||
const fileName = `avatar.${extension}`;
|
||||
const timestamp = new Date().getTime();
|
||||
|
||||
/** Unique filename with timestamp and optional agent ID */
|
||||
const fileName = agentId
|
||||
? `agent-${agentId}-avatar-${timestamp}.${extension}`
|
||||
: `avatar-${timestamp}.${extension}`;
|
||||
|
||||
const downloadURL = await saveBufferToAzure({
|
||||
userId,
|
||||
|
|
@ -110,9 +123,12 @@ async function processAzureAvatar({ buffer, userId, manual, basePath = 'images',
|
|||
});
|
||||
const isManual = manual === 'true';
|
||||
const url = `${downloadURL}?manual=${isManual}`;
|
||||
if (isManual) {
|
||||
|
||||
// Only update user record if this is a user avatar (manual === 'true')
|
||||
if (isManual && !agentId) {
|
||||
await updateUser(userId, { avatar: url });
|
||||
}
|
||||
|
||||
return url;
|
||||
} catch (error) {
|
||||
logger.error('[processAzureAvatar] Error uploading profile picture to Azure:', error);
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
const FormData = require('form-data');
|
||||
const { getCodeBaseURL } = require('@librechat/agents');
|
||||
const { createAxiosInstance } = require('~/config');
|
||||
const { logAxiosError } = require('~/utils');
|
||||
const { createAxiosInstance, logAxiosError } = require('@librechat/api');
|
||||
|
||||
const axios = createAxiosInstance();
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
const path = require('path');
|
||||
const { v4 } = require('uuid');
|
||||
const axios = require('axios');
|
||||
const { logAxiosError } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { getCodeBaseURL } = require('@librechat/agents');
|
||||
const {
|
||||
Tools,
|
||||
|
|
@ -12,8 +14,6 @@ const {
|
|||
const { getStrategyFunctions } = require('~/server/services/Files/strategies');
|
||||
const { convertImage } = require('~/server/services/Files/images/convert');
|
||||
const { createFile, getFiles, updateFile } = require('~/models/File');
|
||||
const { logAxiosError } = require('~/utils');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
/**
|
||||
* Process OpenAI image files, convert to target format, save and return file metadata.
|
||||
|
|
|
|||
|
|
@ -82,14 +82,20 @@ async function prepareImageURL(req, file) {
|
|||
* @param {Buffer} params.buffer - The Buffer containing the avatar image.
|
||||
* @param {string} params.userId - The user ID.
|
||||
* @param {string} params.manual - A string flag indicating whether the update is manual ('true' or 'false').
|
||||
* @param {string} [params.agentId] - Optional agent ID if this is an agent avatar.
|
||||
* @returns {Promise<string>} - A promise that resolves with the URL of the uploaded avatar.
|
||||
* @throws {Error} - Throws an error if Firebase is not initialized or if there is an error in uploading.
|
||||
*/
|
||||
async function processFirebaseAvatar({ buffer, userId, manual }) {
|
||||
async function processFirebaseAvatar({ buffer, userId, manual, agentId }) {
|
||||
try {
|
||||
const metadata = await sharp(buffer).metadata();
|
||||
const extension = metadata.format === 'gif' ? 'gif' : 'png';
|
||||
const fileName = `avatar.${extension}`;
|
||||
const timestamp = new Date().getTime();
|
||||
|
||||
/** Unique filename with timestamp and optional agent ID */
|
||||
const fileName = agentId
|
||||
? `agent-${agentId}-avatar-${timestamp}.${extension}`
|
||||
: `avatar-${timestamp}.${extension}`;
|
||||
|
||||
const downloadURL = await saveBufferToFirebase({
|
||||
userId,
|
||||
|
|
@ -98,10 +104,10 @@ async function processFirebaseAvatar({ buffer, userId, manual }) {
|
|||
});
|
||||
|
||||
const isManual = manual === 'true';
|
||||
|
||||
const url = `${downloadURL}?manual=${isManual}`;
|
||||
|
||||
if (isManual) {
|
||||
// Only update user record if this is a user avatar (manual === 'true')
|
||||
if (isManual && !agentId) {
|
||||
await updateUser(userId, { avatar: url });
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -201,6 +201,10 @@ const unlinkFile = async (filepath) => {
|
|||
*/
|
||||
const deleteLocalFile = async (req, file) => {
|
||||
const { publicPath, uploads } = req.app.locals.paths;
|
||||
|
||||
/** Filepath stripped of query parameters (e.g., ?manual=true) */
|
||||
const cleanFilepath = file.filepath.split('?')[0];
|
||||
|
||||
if (file.embedded && process.env.RAG_API_URL) {
|
||||
const jwtToken = req.headers.authorization.split(' ')[1];
|
||||
axios.delete(`${process.env.RAG_API_URL}/documents`, {
|
||||
|
|
@ -213,32 +217,32 @@ const deleteLocalFile = async (req, file) => {
|
|||
});
|
||||
}
|
||||
|
||||
if (file.filepath.startsWith(`/uploads/${req.user.id}`)) {
|
||||
if (cleanFilepath.startsWith(`/uploads/${req.user.id}`)) {
|
||||
const userUploadDir = path.join(uploads, req.user.id);
|
||||
const basePath = file.filepath.split(`/uploads/${req.user.id}/`)[1];
|
||||
const basePath = cleanFilepath.split(`/uploads/${req.user.id}/`)[1];
|
||||
|
||||
if (!basePath) {
|
||||
throw new Error(`Invalid file path: ${file.filepath}`);
|
||||
throw new Error(`Invalid file path: ${cleanFilepath}`);
|
||||
}
|
||||
|
||||
const filepath = path.join(userUploadDir, basePath);
|
||||
|
||||
const rel = path.relative(userUploadDir, filepath);
|
||||
if (rel.startsWith('..') || path.isAbsolute(rel) || rel.includes(`..${path.sep}`)) {
|
||||
throw new Error(`Invalid file path: ${file.filepath}`);
|
||||
throw new Error(`Invalid file path: ${cleanFilepath}`);
|
||||
}
|
||||
|
||||
await unlinkFile(filepath);
|
||||
return;
|
||||
}
|
||||
|
||||
const parts = file.filepath.split(path.sep);
|
||||
const parts = cleanFilepath.split(path.sep);
|
||||
const subfolder = parts[1];
|
||||
if (!subfolder && parts[0] === EModelEndpoint.agents) {
|
||||
logger.warn(`Agent File ${file.file_id} is missing filepath, may have been deleted already`);
|
||||
return;
|
||||
}
|
||||
const filepath = path.join(publicPath, file.filepath);
|
||||
const filepath = path.join(publicPath, cleanFilepath);
|
||||
|
||||
if (!isValidPath(req, publicPath, subfolder, filepath)) {
|
||||
throw new Error('Invalid file path');
|
||||
|
|
|
|||
|
|
@ -112,10 +112,11 @@ async function prepareImagesLocal(req, file) {
|
|||
* @param {Buffer} params.buffer - The Buffer containing the avatar image.
|
||||
* @param {string} params.userId - The user ID.
|
||||
* @param {string} params.manual - A string flag indicating whether the update is manual ('true' or 'false').
|
||||
* @param {string} [params.agentId] - Optional agent ID if this is an agent avatar.
|
||||
* @returns {Promise<string>} - A promise that resolves with the URL of the uploaded avatar.
|
||||
* @throws {Error} - Throws an error if Firebase is not initialized or if there is an error in uploading.
|
||||
*/
|
||||
async function processLocalAvatar({ buffer, userId, manual }) {
|
||||
async function processLocalAvatar({ buffer, userId, manual, agentId }) {
|
||||
const userDir = path.resolve(
|
||||
__dirname,
|
||||
'..',
|
||||
|
|
@ -132,7 +133,11 @@ async function processLocalAvatar({ buffer, userId, manual }) {
|
|||
const metadata = await sharp(buffer).metadata();
|
||||
const extension = metadata.format === 'gif' ? 'gif' : 'png';
|
||||
|
||||
const fileName = `avatar-${new Date().getTime()}.${extension}`;
|
||||
const timestamp = new Date().getTime();
|
||||
/** Unique filename with timestamp and optional agent ID */
|
||||
const fileName = agentId
|
||||
? `agent-${agentId}-avatar-${timestamp}.${extension}`
|
||||
: `avatar-${timestamp}.${extension}`;
|
||||
const urlRoute = `/images/${userId}/${fileName}`;
|
||||
const avatarPath = path.join(userDir, fileName);
|
||||
|
||||
|
|
@ -142,7 +147,8 @@ async function processLocalAvatar({ buffer, userId, manual }) {
|
|||
const isManual = manual === 'true';
|
||||
let url = `${urlRoute}?manual=${isManual}`;
|
||||
|
||||
if (isManual) {
|
||||
// Only update user record if this is a user avatar (manual === 'true')
|
||||
if (isManual && !agentId) {
|
||||
await updateUser(userId, { avatar: url });
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,238 +0,0 @@
|
|||
// ~/server/services/Files/MistralOCR/crud.js
|
||||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
const FormData = require('form-data');
|
||||
const {
|
||||
FileSources,
|
||||
envVarRegex,
|
||||
extractEnvVariable,
|
||||
extractVariableName,
|
||||
} = require('librechat-data-provider');
|
||||
const { loadAuthValues } = require('~/server/services/Tools/credentials');
|
||||
const { logger, createAxiosInstance } = require('~/config');
|
||||
const { logAxiosError } = require('~/utils/axios');
|
||||
|
||||
const axios = createAxiosInstance();
|
||||
|
||||
/**
|
||||
* Uploads a document to Mistral API using file streaming to avoid loading the entire file into memory
|
||||
*
|
||||
* @param {Object} params Upload parameters
|
||||
* @param {string} params.filePath The path to the file on disk
|
||||
* @param {string} [params.fileName] Optional filename to use (defaults to the name from filePath)
|
||||
* @param {string} params.apiKey Mistral API key
|
||||
* @param {string} [params.baseURL=https://api.mistral.ai/v1] Mistral API base URL
|
||||
* @returns {Promise<Object>} The response from Mistral API
|
||||
*/
|
||||
async function uploadDocumentToMistral({
|
||||
filePath,
|
||||
fileName = '',
|
||||
apiKey,
|
||||
baseURL = 'https://api.mistral.ai/v1',
|
||||
}) {
|
||||
const form = new FormData();
|
||||
form.append('purpose', 'ocr');
|
||||
const actualFileName = fileName || path.basename(filePath);
|
||||
const fileStream = fs.createReadStream(filePath);
|
||||
form.append('file', fileStream, { filename: actualFileName });
|
||||
|
||||
return axios
|
||||
.post(`${baseURL}/files`, form, {
|
||||
headers: {
|
||||
Authorization: `Bearer ${apiKey}`,
|
||||
...form.getHeaders(),
|
||||
},
|
||||
maxBodyLength: Infinity,
|
||||
maxContentLength: Infinity,
|
||||
})
|
||||
.then((res) => res.data)
|
||||
.catch((error) => {
|
||||
throw error;
|
||||
});
|
||||
}
|
||||
|
||||
async function getSignedUrl({
|
||||
apiKey,
|
||||
fileId,
|
||||
expiry = 24,
|
||||
baseURL = 'https://api.mistral.ai/v1',
|
||||
}) {
|
||||
return axios
|
||||
.get(`${baseURL}/files/${fileId}/url?expiry=${expiry}`, {
|
||||
headers: {
|
||||
Authorization: `Bearer ${apiKey}`,
|
||||
},
|
||||
})
|
||||
.then((res) => res.data)
|
||||
.catch((error) => {
|
||||
logger.error('Error fetching signed URL:', error.message);
|
||||
throw error;
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* @param {Object} params
|
||||
* @param {string} params.apiKey
|
||||
* @param {string} params.url - The document or image URL
|
||||
* @param {string} [params.documentType='document_url'] - 'document_url' or 'image_url'
|
||||
* @param {string} [params.model]
|
||||
* @param {string} [params.baseURL]
|
||||
* @returns {Promise<OCRResult>}
|
||||
*/
|
||||
async function performOCR({
|
||||
apiKey,
|
||||
url,
|
||||
documentType = 'document_url',
|
||||
model = 'mistral-ocr-latest',
|
||||
baseURL = 'https://api.mistral.ai/v1',
|
||||
}) {
|
||||
const documentKey = documentType === 'image_url' ? 'image_url' : 'document_url';
|
||||
return axios
|
||||
.post(
|
||||
`${baseURL}/ocr`,
|
||||
{
|
||||
model,
|
||||
image_limit: 0,
|
||||
include_image_base64: false,
|
||||
document: {
|
||||
type: documentType,
|
||||
[documentKey]: url,
|
||||
},
|
||||
},
|
||||
{
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${apiKey}`,
|
||||
},
|
||||
},
|
||||
)
|
||||
.then((res) => res.data)
|
||||
.catch((error) => {
|
||||
logger.error('Error performing OCR:', error.message);
|
||||
throw error;
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Uploads a file to the Mistral OCR API and processes the OCR result.
|
||||
*
|
||||
* @param {Object} params - The params object.
|
||||
* @param {ServerRequest} params.req - The request object from Express. It should have a `user` property with an `id`
|
||||
* representing the user
|
||||
* @param {Express.Multer.File} params.file - The file object, which is part of the request. The file object should
|
||||
* have a `mimetype` property that tells us the file type
|
||||
* @param {string} params.file_id - The file ID.
|
||||
* @param {string} [params.entity_id] - The entity ID, not used here but passed for consistency.
|
||||
* @returns {Promise<{ filepath: string, bytes: number }>} - The result object containing the processed `text` and `images` (not currently used),
|
||||
* along with the `filename` and `bytes` properties.
|
||||
*/
|
||||
const uploadMistralOCR = async ({ req, file, file_id, entity_id }) => {
|
||||
try {
|
||||
/** @type {TCustomConfig['ocr']} */
|
||||
const ocrConfig = req.app.locals?.ocr;
|
||||
|
||||
const apiKeyConfig = ocrConfig.apiKey || '';
|
||||
const baseURLConfig = ocrConfig.baseURL || '';
|
||||
|
||||
const isApiKeyEnvVar = envVarRegex.test(apiKeyConfig);
|
||||
const isBaseURLEnvVar = envVarRegex.test(baseURLConfig);
|
||||
|
||||
const isApiKeyEmpty = !apiKeyConfig.trim();
|
||||
const isBaseURLEmpty = !baseURLConfig.trim();
|
||||
|
||||
let apiKey, baseURL;
|
||||
|
||||
if (isApiKeyEnvVar || isBaseURLEnvVar || isApiKeyEmpty || isBaseURLEmpty) {
|
||||
const apiKeyVarName = isApiKeyEnvVar ? extractVariableName(apiKeyConfig) : 'OCR_API_KEY';
|
||||
const baseURLVarName = isBaseURLEnvVar ? extractVariableName(baseURLConfig) : 'OCR_BASEURL';
|
||||
|
||||
const authValues = await loadAuthValues({
|
||||
userId: req.user.id,
|
||||
authFields: [baseURLVarName, apiKeyVarName],
|
||||
optional: new Set([baseURLVarName]),
|
||||
});
|
||||
|
||||
apiKey = authValues[apiKeyVarName];
|
||||
baseURL = authValues[baseURLVarName];
|
||||
} else {
|
||||
apiKey = apiKeyConfig;
|
||||
baseURL = baseURLConfig;
|
||||
}
|
||||
|
||||
const mistralFile = await uploadDocumentToMistral({
|
||||
filePath: file.path,
|
||||
fileName: file.originalname,
|
||||
apiKey,
|
||||
baseURL,
|
||||
});
|
||||
|
||||
const modelConfig = ocrConfig.mistralModel || '';
|
||||
const model = envVarRegex.test(modelConfig)
|
||||
? extractEnvVariable(modelConfig)
|
||||
: modelConfig.trim() || 'mistral-ocr-latest';
|
||||
|
||||
const signedUrlResponse = await getSignedUrl({
|
||||
apiKey,
|
||||
baseURL,
|
||||
fileId: mistralFile.id,
|
||||
});
|
||||
|
||||
const mimetype = (file.mimetype || '').toLowerCase();
|
||||
const originalname = file.originalname || '';
|
||||
const isImage =
|
||||
mimetype.startsWith('image') || /\.(png|jpe?g|gif|bmp|webp|tiff?)$/i.test(originalname);
|
||||
const documentType = isImage ? 'image_url' : 'document_url';
|
||||
|
||||
const ocrResult = await performOCR({
|
||||
apiKey,
|
||||
baseURL,
|
||||
model,
|
||||
url: signedUrlResponse.url,
|
||||
documentType,
|
||||
});
|
||||
|
||||
let aggregatedText = '';
|
||||
const images = [];
|
||||
ocrResult.pages.forEach((page, index) => {
|
||||
if (ocrResult.pages.length > 1) {
|
||||
aggregatedText += `# PAGE ${index + 1}\n`;
|
||||
}
|
||||
|
||||
aggregatedText += page.markdown + '\n\n';
|
||||
|
||||
if (page.images && page.images.length > 0) {
|
||||
page.images.forEach((image) => {
|
||||
if (image.image_base64) {
|
||||
images.push(image.image_base64);
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
return {
|
||||
filename: file.originalname,
|
||||
bytes: aggregatedText.length * 4,
|
||||
filepath: FileSources.mistral_ocr,
|
||||
text: aggregatedText,
|
||||
images,
|
||||
};
|
||||
} catch (error) {
|
||||
let message = 'Error uploading document to Mistral OCR API';
|
||||
const detail = error?.response?.data?.detail;
|
||||
if (detail && detail !== '') {
|
||||
message = detail;
|
||||
}
|
||||
|
||||
const responseMessage = error?.response?.data?.message;
|
||||
throw new Error(
|
||||
`${logAxiosError({ error, message })}${responseMessage && responseMessage !== '' ? ` - ${responseMessage}` : ''}`,
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
uploadDocumentToMistral,
|
||||
uploadMistralOCR,
|
||||
getSignedUrl,
|
||||
performOCR,
|
||||
};
|
||||
|
|
@ -1,848 +0,0 @@
|
|||
const fs = require('fs');
|
||||
|
||||
const mockAxios = {
|
||||
interceptors: {
|
||||
request: { use: jest.fn(), eject: jest.fn() },
|
||||
response: { use: jest.fn(), eject: jest.fn() },
|
||||
},
|
||||
create: jest.fn().mockReturnValue({
|
||||
defaults: {
|
||||
proxy: null,
|
||||
},
|
||||
get: jest.fn().mockResolvedValue({ data: {} }),
|
||||
post: jest.fn().mockResolvedValue({ data: {} }),
|
||||
put: jest.fn().mockResolvedValue({ data: {} }),
|
||||
delete: jest.fn().mockResolvedValue({ data: {} }),
|
||||
}),
|
||||
get: jest.fn().mockResolvedValue({ data: {} }),
|
||||
post: jest.fn().mockResolvedValue({ data: {} }),
|
||||
put: jest.fn().mockResolvedValue({ data: {} }),
|
||||
delete: jest.fn().mockResolvedValue({ data: {} }),
|
||||
reset: jest.fn().mockImplementation(function () {
|
||||
this.get.mockClear();
|
||||
this.post.mockClear();
|
||||
this.put.mockClear();
|
||||
this.delete.mockClear();
|
||||
this.create.mockClear();
|
||||
}),
|
||||
};
|
||||
|
||||
jest.mock('axios', () => mockAxios);
|
||||
jest.mock('fs');
|
||||
jest.mock('~/config', () => ({
|
||||
logger: {
|
||||
error: jest.fn(),
|
||||
},
|
||||
createAxiosInstance: () => mockAxios,
|
||||
}));
|
||||
jest.mock('~/server/services/Tools/credentials', () => ({
|
||||
loadAuthValues: jest.fn(),
|
||||
}));
|
||||
|
||||
const { uploadDocumentToMistral, uploadMistralOCR, getSignedUrl, performOCR } = require('./crud');
|
||||
|
||||
describe('MistralOCR Service', () => {
|
||||
afterEach(() => {
|
||||
mockAxios.reset();
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
describe('uploadDocumentToMistral', () => {
|
||||
beforeEach(() => {
|
||||
// Create a more complete mock for file streams that FormData can work with
|
||||
const mockReadStream = {
|
||||
on: jest.fn().mockImplementation(function (event, handler) {
|
||||
// Simulate immediate 'end' event to make FormData complete processing
|
||||
if (event === 'end') {
|
||||
handler();
|
||||
}
|
||||
return this;
|
||||
}),
|
||||
pipe: jest.fn().mockImplementation(function () {
|
||||
return this;
|
||||
}),
|
||||
pause: jest.fn(),
|
||||
resume: jest.fn(),
|
||||
emit: jest.fn(),
|
||||
once: jest.fn(),
|
||||
destroy: jest.fn(),
|
||||
};
|
||||
|
||||
fs.createReadStream = jest.fn().mockReturnValue(mockReadStream);
|
||||
|
||||
// Mock FormData's append to avoid actual stream processing
|
||||
jest.mock('form-data', () => {
|
||||
const mockFormData = function () {
|
||||
return {
|
||||
append: jest.fn(),
|
||||
getHeaders: jest
|
||||
.fn()
|
||||
.mockReturnValue({ 'content-type': 'multipart/form-data; boundary=---boundary' }),
|
||||
getBuffer: jest.fn().mockReturnValue(Buffer.from('mock-form-data')),
|
||||
getLength: jest.fn().mockReturnValue(100),
|
||||
};
|
||||
};
|
||||
return mockFormData;
|
||||
});
|
||||
});
|
||||
|
||||
it('should upload a document to Mistral API using file streaming', async () => {
|
||||
const mockResponse = { data: { id: 'file-123', purpose: 'ocr' } };
|
||||
mockAxios.post.mockResolvedValueOnce(mockResponse);
|
||||
|
||||
const result = await uploadDocumentToMistral({
|
||||
filePath: '/path/to/test.pdf',
|
||||
fileName: 'test.pdf',
|
||||
apiKey: 'test-api-key',
|
||||
});
|
||||
|
||||
// Check that createReadStream was called with the correct file path
|
||||
expect(fs.createReadStream).toHaveBeenCalledWith('/path/to/test.pdf');
|
||||
|
||||
// Since we're mocking FormData, we'll just check that axios was called correctly
|
||||
expect(mockAxios.post).toHaveBeenCalledWith(
|
||||
'https://api.mistral.ai/v1/files',
|
||||
expect.anything(),
|
||||
expect.objectContaining({
|
||||
headers: expect.objectContaining({
|
||||
Authorization: 'Bearer test-api-key',
|
||||
}),
|
||||
maxBodyLength: Infinity,
|
||||
maxContentLength: Infinity,
|
||||
}),
|
||||
);
|
||||
expect(result).toEqual(mockResponse.data);
|
||||
});
|
||||
|
||||
it('should handle errors during document upload', async () => {
|
||||
const errorMessage = 'API error';
|
||||
mockAxios.post.mockRejectedValueOnce(new Error(errorMessage));
|
||||
|
||||
await expect(
|
||||
uploadDocumentToMistral({
|
||||
filePath: '/path/to/test.pdf',
|
||||
fileName: 'test.pdf',
|
||||
apiKey: 'test-api-key',
|
||||
}),
|
||||
).rejects.toThrow(errorMessage);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getSignedUrl', () => {
|
||||
it('should fetch signed URL from Mistral API', async () => {
|
||||
const mockResponse = { data: { url: 'https://document-url.com' } };
|
||||
mockAxios.get.mockResolvedValueOnce(mockResponse);
|
||||
|
||||
const result = await getSignedUrl({
|
||||
fileId: 'file-123',
|
||||
apiKey: 'test-api-key',
|
||||
});
|
||||
|
||||
expect(mockAxios.get).toHaveBeenCalledWith(
|
||||
'https://api.mistral.ai/v1/files/file-123/url?expiry=24',
|
||||
{
|
||||
headers: {
|
||||
Authorization: 'Bearer test-api-key',
|
||||
},
|
||||
},
|
||||
);
|
||||
expect(result).toEqual(mockResponse.data);
|
||||
});
|
||||
|
||||
it('should handle errors when fetching signed URL', async () => {
|
||||
const errorMessage = 'API error';
|
||||
mockAxios.get.mockRejectedValueOnce(new Error(errorMessage));
|
||||
|
||||
await expect(
|
||||
getSignedUrl({
|
||||
fileId: 'file-123',
|
||||
apiKey: 'test-api-key',
|
||||
}),
|
||||
).rejects.toThrow();
|
||||
|
||||
const { logger } = require('~/config');
|
||||
expect(logger.error).toHaveBeenCalledWith('Error fetching signed URL:', errorMessage);
|
||||
});
|
||||
});
|
||||
|
||||
describe('performOCR', () => {
|
||||
it('should perform OCR using Mistral API (document_url)', async () => {
|
||||
const mockResponse = {
|
||||
data: {
|
||||
pages: [{ markdown: 'Page 1 content' }, { markdown: 'Page 2 content' }],
|
||||
},
|
||||
};
|
||||
mockAxios.post.mockResolvedValueOnce(mockResponse);
|
||||
|
||||
const result = await performOCR({
|
||||
apiKey: 'test-api-key',
|
||||
url: 'https://document-url.com',
|
||||
model: 'mistral-ocr-latest',
|
||||
documentType: 'document_url',
|
||||
});
|
||||
|
||||
expect(mockAxios.post).toHaveBeenCalledWith(
|
||||
'https://api.mistral.ai/v1/ocr',
|
||||
{
|
||||
model: 'mistral-ocr-latest',
|
||||
include_image_base64: false,
|
||||
image_limit: 0,
|
||||
document: {
|
||||
type: 'document_url',
|
||||
document_url: 'https://document-url.com',
|
||||
},
|
||||
},
|
||||
{
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: 'Bearer test-api-key',
|
||||
},
|
||||
},
|
||||
);
|
||||
expect(result).toEqual(mockResponse.data);
|
||||
});
|
||||
|
||||
it('should perform OCR using Mistral API (image_url)', async () => {
|
||||
const mockResponse = {
|
||||
data: {
|
||||
pages: [{ markdown: 'Image OCR content' }],
|
||||
},
|
||||
};
|
||||
mockAxios.post.mockResolvedValueOnce(mockResponse);
|
||||
|
||||
const result = await performOCR({
|
||||
apiKey: 'test-api-key',
|
||||
url: 'https://image-url.com/image.png',
|
||||
model: 'mistral-ocr-latest',
|
||||
documentType: 'image_url',
|
||||
});
|
||||
|
||||
expect(mockAxios.post).toHaveBeenCalledWith(
|
||||
'https://api.mistral.ai/v1/ocr',
|
||||
{
|
||||
model: 'mistral-ocr-latest',
|
||||
include_image_base64: false,
|
||||
image_limit: 0,
|
||||
document: {
|
||||
type: 'image_url',
|
||||
image_url: 'https://image-url.com/image.png',
|
||||
},
|
||||
},
|
||||
{
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: 'Bearer test-api-key',
|
||||
},
|
||||
},
|
||||
);
|
||||
expect(result).toEqual(mockResponse.data);
|
||||
});
|
||||
|
||||
it('should handle errors during OCR processing', async () => {
|
||||
const errorMessage = 'OCR processing error';
|
||||
mockAxios.post.mockRejectedValueOnce(new Error(errorMessage));
|
||||
|
||||
await expect(
|
||||
performOCR({
|
||||
apiKey: 'test-api-key',
|
||||
url: 'https://document-url.com',
|
||||
}),
|
||||
).rejects.toThrow();
|
||||
|
||||
const { logger } = require('~/config');
|
||||
expect(logger.error).toHaveBeenCalledWith('Error performing OCR:', errorMessage);
|
||||
});
|
||||
});
|
||||
|
||||
describe('uploadMistralOCR', () => {
|
||||
beforeEach(() => {
|
||||
const mockReadStream = {
|
||||
on: jest.fn().mockImplementation(function (event, handler) {
|
||||
if (event === 'end') {
|
||||
handler();
|
||||
}
|
||||
return this;
|
||||
}),
|
||||
pipe: jest.fn().mockImplementation(function () {
|
||||
return this;
|
||||
}),
|
||||
pause: jest.fn(),
|
||||
resume: jest.fn(),
|
||||
emit: jest.fn(),
|
||||
once: jest.fn(),
|
||||
destroy: jest.fn(),
|
||||
};
|
||||
|
||||
fs.createReadStream = jest.fn().mockReturnValue(mockReadStream);
|
||||
});
|
||||
|
||||
it('should process OCR for a file with standard configuration', async () => {
|
||||
// Setup mocks
|
||||
const { loadAuthValues } = require('~/server/services/Tools/credentials');
|
||||
loadAuthValues.mockResolvedValue({
|
||||
OCR_API_KEY: 'test-api-key',
|
||||
OCR_BASEURL: 'https://api.mistral.ai/v1',
|
||||
});
|
||||
|
||||
// Mock file upload response
|
||||
mockAxios.post.mockResolvedValueOnce({
|
||||
data: { id: 'file-123', purpose: 'ocr' },
|
||||
});
|
||||
|
||||
// Mock signed URL response
|
||||
mockAxios.get.mockResolvedValueOnce({
|
||||
data: { url: 'https://signed-url.com' },
|
||||
});
|
||||
|
||||
// Mock OCR response with text and images
|
||||
mockAxios.post.mockResolvedValueOnce({
|
||||
data: {
|
||||
pages: [
|
||||
{
|
||||
markdown: 'Page 1 content',
|
||||
images: [{ image_base64: 'base64image1' }],
|
||||
},
|
||||
{
|
||||
markdown: 'Page 2 content',
|
||||
images: [{ image_base64: 'base64image2' }],
|
||||
},
|
||||
],
|
||||
},
|
||||
});
|
||||
|
||||
const req = {
|
||||
user: { id: 'user123' },
|
||||
app: {
|
||||
locals: {
|
||||
ocr: {
|
||||
// Use environment variable syntax to ensure loadAuthValues is called
|
||||
apiKey: '${OCR_API_KEY}',
|
||||
baseURL: '${OCR_BASEURL}',
|
||||
mistralModel: 'mistral-medium',
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const file = {
|
||||
path: '/tmp/upload/file.pdf',
|
||||
originalname: 'document.pdf',
|
||||
mimetype: 'application/pdf',
|
||||
};
|
||||
|
||||
const result = await uploadMistralOCR({
|
||||
req,
|
||||
file,
|
||||
file_id: 'file123',
|
||||
entity_id: 'entity123',
|
||||
});
|
||||
|
||||
expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf');
|
||||
|
||||
expect(loadAuthValues).toHaveBeenCalledWith({
|
||||
userId: 'user123',
|
||||
authFields: ['OCR_BASEURL', 'OCR_API_KEY'],
|
||||
optional: expect.any(Set),
|
||||
});
|
||||
|
||||
// Verify OCR result
|
||||
expect(result).toEqual({
|
||||
filename: 'document.pdf',
|
||||
bytes: expect.any(Number),
|
||||
filepath: 'mistral_ocr',
|
||||
text: expect.stringContaining('# PAGE 1'),
|
||||
images: ['base64image1', 'base64image2'],
|
||||
});
|
||||
});
|
||||
|
||||
it('should process OCR for an image file and use image_url type', async () => {
|
||||
const { loadAuthValues } = require('~/server/services/Tools/credentials');
|
||||
loadAuthValues.mockResolvedValue({
|
||||
OCR_API_KEY: 'test-api-key',
|
||||
OCR_BASEURL: 'https://api.mistral.ai/v1',
|
||||
});
|
||||
|
||||
// Mock file upload response
|
||||
mockAxios.post.mockResolvedValueOnce({
|
||||
data: { id: 'file-456', purpose: 'ocr' },
|
||||
});
|
||||
|
||||
// Mock signed URL response
|
||||
mockAxios.get.mockResolvedValueOnce({
|
||||
data: { url: 'https://signed-url.com/image.png' },
|
||||
});
|
||||
|
||||
// Mock OCR response for image
|
||||
mockAxios.post.mockResolvedValueOnce({
|
||||
data: {
|
||||
pages: [
|
||||
{
|
||||
markdown: 'Image OCR result',
|
||||
images: [{ image_base64: 'imgbase64' }],
|
||||
},
|
||||
],
|
||||
},
|
||||
});
|
||||
|
||||
const req = {
|
||||
user: { id: 'user456' },
|
||||
app: {
|
||||
locals: {
|
||||
ocr: {
|
||||
apiKey: '${OCR_API_KEY}',
|
||||
baseURL: '${OCR_BASEURL}',
|
||||
mistralModel: 'mistral-medium',
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const file = {
|
||||
path: '/tmp/upload/image.png',
|
||||
originalname: 'image.png',
|
||||
mimetype: 'image/png',
|
||||
};
|
||||
|
||||
const result = await uploadMistralOCR({
|
||||
req,
|
||||
file,
|
||||
file_id: 'file456',
|
||||
entity_id: 'entity456',
|
||||
});
|
||||
|
||||
expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/image.png');
|
||||
|
||||
expect(loadAuthValues).toHaveBeenCalledWith({
|
||||
userId: 'user456',
|
||||
authFields: ['OCR_BASEURL', 'OCR_API_KEY'],
|
||||
optional: expect.any(Set),
|
||||
});
|
||||
|
||||
// Check that the OCR API was called with image_url type
|
||||
expect(mockAxios.post).toHaveBeenCalledWith(
|
||||
'https://api.mistral.ai/v1/ocr',
|
||||
expect.objectContaining({
|
||||
document: expect.objectContaining({
|
||||
type: 'image_url',
|
||||
image_url: 'https://signed-url.com/image.png',
|
||||
}),
|
||||
}),
|
||||
expect.any(Object),
|
||||
);
|
||||
|
||||
expect(result).toEqual({
|
||||
filename: 'image.png',
|
||||
bytes: expect.any(Number),
|
||||
filepath: 'mistral_ocr',
|
||||
text: expect.stringContaining('Image OCR result'),
|
||||
images: ['imgbase64'],
|
||||
});
|
||||
});
|
||||
|
||||
it('should process variable references in configuration', async () => {
|
||||
// Setup mocks with environment variables
|
||||
const { loadAuthValues } = require('~/server/services/Tools/credentials');
|
||||
loadAuthValues.mockResolvedValue({
|
||||
CUSTOM_API_KEY: 'custom-api-key',
|
||||
CUSTOM_BASEURL: 'https://custom-api.mistral.ai/v1',
|
||||
});
|
||||
|
||||
// Mock API responses
|
||||
mockAxios.post.mockResolvedValueOnce({
|
||||
data: { id: 'file-123', purpose: 'ocr' },
|
||||
});
|
||||
mockAxios.get.mockResolvedValueOnce({
|
||||
data: { url: 'https://signed-url.com' },
|
||||
});
|
||||
mockAxios.post.mockResolvedValueOnce({
|
||||
data: {
|
||||
pages: [{ markdown: 'Content from custom API' }],
|
||||
},
|
||||
});
|
||||
|
||||
const req = {
|
||||
user: { id: 'user123' },
|
||||
app: {
|
||||
locals: {
|
||||
ocr: {
|
||||
apiKey: '${CUSTOM_API_KEY}',
|
||||
baseURL: '${CUSTOM_BASEURL}',
|
||||
mistralModel: '${CUSTOM_MODEL}',
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
// Set environment variable for model
|
||||
process.env.CUSTOM_MODEL = 'mistral-large';
|
||||
|
||||
const file = {
|
||||
path: '/tmp/upload/file.pdf',
|
||||
originalname: 'document.pdf',
|
||||
};
|
||||
|
||||
const result = await uploadMistralOCR({
|
||||
req,
|
||||
file,
|
||||
file_id: 'file123',
|
||||
entity_id: 'entity123',
|
||||
});
|
||||
|
||||
expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf');
|
||||
|
||||
// Verify that custom environment variables were extracted and used
|
||||
expect(loadAuthValues).toHaveBeenCalledWith({
|
||||
userId: 'user123',
|
||||
authFields: ['CUSTOM_BASEURL', 'CUSTOM_API_KEY'],
|
||||
optional: expect.any(Set),
|
||||
});
|
||||
|
||||
// Check that mistral-large was used in the OCR API call
|
||||
expect(mockAxios.post).toHaveBeenCalledWith(
|
||||
expect.anything(),
|
||||
expect.objectContaining({
|
||||
model: 'mistral-large',
|
||||
}),
|
||||
expect.anything(),
|
||||
);
|
||||
|
||||
expect(result.text).toEqual('Content from custom API\n\n');
|
||||
});
|
||||
|
||||
it('should fall back to default values when variables are not properly formatted', async () => {
|
||||
const { loadAuthValues } = require('~/server/services/Tools/credentials');
|
||||
loadAuthValues.mockResolvedValue({
|
||||
OCR_API_KEY: 'default-api-key',
|
||||
OCR_BASEURL: undefined, // Testing optional parameter
|
||||
});
|
||||
|
||||
mockAxios.post.mockResolvedValueOnce({
|
||||
data: { id: 'file-123', purpose: 'ocr' },
|
||||
});
|
||||
mockAxios.get.mockResolvedValueOnce({
|
||||
data: { url: 'https://signed-url.com' },
|
||||
});
|
||||
mockAxios.post.mockResolvedValueOnce({
|
||||
data: {
|
||||
pages: [{ markdown: 'Default API result' }],
|
||||
},
|
||||
});
|
||||
|
||||
const req = {
|
||||
user: { id: 'user123' },
|
||||
app: {
|
||||
locals: {
|
||||
ocr: {
|
||||
// Use environment variable syntax to ensure loadAuthValues is called
|
||||
apiKey: '${INVALID_FORMAT}', // Using valid env var format but with an invalid name
|
||||
baseURL: '${OCR_BASEURL}', // Using valid env var format
|
||||
mistralModel: 'mistral-ocr-latest', // Plain string value
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const file = {
|
||||
path: '/tmp/upload/file.pdf',
|
||||
originalname: 'document.pdf',
|
||||
};
|
||||
|
||||
await uploadMistralOCR({
|
||||
req,
|
||||
file,
|
||||
file_id: 'file123',
|
||||
entity_id: 'entity123',
|
||||
});
|
||||
|
||||
expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf');
|
||||
|
||||
// Should use the default values
|
||||
expect(loadAuthValues).toHaveBeenCalledWith({
|
||||
userId: 'user123',
|
||||
authFields: ['OCR_BASEURL', 'INVALID_FORMAT'],
|
||||
optional: expect.any(Set),
|
||||
});
|
||||
|
||||
// Should use the default model when not using environment variable format
|
||||
expect(mockAxios.post).toHaveBeenCalledWith(
|
||||
expect.anything(),
|
||||
expect.objectContaining({
|
||||
model: 'mistral-ocr-latest',
|
||||
}),
|
||||
expect.anything(),
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle API errors during OCR process', async () => {
|
||||
const { loadAuthValues } = require('~/server/services/Tools/credentials');
|
||||
loadAuthValues.mockResolvedValue({
|
||||
OCR_API_KEY: 'test-api-key',
|
||||
});
|
||||
|
||||
// Mock file upload to fail
|
||||
mockAxios.post.mockRejectedValueOnce(new Error('Upload failed'));
|
||||
|
||||
const req = {
|
||||
user: { id: 'user123' },
|
||||
app: {
|
||||
locals: {
|
||||
ocr: {
|
||||
apiKey: 'OCR_API_KEY',
|
||||
baseURL: 'OCR_BASEURL',
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const file = {
|
||||
path: '/tmp/upload/file.pdf',
|
||||
originalname: 'document.pdf',
|
||||
};
|
||||
|
||||
await expect(
|
||||
uploadMistralOCR({
|
||||
req,
|
||||
file,
|
||||
file_id: 'file123',
|
||||
entity_id: 'entity123',
|
||||
}),
|
||||
).rejects.toThrow('Error uploading document to Mistral OCR API');
|
||||
expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf');
|
||||
});
|
||||
|
||||
it('should handle single page documents without page numbering', async () => {
|
||||
const { loadAuthValues } = require('~/server/services/Tools/credentials');
|
||||
loadAuthValues.mockResolvedValue({
|
||||
OCR_API_KEY: 'test-api-key',
|
||||
OCR_BASEURL: 'https://api.mistral.ai/v1', // Make sure this is included
|
||||
});
|
||||
|
||||
// Clear all previous mocks
|
||||
mockAxios.post.mockClear();
|
||||
mockAxios.get.mockClear();
|
||||
|
||||
// 1. First mock: File upload response
|
||||
mockAxios.post.mockImplementationOnce(() =>
|
||||
Promise.resolve({ data: { id: 'file-123', purpose: 'ocr' } }),
|
||||
);
|
||||
|
||||
// 2. Second mock: Signed URL response
|
||||
mockAxios.get.mockImplementationOnce(() =>
|
||||
Promise.resolve({ data: { url: 'https://signed-url.com' } }),
|
||||
);
|
||||
|
||||
// 3. Third mock: OCR response
|
||||
mockAxios.post.mockImplementationOnce(() =>
|
||||
Promise.resolve({
|
||||
data: {
|
||||
pages: [{ markdown: 'Single page content' }],
|
||||
},
|
||||
}),
|
||||
);
|
||||
|
||||
const req = {
|
||||
user: { id: 'user123' },
|
||||
app: {
|
||||
locals: {
|
||||
ocr: {
|
||||
apiKey: 'OCR_API_KEY',
|
||||
baseURL: 'OCR_BASEURL',
|
||||
mistralModel: 'mistral-ocr-latest',
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const file = {
|
||||
path: '/tmp/upload/file.pdf',
|
||||
originalname: 'single-page.pdf',
|
||||
};
|
||||
|
||||
const result = await uploadMistralOCR({
|
||||
req,
|
||||
file,
|
||||
file_id: 'file123',
|
||||
entity_id: 'entity123',
|
||||
});
|
||||
|
||||
expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf');
|
||||
|
||||
// Verify that single page documents don't include page numbering
|
||||
expect(result.text).not.toContain('# PAGE');
|
||||
expect(result.text).toEqual('Single page content\n\n');
|
||||
});
|
||||
|
||||
it('should use literal values in configuration when provided directly', async () => {
|
||||
const { loadAuthValues } = require('~/server/services/Tools/credentials');
|
||||
// We'll still mock this but it should not be used for literal values
|
||||
loadAuthValues.mockResolvedValue({});
|
||||
|
||||
// Clear all previous mocks
|
||||
mockAxios.post.mockClear();
|
||||
mockAxios.get.mockClear();
|
||||
|
||||
// 1. First mock: File upload response
|
||||
mockAxios.post.mockImplementationOnce(() =>
|
||||
Promise.resolve({ data: { id: 'file-123', purpose: 'ocr' } }),
|
||||
);
|
||||
|
||||
// 2. Second mock: Signed URL response
|
||||
mockAxios.get.mockImplementationOnce(() =>
|
||||
Promise.resolve({ data: { url: 'https://signed-url.com' } }),
|
||||
);
|
||||
|
||||
// 3. Third mock: OCR response
|
||||
mockAxios.post.mockImplementationOnce(() =>
|
||||
Promise.resolve({
|
||||
data: {
|
||||
pages: [{ markdown: 'Processed with literal config values' }],
|
||||
},
|
||||
}),
|
||||
);
|
||||
|
||||
const req = {
|
||||
user: { id: 'user123' },
|
||||
app: {
|
||||
locals: {
|
||||
ocr: {
|
||||
// Direct values that should be used as-is, without variable substitution
|
||||
apiKey: 'actual-api-key-value',
|
||||
baseURL: 'https://direct-api-url.mistral.ai/v1',
|
||||
mistralModel: 'mistral-direct-model',
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const file = {
|
||||
path: '/tmp/upload/file.pdf',
|
||||
originalname: 'direct-values.pdf',
|
||||
};
|
||||
|
||||
const result = await uploadMistralOCR({
|
||||
req,
|
||||
file,
|
||||
file_id: 'file123',
|
||||
entity_id: 'entity123',
|
||||
});
|
||||
|
||||
expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf');
|
||||
|
||||
// Verify the correct URL was used with the direct baseURL value
|
||||
expect(mockAxios.post).toHaveBeenCalledWith(
|
||||
'https://direct-api-url.mistral.ai/v1/files',
|
||||
expect.any(Object),
|
||||
expect.objectContaining({
|
||||
headers: expect.objectContaining({
|
||||
Authorization: 'Bearer actual-api-key-value',
|
||||
}),
|
||||
}),
|
||||
);
|
||||
|
||||
// Check the OCR call was made with the direct model value
|
||||
expect(mockAxios.post).toHaveBeenCalledWith(
|
||||
'https://direct-api-url.mistral.ai/v1/ocr',
|
||||
expect.objectContaining({
|
||||
model: 'mistral-direct-model',
|
||||
}),
|
||||
expect.any(Object),
|
||||
);
|
||||
|
||||
// Verify the result
|
||||
expect(result.text).toEqual('Processed with literal config values\n\n');
|
||||
|
||||
// Verify loadAuthValues was never called since we used direct values
|
||||
expect(loadAuthValues).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle empty configuration values and use defaults', async () => {
|
||||
const { loadAuthValues } = require('~/server/services/Tools/credentials');
|
||||
// Set up the mock values to be returned by loadAuthValues
|
||||
loadAuthValues.mockResolvedValue({
|
||||
OCR_API_KEY: 'default-from-env-key',
|
||||
OCR_BASEURL: 'https://default-from-env.mistral.ai/v1',
|
||||
});
|
||||
|
||||
// Clear all previous mocks
|
||||
mockAxios.post.mockClear();
|
||||
mockAxios.get.mockClear();
|
||||
|
||||
// 1. First mock: File upload response
|
||||
mockAxios.post.mockImplementationOnce(() =>
|
||||
Promise.resolve({ data: { id: 'file-123', purpose: 'ocr' } }),
|
||||
);
|
||||
|
||||
// 2. Second mock: Signed URL response
|
||||
mockAxios.get.mockImplementationOnce(() =>
|
||||
Promise.resolve({ data: { url: 'https://signed-url.com' } }),
|
||||
);
|
||||
|
||||
// 3. Third mock: OCR response
|
||||
mockAxios.post.mockImplementationOnce(() =>
|
||||
Promise.resolve({
|
||||
data: {
|
||||
pages: [{ markdown: 'Content from default configuration' }],
|
||||
},
|
||||
}),
|
||||
);
|
||||
|
||||
const req = {
|
||||
user: { id: 'user123' },
|
||||
app: {
|
||||
locals: {
|
||||
ocr: {
|
||||
// Empty string values - should fall back to defaults
|
||||
apiKey: '',
|
||||
baseURL: '',
|
||||
mistralModel: '',
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const file = {
|
||||
path: '/tmp/upload/file.pdf',
|
||||
originalname: 'empty-config.pdf',
|
||||
};
|
||||
|
||||
const result = await uploadMistralOCR({
|
||||
req,
|
||||
file,
|
||||
file_id: 'file123',
|
||||
entity_id: 'entity123',
|
||||
});
|
||||
|
||||
expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf');
|
||||
|
||||
// Verify loadAuthValues was called with the default variable names
|
||||
expect(loadAuthValues).toHaveBeenCalledWith({
|
||||
userId: 'user123',
|
||||
authFields: ['OCR_BASEURL', 'OCR_API_KEY'],
|
||||
optional: expect.any(Set),
|
||||
});
|
||||
|
||||
// Verify the API calls used the default values from loadAuthValues
|
||||
expect(mockAxios.post).toHaveBeenCalledWith(
|
||||
'https://default-from-env.mistral.ai/v1/files',
|
||||
expect.any(Object),
|
||||
expect.objectContaining({
|
||||
headers: expect.objectContaining({
|
||||
Authorization: 'Bearer default-from-env-key',
|
||||
}),
|
||||
}),
|
||||
);
|
||||
|
||||
// Verify the OCR model defaulted to mistral-ocr-latest
|
||||
expect(mockAxios.post).toHaveBeenCalledWith(
|
||||
'https://default-from-env.mistral.ai/v1/ocr',
|
||||
expect.objectContaining({
|
||||
model: 'mistral-ocr-latest',
|
||||
}),
|
||||
expect.any(Object),
|
||||
);
|
||||
|
||||
// Check result
|
||||
expect(result.text).toEqual('Content from default configuration\n\n');
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
@ -1,5 +0,0 @@
|
|||
const crud = require('./crud');
|
||||
|
||||
module.exports = {
|
||||
...crud,
|
||||
};
|
||||
|
|
@ -94,19 +94,28 @@ async function prepareImageURLS3(req, file) {
|
|||
* @param {Buffer} params.buffer - Avatar image buffer.
|
||||
* @param {string} params.userId - User's unique identifier.
|
||||
* @param {string} params.manual - 'true' or 'false' flag for manual update.
|
||||
* @param {string} [params.agentId] - Optional agent ID if this is an agent avatar.
|
||||
* @param {string} [params.basePath='images'] - Base path in the bucket.
|
||||
* @returns {Promise<string>} Signed URL of the uploaded avatar.
|
||||
*/
|
||||
async function processS3Avatar({ buffer, userId, manual, basePath = defaultBasePath }) {
|
||||
async function processS3Avatar({ buffer, userId, manual, agentId, basePath = defaultBasePath }) {
|
||||
try {
|
||||
const metadata = await sharp(buffer).metadata();
|
||||
const extension = metadata.format === 'gif' ? 'gif' : 'png';
|
||||
const fileName = `avatar.${extension}`;
|
||||
const timestamp = new Date().getTime();
|
||||
|
||||
/** Unique filename with timestamp and optional agent ID */
|
||||
const fileName = agentId
|
||||
? `agent-${agentId}-avatar-${timestamp}.${extension}`
|
||||
: `avatar-${timestamp}.${extension}`;
|
||||
|
||||
const downloadURL = await saveBufferToS3({ userId, buffer, fileName, basePath });
|
||||
if (manual === 'true') {
|
||||
|
||||
// Only update user record if this is a user avatar (manual === 'true')
|
||||
if (manual === 'true' && !agentId) {
|
||||
await updateUser(userId, { avatar: downloadURL });
|
||||
}
|
||||
|
||||
return downloadURL;
|
||||
} catch (error) {
|
||||
logger.error('[processS3Avatar] Error processing S3 avatar:', error.message);
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
const fs = require('fs');
|
||||
const axios = require('axios');
|
||||
const FormData = require('form-data');
|
||||
const { logAxiosError } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { FileSources } = require('librechat-data-provider');
|
||||
const { logAxiosError } = require('~/utils');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
/**
|
||||
* Deletes a file from the vector database. This function takes a file object, constructs the full path, and
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
const axios = require('axios');
|
||||
const { logAxiosError } = require('@librechat/api');
|
||||
const {
|
||||
FileSources,
|
||||
VisionModes,
|
||||
|
|
@ -7,8 +8,6 @@ const {
|
|||
EModelEndpoint,
|
||||
} = require('librechat-data-provider');
|
||||
const { getStrategyFunctions } = require('~/server/services/Files/strategies');
|
||||
const { logAxiosError } = require('~/utils');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
/**
|
||||
* Converts a readable stream to a base64 encoded string.
|
||||
|
|
|
|||
|
|
@ -522,7 +522,7 @@ const processAgentFileUpload = async ({ req, res, metadata }) => {
|
|||
throw new Error('OCR capability is not enabled for Agents');
|
||||
}
|
||||
|
||||
const { handleFileUpload: uploadMistralOCR } = getStrategyFunctions(
|
||||
const { handleFileUpload: uploadOCR } = getStrategyFunctions(
|
||||
req.app.locals?.ocr?.strategy ?? FileSources.mistral_ocr,
|
||||
);
|
||||
const { file_id, temp_file_id } = metadata;
|
||||
|
|
@ -534,7 +534,7 @@ const processAgentFileUpload = async ({ req, res, metadata }) => {
|
|||
images,
|
||||
filename,
|
||||
filepath: ocrFileURL,
|
||||
} = await uploadMistralOCR({ req, file, file_id, entity_id: agent_id, basePath });
|
||||
} = await uploadOCR({ req, file, loadAuthValues });
|
||||
|
||||
const fileInfo = removeNullishValues({
|
||||
text,
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
const { FileSources } = require('librechat-data-provider');
|
||||
const { uploadMistralOCR, uploadAzureMistralOCR } = require('@librechat/api');
|
||||
const {
|
||||
getFirebaseURL,
|
||||
prepareImageURL,
|
||||
|
|
@ -46,7 +47,6 @@ const {
|
|||
const { uploadOpenAIFile, deleteOpenAIFile, getOpenAIFileStream } = require('./OpenAI');
|
||||
const { getCodeOutputDownloadStream, uploadCodeEnvFile } = require('./Code');
|
||||
const { uploadVectors, deleteVectors } = require('./VectorDB');
|
||||
const { uploadMistralOCR } = require('./MistralOCR');
|
||||
|
||||
/**
|
||||
* Firebase Storage Strategy Functions
|
||||
|
|
@ -202,6 +202,26 @@ const mistralOCRStrategy = () => ({
|
|||
handleFileUpload: uploadMistralOCR,
|
||||
});
|
||||
|
||||
const azureMistralOCRStrategy = () => ({
|
||||
/** @type {typeof saveFileFromURL | null} */
|
||||
saveURL: null,
|
||||
/** @type {typeof getLocalFileURL | null} */
|
||||
getFileURL: null,
|
||||
/** @type {typeof saveLocalBuffer | null} */
|
||||
saveBuffer: null,
|
||||
/** @type {typeof processLocalAvatar | null} */
|
||||
processAvatar: null,
|
||||
/** @type {typeof uploadLocalImage | null} */
|
||||
handleImageUpload: null,
|
||||
/** @type {typeof prepareImagesLocal | null} */
|
||||
prepareImagePayload: null,
|
||||
/** @type {typeof deleteLocalFile | null} */
|
||||
deleteFile: null,
|
||||
/** @type {typeof getLocalFileStream | null} */
|
||||
getDownloadStream: null,
|
||||
handleFileUpload: uploadAzureMistralOCR,
|
||||
});
|
||||
|
||||
// Strategy Selector
|
||||
const getStrategyFunctions = (fileSource) => {
|
||||
if (fileSource === FileSources.firebase) {
|
||||
|
|
@ -222,6 +242,8 @@ const getStrategyFunctions = (fileSource) => {
|
|||
return codeOutputStrategy();
|
||||
} else if (fileSource === FileSources.mistral_ocr) {
|
||||
return mistralOCRStrategy();
|
||||
} else if (fileSource === FileSources.azure_mistral_ocr) {
|
||||
return azureMistralOCRStrategy();
|
||||
} else {
|
||||
throw new Error('Invalid file source');
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,27 +1,111 @@
|
|||
const { z } = require('zod');
|
||||
const { tool } = require('@langchain/core/tools');
|
||||
const { normalizeServerName } = require('librechat-mcp');
|
||||
const { Constants: AgentConstants, Providers } = require('@librechat/agents');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { Time, CacheKeys, StepTypes } = require('librechat-data-provider');
|
||||
const { sendEvent, normalizeServerName, MCPOAuthHandler } = require('@librechat/api');
|
||||
const { Constants: AgentConstants, Providers, GraphEvents } = require('@librechat/agents');
|
||||
const {
|
||||
Constants,
|
||||
ContentTypes,
|
||||
isAssistantsEndpoint,
|
||||
convertJsonSchemaToZod,
|
||||
} = require('librechat-data-provider');
|
||||
const { logger, getMCPManager } = require('~/config');
|
||||
const { getMCPManager, getFlowStateManager } = require('~/config');
|
||||
const { findToken, createToken, updateToken } = require('~/models');
|
||||
const { getCachedTools } = require('./Config');
|
||||
const { getLogStores } = require('~/cache');
|
||||
|
||||
/**
|
||||
* @param {object} params
|
||||
* @param {ServerResponse} params.res - The Express response object for sending events.
|
||||
* @param {string} params.stepId - The ID of the step in the flow.
|
||||
* @param {ToolCallChunk} params.toolCall - The tool call object containing tool information.
|
||||
* @param {string} params.loginFlowId - The ID of the login flow.
|
||||
* @param {FlowStateManager<any>} params.flowManager - The flow manager instance.
|
||||
*/
|
||||
function createOAuthStart({ res, stepId, toolCall, loginFlowId, flowManager, signal }) {
|
||||
/**
|
||||
* Creates a function to handle OAuth login requests.
|
||||
* @param {string} authURL - The URL to redirect the user for OAuth authentication.
|
||||
* @returns {Promise<boolean>} Returns true to indicate the event was sent successfully.
|
||||
*/
|
||||
return async function (authURL) {
|
||||
/** @type {{ id: string; delta: AgentToolCallDelta }} */
|
||||
const data = {
|
||||
id: stepId,
|
||||
delta: {
|
||||
type: StepTypes.TOOL_CALLS,
|
||||
tool_calls: [{ ...toolCall, args: '' }],
|
||||
auth: authURL,
|
||||
expires_at: Date.now() + Time.TWO_MINUTES,
|
||||
},
|
||||
};
|
||||
/** Used to ensure the handler (use of `sendEvent`) is only invoked once */
|
||||
await flowManager.createFlowWithHandler(
|
||||
loginFlowId,
|
||||
'oauth_login',
|
||||
async () => {
|
||||
sendEvent(res, { event: GraphEvents.ON_RUN_STEP_DELTA, data });
|
||||
logger.debug('Sent OAuth login request to client');
|
||||
return true;
|
||||
},
|
||||
signal,
|
||||
);
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* @param {object} params
|
||||
* @param {ServerResponse} params.res - The Express response object for sending events.
|
||||
* @param {string} params.stepId - The ID of the step in the flow.
|
||||
* @param {ToolCallChunk} params.toolCall - The tool call object containing tool information.
|
||||
* @param {string} params.loginFlowId - The ID of the login flow.
|
||||
* @param {FlowStateManager<any>} params.flowManager - The flow manager instance.
|
||||
*/
|
||||
function createOAuthEnd({ res, stepId, toolCall }) {
|
||||
return async function () {
|
||||
/** @type {{ id: string; delta: AgentToolCallDelta }} */
|
||||
const data = {
|
||||
id: stepId,
|
||||
delta: {
|
||||
type: StepTypes.TOOL_CALLS,
|
||||
tool_calls: [{ ...toolCall }],
|
||||
},
|
||||
};
|
||||
sendEvent(res, { event: GraphEvents.ON_RUN_STEP_DELTA, data });
|
||||
logger.debug('Sent OAuth login success to client');
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* @param {object} params
|
||||
* @param {string} params.userId - The ID of the user.
|
||||
* @param {string} params.serverName - The name of the server.
|
||||
* @param {string} params.toolName - The name of the tool.
|
||||
* @param {FlowStateManager<any>} params.flowManager - The flow manager instance.
|
||||
*/
|
||||
function createAbortHandler({ userId, serverName, toolName, flowManager }) {
|
||||
return function () {
|
||||
logger.info(`[MCP][User: ${userId}][${serverName}][${toolName}] Tool call aborted`);
|
||||
const flowId = MCPOAuthHandler.generateFlowId(userId, serverName);
|
||||
flowManager.failFlow(flowId, 'mcp_oauth', new Error('Tool call aborted'));
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a general tool for an entire action set.
|
||||
*
|
||||
* @param {Object} params - The parameters for loading action sets.
|
||||
* @param {ServerRequest} params.req - The Express request object, containing user/request info.
|
||||
* @param {ServerResponse} params.res - The Express response object for sending events.
|
||||
* @param {string} params.toolKey - The toolKey for the tool.
|
||||
* @param {import('@librechat/agents').Providers | EModelEndpoint} params.provider - The provider for the tool.
|
||||
* @param {string} params.model - The model for the tool.
|
||||
* @returns { Promise<typeof tool | { _call: (toolInput: Object | string) => unknown}> } An object with `_call` method to execute the tool input.
|
||||
*/
|
||||
async function createMCPTool({ req, toolKey, provider: _provider }) {
|
||||
const toolDefinition = req.app.locals.availableTools[toolKey]?.function;
|
||||
async function createMCPTool({ req, res, toolKey, provider: _provider }) {
|
||||
const availableTools = await getCachedTools({ includeGlobal: true });
|
||||
const toolDefinition = availableTools?.[toolKey]?.function;
|
||||
if (!toolDefinition) {
|
||||
logger.error(`Tool ${toolKey} not found in available tools`);
|
||||
return null;
|
||||
|
|
@ -50,19 +134,61 @@ async function createMCPTool({ req, toolKey, provider: _provider }) {
|
|||
|
||||
/** @type {(toolArguments: Object | string, config?: GraphRunnableConfig) => Promise<unknown>} */
|
||||
const _call = async (toolArguments, config) => {
|
||||
const userId = config?.configurable?.user?.id || config?.configurable?.user_id;
|
||||
/** @type {ReturnType<typeof createAbortHandler>} */
|
||||
let abortHandler = null;
|
||||
/** @type {AbortSignal} */
|
||||
let derivedSignal = null;
|
||||
|
||||
try {
|
||||
const derivedSignal = config?.signal ? AbortSignal.any([config.signal]) : undefined;
|
||||
const mcpManager = getMCPManager(config?.configurable?.user_id);
|
||||
const flowsCache = getLogStores(CacheKeys.FLOWS);
|
||||
const flowManager = getFlowStateManager(flowsCache);
|
||||
derivedSignal = config?.signal ? AbortSignal.any([config.signal]) : undefined;
|
||||
const mcpManager = getMCPManager(userId);
|
||||
const provider = (config?.metadata?.provider || _provider)?.toLowerCase();
|
||||
|
||||
const { args: _args, stepId, ...toolCall } = config.toolCall ?? {};
|
||||
const loginFlowId = `${serverName}:oauth_login:${config.metadata.thread_id}:${config.metadata.run_id}`;
|
||||
const oauthStart = createOAuthStart({
|
||||
res,
|
||||
stepId,
|
||||
toolCall,
|
||||
loginFlowId,
|
||||
flowManager,
|
||||
signal: derivedSignal,
|
||||
});
|
||||
const oauthEnd = createOAuthEnd({
|
||||
res,
|
||||
stepId,
|
||||
toolCall,
|
||||
});
|
||||
|
||||
if (derivedSignal) {
|
||||
abortHandler = createAbortHandler({ userId, serverName, toolName, flowManager });
|
||||
derivedSignal.addEventListener('abort', abortHandler, { once: true });
|
||||
}
|
||||
|
||||
const customUserVars =
|
||||
config?.configurable?.userMCPAuthMap?.[`${Constants.mcp_prefix}${serverName}`];
|
||||
|
||||
const result = await mcpManager.callTool({
|
||||
serverName,
|
||||
toolName,
|
||||
provider,
|
||||
toolArguments,
|
||||
options: {
|
||||
userId: config?.configurable?.user_id,
|
||||
signal: derivedSignal,
|
||||
},
|
||||
user: config?.configurable?.user,
|
||||
customUserVars,
|
||||
flowManager,
|
||||
tokenMethods: {
|
||||
findToken,
|
||||
createToken,
|
||||
updateToken,
|
||||
},
|
||||
oauthStart,
|
||||
oauthEnd,
|
||||
});
|
||||
|
||||
if (isAssistantsEndpoint(provider) && Array.isArray(result)) {
|
||||
|
|
@ -74,12 +200,31 @@ async function createMCPTool({ req, toolKey, provider: _provider }) {
|
|||
return result;
|
||||
} catch (error) {
|
||||
logger.error(
|
||||
`[MCP][User: ${config?.configurable?.user_id}][${serverName}] Error calling "${toolName}" MCP tool:`,
|
||||
`[MCP][User: ${userId}][${serverName}] Error calling "${toolName}" MCP tool:`,
|
||||
error,
|
||||
);
|
||||
|
||||
/** OAuth error, provide a helpful message */
|
||||
const isOAuthError =
|
||||
error.message?.includes('401') ||
|
||||
error.message?.includes('OAuth') ||
|
||||
error.message?.includes('authentication') ||
|
||||
error.message?.includes('Non-200 status code (401)');
|
||||
|
||||
if (isOAuthError) {
|
||||
throw new Error(
|
||||
`OAuth authentication required for ${serverName}. Please check the server logs for the authentication URL.`,
|
||||
);
|
||||
}
|
||||
|
||||
throw new Error(
|
||||
`"${toolKey}" tool call failed${error?.message ? `: ${error?.message}` : '.'}`,
|
||||
);
|
||||
} finally {
|
||||
// Clean up abort handler to prevent memory leaks
|
||||
if (abortHandler && derivedSignal) {
|
||||
derivedSignal.removeEventListener('abort', abortHandler);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -1,12 +1,13 @@
|
|||
const axios = require('axios');
|
||||
const { Providers } = require('@librechat/agents');
|
||||
const { logAxiosError } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { HttpsProxyAgent } = require('https-proxy-agent');
|
||||
const { EModelEndpoint, defaultModels, CacheKeys } = require('librechat-data-provider');
|
||||
const { inputSchema, logAxiosError, extractBaseURL, processModelData } = require('~/utils');
|
||||
const { inputSchema, extractBaseURL, processModelData } = require('~/utils');
|
||||
const { OllamaClient } = require('~/app/clients/OllamaClient');
|
||||
const { isUserProvided } = require('~/server/utils');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
/**
|
||||
* Splits a string by commas and trims each resulting value.
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
const axios = require('axios');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { EModelEndpoint, defaultModels } = require('librechat-data-provider');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const {
|
||||
fetchModels,
|
||||
|
|
@ -28,7 +28,8 @@ jest.mock('~/cache/getLogStores', () =>
|
|||
set: jest.fn().mockResolvedValue(true),
|
||||
})),
|
||||
);
|
||||
jest.mock('~/config', () => ({
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
...jest.requireActual('@librechat/data-schemas'),
|
||||
logger: {
|
||||
error: jest.fn(),
|
||||
},
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
const { encrypt, decrypt } = require('~/server/utils/crypto');
|
||||
const { PluginAuth } = require('~/db/models');
|
||||
const { logger } = require('~/config');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { encrypt, decrypt } = require('@librechat/api');
|
||||
const { findOnePluginAuth, updatePluginAuth, deletePluginAuth } = require('~/models');
|
||||
|
||||
/**
|
||||
* Asynchronously retrieves and decrypts the authentication value for a user's plugin, based on a specified authentication field.
|
||||
|
|
@ -25,7 +25,7 @@ const { logger } = require('~/config');
|
|||
*/
|
||||
const getUserPluginAuthValue = async (userId, authField, throwError = true) => {
|
||||
try {
|
||||
const pluginAuth = await PluginAuth.findOne({ userId, authField }).lean();
|
||||
const pluginAuth = await findOnePluginAuth({ userId, authField });
|
||||
if (!pluginAuth) {
|
||||
throw new Error(`No plugin auth ${authField} found for user ${userId}`);
|
||||
}
|
||||
|
|
@ -79,23 +79,12 @@ const getUserPluginAuthValue = async (userId, authField, throwError = true) => {
|
|||
const updateUserPluginAuth = async (userId, authField, pluginKey, value) => {
|
||||
try {
|
||||
const encryptedValue = await encrypt(value);
|
||||
const pluginAuth = await PluginAuth.findOne({ userId, authField }).lean();
|
||||
if (pluginAuth) {
|
||||
return await PluginAuth.findOneAndUpdate(
|
||||
{ userId, authField },
|
||||
{ $set: { value: encryptedValue } },
|
||||
{ new: true, upsert: true },
|
||||
).lean();
|
||||
} else {
|
||||
const newPluginAuth = await new PluginAuth({
|
||||
userId,
|
||||
authField,
|
||||
value: encryptedValue,
|
||||
pluginKey,
|
||||
});
|
||||
await newPluginAuth.save();
|
||||
return newPluginAuth.toObject();
|
||||
}
|
||||
return await updatePluginAuth({
|
||||
userId,
|
||||
authField,
|
||||
pluginKey,
|
||||
value: encryptedValue,
|
||||
});
|
||||
} catch (err) {
|
||||
logger.error('[updateUserPluginAuth]', err);
|
||||
return err;
|
||||
|
|
@ -105,26 +94,25 @@ const updateUserPluginAuth = async (userId, authField, pluginKey, value) => {
|
|||
/**
|
||||
* @async
|
||||
* @param {string} userId
|
||||
* @param {string} authField
|
||||
* @param {boolean} [all]
|
||||
* @param {string | null} authField - The specific authField to delete, or null if `all` is true.
|
||||
* @param {boolean} [all=false] - Whether to delete all auths for the user (or for a specific pluginKey if provided).
|
||||
* @param {string} [pluginKey] - Optional. If `all` is true and `pluginKey` is provided, delete all auths for this user and pluginKey.
|
||||
* @returns {Promise<import('mongoose').DeleteResult>}
|
||||
* @throws {Error}
|
||||
*/
|
||||
const deleteUserPluginAuth = async (userId, authField, all = false) => {
|
||||
if (all) {
|
||||
try {
|
||||
const response = await PluginAuth.deleteMany({ userId });
|
||||
return response;
|
||||
} catch (err) {
|
||||
logger.error('[deleteUserPluginAuth]', err);
|
||||
return err;
|
||||
}
|
||||
}
|
||||
|
||||
const deleteUserPluginAuth = async (userId, authField, all = false, pluginKey) => {
|
||||
try {
|
||||
return await PluginAuth.deleteOne({ userId, authField });
|
||||
return await deletePluginAuth({
|
||||
userId,
|
||||
authField,
|
||||
pluginKey,
|
||||
all,
|
||||
});
|
||||
} catch (err) {
|
||||
logger.error('[deleteUserPluginAuth]', err);
|
||||
logger.error(
|
||||
`[deleteUserPluginAuth] Error deleting ${all ? 'all' : 'single'} auth(s) for userId: ${userId}${pluginKey ? ` and pluginKey: ${pluginKey}` : ''}`,
|
||||
err,
|
||||
);
|
||||
return err;
|
||||
}
|
||||
};
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
const axios = require('axios');
|
||||
const { logAxiosError } = require('@librechat/api');
|
||||
const { EModelEndpoint } = require('librechat-data-provider');
|
||||
const { logAxiosError } = require('~/utils');
|
||||
|
||||
/**
|
||||
* @typedef {Object} RetrieveOptions
|
||||
|
|
|
|||
|
|
@ -1,172 +0,0 @@
|
|||
const axios = require('axios');
|
||||
const { handleOAuthToken } = require('~/models/Token');
|
||||
const { decryptV2 } = require('~/server/utils/crypto');
|
||||
const { logAxiosError } = require('~/utils');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
/**
|
||||
* Processes the access tokens and stores them in the database.
|
||||
* @param {object} tokenData
|
||||
* @param {string} tokenData.access_token
|
||||
* @param {number} tokenData.expires_in
|
||||
* @param {string} [tokenData.refresh_token]
|
||||
* @param {number} [tokenData.refresh_token_expires_in]
|
||||
* @param {object} metadata
|
||||
* @param {string} metadata.userId
|
||||
* @param {string} metadata.identifier
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
async function processAccessTokens(tokenData, { userId, identifier }) {
|
||||
const { access_token, expires_in = 3600, refresh_token, refresh_token_expires_in } = tokenData;
|
||||
if (!access_token) {
|
||||
logger.error('Access token not found: ', tokenData);
|
||||
throw new Error('Access token not found');
|
||||
}
|
||||
await handleOAuthToken({
|
||||
identifier,
|
||||
token: access_token,
|
||||
expiresIn: expires_in,
|
||||
userId,
|
||||
});
|
||||
|
||||
if (refresh_token != null) {
|
||||
logger.debug('Processing refresh token');
|
||||
await handleOAuthToken({
|
||||
token: refresh_token,
|
||||
type: 'oauth_refresh',
|
||||
userId,
|
||||
identifier: `${identifier}:refresh`,
|
||||
expiresIn: refresh_token_expires_in ?? null,
|
||||
});
|
||||
}
|
||||
logger.debug('Access tokens processed');
|
||||
}
|
||||
|
||||
/**
|
||||
* Refreshes the access token using the refresh token.
|
||||
* @param {object} fields
|
||||
* @param {string} fields.userId - The ID of the user.
|
||||
* @param {string} fields.client_url - The URL of the OAuth provider.
|
||||
* @param {string} fields.identifier - The identifier for the token.
|
||||
* @param {string} fields.refresh_token - The refresh token to use.
|
||||
* @param {string} fields.encrypted_oauth_client_id - The client ID for the OAuth provider.
|
||||
* @param {string} fields.encrypted_oauth_client_secret - The client secret for the OAuth provider.
|
||||
* @returns {Promise<{
|
||||
* access_token: string,
|
||||
* expires_in: number,
|
||||
* refresh_token?: string,
|
||||
* refresh_token_expires_in?: number,
|
||||
* }>}
|
||||
*/
|
||||
const refreshAccessToken = async ({
|
||||
userId,
|
||||
client_url,
|
||||
identifier,
|
||||
refresh_token,
|
||||
encrypted_oauth_client_id,
|
||||
encrypted_oauth_client_secret,
|
||||
}) => {
|
||||
try {
|
||||
const oauth_client_id = await decryptV2(encrypted_oauth_client_id);
|
||||
const oauth_client_secret = await decryptV2(encrypted_oauth_client_secret);
|
||||
const params = new URLSearchParams({
|
||||
client_id: oauth_client_id,
|
||||
client_secret: oauth_client_secret,
|
||||
grant_type: 'refresh_token',
|
||||
refresh_token,
|
||||
});
|
||||
|
||||
const response = await axios({
|
||||
method: 'POST',
|
||||
url: client_url,
|
||||
headers: {
|
||||
'Content-Type': 'application/x-www-form-urlencoded',
|
||||
Accept: 'application/json',
|
||||
},
|
||||
data: params.toString(),
|
||||
});
|
||||
await processAccessTokens(response.data, {
|
||||
userId,
|
||||
identifier,
|
||||
});
|
||||
logger.debug(`Access token refreshed successfully for ${identifier}`);
|
||||
return response.data;
|
||||
} catch (error) {
|
||||
const message = 'Error refreshing OAuth tokens';
|
||||
throw new Error(
|
||||
logAxiosError({
|
||||
message,
|
||||
error,
|
||||
}),
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Handles the OAuth callback and exchanges the authorization code for tokens.
|
||||
* @param {object} fields
|
||||
* @param {string} fields.code - The authorization code returned by the provider.
|
||||
* @param {string} fields.userId - The ID of the user.
|
||||
* @param {string} fields.identifier - The identifier for the token.
|
||||
* @param {string} fields.client_url - The URL of the OAuth provider.
|
||||
* @param {string} fields.redirect_uri - The redirect URI for the OAuth provider.
|
||||
* @param {string} fields.encrypted_oauth_client_id - The client ID for the OAuth provider.
|
||||
* @param {string} fields.encrypted_oauth_client_secret - The client secret for the OAuth provider.
|
||||
* @returns {Promise<{
|
||||
* access_token: string,
|
||||
* expires_in: number,
|
||||
* refresh_token?: string,
|
||||
* refresh_token_expires_in?: number,
|
||||
* }>}
|
||||
*/
|
||||
const getAccessToken = async ({
|
||||
code,
|
||||
userId,
|
||||
identifier,
|
||||
client_url,
|
||||
redirect_uri,
|
||||
encrypted_oauth_client_id,
|
||||
encrypted_oauth_client_secret,
|
||||
}) => {
|
||||
const oauth_client_id = await decryptV2(encrypted_oauth_client_id);
|
||||
const oauth_client_secret = await decryptV2(encrypted_oauth_client_secret);
|
||||
const params = new URLSearchParams({
|
||||
code,
|
||||
client_id: oauth_client_id,
|
||||
client_secret: oauth_client_secret,
|
||||
grant_type: 'authorization_code',
|
||||
redirect_uri,
|
||||
});
|
||||
|
||||
try {
|
||||
const response = await axios({
|
||||
method: 'POST',
|
||||
url: client_url,
|
||||
headers: {
|
||||
'Content-Type': 'application/x-www-form-urlencoded',
|
||||
Accept: 'application/json',
|
||||
},
|
||||
data: params.toString(),
|
||||
});
|
||||
|
||||
await processAccessTokens(response.data, {
|
||||
userId,
|
||||
identifier,
|
||||
});
|
||||
logger.debug(`Access tokens successfully created for ${identifier}`);
|
||||
return response.data;
|
||||
} catch (error) {
|
||||
const message = 'Error exchanging OAuth code';
|
||||
throw new Error(
|
||||
logAxiosError({
|
||||
message,
|
||||
error,
|
||||
}),
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
getAccessToken,
|
||||
refreshAccessToken,
|
||||
};
|
||||
|
|
@ -1,64 +0,0 @@
|
|||
const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
class Tokenizer {
|
||||
constructor() {
|
||||
this.tokenizersCache = {};
|
||||
this.tokenizerCallsCount = 0;
|
||||
}
|
||||
|
||||
getTokenizer(encoding, isModelName = false, extendSpecialTokens = {}) {
|
||||
let tokenizer;
|
||||
if (this.tokenizersCache[encoding]) {
|
||||
tokenizer = this.tokenizersCache[encoding];
|
||||
} else {
|
||||
if (isModelName) {
|
||||
tokenizer = encodingForModel(encoding, extendSpecialTokens);
|
||||
} else {
|
||||
tokenizer = getEncoding(encoding, extendSpecialTokens);
|
||||
}
|
||||
this.tokenizersCache[encoding] = tokenizer;
|
||||
}
|
||||
return tokenizer;
|
||||
}
|
||||
|
||||
freeAndResetAllEncoders() {
|
||||
try {
|
||||
Object.keys(this.tokenizersCache).forEach((key) => {
|
||||
if (this.tokenizersCache[key]) {
|
||||
this.tokenizersCache[key].free();
|
||||
delete this.tokenizersCache[key];
|
||||
}
|
||||
});
|
||||
this.tokenizerCallsCount = 1;
|
||||
} catch (error) {
|
||||
logger.error('[Tokenizer] Free and reset encoders error', error);
|
||||
}
|
||||
}
|
||||
|
||||
resetTokenizersIfNecessary() {
|
||||
if (this.tokenizerCallsCount >= 25) {
|
||||
if (this.options?.debug) {
|
||||
logger.debug('[Tokenizer] freeAndResetAllEncoders: reached 25 encodings, resetting...');
|
||||
}
|
||||
this.freeAndResetAllEncoders();
|
||||
}
|
||||
this.tokenizerCallsCount++;
|
||||
}
|
||||
|
||||
getTokenCount(text, encoding = 'cl100k_base') {
|
||||
this.resetTokenizersIfNecessary();
|
||||
try {
|
||||
const tokenizer = this.getTokenizer(encoding);
|
||||
return tokenizer.encode(text, 'all').length;
|
||||
} catch (error) {
|
||||
this.freeAndResetAllEncoders();
|
||||
const tokenizer = this.getTokenizer(encoding);
|
||||
return tokenizer.encode(text, 'all').length;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const TokenizerSingleton = new Tokenizer();
|
||||
|
||||
module.exports = TokenizerSingleton;
|
||||
|
|
@ -1,136 +0,0 @@
|
|||
/**
|
||||
* @file Tokenizer.spec.cjs
|
||||
*
|
||||
* Tests the real TokenizerSingleton (no mocking of `tiktoken`).
|
||||
* Make sure to install `tiktoken` and have it configured properly.
|
||||
*/
|
||||
|
||||
const Tokenizer = require('./Tokenizer'); // <-- Adjust path to your singleton file
|
||||
const { logger } = require('~/config');
|
||||
|
||||
describe('Tokenizer', () => {
|
||||
it('should be a singleton (same instance)', () => {
|
||||
const AnotherTokenizer = require('./Tokenizer'); // same path
|
||||
expect(Tokenizer).toBe(AnotherTokenizer);
|
||||
});
|
||||
|
||||
describe('getTokenizer', () => {
|
||||
it('should create an encoder for an explicit model name (e.g., "gpt-4")', () => {
|
||||
// The real `encoding_for_model` will be called internally
|
||||
// as soon as we pass isModelName = true.
|
||||
const tokenizer = Tokenizer.getTokenizer('gpt-4', true);
|
||||
|
||||
// Basic sanity checks
|
||||
expect(tokenizer).toBeDefined();
|
||||
// You can optionally check certain properties from `tiktoken` if they exist
|
||||
// e.g., expect(typeof tokenizer.encode).toBe('function');
|
||||
});
|
||||
|
||||
it('should create an encoder for a known encoding (e.g., "cl100k_base")', () => {
|
||||
// The real `get_encoding` will be called internally
|
||||
// as soon as we pass isModelName = false.
|
||||
const tokenizer = Tokenizer.getTokenizer('cl100k_base', false);
|
||||
|
||||
expect(tokenizer).toBeDefined();
|
||||
// e.g., expect(typeof tokenizer.encode).toBe('function');
|
||||
});
|
||||
|
||||
it('should return cached tokenizer if previously fetched', () => {
|
||||
const tokenizer1 = Tokenizer.getTokenizer('cl100k_base', false);
|
||||
const tokenizer2 = Tokenizer.getTokenizer('cl100k_base', false);
|
||||
// Should be the exact same instance from the cache
|
||||
expect(tokenizer1).toBe(tokenizer2);
|
||||
});
|
||||
});
|
||||
|
||||
describe('freeAndResetAllEncoders', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
it('should free all encoders and reset tokenizerCallsCount to 1', () => {
|
||||
// By creating two different encodings, we populate the cache
|
||||
Tokenizer.getTokenizer('cl100k_base', false);
|
||||
Tokenizer.getTokenizer('r50k_base', false);
|
||||
|
||||
// Now free them
|
||||
Tokenizer.freeAndResetAllEncoders();
|
||||
|
||||
// The internal cache is cleared
|
||||
expect(Tokenizer.tokenizersCache['cl100k_base']).toBeUndefined();
|
||||
expect(Tokenizer.tokenizersCache['r50k_base']).toBeUndefined();
|
||||
|
||||
// tokenizerCallsCount is reset to 1
|
||||
expect(Tokenizer.tokenizerCallsCount).toBe(1);
|
||||
});
|
||||
|
||||
it('should catch and log errors if freeing fails', () => {
|
||||
// Mock logger.error before the test
|
||||
const mockLoggerError = jest.spyOn(logger, 'error');
|
||||
|
||||
// Set up a problematic tokenizer in the cache
|
||||
Tokenizer.tokenizersCache['cl100k_base'] = {
|
||||
free() {
|
||||
throw new Error('Intentional free error');
|
||||
},
|
||||
};
|
||||
|
||||
// Should not throw uncaught errors
|
||||
Tokenizer.freeAndResetAllEncoders();
|
||||
|
||||
// Verify logger.error was called with correct arguments
|
||||
expect(mockLoggerError).toHaveBeenCalledWith(
|
||||
'[Tokenizer] Free and reset encoders error',
|
||||
expect.any(Error),
|
||||
);
|
||||
|
||||
// Clean up
|
||||
mockLoggerError.mockRestore();
|
||||
Tokenizer.tokenizersCache = {};
|
||||
});
|
||||
});
|
||||
|
||||
describe('getTokenCount', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
Tokenizer.freeAndResetAllEncoders();
|
||||
});
|
||||
|
||||
it('should return the number of tokens in the given text', () => {
|
||||
const text = 'Hello, world!';
|
||||
const count = Tokenizer.getTokenCount(text, 'cl100k_base');
|
||||
expect(count).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
it('should reset encoders if an error is thrown', () => {
|
||||
// We can simulate an error by temporarily overriding the selected tokenizer’s `encode` method.
|
||||
const tokenizer = Tokenizer.getTokenizer('cl100k_base', false);
|
||||
const originalEncode = tokenizer.encode;
|
||||
tokenizer.encode = () => {
|
||||
throw new Error('Forced error');
|
||||
};
|
||||
|
||||
// Despite the forced error, the code should catch and reset, then re-encode
|
||||
const count = Tokenizer.getTokenCount('Hello again', 'cl100k_base');
|
||||
expect(count).toBeGreaterThan(0);
|
||||
|
||||
// Restore the original encode
|
||||
tokenizer.encode = originalEncode;
|
||||
});
|
||||
|
||||
it('should reset tokenizers after 25 calls', () => {
|
||||
// Spy on freeAndResetAllEncoders
|
||||
const resetSpy = jest.spyOn(Tokenizer, 'freeAndResetAllEncoders');
|
||||
|
||||
// Make 24 calls; should NOT reset yet
|
||||
for (let i = 0; i < 24; i++) {
|
||||
Tokenizer.getTokenCount('test text', 'cl100k_base');
|
||||
}
|
||||
expect(resetSpy).not.toHaveBeenCalled();
|
||||
|
||||
// 25th call triggers the reset
|
||||
Tokenizer.getTokenCount('the 25th call!', 'cl100k_base');
|
||||
expect(resetSpy).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
@ -1,5 +1,7 @@
|
|||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
const { sleep } = require('@librechat/agents');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { zodToJsonSchema } = require('zod-to-json-schema');
|
||||
const { Calculator } = require('@langchain/community/tools/calculator');
|
||||
const { tool: toolFn, Tool, DynamicStructuredTool } = require('@langchain/core/tools');
|
||||
|
|
@ -31,14 +33,12 @@ const {
|
|||
toolkits,
|
||||
} = require('~/app/clients/tools');
|
||||
const { processFileURL, uploadImageBuffer } = require('~/server/services/Files/process');
|
||||
const { getEndpointsConfig, getCachedTools } = require('~/server/services/Config');
|
||||
const { createOnSearchResults } = require('~/server/services/Tools/search');
|
||||
const { isActionDomainAllowed } = require('~/server/services/domains');
|
||||
const { getEndpointsConfig } = require('~/server/services/Config');
|
||||
const { recordUsage } = require('~/server/services/Threads');
|
||||
const { loadTools } = require('~/app/clients/tools/util');
|
||||
const { redactMessage } = require('~/config/parsers');
|
||||
const { sleep } = require('~/server/utils');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
/**
|
||||
* @param {string} toolName
|
||||
|
|
@ -226,7 +226,7 @@ async function processRequiredActions(client, requiredActions) {
|
|||
`[required actions] user: ${client.req.user.id} | thread_id: ${requiredActions[0].thread_id} | run_id: ${requiredActions[0].run_id}`,
|
||||
requiredActions,
|
||||
);
|
||||
const toolDefinitions = client.req.app.locals.availableTools;
|
||||
const toolDefinitions = await getCachedTools({ includeGlobal: true });
|
||||
const seenToolkits = new Set();
|
||||
const tools = requiredActions
|
||||
.map((action) => {
|
||||
|
|
@ -500,6 +500,8 @@ async function processRequiredActions(client, requiredActions) {
|
|||
async function loadAgentTools({ req, res, agent, tool_resources, openAIApiKey }) {
|
||||
if (!agent.tools || agent.tools.length === 0) {
|
||||
return {};
|
||||
} else if (agent.tools && agent.tools.length === 1 && agent.tools[0] === AgentCapabilities.ocr) {
|
||||
return {};
|
||||
}
|
||||
|
||||
const endpointsConfig = await getEndpointsConfig(req);
|
||||
|
|
@ -551,6 +553,7 @@ async function loadAgentTools({ req, res, agent, tool_resources, openAIApiKey })
|
|||
tools: _agentTools,
|
||||
options: {
|
||||
req,
|
||||
res,
|
||||
openAIApiKey,
|
||||
tool_resources,
|
||||
processFileURL,
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
const { logger } = require('@librechat/data-schemas');
|
||||
const { encrypt, decrypt } = require('@librechat/api');
|
||||
const { ErrorTypes } = require('librechat-data-provider');
|
||||
const { encrypt, decrypt } = require('~/server/utils/crypto');
|
||||
const { updateUser } = require('~/models');
|
||||
const { Key } = require('~/db/models');
|
||||
|
||||
|
|
@ -70,6 +70,7 @@ const getUserKeyValues = async ({ userId, name }) => {
|
|||
try {
|
||||
userValues = JSON.parse(userValues);
|
||||
} catch (e) {
|
||||
logger.error('[getUserKeyValues]', e);
|
||||
throw new Error(
|
||||
JSON.stringify({
|
||||
type: ErrorTypes.INVALID_USER_KEY,
|
||||
|
|
|
|||
54
api/server/services/initializeMCP.js
Normal file
54
api/server/services/initializeMCP.js
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
const { logger } = require('@librechat/data-schemas');
|
||||
const { CacheKeys, processMCPEnv } = require('librechat-data-provider');
|
||||
const { getMCPManager, getFlowStateManager } = require('~/config');
|
||||
const { getCachedTools, setCachedTools } = require('./Config');
|
||||
const { getLogStores } = require('~/cache');
|
||||
const { findToken, updateToken, createToken, deleteTokens } = require('~/models');
|
||||
|
||||
/**
|
||||
* Initialize MCP servers
|
||||
* @param {import('express').Application} app - Express app instance
|
||||
*/
|
||||
async function initializeMCP(app) {
|
||||
const mcpServers = app.locals.mcpConfig;
|
||||
if (!mcpServers) {
|
||||
return;
|
||||
}
|
||||
|
||||
logger.info('Initializing MCP servers...');
|
||||
const mcpManager = getMCPManager();
|
||||
const flowsCache = getLogStores(CacheKeys.FLOWS);
|
||||
const flowManager = flowsCache ? getFlowStateManager(flowsCache) : null;
|
||||
|
||||
try {
|
||||
await mcpManager.initializeMCP({
|
||||
mcpServers,
|
||||
flowManager,
|
||||
tokenMethods: {
|
||||
findToken,
|
||||
updateToken,
|
||||
createToken,
|
||||
deleteTokens,
|
||||
},
|
||||
processMCPEnv,
|
||||
});
|
||||
|
||||
delete app.locals.mcpConfig;
|
||||
const availableTools = await getCachedTools();
|
||||
|
||||
if (!availableTools) {
|
||||
logger.warn('No available tools found in cache during MCP initialization');
|
||||
return;
|
||||
}
|
||||
|
||||
const toolsCopy = { ...availableTools };
|
||||
await mcpManager.mapAvailableTools(toolsCopy, flowManager);
|
||||
await setCachedTools(toolsCopy, { isGlobal: true });
|
||||
|
||||
logger.info('MCP servers initialized successfully');
|
||||
} catch (error) {
|
||||
logger.error('Failed to initialize MCP servers:', error);
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = initializeMCP;
|
||||
|
|
@ -1,14 +0,0 @@
|
|||
const { EModelEndpoint, agentsEndpointSChema } = require('librechat-data-provider');
|
||||
|
||||
/**
|
||||
* Sets up the Agents configuration from the config (`librechat.yaml`) file.
|
||||
* @param {TCustomConfig} config - The loaded custom configuration.
|
||||
* @returns {Partial<TAgentsEndpoint>} The Agents endpoint configuration.
|
||||
*/
|
||||
function agentsConfigSetup(config) {
|
||||
const agentsConfig = config.endpoints[EModelEndpoint.agents];
|
||||
const parsedConfig = agentsEndpointSChema.parse(agentsConfig);
|
||||
return parsedConfig;
|
||||
}
|
||||
|
||||
module.exports = { agentsConfigSetup };
|
||||
|
|
@ -2,6 +2,7 @@ const {
|
|||
SystemRoles,
|
||||
Permissions,
|
||||
PermissionTypes,
|
||||
isMemoryEnabled,
|
||||
removeNullishValues,
|
||||
} = require('librechat-data-provider');
|
||||
const { updateAccessPermissions } = require('~/models/Role');
|
||||
|
|
@ -20,6 +21,14 @@ async function loadDefaultInterface(config, configDefaults, roleName = SystemRol
|
|||
const hasModelSpecs = config?.modelSpecs?.list?.length > 0;
|
||||
const includesAddedEndpoints = config?.modelSpecs?.addedEndpoints?.length > 0;
|
||||
|
||||
const memoryConfig = config?.memory;
|
||||
const memoryEnabled = isMemoryEnabled(memoryConfig);
|
||||
/** Only disable memories if memory config is present but disabled/invalid */
|
||||
const shouldDisableMemories = memoryConfig && !memoryEnabled;
|
||||
/** Check if personalization is enabled (defaults to true if memory is configured and enabled) */
|
||||
const isPersonalizationEnabled =
|
||||
memoryConfig && memoryEnabled && memoryConfig.personalize !== false;
|
||||
|
||||
/** @type {TCustomConfig['interface']} */
|
||||
const loadedInterface = removeNullishValues({
|
||||
endpointsMenu:
|
||||
|
|
@ -33,6 +42,7 @@ async function loadDefaultInterface(config, configDefaults, roleName = SystemRol
|
|||
privacyPolicy: interfaceConfig?.privacyPolicy ?? defaults.privacyPolicy,
|
||||
termsOfService: interfaceConfig?.termsOfService ?? defaults.termsOfService,
|
||||
bookmarks: interfaceConfig?.bookmarks ?? defaults.bookmarks,
|
||||
memories: shouldDisableMemories ? false : (interfaceConfig?.memories ?? defaults.memories),
|
||||
prompts: interfaceConfig?.prompts ?? defaults.prompts,
|
||||
multiConvo: interfaceConfig?.multiConvo ?? defaults.multiConvo,
|
||||
agents: interfaceConfig?.agents ?? defaults.agents,
|
||||
|
|
@ -45,6 +55,10 @@ async function loadDefaultInterface(config, configDefaults, roleName = SystemRol
|
|||
await updateAccessPermissions(roleName, {
|
||||
[PermissionTypes.PROMPTS]: { [Permissions.USE]: loadedInterface.prompts },
|
||||
[PermissionTypes.BOOKMARKS]: { [Permissions.USE]: loadedInterface.bookmarks },
|
||||
[PermissionTypes.MEMORIES]: {
|
||||
[Permissions.USE]: loadedInterface.memories,
|
||||
[Permissions.OPT_OUT]: isPersonalizationEnabled,
|
||||
},
|
||||
[PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: loadedInterface.multiConvo },
|
||||
[PermissionTypes.AGENTS]: { [Permissions.USE]: loadedInterface.agents },
|
||||
[PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: loadedInterface.temporaryChat },
|
||||
|
|
@ -54,6 +68,10 @@ async function loadDefaultInterface(config, configDefaults, roleName = SystemRol
|
|||
await updateAccessPermissions(SystemRoles.ADMIN, {
|
||||
[PermissionTypes.PROMPTS]: { [Permissions.USE]: loadedInterface.prompts },
|
||||
[PermissionTypes.BOOKMARKS]: { [Permissions.USE]: loadedInterface.bookmarks },
|
||||
[PermissionTypes.MEMORIES]: {
|
||||
[Permissions.USE]: loadedInterface.memories,
|
||||
[Permissions.OPT_OUT]: isPersonalizationEnabled,
|
||||
},
|
||||
[PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: loadedInterface.multiConvo },
|
||||
[PermissionTypes.AGENTS]: { [Permissions.USE]: loadedInterface.agents },
|
||||
[PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: loadedInterface.temporaryChat },
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ describe('loadDefaultInterface', () => {
|
|||
interface: {
|
||||
prompts: true,
|
||||
bookmarks: true,
|
||||
memories: true,
|
||||
multiConvo: true,
|
||||
agents: true,
|
||||
temporaryChat: true,
|
||||
|
|
@ -26,6 +27,7 @@ describe('loadDefaultInterface', () => {
|
|||
expect(updateAccessPermissions).toHaveBeenCalledWith(SystemRoles.USER, {
|
||||
[PermissionTypes.PROMPTS]: { [Permissions.USE]: true },
|
||||
[PermissionTypes.BOOKMARKS]: { [Permissions.USE]: true },
|
||||
[PermissionTypes.MEMORIES]: { [Permissions.USE]: true },
|
||||
[PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: true },
|
||||
[PermissionTypes.AGENTS]: { [Permissions.USE]: true },
|
||||
[PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: true },
|
||||
|
|
@ -39,6 +41,7 @@ describe('loadDefaultInterface', () => {
|
|||
interface: {
|
||||
prompts: false,
|
||||
bookmarks: false,
|
||||
memories: false,
|
||||
multiConvo: false,
|
||||
agents: false,
|
||||
temporaryChat: false,
|
||||
|
|
@ -53,6 +56,7 @@ describe('loadDefaultInterface', () => {
|
|||
expect(updateAccessPermissions).toHaveBeenCalledWith(SystemRoles.USER, {
|
||||
[PermissionTypes.PROMPTS]: { [Permissions.USE]: false },
|
||||
[PermissionTypes.BOOKMARKS]: { [Permissions.USE]: false },
|
||||
[PermissionTypes.MEMORIES]: { [Permissions.USE]: false },
|
||||
[PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: false },
|
||||
[PermissionTypes.AGENTS]: { [Permissions.USE]: false },
|
||||
[PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: false },
|
||||
|
|
@ -70,6 +74,7 @@ describe('loadDefaultInterface', () => {
|
|||
expect(updateAccessPermissions).toHaveBeenCalledWith(SystemRoles.USER, {
|
||||
[PermissionTypes.PROMPTS]: { [Permissions.USE]: undefined },
|
||||
[PermissionTypes.BOOKMARKS]: { [Permissions.USE]: undefined },
|
||||
[PermissionTypes.MEMORIES]: { [Permissions.USE]: undefined },
|
||||
[PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: undefined },
|
||||
[PermissionTypes.AGENTS]: { [Permissions.USE]: undefined },
|
||||
[PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: undefined },
|
||||
|
|
@ -83,6 +88,7 @@ describe('loadDefaultInterface', () => {
|
|||
interface: {
|
||||
prompts: undefined,
|
||||
bookmarks: undefined,
|
||||
memories: undefined,
|
||||
multiConvo: undefined,
|
||||
agents: undefined,
|
||||
temporaryChat: undefined,
|
||||
|
|
@ -97,6 +103,7 @@ describe('loadDefaultInterface', () => {
|
|||
expect(updateAccessPermissions).toHaveBeenCalledWith(SystemRoles.USER, {
|
||||
[PermissionTypes.PROMPTS]: { [Permissions.USE]: undefined },
|
||||
[PermissionTypes.BOOKMARKS]: { [Permissions.USE]: undefined },
|
||||
[PermissionTypes.MEMORIES]: { [Permissions.USE]: undefined },
|
||||
[PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: undefined },
|
||||
[PermissionTypes.AGENTS]: { [Permissions.USE]: undefined },
|
||||
[PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: undefined },
|
||||
|
|
@ -110,6 +117,7 @@ describe('loadDefaultInterface', () => {
|
|||
interface: {
|
||||
prompts: true,
|
||||
bookmarks: false,
|
||||
memories: true,
|
||||
multiConvo: undefined,
|
||||
agents: true,
|
||||
temporaryChat: undefined,
|
||||
|
|
@ -124,6 +132,7 @@ describe('loadDefaultInterface', () => {
|
|||
expect(updateAccessPermissions).toHaveBeenCalledWith(SystemRoles.USER, {
|
||||
[PermissionTypes.PROMPTS]: { [Permissions.USE]: true },
|
||||
[PermissionTypes.BOOKMARKS]: { [Permissions.USE]: false },
|
||||
[PermissionTypes.MEMORIES]: { [Permissions.USE]: true },
|
||||
[PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: undefined },
|
||||
[PermissionTypes.AGENTS]: { [Permissions.USE]: true },
|
||||
[PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: undefined },
|
||||
|
|
@ -138,6 +147,7 @@ describe('loadDefaultInterface', () => {
|
|||
interface: {
|
||||
prompts: true,
|
||||
bookmarks: true,
|
||||
memories: true,
|
||||
multiConvo: true,
|
||||
agents: true,
|
||||
temporaryChat: true,
|
||||
|
|
@ -151,6 +161,7 @@ describe('loadDefaultInterface', () => {
|
|||
expect(updateAccessPermissions).toHaveBeenCalledWith(SystemRoles.USER, {
|
||||
[PermissionTypes.PROMPTS]: { [Permissions.USE]: true },
|
||||
[PermissionTypes.BOOKMARKS]: { [Permissions.USE]: true },
|
||||
[PermissionTypes.MEMORIES]: { [Permissions.USE]: true },
|
||||
[PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: true },
|
||||
[PermissionTypes.AGENTS]: { [Permissions.USE]: true },
|
||||
[PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: true },
|
||||
|
|
@ -168,6 +179,7 @@ describe('loadDefaultInterface', () => {
|
|||
expect(updateAccessPermissions).toHaveBeenCalledWith(SystemRoles.USER, {
|
||||
[PermissionTypes.PROMPTS]: { [Permissions.USE]: undefined },
|
||||
[PermissionTypes.BOOKMARKS]: { [Permissions.USE]: undefined },
|
||||
[PermissionTypes.MEMORIES]: { [Permissions.USE]: undefined },
|
||||
[PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: true },
|
||||
[PermissionTypes.AGENTS]: { [Permissions.USE]: undefined },
|
||||
[PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: undefined },
|
||||
|
|
@ -185,6 +197,7 @@ describe('loadDefaultInterface', () => {
|
|||
expect(updateAccessPermissions).toHaveBeenCalledWith(SystemRoles.USER, {
|
||||
[PermissionTypes.PROMPTS]: { [Permissions.USE]: undefined },
|
||||
[PermissionTypes.BOOKMARKS]: { [Permissions.USE]: undefined },
|
||||
[PermissionTypes.MEMORIES]: { [Permissions.USE]: undefined },
|
||||
[PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: false },
|
||||
[PermissionTypes.AGENTS]: { [Permissions.USE]: undefined },
|
||||
[PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: undefined },
|
||||
|
|
@ -202,6 +215,7 @@ describe('loadDefaultInterface', () => {
|
|||
expect(updateAccessPermissions).toHaveBeenCalledWith(SystemRoles.USER, {
|
||||
[PermissionTypes.PROMPTS]: { [Permissions.USE]: undefined },
|
||||
[PermissionTypes.BOOKMARKS]: { [Permissions.USE]: undefined },
|
||||
[PermissionTypes.MEMORIES]: { [Permissions.USE]: undefined },
|
||||
[PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: undefined },
|
||||
[PermissionTypes.AGENTS]: { [Permissions.USE]: undefined },
|
||||
[PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: undefined },
|
||||
|
|
@ -215,6 +229,7 @@ describe('loadDefaultInterface', () => {
|
|||
interface: {
|
||||
prompts: true,
|
||||
bookmarks: false,
|
||||
memories: true,
|
||||
multiConvo: true,
|
||||
agents: false,
|
||||
temporaryChat: true,
|
||||
|
|
@ -228,6 +243,7 @@ describe('loadDefaultInterface', () => {
|
|||
expect(updateAccessPermissions).toHaveBeenCalledWith(SystemRoles.USER, {
|
||||
[PermissionTypes.PROMPTS]: { [Permissions.USE]: true },
|
||||
[PermissionTypes.BOOKMARKS]: { [Permissions.USE]: false },
|
||||
[PermissionTypes.MEMORIES]: { [Permissions.USE]: true },
|
||||
[PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: true },
|
||||
[PermissionTypes.AGENTS]: { [Permissions.USE]: false },
|
||||
[PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: true },
|
||||
|
|
@ -242,6 +258,7 @@ describe('loadDefaultInterface', () => {
|
|||
interface: {
|
||||
prompts: true,
|
||||
bookmarks: true,
|
||||
memories: false,
|
||||
multiConvo: false,
|
||||
agents: undefined,
|
||||
temporaryChat: undefined,
|
||||
|
|
@ -255,6 +272,7 @@ describe('loadDefaultInterface', () => {
|
|||
expect(updateAccessPermissions).toHaveBeenCalledWith(SystemRoles.USER, {
|
||||
[PermissionTypes.PROMPTS]: { [Permissions.USE]: true },
|
||||
[PermissionTypes.BOOKMARKS]: { [Permissions.USE]: true },
|
||||
[PermissionTypes.MEMORIES]: { [Permissions.USE]: false },
|
||||
[PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: false },
|
||||
[PermissionTypes.AGENTS]: { [Permissions.USE]: undefined },
|
||||
[PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: undefined },
|
||||
|
|
@ -268,6 +286,7 @@ describe('loadDefaultInterface', () => {
|
|||
interface: {
|
||||
prompts: true,
|
||||
bookmarks: false,
|
||||
memories: true,
|
||||
multiConvo: true,
|
||||
agents: false,
|
||||
temporaryChat: true,
|
||||
|
|
@ -281,6 +300,7 @@ describe('loadDefaultInterface', () => {
|
|||
expect(updateAccessPermissions).toHaveBeenCalledWith(SystemRoles.USER, {
|
||||
[PermissionTypes.PROMPTS]: { [Permissions.USE]: true },
|
||||
[PermissionTypes.BOOKMARKS]: { [Permissions.USE]: false },
|
||||
[PermissionTypes.MEMORIES]: { [Permissions.USE]: true },
|
||||
[PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: true },
|
||||
[PermissionTypes.AGENTS]: { [Permissions.USE]: false },
|
||||
[PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: true },
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
const { webcrypto } = require('node:crypto');
|
||||
const { hashBackupCode, decryptV3, decryptV2 } = require('~/server/utils/crypto');
|
||||
const { hashBackupCode, decryptV3, decryptV2 } = require('@librechat/api');
|
||||
const { updateUser } = require('~/models');
|
||||
|
||||
// Base32 alphabet for TOTP secret encoding.
|
||||
|
|
|
|||
|
|
@ -1,140 +0,0 @@
|
|||
require('dotenv').config();
|
||||
const crypto = require('node:crypto');
|
||||
const { webcrypto } = crypto;
|
||||
|
||||
// Use hex decoding for both key and IV for legacy methods.
|
||||
const key = Buffer.from(process.env.CREDS_KEY, 'hex');
|
||||
const iv = Buffer.from(process.env.CREDS_IV, 'hex');
|
||||
const algorithm = 'AES-CBC';
|
||||
|
||||
// --- Legacy v1/v2 Setup: AES-CBC with fixed key and IV ---
|
||||
|
||||
async function encrypt(value) {
|
||||
const cryptoKey = await webcrypto.subtle.importKey('raw', key, { name: algorithm }, false, [
|
||||
'encrypt',
|
||||
]);
|
||||
const encoder = new TextEncoder();
|
||||
const data = encoder.encode(value);
|
||||
const encryptedBuffer = await webcrypto.subtle.encrypt(
|
||||
{ name: algorithm, iv: iv },
|
||||
cryptoKey,
|
||||
data,
|
||||
);
|
||||
return Buffer.from(encryptedBuffer).toString('hex');
|
||||
}
|
||||
|
||||
async function decrypt(encryptedValue) {
|
||||
const cryptoKey = await webcrypto.subtle.importKey('raw', key, { name: algorithm }, false, [
|
||||
'decrypt',
|
||||
]);
|
||||
const encryptedBuffer = Buffer.from(encryptedValue, 'hex');
|
||||
const decryptedBuffer = await webcrypto.subtle.decrypt(
|
||||
{ name: algorithm, iv: iv },
|
||||
cryptoKey,
|
||||
encryptedBuffer,
|
||||
);
|
||||
const decoder = new TextDecoder();
|
||||
return decoder.decode(decryptedBuffer);
|
||||
}
|
||||
|
||||
// --- v2: AES-CBC with a random IV per encryption ---
|
||||
|
||||
async function encryptV2(value) {
|
||||
const gen_iv = webcrypto.getRandomValues(new Uint8Array(16));
|
||||
const cryptoKey = await webcrypto.subtle.importKey('raw', key, { name: algorithm }, false, [
|
||||
'encrypt',
|
||||
]);
|
||||
const encoder = new TextEncoder();
|
||||
const data = encoder.encode(value);
|
||||
const encryptedBuffer = await webcrypto.subtle.encrypt(
|
||||
{ name: algorithm, iv: gen_iv },
|
||||
cryptoKey,
|
||||
data,
|
||||
);
|
||||
return Buffer.from(gen_iv).toString('hex') + ':' + Buffer.from(encryptedBuffer).toString('hex');
|
||||
}
|
||||
|
||||
async function decryptV2(encryptedValue) {
|
||||
const parts = encryptedValue.split(':');
|
||||
if (parts.length === 1) {
|
||||
return parts[0];
|
||||
}
|
||||
const gen_iv = Buffer.from(parts.shift(), 'hex');
|
||||
const encrypted = parts.join(':');
|
||||
const cryptoKey = await webcrypto.subtle.importKey('raw', key, { name: algorithm }, false, [
|
||||
'decrypt',
|
||||
]);
|
||||
const encryptedBuffer = Buffer.from(encrypted, 'hex');
|
||||
const decryptedBuffer = await webcrypto.subtle.decrypt(
|
||||
{ name: algorithm, iv: gen_iv },
|
||||
cryptoKey,
|
||||
encryptedBuffer,
|
||||
);
|
||||
const decoder = new TextDecoder();
|
||||
return decoder.decode(decryptedBuffer);
|
||||
}
|
||||
|
||||
// --- v3: AES-256-CTR using Node's crypto functions ---
|
||||
const algorithm_v3 = 'aes-256-ctr';
|
||||
|
||||
/**
|
||||
* Encrypts a value using AES-256-CTR.
|
||||
* Note: AES-256 requires a 32-byte key. Ensure that process.env.CREDS_KEY is a 64-character hex string.
|
||||
*
|
||||
* @param {string} value - The plaintext to encrypt.
|
||||
* @returns {string} The encrypted string with a "v3:" prefix.
|
||||
*/
|
||||
function encryptV3(value) {
|
||||
if (key.length !== 32) {
|
||||
throw new Error(`Invalid key length: expected 32 bytes, got ${key.length} bytes`);
|
||||
}
|
||||
const iv_v3 = crypto.randomBytes(16);
|
||||
const cipher = crypto.createCipheriv(algorithm_v3, key, iv_v3);
|
||||
const encrypted = Buffer.concat([cipher.update(value, 'utf8'), cipher.final()]);
|
||||
return `v3:${iv_v3.toString('hex')}:${encrypted.toString('hex')}`;
|
||||
}
|
||||
|
||||
function decryptV3(encryptedValue) {
|
||||
const parts = encryptedValue.split(':');
|
||||
if (parts[0] !== 'v3') {
|
||||
throw new Error('Not a v3 encrypted value');
|
||||
}
|
||||
const iv_v3 = Buffer.from(parts[1], 'hex');
|
||||
const encryptedText = Buffer.from(parts.slice(2).join(':'), 'hex');
|
||||
const decipher = crypto.createDecipheriv(algorithm_v3, key, iv_v3);
|
||||
const decrypted = Buffer.concat([decipher.update(encryptedText), decipher.final()]);
|
||||
return decrypted.toString('utf8');
|
||||
}
|
||||
|
||||
async function getRandomValues(length) {
|
||||
if (!Number.isInteger(length) || length <= 0) {
|
||||
throw new Error('Length must be a positive integer');
|
||||
}
|
||||
const randomValues = new Uint8Array(length);
|
||||
webcrypto.getRandomValues(randomValues);
|
||||
return Buffer.from(randomValues).toString('hex');
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes SHA-256 hash for the given input.
|
||||
* @param {string} input
|
||||
* @returns {Promise<string>}
|
||||
*/
|
||||
async function hashBackupCode(input) {
|
||||
const encoder = new TextEncoder();
|
||||
const data = encoder.encode(input);
|
||||
const hashBuffer = await webcrypto.subtle.digest('SHA-256', data);
|
||||
const hashArray = Array.from(new Uint8Array(hashBuffer));
|
||||
return hashArray.map((b) => b.toString(16).padStart(2, '0')).join('');
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
encrypt,
|
||||
decrypt,
|
||||
encryptV2,
|
||||
decryptV2,
|
||||
encryptV3,
|
||||
decryptV3,
|
||||
hashBackupCode,
|
||||
getRandomValues,
|
||||
};
|
||||
|
|
@ -1,5 +1,3 @@
|
|||
const path = require('path');
|
||||
const crypto = require('crypto');
|
||||
const {
|
||||
Capabilities,
|
||||
EModelEndpoint,
|
||||
|
|
@ -218,38 +216,6 @@ function normalizeEndpointName(name = '') {
|
|||
return name.toLowerCase() === Providers.OLLAMA ? Providers.OLLAMA : name;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sanitize a filename by removing any directory components, replacing non-alphanumeric characters
|
||||
* @param {string} inputName
|
||||
* @returns {string}
|
||||
*/
|
||||
function sanitizeFilename(inputName) {
|
||||
// Remove any directory components
|
||||
let name = path.basename(inputName);
|
||||
|
||||
// Replace any non-alphanumeric characters except for '.' and '-'
|
||||
name = name.replace(/[^a-zA-Z0-9.-]/g, '_');
|
||||
|
||||
// Ensure the name doesn't start with a dot (hidden file in Unix-like systems)
|
||||
if (name.startsWith('.') || name === '') {
|
||||
name = '_' + name;
|
||||
}
|
||||
|
||||
// Limit the length of the filename
|
||||
const MAX_LENGTH = 255;
|
||||
if (name.length > MAX_LENGTH) {
|
||||
const ext = path.extname(name);
|
||||
const nameWithoutExt = path.basename(name, ext);
|
||||
name =
|
||||
nameWithoutExt.slice(0, MAX_LENGTH - ext.length - 7) +
|
||||
'-' +
|
||||
crypto.randomBytes(3).toString('hex') +
|
||||
ext;
|
||||
}
|
||||
|
||||
return name;
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
isEnabled,
|
||||
handleText,
|
||||
|
|
@ -260,6 +226,5 @@ module.exports = {
|
|||
generateConfig,
|
||||
addSpaceIfNeeded,
|
||||
createOnProgress,
|
||||
sanitizeFilename,
|
||||
normalizeEndpointName,
|
||||
};
|
||||
|
|
|
|||
|
|
@ -1,103 +0,0 @@
|
|||
const { isEnabled, sanitizeFilename } = require('./handleText');
|
||||
|
||||
describe('isEnabled', () => {
|
||||
test('should return true when input is "true"', () => {
|
||||
expect(isEnabled('true')).toBe(true);
|
||||
});
|
||||
|
||||
test('should return true when input is "TRUE"', () => {
|
||||
expect(isEnabled('TRUE')).toBe(true);
|
||||
});
|
||||
|
||||
test('should return true when input is true', () => {
|
||||
expect(isEnabled(true)).toBe(true);
|
||||
});
|
||||
|
||||
test('should return false when input is "false"', () => {
|
||||
expect(isEnabled('false')).toBe(false);
|
||||
});
|
||||
|
||||
test('should return false when input is false', () => {
|
||||
expect(isEnabled(false)).toBe(false);
|
||||
});
|
||||
|
||||
test('should return false when input is null', () => {
|
||||
expect(isEnabled(null)).toBe(false);
|
||||
});
|
||||
|
||||
test('should return false when input is undefined', () => {
|
||||
expect(isEnabled()).toBe(false);
|
||||
});
|
||||
|
||||
test('should return false when input is an empty string', () => {
|
||||
expect(isEnabled('')).toBe(false);
|
||||
});
|
||||
|
||||
test('should return false when input is a whitespace string', () => {
|
||||
expect(isEnabled(' ')).toBe(false);
|
||||
});
|
||||
|
||||
test('should return false when input is a number', () => {
|
||||
expect(isEnabled(123)).toBe(false);
|
||||
});
|
||||
|
||||
test('should return false when input is an object', () => {
|
||||
expect(isEnabled({})).toBe(false);
|
||||
});
|
||||
|
||||
test('should return false when input is an array', () => {
|
||||
expect(isEnabled([])).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
jest.mock('crypto', () => {
|
||||
const actualModule = jest.requireActual('crypto');
|
||||
return {
|
||||
...actualModule,
|
||||
randomBytes: jest.fn().mockReturnValue(Buffer.from('abc123', 'hex')),
|
||||
};
|
||||
});
|
||||
|
||||
describe('sanitizeFilename', () => {
|
||||
test('removes directory components (1/2)', () => {
|
||||
expect(sanitizeFilename('/path/to/file.txt')).toBe('file.txt');
|
||||
});
|
||||
|
||||
test('removes directory components (2/2)', () => {
|
||||
expect(sanitizeFilename('../../../../file.txt')).toBe('file.txt');
|
||||
});
|
||||
|
||||
test('replaces non-alphanumeric characters', () => {
|
||||
expect(sanitizeFilename('file name@#$.txt')).toBe('file_name___.txt');
|
||||
});
|
||||
|
||||
test('preserves dots and hyphens', () => {
|
||||
expect(sanitizeFilename('file-name.with.dots.txt')).toBe('file-name.with.dots.txt');
|
||||
});
|
||||
|
||||
test('prepends underscore to filenames starting with a dot', () => {
|
||||
expect(sanitizeFilename('.hiddenfile')).toBe('_.hiddenfile');
|
||||
});
|
||||
|
||||
test('truncates long filenames', () => {
|
||||
const longName = 'a'.repeat(300) + '.txt';
|
||||
const result = sanitizeFilename(longName);
|
||||
expect(result.length).toBe(255);
|
||||
expect(result).toMatch(/^a+-abc123\.txt$/);
|
||||
});
|
||||
|
||||
test('handles filenames with no extension', () => {
|
||||
const longName = 'a'.repeat(300);
|
||||
const result = sanitizeFilename(longName);
|
||||
expect(result.length).toBe(255);
|
||||
expect(result).toMatch(/^a+-abc123$/);
|
||||
});
|
||||
|
||||
test('handles empty input', () => {
|
||||
expect(sanitizeFilename('')).toBe('_');
|
||||
});
|
||||
|
||||
test('handles input with only special characters', () => {
|
||||
expect(sanitizeFilename('@#$%^&*')).toBe('_______');
|
||||
});
|
||||
});
|
||||
|
|
@ -3,7 +3,6 @@ const removePorts = require('./removePorts');
|
|||
const countTokens = require('./countTokens');
|
||||
const handleText = require('./handleText');
|
||||
const sendEmail = require('./sendEmail');
|
||||
const cryptoUtils = require('./crypto');
|
||||
const queue = require('./queue');
|
||||
const files = require('./files');
|
||||
const math = require('./math');
|
||||
|
|
@ -13,18 +12,24 @@ const math = require('./math');
|
|||
* @returns {Boolean}
|
||||
*/
|
||||
function checkEmailConfig() {
|
||||
return (
|
||||
// Check if Mailgun is configured
|
||||
const hasMailgunConfig =
|
||||
!!process.env.MAILGUN_API_KEY && !!process.env.MAILGUN_DOMAIN && !!process.env.EMAIL_FROM;
|
||||
|
||||
// Check if SMTP is configured
|
||||
const hasSMTPConfig =
|
||||
(!!process.env.EMAIL_SERVICE || !!process.env.EMAIL_HOST) &&
|
||||
!!process.env.EMAIL_USERNAME &&
|
||||
!!process.env.EMAIL_PASSWORD &&
|
||||
!!process.env.EMAIL_FROM
|
||||
);
|
||||
!!process.env.EMAIL_FROM;
|
||||
|
||||
// Return true if either Mailgun or SMTP is properly configured
|
||||
return hasMailgunConfig || hasSMTPConfig;
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
...streamResponse,
|
||||
checkEmailConfig,
|
||||
...cryptoUtils,
|
||||
...handleText,
|
||||
countTokens,
|
||||
removePorts,
|
||||
|
|
|
|||
|
|
@ -1,9 +1,69 @@
|
|||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
const axios = require('axios');
|
||||
const FormData = require('form-data');
|
||||
const nodemailer = require('nodemailer');
|
||||
const handlebars = require('handlebars');
|
||||
const { logAxiosError } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { isEnabled } = require('~/server/utils/handleText');
|
||||
const logger = require('~/config/winston');
|
||||
|
||||
/**
|
||||
* Sends an email using Mailgun API.
|
||||
*
|
||||
* @async
|
||||
* @function sendEmailViaMailgun
|
||||
* @param {Object} params - The parameters for sending the email.
|
||||
* @param {string} params.to - The recipient's email address.
|
||||
* @param {string} params.from - The sender's email address.
|
||||
* @param {string} params.subject - The subject of the email.
|
||||
* @param {string} params.html - The HTML content of the email.
|
||||
* @returns {Promise<Object>} - A promise that resolves to the response from Mailgun API.
|
||||
*/
|
||||
const sendEmailViaMailgun = async ({ to, from, subject, html }) => {
|
||||
const mailgunApiKey = process.env.MAILGUN_API_KEY;
|
||||
const mailgunDomain = process.env.MAILGUN_DOMAIN;
|
||||
const mailgunHost = process.env.MAILGUN_HOST || 'https://api.mailgun.net';
|
||||
|
||||
if (!mailgunApiKey || !mailgunDomain) {
|
||||
throw new Error('Mailgun API key and domain are required');
|
||||
}
|
||||
|
||||
const formData = new FormData();
|
||||
formData.append('from', from);
|
||||
formData.append('to', to);
|
||||
formData.append('subject', subject);
|
||||
formData.append('html', html);
|
||||
formData.append('o:tracking-clicks', 'no');
|
||||
|
||||
try {
|
||||
const response = await axios.post(`${mailgunHost}/v3/${mailgunDomain}/messages`, formData, {
|
||||
headers: {
|
||||
...formData.getHeaders(),
|
||||
Authorization: `Basic ${Buffer.from(`api:${mailgunApiKey}`).toString('base64')}`,
|
||||
},
|
||||
});
|
||||
|
||||
return response.data;
|
||||
} catch (error) {
|
||||
throw new Error(logAxiosError({ error, message: 'Failed to send email via Mailgun' }));
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Sends an email using SMTP via Nodemailer.
|
||||
*
|
||||
* @async
|
||||
* @function sendEmailViaSMTP
|
||||
* @param {Object} params - The parameters for sending the email.
|
||||
* @param {Object} params.transporterOptions - The transporter configuration options.
|
||||
* @param {Object} params.mailOptions - The email options.
|
||||
* @returns {Promise<Object>} - A promise that resolves to the info object of the sent email.
|
||||
*/
|
||||
const sendEmailViaSMTP = async ({ transporterOptions, mailOptions }) => {
|
||||
const transporter = nodemailer.createTransport(transporterOptions);
|
||||
return await transporter.sendMail(mailOptions);
|
||||
};
|
||||
|
||||
/**
|
||||
* Sends an email using the specified template, subject, and payload.
|
||||
|
|
@ -34,6 +94,30 @@ const logger = require('~/config/winston');
|
|||
*/
|
||||
const sendEmail = async ({ email, subject, payload, template, throwError = true }) => {
|
||||
try {
|
||||
// Read and compile the email template
|
||||
const source = fs.readFileSync(path.join(__dirname, 'emails', template), 'utf8');
|
||||
const compiledTemplate = handlebars.compile(source);
|
||||
const html = compiledTemplate(payload);
|
||||
|
||||
// Prepare common email data
|
||||
const fromName = process.env.EMAIL_FROM_NAME || process.env.APP_TITLE;
|
||||
const fromEmail = process.env.EMAIL_FROM;
|
||||
const fromAddress = `"${fromName}" <${fromEmail}>`;
|
||||
const toAddress = `"${payload.name}" <${email}>`;
|
||||
|
||||
// Check if Mailgun is configured
|
||||
if (process.env.MAILGUN_API_KEY && process.env.MAILGUN_DOMAIN) {
|
||||
logger.debug('[sendEmail] Using Mailgun provider');
|
||||
return await sendEmailViaMailgun({
|
||||
from: fromAddress,
|
||||
to: toAddress,
|
||||
subject: subject,
|
||||
html: html,
|
||||
});
|
||||
}
|
||||
|
||||
// Default to SMTP
|
||||
logger.debug('[sendEmail] Using SMTP provider');
|
||||
const transporterOptions = {
|
||||
// Use STARTTLS by default instead of obligatory TLS
|
||||
secure: process.env.EMAIL_ENCRYPTION === 'tls',
|
||||
|
|
@ -62,30 +146,21 @@ const sendEmail = async ({ email, subject, payload, template, throwError = true
|
|||
transporterOptions.port = process.env.EMAIL_PORT ?? 25;
|
||||
}
|
||||
|
||||
const transporter = nodemailer.createTransport(transporterOptions);
|
||||
|
||||
const source = fs.readFileSync(path.join(__dirname, 'emails', template), 'utf8');
|
||||
const compiledTemplate = handlebars.compile(source);
|
||||
const options = () => {
|
||||
return {
|
||||
// Header address should contain name-addr
|
||||
from:
|
||||
`"${process.env.EMAIL_FROM_NAME || process.env.APP_TITLE}"` +
|
||||
`<${process.env.EMAIL_FROM}>`,
|
||||
to: `"${payload.name}" <${email}>`,
|
||||
envelope: {
|
||||
// Envelope from should contain addr-spec
|
||||
// Mistake in the Nodemailer documentation?
|
||||
from: process.env.EMAIL_FROM,
|
||||
to: email,
|
||||
},
|
||||
subject: subject,
|
||||
html: compiledTemplate(payload),
|
||||
};
|
||||
const mailOptions = {
|
||||
// Header address should contain name-addr
|
||||
from: fromAddress,
|
||||
to: toAddress,
|
||||
envelope: {
|
||||
// Envelope from should contain addr-spec
|
||||
// Mistake in the Nodemailer documentation?
|
||||
from: fromEmail,
|
||||
to: email,
|
||||
},
|
||||
subject: subject,
|
||||
html: html,
|
||||
};
|
||||
|
||||
// Send email
|
||||
return await transporter.sendMail(options());
|
||||
return await sendEmailViaSMTP({ transporterOptions, mailOptions });
|
||||
} catch (error) {
|
||||
if (throwError) {
|
||||
throw error;
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue