🔑 fix: Azure Serverless Support for API Key Header & Version (#4791)

* fix: azure validation/extraction types

* fix: typing, add optional chaining for modelGroup and groupMap properties; expect azureOpenAIApiVersion in serverless tests

* fix: add support for azureOpenAIApiVersion and api-key in serverless mode across clients

* chore: update CONFIG_VERSION to 1.1.8, data-provider bump
This commit is contained in:
Danny Avila 2024-11-25 13:33:06 -05:00 committed by GitHub
parent 07511b3db8
commit e0a5f879b6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 100 additions and 42 deletions

View file

@ -227,6 +227,16 @@ class ChatGPTClient extends BaseClient {
this.azure = !serverless && azureOptions; this.azure = !serverless && azureOptions;
this.azureEndpoint = this.azureEndpoint =
!serverless && genAzureChatCompletion(this.azure, modelOptions.model, this); !serverless && genAzureChatCompletion(this.azure, modelOptions.model, this);
if (serverless === true) {
this.options.defaultQuery = azureOptions.azureOpenAIApiVersion
? { 'api-version': azureOptions.azureOpenAIApiVersion }
: undefined;
this.options.headers['api-key'] = this.apiKey;
}
}
if (this.options.defaultQuery) {
opts.defaultQuery = this.options.defaultQuery;
} }
if (this.options.headers) { if (this.options.headers) {

View file

@ -838,6 +838,12 @@ class OpenAIClient extends BaseClient {
this.options.dropParams = azureConfig.groupMap[groupName].dropParams; this.options.dropParams = azureConfig.groupMap[groupName].dropParams;
this.options.forcePrompt = azureConfig.groupMap[groupName].forcePrompt; this.options.forcePrompt = azureConfig.groupMap[groupName].forcePrompt;
this.azure = !serverless && azureOptions; this.azure = !serverless && azureOptions;
if (serverless === true) {
this.options.defaultQuery = azureOptions.azureOpenAIApiVersion
? { 'api-version': azureOptions.azureOpenAIApiVersion }
: undefined;
this.options.headers['api-key'] = this.apiKey;
}
} }
const titleChatCompletion = async () => { const titleChatCompletion = async () => {
@ -1169,6 +1175,10 @@ ${convo}
opts.defaultHeaders = { ...opts.defaultHeaders, ...this.options.headers }; opts.defaultHeaders = { ...opts.defaultHeaders, ...this.options.headers };
} }
if (this.options.defaultQuery) {
opts.defaultQuery = this.options.defaultQuery;
}
if (this.options.proxy) { if (this.options.proxy) {
opts.httpAgent = new HttpsProxyAgent(this.options.proxy); opts.httpAgent = new HttpsProxyAgent(this.options.proxy);
} }
@ -1207,6 +1217,12 @@ ${convo}
this.azure = !serverless && azureOptions; this.azure = !serverless && azureOptions;
this.azureEndpoint = this.azureEndpoint =
!serverless && genAzureChatCompletion(this.azure, modelOptions.model, this); !serverless && genAzureChatCompletion(this.azure, modelOptions.model, this);
if (serverless === true) {
this.options.defaultQuery = azureOptions.azureOpenAIApiVersion
? { 'api-version': azureOptions.azureOpenAIApiVersion }
: undefined;
this.options.headers['api-key'] = this.apiKey;
}
} }
if (this.azure || this.options.azure) { if (this.azure || this.options.azure) {

View file

@ -135,6 +135,12 @@ const initializeClient = async ({ req, res, version, endpointOption, initAppClie
clientOptions.reverseProxyUrl = baseURL ?? clientOptions.reverseProxyUrl; clientOptions.reverseProxyUrl = baseURL ?? clientOptions.reverseProxyUrl;
clientOptions.headers = opts.defaultHeaders; clientOptions.headers = opts.defaultHeaders;
clientOptions.azure = !serverless && azureOptions; clientOptions.azure = !serverless && azureOptions;
if (serverless === true) {
clientOptions.defaultQuery = azureOptions.azureOpenAIApiVersion
? { 'api-version': azureOptions.azureOpenAIApiVersion }
: undefined;
clientOptions.headers['api-key'] = apiKey;
}
} }
} }

View file

