refactor: request is encrypted. response from AI is still saved in plaintext but from the stream the final response is encrypted.

This commit is contained in:
Ruben Talstra 2025-02-16 11:56:40 +01:00
parent 0cc0e5d287
commit 7346d20224
No known key found for this signature in database
GPG key ID: 2A5A7174A60F3BEA

View file

@ -1,16 +1,59 @@
const { getResponseSender, Constants } = require('librechat-data-provider'); const { getResponseSender, Constants } = require('librechat-data-provider');
const { createAbortController, handleAbortError } = require('~/server/middleware'); const { createAbortController, handleAbortError } = require('~/server/middleware');
const { sendMessage, createOnProgress } = require('~/server/utils'); const { sendMessage, createOnProgress } = require('~/server/utils');
const { saveMessage } = require('~/models'); const { saveMessage, getUserById } = require('~/models');
const { logger } = require('~/config'); const { logger } = require('~/config');
let crypto; let crypto;
try { try {
crypto = require('crypto'); crypto = require('crypto');
} catch (err) { } catch (err) {
logger.error('[openidStrategy] crypto support is disabled!', err); logger.error('[AskController] crypto support is disabled!', err);
} }
/**
* Helper function to encrypt plaintext using AES-256-GCM and then RSA-encrypt the AES key.
* @param {string} plainText - The plaintext to encrypt.
* @param {string} pemPublicKey - The RSA public key in PEM format.
* @returns {Object} An object containing the ciphertext, iv, authTag, and encryptedKey.
*/
function encryptText(plainText, pemPublicKey) {
// Generate a random 256-bit AES key and a 12-byte IV.
const aesKey = crypto.randomBytes(32);
const iv = crypto.randomBytes(12);
// Encrypt the plaintext using AES-256-GCM.
const cipher = crypto.createCipheriv('aes-256-gcm', aesKey, iv);
let ciphertext = cipher.update(plainText, 'utf8', 'base64');
ciphertext += cipher.final('base64');
const authTag = cipher.getAuthTag().toString('base64');
// Encrypt the AES key using the user's RSA public key.
const encryptedKey = crypto.publicEncrypt(
{
key: pemPublicKey,
padding: crypto.constants.RSA_PKCS1_OAEP_PADDING,
oaepHash: 'sha256',
},
aesKey,
).toString('base64');
return {
ciphertext,
iv: iv.toString('base64'),
authTag,
encryptedKey,
};
}
/**
* AskController
* - Initializes the client.
* - Obtains the response from the language model.
* - Retrieves the full user record (to get encryption parameters).
* - If the user has encryption enabled (i.e. encryptionPublicKey is provided),
* encrypts both the request (userMessage) and the response before saving.
*/
const AskController = async (req, res, next, initializeClient, addTitle) => { const AskController = async (req, res, next, initializeClient, addTitle) => {
let { let {
text, text,
@ -39,7 +82,17 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
modelDisplayLabel, modelDisplayLabel,
}); });
const newConvo = !conversationId; const newConvo = !conversationId;
const user = req.user.id; const userId = req.user.id; // User ID from authentication
// Retrieve full user record from DB (including encryption parameters)
const dbUser = await getUserById(userId, 'encryptionPublicKey encryptedPrivateKey encryptionSalt encryptionIV');
// If the user has provided an encryption public key, rebuild the PEM format.
let pemPublicKey = null;
if (dbUser?.encryptionPublicKey && dbUser.encryptionPublicKey.trim() !== '') {
const pubKeyBase64 = dbUser.encryptionPublicKey;
pemPublicKey = `-----BEGIN PUBLIC KEY-----\n${pubKeyBase64.match(/.{1,64}/g).join('\n')}\n-----END PUBLIC KEY-----`;
}
const getReqData = (data = {}) => { const getReqData = (data = {}) => {
for (let key in data) { for (let key in data) {
@ -59,11 +112,9 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
}; };
let getText; let getText;
try { try {
const { client } = await initializeClient({ req, res, endpointOption }); const { client } = await initializeClient({ req, res, endpointOption });
const { onProgress: progressCallback, getPartialText } = createOnProgress(); const { onProgress: progressCallback, getPartialText } = createOnProgress();
getText = client.getStreamText != null ? client.getStreamText.bind(client) : getPartialText; getText = client.getStreamText != null ? client.getStreamText.bind(client) : getPartialText;
const getAbortData = () => ({ const getAbortData = () => ({
@ -81,14 +132,14 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
res.on('close', () => { res.on('close', () => {
logger.debug('[AskController] Request closed'); logger.debug('[AskController] Request closed');
if (!abortController) {return;} if (!abortController) { return; }
if (abortController.signal.aborted || abortController.requestCompleted) {return;} if (abortController.signal.aborted || abortController.requestCompleted) { return; }
abortController.abort(); abortController.abort();
logger.debug('[AskController] Request aborted on close'); logger.debug('[AskController] Request aborted on close');
}); });
const messageOptions = { const messageOptions = {
user, user: userId,
parentMessageId, parentMessageId,
conversationId, conversationId,
overrideParentMessageId, overrideParentMessageId,
@ -99,10 +150,11 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
progressOptions: { res }, progressOptions: { res },
}; };
/** @type {TMessage} */ // Get the response from the language model client.
let response = await client.sendMessage(text, messageOptions); let response = await client.sendMessage(text, messageOptions);
response.endpoint = endpointOption.endpoint; response.endpoint = endpointOption.endpoint;
// Ensure the conversation has a title.
const { conversation = {} } = await client.responsePromise; const { conversation = {} } = await client.responsePromise;
conversation.title = conversation.title =
conversation && !conversation.title ? null : conversation?.title || 'New Chat'; conversation && !conversation.title ? null : conversation?.title || 'New Chat';
@ -113,54 +165,33 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
delete userMessage.image_urls; delete userMessage.image_urls;
} }
// --- Encryption Branch --- // --- Encrypt the user message if encryption is enabled ---
// Only encrypt if the user has set up encryption (i.e. non-empty encryptionPublicKey) if (pemPublicKey && userMessage && userMessage.text) {
if (
req.user.encryptionPublicKey &&
req.user.encryptionPublicKey.trim() !== '' &&
response.text &&
crypto
) {
try { try {
// Reconstruct the user's RSA public key in PEM format. const { ciphertext, iv, authTag, encryptedKey } = encryptText(userMessage.text, pemPublicKey);
const pubKeyBase64 = req.user.encryptionPublicKey; userMessage.text = ciphertext;
const pemPublicKey = `-----BEGIN PUBLIC KEY-----\n${pubKeyBase64.match(/.{1,64}/g).join('\n')}\n-----END PUBLIC KEY-----`; userMessage.iv = iv;
userMessage.authTag = authTag;
// Generate a random 256-bit AES key and a 12-byte IV. userMessage.encryptedKey = encryptedKey;
const aesKey = crypto.randomBytes(32); logger.debug('[AskController] User message encrypted.');
const iv = crypto.randomBytes(12); } catch (encError) {
logger.error('[AskController] Error encrypting user message:', encError);
// Encrypt the response text using AES-GCM. // Optionally, you could choose to throw an error or fallback.
const cipher = crypto.createCipheriv('aes-256-gcm', aesKey, iv); }
let ciphertext = cipher.update(response.text, 'utf8', 'base64');
ciphertext += cipher.final('base64');
const authTag = cipher.getAuthTag().toString('base64');
// Encrypt the AES key using the client's RSA public key.
let encryptedKey;
try {
encryptedKey = crypto.publicEncrypt(
{
key: pemPublicKey,
padding: crypto.constants.RSA_PKCS1_OAEP_PADDING,
oaepHash: 'sha256',
},
aesKey,
).toString('base64');
} catch (err) {
logger.error('Error encrypting AES key:', err);
throw new Error('Encryption failure');
} }
// Replace the plaintext response with the encrypted payload. // --- Encrypt the AI response if encryption is enabled ---
if (pemPublicKey && response.text) {
try {
const { ciphertext, iv, authTag, encryptedKey } = encryptText(response.text, pemPublicKey);
response.text = ciphertext; response.text = ciphertext;
response.iv = iv.toString('base64'); response.iv = iv;
response.authTag = authTag; response.authTag = authTag;
response.encryptedKey = encryptedKey; response.encryptedKey = encryptedKey;
logger.debug('[AskController] Response message encrypted.'); logger.debug('[AskController] Response message encrypted.');
} catch (encError) { } catch (encError) {
logger.error('[AskController] Error during response encryption:', encError); logger.error('[AskController] Error encrypting response message:', encError);
// Optionally, you may choose to return plaintext if encryption fails. // Optionally, you can choose to send plaintext or handle the error.
} }
} }
// --- End Encryption Branch --- // --- End Encryption Branch ---
@ -178,15 +209,15 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
if (!client.savedMessageIds.has(response.messageId)) { if (!client.savedMessageIds.has(response.messageId)) {
await saveMessage( await saveMessage(
req, req,
{ ...response, user }, { ...response, user: userId },
{ context: 'api/server/controllers/AskController.js - response end' }, { context: 'AskController - response end' },
); );
} }
} }
if (!client.skipSaveUserMessage) { if (!client.skipSaveUserMessage) {
await saveMessage(req, userMessage, { await saveMessage(req, userMessage, {
context: 'api/server/controllers/AskController.js - don\'t skip saving user message', context: 'AskController - save user message',
}); });
} }
@ -206,7 +237,7 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
messageId: responseMessageId, messageId: responseMessageId,
parentMessageId: overrideParentMessageId ?? userMessageId ?? parentMessageId, parentMessageId: overrideParentMessageId ?? userMessageId ?? parentMessageId,
}).catch((err) => { }).catch((err) => {
logger.error('[AskController] Error in `handleAbortError`', err); logger.error('[AskController] Error in handleAbortError', err);
}); });
} }
}; };