diff --git a/api/cache/getLogStores.js b/api/cache/getLogStores.js index c46166cc0..9230560b3 100644 --- a/api/cache/getLogStores.js +++ b/api/cache/getLogStores.js @@ -47,7 +47,7 @@ const namespaces = { concurrent: createViolationInstance('concurrent'), non_browser: createViolationInstance('non_browser'), message_limit: createViolationInstance('message_limit'), - token_balance: createViolationInstance('token_balance'), + token_balance: createViolationInstance(ViolationTypes.TOKEN_BALANCE), registrations: createViolationInstance('registrations'), [ViolationTypes.FILE_UPLOAD_LIMIT]: createViolationInstance(ViolationTypes.FILE_UPLOAD_LIMIT), [ViolationTypes.ILLEGAL_MODEL_REQUEST]: createViolationInstance( diff --git a/api/models/Transaction.js b/api/models/Transaction.js index ba9c10c1c..b88c96d63 100644 --- a/api/models/Transaction.js +++ b/api/models/Transaction.js @@ -50,4 +50,23 @@ transactionSchema.statics.create = async function (transactionData) { }; }; -module.exports = mongoose.model('Transaction', transactionSchema); +const Transaction = mongoose.model('Transaction', transactionSchema); + +/** + * Queries and retrieves transactions based on a given filter. + * @async + * @function getTransactions + * @param {Object} filter - MongoDB filter object to apply when querying transactions. + * @returns {Promise} A promise that resolves to an array of matched transactions. + * @throws {Error} Throws an error if querying the database fails. + */ +async function getTransactions(filter) { + try { + return await Transaction.find(filter).lean(); + } catch (error) { + console.error('Error querying transactions:', error); + throw error; + } +} + +module.exports = { Transaction, getTransactions }; diff --git a/api/models/checkBalance.js b/api/models/checkBalance.js index 87798166e..5af77bbb1 100644 --- a/api/models/checkBalance.js +++ b/api/models/checkBalance.js @@ -1,5 +1,6 @@ +const { ViolationTypes } = require('librechat-data-provider'); +const { logViolation } = require('~/cache'); const Balance = require('./Balance'); -const { logViolation } = require('../cache'); /** * Checks the balance for a user and determines if they can spend a certain amount. * If the user cannot spend the amount, it logs a violation and denies the request. @@ -25,7 +26,7 @@ const checkBalance = async ({ req, res, txData }) => { return true; } - const type = 'token_balance'; + const type = ViolationTypes.TOKEN_BALANCE; const errorMessage = { type, balance, diff --git a/api/models/index.js b/api/models/index.js index f1b51d5ef..bf8819382 100644 --- a/api/models/index.js +++ b/api/models/index.js @@ -22,14 +22,12 @@ const Key = require('./Key'); const User = require('./User'); const Session = require('./Session'); const Balance = require('./Balance'); -const Transaction = require('./Transaction'); module.exports = { User, Key, Session, Balance, - Transaction, hashPassword, updateUser, diff --git a/api/models/spendTokens.js b/api/models/spendTokens.js index ac4adeca0..95d46c743 100644 --- a/api/models/spendTokens.js +++ b/api/models/spendTokens.js @@ -1,4 +1,4 @@ -const Transaction = require('./Transaction'); +const { Transaction } = require('./Transaction'); const { logger } = require('~/config'); /** diff --git a/api/server/routes/assistants/chat.js b/api/server/routes/assistants/chat.js index 47a5609a8..fd6c5f7bd 100644 --- a/api/server/routes/assistants/chat.js +++ b/api/server/routes/assistants/chat.js @@ -1,6 +1,12 @@ const { v4 } = require('uuid'); const express = require('express'); -const { EModelEndpoint, Constants, RunStatus, CacheKeys } = require('librechat-data-provider'); +const { + Constants, + RunStatus, + CacheKeys, + EModelEndpoint, + ViolationTypes, +} = require('librechat-data-provider'); const { initThread, recordUsage, @@ -11,10 +17,13 @@ const { } = require('~/server/services/Threads'); const { runAssistant, createOnTextProgress } = require('~/server/services/AssistantService'); const { addTitle, initializeClient } = require('~/server/services/Endpoints/assistants'); -const { sendResponse, sendMessage, sleep } = require('~/server/utils'); +const { sendResponse, sendMessage, sleep, isEnabled, countTokens } = require('~/server/utils'); +const { getTransactions } = require('~/models/Transaction'); const { createRun } = require('~/server/services/Runs'); +const checkBalance = require('~/models/checkBalance'); const { getConvo } = require('~/models/Conversation'); const getLogStores = require('~/cache/getLogStores'); +const { getModelMaxTokens } = require('~/utils'); const { logger } = require('~/config'); const router = express.Router(); @@ -128,6 +137,8 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res : '' }`; return sendResponse(res, messageData, errorMessage); + } else if (error?.message?.includes(ViolationTypes.TOKEN_BALANCE)) { + return sendResponse(res, messageData, error.message); } else { logger.error('[/assistants/chat/]', error); } @@ -207,6 +218,38 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res throw new Error('Missing assistant_id'); } + if (isEnabled(process.env.CHECK_BALANCE)) { + const transactions = + (await getTransactions({ + user: req.user.id, + context: 'message', + conversationId, + })) ?? []; + + const totalPreviousTokens = Math.abs( + transactions.reduce((acc, curr) => acc + curr.rawAmount, 0), + ); + + // TODO: make promptBuffer a config option; buffer for titles, needs buffer for system instructions + const promptBuffer = parentMessageId === Constants.NO_PARENT && !_thread_id ? 200 : 0; + // 5 is added for labels + let promptTokens = (await countTokens(text + (promptPrefix ?? ''))) + 5; + promptTokens += totalPreviousTokens + promptBuffer; + // Count tokens up to the current context window + promptTokens = Math.min(promptTokens, getModelMaxTokens(model)); + + await checkBalance({ + req, + res, + txData: { + model, + user: req.user.id, + tokenType: 'prompt', + amount: promptTokens, + }, + }); + } + /** @type {{ openai: OpenAIClient }} */ const { openai: _openai, client } = await initializeClient({ req, diff --git a/client/src/components/Messages/Content/Error.tsx b/client/src/components/Messages/Content/Error.tsx index 9505d4eba..3d7f4541f 100644 --- a/client/src/components/Messages/Content/Error.tsx +++ b/client/src/components/Messages/Content/Error.tsx @@ -1,6 +1,5 @@ // file deepcode ignore HardcodedNonCryptoSecret: No hardcoded secrets - -import React from 'react'; +import { ViolationTypes } from 'librechat-data-provider'; import type { TOpenAIMessage } from 'librechat-data-provider'; import { formatJSON, extractJson, isJson } from '~/utils/json'; import CodeBlock from './CodeBlock'; @@ -15,7 +14,7 @@ type TMessageLimit = { }; type TTokenBalance = { - type: 'token_balance'; + type: ViolationTypes; balance: number; tokenCost: number; promptTokens: number; diff --git a/client/src/components/SidePanel/Switcher.tsx b/client/src/components/SidePanel/Switcher.tsx index d9cdab852..584389ec8 100644 --- a/client/src/components/SidePanel/Switcher.tsx +++ b/client/src/components/SidePanel/Switcher.tsx @@ -52,7 +52,7 @@ export default function Switcher({ isCollapsed }: SwitcherProps) { const currentAssistant = assistantMap?.[selectedAssistant ?? '']; return ( - span]:line-clamp-1 [&>span]:flex [&>span]:w-full [&>span]:items-center [&>span]:gap-1 [&>span]:truncate [&_svg]:h-4 [&_svg]:w-4 [&_svg]:shrink-0', diff --git a/client/src/hooks/useNewConvo.ts b/client/src/hooks/useNewConvo.ts index 58c11f122..52ac45174 100644 --- a/client/src/hooks/useNewConvo.ts +++ b/client/src/hooks/useNewConvo.ts @@ -15,7 +15,12 @@ import type { TModelsConfig, TEndpointsConfig, } from 'librechat-data-provider'; -import { buildDefaultConvo, getDefaultEndpoint, getEndpointField } from '~/utils'; +import { + buildDefaultConvo, + getDefaultEndpoint, + getEndpointField, + updateLastSelectedModel, +} from '~/utils'; import { useDeleteFilesMutation, useListAssistantsQuery } from '~/data-provider'; import useOriginNavigate from './useOriginNavigate'; import useSetStorage from './useSetStorage'; @@ -32,7 +37,8 @@ const useNewConvo = (index = 0) => { const { data: endpointsConfig = {} as TEndpointsConfig } = useGetEndpointsQuery(); const { data: assistants = [] } = useListAssistantsQuery(defaultOrderQuery, { - select: (res) => res.data.map(({ id, name, metadata }) => ({ id, name, metadata })), + select: (res) => + res.data.map(({ id, name, metadata, model }) => ({ id, name, metadata, model })), }); const { mutateAsync } = useDeleteFilesMutation({ @@ -81,10 +87,30 @@ const useNewConvo = (index = 0) => { conversation.endpointType = undefined; } - if (!conversation.assistant_id && defaultEndpoint === EModelEndpoint.assistants) { - const assistant_id = + const isAssistantEndpoint = defaultEndpoint === EModelEndpoint.assistants; + + if (!conversation.assistant_id && isAssistantEndpoint) { + conversation.assistant_id = localStorage.getItem(`assistant_id__${index}`) ?? assistants[0]?.id; - conversation.assistant_id = assistant_id; + } + + if ( + conversation.assistant_id && + isAssistantEndpoint && + conversation.conversationId === 'new' + ) { + const assistant = assistants.find( + (assistant) => assistant.id === conversation.assistant_id, + ); + conversation.model = assistant?.model; + updateLastSelectedModel({ + endpoint: EModelEndpoint.assistants, + model: conversation.model, + }); + } + + if (conversation.assistant_id && !isAssistantEndpoint) { + conversation.assistant_id = undefined; } const models = modelsConfig?.[defaultEndpoint] ?? []; diff --git a/client/src/utils/endpoints.ts b/client/src/utils/endpoints.ts index e9004bf8a..37ae4fe9e 100644 --- a/client/src/utils/endpoints.ts +++ b/client/src/utils/endpoints.ts @@ -56,3 +56,18 @@ export function mapEndpoints(endpointsConfig: TEndpointsConfig) { (a, b) => (endpointsConfig?.[a]?.order ?? 0) - (endpointsConfig?.[b]?.order ?? 0), ); } + +export function updateLastSelectedModel({ + endpoint, + model, +}: { + endpoint: string; + model: string | undefined; +}) { + if (!model) { + return; + } + const lastSelectedModels = JSON.parse(localStorage.getItem('lastSelectedModel') || '{}'); + lastSelectedModels[endpoint] = model; + localStorage.setItem('lastSelectedModel', JSON.stringify(lastSelectedModels)); +} diff --git a/config/add-balance.js b/config/add-balance.js index fd86b811f..d525e7df5 100644 --- a/config/add-balance.js +++ b/config/add-balance.js @@ -1,7 +1,7 @@ const path = require('path'); require('module-alias')({ base: path.resolve(__dirname, '..', 'api') }); const { askQuestion, silentExit } = require('./helpers'); -const Transaction = require('~/models/Transaction'); +const { Transaction } = require('~/models/Transaction'); const User = require('~/models/User'); const connect = require('./connect'); diff --git a/packages/data-provider/package.json b/packages/data-provider/package.json index 887c6732e..b3a0d45b7 100644 --- a/packages/data-provider/package.json +++ b/packages/data-provider/package.json @@ -1,6 +1,6 @@ { "name": "librechat-data-provider", - "version": "0.4.8", + "version": "0.4.9", "description": "data services for librechat apps", "main": "dist/index.js", "module": "dist/index.es.js", diff --git a/packages/data-provider/src/config.ts b/packages/data-provider/src/config.ts index be8048666..4badaf04b 100644 --- a/packages/data-provider/src/config.ts +++ b/packages/data-provider/src/config.ts @@ -342,11 +342,12 @@ export const modularEndpoints = new Set([ ]); export const supportsBalanceCheck = { + [EModelEndpoint.custom]: true, [EModelEndpoint.openAI]: true, [EModelEndpoint.anthropic]: true, - [EModelEndpoint.azureOpenAI]: true, [EModelEndpoint.gptPlugins]: true, - [EModelEndpoint.custom]: true, + [EModelEndpoint.assistants]: true, + [EModelEndpoint.azureOpenAI]: true, }; export const visionModels = ['gpt-4-vision', 'llava-13b', 'gemini-pro-vision', 'claude-3']; @@ -436,6 +437,10 @@ export enum ViolationTypes { * Illegal Model Request (not available). */ ILLEGAL_MODEL_REQUEST = 'illegal_model_request', + /** + * Token Limit Violation. + */ + TOKEN_BALANCE = 'token_balance', } /**