🤖 feat: Support o4-mini and o3 Models (#6928)

* feat: Add support for new OpenAI models (o4-mini, o3) and update related logic

* 🔧 fix: Rename 'resubmitFiles' to 'isResubmission' for consistency across types and hooks

* 🔧 fix: Replace hardcoded 'pending_req' with CacheKeys.PENDING_REQ for consistency in cache handling

* 🔧 fix: Update cache handling to use Time.ONE_MINUTE instead of hardcoded TTL and streamline imports

* 🔧 fix: Enhance message handling logic to correctly identify parent messages and streamline imports in useSSE
This commit is contained in:
Danny Avila 2025-04-17 00:40:26 -04:00 committed by GitHub
parent 88f4ad7c47
commit 52f146dd97
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 69 additions and 53 deletions

View file

@ -108,7 +108,7 @@ class OpenAIClient extends BaseClient {
this.checkVisionRequest(this.options.attachments); this.checkVisionRequest(this.options.attachments);
} }
const omniPattern = /\b(o1|o3)\b/i; const omniPattern = /\b(o\d)\b/i;
this.isOmni = omniPattern.test(this.modelOptions.model); this.isOmni = omniPattern.test(this.modelOptions.model);
const { OPENAI_FORCE_PROMPT } = process.env ?? {}; const { OPENAI_FORCE_PROMPT } = process.env ?? {};
@ -1237,6 +1237,9 @@ ${convo}
modelOptions.max_completion_tokens = modelOptions.max_tokens; modelOptions.max_completion_tokens = modelOptions.max_tokens;
delete modelOptions.max_tokens; delete modelOptions.max_tokens;
} }
if (this.isOmni === true && modelOptions.temperature != null) {
delete modelOptions.temperature;
}
if (process.env.OPENAI_ORGANIZATION) { if (process.env.OPENAI_ORGANIZATION) {
opts.organization = process.env.OPENAI_ORGANIZATION; opts.organization = process.env.OPENAI_ORGANIZATION;

View file

@ -1,7 +1,8 @@
const { Time, CacheKeys } = require('librechat-data-provider');
const { isEnabled } = require('~/server/utils');
const getLogStores = require('./getLogStores'); const getLogStores = require('./getLogStores');
const { isEnabled } = require('../server/utils');
const { USE_REDIS, LIMIT_CONCURRENT_MESSAGES } = process.env ?? {}; const { USE_REDIS, LIMIT_CONCURRENT_MESSAGES } = process.env ?? {};
const ttl = 1000 * 60 * 1;
/** /**
* Clear or decrement pending requests from the cache. * Clear or decrement pending requests from the cache.
@ -28,7 +29,7 @@ const clearPendingReq = async ({ userId, cache: _cache }) => {
return; return;
} }
const namespace = 'pending_req'; const namespace = CacheKeys.PENDING_REQ;
const cache = _cache ?? getLogStores(namespace); const cache = _cache ?? getLogStores(namespace);
if (!cache) { if (!cache) {
@ -39,7 +40,7 @@ const clearPendingReq = async ({ userId, cache: _cache }) => {
const currentReq = +((await cache.get(key)) ?? 0); const currentReq = +((await cache.get(key)) ?? 0);
if (currentReq && currentReq >= 1) { if (currentReq && currentReq >= 1) {
await cache.set(key, currentReq - 1, ttl); await cache.set(key, currentReq - 1, Time.ONE_MINUTE);
} else { } else {
await cache.delete(key); await cache.delete(key);
} }

View file

@ -19,7 +19,7 @@ const createViolationInstance = (namespace) => {
// Serve cache from memory so no need to clear it on startup/exit // Serve cache from memory so no need to clear it on startup/exit
const pending_req = isRedisEnabled const pending_req = isRedisEnabled
? new Keyv({ store: keyvRedis }) ? new Keyv({ store: keyvRedis })
: new Keyv({ namespace: 'pending_req' }); : new Keyv({ namespace: CacheKeys.PENDING_REQ });
const config = isRedisEnabled const config = isRedisEnabled
? new Keyv({ store: keyvRedis }) ? new Keyv({ store: keyvRedis })
@ -64,7 +64,7 @@ const abortKeys = isRedisEnabled
const namespaces = { const namespaces = {
[CacheKeys.ROLES]: roles, [CacheKeys.ROLES]: roles,
[CacheKeys.CONFIG_STORE]: config, [CacheKeys.CONFIG_STORE]: config,
pending_req, [CacheKeys.PENDING_REQ]: pending_req,
[ViolationTypes.BAN]: new Keyv({ store: keyvMongo, namespace: CacheKeys.BANS, ttl: duration }), [ViolationTypes.BAN]: new Keyv({ store: keyvMongo, namespace: CacheKeys.BANS, ttl: duration }),
[CacheKeys.ENCODED_DOMAINS]: new Keyv({ [CacheKeys.ENCODED_DOMAINS]: new Keyv({
store: keyvMongo, store: keyvMongo,

View file

@ -76,7 +76,9 @@ const tokenValues = Object.assign(
'4k': { prompt: 1.5, completion: 2 }, '4k': { prompt: 1.5, completion: 2 },
'16k': { prompt: 3, completion: 4 }, '16k': { prompt: 3, completion: 4 },
'gpt-3.5-turbo-1106': { prompt: 1, completion: 2 }, 'gpt-3.5-turbo-1106': { prompt: 1, completion: 2 },
'o4-mini': { prompt: 1.1, completion: 4.4 },
'o3-mini': { prompt: 1.1, completion: 4.4 }, 'o3-mini': { prompt: 1.1, completion: 4.4 },
o3: { prompt: 10, completion: 40 },
'o1-mini': { prompt: 1.1, completion: 4.4 }, 'o1-mini': { prompt: 1.1, completion: 4.4 },
'o1-preview': { prompt: 15, completion: 60 }, 'o1-preview': { prompt: 15, completion: 60 },
o1: { prompt: 15, completion: 60 }, o1: { prompt: 15, completion: 60 },

View file

@ -165,6 +165,15 @@ describe('getMultiplier', () => {
); );
}); });
it('should return correct multipliers for o4-mini and o3', () => {
['o4-mini', 'o3'].forEach((model) => {
const prompt = getMultiplier({ model, tokenType: 'prompt' });
const completion = getMultiplier({ model, tokenType: 'completion' });
expect(prompt).toBe(tokenValues[model].prompt);
expect(completion).toBe(tokenValues[model].completion);
});
});
it('should return defaultRate if tokenType is provided but not found in tokenValues', () => { it('should return defaultRate if tokenType is provided but not found in tokenValues', () => {
expect(getMultiplier({ valueKey: '8k', tokenType: 'unknownType' })).toBe(defaultRate); expect(getMultiplier({ valueKey: '8k', tokenType: 'unknownType' })).toBe(defaultRate);
}); });

View file

@ -58,7 +58,7 @@ const payloadParser = ({ req, agent, endpoint }) => {
const legacyContentEndpoints = new Set([KnownEndpoints.groq, KnownEndpoints.deepseek]); const legacyContentEndpoints = new Set([KnownEndpoints.groq, KnownEndpoints.deepseek]);
const noSystemModelRegex = [/\bo1\b/gi]; const noSystemModelRegex = [/\b(o\d)\b/gi];
// const { processMemory, memoryInstructions } = require('~/server/services/Endpoints/agents/memory'); // const { processMemory, memoryInstructions } = require('~/server/services/Endpoints/agents/memory');
// const { getFormattedMemories } = require('~/models/Memory'); // const { getFormattedMemories } = require('~/models/Memory');
@ -975,7 +975,7 @@ class AgentClient extends BaseClient {
}) })
)?.llmConfig ?? clientOptions; )?.llmConfig ?? clientOptions;
} }
if (/\b(o1|o3)\b/i.test(clientOptions.model) && clientOptions.maxTokens != null) { if (/\b(o\d)\b/i.test(clientOptions.model) && clientOptions.maxTokens != null) {
delete clientOptions.maxTokens; delete clientOptions.maxTokens;
} }
try { try {

View file

@ -1,4 +1,4 @@
const { Time } = require('librechat-data-provider'); const { Time, CacheKeys } = require('librechat-data-provider');
const clearPendingReq = require('~/cache/clearPendingReq'); const clearPendingReq = require('~/cache/clearPendingReq');
const { logViolation, getLogStores } = require('~/cache'); const { logViolation, getLogStores } = require('~/cache');
const { isEnabled } = require('~/server/utils'); const { isEnabled } = require('~/server/utils');
@ -25,7 +25,7 @@ const {
* @throws {Error} Throws an error if the user exceeds the concurrent request limit. * @throws {Error} Throws an error if the user exceeds the concurrent request limit.
*/ */
const concurrentLimiter = async (req, res, next) => { const concurrentLimiter = async (req, res, next) => {
const namespace = 'pending_req'; const namespace = CacheKeys.PENDING_REQ;
const cache = getLogStores(namespace); const cache = getLogStores(namespace);
if (!cache) { if (!cache) {
return next(); return next();

View file

@ -2,7 +2,9 @@ const z = require('zod');
const { EModelEndpoint } = require('librechat-data-provider'); const { EModelEndpoint } = require('librechat-data-provider');
const openAIModels = { const openAIModels = {
'o4-mini': 200000,
'o3-mini': 195000, // -5000 from max 'o3-mini': 195000, // -5000 from max
o3: 200000,
o1: 195000, // -5000 from max o1: 195000, // -5000 from max
'o1-mini': 127500, // -500 from max 'o1-mini': 127500, // -500 from max
'o1-preview': 127500, // -500 from max 'o1-preview': 127500, // -500 from max

View file

@ -340,6 +340,15 @@ describe('getModelMaxTokens', () => {
expect(getModelMaxTokens('o1-preview-something')).toBe(o1PreviewTokens); expect(getModelMaxTokens('o1-preview-something')).toBe(o1PreviewTokens);
expect(getModelMaxTokens('openai/o1-preview-something')).toBe(o1PreviewTokens); expect(getModelMaxTokens('openai/o1-preview-something')).toBe(o1PreviewTokens);
}); });
test('should return correct max context tokens for o4-mini and o3', () => {
const o4MiniTokens = maxTokensMap[EModelEndpoint.openAI]['o4-mini'];
const o3Tokens = maxTokensMap[EModelEndpoint.openAI]['o3'];
expect(getModelMaxTokens('o4-mini')).toBe(o4MiniTokens);
expect(getModelMaxTokens('openai/o4-mini')).toBe(o4MiniTokens);
expect(getModelMaxTokens('o3')).toBe(o3Tokens);
expect(getModelMaxTokens('openai/o3')).toBe(o3Tokens);
});
}); });
describe('matchModelName', () => { describe('matchModelName', () => {

View file

@ -306,7 +306,7 @@ export type TAskProps = {
export type TOptions = { export type TOptions = {
editedMessageId?: string | null; editedMessageId?: string | null;
editedText?: string | null; editedText?: string | null;
resubmitFiles?: boolean; isResubmission?: boolean;
isRegenerate?: boolean; isRegenerate?: boolean;
isContinued?: boolean; isContinued?: boolean;
isEdited?: boolean; isEdited?: boolean;

View file

@ -121,6 +121,7 @@ const ChatForm = memo(({ index = 0 }: { index?: number }) => {
setFiles, setFiles,
textAreaRef, textAreaRef,
conversationId, conversationId,
isSubmitting: isSubmitting || isSubmittingAdded,
}); });
const { submitMessage, submitPrompt } = useSubmitMessage(); const { submitMessage, submitPrompt } = useSubmitMessage();

View file

@ -60,7 +60,7 @@ const EditMessage = ({
conversationId, conversationId,
}, },
{ {
resubmitFiles: true, isResubmission: true,
}, },
); );

View file

@ -25,7 +25,7 @@ type EndpointIcon = {
function getOpenAIColor(_model: string | null | undefined) { function getOpenAIColor(_model: string | null | undefined) {
const model = _model?.toLowerCase() ?? ''; const model = _model?.toLowerCase() ?? '';
if (model && /\b(o1|o3)\b/i.test(model)) { if (model && /\b(o\d)\b/i.test(model)) {
return '#000000'; return '#000000';
} }
return model.includes('gpt-4') ? '#AB68FF' : '#19C37D'; return model.includes('gpt-4') ? '#AB68FF' : '#19C37D';

View file

@ -90,7 +90,7 @@ export default function useChatFunctions({
{ {
editedText = null, editedText = null,
editedMessageId = null, editedMessageId = null,
resubmitFiles = false, isResubmission = false,
isRegenerate = false, isRegenerate = false,
isContinued = false, isContinued = false,
isEdited = false, isEdited = false,
@ -202,7 +202,7 @@ export default function useChatFunctions({
}; };
const reuseFiles = const reuseFiles =
(isRegenerate || resubmitFiles) && parentMessage?.files && parentMessage.files.length > 0; (isRegenerate || isResubmission) && parentMessage?.files && parentMessage.files.length > 0;
if (setFiles && reuseFiles === true) { if (setFiles && reuseFiles === true) {
currentMsg.files = parentMessage.files; currentMsg.files = parentMessage.files;
setFiles(new Map()); setFiles(new Map());
@ -298,6 +298,7 @@ export default function useChatFunctions({
isEdited: isEditOrContinue, isEdited: isEditOrContinue,
isContinued, isContinued,
isRegenerate, isRegenerate,
isResubmission,
initialResponse, initialResponse,
isTemporary, isTemporary,
ephemeralAgent, ephemeralAgent,

View file

@ -1,7 +1,8 @@
import debounce from 'lodash/debounce'; import debounce from 'lodash/debounce';
import { SetterOrUpdater, useRecoilValue } from 'recoil'; import { SetterOrUpdater, useRecoilValue } from 'recoil';
import { useState, useEffect, useMemo, useCallback } from 'react'; import { useState, useEffect, useMemo, useCallback } from 'react';
import { LocalStorageKeys, TFile } from 'librechat-data-provider'; import { LocalStorageKeys, Constants } from 'librechat-data-provider';
import type { TFile } from 'librechat-data-provider';
import type { ExtendedFile } from '~/common'; import type { ExtendedFile } from '~/common';
import { useChatFormContext } from '~/Providers'; import { useChatFormContext } from '~/Providers';
import { useGetFiles } from '~/data-provider'; import { useGetFiles } from '~/data-provider';
@ -34,11 +35,13 @@ const decodeBase64 = (base64String: string): string => {
}; };
export const useAutoSave = ({ export const useAutoSave = ({
conversationId, isSubmitting,
conversationId: _conversationId,
textAreaRef, textAreaRef,
setFiles, setFiles,
files, files,
}: { }: {
isSubmitting?: boolean;
conversationId?: string | null; conversationId?: string | null;
textAreaRef?: React.RefObject<HTMLTextAreaElement>; textAreaRef?: React.RefObject<HTMLTextAreaElement>;
files: Map<string, ExtendedFile>; files: Map<string, ExtendedFile>;
@ -47,6 +50,7 @@ export const useAutoSave = ({
// setting for auto-save // setting for auto-save
const { setValue } = useChatFormContext(); const { setValue } = useChatFormContext();
const saveDrafts = useRecoilValue<boolean>(store.saveDrafts); const saveDrafts = useRecoilValue<boolean>(store.saveDrafts);
const conversationId = isSubmitting ? Constants.PENDING_CONVO : _conversationId;
const [currentConversationId, setCurrentConversationId] = useState<string | null>(null); const [currentConversationId, setCurrentConversationId] = useState<string | null>(null);
const fileIds = useMemo(() => Array.from(files.keys()), [files]); const fileIds = useMemo(() => Array.from(files.keys()), [files]);

View file

@ -613,8 +613,17 @@ export default function useEventHandlers({
messages?.[messages.length - 1] != null && messages?.[messages.length - 1] != null &&
messages[messages.length - 2] != null messages[messages.length - 2] != null
) { ) {
const requestMessage = messages[messages.length - 2]; let requestMessage = messages[messages.length - 2];
const responseMessage = messages[messages.length - 1]; const responseMessage = messages[messages.length - 1];
if (requestMessage.messageId !== responseMessage.parentMessageId) {
// the request message is the parent of response, which we search for backwards
for (let i = messages.length - 3; i >= 0; i--) {
if (messages[i].messageId === responseMessage.parentMessageId) {
requestMessage = messages[i];
break;
}
}
}
finalHandler( finalHandler(
{ {
conversation: { conversation: {

View file

@ -2,7 +2,6 @@ import { useEffect, useState } from 'react';
import { v4 } from 'uuid'; import { v4 } from 'uuid';
import { SSE } from 'sse.js'; import { SSE } from 'sse.js';
import { useSetRecoilState } from 'recoil'; import { useSetRecoilState } from 'recoil';
import { useQueryClient } from '@tanstack/react-query';
import { import {
request, request,
Constants, Constants,
@ -13,18 +12,12 @@ import {
removeNullishValues, removeNullishValues,
isAssistantsEndpoint, isAssistantsEndpoint,
} from 'librechat-data-provider'; } from 'librechat-data-provider';
import type { import type { TMessage, TPayload, TSubmission, EventSubmission } from 'librechat-data-provider';
EventSubmission,
TConversation,
TMessage,
TPayload,
TSubmission,
} from 'librechat-data-provider';
import type { EventHandlerParams } from './useEventHandlers'; import type { EventHandlerParams } from './useEventHandlers';
import type { TResData } from '~/common'; import type { TResData } from '~/common';
import { useGenTitleMutation, useGetStartupConfig, useGetUserBalance } from '~/data-provider'; import { useGenTitleMutation, useGetStartupConfig, useGetUserBalance } from '~/data-provider';
import useEventHandlers, { getConvoTitle } from './useEventHandlers';
import { useAuthContext } from '~/hooks/AuthContext'; import { useAuthContext } from '~/hooks/AuthContext';
import useEventHandlers from './useEventHandlers';
import store from '~/store'; import store from '~/store';
const clearDraft = (conversationId?: string | null) => { const clearDraft = (conversationId?: string | null) => {
@ -53,7 +46,6 @@ export default function useSSE(
isAddedRequest = false, isAddedRequest = false,
runIndex = 0, runIndex = 0,
) { ) {
const queryClient = useQueryClient();
const genTitle = useGenTitleMutation(); const genTitle = useGenTitleMutation();
const setActiveRunId = useSetRecoilState(store.activeRunFamily(runIndex)); const setActiveRunId = useSetRecoilState(store.activeRunFamily(runIndex));
@ -107,30 +99,6 @@ export default function useSSE(
let { userMessage } = submission; let { userMessage } = submission;
const payloadData = createPayload(submission); const payloadData = createPayload(submission);
/**
* Helps clear text immediately on submission instead of
* restoring draft, which gets deleted on generation end
* */
const parentId = submission?.isRegenerate
? userMessage.overrideParentMessageId
: userMessage.parentMessageId;
setConversation?.((prev: TConversation | null) => {
if (!prev) {
return null;
}
const title =
getConvoTitle({
parentId,
queryClient,
currentTitle: prev?.title,
conversationId: prev?.conversationId,
}) ?? '';
return {
...prev,
title,
conversationId: Constants.PENDING_CONVO as string,
};
});
let { payload } = payloadData; let { payload } = payloadData;
if (isAssistantsEndpoint(payload.endpoint) || isAgentsEndpoint(payload.endpoint)) { if (isAssistantsEndpoint(payload.endpoint) || isAgentsEndpoint(payload.endpoint)) {
payload = removeNullishValues(payload) as TPayload; payload = removeNullishValues(payload) as TPayload;

View file

@ -856,6 +856,8 @@ export const visionModels = [
'gpt-4o', 'gpt-4o',
'gpt-4-turbo', 'gpt-4-turbo',
'gpt-4-vision', 'gpt-4-vision',
'o4-mini',
'o3',
'o1', 'o1',
'gpt-4.1', 'gpt-4.1',
'gpt-4.5', 'gpt-4.5',
@ -1011,6 +1013,10 @@ export enum CacheKeys {
* Key for in-progress flow states. * Key for in-progress flow states.
*/ */
FLOWS = 'flows', FLOWS = 'flows',
/**
* Key for pending chat requests (concurrency check)
*/
PENDING_REQ = 'pending_req',
/** /**
* Key for s3 check intervals per user * Key for s3 check intervals per user
*/ */

View file

@ -65,6 +65,7 @@ export type TSubmission = {
isTemporary: boolean; isTemporary: boolean;
messages: TMessage[]; messages: TMessage[];
isRegenerate?: boolean; isRegenerate?: boolean;
isResubmission?: boolean;
initialResponse?: TMessage; initialResponse?: TMessage;
conversation: Partial<TConversation>; conversation: Partial<TConversation>;
endpointOption: TEndpointOption; endpointOption: TEndpointOption;