🔄 refactor: Consolidate Tokenizer; Fix Jest Open Handles (#5175)

* refactor: consolidate tokenizer to singleton

* fix: remove legacy tokenizer code, add Tokenizer singleton tests

* ci: fix jest open handles
This commit is contained in:
Danny Avila 2025-01-03 18:11:14 -05:00 committed by GitHub
parent bf0a84e45a
commit c26b54c74d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 202 additions and 221 deletions

View file

@ -1,6 +1,5 @@
const Anthropic = require('@anthropic-ai/sdk'); const Anthropic = require('@anthropic-ai/sdk');
const { HttpsProxyAgent } = require('https-proxy-agent'); const { HttpsProxyAgent } = require('https-proxy-agent');
const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken');
const { const {
Constants, Constants,
EModelEndpoint, EModelEndpoint,
@ -19,6 +18,7 @@ const {
} = require('./prompts'); } = require('./prompts');
const { getModelMaxTokens, getModelMaxOutputTokens, matchModelName } = require('~/utils'); const { getModelMaxTokens, getModelMaxOutputTokens, matchModelName } = require('~/utils');
const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens'); const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens');
const Tokenizer = require('~/server/services/Tokenizer');
const { sleep } = require('~/server/utils'); const { sleep } = require('~/server/utils');
const BaseClient = require('./BaseClient'); const BaseClient = require('./BaseClient');
const { logger } = require('~/config'); const { logger } = require('~/config');
@ -26,8 +26,6 @@ const { logger } = require('~/config');
const HUMAN_PROMPT = '\n\nHuman:'; const HUMAN_PROMPT = '\n\nHuman:';
const AI_PROMPT = '\n\nAssistant:'; const AI_PROMPT = '\n\nAssistant:';
const tokenizersCache = {};
/** Helper function to introduce a delay before retrying */ /** Helper function to introduce a delay before retrying */
function delayBeforeRetry(attempts, baseDelay = 1000) { function delayBeforeRetry(attempts, baseDelay = 1000) {
return new Promise((resolve) => setTimeout(resolve, baseDelay * attempts)); return new Promise((resolve) => setTimeout(resolve, baseDelay * attempts));
@ -149,7 +147,6 @@ class AnthropicClient extends BaseClient {
this.startToken = '||>'; this.startToken = '||>';
this.endToken = ''; this.endToken = '';
this.gptEncoder = this.constructor.getTokenizer('cl100k_base');
return this; return this;
} }
@ -849,22 +846,18 @@ class AnthropicClient extends BaseClient {
logger.debug('AnthropicClient doesn\'t use getBuildMessagesOptions'); logger.debug('AnthropicClient doesn\'t use getBuildMessagesOptions');
} }
static getTokenizer(encoding, isModelName = false, extendSpecialTokens = {}) { getEncoding() {
if (tokenizersCache[encoding]) { return 'cl100k_base';
return tokenizersCache[encoding];
}
let tokenizer;
if (isModelName) {
tokenizer = encodingForModel(encoding, extendSpecialTokens);
} else {
tokenizer = getEncoding(encoding, extendSpecialTokens);
}
tokenizersCache[encoding] = tokenizer;
return tokenizer;
} }
/**
* Returns the token count of a given text. It also checks and resets the tokenizers if necessary.
* @param {string} text - The text to get the token count for.
* @returns {number} The token count of the given text.
*/
getTokenCount(text) { getTokenCount(text) {
return this.gptEncoder.encode(text, 'all').length; const encoding = this.getEncoding();
return Tokenizer.getTokenCount(text, encoding);
} }
/** /**

View file

@ -6,7 +6,6 @@ const { ChatGoogleVertexAI } = require('@langchain/google-vertexai');
const { ChatGoogleGenerativeAI } = require('@langchain/google-genai'); const { ChatGoogleGenerativeAI } = require('@langchain/google-genai');
const { GoogleGenerativeAI: GenAI } = require('@google/generative-ai'); const { GoogleGenerativeAI: GenAI } = require('@google/generative-ai');
const { AIMessage, HumanMessage, SystemMessage } = require('@langchain/core/messages'); const { AIMessage, HumanMessage, SystemMessage } = require('@langchain/core/messages');
const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken');
const { const {
validateVisionModel, validateVisionModel,
getResponseSender, getResponseSender,
@ -17,6 +16,7 @@ const {
AuthKeys, AuthKeys,
} = require('librechat-data-provider'); } = require('librechat-data-provider');
const { encodeAndFormat } = require('~/server/services/Files/images'); const { encodeAndFormat } = require('~/server/services/Files/images');
const Tokenizer = require('~/server/services/Tokenizer');
const { getModelMaxTokens } = require('~/utils'); const { getModelMaxTokens } = require('~/utils');
const { sleep } = require('~/server/utils'); const { sleep } = require('~/server/utils');
const { logger } = require('~/config'); const { logger } = require('~/config');
@ -31,7 +31,6 @@ const BaseClient = require('./BaseClient');
const loc = process.env.GOOGLE_LOC || 'us-central1'; const loc = process.env.GOOGLE_LOC || 'us-central1';
const publisher = 'google'; const publisher = 'google';
const endpointPrefix = `${loc}-aiplatform.googleapis.com`; const endpointPrefix = `${loc}-aiplatform.googleapis.com`;
const tokenizersCache = {};
const settings = endpointSettings[EModelEndpoint.google]; const settings = endpointSettings[EModelEndpoint.google];
const EXCLUDED_GENAI_MODELS = /gemini-(?:1\.0|1-0|pro)/; const EXCLUDED_GENAI_MODELS = /gemini-(?:1\.0|1-0|pro)/;
@ -177,25 +176,15 @@ class GoogleClient extends BaseClient {
// without tripping the stop sequences, so I'm using "||>" instead. // without tripping the stop sequences, so I'm using "||>" instead.
this.startToken = '||>'; this.startToken = '||>';
this.endToken = ''; this.endToken = '';
this.gptEncoder = this.constructor.getTokenizer('cl100k_base');
} else if (isTextModel) { } else if (isTextModel) {
this.startToken = '||>'; this.startToken = '||>';
this.endToken = ''; this.endToken = '';
this.gptEncoder = this.constructor.getTokenizer('text-davinci-003', true, {
'<|im_start|>': 100264,
'<|im_end|>': 100265,
});
} else { } else {
// Previously I was trying to use "<|endoftext|>" but there seems to be some bug with OpenAI's token counting // Previously I was trying to use "<|endoftext|>" but there seems to be some bug with OpenAI's token counting
// system that causes only the first "<|endoftext|>" to be counted as 1 token, and the rest are not treated // system that causes only the first "<|endoftext|>" to be counted as 1 token, and the rest are not treated
// as a single token. So we're using this instead. // as a single token. So we're using this instead.
this.startToken = '||>'; this.startToken = '||>';
this.endToken = ''; this.endToken = '';
try {
this.gptEncoder = this.constructor.getTokenizer(this.modelOptions.model, true);
} catch {
this.gptEncoder = this.constructor.getTokenizer('text-davinci-003', true);
}
} }
if (!this.modelOptions.stop) { if (!this.modelOptions.stop) {
@ -926,23 +915,18 @@ class GoogleClient extends BaseClient {
]; ];
} }
/* TO-DO: Handle tokens with Google tokenization NOTE: these are required */ getEncoding() {
static getTokenizer(encoding, isModelName = false, extendSpecialTokens = {}) { return 'cl100k_base';
if (tokenizersCache[encoding]) {
return tokenizersCache[encoding];
}
let tokenizer;
if (isModelName) {
tokenizer = encodingForModel(encoding, extendSpecialTokens);
} else {
tokenizer = getEncoding(encoding, extendSpecialTokens);
}
tokenizersCache[encoding] = tokenizer;
return tokenizer;
} }
/**
* Returns the token count of a given text. It also checks and resets the tokenizers if necessary.
* @param {string} text - The text to get the token count for.
* @returns {number} The token count of the given text.
*/
getTokenCount(text) { getTokenCount(text) {
return this.gptEncoder.encode(text, 'all').length; const encoding = this.getEncoding();
return Tokenizer.getTokenCount(text, encoding);
} }
} }

View file

@ -13,7 +13,6 @@ const {
validateVisionModel, validateVisionModel,
mapModelToAzureConfig, mapModelToAzureConfig,
} = require('librechat-data-provider'); } = require('librechat-data-provider');
const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken');
const { const {
extractBaseURL, extractBaseURL,
constructAzureURL, constructAzureURL,
@ -29,6 +28,7 @@ const {
createContextHandlers, createContextHandlers,
} = require('./prompts'); } = require('./prompts');
const { encodeAndFormat } = require('~/server/services/Files/images/encode'); const { encodeAndFormat } = require('~/server/services/Files/images/encode');
const Tokenizer = require('~/server/services/Tokenizer');
const { spendTokens } = require('~/models/spendTokens'); const { spendTokens } = require('~/models/spendTokens');
const { isEnabled, sleep } = require('~/server/utils'); const { isEnabled, sleep } = require('~/server/utils');
const { handleOpenAIErrors } = require('./tools/util'); const { handleOpenAIErrors } = require('./tools/util');
@ -40,11 +40,6 @@ const { tokenSplit } = require('./document');
const BaseClient = require('./BaseClient'); const BaseClient = require('./BaseClient');
const { logger } = require('~/config'); const { logger } = require('~/config');
// Cache to store Tiktoken instances
const tokenizersCache = {};
// Counter for keeping track of the number of tokenizer calls
let tokenizerCallsCount = 0;
class OpenAIClient extends BaseClient { class OpenAIClient extends BaseClient {
constructor(apiKey, options = {}) { constructor(apiKey, options = {}) {
super(apiKey, options); super(apiKey, options);
@ -307,75 +302,8 @@ class OpenAIClient extends BaseClient {
} }
} }
// Selects an appropriate tokenizer based on the current configuration of the client instance. getEncoding() {
// It takes into account factors such as whether it's a chat completion, an unofficial chat GPT model, etc. return this.model?.includes('gpt-4o') ? 'o200k_base' : 'cl100k_base';
selectTokenizer() {
let tokenizer;
this.encoding = 'text-davinci-003';
if (this.isChatCompletion) {
this.encoding = this.modelOptions.model.includes('gpt-4o') ? 'o200k_base' : 'cl100k_base';
tokenizer = this.constructor.getTokenizer(this.encoding);
} else if (this.isUnofficialChatGptModel) {
const extendSpecialTokens = {
'<|im_start|>': 100264,
'<|im_end|>': 100265,
};
tokenizer = this.constructor.getTokenizer(this.encoding, true, extendSpecialTokens);
} else {
try {
const { model } = this.modelOptions;
this.encoding = model.includes('instruct') ? 'text-davinci-003' : model;
tokenizer = this.constructor.getTokenizer(this.encoding, true);
} catch {
tokenizer = this.constructor.getTokenizer('text-davinci-003', true);
}
}
return tokenizer;
}
// Retrieves a tokenizer either from the cache or creates a new one if one doesn't exist in the cache.
// If a tokenizer is being created, it's also added to the cache.
static getTokenizer(encoding, isModelName = false, extendSpecialTokens = {}) {
let tokenizer;
if (tokenizersCache[encoding]) {
tokenizer = tokenizersCache[encoding];
} else {
if (isModelName) {
tokenizer = encodingForModel(encoding, extendSpecialTokens);
} else {
tokenizer = getEncoding(encoding, extendSpecialTokens);
}
tokenizersCache[encoding] = tokenizer;
}
return tokenizer;
}
// Frees all encoders in the cache and resets the count.
static freeAndResetAllEncoders() {
try {
Object.keys(tokenizersCache).forEach((key) => {
if (tokenizersCache[key]) {
tokenizersCache[key].free();
delete tokenizersCache[key];
}
});
// Reset count
tokenizerCallsCount = 1;
} catch (error) {
logger.error('[OpenAIClient] Free and reset encoders error', error);
}
}
// Checks if the cache of tokenizers has reached a certain size. If it has, it frees and resets all tokenizers.
resetTokenizersIfNecessary() {
if (tokenizerCallsCount >= 25) {
if (this.options.debug) {
logger.debug('[OpenAIClient] freeAndResetAllEncoders: reached 25 encodings, resetting...');
}
this.constructor.freeAndResetAllEncoders();
}
tokenizerCallsCount++;
} }
/** /**
@ -384,15 +312,8 @@ class OpenAIClient extends BaseClient {
* @returns {number} The token count of the given text. * @returns {number} The token count of the given text.
*/ */
getTokenCount(text) { getTokenCount(text) {
this.resetTokenizersIfNecessary(); const encoding = this.getEncoding();
try { return Tokenizer.getTokenCount(text, encoding);
const tokenizer = this.selectTokenizer();
return tokenizer.encode(text, 'all').length;
} catch (error) {
this.constructor.freeAndResetAllEncoders();
const tokenizer = this.selectTokenizer();
return tokenizer.encode(text, 'all').length;
}
} }
/** /**

View file

@ -1,5 +1,7 @@
jest.mock('~/cache/getLogStores');
require('dotenv').config(); require('dotenv').config();
const OpenAI = require('openai'); const OpenAI = require('openai');
const getLogStores = require('~/cache/getLogStores');
const { fetchEventSource } = require('@waylaidwanderer/fetch-event-source'); const { fetchEventSource } = require('@waylaidwanderer/fetch-event-source');
const { genAzureChatCompletion } = require('~/utils/azureUtils'); const { genAzureChatCompletion } = require('~/utils/azureUtils');
const OpenAIClient = require('../OpenAIClient'); const OpenAIClient = require('../OpenAIClient');
@ -134,7 +136,13 @@ OpenAI.mockImplementation(() => ({
})); }));
describe('OpenAIClient', () => { describe('OpenAIClient', () => {
let client, client2; const mockSet = jest.fn();
const mockCache = { set: mockSet };
beforeEach(() => {
getLogStores.mockReturnValue(mockCache);
});
let client;
const model = 'gpt-4'; const model = 'gpt-4';
const parentMessageId = '1'; const parentMessageId = '1';
const messages = [ const messages = [
@ -176,7 +184,6 @@ describe('OpenAIClient', () => {
beforeEach(() => { beforeEach(() => {
const options = { ...defaultOptions }; const options = { ...defaultOptions };
client = new OpenAIClient('test-api-key', options); client = new OpenAIClient('test-api-key', options);
client2 = new OpenAIClient('test-api-key', options);
client.summarizeMessages = jest.fn().mockResolvedValue({ client.summarizeMessages = jest.fn().mockResolvedValue({
role: 'assistant', role: 'assistant',
content: 'Refined answer', content: 'Refined answer',
@ -185,7 +192,6 @@ describe('OpenAIClient', () => {
client.buildPrompt = jest client.buildPrompt = jest
.fn() .fn()
.mockResolvedValue({ prompt: messages.map((m) => m.text).join('\n') }); .mockResolvedValue({ prompt: messages.map((m) => m.text).join('\n') });
client.constructor.freeAndResetAllEncoders();
client.getMessages = jest.fn().mockResolvedValue([]); client.getMessages = jest.fn().mockResolvedValue([]);
}); });
@ -335,77 +341,11 @@ describe('OpenAIClient', () => {
}); });
}); });
describe('selectTokenizer', () => {
it('should get the correct tokenizer based on the instance state', () => {
const tokenizer = client.selectTokenizer();
expect(tokenizer).toBeDefined();
});
});
describe('freeAllTokenizers', () => {
it('should free all tokenizers', () => {
// Create a tokenizer
const tokenizer = client.selectTokenizer();
// Mock 'free' method on the tokenizer
tokenizer.free = jest.fn();
client.constructor.freeAndResetAllEncoders();
// Check if 'free' method has been called on the tokenizer
expect(tokenizer.free).toHaveBeenCalled();
});
});
describe('getTokenCount', () => { describe('getTokenCount', () => {
it('should return the correct token count', () => { it('should return the correct token count', () => {
const count = client.getTokenCount('Hello, world!'); const count = client.getTokenCount('Hello, world!');
expect(count).toBeGreaterThan(0); expect(count).toBeGreaterThan(0);
}); });
it('should reset the encoder and count when count reaches 25', () => {
const freeAndResetEncoderSpy = jest.spyOn(client.constructor, 'freeAndResetAllEncoders');
// Call getTokenCount 25 times
for (let i = 0; i < 25; i++) {
client.getTokenCount('test text');
}
expect(freeAndResetEncoderSpy).toHaveBeenCalled();
});
it('should not reset the encoder and count when count is less than 25', () => {
const freeAndResetEncoderSpy = jest.spyOn(client.constructor, 'freeAndResetAllEncoders');
freeAndResetEncoderSpy.mockClear();
// Call getTokenCount 24 times
for (let i = 0; i < 24; i++) {
client.getTokenCount('test text');
}
expect(freeAndResetEncoderSpy).not.toHaveBeenCalled();
});
it('should handle errors and reset the encoder', () => {
const freeAndResetEncoderSpy = jest.spyOn(client.constructor, 'freeAndResetAllEncoders');
// Mock encode function to throw an error
client.selectTokenizer().encode = jest.fn().mockImplementation(() => {
throw new Error('Test error');
});
client.getTokenCount('test text');
expect(freeAndResetEncoderSpy).toHaveBeenCalled();
});
it('should not throw null pointer error when freeing the same encoder twice', () => {
client.constructor.freeAndResetAllEncoders();
client2.constructor.freeAndResetAllEncoders();
const count = client2.getTokenCount('test text');
expect(count).toBeGreaterThan(0);
});
}); });
describe('getSaveOptions', () => { describe('getSaveOptions', () => {
@ -548,7 +488,6 @@ describe('OpenAIClient', () => {
testCases.forEach((testCase) => { testCases.forEach((testCase) => {
it(`should return ${testCase.expected} tokens for model ${testCase.model}`, () => { it(`should return ${testCase.expected} tokens for model ${testCase.model}`, () => {
client.modelOptions.model = testCase.model; client.modelOptions.model = testCase.model;
client.selectTokenizer();
// 3 tokens for assistant label // 3 tokens for assistant label
let totalTokens = 3; let totalTokens = 3;
for (let message of example_messages) { for (let message of example_messages) {
@ -582,7 +521,6 @@ describe('OpenAIClient', () => {
it(`should return ${expectedTokens} tokens for model ${visionModel} (Vision Request)`, () => { it(`should return ${expectedTokens} tokens for model ${visionModel} (Vision Request)`, () => {
client.modelOptions.model = visionModel; client.modelOptions.model = visionModel;
client.selectTokenizer();
// 3 tokens for assistant label // 3 tokens for assistant label
let totalTokens = 3; let totalTokens = 3;
for (let message of vision_request) { for (let message of vision_request) {

View file

@ -5,7 +5,7 @@ const { math, isEnabled } = require('~/server/utils');
const keyvRedis = require('./keyvRedis'); const keyvRedis = require('./keyvRedis');
const keyvMongo = require('./keyvMongo'); const keyvMongo = require('./keyvMongo');
const { BAN_DURATION, USE_REDIS, DEBUG_MEMORY_CACHE } = process.env ?? {}; const { BAN_DURATION, USE_REDIS, DEBUG_MEMORY_CACHE, CI } = process.env ?? {};
const duration = math(BAN_DURATION, 7200000); const duration = math(BAN_DURATION, 7200000);
const isRedisEnabled = isEnabled(USE_REDIS); const isRedisEnabled = isEnabled(USE_REDIS);
@ -95,10 +95,8 @@ const namespaces = {
* @returns {Keyv[]} * @returns {Keyv[]}
*/ */
function getTTLStores() { function getTTLStores() {
return Object.values(namespaces).filter((store) => return Object.values(namespaces).filter(
store instanceof Keyv && (store) => store instanceof Keyv && typeof store.opts?.ttl === 'number' && store.opts.ttl > 0,
typeof store.opts?.ttl === 'number' &&
store.opts.ttl > 0,
); );
} }
@ -125,23 +123,28 @@ async function clearExpiredFromCache(cache) {
for (const key of keys) { for (const key of keys) {
try { try {
const raw = cache.opts.store.get(key); const raw = cache.opts.store.get(key);
if (!raw) {continue;} if (!raw) {
continue;
}
const data = cache.opts.deserialize(raw); const data = cache.opts.deserialize(raw);
// Check if the entry is older than TTL // Check if the entry is older than TTL
if (data?.expires && data.expires <= expiryTime) { if (data?.expires && data.expires <= expiryTime) {
const deleted = await cache.opts.store.delete(key); const deleted = await cache.opts.store.delete(key);
if (!deleted) { if (!deleted) {
debugMemoryCache && console.warn(`[Cache] Error deleting entry: ${key} from ${cache.opts.namespace}`); debugMemoryCache &&
console.warn(`[Cache] Error deleting entry: ${key} from ${cache.opts.namespace}`);
continue; continue;
} }
cleared++; cleared++;
} }
} catch (error) { } catch (error) {
debugMemoryCache && console.log(`[Cache] Error processing entry from ${cache.opts.namespace}:`, error); debugMemoryCache &&
console.log(`[Cache] Error processing entry from ${cache.opts.namespace}:`, error);
const deleted = await cache.opts.store.delete(key); const deleted = await cache.opts.store.delete(key);
if (!deleted) { if (!deleted) {
debugMemoryCache && console.warn(`[Cache] Error deleting entry: ${key} from ${cache.opts.namespace}`); debugMemoryCache &&
console.warn(`[Cache] Error deleting entry: ${key} from ${cache.opts.namespace}`);
continue; continue;
} }
cleared++; cleared++;
@ -149,7 +152,10 @@ async function clearExpiredFromCache(cache) {
} }
if (cleared > 0) { if (cleared > 0) {
debugMemoryCache && console.log(`[Cache] Cleared ${cleared} entries older than ${ttl}ms from ${cache.opts.namespace}`); debugMemoryCache &&
console.log(
`[Cache] Cleared ${cleared} entries older than ${ttl}ms from ${cache.opts.namespace}`,
);
} }
} }
@ -157,7 +163,7 @@ const auditCache = () => {
const ttlStores = getTTLStores(); const ttlStores = getTTLStores();
console.log('[Cache] Starting audit'); console.log('[Cache] Starting audit');
ttlStores.forEach(store => { ttlStores.forEach((store) => {
if (!store?.opts?.store?.entries) { if (!store?.opts?.store?.entries) {
return; return;
} }
@ -166,8 +172,7 @@ const auditCache = () => {
count: store.opts.store.size, count: store.opts.store.size,
ttl: store.opts.ttl, ttl: store.opts.ttl,
keys: Array.from(store.opts.store.keys()), keys: Array.from(store.opts.store.keys()),
entriesWithTimestamps: Array.from(store.opts.store.entries()) entriesWithTimestamps: Array.from(store.opts.store.entries()).map(([key, value]) => ({
.map(([key, value]) => ({
key, key,
value, value,
})), })),
@ -180,7 +185,7 @@ const auditCache = () => {
*/ */
async function clearAllExpiredFromCache() { async function clearAllExpiredFromCache() {
const ttlStores = getTTLStores(); const ttlStores = getTTLStores();
await Promise.all(ttlStores.map(store => clearExpiredFromCache(store))); await Promise.all(ttlStores.map((store) => clearExpiredFromCache(store)));
// Force garbage collection if available (Node.js with --expose-gc flag) // Force garbage collection if available (Node.js with --expose-gc flag)
if (global.gc) { if (global.gc) {
@ -188,7 +193,7 @@ async function clearAllExpiredFromCache() {
} }
} }
if (!isRedisEnabled) { if (!isRedisEnabled && !isEnabled(CI)) {
/** @type {Set<NodeJS.Timeout>} */ /** @type {Set<NodeJS.Timeout>} */
const cleanupIntervals = new Set(); const cleanupIntervals = new Set();
@ -221,7 +226,7 @@ if (!isRedisEnabled) {
const dispose = () => { const dispose = () => {
debugMemoryCache && console.log('[Cache] Cleaning up and shutting down...'); debugMemoryCache && console.log('[Cache] Cleaning up and shutting down...');
cleanupIntervals.forEach(interval => clearInterval(interval)); cleanupIntervals.forEach((interval) => clearInterval(interval));
cleanupIntervals.clear(); cleanupIntervals.clear();
// One final cleanup before exit // One final cleanup before exit

View file

@ -1,3 +1,4 @@
jest.mock('~/cache/getLogStores');
const request = require('supertest'); const request = require('supertest');
const express = require('express'); const express = require('express');
const routes = require('../'); const routes = require('../');

View file

@ -1,4 +1,5 @@
// gptPlugins/initializeClient.spec.js // gptPlugins/initializeClient.spec.js
jest.mock('~/cache/getLogStores');
const { EModelEndpoint, ErrorTypes, validateAzureGroups } = require('librechat-data-provider'); const { EModelEndpoint, ErrorTypes, validateAzureGroups } = require('librechat-data-provider');
const { getUserKey, getUserKeyValues } = require('~/server/services/UserService'); const { getUserKey, getUserKeyValues } = require('~/server/services/UserService');
const initializeClient = require('./initialize'); const initializeClient = require('./initialize');

View file

@ -1,3 +1,4 @@
jest.mock('~/cache/getLogStores');
const { EModelEndpoint, ErrorTypes, validateAzureGroups } = require('librechat-data-provider'); const { EModelEndpoint, ErrorTypes, validateAzureGroups } = require('librechat-data-provider');
const { getUserKey, getUserKeyValues } = require('~/server/services/UserService'); const { getUserKey, getUserKeyValues } = require('~/server/services/UserService');
const initializeClient = require('./initialize'); const initializeClient = require('./initialize');

View file

@ -59,6 +59,6 @@ class Tokenizer {
} }
} }
const tokenizerService = new Tokenizer(); const TokenizerSingleton = new Tokenizer();
module.exports = tokenizerService; module.exports = TokenizerSingleton;

View file

@ -0,0 +1,136 @@
/**
* @file Tokenizer.spec.cjs
*
* Tests the real TokenizerSingleton (no mocking of `tiktoken`).
* Make sure to install `tiktoken` and have it configured properly.
*/
const Tokenizer = require('./Tokenizer'); // <-- Adjust path to your singleton file
const { logger } = require('~/config');
describe('Tokenizer', () => {
it('should be a singleton (same instance)', () => {
const AnotherTokenizer = require('./Tokenizer'); // same path
expect(Tokenizer).toBe(AnotherTokenizer);
});
describe('getTokenizer', () => {
it('should create an encoder for an explicit model name (e.g., "gpt-4")', () => {
// The real `encoding_for_model` will be called internally
// as soon as we pass isModelName = true.
const tokenizer = Tokenizer.getTokenizer('gpt-4', true);
// Basic sanity checks
expect(tokenizer).toBeDefined();
// You can optionally check certain properties from `tiktoken` if they exist
// e.g., expect(typeof tokenizer.encode).toBe('function');
});
it('should create an encoder for a known encoding (e.g., "cl100k_base")', () => {
// The real `get_encoding` will be called internally
// as soon as we pass isModelName = false.
const tokenizer = Tokenizer.getTokenizer('cl100k_base', false);
expect(tokenizer).toBeDefined();
// e.g., expect(typeof tokenizer.encode).toBe('function');
});
it('should return cached tokenizer if previously fetched', () => {
const tokenizer1 = Tokenizer.getTokenizer('cl100k_base', false);
const tokenizer2 = Tokenizer.getTokenizer('cl100k_base', false);
// Should be the exact same instance from the cache
expect(tokenizer1).toBe(tokenizer2);
});
});
describe('freeAndResetAllEncoders', () => {
beforeEach(() => {
jest.clearAllMocks();
});
it('should free all encoders and reset tokenizerCallsCount to 1', () => {
// By creating two different encodings, we populate the cache
Tokenizer.getTokenizer('cl100k_base', false);
Tokenizer.getTokenizer('r50k_base', false);
// Now free them
Tokenizer.freeAndResetAllEncoders();
// The internal cache is cleared
expect(Tokenizer.tokenizersCache['cl100k_base']).toBeUndefined();
expect(Tokenizer.tokenizersCache['r50k_base']).toBeUndefined();
// tokenizerCallsCount is reset to 1
expect(Tokenizer.tokenizerCallsCount).toBe(1);
});
it('should catch and log errors if freeing fails', () => {
// Mock logger.error before the test
const mockLoggerError = jest.spyOn(logger, 'error');
// Set up a problematic tokenizer in the cache
Tokenizer.tokenizersCache['cl100k_base'] = {
free() {
throw new Error('Intentional free error');
},
};
// Should not throw uncaught errors
Tokenizer.freeAndResetAllEncoders();
// Verify logger.error was called with correct arguments
expect(mockLoggerError).toHaveBeenCalledWith(
'[Tokenizer] Free and reset encoders error',
expect.any(Error),
);
// Clean up
mockLoggerError.mockRestore();
Tokenizer.tokenizersCache = {};
});
});
describe('getTokenCount', () => {
beforeEach(() => {
jest.clearAllMocks();
Tokenizer.freeAndResetAllEncoders();
});
it('should return the number of tokens in the given text', () => {
const text = 'Hello, world!';
const count = Tokenizer.getTokenCount(text, 'cl100k_base');
expect(count).toBeGreaterThan(0);
});
it('should reset encoders if an error is thrown', () => {
// We can simulate an error by temporarily overriding the selected tokenizers `encode` method.
const tokenizer = Tokenizer.getTokenizer('cl100k_base', false);
const originalEncode = tokenizer.encode;
tokenizer.encode = () => {
throw new Error('Forced error');
};
// Despite the forced error, the code should catch and reset, then re-encode
const count = Tokenizer.getTokenCount('Hello again', 'cl100k_base');
expect(count).toBeGreaterThan(0);
// Restore the original encode
tokenizer.encode = originalEncode;
});
it('should reset tokenizers after 25 calls', () => {
// Spy on freeAndResetAllEncoders
const resetSpy = jest.spyOn(Tokenizer, 'freeAndResetAllEncoders');
// Make 24 calls; should NOT reset yet
for (let i = 0; i < 24; i++) {
Tokenizer.getTokenCount('test text', 'cl100k_base');
}
expect(resetSpy).not.toHaveBeenCalled();
// 25th call triggers the reset
Tokenizer.getTokenCount('the 25th call!', 'cl100k_base');
expect(resetSpy).toHaveBeenCalledTimes(1);
});
});
});

View file

@ -5,3 +5,4 @@ process.env.MONGO_URI = 'mongodb://127.0.0.1:27017/dummy-uri';
process.env.BAN_VIOLATIONS = 'true'; process.env.BAN_VIOLATIONS = 'true';
process.env.BAN_DURATION = '7200000'; process.env.BAN_DURATION = '7200000';
process.env.BAN_INTERVAL = '20'; process.env.BAN_INTERVAL = '20';
process.env.CI = 'true';