🤖 feat: Enhance Assistant Model Handling for Model Specs (#4390)

* chore: cleanup type issues in client/src/utils/endpoints

* refactor: use Constant enum for 'new' conversationId

* refactor: select assistant model if not provided for model spec
This commit is contained in:
Danny Avila 2024-10-11 14:20:32 +02:00 committed by GitHub
parent 2846779603
commit bab0152c58
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 23 additions and 12 deletions

View file

@ -1,12 +1,12 @@
import { useMemo } from 'react'; import { useMemo } from 'react';
import { useRecoilValue } from 'recoil'; import { useRecoilValue } from 'recoil';
import { EModelEndpoint } from 'librechat-data-provider';
import { Content, Portal, Root } from '@radix-ui/react-popover'; import { Content, Portal, Root } from '@radix-ui/react-popover';
import { useGetEndpointsQuery } from 'librechat-data-provider/react-query'; import { useGetEndpointsQuery } from 'librechat-data-provider/react-query';
import { EModelEndpoint, isAssistantsEndpoint } from 'librechat-data-provider';
import type { TModelSpec, TConversation, TEndpointsConfig } from 'librechat-data-provider'; import type { TModelSpec, TConversation, TEndpointsConfig } from 'librechat-data-provider';
import { useChatContext, useAssistantsMapContext } from '~/Providers';
import { getConvoSwitchLogic, getModelSpecIconURL } from '~/utils'; import { getConvoSwitchLogic, getModelSpecIconURL } from '~/utils';
import { useDefaultConvo, useNewConvo } from '~/hooks'; import { useDefaultConvo, useNewConvo } from '~/hooks';
import { useChatContext } from '~/Providers';
import MenuButton from './MenuButton'; import MenuButton from './MenuButton';
import ModelSpecs from './ModelSpecs'; import ModelSpecs from './ModelSpecs';
import store from '~/store'; import store from '~/store';
@ -18,6 +18,7 @@ export default function ModelSpecsMenu({ modelSpecs }: { modelSpecs?: TModelSpec
const { data: endpointsConfig = {} as TEndpointsConfig } = useGetEndpointsQuery(); const { data: endpointsConfig = {} as TEndpointsConfig } = useGetEndpointsQuery();
const modularChat = useRecoilValue(store.modularChat); const modularChat = useRecoilValue(store.modularChat);
const getDefaultConversation = useDefaultConvo(); const getDefaultConversation = useDefaultConvo();
const assistantMap = useAssistantsMapContext();
const onSelectSpec = (spec: TModelSpec) => { const onSelectSpec = (spec: TModelSpec) => {
const { preset } = spec; const { preset } = spec;
@ -47,6 +48,10 @@ export default function ModelSpecsMenu({ modelSpecs }: { modelSpecs?: TModelSpec
preset.endpointType = newEndpointType; preset.endpointType = newEndpointType;
} }
if (isAssistantsEndpoint(newEndpoint) && preset.assistant_id != null && !(preset.model ?? '')) {
preset.model = assistantMap?.[newEndpoint]?.[preset.assistant_id]?.model;
}
const isModular = isCurrentModular && isNewModular && shouldSwitch; const isModular = isCurrentModular && isNewModular && shouldSwitch;
if (isExistingConversation && isModular) { if (isExistingConversation && isModular) {
template.endpointType = newEndpointType as EModelEndpoint | undefined; template.endpointType = newEndpointType as EModelEndpoint | undefined;

View file

@ -64,6 +64,10 @@ export default function useSelectMention({
preset.endpointType = newEndpointType; preset.endpointType = newEndpointType;
} }
if (isAssistantsEndpoint(newEndpoint) && preset.assistant_id != null && !(preset.model ?? '')) {
preset.model = assistantMap?.[newEndpoint]?.[preset.assistant_id]?.model;
}
const isModular = isCurrentModular && isNewModular && shouldSwitch; const isModular = isCurrentModular && isNewModular && shouldSwitch;
if (isExistingConversation && isModular) { if (isExistingConversation && isModular) {
template.endpointType = newEndpointType as EModelEndpoint | undefined; template.endpointType = newEndpointType as EModelEndpoint | undefined;
@ -90,7 +94,7 @@ export default function useSelectMention({
keepAddedConvos: isModular, keepAddedConvos: isModular,
}); });
}, },
[conversation, getDefaultConversation, modularChat, newConversation, endpointsConfig], [conversation, getDefaultConversation, modularChat, newConversation, endpointsConfig, assistantMap],
); );
type Kwargs = { type Kwargs = {

View file

@ -6,6 +6,7 @@ import {
} from 'librechat-data-provider/react-query'; } from 'librechat-data-provider/react-query';
import { useNavigate } from 'react-router-dom'; import { useNavigate } from 'react-router-dom';
import { import {
Constants,
FileSources, FileSources,
isParamEndpoint, isParamEndpoint,
LocalStorageKeys, LocalStorageKeys,
@ -116,7 +117,7 @@ const useNewConvo = (index = 0) => {
) ?? assistants[0]?.id; ) ?? assistants[0]?.id;
} }
if (currentAssistantId && isAssistantEndpoint && conversation.conversationId === 'new') { if (currentAssistantId && isAssistantEndpoint && conversation.conversationId === Constants.NEW_CONVO) {
const assistant = assistants.find((asst) => asst.id === currentAssistantId); const assistant = assistants.find((asst) => asst.id === currentAssistantId);
conversation.model = assistant?.model; conversation.model = assistant?.model;
updateLastSelectedModel({ updateLastSelectedModel({
@ -147,12 +148,12 @@ const useNewConvo = (index = 0) => {
clearAllLatestMessages(); clearAllLatestMessages();
} }
if (conversation.conversationId === 'new' && !modelsData) { if (conversation.conversationId === Constants.NEW_CONVO && !modelsData) {
const appTitle = localStorage.getItem(LocalStorageKeys.APP_TITLE) ?? ''; const appTitle = localStorage.getItem(LocalStorageKeys.APP_TITLE) ?? '';
if (appTitle) { if (appTitle) {
document.title = appTitle; document.title = appTitle;
} }
navigate('/c/new'); navigate(`/c/${Constants.NEW_CONVO}`);
} }
clearTimeout(timeoutIdRef.current); clearTimeout(timeoutIdRef.current);
@ -189,12 +190,12 @@ const useNewConvo = (index = 0) => {
isParamEndpoint(_template.endpoint ?? '', _template.endpointType ?? '') === true || isParamEndpoint(_template.endpoint ?? '', _template.endpointType ?? '') === true ||
isParamEndpoint(_preset?.endpoint ?? '', _preset?.endpointType ?? ''); isParamEndpoint(_preset?.endpoint ?? '', _preset?.endpointType ?? '');
const template = const template =
paramEndpoint === true && templateConvoId && templateConvoId === 'new' paramEndpoint === true && templateConvoId && templateConvoId === Constants.NEW_CONVO
? { endpoint: _template.endpoint } ? { endpoint: _template.endpoint }
: _template; : _template;
const conversation = { const conversation = {
conversationId: 'new', conversationId: Constants.NEW_CONVO as string,
title: 'New Chat', title: 'New Chat',
endpoint: null, endpoint: null,
...template, ...template,

View file

@ -87,22 +87,23 @@ const firstLocalConvoKey = LocalStorageKeys.LAST_CONVO_SETUP + '_0';
* update without updating last convo setup when same endpoint */ * update without updating last convo setup when same endpoint */
export function updateLastSelectedModel({ export function updateLastSelectedModel({
endpoint, endpoint,
model, model = '',
}: { }: {
endpoint: string; endpoint: string;
model: string | undefined; model?: string;
}) { }) {
if (!model) { if (!model) {
return; return;
} }
const lastConversationSetup = JSON.parse(localStorage.getItem(firstLocalConvoKey) || '{}'); /* Note: an empty string value is possible */
const lastConversationSetup = JSON.parse((localStorage.getItem(firstLocalConvoKey) ?? '{}') || '{}');
if (lastConversationSetup.endpoint === endpoint) { if (lastConversationSetup.endpoint === endpoint) {
lastConversationSetup.model = model; lastConversationSetup.model = model;
localStorage.setItem(firstLocalConvoKey, JSON.stringify(lastConversationSetup)); localStorage.setItem(firstLocalConvoKey, JSON.stringify(lastConversationSetup));
} }
const lastSelectedModels = JSON.parse(localStorage.getItem(LocalStorageKeys.LAST_MODEL) || '{}'); const lastSelectedModels = JSON.parse((localStorage.getItem(LocalStorageKeys.LAST_MODEL) ?? '{}') || '{}');
lastSelectedModels[endpoint] = model; lastSelectedModels[endpoint] = model;
localStorage.setItem(LocalStorageKeys.LAST_MODEL, JSON.stringify(lastSelectedModels)); localStorage.setItem(LocalStorageKeys.LAST_MODEL, JSON.stringify(lastSelectedModels));
} }