🔧 refactor: Improve Agent Context & Minor Fixes (#5349)

* refactor: Improve Context for Agents

* 🔧 fix: Safeguard against undefined properties in OpenAIClient response handling

* refactor: log error before re-throwing for original stack trace

* refactor: remove toolResource state from useFileHandling, allow svg files

* refactor: prevent verbose logs from axios errors when using actions

* refactor: add silent method recordTokenUsage in AgentClient

* refactor: streamline token count assignment in BaseClient

* refactor: enhance safety settings handling for Gemini 2.0 model

* fix: capabilities structure in MCPConnection

* refactor: simplify civic integrity threshold handling in GoogleClient and llm

* refactor: update token count retrieval method in BaseClient tests

* ci: fix test for svg
This commit is contained in:
Danny Avila 2025-01-17 12:55:48 -05:00 committed by GitHub
parent e309c6abef
commit b35a8b78e2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 324 additions and 112 deletions

View file

@ -60,6 +60,9 @@ const noSystemModelRegex = [/\bo1\b/gi];
class AgentClient extends BaseClient {
constructor(options = {}) {
super(null, options);
/** The current client class
* @type {string} */
this.clientName = EModelEndpoint.agents;
/** @type {'discard' | 'summarize'} */
this.contextStrategy = 'discard';
@ -91,6 +94,14 @@ class AgentClient extends BaseClient {
this.options = Object.assign({ endpoint: options.endpoint }, clientOptions);
/** @type {string} */
this.model = this.options.agent.model_parameters.model;
/** The key for the usage object's input tokens
* @type {string} */
this.inputTokensKey = 'input_tokens';
/** The key for the usage object's output tokens
* @type {string} */
this.outputTokensKey = 'output_tokens';
/** @type {UsageMetadata} */
this.usage;
}
/**
@ -329,16 +340,18 @@ class AgentClient extends BaseClient {
this.options.agent.instructions = systemContent;
}
/** @type {Record<string, number> | undefined} */
let tokenCountMap;
if (this.contextStrategy) {
({ payload, promptTokens, messages } = await this.handleContextStrategy({
({ payload, promptTokens, tokenCountMap, messages } = await this.handleContextStrategy({
orderedMessages,
formattedMessages,
/* prefer usage_metadata from final message */
buildTokenMap: false,
}));
}
const result = {
tokenCountMap,
prompt: payload,
promptTokens,
messages,
@ -368,8 +381,26 @@ class AgentClient extends BaseClient {
* @param {UsageMetadata[]} [params.collectedUsage=this.collectedUsage]
*/
async recordCollectedUsage({ model, context = 'message', collectedUsage = this.collectedUsage }) {
for (const usage of collectedUsage) {
await spendTokens(
if (!collectedUsage || !collectedUsage.length) {
return;
}
const input_tokens = collectedUsage[0]?.input_tokens || 0;
let output_tokens = 0;
let previousTokens = input_tokens; // Start with original input
for (let i = 0; i < collectedUsage.length; i++) {
const usage = collectedUsage[i];
if (i > 0) {
// Count new tokens generated (input_tokens minus previous accumulated tokens)
output_tokens += (Number(usage.input_tokens) || 0) - previousTokens;
}
// Add this message's output tokens
output_tokens += Number(usage.output_tokens) || 0;
// Update previousTokens to include this message's output
previousTokens += Number(usage.output_tokens) || 0;
spendTokens(
{
context,
conversationId: this.conversationId,
@ -378,8 +409,66 @@ class AgentClient extends BaseClient {
model: usage.model ?? model ?? this.model ?? this.options.agent.model_parameters.model,
},
{ promptTokens: usage.input_tokens, completionTokens: usage.output_tokens },
);
).catch((err) => {
logger.error(
'[api/server/controllers/agents/client.js #recordCollectedUsage] Error spending tokens',
err,
);
});
}
this.usage = {
input_tokens,
output_tokens,
};
}
/**
* Get stream usage as returned by this client's API response.
* @returns {UsageMetadata} The stream usage object.
*/
getStreamUsage() {
return this.usage;
}
/**
* @param {TMessage} responseMessage
* @returns {number}
*/
getTokenCountForResponse({ content }) {
return this.getTokenCountForMessage({
role: 'assistant',
content,
});
}
/**
* Calculates the correct token count for the current user message based on the token count map and API usage.
* Edge case: If the calculation results in a negative value, it returns the original estimate.
* If revisiting a conversation with a chat history entirely composed of token estimates,
* the cumulative token count going forward should become more accurate as the conversation progresses.
* @param {Object} params - The parameters for the calculation.
* @param {Record<string, number>} params.tokenCountMap - A map of message IDs to their token counts.
* @param {string} params.currentMessageId - The ID of the current message to calculate.
* @param {OpenAIUsageMetadata} params.usage - The usage object returned by the API.
* @returns {number} The correct token count for the current user message.
*/
calculateCurrentTokenCount({ tokenCountMap, currentMessageId, usage }) {
const originalEstimate = tokenCountMap[currentMessageId] || 0;
if (!usage || typeof usage[this.inputTokensKey] !== 'number') {
return originalEstimate;
}
tokenCountMap[currentMessageId] = 0;
const totalTokensFromMap = Object.values(tokenCountMap).reduce((sum, count) => {
const numCount = Number(count);
return sum + (isNaN(numCount) ? 0 : numCount);
}, 0);
const totalInputTokens = usage[this.inputTokensKey] ?? 0;
const currentMessageTokens = totalInputTokens - totalTokensFromMap;
return currentMessageTokens > 0 ? currentMessageTokens : originalEstimate;
}
async chatCompletion({ payload, abortController = null }) {
@ -676,12 +765,14 @@ class AgentClient extends BaseClient {
);
});
this.recordCollectedUsage({ context: 'message' }).catch((err) => {
try {
await this.recordCollectedUsage({ context: 'message' });
} catch (err) {
logger.error(
'[api/server/controllers/agents/client.js #chatCompletion] Error recording collected usage',
err,
);
});
}
} catch (err) {
if (!abortController.signal.aborted) {
logger.error(
@ -767,8 +858,11 @@ class AgentClient extends BaseClient {
}
}
/** Silent method, as `recordCollectedUsage` is used instead */
async recordTokenUsage() {}
getEncoding() {
return this.model?.includes('gpt-4o') ? 'o200k_base' : 'cl100k_base';
return 'o200k_base';
}
/**