fix: include toolType in per-tool cache key

Cache key is now {tenantId}:{name}:{toolType} (or {name}:{toolType}
without tenant). This differentiates builtin/mcp/action tools that
may share a name but have different schemas.

GenericTool entries derive type from the mcp flag; LCTool entries
use the toolType field (defaulting to builtin).

Also refactors collectToolSchemas to return ToolEntry[] with
pre-built cache keys instead of Map<name, json>.
This commit is contained in:
Danny Avila 2026-04-03 14:08:03 -04:00
parent 8db4f21f97
commit e2962f4967
2 changed files with 132 additions and 78 deletions

View file

@ -37,11 +37,18 @@ function makeTool(name: string, description = `${name} description`): GenericToo
}) as unknown as GenericTool;
}
function makeToolDef(name: string, description?: string): LCTool {
function makeMcpTool(name: string): GenericTool {
const tool = makeTool(name) as unknown as Record<string, unknown>;
tool.mcp = true;
return tool as unknown as GenericTool;
}
function makeToolDef(name: string, opts?: Partial<LCTool>): LCTool {
return {
name,
description: description ?? `${name} description`,
parameters: { type: 'object', properties: { input: { type: 'string' } } },
description: opts?.description ?? `${name} description`,
parameters: opts?.parameters ?? { type: 'object', properties: { input: { type: 'string' } } },
...opts,
};
}
@ -65,34 +72,45 @@ beforeEach(() => {
/* ========================================================================= */
describe('collectToolSchemas', () => {
it('returns empty map when no tools provided', () => {
expect(collectToolSchemas().size).toBe(0);
expect(collectToolSchemas([], []).size).toBe(0);
it('returns empty array when no tools provided', () => {
expect(collectToolSchemas()).toHaveLength(0);
expect(collectToolSchemas([], [])).toHaveLength(0);
});
it('collects schemas from GenericTool array keyed by name', () => {
const tools = [makeTool('alpha'), makeTool('beta')];
const schemas = collectToolSchemas(tools);
expect(schemas.size).toBe(2);
expect(schemas.has('alpha')).toBe(true);
expect(schemas.has('beta')).toBe(true);
it('collects entries from GenericTool array', () => {
const entries = collectToolSchemas([makeTool('alpha'), makeTool('beta')]);
expect(entries).toHaveLength(2);
expect(entries.map((e) => e.cacheKey)).toEqual(
expect.arrayContaining(['alpha:builtin', 'beta:builtin']),
);
});
it('collects schemas from LCTool definitions', () => {
const defs = [makeToolDef('x'), makeToolDef('y')];
const schemas = collectToolSchemas(undefined, defs);
expect(schemas.size).toBe(2);
expect(schemas.has('x')).toBe(true);
expect(schemas.has('y')).toBe(true);
it('collects entries from LCTool definitions with toolType', () => {
const defs = [makeToolDef('x', { toolType: 'mcp' }), makeToolDef('y', { toolType: 'action' })];
const entries = collectToolSchemas(undefined, defs);
expect(entries).toHaveLength(2);
expect(entries[0].cacheKey).toBe('x:mcp');
expect(entries[1].cacheKey).toBe('y:action');
});
it('defaults toolType to builtin for LCTool without toolType', () => {
const entries = collectToolSchemas(undefined, [makeToolDef('z')]);
expect(entries[0].cacheKey).toBe('z:builtin');
});
it('uses mcp type for GenericTool with mcp flag', () => {
const entries = collectToolSchemas([makeMcpTool('search')]);
expect(entries[0].cacheKey).toBe('search:mcp');
});
it('deduplicates: GenericTool takes precedence over matching toolDefinition', () => {
const tools = [makeTool('shared')];
const defs = [makeToolDef('shared'), makeToolDef('only_def')];
const schemas = collectToolSchemas(tools, defs);
expect(schemas.size).toBe(2);
expect(schemas.has('shared')).toBe(true);
expect(schemas.has('only_def')).toBe(true);
const entries = collectToolSchemas(tools, defs);
expect(entries).toHaveLength(2);
const keys = entries.map((e) => e.cacheKey);
expect(keys).toContain('shared:builtin');
expect(keys).toContain('only_def:builtin');
});
});
@ -109,9 +127,8 @@ describe('computeToolSchemaTokens', () => {
});
it('counts tokens from GenericTool schemas', () => {
const tools = [makeTool('test_tool')];
const result = computeToolSchemaTokens(
tools,
[makeTool('test_tool')],
undefined,
Providers.OPENAI,
undefined,
@ -121,10 +138,9 @@ describe('computeToolSchemaTokens', () => {
});
it('counts tokens from LCTool definitions', () => {
const defs = [makeToolDef('test_def')];
const result = computeToolSchemaTokens(
undefined,
defs,
[makeToolDef('test_def')],
Providers.OPENAI,
undefined,
fakeTokenCounter,
@ -150,7 +166,6 @@ describe('computeToolSchemaTokens', () => {
undefined,
fakeTokenCounter,
);
expect(both).toBe(toolsOnly);
});
@ -170,19 +185,17 @@ describe('computeToolSchemaTokens', () => {
undefined,
fakeTokenCounter,
);
const expectedRatio = ANTHROPIC_TOOL_TOKEN_MULTIPLIER / DEFAULT_TOOL_TOKEN_MULTIPLIER;
expect(anthropic / openai).toBeCloseTo(expectedRatio, 1);
});
it('applies Anthropic multiplier when model name contains "claude"', () => {
const defs = [makeToolDef('tool')];
const clientOptions = { model: 'claude-3-opus' };
const result = computeToolSchemaTokens(
undefined,
defs,
Providers.OPENAI,
clientOptions,
{ model: 'claude-3-opus' },
fakeTokenCounter,
);
const defaultResult = computeToolSchemaTokens(
@ -197,12 +210,11 @@ describe('computeToolSchemaTokens', () => {
it('does not apply Anthropic multiplier for Bedrock even with claude model', () => {
const defs = [makeToolDef('tool')];
const clientOptions = { model: 'claude-3-opus' };
const bedrock = computeToolSchemaTokens(
undefined,
defs,
Providers.BEDROCK,
clientOptions,
{ model: 'claude-3-opus' },
fakeTokenCounter,
);
const defaultResult = computeToolSchemaTokens(
@ -212,7 +224,6 @@ describe('computeToolSchemaTokens', () => {
undefined,
fakeTokenCounter,
);
expect(bedrock).toBe(defaultResult);
});
});
@ -239,8 +250,8 @@ describe('getOrComputeToolTokens', () => {
});
expect(result).toBeGreaterThan(0);
expect(mockCacheStore.has('tool_a')).toBe(true);
expect(mockCacheStore.has('tool_b')).toBe(true);
expect(mockCacheStore.has('tool_a:builtin')).toBe(true);
expect(mockCacheStore.has('tool_b:builtin')).toBe(true);
expect(mockCacheStore.size).toBe(2);
});
@ -282,7 +293,6 @@ describe('getOrComputeToolTokens', () => {
});
expect(openai).not.toBe(anthropic);
// Only one cache entry — raw count is provider-agnostic
expect(mockCacheStore.size).toBe(1);
});
@ -302,17 +312,63 @@ describe('getOrComputeToolTokens', () => {
tokenCounter: counter,
});
// Only one new tokenCounter call for tool_b
expect(counter.mock.calls.length).toBe(callsAfterFirst + 1);
expect(mockCacheStore.size).toBe(2);
});
it('scopes cache keys by tenantId when provided', async () => {
const defs = [makeToolDef('tool')];
await getOrComputeToolTokens({
toolDefinitions: defs,
provider: Providers.OPENAI,
tokenCounter: fakeTokenCounter,
tenantId: 'tenant_123',
});
expect(mockCacheStore.has('tenant_123:tool:builtin')).toBe(true);
});
it('separates cache entries for different tenants', async () => {
const defs = [makeToolDef('tool')];
const t1 = await getOrComputeToolTokens({
toolDefinitions: defs,
provider: Providers.OPENAI,
tokenCounter: fakeTokenCounter,
tenantId: 'tenant_1',
});
const t2 = await getOrComputeToolTokens({
toolDefinitions: defs,
provider: Providers.OPENAI,
tokenCounter: fakeTokenCounter,
tenantId: 'tenant_2',
});
expect(t1).toBe(t2);
expect(mockCacheStore.has('tenant_1:tool:builtin')).toBe(true);
expect(mockCacheStore.has('tenant_2:tool:builtin')).toBe(true);
expect(mockCacheStore.size).toBe(2);
});
it('caches mcp tools with mcp type in key', async () => {
const defs = [makeToolDef('search', { toolType: 'mcp' })];
await getOrComputeToolTokens({
toolDefinitions: defs,
provider: Providers.OPENAI,
tokenCounter: fakeTokenCounter,
});
expect(mockCacheStore.has('search:mcp')).toBe(true);
});
it('falls back to compute when cache read throws', async () => {
mockGet.mockRejectedValueOnce(new Error('Redis down'));
const defs = [makeToolDef('tool')];
const result = await getOrComputeToolTokens({
toolDefinitions: defs,
toolDefinitions: [makeToolDef('tool')],
provider: Providers.OPENAI,
tokenCounter: fakeTokenCounter,
});
@ -324,9 +380,8 @@ describe('getOrComputeToolTokens', () => {
it('does not throw when cache write fails', async () => {
mockSet.mockRejectedValueOnce(new Error('Redis write error'));
const defs = [makeToolDef('tool_write_fail')];
const result = await getOrComputeToolTokens({
toolDefinitions: defs,
toolDefinitions: [makeToolDef('tool_write_fail')],
provider: Providers.OPENAI,
tokenCounter: fakeTokenCounter,
});
@ -335,20 +390,6 @@ describe('getOrComputeToolTokens', () => {
expect(mockSet).toHaveBeenCalled();
});
it('uses GenericTool tools for per-tool caching', async () => {
const tools = [makeTool('alpha'), makeTool('beta')];
const result = await getOrComputeToolTokens({
tools,
provider: Providers.OPENAI,
tokenCounter: fakeTokenCounter,
});
expect(result).toBeGreaterThan(0);
expect(mockCacheStore.has('alpha')).toBe(true);
expect(mockCacheStore.has('beta')).toBe(true);
});
it('matches computeToolSchemaTokens output for same inputs', async () => {
const defs = [makeToolDef('a'), makeToolDef('b'), makeToolDef('c')];

View file

@ -13,6 +13,11 @@ import type { Keyv } from 'keyv';
import { standardCache } from '~/cache';
interface ToolEntry {
cacheKey: string;
json: string;
}
/** Module-level cache instance, lazily initialized. */
let toolTokenCache: Keyv | undefined;
@ -61,35 +66,43 @@ function serializeToolDef(def: LCTool): string {
}
/**
* Builds a map of tool name serialized schema JSON. Deduplicates: a tool
* present in `tools` (with a schema) takes precedence over a matching
* `toolDefinitions` entry, mirroring AgentContext.calculateInstructionTokens().
* Builds a list of tool entries with cache keys and serialized schemas.
* Deduplicates: a tool present in `tools` (with a schema) takes precedence
* over a matching `toolDefinitions` entry.
*
* Cache key includes toolType when available (from LCTool) to differentiate
* builtin/mcp/action tools that may share a name.
* GenericTool entries use the `mcp` flag when present.
*/
export function collectToolSchemas(
tools?: GenericTool[],
toolDefinitions?: LCTool[],
): Map<string, string> {
const schemas = new Map<string, string>();
export function collectToolSchemas(tools?: GenericTool[], toolDefinitions?: LCTool[]): ToolEntry[] {
const seen = new Set<string>();
const entries: ToolEntry[] = [];
if (tools) {
for (const tool of tools) {
const result = serializeGenericTool(tool);
if (result && result.name) {
schemas.set(result.name, result.json);
if (!result || !result.name) {
continue;
}
seen.add(result.name);
const toolType =
(tool as unknown as Record<string, unknown>).mcp === true ? 'mcp' : 'builtin';
entries.push({ cacheKey: `${result.name}:${toolType}`, json: result.json });
}
}
if (toolDefinitions) {
for (const def of toolDefinitions) {
if (!def.name || schemas.has(def.name)) {
if (!def.name || seen.has(def.name)) {
continue;
}
schemas.set(def.name, serializeToolDef(def));
seen.add(def.name);
const toolType = def.toolType ?? 'builtin';
entries.push({ cacheKey: `${def.name}:${toolType}`, json: serializeToolDef(def) });
}
}
return schemas;
return entries;
}
/**
@ -103,9 +116,9 @@ export function computeToolSchemaTokens(
clientOptions: ClientOptions | undefined,
tokenCounter: TokenCounter,
): number {
const schemas = collectToolSchemas(tools, toolDefinitions);
const entries = collectToolSchemas(tools, toolDefinitions);
let rawTokens = 0;
for (const json of schemas.values()) {
for (const { json } of entries) {
rawTokens += tokenCounter(new SystemMessage(json));
}
const multiplier = getToolTokenMultiplier(provider, clientOptions);
@ -115,8 +128,8 @@ export function computeToolSchemaTokens(
/**
* Returns tool schema tokens, using per-tool caching to avoid redundant
* token counting. Each tool's raw (pre-multiplier) token count is cached
* individually by name, so adding/removing a tool only requires computing
* the new one. The provider-specific multiplier is applied to the sum.
* individually, keyed by `{tenantId}:{name}:{toolType}` (or `{name}:{toolType}`
* without tenant). The provider-specific multiplier is applied to the sum.
*
* Returns 0 if there are no tools.
*/
@ -135,8 +148,8 @@ export async function getOrComputeToolTokens({
tokenCounter: TokenCounter;
tenantId?: string;
}): Promise<number> {
const schemas = collectToolSchemas(tools, toolDefinitions);
if (schemas.size === 0) {
const entries = collectToolSchemas(tools, toolDefinitions);
if (entries.length === 0) {
return 0;
}
@ -152,13 +165,13 @@ export async function getOrComputeToolTokens({
let rawTotal = 0;
const toWrite: Array<{ key: string; value: number }> = [];
for (const [name, json] of schemas) {
const cacheKey = `${keyPrefix}${name}`;
for (const { cacheKey, json } of entries) {
const fullKey = `${keyPrefix}${cacheKey}`;
let rawCount: number | undefined;
if (cache) {
try {
rawCount = (await cache.get(cacheKey)) as number | undefined;
rawCount = (await cache.get(fullKey)) as number | undefined;
} catch {
// Cache read failed for this tool — will compute fresh
}
@ -167,7 +180,7 @@ export async function getOrComputeToolTokens({
if (rawCount == null || rawCount <= 0) {
rawCount = tokenCounter(new SystemMessage(json));
if (rawCount > 0 && cache) {
toWrite.push({ key: cacheKey, value: rawCount });
toWrite.push({ key: fullKey, value: rawCount });
}
}