From 9a5d7eaa4ef1dbae21d06ea82e587c54e867b48b Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Tue, 10 Mar 2026 23:14:52 -0400 Subject: [PATCH] =?UTF-8?q?=E2=9A=A1=20refactor:=20Replace=20`tiktoken`=20?= =?UTF-8?q?with=20`ai-tokenizer`=20(#12175)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * chore: Update dependencies by adding ai-tokenizer and removing tiktoken - Added ai-tokenizer version 1.0.6 to package.json and package-lock.json across multiple packages. - Removed tiktoken version 1.0.15 from package.json and package-lock.json in the same locations, streamlining dependency management. * refactor: replace js-tiktoken with ai-tokenizer - Added support for 'claude' encoding in the AgentClient class to improve model compatibility. - Updated Tokenizer class to utilize 'ai-tokenizer' for both 'o200k_base' and 'claude' encodings, replacing the previous 'tiktoken' dependency. - Refactored tests to reflect changes in tokenizer behavior and ensure accurate token counting for both encoding types. - Removed deprecated references to 'tiktoken' and adjusted related tests for improved clarity and functionality. * chore: remove tiktoken mocks from DALLE3 tests - Eliminated mock implementations of 'tiktoken' from DALLE3-related test files to streamline test setup and align with recent dependency updates. - Adjusted related test structures to ensure compatibility with the new tokenizer implementation. * chore: Add distinct encoding support for Anthropic Claude models - Introduced a new method `getEncoding` in the AgentClient class to handle the specific BPE tokenizer for Claude models, ensuring compatibility with the distinct encoding requirements. - Updated documentation to clarify the encoding logic for Claude and other models. * docs: Update return type documentation for getEncoding method in AgentClient - Clarified the return type of the getEncoding method to specify that it can return an EncodingName or undefined, enhancing code readability and type safety. * refactor: Tokenizer class and error handling - Exported the EncodingName type for broader usage. - Renamed encodingMap to encodingData for clarity. - Improved error handling in getTokenCount method to ensure recovery attempts are logged and return 0 on failure. - Updated countTokens function documentation to specify the use of 'o200k_base' encoding. * refactor: Simplify encoding documentation and export type - Updated the getEncoding method documentation to clarify the default behavior for non-Anthropic Claude models. - Exported the EncodingName type separately from the Tokenizer module for improved clarity and usage. * test: Update text processing tests for token limits - Adjusted test cases to handle smaller text sizes, changing scenarios from ~120k tokens to ~20k tokens for both the real tokenizer and countTokens functions. - Updated token limits in tests to reflect new constraints, ensuring tests accurately assess performance and call reduction. - Enhanced console log messages for clarity regarding token counts and reductions in the updated scenarios. * refactor: Update Tokenizer imports and exports - Moved Tokenizer and countTokens exports to the tokenizer module for better organization. - Adjusted imports in memory.ts to reflect the new structure, ensuring consistent usage across the codebase. - Updated memory.test.ts to mock the Tokenizer from the correct module path, enhancing test accuracy. * refactor: Tokenizer initialization and error handling - Introduced an async `initEncoding` method to preload tokenizers, improving performance and accuracy in token counting. - Updated `getTokenCount` to handle uninitialized tokenizers more gracefully, ensuring proper recovery and logging on errors. - Removed deprecated synchronous tokenizer retrieval, streamlining the overall tokenizer management process. * test: Enhance tokenizer tests with initialization and encoding checks - Added `beforeAll` hooks to initialize tokenizers for 'o200k_base' and 'claude' encodings before running tests, ensuring proper setup. - Updated tests to validate the loading of encodings and the correctness of token counts for both 'o200k_base' and 'claude'. - Improved test structure to deduplicate concurrent initialization calls, enhancing performance and reliability. --- .../structured/specs/DALLE3-proxy.spec.js | 1 - .../tools/structured/specs/DALLE3.spec.js | 9 -- api/package.json | 2 +- api/server/controllers/agents/client.js | 4 + api/strategies/samlStrategy.spec.js | 1 - package-lock.json | 23 ++- packages/api/package.json | 2 +- .../api/src/agents/__tests__/memory.test.ts | 5 +- packages/api/src/agents/memory.ts | 3 +- packages/api/src/index.ts | 2 + packages/api/src/utils/index.ts | 1 - packages/api/src/utils/text.spec.ts | 62 +++----- packages/api/src/utils/tokenizer.spec.ts | 137 ++++-------------- packages/api/src/utils/tokenizer.ts | 98 +++++-------- packages/api/src/utils/tokens.ts | 39 ----- 15 files changed, 112 insertions(+), 277 deletions(-) diff --git a/api/app/clients/tools/structured/specs/DALLE3-proxy.spec.js b/api/app/clients/tools/structured/specs/DALLE3-proxy.spec.js index 4481a7d70f..262842b3c2 100644 --- a/api/app/clients/tools/structured/specs/DALLE3-proxy.spec.js +++ b/api/app/clients/tools/structured/specs/DALLE3-proxy.spec.js @@ -1,7 +1,6 @@ const DALLE3 = require('../DALLE3'); const { ProxyAgent } = require('undici'); -jest.mock('tiktoken'); const processFileURL = jest.fn(); describe('DALLE3 Proxy Configuration', () => { diff --git a/api/app/clients/tools/structured/specs/DALLE3.spec.js b/api/app/clients/tools/structured/specs/DALLE3.spec.js index d2040989f9..6071929bfc 100644 --- a/api/app/clients/tools/structured/specs/DALLE3.spec.js +++ b/api/app/clients/tools/structured/specs/DALLE3.spec.js @@ -14,15 +14,6 @@ jest.mock('@librechat/data-schemas', () => { }; }); -jest.mock('tiktoken', () => { - return { - encoding_for_model: jest.fn().mockReturnValue({ - encode: jest.fn(), - decode: jest.fn(), - }), - }; -}); - const processFileURL = jest.fn(); const generate = jest.fn(); diff --git a/api/package.json b/api/package.json index fcd353af57..1618481b58 100644 --- a/api/package.json +++ b/api/package.json @@ -51,6 +51,7 @@ "@modelcontextprotocol/sdk": "^1.27.1", "@node-saml/passport-saml": "^5.1.0", "@smithy/node-http-handler": "^4.4.5", + "ai-tokenizer": "^1.0.6", "axios": "^1.13.5", "bcryptjs": "^2.4.3", "compression": "^1.8.1", @@ -106,7 +107,6 @@ "pdfjs-dist": "^5.4.624", "rate-limit-redis": "^4.2.0", "sharp": "^0.33.5", - "tiktoken": "^1.0.15", "traverse": "^0.6.7", "ua-parser-js": "^1.0.36", "undici": "^7.18.2", diff --git a/api/server/controllers/agents/client.js b/api/server/controllers/agents/client.js index 5f99a0762b..0ecd62b819 100644 --- a/api/server/controllers/agents/client.js +++ b/api/server/controllers/agents/client.js @@ -1172,7 +1172,11 @@ class AgentClient extends BaseClient { } } + /** Anthropic Claude models use a distinct BPE tokenizer; all others default to o200k_base. */ getEncoding() { + if (this.model && this.model.toLowerCase().includes('claude')) { + return 'claude'; + } return 'o200k_base'; } diff --git a/api/strategies/samlStrategy.spec.js b/api/strategies/samlStrategy.spec.js index 06c969ce46..1d16719b87 100644 --- a/api/strategies/samlStrategy.spec.js +++ b/api/strategies/samlStrategy.spec.js @@ -1,5 +1,4 @@ // --- Mocks --- -jest.mock('tiktoken'); jest.mock('fs'); jest.mock('path'); jest.mock('node-fetch'); diff --git a/package-lock.json b/package-lock.json index 09c5219afb..a2db2df389 100644 --- a/package-lock.json +++ b/package-lock.json @@ -66,6 +66,7 @@ "@modelcontextprotocol/sdk": "^1.27.1", "@node-saml/passport-saml": "^5.1.0", "@smithy/node-http-handler": "^4.4.5", + "ai-tokenizer": "^1.0.6", "axios": "^1.13.5", "bcryptjs": "^2.4.3", "compression": "^1.8.1", @@ -121,7 +122,6 @@ "pdfjs-dist": "^5.4.624", "rate-limit-redis": "^4.2.0", "sharp": "^0.33.5", - "tiktoken": "^1.0.15", "traverse": "^0.6.7", "ua-parser-js": "^1.0.36", "undici": "^7.18.2", @@ -22230,6 +22230,20 @@ "node": ">= 14" } }, + "node_modules/ai-tokenizer": { + "version": "1.0.6", + "resolved": "https://registry.npmjs.org/ai-tokenizer/-/ai-tokenizer-1.0.6.tgz", + "integrity": "sha512-GaakQFxen0pRH/HIA4v68ZM40llCH27HUYUSBLK+gVuZ57e53pYJe1xFvSTj4sJJjbWU92m1X6NjPWyeWkFDow==", + "license": "MIT", + "peerDependencies": { + "ai": "^5.0.0" + }, + "peerDependenciesMeta": { + "ai": { + "optional": true + } + } + }, "node_modules/ajv": { "version": "8.18.0", "resolved": "https://registry.npmjs.org/ajv/-/ajv-8.18.0.tgz", @@ -41485,11 +41499,6 @@ "node": ">=0.8" } }, - "node_modules/tiktoken": { - "version": "1.0.15", - "resolved": "https://registry.npmjs.org/tiktoken/-/tiktoken-1.0.15.tgz", - "integrity": "sha512-sCsrq/vMWUSEW29CJLNmPvWxlVp7yh2tlkAjpJltIKqp5CKf98ZNpdeHRmAlPVFlGEbswDc6SmI8vz64W/qErw==" - }, "node_modules/timers-browserify": { "version": "2.0.12", "resolved": "https://registry.npmjs.org/timers-browserify/-/timers-browserify-2.0.12.tgz", @@ -44200,6 +44209,7 @@ "@librechat/data-schemas": "*", "@modelcontextprotocol/sdk": "^1.27.1", "@smithy/node-http-handler": "^4.4.5", + "ai-tokenizer": "^1.0.6", "axios": "^1.13.5", "connect-redis": "^8.1.0", "eventsource": "^3.0.2", @@ -44222,7 +44232,6 @@ "node-fetch": "2.7.0", "pdfjs-dist": "^5.4.624", "rate-limit-redis": "^4.2.0", - "tiktoken": "^1.0.15", "undici": "^7.18.2", "zod": "^3.22.4" } diff --git a/packages/api/package.json b/packages/api/package.json index 46587797a5..966447c51b 100644 --- a/packages/api/package.json +++ b/packages/api/package.json @@ -94,6 +94,7 @@ "@librechat/data-schemas": "*", "@modelcontextprotocol/sdk": "^1.27.1", "@smithy/node-http-handler": "^4.4.5", + "ai-tokenizer": "^1.0.6", "axios": "^1.13.5", "connect-redis": "^8.1.0", "eventsource": "^3.0.2", @@ -116,7 +117,6 @@ "node-fetch": "2.7.0", "pdfjs-dist": "^5.4.624", "rate-limit-redis": "^4.2.0", - "tiktoken": "^1.0.15", "undici": "^7.18.2", "zod": "^3.22.4" } diff --git a/packages/api/src/agents/__tests__/memory.test.ts b/packages/api/src/agents/__tests__/memory.test.ts index 74cd0f4354..dabe6de629 100644 --- a/packages/api/src/agents/__tests__/memory.test.ts +++ b/packages/api/src/agents/__tests__/memory.test.ts @@ -22,8 +22,9 @@ jest.mock('winston', () => ({ })); // Mock the Tokenizer -jest.mock('~/utils', () => ({ - Tokenizer: { +jest.mock('~/utils/tokenizer', () => ({ + __esModule: true, + default: { getTokenCount: jest.fn((text: string) => text.length), // Simple mock: 1 char = 1 token }, })); diff --git a/packages/api/src/agents/memory.ts b/packages/api/src/agents/memory.ts index b8f65a9772..b7ae8a8123 100644 --- a/packages/api/src/agents/memory.ts +++ b/packages/api/src/agents/memory.ts @@ -19,7 +19,8 @@ import type { TAttachment, MemoryArtifact } from 'librechat-data-provider'; import type { BaseMessage, ToolMessage } from '@langchain/core/messages'; import type { Response as ServerResponse } from 'express'; import { GenerationJobManager } from '~/stream/GenerationJobManager'; -import { Tokenizer, resolveHeaders, createSafeUser } from '~/utils'; +import { resolveHeaders, createSafeUser } from '~/utils'; +import Tokenizer from '~/utils/tokenizer'; type RequiredMemoryMethods = Pick< MemoryMethods, diff --git a/packages/api/src/index.ts b/packages/api/src/index.ts index a7edb3882d..687ee7aa49 100644 --- a/packages/api/src/index.ts +++ b/packages/api/src/index.ts @@ -15,6 +15,8 @@ export * from './mcp/errors'; /* Utilities */ export * from './mcp/utils'; export * from './utils'; +export { default as Tokenizer, countTokens } from './utils/tokenizer'; +export type { EncodingName } from './utils/tokenizer'; export * from './db/utils'; /* OAuth */ export * from './oauth'; diff --git a/packages/api/src/utils/index.ts b/packages/api/src/utils/index.ts index 470780cd5c..441c2e02d7 100644 --- a/packages/api/src/utils/index.ts +++ b/packages/api/src/utils/index.ts @@ -19,7 +19,6 @@ export * from './promise'; export * from './sanitizeTitle'; export * from './tempChatRetention'; export * from './text'; -export { default as Tokenizer, countTokens } from './tokenizer'; export * from './yaml'; export * from './http'; export * from './tokens'; diff --git a/packages/api/src/utils/text.spec.ts b/packages/api/src/utils/text.spec.ts index 1b8d8aac98..30185f9da7 100644 --- a/packages/api/src/utils/text.spec.ts +++ b/packages/api/src/utils/text.spec.ts @@ -65,7 +65,7 @@ const createRealTokenCounter = () => { let callCount = 0; const tokenCountFn = (text: string): number => { callCount++; - return Tokenizer.getTokenCount(text, 'cl100k_base'); + return Tokenizer.getTokenCount(text, 'o200k_base'); }; return { tokenCountFn, @@ -590,9 +590,9 @@ describe('processTextWithTokenLimit', () => { }); }); - describe('direct comparison with REAL tiktoken tokenizer', () => { - beforeEach(() => { - Tokenizer.freeAndResetAllEncoders(); + describe('direct comparison with REAL ai-tokenizer', () => { + beforeAll(async () => { + await Tokenizer.initEncoding('o200k_base'); }); it('should produce valid truncation with real tokenizer', async () => { @@ -611,7 +611,7 @@ describe('processTextWithTokenLimit', () => { expect(result.text.length).toBeLessThan(text.length); }); - it('should use fewer tiktoken calls than old implementation (realistic text)', async () => { + it('should use fewer tokenizer calls than old implementation (realistic text)', async () => { const oldCounter = createRealTokenCounter(); const newCounter = createRealTokenCounter(); const text = createRealisticText(15000); @@ -623,8 +623,6 @@ describe('processTextWithTokenLimit', () => { tokenCountFn: oldCounter.tokenCountFn, }); - Tokenizer.freeAndResetAllEncoders(); - await processTextWithTokenLimit({ text, tokenLimit, @@ -634,17 +632,17 @@ describe('processTextWithTokenLimit', () => { const oldCalls = oldCounter.getCallCount(); const newCalls = newCounter.getCallCount(); - console.log(`[Real tiktoken ~15k tokens] OLD: ${oldCalls} calls, NEW: ${newCalls} calls`); - console.log(`[Real tiktoken] Reduction: ${((1 - newCalls / oldCalls) * 100).toFixed(1)}%`); + console.log(`[Real tokenizer ~15k tokens] OLD: ${oldCalls} calls, NEW: ${newCalls} calls`); + console.log(`[Real tokenizer] Reduction: ${((1 - newCalls / oldCalls) * 100).toFixed(1)}%`); expect(newCalls).toBeLessThan(oldCalls); }); - it('should handle the reported user scenario with real tokenizer (~120k tokens)', async () => { + it('should handle large text with real tokenizer (~20k tokens)', async () => { const oldCounter = createRealTokenCounter(); const newCounter = createRealTokenCounter(); - const text = createRealisticText(120000); - const tokenLimit = 100000; + const text = createRealisticText(20000); + const tokenLimit = 15000; const startOld = performance.now(); await processTextWithTokenLimitOLD({ @@ -654,8 +652,6 @@ describe('processTextWithTokenLimit', () => { }); const timeOld = performance.now() - startOld; - Tokenizer.freeAndResetAllEncoders(); - const startNew = performance.now(); const result = await processTextWithTokenLimit({ text, @@ -667,9 +663,9 @@ describe('processTextWithTokenLimit', () => { const oldCalls = oldCounter.getCallCount(); const newCalls = newCounter.getCallCount(); - console.log(`\n[REAL TIKTOKEN - User reported scenario: ~120k tokens]`); - console.log(`OLD implementation: ${oldCalls} tiktoken calls, ${timeOld.toFixed(0)}ms`); - console.log(`NEW implementation: ${newCalls} tiktoken calls, ${timeNew.toFixed(0)}ms`); + console.log(`\n[REAL TOKENIZER - ~20k tokens]`); + console.log(`OLD implementation: ${oldCalls} tokenizer calls, ${timeOld.toFixed(0)}ms`); + console.log(`NEW implementation: ${newCalls} tokenizer calls, ${timeNew.toFixed(0)}ms`); console.log(`Call reduction: ${((1 - newCalls / oldCalls) * 100).toFixed(1)}%`); console.log(`Time reduction: ${((1 - timeNew / timeOld) * 100).toFixed(1)}%`); console.log( @@ -684,8 +680,8 @@ describe('processTextWithTokenLimit', () => { it('should achieve at least 70% reduction with real tokenizer', async () => { const oldCounter = createRealTokenCounter(); const newCounter = createRealTokenCounter(); - const text = createRealisticText(50000); - const tokenLimit = 10000; + const text = createRealisticText(15000); + const tokenLimit = 5000; await processTextWithTokenLimitOLD({ text, @@ -693,8 +689,6 @@ describe('processTextWithTokenLimit', () => { tokenCountFn: oldCounter.tokenCountFn, }); - Tokenizer.freeAndResetAllEncoders(); - await processTextWithTokenLimit({ text, tokenLimit, @@ -706,7 +700,7 @@ describe('processTextWithTokenLimit', () => { const reduction = 1 - newCalls / oldCalls; console.log( - `[Real tiktoken 50k tokens] OLD: ${oldCalls}, NEW: ${newCalls}, Reduction: ${(reduction * 100).toFixed(1)}%`, + `[Real tokenizer 15k tokens] OLD: ${oldCalls}, NEW: ${newCalls}, Reduction: ${(reduction * 100).toFixed(1)}%`, ); expect(reduction).toBeGreaterThanOrEqual(0.7); @@ -714,10 +708,6 @@ describe('processTextWithTokenLimit', () => { }); describe('using countTokens async function from @librechat/api', () => { - beforeEach(() => { - Tokenizer.freeAndResetAllEncoders(); - }); - it('countTokens should return correct token count', async () => { const text = 'Hello, world!'; const count = await countTokens(text); @@ -759,8 +749,6 @@ describe('processTextWithTokenLimit', () => { tokenCountFn: oldCounter.tokenCountFn, }); - Tokenizer.freeAndResetAllEncoders(); - await processTextWithTokenLimit({ text, tokenLimit, @@ -776,11 +764,11 @@ describe('processTextWithTokenLimit', () => { expect(newCalls).toBeLessThan(oldCalls); }); - it('should handle user reported scenario with countTokens (~120k tokens)', async () => { + it('should handle large text with countTokens (~20k tokens)', async () => { const oldCounter = createCountTokensCounter(); const newCounter = createCountTokensCounter(); - const text = createRealisticText(120000); - const tokenLimit = 100000; + const text = createRealisticText(20000); + const tokenLimit = 15000; const startOld = performance.now(); await processTextWithTokenLimitOLD({ @@ -790,8 +778,6 @@ describe('processTextWithTokenLimit', () => { }); const timeOld = performance.now() - startOld; - Tokenizer.freeAndResetAllEncoders(); - const startNew = performance.now(); const result = await processTextWithTokenLimit({ text, @@ -803,7 +789,7 @@ describe('processTextWithTokenLimit', () => { const oldCalls = oldCounter.getCallCount(); const newCalls = newCounter.getCallCount(); - console.log(`\n[countTokens - User reported scenario: ~120k tokens]`); + console.log(`\n[countTokens - ~20k tokens]`); console.log(`OLD implementation: ${oldCalls} countTokens calls, ${timeOld.toFixed(0)}ms`); console.log(`NEW implementation: ${newCalls} countTokens calls, ${timeNew.toFixed(0)}ms`); console.log(`Call reduction: ${((1 - newCalls / oldCalls) * 100).toFixed(1)}%`); @@ -820,8 +806,8 @@ describe('processTextWithTokenLimit', () => { it('should achieve at least 70% reduction with countTokens', async () => { const oldCounter = createCountTokensCounter(); const newCounter = createCountTokensCounter(); - const text = createRealisticText(50000); - const tokenLimit = 10000; + const text = createRealisticText(15000); + const tokenLimit = 5000; await processTextWithTokenLimitOLD({ text, @@ -829,8 +815,6 @@ describe('processTextWithTokenLimit', () => { tokenCountFn: oldCounter.tokenCountFn, }); - Tokenizer.freeAndResetAllEncoders(); - await processTextWithTokenLimit({ text, tokenLimit, @@ -842,7 +826,7 @@ describe('processTextWithTokenLimit', () => { const reduction = 1 - newCalls / oldCalls; console.log( - `[countTokens 50k tokens] OLD: ${oldCalls}, NEW: ${newCalls}, Reduction: ${(reduction * 100).toFixed(1)}%`, + `[countTokens 15k tokens] OLD: ${oldCalls}, NEW: ${newCalls}, Reduction: ${(reduction * 100).toFixed(1)}%`, ); expect(reduction).toBeGreaterThanOrEqual(0.7); diff --git a/packages/api/src/utils/tokenizer.spec.ts b/packages/api/src/utils/tokenizer.spec.ts index edd6fe14de..b8c1bd8d98 100644 --- a/packages/api/src/utils/tokenizer.spec.ts +++ b/packages/api/src/utils/tokenizer.spec.ts @@ -1,12 +1,3 @@ -/** - * @file Tokenizer.spec.cjs - * - * Tests the real TokenizerSingleton (no mocking of `tiktoken`). - * Make sure to install `tiktoken` and have it configured properly. - */ - -import { logger } from '@librechat/data-schemas'; -import type { Tiktoken } from 'tiktoken'; import Tokenizer from './tokenizer'; jest.mock('@librechat/data-schemas', () => ({ @@ -17,127 +8,49 @@ jest.mock('@librechat/data-schemas', () => ({ describe('Tokenizer', () => { it('should be a singleton (same instance)', async () => { - const AnotherTokenizer = await import('./tokenizer'); // same path + const AnotherTokenizer = await import('./tokenizer'); expect(Tokenizer).toBe(AnotherTokenizer.default); }); - describe('getTokenizer', () => { - it('should create an encoder for an explicit model name (e.g., "gpt-4")', () => { - // The real `encoding_for_model` will be called internally - // as soon as we pass isModelName = true. - const tokenizer = Tokenizer.getTokenizer('gpt-4', true); - - // Basic sanity checks - expect(tokenizer).toBeDefined(); - // You can optionally check certain properties from `tiktoken` if they exist - // e.g., expect(typeof tokenizer.encode).toBe('function'); + describe('initEncoding', () => { + it('should load o200k_base encoding', async () => { + await Tokenizer.initEncoding('o200k_base'); + const count = Tokenizer.getTokenCount('Hello, world!', 'o200k_base'); + expect(count).toBeGreaterThan(0); }); - it('should create an encoder for a known encoding (e.g., "cl100k_base")', () => { - // The real `get_encoding` will be called internally - // as soon as we pass isModelName = false. - const tokenizer = Tokenizer.getTokenizer('cl100k_base', false); - - expect(tokenizer).toBeDefined(); - // e.g., expect(typeof tokenizer.encode).toBe('function'); + it('should load claude encoding', async () => { + await Tokenizer.initEncoding('claude'); + const count = Tokenizer.getTokenCount('Hello, world!', 'claude'); + expect(count).toBeGreaterThan(0); }); - it('should return cached tokenizer if previously fetched', () => { - const tokenizer1 = Tokenizer.getTokenizer('cl100k_base', false); - const tokenizer2 = Tokenizer.getTokenizer('cl100k_base', false); - // Should be the exact same instance from the cache - expect(tokenizer1).toBe(tokenizer2); - }); - }); - - describe('freeAndResetAllEncoders', () => { - beforeEach(() => { - jest.clearAllMocks(); - }); - - it('should free all encoders and reset tokenizerCallsCount to 1', () => { - // By creating two different encodings, we populate the cache - Tokenizer.getTokenizer('cl100k_base', false); - Tokenizer.getTokenizer('r50k_base', false); - - // Now free them - Tokenizer.freeAndResetAllEncoders(); - - // The internal cache is cleared - expect(Tokenizer.tokenizersCache['cl100k_base']).toBeUndefined(); - expect(Tokenizer.tokenizersCache['r50k_base']).toBeUndefined(); - - // tokenizerCallsCount is reset to 1 - expect(Tokenizer.tokenizerCallsCount).toBe(1); - }); - - it('should catch and log errors if freeing fails', () => { - // Mock logger.error before the test - const mockLoggerError = jest.spyOn(logger, 'error'); - - // Set up a problematic tokenizer in the cache - Tokenizer.tokenizersCache['cl100k_base'] = { - free() { - throw new Error('Intentional free error'); - }, - } as unknown as Tiktoken; - - // Should not throw uncaught errors - Tokenizer.freeAndResetAllEncoders(); - - // Verify logger.error was called with correct arguments - expect(mockLoggerError).toHaveBeenCalledWith( - '[Tokenizer] Free and reset encoders error', - expect.any(Error), - ); - - // Clean up - mockLoggerError.mockRestore(); - Tokenizer.tokenizersCache = {}; + it('should deduplicate concurrent init calls', async () => { + const [, , count] = await Promise.all([ + Tokenizer.initEncoding('o200k_base'), + Tokenizer.initEncoding('o200k_base'), + Tokenizer.initEncoding('o200k_base').then(() => + Tokenizer.getTokenCount('test', 'o200k_base'), + ), + ]); + expect(count).toBeGreaterThan(0); }); }); describe('getTokenCount', () => { - beforeEach(() => { - jest.clearAllMocks(); - Tokenizer.freeAndResetAllEncoders(); + beforeAll(async () => { + await Tokenizer.initEncoding('o200k_base'); + await Tokenizer.initEncoding('claude'); }); it('should return the number of tokens in the given text', () => { - const text = 'Hello, world!'; - const count = Tokenizer.getTokenCount(text, 'cl100k_base'); + const count = Tokenizer.getTokenCount('Hello, world!', 'o200k_base'); expect(count).toBeGreaterThan(0); }); - it('should reset encoders if an error is thrown', () => { - // We can simulate an error by temporarily overriding the selected tokenizer's `encode` method. - const tokenizer = Tokenizer.getTokenizer('cl100k_base', false); - const originalEncode = tokenizer.encode; - tokenizer.encode = () => { - throw new Error('Forced error'); - }; - - // Despite the forced error, the code should catch and reset, then re-encode - const count = Tokenizer.getTokenCount('Hello again', 'cl100k_base'); + it('should count tokens using claude encoding', () => { + const count = Tokenizer.getTokenCount('Hello, world!', 'claude'); expect(count).toBeGreaterThan(0); - - // Restore the original encode - tokenizer.encode = originalEncode; - }); - - it('should reset tokenizers after 25 calls', () => { - // Spy on freeAndResetAllEncoders - const resetSpy = jest.spyOn(Tokenizer, 'freeAndResetAllEncoders'); - - // Make 24 calls; should NOT reset yet - for (let i = 0; i < 24; i++) { - Tokenizer.getTokenCount('test text', 'cl100k_base'); - } - expect(resetSpy).not.toHaveBeenCalled(); - - // 25th call triggers the reset - Tokenizer.getTokenCount('the 25th call!', 'cl100k_base'); - expect(resetSpy).toHaveBeenCalledTimes(1); }); }); }); diff --git a/packages/api/src/utils/tokenizer.ts b/packages/api/src/utils/tokenizer.ts index 0b0282d36b..4c638c948e 100644 --- a/packages/api/src/utils/tokenizer.ts +++ b/packages/api/src/utils/tokenizer.ts @@ -1,74 +1,46 @@ import { logger } from '@librechat/data-schemas'; -import { encoding_for_model as encodingForModel, get_encoding as getEncoding } from 'tiktoken'; -import type { Tiktoken, TiktokenModel, TiktokenEncoding } from 'tiktoken'; +import { Tokenizer as AiTokenizer } from 'ai-tokenizer'; -interface TokenizerOptions { - debug?: boolean; -} +export type EncodingName = 'o200k_base' | 'claude'; + +type EncodingData = ConstructorParameters[0]; class Tokenizer { - tokenizersCache: Record; - tokenizerCallsCount: number; - private options?: TokenizerOptions; + private tokenizersCache: Partial> = {}; + private loadingPromises: Partial>> = {}; - constructor() { - this.tokenizersCache = {}; - this.tokenizerCallsCount = 0; - } - - getTokenizer( - encoding: TiktokenModel | TiktokenEncoding, - isModelName = false, - extendSpecialTokens: Record = {}, - ): Tiktoken { - let tokenizer: Tiktoken; + /** Pre-loads an encoding so that subsequent getTokenCount calls are accurate. */ + async initEncoding(encoding: EncodingName): Promise { if (this.tokenizersCache[encoding]) { - tokenizer = this.tokenizersCache[encoding]; - } else { - if (isModelName) { - tokenizer = encodingForModel(encoding as TiktokenModel, extendSpecialTokens); - } else { - tokenizer = getEncoding(encoding as TiktokenEncoding, extendSpecialTokens); - } - this.tokenizersCache[encoding] = tokenizer; + return; } - return tokenizer; + if (this.loadingPromises[encoding]) { + return this.loadingPromises[encoding]; + } + this.loadingPromises[encoding] = (async () => { + const data: EncodingData = + encoding === 'claude' + ? await import('ai-tokenizer/encoding/claude') + : await import('ai-tokenizer/encoding/o200k_base'); + this.tokenizersCache[encoding] = new AiTokenizer(data); + })(); + return this.loadingPromises[encoding]; } - freeAndResetAllEncoders(): void { + getTokenCount(text: string, encoding: EncodingName = 'o200k_base'): number { + const tokenizer = this.tokenizersCache[encoding]; + if (!tokenizer) { + this.initEncoding(encoding); + return Math.ceil(text.length / 4); + } try { - Object.keys(this.tokenizersCache).forEach((key) => { - if (this.tokenizersCache[key]) { - this.tokenizersCache[key].free(); - delete this.tokenizersCache[key]; - } - }); - this.tokenizerCallsCount = 1; - } catch (error) { - logger.error('[Tokenizer] Free and reset encoders error', error); - } - } - - resetTokenizersIfNecessary(): void { - if (this.tokenizerCallsCount >= 25) { - if (this.options?.debug) { - logger.debug('[Tokenizer] freeAndResetAllEncoders: reached 25 encodings, resetting...'); - } - this.freeAndResetAllEncoders(); - } - this.tokenizerCallsCount++; - } - - getTokenCount(text: string, encoding: TiktokenModel | TiktokenEncoding = 'cl100k_base'): number { - this.resetTokenizersIfNecessary(); - try { - const tokenizer = this.getTokenizer(encoding); - return tokenizer.encode(text, 'all').length; + return tokenizer.count(text); } catch (error) { logger.error('[Tokenizer] Error getting token count:', error); - this.freeAndResetAllEncoders(); - const tokenizer = this.getTokenizer(encoding); - return tokenizer.encode(text, 'all').length; + delete this.tokenizersCache[encoding]; + delete this.loadingPromises[encoding]; + this.initEncoding(encoding); + return Math.ceil(text.length / 4); } } } @@ -76,13 +48,13 @@ class Tokenizer { const TokenizerSingleton = new Tokenizer(); /** - * Counts the number of tokens in a given text using tiktoken. - * This is an async wrapper around Tokenizer.getTokenCount for compatibility. - * @param text - The text to be tokenized. Defaults to an empty string if not provided. + * Counts the number of tokens in a given text using ai-tokenizer with o200k_base encoding. + * @param text - The text to count tokens in. Defaults to an empty string. * @returns The number of tokens in the provided text. */ export async function countTokens(text = ''): Promise { - return TokenizerSingleton.getTokenCount(text, 'cl100k_base'); + await TokenizerSingleton.initEncoding('o200k_base'); + return TokenizerSingleton.getTokenCount(text, 'o200k_base'); } export default TokenizerSingleton; diff --git a/packages/api/src/utils/tokens.ts b/packages/api/src/utils/tokens.ts index 32b2fc6036..ae09da4f28 100644 --- a/packages/api/src/utils/tokens.ts +++ b/packages/api/src/utils/tokens.ts @@ -593,42 +593,3 @@ export function processModelData(input: z.infer): EndpointTo return tokenConfig; } - -export const tiktokenModels = new Set([ - 'text-davinci-003', - 'text-davinci-002', - 'text-davinci-001', - 'text-curie-001', - 'text-babbage-001', - 'text-ada-001', - 'davinci', - 'curie', - 'babbage', - 'ada', - 'code-davinci-002', - 'code-davinci-001', - 'code-cushman-002', - 'code-cushman-001', - 'davinci-codex', - 'cushman-codex', - 'text-davinci-edit-001', - 'code-davinci-edit-001', - 'text-embedding-ada-002', - 'text-similarity-davinci-001', - 'text-similarity-curie-001', - 'text-similarity-babbage-001', - 'text-similarity-ada-001', - 'text-search-davinci-doc-001', - 'text-search-curie-doc-001', - 'text-search-babbage-doc-001', - 'text-search-ada-doc-001', - 'code-search-babbage-code-001', - 'code-search-ada-code-001', - 'gpt2', - 'gpt-4', - 'gpt-4-0314', - 'gpt-4-32k', - 'gpt-4-32k-0314', - 'gpt-3.5-turbo', - 'gpt-3.5-turbo-0301', -]);