🛠️ fix(Azure/Assistants): Handle Long Domain Names & Other Minor chores (#2475)

* chore: replace violation cache accessors with enum

* chore: fix test

* chore(fileSchema): index timestamps

* fix(ActionService): use encoding/caching strategy for handling assistant function character length limit

* refactor(actions): async `domainParser` also resolve retrieved model (which is deployment name) to user-defined model

* style(AssistantAction): add `whitespace-nowrap` for ellipsis

* refactor(ActionService): if domain is less than or equal to encoded domain fixed length, return domain with replacement of separator

* refactor(actions): use sessions/transactions for updating Assistant Action database records

* chore: remove TTL from ENCODED_DOMAINS cache

* refactor(domainParser): minor optimization and add tests

* fix(spendTokens): use txData.user for token usage logging

* refactor(actions): add helper function `withSession` for database operations with sessions/transactions

* fix(PluginsClient): logger debug `message` field edge case
This commit is contained in:
Danny Avila 2024-04-20 15:02:56 -04:00 committed by GitHub
parent 5d642d0187
commit 8c22bb1d3d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 365 additions and 63 deletions

View file

@ -244,7 +244,7 @@ class PluginsClient extends OpenAIClient {
this.setOptions(opts); this.setOptions(opts);
return super.sendMessage(message, opts); return super.sendMessage(message, opts);
} }
logger.debug('[PluginsClient] sendMessage', { message, opts }); logger.debug('[PluginsClient] sendMessage', { userMessageText: message, opts });
const { const {
user, user,
isEdited, isEdited,

View file

@ -1,6 +1,7 @@
const Session = require('~/models/Session'); const { ViolationTypes } = require('librechat-data-provider');
const getLogStores = require('./getLogStores');
const { isEnabled, math, removePorts } = require('~/server/utils'); const { isEnabled, math, removePorts } = require('~/server/utils');
const getLogStores = require('./getLogStores');
const Session = require('~/models/Session');
const { logger } = require('~/config'); const { logger } = require('~/config');
const { BAN_VIOLATIONS, BAN_INTERVAL } = process.env ?? {}; const { BAN_VIOLATIONS, BAN_INTERVAL } = process.env ?? {};
@ -48,7 +49,7 @@ const banViolation = async (req, res, errorMessage) => {
await Session.deleteAllUserSessions(user_id); await Session.deleteAllUserSessions(user_id);
res.clearCookie('refreshToken'); res.clearCookie('refreshToken');
const banLogs = getLogStores('ban'); const banLogs = getLogStores(ViolationTypes.BAN);
const duration = errorMessage.duration || banLogs.opts.ttl; const duration = errorMessage.duration || banLogs.opts.ttl;
if (duration <= 0) { if (duration <= 0) {

View file

@ -6,6 +6,7 @@ jest.mock('../models/Session');
jest.mock('./getLogStores', () => { jest.mock('./getLogStores', () => {
return jest.fn().mockImplementation(() => { return jest.fn().mockImplementation(() => {
const EventEmitter = require('events'); const EventEmitter = require('events');
const { CacheKeys } = require('librechat-data-provider');
const math = require('../server/utils/math'); const math = require('../server/utils/math');
const mockGet = jest.fn(); const mockGet = jest.fn();
const mockSet = jest.fn(); const mockSet = jest.fn();
@ -33,7 +34,7 @@ jest.mock('./getLogStores', () => {
} }
return new KeyvMongo('', { return new KeyvMongo('', {
namespace: 'bans', namespace: CacheKeys.BANS,
ttl: math(process.env.BAN_DURATION, 7200000), ttl: math(process.env.BAN_DURATION, 7200000),
}); });
}); });

View file

@ -6,6 +6,7 @@ const keyvRedis = require('./keyvRedis');
const keyvMongo = require('./keyvMongo'); const keyvMongo = require('./keyvMongo');
const { BAN_DURATION, USE_REDIS } = process.env ?? {}; const { BAN_DURATION, USE_REDIS } = process.env ?? {};
const THIRTY_MINUTES = 1800000;
const duration = math(BAN_DURATION, 7200000); const duration = math(BAN_DURATION, 7200000);
@ -24,8 +25,8 @@ const config = isEnabled(USE_REDIS)
: new Keyv({ namespace: CacheKeys.CONFIG_STORE }); : new Keyv({ namespace: CacheKeys.CONFIG_STORE });
const tokenConfig = isEnabled(USE_REDIS) // ttl: 30 minutes const tokenConfig = isEnabled(USE_REDIS) // ttl: 30 minutes
? new Keyv({ store: keyvRedis, ttl: 1800000 }) ? new Keyv({ store: keyvRedis, ttl: THIRTY_MINUTES })
: new Keyv({ namespace: CacheKeys.TOKEN_CONFIG, ttl: 1800000 }); : new Keyv({ namespace: CacheKeys.TOKEN_CONFIG, ttl: THIRTY_MINUTES });
const genTitle = isEnabled(USE_REDIS) // ttl: 2 minutes const genTitle = isEnabled(USE_REDIS) // ttl: 2 minutes
? new Keyv({ store: keyvRedis, ttl: 120000 }) ? new Keyv({ store: keyvRedis, ttl: 120000 })
@ -42,7 +43,12 @@ const abortKeys = isEnabled(USE_REDIS)
const namespaces = { const namespaces = {
[CacheKeys.CONFIG_STORE]: config, [CacheKeys.CONFIG_STORE]: config,
pending_req, pending_req,
ban: new Keyv({ store: keyvMongo, namespace: 'bans', ttl: duration }), [ViolationTypes.BAN]: new Keyv({ store: keyvMongo, namespace: CacheKeys.BANS, ttl: duration }),
[CacheKeys.ENCODED_DOMAINS]: new Keyv({
store: keyvMongo,
namespace: CacheKeys.ENCODED_DOMAINS,
ttl: 0,
}),
general: new Keyv({ store: logFile, namespace: 'violations' }), general: new Keyv({ store: logFile, namespace: 'violations' }),
concurrent: createViolationInstance('concurrent'), concurrent: createViolationInstance('concurrent'),
non_browser: createViolationInstance('non_browser'), non_browser: createViolationInstance('non_browser'),

View file

@ -5,19 +5,18 @@ const Action = mongoose.model('action', actionSchema);
/** /**
* Update an action with new data without overwriting existing properties, * Update an action with new data without overwriting existing properties,
* or create a new action if it doesn't exist. * or create a new action if it doesn't exist, within a transaction session if provided.
* *
* @param {Object} searchParams - The search parameters to find the action to update. * @param {Object} searchParams - The search parameters to find the action to update.
* @param {string} searchParams.action_id - The ID of the action to update. * @param {string} searchParams.action_id - The ID of the action to update.
* @param {string} searchParams.user - The user ID of the action's author. * @param {string} searchParams.user - The user ID of the action's author.
* @param {Object} updateData - An object containing the properties to update. * @param {Object} updateData - An object containing the properties to update.
* @param {mongoose.ClientSession} [session] - The transaction session to use.
* @returns {Promise<Object>} The updated or newly created action document as a plain object. * @returns {Promise<Object>} The updated or newly created action document as a plain object.
*/ */
const updateAction = async (searchParams, updateData) => { const updateAction = async (searchParams, updateData, session = null) => {
return await Action.findOneAndUpdate(searchParams, updateData, { const options = { new: true, upsert: true, session };
new: true, return await Action.findOneAndUpdate(searchParams, updateData, options).lean();
upsert: true,
}).lean();
}; };
/** /**
@ -50,15 +49,17 @@ const getActions = async (searchParams, includeSensitive = false) => {
}; };
/** /**
* Deletes an action by its ID. * Deletes an action by params, within a transaction session if provided.
* *
* @param {Object} searchParams - The search parameters to find the action to update. * @param {Object} searchParams - The search parameters to find the action to delete.
* @param {string} searchParams.action_id - The ID of the action to update. * @param {string} searchParams.action_id - The ID of the action to delete.
* @param {string} searchParams.user - The user ID of the action's author. * @param {string} searchParams.user - The user ID of the action's author.
* @param {mongoose.ClientSession} [session] - The transaction session to use (optional).
* @returns {Promise<Object>} A promise that resolves to the deleted action document as a plain object, or null if no document was found. * @returns {Promise<Object>} A promise that resolves to the deleted action document as a plain object, or null if no document was found.
*/ */
const deleteAction = async (searchParams) => { const deleteAction = async (searchParams, session = null) => {
return await Action.findOneAndDelete(searchParams).lean(); const options = session ? { session } : {};
return await Action.findOneAndDelete(searchParams, options).lean();
}; };
module.exports = { module.exports = {

View file

@ -5,19 +5,18 @@ const Assistant = mongoose.model('assistant', assistantSchema);
/** /**
* Update an assistant with new data without overwriting existing properties, * Update an assistant with new data without overwriting existing properties,
* or create a new assistant if it doesn't exist. * or create a new assistant if it doesn't exist, within a transaction session if provided.
* *
* @param {Object} searchParams - The search parameters to find the assistant to update. * @param {Object} searchParams - The search parameters to find the assistant to update.
* @param {string} searchParams.assistant_id - The ID of the assistant to update. * @param {string} searchParams.assistant_id - The ID of the assistant to update.
* @param {string} searchParams.user - The user ID of the assistant's author. * @param {string} searchParams.user - The user ID of the assistant's author.
* @param {Object} updateData - An object containing the properties to update. * @param {Object} updateData - An object containing the properties to update.
* @param {mongoose.ClientSession} [session] - The transaction session to use (optional).
* @returns {Promise<Object>} The updated or newly created assistant document as a plain object. * @returns {Promise<Object>} The updated or newly created assistant document as a plain object.
*/ */
const updateAssistant = async (searchParams, updateData) => { const updateAssistant = async (searchParams, updateData, session = null) => {
return await Assistant.findOneAndUpdate(searchParams, updateData, { const options = { new: true, upsert: true, session };
new: true, return await Assistant.findOneAndUpdate(searchParams, updateData, options).lean();
upsert: true,
}).lean();
}; };
/** /**

View file

@ -99,4 +99,6 @@ const fileSchema = mongoose.Schema(
}, },
); );
fileSchema.index({ createdAt: 1, updatedAt: 1 });
module.exports = fileSchema; module.exports = fileSchema;

View file

@ -54,7 +54,7 @@ const spendTokens = async (txData, tokenUsage) => {
prompt && prompt &&
completion && completion &&
logger.debug('[spendTokens] Transaction data record against balance:', { logger.debug('[spendTokens] Transaction data record against balance:', {
user: prompt.user, user: txData.user,
prompt: prompt.prompt, prompt: prompt.prompt,
promptRate: prompt.rate, promptRate: prompt.rate,
completion: completion.completion, completion: completion.completion,

View file

@ -1,14 +1,15 @@
const Keyv = require('keyv'); const Keyv = require('keyv');
const uap = require('ua-parser-js'); const uap = require('ua-parser-js');
const denyRequest = require('./denyRequest'); const { ViolationTypes } = require('librechat-data-provider');
const { getLogStores } = require('../../cache');
const { isEnabled, removePorts } = require('../utils'); const { isEnabled, removePorts } = require('../utils');
const keyvRedis = require('../../cache/keyvRedis'); const keyvRedis = require('~/cache/keyvRedis');
const User = require('../../models/User'); const denyRequest = require('./denyRequest');
const { getLogStores } = require('~/cache');
const User = require('~/models/User');
const banCache = isEnabled(process.env.USE_REDIS) const banCache = isEnabled(process.env.USE_REDIS)
? new Keyv({ store: keyvRedis }) ? new Keyv({ store: keyvRedis })
: new Keyv({ namespace: 'bans', ttl: 0 }); : new Keyv({ namespace: ViolationTypes.BAN, ttl: 0 });
const message = 'Your account has been temporarily banned due to violations of our service.'; const message = 'Your account has been temporarily banned due to violations of our service.';
/** /**
@ -28,7 +29,7 @@ const banResponse = async (req, res) => {
if (!ua.browser.name) { if (!ua.browser.name) {
return res.status(403).json({ message }); return res.status(403).json({ message });
} else if (baseUrl === '/api/ask' || baseUrl === '/api/edit') { } else if (baseUrl === '/api/ask' || baseUrl === '/api/edit') {
return await denyRequest(req, res, { type: 'ban' }); return await denyRequest(req, res, { type: ViolationTypes.BAN });
} }
return res.status(403).json({ message }); return res.status(403).json({ message });
@ -87,7 +88,7 @@ const checkBan = async (req, res, next = () => {}) => {
return await banResponse(req, res); return await banResponse(req, res);
} }
const banLogs = getLogStores('ban'); const banLogs = getLogStores(ViolationTypes.BAN);
const duration = banLogs.opts.ttl; const duration = banLogs.opts.ttl;
if (duration <= 0) { if (duration <= 0) {

View file

@ -1,10 +1,11 @@
const { v4 } = require('uuid'); const { v4 } = require('uuid');
const express = require('express'); const express = require('express');
const { actionDelimiter } = require('librechat-data-provider');
const { initializeClient } = require('~/server/services/Endpoints/assistants');
const { encryptMetadata, domainParser } = require('~/server/services/ActionService'); const { encryptMetadata, domainParser } = require('~/server/services/ActionService');
const { actionDelimiter, EModelEndpoint } = require('librechat-data-provider');
const { initializeClient } = require('~/server/services/Endpoints/assistants');
const { updateAction, getActions, deleteAction } = require('~/models/Action'); const { updateAction, getActions, deleteAction } = require('~/models/Action');
const { updateAssistant, getAssistant } = require('~/models/Assistant'); const { updateAssistant, getAssistant } = require('~/models/Assistant');
const { withSession } = require('~/server/utils');
const { logger } = require('~/config'); const { logger } = require('~/config');
const router = express.Router(); const router = express.Router();
@ -46,7 +47,7 @@ router.post('/:assistant_id', async (req, res) => {
let { domain } = metadata; let { domain } = metadata;
/* Azure doesn't support periods in function names */ /* Azure doesn't support periods in function names */
domain = domainParser(req, domain, true); domain = await domainParser(req, domain, true);
if (!domain) { if (!domain) {
return res.status(400).json({ message: 'No domain provided' }); return res.status(400).json({ message: 'No domain provided' });
@ -110,7 +111,8 @@ router.post('/:assistant_id', async (req, res) => {
const promises = []; const promises = [];
promises.push( promises.push(
updateAssistant( withSession(
updateAssistant,
{ assistant_id }, { assistant_id },
{ {
actions, actions,
@ -119,7 +121,9 @@ router.post('/:assistant_id', async (req, res) => {
), ),
); );
promises.push(openai.beta.assistants.update(assistant_id, { tools })); promises.push(openai.beta.assistants.update(assistant_id, { tools }));
promises.push(updateAction({ action_id }, { metadata, assistant_id, user: req.user.id })); promises.push(
withSession(updateAction, { action_id }, { metadata, assistant_id, user: req.user.id }),
);
/** @type {[AssistantDocument, Assistant, Action]} */ /** @type {[AssistantDocument, Assistant, Action]} */
const resolved = await Promise.all(promises); const resolved = await Promise.all(promises);
@ -129,6 +133,15 @@ router.post('/:assistant_id', async (req, res) => {
delete resolved[2].metadata[field]; delete resolved[2].metadata[field];
} }
} }
/* Map Azure OpenAI model to the assistant as defined by config */
if (req.app.locals[EModelEndpoint.azureOpenAI]?.assistants) {
resolved[1] = {
...resolved[1],
model: req.body.model,
};
}
res.json(resolved); res.json(resolved);
} catch (error) { } catch (error) {
const message = 'Trouble updating the Assistant Action'; const message = 'Trouble updating the Assistant Action';
@ -171,7 +184,7 @@ router.delete('/:assistant_id/:action_id/:model', async (req, res) => {
return true; return true;
}); });
domain = domainParser(req, domain, true); domain = await domainParser(req, domain, true);
const updatedTools = tools.filter( const updatedTools = tools.filter(
(tool) => !(tool.function && tool.function.name.includes(domain)), (tool) => !(tool.function && tool.function.name.includes(domain)),
@ -179,7 +192,8 @@ router.delete('/:assistant_id/:action_id/:model', async (req, res) => {
const promises = []; const promises = [];
promises.push( promises.push(
updateAssistant( withSession(
updateAssistant,
{ assistant_id }, { assistant_id },
{ {
actions: updatedActions, actions: updatedActions,
@ -188,7 +202,7 @@ router.delete('/:assistant_id/:action_id/:model', async (req, res) => {
), ),
); );
promises.push(openai.beta.assistants.update(assistant_id, { tools: updatedTools })); promises.push(openai.beta.assistants.update(assistant_id, { tools: updatedTools }));
promises.push(deleteAction({ action_id })); promises.push(withSession(deleteAction, { action_id }));
await Promise.all(promises); await Promise.all(promises);
res.status(200).json({ message: 'Action deleted successfully' }); res.status(200).json({ message: 'Action deleted successfully' });

View file

@ -1,20 +1,27 @@
const { AuthTypeEnum, EModelEndpoint, actionDomainSeparator } = require('librechat-data-provider'); const {
AuthTypeEnum,
EModelEndpoint,
actionDomainSeparator,
CacheKeys,
Constants,
} = require('librechat-data-provider');
const { encryptV2, decryptV2 } = require('~/server/utils/crypto'); const { encryptV2, decryptV2 } = require('~/server/utils/crypto');
const { getActions } = require('~/models/Action'); const { getActions } = require('~/models/Action');
const { getLogStores } = require('~/cache');
const { logger } = require('~/config'); const { logger } = require('~/config');
/** /**
* Parses the domain for an action. * Encodes or decodes a domain name to/from base64, or replacing periods with a custom separator.
* *
* Azure OpenAI Assistants API doesn't support periods in function * Necessary because Azure OpenAI Assistants API doesn't support periods in function
* names due to `[a-zA-Z0-9_-]*` Regex Validation. * names due to `[a-zA-Z0-9_-]*` Regex Validation, limited to a 64-character maximum.
* *
* @param {Express.Request} req - Express Request object * @param {Express.Request} req - The Express Request object.
* @param {string} domain - The domain for the actoin * @param {string} domain - The domain name to encode/decode.
* @param {boolean} inverse - If true, replaces periods with `actionDomainSeparator` * @param {boolean} inverse - False to decode from base64, true to encode to base64.
* @returns {string} The parsed domain * @returns {Promise<string>} Encoded or decoded domain string.
*/ */
function domainParser(req, domain, inverse = false) { async function domainParser(req, domain, inverse = false) {
if (!domain) { if (!domain) {
return; return;
} }
@ -23,11 +30,35 @@ function domainParser(req, domain, inverse = false) {
return domain; return domain;
} }
if (inverse) { const domainsCache = getLogStores(CacheKeys.ENCODED_DOMAINS);
const cachedDomain = await domainsCache.get(domain);
if (inverse && cachedDomain) {
return domain;
}
if (inverse && domain.length <= Constants.ENCODED_DOMAIN_LENGTH) {
return domain.replace(/\./g, actionDomainSeparator); return domain.replace(/\./g, actionDomainSeparator);
} }
return domain.replace(actionDomainSeparator, '.'); if (inverse) {
const modifiedDomain = Buffer.from(domain).toString('base64');
const key = modifiedDomain.substring(0, Constants.ENCODED_DOMAIN_LENGTH);
await domainsCache.set(key, modifiedDomain);
return key;
}
const replaceSeparatorRegex = new RegExp(actionDomainSeparator, 'g');
if (!cachedDomain) {
return domain.replace(replaceSeparatorRegex, '.');
}
try {
return Buffer.from(cachedDomain, 'base64').toString('utf-8');
} catch (error) {
logger.error(`Failed to parse domain (possibly not base64): ${domain}`, error);
return domain;
}
} }
/** /**

View file

@ -0,0 +1,196 @@
const { Constants, EModelEndpoint, actionDomainSeparator } = require('librechat-data-provider');
const { domainParser } = require('./ActionService');
jest.mock('keyv');
const globalCache = {};
jest.mock('~/cache/getLogStores', () => {
return jest.fn().mockImplementation(() => {
const EventEmitter = require('events');
const { CacheKeys } = require('librechat-data-provider');
class KeyvMongo extends EventEmitter {
constructor(url = 'mongodb://127.0.0.1:27017', options) {
super();
this.ttlSupport = false;
url = url ?? {};
if (typeof url === 'string') {
url = { url };
}
if (url.uri) {
url = { url: url.uri, ...url };
}
this.opts = {
url,
collection: 'keyv',
...url,
...options,
};
}
get = async (key) => {
return new Promise((resolve) => {
resolve(globalCache[key] || null);
});
};
set = async (key, value) => {
return new Promise((resolve) => {
globalCache[key] = value;
resolve(true);
});
};
}
return new KeyvMongo('', {
namespace: CacheKeys.ENCODED_DOMAINS,
ttl: 0,
});
});
});
describe('domainParser', () => {
const req = {
app: {
locals: {
[EModelEndpoint.azureOpenAI]: {
assistants: true,
},
},
},
};
const reqNoAzure = {
app: {
locals: {
[EModelEndpoint.azureOpenAI]: {
assistants: false,
},
},
},
};
const TLD = '.com';
// Non-azure request
it('returns domain as is if not azure', async () => {
const domain = `example.com${actionDomainSeparator}test${actionDomainSeparator}`;
const result1 = await domainParser(reqNoAzure, domain, false);
const result2 = await domainParser(reqNoAzure, domain, true);
expect(result1).toEqual(domain);
expect(result2).toEqual(domain);
});
// Test for Empty or Null Inputs
it('returns undefined for null domain input', async () => {
const result = await domainParser(req, null, true);
expect(result).toBeUndefined();
});
it('returns undefined for empty domain input', async () => {
const result = await domainParser(req, '', true);
expect(result).toBeUndefined();
});
// Verify Correct Caching Behavior
it('caches encoded domain correctly', async () => {
const domain = 'longdomainname.com';
const encodedDomain = Buffer.from(domain)
.toString('base64')
.substring(0, Constants.ENCODED_DOMAIN_LENGTH);
await domainParser(req, domain, true);
const cachedValue = await globalCache[encodedDomain];
expect(cachedValue).toEqual(Buffer.from(domain).toString('base64'));
});
// Test for Edge Cases Around Length Threshold
it('encodes domain exactly at threshold without modification', async () => {
const domain = 'a'.repeat(Constants.ENCODED_DOMAIN_LENGTH - TLD.length) + TLD;
const expected = domain.replace(/\./g, actionDomainSeparator);
const result = await domainParser(req, domain, true);
expect(result).toEqual(expected);
});
it('encodes domain just below threshold without modification', async () => {
const domain = 'a'.repeat(Constants.ENCODED_DOMAIN_LENGTH - 1 - TLD.length) + TLD;
const expected = domain.replace(/\./g, actionDomainSeparator);
const result = await domainParser(req, domain, true);
expect(result).toEqual(expected);
});
// Test for Unicode Domain Names
it('handles unicode characters in domain names correctly when encoding', async () => {
const unicodeDomain = 'täst.example.com';
const encodedDomain = Buffer.from(unicodeDomain)
.toString('base64')
.substring(0, Constants.ENCODED_DOMAIN_LENGTH);
const result = await domainParser(req, unicodeDomain, true);
expect(result).toEqual(encodedDomain);
});
it('decodes unicode domain names correctly', async () => {
const unicodeDomain = 'täst.example.com';
const encodedDomain = Buffer.from(unicodeDomain).toString('base64');
globalCache[encodedDomain.substring(0, Constants.ENCODED_DOMAIN_LENGTH)] = encodedDomain; // Simulate caching
const result = await domainParser(
req,
encodedDomain.substring(0, Constants.ENCODED_DOMAIN_LENGTH),
false,
);
expect(result).toEqual(unicodeDomain);
});
// Core Functionality Tests
it('returns domain with replaced separators if no cached domain exists', async () => {
const domain = 'example.com';
const withSeparator = domain.replace(/\./g, actionDomainSeparator);
const result = await domainParser(req, withSeparator, false);
expect(result).toEqual(domain);
});
it('returns domain with replaced separators when inverse is false and under encoding length', async () => {
const domain = 'examp.com';
const withSeparator = domain.replace(/\./g, actionDomainSeparator);
const result = await domainParser(req, withSeparator, false);
expect(result).toEqual(domain);
});
it('replaces periods with actionDomainSeparator when inverse is true and under encoding length', async () => {
const domain = 'examp.com';
const expected = domain.replace(/\./g, actionDomainSeparator);
const result = await domainParser(req, domain, true);
expect(result).toEqual(expected);
});
it('encodes domain when length is above threshold and inverse is true', async () => {
const domain = 'a'.repeat(Constants.ENCODED_DOMAIN_LENGTH + 1).concat('.com');
const result = await domainParser(req, domain, true);
expect(result).not.toEqual(domain);
expect(result.length).toBeLessThanOrEqual(Constants.ENCODED_DOMAIN_LENGTH);
});
it('returns encoded value if no encoded value is cached, and inverse is false', async () => {
const originalDomain = 'example.com';
const encodedDomain = Buffer.from(
originalDomain.replace(/\./g, actionDomainSeparator),
).toString('base64');
const result = await domainParser(req, encodedDomain, false);
expect(result).toEqual(encodedDomain);
});
it('decodes encoded value if cached and encoded value is provided, and inverse is false', async () => {
const originalDomain = 'example.com';
const encodedDomain = await domainParser(req, originalDomain, true);
const result = await domainParser(req, encodedDomain, false);
expect(result).toEqual(originalDomain);
});
it('handles invalid base64 encoded values gracefully', async () => {
const invalidBase64Domain = 'not_base64_encoded';
const result = await domainParser(req, invalidBase64Domain, false);
expect(result).toEqual(invalidBase64Domain);
});
});

View file

@ -274,9 +274,16 @@ async function processRequiredActions(client, requiredActions) {
})) ?? []; })) ?? [];
} }
const actionSet = actionSets.find((action) => let actionSet = null;
currentAction.tool.includes(domainParser(client.req, action.metadata.domain, true)), let currentDomain = '';
); for (let action of actionSets) {
const domain = await domainParser(client.req, action.metadata.domain, true);
if (currentAction.tool.includes(domain)) {
currentDomain = domain;
actionSet = action;
break;
}
}
if (!actionSet) { if (!actionSet) {
// TODO: try `function` if no action set is found // TODO: try `function` if no action set is found
@ -298,10 +305,8 @@ async function processRequiredActions(client, requiredActions) {
builders = requestBuilders; builders = requestBuilders;
} }
const functionName = currentAction.tool.replace( const functionName = currentAction.tool.replace(`${actionDelimiter}${currentDomain}`, '');
`${actionDelimiter}${domainParser(client.req, actionSet.metadata.domain, true)}`,
'',
);
const requestBuilder = builders[functionName]; const requestBuilder = builders[functionName];
if (!requestBuilder) { if (!requestBuilder) {

View file

@ -5,6 +5,7 @@ const handleText = require('./handleText');
const cryptoUtils = require('./crypto'); const cryptoUtils = require('./crypto');
const citations = require('./citations'); const citations = require('./citations');
const sendEmail = require('./sendEmail'); const sendEmail = require('./sendEmail');
const mongoose = require('./mongoose');
const queue = require('./queue'); const queue = require('./queue');
const files = require('./files'); const files = require('./files');
const math = require('./math'); const math = require('./math');
@ -14,6 +15,7 @@ module.exports = {
...cryptoUtils, ...cryptoUtils,
...handleText, ...handleText,
...citations, ...citations,
...mongoose,
countTokens, countTokens,
removePorts, removePorts,
sendEmail, sendEmail,

View file

@ -0,0 +1,25 @@
const mongoose = require('mongoose');
/**
* Executes a database operation within a session.
* @param {() => Promise<any>} method - The method to execute. This method must accept a session as its first argument.
* @param {...any} args - Additional arguments to pass to the method.
* @returns {Promise<any>} - The result of the executed method.
*/
async function withSession(method, ...args) {
const session = await mongoose.startSession();
session.startTransaction();
try {
const result = await method(...args, session);
await session.commitTransaction();
return result;
} catch (error) {
if (session.inTransaction()) {
await session.abortTransaction();
}
throw error;
} finally {
await session.endSession();
}
}
module.exports = { withSession };

View file

@ -1,5 +1,5 @@
// import { useState, useEffect } from 'react'; // import { useState, useEffect } from 'react';
import { actionDelimiter, actionDomainSeparator } from 'librechat-data-provider'; import { actionDelimiter, actionDomainSeparator, Constants } from 'librechat-data-provider';
import * as Popover from '@radix-ui/react-popover'; import * as Popover from '@radix-ui/react-popover';
import useLocalize from '~/hooks/useLocalize'; import useLocalize from '~/hooks/useLocalize';
import ProgressCircle from './ProgressCircle'; import ProgressCircle from './ProgressCircle';
@ -63,7 +63,7 @@ export default function ToolCall({
onClick={() => ({})} onClick={() => ({})}
inProgressText={localize('com_assistants_running_action')} inProgressText={localize('com_assistants_running_action')}
finishedText={ finishedText={
domain domain && domain.length !== Constants.ENCODED_DOMAIN_LENGTH
? localize('com_assistants_completed_action', domain) ? localize('com_assistants_completed_action', domain)
: localize('com_assistants_completed_function', function_name) : localize('com_assistants_completed_function', function_name)
} }

View file

@ -1,4 +1,5 @@
import { Fragment } from 'react'; import { Fragment } from 'react';
import { ViolationTypes } from 'librechat-data-provider';
import type { TResPlugin } from 'librechat-data-provider'; import type { TResPlugin } from 'librechat-data-provider';
import type { TMessageContentProps, TText, TDisplayProps } from '~/common'; import type { TMessageContentProps, TText, TDisplayProps } from '~/common';
import { useAuthContext } from '~/hooks'; import { useAuthContext } from '~/hooks';
@ -12,7 +13,7 @@ import Error from './Error';
const ErrorMessage = ({ text }: TText) => { const ErrorMessage = ({ text }: TText) => {
const { logout } = useAuthContext(); const { logout } = useAuthContext();
if (text.includes('ban')) { if (text.includes(ViolationTypes.BAN)) {
logout(); logout();
return null; return null;
} }

View file

@ -15,7 +15,7 @@ export default function AssistantAction({
className="border-token-border-medium flex w-full rounded-lg border text-sm hover:cursor-pointer" className="border-token-border-medium flex w-full rounded-lg border text-sm hover:cursor-pointer"
> >
<div <div
className="h-9 grow px-3 py-2" className="h-9 grow whitespace-nowrap px-3 py-2"
style={{ textOverflow: 'ellipsis', wordBreak: 'break-all', overflow: 'hidden' }} style={{ textOverflow: 'ellipsis', wordBreak: 'break-all', overflow: 'hidden' }}
> >
{action.metadata.domain} {action.metadata.domain}

View file

@ -482,6 +482,15 @@ export enum CacheKeys {
* Key for the override config cache. * Key for the override config cache.
*/ */
OVERRIDE_CONFIG = 'overrideConfig', OVERRIDE_CONFIG = 'overrideConfig',
/**
* Key for the bans cache.
*/
BANS = 'bans',
/**
* Key for the encoded domains cache.
* Used by Azure OpenAI Assistants.
*/
ENCODED_DOMAINS = 'encoded_domains',
} }
/** /**
@ -500,6 +509,10 @@ export enum ViolationTypes {
* Token Limit Violation. * Token Limit Violation.
*/ */
TOKEN_BALANCE = 'token_balance', TOKEN_BALANCE = 'token_balance',
/**
* An issued ban.
*/
BAN = 'ban',
} }
/** /**
@ -580,6 +593,10 @@ export enum Constants {
* Standard value for the first message's `parentMessageId` value, to indicate no parent exists. * Standard value for the first message's `parentMessageId` value, to indicate no parent exists.
*/ */
NO_PARENT = '00000000-0000-0000-0000-000000000000', NO_PARENT = '00000000-0000-0000-0000-000000000000',
/**
* Fixed, encoded domain length for Azure OpenAI Assistants Function name parsing.
*/
ENCODED_DOMAIN_LENGTH = 10,
} }
/** /**