mirror of
https://github.com/danny-avila/LibreChat.git
synced 2025-12-17 08:50:15 +01:00
fix(OpenAIClient): resolve null pointer exception in tokenizer management (#689)
This commit is contained in:
parent
130356654c
commit
197307d514
2 changed files with 97 additions and 48 deletions
|
|
@ -6,7 +6,10 @@ const {
|
||||||
} = require('@dqbd/tiktoken');
|
} = require('@dqbd/tiktoken');
|
||||||
const { maxTokensMap, genAzureChatCompletion } = require('../../utils');
|
const { maxTokensMap, genAzureChatCompletion } = require('../../utils');
|
||||||
|
|
||||||
|
// Cache to store Tiktoken instances
|
||||||
const tokenizersCache = {};
|
const tokenizersCache = {};
|
||||||
|
// Counter for keeping track of the number of tokenizer calls
|
||||||
|
let tokenizerCallsCount = 0;
|
||||||
|
|
||||||
class OpenAIClient extends BaseClient {
|
class OpenAIClient extends BaseClient {
|
||||||
constructor(apiKey, options = {}) {
|
constructor(apiKey, options = {}) {
|
||||||
|
|
@ -89,7 +92,6 @@ class OpenAIClient extends BaseClient {
|
||||||
this.chatGptLabel = this.options.chatGptLabel || 'Assistant';
|
this.chatGptLabel = this.options.chatGptLabel || 'Assistant';
|
||||||
|
|
||||||
this.setupTokens();
|
this.setupTokens();
|
||||||
this.setupTokenizer();
|
|
||||||
|
|
||||||
if (!this.modelOptions.stop) {
|
if (!this.modelOptions.stop) {
|
||||||
const stopTokens = [this.startToken];
|
const stopTokens = [this.startToken];
|
||||||
|
|
@ -133,68 +135,87 @@ class OpenAIClient extends BaseClient {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
setupTokenizer() {
|
// Selects an appropriate tokenizer based on the current configuration of the client instance.
|
||||||
|
// It takes into account factors such as whether it's a chat completion, an unofficial chat GPT model, etc.
|
||||||
|
selectTokenizer() {
|
||||||
|
let tokenizer;
|
||||||
this.encoding = 'text-davinci-003';
|
this.encoding = 'text-davinci-003';
|
||||||
if (this.isChatCompletion) {
|
if (this.isChatCompletion) {
|
||||||
this.encoding = 'cl100k_base';
|
this.encoding = 'cl100k_base';
|
||||||
this.gptEncoder = this.constructor.getTokenizer(this.encoding);
|
tokenizer = this.constructor.getTokenizer(this.encoding);
|
||||||
} else if (this.isUnofficialChatGptModel) {
|
} else if (this.isUnofficialChatGptModel) {
|
||||||
this.gptEncoder = this.constructor.getTokenizer(this.encoding, true, {
|
const extendSpecialTokens = {
|
||||||
'<|im_start|>': 100264,
|
'<|im_start|>': 100264,
|
||||||
'<|im_end|>': 100265,
|
'<|im_end|>': 100265,
|
||||||
});
|
};
|
||||||
|
tokenizer = this.constructor.getTokenizer(this.encoding, true, extendSpecialTokens);
|
||||||
} else {
|
} else {
|
||||||
try {
|
try {
|
||||||
this.encoding = this.modelOptions.model;
|
this.encoding = this.modelOptions.model;
|
||||||
this.gptEncoder = this.constructor.getTokenizer(this.modelOptions.model, true);
|
tokenizer = this.constructor.getTokenizer(this.modelOptions.model, true);
|
||||||
} catch {
|
} catch {
|
||||||
this.gptEncoder = this.constructor.getTokenizer(this.encoding, true);
|
tokenizer = this.constructor.getTokenizer(this.encoding, true);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
static getTokenizer(encoding, isModelName = false, extendSpecialTokens = {}) {
|
|
||||||
if (tokenizersCache[encoding]) {
|
|
||||||
return tokenizersCache[encoding];
|
|
||||||
}
|
|
||||||
let tokenizer;
|
|
||||||
if (isModelName) {
|
|
||||||
tokenizer = encodingForModel(encoding, extendSpecialTokens);
|
|
||||||
} else {
|
|
||||||
tokenizer = getEncoding(encoding, extendSpecialTokens);
|
|
||||||
}
|
|
||||||
tokenizersCache[encoding] = tokenizer;
|
|
||||||
return tokenizer;
|
return tokenizer;
|
||||||
}
|
}
|
||||||
|
|
||||||
freeAndResetEncoder() {
|
// Retrieves a tokenizer either from the cache or creates a new one if one doesn't exist in the cache.
|
||||||
try {
|
// If a tokenizer is being created, it's also added to the cache.
|
||||||
if (!this.gptEncoder) {
|
static getTokenizer(encoding, isModelName = false, extendSpecialTokens = {}) {
|
||||||
return;
|
let tokenizer;
|
||||||
|
if (tokenizersCache[encoding]) {
|
||||||
|
tokenizer = tokenizersCache[encoding];
|
||||||
|
} else {
|
||||||
|
if (isModelName) {
|
||||||
|
tokenizer = encodingForModel(encoding, extendSpecialTokens);
|
||||||
|
} else {
|
||||||
|
tokenizer = getEncoding(encoding, extendSpecialTokens);
|
||||||
}
|
}
|
||||||
this.gptEncoder.free();
|
tokenizersCache[encoding] = tokenizer;
|
||||||
delete tokenizersCache[this.encoding];
|
}
|
||||||
delete tokenizersCache.count;
|
return tokenizer;
|
||||||
this.setupTokenizer();
|
}
|
||||||
|
|
||||||
|
// Frees all encoders in the cache and resets the count.
|
||||||
|
static freeAndResetAllEncoders() {
|
||||||
|
try {
|
||||||
|
Object.keys(tokenizersCache).forEach((key) => {
|
||||||
|
if (tokenizersCache[key]) {
|
||||||
|
tokenizersCache[key].free();
|
||||||
|
delete tokenizersCache[key];
|
||||||
|
}
|
||||||
|
});
|
||||||
|
// Reset count
|
||||||
|
tokenizerCallsCount = 1;
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.log('freeAndResetEncoder error');
|
console.log('Free and reset encoders error');
|
||||||
console.error(error);
|
console.error(error);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
getTokenCount(text) {
|
// Checks if the cache of tokenizers has reached a certain size. If it has, it frees and resets all tokenizers.
|
||||||
try {
|
resetTokenizersIfNecessary() {
|
||||||
if (tokenizersCache.count >= 25) {
|
if (tokenizerCallsCount >= 25) {
|
||||||
if (this.options.debug) {
|
if (this.options.debug) {
|
||||||
console.debug('freeAndResetEncoder: reached 25 encodings, reseting...');
|
console.debug('freeAndResetAllEncoders: reached 25 encodings, resetting...');
|
||||||
}
|
|
||||||
this.freeAndResetEncoder();
|
|
||||||
}
|
}
|
||||||
tokenizersCache.count = (tokenizersCache.count || 0) + 1;
|
this.constructor.freeAndResetAllEncoders();
|
||||||
return this.gptEncoder.encode(text, 'all').length;
|
}
|
||||||
|
tokenizerCallsCount++;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns the token count of a given text. It also checks and resets the tokenizers if necessary.
|
||||||
|
getTokenCount(text) {
|
||||||
|
this.resetTokenizersIfNecessary();
|
||||||
|
try {
|
||||||
|
const tokenizer = this.selectTokenizer();
|
||||||
|
return tokenizer.encode(text, 'all').length;
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
this.freeAndResetEncoder();
|
this.constructor.freeAndResetAllEncoders();
|
||||||
return this.gptEncoder.encode(text, 'all').length;
|
const tokenizer = this.selectTokenizer();
|
||||||
|
return tokenizer.encode(text, 'all').length;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
const OpenAIClient = require('../OpenAIClient');
|
const OpenAIClient = require('../OpenAIClient');
|
||||||
|
|
||||||
describe('OpenAIClient', () => {
|
describe('OpenAIClient', () => {
|
||||||
let client;
|
let client, client2;
|
||||||
const model = 'gpt-4';
|
const model = 'gpt-4';
|
||||||
const parentMessageId = '1';
|
const parentMessageId = '1';
|
||||||
const messages = [
|
const messages = [
|
||||||
|
|
@ -19,11 +19,13 @@ describe('OpenAIClient', () => {
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
client = new OpenAIClient('test-api-key', options);
|
client = new OpenAIClient('test-api-key', options);
|
||||||
|
client2 = new OpenAIClient('test-api-key', options);
|
||||||
client.refineMessages = jest.fn().mockResolvedValue({
|
client.refineMessages = jest.fn().mockResolvedValue({
|
||||||
role: 'assistant',
|
role: 'assistant',
|
||||||
content: 'Refined answer',
|
content: 'Refined answer',
|
||||||
tokenCount: 30,
|
tokenCount: 30,
|
||||||
});
|
});
|
||||||
|
client.constructor.freeAndResetAllEncoders();
|
||||||
});
|
});
|
||||||
|
|
||||||
describe('setOptions', () => {
|
describe('setOptions', () => {
|
||||||
|
|
@ -34,10 +36,25 @@ describe('OpenAIClient', () => {
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
describe('freeAndResetEncoder', () => {
|
describe('selectTokenizer', () => {
|
||||||
it('should reset the encoder', () => {
|
it('should get the correct tokenizer based on the instance state', () => {
|
||||||
client.freeAndResetEncoder();
|
const tokenizer = client.selectTokenizer();
|
||||||
expect(client.gptEncoder).toBeDefined();
|
expect(tokenizer).toBeDefined();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('freeAllTokenizers', () => {
|
||||||
|
it('should free all tokenizers', () => {
|
||||||
|
// Create a tokenizer
|
||||||
|
const tokenizer = client.selectTokenizer();
|
||||||
|
|
||||||
|
// Mock 'free' method on the tokenizer
|
||||||
|
tokenizer.free = jest.fn();
|
||||||
|
|
||||||
|
client.constructor.freeAndResetAllEncoders();
|
||||||
|
|
||||||
|
// Check if 'free' method has been called on the tokenizer
|
||||||
|
expect(tokenizer.free).toHaveBeenCalled();
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
@ -48,7 +65,7 @@ describe('OpenAIClient', () => {
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should reset the encoder and count when count reaches 25', () => {
|
it('should reset the encoder and count when count reaches 25', () => {
|
||||||
const freeAndResetEncoderSpy = jest.spyOn(client, 'freeAndResetEncoder');
|
const freeAndResetEncoderSpy = jest.spyOn(client.constructor, 'freeAndResetAllEncoders');
|
||||||
|
|
||||||
// Call getTokenCount 25 times
|
// Call getTokenCount 25 times
|
||||||
for (let i = 0; i < 25; i++) {
|
for (let i = 0; i < 25; i++) {
|
||||||
|
|
@ -59,7 +76,8 @@ describe('OpenAIClient', () => {
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should not reset the encoder and count when count is less than 25', () => {
|
it('should not reset the encoder and count when count is less than 25', () => {
|
||||||
const freeAndResetEncoderSpy = jest.spyOn(client, 'freeAndResetEncoder');
|
const freeAndResetEncoderSpy = jest.spyOn(client.constructor, 'freeAndResetAllEncoders');
|
||||||
|
freeAndResetEncoderSpy.mockClear();
|
||||||
|
|
||||||
// Call getTokenCount 24 times
|
// Call getTokenCount 24 times
|
||||||
for (let i = 0; i < 24; i++) {
|
for (let i = 0; i < 24; i++) {
|
||||||
|
|
@ -70,8 +88,10 @@ describe('OpenAIClient', () => {
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should handle errors and reset the encoder', () => {
|
it('should handle errors and reset the encoder', () => {
|
||||||
const freeAndResetEncoderSpy = jest.spyOn(client, 'freeAndResetEncoder');
|
const freeAndResetEncoderSpy = jest.spyOn(client.constructor, 'freeAndResetAllEncoders');
|
||||||
client.gptEncoder.encode = jest.fn().mockImplementation(() => {
|
|
||||||
|
// Mock encode function to throw an error
|
||||||
|
client.selectTokenizer().encode = jest.fn().mockImplementation(() => {
|
||||||
throw new Error('Test error');
|
throw new Error('Test error');
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
@ -79,6 +99,14 @@ describe('OpenAIClient', () => {
|
||||||
|
|
||||||
expect(freeAndResetEncoderSpy).toHaveBeenCalled();
|
expect(freeAndResetEncoderSpy).toHaveBeenCalled();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('should not throw null pointer error when freeing the same encoder twice', () => {
|
||||||
|
client.constructor.freeAndResetAllEncoders();
|
||||||
|
client2.constructor.freeAndResetAllEncoders();
|
||||||
|
|
||||||
|
const count = client2.getTokenCount('test text');
|
||||||
|
expect(count).toBeGreaterThan(0);
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
describe('getSaveOptions', () => {
|
describe('getSaveOptions', () => {
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue