From 197307d514c93669ca84d45ee000e39b74369d00 Mon Sep 17 00:00:00 2001 From: Youngwook Kim Date: Mon, 24 Jul 2023 00:59:11 +0900 Subject: [PATCH] fix(OpenAIClient): resolve null pointer exception in tokenizer management (#689) --- api/app/clients/OpenAIClient.js | 99 +++++++++++++--------- api/app/clients/specs/OpenAIClient.test.js | 46 ++++++++-- 2 files changed, 97 insertions(+), 48 deletions(-) diff --git a/api/app/clients/OpenAIClient.js b/api/app/clients/OpenAIClient.js index 97f7851071..53f4815d74 100644 --- a/api/app/clients/OpenAIClient.js +++ b/api/app/clients/OpenAIClient.js @@ -6,7 +6,10 @@ const { } = require('@dqbd/tiktoken'); const { maxTokensMap, genAzureChatCompletion } = require('../../utils'); +// Cache to store Tiktoken instances const tokenizersCache = {}; +// Counter for keeping track of the number of tokenizer calls +let tokenizerCallsCount = 0; class OpenAIClient extends BaseClient { constructor(apiKey, options = {}) { @@ -89,7 +92,6 @@ class OpenAIClient extends BaseClient { this.chatGptLabel = this.options.chatGptLabel || 'Assistant'; this.setupTokens(); - this.setupTokenizer(); if (!this.modelOptions.stop) { const stopTokens = [this.startToken]; @@ -133,68 +135,87 @@ class OpenAIClient extends BaseClient { } } - setupTokenizer() { + // Selects an appropriate tokenizer based on the current configuration of the client instance. + // It takes into account factors such as whether it's a chat completion, an unofficial chat GPT model, etc. + selectTokenizer() { + let tokenizer; this.encoding = 'text-davinci-003'; if (this.isChatCompletion) { this.encoding = 'cl100k_base'; - this.gptEncoder = this.constructor.getTokenizer(this.encoding); + tokenizer = this.constructor.getTokenizer(this.encoding); } else if (this.isUnofficialChatGptModel) { - this.gptEncoder = this.constructor.getTokenizer(this.encoding, true, { + const extendSpecialTokens = { '<|im_start|>': 100264, '<|im_end|>': 100265, - }); + }; + tokenizer = this.constructor.getTokenizer(this.encoding, true, extendSpecialTokens); } else { try { this.encoding = this.modelOptions.model; - this.gptEncoder = this.constructor.getTokenizer(this.modelOptions.model, true); + tokenizer = this.constructor.getTokenizer(this.modelOptions.model, true); } catch { - this.gptEncoder = this.constructor.getTokenizer(this.encoding, true); + tokenizer = this.constructor.getTokenizer(this.encoding, true); } } - } - static getTokenizer(encoding, isModelName = false, extendSpecialTokens = {}) { - if (tokenizersCache[encoding]) { - return tokenizersCache[encoding]; - } - let tokenizer; - if (isModelName) { - tokenizer = encodingForModel(encoding, extendSpecialTokens); - } else { - tokenizer = getEncoding(encoding, extendSpecialTokens); - } - tokenizersCache[encoding] = tokenizer; return tokenizer; } - freeAndResetEncoder() { - try { - if (!this.gptEncoder) { - return; + // Retrieves a tokenizer either from the cache or creates a new one if one doesn't exist in the cache. + // If a tokenizer is being created, it's also added to the cache. + static getTokenizer(encoding, isModelName = false, extendSpecialTokens = {}) { + let tokenizer; + if (tokenizersCache[encoding]) { + tokenizer = tokenizersCache[encoding]; + } else { + if (isModelName) { + tokenizer = encodingForModel(encoding, extendSpecialTokens); + } else { + tokenizer = getEncoding(encoding, extendSpecialTokens); } - this.gptEncoder.free(); - delete tokenizersCache[this.encoding]; - delete tokenizersCache.count; - this.setupTokenizer(); + tokenizersCache[encoding] = tokenizer; + } + return tokenizer; + } + + // Frees all encoders in the cache and resets the count. + static freeAndResetAllEncoders() { + try { + Object.keys(tokenizersCache).forEach((key) => { + if (tokenizersCache[key]) { + tokenizersCache[key].free(); + delete tokenizersCache[key]; + } + }); + // Reset count + tokenizerCallsCount = 1; } catch (error) { - console.log('freeAndResetEncoder error'); + console.log('Free and reset encoders error'); console.error(error); } } - getTokenCount(text) { - try { - if (tokenizersCache.count >= 25) { - if (this.options.debug) { - console.debug('freeAndResetEncoder: reached 25 encodings, reseting...'); - } - this.freeAndResetEncoder(); + // Checks if the cache of tokenizers has reached a certain size. If it has, it frees and resets all tokenizers. + resetTokenizersIfNecessary() { + if (tokenizerCallsCount >= 25) { + if (this.options.debug) { + console.debug('freeAndResetAllEncoders: reached 25 encodings, resetting...'); } - tokenizersCache.count = (tokenizersCache.count || 0) + 1; - return this.gptEncoder.encode(text, 'all').length; + this.constructor.freeAndResetAllEncoders(); + } + tokenizerCallsCount++; + } + + // Returns the token count of a given text. It also checks and resets the tokenizers if necessary. + getTokenCount(text) { + this.resetTokenizersIfNecessary(); + try { + const tokenizer = this.selectTokenizer(); + return tokenizer.encode(text, 'all').length; } catch (error) { - this.freeAndResetEncoder(); - return this.gptEncoder.encode(text, 'all').length; + this.constructor.freeAndResetAllEncoders(); + const tokenizer = this.selectTokenizer(); + return tokenizer.encode(text, 'all').length; } } diff --git a/api/app/clients/specs/OpenAIClient.test.js b/api/app/clients/specs/OpenAIClient.test.js index 0badda47ee..41aeb4f3b4 100644 --- a/api/app/clients/specs/OpenAIClient.test.js +++ b/api/app/clients/specs/OpenAIClient.test.js @@ -1,7 +1,7 @@ const OpenAIClient = require('../OpenAIClient'); describe('OpenAIClient', () => { - let client; + let client, client2; const model = 'gpt-4'; const parentMessageId = '1'; const messages = [ @@ -19,11 +19,13 @@ describe('OpenAIClient', () => { }, }; client = new OpenAIClient('test-api-key', options); + client2 = new OpenAIClient('test-api-key', options); client.refineMessages = jest.fn().mockResolvedValue({ role: 'assistant', content: 'Refined answer', tokenCount: 30, }); + client.constructor.freeAndResetAllEncoders(); }); describe('setOptions', () => { @@ -34,10 +36,25 @@ describe('OpenAIClient', () => { }); }); - describe('freeAndResetEncoder', () => { - it('should reset the encoder', () => { - client.freeAndResetEncoder(); - expect(client.gptEncoder).toBeDefined(); + describe('selectTokenizer', () => { + it('should get the correct tokenizer based on the instance state', () => { + const tokenizer = client.selectTokenizer(); + expect(tokenizer).toBeDefined(); + }); + }); + + describe('freeAllTokenizers', () => { + it('should free all tokenizers', () => { + // Create a tokenizer + const tokenizer = client.selectTokenizer(); + + // Mock 'free' method on the tokenizer + tokenizer.free = jest.fn(); + + client.constructor.freeAndResetAllEncoders(); + + // Check if 'free' method has been called on the tokenizer + expect(tokenizer.free).toHaveBeenCalled(); }); }); @@ -48,7 +65,7 @@ describe('OpenAIClient', () => { }); it('should reset the encoder and count when count reaches 25', () => { - const freeAndResetEncoderSpy = jest.spyOn(client, 'freeAndResetEncoder'); + const freeAndResetEncoderSpy = jest.spyOn(client.constructor, 'freeAndResetAllEncoders'); // Call getTokenCount 25 times for (let i = 0; i < 25; i++) { @@ -59,7 +76,8 @@ describe('OpenAIClient', () => { }); it('should not reset the encoder and count when count is less than 25', () => { - const freeAndResetEncoderSpy = jest.spyOn(client, 'freeAndResetEncoder'); + const freeAndResetEncoderSpy = jest.spyOn(client.constructor, 'freeAndResetAllEncoders'); + freeAndResetEncoderSpy.mockClear(); // Call getTokenCount 24 times for (let i = 0; i < 24; i++) { @@ -70,8 +88,10 @@ describe('OpenAIClient', () => { }); it('should handle errors and reset the encoder', () => { - const freeAndResetEncoderSpy = jest.spyOn(client, 'freeAndResetEncoder'); - client.gptEncoder.encode = jest.fn().mockImplementation(() => { + const freeAndResetEncoderSpy = jest.spyOn(client.constructor, 'freeAndResetAllEncoders'); + + // Mock encode function to throw an error + client.selectTokenizer().encode = jest.fn().mockImplementation(() => { throw new Error('Test error'); }); @@ -79,6 +99,14 @@ describe('OpenAIClient', () => { expect(freeAndResetEncoderSpy).toHaveBeenCalled(); }); + + it('should not throw null pointer error when freeing the same encoder twice', () => { + client.constructor.freeAndResetAllEncoders(); + client2.constructor.freeAndResetAllEncoders(); + + const count = client2.getTokenCount('test text'); + expect(count).toBeGreaterThan(0); + }); }); describe('getSaveOptions', () => {