Merge branch 'main' into feat/Custom-Token-Rates-for-Endpoints

This commit is contained in:
Ruben Talstra 2025-05-14 21:20:25 +02:00 committed by GitHub
commit 9486599268
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
588 changed files with 35845 additions and 13907 deletions

View file

@ -13,7 +13,6 @@ const {
actionDomainSeparator,
} = require('librechat-data-provider');
const { refreshAccessToken } = require('~/server/services/TokenService');
const { isActionDomainAllowed } = require('~/server/services/domains');
const { logger, getFlowStateManager, sendEvent } = require('~/config');
const { encryptV2, decryptV2 } = require('~/server/utils/crypto');
const { getActions, deleteActions } = require('~/models/Action');
@ -51,7 +50,7 @@ const validateAndUpdateTool = async ({ req, tool, assistant_id }) => {
return null;
}
const parsedDomain = await domainParser(req, domain, true);
const parsedDomain = await domainParser(domain, true);
if (!parsedDomain) {
return null;
@ -67,16 +66,14 @@ const validateAndUpdateTool = async ({ req, tool, assistant_id }) => {
*
* Necessary due to `[a-zA-Z0-9_-]*` Regex Validation, limited to a 64-character maximum.
*
* @param {Express.Request} req - The Express Request object.
* @param {string} domain - The domain name to encode/decode.
* @param {boolean} inverse - False to decode from base64, true to encode to base64.
* @returns {Promise<string>} Encoded or decoded domain string.
*/
async function domainParser(req, domain, inverse = false) {
async function domainParser(domain, inverse = false) {
if (!domain) {
return;
}
const domainsCache = getLogStores(CacheKeys.ENCODED_DOMAINS);
const cachedDomain = await domainsCache.get(domain);
if (inverse && cachedDomain) {
@ -123,47 +120,39 @@ async function loadActionSets(searchParams) {
* Creates a general tool for an entire action set.
*
* @param {Object} params - The parameters for loading action sets.
* @param {ServerRequest} params.req
* @param {string} params.userId
* @param {ServerResponse} params.res
* @param {Action} params.action - The action set. Necessary for decrypting authentication values.
* @param {ActionRequest} params.requestBuilder - The ActionRequest builder class to execute the API call.
* @param {string | undefined} [params.name] - The name of the tool.
* @param {string | undefined} [params.description] - The description for the tool.
* @param {import('zod').ZodTypeAny | undefined} [params.zodSchema] - The Zod schema for tool input validation/definition
* @param {{ oauth_client_id?: string; oauth_client_secret?: string; }} params.encrypted - The encrypted values for the action.
* @returns { Promise<typeof tool | { _call: (toolInput: Object | string) => unknown}> } An object with `_call` method to execute the tool input.
*/
async function createActionTool({
req,
userId,
res,
action,
requestBuilder,
zodSchema,
name,
description,
encrypted,
}) {
const isDomainAllowed = await isActionDomainAllowed(action.metadata.domain);
if (!isDomainAllowed) {
return null;
}
const encrypted = {
oauth_client_id: action.metadata.oauth_client_id,
oauth_client_secret: action.metadata.oauth_client_secret,
};
action.metadata = await decryptMetadata(action.metadata);
/** @type {(toolInput: Object | string, config: GraphRunnableConfig) => Promise<unknown>} */
const _call = async (toolInput, config) => {
try {
/** @type {import('librechat-data-provider').ActionMetadataRuntime} */
const metadata = action.metadata;
const executor = requestBuilder.createExecutor();
const preparedExecutor = executor.setParams(toolInput);
const preparedExecutor = executor.setParams(toolInput ?? {});
if (metadata.auth && metadata.auth.type !== AuthTypeEnum.None) {
try {
const action_id = action.action_id;
const identifier = `${req.user.id}:${action.action_id}`;
if (metadata.auth.type === AuthTypeEnum.OAuth && metadata.auth.authorization_url) {
const action_id = action.action_id;
const identifier = `${userId}:${action.action_id}`;
const requestLogin = async () => {
const { args: _args, stepId, ...toolCall } = config.toolCall ?? {};
if (!stepId) {
@ -171,7 +160,7 @@ async function createActionTool({
}
const statePayload = {
nonce: nanoid(),
user: req.user.id,
user: userId,
action_id,
};
@ -198,26 +187,33 @@ async function createActionTool({
expires_at: Date.now() + Time.TWO_MINUTES,
},
};
const flowManager = await getFlowStateManager(getLogStores);
const flowsCache = getLogStores(CacheKeys.FLOWS);
const flowManager = getFlowStateManager(flowsCache);
await flowManager.createFlowWithHandler(
`${identifier}:login`,
`${identifier}:oauth_login:${config.metadata.thread_id}:${config.metadata.run_id}`,
'oauth_login',
async () => {
sendEvent(res, { event: GraphEvents.ON_RUN_STEP_DELTA, data });
logger.debug('Sent OAuth login request to client', { action_id, identifier });
return true;
},
config?.signal,
);
logger.debug('Waiting for OAuth Authorization response', { action_id, identifier });
const result = await flowManager.createFlow(identifier, 'oauth', {
state: stateToken,
userId: req.user.id,
client_url: metadata.auth.client_url,
redirect_uri: `${process.env.DOMAIN_CLIENT}/api/actions/${action_id}/oauth/callback`,
/** Encrypted values */
encrypted_oauth_client_id: encrypted.oauth_client_id,
encrypted_oauth_client_secret: encrypted.oauth_client_secret,
});
const result = await flowManager.createFlow(
identifier,
'oauth',
{
state: stateToken,
userId: userId,
client_url: metadata.auth.client_url,
redirect_uri: `${process.env.DOMAIN_CLIENT}/api/actions/${action_id}/oauth/callback`,
/** Encrypted values */
encrypted_oauth_client_id: encrypted.oauth_client_id,
encrypted_oauth_client_secret: encrypted.oauth_client_secret,
},
config?.signal,
);
logger.debug('Received OAuth Authorization response', { action_id, identifier });
data.delta.auth = undefined;
data.delta.expires_at = undefined;
@ -235,10 +231,10 @@ async function createActionTool({
};
const tokenPromises = [];
tokenPromises.push(findToken({ userId: req.user.id, type: 'oauth', identifier }));
tokenPromises.push(findToken({ userId, type: 'oauth', identifier }));
tokenPromises.push(
findToken({
userId: req.user.id,
userId,
type: 'oauth_refresh',
identifier: `${identifier}:refresh`,
}),
@ -261,18 +257,20 @@ async function createActionTool({
const refresh_token = await decryptV2(refreshTokenData.token);
const refreshTokens = async () =>
await refreshAccessToken({
userId,
identifier,
refresh_token,
userId: req.user.id,
client_url: metadata.auth.client_url,
encrypted_oauth_client_id: encrypted.oauth_client_id,
encrypted_oauth_client_secret: encrypted.oauth_client_secret,
});
const flowManager = await getFlowStateManager(getLogStores);
const flowsCache = getLogStores(CacheKeys.FLOWS);
const flowManager = getFlowStateManager(flowsCache);
const refreshData = await flowManager.createFlowWithHandler(
`${identifier}:refresh`,
'oauth_refresh',
refreshTokens,
config?.signal,
);
metadata.oauth_access_token = refreshData.access_token;
if (refreshData.refresh_token) {
@ -308,9 +306,8 @@ async function createActionTool({
}
return response.data;
} catch (error) {
const logMessage = `API call to ${action.metadata.domain} failed`;
logAxiosError({ message: logMessage, error });
throw error;
const message = `API call to ${action.metadata.domain} failed:`;
return logAxiosError({ message, error });
}
};
@ -327,6 +324,27 @@ async function createActionTool({
};
}
/**
* Encrypts a sensitive value.
* @param {string} value
* @returns {Promise<string>}
*/
async function encryptSensitiveValue(value) {
// Encode API key to handle special characters like ":"
const encodedValue = encodeURIComponent(value);
return await encryptV2(encodedValue);
}
/**
* Decrypts a sensitive value.
* @param {string} value
* @returns {Promise<string>}
*/
async function decryptSensitiveValue(value) {
const decryptedValue = await decryptV2(value);
return decodeURIComponent(decryptedValue);
}
/**
* Encrypts sensitive metadata values for an action.
*
@ -339,17 +357,19 @@ async function encryptMetadata(metadata) {
// ServiceHttp
if (metadata.auth && metadata.auth.type === AuthTypeEnum.ServiceHttp) {
if (metadata.api_key) {
encryptedMetadata.api_key = await encryptV2(metadata.api_key);
encryptedMetadata.api_key = await encryptSensitiveValue(metadata.api_key);
}
}
// OAuth
else if (metadata.auth && metadata.auth.type === AuthTypeEnum.OAuth) {
if (metadata.oauth_client_id) {
encryptedMetadata.oauth_client_id = await encryptV2(metadata.oauth_client_id);
encryptedMetadata.oauth_client_id = await encryptSensitiveValue(metadata.oauth_client_id);
}
if (metadata.oauth_client_secret) {
encryptedMetadata.oauth_client_secret = await encryptV2(metadata.oauth_client_secret);
encryptedMetadata.oauth_client_secret = await encryptSensitiveValue(
metadata.oauth_client_secret,
);
}
}
@ -368,17 +388,19 @@ async function decryptMetadata(metadata) {
// ServiceHttp
if (metadata.auth && metadata.auth.type === AuthTypeEnum.ServiceHttp) {
if (metadata.api_key) {
decryptedMetadata.api_key = await decryptV2(metadata.api_key);
decryptedMetadata.api_key = await decryptSensitiveValue(metadata.api_key);
}
}
// OAuth
else if (metadata.auth && metadata.auth.type === AuthTypeEnum.OAuth) {
if (metadata.oauth_client_id) {
decryptedMetadata.oauth_client_id = await decryptV2(metadata.oauth_client_id);
decryptedMetadata.oauth_client_id = await decryptSensitiveValue(metadata.oauth_client_id);
}
if (metadata.oauth_client_secret) {
decryptedMetadata.oauth_client_secret = await decryptV2(metadata.oauth_client_secret);
decryptedMetadata.oauth_client_secret = await decryptSensitiveValue(
metadata.oauth_client_secret,
);
}
}

View file

@ -78,20 +78,20 @@ describe('domainParser', () => {
// Non-azure request
it('does not return domain as is if not azure', async () => {
const domain = `example.com${actionDomainSeparator}test${actionDomainSeparator}`;
const result1 = await domainParser(reqNoAzure, domain, false);
const result2 = await domainParser(reqNoAzure, domain, true);
const result1 = await domainParser(domain, false);
const result2 = await domainParser(domain, true);
expect(result1).not.toEqual(domain);
expect(result2).not.toEqual(domain);
});
// Test for Empty or Null Inputs
it('returns undefined for null domain input', async () => {
const result = await domainParser(req, null, true);
const result = await domainParser(null, true);
expect(result).toBeUndefined();
});
it('returns undefined for empty domain input', async () => {
const result = await domainParser(req, '', true);
const result = await domainParser('', true);
expect(result).toBeUndefined();
});
@ -102,7 +102,7 @@ describe('domainParser', () => {
.toString('base64')
.substring(0, Constants.ENCODED_DOMAIN_LENGTH);
await domainParser(req, domain, true);
await domainParser(domain, true);
const cachedValue = await globalCache[encodedDomain];
expect(cachedValue).toEqual(Buffer.from(domain).toString('base64'));
@ -112,14 +112,14 @@ describe('domainParser', () => {
it('encodes domain exactly at threshold without modification', async () => {
const domain = 'a'.repeat(Constants.ENCODED_DOMAIN_LENGTH - TLD.length) + TLD;
const expected = domain.replace(/\./g, actionDomainSeparator);
const result = await domainParser(req, domain, true);
const result = await domainParser(domain, true);
expect(result).toEqual(expected);
});
it('encodes domain just below threshold without modification', async () => {
const domain = 'a'.repeat(Constants.ENCODED_DOMAIN_LENGTH - 1 - TLD.length) + TLD;
const expected = domain.replace(/\./g, actionDomainSeparator);
const result = await domainParser(req, domain, true);
const result = await domainParser(domain, true);
expect(result).toEqual(expected);
});
@ -129,7 +129,7 @@ describe('domainParser', () => {
const encodedDomain = Buffer.from(unicodeDomain)
.toString('base64')
.substring(0, Constants.ENCODED_DOMAIN_LENGTH);
const result = await domainParser(req, unicodeDomain, true);
const result = await domainParser(unicodeDomain, true);
expect(result).toEqual(encodedDomain);
});
@ -139,7 +139,6 @@ describe('domainParser', () => {
globalCache[encodedDomain.substring(0, Constants.ENCODED_DOMAIN_LENGTH)] = encodedDomain; // Simulate caching
const result = await domainParser(
req,
encodedDomain.substring(0, Constants.ENCODED_DOMAIN_LENGTH),
false,
);
@ -150,27 +149,27 @@ describe('domainParser', () => {
it('returns domain with replaced separators if no cached domain exists', async () => {
const domain = 'example.com';
const withSeparator = domain.replace(/\./g, actionDomainSeparator);
const result = await domainParser(req, withSeparator, false);
const result = await domainParser(withSeparator, false);
expect(result).toEqual(domain);
});
it('returns domain with replaced separators when inverse is false and under encoding length', async () => {
const domain = 'examp.com';
const withSeparator = domain.replace(/\./g, actionDomainSeparator);
const result = await domainParser(req, withSeparator, false);
const result = await domainParser(withSeparator, false);
expect(result).toEqual(domain);
});
it('replaces periods with actionDomainSeparator when inverse is true and under encoding length', async () => {
const domain = 'examp.com';
const expected = domain.replace(/\./g, actionDomainSeparator);
const result = await domainParser(req, domain, true);
const result = await domainParser(domain, true);
expect(result).toEqual(expected);
});
it('encodes domain when length is above threshold and inverse is true', async () => {
const domain = 'a'.repeat(Constants.ENCODED_DOMAIN_LENGTH + 1).concat('.com');
const result = await domainParser(req, domain, true);
const result = await domainParser(domain, true);
expect(result).not.toEqual(domain);
expect(result.length).toBeLessThanOrEqual(Constants.ENCODED_DOMAIN_LENGTH);
});
@ -180,20 +179,20 @@ describe('domainParser', () => {
const encodedDomain = Buffer.from(
originalDomain.replace(/\./g, actionDomainSeparator),
).toString('base64');
const result = await domainParser(req, encodedDomain, false);
const result = await domainParser(encodedDomain, false);
expect(result).toEqual(encodedDomain);
});
it('decodes encoded value if cached and encoded value is provided, and inverse is false', async () => {
const originalDomain = 'example.com';
const encodedDomain = await domainParser(req, originalDomain, true);
const result = await domainParser(req, encodedDomain, false);
const encodedDomain = await domainParser(originalDomain, true);
const result = await domainParser(encodedDomain, false);
expect(result).toEqual(originalDomain);
});
it('handles invalid base64 encoded values gracefully', async () => {
const invalidBase64Domain = 'not_base64_encoded';
const result = await domainParser(req, invalidBase64Domain, false);
const result = await domainParser(invalidBase64Domain, false);
expect(result).toEqual(invalidBase64Domain);
});
});

View file

@ -1,15 +1,24 @@
const { FileSources, EModelEndpoint, getConfigDefaults } = require('librechat-data-provider');
const {
FileSources,
EModelEndpoint,
loadOCRConfig,
processMCPEnv,
getConfigDefaults,
} = require('librechat-data-provider');
const { checkVariables, checkHealth, checkConfig, checkAzureVariables } = require('./start/checks');
const { azureAssistantsDefaults, assistantsConfigSetup } = require('./start/assistants');
const { initializeAzureBlobService } = require('./Files/Azure/initialize');
const { initializeFirebase } = require('./Files/Firebase/initialize');
const loadCustomConfig = require('./Config/loadCustomConfig');
const handleRateLimits = require('./Config/handleRateLimits');
const { loadDefaultInterface } = require('./start/interface');
const { azureConfigSetup } = require('./start/azureOpenAI');
const { processModelSpecs } = require('./start/modelSpecs');
const { initializeS3 } = require('./Files/S3/initialize');
const { loadAndFormatTools } = require('./ToolService');
const { agentsConfigSetup } = require('./start/agents');
const { initializeRoles } = require('~/models/Role');
const { isEnabled } = require('~/server/utils');
const { getMCPManager } = require('~/config');
const paths = require('~/config/paths');
const { loadTokenRatesConfig } = require('./Config/loadTokenRatesConfig');
@ -27,9 +36,15 @@ const AppService = async (app) => {
const configDefaults = getConfigDefaults();
loadTokenRatesConfig(config, configDefaults);
const ocr = loadOCRConfig(config.ocr);
const filteredTools = config.filteredTools;
const includedTools = config.includedTools;
const fileStrategy = config.fileStrategy ?? configDefaults.fileStrategy;
const startBalance = process.env.START_BALANCE;
const balance = config.balance ?? {
enabled: isEnabled(process.env.CHECK_BALANCE),
startBalance: startBalance ? parseInt(startBalance, 10) : undefined,
};
const imageOutputType = config?.imageOutputType ?? configDefaults.imageOutputType;
process.env.CDN_PROVIDER = fileStrategy;
@ -39,9 +54,13 @@ const AppService = async (app) => {
if (fileStrategy === FileSources.firebase) {
initializeFirebase();
} else if (fileStrategy === FileSources.azure_blob) {
initializeAzureBlobService();
} else if (fileStrategy === FileSources.s3) {
initializeS3();
}
/** @type {Record<string, FunctionTool} */
/** @type {Record<string, FunctionTool>} */
const availableTools = loadAndFormatTools({
adminFilter: filteredTools,
adminIncluded: includedTools,
@ -49,8 +68,8 @@ const AppService = async (app) => {
});
if (config.mcpServers != null) {
const mcpManager = await getMCPManager();
await mcpManager.initializeMCP(config.mcpServers);
const mcpManager = getMCPManager();
await mcpManager.initializeMCP(config.mcpServers, processMCPEnv);
await mcpManager.mapAvailableTools(availableTools);
}
@ -59,6 +78,7 @@ const AppService = async (app) => {
const interfaceConfig = await loadDefaultInterface(config, configDefaults);
const defaultLocals = {
ocr,
paths,
fileStrategy,
socialLogins,
@ -67,6 +87,7 @@ const AppService = async (app) => {
availableTools,
imageOutputType,
interfaceConfig,
balance,
};
if (!Object.keys(config).length) {
@ -127,7 +148,7 @@ const AppService = async (app) => {
...defaultLocals,
fileConfig: config?.fileConfig,
secureImageLinks: config?.secureImageLinks,
modelSpecs: processModelSpecs(endpoints, config.modelSpecs),
modelSpecs: processModelSpecs(endpoints, config.modelSpecs, interfaceConfig),
...endpointLocals,
};
};

View file

@ -15,6 +15,9 @@ jest.mock('./Config/loadCustomConfig', () => {
Promise.resolve({
registration: { socialLogins: ['testLogin'] },
fileStrategy: 'testStrategy',
balance: {
enabled: true,
},
}),
);
});
@ -120,9 +123,13 @@ describe('AppService', () => {
},
},
paths: expect.anything(),
ocr: expect.anything(),
imageOutputType: expect.any(String),
fileConfig: undefined,
secureImageLinks: undefined,
balance: { enabled: true },
filteredTools: undefined,
includedTools: undefined,
});
});
@ -340,9 +347,6 @@ describe('AppService', () => {
process.env.FILE_UPLOAD_USER_MAX = 'initialUserMax';
process.env.FILE_UPLOAD_USER_WINDOW = 'initialUserWindow';
// Mock a custom configuration without specific rate limits
require('./Config/loadCustomConfig').mockImplementationOnce(() => Promise.resolve({}));
await AppService(app);
// Verify that process.env falls back to the initial values
@ -403,9 +407,6 @@ describe('AppService', () => {
process.env.IMPORT_USER_MAX = 'initialUserMax';
process.env.IMPORT_USER_WINDOW = 'initialUserWindow';
// Mock a custom configuration without specific rate limits
require('./Config/loadCustomConfig').mockImplementationOnce(() => Promise.resolve({}));
await AppService(app);
// Verify that process.env falls back to the initial values
@ -444,13 +445,27 @@ describe('AppService updating app.locals and issuing warnings', () => {
expect(app.locals.availableTools).toBeDefined();
expect(app.locals.fileStrategy).toEqual(FileSources.local);
expect(app.locals.socialLogins).toEqual(defaultSocialLogins);
expect(app.locals.balance).toEqual(
expect.objectContaining({
enabled: false,
startBalance: undefined,
}),
);
});
it('should update app.locals with values from loadCustomConfig', async () => {
// Mock loadCustomConfig to return a specific config object
// Mock loadCustomConfig to return a specific config object with a complete balance config
const customConfig = {
fileStrategy: 'firebase',
registration: { socialLogins: ['testLogin'] },
balance: {
enabled: false,
startBalance: 5000,
autoRefillEnabled: true,
refillIntervalValue: 15,
refillIntervalUnit: 'hours',
refillAmount: 5000,
},
};
require('./Config/loadCustomConfig').mockImplementationOnce(() =>
Promise.resolve(customConfig),
@ -463,6 +478,7 @@ describe('AppService updating app.locals and issuing warnings', () => {
expect(app.locals.availableTools).toBeDefined();
expect(app.locals.fileStrategy).toEqual(customConfig.fileStrategy);
expect(app.locals.socialLogins).toEqual(customConfig.registration.socialLogins);
expect(app.locals.balance).toEqual(customConfig.balance);
});
it('should apply the assistants endpoint configuration correctly to app.locals', async () => {
@ -588,4 +604,33 @@ describe('AppService updating app.locals and issuing warnings', () => {
);
});
});
it('should not parse environment variable references in OCR config', async () => {
// Mock custom configuration with env variable references in OCR config
const mockConfig = {
ocr: {
apiKey: '${OCR_API_KEY_CUSTOM_VAR_NAME}',
baseURL: '${OCR_BASEURL_CUSTOM_VAR_NAME}',
strategy: 'mistral_ocr',
mistralModel: 'mistral-medium',
},
};
require('./Config/loadCustomConfig').mockImplementationOnce(() => Promise.resolve(mockConfig));
// Set actual environment variables with different values
process.env.OCR_API_KEY_CUSTOM_VAR_NAME = 'actual-api-key';
process.env.OCR_BASEURL_CUSTOM_VAR_NAME = 'https://actual-ocr-url.com';
// Initialize app
const app = { locals: {} };
await AppService(app);
// Verify that the raw string references were preserved and not interpolated
expect(app.locals.ocr).toBeDefined();
expect(app.locals.ocr.apiKey).toEqual('${OCR_API_KEY_CUSTOM_VAR_NAME}');
expect(app.locals.ocr.baseURL).toEqual('${OCR_BASEURL_CUSTOM_VAR_NAME}');
expect(app.locals.ocr.strategy).toEqual('mistral_ocr');
expect(app.locals.ocr.mistralModel).toEqual('mistral-medium');
});
});

View file

@ -56,7 +56,7 @@ const logoutUser = async (req, refreshToken) => {
try {
req.session.destroy();
} catch (destroyErr) {
logger.error('[logoutUser] Failed to destroy session.', destroyErr);
logger.debug('[logoutUser] Failed to destroy session.', destroyErr);
}
return { status: 200, message: 'Logout successful' };
@ -91,7 +91,7 @@ const sendVerificationEmail = async (user) => {
subject: 'Verify your email',
payload: {
appName: process.env.APP_TITLE || 'LibreChat',
name: user.name,
name: user.name || user.username || user.email,
verificationLink: verificationLink,
year: new Date().getFullYear(),
},
@ -278,7 +278,7 @@ const requestPasswordReset = async (req) => {
subject: 'Password Reset Request',
payload: {
appName: process.env.APP_TITLE || 'LibreChat',
name: user.name,
name: user.name || user.username || user.email,
link: link,
year: new Date().getFullYear(),
},
@ -331,7 +331,7 @@ const resetPassword = async (userId, token, password) => {
subject: 'Password Reset Successfully',
payload: {
appName: process.env.APP_TITLE || 'LibreChat',
name: user.name,
name: user.name || user.username || user.email,
year: new Date().getFullYear(),
},
template: 'passwordReset.handlebars',
@ -414,7 +414,7 @@ const resendVerificationEmail = async (req) => {
subject: 'Verify your email',
payload: {
appName: process.env.APP_TITLE || 'LibreChat',
name: user.name,
name: user.name || user.username || user.email,
verificationLink: verificationLink,
year: new Date().getFullYear(),
},

View file

@ -1,5 +1,5 @@
const { CacheKeys, EModelEndpoint } = require('librechat-data-provider');
const { normalizeEndpointName } = require('~/server/utils');
const { normalizeEndpointName, isEnabled } = require('~/server/utils');
const loadCustomConfig = require('./loadCustomConfig');
const getLogStores = require('~/cache/getLogStores');
@ -23,6 +23,26 @@ async function getCustomConfig() {
return customConfig;
}
/**
* Retrieves the configuration object
* @function getBalanceConfig
* @returns {Promise<TCustomConfig['balance'] | null>}
* */
async function getBalanceConfig() {
const isLegacyEnabled = isEnabled(process.env.CHECK_BALANCE);
const startBalance = process.env.START_BALANCE;
/** @type {TCustomConfig['balance']} */
const config = {
enabled: isLegacyEnabled,
startBalance: startBalance != null && startBalance ? parseInt(startBalance, 10) : undefined,
};
const customConfig = await getCustomConfig();
if (!customConfig) {
return config;
}
return { ...config, ...(customConfig?.['balance'] ?? {}) };
}
/**
*
* @param {string | EModelEndpoint} endpoint
@ -40,4 +60,4 @@ const getCustomEndpointConfig = async (endpoint) => {
);
};
module.exports = { getCustomConfig, getCustomEndpointConfig };
module.exports = { getCustomConfig, getBalanceConfig, getCustomEndpointConfig };

View file

@ -33,10 +33,12 @@ async function getEndpointsConfig(req) {
};
}
if (mergedConfig[EModelEndpoint.agents] && req.app.locals?.[EModelEndpoint.agents]) {
const { disableBuilder, capabilities, ..._rest } = req.app.locals[EModelEndpoint.agents];
const { disableBuilder, capabilities, allowedProviders, ..._rest } =
req.app.locals[EModelEndpoint.agents];
mergedConfig[EModelEndpoint.agents] = {
...mergedConfig[EModelEndpoint.agents],
allowedProviders,
disableBuilder,
capabilities,
};
@ -72,4 +74,15 @@ async function getEndpointsConfig(req) {
return endpointsConfig;
}
module.exports = { getEndpointsConfig };
/**
* @param {ServerRequest} req
* @param {import('librechat-data-provider').AgentCapabilities} capability
* @returns {Promise<boolean>}
*/
const checkCapability = async (req, capability) => {
const endpointsConfig = await getEndpointsConfig(req);
const capabilities = endpointsConfig?.[EModelEndpoint.agents]?.capabilities ?? [];
return capabilities.includes(capability);
};
module.exports = { getEndpointsConfig, checkCapability };

View file

@ -1,19 +1,15 @@
const { isAgentsEndpoint, Constants } = require('librechat-data-provider');
const { loadAgent } = require('~/models/Agent');
const { logger } = require('~/config');
const buildOptions = (req, endpoint, parsedBody) => {
const {
spec,
iconURL,
agent_id,
instructions,
maxContextTokens,
resendFiles = true,
...model_parameters
} = parsedBody;
const buildOptions = (req, endpoint, parsedBody, endpointType) => {
const { spec, iconURL, agent_id, instructions, maxContextTokens, ...model_parameters } =
parsedBody;
const agentPromise = loadAgent({
req,
agent_id,
agent_id: isAgentsEndpoint(endpoint) ? agent_id : Constants.EPHEMERAL_AGENT_ID,
endpoint,
model_parameters,
}).catch((error) => {
logger.error(`[/agents/:${agent_id}] Error retrieving agent during build options step`, error);
return undefined;
@ -24,7 +20,7 @@ const buildOptions = (req, endpoint, parsedBody) => {
iconURL,
endpoint,
agent_id,
resendFiles,
endpointType,
instructions,
maxContextTokens,
model_parameters,

View file

@ -1,7 +1,12 @@
const { createContentAggregator, Providers } = require('@librechat/agents');
const {
Constants,
ErrorTypes,
EModelEndpoint,
EToolResources,
getResponseSender,
AgentCapabilities,
replaceSpecialVars,
providerEndpointMap,
} = require('librechat-data-provider');
const {
@ -15,10 +20,14 @@ const initCustom = require('~/server/services/Endpoints/custom/initialize');
const initGoogle = require('~/server/services/Endpoints/google/initialize');
const generateArtifactsPrompt = require('~/app/clients/prompts/artifacts');
const { getCustomEndpointConfig } = require('~/server/services/Config');
const { processFiles } = require('~/server/services/Files/process');
const { loadAgentTools } = require('~/server/services/ToolService');
const AgentClient = require('~/server/controllers/agents/client');
const { getConvoFiles } = require('~/models/Conversation');
const { getToolFilesByIds } = require('~/models/File');
const { getModelMaxTokens } = require('~/utils');
const { getAgent } = require('~/models/Agent');
const { getFiles } = require('~/models/File');
const { logger } = require('~/config');
const providerConfigMap = {
@ -34,37 +43,73 @@ const providerConfigMap = {
};
/**
*
* @param {Promise<Array<MongoFile | null>> | undefined} _attachments
* @param {AgentToolResources | undefined} _tool_resources
* @param {Object} params
* @param {ServerRequest} params.req
* @param {Promise<Array<MongoFile | null>> | undefined} [params.attachments]
* @param {Set<string>} params.requestFileSet
* @param {AgentToolResources | undefined} [params.tool_resources]
* @returns {Promise<{ attachments: Array<MongoFile | undefined> | undefined, tool_resources: AgentToolResources | undefined }>}
*/
const primeResources = async (_attachments, _tool_resources) => {
const primeResources = async ({
req,
attachments: _attachments,
tool_resources: _tool_resources,
requestFileSet,
}) => {
try {
/** @type {Array<MongoFile | undefined> | undefined} */
let attachments;
const tool_resources = _tool_resources ?? {};
const isOCREnabled = (req.app.locals?.[EModelEndpoint.agents]?.capabilities ?? []).includes(
AgentCapabilities.ocr,
);
if (tool_resources[EToolResources.ocr]?.file_ids && isOCREnabled) {
const context = await getFiles(
{
file_id: { $in: tool_resources.ocr.file_ids },
},
{},
{},
);
attachments = (attachments ?? []).concat(context);
}
if (!_attachments) {
return { attachments: undefined, tool_resources: _tool_resources };
return { attachments, tool_resources };
}
/** @type {Array<MongoFile | undefined> | undefined} */
const files = await _attachments;
const attachments = [];
const tool_resources = _tool_resources ?? {};
if (!attachments) {
/** @type {Array<MongoFile | undefined>} */
attachments = [];
}
for (const file of files) {
if (!file) {
continue;
}
if (file.metadata?.fileIdentifier) {
const execute_code = tool_resources.execute_code ?? {};
const execute_code = tool_resources[EToolResources.execute_code] ?? {};
if (!execute_code.files) {
tool_resources.execute_code = { ...execute_code, files: [] };
tool_resources[EToolResources.execute_code] = { ...execute_code, files: [] };
}
tool_resources.execute_code.files.push(file);
tool_resources[EToolResources.execute_code].files.push(file);
} else if (file.embedded === true) {
const file_search = tool_resources.file_search ?? {};
const file_search = tool_resources[EToolResources.file_search] ?? {};
if (!file_search.files) {
tool_resources.file_search = { ...file_search, files: [] };
tool_resources[EToolResources.file_search] = { ...file_search, files: [] };
}
tool_resources.file_search.files.push(file);
tool_resources[EToolResources.file_search].files.push(file);
} else if (
requestFileSet.has(file.file_id) &&
file.type.startsWith('image') &&
file.height &&
file.width
) {
const image_edit = tool_resources[EToolResources.image_edit] ?? {};
if (!image_edit.files) {
tool_resources[EToolResources.image_edit] = { ...image_edit, files: [] };
}
tool_resources[EToolResources.image_edit].files.push(file);
}
attachments.push(file);
@ -76,13 +121,26 @@ const primeResources = async (_attachments, _tool_resources) => {
}
};
/**
* @param {...string | number} values
* @returns {string | number | undefined}
*/
function optionalChainWithEmptyCheck(...values) {
for (const value of values) {
if (value !== undefined && value !== null && value !== '') {
return value;
}
}
return values[values.length - 1];
}
/**
* @param {object} params
* @param {ServerRequest} params.req
* @param {ServerResponse} params.res
* @param {Agent} params.agent
* @param {Set<string>} [params.allowedProviders]
* @param {object} [params.endpointOption]
* @param {AgentToolResources} [params.tool_resources]
* @param {boolean} [params.isInitialAgent]
* @returns {Promise<Agent>}
*/
@ -91,17 +149,58 @@ const initializeAgentOptions = async ({
res,
agent,
endpointOption,
tool_resources,
allowedProviders,
isInitialAgent = false,
}) => {
const { tools, toolContextMap } = await loadAgentTools({
if (allowedProviders.size > 0 && !allowedProviders.has(agent.provider)) {
throw new Error(
`{ "type": "${ErrorTypes.INVALID_AGENT_PROVIDER}", "info": "${agent.provider}" }`,
);
}
let currentFiles;
/** @type {Array<MongoFile>} */
const requestFiles = req.body.files ?? [];
if (
isInitialAgent &&
req.body.conversationId != null &&
(agent.model_parameters?.resendFiles ?? true) === true
) {
const fileIds = (await getConvoFiles(req.body.conversationId)) ?? [];
/** @type {Set<EToolResources>} */
const toolResourceSet = new Set();
for (const tool of agent.tools) {
if (EToolResources[tool]) {
toolResourceSet.add(EToolResources[tool]);
}
}
const toolFiles = await getToolFilesByIds(fileIds, toolResourceSet);
if (requestFiles.length || toolFiles.length) {
currentFiles = await processFiles(requestFiles.concat(toolFiles));
}
} else if (isInitialAgent && requestFiles.length) {
currentFiles = await processFiles(requestFiles);
}
const { attachments, tool_resources } = await primeResources({
req,
res,
agent,
tool_resources,
attachments: currentFiles,
tool_resources: agent.tool_resources,
requestFileSet: new Set(requestFiles.map((file) => file.file_id)),
});
const provider = agent.provider;
const { tools, toolContextMap } = await loadAgentTools({
req,
res,
agent: {
id: agent.id,
tools: agent.tools,
provider,
model: agent.model,
},
tool_resources,
});
agent.endpoint = provider;
let getOptions = providerConfigMap[provider];
if (!getOptions && providerConfigMap[provider.toLowerCase()] != null) {
@ -134,10 +233,18 @@ const initializeAgentOptions = async ({
endpointOption: _endpointOption,
});
if (
agent.endpoint === EModelEndpoint.azureOpenAI &&
options.llmConfig?.azureOpenAIApiInstanceName == null
) {
agent.provider = Providers.OPENAI;
}
if (options.provider != null) {
agent.provider = options.provider;
}
/** @type {import('@librechat/agents').ClientOptions} */
agent.model_parameters = Object.assign(model_parameters, options.llmConfig);
if (options.configOptions) {
agent.model_parameters.configuration = options.configOptions;
@ -147,6 +254,13 @@ const initializeAgentOptions = async ({
agent.model_parameters.model = agent.model;
}
if (agent.instructions && agent.instructions !== '') {
agent.instructions = replaceSpecialVars({
text: agent.instructions,
user: req.user,
});
}
if (typeof agent.artifacts === 'string' && agent.artifacts !== '') {
agent.additional_instructions = generateArtifactsPrompt({
endpoint: agent.provider,
@ -156,15 +270,23 @@ const initializeAgentOptions = async ({
const tokensModel =
agent.provider === EModelEndpoint.azureOpenAI ? agent.model : agent.model_parameters.model;
const maxTokens = optionalChainWithEmptyCheck(
agent.model_parameters.maxOutputTokens,
agent.model_parameters.maxTokens,
0,
);
const maxContextTokens = optionalChainWithEmptyCheck(
agent.model_parameters.maxContextTokens,
agent.max_context_tokens,
getModelMaxTokens(tokensModel, providerEndpointMap[provider]),
4096,
);
return {
...agent,
tools,
attachments,
toolContextMap,
maxContextTokens:
agent.max_context_tokens ??
getModelMaxTokens(tokensModel, providerEndpointMap[provider]) ??
4000,
maxContextTokens: (maxContextTokens - maxTokens) * 0.9,
};
};
@ -197,12 +319,9 @@ const initializeClient = async ({ req, res, endpointOption }) => {
throw new Error('Agent not found');
}
const { attachments, tool_resources } = await primeResources(
endpointOption.attachments,
primaryAgent.tool_resources,
);
const agentConfigs = new Map();
/** @type {Set<string>} */
const allowedProviders = new Set(req?.app?.locals?.[EModelEndpoint.agents]?.allowedProviders);
// Handle primary agent
const primaryConfig = await initializeAgentOptions({
@ -210,7 +329,7 @@ const initializeClient = async ({ req, res, endpointOption }) => {
res,
agent: primaryAgent,
endpointOption,
tool_resources,
allowedProviders,
isInitialAgent: true,
});
@ -226,6 +345,7 @@ const initializeClient = async ({ req, res, endpointOption }) => {
res,
agent,
endpointOption,
allowedProviders,
});
agentConfigs.set(agentId, config);
}
@ -240,18 +360,25 @@ const initializeClient = async ({ req, res, endpointOption }) => {
const client = new AgentClient({
req,
agent: primaryConfig,
res,
sender,
attachments,
contentParts,
agentConfigs,
eventHandlers,
collectedUsage,
aggregateContent,
artifactPromises,
agent: primaryConfig,
spec: endpointOption.spec,
iconURL: endpointOption.iconURL,
agentConfigs,
endpoint: EModelEndpoint.agents,
attachments: primaryConfig.attachments,
endpointType: endpointOption.endpointType,
maxContextTokens: primaryConfig.maxContextTokens,
resendFiles: primaryConfig.model_parameters?.resendFiles ?? true,
endpoint:
primaryConfig.id === Constants.EPHEMERAL_AGENT_ID
? primaryConfig.endpoint
: EModelEndpoint.agents,
});
return { client };

View file

@ -2,7 +2,11 @@ const { CacheKeys } = require('librechat-data-provider');
const getLogStores = require('~/cache/getLogStores');
const { isEnabled } = require('~/server/utils');
const { saveConvo } = require('~/models');
const { logger } = require('~/config');
/**
* Add title to conversation in a way that avoids memory retention
*/
const addTitle = async (req, { text, response, client }) => {
const { TITLE_CONVO = true } = process.env ?? {};
if (!isEnabled(TITLE_CONVO)) {
@ -13,37 +17,55 @@ const addTitle = async (req, { text, response, client }) => {
return;
}
// If the request was aborted, don't generate the title.
if (client.abortController.signal.aborted) {
return;
}
const titleCache = getLogStores(CacheKeys.GEN_TITLE);
const key = `${req.user.id}-${response.conversationId}`;
const responseText =
response?.content && Array.isArray(response?.content)
? response.content.reduce((acc, block) => {
if (block?.type === 'text') {
return acc + block.text;
}
return acc;
}, '')
: (response?.content ?? response?.text ?? '');
/** @type {NodeJS.Timeout} */
let timeoutId;
try {
const timeoutPromise = new Promise((_, reject) => {
timeoutId = setTimeout(() => reject(new Error('Title generation timeout')), 25000);
}).catch((error) => {
logger.error('Title error:', error);
});
const title = await client.titleConvo({
text,
responseText,
conversationId: response.conversationId,
});
await titleCache.set(key, title, 120000);
await saveConvo(
req,
{
conversationId: response.conversationId,
title,
},
{ context: 'api/server/services/Endpoints/agents/title.js' },
);
let titlePromise;
let abortController = new AbortController();
if (client && typeof client.titleConvo === 'function') {
titlePromise = Promise.race([
client
.titleConvo({
text,
abortController,
})
.catch((error) => {
logger.error('Client title error:', error);
}),
timeoutPromise,
]);
} else {
return;
}
const title = await titlePromise;
if (!abortController.signal.aborted) {
abortController.abort();
}
if (timeoutId) {
clearTimeout(timeoutId);
}
await titleCache.set(key, title, 120000);
await saveConvo(
req,
{
conversationId: response.conversationId,
title,
},
{ context: 'api/server/services/Endpoints/agents/title.js' },
);
} catch (error) {
logger.error('Error generating title:', error);
}
};
module.exports = addTitle;

View file

@ -1,7 +1,7 @@
const { EModelEndpoint } = require('librechat-data-provider');
const { getUserKey, checkUserKeyExpiry } = require('~/server/services/UserService');
const { getLLMConfig } = require('~/server/services/Endpoints/anthropic/llm');
const { AnthropicClient } = require('~/app');
const AnthropicClient = require('~/app/clients/AnthropicClient');
const initializeClient = async ({ req, res, endpointOption, overrideModel, optionsOnly }) => {
const { ANTHROPIC_API_KEY, ANTHROPIC_REVERSE_PROXY, PROXY } = process.env;

View file

@ -13,11 +13,6 @@ const addTitle = async (req, { text, response, client }) => {
return;
}
// If the request was aborted, don't generate the title.
if (client.abortController.signal.aborted) {
return;
}
const titleCache = getLogStores(CacheKeys.GEN_TITLE);
const key = `${req.user.id}-${response.conversationId}`;

View file

@ -3,7 +3,6 @@ const generateArtifactsPrompt = require('~/app/clients/prompts/artifacts');
const { getAssistant } = require('~/models/Assistant');
const buildOptions = async (endpoint, parsedBody) => {
const { promptPrefix, assistant_id, iconURL, greeting, spec, artifacts, ...modelOptions } =
parsedBody;
const endpointOption = removeNullishValues({

View file

@ -23,8 +23,9 @@ const initializeClient = async ({ req, res, endpointOption }) => {
const agent = {
id: EModelEndpoint.bedrock,
name: endpointOption.name,
instructions: endpointOption.promptPrefix,
provider: EModelEndpoint.bedrock,
endpoint: EModelEndpoint.bedrock,
instructions: endpointOption.promptPrefix,
model: endpointOption.model_parameters.model,
model_parameters: endpointOption.model_parameters,
};
@ -54,6 +55,7 @@ const initializeClient = async ({ req, res, endpointOption }) => {
const client = new AgentClient({
req,
res,
agent,
sender,
// tools,

View file

@ -8,7 +8,7 @@ const {
removeNullishValues,
} = require('librechat-data-provider');
const { getUserKey, checkUserKeyExpiry } = require('~/server/services/UserService');
const { sleep } = require('~/server/utils');
const { createHandleLLMNewToken } = require('~/app/clients/generators');
const getOptions = async ({ req, overrideModel, endpointOption }) => {
const {
@ -90,12 +90,7 @@ const getOptions = async ({ req, overrideModel, endpointOption }) => {
llmConfig.callbacks = [
{
handleLLMNewToken: async () => {
if (!streamRate) {
return;
}
await sleep(streamRate);
},
handleLLMNewToken: createHandleLLMNewToken(streamRate),
},
];

View file

@ -9,10 +9,11 @@ const { Providers } = require('@librechat/agents');
const { getUserKeyValues, checkUserKeyExpiry } = require('~/server/services/UserService');
const { getLLMConfig } = require('~/server/services/Endpoints/openAI/llm');
const { getCustomEndpointConfig } = require('~/server/services/Config');
const { createHandleLLMNewToken } = require('~/app/clients/generators');
const { fetchModels } = require('~/server/services/ModelService');
const { isUserProvided, sleep } = require('~/server/utils');
const OpenAIClient = require('~/app/clients/OpenAIClient');
const { isUserProvided } = require('~/server/utils');
const getLogStores = require('~/cache/getLogStores');
const { OpenAIClient } = require('~/app');
const { PROXY } = process.env;
@ -148,9 +149,7 @@ const initializeClient = async ({ req, res, endpointOption, optionsOnly, overrid
}
options.llmConfig.callbacks = [
{
handleLLMNewToken: async () => {
await sleep(customOptions.streamRate);
},
handleLLMNewToken: createHandleLLMNewToken(clientOptions.streamRate),
},
];
return options;

View file

@ -6,9 +6,10 @@ const {
} = require('librechat-data-provider');
const { getUserKeyValues, checkUserKeyExpiry } = require('~/server/services/UserService');
const { getLLMConfig } = require('~/server/services/Endpoints/openAI/llm');
const { isEnabled, isUserProvided, sleep } = require('~/server/utils');
const { createHandleLLMNewToken } = require('~/app/clients/generators');
const { isEnabled, isUserProvided } = require('~/server/utils');
const OpenAIClient = require('~/app/clients/OpenAIClient');
const { getAzureCredentials } = require('~/utils');
const { OpenAIClient } = require('~/app');
const initializeClient = async ({
req,
@ -135,22 +136,18 @@ const initializeClient = async ({
}
if (optionsOnly) {
clientOptions = Object.assign(
{
modelOptions: endpointOption.model_parameters,
},
clientOptions,
);
const modelOptions = endpointOption.model_parameters;
modelOptions.model = modelName;
clientOptions = Object.assign({ modelOptions }, clientOptions);
clientOptions.modelOptions.user = req.user.id;
const options = getLLMConfig(apiKey, clientOptions);
if (!clientOptions.streamRate) {
const streamRate = clientOptions.streamRate;
if (!streamRate) {
return options;
}
options.llmConfig.callbacks = [
{
handleLLMNewToken: async () => {
await sleep(clientOptions.streamRate);
},
handleLLMNewToken: createHandleLLMNewToken(streamRate),
},
];
return options;

View file

@ -28,7 +28,7 @@ const { isEnabled } = require('~/server/utils');
* @returns {Object} Configuration options for creating an LLM instance.
*/
function getLLMConfig(apiKey, options = {}, endpoint = null) {
const {
let {
modelOptions = {},
reverseProxyUrl,
defaultQuery,
@ -50,10 +50,32 @@ function getLLMConfig(apiKey, options = {}, endpoint = null) {
if (addParams && typeof addParams === 'object') {
Object.assign(llmConfig, addParams);
}
/** Note: OpenAI Web Search models do not support any known parameters besdies `max_tokens` */
if (modelOptions.model && /gpt-4o.*search/.test(modelOptions.model)) {
const searchExcludeParams = [
'frequency_penalty',
'presence_penalty',
'temperature',
'top_p',
'top_k',
'stop',
'logit_bias',
'seed',
'response_format',
'n',
'logprobs',
'user',
];
dropParams = dropParams || [];
dropParams = [...new Set([...dropParams, ...searchExcludeParams])];
}
if (dropParams && Array.isArray(dropParams)) {
dropParams.forEach((param) => {
delete llmConfig[param];
if (llmConfig[param]) {
llmConfig[param] = undefined;
}
});
}
@ -114,7 +136,7 @@ function getLLMConfig(apiKey, options = {}, endpoint = null) {
Object.assign(llmConfig, azure);
llmConfig.model = llmConfig.azureOpenAIApiDeploymentName;
} else {
llmConfig.openAIApiKey = apiKey;
llmConfig.apiKey = apiKey;
// Object.assign(llmConfig, {
// configuration: { apiKey },
// });
@ -131,6 +153,12 @@ function getLLMConfig(apiKey, options = {}, endpoint = null) {
delete llmConfig.reasoning_effort;
}
if (llmConfig?.['max_tokens'] != null) {
/** @type {number} */
llmConfig.maxTokens = llmConfig['max_tokens'];
delete llmConfig['max_tokens'];
}
return {
/** @type {OpenAIClientOptions} */
llmConfig,

View file

@ -13,11 +13,6 @@ const addTitle = async (req, { text, response, client }) => {
return;
}
// If the request was aborted and is not azure, don't generate the title.
if (!client.azure && client.abortController.signal.aborted) {
return;
}
const titleCache = getLogStores(CacheKeys.GEN_TITLE);
const key = `${req.user.id}-${response.conversationId}`;

View file

@ -7,6 +7,78 @@ const { getCustomConfig } = require('~/server/services/Config');
const { genAzureEndpoint } = require('~/utils');
const { logger } = require('~/config');
/**
* Maps MIME types to their corresponding file extensions for audio files.
* @type {Object}
*/
const MIME_TO_EXTENSION_MAP = {
// MP4 container formats
'audio/mp4': 'm4a',
'audio/x-m4a': 'm4a',
// Ogg formats
'audio/ogg': 'ogg',
'audio/vorbis': 'ogg',
'application/ogg': 'ogg',
// Wave formats
'audio/wav': 'wav',
'audio/x-wav': 'wav',
'audio/wave': 'wav',
// MP3 formats
'audio/mp3': 'mp3',
'audio/mpeg': 'mp3',
'audio/mpeg3': 'mp3',
// WebM formats
'audio/webm': 'webm',
// Additional formats
'audio/flac': 'flac',
'audio/x-flac': 'flac',
};
/**
* Gets the file extension from the MIME type.
* @param {string} mimeType - The MIME type.
* @returns {string} The file extension.
*/
function getFileExtensionFromMime(mimeType) {
// Default fallback
if (!mimeType) {
return 'webm';
}
// Direct lookup (fastest)
const extension = MIME_TO_EXTENSION_MAP[mimeType];
if (extension) {
return extension;
}
// Try to extract subtype as fallback
const subtype = mimeType.split('/')[1]?.toLowerCase();
// If subtype matches a known extension
if (['mp3', 'mp4', 'ogg', 'wav', 'webm', 'm4a', 'flac'].includes(subtype)) {
return subtype === 'mp4' ? 'm4a' : subtype;
}
// Generic checks for partial matches
if (subtype?.includes('mp4') || subtype?.includes('m4a')) {
return 'm4a';
}
if (subtype?.includes('ogg')) {
return 'ogg';
}
if (subtype?.includes('wav')) {
return 'wav';
}
if (subtype?.includes('mp3') || subtype?.includes('mpeg')) {
return 'mp3';
}
if (subtype?.includes('webm')) {
return 'webm';
}
return 'webm'; // Default fallback
}
/**
* Service class for handling Speech-to-Text (STT) operations.
* @class
@ -170,8 +242,10 @@ class STTService {
throw new Error('Invalid provider');
}
const fileExtension = getFileExtensionFromMime(audioFile.mimetype);
const audioReadStream = Readable.from(audioBuffer);
audioReadStream.path = 'audio.wav';
audioReadStream.path = `audio.${fileExtension}`;
const [url, data, headers] = strategy.call(this, sttSchema, audioReadStream, audioFile);

View file

@ -1,4 +1,10 @@
const { CacheKeys, findLastSeparatorIndex, SEPARATORS, Time } = require('librechat-data-provider');
const {
Time,
CacheKeys,
SEPARATORS,
parseTextParts,
findLastSeparatorIndex,
} = require('librechat-data-provider');
const { getMessage } = require('~/models/Message');
const { getLogStores } = require('~/cache');
@ -84,10 +90,11 @@ function createChunkProcessor(user, messageId) {
notFoundCount++;
return [];
} else {
const text = message.content?.length > 0 ? parseTextParts(message.content) : message.text;
messageCache.set(
messageId,
{
text: message.text,
text,
complete: true,
},
Time.FIVE_MINUTES,
@ -95,7 +102,7 @@ function createChunkProcessor(user, messageId) {
}
const text = typeof message === 'string' ? message : message.text;
const complete = typeof message === 'string' ? false : message.complete ?? true;
const complete = typeof message === 'string' ? false : (message.complete ?? true);
if (text === processedText) {
noChangeCount++;

View file

@ -0,0 +1,253 @@
const fs = require('fs');
const path = require('path');
const mime = require('mime');
const axios = require('axios');
const fetch = require('node-fetch');
const { logger } = require('~/config');
const { getAzureContainerClient } = require('./initialize');
const defaultBasePath = 'images';
const { AZURE_STORAGE_PUBLIC_ACCESS = 'true', AZURE_CONTAINER_NAME = 'files' } = process.env;
/**
* Uploads a buffer to Azure Blob Storage.
*
* Files will be stored at the path: {basePath}/{userId}/{fileName} within the container.
*
* @param {Object} params
* @param {string} params.userId - The user's id.
* @param {Buffer} params.buffer - The buffer to upload.
* @param {string} params.fileName - The name of the file.
* @param {string} [params.basePath='images'] - The base folder within the container.
* @param {string} [params.containerName] - The Azure Blob container name.
* @returns {Promise<string>} The URL of the uploaded blob.
*/
async function saveBufferToAzure({
userId,
buffer,
fileName,
basePath = defaultBasePath,
containerName,
}) {
try {
const containerClient = getAzureContainerClient(containerName);
const access = AZURE_STORAGE_PUBLIC_ACCESS?.toLowerCase() === 'true' ? 'blob' : undefined;
// Create the container if it doesn't exist. This is done per operation.
await containerClient.createIfNotExists({ access });
const blobPath = `${basePath}/${userId}/${fileName}`;
const blockBlobClient = containerClient.getBlockBlobClient(blobPath);
await blockBlobClient.uploadData(buffer);
return blockBlobClient.url;
} catch (error) {
logger.error('[saveBufferToAzure] Error uploading buffer:', error);
throw error;
}
}
/**
* Saves a file from a URL to Azure Blob Storage.
*
* @param {Object} params
* @param {string} params.userId - The user's id.
* @param {string} params.URL - The URL of the file.
* @param {string} params.fileName - The name of the file.
* @param {string} [params.basePath='images'] - The base folder within the container.
* @param {string} [params.containerName] - The Azure Blob container name.
* @returns {Promise<string>} The URL of the uploaded blob.
*/
async function saveURLToAzure({
userId,
URL,
fileName,
basePath = defaultBasePath,
containerName,
}) {
try {
const response = await fetch(URL);
const buffer = await response.buffer();
return await saveBufferToAzure({ userId, buffer, fileName, basePath, containerName });
} catch (error) {
logger.error('[saveURLToAzure] Error uploading file from URL:', error);
throw error;
}
}
/**
* Retrieves a blob URL from Azure Blob Storage.
*
* @param {Object} params
* @param {string} params.fileName - The file name.
* @param {string} [params.basePath='images'] - The base folder used during upload.
* @param {string} [params.userId] - If files are stored in a user-specific directory.
* @param {string} [params.containerName] - The Azure Blob container name.
* @returns {Promise<string>} The blob's URL.
*/
async function getAzureURL({ fileName, basePath = defaultBasePath, userId, containerName }) {
try {
const containerClient = getAzureContainerClient(containerName);
const blobPath = userId ? `${basePath}/${userId}/${fileName}` : `${basePath}/${fileName}`;
const blockBlobClient = containerClient.getBlockBlobClient(blobPath);
return blockBlobClient.url;
} catch (error) {
logger.error('[getAzureURL] Error retrieving blob URL:', error);
throw error;
}
}
/**
* Deletes a blob from Azure Blob Storage.
*
* @param {Object} params
* @param {ServerRequest} params.req - The Express request object.
* @param {MongoFile} params.file - The file object.
*/
async function deleteFileFromAzure(req, file) {
try {
const containerClient = getAzureContainerClient(AZURE_CONTAINER_NAME);
const blobPath = file.filepath.split(`${AZURE_CONTAINER_NAME}/`)[1];
if (!blobPath.includes(req.user.id)) {
throw new Error('User ID not found in blob path');
}
const blockBlobClient = containerClient.getBlockBlobClient(blobPath);
await blockBlobClient.delete();
logger.debug('[deleteFileFromAzure] Blob deleted successfully from Azure Blob Storage');
} catch (error) {
logger.error('[deleteFileFromAzure] Error deleting blob:', error);
if (error.statusCode === 404) {
return;
}
throw error;
}
}
/**
* Streams a file from disk directly to Azure Blob Storage without loading
* the entire file into memory.
*
* @param {Object} params
* @param {string} params.userId - The user's id.
* @param {string} params.filePath - The local file path to upload.
* @param {string} params.fileName - The name of the file in Azure.
* @param {string} [params.basePath='images'] - The base folder within the container.
* @param {string} [params.containerName] - The Azure Blob container name.
* @returns {Promise<string>} The URL of the uploaded blob.
*/
async function streamFileToAzure({
userId,
filePath,
fileName,
basePath = defaultBasePath,
containerName,
}) {
try {
const containerClient = getAzureContainerClient(containerName);
const access = AZURE_STORAGE_PUBLIC_ACCESS?.toLowerCase() === 'true' ? 'blob' : undefined;
// Create the container if it doesn't exist
await containerClient.createIfNotExists({ access });
const blobPath = `${basePath}/${userId}/${fileName}`;
const blockBlobClient = containerClient.getBlockBlobClient(blobPath);
// Get file size for proper content length
const stats = await fs.promises.stat(filePath);
// Create read stream from the file
const fileStream = fs.createReadStream(filePath);
const blobContentType = mime.getType(fileName);
await blockBlobClient.uploadStream(
fileStream,
undefined, // Use default concurrency (5)
undefined, // Use default buffer size (8MB)
{
blobHTTPHeaders: {
blobContentType,
},
onProgress: (progress) => {
logger.debug(
`[streamFileToAzure] Upload progress: ${progress.loadedBytes} bytes of ${stats.size}`,
);
},
},
);
return blockBlobClient.url;
} catch (error) {
logger.error('[streamFileToAzure] Error streaming file:', error);
throw error;
}
}
/**
* Uploads a file from the local file system to Azure Blob Storage.
*
* This function reads the file from disk and then uploads it to Azure Blob Storage
* at the path: {basePath}/{userId}/{fileName}.
*
* @param {Object} params
* @param {object} params.req - The Express request object.
* @param {Express.Multer.File} params.file - The file object.
* @param {string} params.file_id - The file id.
* @param {string} [params.basePath='images'] - The base folder within the container.
* @param {string} [params.containerName] - The Azure Blob container name.
* @returns {Promise<{ filepath: string, bytes: number }>} An object containing the blob URL and its byte size.
*/
async function uploadFileToAzure({
req,
file,
file_id,
basePath = defaultBasePath,
containerName,
}) {
try {
const inputFilePath = file.path;
const stats = await fs.promises.stat(inputFilePath);
const bytes = stats.size;
const userId = req.user.id;
const fileName = `${file_id}__${path.basename(inputFilePath)}`;
const fileURL = await streamFileToAzure({
userId,
filePath: inputFilePath,
fileName,
basePath,
containerName,
});
return { filepath: fileURL, bytes };
} catch (error) {
logger.error('[uploadFileToAzure] Error uploading file:', error);
throw error;
}
}
/**
* Retrieves a readable stream for a blob from Azure Blob Storage.
*
* @param {object} _req - The Express request object.
* @param {string} fileURL - The URL of the blob.
* @returns {Promise<ReadableStream>} A readable stream of the blob.
*/
async function getAzureFileStream(_req, fileURL) {
try {
const response = await axios({
method: 'get',
url: fileURL,
responseType: 'stream',
});
return response.data;
} catch (error) {
logger.error('[getAzureFileStream] Error getting blob stream:', error);
throw error;
}
}
module.exports = {
saveBufferToAzure,
saveURLToAzure,
getAzureURL,
deleteFileFromAzure,
uploadFileToAzure,
getAzureFileStream,
};

View file

@ -0,0 +1,124 @@
const fs = require('fs');
const path = require('path');
const sharp = require('sharp');
const { resizeImageBuffer } = require('../images/resize');
const { updateUser } = require('~/models/userMethods');
const { updateFile } = require('~/models/File');
const { logger } = require('~/config');
const { saveBufferToAzure } = require('./crud');
/**
* Uploads an image file to Azure Blob Storage.
* It resizes and converts the image similar to your Firebase implementation.
*
* @param {Object} params
* @param {object} params.req - The Express request object.
* @param {Express.Multer.File} params.file - The file object.
* @param {string} params.file_id - The file id.
* @param {EModelEndpoint} params.endpoint - The endpoint parameters.
* @param {string} [params.resolution='high'] - The image resolution.
* @param {string} [params.basePath='images'] - The base folder within the container.
* @param {string} [params.containerName] - The Azure Blob container name.
* @returns {Promise<{ filepath: string, bytes: number, width: number, height: number }>}
*/
async function uploadImageToAzure({
req,
file,
file_id,
endpoint,
resolution = 'high',
basePath = 'images',
containerName,
}) {
try {
const inputFilePath = file.path;
const inputBuffer = await fs.promises.readFile(inputFilePath);
const {
buffer: resizedBuffer,
width,
height,
} = await resizeImageBuffer(inputBuffer, resolution, endpoint);
const extension = path.extname(inputFilePath);
const userId = req.user.id;
let webPBuffer;
let fileName = `${file_id}__${path.basename(inputFilePath)}`;
const targetExtension = `.${req.app.locals.imageOutputType}`;
if (extension.toLowerCase() === targetExtension) {
webPBuffer = resizedBuffer;
} else {
webPBuffer = await sharp(resizedBuffer).toFormat(req.app.locals.imageOutputType).toBuffer();
const extRegExp = new RegExp(path.extname(fileName) + '$');
fileName = fileName.replace(extRegExp, targetExtension);
if (!path.extname(fileName)) {
fileName += targetExtension;
}
}
const downloadURL = await saveBufferToAzure({
userId,
buffer: webPBuffer,
fileName,
basePath,
containerName,
});
await fs.promises.unlink(inputFilePath);
const bytes = Buffer.byteLength(webPBuffer);
return { filepath: downloadURL, bytes, width, height };
} catch (error) {
logger.error('[uploadImageToAzure] Error uploading image:', error);
throw error;
}
}
/**
* Prepares the image URL and updates the file record.
*
* @param {object} req - The Express request object.
* @param {MongoFile} file - The file object.
* @returns {Promise<[MongoFile, string]>}
*/
async function prepareAzureImageURL(req, file) {
const { filepath } = file;
const promises = [];
promises.push(updateFile({ file_id: file.file_id }));
promises.push(filepath);
return await Promise.all(promises);
}
/**
* Uploads and processes a user's avatar to Azure Blob Storage.
*
* @param {Object} params
* @param {Buffer} params.buffer - The avatar image buffer.
* @param {string} params.userId - The user's id.
* @param {string} params.manual - Flag to indicate manual update.
* @param {string} [params.basePath='images'] - The base folder within the container.
* @param {string} [params.containerName] - The Azure Blob container name.
* @returns {Promise<string>} The URL of the avatar.
*/
async function processAzureAvatar({ buffer, userId, manual, basePath = 'images', containerName }) {
try {
const downloadURL = await saveBufferToAzure({
userId,
buffer,
fileName: 'avatar.png',
basePath,
containerName,
});
const isManual = manual === 'true';
const url = `${downloadURL}?manual=${isManual}`;
if (isManual) {
await updateUser(userId, { avatar: url });
}
return url;
} catch (error) {
logger.error('[processAzureAvatar] Error uploading profile picture to Azure:', error);
throw error;
}
}
module.exports = {
uploadImageToAzure,
prepareAzureImageURL,
processAzureAvatar,
};

View file

@ -0,0 +1,9 @@
const crud = require('./crud');
const images = require('./images');
const initialize = require('./initialize');
module.exports = {
...crud,
...images,
...initialize,
};

View file

@ -0,0 +1,55 @@
const { BlobServiceClient } = require('@azure/storage-blob');
const { logger } = require('~/config');
let blobServiceClient = null;
let azureWarningLogged = false;
/**
* Initializes the Azure Blob Service client.
* This function establishes a connection by checking if a connection string is provided.
* If available, the connection string is used; otherwise, Managed Identity (via DefaultAzureCredential) is utilized.
* Note: Container creation (and its public access settings) is handled later in the CRUD functions.
* @returns {BlobServiceClient|null} The initialized client, or null if the required configuration is missing.
*/
const initializeAzureBlobService = () => {
if (blobServiceClient) {
return blobServiceClient;
}
const connectionString = process.env.AZURE_STORAGE_CONNECTION_STRING;
if (connectionString) {
blobServiceClient = BlobServiceClient.fromConnectionString(connectionString);
logger.info('Azure Blob Service initialized using connection string');
} else {
const { DefaultAzureCredential } = require('@azure/identity');
const accountName = process.env.AZURE_STORAGE_ACCOUNT_NAME;
if (!accountName) {
if (!azureWarningLogged) {
logger.error(
'[initializeAzureBlobService] Azure Blob Service not initialized. Connection string missing and AZURE_STORAGE_ACCOUNT_NAME not provided.',
);
azureWarningLogged = true;
}
return null;
}
const url = `https://${accountName}.blob.core.windows.net`;
const credential = new DefaultAzureCredential();
blobServiceClient = new BlobServiceClient(url, credential);
logger.info('Azure Blob Service initialized using Managed Identity');
}
return blobServiceClient;
};
/**
* Retrieves the Azure ContainerClient for the given container name.
* @param {string} [containerName=process.env.AZURE_CONTAINER_NAME || 'files'] - The container name.
* @returns {ContainerClient|null} The Azure ContainerClient.
*/
const getAzureContainerClient = (containerName = process.env.AZURE_CONTAINER_NAME || 'files') => {
const serviceClient = initializeAzureBlobService();
return serviceClient ? serviceClient.getContainerClient(containerName) : null;
};
module.exports = {
initializeAzureBlobService,
getAzureContainerClient,
};

View file

@ -1,8 +1,10 @@
const axios = require('axios');
const FormData = require('form-data');
const { getCodeBaseURL } = require('@librechat/agents');
const { createAxiosInstance } = require('~/config');
const { logAxiosError } = require('~/utils');
const axios = createAxiosInstance();
const MAX_FILE_SIZE = 150 * 1024 * 1024;
/**
@ -27,21 +29,15 @@ async function getCodeOutputDownloadStream(fileIdentifier, apiKey) {
timeout: 15000,
};
if (process.env.PROXY) {
options.proxy = {
host: process.env.PROXY,
protocol: process.env.PROXY.startsWith('https') ? 'https' : 'http',
};
}
const response = await axios(options);
return response;
} catch (error) {
logAxiosError({
message: `Error downloading code environment file stream: ${error.message}`,
error,
});
throw new Error(`Error downloading file: ${error.message}`);
throw new Error(
logAxiosError({
message: `Error downloading code environment file stream: ${error.message}`,
error,
}),
);
}
}
@ -79,13 +75,6 @@ async function uploadCodeEnvFile({ req, stream, filename, apiKey, entity_id = ''
maxBodyLength: MAX_FILE_SIZE,
};
if (process.env.PROXY) {
options.proxy = {
host: process.env.PROXY,
protocol: process.env.PROXY.startsWith('https') ? 'https' : 'http',
};
}
const response = await axios.post(`${baseURL}/upload`, form, options);
/** @type {{ message: string; session_id: string; files: Array<{ fileId: string; filename: string }> }} */
@ -101,11 +90,12 @@ async function uploadCodeEnvFile({ req, stream, filename, apiKey, entity_id = ''
return `${fileIdentifier}?entity_id=${entity_id}`;
} catch (error) {
logAxiosError({
message: `Error uploading code environment file: ${error.message}`,
error,
});
throw new Error(`Error uploading code environment file: ${error.message}`);
throw new Error(
logAxiosError({
message: `Error uploading code environment file: ${error.message}`,
error,
}),
);
}
}

View file

@ -309,6 +309,24 @@ function getLocalFileStream(req, filepath) {
throw new Error(`Invalid file path: ${filepath}`);
}
return fs.createReadStream(fullPath);
} else if (filepath.includes('/images/')) {
const basePath = filepath.split('/images/')[1];
if (!basePath) {
logger.warn(`Invalid base path: ${filepath}`);
throw new Error(`Invalid file path: ${filepath}`);
}
const fullPath = path.join(req.app.locals.paths.imageOutput, basePath);
const publicDir = req.app.locals.paths.imageOutput;
const rel = path.relative(publicDir, fullPath);
if (rel.startsWith('..') || path.isAbsolute(rel) || rel.includes(`..${path.sep}`)) {
logger.warn(`Invalid relative file path: ${filepath}`);
throw new Error(`Invalid file path: ${filepath}`);
}
return fs.createReadStream(fullPath);
}
return fs.createReadStream(filepath);

View file

@ -0,0 +1,230 @@
// ~/server/services/Files/MistralOCR/crud.js
const fs = require('fs');
const path = require('path');
const FormData = require('form-data');
const { FileSources, envVarRegex, extractEnvVariable } = require('librechat-data-provider');
const { loadAuthValues } = require('~/server/services/Tools/credentials');
const { logger, createAxiosInstance } = require('~/config');
const { logAxiosError } = require('~/utils/axios');
const axios = createAxiosInstance();
/**
* Uploads a document to Mistral API using file streaming to avoid loading the entire file into memory
*
* @param {Object} params Upload parameters
* @param {string} params.filePath The path to the file on disk
* @param {string} [params.fileName] Optional filename to use (defaults to the name from filePath)
* @param {string} params.apiKey Mistral API key
* @param {string} [params.baseURL=https://api.mistral.ai/v1] Mistral API base URL
* @returns {Promise<Object>} The response from Mistral API
*/
async function uploadDocumentToMistral({
filePath,
fileName = '',
apiKey,
baseURL = 'https://api.mistral.ai/v1',
}) {
const form = new FormData();
form.append('purpose', 'ocr');
const actualFileName = fileName || path.basename(filePath);
const fileStream = fs.createReadStream(filePath);
form.append('file', fileStream, { filename: actualFileName });
return axios
.post(`${baseURL}/files`, form, {
headers: {
Authorization: `Bearer ${apiKey}`,
...form.getHeaders(),
},
maxBodyLength: Infinity,
maxContentLength: Infinity,
})
.then((res) => res.data)
.catch((error) => {
logger.error('Error uploading document to Mistral:', error.message);
throw error;
});
}
async function getSignedUrl({
apiKey,
fileId,
expiry = 24,
baseURL = 'https://api.mistral.ai/v1',
}) {
return axios
.get(`${baseURL}/files/${fileId}/url?expiry=${expiry}`, {
headers: {
Authorization: `Bearer ${apiKey}`,
},
})
.then((res) => res.data)
.catch((error) => {
logger.error('Error fetching signed URL:', error.message);
throw error;
});
}
/**
* @param {Object} params
* @param {string} params.apiKey
* @param {string} params.url - The document or image URL
* @param {string} [params.documentType='document_url'] - 'document_url' or 'image_url'
* @param {string} [params.model]
* @param {string} [params.baseURL]
* @returns {Promise<OCRResult>}
*/
async function performOCR({
apiKey,
url,
documentType = 'document_url',
model = 'mistral-ocr-latest',
baseURL = 'https://api.mistral.ai/v1',
}) {
const documentKey = documentType === 'image_url' ? 'image_url' : 'document_url';
return axios
.post(
`${baseURL}/ocr`,
{
model,
include_image_base64: false,
document: {
type: documentType,
[documentKey]: url,
},
},
{
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${apiKey}`,
},
},
)
.then((res) => res.data)
.catch((error) => {
logger.error('Error performing OCR:', error.message);
throw error;
});
}
function extractVariableName(str) {
const match = str.match(envVarRegex);
return match ? match[1] : null;
}
/**
* Uploads a file to the Mistral OCR API and processes the OCR result.
*
* @param {Object} params - The params object.
* @param {ServerRequest} params.req - The request object from Express. It should have a `user` property with an `id`
* representing the user
* @param {Express.Multer.File} params.file - The file object, which is part of the request. The file object should
* have a `mimetype` property that tells us the file type
* @param {string} params.file_id - The file ID.
* @param {string} [params.entity_id] - The entity ID, not used here but passed for consistency.
* @returns {Promise<{ filepath: string, bytes: number }>} - The result object containing the processed `text` and `images` (not currently used),
* along with the `filename` and `bytes` properties.
*/
const uploadMistralOCR = async ({ req, file, file_id, entity_id }) => {
try {
/** @type {TCustomConfig['ocr']} */
const ocrConfig = req.app.locals?.ocr;
const apiKeyConfig = ocrConfig.apiKey || '';
const baseURLConfig = ocrConfig.baseURL || '';
const isApiKeyEnvVar = envVarRegex.test(apiKeyConfig);
const isBaseURLEnvVar = envVarRegex.test(baseURLConfig);
const isApiKeyEmpty = !apiKeyConfig.trim();
const isBaseURLEmpty = !baseURLConfig.trim();
let apiKey, baseURL;
if (isApiKeyEnvVar || isBaseURLEnvVar || isApiKeyEmpty || isBaseURLEmpty) {
const apiKeyVarName = isApiKeyEnvVar ? extractVariableName(apiKeyConfig) : 'OCR_API_KEY';
const baseURLVarName = isBaseURLEnvVar ? extractVariableName(baseURLConfig) : 'OCR_BASEURL';
const authValues = await loadAuthValues({
userId: req.user.id,
authFields: [baseURLVarName, apiKeyVarName],
optional: new Set([baseURLVarName]),
});
apiKey = authValues[apiKeyVarName];
baseURL = authValues[baseURLVarName];
} else {
apiKey = apiKeyConfig;
baseURL = baseURLConfig;
}
const mistralFile = await uploadDocumentToMistral({
filePath: file.path,
fileName: file.originalname,
apiKey,
baseURL,
});
const modelConfig = ocrConfig.mistralModel || '';
const model = envVarRegex.test(modelConfig)
? extractEnvVariable(modelConfig)
: modelConfig.trim() || 'mistral-ocr-latest';
const signedUrlResponse = await getSignedUrl({
apiKey,
baseURL,
fileId: mistralFile.id,
});
const mimetype = (file.mimetype || '').toLowerCase();
const originalname = file.originalname || '';
const isImage =
mimetype.startsWith('image') || /\.(png|jpe?g|gif|bmp|webp|tiff?)$/i.test(originalname);
const documentType = isImage ? 'image_url' : 'document_url';
const ocrResult = await performOCR({
apiKey,
baseURL,
model,
url: signedUrlResponse.url,
documentType,
});
let aggregatedText = '';
const images = [];
ocrResult.pages.forEach((page, index) => {
if (ocrResult.pages.length > 1) {
aggregatedText += `# PAGE ${index + 1}\n`;
}
aggregatedText += page.markdown + '\n\n';
if (page.images && page.images.length > 0) {
page.images.forEach((image) => {
if (image.image_base64) {
images.push(image.image_base64);
}
});
}
});
return {
filename: file.originalname,
bytes: aggregatedText.length * 4,
filepath: FileSources.mistral_ocr,
text: aggregatedText,
images,
};
} catch (error) {
const message = 'Error uploading document to Mistral OCR API';
throw new Error(logAxiosError({ error, message }));
}
};
module.exports = {
uploadDocumentToMistral,
uploadMistralOCR,
getSignedUrl,
performOCR,
};

View file

@ -0,0 +1,852 @@
const fs = require('fs');
const mockAxios = {
interceptors: {
request: { use: jest.fn(), eject: jest.fn() },
response: { use: jest.fn(), eject: jest.fn() },
},
create: jest.fn().mockReturnValue({
defaults: {
proxy: null,
},
get: jest.fn().mockResolvedValue({ data: {} }),
post: jest.fn().mockResolvedValue({ data: {} }),
put: jest.fn().mockResolvedValue({ data: {} }),
delete: jest.fn().mockResolvedValue({ data: {} }),
}),
get: jest.fn().mockResolvedValue({ data: {} }),
post: jest.fn().mockResolvedValue({ data: {} }),
put: jest.fn().mockResolvedValue({ data: {} }),
delete: jest.fn().mockResolvedValue({ data: {} }),
reset: jest.fn().mockImplementation(function () {
this.get.mockClear();
this.post.mockClear();
this.put.mockClear();
this.delete.mockClear();
this.create.mockClear();
}),
};
jest.mock('axios', () => mockAxios);
jest.mock('fs');
jest.mock('~/config', () => ({
logger: {
error: jest.fn(),
},
createAxiosInstance: () => mockAxios,
}));
jest.mock('~/server/services/Tools/credentials', () => ({
loadAuthValues: jest.fn(),
}));
const { uploadDocumentToMistral, uploadMistralOCR, getSignedUrl, performOCR } = require('./crud');
describe('MistralOCR Service', () => {
afterEach(() => {
mockAxios.reset();
jest.clearAllMocks();
});
describe('uploadDocumentToMistral', () => {
beforeEach(() => {
// Create a more complete mock for file streams that FormData can work with
const mockReadStream = {
on: jest.fn().mockImplementation(function (event, handler) {
// Simulate immediate 'end' event to make FormData complete processing
if (event === 'end') {
handler();
}
return this;
}),
pipe: jest.fn().mockImplementation(function () {
return this;
}),
pause: jest.fn(),
resume: jest.fn(),
emit: jest.fn(),
once: jest.fn(),
destroy: jest.fn(),
};
fs.createReadStream = jest.fn().mockReturnValue(mockReadStream);
// Mock FormData's append to avoid actual stream processing
jest.mock('form-data', () => {
const mockFormData = function () {
return {
append: jest.fn(),
getHeaders: jest
.fn()
.mockReturnValue({ 'content-type': 'multipart/form-data; boundary=---boundary' }),
getBuffer: jest.fn().mockReturnValue(Buffer.from('mock-form-data')),
getLength: jest.fn().mockReturnValue(100),
};
};
return mockFormData;
});
});
it('should upload a document to Mistral API using file streaming', async () => {
const mockResponse = { data: { id: 'file-123', purpose: 'ocr' } };
mockAxios.post.mockResolvedValueOnce(mockResponse);
const result = await uploadDocumentToMistral({
filePath: '/path/to/test.pdf',
fileName: 'test.pdf',
apiKey: 'test-api-key',
});
// Check that createReadStream was called with the correct file path
expect(fs.createReadStream).toHaveBeenCalledWith('/path/to/test.pdf');
// Since we're mocking FormData, we'll just check that axios was called correctly
expect(mockAxios.post).toHaveBeenCalledWith(
'https://api.mistral.ai/v1/files',
expect.anything(),
expect.objectContaining({
headers: expect.objectContaining({
Authorization: 'Bearer test-api-key',
}),
maxBodyLength: Infinity,
maxContentLength: Infinity,
}),
);
expect(result).toEqual(mockResponse.data);
});
it('should handle errors during document upload', async () => {
const errorMessage = 'API error';
mockAxios.post.mockRejectedValueOnce(new Error(errorMessage));
await expect(
uploadDocumentToMistral({
filePath: '/path/to/test.pdf',
fileName: 'test.pdf',
apiKey: 'test-api-key',
}),
).rejects.toThrow();
const { logger } = require('~/config');
expect(logger.error).toHaveBeenCalledWith(
expect.stringContaining('Error uploading document to Mistral:'),
expect.any(String),
);
});
});
describe('getSignedUrl', () => {
it('should fetch signed URL from Mistral API', async () => {
const mockResponse = { data: { url: 'https://document-url.com' } };
mockAxios.get.mockResolvedValueOnce(mockResponse);
const result = await getSignedUrl({
fileId: 'file-123',
apiKey: 'test-api-key',
});
expect(mockAxios.get).toHaveBeenCalledWith(
'https://api.mistral.ai/v1/files/file-123/url?expiry=24',
{
headers: {
Authorization: 'Bearer test-api-key',
},
},
);
expect(result).toEqual(mockResponse.data);
});
it('should handle errors when fetching signed URL', async () => {
const errorMessage = 'API error';
mockAxios.get.mockRejectedValueOnce(new Error(errorMessage));
await expect(
getSignedUrl({
fileId: 'file-123',
apiKey: 'test-api-key',
}),
).rejects.toThrow();
const { logger } = require('~/config');
expect(logger.error).toHaveBeenCalledWith('Error fetching signed URL:', errorMessage);
});
});
describe('performOCR', () => {
it('should perform OCR using Mistral API (document_url)', async () => {
const mockResponse = {
data: {
pages: [{ markdown: 'Page 1 content' }, { markdown: 'Page 2 content' }],
},
};
mockAxios.post.mockResolvedValueOnce(mockResponse);
const result = await performOCR({
apiKey: 'test-api-key',
url: 'https://document-url.com',
model: 'mistral-ocr-latest',
documentType: 'document_url',
});
expect(mockAxios.post).toHaveBeenCalledWith(
'https://api.mistral.ai/v1/ocr',
{
model: 'mistral-ocr-latest',
include_image_base64: false,
document: {
type: 'document_url',
document_url: 'https://document-url.com',
},
},
{
headers: {
'Content-Type': 'application/json',
Authorization: 'Bearer test-api-key',
},
},
);
expect(result).toEqual(mockResponse.data);
});
it('should perform OCR using Mistral API (image_url)', async () => {
const mockResponse = {
data: {
pages: [{ markdown: 'Image OCR content' }],
},
};
mockAxios.post.mockResolvedValueOnce(mockResponse);
const result = await performOCR({
apiKey: 'test-api-key',
url: 'https://image-url.com/image.png',
model: 'mistral-ocr-latest',
documentType: 'image_url',
});
expect(mockAxios.post).toHaveBeenCalledWith(
'https://api.mistral.ai/v1/ocr',
{
model: 'mistral-ocr-latest',
include_image_base64: false,
document: {
type: 'image_url',
image_url: 'https://image-url.com/image.png',
},
},
{
headers: {
'Content-Type': 'application/json',
Authorization: 'Bearer test-api-key',
},
},
);
expect(result).toEqual(mockResponse.data);
});
it('should handle errors during OCR processing', async () => {
const errorMessage = 'OCR processing error';
mockAxios.post.mockRejectedValueOnce(new Error(errorMessage));
await expect(
performOCR({
apiKey: 'test-api-key',
url: 'https://document-url.com',
}),
).rejects.toThrow();
const { logger } = require('~/config');
expect(logger.error).toHaveBeenCalledWith('Error performing OCR:', errorMessage);
});
});
describe('uploadMistralOCR', () => {
beforeEach(() => {
const mockReadStream = {
on: jest.fn().mockImplementation(function (event, handler) {
if (event === 'end') {
handler();
}
return this;
}),
pipe: jest.fn().mockImplementation(function () {
return this;
}),
pause: jest.fn(),
resume: jest.fn(),
emit: jest.fn(),
once: jest.fn(),
destroy: jest.fn(),
};
fs.createReadStream = jest.fn().mockReturnValue(mockReadStream);
});
it('should process OCR for a file with standard configuration', async () => {
// Setup mocks
const { loadAuthValues } = require('~/server/services/Tools/credentials');
loadAuthValues.mockResolvedValue({
OCR_API_KEY: 'test-api-key',
OCR_BASEURL: 'https://api.mistral.ai/v1',
});
// Mock file upload response
mockAxios.post.mockResolvedValueOnce({
data: { id: 'file-123', purpose: 'ocr' },
});
// Mock signed URL response
mockAxios.get.mockResolvedValueOnce({
data: { url: 'https://signed-url.com' },
});
// Mock OCR response with text and images
mockAxios.post.mockResolvedValueOnce({
data: {
pages: [
{
markdown: 'Page 1 content',
images: [{ image_base64: 'base64image1' }],
},
{
markdown: 'Page 2 content',
images: [{ image_base64: 'base64image2' }],
},
],
},
});
const req = {
user: { id: 'user123' },
app: {
locals: {
ocr: {
// Use environment variable syntax to ensure loadAuthValues is called
apiKey: '${OCR_API_KEY}',
baseURL: '${OCR_BASEURL}',
mistralModel: 'mistral-medium',
},
},
},
};
const file = {
path: '/tmp/upload/file.pdf',
originalname: 'document.pdf',
mimetype: 'application/pdf',
};
const result = await uploadMistralOCR({
req,
file,
file_id: 'file123',
entity_id: 'entity123',
});
expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf');
expect(loadAuthValues).toHaveBeenCalledWith({
userId: 'user123',
authFields: ['OCR_BASEURL', 'OCR_API_KEY'],
optional: expect.any(Set),
});
// Verify OCR result
expect(result).toEqual({
filename: 'document.pdf',
bytes: expect.any(Number),
filepath: 'mistral_ocr',
text: expect.stringContaining('# PAGE 1'),
images: ['base64image1', 'base64image2'],
});
});
it('should process OCR for an image file and use image_url type', async () => {
const { loadAuthValues } = require('~/server/services/Tools/credentials');
loadAuthValues.mockResolvedValue({
OCR_API_KEY: 'test-api-key',
OCR_BASEURL: 'https://api.mistral.ai/v1',
});
// Mock file upload response
mockAxios.post.mockResolvedValueOnce({
data: { id: 'file-456', purpose: 'ocr' },
});
// Mock signed URL response
mockAxios.get.mockResolvedValueOnce({
data: { url: 'https://signed-url.com/image.png' },
});
// Mock OCR response for image
mockAxios.post.mockResolvedValueOnce({
data: {
pages: [
{
markdown: 'Image OCR result',
images: [{ image_base64: 'imgbase64' }],
},
],
},
});
const req = {
user: { id: 'user456' },
app: {
locals: {
ocr: {
apiKey: '${OCR_API_KEY}',
baseURL: '${OCR_BASEURL}',
mistralModel: 'mistral-medium',
},
},
},
};
const file = {
path: '/tmp/upload/image.png',
originalname: 'image.png',
mimetype: 'image/png',
};
const result = await uploadMistralOCR({
req,
file,
file_id: 'file456',
entity_id: 'entity456',
});
expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/image.png');
expect(loadAuthValues).toHaveBeenCalledWith({
userId: 'user456',
authFields: ['OCR_BASEURL', 'OCR_API_KEY'],
optional: expect.any(Set),
});
// Check that the OCR API was called with image_url type
expect(mockAxios.post).toHaveBeenCalledWith(
'https://api.mistral.ai/v1/ocr',
expect.objectContaining({
document: expect.objectContaining({
type: 'image_url',
image_url: 'https://signed-url.com/image.png',
}),
}),
expect.any(Object),
);
expect(result).toEqual({
filename: 'image.png',
bytes: expect.any(Number),
filepath: 'mistral_ocr',
text: expect.stringContaining('Image OCR result'),
images: ['imgbase64'],
});
});
it('should process variable references in configuration', async () => {
// Setup mocks with environment variables
const { loadAuthValues } = require('~/server/services/Tools/credentials');
loadAuthValues.mockResolvedValue({
CUSTOM_API_KEY: 'custom-api-key',
CUSTOM_BASEURL: 'https://custom-api.mistral.ai/v1',
});
// Mock API responses
mockAxios.post.mockResolvedValueOnce({
data: { id: 'file-123', purpose: 'ocr' },
});
mockAxios.get.mockResolvedValueOnce({
data: { url: 'https://signed-url.com' },
});
mockAxios.post.mockResolvedValueOnce({
data: {
pages: [{ markdown: 'Content from custom API' }],
},
});
const req = {
user: { id: 'user123' },
app: {
locals: {
ocr: {
apiKey: '${CUSTOM_API_KEY}',
baseURL: '${CUSTOM_BASEURL}',
mistralModel: '${CUSTOM_MODEL}',
},
},
},
};
// Set environment variable for model
process.env.CUSTOM_MODEL = 'mistral-large';
const file = {
path: '/tmp/upload/file.pdf',
originalname: 'document.pdf',
};
const result = await uploadMistralOCR({
req,
file,
file_id: 'file123',
entity_id: 'entity123',
});
expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf');
// Verify that custom environment variables were extracted and used
expect(loadAuthValues).toHaveBeenCalledWith({
userId: 'user123',
authFields: ['CUSTOM_BASEURL', 'CUSTOM_API_KEY'],
optional: expect.any(Set),
});
// Check that mistral-large was used in the OCR API call
expect(mockAxios.post).toHaveBeenCalledWith(
expect.anything(),
expect.objectContaining({
model: 'mistral-large',
}),
expect.anything(),
);
expect(result.text).toEqual('Content from custom API\n\n');
});
it('should fall back to default values when variables are not properly formatted', async () => {
const { loadAuthValues } = require('~/server/services/Tools/credentials');
loadAuthValues.mockResolvedValue({
OCR_API_KEY: 'default-api-key',
OCR_BASEURL: undefined, // Testing optional parameter
});
mockAxios.post.mockResolvedValueOnce({
data: { id: 'file-123', purpose: 'ocr' },
});
mockAxios.get.mockResolvedValueOnce({
data: { url: 'https://signed-url.com' },
});
mockAxios.post.mockResolvedValueOnce({
data: {
pages: [{ markdown: 'Default API result' }],
},
});
const req = {
user: { id: 'user123' },
app: {
locals: {
ocr: {
// Use environment variable syntax to ensure loadAuthValues is called
apiKey: '${INVALID_FORMAT}', // Using valid env var format but with an invalid name
baseURL: '${OCR_BASEURL}', // Using valid env var format
mistralModel: 'mistral-ocr-latest', // Plain string value
},
},
},
};
const file = {
path: '/tmp/upload/file.pdf',
originalname: 'document.pdf',
};
await uploadMistralOCR({
req,
file,
file_id: 'file123',
entity_id: 'entity123',
});
expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf');
// Should use the default values
expect(loadAuthValues).toHaveBeenCalledWith({
userId: 'user123',
authFields: ['OCR_BASEURL', 'INVALID_FORMAT'],
optional: expect.any(Set),
});
// Should use the default model when not using environment variable format
expect(mockAxios.post).toHaveBeenCalledWith(
expect.anything(),
expect.objectContaining({
model: 'mistral-ocr-latest',
}),
expect.anything(),
);
});
it('should handle API errors during OCR process', async () => {
const { loadAuthValues } = require('~/server/services/Tools/credentials');
loadAuthValues.mockResolvedValue({
OCR_API_KEY: 'test-api-key',
});
// Mock file upload to fail
mockAxios.post.mockRejectedValueOnce(new Error('Upload failed'));
const req = {
user: { id: 'user123' },
app: {
locals: {
ocr: {
apiKey: 'OCR_API_KEY',
baseURL: 'OCR_BASEURL',
},
},
},
};
const file = {
path: '/tmp/upload/file.pdf',
originalname: 'document.pdf',
};
await expect(
uploadMistralOCR({
req,
file,
file_id: 'file123',
entity_id: 'entity123',
}),
).rejects.toThrow('Error uploading document to Mistral OCR API');
expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf');
});
it('should handle single page documents without page numbering', async () => {
const { loadAuthValues } = require('~/server/services/Tools/credentials');
loadAuthValues.mockResolvedValue({
OCR_API_KEY: 'test-api-key',
OCR_BASEURL: 'https://api.mistral.ai/v1', // Make sure this is included
});
// Clear all previous mocks
mockAxios.post.mockClear();
mockAxios.get.mockClear();
// 1. First mock: File upload response
mockAxios.post.mockImplementationOnce(() =>
Promise.resolve({ data: { id: 'file-123', purpose: 'ocr' } }),
);
// 2. Second mock: Signed URL response
mockAxios.get.mockImplementationOnce(() =>
Promise.resolve({ data: { url: 'https://signed-url.com' } }),
);
// 3. Third mock: OCR response
mockAxios.post.mockImplementationOnce(() =>
Promise.resolve({
data: {
pages: [{ markdown: 'Single page content' }],
},
}),
);
const req = {
user: { id: 'user123' },
app: {
locals: {
ocr: {
apiKey: 'OCR_API_KEY',
baseURL: 'OCR_BASEURL',
mistralModel: 'mistral-ocr-latest',
},
},
},
};
const file = {
path: '/tmp/upload/file.pdf',
originalname: 'single-page.pdf',
};
const result = await uploadMistralOCR({
req,
file,
file_id: 'file123',
entity_id: 'entity123',
});
expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf');
// Verify that single page documents don't include page numbering
expect(result.text).not.toContain('# PAGE');
expect(result.text).toEqual('Single page content\n\n');
});
it('should use literal values in configuration when provided directly', async () => {
const { loadAuthValues } = require('~/server/services/Tools/credentials');
// We'll still mock this but it should not be used for literal values
loadAuthValues.mockResolvedValue({});
// Clear all previous mocks
mockAxios.post.mockClear();
mockAxios.get.mockClear();
// 1. First mock: File upload response
mockAxios.post.mockImplementationOnce(() =>
Promise.resolve({ data: { id: 'file-123', purpose: 'ocr' } }),
);
// 2. Second mock: Signed URL response
mockAxios.get.mockImplementationOnce(() =>
Promise.resolve({ data: { url: 'https://signed-url.com' } }),
);
// 3. Third mock: OCR response
mockAxios.post.mockImplementationOnce(() =>
Promise.resolve({
data: {
pages: [{ markdown: 'Processed with literal config values' }],
},
}),
);
const req = {
user: { id: 'user123' },
app: {
locals: {
ocr: {
// Direct values that should be used as-is, without variable substitution
apiKey: 'actual-api-key-value',
baseURL: 'https://direct-api-url.mistral.ai/v1',
mistralModel: 'mistral-direct-model',
},
},
},
};
const file = {
path: '/tmp/upload/file.pdf',
originalname: 'direct-values.pdf',
};
const result = await uploadMistralOCR({
req,
file,
file_id: 'file123',
entity_id: 'entity123',
});
expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf');
// Verify the correct URL was used with the direct baseURL value
expect(mockAxios.post).toHaveBeenCalledWith(
'https://direct-api-url.mistral.ai/v1/files',
expect.any(Object),
expect.objectContaining({
headers: expect.objectContaining({
Authorization: 'Bearer actual-api-key-value',
}),
}),
);
// Check the OCR call was made with the direct model value
expect(mockAxios.post).toHaveBeenCalledWith(
'https://direct-api-url.mistral.ai/v1/ocr',
expect.objectContaining({
model: 'mistral-direct-model',
}),
expect.any(Object),
);
// Verify the result
expect(result.text).toEqual('Processed with literal config values\n\n');
// Verify loadAuthValues was never called since we used direct values
expect(loadAuthValues).not.toHaveBeenCalled();
});
it('should handle empty configuration values and use defaults', async () => {
const { loadAuthValues } = require('~/server/services/Tools/credentials');
// Set up the mock values to be returned by loadAuthValues
loadAuthValues.mockResolvedValue({
OCR_API_KEY: 'default-from-env-key',
OCR_BASEURL: 'https://default-from-env.mistral.ai/v1',
});
// Clear all previous mocks
mockAxios.post.mockClear();
mockAxios.get.mockClear();
// 1. First mock: File upload response
mockAxios.post.mockImplementationOnce(() =>
Promise.resolve({ data: { id: 'file-123', purpose: 'ocr' } }),
);
// 2. Second mock: Signed URL response
mockAxios.get.mockImplementationOnce(() =>
Promise.resolve({ data: { url: 'https://signed-url.com' } }),
);
// 3. Third mock: OCR response
mockAxios.post.mockImplementationOnce(() =>
Promise.resolve({
data: {
pages: [{ markdown: 'Content from default configuration' }],
},
}),
);
const req = {
user: { id: 'user123' },
app: {
locals: {
ocr: {
// Empty string values - should fall back to defaults
apiKey: '',
baseURL: '',
mistralModel: '',
},
},
},
};
const file = {
path: '/tmp/upload/file.pdf',
originalname: 'empty-config.pdf',
};
const result = await uploadMistralOCR({
req,
file,
file_id: 'file123',
entity_id: 'entity123',
});
expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf');
// Verify loadAuthValues was called with the default variable names
expect(loadAuthValues).toHaveBeenCalledWith({
userId: 'user123',
authFields: ['OCR_BASEURL', 'OCR_API_KEY'],
optional: expect.any(Set),
});
// Verify the API calls used the default values from loadAuthValues
expect(mockAxios.post).toHaveBeenCalledWith(
'https://default-from-env.mistral.ai/v1/files',
expect.any(Object),
expect.objectContaining({
headers: expect.objectContaining({
Authorization: 'Bearer default-from-env-key',
}),
}),
);
// Verify the OCR model defaulted to mistral-ocr-latest
expect(mockAxios.post).toHaveBeenCalledWith(
'https://default-from-env.mistral.ai/v1/ocr',
expect.objectContaining({
model: 'mistral-ocr-latest',
}),
expect.any(Object),
);
// Check result
expect(result.text).toEqual('Content from default configuration\n\n');
});
});
});

View file

@ -0,0 +1,5 @@
const crud = require('./crud');
module.exports = {
...crud,
};

View file

@ -0,0 +1,467 @@
const fs = require('fs');
const path = require('path');
const fetch = require('node-fetch');
const { FileSources } = require('librechat-data-provider');
const {
PutObjectCommand,
GetObjectCommand,
HeadObjectCommand,
DeleteObjectCommand,
} = require('@aws-sdk/client-s3');
const { getSignedUrl } = require('@aws-sdk/s3-request-presigner');
const { initializeS3 } = require('./initialize');
const { logger } = require('~/config');
const bucketName = process.env.AWS_BUCKET_NAME;
const defaultBasePath = 'images';
let s3UrlExpirySeconds = 7 * 24 * 60 * 60;
let s3RefreshExpiryMs = null;
if (process.env.S3_URL_EXPIRY_SECONDS !== undefined) {
const parsed = parseInt(process.env.S3_URL_EXPIRY_SECONDS, 10);
if (!isNaN(parsed) && parsed > 0) {
s3UrlExpirySeconds = Math.min(parsed, 7 * 24 * 60 * 60);
} else {
logger.warn(
`[S3] Invalid S3_URL_EXPIRY_SECONDS value: "${process.env.S3_URL_EXPIRY_SECONDS}". Using 7-day expiry.`,
);
}
}
if (process.env.S3_REFRESH_EXPIRY_MS !== null && process.env.S3_REFRESH_EXPIRY_MS) {
const parsed = parseInt(process.env.S3_REFRESH_EXPIRY_MS, 10);
if (!isNaN(parsed) && parsed > 0) {
s3RefreshExpiryMs = parsed;
logger.info(`[S3] Using custom refresh expiry time: ${s3RefreshExpiryMs}ms`);
} else {
logger.warn(
`[S3] Invalid S3_REFRESH_EXPIRY_MS value: "${process.env.S3_REFRESH_EXPIRY_MS}". Using default refresh logic.`,
);
}
}
/**
* Constructs the S3 key based on the base path, user ID, and file name.
*/
const getS3Key = (basePath, userId, fileName) => `${basePath}/${userId}/${fileName}`;
/**
* Uploads a buffer to S3 and returns a signed URL.
*
* @param {Object} params
* @param {string} params.userId - The user's unique identifier.
* @param {Buffer} params.buffer - The buffer containing file data.
* @param {string} params.fileName - The file name to use in S3.
* @param {string} [params.basePath='images'] - The base path in the bucket.
* @returns {Promise<string>} Signed URL of the uploaded file.
*/
async function saveBufferToS3({ userId, buffer, fileName, basePath = defaultBasePath }) {
const key = getS3Key(basePath, userId, fileName);
const params = { Bucket: bucketName, Key: key, Body: buffer };
try {
const s3 = initializeS3();
await s3.send(new PutObjectCommand(params));
return await getS3URL({ userId, fileName, basePath });
} catch (error) {
logger.error('[saveBufferToS3] Error uploading buffer to S3:', error.message);
throw error;
}
}
/**
* Retrieves a URL for a file stored in S3.
* Returns a signed URL with expiration time or a proxy URL based on config
*
* @param {Object} params
* @param {string} params.userId - The user's unique identifier.
* @param {string} params.fileName - The file name in S3.
* @param {string} [params.basePath='images'] - The base path in the bucket.
* @returns {Promise<string>} A URL to access the S3 object
*/
async function getS3URL({ userId, fileName, basePath = defaultBasePath }) {
const key = getS3Key(basePath, userId, fileName);
const params = { Bucket: bucketName, Key: key };
try {
const s3 = initializeS3();
return await getSignedUrl(s3, new GetObjectCommand(params), { expiresIn: s3UrlExpirySeconds });
} catch (error) {
logger.error('[getS3URL] Error getting signed URL from S3:', error.message);
throw error;
}
}
/**
* Saves a file from a given URL to S3.
*
* @param {Object} params
* @param {string} params.userId - The user's unique identifier.
* @param {string} params.URL - The source URL of the file.
* @param {string} params.fileName - The file name to use in S3.
* @param {string} [params.basePath='images'] - The base path in the bucket.
* @returns {Promise<string>} Signed URL of the uploaded file.
*/
async function saveURLToS3({ userId, URL, fileName, basePath = defaultBasePath }) {
try {
const response = await fetch(URL);
const buffer = await response.buffer();
// Optionally you can call getBufferMetadata(buffer) if needed.
return await saveBufferToS3({ userId, buffer, fileName, basePath });
} catch (error) {
logger.error('[saveURLToS3] Error uploading file from URL to S3:', error.message);
throw error;
}
}
/**
* Deletes a file from S3.
*
* @param {Object} params
* @param {ServerRequest} params.req
* @param {MongoFile} params.file - The file object to delete.
* @returns {Promise<void>}
*/
async function deleteFileFromS3(req, file) {
const key = extractKeyFromS3Url(file.filepath);
const params = { Bucket: bucketName, Key: key };
if (!key.includes(req.user.id)) {
const message = `[deleteFileFromS3] User ID mismatch: ${req.user.id} vs ${key}`;
logger.error(message);
throw new Error(message);
}
try {
const s3 = initializeS3();
try {
const headCommand = new HeadObjectCommand(params);
await s3.send(headCommand);
logger.debug('[deleteFileFromS3] File exists, proceeding with deletion');
} catch (headErr) {
if (headErr.name === 'NotFound') {
logger.warn(`[deleteFileFromS3] File does not exist: ${key}`);
return;
}
}
const deleteResult = await s3.send(new DeleteObjectCommand(params));
logger.debug('[deleteFileFromS3] Delete command response:', JSON.stringify(deleteResult));
try {
await s3.send(new HeadObjectCommand(params));
logger.error('[deleteFileFromS3] File still exists after deletion!');
} catch (verifyErr) {
if (verifyErr.name === 'NotFound') {
logger.debug(`[deleteFileFromS3] Verified file is deleted: ${key}`);
} else {
logger.error('[deleteFileFromS3] Error verifying deletion:', verifyErr);
}
}
logger.debug('[deleteFileFromS3] S3 File deletion completed');
} catch (error) {
logger.error(`[deleteFileFromS3] Error deleting file from S3: ${error.message}`);
logger.error(error.stack);
// If the file is not found, we can safely return.
if (error.code === 'NoSuchKey') {
return;
}
throw error;
}
}
/**
* Uploads a local file to S3 by streaming it directly without loading into memory.
*
* @param {Object} params
* @param {import('express').Request} params.req - The Express request (must include user).
* @param {Express.Multer.File} params.file - The file object from Multer.
* @param {string} params.file_id - Unique file identifier.
* @param {string} [params.basePath='images'] - The base path in the bucket.
* @returns {Promise<{ filepath: string, bytes: number }>}
*/
async function uploadFileToS3({ req, file, file_id, basePath = defaultBasePath }) {
try {
const inputFilePath = file.path;
const userId = req.user.id;
const fileName = `${file_id}__${path.basename(inputFilePath)}`;
const key = getS3Key(basePath, userId, fileName);
const stats = await fs.promises.stat(inputFilePath);
const bytes = stats.size;
const fileStream = fs.createReadStream(inputFilePath);
const s3 = initializeS3();
const uploadParams = {
Bucket: bucketName,
Key: key,
Body: fileStream,
};
await s3.send(new PutObjectCommand(uploadParams));
const fileURL = await getS3URL({ userId, fileName, basePath });
return { filepath: fileURL, bytes };
} catch (error) {
logger.error('[uploadFileToS3] Error streaming file to S3:', error);
try {
if (file && file.path) {
await fs.promises.unlink(file.path);
}
} catch (unlinkError) {
logger.error(
'[uploadFileToS3] Error deleting temporary file, likely already deleted:',
unlinkError.message,
);
}
throw error;
}
}
/**
* Extracts the S3 key from a URL or returns the key if already properly formatted
*
* @param {string} fileUrlOrKey - The file URL or key
* @returns {string} The S3 key
*/
function extractKeyFromS3Url(fileUrlOrKey) {
if (!fileUrlOrKey) {
throw new Error('Invalid input: URL or key is empty');
}
try {
const url = new URL(fileUrlOrKey);
return url.pathname.substring(1);
} catch (error) {
const parts = fileUrlOrKey.split('/');
if (parts.length >= 3 && !fileUrlOrKey.startsWith('http') && !fileUrlOrKey.startsWith('/')) {
return fileUrlOrKey;
}
return fileUrlOrKey.startsWith('/') ? fileUrlOrKey.substring(1) : fileUrlOrKey;
}
}
/**
* Retrieves a readable stream for a file stored in S3.
*
* @param {ServerRequest} req - Server request object.
* @param {string} filePath - The S3 key of the file.
* @returns {Promise<NodeJS.ReadableStream>}
*/
async function getS3FileStream(_req, filePath) {
try {
const Key = extractKeyFromS3Url(filePath);
const params = { Bucket: bucketName, Key };
const s3 = initializeS3();
const data = await s3.send(new GetObjectCommand(params));
return data.Body; // Returns a Node.js ReadableStream.
} catch (error) {
logger.error('[getS3FileStream] Error retrieving S3 file stream:', error);
throw error;
}
}
/**
* Determines if a signed S3 URL is close to expiration
*
* @param {string} signedUrl - The signed S3 URL
* @param {number} bufferSeconds - Buffer time in seconds
* @returns {boolean} True if the URL needs refreshing
*/
function needsRefresh(signedUrl, bufferSeconds) {
try {
// Parse the URL
const url = new URL(signedUrl);
// Check if it has the signature parameters that indicate it's a signed URL
// X-Amz-Signature is the most reliable indicator for AWS signed URLs
if (!url.searchParams.has('X-Amz-Signature')) {
// Not a signed URL, so no expiration to check (or it's already a proxy URL)
return false;
}
// Extract the expiration time from the URL
const expiresParam = url.searchParams.get('X-Amz-Expires');
const dateParam = url.searchParams.get('X-Amz-Date');
if (!expiresParam || !dateParam) {
// Missing expiration information, assume it needs refresh to be safe
return true;
}
// Parse the AWS date format (YYYYMMDDTHHMMSSZ)
const year = dateParam.substring(0, 4);
const month = dateParam.substring(4, 6);
const day = dateParam.substring(6, 8);
const hour = dateParam.substring(9, 11);
const minute = dateParam.substring(11, 13);
const second = dateParam.substring(13, 15);
const dateObj = new Date(`${year}-${month}-${day}T${hour}:${minute}:${second}Z`);
const expiresAtDate = new Date(dateObj.getTime() + parseInt(expiresParam) * 1000);
// Check if it's close to expiration
const now = new Date();
// If S3_REFRESH_EXPIRY_MS is set, use it to determine if URL is expired
if (s3RefreshExpiryMs !== null) {
const urlCreationTime = dateObj.getTime();
const urlAge = now.getTime() - urlCreationTime;
return urlAge >= s3RefreshExpiryMs;
}
// Otherwise use the default buffer-based logic
const bufferTime = new Date(now.getTime() + bufferSeconds * 1000);
return expiresAtDate <= bufferTime;
} catch (error) {
logger.error('Error checking URL expiration:', error);
// If we can't determine, assume it needs refresh to be safe
return true;
}
}
/**
* Generates a new URL for an expired S3 URL
* @param {string} currentURL - The current file URL
* @returns {Promise<string | undefined>}
*/
async function getNewS3URL(currentURL) {
try {
const s3Key = extractKeyFromS3Url(currentURL);
if (!s3Key) {
return;
}
const keyParts = s3Key.split('/');
if (keyParts.length < 3) {
return;
}
const basePath = keyParts[0];
const userId = keyParts[1];
const fileName = keyParts.slice(2).join('/');
return await getS3URL({
userId,
fileName,
basePath,
});
} catch (error) {
logger.error('Error getting new S3 URL:', error);
}
}
/**
* Refreshes S3 URLs for an array of files if they're expired or close to expiring
*
* @param {MongoFile[]} files - Array of file documents
* @param {(files: MongoFile[]) => Promise<void>} batchUpdateFiles - Function to update files in the database
* @param {number} [bufferSeconds=3600] - Buffer time in seconds to check for expiration
* @returns {Promise<MongoFile[]>} The files with refreshed URLs if needed
*/
async function refreshS3FileUrls(files, batchUpdateFiles, bufferSeconds = 3600) {
if (!files || !Array.isArray(files) || files.length === 0) {
return files;
}
const filesToUpdate = [];
for (let i = 0; i < files.length; i++) {
const file = files[i];
if (!file?.file_id) {
continue;
}
if (file.source !== FileSources.s3) {
continue;
}
if (!file.filepath) {
continue;
}
if (!needsRefresh(file.filepath, bufferSeconds)) {
continue;
}
try {
const newURL = await getNewS3URL(file.filepath);
if (!newURL) {
continue;
}
filesToUpdate.push({
file_id: file.file_id,
filepath: newURL,
});
files[i].filepath = newURL;
} catch (error) {
logger.error(`Error refreshing S3 URL for file ${file.file_id}:`, error);
}
}
if (filesToUpdate.length > 0) {
await batchUpdateFiles(filesToUpdate);
}
return files;
}
/**
* Refreshes a single S3 URL if it's expired or close to expiring
*
* @param {{ filepath: string, source: string }} fileObj - Simple file object containing filepath and source
* @param {number} [bufferSeconds=3600] - Buffer time in seconds to check for expiration
* @returns {Promise<string>} The refreshed URL or the original URL if no refresh needed
*/
async function refreshS3Url(fileObj, bufferSeconds = 3600) {
if (!fileObj || fileObj.source !== FileSources.s3 || !fileObj.filepath) {
return fileObj?.filepath || '';
}
if (!needsRefresh(fileObj.filepath, bufferSeconds)) {
return fileObj.filepath;
}
try {
const s3Key = extractKeyFromS3Url(fileObj.filepath);
if (!s3Key) {
logger.warn(`Unable to extract S3 key from URL: ${fileObj.filepath}`);
return fileObj.filepath;
}
const keyParts = s3Key.split('/');
if (keyParts.length < 3) {
logger.warn(`Invalid S3 key format: ${s3Key}`);
return fileObj.filepath;
}
const basePath = keyParts[0];
const userId = keyParts[1];
const fileName = keyParts.slice(2).join('/');
const newUrl = await getS3URL({
userId,
fileName,
basePath,
});
logger.debug(`Refreshed S3 URL for key: ${s3Key}`);
return newUrl;
} catch (error) {
logger.error(`Error refreshing S3 URL: ${error.message}`);
return fileObj.filepath;
}
}
module.exports = {
saveBufferToS3,
saveURLToS3,
getS3URL,
deleteFileFromS3,
uploadFileToS3,
getS3FileStream,
refreshS3FileUrls,
refreshS3Url,
needsRefresh,
getNewS3URL,
};

View file

@ -0,0 +1,118 @@
const fs = require('fs');
const path = require('path');
const sharp = require('sharp');
const { resizeImageBuffer } = require('../images/resize');
const { updateUser } = require('~/models/userMethods');
const { saveBufferToS3 } = require('./crud');
const { updateFile } = require('~/models/File');
const { logger } = require('~/config');
const defaultBasePath = 'images';
/**
* Resizes, converts, and uploads an image file to S3.
*
* @param {Object} params
* @param {import('express').Request} params.req - Express request (expects user and app.locals.imageOutputType).
* @param {Express.Multer.File} params.file - File object from Multer.
* @param {string} params.file_id - Unique file identifier.
* @param {any} params.endpoint - Endpoint identifier used in image processing.
* @param {string} [params.resolution='high'] - Desired image resolution.
* @param {string} [params.basePath='images'] - Base path in the bucket.
* @returns {Promise<{ filepath: string, bytes: number, width: number, height: number }>}
*/
async function uploadImageToS3({
req,
file,
file_id,
endpoint,
resolution = 'high',
basePath = defaultBasePath,
}) {
try {
const inputFilePath = file.path;
const inputBuffer = await fs.promises.readFile(inputFilePath);
const {
buffer: resizedBuffer,
width,
height,
} = await resizeImageBuffer(inputBuffer, resolution, endpoint);
const extension = path.extname(inputFilePath);
const userId = req.user.id;
let processedBuffer;
let fileName = `${file_id}__${path.basename(inputFilePath)}`;
const targetExtension = `.${req.app.locals.imageOutputType}`;
if (extension.toLowerCase() === targetExtension) {
processedBuffer = resizedBuffer;
} else {
processedBuffer = await sharp(resizedBuffer)
.toFormat(req.app.locals.imageOutputType)
.toBuffer();
fileName = fileName.replace(new RegExp(path.extname(fileName) + '$'), targetExtension);
if (!path.extname(fileName)) {
fileName += targetExtension;
}
}
const downloadURL = await saveBufferToS3({
userId,
buffer: processedBuffer,
fileName,
basePath,
});
await fs.promises.unlink(inputFilePath);
const bytes = Buffer.byteLength(processedBuffer);
return { filepath: downloadURL, bytes, width, height };
} catch (error) {
logger.error('[uploadImageToS3] Error uploading image to S3:', error.message);
throw error;
}
}
/**
* Updates a file record and returns its signed URL.
*
* @param {import('express').Request} req - Express request.
* @param {Object} file - File metadata.
* @returns {Promise<[Promise<any>, string]>}
*/
async function prepareImageURLS3(req, file) {
try {
const updatePromise = updateFile({ file_id: file.file_id });
return Promise.all([updatePromise, file.filepath]);
} catch (error) {
logger.error('[prepareImageURLS3] Error preparing image URL:', error.message);
throw error;
}
}
/**
* Processes a user's avatar image by uploading it to S3 and updating the user's avatar URL if required.
*
* @param {Object} params
* @param {Buffer} params.buffer - Avatar image buffer.
* @param {string} params.userId - User's unique identifier.
* @param {string} params.manual - 'true' or 'false' flag for manual update.
* @param {string} [params.basePath='images'] - Base path in the bucket.
* @returns {Promise<string>} Signed URL of the uploaded avatar.
*/
async function processS3Avatar({ buffer, userId, manual, basePath = defaultBasePath }) {
try {
const downloadURL = await saveBufferToS3({ userId, buffer, fileName: 'avatar.png', basePath });
if (manual === 'true') {
await updateUser(userId, { avatar: downloadURL });
}
return downloadURL;
} catch (error) {
logger.error('[processS3Avatar] Error processing S3 avatar:', error.message);
throw error;
}
}
module.exports = {
uploadImageToS3,
prepareImageURLS3,
processS3Avatar,
};

View file

@ -0,0 +1,9 @@
const crud = require('./crud');
const images = require('./images');
const initialize = require('./initialize');
module.exports = {
...crud,
...images,
...initialize,
};

View file

@ -0,0 +1,53 @@
const { S3Client } = require('@aws-sdk/client-s3');
const { logger } = require('~/config');
let s3 = null;
/**
* Initializes and returns an instance of the AWS S3 client.
*
* If AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY are provided, they will be used.
* Otherwise, the AWS SDK's default credentials chain (including IRSA) is used.
*
* If AWS_ENDPOINT_URL is provided, it will be used as the endpoint.
*
* @returns {S3Client|null} An instance of S3Client if the region is provided; otherwise, null.
*/
const initializeS3 = () => {
if (s3) {
return s3;
}
const region = process.env.AWS_REGION;
if (!region) {
logger.error('[initializeS3] AWS_REGION is not set. Cannot initialize S3.');
return null;
}
// Read the custom endpoint if provided.
const endpoint = process.env.AWS_ENDPOINT_URL;
const accessKeyId = process.env.AWS_ACCESS_KEY_ID;
const secretAccessKey = process.env.AWS_SECRET_ACCESS_KEY;
const config = {
region,
// Conditionally add the endpoint if it is provided
...(endpoint ? { endpoint } : {}),
};
if (accessKeyId && secretAccessKey) {
s3 = new S3Client({
...config,
credentials: { accessKeyId, secretAccessKey },
});
logger.info('[initializeS3] S3 initialized with provided credentials.');
} else {
// When using IRSA, credentials are automatically provided via the IAM Role attached to the ServiceAccount.
s3 = new S3Client(config);
logger.info('[initializeS3] S3 initialized using default credentials (IRSA).');
}
return s3;
};
module.exports = { initializeS3 };

View file

@ -7,8 +7,47 @@ const {
EModelEndpoint,
} = require('librechat-data-provider');
const { getStrategyFunctions } = require('~/server/services/Files/strategies');
const { logAxiosError } = require('~/utils');
const { logger } = require('~/config');
/**
* Converts a readable stream to a base64 encoded string.
*
* @param {NodeJS.ReadableStream} stream - The readable stream to convert.
* @param {boolean} [destroyStream=true] - Whether to destroy the stream after processing.
* @returns {Promise<string>} - Promise resolving to the base64 encoded content.
*/
async function streamToBase64(stream, destroyStream = true) {
return new Promise((resolve, reject) => {
const chunks = [];
stream.on('data', (chunk) => {
chunks.push(chunk);
});
stream.on('end', () => {
try {
const buffer = Buffer.concat(chunks);
const base64Data = buffer.toString('base64');
chunks.length = 0; // Clear the array
resolve(base64Data);
} catch (err) {
reject(err);
}
});
stream.on('error', (error) => {
chunks.length = 0;
reject(error);
});
}).finally(() => {
// Clean up the stream if required
if (destroyStream && stream.destroy && typeof stream.destroy === 'function') {
stream.destroy();
}
});
}
/**
* Fetches an image from a URL and returns its base64 representation.
*
@ -22,10 +61,12 @@ async function fetchImageToBase64(url) {
const response = await axios.get(url, {
responseType: 'arraybuffer',
});
return Buffer.from(response.data).toString('base64');
const base64Data = Buffer.from(response.data).toString('base64');
response.data = null;
return base64Data;
} catch (error) {
logger.error('Error fetching image to convert to base64', error);
throw error;
const message = 'Error fetching image to convert to base64';
throw new Error(logAxiosError({ message, error }));
}
}
@ -37,18 +78,23 @@ const base64Only = new Set([
EModelEndpoint.bedrock,
]);
const blobStorageSources = new Set([FileSources.azure_blob, FileSources.s3]);
/**
* Encodes and formats the given files.
* @param {Express.Request} req - The request object.
* @param {Array<MongoFile>} files - The array of files to encode and format.
* @param {EModelEndpoint} [endpoint] - Optional: The endpoint for the image.
* @param {string} [mode] - Optional: The endpoint mode for the image.
* @returns {Promise<Object>} - A promise that resolves to the result object containing the encoded images and file details.
* @returns {Promise<{ text: string; files: MongoFile[]; image_urls: MessageContentImageUrl[] }>} - A promise that resolves to the result object containing the encoded images and file details.
*/
async function encodeAndFormat(req, files, endpoint, mode) {
const promises = [];
/** @type {Record<FileSources, Pick<ReturnType<typeof getStrategyFunctions>, 'prepareImagePayload' | 'getDownloadStream'>>} */
const encodingMethods = {};
/** @type {{ text: string; files: MongoFile[]; image_urls: MessageContentImageUrl[] }} */
const result = {
text: '',
files: [],
image_urls: [],
};
@ -58,7 +104,11 @@ async function encodeAndFormat(req, files, endpoint, mode) {
}
for (let file of files) {
/** @type {FileSources} */
const source = file.source ?? FileSources.local;
if (source === FileSources.text && file.text) {
result.text += `${!result.text ? 'Attached document(s):\n```md' : '\n\n---\n\n'}# "${file.filename}"\n${file.text}\n`;
}
if (!file.height) {
promises.push([file, null]);
@ -66,18 +116,29 @@ async function encodeAndFormat(req, files, endpoint, mode) {
}
if (!encodingMethods[source]) {
const { prepareImagePayload } = getStrategyFunctions(source);
const { prepareImagePayload, getDownloadStream } = getStrategyFunctions(source);
if (!prepareImagePayload) {
throw new Error(`Encoding function not implemented for ${source}`);
}
encodingMethods[source] = prepareImagePayload;
encodingMethods[source] = { prepareImagePayload, getDownloadStream };
}
const preparePayload = encodingMethods[source];
/* Google & Anthropic don't support passing URLs to payload */
if (source !== FileSources.local && base64Only.has(endpoint)) {
const preparePayload = encodingMethods[source].prepareImagePayload;
/* We need to fetch the image and convert it to base64 if we are using S3/Azure Blob storage. */
if (blobStorageSources.has(source)) {
try {
const downloadStream = encodingMethods[source].getDownloadStream;
let stream = await downloadStream(req, file.filepath);
let base64Data = await streamToBase64(stream);
stream = null;
promises.push([file, base64Data]);
base64Data = null;
continue;
} catch (error) {
// Error handling code
}
} else if (source !== FileSources.local && base64Only.has(endpoint)) {
const [_file, imageURL] = await preparePayload(req, file);
promises.push([_file, await fetchImageToBase64(imageURL)]);
continue;
@ -85,10 +146,15 @@ async function encodeAndFormat(req, files, endpoint, mode) {
promises.push(preparePayload(req, file));
}
if (result.text) {
result.text += '\n```';
}
const detail = req.body.imageDetail ?? ImageDetail.auto;
/** @type {Array<[MongoFile, string]>} */
const formattedImages = await Promise.all(promises);
promises.length = 0;
for (const [file, imageContent] of formattedImages) {
const fileMetadata = {
@ -121,8 +187,8 @@ async function encodeAndFormat(req, files, endpoint, mode) {
};
if (mode === VisionModes.agents) {
result.image_urls.push(imagePart);
result.files.push(fileMetadata);
result.image_urls.push({ ...imagePart });
result.files.push({ ...fileMetadata });
continue;
}
@ -144,10 +210,11 @@ async function encodeAndFormat(req, files, endpoint, mode) {
delete imagePart.image_url;
}
result.image_urls.push(imagePart);
result.files.push(fileMetadata);
result.image_urls.push({ ...imagePart });
result.files.push({ ...fileMetadata });
}
return result;
formattedImages.length = 0;
return { ...result };
}
module.exports = {

View file

@ -28,8 +28,8 @@ const { addResourceFileId, deleteResourceFileId } = require('~/server/controller
const { addAgentResourceFile, removeAgentResourceFiles } = require('~/models/Agent');
const { getOpenAIClient } = require('~/server/controllers/assistants/helpers');
const { createFile, updateFileUsage, deleteFiles } = require('~/models/File');
const { getEndpointsConfig } = require('~/server/services/Config');
const { loadAuthValues } = require('~/app/clients/tools/util');
const { loadAuthValues } = require('~/server/services/Tools/credentials');
const { checkCapability } = require('~/server/services/Config');
const { LB_QueueAsyncCall } = require('~/server/utils/queue');
const { getStrategyFunctions } = require('./strategies');
const { determineFileType } = require('~/server/utils');
@ -162,7 +162,6 @@ const processDeleteRequest = async ({ req, files }) => {
for (const file of files) {
const source = file.source ?? FileSources.local;
if (req.body.agent_id && req.body.tool_resource) {
agentFiles.push({
tool_resource: req.body.tool_resource,
@ -170,6 +169,11 @@ const processDeleteRequest = async ({ req, files }) => {
});
}
if (source === FileSources.text) {
resolvedFileIds.push(file.file_id);
continue;
}
if (checkOpenAIStorage(source) && !client[source]) {
await initializeClients();
}
@ -453,17 +457,6 @@ const processFileUpload = async ({ req, res, metadata }) => {
res.status(200).json({ message: 'File uploaded and processed successfully', ...result });
};
/**
* @param {ServerRequest} req
* @param {AgentCapabilities} capability
* @returns {Promise<boolean>}
*/
const checkCapability = async (req, capability) => {
const endpointsConfig = await getEndpointsConfig(req);
const capabilities = endpointsConfig?.[EModelEndpoint.agents]?.capabilities ?? [];
return capabilities.includes(capability);
};
/**
* Applies the current strategy for file uploads.
* Saves file metadata to the database with an expiry TTL.
@ -499,7 +492,7 @@ const processAgentFileUpload = async ({ req, res, metadata }) => {
let fileInfoMetadata;
const entity_id = messageAttachment === true ? undefined : agent_id;
const basePath = mime.getType(file.originalname)?.startsWith('image') ? 'images' : 'uploads';
if (tool_resource === EToolResources.execute_code) {
const isCodeEnabled = await checkCapability(req, AgentCapabilities.execute_code);
if (!isCodeEnabled) {
@ -521,6 +514,52 @@ const processAgentFileUpload = async ({ req, res, metadata }) => {
if (!isFileSearchEnabled) {
throw new Error('File search is not enabled for Agents');
}
} else if (tool_resource === EToolResources.ocr) {
const isOCREnabled = await checkCapability(req, AgentCapabilities.ocr);
if (!isOCREnabled) {
throw new Error('OCR capability is not enabled for Agents');
}
const { handleFileUpload: uploadMistralOCR } = getStrategyFunctions(
req.app.locals?.ocr?.strategy ?? FileSources.mistral_ocr,
);
const { file_id, temp_file_id } = metadata;
const {
text,
bytes,
// TODO: OCR images support?
images,
filename,
filepath: ocrFileURL,
} = await uploadMistralOCR({ req, file, file_id, entity_id: agent_id, basePath });
const fileInfo = removeNullishValues({
text,
bytes,
file_id,
temp_file_id,
user: req.user.id,
type: 'text/plain',
filepath: ocrFileURL,
source: FileSources.text,
filename: filename ?? file.originalname,
model: messageAttachment ? undefined : req.body.model,
context: messageAttachment ? FileContext.message_attachment : FileContext.agents,
});
if (!messageAttachment && tool_resource) {
await addAgentResourceFile({
req,
file_id,
agent_id,
tool_resource,
});
}
const result = await createFile(fileInfo, true);
return res
.status(200)
.json({ message: 'Agent file uploaded and processed successfully', ...result });
}
const source =
@ -543,6 +582,7 @@ const processAgentFileUpload = async ({ req, res, metadata }) => {
file,
file_id,
entity_id,
basePath,
});
let filepath = _filepath;

View file

@ -21,9 +21,32 @@ const {
processLocalAvatar,
getLocalFileStream,
} = require('./Local');
const {
getS3URL,
saveURLToS3,
saveBufferToS3,
getS3FileStream,
uploadImageToS3,
prepareImageURLS3,
deleteFileFromS3,
processS3Avatar,
uploadFileToS3,
} = require('./S3');
const {
saveBufferToAzure,
saveURLToAzure,
getAzureURL,
deleteFileFromAzure,
uploadFileToAzure,
getAzureFileStream,
uploadImageToAzure,
prepareAzureImageURL,
processAzureAvatar,
} = require('./Azure');
const { uploadOpenAIFile, deleteOpenAIFile, getOpenAIFileStream } = require('./OpenAI');
const { getCodeOutputDownloadStream, uploadCodeEnvFile } = require('./Code');
const { uploadVectors, deleteVectors } = require('./VectorDB');
const { uploadMistralOCR } = require('./MistralOCR');
/**
* Firebase Storage Strategy Functions
@ -57,6 +80,38 @@ const localStrategy = () => ({
getDownloadStream: getLocalFileStream,
});
/**
* S3 Storage Strategy Functions
*
* */
const s3Strategy = () => ({
handleFileUpload: uploadFileToS3,
saveURL: saveURLToS3,
getFileURL: getS3URL,
deleteFile: deleteFileFromS3,
saveBuffer: saveBufferToS3,
prepareImagePayload: prepareImageURLS3,
processAvatar: processS3Avatar,
handleImageUpload: uploadImageToS3,
getDownloadStream: getS3FileStream,
});
/**
* Azure Blob Storage Strategy Functions
*
* */
const azureStrategy = () => ({
handleFileUpload: uploadFileToAzure,
saveURL: saveURLToAzure,
getFileURL: getAzureURL,
deleteFile: deleteFileFromAzure,
saveBuffer: saveBufferToAzure,
prepareImagePayload: prepareAzureImageURL,
processAvatar: processAzureAvatar,
handleImageUpload: uploadImageToAzure,
getDownloadStream: getAzureFileStream,
});
/**
* VectorDB Storage Strategy Functions
*
@ -127,6 +182,26 @@ const codeOutputStrategy = () => ({
getDownloadStream: getCodeOutputDownloadStream,
});
const mistralOCRStrategy = () => ({
/** @type {typeof saveFileFromURL | null} */
saveURL: null,
/** @type {typeof getLocalFileURL | null} */
getFileURL: null,
/** @type {typeof saveLocalBuffer | null} */
saveBuffer: null,
/** @type {typeof processLocalAvatar | null} */
processAvatar: null,
/** @type {typeof uploadLocalImage | null} */
handleImageUpload: null,
/** @type {typeof prepareImagesLocal | null} */
prepareImagePayload: null,
/** @type {typeof deleteLocalFile | null} */
deleteFile: null,
/** @type {typeof getLocalFileStream | null} */
getDownloadStream: null,
handleFileUpload: uploadMistralOCR,
});
// Strategy Selector
const getStrategyFunctions = (fileSource) => {
if (fileSource === FileSources.firebase) {
@ -137,10 +212,16 @@ const getStrategyFunctions = (fileSource) => {
return openAIStrategy();
} else if (fileSource === FileSources.azure) {
return openAIStrategy();
} else if (fileSource === FileSources.azure_blob) {
return azureStrategy();
} else if (fileSource === FileSources.vectordb) {
return vectorStrategy();
} else if (fileSource === FileSources.s3) {
return s3Strategy();
} else if (fileSource === FileSources.execute_code) {
return codeOutputStrategy();
} else if (fileSource === FileSources.mistral_ocr) {
return mistralOCRStrategy();
} else {
throw new Error('Invalid file source');
}

View file

@ -13,13 +13,13 @@ const { logger, getMCPManager } = require('~/config');
* Creates a general tool for an entire action set.
*
* @param {Object} params - The parameters for loading action sets.
* @param {ServerRequest} params.req - The name of the tool.
* @param {ServerRequest} params.req - The Express request object, containing user/request info.
* @param {string} params.toolKey - The toolKey for the tool.
* @param {import('@librechat/agents').Providers | EModelEndpoint} params.provider - The provider for the tool.
* @param {string} params.model - The model for the tool.
* @returns { Promise<typeof tool | { _call: (toolInput: Object | string) => unknown}> } An object with `_call` method to execute the tool input.
*/
async function createMCPTool({ req, toolKey, provider }) {
async function createMCPTool({ req, toolKey, provider: _provider }) {
const toolDefinition = req.app.locals.availableTools[toolKey]?.function;
if (!toolDefinition) {
logger.error(`Tool ${toolKey} not found in available tools`);
@ -27,9 +27,10 @@ async function createMCPTool({ req, toolKey, provider }) {
}
/** @type {LCTool} */
const { description, parameters } = toolDefinition;
const isGoogle = provider === Providers.VERTEXAI || provider === Providers.GOOGLE;
const isGoogle = _provider === Providers.VERTEXAI || _provider === Providers.GOOGLE;
let schema = convertJsonSchemaToZod(parameters, {
allowEmptyObject: !isGoogle,
transformOneOfAnyOf: true,
});
if (!schema) {
@ -37,11 +38,31 @@ async function createMCPTool({ req, toolKey, provider }) {
}
const [toolName, serverName] = toolKey.split(Constants.mcp_delimiter);
/** @type {(toolInput: Object | string) => Promise<unknown>} */
const _call = async (toolInput) => {
if (!req.user?.id) {
logger.error(
`[MCP][${serverName}][${toolName}] User ID not found on request. Cannot create tool.`,
);
throw new Error(`User ID not found on request. Cannot create tool for ${toolKey}.`);
}
/** @type {(toolArguments: Object | string, config?: GraphRunnableConfig) => Promise<unknown>} */
const _call = async (toolArguments, config) => {
try {
const mcpManager = await getMCPManager();
const result = await mcpManager.callTool(serverName, toolName, provider, toolInput);
const derivedSignal = config?.signal ? AbortSignal.any([config.signal]) : undefined;
const mcpManager = getMCPManager(config?.configurable?.user_id);
const provider = (config?.metadata?.provider || _provider)?.toLowerCase();
const result = await mcpManager.callTool({
serverName,
toolName,
provider,
toolArguments,
options: {
userId: config?.configurable?.user_id,
signal: derivedSignal,
},
});
if (isAssistantsEndpoint(provider) && Array.isArray(result)) {
return result[0];
}
@ -50,8 +71,13 @@ async function createMCPTool({ req, toolKey, provider }) {
}
return result;
} catch (error) {
logger.error(`${toolName} MCP server tool call failed`, error);
return `${toolName} MCP server tool call failed.`;
logger.error(
`[MCP][User: ${config?.configurable?.user_id}][${serverName}] Error calling "${toolName}" MCP tool:`,
error,
);
throw new Error(
`"${toolKey}" tool call failed${error?.message ? `: ${error?.message}` : '.'}`,
);
}
};

View file

@ -55,8 +55,7 @@ async function retrieveRun({ thread_id, run_id, timeout, openai }) {
return response.data;
} catch (error) {
const message = '[retrieveRun] Failed to retrieve run data:';
logAxiosError({ message, error });
throw error;
throw new Error(logAxiosError({ message, error }));
}
}

View file

@ -132,6 +132,8 @@ async function saveUserMessage(req, params) {
* @param {string} params.endpoint - The conversation endpoint
* @param {string} params.parentMessageId - The latest user message that triggered this response.
* @param {string} [params.instructions] - Optional: from preset for `instructions` field.
* @param {string} [params.spec] - Optional: Model spec identifier.
* @param {string} [params.iconURL]
* Overrides the instructions of the assistant.
* @param {string} [params.promptPrefix] - Optional: from preset for `additional_instructions` field.
* @return {Promise<Run>} A promise that resolves to the created run object.
@ -154,6 +156,8 @@ async function saveAssistantMessage(req, params) {
text: params.text,
unfinished: false,
// tokenCount,
iconURL: params.iconURL,
spec: params.spec,
});
await saveConvo(
@ -165,6 +169,8 @@ async function saveAssistantMessage(req, params) {
instructions: params.instructions,
assistant_id: params.assistant_id,
model: params.model,
iconURL: params.iconURL,
spec: params.spec,
},
{ context: 'api/server/services/Threads/manage.js #saveAssistantMessage' },
);

View file

@ -93,11 +93,12 @@ const refreshAccessToken = async ({
return response.data;
} catch (error) {
const message = 'Error refreshing OAuth tokens';
logAxiosError({
message,
error,
});
throw new Error(message);
throw new Error(
logAxiosError({
message,
error,
}),
);
}
};
@ -156,11 +157,12 @@ const getAccessToken = async ({
return response.data;
} catch (error) {
const message = 'Error exchanging OAuth code';
logAxiosError({
message,
error,
});
throw new Error(message);
throw new Error(
logAxiosError({
message,
error,
}),
);
}
};

View file

@ -8,6 +8,7 @@ const {
ErrorTypes,
ContentTypes,
imageGenTools,
EToolResources,
EModelEndpoint,
actionDelimiter,
ImageVisionTool,
@ -15,9 +16,20 @@ const {
AgentCapabilities,
validateAndParseOpenAPISpec,
} = require('librechat-data-provider');
const {
createActionTool,
decryptMetadata,
loadActionSets,
domainParser,
} = require('./ActionService');
const {
createOpenAIImageTools,
createYouTubeTools,
manifestToolMap,
toolkits,
} = require('~/app/clients/tools');
const { processFileURL, uploadImageBuffer } = require('~/server/services/Files/process');
const { createYouTubeTools, manifestToolMap, toolkits } = require('~/app/clients/tools');
const { loadActionSets, createActionTool, domainParser } = require('./ActionService');
const { isActionDomainAllowed } = require('~/server/services/domains');
const { getEndpointsConfig } = require('~/server/services/Config');
const { recordUsage } = require('~/server/services/Threads');
const { loadTools } = require('~/app/clients/tools/util');
@ -25,6 +37,30 @@ const { redactMessage } = require('~/config/parsers');
const { sleep } = require('~/server/utils');
const { logger } = require('~/config');
/**
* @param {string} toolName
* @returns {string | undefined} toolKey
*/
function getToolkitKey(toolName) {
/** @type {string|undefined} */
let toolkitKey;
for (const toolkit of toolkits) {
if (toolName.startsWith(EToolResources.image_edit)) {
const splitMatches = toolkit.pluginKey.split('_');
const suffix = splitMatches[splitMatches.length - 1];
if (toolName.endsWith(suffix)) {
toolkitKey = toolkit.pluginKey;
break;
}
}
if (toolName.startsWith(toolkit.pluginKey)) {
toolkitKey = toolkit.pluginKey;
break;
}
}
return toolkitKey;
}
/**
* Loads and formats tools from the specified tool directory.
*
@ -97,14 +133,16 @@ function loadAndFormatTools({ directory, adminFilter = [], adminIncluded = [] })
tools.push(formattedTool);
}
/** Basic Tools; schema: { input: string } */
const basicToolInstances = [new Calculator(), ...createYouTubeTools({ override: true })];
/** Basic Tools & Toolkits; schema: { input: string } */
const basicToolInstances = [
new Calculator(),
...createOpenAIImageTools({ override: true }),
...createYouTubeTools({ override: true }),
];
for (const toolInstance of basicToolInstances) {
const formattedTool = formatToOpenAIAssistantTool(toolInstance);
let toolName = formattedTool[Tools.function].name;
toolName = toolkits.some((toolkit) => toolName.startsWith(toolkit.pluginKey))
? toolName.split('_')[0]
: toolName;
toolName = getToolkitKey(toolName) ?? toolName;
if (filter.has(toolName) && included.size === 0) {
continue;
}
@ -315,54 +353,96 @@ async function processRequiredActions(client, requiredActions) {
if (!tool) {
// throw new Error(`Tool ${currentAction.tool} not found.`);
// Load all action sets once if not already loaded
if (!actionSets.length) {
actionSets =
(await loadActionSets({
assistant_id: client.req.body.assistant_id,
})) ?? [];
// Process all action sets once
// Map domains to their processed action sets
const processedDomains = new Map();
const domainMap = new Map();
for (const action of actionSets) {
const domain = await domainParser(action.metadata.domain, true);
domainMap.set(domain, action);
// Check if domain is allowed
const isDomainAllowed = await isActionDomainAllowed(action.metadata.domain);
if (!isDomainAllowed) {
continue;
}
// Validate and parse OpenAPI spec
const validationResult = validateAndParseOpenAPISpec(action.metadata.raw_spec);
if (!validationResult.spec) {
throw new Error(
`Invalid spec: user: ${client.req.user.id} | thread_id: ${requiredActions[0].thread_id} | run_id: ${requiredActions[0].run_id}`,
);
}
// Process the OpenAPI spec
const { requestBuilders } = openapiToFunction(validationResult.spec);
// Store encrypted values for OAuth flow
const encrypted = {
oauth_client_id: action.metadata.oauth_client_id,
oauth_client_secret: action.metadata.oauth_client_secret,
};
// Decrypt metadata
const decryptedAction = { ...action };
decryptedAction.metadata = await decryptMetadata(action.metadata);
processedDomains.set(domain, {
action: decryptedAction,
requestBuilders,
encrypted,
});
// Store builders for reuse
ActionBuildersMap[action.metadata.domain] = requestBuilders;
}
// Update actionSets reference to use the domain map
actionSets = { domainMap, processedDomains };
}
let actionSet = null;
// Find the matching domain for this tool
let currentDomain = '';
for (let action of actionSets) {
const domain = await domainParser(client.req, action.metadata.domain, true);
for (const domain of actionSets.domainMap.keys()) {
if (currentAction.tool.includes(domain)) {
currentDomain = domain;
actionSet = action;
break;
}
}
if (!actionSet) {
if (!currentDomain || !actionSets.processedDomains.has(currentDomain)) {
// TODO: try `function` if no action set is found
// throw new Error(`Tool ${currentAction.tool} not found.`);
continue;
}
let builders = ActionBuildersMap[actionSet.metadata.domain];
if (!builders) {
const validationResult = validateAndParseOpenAPISpec(actionSet.metadata.raw_spec);
if (!validationResult.spec) {
throw new Error(
`Invalid spec: user: ${client.req.user.id} | thread_id: ${requiredActions[0].thread_id} | run_id: ${requiredActions[0].run_id}`,
);
}
const { requestBuilders } = openapiToFunction(validationResult.spec);
ActionToolMap[actionSet.metadata.domain] = requestBuilders;
builders = requestBuilders;
}
const { action, requestBuilders, encrypted } = actionSets.processedDomains.get(currentDomain);
const functionName = currentAction.tool.replace(`${actionDelimiter}${currentDomain}`, '');
const requestBuilder = builders[functionName];
const requestBuilder = requestBuilders[functionName];
if (!requestBuilder) {
// throw new Error(`Tool ${currentAction.tool} not found.`);
continue;
}
tool = await createActionTool({ action: actionSet, requestBuilder });
// We've already decrypted the metadata, so we can pass it directly
tool = await createActionTool({
userId: client.req.user.id,
res: client.res,
action,
requestBuilder,
// Note: intentionally not passing zodSchema, name, and description for assistants API
encrypted, // Pass the encrypted values for OAuth flow
});
if (!tool) {
logger.warn(
`Invalid action: user: ${client.req.user.id} | thread_id: ${requiredActions[0].thread_id} | run_id: ${requiredActions[0].run_id} | toolName: ${currentAction.tool}`,
@ -410,7 +490,7 @@ async function processRequiredActions(client, requiredActions) {
* @param {Object} params - Run params containing user and request information.
* @param {ServerRequest} params.req - The request object.
* @param {ServerResponse} params.res - The request object.
* @param {Agent} params.agent - The agent to load tools for.
* @param {Pick<Agent, 'id' | 'provider' | 'model' | 'tools'} params.agent - The agent to load tools for.
* @param {string | undefined} [params.openAIApiKey] - The OpenAI API key.
* @returns {Promise<{ tools?: StructuredTool[] }>} The agent tools.
*/
@ -420,21 +500,16 @@ async function loadAgentTools({ req, res, agent, tool_resources, openAIApiKey })
}
const endpointsConfig = await getEndpointsConfig(req);
const capabilities = endpointsConfig?.[EModelEndpoint.agents]?.capabilities ?? [];
const areToolsEnabled = capabilities.includes(AgentCapabilities.tools);
if (!areToolsEnabled) {
logger.debug('Tools are not enabled for this agent.');
return {};
}
const isFileSearchEnabled = capabilities.includes(AgentCapabilities.file_search);
const isCodeEnabled = capabilities.includes(AgentCapabilities.execute_code);
const areActionsEnabled = capabilities.includes(AgentCapabilities.actions);
const enabledCapabilities = new Set(endpointsConfig?.[EModelEndpoint.agents]?.capabilities ?? []);
const checkCapability = (capability) => enabledCapabilities.has(capability);
const areToolsEnabled = checkCapability(AgentCapabilities.tools);
const _agentTools = agent.tools?.filter((tool) => {
if (tool === Tools.file_search && !isFileSearchEnabled) {
return false;
} else if (tool === Tools.execute_code && !isCodeEnabled) {
if (tool === Tools.file_search) {
return checkCapability(AgentCapabilities.file_search);
} else if (tool === Tools.execute_code) {
return checkCapability(AgentCapabilities.execute_code);
} else if (!areToolsEnabled && !tool.includes(actionDelimiter)) {
return false;
}
return true;
@ -468,6 +543,10 @@ async function loadAgentTools({ req, res, agent, tool_resources, openAIApiKey })
continue;
}
if (!areToolsEnabled) {
continue;
}
if (tool.mcp === true) {
agentTools.push(tool);
continue;
@ -500,14 +579,69 @@ async function loadAgentTools({ req, res, agent, tool_resources, openAIApiKey })
return map;
}, {});
if (!areActionsEnabled) {
if (!checkCapability(AgentCapabilities.actions)) {
return {
tools: agentTools,
toolContextMap,
};
}
let actionSets = [];
const actionSets = (await loadActionSets({ agent_id: agent.id })) ?? [];
if (actionSets.length === 0) {
if (_agentTools.length > 0 && agentTools.length === 0) {
logger.warn(`No tools found for the specified tool calls: ${_agentTools.join(', ')}`);
}
return {
tools: agentTools,
toolContextMap,
};
}
// Process each action set once (validate spec, decrypt metadata)
const processedActionSets = new Map();
const domainMap = new Map();
for (const action of actionSets) {
const domain = await domainParser(action.metadata.domain, true);
domainMap.set(domain, action);
// Check if domain is allowed (do this once per action set)
const isDomainAllowed = await isActionDomainAllowed(action.metadata.domain);
if (!isDomainAllowed) {
continue;
}
// Validate and parse OpenAPI spec once per action set
const validationResult = validateAndParseOpenAPISpec(action.metadata.raw_spec);
if (!validationResult.spec) {
continue;
}
const encrypted = {
oauth_client_id: action.metadata.oauth_client_id,
oauth_client_secret: action.metadata.oauth_client_secret,
};
// Decrypt metadata once per action set
const decryptedAction = { ...action };
decryptedAction.metadata = await decryptMetadata(action.metadata);
// Process the OpenAPI spec once per action set
const { requestBuilders, functionSignatures, zodSchemas } = openapiToFunction(
validationResult.spec,
true,
);
processedActionSets.set(domain, {
action: decryptedAction,
requestBuilders,
functionSignatures,
zodSchemas,
encrypted,
});
}
// Now map tools to the processed action sets
const ActionToolMap = {};
for (const toolName of _agentTools) {
@ -515,55 +649,47 @@ async function loadAgentTools({ req, res, agent, tool_resources, openAIApiKey })
continue;
}
if (!actionSets.length) {
actionSets = (await loadActionSets({ agent_id: agent.id })) ?? [];
}
let actionSet = null;
// Find the matching domain for this tool
let currentDomain = '';
for (let action of actionSets) {
const domain = await domainParser(req, action.metadata.domain, true);
for (const domain of domainMap.keys()) {
if (toolName.includes(domain)) {
currentDomain = domain;
actionSet = action;
break;
}
}
if (!actionSet) {
if (!currentDomain || !processedActionSets.has(currentDomain)) {
continue;
}
const validationResult = validateAndParseOpenAPISpec(actionSet.metadata.raw_spec);
if (validationResult.spec) {
const { requestBuilders, functionSignatures, zodSchemas } = openapiToFunction(
validationResult.spec,
true,
);
const functionName = toolName.replace(`${actionDelimiter}${currentDomain}`, '');
const functionSig = functionSignatures.find((sig) => sig.name === functionName);
const requestBuilder = requestBuilders[functionName];
const zodSchema = zodSchemas[functionName];
const { action, encrypted, zodSchemas, requestBuilders, functionSignatures } =
processedActionSets.get(currentDomain);
const functionName = toolName.replace(`${actionDelimiter}${currentDomain}`, '');
const functionSig = functionSignatures.find((sig) => sig.name === functionName);
const requestBuilder = requestBuilders[functionName];
const zodSchema = zodSchemas[functionName];
if (requestBuilder) {
const tool = await createActionTool({
req,
res,
action: actionSet,
requestBuilder,
zodSchema,
name: toolName,
description: functionSig.description,
});
if (!tool) {
logger.warn(
`Invalid action: user: ${req.user.id} | agent_id: ${agent.id} | toolName: ${toolName}`,
);
throw new Error(`{"type":"${ErrorTypes.INVALID_ACTION}"}`);
}
agentTools.push(tool);
ActionToolMap[toolName] = tool;
if (requestBuilder) {
const tool = await createActionTool({
userId: req.user.id,
res,
action,
requestBuilder,
zodSchema,
encrypted,
name: toolName,
description: functionSig.description,
});
if (!tool) {
logger.warn(
`Invalid action: user: ${req.user.id} | agent_id: ${agent.id} | toolName: ${toolName}`,
);
throw new Error(`{"type":"${ErrorTypes.INVALID_ACTION}"}`);
}
agentTools.push(tool);
ActionToolMap[toolName] = tool;
}
}
@ -579,6 +705,7 @@ async function loadAgentTools({ req, res, agent, tool_resources, openAIApiKey })
}
module.exports = {
getToolkitKey,
loadAgentTools,
loadAndFormatTools,
processRequiredActions,

View file

@ -0,0 +1,56 @@
const { getUserPluginAuthValue } = require('~/server/services/PluginService');
/**
*
* @param {Object} params
* @param {string} params.userId
* @param {string[]} params.authFields
* @param {Set<string>} [params.optional]
* @param {boolean} [params.throwError]
* @returns
*/
const loadAuthValues = async ({ userId, authFields, optional, throwError = true }) => {
let authValues = {};
/**
* Finds the first non-empty value for the given authentication field, supporting alternate fields.
* @param {string[]} fields Array of strings representing the authentication fields. Supports alternate fields delimited by "||".
* @returns {Promise<{ authField: string, authValue: string} | null>} An object containing the authentication field and value, or null if not found.
*/
const findAuthValue = async (fields) => {
for (const field of fields) {
let value = process.env[field];
if (value) {
return { authField: field, authValue: value };
}
try {
value = await getUserPluginAuthValue(userId, field, throwError);
} catch (err) {
if (optional && optional.has(field)) {
return { authField: field, authValue: undefined };
}
if (field === fields[fields.length - 1] && !value) {
throw err;
}
}
if (value) {
return { authField: field, authValue: value };
}
}
return null;
};
for (let authField of authFields) {
const fields = authField.split('||');
const result = await findAuthValue(fields);
if (result) {
authValues[result.authField] = result.authValue;
}
}
return authValues;
};
module.exports = {
loadAuthValues,
};

View file

@ -13,6 +13,24 @@ const secretDefaults = {
JWT_REFRESH_SECRET: 'eaa5191f2914e30b9387fd84e254e4ba6fc51b4654968a9b0803b456a54b8418',
};
const deprecatedVariables = [
{
key: 'CHECK_BALANCE',
description:
'Please use the `balance` field in the `librechat.yaml` config file instead.\nMore info: https://librechat.ai/docs/configuration/librechat_yaml/object_structure/balance#overview',
},
{
key: 'START_BALANCE',
description:
'Please use the `balance` field in the `librechat.yaml` config file instead.\nMore info: https://librechat.ai/docs/configuration/librechat_yaml/object_structure/balance#overview',
},
{
key: 'GOOGLE_API_KEY',
description:
'Please use the `GOOGLE_SEARCH_API_KEY` environment variable for the Google Search Tool instead.',
},
];
/**
* Checks environment variables for default secrets and deprecated variables.
* Logs warnings for any default secret values being used and for usage of deprecated `GOOGLE_API_KEY`.
@ -37,19 +55,11 @@ function checkVariables() {
\u200B`);
}
if (process.env.GOOGLE_API_KEY) {
logger.warn(
'The `GOOGLE_API_KEY` environment variable is deprecated.\nPlease use the `GOOGLE_SEARCH_API_KEY` environment variable instead.',
);
}
if (process.env.OPENROUTER_API_KEY) {
logger.warn(
`The \`OPENROUTER_API_KEY\` environment variable is deprecated and its functionality will be removed soon.
Use of this environment variable is highly discouraged as it can lead to unexpected errors when using custom endpoints.
Please use the config (\`librechat.yaml\`) file for setting up OpenRouter, and use \`OPENROUTER_KEY\` or another environment variable instead.`,
);
}
deprecatedVariables.forEach(({ key, description }) => {
if (process.env[key]) {
logger.warn(`The \`${key}\` environment variable is deprecated. ${description}`);
}
});
checkPasswordReset();
}

View file

@ -18,12 +18,15 @@ async function loadDefaultInterface(config, configDefaults, roleName = SystemRol
const { interface: interfaceConfig } = config ?? {};
const { interface: defaults } = configDefaults;
const hasModelSpecs = config?.modelSpecs?.list?.length > 0;
const includesAddedEndpoints = config?.modelSpecs?.addedEndpoints?.length > 0;
/** @type {TCustomConfig['interface']} */
const loadedInterface = removeNullishValues({
endpointsMenu:
interfaceConfig?.endpointsMenu ?? (hasModelSpecs ? false : defaults.endpointsMenu),
modelSelect: interfaceConfig?.modelSelect ?? (hasModelSpecs ? false : defaults.modelSelect),
modelSelect:
interfaceConfig?.modelSelect ??
(hasModelSpecs ? includesAddedEndpoints : defaults.modelSelect),
parameters: interfaceConfig?.parameters ?? (hasModelSpecs ? false : defaults.parameters),
presets: interfaceConfig?.presets ?? (hasModelSpecs ? false : defaults.presets),
sidePanel: interfaceConfig?.sidePanel ?? defaults.sidePanel,

View file

@ -6,9 +6,10 @@ const { logger } = require('~/config');
* Sets up Model Specs from the config (`librechat.yaml`) file.
* @param {TCustomConfig['endpoints']} [endpoints] - The loaded custom configuration for endpoints.
* @param {TCustomConfig['modelSpecs'] | undefined} [modelSpecs] - The loaded custom configuration for model specs.
* @param {TCustomConfig['interface'] | undefined} [interfaceConfig] - The loaded interface configuration.
* @returns {TCustomConfig['modelSpecs'] | undefined} The processed model specs, if any.
*/
function processModelSpecs(endpoints, _modelSpecs) {
function processModelSpecs(endpoints, _modelSpecs, interfaceConfig) {
if (!_modelSpecs) {
return undefined;
}
@ -20,6 +21,19 @@ function processModelSpecs(endpoints, _modelSpecs) {
const customEndpoints = endpoints?.[EModelEndpoint.custom] ?? [];
if (interfaceConfig.modelSelect !== true && (_modelSpecs.addedEndpoints?.length ?? 0) > 0) {
logger.warn(
`To utilize \`addedEndpoints\`, which allows provider/model selections alongside model specs, set \`modelSelect: true\` in the interface configuration.
Example:
\`\`\`yaml
interface:
modelSelect: true
\`\`\`
`,
);
}
for (const spec of list) {
if (EModelEndpoint[spec.preset.endpoint] && spec.preset.endpoint !== EModelEndpoint.custom) {
modelSpecs.push(spec);

View file

@ -1,15 +1,14 @@
const { sign } = require('jsonwebtoken');
const { webcrypto } = require('node:crypto');
const { hashBackupCode, decryptV2 } = require('~/server/utils/crypto');
const { updateUser } = require('~/models/userMethods');
const { decryptV3, decryptV2 } = require('../utils/crypto');
const { hashBackupCode } = require('~/server/utils/crypto');
// Base32 alphabet for TOTP secret encoding.
const BASE32_ALPHABET = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ234567';
/**
* Encodes a Buffer into a Base32 string using the RFC 4648 alphabet.
*
* @param {Buffer} buffer - The buffer to encode.
* @returns {string} The Base32 encoded string.
* Encodes a Buffer into a Base32 string.
* @param {Buffer} buffer
* @returns {string}
*/
const encodeBase32 = (buffer) => {
let bits = 0;
@ -30,10 +29,9 @@ const encodeBase32 = (buffer) => {
};
/**
* Decodes a Base32-encoded string back into a Buffer.
*
* @param {string} base32Str - The Base32-encoded string.
* @returns {Buffer} The decoded buffer.
* Decodes a Base32 string into a Buffer.
* @param {string} base32Str
* @returns {Buffer}
*/
const decodeBase32 = (base32Str) => {
const cleaned = base32Str.replace(/=+$/, '').toUpperCase();
@ -56,20 +54,8 @@ const decodeBase32 = (base32Str) => {
};
/**
* Generates a temporary token for 2FA verification.
* The token is signed with the JWT_SECRET and expires in 5 minutes.
*
* @param {string} userId - The unique identifier of the user.
* @returns {string} The signed JWT token.
*/
const generate2FATempToken = (userId) =>
sign({ userId, twoFAPending: true }, process.env.JWT_SECRET, { expiresIn: '5m' });
/**
* Generates a TOTP secret.
* Creates 10 random bytes using WebCrypto and encodes them into a Base32 string.
*
* @returns {string} A Base32-encoded secret for TOTP.
* Generates a new TOTP secret (Base32 encoded).
* @returns {string}
*/
const generateTOTPSecret = () => {
const randomArray = new Uint8Array(10);
@ -78,29 +64,25 @@ const generateTOTPSecret = () => {
};
/**
* Generates a Time-based One-Time Password (TOTP) based on the provided secret and time.
* This implementation uses a 30-second time step and produces a 6-digit code.
*
* @param {string} secret - The Base32-encoded TOTP secret.
* @param {number} [forTime=Date.now()] - The time (in milliseconds) for which to generate the TOTP.
* @returns {Promise<string>} A promise that resolves to the 6-digit TOTP code.
* Generates a TOTP code based on the secret and time.
* Uses a 30-second time step and produces a 6-digit code.
* @param {string} secret
* @param {number} [forTime=Date.now()]
* @returns {Promise<string>}
*/
const generateTOTP = async (secret, forTime = Date.now()) => {
const timeStep = 30; // seconds
const counter = Math.floor(forTime / 1000 / timeStep);
const counterBuffer = new ArrayBuffer(8);
const counterView = new DataView(counterBuffer);
// Write counter into the last 4 bytes (big-endian)
counterView.setUint32(4, counter, false);
// Decode the secret into an ArrayBuffer
const keyBuffer = decodeBase32(secret);
const keyArrayBuffer = keyBuffer.buffer.slice(
keyBuffer.byteOffset,
keyBuffer.byteOffset + keyBuffer.byteLength,
);
// Import the key for HMAC-SHA1 signing
const cryptoKey = await webcrypto.subtle.importKey(
'raw',
keyArrayBuffer,
@ -108,12 +90,10 @@ const generateTOTP = async (secret, forTime = Date.now()) => {
false,
['sign'],
);
// Generate HMAC signature
const signatureBuffer = await webcrypto.subtle.sign('HMAC', cryptoKey, counterBuffer);
const hmac = new Uint8Array(signatureBuffer);
// Dynamic truncation as per RFC 4226
// Dynamic truncation per RFC 4226.
const offset = hmac[hmac.length - 1] & 0xf;
const slice = hmac.slice(offset, offset + 4);
const view = new DataView(slice.buffer, slice.byteOffset, slice.byteLength);
@ -123,12 +103,10 @@ const generateTOTP = async (secret, forTime = Date.now()) => {
};
/**
* Verifies a provided TOTP token against the secret.
* It allows for a ±1 time-step window to account for slight clock discrepancies.
*
* @param {string} secret - The Base32-encoded TOTP secret.
* @param {string} token - The TOTP token provided by the user.
* @returns {Promise<boolean>} A promise that resolves to true if the token is valid; otherwise, false.
* Verifies a TOTP token by checking a ±1 time step window.
* @param {string} secret
* @param {string} token
* @returns {Promise<boolean>}
*/
const verifyTOTP = async (secret, token) => {
const timeStepMS = 30 * 1000;
@ -143,27 +121,24 @@ const verifyTOTP = async (secret, token) => {
};
/**
* Generates backup codes for two-factor authentication.
* Each backup code is an 8-character hexadecimal string along with its SHA-256 hash.
* The plain codes are returned for one-time download, while the hashed objects are meant for secure storage.
*
* @param {number} [count=10] - The number of backup codes to generate.
* Generates backup codes (default count: 10).
* Each code is an 8-character hexadecimal string and stored with its SHA-256 hash.
* @param {number} [count=10]
* @returns {Promise<{ plainCodes: string[], codeObjects: Array<{ codeHash: string, used: boolean, usedAt: Date | null }> }>}
* A promise that resolves to an object containing both plain backup codes and their corresponding code objects.
*/
const generateBackupCodes = async (count = 10) => {
const plainCodes = [];
const codeObjects = [];
const encoder = new TextEncoder();
for (let i = 0; i < count; i++) {
const randomArray = new Uint8Array(4);
webcrypto.getRandomValues(randomArray);
const code = Array.from(randomArray)
.map((b) => b.toString(16).padStart(2, '0'))
.join(''); // 8-character hex code
.join('');
plainCodes.push(code);
// Compute SHA-256 hash of the code using WebCrypto
const codeBuffer = encoder.encode(code);
const hashBuffer = await webcrypto.subtle.digest('SHA-256', codeBuffer);
const hashArray = Array.from(new Uint8Array(hashBuffer));
@ -174,12 +149,11 @@ const generateBackupCodes = async (count = 10) => {
};
/**
* Verifies a backup code for a user and updates its status as used if valid.
*
* @param {Object} params - The parameters object.
* @param {TUser | undefined} [params.user] - The user object containing backup codes.
* @param {string | undefined} [params.backupCode] - The backup code to verify.
* @returns {Promise<boolean>} A promise that resolves to true if the backup code is valid and updated; otherwise, false.
* Verifies a backup code and, if valid, marks it as used.
* @param {Object} params
* @param {Object} params.user
* @param {string} params.backupCode
* @returns {Promise<boolean>}
*/
const verifyBackupCode = async ({ user, backupCode }) => {
if (!backupCode || !user || !Array.isArray(user.backupCodes)) {
@ -197,42 +171,54 @@ const verifyBackupCode = async ({ user, backupCode }) => {
? { ...codeObj, used: true, usedAt: new Date() }
: codeObj,
);
// Update the user record with the marked backup code.
const { updateUser } = require('~/models');
await updateUser(user._id, { backupCodes: updatedBackupCodes });
return true;
}
return false;
};
/**
* Retrieves and, if necessary, decrypts a stored TOTP secret.
* If the secret contains a colon, it is assumed to be in the format "iv:encryptedData" and will be decrypted.
* If the secret is exactly 16 characters long, it is assumed to be a legacy plain secret.
*
* @param {string|null} storedSecret - The stored TOTP secret (which may be encrypted).
* @returns {Promise<string|null>} A promise that resolves to the plain TOTP secret, or null if none is provided.
* Retrieves and decrypts a stored TOTP secret.
* - Uses decryptV3 if the secret has a "v3:" prefix.
* - Falls back to decryptV2 for colon-delimited values.
* - Assumes a 16-character secret is already plain.
* @param {string|null} storedSecret
* @returns {Promise<string|null>}
*/
const getTOTPSecret = async (storedSecret) => {
if (!storedSecret) { return null; }
// Check for a colon marker (encrypted secrets are stored as "iv:encryptedData")
if (!storedSecret) {
return null;
}
if (storedSecret.startsWith('v3:')) {
return decryptV3(storedSecret);
}
if (storedSecret.includes(':')) {
return await decryptV2(storedSecret);
}
// If it's exactly 16 characters, assume it's already plain (legacy secret)
if (storedSecret.length === 16) {
return storedSecret;
}
// Fallback in case it doesn't meet our criteria.
return storedSecret;
};
/**
* Generates a temporary JWT token for 2FA verification that expires in 5 minutes.
* @param {string} userId
* @returns {string}
*/
const generate2FATempToken = (userId) => {
const { sign } = require('jsonwebtoken');
return sign({ userId, twoFAPending: true }, process.env.JWT_SECRET, { expiresIn: '5m' });
};
module.exports = {
verifyTOTP,
generateTOTP,
getTOTPSecret,
verifyBackupCode,
generateTOTPSecret,
generateTOTP,
verifyTOTP,
generateBackupCodes,
verifyBackupCode,
getTOTPSecret,
generate2FATempToken,
};