From 52f146dd97354791d6d144a6098ae5e39206fcae Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Thu, 17 Apr 2025 00:40:26 -0400 Subject: [PATCH] =?UTF-8?q?=F0=9F=A4=96=20feat:=20Support=20`o4-mini`=20an?= =?UTF-8?q?d=20`o3`=20Models=20(#6928)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: Add support for new OpenAI models (o4-mini, o3) and update related logic * 🔧 fix: Rename 'resubmitFiles' to 'isResubmission' for consistency across types and hooks * 🔧 fix: Replace hardcoded 'pending_req' with CacheKeys.PENDING_REQ for consistency in cache handling * 🔧 fix: Update cache handling to use Time.ONE_MINUTE instead of hardcoded TTL and streamline imports * 🔧 fix: Enhance message handling logic to correctly identify parent messages and streamline imports in useSSE --- api/app/clients/OpenAIClient.js | 5 ++- api/cache/clearPendingReq.js | 9 ++--- api/cache/getLogStores.js | 4 +-- api/models/tx.js | 2 ++ api/models/tx.spec.js | 9 +++++ api/server/controllers/agents/client.js | 4 +-- api/server/middleware/concurrentLimiter.js | 4 +-- api/utils/tokens.js | 2 ++ api/utils/tokens.spec.js | 9 +++++ client/src/common/types.ts | 2 +- client/src/components/Chat/Input/ChatForm.tsx | 1 + .../Chat/Messages/Content/EditMessage.tsx | 2 +- .../Endpoints/MessageEndpointIcon.tsx | 2 +- client/src/hooks/Chat/useChatFunctions.ts | 5 +-- client/src/hooks/Input/useAutoSave.ts | 8 +++-- client/src/hooks/SSE/useEventHandlers.ts | 11 +++++- client/src/hooks/SSE/useSSE.ts | 36 ++----------------- packages/data-provider/src/config.ts | 6 ++++ packages/data-provider/src/types.ts | 1 + 19 files changed, 69 insertions(+), 53 deletions(-) diff --git a/api/app/clients/OpenAIClient.js b/api/app/clients/OpenAIClient.js index 91c99e438..dd437f0b9 100644 --- a/api/app/clients/OpenAIClient.js +++ b/api/app/clients/OpenAIClient.js @@ -108,7 +108,7 @@ class OpenAIClient extends BaseClient { this.checkVisionRequest(this.options.attachments); } - const omniPattern = /\b(o1|o3)\b/i; + const omniPattern = /\b(o\d)\b/i; this.isOmni = omniPattern.test(this.modelOptions.model); const { OPENAI_FORCE_PROMPT } = process.env ?? {}; @@ -1237,6 +1237,9 @@ ${convo} modelOptions.max_completion_tokens = modelOptions.max_tokens; delete modelOptions.max_tokens; } + if (this.isOmni === true && modelOptions.temperature != null) { + delete modelOptions.temperature; + } if (process.env.OPENAI_ORGANIZATION) { opts.organization = process.env.OPENAI_ORGANIZATION; diff --git a/api/cache/clearPendingReq.js b/api/cache/clearPendingReq.js index 122638d7f..54db8e969 100644 --- a/api/cache/clearPendingReq.js +++ b/api/cache/clearPendingReq.js @@ -1,7 +1,8 @@ +const { Time, CacheKeys } = require('librechat-data-provider'); +const { isEnabled } = require('~/server/utils'); const getLogStores = require('./getLogStores'); -const { isEnabled } = require('../server/utils'); + const { USE_REDIS, LIMIT_CONCURRENT_MESSAGES } = process.env ?? {}; -const ttl = 1000 * 60 * 1; /** * Clear or decrement pending requests from the cache. @@ -28,7 +29,7 @@ const clearPendingReq = async ({ userId, cache: _cache }) => { return; } - const namespace = 'pending_req'; + const namespace = CacheKeys.PENDING_REQ; const cache = _cache ?? getLogStores(namespace); if (!cache) { @@ -39,7 +40,7 @@ const clearPendingReq = async ({ userId, cache: _cache }) => { const currentReq = +((await cache.get(key)) ?? 0); if (currentReq && currentReq >= 1) { - await cache.set(key, currentReq - 1, ttl); + await cache.set(key, currentReq - 1, Time.ONE_MINUTE); } else { await cache.delete(key); } diff --git a/api/cache/getLogStores.js b/api/cache/getLogStores.js index e652cfdee..612638b97 100644 --- a/api/cache/getLogStores.js +++ b/api/cache/getLogStores.js @@ -19,7 +19,7 @@ const createViolationInstance = (namespace) => { // Serve cache from memory so no need to clear it on startup/exit const pending_req = isRedisEnabled ? new Keyv({ store: keyvRedis }) - : new Keyv({ namespace: 'pending_req' }); + : new Keyv({ namespace: CacheKeys.PENDING_REQ }); const config = isRedisEnabled ? new Keyv({ store: keyvRedis }) @@ -64,7 +64,7 @@ const abortKeys = isRedisEnabled const namespaces = { [CacheKeys.ROLES]: roles, [CacheKeys.CONFIG_STORE]: config, - pending_req, + [CacheKeys.PENDING_REQ]: pending_req, [ViolationTypes.BAN]: new Keyv({ store: keyvMongo, namespace: CacheKeys.BANS, ttl: duration }), [CacheKeys.ENCODED_DOMAINS]: new Keyv({ store: keyvMongo, diff --git a/api/models/tx.js b/api/models/tx.js index c141cd0d2..67d954a9a 100644 --- a/api/models/tx.js +++ b/api/models/tx.js @@ -76,7 +76,9 @@ const tokenValues = Object.assign( '4k': { prompt: 1.5, completion: 2 }, '16k': { prompt: 3, completion: 4 }, 'gpt-3.5-turbo-1106': { prompt: 1, completion: 2 }, + 'o4-mini': { prompt: 1.1, completion: 4.4 }, 'o3-mini': { prompt: 1.1, completion: 4.4 }, + o3: { prompt: 10, completion: 40 }, 'o1-mini': { prompt: 1.1, completion: 4.4 }, 'o1-preview': { prompt: 15, completion: 60 }, o1: { prompt: 15, completion: 60 }, diff --git a/api/models/tx.spec.js b/api/models/tx.spec.js index c2b4326ab..5e1681072 100644 --- a/api/models/tx.spec.js +++ b/api/models/tx.spec.js @@ -165,6 +165,15 @@ describe('getMultiplier', () => { ); }); + it('should return correct multipliers for o4-mini and o3', () => { + ['o4-mini', 'o3'].forEach((model) => { + const prompt = getMultiplier({ model, tokenType: 'prompt' }); + const completion = getMultiplier({ model, tokenType: 'completion' }); + expect(prompt).toBe(tokenValues[model].prompt); + expect(completion).toBe(tokenValues[model].completion); + }); + }); + it('should return defaultRate if tokenType is provided but not found in tokenValues', () => { expect(getMultiplier({ valueKey: '8k', tokenType: 'unknownType' })).toBe(defaultRate); }); diff --git a/api/server/controllers/agents/client.js b/api/server/controllers/agents/client.js index 09290b59f..b462a8a0c 100644 --- a/api/server/controllers/agents/client.js +++ b/api/server/controllers/agents/client.js @@ -58,7 +58,7 @@ const payloadParser = ({ req, agent, endpoint }) => { const legacyContentEndpoints = new Set([KnownEndpoints.groq, KnownEndpoints.deepseek]); -const noSystemModelRegex = [/\bo1\b/gi]; +const noSystemModelRegex = [/\b(o\d)\b/gi]; // const { processMemory, memoryInstructions } = require('~/server/services/Endpoints/agents/memory'); // const { getFormattedMemories } = require('~/models/Memory'); @@ -975,7 +975,7 @@ class AgentClient extends BaseClient { }) )?.llmConfig ?? clientOptions; } - if (/\b(o1|o3)\b/i.test(clientOptions.model) && clientOptions.maxTokens != null) { + if (/\b(o\d)\b/i.test(clientOptions.model) && clientOptions.maxTokens != null) { delete clientOptions.maxTokens; } try { diff --git a/api/server/middleware/concurrentLimiter.js b/api/server/middleware/concurrentLimiter.js index 21b3a8690..73de65dd2 100644 --- a/api/server/middleware/concurrentLimiter.js +++ b/api/server/middleware/concurrentLimiter.js @@ -1,4 +1,4 @@ -const { Time } = require('librechat-data-provider'); +const { Time, CacheKeys } = require('librechat-data-provider'); const clearPendingReq = require('~/cache/clearPendingReq'); const { logViolation, getLogStores } = require('~/cache'); const { isEnabled } = require('~/server/utils'); @@ -25,7 +25,7 @@ const { * @throws {Error} Throws an error if the user exceeds the concurrent request limit. */ const concurrentLimiter = async (req, res, next) => { - const namespace = 'pending_req'; + const namespace = CacheKeys.PENDING_REQ; const cache = getLogStores(namespace); if (!cache) { return next(); diff --git a/api/utils/tokens.js b/api/utils/tokens.js index 0e505b00d..6faa097b7 100644 --- a/api/utils/tokens.js +++ b/api/utils/tokens.js @@ -2,7 +2,9 @@ const z = require('zod'); const { EModelEndpoint } = require('librechat-data-provider'); const openAIModels = { + 'o4-mini': 200000, 'o3-mini': 195000, // -5000 from max + o3: 200000, o1: 195000, // -5000 from max 'o1-mini': 127500, // -500 from max 'o1-preview': 127500, // -500 from max diff --git a/api/utils/tokens.spec.js b/api/utils/tokens.spec.js index 6bb4967cb..57a9f72e8 100644 --- a/api/utils/tokens.spec.js +++ b/api/utils/tokens.spec.js @@ -340,6 +340,15 @@ describe('getModelMaxTokens', () => { expect(getModelMaxTokens('o1-preview-something')).toBe(o1PreviewTokens); expect(getModelMaxTokens('openai/o1-preview-something')).toBe(o1PreviewTokens); }); + + test('should return correct max context tokens for o4-mini and o3', () => { + const o4MiniTokens = maxTokensMap[EModelEndpoint.openAI]['o4-mini']; + const o3Tokens = maxTokensMap[EModelEndpoint.openAI]['o3']; + expect(getModelMaxTokens('o4-mini')).toBe(o4MiniTokens); + expect(getModelMaxTokens('openai/o4-mini')).toBe(o4MiniTokens); + expect(getModelMaxTokens('o3')).toBe(o3Tokens); + expect(getModelMaxTokens('openai/o3')).toBe(o3Tokens); + }); }); describe('matchModelName', () => { diff --git a/client/src/common/types.ts b/client/src/common/types.ts index 02eeccb62..1c0191477 100644 --- a/client/src/common/types.ts +++ b/client/src/common/types.ts @@ -306,7 +306,7 @@ export type TAskProps = { export type TOptions = { editedMessageId?: string | null; editedText?: string | null; - resubmitFiles?: boolean; + isResubmission?: boolean; isRegenerate?: boolean; isContinued?: boolean; isEdited?: boolean; diff --git a/client/src/components/Chat/Input/ChatForm.tsx b/client/src/components/Chat/Input/ChatForm.tsx index 1aa72a0db..3e4e4f698 100644 --- a/client/src/components/Chat/Input/ChatForm.tsx +++ b/client/src/components/Chat/Input/ChatForm.tsx @@ -121,6 +121,7 @@ const ChatForm = memo(({ index = 0 }: { index?: number }) => { setFiles, textAreaRef, conversationId, + isSubmitting: isSubmitting || isSubmittingAdded, }); const { submitMessage, submitPrompt } = useSubmitMessage(); diff --git a/client/src/components/Chat/Messages/Content/EditMessage.tsx b/client/src/components/Chat/Messages/Content/EditMessage.tsx index aa2eca2d1..b422ede8d 100644 --- a/client/src/components/Chat/Messages/Content/EditMessage.tsx +++ b/client/src/components/Chat/Messages/Content/EditMessage.tsx @@ -60,7 +60,7 @@ const EditMessage = ({ conversationId, }, { - resubmitFiles: true, + isResubmission: true, }, ); diff --git a/client/src/components/Endpoints/MessageEndpointIcon.tsx b/client/src/components/Endpoints/MessageEndpointIcon.tsx index be30ccc67..5005641e0 100644 --- a/client/src/components/Endpoints/MessageEndpointIcon.tsx +++ b/client/src/components/Endpoints/MessageEndpointIcon.tsx @@ -25,7 +25,7 @@ type EndpointIcon = { function getOpenAIColor(_model: string | null | undefined) { const model = _model?.toLowerCase() ?? ''; - if (model && /\b(o1|o3)\b/i.test(model)) { + if (model && /\b(o\d)\b/i.test(model)) { return '#000000'; } return model.includes('gpt-4') ? '#AB68FF' : '#19C37D'; diff --git a/client/src/hooks/Chat/useChatFunctions.ts b/client/src/hooks/Chat/useChatFunctions.ts index c8a001856..dca77edf5 100644 --- a/client/src/hooks/Chat/useChatFunctions.ts +++ b/client/src/hooks/Chat/useChatFunctions.ts @@ -90,7 +90,7 @@ export default function useChatFunctions({ { editedText = null, editedMessageId = null, - resubmitFiles = false, + isResubmission = false, isRegenerate = false, isContinued = false, isEdited = false, @@ -202,7 +202,7 @@ export default function useChatFunctions({ }; const reuseFiles = - (isRegenerate || resubmitFiles) && parentMessage?.files && parentMessage.files.length > 0; + (isRegenerate || isResubmission) && parentMessage?.files && parentMessage.files.length > 0; if (setFiles && reuseFiles === true) { currentMsg.files = parentMessage.files; setFiles(new Map()); @@ -298,6 +298,7 @@ export default function useChatFunctions({ isEdited: isEditOrContinue, isContinued, isRegenerate, + isResubmission, initialResponse, isTemporary, ephemeralAgent, diff --git a/client/src/hooks/Input/useAutoSave.ts b/client/src/hooks/Input/useAutoSave.ts index 93232bc91..642087f44 100644 --- a/client/src/hooks/Input/useAutoSave.ts +++ b/client/src/hooks/Input/useAutoSave.ts @@ -1,7 +1,8 @@ import debounce from 'lodash/debounce'; import { SetterOrUpdater, useRecoilValue } from 'recoil'; import { useState, useEffect, useMemo, useCallback } from 'react'; -import { LocalStorageKeys, TFile } from 'librechat-data-provider'; +import { LocalStorageKeys, Constants } from 'librechat-data-provider'; +import type { TFile } from 'librechat-data-provider'; import type { ExtendedFile } from '~/common'; import { useChatFormContext } from '~/Providers'; import { useGetFiles } from '~/data-provider'; @@ -34,11 +35,13 @@ const decodeBase64 = (base64String: string): string => { }; export const useAutoSave = ({ - conversationId, + isSubmitting, + conversationId: _conversationId, textAreaRef, setFiles, files, }: { + isSubmitting?: boolean; conversationId?: string | null; textAreaRef?: React.RefObject; files: Map; @@ -47,6 +50,7 @@ export const useAutoSave = ({ // setting for auto-save const { setValue } = useChatFormContext(); const saveDrafts = useRecoilValue(store.saveDrafts); + const conversationId = isSubmitting ? Constants.PENDING_CONVO : _conversationId; const [currentConversationId, setCurrentConversationId] = useState(null); const fileIds = useMemo(() => Array.from(files.keys()), [files]); diff --git a/client/src/hooks/SSE/useEventHandlers.ts b/client/src/hooks/SSE/useEventHandlers.ts index cddf17067..260973cd9 100644 --- a/client/src/hooks/SSE/useEventHandlers.ts +++ b/client/src/hooks/SSE/useEventHandlers.ts @@ -613,8 +613,17 @@ export default function useEventHandlers({ messages?.[messages.length - 1] != null && messages[messages.length - 2] != null ) { - const requestMessage = messages[messages.length - 2]; + let requestMessage = messages[messages.length - 2]; const responseMessage = messages[messages.length - 1]; + if (requestMessage.messageId !== responseMessage.parentMessageId) { + // the request message is the parent of response, which we search for backwards + for (let i = messages.length - 3; i >= 0; i--) { + if (messages[i].messageId === responseMessage.parentMessageId) { + requestMessage = messages[i]; + break; + } + } + } finalHandler( { conversation: { diff --git a/client/src/hooks/SSE/useSSE.ts b/client/src/hooks/SSE/useSSE.ts index 404247941..7faad07e3 100644 --- a/client/src/hooks/SSE/useSSE.ts +++ b/client/src/hooks/SSE/useSSE.ts @@ -2,7 +2,6 @@ import { useEffect, useState } from 'react'; import { v4 } from 'uuid'; import { SSE } from 'sse.js'; import { useSetRecoilState } from 'recoil'; -import { useQueryClient } from '@tanstack/react-query'; import { request, Constants, @@ -13,18 +12,12 @@ import { removeNullishValues, isAssistantsEndpoint, } from 'librechat-data-provider'; -import type { - EventSubmission, - TConversation, - TMessage, - TPayload, - TSubmission, -} from 'librechat-data-provider'; +import type { TMessage, TPayload, TSubmission, EventSubmission } from 'librechat-data-provider'; import type { EventHandlerParams } from './useEventHandlers'; import type { TResData } from '~/common'; import { useGenTitleMutation, useGetStartupConfig, useGetUserBalance } from '~/data-provider'; -import useEventHandlers, { getConvoTitle } from './useEventHandlers'; import { useAuthContext } from '~/hooks/AuthContext'; +import useEventHandlers from './useEventHandlers'; import store from '~/store'; const clearDraft = (conversationId?: string | null) => { @@ -53,7 +46,6 @@ export default function useSSE( isAddedRequest = false, runIndex = 0, ) { - const queryClient = useQueryClient(); const genTitle = useGenTitleMutation(); const setActiveRunId = useSetRecoilState(store.activeRunFamily(runIndex)); @@ -107,30 +99,6 @@ export default function useSSE( let { userMessage } = submission; const payloadData = createPayload(submission); - /** - * Helps clear text immediately on submission instead of - * restoring draft, which gets deleted on generation end - * */ - const parentId = submission?.isRegenerate - ? userMessage.overrideParentMessageId - : userMessage.parentMessageId; - setConversation?.((prev: TConversation | null) => { - if (!prev) { - return null; - } - const title = - getConvoTitle({ - parentId, - queryClient, - currentTitle: prev?.title, - conversationId: prev?.conversationId, - }) ?? ''; - return { - ...prev, - title, - conversationId: Constants.PENDING_CONVO as string, - }; - }); let { payload } = payloadData; if (isAssistantsEndpoint(payload.endpoint) || isAgentsEndpoint(payload.endpoint)) { payload = removeNullishValues(payload) as TPayload; diff --git a/packages/data-provider/src/config.ts b/packages/data-provider/src/config.ts index 9742114a3..95bf24a3a 100644 --- a/packages/data-provider/src/config.ts +++ b/packages/data-provider/src/config.ts @@ -856,6 +856,8 @@ export const visionModels = [ 'gpt-4o', 'gpt-4-turbo', 'gpt-4-vision', + 'o4-mini', + 'o3', 'o1', 'gpt-4.1', 'gpt-4.5', @@ -1011,6 +1013,10 @@ export enum CacheKeys { * Key for in-progress flow states. */ FLOWS = 'flows', + /** + * Key for pending chat requests (concurrency check) + */ + PENDING_REQ = 'pending_req', /** * Key for s3 check intervals per user */ diff --git a/packages/data-provider/src/types.ts b/packages/data-provider/src/types.ts index 1ac3d1f46..b365c4e58 100644 --- a/packages/data-provider/src/types.ts +++ b/packages/data-provider/src/types.ts @@ -65,6 +65,7 @@ export type TSubmission = { isTemporary: boolean; messages: TMessage[]; isRegenerate?: boolean; + isResubmission?: boolean; initialResponse?: TMessage; conversation: Partial; endpointOption: TEndpointOption;