🔀 fix: Endpoint Type Mismatch when Switching Conversations (#1834)

* refactor(useUpdateUserKeysMutation): only invalidate the endpoint whose key is being updated by user

* fix(assistants): await `getUserKeyExpiry` call

* chore: fix spinner loading color

* refactor(initializeClient): make known which endpoint api Key is missing

* fix: prevent an `endpointType` mismatch by making it impossible to assign when the `endpointsConfig` doesn't have a `type` defined, also prefer `getQueryData` call to useQuery in useChatHelpers
This commit is contained in:
Danny Avila 2024-02-19 01:31:38 -05:00 committed by GitHub
parent d1eb7fcfc7
commit 5291d18f38
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 28 additions and 18 deletions

View file

@ -32,7 +32,10 @@ const initializeClient = async ({ req, res, endpointOption, initAppClient = fals
let userKey = null; let userKey = null;
if (isUserProvided) { if (isUserProvided) {
const expiresAt = getUserKeyExpiry({ userId: req.user.id, name: EModelEndpoint.assistants }); const expiresAt = await getUserKeyExpiry({
userId: req.user.id,
name: EModelEndpoint.assistants,
});
checkUserKeyExpiry( checkUserKeyExpiry(
expiresAt, expiresAt,
'Your Assistants API key has expired. Please provide your API key again.', 'Your Assistants API key has expired. Please provide your API key again.',
@ -43,7 +46,7 @@ const initializeClient = async ({ req, res, endpointOption, initAppClient = fals
let apiKey = isUserProvided ? userKey : credentials; let apiKey = isUserProvided ? userKey : credentials;
if (!apiKey) { if (!apiKey) {
throw new Error('API key not provided.'); throw new Error(`${EModelEndpoint.assistants} API key not provided.`);
} }
/** @type {OpenAIClient} */ /** @type {OpenAIClient} */

View file

@ -66,7 +66,7 @@ const initializeClient = async ({ req, res, endpointOption }) => {
} }
if (!apiKey) { if (!apiKey) {
throw new Error('API key not provided.'); throw new Error(`${endpoint} API key not provided.`);
} }
const client = new PluginsClient(apiKey, clientOptions); const client = new PluginsClient(apiKey, clientOptions);

View file

@ -1,7 +1,8 @@
// gptPlugins/initializeClient.spec.js // gptPlugins/initializeClient.spec.js
const { PluginsClient } = require('~/app'); const { EModelEndpoint } = require('librechat-data-provider');
const { getUserKey } = require('~/server/services/UserService');
const initializeClient = require('./initializeClient'); const initializeClient = require('./initializeClient');
const { getUserKey } = require('../../UserService'); const { PluginsClient } = require('~/app');
// Mock getUserKey since it's the only function we want to mock // Mock getUserKey since it's the only function we want to mock
jest.mock('~/server/services/UserService', () => ({ jest.mock('~/server/services/UserService', () => ({
@ -112,7 +113,7 @@ describe('gptPlugins/initializeClient', () => {
const endpointOption = { modelOptions: { model: 'default-model' } }; const endpointOption = { modelOptions: { model: 'default-model' } };
await expect(initializeClient({ req, res, endpointOption })).rejects.toThrow( await expect(initializeClient({ req, res, endpointOption })).rejects.toThrow(
'API key not provided.', `${EModelEndpoint.openAI} API key not provided.`,
); );
}); });

View file

@ -58,7 +58,7 @@ const initializeClient = async ({ req, res, endpointOption }) => {
} }
if (!apiKey) { if (!apiKey) {
throw new Error('API key not provided.'); throw new Error(`${endpoint} API key not provided.`);
} }
const client = new OpenAIClient(apiKey, clientOptions); const client = new OpenAIClient(apiKey, clientOptions);

View file

@ -1,6 +1,7 @@
const { OpenAIClient } = require('~/app'); const { EModelEndpoint } = require('librechat-data-provider');
const initializeClient = require('./initializeClient');
const { getUserKey } = require('~/server/services/UserService'); const { getUserKey } = require('~/server/services/UserService');
const initializeClient = require('./initializeClient');
const { OpenAIClient } = require('~/app');
// Mock getUserKey since it's the only function we want to mock // Mock getUserKey since it's the only function we want to mock
jest.mock('~/server/services/UserService', () => ({ jest.mock('~/server/services/UserService', () => ({
@ -145,7 +146,7 @@ describe('initializeClient', () => {
const endpointOption = {}; const endpointOption = {};
await expect(initializeClient({ req, res, endpointOption })).rejects.toThrow( await expect(initializeClient({ req, res, endpointOption })).rejects.toThrow(
'API key not provided.', `${EModelEndpoint.openAI} API key not provided.`,
); );
}); });

View file

@ -3,7 +3,7 @@ import { cn } from '~/utils/';
export default function Spinner({ className = 'm-auto', size = '1em' }) { export default function Spinner({ className = 'm-auto', size = '1em' }) {
return ( return (
<svg <svg
stroke="#ffffff" stroke="currentColor"
fill="none" fill="none"
strokeWidth="2" strokeWidth="2"
viewBox="0 0 24 24" viewBox="0 0 24 24"

View file

@ -9,7 +9,7 @@ import {
ContentTypes, ContentTypes,
} from 'librechat-data-provider'; } from 'librechat-data-provider';
import { useRecoilState, useResetRecoilState, useSetRecoilState } from 'recoil'; import { useRecoilState, useResetRecoilState, useSetRecoilState } from 'recoil';
import { useGetMessagesByConvoId, useGetEndpointsQuery } from 'librechat-data-provider/react-query'; import { useGetMessagesByConvoId } from 'librechat-data-provider/react-query';
import type { import type {
TMessage, TMessage,
TSubmission, TSubmission,
@ -21,12 +21,12 @@ import useSetFilesToDelete from './Files/useSetFilesToDelete';
import useGetSender from './Conversations/useGetSender'; import useGetSender from './Conversations/useGetSender';
import { useAuthContext } from './AuthContext'; import { useAuthContext } from './AuthContext';
import useUserKey from './Input/useUserKey'; import useUserKey from './Input/useUserKey';
import { getEndpointField } from '~/utils';
import useNewConvo from './useNewConvo'; import useNewConvo from './useNewConvo';
import store from '~/store'; import store from '~/store';
// this to be set somewhere else // this to be set somewhere else
export default function useChatHelpers(index = 0, paramId: string | undefined) { export default function useChatHelpers(index = 0, paramId: string | undefined) {
const { data: endpointsConfig = {} as TEndpointsConfig } = useGetEndpointsQuery();
const setShowStopButton = useSetRecoilState(store.showStopButtonByIndex(index)); const setShowStopButton = useSetRecoilState(store.showStopButtonByIndex(index));
const [files, setFiles] = useRecoilState(store.filesByIndex(index)); const [files, setFiles] = useRecoilState(store.filesByIndex(index));
const [filesLoading, setFilesLoading] = useState(false); const [filesLoading, setFilesLoading] = useState(false);
@ -39,7 +39,7 @@ export default function useChatHelpers(index = 0, paramId: string | undefined) {
const { newConversation } = useNewConvo(index); const { newConversation } = useNewConvo(index);
const { useCreateConversationAtom } = store; const { useCreateConversationAtom } = store;
const { conversation, setConversation } = useCreateConversationAtom(index); const { conversation, setConversation } = useCreateConversationAtom(index);
const { conversationId, endpoint, endpointType } = conversation ?? {}; const { conversationId, endpoint } = conversation ?? {};
const queryParam = paramId === 'new' ? paramId : conversationId ?? paramId ?? ''; const queryParam = paramId === 'new' ? paramId : conversationId ?? paramId ?? '';
@ -142,6 +142,9 @@ export default function useChatHelpers(index = 0, paramId: string | undefined) {
const thread_id = parentMessage?.thread_id ?? latestMessage?.thread_id; const thread_id = parentMessage?.thread_id ?? latestMessage?.thread_id;
const endpointsConfig = queryClient.getQueryData<TEndpointsConfig>([QueryKeys.endpoints]);
const endpointType = getEndpointField(endpointsConfig, endpoint, 'type');
// set the endpoint option // set the endpoint option
const convo = parseCompactConvo({ const convo = parseCompactConvo({
endpoint, endpoint,

View file

@ -77,6 +77,8 @@ const useNewConvo = (index = 0) => {
const endpointType = getEndpointField(endpointsConfig, defaultEndpoint, 'type'); const endpointType = getEndpointField(endpointsConfig, defaultEndpoint, 'type');
if (!conversation.endpointType && endpointType) { if (!conversation.endpointType && endpointType) {
conversation.endpointType = endpointType; conversation.endpointType = endpointType;
} else if (conversation.endpointType && !endpointType) {
conversation.endpointType = undefined;
} }
if (!conversation.assistant_id && defaultEndpoint === EModelEndpoint.assistants) { if (!conversation.assistant_id && defaultEndpoint === EModelEndpoint.assistants) {

View file

@ -67,7 +67,7 @@ export default function ChatRoute() {
}, [initialConvoQuery.data, modelsQuery.data, endpointsQuery.data]); }, [initialConvoQuery.data, modelsQuery.data, endpointsQuery.data]);
if (endpointsQuery.isLoading || modelsQuery.isLoading) { if (endpointsQuery.isLoading || modelsQuery.isLoading) {
return <Spinner className="m-auto dark:text-white" />; return <Spinner className="m-auto text-black dark:text-white" />;
} }
if (!isAuthenticated) { if (!isAuthenticated) {

View file

@ -117,8 +117,8 @@ export const useUpdateUserKeysMutation = (): UseMutationResult<
> => { > => {
const queryClient = useQueryClient(); const queryClient = useQueryClient();
return useMutation((payload: t.TUpdateUserKeyRequest) => dataService.updateUserKey(payload), { return useMutation((payload: t.TUpdateUserKeyRequest) => dataService.updateUserKey(payload), {
onSuccess: () => { onSuccess: (data, variables) => {
queryClient.invalidateQueries([QueryKeys.name]); queryClient.invalidateQueries([QueryKeys.name, variables.name]);
}, },
}); });
}; };
@ -136,7 +136,7 @@ export const useRevokeUserKeyMutation = (name: string): UseMutationResult<unknow
const queryClient = useQueryClient(); const queryClient = useQueryClient();
return useMutation(() => dataService.revokeUserKey(name), { return useMutation(() => dataService.revokeUserKey(name), {
onSuccess: () => { onSuccess: () => {
queryClient.invalidateQueries([QueryKeys.name]); queryClient.invalidateQueries([QueryKeys.name, name]);
if (name === s.EModelEndpoint.assistants) { if (name === s.EModelEndpoint.assistants) {
queryClient.invalidateQueries([QueryKeys.assistants, defaultOrderQuery]); queryClient.invalidateQueries([QueryKeys.assistants, defaultOrderQuery]);
queryClient.invalidateQueries([QueryKeys.assistantDocs]); queryClient.invalidateQueries([QueryKeys.assistantDocs]);