fix: Match OpenAI Token Counting Strategy 🪙 (#945)

* wip token fix

* fix: complete token count refactor to match OpenAI example

* chore: add back sendPayload method (accidentally deleted)

* chore: revise JSDoc for getTokenCountForMessage
This commit is contained in:
Danny Avila 2023-09-14 19:40:21 -04:00 committed by GitHub
parent b3afd562b9
commit 9491b753c3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 115 additions and 76 deletions

View file

@ -1,9 +1,6 @@
// const { Agent, ProxyAgent } = require('undici'); // const { Agent, ProxyAgent } = require('undici');
const BaseClient = require('./BaseClient'); const BaseClient = require('./BaseClient');
const { const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken');
encoding_for_model: encodingForModel,
get_encoding: getEncoding,
} = require('@dqbd/tiktoken');
const Anthropic = require('@anthropic-ai/sdk'); const Anthropic = require('@anthropic-ai/sdk');
const HUMAN_PROMPT = '\n\nHuman:'; const HUMAN_PROMPT = '\n\nHuman:';

View file

@ -272,7 +272,9 @@ class BaseClient {
* @returns {Object} An object with three properties: `context`, `remainingContextTokens`, and `messagesToRefine`. `context` is an array of messages that fit within the token limit. `remainingContextTokens` is the number of tokens remaining within the limit after adding the messages to the context. `messagesToRefine` is an array of messages that were not added to the context because they would have exceeded the token limit. * @returns {Object} An object with three properties: `context`, `remainingContextTokens`, and `messagesToRefine`. `context` is an array of messages that fit within the token limit. `remainingContextTokens` is the number of tokens remaining within the limit after adding the messages to the context. `messagesToRefine` is an array of messages that were not added to the context because they would have exceeded the token limit.
*/ */
async getMessagesWithinTokenLimit(messages) { async getMessagesWithinTokenLimit(messages) {
let currentTokenCount = 0; // Every reply is primed with <|start|>assistant<|message|>, so we
// start with 3 tokens for the label after all messages have been counted.
let currentTokenCount = 3;
let context = []; let context = [];
let messagesToRefine = []; let messagesToRefine = [];
let refineIndex = -1; let refineIndex = -1;
@ -562,44 +564,29 @@ class BaseClient {
* Algorithm adapted from "6. Counting tokens for chat API calls" of * Algorithm adapted from "6. Counting tokens for chat API calls" of
* https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb * https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
* *
* An additional 2 tokens need to be added for metadata after all messages have been counted. * An additional 3 tokens need to be added for assistant label priming after all messages have been counted.
* *
* @param {*} message * @param {Object} message
*/ */
getTokenCountForMessage(message) { getTokenCountForMessage(message) {
let tokensPerMessage; // Note: gpt-3.5-turbo and gpt-4 may update over time. Use default for these as well as for unknown models
let nameAdjustment; let tokensPerMessage = 3;
if (this.modelOptions.model.startsWith('gpt-4')) { let tokensPerName = 1;
tokensPerMessage = 3;
nameAdjustment = 1; if (this.modelOptions.model === 'gpt-3.5-turbo-0301') {
} else {
tokensPerMessage = 4; tokensPerMessage = 4;
nameAdjustment = -1; tokensPerName = -1;
} }
if (this.options.debug) { let numTokens = tokensPerMessage;
console.debug('getTokenCountForMessage', message); for (let [key, value] of Object.entries(message)) {
} numTokens += this.getTokenCount(value);
if (key === 'name') {
// Map each property of the message to the number of tokens it contains numTokens += tokensPerName;
const propertyTokenCounts = Object.entries(message).map(([key, value]) => {
if (key === 'tokenCount' || typeof value !== 'string') {
return 0;
} }
// Count the number of tokens in the property value
const numTokens = this.getTokenCount(value);
// Adjust by `nameAdjustment` tokens if the property key is 'name'
const adjustment = key === 'name' ? nameAdjustment : 0;
return numTokens + adjustment;
});
if (this.options.debug) {
console.debug('propertyTokenCounts', propertyTokenCounts);
} }
// Sum the number of tokens in all properties and add `tokensPerMessage` for metadata return numTokens;
return propertyTokenCounts.reduce((a, b) => a + b, tokensPerMessage);
} }
async sendPayload(payload, opts = {}) { async sendPayload(payload, opts = {}) {

View file

@ -1,9 +1,6 @@
const crypto = require('crypto'); const crypto = require('crypto');
const Keyv = require('keyv'); const Keyv = require('keyv');
const { const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken');
encoding_for_model: encodingForModel,
get_encoding: getEncoding,
} = require('@dqbd/tiktoken');
const { fetchEventSource } = require('@waylaidwanderer/fetch-event-source'); const { fetchEventSource } = require('@waylaidwanderer/fetch-event-source');
const { Agent, ProxyAgent } = require('undici'); const { Agent, ProxyAgent } = require('undici');
const BaseClient = require('./BaseClient'); const BaseClient = require('./BaseClient');
@ -526,8 +523,8 @@ ${botMessage.message}
const prompt = `${promptBody}${promptSuffix}`; const prompt = `${promptBody}${promptSuffix}`;
if (isChatGptModel) { if (isChatGptModel) {
messagePayload.content = prompt; messagePayload.content = prompt;
// Add 2 tokens for metadata after all messages have been counted. // Add 3 tokens for Assistant Label priming after all messages have been counted.
currentTokenCount += 2; currentTokenCount += 3;
} }
// Use up to `this.maxContextTokens` tokens (prompt + response), but try to leave `this.maxTokens` tokens for the response. // Use up to `this.maxContextTokens` tokens (prompt + response), but try to leave `this.maxTokens` tokens for the response.
@ -554,33 +551,29 @@ ${botMessage.message}
* Algorithm adapted from "6. Counting tokens for chat API calls" of * Algorithm adapted from "6. Counting tokens for chat API calls" of
* https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb * https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
* *
* An additional 2 tokens need to be added for metadata after all messages have been counted. * An additional 3 tokens need to be added for assistant label priming after all messages have been counted.
* *
* @param {*} message * @param {Object} message
*/ */
getTokenCountForMessage(message) { getTokenCountForMessage(message) {
let tokensPerMessage; // Note: gpt-3.5-turbo and gpt-4 may update over time. Use default for these as well as for unknown models
let nameAdjustment; let tokensPerMessage = 3;
if (this.modelOptions.model.startsWith('gpt-4')) { let tokensPerName = 1;
tokensPerMessage = 3;
nameAdjustment = 1; if (this.modelOptions.model === 'gpt-3.5-turbo-0301') {
} else {
tokensPerMessage = 4; tokensPerMessage = 4;
nameAdjustment = -1; tokensPerName = -1;
} }
// Map each property of the message to the number of tokens it contains let numTokens = tokensPerMessage;
const propertyTokenCounts = Object.entries(message).map(([key, value]) => { for (let [key, value] of Object.entries(message)) {
// Count the number of tokens in the property value numTokens += this.getTokenCount(value);
const numTokens = this.getTokenCount(value); if (key === 'name') {
numTokens += tokensPerName;
}
}
// Adjust by `nameAdjustment` tokens if the property key is 'name' return numTokens;
const adjustment = key === 'name' ? nameAdjustment : 0;
return numTokens + adjustment;
});
// Sum the number of tokens in all properties and add `tokensPerMessage` for metadata
return propertyTokenCounts.reduce((a, b) => a + b, tokensPerMessage);
} }
} }

View file

@ -1,10 +1,7 @@
const BaseClient = require('./BaseClient'); const BaseClient = require('./BaseClient');
const { google } = require('googleapis'); const { google } = require('googleapis');
const { Agent, ProxyAgent } = require('undici'); const { Agent, ProxyAgent } = require('undici');
const { const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken');
encoding_for_model: encodingForModel,
get_encoding: getEncoding,
} = require('@dqbd/tiktoken');
const tokenizersCache = {}; const tokenizersCache = {};

View file

@ -1,9 +1,6 @@
const BaseClient = require('./BaseClient'); const BaseClient = require('./BaseClient');
const ChatGPTClient = require('./ChatGPTClient'); const ChatGPTClient = require('./ChatGPTClient');
const { const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken');
encoding_for_model: encodingForModel,
get_encoding: getEncoding,
} = require('@dqbd/tiktoken');
const { maxTokensMap, genAzureChatCompletion } = require('../../utils'); const { maxTokensMap, genAzureChatCompletion } = require('../../utils');
const { runTitleChain } = require('./chains'); const { runTitleChain } = require('./chains');
const { createLLM } = require('./llm'); const { createLLM } = require('./llm');

View file

@ -138,7 +138,8 @@ describe('BaseClient', () => {
{ role: 'assistant', content: 'How can I help you?', tokenCount: 19 }, { role: 'assistant', content: 'How can I help you?', tokenCount: 19 },
{ role: 'user', content: 'I have a question.', tokenCount: 18 }, { role: 'user', content: 'I have a question.', tokenCount: 18 },
]; ];
const expectedRemainingContextTokens = 58; // 100 - 5 - 19 - 18 // Subtract 3 tokens for Assistant Label priming after all messages have been counted.
const expectedRemainingContextTokens = 58 - 3; // (100 - 5 - 19 - 18) - 3
const expectedMessagesToRefine = []; const expectedMessagesToRefine = [];
const result = await TestClient.getMessagesWithinTokenLimit(messages); const result = await TestClient.getMessagesWithinTokenLimit(messages);
@ -168,7 +169,9 @@ describe('BaseClient', () => {
{ role: 'assistant', content: 'How can I help you?', tokenCount: 19 }, { role: 'assistant', content: 'How can I help you?', tokenCount: 19 },
{ role: 'user', content: 'I have a question.', tokenCount: 18 }, { role: 'user', content: 'I have a question.', tokenCount: 18 },
]; ];
const expectedRemainingContextTokens = 8; // 50 - 18 - 19 - 5
// Subtract 3 tokens for Assistant Label priming after all messages have been counted.
const expectedRemainingContextTokens = 8 - 3; // (50 - 18 - 19 - 5) - 3
const expectedMessagesToRefine = [ const expectedMessagesToRefine = [
{ role: 'user', content: 'I need a coffee, stat!', tokenCount: 30 }, { role: 'user', content: 'I need a coffee, stat!', tokenCount: 30 },
{ role: 'assistant', content: 'Sure, I can help with that.', tokenCount: 30 }, { role: 'assistant', content: 'Sure, I can help with that.', tokenCount: 30 },

View file

@ -213,4 +213,63 @@ describe('OpenAIClient', () => {
expect(result.prompt).toEqual([]); expect(result.prompt).toEqual([]);
}); });
}); });
describe('getTokenCountForMessage', () => {
const example_messages = [
{
role: 'system',
content:
'You are a helpful, pattern-following assistant that translates corporate jargon into plain English.',
},
{
role: 'system',
name: 'example_user',
content: 'New synergies will help drive top-line growth.',
},
{
role: 'system',
name: 'example_assistant',
content: 'Things working well together will increase revenue.',
},
{
role: 'system',
name: 'example_user',
content:
'Let\'s circle back when we have more bandwidth to touch base on opportunities for increased leverage.',
},
{
role: 'system',
name: 'example_assistant',
content: 'Let\'s talk later when we\'re less busy about how to do better.',
},
{
role: 'user',
content:
'This late pivot means we don\'t have time to boil the ocean for the client deliverable.',
},
];
const testCases = [
{ model: 'gpt-3.5-turbo-0301', expected: 127 },
{ model: 'gpt-3.5-turbo-0613', expected: 129 },
{ model: 'gpt-3.5-turbo', expected: 129 },
{ model: 'gpt-4-0314', expected: 129 },
{ model: 'gpt-4-0613', expected: 129 },
{ model: 'gpt-4', expected: 129 },
{ model: 'unknown', expected: 129 },
];
testCases.forEach((testCase) => {
it(`should return ${testCase.expected} tokens for model ${testCase.model}`, () => {
client.modelOptions.model = testCase.model;
client.selectTokenizer();
// 3 tokens for assistant label
let totalTokens = 3;
for (let message of example_messages) {
totalTokens += client.getTokenCountForMessage(message);
}
expect(totalTokens).toBe(testCase.expected);
});
});
});
}); });

