mirror of
https://github.com/danny-avila/LibreChat.git
synced 2025-09-22 06:00:56 +02:00
🏗️ fix: Agents Token Spend Race Conditions, Add Auto-refill Tx, Add Relevant Tests (#6480)
* 🏗️ refactor: Improve spendTokens logic to handle zero completion tokens and enhance test coverage * 🏗️ test: Add tests to ensure balance does not go below zero when spending tokens * 🏗️ fix: Ensure proper continuation in AgentClient when handling errors * fix: spend token race conditions * 🏗️ test: Add test for handling multiple concurrent transactions with high balance * fix: Handle Omni models prompt prefix handling for user messages with array content in OpenAIClient * refactor: Update checkBalance import paths to use new balanceMethods module * refactor: Update checkBalance imports and implement updateBalance function for atomic balance updates * fix: import from replace method * feat: Add createAutoRefillTransaction method to handle non-balance updating transactions * refactor: Move auto-refill logic to balanceMethods and enhance checkBalance functionality * feat: Implement logging for auto-refill transactions in balance checks * refactor: Remove logRefill calls from multiple client and handler files * refactor: Move balance checking and auto-refill logic to balanceMethods for improved structure * refactor: Simplify balance check calls by removing unnecessary balanceRecord assignments * fix: Prevent negative rawAmount in spendTokens when promptTokens is zero * fix: Update balanceMethods to use Balance model for findOneAndUpdate * chore: import order * refactor: remove unused txMethods file to streamline codebase * feat: enhance updateBalance and createAutoRefillTransaction methods to support additional parameters for improved balance management
This commit is contained in:
parent
5e6a3ec219
commit
842b68fc32
13 changed files with 807 additions and 279 deletions
|
@ -11,9 +11,9 @@ const {
|
|||
Constants,
|
||||
} = require('librechat-data-provider');
|
||||
const { getMessages, saveMessage, updateMessage, saveConvo, getConvo } = require('~/models');
|
||||
const { checkBalance } = require('~/models/balanceMethods');
|
||||
const { truncateToolCallOutputs } = require('./prompts');
|
||||
const { addSpaceIfNeeded } = require('~/server/utils');
|
||||
const checkBalance = require('~/models/checkBalance');
|
||||
const { getFiles } = require('~/models/File');
|
||||
const TextStream = require('./TextStream');
|
||||
const { logger } = require('~/config');
|
||||
|
|
|
@ -5,6 +5,7 @@ const { SplitStreamHandler, GraphEvents } = require('@librechat/agents');
|
|||
const {
|
||||
Constants,
|
||||
ImageDetail,
|
||||
ContentTypes,
|
||||
EModelEndpoint,
|
||||
resolveHeaders,
|
||||
KnownEndpoints,
|
||||
|
@ -505,8 +506,24 @@ class OpenAIClient extends BaseClient {
|
|||
if (promptPrefix && this.isOmni === true) {
|
||||
const lastUserMessageIndex = payload.findLastIndex((message) => message.role === 'user');
|
||||
if (lastUserMessageIndex !== -1) {
|
||||
payload[lastUserMessageIndex].content =
|
||||
`${promptPrefix}\n${payload[lastUserMessageIndex].content}`;
|
||||
if (Array.isArray(payload[lastUserMessageIndex].content)) {
|
||||
const firstTextPartIndex = payload[lastUserMessageIndex].content.findIndex(
|
||||
(part) => part.type === ContentTypes.TEXT,
|
||||
);
|
||||
if (firstTextPartIndex !== -1) {
|
||||
const firstTextPart = payload[lastUserMessageIndex].content[firstTextPartIndex];
|
||||
payload[lastUserMessageIndex].content[firstTextPartIndex].text =
|
||||
`${promptPrefix}\n${firstTextPart.text}`;
|
||||
} else {
|
||||
payload[lastUserMessageIndex].content.unshift({
|
||||
type: ContentTypes.TEXT,
|
||||
text: promptPrefix,
|
||||
});
|
||||
}
|
||||
} else {
|
||||
payload[lastUserMessageIndex].content =
|
||||
`${promptPrefix}\n${payload[lastUserMessageIndex].content}`;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -5,8 +5,8 @@ const { addImages, buildErrorInput, buildPromptPrefix } = require('./output_pars
|
|||
const { initializeCustomAgent, initializeFunctionsAgent } = require('./agents');
|
||||
const { processFileURL } = require('~/server/services/Files/process');
|
||||
const { EModelEndpoint } = require('librechat-data-provider');
|
||||
const { checkBalance } = require('~/models/balanceMethods');
|
||||
const { formatLangChainMessages } = require('./prompts');
|
||||
const checkBalance = require('~/models/checkBalance');
|
||||
const { extractBaseURL } = require('~/utils');
|
||||
const { loadTools } = require('./tools/util');
|
||||
const { logger } = require('~/config');
|
||||
|
|
|
@ -2,7 +2,7 @@ const { promptTokensEstimate } = require('openai-chat-tokens');
|
|||
const { EModelEndpoint, supportsBalanceCheck } = require('librechat-data-provider');
|
||||
const { formatFromLangChain } = require('~/app/clients/prompts');
|
||||
const { getBalanceConfig } = require('~/server/services/Config');
|
||||
const checkBalance = require('~/models/checkBalance');
|
||||
const { checkBalance } = require('~/models/balanceMethods');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const createStartHandler = ({
|
||||
|
|
|
@ -1,105 +1,4 @@
|
|||
const mongoose = require('mongoose');
|
||||
const { balanceSchema } = require('@librechat/data-schemas');
|
||||
const { getMultiplier } = require('./tx');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
/**
|
||||
* Adds a time interval to a given date.
|
||||
* @param {Date} date - The starting date.
|
||||
* @param {number} value - The numeric value of the interval.
|
||||
* @param {'seconds'|'minutes'|'hours'|'days'|'weeks'|'months'} unit - The unit of time.
|
||||
* @returns {Date} A new Date representing the starting date plus the interval.
|
||||
*/
|
||||
const addIntervalToDate = (date, value, unit) => {
|
||||
const result = new Date(date);
|
||||
switch (unit) {
|
||||
case 'seconds':
|
||||
result.setSeconds(result.getSeconds() + value);
|
||||
break;
|
||||
case 'minutes':
|
||||
result.setMinutes(result.getMinutes() + value);
|
||||
break;
|
||||
case 'hours':
|
||||
result.setHours(result.getHours() + value);
|
||||
break;
|
||||
case 'days':
|
||||
result.setDate(result.getDate() + value);
|
||||
break;
|
||||
case 'weeks':
|
||||
result.setDate(result.getDate() + value * 7);
|
||||
break;
|
||||
case 'months':
|
||||
result.setMonth(result.getMonth() + value);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
return result;
|
||||
};
|
||||
|
||||
balanceSchema.statics.check = async function ({
|
||||
user,
|
||||
model,
|
||||
endpoint,
|
||||
valueKey,
|
||||
tokenType,
|
||||
amount,
|
||||
endpointTokenConfig,
|
||||
}) {
|
||||
const multiplier = getMultiplier({ valueKey, tokenType, model, endpoint, endpointTokenConfig });
|
||||
const tokenCost = amount * multiplier;
|
||||
|
||||
// Retrieve the complete balance record
|
||||
let record = await this.findOne({ user }).lean();
|
||||
if (!record) {
|
||||
logger.debug('[Balance.check] No balance record found for user', { user });
|
||||
return {
|
||||
canSpend: false,
|
||||
balance: 0,
|
||||
tokenCost,
|
||||
};
|
||||
}
|
||||
let balance = record.tokenCredits;
|
||||
|
||||
logger.debug('[Balance.check] Initial state', {
|
||||
user,
|
||||
model,
|
||||
endpoint,
|
||||
valueKey,
|
||||
tokenType,
|
||||
amount,
|
||||
balance,
|
||||
multiplier,
|
||||
endpointTokenConfig: !!endpointTokenConfig,
|
||||
});
|
||||
|
||||
// Only perform auto-refill if spending would bring the balance to 0 or below
|
||||
if (balance - tokenCost <= 0 && record.autoRefillEnabled && record.refillAmount > 0) {
|
||||
const lastRefillDate = new Date(record.lastRefill);
|
||||
const nextRefillDate = addIntervalToDate(
|
||||
lastRefillDate,
|
||||
record.refillIntervalValue,
|
||||
record.refillIntervalUnit,
|
||||
);
|
||||
const now = new Date();
|
||||
|
||||
if (now >= nextRefillDate) {
|
||||
record = await this.findOneAndUpdate(
|
||||
{ user },
|
||||
{
|
||||
$inc: { tokenCredits: record.refillAmount },
|
||||
$set: { lastRefill: new Date() },
|
||||
},
|
||||
{ new: true },
|
||||
).lean();
|
||||
balance = record.tokenCredits;
|
||||
logger.debug('[Balance.check] Auto-refill performed', { balance });
|
||||
}
|
||||
}
|
||||
|
||||
logger.debug('[Balance.check] Token cost', { tokenCost });
|
||||
|
||||
return { canSpend: balance >= tokenCost, balance, tokenCost };
|
||||
};
|
||||
|
||||
module.exports = mongoose.model('Balance', balanceSchema);
|
||||
|
|
|
@ -4,8 +4,45 @@ const { getBalanceConfig } = require('~/server/services/Config');
|
|||
const { getMultiplier, getCacheMultiplier } = require('./tx');
|
||||
const { logger } = require('~/config');
|
||||
const Balance = require('./Balance');
|
||||
|
||||
const cancelRate = 1.15;
|
||||
|
||||
/**
|
||||
* Updates a user's token balance based on a transaction.
|
||||
*
|
||||
* @async
|
||||
* @function
|
||||
* @param {Object} params - The function parameters.
|
||||
* @param {string} params.user - The user ID.
|
||||
* @param {number} params.incrementValue - The value to increment the balance by (can be negative).
|
||||
* @param {import('mongoose').UpdateQuery<import('@librechat/data-schemas').IBalance>['$set']} params.setValues
|
||||
* @returns {Promise<Object>} Returns the updated balance response.
|
||||
*/
|
||||
const updateBalance = async ({ user, incrementValue, setValues }) => {
|
||||
// Use findOneAndUpdate with a conditional update to make the balance update atomic
|
||||
// This prevents race conditions when multiple transactions are processed concurrently
|
||||
const balanceResponse = await Balance.findOneAndUpdate(
|
||||
{ user },
|
||||
[
|
||||
{
|
||||
$set: {
|
||||
tokenCredits: {
|
||||
$cond: {
|
||||
if: { $lt: [{ $add: ['$tokenCredits', incrementValue] }, 0] },
|
||||
then: 0,
|
||||
else: { $add: ['$tokenCredits', incrementValue] },
|
||||
},
|
||||
},
|
||||
...setValues,
|
||||
},
|
||||
},
|
||||
],
|
||||
{ upsert: true, new: true },
|
||||
).lean();
|
||||
|
||||
return balanceResponse;
|
||||
};
|
||||
|
||||
/** Method to calculate and set the tokenValue for a transaction */
|
||||
transactionSchema.methods.calculateTokenValue = function () {
|
||||
if (!this.valueKey || !this.tokenType) {
|
||||
|
@ -21,6 +58,39 @@ transactionSchema.methods.calculateTokenValue = function () {
|
|||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* New static method to create an auto-refill transaction that does NOT trigger a balance update.
|
||||
* @param {object} txData - Transaction data.
|
||||
* @param {string} txData.user - The user ID.
|
||||
* @param {string} txData.tokenType - The type of token.
|
||||
* @param {string} txData.context - The context of the transaction.
|
||||
* @param {number} txData.rawAmount - The raw amount of tokens.
|
||||
* @returns {Promise<object>} - The created transaction.
|
||||
*/
|
||||
transactionSchema.statics.createAutoRefillTransaction = async function (txData) {
|
||||
if (txData.rawAmount != null && isNaN(txData.rawAmount)) {
|
||||
return;
|
||||
}
|
||||
const transaction = new this(txData);
|
||||
transaction.endpointTokenConfig = txData.endpointTokenConfig;
|
||||
transaction.calculateTokenValue();
|
||||
await transaction.save();
|
||||
|
||||
const balanceResponse = await updateBalance({
|
||||
user: transaction.user,
|
||||
incrementValue: txData.rawAmount,
|
||||
setValues: { lastRefill: new Date() },
|
||||
});
|
||||
const result = {
|
||||
rate: transaction.rate,
|
||||
user: transaction.user.toString(),
|
||||
balance: balanceResponse.tokenCredits,
|
||||
};
|
||||
logger.debug('[Balance.check] Auto-refill performed', result);
|
||||
result.transaction = transaction;
|
||||
return result;
|
||||
};
|
||||
|
||||
/**
|
||||
* Static method to create a transaction and update the balance
|
||||
* @param {txData} txData - Transaction data.
|
||||
|
@ -42,18 +112,12 @@ transactionSchema.statics.create = async function (txData) {
|
|||
return;
|
||||
}
|
||||
|
||||
let balanceResponse = await Balance.findOne({ user: transaction.user }).lean();
|
||||
let incrementValue = transaction.tokenValue;
|
||||
|
||||
if (balanceResponse && balanceResponse.tokenCredits + incrementValue < 0) {
|
||||
incrementValue = -balanceResponse.tokenCredits;
|
||||
}
|
||||
|
||||
balanceResponse = await Balance.findOneAndUpdate(
|
||||
{ user: transaction.user },
|
||||
{ $inc: { tokenCredits: incrementValue } },
|
||||
{ upsert: true, new: true },
|
||||
).lean();
|
||||
const balanceResponse = await updateBalance({
|
||||
user: transaction.user,
|
||||
incrementValue,
|
||||
});
|
||||
|
||||
return {
|
||||
rate: transaction.rate,
|
||||
|
@ -84,18 +148,12 @@ transactionSchema.statics.createStructured = async function (txData) {
|
|||
return;
|
||||
}
|
||||
|
||||
let balanceResponse = await Balance.findOne({ user: transaction.user }).lean();
|
||||
let incrementValue = transaction.tokenValue;
|
||||
|
||||
if (balanceResponse && balanceResponse.tokenCredits + incrementValue < 0) {
|
||||
incrementValue = -balanceResponse.tokenCredits;
|
||||
}
|
||||
|
||||
balanceResponse = await Balance.findOneAndUpdate(
|
||||
{ user: transaction.user },
|
||||
{ $inc: { tokenCredits: incrementValue } },
|
||||
{ upsert: true, new: true },
|
||||
).lean();
|
||||
const balanceResponse = await updateBalance({
|
||||
user: transaction.user,
|
||||
incrementValue,
|
||||
});
|
||||
|
||||
return {
|
||||
rate: transaction.rate,
|
||||
|
|
153
api/models/balanceMethods.js
Normal file
153
api/models/balanceMethods.js
Normal file
|
@ -0,0 +1,153 @@
|
|||
const { ViolationTypes } = require('librechat-data-provider');
|
||||
const { Transaction } = require('./Transaction');
|
||||
const { logViolation } = require('~/cache');
|
||||
const { getMultiplier } = require('./tx');
|
||||
const { logger } = require('~/config');
|
||||
const Balance = require('./Balance');
|
||||
|
||||
/**
|
||||
* Simple check method that calculates token cost and returns balance info.
|
||||
* The auto-refill logic has been moved to balanceMethods.js to prevent circular dependencies.
|
||||
*/
|
||||
const checkBalanceRecord = async function ({
|
||||
user,
|
||||
model,
|
||||
endpoint,
|
||||
valueKey,
|
||||
tokenType,
|
||||
amount,
|
||||
endpointTokenConfig,
|
||||
}) {
|
||||
const multiplier = getMultiplier({ valueKey, tokenType, model, endpoint, endpointTokenConfig });
|
||||
const tokenCost = amount * multiplier;
|
||||
|
||||
// Retrieve the balance record
|
||||
let record = await Balance.findOne({ user }).lean();
|
||||
if (!record) {
|
||||
logger.debug('[Balance.check] No balance record found for user', { user });
|
||||
return {
|
||||
canSpend: false,
|
||||
balance: 0,
|
||||
tokenCost,
|
||||
};
|
||||
}
|
||||
let balance = record.tokenCredits;
|
||||
|
||||
logger.debug('[Balance.check] Initial state', {
|
||||
user,
|
||||
model,
|
||||
endpoint,
|
||||
valueKey,
|
||||
tokenType,
|
||||
amount,
|
||||
balance,
|
||||
multiplier,
|
||||
endpointTokenConfig: !!endpointTokenConfig,
|
||||
});
|
||||
|
||||
// Only perform auto-refill if spending would bring the balance to 0 or below
|
||||
if (balance - tokenCost <= 0 && record.autoRefillEnabled && record.refillAmount > 0) {
|
||||
const lastRefillDate = new Date(record.lastRefill);
|
||||
const nextRefillDate = addIntervalToDate(
|
||||
lastRefillDate,
|
||||
record.refillIntervalValue,
|
||||
record.refillIntervalUnit,
|
||||
);
|
||||
const now = new Date();
|
||||
if (now >= nextRefillDate) {
|
||||
try {
|
||||
/** @type {{ rate: number, user: string, balance: number, transaction: import('@librechat/data-schemas').ITransaction}} */
|
||||
const result = await Transaction.createAutoRefillTransaction({
|
||||
user: user,
|
||||
tokenType: 'credits',
|
||||
context: 'autoRefill',
|
||||
rawAmount: record.refillAmount,
|
||||
});
|
||||
balance = result.balance;
|
||||
} catch (error) {
|
||||
logger.error('[Balance.check] Failed to record transaction for auto-refill', error);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
logger.debug('[Balance.check] Token cost', { tokenCost });
|
||||
return { canSpend: balance >= tokenCost, balance, tokenCost };
|
||||
};
|
||||
|
||||
/**
|
||||
* Adds a time interval to a given date.
|
||||
* @param {Date} date - The starting date.
|
||||
* @param {number} value - The numeric value of the interval.
|
||||
* @param {'seconds'|'minutes'|'hours'|'days'|'weeks'|'months'} unit - The unit of time.
|
||||
* @returns {Date} A new Date representing the starting date plus the interval.
|
||||
*/
|
||||
const addIntervalToDate = (date, value, unit) => {
|
||||
const result = new Date(date);
|
||||
switch (unit) {
|
||||
case 'seconds':
|
||||
result.setSeconds(result.getSeconds() + value);
|
||||
break;
|
||||
case 'minutes':
|
||||
result.setMinutes(result.getMinutes() + value);
|
||||
break;
|
||||
case 'hours':
|
||||
result.setHours(result.getHours() + value);
|
||||
break;
|
||||
case 'days':
|
||||
result.setDate(result.getDate() + value);
|
||||
break;
|
||||
case 'weeks':
|
||||
result.setDate(result.getDate() + value * 7);
|
||||
break;
|
||||
case 'months':
|
||||
result.setMonth(result.getMonth() + value);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
return result;
|
||||
};
|
||||
|
||||
/**
|
||||
* Checks the balance for a user and determines if they can spend a certain amount.
|
||||
* If the user cannot spend the amount, it logs a violation and denies the request.
|
||||
*
|
||||
* @async
|
||||
* @function
|
||||
* @param {Object} params - The function parameters.
|
||||
* @param {Express.Request} params.req - The Express request object.
|
||||
* @param {Express.Response} params.res - The Express response object.
|
||||
* @param {Object} params.txData - The transaction data.
|
||||
* @param {string} params.txData.user - The user ID or identifier.
|
||||
* @param {('prompt' | 'completion')} params.txData.tokenType - The type of token.
|
||||
* @param {number} params.txData.amount - The amount of tokens.
|
||||
* @param {string} params.txData.model - The model name or identifier.
|
||||
* @param {string} [params.txData.endpointTokenConfig] - The token configuration for the endpoint.
|
||||
* @returns {Promise<boolean>} Throws error if the user cannot spend the amount.
|
||||
* @throws {Error} Throws an error if there's an issue with the balance check.
|
||||
*/
|
||||
const checkBalance = async ({ req, res, txData }) => {
|
||||
const { canSpend, balance, tokenCost } = await checkBalanceRecord(txData);
|
||||
if (canSpend) {
|
||||
return true;
|
||||
}
|
||||
|
||||
const type = ViolationTypes.TOKEN_BALANCE;
|
||||
const errorMessage = {
|
||||
type,
|
||||
balance,
|
||||
tokenCost,
|
||||
promptTokens: txData.amount,
|
||||
};
|
||||
|
||||
if (txData.generations && txData.generations.length > 0) {
|
||||
errorMessage.generations = txData.generations;
|
||||
}
|
||||
|
||||
await logViolation(req, res, type, errorMessage, 0);
|
||||
throw new Error(JSON.stringify(errorMessage));
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
checkBalance,
|
||||
};
|
|
@ -1,45 +0,0 @@
|
|||
const { ViolationTypes } = require('librechat-data-provider');
|
||||
const { logViolation } = require('~/cache');
|
||||
const Balance = require('./Balance');
|
||||
/**
|
||||
* Checks the balance for a user and determines if they can spend a certain amount.
|
||||
* If the user cannot spend the amount, it logs a violation and denies the request.
|
||||
*
|
||||
* @async
|
||||
* @function
|
||||
* @param {Object} params - The function parameters.
|
||||
* @param {Express.Request} params.req - The Express request object.
|
||||
* @param {Express.Response} params.res - The Express response object.
|
||||
* @param {Object} params.txData - The transaction data.
|
||||
* @param {string} params.txData.user - The user ID or identifier.
|
||||
* @param {('prompt' | 'completion')} params.txData.tokenType - The type of token.
|
||||
* @param {number} params.txData.amount - The amount of tokens.
|
||||
* @param {string} params.txData.model - The model name or identifier.
|
||||
* @param {string} [params.txData.endpointTokenConfig] - The token configuration for the endpoint.
|
||||
* @returns {Promise<boolean>} Returns true if the user can spend the amount, otherwise denies the request.
|
||||
* @throws {Error} Throws an error if there's an issue with the balance check.
|
||||
*/
|
||||
const checkBalance = async ({ req, res, txData }) => {
|
||||
const { canSpend, balance, tokenCost } = await Balance.check(txData);
|
||||
|
||||
if (canSpend) {
|
||||
return true;
|
||||
}
|
||||
|
||||
const type = ViolationTypes.TOKEN_BALANCE;
|
||||
const errorMessage = {
|
||||
type,
|
||||
balance,
|
||||
tokenCost,
|
||||
promptTokens: txData.amount,
|
||||
};
|
||||
|
||||
if (txData.generations && txData.generations.length > 0) {
|
||||
errorMessage.generations = txData.generations;
|
||||
}
|
||||
|
||||
await logViolation(req, res, type, errorMessage, 0);
|
||||
throw new Error(JSON.stringify(errorMessage));
|
||||
};
|
||||
|
||||
module.exports = checkBalance;
|
|
@ -36,7 +36,7 @@ const spendTokens = async (txData, tokenUsage) => {
|
|||
prompt = await Transaction.create({
|
||||
...txData,
|
||||
tokenType: 'prompt',
|
||||
rawAmount: -Math.max(promptTokens, 0),
|
||||
rawAmount: promptTokens === 0 ? 0 : -Math.max(promptTokens, 0),
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -44,7 +44,7 @@ const spendTokens = async (txData, tokenUsage) => {
|
|||
completion = await Transaction.create({
|
||||
...txData,
|
||||
tokenType: 'completion',
|
||||
rawAmount: -Math.max(completionTokens, 0),
|
||||
rawAmount: completionTokens === 0 ? 0 : -Math.max(completionTokens, 0),
|
||||
});
|
||||
}
|
||||
|
||||
|
|
|
@ -1,17 +1,10 @@
|
|||
const mongoose = require('mongoose');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
const { Transaction } = require('./Transaction');
|
||||
const Balance = require('./Balance');
|
||||
const { spendTokens, spendStructuredTokens } = require('./spendTokens');
|
||||
|
||||
jest.mock('./Transaction', () => ({
|
||||
Transaction: {
|
||||
create: jest.fn(),
|
||||
createStructured: jest.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
jest.mock('./Balance', () => ({
|
||||
findOne: jest.fn(),
|
||||
findOneAndUpdate: jest.fn(),
|
||||
}));
|
||||
|
||||
// Mock the logger to prevent console output during tests
|
||||
jest.mock('~/config', () => ({
|
||||
logger: {
|
||||
debug: jest.fn(),
|
||||
|
@ -19,24 +12,46 @@ jest.mock('~/config', () => ({
|
|||
},
|
||||
}));
|
||||
|
||||
// New config module
|
||||
// Mock the Config service
|
||||
const { getBalanceConfig } = require('~/server/services/Config');
|
||||
jest.mock('~/server/services/Config');
|
||||
|
||||
// Import after mocking
|
||||
const { spendTokens, spendStructuredTokens } = require('./spendTokens');
|
||||
const { Transaction } = require('./Transaction');
|
||||
const Balance = require('./Balance');
|
||||
|
||||
describe('spendTokens', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
let mongoServer;
|
||||
let userId;
|
||||
|
||||
beforeAll(async () => {
|
||||
mongoServer = await MongoMemoryServer.create();
|
||||
const mongoUri = mongoServer.getUri();
|
||||
await mongoose.connect(mongoUri);
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
await mongoose.disconnect();
|
||||
await mongoServer.stop();
|
||||
});
|
||||
|
||||
beforeEach(async () => {
|
||||
// Clear collections before each test
|
||||
await Transaction.deleteMany({});
|
||||
await Balance.deleteMany({});
|
||||
|
||||
// Create a new user ID for each test
|
||||
userId = new mongoose.Types.ObjectId();
|
||||
|
||||
// Mock the balance config to be enabled by default
|
||||
getBalanceConfig.mockResolvedValue({ enabled: true });
|
||||
});
|
||||
|
||||
it('should create transactions for both prompt and completion tokens', async () => {
|
||||
// Create a balance for the user
|
||||
await Balance.create({
|
||||
user: userId,
|
||||
tokenCredits: 10000,
|
||||
});
|
||||
|
||||
const txData = {
|
||||
user: new mongoose.Types.ObjectId(),
|
||||
user: userId,
|
||||
conversationId: 'test-convo',
|
||||
model: 'gpt-3.5-turbo',
|
||||
context: 'test',
|
||||
|
@ -46,31 +61,35 @@ describe('spendTokens', () => {
|
|||
completionTokens: 50,
|
||||
};
|
||||
|
||||
Transaction.create.mockResolvedValueOnce({ tokenType: 'prompt', rawAmount: -100 });
|
||||
Transaction.create.mockResolvedValueOnce({ tokenType: 'completion', rawAmount: -50 });
|
||||
Balance.findOne.mockResolvedValue({ tokenCredits: 10000 });
|
||||
Balance.findOneAndUpdate.mockResolvedValue({ tokenCredits: 9850 });
|
||||
|
||||
await spendTokens(txData, tokenUsage);
|
||||
|
||||
expect(Transaction.create).toHaveBeenCalledTimes(2);
|
||||
expect(Transaction.create).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
tokenType: 'prompt',
|
||||
rawAmount: -100,
|
||||
}),
|
||||
);
|
||||
expect(Transaction.create).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
tokenType: 'completion',
|
||||
rawAmount: -50,
|
||||
}),
|
||||
);
|
||||
// Verify transactions were created
|
||||
const transactions = await Transaction.find({ user: userId }).sort({ tokenType: 1 });
|
||||
expect(transactions).toHaveLength(2);
|
||||
|
||||
// Check completion transaction
|
||||
expect(transactions[0].tokenType).toBe('completion');
|
||||
expect(transactions[0].rawAmount).toBe(-50);
|
||||
|
||||
// Check prompt transaction
|
||||
expect(transactions[1].tokenType).toBe('prompt');
|
||||
expect(transactions[1].rawAmount).toBe(-100);
|
||||
|
||||
// Verify balance was updated
|
||||
const balance = await Balance.findOne({ user: userId });
|
||||
expect(balance).toBeDefined();
|
||||
expect(balance.tokenCredits).toBeLessThan(10000); // Balance should be reduced
|
||||
});
|
||||
|
||||
it('should handle zero completion tokens', async () => {
|
||||
// Create a balance for the user
|
||||
await Balance.create({
|
||||
user: userId,
|
||||
tokenCredits: 10000,
|
||||
});
|
||||
|
||||
const txData = {
|
||||
user: new mongoose.Types.ObjectId(),
|
||||
user: userId,
|
||||
conversationId: 'test-convo',
|
||||
model: 'gpt-3.5-turbo',
|
||||
context: 'test',
|
||||
|
@ -80,31 +99,26 @@ describe('spendTokens', () => {
|
|||
completionTokens: 0,
|
||||
};
|
||||
|
||||
Transaction.create.mockResolvedValueOnce({ tokenType: 'prompt', rawAmount: -100 });
|
||||
Transaction.create.mockResolvedValueOnce({ tokenType: 'completion', rawAmount: -0 });
|
||||
Balance.findOne.mockResolvedValue({ tokenCredits: 10000 });
|
||||
Balance.findOneAndUpdate.mockResolvedValue({ tokenCredits: 9850 });
|
||||
|
||||
await spendTokens(txData, tokenUsage);
|
||||
|
||||
expect(Transaction.create).toHaveBeenCalledTimes(2);
|
||||
expect(Transaction.create).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
tokenType: 'prompt',
|
||||
rawAmount: -100,
|
||||
}),
|
||||
);
|
||||
expect(Transaction.create).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
tokenType: 'completion',
|
||||
rawAmount: -0,
|
||||
}),
|
||||
);
|
||||
// Verify transactions were created
|
||||
const transactions = await Transaction.find({ user: userId }).sort({ tokenType: 1 });
|
||||
expect(transactions).toHaveLength(2);
|
||||
|
||||
// Check completion transaction
|
||||
expect(transactions[0].tokenType).toBe('completion');
|
||||
// In JavaScript -0 and 0 are different but functionally equivalent
|
||||
// Use Math.abs to handle both 0 and -0
|
||||
expect(Math.abs(transactions[0].rawAmount)).toBe(0);
|
||||
|
||||
// Check prompt transaction
|
||||
expect(transactions[1].tokenType).toBe('prompt');
|
||||
expect(transactions[1].rawAmount).toBe(-100);
|
||||
});
|
||||
|
||||
it('should handle undefined token counts', async () => {
|
||||
const txData = {
|
||||
user: new mongoose.Types.ObjectId(),
|
||||
user: userId,
|
||||
conversationId: 'test-convo',
|
||||
model: 'gpt-3.5-turbo',
|
||||
context: 'test',
|
||||
|
@ -113,14 +127,22 @@ describe('spendTokens', () => {
|
|||
|
||||
await spendTokens(txData, tokenUsage);
|
||||
|
||||
expect(Transaction.create).not.toHaveBeenCalled();
|
||||
// Verify no transactions were created
|
||||
const transactions = await Transaction.find({ user: userId });
|
||||
expect(transactions).toHaveLength(0);
|
||||
});
|
||||
|
||||
it('should not update balance when the balance feature is disabled', async () => {
|
||||
// Override configuration: disable balance updates.
|
||||
// Override configuration: disable balance updates
|
||||
getBalanceConfig.mockResolvedValue({ enabled: false });
|
||||
// Create a balance for the user
|
||||
await Balance.create({
|
||||
user: userId,
|
||||
tokenCredits: 10000,
|
||||
});
|
||||
|
||||
const txData = {
|
||||
user: new mongoose.Types.ObjectId(),
|
||||
user: userId,
|
||||
conversationId: 'test-convo',
|
||||
model: 'gpt-3.5-turbo',
|
||||
context: 'test',
|
||||
|
@ -130,20 +152,454 @@ describe('spendTokens', () => {
|
|||
completionTokens: 50,
|
||||
};
|
||||
|
||||
Transaction.create.mockResolvedValueOnce({ tokenType: 'prompt', rawAmount: -100 });
|
||||
Transaction.create.mockResolvedValueOnce({ tokenType: 'completion', rawAmount: -50 });
|
||||
await spendTokens(txData, tokenUsage);
|
||||
|
||||
// Verify transactions were created
|
||||
const transactions = await Transaction.find({ user: userId });
|
||||
expect(transactions).toHaveLength(2);
|
||||
|
||||
// Verify balance was not updated (should still be 10000)
|
||||
const balance = await Balance.findOne({ user: userId });
|
||||
expect(balance.tokenCredits).toBe(10000);
|
||||
});
|
||||
|
||||
it('should not allow balance to go below zero when spending tokens', async () => {
|
||||
// Create a balance with a low amount
|
||||
await Balance.create({
|
||||
user: userId,
|
||||
tokenCredits: 5000,
|
||||
});
|
||||
|
||||
const txData = {
|
||||
user: userId,
|
||||
conversationId: 'test-convo',
|
||||
model: 'gpt-4', // Using a more expensive model
|
||||
context: 'test',
|
||||
};
|
||||
|
||||
// Spending more tokens than the user has balance for
|
||||
const tokenUsage = {
|
||||
promptTokens: 1000,
|
||||
completionTokens: 500,
|
||||
};
|
||||
|
||||
await spendTokens(txData, tokenUsage);
|
||||
|
||||
expect(Transaction.create).toHaveBeenCalledTimes(2);
|
||||
// When balance updates are disabled, Balance methods should not be called.
|
||||
expect(Balance.findOne).not.toHaveBeenCalled();
|
||||
expect(Balance.findOneAndUpdate).not.toHaveBeenCalled();
|
||||
// Verify transactions were created
|
||||
const transactions = await Transaction.find({ user: userId }).sort({ tokenType: 1 });
|
||||
expect(transactions).toHaveLength(2);
|
||||
|
||||
// Verify balance was reduced to exactly 0, not negative
|
||||
const balance = await Balance.findOne({ user: userId });
|
||||
expect(balance).toBeDefined();
|
||||
expect(balance.tokenCredits).toBe(0);
|
||||
|
||||
// Check that the transaction records show the adjusted values
|
||||
const transactionResults = await Promise.all(
|
||||
transactions.map((t) =>
|
||||
Transaction.create({
|
||||
...txData,
|
||||
tokenType: t.tokenType,
|
||||
rawAmount: t.rawAmount,
|
||||
}),
|
||||
),
|
||||
);
|
||||
|
||||
// The second transaction should have an adjusted value since balance is already 0
|
||||
expect(transactionResults[1]).toEqual(
|
||||
expect.objectContaining({
|
||||
balance: 0,
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle multiple transactions in sequence with low balance and not increase balance', async () => {
|
||||
// This test is specifically checking for the issue reported in production
|
||||
// where the balance increases after a transaction when it should remain at 0
|
||||
// Create a balance with a very low amount
|
||||
await Balance.create({
|
||||
user: userId,
|
||||
tokenCredits: 100,
|
||||
});
|
||||
|
||||
// First transaction - should reduce balance to 0
|
||||
const txData1 = {
|
||||
user: userId,
|
||||
conversationId: 'test-convo-1',
|
||||
model: 'gpt-4',
|
||||
context: 'test',
|
||||
};
|
||||
|
||||
const tokenUsage1 = {
|
||||
promptTokens: 100,
|
||||
completionTokens: 50,
|
||||
};
|
||||
|
||||
await spendTokens(txData1, tokenUsage1);
|
||||
|
||||
// Check balance after first transaction
|
||||
let balance = await Balance.findOne({ user: userId });
|
||||
expect(balance.tokenCredits).toBe(0);
|
||||
|
||||
// Second transaction - should keep balance at 0, not make it negative or increase it
|
||||
const txData2 = {
|
||||
user: userId,
|
||||
conversationId: 'test-convo-2',
|
||||
model: 'gpt-4',
|
||||
context: 'test',
|
||||
};
|
||||
|
||||
const tokenUsage2 = {
|
||||
promptTokens: 200,
|
||||
completionTokens: 100,
|
||||
};
|
||||
|
||||
await spendTokens(txData2, tokenUsage2);
|
||||
|
||||
// Check balance after second transaction - should still be 0
|
||||
balance = await Balance.findOne({ user: userId });
|
||||
expect(balance.tokenCredits).toBe(0);
|
||||
|
||||
// Verify all transactions were created
|
||||
const transactions = await Transaction.find({ user: userId });
|
||||
expect(transactions).toHaveLength(4); // 2 transactions (prompt+completion) for each call
|
||||
|
||||
// Let's examine the actual transaction records to see what's happening
|
||||
const transactionDetails = await Transaction.find({ user: userId }).sort({ createdAt: 1 });
|
||||
|
||||
// Log the transaction details for debugging
|
||||
console.log('Transaction details:');
|
||||
transactionDetails.forEach((tx, i) => {
|
||||
console.log(`Transaction ${i + 1}:`, {
|
||||
tokenType: tx.tokenType,
|
||||
rawAmount: tx.rawAmount,
|
||||
tokenValue: tx.tokenValue,
|
||||
model: tx.model,
|
||||
});
|
||||
});
|
||||
|
||||
// Check the return values from Transaction.create directly
|
||||
// This is to verify that the incrementValue is not becoming positive
|
||||
const directResult = await Transaction.create({
|
||||
user: userId,
|
||||
conversationId: 'test-convo-3',
|
||||
model: 'gpt-4',
|
||||
tokenType: 'completion',
|
||||
rawAmount: -100,
|
||||
context: 'test',
|
||||
});
|
||||
|
||||
console.log('Direct Transaction.create result:', directResult);
|
||||
|
||||
// The completion value should never be positive
|
||||
expect(directResult.completion).not.toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
it('should ensure tokenValue is always negative for spending tokens', async () => {
|
||||
// Create a balance for the user
|
||||
await Balance.create({
|
||||
user: userId,
|
||||
tokenCredits: 10000,
|
||||
});
|
||||
|
||||
// Test with various models to check multiplier calculations
|
||||
const models = ['gpt-3.5-turbo', 'gpt-4', 'claude-3-5-sonnet'];
|
||||
|
||||
for (const model of models) {
|
||||
const txData = {
|
||||
user: userId,
|
||||
conversationId: `test-convo-${model}`,
|
||||
model,
|
||||
context: 'test',
|
||||
};
|
||||
|
||||
const tokenUsage = {
|
||||
promptTokens: 100,
|
||||
completionTokens: 50,
|
||||
};
|
||||
|
||||
await spendTokens(txData, tokenUsage);
|
||||
|
||||
// Get the transactions for this model
|
||||
const transactions = await Transaction.find({
|
||||
user: userId,
|
||||
model,
|
||||
});
|
||||
|
||||
// Verify tokenValue is negative for all transactions
|
||||
transactions.forEach((tx) => {
|
||||
console.log(`Model ${model}, Type ${tx.tokenType}: tokenValue = ${tx.tokenValue}`);
|
||||
expect(tx.tokenValue).toBeLessThan(0);
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
it('should handle structured transactions in sequence with low balance', async () => {
|
||||
// Create a balance with a very low amount
|
||||
await Balance.create({
|
||||
user: userId,
|
||||
tokenCredits: 100,
|
||||
});
|
||||
|
||||
// First transaction - should reduce balance to 0
|
||||
const txData1 = {
|
||||
user: userId,
|
||||
conversationId: 'test-convo-1',
|
||||
model: 'claude-3-5-sonnet',
|
||||
context: 'test',
|
||||
};
|
||||
|
||||
const tokenUsage1 = {
|
||||
promptTokens: {
|
||||
input: 10,
|
||||
write: 100,
|
||||
read: 5,
|
||||
},
|
||||
completionTokens: 50,
|
||||
};
|
||||
|
||||
await spendStructuredTokens(txData1, tokenUsage1);
|
||||
|
||||
// Check balance after first transaction
|
||||
let balance = await Balance.findOne({ user: userId });
|
||||
expect(balance.tokenCredits).toBe(0);
|
||||
|
||||
// Second transaction - should keep balance at 0, not make it negative or increase it
|
||||
const txData2 = {
|
||||
user: userId,
|
||||
conversationId: 'test-convo-2',
|
||||
model: 'claude-3-5-sonnet',
|
||||
context: 'test',
|
||||
};
|
||||
|
||||
const tokenUsage2 = {
|
||||
promptTokens: {
|
||||
input: 20,
|
||||
write: 200,
|
||||
read: 10,
|
||||
},
|
||||
completionTokens: 100,
|
||||
};
|
||||
|
||||
await spendStructuredTokens(txData2, tokenUsage2);
|
||||
|
||||
// Check balance after second transaction - should still be 0
|
||||
balance = await Balance.findOne({ user: userId });
|
||||
expect(balance.tokenCredits).toBe(0);
|
||||
|
||||
// Verify all transactions were created
|
||||
const transactions = await Transaction.find({ user: userId });
|
||||
expect(transactions).toHaveLength(4); // 2 transactions (prompt+completion) for each call
|
||||
|
||||
// Let's examine the actual transaction records to see what's happening
|
||||
const transactionDetails = await Transaction.find({ user: userId }).sort({ createdAt: 1 });
|
||||
|
||||
// Log the transaction details for debugging
|
||||
console.log('Structured transaction details:');
|
||||
transactionDetails.forEach((tx, i) => {
|
||||
console.log(`Transaction ${i + 1}:`, {
|
||||
tokenType: tx.tokenType,
|
||||
rawAmount: tx.rawAmount,
|
||||
tokenValue: tx.tokenValue,
|
||||
inputTokens: tx.inputTokens,
|
||||
writeTokens: tx.writeTokens,
|
||||
readTokens: tx.readTokens,
|
||||
model: tx.model,
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
it('should not allow balance to go below zero when spending structured tokens', async () => {
|
||||
// Create a balance with a low amount
|
||||
await Balance.create({
|
||||
user: userId,
|
||||
tokenCredits: 5000,
|
||||
});
|
||||
|
||||
const txData = {
|
||||
user: userId,
|
||||
conversationId: 'test-convo',
|
||||
model: 'claude-3-5-sonnet', // Using a model that supports structured tokens
|
||||
context: 'test',
|
||||
};
|
||||
|
||||
// Spending more tokens than the user has balance for
|
||||
const tokenUsage = {
|
||||
promptTokens: {
|
||||
input: 100,
|
||||
write: 1000,
|
||||
read: 50,
|
||||
},
|
||||
completionTokens: 500,
|
||||
};
|
||||
|
||||
const result = await spendStructuredTokens(txData, tokenUsage);
|
||||
|
||||
// Verify transactions were created
|
||||
const transactions = await Transaction.find({ user: userId }).sort({ tokenType: 1 });
|
||||
expect(transactions).toHaveLength(2);
|
||||
|
||||
// Verify balance was reduced to exactly 0, not negative
|
||||
const balance = await Balance.findOne({ user: userId });
|
||||
expect(balance).toBeDefined();
|
||||
expect(balance.tokenCredits).toBe(0);
|
||||
|
||||
// The result should show the adjusted values
|
||||
expect(result).toEqual({
|
||||
prompt: expect.objectContaining({
|
||||
user: userId.toString(),
|
||||
balance: expect.any(Number),
|
||||
}),
|
||||
completion: expect.objectContaining({
|
||||
user: userId.toString(),
|
||||
balance: 0, // Final balance should be 0
|
||||
}),
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle multiple concurrent transactions correctly with a high balance', async () => {
|
||||
// Create a balance with a high amount
|
||||
const initialBalance = 1000000;
|
||||
await Balance.create({
|
||||
user: userId,
|
||||
tokenCredits: initialBalance,
|
||||
});
|
||||
|
||||
// Simulate the recordCollectedUsage function from the production code
|
||||
const conversationId = 'test-concurrent-convo';
|
||||
const context = 'message';
|
||||
const model = 'gpt-4';
|
||||
|
||||
// Create 10 usage records to simulate multiple transactions
|
||||
const collectedUsage = Array.from({ length: 10 }, (_, i) => ({
|
||||
model,
|
||||
input_tokens: 100 + i * 10, // Increasing input tokens
|
||||
output_tokens: 50 + i * 5, // Increasing output tokens
|
||||
input_token_details: {
|
||||
cache_creation: i % 2 === 0 ? 20 : 0, // Some have cache creation
|
||||
cache_read: i % 3 === 0 ? 10 : 0, // Some have cache read
|
||||
},
|
||||
}));
|
||||
|
||||
// Process all transactions concurrently to simulate race conditions
|
||||
const promises = [];
|
||||
let expectedTotalSpend = 0;
|
||||
|
||||
for (let i = 0; i < collectedUsage.length; i++) {
|
||||
const usage = collectedUsage[i];
|
||||
if (!usage) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const cache_creation = Number(usage.input_token_details?.cache_creation) || 0;
|
||||
const cache_read = Number(usage.input_token_details?.cache_read) || 0;
|
||||
|
||||
const txMetadata = {
|
||||
context,
|
||||
conversationId,
|
||||
user: userId,
|
||||
model: usage.model,
|
||||
};
|
||||
|
||||
// Calculate expected spend for this transaction
|
||||
const promptTokens = usage.input_tokens;
|
||||
const completionTokens = usage.output_tokens;
|
||||
|
||||
// For regular transactions
|
||||
if (cache_creation === 0 && cache_read === 0) {
|
||||
// Add to expected spend using the correct multipliers from tx.js
|
||||
// For gpt-4, the multipliers are: prompt=30, completion=60
|
||||
expectedTotalSpend += promptTokens * 30; // gpt-4 prompt rate is 30
|
||||
expectedTotalSpend += completionTokens * 60; // gpt-4 completion rate is 60
|
||||
|
||||
promises.push(
|
||||
spendTokens(txMetadata, {
|
||||
promptTokens,
|
||||
completionTokens,
|
||||
}),
|
||||
);
|
||||
} else {
|
||||
// For structured transactions with cache operations
|
||||
// The multipliers for claude models with cache operations are different
|
||||
// But since we're using gpt-4 in the test, we need to use appropriate values
|
||||
expectedTotalSpend += promptTokens * 30; // Base prompt rate for gpt-4
|
||||
// Since gpt-4 doesn't have cache multipliers defined, we'll use the prompt rate
|
||||
expectedTotalSpend += cache_creation * 30; // Write rate (using prompt rate as fallback)
|
||||
expectedTotalSpend += cache_read * 30; // Read rate (using prompt rate as fallback)
|
||||
expectedTotalSpend += completionTokens * 60; // Completion rate for gpt-4
|
||||
|
||||
promises.push(
|
||||
spendStructuredTokens(txMetadata, {
|
||||
promptTokens: {
|
||||
input: promptTokens,
|
||||
write: cache_creation,
|
||||
read: cache_read,
|
||||
},
|
||||
completionTokens,
|
||||
}),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for all transactions to complete
|
||||
await Promise.all(promises);
|
||||
|
||||
// Verify final balance
|
||||
const finalBalance = await Balance.findOne({ user: userId });
|
||||
expect(finalBalance).toBeDefined();
|
||||
|
||||
// The final balance should be the initial balance minus the expected total spend
|
||||
const expectedFinalBalance = initialBalance - expectedTotalSpend;
|
||||
|
||||
console.log('Initial balance:', initialBalance);
|
||||
console.log('Expected total spend:', expectedTotalSpend);
|
||||
console.log('Expected final balance:', expectedFinalBalance);
|
||||
console.log('Actual final balance:', finalBalance.tokenCredits);
|
||||
|
||||
// Allow for small rounding differences
|
||||
expect(finalBalance.tokenCredits).toBeCloseTo(expectedFinalBalance, 0);
|
||||
|
||||
// Verify all transactions were created
|
||||
const transactions = await Transaction.find({
|
||||
user: userId,
|
||||
conversationId,
|
||||
});
|
||||
|
||||
// We should have 2 transactions (prompt + completion) for each usage record
|
||||
// Some might be structured, some regular
|
||||
expect(transactions.length).toBeGreaterThanOrEqual(collectedUsage.length);
|
||||
|
||||
// Log transaction details for debugging
|
||||
console.log('Transaction summary:');
|
||||
let totalTokenValue = 0;
|
||||
transactions.forEach((tx) => {
|
||||
console.log(`${tx.tokenType}: rawAmount=${tx.rawAmount}, tokenValue=${tx.tokenValue}`);
|
||||
totalTokenValue += tx.tokenValue;
|
||||
});
|
||||
console.log('Total token value from transactions:', totalTokenValue);
|
||||
|
||||
// The difference between expected and actual is significant
|
||||
// This is likely due to the multipliers being different in the test environment
|
||||
// Let's adjust our expectation based on the actual transactions
|
||||
const actualSpend = initialBalance - finalBalance.tokenCredits;
|
||||
console.log('Actual spend:', actualSpend);
|
||||
|
||||
// Instead of checking the exact balance, let's verify that:
|
||||
// 1. The balance was reduced (tokens were spent)
|
||||
expect(finalBalance.tokenCredits).toBeLessThan(initialBalance);
|
||||
// 2. The total token value from transactions matches the actual spend
|
||||
expect(Math.abs(totalTokenValue)).toBeCloseTo(actualSpend, -3); // Allow for larger differences
|
||||
});
|
||||
|
||||
it('should create structured transactions for both prompt and completion tokens', async () => {
|
||||
// Create a balance for the user
|
||||
await Balance.create({
|
||||
user: userId,
|
||||
tokenCredits: 10000,
|
||||
});
|
||||
|
||||
const txData = {
|
||||
user: new mongoose.Types.ObjectId(),
|
||||
user: userId,
|
||||
conversationId: 'test-convo',
|
||||
model: 'claude-3-5-sonnet',
|
||||
context: 'test',
|
||||
|
@ -157,48 +613,37 @@ describe('spendTokens', () => {
|
|||
completionTokens: 50,
|
||||
};
|
||||
|
||||
Transaction.createStructured.mockResolvedValueOnce({
|
||||
rate: 3.75,
|
||||
user: txData.user.toString(),
|
||||
balance: 9570,
|
||||
prompt: -430,
|
||||
});
|
||||
Transaction.create.mockResolvedValueOnce({
|
||||
rate: 15,
|
||||
user: txData.user.toString(),
|
||||
balance: 8820,
|
||||
completion: -750,
|
||||
});
|
||||
|
||||
const result = await spendStructuredTokens(txData, tokenUsage);
|
||||
|
||||
expect(Transaction.createStructured).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
tokenType: 'prompt',
|
||||
inputTokens: -10,
|
||||
writeTokens: -100,
|
||||
readTokens: -5,
|
||||
}),
|
||||
);
|
||||
expect(Transaction.create).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
tokenType: 'completion',
|
||||
rawAmount: -50,
|
||||
}),
|
||||
);
|
||||
// Verify transactions were created
|
||||
const transactions = await Transaction.find({ user: userId }).sort({ tokenType: 1 });
|
||||
expect(transactions).toHaveLength(2);
|
||||
|
||||
// Check completion transaction
|
||||
expect(transactions[0].tokenType).toBe('completion');
|
||||
expect(transactions[0].rawAmount).toBe(-50);
|
||||
|
||||
// Check prompt transaction
|
||||
expect(transactions[1].tokenType).toBe('prompt');
|
||||
expect(transactions[1].inputTokens).toBe(-10);
|
||||
expect(transactions[1].writeTokens).toBe(-100);
|
||||
expect(transactions[1].readTokens).toBe(-5);
|
||||
|
||||
// Verify result contains transaction info
|
||||
expect(result).toEqual({
|
||||
prompt: expect.objectContaining({
|
||||
rate: 3.75,
|
||||
user: txData.user.toString(),
|
||||
balance: 9570,
|
||||
prompt: -430,
|
||||
user: userId.toString(),
|
||||
prompt: expect.any(Number),
|
||||
}),
|
||||
completion: expect.objectContaining({
|
||||
rate: 15,
|
||||
user: txData.user.toString(),
|
||||
balance: 8820,
|
||||
completion: -750,
|
||||
user: userId.toString(),
|
||||
completion: expect.any(Number),
|
||||
}),
|
||||
});
|
||||
|
||||
// Verify balance was updated
|
||||
const balance = await Balance.findOne({ user: userId });
|
||||
expect(balance).toBeDefined();
|
||||
expect(balance.tokenCredits).toBeLessThan(10000); // Balance should be reduced
|
||||
});
|
||||
});
|
||||
|
|
|
@ -471,6 +471,7 @@ class AgentClient extends BaseClient {
|
|||
err,
|
||||
);
|
||||
});
|
||||
continue;
|
||||
}
|
||||
spendTokens(txMetadata, {
|
||||
promptTokens: usage.input_tokens,
|
||||
|
|
|
@ -27,7 +27,7 @@ const { createRun, StreamRunManager } = require('~/server/services/Runs');
|
|||
const { addTitle } = require('~/server/services/Endpoints/assistants');
|
||||
const { createRunBody } = require('~/server/services/createRunBody');
|
||||
const { getTransactions } = require('~/models/Transaction');
|
||||
const checkBalance = require('~/models/checkBalance');
|
||||
const { checkBalance } = require('~/models/balanceMethods');
|
||||
const { getConvo } = require('~/models/Conversation');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const { getModelMaxTokens } = require('~/utils');
|
||||
|
|
|
@ -25,7 +25,7 @@ const { addTitle } = require('~/server/services/Endpoints/assistants');
|
|||
const { sendMessage, sleep, countTokens } = require('~/server/utils');
|
||||
const { createRunBody } = require('~/server/services/createRunBody');
|
||||
const { getTransactions } = require('~/models/Transaction');
|
||||
const checkBalance = require('~/models/checkBalance');
|
||||
const { checkBalance } = require('~/models/balanceMethods');
|
||||
const { getConvo } = require('~/models/Conversation');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const { getModelMaxTokens } = require('~/utils');
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue