diff --git a/api/models/tx.js b/api/models/tx.js index 82ae9fb034..7bc3a77600 100644 --- a/api/models/tx.js +++ b/api/models/tx.js @@ -1,6 +1,49 @@ const { matchModelName } = require('../utils'); const defaultRate = 6; +const customTokenOverrides = {}; +const customCacheOverrides = {}; + +/** + * Allows overriding the default token multipliers. + * + * @param {Object} overrides - An object mapping model keys to their custom token multipliers. + * @param {Object} overrides. - An object containing custom multipliers for the model. + * @param {number} overrides..prompt - The custom prompt multiplier for the model. + * @param {number} overrides..completion - The custom completion multiplier for the model. + * + * @example + * // Override the multipliers for "gpt-4o-mini" and "gpt-3.5": + * setCustomTokenOverrides({ + * "gpt-4o-mini": { prompt: 0.2, completion: 0.5 }, + * "gpt-3.5": { prompt: 1.0, completion: 2.0 } + * }); + */ +const setCustomTokenOverrides = (overrides) => { + Object.assign(customTokenOverrides, overrides); +}; + +/** + * Allows overriding the default cache multipliers. + * The override values should be nested under a key named "Cache". + * + * @param {Object} overrides - An object mapping model keys to their custom cache multipliers. + * @param {Object} overrides. - An object that must include a "Cache" property. + * @param {Object} overrides..Cache - An object containing custom cache multipliers for the model. + * @param {number} overrides..Cache.write - The custom cache write multiplier for the model. + * @param {number} overrides..Cache.read - The custom cache read multiplier for the model. + * + * @example + * // Override the cache multipliers for "gpt-4o-mini" and "gpt-3.5": + * setCustomCacheOverrides({ + * "gpt-4o-mini": { cache: { write: 0.2, read: 0.5 } }, + * "gpt-3.5": { cache: { write: 1.0, read: 1.5 } } + * }); + */ +const setCustomCacheOverrides = (overrides) => { + Object.assign(customCacheOverrides, overrides); +}; + /** * AWS Bedrock pricing * source: https://aws.amazon.com/bedrock/pricing/ @@ -243,20 +286,23 @@ const getCacheMultiplier = ({ valueKey, cacheType, model, endpoint, endpointToke return endpointTokenConfig?.[model]?.[cacheType] ?? null; } - if (valueKey && cacheType) { - return cacheTokenValues[valueKey]?.[cacheType] ?? null; + if (!valueKey && model) { + valueKey = getValueKey(model, endpoint); } - - if (!cacheType || !model) { - return null; - } - - valueKey = getValueKey(model, endpoint); if (!valueKey) { return null; } - // If we got this far, and values[cacheType] is undefined somehow, return a rough average of default multipliers + // Check for custom cache overrides under the "cache" property. + if ( + customCacheOverrides[valueKey] && + customCacheOverrides[valueKey].cache && + customCacheOverrides[valueKey].cache[cacheType] != null + ) { + return customCacheOverrides[valueKey].cache[cacheType]; + } + + // Fallback to the default cacheTokenValues. return cacheTokenValues[valueKey]?.[cacheType] ?? null; }; @@ -267,4 +313,6 @@ module.exports = { getCacheMultiplier, defaultRate, cacheTokenValues, + setCustomTokenOverrides, + setCustomCacheOverrides, }; diff --git a/api/server/services/AppService.js b/api/server/services/AppService.js index 09827c9244..041bd36163 100644 --- a/api/server/services/AppService.js +++ b/api/server/services/AppService.js @@ -25,10 +25,7 @@ const AppService = async (app) => { /** @type {TCustomConfig} */ const config = (await loadCustomConfig()) ?? {}; const configDefaults = getConfigDefaults(); - const tokenRatesConfig = loadTokenRatesConfig(config, configDefaults); - // - // // Set the global token rates configuration so that it can be used by the tx.js functions. - // setTokenRatesConfig(tokenRatesConfig); + loadTokenRatesConfig(config, configDefaults); const filteredTools = config.filteredTools; const includedTools = config.includedTools; diff --git a/api/server/services/Config/loadTokenRatesConfig.js b/api/server/services/Config/loadTokenRatesConfig.js index cc7d03b8a0..8513453e4e 100644 --- a/api/server/services/Config/loadTokenRatesConfig.js +++ b/api/server/services/Config/loadTokenRatesConfig.js @@ -1,8 +1,9 @@ const { removeNullishValues } = require('librechat-data-provider'); const { logger } = require('~/config'); +const { setCustomTokenOverrides, setCustomCacheOverrides } = require('~/models/tx'); /** - * Loads custom token rates from the user's YAML config, merging with default token rates if available. + * Loads token rates from the user's configuration, merging with default token rates if available. * * @param {TCustomConfig | undefined} config - The loaded custom configuration. * @param {TConfigDefaults} [configDefaults] - Optional default configuration values. @@ -13,6 +14,8 @@ function loadTokenRatesConfig(config, configDefaults) { if (!configDefaults?.tokenRates) { logger.info(`User tokenRates configuration:\n${JSON.stringify(userTokenRates, null, 2)}`); + // Apply custom token rates even if there are no defaults + applyCustomTokenRates(userTokenRates); return userTokenRates; } @@ -20,8 +23,49 @@ function loadTokenRatesConfig(config, configDefaults) { const defaultTokenRates = removeNullishValues(configDefaults.tokenRates); const merged = { ...defaultTokenRates, ...userTokenRates }; + // Apply custom token rates configuration + applyCustomTokenRates(merged); + logger.info(`Merged tokenRates configuration:\n${JSON.stringify(merged, null, 2)}`); return merged; } +/** + * Processes the token rates configuration to set up custom overrides for each model. + * + * The configuration is expected to be specified per model: + * + * For each model in the tokenRates configuration, this function will call the tx.js + * override functions to apply the custom token and cache multipliers. + * + * @param {TModelTokenRates} tokenRates - The token rates configuration mapping models to token costs. + */ +function applyCustomTokenRates(tokenRates) { + // Iterate over each model in the tokenRates configuration. + Object.keys(tokenRates).forEach((model) => { + const rate = tokenRates[model]; + // If token multipliers are provided, set custom token overrides. + if (rate.prompt != null || rate.completion != null) { + setCustomTokenOverrides({ + [model]: { + prompt: rate.prompt, + completion: rate.completion, + }, + }); + } + // Check for cache overrides. + const cacheOverrides = rate.cache; + if (cacheOverrides && (cacheOverrides.write != null || cacheOverrides.read != null)) { + setCustomCacheOverrides({ + [model]: { + cache: { + write: cacheOverrides.write, + read: cacheOverrides.read, + }, + }, + }); + } + }); +} + module.exports = { loadTokenRatesConfig }; \ No newline at end of file diff --git a/packages/data-provider/src/config.ts b/packages/data-provider/src/config.ts index 5dbf8f90ee..55097b90e7 100644 --- a/packages/data-provider/src/config.ts +++ b/packages/data-provider/src/config.ts @@ -505,7 +505,7 @@ export type TStartupConfig = { helpAndFaqURL: string; customFooter?: string; modelSpecs?: TSpecsConfig; - tokenRates?: TTokenRates; + tokenRates?: TModelTokenRates; sharedLinksEnabled: boolean; publicSharedLinksEnabled: boolean; analyticsGtmId?: string; @@ -523,16 +523,7 @@ export type TTokenCost = { }; // Endpoint token rates schema type -export type TEndpointTokenRates = Record; - -// Token rates schema type -export type TTokenRates = { - openAI?: TEndpointTokenRates; - google?: TEndpointTokenRates; - anthropic?: TEndpointTokenRates; - bedrock?: TEndpointTokenRates; - custom?: TEndpointTokenRates; -}; +export type TModelTokenRates = Record; const tokenCostSchema = z.object({ prompt: z.number().optional(), // e.g. 1.5 => $1.50 / 1M tokens @@ -545,16 +536,6 @@ const tokenCostSchema = z.object({ .optional(), }); -const endpointTokenRatesSchema = z.record(z.string(), tokenCostSchema); - -const tokenRatesSchema = z.object({ - openAI: endpointTokenRatesSchema.optional(), - google: endpointTokenRatesSchema.optional(), - anthropic: endpointTokenRatesSchema.optional(), - bedrock: endpointTokenRatesSchema.optional(), - custom: endpointTokenRatesSchema.optional(), -}); - export const configSchema = z.object({ version: z.string(), cache: z.boolean().default(true), @@ -586,7 +567,7 @@ export const configSchema = z.object({ rateLimits: rateLimitSchema.optional(), fileConfig: fileConfigSchema.optional(), modelSpecs: specsConfigSchema.optional(), - tokenRates: tokenRatesSchema.optional(), + tokenRates: tokenCostSchema.optional(), endpoints: z .object({ all: baseEndpointSchema.optional(),