mirror of
https://github.com/danny-avila/LibreChat.git
synced 2025-09-22 06:00:56 +02:00
🪙 feat: Use OpenRouter Model Data for Token Cost and Context (#1703)
* feat: use openrouter data for model token cost/context * chore: add ttl for tokenConfig and refetch models if cache expired
This commit is contained in:
parent
f1d974c513
commit
30e143e96d
14 changed files with 146 additions and 16 deletions
|
@ -428,7 +428,10 @@ class BaseClient {
|
|||
await this.saveMessageToDatabase(userMessage, saveOptions, user);
|
||||
}
|
||||
|
||||
if (isEnabled(process.env.CHECK_BALANCE) && supportsBalanceCheck[this.options.endpoint]) {
|
||||
if (
|
||||
isEnabled(process.env.CHECK_BALANCE) &&
|
||||
supportsBalanceCheck[this.options.endpointType ?? this.options.endpoint]
|
||||
) {
|
||||
await checkBalance({
|
||||
req: this.options.req,
|
||||
res: this.options.res,
|
||||
|
@ -438,6 +441,7 @@ class BaseClient {
|
|||
amount: promptTokens,
|
||||
model: this.modelOptions.model,
|
||||
endpoint: this.options.endpoint,
|
||||
endpointTokenConfig: this.options.endpointTokenConfig,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
|
|
@ -131,8 +131,13 @@ class OpenAIClient extends BaseClient {
|
|||
const { isChatGptModel } = this;
|
||||
this.isUnofficialChatGptModel =
|
||||
model.startsWith('text-chat') || model.startsWith('text-davinci-002-render');
|
||||
|
||||
this.maxContextTokens =
|
||||
getModelMaxTokens(model, this.options.endpointType ?? this.options.endpoint) ?? 4095; // 1 less than maximum
|
||||
getModelMaxTokens(
|
||||
model,
|
||||
this.options.endpointType ?? this.options.endpoint,
|
||||
this.options.endpointTokenConfig,
|
||||
) ?? 4095; // 1 less than maximum
|
||||
|
||||
if (this.shouldSummarize) {
|
||||
this.maxContextTokens = Math.floor(this.maxContextTokens / 2);
|
||||
|
@ -780,7 +785,12 @@ ${convo}
|
|||
// TODO: remove the gpt fallback and make it specific to endpoint
|
||||
const { OPENAI_SUMMARY_MODEL = 'gpt-3.5-turbo' } = process.env ?? {};
|
||||
const model = this.options.summaryModel ?? OPENAI_SUMMARY_MODEL;
|
||||
const maxContextTokens = getModelMaxTokens(model) ?? 4095;
|
||||
const maxContextTokens =
|
||||
getModelMaxTokens(
|
||||
model,
|
||||
this.options.endpointType ?? this.options.endpoint,
|
||||
this.options.endpointTokenConfig,
|
||||
) ?? 4095; // 1 less than maximum
|
||||
|
||||
// 3 tokens for the assistant label, and 98 for the summarizer prompt (101)
|
||||
let promptBuffer = 101;
|
||||
|
@ -886,6 +896,7 @@ ${convo}
|
|||
model: this.modelOptions.model,
|
||||
context: 'message',
|
||||
conversationId: this.conversationId,
|
||||
endpointTokenConfig: this.options.endpointTokenConfig,
|
||||
},
|
||||
{ promptTokens, completionTokens },
|
||||
);
|
||||
|
|
5
api/cache/getLogStores.js
vendored
5
api/cache/getLogStores.js
vendored
|
@ -23,6 +23,10 @@ const config = isEnabled(USE_REDIS)
|
|||
? new Keyv({ store: keyvRedis })
|
||||
: new Keyv({ namespace: CacheKeys.CONFIG_STORE });
|
||||
|
||||
const tokenConfig = isEnabled(USE_REDIS) // ttl: 30 minutes
|
||||
? new Keyv({ store: keyvRedis, ttl: 1800000 })
|
||||
: new Keyv({ namespace: CacheKeys.TOKEN_CONFIG, ttl: 1800000 });
|
||||
|
||||
const namespaces = {
|
||||
[CacheKeys.CONFIG_STORE]: config,
|
||||
pending_req,
|
||||
|
@ -34,6 +38,7 @@ const namespaces = {
|
|||
token_balance: createViolationInstance('token_balance'),
|
||||
registrations: createViolationInstance('registrations'),
|
||||
logins: createViolationInstance('logins'),
|
||||
[CacheKeys.TOKEN_CONFIG]: tokenConfig,
|
||||
};
|
||||
|
||||
/**
|
||||
|
|
|
@ -10,8 +10,9 @@ balanceSchema.statics.check = async function ({
|
|||
valueKey,
|
||||
tokenType,
|
||||
amount,
|
||||
endpointTokenConfig,
|
||||
}) {
|
||||
const multiplier = getMultiplier({ valueKey, tokenType, model, endpoint });
|
||||
const multiplier = getMultiplier({ valueKey, tokenType, model, endpoint, endpointTokenConfig });
|
||||
const tokenCost = amount * multiplier;
|
||||
const { tokenCredits: balance } = (await this.findOne({ user }, 'tokenCredits').lean()) ?? {};
|
||||
|
||||
|
@ -24,6 +25,7 @@ balanceSchema.statics.check = async function ({
|
|||
amount,
|
||||
balance,
|
||||
multiplier,
|
||||
endpointTokenConfig: !!endpointTokenConfig,
|
||||
});
|
||||
|
||||
if (!balance) {
|
||||
|
|
|
@ -10,8 +10,8 @@ transactionSchema.methods.calculateTokenValue = function () {
|
|||
if (!this.valueKey || !this.tokenType) {
|
||||
this.tokenValue = this.rawAmount;
|
||||
}
|
||||
const { valueKey, tokenType, model } = this;
|
||||
const multiplier = getMultiplier({ valueKey, tokenType, model });
|
||||
const { valueKey, tokenType, model, endpointTokenConfig } = this;
|
||||
const multiplier = getMultiplier({ valueKey, tokenType, model, endpointTokenConfig });
|
||||
this.rate = multiplier;
|
||||
this.tokenValue = this.rawAmount * multiplier;
|
||||
if (this.context && this.tokenType === 'completion' && this.context === 'incomplete') {
|
||||
|
@ -25,6 +25,7 @@ transactionSchema.statics.create = async function (transactionData) {
|
|||
const Transaction = this;
|
||||
|
||||
const transaction = new Transaction(transactionData);
|
||||
transaction.endpointTokenConfig = transactionData.endpointTokenConfig;
|
||||
transaction.calculateTokenValue();
|
||||
|
||||
// Save the transaction
|
||||
|
|
|
@ -14,6 +14,7 @@ const { logViolation } = require('../cache');
|
|||
* @param {('prompt' | 'completion')} params.txData.tokenType - The type of token.
|
||||
* @param {number} params.txData.amount - The amount of tokens.
|
||||
* @param {string} params.txData.model - The model name or identifier.
|
||||
* @param {string} [params.txData.endpointTokenConfig] - The token configuration for the endpoint.
|
||||
* @returns {Promise<boolean>} Returns true if the user can spend the amount, otherwise denies the request.
|
||||
* @throws {Error} Throws an error if there's an issue with the balance check.
|
||||
*/
|
||||
|
|
|
@ -11,6 +11,7 @@ const { logger } = require('~/config');
|
|||
* @param {String} txData.conversationId - The ID of the conversation.
|
||||
* @param {String} txData.model - The model name.
|
||||
* @param {String} txData.context - The context in which the transaction is made.
|
||||
* @param {String} [txData.endpointTokenConfig] - The current endpoint token config.
|
||||
* @param {String} [txData.valueKey] - The value key (optional).
|
||||
* @param {Object} tokenUsage - The number of tokens used.
|
||||
* @param {Number} tokenUsage.promptTokens - The number of prompt tokens used.
|
||||
|
|
|
@ -57,9 +57,14 @@ const getValueKey = (model, endpoint) => {
|
|||
* @param {string} [params.tokenType] - The type of token (e.g., 'prompt' or 'completion').
|
||||
* @param {string} [params.model] - The model name to derive the value key from if not provided.
|
||||
* @param {string} [params.endpoint] - The endpoint name to derive the value key from if not provided.
|
||||
* @param {EndpointTokenConfig} [params.endpointTokenConfig] - The token configuration for the endpoint.
|
||||
* @returns {number} The multiplier for the given parameters, or a default value if not found.
|
||||
*/
|
||||
const getMultiplier = ({ valueKey, tokenType, model, endpoint }) => {
|
||||
const getMultiplier = ({ valueKey, tokenType, model, endpoint, endpointTokenConfig }) => {
|
||||
if (endpointTokenConfig) {
|
||||
return endpointTokenConfig?.[model]?.[tokenType] ?? defaultRate;
|
||||
}
|
||||
|
||||
if (valueKey && tokenType) {
|
||||
return tokenValues[valueKey][tokenType] ?? defaultRate;
|
||||
}
|
||||
|
|
|
@ -49,7 +49,7 @@ async function loadConfigModels() {
|
|||
|
||||
if (models.fetch && !isUserProvided(API_KEY) && !isUserProvided(BASE_URL)) {
|
||||
fetchPromisesMap[BASE_URL] =
|
||||
fetchPromisesMap[BASE_URL] || fetchModels({ baseURL: BASE_URL, apiKey: API_KEY });
|
||||
fetchPromisesMap[BASE_URL] || fetchModels({ baseURL: BASE_URL, apiKey: API_KEY, name });
|
||||
baseUrlToNameMap[BASE_URL] = baseUrlToNameMap[BASE_URL] || [];
|
||||
baseUrlToNameMap[BASE_URL].push(name);
|
||||
continue;
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
const { EModelEndpoint } = require('librechat-data-provider');
|
||||
const { EModelEndpoint, CacheKeys } = require('librechat-data-provider');
|
||||
const { getUserKey, checkUserKeyExpiry } = require('~/server/services/UserService');
|
||||
const { isUserProvided, extractEnvVariable } = require('~/server/utils');
|
||||
const { fetchModels } = require('~/server/services/ModelService');
|
||||
const getCustomConfig = require('~/cache/getCustomConfig');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const { OpenAIClient } = require('~/app');
|
||||
|
||||
const envVarRegex = /^\${(.+)}$/;
|
||||
|
@ -37,6 +39,13 @@ const initializeClient = async ({ req, res, endpointOption }) => {
|
|||
throw new Error(`Missing Base URL for ${endpoint}.`);
|
||||
}
|
||||
|
||||
const cache = getLogStores(CacheKeys.TOKEN_CONFIG);
|
||||
let endpointTokenConfig = await cache.get(endpoint);
|
||||
if (!endpointTokenConfig) {
|
||||
await fetchModels({ apiKey: CUSTOM_API_KEY, baseURL: CUSTOM_BASE_URL, name: endpoint });
|
||||
endpointTokenConfig = await cache.get(endpoint);
|
||||
}
|
||||
|
||||
const customOptions = {
|
||||
headers: resolvedHeaders,
|
||||
addParams: endpointConfig.addParams,
|
||||
|
@ -48,6 +57,7 @@ const initializeClient = async ({ req, res, endpointOption }) => {
|
|||
modelDisplayLabel: endpointConfig.modelDisplayLabel,
|
||||
titleMethod: endpointConfig.titleMethod ?? 'completion',
|
||||
contextStrategy: endpointConfig.summarize ? 'summarize' : null,
|
||||
endpointTokenConfig,
|
||||
};
|
||||
|
||||
const useUserKey = isUserProvided(CUSTOM_API_KEY);
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
const Keyv = require('keyv');
|
||||
const axios = require('axios');
|
||||
const HttpsProxyAgent = require('https-proxy-agent');
|
||||
const { EModelEndpoint, defaultModels } = require('librechat-data-provider');
|
||||
const { EModelEndpoint, defaultModels, CacheKeys } = require('librechat-data-provider');
|
||||
const { extractBaseURL, inputSchema, processModelData } = require('~/utils');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const keyvRedis = require('~/cache/keyvRedis');
|
||||
const { extractBaseURL } = require('~/utils');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
// const { getAzureCredentials, genAzureChatCompletion } = require('~/utils/');
|
||||
|
@ -32,10 +33,17 @@ const {
|
|||
* @param {string} params.baseURL - The base path URL for the API.
|
||||
* @param {string} [params.name='OpenAI'] - The name of the API; defaults to 'OpenAI'.
|
||||
* @param {boolean} [params.azure=false] - Whether to fetch models from Azure.
|
||||
* @param {boolean} [params.createTokenConfig=true] - Whether to create a token configuration from the API response.
|
||||
* @returns {Promise<string[]>} A promise that resolves to an array of model identifiers.
|
||||
* @async
|
||||
*/
|
||||
const fetchModels = async ({ apiKey, baseURL, name = 'OpenAI', azure = false }) => {
|
||||
const fetchModels = async ({
|
||||
apiKey,
|
||||
baseURL,
|
||||
name = 'OpenAI',
|
||||
azure = false,
|
||||
createTokenConfig = true,
|
||||
}) => {
|
||||
let models = [];
|
||||
|
||||
if (!baseURL && !azure) {
|
||||
|
@ -58,7 +66,16 @@ const fetchModels = async ({ apiKey, baseURL, name = 'OpenAI', azure = false })
|
|||
}
|
||||
|
||||
const res = await axios.get(`${baseURL}${azure ? '' : '/models'}`, payload);
|
||||
models = res.data.data.map((item) => item.id);
|
||||
/** @type {z.infer<typeof inputSchema>} */
|
||||
const input = res.data;
|
||||
|
||||
const validationResult = inputSchema.safeParse(input);
|
||||
if (validationResult.success && createTokenConfig) {
|
||||
const endpointTokenConfig = processModelData(input);
|
||||
const cache = getLogStores(CacheKeys.TOKEN_CONFIG);
|
||||
await cache.set(name, endpointTokenConfig);
|
||||
}
|
||||
models = input.data.map((item) => item.id);
|
||||
} catch (err) {
|
||||
logger.error(`Failed to fetch models from ${azure ? 'Azure ' : ''}${name} API`, err);
|
||||
}
|
||||
|
|
|
@ -385,3 +385,18 @@
|
|||
* @property {string} [azureOpenAIApiVersion] - The Azure OpenAI API version.
|
||||
* @memberof typedefs
|
||||
*/
|
||||
|
||||
/**
|
||||
* @typedef {Object} TokenConfig
|
||||
* A configuration object mapping model keys to their respective prompt, completion rates, and context limit.
|
||||
* @property {number} prompt - The prompt rate
|
||||
* @property {number} completion - The completion rate
|
||||
* @property {number} context - The maximum context length supported by the model.
|
||||
* @memberof typedefs
|
||||
*/
|
||||
|
||||
/**
|
||||
* @typedef {Record<string, TokenConfig>} EndpointTokenConfig
|
||||
* An endpoint's config object mapping model keys to their respective prompt, completion rates, and context limit.
|
||||
* @memberof typedefs
|
||||
*/
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
const z = require('zod');
|
||||
const { EModelEndpoint } = require('librechat-data-provider');
|
||||
|
||||
const models = [
|
||||
|
@ -91,6 +92,7 @@ const maxTokensMap = {
|
|||
*
|
||||
* @param {string} modelName - The name of the model to look up.
|
||||
* @param {string} endpoint - The endpoint (default is 'openAI').
|
||||
* @param {EndpointTokenConfig} [endpointTokenConfig] - Token Config for current endpoint to use for max tokens lookup
|
||||
* @returns {number|undefined} The maximum tokens for the given model or undefined if no match is found.
|
||||
*
|
||||
* @example
|
||||
|
@ -98,16 +100,21 @@ const maxTokensMap = {
|
|||
* getModelMaxTokens('gpt-4-32k-unknown'); // Returns 32767
|
||||
* getModelMaxTokens('unknown-model'); // Returns undefined
|
||||
*/
|
||||
function getModelMaxTokens(modelName, endpoint = EModelEndpoint.openAI) {
|
||||
function getModelMaxTokens(modelName, endpoint = EModelEndpoint.openAI, endpointTokenConfig) {
|
||||
if (typeof modelName !== 'string') {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
const tokensMap = maxTokensMap[endpoint];
|
||||
/** @type {EndpointTokenConfig | Record<string, number>} */
|
||||
const tokensMap = endpointTokenConfig ?? maxTokensMap[endpoint];
|
||||
if (!tokensMap) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
if (tokensMap[modelName]?.context) {
|
||||
return tokensMap[modelName].context;
|
||||
}
|
||||
|
||||
if (tokensMap[modelName]) {
|
||||
return tokensMap[modelName];
|
||||
}
|
||||
|
@ -115,7 +122,8 @@ function getModelMaxTokens(modelName, endpoint = EModelEndpoint.openAI) {
|
|||
const keys = Object.keys(tokensMap);
|
||||
for (let i = keys.length - 1; i >= 0; i--) {
|
||||
if (modelName.includes(keys[i])) {
|
||||
return tokensMap[keys[i]];
|
||||
const result = tokensMap[keys[i]];
|
||||
return result?.context ?? result;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -160,9 +168,55 @@ function matchModelName(modelName, endpoint = EModelEndpoint.openAI) {
|
|||
return modelName;
|
||||
}
|
||||
|
||||
const modelSchema = z.object({
|
||||
id: z.string(),
|
||||
pricing: z.object({
|
||||
prompt: z.string(),
|
||||
completion: z.string(),
|
||||
}),
|
||||
context_length: z.number(),
|
||||
});
|
||||
|
||||
const inputSchema = z.object({
|
||||
data: z.array(modelSchema),
|
||||
});
|
||||
|
||||
/**
|
||||
* Processes a list of model data from an API and organizes it into structured data based on URL and specifics of rates and context.
|
||||
* @param {{ data: Array<z.infer<typeof modelSchema>> }} input The input object containing base URL and data fetched from the API.
|
||||
* @returns {EndpointTokenConfig} The processed model data.
|
||||
*/
|
||||
function processModelData(input) {
|
||||
const validationResult = inputSchema.safeParse(input);
|
||||
if (!validationResult.success) {
|
||||
throw new Error('Invalid input data');
|
||||
}
|
||||
const { data } = validationResult.data;
|
||||
|
||||
/** @type {EndpointTokenConfig} */
|
||||
const tokenConfig = {};
|
||||
|
||||
for (const model of data) {
|
||||
const modelKey = model.id;
|
||||
const prompt = parseFloat(model.pricing.prompt) * 1000000;
|
||||
const completion = parseFloat(model.pricing.completion) * 1000000;
|
||||
|
||||
tokenConfig[modelKey] = {
|
||||
prompt,
|
||||
completion,
|
||||
context: model.context_length,
|
||||
};
|
||||
}
|
||||
|
||||
return tokenConfig;
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
tiktokenModels: new Set(models),
|
||||
maxTokensMap,
|
||||
inputSchema,
|
||||
modelSchema,
|
||||
getModelMaxTokens,
|
||||
matchModelName,
|
||||
processModelData,
|
||||
};
|
||||
|
|
|
@ -168,6 +168,10 @@ export enum CacheKeys {
|
|||
* Key for the default endpoint config cache.
|
||||
*/
|
||||
ENDPOINT_CONFIG = 'endpointsConfig',
|
||||
/**
|
||||
* Key for accessing the model token config cache.
|
||||
*/
|
||||
TOKEN_CONFIG = 'tokenConfig',
|
||||
/**
|
||||
* Key for the custom config cache.
|
||||
*/
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue