🪙 feat: Use OpenRouter Model Data for Token Cost and Context (#1703)

* feat: use openrouter data for model token cost/context

* chore: add ttl for tokenConfig and refetch models if cache expired
This commit is contained in:
Danny Avila 2024-02-02 00:42:11 -05:00 committed by GitHub
parent f1d974c513
commit 30e143e96d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 146 additions and 16 deletions

View file

@ -428,7 +428,10 @@ class BaseClient {
await this.saveMessageToDatabase(userMessage, saveOptions, user);
}
if (isEnabled(process.env.CHECK_BALANCE) && supportsBalanceCheck[this.options.endpoint]) {
if (
isEnabled(process.env.CHECK_BALANCE) &&
supportsBalanceCheck[this.options.endpointType ?? this.options.endpoint]
) {
await checkBalance({
req: this.options.req,
res: this.options.res,
@ -438,6 +441,7 @@ class BaseClient {
amount: promptTokens,
model: this.modelOptions.model,
endpoint: this.options.endpoint,
endpointTokenConfig: this.options.endpointTokenConfig,
},
});
}

View file

@ -131,8 +131,13 @@ class OpenAIClient extends BaseClient {
const { isChatGptModel } = this;
this.isUnofficialChatGptModel =
model.startsWith('text-chat') || model.startsWith('text-davinci-002-render');
this.maxContextTokens =
getModelMaxTokens(model, this.options.endpointType ?? this.options.endpoint) ?? 4095; // 1 less than maximum
getModelMaxTokens(
model,
this.options.endpointType ?? this.options.endpoint,
this.options.endpointTokenConfig,
) ?? 4095; // 1 less than maximum
if (this.shouldSummarize) {
this.maxContextTokens = Math.floor(this.maxContextTokens / 2);
@ -780,7 +785,12 @@ ${convo}
// TODO: remove the gpt fallback and make it specific to endpoint
const { OPENAI_SUMMARY_MODEL = 'gpt-3.5-turbo' } = process.env ?? {};
const model = this.options.summaryModel ?? OPENAI_SUMMARY_MODEL;
const maxContextTokens = getModelMaxTokens(model) ?? 4095;
const maxContextTokens =
getModelMaxTokens(
model,
this.options.endpointType ?? this.options.endpoint,
this.options.endpointTokenConfig,
) ?? 4095; // 1 less than maximum
// 3 tokens for the assistant label, and 98 for the summarizer prompt (101)
let promptBuffer = 101;
@ -886,6 +896,7 @@ ${convo}
model: this.modelOptions.model,
context: 'message',
conversationId: this.conversationId,
endpointTokenConfig: this.options.endpointTokenConfig,
},
{ promptTokens, completionTokens },
);

View file

@ -23,6 +23,10 @@ const config = isEnabled(USE_REDIS)
? new Keyv({ store: keyvRedis })
: new Keyv({ namespace: CacheKeys.CONFIG_STORE });
const tokenConfig = isEnabled(USE_REDIS) // ttl: 30 minutes
? new Keyv({ store: keyvRedis, ttl: 1800000 })
: new Keyv({ namespace: CacheKeys.TOKEN_CONFIG, ttl: 1800000 });
const namespaces = {
[CacheKeys.CONFIG_STORE]: config,
pending_req,
@ -34,6 +38,7 @@ const namespaces = {
token_balance: createViolationInstance('token_balance'),
registrations: createViolationInstance('registrations'),
logins: createViolationInstance('logins'),
[CacheKeys.TOKEN_CONFIG]: tokenConfig,
};
/**

View file

@ -10,8 +10,9 @@ balanceSchema.statics.check = async function ({
valueKey,
tokenType,
amount,
endpointTokenConfig,
}) {
const multiplier = getMultiplier({ valueKey, tokenType, model, endpoint });
const multiplier = getMultiplier({ valueKey, tokenType, model, endpoint, endpointTokenConfig });
const tokenCost = amount * multiplier;
const { tokenCredits: balance } = (await this.findOne({ user }, 'tokenCredits').lean()) ?? {};
@ -24,6 +25,7 @@ balanceSchema.statics.check = async function ({
amount,
balance,
multiplier,
endpointTokenConfig: !!endpointTokenConfig,
});
if (!balance) {

View file

@ -10,8 +10,8 @@ transactionSchema.methods.calculateTokenValue = function () {
if (!this.valueKey || !this.tokenType) {
this.tokenValue = this.rawAmount;
}
const { valueKey, tokenType, model } = this;
const multiplier = getMultiplier({ valueKey, tokenType, model });
const { valueKey, tokenType, model, endpointTokenConfig } = this;
const multiplier = getMultiplier({ valueKey, tokenType, model, endpointTokenConfig });
this.rate = multiplier;
this.tokenValue = this.rawAmount * multiplier;
if (this.context && this.tokenType === 'completion' && this.context === 'incomplete') {
@ -25,6 +25,7 @@ transactionSchema.statics.create = async function (transactionData) {
const Transaction = this;
const transaction = new Transaction(transactionData);
transaction.endpointTokenConfig = transactionData.endpointTokenConfig;
transaction.calculateTokenValue();
// Save the transaction

View file

@ -14,6 +14,7 @@ const { logViolation } = require('../cache');
* @param {('prompt' | 'completion')} params.txData.tokenType - The type of token.
* @param {number} params.txData.amount - The amount of tokens.
* @param {string} params.txData.model - The model name or identifier.
* @param {string} [params.txData.endpointTokenConfig] - The token configuration for the endpoint.
* @returns {Promise<boolean>} Returns true if the user can spend the amount, otherwise denies the request.
* @throws {Error} Throws an error if there's an issue with the balance check.
*/

View file

@ -11,6 +11,7 @@ const { logger } = require('~/config');
* @param {String} txData.conversationId - The ID of the conversation.
* @param {String} txData.model - The model name.
* @param {String} txData.context - The context in which the transaction is made.
* @param {String} [txData.endpointTokenConfig] - The current endpoint token config.
* @param {String} [txData.valueKey] - The value key (optional).
* @param {Object} tokenUsage - The number of tokens used.
* @param {Number} tokenUsage.promptTokens - The number of prompt tokens used.

View file

@ -57,9 +57,14 @@ const getValueKey = (model, endpoint) => {
* @param {string} [params.tokenType] - The type of token (e.g., 'prompt' or 'completion').
* @param {string} [params.model] - The model name to derive the value key from if not provided.
* @param {string} [params.endpoint] - The endpoint name to derive the value key from if not provided.
* @param {EndpointTokenConfig} [params.endpointTokenConfig] - The token configuration for the endpoint.
* @returns {number} The multiplier for the given parameters, or a default value if not found.
*/
const getMultiplier = ({ valueKey, tokenType, model, endpoint }) => {
const getMultiplier = ({ valueKey, tokenType, model, endpoint, endpointTokenConfig }) => {
if (endpointTokenConfig) {
return endpointTokenConfig?.[model]?.[tokenType] ?? defaultRate;
}
if (valueKey && tokenType) {
return tokenValues[valueKey][tokenType] ?? defaultRate;
}

View file

@ -49,7 +49,7 @@ async function loadConfigModels() {
if (models.fetch && !isUserProvided(API_KEY) && !isUserProvided(BASE_URL)) {
fetchPromisesMap[BASE_URL] =
fetchPromisesMap[BASE_URL] || fetchModels({ baseURL: BASE_URL, apiKey: API_KEY });
fetchPromisesMap[BASE_URL] || fetchModels({ baseURL: BASE_URL, apiKey: API_KEY, name });
baseUrlToNameMap[BASE_URL] = baseUrlToNameMap[BASE_URL] || [];
baseUrlToNameMap[BASE_URL].push(name);
continue;

View file

@ -1,7 +1,9 @@
const { EModelEndpoint } = require('librechat-data-provider');
const { EModelEndpoint, CacheKeys } = require('librechat-data-provider');
const { getUserKey, checkUserKeyExpiry } = require('~/server/services/UserService');
const { isUserProvided, extractEnvVariable } = require('~/server/utils');
const { fetchModels } = require('~/server/services/ModelService');
const getCustomConfig = require('~/cache/getCustomConfig');
const getLogStores = require('~/cache/getLogStores');
const { OpenAIClient } = require('~/app');
const envVarRegex = /^\${(.+)}$/;
@ -37,6 +39,13 @@ const initializeClient = async ({ req, res, endpointOption }) => {
throw new Error(`Missing Base URL for ${endpoint}.`);
}
const cache = getLogStores(CacheKeys.TOKEN_CONFIG);
let endpointTokenConfig = await cache.get(endpoint);
if (!endpointTokenConfig) {
await fetchModels({ apiKey: CUSTOM_API_KEY, baseURL: CUSTOM_BASE_URL, name: endpoint });
endpointTokenConfig = await cache.get(endpoint);
}
const customOptions = {
headers: resolvedHeaders,
addParams: endpointConfig.addParams,
@ -48,6 +57,7 @@ const initializeClient = async ({ req, res, endpointOption }) => {
modelDisplayLabel: endpointConfig.modelDisplayLabel,
titleMethod: endpointConfig.titleMethod ?? 'completion',
contextStrategy: endpointConfig.summarize ? 'summarize' : null,
endpointTokenConfig,
};
const useUserKey = isUserProvided(CUSTOM_API_KEY);

View file

@ -1,10 +1,11 @@
const Keyv = require('keyv');
const axios = require('axios');
const HttpsProxyAgent = require('https-proxy-agent');
const { EModelEndpoint, defaultModels } = require('librechat-data-provider');
const { EModelEndpoint, defaultModels, CacheKeys } = require('librechat-data-provider');
const { extractBaseURL, inputSchema, processModelData } = require('~/utils');
const getLogStores = require('~/cache/getLogStores');
const { isEnabled } = require('~/server/utils');
const keyvRedis = require('~/cache/keyvRedis');
const { extractBaseURL } = require('~/utils');
const { logger } = require('~/config');
// const { getAzureCredentials, genAzureChatCompletion } = require('~/utils/');
@ -32,10 +33,17 @@ const {
* @param {string} params.baseURL - The base path URL for the API.
* @param {string} [params.name='OpenAI'] - The name of the API; defaults to 'OpenAI'.
* @param {boolean} [params.azure=false] - Whether to fetch models from Azure.
* @param {boolean} [params.createTokenConfig=true] - Whether to create a token configuration from the API response.
* @returns {Promise<string[]>} A promise that resolves to an array of model identifiers.
* @async
*/
const fetchModels = async ({ apiKey, baseURL, name = 'OpenAI', azure = false }) => {
const fetchModels = async ({
apiKey,
baseURL,
name = 'OpenAI',
azure = false,
createTokenConfig = true,
}) => {
let models = [];
if (!baseURL && !azure) {
@ -58,7 +66,16 @@ const fetchModels = async ({ apiKey, baseURL, name = 'OpenAI', azure = false })
}
const res = await axios.get(`${baseURL}${azure ? '' : '/models'}`, payload);
models = res.data.data.map((item) => item.id);
/** @type {z.infer<typeof inputSchema>} */
const input = res.data;
const validationResult = inputSchema.safeParse(input);
if (validationResult.success && createTokenConfig) {
const endpointTokenConfig = processModelData(input);
const cache = getLogStores(CacheKeys.TOKEN_CONFIG);
await cache.set(name, endpointTokenConfig);
}
models = input.data.map((item) => item.id);
} catch (err) {
logger.error(`Failed to fetch models from ${azure ? 'Azure ' : ''}${name} API`, err);
}

View file

@ -385,3 +385,18 @@
* @property {string} [azureOpenAIApiVersion] - The Azure OpenAI API version.
* @memberof typedefs
*/
/**
* @typedef {Object} TokenConfig
* A configuration object mapping model keys to their respective prompt, completion rates, and context limit.
* @property {number} prompt - The prompt rate
* @property {number} completion - The completion rate
* @property {number} context - The maximum context length supported by the model.
* @memberof typedefs
*/
/**
* @typedef {Record<string, TokenConfig>} EndpointTokenConfig
* An endpoint's config object mapping model keys to their respective prompt, completion rates, and context limit.
* @memberof typedefs
*/

View file

@ -1,3 +1,4 @@
const z = require('zod');
const { EModelEndpoint } = require('librechat-data-provider');
const models = [
@ -91,6 +92,7 @@ const maxTokensMap = {
*
* @param {string} modelName - The name of the model to look up.
* @param {string} endpoint - The endpoint (default is 'openAI').
* @param {EndpointTokenConfig} [endpointTokenConfig] - Token Config for current endpoint to use for max tokens lookup
* @returns {number|undefined} The maximum tokens for the given model or undefined if no match is found.
*
* @example
@ -98,16 +100,21 @@ const maxTokensMap = {
* getModelMaxTokens('gpt-4-32k-unknown'); // Returns 32767
* getModelMaxTokens('unknown-model'); // Returns undefined
*/
function getModelMaxTokens(modelName, endpoint = EModelEndpoint.openAI) {
function getModelMaxTokens(modelName, endpoint = EModelEndpoint.openAI, endpointTokenConfig) {
if (typeof modelName !== 'string') {
return undefined;
}
const tokensMap = maxTokensMap[endpoint];
/** @type {EndpointTokenConfig | Record<string, number>} */
const tokensMap = endpointTokenConfig ?? maxTokensMap[endpoint];
if (!tokensMap) {
return undefined;
}
if (tokensMap[modelName]?.context) {
return tokensMap[modelName].context;
}
if (tokensMap[modelName]) {
return tokensMap[modelName];
}
@ -115,7 +122,8 @@ function getModelMaxTokens(modelName, endpoint = EModelEndpoint.openAI) {
const keys = Object.keys(tokensMap);
for (let i = keys.length - 1; i >= 0; i--) {
if (modelName.includes(keys[i])) {
return tokensMap[keys[i]];
const result = tokensMap[keys[i]];
return result?.context ?? result;
}
}
@ -160,9 +168,55 @@ function matchModelName(modelName, endpoint = EModelEndpoint.openAI) {
return modelName;
}
const modelSchema = z.object({
id: z.string(),
pricing: z.object({
prompt: z.string(),
completion: z.string(),
}),
context_length: z.number(),
});
const inputSchema = z.object({
data: z.array(modelSchema),
});
/**
* Processes a list of model data from an API and organizes it into structured data based on URL and specifics of rates and context.
* @param {{ data: Array<z.infer<typeof modelSchema>> }} input The input object containing base URL and data fetched from the API.
* @returns {EndpointTokenConfig} The processed model data.
*/
function processModelData(input) {
const validationResult = inputSchema.safeParse(input);
if (!validationResult.success) {
throw new Error('Invalid input data');
}
const { data } = validationResult.data;
/** @type {EndpointTokenConfig} */
const tokenConfig = {};
for (const model of data) {
const modelKey = model.id;
const prompt = parseFloat(model.pricing.prompt) * 1000000;
const completion = parseFloat(model.pricing.completion) * 1000000;
tokenConfig[modelKey] = {
prompt,
completion,
context: model.context_length,
};
}
return tokenConfig;
}
module.exports = {
tiktokenModels: new Set(models),
maxTokensMap,
inputSchema,
modelSchema,
getModelMaxTokens,
matchModelName,
processModelData,
};

View file

@ -168,6 +168,10 @@ export enum CacheKeys {
* Key for the default endpoint config cache.
*/
ENDPOINT_CONFIG = 'endpointsConfig',
/**
* Key for accessing the model token config cache.
*/
TOKEN_CONFIG = 'tokenConfig',
/**
* Key for the custom config cache.
*/