View file

@ -23,7 +23,6 @@
"dependencies": { "dependencies": {
"@anthropic-ai/sdk": "^0.5.4", "@anthropic-ai/sdk": "^0.5.4",
"@azure/search-documents": "^11.3.2", "@azure/search-documents": "^11.3.2",
"@dqbd/tiktoken": "^1.0.7",
"@keyv/mongo": "^2.1.8", "@keyv/mongo": "^2.1.8",
"@waylaidwanderer/chatgpt-api": "^1.37.2", "@waylaidwanderer/chatgpt-api": "^1.37.2",
"axios": "^1.3.4", "axios": "^1.3.4",
@ -60,6 +59,7 @@
"passport-local": "^1.0.0", "passport-local": "^1.0.0",
"pino": "^8.12.1", "pino": "^8.12.1",
"sharp": "^0.32.5", "sharp": "^0.32.5",
"tiktoken": "^1.0.10",
"ua-parser-js": "^1.0.36", "ua-parser-js": "^1.0.36",
"zod": "^3.22.2" "zod": "^3.22.2"
}, },

View file

@ -1,9 +1,9 @@
const express = require('express'); const express = require('express');
const router = express.Router(); const router = express.Router();
const { Tiktoken } = require('@dqbd/tiktoken/lite'); const { Tiktoken } = require('tiktoken/lite');
const { load } = require('@dqbd/tiktoken/load'); const { load } = require('tiktoken/load');
const registry = require('@dqbd/tiktoken/registry.json'); const registry = require('tiktoken/registry.json');
const models = require('@dqbd/tiktoken/model_to_encoding.json'); const models = require('tiktoken/model_to_encoding.json');
const requireJwtAuth = require('../middleware/requireJwtAuth'); const requireJwtAuth = require('../middleware/requireJwtAuth');
router.post('/', requireJwtAuth, async (req, res) => { router.post('/', requireJwtAuth, async (req, res) => {

View file

@ -41,6 +41,7 @@ const maxTokensMap = {
'gpt-4': 8191, 'gpt-4': 8191,
'gpt-4-0613': 8191, 'gpt-4-0613': 8191,
'gpt-4-32k': 32767, 'gpt-4-32k': 32767,
'gpt-4-32k-0314': 32767,
'gpt-4-32k-0613': 32767, 'gpt-4-32k-0613': 32767,
'gpt-3.5-turbo': 4095, 'gpt-3.5-turbo': 4095,
'gpt-3.5-turbo-0613': 4095, 'gpt-3.5-turbo-0613': 4095,

7
package-lock.json generated
View file

@ -44,7 +44,6 @@
"dependencies": { "dependencies": {
"@anthropic-ai/sdk": "^0.5.4", "@anthropic-ai/sdk": "^0.5.4",
"@azure/search-documents": "^11.3.2", "@azure/search-documents": "^11.3.2",
"@dqbd/tiktoken": "^1.0.7",
"@keyv/mongo": "^2.1.8", "@keyv/mongo": "^2.1.8",
"@waylaidwanderer/chatgpt-api": "^1.37.2", "@waylaidwanderer/chatgpt-api": "^1.37.2",
"axios": "^1.3.4", "axios": "^1.3.4",
@ -81,6 +80,7 @@
"passport-local": "^1.0.0", "passport-local": "^1.0.0",
"pino": "^8.12.1", "pino": "^8.12.1",
"sharp": "^0.32.5", "sharp": "^0.32.5",
"tiktoken": "^1.0.10",
"ua-parser-js": "^1.0.36", "ua-parser-js": "^1.0.36",
"zod": "^3.22.2" "zod": "^3.22.2"
}, },
@ -21886,6 +21886,11 @@
"real-require": "^0.2.0" "real-require": "^0.2.0"
} }
}, },
"node_modules/tiktoken": {
"version": "1.0.10",
"resolved": "https://registry.npmjs.org/tiktoken/-/tiktoken-1.0.10.tgz",
"integrity": "sha512-gF8ndTCNu7WcRFbl1UUWaFIB4CTXmHzS3tRYdyUYF7x3C6YR6Evoao4zhKDmWIwv2PzNbzoQMV8Pxt+17lEDbA=="
},
"node_modules/tmp": { "node_modules/tmp": {
"version": "0.0.33", "version": "0.0.33",
"resolved": "https://registry.npmjs.org/tmp/-/tmp-0.0.33.tgz", "resolved": "https://registry.npmjs.org/tmp/-/tmp-0.0.33.tgz",