mirror of
https://github.com/danny-avila/LibreChat.git
synced 2026-04-07 08:25:23 +02:00
refactor: single-pass collectToolData to avoid redundant tool iteration
Extract collectToolData() that builds both the fingerprint names and serialized schemas in one pass over tools + toolDefinitions. getOrComputeToolTokens uses the pre-collected schemas directly on cache miss instead of re-looping. getToolFingerprint and computeToolSchemaTokens delegate to the same shared function for standalone use.
This commit is contained in:
parent
749666503c
commit
90d32b691e
1 changed files with 84 additions and 75 deletions
|
|
@ -23,34 +23,6 @@ function getCache(): Keyv {
|
|||
return toolTokenCache;
|
||||
}
|
||||
|
||||
export function getToolFingerprint(tools?: GenericTool[], toolDefinitions?: LCTool[]): string {
|
||||
const names = new Set<string>();
|
||||
|
||||
if (tools) {
|
||||
for (const tool of tools) {
|
||||
const name = (tool as unknown as Record<string, unknown>).name;
|
||||
if (typeof name === 'string' && name) {
|
||||
names.add(name);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (toolDefinitions) {
|
||||
for (const def of toolDefinitions) {
|
||||
if (def.name) {
|
||||
names.add(def.name);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (names.size === 0) {
|
||||
return '';
|
||||
}
|
||||
|
||||
const sorted = Array.from(names).sort();
|
||||
return sorted.join(',') + '|' + sorted.length;
|
||||
}
|
||||
|
||||
function getToolTokenMultiplier(provider: Providers, clientOptions?: ClientOptions): number {
|
||||
const isAnthropic =
|
||||
provider !== Providers.BEDROCK &&
|
||||
|
|
@ -61,6 +33,77 @@ function getToolTokenMultiplier(provider: Providers, clientOptions?: ClientOptio
|
|||
return isAnthropic ? ANTHROPIC_TOOL_TOKEN_MULTIPLIER : DEFAULT_TOOL_TOKEN_MULTIPLIER;
|
||||
}
|
||||
|
||||
/**
|
||||
* Single pass over tools and toolDefinitions. Collects deduplicated sorted
|
||||
* tool names (for fingerprint) and pre-serialized schemas (for token
|
||||
* counting on cache miss), mirroring the dedup logic in
|
||||
* AgentContext.calculateInstructionTokens().
|
||||
*/
|
||||
function collectToolData(
|
||||
tools?: GenericTool[],
|
||||
toolDefinitions?: LCTool[],
|
||||
): { names: string[]; schemas: string[] } {
|
||||
const nameSet = new Set<string>();
|
||||
const countedNames = new Set<string>();
|
||||
const schemas: string[] = [];
|
||||
|
||||
if (tools) {
|
||||
for (const tool of tools) {
|
||||
const genericTool = tool as unknown as Record<string, unknown>;
|
||||
const toolName = (genericTool.name as string | undefined) ?? '';
|
||||
if (toolName) {
|
||||
nameSet.add(toolName);
|
||||
}
|
||||
if (genericTool.schema != null && typeof genericTool.schema === 'object') {
|
||||
schemas.push(
|
||||
JSON.stringify(
|
||||
toJsonSchema(
|
||||
genericTool.schema,
|
||||
toolName,
|
||||
(genericTool.description as string | undefined) ?? '',
|
||||
),
|
||||
),
|
||||
);
|
||||
if (toolName) {
|
||||
countedNames.add(toolName);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (toolDefinitions) {
|
||||
for (const def of toolDefinitions) {
|
||||
if (def.name) {
|
||||
nameSet.add(def.name);
|
||||
}
|
||||
if (countedNames.has(def.name)) {
|
||||
continue;
|
||||
}
|
||||
schemas.push(
|
||||
JSON.stringify({
|
||||
type: 'function',
|
||||
function: {
|
||||
name: def.name,
|
||||
description: def.description ?? '',
|
||||
parameters: def.parameters ?? {},
|
||||
},
|
||||
}),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
const names = nameSet.size > 0 ? Array.from(nameSet).sort() : [];
|
||||
return { names, schemas };
|
||||
}
|
||||
|
||||
export function getToolFingerprint(tools?: GenericTool[], toolDefinitions?: LCTool[]): string {
|
||||
const { names } = collectToolData(tools, toolDefinitions);
|
||||
if (names.length === 0) {
|
||||
return '';
|
||||
}
|
||||
return names.join(',') + '|' + names.length;
|
||||
}
|
||||
|
||||
export function computeToolSchemaTokens(
|
||||
tools: GenericTool[] | undefined,
|
||||
toolDefinitions: LCTool[] | undefined,
|
||||
|
|
@ -68,44 +111,11 @@ export function computeToolSchemaTokens(
|
|||
clientOptions: ClientOptions | undefined,
|
||||
tokenCounter: TokenCounter,
|
||||
): number {
|
||||
const { schemas } = collectToolData(tools, toolDefinitions);
|
||||
let toolTokens = 0;
|
||||
const countedToolNames = new Set<string>();
|
||||
|
||||
if (tools && tools.length > 0) {
|
||||
for (const tool of tools) {
|
||||
const genericTool = tool as unknown as Record<string, unknown>;
|
||||
if (genericTool.schema != null && typeof genericTool.schema === 'object') {
|
||||
const toolName = (genericTool.name as string | undefined) ?? '';
|
||||
const jsonSchema = toJsonSchema(
|
||||
genericTool.schema,
|
||||
toolName,
|
||||
(genericTool.description as string | undefined) ?? '',
|
||||
);
|
||||
toolTokens += tokenCounter(new SystemMessage(JSON.stringify(jsonSchema)));
|
||||
if (toolName) {
|
||||
countedToolNames.add(toolName);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (const schema of schemas) {
|
||||
toolTokens += tokenCounter(new SystemMessage(schema));
|
||||
}
|
||||
|
||||
if (toolDefinitions && toolDefinitions.length > 0) {
|
||||
for (const def of toolDefinitions) {
|
||||
if (countedToolNames.has(def.name)) {
|
||||
continue;
|
||||
}
|
||||
const schema = {
|
||||
type: 'function',
|
||||
function: {
|
||||
name: def.name,
|
||||
description: def.description ?? '',
|
||||
parameters: def.parameters ?? {},
|
||||
},
|
||||
};
|
||||
toolTokens += tokenCounter(new SystemMessage(JSON.stringify(schema)));
|
||||
}
|
||||
}
|
||||
|
||||
const multiplier = getToolTokenMultiplier(provider, clientOptions);
|
||||
return Math.ceil(toolTokens * multiplier);
|
||||
}
|
||||
|
|
@ -113,8 +123,8 @@ export function computeToolSchemaTokens(
|
|||
/**
|
||||
* Returns cached tool schema tokens or computes them on miss.
|
||||
* Returns 0 if there are no tools.
|
||||
* Cache errors are non-fatal — falls through to compute on read failure,
|
||||
* logs on write failure.
|
||||
* Single pass over tool arrays: builds fingerprint and serialized schemas
|
||||
* together, then only runs the token counter if the cache misses.
|
||||
*/
|
||||
export async function getOrComputeToolTokens({
|
||||
tools,
|
||||
|
|
@ -129,11 +139,12 @@ export async function getOrComputeToolTokens({
|
|||
clientOptions?: ClientOptions;
|
||||
tokenCounter: TokenCounter;
|
||||
}): Promise<number> {
|
||||
const fingerprint = getToolFingerprint(tools, toolDefinitions);
|
||||
if (!fingerprint) {
|
||||
const { names, schemas } = collectToolData(tools, toolDefinitions);
|
||||
if (names.length === 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const fingerprint = names.join(',') + '|' + names.length;
|
||||
const multiplier = getToolTokenMultiplier(provider, clientOptions);
|
||||
const multiplierKey = multiplier === ANTHROPIC_TOOL_TOKEN_MULTIPLIER ? 'anthropic' : 'default';
|
||||
const cacheKey = `${provider}:${multiplierKey}:${fingerprint}`;
|
||||
|
|
@ -148,13 +159,11 @@ export async function getOrComputeToolTokens({
|
|||
logger.debug('[toolTokens] Cache read failed, computing fresh', err);
|
||||
}
|
||||
|
||||
const tokens = computeToolSchemaTokens(
|
||||
tools,
|
||||
toolDefinitions,
|
||||
provider,
|
||||
clientOptions,
|
||||
tokenCounter,
|
||||
);
|
||||
let toolTokens = 0;
|
||||
for (const schema of schemas) {
|
||||
toolTokens += tokenCounter(new SystemMessage(schema));
|
||||
}
|
||||
const tokens = Math.ceil(toolTokens * multiplier);
|
||||
|
||||
if (tokens > 0) {
|
||||
cache.set(cacheKey, tokens).catch((err: unknown) => {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue