🔧 fix: Improve Endpoint Handling and Address Edge Cases (#1486)

* fix(TEndpointsConfig): resolve property access issues with typesafe helper function

* fix: undefined or null endpoint edge case

* refactor(mapEndpoints -> endpoints): renamed module to be more general for endpoint handling, wrote unit tests, export all helpers
This commit is contained in:
Danny Avila 2024-01-04 10:17:15 -05:00 committed by GitHub
parent 42f2353509
commit 9864fc8700
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
24 changed files with 275 additions and 99 deletions

View file

@ -11,6 +11,7 @@ const buildDefaultConvo = ({
conversation: TConversation;
endpoint: EModelEndpoint;
models: string[];
// TODO: fix this type as we should allow undefined
lastConversationSetup: TConversation;
}) => {
const { lastSelectedModel, lastSelectedTools, lastBingSettings } = getLocalStorageItems();

View file

@ -0,0 +1,94 @@
import { EModelEndpoint } from 'librechat-data-provider';
import type { TEndpointsConfig, TConfig } from 'librechat-data-provider';
import {
getEndpointField,
getAvailableEndpoints,
getEndpointsFilter,
mapEndpoints,
} from './endpoints';
const mockEndpointsConfig: TEndpointsConfig = {
[EModelEndpoint.openAI]: { type: undefined, iconURL: 'openAI_icon.png', order: 0 },
[EModelEndpoint.google]: { type: undefined, iconURL: 'google_icon.png', order: 1 },
Mistral: { type: EModelEndpoint.custom, iconURL: 'custom_icon.png', order: 2 },
};
describe('getEndpointField', () => {
it('returns undefined if endpointsConfig is undefined', () => {
expect(getEndpointField(undefined, EModelEndpoint.openAI, 'type')).toBeUndefined();
});
it('returns undefined if endpoint is null', () => {
expect(getEndpointField(mockEndpointsConfig, null, 'type')).toBeUndefined();
});
it('returns undefined if endpoint is undefined', () => {
expect(getEndpointField(mockEndpointsConfig, undefined, 'type')).toBeUndefined();
});
it('returns undefined if the endpoint does not exist in endpointsConfig', () => {
expect(getEndpointField(mockEndpointsConfig, EModelEndpoint.bingAI, 'type')).toBeUndefined();
});
it('returns the correct value for a valid endpoint and property', () => {
expect(getEndpointField(mockEndpointsConfig, EModelEndpoint.openAI, 'order')).toEqual(0);
expect(getEndpointField(mockEndpointsConfig, EModelEndpoint.google, 'iconURL')).toEqual(
'google_icon.png',
);
});
it('returns undefined for a valid endpoint but an invalid property', () => {
/* Type assertion as 'nonexistentProperty' is intentionally not a valid property of TConfig */
expect(
getEndpointField(
mockEndpointsConfig,
EModelEndpoint.openAI,
'nonexistentProperty' as keyof TConfig,
),
).toBeUndefined();
});
it('returns the correct value for a non-enum endpoint and valid property', () => {
expect(getEndpointField(mockEndpointsConfig, 'Mistral', 'type')).toEqual(EModelEndpoint.custom);
});
it('returns undefined for a non-enum endpoint with an invalid property', () => {
expect(
getEndpointField(mockEndpointsConfig, 'Mistral', 'nonexistentProperty' as keyof TConfig),
).toBeUndefined();
});
});
describe('getEndpointsFilter', () => {
it('returns an empty object if endpointsConfig is undefined', () => {
expect(getEndpointsFilter(undefined)).toEqual({});
});
it('returns a filter object based on endpointsConfig', () => {
const expectedFilter = {
[EModelEndpoint.openAI]: true,
[EModelEndpoint.google]: true,
Mistral: true,
};
expect(getEndpointsFilter(mockEndpointsConfig)).toEqual(expectedFilter);
});
});
describe('getAvailableEndpoints', () => {
it('returns available endpoints based on filter and config', () => {
const filter = {
[EModelEndpoint.openAI]: true,
[EModelEndpoint.google]: false,
Mistral: true,
};
const expectedEndpoints = [EModelEndpoint.openAI, 'Mistral'];
expect(getAvailableEndpoints(filter, mockEndpointsConfig)).toEqual(expectedEndpoints);
});
});
describe('mapEndpoints', () => {
it('returns sorted available endpoints', () => {
const expectedOrder = [EModelEndpoint.openAI, EModelEndpoint.google, 'Mistral'];
expect(mapEndpoints(mockEndpointsConfig)).toEqual(expectedOrder);
});
});

View file

@ -0,0 +1,58 @@
import { defaultEndpoints } from 'librechat-data-provider';
import type { EModelEndpoint, TEndpointsConfig, TConfig } from 'librechat-data-provider';
export const getEndpointsFilter = (endpointsConfig: TEndpointsConfig) => {
const filter: Record<string, boolean> = {};
if (!endpointsConfig) {
return filter;
}
for (const key of Object.keys(endpointsConfig)) {
filter[key] = !!endpointsConfig[key];
}
return filter;
};
export const getAvailableEndpoints = (
filter: Record<string, boolean>,
endpointsConfig: TEndpointsConfig,
) => {
const defaultSet = new Set(defaultEndpoints);
const availableEndpoints: EModelEndpoint[] = [];
for (const endpoint in endpointsConfig) {
// Check if endpoint is in the filter or its type is in defaultEndpoints
if (
filter[endpoint] ||
(endpointsConfig[endpoint]?.type &&
defaultSet.has(endpointsConfig[endpoint]?.type as EModelEndpoint))
) {
availableEndpoints.push(endpoint as EModelEndpoint);
}
}
return availableEndpoints;
};
export function getEndpointField<K extends keyof TConfig>(
endpointsConfig: TEndpointsConfig | undefined,
endpoint: EModelEndpoint | string | null | undefined,
property: K,
): TConfig[K] | undefined {
if (!endpointsConfig || endpoint === null || endpoint === undefined) {
return undefined;
}
const config = endpointsConfig[endpoint];
if (!config) {
return undefined;
}
return config[property];
}
export function mapEndpoints(endpointsConfig: TEndpointsConfig) {
const filter = getEndpointsFilter(endpointsConfig);
return getAvailableEndpoints(filter, endpointsConfig).sort(
(a, b) => (endpointsConfig?.[a]?.order ?? 0) - (endpointsConfig?.[b]?.order ?? 0),
);
}

View file

@ -5,7 +5,7 @@ import type {
EModelEndpoint,
} from 'librechat-data-provider';
import getLocalStorageItems from './getLocalStorageItems';
import mapEndpoints from './mapEndpoints';
import { mapEndpoints } from './endpoints';
type TConvoSetup = Partial<TPreset> | Partial<TConversation>;
@ -13,7 +13,7 @@ type TDefaultEndpoint = { convoSetup: TConvoSetup; endpointsConfig: TEndpointsCo
const getEndpointFromSetup = (convoSetup: TConvoSetup, endpointsConfig: TEndpointsConfig) => {
const { endpoint: targetEndpoint } = convoSetup || {};
if (targetEndpoint && endpointsConfig?.[targetEndpoint]) {
if (targetEndpoint && endpointsConfig?.[targetEndpoint ?? '']) {
return targetEndpoint;
} else if (targetEndpoint) {
console.warn(`Illegal target endpoint ${targetEndpoint} ${endpointsConfig}`);
@ -35,7 +35,7 @@ const getEndpointFromLocalStorage = (endpointsConfig: TEndpointsConfig) => {
return endpoint;
}
return endpoint && endpointsConfig[endpoint] ? endpoint : null;
return endpoint && endpointsConfig?.[endpoint ?? ''] ? endpoint : null;
} catch (error) {
console.error(error);
return null;

View file

@ -1,9 +1,9 @@
export * from './json';
export * from './presets';
export * from './languages';
export * from './endpoints';
export { default as cn } from './cn';
export { default as buildTree } from './buildTree';
export { default as mapEndpoints } from './mapEndpoints';
export { default as getLoginError } from './getLoginError';
export { default as cleanupPreset } from './cleanupPreset';
export { default as validateIframe } from './validateIframe';

View file

@ -1,37 +0,0 @@
import { defaultEndpoints } from 'librechat-data-provider';
import type { EModelEndpoint, TEndpointsConfig } from 'librechat-data-provider';
const getEndpointsFilter = (endpointsConfig: TEndpointsConfig) => {
const filter: Record<string, boolean> = {};
for (const key of Object.keys(endpointsConfig)) {
filter[key] = !!endpointsConfig[key];
}
return filter;
};
const getAvailableEndpoints = (
filter: Record<string, boolean>,
endpointsConfig: TEndpointsConfig,
) => {
const defaultSet = new Set(defaultEndpoints);
const availableEndpoints: EModelEndpoint[] = [];
for (const endpoint in endpointsConfig) {
// Check if endpoint is in the filter or its type is in defaultEndpoints
if (
filter[endpoint] ||
(endpointsConfig[endpoint]?.type && defaultSet.has(endpointsConfig[endpoint].type))
) {
availableEndpoints.push(endpoint as EModelEndpoint);
}
}
return availableEndpoints;
};
export default function mapEndpoints(endpointsConfig: TEndpointsConfig) {
const filter = getEndpointsFilter(endpointsConfig);
return getAvailableEndpoints(filter, endpointsConfig).sort(
(a, b) => (endpointsConfig[a]?.order ?? 0) - (endpointsConfig[b]?.order ?? 0),
);
}