diff --git a/api/server/services/Config/loadConfigModels.js b/api/server/services/Config/loadConfigModels.js index 26920459a4..aebd41e19a 100644 --- a/api/server/services/Config/loadConfigModels.js +++ b/api/server/services/Config/loadConfigModels.js @@ -1,6 +1,6 @@ const { EModelEndpoint, extractEnvVariable } = require('librechat-data-provider'); -const { isUserProvided } = require('~/server/utils'); const { fetchModels } = require('~/server/services/ModelService'); +const { isUserProvided } = require('~/server/utils'); const getCustomConfig = require('./getCustomConfig'); /** @@ -41,8 +41,8 @@ async function loadConfigModels(req) { (endpoint.models.fetch || endpoint.models.default), ); - const fetchPromisesMap = {}; // Map for promises keyed by baseURL - const baseUrlToNameMap = {}; // Map to associate baseURLs with names + const fetchPromisesMap = {}; // Map for promises keyed by unique combination of baseURL and apiKey + const uniqueKeyToNameMap = {}; // Map to associate unique keys with endpoint names for (let i = 0; i < customEndpoints.length; i++) { const endpoint = customEndpoints[i]; @@ -51,11 +51,13 @@ async function loadConfigModels(req) { const API_KEY = extractEnvVariable(apiKey); const BASE_URL = extractEnvVariable(baseURL); + const uniqueKey = `${BASE_URL}__${API_KEY}`; + modelsConfig[name] = []; if (models.fetch && !isUserProvided(API_KEY) && !isUserProvided(BASE_URL)) { - fetchPromisesMap[BASE_URL] = - fetchPromisesMap[BASE_URL] || + fetchPromisesMap[uniqueKey] = + fetchPromisesMap[uniqueKey] || fetchModels({ user: req.user.id, baseURL: BASE_URL, @@ -63,8 +65,8 @@ async function loadConfigModels(req) { name, userIdQuery: models.userIdQuery, }); - baseUrlToNameMap[BASE_URL] = baseUrlToNameMap[BASE_URL] || []; - baseUrlToNameMap[BASE_URL].push(name); + uniqueKeyToNameMap[uniqueKey] = uniqueKeyToNameMap[uniqueKey] || []; + uniqueKeyToNameMap[uniqueKey].push(name); continue; } @@ -74,12 +76,12 @@ async function loadConfigModels(req) { } const fetchedData = await Promise.all(Object.values(fetchPromisesMap)); - const baseUrls = Object.keys(fetchPromisesMap); + const uniqueKeys = Object.keys(fetchPromisesMap); for (let i = 0; i < fetchedData.length; i++) { - const currentBaseUrl = baseUrls[i]; + const currentKey = uniqueKeys[i]; const modelData = fetchedData[i]; - const associatedNames = baseUrlToNameMap[currentBaseUrl]; + const associatedNames = uniqueKeyToNameMap[currentKey]; for (const name of associatedNames) { modelsConfig[name] = modelData; diff --git a/api/server/services/Config/loadConfigModels.spec.js b/api/server/services/Config/loadConfigModels.spec.js new file mode 100644 index 0000000000..b49a0121de --- /dev/null +++ b/api/server/services/Config/loadConfigModels.spec.js @@ -0,0 +1,265 @@ +const { fetchModels } = require('~/server/services/ModelService'); +const loadConfigModels = require('./loadConfigModels'); +const getCustomConfig = require('./getCustomConfig'); + +jest.mock('~/server/services/ModelService'); +jest.mock('./getCustomConfig'); + +const exampleConfig = { + endpoints: { + custom: [ + { + name: 'Mistral', + apiKey: '${MY_PRECIOUS_MISTRAL_KEY}', + baseURL: 'https://api.mistral.ai/v1', + models: { + default: ['mistral-tiny', 'mistral-small', 'mistral-medium', 'mistral-large-latest'], + fetch: true, + }, + dropParams: ['stop', 'user', 'frequency_penalty', 'presence_penalty'], + }, + { + name: 'OpenRouter', + apiKey: '${MY_OPENROUTER_API_KEY}', + baseURL: 'https://openrouter.ai/api/v1', + models: { + default: ['gpt-3.5-turbo'], + fetch: true, + }, + dropParams: ['stop'], + }, + { + name: 'groq', + apiKey: 'user_provided', + baseURL: 'https://api.groq.com/openai/v1/', + models: { + default: ['llama2-70b-4096', 'mixtral-8x7b-32768'], + fetch: false, + }, + }, + { + name: 'Ollama', + apiKey: 'user_provided', + baseURL: 'http://localhost:11434/v1/', + models: { + default: ['mistral', 'llama2:13b'], + fetch: false, + }, + }, + ], + }, +}; + +describe('loadConfigModels', () => { + const mockRequest = { app: { locals: {} }, user: { id: 'testUserId' } }; + + const originalEnv = process.env; + + beforeEach(() => { + jest.resetAllMocks(); + jest.resetModules(); + process.env = { ...originalEnv }; + }); + + afterEach(() => { + process.env = originalEnv; + }); + + it('should return an empty object if customConfig is null', async () => { + getCustomConfig.mockResolvedValue(null); + const result = await loadConfigModels(mockRequest); + expect(result).toEqual({}); + }); + + it('handles azure models and endpoint correctly', async () => { + mockRequest.app.locals.azureOpenAI = { modelNames: ['model1', 'model2'] }; + getCustomConfig.mockResolvedValue({ + endpoints: { + azureOpenAI: { + models: ['model1', 'model2'], + }, + }, + }); + + const result = await loadConfigModels(mockRequest); + expect(result.azureOpenAI).toEqual(['model1', 'model2']); + }); + + it('fetches custom models based on the unique key', async () => { + process.env.BASE_URL = 'http://example.com'; + process.env.API_KEY = 'some-api-key'; + const customEndpoints = { + custom: [ + { + baseURL: '${BASE_URL}', + apiKey: '${API_KEY}', + name: 'CustomModel', + models: { fetch: true }, + }, + ], + }; + + getCustomConfig.mockResolvedValue({ endpoints: customEndpoints }); + fetchModels.mockResolvedValue(['customModel1', 'customModel2']); + + const result = await loadConfigModels(mockRequest); + expect(fetchModels).toHaveBeenCalled(); + expect(result.CustomModel).toEqual(['customModel1', 'customModel2']); + }); + + it('correctly associates models to names using unique keys', async () => { + getCustomConfig.mockResolvedValue({ + endpoints: { + custom: [ + { + baseURL: 'http://example.com', + apiKey: 'API_KEY1', + name: 'Model1', + models: { fetch: true }, + }, + { + baseURL: 'http://example.com', + apiKey: 'API_KEY2', + name: 'Model2', + models: { fetch: true }, + }, + ], + }, + }); + fetchModels.mockImplementation(({ apiKey }) => + Promise.resolve(apiKey === 'API_KEY1' ? ['model1Data'] : ['model2Data']), + ); + + const result = await loadConfigModels(mockRequest); + expect(result.Model1).toEqual(['model1Data']); + expect(result.Model2).toEqual(['model2Data']); + }); + + it('correctly handles multiple endpoints with the same baseURL but different apiKeys', async () => { + // Mock the custom configuration to simulate the user's scenario + getCustomConfig.mockResolvedValue({ + endpoints: { + custom: [ + { + name: 'LiteLLM', + apiKey: '${LITELLM_ALL_MODELS}', + baseURL: '${LITELLM_HOST}', + models: { fetch: true }, + }, + { + name: 'OpenAI', + apiKey: '${LITELLM_OPENAI_MODELS}', + baseURL: '${LITELLM_SECOND_HOST}', + models: { fetch: true }, + }, + { + name: 'Google', + apiKey: '${LITELLM_GOOGLE_MODELS}', + baseURL: '${LITELLM_SECOND_HOST}', + models: { fetch: true }, + }, + ], + }, + }); + + // Mock `fetchModels` to return different models based on the apiKey + fetchModels.mockImplementation(({ apiKey }) => { + switch (apiKey) { + case '${LITELLM_ALL_MODELS}': + return Promise.resolve(['AllModel1', 'AllModel2']); + case '${LITELLM_OPENAI_MODELS}': + return Promise.resolve(['OpenAIModel']); + case '${LITELLM_GOOGLE_MODELS}': + return Promise.resolve(['GoogleModel']); + default: + return Promise.resolve([]); + } + }); + + const result = await loadConfigModels(mockRequest); + + // Assert that the models are correctly fetched and mapped based on unique keys + expect(result.LiteLLM).toEqual(['AllModel1', 'AllModel2']); + expect(result.OpenAI).toEqual(['OpenAIModel']); + expect(result.Google).toEqual(['GoogleModel']); + + // Ensure that fetchModels was called with correct parameters + expect(fetchModels).toHaveBeenCalledTimes(3); + expect(fetchModels).toHaveBeenCalledWith( + expect.objectContaining({ apiKey: '${LITELLM_ALL_MODELS}' }), + ); + expect(fetchModels).toHaveBeenCalledWith( + expect.objectContaining({ apiKey: '${LITELLM_OPENAI_MODELS}' }), + ); + expect(fetchModels).toHaveBeenCalledWith( + expect.objectContaining({ apiKey: '${LITELLM_GOOGLE_MODELS}' }), + ); + }); + + it('loads models based on custom endpoint configuration respecting fetch rules', async () => { + process.env.MY_PRECIOUS_MISTRAL_KEY = 'actual_mistral_api_key'; + process.env.MY_OPENROUTER_API_KEY = 'actual_openrouter_api_key'; + // Setup custom configuration with specific API keys for Mistral and OpenRouter + // and "user_provided" for groq and Ollama, indicating no fetch for the latter two + getCustomConfig.mockResolvedValue(exampleConfig); + + // Assuming fetchModels would be called only for Mistral and OpenRouter + fetchModels.mockImplementation(({ name }) => { + switch (name) { + case 'Mistral': + return Promise.resolve([ + 'mistral-tiny', + 'mistral-small', + 'mistral-medium', + 'mistral-large-latest', + ]); + case 'OpenRouter': + return Promise.resolve(['gpt-3.5-turbo']); + default: + return Promise.resolve([]); + } + }); + + const result = await loadConfigModels(mockRequest); + + // Since fetch is true and apiKey is not "user_provided", fetching occurs for Mistral and OpenRouter + expect(result.Mistral).toEqual([ + 'mistral-tiny', + 'mistral-small', + 'mistral-medium', + 'mistral-large-latest', + ]); + expect(fetchModels).toHaveBeenCalledWith( + expect.objectContaining({ + name: 'Mistral', + apiKey: process.env.MY_PRECIOUS_MISTRAL_KEY, + }), + ); + + expect(result.OpenRouter).toEqual(['gpt-3.5-turbo']); + expect(fetchModels).toHaveBeenCalledWith( + expect.objectContaining({ + name: 'OpenRouter', + apiKey: process.env.MY_OPENROUTER_API_KEY, + }), + ); + + // For groq and Ollama, since the apiKey is "user_provided", models should not be fetched + // Depending on your implementation's behavior regarding "default" models without fetching, + // you may need to adjust the following assertions: + expect(result.groq).toBe(exampleConfig.endpoints.custom[2].models.default); + expect(result.Ollama).toBe(exampleConfig.endpoints.custom[3].models.default); + + // Verifying fetchModels was not called for groq and Ollama + expect(fetchModels).not.toHaveBeenCalledWith( + expect.objectContaining({ + name: 'groq', + }), + ); + expect(fetchModels).not.toHaveBeenCalledWith( + expect.objectContaining({ + name: 'Ollama', + }), + ); + }); +});