mirror of
https://github.com/danny-avila/LibreChat.git
synced 2025-09-22 06:00:56 +02:00
🪙 feat: Assistants Token Balance & other improvements (#2114)
* chore: add assistants to supportsBalanceCheck * feat(Transaction): getTransactions and refactor export of model * refactor: use enum: ViolationTypes.TOKEN_BALANCE * feat(assistants): check balance * refactor(assistants): only add promptBuffer if new convo (for title), and remove endpoint definition * refactor(assistants): Count tokens up to the current context window * fix(Switcher): make Select list explicitly controlled * feat(assistants): use assistant's default model when no model is specified instead of the last selected assistant, prevent assistant_id from being recorded in non-assistant endpoints * chore(assistants/chat): import order * chore: bump librechat-data-provider due to changes
This commit is contained in:
parent
f848d752e0
commit
a9d2d3fe40
13 changed files with 128 additions and 22 deletions
2
api/cache/getLogStores.js
vendored
2
api/cache/getLogStores.js
vendored
|
@ -47,7 +47,7 @@ const namespaces = {
|
|||
concurrent: createViolationInstance('concurrent'),
|
||||
non_browser: createViolationInstance('non_browser'),
|
||||
message_limit: createViolationInstance('message_limit'),
|
||||
token_balance: createViolationInstance('token_balance'),
|
||||
token_balance: createViolationInstance(ViolationTypes.TOKEN_BALANCE),
|
||||
registrations: createViolationInstance('registrations'),
|
||||
[ViolationTypes.FILE_UPLOAD_LIMIT]: createViolationInstance(ViolationTypes.FILE_UPLOAD_LIMIT),
|
||||
[ViolationTypes.ILLEGAL_MODEL_REQUEST]: createViolationInstance(
|
||||
|
|
|
@ -50,4 +50,23 @@ transactionSchema.statics.create = async function (transactionData) {
|
|||
};
|
||||
};
|
||||
|
||||
module.exports = mongoose.model('Transaction', transactionSchema);
|
||||
const Transaction = mongoose.model('Transaction', transactionSchema);
|
||||
|
||||
/**
|
||||
* Queries and retrieves transactions based on a given filter.
|
||||
* @async
|
||||
* @function getTransactions
|
||||
* @param {Object} filter - MongoDB filter object to apply when querying transactions.
|
||||
* @returns {Promise<Array>} A promise that resolves to an array of matched transactions.
|
||||
* @throws {Error} Throws an error if querying the database fails.
|
||||
*/
|
||||
async function getTransactions(filter) {
|
||||
try {
|
||||
return await Transaction.find(filter).lean();
|
||||
} catch (error) {
|
||||
console.error('Error querying transactions:', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = { Transaction, getTransactions };
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
const { ViolationTypes } = require('librechat-data-provider');
|
||||
const { logViolation } = require('~/cache');
|
||||
const Balance = require('./Balance');
|
||||
const { logViolation } = require('../cache');
|
||||
/**
|
||||
* Checks the balance for a user and determines if they can spend a certain amount.
|
||||
* If the user cannot spend the amount, it logs a violation and denies the request.
|
||||
|
@ -25,7 +26,7 @@ const checkBalance = async ({ req, res, txData }) => {
|
|||
return true;
|
||||
}
|
||||
|
||||
const type = 'token_balance';
|
||||
const type = ViolationTypes.TOKEN_BALANCE;
|
||||
const errorMessage = {
|
||||
type,
|
||||
balance,
|
||||
|
|
|
@ -22,14 +22,12 @@ const Key = require('./Key');
|
|||
const User = require('./User');
|
||||
const Session = require('./Session');
|
||||
const Balance = require('./Balance');
|
||||
const Transaction = require('./Transaction');
|
||||
|
||||
module.exports = {
|
||||
User,
|
||||
Key,
|
||||
Session,
|
||||
Balance,
|
||||
Transaction,
|
||||
|
||||
hashPassword,
|
||||
updateUser,
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
const Transaction = require('./Transaction');
|
||||
const { Transaction } = require('./Transaction');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
/**
|
||||
|
|
|
@ -1,6 +1,12 @@
|
|||
const { v4 } = require('uuid');
|
||||
const express = require('express');
|
||||
const { EModelEndpoint, Constants, RunStatus, CacheKeys } = require('librechat-data-provider');
|
||||
const {
|
||||
Constants,
|
||||
RunStatus,
|
||||
CacheKeys,
|
||||
EModelEndpoint,
|
||||
ViolationTypes,
|
||||
} = require('librechat-data-provider');
|
||||
const {
|
||||
initThread,
|
||||
recordUsage,
|
||||
|
@ -11,10 +17,13 @@ const {
|
|||
} = require('~/server/services/Threads');
|
||||
const { runAssistant, createOnTextProgress } = require('~/server/services/AssistantService');
|
||||
const { addTitle, initializeClient } = require('~/server/services/Endpoints/assistants');
|
||||
const { sendResponse, sendMessage, sleep } = require('~/server/utils');
|
||||
const { sendResponse, sendMessage, sleep, isEnabled, countTokens } = require('~/server/utils');
|
||||
const { getTransactions } = require('~/models/Transaction');
|
||||
const { createRun } = require('~/server/services/Runs');
|
||||
const checkBalance = require('~/models/checkBalance');
|
||||
const { getConvo } = require('~/models/Conversation');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const { getModelMaxTokens } = require('~/utils');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const router = express.Router();
|
||||
|
@ -128,6 +137,8 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res
|
|||
: ''
|
||||
}`;
|
||||
return sendResponse(res, messageData, errorMessage);
|
||||
} else if (error?.message?.includes(ViolationTypes.TOKEN_BALANCE)) {
|
||||
return sendResponse(res, messageData, error.message);
|
||||
} else {
|
||||
logger.error('[/assistants/chat/]', error);
|
||||
}
|
||||
|
@ -207,6 +218,38 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res
|
|||
throw new Error('Missing assistant_id');
|
||||
}
|
||||
|
||||
if (isEnabled(process.env.CHECK_BALANCE)) {
|
||||
const transactions =
|
||||
(await getTransactions({
|
||||
user: req.user.id,
|
||||
context: 'message',
|
||||
conversationId,
|
||||
})) ?? [];
|
||||
|
||||
const totalPreviousTokens = Math.abs(
|
||||
transactions.reduce((acc, curr) => acc + curr.rawAmount, 0),
|
||||
);
|
||||
|
||||
// TODO: make promptBuffer a config option; buffer for titles, needs buffer for system instructions
|
||||
const promptBuffer = parentMessageId === Constants.NO_PARENT && !_thread_id ? 200 : 0;
|
||||
// 5 is added for labels
|
||||
let promptTokens = (await countTokens(text + (promptPrefix ?? ''))) + 5;
|
||||
promptTokens += totalPreviousTokens + promptBuffer;
|
||||
// Count tokens up to the current context window
|
||||
promptTokens = Math.min(promptTokens, getModelMaxTokens(model));
|
||||
|
||||
await checkBalance({
|
||||
req,
|
||||
res,
|
||||
txData: {
|
||||
model,
|
||||
user: req.user.id,
|
||||
tokenType: 'prompt',
|
||||
amount: promptTokens,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
/** @type {{ openai: OpenAIClient }} */
|
||||
const { openai: _openai, client } = await initializeClient({
|
||||
req,
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
// file deepcode ignore HardcodedNonCryptoSecret: No hardcoded secrets
|
||||
|
||||
import React from 'react';
|
||||
import { ViolationTypes } from 'librechat-data-provider';
|
||||
import type { TOpenAIMessage } from 'librechat-data-provider';
|
||||
import { formatJSON, extractJson, isJson } from '~/utils/json';
|
||||
import CodeBlock from './CodeBlock';
|
||||
|
@ -15,7 +14,7 @@ type TMessageLimit = {
|
|||
};
|
||||
|
||||
type TTokenBalance = {
|
||||
type: 'token_balance';
|
||||
type: ViolationTypes;
|
||||
balance: number;
|
||||
tokenCost: number;
|
||||
promptTokens: number;
|
||||
|
|
|
@ -52,7 +52,7 @@ export default function Switcher({ isCollapsed }: SwitcherProps) {
|
|||
const currentAssistant = assistantMap?.[selectedAssistant ?? ''];
|
||||
|
||||
return (
|
||||
<Select defaultValue={selectedAssistant as string | undefined} onValueChange={onSelect}>
|
||||
<Select value={selectedAssistant as string | undefined} onValueChange={onSelect}>
|
||||
<SelectTrigger
|
||||
className={cn(
|
||||
'flex items-center gap-2 [&>span]:line-clamp-1 [&>span]:flex [&>span]:w-full [&>span]:items-center [&>span]:gap-1 [&>span]:truncate [&_svg]:h-4 [&_svg]:w-4 [&_svg]:shrink-0',
|
||||
|
|
|
@ -15,7 +15,12 @@ import type {
|
|||
TModelsConfig,
|
||||
TEndpointsConfig,
|
||||
} from 'librechat-data-provider';
|
||||
import { buildDefaultConvo, getDefaultEndpoint, getEndpointField } from '~/utils';
|
||||
import {
|
||||
buildDefaultConvo,
|
||||
getDefaultEndpoint,
|
||||
getEndpointField,
|
||||
updateLastSelectedModel,
|
||||
} from '~/utils';
|
||||
import { useDeleteFilesMutation, useListAssistantsQuery } from '~/data-provider';
|
||||
import useOriginNavigate from './useOriginNavigate';
|
||||
import useSetStorage from './useSetStorage';
|
||||
|
@ -32,7 +37,8 @@ const useNewConvo = (index = 0) => {
|
|||
const { data: endpointsConfig = {} as TEndpointsConfig } = useGetEndpointsQuery();
|
||||
|
||||
const { data: assistants = [] } = useListAssistantsQuery(defaultOrderQuery, {
|
||||
select: (res) => res.data.map(({ id, name, metadata }) => ({ id, name, metadata })),
|
||||
select: (res) =>
|
||||
res.data.map(({ id, name, metadata, model }) => ({ id, name, metadata, model })),
|
||||
});
|
||||
|
||||
const { mutateAsync } = useDeleteFilesMutation({
|
||||
|
@ -81,10 +87,30 @@ const useNewConvo = (index = 0) => {
|
|||
conversation.endpointType = undefined;
|
||||
}
|
||||
|
||||
if (!conversation.assistant_id && defaultEndpoint === EModelEndpoint.assistants) {
|
||||
const assistant_id =
|
||||
const isAssistantEndpoint = defaultEndpoint === EModelEndpoint.assistants;
|
||||
|
||||
if (!conversation.assistant_id && isAssistantEndpoint) {
|
||||
conversation.assistant_id =
|
||||
localStorage.getItem(`assistant_id__${index}`) ?? assistants[0]?.id;
|
||||
conversation.assistant_id = assistant_id;
|
||||
}
|
||||
|
||||
if (
|
||||
conversation.assistant_id &&
|
||||
isAssistantEndpoint &&
|
||||
conversation.conversationId === 'new'
|
||||
) {
|
||||
const assistant = assistants.find(
|
||||
(assistant) => assistant.id === conversation.assistant_id,
|
||||
);
|
||||
conversation.model = assistant?.model;
|
||||
updateLastSelectedModel({
|
||||
endpoint: EModelEndpoint.assistants,
|
||||
model: conversation.model,
|
||||
});
|
||||
}
|
||||
|
||||
if (conversation.assistant_id && !isAssistantEndpoint) {
|
||||
conversation.assistant_id = undefined;
|
||||
}
|
||||
|
||||
const models = modelsConfig?.[defaultEndpoint] ?? [];
|
||||
|
|
|
@ -56,3 +56,18 @@ export function mapEndpoints(endpointsConfig: TEndpointsConfig) {
|
|||
(a, b) => (endpointsConfig?.[a]?.order ?? 0) - (endpointsConfig?.[b]?.order ?? 0),
|
||||
);
|
||||
}
|
||||
|
||||
export function updateLastSelectedModel({
|
||||
endpoint,
|
||||
model,
|
||||
}: {
|
||||
endpoint: string;
|
||||
model: string | undefined;
|
||||
}) {
|
||||
if (!model) {
|
||||
return;
|
||||
}
|
||||
const lastSelectedModels = JSON.parse(localStorage.getItem('lastSelectedModel') || '{}');
|
||||
lastSelectedModels[endpoint] = model;
|
||||
localStorage.setItem('lastSelectedModel', JSON.stringify(lastSelectedModels));
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
const path = require('path');
|
||||
require('module-alias')({ base: path.resolve(__dirname, '..', 'api') });
|
||||
const { askQuestion, silentExit } = require('./helpers');
|
||||
const Transaction = require('~/models/Transaction');
|
||||
const { Transaction } = require('~/models/Transaction');
|
||||
const User = require('~/models/User');
|
||||
const connect = require('./connect');
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
{
|
||||
"name": "librechat-data-provider",
|
||||
"version": "0.4.8",
|
||||
"version": "0.4.9",
|
||||
"description": "data services for librechat apps",
|
||||
"main": "dist/index.js",
|
||||
"module": "dist/index.es.js",
|
||||
|
|
|
@ -342,11 +342,12 @@ export const modularEndpoints = new Set<EModelEndpoint | string>([
|
|||
]);
|
||||
|
||||
export const supportsBalanceCheck = {
|
||||
[EModelEndpoint.custom]: true,
|
||||
[EModelEndpoint.openAI]: true,
|
||||
[EModelEndpoint.anthropic]: true,
|
||||
[EModelEndpoint.azureOpenAI]: true,
|
||||
[EModelEndpoint.gptPlugins]: true,
|
||||
[EModelEndpoint.custom]: true,
|
||||
[EModelEndpoint.assistants]: true,
|
||||
[EModelEndpoint.azureOpenAI]: true,
|
||||
};
|
||||
|
||||
export const visionModels = ['gpt-4-vision', 'llava-13b', 'gemini-pro-vision', 'claude-3'];
|
||||
|
@ -436,6 +437,10 @@ export enum ViolationTypes {
|
|||
* Illegal Model Request (not available).
|
||||
*/
|
||||
ILLEGAL_MODEL_REQUEST = 'illegal_model_request',
|
||||
/**
|
||||
* Token Limit Violation.
|
||||
*/
|
||||
TOKEN_BALANCE = 'token_balance',
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue