mirror of
https://github.com/danny-avila/LibreChat.git
synced 2025-12-23 03:40:14 +01:00
225 lines
6.7 KiB
TypeScript
225 lines
6.7 KiB
TypeScript
|
|
import React, { startTransition, createContext, useContext, useState, useMemo } from 'react';
|
||
|
|
import { EModelEndpoint, isAgentsEndpoint, isAssistantsEndpoint } from 'librechat-data-provider';
|
||
|
|
import type * as t from 'librechat-data-provider';
|
||
|
|
import type { Endpoint, SelectedValues } from '~/common';
|
||
|
|
import { useAgentsMapContext, useAssistantsMapContext, useChatContext } from '~/Providers';
|
||
|
|
import { useEndpoints, useSelectorEffects, useKeyDialog, useLocalize } from '~/hooks';
|
||
|
|
import useSelectMention from '~/hooks/Input/useSelectMention';
|
||
|
|
import { useGetEndpointsQuery } from '~/data-provider';
|
||
|
|
import { filterItems } from './utils';
|
||
|
|
|
||
|
|
type ModelSelectorContextType = {
|
||
|
|
// State
|
||
|
|
searchValue: string;
|
||
|
|
selectedValues: SelectedValues;
|
||
|
|
endpointSearchValues: Record<string, string>;
|
||
|
|
searchResults: (t.TModelSpec | Endpoint)[] | null;
|
||
|
|
// LibreChat
|
||
|
|
modelSpecs: t.TModelSpec[];
|
||
|
|
mappedEndpoints: Endpoint[];
|
||
|
|
agentsMap: t.TAgentsMap | undefined;
|
||
|
|
assistantsMap: t.TAssistantsMap | undefined;
|
||
|
|
endpointsConfig: t.TEndpointsConfig;
|
||
|
|
|
||
|
|
// Functions
|
||
|
|
getDisplayValue: () => string;
|
||
|
|
endpointRequiresUserKey: (endpoint: string) => boolean;
|
||
|
|
setSelectedValues: React.Dispatch<React.SetStateAction<SelectedValues>>;
|
||
|
|
setSearchValue: (value: string) => void;
|
||
|
|
setEndpointSearchValue: (endpoint: string, value: string) => void;
|
||
|
|
handleSelectSpec: (spec: t.TModelSpec) => void;
|
||
|
|
handleSelectEndpoint: (endpoint: Endpoint) => void;
|
||
|
|
handleSelectModel: (endpoint: Endpoint, model: string) => void;
|
||
|
|
} & ReturnType<typeof useKeyDialog>;
|
||
|
|
|
||
|
|
const ModelSelectorContext = createContext<ModelSelectorContextType | undefined>(undefined);
|
||
|
|
|
||
|
|
export function useModelSelectorContext() {
|
||
|
|
const context = useContext(ModelSelectorContext);
|
||
|
|
if (context === undefined) {
|
||
|
|
throw new Error('useModelSelectorContext must be used within a ModelSelectorProvider');
|
||
|
|
}
|
||
|
|
return context;
|
||
|
|
}
|
||
|
|
|
||
|
|
interface ModelSelectorProviderProps {
|
||
|
|
children: React.ReactNode;
|
||
|
|
modelSpecs: t.TModelSpec[];
|
||
|
|
interfaceConfig: t.TInterfaceConfig;
|
||
|
|
}
|
||
|
|
|
||
|
|
export function ModelSelectorProvider({
|
||
|
|
children,
|
||
|
|
modelSpecs,
|
||
|
|
interfaceConfig,
|
||
|
|
}: ModelSelectorProviderProps) {
|
||
|
|
const localize = useLocalize();
|
||
|
|
const agentsMap = useAgentsMapContext();
|
||
|
|
const assistantsMap = useAssistantsMapContext();
|
||
|
|
const { data: endpointsConfig } = useGetEndpointsQuery();
|
||
|
|
const { conversation, newConversation } = useChatContext();
|
||
|
|
const { mappedEndpoints, endpointRequiresUserKey } = useEndpoints({
|
||
|
|
agentsMap,
|
||
|
|
assistantsMap,
|
||
|
|
endpointsConfig,
|
||
|
|
interfaceConfig,
|
||
|
|
});
|
||
|
|
const { onSelectEndpoint, onSelectSpec } = useSelectMention({
|
||
|
|
// presets,
|
||
|
|
modelSpecs,
|
||
|
|
assistantsMap,
|
||
|
|
endpointsConfig,
|
||
|
|
newConversation,
|
||
|
|
returnHandlers: true,
|
||
|
|
});
|
||
|
|
|
||
|
|
// State
|
||
|
|
const [selectedValues, setSelectedValues] = useState<SelectedValues>({
|
||
|
|
endpoint: conversation?.endpoint || '',
|
||
|
|
model: conversation?.model || '',
|
||
|
|
modelSpec: conversation?.spec || '',
|
||
|
|
});
|
||
|
|
useSelectorEffects({
|
||
|
|
agentsMap,
|
||
|
|
conversation,
|
||
|
|
assistantsMap,
|
||
|
|
setSelectedValues,
|
||
|
|
});
|
||
|
|
|
||
|
|
const [searchValue, setSearchValueState] = useState('');
|
||
|
|
const [endpointSearchValues, setEndpointSearchValues] = useState<Record<string, string>>({});
|
||
|
|
|
||
|
|
const keyProps = useKeyDialog();
|
||
|
|
|
||
|
|
// Memoized search results
|
||
|
|
const searchResults = useMemo(() => {
|
||
|
|
if (!searchValue) {
|
||
|
|
return null;
|
||
|
|
}
|
||
|
|
const allItems = [...modelSpecs, ...mappedEndpoints];
|
||
|
|
return filterItems(allItems, searchValue, agentsMap, assistantsMap || {});
|
||
|
|
}, [searchValue, modelSpecs, mappedEndpoints, agentsMap, assistantsMap]);
|
||
|
|
|
||
|
|
// Functions
|
||
|
|
const setSearchValue = (value: string) => {
|
||
|
|
startTransition(() => setSearchValueState(value));
|
||
|
|
};
|
||
|
|
|
||
|
|
const setEndpointSearchValue = (endpoint: string, value: string) => {
|
||
|
|
setEndpointSearchValues((prev) => ({
|
||
|
|
...prev,
|
||
|
|
[endpoint]: value,
|
||
|
|
}));
|
||
|
|
};
|
||
|
|
|
||
|
|
const handleSelectSpec = (spec: t.TModelSpec) => {
|
||
|
|
onSelectSpec?.(spec);
|
||
|
|
setSelectedValues({
|
||
|
|
endpoint: spec.preset.endpoint,
|
||
|
|
model: spec.preset.model ?? null,
|
||
|
|
modelSpec: spec.name,
|
||
|
|
});
|
||
|
|
};
|
||
|
|
|
||
|
|
const handleSelectEndpoint = (endpoint: Endpoint) => {
|
||
|
|
if (!endpoint.hasModels) {
|
||
|
|
if (endpoint.value) {
|
||
|
|
onSelectEndpoint?.(endpoint.value);
|
||
|
|
}
|
||
|
|
setSelectedValues({
|
||
|
|
endpoint: endpoint.value,
|
||
|
|
model: '',
|
||
|
|
modelSpec: '',
|
||
|
|
});
|
||
|
|
}
|
||
|
|
};
|
||
|
|
|
||
|
|
const handleSelectModel = (endpoint: Endpoint, model: string) => {
|
||
|
|
if (isAgentsEndpoint(endpoint.value)) {
|
||
|
|
onSelectEndpoint?.(endpoint.value, {
|
||
|
|
agent_id: model,
|
||
|
|
});
|
||
|
|
} else if (isAssistantsEndpoint(endpoint.value)) {
|
||
|
|
onSelectEndpoint?.(endpoint.value, {
|
||
|
|
assistant_id: model,
|
||
|
|
model: assistantsMap?.[endpoint.value]?.[model]?.model ?? '',
|
||
|
|
});
|
||
|
|
} else if (endpoint.value) {
|
||
|
|
onSelectEndpoint?.(endpoint.value, { model });
|
||
|
|
}
|
||
|
|
setSelectedValues({
|
||
|
|
endpoint: endpoint.value,
|
||
|
|
model: model,
|
||
|
|
modelSpec: '',
|
||
|
|
});
|
||
|
|
};
|
||
|
|
|
||
|
|
const getDisplayValue = () => {
|
||
|
|
if (selectedValues.modelSpec) {
|
||
|
|
const spec = modelSpecs.find((s) => s.name === selectedValues.modelSpec);
|
||
|
|
return spec?.label || localize('com_endpoint_select_model');
|
||
|
|
}
|
||
|
|
|
||
|
|
if (selectedValues.model && selectedValues.endpoint) {
|
||
|
|
const endpoint = mappedEndpoints.find((e) => e.value === selectedValues.endpoint);
|
||
|
|
if (!endpoint) {
|
||
|
|
return localize('com_endpoint_select_model');
|
||
|
|
}
|
||
|
|
|
||
|
|
if (
|
||
|
|
endpoint.value === EModelEndpoint.agents &&
|
||
|
|
endpoint.agentNames &&
|
||
|
|
endpoint.agentNames[selectedValues.model]
|
||
|
|
) {
|
||
|
|
return endpoint.agentNames[selectedValues.model];
|
||
|
|
}
|
||
|
|
|
||
|
|
if (
|
||
|
|
(endpoint.value === EModelEndpoint.assistants ||
|
||
|
|
endpoint.value === EModelEndpoint.azureAssistants) &&
|
||
|
|
endpoint.assistantNames &&
|
||
|
|
endpoint.assistantNames[selectedValues.model]
|
||
|
|
) {
|
||
|
|
return endpoint.assistantNames[selectedValues.model];
|
||
|
|
}
|
||
|
|
|
||
|
|
return selectedValues.model;
|
||
|
|
}
|
||
|
|
|
||
|
|
if (selectedValues.endpoint) {
|
||
|
|
const endpoint = mappedEndpoints.find((e) => e.value === selectedValues.endpoint);
|
||
|
|
return endpoint?.label || localize('com_endpoint_select_model');
|
||
|
|
}
|
||
|
|
|
||
|
|
return localize('com_endpoint_select_model');
|
||
|
|
};
|
||
|
|
|
||
|
|
const value = {
|
||
|
|
// State
|
||
|
|
searchValue,
|
||
|
|
searchResults,
|
||
|
|
selectedValues,
|
||
|
|
endpointSearchValues,
|
||
|
|
// LibreChat
|
||
|
|
agentsMap,
|
||
|
|
modelSpecs,
|
||
|
|
assistantsMap,
|
||
|
|
mappedEndpoints,
|
||
|
|
endpointsConfig,
|
||
|
|
|
||
|
|
// Functions
|
||
|
|
setSearchValue,
|
||
|
|
getDisplayValue,
|
||
|
|
handleSelectSpec,
|
||
|
|
handleSelectModel,
|
||
|
|
setSelectedValues,
|
||
|
|
handleSelectEndpoint,
|
||
|
|
setEndpointSearchValue,
|
||
|
|
endpointRequiresUserKey,
|
||
|
|
// Dialog
|
||
|
|
...keyProps,
|
||
|
|
};
|
||
|
|
|
||
|
|
return <ModelSelectorContext.Provider value={value}>{children}</ModelSelectorContext.Provider>;
|
||
|
|
}
|