mirror of
https://github.com/danny-avila/LibreChat.git
synced 2026-04-08 08:54:38 +02:00
Merge e2962f4967 into 8ed0bcf5ca
This commit is contained in:
commit
e3d54652ff
5 changed files with 662 additions and 5 deletions
|
|
@ -19,3 +19,4 @@ export * from './tools';
|
||||||
export * from './validation';
|
export * from './validation';
|
||||||
export * from './added';
|
export * from './added';
|
||||||
export * from './load';
|
export * from './load';
|
||||||
|
export * from './toolTokens';
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
import { logger } from '@librechat/data-schemas';
|
||||||
import { Run, Providers, Constants } from '@librechat/agents';
|
import { Run, Providers, Constants } from '@librechat/agents';
|
||||||
import { providerEndpointMap, KnownEndpoints } from 'librechat-data-provider';
|
import { providerEndpointMap, KnownEndpoints } from 'librechat-data-provider';
|
||||||
import type {
|
import type {
|
||||||
|
|
@ -18,6 +19,7 @@ import type { BaseMessage } from '@langchain/core/messages';
|
||||||
import type { IUser } from '@librechat/data-schemas';
|
import type { IUser } from '@librechat/data-schemas';
|
||||||
import type * as t from '~/types';
|
import type * as t from '~/types';
|
||||||
import { resolveHeaders, createSafeUser } from '~/utils/env';
|
import { resolveHeaders, createSafeUser } from '~/utils/env';
|
||||||
|
import { getOrComputeToolTokens } from './toolTokens';
|
||||||
|
|
||||||
/** Expected shape of JSON tool search results */
|
/** Expected shape of JSON tool search results */
|
||||||
interface ToolSearchJsonResult {
|
interface ToolSearchJsonResult {
|
||||||
|
|
@ -295,8 +297,7 @@ export async function createRun({
|
||||||
? extractDiscoveredToolsFromHistory(messages)
|
? extractDiscoveredToolsFromHistory(messages)
|
||||||
: new Set<string>();
|
: new Set<string>();
|
||||||
|
|
||||||
const agentInputs: AgentInputs[] = [];
|
const buildAgentContext = async (agent: RunAgent): Promise<AgentInputs> => {
|
||||||
const buildAgentContext = (agent: RunAgent) => {
|
|
||||||
const provider =
|
const provider =
|
||||||
(providerEndpointMap[
|
(providerEndpointMap[
|
||||||
agent.provider as keyof typeof providerEndpointMap
|
agent.provider as keyof typeof providerEndpointMap
|
||||||
|
|
@ -381,11 +382,24 @@ export async function createRun({
|
||||||
agent.maxContextTokens,
|
agent.maxContextTokens,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
let toolSchemaTokens: number | undefined;
|
||||||
|
if (tokenCounter) {
|
||||||
|
toolSchemaTokens = await getOrComputeToolTokens({
|
||||||
|
tools: agent.tools,
|
||||||
|
toolDefinitions,
|
||||||
|
provider,
|
||||||
|
clientOptions: llmConfig,
|
||||||
|
tokenCounter,
|
||||||
|
tenantId: user?.tenantId,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
const reasoningKey = getReasoningKey(provider, llmConfig, agent.endpoint);
|
const reasoningKey = getReasoningKey(provider, llmConfig, agent.endpoint);
|
||||||
const agentInput: AgentInputs = {
|
const agentInput: AgentInputs = {
|
||||||
provider,
|
provider,
|
||||||
reasoningKey,
|
reasoningKey,
|
||||||
toolDefinitions,
|
toolDefinitions,
|
||||||
|
toolSchemaTokens,
|
||||||
agentId: agent.id,
|
agentId: agent.id,
|
||||||
tools: agent.tools,
|
tools: agent.tools,
|
||||||
clientOptions: llmConfig,
|
clientOptions: llmConfig,
|
||||||
|
|
@ -401,11 +415,35 @@ export async function createRun({
|
||||||
contextPruningConfig: summarization.contextPruning,
|
contextPruningConfig: summarization.contextPruning,
|
||||||
maxToolResultChars: agent.maxToolResultChars,
|
maxToolResultChars: agent.maxToolResultChars,
|
||||||
};
|
};
|
||||||
agentInputs.push(agentInput);
|
return agentInput;
|
||||||
};
|
};
|
||||||
|
|
||||||
for (const agent of agents) {
|
const settled = await Promise.allSettled(agents.map(buildAgentContext));
|
||||||
buildAgentContext(agent);
|
const agentInputs: AgentInputs[] = [];
|
||||||
|
for (let i = 0; i < settled.length; i++) {
|
||||||
|
const result = settled[i];
|
||||||
|
if (result.status === 'fulfilled') {
|
||||||
|
agentInputs.push(result.value);
|
||||||
|
} else {
|
||||||
|
logger.error(`[createRun] buildAgentContext failed for agent ${agents[i].id}`, result.reason);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (agentInputs.length === 0) {
|
||||||
|
throw new Error(
|
||||||
|
`[createRun] All ${agents.length} agent(s) failed to initialize; cannot create run`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
const hasEdges = (agents[0].edges?.length ?? 0) > 0;
|
||||||
|
if (agentInputs.length < agents.length && hasEdges) {
|
||||||
|
const failedIds = agents
|
||||||
|
.filter((_, i) => settled[i].status === 'rejected')
|
||||||
|
.map((a) => a.id)
|
||||||
|
.join(', ');
|
||||||
|
throw new Error(
|
||||||
|
`[createRun] Agent(s) [${failedIds}] failed in a routed multi-agent run; cannot proceed with partial graph`,
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
const graphConfig: RunConfig['graphConfig'] = {
|
const graphConfig: RunConfig['graphConfig'] = {
|
||||||
|
|
|
||||||
412
packages/api/src/agents/toolTokens.spec.ts
Normal file
412
packages/api/src/agents/toolTokens.spec.ts
Normal file
|
|
@ -0,0 +1,412 @@
|
||||||
|
import { z } from 'zod';
|
||||||
|
import { DynamicStructuredTool } from '@langchain/core/tools';
|
||||||
|
import {
|
||||||
|
Providers,
|
||||||
|
ANTHROPIC_TOOL_TOKEN_MULTIPLIER,
|
||||||
|
DEFAULT_TOOL_TOKEN_MULTIPLIER,
|
||||||
|
} from '@librechat/agents';
|
||||||
|
|
||||||
|
import type { GenericTool, LCTool, TokenCounter } from '@librechat/agents';
|
||||||
|
|
||||||
|
import { collectToolSchemas, computeToolSchemaTokens, getOrComputeToolTokens } from './toolTokens';
|
||||||
|
|
||||||
|
/* ---------- Mock standardCache with hoisted get/set for per-test overrides ---------- */
|
||||||
|
const mockCacheStore = new Map<string, unknown>();
|
||||||
|
const mockGet = jest.fn((key: string) => Promise.resolve(mockCacheStore.get(key)));
|
||||||
|
const mockSet = jest.fn((key: string, value: unknown) => {
|
||||||
|
mockCacheStore.set(key, value);
|
||||||
|
return Promise.resolve(true);
|
||||||
|
});
|
||||||
|
|
||||||
|
jest.mock('~/cache', () => ({
|
||||||
|
standardCache: jest.fn(() => ({ get: mockGet, set: mockSet })),
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('@librechat/data-schemas', () => ({
|
||||||
|
logger: { debug: jest.fn(), error: jest.fn(), warn: jest.fn(), info: jest.fn() },
|
||||||
|
}));
|
||||||
|
|
||||||
|
/* ---------- Helpers ---------- */
|
||||||
|
|
||||||
|
function makeTool(name: string, description = `${name} description`): GenericTool {
|
||||||
|
return new DynamicStructuredTool({
|
||||||
|
name,
|
||||||
|
description,
|
||||||
|
schema: z.object({ input: z.string().optional() }),
|
||||||
|
func: async () => 'ok',
|
||||||
|
}) as unknown as GenericTool;
|
||||||
|
}
|
||||||
|
|
||||||
|
function makeMcpTool(name: string): GenericTool {
|
||||||
|
const tool = makeTool(name) as unknown as Record<string, unknown>;
|
||||||
|
tool.mcp = true;
|
||||||
|
return tool as unknown as GenericTool;
|
||||||
|
}
|
||||||
|
|
||||||
|
function makeToolDef(name: string, opts?: Partial<LCTool>): LCTool {
|
||||||
|
return {
|
||||||
|
name,
|
||||||
|
description: opts?.description ?? `${name} description`,
|
||||||
|
parameters: opts?.parameters ?? { type: 'object', properties: { input: { type: 'string' } } },
|
||||||
|
...opts,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Token counter that returns the string length of message content (deterministic). */
|
||||||
|
const fakeTokenCounter: TokenCounter = (msg) => {
|
||||||
|
const content = typeof msg.content === 'string' ? msg.content : JSON.stringify(msg.content);
|
||||||
|
return content.length;
|
||||||
|
};
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
mockCacheStore.clear();
|
||||||
|
mockGet.mockImplementation((key: string) => Promise.resolve(mockCacheStore.get(key)));
|
||||||
|
mockSet.mockImplementation((key: string, value: unknown) => {
|
||||||
|
mockCacheStore.set(key, value);
|
||||||
|
return Promise.resolve(true);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
/* ========================================================================= */
|
||||||
|
/* collectToolSchemas */
|
||||||
|
/* ========================================================================= */
|
||||||
|
|
||||||
|
describe('collectToolSchemas', () => {
|
||||||
|
it('returns empty array when no tools provided', () => {
|
||||||
|
expect(collectToolSchemas()).toHaveLength(0);
|
||||||
|
expect(collectToolSchemas([], [])).toHaveLength(0);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('collects entries from GenericTool array', () => {
|
||||||
|
const entries = collectToolSchemas([makeTool('alpha'), makeTool('beta')]);
|
||||||
|
expect(entries).toHaveLength(2);
|
||||||
|
expect(entries.map((e) => e.cacheKey)).toEqual(
|
||||||
|
expect.arrayContaining(['alpha:builtin', 'beta:builtin']),
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('collects entries from LCTool definitions with toolType', () => {
|
||||||
|
const defs = [makeToolDef('x', { toolType: 'mcp' }), makeToolDef('y', { toolType: 'action' })];
|
||||||
|
const entries = collectToolSchemas(undefined, defs);
|
||||||
|
expect(entries).toHaveLength(2);
|
||||||
|
expect(entries[0].cacheKey).toBe('x:mcp');
|
||||||
|
expect(entries[1].cacheKey).toBe('y:action');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('defaults toolType to builtin for LCTool without toolType', () => {
|
||||||
|
const entries = collectToolSchemas(undefined, [makeToolDef('z')]);
|
||||||
|
expect(entries[0].cacheKey).toBe('z:builtin');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('uses mcp type for GenericTool with mcp flag', () => {
|
||||||
|
const entries = collectToolSchemas([makeMcpTool('search')]);
|
||||||
|
expect(entries[0].cacheKey).toBe('search:mcp');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('deduplicates: GenericTool takes precedence over matching toolDefinition', () => {
|
||||||
|
const tools = [makeTool('shared')];
|
||||||
|
const defs = [makeToolDef('shared'), makeToolDef('only_def')];
|
||||||
|
const entries = collectToolSchemas(tools, defs);
|
||||||
|
expect(entries).toHaveLength(2);
|
||||||
|
const keys = entries.map((e) => e.cacheKey);
|
||||||
|
expect(keys).toContain('shared:builtin');
|
||||||
|
expect(keys).toContain('only_def:builtin');
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
/* ========================================================================= */
|
||||||
|
/* computeToolSchemaTokens */
|
||||||
|
/* ========================================================================= */
|
||||||
|
|
||||||
|
describe('computeToolSchemaTokens', () => {
|
||||||
|
it('returns 0 when no tools provided', () => {
|
||||||
|
expect(
|
||||||
|
computeToolSchemaTokens(undefined, undefined, Providers.OPENAI, undefined, fakeTokenCounter),
|
||||||
|
).toBe(0);
|
||||||
|
expect(computeToolSchemaTokens([], [], Providers.OPENAI, undefined, fakeTokenCounter)).toBe(0);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('counts tokens from GenericTool schemas', () => {
|
||||||
|
const result = computeToolSchemaTokens(
|
||||||
|
[makeTool('test_tool')],
|
||||||
|
undefined,
|
||||||
|
Providers.OPENAI,
|
||||||
|
undefined,
|
||||||
|
fakeTokenCounter,
|
||||||
|
);
|
||||||
|
expect(result).toBeGreaterThan(0);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('counts tokens from LCTool definitions', () => {
|
||||||
|
const result = computeToolSchemaTokens(
|
||||||
|
undefined,
|
||||||
|
[makeToolDef('test_def')],
|
||||||
|
Providers.OPENAI,
|
||||||
|
undefined,
|
||||||
|
fakeTokenCounter,
|
||||||
|
);
|
||||||
|
expect(result).toBeGreaterThan(0);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('deduplicates: tool counted from tools array is skipped in toolDefinitions', () => {
|
||||||
|
const tools = [makeTool('shared')];
|
||||||
|
const defs = [makeToolDef('shared')];
|
||||||
|
|
||||||
|
const toolsOnly = computeToolSchemaTokens(
|
||||||
|
tools,
|
||||||
|
undefined,
|
||||||
|
Providers.OPENAI,
|
||||||
|
undefined,
|
||||||
|
fakeTokenCounter,
|
||||||
|
);
|
||||||
|
const both = computeToolSchemaTokens(
|
||||||
|
tools,
|
||||||
|
defs,
|
||||||
|
Providers.OPENAI,
|
||||||
|
undefined,
|
||||||
|
fakeTokenCounter,
|
||||||
|
);
|
||||||
|
expect(both).toBe(toolsOnly);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('applies Anthropic multiplier for Anthropic provider', () => {
|
||||||
|
const defs = [makeToolDef('tool')];
|
||||||
|
const openai = computeToolSchemaTokens(
|
||||||
|
undefined,
|
||||||
|
defs,
|
||||||
|
Providers.OPENAI,
|
||||||
|
undefined,
|
||||||
|
fakeTokenCounter,
|
||||||
|
);
|
||||||
|
const anthropic = computeToolSchemaTokens(
|
||||||
|
undefined,
|
||||||
|
defs,
|
||||||
|
Providers.ANTHROPIC,
|
||||||
|
undefined,
|
||||||
|
fakeTokenCounter,
|
||||||
|
);
|
||||||
|
const expectedRatio = ANTHROPIC_TOOL_TOKEN_MULTIPLIER / DEFAULT_TOOL_TOKEN_MULTIPLIER;
|
||||||
|
expect(anthropic / openai).toBeCloseTo(expectedRatio, 1);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('applies Anthropic multiplier when model name contains "claude"', () => {
|
||||||
|
const defs = [makeToolDef('tool')];
|
||||||
|
const result = computeToolSchemaTokens(
|
||||||
|
undefined,
|
||||||
|
defs,
|
||||||
|
Providers.OPENAI,
|
||||||
|
{ model: 'claude-3-opus' },
|
||||||
|
fakeTokenCounter,
|
||||||
|
);
|
||||||
|
const defaultResult = computeToolSchemaTokens(
|
||||||
|
undefined,
|
||||||
|
defs,
|
||||||
|
Providers.OPENAI,
|
||||||
|
undefined,
|
||||||
|
fakeTokenCounter,
|
||||||
|
);
|
||||||
|
expect(result).toBeGreaterThan(defaultResult);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('does not apply Anthropic multiplier for Bedrock even with claude model', () => {
|
||||||
|
const defs = [makeToolDef('tool')];
|
||||||
|
const bedrock = computeToolSchemaTokens(
|
||||||
|
undefined,
|
||||||
|
defs,
|
||||||
|
Providers.BEDROCK,
|
||||||
|
{ model: 'claude-3-opus' },
|
||||||
|
fakeTokenCounter,
|
||||||
|
);
|
||||||
|
const defaultResult = computeToolSchemaTokens(
|
||||||
|
undefined,
|
||||||
|
defs,
|
||||||
|
Providers.OPENAI,
|
||||||
|
undefined,
|
||||||
|
fakeTokenCounter,
|
||||||
|
);
|
||||||
|
expect(bedrock).toBe(defaultResult);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
/* ========================================================================= */
|
||||||
|
/* getOrComputeToolTokens */
|
||||||
|
/* ========================================================================= */
|
||||||
|
|
||||||
|
describe('getOrComputeToolTokens', () => {
|
||||||
|
it('returns 0 when no tools provided', async () => {
|
||||||
|
const result = await getOrComputeToolTokens({
|
||||||
|
provider: Providers.OPENAI,
|
||||||
|
tokenCounter: fakeTokenCounter,
|
||||||
|
});
|
||||||
|
expect(result).toBe(0);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('computes and caches each tool individually on first call', async () => {
|
||||||
|
const defs = [makeToolDef('tool_a'), makeToolDef('tool_b')];
|
||||||
|
const result = await getOrComputeToolTokens({
|
||||||
|
toolDefinitions: defs,
|
||||||
|
provider: Providers.OPENAI,
|
||||||
|
tokenCounter: fakeTokenCounter,
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(result).toBeGreaterThan(0);
|
||||||
|
expect(mockCacheStore.has('tool_a:builtin')).toBe(true);
|
||||||
|
expect(mockCacheStore.has('tool_b:builtin')).toBe(true);
|
||||||
|
expect(mockCacheStore.size).toBe(2);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('uses cached per-tool values on second call without recomputing', async () => {
|
||||||
|
const defs = [makeToolDef('tool_a')];
|
||||||
|
const counter = jest.fn(fakeTokenCounter);
|
||||||
|
|
||||||
|
const first = await getOrComputeToolTokens({
|
||||||
|
toolDefinitions: defs,
|
||||||
|
provider: Providers.OPENAI,
|
||||||
|
tokenCounter: counter,
|
||||||
|
});
|
||||||
|
|
||||||
|
const callCountAfterFirst = counter.mock.calls.length;
|
||||||
|
|
||||||
|
const second = await getOrComputeToolTokens({
|
||||||
|
toolDefinitions: defs,
|
||||||
|
provider: Providers.OPENAI,
|
||||||
|
tokenCounter: counter,
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(second).toBe(first);
|
||||||
|
expect(counter.mock.calls.length).toBe(callCountAfterFirst);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('applies different multipliers for different providers on same cached raw counts', async () => {
|
||||||
|
const defs = [makeToolDef('tool')];
|
||||||
|
|
||||||
|
const openai = await getOrComputeToolTokens({
|
||||||
|
toolDefinitions: defs,
|
||||||
|
provider: Providers.OPENAI,
|
||||||
|
tokenCounter: fakeTokenCounter,
|
||||||
|
});
|
||||||
|
|
||||||
|
const anthropic = await getOrComputeToolTokens({
|
||||||
|
toolDefinitions: defs,
|
||||||
|
provider: Providers.ANTHROPIC,
|
||||||
|
tokenCounter: fakeTokenCounter,
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(openai).not.toBe(anthropic);
|
||||||
|
expect(mockCacheStore.size).toBe(1);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('only computes new tools when tool set grows', async () => {
|
||||||
|
const counter = jest.fn(fakeTokenCounter);
|
||||||
|
|
||||||
|
await getOrComputeToolTokens({
|
||||||
|
toolDefinitions: [makeToolDef('tool_a')],
|
||||||
|
provider: Providers.OPENAI,
|
||||||
|
tokenCounter: counter,
|
||||||
|
});
|
||||||
|
const callsAfterFirst = counter.mock.calls.length;
|
||||||
|
|
||||||
|
await getOrComputeToolTokens({
|
||||||
|
toolDefinitions: [makeToolDef('tool_a'), makeToolDef('tool_b')],
|
||||||
|
provider: Providers.OPENAI,
|
||||||
|
tokenCounter: counter,
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(counter.mock.calls.length).toBe(callsAfterFirst + 1);
|
||||||
|
expect(mockCacheStore.size).toBe(2);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('scopes cache keys by tenantId when provided', async () => {
|
||||||
|
const defs = [makeToolDef('tool')];
|
||||||
|
|
||||||
|
await getOrComputeToolTokens({
|
||||||
|
toolDefinitions: defs,
|
||||||
|
provider: Providers.OPENAI,
|
||||||
|
tokenCounter: fakeTokenCounter,
|
||||||
|
tenantId: 'tenant_123',
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(mockCacheStore.has('tenant_123:tool:builtin')).toBe(true);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('separates cache entries for different tenants', async () => {
|
||||||
|
const defs = [makeToolDef('tool')];
|
||||||
|
|
||||||
|
const t1 = await getOrComputeToolTokens({
|
||||||
|
toolDefinitions: defs,
|
||||||
|
provider: Providers.OPENAI,
|
||||||
|
tokenCounter: fakeTokenCounter,
|
||||||
|
tenantId: 'tenant_1',
|
||||||
|
});
|
||||||
|
|
||||||
|
const t2 = await getOrComputeToolTokens({
|
||||||
|
toolDefinitions: defs,
|
||||||
|
provider: Providers.OPENAI,
|
||||||
|
tokenCounter: fakeTokenCounter,
|
||||||
|
tenantId: 'tenant_2',
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(t1).toBe(t2);
|
||||||
|
expect(mockCacheStore.has('tenant_1:tool:builtin')).toBe(true);
|
||||||
|
expect(mockCacheStore.has('tenant_2:tool:builtin')).toBe(true);
|
||||||
|
expect(mockCacheStore.size).toBe(2);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('caches mcp tools with mcp type in key', async () => {
|
||||||
|
const defs = [makeToolDef('search', { toolType: 'mcp' })];
|
||||||
|
|
||||||
|
await getOrComputeToolTokens({
|
||||||
|
toolDefinitions: defs,
|
||||||
|
provider: Providers.OPENAI,
|
||||||
|
tokenCounter: fakeTokenCounter,
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(mockCacheStore.has('search:mcp')).toBe(true);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('falls back to compute when cache read throws', async () => {
|
||||||
|
mockGet.mockRejectedValueOnce(new Error('Redis down'));
|
||||||
|
|
||||||
|
const result = await getOrComputeToolTokens({
|
||||||
|
toolDefinitions: [makeToolDef('tool')],
|
||||||
|
provider: Providers.OPENAI,
|
||||||
|
tokenCounter: fakeTokenCounter,
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(result).toBeGreaterThan(0);
|
||||||
|
expect(mockGet).toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('does not throw when cache write fails', async () => {
|
||||||
|
mockSet.mockRejectedValueOnce(new Error('Redis write error'));
|
||||||
|
|
||||||
|
const result = await getOrComputeToolTokens({
|
||||||
|
toolDefinitions: [makeToolDef('tool_write_fail')],
|
||||||
|
provider: Providers.OPENAI,
|
||||||
|
tokenCounter: fakeTokenCounter,
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(result).toBeGreaterThan(0);
|
||||||
|
expect(mockSet).toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('matches computeToolSchemaTokens output for same inputs', async () => {
|
||||||
|
const defs = [makeToolDef('a'), makeToolDef('b'), makeToolDef('c')];
|
||||||
|
|
||||||
|
const cached = await getOrComputeToolTokens({
|
||||||
|
toolDefinitions: defs,
|
||||||
|
provider: Providers.OPENAI,
|
||||||
|
tokenCounter: fakeTokenCounter,
|
||||||
|
});
|
||||||
|
|
||||||
|
const direct = computeToolSchemaTokens(
|
||||||
|
undefined,
|
||||||
|
defs,
|
||||||
|
Providers.OPENAI,
|
||||||
|
undefined,
|
||||||
|
fakeTokenCounter,
|
||||||
|
);
|
||||||
|
|
||||||
|
expect(cached).toBe(direct);
|
||||||
|
});
|
||||||
|
});
|
||||||
201
packages/api/src/agents/toolTokens.ts
Normal file
201
packages/api/src/agents/toolTokens.ts
Normal file
|
|
@ -0,0 +1,201 @@
|
||||||
|
import { logger } from '@librechat/data-schemas';
|
||||||
|
import { SystemMessage } from '@langchain/core/messages';
|
||||||
|
import { CacheKeys, Time } from 'librechat-data-provider';
|
||||||
|
import {
|
||||||
|
Providers,
|
||||||
|
toJsonSchema,
|
||||||
|
ANTHROPIC_TOOL_TOKEN_MULTIPLIER,
|
||||||
|
DEFAULT_TOOL_TOKEN_MULTIPLIER,
|
||||||
|
} from '@librechat/agents';
|
||||||
|
|
||||||
|
import type { GenericTool, LCTool, TokenCounter, ClientOptions } from '@librechat/agents';
|
||||||
|
import type { Keyv } from 'keyv';
|
||||||
|
|
||||||
|
import { standardCache } from '~/cache';
|
||||||
|
|
||||||
|
interface ToolEntry {
|
||||||
|
cacheKey: string;
|
||||||
|
json: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Module-level cache instance, lazily initialized. */
|
||||||
|
let toolTokenCache: Keyv | undefined;
|
||||||
|
|
||||||
|
function getCache(): Keyv {
|
||||||
|
if (!toolTokenCache) {
|
||||||
|
toolTokenCache = standardCache(CacheKeys.TOOL_TOKENS, Time.THIRTY_MINUTES);
|
||||||
|
}
|
||||||
|
return toolTokenCache;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function getToolTokenMultiplier(provider: Providers, clientOptions?: ClientOptions): number {
|
||||||
|
const isAnthropic =
|
||||||
|
provider !== Providers.BEDROCK &&
|
||||||
|
(provider === Providers.ANTHROPIC ||
|
||||||
|
/anthropic|claude/i.test(
|
||||||
|
String((clientOptions as { model?: string } | undefined)?.model ?? ''),
|
||||||
|
));
|
||||||
|
return isAnthropic ? ANTHROPIC_TOOL_TOKEN_MULTIPLIER : DEFAULT_TOOL_TOKEN_MULTIPLIER;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Serializes a GenericTool to a JSON string for token counting. Returns null if no schema. */
|
||||||
|
function serializeGenericTool(tool: GenericTool): { name: string; json: string } | null {
|
||||||
|
const genericTool = tool as unknown as Record<string, unknown>;
|
||||||
|
const toolName = (genericTool.name as string | undefined) ?? '';
|
||||||
|
if (genericTool.schema == null || typeof genericTool.schema !== 'object') {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
const jsonSchema = toJsonSchema(
|
||||||
|
genericTool.schema,
|
||||||
|
toolName,
|
||||||
|
(genericTool.description as string | undefined) ?? '',
|
||||||
|
);
|
||||||
|
return { name: toolName, json: JSON.stringify(jsonSchema) };
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Serializes an LCTool definition to a JSON string for token counting. */
|
||||||
|
function serializeToolDef(def: LCTool): string {
|
||||||
|
return JSON.stringify({
|
||||||
|
type: 'function',
|
||||||
|
function: {
|
||||||
|
name: def.name,
|
||||||
|
description: def.description ?? '',
|
||||||
|
parameters: def.parameters ?? {},
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Builds a list of tool entries with cache keys and serialized schemas.
|
||||||
|
* Deduplicates: a tool present in `tools` (with a schema) takes precedence
|
||||||
|
* over a matching `toolDefinitions` entry.
|
||||||
|
*
|
||||||
|
* Cache key includes toolType when available (from LCTool) to differentiate
|
||||||
|
* builtin/mcp/action tools that may share a name.
|
||||||
|
* GenericTool entries use the `mcp` flag when present.
|
||||||
|
*/
|
||||||
|
export function collectToolSchemas(tools?: GenericTool[], toolDefinitions?: LCTool[]): ToolEntry[] {
|
||||||
|
const seen = new Set<string>();
|
||||||
|
const entries: ToolEntry[] = [];
|
||||||
|
|
||||||
|
if (tools) {
|
||||||
|
for (const tool of tools) {
|
||||||
|
const result = serializeGenericTool(tool);
|
||||||
|
if (!result || !result.name) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
seen.add(result.name);
|
||||||
|
const toolType =
|
||||||
|
(tool as unknown as Record<string, unknown>).mcp === true ? 'mcp' : 'builtin';
|
||||||
|
entries.push({ cacheKey: `${result.name}:${toolType}`, json: result.json });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (toolDefinitions) {
|
||||||
|
for (const def of toolDefinitions) {
|
||||||
|
if (!def.name || seen.has(def.name)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
seen.add(def.name);
|
||||||
|
const toolType = def.toolType ?? 'builtin';
|
||||||
|
entries.push({ cacheKey: `${def.name}:${toolType}`, json: serializeToolDef(def) });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return entries;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Computes tool schema tokens from scratch using the provided token counter.
|
||||||
|
* Mirrors the logic in AgentContext.calculateInstructionTokens().
|
||||||
|
*/
|
||||||
|
export function computeToolSchemaTokens(
|
||||||
|
tools: GenericTool[] | undefined,
|
||||||
|
toolDefinitions: LCTool[] | undefined,
|
||||||
|
provider: Providers,
|
||||||
|
clientOptions: ClientOptions | undefined,
|
||||||
|
tokenCounter: TokenCounter,
|
||||||
|
): number {
|
||||||
|
const entries = collectToolSchemas(tools, toolDefinitions);
|
||||||
|
let rawTokens = 0;
|
||||||
|
for (const { json } of entries) {
|
||||||
|
rawTokens += tokenCounter(new SystemMessage(json));
|
||||||
|
}
|
||||||
|
const multiplier = getToolTokenMultiplier(provider, clientOptions);
|
||||||
|
return Math.ceil(rawTokens * multiplier);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns tool schema tokens, using per-tool caching to avoid redundant
|
||||||
|
* token counting. Each tool's raw (pre-multiplier) token count is cached
|
||||||
|
* individually, keyed by `{tenantId}:{name}:{toolType}` (or `{name}:{toolType}`
|
||||||
|
* without tenant). The provider-specific multiplier is applied to the sum.
|
||||||
|
*
|
||||||
|
* Returns 0 if there are no tools.
|
||||||
|
*/
|
||||||
|
export async function getOrComputeToolTokens({
|
||||||
|
tools,
|
||||||
|
toolDefinitions,
|
||||||
|
provider,
|
||||||
|
clientOptions,
|
||||||
|
tokenCounter,
|
||||||
|
tenantId,
|
||||||
|
}: {
|
||||||
|
tools?: GenericTool[];
|
||||||
|
toolDefinitions?: LCTool[];
|
||||||
|
provider: Providers;
|
||||||
|
clientOptions?: ClientOptions;
|
||||||
|
tokenCounter: TokenCounter;
|
||||||
|
tenantId?: string;
|
||||||
|
}): Promise<number> {
|
||||||
|
const entries = collectToolSchemas(tools, toolDefinitions);
|
||||||
|
if (entries.length === 0) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
const keyPrefix = tenantId ? `${tenantId}:` : '';
|
||||||
|
|
||||||
|
let cache: Keyv | undefined;
|
||||||
|
try {
|
||||||
|
cache = getCache();
|
||||||
|
} catch (err) {
|
||||||
|
logger.debug('[toolTokens] Cache init failed, computing fresh', err);
|
||||||
|
}
|
||||||
|
|
||||||
|
let rawTotal = 0;
|
||||||
|
const toWrite: Array<{ key: string; value: number }> = [];
|
||||||
|
|
||||||
|
for (const { cacheKey, json } of entries) {
|
||||||
|
const fullKey = `${keyPrefix}${cacheKey}`;
|
||||||
|
let rawCount: number | undefined;
|
||||||
|
|
||||||
|
if (cache) {
|
||||||
|
try {
|
||||||
|
rawCount = (await cache.get(fullKey)) as number | undefined;
|
||||||
|
} catch {
|
||||||
|
// Cache read failed for this tool — will compute fresh
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (rawCount == null || rawCount <= 0) {
|
||||||
|
rawCount = tokenCounter(new SystemMessage(json));
|
||||||
|
if (rawCount > 0 && cache) {
|
||||||
|
toWrite.push({ key: fullKey, value: rawCount });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
rawTotal += rawCount;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fire-and-forget cache writes for newly computed tools
|
||||||
|
if (cache && toWrite.length > 0) {
|
||||||
|
for (const { key, value } of toWrite) {
|
||||||
|
cache.set(key, value).catch((err: unknown) => {
|
||||||
|
logger.debug(`[toolTokens] Cache write failed for ${key}`, err);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const multiplier = getToolTokenMultiplier(provider, clientOptions);
|
||||||
|
return Math.ceil(rawTotal * multiplier);
|
||||||
|
}
|
||||||
|
|
@ -1564,6 +1564,11 @@ export enum CacheKeys {
|
||||||
* Key for admin panel OAuth exchange codes (one-time-use, short TTL).
|
* Key for admin panel OAuth exchange codes (one-time-use, short TTL).
|
||||||
*/
|
*/
|
||||||
ADMIN_OAUTH_EXCHANGE = 'ADMIN_OAUTH_EXCHANGE',
|
ADMIN_OAUTH_EXCHANGE = 'ADMIN_OAUTH_EXCHANGE',
|
||||||
|
/**
|
||||||
|
* Key for cached tool schema token counts.
|
||||||
|
* Keyed by provider + tool fingerprint to avoid redundant token counting.
|
||||||
|
*/
|
||||||
|
TOOL_TOKENS = 'TOOL_TOKENS',
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue