🧠 feat: Bedrock Anthropic Reasoning & Update Endpoint Handling (#6163)

* feat: Add thinking and thinkingBudget parameters for Bedrock Anthropic models

* chore: Update @librechat/agents to version 2.1.8

* refactor: change region order in params

* refactor: Add maxTokens parameter to conversation preset schema

* refactor: Update agent client to use bedrockInputSchema and improve error handling for model parameters

* refactor: streamline/optimize llmConfig initialization and saving for bedrock

* fix: ensure config titleModel is used for all endpoints

* refactor: enhance OpenAIClient and agent initialization to support endpoint checks for OpenRouter

* chore: bump @google/generative-ai
This commit is contained in:
Danny Avila 2025-03-03 19:09:22 -05:00 committed by GitHub
parent 3accf91094
commit ceb0da874b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 2224 additions and 667 deletions

View file

@ -827,7 +827,8 @@ class GoogleClient extends BaseClient {
let reply = ''; let reply = '';
const { abortController } = options; const { abortController } = options;
const model = this.modelOptions.modelName ?? this.modelOptions.model ?? ''; const model =
this.options.titleModel ?? this.modelOptions.modelName ?? this.modelOptions.model ?? '';
const safetySettings = getSafetySettings(model); const safetySettings = getSafetySettings(model);
if (!EXCLUDED_GENAI_MODELS.test(model) && !this.project_id) { if (!EXCLUDED_GENAI_MODELS.test(model) && !this.project_id) {
logger.debug('Identified titling model as GenAI version'); logger.debug('Identified titling model as GenAI version');

View file

@ -112,7 +112,12 @@ class OpenAIClient extends BaseClient {
const { OPENAI_FORCE_PROMPT } = process.env ?? {}; const { OPENAI_FORCE_PROMPT } = process.env ?? {};
const { reverseProxyUrl: reverseProxy } = this.options; const { reverseProxyUrl: reverseProxy } = this.options;
if (!this.useOpenRouter && reverseProxy && reverseProxy.includes(KnownEndpoints.openrouter)) { if (
!this.useOpenRouter &&
((reverseProxy && reverseProxy.includes(KnownEndpoints.openrouter)) ||
(this.options.endpoint &&
this.options.endpoint.toLowerCase().includes(KnownEndpoints.openrouter)))
) {
this.useOpenRouter = true; this.useOpenRouter = true;
} }

View file

@ -56,6 +56,10 @@ const conversationPreset = {
type: Number, type: Number,
required: false, required: false,
}, },
maxTokens: {
type: Number,
required: false,
},
presence_penalty: { presence_penalty: {
type: Number, type: Number,
required: false, required: false,

View file

@ -36,7 +36,7 @@
"dependencies": { "dependencies": {
"@anthropic-ai/sdk": "^0.37.0", "@anthropic-ai/sdk": "^0.37.0",
"@azure/search-documents": "^12.0.0", "@azure/search-documents": "^12.0.0",
"@google/generative-ai": "^0.21.0", "@google/generative-ai": "^0.23.0",
"@googleapis/youtube": "^20.0.0", "@googleapis/youtube": "^20.0.0",
"@keyv/mongo": "^2.1.8", "@keyv/mongo": "^2.1.8",
"@keyv/redis": "^2.8.1", "@keyv/redis": "^2.8.1",
@ -45,7 +45,7 @@
"@langchain/google-genai": "^0.1.9", "@langchain/google-genai": "^0.1.9",
"@langchain/google-vertexai": "^0.2.0", "@langchain/google-vertexai": "^0.2.0",
"@langchain/textsplitters": "^0.1.0", "@langchain/textsplitters": "^0.1.0",
"@librechat/agents": "^2.1.7", "@librechat/agents": "^2.1.8",
"@waylaidwanderer/fetch-event-source": "^3.0.1", "@waylaidwanderer/fetch-event-source": "^3.0.1",
"axios": "1.7.8", "axios": "1.7.8",
"bcryptjs": "^2.4.3", "bcryptjs": "^2.4.3",

View file

@ -17,7 +17,7 @@ const {
KnownEndpoints, KnownEndpoints,
anthropicSchema, anthropicSchema,
isAgentsEndpoint, isAgentsEndpoint,
bedrockOutputParser, bedrockInputSchema,
removeNullishValues, removeNullishValues,
} = require('librechat-data-provider'); } = require('librechat-data-provider');
const { const {
@ -30,6 +30,7 @@ const {
const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens'); const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens');
const { getBufferString, HumanMessage } = require('@langchain/core/messages'); const { getBufferString, HumanMessage } = require('@langchain/core/messages');
const { encodeAndFormat } = require('~/server/services/Files/images/encode'); const { encodeAndFormat } = require('~/server/services/Files/images/encode');
const { getCustomEndpointConfig } = require('~/server/services/Config');
const Tokenizer = require('~/server/services/Tokenizer'); const Tokenizer = require('~/server/services/Tokenizer');
const BaseClient = require('~/app/clients/BaseClient'); const BaseClient = require('~/app/clients/BaseClient');
const { createRun } = require('./run'); const { createRun } = require('./run');
@ -39,10 +40,10 @@ const { logger } = require('~/config');
/** @typedef {import('@langchain/core/runnables').RunnableConfig} RunnableConfig */ /** @typedef {import('@langchain/core/runnables').RunnableConfig} RunnableConfig */
const providerParsers = { const providerParsers = {
[EModelEndpoint.openAI]: openAISchema, [EModelEndpoint.openAI]: openAISchema.parse,
[EModelEndpoint.azureOpenAI]: openAISchema, [EModelEndpoint.azureOpenAI]: openAISchema.parse,
[EModelEndpoint.anthropic]: anthropicSchema, [EModelEndpoint.anthropic]: anthropicSchema.parse,
[EModelEndpoint.bedrock]: bedrockOutputParser, [EModelEndpoint.bedrock]: bedrockInputSchema.parse,
}; };
const legacyContentEndpoints = new Set([KnownEndpoints.groq, KnownEndpoints.deepseek]); const legacyContentEndpoints = new Set([KnownEndpoints.groq, KnownEndpoints.deepseek]);
@ -187,7 +188,14 @@ class AgentClient extends BaseClient {
: {}; : {};
if (parseOptions) { if (parseOptions) {
try {
runOptions = parseOptions(this.options.agent.model_parameters); runOptions = parseOptions(this.options.agent.model_parameters);
} catch (error) {
logger.error(
'[api/server/controllers/agents/client.js #getSaveOptions] Error parsing options',
error,
);
}
} }
return removeNullishValues( return removeNullishValues(
@ -824,13 +832,16 @@ class AgentClient extends BaseClient {
const clientOptions = { const clientOptions = {
maxTokens: 75, maxTokens: 75,
}; };
const providerConfig = this.options.req.app.locals[this.options.agent.provider]; let endpointConfig = this.options.req.app.locals[this.options.agent.endpoint];
if (!endpointConfig) {
endpointConfig = await getCustomEndpointConfig(this.options.agent.endpoint);
}
if ( if (
providerConfig && endpointConfig &&
providerConfig.titleModel && endpointConfig.titleModel &&
providerConfig.titleModel !== Constants.CURRENT_MODEL endpointConfig.titleModel !== Constants.CURRENT_MODEL
) { ) {
clientOptions.model = providerConfig.titleModel; clientOptions.model = endpointConfig.titleModel;
} }
try { try {
const titleResult = await this.run.generateTitle({ const titleResult = await this.run.generateTitle({

View file

@ -45,7 +45,10 @@ async function createRun({
/** @type {'reasoning_content' | 'reasoning'} */ /** @type {'reasoning_content' | 'reasoning'} */
let reasoningKey; let reasoningKey;
if (llmConfig.configuration?.baseURL?.includes(KnownEndpoints.openrouter)) { if (
llmConfig.configuration?.baseURL?.includes(KnownEndpoints.openrouter) ||
(agent.endpoint && agent.endpoint.toLowerCase().includes(KnownEndpoints.openrouter))
) {
reasoningKey = 'reasoning'; reasoningKey = 'reasoning';
} }
if (/o1(?!-(?:mini|preview)).*$/.test(llmConfig.model)) { if (/o1(?!-(?:mini|preview)).*$/.test(llmConfig.model)) {

View file

@ -101,6 +101,7 @@ const initializeAgentOptions = async ({
}); });
const provider = agent.provider; const provider = agent.provider;
agent.endpoint = provider;
let getOptions = providerConfigMap[provider]; let getOptions = providerConfigMap[provider];
if (!getOptions && providerConfigMap[provider.toLowerCase()] != null) { if (!getOptions && providerConfigMap[provider.toLowerCase()] != null) {
agent.provider = provider.toLowerCase(); agent.provider = provider.toLowerCase();
@ -112,9 +113,7 @@ const initializeAgentOptions = async ({
} }
getOptions = initCustom; getOptions = initCustom;
agent.provider = Providers.OPENAI; agent.provider = Providers.OPENAI;
agent.endpoint = provider.toLowerCase();
} }
const model_parameters = Object.assign( const model_parameters = Object.assign(
{}, {},
agent.model_parameters ?? { model: agent.model }, agent.model_parameters ?? { model: agent.model },

View file

@ -27,6 +27,7 @@ const initializeClient = async ({ req, res, endpointOption, overrideModel, optio
if (anthropicConfig) { if (anthropicConfig) {
clientOptions.streamRate = anthropicConfig.streamRate; clientOptions.streamRate = anthropicConfig.streamRate;
clientOptions.titleModel = anthropicConfig.titleModel;
} }
/** @type {undefined | TBaseEndpoint} */ /** @type {undefined | TBaseEndpoint} */

View file

@ -1,6 +1,5 @@
const { removeNullishValues, bedrockInputParser } = require('librechat-data-provider'); const { removeNullishValues } = require('librechat-data-provider');
const generateArtifactsPrompt = require('~/app/clients/prompts/artifacts'); const generateArtifactsPrompt = require('~/app/clients/prompts/artifacts');
const { logger } = require('~/config');
const buildOptions = (endpoint, parsedBody) => { const buildOptions = (endpoint, parsedBody) => {
const { const {
@ -15,12 +14,6 @@ const buildOptions = (endpoint, parsedBody) => {
artifacts, artifacts,
...model_parameters ...model_parameters
} = parsedBody; } = parsedBody;
let parsedParams = model_parameters;
try {
parsedParams = bedrockInputParser.parse(model_parameters);
} catch (error) {
logger.warn('Failed to parse bedrock input', error);
}
const endpointOption = removeNullishValues({ const endpointOption = removeNullishValues({
endpoint, endpoint,
name, name,
@ -31,7 +24,7 @@ const buildOptions = (endpoint, parsedBody) => {
spec, spec,
promptPrefix, promptPrefix,
maxContextTokens, maxContextTokens,
model_parameters: parsedParams, model_parameters,
}); });
if (typeof artifacts === 'string') { if (typeof artifacts === 'string') {

View file

@ -1,14 +1,16 @@
const { HttpsProxyAgent } = require('https-proxy-agent'); const { HttpsProxyAgent } = require('https-proxy-agent');
const { const {
EModelEndpoint,
Constants,
AuthType, AuthType,
Constants,
EModelEndpoint,
bedrockInputParser,
bedrockOutputParser,
removeNullishValues, removeNullishValues,
} = require('librechat-data-provider'); } = require('librechat-data-provider');
const { getUserKey, checkUserKeyExpiry } = require('~/server/services/UserService'); const { getUserKey, checkUserKeyExpiry } = require('~/server/services/UserService');
const { sleep } = require('~/server/utils'); const { sleep } = require('~/server/utils');
const getOptions = async ({ req, endpointOption }) => { const getOptions = async ({ req, overrideModel, endpointOption }) => {
const { const {
BEDROCK_AWS_SECRET_ACCESS_KEY, BEDROCK_AWS_SECRET_ACCESS_KEY,
BEDROCK_AWS_ACCESS_KEY_ID, BEDROCK_AWS_ACCESS_KEY_ID,
@ -62,11 +64,31 @@ const getOptions = async ({ req, endpointOption }) => {
/** @type {BedrockClientOptions} */ /** @type {BedrockClientOptions} */
const requestOptions = { const requestOptions = {
model: endpointOption.model, model: overrideModel ?? endpointOption.model,
region: BEDROCK_AWS_DEFAULT_REGION, region: BEDROCK_AWS_DEFAULT_REGION,
streaming: true, };
streamUsage: true,
callbacks: [ const configOptions = {};
if (PROXY) {
/** NOTE: NOT SUPPORTED BY BEDROCK */
configOptions.httpAgent = new HttpsProxyAgent(PROXY);
}
const llmConfig = bedrockOutputParser(
bedrockInputParser.parse(
removeNullishValues(Object.assign(requestOptions, endpointOption.model_parameters)),
),
);
if (credentials) {
llmConfig.credentials = credentials;
}
if (BEDROCK_REVERSE_PROXY) {
llmConfig.endpointHost = BEDROCK_REVERSE_PROXY;
}
llmConfig.callbacks = [
{ {
handleLLMNewToken: async () => { handleLLMNewToken: async () => {
if (!streamRate) { if (!streamRate) {
@ -75,26 +97,11 @@ const getOptions = async ({ req, endpointOption }) => {
await sleep(streamRate); await sleep(streamRate);
}, },
}, },
], ];
};
if (credentials) {
requestOptions.credentials = credentials;
}
if (BEDROCK_REVERSE_PROXY) {
requestOptions.endpointHost = BEDROCK_REVERSE_PROXY;
}
const configOptions = {};
if (PROXY) {
/** NOTE: NOT SUPPORTED BY BEDROCK */
configOptions.httpAgent = new HttpsProxyAgent(PROXY);
}
return { return {
/** @type {BedrockClientOptions} */ /** @type {BedrockClientOptions} */
llmConfig: removeNullishValues(Object.assign(requestOptions, endpointOption.model_parameters)), llmConfig,
configOptions, configOptions,
}; };
}; };

View file

@ -141,7 +141,7 @@ const initializeClient = async ({ req, res, endpointOption, optionsOnly, overrid
}, },
clientOptions, clientOptions,
); );
const options = getLLMConfig(apiKey, clientOptions); const options = getLLMConfig(apiKey, clientOptions, endpoint);
if (!customOptions.streamRate) { if (!customOptions.streamRate) {
return options; return options;
} }

View file

@ -5,12 +5,7 @@ const { isEnabled } = require('~/server/utils');
const { GoogleClient } = require('~/app'); const { GoogleClient } = require('~/app');
const initializeClient = async ({ req, res, endpointOption, overrideModel, optionsOnly }) => { const initializeClient = async ({ req, res, endpointOption, overrideModel, optionsOnly }) => {
const { const { GOOGLE_KEY, GOOGLE_REVERSE_PROXY, GOOGLE_AUTH_HEADER, PROXY } = process.env;
GOOGLE_KEY,
GOOGLE_REVERSE_PROXY,
GOOGLE_AUTH_HEADER,
PROXY,
} = process.env;
const isUserProvided = GOOGLE_KEY === 'user_provided'; const isUserProvided = GOOGLE_KEY === 'user_provided';
const { key: expiresAt } = req.body; const { key: expiresAt } = req.body;
@ -43,6 +38,7 @@ const initializeClient = async ({ req, res, endpointOption, overrideModel, optio
if (googleConfig) { if (googleConfig) {
clientOptions.streamRate = googleConfig.streamRate; clientOptions.streamRate = googleConfig.streamRate;
clientOptions.titleModel = googleConfig.titleModel;
} }
if (allConfig) { if (allConfig) {

View file

@ -113,6 +113,7 @@ const initializeClient = async ({
if (!isAzureOpenAI && openAIConfig) { if (!isAzureOpenAI && openAIConfig) {
clientOptions.streamRate = openAIConfig.streamRate; clientOptions.streamRate = openAIConfig.streamRate;
clientOptions.titleModel = openAIConfig.titleModel;
} }
/** @type {undefined | TBaseEndpoint} */ /** @type {undefined | TBaseEndpoint} */

View file

@ -23,9 +23,10 @@ const { isEnabled } = require('~/server/utils');
* @param {boolean} [options.streaming] - Whether to use streaming mode. * @param {boolean} [options.streaming] - Whether to use streaming mode.
* @param {Object} [options.addParams] - Additional parameters to add to the model options. * @param {Object} [options.addParams] - Additional parameters to add to the model options.
* @param {string[]} [options.dropParams] - Parameters to remove from the model options. * @param {string[]} [options.dropParams] - Parameters to remove from the model options.
* @param {string|null} [endpoint=null] - The endpoint name
* @returns {Object} Configuration options for creating an LLM instance. * @returns {Object} Configuration options for creating an LLM instance.
*/ */
function getLLMConfig(apiKey, options = {}) { function getLLMConfig(apiKey, options = {}, endpoint = null) {
const { const {
modelOptions = {}, modelOptions = {},
reverseProxyUrl, reverseProxyUrl,
@ -58,7 +59,10 @@ function getLLMConfig(apiKey, options = {}) {
let useOpenRouter; let useOpenRouter;
/** @type {OpenAIClientOptions['configuration']} */ /** @type {OpenAIClientOptions['configuration']} */
const configOptions = {}; const configOptions = {};
if (reverseProxyUrl && reverseProxyUrl.includes(KnownEndpoints.openrouter)) { if (
(reverseProxyUrl && reverseProxyUrl.includes(KnownEndpoints.openrouter)) ||
(endpoint && endpoint.toLowerCase().includes(KnownEndpoints.openrouter))
) {
useOpenRouter = true; useOpenRouter = true;
llmConfig.include_reasoning = true; llmConfig.include_reasoning = true;
configOptions.baseURL = reverseProxyUrl; configOptions.baseURL = reverseProxyUrl;

View file

@ -553,8 +553,10 @@ const bedrockAnthropic: SettingsConfiguration = [
bedrock.topP, bedrock.topP,
bedrock.topK, bedrock.topK,
baseDefinitions.stop, baseDefinitions.stop,
bedrock.region,
librechat.resendFiles, librechat.resendFiles,
bedrock.region,
anthropic.thinking,
anthropic.thinkingBudget,
]; ];
const bedrockMistral: SettingsConfiguration = [ const bedrockMistral: SettingsConfiguration = [
@ -564,8 +566,8 @@ const bedrockMistral: SettingsConfiguration = [
bedrock.maxTokens, bedrock.maxTokens,
mistral.temperature, mistral.temperature,
mistral.topP, mistral.topP,
bedrock.region,
librechat.resendFiles, librechat.resendFiles,
bedrock.region,
]; ];
const bedrockCohere: SettingsConfiguration = [ const bedrockCohere: SettingsConfiguration = [
@ -575,8 +577,8 @@ const bedrockCohere: SettingsConfiguration = [
bedrock.maxTokens, bedrock.maxTokens,
cohere.temperature, cohere.temperature,
cohere.topP, cohere.topP,
bedrock.region,
librechat.resendFiles, librechat.resendFiles,
bedrock.region,
]; ];
const bedrockGeneral: SettingsConfiguration = [ const bedrockGeneral: SettingsConfiguration = [
@ -585,8 +587,8 @@ const bedrockGeneral: SettingsConfiguration = [
librechat.maxContextTokens, librechat.maxContextTokens,
meta.temperature, meta.temperature,
meta.topP, meta.topP,
bedrock.region,
librechat.resendFiles, librechat.resendFiles,
bedrock.region,
]; ];
const bedrockAnthropicCol1: SettingsConfiguration = [ const bedrockAnthropicCol1: SettingsConfiguration = [
@ -602,8 +604,10 @@ const bedrockAnthropicCol2: SettingsConfiguration = [
bedrock.temperature, bedrock.temperature,
bedrock.topP, bedrock.topP,
bedrock.topK, bedrock.topK,
bedrock.region,
librechat.resendFiles, librechat.resendFiles,
bedrock.region,
anthropic.thinking,
anthropic.thinkingBudget,
]; ];
const bedrockMistralCol1: SettingsConfiguration = [ const bedrockMistralCol1: SettingsConfiguration = [
@ -617,8 +621,8 @@ const bedrockMistralCol2: SettingsConfiguration = [
bedrock.maxTokens, bedrock.maxTokens,
mistral.temperature, mistral.temperature,
mistral.topP, mistral.topP,
bedrock.region,
librechat.resendFiles, librechat.resendFiles,
bedrock.region,
]; ];
const bedrockCohereCol1: SettingsConfiguration = [ const bedrockCohereCol1: SettingsConfiguration = [
@ -632,8 +636,8 @@ const bedrockCohereCol2: SettingsConfiguration = [
bedrock.maxTokens, bedrock.maxTokens,
cohere.temperature, cohere.temperature,
cohere.topP, cohere.topP,
bedrock.region,
librechat.resendFiles, librechat.resendFiles,
bedrock.region,
]; ];
const bedrockGeneralCol1: SettingsConfiguration = [ const bedrockGeneralCol1: SettingsConfiguration = [
@ -646,8 +650,8 @@ const bedrockGeneralCol2: SettingsConfiguration = [
librechat.maxContextTokens, librechat.maxContextTokens,
meta.temperature, meta.temperature,
meta.topP, meta.topP,
bedrock.region,
librechat.resendFiles, librechat.resendFiles,
bedrock.region,
]; ];
export const settings: Record<string, SettingsConfiguration | undefined> = { export const settings: Record<string, SettingsConfiguration | undefined> = {

2636
package-lock.json generated

File diff suppressed because it is too large Load diff

View file

@ -1,6 +1,20 @@
import { z } from 'zod'; import { z } from 'zod';
import * as s from './schemas'; import * as s from './schemas';
type ThinkingConfig = {
type: 'enabled';
budget_tokens: number;
};
type AnthropicReasoning = {
thinking?: ThinkingConfig | boolean;
thinkingBudget?: number;
};
type AnthropicInput = BedrockConverseInput & {
additionalModelRequestFields: BedrockConverseInput['additionalModelRequestFields'] &
AnthropicReasoning;
};
export const bedrockInputSchema = s.tConversationSchema export const bedrockInputSchema = s.tConversationSchema
.pick({ .pick({
/* LibreChat params; optionType: 'conversation' */ /* LibreChat params; optionType: 'conversation' */
@ -21,11 +35,24 @@ export const bedrockInputSchema = s.tConversationSchema
temperature: true, temperature: true,
topP: true, topP: true,
stop: true, stop: true,
thinking: true,
thinkingBudget: true,
/* Catch-all fields */ /* Catch-all fields */
topK: true, topK: true,
additionalModelRequestFields: true, additionalModelRequestFields: true,
}) })
.transform((obj) => s.removeNullishValues(obj)) .transform((obj) => {
if ((obj as AnthropicInput).additionalModelRequestFields?.thinking != null) {
const _obj = obj as AnthropicInput;
obj.thinking = !!_obj.additionalModelRequestFields.thinking;
obj.thinkingBudget =
typeof _obj.additionalModelRequestFields.thinking === 'object'
? (_obj.additionalModelRequestFields.thinking as ThinkingConfig)?.budget_tokens
: undefined;
delete obj.additionalModelRequestFields;
}
return s.removeNullishValues(obj);
})
.catch(() => ({})); .catch(() => ({}));
export type BedrockConverseInput = z.infer<typeof bedrockInputSchema>; export type BedrockConverseInput = z.infer<typeof bedrockInputSchema>;
@ -49,6 +76,8 @@ export const bedrockInputParser = s.tConversationSchema
temperature: true, temperature: true,
topP: true, topP: true,
stop: true, stop: true,
thinking: true,
thinkingBudget: true,
/* Catch-all fields */ /* Catch-all fields */
topK: true, topK: true,
additionalModelRequestFields: true, additionalModelRequestFields: true,
@ -87,6 +116,27 @@ export const bedrockInputParser = s.tConversationSchema
} }
}); });
/** Default thinking and thinkingBudget for 'anthropic.claude-3-7-sonnet' models, if not defined */
if (
typeof typedData.model === 'string' &&
typedData.model.includes('anthropic.claude-3-7-sonnet')
) {
if (additionalFields.thinking === undefined) {
additionalFields.thinking = true;
} else if (additionalFields.thinking === false) {
delete additionalFields.thinking;
delete additionalFields.thinkingBudget;
}
if (additionalFields.thinking === true && additionalFields.thinkingBudget === undefined) {
additionalFields.thinkingBudget = 2000;
}
additionalFields.anthropic_beta = ['output-128k-2025-02-19'];
} else if (additionalFields.thinking != null || additionalFields.thinkingBudget != null) {
delete additionalFields.thinking;
delete additionalFields.thinkingBudget;
}
if (Object.keys(additionalFields).length > 0) { if (Object.keys(additionalFields).length > 0) {
typedData.additionalModelRequestFields = { typedData.additionalModelRequestFields = {
...((typedData.additionalModelRequestFields as Record<string, unknown> | undefined) || {}), ...((typedData.additionalModelRequestFields as Record<string, unknown> | undefined) || {}),
@ -104,9 +154,34 @@ export const bedrockInputParser = s.tConversationSchema
}) })
.catch(() => ({})); .catch(() => ({}));
/**
* Configures the "thinking" parameter based on given input and thinking options.
*
* @param data - The parsed Bedrock request options object
* @returns The object with thinking configured appropriately
*/
function configureThinking(data: AnthropicInput): AnthropicInput {
const updatedData = { ...data };
if (updatedData.additionalModelRequestFields?.thinking === true) {
updatedData.maxTokens = updatedData.maxTokens ?? updatedData.maxOutputTokens ?? 8192;
delete updatedData.maxOutputTokens;
const thinkingConfig: AnthropicReasoning['thinking'] = {
type: 'enabled',
budget_tokens: updatedData.additionalModelRequestFields.thinkingBudget ?? 2000,
};
if (thinkingConfig.budget_tokens > updatedData.maxTokens) {
thinkingConfig.budget_tokens = Math.floor(updatedData.maxTokens * 0.9);
}
updatedData.additionalModelRequestFields.thinking = thinkingConfig;
delete updatedData.additionalModelRequestFields.thinkingBudget;
}
return updatedData;
}
export const bedrockOutputParser = (data: Record<string, unknown>) => { export const bedrockOutputParser = (data: Record<string, unknown>) => {
const knownKeys = [...Object.keys(s.tConversationSchema.shape), 'topK', 'top_k']; const knownKeys = [...Object.keys(s.tConversationSchema.shape), 'topK', 'top_k'];
const result: Record<string, unknown> = {}; let result: Record<string, unknown> = {};
// Extract known fields from the root level // Extract known fields from the root level
Object.entries(data).forEach(([key, value]) => { Object.entries(data).forEach(([key, value]) => {
@ -125,6 +200,8 @@ export const bedrockOutputParser = (data: Record<string, unknown>) => {
if (knownKeys.includes(key)) { if (knownKeys.includes(key)) {
if (key === 'top_k') { if (key === 'top_k') {
result['topK'] = value; result['topK'] = value;
} else if (key === 'thinking' || key === 'thinkingBudget') {
return;
} else { } else {
result[key] = value; result[key] = value;
} }
@ -140,8 +217,11 @@ export const bedrockOutputParser = (data: Record<string, unknown>) => {
result.maxTokens = result.maxOutputTokens; result.maxTokens = result.maxOutputTokens;
} }
// Remove additionalModelRequestFields from the result result = configureThinking(result as AnthropicInput);
// Remove additionalModelRequestFields from the result if it doesn't thinking config
if ((result as AnthropicInput).additionalModelRequestFields?.thinking == null) {
delete result.additionalModelRequestFields; delete result.additionalModelRequestFields;
}
return result; return result;
}; };