feat: Refactor Token Rates Configuration and Introduce Custom Overrides

This commit is contained in:
Ruben Talstra 2025-02-27 10:57:42 +01:00
parent 7dfb386f5a
commit 262e6aa4c9
No known key found for this signature in database
GPG key ID: 2A5A7174A60F3BEA
4 changed files with 106 additions and 36 deletions

View file

@ -1,6 +1,49 @@
const { matchModelName } = require('../utils');
const defaultRate = 6;
const customTokenOverrides = {};
const customCacheOverrides = {};
/**
* Allows overriding the default token multipliers.
*
* @param {Object} overrides - An object mapping model keys to their custom token multipliers.
* @param {Object} overrides.<model> - An object containing custom multipliers for the model.
* @param {number} overrides.<model>.prompt - The custom prompt multiplier for the model.
* @param {number} overrides.<model>.completion - The custom completion multiplier for the model.
*
* @example
* // Override the multipliers for "gpt-4o-mini" and "gpt-3.5":
* setCustomTokenOverrides({
* "gpt-4o-mini": { prompt: 0.2, completion: 0.5 },
* "gpt-3.5": { prompt: 1.0, completion: 2.0 }
* });
*/
const setCustomTokenOverrides = (overrides) => {
Object.assign(customTokenOverrides, overrides);
};
/**
* Allows overriding the default cache multipliers.
* The override values should be nested under a key named "Cache".
*
* @param {Object} overrides - An object mapping model keys to their custom cache multipliers.
* @param {Object} overrides.<model> - An object that must include a "Cache" property.
* @param {Object} overrides.<model>.Cache - An object containing custom cache multipliers for the model.
* @param {number} overrides.<model>.Cache.write - The custom cache write multiplier for the model.
* @param {number} overrides.<model>.Cache.read - The custom cache read multiplier for the model.
*
* @example
* // Override the cache multipliers for "gpt-4o-mini" and "gpt-3.5":
* setCustomCacheOverrides({
* "gpt-4o-mini": { cache: { write: 0.2, read: 0.5 } },
* "gpt-3.5": { cache: { write: 1.0, read: 1.5 } }
* });
*/
const setCustomCacheOverrides = (overrides) => {
Object.assign(customCacheOverrides, overrides);
};
/**
* AWS Bedrock pricing
* source: https://aws.amazon.com/bedrock/pricing/
@ -243,20 +286,23 @@ const getCacheMultiplier = ({ valueKey, cacheType, model, endpoint, endpointToke
return endpointTokenConfig?.[model]?.[cacheType] ?? null;
}
if (valueKey && cacheType) {
return cacheTokenValues[valueKey]?.[cacheType] ?? null;
if (!valueKey && model) {
valueKey = getValueKey(model, endpoint);
}
if (!cacheType || !model) {
return null;
}
valueKey = getValueKey(model, endpoint);
if (!valueKey) {
return null;
}
// If we got this far, and values[cacheType] is undefined somehow, return a rough average of default multipliers
// Check for custom cache overrides under the "cache" property.
if (
customCacheOverrides[valueKey] &&
customCacheOverrides[valueKey].cache &&
customCacheOverrides[valueKey].cache[cacheType] != null
) {
return customCacheOverrides[valueKey].cache[cacheType];
}
// Fallback to the default cacheTokenValues.
return cacheTokenValues[valueKey]?.[cacheType] ?? null;
};
@ -267,4 +313,6 @@ module.exports = {
getCacheMultiplier,
defaultRate,
cacheTokenValues,
setCustomTokenOverrides,
setCustomCacheOverrides,
};

View file

@ -25,10 +25,7 @@ const AppService = async (app) => {
/** @type {TCustomConfig} */
const config = (await loadCustomConfig()) ?? {};
const configDefaults = getConfigDefaults();
const tokenRatesConfig = loadTokenRatesConfig(config, configDefaults);
//
// // Set the global token rates configuration so that it can be used by the tx.js functions.
// setTokenRatesConfig(tokenRatesConfig);
loadTokenRatesConfig(config, configDefaults);
const filteredTools = config.filteredTools;
const includedTools = config.includedTools;

View file

@ -1,8 +1,9 @@
const { removeNullishValues } = require('librechat-data-provider');
const { logger } = require('~/config');
const { setCustomTokenOverrides, setCustomCacheOverrides } = require('~/models/tx');
/**
* Loads custom token rates from the user's YAML config, merging with default token rates if available.
* Loads token rates from the user's configuration, merging with default token rates if available.
*
* @param {TCustomConfig | undefined} config - The loaded custom configuration.
* @param {TConfigDefaults} [configDefaults] - Optional default configuration values.
@ -13,6 +14,8 @@ function loadTokenRatesConfig(config, configDefaults) {
if (!configDefaults?.tokenRates) {
logger.info(`User tokenRates configuration:\n${JSON.stringify(userTokenRates, null, 2)}`);
// Apply custom token rates even if there are no defaults
applyCustomTokenRates(userTokenRates);
return userTokenRates;
}
@ -20,8 +23,49 @@ function loadTokenRatesConfig(config, configDefaults) {
const defaultTokenRates = removeNullishValues(configDefaults.tokenRates);
const merged = { ...defaultTokenRates, ...userTokenRates };
// Apply custom token rates configuration
applyCustomTokenRates(merged);
logger.info(`Merged tokenRates configuration:\n${JSON.stringify(merged, null, 2)}`);
return merged;
}
/**
* Processes the token rates configuration to set up custom overrides for each model.
*
* The configuration is expected to be specified per model:
*
* For each model in the tokenRates configuration, this function will call the tx.js
* override functions to apply the custom token and cache multipliers.
*
* @param {TModelTokenRates} tokenRates - The token rates configuration mapping models to token costs.
*/
function applyCustomTokenRates(tokenRates) {
// Iterate over each model in the tokenRates configuration.
Object.keys(tokenRates).forEach((model) => {
const rate = tokenRates[model];
// If token multipliers are provided, set custom token overrides.
if (rate.prompt != null || rate.completion != null) {
setCustomTokenOverrides({
[model]: {
prompt: rate.prompt,
completion: rate.completion,
},
});
}
// Check for cache overrides.
const cacheOverrides = rate.cache;
if (cacheOverrides && (cacheOverrides.write != null || cacheOverrides.read != null)) {
setCustomCacheOverrides({
[model]: {
cache: {
write: cacheOverrides.write,
read: cacheOverrides.read,
},
},
});
}
});
}
module.exports = { loadTokenRatesConfig };

View file

@ -505,7 +505,7 @@ export type TStartupConfig = {
helpAndFaqURL: string;
customFooter?: string;
modelSpecs?: TSpecsConfig;
tokenRates?: TTokenRates;
tokenRates?: TModelTokenRates;
sharedLinksEnabled: boolean;
publicSharedLinksEnabled: boolean;
analyticsGtmId?: string;
@ -523,16 +523,7 @@ export type TTokenCost = {
};
// Endpoint token rates schema type
export type TEndpointTokenRates = Record<string, TTokenCost>;
// Token rates schema type
export type TTokenRates = {
openAI?: TEndpointTokenRates;
google?: TEndpointTokenRates;
anthropic?: TEndpointTokenRates;
bedrock?: TEndpointTokenRates;
custom?: TEndpointTokenRates;
};
export type TModelTokenRates = Record<string, TTokenCost>;
const tokenCostSchema = z.object({
prompt: z.number().optional(), // e.g. 1.5 => $1.50 / 1M tokens
@ -545,16 +536,6 @@ const tokenCostSchema = z.object({
.optional(),
});
const endpointTokenRatesSchema = z.record(z.string(), tokenCostSchema);
const tokenRatesSchema = z.object({
openAI: endpointTokenRatesSchema.optional(),
google: endpointTokenRatesSchema.optional(),
anthropic: endpointTokenRatesSchema.optional(),
bedrock: endpointTokenRatesSchema.optional(),
custom: endpointTokenRatesSchema.optional(),
});
export const configSchema = z.object({
version: z.string(),
cache: z.boolean().default(true),
@ -586,7 +567,7 @@ export const configSchema = z.object({
rateLimits: rateLimitSchema.optional(),
fileConfig: fileConfigSchema.optional(),
modelSpecs: specsConfigSchema.optional(),
tokenRates: tokenRatesSchema.optional(),
tokenRates: tokenCostSchema.optional(),
endpoints: z
.object({
all: baseEndpointSchema.optional(),