🎚️ feat: Custom Parameters (#7342)

* #

* - refactor: simplified getCustomConfig func

* #

* - feature: persist values for parameters with optionType of custom

* #

* - refactor: moved `Parameters/settings.ts` into `data-provider` so that both frontend and backend code can use it.

* - feature: loadCustomConfig can now parse and validate customParams property for `endpoints.custom` in `librechat.yaml`

* # fixed linter

* # removed .strict() in config.ts

* change: added packages/data-provider/src to SOURCE_DIRS for i18n check

* # removed unnecessary lodash imports

* # addressed PR comments
# fixed lint for updated files

* # better import for lodash (w/o relying on tree-shaking)
This commit is contained in:
Theo N. Truong 2025-05-19 17:33:25 -06:00 committed by GitHub
parent c79ee32006
commit 7ce782fec6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
23 changed files with 340 additions and 132 deletions

View file

@ -10,17 +10,7 @@ const getLogStores = require('~/cache/getLogStores');
* */
async function getCustomConfig() {
const cache = getLogStores(CacheKeys.CONFIG_STORE);
let customConfig = await cache.get(CacheKeys.CUSTOM_CONFIG);
if (!customConfig) {
customConfig = await loadCustomConfig();
}
if (!customConfig) {
return null;
}
return customConfig;
return (await cache.get(CacheKeys.CUSTOM_CONFIG)) || (await loadCustomConfig());
}
/**

View file

@ -29,7 +29,14 @@ async function loadConfigEndpoints(req) {
for (let i = 0; i < customEndpoints.length; i++) {
const endpoint = customEndpoints[i];
const { baseURL, apiKey, name: configName, iconURL, modelDisplayLabel } = endpoint;
const {
baseURL,
apiKey,
name: configName,
iconURL,
modelDisplayLabel,
customParams,
} = endpoint;
const name = normalizeEndpointName(configName);
const resolvedApiKey = extractEnvVariable(apiKey);
@ -41,6 +48,7 @@ async function loadConfigEndpoints(req) {
userProvideURL: isUserProvided(resolvedBaseURL),
modelDisplayLabel,
iconURL,
customParams,
};
}
}

View file

@ -1,10 +1,18 @@
const path = require('path');
const { CacheKeys, configSchema, EImageOutputType } = require('librechat-data-provider');
const {
CacheKeys,
configSchema,
EImageOutputType,
validateSettingDefinitions,
agentParamSettings,
paramSettings,
} = require('librechat-data-provider');
const getLogStores = require('~/cache/getLogStores');
const loadYaml = require('~/utils/loadYaml');
const { logger } = require('~/config');
const axios = require('axios');
const yaml = require('js-yaml');
const keyBy = require('lodash/keyBy');
const projectRoot = path.resolve(__dirname, '..', '..', '..', '..');
const defaultConfigPath = path.resolve(projectRoot, 'librechat.yaml');
@ -105,6 +113,10 @@ https://www.librechat.ai/docs/configuration/stt_tts`);
logger.debug('Custom config:', customConfig);
}
(customConfig.endpoints?.custom ?? [])
.filter((endpoint) => endpoint.customParams)
.forEach((endpoint) => parseCustomParams(endpoint.name, endpoint.customParams));
if (customConfig.cache) {
const cache = getLogStores(CacheKeys.CONFIG_STORE);
await cache.set(CacheKeys.CUSTOM_CONFIG, customConfig);
@ -117,4 +129,52 @@ https://www.librechat.ai/docs/configuration/stt_tts`);
return customConfig;
}
// Validate and fill out missing values for custom parameters
function parseCustomParams(endpointName, customParams) {
const paramEndpoint = customParams.defaultParamsEndpoint;
customParams.paramDefinitions = customParams.paramDefinitions || [];
// Checks if `defaultParamsEndpoint` is a key in `paramSettings`.
const validEndpoints = new Set([
...Object.keys(paramSettings),
...Object.keys(agentParamSettings),
]);
if (!validEndpoints.has(paramEndpoint)) {
throw new Error(
`defaultParamsEndpoint of "${endpointName}" endpoint is invalid. ` +
`Valid options are ${Array.from(validEndpoints).join(', ')}`,
);
}
// creates default param maps
const regularParams = paramSettings[paramEndpoint] ?? [];
const agentParams = agentParamSettings[paramEndpoint] ?? [];
const defaultParams = regularParams.concat(agentParams);
const defaultParamsMap = keyBy(defaultParams, 'key');
// TODO: Remove this check once we support new parameters not part of default parameters.
// Checks if every key in `paramDefinitions` is valid.
const validKeys = new Set(Object.keys(defaultParamsMap));
const paramKeys = customParams.paramDefinitions.map((param) => param.key);
if (paramKeys.some((key) => !validKeys.has(key))) {
throw new Error(
`paramDefinitions of "${endpointName}" endpoint contains invalid key(s). ` +
`Valid parameter keys are ${Array.from(validKeys).join(', ')}`,
);
}
// Fill out missing values for custom param definitions
customParams.paramDefinitions = customParams.paramDefinitions.map((param) => {
return { ...defaultParamsMap[param.key], ...param, optionType: 'custom' };
});
try {
validateSettingDefinitions(customParams.paramDefinitions);
} catch (e) {
throw new Error(
`Custom parameter definitions for "${endpointName}" endpoint is malformed: ${e.message}`,
);
}
}
module.exports = loadCustomConfig;

View file

@ -1,6 +1,34 @@
jest.mock('axios');
jest.mock('~/cache/getLogStores');
jest.mock('~/utils/loadYaml');
jest.mock('librechat-data-provider', () => {
const actual = jest.requireActual('librechat-data-provider');
return {
...actual,
paramSettings: { foo: {}, bar: {}, custom: {} },
agentParamSettings: {
custom: [],
google: [
{
key: 'pressure',
type: 'string',
component: 'input',
},
{
key: 'temperature',
type: 'number',
component: 'slider',
default: 0.5,
range: {
min: 0,
max: 2,
step: 0.01,
},
},
],
},
};
});
const axios = require('axios');
const loadCustomConfig = require('./loadCustomConfig');
@ -150,4 +178,126 @@ describe('loadCustomConfig', () => {
expect(logger.info).toHaveBeenCalledWith(JSON.stringify(mockConfig, null, 2));
expect(logger.debug).toHaveBeenCalledWith('Custom config:', mockConfig);
});
describe('parseCustomParams', () => {
const mockConfig = {
version: '1.0',
cache: false,
endpoints: {
custom: [
{
name: 'Google',
apiKey: 'user_provided',
customParams: {},
},
],
},
};
async function loadCustomParams(customParams) {
mockConfig.endpoints.custom[0].customParams = customParams;
loadYaml.mockReturnValue(mockConfig);
return await loadCustomConfig();
}
beforeEach(() => {
jest.resetAllMocks();
process.env.CONFIG_PATH = 'validConfig.yaml';
});
it('returns no error when customParams is undefined', async () => {
const result = await loadCustomParams(undefined);
expect(result).toEqual(mockConfig);
});
it('returns no error when customParams is valid', async () => {
const result = await loadCustomParams({
defaultParamsEndpoint: 'google',
paramDefinitions: [
{
key: 'temperature',
default: 0.5,
},
],
});
expect(result).toEqual(mockConfig);
});
it('throws an error when paramDefinitions contain unsupported keys', async () => {
const malformedCustomParams = {
defaultParamsEndpoint: 'google',
paramDefinitions: [
{ key: 'temperature', default: 0.5 },
{ key: 'unsupportedKey', range: 0.5 },
],
};
await expect(loadCustomParams(malformedCustomParams)).rejects.toThrow(
'paramDefinitions of "Google" endpoint contains invalid key(s). Valid parameter keys are pressure, temperature',
);
});
it('throws an error when paramDefinitions is malformed', async () => {
const malformedCustomParams = {
defaultParamsEndpoint: 'google',
paramDefinitions: [
{
key: 'temperature',
type: 'noomba',
component: 'inpoot',
optionType: 'custom',
},
],
};
await expect(loadCustomParams(malformedCustomParams)).rejects.toThrow(
/Custom parameter definitions for "Google" endpoint is malformed:/,
);
});
it('throws an error when defaultParamsEndpoint is not provided', async () => {
const malformedCustomParams = { defaultParamsEndpoint: undefined };
await expect(loadCustomParams(malformedCustomParams)).rejects.toThrow(
'defaultParamsEndpoint of "Google" endpoint is invalid. Valid options are foo, bar, custom, google',
);
});
it('fills the paramDefinitions with missing values', async () => {
const customParams = {
defaultParamsEndpoint: 'google',
paramDefinitions: [
{ key: 'temperature', default: 0.7, range: { min: 0.1, max: 0.9, step: 0.1 } },
{ key: 'pressure', component: 'textarea' },
],
};
const parsedConfig = await loadCustomParams(customParams);
const paramDefinitions = parsedConfig.endpoints.custom[0].customParams.paramDefinitions;
expect(paramDefinitions).toEqual([
{
columnSpan: 1,
component: 'slider',
default: 0.7, // overridden
includeInput: true,
key: 'temperature',
label: 'temperature',
optionType: 'custom',
range: {
// overridden
max: 0.9,
min: 0.1,
step: 0.1,
},
type: 'number',
},
{
columnSpan: 1,
component: 'textarea', // overridden
key: 'pressure',
label: 'pressure',
optionType: 'custom',
placeholder: '',
type: 'string',
},
]);
});
});
});