mirror of
https://github.com/danny-avila/LibreChat.git
synced 2026-01-12 21:48:51 +01:00
Merge branch 'main' into feat/Multitenant-login-OIDC
This commit is contained in:
commit
c14751cef5
417 changed files with 28394 additions and 9012 deletions
|
|
@ -150,11 +150,11 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
|
|||
} catch (error) {
|
||||
const partialText = getText && getText();
|
||||
handleAbortError(res, req, error, {
|
||||
sender,
|
||||
partialText,
|
||||
conversationId,
|
||||
sender,
|
||||
messageId: responseMessageId,
|
||||
parentMessageId: userMessageId ?? parentMessageId,
|
||||
parentMessageId: overrideParentMessageId ?? userMessageId ?? parentMessageId,
|
||||
}).catch((err) => {
|
||||
logger.error('[AskController] Error in `handleAbortError`', err);
|
||||
});
|
||||
|
|
|
|||
|
|
@ -61,7 +61,7 @@ const refreshController = async (req, res) => {
|
|||
|
||||
try {
|
||||
const payload = jwt.verify(refreshToken, process.env.JWT_REFRESH_SECRET);
|
||||
const user = await getUserById(payload.id, '-password -__v');
|
||||
const user = await getUserById(payload.id, '-password -__v -totpSecret');
|
||||
if (!user) {
|
||||
return res.status(401).redirect('/login');
|
||||
}
|
||||
|
|
|
|||
|
|
@ -135,11 +135,11 @@ const EditController = async (req, res, next, initializeClient) => {
|
|||
} catch (error) {
|
||||
const partialText = getText();
|
||||
handleAbortError(res, req, error, {
|
||||
sender,
|
||||
partialText,
|
||||
conversationId,
|
||||
sender,
|
||||
messageId: responseMessageId,
|
||||
parentMessageId: userMessageId ?? parentMessageId,
|
||||
parentMessageId: overrideParentMessageId ?? userMessageId ?? parentMessageId,
|
||||
}).catch((err) => {
|
||||
logger.error('[EditController] Error in `handleAbortError`', err);
|
||||
});
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
const { CacheKeys } = require('librechat-data-provider');
|
||||
const { loadDefaultModels, loadConfigModels } = require('~/server/services/Config');
|
||||
const { getLogStores } = require('~/cache');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
/**
|
||||
* @param {ServerRequest} req
|
||||
|
|
@ -36,8 +37,13 @@ async function loadModels(req) {
|
|||
}
|
||||
|
||||
async function modelController(req, res) {
|
||||
const modelConfig = await loadModels(req);
|
||||
res.send(modelConfig);
|
||||
try {
|
||||
const modelConfig = await loadModels(req);
|
||||
res.send(modelConfig);
|
||||
} catch (error) {
|
||||
logger.error('Error fetching models:', error);
|
||||
res.status(500).send({ error: error.message });
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = { modelController, loadModels, getModelsConfig };
|
||||
|
|
|
|||
138
api/server/controllers/TwoFactorController.js
Normal file
138
api/server/controllers/TwoFactorController.js
Normal file
|
|
@ -0,0 +1,138 @@
|
|||
const {
|
||||
generateTOTPSecret,
|
||||
generateBackupCodes,
|
||||
verifyTOTP,
|
||||
verifyBackupCode,
|
||||
getTOTPSecret,
|
||||
} = require('~/server/services/twoFactorService');
|
||||
const { updateUser, getUserById } = require('~/models');
|
||||
const { logger } = require('~/config');
|
||||
const { encryptV3 } = require('~/server/utils/crypto');
|
||||
|
||||
const safeAppTitle = (process.env.APP_TITLE || 'LibreChat').replace(/\s+/g, '');
|
||||
|
||||
/**
|
||||
* Enable 2FA for the user by generating a new TOTP secret and backup codes.
|
||||
* The secret is encrypted and stored, and 2FA is marked as disabled until confirmed.
|
||||
*/
|
||||
const enable2FA = async (req, res) => {
|
||||
try {
|
||||
const userId = req.user.id;
|
||||
const secret = generateTOTPSecret();
|
||||
const { plainCodes, codeObjects } = await generateBackupCodes();
|
||||
|
||||
// Encrypt the secret with v3 encryption before saving.
|
||||
const encryptedSecret = encryptV3(secret);
|
||||
|
||||
// Update the user record: store the secret & backup codes and set twoFactorEnabled to false.
|
||||
const user = await updateUser(userId, {
|
||||
totpSecret: encryptedSecret,
|
||||
backupCodes: codeObjects,
|
||||
twoFactorEnabled: false,
|
||||
});
|
||||
|
||||
const otpauthUrl = `otpauth://totp/${safeAppTitle}:${user.email}?secret=${secret}&issuer=${safeAppTitle}`;
|
||||
|
||||
return res.status(200).json({ otpauthUrl, backupCodes: plainCodes });
|
||||
} catch (err) {
|
||||
logger.error('[enable2FA]', err);
|
||||
return res.status(500).json({ message: err.message });
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Verify a 2FA code (either TOTP or backup code) during setup.
|
||||
*/
|
||||
const verify2FA = async (req, res) => {
|
||||
try {
|
||||
const userId = req.user.id;
|
||||
const { token, backupCode } = req.body;
|
||||
const user = await getUserById(userId);
|
||||
|
||||
if (!user || !user.totpSecret) {
|
||||
return res.status(400).json({ message: '2FA not initiated' });
|
||||
}
|
||||
|
||||
const secret = await getTOTPSecret(user.totpSecret);
|
||||
let isVerified = false;
|
||||
|
||||
if (token) {
|
||||
isVerified = await verifyTOTP(secret, token);
|
||||
} else if (backupCode) {
|
||||
isVerified = await verifyBackupCode({ user, backupCode });
|
||||
}
|
||||
|
||||
if (isVerified) {
|
||||
return res.status(200).json();
|
||||
}
|
||||
return res.status(400).json({ message: 'Invalid token or backup code.' });
|
||||
} catch (err) {
|
||||
logger.error('[verify2FA]', err);
|
||||
return res.status(500).json({ message: err.message });
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Confirm and enable 2FA after a successful verification.
|
||||
*/
|
||||
const confirm2FA = async (req, res) => {
|
||||
try {
|
||||
const userId = req.user.id;
|
||||
const { token } = req.body;
|
||||
const user = await getUserById(userId);
|
||||
|
||||
if (!user || !user.totpSecret) {
|
||||
return res.status(400).json({ message: '2FA not initiated' });
|
||||
}
|
||||
|
||||
const secret = await getTOTPSecret(user.totpSecret);
|
||||
if (await verifyTOTP(secret, token)) {
|
||||
await updateUser(userId, { twoFactorEnabled: true });
|
||||
return res.status(200).json();
|
||||
}
|
||||
return res.status(400).json({ message: 'Invalid token.' });
|
||||
} catch (err) {
|
||||
logger.error('[confirm2FA]', err);
|
||||
return res.status(500).json({ message: err.message });
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Disable 2FA by clearing the stored secret and backup codes.
|
||||
*/
|
||||
const disable2FA = async (req, res) => {
|
||||
try {
|
||||
const userId = req.user.id;
|
||||
await updateUser(userId, { totpSecret: null, backupCodes: [], twoFactorEnabled: false });
|
||||
return res.status(200).json();
|
||||
} catch (err) {
|
||||
logger.error('[disable2FA]', err);
|
||||
return res.status(500).json({ message: err.message });
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Regenerate backup codes for the user.
|
||||
*/
|
||||
const regenerateBackupCodes = async (req, res) => {
|
||||
try {
|
||||
const userId = req.user.id;
|
||||
const { plainCodes, codeObjects } = await generateBackupCodes();
|
||||
await updateUser(userId, { backupCodes: codeObjects });
|
||||
return res.status(200).json({
|
||||
backupCodes: plainCodes,
|
||||
backupCodesHash: codeObjects,
|
||||
});
|
||||
} catch (err) {
|
||||
logger.error('[regenerateBackupCodes]', err);
|
||||
return res.status(500).json({ message: err.message });
|
||||
}
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
enable2FA,
|
||||
verify2FA,
|
||||
confirm2FA,
|
||||
disable2FA,
|
||||
regenerateBackupCodes,
|
||||
};
|
||||
|
|
@ -19,7 +19,9 @@ const { Transaction } = require('~/models/Transaction');
|
|||
const { logger } = require('~/config');
|
||||
|
||||
const getUserController = async (req, res) => {
|
||||
res.status(200).send(req.user);
|
||||
const userData = req.user.toObject != null ? req.user.toObject() : { ...req.user };
|
||||
delete userData.totpSecret;
|
||||
res.status(200).send(userData);
|
||||
};
|
||||
|
||||
const getTermsStatusController = async (req, res) => {
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
const { Tools, StepTypes, imageGenTools, FileContext } = require('librechat-data-provider');
|
||||
const { nanoid } = require('nanoid');
|
||||
const { Tools, StepTypes, FileContext } = require('librechat-data-provider');
|
||||
const {
|
||||
EnvVar,
|
||||
Providers,
|
||||
|
|
@ -9,8 +10,8 @@ const {
|
|||
ChatModelStreamHandler,
|
||||
} = require('@librechat/agents');
|
||||
const { processCodeOutput } = require('~/server/services/Files/Code/process');
|
||||
const { loadAuthValues } = require('~/server/services/Tools/credentials');
|
||||
const { saveBase64Image } = require('~/server/services/Files/process');
|
||||
const { loadAuthValues } = require('~/app/clients/tools/util');
|
||||
const { logger, sendEvent } = require('~/config');
|
||||
|
||||
/** @typedef {import('@librechat/agents').Graph} Graph */
|
||||
|
|
@ -199,6 +200,22 @@ function getDefaultHandlers({ res, aggregateContent, toolEndCallback, collectedU
|
|||
aggregateContent({ event, data });
|
||||
},
|
||||
},
|
||||
[GraphEvents.ON_REASONING_DELTA]: {
|
||||
/**
|
||||
* Handle ON_REASONING_DELTA event.
|
||||
* @param {string} event - The event name.
|
||||
* @param {StreamEventData} data - The event data.
|
||||
* @param {GraphRunnableConfig['configurable']} [metadata] The runnable metadata.
|
||||
*/
|
||||
handle: (event, data, metadata) => {
|
||||
if (metadata?.last_agent_index === metadata?.agent_index) {
|
||||
sendEvent(res, { event, data });
|
||||
} else if (!metadata?.hide_sequential_outputs) {
|
||||
sendEvent(res, { event, data });
|
||||
}
|
||||
aggregateContent({ event, data });
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
return handlers;
|
||||
|
|
@ -226,32 +243,6 @@ function createToolEndCallback({ req, res, artifactPromises }) {
|
|||
return;
|
||||
}
|
||||
|
||||
if (imageGenTools.has(output.name)) {
|
||||
artifactPromises.push(
|
||||
(async () => {
|
||||
const fileMetadata = Object.assign(output.artifact, {
|
||||
messageId: metadata.run_id,
|
||||
toolCallId: output.tool_call_id,
|
||||
conversationId: metadata.thread_id,
|
||||
});
|
||||
if (!res.headersSent) {
|
||||
return fileMetadata;
|
||||
}
|
||||
|
||||
if (!fileMetadata) {
|
||||
return null;
|
||||
}
|
||||
|
||||
res.write(`event: attachment\ndata: ${JSON.stringify(fileMetadata)}\n\n`);
|
||||
return fileMetadata;
|
||||
})().catch((error) => {
|
||||
logger.error('Error processing code output:', error);
|
||||
return null;
|
||||
}),
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
if (output.artifact.content) {
|
||||
/** @type {FormattedContent[]} */
|
||||
const content = output.artifact.content;
|
||||
|
|
@ -262,7 +253,7 @@ function createToolEndCallback({ req, res, artifactPromises }) {
|
|||
const { url } = part.image_url;
|
||||
artifactPromises.push(
|
||||
(async () => {
|
||||
const filename = `${output.tool_call_id}-image-${new Date().getTime()}`;
|
||||
const filename = `${output.name}_${output.tool_call_id}_img_${nanoid()}`;
|
||||
const file = await saveBase64Image(url, {
|
||||
req,
|
||||
filename,
|
||||
|
|
|
|||
|
|
@ -7,7 +7,16 @@
|
|||
// validateVisionModel,
|
||||
// mapModelToAzureConfig,
|
||||
// } = require('librechat-data-provider');
|
||||
const { Callback, createMetadataAggregator } = require('@librechat/agents');
|
||||
require('events').EventEmitter.defaultMaxListeners = 100;
|
||||
const {
|
||||
Callback,
|
||||
GraphEvents,
|
||||
formatMessage,
|
||||
formatAgentMessages,
|
||||
formatContentStrings,
|
||||
getTokenCountForMessage,
|
||||
createMetadataAggregator,
|
||||
} = require('@librechat/agents');
|
||||
const {
|
||||
Constants,
|
||||
VisionModes,
|
||||
|
|
@ -17,36 +26,28 @@ const {
|
|||
KnownEndpoints,
|
||||
anthropicSchema,
|
||||
isAgentsEndpoint,
|
||||
bedrockOutputParser,
|
||||
AgentCapabilities,
|
||||
bedrockInputSchema,
|
||||
removeNullishValues,
|
||||
} = require('librechat-data-provider');
|
||||
const {
|
||||
extractBaseURL,
|
||||
// constructAzureURL,
|
||||
// genAzureChatCompletion,
|
||||
} = require('~/utils');
|
||||
const {
|
||||
formatMessage,
|
||||
formatAgentMessages,
|
||||
formatContentStrings,
|
||||
createContextHandlers,
|
||||
} = require('~/app/clients/prompts');
|
||||
const { encodeAndFormat } = require('~/server/services/Files/images/encode');
|
||||
const { getCustomEndpointConfig, checkCapability } = require('~/server/services/Config');
|
||||
const { addCacheControl, createContextHandlers } = require('~/app/clients/prompts');
|
||||
const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens');
|
||||
const { getBufferString, HumanMessage } = require('@langchain/core/messages');
|
||||
const { encodeAndFormat } = require('~/server/services/Files/images/encode');
|
||||
const Tokenizer = require('~/server/services/Tokenizer');
|
||||
const { spendTokens } = require('~/models/spendTokens');
|
||||
const BaseClient = require('~/app/clients/BaseClient');
|
||||
const { logger, sendEvent } = require('~/config');
|
||||
const { createRun } = require('./run');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
/** @typedef {import('@librechat/agents').MessageContentComplex} MessageContentComplex */
|
||||
/** @typedef {import('@langchain/core/runnables').RunnableConfig} RunnableConfig */
|
||||
|
||||
const providerParsers = {
|
||||
[EModelEndpoint.openAI]: openAISchema,
|
||||
[EModelEndpoint.azureOpenAI]: openAISchema,
|
||||
[EModelEndpoint.anthropic]: anthropicSchema,
|
||||
[EModelEndpoint.bedrock]: bedrockOutputParser,
|
||||
[EModelEndpoint.openAI]: openAISchema.parse,
|
||||
[EModelEndpoint.azureOpenAI]: openAISchema.parse,
|
||||
[EModelEndpoint.anthropic]: anthropicSchema.parse,
|
||||
[EModelEndpoint.bedrock]: bedrockInputSchema.parse,
|
||||
};
|
||||
|
||||
const legacyContentEndpoints = new Set([KnownEndpoints.groq, KnownEndpoints.deepseek]);
|
||||
|
|
@ -102,6 +103,8 @@ class AgentClient extends BaseClient {
|
|||
this.outputTokensKey = 'output_tokens';
|
||||
/** @type {UsageMetadata} */
|
||||
this.usage;
|
||||
/** @type {Record<string, number>} */
|
||||
this.indexTokenCountMap = {};
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -191,7 +194,14 @@ class AgentClient extends BaseClient {
|
|||
: {};
|
||||
|
||||
if (parseOptions) {
|
||||
runOptions = parseOptions(this.options.agent.model_parameters);
|
||||
try {
|
||||
runOptions = parseOptions(this.options.agent.model_parameters);
|
||||
} catch (error) {
|
||||
logger.error(
|
||||
'[api/server/controllers/agents/client.js #getSaveOptions] Error parsing options',
|
||||
error,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
return removeNullishValues(
|
||||
|
|
@ -219,14 +229,23 @@ class AgentClient extends BaseClient {
|
|||
};
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param {TMessage} message
|
||||
* @param {Array<MongoFile>} attachments
|
||||
* @returns {Promise<Array<Partial<MongoFile>>>}
|
||||
*/
|
||||
async addImageURLs(message, attachments) {
|
||||
const { files, image_urls } = await encodeAndFormat(
|
||||
const { files, text, image_urls } = await encodeAndFormat(
|
||||
this.options.req,
|
||||
attachments,
|
||||
this.options.agent.provider,
|
||||
VisionModes.agents,
|
||||
);
|
||||
message.image_urls = image_urls.length ? image_urls : undefined;
|
||||
if (text && text.length) {
|
||||
message.ocr = text;
|
||||
}
|
||||
return files;
|
||||
}
|
||||
|
||||
|
|
@ -304,7 +323,21 @@ class AgentClient extends BaseClient {
|
|||
assistantName: this.options?.modelLabel,
|
||||
});
|
||||
|
||||
const needsTokenCount = this.contextStrategy && !orderedMessages[i].tokenCount;
|
||||
if (message.ocr && i !== orderedMessages.length - 1) {
|
||||
if (typeof formattedMessage.content === 'string') {
|
||||
formattedMessage.content = message.ocr + '\n' + formattedMessage.content;
|
||||
} else {
|
||||
const textPart = formattedMessage.content.find((part) => part.type === 'text');
|
||||
textPart
|
||||
? (textPart.text = message.ocr + '\n' + textPart.text)
|
||||
: formattedMessage.content.unshift({ type: 'text', text: message.ocr });
|
||||
}
|
||||
} else if (message.ocr && i === orderedMessages.length - 1) {
|
||||
systemContent = [systemContent, message.ocr].join('\n');
|
||||
}
|
||||
|
||||
const needsTokenCount =
|
||||
(this.contextStrategy && !orderedMessages[i].tokenCount) || message.ocr;
|
||||
|
||||
/* If tokens were never counted, or, is a Vision request and the message has files, count again */
|
||||
if (needsTokenCount || (this.isVisionModel && (message.image_urls || message.files))) {
|
||||
|
|
@ -350,6 +383,10 @@ class AgentClient extends BaseClient {
|
|||
}));
|
||||
}
|
||||
|
||||
for (let i = 0; i < messages.length; i++) {
|
||||
this.indexTokenCountMap[i] = messages[i].tokenCount;
|
||||
}
|
||||
|
||||
const result = {
|
||||
tokenCountMap,
|
||||
prompt: payload,
|
||||
|
|
@ -384,15 +421,34 @@ class AgentClient extends BaseClient {
|
|||
if (!collectedUsage || !collectedUsage.length) {
|
||||
return;
|
||||
}
|
||||
const input_tokens = collectedUsage[0]?.input_tokens || 0;
|
||||
const input_tokens =
|
||||
(collectedUsage[0]?.input_tokens || 0) +
|
||||
(Number(collectedUsage[0]?.input_token_details?.cache_creation) || 0) +
|
||||
(Number(collectedUsage[0]?.input_token_details?.cache_read) || 0);
|
||||
|
||||
let output_tokens = 0;
|
||||
let previousTokens = input_tokens; // Start with original input
|
||||
for (let i = 0; i < collectedUsage.length; i++) {
|
||||
const usage = collectedUsage[i];
|
||||
if (!usage) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const cache_creation = Number(usage.input_token_details?.cache_creation) || 0;
|
||||
const cache_read = Number(usage.input_token_details?.cache_read) || 0;
|
||||
|
||||
const txMetadata = {
|
||||
context,
|
||||
conversationId: this.conversationId,
|
||||
user: this.user ?? this.options.req.user?.id,
|
||||
endpointTokenConfig: this.options.endpointTokenConfig,
|
||||
model: usage.model ?? model ?? this.model ?? this.options.agent.model_parameters.model,
|
||||
};
|
||||
|
||||
if (i > 0) {
|
||||
// Count new tokens generated (input_tokens minus previous accumulated tokens)
|
||||
output_tokens += (Number(usage.input_tokens) || 0) - previousTokens;
|
||||
output_tokens +=
|
||||
(Number(usage.input_tokens) || 0) + cache_creation + cache_read - previousTokens;
|
||||
}
|
||||
|
||||
// Add this message's output tokens
|
||||
|
|
@ -400,16 +456,26 @@ class AgentClient extends BaseClient {
|
|||
|
||||
// Update previousTokens to include this message's output
|
||||
previousTokens += Number(usage.output_tokens) || 0;
|
||||
spendTokens(
|
||||
{
|
||||
context,
|
||||
conversationId: this.conversationId,
|
||||
user: this.user ?? this.options.req.user?.id,
|
||||
endpointTokenConfig: this.options.endpointTokenConfig,
|
||||
model: usage.model ?? model ?? this.model ?? this.options.agent.model_parameters.model,
|
||||
},
|
||||
{ promptTokens: usage.input_tokens, completionTokens: usage.output_tokens },
|
||||
).catch((err) => {
|
||||
|
||||
if (cache_creation > 0 || cache_read > 0) {
|
||||
spendStructuredTokens(txMetadata, {
|
||||
promptTokens: {
|
||||
input: usage.input_tokens,
|
||||
write: cache_creation,
|
||||
read: cache_read,
|
||||
},
|
||||
completionTokens: usage.output_tokens,
|
||||
}).catch((err) => {
|
||||
logger.error(
|
||||
'[api/server/controllers/agents/client.js #recordCollectedUsage] Error spending structured tokens',
|
||||
err,
|
||||
);
|
||||
});
|
||||
}
|
||||
spendTokens(txMetadata, {
|
||||
promptTokens: usage.input_tokens,
|
||||
completionTokens: usage.output_tokens,
|
||||
}).catch((err) => {
|
||||
logger.error(
|
||||
'[api/server/controllers/agents/client.js #recordCollectedUsage] Error spending tokens',
|
||||
err,
|
||||
|
|
@ -477,19 +543,6 @@ class AgentClient extends BaseClient {
|
|||
abortController = new AbortController();
|
||||
}
|
||||
|
||||
const baseURL = extractBaseURL(this.completionsUrl);
|
||||
logger.debug('[api/server/controllers/agents/client.js] chatCompletion', {
|
||||
baseURL,
|
||||
payload,
|
||||
});
|
||||
|
||||
// if (this.useOpenRouter) {
|
||||
// opts.defaultHeaders = {
|
||||
// 'HTTP-Referer': 'https://librechat.ai',
|
||||
// 'X-Title': 'LibreChat',
|
||||
// };
|
||||
// }
|
||||
|
||||
// if (this.options.headers) {
|
||||
// opts.defaultHeaders = { ...opts.defaultHeaders, ...this.options.headers };
|
||||
// }
|
||||
|
|
@ -579,6 +632,9 @@ class AgentClient extends BaseClient {
|
|||
// });
|
||||
// }
|
||||
|
||||
/** @type {TCustomConfig['endpoints']['agents']} */
|
||||
const agentsEConfig = this.options.req.app.locals[EModelEndpoint.agents];
|
||||
|
||||
/** @type {Partial<RunnableConfig> & { version: 'v1' | 'v2'; run_id?: string; streamMode: string }} */
|
||||
const config = {
|
||||
configurable: {
|
||||
|
|
@ -586,19 +642,30 @@ class AgentClient extends BaseClient {
|
|||
last_agent_index: this.agentConfigs?.size ?? 0,
|
||||
hide_sequential_outputs: this.options.agent.hide_sequential_outputs,
|
||||
},
|
||||
recursionLimit: this.options.req.app.locals[EModelEndpoint.agents]?.recursionLimit,
|
||||
recursionLimit: agentsEConfig?.recursionLimit,
|
||||
signal: abortController.signal,
|
||||
streamMode: 'values',
|
||||
version: 'v2',
|
||||
};
|
||||
|
||||
const initialMessages = formatAgentMessages(payload);
|
||||
const toolSet = new Set((this.options.agent.tools ?? []).map((tool) => tool && tool.name));
|
||||
let { messages: initialMessages, indexTokenCountMap } = formatAgentMessages(
|
||||
payload,
|
||||
this.indexTokenCountMap,
|
||||
toolSet,
|
||||
);
|
||||
if (legacyContentEndpoints.has(this.options.agent.endpoint)) {
|
||||
formatContentStrings(initialMessages);
|
||||
initialMessages = formatContentStrings(initialMessages);
|
||||
}
|
||||
|
||||
/** @type {ReturnType<createRun>} */
|
||||
let run;
|
||||
const countTokens = ((text) => this.getTokenCount(text)).bind(this);
|
||||
|
||||
/** @type {(message: BaseMessage) => number} */
|
||||
const tokenCounter = (message) => {
|
||||
return getTokenCountForMessage(message, countTokens);
|
||||
};
|
||||
|
||||
/**
|
||||
*
|
||||
|
|
@ -606,12 +673,23 @@ class AgentClient extends BaseClient {
|
|||
* @param {BaseMessage[]} messages
|
||||
* @param {number} [i]
|
||||
* @param {TMessageContentParts[]} [contentData]
|
||||
* @param {Record<string, number>} [currentIndexCountMap]
|
||||
*/
|
||||
const runAgent = async (agent, messages, i = 0, contentData = []) => {
|
||||
const runAgent = async (agent, _messages, i = 0, contentData = [], _currentIndexCountMap) => {
|
||||
config.configurable.model = agent.model_parameters.model;
|
||||
const currentIndexCountMap = _currentIndexCountMap ?? indexTokenCountMap;
|
||||
if (i > 0) {
|
||||
this.model = agent.model_parameters.model;
|
||||
}
|
||||
if (agent.recursion_limit && typeof agent.recursion_limit === 'number') {
|
||||
config.recursionLimit = agent.recursion_limit;
|
||||
}
|
||||
if (
|
||||
agentsEConfig?.maxRecursionLimit &&
|
||||
config.recursionLimit > agentsEConfig?.maxRecursionLimit
|
||||
) {
|
||||
config.recursionLimit = agentsEConfig?.maxRecursionLimit;
|
||||
}
|
||||
config.configurable.agent_id = agent.id;
|
||||
config.configurable.name = agent.name;
|
||||
config.configurable.agent_index = i;
|
||||
|
|
@ -626,7 +704,7 @@ class AgentClient extends BaseClient {
|
|||
let systemContent = [
|
||||
systemMessage,
|
||||
agent.instructions ?? '',
|
||||
i !== 0 ? agent.additional_instructions ?? '' : '',
|
||||
i !== 0 ? (agent.additional_instructions ?? '') : '',
|
||||
]
|
||||
.join('\n')
|
||||
.trim();
|
||||
|
|
@ -640,12 +718,21 @@ class AgentClient extends BaseClient {
|
|||
}
|
||||
|
||||
if (noSystemMessages === true && systemContent?.length) {
|
||||
let latestMessage = messages.pop().content;
|
||||
let latestMessage = _messages.pop().content;
|
||||
if (typeof latestMessage !== 'string') {
|
||||
latestMessage = latestMessage[0].text;
|
||||
}
|
||||
latestMessage = [systemContent, latestMessage].join('\n');
|
||||
messages.push(new HumanMessage(latestMessage));
|
||||
_messages.push(new HumanMessage(latestMessage));
|
||||
}
|
||||
|
||||
let messages = _messages;
|
||||
if (
|
||||
agent.model_parameters?.clientOptions?.defaultHeaders?.['anthropic-beta']?.includes(
|
||||
'prompt-caching',
|
||||
)
|
||||
) {
|
||||
messages = addCacheControl(messages);
|
||||
}
|
||||
|
||||
run = await createRun({
|
||||
|
|
@ -665,11 +752,29 @@ class AgentClient extends BaseClient {
|
|||
}
|
||||
|
||||
if (contentData.length) {
|
||||
const agentUpdate = {
|
||||
type: ContentTypes.AGENT_UPDATE,
|
||||
[ContentTypes.AGENT_UPDATE]: {
|
||||
index: contentData.length,
|
||||
runId: this.responseMessageId,
|
||||
agentId: agent.id,
|
||||
},
|
||||
};
|
||||
const streamData = {
|
||||
event: GraphEvents.ON_AGENT_UPDATE,
|
||||
data: agentUpdate,
|
||||
};
|
||||
this.options.aggregateContent(streamData);
|
||||
sendEvent(this.options.res, streamData);
|
||||
contentData.push(agentUpdate);
|
||||
run.Graph.contentData = contentData;
|
||||
}
|
||||
|
||||
await run.processStream({ messages }, config, {
|
||||
keepContent: i !== 0,
|
||||
tokenCounter,
|
||||
indexTokenCountMap: currentIndexCountMap,
|
||||
maxContextTokens: agent.maxContextTokens,
|
||||
callbacks: {
|
||||
[Callback.TOOL_ERROR]: (graph, error, toolId) => {
|
||||
logger.error(
|
||||
|
|
@ -683,9 +788,13 @@ class AgentClient extends BaseClient {
|
|||
};
|
||||
|
||||
await runAgent(this.options.agent, initialMessages);
|
||||
|
||||
let finalContentStart = 0;
|
||||
if (this.agentConfigs && this.agentConfigs.size > 0) {
|
||||
if (
|
||||
this.agentConfigs &&
|
||||
this.agentConfigs.size > 0 &&
|
||||
(await checkCapability(this.options.req, AgentCapabilities.chain))
|
||||
) {
|
||||
const windowSize = 5;
|
||||
let latestMessage = initialMessages.pop().content;
|
||||
if (typeof latestMessage !== 'string') {
|
||||
latestMessage = latestMessage[0].text;
|
||||
|
|
@ -693,7 +802,16 @@ class AgentClient extends BaseClient {
|
|||
let i = 1;
|
||||
let runMessages = [];
|
||||
|
||||
const lastFiveMessages = initialMessages.slice(-5);
|
||||
const windowIndexCountMap = {};
|
||||
const windowMessages = initialMessages.slice(-windowSize);
|
||||
let currentIndex = 4;
|
||||
for (let i = initialMessages.length - 1; i >= 0; i--) {
|
||||
windowIndexCountMap[currentIndex] = indexTokenCountMap[i];
|
||||
currentIndex--;
|
||||
if (currentIndex < 0) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
for (const [agentId, agent] of this.agentConfigs) {
|
||||
if (abortController.signal.aborted === true) {
|
||||
break;
|
||||
|
|
@ -728,7 +846,9 @@ class AgentClient extends BaseClient {
|
|||
}
|
||||
try {
|
||||
const contextMessages = [];
|
||||
for (const message of lastFiveMessages) {
|
||||
const runIndexCountMap = {};
|
||||
for (let i = 0; i < windowMessages.length; i++) {
|
||||
const message = windowMessages[i];
|
||||
const messageType = message._getType();
|
||||
if (
|
||||
(!agent.tools || agent.tools.length === 0) &&
|
||||
|
|
@ -736,11 +856,13 @@ class AgentClient extends BaseClient {
|
|||
) {
|
||||
continue;
|
||||
}
|
||||
|
||||
runIndexCountMap[contextMessages.length] = windowIndexCountMap[i];
|
||||
contextMessages.push(message);
|
||||
}
|
||||
const currentMessages = [...contextMessages, new HumanMessage(bufferString)];
|
||||
await runAgent(agent, currentMessages, i, contentData);
|
||||
const bufferMessage = new HumanMessage(bufferString);
|
||||
runIndexCountMap[contextMessages.length] = tokenCounter(bufferMessage);
|
||||
const currentMessages = [...contextMessages, bufferMessage];
|
||||
await runAgent(agent, currentMessages, i, contentData, runIndexCountMap);
|
||||
} catch (err) {
|
||||
logger.error(
|
||||
`[api/server/controllers/agents/client.js #chatCompletion] Error running agent ${agentId} (${i})`,
|
||||
|
|
@ -751,6 +873,7 @@ class AgentClient extends BaseClient {
|
|||
}
|
||||
}
|
||||
|
||||
/** Note: not implemented */
|
||||
if (config.configurable.hide_sequential_outputs !== true) {
|
||||
finalContentStart = 0;
|
||||
}
|
||||
|
|
@ -774,18 +897,20 @@ class AgentClient extends BaseClient {
|
|||
);
|
||||
}
|
||||
} catch (err) {
|
||||
logger.error(
|
||||
'[api/server/controllers/agents/client.js #sendCompletion] Operation aborted',
|
||||
err,
|
||||
);
|
||||
if (!abortController.signal.aborted) {
|
||||
logger.error(
|
||||
'[api/server/controllers/agents/client.js #sendCompletion] Unhandled error type',
|
||||
err,
|
||||
);
|
||||
throw err;
|
||||
this.contentParts.push({
|
||||
type: ContentTypes.ERROR,
|
||||
[ContentTypes.ERROR]: `An error occurred while processing the request${err?.message ? `: ${err.message}` : ''}`,
|
||||
});
|
||||
}
|
||||
|
||||
logger.warn(
|
||||
'[api/server/controllers/agents/client.js #sendCompletion] Operation aborted',
|
||||
err,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -800,14 +925,20 @@ class AgentClient extends BaseClient {
|
|||
throw new Error('Run not initialized');
|
||||
}
|
||||
const { handleLLMEnd, collected: collectedMetadata } = createMetadataAggregator();
|
||||
const clientOptions = {};
|
||||
const providerConfig = this.options.req.app.locals[this.options.agent.provider];
|
||||
/** @type {import('@librechat/agents').ClientOptions} */
|
||||
const clientOptions = {
|
||||
maxTokens: 75,
|
||||
};
|
||||
let endpointConfig = this.options.req.app.locals[this.options.agent.endpoint];
|
||||
if (!endpointConfig) {
|
||||
endpointConfig = await getCustomEndpointConfig(this.options.agent.endpoint);
|
||||
}
|
||||
if (
|
||||
providerConfig &&
|
||||
providerConfig.titleModel &&
|
||||
providerConfig.titleModel !== Constants.CURRENT_MODEL
|
||||
endpointConfig &&
|
||||
endpointConfig.titleModel &&
|
||||
endpointConfig.titleModel !== Constants.CURRENT_MODEL
|
||||
) {
|
||||
clientOptions.model = providerConfig.titleModel;
|
||||
clientOptions.model = endpointConfig.titleModel;
|
||||
}
|
||||
try {
|
||||
const titleResult = await this.run.generateTitle({
|
||||
|
|
|
|||
|
|
@ -142,7 +142,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
|||
conversationId,
|
||||
sender,
|
||||
messageId: responseMessageId,
|
||||
parentMessageId: userMessageId ?? parentMessageId,
|
||||
parentMessageId: overrideParentMessageId ?? userMessageId ?? parentMessageId,
|
||||
}).catch((err) => {
|
||||
logger.error('[api/server/controllers/agents/request] Error in `handleAbortError`', err);
|
||||
});
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
const { Run, Providers } = require('@librechat/agents');
|
||||
const { providerEndpointMap } = require('librechat-data-provider');
|
||||
const { providerEndpointMap, KnownEndpoints } = require('librechat-data-provider');
|
||||
|
||||
/**
|
||||
* @typedef {import('@librechat/agents').t} t
|
||||
|
|
@ -7,6 +7,7 @@ const { providerEndpointMap } = require('librechat-data-provider');
|
|||
* @typedef {import('@librechat/agents').StreamEventData} StreamEventData
|
||||
* @typedef {import('@librechat/agents').EventHandler} EventHandler
|
||||
* @typedef {import('@librechat/agents').GraphEvents} GraphEvents
|
||||
* @typedef {import('@librechat/agents').LLMConfig} LLMConfig
|
||||
* @typedef {import('@librechat/agents').IState} IState
|
||||
*/
|
||||
|
||||
|
|
@ -32,6 +33,7 @@ async function createRun({
|
|||
streamUsage = true,
|
||||
}) {
|
||||
const provider = providerEndpointMap[agent.provider] ?? agent.provider;
|
||||
/** @type {LLMConfig} */
|
||||
const llmConfig = Object.assign(
|
||||
{
|
||||
provider,
|
||||
|
|
@ -41,6 +43,14 @@ async function createRun({
|
|||
agent.model_parameters,
|
||||
);
|
||||
|
||||
/** @type {'reasoning_content' | 'reasoning'} */
|
||||
let reasoningKey;
|
||||
if (
|
||||
llmConfig.configuration?.baseURL?.includes(KnownEndpoints.openrouter) ||
|
||||
(agent.endpoint && agent.endpoint.toLowerCase().includes(KnownEndpoints.openrouter))
|
||||
) {
|
||||
reasoningKey = 'reasoning';
|
||||
}
|
||||
if (/o1(?!-(?:mini|preview)).*$/.test(llmConfig.model)) {
|
||||
llmConfig.streaming = false;
|
||||
llmConfig.disableStreaming = true;
|
||||
|
|
@ -50,6 +60,7 @@ async function createRun({
|
|||
const graphConfig = {
|
||||
signal,
|
||||
llmConfig,
|
||||
reasoningKey,
|
||||
tools: agent.tools,
|
||||
instructions: agent.instructions,
|
||||
additional_instructions: agent.additional_instructions,
|
||||
|
|
@ -57,7 +68,7 @@ async function createRun({
|
|||
};
|
||||
|
||||
// TEMPORARY FOR TESTING
|
||||
if (agent.provider === Providers.ANTHROPIC) {
|
||||
if (agent.provider === Providers.ANTHROPIC || agent.provider === Providers.BEDROCK) {
|
||||
graphConfig.streamBuffer = 2000;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
const fs = require('fs').promises;
|
||||
const { nanoid } = require('nanoid');
|
||||
const {
|
||||
FileContext,
|
||||
Constants,
|
||||
Tools,
|
||||
Constants,
|
||||
FileContext,
|
||||
SystemRoles,
|
||||
EToolResources,
|
||||
actionDelimiter,
|
||||
} = require('librechat-data-provider');
|
||||
const {
|
||||
|
|
@ -203,14 +204,21 @@ const duplicateAgentHandler = async (req, res) => {
|
|||
}
|
||||
|
||||
const {
|
||||
_id: __id,
|
||||
id: _id,
|
||||
_id: __id,
|
||||
author: _author,
|
||||
createdAt: _createdAt,
|
||||
updatedAt: _updatedAt,
|
||||
tool_resources: _tool_resources = {},
|
||||
...cloneData
|
||||
} = agent;
|
||||
|
||||
if (_tool_resources?.[EToolResources.ocr]) {
|
||||
cloneData.tool_resources = {
|
||||
[EToolResources.ocr]: _tool_resources[EToolResources.ocr],
|
||||
};
|
||||
}
|
||||
|
||||
const newAgentId = `agent_${nanoid()}`;
|
||||
const newAgentData = Object.assign(cloneData, {
|
||||
id: newAgentId,
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
const { generate2FATempToken } = require('~/server/services/twoFactorService');
|
||||
const { setAuthTokens } = require('~/server/services/AuthService');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
|
|
@ -7,7 +8,12 @@ const loginController = async (req, res) => {
|
|||
return res.status(400).json({ message: 'Invalid credentials' });
|
||||
}
|
||||
|
||||
const { password: _, __v, ...user } = req.user;
|
||||
if (req.user.twoFactorEnabled) {
|
||||
const tempToken = generate2FATempToken(req.user._id);
|
||||
return res.status(200).json({ twoFAPending: true, tempToken });
|
||||
}
|
||||
|
||||
const { password: _p, totpSecret: _t, __v, ...user } = req.user;
|
||||
user.id = user._id.toString();
|
||||
|
||||
const token = await setAuthTokens(req.user._id, res);
|
||||
|
|
|
|||
60
api/server/controllers/auth/TwoFactorAuthController.js
Normal file
60
api/server/controllers/auth/TwoFactorAuthController.js
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
const jwt = require('jsonwebtoken');
|
||||
const {
|
||||
verifyTOTP,
|
||||
verifyBackupCode,
|
||||
getTOTPSecret,
|
||||
} = require('~/server/services/twoFactorService');
|
||||
const { setAuthTokens } = require('~/server/services/AuthService');
|
||||
const { getUserById } = require('~/models/userMethods');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
/**
|
||||
* Verifies the 2FA code during login using a temporary token.
|
||||
*/
|
||||
const verify2FAWithTempToken = async (req, res) => {
|
||||
try {
|
||||
const { tempToken, token, backupCode } = req.body;
|
||||
if (!tempToken) {
|
||||
return res.status(400).json({ message: 'Missing temporary token' });
|
||||
}
|
||||
|
||||
let payload;
|
||||
try {
|
||||
payload = jwt.verify(tempToken, process.env.JWT_SECRET);
|
||||
} catch (err) {
|
||||
return res.status(401).json({ message: 'Invalid or expired temporary token' });
|
||||
}
|
||||
|
||||
const user = await getUserById(payload.userId);
|
||||
if (!user || !user.twoFactorEnabled) {
|
||||
return res.status(400).json({ message: '2FA is not enabled for this user' });
|
||||
}
|
||||
|
||||
const secret = await getTOTPSecret(user.totpSecret);
|
||||
let isVerified = false;
|
||||
if (token) {
|
||||
isVerified = await verifyTOTP(secret, token);
|
||||
} else if (backupCode) {
|
||||
isVerified = await verifyBackupCode({ user, backupCode });
|
||||
}
|
||||
|
||||
if (!isVerified) {
|
||||
return res.status(401).json({ message: 'Invalid 2FA code or backup code' });
|
||||
}
|
||||
|
||||
// Prepare user data to return (omit sensitive fields).
|
||||
const userData = user.toObject ? user.toObject() : { ...user };
|
||||
delete userData.password;
|
||||
delete userData.__v;
|
||||
delete userData.totpSecret;
|
||||
userData.id = user._id.toString();
|
||||
|
||||
const authToken = await setAuthTokens(user._id, res);
|
||||
return res.status(200).json({ token: authToken, user: userData });
|
||||
} catch (err) {
|
||||
logger.error('[verify2FAWithTempToken]', err);
|
||||
return res.status(500).json({ message: 'Something went wrong' });
|
||||
}
|
||||
};
|
||||
|
||||
module.exports = { verify2FAWithTempToken };
|
||||
|
|
@ -1,10 +1,18 @@
|
|||
const { nanoid } = require('nanoid');
|
||||
const { EnvVar } = require('@librechat/agents');
|
||||
const { Tools, AuthType, ToolCallTypes } = require('librechat-data-provider');
|
||||
const {
|
||||
Tools,
|
||||
AuthType,
|
||||
Permissions,
|
||||
ToolCallTypes,
|
||||
PermissionTypes,
|
||||
} = require('librechat-data-provider');
|
||||
const { processFileURL, uploadImageBuffer } = require('~/server/services/Files/process');
|
||||
const { processCodeOutput } = require('~/server/services/Files/Code/process');
|
||||
const { loadAuthValues, loadTools } = require('~/app/clients/tools/util');
|
||||
const { createToolCall, getToolCallsByConvo } = require('~/models/ToolCall');
|
||||
const { loadAuthValues } = require('~/server/services/Tools/credentials');
|
||||
const { loadTools } = require('~/app/clients/tools/util');
|
||||
const { checkAccess } = require('~/server/middleware');
|
||||
const { getMessage } = require('~/models/Message');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
|
|
@ -12,6 +20,10 @@ const fieldsMap = {
|
|||
[Tools.execute_code]: [EnvVar.CODE_API_KEY],
|
||||
};
|
||||
|
||||
const toolAccessPermType = {
|
||||
[Tools.execute_code]: PermissionTypes.RUN_CODE,
|
||||
};
|
||||
|
||||
/**
|
||||
* @param {ServerRequest} req - The request object, containing information about the HTTP request.
|
||||
* @param {ServerResponse} res - The response object, used to send back the desired HTTP response.
|
||||
|
|
@ -58,6 +70,7 @@ const verifyToolAuth = async (req, res) => {
|
|||
/**
|
||||
* @param {ServerRequest} req - The request object, containing information about the HTTP request.
|
||||
* @param {ServerResponse} res - The response object, used to send back the desired HTTP response.
|
||||
* @param {NextFunction} next - The next middleware function to call.
|
||||
* @returns {Promise<void>} A promise that resolves when the function has completed.
|
||||
*/
|
||||
const callTool = async (req, res) => {
|
||||
|
|
@ -83,6 +96,16 @@ const callTool = async (req, res) => {
|
|||
return;
|
||||
}
|
||||
logger.debug(`[${toolId}/call] User: ${req.user.id}`);
|
||||
let hasAccess = true;
|
||||
if (toolAccessPermType[toolId]) {
|
||||
hasAccess = await checkAccess(req.user, toolAccessPermType[toolId], [Permissions.USE]);
|
||||
}
|
||||
if (!hasAccess) {
|
||||
logger.warn(
|
||||
`[${toolAccessPermType[toolId]}] Forbidden: Insufficient permissions for User ${req.user.id}: ${Permissions.USE}`,
|
||||
);
|
||||
return res.status(403).json({ message: 'Forbidden: Insufficient permissions' });
|
||||
}
|
||||
const { loadedTools } = await loadTools({
|
||||
user: req.user.id,
|
||||
tools: [toolId],
|
||||
|
|
|
|||
|
|
@ -22,10 +22,11 @@ const staticCache = require('./utils/staticCache');
|
|||
const noIndex = require('./middleware/noIndex');
|
||||
const routes = require('./routes');
|
||||
|
||||
const { PORT, HOST, ALLOW_SOCIAL_LOGIN, DISABLE_COMPRESSION } = process.env ?? {};
|
||||
const { PORT, HOST, ALLOW_SOCIAL_LOGIN, DISABLE_COMPRESSION, TRUST_PROXY } = process.env ?? {};
|
||||
|
||||
const port = Number(PORT) || 3080;
|
||||
const host = HOST || 'localhost';
|
||||
const trusted_proxy = Number(TRUST_PROXY) || 1; /* trust first proxy by default */
|
||||
|
||||
const startServer = async () => {
|
||||
if (typeof Bun !== 'undefined') {
|
||||
|
|
@ -53,7 +54,7 @@ const startServer = async () => {
|
|||
app.use(staticCache(app.locals.paths.dist));
|
||||
app.use(staticCache(app.locals.paths.fonts));
|
||||
app.use(staticCache(app.locals.paths.assets));
|
||||
app.set('trust proxy', 1); /* trust first proxy */
|
||||
app.set('trust proxy', trusted_proxy);
|
||||
app.use(cors());
|
||||
app.use(cookieParser());
|
||||
|
||||
|
|
@ -145,6 +146,18 @@ process.on('uncaughtException', (err) => {
|
|||
logger.error('There was an uncaught error:', err);
|
||||
}
|
||||
|
||||
if (err.message.includes('abort')) {
|
||||
logger.warn('There was an uncatchable AbortController error.');
|
||||
return;
|
||||
}
|
||||
|
||||
if (err.message.includes('GoogleGenerativeAI')) {
|
||||
logger.warn(
|
||||
'\n\n`GoogleGenerativeAI` errors cannot be caught due to an upstream issue, see: https://github.com/google-gemini/generative-ai-js/issues/303',
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
if (err.message.includes('fetch failed')) {
|
||||
if (messageCount === 0) {
|
||||
logger.warn('Meilisearch error, search will be disabled');
|
||||
|
|
|
|||
|
|
@ -120,7 +120,7 @@ const createAbortController = (req, res, getAbortData, getReqData) => {
|
|||
{ promptTokens, completionTokens },
|
||||
);
|
||||
|
||||
saveMessage(
|
||||
await saveMessage(
|
||||
req,
|
||||
{ ...responseMessage, user },
|
||||
{ context: 'api/server/middleware/abortMiddleware.js' },
|
||||
|
|
|
|||
|
|
@ -10,7 +10,6 @@ const openAI = require('~/server/services/Endpoints/openAI');
|
|||
const agents = require('~/server/services/Endpoints/agents');
|
||||
const custom = require('~/server/services/Endpoints/custom');
|
||||
const google = require('~/server/services/Endpoints/google');
|
||||
const { getConvoFiles } = require('~/models/Conversation');
|
||||
const { handleError } = require('~/server/utils');
|
||||
|
||||
const buildFunction = {
|
||||
|
|
@ -87,16 +86,8 @@ async function buildEndpointOption(req, res, next) {
|
|||
|
||||
// TODO: use `getModelsConfig` only when necessary
|
||||
const modelsConfig = await getModelsConfig(req);
|
||||
const { resendFiles = true } = req.body.endpointOption;
|
||||
req.body.endpointOption.modelsConfig = modelsConfig;
|
||||
if (isAgents && resendFiles && req.body.conversationId) {
|
||||
const fileIds = await getConvoFiles(req.body.conversationId);
|
||||
const requestFiles = req.body.files ?? [];
|
||||
if (requestFiles.length || fileIds.length) {
|
||||
req.body.endpointOption.attachments = processFiles(requestFiles, fileIds);
|
||||
}
|
||||
} else if (req.body.files) {
|
||||
// hold the promise
|
||||
if (req.body.files && !isAgents) {
|
||||
req.body.endpointOption.attachments = processFiles(req.body.files);
|
||||
}
|
||||
next();
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ const banResponse = async (req, res) => {
|
|||
* @function
|
||||
* @param {Object} req - Express request object.
|
||||
* @param {Object} res - Express response object.
|
||||
* @param {Function} next - Next middleware function.
|
||||
* @param {import('express').NextFunction} next - Next middleware function.
|
||||
*
|
||||
* @returns {Promise<function|Object>} - Returns a Promise which when resolved calls next middleware if user or source IP is not banned. Otherwise calls `banResponse()` and sets ban details in `banCache`.
|
||||
*/
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ const {
|
|||
* @function
|
||||
* @param {Object} req - Express request object containing user information.
|
||||
* @param {Object} res - Express response object.
|
||||
* @param {function} next - Express next middleware function.
|
||||
* @param {import('express').NextFunction} next - Next middleware function.
|
||||
* @throws {Error} Throws an error if the user exceeds the concurrent request limit.
|
||||
*/
|
||||
const concurrentLimiter = async (req, res, next) => {
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ const checkInviteUser = require('./checkInviteUser');
|
|||
const requireJwtAuth = require('./requireJwtAuth');
|
||||
const validateModel = require('./validateModel');
|
||||
const moderateText = require('./moderateText');
|
||||
const logHeaders = require('./logHeaders');
|
||||
const setHeaders = require('./setHeaders');
|
||||
const validate = require('./validate');
|
||||
const limiters = require('./limiters');
|
||||
|
|
@ -31,6 +32,7 @@ module.exports = {
|
|||
checkBan,
|
||||
uaParser,
|
||||
setHeaders,
|
||||
logHeaders,
|
||||
moderateText,
|
||||
validateModel,
|
||||
requireJwtAuth,
|
||||
|
|
|
|||
|
|
@ -1,6 +1,11 @@
|
|||
const Keyv = require('keyv');
|
||||
const rateLimit = require('express-rate-limit');
|
||||
const { RedisStore } = require('rate-limit-redis');
|
||||
const { ViolationTypes } = require('librechat-data-provider');
|
||||
const logViolation = require('~/cache/logViolation');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const keyvRedis = require('~/cache/keyvRedis');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const getEnvironmentVariables = () => {
|
||||
const IMPORT_IP_MAX = parseInt(process.env.IMPORT_IP_MAX) || 100;
|
||||
|
|
@ -48,21 +53,39 @@ const createImportLimiters = () => {
|
|||
const { importIpWindowMs, importIpMax, importUserWindowMs, importUserMax } =
|
||||
getEnvironmentVariables();
|
||||
|
||||
const importIpLimiter = rateLimit({
|
||||
const ipLimiterOptions = {
|
||||
windowMs: importIpWindowMs,
|
||||
max: importIpMax,
|
||||
handler: createImportHandler(),
|
||||
});
|
||||
|
||||
const importUserLimiter = rateLimit({
|
||||
};
|
||||
const userLimiterOptions = {
|
||||
windowMs: importUserWindowMs,
|
||||
max: importUserMax,
|
||||
handler: createImportHandler(false),
|
||||
keyGenerator: function (req) {
|
||||
return req.user?.id; // Use the user ID or NULL if not available
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
if (isEnabled(process.env.USE_REDIS)) {
|
||||
logger.debug('Using Redis for import rate limiters.');
|
||||
const keyv = new Keyv({ store: keyvRedis });
|
||||
const client = keyv.opts.store.redis;
|
||||
const sendCommand = (...args) => client.call(...args);
|
||||
const ipStore = new RedisStore({
|
||||
sendCommand,
|
||||
prefix: 'import_ip_limiter:',
|
||||
});
|
||||
const userStore = new RedisStore({
|
||||
sendCommand,
|
||||
prefix: 'import_user_limiter:',
|
||||
});
|
||||
ipLimiterOptions.store = ipStore;
|
||||
userLimiterOptions.store = userStore;
|
||||
}
|
||||
|
||||
const importIpLimiter = rateLimit(ipLimiterOptions);
|
||||
const importUserLimiter = rateLimit(userLimiterOptions);
|
||||
return { importIpLimiter, importUserLimiter };
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,10 @@
|
|||
const Keyv = require('keyv');
|
||||
const rateLimit = require('express-rate-limit');
|
||||
const { removePorts } = require('~/server/utils');
|
||||
const { RedisStore } = require('rate-limit-redis');
|
||||
const { removePorts, isEnabled } = require('~/server/utils');
|
||||
const keyvRedis = require('~/cache/keyvRedis');
|
||||
const { logViolation } = require('~/cache');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const { LOGIN_WINDOW = 5, LOGIN_MAX = 7, LOGIN_VIOLATION_SCORE: score } = process.env;
|
||||
const windowMs = LOGIN_WINDOW * 60 * 1000;
|
||||
|
|
@ -20,11 +24,25 @@ const handler = async (req, res) => {
|
|||
return res.status(429).json({ message });
|
||||
};
|
||||
|
||||
const loginLimiter = rateLimit({
|
||||
const limiterOptions = {
|
||||
windowMs,
|
||||
max,
|
||||
handler,
|
||||
keyGenerator: removePorts,
|
||||
});
|
||||
};
|
||||
|
||||
if (isEnabled(process.env.USE_REDIS)) {
|
||||
logger.debug('Using Redis for login rate limiter.');
|
||||
const keyv = new Keyv({ store: keyvRedis });
|
||||
const client = keyv.opts.store.redis;
|
||||
const sendCommand = (...args) => client.call(...args);
|
||||
const store = new RedisStore({
|
||||
sendCommand,
|
||||
prefix: 'login_limiter:',
|
||||
});
|
||||
limiterOptions.store = store;
|
||||
}
|
||||
|
||||
const loginLimiter = rateLimit(limiterOptions);
|
||||
|
||||
module.exports = loginLimiter;
|
||||
|
|
|
|||
|
|
@ -1,6 +1,11 @@
|
|||
const Keyv = require('keyv');
|
||||
const rateLimit = require('express-rate-limit');
|
||||
const { RedisStore } = require('rate-limit-redis');
|
||||
const denyRequest = require('~/server/middleware/denyRequest');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const keyvRedis = require('~/cache/keyvRedis');
|
||||
const { logViolation } = require('~/cache');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const {
|
||||
MESSAGE_IP_MAX = 40,
|
||||
|
|
@ -41,25 +46,49 @@ const createHandler = (ip = true) => {
|
|||
};
|
||||
|
||||
/**
|
||||
* Message request rate limiter by IP
|
||||
* Message request rate limiters
|
||||
*/
|
||||
const messageIpLimiter = rateLimit({
|
||||
const ipLimiterOptions = {
|
||||
windowMs: ipWindowMs,
|
||||
max: ipMax,
|
||||
handler: createHandler(),
|
||||
});
|
||||
};
|
||||
|
||||
/**
|
||||
* Message request rate limiter by userId
|
||||
*/
|
||||
const messageUserLimiter = rateLimit({
|
||||
const userLimiterOptions = {
|
||||
windowMs: userWindowMs,
|
||||
max: userMax,
|
||||
handler: createHandler(false),
|
||||
keyGenerator: function (req) {
|
||||
return req.user?.id; // Use the user ID or NULL if not available
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
if (isEnabled(process.env.USE_REDIS)) {
|
||||
logger.debug('Using Redis for message rate limiters.');
|
||||
const keyv = new Keyv({ store: keyvRedis });
|
||||
const client = keyv.opts.store.redis;
|
||||
const sendCommand = (...args) => client.call(...args);
|
||||
const ipStore = new RedisStore({
|
||||
sendCommand,
|
||||
prefix: 'message_ip_limiter:',
|
||||
});
|
||||
const userStore = new RedisStore({
|
||||
sendCommand,
|
||||
prefix: 'message_user_limiter:',
|
||||
});
|
||||
ipLimiterOptions.store = ipStore;
|
||||
userLimiterOptions.store = userStore;
|
||||
}
|
||||
|
||||
/**
|
||||
* Message request rate limiter by IP
|
||||
*/
|
||||
const messageIpLimiter = rateLimit(ipLimiterOptions);
|
||||
|
||||
/**
|
||||
* Message request rate limiter by userId
|
||||
*/
|
||||
const messageUserLimiter = rateLimit(userLimiterOptions);
|
||||
|
||||
module.exports = {
|
||||
messageIpLimiter,
|
||||
|
|
|
|||
|
|
@ -1,6 +1,10 @@
|
|||
const Keyv = require('keyv');
|
||||
const rateLimit = require('express-rate-limit');
|
||||
const { removePorts } = require('~/server/utils');
|
||||
const { RedisStore } = require('rate-limit-redis');
|
||||
const { removePorts, isEnabled } = require('~/server/utils');
|
||||
const keyvRedis = require('~/cache/keyvRedis');
|
||||
const { logViolation } = require('~/cache');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const { REGISTER_WINDOW = 60, REGISTER_MAX = 5, REGISTRATION_VIOLATION_SCORE: score } = process.env;
|
||||
const windowMs = REGISTER_WINDOW * 60 * 1000;
|
||||
|
|
@ -20,11 +24,25 @@ const handler = async (req, res) => {
|
|||
return res.status(429).json({ message });
|
||||
};
|
||||
|
||||
const registerLimiter = rateLimit({
|
||||
const limiterOptions = {
|
||||
windowMs,
|
||||
max,
|
||||
handler,
|
||||
keyGenerator: removePorts,
|
||||
});
|
||||
};
|
||||
|
||||
if (isEnabled(process.env.USE_REDIS)) {
|
||||
logger.debug('Using Redis for register rate limiter.');
|
||||
const keyv = new Keyv({ store: keyvRedis });
|
||||
const client = keyv.opts.store.redis;
|
||||
const sendCommand = (...args) => client.call(...args);
|
||||
const store = new RedisStore({
|
||||
sendCommand,
|
||||
prefix: 'register_limiter:',
|
||||
});
|
||||
limiterOptions.store = store;
|
||||
}
|
||||
|
||||
const registerLimiter = rateLimit(limiterOptions);
|
||||
|
||||
module.exports = registerLimiter;
|
||||
|
|
|
|||
|
|
@ -1,7 +1,11 @@
|
|||
const Keyv = require('keyv');
|
||||
const rateLimit = require('express-rate-limit');
|
||||
const { RedisStore } = require('rate-limit-redis');
|
||||
const { ViolationTypes } = require('librechat-data-provider');
|
||||
const { removePorts } = require('~/server/utils');
|
||||
const { removePorts, isEnabled } = require('~/server/utils');
|
||||
const keyvRedis = require('~/cache/keyvRedis');
|
||||
const { logViolation } = require('~/cache');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const {
|
||||
RESET_PASSWORD_WINDOW = 2,
|
||||
|
|
@ -25,11 +29,25 @@ const handler = async (req, res) => {
|
|||
return res.status(429).json({ message });
|
||||
};
|
||||
|
||||
const resetPasswordLimiter = rateLimit({
|
||||
const limiterOptions = {
|
||||
windowMs,
|
||||
max,
|
||||
handler,
|
||||
keyGenerator: removePorts,
|
||||
});
|
||||
};
|
||||
|
||||
if (isEnabled(process.env.USE_REDIS)) {
|
||||
logger.debug('Using Redis for reset password rate limiter.');
|
||||
const keyv = new Keyv({ store: keyvRedis });
|
||||
const client = keyv.opts.store.redis;
|
||||
const sendCommand = (...args) => client.call(...args);
|
||||
const store = new RedisStore({
|
||||
sendCommand,
|
||||
prefix: 'reset_password_limiter:',
|
||||
});
|
||||
limiterOptions.store = store;
|
||||
}
|
||||
|
||||
const resetPasswordLimiter = rateLimit(limiterOptions);
|
||||
|
||||
module.exports = resetPasswordLimiter;
|
||||
|
|
|
|||
|
|
@ -1,6 +1,11 @@
|
|||
const Keyv = require('keyv');
|
||||
const rateLimit = require('express-rate-limit');
|
||||
const { RedisStore } = require('rate-limit-redis');
|
||||
const { ViolationTypes } = require('librechat-data-provider');
|
||||
const logViolation = require('~/cache/logViolation');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const keyvRedis = require('~/cache/keyvRedis');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const getEnvironmentVariables = () => {
|
||||
const STT_IP_MAX = parseInt(process.env.STT_IP_MAX) || 100;
|
||||
|
|
@ -47,20 +52,40 @@ const createSTTHandler = (ip = true) => {
|
|||
const createSTTLimiters = () => {
|
||||
const { sttIpWindowMs, sttIpMax, sttUserWindowMs, sttUserMax } = getEnvironmentVariables();
|
||||
|
||||
const sttIpLimiter = rateLimit({
|
||||
const ipLimiterOptions = {
|
||||
windowMs: sttIpWindowMs,
|
||||
max: sttIpMax,
|
||||
handler: createSTTHandler(),
|
||||
});
|
||||
};
|
||||
|
||||
const sttUserLimiter = rateLimit({
|
||||
const userLimiterOptions = {
|
||||
windowMs: sttUserWindowMs,
|
||||
max: sttUserMax,
|
||||
handler: createSTTHandler(false),
|
||||
keyGenerator: function (req) {
|
||||
return req.user?.id; // Use the user ID or NULL if not available
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
if (isEnabled(process.env.USE_REDIS)) {
|
||||
logger.debug('Using Redis for STT rate limiters.');
|
||||
const keyv = new Keyv({ store: keyvRedis });
|
||||
const client = keyv.opts.store.redis;
|
||||
const sendCommand = (...args) => client.call(...args);
|
||||
const ipStore = new RedisStore({
|
||||
sendCommand,
|
||||
prefix: 'stt_ip_limiter:',
|
||||
});
|
||||
const userStore = new RedisStore({
|
||||
sendCommand,
|
||||
prefix: 'stt_user_limiter:',
|
||||
});
|
||||
ipLimiterOptions.store = ipStore;
|
||||
userLimiterOptions.store = userStore;
|
||||
}
|
||||
|
||||
const sttIpLimiter = rateLimit(ipLimiterOptions);
|
||||
const sttUserLimiter = rateLimit(userLimiterOptions);
|
||||
|
||||
return { sttIpLimiter, sttUserLimiter };
|
||||
};
|
||||
|
|
|
|||
|
|
@ -1,25 +1,46 @@
|
|||
const Keyv = require('keyv');
|
||||
const rateLimit = require('express-rate-limit');
|
||||
const { RedisStore } = require('rate-limit-redis');
|
||||
const { ViolationTypes } = require('librechat-data-provider');
|
||||
const logViolation = require('~/cache/logViolation');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const keyvRedis = require('~/cache/keyvRedis');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const toolCallLimiter = rateLimit({
|
||||
const handler = async (req, res) => {
|
||||
const type = ViolationTypes.TOOL_CALL_LIMIT;
|
||||
const errorMessage = {
|
||||
type,
|
||||
max: 1,
|
||||
limiter: 'user',
|
||||
windowInMinutes: 1,
|
||||
};
|
||||
|
||||
await logViolation(req, res, type, errorMessage, 0);
|
||||
res.status(429).json({ message: 'Too many tool call requests. Try again later' });
|
||||
};
|
||||
|
||||
const limiterOptions = {
|
||||
windowMs: 1000,
|
||||
max: 1,
|
||||
handler: async (req, res) => {
|
||||
const type = ViolationTypes.TOOL_CALL_LIMIT;
|
||||
const errorMessage = {
|
||||
type,
|
||||
max: 1,
|
||||
limiter: 'user',
|
||||
windowInMinutes: 1,
|
||||
};
|
||||
|
||||
await logViolation(req, res, type, errorMessage, 0);
|
||||
res.status(429).json({ message: 'Too many tool call requests. Try again later' });
|
||||
},
|
||||
handler,
|
||||
keyGenerator: function (req) {
|
||||
return req.user?.id;
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
if (isEnabled(process.env.USE_REDIS)) {
|
||||
logger.debug('Using Redis for tool call rate limiter.');
|
||||
const keyv = new Keyv({ store: keyvRedis });
|
||||
const client = keyv.opts.store.redis;
|
||||
const sendCommand = (...args) => client.call(...args);
|
||||
const store = new RedisStore({
|
||||
sendCommand,
|
||||
prefix: 'tool_call_limiter:',
|
||||
});
|
||||
limiterOptions.store = store;
|
||||
}
|
||||
|
||||
const toolCallLimiter = rateLimit(limiterOptions);
|
||||
|
||||
module.exports = toolCallLimiter;
|
||||
|
|
|
|||
|
|
@ -1,6 +1,11 @@
|
|||
const Keyv = require('keyv');
|
||||
const rateLimit = require('express-rate-limit');
|
||||
const { RedisStore } = require('rate-limit-redis');
|
||||
const { ViolationTypes } = require('librechat-data-provider');
|
||||
const logViolation = require('~/cache/logViolation');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const keyvRedis = require('~/cache/keyvRedis');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const getEnvironmentVariables = () => {
|
||||
const TTS_IP_MAX = parseInt(process.env.TTS_IP_MAX) || 100;
|
||||
|
|
@ -47,20 +52,40 @@ const createTTSHandler = (ip = true) => {
|
|||
const createTTSLimiters = () => {
|
||||
const { ttsIpWindowMs, ttsIpMax, ttsUserWindowMs, ttsUserMax } = getEnvironmentVariables();
|
||||
|
||||
const ttsIpLimiter = rateLimit({
|
||||
const ipLimiterOptions = {
|
||||
windowMs: ttsIpWindowMs,
|
||||
max: ttsIpMax,
|
||||
handler: createTTSHandler(),
|
||||
});
|
||||
};
|
||||
|
||||
const ttsUserLimiter = rateLimit({
|
||||
const userLimiterOptions = {
|
||||
windowMs: ttsUserWindowMs,
|
||||
max: ttsUserMax,
|
||||
handler: createTTSHandler(false),
|
||||
keyGenerator: function (req) {
|
||||
return req.user?.id; // Use the user ID or NULL if not available
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
if (isEnabled(process.env.USE_REDIS)) {
|
||||
logger.debug('Using Redis for TTS rate limiters.');
|
||||
const keyv = new Keyv({ store: keyvRedis });
|
||||
const client = keyv.opts.store.redis;
|
||||
const sendCommand = (...args) => client.call(...args);
|
||||
const ipStore = new RedisStore({
|
||||
sendCommand,
|
||||
prefix: 'tts_ip_limiter:',
|
||||
});
|
||||
const userStore = new RedisStore({
|
||||
sendCommand,
|
||||
prefix: 'tts_user_limiter:',
|
||||
});
|
||||
ipLimiterOptions.store = ipStore;
|
||||
userLimiterOptions.store = userStore;
|
||||
}
|
||||
|
||||
const ttsIpLimiter = rateLimit(ipLimiterOptions);
|
||||
const ttsUserLimiter = rateLimit(userLimiterOptions);
|
||||
|
||||
return { ttsIpLimiter, ttsUserLimiter };
|
||||
};
|
||||
|
|
|
|||
|
|
@ -1,6 +1,11 @@
|
|||
const Keyv = require('keyv');
|
||||
const rateLimit = require('express-rate-limit');
|
||||
const { RedisStore } = require('rate-limit-redis');
|
||||
const { ViolationTypes } = require('librechat-data-provider');
|
||||
const logViolation = require('~/cache/logViolation');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const keyvRedis = require('~/cache/keyvRedis');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const getEnvironmentVariables = () => {
|
||||
const FILE_UPLOAD_IP_MAX = parseInt(process.env.FILE_UPLOAD_IP_MAX) || 100;
|
||||
|
|
@ -52,20 +57,40 @@ const createFileLimiters = () => {
|
|||
const { fileUploadIpWindowMs, fileUploadIpMax, fileUploadUserWindowMs, fileUploadUserMax } =
|
||||
getEnvironmentVariables();
|
||||
|
||||
const fileUploadIpLimiter = rateLimit({
|
||||
const ipLimiterOptions = {
|
||||
windowMs: fileUploadIpWindowMs,
|
||||
max: fileUploadIpMax,
|
||||
handler: createFileUploadHandler(),
|
||||
});
|
||||
};
|
||||
|
||||
const fileUploadUserLimiter = rateLimit({
|
||||
const userLimiterOptions = {
|
||||
windowMs: fileUploadUserWindowMs,
|
||||
max: fileUploadUserMax,
|
||||
handler: createFileUploadHandler(false),
|
||||
keyGenerator: function (req) {
|
||||
return req.user?.id; // Use the user ID or NULL if not available
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
if (isEnabled(process.env.USE_REDIS)) {
|
||||
logger.debug('Using Redis for file upload rate limiters.');
|
||||
const keyv = new Keyv({ store: keyvRedis });
|
||||
const client = keyv.opts.store.redis;
|
||||
const sendCommand = (...args) => client.call(...args);
|
||||
const ipStore = new RedisStore({
|
||||
sendCommand,
|
||||
prefix: 'file_upload_ip_limiter:',
|
||||
});
|
||||
const userStore = new RedisStore({
|
||||
sendCommand,
|
||||
prefix: 'file_upload_user_limiter:',
|
||||
});
|
||||
ipLimiterOptions.store = ipStore;
|
||||
userLimiterOptions.store = userStore;
|
||||
}
|
||||
|
||||
const fileUploadIpLimiter = rateLimit(ipLimiterOptions);
|
||||
const fileUploadUserLimiter = rateLimit(userLimiterOptions);
|
||||
|
||||
return { fileUploadIpLimiter, fileUploadUserLimiter };
|
||||
};
|
||||
|
|
|
|||
|
|
@ -1,7 +1,11 @@
|
|||
const Keyv = require('keyv');
|
||||
const rateLimit = require('express-rate-limit');
|
||||
const { RedisStore } = require('rate-limit-redis');
|
||||
const { ViolationTypes } = require('librechat-data-provider');
|
||||
const { removePorts } = require('~/server/utils');
|
||||
const { removePorts, isEnabled } = require('~/server/utils');
|
||||
const keyvRedis = require('~/cache/keyvRedis');
|
||||
const { logViolation } = require('~/cache');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const {
|
||||
VERIFY_EMAIL_WINDOW = 2,
|
||||
|
|
@ -25,11 +29,25 @@ const handler = async (req, res) => {
|
|||
return res.status(429).json({ message });
|
||||
};
|
||||
|
||||
const verifyEmailLimiter = rateLimit({
|
||||
const limiterOptions = {
|
||||
windowMs,
|
||||
max,
|
||||
handler,
|
||||
keyGenerator: removePorts,
|
||||
});
|
||||
};
|
||||
|
||||
if (isEnabled(process.env.USE_REDIS)) {
|
||||
logger.debug('Using Redis for verify email rate limiter.');
|
||||
const keyv = new Keyv({ store: keyvRedis });
|
||||
const client = keyv.opts.store.redis;
|
||||
const sendCommand = (...args) => client.call(...args);
|
||||
const store = new RedisStore({
|
||||
sendCommand,
|
||||
prefix: 'verify_email_limiter:',
|
||||
});
|
||||
limiterOptions.store = store;
|
||||
}
|
||||
|
||||
const verifyEmailLimiter = rateLimit(limiterOptions);
|
||||
|
||||
module.exports = verifyEmailLimiter;
|
||||
|
|
|
|||
32
api/server/middleware/logHeaders.js
Normal file
32
api/server/middleware/logHeaders.js
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
const { logger } = require('~/config');
|
||||
|
||||
/**
|
||||
* Middleware to log Forwarded Headers
|
||||
* @function
|
||||
* @param {ServerRequest} req - Express request object containing user information.
|
||||
* @param {ServerResponse} res - Express response object.
|
||||
* @param {import('express').NextFunction} next - Next middleware function.
|
||||
* @throws {Error} Throws an error if the user exceeds the concurrent request limit.
|
||||
*/
|
||||
const logHeaders = (req, res, next) => {
|
||||
try {
|
||||
const forwardedHeaders = {};
|
||||
if (req.headers['x-forwarded-for']) {
|
||||
forwardedHeaders['x-forwarded-for'] = req.headers['x-forwarded-for'];
|
||||
}
|
||||
if (req.headers['x-forwarded-host']) {
|
||||
forwardedHeaders['x-forwarded-host'] = req.headers['x-forwarded-host'];
|
||||
}
|
||||
if (req.headers['x-forwarded-proto']) {
|
||||
forwardedHeaders['x-forwarded-proto'] = req.headers['x-forwarded-proto'];
|
||||
}
|
||||
if (Object.keys(forwardedHeaders).length > 0) {
|
||||
logger.debug('X-Forwarded headers detected in OAuth request:', forwardedHeaders);
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('Error logging X-Forwarded headers:', error);
|
||||
}
|
||||
next();
|
||||
};
|
||||
|
||||
module.exports = logHeaders;
|
||||
|
|
@ -1,32 +1,18 @@
|
|||
const passport = require('passport');
|
||||
const DebugControl = require('../../utils/debug.js');
|
||||
|
||||
function log({ title, parameters }) {
|
||||
DebugControl.log.functionName(title);
|
||||
if (parameters) {
|
||||
DebugControl.log.parameters(parameters);
|
||||
}
|
||||
}
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const requireLocalAuth = (req, res, next) => {
|
||||
passport.authenticate('local', (err, user, info) => {
|
||||
if (err) {
|
||||
log({
|
||||
title: '(requireLocalAuth) Error at passport.authenticate',
|
||||
parameters: [{ name: 'error', value: err }],
|
||||
});
|
||||
logger.error('[requireLocalAuth] Error at passport.authenticate:', err);
|
||||
return next(err);
|
||||
}
|
||||
if (!user) {
|
||||
log({
|
||||
title: '(requireLocalAuth) Error: No user',
|
||||
});
|
||||
logger.debug('[requireLocalAuth] Error: No user');
|
||||
return res.status(404).send(info);
|
||||
}
|
||||
if (info && info.message) {
|
||||
log({
|
||||
title: '(requireLocalAuth) Error: ' + info.message,
|
||||
});
|
||||
logger.debug('[requireLocalAuth] Error: ' + info.message);
|
||||
return res.status(422).send({ message: info.message });
|
||||
}
|
||||
req.user = user;
|
||||
|
|
|
|||
|
|
@ -1,4 +1,42 @@
|
|||
const { getRoleByName } = require('~/models/Role');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
/**
|
||||
* Core function to check if a user has one or more required permissions
|
||||
*
|
||||
* @param {object} user - The user object
|
||||
* @param {PermissionTypes} permissionType - The type of permission to check
|
||||
* @param {Permissions[]} permissions - The list of specific permissions to check
|
||||
* @param {Record<Permissions, string[]>} [bodyProps] - An optional object where keys are permissions and values are arrays of properties to check
|
||||
* @param {object} [checkObject] - The object to check properties against
|
||||
* @returns {Promise<boolean>} Whether the user has the required permissions
|
||||
*/
|
||||
const checkAccess = async (user, permissionType, permissions, bodyProps = {}, checkObject = {}) => {
|
||||
if (!user) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const role = await getRoleByName(user.role);
|
||||
if (role && role[permissionType]) {
|
||||
const hasAnyPermission = permissions.some((permission) => {
|
||||
if (role[permissionType][permission]) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (bodyProps[permission] && checkObject) {
|
||||
return bodyProps[permission].some((prop) =>
|
||||
Object.prototype.hasOwnProperty.call(checkObject, prop),
|
||||
);
|
||||
}
|
||||
|
||||
return false;
|
||||
});
|
||||
|
||||
return hasAnyPermission;
|
||||
}
|
||||
|
||||
return false;
|
||||
};
|
||||
|
||||
/**
|
||||
* Middleware to check if a user has one or more required permissions, optionally based on `req.body` properties.
|
||||
|
|
@ -6,42 +44,35 @@ const { getRoleByName } = require('~/models/Role');
|
|||
* @param {PermissionTypes} permissionType - The type of permission to check.
|
||||
* @param {Permissions[]} permissions - The list of specific permissions to check.
|
||||
* @param {Record<Permissions, string[]>} [bodyProps] - An optional object where keys are permissions and values are arrays of `req.body` properties to check.
|
||||
* @returns {Function} Express middleware function.
|
||||
* @returns {(req: ServerRequest, res: ServerResponse, next: NextFunction) => Promise<void>} Express middleware function.
|
||||
*/
|
||||
const generateCheckAccess = (permissionType, permissions, bodyProps = {}) => {
|
||||
return async (req, res, next) => {
|
||||
try {
|
||||
const { user } = req;
|
||||
if (!user) {
|
||||
return res.status(401).json({ message: 'Authorization required' });
|
||||
}
|
||||
|
||||
const role = await getRoleByName(user.role);
|
||||
if (role && role[permissionType]) {
|
||||
const hasAnyPermission = permissions.some((permission) => {
|
||||
if (role[permissionType][permission]) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (bodyProps[permission] && req.body) {
|
||||
return bodyProps[permission].some((prop) =>
|
||||
Object.prototype.hasOwnProperty.call(req.body, prop),
|
||||
);
|
||||
}
|
||||
|
||||
return false;
|
||||
});
|
||||
|
||||
if (hasAnyPermission) {
|
||||
return next();
|
||||
}
|
||||
const hasAccess = await checkAccess(
|
||||
req.user,
|
||||
permissionType,
|
||||
permissions,
|
||||
bodyProps,
|
||||
req.body,
|
||||
);
|
||||
|
||||
if (hasAccess) {
|
||||
return next();
|
||||
}
|
||||
|
||||
logger.warn(
|
||||
`[${permissionType}] Forbidden: Insufficient permissions for User ${req.user.id}: ${permissions.join(', ')}`,
|
||||
);
|
||||
return res.status(403).json({ message: 'Forbidden: Insufficient permissions' });
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
return res.status(500).json({ message: `Server error: ${error.message}` });
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
module.exports = generateCheckAccess;
|
||||
module.exports = {
|
||||
checkAccess,
|
||||
generateCheckAccess,
|
||||
};
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
const checkAdmin = require('./checkAdmin');
|
||||
const generateCheckAccess = require('./generateCheckAccess');
|
||||
const { checkAccess, generateCheckAccess } = require('./generateCheckAccess');
|
||||
|
||||
module.exports = {
|
||||
checkAdmin,
|
||||
checkAccess,
|
||||
generateCheckAccess,
|
||||
};
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ afterEach(() => {
|
|||
delete process.env.OPENID_ISSUER;
|
||||
delete process.env.OPENID_SESSION_SECRET;
|
||||
delete process.env.OPENID_BUTTON_LABEL;
|
||||
delete process.env.OPENID_AUTO_REDIRECT;
|
||||
delete process.env.OPENID_AUTH_URL;
|
||||
delete process.env.GITHUB_CLIENT_ID;
|
||||
delete process.env.GITHUB_CLIENT_SECRET;
|
||||
|
|
|
|||
|
|
@ -7,8 +7,17 @@ const {
|
|||
} = require('~/server/controllers/AuthController');
|
||||
const { loginController } = require('~/server/controllers/auth/LoginController');
|
||||
const { logoutController } = require('~/server/controllers/auth/LogoutController');
|
||||
const { verify2FAWithTempToken } = require('~/server/controllers/auth/TwoFactorAuthController');
|
||||
const {
|
||||
enable2FA,
|
||||
verify2FA,
|
||||
disable2FA,
|
||||
regenerateBackupCodes,
|
||||
confirm2FA,
|
||||
} = require('~/server/controllers/TwoFactorController');
|
||||
const {
|
||||
checkBan,
|
||||
logHeaders,
|
||||
loginLimiter,
|
||||
requireJwtAuth,
|
||||
checkInviteUser,
|
||||
|
|
@ -27,6 +36,7 @@ const ldapAuth = !!process.env.LDAP_URL && !!process.env.LDAP_USER_SEARCH_BASE;
|
|||
router.post('/logout', requireJwtAuth, logoutController);
|
||||
router.post(
|
||||
'/login',
|
||||
logHeaders,
|
||||
loginLimiter,
|
||||
checkBan,
|
||||
ldapAuth ? requireLdapAuth : requireLocalAuth,
|
||||
|
|
@ -50,4 +60,11 @@ router.post(
|
|||
);
|
||||
router.post('/resetPassword', checkBan, validatePasswordReset, resetPasswordController);
|
||||
|
||||
router.get('/2fa/enable', requireJwtAuth, enable2FA);
|
||||
router.post('/2fa/verify', requireJwtAuth, verify2FA);
|
||||
router.post('/2fa/verify-temp', checkBan, verify2FAWithTempToken);
|
||||
router.post('/2fa/confirm', requireJwtAuth, confirm2FA);
|
||||
router.post('/2fa/disable', requireJwtAuth, disable2FA);
|
||||
router.post('/2fa/backup/regenerate', requireJwtAuth, regenerateBackupCodes);
|
||||
|
||||
module.exports = router;
|
||||
|
|
|
|||
|
|
@ -47,16 +47,17 @@ router.get('/', async function (req, res) {
|
|||
githubLoginEnabled: !!process.env.GITHUB_CLIENT_ID && !!process.env.GITHUB_CLIENT_SECRET,
|
||||
googleLoginEnabled: !!process.env.GOOGLE_CLIENT_ID && !!process.env.GOOGLE_CLIENT_SECRET,
|
||||
appleLoginEnabled:
|
||||
!!process.env.APPLE_CLIENT_ID &&
|
||||
!!process.env.APPLE_TEAM_ID &&
|
||||
!!process.env.APPLE_KEY_ID &&
|
||||
!!process.env.APPLE_PRIVATE_KEY_PATH,
|
||||
!!process.env.APPLE_CLIENT_ID &&
|
||||
!!process.env.APPLE_TEAM_ID &&
|
||||
!!process.env.APPLE_KEY_ID &&
|
||||
!!process.env.APPLE_PRIVATE_KEY_PATH,
|
||||
openidLoginEnabled:
|
||||
!!process.env.OPENID_ENABLED &&
|
||||
!!process.env.OPENID_SESSION_SECRET,
|
||||
openidMultiTenantEnabled: !!process.env.OPENID_MULTI_TENANT,
|
||||
openidLabel: process.env.OPENID_BUTTON_LABEL || 'Continue with OpenID',
|
||||
openidImageUrl: process.env.OPENID_IMAGE_URL,
|
||||
openidAutoRedirect: isEnabled(process.env.OPENID_AUTO_REDIRECT),
|
||||
serverDomain: process.env.DOMAIN_SERVER || 'http://localhost:3080',
|
||||
emailLoginEnabled,
|
||||
registrationEnabled: !ldap?.enabled && isEnabled(process.env.ALLOW_REGISTRATION),
|
||||
|
|
@ -79,6 +80,7 @@ router.get('/', async function (req, res) {
|
|||
publicSharedLinksEnabled,
|
||||
analyticsGtmId: process.env.ANALYTICS_GTM_ID,
|
||||
instanceProjectId: instanceProject._id.toString(),
|
||||
bundlerURL: process.env.SANDPACK_BUNDLER_URL,
|
||||
};
|
||||
|
||||
if (ldap) {
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ const {
|
|||
} = require('~/server/services/Files/process');
|
||||
const { getStrategyFunctions } = require('~/server/services/Files/strategies');
|
||||
const { getOpenAIClient } = require('~/server/controllers/assistants/helpers');
|
||||
const { loadAuthValues } = require('~/app/clients/tools/util');
|
||||
const { loadAuthValues } = require('~/server/services/Tools/credentials');
|
||||
const { getAgent } = require('~/models/Agent');
|
||||
const { getFiles } = require('~/models/File');
|
||||
const { logger } = require('~/config');
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
// file deepcode ignore NoRateLimitingForLogin: Rate limiting is handled by the `loginLimiter` middleware
|
||||
const express = require('express');
|
||||
const passport = require('passport');
|
||||
const { loginLimiter, checkBan, checkDomainAllowed } = require('~/server/middleware');
|
||||
const { loginLimiter, logHeaders, checkBan, checkDomainAllowed } = require('~/server/middleware');
|
||||
const { setAuthTokens } = require('~/server/services/AuthService');
|
||||
const { logger } = require('~/config');
|
||||
const { chooseOpenIdStrategy } = require('~/server/utils/openidHelper');
|
||||
|
|
@ -13,6 +13,7 @@ const domains = {
|
|||
server: process.env.DOMAIN_SERVER,
|
||||
};
|
||||
|
||||
router.use(logHeaders);
|
||||
router.use(loginLimiter);
|
||||
|
||||
const oauthHandler = async (req, res) => {
|
||||
|
|
@ -31,8 +32,10 @@ const oauthHandler = async (req, res) => {
|
|||
|
||||
router.get('/error', (req, res) => {
|
||||
// A single error message is pushed by passport when authentication fails.
|
||||
logger.error('Error in OAuth authentication:', { message: req.session?.messages?.pop() });
|
||||
res.redirect(`${domains.client}/login`);
|
||||
logger.error('Error in OAuth authentication:', { message: req.session.messages.pop() });
|
||||
|
||||
// Redirect to login page with auth_failed parameter to prevent infinite redirect loops
|
||||
res.redirect(`${domains.client}/login?redirect=false`);
|
||||
});
|
||||
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -161,9 +161,9 @@ async function createActionTool({
|
|||
|
||||
if (metadata.auth && metadata.auth.type !== AuthTypeEnum.None) {
|
||||
try {
|
||||
const action_id = action.action_id;
|
||||
const identifier = `${req.user.id}:${action.action_id}`;
|
||||
if (metadata.auth.type === AuthTypeEnum.OAuth && metadata.auth.authorization_url) {
|
||||
const action_id = action.action_id;
|
||||
const identifier = `${req.user.id}:${action.action_id}`;
|
||||
const requestLogin = async () => {
|
||||
const { args: _args, stepId, ...toolCall } = config.toolCall ?? {};
|
||||
if (!stepId) {
|
||||
|
|
|
|||
|
|
@ -1,7 +1,15 @@
|
|||
const { FileSources, EModelEndpoint, getConfigDefaults } = require('librechat-data-provider');
|
||||
const {
|
||||
FileSources,
|
||||
EModelEndpoint,
|
||||
loadOCRConfig,
|
||||
processMCPEnv,
|
||||
getConfigDefaults,
|
||||
} = require('librechat-data-provider');
|
||||
const { checkVariables, checkHealth, checkConfig, checkAzureVariables } = require('./start/checks');
|
||||
const { azureAssistantsDefaults, assistantsConfigSetup } = require('./start/assistants');
|
||||
const { initializeAzureBlobService } = require('./Files/Azure/initialize');
|
||||
const { initializeFirebase } = require('./Files/Firebase/initialize');
|
||||
const { initializeS3 } = require('./Files/S3/initialize');
|
||||
const loadCustomConfig = require('./Config/loadCustomConfig');
|
||||
const handleRateLimits = require('./Config/handleRateLimits');
|
||||
const { loadDefaultInterface } = require('./start/interface');
|
||||
|
|
@ -25,6 +33,7 @@ const AppService = async (app) => {
|
|||
const config = (await loadCustomConfig()) ?? {};
|
||||
const configDefaults = getConfigDefaults();
|
||||
|
||||
const ocr = loadOCRConfig(config.ocr);
|
||||
const filteredTools = config.filteredTools;
|
||||
const includedTools = config.includedTools;
|
||||
const fileStrategy = config.fileStrategy ?? configDefaults.fileStrategy;
|
||||
|
|
@ -37,6 +46,10 @@ const AppService = async (app) => {
|
|||
|
||||
if (fileStrategy === FileSources.firebase) {
|
||||
initializeFirebase();
|
||||
} else if (fileStrategy === FileSources.azure) {
|
||||
initializeAzureBlobService();
|
||||
} else if (fileStrategy === FileSources.s3) {
|
||||
initializeS3();
|
||||
}
|
||||
|
||||
/** @type {Record<string, FunctionTool} */
|
||||
|
|
@ -48,7 +61,7 @@ const AppService = async (app) => {
|
|||
|
||||
if (config.mcpServers != null) {
|
||||
const mcpManager = await getMCPManager();
|
||||
await mcpManager.initializeMCP(config.mcpServers);
|
||||
await mcpManager.initializeMCP(config.mcpServers, processMCPEnv);
|
||||
await mcpManager.mapAvailableTools(availableTools);
|
||||
}
|
||||
|
||||
|
|
@ -57,6 +70,7 @@ const AppService = async (app) => {
|
|||
const interfaceConfig = await loadDefaultInterface(config, configDefaults);
|
||||
|
||||
const defaultLocals = {
|
||||
ocr,
|
||||
paths,
|
||||
fileStrategy,
|
||||
socialLogins,
|
||||
|
|
|
|||
|
|
@ -120,6 +120,7 @@ describe('AppService', () => {
|
|||
},
|
||||
},
|
||||
paths: expect.anything(),
|
||||
ocr: expect.anything(),
|
||||
imageOutputType: expect.any(String),
|
||||
fileConfig: undefined,
|
||||
secureImageLinks: undefined,
|
||||
|
|
@ -588,4 +589,33 @@ describe('AppService updating app.locals and issuing warnings', () => {
|
|||
);
|
||||
});
|
||||
});
|
||||
|
||||
it('should not parse environment variable references in OCR config', async () => {
|
||||
// Mock custom configuration with env variable references in OCR config
|
||||
const mockConfig = {
|
||||
ocr: {
|
||||
apiKey: '${OCR_API_KEY_CUSTOM_VAR_NAME}',
|
||||
baseURL: '${OCR_BASEURL_CUSTOM_VAR_NAME}',
|
||||
strategy: 'mistral_ocr',
|
||||
mistralModel: 'mistral-medium',
|
||||
},
|
||||
};
|
||||
|
||||
require('./Config/loadCustomConfig').mockImplementationOnce(() => Promise.resolve(mockConfig));
|
||||
|
||||
// Set actual environment variables with different values
|
||||
process.env.OCR_API_KEY_CUSTOM_VAR_NAME = 'actual-api-key';
|
||||
process.env.OCR_BASEURL_CUSTOM_VAR_NAME = 'https://actual-ocr-url.com';
|
||||
|
||||
// Initialize app
|
||||
const app = { locals: {} };
|
||||
await AppService(app);
|
||||
|
||||
// Verify that the raw string references were preserved and not interpolated
|
||||
expect(app.locals.ocr).toBeDefined();
|
||||
expect(app.locals.ocr.apiKey).toEqual('${OCR_API_KEY_CUSTOM_VAR_NAME}');
|
||||
expect(app.locals.ocr.baseURL).toEqual('${OCR_BASEURL_CUSTOM_VAR_NAME}');
|
||||
expect(app.locals.ocr.strategy).toEqual('mistral_ocr');
|
||||
expect(app.locals.ocr.mistralModel).toEqual('mistral-medium');
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -72,4 +72,15 @@ async function getEndpointsConfig(req) {
|
|||
return endpointsConfig;
|
||||
}
|
||||
|
||||
module.exports = { getEndpointsConfig };
|
||||
/**
|
||||
* @param {ServerRequest} req
|
||||
* @param {import('librechat-data-provider').AgentCapabilities} capability
|
||||
* @returns {Promise<boolean>}
|
||||
*/
|
||||
const checkCapability = async (req, capability) => {
|
||||
const endpointsConfig = await getEndpointsConfig(req);
|
||||
const capabilities = endpointsConfig?.[EModelEndpoint.agents]?.capabilities ?? [];
|
||||
return capabilities.includes(capability);
|
||||
};
|
||||
|
||||
module.exports = { getEndpointsConfig, checkCapability };
|
||||
|
|
|
|||
|
|
@ -47,7 +47,7 @@ async function loadConfigModels(req) {
|
|||
);
|
||||
|
||||
/**
|
||||
* @type {Record<string, string[]>}
|
||||
* @type {Record<string, Promise<string[]>>}
|
||||
* Map for promises keyed by unique combination of baseURL and apiKey */
|
||||
const fetchPromisesMap = {};
|
||||
/**
|
||||
|
|
@ -102,7 +102,7 @@ async function loadConfigModels(req) {
|
|||
|
||||
for (const name of associatedNames) {
|
||||
const endpoint = endpointsMap[name];
|
||||
modelsConfig[name] = !modelData?.length ? endpoint.models.default ?? [] : modelData;
|
||||
modelsConfig[name] = !modelData?.length ? (endpoint.models.default ?? []) : modelData;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -5,8 +5,8 @@ const {
|
|||
getGoogleModels,
|
||||
getBedrockModels,
|
||||
getAnthropicModels,
|
||||
getChatGPTBrowserModels,
|
||||
} = require('~/server/services/ModelService');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
/**
|
||||
* Loads the default models for the application.
|
||||
|
|
@ -15,31 +15,68 @@ const {
|
|||
* @param {Express.Request} req - The Express request object.
|
||||
*/
|
||||
async function loadDefaultModels(req) {
|
||||
const google = getGoogleModels();
|
||||
const openAI = await getOpenAIModels({ user: req.user.id });
|
||||
const anthropic = getAnthropicModels();
|
||||
const chatGPTBrowser = getChatGPTBrowserModels();
|
||||
const azureOpenAI = await getOpenAIModels({ user: req.user.id, azure: true });
|
||||
const gptPlugins = await getOpenAIModels({
|
||||
user: req.user.id,
|
||||
azure: useAzurePlugins,
|
||||
plugins: true,
|
||||
});
|
||||
const assistants = await getOpenAIModels({ assistants: true });
|
||||
const azureAssistants = await getOpenAIModels({ azureAssistants: true });
|
||||
try {
|
||||
const [
|
||||
openAI,
|
||||
anthropic,
|
||||
azureOpenAI,
|
||||
gptPlugins,
|
||||
assistants,
|
||||
azureAssistants,
|
||||
google,
|
||||
bedrock,
|
||||
] = await Promise.all([
|
||||
getOpenAIModels({ user: req.user.id }).catch((error) => {
|
||||
logger.error('Error fetching OpenAI models:', error);
|
||||
return [];
|
||||
}),
|
||||
getAnthropicModels({ user: req.user.id }).catch((error) => {
|
||||
logger.error('Error fetching Anthropic models:', error);
|
||||
return [];
|
||||
}),
|
||||
getOpenAIModels({ user: req.user.id, azure: true }).catch((error) => {
|
||||
logger.error('Error fetching Azure OpenAI models:', error);
|
||||
return [];
|
||||
}),
|
||||
getOpenAIModels({ user: req.user.id, azure: useAzurePlugins, plugins: true }).catch(
|
||||
(error) => {
|
||||
logger.error('Error fetching Plugin models:', error);
|
||||
return [];
|
||||
},
|
||||
),
|
||||
getOpenAIModels({ assistants: true }).catch((error) => {
|
||||
logger.error('Error fetching OpenAI Assistants API models:', error);
|
||||
return [];
|
||||
}),
|
||||
getOpenAIModels({ azureAssistants: true }).catch((error) => {
|
||||
logger.error('Error fetching Azure OpenAI Assistants API models:', error);
|
||||
return [];
|
||||
}),
|
||||
Promise.resolve(getGoogleModels()).catch((error) => {
|
||||
logger.error('Error getting Google models:', error);
|
||||
return [];
|
||||
}),
|
||||
Promise.resolve(getBedrockModels()).catch((error) => {
|
||||
logger.error('Error getting Bedrock models:', error);
|
||||
return [];
|
||||
}),
|
||||
]);
|
||||
|
||||
return {
|
||||
[EModelEndpoint.openAI]: openAI,
|
||||
[EModelEndpoint.agents]: openAI,
|
||||
[EModelEndpoint.google]: google,
|
||||
[EModelEndpoint.anthropic]: anthropic,
|
||||
[EModelEndpoint.gptPlugins]: gptPlugins,
|
||||
[EModelEndpoint.azureOpenAI]: azureOpenAI,
|
||||
[EModelEndpoint.chatGPTBrowser]: chatGPTBrowser,
|
||||
[EModelEndpoint.assistants]: assistants,
|
||||
[EModelEndpoint.azureAssistants]: azureAssistants,
|
||||
[EModelEndpoint.bedrock]: getBedrockModels(),
|
||||
};
|
||||
return {
|
||||
[EModelEndpoint.openAI]: openAI,
|
||||
[EModelEndpoint.agents]: openAI,
|
||||
[EModelEndpoint.google]: google,
|
||||
[EModelEndpoint.anthropic]: anthropic,
|
||||
[EModelEndpoint.gptPlugins]: gptPlugins,
|
||||
[EModelEndpoint.azureOpenAI]: azureOpenAI,
|
||||
[EModelEndpoint.assistants]: assistants,
|
||||
[EModelEndpoint.azureAssistants]: azureAssistants,
|
||||
[EModelEndpoint.bedrock]: bedrock,
|
||||
};
|
||||
} catch (error) {
|
||||
logger.error('Error fetching default models:', error);
|
||||
throw new Error(`Failed to load default models: ${error.message}`);
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = loadDefaultModels;
|
||||
|
|
|
|||
|
|
@ -2,15 +2,8 @@ const { loadAgent } = require('~/models/Agent');
|
|||
const { logger } = require('~/config');
|
||||
|
||||
const buildOptions = (req, endpoint, parsedBody) => {
|
||||
const {
|
||||
spec,
|
||||
iconURL,
|
||||
agent_id,
|
||||
instructions,
|
||||
maxContextTokens,
|
||||
resendFiles = true,
|
||||
...model_parameters
|
||||
} = parsedBody;
|
||||
const { spec, iconURL, agent_id, instructions, maxContextTokens, ...model_parameters } =
|
||||
parsedBody;
|
||||
const agentPromise = loadAgent({
|
||||
req,
|
||||
agent_id,
|
||||
|
|
@ -24,7 +17,6 @@ const buildOptions = (req, endpoint, parsedBody) => {
|
|||
iconURL,
|
||||
endpoint,
|
||||
agent_id,
|
||||
resendFiles,
|
||||
instructions,
|
||||
maxContextTokens,
|
||||
model_parameters,
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ const { createContentAggregator, Providers } = require('@librechat/agents');
|
|||
const {
|
||||
EModelEndpoint,
|
||||
getResponseSender,
|
||||
AgentCapabilities,
|
||||
providerEndpointMap,
|
||||
} = require('librechat-data-provider');
|
||||
const {
|
||||
|
|
@ -15,36 +16,61 @@ const initCustom = require('~/server/services/Endpoints/custom/initialize');
|
|||
const initGoogle = require('~/server/services/Endpoints/google/initialize');
|
||||
const generateArtifactsPrompt = require('~/app/clients/prompts/artifacts');
|
||||
const { getCustomEndpointConfig } = require('~/server/services/Config');
|
||||
const { processFiles } = require('~/server/services/Files/process');
|
||||
const { loadAgentTools } = require('~/server/services/ToolService');
|
||||
const AgentClient = require('~/server/controllers/agents/client');
|
||||
const { getConvoFiles } = require('~/models/Conversation');
|
||||
const { getToolFilesByIds } = require('~/models/File');
|
||||
const { getModelMaxTokens } = require('~/utils');
|
||||
const { getAgent } = require('~/models/Agent');
|
||||
const { getFiles } = require('~/models/File');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const providerConfigMap = {
|
||||
[Providers.XAI]: initCustom,
|
||||
[Providers.OLLAMA]: initCustom,
|
||||
[Providers.DEEPSEEK]: initCustom,
|
||||
[Providers.OPENROUTER]: initCustom,
|
||||
[EModelEndpoint.openAI]: initOpenAI,
|
||||
[EModelEndpoint.google]: initGoogle,
|
||||
[EModelEndpoint.azureOpenAI]: initOpenAI,
|
||||
[EModelEndpoint.anthropic]: initAnthropic,
|
||||
[EModelEndpoint.bedrock]: getBedrockOptions,
|
||||
[EModelEndpoint.google]: initGoogle,
|
||||
[Providers.OLLAMA]: initCustom,
|
||||
};
|
||||
|
||||
/**
|
||||
*
|
||||
* @param {ServerRequest} req
|
||||
* @param {Promise<Array<MongoFile | null>> | undefined} _attachments
|
||||
* @param {AgentToolResources | undefined} _tool_resources
|
||||
* @returns {Promise<{ attachments: Array<MongoFile | undefined> | undefined, tool_resources: AgentToolResources | undefined }>}
|
||||
*/
|
||||
const primeResources = async (_attachments, _tool_resources) => {
|
||||
const primeResources = async (req, _attachments, _tool_resources) => {
|
||||
try {
|
||||
/** @type {Array<MongoFile | undefined> | undefined} */
|
||||
let attachments;
|
||||
const tool_resources = _tool_resources ?? {};
|
||||
const isOCREnabled = (req.app.locals?.[EModelEndpoint.agents]?.capabilities ?? []).includes(
|
||||
AgentCapabilities.ocr,
|
||||
);
|
||||
if (tool_resources.ocr?.file_ids && isOCREnabled) {
|
||||
const context = await getFiles(
|
||||
{
|
||||
file_id: { $in: tool_resources.ocr.file_ids },
|
||||
},
|
||||
{},
|
||||
{},
|
||||
);
|
||||
attachments = (attachments ?? []).concat(context);
|
||||
}
|
||||
if (!_attachments) {
|
||||
return { attachments: undefined, tool_resources: _tool_resources };
|
||||
return { attachments, tool_resources };
|
||||
}
|
||||
/** @type {Array<MongoFile | undefined> | undefined} */
|
||||
const files = await _attachments;
|
||||
const attachments = [];
|
||||
const tool_resources = _tool_resources ?? {};
|
||||
if (!attachments) {
|
||||
/** @type {Array<MongoFile | undefined>} */
|
||||
attachments = [];
|
||||
}
|
||||
|
||||
for (const file of files) {
|
||||
if (!file) {
|
||||
|
|
@ -79,7 +105,6 @@ const primeResources = async (_attachments, _tool_resources) => {
|
|||
* @param {ServerResponse} params.res
|
||||
* @param {Agent} params.agent
|
||||
* @param {object} [params.endpointOption]
|
||||
* @param {AgentToolResources} [params.tool_resources]
|
||||
* @param {boolean} [params.isInitialAgent]
|
||||
* @returns {Promise<Agent>}
|
||||
*/
|
||||
|
|
@ -88,9 +113,30 @@ const initializeAgentOptions = async ({
|
|||
res,
|
||||
agent,
|
||||
endpointOption,
|
||||
tool_resources,
|
||||
isInitialAgent = false,
|
||||
}) => {
|
||||
let currentFiles;
|
||||
/** @type {Array<MongoFile>} */
|
||||
const requestFiles = req.body.files ?? [];
|
||||
if (
|
||||
isInitialAgent &&
|
||||
req.body.conversationId != null &&
|
||||
(agent.model_parameters?.resendFiles ?? true) === true
|
||||
) {
|
||||
const fileIds = (await getConvoFiles(req.body.conversationId)) ?? [];
|
||||
const toolFiles = await getToolFilesByIds(fileIds);
|
||||
if (requestFiles.length || toolFiles.length) {
|
||||
currentFiles = await processFiles(requestFiles.concat(toolFiles));
|
||||
}
|
||||
} else if (isInitialAgent && requestFiles.length) {
|
||||
currentFiles = await processFiles(requestFiles);
|
||||
}
|
||||
|
||||
const { attachments, tool_resources } = await primeResources(
|
||||
req,
|
||||
currentFiles,
|
||||
agent.tool_resources,
|
||||
);
|
||||
const { tools, toolContextMap } = await loadAgentTools({
|
||||
req,
|
||||
res,
|
||||
|
|
@ -99,18 +145,19 @@ const initializeAgentOptions = async ({
|
|||
});
|
||||
|
||||
const provider = agent.provider;
|
||||
agent.endpoint = provider;
|
||||
let getOptions = providerConfigMap[provider];
|
||||
|
||||
if (!getOptions) {
|
||||
if (!getOptions && providerConfigMap[provider.toLowerCase()] != null) {
|
||||
agent.provider = provider.toLowerCase();
|
||||
getOptions = providerConfigMap[agent.provider];
|
||||
} else if (!getOptions) {
|
||||
const customEndpointConfig = await getCustomEndpointConfig(provider);
|
||||
if (!customEndpointConfig) {
|
||||
throw new Error(`Provider ${provider} not supported`);
|
||||
}
|
||||
getOptions = initCustom;
|
||||
agent.provider = Providers.OPENAI;
|
||||
agent.endpoint = provider.toLowerCase();
|
||||
}
|
||||
|
||||
const model_parameters = Object.assign(
|
||||
{},
|
||||
agent.model_parameters ?? { model: agent.model },
|
||||
|
|
@ -134,6 +181,7 @@ const initializeAgentOptions = async ({
|
|||
agent.provider = options.provider;
|
||||
}
|
||||
|
||||
/** @type {import('@librechat/agents').ClientOptions} */
|
||||
agent.model_parameters = Object.assign(model_parameters, options.llmConfig);
|
||||
if (options.configOptions) {
|
||||
agent.model_parameters.configuration = options.configOptions;
|
||||
|
|
@ -152,15 +200,18 @@ const initializeAgentOptions = async ({
|
|||
|
||||
const tokensModel =
|
||||
agent.provider === EModelEndpoint.azureOpenAI ? agent.model : agent.model_parameters.model;
|
||||
|
||||
const maxTokens = agent.model_parameters.maxOutputTokens ?? agent.model_parameters.maxTokens ?? 0;
|
||||
const maxContextTokens =
|
||||
agent.model_parameters.maxContextTokens ??
|
||||
agent.max_context_tokens ??
|
||||
getModelMaxTokens(tokensModel, providerEndpointMap[provider]) ??
|
||||
4096;
|
||||
return {
|
||||
...agent,
|
||||
tools,
|
||||
attachments,
|
||||
toolContextMap,
|
||||
maxContextTokens:
|
||||
agent.max_context_tokens ??
|
||||
getModelMaxTokens(tokensModel, providerEndpointMap[provider]) ??
|
||||
4000,
|
||||
maxContextTokens: (maxContextTokens - maxTokens) * 0.9,
|
||||
};
|
||||
};
|
||||
|
||||
|
|
@ -193,11 +244,6 @@ const initializeClient = async ({ req, res, endpointOption }) => {
|
|||
throw new Error('Agent not found');
|
||||
}
|
||||
|
||||
const { attachments, tool_resources } = await primeResources(
|
||||
endpointOption.attachments,
|
||||
primaryAgent.tool_resources,
|
||||
);
|
||||
|
||||
const agentConfigs = new Map();
|
||||
|
||||
// Handle primary agent
|
||||
|
|
@ -206,7 +252,6 @@ const initializeClient = async ({ req, res, endpointOption }) => {
|
|||
res,
|
||||
agent: primaryAgent,
|
||||
endpointOption,
|
||||
tool_resources,
|
||||
isInitialAgent: true,
|
||||
});
|
||||
|
||||
|
|
@ -236,18 +281,21 @@ const initializeClient = async ({ req, res, endpointOption }) => {
|
|||
|
||||
const client = new AgentClient({
|
||||
req,
|
||||
agent: primaryConfig,
|
||||
res,
|
||||
sender,
|
||||
attachments,
|
||||
contentParts,
|
||||
agentConfigs,
|
||||
eventHandlers,
|
||||
collectedUsage,
|
||||
aggregateContent,
|
||||
artifactPromises,
|
||||
agent: primaryConfig,
|
||||
spec: endpointOption.spec,
|
||||
iconURL: endpointOption.iconURL,
|
||||
agentConfigs,
|
||||
endpoint: EModelEndpoint.agents,
|
||||
attachments: primaryConfig.attachments,
|
||||
maxContextTokens: primaryConfig.maxContextTokens,
|
||||
resendFiles: primaryConfig.model_parameters?.resendFiles ?? true,
|
||||
});
|
||||
|
||||
return { client };
|
||||
|
|
|
|||
|
|
@ -20,10 +20,19 @@ const addTitle = async (req, { text, response, client }) => {
|
|||
|
||||
const titleCache = getLogStores(CacheKeys.GEN_TITLE);
|
||||
const key = `${req.user.id}-${response.conversationId}`;
|
||||
const responseText =
|
||||
response?.content && Array.isArray(response?.content)
|
||||
? response.content.reduce((acc, block) => {
|
||||
if (block?.type === 'text') {
|
||||
return acc + block.text;
|
||||
}
|
||||
return acc;
|
||||
}, '')
|
||||
: (response?.content ?? response?.text ?? '');
|
||||
|
||||
const title = await client.titleConvo({
|
||||
text,
|
||||
responseText: response?.text ?? '',
|
||||
responseText,
|
||||
conversationId: response.conversationId,
|
||||
});
|
||||
await titleCache.set(key, title, 120000);
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
const { removeNullishValues } = require('librechat-data-provider');
|
||||
const { removeNullishValues, anthropicSettings } = require('librechat-data-provider');
|
||||
const generateArtifactsPrompt = require('~/app/clients/prompts/artifacts');
|
||||
|
||||
const buildOptions = (endpoint, parsedBody) => {
|
||||
|
|
@ -6,8 +6,10 @@ const buildOptions = (endpoint, parsedBody) => {
|
|||
modelLabel,
|
||||
promptPrefix,
|
||||
maxContextTokens,
|
||||
resendFiles = true,
|
||||
promptCache = true,
|
||||
resendFiles = anthropicSettings.resendFiles.default,
|
||||
promptCache = anthropicSettings.promptCache.default,
|
||||
thinking = anthropicSettings.thinking.default,
|
||||
thinkingBudget = anthropicSettings.thinkingBudget.default,
|
||||
iconURL,
|
||||
greeting,
|
||||
spec,
|
||||
|
|
@ -21,6 +23,8 @@ const buildOptions = (endpoint, parsedBody) => {
|
|||
promptPrefix,
|
||||
resendFiles,
|
||||
promptCache,
|
||||
thinking,
|
||||
thinkingBudget,
|
||||
iconURL,
|
||||
greeting,
|
||||
spec,
|
||||
|
|
|
|||
111
api/server/services/Endpoints/anthropic/helpers.js
Normal file
111
api/server/services/Endpoints/anthropic/helpers.js
Normal file
|
|
@ -0,0 +1,111 @@
|
|||
const { EModelEndpoint, anthropicSettings } = require('librechat-data-provider');
|
||||
const { matchModelName } = require('~/utils');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
/**
|
||||
* @param {string} modelName
|
||||
* @returns {boolean}
|
||||
*/
|
||||
function checkPromptCacheSupport(modelName) {
|
||||
const modelMatch = matchModelName(modelName, EModelEndpoint.anthropic);
|
||||
if (
|
||||
modelMatch.includes('claude-3-5-sonnet-latest') ||
|
||||
modelMatch.includes('claude-3.5-sonnet-latest')
|
||||
) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (
|
||||
modelMatch === 'claude-3-7-sonnet' ||
|
||||
modelMatch === 'claude-3-5-sonnet' ||
|
||||
modelMatch === 'claude-3-5-haiku' ||
|
||||
modelMatch === 'claude-3-haiku' ||
|
||||
modelMatch === 'claude-3-opus' ||
|
||||
modelMatch === 'claude-3.7-sonnet' ||
|
||||
modelMatch === 'claude-3.5-sonnet' ||
|
||||
modelMatch === 'claude-3.5-haiku'
|
||||
) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the appropriate headers for Claude models with cache control
|
||||
* @param {string} model The model name
|
||||
* @param {boolean} supportsCacheControl Whether the model supports cache control
|
||||
* @returns {AnthropicClientOptions['extendedOptions']['defaultHeaders']|undefined} The headers object or undefined if not applicable
|
||||
*/
|
||||
function getClaudeHeaders(model, supportsCacheControl) {
|
||||
if (!supportsCacheControl) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
if (/claude-3[-.]5-sonnet/.test(model)) {
|
||||
return {
|
||||
'anthropic-beta': 'max-tokens-3-5-sonnet-2024-07-15,prompt-caching-2024-07-31',
|
||||
};
|
||||
} else if (/claude-3[-.]7/.test(model)) {
|
||||
return {
|
||||
'anthropic-beta':
|
||||
'token-efficient-tools-2025-02-19,output-128k-2025-02-19,prompt-caching-2024-07-31',
|
||||
};
|
||||
} else {
|
||||
return {
|
||||
'anthropic-beta': 'prompt-caching-2024-07-31',
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Configures reasoning-related options for Claude models
|
||||
* @param {AnthropicClientOptions & { max_tokens?: number }} anthropicInput The request options object
|
||||
* @param {Object} extendedOptions Additional client configuration options
|
||||
* @param {boolean} extendedOptions.thinking Whether thinking is enabled in client config
|
||||
* @param {number|null} extendedOptions.thinkingBudget The token budget for thinking
|
||||
* @returns {Object} Updated request options
|
||||
*/
|
||||
function configureReasoning(anthropicInput, extendedOptions = {}) {
|
||||
const updatedOptions = { ...anthropicInput };
|
||||
const currentMaxTokens = updatedOptions.max_tokens ?? updatedOptions.maxTokens;
|
||||
if (
|
||||
extendedOptions.thinking &&
|
||||
updatedOptions?.model &&
|
||||
/claude-3[-.]7/.test(updatedOptions.model)
|
||||
) {
|
||||
updatedOptions.thinking = {
|
||||
type: 'enabled',
|
||||
};
|
||||
}
|
||||
|
||||
if (updatedOptions.thinking != null && extendedOptions.thinkingBudget != null) {
|
||||
updatedOptions.thinking = {
|
||||
...updatedOptions.thinking,
|
||||
budget_tokens: extendedOptions.thinkingBudget,
|
||||
};
|
||||
}
|
||||
|
||||
if (
|
||||
updatedOptions.thinking != null &&
|
||||
(currentMaxTokens == null || updatedOptions.thinking.budget_tokens > currentMaxTokens)
|
||||
) {
|
||||
const maxTokens = anthropicSettings.maxOutputTokens.reset(updatedOptions.model);
|
||||
updatedOptions.max_tokens = currentMaxTokens ?? maxTokens;
|
||||
|
||||
logger.warn(
|
||||
updatedOptions.max_tokens === maxTokens
|
||||
? '[AnthropicClient] max_tokens is not defined while thinking is enabled. Setting max_tokens to model default.'
|
||||
: `[AnthropicClient] thinking budget_tokens (${updatedOptions.thinking.budget_tokens}) exceeds max_tokens (${updatedOptions.max_tokens}). Adjusting budget_tokens.`,
|
||||
);
|
||||
|
||||
updatedOptions.thinking.budget_tokens = Math.min(
|
||||
updatedOptions.thinking.budget_tokens,
|
||||
Math.floor(updatedOptions.max_tokens * 0.9),
|
||||
);
|
||||
}
|
||||
|
||||
return updatedOptions;
|
||||
}
|
||||
|
||||
module.exports = { checkPromptCacheSupport, getClaudeHeaders, configureReasoning };
|
||||
|
|
@ -27,6 +27,7 @@ const initializeClient = async ({ req, res, endpointOption, overrideModel, optio
|
|||
|
||||
if (anthropicConfig) {
|
||||
clientOptions.streamRate = anthropicConfig.streamRate;
|
||||
clientOptions.titleModel = anthropicConfig.titleModel;
|
||||
}
|
||||
|
||||
/** @type {undefined | TBaseEndpoint} */
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
const { HttpsProxyAgent } = require('https-proxy-agent');
|
||||
const { anthropicSettings, removeNullishValues } = require('librechat-data-provider');
|
||||
const { checkPromptCacheSupport, getClaudeHeaders, configureReasoning } = require('./helpers');
|
||||
|
||||
/**
|
||||
* Generates configuration options for creating an Anthropic language model (LLM) instance.
|
||||
|
|
@ -20,6 +21,14 @@ const { anthropicSettings, removeNullishValues } = require('librechat-data-provi
|
|||
* @returns {Object} Configuration options for creating an Anthropic LLM instance, with null and undefined values removed.
|
||||
*/
|
||||
function getLLMConfig(apiKey, options = {}) {
|
||||
const systemOptions = {
|
||||
thinking: options.modelOptions.thinking ?? anthropicSettings.thinking.default,
|
||||
promptCache: options.modelOptions.promptCache ?? anthropicSettings.promptCache.default,
|
||||
thinkingBudget: options.modelOptions.thinkingBudget ?? anthropicSettings.thinkingBudget.default,
|
||||
};
|
||||
for (let key in systemOptions) {
|
||||
delete options.modelOptions[key];
|
||||
}
|
||||
const defaultOptions = {
|
||||
model: anthropicSettings.model.default,
|
||||
maxOutputTokens: anthropicSettings.maxOutputTokens.default,
|
||||
|
|
@ -29,19 +38,34 @@ function getLLMConfig(apiKey, options = {}) {
|
|||
const mergedOptions = Object.assign(defaultOptions, options.modelOptions);
|
||||
|
||||
/** @type {AnthropicClientOptions} */
|
||||
const requestOptions = {
|
||||
let requestOptions = {
|
||||
apiKey,
|
||||
model: mergedOptions.model,
|
||||
stream: mergedOptions.stream,
|
||||
temperature: mergedOptions.temperature,
|
||||
topP: mergedOptions.topP,
|
||||
topK: mergedOptions.topK,
|
||||
stopSequences: mergedOptions.stop,
|
||||
maxTokens:
|
||||
mergedOptions.maxOutputTokens || anthropicSettings.maxOutputTokens.reset(mergedOptions.model),
|
||||
clientOptions: {},
|
||||
};
|
||||
|
||||
requestOptions = configureReasoning(requestOptions, systemOptions);
|
||||
|
||||
if (!/claude-3[-.]7/.test(mergedOptions.model)) {
|
||||
requestOptions.topP = mergedOptions.topP;
|
||||
requestOptions.topK = mergedOptions.topK;
|
||||
} else if (requestOptions.thinking == null) {
|
||||
requestOptions.topP = mergedOptions.topP;
|
||||
requestOptions.topK = mergedOptions.topK;
|
||||
}
|
||||
|
||||
const supportsCacheControl =
|
||||
systemOptions.promptCache === true && checkPromptCacheSupport(requestOptions.model);
|
||||
const headers = getClaudeHeaders(requestOptions.model, supportsCacheControl);
|
||||
if (headers) {
|
||||
requestOptions.clientOptions.defaultHeaders = headers;
|
||||
}
|
||||
|
||||
if (options.proxy) {
|
||||
requestOptions.clientOptions.httpAgent = new HttpsProxyAgent(options.proxy);
|
||||
}
|
||||
|
|
|
|||
153
api/server/services/Endpoints/anthropic/llm.spec.js
Normal file
153
api/server/services/Endpoints/anthropic/llm.spec.js
Normal file
|
|
@ -0,0 +1,153 @@
|
|||
const { anthropicSettings } = require('librechat-data-provider');
|
||||
const { getLLMConfig } = require('~/server/services/Endpoints/anthropic/llm');
|
||||
|
||||
jest.mock('https-proxy-agent', () => ({
|
||||
HttpsProxyAgent: jest.fn().mockImplementation((proxy) => ({ proxy })),
|
||||
}));
|
||||
|
||||
describe('getLLMConfig', () => {
|
||||
it('should create a basic configuration with default values', () => {
|
||||
const result = getLLMConfig('test-api-key', { modelOptions: {} });
|
||||
|
||||
expect(result.llmConfig).toHaveProperty('apiKey', 'test-api-key');
|
||||
expect(result.llmConfig).toHaveProperty('model', anthropicSettings.model.default);
|
||||
expect(result.llmConfig).toHaveProperty('stream', true);
|
||||
expect(result.llmConfig).toHaveProperty('maxTokens');
|
||||
});
|
||||
|
||||
it('should include proxy settings when provided', () => {
|
||||
const result = getLLMConfig('test-api-key', {
|
||||
modelOptions: {},
|
||||
proxy: 'http://proxy:8080',
|
||||
});
|
||||
|
||||
expect(result.llmConfig.clientOptions).toHaveProperty('httpAgent');
|
||||
expect(result.llmConfig.clientOptions.httpAgent).toHaveProperty('proxy', 'http://proxy:8080');
|
||||
});
|
||||
|
||||
it('should include reverse proxy URL when provided', () => {
|
||||
const result = getLLMConfig('test-api-key', {
|
||||
modelOptions: {},
|
||||
reverseProxyUrl: 'http://reverse-proxy',
|
||||
});
|
||||
|
||||
expect(result.llmConfig.clientOptions).toHaveProperty('baseURL', 'http://reverse-proxy');
|
||||
});
|
||||
|
||||
it('should include topK and topP for non-Claude-3.7 models', () => {
|
||||
const result = getLLMConfig('test-api-key', {
|
||||
modelOptions: {
|
||||
model: 'claude-3-opus',
|
||||
topK: 10,
|
||||
topP: 0.9,
|
||||
},
|
||||
});
|
||||
|
||||
expect(result.llmConfig).toHaveProperty('topK', 10);
|
||||
expect(result.llmConfig).toHaveProperty('topP', 0.9);
|
||||
});
|
||||
|
||||
it('should include topK and topP for Claude-3.5 models', () => {
|
||||
const result = getLLMConfig('test-api-key', {
|
||||
modelOptions: {
|
||||
model: 'claude-3-5-sonnet',
|
||||
topK: 10,
|
||||
topP: 0.9,
|
||||
},
|
||||
});
|
||||
|
||||
expect(result.llmConfig).toHaveProperty('topK', 10);
|
||||
expect(result.llmConfig).toHaveProperty('topP', 0.9);
|
||||
});
|
||||
|
||||
it('should NOT include topK and topP for Claude-3-7 models (hyphen notation)', () => {
|
||||
const result = getLLMConfig('test-api-key', {
|
||||
modelOptions: {
|
||||
model: 'claude-3-7-sonnet',
|
||||
topK: 10,
|
||||
topP: 0.9,
|
||||
},
|
||||
});
|
||||
|
||||
expect(result.llmConfig).not.toHaveProperty('topK');
|
||||
expect(result.llmConfig).not.toHaveProperty('topP');
|
||||
});
|
||||
|
||||
it('should NOT include topK and topP for Claude-3.7 models (decimal notation)', () => {
|
||||
const result = getLLMConfig('test-api-key', {
|
||||
modelOptions: {
|
||||
model: 'claude-3.7-sonnet',
|
||||
topK: 10,
|
||||
topP: 0.9,
|
||||
},
|
||||
});
|
||||
|
||||
expect(result.llmConfig).not.toHaveProperty('topK');
|
||||
expect(result.llmConfig).not.toHaveProperty('topP');
|
||||
});
|
||||
|
||||
it('should handle custom maxOutputTokens', () => {
|
||||
const result = getLLMConfig('test-api-key', {
|
||||
modelOptions: {
|
||||
model: 'claude-3-opus',
|
||||
maxOutputTokens: 2048,
|
||||
},
|
||||
});
|
||||
|
||||
expect(result.llmConfig).toHaveProperty('maxTokens', 2048);
|
||||
});
|
||||
|
||||
it('should handle promptCache setting', () => {
|
||||
const result = getLLMConfig('test-api-key', {
|
||||
modelOptions: {
|
||||
model: 'claude-3-5-sonnet',
|
||||
promptCache: true,
|
||||
},
|
||||
});
|
||||
|
||||
// We're not checking specific header values since that depends on the actual helper function
|
||||
// Just verifying that the promptCache setting is processed
|
||||
expect(result.llmConfig).toBeDefined();
|
||||
});
|
||||
|
||||
it('should include topK and topP for Claude-3.7 models when thinking is not enabled', () => {
|
||||
// Test with thinking explicitly set to null/undefined
|
||||
const result = getLLMConfig('test-api-key', {
|
||||
modelOptions: {
|
||||
model: 'claude-3-7-sonnet',
|
||||
topK: 10,
|
||||
topP: 0.9,
|
||||
thinking: false,
|
||||
},
|
||||
});
|
||||
|
||||
expect(result.llmConfig).toHaveProperty('topK', 10);
|
||||
expect(result.llmConfig).toHaveProperty('topP', 0.9);
|
||||
|
||||
// Test with thinking explicitly set to false
|
||||
const result2 = getLLMConfig('test-api-key', {
|
||||
modelOptions: {
|
||||
model: 'claude-3-7-sonnet',
|
||||
topK: 10,
|
||||
topP: 0.9,
|
||||
thinking: false,
|
||||
},
|
||||
});
|
||||
|
||||
expect(result2.llmConfig).toHaveProperty('topK', 10);
|
||||
expect(result2.llmConfig).toHaveProperty('topP', 0.9);
|
||||
|
||||
// Test with decimal notation as well
|
||||
const result3 = getLLMConfig('test-api-key', {
|
||||
modelOptions: {
|
||||
model: 'claude-3.7-sonnet',
|
||||
topK: 10,
|
||||
topP: 0.9,
|
||||
thinking: false,
|
||||
},
|
||||
});
|
||||
|
||||
expect(result3.llmConfig).toHaveProperty('topK', 10);
|
||||
expect(result3.llmConfig).toHaveProperty('topP', 0.9);
|
||||
});
|
||||
});
|
||||
|
|
@ -1,6 +1,5 @@
|
|||
const { removeNullishValues, bedrockInputParser } = require('librechat-data-provider');
|
||||
const { removeNullishValues } = require('librechat-data-provider');
|
||||
const generateArtifactsPrompt = require('~/app/clients/prompts/artifacts');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const buildOptions = (endpoint, parsedBody) => {
|
||||
const {
|
||||
|
|
@ -15,12 +14,6 @@ const buildOptions = (endpoint, parsedBody) => {
|
|||
artifacts,
|
||||
...model_parameters
|
||||
} = parsedBody;
|
||||
let parsedParams = model_parameters;
|
||||
try {
|
||||
parsedParams = bedrockInputParser.parse(model_parameters);
|
||||
} catch (error) {
|
||||
logger.warn('Failed to parse bedrock input', error);
|
||||
}
|
||||
const endpointOption = removeNullishValues({
|
||||
endpoint,
|
||||
name,
|
||||
|
|
@ -31,7 +24,7 @@ const buildOptions = (endpoint, parsedBody) => {
|
|||
spec,
|
||||
promptPrefix,
|
||||
maxContextTokens,
|
||||
model_parameters: parsedParams,
|
||||
model_parameters,
|
||||
});
|
||||
|
||||
if (typeof artifacts === 'string') {
|
||||
|
|
|
|||
|
|
@ -23,8 +23,9 @@ const initializeClient = async ({ req, res, endpointOption }) => {
|
|||
const agent = {
|
||||
id: EModelEndpoint.bedrock,
|
||||
name: endpointOption.name,
|
||||
instructions: endpointOption.promptPrefix,
|
||||
provider: EModelEndpoint.bedrock,
|
||||
endpoint: EModelEndpoint.bedrock,
|
||||
instructions: endpointOption.promptPrefix,
|
||||
model: endpointOption.model_parameters.model,
|
||||
model_parameters: endpointOption.model_parameters,
|
||||
};
|
||||
|
|
@ -54,6 +55,7 @@ const initializeClient = async ({ req, res, endpointOption }) => {
|
|||
|
||||
const client = new AgentClient({
|
||||
req,
|
||||
res,
|
||||
agent,
|
||||
sender,
|
||||
// tools,
|
||||
|
|
|
|||
|
|
@ -1,14 +1,16 @@
|
|||
const { HttpsProxyAgent } = require('https-proxy-agent');
|
||||
const {
|
||||
EModelEndpoint,
|
||||
Constants,
|
||||
AuthType,
|
||||
Constants,
|
||||
EModelEndpoint,
|
||||
bedrockInputParser,
|
||||
bedrockOutputParser,
|
||||
removeNullishValues,
|
||||
} = require('librechat-data-provider');
|
||||
const { getUserKey, checkUserKeyExpiry } = require('~/server/services/UserService');
|
||||
const { sleep } = require('~/server/utils');
|
||||
|
||||
const getOptions = async ({ req, endpointOption }) => {
|
||||
const getOptions = async ({ req, overrideModel, endpointOption }) => {
|
||||
const {
|
||||
BEDROCK_AWS_SECRET_ACCESS_KEY,
|
||||
BEDROCK_AWS_ACCESS_KEY_ID,
|
||||
|
|
@ -62,39 +64,44 @@ const getOptions = async ({ req, endpointOption }) => {
|
|||
|
||||
/** @type {BedrockClientOptions} */
|
||||
const requestOptions = {
|
||||
model: endpointOption.model,
|
||||
model: overrideModel ?? endpointOption.model,
|
||||
region: BEDROCK_AWS_DEFAULT_REGION,
|
||||
streaming: true,
|
||||
streamUsage: true,
|
||||
callbacks: [
|
||||
{
|
||||
handleLLMNewToken: async () => {
|
||||
if (!streamRate) {
|
||||
return;
|
||||
}
|
||||
await sleep(streamRate);
|
||||
},
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
if (credentials) {
|
||||
requestOptions.credentials = credentials;
|
||||
}
|
||||
|
||||
if (BEDROCK_REVERSE_PROXY) {
|
||||
requestOptions.endpointHost = BEDROCK_REVERSE_PROXY;
|
||||
}
|
||||
|
||||
const configOptions = {};
|
||||
if (PROXY) {
|
||||
/** NOTE: NOT SUPPORTED BY BEDROCK */
|
||||
configOptions.httpAgent = new HttpsProxyAgent(PROXY);
|
||||
}
|
||||
|
||||
const llmConfig = bedrockOutputParser(
|
||||
bedrockInputParser.parse(
|
||||
removeNullishValues(Object.assign(requestOptions, endpointOption.model_parameters)),
|
||||
),
|
||||
);
|
||||
|
||||
if (credentials) {
|
||||
llmConfig.credentials = credentials;
|
||||
}
|
||||
|
||||
if (BEDROCK_REVERSE_PROXY) {
|
||||
llmConfig.endpointHost = BEDROCK_REVERSE_PROXY;
|
||||
}
|
||||
|
||||
llmConfig.callbacks = [
|
||||
{
|
||||
handleLLMNewToken: async () => {
|
||||
if (!streamRate) {
|
||||
return;
|
||||
}
|
||||
await sleep(streamRate);
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
return {
|
||||
/** @type {BedrockClientOptions} */
|
||||
llmConfig: removeNullishValues(Object.assign(requestOptions, endpointOption.model_parameters)),
|
||||
llmConfig,
|
||||
configOptions,
|
||||
};
|
||||
};
|
||||
|
|
|
|||
|
|
@ -141,7 +141,8 @@ const initializeClient = async ({ req, res, endpointOption, optionsOnly, overrid
|
|||
},
|
||||
clientOptions,
|
||||
);
|
||||
const options = getLLMConfig(apiKey, clientOptions);
|
||||
clientOptions.modelOptions.user = req.user.id;
|
||||
const options = getLLMConfig(apiKey, clientOptions, endpoint);
|
||||
if (!customOptions.streamRate) {
|
||||
return options;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,12 +5,7 @@ const { isEnabled } = require('~/server/utils');
|
|||
const { GoogleClient } = require('~/app');
|
||||
|
||||
const initializeClient = async ({ req, res, endpointOption, overrideModel, optionsOnly }) => {
|
||||
const {
|
||||
GOOGLE_KEY,
|
||||
GOOGLE_REVERSE_PROXY,
|
||||
GOOGLE_AUTH_HEADER,
|
||||
PROXY,
|
||||
} = process.env;
|
||||
const { GOOGLE_KEY, GOOGLE_REVERSE_PROXY, GOOGLE_AUTH_HEADER, PROXY } = process.env;
|
||||
const isUserProvided = GOOGLE_KEY === 'user_provided';
|
||||
const { key: expiresAt } = req.body;
|
||||
|
||||
|
|
@ -43,6 +38,7 @@ const initializeClient = async ({ req, res, endpointOption, overrideModel, optio
|
|||
|
||||
if (googleConfig) {
|
||||
clientOptions.streamRate = googleConfig.streamRate;
|
||||
clientOptions.titleModel = googleConfig.titleModel;
|
||||
}
|
||||
|
||||
if (allConfig) {
|
||||
|
|
|
|||
|
|
@ -113,6 +113,7 @@ const initializeClient = async ({
|
|||
|
||||
if (!isAzureOpenAI && openAIConfig) {
|
||||
clientOptions.streamRate = openAIConfig.streamRate;
|
||||
clientOptions.titleModel = openAIConfig.titleModel;
|
||||
}
|
||||
|
||||
/** @type {undefined | TBaseEndpoint} */
|
||||
|
|
@ -134,12 +135,10 @@ const initializeClient = async ({
|
|||
}
|
||||
|
||||
if (optionsOnly) {
|
||||
clientOptions = Object.assign(
|
||||
{
|
||||
modelOptions: endpointOption.model_parameters,
|
||||
},
|
||||
clientOptions,
|
||||
);
|
||||
const modelOptions = endpointOption.model_parameters;
|
||||
modelOptions.model = modelName;
|
||||
clientOptions = Object.assign({ modelOptions }, clientOptions);
|
||||
clientOptions.modelOptions.user = req.user.id;
|
||||
const options = getLLMConfig(apiKey, clientOptions);
|
||||
if (!clientOptions.streamRate) {
|
||||
return options;
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
const { HttpsProxyAgent } = require('https-proxy-agent');
|
||||
const { KnownEndpoints } = require('librechat-data-provider');
|
||||
const { sanitizeModelName, constructAzureURL } = require('~/utils');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
|
||||
|
|
@ -8,6 +9,7 @@ const { isEnabled } = require('~/server/utils');
|
|||
* @param {Object} options - Additional options for configuring the LLM.
|
||||
* @param {Object} [options.modelOptions] - Model-specific options.
|
||||
* @param {string} [options.modelOptions.model] - The name of the model to use.
|
||||
* @param {string} [options.modelOptions.user] - The user ID
|
||||
* @param {number} [options.modelOptions.temperature] - Controls randomness in output generation (0-2).
|
||||
* @param {number} [options.modelOptions.top_p] - Controls diversity via nucleus sampling (0-1).
|
||||
* @param {number} [options.modelOptions.frequency_penalty] - Reduces repetition of token sequences (-2 to 2).
|
||||
|
|
@ -22,13 +24,13 @@ const { isEnabled } = require('~/server/utils');
|
|||
* @param {boolean} [options.streaming] - Whether to use streaming mode.
|
||||
* @param {Object} [options.addParams] - Additional parameters to add to the model options.
|
||||
* @param {string[]} [options.dropParams] - Parameters to remove from the model options.
|
||||
* @param {string|null} [endpoint=null] - The endpoint name
|
||||
* @returns {Object} Configuration options for creating an LLM instance.
|
||||
*/
|
||||
function getLLMConfig(apiKey, options = {}) {
|
||||
const {
|
||||
function getLLMConfig(apiKey, options = {}, endpoint = null) {
|
||||
let {
|
||||
modelOptions = {},
|
||||
reverseProxyUrl,
|
||||
useOpenRouter,
|
||||
defaultQuery,
|
||||
headers,
|
||||
proxy,
|
||||
|
|
@ -48,19 +50,45 @@ function getLLMConfig(apiKey, options = {}) {
|
|||
if (addParams && typeof addParams === 'object') {
|
||||
Object.assign(llmConfig, addParams);
|
||||
}
|
||||
/** Note: OpenAI Web Search models do not support any known parameters besdies `max_tokens` */
|
||||
if (modelOptions.model && /gpt-4o.*search/.test(modelOptions.model)) {
|
||||
const searchExcludeParams = [
|
||||
'frequency_penalty',
|
||||
'presence_penalty',
|
||||
'temperature',
|
||||
'top_p',
|
||||
'top_k',
|
||||
'stop',
|
||||
'logit_bias',
|
||||
'seed',
|
||||
'response_format',
|
||||
'n',
|
||||
'logprobs',
|
||||
'user',
|
||||
];
|
||||
|
||||
dropParams = dropParams || [];
|
||||
dropParams = [...new Set([...dropParams, ...searchExcludeParams])];
|
||||
}
|
||||
|
||||
if (dropParams && Array.isArray(dropParams)) {
|
||||
dropParams.forEach((param) => {
|
||||
delete llmConfig[param];
|
||||
if (llmConfig[param]) {
|
||||
llmConfig[param] = undefined;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
let useOpenRouter;
|
||||
/** @type {OpenAIClientOptions['configuration']} */
|
||||
const configOptions = {};
|
||||
|
||||
// Handle OpenRouter or custom reverse proxy
|
||||
if (useOpenRouter || reverseProxyUrl === 'https://openrouter.ai/api/v1') {
|
||||
configOptions.baseURL = 'https://openrouter.ai/api/v1';
|
||||
if (
|
||||
(reverseProxyUrl && reverseProxyUrl.includes(KnownEndpoints.openrouter)) ||
|
||||
(endpoint && endpoint.toLowerCase().includes(KnownEndpoints.openrouter))
|
||||
) {
|
||||
useOpenRouter = true;
|
||||
llmConfig.include_reasoning = true;
|
||||
configOptions.baseURL = reverseProxyUrl;
|
||||
configOptions.defaultHeaders = Object.assign(
|
||||
{
|
||||
'HTTP-Referer': 'https://librechat.ai',
|
||||
|
|
@ -118,6 +146,13 @@ function getLLMConfig(apiKey, options = {}) {
|
|||
llmConfig.organization = process.env.OPENAI_ORGANIZATION;
|
||||
}
|
||||
|
||||
if (useOpenRouter && llmConfig.reasoning_effort != null) {
|
||||
llmConfig.reasoning = {
|
||||
effort: llmConfig.reasoning_effort,
|
||||
};
|
||||
delete llmConfig.reasoning_effort;
|
||||
}
|
||||
|
||||
return {
|
||||
/** @type {OpenAIClientOptions} */
|
||||
llmConfig,
|
||||
|
|
|
|||
196
api/server/services/Files/Azure/crud.js
Normal file
196
api/server/services/Files/Azure/crud.js
Normal file
|
|
@ -0,0 +1,196 @@
|
|||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
const axios = require('axios');
|
||||
const fetch = require('node-fetch');
|
||||
const { logger } = require('~/config');
|
||||
const { getAzureContainerClient } = require('./initialize');
|
||||
|
||||
const defaultBasePath = 'images';
|
||||
|
||||
/**
|
||||
* Uploads a buffer to Azure Blob Storage.
|
||||
*
|
||||
* Files will be stored at the path: {basePath}/{userId}/{fileName} within the container.
|
||||
*
|
||||
* @param {Object} params
|
||||
* @param {string} params.userId - The user's id.
|
||||
* @param {Buffer} params.buffer - The buffer to upload.
|
||||
* @param {string} params.fileName - The name of the file.
|
||||
* @param {string} [params.basePath='images'] - The base folder within the container.
|
||||
* @param {string} [params.containerName] - The Azure Blob container name.
|
||||
* @returns {Promise<string>} The URL of the uploaded blob.
|
||||
*/
|
||||
async function saveBufferToAzure({
|
||||
userId,
|
||||
buffer,
|
||||
fileName,
|
||||
basePath = defaultBasePath,
|
||||
containerName,
|
||||
}) {
|
||||
try {
|
||||
const containerClient = getAzureContainerClient(containerName);
|
||||
// Create the container if it doesn't exist. This is done per operation.
|
||||
await containerClient.createIfNotExists({
|
||||
access: process.env.AZURE_STORAGE_PUBLIC_ACCESS ? 'blob' : undefined,
|
||||
});
|
||||
const blobPath = `${basePath}/${userId}/${fileName}`;
|
||||
const blockBlobClient = containerClient.getBlockBlobClient(blobPath);
|
||||
await blockBlobClient.uploadData(buffer);
|
||||
return blockBlobClient.url;
|
||||
} catch (error) {
|
||||
logger.error('[saveBufferToAzure] Error uploading buffer:', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Saves a file from a URL to Azure Blob Storage.
|
||||
*
|
||||
* @param {Object} params
|
||||
* @param {string} params.userId - The user's id.
|
||||
* @param {string} params.URL - The URL of the file.
|
||||
* @param {string} params.fileName - The name of the file.
|
||||
* @param {string} [params.basePath='images'] - The base folder within the container.
|
||||
* @param {string} [params.containerName] - The Azure Blob container name.
|
||||
* @returns {Promise<string>} The URL of the uploaded blob.
|
||||
*/
|
||||
async function saveURLToAzure({
|
||||
userId,
|
||||
URL,
|
||||
fileName,
|
||||
basePath = defaultBasePath,
|
||||
containerName,
|
||||
}) {
|
||||
try {
|
||||
const response = await fetch(URL);
|
||||
const buffer = await response.buffer();
|
||||
return await saveBufferToAzure({ userId, buffer, fileName, basePath, containerName });
|
||||
} catch (error) {
|
||||
logger.error('[saveURLToAzure] Error uploading file from URL:', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves a blob URL from Azure Blob Storage.
|
||||
*
|
||||
* @param {Object} params
|
||||
* @param {string} params.fileName - The file name.
|
||||
* @param {string} [params.basePath='images'] - The base folder used during upload.
|
||||
* @param {string} [params.userId] - If files are stored in a user-specific directory.
|
||||
* @param {string} [params.containerName] - The Azure Blob container name.
|
||||
* @returns {Promise<string>} The blob's URL.
|
||||
*/
|
||||
async function getAzureURL({ fileName, basePath = defaultBasePath, userId, containerName }) {
|
||||
try {
|
||||
const containerClient = getAzureContainerClient(containerName);
|
||||
const blobPath = userId ? `${basePath}/${userId}/${fileName}` : `${basePath}/${fileName}`;
|
||||
const blockBlobClient = containerClient.getBlockBlobClient(blobPath);
|
||||
return blockBlobClient.url;
|
||||
} catch (error) {
|
||||
logger.error('[getAzureURL] Error retrieving blob URL:', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Deletes a blob from Azure Blob Storage.
|
||||
*
|
||||
* @param {Object} params
|
||||
* @param {string} params.fileName - The name of the file.
|
||||
* @param {string} [params.basePath='images'] - The base folder where the file is stored.
|
||||
* @param {string} params.userId - The user's id.
|
||||
* @param {string} [params.containerName] - The Azure Blob container name.
|
||||
*/
|
||||
async function deleteFileFromAzure({
|
||||
fileName,
|
||||
basePath = defaultBasePath,
|
||||
userId,
|
||||
containerName,
|
||||
}) {
|
||||
try {
|
||||
const containerClient = getAzureContainerClient(containerName);
|
||||
const blobPath = `${basePath}/${userId}/${fileName}`;
|
||||
const blockBlobClient = containerClient.getBlockBlobClient(blobPath);
|
||||
await blockBlobClient.delete();
|
||||
logger.debug('[deleteFileFromAzure] Blob deleted successfully from Azure Blob Storage');
|
||||
} catch (error) {
|
||||
logger.error('[deleteFileFromAzure] Error deleting blob:', error.message);
|
||||
if (error.statusCode === 404) {
|
||||
return;
|
||||
}
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Uploads a file from the local file system to Azure Blob Storage.
|
||||
*
|
||||
* This function reads the file from disk and then uploads it to Azure Blob Storage
|
||||
* at the path: {basePath}/{userId}/{fileName}.
|
||||
*
|
||||
* @param {Object} params
|
||||
* @param {object} params.req - The Express request object.
|
||||
* @param {Express.Multer.File} params.file - The file object.
|
||||
* @param {string} params.file_id - The file id.
|
||||
* @param {string} [params.basePath='images'] - The base folder within the container.
|
||||
* @param {string} [params.containerName] - The Azure Blob container name.
|
||||
* @returns {Promise<{ filepath: string, bytes: number }>} An object containing the blob URL and its byte size.
|
||||
*/
|
||||
async function uploadFileToAzure({
|
||||
req,
|
||||
file,
|
||||
file_id,
|
||||
basePath = defaultBasePath,
|
||||
containerName,
|
||||
}) {
|
||||
try {
|
||||
const inputFilePath = file.path;
|
||||
const inputBuffer = await fs.promises.readFile(inputFilePath);
|
||||
const bytes = Buffer.byteLength(inputBuffer);
|
||||
const userId = req.user.id;
|
||||
const fileName = `${file_id}__${path.basename(inputFilePath)}`;
|
||||
const fileURL = await saveBufferToAzure({
|
||||
userId,
|
||||
buffer: inputBuffer,
|
||||
fileName,
|
||||
basePath,
|
||||
containerName,
|
||||
});
|
||||
await fs.promises.unlink(inputFilePath);
|
||||
return { filepath: fileURL, bytes };
|
||||
} catch (error) {
|
||||
logger.error('[uploadFileToAzure] Error uploading file:', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves a readable stream for a blob from Azure Blob Storage.
|
||||
*
|
||||
* @param {object} _req - The Express request object.
|
||||
* @param {string} fileURL - The URL of the blob.
|
||||
* @returns {Promise<ReadableStream>} A readable stream of the blob.
|
||||
*/
|
||||
async function getAzureFileStream(_req, fileURL) {
|
||||
try {
|
||||
const response = await axios({
|
||||
method: 'get',
|
||||
url: fileURL,
|
||||
responseType: 'stream',
|
||||
});
|
||||
return response.data;
|
||||
} catch (error) {
|
||||
logger.error('[getAzureFileStream] Error getting blob stream:', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
saveBufferToAzure,
|
||||
saveURLToAzure,
|
||||
getAzureURL,
|
||||
deleteFileFromAzure,
|
||||
uploadFileToAzure,
|
||||
getAzureFileStream,
|
||||
};
|
||||
124
api/server/services/Files/Azure/images.js
Normal file
124
api/server/services/Files/Azure/images.js
Normal file
|
|
@ -0,0 +1,124 @@
|
|||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
const sharp = require('sharp');
|
||||
const { resizeImageBuffer } = require('../images/resize');
|
||||
const { updateUser } = require('~/models/userMethods');
|
||||
const { updateFile } = require('~/models/File');
|
||||
const { logger } = require('~/config');
|
||||
const { saveBufferToAzure } = require('./crud');
|
||||
|
||||
/**
|
||||
* Uploads an image file to Azure Blob Storage.
|
||||
* It resizes and converts the image similar to your Firebase implementation.
|
||||
*
|
||||
* @param {Object} params
|
||||
* @param {object} params.req - The Express request object.
|
||||
* @param {Express.Multer.File} params.file - The file object.
|
||||
* @param {string} params.file_id - The file id.
|
||||
* @param {EModelEndpoint} params.endpoint - The endpoint parameters.
|
||||
* @param {string} [params.resolution='high'] - The image resolution.
|
||||
* @param {string} [params.basePath='images'] - The base folder within the container.
|
||||
* @param {string} [params.containerName] - The Azure Blob container name.
|
||||
* @returns {Promise<{ filepath: string, bytes: number, width: number, height: number }>}
|
||||
*/
|
||||
async function uploadImageToAzure({
|
||||
req,
|
||||
file,
|
||||
file_id,
|
||||
endpoint,
|
||||
resolution = 'high',
|
||||
basePath = 'images',
|
||||
containerName,
|
||||
}) {
|
||||
try {
|
||||
const inputFilePath = file.path;
|
||||
const inputBuffer = await fs.promises.readFile(inputFilePath);
|
||||
const {
|
||||
buffer: resizedBuffer,
|
||||
width,
|
||||
height,
|
||||
} = await resizeImageBuffer(inputBuffer, resolution, endpoint);
|
||||
const extension = path.extname(inputFilePath);
|
||||
const userId = req.user.id;
|
||||
let webPBuffer;
|
||||
let fileName = `${file_id}__${path.basename(inputFilePath)}`;
|
||||
const targetExtension = `.${req.app.locals.imageOutputType}`;
|
||||
|
||||
if (extension.toLowerCase() === targetExtension) {
|
||||
webPBuffer = resizedBuffer;
|
||||
} else {
|
||||
webPBuffer = await sharp(resizedBuffer).toFormat(req.app.locals.imageOutputType).toBuffer();
|
||||
const extRegExp = new RegExp(path.extname(fileName) + '$');
|
||||
fileName = fileName.replace(extRegExp, targetExtension);
|
||||
if (!path.extname(fileName)) {
|
||||
fileName += targetExtension;
|
||||
}
|
||||
}
|
||||
const downloadURL = await saveBufferToAzure({
|
||||
userId,
|
||||
buffer: webPBuffer,
|
||||
fileName,
|
||||
basePath,
|
||||
containerName,
|
||||
});
|
||||
await fs.promises.unlink(inputFilePath);
|
||||
const bytes = Buffer.byteLength(webPBuffer);
|
||||
return { filepath: downloadURL, bytes, width, height };
|
||||
} catch (error) {
|
||||
logger.error('[uploadImageToAzure] Error uploading image:', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Prepares the image URL and updates the file record.
|
||||
*
|
||||
* @param {object} req - The Express request object.
|
||||
* @param {MongoFile} file - The file object.
|
||||
* @returns {Promise<[MongoFile, string]>}
|
||||
*/
|
||||
async function prepareAzureImageURL(req, file) {
|
||||
const { filepath } = file;
|
||||
const promises = [];
|
||||
promises.push(updateFile({ file_id: file.file_id }));
|
||||
promises.push(filepath);
|
||||
return await Promise.all(promises);
|
||||
}
|
||||
|
||||
/**
|
||||
* Uploads and processes a user's avatar to Azure Blob Storage.
|
||||
*
|
||||
* @param {Object} params
|
||||
* @param {Buffer} params.buffer - The avatar image buffer.
|
||||
* @param {string} params.userId - The user's id.
|
||||
* @param {string} params.manual - Flag to indicate manual update.
|
||||
* @param {string} [params.basePath='images'] - The base folder within the container.
|
||||
* @param {string} [params.containerName] - The Azure Blob container name.
|
||||
* @returns {Promise<string>} The URL of the avatar.
|
||||
*/
|
||||
async function processAzureAvatar({ buffer, userId, manual, basePath = 'images', containerName }) {
|
||||
try {
|
||||
const downloadURL = await saveBufferToAzure({
|
||||
userId,
|
||||
buffer,
|
||||
fileName: 'avatar.png',
|
||||
basePath,
|
||||
containerName,
|
||||
});
|
||||
const isManual = manual === 'true';
|
||||
const url = `${downloadURL}?manual=${isManual}`;
|
||||
if (isManual) {
|
||||
await updateUser(userId, { avatar: url });
|
||||
}
|
||||
return url;
|
||||
} catch (error) {
|
||||
logger.error('[processAzureAvatar] Error uploading profile picture to Azure:', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
uploadImageToAzure,
|
||||
prepareAzureImageURL,
|
||||
processAzureAvatar,
|
||||
};
|
||||
9
api/server/services/Files/Azure/index.js
Normal file
9
api/server/services/Files/Azure/index.js
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
const crud = require('./crud');
|
||||
const images = require('./images');
|
||||
const initialize = require('./initialize');
|
||||
|
||||
module.exports = {
|
||||
...crud,
|
||||
...images,
|
||||
...initialize,
|
||||
};
|
||||
55
api/server/services/Files/Azure/initialize.js
Normal file
55
api/server/services/Files/Azure/initialize.js
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
const { BlobServiceClient } = require('@azure/storage-blob');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
let blobServiceClient = null;
|
||||
let azureWarningLogged = false;
|
||||
|
||||
/**
|
||||
* Initializes the Azure Blob Service client.
|
||||
* This function establishes a connection by checking if a connection string is provided.
|
||||
* If available, the connection string is used; otherwise, Managed Identity (via DefaultAzureCredential) is utilized.
|
||||
* Note: Container creation (and its public access settings) is handled later in the CRUD functions.
|
||||
* @returns {BlobServiceClient|null} The initialized client, or null if the required configuration is missing.
|
||||
*/
|
||||
const initializeAzureBlobService = () => {
|
||||
if (blobServiceClient) {
|
||||
return blobServiceClient;
|
||||
}
|
||||
const connectionString = process.env.AZURE_STORAGE_CONNECTION_STRING;
|
||||
if (connectionString) {
|
||||
blobServiceClient = BlobServiceClient.fromConnectionString(connectionString);
|
||||
logger.info('Azure Blob Service initialized using connection string');
|
||||
} else {
|
||||
const { DefaultAzureCredential } = require('@azure/identity');
|
||||
const accountName = process.env.AZURE_STORAGE_ACCOUNT_NAME;
|
||||
if (!accountName) {
|
||||
if (!azureWarningLogged) {
|
||||
logger.error(
|
||||
'[initializeAzureBlobService] Azure Blob Service not initialized. Connection string missing and AZURE_STORAGE_ACCOUNT_NAME not provided.',
|
||||
);
|
||||
azureWarningLogged = true;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
const url = `https://${accountName}.blob.core.windows.net`;
|
||||
const credential = new DefaultAzureCredential();
|
||||
blobServiceClient = new BlobServiceClient(url, credential);
|
||||
logger.info('Azure Blob Service initialized using Managed Identity');
|
||||
}
|
||||
return blobServiceClient;
|
||||
};
|
||||
|
||||
/**
|
||||
* Retrieves the Azure ContainerClient for the given container name.
|
||||
* @param {string} [containerName=process.env.AZURE_CONTAINER_NAME || 'files'] - The container name.
|
||||
* @returns {ContainerClient|null} The Azure ContainerClient.
|
||||
*/
|
||||
const getAzureContainerClient = (containerName = process.env.AZURE_CONTAINER_NAME || 'files') => {
|
||||
const serviceClient = initializeAzureBlobService();
|
||||
return serviceClient ? serviceClient.getContainerClient(containerName) : null;
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
initializeAzureBlobService,
|
||||
getAzureContainerClient,
|
||||
};
|
||||
|
|
@ -1,7 +1,9 @@
|
|||
// Code Files
|
||||
const axios = require('axios');
|
||||
const FormData = require('form-data');
|
||||
const { getCodeBaseURL } = require('@librechat/agents');
|
||||
const { createAxiosInstance } = require('~/config');
|
||||
const { logAxiosError } = require('~/utils');
|
||||
|
||||
const axios = createAxiosInstance();
|
||||
|
||||
const MAX_FILE_SIZE = 150 * 1024 * 1024;
|
||||
|
||||
|
|
@ -15,7 +17,8 @@ const MAX_FILE_SIZE = 150 * 1024 * 1024;
|
|||
async function getCodeOutputDownloadStream(fileIdentifier, apiKey) {
|
||||
try {
|
||||
const baseURL = getCodeBaseURL();
|
||||
const response = await axios({
|
||||
/** @type {import('axios').AxiosRequestConfig} */
|
||||
const options = {
|
||||
method: 'get',
|
||||
url: `${baseURL}/download/${fileIdentifier}`,
|
||||
responseType: 'stream',
|
||||
|
|
@ -24,10 +27,15 @@ async function getCodeOutputDownloadStream(fileIdentifier, apiKey) {
|
|||
'X-API-Key': apiKey,
|
||||
},
|
||||
timeout: 15000,
|
||||
});
|
||||
};
|
||||
|
||||
const response = await axios(options);
|
||||
return response;
|
||||
} catch (error) {
|
||||
logAxiosError({
|
||||
message: `Error downloading code environment file stream: ${error.message}`,
|
||||
error,
|
||||
});
|
||||
throw new Error(`Error downloading file: ${error.message}`);
|
||||
}
|
||||
}
|
||||
|
|
@ -53,7 +61,8 @@ async function uploadCodeEnvFile({ req, stream, filename, apiKey, entity_id = ''
|
|||
form.append('file', stream, filename);
|
||||
|
||||
const baseURL = getCodeBaseURL();
|
||||
const response = await axios.post(`${baseURL}/upload`, form, {
|
||||
/** @type {import('axios').AxiosRequestConfig} */
|
||||
const options = {
|
||||
headers: {
|
||||
...form.getHeaders(),
|
||||
'Content-Type': 'multipart/form-data',
|
||||
|
|
@ -63,7 +72,9 @@ async function uploadCodeEnvFile({ req, stream, filename, apiKey, entity_id = ''
|
|||
},
|
||||
maxContentLength: MAX_FILE_SIZE,
|
||||
maxBodyLength: MAX_FILE_SIZE,
|
||||
});
|
||||
};
|
||||
|
||||
const response = await axios.post(`${baseURL}/upload`, form, options);
|
||||
|
||||
/** @type {{ message: string; session_id: string; files: Array<{ fileId: string; filename: string }> }} */
|
||||
const result = response.data;
|
||||
|
|
@ -78,7 +89,11 @@ async function uploadCodeEnvFile({ req, stream, filename, apiKey, entity_id = ''
|
|||
|
||||
return `${fileIdentifier}?entity_id=${entity_id}`;
|
||||
} catch (error) {
|
||||
throw new Error(`Error uploading file: ${error.message}`);
|
||||
logAxiosError({
|
||||
message: `Error uploading code environment file: ${error.message}`,
|
||||
error,
|
||||
});
|
||||
throw new Error(`Error uploading code environment file: ${error.message}`);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ const {
|
|||
const { getStrategyFunctions } = require('~/server/services/Files/strategies');
|
||||
const { convertImage } = require('~/server/services/Files/images/convert');
|
||||
const { createFile, getFiles, updateFile } = require('~/models/File');
|
||||
const { logAxiosError } = require('~/utils');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
/**
|
||||
|
|
@ -85,7 +86,10 @@ const processCodeOutput = async ({
|
|||
/** Note: `messageId` & `toolCallId` are not part of file DB schema; message object records associated file ID */
|
||||
return Object.assign(file, { messageId, toolCallId });
|
||||
} catch (error) {
|
||||
logger.error('Error downloading file:', error);
|
||||
logAxiosError({
|
||||
message: 'Error downloading code environment file',
|
||||
error,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -135,7 +139,10 @@ async function getSessionInfo(fileIdentifier, apiKey) {
|
|||
|
||||
return response.data.find((file) => file.name.startsWith(path))?.lastModified;
|
||||
} catch (error) {
|
||||
logger.error(`Error fetching session info: ${error.message}`, error);
|
||||
logAxiosError({
|
||||
message: `Error fetching session info: ${error.message}`,
|
||||
error,
|
||||
});
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
|
@ -202,7 +209,7 @@ const primeFiles = async (options, apiKey) => {
|
|||
const { handleFileUpload: uploadCodeEnvFile } = getStrategyFunctions(
|
||||
FileSources.execute_code,
|
||||
);
|
||||
const stream = await getDownloadStream(file.filepath);
|
||||
const stream = await getDownloadStream(options.req, file.filepath);
|
||||
const fileIdentifier = await uploadCodeEnvFile({
|
||||
req: options.req,
|
||||
stream,
|
||||
|
|
|
|||
|
|
@ -224,10 +224,11 @@ async function uploadFileToFirebase({ req, file, file_id }) {
|
|||
/**
|
||||
* Retrieves a readable stream for a file from Firebase storage.
|
||||
*
|
||||
* @param {ServerRequest} _req
|
||||
* @param {string} filepath - The filepath.
|
||||
* @returns {Promise<ReadableStream>} A readable stream of the file.
|
||||
*/
|
||||
async function getFirebaseFileStream(filepath) {
|
||||
async function getFirebaseFileStream(_req, filepath) {
|
||||
try {
|
||||
const storage = getFirebaseStorage();
|
||||
if (!storage) {
|
||||
|
|
|
|||
|
|
@ -175,6 +175,17 @@ const isValidPath = (req, base, subfolder, filepath) => {
|
|||
return normalizedFilepath.startsWith(normalizedBase);
|
||||
};
|
||||
|
||||
/**
|
||||
* @param {string} filepath
|
||||
*/
|
||||
const unlinkFile = async (filepath) => {
|
||||
try {
|
||||
await fs.promises.unlink(filepath);
|
||||
} catch (error) {
|
||||
logger.error('Error deleting file:', error);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Deletes a file from the filesystem. This function takes a file object, constructs the full path, and
|
||||
* verifies the path's validity before deleting the file. If the path is invalid, an error is thrown.
|
||||
|
|
@ -217,7 +228,7 @@ const deleteLocalFile = async (req, file) => {
|
|||
throw new Error(`Invalid file path: ${file.filepath}`);
|
||||
}
|
||||
|
||||
await fs.promises.unlink(filepath);
|
||||
await unlinkFile(filepath);
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -233,7 +244,7 @@ const deleteLocalFile = async (req, file) => {
|
|||
throw new Error('Invalid file path');
|
||||
}
|
||||
|
||||
await fs.promises.unlink(filepath);
|
||||
await unlinkFile(filepath);
|
||||
};
|
||||
|
||||
/**
|
||||
|
|
@ -275,11 +286,31 @@ async function uploadLocalFile({ req, file, file_id }) {
|
|||
/**
|
||||
* Retrieves a readable stream for a file from local storage.
|
||||
*
|
||||
* @param {ServerRequest} req - The request object from Express
|
||||
* @param {string} filepath - The filepath.
|
||||
* @returns {ReadableStream} A readable stream of the file.
|
||||
*/
|
||||
function getLocalFileStream(filepath) {
|
||||
function getLocalFileStream(req, filepath) {
|
||||
try {
|
||||
if (filepath.includes('/uploads/')) {
|
||||
const basePath = filepath.split('/uploads/')[1];
|
||||
|
||||
if (!basePath) {
|
||||
logger.warn(`Invalid base path: ${filepath}`);
|
||||
throw new Error(`Invalid file path: ${filepath}`);
|
||||
}
|
||||
|
||||
const fullPath = path.join(req.app.locals.paths.uploads, basePath);
|
||||
const uploadsDir = req.app.locals.paths.uploads;
|
||||
|
||||
const rel = path.relative(uploadsDir, fullPath);
|
||||
if (rel.startsWith('..') || path.isAbsolute(rel) || rel.includes(`..${path.sep}`)) {
|
||||
logger.warn(`Invalid relative file path: ${filepath}`);
|
||||
throw new Error(`Invalid file path: ${filepath}`);
|
||||
}
|
||||
|
||||
return fs.createReadStream(fullPath);
|
||||
}
|
||||
return fs.createReadStream(filepath);
|
||||
} catch (error) {
|
||||
logger.error('Error getting local file stream:', error);
|
||||
|
|
|
|||
207
api/server/services/Files/MistralOCR/crud.js
Normal file
207
api/server/services/Files/MistralOCR/crud.js
Normal file
|
|
@ -0,0 +1,207 @@
|
|||
// ~/server/services/Files/MistralOCR/crud.js
|
||||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
const FormData = require('form-data');
|
||||
const { FileSources, envVarRegex, extractEnvVariable } = require('librechat-data-provider');
|
||||
const { loadAuthValues } = require('~/server/services/Tools/credentials');
|
||||
const { logger, createAxiosInstance } = require('~/config');
|
||||
const { logAxiosError } = require('~/utils');
|
||||
|
||||
const axios = createAxiosInstance();
|
||||
|
||||
/**
|
||||
* Uploads a document to Mistral API using file streaming to avoid loading the entire file into memory
|
||||
*
|
||||
* @param {Object} params Upload parameters
|
||||
* @param {string} params.filePath The path to the file on disk
|
||||
* @param {string} [params.fileName] Optional filename to use (defaults to the name from filePath)
|
||||
* @param {string} params.apiKey Mistral API key
|
||||
* @param {string} [params.baseURL=https://api.mistral.ai/v1] Mistral API base URL
|
||||
* @returns {Promise<Object>} The response from Mistral API
|
||||
*/
|
||||
async function uploadDocumentToMistral({
|
||||
filePath,
|
||||
fileName = '',
|
||||
apiKey,
|
||||
baseURL = 'https://api.mistral.ai/v1',
|
||||
}) {
|
||||
const form = new FormData();
|
||||
form.append('purpose', 'ocr');
|
||||
const actualFileName = fileName || path.basename(filePath);
|
||||
const fileStream = fs.createReadStream(filePath);
|
||||
form.append('file', fileStream, { filename: actualFileName });
|
||||
|
||||
return axios
|
||||
.post(`${baseURL}/files`, form, {
|
||||
headers: {
|
||||
Authorization: `Bearer ${apiKey}`,
|
||||
...form.getHeaders(),
|
||||
},
|
||||
maxBodyLength: Infinity,
|
||||
maxContentLength: Infinity,
|
||||
})
|
||||
.then((res) => res.data)
|
||||
.catch((error) => {
|
||||
logger.error('Error uploading document to Mistral:', error.message);
|
||||
throw error;
|
||||
});
|
||||
}
|
||||
|
||||
async function getSignedUrl({
|
||||
apiKey,
|
||||
fileId,
|
||||
expiry = 24,
|
||||
baseURL = 'https://api.mistral.ai/v1',
|
||||
}) {
|
||||
return axios
|
||||
.get(`${baseURL}/files/${fileId}/url?expiry=${expiry}`, {
|
||||
headers: {
|
||||
Authorization: `Bearer ${apiKey}`,
|
||||
},
|
||||
})
|
||||
.then((res) => res.data)
|
||||
.catch((error) => {
|
||||
logger.error('Error fetching signed URL:', error.message);
|
||||
throw error;
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* @param {Object} params
|
||||
* @param {string} params.apiKey
|
||||
* @param {string} params.documentUrl
|
||||
* @param {string} [params.baseURL]
|
||||
* @returns {Promise<OCRResult>}
|
||||
*/
|
||||
async function performOCR({
|
||||
apiKey,
|
||||
documentUrl,
|
||||
model = 'mistral-ocr-latest',
|
||||
baseURL = 'https://api.mistral.ai/v1',
|
||||
}) {
|
||||
return axios
|
||||
.post(
|
||||
`${baseURL}/ocr`,
|
||||
{
|
||||
model,
|
||||
include_image_base64: false,
|
||||
document: {
|
||||
type: 'document_url',
|
||||
document_url: documentUrl,
|
||||
},
|
||||
},
|
||||
{
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${apiKey}`,
|
||||
},
|
||||
},
|
||||
)
|
||||
.then((res) => res.data)
|
||||
.catch((error) => {
|
||||
logger.error('Error performing OCR:', error.message);
|
||||
throw error;
|
||||
});
|
||||
}
|
||||
|
||||
function extractVariableName(str) {
|
||||
const match = str.match(envVarRegex);
|
||||
return match ? match[1] : null;
|
||||
}
|
||||
|
||||
const uploadMistralOCR = async ({ req, file, file_id, entity_id }) => {
|
||||
try {
|
||||
/** @type {TCustomConfig['ocr']} */
|
||||
const ocrConfig = req.app.locals?.ocr;
|
||||
|
||||
const apiKeyConfig = ocrConfig.apiKey || '';
|
||||
const baseURLConfig = ocrConfig.baseURL || '';
|
||||
|
||||
const isApiKeyEnvVar = envVarRegex.test(apiKeyConfig);
|
||||
const isBaseURLEnvVar = envVarRegex.test(baseURLConfig);
|
||||
|
||||
const isApiKeyEmpty = !apiKeyConfig.trim();
|
||||
const isBaseURLEmpty = !baseURLConfig.trim();
|
||||
|
||||
let apiKey, baseURL;
|
||||
|
||||
if (isApiKeyEnvVar || isBaseURLEnvVar || isApiKeyEmpty || isBaseURLEmpty) {
|
||||
const apiKeyVarName = isApiKeyEnvVar ? extractVariableName(apiKeyConfig) : 'OCR_API_KEY';
|
||||
const baseURLVarName = isBaseURLEnvVar ? extractVariableName(baseURLConfig) : 'OCR_BASEURL';
|
||||
|
||||
const authValues = await loadAuthValues({
|
||||
userId: req.user.id,
|
||||
authFields: [baseURLVarName, apiKeyVarName],
|
||||
optional: new Set([baseURLVarName]),
|
||||
});
|
||||
|
||||
apiKey = authValues[apiKeyVarName];
|
||||
baseURL = authValues[baseURLVarName];
|
||||
} else {
|
||||
apiKey = apiKeyConfig;
|
||||
baseURL = baseURLConfig;
|
||||
}
|
||||
|
||||
const mistralFile = await uploadDocumentToMistral({
|
||||
filePath: file.path,
|
||||
fileName: file.originalname,
|
||||
apiKey,
|
||||
baseURL,
|
||||
});
|
||||
|
||||
const modelConfig = ocrConfig.mistralModel || '';
|
||||
const model = envVarRegex.test(modelConfig)
|
||||
? extractEnvVariable(modelConfig)
|
||||
: modelConfig.trim() || 'mistral-ocr-latest';
|
||||
|
||||
const signedUrlResponse = await getSignedUrl({
|
||||
apiKey,
|
||||
baseURL,
|
||||
fileId: mistralFile.id,
|
||||
});
|
||||
|
||||
const ocrResult = await performOCR({
|
||||
apiKey,
|
||||
baseURL,
|
||||
model,
|
||||
documentUrl: signedUrlResponse.url,
|
||||
});
|
||||
|
||||
let aggregatedText = '';
|
||||
const images = [];
|
||||
ocrResult.pages.forEach((page, index) => {
|
||||
if (ocrResult.pages.length > 1) {
|
||||
aggregatedText += `# PAGE ${index + 1}\n`;
|
||||
}
|
||||
|
||||
aggregatedText += page.markdown + '\n\n';
|
||||
|
||||
if (page.images && page.images.length > 0) {
|
||||
page.images.forEach((image) => {
|
||||
if (image.image_base64) {
|
||||
images.push(image.image_base64);
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
return {
|
||||
filename: file.originalname,
|
||||
bytes: aggregatedText.length * 4,
|
||||
filepath: FileSources.mistral_ocr,
|
||||
text: aggregatedText,
|
||||
images,
|
||||
};
|
||||
} catch (error) {
|
||||
const message = 'Error uploading document to Mistral OCR API';
|
||||
logAxiosError({ error, message });
|
||||
throw new Error(message);
|
||||
}
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
uploadDocumentToMistral,
|
||||
uploadMistralOCR,
|
||||
getSignedUrl,
|
||||
performOCR,
|
||||
};
|
||||
737
api/server/services/Files/MistralOCR/crud.spec.js
Normal file
737
api/server/services/Files/MistralOCR/crud.spec.js
Normal file
|
|
@ -0,0 +1,737 @@
|
|||
const fs = require('fs');
|
||||
|
||||
const mockAxios = {
|
||||
interceptors: {
|
||||
request: { use: jest.fn(), eject: jest.fn() },
|
||||
response: { use: jest.fn(), eject: jest.fn() },
|
||||
},
|
||||
create: jest.fn().mockReturnValue({
|
||||
defaults: {
|
||||
proxy: null,
|
||||
},
|
||||
get: jest.fn().mockResolvedValue({ data: {} }),
|
||||
post: jest.fn().mockResolvedValue({ data: {} }),
|
||||
put: jest.fn().mockResolvedValue({ data: {} }),
|
||||
delete: jest.fn().mockResolvedValue({ data: {} }),
|
||||
}),
|
||||
get: jest.fn().mockResolvedValue({ data: {} }),
|
||||
post: jest.fn().mockResolvedValue({ data: {} }),
|
||||
put: jest.fn().mockResolvedValue({ data: {} }),
|
||||
delete: jest.fn().mockResolvedValue({ data: {} }),
|
||||
reset: jest.fn().mockImplementation(function () {
|
||||
this.get.mockClear();
|
||||
this.post.mockClear();
|
||||
this.put.mockClear();
|
||||
this.delete.mockClear();
|
||||
this.create.mockClear();
|
||||
}),
|
||||
};
|
||||
|
||||
jest.mock('axios', () => mockAxios);
|
||||
jest.mock('fs');
|
||||
jest.mock('~/utils', () => ({
|
||||
logAxiosError: jest.fn(),
|
||||
}));
|
||||
jest.mock('~/config', () => ({
|
||||
logger: {
|
||||
error: jest.fn(),
|
||||
},
|
||||
createAxiosInstance: () => mockAxios,
|
||||
}));
|
||||
jest.mock('~/server/services/Tools/credentials', () => ({
|
||||
loadAuthValues: jest.fn(),
|
||||
}));
|
||||
|
||||
const { uploadDocumentToMistral, uploadMistralOCR, getSignedUrl, performOCR } = require('./crud');
|
||||
|
||||
describe('MistralOCR Service', () => {
|
||||
afterEach(() => {
|
||||
mockAxios.reset();
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
describe('uploadDocumentToMistral', () => {
|
||||
beforeEach(() => {
|
||||
// Create a more complete mock for file streams that FormData can work with
|
||||
const mockReadStream = {
|
||||
on: jest.fn().mockImplementation(function (event, handler) {
|
||||
// Simulate immediate 'end' event to make FormData complete processing
|
||||
if (event === 'end') {
|
||||
handler();
|
||||
}
|
||||
return this;
|
||||
}),
|
||||
pipe: jest.fn().mockImplementation(function () {
|
||||
return this;
|
||||
}),
|
||||
pause: jest.fn(),
|
||||
resume: jest.fn(),
|
||||
emit: jest.fn(),
|
||||
once: jest.fn(),
|
||||
destroy: jest.fn(),
|
||||
};
|
||||
|
||||
fs.createReadStream = jest.fn().mockReturnValue(mockReadStream);
|
||||
|
||||
// Mock FormData's append to avoid actual stream processing
|
||||
jest.mock('form-data', () => {
|
||||
const mockFormData = function () {
|
||||
return {
|
||||
append: jest.fn(),
|
||||
getHeaders: jest
|
||||
.fn()
|
||||
.mockReturnValue({ 'content-type': 'multipart/form-data; boundary=---boundary' }),
|
||||
getBuffer: jest.fn().mockReturnValue(Buffer.from('mock-form-data')),
|
||||
getLength: jest.fn().mockReturnValue(100),
|
||||
};
|
||||
};
|
||||
return mockFormData;
|
||||
});
|
||||
});
|
||||
|
||||
it('should upload a document to Mistral API using file streaming', async () => {
|
||||
const mockResponse = { data: { id: 'file-123', purpose: 'ocr' } };
|
||||
mockAxios.post.mockResolvedValueOnce(mockResponse);
|
||||
|
||||
const result = await uploadDocumentToMistral({
|
||||
filePath: '/path/to/test.pdf',
|
||||
fileName: 'test.pdf',
|
||||
apiKey: 'test-api-key',
|
||||
});
|
||||
|
||||
// Check that createReadStream was called with the correct file path
|
||||
expect(fs.createReadStream).toHaveBeenCalledWith('/path/to/test.pdf');
|
||||
|
||||
// Since we're mocking FormData, we'll just check that axios was called correctly
|
||||
expect(mockAxios.post).toHaveBeenCalledWith(
|
||||
'https://api.mistral.ai/v1/files',
|
||||
expect.anything(),
|
||||
expect.objectContaining({
|
||||
headers: expect.objectContaining({
|
||||
Authorization: 'Bearer test-api-key',
|
||||
}),
|
||||
maxBodyLength: Infinity,
|
||||
maxContentLength: Infinity,
|
||||
}),
|
||||
);
|
||||
expect(result).toEqual(mockResponse.data);
|
||||
});
|
||||
|
||||
it('should handle errors during document upload', async () => {
|
||||
const errorMessage = 'API error';
|
||||
mockAxios.post.mockRejectedValueOnce(new Error(errorMessage));
|
||||
|
||||
await expect(
|
||||
uploadDocumentToMistral({
|
||||
filePath: '/path/to/test.pdf',
|
||||
fileName: 'test.pdf',
|
||||
apiKey: 'test-api-key',
|
||||
}),
|
||||
).rejects.toThrow();
|
||||
|
||||
const { logger } = require('~/config');
|
||||
expect(logger.error).toHaveBeenCalledWith(
|
||||
expect.stringContaining('Error uploading document to Mistral:'),
|
||||
expect.any(String),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getSignedUrl', () => {
|
||||
it('should fetch signed URL from Mistral API', async () => {
|
||||
const mockResponse = { data: { url: 'https://document-url.com' } };
|
||||
mockAxios.get.mockResolvedValueOnce(mockResponse);
|
||||
|
||||
const result = await getSignedUrl({
|
||||
fileId: 'file-123',
|
||||
apiKey: 'test-api-key',
|
||||
});
|
||||
|
||||
expect(mockAxios.get).toHaveBeenCalledWith(
|
||||
'https://api.mistral.ai/v1/files/file-123/url?expiry=24',
|
||||
{
|
||||
headers: {
|
||||
Authorization: 'Bearer test-api-key',
|
||||
},
|
||||
},
|
||||
);
|
||||
expect(result).toEqual(mockResponse.data);
|
||||
});
|
||||
|
||||
it('should handle errors when fetching signed URL', async () => {
|
||||
const errorMessage = 'API error';
|
||||
mockAxios.get.mockRejectedValueOnce(new Error(errorMessage));
|
||||
|
||||
await expect(
|
||||
getSignedUrl({
|
||||
fileId: 'file-123',
|
||||
apiKey: 'test-api-key',
|
||||
}),
|
||||
).rejects.toThrow();
|
||||
|
||||
const { logger } = require('~/config');
|
||||
expect(logger.error).toHaveBeenCalledWith('Error fetching signed URL:', errorMessage);
|
||||
});
|
||||
});
|
||||
|
||||
describe('performOCR', () => {
|
||||
it('should perform OCR using Mistral API', async () => {
|
||||
const mockResponse = {
|
||||
data: {
|
||||
pages: [{ markdown: 'Page 1 content' }, { markdown: 'Page 2 content' }],
|
||||
},
|
||||
};
|
||||
mockAxios.post.mockResolvedValueOnce(mockResponse);
|
||||
|
||||
const result = await performOCR({
|
||||
apiKey: 'test-api-key',
|
||||
documentUrl: 'https://document-url.com',
|
||||
model: 'mistral-ocr-latest',
|
||||
});
|
||||
|
||||
expect(mockAxios.post).toHaveBeenCalledWith(
|
||||
'https://api.mistral.ai/v1/ocr',
|
||||
{
|
||||
model: 'mistral-ocr-latest',
|
||||
include_image_base64: false,
|
||||
document: {
|
||||
type: 'document_url',
|
||||
document_url: 'https://document-url.com',
|
||||
},
|
||||
},
|
||||
{
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: 'Bearer test-api-key',
|
||||
},
|
||||
},
|
||||
);
|
||||
expect(result).toEqual(mockResponse.data);
|
||||
});
|
||||
|
||||
it('should handle errors during OCR processing', async () => {
|
||||
const errorMessage = 'OCR processing error';
|
||||
mockAxios.post.mockRejectedValueOnce(new Error(errorMessage));
|
||||
|
||||
await expect(
|
||||
performOCR({
|
||||
apiKey: 'test-api-key',
|
||||
documentUrl: 'https://document-url.com',
|
||||
}),
|
||||
).rejects.toThrow();
|
||||
|
||||
const { logger } = require('~/config');
|
||||
expect(logger.error).toHaveBeenCalledWith('Error performing OCR:', errorMessage);
|
||||
});
|
||||
});
|
||||
|
||||
describe('uploadMistralOCR', () => {
|
||||
beforeEach(() => {
|
||||
const mockReadStream = {
|
||||
on: jest.fn().mockImplementation(function (event, handler) {
|
||||
if (event === 'end') {
|
||||
handler();
|
||||
}
|
||||
return this;
|
||||
}),
|
||||
pipe: jest.fn().mockImplementation(function () {
|
||||
return this;
|
||||
}),
|
||||
pause: jest.fn(),
|
||||
resume: jest.fn(),
|
||||
emit: jest.fn(),
|
||||
once: jest.fn(),
|
||||
destroy: jest.fn(),
|
||||
};
|
||||
|
||||
fs.createReadStream = jest.fn().mockReturnValue(mockReadStream);
|
||||
});
|
||||
|
||||
it('should process OCR for a file with standard configuration', async () => {
|
||||
// Setup mocks
|
||||
const { loadAuthValues } = require('~/server/services/Tools/credentials');
|
||||
loadAuthValues.mockResolvedValue({
|
||||
OCR_API_KEY: 'test-api-key',
|
||||
OCR_BASEURL: 'https://api.mistral.ai/v1',
|
||||
});
|
||||
|
||||
// Mock file upload response
|
||||
mockAxios.post.mockResolvedValueOnce({
|
||||
data: { id: 'file-123', purpose: 'ocr' },
|
||||
});
|
||||
|
||||
// Mock signed URL response
|
||||
mockAxios.get.mockResolvedValueOnce({
|
||||
data: { url: 'https://signed-url.com' },
|
||||
});
|
||||
|
||||
// Mock OCR response with text and images
|
||||
mockAxios.post.mockResolvedValueOnce({
|
||||
data: {
|
||||
pages: [
|
||||
{
|
||||
markdown: 'Page 1 content',
|
||||
images: [{ image_base64: 'base64image1' }],
|
||||
},
|
||||
{
|
||||
markdown: 'Page 2 content',
|
||||
images: [{ image_base64: 'base64image2' }],
|
||||
},
|
||||
],
|
||||
},
|
||||
});
|
||||
|
||||
const req = {
|
||||
user: { id: 'user123' },
|
||||
app: {
|
||||
locals: {
|
||||
ocr: {
|
||||
// Use environment variable syntax to ensure loadAuthValues is called
|
||||
apiKey: '${OCR_API_KEY}',
|
||||
baseURL: '${OCR_BASEURL}',
|
||||
mistralModel: 'mistral-medium',
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const file = {
|
||||
path: '/tmp/upload/file.pdf',
|
||||
originalname: 'document.pdf',
|
||||
};
|
||||
|
||||
const result = await uploadMistralOCR({
|
||||
req,
|
||||
file,
|
||||
file_id: 'file123',
|
||||
entity_id: 'entity123',
|
||||
});
|
||||
|
||||
expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf');
|
||||
|
||||
expect(loadAuthValues).toHaveBeenCalledWith({
|
||||
userId: 'user123',
|
||||
authFields: ['OCR_BASEURL', 'OCR_API_KEY'],
|
||||
optional: expect.any(Set),
|
||||
});
|
||||
|
||||
// Verify OCR result
|
||||
expect(result).toEqual({
|
||||
filename: 'document.pdf',
|
||||
bytes: expect.any(Number),
|
||||
filepath: 'mistral_ocr',
|
||||
text: expect.stringContaining('# PAGE 1'),
|
||||
images: ['base64image1', 'base64image2'],
|
||||
});
|
||||
});
|
||||
|
||||
it('should process variable references in configuration', async () => {
|
||||
// Setup mocks with environment variables
|
||||
const { loadAuthValues } = require('~/server/services/Tools/credentials');
|
||||
loadAuthValues.mockResolvedValue({
|
||||
CUSTOM_API_KEY: 'custom-api-key',
|
||||
CUSTOM_BASEURL: 'https://custom-api.mistral.ai/v1',
|
||||
});
|
||||
|
||||
// Mock API responses
|
||||
mockAxios.post.mockResolvedValueOnce({
|
||||
data: { id: 'file-123', purpose: 'ocr' },
|
||||
});
|
||||
mockAxios.get.mockResolvedValueOnce({
|
||||
data: { url: 'https://signed-url.com' },
|
||||
});
|
||||
mockAxios.post.mockResolvedValueOnce({
|
||||
data: {
|
||||
pages: [{ markdown: 'Content from custom API' }],
|
||||
},
|
||||
});
|
||||
|
||||
const req = {
|
||||
user: { id: 'user123' },
|
||||
app: {
|
||||
locals: {
|
||||
ocr: {
|
||||
apiKey: '${CUSTOM_API_KEY}',
|
||||
baseURL: '${CUSTOM_BASEURL}',
|
||||
mistralModel: '${CUSTOM_MODEL}',
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
// Set environment variable for model
|
||||
process.env.CUSTOM_MODEL = 'mistral-large';
|
||||
|
||||
const file = {
|
||||
path: '/tmp/upload/file.pdf',
|
||||
originalname: 'document.pdf',
|
||||
};
|
||||
|
||||
const result = await uploadMistralOCR({
|
||||
req,
|
||||
file,
|
||||
file_id: 'file123',
|
||||
entity_id: 'entity123',
|
||||
});
|
||||
|
||||
expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf');
|
||||
|
||||
// Verify that custom environment variables were extracted and used
|
||||
expect(loadAuthValues).toHaveBeenCalledWith({
|
||||
userId: 'user123',
|
||||
authFields: ['CUSTOM_BASEURL', 'CUSTOM_API_KEY'],
|
||||
optional: expect.any(Set),
|
||||
});
|
||||
|
||||
// Check that mistral-large was used in the OCR API call
|
||||
expect(mockAxios.post).toHaveBeenCalledWith(
|
||||
expect.anything(),
|
||||
expect.objectContaining({
|
||||
model: 'mistral-large',
|
||||
}),
|
||||
expect.anything(),
|
||||
);
|
||||
|
||||
expect(result.text).toEqual('Content from custom API\n\n');
|
||||
});
|
||||
|
||||
it('should fall back to default values when variables are not properly formatted', async () => {
|
||||
const { loadAuthValues } = require('~/server/services/Tools/credentials');
|
||||
loadAuthValues.mockResolvedValue({
|
||||
OCR_API_KEY: 'default-api-key',
|
||||
OCR_BASEURL: undefined, // Testing optional parameter
|
||||
});
|
||||
|
||||
mockAxios.post.mockResolvedValueOnce({
|
||||
data: { id: 'file-123', purpose: 'ocr' },
|
||||
});
|
||||
mockAxios.get.mockResolvedValueOnce({
|
||||
data: { url: 'https://signed-url.com' },
|
||||
});
|
||||
mockAxios.post.mockResolvedValueOnce({
|
||||
data: {
|
||||
pages: [{ markdown: 'Default API result' }],
|
||||
},
|
||||
});
|
||||
|
||||
const req = {
|
||||
user: { id: 'user123' },
|
||||
app: {
|
||||
locals: {
|
||||
ocr: {
|
||||
// Use environment variable syntax to ensure loadAuthValues is called
|
||||
apiKey: '${INVALID_FORMAT}', // Using valid env var format but with an invalid name
|
||||
baseURL: '${OCR_BASEURL}', // Using valid env var format
|
||||
mistralModel: 'mistral-ocr-latest', // Plain string value
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const file = {
|
||||
path: '/tmp/upload/file.pdf',
|
||||
originalname: 'document.pdf',
|
||||
};
|
||||
|
||||
await uploadMistralOCR({
|
||||
req,
|
||||
file,
|
||||
file_id: 'file123',
|
||||
entity_id: 'entity123',
|
||||
});
|
||||
|
||||
expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf');
|
||||
|
||||
// Should use the default values
|
||||
expect(loadAuthValues).toHaveBeenCalledWith({
|
||||
userId: 'user123',
|
||||
authFields: ['OCR_BASEURL', 'INVALID_FORMAT'],
|
||||
optional: expect.any(Set),
|
||||
});
|
||||
|
||||
// Should use the default model when not using environment variable format
|
||||
expect(mockAxios.post).toHaveBeenCalledWith(
|
||||
expect.anything(),
|
||||
expect.objectContaining({
|
||||
model: 'mistral-ocr-latest',
|
||||
}),
|
||||
expect.anything(),
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle API errors during OCR process', async () => {
|
||||
const { loadAuthValues } = require('~/server/services/Tools/credentials');
|
||||
loadAuthValues.mockResolvedValue({
|
||||
OCR_API_KEY: 'test-api-key',
|
||||
});
|
||||
|
||||
// Mock file upload to fail
|
||||
mockAxios.post.mockRejectedValueOnce(new Error('Upload failed'));
|
||||
|
||||
const req = {
|
||||
user: { id: 'user123' },
|
||||
app: {
|
||||
locals: {
|
||||
ocr: {
|
||||
apiKey: 'OCR_API_KEY',
|
||||
baseURL: 'OCR_BASEURL',
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const file = {
|
||||
path: '/tmp/upload/file.pdf',
|
||||
originalname: 'document.pdf',
|
||||
};
|
||||
|
||||
await expect(
|
||||
uploadMistralOCR({
|
||||
req,
|
||||
file,
|
||||
file_id: 'file123',
|
||||
entity_id: 'entity123',
|
||||
}),
|
||||
).rejects.toThrow('Error uploading document to Mistral OCR API');
|
||||
expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf');
|
||||
|
||||
const { logAxiosError } = require('~/utils');
|
||||
expect(logAxiosError).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle single page documents without page numbering', async () => {
|
||||
const { loadAuthValues } = require('~/server/services/Tools/credentials');
|
||||
loadAuthValues.mockResolvedValue({
|
||||
OCR_API_KEY: 'test-api-key',
|
||||
OCR_BASEURL: 'https://api.mistral.ai/v1', // Make sure this is included
|
||||
});
|
||||
|
||||
// Clear all previous mocks
|
||||
mockAxios.post.mockClear();
|
||||
mockAxios.get.mockClear();
|
||||
|
||||
// 1. First mock: File upload response
|
||||
mockAxios.post.mockImplementationOnce(() =>
|
||||
Promise.resolve({ data: { id: 'file-123', purpose: 'ocr' } }),
|
||||
);
|
||||
|
||||
// 2. Second mock: Signed URL response
|
||||
mockAxios.get.mockImplementationOnce(() =>
|
||||
Promise.resolve({ data: { url: 'https://signed-url.com' } }),
|
||||
);
|
||||
|
||||
// 3. Third mock: OCR response
|
||||
mockAxios.post.mockImplementationOnce(() =>
|
||||
Promise.resolve({
|
||||
data: {
|
||||
pages: [{ markdown: 'Single page content' }],
|
||||
},
|
||||
}),
|
||||
);
|
||||
|
||||
const req = {
|
||||
user: { id: 'user123' },
|
||||
app: {
|
||||
locals: {
|
||||
ocr: {
|
||||
apiKey: 'OCR_API_KEY',
|
||||
baseURL: 'OCR_BASEURL',
|
||||
mistralModel: 'mistral-ocr-latest',
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const file = {
|
||||
path: '/tmp/upload/file.pdf',
|
||||
originalname: 'single-page.pdf',
|
||||
};
|
||||
|
||||
const result = await uploadMistralOCR({
|
||||
req,
|
||||
file,
|
||||
file_id: 'file123',
|
||||
entity_id: 'entity123',
|
||||
});
|
||||
|
||||
expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf');
|
||||
|
||||
// Verify that single page documents don't include page numbering
|
||||
expect(result.text).not.toContain('# PAGE');
|
||||
expect(result.text).toEqual('Single page content\n\n');
|
||||
});
|
||||
|
||||
it('should use literal values in configuration when provided directly', async () => {
|
||||
const { loadAuthValues } = require('~/server/services/Tools/credentials');
|
||||
// We'll still mock this but it should not be used for literal values
|
||||
loadAuthValues.mockResolvedValue({});
|
||||
|
||||
// Clear all previous mocks
|
||||
mockAxios.post.mockClear();
|
||||
mockAxios.get.mockClear();
|
||||
|
||||
// 1. First mock: File upload response
|
||||
mockAxios.post.mockImplementationOnce(() =>
|
||||
Promise.resolve({ data: { id: 'file-123', purpose: 'ocr' } }),
|
||||
);
|
||||
|
||||
// 2. Second mock: Signed URL response
|
||||
mockAxios.get.mockImplementationOnce(() =>
|
||||
Promise.resolve({ data: { url: 'https://signed-url.com' } }),
|
||||
);
|
||||
|
||||
// 3. Third mock: OCR response
|
||||
mockAxios.post.mockImplementationOnce(() =>
|
||||
Promise.resolve({
|
||||
data: {
|
||||
pages: [{ markdown: 'Processed with literal config values' }],
|
||||
},
|
||||
}),
|
||||
);
|
||||
|
||||
const req = {
|
||||
user: { id: 'user123' },
|
||||
app: {
|
||||
locals: {
|
||||
ocr: {
|
||||
// Direct values that should be used as-is, without variable substitution
|
||||
apiKey: 'actual-api-key-value',
|
||||
baseURL: 'https://direct-api-url.mistral.ai/v1',
|
||||
mistralModel: 'mistral-direct-model',
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const file = {
|
||||
path: '/tmp/upload/file.pdf',
|
||||
originalname: 'direct-values.pdf',
|
||||
};
|
||||
|
||||
const result = await uploadMistralOCR({
|
||||
req,
|
||||
file,
|
||||
file_id: 'file123',
|
||||
entity_id: 'entity123',
|
||||
});
|
||||
|
||||
expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf');
|
||||
|
||||
// Verify the correct URL was used with the direct baseURL value
|
||||
expect(mockAxios.post).toHaveBeenCalledWith(
|
||||
'https://direct-api-url.mistral.ai/v1/files',
|
||||
expect.any(Object),
|
||||
expect.objectContaining({
|
||||
headers: expect.objectContaining({
|
||||
Authorization: 'Bearer actual-api-key-value',
|
||||
}),
|
||||
}),
|
||||
);
|
||||
|
||||
// Check the OCR call was made with the direct model value
|
||||
expect(mockAxios.post).toHaveBeenCalledWith(
|
||||
'https://direct-api-url.mistral.ai/v1/ocr',
|
||||
expect.objectContaining({
|
||||
model: 'mistral-direct-model',
|
||||
}),
|
||||
expect.any(Object),
|
||||
);
|
||||
|
||||
// Verify the result
|
||||
expect(result.text).toEqual('Processed with literal config values\n\n');
|
||||
|
||||
// Verify loadAuthValues was never called since we used direct values
|
||||
expect(loadAuthValues).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle empty configuration values and use defaults', async () => {
|
||||
const { loadAuthValues } = require('~/server/services/Tools/credentials');
|
||||
// Set up the mock values to be returned by loadAuthValues
|
||||
loadAuthValues.mockResolvedValue({
|
||||
OCR_API_KEY: 'default-from-env-key',
|
||||
OCR_BASEURL: 'https://default-from-env.mistral.ai/v1',
|
||||
});
|
||||
|
||||
// Clear all previous mocks
|
||||
mockAxios.post.mockClear();
|
||||
mockAxios.get.mockClear();
|
||||
|
||||
// 1. First mock: File upload response
|
||||
mockAxios.post.mockImplementationOnce(() =>
|
||||
Promise.resolve({ data: { id: 'file-123', purpose: 'ocr' } }),
|
||||
);
|
||||
|
||||
// 2. Second mock: Signed URL response
|
||||
mockAxios.get.mockImplementationOnce(() =>
|
||||
Promise.resolve({ data: { url: 'https://signed-url.com' } }),
|
||||
);
|
||||
|
||||
// 3. Third mock: OCR response
|
||||
mockAxios.post.mockImplementationOnce(() =>
|
||||
Promise.resolve({
|
||||
data: {
|
||||
pages: [{ markdown: 'Content from default configuration' }],
|
||||
},
|
||||
}),
|
||||
);
|
||||
|
||||
const req = {
|
||||
user: { id: 'user123' },
|
||||
app: {
|
||||
locals: {
|
||||
ocr: {
|
||||
// Empty string values - should fall back to defaults
|
||||
apiKey: '',
|
||||
baseURL: '',
|
||||
mistralModel: '',
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const file = {
|
||||
path: '/tmp/upload/file.pdf',
|
||||
originalname: 'empty-config.pdf',
|
||||
};
|
||||
|
||||
const result = await uploadMistralOCR({
|
||||
req,
|
||||
file,
|
||||
file_id: 'file123',
|
||||
entity_id: 'entity123',
|
||||
});
|
||||
|
||||
expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf');
|
||||
|
||||
// Verify loadAuthValues was called with the default variable names
|
||||
expect(loadAuthValues).toHaveBeenCalledWith({
|
||||
userId: 'user123',
|
||||
authFields: ['OCR_BASEURL', 'OCR_API_KEY'],
|
||||
optional: expect.any(Set),
|
||||
});
|
||||
|
||||
// Verify the API calls used the default values from loadAuthValues
|
||||
expect(mockAxios.post).toHaveBeenCalledWith(
|
||||
'https://default-from-env.mistral.ai/v1/files',
|
||||
expect.any(Object),
|
||||
expect.objectContaining({
|
||||
headers: expect.objectContaining({
|
||||
Authorization: 'Bearer default-from-env-key',
|
||||
}),
|
||||
}),
|
||||
);
|
||||
|
||||
// Verify the OCR model defaulted to mistral-ocr-latest
|
||||
expect(mockAxios.post).toHaveBeenCalledWith(
|
||||
'https://default-from-env.mistral.ai/v1/ocr',
|
||||
expect.objectContaining({
|
||||
model: 'mistral-ocr-latest',
|
||||
}),
|
||||
expect.any(Object),
|
||||
);
|
||||
|
||||
// Check result
|
||||
expect(result.text).toEqual('Content from default configuration\n\n');
|
||||
});
|
||||
});
|
||||
});
|
||||
5
api/server/services/Files/MistralOCR/index.js
Normal file
5
api/server/services/Files/MistralOCR/index.js
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
const crud = require('./crud');
|
||||
|
||||
module.exports = {
|
||||
...crud,
|
||||
};
|
||||
163
api/server/services/Files/S3/crud.js
Normal file
163
api/server/services/Files/S3/crud.js
Normal file
|
|
@ -0,0 +1,163 @@
|
|||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
const fetch = require('node-fetch');
|
||||
const { PutObjectCommand, GetObjectCommand, DeleteObjectCommand } = require('@aws-sdk/client-s3');
|
||||
const { getSignedUrl } = require('@aws-sdk/s3-request-presigner');
|
||||
const { initializeS3 } = require('./initialize');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const bucketName = process.env.AWS_BUCKET_NAME;
|
||||
const defaultBasePath = 'images';
|
||||
|
||||
/**
|
||||
* Constructs the S3 key based on the base path, user ID, and file name.
|
||||
*/
|
||||
const getS3Key = (basePath, userId, fileName) => `${basePath}/${userId}/${fileName}`;
|
||||
|
||||
/**
|
||||
* Uploads a buffer to S3 and returns a signed URL.
|
||||
*
|
||||
* @param {Object} params
|
||||
* @param {string} params.userId - The user's unique identifier.
|
||||
* @param {Buffer} params.buffer - The buffer containing file data.
|
||||
* @param {string} params.fileName - The file name to use in S3.
|
||||
* @param {string} [params.basePath='images'] - The base path in the bucket.
|
||||
* @returns {Promise<string>} Signed URL of the uploaded file.
|
||||
*/
|
||||
async function saveBufferToS3({ userId, buffer, fileName, basePath = defaultBasePath }) {
|
||||
const key = getS3Key(basePath, userId, fileName);
|
||||
const params = { Bucket: bucketName, Key: key, Body: buffer };
|
||||
|
||||
try {
|
||||
const s3 = initializeS3();
|
||||
await s3.send(new PutObjectCommand(params));
|
||||
return await getS3URL({ userId, fileName, basePath });
|
||||
} catch (error) {
|
||||
logger.error('[saveBufferToS3] Error uploading buffer to S3:', error.message);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves a signed URL for a file stored in S3.
|
||||
*
|
||||
* @param {Object} params
|
||||
* @param {string} params.userId - The user's unique identifier.
|
||||
* @param {string} params.fileName - The file name in S3.
|
||||
* @param {string} [params.basePath='images'] - The base path in the bucket.
|
||||
* @returns {Promise<string>} A signed URL valid for 24 hours.
|
||||
*/
|
||||
async function getS3URL({ userId, fileName, basePath = defaultBasePath }) {
|
||||
const key = getS3Key(basePath, userId, fileName);
|
||||
const params = { Bucket: bucketName, Key: key };
|
||||
|
||||
try {
|
||||
const s3 = initializeS3();
|
||||
return await getSignedUrl(s3, new GetObjectCommand(params), { expiresIn: 86400 });
|
||||
} catch (error) {
|
||||
logger.error('[getS3URL] Error getting signed URL from S3:', error.message);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Saves a file from a given URL to S3.
|
||||
*
|
||||
* @param {Object} params
|
||||
* @param {string} params.userId - The user's unique identifier.
|
||||
* @param {string} params.URL - The source URL of the file.
|
||||
* @param {string} params.fileName - The file name to use in S3.
|
||||
* @param {string} [params.basePath='images'] - The base path in the bucket.
|
||||
* @returns {Promise<string>} Signed URL of the uploaded file.
|
||||
*/
|
||||
async function saveURLToS3({ userId, URL, fileName, basePath = defaultBasePath }) {
|
||||
try {
|
||||
const response = await fetch(URL);
|
||||
const buffer = await response.buffer();
|
||||
// Optionally you can call getBufferMetadata(buffer) if needed.
|
||||
return await saveBufferToS3({ userId, buffer, fileName, basePath });
|
||||
} catch (error) {
|
||||
logger.error('[saveURLToS3] Error uploading file from URL to S3:', error.message);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Deletes a file from S3.
|
||||
*
|
||||
* @param {Object} params
|
||||
* @param {string} params.userId - The user's unique identifier.
|
||||
* @param {string} params.fileName - The file name in S3.
|
||||
* @param {string} [params.basePath='images'] - The base path in the bucket.
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
async function deleteFileFromS3({ userId, fileName, basePath = defaultBasePath }) {
|
||||
const key = getS3Key(basePath, userId, fileName);
|
||||
const params = { Bucket: bucketName, Key: key };
|
||||
|
||||
try {
|
||||
const s3 = initializeS3();
|
||||
await s3.send(new DeleteObjectCommand(params));
|
||||
logger.debug('[deleteFileFromS3] File deleted successfully from S3');
|
||||
} catch (error) {
|
||||
logger.error('[deleteFileFromS3] Error deleting file from S3:', error.message);
|
||||
// If the file is not found, we can safely return.
|
||||
if (error.code === 'NoSuchKey') {
|
||||
return;
|
||||
}
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Uploads a local file to S3.
|
||||
*
|
||||
* @param {Object} params
|
||||
* @param {import('express').Request} params.req - The Express request (must include user).
|
||||
* @param {Express.Multer.File} params.file - The file object from Multer.
|
||||
* @param {string} params.file_id - Unique file identifier.
|
||||
* @param {string} [params.basePath='images'] - The base path in the bucket.
|
||||
* @returns {Promise<{ filepath: string, bytes: number }>}
|
||||
*/
|
||||
async function uploadFileToS3({ req, file, file_id, basePath = defaultBasePath }) {
|
||||
try {
|
||||
const inputFilePath = file.path;
|
||||
const inputBuffer = await fs.promises.readFile(inputFilePath);
|
||||
const bytes = Buffer.byteLength(inputBuffer);
|
||||
const userId = req.user.id;
|
||||
const fileName = `${file_id}__${path.basename(inputFilePath)}`;
|
||||
const fileURL = await saveBufferToS3({ userId, buffer: inputBuffer, fileName, basePath });
|
||||
await fs.promises.unlink(inputFilePath);
|
||||
return { filepath: fileURL, bytes };
|
||||
} catch (error) {
|
||||
logger.error('[uploadFileToS3] Error uploading file to S3:', error.message);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves a readable stream for a file stored in S3.
|
||||
*
|
||||
* @param {string} filePath - The S3 key of the file.
|
||||
* @returns {Promise<NodeJS.ReadableStream>}
|
||||
*/
|
||||
async function getS3FileStream(filePath) {
|
||||
const params = { Bucket: bucketName, Key: filePath };
|
||||
try {
|
||||
const s3 = initializeS3();
|
||||
const data = await s3.send(new GetObjectCommand(params));
|
||||
return data.Body; // Returns a Node.js ReadableStream.
|
||||
} catch (error) {
|
||||
logger.error('[getS3FileStream] Error retrieving S3 file stream:', error.message);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
saveBufferToS3,
|
||||
saveURLToS3,
|
||||
getS3URL,
|
||||
deleteFileFromS3,
|
||||
uploadFileToS3,
|
||||
getS3FileStream,
|
||||
};
|
||||
118
api/server/services/Files/S3/images.js
Normal file
118
api/server/services/Files/S3/images.js
Normal file
|
|
@ -0,0 +1,118 @@
|
|||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
const sharp = require('sharp');
|
||||
const { resizeImageBuffer } = require('../images/resize');
|
||||
const { updateUser } = require('~/models/userMethods');
|
||||
const { saveBufferToS3 } = require('./crud');
|
||||
const { updateFile } = require('~/models/File');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const defaultBasePath = 'images';
|
||||
|
||||
/**
|
||||
* Resizes, converts, and uploads an image file to S3.
|
||||
*
|
||||
* @param {Object} params
|
||||
* @param {import('express').Request} params.req - Express request (expects user and app.locals.imageOutputType).
|
||||
* @param {Express.Multer.File} params.file - File object from Multer.
|
||||
* @param {string} params.file_id - Unique file identifier.
|
||||
* @param {any} params.endpoint - Endpoint identifier used in image processing.
|
||||
* @param {string} [params.resolution='high'] - Desired image resolution.
|
||||
* @param {string} [params.basePath='images'] - Base path in the bucket.
|
||||
* @returns {Promise<{ filepath: string, bytes: number, width: number, height: number }>}
|
||||
*/
|
||||
async function uploadImageToS3({
|
||||
req,
|
||||
file,
|
||||
file_id,
|
||||
endpoint,
|
||||
resolution = 'high',
|
||||
basePath = defaultBasePath,
|
||||
}) {
|
||||
try {
|
||||
const inputFilePath = file.path;
|
||||
const inputBuffer = await fs.promises.readFile(inputFilePath);
|
||||
const {
|
||||
buffer: resizedBuffer,
|
||||
width,
|
||||
height,
|
||||
} = await resizeImageBuffer(inputBuffer, resolution, endpoint);
|
||||
const extension = path.extname(inputFilePath);
|
||||
const userId = req.user.id;
|
||||
|
||||
let processedBuffer;
|
||||
let fileName = `${file_id}__${path.basename(inputFilePath)}`;
|
||||
const targetExtension = `.${req.app.locals.imageOutputType}`;
|
||||
|
||||
if (extension.toLowerCase() === targetExtension) {
|
||||
processedBuffer = resizedBuffer;
|
||||
} else {
|
||||
processedBuffer = await sharp(resizedBuffer)
|
||||
.toFormat(req.app.locals.imageOutputType)
|
||||
.toBuffer();
|
||||
fileName = fileName.replace(new RegExp(path.extname(fileName) + '$'), targetExtension);
|
||||
if (!path.extname(fileName)) {
|
||||
fileName += targetExtension;
|
||||
}
|
||||
}
|
||||
|
||||
const downloadURL = await saveBufferToS3({
|
||||
userId,
|
||||
buffer: processedBuffer,
|
||||
fileName,
|
||||
basePath,
|
||||
});
|
||||
await fs.promises.unlink(inputFilePath);
|
||||
const bytes = Buffer.byteLength(processedBuffer);
|
||||
return { filepath: downloadURL, bytes, width, height };
|
||||
} catch (error) {
|
||||
logger.error('[uploadImageToS3] Error uploading image to S3:', error.message);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Updates a file record and returns its signed URL.
|
||||
*
|
||||
* @param {import('express').Request} req - Express request.
|
||||
* @param {Object} file - File metadata.
|
||||
* @returns {Promise<[Promise<any>, string]>}
|
||||
*/
|
||||
async function prepareImageURLS3(req, file) {
|
||||
try {
|
||||
const updatePromise = updateFile({ file_id: file.file_id });
|
||||
return Promise.all([updatePromise, file.filepath]);
|
||||
} catch (error) {
|
||||
logger.error('[prepareImageURLS3] Error preparing image URL:', error.message);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Processes a user's avatar image by uploading it to S3 and updating the user's avatar URL if required.
|
||||
*
|
||||
* @param {Object} params
|
||||
* @param {Buffer} params.buffer - Avatar image buffer.
|
||||
* @param {string} params.userId - User's unique identifier.
|
||||
* @param {string} params.manual - 'true' or 'false' flag for manual update.
|
||||
* @param {string} [params.basePath='images'] - Base path in the bucket.
|
||||
* @returns {Promise<string>} Signed URL of the uploaded avatar.
|
||||
*/
|
||||
async function processS3Avatar({ buffer, userId, manual, basePath = defaultBasePath }) {
|
||||
try {
|
||||
const downloadURL = await saveBufferToS3({ userId, buffer, fileName: 'avatar.png', basePath });
|
||||
if (manual === 'true') {
|
||||
await updateUser(userId, { avatar: downloadURL });
|
||||
}
|
||||
return downloadURL;
|
||||
} catch (error) {
|
||||
logger.error('[processS3Avatar] Error processing S3 avatar:', error.message);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
uploadImageToS3,
|
||||
prepareImageURLS3,
|
||||
processS3Avatar,
|
||||
};
|
||||
9
api/server/services/Files/S3/index.js
Normal file
9
api/server/services/Files/S3/index.js
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
const crud = require('./crud');
|
||||
const images = require('./images');
|
||||
const initialize = require('./initialize');
|
||||
|
||||
module.exports = {
|
||||
...crud,
|
||||
...images,
|
||||
...initialize,
|
||||
};
|
||||
53
api/server/services/Files/S3/initialize.js
Normal file
53
api/server/services/Files/S3/initialize.js
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
const { S3Client } = require('@aws-sdk/client-s3');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
let s3 = null;
|
||||
|
||||
/**
|
||||
* Initializes and returns an instance of the AWS S3 client.
|
||||
*
|
||||
* If AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY are provided, they will be used.
|
||||
* Otherwise, the AWS SDK's default credentials chain (including IRSA) is used.
|
||||
*
|
||||
* If AWS_ENDPOINT_URL is provided, it will be used as the endpoint.
|
||||
*
|
||||
* @returns {S3Client|null} An instance of S3Client if the region is provided; otherwise, null.
|
||||
*/
|
||||
const initializeS3 = () => {
|
||||
if (s3) {
|
||||
return s3;
|
||||
}
|
||||
|
||||
const region = process.env.AWS_REGION;
|
||||
if (!region) {
|
||||
logger.error('[initializeS3] AWS_REGION is not set. Cannot initialize S3.');
|
||||
return null;
|
||||
}
|
||||
|
||||
// Read the custom endpoint if provided.
|
||||
const endpoint = process.env.AWS_ENDPOINT_URL;
|
||||
const accessKeyId = process.env.AWS_ACCESS_KEY_ID;
|
||||
const secretAccessKey = process.env.AWS_SECRET_ACCESS_KEY;
|
||||
|
||||
const config = {
|
||||
region,
|
||||
// Conditionally add the endpoint if it is provided
|
||||
...(endpoint ? { endpoint } : {}),
|
||||
};
|
||||
|
||||
if (accessKeyId && secretAccessKey) {
|
||||
s3 = new S3Client({
|
||||
...config,
|
||||
credentials: { accessKeyId, secretAccessKey },
|
||||
});
|
||||
logger.info('[initializeS3] S3 initialized with provided credentials.');
|
||||
} else {
|
||||
// When using IRSA, credentials are automatically provided via the IAM Role attached to the ServiceAccount.
|
||||
s3 = new S3Client(config);
|
||||
logger.info('[initializeS3] S3 initialized using default credentials (IRSA).');
|
||||
}
|
||||
|
||||
return s3;
|
||||
};
|
||||
|
||||
module.exports = { initializeS3 };
|
||||
|
|
@ -37,7 +37,14 @@ const deleteVectors = async (req, file) => {
|
|||
error,
|
||||
message: 'Error deleting vectors',
|
||||
});
|
||||
throw new Error(error.message || 'An error occurred during file deletion.');
|
||||
if (
|
||||
error.response &&
|
||||
error.response.status !== 404 &&
|
||||
(error.response.status < 200 || error.response.status >= 300)
|
||||
) {
|
||||
logger.warn('Error deleting vectors, file will not be deleted');
|
||||
throw new Error(error.message || 'An error occurred during file deletion.');
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -49,6 +49,7 @@ async function encodeAndFormat(req, files, endpoint, mode) {
|
|||
const promises = [];
|
||||
const encodingMethods = {};
|
||||
const result = {
|
||||
text: '',
|
||||
files: [],
|
||||
image_urls: [],
|
||||
};
|
||||
|
|
@ -59,6 +60,9 @@ async function encodeAndFormat(req, files, endpoint, mode) {
|
|||
|
||||
for (let file of files) {
|
||||
const source = file.source ?? FileSources.local;
|
||||
if (source === FileSources.text && file.text) {
|
||||
result.text += `${!result.text ? 'Attached document(s):\n```md' : '\n\n---\n\n'}# "${file.filename}"\n${file.text}\n`;
|
||||
}
|
||||
|
||||
if (!file.height) {
|
||||
promises.push([file, null]);
|
||||
|
|
@ -85,6 +89,10 @@ async function encodeAndFormat(req, files, endpoint, mode) {
|
|||
promises.push(preparePayload(req, file));
|
||||
}
|
||||
|
||||
if (result.text) {
|
||||
result.text += '\n```';
|
||||
}
|
||||
|
||||
const detail = req.body.imageDetail ?? ImageDetail.auto;
|
||||
|
||||
/** @type {Array<[MongoFile, string]>} */
|
||||
|
|
|
|||
|
|
@ -28,8 +28,8 @@ const { addResourceFileId, deleteResourceFileId } = require('~/server/controller
|
|||
const { addAgentResourceFile, removeAgentResourceFiles } = require('~/models/Agent');
|
||||
const { getOpenAIClient } = require('~/server/controllers/assistants/helpers');
|
||||
const { createFile, updateFileUsage, deleteFiles } = require('~/models/File');
|
||||
const { getEndpointsConfig } = require('~/server/services/Config');
|
||||
const { loadAuthValues } = require('~/app/clients/tools/util');
|
||||
const { loadAuthValues } = require('~/server/services/Tools/credentials');
|
||||
const { checkCapability } = require('~/server/services/Config');
|
||||
const { LB_QueueAsyncCall } = require('~/server/utils/queue');
|
||||
const { getStrategyFunctions } = require('./strategies');
|
||||
const { determineFileType } = require('~/server/utils');
|
||||
|
|
@ -162,7 +162,6 @@ const processDeleteRequest = async ({ req, files }) => {
|
|||
|
||||
for (const file of files) {
|
||||
const source = file.source ?? FileSources.local;
|
||||
|
||||
if (req.body.agent_id && req.body.tool_resource) {
|
||||
agentFiles.push({
|
||||
tool_resource: req.body.tool_resource,
|
||||
|
|
@ -170,6 +169,11 @@ const processDeleteRequest = async ({ req, files }) => {
|
|||
});
|
||||
}
|
||||
|
||||
if (source === FileSources.text) {
|
||||
resolvedFileIds.push(file.file_id);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (checkOpenAIStorage(source) && !client[source]) {
|
||||
await initializeClients();
|
||||
}
|
||||
|
|
@ -347,8 +351,8 @@ const uploadImageBuffer = async ({ req, context, metadata = {}, resize = true })
|
|||
req.app.locals.imageOutputType
|
||||
}`;
|
||||
}
|
||||
|
||||
const filepath = await saveBuffer({ userId: req.user.id, fileName: filename, buffer });
|
||||
const fileName = `${file_id}-${filename}`;
|
||||
const filepath = await saveBuffer({ userId: req.user.id, fileName, buffer });
|
||||
return await createFile(
|
||||
{
|
||||
user: req.user.id,
|
||||
|
|
@ -453,17 +457,6 @@ const processFileUpload = async ({ req, res, metadata }) => {
|
|||
res.status(200).json({ message: 'File uploaded and processed successfully', ...result });
|
||||
};
|
||||
|
||||
/**
|
||||
* @param {ServerRequest} req
|
||||
* @param {AgentCapabilities} capability
|
||||
* @returns {Promise<boolean>}
|
||||
*/
|
||||
const checkCapability = async (req, capability) => {
|
||||
const endpointsConfig = await getEndpointsConfig(req);
|
||||
const capabilities = endpointsConfig?.[EModelEndpoint.agents]?.capabilities ?? [];
|
||||
return capabilities.includes(capability);
|
||||
};
|
||||
|
||||
/**
|
||||
* Applies the current strategy for file uploads.
|
||||
* Saves file metadata to the database with an expiry TTL.
|
||||
|
|
@ -521,6 +514,52 @@ const processAgentFileUpload = async ({ req, res, metadata }) => {
|
|||
if (!isFileSearchEnabled) {
|
||||
throw new Error('File search is not enabled for Agents');
|
||||
}
|
||||
} else if (tool_resource === EToolResources.ocr) {
|
||||
const isOCREnabled = await checkCapability(req, AgentCapabilities.ocr);
|
||||
if (!isOCREnabled) {
|
||||
throw new Error('OCR capability is not enabled for Agents');
|
||||
}
|
||||
|
||||
const { handleFileUpload } = getStrategyFunctions(
|
||||
req.app.locals?.ocr?.strategy ?? FileSources.mistral_ocr,
|
||||
);
|
||||
const { file_id, temp_file_id } = metadata;
|
||||
|
||||
const {
|
||||
text,
|
||||
bytes,
|
||||
// TODO: OCR images support?
|
||||
images,
|
||||
filename,
|
||||
filepath: ocrFileURL,
|
||||
} = await handleFileUpload({ req, file, file_id, entity_id: agent_id });
|
||||
|
||||
const fileInfo = removeNullishValues({
|
||||
text,
|
||||
bytes,
|
||||
file_id,
|
||||
temp_file_id,
|
||||
user: req.user.id,
|
||||
type: file.mimetype,
|
||||
filepath: ocrFileURL,
|
||||
source: FileSources.text,
|
||||
filename: filename ?? file.originalname,
|
||||
model: messageAttachment ? undefined : req.body.model,
|
||||
context: messageAttachment ? FileContext.message_attachment : FileContext.agents,
|
||||
});
|
||||
|
||||
if (!messageAttachment && tool_resource) {
|
||||
await addAgentResourceFile({
|
||||
req,
|
||||
file_id,
|
||||
agent_id,
|
||||
tool_resource,
|
||||
});
|
||||
}
|
||||
const result = await createFile(fileInfo, true);
|
||||
return res
|
||||
.status(200)
|
||||
.json({ message: 'Agent file uploaded and processed successfully', ...result });
|
||||
}
|
||||
|
||||
const source =
|
||||
|
|
@ -801,8 +840,7 @@ async function saveBase64Image(
|
|||
{ req, file_id: _file_id, filename: _filename, endpoint, context, resolution = 'high' },
|
||||
) {
|
||||
const file_id = _file_id ?? v4();
|
||||
|
||||
let filename = _filename;
|
||||
let filename = `${file_id}-${_filename}`;
|
||||
const { buffer: inputBuffer, type } = base64ToBuffer(url);
|
||||
if (!path.extname(_filename)) {
|
||||
const extension = mime.getExtension(type);
|
||||
|
|
|
|||
|
|
@ -21,9 +21,32 @@ const {
|
|||
processLocalAvatar,
|
||||
getLocalFileStream,
|
||||
} = require('./Local');
|
||||
const {
|
||||
getS3URL,
|
||||
saveURLToS3,
|
||||
saveBufferToS3,
|
||||
getS3FileStream,
|
||||
uploadImageToS3,
|
||||
prepareImageURLS3,
|
||||
deleteFileFromS3,
|
||||
processS3Avatar,
|
||||
uploadFileToS3,
|
||||
} = require('./S3');
|
||||
const {
|
||||
saveBufferToAzure,
|
||||
saveURLToAzure,
|
||||
getAzureURL,
|
||||
deleteFileFromAzure,
|
||||
uploadFileToAzure,
|
||||
getAzureFileStream,
|
||||
uploadImageToAzure,
|
||||
prepareAzureImageURL,
|
||||
processAzureAvatar,
|
||||
} = require('./Azure');
|
||||
const { uploadOpenAIFile, deleteOpenAIFile, getOpenAIFileStream } = require('./OpenAI');
|
||||
const { getCodeOutputDownloadStream, uploadCodeEnvFile } = require('./Code');
|
||||
const { uploadVectors, deleteVectors } = require('./VectorDB');
|
||||
const { uploadMistralOCR } = require('./MistralOCR');
|
||||
|
||||
/**
|
||||
* Firebase Storage Strategy Functions
|
||||
|
|
@ -57,6 +80,38 @@ const localStrategy = () => ({
|
|||
getDownloadStream: getLocalFileStream,
|
||||
});
|
||||
|
||||
/**
|
||||
* S3 Storage Strategy Functions
|
||||
*
|
||||
* */
|
||||
const s3Strategy = () => ({
|
||||
handleFileUpload: uploadFileToS3,
|
||||
saveURL: saveURLToS3,
|
||||
getFileURL: getS3URL,
|
||||
deleteFile: deleteFileFromS3,
|
||||
saveBuffer: saveBufferToS3,
|
||||
prepareImagePayload: prepareImageURLS3,
|
||||
processAvatar: processS3Avatar,
|
||||
handleImageUpload: uploadImageToS3,
|
||||
getDownloadStream: getS3FileStream,
|
||||
});
|
||||
|
||||
/**
|
||||
* Azure Blob Storage Strategy Functions
|
||||
*
|
||||
* */
|
||||
const azureStrategy = () => ({
|
||||
handleFileUpload: uploadFileToAzure,
|
||||
saveURL: saveURLToAzure,
|
||||
getFileURL: getAzureURL,
|
||||
deleteFile: deleteFileFromAzure,
|
||||
saveBuffer: saveBufferToAzure,
|
||||
prepareImagePayload: prepareAzureImageURL,
|
||||
processAvatar: processAzureAvatar,
|
||||
handleImageUpload: uploadImageToAzure,
|
||||
getDownloadStream: getAzureFileStream,
|
||||
});
|
||||
|
||||
/**
|
||||
* VectorDB Storage Strategy Functions
|
||||
*
|
||||
|
|
@ -127,6 +182,26 @@ const codeOutputStrategy = () => ({
|
|||
getDownloadStream: getCodeOutputDownloadStream,
|
||||
});
|
||||
|
||||
const mistralOCRStrategy = () => ({
|
||||
/** @type {typeof saveFileFromURL | null} */
|
||||
saveURL: null,
|
||||
/** @type {typeof getLocalFileURL | null} */
|
||||
getFileURL: null,
|
||||
/** @type {typeof saveLocalBuffer | null} */
|
||||
saveBuffer: null,
|
||||
/** @type {typeof processLocalAvatar | null} */
|
||||
processAvatar: null,
|
||||
/** @type {typeof uploadLocalImage | null} */
|
||||
handleImageUpload: null,
|
||||
/** @type {typeof prepareImagesLocal | null} */
|
||||
prepareImagePayload: null,
|
||||
/** @type {typeof deleteLocalFile | null} */
|
||||
deleteFile: null,
|
||||
/** @type {typeof getLocalFileStream | null} */
|
||||
getDownloadStream: null,
|
||||
handleFileUpload: uploadMistralOCR,
|
||||
});
|
||||
|
||||
// Strategy Selector
|
||||
const getStrategyFunctions = (fileSource) => {
|
||||
if (fileSource === FileSources.firebase) {
|
||||
|
|
@ -136,11 +211,15 @@ const getStrategyFunctions = (fileSource) => {
|
|||
} else if (fileSource === FileSources.openai) {
|
||||
return openAIStrategy();
|
||||
} else if (fileSource === FileSources.azure) {
|
||||
return openAIStrategy();
|
||||
return azureStrategy();
|
||||
} else if (fileSource === FileSources.vectordb) {
|
||||
return vectorStrategy();
|
||||
} else if (fileSource === FileSources.s3) {
|
||||
return s3Strategy();
|
||||
} else if (fileSource === FileSources.execute_code) {
|
||||
return codeOutputStrategy();
|
||||
} else if (fileSource === FileSources.mistral_ocr) {
|
||||
return mistralOCRStrategy();
|
||||
} else {
|
||||
throw new Error('Invalid file source');
|
||||
}
|
||||
|
|
|
|||
|
|
@ -37,11 +37,19 @@ async function createMCPTool({ req, toolKey, provider }) {
|
|||
}
|
||||
|
||||
const [toolName, serverName] = toolKey.split(Constants.mcp_delimiter);
|
||||
/** @type {(toolInput: Object | string) => Promise<unknown>} */
|
||||
const _call = async (toolInput) => {
|
||||
/** @type {(toolArguments: Object | string, config?: GraphRunnableConfig) => Promise<unknown>} */
|
||||
const _call = async (toolArguments, config) => {
|
||||
try {
|
||||
const mcpManager = await getMCPManager();
|
||||
const result = await mcpManager.callTool(serverName, toolName, provider, toolInput);
|
||||
const result = await mcpManager.callTool({
|
||||
serverName,
|
||||
toolName,
|
||||
provider,
|
||||
toolArguments,
|
||||
options: {
|
||||
signal: config?.signal,
|
||||
},
|
||||
});
|
||||
if (isAssistantsEndpoint(provider) && Array.isArray(result)) {
|
||||
return result[0];
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,9 +1,12 @@
|
|||
const axios = require('axios');
|
||||
const { Providers } = require('@librechat/agents');
|
||||
const { HttpsProxyAgent } = require('https-proxy-agent');
|
||||
const { EModelEndpoint, defaultModels, CacheKeys } = require('librechat-data-provider');
|
||||
const { inputSchema, logAxiosError, extractBaseURL, processModelData } = require('~/utils');
|
||||
const { OllamaClient } = require('~/app/clients/OllamaClient');
|
||||
const { isUserProvided } = require('~/server/utils');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
/**
|
||||
* Splits a string by commas and trims each resulting value.
|
||||
|
|
@ -41,7 +44,7 @@ const fetchModels = async ({
|
|||
user,
|
||||
apiKey,
|
||||
baseURL,
|
||||
name = 'OpenAI',
|
||||
name = EModelEndpoint.openAI,
|
||||
azure = false,
|
||||
userIdQuery = false,
|
||||
createTokenConfig = true,
|
||||
|
|
@ -57,18 +60,25 @@ const fetchModels = async ({
|
|||
return models;
|
||||
}
|
||||
|
||||
if (name && name.toLowerCase().startsWith('ollama')) {
|
||||
if (name && name.toLowerCase().startsWith(Providers.OLLAMA)) {
|
||||
return await OllamaClient.fetchModels(baseURL);
|
||||
}
|
||||
|
||||
try {
|
||||
const options = {
|
||||
headers: {
|
||||
Authorization: `Bearer ${apiKey}`,
|
||||
},
|
||||
headers: {},
|
||||
timeout: 5000,
|
||||
};
|
||||
|
||||
if (name === EModelEndpoint.anthropic) {
|
||||
options.headers = {
|
||||
'x-api-key': apiKey,
|
||||
'anthropic-version': process.env.ANTHROPIC_VERSION || '2023-06-01',
|
||||
};
|
||||
} else {
|
||||
options.headers.Authorization = `Bearer ${apiKey}`;
|
||||
}
|
||||
|
||||
if (process.env.PROXY) {
|
||||
options.httpsAgent = new HttpsProxyAgent(process.env.PROXY);
|
||||
}
|
||||
|
|
@ -128,9 +138,6 @@ const fetchOpenAIModels = async (opts, _models = []) => {
|
|||
// .split('/deployments')[0]
|
||||
// .concat(`/models?api-version=${azure.azureOpenAIApiVersion}`);
|
||||
// apiKey = azureOpenAIApiKey;
|
||||
} else if (process.env.OPENROUTER_API_KEY) {
|
||||
reverseProxyUrl = 'https://openrouter.ai/api/v1';
|
||||
apiKey = process.env.OPENROUTER_API_KEY;
|
||||
}
|
||||
|
||||
if (reverseProxyUrl) {
|
||||
|
|
@ -150,7 +157,7 @@ const fetchOpenAIModels = async (opts, _models = []) => {
|
|||
baseURL,
|
||||
azure: opts.azure,
|
||||
user: opts.user,
|
||||
name: baseURL,
|
||||
name: EModelEndpoint.openAI,
|
||||
});
|
||||
}
|
||||
|
||||
|
|
@ -159,7 +166,7 @@ const fetchOpenAIModels = async (opts, _models = []) => {
|
|||
}
|
||||
|
||||
if (baseURL === openaiBaseURL) {
|
||||
const regex = /(text-davinci-003|gpt-|o\d+-)/;
|
||||
const regex = /(text-davinci-003|gpt-|o\d+)/;
|
||||
const excludeRegex = /audio|realtime/;
|
||||
models = models.filter((model) => regex.test(model) && !excludeRegex.test(model));
|
||||
const instructModels = models.filter((model) => model.includes('instruct'));
|
||||
|
|
@ -217,7 +224,7 @@ const getOpenAIModels = async (opts) => {
|
|||
return models;
|
||||
}
|
||||
|
||||
if (userProvidedOpenAI && !process.env.OPENROUTER_API_KEY) {
|
||||
if (userProvidedOpenAI) {
|
||||
return models;
|
||||
}
|
||||
|
||||
|
|
@ -233,13 +240,71 @@ const getChatGPTBrowserModels = () => {
|
|||
return models;
|
||||
};
|
||||
|
||||
const getAnthropicModels = () => {
|
||||
/**
|
||||
* Fetches models from the Anthropic API.
|
||||
* @async
|
||||
* @function
|
||||
* @param {object} opts - The options for fetching the models.
|
||||
* @param {string} opts.user - The user ID to send to the API.
|
||||
* @param {string[]} [_models=[]] - The models to use as a fallback.
|
||||
*/
|
||||
const fetchAnthropicModels = async (opts, _models = []) => {
|
||||
let models = _models.slice() ?? [];
|
||||
let apiKey = process.env.ANTHROPIC_API_KEY;
|
||||
const anthropicBaseURL = 'https://api.anthropic.com/v1';
|
||||
let baseURL = anthropicBaseURL;
|
||||
let reverseProxyUrl = process.env.ANTHROPIC_REVERSE_PROXY;
|
||||
|
||||
if (reverseProxyUrl) {
|
||||
baseURL = extractBaseURL(reverseProxyUrl);
|
||||
}
|
||||
|
||||
if (!apiKey) {
|
||||
return models;
|
||||
}
|
||||
|
||||
const modelsCache = getLogStores(CacheKeys.MODEL_QUERIES);
|
||||
|
||||
const cachedModels = await modelsCache.get(baseURL);
|
||||
if (cachedModels) {
|
||||
return cachedModels;
|
||||
}
|
||||
|
||||
if (baseURL) {
|
||||
models = await fetchModels({
|
||||
apiKey,
|
||||
baseURL,
|
||||
user: opts.user,
|
||||
name: EModelEndpoint.anthropic,
|
||||
tokenKey: EModelEndpoint.anthropic,
|
||||
});
|
||||
}
|
||||
|
||||
if (models.length === 0) {
|
||||
return _models;
|
||||
}
|
||||
|
||||
await modelsCache.set(baseURL, models);
|
||||
return models;
|
||||
};
|
||||
|
||||
const getAnthropicModels = async (opts = {}) => {
|
||||
let models = defaultModels[EModelEndpoint.anthropic];
|
||||
if (process.env.ANTHROPIC_MODELS) {
|
||||
models = splitAndTrim(process.env.ANTHROPIC_MODELS);
|
||||
return models;
|
||||
}
|
||||
|
||||
return models;
|
||||
if (isUserProvided(process.env.ANTHROPIC_API_KEY)) {
|
||||
return models;
|
||||
}
|
||||
|
||||
try {
|
||||
return await fetchAnthropicModels(opts, models);
|
||||
} catch (error) {
|
||||
logger.error('Error fetching Anthropic models:', error);
|
||||
return models;
|
||||
}
|
||||
};
|
||||
|
||||
const getGoogleModels = () => {
|
||||
|
|
|
|||
|
|
@ -161,22 +161,6 @@ describe('getOpenAIModels', () => {
|
|||
expect(models).toEqual(expect.arrayContaining(['openai-model', 'openai-model-2']));
|
||||
});
|
||||
|
||||
it('attempts to use OPENROUTER_API_KEY if set', async () => {
|
||||
process.env.OPENROUTER_API_KEY = 'test-router-key';
|
||||
const expectedModels = ['model-router-1', 'model-router-2'];
|
||||
|
||||
axios.get.mockResolvedValue({
|
||||
data: {
|
||||
data: expectedModels.map((id) => ({ id })),
|
||||
},
|
||||
});
|
||||
|
||||
const models = await getOpenAIModels({ user: 'user456' });
|
||||
|
||||
expect(models).toEqual(expect.arrayContaining(expectedModels));
|
||||
expect(axios.get).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('utilizes proxy configuration when PROXY is set', async () => {
|
||||
axios.get.mockResolvedValue({
|
||||
data: {
|
||||
|
|
@ -368,15 +352,15 @@ describe('splitAndTrim', () => {
|
|||
});
|
||||
|
||||
describe('getAnthropicModels', () => {
|
||||
it('returns default models when ANTHROPIC_MODELS is not set', () => {
|
||||
it('returns default models when ANTHROPIC_MODELS is not set', async () => {
|
||||
delete process.env.ANTHROPIC_MODELS;
|
||||
const models = getAnthropicModels();
|
||||
const models = await getAnthropicModels();
|
||||
expect(models).toEqual(defaultModels[EModelEndpoint.anthropic]);
|
||||
});
|
||||
|
||||
it('returns models from ANTHROPIC_MODELS when set', () => {
|
||||
it('returns models from ANTHROPIC_MODELS when set', async () => {
|
||||
process.env.ANTHROPIC_MODELS = 'claude-1, claude-2 ';
|
||||
const models = getAnthropicModels();
|
||||
const models = await getAnthropicModels();
|
||||
expect(models).toEqual(['claude-1', 'claude-2']);
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -362,7 +362,12 @@ async function processRequiredActions(client, requiredActions) {
|
|||
continue;
|
||||
}
|
||||
|
||||
tool = await createActionTool({ action: actionSet, requestBuilder });
|
||||
tool = await createActionTool({
|
||||
req: client.req,
|
||||
res: client.res,
|
||||
action: actionSet,
|
||||
requestBuilder,
|
||||
});
|
||||
if (!tool) {
|
||||
logger.warn(
|
||||
`Invalid action: user: ${client.req.user.id} | thread_id: ${requiredActions[0].thread_id} | run_id: ${requiredActions[0].run_id} | toolName: ${currentAction.tool}`,
|
||||
|
|
|
|||
56
api/server/services/Tools/credentials.js
Normal file
56
api/server/services/Tools/credentials.js
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
const { getUserPluginAuthValue } = require('~/server/services/PluginService');
|
||||
|
||||
/**
|
||||
*
|
||||
* @param {Object} params
|
||||
* @param {string} params.userId
|
||||
* @param {string[]} params.authFields
|
||||
* @param {Set<string>} [params.optional]
|
||||
* @param {boolean} [params.throwError]
|
||||
* @returns
|
||||
*/
|
||||
const loadAuthValues = async ({ userId, authFields, optional, throwError = true }) => {
|
||||
let authValues = {};
|
||||
|
||||
/**
|
||||
* Finds the first non-empty value for the given authentication field, supporting alternate fields.
|
||||
* @param {string[]} fields Array of strings representing the authentication fields. Supports alternate fields delimited by "||".
|
||||
* @returns {Promise<{ authField: string, authValue: string} | null>} An object containing the authentication field and value, or null if not found.
|
||||
*/
|
||||
const findAuthValue = async (fields) => {
|
||||
for (const field of fields) {
|
||||
let value = process.env[field];
|
||||
if (value) {
|
||||
return { authField: field, authValue: value };
|
||||
}
|
||||
try {
|
||||
value = await getUserPluginAuthValue(userId, field, throwError);
|
||||
} catch (err) {
|
||||
if (optional && optional.has(field)) {
|
||||
return { authField: field, authValue: undefined };
|
||||
}
|
||||
if (field === fields[fields.length - 1] && !value) {
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
if (value) {
|
||||
return { authField: field, authValue: value };
|
||||
}
|
||||
}
|
||||
return null;
|
||||
};
|
||||
|
||||
for (let authField of authFields) {
|
||||
const fields = authField.split('||');
|
||||
const result = await findAuthValue(fields);
|
||||
if (result) {
|
||||
authValues[result.authField] = result.authValue;
|
||||
}
|
||||
}
|
||||
|
||||
return authValues;
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
loadAuthValues,
|
||||
};
|
||||
|
|
@ -34,6 +34,8 @@ async function loadDefaultInterface(config, configDefaults, roleName = SystemRol
|
|||
multiConvo: interfaceConfig?.multiConvo ?? defaults.multiConvo,
|
||||
agents: interfaceConfig?.agents ?? defaults.agents,
|
||||
temporaryChat: interfaceConfig?.temporaryChat ?? defaults.temporaryChat,
|
||||
runCode: interfaceConfig?.runCode ?? defaults.runCode,
|
||||
customWelcome: interfaceConfig?.customWelcome ?? defaults.customWelcome,
|
||||
});
|
||||
|
||||
await updateAccessPermissions(roleName, {
|
||||
|
|
@ -41,12 +43,16 @@ async function loadDefaultInterface(config, configDefaults, roleName = SystemRol
|
|||
[PermissionTypes.BOOKMARKS]: { [Permissions.USE]: loadedInterface.bookmarks },
|
||||
[PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: loadedInterface.multiConvo },
|
||||
[PermissionTypes.AGENTS]: { [Permissions.USE]: loadedInterface.agents },
|
||||
[PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: loadedInterface.temporaryChat },
|
||||
[PermissionTypes.RUN_CODE]: { [Permissions.USE]: loadedInterface.runCode },
|
||||
});
|
||||
await updateAccessPermissions(SystemRoles.ADMIN, {
|
||||
[PermissionTypes.PROMPTS]: { [Permissions.USE]: loadedInterface.prompts },
|
||||
[PermissionTypes.BOOKMARKS]: { [Permissions.USE]: loadedInterface.bookmarks },
|
||||
[PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: loadedInterface.multiConvo },
|
||||
[PermissionTypes.AGENTS]: { [Permissions.USE]: loadedInterface.agents },
|
||||
[PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: loadedInterface.temporaryChat },
|
||||
[PermissionTypes.RUN_CODE]: { [Permissions.USE]: loadedInterface.runCode },
|
||||
});
|
||||
|
||||
let i = 0;
|
||||
|
|
|
|||
|
|
@ -14,6 +14,8 @@ describe('loadDefaultInterface', () => {
|
|||
bookmarks: true,
|
||||
multiConvo: true,
|
||||
agents: true,
|
||||
temporaryChat: true,
|
||||
runCode: true,
|
||||
},
|
||||
};
|
||||
const configDefaults = { interface: {} };
|
||||
|
|
@ -25,6 +27,8 @@ describe('loadDefaultInterface', () => {
|
|||
[PermissionTypes.BOOKMARKS]: { [Permissions.USE]: true },
|
||||
[PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: true },
|
||||
[PermissionTypes.AGENTS]: { [Permissions.USE]: true },
|
||||
[PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: true },
|
||||
[PermissionTypes.RUN_CODE]: { [Permissions.USE]: true },
|
||||
});
|
||||
});
|
||||
|
||||
|
|
@ -35,6 +39,8 @@ describe('loadDefaultInterface', () => {
|
|||
bookmarks: false,
|
||||
multiConvo: false,
|
||||
agents: false,
|
||||
temporaryChat: false,
|
||||
runCode: false,
|
||||
},
|
||||
};
|
||||
const configDefaults = { interface: {} };
|
||||
|
|
@ -46,6 +52,8 @@ describe('loadDefaultInterface', () => {
|
|||
[PermissionTypes.BOOKMARKS]: { [Permissions.USE]: false },
|
||||
[PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: false },
|
||||
[PermissionTypes.AGENTS]: { [Permissions.USE]: false },
|
||||
[PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: false },
|
||||
[PermissionTypes.RUN_CODE]: { [Permissions.USE]: false },
|
||||
});
|
||||
});
|
||||
|
||||
|
|
@ -60,6 +68,8 @@ describe('loadDefaultInterface', () => {
|
|||
[PermissionTypes.BOOKMARKS]: { [Permissions.USE]: undefined },
|
||||
[PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: undefined },
|
||||
[PermissionTypes.AGENTS]: { [Permissions.USE]: undefined },
|
||||
[PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: undefined },
|
||||
[PermissionTypes.RUN_CODE]: { [Permissions.USE]: undefined },
|
||||
});
|
||||
});
|
||||
|
||||
|
|
@ -70,6 +80,8 @@ describe('loadDefaultInterface', () => {
|
|||
bookmarks: undefined,
|
||||
multiConvo: undefined,
|
||||
agents: undefined,
|
||||
temporaryChat: undefined,
|
||||
runCode: undefined,
|
||||
},
|
||||
};
|
||||
const configDefaults = { interface: {} };
|
||||
|
|
@ -81,6 +93,8 @@ describe('loadDefaultInterface', () => {
|
|||
[PermissionTypes.BOOKMARKS]: { [Permissions.USE]: undefined },
|
||||
[PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: undefined },
|
||||
[PermissionTypes.AGENTS]: { [Permissions.USE]: undefined },
|
||||
[PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: undefined },
|
||||
[PermissionTypes.RUN_CODE]: { [Permissions.USE]: undefined },
|
||||
});
|
||||
});
|
||||
|
||||
|
|
@ -91,6 +105,8 @@ describe('loadDefaultInterface', () => {
|
|||
bookmarks: false,
|
||||
multiConvo: undefined,
|
||||
agents: true,
|
||||
temporaryChat: undefined,
|
||||
runCode: false,
|
||||
},
|
||||
};
|
||||
const configDefaults = { interface: {} };
|
||||
|
|
@ -102,6 +118,8 @@ describe('loadDefaultInterface', () => {
|
|||
[PermissionTypes.BOOKMARKS]: { [Permissions.USE]: false },
|
||||
[PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: undefined },
|
||||
[PermissionTypes.AGENTS]: { [Permissions.USE]: true },
|
||||
[PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: undefined },
|
||||
[PermissionTypes.RUN_CODE]: { [Permissions.USE]: false },
|
||||
});
|
||||
});
|
||||
|
||||
|
|
@ -113,6 +131,8 @@ describe('loadDefaultInterface', () => {
|
|||
bookmarks: true,
|
||||
multiConvo: true,
|
||||
agents: true,
|
||||
temporaryChat: true,
|
||||
runCode: true,
|
||||
},
|
||||
};
|
||||
|
||||
|
|
@ -123,6 +143,8 @@ describe('loadDefaultInterface', () => {
|
|||
[PermissionTypes.BOOKMARKS]: { [Permissions.USE]: true },
|
||||
[PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: true },
|
||||
[PermissionTypes.AGENTS]: { [Permissions.USE]: true },
|
||||
[PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: true },
|
||||
[PermissionTypes.RUN_CODE]: { [Permissions.USE]: true },
|
||||
});
|
||||
});
|
||||
|
||||
|
|
@ -137,6 +159,8 @@ describe('loadDefaultInterface', () => {
|
|||
[PermissionTypes.BOOKMARKS]: { [Permissions.USE]: undefined },
|
||||
[PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: true },
|
||||
[PermissionTypes.AGENTS]: { [Permissions.USE]: undefined },
|
||||
[PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: undefined },
|
||||
[PermissionTypes.RUN_CODE]: { [Permissions.USE]: undefined },
|
||||
});
|
||||
});
|
||||
|
||||
|
|
@ -151,6 +175,8 @@ describe('loadDefaultInterface', () => {
|
|||
[PermissionTypes.BOOKMARKS]: { [Permissions.USE]: undefined },
|
||||
[PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: false },
|
||||
[PermissionTypes.AGENTS]: { [Permissions.USE]: undefined },
|
||||
[PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: undefined },
|
||||
[PermissionTypes.RUN_CODE]: { [Permissions.USE]: undefined },
|
||||
});
|
||||
});
|
||||
|
||||
|
|
@ -165,6 +191,8 @@ describe('loadDefaultInterface', () => {
|
|||
[PermissionTypes.BOOKMARKS]: { [Permissions.USE]: undefined },
|
||||
[PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: undefined },
|
||||
[PermissionTypes.AGENTS]: { [Permissions.USE]: undefined },
|
||||
[PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: undefined },
|
||||
[PermissionTypes.RUN_CODE]: { [Permissions.USE]: undefined },
|
||||
});
|
||||
});
|
||||
|
||||
|
|
@ -175,6 +203,8 @@ describe('loadDefaultInterface', () => {
|
|||
bookmarks: false,
|
||||
multiConvo: true,
|
||||
agents: false,
|
||||
temporaryChat: true,
|
||||
runCode: false,
|
||||
},
|
||||
};
|
||||
const configDefaults = { interface: {} };
|
||||
|
|
@ -186,6 +216,8 @@ describe('loadDefaultInterface', () => {
|
|||
[PermissionTypes.BOOKMARKS]: { [Permissions.USE]: false },
|
||||
[PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: true },
|
||||
[PermissionTypes.AGENTS]: { [Permissions.USE]: false },
|
||||
[PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: true },
|
||||
[PermissionTypes.RUN_CODE]: { [Permissions.USE]: false },
|
||||
});
|
||||
});
|
||||
|
||||
|
|
@ -197,6 +229,8 @@ describe('loadDefaultInterface', () => {
|
|||
bookmarks: true,
|
||||
multiConvo: false,
|
||||
agents: undefined,
|
||||
temporaryChat: undefined,
|
||||
runCode: undefined,
|
||||
},
|
||||
};
|
||||
|
||||
|
|
@ -207,6 +241,8 @@ describe('loadDefaultInterface', () => {
|
|||
[PermissionTypes.BOOKMARKS]: { [Permissions.USE]: true },
|
||||
[PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: false },
|
||||
[PermissionTypes.AGENTS]: { [Permissions.USE]: undefined },
|
||||
[PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: undefined },
|
||||
[PermissionTypes.RUN_CODE]: { [Permissions.USE]: undefined },
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
|||
224
api/server/services/twoFactorService.js
Normal file
224
api/server/services/twoFactorService.js
Normal file
|
|
@ -0,0 +1,224 @@
|
|||
const { webcrypto } = require('node:crypto');
|
||||
const { decryptV3, decryptV2 } = require('../utils/crypto');
|
||||
const { hashBackupCode } = require('~/server/utils/crypto');
|
||||
|
||||
// Base32 alphabet for TOTP secret encoding.
|
||||
const BASE32_ALPHABET = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ234567';
|
||||
|
||||
/**
|
||||
* Encodes a Buffer into a Base32 string.
|
||||
* @param {Buffer} buffer
|
||||
* @returns {string}
|
||||
*/
|
||||
const encodeBase32 = (buffer) => {
|
||||
let bits = 0;
|
||||
let value = 0;
|
||||
let output = '';
|
||||
for (const byte of buffer) {
|
||||
value = (value << 8) | byte;
|
||||
bits += 8;
|
||||
while (bits >= 5) {
|
||||
output += BASE32_ALPHABET[(value >>> (bits - 5)) & 31];
|
||||
bits -= 5;
|
||||
}
|
||||
}
|
||||
if (bits > 0) {
|
||||
output += BASE32_ALPHABET[(value << (5 - bits)) & 31];
|
||||
}
|
||||
return output;
|
||||
};
|
||||
|
||||
/**
|
||||
* Decodes a Base32 string into a Buffer.
|
||||
* @param {string} base32Str
|
||||
* @returns {Buffer}
|
||||
*/
|
||||
const decodeBase32 = (base32Str) => {
|
||||
const cleaned = base32Str.replace(/=+$/, '').toUpperCase();
|
||||
let bits = 0;
|
||||
let value = 0;
|
||||
const output = [];
|
||||
for (const char of cleaned) {
|
||||
const idx = BASE32_ALPHABET.indexOf(char);
|
||||
if (idx === -1) {
|
||||
continue;
|
||||
}
|
||||
value = (value << 5) | idx;
|
||||
bits += 5;
|
||||
if (bits >= 8) {
|
||||
output.push((value >>> (bits - 8)) & 0xff);
|
||||
bits -= 8;
|
||||
}
|
||||
}
|
||||
return Buffer.from(output);
|
||||
};
|
||||
|
||||
/**
|
||||
* Generates a new TOTP secret (Base32 encoded).
|
||||
* @returns {string}
|
||||
*/
|
||||
const generateTOTPSecret = () => {
|
||||
const randomArray = new Uint8Array(10);
|
||||
webcrypto.getRandomValues(randomArray);
|
||||
return encodeBase32(Buffer.from(randomArray));
|
||||
};
|
||||
|
||||
/**
|
||||
* Generates a TOTP code based on the secret and time.
|
||||
* Uses a 30-second time step and produces a 6-digit code.
|
||||
* @param {string} secret
|
||||
* @param {number} [forTime=Date.now()]
|
||||
* @returns {Promise<string>}
|
||||
*/
|
||||
const generateTOTP = async (secret, forTime = Date.now()) => {
|
||||
const timeStep = 30; // seconds
|
||||
const counter = Math.floor(forTime / 1000 / timeStep);
|
||||
const counterBuffer = new ArrayBuffer(8);
|
||||
const counterView = new DataView(counterBuffer);
|
||||
counterView.setUint32(4, counter, false);
|
||||
|
||||
const keyBuffer = decodeBase32(secret);
|
||||
const keyArrayBuffer = keyBuffer.buffer.slice(
|
||||
keyBuffer.byteOffset,
|
||||
keyBuffer.byteOffset + keyBuffer.byteLength,
|
||||
);
|
||||
|
||||
const cryptoKey = await webcrypto.subtle.importKey(
|
||||
'raw',
|
||||
keyArrayBuffer,
|
||||
{ name: 'HMAC', hash: 'SHA-1' },
|
||||
false,
|
||||
['sign'],
|
||||
);
|
||||
const signatureBuffer = await webcrypto.subtle.sign('HMAC', cryptoKey, counterBuffer);
|
||||
const hmac = new Uint8Array(signatureBuffer);
|
||||
|
||||
// Dynamic truncation per RFC 4226.
|
||||
const offset = hmac[hmac.length - 1] & 0xf;
|
||||
const slice = hmac.slice(offset, offset + 4);
|
||||
const view = new DataView(slice.buffer, slice.byteOffset, slice.byteLength);
|
||||
const binaryCode = view.getUint32(0, false) & 0x7fffffff;
|
||||
const code = (binaryCode % 1000000).toString().padStart(6, '0');
|
||||
return code;
|
||||
};
|
||||
|
||||
/**
|
||||
* Verifies a TOTP token by checking a ±1 time step window.
|
||||
* @param {string} secret
|
||||
* @param {string} token
|
||||
* @returns {Promise<boolean>}
|
||||
*/
|
||||
const verifyTOTP = async (secret, token) => {
|
||||
const timeStepMS = 30 * 1000;
|
||||
const currentTime = Date.now();
|
||||
for (let offset = -1; offset <= 1; offset++) {
|
||||
const expected = await generateTOTP(secret, currentTime + offset * timeStepMS);
|
||||
if (expected === token) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
/**
|
||||
* Generates backup codes (default count: 10).
|
||||
* Each code is an 8-character hexadecimal string and stored with its SHA-256 hash.
|
||||
* @param {number} [count=10]
|
||||
* @returns {Promise<{ plainCodes: string[], codeObjects: Array<{ codeHash: string, used: boolean, usedAt: Date | null }> }>}
|
||||
*/
|
||||
const generateBackupCodes = async (count = 10) => {
|
||||
const plainCodes = [];
|
||||
const codeObjects = [];
|
||||
const encoder = new TextEncoder();
|
||||
|
||||
for (let i = 0; i < count; i++) {
|
||||
const randomArray = new Uint8Array(4);
|
||||
webcrypto.getRandomValues(randomArray);
|
||||
const code = Array.from(randomArray)
|
||||
.map((b) => b.toString(16).padStart(2, '0'))
|
||||
.join('');
|
||||
plainCodes.push(code);
|
||||
|
||||
const codeBuffer = encoder.encode(code);
|
||||
const hashBuffer = await webcrypto.subtle.digest('SHA-256', codeBuffer);
|
||||
const hashArray = Array.from(new Uint8Array(hashBuffer));
|
||||
const codeHash = hashArray.map((b) => b.toString(16).padStart(2, '0')).join('');
|
||||
codeObjects.push({ codeHash, used: false, usedAt: null });
|
||||
}
|
||||
return { plainCodes, codeObjects };
|
||||
};
|
||||
|
||||
/**
|
||||
* Verifies a backup code and, if valid, marks it as used.
|
||||
* @param {Object} params
|
||||
* @param {Object} params.user
|
||||
* @param {string} params.backupCode
|
||||
* @returns {Promise<boolean>}
|
||||
*/
|
||||
const verifyBackupCode = async ({ user, backupCode }) => {
|
||||
if (!backupCode || !user || !Array.isArray(user.backupCodes)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const hashedInput = await hashBackupCode(backupCode.trim());
|
||||
const matchingCode = user.backupCodes.find(
|
||||
(codeObj) => codeObj.codeHash === hashedInput && !codeObj.used,
|
||||
);
|
||||
|
||||
if (matchingCode) {
|
||||
const updatedBackupCodes = user.backupCodes.map((codeObj) =>
|
||||
codeObj.codeHash === hashedInput && !codeObj.used
|
||||
? { ...codeObj, used: true, usedAt: new Date() }
|
||||
: codeObj,
|
||||
);
|
||||
// Update the user record with the marked backup code.
|
||||
const { updateUser } = require('~/models');
|
||||
await updateUser(user._id, { backupCodes: updatedBackupCodes });
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
/**
|
||||
* Retrieves and decrypts a stored TOTP secret.
|
||||
* - Uses decryptV3 if the secret has a "v3:" prefix.
|
||||
* - Falls back to decryptV2 for colon-delimited values.
|
||||
* - Assumes a 16-character secret is already plain.
|
||||
* @param {string|null} storedSecret
|
||||
* @returns {Promise<string|null>}
|
||||
*/
|
||||
const getTOTPSecret = async (storedSecret) => {
|
||||
if (!storedSecret) {
|
||||
return null;
|
||||
}
|
||||
if (storedSecret.startsWith('v3:')) {
|
||||
return decryptV3(storedSecret);
|
||||
}
|
||||
if (storedSecret.includes(':')) {
|
||||
return await decryptV2(storedSecret);
|
||||
}
|
||||
if (storedSecret.length === 16) {
|
||||
return storedSecret;
|
||||
}
|
||||
return storedSecret;
|
||||
};
|
||||
|
||||
/**
|
||||
* Generates a temporary JWT token for 2FA verification that expires in 5 minutes.
|
||||
* @param {string} userId
|
||||
* @returns {string}
|
||||
*/
|
||||
const generate2FATempToken = (userId) => {
|
||||
const { sign } = require('jsonwebtoken');
|
||||
return sign({ userId, twoFAPending: true }, process.env.JWT_SECRET, { expiresIn: '5m' });
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
generateTOTPSecret,
|
||||
generateTOTP,
|
||||
verifyTOTP,
|
||||
generateBackupCodes,
|
||||
verifyBackupCode,
|
||||
getTOTPSecret,
|
||||
generate2FATempToken,
|
||||
};
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
const Redis = require('ioredis');
|
||||
const Keyv = require('keyv');
|
||||
const passport = require('passport');
|
||||
const session = require('express-session');
|
||||
const MemoryStore = require('memorystore')(session);
|
||||
|
|
@ -12,12 +12,15 @@ const {
|
|||
appleLogin,
|
||||
} = require('~/strategies');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const keyvRedis = require('~/cache/keyvRedis');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
/**
|
||||
* @param {Express.Application} app
|
||||
*/
|
||||
const configureSocialLogins = (app) => {
|
||||
logger.info('Configuring social logins...');
|
||||
|
||||
if (process.env.GOOGLE_CLIENT_ID && process.env.GOOGLE_CLIENT_SECRET) {
|
||||
passport.use(googleLogin());
|
||||
}
|
||||
|
|
@ -37,18 +40,17 @@ const configureSocialLogins = (app) => {
|
|||
process.env.OPENID_ENABLED &&
|
||||
process.env.OPENID_SESSION_SECRET
|
||||
) {
|
||||
logger.info('Configuring OpenID Connect...');
|
||||
const sessionOptions = {
|
||||
secret: process.env.OPENID_SESSION_SECRET,
|
||||
resave: false,
|
||||
saveUninitialized: false,
|
||||
};
|
||||
if (isEnabled(process.env.USE_REDIS)) {
|
||||
const client = new Redis(process.env.REDIS_URI);
|
||||
client
|
||||
.on('error', (err) => logger.error('ioredis error:', err))
|
||||
.on('ready', () => logger.info('ioredis successfully initialized.'))
|
||||
.on('reconnecting', () => logger.info('ioredis reconnecting...'));
|
||||
sessionOptions.store = new RedisStore({ client, prefix: 'librechat' });
|
||||
logger.debug('Using Redis for session storage in OpenID...');
|
||||
const keyv = new Keyv({ store: keyvRedis });
|
||||
const client = keyv.opts.store.redis;
|
||||
sessionOptions.store = new RedisStore({ client, prefix: 'openid_session' });
|
||||
} else {
|
||||
sessionOptions.store = new MemoryStore({
|
||||
checkPeriod: 86400000, // prune expired entries every 24h
|
||||
|
|
@ -57,7 +59,9 @@ const configureSocialLogins = (app) => {
|
|||
app.use(session(sessionOptions));
|
||||
app.use(passport.session());
|
||||
setupOpenId();
|
||||
|
||||
logger.info('OpenID Connect configured.');
|
||||
}
|
||||
};
|
||||
|
||||
module.exports = configureSocialLogins;
|
||||
module.exports = configureSocialLogins;
|
||||
|
|
|
|||
|
|
@ -1,27 +1,25 @@
|
|||
require('dotenv').config();
|
||||
const crypto = require('node:crypto');
|
||||
const { webcrypto } = crypto;
|
||||
|
||||
const { webcrypto } = require('node:crypto');
|
||||
// Use hex decoding for both key and IV for legacy methods.
|
||||
const key = Buffer.from(process.env.CREDS_KEY, 'hex');
|
||||
const iv = Buffer.from(process.env.CREDS_IV, 'hex');
|
||||
const algorithm = 'AES-CBC';
|
||||
|
||||
// --- Legacy v1/v2 Setup: AES-CBC with fixed key and IV ---
|
||||
|
||||
async function encrypt(value) {
|
||||
const cryptoKey = await webcrypto.subtle.importKey('raw', key, { name: algorithm }, false, [
|
||||
'encrypt',
|
||||
]);
|
||||
|
||||
const encoder = new TextEncoder();
|
||||
const data = encoder.encode(value);
|
||||
|
||||
const encryptedBuffer = await webcrypto.subtle.encrypt(
|
||||
{
|
||||
name: algorithm,
|
||||
iv: iv,
|
||||
},
|
||||
{ name: algorithm, iv: iv },
|
||||
cryptoKey,
|
||||
data,
|
||||
);
|
||||
|
||||
return Buffer.from(encryptedBuffer).toString('hex');
|
||||
}
|
||||
|
||||
|
|
@ -29,73 +27,85 @@ async function decrypt(encryptedValue) {
|
|||
const cryptoKey = await webcrypto.subtle.importKey('raw', key, { name: algorithm }, false, [
|
||||
'decrypt',
|
||||
]);
|
||||
|
||||
const encryptedBuffer = Buffer.from(encryptedValue, 'hex');
|
||||
|
||||
const decryptedBuffer = await webcrypto.subtle.decrypt(
|
||||
{
|
||||
name: algorithm,
|
||||
iv: iv,
|
||||
},
|
||||
{ name: algorithm, iv: iv },
|
||||
cryptoKey,
|
||||
encryptedBuffer,
|
||||
);
|
||||
|
||||
const decoder = new TextDecoder();
|
||||
return decoder.decode(decryptedBuffer);
|
||||
}
|
||||
|
||||
// Programmatically generate iv
|
||||
// --- v2: AES-CBC with a random IV per encryption ---
|
||||
|
||||
async function encryptV2(value) {
|
||||
const gen_iv = webcrypto.getRandomValues(new Uint8Array(16));
|
||||
|
||||
const cryptoKey = await webcrypto.subtle.importKey('raw', key, { name: algorithm }, false, [
|
||||
'encrypt',
|
||||
]);
|
||||
|
||||
const encoder = new TextEncoder();
|
||||
const data = encoder.encode(value);
|
||||
|
||||
const encryptedBuffer = await webcrypto.subtle.encrypt(
|
||||
{
|
||||
name: algorithm,
|
||||
iv: gen_iv,
|
||||
},
|
||||
{ name: algorithm, iv: gen_iv },
|
||||
cryptoKey,
|
||||
data,
|
||||
);
|
||||
|
||||
return Buffer.from(gen_iv).toString('hex') + ':' + Buffer.from(encryptedBuffer).toString('hex');
|
||||
}
|
||||
|
||||
async function decryptV2(encryptedValue) {
|
||||
const parts = encryptedValue.split(':');
|
||||
// Already decrypted from an earlier invocation
|
||||
if (parts.length === 1) {
|
||||
return parts[0];
|
||||
}
|
||||
const gen_iv = Buffer.from(parts.shift(), 'hex');
|
||||
const encrypted = parts.join(':');
|
||||
|
||||
const cryptoKey = await webcrypto.subtle.importKey('raw', key, { name: algorithm }, false, [
|
||||
'decrypt',
|
||||
]);
|
||||
|
||||
const encryptedBuffer = Buffer.from(encrypted, 'hex');
|
||||
|
||||
const decryptedBuffer = await webcrypto.subtle.decrypt(
|
||||
{
|
||||
name: algorithm,
|
||||
iv: gen_iv,
|
||||
},
|
||||
{ name: algorithm, iv: gen_iv },
|
||||
cryptoKey,
|
||||
encryptedBuffer,
|
||||
);
|
||||
|
||||
const decoder = new TextDecoder();
|
||||
return decoder.decode(decryptedBuffer);
|
||||
}
|
||||
|
||||
// --- v3: AES-256-CTR using Node's crypto functions ---
|
||||
const algorithm_v3 = 'aes-256-ctr';
|
||||
|
||||
/**
|
||||
* Encrypts a value using AES-256-CTR.
|
||||
* Note: AES-256 requires a 32-byte key. Ensure that process.env.CREDS_KEY is a 64-character hex string.
|
||||
*
|
||||
* @param {string} value - The plaintext to encrypt.
|
||||
* @returns {string} The encrypted string with a "v3:" prefix.
|
||||
*/
|
||||
function encryptV3(value) {
|
||||
if (key.length !== 32) {
|
||||
throw new Error(`Invalid key length: expected 32 bytes, got ${key.length} bytes`);
|
||||
}
|
||||
const iv_v3 = crypto.randomBytes(16);
|
||||
const cipher = crypto.createCipheriv(algorithm_v3, key, iv_v3);
|
||||
const encrypted = Buffer.concat([cipher.update(value, 'utf8'), cipher.final()]);
|
||||
return `v3:${iv_v3.toString('hex')}:${encrypted.toString('hex')}`;
|
||||
}
|
||||
|
||||
function decryptV3(encryptedValue) {
|
||||
const parts = encryptedValue.split(':');
|
||||
if (parts[0] !== 'v3') {
|
||||
throw new Error('Not a v3 encrypted value');
|
||||
}
|
||||
const iv_v3 = Buffer.from(parts[1], 'hex');
|
||||
const encryptedText = Buffer.from(parts.slice(2).join(':'), 'hex');
|
||||
const decipher = crypto.createDecipheriv(algorithm_v3, key, iv_v3);
|
||||
const decrypted = Buffer.concat([decipher.update(encryptedText), decipher.final()]);
|
||||
return decrypted.toString('utf8');
|
||||
}
|
||||
|
||||
async function hashToken(str) {
|
||||
const data = new TextEncoder().encode(str);
|
||||
const hashBuffer = await webcrypto.subtle.digest('SHA-256', data);
|
||||
|
|
@ -106,10 +116,32 @@ async function getRandomValues(length) {
|
|||
if (!Number.isInteger(length) || length <= 0) {
|
||||
throw new Error('Length must be a positive integer');
|
||||
}
|
||||
|
||||
const randomValues = new Uint8Array(length);
|
||||
webcrypto.getRandomValues(randomValues);
|
||||
return Buffer.from(randomValues).toString('hex');
|
||||
}
|
||||
|
||||
module.exports = { encrypt, decrypt, encryptV2, decryptV2, hashToken, getRandomValues };
|
||||
/**
|
||||
* Computes SHA-256 hash for the given input.
|
||||
* @param {string} input
|
||||
* @returns {Promise<string>}
|
||||
*/
|
||||
async function hashBackupCode(input) {
|
||||
const encoder = new TextEncoder();
|
||||
const data = encoder.encode(input);
|
||||
const hashBuffer = await webcrypto.subtle.digest('SHA-256', data);
|
||||
const hashArray = Array.from(new Uint8Array(hashBuffer));
|
||||
return hashArray.map((b) => b.toString(16).padStart(2, '0')).join('');
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
encrypt,
|
||||
decrypt,
|
||||
encryptV2,
|
||||
decryptV2,
|
||||
encryptV3,
|
||||
decryptV3,
|
||||
hashToken,
|
||||
hashBackupCode,
|
||||
getRandomValues,
|
||||
};
|
||||
|
|
|
|||
|
|
@ -203,6 +203,8 @@ function generateConfig(key, baseURL, endpoint) {
|
|||
AgentCapabilities.artifacts,
|
||||
AgentCapabilities.actions,
|
||||
AgentCapabilities.tools,
|
||||
AgentCapabilities.ocr,
|
||||
AgentCapabilities.chain,
|
||||
];
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
const express = require('express');
|
||||
const expressStaticGzip = require('express-static-gzip');
|
||||
|
||||
const oneDayInSeconds = 24 * 60 * 60;
|
||||
|
||||
|
|
@ -6,13 +6,13 @@ const sMaxAge = process.env.STATIC_CACHE_S_MAX_AGE || oneDayInSeconds;
|
|||
const maxAge = process.env.STATIC_CACHE_MAX_AGE || oneDayInSeconds * 2;
|
||||
|
||||
const staticCache = (staticPath) =>
|
||||
express.static(staticPath, {
|
||||
setHeaders: (res) => {
|
||||
if (process.env.NODE_ENV?.toLowerCase() !== 'production') {
|
||||
return;
|
||||
expressStaticGzip(staticPath, {
|
||||
enableBrotli: false, // disable Brotli, only using gzip
|
||||
orderPreference: ['gz'],
|
||||
setHeaders: (res, _path) => {
|
||||
if (process.env.NODE_ENV?.toLowerCase() === 'production') {
|
||||
res.setHeader('Cache-Control', `public, max-age=${maxAge}, s-maxage=${sMaxAge}`);
|
||||
}
|
||||
|
||||
res.setHeader('Cache-Control', `public, max-age=${maxAge}, s-maxage=${sMaxAge}`);
|
||||
},
|
||||
});
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue