diff --git a/api/server/controllers/agents/__tests__/openai.spec.js b/api/server/controllers/agents/__tests__/openai.spec.js new file mode 100644 index 0000000000..03a280b545 --- /dev/null +++ b/api/server/controllers/agents/__tests__/openai.spec.js @@ -0,0 +1,207 @@ +/** + * Unit tests for OpenAI-compatible API controller + * Tests that recordCollectedUsage is called correctly for token spending + */ + +const mockSpendTokens = jest.fn().mockResolvedValue({}); +const mockSpendStructuredTokens = jest.fn().mockResolvedValue({}); +const mockRecordCollectedUsage = jest + .fn() + .mockResolvedValue({ input_tokens: 100, output_tokens: 50 }); +const mockGetBalanceConfig = jest.fn().mockReturnValue({ enabled: true }); +const mockGetTransactionsConfig = jest.fn().mockReturnValue({ enabled: true }); + +jest.mock('nanoid', () => ({ + nanoid: jest.fn(() => 'mock-nanoid-123'), +})); + +jest.mock('@librechat/data-schemas', () => ({ + logger: { + debug: jest.fn(), + error: jest.fn(), + warn: jest.fn(), + }, +})); + +jest.mock('@librechat/agents', () => ({ + Callback: { TOOL_ERROR: 'TOOL_ERROR' }, + ToolEndHandler: jest.fn(), + formatAgentMessages: jest.fn().mockReturnValue({ + messages: [], + indexTokenCountMap: {}, + }), + ChatModelStreamHandler: jest.fn().mockImplementation(() => ({ + handle: jest.fn(), + })), +})); + +jest.mock('@librechat/api', () => ({ + writeSSE: jest.fn(), + createRun: jest.fn().mockResolvedValue({ + processStream: jest.fn().mockResolvedValue(undefined), + }), + createChunk: jest.fn().mockReturnValue({}), + buildToolSet: jest.fn().mockReturnValue(new Set()), + sendFinalChunk: jest.fn(), + createSafeUser: jest.fn().mockReturnValue({ id: 'user-123' }), + validateRequest: jest + .fn() + .mockReturnValue({ request: { model: 'agent-123', messages: [], stream: false } }), + initializeAgent: jest.fn().mockResolvedValue({ + model: 'gpt-4', + model_parameters: {}, + toolRegistry: {}, + }), + getBalanceConfig: mockGetBalanceConfig, + createErrorResponse: jest.fn(), + getTransactionsConfig: mockGetTransactionsConfig, + recordCollectedUsage: mockRecordCollectedUsage, + buildNonStreamingResponse: jest.fn().mockReturnValue({ id: 'resp-123' }), + createOpenAIStreamTracker: jest.fn().mockReturnValue({ + addText: jest.fn(), + addReasoning: jest.fn(), + toolCalls: new Map(), + usage: { promptTokens: 0, completionTokens: 0, reasoningTokens: 0 }, + }), + createOpenAIContentAggregator: jest.fn().mockReturnValue({ + addText: jest.fn(), + addReasoning: jest.fn(), + getText: jest.fn().mockReturnValue(''), + getReasoning: jest.fn().mockReturnValue(''), + toolCalls: new Map(), + usage: { promptTokens: 100, completionTokens: 50, reasoningTokens: 0 }, + }), + createToolExecuteHandler: jest.fn().mockReturnValue({ handle: jest.fn() }), + isChatCompletionValidationFailure: jest.fn().mockReturnValue(false), +})); + +jest.mock('~/server/services/ToolService', () => ({ + loadAgentTools: jest.fn().mockResolvedValue([]), + loadToolsForExecution: jest.fn().mockResolvedValue([]), +})); + +jest.mock('~/models/spendTokens', () => ({ + spendTokens: mockSpendTokens, + spendStructuredTokens: mockSpendStructuredTokens, +})); + +jest.mock('~/server/controllers/agents/callbacks', () => ({ + createToolEndCallback: jest.fn().mockReturnValue(jest.fn()), +})); + +jest.mock('~/server/services/PermissionService', () => ({ + findAccessibleResources: jest.fn().mockResolvedValue([]), +})); + +jest.mock('~/models/Conversation', () => ({ + getConvoFiles: jest.fn().mockResolvedValue([]), +})); + +jest.mock('~/models/Agent', () => ({ + getAgent: jest.fn().mockResolvedValue({ + id: 'agent-123', + provider: 'openAI', + model_parameters: { model: 'gpt-4' }, + }), + getAgents: jest.fn().mockResolvedValue([]), +})); + +jest.mock('~/models', () => ({ + getFiles: jest.fn(), + getUserKey: jest.fn(), + getMessages: jest.fn(), + updateFilesUsage: jest.fn(), + getUserKeyValues: jest.fn(), + getUserCodeFiles: jest.fn(), + getToolFilesByIds: jest.fn(), + getCodeGeneratedFiles: jest.fn(), +})); + +describe('OpenAIChatCompletionController', () => { + let OpenAIChatCompletionController; + let req, res; + + beforeEach(() => { + jest.clearAllMocks(); + + const controller = require('../openai'); + OpenAIChatCompletionController = controller.OpenAIChatCompletionController; + + req = { + body: { + model: 'agent-123', + messages: [{ role: 'user', content: 'Hello' }], + stream: false, + }, + user: { id: 'user-123' }, + config: { + endpoints: { + agents: { allowedProviders: ['openAI'] }, + }, + }, + on: jest.fn(), + }; + + res = { + status: jest.fn().mockReturnThis(), + json: jest.fn(), + setHeader: jest.fn(), + flushHeaders: jest.fn(), + end: jest.fn(), + write: jest.fn(), + }; + }); + + describe('token usage recording', () => { + it('should call recordCollectedUsage after successful non-streaming completion', async () => { + await OpenAIChatCompletionController(req, res); + + expect(mockRecordCollectedUsage).toHaveBeenCalledTimes(1); + expect(mockRecordCollectedUsage).toHaveBeenCalledWith( + { spendTokens: mockSpendTokens, spendStructuredTokens: mockSpendStructuredTokens }, + expect.objectContaining({ + user: 'user-123', + conversationId: expect.any(String), + collectedUsage: expect.any(Array), + context: 'message', + balance: { enabled: true }, + transactions: { enabled: true }, + }), + ); + }); + + it('should pass balance and transactions config to recordCollectedUsage', async () => { + mockGetBalanceConfig.mockReturnValue({ enabled: true, startBalance: 1000 }); + mockGetTransactionsConfig.mockReturnValue({ enabled: true, rateLimit: 100 }); + + await OpenAIChatCompletionController(req, res); + + expect(mockRecordCollectedUsage).toHaveBeenCalledWith( + expect.any(Object), + expect.objectContaining({ + balance: { enabled: true, startBalance: 1000 }, + transactions: { enabled: true, rateLimit: 100 }, + }), + ); + }); + + it('should pass spendTokens and spendStructuredTokens as dependencies', async () => { + await OpenAIChatCompletionController(req, res); + + const [deps] = mockRecordCollectedUsage.mock.calls[0]; + expect(deps).toHaveProperty('spendTokens', mockSpendTokens); + expect(deps).toHaveProperty('spendStructuredTokens', mockSpendStructuredTokens); + }); + + it('should include model from primaryConfig in recordCollectedUsage params', async () => { + await OpenAIChatCompletionController(req, res); + + expect(mockRecordCollectedUsage).toHaveBeenCalledWith( + expect.any(Object), + expect.objectContaining({ + model: 'gpt-4', + }), + ); + }); + }); +}); diff --git a/api/server/controllers/agents/__tests__/responses.unit.spec.js b/api/server/controllers/agents/__tests__/responses.unit.spec.js new file mode 100644 index 0000000000..25e048f2fa --- /dev/null +++ b/api/server/controllers/agents/__tests__/responses.unit.spec.js @@ -0,0 +1,315 @@ +/** + * Unit tests for Open Responses API controller + * Tests that recordCollectedUsage is called correctly for token spending + */ + +const mockSpendTokens = jest.fn().mockResolvedValue({}); +const mockSpendStructuredTokens = jest.fn().mockResolvedValue({}); +const mockRecordCollectedUsage = jest + .fn() + .mockResolvedValue({ input_tokens: 100, output_tokens: 50 }); +const mockGetBalanceConfig = jest.fn().mockReturnValue({ enabled: true }); +const mockGetTransactionsConfig = jest.fn().mockReturnValue({ enabled: true }); + +jest.mock('nanoid', () => ({ + nanoid: jest.fn(() => 'mock-nanoid-123'), +})); + +jest.mock('uuid', () => ({ + v4: jest.fn(() => 'mock-uuid-456'), +})); + +jest.mock('@librechat/data-schemas', () => ({ + logger: { + debug: jest.fn(), + error: jest.fn(), + warn: jest.fn(), + }, +})); + +jest.mock('@librechat/agents', () => ({ + Callback: { TOOL_ERROR: 'TOOL_ERROR' }, + ToolEndHandler: jest.fn(), + formatAgentMessages: jest.fn().mockReturnValue({ + messages: [], + indexTokenCountMap: {}, + }), + ChatModelStreamHandler: jest.fn().mockImplementation(() => ({ + handle: jest.fn(), + })), +})); + +jest.mock('@librechat/api', () => ({ + createRun: jest.fn().mockResolvedValue({ + processStream: jest.fn().mockResolvedValue(undefined), + }), + buildToolSet: jest.fn().mockReturnValue(new Set()), + createSafeUser: jest.fn().mockReturnValue({ id: 'user-123' }), + initializeAgent: jest.fn().mockResolvedValue({ + model: 'claude-3', + model_parameters: {}, + toolRegistry: {}, + }), + getBalanceConfig: mockGetBalanceConfig, + getTransactionsConfig: mockGetTransactionsConfig, + recordCollectedUsage: mockRecordCollectedUsage, + createToolExecuteHandler: jest.fn().mockReturnValue({ handle: jest.fn() }), + // Responses API + writeDone: jest.fn(), + buildResponse: jest.fn().mockReturnValue({ id: 'resp_123', output: [] }), + generateResponseId: jest.fn().mockReturnValue('resp_mock-123'), + isValidationFailure: jest.fn().mockReturnValue(false), + emitResponseCreated: jest.fn(), + createResponseContext: jest.fn().mockReturnValue({ responseId: 'resp_123' }), + createResponseTracker: jest.fn().mockReturnValue({ + usage: { promptTokens: 100, completionTokens: 50 }, + }), + setupStreamingResponse: jest.fn(), + emitResponseInProgress: jest.fn(), + convertInputToMessages: jest.fn().mockReturnValue([]), + validateResponseRequest: jest.fn().mockReturnValue({ + request: { model: 'agent-123', input: 'Hello', stream: false }, + }), + buildAggregatedResponse: jest.fn().mockReturnValue({ + id: 'resp_123', + status: 'completed', + output: [], + usage: { input_tokens: 100, output_tokens: 50, total_tokens: 150 }, + }), + createResponseAggregator: jest.fn().mockReturnValue({ + usage: { promptTokens: 100, completionTokens: 50 }, + }), + sendResponsesErrorResponse: jest.fn(), + createResponsesEventHandlers: jest.fn().mockReturnValue({ + handlers: { + on_message_delta: { handle: jest.fn() }, + on_reasoning_delta: { handle: jest.fn() }, + on_run_step: { handle: jest.fn() }, + on_run_step_delta: { handle: jest.fn() }, + on_chat_model_end: { handle: jest.fn() }, + }, + finalizeStream: jest.fn(), + }), + createAggregatorEventHandlers: jest.fn().mockReturnValue({ + on_message_delta: { handle: jest.fn() }, + on_reasoning_delta: { handle: jest.fn() }, + on_run_step: { handle: jest.fn() }, + on_run_step_delta: { handle: jest.fn() }, + on_chat_model_end: { handle: jest.fn() }, + }), +})); + +jest.mock('~/server/services/ToolService', () => ({ + loadAgentTools: jest.fn().mockResolvedValue([]), + loadToolsForExecution: jest.fn().mockResolvedValue([]), +})); + +jest.mock('~/models/spendTokens', () => ({ + spendTokens: mockSpendTokens, + spendStructuredTokens: mockSpendStructuredTokens, +})); + +jest.mock('~/server/controllers/agents/callbacks', () => ({ + createToolEndCallback: jest.fn().mockReturnValue(jest.fn()), + createResponsesToolEndCallback: jest.fn().mockReturnValue(jest.fn()), +})); + +jest.mock('~/server/services/PermissionService', () => ({ + findAccessibleResources: jest.fn().mockResolvedValue([]), +})); + +jest.mock('~/models/Conversation', () => ({ + getConvoFiles: jest.fn().mockResolvedValue([]), + saveConvo: jest.fn().mockResolvedValue({}), + getConvo: jest.fn().mockResolvedValue(null), +})); + +jest.mock('~/models/Agent', () => ({ + getAgent: jest.fn().mockResolvedValue({ + id: 'agent-123', + name: 'Test Agent', + provider: 'anthropic', + model_parameters: { model: 'claude-3' }, + }), + getAgents: jest.fn().mockResolvedValue([]), +})); + +jest.mock('~/models', () => ({ + getFiles: jest.fn(), + getUserKey: jest.fn(), + getMessages: jest.fn().mockResolvedValue([]), + saveMessage: jest.fn().mockResolvedValue({}), + updateFilesUsage: jest.fn(), + getUserKeyValues: jest.fn(), + getUserCodeFiles: jest.fn(), + getToolFilesByIds: jest.fn(), + getCodeGeneratedFiles: jest.fn(), +})); + +describe('createResponse controller', () => { + let createResponse; + let req, res; + + beforeEach(() => { + jest.clearAllMocks(); + + const controller = require('../responses'); + createResponse = controller.createResponse; + + req = { + body: { + model: 'agent-123', + input: 'Hello', + stream: false, + }, + user: { id: 'user-123' }, + config: { + endpoints: { + agents: { allowedProviders: ['anthropic'] }, + }, + }, + on: jest.fn(), + }; + + res = { + status: jest.fn().mockReturnThis(), + json: jest.fn(), + setHeader: jest.fn(), + flushHeaders: jest.fn(), + end: jest.fn(), + write: jest.fn(), + }; + }); + + describe('token usage recording - non-streaming', () => { + it('should call recordCollectedUsage after successful non-streaming completion', async () => { + await createResponse(req, res); + + expect(mockRecordCollectedUsage).toHaveBeenCalledTimes(1); + expect(mockRecordCollectedUsage).toHaveBeenCalledWith( + { spendTokens: mockSpendTokens, spendStructuredTokens: mockSpendStructuredTokens }, + expect.objectContaining({ + user: 'user-123', + conversationId: expect.any(String), + collectedUsage: expect.any(Array), + context: 'message', + }), + ); + }); + + it('should pass balance and transactions config to recordCollectedUsage', async () => { + mockGetBalanceConfig.mockReturnValue({ enabled: true, startBalance: 2000 }); + mockGetTransactionsConfig.mockReturnValue({ enabled: true }); + + await createResponse(req, res); + + expect(mockRecordCollectedUsage).toHaveBeenCalledWith( + expect.any(Object), + expect.objectContaining({ + balance: { enabled: true, startBalance: 2000 }, + transactions: { enabled: true }, + }), + ); + }); + + it('should pass spendTokens and spendStructuredTokens as dependencies', async () => { + await createResponse(req, res); + + const [deps] = mockRecordCollectedUsage.mock.calls[0]; + expect(deps).toHaveProperty('spendTokens', mockSpendTokens); + expect(deps).toHaveProperty('spendStructuredTokens', mockSpendStructuredTokens); + }); + + it('should include model from primaryConfig in recordCollectedUsage params', async () => { + await createResponse(req, res); + + expect(mockRecordCollectedUsage).toHaveBeenCalledWith( + expect.any(Object), + expect.objectContaining({ + model: 'claude-3', + }), + ); + }); + }); + + describe('token usage recording - streaming', () => { + beforeEach(() => { + req.body.stream = true; + + const api = require('@librechat/api'); + api.validateResponseRequest.mockReturnValue({ + request: { model: 'agent-123', input: 'Hello', stream: true }, + }); + }); + + it('should call recordCollectedUsage after successful streaming completion', async () => { + await createResponse(req, res); + + expect(mockRecordCollectedUsage).toHaveBeenCalledTimes(1); + expect(mockRecordCollectedUsage).toHaveBeenCalledWith( + { spendTokens: mockSpendTokens, spendStructuredTokens: mockSpendStructuredTokens }, + expect.objectContaining({ + user: 'user-123', + context: 'message', + }), + ); + }); + }); + + describe('collectedUsage population', () => { + it('should collect usage from on_chat_model_end events', async () => { + const api = require('@librechat/api'); + + let capturedOnChatModelEnd; + api.createAggregatorEventHandlers.mockImplementation(() => { + return { + on_message_delta: { handle: jest.fn() }, + on_reasoning_delta: { handle: jest.fn() }, + on_run_step: { handle: jest.fn() }, + on_run_step_delta: { handle: jest.fn() }, + on_chat_model_end: { + handle: jest.fn((event, data) => { + if (capturedOnChatModelEnd) { + capturedOnChatModelEnd(event, data); + } + }), + }, + }; + }); + + api.createRun.mockImplementation(async ({ customHandlers }) => { + capturedOnChatModelEnd = (event, data) => { + customHandlers.on_chat_model_end.handle(event, data); + }; + + return { + processStream: jest.fn().mockImplementation(async () => { + customHandlers.on_chat_model_end.handle('on_chat_model_end', { + output: { + usage_metadata: { + input_tokens: 150, + output_tokens: 75, + model: 'claude-3', + }, + }, + }); + }), + }; + }); + + await createResponse(req, res); + + expect(mockRecordCollectedUsage).toHaveBeenCalledWith( + expect.any(Object), + expect.objectContaining({ + collectedUsage: expect.arrayContaining([ + expect.objectContaining({ + input_tokens: 150, + output_tokens: 75, + }), + ]), + }), + ); + }); + }); +}); diff --git a/api/server/controllers/agents/openai.js b/api/server/controllers/agents/openai.js index 605dc7c26a..d4dc82174d 100644 --- a/api/server/controllers/agents/openai.js +++ b/api/server/controllers/agents/openai.js @@ -16,16 +16,20 @@ const { createSafeUser, validateRequest, initializeAgent, + getBalanceConfig, createErrorResponse, + recordCollectedUsage, + getTransactionsConfig, + createToolExecuteHandler, buildNonStreamingResponse, createOpenAIStreamTracker, createOpenAIContentAggregator, - createToolExecuteHandler, isChatCompletionValidationFailure, } = require('@librechat/api'); const { loadAgentTools, loadToolsForExecution } = require('~/server/services/ToolService'); const { createToolEndCallback } = require('~/server/controllers/agents/callbacks'); const { findAccessibleResources } = require('~/server/services/PermissionService'); +const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens'); const { getConvoFiles } = require('~/models/Conversation'); const { getAgent, getAgents } = require('~/models/Agent'); const db = require('~/models'); @@ -497,6 +501,24 @@ const OpenAIChatCompletionController = async (req, res) => { }, }); + // Record token usage against balance + const balanceConfig = getBalanceConfig(appConfig); + const transactionsConfig = getTransactionsConfig(appConfig); + recordCollectedUsage( + { spendTokens, spendStructuredTokens }, + { + user: userId, + conversationId, + collectedUsage, + context: 'message', + balance: balanceConfig, + transactions: transactionsConfig, + model: primaryConfig.model || agent.model_parameters?.model, + }, + ).catch((err) => { + logger.error('[OpenAI API] Error recording usage:', err); + }); + // Finalize response const duration = Date.now() - requestStartTime; if (isStreaming) { diff --git a/api/server/controllers/agents/responses.js b/api/server/controllers/agents/responses.js index 06d1249ec6..3cd1dff5eb 100644 --- a/api/server/controllers/agents/responses.js +++ b/api/server/controllers/agents/responses.js @@ -13,6 +13,9 @@ const { buildToolSet, createSafeUser, initializeAgent, + getBalanceConfig, + recordCollectedUsage, + getTransactionsConfig, createToolExecuteHandler, // Responses API writeDone, @@ -39,6 +42,7 @@ const { const { loadAgentTools, loadToolsForExecution } = require('~/server/services/ToolService'); const { findAccessibleResources } = require('~/server/services/PermissionService'); const { getConvoFiles, saveConvo, getConvo } = require('~/models/Conversation'); +const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens'); const { getAgent, getAgents } = require('~/models/Agent'); const db = require('~/models'); @@ -403,6 +407,9 @@ const createResponse = async (req, res) => { const { handlers: responsesHandlers, finalizeStream } = createResponsesEventHandlers(handlerConfig); + // Collect usage for balance tracking + const collectedUsage = []; + // Built-in handler for processing raw model stream chunks const chatModelStreamHandler = new ChatModelStreamHandler(); @@ -445,7 +452,15 @@ const createResponse = async (req, res) => { on_reasoning_delta: responsesHandlers.on_reasoning_delta, on_run_step: responsesHandlers.on_run_step, on_run_step_delta: responsesHandlers.on_run_step_delta, - on_chat_model_end: responsesHandlers.on_chat_model_end, + on_chat_model_end: { + handle: (event, data) => { + responsesHandlers.on_chat_model_end.handle(event, data); + const usage = data?.output?.usage_metadata; + if (usage) { + collectedUsage.push(usage); + } + }, + }, on_tool_end: new ToolEndHandler(toolEndCallback, logger), on_run_step_completed: { handle: () => {} }, on_chain_stream: { handle: () => {} }, @@ -499,6 +514,24 @@ const createResponse = async (req, res) => { }, }); + // Record token usage against balance + const balanceConfig = getBalanceConfig(req.config); + const transactionsConfig = getTransactionsConfig(req.config); + recordCollectedUsage( + { spendTokens, spendStructuredTokens }, + { + user: userId, + conversationId, + collectedUsage, + context: 'message', + balance: balanceConfig, + transactions: transactionsConfig, + model: primaryConfig.model || agent.model_parameters?.model, + }, + ).catch((err) => { + logger.error('[Responses API] Error recording usage:', err); + }); + // Finalize the stream finalizeStream(); res.end(); @@ -539,6 +572,9 @@ const createResponse = async (req, res) => { const chatModelStreamHandler = new ChatModelStreamHandler(); + // Collect usage for balance tracking + const collectedUsage = []; + /** @type {Promise[]} */ const artifactPromises = []; const toolEndCallback = createToolEndCallback({ req, res, artifactPromises, streamId: null }); @@ -569,7 +605,15 @@ const createResponse = async (req, res) => { on_reasoning_delta: aggregatorHandlers.on_reasoning_delta, on_run_step: aggregatorHandlers.on_run_step, on_run_step_delta: aggregatorHandlers.on_run_step_delta, - on_chat_model_end: aggregatorHandlers.on_chat_model_end, + on_chat_model_end: { + handle: (event, data) => { + aggregatorHandlers.on_chat_model_end.handle(event, data); + const usage = data?.output?.usage_metadata; + if (usage) { + collectedUsage.push(usage); + } + }, + }, on_tool_end: new ToolEndHandler(toolEndCallback, logger), on_run_step_completed: { handle: () => {} }, on_chain_stream: { handle: () => {} }, @@ -621,6 +665,24 @@ const createResponse = async (req, res) => { }, }); + // Record token usage against balance + const balanceConfig = getBalanceConfig(req.config); + const transactionsConfig = getTransactionsConfig(req.config); + recordCollectedUsage( + { spendTokens, spendStructuredTokens }, + { + user: userId, + conversationId, + collectedUsage, + context: 'message', + balance: balanceConfig, + transactions: transactionsConfig, + model: primaryConfig.model || agent.model_parameters?.model, + }, + ).catch((err) => { + logger.error('[Responses API] Error recording usage:', err); + }); + if (artifactPromises.length > 0) { try { await Promise.all(artifactPromises); diff --git a/packages/api/src/agents/index.ts b/packages/api/src/agents/index.ts index 5d2b14920f..a5a0c340fe 100644 --- a/packages/api/src/agents/index.ts +++ b/packages/api/src/agents/index.ts @@ -8,6 +8,7 @@ export * from './legacy'; export * from './memory'; export * from './migration'; export * from './openai'; +export * from './usage'; export * from './resources'; export * from './responses'; export * from './run'; diff --git a/packages/api/src/agents/usage.spec.ts b/packages/api/src/agents/usage.spec.ts new file mode 100644 index 0000000000..9c06567dc4 --- /dev/null +++ b/packages/api/src/agents/usage.spec.ts @@ -0,0 +1,434 @@ +import { recordCollectedUsage } from './usage'; +import type { RecordUsageDeps, RecordUsageParams } from './usage'; +import type { UsageMetadata } from '../stream/interfaces/IJobStore'; + +describe('recordCollectedUsage', () => { + let mockSpendTokens: jest.Mock; + let mockSpendStructuredTokens: jest.Mock; + let deps: RecordUsageDeps; + + const baseParams: Omit = { + user: 'user-123', + conversationId: 'convo-123', + model: 'gpt-4', + context: 'message', + balance: { enabled: true }, + transactions: { enabled: true }, + }; + + beforeEach(() => { + jest.clearAllMocks(); + mockSpendTokens = jest.fn().mockResolvedValue(undefined); + mockSpendStructuredTokens = jest.fn().mockResolvedValue(undefined); + deps = { + spendTokens: mockSpendTokens, + spendStructuredTokens: mockSpendStructuredTokens, + }; + }); + + describe('basic functionality', () => { + it('should return undefined if collectedUsage is empty', async () => { + const result = await recordCollectedUsage(deps, { + ...baseParams, + collectedUsage: [], + }); + + expect(result).toBeUndefined(); + expect(mockSpendTokens).not.toHaveBeenCalled(); + expect(mockSpendStructuredTokens).not.toHaveBeenCalled(); + }); + + it('should return undefined if collectedUsage is null-ish', async () => { + const result = await recordCollectedUsage(deps, { + ...baseParams, + collectedUsage: null as unknown as UsageMetadata[], + }); + + expect(result).toBeUndefined(); + expect(mockSpendTokens).not.toHaveBeenCalled(); + }); + + it('should handle single usage entry correctly', async () => { + const collectedUsage: UsageMetadata[] = [ + { input_tokens: 100, output_tokens: 50, model: 'gpt-4' }, + ]; + + const result = await recordCollectedUsage(deps, { + ...baseParams, + collectedUsage, + }); + + expect(mockSpendTokens).toHaveBeenCalledTimes(1); + expect(mockSpendTokens).toHaveBeenCalledWith( + expect.objectContaining({ + user: 'user-123', + conversationId: 'convo-123', + model: 'gpt-4', + context: 'message', + }), + { promptTokens: 100, completionTokens: 50 }, + ); + expect(result).toEqual({ input_tokens: 100, output_tokens: 50 }); + }); + + it('should skip null entries in collectedUsage', async () => { + const collectedUsage = [ + { input_tokens: 100, output_tokens: 50, model: 'gpt-4' }, + null, + { input_tokens: 200, output_tokens: 60, model: 'gpt-4' }, + ] as UsageMetadata[]; + + const result = await recordCollectedUsage(deps, { + ...baseParams, + collectedUsage, + }); + + expect(mockSpendTokens).toHaveBeenCalledTimes(2); + expect(result).toEqual({ input_tokens: 100, output_tokens: 110 }); + }); + }); + + describe('sequential execution (tool calls)', () => { + it('should calculate tokens correctly for sequential tool calls', async () => { + const collectedUsage: UsageMetadata[] = [ + { input_tokens: 100, output_tokens: 50, model: 'gpt-4' }, + { input_tokens: 150, output_tokens: 30, model: 'gpt-4' }, + { input_tokens: 180, output_tokens: 20, model: 'gpt-4' }, + ]; + + const result = await recordCollectedUsage(deps, { + ...baseParams, + collectedUsage, + }); + + expect(mockSpendTokens).toHaveBeenCalledTimes(3); + expect(result?.output_tokens).toBe(100); // 50 + 30 + 20 + expect(result?.input_tokens).toBe(100); // First entry's input + }); + }); + + describe('parallel execution (multiple agents)', () => { + it('should handle parallel agents with independent input tokens', async () => { + const collectedUsage: UsageMetadata[] = [ + { input_tokens: 100, output_tokens: 50, model: 'gpt-4' }, + { input_tokens: 80, output_tokens: 40, model: 'gpt-4' }, + ]; + + const result = await recordCollectedUsage(deps, { + ...baseParams, + collectedUsage, + }); + + expect(mockSpendTokens).toHaveBeenCalledTimes(2); + expect(result?.output_tokens).toBe(90); // 50 + 40 + expect(result?.output_tokens).toBeGreaterThan(0); + }); + + it('should NOT produce negative output_tokens for parallel execution', async () => { + const collectedUsage: UsageMetadata[] = [ + { input_tokens: 200, output_tokens: 100, model: 'gpt-4' }, + { input_tokens: 50, output_tokens: 30, model: 'gpt-4' }, + ]; + + const result = await recordCollectedUsage(deps, { + ...baseParams, + collectedUsage, + }); + + expect(result?.output_tokens).toBeGreaterThan(0); + expect(result?.output_tokens).toBe(130); // 100 + 30 + }); + + it('should calculate correct total output for multiple parallel agents', async () => { + const collectedUsage: UsageMetadata[] = [ + { input_tokens: 100, output_tokens: 50, model: 'gpt-4' }, + { input_tokens: 120, output_tokens: 60, model: 'gpt-4-turbo' }, + { input_tokens: 80, output_tokens: 40, model: 'claude-3' }, + ]; + + const result = await recordCollectedUsage(deps, { + ...baseParams, + collectedUsage, + }); + + expect(mockSpendTokens).toHaveBeenCalledTimes(3); + expect(result?.output_tokens).toBe(150); // 50 + 60 + 40 + }); + }); + + describe('cache token handling - OpenAI format', () => { + it('should use spendStructuredTokens for cache tokens (input_token_details)', async () => { + const collectedUsage: UsageMetadata[] = [ + { + input_tokens: 100, + output_tokens: 50, + model: 'gpt-4', + input_token_details: { + cache_creation: 20, + cache_read: 10, + }, + }, + ]; + + const result = await recordCollectedUsage(deps, { + ...baseParams, + collectedUsage, + }); + + expect(mockSpendStructuredTokens).toHaveBeenCalledTimes(1); + expect(mockSpendTokens).not.toHaveBeenCalled(); + expect(mockSpendStructuredTokens).toHaveBeenCalledWith( + expect.objectContaining({ model: 'gpt-4' }), + { + promptTokens: { input: 100, write: 20, read: 10 }, + completionTokens: 50, + }, + ); + expect(result?.input_tokens).toBe(130); // 100 + 20 + 10 + }); + }); + + describe('cache token handling - Anthropic format', () => { + it('should use spendStructuredTokens for cache tokens (cache_*_input_tokens)', async () => { + const collectedUsage: UsageMetadata[] = [ + { + input_tokens: 100, + output_tokens: 50, + model: 'claude-3', + cache_creation_input_tokens: 25, + cache_read_input_tokens: 15, + }, + ]; + + const result = await recordCollectedUsage(deps, { + ...baseParams, + collectedUsage, + }); + + expect(mockSpendStructuredTokens).toHaveBeenCalledTimes(1); + expect(mockSpendTokens).not.toHaveBeenCalled(); + expect(mockSpendStructuredTokens).toHaveBeenCalledWith( + expect.objectContaining({ model: 'claude-3' }), + { + promptTokens: { input: 100, write: 25, read: 15 }, + completionTokens: 50, + }, + ); + expect(result?.input_tokens).toBe(140); // 100 + 25 + 15 + }); + }); + + describe('mixed cache and non-cache entries', () => { + it('should handle mixed entries correctly', async () => { + const collectedUsage: UsageMetadata[] = [ + { input_tokens: 100, output_tokens: 50, model: 'gpt-4' }, + { + input_tokens: 150, + output_tokens: 30, + model: 'gpt-4', + input_token_details: { cache_creation: 10, cache_read: 5 }, + }, + { input_tokens: 200, output_tokens: 20, model: 'gpt-4' }, + ]; + + const result = await recordCollectedUsage(deps, { + ...baseParams, + collectedUsage, + }); + + expect(mockSpendTokens).toHaveBeenCalledTimes(2); + expect(mockSpendStructuredTokens).toHaveBeenCalledTimes(1); + expect(result?.output_tokens).toBe(100); // 50 + 30 + 20 + }); + }); + + describe('model fallback', () => { + it('should use usage.model when available', async () => { + const collectedUsage: UsageMetadata[] = [ + { input_tokens: 100, output_tokens: 50, model: 'gpt-4-turbo' }, + ]; + + await recordCollectedUsage(deps, { + ...baseParams, + model: 'fallback-model', + collectedUsage, + }); + + expect(mockSpendTokens).toHaveBeenCalledWith( + expect.objectContaining({ model: 'gpt-4-turbo' }), + expect.any(Object), + ); + }); + + it('should fallback to param model when usage.model is missing', async () => { + const collectedUsage: UsageMetadata[] = [{ input_tokens: 100, output_tokens: 50 }]; + + await recordCollectedUsage(deps, { + ...baseParams, + model: 'param-model', + collectedUsage, + }); + + expect(mockSpendTokens).toHaveBeenCalledWith( + expect.objectContaining({ model: 'param-model' }), + expect.any(Object), + ); + }); + }); + + describe('real-world scenarios', () => { + it('should correctly sum output tokens for sequential tool calls with growing context', async () => { + const collectedUsage: UsageMetadata[] = [ + { input_tokens: 31596, output_tokens: 151, model: 'claude-opus' }, + { input_tokens: 35368, output_tokens: 150, model: 'claude-opus' }, + { input_tokens: 58362, output_tokens: 295, model: 'claude-opus' }, + { input_tokens: 112604, output_tokens: 193, model: 'claude-opus' }, + { input_tokens: 257440, output_tokens: 2217, model: 'claude-opus' }, + ]; + + const result = await recordCollectedUsage(deps, { + ...baseParams, + collectedUsage, + }); + + expect(result?.input_tokens).toBe(31596); + expect(result?.output_tokens).toBe(3006); // 151 + 150 + 295 + 193 + 2217 + expect(mockSpendTokens).toHaveBeenCalledTimes(5); + }); + + it('should handle cache tokens with multiple tool calls', async () => { + const collectedUsage: UsageMetadata[] = [ + { + input_tokens: 788, + output_tokens: 163, + model: 'claude-opus', + input_token_details: { cache_read: 0, cache_creation: 30808 }, + }, + { + input_tokens: 3802, + output_tokens: 149, + model: 'claude-opus', + input_token_details: { cache_read: 30808, cache_creation: 768 }, + }, + { + input_tokens: 26808, + output_tokens: 225, + model: 'claude-opus', + input_token_details: { cache_read: 31576, cache_creation: 0 }, + }, + ]; + + const result = await recordCollectedUsage(deps, { + ...baseParams, + collectedUsage, + }); + + // input_tokens = 788 + 30808 + 0 = 31596 + expect(result?.input_tokens).toBe(31596); + // output_tokens = 163 + 149 + 225 = 537 + expect(result?.output_tokens).toBe(537); + expect(mockSpendStructuredTokens).toHaveBeenCalledTimes(3); + expect(mockSpendTokens).not.toHaveBeenCalled(); + }); + }); + + describe('error handling', () => { + it('should catch and log errors from spendTokens without throwing', async () => { + mockSpendTokens.mockRejectedValue(new Error('DB error')); + + const collectedUsage: UsageMetadata[] = [ + { input_tokens: 100, output_tokens: 50, model: 'gpt-4' }, + ]; + + const result = await recordCollectedUsage(deps, { + ...baseParams, + collectedUsage, + }); + + expect(result).toEqual({ input_tokens: 100, output_tokens: 50 }); + }); + + it('should catch and log errors from spendStructuredTokens without throwing', async () => { + mockSpendStructuredTokens.mockRejectedValue(new Error('DB error')); + + const collectedUsage: UsageMetadata[] = [ + { + input_tokens: 100, + output_tokens: 50, + model: 'gpt-4', + input_token_details: { cache_creation: 20, cache_read: 10 }, + }, + ]; + + const result = await recordCollectedUsage(deps, { + ...baseParams, + collectedUsage, + }); + + expect(result).toEqual({ input_tokens: 130, output_tokens: 50 }); + }); + }); + + describe('transaction metadata', () => { + it('should pass all metadata fields to spend functions', async () => { + const collectedUsage: UsageMetadata[] = [ + { input_tokens: 100, output_tokens: 50, model: 'gpt-4' }, + ]; + + const endpointTokenConfig = { 'gpt-4': { prompt: 0.01, completion: 0.03, context: 8192 } }; + + await recordCollectedUsage(deps, { + ...baseParams, + endpointTokenConfig, + collectedUsage, + }); + + expect(mockSpendTokens).toHaveBeenCalledWith( + { + user: 'user-123', + conversationId: 'convo-123', + model: 'gpt-4', + context: 'message', + balance: { enabled: true }, + transactions: { enabled: true }, + endpointTokenConfig, + }, + { promptTokens: 100, completionTokens: 50 }, + ); + }); + + it('should use default context "message" when not provided', async () => { + const collectedUsage: UsageMetadata[] = [ + { input_tokens: 100, output_tokens: 50, model: 'gpt-4' }, + ]; + + await recordCollectedUsage(deps, { + user: 'user-123', + conversationId: 'convo-123', + collectedUsage, + }); + + expect(mockSpendTokens).toHaveBeenCalledWith( + expect.objectContaining({ context: 'message' }), + expect.any(Object), + ); + }); + + it('should allow custom context like "title"', async () => { + const collectedUsage: UsageMetadata[] = [ + { input_tokens: 100, output_tokens: 50, model: 'gpt-4' }, + ]; + + await recordCollectedUsage(deps, { + ...baseParams, + context: 'title', + collectedUsage, + }); + + expect(mockSpendTokens).toHaveBeenCalledWith( + expect.objectContaining({ context: 'title' }), + expect.any(Object), + ); + }); + }); +}); diff --git a/packages/api/src/agents/usage.ts b/packages/api/src/agents/usage.ts new file mode 100644 index 0000000000..545be9195d --- /dev/null +++ b/packages/api/src/agents/usage.ts @@ -0,0 +1,146 @@ +import { logger } from '@librechat/data-schemas'; +import type { TCustomConfig, TTransactionsConfig } from 'librechat-data-provider'; +import type { UsageMetadata } from '../stream/interfaces/IJobStore'; +import type { EndpointTokenConfig } from '../types/tokens'; + +interface TokenUsage { + promptTokens?: number; + completionTokens?: number; +} + +interface StructuredPromptTokens { + input?: number; + write?: number; + read?: number; +} + +interface StructuredTokenUsage { + promptTokens?: StructuredPromptTokens; + completionTokens?: number; +} + +interface TxMetadata { + user: string; + model?: string; + context: string; + conversationId: string; + balance?: Partial | null; + transactions?: Partial; + endpointTokenConfig?: EndpointTokenConfig; +} + +type SpendTokensFn = (txData: TxMetadata, tokenUsage: TokenUsage) => Promise; +type SpendStructuredTokensFn = ( + txData: TxMetadata, + tokenUsage: StructuredTokenUsage, +) => Promise; + +export interface RecordUsageDeps { + spendTokens: SpendTokensFn; + spendStructuredTokens: SpendStructuredTokensFn; +} + +export interface RecordUsageParams { + user: string; + conversationId: string; + collectedUsage: UsageMetadata[]; + model?: string; + context?: string; + balance?: Partial | null; + transactions?: Partial; + endpointTokenConfig?: EndpointTokenConfig; +} + +export interface RecordUsageResult { + input_tokens: number; + output_tokens: number; +} + +/** + * Records token usage for collected LLM calls and spends tokens against balance. + * This handles both sequential execution (tool calls) and parallel execution (multiple agents). + */ +export async function recordCollectedUsage( + deps: RecordUsageDeps, + params: RecordUsageParams, +): Promise { + const { + user, + model, + balance, + transactions, + conversationId, + collectedUsage, + endpointTokenConfig, + context = 'message', + } = params; + + const { spendTokens, spendStructuredTokens } = deps; + + if (!collectedUsage || !collectedUsage.length) { + return; + } + + const firstUsage = collectedUsage[0]; + const input_tokens = + (firstUsage?.input_tokens || 0) + + (Number(firstUsage?.input_token_details?.cache_creation) || + Number(firstUsage?.cache_creation_input_tokens) || + 0) + + (Number(firstUsage?.input_token_details?.cache_read) || + Number(firstUsage?.cache_read_input_tokens) || + 0); + + let total_output_tokens = 0; + + for (const usage of collectedUsage) { + if (!usage) { + continue; + } + + const cache_creation = + Number(usage.input_token_details?.cache_creation) || + Number(usage.cache_creation_input_tokens) || + 0; + const cache_read = + Number(usage.input_token_details?.cache_read) || Number(usage.cache_read_input_tokens) || 0; + + total_output_tokens += Number(usage.output_tokens) || 0; + + const txMetadata: TxMetadata = { + context, + balance, + transactions, + conversationId, + user, + endpointTokenConfig, + model: usage.model ?? model, + }; + + if (cache_creation > 0 || cache_read > 0) { + spendStructuredTokens(txMetadata, { + promptTokens: { + input: usage.input_tokens, + write: cache_creation, + read: cache_read, + }, + completionTokens: usage.output_tokens, + }).catch((err) => { + logger.error('[packages/api #recordCollectedUsage] Error spending structured tokens', err); + }); + continue; + } + + spendTokens(txMetadata, { + promptTokens: usage.input_tokens, + completionTokens: usage.output_tokens, + }).catch((err) => { + logger.error('[packages/api #recordCollectedUsage] Error spending tokens', err); + }); + } + + return { + input_tokens, + output_tokens: total_output_tokens, + }; +}