diff --git a/api/app/clients/OpenAIClient.js b/api/app/clients/OpenAIClient.js index 2b254036c5..12d717c6dc 100644 --- a/api/app/clients/OpenAIClient.js +++ b/api/app/clients/OpenAIClient.js @@ -6,8 +6,10 @@ const { Tokenizer, createFetch, resolveHeaders, + shouldUseEntraId, constructAzureURL, getModelMaxTokens, + getEntraIdAccessToken, genAzureChatCompletion, getModelMaxOutputTokens, createStreamEventHandlers, @@ -614,7 +616,7 @@ class OpenAIClient extends BaseClient { return (reply ?? '').trim(); } - initializeLLM({ + async initializeLLM({ model = openAISettings.model.default, modelName, temperature = 0.2, @@ -753,7 +755,14 @@ class OpenAIClient extends BaseClient { this.options.defaultQuery = azureOptions.azureOpenAIApiVersion ? { 'api-version': azureOptions.azureOpenAIApiVersion } : undefined; - this.options.headers['api-key'] = this.apiKey; + if (shouldUseEntraId()) { + this.options.headers = { + ...this.options.headers, + Authorization: `Bearer ${await getEntraIdAccessToken()}`, + }; + } else { + this.options.headers['api-key'] = this.apiKey; + } } } @@ -812,7 +821,7 @@ ${convo} try { this.abortController = new AbortController(); - const llm = this.initializeLLM({ + const llm = await this.initializeLLM({ ...modelOptions, conversationId, context: 'title', @@ -961,7 +970,7 @@ ${convo} const initialPromptTokens = this.maxContextTokens - remainingContextTokens; logger.debug('[OpenAIClient] initialPromptTokens', initialPromptTokens); - const llm = this.initializeLLM({ + const llm = await this.initializeLLM({ model, temperature: 0.2, context: 'summary', @@ -1187,7 +1196,14 @@ ${convo} this.options.defaultQuery = azureOptions.azureOpenAIApiVersion ? { 'api-version': azureOptions.azureOpenAIApiVersion } : undefined; - this.options.headers['api-key'] = this.apiKey; + if (shouldUseEntraId()) { + this.options.headers = { + ...this.options.headers, + Authorization: `Bearer ${await getEntraIdAccessToken()}`, + }; + } else { + this.options.headers['api-key'] = this.apiKey; + } } } @@ -1208,7 +1224,14 @@ ${convo} : this.azureEndpoint.split(/(? { expect(result.openAIApiKey).toBe('test'); expect(result.client.options.reverseProxyUrl).toBe('https://user-provided-url.com'); }); + + test('should use Entra ID authentication when AZURE_OPENAI_USE_ENTRA_ID is enabled', async () => { + process.env.AZURE_OPENAI_USE_ENTRA_ID = 'true'; + process.env.AZURE_API_KEY = 'test-azure-api-key'; + process.env.AZURE_OPENAI_API_INSTANCE_NAME = 'test-instance'; + process.env.AZURE_OPENAI_API_DEPLOYMENT_NAME = 'test-deployment'; + process.env.AZURE_OPENAI_API_VERSION = '2024-12-01-preview'; + + const req = { + body: { + key: null, + endpoint: EModelEndpoint.azureOpenAI, + model: 'gpt-4-vision-preview', + }, + user: { id: '123' }, + app: { locals: {} }, + config: mockAppConfig, + }; + const res = {}; + const endpointOption = {}; + + const result = await initializeClient({ req, res, endpointOption }); + + expect(result.openAIApiKey).toBeTruthy(); + }); }); diff --git a/packages/api/src/endpoints/openai/config.spec.ts b/packages/api/src/endpoints/openai/config.spec.ts index fa718f1043..0b7b5453a0 100644 --- a/packages/api/src/endpoints/openai/config.spec.ts +++ b/packages/api/src/endpoints/openai/config.spec.ts @@ -1549,4 +1549,22 @@ describe('getOpenAIConfig', () => { }); }); }); + + describe('Entra ID Authentication', () => { + it('should handle Entra ID authentication in Azure configuration', () => { + const azure = { + azureOpenAIApiInstanceName: 'test-instance', + azureOpenAIApiDeploymentName: 'test-deployment', + azureOpenAIApiVersion: '2023-05-15', + azureOpenAIApiKey: 'entra-id-placeholder', + }; + + const result = getOpenAIConfig(mockApiKey, { azure }); + + expect(result.llmConfig).toMatchObject({ + ...azure, + model: 'test-deployment', + }); + }); + }); }); diff --git a/packages/api/src/endpoints/openai/initialize.ts b/packages/api/src/endpoints/openai/initialize.ts index b313c28bf9..fa2e07a7f6 100644 --- a/packages/api/src/endpoints/openai/initialize.ts +++ b/packages/api/src/endpoints/openai/initialize.ts @@ -6,7 +6,7 @@ import type { UserKeyValues, } from '~/types'; import { createHandleLLMNewToken } from '~/utils/generators'; -import { getAzureCredentials } from '~/utils/azure'; +import { getAzureCredentials, getEntraIdAccessToken, shouldUseEntraId } from '~/utils/azure'; import { isUserProvided } from '~/utils/common'; import { resolveHeaders } from '~/utils/env'; import { getOpenAIConfig } from './config'; @@ -110,12 +110,30 @@ export const initializeOpenAI = async ({ if (!clientOptions.headers) { clientOptions.headers = {}; } - clientOptions.headers['api-key'] = apiKey; + if (shouldUseEntraId()) { + clientOptions.headers['Authorization'] = `Bearer ${await getEntraIdAccessToken()}`; + } else { + clientOptions.headers['api-key'] = apiKey || ''; + } + } else { + apiKey = azureOptions.azureOpenAIApiKey || ''; + clientOptions.azure = azureOptions; + if (shouldUseEntraId()) { + apiKey = 'entra-id-placeholder'; + clientOptions.headers['Authorization'] = `Bearer ${await getEntraIdAccessToken()}`; + } } } else if (isAzureOpenAI) { clientOptions.azure = userProvidesKey && userValues?.apiKey ? JSON.parse(userValues.apiKey) : getAzureCredentials(); - apiKey = clientOptions.azure ? clientOptions.azure.azureOpenAIApiKey : undefined; + if (shouldUseEntraId()) { + clientOptions.headers = { + ...clientOptions.headers, + Authorization: `Bearer ${await getEntraIdAccessToken()}`, + }; + } else { + apiKey = clientOptions.azure ? clientOptions.azure.azureOpenAIApiKey : undefined; + } } if (userProvidesKey && !apiKey) { diff --git a/packages/api/src/utils/azure.ts b/packages/api/src/utils/azure.ts index b4051d3d80..2821e8b0f9 100644 --- a/packages/api/src/utils/azure.ts +++ b/packages/api/src/utils/azure.ts @@ -1,5 +1,7 @@ import { isEnabled } from './common'; import type { AzureOptions, GenericClient } from '~/types'; +import { DefaultAzureCredential } from '@azure/identity'; +import { logger } from '@librechat/data-schemas'; /** * Sanitizes the model name to be used in the URL by removing or replacing disallowed characters. @@ -118,3 +120,43 @@ export function constructAzureURL({ return finalURL; } + +/** + * Checks if Entra ID authentication should be used based on environment variables. + * @returns {boolean} True if Entra ID authentication should be used + */ +export const shouldUseEntraId = (): boolean => { + return process.env.AZURE_OPENAI_USE_ENTRA_ID === 'true'; +}; + +/** + * Creates an Azure credential for Entra ID authentication. + * Uses DefaultAzureCredential which supports multiple authentication methods: + * - Managed Identity (when running in Azure) + * - Service Principal (when environment variables are set) + * - Azure CLI (for local development) + * - Visual Studio Code (for local development) + * + * @returns DefaultAzureCredential instance + */ + +export const createEntraIdCredential = (): DefaultAzureCredential => { + return new DefaultAzureCredential(); +}; + +/** + * Gets the access token for Entra ID authentication from azure/identity. + * @returns {Promise} The access token + */ +export const getEntraIdAccessToken = async (): Promise => { + try { + const credential = createEntraIdCredential(); + + const tokenResponse = await credential.getToken('https://cognitiveservices.azure.com/.default'); + + return tokenResponse.token; + } catch (error) { + logger.error('[ENTRA_ID_DEBUG] Failed to get Entra ID access token:', error); + throw error; + } +}; diff --git a/packages/data-provider/specs/azure.spec.ts b/packages/data-provider/specs/azure.spec.ts index 5628d3f24b..a4cfa9764d 100644 --- a/packages/data-provider/specs/azure.spec.ts +++ b/packages/data-provider/specs/azure.spec.ts @@ -842,3 +842,33 @@ describe('mapGroupToAzureConfig', () => { }).toThrow(`Group named "${groupName}" not found in configuration.`); }); }); + +describe('Entra ID Authentication', () => { + it('should handle Entra ID placeholder in Azure configuration', () => { + const configs = [ + { + group: 'entra-id-group', + apiKey: 'entra-id-placeholder', + instanceName: 'entra-instance', + deploymentName: 'entra-deployment', + version: '2024-12-01-preview', + models: { + 'gpt-4': { + deploymentName: 'gpt-4-deployment', + version: '2024-12-01-preview', + }, + }, + }, + ]; + const { isValid, modelNames, modelGroupMap, groupMap } = validateAzureGroups(configs); + expect(isValid).toBe(true); + expect(modelNames).toEqual(['gpt-4']); + + const { azureOptions } = mapModelToAzureConfig({ + modelName: 'gpt-4', + modelGroupMap, + groupMap, + }); + expect(azureOptions.azureOpenAIApiKey).toBe('entra-id-placeholder'); + }); +});