mirror of
https://github.com/danny-avila/LibreChat.git
synced 2025-12-18 01:10:14 +01:00
🅰️ feat: Azure AI Studio, Models as a Service Support (#1902)
* feat(data-provider): add Azure serverless inference handling through librechat.yaml * feat(azureOpenAI): serverless inference handling in api * docs: update docs with new azureOpenAI endpoint config fields and serverless inference endpoint setup * chore: remove unnecessary checks for apiKey as schema would not allow apiKey to be undefined * ci(azureOpenAI): update tests for serverless configurations
This commit is contained in:
parent
6d6b3c9c1d
commit
08d4b3cc8a
9 changed files with 460 additions and 26 deletions
|
|
@ -1,6 +1,6 @@
|
|||
{
|
||||
"name": "librechat-data-provider",
|
||||
"version": "0.4.4",
|
||||
"version": "0.4.5",
|
||||
"description": "data services for librechat apps",
|
||||
"main": "dist/index.js",
|
||||
"module": "dist/index.es.js",
|
||||
|
|
|
|||
|
|
@ -188,13 +188,147 @@ describe('validateAzureGroups', () => {
|
|||
},
|
||||
},
|
||||
];
|
||||
// @ts-expect-error This error is expected because the 'instanceName' property is intentionally left out.
|
||||
const { isValid, errors } = validateAzureGroups(configs);
|
||||
expect(isValid).toBe(false);
|
||||
expect(errors.length).toBe(1);
|
||||
});
|
||||
});
|
||||
|
||||
describe('validateAzureGroups for Serverless Configurations', () => {
|
||||
const originalEnv = process.env;
|
||||
|
||||
beforeEach(() => {
|
||||
jest.resetModules();
|
||||
process.env = { ...originalEnv };
|
||||
});
|
||||
|
||||
afterAll(() => {
|
||||
process.env = originalEnv;
|
||||
});
|
||||
|
||||
it('should validate a correct serverless configuration', () => {
|
||||
const configs = [
|
||||
{
|
||||
group: 'serverless-group',
|
||||
apiKey: '${SERVERLESS_API_KEY}',
|
||||
baseURL: 'https://serverless.example.com/v1/completions',
|
||||
serverless: true,
|
||||
models: {
|
||||
'model-serverless': true,
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
const { isValid, errors } = validateAzureGroups(configs);
|
||||
|
||||
expect(isValid).toBe(true);
|
||||
expect(errors.length).toBe(0);
|
||||
});
|
||||
|
||||
it('should return invalid for a serverless configuration missing baseURL', () => {
|
||||
const configs = [
|
||||
{
|
||||
group: 'serverless-group',
|
||||
apiKey: '${SERVERLESS_API_KEY}',
|
||||
serverless: true,
|
||||
models: {
|
||||
'model-serverless': true,
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
const { isValid, errors } = validateAzureGroups(configs);
|
||||
expect(isValid).toBe(false);
|
||||
expect(errors).toEqual(
|
||||
expect.arrayContaining([
|
||||
expect.stringContaining(
|
||||
'Group "serverless-group" is serverless but missing mandatory "baseURL."',
|
||||
),
|
||||
]),
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw an error when environment variable for apiKey is not set', () => {
|
||||
process.env.SERVERLESS_API_KEY = '';
|
||||
|
||||
expect(() => {
|
||||
mapModelToAzureConfig({
|
||||
modelName: 'model-serverless',
|
||||
modelGroupMap: {
|
||||
'model-serverless': {
|
||||
group: 'serverless-group',
|
||||
},
|
||||
},
|
||||
groupMap: {
|
||||
'serverless-group': {
|
||||
apiKey: '${SERVERLESS_API_KEY}',
|
||||
baseURL: 'https://serverless.example.com/v1/completions',
|
||||
serverless: true,
|
||||
models: { 'model-serverless': true },
|
||||
},
|
||||
},
|
||||
});
|
||||
}).toThrow('Azure configuration environment variable "${SERVERLESS_API_KEY}" was not found.');
|
||||
});
|
||||
|
||||
it('should correctly extract environment variables and prepare serverless config', () => {
|
||||
process.env.SERVERLESS_API_KEY = 'abc123';
|
||||
|
||||
const { azureOptions, baseURL, serverless } = mapModelToAzureConfig({
|
||||
modelName: 'model-serverless',
|
||||
modelGroupMap: {
|
||||
'model-serverless': {
|
||||
group: 'serverless-group',
|
||||
},
|
||||
},
|
||||
groupMap: {
|
||||
'serverless-group': {
|
||||
apiKey: '${SERVERLESS_API_KEY}',
|
||||
baseURL: 'https://serverless.example.com/v1/completions',
|
||||
serverless: true,
|
||||
models: { 'model-serverless': true },
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
expect(azureOptions.azureOpenAIApiKey).toEqual('abc123');
|
||||
expect(baseURL).toEqual('https://serverless.example.com/v1/completions');
|
||||
expect(serverless).toBe(true);
|
||||
});
|
||||
|
||||
it('should ensure serverless flag triggers appropriate validations and mappings', () => {
|
||||
const configs = [
|
||||
{
|
||||
group: 'serverless-group-2',
|
||||
apiKey: '${NEW_SERVERLESS_API_KEY}',
|
||||
baseURL: 'https://new-serverless.example.com/v1/completions',
|
||||
serverless: true,
|
||||
models: {
|
||||
'new-model-serverless': true,
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
process.env.NEW_SERVERLESS_API_KEY = 'def456';
|
||||
|
||||
const { isValid, errors, modelGroupMap, groupMap } = validateAzureGroups(configs);
|
||||
expect(isValid).toBe(true);
|
||||
expect(errors.length).toBe(0);
|
||||
|
||||
const { azureOptions, baseURL, serverless } = mapModelToAzureConfig({
|
||||
modelName: 'new-model-serverless',
|
||||
modelGroupMap,
|
||||
groupMap,
|
||||
});
|
||||
|
||||
expect(azureOptions).toEqual({
|
||||
azureOpenAIApiKey: 'def456',
|
||||
});
|
||||
expect(baseURL).toEqual('https://new-serverless.example.com/v1/completions');
|
||||
expect(serverless).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe('validateAzureGroups with modelGroupMap and groupMap', () => {
|
||||
const originalEnv = process.env;
|
||||
|
||||
|
|
@ -396,6 +530,8 @@ describe('validateAzureGroups with modelGroupMap and groupMap', () => {
|
|||
it('should list all expected models in both modelGroupMap and groupMap', () => {
|
||||
process.env.WESTUS_API_KEY = 'westus-key';
|
||||
process.env.EASTUS_API_KEY = 'eastus-key';
|
||||
process.env.AZURE_MISTRAL_API_KEY = 'mistral-key';
|
||||
process.env.AZURE_LLAMA2_70B_API_KEY = 'llama-key';
|
||||
|
||||
const validConfigs: TAzureGroups = [
|
||||
{
|
||||
|
|
@ -436,6 +572,26 @@ describe('validateAzureGroups with modelGroupMap and groupMap', () => {
|
|||
'x-api-key': 'x-api-key-value',
|
||||
},
|
||||
},
|
||||
{
|
||||
group: 'mistral-inference',
|
||||
apiKey: '${AZURE_MISTRAL_API_KEY}',
|
||||
baseURL:
|
||||
'https://Mistral-large-vnpet-serverless.region.inference.ai.azure.com/v1/chat/completions',
|
||||
serverless: true,
|
||||
models: {
|
||||
'mistral-large': true,
|
||||
},
|
||||
},
|
||||
{
|
||||
group: 'llama-70b-chat',
|
||||
apiKey: '${AZURE_LLAMA2_70B_API_KEY}',
|
||||
baseURL:
|
||||
'https://Llama-2-70b-chat-qmvyb-serverless.region.inference.ai.azure.com/v1/chat/completions',
|
||||
serverless: true,
|
||||
models: {
|
||||
'llama-70b-chat': true,
|
||||
},
|
||||
},
|
||||
];
|
||||
const { isValid, modelGroupMap, groupMap, modelNames } = validateAzureGroups(validConfigs);
|
||||
expect(isValid).toBe(true);
|
||||
|
|
@ -446,6 +602,8 @@ describe('validateAzureGroups with modelGroupMap and groupMap', () => {
|
|||
'gpt-4',
|
||||
'gpt-4-1106-preview',
|
||||
'gpt-4-turbo',
|
||||
'mistral-large',
|
||||
'llama-70b-chat',
|
||||
]);
|
||||
|
||||
// Check modelGroupMap
|
||||
|
|
@ -484,6 +642,34 @@ describe('validateAzureGroups with modelGroupMap and groupMap', () => {
|
|||
}),
|
||||
);
|
||||
|
||||
// Check groupMap for 'mistral-inference'
|
||||
expect(groupMap).toHaveProperty('mistral-inference');
|
||||
expect(groupMap['mistral-inference']).toEqual(
|
||||
expect.objectContaining({
|
||||
apiKey: '${AZURE_MISTRAL_API_KEY}',
|
||||
baseURL:
|
||||
'https://Mistral-large-vnpet-serverless.region.inference.ai.azure.com/v1/chat/completions',
|
||||
serverless: true,
|
||||
models: expect.objectContaining({
|
||||
'mistral-large': true,
|
||||
}),
|
||||
}),
|
||||
);
|
||||
|
||||
// Check groupMap for 'llama-70b-chat'
|
||||
expect(groupMap).toHaveProperty('llama-70b-chat');
|
||||
expect(groupMap['llama-70b-chat']).toEqual(
|
||||
expect.objectContaining({
|
||||
apiKey: '${AZURE_LLAMA2_70B_API_KEY}',
|
||||
baseURL:
|
||||
'https://Llama-2-70b-chat-qmvyb-serverless.region.inference.ai.azure.com/v1/chat/completions',
|
||||
serverless: true,
|
||||
models: expect.objectContaining({
|
||||
'llama-70b-chat': true,
|
||||
}),
|
||||
}),
|
||||
);
|
||||
|
||||
const { azureOptions: azureOptions1 } = mapModelToAzureConfig({
|
||||
modelName: 'gpt-4-vision-preview',
|
||||
modelGroupMap,
|
||||
|
|
@ -563,5 +749,39 @@ describe('validateAzureGroups with modelGroupMap and groupMap', () => {
|
|||
azureOpenAIApiDeploymentName: 'gpt-4-1106-preview',
|
||||
azureOpenAIApiVersion: '2023-12-01-preview',
|
||||
});
|
||||
|
||||
const {
|
||||
azureOptions: azureOptions7,
|
||||
serverless: serverlessMistral,
|
||||
baseURL: mistralEndpoint,
|
||||
} = mapModelToAzureConfig({
|
||||
modelName: 'mistral-large',
|
||||
modelGroupMap,
|
||||
groupMap,
|
||||
});
|
||||
expect(serverlessMistral).toBe(true);
|
||||
expect(mistralEndpoint).toBe(
|
||||
'https://Mistral-large-vnpet-serverless.region.inference.ai.azure.com/v1/chat/completions',
|
||||
);
|
||||
expect(azureOptions7).toEqual({
|
||||
azureOpenAIApiKey: 'mistral-key',
|
||||
});
|
||||
|
||||
const {
|
||||
azureOptions: azureOptions8,
|
||||
serverless: serverlessLlama,
|
||||
baseURL: llamaEndpoint,
|
||||
} = mapModelToAzureConfig({
|
||||
modelName: 'llama-70b-chat',
|
||||
modelGroupMap,
|
||||
groupMap,
|
||||
});
|
||||
expect(serverlessLlama).toBe(true);
|
||||
expect(llamaEndpoint).toBe(
|
||||
'https://Llama-2-70b-chat-qmvyb-serverless.region.inference.ai.azure.com/v1/chat/completions',
|
||||
);
|
||||
expect(azureOptions8).toEqual({
|
||||
azureOpenAIApiKey: 'llama-key',
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -71,6 +71,8 @@ export function validateAzureGroups(configs: TAzureGroups): TValidatedAzureConfi
|
|||
baseURL,
|
||||
additionalHeaders,
|
||||
models,
|
||||
serverless,
|
||||
...rest
|
||||
} = group;
|
||||
|
||||
if (groupMap[groupName]) {
|
||||
|
|
@ -78,6 +80,18 @@ export function validateAzureGroups(configs: TAzureGroups): TValidatedAzureConfi
|
|||
return { isValid: false, modelNames, modelGroupMap, groupMap, errors };
|
||||
}
|
||||
|
||||
if (serverless && !baseURL) {
|
||||
errors.push(`Group "${groupName}" is serverless but missing mandatory "baseURL."`);
|
||||
return { isValid: false, modelNames, modelGroupMap, groupMap, errors };
|
||||
}
|
||||
|
||||
if (!instanceName && !serverless) {
|
||||
errors.push(
|
||||
`Group "${groupName}" is missing an "instanceName" for non-serverless configuration.`,
|
||||
);
|
||||
return { isValid: false, modelNames, modelGroupMap, groupMap, errors };
|
||||
}
|
||||
|
||||
groupMap[groupName] = {
|
||||
apiKey,
|
||||
instanceName,
|
||||
|
|
@ -86,6 +100,8 @@ export function validateAzureGroups(configs: TAzureGroups): TValidatedAzureConfi
|
|||
baseURL,
|
||||
additionalHeaders,
|
||||
models,
|
||||
serverless,
|
||||
...rest,
|
||||
};
|
||||
|
||||
for (const modelName in group.models) {
|
||||
|
|
@ -99,6 +115,13 @@ export function validateAzureGroups(configs: TAzureGroups): TValidatedAzureConfi
|
|||
return { isValid: false, modelNames, modelGroupMap, groupMap, errors };
|
||||
}
|
||||
|
||||
if (serverless) {
|
||||
modelGroupMap[modelName] = {
|
||||
group: groupName,
|
||||
};
|
||||
continue;
|
||||
}
|
||||
|
||||
if (typeof model === 'boolean') {
|
||||
// For boolean models, check if group-level deploymentName and version are present.
|
||||
if (!group.deploymentName || !group.version) {
|
||||
|
|
@ -138,15 +161,16 @@ export function validateAzureGroups(configs: TAzureGroups): TValidatedAzureConfi
|
|||
|
||||
type AzureOptions = {
|
||||
azureOpenAIApiKey: string;
|
||||
azureOpenAIApiInstanceName: string;
|
||||
azureOpenAIApiDeploymentName: string;
|
||||
azureOpenAIApiVersion: string;
|
||||
azureOpenAIApiInstanceName?: string;
|
||||
azureOpenAIApiDeploymentName?: string;
|
||||
azureOpenAIApiVersion?: string;
|
||||
};
|
||||
|
||||
type MappedAzureConfig = {
|
||||
azureOptions: AzureOptions;
|
||||
baseURL?: string;
|
||||
headers?: Record<string, string>;
|
||||
serverless?: boolean;
|
||||
};
|
||||
|
||||
export function mapModelToAzureConfig({
|
||||
|
|
@ -168,6 +192,47 @@ export function mapModelToAzureConfig({
|
|||
);
|
||||
}
|
||||
|
||||
const instanceName = groupConfig.instanceName;
|
||||
|
||||
if (!instanceName && !groupConfig.serverless) {
|
||||
throw new Error(
|
||||
`Group "${modelConfig.group}" is missing an instanceName for non-serverless configuration.`,
|
||||
);
|
||||
}
|
||||
|
||||
if (groupConfig.serverless && !groupConfig.baseURL) {
|
||||
throw new Error(
|
||||
`Group "${modelConfig.group}" is missing the required base URL for serverless configuration.`,
|
||||
);
|
||||
}
|
||||
|
||||
if (groupConfig.serverless) {
|
||||
const result: MappedAzureConfig = {
|
||||
azureOptions: {
|
||||
azureOpenAIApiKey: extractEnvVariable(groupConfig.apiKey),
|
||||
},
|
||||
baseURL: extractEnvVariable(groupConfig.baseURL as string),
|
||||
serverless: true,
|
||||
};
|
||||
|
||||
const apiKeyValue = result.azureOptions.azureOpenAIApiKey;
|
||||
if (typeof apiKeyValue === 'string' && envVarRegex.test(apiKeyValue)) {
|
||||
throw new Error(`Azure configuration environment variable "${apiKeyValue}" was not found.`);
|
||||
}
|
||||
|
||||
if (groupConfig.additionalHeaders) {
|
||||
result.headers = groupConfig.additionalHeaders;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
if (!instanceName) {
|
||||
throw new Error(
|
||||
`Group "${modelConfig.group}" is missing an instanceName for non-serverless configuration.`,
|
||||
);
|
||||
}
|
||||
|
||||
const modelDetails = groupConfig.models[modelName];
|
||||
const deploymentName =
|
||||
typeof modelDetails === 'object'
|
||||
|
|
@ -186,7 +251,7 @@ export function mapModelToAzureConfig({
|
|||
|
||||
const azureOptions: AzureOptions = {
|
||||
azureOpenAIApiKey: extractEnvVariable(groupConfig.apiKey),
|
||||
azureOpenAIApiInstanceName: extractEnvVariable(groupConfig.instanceName),
|
||||
azureOpenAIApiInstanceName: extractEnvVariable(instanceName),
|
||||
azureOpenAIApiDeploymentName: extractEnvVariable(deploymentName),
|
||||
azureOpenAIApiVersion: extractEnvVariable(version),
|
||||
};
|
||||
|
|
|
|||
|
|
@ -19,8 +19,12 @@ export type TAzureModelConfig = z.infer<typeof modelConfigSchema>;
|
|||
|
||||
export const azureBaseSchema = z.object({
|
||||
apiKey: z.string(),
|
||||
instanceName: z.string(),
|
||||
serverless: z.boolean().optional(),
|
||||
instanceName: z.string().optional(),
|
||||
deploymentName: z.string().optional(),
|
||||
addParams: z.record(z.any()).optional(),
|
||||
dropParams: z.array(z.string()).optional(),
|
||||
forcePrompt: z.boolean().optional(),
|
||||
version: z.string().optional(),
|
||||
baseURL: z.string().optional(),
|
||||
additionalHeaders: z.record(z.any()).optional(),
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue