diff --git a/packages/api/src/agents/toolTokens.spec.ts b/packages/api/src/agents/toolTokens.spec.ts index 27461e0d67..8d324ea58b 100644 --- a/packages/api/src/agents/toolTokens.spec.ts +++ b/packages/api/src/agents/toolTokens.spec.ts @@ -37,11 +37,18 @@ function makeTool(name: string, description = `${name} description`): GenericToo }) as unknown as GenericTool; } -function makeToolDef(name: string, description?: string): LCTool { +function makeMcpTool(name: string): GenericTool { + const tool = makeTool(name) as unknown as Record; + tool.mcp = true; + return tool as unknown as GenericTool; +} + +function makeToolDef(name: string, opts?: Partial): LCTool { return { name, - description: description ?? `${name} description`, - parameters: { type: 'object', properties: { input: { type: 'string' } } }, + description: opts?.description ?? `${name} description`, + parameters: opts?.parameters ?? { type: 'object', properties: { input: { type: 'string' } } }, + ...opts, }; } @@ -65,34 +72,45 @@ beforeEach(() => { /* ========================================================================= */ describe('collectToolSchemas', () => { - it('returns empty map when no tools provided', () => { - expect(collectToolSchemas().size).toBe(0); - expect(collectToolSchemas([], []).size).toBe(0); + it('returns empty array when no tools provided', () => { + expect(collectToolSchemas()).toHaveLength(0); + expect(collectToolSchemas([], [])).toHaveLength(0); }); - it('collects schemas from GenericTool array keyed by name', () => { - const tools = [makeTool('alpha'), makeTool('beta')]; - const schemas = collectToolSchemas(tools); - expect(schemas.size).toBe(2); - expect(schemas.has('alpha')).toBe(true); - expect(schemas.has('beta')).toBe(true); + it('collects entries from GenericTool array', () => { + const entries = collectToolSchemas([makeTool('alpha'), makeTool('beta')]); + expect(entries).toHaveLength(2); + expect(entries.map((e) => e.cacheKey)).toEqual( + expect.arrayContaining(['alpha:builtin', 'beta:builtin']), + ); }); - it('collects schemas from LCTool definitions', () => { - const defs = [makeToolDef('x'), makeToolDef('y')]; - const schemas = collectToolSchemas(undefined, defs); - expect(schemas.size).toBe(2); - expect(schemas.has('x')).toBe(true); - expect(schemas.has('y')).toBe(true); + it('collects entries from LCTool definitions with toolType', () => { + const defs = [makeToolDef('x', { toolType: 'mcp' }), makeToolDef('y', { toolType: 'action' })]; + const entries = collectToolSchemas(undefined, defs); + expect(entries).toHaveLength(2); + expect(entries[0].cacheKey).toBe('x:mcp'); + expect(entries[1].cacheKey).toBe('y:action'); + }); + + it('defaults toolType to builtin for LCTool without toolType', () => { + const entries = collectToolSchemas(undefined, [makeToolDef('z')]); + expect(entries[0].cacheKey).toBe('z:builtin'); + }); + + it('uses mcp type for GenericTool with mcp flag', () => { + const entries = collectToolSchemas([makeMcpTool('search')]); + expect(entries[0].cacheKey).toBe('search:mcp'); }); it('deduplicates: GenericTool takes precedence over matching toolDefinition', () => { const tools = [makeTool('shared')]; const defs = [makeToolDef('shared'), makeToolDef('only_def')]; - const schemas = collectToolSchemas(tools, defs); - expect(schemas.size).toBe(2); - expect(schemas.has('shared')).toBe(true); - expect(schemas.has('only_def')).toBe(true); + const entries = collectToolSchemas(tools, defs); + expect(entries).toHaveLength(2); + const keys = entries.map((e) => e.cacheKey); + expect(keys).toContain('shared:builtin'); + expect(keys).toContain('only_def:builtin'); }); }); @@ -109,9 +127,8 @@ describe('computeToolSchemaTokens', () => { }); it('counts tokens from GenericTool schemas', () => { - const tools = [makeTool('test_tool')]; const result = computeToolSchemaTokens( - tools, + [makeTool('test_tool')], undefined, Providers.OPENAI, undefined, @@ -121,10 +138,9 @@ describe('computeToolSchemaTokens', () => { }); it('counts tokens from LCTool definitions', () => { - const defs = [makeToolDef('test_def')]; const result = computeToolSchemaTokens( undefined, - defs, + [makeToolDef('test_def')], Providers.OPENAI, undefined, fakeTokenCounter, @@ -150,7 +166,6 @@ describe('computeToolSchemaTokens', () => { undefined, fakeTokenCounter, ); - expect(both).toBe(toolsOnly); }); @@ -170,19 +185,17 @@ describe('computeToolSchemaTokens', () => { undefined, fakeTokenCounter, ); - const expectedRatio = ANTHROPIC_TOOL_TOKEN_MULTIPLIER / DEFAULT_TOOL_TOKEN_MULTIPLIER; expect(anthropic / openai).toBeCloseTo(expectedRatio, 1); }); it('applies Anthropic multiplier when model name contains "claude"', () => { const defs = [makeToolDef('tool')]; - const clientOptions = { model: 'claude-3-opus' }; const result = computeToolSchemaTokens( undefined, defs, Providers.OPENAI, - clientOptions, + { model: 'claude-3-opus' }, fakeTokenCounter, ); const defaultResult = computeToolSchemaTokens( @@ -197,12 +210,11 @@ describe('computeToolSchemaTokens', () => { it('does not apply Anthropic multiplier for Bedrock even with claude model', () => { const defs = [makeToolDef('tool')]; - const clientOptions = { model: 'claude-3-opus' }; const bedrock = computeToolSchemaTokens( undefined, defs, Providers.BEDROCK, - clientOptions, + { model: 'claude-3-opus' }, fakeTokenCounter, ); const defaultResult = computeToolSchemaTokens( @@ -212,7 +224,6 @@ describe('computeToolSchemaTokens', () => { undefined, fakeTokenCounter, ); - expect(bedrock).toBe(defaultResult); }); }); @@ -239,8 +250,8 @@ describe('getOrComputeToolTokens', () => { }); expect(result).toBeGreaterThan(0); - expect(mockCacheStore.has('tool_a')).toBe(true); - expect(mockCacheStore.has('tool_b')).toBe(true); + expect(mockCacheStore.has('tool_a:builtin')).toBe(true); + expect(mockCacheStore.has('tool_b:builtin')).toBe(true); expect(mockCacheStore.size).toBe(2); }); @@ -282,7 +293,6 @@ describe('getOrComputeToolTokens', () => { }); expect(openai).not.toBe(anthropic); - // Only one cache entry — raw count is provider-agnostic expect(mockCacheStore.size).toBe(1); }); @@ -302,17 +312,63 @@ describe('getOrComputeToolTokens', () => { tokenCounter: counter, }); - // Only one new tokenCounter call for tool_b expect(counter.mock.calls.length).toBe(callsAfterFirst + 1); expect(mockCacheStore.size).toBe(2); }); + it('scopes cache keys by tenantId when provided', async () => { + const defs = [makeToolDef('tool')]; + + await getOrComputeToolTokens({ + toolDefinitions: defs, + provider: Providers.OPENAI, + tokenCounter: fakeTokenCounter, + tenantId: 'tenant_123', + }); + + expect(mockCacheStore.has('tenant_123:tool:builtin')).toBe(true); + }); + + it('separates cache entries for different tenants', async () => { + const defs = [makeToolDef('tool')]; + + const t1 = await getOrComputeToolTokens({ + toolDefinitions: defs, + provider: Providers.OPENAI, + tokenCounter: fakeTokenCounter, + tenantId: 'tenant_1', + }); + + const t2 = await getOrComputeToolTokens({ + toolDefinitions: defs, + provider: Providers.OPENAI, + tokenCounter: fakeTokenCounter, + tenantId: 'tenant_2', + }); + + expect(t1).toBe(t2); + expect(mockCacheStore.has('tenant_1:tool:builtin')).toBe(true); + expect(mockCacheStore.has('tenant_2:tool:builtin')).toBe(true); + expect(mockCacheStore.size).toBe(2); + }); + + it('caches mcp tools with mcp type in key', async () => { + const defs = [makeToolDef('search', { toolType: 'mcp' })]; + + await getOrComputeToolTokens({ + toolDefinitions: defs, + provider: Providers.OPENAI, + tokenCounter: fakeTokenCounter, + }); + + expect(mockCacheStore.has('search:mcp')).toBe(true); + }); + it('falls back to compute when cache read throws', async () => { mockGet.mockRejectedValueOnce(new Error('Redis down')); - const defs = [makeToolDef('tool')]; const result = await getOrComputeToolTokens({ - toolDefinitions: defs, + toolDefinitions: [makeToolDef('tool')], provider: Providers.OPENAI, tokenCounter: fakeTokenCounter, }); @@ -324,9 +380,8 @@ describe('getOrComputeToolTokens', () => { it('does not throw when cache write fails', async () => { mockSet.mockRejectedValueOnce(new Error('Redis write error')); - const defs = [makeToolDef('tool_write_fail')]; const result = await getOrComputeToolTokens({ - toolDefinitions: defs, + toolDefinitions: [makeToolDef('tool_write_fail')], provider: Providers.OPENAI, tokenCounter: fakeTokenCounter, }); @@ -335,20 +390,6 @@ describe('getOrComputeToolTokens', () => { expect(mockSet).toHaveBeenCalled(); }); - it('uses GenericTool tools for per-tool caching', async () => { - const tools = [makeTool('alpha'), makeTool('beta')]; - - const result = await getOrComputeToolTokens({ - tools, - provider: Providers.OPENAI, - tokenCounter: fakeTokenCounter, - }); - - expect(result).toBeGreaterThan(0); - expect(mockCacheStore.has('alpha')).toBe(true); - expect(mockCacheStore.has('beta')).toBe(true); - }); - it('matches computeToolSchemaTokens output for same inputs', async () => { const defs = [makeToolDef('a'), makeToolDef('b'), makeToolDef('c')]; diff --git a/packages/api/src/agents/toolTokens.ts b/packages/api/src/agents/toolTokens.ts index 22d88d0285..4a95b07600 100644 --- a/packages/api/src/agents/toolTokens.ts +++ b/packages/api/src/agents/toolTokens.ts @@ -13,6 +13,11 @@ import type { Keyv } from 'keyv'; import { standardCache } from '~/cache'; +interface ToolEntry { + cacheKey: string; + json: string; +} + /** Module-level cache instance, lazily initialized. */ let toolTokenCache: Keyv | undefined; @@ -61,35 +66,43 @@ function serializeToolDef(def: LCTool): string { } /** - * Builds a map of tool name → serialized schema JSON. Deduplicates: a tool - * present in `tools` (with a schema) takes precedence over a matching - * `toolDefinitions` entry, mirroring AgentContext.calculateInstructionTokens(). + * Builds a list of tool entries with cache keys and serialized schemas. + * Deduplicates: a tool present in `tools` (with a schema) takes precedence + * over a matching `toolDefinitions` entry. + * + * Cache key includes toolType when available (from LCTool) to differentiate + * builtin/mcp/action tools that may share a name. + * GenericTool entries use the `mcp` flag when present. */ -export function collectToolSchemas( - tools?: GenericTool[], - toolDefinitions?: LCTool[], -): Map { - const schemas = new Map(); +export function collectToolSchemas(tools?: GenericTool[], toolDefinitions?: LCTool[]): ToolEntry[] { + const seen = new Set(); + const entries: ToolEntry[] = []; if (tools) { for (const tool of tools) { const result = serializeGenericTool(tool); - if (result && result.name) { - schemas.set(result.name, result.json); + if (!result || !result.name) { + continue; } + seen.add(result.name); + const toolType = + (tool as unknown as Record).mcp === true ? 'mcp' : 'builtin'; + entries.push({ cacheKey: `${result.name}:${toolType}`, json: result.json }); } } if (toolDefinitions) { for (const def of toolDefinitions) { - if (!def.name || schemas.has(def.name)) { + if (!def.name || seen.has(def.name)) { continue; } - schemas.set(def.name, serializeToolDef(def)); + seen.add(def.name); + const toolType = def.toolType ?? 'builtin'; + entries.push({ cacheKey: `${def.name}:${toolType}`, json: serializeToolDef(def) }); } } - return schemas; + return entries; } /** @@ -103,9 +116,9 @@ export function computeToolSchemaTokens( clientOptions: ClientOptions | undefined, tokenCounter: TokenCounter, ): number { - const schemas = collectToolSchemas(tools, toolDefinitions); + const entries = collectToolSchemas(tools, toolDefinitions); let rawTokens = 0; - for (const json of schemas.values()) { + for (const { json } of entries) { rawTokens += tokenCounter(new SystemMessage(json)); } const multiplier = getToolTokenMultiplier(provider, clientOptions); @@ -115,8 +128,8 @@ export function computeToolSchemaTokens( /** * Returns tool schema tokens, using per-tool caching to avoid redundant * token counting. Each tool's raw (pre-multiplier) token count is cached - * individually by name, so adding/removing a tool only requires computing - * the new one. The provider-specific multiplier is applied to the sum. + * individually, keyed by `{tenantId}:{name}:{toolType}` (or `{name}:{toolType}` + * without tenant). The provider-specific multiplier is applied to the sum. * * Returns 0 if there are no tools. */ @@ -135,8 +148,8 @@ export async function getOrComputeToolTokens({ tokenCounter: TokenCounter; tenantId?: string; }): Promise { - const schemas = collectToolSchemas(tools, toolDefinitions); - if (schemas.size === 0) { + const entries = collectToolSchemas(tools, toolDefinitions); + if (entries.length === 0) { return 0; } @@ -152,13 +165,13 @@ export async function getOrComputeToolTokens({ let rawTotal = 0; const toWrite: Array<{ key: string; value: number }> = []; - for (const [name, json] of schemas) { - const cacheKey = `${keyPrefix}${name}`; + for (const { cacheKey, json } of entries) { + const fullKey = `${keyPrefix}${cacheKey}`; let rawCount: number | undefined; if (cache) { try { - rawCount = (await cache.get(cacheKey)) as number | undefined; + rawCount = (await cache.get(fullKey)) as number | undefined; } catch { // Cache read failed for this tool — will compute fresh } @@ -167,7 +180,7 @@ export async function getOrComputeToolTokens({ if (rawCount == null || rawCount <= 0) { rawCount = tokenCounter(new SystemMessage(json)); if (rawCount > 0 && cache) { - toWrite.push({ key: cacheKey, value: rawCount }); + toWrite.push({ key: fullKey, value: rawCount }); } }