LibreChat/packages/data-provider/src/bedrock.ts
Danny Avila 1a1e6850a3
🪨 fix: Minor AWS Bedrock/Misc. Improvements (#3974)
* refactor(EditMessage): avoid manipulation of native paste handling, leverage react-hook-form for textarea changes

* style: apply better theming for MinimalIcon

* fix(useVoicesQuery/useCustomConfigSpeechQuery): make sure to only try request once per render

* feat: edit message content parts

* fix(useCopyToClipboard): handle both assistants and agents content blocks

* refactor: remove save & submit and update text content correctly

* chore(.env.example/config): exclude unsupported bedrock models

* feat: artifacts for aws bedrock

* fix: export options for bedrock conversations
2024-09-10 12:56:19 -04:00

147 lines
3.9 KiB
TypeScript

import { z } from 'zod';
import * as s from './schemas';
export const bedrockInputSchema = s.tConversationSchema
.pick({
/* LibreChat params; optionType: 'conversation' */
modelLabel: true,
promptPrefix: true,
resendFiles: true,
iconURL: true,
greeting: true,
spec: true,
maxOutputTokens: true,
maxContextTokens: true,
artifacts: true,
/* Bedrock params; optionType: 'model' */
region: true,
system: true,
model: true,
maxTokens: true,
temperature: true,
topP: true,
stop: true,
/* Catch-all fields */
topK: true,
additionalModelRequestFields: true,
})
.transform(s.removeNullishValues)
.catch(() => ({}));
export type BedrockConverseInput = z.infer<typeof bedrockInputSchema>;
export const bedrockInputParser = s.tConversationSchema
.pick({
/* LibreChat params; optionType: 'conversation' */
modelLabel: true,
promptPrefix: true,
resendFiles: true,
iconURL: true,
greeting: true,
spec: true,
artifacts: true,
maxOutputTokens: true,
maxContextTokens: true,
/* Bedrock params; optionType: 'model' */
region: true,
model: true,
maxTokens: true,
temperature: true,
topP: true,
stop: true,
/* Catch-all fields */
topK: true,
additionalModelRequestFields: true,
})
.catchall(z.any())
.transform((data) => {
const knownKeys = [
'modelLabel',
'promptPrefix',
'resendFiles',
'iconURL',
'greeting',
'spec',
'maxOutputTokens',
'artifacts',
'additionalModelRequestFields',
'region',
'model',
'maxTokens',
'temperature',
'topP',
'stop',
];
const additionalFields: Record<string, unknown> = {};
const typedData = data as Record<string, unknown>;
Object.entries(typedData).forEach(([key, value]) => {
if (!knownKeys.includes(key)) {
if (key === 'topK') {
additionalFields['top_k'] = value;
} else {
additionalFields[key] = value;
}
delete typedData[key];
}
});
if (Object.keys(additionalFields).length > 0) {
typedData.additionalModelRequestFields = {
...((typedData.additionalModelRequestFields as Record<string, unknown> | undefined) || {}),
...additionalFields,
};
}
if (typedData.maxOutputTokens !== undefined) {
typedData.maxTokens = typedData.maxOutputTokens;
} else if (typedData.maxTokens !== undefined) {
typedData.maxOutputTokens = typedData.maxTokens;
}
return s.removeNullishValues(typedData) as BedrockConverseInput;
})
.catch(() => ({}));
export const bedrockOutputParser = (data: Record<string, unknown>) => {
const knownKeys = [...Object.keys(s.tConversationSchema.shape), 'topK', 'top_k'];
const result: Record<string, unknown> = {};
// Extract known fields from the root level
Object.entries(data).forEach(([key, value]) => {
if (knownKeys.includes(key)) {
result[key] = value;
}
});
// Extract known fields from additionalModelRequestFields
if (
typeof data.additionalModelRequestFields === 'object' &&
data.additionalModelRequestFields !== null
) {
Object.entries(data.additionalModelRequestFields as Record<string, unknown>).forEach(
([key, value]) => {
if (knownKeys.includes(key)) {
if (key === 'top_k') {
result['topK'] = value;
} else {
result[key] = value;
}
}
},
);
}
// Handle maxTokens and maxOutputTokens
if (result.maxTokens !== undefined && result.maxOutputTokens === undefined) {
result.maxOutputTokens = result.maxTokens;
} else if (result.maxOutputTokens !== undefined && result.maxTokens === undefined) {
result.maxTokens = result.maxOutputTokens;
}
// Remove additionalModelRequestFields from the result
delete result.additionalModelRequestFields;
return result;
};