const { logger } = require('@librechat/data-schemas'); const { ViolationTypes } = require('librechat-data-provider'); const { createAutoRefillTransaction } = require('./Transaction'); const { logViolation } = require('~/cache'); const { getMultiplier } = require('./tx'); const { Balance } = require('~/db/models'); function isInvalidDate(date) { return isNaN(date); } /** * Simple check method that calculates token cost and returns balance info. * The auto-refill logic has been moved to balanceMethods.js to prevent circular dependencies. */ const checkBalanceRecord = async function ({ user, model, endpoint, valueKey, tokenType, amount, endpointTokenConfig, }) { const multiplier = getMultiplier({ valueKey, tokenType, model, endpoint, endpointTokenConfig }); const tokenCost = amount * multiplier; // Retrieve the balance record let record = await Balance.findOne({ user }).lean(); if (!record) { logger.debug('[Balance.check] No balance record found for user', { user }); return { canSpend: false, balance: 0, tokenCost, }; } let balance = record.tokenCredits; logger.debug('[Balance.check] Initial state', { user, model, endpoint, valueKey, tokenType, amount, balance, multiplier, endpointTokenConfig: !!endpointTokenConfig, }); // Only perform auto-refill if spending would bring the balance to 0 or below if (balance - tokenCost <= 0 && record.autoRefillEnabled && record.refillAmount > 0) { const lastRefillDate = new Date(record.lastRefill); const now = new Date(); if ( isInvalidDate(lastRefillDate) || now >= addIntervalToDate(lastRefillDate, record.refillIntervalValue, record.refillIntervalUnit) ) { try { /** @type {{ rate: number, user: string, balance: number, transaction: import('@librechat/data-schemas').ITransaction}} */ const result = await createAutoRefillTransaction({ user: user, tokenType: 'credits', context: 'autoRefill', rawAmount: record.refillAmount, }); balance = result.balance; } catch (error) { logger.error('[Balance.check] Failed to record transaction for auto-refill', error); } } } logger.debug('[Balance.check] Token cost', { tokenCost }); return { canSpend: balance >= tokenCost, balance, tokenCost }; }; /** * Adds a time interval to a given date. * @param {Date} date - The starting date. * @param {number} value - The numeric value of the interval. * @param {'seconds'|'minutes'|'hours'|'days'|'weeks'|'months'} unit - The unit of time. * @returns {Date} A new Date representing the starting date plus the interval. */ const addIntervalToDate = (date, value, unit) => { const result = new Date(date); switch (unit) { case 'seconds': result.setSeconds(result.getSeconds() + value); break; case 'minutes': result.setMinutes(result.getMinutes() + value); break; case 'hours': result.setHours(result.getHours() + value); break; case 'days': result.setDate(result.getDate() + value); break; case 'weeks': result.setDate(result.getDate() + value * 7); break; case 'months': result.setMonth(result.getMonth() + value); break; default: break; } return result; }; /** * Checks the balance for a user and determines if they can spend a certain amount. * If the user cannot spend the amount, it logs a violation and denies the request. * * @async * @function * @param {Object} params - The function parameters. * @param {ServerRequest} params.req - The Express request object. * @param {Express.Response} params.res - The Express response object. * @param {Object} params.txData - The transaction data. * @param {string} params.txData.user - The user ID or identifier. * @param {('prompt' | 'completion')} params.txData.tokenType - The type of token. * @param {number} params.txData.amount - The amount of tokens. * @param {string} params.txData.model - The model name or identifier. * @param {string} [params.txData.endpointTokenConfig] - The token configuration for the endpoint. * @returns {Promise} Throws error if the user cannot spend the amount. * @throws {Error} Throws an error if there's an issue with the balance check. */ const checkBalance = async ({ req, res, txData }) => { const { canSpend, balance, tokenCost } = await checkBalanceRecord(txData); if (canSpend) { return true; } const type = ViolationTypes.TOKEN_BALANCE; const errorMessage = { type, balance, tokenCost, promptTokens: txData.amount, }; if (txData.generations && txData.generations.length > 0) { errorMessage.generations = txData.generations; } await logViolation(req, res, type, errorMessage, 0); throw new Error(JSON.stringify(errorMessage)); }; module.exports = { checkBalance, };