🚀 feat: o1 (#4019)

* feat: o1 default response sender string

* feat: add o1 models to default openai models list, add `no_system_messages` error type; refactor: use error type as localization key

* refactor(MessageEndpointIcon): differentiate openAI icon model color for o1 models

* refactor(AnthropicClient): use new input/output tokens keys; add prompt caching for claude-3-opus

* refactor(BaseClient): to use new input/output tokens keys; update typedefs

* feat: initial o1 model handling, including token cost complexity

* EXPERIMENTAL: special handling for o1 model with custom instructions
This commit is contained in:
Danny Avila 2024-09-12 18:15:43 -04:00 committed by GitHub
parent 9a393be012
commit 45b42830a5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 229 additions and 36 deletions

View file

@ -64,6 +64,12 @@ class AnthropicClient extends BaseClient {
/** Whether or not the model supports Prompt Caching
* @type {boolean} */
this.supportsCacheControl;
/** The key for the usage object's input tokens
* @type {string} */
this.inputTokensKey = 'input_tokens';
/** The key for the usage object's output tokens
* @type {string} */
this.outputTokensKey = 'output_tokens';
}
setOptions(options) {
@ -200,7 +206,7 @@ class AnthropicClient extends BaseClient {
}
/**
* Calculates the correct token count for the current message based on the token count map and API usage.
* Calculates the correct token count for the current user message based on the token count map and API usage.
* Edge case: If the calculation results in a negative value, it returns the original estimate.
* If revisiting a conversation with a chat history entirely composed of token estimates,
* the cumulative token count going forward should become more accurate as the conversation progresses.
@ -208,7 +214,7 @@ class AnthropicClient extends BaseClient {
* @param {Record<string, number>} params.tokenCountMap - A map of message IDs to their token counts.
* @param {string} params.currentMessageId - The ID of the current message to calculate.
* @param {AnthropicStreamUsage} params.usage - The usage object returned by the API.
* @returns {number} The correct token count for the current message.
* @returns {number} The correct token count for the current user message.
*/
calculateCurrentTokenCount({ tokenCountMap, currentMessageId, usage }) {
const originalEstimate = tokenCountMap[currentMessageId] || 0;
@ -680,7 +686,11 @@ class AnthropicClient extends BaseClient {
*/
checkPromptCacheSupport(modelName) {
const modelMatch = matchModelName(modelName, EModelEndpoint.anthropic);
if (modelMatch === 'claude-3-5-sonnet' || modelMatch === 'claude-3-haiku') {
if (
modelMatch === 'claude-3-5-sonnet' ||
modelMatch === 'claude-3-haiku' ||
modelMatch === 'claude-3-opus'
) {
return true;
}
return false;

View file

@ -42,6 +42,12 @@ class BaseClient {
this.conversationId;
/** @type {string} */
this.responseMessageId;
/** The key for the usage object's input tokens
* @type {string} */
this.inputTokensKey = 'prompt_tokens';
/** The key for the usage object's output tokens
* @type {string} */
this.outputTokensKey = 'completion_tokens';
}
setOptions() {
@ -604,8 +610,8 @@ class BaseClient {
* @type {StreamUsage | null} */
const usage = this.getStreamUsage != null ? this.getStreamUsage() : null;
if (usage != null && Number(usage.output_tokens) > 0) {
responseMessage.tokenCount = usage.output_tokens;
if (usage != null && Number(usage[this.outputTokensKey]) > 0) {
responseMessage.tokenCount = usage[this.outputTokensKey];
completionTokens = responseMessage.tokenCount;
await this.updateUserMessageTokenCount({ usage, tokenCountMap, userMessage, opts });
} else {
@ -655,7 +661,7 @@ class BaseClient {
/** @type {boolean} */
const shouldUpdateCount =
this.calculateCurrentTokenCount != null &&
Number(usage.input_tokens) > 0 &&
Number(usage[this.inputTokensKey]) > 0 &&
(this.options.resendFiles ||
(!this.options.resendFiles && !this.options.attachments?.length)) &&
!this.options.promptPrefix;

View file

@ -19,6 +19,7 @@ const {
constructAzureURL,
getModelMaxTokens,
genAzureChatCompletion,
getModelMaxOutputTokens,
} = require('~/utils');
const {
truncateText,
@ -64,6 +65,9 @@ class OpenAIClient extends BaseClient {
/** @type {string | undefined} - The API Completions URL */
this.completionsUrl;
/** @type {OpenAIUsageMetadata | undefined} */
this.usage;
}
// TODO: PluginsClient calls this 3x, unneeded
@ -138,7 +142,8 @@ class OpenAIClient extends BaseClient {
const { model } = this.modelOptions;
this.isChatCompletion = this.useOpenRouter || !!reverseProxy || model.includes('gpt');
this.isChatCompletion =
/\bo1\b/i.test(model) || model.includes('gpt') || this.useOpenRouter || !!reverseProxy;
this.isChatGptModel = this.isChatCompletion;
if (
model.includes('text-davinci') ||
@ -169,7 +174,14 @@ class OpenAIClient extends BaseClient {
logger.debug('[OpenAIClient] maxContextTokens', this.maxContextTokens);
}
this.maxResponseTokens = this.modelOptions.max_tokens || 1024;
this.maxResponseTokens =
this.modelOptions.max_tokens ??
getModelMaxOutputTokens(
model,
this.options.endpointType ?? this.options.endpoint,
this.options.endpointTokenConfig,
) ??
1024;
this.maxPromptTokens =
this.options.maxPromptTokens || this.maxContextTokens - this.maxResponseTokens;
@ -533,7 +545,8 @@ class OpenAIClient extends BaseClient {
promptPrefix = this.augmentedPrompt + promptPrefix;
}
if (promptPrefix) {
const isO1Model = /\bo1\b/i.test(this.modelOptions.model);
if (promptPrefix && !isO1Model) {
promptPrefix = `Instructions:\n${promptPrefix.trim()}`;
instructions = {
role: 'system',
@ -561,6 +574,16 @@ class OpenAIClient extends BaseClient {
messages,
};
/** EXPERIMENTAL */
if (promptPrefix && isO1Model) {
const lastUserMessageIndex = payload.findLastIndex((message) => message.role === 'user');
if (lastUserMessageIndex !== -1) {
payload[
lastUserMessageIndex
].content = `${promptPrefix}\n${payload[lastUserMessageIndex].content}`;
}
}
if (tokenCountMap) {
tokenCountMap.instructions = instructions?.tokenCount;
result.tokenCountMap = tokenCountMap;
@ -885,6 +908,56 @@ ${convo}
return title;
}
/**
* Get stream usage as returned by this client's API response.
* @returns {OpenAIUsageMetadata} The stream usage object.
*/
getStreamUsage() {
if (
typeof this.usage === 'object' &&
typeof this.usage.completion_tokens_details === 'object'
) {
const outputTokens = Math.abs(
this.usage.completion_tokens_details.reasoning_tokens - this.usage[this.outputTokensKey],
);
return {
...this.usage.completion_tokens_details,
[this.inputTokensKey]: this.usage[this.inputTokensKey],
[this.outputTokensKey]: outputTokens,
};
}
return this.usage;
}
/**
* Calculates the correct token count for the current user message based on the token count map and API usage.
* Edge case: If the calculation results in a negative value, it returns the original estimate.
* If revisiting a conversation with a chat history entirely composed of token estimates,
* the cumulative token count going forward should become more accurate as the conversation progresses.
* @param {Object} params - The parameters for the calculation.
* @param {Record<string, number>} params.tokenCountMap - A map of message IDs to their token counts.
* @param {string} params.currentMessageId - The ID of the current message to calculate.
* @param {OpenAIUsageMetadata} params.usage - The usage object returned by the API.
* @returns {number} The correct token count for the current user message.
*/
calculateCurrentTokenCount({ tokenCountMap, currentMessageId, usage }) {
const originalEstimate = tokenCountMap[currentMessageId] || 0;
if (!usage || typeof usage[this.inputTokensKey] !== 'number') {
return originalEstimate;
}
tokenCountMap[currentMessageId] = 0;
const totalTokensFromMap = Object.values(tokenCountMap).reduce((sum, count) => {
const numCount = Number(count);
return sum + (isNaN(numCount) ? 0 : numCount);
}, 0);
const totalInputTokens = usage[this.inputTokensKey] ?? 0;
const currentMessageTokens = totalInputTokens - totalTokensFromMap;
return currentMessageTokens > 0 ? currentMessageTokens : originalEstimate;
}
async summarizeMessages({ messagesToRefine, remainingContextTokens }) {
logger.debug('[OpenAIClient] Summarizing messages...');
let context = messagesToRefine;
@ -1000,7 +1073,16 @@ ${convo}
}
}
async recordTokenUsage({ promptTokens, completionTokens, context = 'message' }) {
/**
* @param {object} params
* @param {number} params.promptTokens
* @param {number} params.completionTokens
* @param {OpenAIUsageMetadata} [params.usage]
* @param {string} [params.model]
* @param {string} [params.context='message']
* @returns {Promise<void>}
*/
async recordTokenUsage({ promptTokens, completionTokens, usage, context = 'message' }) {
await spendTokens(
{
context,
@ -1011,6 +1093,19 @@ ${convo}
},
{ promptTokens, completionTokens },
);
if (typeof usage === 'object' && typeof usage.reasoning_tokens === 'number') {
await spendTokens(
{
context: 'reasoning',
model: this.modelOptions.model,
conversationId: this.conversationId,
user: this.user ?? this.options.req.user?.id,
endpointTokenConfig: this.options.endpointTokenConfig,
},
{ completionTokens: usage.reasoning_tokens },
);
}
}
getTokenCountForResponse(response) {
@ -1191,6 +1286,10 @@ ${convo}
/** @type {(value: void | PromiseLike<void>) => void} */
let streamResolve;
if (modelOptions.stream && /\bo1\b/i.test(modelOptions.model)) {
delete modelOptions.stream;
}
if (modelOptions.stream) {
streamPromise = new Promise((resolve) => {
streamResolve = resolve;
@ -1269,6 +1368,8 @@ ${convo}
}
const { choices } = chatCompletion;
this.usage = chatCompletion.usage;
if (!Array.isArray(choices) || choices.length === 0) {
logger.warn('[OpenAIClient] Chat completion response has no choices');
return intermediateReply.join('');

View file

@ -37,6 +37,9 @@ const tokenValues = Object.assign(
'4k': { prompt: 1.5, completion: 2 },
'16k': { prompt: 3, completion: 4 },
'gpt-3.5-turbo-1106': { prompt: 1, completion: 2 },
'o1-preview': { prompt: 15, completion: 60 },
'o1-mini': { prompt: 3, completion: 12 },
o1: { prompt: 15, completion: 60 },
'gpt-4o-2024-08-06': { prompt: 2.5, completion: 10 },
'gpt-4o-mini': { prompt: 0.15, completion: 0.6 },
'gpt-4o': { prompt: 5, completion: 15 },
@ -95,6 +98,12 @@ const getValueKey = (model, endpoint) => {
return 'gpt-3.5-turbo-1106';
} else if (modelName.includes('gpt-3.5')) {
return '4k';
} else if (modelName.includes('o1-preview')) {
return 'o1-preview';
} else if (modelName.includes('o1-mini')) {
return 'o1-mini';
} else if (modelName.includes('o1')) {
return 'o1';
} else if (modelName.includes('gpt-4o-2024-08-06')) {
return 'gpt-4o-2024-08-06';
} else if (modelName.includes('gpt-4o-mini')) {

View file

@ -173,6 +173,10 @@ const handleAbortError = async (res, req, error, data) => {
errorText = `{"type":"${ErrorTypes.INVALID_REQUEST}"}`;
}
if (error?.message?.includes('does not support \'system\'')) {
errorText = `{"type":"${ErrorTypes.NO_SYSTEM_MESSAGES}"}`;
}
const respondWithError = async (partialText) => {
let options = {
sender,

View file

@ -1443,7 +1443,19 @@
*/
/**
* @typedef {AnthropicStreamUsage} StreamUsage - Stream usage for all providers (currently only Anthropic)
* @exports OpenAIUsageMetadata
* @typedef {Object} OpenAIUsageMetadata - Usage statistics related to the run. This value will be `null` if the run is not in a terminal state (i.e. `in_progress`, `queued`, etc.).
* @property {number} [usage.completion_tokens] - Number of completion tokens used over the course of the run.
* @property {number} [usage.prompt_tokens] - Number of prompt tokens used over the course of the run.
* @property {number} [usage.total_tokens] - Total number of tokens used (prompt + completion).
* @property {number} [usage.reasoning_tokens] - Total number of tokens used for reasoning (OpenAI o1 models).
* @property {Object} [usage.completion_tokens_details] - Further details on the completion tokens used (OpenAI o1 models).
* @property {number} [usage.completion_tokens_details.reasoning_tokens] - Total number of tokens used for reasoning (OpenAI o1 models).
* @memberof typedefs
*/
/**
* @typedef {AnthropicStreamUsage | OpenAIUsageMetadata | UsageMetadata} StreamUsage - Stream usage for all providers (currently only Anthropic, OpenAI, LangChain)
*/
/* Native app/client methods */

View file

@ -2,6 +2,9 @@ const z = require('zod');
const { EModelEndpoint } = require('librechat-data-provider');
const openAIModels = {
o1: 127500, // -500 from max
'o1-mini': 127500, // -500 from max
'o1-preview': 127500, // -500 from max
'gpt-4': 8187, // -5 from max
'gpt-4-0613': 8187, // -5 from max
'gpt-4-32k': 32758, // -10 from max
@ -113,6 +116,19 @@ const maxTokensMap = {
[EModelEndpoint.bedrock]: bedrockModels,
};
const modelMaxOutputs = {
o1: 32268, // -500 from max: 32,768
'o1-mini': 65136, // -500 from max: 65,536
'o1-preview': 32268, // -500 from max: 32,768
system_default: 1024,
};
const maxOutputTokensMap = {
[EModelEndpoint.azureOpenAI]: modelMaxOutputs,
[EModelEndpoint.openAI]: modelMaxOutputs,
[EModelEndpoint.custom]: modelMaxOutputs,
};
/**
* Finds the first matching pattern in the tokens map.
* @param {string} modelName
@ -132,27 +148,15 @@ function findMatchingPattern(modelName, tokensMap) {
}
/**
* Retrieves the maximum tokens for a given model name. If the exact model name isn't found,
* it searches for partial matches within the model name, checking keys in reverse order.
* Retrieves a token value for a given model name from a tokens map.
*
* @param {string} modelName - The name of the model to look up.
* @param {string} endpoint - The endpoint (default is 'openAI').
* @param {EndpointTokenConfig} [endpointTokenConfig] - Token Config for current endpoint to use for max tokens lookup
* @returns {number|undefined} The maximum tokens for the given model or undefined if no match is found.
*
* @example
* getModelMaxTokens('gpt-4-32k-0613'); // Returns 32767
* getModelMaxTokens('gpt-4-32k-unknown'); // Returns 32767
* getModelMaxTokens('unknown-model'); // Returns undefined
* @param {EndpointTokenConfig | Record<string, number>} tokensMap - The map of model names to token values.
* @param {string} [key='context'] - The key to look up in the tokens map.
* @returns {number|undefined} The token value for the given model or undefined if no match is found.
*/
function getModelMaxTokens(modelName, endpoint = EModelEndpoint.openAI, endpointTokenConfig) {
if (typeof modelName !== 'string') {
return undefined;
}
/** @type {EndpointTokenConfig | Record<string, number>} */
const tokensMap = endpointTokenConfig ?? maxTokensMap[endpoint];
if (!tokensMap) {
function getModelTokenValue(modelName, tokensMap, key = 'context') {
if (typeof modelName !== 'string' || !tokensMap) {
return undefined;
}
@ -168,10 +172,36 @@ function getModelMaxTokens(modelName, endpoint = EModelEndpoint.openAI, endpoint
if (matchedPattern) {
const result = tokensMap[matchedPattern];
return result?.context ?? result;
return result?.[key] ?? result ?? tokensMap.system_default;
}
return undefined;
return tokensMap.system_default;
}
/**
* Retrieves the maximum tokens for a given model name.
*
* @param {string} modelName - The name of the model to look up.
* @param {string} endpoint - The endpoint (default is 'openAI').
* @param {EndpointTokenConfig} [endpointTokenConfig] - Token Config for current endpoint to use for max tokens lookup
* @returns {number|undefined} The maximum tokens for the given model or undefined if no match is found.
*/
function getModelMaxTokens(modelName, endpoint = EModelEndpoint.openAI, endpointTokenConfig) {
const tokensMap = endpointTokenConfig ?? maxTokensMap[endpoint];
return getModelTokenValue(modelName, tokensMap);
}
/**
* Retrieves the maximum output tokens for a given model name.
*
* @param {string} modelName - The name of the model to look up.
* @param {string} endpoint - The endpoint (default is 'openAI').
* @param {EndpointTokenConfig} [endpointTokenConfig] - Token Config for current endpoint to use for max tokens lookup
* @returns {number|undefined} The maximum output tokens for the given model or undefined if no match is found.
*/
function getModelMaxOutputTokens(modelName, endpoint = EModelEndpoint.openAI, endpointTokenConfig) {
const tokensMap = endpointTokenConfig ?? maxOutputTokensMap[endpoint];
return getModelTokenValue(modelName, tokensMap, 'output');
}
/**
@ -298,7 +328,8 @@ module.exports = {
maxTokensMap,
inputSchema,
modelSchema,
getModelMaxTokens,
matchModelName,
processModelData,
getModelMaxTokens,
getModelMaxOutputTokens,
};

View file

@ -17,6 +17,14 @@ import {
import { IconProps } from '~/common';
import { cn } from '~/utils';
function getOpenAIColor(_model: string | null | undefined) {
const model = _model?.toLowerCase() ?? '';
if (model && /\bo1\b/i.test(model)) {
return '#000000';
}
return model.includes('gpt-4') ? '#AB68FF' : '#19C37D';
}
function getGoogleIcon(model: string | null | undefined, size: number) {
if (model?.toLowerCase().includes('code') === true) {
return <CodeyIcon size={size * 0.75} />;
@ -119,8 +127,7 @@ const MessageEndpointIcon: React.FC<IconProps> = (props) => {
},
[EModelEndpoint.openAI]: {
icon: <GPTIcon size={size * 0.5555555555555556} />,
bg:
typeof model === 'string' && model.toLowerCase().includes('gpt-4') ? '#AB68FF' : '#19C37D',
bg: getOpenAIColor(model),
name: 'ChatGPT',
},
[EModelEndpoint.gptPlugins]: {

View file

@ -42,7 +42,8 @@ const errorMessages = {
[ErrorTypes.NO_USER_KEY]: 'com_error_no_user_key',
[ErrorTypes.INVALID_USER_KEY]: 'com_error_invalid_user_key',
[ErrorTypes.NO_BASE_URL]: 'com_error_no_base_url',
[ErrorTypes.INVALID_REQUEST]: 'com_error_invalid_request',
[ErrorTypes.INVALID_REQUEST]: `com_error_${ErrorTypes.INVALID_REQUEST}`,
[ErrorTypes.NO_SYSTEM_MESSAGES]: `com_error_${ErrorTypes.NO_SYSTEM_MESSAGES}`,
[ErrorTypes.EXPIRED_USER_KEY]: (json: TExpiredKey, localize: LocalizeFunction) => {
const { expiredAt, endpoint } = json;
return localize('com_error_expired_user_key', endpoint, expiredAt);

View file

@ -24,8 +24,10 @@ export default {
com_error_no_base_url: 'No base URL found. Please provide one and try again.',
com_warning_resubmit_unsupported:
'Resubmitting the AI message is not supported for this endpoint.',
com_error_invalid_request:
com_error_invalid_request_error:
'The AI service rejected the request due to an error. This could be caused by an invalid API key or an improperly formatted request.',
com_error_no_system_messages:
'The selected AI service or model does not support system messages. Try using prompts instead of custom instructions.',
com_error_invalid_user_key: 'Invalid key provided. Please provide a valid key and try again.',
com_error_expired_user_key:
'Provided key for {0} expired at {1}. Please provide a new key and try again.',

View file

@ -11,6 +11,10 @@ export const defaultSocialLogins = ['google', 'facebook', 'openid', 'github', 'd
export const defaultRetrievalModels = [
'gpt-4o',
'o1-preview-2024-09-12',
'o1-preview',
'o1-mini-2024-09-12',
'o1-mini',
'chatgpt-4o-latest',
'gpt-4o-2024-05-13',
'gpt-4o-2024-08-06',
@ -951,6 +955,10 @@ export enum ErrorTypes {
* Invalid request error, API rejected request
*/
INVALID_REQUEST = 'invalid_request_error',
/**
* Invalid request error, API rejected request
*/
NO_SYSTEM_MESSAGES = 'no_system_messages',
}
/**

View file

@ -240,6 +240,8 @@ export const getResponseSender = (endpointOption: t.TEndpointOption): string =>
) {
if (chatGptLabel) {
return chatGptLabel;
} else if (model && /\bo1\b/i.test(model)) {
return 'o1';
} else if (model && model.includes('gpt-3')) {
return 'GPT-3.5';
} else if (model && model.includes('gpt-4o')) {