mirror of
https://github.com/danny-avila/LibreChat.git
synced 2026-01-28 13:16:13 +01:00
Merge branch 'main' into feat/openid-custom-data
This commit is contained in:
commit
f0a42d20a2
296 changed files with 9736 additions and 4122 deletions
|
|
@ -4,6 +4,7 @@ const {
|
|||
Constants,
|
||||
ErrorTypes,
|
||||
EModelEndpoint,
|
||||
parseTextParts,
|
||||
anthropicSettings,
|
||||
getResponseSender,
|
||||
validateVisionModel,
|
||||
|
|
@ -696,15 +697,8 @@ class AnthropicClient extends BaseClient {
|
|||
if (msg.text != null && msg.text && msg.text.startsWith(':::thinking')) {
|
||||
msg.text = msg.text.replace(/:::thinking.*?:::/gs, '').trim();
|
||||
} else if (msg.content != null) {
|
||||
/** @type {import('@librechat/agents').MessageContentComplex} */
|
||||
const newContent = [];
|
||||
for (let part of msg.content) {
|
||||
if (part.think != null) {
|
||||
continue;
|
||||
}
|
||||
newContent.push(part);
|
||||
}
|
||||
msg.content = newContent;
|
||||
msg.text = parseTextParts(msg.content, true);
|
||||
delete msg.content;
|
||||
}
|
||||
|
||||
return msg;
|
||||
|
|
|
|||
|
|
@ -676,7 +676,8 @@ class BaseClient {
|
|||
responseMessage.text = addSpaceIfNeeded(generation) + completion;
|
||||
} else if (
|
||||
Array.isArray(completion) &&
|
||||
isParamEndpoint(this.options.endpoint, this.options.endpointType)
|
||||
(this.clientName === EModelEndpoint.agents ||
|
||||
isParamEndpoint(this.options.endpoint, this.options.endpointType))
|
||||
) {
|
||||
responseMessage.text = '';
|
||||
responseMessage.content = completion;
|
||||
|
|
@ -879,13 +880,14 @@ class BaseClient {
|
|||
: await getConvo(this.options.req?.user?.id, message.conversationId);
|
||||
|
||||
const unsetFields = {};
|
||||
const exceptions = new Set(['spec', 'iconURL']);
|
||||
if (existingConvo != null) {
|
||||
this.fetchedConvo = true;
|
||||
for (const key in existingConvo) {
|
||||
if (!key) {
|
||||
continue;
|
||||
}
|
||||
if (excludedKeys.has(key)) {
|
||||
if (excludedKeys.has(key) && !exceptions.has(key)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ const {
|
|||
validateVisionModel,
|
||||
getResponseSender,
|
||||
endpointSettings,
|
||||
parseTextParts,
|
||||
EModelEndpoint,
|
||||
ContentTypes,
|
||||
VisionModes,
|
||||
|
|
@ -198,7 +199,11 @@ class GoogleClient extends BaseClient {
|
|||
*/
|
||||
checkVisionRequest(attachments) {
|
||||
/* Validation vision request */
|
||||
this.defaultVisionModel = this.options.visionModel ?? 'gemini-pro-vision';
|
||||
this.defaultVisionModel =
|
||||
this.options.visionModel ??
|
||||
(!EXCLUDED_GENAI_MODELS.test(this.modelOptions.model)
|
||||
? this.modelOptions.model
|
||||
: 'gemini-pro-vision');
|
||||
const availableModels = this.options.modelsConfig?.[EModelEndpoint.google];
|
||||
this.isVisionModel = validateVisionModel({ model: this.modelOptions.model, availableModels });
|
||||
|
||||
|
|
@ -770,6 +775,22 @@ class GoogleClient extends BaseClient {
|
|||
return this.usage;
|
||||
}
|
||||
|
||||
getMessageMapMethod() {
|
||||
/**
|
||||
* @param {TMessage} msg
|
||||
*/
|
||||
return (msg) => {
|
||||
if (msg.text != null && msg.text && msg.text.startsWith(':::thinking')) {
|
||||
msg.text = msg.text.replace(/:::thinking.*?:::/gs, '').trim();
|
||||
} else if (msg.content != null) {
|
||||
msg.text = parseTextParts(msg.content, true);
|
||||
delete msg.content;
|
||||
}
|
||||
|
||||
return msg;
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculates the correct token count for the current user message based on the token count map and API usage.
|
||||
* Edge case: If the calculation results in a negative value, it returns the original estimate.
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ const {
|
|||
Constants,
|
||||
ImageDetail,
|
||||
ContentTypes,
|
||||
parseTextParts,
|
||||
EModelEndpoint,
|
||||
resolveHeaders,
|
||||
KnownEndpoints,
|
||||
|
|
@ -226,10 +227,6 @@ class OpenAIClient extends BaseClient {
|
|||
logger.debug('Using Azure endpoint');
|
||||
}
|
||||
|
||||
if (this.useOpenRouter) {
|
||||
this.completionsUrl = 'https://openrouter.ai/api/v1/chat/completions';
|
||||
}
|
||||
|
||||
return this;
|
||||
}
|
||||
|
||||
|
|
@ -1125,15 +1122,8 @@ ${convo}
|
|||
if (msg.text != null && msg.text && msg.text.startsWith(':::thinking')) {
|
||||
msg.text = msg.text.replace(/:::thinking.*?:::/gs, '').trim();
|
||||
} else if (msg.content != null) {
|
||||
/** @type {import('@librechat/agents').MessageContentComplex} */
|
||||
const newContent = [];
|
||||
for (let part of msg.content) {
|
||||
if (part.think != null) {
|
||||
continue;
|
||||
}
|
||||
newContent.push(part);
|
||||
}
|
||||
msg.content = newContent;
|
||||
msg.text = parseTextParts(msg.content, true);
|
||||
delete msg.content;
|
||||
}
|
||||
|
||||
return msg;
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@ function createLLM({
|
|||
let credentials = { openAIApiKey };
|
||||
let configuration = {
|
||||
apiKey: openAIApiKey,
|
||||
...(configOptions.basePath && { baseURL: configOptions.basePath }),
|
||||
};
|
||||
|
||||
/** @type {AzureOptions} */
|
||||
|
|
|
|||
5
api/cache/getLogStores.js
vendored
5
api/cache/getLogStores.js
vendored
|
|
@ -49,6 +49,10 @@ const genTitle = isRedisEnabled
|
|||
? new Keyv({ store: keyvRedis, ttl: Time.TWO_MINUTES })
|
||||
: new Keyv({ namespace: CacheKeys.GEN_TITLE, ttl: Time.TWO_MINUTES });
|
||||
|
||||
const s3ExpiryInterval = isRedisEnabled
|
||||
? new Keyv({ store: keyvRedis, ttl: Time.THIRTY_MINUTES })
|
||||
: new Keyv({ namespace: CacheKeys.S3_EXPIRY_INTERVAL, ttl: Time.THIRTY_MINUTES });
|
||||
|
||||
const modelQueries = isEnabled(process.env.USE_REDIS)
|
||||
? new Keyv({ store: keyvRedis })
|
||||
: new Keyv({ namespace: CacheKeys.MODEL_QUERIES });
|
||||
|
|
@ -89,6 +93,7 @@ const namespaces = {
|
|||
[CacheKeys.ABORT_KEYS]: abortKeys,
|
||||
[CacheKeys.TOKEN_CONFIG]: tokenConfig,
|
||||
[CacheKeys.GEN_TITLE]: genTitle,
|
||||
[CacheKeys.S3_EXPIRY_INTERVAL]: s3ExpiryInterval,
|
||||
[CacheKeys.MODEL_QUERIES]: modelQueries,
|
||||
[CacheKeys.AUDIO_RUNS]: audioRuns,
|
||||
[CacheKeys.MESSAGES]: messages,
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
const axios = require('axios');
|
||||
const { EventSource } = require('eventsource');
|
||||
const { Time, CacheKeys } = require('librechat-data-provider');
|
||||
const { MCPManager, FlowStateManager } = require('librechat-mcp');
|
||||
const logger = require('./winston');
|
||||
|
||||
global.EventSource = EventSource;
|
||||
|
|
@ -9,11 +10,10 @@ let mcpManager = null;
|
|||
let flowManager = null;
|
||||
|
||||
/**
|
||||
* @returns {Promise<MCPManager>}
|
||||
* @returns {MCPManager}
|
||||
*/
|
||||
async function getMCPManager() {
|
||||
function getMCPManager() {
|
||||
if (!mcpManager) {
|
||||
const { MCPManager } = await import('librechat-mcp');
|
||||
mcpManager = MCPManager.getInstance(logger);
|
||||
}
|
||||
return mcpManager;
|
||||
|
|
@ -21,11 +21,10 @@ async function getMCPManager() {
|
|||
|
||||
/**
|
||||
* @param {(key: string) => Keyv} getLogStores
|
||||
* @returns {Promise<FlowStateManager>}
|
||||
* @returns {FlowStateManager}
|
||||
*/
|
||||
async function getFlowStateManager(getLogStores) {
|
||||
function getFlowStateManager(getLogStores) {
|
||||
if (!flowManager) {
|
||||
const { FlowStateManager } = await import('librechat-mcp');
|
||||
flowManager = new FlowStateManager(getLogStores(CacheKeys.FLOWS), {
|
||||
ttl: Time.ONE_MINUTE * 3,
|
||||
logger,
|
||||
|
|
|
|||
|
|
@ -4,7 +4,11 @@ require('winston-daily-rotate-file');
|
|||
|
||||
const logDir = path.join(__dirname, '..', 'logs');
|
||||
|
||||
const { NODE_ENV } = process.env;
|
||||
const { NODE_ENV, DEBUG_LOGGING = false } = process.env;
|
||||
|
||||
const useDebugLogging =
|
||||
(typeof DEBUG_LOGGING === 'string' && DEBUG_LOGGING?.toLowerCase() === 'true') ||
|
||||
DEBUG_LOGGING === true;
|
||||
|
||||
const levels = {
|
||||
error: 0,
|
||||
|
|
@ -36,9 +40,10 @@ const fileFormat = winston.format.combine(
|
|||
winston.format.splat(),
|
||||
);
|
||||
|
||||
const logLevel = useDebugLogging ? 'debug' : 'error';
|
||||
const transports = [
|
||||
new winston.transports.DailyRotateFile({
|
||||
level: 'debug',
|
||||
level: logLevel,
|
||||
filename: `${logDir}/meiliSync-%DATE%.log`,
|
||||
datePattern: 'YYYY-MM-DD',
|
||||
zippedArchive: true,
|
||||
|
|
@ -48,14 +53,6 @@ const transports = [
|
|||
}),
|
||||
];
|
||||
|
||||
// if (NODE_ENV !== 'production') {
|
||||
// transports.push(
|
||||
// new winston.transports.Console({
|
||||
// format: winston.format.combine(winston.format.colorize(), winston.format.simple()),
|
||||
// }),
|
||||
// );
|
||||
// }
|
||||
|
||||
const consoleFormat = winston.format.combine(
|
||||
winston.format.colorize({ all: true }),
|
||||
winston.format.timestamp({ format: 'YYYY-MM-DD HH:mm:ss' }),
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ const { redactFormat, redactMessage, debugTraverse, jsonTruncateFormat } = requi
|
|||
|
||||
const logDir = path.join(__dirname, '..', 'logs');
|
||||
|
||||
const { NODE_ENV, DEBUG_LOGGING = true, DEBUG_CONSOLE = false, CONSOLE_JSON = false } = process.env;
|
||||
const { NODE_ENV, DEBUG_LOGGING = true, CONSOLE_JSON = false, DEBUG_CONSOLE = false } = process.env;
|
||||
|
||||
const useConsoleJson =
|
||||
(typeof CONSOLE_JSON === 'string' && CONSOLE_JSON?.toLowerCase() === 'true') ||
|
||||
|
|
@ -15,6 +15,10 @@ const useDebugConsole =
|
|||
(typeof DEBUG_CONSOLE === 'string' && DEBUG_CONSOLE?.toLowerCase() === 'true') ||
|
||||
DEBUG_CONSOLE === true;
|
||||
|
||||
const useDebugLogging =
|
||||
(typeof DEBUG_LOGGING === 'string' && DEBUG_LOGGING?.toLowerCase() === 'true') ||
|
||||
DEBUG_LOGGING === true;
|
||||
|
||||
const levels = {
|
||||
error: 0,
|
||||
warn: 1,
|
||||
|
|
@ -57,28 +61,9 @@ const transports = [
|
|||
maxFiles: '14d',
|
||||
format: fileFormat,
|
||||
}),
|
||||
// new winston.transports.DailyRotateFile({
|
||||
// level: 'info',
|
||||
// filename: `${logDir}/info-%DATE%.log`,
|
||||
// datePattern: 'YYYY-MM-DD',
|
||||
// zippedArchive: true,
|
||||
// maxSize: '20m',
|
||||
// maxFiles: '14d',
|
||||
// }),
|
||||
];
|
||||
|
||||
// if (NODE_ENV !== 'production') {
|
||||
// transports.push(
|
||||
// new winston.transports.Console({
|
||||
// format: winston.format.combine(winston.format.colorize(), winston.format.simple()),
|
||||
// }),
|
||||
// );
|
||||
// }
|
||||
|
||||
if (
|
||||
(typeof DEBUG_LOGGING === 'string' && DEBUG_LOGGING?.toLowerCase() === 'true') ||
|
||||
DEBUG_LOGGING === true
|
||||
) {
|
||||
if (useDebugLogging) {
|
||||
transports.push(
|
||||
new winston.transports.DailyRotateFile({
|
||||
level: 'debug',
|
||||
|
|
@ -107,10 +92,16 @@ const consoleFormat = winston.format.combine(
|
|||
}),
|
||||
);
|
||||
|
||||
// Determine console log level
|
||||
let consoleLogLevel = 'info';
|
||||
if (useDebugConsole) {
|
||||
consoleLogLevel = 'debug';
|
||||
}
|
||||
|
||||
if (useDebugConsole) {
|
||||
transports.push(
|
||||
new winston.transports.Console({
|
||||
level: 'debug',
|
||||
level: consoleLogLevel,
|
||||
format: useConsoleJson
|
||||
? winston.format.combine(fileFormat, jsonTruncateFormat(), winston.format.json())
|
||||
: winston.format.combine(fileFormat, debugTraverse),
|
||||
|
|
@ -119,14 +110,14 @@ if (useDebugConsole) {
|
|||
} else if (useConsoleJson) {
|
||||
transports.push(
|
||||
new winston.transports.Console({
|
||||
level: 'info',
|
||||
level: consoleLogLevel,
|
||||
format: winston.format.combine(fileFormat, jsonTruncateFormat(), winston.format.json()),
|
||||
}),
|
||||
);
|
||||
} else {
|
||||
transports.push(
|
||||
new winston.transports.Console({
|
||||
level: 'info',
|
||||
level: consoleLogLevel,
|
||||
format: consoleFormat,
|
||||
}),
|
||||
);
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
const mongoose = require('mongoose');
|
||||
const { SystemRoles } = require('librechat-data-provider');
|
||||
const { GLOBAL_PROJECT_NAME } = require('librechat-data-provider').Constants;
|
||||
const { agentSchema } = require('@librechat/data-schemas');
|
||||
const { SystemRoles, Tools } = require('librechat-data-provider');
|
||||
const { GLOBAL_PROJECT_NAME, EPHEMERAL_AGENT_ID, mcp_delimiter } =
|
||||
require('librechat-data-provider').Constants;
|
||||
const { CONFIG_STORE, STARTUP_CONFIG } = require('librechat-data-provider').CacheKeys;
|
||||
const {
|
||||
getProjectByName,
|
||||
|
|
@ -9,7 +11,6 @@ const {
|
|||
removeAgentFromAllProjects,
|
||||
} = require('./Project');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const { agentSchema } = require('@librechat/data-schemas');
|
||||
|
||||
const Agent = mongoose.model('agent', agentSchema);
|
||||
|
||||
|
|
@ -39,13 +40,69 @@ const getAgent = async (searchParameter) => await Agent.findOne(searchParameter)
|
|||
* @param {Object} params
|
||||
* @param {ServerRequest} params.req
|
||||
* @param {string} params.agent_id
|
||||
* @param {string} params.endpoint
|
||||
* @param {import('@librechat/agents').ClientOptions} [params.model_parameters]
|
||||
* @returns {Agent|null} The agent document as a plain object, or null if not found.
|
||||
*/
|
||||
const loadEphemeralAgent = ({ req, agent_id, endpoint, model_parameters: _m }) => {
|
||||
const { model, ...model_parameters } = _m;
|
||||
/** @type {Record<string, FunctionTool>} */
|
||||
const availableTools = req.app.locals.availableTools;
|
||||
const mcpServers = new Set(req.body.ephemeralAgent?.mcp);
|
||||
/** @type {string[]} */
|
||||
const tools = [];
|
||||
if (req.body.ephemeralAgent?.execute_code === true) {
|
||||
tools.push(Tools.execute_code);
|
||||
}
|
||||
|
||||
if (mcpServers.size > 0) {
|
||||
for (const toolName of Object.keys(availableTools)) {
|
||||
if (!toolName.includes(mcp_delimiter)) {
|
||||
continue;
|
||||
}
|
||||
const mcpServer = toolName.split(mcp_delimiter)?.[1];
|
||||
if (mcpServer && mcpServers.has(mcpServer)) {
|
||||
tools.push(toolName);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const instructions = req.body.promptPrefix;
|
||||
return {
|
||||
id: agent_id,
|
||||
instructions,
|
||||
provider: endpoint,
|
||||
model_parameters,
|
||||
model,
|
||||
tools,
|
||||
};
|
||||
};
|
||||
|
||||
/**
|
||||
* Load an agent based on the provided ID
|
||||
*
|
||||
* @param {Object} params
|
||||
* @param {ServerRequest} params.req
|
||||
* @param {string} params.agent_id
|
||||
* @param {string} params.endpoint
|
||||
* @param {import('@librechat/agents').ClientOptions} [params.model_parameters]
|
||||
* @returns {Promise<Agent|null>} The agent document as a plain object, or null if not found.
|
||||
*/
|
||||
const loadAgent = async ({ req, agent_id }) => {
|
||||
const loadAgent = async ({ req, agent_id, endpoint, model_parameters }) => {
|
||||
if (!agent_id) {
|
||||
return null;
|
||||
}
|
||||
if (agent_id === EPHEMERAL_AGENT_ID) {
|
||||
return loadEphemeralAgent({ req, agent_id, endpoint, model_parameters });
|
||||
}
|
||||
const agent = await getAgent({
|
||||
id: agent_id,
|
||||
});
|
||||
|
||||
if (!agent) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (agent.author.toString() === req.user.id) {
|
||||
return agent;
|
||||
}
|
||||
|
|
@ -122,16 +179,17 @@ const addAgentResourceFile = async ({ agent_id, tool_resource, file_id }) => {
|
|||
};
|
||||
|
||||
/**
|
||||
* Removes multiple resource files from an agent in a single update.
|
||||
* Removes multiple resource files from an agent using atomic operations.
|
||||
* @param {object} params
|
||||
* @param {string} params.agent_id
|
||||
* @param {Array<{tool_resource: string, file_id: string}>} params.files
|
||||
* @returns {Promise<Agent>} The updated agent.
|
||||
* @throws {Error} If the agent is not found or update fails.
|
||||
*/
|
||||
const removeAgentResourceFiles = async ({ agent_id, files }) => {
|
||||
const searchParameter = { id: agent_id };
|
||||
|
||||
// associate each tool resource with the respective file ids array
|
||||
// Group files to remove by resource
|
||||
const filesByResource = files.reduce((acc, { tool_resource, file_id }) => {
|
||||
if (!acc[tool_resource]) {
|
||||
acc[tool_resource] = [];
|
||||
|
|
@ -140,42 +198,35 @@ const removeAgentResourceFiles = async ({ agent_id, files }) => {
|
|||
return acc;
|
||||
}, {});
|
||||
|
||||
// build the update aggregation pipeline wich removes file ids from tool resources array
|
||||
// and eventually deletes empty tool resources
|
||||
const updateData = [];
|
||||
Object.entries(filesByResource).forEach(([resource, fileIds]) => {
|
||||
const toolResourcePath = `tool_resources.${resource}`;
|
||||
const fileIdsPath = `${toolResourcePath}.file_ids`;
|
||||
|
||||
// file ids removal stage
|
||||
updateData.push({
|
||||
$set: {
|
||||
[fileIdsPath]: {
|
||||
$filter: {
|
||||
input: `$${fileIdsPath}`,
|
||||
cond: { $not: [{ $in: ['$$this', fileIds] }] },
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
// empty tool resource deletion stage
|
||||
updateData.push({
|
||||
$set: {
|
||||
[toolResourcePath]: {
|
||||
$cond: [{ $eq: [`$${fileIdsPath}`, []] }, '$$REMOVE', `$${toolResourcePath}`],
|
||||
},
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
// return the updated agent or throw if no agent matches
|
||||
const updatedAgent = await updateAgent(searchParameter, updateData);
|
||||
if (updatedAgent) {
|
||||
return updatedAgent;
|
||||
} else {
|
||||
throw new Error('Agent not found for removing resource files');
|
||||
// Step 1: Atomically remove file IDs using $pull
|
||||
const pullOps = {};
|
||||
const resourcesToCheck = new Set();
|
||||
for (const [resource, fileIds] of Object.entries(filesByResource)) {
|
||||
const fileIdsPath = `tool_resources.${resource}.file_ids`;
|
||||
pullOps[fileIdsPath] = { $in: fileIds };
|
||||
resourcesToCheck.add(resource);
|
||||
}
|
||||
|
||||
const updatePullData = { $pull: pullOps };
|
||||
const agentAfterPull = await Agent.findOneAndUpdate(searchParameter, updatePullData, {
|
||||
new: true,
|
||||
}).lean();
|
||||
|
||||
if (!agentAfterPull) {
|
||||
// Agent might have been deleted concurrently, or never existed.
|
||||
// Check if it existed before trying to throw.
|
||||
const agentExists = await getAgent(searchParameter);
|
||||
if (!agentExists) {
|
||||
throw new Error('Agent not found for removing resource files');
|
||||
}
|
||||
// If it existed but findOneAndUpdate returned null, something else went wrong.
|
||||
throw new Error('Failed to update agent during file removal (pull step)');
|
||||
}
|
||||
|
||||
// Return the agent state directly after the $pull operation.
|
||||
// Skipping the $unset step for now to simplify and test core $pull atomicity.
|
||||
// Empty arrays might remain, but the removal itself should be correct.
|
||||
return agentAfterPull;
|
||||
};
|
||||
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -157,4 +157,134 @@ describe('Agent Resource File Operations', () => {
|
|||
expect(updatedAgent.tool_resources[tool].file_ids).toHaveLength(5);
|
||||
});
|
||||
});
|
||||
|
||||
test('should handle concurrent duplicate additions', async () => {
|
||||
const agent = await createBasicAgent();
|
||||
const fileId = uuidv4();
|
||||
|
||||
// Concurrent additions of the same file
|
||||
const additionPromises = Array.from({ length: 5 }).map(() =>
|
||||
addAgentResourceFile({
|
||||
agent_id: agent.id,
|
||||
tool_resource: 'test_tool',
|
||||
file_id: fileId,
|
||||
}),
|
||||
);
|
||||
|
||||
await Promise.all(additionPromises);
|
||||
|
||||
const updatedAgent = await Agent.findOne({ id: agent.id });
|
||||
expect(updatedAgent.tool_resources.test_tool.file_ids).toBeDefined();
|
||||
// Should only contain one instance of the fileId
|
||||
expect(updatedAgent.tool_resources.test_tool.file_ids).toHaveLength(1);
|
||||
expect(updatedAgent.tool_resources.test_tool.file_ids[0]).toBe(fileId);
|
||||
});
|
||||
|
||||
test('should handle concurrent add and remove of the same file', async () => {
|
||||
const agent = await createBasicAgent();
|
||||
const fileId = uuidv4();
|
||||
|
||||
// First, ensure the file exists (or test might be trivial if remove runs first)
|
||||
await addAgentResourceFile({
|
||||
agent_id: agent.id,
|
||||
tool_resource: 'test_tool',
|
||||
file_id: fileId,
|
||||
});
|
||||
|
||||
// Concurrent add (which should be ignored) and remove
|
||||
const operations = [
|
||||
addAgentResourceFile({
|
||||
agent_id: agent.id,
|
||||
tool_resource: 'test_tool',
|
||||
file_id: fileId,
|
||||
}),
|
||||
removeAgentResourceFiles({
|
||||
agent_id: agent.id,
|
||||
files: [{ tool_resource: 'test_tool', file_id: fileId }],
|
||||
}),
|
||||
];
|
||||
|
||||
await Promise.all(operations);
|
||||
|
||||
const updatedAgent = await Agent.findOne({ id: agent.id });
|
||||
// The final state should ideally be that the file is removed,
|
||||
// but the key point is consistency (not duplicated or error state).
|
||||
// Depending on execution order, the file might remain if the add operation's
|
||||
// findOneAndUpdate runs after the remove operation completes.
|
||||
// A more robust check might be that the length is <= 1.
|
||||
// Given the remove uses an update pipeline, it might be more likely to win.
|
||||
// The final state depends on race condition timing (add or remove might "win").
|
||||
// The critical part is that the state is consistent (no duplicates, no errors).
|
||||
// Assert that the fileId is either present exactly once or not present at all.
|
||||
expect(updatedAgent.tool_resources.test_tool.file_ids).toBeDefined();
|
||||
const finalFileIds = updatedAgent.tool_resources.test_tool.file_ids;
|
||||
const count = finalFileIds.filter((id) => id === fileId).length;
|
||||
expect(count).toBeLessThanOrEqual(1); // Should be 0 or 1, never more
|
||||
// Optional: Check overall length is consistent with the count
|
||||
if (count === 0) {
|
||||
expect(finalFileIds).toHaveLength(0);
|
||||
} else {
|
||||
expect(finalFileIds).toHaveLength(1);
|
||||
expect(finalFileIds[0]).toBe(fileId);
|
||||
}
|
||||
});
|
||||
|
||||
test('should handle concurrent duplicate removals', async () => {
|
||||
const agent = await createBasicAgent();
|
||||
const fileId = uuidv4();
|
||||
|
||||
// Add the file first
|
||||
await addAgentResourceFile({
|
||||
agent_id: agent.id,
|
||||
tool_resource: 'test_tool',
|
||||
file_id: fileId,
|
||||
});
|
||||
|
||||
// Concurrent removals of the same file
|
||||
const removalPromises = Array.from({ length: 5 }).map(() =>
|
||||
removeAgentResourceFiles({
|
||||
agent_id: agent.id,
|
||||
files: [{ tool_resource: 'test_tool', file_id: fileId }],
|
||||
}),
|
||||
);
|
||||
|
||||
await Promise.all(removalPromises);
|
||||
|
||||
const updatedAgent = await Agent.findOne({ id: agent.id });
|
||||
// Check if the array is empty or the tool resource itself is removed
|
||||
const fileIds = updatedAgent.tool_resources?.test_tool?.file_ids ?? [];
|
||||
expect(fileIds).toHaveLength(0);
|
||||
expect(fileIds).not.toContain(fileId);
|
||||
});
|
||||
|
||||
test('should handle concurrent removals of different files', async () => {
|
||||
const agent = await createBasicAgent();
|
||||
const fileIds = Array.from({ length: 10 }, () => uuidv4());
|
||||
|
||||
// Add all files first
|
||||
await Promise.all(
|
||||
fileIds.map((fileId) =>
|
||||
addAgentResourceFile({
|
||||
agent_id: agent.id,
|
||||
tool_resource: 'test_tool',
|
||||
file_id: fileId,
|
||||
}),
|
||||
),
|
||||
);
|
||||
|
||||
// Concurrently remove all files
|
||||
const removalPromises = fileIds.map((fileId) =>
|
||||
removeAgentResourceFiles({
|
||||
agent_id: agent.id,
|
||||
files: [{ tool_resource: 'test_tool', file_id: fileId }],
|
||||
}),
|
||||
);
|
||||
|
||||
await Promise.all(removalPromises);
|
||||
|
||||
const updatedAgent = await Agent.findOne({ id: agent.id });
|
||||
// Check if the array is empty or the tool resource itself is removed
|
||||
const finalFileIds = updatedAgent.tool_resources?.test_tool?.file_ids ?? [];
|
||||
expect(finalFileIds).toHaveLength(0);
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -134,6 +134,28 @@ const deleteFiles = async (file_ids, user) => {
|
|||
return await File.deleteMany(deleteQuery);
|
||||
};
|
||||
|
||||
/**
|
||||
* Batch updates files with new signed URLs in MongoDB
|
||||
*
|
||||
* @param {MongoFile[]} updates - Array of updates in the format { file_id, filepath }
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
async function batchUpdateFiles(updates) {
|
||||
if (!updates || updates.length === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const bulkOperations = updates.map((update) => ({
|
||||
updateOne: {
|
||||
filter: { file_id: update.file_id },
|
||||
update: { $set: { filepath: update.filepath } },
|
||||
},
|
||||
}));
|
||||
|
||||
const result = await File.bulkWrite(bulkOperations);
|
||||
logger.info(`Updated ${result.modifiedCount} files with new S3 URLs`);
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
File,
|
||||
findFileById,
|
||||
|
|
@ -145,4 +167,5 @@ module.exports = {
|
|||
deleteFile,
|
||||
deleteFiles,
|
||||
deleteFileByFilter,
|
||||
batchUpdateFiles,
|
||||
};
|
||||
|
|
|
|||
|
|
@ -4,13 +4,8 @@ const {
|
|||
SystemRoles,
|
||||
roleDefaults,
|
||||
PermissionTypes,
|
||||
permissionsSchema,
|
||||
removeNullishValues,
|
||||
agentPermissionsSchema,
|
||||
promptPermissionsSchema,
|
||||
runCodePermissionsSchema,
|
||||
bookmarkPermissionsSchema,
|
||||
multiConvoPermissionsSchema,
|
||||
temporaryChatPermissionsSchema,
|
||||
} = require('librechat-data-provider');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const { roleSchema } = require('@librechat/data-schemas');
|
||||
|
|
@ -20,15 +15,16 @@ const Role = mongoose.model('Role', roleSchema);
|
|||
|
||||
/**
|
||||
* Retrieve a role by name and convert the found role document to a plain object.
|
||||
* If the role with the given name doesn't exist and the name is a system defined role, create it and return the lean version.
|
||||
* If the role with the given name doesn't exist and the name is a system defined role,
|
||||
* create it and return the lean version.
|
||||
*
|
||||
* @param {string} roleName - The name of the role to find or create.
|
||||
* @param {string|string[]} [fieldsToSelect] - The fields to include or exclude in the returned document.
|
||||
* @returns {Promise<Object>} A plain object representing the role document.
|
||||
*/
|
||||
const getRoleByName = async function (roleName, fieldsToSelect = null) {
|
||||
const cache = getLogStores(CacheKeys.ROLES);
|
||||
try {
|
||||
const cache = getLogStores(CacheKeys.ROLES);
|
||||
const cachedRole = await cache.get(roleName);
|
||||
if (cachedRole) {
|
||||
return cachedRole;
|
||||
|
|
@ -40,8 +36,7 @@ const getRoleByName = async function (roleName, fieldsToSelect = null) {
|
|||
let role = await query.lean().exec();
|
||||
|
||||
if (!role && SystemRoles[roleName]) {
|
||||
role = roleDefaults[roleName];
|
||||
role = await new Role(role).save();
|
||||
role = await new Role(roleDefaults[roleName]).save();
|
||||
await cache.set(roleName, role);
|
||||
return role.toObject();
|
||||
}
|
||||
|
|
@ -60,8 +55,8 @@ const getRoleByName = async function (roleName, fieldsToSelect = null) {
|
|||
* @returns {Promise<TRole>} Updated role document.
|
||||
*/
|
||||
const updateRoleByName = async function (roleName, updates) {
|
||||
const cache = getLogStores(CacheKeys.ROLES);
|
||||
try {
|
||||
const cache = getLogStores(CacheKeys.ROLES);
|
||||
const role = await Role.findOneAndUpdate(
|
||||
{ name: roleName },
|
||||
{ $set: updates },
|
||||
|
|
@ -77,29 +72,20 @@ const updateRoleByName = async function (roleName, updates) {
|
|||
}
|
||||
};
|
||||
|
||||
const permissionSchemas = {
|
||||
[PermissionTypes.AGENTS]: agentPermissionsSchema,
|
||||
[PermissionTypes.PROMPTS]: promptPermissionsSchema,
|
||||
[PermissionTypes.BOOKMARKS]: bookmarkPermissionsSchema,
|
||||
[PermissionTypes.MULTI_CONVO]: multiConvoPermissionsSchema,
|
||||
[PermissionTypes.TEMPORARY_CHAT]: temporaryChatPermissionsSchema,
|
||||
[PermissionTypes.RUN_CODE]: runCodePermissionsSchema,
|
||||
};
|
||||
|
||||
/**
|
||||
* Updates access permissions for a specific role and multiple permission types.
|
||||
* @param {SystemRoles} roleName - The role to update.
|
||||
* @param {string} roleName - The role to update.
|
||||
* @param {Object.<PermissionTypes, Object.<Permissions, boolean>>} permissionsUpdate - Permissions to update and their values.
|
||||
*/
|
||||
async function updateAccessPermissions(roleName, permissionsUpdate) {
|
||||
// Filter and clean the permission updates based on our schema definition.
|
||||
const updates = {};
|
||||
for (const [permissionType, permissions] of Object.entries(permissionsUpdate)) {
|
||||
if (permissionSchemas[permissionType]) {
|
||||
if (permissionsSchema.shape && permissionsSchema.shape[permissionType]) {
|
||||
updates[permissionType] = removeNullishValues(permissions);
|
||||
}
|
||||
}
|
||||
|
||||
if (Object.keys(updates).length === 0) {
|
||||
if (!Object.keys(updates).length) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -109,26 +95,75 @@ async function updateAccessPermissions(roleName, permissionsUpdate) {
|
|||
return;
|
||||
}
|
||||
|
||||
const updatedPermissions = {};
|
||||
const currentPermissions = role.permissions || {};
|
||||
const updatedPermissions = { ...currentPermissions };
|
||||
let hasChanges = false;
|
||||
|
||||
const unsetFields = {};
|
||||
const permissionTypes = Object.keys(permissionsSchema.shape || {});
|
||||
for (const permType of permissionTypes) {
|
||||
if (role[permType] && typeof role[permType] === 'object') {
|
||||
logger.info(
|
||||
`Migrating '${roleName}' role from old schema: found '${permType}' at top level`,
|
||||
);
|
||||
|
||||
updatedPermissions[permType] = {
|
||||
...updatedPermissions[permType],
|
||||
...role[permType],
|
||||
};
|
||||
|
||||
unsetFields[permType] = 1;
|
||||
hasChanges = true;
|
||||
}
|
||||
}
|
||||
|
||||
// Process the current updates
|
||||
for (const [permissionType, permissions] of Object.entries(updates)) {
|
||||
const currentPermissions = role[permissionType] || {};
|
||||
updatedPermissions[permissionType] = { ...currentPermissions };
|
||||
const currentTypePermissions = currentPermissions[permissionType] || {};
|
||||
updatedPermissions[permissionType] = { ...currentTypePermissions };
|
||||
|
||||
for (const [permission, value] of Object.entries(permissions)) {
|
||||
if (currentPermissions[permission] !== value) {
|
||||
if (currentTypePermissions[permission] !== value) {
|
||||
updatedPermissions[permissionType][permission] = value;
|
||||
hasChanges = true;
|
||||
logger.info(
|
||||
`Updating '${roleName}' role ${permissionType} '${permission}' permission from ${currentPermissions[permission]} to: ${value}`,
|
||||
`Updating '${roleName}' role permission '${permissionType}' '${permission}' from ${currentTypePermissions[permission]} to: ${value}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (hasChanges) {
|
||||
await updateRoleByName(roleName, updatedPermissions);
|
||||
const updateObj = { permissions: updatedPermissions };
|
||||
|
||||
if (Object.keys(unsetFields).length > 0) {
|
||||
logger.info(
|
||||
`Unsetting old schema fields for '${roleName}' role: ${Object.keys(unsetFields).join(', ')}`,
|
||||
);
|
||||
|
||||
try {
|
||||
await Role.updateOne(
|
||||
{ name: roleName },
|
||||
{
|
||||
$set: updateObj,
|
||||
$unset: unsetFields,
|
||||
},
|
||||
);
|
||||
|
||||
const cache = getLogStores(CacheKeys.ROLES);
|
||||
const updatedRole = await Role.findOne({ name: roleName }).select('-__v').lean().exec();
|
||||
await cache.set(roleName, updatedRole);
|
||||
|
||||
logger.info(`Updated role '${roleName}' and removed old schema fields`);
|
||||
} catch (updateError) {
|
||||
logger.error(`Error during role migration update: ${updateError.message}`);
|
||||
throw updateError;
|
||||
}
|
||||
} else {
|
||||
// Standard update if no migration needed
|
||||
await updateRoleByName(roleName, updateObj);
|
||||
}
|
||||
|
||||
logger.info(`Updated '${roleName}' role permissions`);
|
||||
} else {
|
||||
logger.info(`No changes needed for '${roleName}' role permissions`);
|
||||
|
|
@ -146,34 +181,111 @@ async function updateAccessPermissions(roleName, permissionsUpdate) {
|
|||
* @returns {Promise<void>}
|
||||
*/
|
||||
const initializeRoles = async function () {
|
||||
const defaultRoles = [SystemRoles.ADMIN, SystemRoles.USER];
|
||||
|
||||
for (const roleName of defaultRoles) {
|
||||
for (const roleName of [SystemRoles.ADMIN, SystemRoles.USER]) {
|
||||
let role = await Role.findOne({ name: roleName });
|
||||
const defaultPerms = roleDefaults[roleName].permissions;
|
||||
|
||||
if (!role) {
|
||||
// Create new role if it doesn't exist
|
||||
// Create new role if it doesn't exist.
|
||||
role = new Role(roleDefaults[roleName]);
|
||||
} else {
|
||||
// Add missing permission types
|
||||
let isUpdated = false;
|
||||
for (const permType of Object.values(PermissionTypes)) {
|
||||
if (!role[permType]) {
|
||||
role[permType] = roleDefaults[roleName][permType];
|
||||
isUpdated = true;
|
||||
// Ensure role.permissions is defined.
|
||||
role.permissions = role.permissions || {};
|
||||
// For each permission type in defaults, add it if missing.
|
||||
for (const permType of Object.keys(defaultPerms)) {
|
||||
if (role.permissions[permType] == null) {
|
||||
role.permissions[permType] = defaultPerms[permType];
|
||||
}
|
||||
}
|
||||
if (isUpdated) {
|
||||
await role.save();
|
||||
}
|
||||
}
|
||||
await role.save();
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Migrates roles from old schema to new schema structure.
|
||||
* This can be called directly to fix existing roles.
|
||||
*
|
||||
* @param {string} [roleName] - Optional specific role to migrate. If not provided, migrates all roles.
|
||||
* @returns {Promise<number>} Number of roles migrated.
|
||||
*/
|
||||
const migrateRoleSchema = async function (roleName) {
|
||||
try {
|
||||
// Get roles to migrate
|
||||
let roles;
|
||||
if (roleName) {
|
||||
const role = await Role.findOne({ name: roleName });
|
||||
roles = role ? [role] : [];
|
||||
} else {
|
||||
roles = await Role.find({});
|
||||
}
|
||||
|
||||
logger.info(`Migrating ${roles.length} roles to new schema structure`);
|
||||
let migratedCount = 0;
|
||||
|
||||
for (const role of roles) {
|
||||
const permissionTypes = Object.keys(permissionsSchema.shape || {});
|
||||
const unsetFields = {};
|
||||
let hasOldSchema = false;
|
||||
|
||||
// Check for old schema fields
|
||||
for (const permType of permissionTypes) {
|
||||
if (role[permType] && typeof role[permType] === 'object') {
|
||||
hasOldSchema = true;
|
||||
|
||||
// Ensure permissions object exists
|
||||
role.permissions = role.permissions || {};
|
||||
|
||||
// Migrate permissions from old location to new
|
||||
role.permissions[permType] = {
|
||||
...role.permissions[permType],
|
||||
...role[permType],
|
||||
};
|
||||
|
||||
// Mark field for removal
|
||||
unsetFields[permType] = 1;
|
||||
}
|
||||
}
|
||||
|
||||
if (hasOldSchema) {
|
||||
try {
|
||||
logger.info(`Migrating role '${role.name}' from old schema structure`);
|
||||
|
||||
// Simple update operation
|
||||
await Role.updateOne(
|
||||
{ _id: role._id },
|
||||
{
|
||||
$set: { permissions: role.permissions },
|
||||
$unset: unsetFields,
|
||||
},
|
||||
);
|
||||
|
||||
// Refresh cache
|
||||
const cache = getLogStores(CacheKeys.ROLES);
|
||||
const updatedRole = await Role.findById(role._id).lean().exec();
|
||||
await cache.set(role.name, updatedRole);
|
||||
|
||||
migratedCount++;
|
||||
logger.info(`Migrated role '${role.name}'`);
|
||||
} catch (error) {
|
||||
logger.error(`Failed to migrate role '${role.name}': ${error.message}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
logger.info(`Migration complete: ${migratedCount} roles migrated`);
|
||||
return migratedCount;
|
||||
} catch (error) {
|
||||
logger.error(`Role schema migration failed: ${error.message}`);
|
||||
throw error;
|
||||
}
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
Role,
|
||||
getRoleByName,
|
||||
initializeRoles,
|
||||
updateRoleByName,
|
||||
updateAccessPermissions,
|
||||
migrateRoleSchema,
|
||||
};
|
||||
|
|
|
|||
|
|
@ -2,22 +2,21 @@ const mongoose = require('mongoose');
|
|||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
const {
|
||||
SystemRoles,
|
||||
PermissionTypes,
|
||||
roleDefaults,
|
||||
Permissions,
|
||||
roleDefaults,
|
||||
PermissionTypes,
|
||||
} = require('librechat-data-provider');
|
||||
const { updateAccessPermissions, initializeRoles } = require('~/models/Role');
|
||||
const { Role, getRoleByName, updateAccessPermissions, initializeRoles } = require('~/models/Role');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const { Role } = require('~/models/Role');
|
||||
|
||||
// Mock the cache
|
||||
jest.mock('~/cache/getLogStores', () => {
|
||||
return jest.fn().mockReturnValue({
|
||||
jest.mock('~/cache/getLogStores', () =>
|
||||
jest.fn().mockReturnValue({
|
||||
get: jest.fn(),
|
||||
set: jest.fn(),
|
||||
del: jest.fn(),
|
||||
});
|
||||
});
|
||||
}),
|
||||
);
|
||||
|
||||
let mongoServer;
|
||||
|
||||
|
|
@ -41,10 +40,12 @@ describe('updateAccessPermissions', () => {
|
|||
it('should update permissions when changes are needed', async () => {
|
||||
await new Role({
|
||||
name: SystemRoles.USER,
|
||||
[PermissionTypes.PROMPTS]: {
|
||||
CREATE: true,
|
||||
USE: true,
|
||||
SHARED_GLOBAL: false,
|
||||
permissions: {
|
||||
[PermissionTypes.PROMPTS]: {
|
||||
CREATE: true,
|
||||
USE: true,
|
||||
SHARED_GLOBAL: false,
|
||||
},
|
||||
},
|
||||
}).save();
|
||||
|
||||
|
|
@ -56,8 +57,8 @@ describe('updateAccessPermissions', () => {
|
|||
},
|
||||
});
|
||||
|
||||
const updatedRole = await Role.findOne({ name: SystemRoles.USER }).lean();
|
||||
expect(updatedRole[PermissionTypes.PROMPTS]).toEqual({
|
||||
const updatedRole = await getRoleByName(SystemRoles.USER);
|
||||
expect(updatedRole.permissions[PermissionTypes.PROMPTS]).toEqual({
|
||||
CREATE: true,
|
||||
USE: true,
|
||||
SHARED_GLOBAL: true,
|
||||
|
|
@ -67,10 +68,12 @@ describe('updateAccessPermissions', () => {
|
|||
it('should not update permissions when no changes are needed', async () => {
|
||||
await new Role({
|
||||
name: SystemRoles.USER,
|
||||
[PermissionTypes.PROMPTS]: {
|
||||
CREATE: true,
|
||||
USE: true,
|
||||
SHARED_GLOBAL: false,
|
||||
permissions: {
|
||||
[PermissionTypes.PROMPTS]: {
|
||||
CREATE: true,
|
||||
USE: true,
|
||||
SHARED_GLOBAL: false,
|
||||
},
|
||||
},
|
||||
}).save();
|
||||
|
||||
|
|
@ -82,8 +85,8 @@ describe('updateAccessPermissions', () => {
|
|||
},
|
||||
});
|
||||
|
||||
const updatedRole = await Role.findOne({ name: SystemRoles.USER }).lean();
|
||||
expect(updatedRole[PermissionTypes.PROMPTS]).toEqual({
|
||||
const updatedRole = await getRoleByName(SystemRoles.USER);
|
||||
expect(updatedRole.permissions[PermissionTypes.PROMPTS]).toEqual({
|
||||
CREATE: true,
|
||||
USE: true,
|
||||
SHARED_GLOBAL: false,
|
||||
|
|
@ -92,11 +95,8 @@ describe('updateAccessPermissions', () => {
|
|||
|
||||
it('should handle non-existent roles', async () => {
|
||||
await updateAccessPermissions('NON_EXISTENT_ROLE', {
|
||||
[PermissionTypes.PROMPTS]: {
|
||||
CREATE: true,
|
||||
},
|
||||
[PermissionTypes.PROMPTS]: { CREATE: true },
|
||||
});
|
||||
|
||||
const role = await Role.findOne({ name: 'NON_EXISTENT_ROLE' });
|
||||
expect(role).toBeNull();
|
||||
});
|
||||
|
|
@ -104,21 +104,21 @@ describe('updateAccessPermissions', () => {
|
|||
it('should update only specified permissions', async () => {
|
||||
await new Role({
|
||||
name: SystemRoles.USER,
|
||||
[PermissionTypes.PROMPTS]: {
|
||||
CREATE: true,
|
||||
USE: true,
|
||||
SHARED_GLOBAL: false,
|
||||
permissions: {
|
||||
[PermissionTypes.PROMPTS]: {
|
||||
CREATE: true,
|
||||
USE: true,
|
||||
SHARED_GLOBAL: false,
|
||||
},
|
||||
},
|
||||
}).save();
|
||||
|
||||
await updateAccessPermissions(SystemRoles.USER, {
|
||||
[PermissionTypes.PROMPTS]: {
|
||||
SHARED_GLOBAL: true,
|
||||
},
|
||||
[PermissionTypes.PROMPTS]: { SHARED_GLOBAL: true },
|
||||
});
|
||||
|
||||
const updatedRole = await Role.findOne({ name: SystemRoles.USER }).lean();
|
||||
expect(updatedRole[PermissionTypes.PROMPTS]).toEqual({
|
||||
const updatedRole = await getRoleByName(SystemRoles.USER);
|
||||
expect(updatedRole.permissions[PermissionTypes.PROMPTS]).toEqual({
|
||||
CREATE: true,
|
||||
USE: true,
|
||||
SHARED_GLOBAL: true,
|
||||
|
|
@ -128,21 +128,21 @@ describe('updateAccessPermissions', () => {
|
|||
it('should handle partial updates', async () => {
|
||||
await new Role({
|
||||
name: SystemRoles.USER,
|
||||
[PermissionTypes.PROMPTS]: {
|
||||
CREATE: true,
|
||||
USE: true,
|
||||
SHARED_GLOBAL: false,
|
||||
permissions: {
|
||||
[PermissionTypes.PROMPTS]: {
|
||||
CREATE: true,
|
||||
USE: true,
|
||||
SHARED_GLOBAL: false,
|
||||
},
|
||||
},
|
||||
}).save();
|
||||
|
||||
await updateAccessPermissions(SystemRoles.USER, {
|
||||
[PermissionTypes.PROMPTS]: {
|
||||
USE: false,
|
||||
},
|
||||
[PermissionTypes.PROMPTS]: { USE: false },
|
||||
});
|
||||
|
||||
const updatedRole = await Role.findOne({ name: SystemRoles.USER }).lean();
|
||||
expect(updatedRole[PermissionTypes.PROMPTS]).toEqual({
|
||||
const updatedRole = await getRoleByName(SystemRoles.USER);
|
||||
expect(updatedRole.permissions[PermissionTypes.PROMPTS]).toEqual({
|
||||
CREATE: true,
|
||||
USE: false,
|
||||
SHARED_GLOBAL: false,
|
||||
|
|
@ -152,13 +152,9 @@ describe('updateAccessPermissions', () => {
|
|||
it('should update multiple permission types at once', async () => {
|
||||
await new Role({
|
||||
name: SystemRoles.USER,
|
||||
[PermissionTypes.PROMPTS]: {
|
||||
CREATE: true,
|
||||
USE: true,
|
||||
SHARED_GLOBAL: false,
|
||||
},
|
||||
[PermissionTypes.BOOKMARKS]: {
|
||||
USE: true,
|
||||
permissions: {
|
||||
[PermissionTypes.PROMPTS]: { CREATE: true, USE: true, SHARED_GLOBAL: false },
|
||||
[PermissionTypes.BOOKMARKS]: { USE: true },
|
||||
},
|
||||
}).save();
|
||||
|
||||
|
|
@ -167,24 +163,20 @@ describe('updateAccessPermissions', () => {
|
|||
[PermissionTypes.BOOKMARKS]: { USE: false },
|
||||
});
|
||||
|
||||
const updatedRole = await Role.findOne({ name: SystemRoles.USER }).lean();
|
||||
expect(updatedRole[PermissionTypes.PROMPTS]).toEqual({
|
||||
const updatedRole = await getRoleByName(SystemRoles.USER);
|
||||
expect(updatedRole.permissions[PermissionTypes.PROMPTS]).toEqual({
|
||||
CREATE: true,
|
||||
USE: false,
|
||||
SHARED_GLOBAL: true,
|
||||
});
|
||||
expect(updatedRole[PermissionTypes.BOOKMARKS]).toEqual({
|
||||
USE: false,
|
||||
});
|
||||
expect(updatedRole.permissions[PermissionTypes.BOOKMARKS]).toEqual({ USE: false });
|
||||
});
|
||||
|
||||
it('should handle updates for a single permission type', async () => {
|
||||
await new Role({
|
||||
name: SystemRoles.USER,
|
||||
[PermissionTypes.PROMPTS]: {
|
||||
CREATE: true,
|
||||
USE: true,
|
||||
SHARED_GLOBAL: false,
|
||||
permissions: {
|
||||
[PermissionTypes.PROMPTS]: { CREATE: true, USE: true, SHARED_GLOBAL: false },
|
||||
},
|
||||
}).save();
|
||||
|
||||
|
|
@ -192,8 +184,8 @@ describe('updateAccessPermissions', () => {
|
|||
[PermissionTypes.PROMPTS]: { USE: false, SHARED_GLOBAL: true },
|
||||
});
|
||||
|
||||
const updatedRole = await Role.findOne({ name: SystemRoles.USER }).lean();
|
||||
expect(updatedRole[PermissionTypes.PROMPTS]).toEqual({
|
||||
const updatedRole = await getRoleByName(SystemRoles.USER);
|
||||
expect(updatedRole.permissions[PermissionTypes.PROMPTS]).toEqual({
|
||||
CREATE: true,
|
||||
USE: false,
|
||||
SHARED_GLOBAL: true,
|
||||
|
|
@ -203,33 +195,25 @@ describe('updateAccessPermissions', () => {
|
|||
it('should update MULTI_CONVO permissions', async () => {
|
||||
await new Role({
|
||||
name: SystemRoles.USER,
|
||||
[PermissionTypes.MULTI_CONVO]: {
|
||||
USE: false,
|
||||
permissions: {
|
||||
[PermissionTypes.MULTI_CONVO]: { USE: false },
|
||||
},
|
||||
}).save();
|
||||
|
||||
await updateAccessPermissions(SystemRoles.USER, {
|
||||
[PermissionTypes.MULTI_CONVO]: {
|
||||
USE: true,
|
||||
},
|
||||
[PermissionTypes.MULTI_CONVO]: { USE: true },
|
||||
});
|
||||
|
||||
const updatedRole = await Role.findOne({ name: SystemRoles.USER }).lean();
|
||||
expect(updatedRole[PermissionTypes.MULTI_CONVO]).toEqual({
|
||||
USE: true,
|
||||
});
|
||||
const updatedRole = await getRoleByName(SystemRoles.USER);
|
||||
expect(updatedRole.permissions[PermissionTypes.MULTI_CONVO]).toEqual({ USE: true });
|
||||
});
|
||||
|
||||
it('should update MULTI_CONVO permissions along with other permission types', async () => {
|
||||
await new Role({
|
||||
name: SystemRoles.USER,
|
||||
[PermissionTypes.PROMPTS]: {
|
||||
CREATE: true,
|
||||
USE: true,
|
||||
SHARED_GLOBAL: false,
|
||||
},
|
||||
[PermissionTypes.MULTI_CONVO]: {
|
||||
USE: false,
|
||||
permissions: {
|
||||
[PermissionTypes.PROMPTS]: { CREATE: true, USE: true, SHARED_GLOBAL: false },
|
||||
[PermissionTypes.MULTI_CONVO]: { USE: false },
|
||||
},
|
||||
}).save();
|
||||
|
||||
|
|
@ -238,35 +222,29 @@ describe('updateAccessPermissions', () => {
|
|||
[PermissionTypes.MULTI_CONVO]: { USE: true },
|
||||
});
|
||||
|
||||
const updatedRole = await Role.findOne({ name: SystemRoles.USER }).lean();
|
||||
expect(updatedRole[PermissionTypes.PROMPTS]).toEqual({
|
||||
const updatedRole = await getRoleByName(SystemRoles.USER);
|
||||
expect(updatedRole.permissions[PermissionTypes.PROMPTS]).toEqual({
|
||||
CREATE: true,
|
||||
USE: true,
|
||||
SHARED_GLOBAL: true,
|
||||
});
|
||||
expect(updatedRole[PermissionTypes.MULTI_CONVO]).toEqual({
|
||||
USE: true,
|
||||
});
|
||||
expect(updatedRole.permissions[PermissionTypes.MULTI_CONVO]).toEqual({ USE: true });
|
||||
});
|
||||
|
||||
it('should not update MULTI_CONVO permissions when no changes are needed', async () => {
|
||||
await new Role({
|
||||
name: SystemRoles.USER,
|
||||
[PermissionTypes.MULTI_CONVO]: {
|
||||
USE: true,
|
||||
permissions: {
|
||||
[PermissionTypes.MULTI_CONVO]: { USE: true },
|
||||
},
|
||||
}).save();
|
||||
|
||||
await updateAccessPermissions(SystemRoles.USER, {
|
||||
[PermissionTypes.MULTI_CONVO]: {
|
||||
USE: true,
|
||||
},
|
||||
[PermissionTypes.MULTI_CONVO]: { USE: true },
|
||||
});
|
||||
|
||||
const updatedRole = await Role.findOne({ name: SystemRoles.USER }).lean();
|
||||
expect(updatedRole[PermissionTypes.MULTI_CONVO]).toEqual({
|
||||
USE: true,
|
||||
});
|
||||
const updatedRole = await getRoleByName(SystemRoles.USER);
|
||||
expect(updatedRole.permissions[PermissionTypes.MULTI_CONVO]).toEqual({ USE: true });
|
||||
});
|
||||
});
|
||||
|
||||
|
|
@ -278,65 +256,69 @@ describe('initializeRoles', () => {
|
|||
it('should create default roles if they do not exist', async () => {
|
||||
await initializeRoles();
|
||||
|
||||
const adminRole = await Role.findOne({ name: SystemRoles.ADMIN }).lean();
|
||||
const userRole = await Role.findOne({ name: SystemRoles.USER }).lean();
|
||||
const adminRole = await getRoleByName(SystemRoles.ADMIN);
|
||||
const userRole = await getRoleByName(SystemRoles.USER);
|
||||
|
||||
expect(adminRole).toBeTruthy();
|
||||
expect(userRole).toBeTruthy();
|
||||
|
||||
// Check if all permission types exist
|
||||
// Check if all permission types exist in the permissions field
|
||||
Object.values(PermissionTypes).forEach((permType) => {
|
||||
expect(adminRole[permType]).toBeDefined();
|
||||
expect(userRole[permType]).toBeDefined();
|
||||
expect(adminRole.permissions[permType]).toBeDefined();
|
||||
expect(userRole.permissions[permType]).toBeDefined();
|
||||
});
|
||||
|
||||
// Check if permissions match defaults (example for ADMIN role)
|
||||
expect(adminRole[PermissionTypes.PROMPTS].SHARED_GLOBAL).toBe(true);
|
||||
expect(adminRole[PermissionTypes.BOOKMARKS].USE).toBe(true);
|
||||
expect(adminRole[PermissionTypes.AGENTS].CREATE).toBe(true);
|
||||
// Example: Check default values for ADMIN role
|
||||
expect(adminRole.permissions[PermissionTypes.PROMPTS].SHARED_GLOBAL).toBe(true);
|
||||
expect(adminRole.permissions[PermissionTypes.BOOKMARKS].USE).toBe(true);
|
||||
expect(adminRole.permissions[PermissionTypes.AGENTS].CREATE).toBe(true);
|
||||
});
|
||||
|
||||
it('should not modify existing permissions for existing roles', async () => {
|
||||
const customUserRole = {
|
||||
name: SystemRoles.USER,
|
||||
[PermissionTypes.PROMPTS]: {
|
||||
[Permissions.USE]: false,
|
||||
[Permissions.CREATE]: true,
|
||||
[Permissions.SHARED_GLOBAL]: true,
|
||||
},
|
||||
[PermissionTypes.BOOKMARKS]: {
|
||||
[Permissions.USE]: false,
|
||||
permissions: {
|
||||
[PermissionTypes.PROMPTS]: {
|
||||
[Permissions.USE]: false,
|
||||
[Permissions.CREATE]: true,
|
||||
[Permissions.SHARED_GLOBAL]: true,
|
||||
},
|
||||
[PermissionTypes.BOOKMARKS]: { [Permissions.USE]: false },
|
||||
},
|
||||
};
|
||||
|
||||
await new Role(customUserRole).save();
|
||||
|
||||
await initializeRoles();
|
||||
|
||||
const userRole = await Role.findOne({ name: SystemRoles.USER }).lean();
|
||||
|
||||
expect(userRole[PermissionTypes.PROMPTS]).toEqual(customUserRole[PermissionTypes.PROMPTS]);
|
||||
expect(userRole[PermissionTypes.BOOKMARKS]).toEqual(customUserRole[PermissionTypes.BOOKMARKS]);
|
||||
expect(userRole[PermissionTypes.AGENTS]).toBeDefined();
|
||||
const userRole = await getRoleByName(SystemRoles.USER);
|
||||
expect(userRole.permissions[PermissionTypes.PROMPTS]).toEqual(
|
||||
customUserRole.permissions[PermissionTypes.PROMPTS],
|
||||
);
|
||||
expect(userRole.permissions[PermissionTypes.BOOKMARKS]).toEqual(
|
||||
customUserRole.permissions[PermissionTypes.BOOKMARKS],
|
||||
);
|
||||
expect(userRole.permissions[PermissionTypes.AGENTS]).toBeDefined();
|
||||
});
|
||||
|
||||
it('should add new permission types to existing roles', async () => {
|
||||
const partialUserRole = {
|
||||
name: SystemRoles.USER,
|
||||
[PermissionTypes.PROMPTS]: roleDefaults[SystemRoles.USER][PermissionTypes.PROMPTS],
|
||||
[PermissionTypes.BOOKMARKS]: roleDefaults[SystemRoles.USER][PermissionTypes.BOOKMARKS],
|
||||
permissions: {
|
||||
[PermissionTypes.PROMPTS]:
|
||||
roleDefaults[SystemRoles.USER].permissions[PermissionTypes.PROMPTS],
|
||||
[PermissionTypes.BOOKMARKS]:
|
||||
roleDefaults[SystemRoles.USER].permissions[PermissionTypes.BOOKMARKS],
|
||||
},
|
||||
};
|
||||
|
||||
await new Role(partialUserRole).save();
|
||||
|
||||
await initializeRoles();
|
||||
|
||||
const userRole = await Role.findOne({ name: SystemRoles.USER }).lean();
|
||||
|
||||
expect(userRole[PermissionTypes.AGENTS]).toBeDefined();
|
||||
expect(userRole[PermissionTypes.AGENTS].CREATE).toBeDefined();
|
||||
expect(userRole[PermissionTypes.AGENTS].USE).toBeDefined();
|
||||
expect(userRole[PermissionTypes.AGENTS].SHARED_GLOBAL).toBeDefined();
|
||||
const userRole = await getRoleByName(SystemRoles.USER);
|
||||
expect(userRole.permissions[PermissionTypes.AGENTS]).toBeDefined();
|
||||
expect(userRole.permissions[PermissionTypes.AGENTS].CREATE).toBeDefined();
|
||||
expect(userRole.permissions[PermissionTypes.AGENTS].USE).toBeDefined();
|
||||
expect(userRole.permissions[PermissionTypes.AGENTS].SHARED_GLOBAL).toBeDefined();
|
||||
});
|
||||
|
||||
it('should handle multiple runs without duplicating or modifying data', async () => {
|
||||
|
|
@ -349,72 +331,73 @@ describe('initializeRoles', () => {
|
|||
expect(adminRoles).toHaveLength(1);
|
||||
expect(userRoles).toHaveLength(1);
|
||||
|
||||
const adminRole = adminRoles[0].toObject();
|
||||
const userRole = userRoles[0].toObject();
|
||||
|
||||
// Check if all permission types exist
|
||||
const adminPerms = adminRoles[0].toObject().permissions;
|
||||
const userPerms = userRoles[0].toObject().permissions;
|
||||
Object.values(PermissionTypes).forEach((permType) => {
|
||||
expect(adminRole[permType]).toBeDefined();
|
||||
expect(userRole[permType]).toBeDefined();
|
||||
expect(adminPerms[permType]).toBeDefined();
|
||||
expect(userPerms[permType]).toBeDefined();
|
||||
});
|
||||
});
|
||||
|
||||
it('should update roles with missing permission types from roleDefaults', async () => {
|
||||
const partialAdminRole = {
|
||||
name: SystemRoles.ADMIN,
|
||||
[PermissionTypes.PROMPTS]: {
|
||||
[Permissions.USE]: false,
|
||||
[Permissions.CREATE]: false,
|
||||
[Permissions.SHARED_GLOBAL]: false,
|
||||
permissions: {
|
||||
[PermissionTypes.PROMPTS]: {
|
||||
[Permissions.USE]: false,
|
||||
[Permissions.CREATE]: false,
|
||||
[Permissions.SHARED_GLOBAL]: false,
|
||||
},
|
||||
[PermissionTypes.BOOKMARKS]:
|
||||
roleDefaults[SystemRoles.ADMIN].permissions[PermissionTypes.BOOKMARKS],
|
||||
},
|
||||
[PermissionTypes.BOOKMARKS]: roleDefaults[SystemRoles.ADMIN][PermissionTypes.BOOKMARKS],
|
||||
};
|
||||
|
||||
await new Role(partialAdminRole).save();
|
||||
|
||||
await initializeRoles();
|
||||
|
||||
const adminRole = await Role.findOne({ name: SystemRoles.ADMIN }).lean();
|
||||
|
||||
expect(adminRole[PermissionTypes.PROMPTS]).toEqual(partialAdminRole[PermissionTypes.PROMPTS]);
|
||||
expect(adminRole[PermissionTypes.AGENTS]).toBeDefined();
|
||||
expect(adminRole[PermissionTypes.AGENTS].CREATE).toBeDefined();
|
||||
expect(adminRole[PermissionTypes.AGENTS].USE).toBeDefined();
|
||||
expect(adminRole[PermissionTypes.AGENTS].SHARED_GLOBAL).toBeDefined();
|
||||
const adminRole = await getRoleByName(SystemRoles.ADMIN);
|
||||
expect(adminRole.permissions[PermissionTypes.PROMPTS]).toEqual(
|
||||
partialAdminRole.permissions[PermissionTypes.PROMPTS],
|
||||
);
|
||||
expect(adminRole.permissions[PermissionTypes.AGENTS]).toBeDefined();
|
||||
expect(adminRole.permissions[PermissionTypes.AGENTS].CREATE).toBeDefined();
|
||||
expect(adminRole.permissions[PermissionTypes.AGENTS].USE).toBeDefined();
|
||||
expect(adminRole.permissions[PermissionTypes.AGENTS].SHARED_GLOBAL).toBeDefined();
|
||||
});
|
||||
|
||||
it('should include MULTI_CONVO permissions when creating default roles', async () => {
|
||||
await initializeRoles();
|
||||
|
||||
const adminRole = await Role.findOne({ name: SystemRoles.ADMIN }).lean();
|
||||
const userRole = await Role.findOne({ name: SystemRoles.USER }).lean();
|
||||
const adminRole = await getRoleByName(SystemRoles.ADMIN);
|
||||
const userRole = await getRoleByName(SystemRoles.USER);
|
||||
|
||||
expect(adminRole[PermissionTypes.MULTI_CONVO]).toBeDefined();
|
||||
expect(userRole[PermissionTypes.MULTI_CONVO]).toBeDefined();
|
||||
|
||||
// Check if MULTI_CONVO permissions match defaults
|
||||
expect(adminRole[PermissionTypes.MULTI_CONVO].USE).toBe(
|
||||
roleDefaults[SystemRoles.ADMIN][PermissionTypes.MULTI_CONVO].USE,
|
||||
expect(adminRole.permissions[PermissionTypes.MULTI_CONVO]).toBeDefined();
|
||||
expect(userRole.permissions[PermissionTypes.MULTI_CONVO]).toBeDefined();
|
||||
expect(adminRole.permissions[PermissionTypes.MULTI_CONVO].USE).toBe(
|
||||
roleDefaults[SystemRoles.ADMIN].permissions[PermissionTypes.MULTI_CONVO].USE,
|
||||
);
|
||||
expect(userRole[PermissionTypes.MULTI_CONVO].USE).toBe(
|
||||
roleDefaults[SystemRoles.USER][PermissionTypes.MULTI_CONVO].USE,
|
||||
expect(userRole.permissions[PermissionTypes.MULTI_CONVO].USE).toBe(
|
||||
roleDefaults[SystemRoles.USER].permissions[PermissionTypes.MULTI_CONVO].USE,
|
||||
);
|
||||
});
|
||||
|
||||
it('should add MULTI_CONVO permissions to existing roles without them', async () => {
|
||||
const partialUserRole = {
|
||||
name: SystemRoles.USER,
|
||||
[PermissionTypes.PROMPTS]: roleDefaults[SystemRoles.USER][PermissionTypes.PROMPTS],
|
||||
[PermissionTypes.BOOKMARKS]: roleDefaults[SystemRoles.USER][PermissionTypes.BOOKMARKS],
|
||||
permissions: {
|
||||
[PermissionTypes.PROMPTS]:
|
||||
roleDefaults[SystemRoles.USER].permissions[PermissionTypes.PROMPTS],
|
||||
[PermissionTypes.BOOKMARKS]:
|
||||
roleDefaults[SystemRoles.USER].permissions[PermissionTypes.BOOKMARKS],
|
||||
},
|
||||
};
|
||||
|
||||
await new Role(partialUserRole).save();
|
||||
|
||||
await initializeRoles();
|
||||
|
||||
const userRole = await Role.findOne({ name: SystemRoles.USER }).lean();
|
||||
|
||||
expect(userRole[PermissionTypes.MULTI_CONVO]).toBeDefined();
|
||||
expect(userRole[PermissionTypes.MULTI_CONVO].USE).toBeDefined();
|
||||
const userRole = await getRoleByName(SystemRoles.USER);
|
||||
expect(userRole.permissions[PermissionTypes.MULTI_CONVO]).toBeDefined();
|
||||
expect(userRole.permissions[PermissionTypes.MULTI_CONVO].USE).toBeDefined();
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -8,39 +8,135 @@ const Balance = require('./Balance');
|
|||
const cancelRate = 1.15;
|
||||
|
||||
/**
|
||||
* Updates a user's token balance based on a transaction.
|
||||
*
|
||||
* Updates a user's token balance based on a transaction using optimistic concurrency control
|
||||
* without schema changes. Compatible with DocumentDB.
|
||||
* @async
|
||||
* @function
|
||||
* @param {Object} params - The function parameters.
|
||||
* @param {string} params.user - The user ID.
|
||||
* @param {string|mongoose.Types.ObjectId} params.user - The user ID.
|
||||
* @param {number} params.incrementValue - The value to increment the balance by (can be negative).
|
||||
* @param {import('mongoose').UpdateQuery<import('@librechat/data-schemas').IBalance>['$set']} params.setValues
|
||||
* @returns {Promise<Object>} Returns the updated balance response.
|
||||
* @param {import('mongoose').UpdateQuery<import('@librechat/data-schemas').IBalance>['$set']} [params.setValues] - Optional additional fields to set.
|
||||
* @returns {Promise<Object>} Returns the updated balance document (lean).
|
||||
* @throws {Error} Throws an error if the update fails after multiple retries.
|
||||
*/
|
||||
const updateBalance = async ({ user, incrementValue, setValues }) => {
|
||||
// Use findOneAndUpdate with a conditional update to make the balance update atomic
|
||||
// This prevents race conditions when multiple transactions are processed concurrently
|
||||
const balanceResponse = await Balance.findOneAndUpdate(
|
||||
{ user },
|
||||
[
|
||||
{
|
||||
$set: {
|
||||
tokenCredits: {
|
||||
$cond: {
|
||||
if: { $lt: [{ $add: ['$tokenCredits', incrementValue] }, 0] },
|
||||
then: 0,
|
||||
else: { $add: ['$tokenCredits', incrementValue] },
|
||||
},
|
||||
},
|
||||
...setValues,
|
||||
},
|
||||
},
|
||||
],
|
||||
{ upsert: true, new: true },
|
||||
).lean();
|
||||
let maxRetries = 10; // Number of times to retry on conflict
|
||||
let delay = 50; // Initial retry delay in ms
|
||||
let lastError = null;
|
||||
|
||||
return balanceResponse;
|
||||
for (let attempt = 1; attempt <= maxRetries; attempt++) {
|
||||
let currentBalanceDoc;
|
||||
try {
|
||||
// 1. Read the current document state
|
||||
currentBalanceDoc = await Balance.findOne({ user }).lean();
|
||||
const currentCredits = currentBalanceDoc ? currentBalanceDoc.tokenCredits : 0;
|
||||
|
||||
// 2. Calculate the desired new state
|
||||
const potentialNewCredits = currentCredits + incrementValue;
|
||||
const newCredits = Math.max(0, potentialNewCredits); // Ensure balance doesn't go below zero
|
||||
|
||||
// 3. Prepare the update payload
|
||||
const updatePayload = {
|
||||
$set: {
|
||||
tokenCredits: newCredits,
|
||||
...(setValues || {}), // Merge other values to set
|
||||
},
|
||||
};
|
||||
|
||||
// 4. Attempt the conditional update or upsert
|
||||
let updatedBalance = null;
|
||||
if (currentBalanceDoc) {
|
||||
// --- Document Exists: Perform Conditional Update ---
|
||||
// Try to update only if the tokenCredits match the value we read (currentCredits)
|
||||
updatedBalance = await Balance.findOneAndUpdate(
|
||||
{
|
||||
user: user,
|
||||
tokenCredits: currentCredits, // Optimistic lock: condition based on the read value
|
||||
},
|
||||
updatePayload,
|
||||
{
|
||||
new: true, // Return the modified document
|
||||
// lean: true, // .lean() is applied after query execution in Mongoose >= 6
|
||||
},
|
||||
).lean(); // Use lean() for plain JS object
|
||||
|
||||
if (updatedBalance) {
|
||||
// Success! The update was applied based on the expected current state.
|
||||
return updatedBalance;
|
||||
}
|
||||
// If updatedBalance is null, it means tokenCredits changed between read and write (conflict).
|
||||
lastError = new Error(`Concurrency conflict for user ${user} on attempt ${attempt}.`);
|
||||
// Proceed to retry logic below.
|
||||
} else {
|
||||
// --- Document Does Not Exist: Perform Conditional Upsert ---
|
||||
// Try to insert the document, but only if it still doesn't exist.
|
||||
// Using tokenCredits: {$exists: false} helps prevent race conditions where
|
||||
// another process creates the doc between our findOne and findOneAndUpdate.
|
||||
try {
|
||||
updatedBalance = await Balance.findOneAndUpdate(
|
||||
{
|
||||
user: user,
|
||||
// Attempt to match only if the document doesn't exist OR was just created
|
||||
// without tokenCredits (less likely but possible). A simple { user } filter
|
||||
// might also work, relying on the retry for conflicts.
|
||||
// Let's use a simpler filter and rely on retry for races.
|
||||
// tokenCredits: { $exists: false } // This condition might be too strict if doc exists with 0 credits
|
||||
},
|
||||
updatePayload,
|
||||
{
|
||||
upsert: true, // Create if doesn't exist
|
||||
new: true, // Return the created/updated document
|
||||
// setDefaultsOnInsert: true, // Ensure schema defaults are applied on insert
|
||||
// lean: true,
|
||||
},
|
||||
).lean();
|
||||
|
||||
if (updatedBalance) {
|
||||
// Upsert succeeded (likely created the document)
|
||||
return updatedBalance;
|
||||
}
|
||||
// If null, potentially a rare race condition during upsert. Retry should handle it.
|
||||
lastError = new Error(
|
||||
`Upsert race condition suspected for user ${user} on attempt ${attempt}.`,
|
||||
);
|
||||
} catch (error) {
|
||||
if (error.code === 11000) {
|
||||
// E11000 duplicate key error on index
|
||||
// This means another process created the document *just* before our upsert.
|
||||
// It's a concurrency conflict during creation. We should retry.
|
||||
lastError = error; // Store the error
|
||||
// Proceed to retry logic below.
|
||||
} else {
|
||||
// Different error, rethrow
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
} // End if/else (document exists?)
|
||||
} catch (error) {
|
||||
// Catch errors from findOne or unexpected findOneAndUpdate errors
|
||||
logger.error(`[updateBalance] Error during attempt ${attempt} for user ${user}:`, error);
|
||||
lastError = error; // Store the error
|
||||
// Consider stopping retries for non-transient errors, but for now, we retry.
|
||||
}
|
||||
|
||||
// If we reached here, it means the update failed (conflict or error), wait and retry
|
||||
if (attempt < maxRetries) {
|
||||
const jitter = Math.random() * delay * 0.5; // Add jitter to delay
|
||||
await new Promise((resolve) => setTimeout(resolve, delay + jitter));
|
||||
delay = Math.min(delay * 2, 2000); // Exponential backoff with cap
|
||||
}
|
||||
} // End for loop (retries)
|
||||
|
||||
// If loop finishes without success, throw the last encountered error or a generic one
|
||||
logger.error(
|
||||
`[updateBalance] Failed to update balance for user ${user} after ${maxRetries} attempts.`,
|
||||
);
|
||||
throw (
|
||||
lastError ||
|
||||
new Error(
|
||||
`Failed to update balance for user ${user} after maximum retries due to persistent conflicts.`,
|
||||
)
|
||||
);
|
||||
};
|
||||
|
||||
/** Method to calculate and set the tokenValue for a transaction */
|
||||
|
|
|
|||
|
|
@ -5,6 +5,10 @@ const { getMultiplier } = require('./tx');
|
|||
const { logger } = require('~/config');
|
||||
const Balance = require('./Balance');
|
||||
|
||||
function isInvalidDate(date) {
|
||||
return isNaN(date);
|
||||
}
|
||||
|
||||
/**
|
||||
* Simple check method that calculates token cost and returns balance info.
|
||||
* The auto-refill logic has been moved to balanceMethods.js to prevent circular dependencies.
|
||||
|
|
@ -48,13 +52,12 @@ const checkBalanceRecord = async function ({
|
|||
// Only perform auto-refill if spending would bring the balance to 0 or below
|
||||
if (balance - tokenCost <= 0 && record.autoRefillEnabled && record.refillAmount > 0) {
|
||||
const lastRefillDate = new Date(record.lastRefill);
|
||||
const nextRefillDate = addIntervalToDate(
|
||||
lastRefillDate,
|
||||
record.refillIntervalValue,
|
||||
record.refillIntervalUnit,
|
||||
);
|
||||
const now = new Date();
|
||||
if (now >= nextRefillDate) {
|
||||
if (
|
||||
isInvalidDate(lastRefillDate) ||
|
||||
now >=
|
||||
addIntervalToDate(lastRefillDate, record.refillIntervalValue, record.refillIntervalUnit)
|
||||
) {
|
||||
try {
|
||||
/** @type {{ rate: number, user: string, balance: number, transaction: import('@librechat/data-schemas').ITransaction}} */
|
||||
const result = await Transaction.createAutoRefillTransaction({
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
const _ = require('lodash');
|
||||
const mongoose = require('mongoose');
|
||||
const { MeiliSearch } = require('meilisearch');
|
||||
const { parseTextParts, ContentTypes } = require('librechat-data-provider');
|
||||
const { cleanUpPrimaryKeyValue } = require('~/lib/utils/misc');
|
||||
const logger = require('~/config/meiliLogger');
|
||||
|
||||
|
|
@ -238,10 +239,7 @@ const createMeiliMongooseModel = function ({ index, attributesToIndex }) {
|
|||
}
|
||||
|
||||
if (object.content && Array.isArray(object.content)) {
|
||||
object.text = object.content
|
||||
.filter((item) => item.type === 'text' && item.text && item.text.value)
|
||||
.map((item) => item.text.value)
|
||||
.join(' ');
|
||||
object.text = parseTextParts(object.content);
|
||||
delete object.content;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -459,7 +459,7 @@ describe('spendTokens', () => {
|
|||
|
||||
it('should handle multiple concurrent transactions correctly with a high balance', async () => {
|
||||
// Create a balance with a high amount
|
||||
const initialBalance = 1000000;
|
||||
const initialBalance = 10000000;
|
||||
await Balance.create({
|
||||
user: userId,
|
||||
tokenCredits: initialBalance,
|
||||
|
|
@ -470,8 +470,9 @@ describe('spendTokens', () => {
|
|||
const context = 'message';
|
||||
const model = 'gpt-4';
|
||||
|
||||
// Create 10 usage records to simulate multiple transactions
|
||||
const collectedUsage = Array.from({ length: 10 }, (_, i) => ({
|
||||
const amount = 50;
|
||||
// Create `amount` of usage records to simulate multiple transactions
|
||||
const collectedUsage = Array.from({ length: amount }, (_, i) => ({
|
||||
model,
|
||||
input_tokens: 100 + i * 10, // Increasing input tokens
|
||||
output_tokens: 50 + i * 5, // Increasing output tokens
|
||||
|
|
@ -591,6 +592,80 @@ describe('spendTokens', () => {
|
|||
expect(Math.abs(totalTokenValue)).toBeCloseTo(actualSpend, -3); // Allow for larger differences
|
||||
});
|
||||
|
||||
// Add this new test case
|
||||
it('should handle multiple concurrent balance increases correctly', async () => {
|
||||
// Start with zero balance
|
||||
const initialBalance = 0;
|
||||
await Balance.create({
|
||||
user: userId,
|
||||
tokenCredits: initialBalance,
|
||||
});
|
||||
|
||||
const numberOfRefills = 25;
|
||||
const refillAmount = 1000;
|
||||
|
||||
const promises = [];
|
||||
for (let i = 0; i < numberOfRefills; i++) {
|
||||
promises.push(
|
||||
Transaction.createAutoRefillTransaction({
|
||||
user: userId,
|
||||
tokenType: 'credits',
|
||||
context: 'concurrent-refill-test',
|
||||
rawAmount: refillAmount,
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
// Wait for all refill transactions to complete
|
||||
const results = await Promise.all(promises);
|
||||
|
||||
// Verify final balance
|
||||
const finalBalance = await Balance.findOne({ user: userId });
|
||||
expect(finalBalance).toBeDefined();
|
||||
|
||||
// The final balance should be the initial balance plus the sum of all refills
|
||||
const expectedFinalBalance = initialBalance + numberOfRefills * refillAmount;
|
||||
|
||||
console.log('Initial balance (Increase Test):', initialBalance);
|
||||
console.log(`Performed ${numberOfRefills} refills of ${refillAmount} each.`);
|
||||
console.log('Expected final balance (Increase Test):', expectedFinalBalance);
|
||||
console.log('Actual final balance (Increase Test):', finalBalance.tokenCredits);
|
||||
|
||||
// Use toBeCloseTo for safety, though toBe should work for integer math
|
||||
expect(finalBalance.tokenCredits).toBeCloseTo(expectedFinalBalance, 0);
|
||||
|
||||
// Verify all transactions were created
|
||||
const transactions = await Transaction.find({
|
||||
user: userId,
|
||||
context: 'concurrent-refill-test',
|
||||
});
|
||||
|
||||
// We should have one transaction for each refill attempt
|
||||
expect(transactions.length).toBe(numberOfRefills);
|
||||
|
||||
// Optional: Verify the sum of increments from the results matches the balance change
|
||||
const totalIncrementReported = results.reduce((sum, result) => {
|
||||
// Assuming createAutoRefillTransaction returns an object with the increment amount
|
||||
// Adjust this based on the actual return structure.
|
||||
// Let's assume it returns { balance: newBalance, transaction: { rawAmount: ... } }
|
||||
// Or perhaps we check the transaction.rawAmount directly
|
||||
return sum + (result?.transaction?.rawAmount || 0);
|
||||
}, 0);
|
||||
console.log('Total increment reported by results:', totalIncrementReported);
|
||||
expect(totalIncrementReported).toBe(expectedFinalBalance - initialBalance);
|
||||
|
||||
// Optional: Check the sum of tokenValue from saved transactions
|
||||
let totalTokenValueFromDb = 0;
|
||||
transactions.forEach((tx) => {
|
||||
// For refills, rawAmount is positive, and tokenValue might be calculated based on it
|
||||
// Let's assume tokenValue directly reflects the increment for simplicity here
|
||||
// If calculation is involved, adjust accordingly
|
||||
totalTokenValueFromDb += tx.rawAmount; // Or tx.tokenValue if that holds the increment
|
||||
});
|
||||
console.log('Total rawAmount from DB transactions:', totalTokenValueFromDb);
|
||||
expect(totalTokenValueFromDb).toBeCloseTo(expectedFinalBalance - initialBalance, 0);
|
||||
});
|
||||
|
||||
it('should create structured transactions for both prompt and completion tokens', async () => {
|
||||
// Create a balance for the user
|
||||
await Balance.create({
|
||||
|
|
|
|||
|
|
@ -107,8 +107,10 @@ const tokenValues = Object.assign(
|
|||
so this was from https://artificialanalysis.ai/models/command-light/providers */
|
||||
command: { prompt: 0.38, completion: 0.38 },
|
||||
'gemini-2.0-flash-lite': { prompt: 0.075, completion: 0.3 },
|
||||
'gemini-2.0-flash': { prompt: 0.1, completion: 0.7 },
|
||||
'gemini-2.0-flash': { prompt: 0.1, completion: 0.4 },
|
||||
'gemini-2.0': { prompt: 0, completion: 0 }, // https://ai.google.dev/pricing
|
||||
'gemini-2.5-pro-preview-03-25': { prompt: 1.25, completion: 10 },
|
||||
'gemini-2.5': { prompt: 0, completion: 0 }, // Free for a period of time
|
||||
'gemini-1.5-flash-8b': { prompt: 0.075, completion: 0.3 },
|
||||
'gemini-1.5-flash': { prompt: 0.15, completion: 0.6 },
|
||||
'gemini-1.5': { prompt: 2.5, completion: 10 },
|
||||
|
|
@ -122,6 +124,12 @@ const tokenValues = Object.assign(
|
|||
'grok-2-latest': { prompt: 2.0, completion: 10.0 },
|
||||
'grok-2': { prompt: 2.0, completion: 10.0 },
|
||||
'grok-beta': { prompt: 5.0, completion: 15.0 },
|
||||
'mistral-large': { prompt: 2.0, completion: 6.0 },
|
||||
'pixtral-large': { prompt: 2.0, completion: 6.0 },
|
||||
'mistral-saba': { prompt: 0.2, completion: 0.6 },
|
||||
codestral: { prompt: 0.3, completion: 0.9 },
|
||||
'ministral-8b': { prompt: 0.1, completion: 0.1 },
|
||||
'ministral-3b': { prompt: 0.04, completion: 0.04 },
|
||||
},
|
||||
bedrockValues,
|
||||
);
|
||||
|
|
|
|||
|
|
@ -44,12 +44,12 @@
|
|||
"@googleapis/youtube": "^20.0.0",
|
||||
"@keyv/mongo": "^2.1.8",
|
||||
"@keyv/redis": "^2.8.1",
|
||||
"@langchain/community": "^0.3.34",
|
||||
"@langchain/core": "^0.3.40",
|
||||
"@langchain/google-genai": "^0.1.11",
|
||||
"@langchain/google-vertexai": "^0.2.2",
|
||||
"@langchain/community": "^0.3.39",
|
||||
"@langchain/core": "^0.3.43",
|
||||
"@langchain/google-genai": "^0.2.2",
|
||||
"@langchain/google-vertexai": "^0.2.3",
|
||||
"@langchain/textsplitters": "^0.1.0",
|
||||
"@librechat/agents": "^2.3.94",
|
||||
"@librechat/agents": "^2.4.12",
|
||||
"@librechat/data-schemas": "*",
|
||||
"@waylaidwanderer/fetch-event-source": "^3.0.1",
|
||||
"@microsoft/microsoft-graph-client": "^3.0.7",
|
||||
|
|
@ -105,7 +105,7 @@
|
|||
"passport-ldapauth": "^3.0.1",
|
||||
"passport-local": "^1.0.0",
|
||||
"rate-limit-redis": "^4.2.0",
|
||||
"sharp": "^0.32.6",
|
||||
"sharp": "^0.33.5",
|
||||
"tiktoken": "^1.0.15",
|
||||
"traverse": "^0.6.7",
|
||||
"ua-parser-js": "^1.0.36",
|
||||
|
|
|
|||
|
|
@ -108,7 +108,7 @@ const getAvailableTools = async (req, res) => {
|
|||
const pluginManifest = availableTools;
|
||||
const customConfig = await getCustomConfig();
|
||||
if (customConfig?.mcpServers != null) {
|
||||
const mcpManager = await getMCPManager();
|
||||
const mcpManager = getMCPManager();
|
||||
await mcpManager.loadManifestTools(pluginManifest);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
const { FileSources } = require('librechat-data-provider');
|
||||
const {
|
||||
Balance,
|
||||
getFiles,
|
||||
updateUser,
|
||||
deleteFiles,
|
||||
deleteConvos,
|
||||
deletePresets,
|
||||
|
|
@ -12,6 +14,7 @@ const User = require('~/models/User');
|
|||
const { updateUserPluginAuth, deleteUserPluginAuth } = require('~/server/services/PluginService');
|
||||
const { updateUserPluginsService, deleteUserKey } = require('~/server/services/UserService');
|
||||
const { verifyEmail, resendVerificationEmail } = require('~/server/services/AuthService');
|
||||
const { needsRefresh, getNewS3URL } = require('~/server/services/Files/S3/crud');
|
||||
const { processDeleteRequest } = require('~/server/services/Files/process');
|
||||
const { deleteAllSharedLinks } = require('~/models/Share');
|
||||
const { deleteToolCalls } = require('~/models/ToolCall');
|
||||
|
|
@ -19,8 +22,23 @@ const { Transaction } = require('~/models/Transaction');
|
|||
const { logger } = require('~/config');
|
||||
|
||||
const getUserController = async (req, res) => {
|
||||
/** @type {MongoUser} */
|
||||
const userData = req.user.toObject != null ? req.user.toObject() : { ...req.user };
|
||||
delete userData.totpSecret;
|
||||
if (req.app.locals.fileStrategy === FileSources.s3 && userData.avatar) {
|
||||
const avatarNeedsRefresh = needsRefresh(userData.avatar, 3600);
|
||||
if (!avatarNeedsRefresh) {
|
||||
return res.status(200).send(userData);
|
||||
}
|
||||
const originalAvatar = userData.avatar;
|
||||
try {
|
||||
userData.avatar = await getNewS3URL(userData.avatar);
|
||||
await updateUser(userData.id, { avatar: userData.avatar });
|
||||
} catch (error) {
|
||||
userData.avatar = originalAvatar;
|
||||
logger.error('Error getting new S3 URL for avatar:', error);
|
||||
}
|
||||
}
|
||||
res.status(200).send(userData);
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -20,11 +20,9 @@ const {
|
|||
const {
|
||||
Constants,
|
||||
VisionModes,
|
||||
openAISchema,
|
||||
ContentTypes,
|
||||
EModelEndpoint,
|
||||
KnownEndpoints,
|
||||
anthropicSchema,
|
||||
isAgentsEndpoint,
|
||||
AgentCapabilities,
|
||||
bedrockInputSchema,
|
||||
|
|
@ -43,11 +41,18 @@ const { createRun } = require('./run');
|
|||
/** @typedef {import('@librechat/agents').MessageContentComplex} MessageContentComplex */
|
||||
/** @typedef {import('@langchain/core/runnables').RunnableConfig} RunnableConfig */
|
||||
|
||||
const providerParsers = {
|
||||
[EModelEndpoint.openAI]: openAISchema.parse,
|
||||
[EModelEndpoint.azureOpenAI]: openAISchema.parse,
|
||||
[EModelEndpoint.anthropic]: anthropicSchema.parse,
|
||||
[EModelEndpoint.bedrock]: bedrockInputSchema.parse,
|
||||
/**
|
||||
* @param {ServerRequest} req
|
||||
* @param {Agent} agent
|
||||
* @param {string} endpoint
|
||||
*/
|
||||
const payloadParser = ({ req, agent, endpoint }) => {
|
||||
if (isAgentsEndpoint(endpoint)) {
|
||||
return { model: undefined };
|
||||
} else if (endpoint === EModelEndpoint.bedrock) {
|
||||
return bedrockInputSchema.parse(agent.model_parameters);
|
||||
}
|
||||
return req.body.endpointOption.model_parameters;
|
||||
};
|
||||
|
||||
const legacyContentEndpoints = new Set([KnownEndpoints.groq, KnownEndpoints.deepseek]);
|
||||
|
|
@ -180,28 +185,19 @@ class AgentClient extends BaseClient {
|
|||
}
|
||||
|
||||
getSaveOptions() {
|
||||
const parseOptions = providerParsers[this.options.endpoint];
|
||||
let runOptions =
|
||||
this.options.endpoint === EModelEndpoint.agents
|
||||
? {
|
||||
model: undefined,
|
||||
// TODO:
|
||||
// would need to be override settings; otherwise, model needs to be undefined
|
||||
// model: this.override.model,
|
||||
// instructions: this.override.instructions,
|
||||
// additional_instructions: this.override.additional_instructions,
|
||||
}
|
||||
: {};
|
||||
|
||||
if (parseOptions) {
|
||||
try {
|
||||
runOptions = parseOptions(this.options.agent.model_parameters);
|
||||
} catch (error) {
|
||||
logger.error(
|
||||
'[api/server/controllers/agents/client.js #getSaveOptions] Error parsing options',
|
||||
error,
|
||||
);
|
||||
}
|
||||
// TODO:
|
||||
// would need to be override settings; otherwise, model needs to be undefined
|
||||
// model: this.override.model,
|
||||
// instructions: this.override.instructions,
|
||||
// additional_instructions: this.override.additional_instructions,
|
||||
let runOptions = {};
|
||||
try {
|
||||
runOptions = payloadParser(this.options);
|
||||
} catch (error) {
|
||||
logger.error(
|
||||
'[api/server/controllers/agents/client.js #getSaveOptions] Error parsing options',
|
||||
error,
|
||||
);
|
||||
}
|
||||
|
||||
return removeNullishValues(
|
||||
|
|
@ -932,7 +928,14 @@ class AgentClient extends BaseClient {
|
|||
};
|
||||
let endpointConfig = this.options.req.app.locals[this.options.agent.endpoint];
|
||||
if (!endpointConfig) {
|
||||
endpointConfig = await getCustomEndpointConfig(this.options.agent.endpoint);
|
||||
try {
|
||||
endpointConfig = await getCustomEndpointConfig(this.options.agent.endpoint);
|
||||
} catch (err) {
|
||||
logger.error(
|
||||
'[api/server/controllers/agents/client.js #titleConvo] Error getting custom endpoint config',
|
||||
err,
|
||||
);
|
||||
}
|
||||
}
|
||||
if (
|
||||
endpointConfig &&
|
||||
|
|
|
|||
|
|
@ -11,6 +11,13 @@ const { providerEndpointMap, KnownEndpoints } = require('librechat-data-provider
|
|||
* @typedef {import('@librechat/agents').IState} IState
|
||||
*/
|
||||
|
||||
const customProviders = new Set([
|
||||
Providers.XAI,
|
||||
Providers.OLLAMA,
|
||||
Providers.DEEPSEEK,
|
||||
Providers.OPENROUTER,
|
||||
]);
|
||||
|
||||
/**
|
||||
* Creates a new Run instance with custom handlers and configuration.
|
||||
*
|
||||
|
|
@ -43,6 +50,15 @@ async function createRun({
|
|||
agent.model_parameters,
|
||||
);
|
||||
|
||||
/** Resolves issues with new OpenAI usage field */
|
||||
if (
|
||||
customProviders.has(agent.provider) ||
|
||||
(agent.provider === Providers.OPENAI && agent.endpoint !== agent.provider)
|
||||
) {
|
||||
llmConfig.streamUsage = false;
|
||||
llmConfig.usage = true;
|
||||
}
|
||||
|
||||
/** @type {'reasoning_content' | 'reasoning'} */
|
||||
let reasoningKey;
|
||||
if (
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ const {
|
|||
Tools,
|
||||
Constants,
|
||||
FileContext,
|
||||
FileSources,
|
||||
SystemRoles,
|
||||
EToolResources,
|
||||
actionDelimiter,
|
||||
|
|
@ -17,9 +18,10 @@ const {
|
|||
} = require('~/models/Agent');
|
||||
const { uploadImageBuffer, filterFile } = require('~/server/services/Files/process');
|
||||
const { getStrategyFunctions } = require('~/server/services/Files/strategies');
|
||||
const { refreshS3Url } = require('~/server/services/Files/S3/crud');
|
||||
const { updateAction, getActions } = require('~/models/Action');
|
||||
const { getProjectByName } = require('~/models/Project');
|
||||
const { updateAgentProjects } = require('~/models/Agent');
|
||||
const { getProjectByName } = require('~/models/Project');
|
||||
const { deleteFileByFilter } = require('~/models/File');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
|
|
@ -102,6 +104,14 @@ const getAgentHandler = async (req, res) => {
|
|||
return res.status(404).json({ error: 'Agent not found' });
|
||||
}
|
||||
|
||||
if (agent.avatar && agent.avatar?.source === FileSources.s3) {
|
||||
const originalUrl = agent.avatar.filepath;
|
||||
agent.avatar.filepath = await refreshS3Url(agent.avatar);
|
||||
if (originalUrl !== agent.avatar.filepath) {
|
||||
await updateAgent({ id }, { avatar: agent.avatar });
|
||||
}
|
||||
}
|
||||
|
||||
agent.author = agent.author.toString();
|
||||
agent.isCollaborative = !!agent.isCollaborative;
|
||||
|
||||
|
|
|
|||
|
|
@ -148,6 +148,13 @@ const createAbortController = (req, res, getAbortData, getReqData) => {
|
|||
return { abortController, onStart };
|
||||
};
|
||||
|
||||
/**
|
||||
* @param {ServerResponse} res
|
||||
* @param {ServerRequest} req
|
||||
* @param {Error | unknown} error
|
||||
* @param {Partial<TMessage> & { partialText?: string }} data
|
||||
* @returns { Promise<void> }
|
||||
*/
|
||||
const handleAbortError = async (res, req, error, data) => {
|
||||
if (error?.message?.includes('base64')) {
|
||||
logger.error('[handleAbortError] Error in base64 encoding', {
|
||||
|
|
@ -178,17 +185,30 @@ const handleAbortError = async (res, req, error, data) => {
|
|||
errorText = `{"type":"${ErrorTypes.NO_SYSTEM_MESSAGES}"}`;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param {string} partialText
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
const respondWithError = async (partialText) => {
|
||||
const endpointOption = req.body?.endpointOption;
|
||||
let options = {
|
||||
sender,
|
||||
messageId,
|
||||
conversationId,
|
||||
parentMessageId,
|
||||
text: errorText,
|
||||
shouldSaveMessage: true,
|
||||
user: req.user.id,
|
||||
shouldSaveMessage: true,
|
||||
spec: endpointOption?.spec,
|
||||
iconURL: endpointOption?.iconURL,
|
||||
modelLabel: endpointOption?.modelLabel,
|
||||
model: endpointOption?.modelOptions?.model || req.body?.model,
|
||||
};
|
||||
|
||||
if (req.body?.agent_id) {
|
||||
options.agent_id = req.body.agent_id;
|
||||
}
|
||||
|
||||
if (partialText) {
|
||||
options = {
|
||||
...options,
|
||||
|
|
|
|||
|
|
@ -1,6 +1,11 @@
|
|||
const { parseCompactConvo, EModelEndpoint, isAgentsEndpoint } = require('librechat-data-provider');
|
||||
const { getModelsConfig } = require('~/server/controllers/ModelController');
|
||||
const {
|
||||
parseCompactConvo,
|
||||
EModelEndpoint,
|
||||
isAgentsEndpoint,
|
||||
EndpointURLs,
|
||||
} = require('librechat-data-provider');
|
||||
const azureAssistants = require('~/server/services/Endpoints/azureAssistants');
|
||||
const { getModelsConfig } = require('~/server/controllers/ModelController');
|
||||
const assistants = require('~/server/services/Endpoints/assistants');
|
||||
const gptPlugins = require('~/server/services/Endpoints/gptPlugins');
|
||||
const { processFiles } = require('~/server/services/Files/process');
|
||||
|
|
@ -77,8 +82,9 @@ async function buildEndpointOption(req, res, next) {
|
|||
}
|
||||
|
||||
try {
|
||||
const isAgents = isAgentsEndpoint(endpoint);
|
||||
const endpointFn = buildFunction[endpointType ?? endpoint];
|
||||
const isAgents =
|
||||
isAgentsEndpoint(endpoint) || req.baseUrl.startsWith(EndpointURLs[EModelEndpoint.agents]);
|
||||
const endpointFn = buildFunction[isAgents ? EModelEndpoint.agents : (endpointType ?? endpoint)];
|
||||
const builder = isAgents ? (...args) => endpointFn(req, ...args) : endpointFn;
|
||||
|
||||
// TODO: use object params
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ const concurrentLimiter = require('./concurrentLimiter');
|
|||
const validateEndpoint = require('./validateEndpoint');
|
||||
const requireLocalAuth = require('./requireLocalAuth');
|
||||
const canDeleteAccount = require('./canDeleteAccount');
|
||||
const setBalanceConfig = require('./setBalanceConfig');
|
||||
const requireLdapAuth = require('./requireLdapAuth');
|
||||
const abortMiddleware = require('./abortMiddleware');
|
||||
const checkInviteUser = require('./checkInviteUser');
|
||||
|
|
@ -41,6 +42,7 @@ module.exports = {
|
|||
requireLocalAuth,
|
||||
canDeleteAccount,
|
||||
validateEndpoint,
|
||||
setBalanceConfig,
|
||||
concurrentLimiter,
|
||||
checkDomainAllowed,
|
||||
validateMessageReq,
|
||||
|
|
|
|||
|
|
@ -1,39 +1,41 @@
|
|||
const axios = require('axios');
|
||||
const { ErrorTypes } = require('librechat-data-provider');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const denyRequest = require('./denyRequest');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
async function moderateText(req, res, next) {
|
||||
if (process.env.OPENAI_MODERATION === 'true') {
|
||||
try {
|
||||
const { text } = req.body;
|
||||
if (!isEnabled(process.env.OPENAI_MODERATION)) {
|
||||
return next();
|
||||
}
|
||||
try {
|
||||
const { text } = req.body;
|
||||
|
||||
const response = await axios.post(
|
||||
process.env.OPENAI_MODERATION_REVERSE_PROXY || 'https://api.openai.com/v1/moderations',
|
||||
{
|
||||
input: text,
|
||||
const response = await axios.post(
|
||||
process.env.OPENAI_MODERATION_REVERSE_PROXY || 'https://api.openai.com/v1/moderations',
|
||||
{
|
||||
input: text,
|
||||
},
|
||||
{
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${process.env.OPENAI_MODERATION_API_KEY}`,
|
||||
},
|
||||
{
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${process.env.OPENAI_MODERATION_API_KEY}`,
|
||||
},
|
||||
},
|
||||
);
|
||||
},
|
||||
);
|
||||
|
||||
const results = response.data.results;
|
||||
const flagged = results.some((result) => result.flagged);
|
||||
const results = response.data.results;
|
||||
const flagged = results.some((result) => result.flagged);
|
||||
|
||||
if (flagged) {
|
||||
const type = ErrorTypes.MODERATION;
|
||||
const errorMessage = { type };
|
||||
return await denyRequest(req, res, errorMessage);
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('Error in moderateText:', error);
|
||||
const errorMessage = 'error in moderation check';
|
||||
if (flagged) {
|
||||
const type = ErrorTypes.MODERATION;
|
||||
const errorMessage = { type };
|
||||
return await denyRequest(req, res, errorMessage);
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('Error in moderateText:', error);
|
||||
const errorMessage = 'error in moderation check';
|
||||
return await denyRequest(req, res, errorMessage);
|
||||
}
|
||||
next();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,9 +17,9 @@ const checkAccess = async (user, permissionType, permissions, bodyProps = {}, ch
|
|||
}
|
||||
|
||||
const role = await getRoleByName(user.role);
|
||||
if (role && role[permissionType]) {
|
||||
if (role && role.permissions && role.permissions[permissionType]) {
|
||||
const hasAnyPermission = permissions.some((permission) => {
|
||||
if (role[permissionType][permission]) {
|
||||
if (role.permissions[permissionType][permission]) {
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
|||
91
api/server/middleware/setBalanceConfig.js
Normal file
91
api/server/middleware/setBalanceConfig.js
Normal file
|
|
@ -0,0 +1,91 @@
|
|||
const { getBalanceConfig } = require('~/server/services/Config');
|
||||
const Balance = require('~/models/Balance');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
/**
|
||||
* Middleware to synchronize user balance settings with current balance configuration.
|
||||
* @function
|
||||
* @param {Object} req - Express request object containing user information.
|
||||
* @param {Object} res - Express response object.
|
||||
* @param {import('express').NextFunction} next - Next middleware function.
|
||||
*/
|
||||
const setBalanceConfig = async (req, res, next) => {
|
||||
try {
|
||||
const balanceConfig = await getBalanceConfig();
|
||||
if (!balanceConfig?.enabled) {
|
||||
return next();
|
||||
}
|
||||
if (balanceConfig.startBalance == null) {
|
||||
return next();
|
||||
}
|
||||
|
||||
const userId = req.user._id;
|
||||
const userBalanceRecord = await Balance.findOne({ user: userId }).lean();
|
||||
const updateFields = buildUpdateFields(balanceConfig, userBalanceRecord);
|
||||
|
||||
if (Object.keys(updateFields).length === 0) {
|
||||
return next();
|
||||
}
|
||||
|
||||
await Balance.findOneAndUpdate(
|
||||
{ user: userId },
|
||||
{ $set: updateFields },
|
||||
{ upsert: true, new: true },
|
||||
);
|
||||
|
||||
next();
|
||||
} catch (error) {
|
||||
logger.error('Error setting user balance:', error);
|
||||
next(error);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Build an object containing fields that need updating
|
||||
* @param {Object} config - The balance configuration
|
||||
* @param {Object|null} userRecord - The user's current balance record, if any
|
||||
* @returns {Object} Fields that need updating
|
||||
*/
|
||||
function buildUpdateFields(config, userRecord) {
|
||||
const updateFields = {};
|
||||
|
||||
// Ensure user record has the required fields
|
||||
if (!userRecord) {
|
||||
updateFields.user = userRecord?.user;
|
||||
updateFields.tokenCredits = config.startBalance;
|
||||
}
|
||||
|
||||
if (userRecord?.tokenCredits == null && config.startBalance != null) {
|
||||
updateFields.tokenCredits = config.startBalance;
|
||||
}
|
||||
|
||||
const isAutoRefillConfigValid =
|
||||
config.autoRefillEnabled &&
|
||||
config.refillIntervalValue != null &&
|
||||
config.refillIntervalUnit != null &&
|
||||
config.refillAmount != null;
|
||||
|
||||
if (!isAutoRefillConfigValid) {
|
||||
return updateFields;
|
||||
}
|
||||
|
||||
if (userRecord?.autoRefillEnabled !== config.autoRefillEnabled) {
|
||||
updateFields.autoRefillEnabled = config.autoRefillEnabled;
|
||||
}
|
||||
|
||||
if (userRecord?.refillIntervalValue !== config.refillIntervalValue) {
|
||||
updateFields.refillIntervalValue = config.refillIntervalValue;
|
||||
}
|
||||
|
||||
if (userRecord?.refillIntervalUnit !== config.refillIntervalUnit) {
|
||||
updateFields.refillIntervalUnit = config.refillIntervalUnit;
|
||||
}
|
||||
|
||||
if (userRecord?.refillAmount !== config.refillAmount) {
|
||||
updateFields.refillAmount = config.refillAmount;
|
||||
}
|
||||
|
||||
return updateFields;
|
||||
}
|
||||
|
||||
module.exports = setBalanceConfig;
|
||||
|
|
@ -20,7 +20,7 @@ router.get('/:action_id/oauth/callback', async (req, res) => {
|
|||
const { action_id } = req.params;
|
||||
const { code, state } = req.query;
|
||||
|
||||
const flowManager = await getFlowStateManager(getLogStores);
|
||||
const flowManager = getFlowStateManager(getLogStores);
|
||||
let identifier = action_id;
|
||||
try {
|
||||
let decodedState;
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ const { PermissionTypes, Permissions } = require('librechat-data-provider');
|
|||
const {
|
||||
setHeaders,
|
||||
handleAbort,
|
||||
moderateText,
|
||||
// validateModel,
|
||||
generateCheckAccess,
|
||||
validateConvoAccess,
|
||||
|
|
@ -14,28 +15,38 @@ const addTitle = require('~/server/services/Endpoints/agents/title');
|
|||
|
||||
const router = express.Router();
|
||||
|
||||
router.use(moderateText);
|
||||
router.post('/abort', handleAbort());
|
||||
|
||||
const checkAgentAccess = generateCheckAccess(PermissionTypes.AGENTS, [Permissions.USE]);
|
||||
|
||||
router.use(checkAgentAccess);
|
||||
router.use(validateConvoAccess);
|
||||
router.use(buildEndpointOption);
|
||||
router.use(setHeaders);
|
||||
|
||||
const controller = async (req, res, next) => {
|
||||
await AgentController(req, res, next, initializeClient, addTitle);
|
||||
};
|
||||
|
||||
/**
|
||||
* @route POST /
|
||||
* @route POST / (regular endpoint)
|
||||
* @desc Chat with an assistant
|
||||
* @access Public
|
||||
* @param {express.Request} req - The request object, containing the request data.
|
||||
* @param {express.Response} res - The response object, used to send back a response.
|
||||
* @returns {void}
|
||||
*/
|
||||
router.post(
|
||||
'/',
|
||||
// validateModel,
|
||||
checkAgentAccess,
|
||||
validateConvoAccess,
|
||||
buildEndpointOption,
|
||||
setHeaders,
|
||||
async (req, res, next) => {
|
||||
await AgentController(req, res, next, initializeClient, addTitle);
|
||||
},
|
||||
);
|
||||
router.post('/', controller);
|
||||
|
||||
/**
|
||||
* @route POST /:endpoint (ephemeral agents)
|
||||
* @desc Chat with an assistant
|
||||
* @access Public
|
||||
* @param {express.Request} req - The request object, containing the request data.
|
||||
* @param {express.Response} res - The response object, used to send back a response.
|
||||
* @returns {void}
|
||||
*/
|
||||
router.post('/:endpoint', controller);
|
||||
|
||||
module.exports = router;
|
||||
|
|
|
|||
|
|
@ -1,21 +1,40 @@
|
|||
const express = require('express');
|
||||
const router = express.Router();
|
||||
const {
|
||||
uaParser,
|
||||
checkBan,
|
||||
requireJwtAuth,
|
||||
// concurrentLimiter,
|
||||
// messageIpLimiter,
|
||||
// messageUserLimiter,
|
||||
messageIpLimiter,
|
||||
concurrentLimiter,
|
||||
messageUserLimiter,
|
||||
} = require('~/server/middleware');
|
||||
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const { v1 } = require('./v1');
|
||||
const chat = require('./chat');
|
||||
|
||||
const { LIMIT_CONCURRENT_MESSAGES, LIMIT_MESSAGE_IP, LIMIT_MESSAGE_USER } = process.env ?? {};
|
||||
|
||||
const router = express.Router();
|
||||
|
||||
router.use(requireJwtAuth);
|
||||
router.use(checkBan);
|
||||
router.use(uaParser);
|
||||
|
||||
router.use('/', v1);
|
||||
router.use('/chat', chat);
|
||||
|
||||
const chatRouter = express.Router();
|
||||
if (isEnabled(LIMIT_CONCURRENT_MESSAGES)) {
|
||||
chatRouter.use(concurrentLimiter);
|
||||
}
|
||||
|
||||
if (isEnabled(LIMIT_MESSAGE_IP)) {
|
||||
chatRouter.use(messageIpLimiter);
|
||||
}
|
||||
|
||||
if (isEnabled(LIMIT_MESSAGE_USER)) {
|
||||
chatRouter.use(messageUserLimiter);
|
||||
}
|
||||
|
||||
chatRouter.use('/', chat);
|
||||
router.use('/chat', chatRouter);
|
||||
|
||||
module.exports = router;
|
||||
|
|
|
|||
|
|
@ -1,10 +1,4 @@
|
|||
const express = require('express');
|
||||
const openAI = require('./openAI');
|
||||
const custom = require('./custom');
|
||||
const google = require('./google');
|
||||
const anthropic = require('./anthropic');
|
||||
const gptPlugins = require('./gptPlugins');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const { EModelEndpoint } = require('librechat-data-provider');
|
||||
const {
|
||||
uaParser,
|
||||
|
|
@ -15,6 +9,12 @@ const {
|
|||
messageUserLimiter,
|
||||
validateConvoAccess,
|
||||
} = require('~/server/middleware');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const gptPlugins = require('./gptPlugins');
|
||||
const anthropic = require('./anthropic');
|
||||
const custom = require('./custom');
|
||||
const google = require('./google');
|
||||
const openAI = require('./openAI');
|
||||
|
||||
const { LIMIT_CONCURRENT_MESSAGES, LIMIT_MESSAGE_IP, LIMIT_MESSAGE_USER } = process.env ?? {};
|
||||
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ const {
|
|||
checkInviteUser,
|
||||
registerLimiter,
|
||||
requireLdapAuth,
|
||||
setBalanceConfig,
|
||||
requireLocalAuth,
|
||||
resetPasswordLimiter,
|
||||
validateRegistration,
|
||||
|
|
@ -40,6 +41,7 @@ router.post(
|
|||
loginLimiter,
|
||||
checkBan,
|
||||
ldapAuth ? requireLdapAuth : requireLocalAuth,
|
||||
setBalanceConfig,
|
||||
loginController,
|
||||
);
|
||||
router.post('/refresh', refreshController);
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ const router = express.Router();
|
|||
const {
|
||||
setHeaders,
|
||||
handleAbort,
|
||||
moderateText,
|
||||
// validateModel,
|
||||
// validateEndpoint,
|
||||
buildEndpointOption,
|
||||
|
|
@ -12,6 +13,7 @@ const { initializeClient } = require('~/server/services/Endpoints/bedrock');
|
|||
const AgentController = require('~/server/controllers/agents/request');
|
||||
const addTitle = require('~/server/services/Endpoints/agents/title');
|
||||
|
||||
router.use(moderateText);
|
||||
router.post('/abort', handleAbort());
|
||||
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -1,19 +1,35 @@
|
|||
const express = require('express');
|
||||
const router = express.Router();
|
||||
const {
|
||||
uaParser,
|
||||
checkBan,
|
||||
requireJwtAuth,
|
||||
// concurrentLimiter,
|
||||
// messageIpLimiter,
|
||||
// messageUserLimiter,
|
||||
messageIpLimiter,
|
||||
concurrentLimiter,
|
||||
messageUserLimiter,
|
||||
} = require('~/server/middleware');
|
||||
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const chat = require('./chat');
|
||||
|
||||
const { LIMIT_CONCURRENT_MESSAGES, LIMIT_MESSAGE_IP, LIMIT_MESSAGE_USER } = process.env ?? {};
|
||||
|
||||
const router = express.Router();
|
||||
|
||||
router.use(requireJwtAuth);
|
||||
router.use(checkBan);
|
||||
router.use(uaParser);
|
||||
|
||||
if (isEnabled(LIMIT_CONCURRENT_MESSAGES)) {
|
||||
router.use(concurrentLimiter);
|
||||
}
|
||||
|
||||
if (isEnabled(LIMIT_MESSAGE_IP)) {
|
||||
router.use(messageIpLimiter);
|
||||
}
|
||||
|
||||
if (isEnabled(LIMIT_MESSAGE_USER)) {
|
||||
router.use(messageUserLimiter);
|
||||
}
|
||||
|
||||
router.use('/chat', chat);
|
||||
|
||||
module.exports = router;
|
||||
|
|
|
|||
|
|
@ -2,7 +2,9 @@ const fs = require('fs').promises;
|
|||
const express = require('express');
|
||||
const { EnvVar } = require('@librechat/agents');
|
||||
const {
|
||||
Time,
|
||||
isUUID,
|
||||
CacheKeys,
|
||||
FileSources,
|
||||
EModelEndpoint,
|
||||
isAgentsEndpoint,
|
||||
|
|
@ -17,8 +19,10 @@ const {
|
|||
const { getStrategyFunctions } = require('~/server/services/Files/strategies');
|
||||
const { getOpenAIClient } = require('~/server/controllers/assistants/helpers');
|
||||
const { loadAuthValues } = require('~/server/services/Tools/credentials');
|
||||
const { refreshS3FileUrls } = require('~/server/services/Files/S3/crud');
|
||||
const { getFiles, batchUpdateFiles } = require('~/models/File');
|
||||
const { getAgent } = require('~/models/Agent');
|
||||
const { getFiles } = require('~/models/File');
|
||||
const { getLogStores } = require('~/cache');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const router = express.Router();
|
||||
|
|
@ -26,6 +30,18 @@ const router = express.Router();
|
|||
router.get('/', async (req, res) => {
|
||||
try {
|
||||
const files = await getFiles({ user: req.user.id });
|
||||
if (req.app.locals.fileStrategy === FileSources.s3) {
|
||||
try {
|
||||
const cache = getLogStores(CacheKeys.S3_EXPIRY_INTERVAL);
|
||||
const alreadyChecked = await cache.get(req.user.id);
|
||||
if (!alreadyChecked) {
|
||||
await refreshS3FileUrls(files, batchUpdateFiles);
|
||||
await cache.set(req.user.id, true, Time.THIRTY_MINUTES);
|
||||
}
|
||||
} catch (error) {
|
||||
logger.warn('[/files] Error refreshing S3 file URLs:', error);
|
||||
}
|
||||
}
|
||||
res.status(200).send(files);
|
||||
} catch (error) {
|
||||
logger.error('[/files] Error getting files:', error);
|
||||
|
|
|
|||
|
|
@ -1,7 +1,13 @@
|
|||
// file deepcode ignore NoRateLimitingForLogin: Rate limiting is handled by the `loginLimiter` middleware
|
||||
const express = require('express');
|
||||
const passport = require('passport');
|
||||
const { loginLimiter, logHeaders, checkBan, checkDomainAllowed } = require('~/server/middleware');
|
||||
const {
|
||||
checkBan,
|
||||
logHeaders,
|
||||
loginLimiter,
|
||||
setBalanceConfig,
|
||||
checkDomainAllowed,
|
||||
} = require('~/server/middleware');
|
||||
const { setAuthTokens } = require('~/server/services/AuthService');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
|
|
@ -56,6 +62,7 @@ router.get(
|
|||
session: false,
|
||||
scope: ['openid', 'profile', 'email'],
|
||||
}),
|
||||
setBalanceConfig,
|
||||
oauthHandler,
|
||||
);
|
||||
|
||||
|
|
@ -80,6 +87,7 @@ router.get(
|
|||
scope: ['public_profile'],
|
||||
profileFields: ['id', 'email', 'name'],
|
||||
}),
|
||||
setBalanceConfig,
|
||||
oauthHandler,
|
||||
);
|
||||
|
||||
|
|
@ -100,6 +108,7 @@ router.get(
|
|||
failureMessage: true,
|
||||
session: false,
|
||||
}),
|
||||
setBalanceConfig,
|
||||
oauthHandler,
|
||||
);
|
||||
|
||||
|
|
@ -122,6 +131,7 @@ router.get(
|
|||
session: false,
|
||||
scope: ['user:email', 'read:user'],
|
||||
}),
|
||||
setBalanceConfig,
|
||||
oauthHandler,
|
||||
);
|
||||
|
||||
|
|
@ -144,6 +154,7 @@ router.get(
|
|||
session: false,
|
||||
scope: ['identify', 'email'],
|
||||
}),
|
||||
setBalanceConfig,
|
||||
oauthHandler,
|
||||
);
|
||||
|
||||
|
|
@ -164,6 +175,7 @@ router.post(
|
|||
failureMessage: true,
|
||||
session: false,
|
||||
}),
|
||||
setBalanceConfig,
|
||||
oauthHandler,
|
||||
);
|
||||
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ router.put('/:roleName/prompts', checkAdmin, async (req, res) => {
|
|||
const { roleName: _r } = req.params;
|
||||
// TODO: TEMP, use a better parsing for roleName
|
||||
const roleName = _r.toUpperCase();
|
||||
/** @type {TRole['PROMPTS']} */
|
||||
/** @type {TRole['permissions']['PROMPTS']} */
|
||||
const updates = req.body;
|
||||
|
||||
try {
|
||||
|
|
@ -59,10 +59,16 @@ router.put('/:roleName/prompts', checkAdmin, async (req, res) => {
|
|||
return res.status(404).send({ message: 'Role not found' });
|
||||
}
|
||||
|
||||
const currentPermissions =
|
||||
role.permissions?.[PermissionTypes.PROMPTS] || role[PermissionTypes.PROMPTS] || {};
|
||||
|
||||
const mergedUpdates = {
|
||||
[PermissionTypes.PROMPTS]: {
|
||||
...role[PermissionTypes.PROMPTS],
|
||||
...parsedUpdates,
|
||||
permissions: {
|
||||
...role.permissions,
|
||||
[PermissionTypes.PROMPTS]: {
|
||||
...currentPermissions,
|
||||
...parsedUpdates,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
|
|
@ -81,7 +87,7 @@ router.put('/:roleName/agents', checkAdmin, async (req, res) => {
|
|||
const { roleName: _r } = req.params;
|
||||
// TODO: TEMP, use a better parsing for roleName
|
||||
const roleName = _r.toUpperCase();
|
||||
/** @type {TRole['AGENTS']} */
|
||||
/** @type {TRole['permissions']['AGENTS']} */
|
||||
const updates = req.body;
|
||||
|
||||
try {
|
||||
|
|
@ -92,17 +98,23 @@ router.put('/:roleName/agents', checkAdmin, async (req, res) => {
|
|||
return res.status(404).send({ message: 'Role not found' });
|
||||
}
|
||||
|
||||
const currentPermissions =
|
||||
role.permissions?.[PermissionTypes.AGENTS] || role[PermissionTypes.AGENTS] || {};
|
||||
|
||||
const mergedUpdates = {
|
||||
[PermissionTypes.AGENTS]: {
|
||||
...role[PermissionTypes.AGENTS],
|
||||
...parsedUpdates,
|
||||
permissions: {
|
||||
...role.permissions,
|
||||
[PermissionTypes.AGENTS]: {
|
||||
...currentPermissions,
|
||||
...parsedUpdates,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const updatedRole = await updateRoleByName(roleName, mergedUpdates);
|
||||
res.status(200).send(updatedRole);
|
||||
} catch (error) {
|
||||
return res.status(400).send({ message: 'Invalid prompt permissions.', error: error.errors });
|
||||
return res.status(400).send({ message: 'Invalid agent permissions.', error: error.errors });
|
||||
}
|
||||
});
|
||||
|
||||
|
|
|
|||
|
|
@ -13,7 +13,6 @@ const {
|
|||
actionDomainSeparator,
|
||||
} = require('librechat-data-provider');
|
||||
const { refreshAccessToken } = require('~/server/services/TokenService');
|
||||
const { isActionDomainAllowed } = require('~/server/services/domains');
|
||||
const { logger, getFlowStateManager, sendEvent } = require('~/config');
|
||||
const { encryptV2, decryptV2 } = require('~/server/utils/crypto');
|
||||
const { getActions, deleteActions } = require('~/models/Action');
|
||||
|
|
@ -130,6 +129,7 @@ async function loadActionSets(searchParams) {
|
|||
* @param {string | undefined} [params.name] - The name of the tool.
|
||||
* @param {string | undefined} [params.description] - The description for the tool.
|
||||
* @param {import('zod').ZodTypeAny | undefined} [params.zodSchema] - The Zod schema for tool input validation/definition
|
||||
* @param {{ oauth_client_id?: string; oauth_client_secret?: string; }} params.encrypted - The encrypted values for the action.
|
||||
* @returns { Promise<typeof tool | { _call: (toolInput: Object | string) => unknown}> } An object with `_call` method to execute the tool input.
|
||||
*/
|
||||
async function createActionTool({
|
||||
|
|
@ -140,17 +140,8 @@ async function createActionTool({
|
|||
zodSchema,
|
||||
name,
|
||||
description,
|
||||
encrypted,
|
||||
}) {
|
||||
const isDomainAllowed = await isActionDomainAllowed(action.metadata.domain);
|
||||
if (!isDomainAllowed) {
|
||||
return null;
|
||||
}
|
||||
const encrypted = {
|
||||
oauth_client_id: action.metadata.oauth_client_id,
|
||||
oauth_client_secret: action.metadata.oauth_client_secret,
|
||||
};
|
||||
action.metadata = await decryptMetadata(action.metadata);
|
||||
|
||||
/** @type {(toolInput: Object | string, config: GraphRunnableConfig) => Promise<unknown>} */
|
||||
const _call = async (toolInput, config) => {
|
||||
try {
|
||||
|
|
@ -198,26 +189,32 @@ async function createActionTool({
|
|||
expires_at: Date.now() + Time.TWO_MINUTES,
|
||||
},
|
||||
};
|
||||
const flowManager = await getFlowStateManager(getLogStores);
|
||||
const flowManager = getFlowStateManager(getLogStores);
|
||||
await flowManager.createFlowWithHandler(
|
||||
`${identifier}:login`,
|
||||
`${identifier}:oauth_login:${config.metadata.thread_id}:${config.metadata.run_id}`,
|
||||
'oauth_login',
|
||||
async () => {
|
||||
sendEvent(res, { event: GraphEvents.ON_RUN_STEP_DELTA, data });
|
||||
logger.debug('Sent OAuth login request to client', { action_id, identifier });
|
||||
return true;
|
||||
},
|
||||
config?.signal,
|
||||
);
|
||||
logger.debug('Waiting for OAuth Authorization response', { action_id, identifier });
|
||||
const result = await flowManager.createFlow(identifier, 'oauth', {
|
||||
state: stateToken,
|
||||
userId: req.user.id,
|
||||
client_url: metadata.auth.client_url,
|
||||
redirect_uri: `${process.env.DOMAIN_CLIENT}/api/actions/${action_id}/oauth/callback`,
|
||||
/** Encrypted values */
|
||||
encrypted_oauth_client_id: encrypted.oauth_client_id,
|
||||
encrypted_oauth_client_secret: encrypted.oauth_client_secret,
|
||||
});
|
||||
const result = await flowManager.createFlow(
|
||||
identifier,
|
||||
'oauth',
|
||||
{
|
||||
state: stateToken,
|
||||
userId: req.user.id,
|
||||
client_url: metadata.auth.client_url,
|
||||
redirect_uri: `${process.env.DOMAIN_CLIENT}/api/actions/${action_id}/oauth/callback`,
|
||||
/** Encrypted values */
|
||||
encrypted_oauth_client_id: encrypted.oauth_client_id,
|
||||
encrypted_oauth_client_secret: encrypted.oauth_client_secret,
|
||||
},
|
||||
config?.signal,
|
||||
);
|
||||
logger.debug('Received OAuth Authorization response', { action_id, identifier });
|
||||
data.delta.auth = undefined;
|
||||
data.delta.expires_at = undefined;
|
||||
|
|
@ -268,11 +265,12 @@ async function createActionTool({
|
|||
encrypted_oauth_client_id: encrypted.oauth_client_id,
|
||||
encrypted_oauth_client_secret: encrypted.oauth_client_secret,
|
||||
});
|
||||
const flowManager = await getFlowStateManager(getLogStores);
|
||||
const flowManager = getFlowStateManager(getLogStores);
|
||||
const refreshData = await flowManager.createFlowWithHandler(
|
||||
`${identifier}:refresh`,
|
||||
'oauth_refresh',
|
||||
refreshTokens,
|
||||
config?.signal,
|
||||
);
|
||||
metadata.oauth_access_token = refreshData.access_token;
|
||||
if (refreshData.refresh_token) {
|
||||
|
|
@ -308,9 +306,8 @@ async function createActionTool({
|
|||
}
|
||||
return response.data;
|
||||
} catch (error) {
|
||||
const logMessage = `API call to ${action.metadata.domain} failed`;
|
||||
logAxiosError({ message: logMessage, error });
|
||||
throw error;
|
||||
const message = `API call to ${action.metadata.domain} failed:`;
|
||||
return logAxiosError({ message, error });
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -327,6 +324,27 @@ async function createActionTool({
|
|||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Encrypts a sensitive value.
|
||||
* @param {string} value
|
||||
* @returns {Promise<string>}
|
||||
*/
|
||||
async function encryptSensitiveValue(value) {
|
||||
// Encode API key to handle special characters like ":"
|
||||
const encodedValue = encodeURIComponent(value);
|
||||
return await encryptV2(encodedValue);
|
||||
}
|
||||
|
||||
/**
|
||||
* Decrypts a sensitive value.
|
||||
* @param {string} value
|
||||
* @returns {Promise<string>}
|
||||
*/
|
||||
async function decryptSensitiveValue(value) {
|
||||
const decryptedValue = await decryptV2(value);
|
||||
return decodeURIComponent(decryptedValue);
|
||||
}
|
||||
|
||||
/**
|
||||
* Encrypts sensitive metadata values for an action.
|
||||
*
|
||||
|
|
@ -339,17 +357,19 @@ async function encryptMetadata(metadata) {
|
|||
// ServiceHttp
|
||||
if (metadata.auth && metadata.auth.type === AuthTypeEnum.ServiceHttp) {
|
||||
if (metadata.api_key) {
|
||||
encryptedMetadata.api_key = await encryptV2(metadata.api_key);
|
||||
encryptedMetadata.api_key = await encryptSensitiveValue(metadata.api_key);
|
||||
}
|
||||
}
|
||||
|
||||
// OAuth
|
||||
else if (metadata.auth && metadata.auth.type === AuthTypeEnum.OAuth) {
|
||||
if (metadata.oauth_client_id) {
|
||||
encryptedMetadata.oauth_client_id = await encryptV2(metadata.oauth_client_id);
|
||||
encryptedMetadata.oauth_client_id = await encryptSensitiveValue(metadata.oauth_client_id);
|
||||
}
|
||||
if (metadata.oauth_client_secret) {
|
||||
encryptedMetadata.oauth_client_secret = await encryptV2(metadata.oauth_client_secret);
|
||||
encryptedMetadata.oauth_client_secret = await encryptSensitiveValue(
|
||||
metadata.oauth_client_secret,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -368,17 +388,19 @@ async function decryptMetadata(metadata) {
|
|||
// ServiceHttp
|
||||
if (metadata.auth && metadata.auth.type === AuthTypeEnum.ServiceHttp) {
|
||||
if (metadata.api_key) {
|
||||
decryptedMetadata.api_key = await decryptV2(metadata.api_key);
|
||||
decryptedMetadata.api_key = await decryptSensitiveValue(metadata.api_key);
|
||||
}
|
||||
}
|
||||
|
||||
// OAuth
|
||||
else if (metadata.auth && metadata.auth.type === AuthTypeEnum.OAuth) {
|
||||
if (metadata.oauth_client_id) {
|
||||
decryptedMetadata.oauth_client_id = await decryptV2(metadata.oauth_client_id);
|
||||
decryptedMetadata.oauth_client_id = await decryptSensitiveValue(metadata.oauth_client_id);
|
||||
}
|
||||
if (metadata.oauth_client_secret) {
|
||||
decryptedMetadata.oauth_client_secret = await decryptV2(metadata.oauth_client_secret);
|
||||
decryptedMetadata.oauth_client_secret = await decryptSensitiveValue(
|
||||
metadata.oauth_client_secret,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -52,7 +52,7 @@ const AppService = async (app) => {
|
|||
|
||||
if (fileStrategy === FileSources.firebase) {
|
||||
initializeFirebase();
|
||||
} else if (fileStrategy === FileSources.azure) {
|
||||
} else if (fileStrategy === FileSources.azure_blob) {
|
||||
initializeAzureBlobService();
|
||||
} else if (fileStrategy === FileSources.s3) {
|
||||
initializeS3();
|
||||
|
|
@ -66,7 +66,7 @@ const AppService = async (app) => {
|
|||
});
|
||||
|
||||
if (config.mcpServers != null) {
|
||||
const mcpManager = await getMCPManager();
|
||||
const mcpManager = getMCPManager();
|
||||
await mcpManager.initializeMCP(config.mcpServers, processMCPEnv);
|
||||
await mcpManager.mapAvailableTools(availableTools);
|
||||
}
|
||||
|
|
@ -146,7 +146,7 @@ const AppService = async (app) => {
|
|||
...defaultLocals,
|
||||
fileConfig: config?.fileConfig,
|
||||
secureImageLinks: config?.secureImageLinks,
|
||||
modelSpecs: processModelSpecs(endpoints, config.modelSpecs),
|
||||
modelSpecs: processModelSpecs(endpoints, config.modelSpecs, interfaceConfig),
|
||||
...endpointLocals,
|
||||
};
|
||||
};
|
||||
|
|
|
|||
|
|
@ -91,7 +91,7 @@ const sendVerificationEmail = async (user) => {
|
|||
subject: 'Verify your email',
|
||||
payload: {
|
||||
appName: process.env.APP_TITLE || 'LibreChat',
|
||||
name: user.name,
|
||||
name: user.name || user.username || user.email,
|
||||
verificationLink: verificationLink,
|
||||
year: new Date().getFullYear(),
|
||||
},
|
||||
|
|
@ -278,7 +278,7 @@ const requestPasswordReset = async (req) => {
|
|||
subject: 'Password Reset Request',
|
||||
payload: {
|
||||
appName: process.env.APP_TITLE || 'LibreChat',
|
||||
name: user.name,
|
||||
name: user.name || user.username || user.email,
|
||||
link: link,
|
||||
year: new Date().getFullYear(),
|
||||
},
|
||||
|
|
@ -331,7 +331,7 @@ const resetPassword = async (userId, token, password) => {
|
|||
subject: 'Password Reset Successfully',
|
||||
payload: {
|
||||
appName: process.env.APP_TITLE || 'LibreChat',
|
||||
name: user.name,
|
||||
name: user.name || user.username || user.email,
|
||||
year: new Date().getFullYear(),
|
||||
},
|
||||
template: 'passwordReset.handlebars',
|
||||
|
|
@ -414,7 +414,7 @@ const resendVerificationEmail = async (req) => {
|
|||
subject: 'Verify your email',
|
||||
payload: {
|
||||
appName: process.env.APP_TITLE || 'LibreChat',
|
||||
name: user.name,
|
||||
name: user.name || user.username || user.email,
|
||||
verificationLink: verificationLink,
|
||||
year: new Date().getFullYear(),
|
||||
},
|
||||
|
|
|
|||
|
|
@ -31,19 +31,16 @@ async function getCustomConfig() {
|
|||
async function getBalanceConfig() {
|
||||
const isLegacyEnabled = isEnabled(process.env.CHECK_BALANCE);
|
||||
const startBalance = process.env.START_BALANCE;
|
||||
if (isLegacyEnabled || (startBalance != null && startBalance)) {
|
||||
/** @type {TCustomConfig['balance']} */
|
||||
const config = {
|
||||
enabled: isLegacyEnabled,
|
||||
startBalance: startBalance ? parseInt(startBalance, 10) : undefined,
|
||||
};
|
||||
return config;
|
||||
}
|
||||
/** @type {TCustomConfig['balance']} */
|
||||
const config = {
|
||||
enabled: isLegacyEnabled,
|
||||
startBalance: startBalance != null && startBalance ? parseInt(startBalance, 10) : undefined,
|
||||
};
|
||||
const customConfig = await getCustomConfig();
|
||||
if (!customConfig) {
|
||||
return null;
|
||||
return config;
|
||||
}
|
||||
return customConfig?.['balance'] ?? null;
|
||||
return { ...config, ...(customConfig?.['balance'] ?? {}) };
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -33,10 +33,12 @@ async function getEndpointsConfig(req) {
|
|||
};
|
||||
}
|
||||
if (mergedConfig[EModelEndpoint.agents] && req.app.locals?.[EModelEndpoint.agents]) {
|
||||
const { disableBuilder, capabilities, ..._rest } = req.app.locals[EModelEndpoint.agents];
|
||||
const { disableBuilder, capabilities, allowedProviders, ..._rest } =
|
||||
req.app.locals[EModelEndpoint.agents];
|
||||
|
||||
mergedConfig[EModelEndpoint.agents] = {
|
||||
...mergedConfig[EModelEndpoint.agents],
|
||||
allowedProviders,
|
||||
disableBuilder,
|
||||
capabilities,
|
||||
};
|
||||
|
|
|
|||
|
|
@ -1,12 +1,15 @@
|
|||
const { isAgentsEndpoint, Constants } = require('librechat-data-provider');
|
||||
const { loadAgent } = require('~/models/Agent');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const buildOptions = (req, endpoint, parsedBody) => {
|
||||
const buildOptions = (req, endpoint, parsedBody, endpointType) => {
|
||||
const { spec, iconURL, agent_id, instructions, maxContextTokens, ...model_parameters } =
|
||||
parsedBody;
|
||||
const agentPromise = loadAgent({
|
||||
req,
|
||||
agent_id,
|
||||
agent_id: isAgentsEndpoint(endpoint) ? agent_id : Constants.EPHEMERAL_AGENT_ID,
|
||||
endpoint,
|
||||
model_parameters,
|
||||
}).catch((error) => {
|
||||
logger.error(`[/agents/:${agent_id}] Error retrieving agent during build options step`, error);
|
||||
return undefined;
|
||||
|
|
@ -17,6 +20,7 @@ const buildOptions = (req, endpoint, parsedBody) => {
|
|||
iconURL,
|
||||
endpoint,
|
||||
agent_id,
|
||||
endpointType,
|
||||
instructions,
|
||||
maxContextTokens,
|
||||
model_parameters,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
const { createContentAggregator, Providers } = require('@librechat/agents');
|
||||
const {
|
||||
Constants,
|
||||
ErrorTypes,
|
||||
EModelEndpoint,
|
||||
getResponseSender,
|
||||
AgentCapabilities,
|
||||
|
|
@ -117,6 +119,7 @@ function optionalChainWithEmptyCheck(...values) {
|
|||
* @param {ServerRequest} params.req
|
||||
* @param {ServerResponse} params.res
|
||||
* @param {Agent} params.agent
|
||||
* @param {Set<string>} [params.allowedProviders]
|
||||
* @param {object} [params.endpointOption]
|
||||
* @param {boolean} [params.isInitialAgent]
|
||||
* @returns {Promise<Agent>}
|
||||
|
|
@ -126,8 +129,14 @@ const initializeAgentOptions = async ({
|
|||
res,
|
||||
agent,
|
||||
endpointOption,
|
||||
allowedProviders,
|
||||
isInitialAgent = false,
|
||||
}) => {
|
||||
if (allowedProviders.size > 0 && !allowedProviders.has(agent.provider)) {
|
||||
throw new Error(
|
||||
`{ "type": "${ErrorTypes.INVALID_AGENT_PROVIDER}", "info": "${agent.provider}" }`,
|
||||
);
|
||||
}
|
||||
let currentFiles;
|
||||
/** @type {Array<MongoFile>} */
|
||||
const requestFiles = req.body.files ?? [];
|
||||
|
|
@ -263,6 +272,8 @@ const initializeClient = async ({ req, res, endpointOption }) => {
|
|||
}
|
||||
|
||||
const agentConfigs = new Map();
|
||||
/** @type {Set<string>} */
|
||||
const allowedProviders = new Set(req?.app?.locals?.[EModelEndpoint.agents]?.allowedProviders);
|
||||
|
||||
// Handle primary agent
|
||||
const primaryConfig = await initializeAgentOptions({
|
||||
|
|
@ -270,6 +281,7 @@ const initializeClient = async ({ req, res, endpointOption }) => {
|
|||
res,
|
||||
agent: primaryAgent,
|
||||
endpointOption,
|
||||
allowedProviders,
|
||||
isInitialAgent: true,
|
||||
});
|
||||
|
||||
|
|
@ -285,6 +297,7 @@ const initializeClient = async ({ req, res, endpointOption }) => {
|
|||
res,
|
||||
agent,
|
||||
endpointOption,
|
||||
allowedProviders,
|
||||
});
|
||||
agentConfigs.set(agentId, config);
|
||||
}
|
||||
|
|
@ -310,10 +323,14 @@ const initializeClient = async ({ req, res, endpointOption }) => {
|
|||
agent: primaryConfig,
|
||||
spec: endpointOption.spec,
|
||||
iconURL: endpointOption.iconURL,
|
||||
endpoint: EModelEndpoint.agents,
|
||||
attachments: primaryConfig.attachments,
|
||||
endpointType: endpointOption.endpointType,
|
||||
maxContextTokens: primaryConfig.maxContextTokens,
|
||||
resendFiles: primaryConfig.model_parameters?.resendFiles ?? true,
|
||||
endpoint:
|
||||
primaryConfig.id === Constants.EPHEMERAL_AGENT_ID
|
||||
? primaryConfig.endpoint
|
||||
: EModelEndpoint.agents,
|
||||
});
|
||||
|
||||
return { client };
|
||||
|
|
|
|||
|
|
@ -1,4 +1,10 @@
|
|||
const { CacheKeys, findLastSeparatorIndex, SEPARATORS, Time } = require('librechat-data-provider');
|
||||
const {
|
||||
Time,
|
||||
CacheKeys,
|
||||
SEPARATORS,
|
||||
parseTextParts,
|
||||
findLastSeparatorIndex,
|
||||
} = require('librechat-data-provider');
|
||||
const { getMessage } = require('~/models/Message');
|
||||
const { getLogStores } = require('~/cache');
|
||||
|
||||
|
|
@ -84,10 +90,11 @@ function createChunkProcessor(user, messageId) {
|
|||
notFoundCount++;
|
||||
return [];
|
||||
} else {
|
||||
const text = message.content?.length > 0 ? parseTextParts(message.content) : message.text;
|
||||
messageCache.set(
|
||||
messageId,
|
||||
{
|
||||
text: message.text,
|
||||
text,
|
||||
complete: true,
|
||||
},
|
||||
Time.FIVE_MINUTES,
|
||||
|
|
@ -95,7 +102,7 @@ function createChunkProcessor(user, messageId) {
|
|||
}
|
||||
|
||||
const text = typeof message === 'string' ? message : message.text;
|
||||
const complete = typeof message === 'string' ? false : message.complete ?? true;
|
||||
const complete = typeof message === 'string' ? false : (message.complete ?? true);
|
||||
|
||||
if (text === processedText) {
|
||||
noChangeCount++;
|
||||
|
|
|
|||
|
|
@ -1,11 +1,13 @@
|
|||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
const mime = require('mime');
|
||||
const axios = require('axios');
|
||||
const fetch = require('node-fetch');
|
||||
const { logger } = require('~/config');
|
||||
const { getAzureContainerClient } = require('./initialize');
|
||||
|
||||
const defaultBasePath = 'images';
|
||||
const { AZURE_STORAGE_PUBLIC_ACCESS = 'true', AZURE_CONTAINER_NAME = 'files' } = process.env;
|
||||
|
||||
/**
|
||||
* Uploads a buffer to Azure Blob Storage.
|
||||
|
|
@ -29,10 +31,9 @@ async function saveBufferToAzure({
|
|||
}) {
|
||||
try {
|
||||
const containerClient = getAzureContainerClient(containerName);
|
||||
const access = AZURE_STORAGE_PUBLIC_ACCESS?.toLowerCase() === 'true' ? 'blob' : undefined;
|
||||
// Create the container if it doesn't exist. This is done per operation.
|
||||
await containerClient.createIfNotExists({
|
||||
access: process.env.AZURE_STORAGE_PUBLIC_ACCESS ? 'blob' : undefined,
|
||||
});
|
||||
await containerClient.createIfNotExists({ access });
|
||||
const blobPath = `${basePath}/${userId}/${fileName}`;
|
||||
const blockBlobClient = containerClient.getBlockBlobClient(blobPath);
|
||||
await blockBlobClient.uploadData(buffer);
|
||||
|
|
@ -97,25 +98,21 @@ async function getAzureURL({ fileName, basePath = defaultBasePath, userId, conta
|
|||
* Deletes a blob from Azure Blob Storage.
|
||||
*
|
||||
* @param {Object} params
|
||||
* @param {string} params.fileName - The name of the file.
|
||||
* @param {string} [params.basePath='images'] - The base folder where the file is stored.
|
||||
* @param {string} params.userId - The user's id.
|
||||
* @param {string} [params.containerName] - The Azure Blob container name.
|
||||
* @param {ServerRequest} params.req - The Express request object.
|
||||
* @param {MongoFile} params.file - The file object.
|
||||
*/
|
||||
async function deleteFileFromAzure({
|
||||
fileName,
|
||||
basePath = defaultBasePath,
|
||||
userId,
|
||||
containerName,
|
||||
}) {
|
||||
async function deleteFileFromAzure(req, file) {
|
||||
try {
|
||||
const containerClient = getAzureContainerClient(containerName);
|
||||
const blobPath = `${basePath}/${userId}/${fileName}`;
|
||||
const containerClient = getAzureContainerClient(AZURE_CONTAINER_NAME);
|
||||
const blobPath = file.filepath.split(`${AZURE_CONTAINER_NAME}/`)[1];
|
||||
if (!blobPath.includes(req.user.id)) {
|
||||
throw new Error('User ID not found in blob path');
|
||||
}
|
||||
const blockBlobClient = containerClient.getBlockBlobClient(blobPath);
|
||||
await blockBlobClient.delete();
|
||||
logger.debug('[deleteFileFromAzure] Blob deleted successfully from Azure Blob Storage');
|
||||
} catch (error) {
|
||||
logger.error('[deleteFileFromAzure] Error deleting blob:', error.message);
|
||||
logger.error('[deleteFileFromAzure] Error deleting blob:', error);
|
||||
if (error.statusCode === 404) {
|
||||
return;
|
||||
}
|
||||
|
|
@ -123,6 +120,65 @@ async function deleteFileFromAzure({
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Streams a file from disk directly to Azure Blob Storage without loading
|
||||
* the entire file into memory.
|
||||
*
|
||||
* @param {Object} params
|
||||
* @param {string} params.userId - The user's id.
|
||||
* @param {string} params.filePath - The local file path to upload.
|
||||
* @param {string} params.fileName - The name of the file in Azure.
|
||||
* @param {string} [params.basePath='images'] - The base folder within the container.
|
||||
* @param {string} [params.containerName] - The Azure Blob container name.
|
||||
* @returns {Promise<string>} The URL of the uploaded blob.
|
||||
*/
|
||||
async function streamFileToAzure({
|
||||
userId,
|
||||
filePath,
|
||||
fileName,
|
||||
basePath = defaultBasePath,
|
||||
containerName,
|
||||
}) {
|
||||
try {
|
||||
const containerClient = getAzureContainerClient(containerName);
|
||||
const access = AZURE_STORAGE_PUBLIC_ACCESS?.toLowerCase() === 'true' ? 'blob' : undefined;
|
||||
|
||||
// Create the container if it doesn't exist
|
||||
await containerClient.createIfNotExists({ access });
|
||||
|
||||
const blobPath = `${basePath}/${userId}/${fileName}`;
|
||||
const blockBlobClient = containerClient.getBlockBlobClient(blobPath);
|
||||
|
||||
// Get file size for proper content length
|
||||
const stats = await fs.promises.stat(filePath);
|
||||
|
||||
// Create read stream from the file
|
||||
const fileStream = fs.createReadStream(filePath);
|
||||
|
||||
const blobContentType = mime.getType(fileName);
|
||||
await blockBlobClient.uploadStream(
|
||||
fileStream,
|
||||
undefined, // Use default concurrency (5)
|
||||
undefined, // Use default buffer size (8MB)
|
||||
{
|
||||
blobHTTPHeaders: {
|
||||
blobContentType,
|
||||
},
|
||||
onProgress: (progress) => {
|
||||
logger.debug(
|
||||
`[streamFileToAzure] Upload progress: ${progress.loadedBytes} bytes of ${stats.size}`,
|
||||
);
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
return blockBlobClient.url;
|
||||
} catch (error) {
|
||||
logger.error('[streamFileToAzure] Error streaming file:', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Uploads a file from the local file system to Azure Blob Storage.
|
||||
*
|
||||
|
|
@ -146,18 +202,19 @@ async function uploadFileToAzure({
|
|||
}) {
|
||||
try {
|
||||
const inputFilePath = file.path;
|
||||
const inputBuffer = await fs.promises.readFile(inputFilePath);
|
||||
const bytes = Buffer.byteLength(inputBuffer);
|
||||
const stats = await fs.promises.stat(inputFilePath);
|
||||
const bytes = stats.size;
|
||||
const userId = req.user.id;
|
||||
const fileName = `${file_id}__${path.basename(inputFilePath)}`;
|
||||
const fileURL = await saveBufferToAzure({
|
||||
|
||||
const fileURL = await streamFileToAzure({
|
||||
userId,
|
||||
buffer: inputBuffer,
|
||||
filePath: inputFilePath,
|
||||
fileName,
|
||||
basePath,
|
||||
containerName,
|
||||
});
|
||||
await fs.promises.unlink(inputFilePath);
|
||||
|
||||
return { filepath: fileURL, bytes };
|
||||
} catch (error) {
|
||||
logger.error('[uploadFileToAzure] Error uploading file:', error);
|
||||
|
|
|
|||
|
|
@ -32,11 +32,12 @@ async function getCodeOutputDownloadStream(fileIdentifier, apiKey) {
|
|||
const response = await axios(options);
|
||||
return response;
|
||||
} catch (error) {
|
||||
logAxiosError({
|
||||
message: `Error downloading code environment file stream: ${error.message}`,
|
||||
error,
|
||||
});
|
||||
throw new Error(`Error downloading file: ${error.message}`);
|
||||
throw new Error(
|
||||
logAxiosError({
|
||||
message: `Error downloading code environment file stream: ${error.message}`,
|
||||
error,
|
||||
}),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -89,11 +90,12 @@ async function uploadCodeEnvFile({ req, stream, filename, apiKey, entity_id = ''
|
|||
|
||||
return `${fileIdentifier}?entity_id=${entity_id}`;
|
||||
} catch (error) {
|
||||
logAxiosError({
|
||||
message: `Error uploading code environment file: ${error.message}`,
|
||||
error,
|
||||
});
|
||||
throw new Error(`Error uploading code environment file: ${error.message}`);
|
||||
throw new Error(
|
||||
logAxiosError({
|
||||
message: `Error uploading code environment file: ${error.message}`,
|
||||
error,
|
||||
}),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ const FormData = require('form-data');
|
|||
const { FileSources, envVarRegex, extractEnvVariable } = require('librechat-data-provider');
|
||||
const { loadAuthValues } = require('~/server/services/Tools/credentials');
|
||||
const { logger, createAxiosInstance } = require('~/config');
|
||||
const { logAxiosError } = require('~/utils');
|
||||
const { logAxiosError } = require('~/utils/axios');
|
||||
|
||||
const axios = createAxiosInstance();
|
||||
|
||||
|
|
@ -194,8 +194,7 @@ const uploadMistralOCR = async ({ req, file, file_id, entity_id }) => {
|
|||
};
|
||||
} catch (error) {
|
||||
const message = 'Error uploading document to Mistral OCR API';
|
||||
logAxiosError({ error, message });
|
||||
throw new Error(message);
|
||||
throw new Error(logAxiosError({ error, message }));
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -29,9 +29,6 @@ const mockAxios = {
|
|||
|
||||
jest.mock('axios', () => mockAxios);
|
||||
jest.mock('fs');
|
||||
jest.mock('~/utils', () => ({
|
||||
logAxiosError: jest.fn(),
|
||||
}));
|
||||
jest.mock('~/config', () => ({
|
||||
logger: {
|
||||
error: jest.fn(),
|
||||
|
|
@ -494,9 +491,6 @@ describe('MistralOCR Service', () => {
|
|||
}),
|
||||
).rejects.toThrow('Error uploading document to Mistral OCR API');
|
||||
expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf');
|
||||
|
||||
const { logAxiosError } = require('~/utils');
|
||||
expect(logAxiosError).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle single page documents without page numbering', async () => {
|
||||
|
|
|
|||
|
|
@ -1,7 +1,13 @@
|
|||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
const fetch = require('node-fetch');
|
||||
const { PutObjectCommand, GetObjectCommand, DeleteObjectCommand } = require('@aws-sdk/client-s3');
|
||||
const { FileSources } = require('librechat-data-provider');
|
||||
const {
|
||||
PutObjectCommand,
|
||||
GetObjectCommand,
|
||||
HeadObjectCommand,
|
||||
DeleteObjectCommand,
|
||||
} = require('@aws-sdk/client-s3');
|
||||
const { getSignedUrl } = require('@aws-sdk/s3-request-presigner');
|
||||
const { initializeS3 } = require('./initialize');
|
||||
const { logger } = require('~/config');
|
||||
|
|
@ -9,6 +15,34 @@ const { logger } = require('~/config');
|
|||
const bucketName = process.env.AWS_BUCKET_NAME;
|
||||
const defaultBasePath = 'images';
|
||||
|
||||
let s3UrlExpirySeconds = 7 * 24 * 60 * 60;
|
||||
let s3RefreshExpiryMs = null;
|
||||
|
||||
if (process.env.S3_URL_EXPIRY_SECONDS !== undefined) {
|
||||
const parsed = parseInt(process.env.S3_URL_EXPIRY_SECONDS, 10);
|
||||
|
||||
if (!isNaN(parsed) && parsed > 0) {
|
||||
s3UrlExpirySeconds = Math.min(parsed, 7 * 24 * 60 * 60);
|
||||
} else {
|
||||
logger.warn(
|
||||
`[S3] Invalid S3_URL_EXPIRY_SECONDS value: "${process.env.S3_URL_EXPIRY_SECONDS}". Using 7-day expiry.`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if (process.env.S3_REFRESH_EXPIRY_MS !== null && process.env.S3_REFRESH_EXPIRY_MS) {
|
||||
const parsed = parseInt(process.env.S3_REFRESH_EXPIRY_MS, 10);
|
||||
|
||||
if (!isNaN(parsed) && parsed > 0) {
|
||||
s3RefreshExpiryMs = parsed;
|
||||
logger.info(`[S3] Using custom refresh expiry time: ${s3RefreshExpiryMs}ms`);
|
||||
} else {
|
||||
logger.warn(
|
||||
`[S3] Invalid S3_REFRESH_EXPIRY_MS value: "${process.env.S3_REFRESH_EXPIRY_MS}". Using default refresh logic.`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructs the S3 key based on the base path, user ID, and file name.
|
||||
*/
|
||||
|
|
@ -39,13 +73,14 @@ async function saveBufferToS3({ userId, buffer, fileName, basePath = defaultBase
|
|||
}
|
||||
|
||||
/**
|
||||
* Retrieves a signed URL for a file stored in S3.
|
||||
* Retrieves a URL for a file stored in S3.
|
||||
* Returns a signed URL with expiration time or a proxy URL based on config
|
||||
*
|
||||
* @param {Object} params
|
||||
* @param {string} params.userId - The user's unique identifier.
|
||||
* @param {string} params.fileName - The file name in S3.
|
||||
* @param {string} [params.basePath='images'] - The base path in the bucket.
|
||||
* @returns {Promise<string>} A signed URL valid for 24 hours.
|
||||
* @returns {Promise<string>} A URL to access the S3 object
|
||||
*/
|
||||
async function getS3URL({ userId, fileName, basePath = defaultBasePath }) {
|
||||
const key = getS3Key(basePath, userId, fileName);
|
||||
|
|
@ -53,7 +88,7 @@ async function getS3URL({ userId, fileName, basePath = defaultBasePath }) {
|
|||
|
||||
try {
|
||||
const s3 = initializeS3();
|
||||
return await getSignedUrl(s3, new GetObjectCommand(params), { expiresIn: 86400 });
|
||||
return await getSignedUrl(s3, new GetObjectCommand(params), { expiresIn: s3UrlExpirySeconds });
|
||||
} catch (error) {
|
||||
logger.error('[getS3URL] Error getting signed URL from S3:', error.message);
|
||||
throw error;
|
||||
|
|
@ -86,21 +121,51 @@ async function saveURLToS3({ userId, URL, fileName, basePath = defaultBasePath }
|
|||
* Deletes a file from S3.
|
||||
*
|
||||
* @param {Object} params
|
||||
* @param {string} params.userId - The user's unique identifier.
|
||||
* @param {string} params.fileName - The file name in S3.
|
||||
* @param {string} [params.basePath='images'] - The base path in the bucket.
|
||||
* @param {ServerRequest} params.req
|
||||
* @param {MongoFile} params.file - The file object to delete.
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
async function deleteFileFromS3({ userId, fileName, basePath = defaultBasePath }) {
|
||||
const key = getS3Key(basePath, userId, fileName);
|
||||
async function deleteFileFromS3(req, file) {
|
||||
const key = extractKeyFromS3Url(file.filepath);
|
||||
const params = { Bucket: bucketName, Key: key };
|
||||
if (!key.includes(req.user.id)) {
|
||||
const message = `[deleteFileFromS3] User ID mismatch: ${req.user.id} vs ${key}`;
|
||||
logger.error(message);
|
||||
throw new Error(message);
|
||||
}
|
||||
|
||||
try {
|
||||
const s3 = initializeS3();
|
||||
await s3.send(new DeleteObjectCommand(params));
|
||||
logger.debug('[deleteFileFromS3] File deleted successfully from S3');
|
||||
|
||||
try {
|
||||
const headCommand = new HeadObjectCommand(params);
|
||||
await s3.send(headCommand);
|
||||
logger.debug('[deleteFileFromS3] File exists, proceeding with deletion');
|
||||
} catch (headErr) {
|
||||
if (headErr.name === 'NotFound') {
|
||||
logger.warn(`[deleteFileFromS3] File does not exist: ${key}`);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
const deleteResult = await s3.send(new DeleteObjectCommand(params));
|
||||
logger.debug('[deleteFileFromS3] Delete command response:', JSON.stringify(deleteResult));
|
||||
try {
|
||||
await s3.send(new HeadObjectCommand(params));
|
||||
logger.error('[deleteFileFromS3] File still exists after deletion!');
|
||||
} catch (verifyErr) {
|
||||
if (verifyErr.name === 'NotFound') {
|
||||
logger.debug(`[deleteFileFromS3] Verified file is deleted: ${key}`);
|
||||
} else {
|
||||
logger.error('[deleteFileFromS3] Error verifying deletion:', verifyErr);
|
||||
}
|
||||
}
|
||||
|
||||
logger.debug('[deleteFileFromS3] S3 File deletion completed');
|
||||
} catch (error) {
|
||||
logger.error('[deleteFileFromS3] Error deleting file from S3:', error.message);
|
||||
logger.error(`[deleteFileFromS3] Error deleting file from S3: ${error.message}`);
|
||||
logger.error(error.stack);
|
||||
|
||||
// If the file is not found, we can safely return.
|
||||
if (error.code === 'NoSuchKey') {
|
||||
return;
|
||||
|
|
@ -110,7 +175,7 @@ async function deleteFileFromS3({ userId, fileName, basePath = defaultBasePath }
|
|||
}
|
||||
|
||||
/**
|
||||
* Uploads a local file to S3.
|
||||
* Uploads a local file to S3 by streaming it directly without loading into memory.
|
||||
*
|
||||
* @param {Object} params
|
||||
* @param {import('express').Request} params.req - The Express request (must include user).
|
||||
|
|
@ -122,37 +187,272 @@ async function deleteFileFromS3({ userId, fileName, basePath = defaultBasePath }
|
|||
async function uploadFileToS3({ req, file, file_id, basePath = defaultBasePath }) {
|
||||
try {
|
||||
const inputFilePath = file.path;
|
||||
const inputBuffer = await fs.promises.readFile(inputFilePath);
|
||||
const bytes = Buffer.byteLength(inputBuffer);
|
||||
const userId = req.user.id;
|
||||
const fileName = `${file_id}__${path.basename(inputFilePath)}`;
|
||||
const fileURL = await saveBufferToS3({ userId, buffer: inputBuffer, fileName, basePath });
|
||||
await fs.promises.unlink(inputFilePath);
|
||||
const key = getS3Key(basePath, userId, fileName);
|
||||
|
||||
const stats = await fs.promises.stat(inputFilePath);
|
||||
const bytes = stats.size;
|
||||
const fileStream = fs.createReadStream(inputFilePath);
|
||||
|
||||
const s3 = initializeS3();
|
||||
const uploadParams = {
|
||||
Bucket: bucketName,
|
||||
Key: key,
|
||||
Body: fileStream,
|
||||
};
|
||||
|
||||
await s3.send(new PutObjectCommand(uploadParams));
|
||||
const fileURL = await getS3URL({ userId, fileName, basePath });
|
||||
return { filepath: fileURL, bytes };
|
||||
} catch (error) {
|
||||
logger.error('[uploadFileToS3] Error uploading file to S3:', error.message);
|
||||
logger.error('[uploadFileToS3] Error streaming file to S3:', error);
|
||||
try {
|
||||
if (file && file.path) {
|
||||
await fs.promises.unlink(file.path);
|
||||
}
|
||||
} catch (unlinkError) {
|
||||
logger.error(
|
||||
'[uploadFileToS3] Error deleting temporary file, likely already deleted:',
|
||||
unlinkError.message,
|
||||
);
|
||||
}
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Extracts the S3 key from a URL or returns the key if already properly formatted
|
||||
*
|
||||
* @param {string} fileUrlOrKey - The file URL or key
|
||||
* @returns {string} The S3 key
|
||||
*/
|
||||
function extractKeyFromS3Url(fileUrlOrKey) {
|
||||
if (!fileUrlOrKey) {
|
||||
throw new Error('Invalid input: URL or key is empty');
|
||||
}
|
||||
|
||||
try {
|
||||
const url = new URL(fileUrlOrKey);
|
||||
return url.pathname.substring(1);
|
||||
} catch (error) {
|
||||
const parts = fileUrlOrKey.split('/');
|
||||
|
||||
if (parts.length >= 3 && !fileUrlOrKey.startsWith('http') && !fileUrlOrKey.startsWith('/')) {
|
||||
return fileUrlOrKey;
|
||||
}
|
||||
|
||||
return fileUrlOrKey.startsWith('/') ? fileUrlOrKey.substring(1) : fileUrlOrKey;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves a readable stream for a file stored in S3.
|
||||
*
|
||||
* @param {ServerRequest} req - Server request object.
|
||||
* @param {string} filePath - The S3 key of the file.
|
||||
* @returns {Promise<NodeJS.ReadableStream>}
|
||||
*/
|
||||
async function getS3FileStream(filePath) {
|
||||
const params = { Bucket: bucketName, Key: filePath };
|
||||
async function getS3FileStream(_req, filePath) {
|
||||
try {
|
||||
const Key = extractKeyFromS3Url(filePath);
|
||||
const params = { Bucket: bucketName, Key };
|
||||
const s3 = initializeS3();
|
||||
const data = await s3.send(new GetObjectCommand(params));
|
||||
return data.Body; // Returns a Node.js ReadableStream.
|
||||
} catch (error) {
|
||||
logger.error('[getS3FileStream] Error retrieving S3 file stream:', error.message);
|
||||
logger.error('[getS3FileStream] Error retrieving S3 file stream:', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Determines if a signed S3 URL is close to expiration
|
||||
*
|
||||
* @param {string} signedUrl - The signed S3 URL
|
||||
* @param {number} bufferSeconds - Buffer time in seconds
|
||||
* @returns {boolean} True if the URL needs refreshing
|
||||
*/
|
||||
function needsRefresh(signedUrl, bufferSeconds) {
|
||||
try {
|
||||
// Parse the URL
|
||||
const url = new URL(signedUrl);
|
||||
|
||||
// Check if it has the signature parameters that indicate it's a signed URL
|
||||
// X-Amz-Signature is the most reliable indicator for AWS signed URLs
|
||||
if (!url.searchParams.has('X-Amz-Signature')) {
|
||||
// Not a signed URL, so no expiration to check (or it's already a proxy URL)
|
||||
return false;
|
||||
}
|
||||
|
||||
// Extract the expiration time from the URL
|
||||
const expiresParam = url.searchParams.get('X-Amz-Expires');
|
||||
const dateParam = url.searchParams.get('X-Amz-Date');
|
||||
|
||||
if (!expiresParam || !dateParam) {
|
||||
// Missing expiration information, assume it needs refresh to be safe
|
||||
return true;
|
||||
}
|
||||
|
||||
// Parse the AWS date format (YYYYMMDDTHHMMSSZ)
|
||||
const year = dateParam.substring(0, 4);
|
||||
const month = dateParam.substring(4, 6);
|
||||
const day = dateParam.substring(6, 8);
|
||||
const hour = dateParam.substring(9, 11);
|
||||
const minute = dateParam.substring(11, 13);
|
||||
const second = dateParam.substring(13, 15);
|
||||
|
||||
const dateObj = new Date(`${year}-${month}-${day}T${hour}:${minute}:${second}Z`);
|
||||
const expiresAtDate = new Date(dateObj.getTime() + parseInt(expiresParam) * 1000);
|
||||
|
||||
// Check if it's close to expiration
|
||||
const now = new Date();
|
||||
|
||||
// If S3_REFRESH_EXPIRY_MS is set, use it to determine if URL is expired
|
||||
if (s3RefreshExpiryMs !== null) {
|
||||
const urlCreationTime = dateObj.getTime();
|
||||
const urlAge = now.getTime() - urlCreationTime;
|
||||
return urlAge >= s3RefreshExpiryMs;
|
||||
}
|
||||
|
||||
// Otherwise use the default buffer-based logic
|
||||
const bufferTime = new Date(now.getTime() + bufferSeconds * 1000);
|
||||
return expiresAtDate <= bufferTime;
|
||||
} catch (error) {
|
||||
logger.error('Error checking URL expiration:', error);
|
||||
// If we can't determine, assume it needs refresh to be safe
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates a new URL for an expired S3 URL
|
||||
* @param {string} currentURL - The current file URL
|
||||
* @returns {Promise<string | undefined>}
|
||||
*/
|
||||
async function getNewS3URL(currentURL) {
|
||||
try {
|
||||
const s3Key = extractKeyFromS3Url(currentURL);
|
||||
if (!s3Key) {
|
||||
return;
|
||||
}
|
||||
const keyParts = s3Key.split('/');
|
||||
if (keyParts.length < 3) {
|
||||
return;
|
||||
}
|
||||
|
||||
const basePath = keyParts[0];
|
||||
const userId = keyParts[1];
|
||||
const fileName = keyParts.slice(2).join('/');
|
||||
|
||||
return await getS3URL({
|
||||
userId,
|
||||
fileName,
|
||||
basePath,
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error('Error getting new S3 URL:', error);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Refreshes S3 URLs for an array of files if they're expired or close to expiring
|
||||
*
|
||||
* @param {IMongoFile[]} files - Array of file documents
|
||||
* @param {(files: MongoFile[]) => Promise<void>} batchUpdateFiles - Function to update files in the database
|
||||
* @param {number} [bufferSeconds=3600] - Buffer time in seconds to check for expiration
|
||||
* @returns {Promise<IMongoFile[]>} The files with refreshed URLs if needed
|
||||
*/
|
||||
async function refreshS3FileUrls(files, batchUpdateFiles, bufferSeconds = 3600) {
|
||||
if (!files || !Array.isArray(files) || files.length === 0) {
|
||||
return files;
|
||||
}
|
||||
|
||||
const filesToUpdate = [];
|
||||
|
||||
for (let i = 0; i < files.length; i++) {
|
||||
const file = files[i];
|
||||
if (!file?.file_id) {
|
||||
continue;
|
||||
}
|
||||
if (file.source !== FileSources.s3) {
|
||||
continue;
|
||||
}
|
||||
if (!file.filepath) {
|
||||
continue;
|
||||
}
|
||||
if (!needsRefresh(file.filepath, bufferSeconds)) {
|
||||
continue;
|
||||
}
|
||||
try {
|
||||
const newURL = await getNewS3URL(file.filepath);
|
||||
if (!newURL) {
|
||||
continue;
|
||||
}
|
||||
filesToUpdate.push({
|
||||
file_id: file.file_id,
|
||||
filepath: newURL,
|
||||
});
|
||||
files[i].filepath = newURL;
|
||||
} catch (error) {
|
||||
logger.error(`Error refreshing S3 URL for file ${file.file_id}:`, error);
|
||||
}
|
||||
}
|
||||
|
||||
if (filesToUpdate.length > 0) {
|
||||
await batchUpdateFiles(filesToUpdate);
|
||||
}
|
||||
|
||||
return files;
|
||||
}
|
||||
|
||||
/**
|
||||
* Refreshes a single S3 URL if it's expired or close to expiring
|
||||
*
|
||||
* @param {{ filepath: string, source: string }} fileObj - Simple file object containing filepath and source
|
||||
* @param {number} [bufferSeconds=3600] - Buffer time in seconds to check for expiration
|
||||
* @returns {Promise<string>} The refreshed URL or the original URL if no refresh needed
|
||||
*/
|
||||
async function refreshS3Url(fileObj, bufferSeconds = 3600) {
|
||||
if (!fileObj || fileObj.source !== FileSources.s3 || !fileObj.filepath) {
|
||||
return fileObj?.filepath || '';
|
||||
}
|
||||
|
||||
if (!needsRefresh(fileObj.filepath, bufferSeconds)) {
|
||||
return fileObj.filepath;
|
||||
}
|
||||
|
||||
try {
|
||||
const s3Key = extractKeyFromS3Url(fileObj.filepath);
|
||||
if (!s3Key) {
|
||||
logger.warn(`Unable to extract S3 key from URL: ${fileObj.filepath}`);
|
||||
return fileObj.filepath;
|
||||
}
|
||||
|
||||
const keyParts = s3Key.split('/');
|
||||
if (keyParts.length < 3) {
|
||||
logger.warn(`Invalid S3 key format: ${s3Key}`);
|
||||
return fileObj.filepath;
|
||||
}
|
||||
|
||||
const basePath = keyParts[0];
|
||||
const userId = keyParts[1];
|
||||
const fileName = keyParts.slice(2).join('/');
|
||||
|
||||
const newUrl = await getS3URL({
|
||||
userId,
|
||||
fileName,
|
||||
basePath,
|
||||
});
|
||||
|
||||
logger.debug(`Refreshed S3 URL for key: ${s3Key}`);
|
||||
return newUrl;
|
||||
} catch (error) {
|
||||
logger.error(`Error refreshing S3 URL: ${error.message}`);
|
||||
return fileObj.filepath;
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
saveBufferToS3,
|
||||
saveURLToS3,
|
||||
|
|
@ -160,4 +460,8 @@ module.exports = {
|
|||
deleteFileFromS3,
|
||||
uploadFileToS3,
|
||||
getS3FileStream,
|
||||
refreshS3FileUrls,
|
||||
refreshS3Url,
|
||||
needsRefresh,
|
||||
getNewS3URL,
|
||||
};
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ const {
|
|||
EModelEndpoint,
|
||||
} = require('librechat-data-provider');
|
||||
const { getStrategyFunctions } = require('~/server/services/Files/strategies');
|
||||
const { logAxiosError } = require('~/utils');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
/**
|
||||
|
|
@ -24,8 +25,8 @@ async function fetchImageToBase64(url) {
|
|||
});
|
||||
return Buffer.from(response.data).toString('base64');
|
||||
} catch (error) {
|
||||
logger.error('Error fetching image to convert to base64', error);
|
||||
throw error;
|
||||
const message = 'Error fetching image to convert to base64';
|
||||
throw new Error(logAxiosError({ message, error }));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -37,17 +38,21 @@ const base64Only = new Set([
|
|||
EModelEndpoint.bedrock,
|
||||
]);
|
||||
|
||||
const blobStorageSources = new Set([FileSources.azure_blob, FileSources.s3]);
|
||||
|
||||
/**
|
||||
* Encodes and formats the given files.
|
||||
* @param {Express.Request} req - The request object.
|
||||
* @param {Array<MongoFile>} files - The array of files to encode and format.
|
||||
* @param {EModelEndpoint} [endpoint] - Optional: The endpoint for the image.
|
||||
* @param {string} [mode] - Optional: The endpoint mode for the image.
|
||||
* @returns {Promise<Object>} - A promise that resolves to the result object containing the encoded images and file details.
|
||||
* @returns {Promise<{ text: string; files: MongoFile[]; image_urls: MessageContentImageUrl[] }>} - A promise that resolves to the result object containing the encoded images and file details.
|
||||
*/
|
||||
async function encodeAndFormat(req, files, endpoint, mode) {
|
||||
const promises = [];
|
||||
/** @type {Record<FileSources, Pick<ReturnType<typeof getStrategyFunctions>, 'prepareImagePayload' | 'getDownloadStream'>>} */
|
||||
const encodingMethods = {};
|
||||
/** @type {{ text: string; files: MongoFile[]; image_urls: MessageContentImageUrl[] }} */
|
||||
const result = {
|
||||
text: '',
|
||||
files: [],
|
||||
|
|
@ -59,6 +64,7 @@ async function encodeAndFormat(req, files, endpoint, mode) {
|
|||
}
|
||||
|
||||
for (let file of files) {
|
||||
/** @type {FileSources} */
|
||||
const source = file.source ?? FileSources.local;
|
||||
if (source === FileSources.text && file.text) {
|
||||
result.text += `${!result.text ? 'Attached document(s):\n```md' : '\n\n---\n\n'}# "${file.filename}"\n${file.text}\n`;
|
||||
|
|
@ -70,18 +76,52 @@ async function encodeAndFormat(req, files, endpoint, mode) {
|
|||
}
|
||||
|
||||
if (!encodingMethods[source]) {
|
||||
const { prepareImagePayload } = getStrategyFunctions(source);
|
||||
const { prepareImagePayload, getDownloadStream } = getStrategyFunctions(source);
|
||||
if (!prepareImagePayload) {
|
||||
throw new Error(`Encoding function not implemented for ${source}`);
|
||||
}
|
||||
|
||||
encodingMethods[source] = prepareImagePayload;
|
||||
encodingMethods[source] = { prepareImagePayload, getDownloadStream };
|
||||
}
|
||||
|
||||
const preparePayload = encodingMethods[source];
|
||||
const preparePayload = encodingMethods[source].prepareImagePayload;
|
||||
/* We need to fetch the image and convert it to base64 if we are using S3/Azure Blob storage. */
|
||||
if (blobStorageSources.has(source)) {
|
||||
try {
|
||||
const downloadStream = encodingMethods[source].getDownloadStream;
|
||||
const stream = await downloadStream(req, file.filepath);
|
||||
const streamPromise = new Promise((resolve, reject) => {
|
||||
/** @type {Uint8Array[]} */
|
||||
const chunks = [];
|
||||
stream.on('readable', () => {
|
||||
let chunk;
|
||||
while (null !== (chunk = stream.read())) {
|
||||
chunks.push(chunk);
|
||||
}
|
||||
});
|
||||
|
||||
/* Google & Anthropic don't support passing URLs to payload */
|
||||
if (source !== FileSources.local && base64Only.has(endpoint)) {
|
||||
stream.on('end', () => {
|
||||
const buffer = Buffer.concat(chunks);
|
||||
const base64Data = buffer.toString('base64');
|
||||
resolve(base64Data);
|
||||
});
|
||||
stream.on('error', (error) => {
|
||||
reject(error);
|
||||
});
|
||||
});
|
||||
const base64Data = await streamPromise;
|
||||
promises.push([file, base64Data]);
|
||||
continue;
|
||||
} catch (error) {
|
||||
logger.error(
|
||||
`Error processing blob storage file stream for ${file.name} base64 payload:`,
|
||||
error,
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
/* Google & Anthropic don't support passing URLs to payload */
|
||||
} else if (source !== FileSources.local && base64Only.has(endpoint)) {
|
||||
const [_file, imageURL] = await preparePayload(req, file);
|
||||
promises.push([_file, await fetchImageToBase64(imageURL)]);
|
||||
continue;
|
||||
|
|
|
|||
|
|
@ -492,7 +492,7 @@ const processAgentFileUpload = async ({ req, res, metadata }) => {
|
|||
|
||||
let fileInfoMetadata;
|
||||
const entity_id = messageAttachment === true ? undefined : agent_id;
|
||||
|
||||
const basePath = mime.getType(file.originalname)?.startsWith('image') ? 'images' : 'uploads';
|
||||
if (tool_resource === EToolResources.execute_code) {
|
||||
const isCodeEnabled = await checkCapability(req, AgentCapabilities.execute_code);
|
||||
if (!isCodeEnabled) {
|
||||
|
|
@ -532,7 +532,7 @@ const processAgentFileUpload = async ({ req, res, metadata }) => {
|
|||
images,
|
||||
filename,
|
||||
filepath: ocrFileURL,
|
||||
} = await handleFileUpload({ req, file, file_id, entity_id: agent_id });
|
||||
} = await handleFileUpload({ req, file, file_id, entity_id: agent_id, basePath });
|
||||
|
||||
const fileInfo = removeNullishValues({
|
||||
text,
|
||||
|
|
@ -582,6 +582,7 @@ const processAgentFileUpload = async ({ req, res, metadata }) => {
|
|||
file,
|
||||
file_id,
|
||||
entity_id,
|
||||
basePath,
|
||||
});
|
||||
|
||||
let filepath = _filepath;
|
||||
|
|
|
|||
|
|
@ -211,6 +211,8 @@ const getStrategyFunctions = (fileSource) => {
|
|||
} else if (fileSource === FileSources.openai) {
|
||||
return openAIStrategy();
|
||||
} else if (fileSource === FileSources.azure) {
|
||||
return openAIStrategy();
|
||||
} else if (fileSource === FileSources.azure_blob) {
|
||||
return azureStrategy();
|
||||
} else if (fileSource === FileSources.vectordb) {
|
||||
return vectorStrategy();
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ const { logger, getMCPManager } = require('~/config');
|
|||
* Creates a general tool for an entire action set.
|
||||
*
|
||||
* @param {Object} params - The parameters for loading action sets.
|
||||
* @param {ServerRequest} params.req - The name of the tool.
|
||||
* @param {ServerRequest} params.req - The Express request object, containing user/request info.
|
||||
* @param {string} params.toolKey - The toolKey for the tool.
|
||||
* @param {import('@librechat/agents').Providers | EModelEndpoint} params.provider - The provider for the tool.
|
||||
* @param {string} params.model - The model for the tool.
|
||||
|
|
@ -37,19 +37,30 @@ async function createMCPTool({ req, toolKey, provider }) {
|
|||
}
|
||||
|
||||
const [toolName, serverName] = toolKey.split(Constants.mcp_delimiter);
|
||||
const userId = req.user?.id;
|
||||
|
||||
if (!userId) {
|
||||
logger.error(
|
||||
`[MCP][${serverName}][${toolName}] User ID not found on request. Cannot create tool.`,
|
||||
);
|
||||
throw new Error(`User ID not found on request. Cannot create tool for ${toolKey}.`);
|
||||
}
|
||||
|
||||
/** @type {(toolArguments: Object | string, config?: GraphRunnableConfig) => Promise<unknown>} */
|
||||
const _call = async (toolArguments, config) => {
|
||||
try {
|
||||
const mcpManager = await getMCPManager();
|
||||
const mcpManager = getMCPManager();
|
||||
const result = await mcpManager.callTool({
|
||||
serverName,
|
||||
toolName,
|
||||
provider,
|
||||
toolArguments,
|
||||
options: {
|
||||
userId,
|
||||
signal: config?.signal,
|
||||
},
|
||||
});
|
||||
|
||||
if (isAssistantsEndpoint(provider) && Array.isArray(result)) {
|
||||
return result[0];
|
||||
}
|
||||
|
|
@ -58,8 +69,13 @@ async function createMCPTool({ req, toolKey, provider }) {
|
|||
}
|
||||
return result;
|
||||
} catch (error) {
|
||||
logger.error(`${toolName} MCP server tool call failed`, error);
|
||||
return `${toolName} MCP server tool call failed.`;
|
||||
logger.error(
|
||||
`[MCP][User: ${userId}][${serverName}] Error calling "${toolName}" MCP tool:`,
|
||||
error,
|
||||
);
|
||||
throw new Error(
|
||||
`"${toolKey}" tool call failed${error?.message ? `: ${error?.message}` : '.'}`,
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -55,8 +55,7 @@ async function retrieveRun({ thread_id, run_id, timeout, openai }) {
|
|||
return response.data;
|
||||
} catch (error) {
|
||||
const message = '[retrieveRun] Failed to retrieve run data:';
|
||||
logAxiosError({ message, error });
|
||||
throw error;
|
||||
throw new Error(logAxiosError({ message, error }));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -93,11 +93,12 @@ const refreshAccessToken = async ({
|
|||
return response.data;
|
||||
} catch (error) {
|
||||
const message = 'Error refreshing OAuth tokens';
|
||||
logAxiosError({
|
||||
message,
|
||||
error,
|
||||
});
|
||||
throw new Error(message);
|
||||
throw new Error(
|
||||
logAxiosError({
|
||||
message,
|
||||
error,
|
||||
}),
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -156,11 +157,12 @@ const getAccessToken = async ({
|
|||
return response.data;
|
||||
} catch (error) {
|
||||
const message = 'Error exchanging OAuth code';
|
||||
logAxiosError({
|
||||
message,
|
||||
error,
|
||||
});
|
||||
throw new Error(message);
|
||||
throw new Error(
|
||||
logAxiosError({
|
||||
message,
|
||||
error,
|
||||
}),
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -15,9 +15,15 @@ const {
|
|||
AgentCapabilities,
|
||||
validateAndParseOpenAPISpec,
|
||||
} = require('librechat-data-provider');
|
||||
const {
|
||||
loadActionSets,
|
||||
createActionTool,
|
||||
decryptMetadata,
|
||||
domainParser,
|
||||
} = require('./ActionService');
|
||||
const { processFileURL, uploadImageBuffer } = require('~/server/services/Files/process');
|
||||
const { createYouTubeTools, manifestToolMap, toolkits } = require('~/app/clients/tools');
|
||||
const { loadActionSets, createActionTool, domainParser } = require('./ActionService');
|
||||
const { isActionDomainAllowed } = require('~/server/services/domains');
|
||||
const { getEndpointsConfig } = require('~/server/services/Config');
|
||||
const { recordUsage } = require('~/server/services/Threads');
|
||||
const { loadTools } = require('~/app/clients/tools/util');
|
||||
|
|
@ -315,58 +321,95 @@ async function processRequiredActions(client, requiredActions) {
|
|||
if (!tool) {
|
||||
// throw new Error(`Tool ${currentAction.tool} not found.`);
|
||||
|
||||
// Load all action sets once if not already loaded
|
||||
if (!actionSets.length) {
|
||||
actionSets =
|
||||
(await loadActionSets({
|
||||
assistant_id: client.req.body.assistant_id,
|
||||
})) ?? [];
|
||||
|
||||
// Process all action sets once
|
||||
// Map domains to their processed action sets
|
||||
const processedDomains = new Map();
|
||||
const domainMap = new Map();
|
||||
|
||||
for (const action of actionSets) {
|
||||
const domain = await domainParser(client.req, action.metadata.domain, true);
|
||||
domainMap.set(domain, action);
|
||||
|
||||
// Check if domain is allowed
|
||||
const isDomainAllowed = await isActionDomainAllowed(action.metadata.domain);
|
||||
if (!isDomainAllowed) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Validate and parse OpenAPI spec
|
||||
const validationResult = validateAndParseOpenAPISpec(action.metadata.raw_spec);
|
||||
if (!validationResult.spec) {
|
||||
throw new Error(
|
||||
`Invalid spec: user: ${client.req.user.id} | thread_id: ${requiredActions[0].thread_id} | run_id: ${requiredActions[0].run_id}`,
|
||||
);
|
||||
}
|
||||
|
||||
// Process the OpenAPI spec
|
||||
const { requestBuilders } = openapiToFunction(validationResult.spec);
|
||||
|
||||
// Store encrypted values for OAuth flow
|
||||
const encrypted = {
|
||||
oauth_client_id: action.metadata.oauth_client_id,
|
||||
oauth_client_secret: action.metadata.oauth_client_secret,
|
||||
};
|
||||
|
||||
// Decrypt metadata
|
||||
const decryptedAction = { ...action };
|
||||
decryptedAction.metadata = await decryptMetadata(action.metadata);
|
||||
|
||||
processedDomains.set(domain, {
|
||||
action: decryptedAction,
|
||||
requestBuilders,
|
||||
encrypted,
|
||||
});
|
||||
|
||||
// Store builders for reuse
|
||||
ActionBuildersMap[action.metadata.domain] = requestBuilders;
|
||||
}
|
||||
|
||||
// Update actionSets reference to use the domain map
|
||||
actionSets = { domainMap, processedDomains };
|
||||
}
|
||||
|
||||
let actionSet = null;
|
||||
// Find the matching domain for this tool
|
||||
let currentDomain = '';
|
||||
for (let action of actionSets) {
|
||||
const domain = await domainParser(client.req, action.metadata.domain, true);
|
||||
for (const domain of actionSets.domainMap.keys()) {
|
||||
if (currentAction.tool.includes(domain)) {
|
||||
currentDomain = domain;
|
||||
actionSet = action;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!actionSet) {
|
||||
if (!currentDomain || !actionSets.processedDomains.has(currentDomain)) {
|
||||
// TODO: try `function` if no action set is found
|
||||
// throw new Error(`Tool ${currentAction.tool} not found.`);
|
||||
continue;
|
||||
}
|
||||
|
||||
let builders = ActionBuildersMap[actionSet.metadata.domain];
|
||||
|
||||
if (!builders) {
|
||||
const validationResult = validateAndParseOpenAPISpec(actionSet.metadata.raw_spec);
|
||||
if (!validationResult.spec) {
|
||||
throw new Error(
|
||||
`Invalid spec: user: ${client.req.user.id} | thread_id: ${requiredActions[0].thread_id} | run_id: ${requiredActions[0].run_id}`,
|
||||
);
|
||||
}
|
||||
const { requestBuilders } = openapiToFunction(validationResult.spec);
|
||||
ActionToolMap[actionSet.metadata.domain] = requestBuilders;
|
||||
builders = requestBuilders;
|
||||
}
|
||||
|
||||
const { action, requestBuilders, encrypted } = actionSets.processedDomains.get(currentDomain);
|
||||
const functionName = currentAction.tool.replace(`${actionDelimiter}${currentDomain}`, '');
|
||||
|
||||
const requestBuilder = builders[functionName];
|
||||
const requestBuilder = requestBuilders[functionName];
|
||||
|
||||
if (!requestBuilder) {
|
||||
// throw new Error(`Tool ${currentAction.tool} not found.`);
|
||||
continue;
|
||||
}
|
||||
|
||||
// We've already decrypted the metadata, so we can pass it directly
|
||||
tool = await createActionTool({
|
||||
req: client.req,
|
||||
res: client.res,
|
||||
action: actionSet,
|
||||
action,
|
||||
requestBuilder,
|
||||
// Note: intentionally not passing zodSchema, name, and description for assistants API
|
||||
encrypted, // Pass the encrypted values for OAuth flow
|
||||
});
|
||||
if (!tool) {
|
||||
logger.warn(
|
||||
|
|
@ -430,10 +473,10 @@ async function loadAgentTools({ req, res, agent, tool_resources, openAIApiKey })
|
|||
const areToolsEnabled = checkCapability(AgentCapabilities.tools);
|
||||
|
||||
const _agentTools = agent.tools?.filter((tool) => {
|
||||
if (tool === Tools.file_search && !checkCapability(AgentCapabilities.file_search)) {
|
||||
return false;
|
||||
} else if (tool === Tools.execute_code && !checkCapability(AgentCapabilities.execute_code)) {
|
||||
return false;
|
||||
if (tool === Tools.file_search) {
|
||||
return checkCapability(AgentCapabilities.file_search);
|
||||
} else if (tool === Tools.execute_code) {
|
||||
return checkCapability(AgentCapabilities.execute_code);
|
||||
} else if (!areToolsEnabled && !tool.includes(actionDelimiter)) {
|
||||
return false;
|
||||
}
|
||||
|
|
@ -511,7 +554,62 @@ async function loadAgentTools({ req, res, agent, tool_resources, openAIApiKey })
|
|||
};
|
||||
}
|
||||
|
||||
let actionSets = [];
|
||||
const actionSets = (await loadActionSets({ agent_id: agent.id })) ?? [];
|
||||
if (actionSets.length === 0) {
|
||||
if (_agentTools.length > 0 && agentTools.length === 0) {
|
||||
logger.warn(`No tools found for the specified tool calls: ${_agentTools.join(', ')}`);
|
||||
}
|
||||
return {
|
||||
tools: agentTools,
|
||||
toolContextMap,
|
||||
};
|
||||
}
|
||||
|
||||
// Process each action set once (validate spec, decrypt metadata)
|
||||
const processedActionSets = new Map();
|
||||
const domainMap = new Map();
|
||||
|
||||
for (const action of actionSets) {
|
||||
const domain = await domainParser(req, action.metadata.domain, true);
|
||||
domainMap.set(domain, action);
|
||||
|
||||
// Check if domain is allowed (do this once per action set)
|
||||
const isDomainAllowed = await isActionDomainAllowed(action.metadata.domain);
|
||||
if (!isDomainAllowed) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Validate and parse OpenAPI spec once per action set
|
||||
const validationResult = validateAndParseOpenAPISpec(action.metadata.raw_spec);
|
||||
if (!validationResult.spec) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const encrypted = {
|
||||
oauth_client_id: action.metadata.oauth_client_id,
|
||||
oauth_client_secret: action.metadata.oauth_client_secret,
|
||||
};
|
||||
|
||||
// Decrypt metadata once per action set
|
||||
const decryptedAction = { ...action };
|
||||
decryptedAction.metadata = await decryptMetadata(action.metadata);
|
||||
|
||||
// Process the OpenAPI spec once per action set
|
||||
const { requestBuilders, functionSignatures, zodSchemas } = openapiToFunction(
|
||||
validationResult.spec,
|
||||
true,
|
||||
);
|
||||
|
||||
processedActionSets.set(domain, {
|
||||
action: decryptedAction,
|
||||
requestBuilders,
|
||||
functionSignatures,
|
||||
zodSchemas,
|
||||
encrypted,
|
||||
});
|
||||
}
|
||||
|
||||
// Now map tools to the processed action sets
|
||||
const ActionToolMap = {};
|
||||
|
||||
for (const toolName of _agentTools) {
|
||||
|
|
@ -519,55 +617,47 @@ async function loadAgentTools({ req, res, agent, tool_resources, openAIApiKey })
|
|||
continue;
|
||||
}
|
||||
|
||||
if (!actionSets.length) {
|
||||
actionSets = (await loadActionSets({ agent_id: agent.id })) ?? [];
|
||||
}
|
||||
|
||||
let actionSet = null;
|
||||
// Find the matching domain for this tool
|
||||
let currentDomain = '';
|
||||
for (let action of actionSets) {
|
||||
const domain = await domainParser(req, action.metadata.domain, true);
|
||||
for (const domain of domainMap.keys()) {
|
||||
if (toolName.includes(domain)) {
|
||||
currentDomain = domain;
|
||||
actionSet = action;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!actionSet) {
|
||||
if (!currentDomain || !processedActionSets.has(currentDomain)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const validationResult = validateAndParseOpenAPISpec(actionSet.metadata.raw_spec);
|
||||
if (validationResult.spec) {
|
||||
const { requestBuilders, functionSignatures, zodSchemas } = openapiToFunction(
|
||||
validationResult.spec,
|
||||
true,
|
||||
);
|
||||
const functionName = toolName.replace(`${actionDelimiter}${currentDomain}`, '');
|
||||
const functionSig = functionSignatures.find((sig) => sig.name === functionName);
|
||||
const requestBuilder = requestBuilders[functionName];
|
||||
const zodSchema = zodSchemas[functionName];
|
||||
const { action, encrypted, zodSchemas, requestBuilders, functionSignatures } =
|
||||
processedActionSets.get(currentDomain);
|
||||
const functionName = toolName.replace(`${actionDelimiter}${currentDomain}`, '');
|
||||
const functionSig = functionSignatures.find((sig) => sig.name === functionName);
|
||||
const requestBuilder = requestBuilders[functionName];
|
||||
const zodSchema = zodSchemas[functionName];
|
||||
|
||||
if (requestBuilder) {
|
||||
const tool = await createActionTool({
|
||||
req,
|
||||
res,
|
||||
action: actionSet,
|
||||
requestBuilder,
|
||||
zodSchema,
|
||||
name: toolName,
|
||||
description: functionSig.description,
|
||||
});
|
||||
if (!tool) {
|
||||
logger.warn(
|
||||
`Invalid action: user: ${req.user.id} | agent_id: ${agent.id} | toolName: ${toolName}`,
|
||||
);
|
||||
throw new Error(`{"type":"${ErrorTypes.INVALID_ACTION}"}`);
|
||||
}
|
||||
agentTools.push(tool);
|
||||
ActionToolMap[toolName] = tool;
|
||||
if (requestBuilder) {
|
||||
const tool = await createActionTool({
|
||||
req,
|
||||
res,
|
||||
action,
|
||||
requestBuilder,
|
||||
zodSchema,
|
||||
encrypted,
|
||||
name: toolName,
|
||||
description: functionSig.description,
|
||||
});
|
||||
|
||||
if (!tool) {
|
||||
logger.warn(
|
||||
`Invalid action: user: ${req.user.id} | agent_id: ${agent.id} | toolName: ${toolName}`,
|
||||
);
|
||||
throw new Error(`{"type":"${ErrorTypes.INVALID_ACTION}"}`);
|
||||
}
|
||||
|
||||
agentTools.push(tool);
|
||||
ActionToolMap[toolName] = tool;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -13,6 +13,24 @@ const secretDefaults = {
|
|||
JWT_REFRESH_SECRET: 'eaa5191f2914e30b9387fd84e254e4ba6fc51b4654968a9b0803b456a54b8418',
|
||||
};
|
||||
|
||||
const deprecatedVariables = [
|
||||
{
|
||||
key: 'CHECK_BALANCE',
|
||||
description:
|
||||
'Please use the `balance` field in the `librechat.yaml` config file instead.\nMore info: https://librechat.ai/docs/configuration/librechat_yaml/object_structure/balance#overview',
|
||||
},
|
||||
{
|
||||
key: 'START_BALANCE',
|
||||
description:
|
||||
'Please use the `balance` field in the `librechat.yaml` config file instead.\nMore info: https://librechat.ai/docs/configuration/librechat_yaml/object_structure/balance#overview',
|
||||
},
|
||||
{
|
||||
key: 'GOOGLE_API_KEY',
|
||||
description:
|
||||
'Please use the `GOOGLE_SEARCH_API_KEY` environment variable for the Google Search Tool instead.',
|
||||
},
|
||||
];
|
||||
|
||||
/**
|
||||
* Checks environment variables for default secrets and deprecated variables.
|
||||
* Logs warnings for any default secret values being used and for usage of deprecated `GOOGLE_API_KEY`.
|
||||
|
|
@ -37,19 +55,11 @@ function checkVariables() {
|
|||
\u200B`);
|
||||
}
|
||||
|
||||
if (process.env.GOOGLE_API_KEY) {
|
||||
logger.warn(
|
||||
'The `GOOGLE_API_KEY` environment variable is deprecated.\nPlease use the `GOOGLE_SEARCH_API_KEY` environment variable instead.',
|
||||
);
|
||||
}
|
||||
|
||||
if (process.env.OPENROUTER_API_KEY) {
|
||||
logger.warn(
|
||||
`The \`OPENROUTER_API_KEY\` environment variable is deprecated and its functionality will be removed soon.
|
||||
Use of this environment variable is highly discouraged as it can lead to unexpected errors when using custom endpoints.
|
||||
Please use the config (\`librechat.yaml\`) file for setting up OpenRouter, and use \`OPENROUTER_KEY\` or another environment variable instead.`,
|
||||
);
|
||||
}
|
||||
deprecatedVariables.forEach(({ key, description }) => {
|
||||
if (process.env[key]) {
|
||||
logger.warn(`The \`${key}\` environment variable is deprecated. ${description}`);
|
||||
}
|
||||
});
|
||||
|
||||
checkPasswordReset();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,12 +18,15 @@ async function loadDefaultInterface(config, configDefaults, roleName = SystemRol
|
|||
const { interface: interfaceConfig } = config ?? {};
|
||||
const { interface: defaults } = configDefaults;
|
||||
const hasModelSpecs = config?.modelSpecs?.list?.length > 0;
|
||||
const includesAddedEndpoints = config?.modelSpecs?.addedEndpoints?.length > 0;
|
||||
|
||||
/** @type {TCustomConfig['interface']} */
|
||||
const loadedInterface = removeNullishValues({
|
||||
endpointsMenu:
|
||||
interfaceConfig?.endpointsMenu ?? (hasModelSpecs ? false : defaults.endpointsMenu),
|
||||
modelSelect: interfaceConfig?.modelSelect ?? (hasModelSpecs ? false : defaults.modelSelect),
|
||||
modelSelect:
|
||||
interfaceConfig?.modelSelect ??
|
||||
(hasModelSpecs ? includesAddedEndpoints : defaults.modelSelect),
|
||||
parameters: interfaceConfig?.parameters ?? (hasModelSpecs ? false : defaults.parameters),
|
||||
presets: interfaceConfig?.presets ?? (hasModelSpecs ? false : defaults.presets),
|
||||
sidePanel: interfaceConfig?.sidePanel ?? defaults.sidePanel,
|
||||
|
|
|
|||
|
|
@ -6,9 +6,10 @@ const { logger } = require('~/config');
|
|||
* Sets up Model Specs from the config (`librechat.yaml`) file.
|
||||
* @param {TCustomConfig['endpoints']} [endpoints] - The loaded custom configuration for endpoints.
|
||||
* @param {TCustomConfig['modelSpecs'] | undefined} [modelSpecs] - The loaded custom configuration for model specs.
|
||||
* @param {TCustomConfig['interface'] | undefined} [interfaceConfig] - The loaded interface configuration.
|
||||
* @returns {TCustomConfig['modelSpecs'] | undefined} The processed model specs, if any.
|
||||
*/
|
||||
function processModelSpecs(endpoints, _modelSpecs) {
|
||||
function processModelSpecs(endpoints, _modelSpecs, interfaceConfig) {
|
||||
if (!_modelSpecs) {
|
||||
return undefined;
|
||||
}
|
||||
|
|
@ -20,6 +21,19 @@ function processModelSpecs(endpoints, _modelSpecs) {
|
|||
|
||||
const customEndpoints = endpoints?.[EModelEndpoint.custom] ?? [];
|
||||
|
||||
if (interfaceConfig.modelSelect !== true && (_modelSpecs.addedEndpoints?.length ?? 0) > 0) {
|
||||
logger.warn(
|
||||
`To utilize \`addedEndpoints\`, which allows provider/model selections alongside model specs, set \`modelSelect: true\` in the interface configuration.
|
||||
|
||||
Example:
|
||||
\`\`\`yaml
|
||||
interface:
|
||||
modelSelect: true
|
||||
\`\`\`
|
||||
`,
|
||||
);
|
||||
}
|
||||
|
||||
for (const spec of list) {
|
||||
if (EModelEndpoint[spec.preset.endpoint] && spec.preset.endpoint !== EModelEndpoint.custom) {
|
||||
modelSpecs.push(spec);
|
||||
|
|
|
|||
|
|
@ -403,6 +403,12 @@
|
|||
* @memberof typedefs
|
||||
*/
|
||||
|
||||
/**
|
||||
* @exports MessageContentImageUrl
|
||||
* @typedef {import('librechat-data-provider').Agents.MessageContentImageUrl} MessageContentImageUrl
|
||||
* @memberof typedefs
|
||||
*/
|
||||
|
||||
/** Prompts */
|
||||
/**
|
||||
* @exports TPrompt
|
||||
|
|
@ -760,6 +766,23 @@
|
|||
* @memberof typedefs
|
||||
*/
|
||||
|
||||
/**
|
||||
* @exports MongoFile
|
||||
* @typedef {import('@librechat/data-schemas').IMongoFile} MongoFile
|
||||
* @memberof typedefs
|
||||
*/
|
||||
/**
|
||||
* @exports IBalance
|
||||
* @typedef {import('@librechat/data-schemas').IBalance} IBalance
|
||||
* @memberof typedefs
|
||||
*/
|
||||
|
||||
/**
|
||||
* @exports MongoUser
|
||||
* @typedef {import('@librechat/data-schemas').IUser} MongoUser
|
||||
* @memberof typedefs
|
||||
*/
|
||||
|
||||
/**
|
||||
* @exports ObjectId
|
||||
* @typedef {import('mongoose').Types.ObjectId} ObjectId
|
||||
|
|
@ -805,6 +828,12 @@
|
|||
* @memberof typedefs
|
||||
*/
|
||||
|
||||
/**
|
||||
* @exports TEndpointOption
|
||||
* @typedef {import('librechat-data-provider').TEndpointOption} TEndpointOption
|
||||
* @memberof typedefs
|
||||
*/
|
||||
|
||||
/**
|
||||
* @exports TAttachment
|
||||
* @typedef {import('librechat-data-provider').TAttachment} TAttachment
|
||||
|
|
|
|||
|
|
@ -6,32 +6,41 @@ const { logger } = require('~/config');
|
|||
* @param {Object} options - The options object.
|
||||
* @param {string} options.message - The custom message to be logged.
|
||||
* @param {import('axios').AxiosError} options.error - The Axios error object.
|
||||
* @returns {string} The log message.
|
||||
*/
|
||||
const logAxiosError = ({ message, error }) => {
|
||||
let logMessage = message;
|
||||
try {
|
||||
const stack = error.stack || 'No stack trace available';
|
||||
|
||||
if (error.response?.status) {
|
||||
const { status, headers, data } = error.response;
|
||||
logger.error(`${message} The server responded with status ${status}: ${error.message}`, {
|
||||
logMessage = `${message} The server responded with status ${status}: ${error.message}`;
|
||||
logger.error(logMessage, {
|
||||
status,
|
||||
headers,
|
||||
data,
|
||||
stack,
|
||||
});
|
||||
} else if (error.request) {
|
||||
const { method, url } = error.config || {};
|
||||
logger.error(
|
||||
`${message} No response received for ${method ? method.toUpperCase() : ''} ${url || ''}: ${error.message}`,
|
||||
{ requestInfo: { method, url } },
|
||||
);
|
||||
logMessage = `${message} No response received for ${method ? method.toUpperCase() : ''} ${url || ''}: ${error.message}`;
|
||||
logger.error(logMessage, {
|
||||
requestInfo: { method, url },
|
||||
stack,
|
||||
});
|
||||
} else if (error?.message?.includes('Cannot read properties of undefined (reading \'status\')')) {
|
||||
logger.error(
|
||||
`${message} It appears the request timed out or was unsuccessful: ${error.message}`,
|
||||
);
|
||||
logMessage = `${message} It appears the request timed out or was unsuccessful: ${error.message}`;
|
||||
logger.error(logMessage, { stack });
|
||||
} else {
|
||||
logger.error(`${message} An error occurred while setting up the request: ${error.message}`);
|
||||
logMessage = `${message} An error occurred while setting up the request: ${error.message}`;
|
||||
logger.error(logMessage, { stack });
|
||||
}
|
||||
} catch (err) {
|
||||
logger.error(`Error in logAxiosError: ${err.message}`);
|
||||
logMessage = `Error in logAxiosError: ${err.message}`;
|
||||
logger.error(logMessage, { stack: err.stack || 'No stack trace available' });
|
||||
}
|
||||
return logMessage;
|
||||
};
|
||||
|
||||
module.exports = { logAxiosError };
|
||||
|
|
|
|||
|
|
@ -34,8 +34,14 @@ const mistralModels = {
|
|||
'mistral-7b': 31990, // -10 from max
|
||||
'mistral-small': 31990, // -10 from max
|
||||
'mixtral-8x7b': 31990, // -10 from max
|
||||
'mistral-large': 131000,
|
||||
'mistral-large-2402': 127500,
|
||||
'mistral-large-2407': 127500,
|
||||
'pixtral-large': 131000,
|
||||
'mistral-saba': 32000,
|
||||
codestral: 256000,
|
||||
'ministral-8b': 131000,
|
||||
'ministral-3b': 131000,
|
||||
};
|
||||
|
||||
const cohereModels = {
|
||||
|
|
@ -52,6 +58,7 @@ const googleModels = {
|
|||
gemini: 30720, // -2048 from max
|
||||
'gemini-pro-vision': 12288,
|
||||
'gemini-exp': 2000000,
|
||||
'gemini-2.5': 1000000, // 1M input tokens, 64k output tokens
|
||||
'gemini-2.0': 2000000,
|
||||
'gemini-2.0-flash': 1000000,
|
||||
'gemini-2.0-flash-lite': 1000000,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue