diff --git a/packages/api/src/agents/index.ts b/packages/api/src/agents/index.ts index 53f7f60a93..fbc46bfb3e 100644 --- a/packages/api/src/agents/index.ts +++ b/packages/api/src/agents/index.ts @@ -19,3 +19,4 @@ export * from './tools'; export * from './validation'; export * from './added'; export * from './load'; +export * from './toolTokens'; diff --git a/packages/api/src/agents/run.ts b/packages/api/src/agents/run.ts index b6b5e6a14d..1b67f2b337 100644 --- a/packages/api/src/agents/run.ts +++ b/packages/api/src/agents/run.ts @@ -1,3 +1,4 @@ +import { logger } from '@librechat/data-schemas'; import { Run, Providers, Constants } from '@librechat/agents'; import { providerEndpointMap, KnownEndpoints } from 'librechat-data-provider'; import type { @@ -18,6 +19,7 @@ import type { BaseMessage } from '@langchain/core/messages'; import type { IUser } from '@librechat/data-schemas'; import type * as t from '~/types'; import { resolveHeaders, createSafeUser } from '~/utils/env'; +import { getOrComputeToolTokens } from './toolTokens'; /** Expected shape of JSON tool search results */ interface ToolSearchJsonResult { @@ -295,8 +297,7 @@ export async function createRun({ ? extractDiscoveredToolsFromHistory(messages) : new Set(); - const agentInputs: AgentInputs[] = []; - const buildAgentContext = (agent: RunAgent) => { + const buildAgentContext = async (agent: RunAgent): Promise => { const provider = (providerEndpointMap[ agent.provider as keyof typeof providerEndpointMap @@ -381,11 +382,24 @@ export async function createRun({ agent.maxContextTokens, ); + let toolSchemaTokens: number | undefined; + if (tokenCounter) { + toolSchemaTokens = await getOrComputeToolTokens({ + tools: agent.tools, + toolDefinitions, + provider, + clientOptions: llmConfig, + tokenCounter, + tenantId: user?.tenantId, + }); + } + const reasoningKey = getReasoningKey(provider, llmConfig, agent.endpoint); const agentInput: AgentInputs = { provider, reasoningKey, toolDefinitions, + toolSchemaTokens, agentId: agent.id, tools: agent.tools, clientOptions: llmConfig, @@ -401,11 +415,35 @@ export async function createRun({ contextPruningConfig: summarization.contextPruning, maxToolResultChars: agent.maxToolResultChars, }; - agentInputs.push(agentInput); + return agentInput; }; - for (const agent of agents) { - buildAgentContext(agent); + const settled = await Promise.allSettled(agents.map(buildAgentContext)); + const agentInputs: AgentInputs[] = []; + for (let i = 0; i < settled.length; i++) { + const result = settled[i]; + if (result.status === 'fulfilled') { + agentInputs.push(result.value); + } else { + logger.error(`[createRun] buildAgentContext failed for agent ${agents[i].id}`, result.reason); + } + } + + if (agentInputs.length === 0) { + throw new Error( + `[createRun] All ${agents.length} agent(s) failed to initialize; cannot create run`, + ); + } + + const hasEdges = (agents[0].edges?.length ?? 0) > 0; + if (agentInputs.length < agents.length && hasEdges) { + const failedIds = agents + .filter((_, i) => settled[i].status === 'rejected') + .map((a) => a.id) + .join(', '); + throw new Error( + `[createRun] Agent(s) [${failedIds}] failed in a routed multi-agent run; cannot proceed with partial graph`, + ); } const graphConfig: RunConfig['graphConfig'] = { diff --git a/packages/api/src/agents/toolTokens.spec.ts b/packages/api/src/agents/toolTokens.spec.ts new file mode 100644 index 0000000000..8d324ea58b --- /dev/null +++ b/packages/api/src/agents/toolTokens.spec.ts @@ -0,0 +1,412 @@ +import { z } from 'zod'; +import { DynamicStructuredTool } from '@langchain/core/tools'; +import { + Providers, + ANTHROPIC_TOOL_TOKEN_MULTIPLIER, + DEFAULT_TOOL_TOKEN_MULTIPLIER, +} from '@librechat/agents'; + +import type { GenericTool, LCTool, TokenCounter } from '@librechat/agents'; + +import { collectToolSchemas, computeToolSchemaTokens, getOrComputeToolTokens } from './toolTokens'; + +/* ---------- Mock standardCache with hoisted get/set for per-test overrides ---------- */ +const mockCacheStore = new Map(); +const mockGet = jest.fn((key: string) => Promise.resolve(mockCacheStore.get(key))); +const mockSet = jest.fn((key: string, value: unknown) => { + mockCacheStore.set(key, value); + return Promise.resolve(true); +}); + +jest.mock('~/cache', () => ({ + standardCache: jest.fn(() => ({ get: mockGet, set: mockSet })), +})); + +jest.mock('@librechat/data-schemas', () => ({ + logger: { debug: jest.fn(), error: jest.fn(), warn: jest.fn(), info: jest.fn() }, +})); + +/* ---------- Helpers ---------- */ + +function makeTool(name: string, description = `${name} description`): GenericTool { + return new DynamicStructuredTool({ + name, + description, + schema: z.object({ input: z.string().optional() }), + func: async () => 'ok', + }) as unknown as GenericTool; +} + +function makeMcpTool(name: string): GenericTool { + const tool = makeTool(name) as unknown as Record; + tool.mcp = true; + return tool as unknown as GenericTool; +} + +function makeToolDef(name: string, opts?: Partial): LCTool { + return { + name, + description: opts?.description ?? `${name} description`, + parameters: opts?.parameters ?? { type: 'object', properties: { input: { type: 'string' } } }, + ...opts, + }; +} + +/** Token counter that returns the string length of message content (deterministic). */ +const fakeTokenCounter: TokenCounter = (msg) => { + const content = typeof msg.content === 'string' ? msg.content : JSON.stringify(msg.content); + return content.length; +}; + +beforeEach(() => { + mockCacheStore.clear(); + mockGet.mockImplementation((key: string) => Promise.resolve(mockCacheStore.get(key))); + mockSet.mockImplementation((key: string, value: unknown) => { + mockCacheStore.set(key, value); + return Promise.resolve(true); + }); +}); + +/* ========================================================================= */ +/* collectToolSchemas */ +/* ========================================================================= */ + +describe('collectToolSchemas', () => { + it('returns empty array when no tools provided', () => { + expect(collectToolSchemas()).toHaveLength(0); + expect(collectToolSchemas([], [])).toHaveLength(0); + }); + + it('collects entries from GenericTool array', () => { + const entries = collectToolSchemas([makeTool('alpha'), makeTool('beta')]); + expect(entries).toHaveLength(2); + expect(entries.map((e) => e.cacheKey)).toEqual( + expect.arrayContaining(['alpha:builtin', 'beta:builtin']), + ); + }); + + it('collects entries from LCTool definitions with toolType', () => { + const defs = [makeToolDef('x', { toolType: 'mcp' }), makeToolDef('y', { toolType: 'action' })]; + const entries = collectToolSchemas(undefined, defs); + expect(entries).toHaveLength(2); + expect(entries[0].cacheKey).toBe('x:mcp'); + expect(entries[1].cacheKey).toBe('y:action'); + }); + + it('defaults toolType to builtin for LCTool without toolType', () => { + const entries = collectToolSchemas(undefined, [makeToolDef('z')]); + expect(entries[0].cacheKey).toBe('z:builtin'); + }); + + it('uses mcp type for GenericTool with mcp flag', () => { + const entries = collectToolSchemas([makeMcpTool('search')]); + expect(entries[0].cacheKey).toBe('search:mcp'); + }); + + it('deduplicates: GenericTool takes precedence over matching toolDefinition', () => { + const tools = [makeTool('shared')]; + const defs = [makeToolDef('shared'), makeToolDef('only_def')]; + const entries = collectToolSchemas(tools, defs); + expect(entries).toHaveLength(2); + const keys = entries.map((e) => e.cacheKey); + expect(keys).toContain('shared:builtin'); + expect(keys).toContain('only_def:builtin'); + }); +}); + +/* ========================================================================= */ +/* computeToolSchemaTokens */ +/* ========================================================================= */ + +describe('computeToolSchemaTokens', () => { + it('returns 0 when no tools provided', () => { + expect( + computeToolSchemaTokens(undefined, undefined, Providers.OPENAI, undefined, fakeTokenCounter), + ).toBe(0); + expect(computeToolSchemaTokens([], [], Providers.OPENAI, undefined, fakeTokenCounter)).toBe(0); + }); + + it('counts tokens from GenericTool schemas', () => { + const result = computeToolSchemaTokens( + [makeTool('test_tool')], + undefined, + Providers.OPENAI, + undefined, + fakeTokenCounter, + ); + expect(result).toBeGreaterThan(0); + }); + + it('counts tokens from LCTool definitions', () => { + const result = computeToolSchemaTokens( + undefined, + [makeToolDef('test_def')], + Providers.OPENAI, + undefined, + fakeTokenCounter, + ); + expect(result).toBeGreaterThan(0); + }); + + it('deduplicates: tool counted from tools array is skipped in toolDefinitions', () => { + const tools = [makeTool('shared')]; + const defs = [makeToolDef('shared')]; + + const toolsOnly = computeToolSchemaTokens( + tools, + undefined, + Providers.OPENAI, + undefined, + fakeTokenCounter, + ); + const both = computeToolSchemaTokens( + tools, + defs, + Providers.OPENAI, + undefined, + fakeTokenCounter, + ); + expect(both).toBe(toolsOnly); + }); + + it('applies Anthropic multiplier for Anthropic provider', () => { + const defs = [makeToolDef('tool')]; + const openai = computeToolSchemaTokens( + undefined, + defs, + Providers.OPENAI, + undefined, + fakeTokenCounter, + ); + const anthropic = computeToolSchemaTokens( + undefined, + defs, + Providers.ANTHROPIC, + undefined, + fakeTokenCounter, + ); + const expectedRatio = ANTHROPIC_TOOL_TOKEN_MULTIPLIER / DEFAULT_TOOL_TOKEN_MULTIPLIER; + expect(anthropic / openai).toBeCloseTo(expectedRatio, 1); + }); + + it('applies Anthropic multiplier when model name contains "claude"', () => { + const defs = [makeToolDef('tool')]; + const result = computeToolSchemaTokens( + undefined, + defs, + Providers.OPENAI, + { model: 'claude-3-opus' }, + fakeTokenCounter, + ); + const defaultResult = computeToolSchemaTokens( + undefined, + defs, + Providers.OPENAI, + undefined, + fakeTokenCounter, + ); + expect(result).toBeGreaterThan(defaultResult); + }); + + it('does not apply Anthropic multiplier for Bedrock even with claude model', () => { + const defs = [makeToolDef('tool')]; + const bedrock = computeToolSchemaTokens( + undefined, + defs, + Providers.BEDROCK, + { model: 'claude-3-opus' }, + fakeTokenCounter, + ); + const defaultResult = computeToolSchemaTokens( + undefined, + defs, + Providers.OPENAI, + undefined, + fakeTokenCounter, + ); + expect(bedrock).toBe(defaultResult); + }); +}); + +/* ========================================================================= */ +/* getOrComputeToolTokens */ +/* ========================================================================= */ + +describe('getOrComputeToolTokens', () => { + it('returns 0 when no tools provided', async () => { + const result = await getOrComputeToolTokens({ + provider: Providers.OPENAI, + tokenCounter: fakeTokenCounter, + }); + expect(result).toBe(0); + }); + + it('computes and caches each tool individually on first call', async () => { + const defs = [makeToolDef('tool_a'), makeToolDef('tool_b')]; + const result = await getOrComputeToolTokens({ + toolDefinitions: defs, + provider: Providers.OPENAI, + tokenCounter: fakeTokenCounter, + }); + + expect(result).toBeGreaterThan(0); + expect(mockCacheStore.has('tool_a:builtin')).toBe(true); + expect(mockCacheStore.has('tool_b:builtin')).toBe(true); + expect(mockCacheStore.size).toBe(2); + }); + + it('uses cached per-tool values on second call without recomputing', async () => { + const defs = [makeToolDef('tool_a')]; + const counter = jest.fn(fakeTokenCounter); + + const first = await getOrComputeToolTokens({ + toolDefinitions: defs, + provider: Providers.OPENAI, + tokenCounter: counter, + }); + + const callCountAfterFirst = counter.mock.calls.length; + + const second = await getOrComputeToolTokens({ + toolDefinitions: defs, + provider: Providers.OPENAI, + tokenCounter: counter, + }); + + expect(second).toBe(first); + expect(counter.mock.calls.length).toBe(callCountAfterFirst); + }); + + it('applies different multipliers for different providers on same cached raw counts', async () => { + const defs = [makeToolDef('tool')]; + + const openai = await getOrComputeToolTokens({ + toolDefinitions: defs, + provider: Providers.OPENAI, + tokenCounter: fakeTokenCounter, + }); + + const anthropic = await getOrComputeToolTokens({ + toolDefinitions: defs, + provider: Providers.ANTHROPIC, + tokenCounter: fakeTokenCounter, + }); + + expect(openai).not.toBe(anthropic); + expect(mockCacheStore.size).toBe(1); + }); + + it('only computes new tools when tool set grows', async () => { + const counter = jest.fn(fakeTokenCounter); + + await getOrComputeToolTokens({ + toolDefinitions: [makeToolDef('tool_a')], + provider: Providers.OPENAI, + tokenCounter: counter, + }); + const callsAfterFirst = counter.mock.calls.length; + + await getOrComputeToolTokens({ + toolDefinitions: [makeToolDef('tool_a'), makeToolDef('tool_b')], + provider: Providers.OPENAI, + tokenCounter: counter, + }); + + expect(counter.mock.calls.length).toBe(callsAfterFirst + 1); + expect(mockCacheStore.size).toBe(2); + }); + + it('scopes cache keys by tenantId when provided', async () => { + const defs = [makeToolDef('tool')]; + + await getOrComputeToolTokens({ + toolDefinitions: defs, + provider: Providers.OPENAI, + tokenCounter: fakeTokenCounter, + tenantId: 'tenant_123', + }); + + expect(mockCacheStore.has('tenant_123:tool:builtin')).toBe(true); + }); + + it('separates cache entries for different tenants', async () => { + const defs = [makeToolDef('tool')]; + + const t1 = await getOrComputeToolTokens({ + toolDefinitions: defs, + provider: Providers.OPENAI, + tokenCounter: fakeTokenCounter, + tenantId: 'tenant_1', + }); + + const t2 = await getOrComputeToolTokens({ + toolDefinitions: defs, + provider: Providers.OPENAI, + tokenCounter: fakeTokenCounter, + tenantId: 'tenant_2', + }); + + expect(t1).toBe(t2); + expect(mockCacheStore.has('tenant_1:tool:builtin')).toBe(true); + expect(mockCacheStore.has('tenant_2:tool:builtin')).toBe(true); + expect(mockCacheStore.size).toBe(2); + }); + + it('caches mcp tools with mcp type in key', async () => { + const defs = [makeToolDef('search', { toolType: 'mcp' })]; + + await getOrComputeToolTokens({ + toolDefinitions: defs, + provider: Providers.OPENAI, + tokenCounter: fakeTokenCounter, + }); + + expect(mockCacheStore.has('search:mcp')).toBe(true); + }); + + it('falls back to compute when cache read throws', async () => { + mockGet.mockRejectedValueOnce(new Error('Redis down')); + + const result = await getOrComputeToolTokens({ + toolDefinitions: [makeToolDef('tool')], + provider: Providers.OPENAI, + tokenCounter: fakeTokenCounter, + }); + + expect(result).toBeGreaterThan(0); + expect(mockGet).toHaveBeenCalled(); + }); + + it('does not throw when cache write fails', async () => { + mockSet.mockRejectedValueOnce(new Error('Redis write error')); + + const result = await getOrComputeToolTokens({ + toolDefinitions: [makeToolDef('tool_write_fail')], + provider: Providers.OPENAI, + tokenCounter: fakeTokenCounter, + }); + + expect(result).toBeGreaterThan(0); + expect(mockSet).toHaveBeenCalled(); + }); + + it('matches computeToolSchemaTokens output for same inputs', async () => { + const defs = [makeToolDef('a'), makeToolDef('b'), makeToolDef('c')]; + + const cached = await getOrComputeToolTokens({ + toolDefinitions: defs, + provider: Providers.OPENAI, + tokenCounter: fakeTokenCounter, + }); + + const direct = computeToolSchemaTokens( + undefined, + defs, + Providers.OPENAI, + undefined, + fakeTokenCounter, + ); + + expect(cached).toBe(direct); + }); +}); diff --git a/packages/api/src/agents/toolTokens.ts b/packages/api/src/agents/toolTokens.ts new file mode 100644 index 0000000000..4a95b07600 --- /dev/null +++ b/packages/api/src/agents/toolTokens.ts @@ -0,0 +1,201 @@ +import { logger } from '@librechat/data-schemas'; +import { SystemMessage } from '@langchain/core/messages'; +import { CacheKeys, Time } from 'librechat-data-provider'; +import { + Providers, + toJsonSchema, + ANTHROPIC_TOOL_TOKEN_MULTIPLIER, + DEFAULT_TOOL_TOKEN_MULTIPLIER, +} from '@librechat/agents'; + +import type { GenericTool, LCTool, TokenCounter, ClientOptions } from '@librechat/agents'; +import type { Keyv } from 'keyv'; + +import { standardCache } from '~/cache'; + +interface ToolEntry { + cacheKey: string; + json: string; +} + +/** Module-level cache instance, lazily initialized. */ +let toolTokenCache: Keyv | undefined; + +function getCache(): Keyv { + if (!toolTokenCache) { + toolTokenCache = standardCache(CacheKeys.TOOL_TOKENS, Time.THIRTY_MINUTES); + } + return toolTokenCache; +} + +export function getToolTokenMultiplier(provider: Providers, clientOptions?: ClientOptions): number { + const isAnthropic = + provider !== Providers.BEDROCK && + (provider === Providers.ANTHROPIC || + /anthropic|claude/i.test( + String((clientOptions as { model?: string } | undefined)?.model ?? ''), + )); + return isAnthropic ? ANTHROPIC_TOOL_TOKEN_MULTIPLIER : DEFAULT_TOOL_TOKEN_MULTIPLIER; +} + +/** Serializes a GenericTool to a JSON string for token counting. Returns null if no schema. */ +function serializeGenericTool(tool: GenericTool): { name: string; json: string } | null { + const genericTool = tool as unknown as Record; + const toolName = (genericTool.name as string | undefined) ?? ''; + if (genericTool.schema == null || typeof genericTool.schema !== 'object') { + return null; + } + const jsonSchema = toJsonSchema( + genericTool.schema, + toolName, + (genericTool.description as string | undefined) ?? '', + ); + return { name: toolName, json: JSON.stringify(jsonSchema) }; +} + +/** Serializes an LCTool definition to a JSON string for token counting. */ +function serializeToolDef(def: LCTool): string { + return JSON.stringify({ + type: 'function', + function: { + name: def.name, + description: def.description ?? '', + parameters: def.parameters ?? {}, + }, + }); +} + +/** + * Builds a list of tool entries with cache keys and serialized schemas. + * Deduplicates: a tool present in `tools` (with a schema) takes precedence + * over a matching `toolDefinitions` entry. + * + * Cache key includes toolType when available (from LCTool) to differentiate + * builtin/mcp/action tools that may share a name. + * GenericTool entries use the `mcp` flag when present. + */ +export function collectToolSchemas(tools?: GenericTool[], toolDefinitions?: LCTool[]): ToolEntry[] { + const seen = new Set(); + const entries: ToolEntry[] = []; + + if (tools) { + for (const tool of tools) { + const result = serializeGenericTool(tool); + if (!result || !result.name) { + continue; + } + seen.add(result.name); + const toolType = + (tool as unknown as Record).mcp === true ? 'mcp' : 'builtin'; + entries.push({ cacheKey: `${result.name}:${toolType}`, json: result.json }); + } + } + + if (toolDefinitions) { + for (const def of toolDefinitions) { + if (!def.name || seen.has(def.name)) { + continue; + } + seen.add(def.name); + const toolType = def.toolType ?? 'builtin'; + entries.push({ cacheKey: `${def.name}:${toolType}`, json: serializeToolDef(def) }); + } + } + + return entries; +} + +/** + * Computes tool schema tokens from scratch using the provided token counter. + * Mirrors the logic in AgentContext.calculateInstructionTokens(). + */ +export function computeToolSchemaTokens( + tools: GenericTool[] | undefined, + toolDefinitions: LCTool[] | undefined, + provider: Providers, + clientOptions: ClientOptions | undefined, + tokenCounter: TokenCounter, +): number { + const entries = collectToolSchemas(tools, toolDefinitions); + let rawTokens = 0; + for (const { json } of entries) { + rawTokens += tokenCounter(new SystemMessage(json)); + } + const multiplier = getToolTokenMultiplier(provider, clientOptions); + return Math.ceil(rawTokens * multiplier); +} + +/** + * Returns tool schema tokens, using per-tool caching to avoid redundant + * token counting. Each tool's raw (pre-multiplier) token count is cached + * individually, keyed by `{tenantId}:{name}:{toolType}` (or `{name}:{toolType}` + * without tenant). The provider-specific multiplier is applied to the sum. + * + * Returns 0 if there are no tools. + */ +export async function getOrComputeToolTokens({ + tools, + toolDefinitions, + provider, + clientOptions, + tokenCounter, + tenantId, +}: { + tools?: GenericTool[]; + toolDefinitions?: LCTool[]; + provider: Providers; + clientOptions?: ClientOptions; + tokenCounter: TokenCounter; + tenantId?: string; +}): Promise { + const entries = collectToolSchemas(tools, toolDefinitions); + if (entries.length === 0) { + return 0; + } + + const keyPrefix = tenantId ? `${tenantId}:` : ''; + + let cache: Keyv | undefined; + try { + cache = getCache(); + } catch (err) { + logger.debug('[toolTokens] Cache init failed, computing fresh', err); + } + + let rawTotal = 0; + const toWrite: Array<{ key: string; value: number }> = []; + + for (const { cacheKey, json } of entries) { + const fullKey = `${keyPrefix}${cacheKey}`; + let rawCount: number | undefined; + + if (cache) { + try { + rawCount = (await cache.get(fullKey)) as number | undefined; + } catch { + // Cache read failed for this tool — will compute fresh + } + } + + if (rawCount == null || rawCount <= 0) { + rawCount = tokenCounter(new SystemMessage(json)); + if (rawCount > 0 && cache) { + toWrite.push({ key: fullKey, value: rawCount }); + } + } + + rawTotal += rawCount; + } + + // Fire-and-forget cache writes for newly computed tools + if (cache && toWrite.length > 0) { + for (const { key, value } of toWrite) { + cache.set(key, value).catch((err: unknown) => { + logger.debug(`[toolTokens] Cache write failed for ${key}`, err); + }); + } + } + + const multiplier = getToolTokenMultiplier(provider, clientOptions); + return Math.ceil(rawTotal * multiplier); +} diff --git a/packages/data-provider/src/config.ts b/packages/data-provider/src/config.ts index ca40ec2c8c..dbed903a69 100644 --- a/packages/data-provider/src/config.ts +++ b/packages/data-provider/src/config.ts @@ -1564,6 +1564,11 @@ export enum CacheKeys { * Key for admin panel OAuth exchange codes (one-time-use, short TTL). */ ADMIN_OAUTH_EXCHANGE = 'ADMIN_OAUTH_EXCHANGE', + /** + * Key for cached tool schema token counts. + * Keyed by provider + tool fingerprint to avoid redundant token counting. + */ + TOOL_TOKENS = 'TOOL_TOKENS', } /**