diff --git a/api/app/clients/BaseClient.js b/api/app/clients/BaseClient.js index c0d9169b51..f4a69be229 100644 --- a/api/app/clients/BaseClient.js +++ b/api/app/clients/BaseClient.js @@ -2,6 +2,7 @@ const crypto = require('crypto'); const fetch = require('node-fetch'); const { logger } = require('@librechat/data-schemas'); const { + countTokens, getBalanceConfig, extractFileContext, encodeAndFormatAudios, @@ -23,7 +24,6 @@ const { getMessages, saveMessage, updateMessage, saveConvo, getConvo } = require const { getStrategyFunctions } = require('~/server/services/Files/strategies'); const { checkBalance } = require('~/models/balanceMethods'); const { truncateToolCallOutputs } = require('./prompts'); -const countTokens = require('~/server/utils/countTokens'); const { getFiles } = require('~/models/File'); const TextStream = require('./TextStream'); diff --git a/api/models/Prompt.js b/api/models/Prompt.js index fbc161e97d..bde911b23a 100644 --- a/api/models/Prompt.js +++ b/api/models/Prompt.js @@ -1,4 +1,5 @@ const { ObjectId } = require('mongodb'); +const { escapeRegExp } = require('@librechat/api'); const { logger } = require('@librechat/data-schemas'); const { Constants, @@ -14,7 +15,6 @@ const { } = require('./Project'); const { removeAllPermissions } = require('~/server/services/PermissionService'); const { PromptGroup, Prompt, AclEntry } = require('~/db/models'); -const { escapeRegExp } = require('~/server/utils'); /** * Create a pipeline for the aggregation to get prompt groups diff --git a/api/server/controllers/assistants/chatV1.js b/api/server/controllers/assistants/chatV1.js index 4bd49e04dd..91759bed37 100644 --- a/api/server/controllers/assistants/chatV1.js +++ b/api/server/controllers/assistants/chatV1.js @@ -1,7 +1,7 @@ const { v4 } = require('uuid'); const { sleep } = require('@librechat/agents'); const { logger } = require('@librechat/data-schemas'); -const { sendEvent, getBalanceConfig, getModelMaxTokens } = require('@librechat/api'); +const { sendEvent, getBalanceConfig, getModelMaxTokens, countTokens } = require('@librechat/api'); const { Time, Constants, @@ -33,7 +33,6 @@ const { getTransactions } = require('~/models/Transaction'); const { checkBalance } = require('~/models/balanceMethods'); const { getConvo } = require('~/models/Conversation'); const getLogStores = require('~/cache/getLogStores'); -const { countTokens } = require('~/server/utils'); const { getOpenAIClient } = require('./helpers'); /** diff --git a/api/server/controllers/assistants/chatV2.js b/api/server/controllers/assistants/chatV2.js index 20b3398ee2..2dcfef2846 100644 --- a/api/server/controllers/assistants/chatV2.js +++ b/api/server/controllers/assistants/chatV2.js @@ -1,7 +1,7 @@ const { v4 } = require('uuid'); const { sleep } = require('@librechat/agents'); const { logger } = require('@librechat/data-schemas'); -const { sendEvent, getBalanceConfig, getModelMaxTokens } = require('@librechat/api'); +const { sendEvent, getBalanceConfig, getModelMaxTokens, countTokens } = require('@librechat/api'); const { Time, Constants, @@ -30,7 +30,6 @@ const { getTransactions } = require('~/models/Transaction'); const { checkBalance } = require('~/models/balanceMethods'); const { getConvo } = require('~/models/Conversation'); const getLogStores = require('~/cache/getLogStores'); -const { countTokens } = require('~/server/utils'); const { getOpenAIClient } = require('./helpers'); /** diff --git a/api/server/experimental.js b/api/server/experimental.js index 2e7f5dff63..0ceb58de22 100644 --- a/api/server/experimental.js +++ b/api/server/experimental.js @@ -292,7 +292,6 @@ if (cluster.isMaster) { app.use('/api/presets', routes.presets); app.use('/api/prompts', routes.prompts); app.use('/api/categories', routes.categories); - app.use('/api/tokenizer', routes.tokenizer); app.use('/api/endpoints', routes.endpoints); app.use('/api/balance', routes.balance); app.use('/api/models', routes.models); diff --git a/api/server/index.js b/api/server/index.js index d0bb64405f..767847c286 100644 --- a/api/server/index.js +++ b/api/server/index.js @@ -128,7 +128,6 @@ const startServer = async () => { app.use('/api/presets', routes.presets); app.use('/api/prompts', routes.prompts); app.use('/api/categories', routes.categories); - app.use('/api/tokenizer', routes.tokenizer); app.use('/api/endpoints', routes.endpoints); app.use('/api/balance', routes.balance); app.use('/api/models', routes.models); diff --git a/api/server/routes/index.js b/api/server/routes/index.js index adaca3859a..e8250a1f4d 100644 --- a/api/server/routes/index.js +++ b/api/server/routes/index.js @@ -1,7 +1,6 @@ const accessPermissions = require('./accessPermissions'); const assistants = require('./assistants'); const categories = require('./categories'); -const tokenizer = require('./tokenizer'); const endpoints = require('./endpoints'); const staticRoute = require('./static'); const messages = require('./messages'); @@ -53,7 +52,6 @@ module.exports = { messages, memories, endpoints, - tokenizer, assistants, categories, staticRoute, diff --git a/api/server/routes/messages.js b/api/server/routes/messages.js index 1e214278c9..901dd8961f 100644 --- a/api/server/routes/messages.js +++ b/api/server/routes/messages.js @@ -1,7 +1,7 @@ const express = require('express'); -const { unescapeLaTeX } = require('@librechat/api'); const { logger } = require('@librechat/data-schemas'); const { ContentTypes } = require('librechat-data-provider'); +const { unescapeLaTeX, countTokens } = require('@librechat/api'); const { saveConvo, getMessage, @@ -14,7 +14,6 @@ const { findAllArtifacts, replaceArtifactContent } = require('~/server/services/ const { requireJwtAuth, validateMessageReq } = require('~/server/middleware'); const { cleanUpPrimaryKeyValue } = require('~/lib/utils/misc'); const { getConvosQueried } = require('~/models/Conversation'); -const { countTokens } = require('~/server/utils'); const { Message } = require('~/db/models'); const router = express.Router(); diff --git a/api/server/routes/tokenizer.js b/api/server/routes/tokenizer.js deleted file mode 100644 index 62eb31b70e..0000000000 --- a/api/server/routes/tokenizer.js +++ /dev/null @@ -1,19 +0,0 @@ -const express = require('express'); -const { logger } = require('@librechat/data-schemas'); -const requireJwtAuth = require('~/server/middleware/requireJwtAuth'); -const { countTokens } = require('~/server/utils'); - -const router = express.Router(); - -router.post('/', requireJwtAuth, async (req, res) => { - try { - const { arg } = req.body; - const count = await countTokens(arg?.text ?? arg); - res.send({ count }); - } catch (e) { - logger.error('[/tokenizer] Error counting tokens', e); - res.status(500).json('Error counting tokens'); - } -}); - -module.exports = router; diff --git a/api/server/services/Threads/manage.js b/api/server/services/Threads/manage.js index 4cc1e107ed..627dba1a35 100644 --- a/api/server/services/Threads/manage.js +++ b/api/server/services/Threads/manage.js @@ -1,5 +1,6 @@ const path = require('path'); const { v4 } = require('uuid'); +const { countTokens, escapeRegExp } = require('@librechat/api'); const { Constants, ContentTypes, @@ -8,7 +9,6 @@ const { } = require('librechat-data-provider'); const { retrieveAndProcessFile } = require('~/server/services/Files/process'); const { recordMessage, getMessages } = require('~/models/Message'); -const { countTokens, escapeRegExp } = require('~/server/utils'); const { spendTokens } = require('~/models/spendTokens'); const { saveConvo } = require('~/models/Conversation'); diff --git a/api/server/utils/countTokens.js b/api/server/utils/countTokens.js deleted file mode 100644 index 504de26a5e..0000000000 --- a/api/server/utils/countTokens.js +++ /dev/null @@ -1,37 +0,0 @@ -const { Tiktoken } = require('tiktoken/lite'); -const { logger } = require('@librechat/data-schemas'); -const p50k_base = require('tiktoken/encoders/p50k_base.json'); -const cl100k_base = require('tiktoken/encoders/cl100k_base.json'); - -/** - * Counts the number of tokens in a given text using a specified encoding model. - * - * This function utilizes the 'Tiktoken' library to encode text based on the selected model. - * It supports two models, 'text-davinci-003' and 'gpt-3.5-turbo', each with its own encoding strategy. - * For 'text-davinci-003', the 'p50k_base' encoder is used, whereas for other models, the 'cl100k_base' encoder is applied. - * In case of an error during encoding, the error is logged, and the function returns 0. - * - * @async - * @param {string} text - The text to be tokenized. Defaults to an empty string if not provided. - * @param {string} modelName - The name of the model used for tokenizing. Defaults to 'gpt-3.5-turbo'. - * @returns {Promise} The number of tokens in the provided text. Returns 0 if an error occurs. - * @throws Logs the error to a logger and rethrows if any error occurs during tokenization. - */ -const countTokens = async (text = '', modelName = 'gpt-3.5-turbo') => { - let encoder = null; - try { - const model = modelName.includes('text-davinci-003') ? p50k_base : cl100k_base; - encoder = new Tiktoken(model.bpe_ranks, model.special_tokens, model.pat_str); - const tokens = encoder.encode(text); - encoder.free(); - return tokens.length; - } catch (e) { - logger.error('[countTokens]', e); - if (encoder) { - encoder.free(); - } - return 0; - } -}; - -module.exports = countTokens; diff --git a/api/server/utils/handleText.js b/api/server/utils/handleText.js index 15c2db3fcc..a798dc99bd 100644 --- a/api/server/utils/handleText.js +++ b/api/server/utils/handleText.js @@ -10,14 +10,6 @@ const { const { sendEvent } = require('@librechat/api'); const partialRight = require('lodash/partialRight'); -/** Helper function to escape special characters in regex - * @param {string} string - The string to escape. - * @returns {string} The escaped string. - */ -function escapeRegExp(string) { - return string.replace(/[.*+?^${}()|[\]\\]/g, '\\$&'); -} - const addSpaceIfNeeded = (text) => (text.length > 0 && !text.endsWith(' ') ? text + ' ' : text); const base = { message: true, initial: true }; @@ -181,7 +173,6 @@ function generateConfig(key, baseURL, endpoint) { module.exports = { handleText, formatSteps, - escapeRegExp, formatAction, isUserProvided, generateConfig, diff --git a/api/server/utils/index.js b/api/server/utils/index.js index 7e29b9f518..918ab54f85 100644 --- a/api/server/utils/index.js +++ b/api/server/utils/index.js @@ -1,5 +1,4 @@ const removePorts = require('./removePorts'); -const countTokens = require('./countTokens'); const handleText = require('./handleText'); const sendEmail = require('./sendEmail'); const queue = require('./queue'); @@ -7,7 +6,6 @@ const files = require('./files'); module.exports = { ...handleText, - countTokens, removePorts, sendEmail, ...files, diff --git a/packages/api/src/prompts/format.ts b/packages/api/src/prompts/format.ts index ad6f4ec237..df2b11b59a 100644 --- a/packages/api/src/prompts/format.ts +++ b/packages/api/src/prompts/format.ts @@ -2,6 +2,7 @@ import { SystemCategories } from 'librechat-data-provider'; import type { IPromptGroupDocument as IPromptGroup } from '@librechat/data-schemas'; import type { Types } from 'mongoose'; import type { PromptGroupsListResponse } from '~/types'; +import { escapeRegExp } from '~/utils/common'; /** * Formats prompt groups for the paginated /groups endpoint response @@ -101,7 +102,6 @@ export function buildPromptGroupFilter({ // Handle name filter - convert to regex for case-insensitive search if (name) { - const escapeRegExp = (str: string) => str.replace(/[.*+?^${}()|[\]\\]/g, '\\$&'); filter.name = new RegExp(escapeRegExp(name), 'i'); } diff --git a/packages/api/src/utils/common.ts b/packages/api/src/utils/common.ts index a5860b0a69..6f4871b741 100644 --- a/packages/api/src/utils/common.ts +++ b/packages/api/src/utils/common.ts @@ -48,3 +48,12 @@ export function optionalChainWithEmptyCheck( } return values[values.length - 1]; } + +/** + * Escapes special characters in a string for use in a regular expression. + * @param str - The string to escape. + * @returns The escaped string safe for use in RegExp. + */ +export function escapeRegExp(str: string): string { + return str.replace(/[.*+?^${}()|[\]\\]/g, '\\$&'); +} diff --git a/packages/api/src/utils/index.ts b/packages/api/src/utils/index.ts index 050f42796b..76f11289cb 100644 --- a/packages/api/src/utils/index.ts +++ b/packages/api/src/utils/index.ts @@ -17,7 +17,7 @@ export * from './promise'; export * from './sanitizeTitle'; export * from './tempChatRetention'; export * from './text'; -export { default as Tokenizer } from './tokenizer'; +export { default as Tokenizer, countTokens } from './tokenizer'; export * from './yaml'; export * from './http'; export * from './tokens'; diff --git a/packages/api/src/utils/text.spec.ts b/packages/api/src/utils/text.spec.ts new file mode 100644 index 0000000000..1b8d8aac98 --- /dev/null +++ b/packages/api/src/utils/text.spec.ts @@ -0,0 +1,851 @@ +import { processTextWithTokenLimit, TokenCountFn } from './text'; +import Tokenizer, { countTokens } from './tokenizer'; + +jest.mock('@librechat/data-schemas', () => ({ + logger: { + debug: jest.fn(), + warn: jest.fn(), + error: jest.fn(), + }, +})); + +/** + * OLD IMPLEMENTATION (Binary Search) - kept for comparison testing + * This is the original algorithm that caused CPU spikes + */ +async function processTextWithTokenLimitOLD({ + text, + tokenLimit, + tokenCountFn, +}: { + text: string; + tokenLimit: number; + tokenCountFn: TokenCountFn; +}): Promise<{ text: string; tokenCount: number; wasTruncated: boolean }> { + const originalTokenCount = await tokenCountFn(text); + + if (originalTokenCount <= tokenLimit) { + return { + text, + tokenCount: originalTokenCount, + wasTruncated: false, + }; + } + + let low = 0; + let high = text.length; + let bestText = ''; + + while (low <= high) { + const mid = Math.floor((low + high) / 2); + const truncatedText = text.substring(0, mid); + const tokenCount = await tokenCountFn(truncatedText); + + if (tokenCount <= tokenLimit) { + bestText = truncatedText; + low = mid + 1; + } else { + high = mid - 1; + } + } + + const finalTokenCount = await tokenCountFn(bestText); + + return { + text: bestText, + tokenCount: finalTokenCount, + wasTruncated: true, + }; +} + +/** + * Creates a wrapper around Tokenizer.getTokenCount that tracks call count + */ +const createRealTokenCounter = () => { + let callCount = 0; + const tokenCountFn = (text: string): number => { + callCount++; + return Tokenizer.getTokenCount(text, 'cl100k_base'); + }; + return { + tokenCountFn, + getCallCount: () => callCount, + resetCallCount: () => { + callCount = 0; + }, + }; +}; + +/** + * Creates a wrapper around the async countTokens function that tracks call count + */ +const createCountTokensCounter = () => { + let callCount = 0; + const tokenCountFn = async (text: string): Promise => { + callCount++; + return countTokens(text); + }; + return { + tokenCountFn, + getCallCount: () => callCount, + resetCallCount: () => { + callCount = 0; + }, + }; +}; + +describe('processTextWithTokenLimit', () => { + /** + * Creates a mock token count function that simulates realistic token counting. + * Roughly 4 characters per token (common for English text). + * Tracks call count to verify efficiency. + */ + const createMockTokenCounter = () => { + let callCount = 0; + const tokenCountFn = (text: string): number => { + callCount++; + return Math.ceil(text.length / 4); + }; + return { + tokenCountFn, + getCallCount: () => callCount, + resetCallCount: () => { + callCount = 0; + }, + }; + }; + + /** Creates a string of specified character length */ + const createTextOfLength = (charLength: number): string => { + return 'a'.repeat(charLength); + }; + + /** Creates realistic text content with varied token density */ + const createRealisticText = (approximateTokens: number): string => { + const words = [ + 'the', + 'quick', + 'brown', + 'fox', + 'jumps', + 'over', + 'lazy', + 'dog', + 'lorem', + 'ipsum', + 'dolor', + 'sit', + 'amet', + 'consectetur', + 'adipiscing', + 'elit', + 'sed', + 'do', + 'eiusmod', + 'tempor', + 'incididunt', + 'ut', + 'labore', + 'et', + 'dolore', + 'magna', + 'aliqua', + 'enim', + 'ad', + 'minim', + 'veniam', + 'authentication', + 'implementation', + 'configuration', + 'documentation', + ]; + const result: string[] = []; + for (let i = 0; i < approximateTokens; i++) { + result.push(words[i % words.length]); + } + return result.join(' '); + }; + + describe('tokenCountFn flexibility (sync and async)', () => { + it('should work with synchronous tokenCountFn', async () => { + const syncTokenCountFn = (text: string): number => Math.ceil(text.length / 4); + const text = 'Hello, world! This is a test message.'; + const tokenLimit = 5; + + const result = await processTextWithTokenLimit({ + text, + tokenLimit, + tokenCountFn: syncTokenCountFn, + }); + + expect(result.wasTruncated).toBe(true); + expect(result.tokenCount).toBeLessThanOrEqual(tokenLimit); + }); + + it('should work with asynchronous tokenCountFn', async () => { + const asyncTokenCountFn = async (text: string): Promise => { + await new Promise((resolve) => setTimeout(resolve, 1)); + return Math.ceil(text.length / 4); + }; + const text = 'Hello, world! This is a test message.'; + const tokenLimit = 5; + + const result = await processTextWithTokenLimit({ + text, + tokenLimit, + tokenCountFn: asyncTokenCountFn, + }); + + expect(result.wasTruncated).toBe(true); + expect(result.tokenCount).toBeLessThanOrEqual(tokenLimit); + }); + + it('should produce equivalent results with sync and async tokenCountFn', async () => { + const syncTokenCountFn = (text: string): number => Math.ceil(text.length / 4); + const asyncTokenCountFn = async (text: string): Promise => Math.ceil(text.length / 4); + const text = 'a'.repeat(8000); + const tokenLimit = 1000; + + const syncResult = await processTextWithTokenLimit({ + text, + tokenLimit, + tokenCountFn: syncTokenCountFn, + }); + + const asyncResult = await processTextWithTokenLimit({ + text, + tokenLimit, + tokenCountFn: asyncTokenCountFn, + }); + + expect(syncResult.tokenCount).toBe(asyncResult.tokenCount); + expect(syncResult.wasTruncated).toBe(asyncResult.wasTruncated); + expect(syncResult.text.length).toBe(asyncResult.text.length); + }); + }); + + describe('when text is under the token limit', () => { + it('should return original text unchanged', async () => { + const { tokenCountFn } = createMockTokenCounter(); + const text = 'Hello, world!'; + const tokenLimit = 100; + + const result = await processTextWithTokenLimit({ + text, + tokenLimit, + tokenCountFn, + }); + + expect(result.text).toBe(text); + expect(result.wasTruncated).toBe(false); + }); + + it('should return correct token count', async () => { + const { tokenCountFn } = createMockTokenCounter(); + const text = 'Hello, world!'; + const tokenLimit = 100; + + const result = await processTextWithTokenLimit({ + text, + tokenLimit, + tokenCountFn, + }); + + expect(result.tokenCount).toBe(Math.ceil(text.length / 4)); + }); + + it('should only call tokenCountFn once when under limit', async () => { + const { tokenCountFn, getCallCount } = createMockTokenCounter(); + const text = 'Hello, world!'; + const tokenLimit = 100; + + await processTextWithTokenLimit({ + text, + tokenLimit, + tokenCountFn, + }); + + expect(getCallCount()).toBe(1); + }); + }); + + describe('when text is exactly at the token limit', () => { + it('should return original text unchanged', async () => { + const { tokenCountFn } = createMockTokenCounter(); + const text = createTextOfLength(400); + const tokenLimit = 100; + + const result = await processTextWithTokenLimit({ + text, + tokenLimit, + tokenCountFn, + }); + + expect(result.text).toBe(text); + expect(result.wasTruncated).toBe(false); + expect(result.tokenCount).toBe(tokenLimit); + }); + }); + + describe('when text exceeds the token limit', () => { + it('should truncate text to fit within limit', async () => { + const { tokenCountFn } = createMockTokenCounter(); + const text = createTextOfLength(8000); + const tokenLimit = 1000; + + const result = await processTextWithTokenLimit({ + text, + tokenLimit, + tokenCountFn, + }); + + expect(result.wasTruncated).toBe(true); + expect(result.tokenCount).toBeLessThanOrEqual(tokenLimit); + expect(result.text.length).toBeLessThan(text.length); + }); + + it('should truncate text to be close to but not exceed the limit', async () => { + const { tokenCountFn } = createMockTokenCounter(); + const text = createTextOfLength(8000); + const tokenLimit = 1000; + + const result = await processTextWithTokenLimit({ + text, + tokenLimit, + tokenCountFn, + }); + + expect(result.tokenCount).toBeLessThanOrEqual(tokenLimit); + expect(result.tokenCount).toBeGreaterThan(tokenLimit * 0.9); + }); + }); + + describe('efficiency - tokenCountFn call count', () => { + it('should call tokenCountFn at most 7 times for large text (vs ~17 for binary search)', async () => { + const { tokenCountFn, getCallCount } = createMockTokenCounter(); + const text = createTextOfLength(400000); + const tokenLimit = 50000; + + await processTextWithTokenLimit({ + text, + tokenLimit, + tokenCountFn, + }); + + expect(getCallCount()).toBeLessThanOrEqual(7); + }); + + it('should typically call tokenCountFn only 2-3 times for standard truncation', async () => { + const { tokenCountFn, getCallCount } = createMockTokenCounter(); + const text = createTextOfLength(40000); + const tokenLimit = 5000; + + await processTextWithTokenLimit({ + text, + tokenLimit, + tokenCountFn, + }); + + expect(getCallCount()).toBeLessThanOrEqual(3); + }); + + it('should call tokenCountFn only once when text is under limit', async () => { + const { tokenCountFn, getCallCount } = createMockTokenCounter(); + const text = createTextOfLength(1000); + const tokenLimit = 10000; + + await processTextWithTokenLimit({ + text, + tokenLimit, + tokenCountFn, + }); + + expect(getCallCount()).toBe(1); + }); + + it('should handle very large text (100k+ tokens) efficiently', async () => { + const { tokenCountFn, getCallCount } = createMockTokenCounter(); + const text = createTextOfLength(500000); + const tokenLimit = 100000; + + await processTextWithTokenLimit({ + text, + tokenLimit, + tokenCountFn, + }); + + expect(getCallCount()).toBeLessThanOrEqual(7); + }); + }); + + describe('edge cases', () => { + it('should handle empty text', async () => { + const { tokenCountFn } = createMockTokenCounter(); + const text = ''; + const tokenLimit = 100; + + const result = await processTextWithTokenLimit({ + text, + tokenLimit, + tokenCountFn, + }); + + expect(result.text).toBe(''); + expect(result.tokenCount).toBe(0); + expect(result.wasTruncated).toBe(false); + }); + + it('should handle token limit of 1', async () => { + const { tokenCountFn } = createMockTokenCounter(); + const text = createTextOfLength(1000); + const tokenLimit = 1; + + const result = await processTextWithTokenLimit({ + text, + tokenLimit, + tokenCountFn, + }); + + expect(result.wasTruncated).toBe(true); + expect(result.tokenCount).toBeLessThanOrEqual(tokenLimit); + }); + + it('should handle text that is just slightly over the limit', async () => { + const { tokenCountFn } = createMockTokenCounter(); + const text = createTextOfLength(404); + const tokenLimit = 100; + + const result = await processTextWithTokenLimit({ + text, + tokenLimit, + tokenCountFn, + }); + + expect(result.wasTruncated).toBe(true); + expect(result.tokenCount).toBeLessThanOrEqual(tokenLimit); + }); + }); + + describe('correctness with variable token density', () => { + it('should handle text with varying token density', async () => { + const variableDensityTokenCounter = (text: string): number => { + const shortWords = (text.match(/\s+/g) || []).length; + return Math.ceil(text.length / 4) + shortWords; + }; + + const text = 'This is a test with many short words and some longer concatenated words too'; + const tokenLimit = 10; + + const result = await processTextWithTokenLimit({ + text, + tokenLimit, + tokenCountFn: variableDensityTokenCounter, + }); + + expect(result.wasTruncated).toBe(true); + expect(result.tokenCount).toBeLessThanOrEqual(tokenLimit); + }); + }); + + describe('direct comparison with OLD binary search implementation', () => { + it('should produce equivalent results to the old implementation', async () => { + const oldCounter = createMockTokenCounter(); + const newCounter = createMockTokenCounter(); + const text = createTextOfLength(8000); + const tokenLimit = 1000; + + const oldResult = await processTextWithTokenLimitOLD({ + text, + tokenLimit, + tokenCountFn: oldCounter.tokenCountFn, + }); + + const newResult = await processTextWithTokenLimit({ + text, + tokenLimit, + tokenCountFn: newCounter.tokenCountFn, + }); + + expect(newResult.wasTruncated).toBe(oldResult.wasTruncated); + expect(newResult.tokenCount).toBeLessThanOrEqual(tokenLimit); + expect(oldResult.tokenCount).toBeLessThanOrEqual(tokenLimit); + }); + + it('should use significantly fewer tokenCountFn calls than old implementation (400k chars)', async () => { + const oldCounter = createMockTokenCounter(); + const newCounter = createMockTokenCounter(); + const text = createTextOfLength(400000); + const tokenLimit = 50000; + + await processTextWithTokenLimitOLD({ + text, + tokenLimit, + tokenCountFn: oldCounter.tokenCountFn, + }); + + await processTextWithTokenLimit({ + text, + tokenLimit, + tokenCountFn: newCounter.tokenCountFn, + }); + + const oldCalls = oldCounter.getCallCount(); + const newCalls = newCounter.getCallCount(); + + console.log( + `[400k chars] OLD implementation: ${oldCalls} calls, NEW implementation: ${newCalls} calls`, + ); + console.log(`[400k chars] Reduction: ${((1 - newCalls / oldCalls) * 100).toFixed(1)}%`); + + expect(newCalls).toBeLessThan(oldCalls); + expect(newCalls).toBeLessThanOrEqual(7); + }); + + it('should use significantly fewer tokenCountFn calls than old implementation (500k chars, 100k token limit)', async () => { + const oldCounter = createMockTokenCounter(); + const newCounter = createMockTokenCounter(); + const text = createTextOfLength(500000); + const tokenLimit = 100000; + + await processTextWithTokenLimitOLD({ + text, + tokenLimit, + tokenCountFn: oldCounter.tokenCountFn, + }); + + await processTextWithTokenLimit({ + text, + tokenLimit, + tokenCountFn: newCounter.tokenCountFn, + }); + + const oldCalls = oldCounter.getCallCount(); + const newCalls = newCounter.getCallCount(); + + console.log( + `[500k chars] OLD implementation: ${oldCalls} calls, NEW implementation: ${newCalls} calls`, + ); + console.log(`[500k chars] Reduction: ${((1 - newCalls / oldCalls) * 100).toFixed(1)}%`); + + expect(newCalls).toBeLessThan(oldCalls); + }); + + it('should achieve at least 70% reduction in tokenCountFn calls', async () => { + const oldCounter = createMockTokenCounter(); + const newCounter = createMockTokenCounter(); + const text = createTextOfLength(500000); + const tokenLimit = 100000; + + await processTextWithTokenLimitOLD({ + text, + tokenLimit, + tokenCountFn: oldCounter.tokenCountFn, + }); + + await processTextWithTokenLimit({ + text, + tokenLimit, + tokenCountFn: newCounter.tokenCountFn, + }); + + const oldCalls = oldCounter.getCallCount(); + const newCalls = newCounter.getCallCount(); + const reduction = 1 - newCalls / oldCalls; + + console.log( + `Efficiency improvement: ${(reduction * 100).toFixed(1)}% fewer tokenCountFn calls`, + ); + + expect(reduction).toBeGreaterThanOrEqual(0.7); + }); + + it('should simulate the reported scenario (122k tokens, 100k limit)', async () => { + const oldCounter = createMockTokenCounter(); + const newCounter = createMockTokenCounter(); + const text = createTextOfLength(489564); + const tokenLimit = 100000; + + await processTextWithTokenLimitOLD({ + text, + tokenLimit, + tokenCountFn: oldCounter.tokenCountFn, + }); + + await processTextWithTokenLimit({ + text, + tokenLimit, + tokenCountFn: newCounter.tokenCountFn, + }); + + const oldCalls = oldCounter.getCallCount(); + const newCalls = newCounter.getCallCount(); + + console.log(`[User reported scenario: ~122k tokens]`); + console.log(`OLD implementation: ${oldCalls} tokenCountFn calls`); + console.log(`NEW implementation: ${newCalls} tokenCountFn calls`); + console.log(`Reduction: ${((1 - newCalls / oldCalls) * 100).toFixed(1)}%`); + + expect(newCalls).toBeLessThan(oldCalls); + expect(newCalls).toBeLessThanOrEqual(7); + }); + }); + + describe('direct comparison with REAL tiktoken tokenizer', () => { + beforeEach(() => { + Tokenizer.freeAndResetAllEncoders(); + }); + + it('should produce valid truncation with real tokenizer', async () => { + const counter = createRealTokenCounter(); + const text = createRealisticText(5000); + const tokenLimit = 1000; + + const result = await processTextWithTokenLimit({ + text, + tokenLimit, + tokenCountFn: counter.tokenCountFn, + }); + + expect(result.wasTruncated).toBe(true); + expect(result.tokenCount).toBeLessThanOrEqual(tokenLimit); + expect(result.text.length).toBeLessThan(text.length); + }); + + it('should use fewer tiktoken calls than old implementation (realistic text)', async () => { + const oldCounter = createRealTokenCounter(); + const newCounter = createRealTokenCounter(); + const text = createRealisticText(15000); + const tokenLimit = 5000; + + await processTextWithTokenLimitOLD({ + text, + tokenLimit, + tokenCountFn: oldCounter.tokenCountFn, + }); + + Tokenizer.freeAndResetAllEncoders(); + + await processTextWithTokenLimit({ + text, + tokenLimit, + tokenCountFn: newCounter.tokenCountFn, + }); + + const oldCalls = oldCounter.getCallCount(); + const newCalls = newCounter.getCallCount(); + + console.log(`[Real tiktoken ~15k tokens] OLD: ${oldCalls} calls, NEW: ${newCalls} calls`); + console.log(`[Real tiktoken] Reduction: ${((1 - newCalls / oldCalls) * 100).toFixed(1)}%`); + + expect(newCalls).toBeLessThan(oldCalls); + }); + + it('should handle the reported user scenario with real tokenizer (~120k tokens)', async () => { + const oldCounter = createRealTokenCounter(); + const newCounter = createRealTokenCounter(); + const text = createRealisticText(120000); + const tokenLimit = 100000; + + const startOld = performance.now(); + await processTextWithTokenLimitOLD({ + text, + tokenLimit, + tokenCountFn: oldCounter.tokenCountFn, + }); + const timeOld = performance.now() - startOld; + + Tokenizer.freeAndResetAllEncoders(); + + const startNew = performance.now(); + const result = await processTextWithTokenLimit({ + text, + tokenLimit, + tokenCountFn: newCounter.tokenCountFn, + }); + const timeNew = performance.now() - startNew; + + const oldCalls = oldCounter.getCallCount(); + const newCalls = newCounter.getCallCount(); + + console.log(`\n[REAL TIKTOKEN - User reported scenario: ~120k tokens]`); + console.log(`OLD implementation: ${oldCalls} tiktoken calls, ${timeOld.toFixed(0)}ms`); + console.log(`NEW implementation: ${newCalls} tiktoken calls, ${timeNew.toFixed(0)}ms`); + console.log(`Call reduction: ${((1 - newCalls / oldCalls) * 100).toFixed(1)}%`); + console.log(`Time reduction: ${((1 - timeNew / timeOld) * 100).toFixed(1)}%`); + console.log( + `Result: truncated=${result.wasTruncated}, tokens=${result.tokenCount}/${tokenLimit}\n`, + ); + + expect(newCalls).toBeLessThan(oldCalls); + expect(result.tokenCount).toBeLessThanOrEqual(tokenLimit); + expect(newCalls).toBeLessThanOrEqual(7); + }); + + it('should achieve at least 70% reduction with real tokenizer', async () => { + const oldCounter = createRealTokenCounter(); + const newCounter = createRealTokenCounter(); + const text = createRealisticText(50000); + const tokenLimit = 10000; + + await processTextWithTokenLimitOLD({ + text, + tokenLimit, + tokenCountFn: oldCounter.tokenCountFn, + }); + + Tokenizer.freeAndResetAllEncoders(); + + await processTextWithTokenLimit({ + text, + tokenLimit, + tokenCountFn: newCounter.tokenCountFn, + }); + + const oldCalls = oldCounter.getCallCount(); + const newCalls = newCounter.getCallCount(); + const reduction = 1 - newCalls / oldCalls; + + console.log( + `[Real tiktoken 50k tokens] OLD: ${oldCalls}, NEW: ${newCalls}, Reduction: ${(reduction * 100).toFixed(1)}%`, + ); + + expect(reduction).toBeGreaterThanOrEqual(0.7); + }); + }); + + describe('using countTokens async function from @librechat/api', () => { + beforeEach(() => { + Tokenizer.freeAndResetAllEncoders(); + }); + + it('countTokens should return correct token count', async () => { + const text = 'Hello, world!'; + const count = await countTokens(text); + + expect(count).toBeGreaterThan(0); + expect(typeof count).toBe('number'); + }); + + it('countTokens should handle empty string', async () => { + const count = await countTokens(''); + expect(count).toBe(0); + }); + + it('should work with processTextWithTokenLimit using countTokens', async () => { + const counter = createCountTokensCounter(); + const text = createRealisticText(5000); + const tokenLimit = 1000; + + const result = await processTextWithTokenLimit({ + text, + tokenLimit, + tokenCountFn: counter.tokenCountFn, + }); + + expect(result.wasTruncated).toBe(true); + expect(result.tokenCount).toBeLessThanOrEqual(tokenLimit); + expect(result.text.length).toBeLessThan(text.length); + }); + + it('should use fewer countTokens calls than old implementation', async () => { + const oldCounter = createCountTokensCounter(); + const newCounter = createCountTokensCounter(); + const text = createRealisticText(15000); + const tokenLimit = 5000; + + await processTextWithTokenLimitOLD({ + text, + tokenLimit, + tokenCountFn: oldCounter.tokenCountFn, + }); + + Tokenizer.freeAndResetAllEncoders(); + + await processTextWithTokenLimit({ + text, + tokenLimit, + tokenCountFn: newCounter.tokenCountFn, + }); + + const oldCalls = oldCounter.getCallCount(); + const newCalls = newCounter.getCallCount(); + + console.log(`[countTokens ~15k tokens] OLD: ${oldCalls} calls, NEW: ${newCalls} calls`); + console.log(`[countTokens] Reduction: ${((1 - newCalls / oldCalls) * 100).toFixed(1)}%`); + + expect(newCalls).toBeLessThan(oldCalls); + }); + + it('should handle user reported scenario with countTokens (~120k tokens)', async () => { + const oldCounter = createCountTokensCounter(); + const newCounter = createCountTokensCounter(); + const text = createRealisticText(120000); + const tokenLimit = 100000; + + const startOld = performance.now(); + await processTextWithTokenLimitOLD({ + text, + tokenLimit, + tokenCountFn: oldCounter.tokenCountFn, + }); + const timeOld = performance.now() - startOld; + + Tokenizer.freeAndResetAllEncoders(); + + const startNew = performance.now(); + const result = await processTextWithTokenLimit({ + text, + tokenLimit, + tokenCountFn: newCounter.tokenCountFn, + }); + const timeNew = performance.now() - startNew; + + const oldCalls = oldCounter.getCallCount(); + const newCalls = newCounter.getCallCount(); + + console.log(`\n[countTokens - User reported scenario: ~120k tokens]`); + console.log(`OLD implementation: ${oldCalls} countTokens calls, ${timeOld.toFixed(0)}ms`); + console.log(`NEW implementation: ${newCalls} countTokens calls, ${timeNew.toFixed(0)}ms`); + console.log(`Call reduction: ${((1 - newCalls / oldCalls) * 100).toFixed(1)}%`); + console.log(`Time reduction: ${((1 - timeNew / timeOld) * 100).toFixed(1)}%`); + console.log( + `Result: truncated=${result.wasTruncated}, tokens=${result.tokenCount}/${tokenLimit}\n`, + ); + + expect(newCalls).toBeLessThan(oldCalls); + expect(result.tokenCount).toBeLessThanOrEqual(tokenLimit); + expect(newCalls).toBeLessThanOrEqual(7); + }); + + it('should achieve at least 70% reduction with countTokens', async () => { + const oldCounter = createCountTokensCounter(); + const newCounter = createCountTokensCounter(); + const text = createRealisticText(50000); + const tokenLimit = 10000; + + await processTextWithTokenLimitOLD({ + text, + tokenLimit, + tokenCountFn: oldCounter.tokenCountFn, + }); + + Tokenizer.freeAndResetAllEncoders(); + + await processTextWithTokenLimit({ + text, + tokenLimit, + tokenCountFn: newCounter.tokenCountFn, + }); + + const oldCalls = oldCounter.getCallCount(); + const newCalls = newCounter.getCallCount(); + const reduction = 1 - newCalls / oldCalls; + + console.log( + `[countTokens 50k tokens] OLD: ${oldCalls}, NEW: ${newCalls}, Reduction: ${(reduction * 100).toFixed(1)}%`, + ); + + expect(reduction).toBeGreaterThanOrEqual(0.7); + }); + }); +}); diff --git a/packages/api/src/utils/text.ts b/packages/api/src/utils/text.ts index 3de343bd32..3099c2bbc4 100644 --- a/packages/api/src/utils/text.ts +++ b/packages/api/src/utils/text.ts @@ -1,11 +1,39 @@ import { logger } from '@librechat/data-schemas'; +/** Token count function that can be sync or async */ +export type TokenCountFn = (text: string) => number | Promise; + +/** + * Safety buffer multiplier applied to character position estimates during truncation. + * + * We use 98% (0.98) rather than 100% to intentionally undershoot the target on the first attempt. + * This is necessary because: + * - Token density varies across text (some regions may have more tokens per character than the average) + * - The ratio-based estimate assumes uniform token distribution, which is rarely true + * - Undershooting is safer than overshooting: exceeding the limit requires another iteration, + * while being slightly under is acceptable + * - In practice, this buffer reduces refinement iterations from 2-3 down to 0-1 in most cases + * + * @example + * // If text has 1000 chars and 250 tokens (4 chars/token average), targeting 100 tokens: + * // Without buffer: estimate = 1000 * (100/250) = 400 chars → might yield 105 tokens (over!) + * // With 0.98 buffer: estimate = 400 * 0.98 = 392 chars → likely yields 97-99 tokens (safe) + */ +const TRUNCATION_SAFETY_BUFFER = 0.98; + /** * Processes text content by counting tokens and truncating if it exceeds the specified limit. + * Uses ratio-based estimation to minimize expensive tokenCountFn calls. + * * @param text - The text content to process * @param tokenLimit - The maximum number of tokens allowed - * @param tokenCountFn - Function to count tokens + * @param tokenCountFn - Function to count tokens (can be sync or async) * @returns Promise resolving to object with processed text, token count, and truncation status + * + * @remarks + * This function uses a ratio-based estimation algorithm instead of binary search. + * Binary search would require O(log n) tokenCountFn calls (~17 for 100k chars), + * while this approach typically requires only 2-3 calls for a 90%+ reduction in CPU usage. */ export async function processTextWithTokenLimit({ text, @@ -14,7 +42,7 @@ export async function processTextWithTokenLimit({ }: { text: string; tokenLimit: number; - tokenCountFn: (text: string) => number; + tokenCountFn: TokenCountFn; }): Promise<{ text: string; tokenCount: number; wasTruncated: boolean }> { const originalTokenCount = await tokenCountFn(text); @@ -26,40 +54,34 @@ export async function processTextWithTokenLimit({ }; } - /** - * Doing binary search here to find the truncation point efficiently - * (May be a better way to go about this) - */ - let low = 0; - let high = text.length; - let bestText = ''; - logger.debug( `[textTokenLimiter] Text content exceeds token limit: ${originalTokenCount} > ${tokenLimit}, truncating...`, ); - while (low <= high) { - const mid = Math.floor((low + high) / 2); - const truncatedText = text.substring(0, mid); - const tokenCount = await tokenCountFn(truncatedText); + const ratio = tokenLimit / originalTokenCount; + let charPosition = Math.floor(text.length * ratio * TRUNCATION_SAFETY_BUFFER); - if (tokenCount <= tokenLimit) { - bestText = truncatedText; - low = mid + 1; - } else { - high = mid - 1; - } + let truncatedText = text.substring(0, charPosition); + let tokenCount = await tokenCountFn(truncatedText); + + const maxIterations = 5; + let iterations = 0; + + while (tokenCount > tokenLimit && iterations < maxIterations && charPosition > 0) { + const overageRatio = tokenLimit / tokenCount; + charPosition = Math.floor(charPosition * overageRatio * TRUNCATION_SAFETY_BUFFER); + truncatedText = text.substring(0, charPosition); + tokenCount = await tokenCountFn(truncatedText); + iterations++; } - const finalTokenCount = await tokenCountFn(bestText); - logger.warn( - `[textTokenLimiter] Text truncated from ${originalTokenCount} to ${finalTokenCount} tokens (limit: ${tokenLimit})`, + `[textTokenLimiter] Text truncated from ${originalTokenCount} to ${tokenCount} tokens (limit: ${tokenLimit})`, ); return { - text: bestText, - tokenCount: finalTokenCount, + text: truncatedText, + tokenCount, wasTruncated: true, }; } diff --git a/packages/api/src/utils/tokenizer.ts b/packages/api/src/utils/tokenizer.ts index 2a2088cad3..0b0282d36b 100644 --- a/packages/api/src/utils/tokenizer.ts +++ b/packages/api/src/utils/tokenizer.ts @@ -75,4 +75,14 @@ class Tokenizer { const TokenizerSingleton = new Tokenizer(); +/** + * Counts the number of tokens in a given text using tiktoken. + * This is an async wrapper around Tokenizer.getTokenCount for compatibility. + * @param text - The text to be tokenized. Defaults to an empty string if not provided. + * @returns The number of tokens in the provided text. + */ +export async function countTokens(text = ''): Promise { + return TokenizerSingleton.getTokenCount(text, 'cl100k_base'); +} + export default TokenizerSingleton;