💰 fix: Multi-Agent Token Spending & Prevent Double-Spend (#11433)

* fix: Token Spending Logic for Multi-Agents on Abort Scenarios

* Implemented logic to skip token spending if a conversation is aborted, preventing double-spending.
* Introduced `spendCollectedUsage` function to handle token spending for multiple models during aborts, ensuring accurate accounting for parallel agents.
* Updated `GenerationJobManager` to store and retrieve collected usage data for improved abort handling.
* Added comprehensive tests for the new functionality, covering various scenarios including cache token handling and parallel agent usage.

* fix: Memory Context Handling for Multi-Agents

* Refactored `buildMessages` method to pass memory context to parallel agents, ensuring they share the same user context.
* Improved handling of memory context when no existing instructions are present for parallel agents.
* Added comprehensive tests to verify memory context propagation and behavior under various scenarios, including cases with no memory available and empty agent configurations.
* Enhanced logging for better traceability of memory context additions to agents.

* chore: Memory Context Documentation for Parallel Agents

* Updated documentation in the `AgentClient` class to clarify the in-place mutation of agentConfig objects when passing memory context to parallel agents.
* Added notes on the implications of mutating objects directly to ensure all parallel agents receive the correct memory context before execution.

* chore: UsageMetadata Interface docs for Token Spending

* Expanded the UsageMetadata interface to support both OpenAI and Anthropic cache token formats.
* Added detailed documentation for cache token properties, including mutually exclusive fields for different model types.
* Improved clarity on how to access cache token details for accurate token spending tracking.

* fix: Enhance Token Spending Logic in Abort Middleware

* Refactored `spendCollectedUsage` function to utilize Promise.all for concurrent token spending, improving performance and ensuring all operations complete before clearing the collectedUsage array.
* Added documentation to clarify the importance of clearing the collectedUsage array to prevent double-spending in abort scenarios.
* Updated tests to verify the correct behavior of the spending logic and the clearing of the array after spending operations.
This commit is contained in:
Danny Avila 2026-01-20 14:43:19 -05:00 committed by GitHub
parent 32e6f3b8e5
commit 36c5a88c4e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 1440 additions and 28 deletions

View file

@ -7,13 +7,89 @@ const {
sanitizeMessageForTransmit,
} = require('@librechat/api');
const { isAssistantsEndpoint, ErrorTypes } = require('librechat-data-provider');
const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens');
const { truncateText, smartTruncateText } = require('~/app/clients/prompts');
const clearPendingReq = require('~/cache/clearPendingReq');
const { sendError } = require('~/server/middleware/error');
const { spendTokens } = require('~/models/spendTokens');
const { saveMessage, getConvo } = require('~/models');
const { abortRun } = require('./abortRun');
/**
* Spend tokens for all models from collected usage.
* This handles both sequential and parallel agent execution.
*
* IMPORTANT: After spending, this function clears the collectedUsage array
* to prevent double-spending. The array is shared with AgentClient.collectedUsage,
* so clearing it here prevents the finally block from also spending tokens.
*
* @param {Object} params
* @param {string} params.userId - User ID
* @param {string} params.conversationId - Conversation ID
* @param {Array<Object>} params.collectedUsage - Usage metadata from all models
* @param {string} [params.fallbackModel] - Fallback model name if not in usage
*/
async function spendCollectedUsage({ userId, conversationId, collectedUsage, fallbackModel }) {
if (!collectedUsage || collectedUsage.length === 0) {
return;
}
const spendPromises = [];
for (const usage of collectedUsage) {
if (!usage) {
continue;
}
// Support both OpenAI format (input_token_details) and Anthropic format (cache_*_input_tokens)
const cache_creation =
Number(usage.input_token_details?.cache_creation) ||
Number(usage.cache_creation_input_tokens) ||
0;
const cache_read =
Number(usage.input_token_details?.cache_read) || Number(usage.cache_read_input_tokens) || 0;
const txMetadata = {
context: 'abort',
conversationId,
user: userId,
model: usage.model ?? fallbackModel,
};
if (cache_creation > 0 || cache_read > 0) {
spendPromises.push(
spendStructuredTokens(txMetadata, {
promptTokens: {
input: usage.input_tokens,
write: cache_creation,
read: cache_read,
},
completionTokens: usage.output_tokens,
}).catch((err) => {
logger.error('[abortMiddleware] Error spending structured tokens for abort', err);
}),
);
continue;
}
spendPromises.push(
spendTokens(txMetadata, {
promptTokens: usage.input_tokens,
completionTokens: usage.output_tokens,
}).catch((err) => {
logger.error('[abortMiddleware] Error spending tokens for abort', err);
}),
);
}
// Wait for all token spending to complete
await Promise.all(spendPromises);
// Clear the array to prevent double-spending from the AgentClient finally block.
// The collectedUsage array is shared by reference with AgentClient.collectedUsage,
// so clearing it here ensures recordCollectedUsage() sees an empty array and returns early.
collectedUsage.length = 0;
}
/**
* Abort an active message generation.
* Uses GenerationJobManager for all agent requests.
@ -39,9 +115,8 @@ async function abortMessage(req, res) {
return;
}
const { jobData, content, text } = abortResult;
const { jobData, content, text, collectedUsage } = abortResult;
// Count tokens and spend them
const completionTokens = await countTokens(text);
const promptTokens = jobData?.promptTokens ?? 0;
@ -62,10 +137,21 @@ async function abortMessage(req, res) {
tokenCount: completionTokens,
};
await spendTokens(
{ ...responseMessage, context: 'incomplete', user: userId },
{ promptTokens, completionTokens },
);
// Spend tokens for ALL models from collectedUsage (handles parallel agents/addedConvo)
if (collectedUsage && collectedUsage.length > 0) {
await spendCollectedUsage({
userId,
conversationId: jobData?.conversationId,
collectedUsage,
fallbackModel: jobData?.model,
});
} else {
// Fallback: no collected usage, use text-based token counting for primary model only
await spendTokens(
{ ...responseMessage, context: 'incomplete', user: userId },
{ promptTokens, completionTokens },
);
}
await saveMessage(
req,

View file

@ -0,0 +1,428 @@
/**
* Tests for abortMiddleware - spendCollectedUsage function
*
* This tests the token spending logic for abort scenarios,
* particularly for parallel agents (addedConvo) where multiple
* models need their tokens spent.
*/
const mockSpendTokens = jest.fn().mockResolvedValue();
const mockSpendStructuredTokens = jest.fn().mockResolvedValue();
jest.mock('~/models/spendTokens', () => ({
spendTokens: (...args) => mockSpendTokens(...args),
spendStructuredTokens: (...args) => mockSpendStructuredTokens(...args),
}));
jest.mock('@librechat/data-schemas', () => ({
logger: {
debug: jest.fn(),
error: jest.fn(),
warn: jest.fn(),
info: jest.fn(),
},
}));
jest.mock('@librechat/api', () => ({
countTokens: jest.fn().mockResolvedValue(100),
isEnabled: jest.fn().mockReturnValue(false),
sendEvent: jest.fn(),
GenerationJobManager: {
abortJob: jest.fn(),
},
sanitizeMessageForTransmit: jest.fn((msg) => msg),
}));
jest.mock('librechat-data-provider', () => ({
isAssistantsEndpoint: jest.fn().mockReturnValue(false),
ErrorTypes: { INVALID_REQUEST: 'INVALID_REQUEST', NO_SYSTEM_MESSAGES: 'NO_SYSTEM_MESSAGES' },
}));
jest.mock('~/app/clients/prompts', () => ({
truncateText: jest.fn((text) => text),
smartTruncateText: jest.fn((text) => text),
}));
jest.mock('~/cache/clearPendingReq', () => jest.fn().mockResolvedValue());
jest.mock('~/server/middleware/error', () => ({
sendError: jest.fn(),
}));
jest.mock('~/models', () => ({
saveMessage: jest.fn().mockResolvedValue(),
getConvo: jest.fn().mockResolvedValue({ title: 'Test Chat' }),
}));
jest.mock('./abortRun', () => ({
abortRun: jest.fn(),
}));
// Import the module after mocks are set up
// We need to extract the spendCollectedUsage function for testing
// Since it's not exported, we'll test it through the handleAbort flow
describe('abortMiddleware - spendCollectedUsage', () => {
beforeEach(() => {
jest.clearAllMocks();
});
describe('spendCollectedUsage logic', () => {
// Since spendCollectedUsage is not exported, we test the logic directly
// by replicating the function here for unit testing
const spendCollectedUsage = async ({
userId,
conversationId,
collectedUsage,
fallbackModel,
}) => {
if (!collectedUsage || collectedUsage.length === 0) {
return;
}
const spendPromises = [];
for (const usage of collectedUsage) {
if (!usage) {
continue;
}
const cache_creation =
Number(usage.input_token_details?.cache_creation) ||
Number(usage.cache_creation_input_tokens) ||
0;
const cache_read =
Number(usage.input_token_details?.cache_read) ||
Number(usage.cache_read_input_tokens) ||
0;
const txMetadata = {
context: 'abort',
conversationId,
user: userId,
model: usage.model ?? fallbackModel,
};
if (cache_creation > 0 || cache_read > 0) {
spendPromises.push(
mockSpendStructuredTokens(txMetadata, {
promptTokens: {
input: usage.input_tokens,
write: cache_creation,
read: cache_read,
},
completionTokens: usage.output_tokens,
}).catch(() => {
// Log error but don't throw
}),
);
continue;
}
spendPromises.push(
mockSpendTokens(txMetadata, {
promptTokens: usage.input_tokens,
completionTokens: usage.output_tokens,
}).catch(() => {
// Log error but don't throw
}),
);
}
// Wait for all token spending to complete
await Promise.all(spendPromises);
// Clear the array to prevent double-spending
collectedUsage.length = 0;
};
it('should return early if collectedUsage is empty', async () => {
await spendCollectedUsage({
userId: 'user-123',
conversationId: 'convo-123',
collectedUsage: [],
fallbackModel: 'gpt-4',
});
expect(mockSpendTokens).not.toHaveBeenCalled();
expect(mockSpendStructuredTokens).not.toHaveBeenCalled();
});
it('should return early if collectedUsage is null', async () => {
await spendCollectedUsage({
userId: 'user-123',
conversationId: 'convo-123',
collectedUsage: null,
fallbackModel: 'gpt-4',
});
expect(mockSpendTokens).not.toHaveBeenCalled();
expect(mockSpendStructuredTokens).not.toHaveBeenCalled();
});
it('should skip null entries in collectedUsage', async () => {
const collectedUsage = [
{ input_tokens: 100, output_tokens: 50, model: 'gpt-4' },
null,
{ input_tokens: 200, output_tokens: 60, model: 'gpt-4' },
];
await spendCollectedUsage({
userId: 'user-123',
conversationId: 'convo-123',
collectedUsage,
fallbackModel: 'gpt-4',
});
expect(mockSpendTokens).toHaveBeenCalledTimes(2);
});
it('should spend tokens for single model', async () => {
const collectedUsage = [{ input_tokens: 100, output_tokens: 50, model: 'gpt-4' }];
await spendCollectedUsage({
userId: 'user-123',
conversationId: 'convo-123',
collectedUsage,
fallbackModel: 'gpt-4',
});
expect(mockSpendTokens).toHaveBeenCalledTimes(1);
expect(mockSpendTokens).toHaveBeenCalledWith(
expect.objectContaining({
context: 'abort',
conversationId: 'convo-123',
user: 'user-123',
model: 'gpt-4',
}),
{ promptTokens: 100, completionTokens: 50 },
);
});
it('should spend tokens for multiple models (parallel agents)', async () => {
const collectedUsage = [
{ input_tokens: 100, output_tokens: 50, model: 'gpt-4' },
{ input_tokens: 80, output_tokens: 40, model: 'claude-3' },
{ input_tokens: 120, output_tokens: 60, model: 'gemini-pro' },
];
await spendCollectedUsage({
userId: 'user-123',
conversationId: 'convo-123',
collectedUsage,
fallbackModel: 'gpt-4',
});
expect(mockSpendTokens).toHaveBeenCalledTimes(3);
// Verify each model was called
expect(mockSpendTokens).toHaveBeenNthCalledWith(
1,
expect.objectContaining({ model: 'gpt-4' }),
{ promptTokens: 100, completionTokens: 50 },
);
expect(mockSpendTokens).toHaveBeenNthCalledWith(
2,
expect.objectContaining({ model: 'claude-3' }),
{ promptTokens: 80, completionTokens: 40 },
);
expect(mockSpendTokens).toHaveBeenNthCalledWith(
3,
expect.objectContaining({ model: 'gemini-pro' }),
{ promptTokens: 120, completionTokens: 60 },
);
});
it('should use fallbackModel when usage.model is missing', async () => {
const collectedUsage = [{ input_tokens: 100, output_tokens: 50 }];
await spendCollectedUsage({
userId: 'user-123',
conversationId: 'convo-123',
collectedUsage,
fallbackModel: 'fallback-model',
});
expect(mockSpendTokens).toHaveBeenCalledWith(
expect.objectContaining({ model: 'fallback-model' }),
expect.any(Object),
);
});
it('should use spendStructuredTokens for OpenAI format cache tokens', async () => {
const collectedUsage = [
{
input_tokens: 100,
output_tokens: 50,
model: 'gpt-4',
input_token_details: {
cache_creation: 20,
cache_read: 10,
},
},
];
await spendCollectedUsage({
userId: 'user-123',
conversationId: 'convo-123',
collectedUsage,
fallbackModel: 'gpt-4',
});
expect(mockSpendStructuredTokens).toHaveBeenCalledTimes(1);
expect(mockSpendTokens).not.toHaveBeenCalled();
expect(mockSpendStructuredTokens).toHaveBeenCalledWith(
expect.objectContaining({ model: 'gpt-4', context: 'abort' }),
{
promptTokens: {
input: 100,
write: 20,
read: 10,
},
completionTokens: 50,
},
);
});
it('should use spendStructuredTokens for Anthropic format cache tokens', async () => {
const collectedUsage = [
{
input_tokens: 100,
output_tokens: 50,
model: 'claude-3',
cache_creation_input_tokens: 25,
cache_read_input_tokens: 15,
},
];
await spendCollectedUsage({
userId: 'user-123',
conversationId: 'convo-123',
collectedUsage,
fallbackModel: 'claude-3',
});
expect(mockSpendStructuredTokens).toHaveBeenCalledTimes(1);
expect(mockSpendTokens).not.toHaveBeenCalled();
expect(mockSpendStructuredTokens).toHaveBeenCalledWith(
expect.objectContaining({ model: 'claude-3' }),
{
promptTokens: {
input: 100,
write: 25,
read: 15,
},
completionTokens: 50,
},
);
});
it('should handle mixed cache and non-cache entries', async () => {
const collectedUsage = [
{ input_tokens: 100, output_tokens: 50, model: 'gpt-4' },
{
input_tokens: 150,
output_tokens: 30,
model: 'claude-3',
cache_creation_input_tokens: 20,
cache_read_input_tokens: 10,
},
{ input_tokens: 200, output_tokens: 20, model: 'gemini-pro' },
];
await spendCollectedUsage({
userId: 'user-123',
conversationId: 'convo-123',
collectedUsage,
fallbackModel: 'gpt-4',
});
expect(mockSpendTokens).toHaveBeenCalledTimes(2);
expect(mockSpendStructuredTokens).toHaveBeenCalledTimes(1);
});
it('should handle real-world parallel agent abort scenario', async () => {
// Simulates: Primary agent (gemini) + addedConvo agent (gpt-5) aborted mid-stream
const collectedUsage = [
{ input_tokens: 31596, output_tokens: 151, model: 'gemini-3-flash-preview' },
{ input_tokens: 28000, output_tokens: 120, model: 'gpt-5.2' },
];
await spendCollectedUsage({
userId: 'user-123',
conversationId: 'convo-123',
collectedUsage,
fallbackModel: 'gemini-3-flash-preview',
});
expect(mockSpendTokens).toHaveBeenCalledTimes(2);
// Primary model
expect(mockSpendTokens).toHaveBeenNthCalledWith(
1,
expect.objectContaining({ model: 'gemini-3-flash-preview' }),
{ promptTokens: 31596, completionTokens: 151 },
);
// Parallel model (addedConvo)
expect(mockSpendTokens).toHaveBeenNthCalledWith(
2,
expect.objectContaining({ model: 'gpt-5.2' }),
{ promptTokens: 28000, completionTokens: 120 },
);
});
it('should clear collectedUsage array after spending to prevent double-spending', async () => {
// This tests the race condition fix: after abort middleware spends tokens,
// the collectedUsage array is cleared so AgentClient.recordCollectedUsage()
// (which shares the same array reference) sees an empty array and returns early.
const collectedUsage = [
{ input_tokens: 100, output_tokens: 50, model: 'gpt-4' },
{ input_tokens: 80, output_tokens: 40, model: 'claude-3' },
];
expect(collectedUsage.length).toBe(2);
await spendCollectedUsage({
userId: 'user-123',
conversationId: 'convo-123',
collectedUsage,
fallbackModel: 'gpt-4',
});
expect(mockSpendTokens).toHaveBeenCalledTimes(2);
// The array should be cleared after spending
expect(collectedUsage.length).toBe(0);
});
it('should await all token spending operations before clearing array', async () => {
// Ensure we don't clear the array before spending completes
let spendCallCount = 0;
mockSpendTokens.mockImplementation(async () => {
spendCallCount++;
// Simulate async delay
await new Promise((resolve) => setTimeout(resolve, 10));
});
const collectedUsage = [
{ input_tokens: 100, output_tokens: 50, model: 'gpt-4' },
{ input_tokens: 80, output_tokens: 40, model: 'claude-3' },
];
await spendCollectedUsage({
userId: 'user-123',
conversationId: 'convo-123',
collectedUsage,
fallbackModel: 'gpt-4',
});
// Both spend calls should have completed
expect(spendCallCount).toBe(2);
// Array should be cleared after awaiting
expect(collectedUsage.length).toBe(0);
});
});
});