diff --git a/api/app/clients/BaseClient.js b/api/app/clients/BaseClient.js index aa39084b9..3b919c92f 100644 --- a/api/app/clients/BaseClient.js +++ b/api/app/clients/BaseClient.js @@ -428,7 +428,10 @@ class BaseClient { await this.saveMessageToDatabase(userMessage, saveOptions, user); } - if (isEnabled(process.env.CHECK_BALANCE) && supportsBalanceCheck[this.options.endpoint]) { + if ( + isEnabled(process.env.CHECK_BALANCE) && + supportsBalanceCheck[this.options.endpointType ?? this.options.endpoint] + ) { await checkBalance({ req: this.options.req, res: this.options.res, @@ -438,6 +441,7 @@ class BaseClient { amount: promptTokens, model: this.modelOptions.model, endpoint: this.options.endpoint, + endpointTokenConfig: this.options.endpointTokenConfig, }, }); } diff --git a/api/app/clients/OpenAIClient.js b/api/app/clients/OpenAIClient.js index 96f4bb1b0..f9c551097 100644 --- a/api/app/clients/OpenAIClient.js +++ b/api/app/clients/OpenAIClient.js @@ -131,8 +131,13 @@ class OpenAIClient extends BaseClient { const { isChatGptModel } = this; this.isUnofficialChatGptModel = model.startsWith('text-chat') || model.startsWith('text-davinci-002-render'); + this.maxContextTokens = - getModelMaxTokens(model, this.options.endpointType ?? this.options.endpoint) ?? 4095; // 1 less than maximum + getModelMaxTokens( + model, + this.options.endpointType ?? this.options.endpoint, + this.options.endpointTokenConfig, + ) ?? 4095; // 1 less than maximum if (this.shouldSummarize) { this.maxContextTokens = Math.floor(this.maxContextTokens / 2); @@ -780,7 +785,12 @@ ${convo} // TODO: remove the gpt fallback and make it specific to endpoint const { OPENAI_SUMMARY_MODEL = 'gpt-3.5-turbo' } = process.env ?? {}; const model = this.options.summaryModel ?? OPENAI_SUMMARY_MODEL; - const maxContextTokens = getModelMaxTokens(model) ?? 4095; + const maxContextTokens = + getModelMaxTokens( + model, + this.options.endpointType ?? this.options.endpoint, + this.options.endpointTokenConfig, + ) ?? 4095; // 1 less than maximum // 3 tokens for the assistant label, and 98 for the summarizer prompt (101) let promptBuffer = 101; @@ -886,6 +896,7 @@ ${convo} model: this.modelOptions.model, context: 'message', conversationId: this.conversationId, + endpointTokenConfig: this.options.endpointTokenConfig, }, { promptTokens, completionTokens }, ); diff --git a/api/cache/getLogStores.js b/api/cache/getLogStores.js index 016c77000..03aaf8470 100644 --- a/api/cache/getLogStores.js +++ b/api/cache/getLogStores.js @@ -23,6 +23,10 @@ const config = isEnabled(USE_REDIS) ? new Keyv({ store: keyvRedis }) : new Keyv({ namespace: CacheKeys.CONFIG_STORE }); +const tokenConfig = isEnabled(USE_REDIS) // ttl: 30 minutes + ? new Keyv({ store: keyvRedis, ttl: 1800000 }) + : new Keyv({ namespace: CacheKeys.TOKEN_CONFIG, ttl: 1800000 }); + const namespaces = { [CacheKeys.CONFIG_STORE]: config, pending_req, @@ -34,6 +38,7 @@ const namespaces = { token_balance: createViolationInstance('token_balance'), registrations: createViolationInstance('registrations'), logins: createViolationInstance('logins'), + [CacheKeys.TOKEN_CONFIG]: tokenConfig, }; /** diff --git a/api/models/Balance.js b/api/models/Balance.js index 45dec6963..24d9087b7 100644 --- a/api/models/Balance.js +++ b/api/models/Balance.js @@ -10,8 +10,9 @@ balanceSchema.statics.check = async function ({ valueKey, tokenType, amount, + endpointTokenConfig, }) { - const multiplier = getMultiplier({ valueKey, tokenType, model, endpoint }); + const multiplier = getMultiplier({ valueKey, tokenType, model, endpoint, endpointTokenConfig }); const tokenCost = amount * multiplier; const { tokenCredits: balance } = (await this.findOne({ user }, 'tokenCredits').lean()) ?? {}; @@ -24,6 +25,7 @@ balanceSchema.statics.check = async function ({ amount, balance, multiplier, + endpointTokenConfig: !!endpointTokenConfig, }); if (!balance) { diff --git a/api/models/Transaction.js b/api/models/Transaction.js index 0bc26fc37..e60820359 100644 --- a/api/models/Transaction.js +++ b/api/models/Transaction.js @@ -10,8 +10,8 @@ transactionSchema.methods.calculateTokenValue = function () { if (!this.valueKey || !this.tokenType) { this.tokenValue = this.rawAmount; } - const { valueKey, tokenType, model } = this; - const multiplier = getMultiplier({ valueKey, tokenType, model }); + const { valueKey, tokenType, model, endpointTokenConfig } = this; + const multiplier = getMultiplier({ valueKey, tokenType, model, endpointTokenConfig }); this.rate = multiplier; this.tokenValue = this.rawAmount * multiplier; if (this.context && this.tokenType === 'completion' && this.context === 'incomplete') { @@ -25,6 +25,7 @@ transactionSchema.statics.create = async function (transactionData) { const Transaction = this; const transaction = new Transaction(transactionData); + transaction.endpointTokenConfig = transactionData.endpointTokenConfig; transaction.calculateTokenValue(); // Save the transaction diff --git a/api/models/checkBalance.js b/api/models/checkBalance.js index c0bbd060b..87798166e 100644 --- a/api/models/checkBalance.js +++ b/api/models/checkBalance.js @@ -14,6 +14,7 @@ const { logViolation } = require('../cache'); * @param {('prompt' | 'completion')} params.txData.tokenType - The type of token. * @param {number} params.txData.amount - The amount of tokens. * @param {string} params.txData.model - The model name or identifier. + * @param {string} [params.txData.endpointTokenConfig] - The token configuration for the endpoint. * @returns {Promise} Returns true if the user can spend the amount, otherwise denies the request. * @throws {Error} Throws an error if there's an issue with the balance check. */ diff --git a/api/models/spendTokens.js b/api/models/spendTokens.js index fe3a2be87..3687d5512 100644 --- a/api/models/spendTokens.js +++ b/api/models/spendTokens.js @@ -11,6 +11,7 @@ const { logger } = require('~/config'); * @param {String} txData.conversationId - The ID of the conversation. * @param {String} txData.model - The model name. * @param {String} txData.context - The context in which the transaction is made. + * @param {String} [txData.endpointTokenConfig] - The current endpoint token config. * @param {String} [txData.valueKey] - The value key (optional). * @param {Object} tokenUsage - The number of tokens used. * @param {Number} tokenUsage.promptTokens - The number of prompt tokens used. diff --git a/api/models/tx.js b/api/models/tx.js index c0283de0b..d3be0d869 100644 --- a/api/models/tx.js +++ b/api/models/tx.js @@ -57,9 +57,14 @@ const getValueKey = (model, endpoint) => { * @param {string} [params.tokenType] - The type of token (e.g., 'prompt' or 'completion'). * @param {string} [params.model] - The model name to derive the value key from if not provided. * @param {string} [params.endpoint] - The endpoint name to derive the value key from if not provided. + * @param {EndpointTokenConfig} [params.endpointTokenConfig] - The token configuration for the endpoint. * @returns {number} The multiplier for the given parameters, or a default value if not found. */ -const getMultiplier = ({ valueKey, tokenType, model, endpoint }) => { +const getMultiplier = ({ valueKey, tokenType, model, endpoint, endpointTokenConfig }) => { + if (endpointTokenConfig) { + return endpointTokenConfig?.[model]?.[tokenType] ?? defaultRate; + } + if (valueKey && tokenType) { return tokenValues[valueKey][tokenType] ?? defaultRate; } diff --git a/api/server/services/Config/loadConfigModels.js b/api/server/services/Config/loadConfigModels.js index 0abe15a8a..aff861493 100644 --- a/api/server/services/Config/loadConfigModels.js +++ b/api/server/services/Config/loadConfigModels.js @@ -49,7 +49,7 @@ async function loadConfigModels() { if (models.fetch && !isUserProvided(API_KEY) && !isUserProvided(BASE_URL)) { fetchPromisesMap[BASE_URL] = - fetchPromisesMap[BASE_URL] || fetchModels({ baseURL: BASE_URL, apiKey: API_KEY }); + fetchPromisesMap[BASE_URL] || fetchModels({ baseURL: BASE_URL, apiKey: API_KEY, name }); baseUrlToNameMap[BASE_URL] = baseUrlToNameMap[BASE_URL] || []; baseUrlToNameMap[BASE_URL].push(name); continue; diff --git a/api/server/services/Endpoints/custom/initializeClient.js b/api/server/services/Endpoints/custom/initializeClient.js index 978506b7b..e5c9a62e8 100644 --- a/api/server/services/Endpoints/custom/initializeClient.js +++ b/api/server/services/Endpoints/custom/initializeClient.js @@ -1,7 +1,9 @@ -const { EModelEndpoint } = require('librechat-data-provider'); +const { EModelEndpoint, CacheKeys } = require('librechat-data-provider'); const { getUserKey, checkUserKeyExpiry } = require('~/server/services/UserService'); const { isUserProvided, extractEnvVariable } = require('~/server/utils'); +const { fetchModels } = require('~/server/services/ModelService'); const getCustomConfig = require('~/cache/getCustomConfig'); +const getLogStores = require('~/cache/getLogStores'); const { OpenAIClient } = require('~/app'); const envVarRegex = /^\${(.+)}$/; @@ -37,6 +39,13 @@ const initializeClient = async ({ req, res, endpointOption }) => { throw new Error(`Missing Base URL for ${endpoint}.`); } + const cache = getLogStores(CacheKeys.TOKEN_CONFIG); + let endpointTokenConfig = await cache.get(endpoint); + if (!endpointTokenConfig) { + await fetchModels({ apiKey: CUSTOM_API_KEY, baseURL: CUSTOM_BASE_URL, name: endpoint }); + endpointTokenConfig = await cache.get(endpoint); + } + const customOptions = { headers: resolvedHeaders, addParams: endpointConfig.addParams, @@ -48,6 +57,7 @@ const initializeClient = async ({ req, res, endpointOption }) => { modelDisplayLabel: endpointConfig.modelDisplayLabel, titleMethod: endpointConfig.titleMethod ?? 'completion', contextStrategy: endpointConfig.summarize ? 'summarize' : null, + endpointTokenConfig, }; const useUserKey = isUserProvided(CUSTOM_API_KEY); diff --git a/api/server/services/ModelService.js b/api/server/services/ModelService.js index 76ac06154..9a9f5238f 100644 --- a/api/server/services/ModelService.js +++ b/api/server/services/ModelService.js @@ -1,10 +1,11 @@ const Keyv = require('keyv'); const axios = require('axios'); const HttpsProxyAgent = require('https-proxy-agent'); -const { EModelEndpoint, defaultModels } = require('librechat-data-provider'); +const { EModelEndpoint, defaultModels, CacheKeys } = require('librechat-data-provider'); +const { extractBaseURL, inputSchema, processModelData } = require('~/utils'); +const getLogStores = require('~/cache/getLogStores'); const { isEnabled } = require('~/server/utils'); const keyvRedis = require('~/cache/keyvRedis'); -const { extractBaseURL } = require('~/utils'); const { logger } = require('~/config'); // const { getAzureCredentials, genAzureChatCompletion } = require('~/utils/'); @@ -32,10 +33,17 @@ const { * @param {string} params.baseURL - The base path URL for the API. * @param {string} [params.name='OpenAI'] - The name of the API; defaults to 'OpenAI'. * @param {boolean} [params.azure=false] - Whether to fetch models from Azure. + * @param {boolean} [params.createTokenConfig=true] - Whether to create a token configuration from the API response. * @returns {Promise} A promise that resolves to an array of model identifiers. * @async */ -const fetchModels = async ({ apiKey, baseURL, name = 'OpenAI', azure = false }) => { +const fetchModels = async ({ + apiKey, + baseURL, + name = 'OpenAI', + azure = false, + createTokenConfig = true, +}) => { let models = []; if (!baseURL && !azure) { @@ -58,7 +66,16 @@ const fetchModels = async ({ apiKey, baseURL, name = 'OpenAI', azure = false }) } const res = await axios.get(`${baseURL}${azure ? '' : '/models'}`, payload); - models = res.data.data.map((item) => item.id); + /** @type {z.infer} */ + const input = res.data; + + const validationResult = inputSchema.safeParse(input); + if (validationResult.success && createTokenConfig) { + const endpointTokenConfig = processModelData(input); + const cache = getLogStores(CacheKeys.TOKEN_CONFIG); + await cache.set(name, endpointTokenConfig); + } + models = input.data.map((item) => item.id); } catch (err) { logger.error(`Failed to fetch models from ${azure ? 'Azure ' : ''}${name} API`, err); } diff --git a/api/typedefs.js b/api/typedefs.js index 7bb956c9a..35b4e993f 100644 --- a/api/typedefs.js +++ b/api/typedefs.js @@ -385,3 +385,18 @@ * @property {string} [azureOpenAIApiVersion] - The Azure OpenAI API version. * @memberof typedefs */ + +/** + * @typedef {Object} TokenConfig + * A configuration object mapping model keys to their respective prompt, completion rates, and context limit. + * @property {number} prompt - The prompt rate + * @property {number} completion - The completion rate + * @property {number} context - The maximum context length supported by the model. + * @memberof typedefs + */ + +/** + * @typedef {Record} EndpointTokenConfig + * An endpoint's config object mapping model keys to their respective prompt, completion rates, and context limit. + * @memberof typedefs + */ diff --git a/api/utils/tokens.js b/api/utils/tokens.js index fb6e363d8..d2bc3d9bb 100644 --- a/api/utils/tokens.js +++ b/api/utils/tokens.js @@ -1,3 +1,4 @@ +const z = require('zod'); const { EModelEndpoint } = require('librechat-data-provider'); const models = [ @@ -91,6 +92,7 @@ const maxTokensMap = { * * @param {string} modelName - The name of the model to look up. * @param {string} endpoint - The endpoint (default is 'openAI'). + * @param {EndpointTokenConfig} [endpointTokenConfig] - Token Config for current endpoint to use for max tokens lookup * @returns {number|undefined} The maximum tokens for the given model or undefined if no match is found. * * @example @@ -98,16 +100,21 @@ const maxTokensMap = { * getModelMaxTokens('gpt-4-32k-unknown'); // Returns 32767 * getModelMaxTokens('unknown-model'); // Returns undefined */ -function getModelMaxTokens(modelName, endpoint = EModelEndpoint.openAI) { +function getModelMaxTokens(modelName, endpoint = EModelEndpoint.openAI, endpointTokenConfig) { if (typeof modelName !== 'string') { return undefined; } - const tokensMap = maxTokensMap[endpoint]; + /** @type {EndpointTokenConfig | Record} */ + const tokensMap = endpointTokenConfig ?? maxTokensMap[endpoint]; if (!tokensMap) { return undefined; } + if (tokensMap[modelName]?.context) { + return tokensMap[modelName].context; + } + if (tokensMap[modelName]) { return tokensMap[modelName]; } @@ -115,7 +122,8 @@ function getModelMaxTokens(modelName, endpoint = EModelEndpoint.openAI) { const keys = Object.keys(tokensMap); for (let i = keys.length - 1; i >= 0; i--) { if (modelName.includes(keys[i])) { - return tokensMap[keys[i]]; + const result = tokensMap[keys[i]]; + return result?.context ?? result; } } @@ -160,9 +168,55 @@ function matchModelName(modelName, endpoint = EModelEndpoint.openAI) { return modelName; } +const modelSchema = z.object({ + id: z.string(), + pricing: z.object({ + prompt: z.string(), + completion: z.string(), + }), + context_length: z.number(), +}); + +const inputSchema = z.object({ + data: z.array(modelSchema), +}); + +/** + * Processes a list of model data from an API and organizes it into structured data based on URL and specifics of rates and context. + * @param {{ data: Array> }} input The input object containing base URL and data fetched from the API. + * @returns {EndpointTokenConfig} The processed model data. + */ +function processModelData(input) { + const validationResult = inputSchema.safeParse(input); + if (!validationResult.success) { + throw new Error('Invalid input data'); + } + const { data } = validationResult.data; + + /** @type {EndpointTokenConfig} */ + const tokenConfig = {}; + + for (const model of data) { + const modelKey = model.id; + const prompt = parseFloat(model.pricing.prompt) * 1000000; + const completion = parseFloat(model.pricing.completion) * 1000000; + + tokenConfig[modelKey] = { + prompt, + completion, + context: model.context_length, + }; + } + + return tokenConfig; +} + module.exports = { tiktokenModels: new Set(models), maxTokensMap, + inputSchema, + modelSchema, getModelMaxTokens, matchModelName, + processModelData, }; diff --git a/packages/data-provider/src/config.ts b/packages/data-provider/src/config.ts index 74e4a330b..e21bf8fef 100644 --- a/packages/data-provider/src/config.ts +++ b/packages/data-provider/src/config.ts @@ -168,6 +168,10 @@ export enum CacheKeys { * Key for the default endpoint config cache. */ ENDPOINT_CONFIG = 'endpointsConfig', + /** + * Key for accessing the model token config cache. + */ + TOKEN_CONFIG = 'tokenConfig', /** * Key for the custom config cache. */