mirror of
https://github.com/danny-avila/LibreChat.git
synced 2025-09-22 08:12:00 +02:00
🪙 feat: Assistants Token Balance & other improvements (#2114)
* chore: add assistants to supportsBalanceCheck * feat(Transaction): getTransactions and refactor export of model * refactor: use enum: ViolationTypes.TOKEN_BALANCE * feat(assistants): check balance * refactor(assistants): only add promptBuffer if new convo (for title), and remove endpoint definition * refactor(assistants): Count tokens up to the current context window * fix(Switcher): make Select list explicitly controlled * feat(assistants): use assistant's default model when no model is specified instead of the last selected assistant, prevent assistant_id from being recorded in non-assistant endpoints * chore(assistants/chat): import order * chore: bump librechat-data-provider due to changes
This commit is contained in:
parent
f848d752e0
commit
a9d2d3fe40
13 changed files with 128 additions and 22 deletions
2
api/cache/getLogStores.js
vendored
2
api/cache/getLogStores.js
vendored
|
@ -47,7 +47,7 @@ const namespaces = {
|
||||||
concurrent: createViolationInstance('concurrent'),
|
concurrent: createViolationInstance('concurrent'),
|
||||||
non_browser: createViolationInstance('non_browser'),
|
non_browser: createViolationInstance('non_browser'),
|
||||||
message_limit: createViolationInstance('message_limit'),
|
message_limit: createViolationInstance('message_limit'),
|
||||||
token_balance: createViolationInstance('token_balance'),
|
token_balance: createViolationInstance(ViolationTypes.TOKEN_BALANCE),
|
||||||
registrations: createViolationInstance('registrations'),
|
registrations: createViolationInstance('registrations'),
|
||||||
[ViolationTypes.FILE_UPLOAD_LIMIT]: createViolationInstance(ViolationTypes.FILE_UPLOAD_LIMIT),
|
[ViolationTypes.FILE_UPLOAD_LIMIT]: createViolationInstance(ViolationTypes.FILE_UPLOAD_LIMIT),
|
||||||
[ViolationTypes.ILLEGAL_MODEL_REQUEST]: createViolationInstance(
|
[ViolationTypes.ILLEGAL_MODEL_REQUEST]: createViolationInstance(
|
||||||
|
|
|
@ -50,4 +50,23 @@ transactionSchema.statics.create = async function (transactionData) {
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
module.exports = mongoose.model('Transaction', transactionSchema);
|
const Transaction = mongoose.model('Transaction', transactionSchema);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Queries and retrieves transactions based on a given filter.
|
||||||
|
* @async
|
||||||
|
* @function getTransactions
|
||||||
|
* @param {Object} filter - MongoDB filter object to apply when querying transactions.
|
||||||
|
* @returns {Promise<Array>} A promise that resolves to an array of matched transactions.
|
||||||
|
* @throws {Error} Throws an error if querying the database fails.
|
||||||
|
*/
|
||||||
|
async function getTransactions(filter) {
|
||||||
|
try {
|
||||||
|
return await Transaction.find(filter).lean();
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Error querying transactions:', error);
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
module.exports = { Transaction, getTransactions };
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
|
const { ViolationTypes } = require('librechat-data-provider');
|
||||||
|
const { logViolation } = require('~/cache');
|
||||||
const Balance = require('./Balance');
|
const Balance = require('./Balance');
|
||||||
const { logViolation } = require('../cache');
|
|
||||||
/**
|
/**
|
||||||
* Checks the balance for a user and determines if they can spend a certain amount.
|
* Checks the balance for a user and determines if they can spend a certain amount.
|
||||||
* If the user cannot spend the amount, it logs a violation and denies the request.
|
* If the user cannot spend the amount, it logs a violation and denies the request.
|
||||||
|
@ -25,7 +26,7 @@ const checkBalance = async ({ req, res, txData }) => {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
const type = 'token_balance';
|
const type = ViolationTypes.TOKEN_BALANCE;
|
||||||
const errorMessage = {
|
const errorMessage = {
|
||||||
type,
|
type,
|
||||||
balance,
|
balance,
|
||||||
|
|
|
@ -22,14 +22,12 @@ const Key = require('./Key');
|
||||||
const User = require('./User');
|
const User = require('./User');
|
||||||
const Session = require('./Session');
|
const Session = require('./Session');
|
||||||
const Balance = require('./Balance');
|
const Balance = require('./Balance');
|
||||||
const Transaction = require('./Transaction');
|
|
||||||
|
|
||||||
module.exports = {
|
module.exports = {
|
||||||
User,
|
User,
|
||||||
Key,
|
Key,
|
||||||
Session,
|
Session,
|
||||||
Balance,
|
Balance,
|
||||||
Transaction,
|
|
||||||
|
|
||||||
hashPassword,
|
hashPassword,
|
||||||
updateUser,
|
updateUser,
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
const Transaction = require('./Transaction');
|
const { Transaction } = require('./Transaction');
|
||||||
const { logger } = require('~/config');
|
const { logger } = require('~/config');
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -1,6 +1,12 @@
|
||||||
const { v4 } = require('uuid');
|
const { v4 } = require('uuid');
|
||||||
const express = require('express');
|
const express = require('express');
|
||||||
const { EModelEndpoint, Constants, RunStatus, CacheKeys } = require('librechat-data-provider');
|
const {
|
||||||
|
Constants,
|
||||||
|
RunStatus,
|
||||||
|
CacheKeys,
|
||||||
|
EModelEndpoint,
|
||||||
|
ViolationTypes,
|
||||||
|
} = require('librechat-data-provider');
|
||||||
const {
|
const {
|
||||||
initThread,
|
initThread,
|
||||||
recordUsage,
|
recordUsage,
|
||||||
|
@ -11,10 +17,13 @@ const {
|
||||||
} = require('~/server/services/Threads');
|
} = require('~/server/services/Threads');
|
||||||
const { runAssistant, createOnTextProgress } = require('~/server/services/AssistantService');
|
const { runAssistant, createOnTextProgress } = require('~/server/services/AssistantService');
|
||||||
const { addTitle, initializeClient } = require('~/server/services/Endpoints/assistants');
|
const { addTitle, initializeClient } = require('~/server/services/Endpoints/assistants');
|
||||||
const { sendResponse, sendMessage, sleep } = require('~/server/utils');
|
const { sendResponse, sendMessage, sleep, isEnabled, countTokens } = require('~/server/utils');
|
||||||
|
const { getTransactions } = require('~/models/Transaction');
|
||||||
const { createRun } = require('~/server/services/Runs');
|
const { createRun } = require('~/server/services/Runs');
|
||||||
|
const checkBalance = require('~/models/checkBalance');
|
||||||
const { getConvo } = require('~/models/Conversation');
|
const { getConvo } = require('~/models/Conversation');
|
||||||
const getLogStores = require('~/cache/getLogStores');
|
const getLogStores = require('~/cache/getLogStores');
|
||||||
|
const { getModelMaxTokens } = require('~/utils');
|
||||||
const { logger } = require('~/config');
|
const { logger } = require('~/config');
|
||||||
|
|
||||||
const router = express.Router();
|
const router = express.Router();
|
||||||
|
@ -128,6 +137,8 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res
|
||||||
: ''
|
: ''
|
||||||
}`;
|
}`;
|
||||||
return sendResponse(res, messageData, errorMessage);
|
return sendResponse(res, messageData, errorMessage);
|
||||||
|
} else if (error?.message?.includes(ViolationTypes.TOKEN_BALANCE)) {
|
||||||
|
return sendResponse(res, messageData, error.message);
|
||||||
} else {
|
} else {
|
||||||
logger.error('[/assistants/chat/]', error);
|
logger.error('[/assistants/chat/]', error);
|
||||||
}
|
}
|
||||||
|
@ -207,6 +218,38 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res
|
||||||
throw new Error('Missing assistant_id');
|
throw new Error('Missing assistant_id');
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (isEnabled(process.env.CHECK_BALANCE)) {
|
||||||
|
const transactions =
|
||||||
|
(await getTransactions({
|
||||||
|
user: req.user.id,
|
||||||
|
context: 'message',
|
||||||
|
conversationId,
|
||||||
|
})) ?? [];
|
||||||
|
|
||||||
|
const totalPreviousTokens = Math.abs(
|
||||||
|
transactions.reduce((acc, curr) => acc + curr.rawAmount, 0),
|
||||||
|
);
|
||||||
|
|
||||||
|
// TODO: make promptBuffer a config option; buffer for titles, needs buffer for system instructions
|
||||||
|
const promptBuffer = parentMessageId === Constants.NO_PARENT && !_thread_id ? 200 : 0;
|
||||||
|
// 5 is added for labels
|
||||||
|
let promptTokens = (await countTokens(text + (promptPrefix ?? ''))) + 5;
|
||||||
|
promptTokens += totalPreviousTokens + promptBuffer;
|
||||||
|
// Count tokens up to the current context window
|
||||||
|
promptTokens = Math.min(promptTokens, getModelMaxTokens(model));
|
||||||
|
|
||||||
|
await checkBalance({
|
||||||
|
req,
|
||||||
|
res,
|
||||||
|
txData: {
|
||||||
|
model,
|
||||||
|
user: req.user.id,
|
||||||
|
tokenType: 'prompt',
|
||||||
|
amount: promptTokens,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
/** @type {{ openai: OpenAIClient }} */
|
/** @type {{ openai: OpenAIClient }} */
|
||||||
const { openai: _openai, client } = await initializeClient({
|
const { openai: _openai, client } = await initializeClient({
|
||||||
req,
|
req,
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
// file deepcode ignore HardcodedNonCryptoSecret: No hardcoded secrets
|
// file deepcode ignore HardcodedNonCryptoSecret: No hardcoded secrets
|
||||||
|
import { ViolationTypes } from 'librechat-data-provider';
|
||||||
import React from 'react';
|
|
||||||
import type { TOpenAIMessage } from 'librechat-data-provider';
|
import type { TOpenAIMessage } from 'librechat-data-provider';
|
||||||
import { formatJSON, extractJson, isJson } from '~/utils/json';
|
import { formatJSON, extractJson, isJson } from '~/utils/json';
|
||||||
import CodeBlock from './CodeBlock';
|
import CodeBlock from './CodeBlock';
|
||||||
|
@ -15,7 +14,7 @@ type TMessageLimit = {
|
||||||
};
|
};
|
||||||
|
|
||||||
type TTokenBalance = {
|
type TTokenBalance = {
|
||||||
type: 'token_balance';
|
type: ViolationTypes;
|
||||||
balance: number;
|
balance: number;
|
||||||
tokenCost: number;
|
tokenCost: number;
|
||||||
promptTokens: number;
|
promptTokens: number;
|
||||||
|
|
|
@ -52,7 +52,7 @@ export default function Switcher({ isCollapsed }: SwitcherProps) {
|
||||||
const currentAssistant = assistantMap?.[selectedAssistant ?? ''];
|
const currentAssistant = assistantMap?.[selectedAssistant ?? ''];
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Select defaultValue={selectedAssistant as string | undefined} onValueChange={onSelect}>
|
<Select value={selectedAssistant as string | undefined} onValueChange={onSelect}>
|
||||||
<SelectTrigger
|
<SelectTrigger
|
||||||
className={cn(
|
className={cn(
|
||||||
'flex items-center gap-2 [&>span]:line-clamp-1 [&>span]:flex [&>span]:w-full [&>span]:items-center [&>span]:gap-1 [&>span]:truncate [&_svg]:h-4 [&_svg]:w-4 [&_svg]:shrink-0',
|
'flex items-center gap-2 [&>span]:line-clamp-1 [&>span]:flex [&>span]:w-full [&>span]:items-center [&>span]:gap-1 [&>span]:truncate [&_svg]:h-4 [&_svg]:w-4 [&_svg]:shrink-0',
|
||||||
|
|
|
@ -15,7 +15,12 @@ import type {
|
||||||
TModelsConfig,
|
TModelsConfig,
|
||||||
TEndpointsConfig,
|
TEndpointsConfig,
|
||||||
} from 'librechat-data-provider';
|
} from 'librechat-data-provider';
|
||||||
import { buildDefaultConvo, getDefaultEndpoint, getEndpointField } from '~/utils';
|
import {
|
||||||
|
buildDefaultConvo,
|
||||||
|
getDefaultEndpoint,
|
||||||
|
getEndpointField,
|
||||||
|
updateLastSelectedModel,
|
||||||
|
} from '~/utils';
|
||||||
import { useDeleteFilesMutation, useListAssistantsQuery } from '~/data-provider';
|
import { useDeleteFilesMutation, useListAssistantsQuery } from '~/data-provider';
|
||||||
import useOriginNavigate from './useOriginNavigate';
|
import useOriginNavigate from './useOriginNavigate';
|
||||||
import useSetStorage from './useSetStorage';
|
import useSetStorage from './useSetStorage';
|
||||||
|
@ -32,7 +37,8 @@ const useNewConvo = (index = 0) => {
|
||||||
const { data: endpointsConfig = {} as TEndpointsConfig } = useGetEndpointsQuery();
|
const { data: endpointsConfig = {} as TEndpointsConfig } = useGetEndpointsQuery();
|
||||||
|
|
||||||
const { data: assistants = [] } = useListAssistantsQuery(defaultOrderQuery, {
|
const { data: assistants = [] } = useListAssistantsQuery(defaultOrderQuery, {
|
||||||
select: (res) => res.data.map(({ id, name, metadata }) => ({ id, name, metadata })),
|
select: (res) =>
|
||||||
|
res.data.map(({ id, name, metadata, model }) => ({ id, name, metadata, model })),
|
||||||
});
|
});
|
||||||
|
|
||||||
const { mutateAsync } = useDeleteFilesMutation({
|
const { mutateAsync } = useDeleteFilesMutation({
|
||||||
|
@ -81,10 +87,30 @@ const useNewConvo = (index = 0) => {
|
||||||
conversation.endpointType = undefined;
|
conversation.endpointType = undefined;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!conversation.assistant_id && defaultEndpoint === EModelEndpoint.assistants) {
|
const isAssistantEndpoint = defaultEndpoint === EModelEndpoint.assistants;
|
||||||
const assistant_id =
|
|
||||||
|
if (!conversation.assistant_id && isAssistantEndpoint) {
|
||||||
|
conversation.assistant_id =
|
||||||
localStorage.getItem(`assistant_id__${index}`) ?? assistants[0]?.id;
|
localStorage.getItem(`assistant_id__${index}`) ?? assistants[0]?.id;
|
||||||
conversation.assistant_id = assistant_id;
|
}
|
||||||
|
|
||||||
|
if (
|
||||||
|
conversation.assistant_id &&
|
||||||
|
isAssistantEndpoint &&
|
||||||
|
conversation.conversationId === 'new'
|
||||||
|
) {
|
||||||
|
const assistant = assistants.find(
|
||||||
|
(assistant) => assistant.id === conversation.assistant_id,
|
||||||
|
);
|
||||||
|
conversation.model = assistant?.model;
|
||||||
|
updateLastSelectedModel({
|
||||||
|
endpoint: EModelEndpoint.assistants,
|
||||||
|
model: conversation.model,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if (conversation.assistant_id && !isAssistantEndpoint) {
|
||||||
|
conversation.assistant_id = undefined;
|
||||||
}
|
}
|
||||||
|
|
||||||
const models = modelsConfig?.[defaultEndpoint] ?? [];
|
const models = modelsConfig?.[defaultEndpoint] ?? [];
|
||||||
|
|
|
@ -56,3 +56,18 @@ export function mapEndpoints(endpointsConfig: TEndpointsConfig) {
|
||||||
(a, b) => (endpointsConfig?.[a]?.order ?? 0) - (endpointsConfig?.[b]?.order ?? 0),
|
(a, b) => (endpointsConfig?.[a]?.order ?? 0) - (endpointsConfig?.[b]?.order ?? 0),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function updateLastSelectedModel({
|
||||||
|
endpoint,
|
||||||
|
model,
|
||||||
|
}: {
|
||||||
|
endpoint: string;
|
||||||
|
model: string | undefined;
|
||||||
|
}) {
|
||||||
|
if (!model) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const lastSelectedModels = JSON.parse(localStorage.getItem('lastSelectedModel') || '{}');
|
||||||
|
lastSelectedModels[endpoint] = model;
|
||||||
|
localStorage.setItem('lastSelectedModel', JSON.stringify(lastSelectedModels));
|
||||||
|
}
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
const path = require('path');
|
const path = require('path');
|
||||||
require('module-alias')({ base: path.resolve(__dirname, '..', 'api') });
|
require('module-alias')({ base: path.resolve(__dirname, '..', 'api') });
|
||||||
const { askQuestion, silentExit } = require('./helpers');
|
const { askQuestion, silentExit } = require('./helpers');
|
||||||
const Transaction = require('~/models/Transaction');
|
const { Transaction } = require('~/models/Transaction');
|
||||||
const User = require('~/models/User');
|
const User = require('~/models/User');
|
||||||
const connect = require('./connect');
|
const connect = require('./connect');
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
{
|
{
|
||||||
"name": "librechat-data-provider",
|
"name": "librechat-data-provider",
|
||||||
"version": "0.4.8",
|
"version": "0.4.9",
|
||||||
"description": "data services for librechat apps",
|
"description": "data services for librechat apps",
|
||||||
"main": "dist/index.js",
|
"main": "dist/index.js",
|
||||||
"module": "dist/index.es.js",
|
"module": "dist/index.es.js",
|
||||||
|
|
|
@ -342,11 +342,12 @@ export const modularEndpoints = new Set<EModelEndpoint | string>([
|
||||||
]);
|
]);
|
||||||
|
|
||||||
export const supportsBalanceCheck = {
|
export const supportsBalanceCheck = {
|
||||||
|
[EModelEndpoint.custom]: true,
|
||||||
[EModelEndpoint.openAI]: true,
|
[EModelEndpoint.openAI]: true,
|
||||||
[EModelEndpoint.anthropic]: true,
|
[EModelEndpoint.anthropic]: true,
|
||||||
[EModelEndpoint.azureOpenAI]: true,
|
|
||||||
[EModelEndpoint.gptPlugins]: true,
|
[EModelEndpoint.gptPlugins]: true,
|
||||||
[EModelEndpoint.custom]: true,
|
[EModelEndpoint.assistants]: true,
|
||||||
|
[EModelEndpoint.azureOpenAI]: true,
|
||||||
};
|
};
|
||||||
|
|
||||||
export const visionModels = ['gpt-4-vision', 'llava-13b', 'gemini-pro-vision', 'claude-3'];
|
export const visionModels = ['gpt-4-vision', 'llava-13b', 'gemini-pro-vision', 'claude-3'];
|
||||||
|
@ -436,6 +437,10 @@ export enum ViolationTypes {
|
||||||
* Illegal Model Request (not available).
|
* Illegal Model Request (not available).
|
||||||
*/
|
*/
|
||||||
ILLEGAL_MODEL_REQUEST = 'illegal_model_request',
|
ILLEGAL_MODEL_REQUEST = 'illegal_model_request',
|
||||||
|
/**
|
||||||
|
* Token Limit Violation.
|
||||||
|
*/
|
||||||
|
TOKEN_BALANCE = 'token_balance',
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue