🎭 feat: Override Custom Endpoint Schema with Specified Params Endpoint (#11788)

* 🔧 refactor: Simplify payload parsing and enhance getSaveOptions logic

- Removed unused bedrockInputSchema from payloadParser, streamlining the function.
- Updated payloadParser to handle optional chaining for model parameters.
- Enhanced getSaveOptions to ensure runOptions defaults to an empty object if parsing fails, improving robustness.
- Adjusted the assignment of maxContextTokens to use the instance variable for consistency.

* 🔧 fix: Update maxContextTokens assignment logic in initializeAgent function

- Enhanced the maxContextTokens assignment to allow for user-defined values, ensuring it defaults to a calculated value only when not provided or invalid. This change improves flexibility in agent initialization.

* 🧪 test: Add unit tests for initializeAgent function

- Introduced comprehensive unit tests for the initializeAgent function, focusing on maxContextTokens behavior.
- Tests cover scenarios for user-defined values, fallback calculations, and edge cases such as zero and negative values, enhancing overall test coverage and reliability of agent initialization logic.

* refactor: default params Endpoint Configuration Handling

- Integrated `getEndpointsConfig` to fetch endpoint configurations, allowing for dynamic handling of `defaultParamsEndpoint`.
- Updated `buildEndpointOption` to pass `defaultParamsEndpoint` to `parseCompactConvo`, ensuring correct parameter handling based on endpoint type.
- Added comprehensive unit tests for `buildDefaultConvo` and `cleanupPreset` to validate behavior with `defaultParamsEndpoint`, covering various scenarios and edge cases.
- Refactored related hooks and utility functions to support the new configuration structure, improving overall flexibility and maintainability.

* refactor: Centralize defaultParamsEndpoint retrieval

- Introduced `getDefaultParamsEndpoint` function to streamline the retrieval of `defaultParamsEndpoint` across various hooks and middleware.
- Updated multiple files to utilize the new function, enhancing code consistency and maintainability.
- Removed redundant logic for fetching `defaultParamsEndpoint`, simplifying the codebase.
This commit is contained in:
Danny Avila 2026-02-13 23:04:51 -05:00 committed by GitHub
parent 6cc6ee3207
commit 467df0f07a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 1234 additions and 45 deletions

View file

@ -0,0 +1,284 @@
import { Providers } from '@librechat/agents';
import { EModelEndpoint } from 'librechat-data-provider';
import type { Agent } from 'librechat-data-provider';
import type { ServerRequest, InitializeResultBase } from '~/types';
import type { InitializeAgentDbMethods } from '../initialize';
// Mock logger
jest.mock('winston', () => ({
createLogger: jest.fn(() => ({
debug: jest.fn(),
warn: jest.fn(),
error: jest.fn(),
})),
format: {
combine: jest.fn(),
colorize: jest.fn(),
simple: jest.fn(),
},
transports: {
Console: jest.fn(),
},
}));
const mockExtractLibreChatParams = jest.fn();
const mockGetModelMaxTokens = jest.fn();
const mockOptionalChainWithEmptyCheck = jest.fn();
const mockGetThreadData = jest.fn();
jest.mock('~/utils', () => ({
extractLibreChatParams: (...args: unknown[]) => mockExtractLibreChatParams(...args),
getModelMaxTokens: (...args: unknown[]) => mockGetModelMaxTokens(...args),
optionalChainWithEmptyCheck: (...args: unknown[]) => mockOptionalChainWithEmptyCheck(...args),
getThreadData: (...args: unknown[]) => mockGetThreadData(...args),
}));
const mockGetProviderConfig = jest.fn();
jest.mock('~/endpoints', () => ({
getProviderConfig: (...args: unknown[]) => mockGetProviderConfig(...args),
}));
jest.mock('~/files', () => ({
filterFilesByEndpointConfig: jest.fn(() => []),
}));
jest.mock('~/prompts', () => ({
generateArtifactsPrompt: jest.fn(() => null),
}));
jest.mock('../resources', () => ({
primeResources: jest.fn().mockResolvedValue({
attachments: [],
tool_resources: undefined,
}),
}));
import { initializeAgent } from '../initialize';
/**
* Creates minimal mock objects for initializeAgent tests.
*/
function createMocks(overrides?: {
maxContextTokens?: number;
modelDefault?: number;
maxOutputTokens?: number;
}) {
const { maxContextTokens, modelDefault = 200000, maxOutputTokens = 4096 } = overrides ?? {};
const agent = {
id: 'agent-1',
model: 'test-model',
provider: Providers.OPENAI,
tools: [],
model_parameters: { model: 'test-model' },
} as unknown as Agent;
const req = {
user: { id: 'user-1' },
config: {},
} as unknown as ServerRequest;
const res = {} as unknown as import('express').Response;
const mockGetOptions = jest.fn().mockResolvedValue({
llmConfig: {
model: 'test-model',
maxTokens: maxOutputTokens,
},
endpointTokenConfig: undefined,
} satisfies InitializeResultBase);
mockGetProviderConfig.mockReturnValue({
getOptions: mockGetOptions,
overrideProvider: Providers.OPENAI,
});
// extractLibreChatParams returns maxContextTokens when provided in model_parameters
mockExtractLibreChatParams.mockReturnValue({
resendFiles: false,
maxContextTokens,
modelOptions: { model: 'test-model' },
});
// getModelMaxTokens returns the model's default context window
mockGetModelMaxTokens.mockReturnValue(modelDefault);
// Implement real optionalChainWithEmptyCheck behavior
mockOptionalChainWithEmptyCheck.mockImplementation(
(...values: (string | number | undefined)[]) => {
for (const v of values) {
if (v !== undefined && v !== null && v !== '') {
return v;
}
}
return values[values.length - 1];
},
);
const loadTools = jest.fn().mockResolvedValue({
tools: [],
toolContextMap: {},
userMCPAuthMap: undefined,
toolRegistry: undefined,
toolDefinitions: [],
hasDeferredTools: false,
});
const db: InitializeAgentDbMethods = {
getFiles: jest.fn().mockResolvedValue([]),
getConvoFiles: jest.fn().mockResolvedValue([]),
updateFilesUsage: jest.fn().mockResolvedValue([]),
getUserKey: jest.fn().mockResolvedValue('user-1'),
getUserKeyValues: jest.fn().mockResolvedValue([]),
getToolFilesByIds: jest.fn().mockResolvedValue([]),
};
return { agent, req, res, loadTools, db };
}
describe('initializeAgent — maxContextTokens', () => {
beforeEach(() => {
jest.clearAllMocks();
});
it('uses user-configured maxContextTokens when provided via model_parameters', async () => {
const userValue = 50000;
const { agent, req, res, loadTools, db } = createMocks({
maxContextTokens: userValue,
modelDefault: 200000,
maxOutputTokens: 4096,
});
const result = await initializeAgent(
{
req,
res,
agent,
loadTools,
endpointOption: {
endpoint: EModelEndpoint.agents,
model_parameters: { maxContextTokens: userValue },
},
allowedProviders: new Set([Providers.OPENAI]),
isInitialAgent: true,
},
db,
);
expect(result.maxContextTokens).toBe(userValue);
});
it('falls back to formula when maxContextTokens is NOT provided', async () => {
const modelDefault = 200000;
const maxOutputTokens = 4096;
const { agent, req, res, loadTools, db } = createMocks({
maxContextTokens: undefined,
modelDefault,
maxOutputTokens,
});
const result = await initializeAgent(
{
req,
res,
agent,
loadTools,
endpointOption: { endpoint: EModelEndpoint.agents },
allowedProviders: new Set([Providers.OPENAI]),
isInitialAgent: true,
},
db,
);
const expected = Math.round((modelDefault - maxOutputTokens) * 0.9);
expect(result.maxContextTokens).toBe(expected);
});
it('falls back to formula when maxContextTokens is 0', async () => {
const maxOutputTokens = 4096;
const { agent, req, res, loadTools, db } = createMocks({
maxContextTokens: 0,
modelDefault: 200000,
maxOutputTokens,
});
const result = await initializeAgent(
{
req,
res,
agent,
loadTools,
endpointOption: {
endpoint: EModelEndpoint.agents,
model_parameters: { maxContextTokens: 0 },
},
allowedProviders: new Set([Providers.OPENAI]),
isInitialAgent: true,
},
db,
);
// 0 is not used as-is; the formula kicks in.
// optionalChainWithEmptyCheck(0, 200000, 18000) returns 0 (not null/undefined),
// then Number(0) || 18000 = 18000 (the fallback default).
expect(result.maxContextTokens).not.toBe(0);
const expected = Math.round((18000 - maxOutputTokens) * 0.9);
expect(result.maxContextTokens).toBe(expected);
});
it('falls back to formula when maxContextTokens is negative', async () => {
const maxOutputTokens = 4096;
const { agent, req, res, loadTools, db } = createMocks({
maxContextTokens: -1,
modelDefault: 200000,
maxOutputTokens,
});
const result = await initializeAgent(
{
req,
res,
agent,
loadTools,
endpointOption: {
endpoint: EModelEndpoint.agents,
model_parameters: { maxContextTokens: -1 },
},
allowedProviders: new Set([Providers.OPENAI]),
isInitialAgent: true,
},
db,
);
// -1 is not used as-is; the formula kicks in
expect(result.maxContextTokens).not.toBe(-1);
});
it('preserves small user-configured value (e.g. 1000 from modelSpec)', async () => {
const userValue = 1000;
const { agent, req, res, loadTools, db } = createMocks({
maxContextTokens: userValue,
modelDefault: 128000,
maxOutputTokens: 4096,
});
const result = await initializeAgent(
{
req,
res,
agent,
loadTools,
endpointOption: {
endpoint: EModelEndpoint.agents,
model_parameters: { maxContextTokens: userValue },
},
allowedProviders: new Set([Providers.OPENAI]),
isInitialAgent: true,
},
db,
);
// Should NOT be overridden to Math.round((128000 - 4096) * 0.9) = 111,514
expect(result.maxContextTokens).toBe(userValue);
});
});

View file

@ -413,7 +413,10 @@ export async function initializeAgent(
toolContextMap: toolContextMap ?? {},
useLegacyContent: !!options.useLegacyContent,
tools: (tools ?? []) as GenericTool[] & string[],
maxContextTokens: Math.round((agentMaxContextNum - maxOutputTokensNum) * 0.9),
maxContextTokens:
maxContextTokens != null && maxContextTokens > 0
? maxContextTokens
: Math.round((agentMaxContextNum - maxOutputTokensNum) * 0.9),
};
return initializedAgent;

View file

@ -1,4 +1,4 @@
import { replaceSpecialVars, parseCompactConvo, parseTextParts } from '../src/parsers';
import { replaceSpecialVars, parseConvo, parseCompactConvo, parseTextParts } from '../src/parsers';
import { specialVariables } from '../src/config';
import { EModelEndpoint } from '../src/schemas';
import { ContentTypes } from '../src/types/runs';
@ -262,6 +262,257 @@ describe('parseCompactConvo', () => {
});
});
describe('parseConvo - defaultParamsEndpoint', () => {
test('should strip maxOutputTokens for custom endpoint without defaultParamsEndpoint', () => {
const conversation: Partial<TConversation> = {
model: 'anthropic/claude-opus-4.5',
temperature: 0.7,
maxOutputTokens: 8192,
maxContextTokens: 50000,
};
const result = parseConvo({
endpoint: 'MyCustomEndpoint' as EModelEndpoint,
endpointType: EModelEndpoint.custom,
conversation,
});
expect(result).not.toBeNull();
expect(result?.temperature).toBe(0.7);
expect(result?.maxContextTokens).toBe(50000);
expect(result?.maxOutputTokens).toBeUndefined();
});
test('should preserve maxOutputTokens when defaultParamsEndpoint is anthropic', () => {
const conversation: Partial<TConversation> = {
model: 'anthropic/claude-opus-4.5',
temperature: 0.7,
maxOutputTokens: 8192,
topP: 0.9,
topK: 40,
maxContextTokens: 50000,
};
const result = parseConvo({
endpoint: 'MyCustomEndpoint' as EModelEndpoint,
endpointType: EModelEndpoint.custom,
conversation,
defaultParamsEndpoint: EModelEndpoint.anthropic,
});
expect(result).not.toBeNull();
expect(result?.model).toBe('anthropic/claude-opus-4.5');
expect(result?.temperature).toBe(0.7);
expect(result?.maxOutputTokens).toBe(8192);
expect(result?.topP).toBe(0.9);
expect(result?.topK).toBe(40);
expect(result?.maxContextTokens).toBe(50000);
});
test('should strip OpenAI-specific fields when defaultParamsEndpoint is anthropic', () => {
const conversation: Partial<TConversation> = {
model: 'anthropic/claude-opus-4.5',
temperature: 0.7,
max_tokens: 4096,
top_p: 0.9,
presence_penalty: 0.5,
frequency_penalty: 0.3,
};
const result = parseConvo({
endpoint: 'MyCustomEndpoint' as EModelEndpoint,
endpointType: EModelEndpoint.custom,
conversation,
defaultParamsEndpoint: EModelEndpoint.anthropic,
});
expect(result).not.toBeNull();
expect(result?.temperature).toBe(0.7);
expect(result?.max_tokens).toBeUndefined();
expect(result?.top_p).toBeUndefined();
expect(result?.presence_penalty).toBeUndefined();
expect(result?.frequency_penalty).toBeUndefined();
});
test('should preserve max_tokens when defaultParamsEndpoint is not set (OpenAI default)', () => {
const conversation: Partial<TConversation> = {
model: 'gpt-4o',
temperature: 0.7,
max_tokens: 4096,
top_p: 0.9,
};
const result = parseConvo({
endpoint: 'MyCustomEndpoint' as EModelEndpoint,
endpointType: EModelEndpoint.custom,
conversation,
});
expect(result).not.toBeNull();
expect(result?.max_tokens).toBe(4096);
expect(result?.top_p).toBe(0.9);
});
test('should preserve Google-specific fields when defaultParamsEndpoint is google', () => {
const conversation: Partial<TConversation> = {
model: 'gemini-pro',
temperature: 0.7,
maxOutputTokens: 8192,
topP: 0.9,
topK: 40,
};
const result = parseConvo({
endpoint: 'MyCustomEndpoint' as EModelEndpoint,
endpointType: EModelEndpoint.custom,
conversation,
defaultParamsEndpoint: EModelEndpoint.google,
});
expect(result).not.toBeNull();
expect(result?.maxOutputTokens).toBe(8192);
expect(result?.topP).toBe(0.9);
expect(result?.topK).toBe(40);
});
test('should not strip fields from non-custom endpoints that already have a schema', () => {
const conversation: Partial<TConversation> = {
model: 'gpt-4o',
temperature: 0.7,
max_tokens: 4096,
top_p: 0.9,
};
const result = parseConvo({
endpoint: EModelEndpoint.openAI,
conversation,
defaultParamsEndpoint: EModelEndpoint.anthropic,
});
expect(result).not.toBeNull();
expect(result?.max_tokens).toBe(4096);
expect(result?.top_p).toBe(0.9);
});
test('should not carry bedrock region to custom endpoint without defaultParamsEndpoint', () => {
const conversation: Partial<TConversation> = {
model: 'gpt-4o',
temperature: 0.7,
region: 'us-east-1',
};
const result = parseConvo({
endpoint: 'MyCustomEndpoint' as EModelEndpoint,
endpointType: EModelEndpoint.custom,
conversation,
});
expect(result).not.toBeNull();
expect(result?.temperature).toBe(0.7);
expect(result?.region).toBeUndefined();
});
test('should fall back to endpointType schema when defaultParamsEndpoint is invalid', () => {
const conversation: Partial<TConversation> = {
model: 'gpt-4o',
temperature: 0.7,
max_tokens: 4096,
maxOutputTokens: 8192,
};
const result = parseConvo({
endpoint: 'MyCustomEndpoint' as EModelEndpoint,
endpointType: EModelEndpoint.custom,
conversation,
defaultParamsEndpoint: 'nonexistent_endpoint',
});
expect(result).not.toBeNull();
expect(result?.max_tokens).toBe(4096);
expect(result?.maxOutputTokens).toBeUndefined();
});
});
describe('parseCompactConvo - defaultParamsEndpoint', () => {
test('should strip maxOutputTokens for custom endpoint without defaultParamsEndpoint', () => {
const conversation: Partial<TConversation> = {
model: 'anthropic/claude-opus-4.5',
temperature: 0.7,
maxOutputTokens: 8192,
};
const result = parseCompactConvo({
endpoint: 'MyCustomEndpoint' as EModelEndpoint,
endpointType: EModelEndpoint.custom,
conversation,
});
expect(result).not.toBeNull();
expect(result?.temperature).toBe(0.7);
expect(result?.maxOutputTokens).toBeUndefined();
});
test('should preserve maxOutputTokens when defaultParamsEndpoint is anthropic', () => {
const conversation: Partial<TConversation> = {
model: 'anthropic/claude-opus-4.5',
temperature: 0.7,
maxOutputTokens: 8192,
topP: 0.9,
maxContextTokens: 50000,
};
const result = parseCompactConvo({
endpoint: 'MyCustomEndpoint' as EModelEndpoint,
endpointType: EModelEndpoint.custom,
conversation,
defaultParamsEndpoint: EModelEndpoint.anthropic,
});
expect(result).not.toBeNull();
expect(result?.maxOutputTokens).toBe(8192);
expect(result?.topP).toBe(0.9);
expect(result?.maxContextTokens).toBe(50000);
});
test('should strip iconURL even when defaultParamsEndpoint is set', () => {
const conversation: Partial<TConversation> = {
model: 'anthropic/claude-opus-4.5',
iconURL: 'https://malicious.com/track.png',
maxOutputTokens: 8192,
};
const result = parseCompactConvo({
endpoint: 'MyCustomEndpoint' as EModelEndpoint,
endpointType: EModelEndpoint.custom,
conversation,
defaultParamsEndpoint: EModelEndpoint.anthropic,
});
expect(result).not.toBeNull();
expect(result?.['iconURL']).toBeUndefined();
expect(result?.maxOutputTokens).toBe(8192);
});
test('should fall back to endpointType when defaultParamsEndpoint is null', () => {
const conversation: Partial<TConversation> = {
model: 'gpt-4o',
max_tokens: 4096,
maxOutputTokens: 8192,
};
const result = parseCompactConvo({
endpoint: 'MyCustomEndpoint' as EModelEndpoint,
endpointType: EModelEndpoint.custom,
conversation,
defaultParamsEndpoint: null,
});
expect(result).not.toBeNull();
expect(result?.max_tokens).toBe(4096);
expect(result?.maxOutputTokens).toBeUndefined();
});
});
describe('parseTextParts', () => {
test('should concatenate text parts', () => {
const parts: TMessageContentParts[] = [

View file

@ -1908,3 +1908,14 @@ export function getEndpointField<
}
return config[property];
}
/** Resolves the `defaultParamsEndpoint` for a given endpoint from its custom params config */
export function getDefaultParamsEndpoint(
endpointsConfig: TEndpointsConfig | undefined | null,
endpoint: string | null | undefined,
): string | undefined {
if (!endpointsConfig || !endpoint) {
return undefined;
}
return endpointsConfig[endpoint]?.customParams?.defaultParamsEndpoint;
}

View file

@ -144,26 +144,25 @@ export const parseConvo = ({
endpointType,
conversation,
possibleValues,
defaultParamsEndpoint,
}: {
endpoint: EndpointSchemaKey;
endpointType?: EndpointSchemaKey | null;
conversation: Partial<s.TConversation | s.TPreset> | null;
possibleValues?: TPossibleValues;
// TODO: POC for default schema
// defaultSchema?: Partial<EndpointSchema>,
defaultParamsEndpoint?: string | null;
}) => {
let schema = endpointSchemas[endpoint] as EndpointSchema | undefined;
if (!schema && !endpointType) {
throw new Error(`Unknown endpoint: ${endpoint}`);
} else if (!schema && endpointType) {
schema = endpointSchemas[endpointType];
} else if (!schema) {
const overrideSchema = defaultParamsEndpoint
? endpointSchemas[defaultParamsEndpoint as EndpointSchemaKey]
: undefined;
schema = overrideSchema ?? (endpointType ? endpointSchemas[endpointType] : undefined);
}
// if (defaultSchema && schemaCreators[endpoint]) {
// schema = schemaCreators[endpoint](defaultSchema);
// }
const convo = schema?.parse(conversation) as s.TConversation | undefined;
const { models } = possibleValues ?? {};
@ -310,13 +309,13 @@ export const parseCompactConvo = ({
endpointType,
conversation,
possibleValues,
defaultParamsEndpoint,
}: {
endpoint?: EndpointSchemaKey;
endpointType?: EndpointSchemaKey | null;
conversation: Partial<s.TConversation | s.TPreset>;
possibleValues?: TPossibleValues;
// TODO: POC for default schema
// defaultSchema?: Partial<EndpointSchema>,
defaultParamsEndpoint?: string | null;
}): Omit<s.TConversation, 'iconURL'> | null => {
if (!endpoint) {
throw new Error(`undefined endpoint: ${endpoint}`);
@ -326,8 +325,11 @@ export const parseCompactConvo = ({
if (!schema && !endpointType) {
throw new Error(`Unknown endpoint: ${endpoint}`);
} else if (!schema && endpointType) {
schema = compactEndpointSchemas[endpointType];
} else if (!schema) {
const overrideSchema = defaultParamsEndpoint
? compactEndpointSchemas[defaultParamsEndpoint as EndpointSchemaKey]
: undefined;
schema = overrideSchema ?? (endpointType ? compactEndpointSchemas[endpointType] : undefined);
}
if (!schema) {