refactor: single-pass collectToolData to avoid redundant tool iteration

Extract collectToolData() that builds both the fingerprint names and
serialized schemas in one pass over tools + toolDefinitions.
getOrComputeToolTokens uses the pre-collected schemas directly on
cache miss instead of re-looping. getToolFingerprint and
computeToolSchemaTokens delegate to the same shared function for
standalone use.
This commit is contained in:
Danny Avila 2026-04-01 22:49:05 -04:00
parent 749666503c
commit 90d32b691e

View file

@ -23,34 +23,6 @@ function getCache(): Keyv {
return toolTokenCache;
}
export function getToolFingerprint(tools?: GenericTool[], toolDefinitions?: LCTool[]): string {
const names = new Set<string>();
if (tools) {
for (const tool of tools) {
const name = (tool as unknown as Record<string, unknown>).name;
if (typeof name === 'string' && name) {
names.add(name);
}
}
}
if (toolDefinitions) {
for (const def of toolDefinitions) {
if (def.name) {
names.add(def.name);
}
}
}
if (names.size === 0) {
return '';
}
const sorted = Array.from(names).sort();
return sorted.join(',') + '|' + sorted.length;
}
function getToolTokenMultiplier(provider: Providers, clientOptions?: ClientOptions): number {
const isAnthropic =
provider !== Providers.BEDROCK &&
@ -61,6 +33,77 @@ function getToolTokenMultiplier(provider: Providers, clientOptions?: ClientOptio
return isAnthropic ? ANTHROPIC_TOOL_TOKEN_MULTIPLIER : DEFAULT_TOOL_TOKEN_MULTIPLIER;
}
/**
* Single pass over tools and toolDefinitions. Collects deduplicated sorted
* tool names (for fingerprint) and pre-serialized schemas (for token
* counting on cache miss), mirroring the dedup logic in
* AgentContext.calculateInstructionTokens().
*/
function collectToolData(
tools?: GenericTool[],
toolDefinitions?: LCTool[],
): { names: string[]; schemas: string[] } {
const nameSet = new Set<string>();
const countedNames = new Set<string>();
const schemas: string[] = [];
if (tools) {
for (const tool of tools) {
const genericTool = tool as unknown as Record<string, unknown>;
const toolName = (genericTool.name as string | undefined) ?? '';
if (toolName) {
nameSet.add(toolName);
}
if (genericTool.schema != null && typeof genericTool.schema === 'object') {
schemas.push(
JSON.stringify(
toJsonSchema(
genericTool.schema,
toolName,
(genericTool.description as string | undefined) ?? '',
),
),
);
if (toolName) {
countedNames.add(toolName);
}
}
}
}
if (toolDefinitions) {
for (const def of toolDefinitions) {
if (def.name) {
nameSet.add(def.name);
}
if (countedNames.has(def.name)) {
continue;
}
schemas.push(
JSON.stringify({
type: 'function',
function: {
name: def.name,
description: def.description ?? '',
parameters: def.parameters ?? {},
},
}),
);
}
}
const names = nameSet.size > 0 ? Array.from(nameSet).sort() : [];
return { names, schemas };
}
export function getToolFingerprint(tools?: GenericTool[], toolDefinitions?: LCTool[]): string {
const { names } = collectToolData(tools, toolDefinitions);
if (names.length === 0) {
return '';
}
return names.join(',') + '|' + names.length;
}
export function computeToolSchemaTokens(
tools: GenericTool[] | undefined,
toolDefinitions: LCTool[] | undefined,
@ -68,44 +111,11 @@ export function computeToolSchemaTokens(
clientOptions: ClientOptions | undefined,
tokenCounter: TokenCounter,
): number {
const { schemas } = collectToolData(tools, toolDefinitions);
let toolTokens = 0;
const countedToolNames = new Set<string>();
if (tools && tools.length > 0) {
for (const tool of tools) {
const genericTool = tool as unknown as Record<string, unknown>;
if (genericTool.schema != null && typeof genericTool.schema === 'object') {
const toolName = (genericTool.name as string | undefined) ?? '';
const jsonSchema = toJsonSchema(
genericTool.schema,
toolName,
(genericTool.description as string | undefined) ?? '',
);
toolTokens += tokenCounter(new SystemMessage(JSON.stringify(jsonSchema)));
if (toolName) {
countedToolNames.add(toolName);
}
}
}
for (const schema of schemas) {
toolTokens += tokenCounter(new SystemMessage(schema));
}
if (toolDefinitions && toolDefinitions.length > 0) {
for (const def of toolDefinitions) {
if (countedToolNames.has(def.name)) {
continue;
}
const schema = {
type: 'function',
function: {
name: def.name,
description: def.description ?? '',
parameters: def.parameters ?? {},
},
};
toolTokens += tokenCounter(new SystemMessage(JSON.stringify(schema)));
}
}
const multiplier = getToolTokenMultiplier(provider, clientOptions);
return Math.ceil(toolTokens * multiplier);
}
@ -113,8 +123,8 @@ export function computeToolSchemaTokens(
/**
* Returns cached tool schema tokens or computes them on miss.
* Returns 0 if there are no tools.
* Cache errors are non-fatal falls through to compute on read failure,
* logs on write failure.
* Single pass over tool arrays: builds fingerprint and serialized schemas
* together, then only runs the token counter if the cache misses.
*/
export async function getOrComputeToolTokens({
tools,
@ -129,11 +139,12 @@ export async function getOrComputeToolTokens({
clientOptions?: ClientOptions;
tokenCounter: TokenCounter;
}): Promise<number> {
const fingerprint = getToolFingerprint(tools, toolDefinitions);
if (!fingerprint) {
const { names, schemas } = collectToolData(tools, toolDefinitions);
if (names.length === 0) {
return 0;
}
const fingerprint = names.join(',') + '|' + names.length;
const multiplier = getToolTokenMultiplier(provider, clientOptions);
const multiplierKey = multiplier === ANTHROPIC_TOOL_TOKEN_MULTIPLIER ? 'anthropic' : 'default';
const cacheKey = `${provider}:${multiplierKey}:${fingerprint}`;
@ -148,13 +159,11 @@ export async function getOrComputeToolTokens({
logger.debug('[toolTokens] Cache read failed, computing fresh', err);
}
const tokens = computeToolSchemaTokens(
tools,
toolDefinitions,
provider,
clientOptions,
tokenCounter,
);
let toolTokens = 0;
for (const schema of schemas) {
toolTokens += tokenCounter(new SystemMessage(schema));
}
const tokens = Math.ceil(toolTokens * multiplier);
if (tokens > 0) {
cache.set(cacheKey, tokens).catch((err: unknown) => {