mirror of
https://github.com/danny-avila/LibreChat.git
synced 2026-04-07 00:15:23 +02:00
fix: include toolType in per-tool cache key
Cache key is now {tenantId}:{name}:{toolType} (or {name}:{toolType}
without tenant). This differentiates builtin/mcp/action tools that
may share a name but have different schemas.
GenericTool entries derive type from the mcp flag; LCTool entries
use the toolType field (defaulting to builtin).
Also refactors collectToolSchemas to return ToolEntry[] with
pre-built cache keys instead of Map<name, json>.
This commit is contained in:
parent
8db4f21f97
commit
e2962f4967
2 changed files with 132 additions and 78 deletions
|
|
@ -37,11 +37,18 @@ function makeTool(name: string, description = `${name} description`): GenericToo
|
|||
}) as unknown as GenericTool;
|
||||
}
|
||||
|
||||
function makeToolDef(name: string, description?: string): LCTool {
|
||||
function makeMcpTool(name: string): GenericTool {
|
||||
const tool = makeTool(name) as unknown as Record<string, unknown>;
|
||||
tool.mcp = true;
|
||||
return tool as unknown as GenericTool;
|
||||
}
|
||||
|
||||
function makeToolDef(name: string, opts?: Partial<LCTool>): LCTool {
|
||||
return {
|
||||
name,
|
||||
description: description ?? `${name} description`,
|
||||
parameters: { type: 'object', properties: { input: { type: 'string' } } },
|
||||
description: opts?.description ?? `${name} description`,
|
||||
parameters: opts?.parameters ?? { type: 'object', properties: { input: { type: 'string' } } },
|
||||
...opts,
|
||||
};
|
||||
}
|
||||
|
||||
|
|
@ -65,34 +72,45 @@ beforeEach(() => {
|
|||
/* ========================================================================= */
|
||||
|
||||
describe('collectToolSchemas', () => {
|
||||
it('returns empty map when no tools provided', () => {
|
||||
expect(collectToolSchemas().size).toBe(0);
|
||||
expect(collectToolSchemas([], []).size).toBe(0);
|
||||
it('returns empty array when no tools provided', () => {
|
||||
expect(collectToolSchemas()).toHaveLength(0);
|
||||
expect(collectToolSchemas([], [])).toHaveLength(0);
|
||||
});
|
||||
|
||||
it('collects schemas from GenericTool array keyed by name', () => {
|
||||
const tools = [makeTool('alpha'), makeTool('beta')];
|
||||
const schemas = collectToolSchemas(tools);
|
||||
expect(schemas.size).toBe(2);
|
||||
expect(schemas.has('alpha')).toBe(true);
|
||||
expect(schemas.has('beta')).toBe(true);
|
||||
it('collects entries from GenericTool array', () => {
|
||||
const entries = collectToolSchemas([makeTool('alpha'), makeTool('beta')]);
|
||||
expect(entries).toHaveLength(2);
|
||||
expect(entries.map((e) => e.cacheKey)).toEqual(
|
||||
expect.arrayContaining(['alpha:builtin', 'beta:builtin']),
|
||||
);
|
||||
});
|
||||
|
||||
it('collects schemas from LCTool definitions', () => {
|
||||
const defs = [makeToolDef('x'), makeToolDef('y')];
|
||||
const schemas = collectToolSchemas(undefined, defs);
|
||||
expect(schemas.size).toBe(2);
|
||||
expect(schemas.has('x')).toBe(true);
|
||||
expect(schemas.has('y')).toBe(true);
|
||||
it('collects entries from LCTool definitions with toolType', () => {
|
||||
const defs = [makeToolDef('x', { toolType: 'mcp' }), makeToolDef('y', { toolType: 'action' })];
|
||||
const entries = collectToolSchemas(undefined, defs);
|
||||
expect(entries).toHaveLength(2);
|
||||
expect(entries[0].cacheKey).toBe('x:mcp');
|
||||
expect(entries[1].cacheKey).toBe('y:action');
|
||||
});
|
||||
|
||||
it('defaults toolType to builtin for LCTool without toolType', () => {
|
||||
const entries = collectToolSchemas(undefined, [makeToolDef('z')]);
|
||||
expect(entries[0].cacheKey).toBe('z:builtin');
|
||||
});
|
||||
|
||||
it('uses mcp type for GenericTool with mcp flag', () => {
|
||||
const entries = collectToolSchemas([makeMcpTool('search')]);
|
||||
expect(entries[0].cacheKey).toBe('search:mcp');
|
||||
});
|
||||
|
||||
it('deduplicates: GenericTool takes precedence over matching toolDefinition', () => {
|
||||
const tools = [makeTool('shared')];
|
||||
const defs = [makeToolDef('shared'), makeToolDef('only_def')];
|
||||
const schemas = collectToolSchemas(tools, defs);
|
||||
expect(schemas.size).toBe(2);
|
||||
expect(schemas.has('shared')).toBe(true);
|
||||
expect(schemas.has('only_def')).toBe(true);
|
||||
const entries = collectToolSchemas(tools, defs);
|
||||
expect(entries).toHaveLength(2);
|
||||
const keys = entries.map((e) => e.cacheKey);
|
||||
expect(keys).toContain('shared:builtin');
|
||||
expect(keys).toContain('only_def:builtin');
|
||||
});
|
||||
});
|
||||
|
||||
|
|
@ -109,9 +127,8 @@ describe('computeToolSchemaTokens', () => {
|
|||
});
|
||||
|
||||
it('counts tokens from GenericTool schemas', () => {
|
||||
const tools = [makeTool('test_tool')];
|
||||
const result = computeToolSchemaTokens(
|
||||
tools,
|
||||
[makeTool('test_tool')],
|
||||
undefined,
|
||||
Providers.OPENAI,
|
||||
undefined,
|
||||
|
|
@ -121,10 +138,9 @@ describe('computeToolSchemaTokens', () => {
|
|||
});
|
||||
|
||||
it('counts tokens from LCTool definitions', () => {
|
||||
const defs = [makeToolDef('test_def')];
|
||||
const result = computeToolSchemaTokens(
|
||||
undefined,
|
||||
defs,
|
||||
[makeToolDef('test_def')],
|
||||
Providers.OPENAI,
|
||||
undefined,
|
||||
fakeTokenCounter,
|
||||
|
|
@ -150,7 +166,6 @@ describe('computeToolSchemaTokens', () => {
|
|||
undefined,
|
||||
fakeTokenCounter,
|
||||
);
|
||||
|
||||
expect(both).toBe(toolsOnly);
|
||||
});
|
||||
|
||||
|
|
@ -170,19 +185,17 @@ describe('computeToolSchemaTokens', () => {
|
|||
undefined,
|
||||
fakeTokenCounter,
|
||||
);
|
||||
|
||||
const expectedRatio = ANTHROPIC_TOOL_TOKEN_MULTIPLIER / DEFAULT_TOOL_TOKEN_MULTIPLIER;
|
||||
expect(anthropic / openai).toBeCloseTo(expectedRatio, 1);
|
||||
});
|
||||
|
||||
it('applies Anthropic multiplier when model name contains "claude"', () => {
|
||||
const defs = [makeToolDef('tool')];
|
||||
const clientOptions = { model: 'claude-3-opus' };
|
||||
const result = computeToolSchemaTokens(
|
||||
undefined,
|
||||
defs,
|
||||
Providers.OPENAI,
|
||||
clientOptions,
|
||||
{ model: 'claude-3-opus' },
|
||||
fakeTokenCounter,
|
||||
);
|
||||
const defaultResult = computeToolSchemaTokens(
|
||||
|
|
@ -197,12 +210,11 @@ describe('computeToolSchemaTokens', () => {
|
|||
|
||||
it('does not apply Anthropic multiplier for Bedrock even with claude model', () => {
|
||||
const defs = [makeToolDef('tool')];
|
||||
const clientOptions = { model: 'claude-3-opus' };
|
||||
const bedrock = computeToolSchemaTokens(
|
||||
undefined,
|
||||
defs,
|
||||
Providers.BEDROCK,
|
||||
clientOptions,
|
||||
{ model: 'claude-3-opus' },
|
||||
fakeTokenCounter,
|
||||
);
|
||||
const defaultResult = computeToolSchemaTokens(
|
||||
|
|
@ -212,7 +224,6 @@ describe('computeToolSchemaTokens', () => {
|
|||
undefined,
|
||||
fakeTokenCounter,
|
||||
);
|
||||
|
||||
expect(bedrock).toBe(defaultResult);
|
||||
});
|
||||
});
|
||||
|
|
@ -239,8 +250,8 @@ describe('getOrComputeToolTokens', () => {
|
|||
});
|
||||
|
||||
expect(result).toBeGreaterThan(0);
|
||||
expect(mockCacheStore.has('tool_a')).toBe(true);
|
||||
expect(mockCacheStore.has('tool_b')).toBe(true);
|
||||
expect(mockCacheStore.has('tool_a:builtin')).toBe(true);
|
||||
expect(mockCacheStore.has('tool_b:builtin')).toBe(true);
|
||||
expect(mockCacheStore.size).toBe(2);
|
||||
});
|
||||
|
||||
|
|
@ -282,7 +293,6 @@ describe('getOrComputeToolTokens', () => {
|
|||
});
|
||||
|
||||
expect(openai).not.toBe(anthropic);
|
||||
// Only one cache entry — raw count is provider-agnostic
|
||||
expect(mockCacheStore.size).toBe(1);
|
||||
});
|
||||
|
||||
|
|
@ -302,17 +312,63 @@ describe('getOrComputeToolTokens', () => {
|
|||
tokenCounter: counter,
|
||||
});
|
||||
|
||||
// Only one new tokenCounter call for tool_b
|
||||
expect(counter.mock.calls.length).toBe(callsAfterFirst + 1);
|
||||
expect(mockCacheStore.size).toBe(2);
|
||||
});
|
||||
|
||||
it('scopes cache keys by tenantId when provided', async () => {
|
||||
const defs = [makeToolDef('tool')];
|
||||
|
||||
await getOrComputeToolTokens({
|
||||
toolDefinitions: defs,
|
||||
provider: Providers.OPENAI,
|
||||
tokenCounter: fakeTokenCounter,
|
||||
tenantId: 'tenant_123',
|
||||
});
|
||||
|
||||
expect(mockCacheStore.has('tenant_123:tool:builtin')).toBe(true);
|
||||
});
|
||||
|
||||
it('separates cache entries for different tenants', async () => {
|
||||
const defs = [makeToolDef('tool')];
|
||||
|
||||
const t1 = await getOrComputeToolTokens({
|
||||
toolDefinitions: defs,
|
||||
provider: Providers.OPENAI,
|
||||
tokenCounter: fakeTokenCounter,
|
||||
tenantId: 'tenant_1',
|
||||
});
|
||||
|
||||
const t2 = await getOrComputeToolTokens({
|
||||
toolDefinitions: defs,
|
||||
provider: Providers.OPENAI,
|
||||
tokenCounter: fakeTokenCounter,
|
||||
tenantId: 'tenant_2',
|
||||
});
|
||||
|
||||
expect(t1).toBe(t2);
|
||||
expect(mockCacheStore.has('tenant_1:tool:builtin')).toBe(true);
|
||||
expect(mockCacheStore.has('tenant_2:tool:builtin')).toBe(true);
|
||||
expect(mockCacheStore.size).toBe(2);
|
||||
});
|
||||
|
||||
it('caches mcp tools with mcp type in key', async () => {
|
||||
const defs = [makeToolDef('search', { toolType: 'mcp' })];
|
||||
|
||||
await getOrComputeToolTokens({
|
||||
toolDefinitions: defs,
|
||||
provider: Providers.OPENAI,
|
||||
tokenCounter: fakeTokenCounter,
|
||||
});
|
||||
|
||||
expect(mockCacheStore.has('search:mcp')).toBe(true);
|
||||
});
|
||||
|
||||
it('falls back to compute when cache read throws', async () => {
|
||||
mockGet.mockRejectedValueOnce(new Error('Redis down'));
|
||||
|
||||
const defs = [makeToolDef('tool')];
|
||||
const result = await getOrComputeToolTokens({
|
||||
toolDefinitions: defs,
|
||||
toolDefinitions: [makeToolDef('tool')],
|
||||
provider: Providers.OPENAI,
|
||||
tokenCounter: fakeTokenCounter,
|
||||
});
|
||||
|
|
@ -324,9 +380,8 @@ describe('getOrComputeToolTokens', () => {
|
|||
it('does not throw when cache write fails', async () => {
|
||||
mockSet.mockRejectedValueOnce(new Error('Redis write error'));
|
||||
|
||||
const defs = [makeToolDef('tool_write_fail')];
|
||||
const result = await getOrComputeToolTokens({
|
||||
toolDefinitions: defs,
|
||||
toolDefinitions: [makeToolDef('tool_write_fail')],
|
||||
provider: Providers.OPENAI,
|
||||
tokenCounter: fakeTokenCounter,
|
||||
});
|
||||
|
|
@ -335,20 +390,6 @@ describe('getOrComputeToolTokens', () => {
|
|||
expect(mockSet).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('uses GenericTool tools for per-tool caching', async () => {
|
||||
const tools = [makeTool('alpha'), makeTool('beta')];
|
||||
|
||||
const result = await getOrComputeToolTokens({
|
||||
tools,
|
||||
provider: Providers.OPENAI,
|
||||
tokenCounter: fakeTokenCounter,
|
||||
});
|
||||
|
||||
expect(result).toBeGreaterThan(0);
|
||||
expect(mockCacheStore.has('alpha')).toBe(true);
|
||||
expect(mockCacheStore.has('beta')).toBe(true);
|
||||
});
|
||||
|
||||
it('matches computeToolSchemaTokens output for same inputs', async () => {
|
||||
const defs = [makeToolDef('a'), makeToolDef('b'), makeToolDef('c')];
|
||||
|
||||
|
|
|
|||
|
|
@ -13,6 +13,11 @@ import type { Keyv } from 'keyv';
|
|||
|
||||
import { standardCache } from '~/cache';
|
||||
|
||||
interface ToolEntry {
|
||||
cacheKey: string;
|
||||
json: string;
|
||||
}
|
||||
|
||||
/** Module-level cache instance, lazily initialized. */
|
||||
let toolTokenCache: Keyv | undefined;
|
||||
|
||||
|
|
@ -61,35 +66,43 @@ function serializeToolDef(def: LCTool): string {
|
|||
}
|
||||
|
||||
/**
|
||||
* Builds a map of tool name → serialized schema JSON. Deduplicates: a tool
|
||||
* present in `tools` (with a schema) takes precedence over a matching
|
||||
* `toolDefinitions` entry, mirroring AgentContext.calculateInstructionTokens().
|
||||
* Builds a list of tool entries with cache keys and serialized schemas.
|
||||
* Deduplicates: a tool present in `tools` (with a schema) takes precedence
|
||||
* over a matching `toolDefinitions` entry.
|
||||
*
|
||||
* Cache key includes toolType when available (from LCTool) to differentiate
|
||||
* builtin/mcp/action tools that may share a name.
|
||||
* GenericTool entries use the `mcp` flag when present.
|
||||
*/
|
||||
export function collectToolSchemas(
|
||||
tools?: GenericTool[],
|
||||
toolDefinitions?: LCTool[],
|
||||
): Map<string, string> {
|
||||
const schemas = new Map<string, string>();
|
||||
export function collectToolSchemas(tools?: GenericTool[], toolDefinitions?: LCTool[]): ToolEntry[] {
|
||||
const seen = new Set<string>();
|
||||
const entries: ToolEntry[] = [];
|
||||
|
||||
if (tools) {
|
||||
for (const tool of tools) {
|
||||
const result = serializeGenericTool(tool);
|
||||
if (result && result.name) {
|
||||
schemas.set(result.name, result.json);
|
||||
if (!result || !result.name) {
|
||||
continue;
|
||||
}
|
||||
seen.add(result.name);
|
||||
const toolType =
|
||||
(tool as unknown as Record<string, unknown>).mcp === true ? 'mcp' : 'builtin';
|
||||
entries.push({ cacheKey: `${result.name}:${toolType}`, json: result.json });
|
||||
}
|
||||
}
|
||||
|
||||
if (toolDefinitions) {
|
||||
for (const def of toolDefinitions) {
|
||||
if (!def.name || schemas.has(def.name)) {
|
||||
if (!def.name || seen.has(def.name)) {
|
||||
continue;
|
||||
}
|
||||
schemas.set(def.name, serializeToolDef(def));
|
||||
seen.add(def.name);
|
||||
const toolType = def.toolType ?? 'builtin';
|
||||
entries.push({ cacheKey: `${def.name}:${toolType}`, json: serializeToolDef(def) });
|
||||
}
|
||||
}
|
||||
|
||||
return schemas;
|
||||
return entries;
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -103,9 +116,9 @@ export function computeToolSchemaTokens(
|
|||
clientOptions: ClientOptions | undefined,
|
||||
tokenCounter: TokenCounter,
|
||||
): number {
|
||||
const schemas = collectToolSchemas(tools, toolDefinitions);
|
||||
const entries = collectToolSchemas(tools, toolDefinitions);
|
||||
let rawTokens = 0;
|
||||
for (const json of schemas.values()) {
|
||||
for (const { json } of entries) {
|
||||
rawTokens += tokenCounter(new SystemMessage(json));
|
||||
}
|
||||
const multiplier = getToolTokenMultiplier(provider, clientOptions);
|
||||
|
|
@ -115,8 +128,8 @@ export function computeToolSchemaTokens(
|
|||
/**
|
||||
* Returns tool schema tokens, using per-tool caching to avoid redundant
|
||||
* token counting. Each tool's raw (pre-multiplier) token count is cached
|
||||
* individually by name, so adding/removing a tool only requires computing
|
||||
* the new one. The provider-specific multiplier is applied to the sum.
|
||||
* individually, keyed by `{tenantId}:{name}:{toolType}` (or `{name}:{toolType}`
|
||||
* without tenant). The provider-specific multiplier is applied to the sum.
|
||||
*
|
||||
* Returns 0 if there are no tools.
|
||||
*/
|
||||
|
|
@ -135,8 +148,8 @@ export async function getOrComputeToolTokens({
|
|||
tokenCounter: TokenCounter;
|
||||
tenantId?: string;
|
||||
}): Promise<number> {
|
||||
const schemas = collectToolSchemas(tools, toolDefinitions);
|
||||
if (schemas.size === 0) {
|
||||
const entries = collectToolSchemas(tools, toolDefinitions);
|
||||
if (entries.length === 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
|
@ -152,13 +165,13 @@ export async function getOrComputeToolTokens({
|
|||
let rawTotal = 0;
|
||||
const toWrite: Array<{ key: string; value: number }> = [];
|
||||
|
||||
for (const [name, json] of schemas) {
|
||||
const cacheKey = `${keyPrefix}${name}`;
|
||||
for (const { cacheKey, json } of entries) {
|
||||
const fullKey = `${keyPrefix}${cacheKey}`;
|
||||
let rawCount: number | undefined;
|
||||
|
||||
if (cache) {
|
||||
try {
|
||||
rawCount = (await cache.get(cacheKey)) as number | undefined;
|
||||
rawCount = (await cache.get(fullKey)) as number | undefined;
|
||||
} catch {
|
||||
// Cache read failed for this tool — will compute fresh
|
||||
}
|
||||
|
|
@ -167,7 +180,7 @@ export async function getOrComputeToolTokens({
|
|||
if (rawCount == null || rawCount <= 0) {
|
||||
rawCount = tokenCounter(new SystemMessage(json));
|
||||
if (rawCount > 0 && cache) {
|
||||
toWrite.push({ key: cacheKey, value: rawCount });
|
||||
toWrite.push({ key: fullKey, value: rawCount });
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue