mirror of
https://github.com/danny-avila/LibreChat.git
synced 2025-12-18 01:10:14 +01:00
fix: Match OpenAI Token Counting Strategy 🪙 (#945)
* wip token fix * fix: complete token count refactor to match OpenAI example * chore: add back sendPayload method (accidentally deleted) * chore: revise JSDoc for getTokenCountForMessage
This commit is contained in:
parent
b3afd562b9
commit
9491b753c3
11 changed files with 115 additions and 76 deletions
|
|
@ -1,9 +1,6 @@
|
||||||
// const { Agent, ProxyAgent } = require('undici');
|
// const { Agent, ProxyAgent } = require('undici');
|
||||||
const BaseClient = require('./BaseClient');
|
const BaseClient = require('./BaseClient');
|
||||||
const {
|
const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken');
|
||||||
encoding_for_model: encodingForModel,
|
|
||||||
get_encoding: getEncoding,
|
|
||||||
} = require('@dqbd/tiktoken');
|
|
||||||
const Anthropic = require('@anthropic-ai/sdk');
|
const Anthropic = require('@anthropic-ai/sdk');
|
||||||
|
|
||||||
const HUMAN_PROMPT = '\n\nHuman:';
|
const HUMAN_PROMPT = '\n\nHuman:';
|
||||||
|
|
|
||||||
|
|
@ -272,7 +272,9 @@ class BaseClient {
|
||||||
* @returns {Object} An object with three properties: `context`, `remainingContextTokens`, and `messagesToRefine`. `context` is an array of messages that fit within the token limit. `remainingContextTokens` is the number of tokens remaining within the limit after adding the messages to the context. `messagesToRefine` is an array of messages that were not added to the context because they would have exceeded the token limit.
|
* @returns {Object} An object with three properties: `context`, `remainingContextTokens`, and `messagesToRefine`. `context` is an array of messages that fit within the token limit. `remainingContextTokens` is the number of tokens remaining within the limit after adding the messages to the context. `messagesToRefine` is an array of messages that were not added to the context because they would have exceeded the token limit.
|
||||||
*/
|
*/
|
||||||
async getMessagesWithinTokenLimit(messages) {
|
async getMessagesWithinTokenLimit(messages) {
|
||||||
let currentTokenCount = 0;
|
// Every reply is primed with <|start|>assistant<|message|>, so we
|
||||||
|
// start with 3 tokens for the label after all messages have been counted.
|
||||||
|
let currentTokenCount = 3;
|
||||||
let context = [];
|
let context = [];
|
||||||
let messagesToRefine = [];
|
let messagesToRefine = [];
|
||||||
let refineIndex = -1;
|
let refineIndex = -1;
|
||||||
|
|
@ -562,44 +564,29 @@ class BaseClient {
|
||||||
* Algorithm adapted from "6. Counting tokens for chat API calls" of
|
* Algorithm adapted from "6. Counting tokens for chat API calls" of
|
||||||
* https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
* https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
||||||
*
|
*
|
||||||
* An additional 2 tokens need to be added for metadata after all messages have been counted.
|
* An additional 3 tokens need to be added for assistant label priming after all messages have been counted.
|
||||||
*
|
*
|
||||||
* @param {*} message
|
* @param {Object} message
|
||||||
*/
|
*/
|
||||||
getTokenCountForMessage(message) {
|
getTokenCountForMessage(message) {
|
||||||
let tokensPerMessage;
|
// Note: gpt-3.5-turbo and gpt-4 may update over time. Use default for these as well as for unknown models
|
||||||
let nameAdjustment;
|
let tokensPerMessage = 3;
|
||||||
if (this.modelOptions.model.startsWith('gpt-4')) {
|
let tokensPerName = 1;
|
||||||
tokensPerMessage = 3;
|
|
||||||
nameAdjustment = 1;
|
if (this.modelOptions.model === 'gpt-3.5-turbo-0301') {
|
||||||
} else {
|
|
||||||
tokensPerMessage = 4;
|
tokensPerMessage = 4;
|
||||||
nameAdjustment = -1;
|
tokensPerName = -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (this.options.debug) {
|
let numTokens = tokensPerMessage;
|
||||||
console.debug('getTokenCountForMessage', message);
|
for (let [key, value] of Object.entries(message)) {
|
||||||
}
|
numTokens += this.getTokenCount(value);
|
||||||
|
if (key === 'name') {
|
||||||
// Map each property of the message to the number of tokens it contains
|
numTokens += tokensPerName;
|
||||||
const propertyTokenCounts = Object.entries(message).map(([key, value]) => {
|
|
||||||
if (key === 'tokenCount' || typeof value !== 'string') {
|
|
||||||
return 0;
|
|
||||||
}
|
}
|
||||||
// Count the number of tokens in the property value
|
|
||||||
const numTokens = this.getTokenCount(value);
|
|
||||||
|
|
||||||
// Adjust by `nameAdjustment` tokens if the property key is 'name'
|
|
||||||
const adjustment = key === 'name' ? nameAdjustment : 0;
|
|
||||||
return numTokens + adjustment;
|
|
||||||
});
|
|
||||||
|
|
||||||
if (this.options.debug) {
|
|
||||||
console.debug('propertyTokenCounts', propertyTokenCounts);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sum the number of tokens in all properties and add `tokensPerMessage` for metadata
|
return numTokens;
|
||||||
return propertyTokenCounts.reduce((a, b) => a + b, tokensPerMessage);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async sendPayload(payload, opts = {}) {
|
async sendPayload(payload, opts = {}) {
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,6 @@
|
||||||
const crypto = require('crypto');
|
const crypto = require('crypto');
|
||||||
const Keyv = require('keyv');
|
const Keyv = require('keyv');
|
||||||
const {
|
const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken');
|
||||||
encoding_for_model: encodingForModel,
|
|
||||||
get_encoding: getEncoding,
|
|
||||||
} = require('@dqbd/tiktoken');
|
|
||||||
const { fetchEventSource } = require('@waylaidwanderer/fetch-event-source');
|
const { fetchEventSource } = require('@waylaidwanderer/fetch-event-source');
|
||||||
const { Agent, ProxyAgent } = require('undici');
|
const { Agent, ProxyAgent } = require('undici');
|
||||||
const BaseClient = require('./BaseClient');
|
const BaseClient = require('./BaseClient');
|
||||||
|
|
@ -526,8 +523,8 @@ ${botMessage.message}
|
||||||
const prompt = `${promptBody}${promptSuffix}`;
|
const prompt = `${promptBody}${promptSuffix}`;
|
||||||
if (isChatGptModel) {
|
if (isChatGptModel) {
|
||||||
messagePayload.content = prompt;
|
messagePayload.content = prompt;
|
||||||
// Add 2 tokens for metadata after all messages have been counted.
|
// Add 3 tokens for Assistant Label priming after all messages have been counted.
|
||||||
currentTokenCount += 2;
|
currentTokenCount += 3;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use up to `this.maxContextTokens` tokens (prompt + response), but try to leave `this.maxTokens` tokens for the response.
|
// Use up to `this.maxContextTokens` tokens (prompt + response), but try to leave `this.maxTokens` tokens for the response.
|
||||||
|
|
@ -554,33 +551,29 @@ ${botMessage.message}
|
||||||
* Algorithm adapted from "6. Counting tokens for chat API calls" of
|
* Algorithm adapted from "6. Counting tokens for chat API calls" of
|
||||||
* https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
* https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
||||||
*
|
*
|
||||||
* An additional 2 tokens need to be added for metadata after all messages have been counted.
|
* An additional 3 tokens need to be added for assistant label priming after all messages have been counted.
|
||||||
*
|
*
|
||||||
* @param {*} message
|
* @param {Object} message
|
||||||
*/
|
*/
|
||||||
getTokenCountForMessage(message) {
|
getTokenCountForMessage(message) {
|
||||||
let tokensPerMessage;
|
// Note: gpt-3.5-turbo and gpt-4 may update over time. Use default for these as well as for unknown models
|
||||||
let nameAdjustment;
|
let tokensPerMessage = 3;
|
||||||
if (this.modelOptions.model.startsWith('gpt-4')) {
|
let tokensPerName = 1;
|
||||||
tokensPerMessage = 3;
|
|
||||||
nameAdjustment = 1;
|
if (this.modelOptions.model === 'gpt-3.5-turbo-0301') {
|
||||||
} else {
|
|
||||||
tokensPerMessage = 4;
|
tokensPerMessage = 4;
|
||||||
nameAdjustment = -1;
|
tokensPerName = -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Map each property of the message to the number of tokens it contains
|
let numTokens = tokensPerMessage;
|
||||||
const propertyTokenCounts = Object.entries(message).map(([key, value]) => {
|
for (let [key, value] of Object.entries(message)) {
|
||||||
// Count the number of tokens in the property value
|
numTokens += this.getTokenCount(value);
|
||||||
const numTokens = this.getTokenCount(value);
|
if (key === 'name') {
|
||||||
|
numTokens += tokensPerName;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Adjust by `nameAdjustment` tokens if the property key is 'name'
|
return numTokens;
|
||||||
const adjustment = key === 'name' ? nameAdjustment : 0;
|
|
||||||
return numTokens + adjustment;
|
|
||||||
});
|
|
||||||
|
|
||||||
// Sum the number of tokens in all properties and add `tokensPerMessage` for metadata
|
|
||||||
return propertyTokenCounts.reduce((a, b) => a + b, tokensPerMessage);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,7 @@
|
||||||
const BaseClient = require('./BaseClient');
|
const BaseClient = require('./BaseClient');
|
||||||
const { google } = require('googleapis');
|
const { google } = require('googleapis');
|
||||||
const { Agent, ProxyAgent } = require('undici');
|
const { Agent, ProxyAgent } = require('undici');
|
||||||
const {
|
const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken');
|
||||||
encoding_for_model: encodingForModel,
|
|
||||||
get_encoding: getEncoding,
|
|
||||||
} = require('@dqbd/tiktoken');
|
|
||||||
|
|
||||||
const tokenizersCache = {};
|
const tokenizersCache = {};
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,6 @@
|
||||||
const BaseClient = require('./BaseClient');
|
const BaseClient = require('./BaseClient');
|
||||||
const ChatGPTClient = require('./ChatGPTClient');
|
const ChatGPTClient = require('./ChatGPTClient');
|
||||||
const {
|
const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken');
|
||||||
encoding_for_model: encodingForModel,
|
|
||||||
get_encoding: getEncoding,
|
|
||||||
} = require('@dqbd/tiktoken');
|
|
||||||
const { maxTokensMap, genAzureChatCompletion } = require('../../utils');
|
const { maxTokensMap, genAzureChatCompletion } = require('../../utils');
|
||||||
const { runTitleChain } = require('./chains');
|
const { runTitleChain } = require('./chains');
|
||||||
const { createLLM } = require('./llm');
|
const { createLLM } = require('./llm');
|
||||||
|
|
|
||||||
|
|
@ -138,7 +138,8 @@ describe('BaseClient', () => {
|
||||||
{ role: 'assistant', content: 'How can I help you?', tokenCount: 19 },
|
{ role: 'assistant', content: 'How can I help you?', tokenCount: 19 },
|
||||||
{ role: 'user', content: 'I have a question.', tokenCount: 18 },
|
{ role: 'user', content: 'I have a question.', tokenCount: 18 },
|
||||||
];
|
];
|
||||||
const expectedRemainingContextTokens = 58; // 100 - 5 - 19 - 18
|
// Subtract 3 tokens for Assistant Label priming after all messages have been counted.
|
||||||
|
const expectedRemainingContextTokens = 58 - 3; // (100 - 5 - 19 - 18) - 3
|
||||||
const expectedMessagesToRefine = [];
|
const expectedMessagesToRefine = [];
|
||||||
|
|
||||||
const result = await TestClient.getMessagesWithinTokenLimit(messages);
|
const result = await TestClient.getMessagesWithinTokenLimit(messages);
|
||||||
|
|
@ -168,7 +169,9 @@ describe('BaseClient', () => {
|
||||||
{ role: 'assistant', content: 'How can I help you?', tokenCount: 19 },
|
{ role: 'assistant', content: 'How can I help you?', tokenCount: 19 },
|
||||||
{ role: 'user', content: 'I have a question.', tokenCount: 18 },
|
{ role: 'user', content: 'I have a question.', tokenCount: 18 },
|
||||||
];
|
];
|
||||||
const expectedRemainingContextTokens = 8; // 50 - 18 - 19 - 5
|
|
||||||
|
// Subtract 3 tokens for Assistant Label priming after all messages have been counted.
|
||||||
|
const expectedRemainingContextTokens = 8 - 3; // (50 - 18 - 19 - 5) - 3
|
||||||
const expectedMessagesToRefine = [
|
const expectedMessagesToRefine = [
|
||||||
{ role: 'user', content: 'I need a coffee, stat!', tokenCount: 30 },
|
{ role: 'user', content: 'I need a coffee, stat!', tokenCount: 30 },
|
||||||
{ role: 'assistant', content: 'Sure, I can help with that.', tokenCount: 30 },
|
{ role: 'assistant', content: 'Sure, I can help with that.', tokenCount: 30 },
|
||||||
|
|
|
||||||
|
|
@ -213,4 +213,63 @@ describe('OpenAIClient', () => {
|
||||||
expect(result.prompt).toEqual([]);
|
expect(result.prompt).toEqual([]);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe('getTokenCountForMessage', () => {
|
||||||
|
const example_messages = [
|
||||||
|
{
|
||||||
|
role: 'system',
|
||||||
|
content:
|
||||||
|
'You are a helpful, pattern-following assistant that translates corporate jargon into plain English.',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
role: 'system',
|
||||||
|
name: 'example_user',
|
||||||
|
content: 'New synergies will help drive top-line growth.',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
role: 'system',
|
||||||
|
name: 'example_assistant',
|
||||||
|
content: 'Things working well together will increase revenue.',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
role: 'system',
|
||||||
|
name: 'example_user',
|
||||||
|
content:
|
||||||
|
'Let\'s circle back when we have more bandwidth to touch base on opportunities for increased leverage.',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
role: 'system',
|
||||||
|
name: 'example_assistant',
|
||||||
|
content: 'Let\'s talk later when we\'re less busy about how to do better.',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
role: 'user',
|
||||||
|
content:
|
||||||
|
'This late pivot means we don\'t have time to boil the ocean for the client deliverable.',
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
const testCases = [
|
||||||
|
{ model: 'gpt-3.5-turbo-0301', expected: 127 },
|
||||||
|
{ model: 'gpt-3.5-turbo-0613', expected: 129 },
|
||||||
|
{ model: 'gpt-3.5-turbo', expected: 129 },
|
||||||
|
{ model: 'gpt-4-0314', expected: 129 },
|
||||||
|
{ model: 'gpt-4-0613', expected: 129 },
|
||||||
|
{ model: 'gpt-4', expected: 129 },
|
||||||
|
{ model: 'unknown', expected: 129 },
|
||||||
|
];
|
||||||
|
|
||||||
|
testCases.forEach((testCase) => {
|
||||||
|
it(`should return ${testCase.expected} tokens for model ${testCase.model}`, () => {
|
||||||
|
client.modelOptions.model = testCase.model;
|
||||||
|
client.selectTokenizer();
|
||||||
|
// 3 tokens for assistant label
|
||||||
|
let totalTokens = 3;
|
||||||
|
for (let message of example_messages) {
|
||||||
|
totalTokens += client.getTokenCountForMessage(message);
|
||||||
|
}
|
||||||
|
expect(totalTokens).toBe(testCase.expected);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,6 @@
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"@anthropic-ai/sdk": "^0.5.4",
|
"@anthropic-ai/sdk": "^0.5.4",
|
||||||
"@azure/search-documents": "^11.3.2",
|
"@azure/search-documents": "^11.3.2",
|
||||||
"@dqbd/tiktoken": "^1.0.7",
|
|
||||||
"@keyv/mongo": "^2.1.8",
|
"@keyv/mongo": "^2.1.8",
|
||||||
"@waylaidwanderer/chatgpt-api": "^1.37.2",
|
"@waylaidwanderer/chatgpt-api": "^1.37.2",
|
||||||
"axios": "^1.3.4",
|
"axios": "^1.3.4",
|
||||||
|
|
@ -60,6 +59,7 @@
|
||||||
"passport-local": "^1.0.0",
|
"passport-local": "^1.0.0",
|
||||||
"pino": "^8.12.1",
|
"pino": "^8.12.1",
|
||||||
"sharp": "^0.32.5",
|
"sharp": "^0.32.5",
|
||||||
|
"tiktoken": "^1.0.10",
|
||||||
"ua-parser-js": "^1.0.36",
|
"ua-parser-js": "^1.0.36",
|
||||||
"zod": "^3.22.2"
|
"zod": "^3.22.2"
|
||||||
},
|
},
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,9 @@
|
||||||
const express = require('express');
|
const express = require('express');
|
||||||
const router = express.Router();
|
const router = express.Router();
|
||||||
const { Tiktoken } = require('@dqbd/tiktoken/lite');
|
const { Tiktoken } = require('tiktoken/lite');
|
||||||
const { load } = require('@dqbd/tiktoken/load');
|
const { load } = require('tiktoken/load');
|
||||||
const registry = require('@dqbd/tiktoken/registry.json');
|
const registry = require('tiktoken/registry.json');
|
||||||
const models = require('@dqbd/tiktoken/model_to_encoding.json');
|
const models = require('tiktoken/model_to_encoding.json');
|
||||||
const requireJwtAuth = require('../middleware/requireJwtAuth');
|
const requireJwtAuth = require('../middleware/requireJwtAuth');
|
||||||
|
|
||||||
router.post('/', requireJwtAuth, async (req, res) => {
|
router.post('/', requireJwtAuth, async (req, res) => {
|
||||||
|
|
|
||||||
|
|
@ -41,6 +41,7 @@ const maxTokensMap = {
|
||||||
'gpt-4': 8191,
|
'gpt-4': 8191,
|
||||||
'gpt-4-0613': 8191,
|
'gpt-4-0613': 8191,
|
||||||
'gpt-4-32k': 32767,
|
'gpt-4-32k': 32767,
|
||||||
|
'gpt-4-32k-0314': 32767,
|
||||||
'gpt-4-32k-0613': 32767,
|
'gpt-4-32k-0613': 32767,
|
||||||
'gpt-3.5-turbo': 4095,
|
'gpt-3.5-turbo': 4095,
|
||||||
'gpt-3.5-turbo-0613': 4095,
|
'gpt-3.5-turbo-0613': 4095,
|
||||||
|
|
|
||||||
7
package-lock.json
generated
7
package-lock.json
generated
|
|
@ -44,7 +44,6 @@
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"@anthropic-ai/sdk": "^0.5.4",
|
"@anthropic-ai/sdk": "^0.5.4",
|
||||||
"@azure/search-documents": "^11.3.2",
|
"@azure/search-documents": "^11.3.2",
|
||||||
"@dqbd/tiktoken": "^1.0.7",
|
|
||||||
"@keyv/mongo": "^2.1.8",
|
"@keyv/mongo": "^2.1.8",
|
||||||
"@waylaidwanderer/chatgpt-api": "^1.37.2",
|
"@waylaidwanderer/chatgpt-api": "^1.37.2",
|
||||||
"axios": "^1.3.4",
|
"axios": "^1.3.4",
|
||||||
|
|
@ -81,6 +80,7 @@
|
||||||
"passport-local": "^1.0.0",
|
"passport-local": "^1.0.0",
|
||||||
"pino": "^8.12.1",
|
"pino": "^8.12.1",
|
||||||
"sharp": "^0.32.5",
|
"sharp": "^0.32.5",
|
||||||
|
"tiktoken": "^1.0.10",
|
||||||
"ua-parser-js": "^1.0.36",
|
"ua-parser-js": "^1.0.36",
|
||||||
"zod": "^3.22.2"
|
"zod": "^3.22.2"
|
||||||
},
|
},
|
||||||
|
|
@ -21886,6 +21886,11 @@
|
||||||
"real-require": "^0.2.0"
|
"real-require": "^0.2.0"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"node_modules/tiktoken": {
|
||||||
|
"version": "1.0.10",
|
||||||
|
"resolved": "https://registry.npmjs.org/tiktoken/-/tiktoken-1.0.10.tgz",
|
||||||
|
"integrity": "sha512-gF8ndTCNu7WcRFbl1UUWaFIB4CTXmHzS3tRYdyUYF7x3C6YR6Evoao4zhKDmWIwv2PzNbzoQMV8Pxt+17lEDbA=="
|
||||||
|
},
|
||||||
"node_modules/tmp": {
|
"node_modules/tmp": {
|
||||||
"version": "0.0.33",
|
"version": "0.0.33",
|
||||||
"resolved": "https://registry.npmjs.org/tmp/-/tmp-0.0.33.tgz",
|
"resolved": "https://registry.npmjs.org/tmp/-/tmp-0.0.33.tgz",
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue