mirror of
https://github.com/danny-avila/LibreChat.git
synced 2025-12-17 00:40:14 +01:00
🦙 fix: Ollama Custom Headers (#10314)
* 🦙 fix: Ollama Custom Headers
* chore: Correct import order for resolveHeaders in OllamaClient.js
* fix: Improve error logging for Ollama API model fetch failure
* ci: update Ollama model fetch tests
* ci: Add unit test for passing headers and user object to Ollama fetchModels
This commit is contained in:
parent
5e35b7d09d
commit
d904b281f1
5 changed files with 107 additions and 56 deletions
|
|
@ -2,7 +2,7 @@ const { z } = require('zod');
|
||||||
const axios = require('axios');
|
const axios = require('axios');
|
||||||
const { Ollama } = require('ollama');
|
const { Ollama } = require('ollama');
|
||||||
const { sleep } = require('@librechat/agents');
|
const { sleep } = require('@librechat/agents');
|
||||||
const { logAxiosError } = require('@librechat/api');
|
const { resolveHeaders } = require('@librechat/api');
|
||||||
const { logger } = require('@librechat/data-schemas');
|
const { logger } = require('@librechat/data-schemas');
|
||||||
const { Constants } = require('librechat-data-provider');
|
const { Constants } = require('librechat-data-provider');
|
||||||
const { deriveBaseURL } = require('~/utils');
|
const { deriveBaseURL } = require('~/utils');
|
||||||
|
|
@ -44,6 +44,7 @@ class OllamaClient {
|
||||||
constructor(options = {}) {
|
constructor(options = {}) {
|
||||||
const host = deriveBaseURL(options.baseURL ?? 'http://localhost:11434');
|
const host = deriveBaseURL(options.baseURL ?? 'http://localhost:11434');
|
||||||
this.streamRate = options.streamRate ?? Constants.DEFAULT_STREAM_RATE;
|
this.streamRate = options.streamRate ?? Constants.DEFAULT_STREAM_RATE;
|
||||||
|
this.headers = options.headers ?? {};
|
||||||
/** @type {Ollama} */
|
/** @type {Ollama} */
|
||||||
this.client = new Ollama({ host });
|
this.client = new Ollama({ host });
|
||||||
}
|
}
|
||||||
|
|
@ -51,27 +52,32 @@ class OllamaClient {
|
||||||
/**
|
/**
|
||||||
* Fetches Ollama models from the specified base API path.
|
* Fetches Ollama models from the specified base API path.
|
||||||
* @param {string} baseURL
|
* @param {string} baseURL
|
||||||
|
* @param {Object} [options] - Optional configuration
|
||||||
|
* @param {Partial<IUser>} [options.user] - User object for header resolution
|
||||||
|
* @param {Record<string, string>} [options.headers] - Headers to include in the request
|
||||||
* @returns {Promise<string[]>} The Ollama models.
|
* @returns {Promise<string[]>} The Ollama models.
|
||||||
|
* @throws {Error} Throws if the Ollama API request fails
|
||||||
*/
|
*/
|
||||||
static async fetchModels(baseURL) {
|
static async fetchModels(baseURL, options = {}) {
|
||||||
let models = [];
|
|
||||||
if (!baseURL) {
|
if (!baseURL) {
|
||||||
return models;
|
|
||||||
}
|
|
||||||
try {
|
|
||||||
const ollamaEndpoint = deriveBaseURL(baseURL);
|
|
||||||
/** @type {Promise<AxiosResponse<OllamaListResponse>>} */
|
|
||||||
const response = await axios.get(`${ollamaEndpoint}/api/tags`, {
|
|
||||||
timeout: 5000,
|
|
||||||
});
|
|
||||||
models = response.data.models.map((tag) => tag.name);
|
|
||||||
return models;
|
|
||||||
} catch (error) {
|
|
||||||
const logMessage =
|
|
||||||
"Failed to fetch models from Ollama API. If you are not using Ollama directly, and instead, through some aggregator or reverse proxy that handles fetching via OpenAI spec, ensure the name of the endpoint doesn't start with `ollama` (case-insensitive).";
|
|
||||||
logAxiosError({ message: logMessage, error });
|
|
||||||
return [];
|
return [];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const ollamaEndpoint = deriveBaseURL(baseURL);
|
||||||
|
|
||||||
|
const resolvedHeaders = resolveHeaders({
|
||||||
|
headers: options.headers,
|
||||||
|
user: options.user,
|
||||||
|
});
|
||||||
|
|
||||||
|
/** @type {Promise<AxiosResponse<OllamaListResponse>>} */
|
||||||
|
const response = await axios.get(`${ollamaEndpoint}/api/tags`, {
|
||||||
|
headers: resolvedHeaders,
|
||||||
|
timeout: 5000,
|
||||||
|
});
|
||||||
|
|
||||||
|
const models = response.data.models.map((tag) => tag.name);
|
||||||
|
return models;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
||||||
|
|
@ -57,7 +57,7 @@ async function loadConfigModels(req) {
|
||||||
|
|
||||||
for (let i = 0; i < customEndpoints.length; i++) {
|
for (let i = 0; i < customEndpoints.length; i++) {
|
||||||
const endpoint = customEndpoints[i];
|
const endpoint = customEndpoints[i];
|
||||||
const { models, name: configName, baseURL, apiKey } = endpoint;
|
const { models, name: configName, baseURL, apiKey, headers: endpointHeaders } = endpoint;
|
||||||
const name = normalizeEndpointName(configName);
|
const name = normalizeEndpointName(configName);
|
||||||
endpointsMap[name] = endpoint;
|
endpointsMap[name] = endpoint;
|
||||||
|
|
||||||
|
|
@ -76,6 +76,8 @@ async function loadConfigModels(req) {
|
||||||
apiKey: API_KEY,
|
apiKey: API_KEY,
|
||||||
baseURL: BASE_URL,
|
baseURL: BASE_URL,
|
||||||
user: req.user.id,
|
user: req.user.id,
|
||||||
|
userObject: req.user,
|
||||||
|
headers: endpointHeaders,
|
||||||
direct: endpoint.directEndpoint,
|
direct: endpoint.directEndpoint,
|
||||||
userIdQuery: models.userIdQuery,
|
userIdQuery: models.userIdQuery,
|
||||||
});
|
});
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
const { Providers } = require('@librechat/agents');
|
|
||||||
const {
|
const {
|
||||||
resolveHeaders,
|
resolveHeaders,
|
||||||
isUserProvided,
|
isUserProvided,
|
||||||
|
|
@ -143,39 +142,27 @@ const initializeClient = async ({ req, res, endpointOption, optionsOnly, overrid
|
||||||
|
|
||||||
if (optionsOnly) {
|
if (optionsOnly) {
|
||||||
const modelOptions = endpointOption?.model_parameters ?? {};
|
const modelOptions = endpointOption?.model_parameters ?? {};
|
||||||
if (endpoint !== Providers.OLLAMA) {
|
clientOptions = Object.assign(
|
||||||
clientOptions = Object.assign(
|
{
|
||||||
{
|
modelOptions,
|
||||||
modelOptions,
|
},
|
||||||
},
|
clientOptions,
|
||||||
clientOptions,
|
);
|
||||||
);
|
clientOptions.modelOptions.user = req.user.id;
|
||||||
clientOptions.modelOptions.user = req.user.id;
|
const options = getOpenAIConfig(apiKey, clientOptions, endpoint);
|
||||||
const options = getOpenAIConfig(apiKey, clientOptions, endpoint);
|
if (options != null) {
|
||||||
if (options != null) {
|
options.useLegacyContent = true;
|
||||||
options.useLegacyContent = true;
|
options.endpointTokenConfig = endpointTokenConfig;
|
||||||
options.endpointTokenConfig = endpointTokenConfig;
|
}
|
||||||
}
|
if (!clientOptions.streamRate) {
|
||||||
if (!clientOptions.streamRate) {
|
|
||||||
return options;
|
|
||||||
}
|
|
||||||
options.llmConfig.callbacks = [
|
|
||||||
{
|
|
||||||
handleLLMNewToken: createHandleLLMNewToken(clientOptions.streamRate),
|
|
||||||
},
|
|
||||||
];
|
|
||||||
return options;
|
return options;
|
||||||
}
|
}
|
||||||
|
options.llmConfig.callbacks = [
|
||||||
if (clientOptions.reverseProxyUrl) {
|
{
|
||||||
modelOptions.baseUrl = clientOptions.reverseProxyUrl.split('/v1')[0];
|
handleLLMNewToken: createHandleLLMNewToken(clientOptions.streamRate),
|
||||||
delete clientOptions.reverseProxyUrl;
|
},
|
||||||
}
|
];
|
||||||
|
return options;
|
||||||
return {
|
|
||||||
useLegacyContent: true,
|
|
||||||
llmConfig: modelOptions,
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const client = new OpenAIClient(apiKey, clientOptions);
|
const client = new OpenAIClient(apiKey, clientOptions);
|
||||||
|
|
|
||||||
|
|
@ -39,6 +39,8 @@ const { openAIApiKey, userProvidedOpenAI } = require('./Config/EndpointService')
|
||||||
* @param {boolean} [params.userIdQuery=false] - Whether to send the user ID as a query parameter.
|
* @param {boolean} [params.userIdQuery=false] - Whether to send the user ID as a query parameter.
|
||||||
* @param {boolean} [params.createTokenConfig=true] - Whether to create a token configuration from the API response.
|
* @param {boolean} [params.createTokenConfig=true] - Whether to create a token configuration from the API response.
|
||||||
* @param {string} [params.tokenKey] - The cache key to save the token configuration. Uses `name` if omitted.
|
* @param {string} [params.tokenKey] - The cache key to save the token configuration. Uses `name` if omitted.
|
||||||
|
* @param {Record<string, string>} [params.headers] - Optional headers for the request.
|
||||||
|
* @param {Partial<IUser>} [params.userObject] - Optional user object for header resolution.
|
||||||
* @returns {Promise<string[]>} A promise that resolves to an array of model identifiers.
|
* @returns {Promise<string[]>} A promise that resolves to an array of model identifiers.
|
||||||
* @async
|
* @async
|
||||||
*/
|
*/
|
||||||
|
|
@ -52,6 +54,8 @@ const fetchModels = async ({
|
||||||
userIdQuery = false,
|
userIdQuery = false,
|
||||||
createTokenConfig = true,
|
createTokenConfig = true,
|
||||||
tokenKey,
|
tokenKey,
|
||||||
|
headers,
|
||||||
|
userObject,
|
||||||
}) => {
|
}) => {
|
||||||
let models = [];
|
let models = [];
|
||||||
const baseURL = direct ? extractBaseURL(_baseURL) : _baseURL;
|
const baseURL = direct ? extractBaseURL(_baseURL) : _baseURL;
|
||||||
|
|
@ -65,7 +69,13 @@ const fetchModels = async ({
|
||||||
}
|
}
|
||||||
|
|
||||||
if (name && name.toLowerCase().startsWith(Providers.OLLAMA)) {
|
if (name && name.toLowerCase().startsWith(Providers.OLLAMA)) {
|
||||||
return await OllamaClient.fetchModels(baseURL);
|
try {
|
||||||
|
return await OllamaClient.fetchModels(baseURL, { headers, user: userObject });
|
||||||
|
} catch (ollamaError) {
|
||||||
|
const logMessage =
|
||||||
|
'Failed to fetch models from Ollama API. Attempting to fetch via OpenAI-compatible endpoint.';
|
||||||
|
logAxiosError({ message: logMessage, error: ollamaError });
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
const axios = require('axios');
|
const axios = require('axios');
|
||||||
const { logger } = require('@librechat/data-schemas');
|
const { logAxiosError, resolveHeaders } = require('@librechat/api');
|
||||||
const { EModelEndpoint, defaultModels } = require('librechat-data-provider');
|
const { EModelEndpoint, defaultModels } = require('librechat-data-provider');
|
||||||
|
|
||||||
const {
|
const {
|
||||||
|
|
@ -18,6 +18,8 @@ jest.mock('@librechat/api', () => {
|
||||||
processModelData: jest.fn((...args) => {
|
processModelData: jest.fn((...args) => {
|
||||||
return originalUtils.processModelData(...args);
|
return originalUtils.processModelData(...args);
|
||||||
}),
|
}),
|
||||||
|
logAxiosError: jest.fn(),
|
||||||
|
resolveHeaders: jest.fn((options) => options?.headers || {}),
|
||||||
};
|
};
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
@ -277,12 +279,51 @@ describe('fetchModels with Ollama specific logic', () => {
|
||||||
|
|
||||||
expect(models).toEqual(['Ollama-Base', 'Ollama-Advanced']);
|
expect(models).toEqual(['Ollama-Base', 'Ollama-Advanced']);
|
||||||
expect(axios.get).toHaveBeenCalledWith('https://api.ollama.test.com/api/tags', {
|
expect(axios.get).toHaveBeenCalledWith('https://api.ollama.test.com/api/tags', {
|
||||||
|
headers: {},
|
||||||
timeout: 5000,
|
timeout: 5000,
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should handle errors gracefully when fetching Ollama models fails', async () => {
|
it('should pass headers and user object to Ollama fetchModels', async () => {
|
||||||
axios.get.mockRejectedValue(new Error('Network error'));
|
const customHeaders = {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
Authorization: 'Bearer custom-token',
|
||||||
|
};
|
||||||
|
const userObject = {
|
||||||
|
id: 'user789',
|
||||||
|
email: 'test@example.com',
|
||||||
|
};
|
||||||
|
|
||||||
|
resolveHeaders.mockReturnValueOnce(customHeaders);
|
||||||
|
|
||||||
|
const models = await fetchModels({
|
||||||
|
user: 'user789',
|
||||||
|
apiKey: 'testApiKey',
|
||||||
|
baseURL: 'https://api.ollama.test.com',
|
||||||
|
name: 'ollama',
|
||||||
|
headers: customHeaders,
|
||||||
|
userObject,
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(models).toEqual(['Ollama-Base', 'Ollama-Advanced']);
|
||||||
|
expect(resolveHeaders).toHaveBeenCalledWith({
|
||||||
|
headers: customHeaders,
|
||||||
|
user: userObject,
|
||||||
|
});
|
||||||
|
expect(axios.get).toHaveBeenCalledWith('https://api.ollama.test.com/api/tags', {
|
||||||
|
headers: customHeaders,
|
||||||
|
timeout: 5000,
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle errors gracefully when fetching Ollama models fails and fallback to OpenAI-compatible fetch', async () => {
|
||||||
|
axios.get.mockRejectedValueOnce(new Error('Ollama API error'));
|
||||||
|
axios.get.mockResolvedValueOnce({
|
||||||
|
data: {
|
||||||
|
data: [{ id: 'fallback-model-1' }, { id: 'fallback-model-2' }],
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
const models = await fetchModels({
|
const models = await fetchModels({
|
||||||
user: 'user789',
|
user: 'user789',
|
||||||
apiKey: 'testApiKey',
|
apiKey: 'testApiKey',
|
||||||
|
|
@ -290,8 +331,13 @@ describe('fetchModels with Ollama specific logic', () => {
|
||||||
name: 'OllamaAPI',
|
name: 'OllamaAPI',
|
||||||
});
|
});
|
||||||
|
|
||||||
expect(models).toEqual([]);
|
expect(models).toEqual(['fallback-model-1', 'fallback-model-2']);
|
||||||
expect(logger.error).toHaveBeenCalled();
|
expect(logAxiosError).toHaveBeenCalledWith({
|
||||||
|
message:
|
||||||
|
'Failed to fetch models from Ollama API. Attempting to fetch via OpenAI-compatible endpoint.',
|
||||||
|
error: expect.any(Error),
|
||||||
|
});
|
||||||
|
expect(axios.get).toHaveBeenCalledTimes(2);
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should return an empty array if no baseURL is provided', async () => {
|
it('should return an empty array if no baseURL is provided', async () => {
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue