🔧 fix: Improve Endpoint Handling and Address Edge Cases (#1486)

* fix(TEndpointsConfig): resolve property access issues with typesafe helper function

* fix: undefined or null endpoint edge case

* refactor(mapEndpoints -> endpoints): renamed module to be more general for endpoint handling, wrote unit tests, export all helpers
This commit is contained in:
Danny Avila 2024-01-04 10:17:15 -05:00 committed by GitHub
parent 42f2353509
commit 9864fc8700
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
24 changed files with 275 additions and 99 deletions

View file

@ -45,19 +45,23 @@ export default function ChatForm({ index = 0 }) {
<div className="flex w-full items-center"> <div className="flex w-full items-center">
<div className="[&:has(textarea:focus)]:border-token-border-xheavy border-token-border-heavy shadow-xs dark:shadow-xs relative flex w-full flex-grow flex-col overflow-hidden rounded-2xl border border-black/10 bg-white shadow-[0_0_0_2px_rgba(255,255,255,0.95)] dark:border-gray-600 dark:bg-gray-800 dark:text-white dark:shadow-[0_0_0_2px_rgba(52,53,65,0.95)] [&:has(textarea:focus)]:shadow-[0_2px_6px_rgba(0,0,0,.05)]"> <div className="[&:has(textarea:focus)]:border-token-border-xheavy border-token-border-heavy shadow-xs dark:shadow-xs relative flex w-full flex-grow flex-col overflow-hidden rounded-2xl border border-black/10 bg-white shadow-[0_0_0_2px_rgba(255,255,255,0.95)] dark:border-gray-600 dark:bg-gray-800 dark:text-white dark:shadow-[0_0_0_2px_rgba(52,53,65,0.95)] [&:has(textarea:focus)]:shadow-[0_2px_6px_rgba(0,0,0,.05)]">
<Images files={files} setFiles={setFiles} setFilesLoading={setFilesLoading} /> <Images files={files} setFiles={setFiles} setFilesLoading={setFilesLoading} />
<Textarea {endpoint && (
value={text} <Textarea
disabled={requiresKey} value={text}
onChange={(e: ChangeEvent<HTMLTextAreaElement>) => setText(e.target.value)} disabled={requiresKey}
setText={setText} onChange={(e: ChangeEvent<HTMLTextAreaElement>) => setText(e.target.value)}
submitMessage={submitMessage} setText={setText}
endpoint={endpoint} submitMessage={submitMessage}
/> endpoint={endpoint}
/>
)}
<AttachFile endpoint={endpoint ?? ''} disabled={requiresKey} /> <AttachFile endpoint={endpoint ?? ''} disabled={requiresKey} />
{isSubmitting && showStopButton ? ( {isSubmitting && showStopButton ? (
<StopButton stop={handleStopGenerating} setShowStopButton={setShowStopButton} /> <StopButton stop={handleStopGenerating} setShowStopButton={setShowStopButton} />
) : ( ) : (
<SendButton text={text} disabled={filesLoading || isSubmitting || requiresKey} /> endpoint && (
<SendButton text={text} disabled={filesLoading || isSubmitting || requiresKey} />
)
)} )}
</div> </div>
</div> </div>

View file

@ -3,6 +3,7 @@ import { useGetEndpointsQuery } from 'librechat-data-provider/react-query';
import { EModelEndpoint } from 'librechat-data-provider'; import { EModelEndpoint } from 'librechat-data-provider';
import { icons } from './Menus/Endpoints/Icons'; import { icons } from './Menus/Endpoints/Icons';
import { useChatContext } from '~/Providers'; import { useChatContext } from '~/Providers';
import { getEndpointField } from '~/utils';
import { useLocalize } from '~/hooks'; import { useLocalize } from '~/hooks';
export default function Landing({ Header }: { Header?: ReactNode }) { export default function Landing({ Header }: { Header?: ReactNode }) {
@ -19,7 +20,9 @@ export default function Landing({ Header }: { Header?: ReactNode }) {
endpoint = EModelEndpoint.openAI; endpoint = EModelEndpoint.openAI;
} }
const iconKey = endpointsConfig?.[endpoint ?? '']?.type ? 'unknown' : endpoint ?? 'unknown'; const endpointType = getEndpointField(endpointsConfig, endpoint, 'type');
const iconURL = getEndpointField(endpointsConfig, endpoint, 'iconURL');
const iconKey = endpointType ? 'unknown' : endpoint ?? 'unknown';
return ( return (
<div className="relative h-full"> <div className="relative h-full">
@ -27,13 +30,14 @@ export default function Landing({ Header }: { Header?: ReactNode }) {
<div className="flex h-full flex-col items-center justify-center"> <div className="flex h-full flex-col items-center justify-center">
<div className="mb-3 h-[72px] w-[72px]"> <div className="mb-3 h-[72px] w-[72px]">
<div className="gizmo-shadow-stroke relative flex h-full items-center justify-center rounded-full bg-white text-black"> <div className="gizmo-shadow-stroke relative flex h-full items-center justify-center rounded-full bg-white text-black">
{icons[iconKey]({ {endpoint &&
size: 41, icons[iconKey]({
context: 'landing', size: 41,
className: 'h-2/3 w-2/3', context: 'landing',
endpoint: endpoint as EModelEndpoint | string, className: 'h-2/3 w-2/3',
iconURL: endpointsConfig?.[endpoint ?? ''].iconURL, endpoint: endpoint,
})} iconURL: iconURL,
})}
</div> </div>
</div> </div>
<div className="mb-5 text-2xl font-medium dark:text-white"> <div className="mb-5 text-2xl font-medium dark:text-white">

View file

@ -7,10 +7,10 @@ import type { FC } from 'react';
import type { TPreset } from 'librechat-data-provider'; import type { TPreset } from 'librechat-data-provider';
import { useLocalize, useUserKey, useDefaultConvo } from '~/hooks'; import { useLocalize, useUserKey, useDefaultConvo } from '~/hooks';
import { SetKeyDialog } from '~/components/Input/SetKeyDialog'; import { SetKeyDialog } from '~/components/Input/SetKeyDialog';
import { cn, getEndpointField } from '~/utils';
import { useChatContext } from '~/Providers'; import { useChatContext } from '~/Providers';
import store from '~/store';
import { icons } from './Icons'; import { icons } from './Icons';
import { cn } from '~/utils'; import store from '~/store';
type MenuItemProps = { type MenuItemProps = {
title: string; title: string;
@ -50,7 +50,7 @@ const MenuItem: FC<MenuItemProps> = ({
const template: Partial<TPreset> = { endpoint: newEndpoint, conversationId: 'new' }; const template: Partial<TPreset> = { endpoint: newEndpoint, conversationId: 'new' };
const { conversationId } = conversation ?? {}; const { conversationId } = conversation ?? {};
if (modularChat && conversationId && conversationId !== 'new') { if (modularChat && conversationId && conversationId !== 'new') {
template.endpointType = endpointsConfig?.[newEndpoint]?.type; template.endpointType = getEndpointField(endpointsConfig, newEndpoint, 'type');
const currentConvo = getDefaultConversation({ const currentConvo = getDefaultConversation({
/* target endpointType is necessary to avoid endpoint mixing */ /* target endpointType is necessary to avoid endpoint mixing */
@ -66,7 +66,7 @@ const MenuItem: FC<MenuItemProps> = ({
} }
}; };
const endpointType = endpointsConfig?.[endpoint ?? '']?.type; const endpointType = getEndpointField(endpointsConfig, endpoint, 'type');
const iconKey = endpointType ? 'unknown' : endpoint ?? 'unknown'; const iconKey = endpointType ? 'unknown' : endpoint ?? 'unknown';
const Icon = icons[iconKey]; const Icon = icons[iconKey];
@ -88,7 +88,7 @@ const MenuItem: FC<MenuItemProps> = ({
endpoint={endpoint} endpoint={endpoint}
context={'menu-item'} context={'menu-item'}
className="icon-md shrink-0 dark:text-white" className="icon-md shrink-0 dark:text-white"
iconURL={endpointsConfig?.[endpoint ?? '']?.iconURL} iconURL={getEndpointField(endpointsConfig, endpoint, 'iconURL')}
/> />
} }
<div> <div>
@ -167,7 +167,7 @@ const MenuItem: FC<MenuItemProps> = ({
endpoint={endpoint} endpoint={endpoint}
endpointType={endpointType} endpointType={endpointType}
onOpenChange={setDialogOpen} onOpenChange={setDialogOpen}
userProvideURL={endpointsConfig?.[endpoint ?? '']?.userProvideURL} userProvideURL={getEndpointField(endpointsConfig, endpoint, 'userProvideURL')}
/> />
)} )}
</> </>

View file

@ -3,6 +3,7 @@ import { Close } from '@radix-ui/react-popover';
import { EModelEndpoint, alternateName } from 'librechat-data-provider'; import { EModelEndpoint, alternateName } from 'librechat-data-provider';
import { useGetEndpointsQuery } from 'librechat-data-provider/react-query'; import { useGetEndpointsQuery } from 'librechat-data-provider/react-query';
import MenuSeparator from '../UI/MenuSeparator'; import MenuSeparator from '../UI/MenuSeparator';
import { getEndpointField } from '~/utils';
import MenuItem from './MenuItem'; import MenuItem from './MenuItem';
const EndpointItems: FC<{ const EndpointItems: FC<{
@ -19,7 +20,11 @@ const EndpointItems: FC<{
} else if (!endpointsConfig?.[endpoint]) { } else if (!endpointsConfig?.[endpoint]) {
return null; return null;
} }
const userProvidesKey = endpointsConfig?.[endpoint]?.userProvide; const userProvidesKey: boolean | null | undefined = getEndpointField(
endpointsConfig,
endpoint,
'userProvide',
);
return ( return (
<Close asChild key={`endpoint-${endpoint}`}> <Close asChild key={`endpoint-${endpoint}`}>
<div key={`endpoint-${endpoint}`}> <div key={`endpoint-${endpoint}`}>

View file

@ -8,10 +8,10 @@ import type { TPreset } from 'librechat-data-provider';
import FileUpload from '~/components/Input/EndpointMenu/FileUpload'; import FileUpload from '~/components/Input/EndpointMenu/FileUpload';
import { PinIcon, EditIcon, TrashIcon } from '~/components/svg'; import { PinIcon, EditIcon, TrashIcon } from '~/components/svg';
import DialogTemplate from '~/components/ui/DialogTemplate'; import DialogTemplate from '~/components/ui/DialogTemplate';
import { getPresetTitle, getEndpointField } from '~/utils';
import { Dialog, DialogTrigger } from '~/components/ui/'; import { Dialog, DialogTrigger } from '~/components/ui/';
import { MenuSeparator, MenuItem } from '../UI'; import { MenuSeparator, MenuItem } from '../UI';
import { icons } from '../Endpoints/Icons'; import { icons } from '../Endpoints/Icons';
import { getPresetTitle } from '~/utils';
import { useLocalize } from '~/hooks'; import { useLocalize } from '~/hooks';
import store from '~/store'; import store from '~/store';
@ -95,7 +95,7 @@ const PresetItems: FC<{
return null; return null;
} }
const iconKey = endpointsConfig?.[preset.endpoint ?? '']?.type const iconKey = getEndpointField(endpointsConfig, preset.endpoint, 'type')
? 'unknown' ? 'unknown'
: preset.endpoint ?? 'unknown'; : preset.endpoint ?? 'unknown';
@ -111,7 +111,7 @@ const PresetItems: FC<{
onClick={() => onSelectPreset(preset)} onClick={() => onSelectPreset(preset)}
icon={icons[iconKey]({ icon={icons[iconKey]({
context: 'menu-item', context: 'menu-item',
iconURL: endpointsConfig?.[preset.endpoint ?? ''].iconURL, iconURL: getEndpointField(endpointsConfig, preset.endpoint, 'iconURL'),
className: 'icon-md mr-1 dark:text-white', className: 'icon-md mr-1 dark:text-white',
endpoint: preset.endpoint, endpoint: preset.endpoint,
})} })}

View file

@ -5,12 +5,14 @@ import {
useGetEndpointsQuery, useGetEndpointsQuery,
useUpdateConversationMutation, useUpdateConversationMutation,
} from 'librechat-data-provider/react-query'; } from 'librechat-data-provider/react-query';
import { EModelEndpoint } from 'librechat-data-provider';
import type { MouseEvent, FocusEvent, KeyboardEvent } from 'react'; import type { MouseEvent, FocusEvent, KeyboardEvent } from 'react';
import { useConversations, useNavigateToConvo } from '~/hooks'; import { useConversations, useNavigateToConvo } from '~/hooks';
import { MinimalIcon } from '~/components/Endpoints'; import { MinimalIcon } from '~/components/Endpoints';
import { NotificationSeverity } from '~/common'; import { NotificationSeverity } from '~/common';
import { useToastContext } from '~/Providers'; import { useToastContext } from '~/Providers';
import DeleteButton from './NewDeleteButton'; import DeleteButton from './NewDeleteButton';
import { getEndpointField } from '~/utils';
import RenameButton from './RenameButton'; import RenameButton from './RenameButton';
import store from '~/store'; import store from '~/store';
@ -41,7 +43,7 @@ export default function Conversation({ conversation, retainView, toggleNav, i })
document.title = title; document.title = title;
// set conversation to the new conversation // set conversation to the new conversation
if (conversation?.endpoint === 'gptPlugins') { if (conversation?.endpoint === EModelEndpoint.gptPlugins) {
let lastSelectedTools = []; let lastSelectedTools = [];
try { try {
lastSelectedTools = JSON.parse(localStorage.getItem('lastSelectedTools') ?? '') ?? []; lastSelectedTools = JSON.parse(localStorage.getItem('lastSelectedTools') ?? '') ?? [];
@ -90,7 +92,7 @@ export default function Conversation({ conversation, retainView, toggleNav, i })
const icon = MinimalIcon({ const icon = MinimalIcon({
size: 20, size: 20,
iconURL: endpointsConfig?.[conversation.endpoint ?? '']?.iconURL, iconURL: getEndpointField(endpointsConfig, conversation.endpoint, 'iconURL'),
endpoint: conversation.endpoint, endpoint: conversation.endpoint,
endpointType: conversation.endpointType, endpointType: conversation.endpointType,
model: conversation.model, model: conversation.model,

View file

@ -4,9 +4,9 @@ import { alternateName } from 'librechat-data-provider';
import { useGetEndpointsQuery } from 'librechat-data-provider/react-query'; import { useGetEndpointsQuery } from 'librechat-data-provider/react-query';
import { DropdownMenuRadioItem } from '~/components'; import { DropdownMenuRadioItem } from '~/components';
import { SetKeyDialog } from '../SetKeyDialog'; import { SetKeyDialog } from '../SetKeyDialog';
import { cn, getEndpointField } from '~/utils';
import { Icon } from '~/components/Endpoints'; import { Icon } from '~/components/Endpoints';
import { useLocalize } from '~/hooks'; import { useLocalize } from '~/hooks';
import { cn } from '~/utils';
export default function ModelItem({ export default function ModelItem({
endpoint, endpoint,
@ -29,7 +29,11 @@ export default function ModelItem({
isCreatedByUser: false, isCreatedByUser: false,
}); });
const userProvidesKey = endpointsConfig?.[endpoint]?.userProvide; const userProvidesKey: boolean | null | undefined = getEndpointField(
endpointsConfig,
endpoint,
'userProvide',
);
const localize = useLocalize(); const localize = useLocalize();
// regular model // regular model

View file

@ -1,13 +1,14 @@
import React, { useEffect, useContext, useRef, useState, useCallback } from 'react';
import TextareaAutosize from 'react-textarea-autosize'; import TextareaAutosize from 'react-textarea-autosize';
import { useRecoilValue, useRecoilState, useSetRecoilState } from 'recoil'; import { useRecoilValue, useRecoilState, useSetRecoilState } from 'recoil';
import SubmitButton from './SubmitButton'; import React, { useEffect, useContext, useRef, useState, useCallback } from 'react';
import OptionsBar from './OptionsBar';
import { EndpointMenu } from './EndpointMenu'; import { EndpointMenu } from './EndpointMenu';
import SubmitButton from './SubmitButton';
import OptionsBar from './OptionsBar';
import Footer from './Footer'; import Footer from './Footer';
import { useMessageHandler, ThemeContext } from '~/hooks'; import { useMessageHandler, ThemeContext } from '~/hooks';
import { cn } from '~/utils'; import { cn, getEndpointField } from '~/utils';
import store from '~/store'; import store from '~/store';
interface TextChatProps { interface TextChatProps {
@ -195,7 +196,7 @@ export default function TextChat({ isSearchView = false }: TextChatProps) {
isSubmitting={isSubmitting} isSubmitting={isSubmitting}
userProvidesKey={ userProvidesKey={
conversation?.endpoint conversation?.endpoint
? endpointsConfig?.[conversation.endpoint]?.userProvide ? getEndpointField(endpointsConfig, conversation.endpoint, 'userProvide')
: undefined : undefined
} }
hasText={hasText} hasText={hasText}

View file

@ -7,7 +7,7 @@ export default function useGetSender() {
const { data: endpointsConfig = {} as TEndpointsConfig } = useGetEndpointsQuery(); const { data: endpointsConfig = {} as TEndpointsConfig } = useGetEndpointsQuery();
return useCallback( return useCallback(
(endpointOption: TEndpointOption) => { (endpointOption: TEndpointOption) => {
const { modelDisplayLabel } = endpointsConfig[endpointOption.endpoint ?? ''] ?? {}; const { modelDisplayLabel } = endpointsConfig?.[endpointOption.endpoint ?? ''] ?? {};
return getResponseSender({ ...endpointOption, modelDisplayLabel }); return getResponseSender({ ...endpointOption, modelDisplayLabel });
}, },
[endpointsConfig], [endpointsConfig],

View file

@ -13,11 +13,11 @@ import {
} from '~/data-provider'; } from '~/data-provider';
import { useChatContext, useToastContext } from '~/Providers'; import { useChatContext, useToastContext } from '~/Providers';
import useNavigateToConvo from '~/hooks/useNavigateToConvo'; import useNavigateToConvo from '~/hooks/useNavigateToConvo';
import { cleanupPreset, getEndpointField } from '~/utils';
import useDefaultConvo from '~/hooks/useDefaultConvo'; import useDefaultConvo from '~/hooks/useDefaultConvo';
import { useAuthContext } from '~/hooks/AuthContext'; import { useAuthContext } from '~/hooks/AuthContext';
import { NotificationSeverity } from '~/common'; import { NotificationSeverity } from '~/common';
import useLocalize from '~/hooks/useLocalize'; import useLocalize from '~/hooks/useLocalize';
import { cleanupPreset } from '~/utils';
import store from '~/store'; import store from '~/store';
export default function usePresets() { export default function usePresets() {
@ -162,12 +162,13 @@ export default function usePresets() {
const endpointsConfig = queryClient.getQueryData<TEndpointsConfig>([QueryKeys.endpoints]); const endpointsConfig = queryClient.getQueryData<TEndpointsConfig>([QueryKeys.endpoints]);
const currentEndpointType = endpointsConfig?.[endpoint ?? '']?.type ?? ''; const currentEndpointType = getEndpointField(endpointsConfig, endpoint, 'type');
const endpointType = endpointsConfig?.[newPreset?.endpoint ?? '']?.type; const endpointType = getEndpointField(endpointsConfig, newPreset.endpoint, 'type');
if ( if (
(modularEndpoints.has(endpoint ?? '') || modularEndpoints.has(currentEndpointType)) && (modularEndpoints.has(endpoint ?? '') || modularEndpoints.has(currentEndpointType ?? '')) &&
(modularEndpoints.has(newPreset?.endpoint ?? '') || modularEndpoints.has(endpointType)) && (modularEndpoints.has(newPreset?.endpoint ?? '') ||
modularEndpoints.has(endpointType ?? '')) &&
(endpoint === newPreset?.endpoint || modularChat) (endpoint === newPreset?.endpoint || modularChat)
) { ) {
const currentConvo = getDefaultConversation({ const currentConvo = getDefaultConversation({

View file

@ -1,12 +1,17 @@
import { useGetEndpointsQuery } from 'librechat-data-provider/react-query'; import { useGetEndpointsQuery } from 'librechat-data-provider/react-query';
import { useChatContext } from '~/Providers/ChatContext'; import { useChatContext } from '~/Providers/ChatContext';
import { getEndpointField } from '~/utils';
import useUserKey from './useUserKey'; import useUserKey from './useUserKey';
export default function useRequiresKey() { export default function useRequiresKey() {
const { conversation } = useChatContext(); const { conversation } = useChatContext();
const { data: endpointsConfig } = useGetEndpointsQuery(); const { data: endpointsConfig } = useGetEndpointsQuery();
const { endpoint } = conversation || {}; const { endpoint } = conversation || {};
const userProvidesKey = endpointsConfig?.[endpoint ?? '']?.userProvide; const userProvidesKey: boolean | null | undefined = getEndpointField(
endpointsConfig,
endpoint,
'userProvide',
);
const { getExpiry } = useUserKey(endpoint ?? ''); const { getExpiry } = useUserKey(endpoint ?? '');
const expiryTime = getExpiry(); const expiryTime = getExpiry();
const requiresKey = !expiryTime && userProvidesKey; const requiresKey = !expiryTime && userProvidesKey;

View file

@ -8,7 +8,7 @@ import {
const useUserKey = (endpoint: string) => { const useUserKey = (endpoint: string) => {
const { data: endpointsConfig } = useGetEndpointsQuery(); const { data: endpointsConfig } = useGetEndpointsQuery();
const config = endpointsConfig?.[endpoint]; const config = endpointsConfig?.[endpoint ?? ''];
const { azure } = config ?? {}; const { azure } = config ?? {};
let keyName = endpoint; let keyName = endpoint;

View file

@ -5,6 +5,7 @@ import type { TMessage } from 'librechat-data-provider';
import type { TMessageProps } from '~/common'; import type { TMessageProps } from '~/common';
import Icon from '~/components/Endpoints/Icon'; import Icon from '~/components/Endpoints/Icon';
import { useChatContext } from '~/Providers'; import { useChatContext } from '~/Providers';
import { getEndpointField } from '~/utils';
export default function useMessageHelpers(props: TMessageProps) { export default function useMessageHelpers(props: TMessageProps) {
const latestText = useRef(''); const latestText = useRef('');
@ -53,7 +54,7 @@ export default function useMessageHelpers(props: TMessageProps) {
const icon = Icon({ const icon = Icon({
...conversation, ...conversation,
...(message as TMessage), ...(message as TMessage),
iconURL: endpointsConfig?.[conversation?.endpoint ?? '']?.iconURL, iconURL: getEndpointField(endpointsConfig, conversation?.endpoint, 'iconURL'),
model: message?.model ?? conversation?.model, model: message?.model ?? conversation?.model,
size: 28.8, size: 28.8,
}); });

View file

@ -161,7 +161,7 @@ export default function useChatHelpers(index = 0, paramId: string | undefined) {
conversation: conversation ?? {}, conversation: conversation ?? {},
}); });
const { modelDisplayLabel } = endpointsConfig[endpoint ?? ''] ?? {}; const { modelDisplayLabel } = endpointsConfig?.[endpoint ?? ''] ?? {};
const endpointOption = { const endpointOption = {
...convo, ...convo,
endpoint, endpoint,

View file

@ -9,7 +9,7 @@ import type {
TModelsConfig, TModelsConfig,
TEndpointsConfig, TEndpointsConfig,
} from 'librechat-data-provider'; } from 'librechat-data-provider';
import { buildDefaultConvo, getDefaultEndpoint } from '~/utils'; import { buildDefaultConvo, getDefaultEndpoint, getEndpointField } from '~/utils';
import useOriginNavigate from './useOriginNavigate'; import useOriginNavigate from './useOriginNavigate';
import store from '~/store'; import store from '~/store';
@ -38,8 +38,9 @@ const useConversation = () => {
endpointsConfig, endpointsConfig,
}); });
if (!conversation.endpointType && endpointsConfig[defaultEndpoint]?.type) { const endpointType = getEndpointField(endpointsConfig, defaultEndpoint, 'type');
conversation.endpointType = endpointsConfig[defaultEndpoint]?.type; if (!conversation.endpointType && endpointType) {
conversation.endpointType = endpointType;
} }
const models = modelsConfig?.[defaultEndpoint] ?? []; const models = modelsConfig?.[defaultEndpoint] ?? [];

View file

@ -1,11 +1,15 @@
import { useQueryClient } from '@tanstack/react-query';
import { useSetRecoilState, useResetRecoilState } from 'recoil'; import { useSetRecoilState, useResetRecoilState } from 'recoil';
import type { TConversation } from 'librechat-data-provider'; import { QueryKeys } from 'librechat-data-provider';
import type { TConversation, TEndpointsConfig, TModelsConfig } from 'librechat-data-provider';
import { buildDefaultConvo, getDefaultEndpoint, getEndpointField } from '~/utils';
import useOriginNavigate from './useOriginNavigate'; import useOriginNavigate from './useOriginNavigate';
import useSetStorage from './useSetStorage'; import useSetStorage from './useSetStorage';
import store from '~/store'; import store from '~/store';
const useNavigateToConvo = (index = 0) => { const useNavigateToConvo = (index = 0) => {
const setStorage = useSetStorage(); const setStorage = useSetStorage();
const queryClient = useQueryClient();
const navigate = useOriginNavigate(); const navigate = useOriginNavigate();
const { setConversation } = store.useCreateConversationAtom(index); const { setConversation } = store.useCreateConversationAtom(index);
const setSubmission = useSetRecoilState(store.submissionByIndex(index)); const setSubmission = useSetRecoilState(store.submissionByIndex(index));
@ -21,9 +25,34 @@ const useNavigateToConvo = (index = 0) => {
if (_resetLatestMessage) { if (_resetLatestMessage) {
resetLatestMessage(); resetLatestMessage();
} }
setStorage(conversation);
setConversation(conversation); let convo = { ...conversation };
navigate(conversation?.conversationId); if (!convo?.endpoint) {
/* undefined endpoint edge case */
const modelsConfig = queryClient.getQueryData<TModelsConfig>([QueryKeys.models]);
const endpointsConfig = queryClient.getQueryData<TEndpointsConfig>([QueryKeys.endpoints]);
const defaultEndpoint = getDefaultEndpoint({
convoSetup: conversation,
endpointsConfig,
});
const endpointType = getEndpointField(endpointsConfig, defaultEndpoint, 'type');
if (!conversation.endpointType && endpointType) {
conversation.endpointType = endpointType;
}
const models = modelsConfig?.[defaultEndpoint ?? ''] ?? [];
convo = buildDefaultConvo({
conversation,
endpoint: defaultEndpoint,
lastConversationSetup: conversation,
models,
});
}
setStorage(convo);
setConversation(convo);
navigate(convo?.conversationId);
}; };
return { return {

View file

@ -14,7 +14,7 @@ import type {
TModelsConfig, TModelsConfig,
TEndpointsConfig, TEndpointsConfig,
} from 'librechat-data-provider'; } from 'librechat-data-provider';
import { buildDefaultConvo, getDefaultEndpoint } from '~/utils'; import { buildDefaultConvo, getDefaultEndpoint, getEndpointField } from '~/utils';
import { useDeleteFilesMutation } from '~/data-provider'; import { useDeleteFilesMutation } from '~/data-provider';
import useOriginNavigate from './useOriginNavigate'; import useOriginNavigate from './useOriginNavigate';
import useSetStorage from './useSetStorage'; import useSetStorage from './useSetStorage';
@ -69,8 +69,9 @@ const useNewConvo = (index = 0) => {
endpointsConfig, endpointsConfig,
}); });
if (!conversation.endpointType && endpointsConfig[defaultEndpoint]?.type) { const endpointType = getEndpointField(endpointsConfig, defaultEndpoint, 'type');
conversation.endpointType = endpointsConfig[defaultEndpoint]?.type; if (!conversation.endpointType && endpointType) {
conversation.endpointType = endpointType;
} }
const models = modelsConfig?.[defaultEndpoint] ?? []; const models = modelsConfig?.[defaultEndpoint] ?? [];

View file

@ -11,6 +11,7 @@ const buildDefaultConvo = ({
conversation: TConversation; conversation: TConversation;
endpoint: EModelEndpoint; endpoint: EModelEndpoint;
models: string[]; models: string[];
// TODO: fix this type as we should allow undefined
lastConversationSetup: TConversation; lastConversationSetup: TConversation;
}) => { }) => {
const { lastSelectedModel, lastSelectedTools, lastBingSettings } = getLocalStorageItems(); const { lastSelectedModel, lastSelectedTools, lastBingSettings } = getLocalStorageItems();

View file

@ -0,0 +1,94 @@
import { EModelEndpoint } from 'librechat-data-provider';
import type { TEndpointsConfig, TConfig } from 'librechat-data-provider';
import {
getEndpointField,
getAvailableEndpoints,
getEndpointsFilter,
mapEndpoints,
} from './endpoints';
const mockEndpointsConfig: TEndpointsConfig = {
[EModelEndpoint.openAI]: { type: undefined, iconURL: 'openAI_icon.png', order: 0 },
[EModelEndpoint.google]: { type: undefined, iconURL: 'google_icon.png', order: 1 },
Mistral: { type: EModelEndpoint.custom, iconURL: 'custom_icon.png', order: 2 },
};
describe('getEndpointField', () => {
it('returns undefined if endpointsConfig is undefined', () => {
expect(getEndpointField(undefined, EModelEndpoint.openAI, 'type')).toBeUndefined();
});
it('returns undefined if endpoint is null', () => {
expect(getEndpointField(mockEndpointsConfig, null, 'type')).toBeUndefined();
});
it('returns undefined if endpoint is undefined', () => {
expect(getEndpointField(mockEndpointsConfig, undefined, 'type')).toBeUndefined();
});
it('returns undefined if the endpoint does not exist in endpointsConfig', () => {
expect(getEndpointField(mockEndpointsConfig, EModelEndpoint.bingAI, 'type')).toBeUndefined();
});
it('returns the correct value for a valid endpoint and property', () => {
expect(getEndpointField(mockEndpointsConfig, EModelEndpoint.openAI, 'order')).toEqual(0);
expect(getEndpointField(mockEndpointsConfig, EModelEndpoint.google, 'iconURL')).toEqual(
'google_icon.png',
);
});
it('returns undefined for a valid endpoint but an invalid property', () => {
/* Type assertion as 'nonexistentProperty' is intentionally not a valid property of TConfig */
expect(
getEndpointField(
mockEndpointsConfig,
EModelEndpoint.openAI,
'nonexistentProperty' as keyof TConfig,
),
).toBeUndefined();
});
it('returns the correct value for a non-enum endpoint and valid property', () => {
expect(getEndpointField(mockEndpointsConfig, 'Mistral', 'type')).toEqual(EModelEndpoint.custom);
});
it('returns undefined for a non-enum endpoint with an invalid property', () => {
expect(
getEndpointField(mockEndpointsConfig, 'Mistral', 'nonexistentProperty' as keyof TConfig),
).toBeUndefined();
});
});
describe('getEndpointsFilter', () => {
it('returns an empty object if endpointsConfig is undefined', () => {
expect(getEndpointsFilter(undefined)).toEqual({});
});
it('returns a filter object based on endpointsConfig', () => {
const expectedFilter = {
[EModelEndpoint.openAI]: true,
[EModelEndpoint.google]: true,
Mistral: true,
};
expect(getEndpointsFilter(mockEndpointsConfig)).toEqual(expectedFilter);
});
});
describe('getAvailableEndpoints', () => {
it('returns available endpoints based on filter and config', () => {
const filter = {
[EModelEndpoint.openAI]: true,
[EModelEndpoint.google]: false,
Mistral: true,
};
const expectedEndpoints = [EModelEndpoint.openAI, 'Mistral'];
expect(getAvailableEndpoints(filter, mockEndpointsConfig)).toEqual(expectedEndpoints);
});
});
describe('mapEndpoints', () => {
it('returns sorted available endpoints', () => {
const expectedOrder = [EModelEndpoint.openAI, EModelEndpoint.google, 'Mistral'];
expect(mapEndpoints(mockEndpointsConfig)).toEqual(expectedOrder);
});
});

View file

@ -0,0 +1,58 @@
import { defaultEndpoints } from 'librechat-data-provider';
import type { EModelEndpoint, TEndpointsConfig, TConfig } from 'librechat-data-provider';
export const getEndpointsFilter = (endpointsConfig: TEndpointsConfig) => {
const filter: Record<string, boolean> = {};
if (!endpointsConfig) {
return filter;
}
for (const key of Object.keys(endpointsConfig)) {
filter[key] = !!endpointsConfig[key];
}
return filter;
};
export const getAvailableEndpoints = (
filter: Record<string, boolean>,
endpointsConfig: TEndpointsConfig,
) => {
const defaultSet = new Set(defaultEndpoints);
const availableEndpoints: EModelEndpoint[] = [];
for (const endpoint in endpointsConfig) {
// Check if endpoint is in the filter or its type is in defaultEndpoints
if (
filter[endpoint] ||
(endpointsConfig[endpoint]?.type &&
defaultSet.has(endpointsConfig[endpoint]?.type as EModelEndpoint))
) {
availableEndpoints.push(endpoint as EModelEndpoint);
}
}
return availableEndpoints;
};
export function getEndpointField<K extends keyof TConfig>(
endpointsConfig: TEndpointsConfig | undefined,
endpoint: EModelEndpoint | string | null | undefined,
property: K,
): TConfig[K] | undefined {
if (!endpointsConfig || endpoint === null || endpoint === undefined) {
return undefined;
}
const config = endpointsConfig[endpoint];
if (!config) {
return undefined;
}
return config[property];
}
export function mapEndpoints(endpointsConfig: TEndpointsConfig) {
const filter = getEndpointsFilter(endpointsConfig);
return getAvailableEndpoints(filter, endpointsConfig).sort(
(a, b) => (endpointsConfig?.[a]?.order ?? 0) - (endpointsConfig?.[b]?.order ?? 0),
);
}

View file

@ -5,7 +5,7 @@ import type {
EModelEndpoint, EModelEndpoint,
} from 'librechat-data-provider'; } from 'librechat-data-provider';
import getLocalStorageItems from './getLocalStorageItems'; import getLocalStorageItems from './getLocalStorageItems';
import mapEndpoints from './mapEndpoints'; import { mapEndpoints } from './endpoints';
type TConvoSetup = Partial<TPreset> | Partial<TConversation>; type TConvoSetup = Partial<TPreset> | Partial<TConversation>;
@ -13,7 +13,7 @@ type TDefaultEndpoint = { convoSetup: TConvoSetup; endpointsConfig: TEndpointsCo
const getEndpointFromSetup = (convoSetup: TConvoSetup, endpointsConfig: TEndpointsConfig) => { const getEndpointFromSetup = (convoSetup: TConvoSetup, endpointsConfig: TEndpointsConfig) => {
const { endpoint: targetEndpoint } = convoSetup || {}; const { endpoint: targetEndpoint } = convoSetup || {};
if (targetEndpoint && endpointsConfig?.[targetEndpoint]) { if (targetEndpoint && endpointsConfig?.[targetEndpoint ?? '']) {
return targetEndpoint; return targetEndpoint;
} else if (targetEndpoint) { } else if (targetEndpoint) {
console.warn(`Illegal target endpoint ${targetEndpoint} ${endpointsConfig}`); console.warn(`Illegal target endpoint ${targetEndpoint} ${endpointsConfig}`);
@ -35,7 +35,7 @@ const getEndpointFromLocalStorage = (endpointsConfig: TEndpointsConfig) => {
return endpoint; return endpoint;
} }
return endpoint && endpointsConfig[endpoint] ? endpoint : null; return endpoint && endpointsConfig?.[endpoint ?? ''] ? endpoint : null;
} catch (error) { } catch (error) {
console.error(error); console.error(error);
return null; return null;

View file

@ -1,9 +1,9 @@
export * from './json'; export * from './json';
export * from './presets'; export * from './presets';
export * from './languages'; export * from './languages';
export * from './endpoints';
export { default as cn } from './cn'; export { default as cn } from './cn';
export { default as buildTree } from './buildTree'; export { default as buildTree } from './buildTree';
export { default as mapEndpoints } from './mapEndpoints';
export { default as getLoginError } from './getLoginError'; export { default as getLoginError } from './getLoginError';
export { default as cleanupPreset } from './cleanupPreset'; export { default as cleanupPreset } from './cleanupPreset';
export { default as validateIframe } from './validateIframe'; export { default as validateIframe } from './validateIframe';

View file

@ -1,37 +0,0 @@
import { defaultEndpoints } from 'librechat-data-provider';
import type { EModelEndpoint, TEndpointsConfig } from 'librechat-data-provider';
const getEndpointsFilter = (endpointsConfig: TEndpointsConfig) => {
const filter: Record<string, boolean> = {};
for (const key of Object.keys(endpointsConfig)) {
filter[key] = !!endpointsConfig[key];
}
return filter;
};
const getAvailableEndpoints = (
filter: Record<string, boolean>,
endpointsConfig: TEndpointsConfig,
) => {
const defaultSet = new Set(defaultEndpoints);
const availableEndpoints: EModelEndpoint[] = [];
for (const endpoint in endpointsConfig) {
// Check if endpoint is in the filter or its type is in defaultEndpoints
if (
filter[endpoint] ||
(endpointsConfig[endpoint]?.type && defaultSet.has(endpointsConfig[endpoint].type))
) {
availableEndpoints.push(endpoint as EModelEndpoint);
}
}
return availableEndpoints;
};
export default function mapEndpoints(endpointsConfig: TEndpointsConfig) {
const filter = getEndpointsFilter(endpointsConfig);
return getAvailableEndpoints(filter, endpointsConfig).sort(
(a, b) => (endpointsConfig[a]?.order ?? 0) - (endpointsConfig[b]?.order ?? 0),
);
}

View file

@ -139,9 +139,11 @@ export type TConfig = {
userProvideURL?: boolean | null; userProvideURL?: boolean | null;
}; };
export type TModelsConfig = Record<string, string[]>; export type TEndpointsConfig =
| Record<EModelEndpoint | string, TConfig | null | undefined>
| undefined;
export type TEndpointsConfig = Record<EModelEndpoint, TConfig | null>; export type TModelsConfig = Record<string, string[]>;
export type TUpdateTokenCountResponse = { export type TUpdateTokenCountResponse = {
count: number; count: number;