feat: Accurate Token Usage Tracking & Optional Balance (#1018)

* refactor(Chains/llms): allow passing callbacks

* refactor(BaseClient): accurately count completion tokens as generation only

* refactor(OpenAIClient): remove unused getTokenCountForResponse, pass streaming var and callbacks in initializeLLM

* wip: summary prompt tokens

* refactor(summarizeMessages): new cut-off strategy that generates a better summary by adding context from beginning, truncating the middle, and providing the end
wip: draft out relevant providers and variables for token tracing

* refactor(createLLM): make streaming prop false by default

* chore: remove use of getTokenCountForResponse

* refactor(agents): use BufferMemory as ConversationSummaryBufferMemory token usage not easy to trace

* chore: remove passing of streaming prop, also console log useful vars for tracing

* feat: formatFromLangChain helper function to count tokens for ChatModelStart

* refactor(initializeLLM): add role for LLM tracing

* chore(formatFromLangChain): update JSDoc

* feat(formatMessages): formats langChain messages into OpenAI payload format

* chore: install openai-chat-tokens

* refactor(formatMessage): optimize conditional langChain logic
fix(formatFromLangChain): fix destructuring

* feat: accurate prompt tokens for ChatModelStart before generation

* refactor(handleChatModelStart): move to callbacks dir, use factory function

* refactor(initializeLLM): rename 'role' to 'context'

* feat(Balance/Transaction): new schema/models for tracking token spend
refactor(Key): factor out model export to separate file

* refactor(initializeClient): add req,res objects to client options

* feat: add-balance script to add to an existing users' token balance
refactor(Transaction): use multiplier map/function, return balance update

* refactor(Tx): update enum for tokenType, return 1 for multiplier if no map match

* refactor(Tx): add fair fallback value multiplier incase the config result is undefined

* refactor(Balance): rename 'tokens' to 'tokenCredits'

* feat: balance check, add tx.js for new tx-related methods and tests

* chore(summaryPrompts): update prompt token count

* refactor(callbacks): pass req, res
wip: check balance

* refactor(Tx): make convoId a String type, fix(calculateTokenValue)

* refactor(BaseClient): add conversationId as client prop when assigned

* feat(RunManager): track LLM runs with manager, track token spend from LLM,
refactor(OpenAIClient): use RunManager to create callbacks, pass user prop to langchain api calls

* feat(spendTokens): helper to spend prompt/completion tokens

* feat(checkBalance): add helper to check, log, deny request if balance doesn't have enough funds
refactor(Balance): static check method to return object instead of boolean now
wip(OpenAIClient): implement use of checkBalance

* refactor(initializeLLM): add token buffer to assure summary isn't generated when subsequent payload is too large
refactor(OpenAIClient): add checkBalance
refactor(createStartHandler): add checkBalance

* chore: remove prompt and completion token logging from route handler

* chore(spendTokens): add JSDoc

* feat(logTokenCost): record transactions for basic api calls

* chore(ask/edit): invoke getResponseSender only once per API call

* refactor(ask/edit): pass promptTokens to getIds and include in abort data

* refactor(getIds -> getReqData): rename function

* refactor(Tx): increase value if incomplete message

* feat: record tokenUsage when message is aborted

* refactor: subtract tokens when payload includes function_call

* refactor: add namespace for token_balance

* fix(spendTokens): only execute if corresponding token type amounts are defined

* refactor(checkBalance): throws Error if not enough token credits

* refactor(runTitleChain): pass and use signal, spread object props in create helpers, and use 'call' instead of 'run'

* fix(abortMiddleware): circular dependency, and default to empty string for completionTokens

* fix: properly cancel title requests when there isn't enough tokens to generate

* feat(predictNewSummary): custom chain for summaries to allow signal passing
refactor(summaryBuffer): use new custom chain

* feat(RunManager): add getRunByConversationId method, refactor: remove run and throw llm error on handleLLMError

* refactor(createStartHandler): if summary, add error details to runs

* fix(OpenAIClient): support aborting from summarization & showing error to user
refactor(summarizeMessages): remove unnecessary operations counting summaryPromptTokens and note for alternative, pass signal to summaryBuffer

* refactor(logTokenCost -> recordTokenUsage): rename

* refactor(checkBalance): include promptTokens in errorMessage

* refactor(checkBalance/spendTokens): move to models dir

* fix(createLanguageChain): correctly pass config

* refactor(initializeLLM/title): add tokenBuffer of 150 for balance check

* refactor(openAPIPlugin): pass signal and memory, filter functions by the one being called

* refactor(createStartHandler): add error to run if context is plugins as well

* refactor(RunManager/handleLLMError): throw error immediately if plugins, don't remove run

* refactor(PluginsClient): pass memory and signal to tools, cleanup error handling logic

* chore: use absolute equality for addTitle condition

* refactor(checkBalance): move checkBalance to execute after userMessage and tokenCounts are saved, also make conditional

* style: icon changes to match official

* fix(BaseClient): getTokenCountForResponse -> getTokenCount

* fix(formatLangChainMessages): add kwargs as fallback prop from lc_kwargs, update JSDoc

* refactor(Tx.create): does not update balance if CHECK_BALANCE is not enabled

* fix(e2e/cleanUp): cleanup new collections, import all model methods from index

* fix(config/add-balance): add uncaughtException listener

* fix: circular dependency

* refactor(initializeLLM/checkBalance): append new generations to errorMessage if cost exceeds balance

* fix(handleResponseMessage): only record token usage in this method if not error and completion is not skipped

* fix(createStartHandler): correct condition for generations

* chore: bump postcss due to moderate severity vulnerability

* chore: bump zod due to low severity vulnerability

* chore: bump openai & data-provider version

* feat(types): OpenAI Message types

* chore: update bun lockfile

* refactor(CodeBlock): add error block formatting

* refactor(utils/Plugin): factor out formatJSON and cn to separate files (json.ts and cn.ts), add extractJSON

* chore(logViolation): delete user_id after error is logged

* refactor(getMessageError -> Error): change to React.FC, add token_balance handling, use extractJSON to determine JSON instead of regex

* fix(DALL-E): use latest openai SDK

* chore: reorganize imports, fix type issue

* feat(server): add balance route

* fix(api/models): add auth

* feat(data-provider): /api/balance query

* feat: show balance if checking is enabled, refetch on final message or error

* chore: update docs, .env.example with token_usage info, add balance script command

* fix(Balance): fallback to empty obj for balance query

* style: slight adjustment of balance element

* docs(token_usage): add PR notes
This commit is contained in:
Danny Avila 2023-10-05 18:34:10 -04:00 committed by GitHub
parent be71a1947b
commit 365c39c405
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
81 changed files with 1606 additions and 293 deletions

38
api/models/Balance.js Normal file
View file

@ -0,0 +1,38 @@
const mongoose = require('mongoose');
const balanceSchema = require('./schema/balance');
const { getMultiplier } = require('./tx');
balanceSchema.statics.check = async function ({ user, model, valueKey, tokenType, amount, debug }) {
const multiplier = getMultiplier({ valueKey, tokenType, model });
const tokenCost = amount * multiplier;
const { tokenCredits: balance } = (await this.findOne({ user }, 'tokenCredits').lean()) ?? {};
if (debug) {
console.log('balance check', {
user,
model,
valueKey,
tokenType,
amount,
debug,
balance,
multiplier,
});
}
if (!balance) {
return {
canSpend: false,
balance: 0,
tokenCost,
};
}
if (debug) {
console.log('balance check', { tokenCost });
}
return { canSpend: balance >= tokenCost, balance, tokenCost };
};
module.exports = mongoose.model('Balance', balanceSchema);

4
api/models/Key.js Normal file
View file

@ -0,0 +1,4 @@
const mongoose = require('mongoose');
const keySchema = require('./schema/key');
module.exports = mongoose.model('Key', keySchema);

42
api/models/Transaction.js Normal file
View file

@ -0,0 +1,42 @@
const mongoose = require('mongoose');
const { isEnabled } = require('../server/utils/handleText');
const transactionSchema = require('./schema/transaction');
const { getMultiplier } = require('./tx');
const Balance = require('./Balance');
// Method to calculate and set the tokenValue for a transaction
transactionSchema.methods.calculateTokenValue = function () {
if (!this.valueKey || !this.tokenType) {
this.tokenValue = this.rawAmount;
}
const { valueKey, tokenType, model } = this;
const multiplier = getMultiplier({ valueKey, tokenType, model });
this.tokenValue = this.rawAmount * multiplier;
if (this.context && this.tokenType === 'completion' && this.context === 'incomplete') {
this.tokenValue = Math.floor(this.tokenValue * 1.15);
}
};
// Static method to create a transaction and update the balance
transactionSchema.statics.create = async function (transactionData) {
const Transaction = this;
const transaction = new Transaction(transactionData);
transaction.calculateTokenValue();
// Save the transaction
await transaction.save();
if (!isEnabled(process.env.CHECK_BALANCE)) {
return;
}
// Adjust the user's balance
return await Balance.findOneAndUpdate(
{ user: transaction.user },
{ $inc: { tokenCredits: transaction.tokenValue } },
{ upsert: true, new: true },
);
};
module.exports = mongoose.model('Transaction', transactionSchema);

View file

@ -0,0 +1,44 @@
const Balance = require('./Balance');
const { logViolation } = require('../cache');
/**
* Checks the balance for a user and determines if they can spend a certain amount.
* If the user cannot spend the amount, it logs a violation and denies the request.
*
* @async
* @function
* @param {Object} params - The function parameters.
* @param {Object} params.req - The Express request object.
* @param {Object} params.res - The Express response object.
* @param {Object} params.txData - The transaction data.
* @param {string} params.txData.user - The user ID or identifier.
* @param {('prompt' | 'completion')} params.txData.tokenType - The type of token.
* @param {number} params.txData.amount - The amount of tokens.
* @param {boolean} params.txData.debug - Debug flag.
* @param {string} params.txData.model - The model name or identifier.
* @returns {Promise<boolean>} Returns true if the user can spend the amount, otherwise denies the request.
* @throws {Error} Throws an error if there's an issue with the balance check.
*/
const checkBalance = async ({ req, res, txData }) => {
const { canSpend, balance, tokenCost } = await Balance.check(txData);
if (canSpend) {
return true;
}
const type = 'token_balance';
const errorMessage = {
type,
balance,
tokenCost,
promptTokens: txData.amount,
};
if (txData.generations && txData.generations.length > 0) {
errorMessage.generations = txData.generations;
}
await logViolation(req, res, type, errorMessage, 0);
throw new Error(JSON.stringify(errorMessage));
};
module.exports = checkBalance;

View file

@ -5,14 +5,20 @@ const {
deleteMessagesSince,
deleteMessages,
} = require('./Message');
const { getConvoTitle, getConvo, saveConvo } = require('./Conversation');
const { getConvoTitle, getConvo, saveConvo, deleteConvos } = require('./Conversation');
const { getPreset, getPresets, savePreset, deletePresets } = require('./Preset');
const Key = require('./Key');
const User = require('./User');
const Key = require('./schema/keySchema');
const Session = require('./Session');
const Balance = require('./Balance');
const Transaction = require('./Transaction');
module.exports = {
User,
Key,
Session,
Balance,
Transaction,
getMessages,
saveMessage,
@ -23,6 +29,7 @@ module.exports = {
getConvoTitle,
getConvo,
saveConvo,
deleteConvos,
getPreset,
getPresets,

View file

@ -0,0 +1,17 @@
const mongoose = require('mongoose');
const balanceSchema = mongoose.Schema({
user: {
type: mongoose.Schema.Types.ObjectId,
ref: 'User',
index: true,
required: true,
},
// 1000 tokenCredits = 1 mill ($0.001 USD)
tokenCredits: {
type: Number,
default: 0,
},
});
module.exports = balanceSchema;

View file

@ -22,4 +22,4 @@ const keySchema = mongoose.Schema({
keySchema.index({ expiresAt: 1 }, { expireAfterSeconds: 0 });
module.exports = mongoose.model('Key', keySchema);
module.exports = keySchema;

View file

@ -0,0 +1,33 @@
const mongoose = require('mongoose');
const transactionSchema = mongoose.Schema({
user: {
type: mongoose.Schema.Types.ObjectId,
ref: 'User',
index: true,
required: true,
},
conversationId: {
type: String,
ref: 'Conversation',
index: true,
},
tokenType: {
type: String,
enum: ['prompt', 'completion', 'credits'],
required: true,
},
model: {
type: String,
},
context: {
type: String,
},
valueKey: {
type: String,
},
rawAmount: Number,
tokenValue: Number,
});
module.exports = transactionSchema;

49
api/models/spendTokens.js Normal file
View file

@ -0,0 +1,49 @@
const Transaction = require('./Transaction');
/**
* Creates up to two transactions to record the spending of tokens.
*
* @function
* @async
* @param {Object} txData - Transaction data.
* @param {mongoose.Schema.Types.ObjectId} txData.user - The user ID.
* @param {String} txData.conversationId - The ID of the conversation.
* @param {String} txData.model - The model name.
* @param {String} txData.context - The context in which the transaction is made.
* @param {String} [txData.valueKey] - The value key (optional).
* @param {Object} tokenUsage - The number of tokens used.
* @param {Number} tokenUsage.promptTokens - The number of prompt tokens used.
* @param {Number} tokenUsage.completionTokens - The number of completion tokens used.
* @returns {Promise<void>} - Returns nothing.
* @throws {Error} - Throws an error if there's an issue creating the transactions.
*/
const spendTokens = async (txData, tokenUsage) => {
const { promptTokens, completionTokens } = tokenUsage;
let prompt, completion;
try {
if (promptTokens >= 0) {
prompt = await Transaction.create({
...txData,
tokenType: 'prompt',
rawAmount: -promptTokens,
});
}
if (!completionTokens) {
this.debug && console.dir({ prompt, completion }, { depth: null });
return;
}
completion = await Transaction.create({
...txData,
tokenType: 'completion',
rawAmount: -completionTokens,
});
this.debug && console.dir({ prompt, completion }, { depth: null });
} catch (err) {
console.error(err);
}
};
module.exports = spendTokens;

67
api/models/tx.js Normal file
View file

@ -0,0 +1,67 @@
const { matchModelName } = require('../utils');
/**
* Mapping of model token sizes to their respective multipliers for prompt and completion.
* @type {Object.<string, {prompt: number, completion: number}>}
*/
const tokenValues = {
'8k': { prompt: 3, completion: 6 },
'32k': { prompt: 6, completion: 12 },
'4k': { prompt: 1.5, completion: 2 },
'16k': { prompt: 3, completion: 4 },
};
/**
* Retrieves the key associated with a given model name.
*
* @param {string} model - The model name to match.
* @returns {string|undefined} The key corresponding to the model name, or undefined if no match is found.
*/
const getValueKey = (model) => {
const modelName = matchModelName(model);
if (!modelName) {
return undefined;
}
if (modelName.includes('gpt-3.5-turbo-16k')) {
return '16k';
} else if (modelName.includes('gpt-3.5')) {
return '4k';
} else if (modelName.includes('gpt-4-32k')) {
return '32k';
} else if (modelName.includes('gpt-4')) {
return '8k';
}
return undefined;
};
/**
* Retrieves the multiplier for a given value key and token type. If no value key is provided,
* it attempts to derive it from the model name.
*
* @param {Object} params - The parameters for the function.
* @param {string} [params.valueKey] - The key corresponding to the model name.
* @param {string} [params.tokenType] - The type of token (e.g., 'prompt' or 'completion').
* @param {string} [params.model] - The model name to derive the value key from if not provided.
* @returns {number} The multiplier for the given parameters, or a default value if not found.
*/
const getMultiplier = ({ valueKey, tokenType, model }) => {
if (valueKey && tokenType) {
return tokenValues[valueKey][tokenType] ?? 4.5;
}
if (!tokenType || !model) {
return 1;
}
valueKey = getValueKey(model);
if (!valueKey) {
return 4.5;
}
// If we got this far, and values[tokenType] is undefined somehow, return a rough average of default multipliers
return tokenValues[valueKey][tokenType] ?? 4.5;
};
module.exports = { tokenValues, getValueKey, getMultiplier };

47
api/models/tx.spec.js Normal file
View file

@ -0,0 +1,47 @@
const { getValueKey, getMultiplier } = require('./tx');
describe('getValueKey', () => {
it('should return "16k" for model name containing "gpt-3.5-turbo-16k"', () => {
expect(getValueKey('gpt-3.5-turbo-16k-some-other-info')).toBe('16k');
});
it('should return "4k" for model name containing "gpt-3.5"', () => {
expect(getValueKey('gpt-3.5-some-other-info')).toBe('4k');
});
it('should return "32k" for model name containing "gpt-4-32k"', () => {
expect(getValueKey('gpt-4-32k-some-other-info')).toBe('32k');
});
it('should return "8k" for model name containing "gpt-4"', () => {
expect(getValueKey('gpt-4-some-other-info')).toBe('8k');
});
it('should return undefined for model names that do not match any known patterns', () => {
expect(getValueKey('gpt-5-some-other-info')).toBeUndefined();
});
});
describe('getMultiplier', () => {
it('should return the correct multiplier for a given valueKey and tokenType', () => {
expect(getMultiplier({ valueKey: '8k', tokenType: 'prompt' })).toBe(3);
expect(getMultiplier({ valueKey: '8k', tokenType: 'completion' })).toBe(6);
});
it('should return 4.5 if tokenType is provided but not found in tokenValues', () => {
expect(getMultiplier({ valueKey: '8k', tokenType: 'unknownType' })).toBe(4.5);
});
it('should derive the valueKey from the model if not provided', () => {
expect(getMultiplier({ tokenType: 'prompt', model: 'gpt-4-some-other-info' })).toBe(3);
});
it('should return 1 if only model or tokenType is missing', () => {
expect(getMultiplier({ tokenType: 'prompt' })).toBe(1);
expect(getMultiplier({ model: 'gpt-4-some-other-info' })).toBe(1);
});
it('should return 4.5 if derived valueKey does not match any known patterns', () => {
expect(getMultiplier({ tokenType: 'prompt', model: 'gpt-5-some-other-info' })).toBe(4.5);
});
});