fix(OpenAIClient): resolve null pointer exception in tokenizer management (#689)

This commit is contained in:
Youngwook Kim 2023-07-24 00:59:11 +09:00 committed by GitHub
parent 130356654c
commit 197307d514
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 97 additions and 48 deletions

View file

@ -6,7 +6,10 @@ const {
} = require('@dqbd/tiktoken'); } = require('@dqbd/tiktoken');
const { maxTokensMap, genAzureChatCompletion } = require('../../utils'); const { maxTokensMap, genAzureChatCompletion } = require('../../utils');
// Cache to store Tiktoken instances
const tokenizersCache = {}; const tokenizersCache = {};
// Counter for keeping track of the number of tokenizer calls
let tokenizerCallsCount = 0;
class OpenAIClient extends BaseClient { class OpenAIClient extends BaseClient {
constructor(apiKey, options = {}) { constructor(apiKey, options = {}) {
@ -89,7 +92,6 @@ class OpenAIClient extends BaseClient {
this.chatGptLabel = this.options.chatGptLabel || 'Assistant'; this.chatGptLabel = this.options.chatGptLabel || 'Assistant';
this.setupTokens(); this.setupTokens();
this.setupTokenizer();
if (!this.modelOptions.stop) { if (!this.modelOptions.stop) {
const stopTokens = [this.startToken]; const stopTokens = [this.startToken];
@ -133,68 +135,87 @@ class OpenAIClient extends BaseClient {
} }
} }
setupTokenizer() { // Selects an appropriate tokenizer based on the current configuration of the client instance.
// It takes into account factors such as whether it's a chat completion, an unofficial chat GPT model, etc.
selectTokenizer() {
let tokenizer;
this.encoding = 'text-davinci-003'; this.encoding = 'text-davinci-003';
if (this.isChatCompletion) { if (this.isChatCompletion) {
this.encoding = 'cl100k_base'; this.encoding = 'cl100k_base';
this.gptEncoder = this.constructor.getTokenizer(this.encoding); tokenizer = this.constructor.getTokenizer(this.encoding);
} else if (this.isUnofficialChatGptModel) { } else if (this.isUnofficialChatGptModel) {
this.gptEncoder = this.constructor.getTokenizer(this.encoding, true, { const extendSpecialTokens = {
'<|im_start|>': 100264, '<|im_start|>': 100264,
'<|im_end|>': 100265, '<|im_end|>': 100265,
}); };
tokenizer = this.constructor.getTokenizer(this.encoding, true, extendSpecialTokens);
} else { } else {
try { try {
this.encoding = this.modelOptions.model; this.encoding = this.modelOptions.model;
this.gptEncoder = this.constructor.getTokenizer(this.modelOptions.model, true); tokenizer = this.constructor.getTokenizer(this.modelOptions.model, true);
} catch { } catch {
this.gptEncoder = this.constructor.getTokenizer(this.encoding, true); tokenizer = this.constructor.getTokenizer(this.encoding, true);
}
} }
} }
static getTokenizer(encoding, isModelName = false, extendSpecialTokens = {}) { return tokenizer;
if (tokenizersCache[encoding]) {
return tokenizersCache[encoding];
} }
// Retrieves a tokenizer either from the cache or creates a new one if one doesn't exist in the cache.
// If a tokenizer is being created, it's also added to the cache.
static getTokenizer(encoding, isModelName = false, extendSpecialTokens = {}) {
let tokenizer; let tokenizer;
if (tokenizersCache[encoding]) {
tokenizer = tokenizersCache[encoding];
} else {
if (isModelName) { if (isModelName) {
tokenizer = encodingForModel(encoding, extendSpecialTokens); tokenizer = encodingForModel(encoding, extendSpecialTokens);
} else { } else {
tokenizer = getEncoding(encoding, extendSpecialTokens); tokenizer = getEncoding(encoding, extendSpecialTokens);
} }
tokenizersCache[encoding] = tokenizer; tokenizersCache[encoding] = tokenizer;
}
return tokenizer; return tokenizer;
} }
freeAndResetEncoder() { // Frees all encoders in the cache and resets the count.
static freeAndResetAllEncoders() {
try { try {
if (!this.gptEncoder) { Object.keys(tokenizersCache).forEach((key) => {
return; if (tokenizersCache[key]) {
tokenizersCache[key].free();
delete tokenizersCache[key];
} }
this.gptEncoder.free(); });
delete tokenizersCache[this.encoding]; // Reset count
delete tokenizersCache.count; tokenizerCallsCount = 1;
this.setupTokenizer();
} catch (error) { } catch (error) {
console.log('freeAndResetEncoder error'); console.log('Free and reset encoders error');
console.error(error); console.error(error);
} }
} }
getTokenCount(text) { // Checks if the cache of tokenizers has reached a certain size. If it has, it frees and resets all tokenizers.
try { resetTokenizersIfNecessary() {
if (tokenizersCache.count >= 25) { if (tokenizerCallsCount >= 25) {
if (this.options.debug) { if (this.options.debug) {
console.debug('freeAndResetEncoder: reached 25 encodings, reseting...'); console.debug('freeAndResetAllEncoders: reached 25 encodings, resetting...');
} }
this.freeAndResetEncoder(); this.constructor.freeAndResetAllEncoders();
} }
tokenizersCache.count = (tokenizersCache.count || 0) + 1; tokenizerCallsCount++;
return this.gptEncoder.encode(text, 'all').length; }
// Returns the token count of a given text. It also checks and resets the tokenizers if necessary.
getTokenCount(text) {
this.resetTokenizersIfNecessary();
try {
const tokenizer = this.selectTokenizer();
return tokenizer.encode(text, 'all').length;
} catch (error) { } catch (error) {
this.freeAndResetEncoder(); this.constructor.freeAndResetAllEncoders();
return this.gptEncoder.encode(text, 'all').length; const tokenizer = this.selectTokenizer();
return tokenizer.encode(text, 'all').length;
} }
} }

View file

@ -1,7 +1,7 @@
const OpenAIClient = require('../OpenAIClient'); const OpenAIClient = require('../OpenAIClient');
describe('OpenAIClient', () => { describe('OpenAIClient', () => {
let client; let client, client2;
const model = 'gpt-4'; const model = 'gpt-4';
const parentMessageId = '1'; const parentMessageId = '1';
const messages = [ const messages = [
@ -19,11 +19,13 @@ describe('OpenAIClient', () => {
}, },
}; };
client = new OpenAIClient('test-api-key', options); client = new OpenAIClient('test-api-key', options);
client2 = new OpenAIClient('test-api-key', options);
client.refineMessages = jest.fn().mockResolvedValue({ client.refineMessages = jest.fn().mockResolvedValue({
role: 'assistant', role: 'assistant',
content: 'Refined answer', content: 'Refined answer',
tokenCount: 30, tokenCount: 30,
}); });
client.constructor.freeAndResetAllEncoders();
}); });
describe('setOptions', () => { describe('setOptions', () => {
@ -34,10 +36,25 @@ describe('OpenAIClient', () => {
}); });
}); });
describe('freeAndResetEncoder', () => { describe('selectTokenizer', () => {
it('should reset the encoder', () => { it('should get the correct tokenizer based on the instance state', () => {
client.freeAndResetEncoder(); const tokenizer = client.selectTokenizer();
expect(client.gptEncoder).toBeDefined(); expect(tokenizer).toBeDefined();
});
});
describe('freeAllTokenizers', () => {
it('should free all tokenizers', () => {
// Create a tokenizer
const tokenizer = client.selectTokenizer();
// Mock 'free' method on the tokenizer
tokenizer.free = jest.fn();
client.constructor.freeAndResetAllEncoders();
// Check if 'free' method has been called on the tokenizer
expect(tokenizer.free).toHaveBeenCalled();
}); });
}); });
@ -48,7 +65,7 @@ describe('OpenAIClient', () => {
}); });
it('should reset the encoder and count when count reaches 25', () => { it('should reset the encoder and count when count reaches 25', () => {
const freeAndResetEncoderSpy = jest.spyOn(client, 'freeAndResetEncoder'); const freeAndResetEncoderSpy = jest.spyOn(client.constructor, 'freeAndResetAllEncoders');
// Call getTokenCount 25 times // Call getTokenCount 25 times
for (let i = 0; i < 25; i++) { for (let i = 0; i < 25; i++) {
@ -59,7 +76,8 @@ describe('OpenAIClient', () => {
}); });
it('should not reset the encoder and count when count is less than 25', () => { it('should not reset the encoder and count when count is less than 25', () => {
const freeAndResetEncoderSpy = jest.spyOn(client, 'freeAndResetEncoder'); const freeAndResetEncoderSpy = jest.spyOn(client.constructor, 'freeAndResetAllEncoders');
freeAndResetEncoderSpy.mockClear();
// Call getTokenCount 24 times // Call getTokenCount 24 times
for (let i = 0; i < 24; i++) { for (let i = 0; i < 24; i++) {
@ -70,8 +88,10 @@ describe('OpenAIClient', () => {
}); });
it('should handle errors and reset the encoder', () => { it('should handle errors and reset the encoder', () => {
const freeAndResetEncoderSpy = jest.spyOn(client, 'freeAndResetEncoder'); const freeAndResetEncoderSpy = jest.spyOn(client.constructor, 'freeAndResetAllEncoders');
client.gptEncoder.encode = jest.fn().mockImplementation(() => {
// Mock encode function to throw an error
client.selectTokenizer().encode = jest.fn().mockImplementation(() => {
throw new Error('Test error'); throw new Error('Test error');
}); });
@ -79,6 +99,14 @@ describe('OpenAIClient', () => {
expect(freeAndResetEncoderSpy).toHaveBeenCalled(); expect(freeAndResetEncoderSpy).toHaveBeenCalled();
}); });
it('should not throw null pointer error when freeing the same encoder twice', () => {
client.constructor.freeAndResetAllEncoders();
client2.constructor.freeAndResetAllEncoders();
const count = client2.getTokenCount('test text');
expect(count).toBeGreaterThan(0);
});
}); });
describe('getSaveOptions', () => { describe('getSaveOptions', () => {