🪙 feat: Assistants Token Balance & other improvements (#2114)

* chore: add assistants to supportsBalanceCheck

* feat(Transaction): getTransactions and refactor export of model

* refactor: use enum: ViolationTypes.TOKEN_BALANCE

* feat(assistants): check balance

* refactor(assistants): only add promptBuffer if new convo (for title), and remove endpoint definition

* refactor(assistants): Count tokens up to the current context window

* fix(Switcher): make Select list explicitly controlled

* feat(assistants): use assistant's default model when no model is specified instead of the last selected assistant, prevent assistant_id from being recorded in non-assistant endpoints

* chore(assistants/chat): import order

* chore: bump librechat-data-provider due to changes
This commit is contained in:
Danny Avila 2024-03-15 19:48:42 -04:00 committed by GitHub
parent f848d752e0
commit a9d2d3fe40
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 128 additions and 22 deletions

View file

@ -47,7 +47,7 @@ const namespaces = {
concurrent: createViolationInstance('concurrent'),
non_browser: createViolationInstance('non_browser'),
message_limit: createViolationInstance('message_limit'),
token_balance: createViolationInstance('token_balance'),
token_balance: createViolationInstance(ViolationTypes.TOKEN_BALANCE),
registrations: createViolationInstance('registrations'),
[ViolationTypes.FILE_UPLOAD_LIMIT]: createViolationInstance(ViolationTypes.FILE_UPLOAD_LIMIT),
[ViolationTypes.ILLEGAL_MODEL_REQUEST]: createViolationInstance(

View file

@ -50,4 +50,23 @@ transactionSchema.statics.create = async function (transactionData) {
};
};
module.exports = mongoose.model('Transaction', transactionSchema);
const Transaction = mongoose.model('Transaction', transactionSchema);
/**
* Queries and retrieves transactions based on a given filter.
* @async
* @function getTransactions
* @param {Object} filter - MongoDB filter object to apply when querying transactions.
* @returns {Promise<Array>} A promise that resolves to an array of matched transactions.
* @throws {Error} Throws an error if querying the database fails.
*/
async function getTransactions(filter) {
try {
return await Transaction.find(filter).lean();
} catch (error) {
console.error('Error querying transactions:', error);
throw error;
}
}
module.exports = { Transaction, getTransactions };

View file

@ -1,5 +1,6 @@
const { ViolationTypes } = require('librechat-data-provider');
const { logViolation } = require('~/cache');
const Balance = require('./Balance');
const { logViolation } = require('../cache');
/**
* Checks the balance for a user and determines if they can spend a certain amount.
* If the user cannot spend the amount, it logs a violation and denies the request.
@ -25,7 +26,7 @@ const checkBalance = async ({ req, res, txData }) => {
return true;
}
const type = 'token_balance';
const type = ViolationTypes.TOKEN_BALANCE;
const errorMessage = {
type,
balance,

View file

@ -22,14 +22,12 @@ const Key = require('./Key');
const User = require('./User');
const Session = require('./Session');
const Balance = require('./Balance');
const Transaction = require('./Transaction');
module.exports = {
User,
Key,
Session,
Balance,
Transaction,
hashPassword,
updateUser,

View file

@ -1,4 +1,4 @@
const Transaction = require('./Transaction');
const { Transaction } = require('./Transaction');
const { logger } = require('~/config');
/**

View file

@ -1,6 +1,12 @@
const { v4 } = require('uuid');
const express = require('express');
const { EModelEndpoint, Constants, RunStatus, CacheKeys } = require('librechat-data-provider');
const {
Constants,
RunStatus,
CacheKeys,
EModelEndpoint,
ViolationTypes,
} = require('librechat-data-provider');
const {
initThread,
recordUsage,
@ -11,10 +17,13 @@ const {
} = require('~/server/services/Threads');
const { runAssistant, createOnTextProgress } = require('~/server/services/AssistantService');
const { addTitle, initializeClient } = require('~/server/services/Endpoints/assistants');
const { sendResponse, sendMessage, sleep } = require('~/server/utils');
const { sendResponse, sendMessage, sleep, isEnabled, countTokens } = require('~/server/utils');
const { getTransactions } = require('~/models/Transaction');
const { createRun } = require('~/server/services/Runs');
const checkBalance = require('~/models/checkBalance');
const { getConvo } = require('~/models/Conversation');
const getLogStores = require('~/cache/getLogStores');
const { getModelMaxTokens } = require('~/utils');
const { logger } = require('~/config');
const router = express.Router();
@ -128,6 +137,8 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res
: ''
}`;
return sendResponse(res, messageData, errorMessage);
} else if (error?.message?.includes(ViolationTypes.TOKEN_BALANCE)) {
return sendResponse(res, messageData, error.message);
} else {
logger.error('[/assistants/chat/]', error);
}
@ -207,6 +218,38 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res
throw new Error('Missing assistant_id');
}
if (isEnabled(process.env.CHECK_BALANCE)) {
const transactions =
(await getTransactions({
user: req.user.id,
context: 'message',
conversationId,
})) ?? [];
const totalPreviousTokens = Math.abs(
transactions.reduce((acc, curr) => acc + curr.rawAmount, 0),
);
// TODO: make promptBuffer a config option; buffer for titles, needs buffer for system instructions
const promptBuffer = parentMessageId === Constants.NO_PARENT && !_thread_id ? 200 : 0;
// 5 is added for labels
let promptTokens = (await countTokens(text + (promptPrefix ?? ''))) + 5;
promptTokens += totalPreviousTokens + promptBuffer;
// Count tokens up to the current context window
promptTokens = Math.min(promptTokens, getModelMaxTokens(model));
await checkBalance({
req,
res,
txData: {
model,
user: req.user.id,
tokenType: 'prompt',
amount: promptTokens,
},
});
}
/** @type {{ openai: OpenAIClient }} */
const { openai: _openai, client } = await initializeClient({
req,

View file

@ -1,6 +1,5 @@
// file deepcode ignore HardcodedNonCryptoSecret: No hardcoded secrets
import React from 'react';
import { ViolationTypes } from 'librechat-data-provider';
import type { TOpenAIMessage } from 'librechat-data-provider';
import { formatJSON, extractJson, isJson } from '~/utils/json';
import CodeBlock from './CodeBlock';
@ -15,7 +14,7 @@ type TMessageLimit = {
};
type TTokenBalance = {
type: 'token_balance';
type: ViolationTypes;
balance: number;
tokenCost: number;
promptTokens: number;

View file

@ -52,7 +52,7 @@ export default function Switcher({ isCollapsed }: SwitcherProps) {
const currentAssistant = assistantMap?.[selectedAssistant ?? ''];
return (
<Select defaultValue={selectedAssistant as string | undefined} onValueChange={onSelect}>
<Select value={selectedAssistant as string | undefined} onValueChange={onSelect}>
<SelectTrigger
className={cn(
'flex items-center gap-2 [&>span]:line-clamp-1 [&>span]:flex [&>span]:w-full [&>span]:items-center [&>span]:gap-1 [&>span]:truncate [&_svg]:h-4 [&_svg]:w-4 [&_svg]:shrink-0',

View file

@ -15,7 +15,12 @@ import type {
TModelsConfig,
TEndpointsConfig,
} from 'librechat-data-provider';
import { buildDefaultConvo, getDefaultEndpoint, getEndpointField } from '~/utils';
import {
buildDefaultConvo,
getDefaultEndpoint,
getEndpointField,
updateLastSelectedModel,
} from '~/utils';
import { useDeleteFilesMutation, useListAssistantsQuery } from '~/data-provider';
import useOriginNavigate from './useOriginNavigate';
import useSetStorage from './useSetStorage';
@ -32,7 +37,8 @@ const useNewConvo = (index = 0) => {
const { data: endpointsConfig = {} as TEndpointsConfig } = useGetEndpointsQuery();
const { data: assistants = [] } = useListAssistantsQuery(defaultOrderQuery, {
select: (res) => res.data.map(({ id, name, metadata }) => ({ id, name, metadata })),
select: (res) =>
res.data.map(({ id, name, metadata, model }) => ({ id, name, metadata, model })),
});
const { mutateAsync } = useDeleteFilesMutation({
@ -81,10 +87,30 @@ const useNewConvo = (index = 0) => {
conversation.endpointType = undefined;
}
if (!conversation.assistant_id && defaultEndpoint === EModelEndpoint.assistants) {
const assistant_id =
const isAssistantEndpoint = defaultEndpoint === EModelEndpoint.assistants;
if (!conversation.assistant_id && isAssistantEndpoint) {
conversation.assistant_id =
localStorage.getItem(`assistant_id__${index}`) ?? assistants[0]?.id;
conversation.assistant_id = assistant_id;
}
if (
conversation.assistant_id &&
isAssistantEndpoint &&
conversation.conversationId === 'new'
) {
const assistant = assistants.find(
(assistant) => assistant.id === conversation.assistant_id,
);
conversation.model = assistant?.model;
updateLastSelectedModel({
endpoint: EModelEndpoint.assistants,
model: conversation.model,
});
}
if (conversation.assistant_id && !isAssistantEndpoint) {
conversation.assistant_id = undefined;
}
const models = modelsConfig?.[defaultEndpoint] ?? [];

View file

@ -56,3 +56,18 @@ export function mapEndpoints(endpointsConfig: TEndpointsConfig) {
(a, b) => (endpointsConfig?.[a]?.order ?? 0) - (endpointsConfig?.[b]?.order ?? 0),
);
}
export function updateLastSelectedModel({
endpoint,
model,
}: {
endpoint: string;
model: string | undefined;
}) {
if (!model) {
return;
}
const lastSelectedModels = JSON.parse(localStorage.getItem('lastSelectedModel') || '{}');
lastSelectedModels[endpoint] = model;
localStorage.setItem('lastSelectedModel', JSON.stringify(lastSelectedModels));
}

View file

@ -1,7 +1,7 @@
const path = require('path');
require('module-alias')({ base: path.resolve(__dirname, '..', 'api') });
const { askQuestion, silentExit } = require('./helpers');
const Transaction = require('~/models/Transaction');
const { Transaction } = require('~/models/Transaction');
const User = require('~/models/User');
const connect = require('./connect');

View file

@ -1,6 +1,6 @@
{
"name": "librechat-data-provider",
"version": "0.4.8",
"version": "0.4.9",
"description": "data services for librechat apps",
"main": "dist/index.js",
"module": "dist/index.es.js",

View file

@ -342,11 +342,12 @@ export const modularEndpoints = new Set<EModelEndpoint | string>([
]);
export const supportsBalanceCheck = {
[EModelEndpoint.custom]: true,
[EModelEndpoint.openAI]: true,
[EModelEndpoint.anthropic]: true,
[EModelEndpoint.azureOpenAI]: true,
[EModelEndpoint.gptPlugins]: true,
[EModelEndpoint.custom]: true,
[EModelEndpoint.assistants]: true,
[EModelEndpoint.azureOpenAI]: true,
};
export const visionModels = ['gpt-4-vision', 'llava-13b', 'gemini-pro-vision', 'claude-3'];
@ -436,6 +437,10 @@ export enum ViolationTypes {
* Illegal Model Request (not available).
*/
ILLEGAL_MODEL_REQUEST = 'illegal_model_request',
/**
* Token Limit Violation.
*/
TOKEN_BALANCE = 'token_balance',
}
/**