@ -96,6 +96,12 @@ const initializeClient = async ({ req, res, endpointOption }) => {
apiKey = azureOptions.azureOpenAIApiKey; apiKey = azureOptions.azureOpenAIApiKey;
clientOptions.azure = !serverless && azureOptions; clientOptions.azure = !serverless && azureOptions;
if (serverless === true) {
clientOptions.defaultQuery = azureOptions.azureOpenAIApiVersion
? { 'api-version': azureOptions.azureOpenAIApiVersion }
: undefined;
clientOptions.headers['api-key'] = apiKey;
}
} else if (useAzure || (apiKey && apiKey.includes('{"azure') && !clientOptions.azure)) { } else if (useAzure || (apiKey && apiKey.includes('{"azure') && !clientOptions.azure)) {
clientOptions.azure = userProvidesKey ? JSON.parse(userValues.apiKey) : getAzureCredentials(); clientOptions.azure = userProvidesKey ? JSON.parse(userValues.apiKey) : getAzureCredentials();
apiKey = clientOptions.azure.azureOpenAIApiKey; apiKey = clientOptions.azure.azureOpenAIApiKey;

View file

@ -97,6 +97,12 @@ const initializeClient = async ({
apiKey = azureOptions.azureOpenAIApiKey; apiKey = azureOptions.azureOpenAIApiKey;
clientOptions.azure = !serverless && azureOptions; clientOptions.azure = !serverless && azureOptions;
if (serverless === true) {
clientOptions.defaultQuery = azureOptions.azureOpenAIApiVersion
? { 'api-version': azureOptions.azureOpenAIApiVersion }
: undefined;
clientOptions.headers['api-key'] = apiKey;
}
} else if (isAzureOpenAI) { } else if (isAzureOpenAI) {
clientOptions.azure = userProvidesKey ? JSON.parse(userValues.apiKey) : getAzureCredentials(); clientOptions.azure = userProvidesKey ? JSON.parse(userValues.apiKey) : getAzureCredentials();
apiKey = clientOptions.azure.azureOpenAIApiKey; apiKey = clientOptions.azure.azureOpenAIApiKey;

View file

@ -29,6 +29,7 @@ function getLLMConfig(apiKey, options = {}) {
modelOptions = {}, modelOptions = {},
reverseProxyUrl, reverseProxyUrl,
useOpenRouter, useOpenRouter,
defaultQuery,
headers, headers,
proxy, proxy,
azure, azure,
@ -74,6 +75,10 @@ function getLLMConfig(apiKey, options = {}) {
} }
} }
if (defaultQuery) {
configOptions.baseOptions.defaultQuery = defaultQuery;
}
if (proxy) { if (proxy) {
const proxyAgent = new HttpsProxyAgent(proxy); const proxyAgent = new HttpsProxyAgent(proxy);
Object.assign(configOptions, { Object.assign(configOptions, {

2
package-lock.json generated
View file

@ -36137,7 +36137,7 @@
}, },
"packages/data-provider": { "packages/data-provider": {
"name": "librechat-data-provider", "name": "librechat-data-provider",
"version": "0.7.55", "version": "0.7.56",
"license": "ISC", "license": "ISC",
"dependencies": { "dependencies": {
"@types/js-yaml": "^4.0.9", "@types/js-yaml": "^4.0.9",

View file

@ -1,6 +1,6 @@
{ {
"name": "librechat-data-provider", "name": "librechat-data-provider",
"version": "0.7.55", "version": "0.7.56",
"description": "data services for librechat apps", "description": "data services for librechat apps",
"main": "dist/index.js", "main": "dist/index.js",
"module": "dist/index.es.js", "module": "dist/index.es.js",

View file

@ -94,8 +94,8 @@ describe('validateAzureGroups', () => {
expect(isValid).toBe(true); expect(isValid).toBe(true);
const modelGroup = modelGroupMap['gpt-5-turbo']; const modelGroup = modelGroupMap['gpt-5-turbo'];
expect(modelGroup).toBeDefined(); expect(modelGroup).toBeDefined();
expect(modelGroup.group).toBe('japan-east'); expect(modelGroup?.group).toBe('japan-east');
expect(groupMap[modelGroup.group]).toBeDefined(); expect(groupMap[modelGroup?.group ?? '']).toBeDefined();
expect(modelNames).toContain('gpt-5-turbo'); expect(modelNames).toContain('gpt-5-turbo');
const { azureOptions } = mapModelToAzureConfig({ const { azureOptions } = mapModelToAzureConfig({
modelName: 'gpt-5-turbo', modelName: 'gpt-5-turbo',
@ -323,6 +323,7 @@ describe('validateAzureGroups for Serverless Configurations', () => {
expect(azureOptions).toEqual({ expect(azureOptions).toEqual({
azureOpenAIApiKey: 'def456', azureOpenAIApiKey: 'def456',
azureOpenAIApiVersion: '',
}); });
expect(baseURL).toEqual('https://new-serverless.example.com/v1/completions'); expect(baseURL).toEqual('https://new-serverless.example.com/v1/completions');
expect(serverless).toBe(true); expect(serverless).toBe(true);
@ -381,10 +382,10 @@ describe('validateAzureGroups with modelGroupMap and groupMap', () => {
const { isValid, modelGroupMap, groupMap } = validateAzureGroups(validConfigs); const { isValid, modelGroupMap, groupMap } = validateAzureGroups(validConfigs);
expect(isValid).toBe(true); expect(isValid).toBe(true);
expect(modelGroupMap['gpt-4-turbo']).toBeDefined(); expect(modelGroupMap['gpt-4-turbo']).toBeDefined();
expect(modelGroupMap['gpt-4-turbo'].group).toBe('us-east'); expect(modelGroupMap['gpt-4-turbo']?.group).toBe('us-east');
expect(groupMap['us-east']).toBeDefined(); expect(groupMap['us-east']).toBeDefined();
expect(groupMap['us-east'].apiKey).toBe('prod-1234'); expect(groupMap['us-east']?.apiKey).toBe('prod-1234');
expect(groupMap['us-east'].models['gpt-4-turbo']).toBeDefined(); expect(groupMap['us-east']?.models['gpt-4-turbo']).toBeDefined();
const { azureOptions, baseURL, headers } = mapModelToAzureConfig({ const { azureOptions, baseURL, headers } = mapModelToAzureConfig({
modelName: 'gpt-4-turbo', modelName: 'gpt-4-turbo',
modelGroupMap, modelGroupMap,
@ -765,6 +766,7 @@ describe('validateAzureGroups with modelGroupMap and groupMap', () => {
); );
expect(azureOptions7).toEqual({ expect(azureOptions7).toEqual({
azureOpenAIApiKey: 'mistral-key', azureOpenAIApiKey: 'mistral-key',
azureOpenAIApiVersion: '',
}); });
const { const {
@ -782,6 +784,7 @@ describe('validateAzureGroups with modelGroupMap and groupMap', () => {
); );
expect(azureOptions8).toEqual({ expect(azureOptions8).toEqual({
azureOpenAIApiKey: 'llama-key', azureOpenAIApiKey: 'llama-key',
azureOpenAIApiVersion: '',
}); });
}); });
}); });

View file

@ -63,13 +63,13 @@ export function validateAzureGroups(configs: TAzureGroups): TAzureConfigValidati
const { const {
group: groupName, group: groupName,
apiKey, apiKey,
instanceName, instanceName = '',
deploymentName, deploymentName = '',
version, version = '',
baseURL, baseURL = '',
additionalHeaders, additionalHeaders,
models, models,
serverless, serverless = false,
...rest ...rest
} = group; } = group;
@ -120,9 +120,11 @@ export function validateAzureGroups(configs: TAzureGroups): TAzureConfigValidati
continue; continue;
} }
const groupDeploymentName = group.deploymentName ?? '';
const groupVersion = group.version ?? '';
if (typeof model === 'boolean') { if (typeof model === 'boolean') {
// For boolean models, check if group-level deploymentName and version are present. // For boolean models, check if group-level deploymentName and version are present.
if (!group.deploymentName || !group.version) { if (!groupDeploymentName || !groupVersion) {
errors.push( errors.push(
`Model "${modelName}" in group "${groupName}" is missing a deploymentName or version.`, `Model "${modelName}" in group "${groupName}" is missing a deploymentName or version.`,
); );
@ -133,11 +135,10 @@ export function validateAzureGroups(configs: TAzureGroups): TAzureConfigValidati
group: groupName, group: groupName,
}; };
} else { } else {
const modelDeploymentName = model.deploymentName ?? '';
const modelVersion = model.version ?? '';
// For object models, check if deploymentName and version are required but missing. // For object models, check if deploymentName and version are required but missing.
if ( if ((!modelDeploymentName && !groupDeploymentName) || (!modelVersion && !groupVersion)) {
(!model.deploymentName && !group.deploymentName) ||
(!model.version && !group.version)
) {
errors.push( errors.push(
`Model "${modelName}" in group "${groupName}" is missing a required deploymentName or version.`, `Model "${modelName}" in group "${groupName}" is missing a required deploymentName or version.`,
); );
@ -146,8 +147,8 @@ export function validateAzureGroups(configs: TAzureGroups): TAzureConfigValidati
modelGroupMap[modelName] = { modelGroupMap[modelName] = {
group: groupName, group: groupName,
// deploymentName: model.deploymentName || group.deploymentName, // deploymentName: modelDeploymentName || groupDeploymentName,
// version: model.version || group.version, // version: modelVersion || groupVersion,
}; };
} }
} }
@ -190,26 +191,28 @@ export function mapModelToAzureConfig({
); );
} }
const instanceName = groupConfig.instanceName; const instanceName = groupConfig.instanceName ?? '';
if (!instanceName && !groupConfig.serverless) { if (!instanceName && groupConfig.serverless !== true) {
throw new Error( throw new Error(
`Group "${modelConfig.group}" is missing an instanceName for non-serverless configuration.`, `Group "${modelConfig.group}" is missing an instanceName for non-serverless configuration.`,
); );
} }
if (groupConfig.serverless && !groupConfig.baseURL) { const baseURL = groupConfig.baseURL ?? '';
if (groupConfig.serverless === true && !baseURL) {
throw new Error( throw new Error(
`Group "${modelConfig.group}" is missing the required base URL for serverless configuration.`, `Group "${modelConfig.group}" is missing the required base URL for serverless configuration.`,
); );
} }
if (groupConfig.serverless) { if (groupConfig.serverless === true) {
const result: MappedAzureConfig = { const result: MappedAzureConfig = {
azureOptions: { azureOptions: {
azureOpenAIApiVersion: extractEnvVariable(groupConfig.version ?? ''),
azureOpenAIApiKey: extractEnvVariable(groupConfig.apiKey), azureOpenAIApiKey: extractEnvVariable(groupConfig.apiKey),
}, },
baseURL: extractEnvVariable(groupConfig.baseURL as string), baseURL: extractEnvVariable(baseURL),
serverless: true, serverless: true,
}; };
@ -232,11 +235,11 @@ export function mapModelToAzureConfig({
} }
const modelDetails = groupConfig.models[modelName]; const modelDetails = groupConfig.models[modelName];
const { deploymentName, version } = const { deploymentName = '', version = '' } =
typeof modelDetails === 'object' typeof modelDetails === 'object'
? { ? {
deploymentName: modelDetails.deploymentName || groupConfig.deploymentName, deploymentName: modelDetails.deploymentName ?? groupConfig.deploymentName,
version: modelDetails.version || groupConfig.version, version: modelDetails.version ?? groupConfig.version,
} }
: { : {
deploymentName: groupConfig.deploymentName, deploymentName: groupConfig.deploymentName,
@ -264,8 +267,8 @@ export function mapModelToAzureConfig({
const result: MappedAzureConfig = { azureOptions }; const result: MappedAzureConfig = { azureOptions };
if (groupConfig.baseURL) { if (baseURL) {
result.baseURL = extractEnvVariable(groupConfig.baseURL); result.baseURL = extractEnvVariable(baseURL);
} }
if (groupConfig.additionalHeaders) { if (groupConfig.additionalHeaders) {
@ -287,15 +290,17 @@ export function mapGroupToAzureConfig({
throw new Error(`Group named "${groupName}" not found in configuration.`); throw new Error(`Group named "${groupName}" not found in configuration.`);
} }
const instanceName = groupConfig.instanceName as string; const instanceName = groupConfig.instanceName ?? '';
const serverless = groupConfig.serverless ?? false;
const baseURL = groupConfig.baseURL ?? '';
if (!instanceName && !groupConfig.serverless) { if (!instanceName && !serverless) {
throw new Error( throw new Error(
`Group "${groupName}" is missing an instanceName for non-serverless configuration.`, `Group "${groupName}" is missing an instanceName for non-serverless configuration.`,
); );
} }
if (groupConfig.serverless && !groupConfig.baseURL) { if (serverless && !baseURL) {
throw new Error( throw new Error(
`Group "${groupName}" is missing the required base URL for serverless configuration.`, `Group "${groupName}" is missing the required base URL for serverless configuration.`,
); );
@ -311,25 +316,26 @@ export function mapGroupToAzureConfig({
const modelDetails = groupConfig.models[firstModelName]; const modelDetails = groupConfig.models[firstModelName];
const azureOptions: AzureOptions = { const azureOptions: AzureOptions = {
azureOpenAIApiVersion: extractEnvVariable(groupConfig.version ?? ''),
azureOpenAIApiKey: extractEnvVariable(groupConfig.apiKey), azureOpenAIApiKey: extractEnvVariable(groupConfig.apiKey),
azureOpenAIApiInstanceName: extractEnvVariable(instanceName), azureOpenAIApiInstanceName: extractEnvVariable(instanceName),
// DeploymentName and Version set below // DeploymentName and Version set below
}; };
if (groupConfig.serverless) { if (serverless) {
return { return {
azureOptions, azureOptions,
baseURL: extractEnvVariable(groupConfig.baseURL ?? ''), baseURL: extractEnvVariable(baseURL),
serverless: true, serverless: true,
...(groupConfig.additionalHeaders && { headers: groupConfig.additionalHeaders }), ...(groupConfig.additionalHeaders && { headers: groupConfig.additionalHeaders }),
}; };
} }
const { deploymentName, version } = const { deploymentName = '', version = '' } =
typeof modelDetails === 'object' typeof modelDetails === 'object'
? { ? {
deploymentName: modelDetails.deploymentName || groupConfig.deploymentName, deploymentName: modelDetails.deploymentName ?? groupConfig.deploymentName,
version: modelDetails.version || groupConfig.version, version: modelDetails.version ?? groupConfig.version,
} }
: { : {
deploymentName: groupConfig.deploymentName, deploymentName: groupConfig.deploymentName,
@ -347,8 +353,8 @@ export function mapGroupToAzureConfig({
const result: MappedAzureConfig = { azureOptions }; const result: MappedAzureConfig = { azureOptions };
if (groupConfig.baseURL) { if (baseURL) {
result.baseURL = extractEnvVariable(groupConfig.baseURL); result.baseURL = extractEnvVariable(baseURL);
} }
if (groupConfig.additionalHeaders) { if (groupConfig.additionalHeaders) {

View file

@ -114,10 +114,10 @@ export type TAzureModelMapSchema = {
group: string; group: string;
}; };
export type TAzureModelGroupMap = Record<string, TAzureModelMapSchema>; export type TAzureModelGroupMap = Record<string, TAzureModelMapSchema | undefined>;
export type TAzureGroupMap = Record< export type TAzureGroupMap = Record<
string, string,
TAzureBaseSchema & { models: Record<string, TAzureModelConfig> } (TAzureBaseSchema & { models: Record<string, TAzureModelConfig | undefined> }) | undefined
>; >;
export type TValidatedAzureConfig = { export type TValidatedAzureConfig = {
@ -1080,7 +1080,7 @@ export enum Constants {
/** Key for the app's version. */ /** Key for the app's version. */
VERSION = 'v0.7.5', VERSION = 'v0.7.5',
/** Key for the Custom Config's version (librechat.yaml). */ /** Key for the Custom Config's version (librechat.yaml). */
CONFIG_VERSION = '1.1.7', CONFIG_VERSION = '1.1.8',
/** Standard value for the first message's `parentMessageId` value, to indicate no parent exists. */ /** Standard value for the first message's `parentMessageId` value, to indicate no parent exists. */
NO_PARENT = '00000000-0000-0000-0000-000000000000', NO_PARENT = '00000000-0000-0000-0000-000000000000',
/** Standard value for the initial conversationId before a request is sent */ /** Standard value for the initial conversationId before a request is sent */