From 4beb06aa4bf8767d9edcc8a24b22d558bc525044 Mon Sep 17 00:00:00 2001 From: Danny Avila <110412045+danny-avila@users.noreply.github.com> Date: Sun, 21 May 2023 12:43:06 -0400 Subject: [PATCH] Minor fixes: tokenizer, default Bing toneStyle, SiblingSwitch (#348) * fix: tokenizer will count completion tokens correctly, remove global var, will allow unofficial models for alternative endpoints * refactor(askBingAI.js, Settings.jsx, types.ts, cleanupPreset.js, getDefaultConversation.js, handleSubmit.js): change default toneStyle to 'creative' instead of 'fast' for Bing AI endpoint. * fix(SiblingSwitch): correctly appears now style(HoverButtons.jsx): add 'active' class to hover buttons --- api/app/clients/chatgpt-client.js | 33 ++++---- api/server/routes/ask/askBingAI.js | 6 +- api/utils/tiktokenModels.js | 40 ++++++++++ .../components/Endpoints/BingAI/Settings.jsx | 4 +- .../src/components/Messages/HoverButtons.jsx | 4 +- client/src/data-provider/types.ts | 26 +++---- client/src/mobile.css | 78 +++++++++++-------- client/src/utils/cleanupPreset.js | 2 +- client/src/utils/getDefaultConversation.js | 8 +- client/src/utils/handleSubmit.js | 2 +- 10 files changed, 125 insertions(+), 78 deletions(-) create mode 100644 api/utils/tiktokenModels.js diff --git a/api/app/clients/chatgpt-client.js b/api/app/clients/chatgpt-client.js index 2ce7367a8..a5e9a05e6 100644 --- a/api/app/clients/chatgpt-client.js +++ b/api/app/clients/chatgpt-client.js @@ -2,6 +2,7 @@ require('dotenv').config(); const { KeyvFile } = require('keyv-file'); const { genAzureEndpoint } = require('../../utils/genAzureEndpoints'); const tiktoken = require('@dqbd/tiktoken'); +const tiktokenModels = require('../../utils/tiktokenModels'); const encoding_for_model = tiktoken.encoding_for_model; const askClient = async ({ @@ -26,9 +27,8 @@ const askClient = async ({ }; const azure = process.env.AZURE_OPENAI_API_KEY ? true : false; - if (promptPrefix == null) { - promptText = 'You are ChatGPT, a large language model trained by OpenAI.'; - } else { + let promptText = 'You are ChatGPT, a large language model trained by OpenAI.'; + if (promptPrefix) { promptText = promptPrefix; } const maxContextTokens = model === 'gpt-4' ? 8191 : model === 'gpt-4-32k' ? 32767 : 4095; // 1 less than maximum @@ -68,25 +68,18 @@ const askClient = async ({ ...(parentMessageId && conversationId ? { parentMessageId, conversationId } : {}) }; - const enc = encoding_for_model(model); - const text_tokens = enc.encode(text); - const prompt_tokens = enc.encode(promptText); - // console.log("Prompt tokens = ", prompt_tokens.length); - // console.log("Message Tokens = ", text_tokens.length); - + const enc = encoding_for_model(tiktokenModels.has(model) ? model : 'gpt-3.5-turbo'); + const usage = { + prompt_tokens: (enc.encode(promptText)).length + (enc.encode(text)).length, + } + const res = await client.sendMessage(text, { ...options, userId }); - // return res; - // create a new response object that includes the token counts - const newRes = { + usage.completion_tokens = (enc.encode(res.response)).length; + usage.total_tokens = usage.prompt_tokens + usage.completion_tokens; + return { ...res, - usage: { - prompt_tokens: prompt_tokens.length, - completion_tokens: text_tokens.length, - total_tokens: prompt_tokens.length + text_tokens.length - } - }; - - return newRes; + usage, + } }; module.exports = { askClient }; diff --git a/api/server/routes/ask/askBingAI.js b/api/server/routes/ask/askBingAI.js index 25e61bcfc..a4a962be7 100644 --- a/api/server/routes/ask/askBingAI.js +++ b/api/server/routes/ask/askBingAI.js @@ -40,7 +40,7 @@ router.post('/', requireJwtAuth, async (req, res) => { jailbreakConversationId: req.body?.jailbreakConversationId ?? null, systemMessage: req.body?.systemMessage ?? null, context: req.body?.context ?? null, - toneStyle: req.body?.toneStyle ?? 'fast', + toneStyle: req.body?.toneStyle ?? 'creative', token: req.body?.token ?? null }; else @@ -51,7 +51,7 @@ router.post('/', requireJwtAuth, async (req, res) => { conversationSignature: req.body?.conversationSignature ?? null, clientId: req.body?.clientId ?? null, invocationId: req.body?.invocationId ?? null, - toneStyle: req.body?.toneStyle ?? 'fast', + toneStyle: req.body?.toneStyle ?? 'creative', token: req.body?.token ?? null }; @@ -110,7 +110,7 @@ const ask = async ({ try { let lastSavedTimestamp = 0; - const { onProgress: progressCallback, getPartialText } = createOnProgress({ + const { onProgress: progressCallback } = createOnProgress({ onProgress: ({ text }) => { const currentTimestamp = Date.now(); if (currentTimestamp - lastSavedTimestamp > 500) { diff --git a/api/utils/tiktokenModels.js b/api/utils/tiktokenModels.js new file mode 100644 index 000000000..3107cfff0 --- /dev/null +++ b/api/utils/tiktokenModels.js @@ -0,0 +1,40 @@ +const models = [ + 'text-davinci-003', + 'text-davinci-002', + 'text-davinci-001', + 'text-curie-001', + 'text-babbage-001', + 'text-ada-001', + 'davinci', + 'curie', + 'babbage', + 'ada', + 'code-davinci-002', + 'code-davinci-001', + 'code-cushman-002', + 'code-cushman-001', + 'davinci-codex', + 'cushman-codex', + 'text-davinci-edit-001', + 'code-davinci-edit-001', + 'text-embedding-ada-002', + 'text-similarity-davinci-001', + 'text-similarity-curie-001', + 'text-similarity-babbage-001', + 'text-similarity-ada-001', + 'text-search-davinci-doc-001', + 'text-search-curie-doc-001', + 'text-search-babbage-doc-001', + 'text-search-ada-doc-001', + 'code-search-babbage-code-001', + 'code-search-ada-code-001', + 'gpt2', + 'gpt-4', + 'gpt-4-0314', + 'gpt-4-32k', + 'gpt-4-32k-0314', + 'gpt-3.5-turbo', + 'gpt-3.5-turbo-0301' +]; + +module.exports = new Set(models); diff --git a/client/src/components/Endpoints/BingAI/Settings.jsx b/client/src/components/Endpoints/BingAI/Settings.jsx index 18c2b4ce9..26e030a89 100644 --- a/client/src/components/Endpoints/BingAI/Settings.jsx +++ b/client/src/components/Endpoints/BingAI/Settings.jsx @@ -47,7 +47,7 @@ function Settings(props) {
System Message {' '} diff --git a/client/src/components/Messages/HoverButtons.jsx b/client/src/components/Messages/HoverButtons.jsx index d7755e7ab..334976109 100644 --- a/client/src/components/Messages/HoverButtons.jsx +++ b/client/src/components/Messages/HoverButtons.jsx @@ -53,7 +53,7 @@ export default function HoverButtons({ ) : null} {regenerateEnabled ? (