diff --git a/api/app/clients/OpenAIClient.js b/api/app/clients/OpenAIClient.js index f4c42351e3..dee4224e18 100644 --- a/api/app/clients/OpenAIClient.js +++ b/api/app/clients/OpenAIClient.js @@ -6,8 +6,10 @@ const { Tokenizer, createFetch, resolveHeaders, + shouldUseEntraId, constructAzureURL, getModelMaxTokens, + getEntraIdAccessToken, genAzureChatCompletion, getModelMaxOutputTokens, createStreamEventHandlers, @@ -837,7 +839,14 @@ class OpenAIClient extends BaseClient { this.options.defaultQuery = azureOptions.azureOpenAIApiVersion ? { 'api-version': azureOptions.azureOpenAIApiVersion } : undefined; - this.options.headers['api-key'] = this.apiKey; + if (shouldUseEntraId()) { + this.options.headers = { + ...this.options.headers, + Authorization: `Bearer ${await getEntraIdAccessToken()}`, + }; + } else { + this.options.headers['api-key'] = this.apiKey; + } } } @@ -858,7 +867,14 @@ class OpenAIClient extends BaseClient { : this.azureEndpoint.split(/(? { + const actual = jest.requireActual('@librechat/api'); + return { + ...actual, + getEntraIdAccessToken: jest.fn(), + shouldUseEntraId: jest.fn(() => actual.shouldUseEntraId()), + }; +}); +const { getEntraIdAccessToken, shouldUseEntraId } = require('@librechat/api'); // Mock getUserKey since it's the only function we want to mock jest.mock('~/server/services/UserService', () => ({ @@ -428,4 +437,31 @@ describe('initializeClient', () => { expect(result.openAIApiKey).toBe('test'); expect(result.client.options.reverseProxyUrl).toBe('https://user-provided-url.com'); }); + + test('should use Entra ID authentication when AZURE_OPENAI_USE_ENTRA_ID is enabled', async () => { + shouldUseEntraId.mockReturnValue(true); + getEntraIdAccessToken.mockResolvedValue('entra-token'); + process.env.AZURE_OPENAI_USE_ENTRA_ID = 'true'; + process.env.AZURE_API_KEY = 'test-azure-api-key'; + process.env.AZURE_OPENAI_API_INSTANCE_NAME = 'test-instance'; + process.env.AZURE_OPENAI_API_DEPLOYMENT_NAME = 'test-deployment'; + process.env.AZURE_OPENAI_API_VERSION = '2024-12-01-preview'; + + const req = { + body: { + key: null, + endpoint: EModelEndpoint.azureOpenAI, + model: 'gpt-4-vision-preview', + }, + user: { id: '123' }, + app: { locals: {} }, + config: mockAppConfig, + }; + const res = {}; + const endpointOption = {}; + + const result = await initializeClient({ req, res, endpointOption }); + + expect(result.openAIApiKey).toBeTruthy(); + }); }); diff --git a/packages/api/src/endpoints/openai/config.spec.ts b/packages/api/src/endpoints/openai/config.spec.ts index 405873490f..82c6ec657b 100644 --- a/packages/api/src/endpoints/openai/config.spec.ts +++ b/packages/api/src/endpoints/openai/config.spec.ts @@ -1862,4 +1862,22 @@ describe('getOpenAIConfig', () => { }); }); }); + + describe('Entra ID Authentication', () => { + it('should handle Entra ID authentication in Azure configuration', () => { + const azure = { + azureOpenAIApiInstanceName: 'test-instance', + azureOpenAIApiDeploymentName: 'test-deployment', + azureOpenAIApiVersion: '2023-05-15', + azureOpenAIApiKey: 'entra-id-placeholder', + }; + + const result = getOpenAIConfig(mockApiKey, { azure }); + + expect(result.llmConfig).toMatchObject({ + ...azure, + model: 'test-deployment', + }); + }); + }); }); diff --git a/packages/api/src/endpoints/openai/initialize.ts b/packages/api/src/endpoints/openai/initialize.ts index 9b1c5dd131..d2ba469fdf 100644 --- a/packages/api/src/endpoints/openai/initialize.ts +++ b/packages/api/src/endpoints/openai/initialize.ts @@ -5,7 +5,7 @@ import type { LLMConfigResult, UserKeyValues, } from '~/types'; -import { getAzureCredentials } from '~/utils/azure'; +import { getAzureCredentials, getEntraIdAccessToken, shouldUseEntraId } from '~/utils/azure'; import { isUserProvided } from '~/utils/common'; import { resolveHeaders } from '~/utils/env'; import { getOpenAIConfig } from './config'; @@ -98,7 +98,7 @@ export const initializeOpenAI = async ({ clientOptions.dropParams = groupMap[groupName]?.dropParams; } - apiKey = azureOptions.azureOpenAIApiKey; + apiKey = shouldUseEntraId() ? 'entra-id-placeholder' : azureOptions.azureOpenAIApiKey; clientOptions.azure = !serverless ? azureOptions : undefined; if (serverless === true) { @@ -109,12 +109,30 @@ export const initializeOpenAI = async ({ if (!clientOptions.headers) { clientOptions.headers = {}; } - clientOptions.headers['api-key'] = apiKey; + if (shouldUseEntraId()) { + clientOptions.headers['Authorization'] = `Bearer ${await getEntraIdAccessToken()}`; + } else { + clientOptions.headers['api-key'] = apiKey || ''; + } + } else { + apiKey = azureOptions.azureOpenAIApiKey || ''; + clientOptions.azure = azureOptions; + if (shouldUseEntraId()) { + apiKey = 'entra-id-placeholder'; + clientOptions.headers['Authorization'] = `Bearer ${await getEntraIdAccessToken()}`; + } } } else if (isAzureOpenAI) { clientOptions.azure = userProvidesKey && userValues?.apiKey ? JSON.parse(userValues.apiKey) : getAzureCredentials(); - apiKey = clientOptions.azure ? clientOptions.azure.azureOpenAIApiKey : undefined; + if (shouldUseEntraId()) { + clientOptions.headers = { + ...clientOptions.headers, + Authorization: `Bearer ${await getEntraIdAccessToken()}`, + }; + } else { + apiKey = clientOptions.azure ? clientOptions.azure.azureOpenAIApiKey : undefined; + } } if (userProvidesKey && !apiKey) { diff --git a/packages/api/src/utils/azure.ts b/packages/api/src/utils/azure.ts index 1bbd0e29b2..896029ef34 100644 --- a/packages/api/src/utils/azure.ts +++ b/packages/api/src/utils/azure.ts @@ -1,5 +1,7 @@ import { isEnabled } from './common'; import type { AzureOptions, GenericClient } from '~/types'; +import { DefaultAzureCredential } from '@azure/identity'; +import { logger } from '@librechat/data-schemas'; /** * Sanitizes the model name to be used in the URL by removing or replacing disallowed characters. @@ -124,3 +126,43 @@ export function constructAzureURL({ return finalURL; } + +/** + * Checks if Entra ID authentication should be used based on environment variables. + * @returns {boolean} True if Entra ID authentication should be used + */ +export const shouldUseEntraId = (): boolean => { + return process.env.AZURE_OPENAI_USE_ENTRA_ID === 'true'; +}; + +/** + * Creates an Azure credential for Entra ID authentication. + * Uses DefaultAzureCredential which supports multiple authentication methods: + * - Managed Identity (when running in Azure) + * - Service Principal (when environment variables are set) + * - Azure CLI (for local development) + * - Visual Studio Code (for local development) + * + * @returns DefaultAzureCredential instance + */ + +export const createEntraIdCredential = (): DefaultAzureCredential => { + return new DefaultAzureCredential(); +}; + +/** + * Gets the access token for Entra ID authentication from azure/identity. + * @returns {Promise} The access token + */ +export const getEntraIdAccessToken = async (): Promise => { + try { + const credential = createEntraIdCredential(); + + const tokenResponse = await credential.getToken('https://cognitiveservices.azure.com/.default'); + + return tokenResponse.token; + } catch (error) { + logger.error('[ENTRA_ID_DEBUG] Failed to get Entra ID access token:', error); + throw error; + } +}; diff --git a/packages/data-provider/specs/azure.spec.ts b/packages/data-provider/specs/azure.spec.ts index 5628d3f24b..a4cfa9764d 100644 --- a/packages/data-provider/specs/azure.spec.ts +++ b/packages/data-provider/specs/azure.spec.ts @@ -842,3 +842,33 @@ describe('mapGroupToAzureConfig', () => { }).toThrow(`Group named "${groupName}" not found in configuration.`); }); }); + +describe('Entra ID Authentication', () => { + it('should handle Entra ID placeholder in Azure configuration', () => { + const configs = [ + { + group: 'entra-id-group', + apiKey: 'entra-id-placeholder', + instanceName: 'entra-instance', + deploymentName: 'entra-deployment', + version: '2024-12-01-preview', + models: { + 'gpt-4': { + deploymentName: 'gpt-4-deployment', + version: '2024-12-01-preview', + }, + }, + }, + ]; + const { isValid, modelNames, modelGroupMap, groupMap } = validateAzureGroups(configs); + expect(isValid).toBe(true); + expect(modelNames).toEqual(['gpt-4']); + + const { azureOptions } = mapModelToAzureConfig({ + modelName: 'gpt-4', + modelGroupMap, + groupMap, + }); + expect(azureOptions.azureOpenAIApiKey).toBe('entra-id-placeholder'); + }); +});