mirror of
https://github.com/danny-avila/LibreChat.git
synced 2026-02-02 15:51:49 +01:00
📉 feat: Add Token Usage Tracking for Agents API Routes (#11600)
* feat: Implement token usage tracking for OpenAI and Responses controllers - Added functionality to record token usage against user balances in OpenAIChatCompletionController and createResponse functions. - Introduced new utility functions for managing token spending and structured token usage. - Enhanced error handling for token recording to improve logging and debugging capabilities. - Updated imports to include new usage tracking methods and configurations. * test: Add unit tests for recordCollectedUsage function in usage.spec.ts - Introduced comprehensive tests for the recordCollectedUsage function, covering various scenarios including handling empty and null collectedUsage, single and multiple usage entries, and sequential and parallel execution cases. - Enhanced token handling tests to ensure correct calculations for both OpenAI and Anthropic formats, including cache token management. - Improved overall test coverage for usage tracking functionality, ensuring robust validation of expected behaviors and outcomes. * test: Add unit tests for OpenAI and Responses API controllers - Introduced comprehensive unit tests for the OpenAIChatCompletionController and createResponse functions, focusing on the correct invocation of recordCollectedUsage for token spending. - Enhanced tests to validate the passing of balance and transactions configuration to the recordCollectedUsage function. - Ensured proper dependency injection of spendTokens and spendStructuredTokens in the usage recording process. - Improved overall test coverage for token usage tracking, ensuring robust validation of expected behaviors and outcomes.
This commit is contained in:
parent
d13037881a
commit
9a38af5875
7 changed files with 1190 additions and 3 deletions
207
api/server/controllers/agents/__tests__/openai.spec.js
Normal file
207
api/server/controllers/agents/__tests__/openai.spec.js
Normal file
|
|
@ -0,0 +1,207 @@
|
|||
/**
|
||||
* Unit tests for OpenAI-compatible API controller
|
||||
* Tests that recordCollectedUsage is called correctly for token spending
|
||||
*/
|
||||
|
||||
const mockSpendTokens = jest.fn().mockResolvedValue({});
|
||||
const mockSpendStructuredTokens = jest.fn().mockResolvedValue({});
|
||||
const mockRecordCollectedUsage = jest
|
||||
.fn()
|
||||
.mockResolvedValue({ input_tokens: 100, output_tokens: 50 });
|
||||
const mockGetBalanceConfig = jest.fn().mockReturnValue({ enabled: true });
|
||||
const mockGetTransactionsConfig = jest.fn().mockReturnValue({ enabled: true });
|
||||
|
||||
jest.mock('nanoid', () => ({
|
||||
nanoid: jest.fn(() => 'mock-nanoid-123'),
|
||||
}));
|
||||
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
logger: {
|
||||
debug: jest.fn(),
|
||||
error: jest.fn(),
|
||||
warn: jest.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
jest.mock('@librechat/agents', () => ({
|
||||
Callback: { TOOL_ERROR: 'TOOL_ERROR' },
|
||||
ToolEndHandler: jest.fn(),
|
||||
formatAgentMessages: jest.fn().mockReturnValue({
|
||||
messages: [],
|
||||
indexTokenCountMap: {},
|
||||
}),
|
||||
ChatModelStreamHandler: jest.fn().mockImplementation(() => ({
|
||||
handle: jest.fn(),
|
||||
})),
|
||||
}));
|
||||
|
||||
jest.mock('@librechat/api', () => ({
|
||||
writeSSE: jest.fn(),
|
||||
createRun: jest.fn().mockResolvedValue({
|
||||
processStream: jest.fn().mockResolvedValue(undefined),
|
||||
}),
|
||||
createChunk: jest.fn().mockReturnValue({}),
|
||||
buildToolSet: jest.fn().mockReturnValue(new Set()),
|
||||
sendFinalChunk: jest.fn(),
|
||||
createSafeUser: jest.fn().mockReturnValue({ id: 'user-123' }),
|
||||
validateRequest: jest
|
||||
.fn()
|
||||
.mockReturnValue({ request: { model: 'agent-123', messages: [], stream: false } }),
|
||||
initializeAgent: jest.fn().mockResolvedValue({
|
||||
model: 'gpt-4',
|
||||
model_parameters: {},
|
||||
toolRegistry: {},
|
||||
}),
|
||||
getBalanceConfig: mockGetBalanceConfig,
|
||||
createErrorResponse: jest.fn(),
|
||||
getTransactionsConfig: mockGetTransactionsConfig,
|
||||
recordCollectedUsage: mockRecordCollectedUsage,
|
||||
buildNonStreamingResponse: jest.fn().mockReturnValue({ id: 'resp-123' }),
|
||||
createOpenAIStreamTracker: jest.fn().mockReturnValue({
|
||||
addText: jest.fn(),
|
||||
addReasoning: jest.fn(),
|
||||
toolCalls: new Map(),
|
||||
usage: { promptTokens: 0, completionTokens: 0, reasoningTokens: 0 },
|
||||
}),
|
||||
createOpenAIContentAggregator: jest.fn().mockReturnValue({
|
||||
addText: jest.fn(),
|
||||
addReasoning: jest.fn(),
|
||||
getText: jest.fn().mockReturnValue(''),
|
||||
getReasoning: jest.fn().mockReturnValue(''),
|
||||
toolCalls: new Map(),
|
||||
usage: { promptTokens: 100, completionTokens: 50, reasoningTokens: 0 },
|
||||
}),
|
||||
createToolExecuteHandler: jest.fn().mockReturnValue({ handle: jest.fn() }),
|
||||
isChatCompletionValidationFailure: jest.fn().mockReturnValue(false),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/ToolService', () => ({
|
||||
loadAgentTools: jest.fn().mockResolvedValue([]),
|
||||
loadToolsForExecution: jest.fn().mockResolvedValue([]),
|
||||
}));
|
||||
|
||||
jest.mock('~/models/spendTokens', () => ({
|
||||
spendTokens: mockSpendTokens,
|
||||
spendStructuredTokens: mockSpendStructuredTokens,
|
||||
}));
|
||||
|
||||
jest.mock('~/server/controllers/agents/callbacks', () => ({
|
||||
createToolEndCallback: jest.fn().mockReturnValue(jest.fn()),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/PermissionService', () => ({
|
||||
findAccessibleResources: jest.fn().mockResolvedValue([]),
|
||||
}));
|
||||
|
||||
jest.mock('~/models/Conversation', () => ({
|
||||
getConvoFiles: jest.fn().mockResolvedValue([]),
|
||||
}));
|
||||
|
||||
jest.mock('~/models/Agent', () => ({
|
||||
getAgent: jest.fn().mockResolvedValue({
|
||||
id: 'agent-123',
|
||||
provider: 'openAI',
|
||||
model_parameters: { model: 'gpt-4' },
|
||||
}),
|
||||
getAgents: jest.fn().mockResolvedValue([]),
|
||||
}));
|
||||
|
||||
jest.mock('~/models', () => ({
|
||||
getFiles: jest.fn(),
|
||||
getUserKey: jest.fn(),
|
||||
getMessages: jest.fn(),
|
||||
updateFilesUsage: jest.fn(),
|
||||
getUserKeyValues: jest.fn(),
|
||||
getUserCodeFiles: jest.fn(),
|
||||
getToolFilesByIds: jest.fn(),
|
||||
getCodeGeneratedFiles: jest.fn(),
|
||||
}));
|
||||
|
||||
describe('OpenAIChatCompletionController', () => {
|
||||
let OpenAIChatCompletionController;
|
||||
let req, res;
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
|
||||
const controller = require('../openai');
|
||||
OpenAIChatCompletionController = controller.OpenAIChatCompletionController;
|
||||
|
||||
req = {
|
||||
body: {
|
||||
model: 'agent-123',
|
||||
messages: [{ role: 'user', content: 'Hello' }],
|
||||
stream: false,
|
||||
},
|
||||
user: { id: 'user-123' },
|
||||
config: {
|
||||
endpoints: {
|
||||
agents: { allowedProviders: ['openAI'] },
|
||||
},
|
||||
},
|
||||
on: jest.fn(),
|
||||
};
|
||||
|
||||
res = {
|
||||
status: jest.fn().mockReturnThis(),
|
||||
json: jest.fn(),
|
||||
setHeader: jest.fn(),
|
||||
flushHeaders: jest.fn(),
|
||||
end: jest.fn(),
|
||||
write: jest.fn(),
|
||||
};
|
||||
});
|
||||
|
||||
describe('token usage recording', () => {
|
||||
it('should call recordCollectedUsage after successful non-streaming completion', async () => {
|
||||
await OpenAIChatCompletionController(req, res);
|
||||
|
||||
expect(mockRecordCollectedUsage).toHaveBeenCalledTimes(1);
|
||||
expect(mockRecordCollectedUsage).toHaveBeenCalledWith(
|
||||
{ spendTokens: mockSpendTokens, spendStructuredTokens: mockSpendStructuredTokens },
|
||||
expect.objectContaining({
|
||||
user: 'user-123',
|
||||
conversationId: expect.any(String),
|
||||
collectedUsage: expect.any(Array),
|
||||
context: 'message',
|
||||
balance: { enabled: true },
|
||||
transactions: { enabled: true },
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should pass balance and transactions config to recordCollectedUsage', async () => {
|
||||
mockGetBalanceConfig.mockReturnValue({ enabled: true, startBalance: 1000 });
|
||||
mockGetTransactionsConfig.mockReturnValue({ enabled: true, rateLimit: 100 });
|
||||
|
||||
await OpenAIChatCompletionController(req, res);
|
||||
|
||||
expect(mockRecordCollectedUsage).toHaveBeenCalledWith(
|
||||
expect.any(Object),
|
||||
expect.objectContaining({
|
||||
balance: { enabled: true, startBalance: 1000 },
|
||||
transactions: { enabled: true, rateLimit: 100 },
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should pass spendTokens and spendStructuredTokens as dependencies', async () => {
|
||||
await OpenAIChatCompletionController(req, res);
|
||||
|
||||
const [deps] = mockRecordCollectedUsage.mock.calls[0];
|
||||
expect(deps).toHaveProperty('spendTokens', mockSpendTokens);
|
||||
expect(deps).toHaveProperty('spendStructuredTokens', mockSpendStructuredTokens);
|
||||
});
|
||||
|
||||
it('should include model from primaryConfig in recordCollectedUsage params', async () => {
|
||||
await OpenAIChatCompletionController(req, res);
|
||||
|
||||
expect(mockRecordCollectedUsage).toHaveBeenCalledWith(
|
||||
expect.any(Object),
|
||||
expect.objectContaining({
|
||||
model: 'gpt-4',
|
||||
}),
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
315
api/server/controllers/agents/__tests__/responses.unit.spec.js
Normal file
315
api/server/controllers/agents/__tests__/responses.unit.spec.js
Normal file
|
|
@ -0,0 +1,315 @@
|
|||
/**
|
||||
* Unit tests for Open Responses API controller
|
||||
* Tests that recordCollectedUsage is called correctly for token spending
|
||||
*/
|
||||
|
||||
const mockSpendTokens = jest.fn().mockResolvedValue({});
|
||||
const mockSpendStructuredTokens = jest.fn().mockResolvedValue({});
|
||||
const mockRecordCollectedUsage = jest
|
||||
.fn()
|
||||
.mockResolvedValue({ input_tokens: 100, output_tokens: 50 });
|
||||
const mockGetBalanceConfig = jest.fn().mockReturnValue({ enabled: true });
|
||||
const mockGetTransactionsConfig = jest.fn().mockReturnValue({ enabled: true });
|
||||
|
||||
jest.mock('nanoid', () => ({
|
||||
nanoid: jest.fn(() => 'mock-nanoid-123'),
|
||||
}));
|
||||
|
||||
jest.mock('uuid', () => ({
|
||||
v4: jest.fn(() => 'mock-uuid-456'),
|
||||
}));
|
||||
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
logger: {
|
||||
debug: jest.fn(),
|
||||
error: jest.fn(),
|
||||
warn: jest.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
jest.mock('@librechat/agents', () => ({
|
||||
Callback: { TOOL_ERROR: 'TOOL_ERROR' },
|
||||
ToolEndHandler: jest.fn(),
|
||||
formatAgentMessages: jest.fn().mockReturnValue({
|
||||
messages: [],
|
||||
indexTokenCountMap: {},
|
||||
}),
|
||||
ChatModelStreamHandler: jest.fn().mockImplementation(() => ({
|
||||
handle: jest.fn(),
|
||||
})),
|
||||
}));
|
||||
|
||||
jest.mock('@librechat/api', () => ({
|
||||
createRun: jest.fn().mockResolvedValue({
|
||||
processStream: jest.fn().mockResolvedValue(undefined),
|
||||
}),
|
||||
buildToolSet: jest.fn().mockReturnValue(new Set()),
|
||||
createSafeUser: jest.fn().mockReturnValue({ id: 'user-123' }),
|
||||
initializeAgent: jest.fn().mockResolvedValue({
|
||||
model: 'claude-3',
|
||||
model_parameters: {},
|
||||
toolRegistry: {},
|
||||
}),
|
||||
getBalanceConfig: mockGetBalanceConfig,
|
||||
getTransactionsConfig: mockGetTransactionsConfig,
|
||||
recordCollectedUsage: mockRecordCollectedUsage,
|
||||
createToolExecuteHandler: jest.fn().mockReturnValue({ handle: jest.fn() }),
|
||||
// Responses API
|
||||
writeDone: jest.fn(),
|
||||
buildResponse: jest.fn().mockReturnValue({ id: 'resp_123', output: [] }),
|
||||
generateResponseId: jest.fn().mockReturnValue('resp_mock-123'),
|
||||
isValidationFailure: jest.fn().mockReturnValue(false),
|
||||
emitResponseCreated: jest.fn(),
|
||||
createResponseContext: jest.fn().mockReturnValue({ responseId: 'resp_123' }),
|
||||
createResponseTracker: jest.fn().mockReturnValue({
|
||||
usage: { promptTokens: 100, completionTokens: 50 },
|
||||
}),
|
||||
setupStreamingResponse: jest.fn(),
|
||||
emitResponseInProgress: jest.fn(),
|
||||
convertInputToMessages: jest.fn().mockReturnValue([]),
|
||||
validateResponseRequest: jest.fn().mockReturnValue({
|
||||
request: { model: 'agent-123', input: 'Hello', stream: false },
|
||||
}),
|
||||
buildAggregatedResponse: jest.fn().mockReturnValue({
|
||||
id: 'resp_123',
|
||||
status: 'completed',
|
||||
output: [],
|
||||
usage: { input_tokens: 100, output_tokens: 50, total_tokens: 150 },
|
||||
}),
|
||||
createResponseAggregator: jest.fn().mockReturnValue({
|
||||
usage: { promptTokens: 100, completionTokens: 50 },
|
||||
}),
|
||||
sendResponsesErrorResponse: jest.fn(),
|
||||
createResponsesEventHandlers: jest.fn().mockReturnValue({
|
||||
handlers: {
|
||||
on_message_delta: { handle: jest.fn() },
|
||||
on_reasoning_delta: { handle: jest.fn() },
|
||||
on_run_step: { handle: jest.fn() },
|
||||
on_run_step_delta: { handle: jest.fn() },
|
||||
on_chat_model_end: { handle: jest.fn() },
|
||||
},
|
||||
finalizeStream: jest.fn(),
|
||||
}),
|
||||
createAggregatorEventHandlers: jest.fn().mockReturnValue({
|
||||
on_message_delta: { handle: jest.fn() },
|
||||
on_reasoning_delta: { handle: jest.fn() },
|
||||
on_run_step: { handle: jest.fn() },
|
||||
on_run_step_delta: { handle: jest.fn() },
|
||||
on_chat_model_end: { handle: jest.fn() },
|
||||
}),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/ToolService', () => ({
|
||||
loadAgentTools: jest.fn().mockResolvedValue([]),
|
||||
loadToolsForExecution: jest.fn().mockResolvedValue([]),
|
||||
}));
|
||||
|
||||
jest.mock('~/models/spendTokens', () => ({
|
||||
spendTokens: mockSpendTokens,
|
||||
spendStructuredTokens: mockSpendStructuredTokens,
|
||||
}));
|
||||
|
||||
jest.mock('~/server/controllers/agents/callbacks', () => ({
|
||||
createToolEndCallback: jest.fn().mockReturnValue(jest.fn()),
|
||||
createResponsesToolEndCallback: jest.fn().mockReturnValue(jest.fn()),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/PermissionService', () => ({
|
||||
findAccessibleResources: jest.fn().mockResolvedValue([]),
|
||||
}));
|
||||
|
||||
jest.mock('~/models/Conversation', () => ({
|
||||
getConvoFiles: jest.fn().mockResolvedValue([]),
|
||||
saveConvo: jest.fn().mockResolvedValue({}),
|
||||
getConvo: jest.fn().mockResolvedValue(null),
|
||||
}));
|
||||
|
||||
jest.mock('~/models/Agent', () => ({
|
||||
getAgent: jest.fn().mockResolvedValue({
|
||||
id: 'agent-123',
|
||||
name: 'Test Agent',
|
||||
provider: 'anthropic',
|
||||
model_parameters: { model: 'claude-3' },
|
||||
}),
|
||||
getAgents: jest.fn().mockResolvedValue([]),
|
||||
}));
|
||||
|
||||
jest.mock('~/models', () => ({
|
||||
getFiles: jest.fn(),
|
||||
getUserKey: jest.fn(),
|
||||
getMessages: jest.fn().mockResolvedValue([]),
|
||||
saveMessage: jest.fn().mockResolvedValue({}),
|
||||
updateFilesUsage: jest.fn(),
|
||||
getUserKeyValues: jest.fn(),
|
||||
getUserCodeFiles: jest.fn(),
|
||||
getToolFilesByIds: jest.fn(),
|
||||
getCodeGeneratedFiles: jest.fn(),
|
||||
}));
|
||||
|
||||
describe('createResponse controller', () => {
|
||||
let createResponse;
|
||||
let req, res;
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
|
||||
const controller = require('../responses');
|
||||
createResponse = controller.createResponse;
|
||||
|
||||
req = {
|
||||
body: {
|
||||
model: 'agent-123',
|
||||
input: 'Hello',
|
||||
stream: false,
|
||||
},
|
||||
user: { id: 'user-123' },
|
||||
config: {
|
||||
endpoints: {
|
||||
agents: { allowedProviders: ['anthropic'] },
|
||||
},
|
||||
},
|
||||
on: jest.fn(),
|
||||
};
|
||||
|
||||
res = {
|
||||
status: jest.fn().mockReturnThis(),
|
||||
json: jest.fn(),
|
||||
setHeader: jest.fn(),
|
||||
flushHeaders: jest.fn(),
|
||||
end: jest.fn(),
|
||||
write: jest.fn(),
|
||||
};
|
||||
});
|
||||
|
||||
describe('token usage recording - non-streaming', () => {
|
||||
it('should call recordCollectedUsage after successful non-streaming completion', async () => {
|
||||
await createResponse(req, res);
|
||||
|
||||
expect(mockRecordCollectedUsage).toHaveBeenCalledTimes(1);
|
||||
expect(mockRecordCollectedUsage).toHaveBeenCalledWith(
|
||||
{ spendTokens: mockSpendTokens, spendStructuredTokens: mockSpendStructuredTokens },
|
||||
expect.objectContaining({
|
||||
user: 'user-123',
|
||||
conversationId: expect.any(String),
|
||||
collectedUsage: expect.any(Array),
|
||||
context: 'message',
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should pass balance and transactions config to recordCollectedUsage', async () => {
|
||||
mockGetBalanceConfig.mockReturnValue({ enabled: true, startBalance: 2000 });
|
||||
mockGetTransactionsConfig.mockReturnValue({ enabled: true });
|
||||
|
||||
await createResponse(req, res);
|
||||
|
||||
expect(mockRecordCollectedUsage).toHaveBeenCalledWith(
|
||||
expect.any(Object),
|
||||
expect.objectContaining({
|
||||
balance: { enabled: true, startBalance: 2000 },
|
||||
transactions: { enabled: true },
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should pass spendTokens and spendStructuredTokens as dependencies', async () => {
|
||||
await createResponse(req, res);
|
||||
|
||||
const [deps] = mockRecordCollectedUsage.mock.calls[0];
|
||||
expect(deps).toHaveProperty('spendTokens', mockSpendTokens);
|
||||
expect(deps).toHaveProperty('spendStructuredTokens', mockSpendStructuredTokens);
|
||||
});
|
||||
|
||||
it('should include model from primaryConfig in recordCollectedUsage params', async () => {
|
||||
await createResponse(req, res);
|
||||
|
||||
expect(mockRecordCollectedUsage).toHaveBeenCalledWith(
|
||||
expect.any(Object),
|
||||
expect.objectContaining({
|
||||
model: 'claude-3',
|
||||
}),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('token usage recording - streaming', () => {
|
||||
beforeEach(() => {
|
||||
req.body.stream = true;
|
||||
|
||||
const api = require('@librechat/api');
|
||||
api.validateResponseRequest.mockReturnValue({
|
||||
request: { model: 'agent-123', input: 'Hello', stream: true },
|
||||
});
|
||||
});
|
||||
|
||||
it('should call recordCollectedUsage after successful streaming completion', async () => {
|
||||
await createResponse(req, res);
|
||||
|
||||
expect(mockRecordCollectedUsage).toHaveBeenCalledTimes(1);
|
||||
expect(mockRecordCollectedUsage).toHaveBeenCalledWith(
|
||||
{ spendTokens: mockSpendTokens, spendStructuredTokens: mockSpendStructuredTokens },
|
||||
expect.objectContaining({
|
||||
user: 'user-123',
|
||||
context: 'message',
|
||||
}),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('collectedUsage population', () => {
|
||||
it('should collect usage from on_chat_model_end events', async () => {
|
||||
const api = require('@librechat/api');
|
||||
|
||||
let capturedOnChatModelEnd;
|
||||
api.createAggregatorEventHandlers.mockImplementation(() => {
|
||||
return {
|
||||
on_message_delta: { handle: jest.fn() },
|
||||
on_reasoning_delta: { handle: jest.fn() },
|
||||
on_run_step: { handle: jest.fn() },
|
||||
on_run_step_delta: { handle: jest.fn() },
|
||||
on_chat_model_end: {
|
||||
handle: jest.fn((event, data) => {
|
||||
if (capturedOnChatModelEnd) {
|
||||
capturedOnChatModelEnd(event, data);
|
||||
}
|
||||
}),
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
api.createRun.mockImplementation(async ({ customHandlers }) => {
|
||||
capturedOnChatModelEnd = (event, data) => {
|
||||
customHandlers.on_chat_model_end.handle(event, data);
|
||||
};
|
||||
|
||||
return {
|
||||
processStream: jest.fn().mockImplementation(async () => {
|
||||
customHandlers.on_chat_model_end.handle('on_chat_model_end', {
|
||||
output: {
|
||||
usage_metadata: {
|
||||
input_tokens: 150,
|
||||
output_tokens: 75,
|
||||
model: 'claude-3',
|
||||
},
|
||||
},
|
||||
});
|
||||
}),
|
||||
};
|
||||
});
|
||||
|
||||
await createResponse(req, res);
|
||||
|
||||
expect(mockRecordCollectedUsage).toHaveBeenCalledWith(
|
||||
expect.any(Object),
|
||||
expect.objectContaining({
|
||||
collectedUsage: expect.arrayContaining([
|
||||
expect.objectContaining({
|
||||
input_tokens: 150,
|
||||
output_tokens: 75,
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
@ -16,16 +16,20 @@ const {
|
|||
createSafeUser,
|
||||
validateRequest,
|
||||
initializeAgent,
|
||||
getBalanceConfig,
|
||||
createErrorResponse,
|
||||
recordCollectedUsage,
|
||||
getTransactionsConfig,
|
||||
createToolExecuteHandler,
|
||||
buildNonStreamingResponse,
|
||||
createOpenAIStreamTracker,
|
||||
createOpenAIContentAggregator,
|
||||
createToolExecuteHandler,
|
||||
isChatCompletionValidationFailure,
|
||||
} = require('@librechat/api');
|
||||
const { loadAgentTools, loadToolsForExecution } = require('~/server/services/ToolService');
|
||||
const { createToolEndCallback } = require('~/server/controllers/agents/callbacks');
|
||||
const { findAccessibleResources } = require('~/server/services/PermissionService');
|
||||
const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens');
|
||||
const { getConvoFiles } = require('~/models/Conversation');
|
||||
const { getAgent, getAgents } = require('~/models/Agent');
|
||||
const db = require('~/models');
|
||||
|
|
@ -497,6 +501,24 @@ const OpenAIChatCompletionController = async (req, res) => {
|
|||
},
|
||||
});
|
||||
|
||||
// Record token usage against balance
|
||||
const balanceConfig = getBalanceConfig(appConfig);
|
||||
const transactionsConfig = getTransactionsConfig(appConfig);
|
||||
recordCollectedUsage(
|
||||
{ spendTokens, spendStructuredTokens },
|
||||
{
|
||||
user: userId,
|
||||
conversationId,
|
||||
collectedUsage,
|
||||
context: 'message',
|
||||
balance: balanceConfig,
|
||||
transactions: transactionsConfig,
|
||||
model: primaryConfig.model || agent.model_parameters?.model,
|
||||
},
|
||||
).catch((err) => {
|
||||
logger.error('[OpenAI API] Error recording usage:', err);
|
||||
});
|
||||
|
||||
// Finalize response
|
||||
const duration = Date.now() - requestStartTime;
|
||||
if (isStreaming) {
|
||||
|
|
|
|||
|
|
@ -13,6 +13,9 @@ const {
|
|||
buildToolSet,
|
||||
createSafeUser,
|
||||
initializeAgent,
|
||||
getBalanceConfig,
|
||||
recordCollectedUsage,
|
||||
getTransactionsConfig,
|
||||
createToolExecuteHandler,
|
||||
// Responses API
|
||||
writeDone,
|
||||
|
|
@ -39,6 +42,7 @@ const {
|
|||
const { loadAgentTools, loadToolsForExecution } = require('~/server/services/ToolService');
|
||||
const { findAccessibleResources } = require('~/server/services/PermissionService');
|
||||
const { getConvoFiles, saveConvo, getConvo } = require('~/models/Conversation');
|
||||
const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens');
|
||||
const { getAgent, getAgents } = require('~/models/Agent');
|
||||
const db = require('~/models');
|
||||
|
||||
|
|
@ -403,6 +407,9 @@ const createResponse = async (req, res) => {
|
|||
const { handlers: responsesHandlers, finalizeStream } =
|
||||
createResponsesEventHandlers(handlerConfig);
|
||||
|
||||
// Collect usage for balance tracking
|
||||
const collectedUsage = [];
|
||||
|
||||
// Built-in handler for processing raw model stream chunks
|
||||
const chatModelStreamHandler = new ChatModelStreamHandler();
|
||||
|
||||
|
|
@ -445,7 +452,15 @@ const createResponse = async (req, res) => {
|
|||
on_reasoning_delta: responsesHandlers.on_reasoning_delta,
|
||||
on_run_step: responsesHandlers.on_run_step,
|
||||
on_run_step_delta: responsesHandlers.on_run_step_delta,
|
||||
on_chat_model_end: responsesHandlers.on_chat_model_end,
|
||||
on_chat_model_end: {
|
||||
handle: (event, data) => {
|
||||
responsesHandlers.on_chat_model_end.handle(event, data);
|
||||
const usage = data?.output?.usage_metadata;
|
||||
if (usage) {
|
||||
collectedUsage.push(usage);
|
||||
}
|
||||
},
|
||||
},
|
||||
on_tool_end: new ToolEndHandler(toolEndCallback, logger),
|
||||
on_run_step_completed: { handle: () => {} },
|
||||
on_chain_stream: { handle: () => {} },
|
||||
|
|
@ -499,6 +514,24 @@ const createResponse = async (req, res) => {
|
|||
},
|
||||
});
|
||||
|
||||
// Record token usage against balance
|
||||
const balanceConfig = getBalanceConfig(req.config);
|
||||
const transactionsConfig = getTransactionsConfig(req.config);
|
||||
recordCollectedUsage(
|
||||
{ spendTokens, spendStructuredTokens },
|
||||
{
|
||||
user: userId,
|
||||
conversationId,
|
||||
collectedUsage,
|
||||
context: 'message',
|
||||
balance: balanceConfig,
|
||||
transactions: transactionsConfig,
|
||||
model: primaryConfig.model || agent.model_parameters?.model,
|
||||
},
|
||||
).catch((err) => {
|
||||
logger.error('[Responses API] Error recording usage:', err);
|
||||
});
|
||||
|
||||
// Finalize the stream
|
||||
finalizeStream();
|
||||
res.end();
|
||||
|
|
@ -539,6 +572,9 @@ const createResponse = async (req, res) => {
|
|||
|
||||
const chatModelStreamHandler = new ChatModelStreamHandler();
|
||||
|
||||
// Collect usage for balance tracking
|
||||
const collectedUsage = [];
|
||||
|
||||
/** @type {Promise<import('librechat-data-provider').TAttachment | null>[]} */
|
||||
const artifactPromises = [];
|
||||
const toolEndCallback = createToolEndCallback({ req, res, artifactPromises, streamId: null });
|
||||
|
|
@ -569,7 +605,15 @@ const createResponse = async (req, res) => {
|
|||
on_reasoning_delta: aggregatorHandlers.on_reasoning_delta,
|
||||
on_run_step: aggregatorHandlers.on_run_step,
|
||||
on_run_step_delta: aggregatorHandlers.on_run_step_delta,
|
||||
on_chat_model_end: aggregatorHandlers.on_chat_model_end,
|
||||
on_chat_model_end: {
|
||||
handle: (event, data) => {
|
||||
aggregatorHandlers.on_chat_model_end.handle(event, data);
|
||||
const usage = data?.output?.usage_metadata;
|
||||
if (usage) {
|
||||
collectedUsage.push(usage);
|
||||
}
|
||||
},
|
||||
},
|
||||
on_tool_end: new ToolEndHandler(toolEndCallback, logger),
|
||||
on_run_step_completed: { handle: () => {} },
|
||||
on_chain_stream: { handle: () => {} },
|
||||
|
|
@ -621,6 +665,24 @@ const createResponse = async (req, res) => {
|
|||
},
|
||||
});
|
||||
|
||||
// Record token usage against balance
|
||||
const balanceConfig = getBalanceConfig(req.config);
|
||||
const transactionsConfig = getTransactionsConfig(req.config);
|
||||
recordCollectedUsage(
|
||||
{ spendTokens, spendStructuredTokens },
|
||||
{
|
||||
user: userId,
|
||||
conversationId,
|
||||
collectedUsage,
|
||||
context: 'message',
|
||||
balance: balanceConfig,
|
||||
transactions: transactionsConfig,
|
||||
model: primaryConfig.model || agent.model_parameters?.model,
|
||||
},
|
||||
).catch((err) => {
|
||||
logger.error('[Responses API] Error recording usage:', err);
|
||||
});
|
||||
|
||||
if (artifactPromises.length > 0) {
|
||||
try {
|
||||
await Promise.all(artifactPromises);
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ export * from './legacy';
|
|||
export * from './memory';
|
||||
export * from './migration';
|
||||
export * from './openai';
|
||||
export * from './usage';
|
||||
export * from './resources';
|
||||
export * from './responses';
|
||||
export * from './run';
|
||||
|
|
|
|||
434
packages/api/src/agents/usage.spec.ts
Normal file
434
packages/api/src/agents/usage.spec.ts
Normal file
|
|
@ -0,0 +1,434 @@
|
|||
import { recordCollectedUsage } from './usage';
|
||||
import type { RecordUsageDeps, RecordUsageParams } from './usage';
|
||||
import type { UsageMetadata } from '../stream/interfaces/IJobStore';
|
||||
|
||||
describe('recordCollectedUsage', () => {
|
||||
let mockSpendTokens: jest.Mock;
|
||||
let mockSpendStructuredTokens: jest.Mock;
|
||||
let deps: RecordUsageDeps;
|
||||
|
||||
const baseParams: Omit<RecordUsageParams, 'collectedUsage'> = {
|
||||
user: 'user-123',
|
||||
conversationId: 'convo-123',
|
||||
model: 'gpt-4',
|
||||
context: 'message',
|
||||
balance: { enabled: true },
|
||||
transactions: { enabled: true },
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
mockSpendTokens = jest.fn().mockResolvedValue(undefined);
|
||||
mockSpendStructuredTokens = jest.fn().mockResolvedValue(undefined);
|
||||
deps = {
|
||||
spendTokens: mockSpendTokens,
|
||||
spendStructuredTokens: mockSpendStructuredTokens,
|
||||
};
|
||||
});
|
||||
|
||||
describe('basic functionality', () => {
|
||||
it('should return undefined if collectedUsage is empty', async () => {
|
||||
const result = await recordCollectedUsage(deps, {
|
||||
...baseParams,
|
||||
collectedUsage: [],
|
||||
});
|
||||
|
||||
expect(result).toBeUndefined();
|
||||
expect(mockSpendTokens).not.toHaveBeenCalled();
|
||||
expect(mockSpendStructuredTokens).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should return undefined if collectedUsage is null-ish', async () => {
|
||||
const result = await recordCollectedUsage(deps, {
|
||||
...baseParams,
|
||||
collectedUsage: null as unknown as UsageMetadata[],
|
||||
});
|
||||
|
||||
expect(result).toBeUndefined();
|
||||
expect(mockSpendTokens).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle single usage entry correctly', async () => {
|
||||
const collectedUsage: UsageMetadata[] = [
|
||||
{ input_tokens: 100, output_tokens: 50, model: 'gpt-4' },
|
||||
];
|
||||
|
||||
const result = await recordCollectedUsage(deps, {
|
||||
...baseParams,
|
||||
collectedUsage,
|
||||
});
|
||||
|
||||
expect(mockSpendTokens).toHaveBeenCalledTimes(1);
|
||||
expect(mockSpendTokens).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
user: 'user-123',
|
||||
conversationId: 'convo-123',
|
||||
model: 'gpt-4',
|
||||
context: 'message',
|
||||
}),
|
||||
{ promptTokens: 100, completionTokens: 50 },
|
||||
);
|
||||
expect(result).toEqual({ input_tokens: 100, output_tokens: 50 });
|
||||
});
|
||||
|
||||
it('should skip null entries in collectedUsage', async () => {
|
||||
const collectedUsage = [
|
||||
{ input_tokens: 100, output_tokens: 50, model: 'gpt-4' },
|
||||
null,
|
||||
{ input_tokens: 200, output_tokens: 60, model: 'gpt-4' },
|
||||
] as UsageMetadata[];
|
||||
|
||||
const result = await recordCollectedUsage(deps, {
|
||||
...baseParams,
|
||||
collectedUsage,
|
||||
});
|
||||
|
||||
expect(mockSpendTokens).toHaveBeenCalledTimes(2);
|
||||
expect(result).toEqual({ input_tokens: 100, output_tokens: 110 });
|
||||
});
|
||||
});
|
||||
|
||||
describe('sequential execution (tool calls)', () => {
|
||||
it('should calculate tokens correctly for sequential tool calls', async () => {
|
||||
const collectedUsage: UsageMetadata[] = [
|
||||
{ input_tokens: 100, output_tokens: 50, model: 'gpt-4' },
|
||||
{ input_tokens: 150, output_tokens: 30, model: 'gpt-4' },
|
||||
{ input_tokens: 180, output_tokens: 20, model: 'gpt-4' },
|
||||
];
|
||||
|
||||
const result = await recordCollectedUsage(deps, {
|
||||
...baseParams,
|
||||
collectedUsage,
|
||||
});
|
||||
|
||||
expect(mockSpendTokens).toHaveBeenCalledTimes(3);
|
||||
expect(result?.output_tokens).toBe(100); // 50 + 30 + 20
|
||||
expect(result?.input_tokens).toBe(100); // First entry's input
|
||||
});
|
||||
});
|
||||
|
||||
describe('parallel execution (multiple agents)', () => {
|
||||
it('should handle parallel agents with independent input tokens', async () => {
|
||||
const collectedUsage: UsageMetadata[] = [
|
||||
{ input_tokens: 100, output_tokens: 50, model: 'gpt-4' },
|
||||
{ input_tokens: 80, output_tokens: 40, model: 'gpt-4' },
|
||||
];
|
||||
|
||||
const result = await recordCollectedUsage(deps, {
|
||||
...baseParams,
|
||||
collectedUsage,
|
||||
});
|
||||
|
||||
expect(mockSpendTokens).toHaveBeenCalledTimes(2);
|
||||
expect(result?.output_tokens).toBe(90); // 50 + 40
|
||||
expect(result?.output_tokens).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
it('should NOT produce negative output_tokens for parallel execution', async () => {
|
||||
const collectedUsage: UsageMetadata[] = [
|
||||
{ input_tokens: 200, output_tokens: 100, model: 'gpt-4' },
|
||||
{ input_tokens: 50, output_tokens: 30, model: 'gpt-4' },
|
||||
];
|
||||
|
||||
const result = await recordCollectedUsage(deps, {
|
||||
...baseParams,
|
||||
collectedUsage,
|
||||
});
|
||||
|
||||
expect(result?.output_tokens).toBeGreaterThan(0);
|
||||
expect(result?.output_tokens).toBe(130); // 100 + 30
|
||||
});
|
||||
|
||||
it('should calculate correct total output for multiple parallel agents', async () => {
|
||||
const collectedUsage: UsageMetadata[] = [
|
||||
{ input_tokens: 100, output_tokens: 50, model: 'gpt-4' },
|
||||
{ input_tokens: 120, output_tokens: 60, model: 'gpt-4-turbo' },
|
||||
{ input_tokens: 80, output_tokens: 40, model: 'claude-3' },
|
||||
];
|
||||
|
||||
const result = await recordCollectedUsage(deps, {
|
||||
...baseParams,
|
||||
collectedUsage,
|
||||
});
|
||||
|
||||
expect(mockSpendTokens).toHaveBeenCalledTimes(3);
|
||||
expect(result?.output_tokens).toBe(150); // 50 + 60 + 40
|
||||
});
|
||||
});
|
||||
|
||||
describe('cache token handling - OpenAI format', () => {
|
||||
it('should use spendStructuredTokens for cache tokens (input_token_details)', async () => {
|
||||
const collectedUsage: UsageMetadata[] = [
|
||||
{
|
||||
input_tokens: 100,
|
||||
output_tokens: 50,
|
||||
model: 'gpt-4',
|
||||
input_token_details: {
|
||||
cache_creation: 20,
|
||||
cache_read: 10,
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
const result = await recordCollectedUsage(deps, {
|
||||
...baseParams,
|
||||
collectedUsage,
|
||||
});
|
||||
|
||||
expect(mockSpendStructuredTokens).toHaveBeenCalledTimes(1);
|
||||
expect(mockSpendTokens).not.toHaveBeenCalled();
|
||||
expect(mockSpendStructuredTokens).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ model: 'gpt-4' }),
|
||||
{
|
||||
promptTokens: { input: 100, write: 20, read: 10 },
|
||||
completionTokens: 50,
|
||||
},
|
||||
);
|
||||
expect(result?.input_tokens).toBe(130); // 100 + 20 + 10
|
||||
});
|
||||
});
|
||||
|
||||
describe('cache token handling - Anthropic format', () => {
|
||||
it('should use spendStructuredTokens for cache tokens (cache_*_input_tokens)', async () => {
|
||||
const collectedUsage: UsageMetadata[] = [
|
||||
{
|
||||
input_tokens: 100,
|
||||
output_tokens: 50,
|
||||
model: 'claude-3',
|
||||
cache_creation_input_tokens: 25,
|
||||
cache_read_input_tokens: 15,
|
||||
},
|
||||
];
|
||||
|
||||
const result = await recordCollectedUsage(deps, {
|
||||
...baseParams,
|
||||
collectedUsage,
|
||||
});
|
||||
|
||||
expect(mockSpendStructuredTokens).toHaveBeenCalledTimes(1);
|
||||
expect(mockSpendTokens).not.toHaveBeenCalled();
|
||||
expect(mockSpendStructuredTokens).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ model: 'claude-3' }),
|
||||
{
|
||||
promptTokens: { input: 100, write: 25, read: 15 },
|
||||
completionTokens: 50,
|
||||
},
|
||||
);
|
||||
expect(result?.input_tokens).toBe(140); // 100 + 25 + 15
|
||||
});
|
||||
});
|
||||
|
||||
describe('mixed cache and non-cache entries', () => {
|
||||
it('should handle mixed entries correctly', async () => {
|
||||
const collectedUsage: UsageMetadata[] = [
|
||||
{ input_tokens: 100, output_tokens: 50, model: 'gpt-4' },
|
||||
{
|
||||
input_tokens: 150,
|
||||
output_tokens: 30,
|
||||
model: 'gpt-4',
|
||||
input_token_details: { cache_creation: 10, cache_read: 5 },
|
||||
},
|
||||
{ input_tokens: 200, output_tokens: 20, model: 'gpt-4' },
|
||||
];
|
||||
|
||||
const result = await recordCollectedUsage(deps, {
|
||||
...baseParams,
|
||||
collectedUsage,
|
||||
});
|
||||
|
||||
expect(mockSpendTokens).toHaveBeenCalledTimes(2);
|
||||
expect(mockSpendStructuredTokens).toHaveBeenCalledTimes(1);
|
||||
expect(result?.output_tokens).toBe(100); // 50 + 30 + 20
|
||||
});
|
||||
});
|
||||
|
||||
describe('model fallback', () => {
|
||||
it('should use usage.model when available', async () => {
|
||||
const collectedUsage: UsageMetadata[] = [
|
||||
{ input_tokens: 100, output_tokens: 50, model: 'gpt-4-turbo' },
|
||||
];
|
||||
|
||||
await recordCollectedUsage(deps, {
|
||||
...baseParams,
|
||||
model: 'fallback-model',
|
||||
collectedUsage,
|
||||
});
|
||||
|
||||
expect(mockSpendTokens).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ model: 'gpt-4-turbo' }),
|
||||
expect.any(Object),
|
||||
);
|
||||
});
|
||||
|
||||
it('should fallback to param model when usage.model is missing', async () => {
|
||||
const collectedUsage: UsageMetadata[] = [{ input_tokens: 100, output_tokens: 50 }];
|
||||
|
||||
await recordCollectedUsage(deps, {
|
||||
...baseParams,
|
||||
model: 'param-model',
|
||||
collectedUsage,
|
||||
});
|
||||
|
||||
expect(mockSpendTokens).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ model: 'param-model' }),
|
||||
expect.any(Object),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('real-world scenarios', () => {
|
||||
it('should correctly sum output tokens for sequential tool calls with growing context', async () => {
|
||||
const collectedUsage: UsageMetadata[] = [
|
||||
{ input_tokens: 31596, output_tokens: 151, model: 'claude-opus' },
|
||||
{ input_tokens: 35368, output_tokens: 150, model: 'claude-opus' },
|
||||
{ input_tokens: 58362, output_tokens: 295, model: 'claude-opus' },
|
||||
{ input_tokens: 112604, output_tokens: 193, model: 'claude-opus' },
|
||||
{ input_tokens: 257440, output_tokens: 2217, model: 'claude-opus' },
|
||||
];
|
||||
|
||||
const result = await recordCollectedUsage(deps, {
|
||||
...baseParams,
|
||||
collectedUsage,
|
||||
});
|
||||
|
||||
expect(result?.input_tokens).toBe(31596);
|
||||
expect(result?.output_tokens).toBe(3006); // 151 + 150 + 295 + 193 + 2217
|
||||
expect(mockSpendTokens).toHaveBeenCalledTimes(5);
|
||||
});
|
||||
|
||||
it('should handle cache tokens with multiple tool calls', async () => {
|
||||
const collectedUsage: UsageMetadata[] = [
|
||||
{
|
||||
input_tokens: 788,
|
||||
output_tokens: 163,
|
||||
model: 'claude-opus',
|
||||
input_token_details: { cache_read: 0, cache_creation: 30808 },
|
||||
},
|
||||
{
|
||||
input_tokens: 3802,
|
||||
output_tokens: 149,
|
||||
model: 'claude-opus',
|
||||
input_token_details: { cache_read: 30808, cache_creation: 768 },
|
||||
},
|
||||
{
|
||||
input_tokens: 26808,
|
||||
output_tokens: 225,
|
||||
model: 'claude-opus',
|
||||
input_token_details: { cache_read: 31576, cache_creation: 0 },
|
||||
},
|
||||
];
|
||||
|
||||
const result = await recordCollectedUsage(deps, {
|
||||
...baseParams,
|
||||
collectedUsage,
|
||||
});
|
||||
|
||||
// input_tokens = 788 + 30808 + 0 = 31596
|
||||
expect(result?.input_tokens).toBe(31596);
|
||||
// output_tokens = 163 + 149 + 225 = 537
|
||||
expect(result?.output_tokens).toBe(537);
|
||||
expect(mockSpendStructuredTokens).toHaveBeenCalledTimes(3);
|
||||
expect(mockSpendTokens).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('error handling', () => {
|
||||
it('should catch and log errors from spendTokens without throwing', async () => {
|
||||
mockSpendTokens.mockRejectedValue(new Error('DB error'));
|
||||
|
||||
const collectedUsage: UsageMetadata[] = [
|
||||
{ input_tokens: 100, output_tokens: 50, model: 'gpt-4' },
|
||||
];
|
||||
|
||||
const result = await recordCollectedUsage(deps, {
|
||||
...baseParams,
|
||||
collectedUsage,
|
||||
});
|
||||
|
||||
expect(result).toEqual({ input_tokens: 100, output_tokens: 50 });
|
||||
});
|
||||
|
||||
it('should catch and log errors from spendStructuredTokens without throwing', async () => {
|
||||
mockSpendStructuredTokens.mockRejectedValue(new Error('DB error'));
|
||||
|
||||
const collectedUsage: UsageMetadata[] = [
|
||||
{
|
||||
input_tokens: 100,
|
||||
output_tokens: 50,
|
||||
model: 'gpt-4',
|
||||
input_token_details: { cache_creation: 20, cache_read: 10 },
|
||||
},
|
||||
];
|
||||
|
||||
const result = await recordCollectedUsage(deps, {
|
||||
...baseParams,
|
||||
collectedUsage,
|
||||
});
|
||||
|
||||
expect(result).toEqual({ input_tokens: 130, output_tokens: 50 });
|
||||
});
|
||||
});
|
||||
|
||||
describe('transaction metadata', () => {
|
||||
it('should pass all metadata fields to spend functions', async () => {
|
||||
const collectedUsage: UsageMetadata[] = [
|
||||
{ input_tokens: 100, output_tokens: 50, model: 'gpt-4' },
|
||||
];
|
||||
|
||||
const endpointTokenConfig = { 'gpt-4': { prompt: 0.01, completion: 0.03, context: 8192 } };
|
||||
|
||||
await recordCollectedUsage(deps, {
|
||||
...baseParams,
|
||||
endpointTokenConfig,
|
||||
collectedUsage,
|
||||
});
|
||||
|
||||
expect(mockSpendTokens).toHaveBeenCalledWith(
|
||||
{
|
||||
user: 'user-123',
|
||||
conversationId: 'convo-123',
|
||||
model: 'gpt-4',
|
||||
context: 'message',
|
||||
balance: { enabled: true },
|
||||
transactions: { enabled: true },
|
||||
endpointTokenConfig,
|
||||
},
|
||||
{ promptTokens: 100, completionTokens: 50 },
|
||||
);
|
||||
});
|
||||
|
||||
it('should use default context "message" when not provided', async () => {
|
||||
const collectedUsage: UsageMetadata[] = [
|
||||
{ input_tokens: 100, output_tokens: 50, model: 'gpt-4' },
|
||||
];
|
||||
|
||||
await recordCollectedUsage(deps, {
|
||||
user: 'user-123',
|
||||
conversationId: 'convo-123',
|
||||
collectedUsage,
|
||||
});
|
||||
|
||||
expect(mockSpendTokens).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ context: 'message' }),
|
||||
expect.any(Object),
|
||||
);
|
||||
});
|
||||
|
||||
it('should allow custom context like "title"', async () => {
|
||||
const collectedUsage: UsageMetadata[] = [
|
||||
{ input_tokens: 100, output_tokens: 50, model: 'gpt-4' },
|
||||
];
|
||||
|
||||
await recordCollectedUsage(deps, {
|
||||
...baseParams,
|
||||
context: 'title',
|
||||
collectedUsage,
|
||||
});
|
||||
|
||||
expect(mockSpendTokens).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ context: 'title' }),
|
||||
expect.any(Object),
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
146
packages/api/src/agents/usage.ts
Normal file
146
packages/api/src/agents/usage.ts
Normal file
|
|
@ -0,0 +1,146 @@
|
|||
import { logger } from '@librechat/data-schemas';
|
||||
import type { TCustomConfig, TTransactionsConfig } from 'librechat-data-provider';
|
||||
import type { UsageMetadata } from '../stream/interfaces/IJobStore';
|
||||
import type { EndpointTokenConfig } from '../types/tokens';
|
||||
|
||||
interface TokenUsage {
|
||||
promptTokens?: number;
|
||||
completionTokens?: number;
|
||||
}
|
||||
|
||||
interface StructuredPromptTokens {
|
||||
input?: number;
|
||||
write?: number;
|
||||
read?: number;
|
||||
}
|
||||
|
||||
interface StructuredTokenUsage {
|
||||
promptTokens?: StructuredPromptTokens;
|
||||
completionTokens?: number;
|
||||
}
|
||||
|
||||
interface TxMetadata {
|
||||
user: string;
|
||||
model?: string;
|
||||
context: string;
|
||||
conversationId: string;
|
||||
balance?: Partial<TCustomConfig['balance']> | null;
|
||||
transactions?: Partial<TTransactionsConfig>;
|
||||
endpointTokenConfig?: EndpointTokenConfig;
|
||||
}
|
||||
|
||||
type SpendTokensFn = (txData: TxMetadata, tokenUsage: TokenUsage) => Promise<unknown>;
|
||||
type SpendStructuredTokensFn = (
|
||||
txData: TxMetadata,
|
||||
tokenUsage: StructuredTokenUsage,
|
||||
) => Promise<unknown>;
|
||||
|
||||
export interface RecordUsageDeps {
|
||||
spendTokens: SpendTokensFn;
|
||||
spendStructuredTokens: SpendStructuredTokensFn;
|
||||
}
|
||||
|
||||
export interface RecordUsageParams {
|
||||
user: string;
|
||||
conversationId: string;
|
||||
collectedUsage: UsageMetadata[];
|
||||
model?: string;
|
||||
context?: string;
|
||||
balance?: Partial<TCustomConfig['balance']> | null;
|
||||
transactions?: Partial<TTransactionsConfig>;
|
||||
endpointTokenConfig?: EndpointTokenConfig;
|
||||
}
|
||||
|
||||
export interface RecordUsageResult {
|
||||
input_tokens: number;
|
||||
output_tokens: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Records token usage for collected LLM calls and spends tokens against balance.
|
||||
* This handles both sequential execution (tool calls) and parallel execution (multiple agents).
|
||||
*/
|
||||
export async function recordCollectedUsage(
|
||||
deps: RecordUsageDeps,
|
||||
params: RecordUsageParams,
|
||||
): Promise<RecordUsageResult | undefined> {
|
||||
const {
|
||||
user,
|
||||
model,
|
||||
balance,
|
||||
transactions,
|
||||
conversationId,
|
||||
collectedUsage,
|
||||
endpointTokenConfig,
|
||||
context = 'message',
|
||||
} = params;
|
||||
|
||||
const { spendTokens, spendStructuredTokens } = deps;
|
||||
|
||||
if (!collectedUsage || !collectedUsage.length) {
|
||||
return;
|
||||
}
|
||||
|
||||
const firstUsage = collectedUsage[0];
|
||||
const input_tokens =
|
||||
(firstUsage?.input_tokens || 0) +
|
||||
(Number(firstUsage?.input_token_details?.cache_creation) ||
|
||||
Number(firstUsage?.cache_creation_input_tokens) ||
|
||||
0) +
|
||||
(Number(firstUsage?.input_token_details?.cache_read) ||
|
||||
Number(firstUsage?.cache_read_input_tokens) ||
|
||||
0);
|
||||
|
||||
let total_output_tokens = 0;
|
||||
|
||||
for (const usage of collectedUsage) {
|
||||
if (!usage) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const cache_creation =
|
||||
Number(usage.input_token_details?.cache_creation) ||
|
||||
Number(usage.cache_creation_input_tokens) ||
|
||||
0;
|
||||
const cache_read =
|
||||
Number(usage.input_token_details?.cache_read) || Number(usage.cache_read_input_tokens) || 0;
|
||||
|
||||
total_output_tokens += Number(usage.output_tokens) || 0;
|
||||
|
||||
const txMetadata: TxMetadata = {
|
||||
context,
|
||||
balance,
|
||||
transactions,
|
||||
conversationId,
|
||||
user,
|
||||
endpointTokenConfig,
|
||||
model: usage.model ?? model,
|
||||
};
|
||||
|
||||
if (cache_creation > 0 || cache_read > 0) {
|
||||
spendStructuredTokens(txMetadata, {
|
||||
promptTokens: {
|
||||
input: usage.input_tokens,
|
||||
write: cache_creation,
|
||||
read: cache_read,
|
||||
},
|
||||
completionTokens: usage.output_tokens,
|
||||
}).catch((err) => {
|
||||
logger.error('[packages/api #recordCollectedUsage] Error spending structured tokens', err);
|
||||
});
|
||||
continue;
|
||||
}
|
||||
|
||||
spendTokens(txMetadata, {
|
||||
promptTokens: usage.input_tokens,
|
||||
completionTokens: usage.output_tokens,
|
||||
}).catch((err) => {
|
||||
logger.error('[packages/api #recordCollectedUsage] Error spending tokens', err);
|
||||
});
|
||||
}
|
||||
|
||||
return {
|
||||
input_tokens,
|
||||
output_tokens: total_output_tokens,
|
||||
};
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue