🔄 refactor: Consolidate Tokenizer; Fix Jest Open Handles (#5175)

* refactor: consolidate tokenizer to singleton

* fix: remove legacy tokenizer code, add Tokenizer singleton tests

* ci: fix jest open handles
This commit is contained in:
Danny Avila 2025-01-03 18:11:14 -05:00 committed by GitHub
parent bf0a84e45a
commit c26b54c74d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 202 additions and 221 deletions

View file

@ -1,5 +1,7 @@
jest.mock('~/cache/getLogStores');
require('dotenv').config();
const OpenAI = require('openai');
const getLogStores = require('~/cache/getLogStores');
const { fetchEventSource } = require('@waylaidwanderer/fetch-event-source');
const { genAzureChatCompletion } = require('~/utils/azureUtils');
const OpenAIClient = require('../OpenAIClient');
@ -134,7 +136,13 @@ OpenAI.mockImplementation(() => ({
}));
describe('OpenAIClient', () => {
let client, client2;
const mockSet = jest.fn();
const mockCache = { set: mockSet };
beforeEach(() => {
getLogStores.mockReturnValue(mockCache);
});
let client;
const model = 'gpt-4';
const parentMessageId = '1';
const messages = [
@ -176,7 +184,6 @@ describe('OpenAIClient', () => {
beforeEach(() => {
const options = { ...defaultOptions };
client = new OpenAIClient('test-api-key', options);
client2 = new OpenAIClient('test-api-key', options);
client.summarizeMessages = jest.fn().mockResolvedValue({
role: 'assistant',
content: 'Refined answer',
@ -185,7 +192,6 @@ describe('OpenAIClient', () => {
client.buildPrompt = jest
.fn()
.mockResolvedValue({ prompt: messages.map((m) => m.text).join('\n') });
client.constructor.freeAndResetAllEncoders();
client.getMessages = jest.fn().mockResolvedValue([]);
});
@ -335,77 +341,11 @@ describe('OpenAIClient', () => {
});
});
describe('selectTokenizer', () => {
it('should get the correct tokenizer based on the instance state', () => {
const tokenizer = client.selectTokenizer();
expect(tokenizer).toBeDefined();
});
});
describe('freeAllTokenizers', () => {
it('should free all tokenizers', () => {
// Create a tokenizer
const tokenizer = client.selectTokenizer();
// Mock 'free' method on the tokenizer
tokenizer.free = jest.fn();
client.constructor.freeAndResetAllEncoders();
// Check if 'free' method has been called on the tokenizer
expect(tokenizer.free).toHaveBeenCalled();
});
});
describe('getTokenCount', () => {
it('should return the correct token count', () => {
const count = client.getTokenCount('Hello, world!');
expect(count).toBeGreaterThan(0);
});
it('should reset the encoder and count when count reaches 25', () => {
const freeAndResetEncoderSpy = jest.spyOn(client.constructor, 'freeAndResetAllEncoders');
// Call getTokenCount 25 times
for (let i = 0; i < 25; i++) {
client.getTokenCount('test text');
}
expect(freeAndResetEncoderSpy).toHaveBeenCalled();
});
it('should not reset the encoder and count when count is less than 25', () => {
const freeAndResetEncoderSpy = jest.spyOn(client.constructor, 'freeAndResetAllEncoders');
freeAndResetEncoderSpy.mockClear();
// Call getTokenCount 24 times
for (let i = 0; i < 24; i++) {
client.getTokenCount('test text');
}
expect(freeAndResetEncoderSpy).not.toHaveBeenCalled();
});
it('should handle errors and reset the encoder', () => {
const freeAndResetEncoderSpy = jest.spyOn(client.constructor, 'freeAndResetAllEncoders');
// Mock encode function to throw an error
client.selectTokenizer().encode = jest.fn().mockImplementation(() => {
throw new Error('Test error');
});
client.getTokenCount('test text');
expect(freeAndResetEncoderSpy).toHaveBeenCalled();
});
it('should not throw null pointer error when freeing the same encoder twice', () => {
client.constructor.freeAndResetAllEncoders();
client2.constructor.freeAndResetAllEncoders();
const count = client2.getTokenCount('test text');
expect(count).toBeGreaterThan(0);
});
});
describe('getSaveOptions', () => {
@ -548,7 +488,6 @@ describe('OpenAIClient', () => {
testCases.forEach((testCase) => {
it(`should return ${testCase.expected} tokens for model ${testCase.model}`, () => {
client.modelOptions.model = testCase.model;
client.selectTokenizer();
// 3 tokens for assistant label
let totalTokens = 3;
for (let message of example_messages) {
@ -582,7 +521,6 @@ describe('OpenAIClient', () => {
it(`should return ${expectedTokens} tokens for model ${visionModel} (Vision Request)`, () => {
client.modelOptions.model = visionModel;
client.selectTokenizer();
// 3 tokens for assistant label
let totalTokens = 3;
for (let message of vision_request) {