🦙 feat: Fetch list of Ollama Models (#2565)

* 🦙 feat: Fetch list of Ollama Models

* style: better Tag text styling for light mode
This commit is contained in:
Danny Avila 2024-04-27 18:27:04 -04:00 committed by GitHub
parent 8a78500fe2
commit 63ef15ab63
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 210 additions and 4 deletions

View file

@ -3,9 +3,59 @@ const { HttpsProxyAgent } = require('https-proxy-agent');
const { EModelEndpoint, defaultModels, CacheKeys } = require('librechat-data-provider');
const { extractBaseURL, inputSchema, processModelData, logAxiosError } = require('~/utils');
const getLogStores = require('~/cache/getLogStores');
const { logger } = require('~/config');
const { openAIApiKey, userProvidedOpenAI } = require('./Config/EndpointService').config;
/**
* Extracts the base URL from the provided URL.
* @param {string} fullURL - The full URL.
* @returns {string} The base URL.
*/
function deriveBaseURL(fullURL) {
try {
const parsedUrl = new URL(fullURL);
const protocol = parsedUrl.protocol;
const hostname = parsedUrl.hostname;
const port = parsedUrl.port;
// Check if the parsed URL components are meaningful
if (!protocol || !hostname) {
return fullURL;
}
// Reconstruct the base URL
return `${protocol}//${hostname}${port ? `:${port}` : ''}`;
} catch (error) {
logger.error('Failed to derive base URL', error);
return fullURL; // Return the original URL in case of any exception
}
}
/**
* Fetches Ollama models from the specified base API path.
* @param {string} baseURL
* @returns {Promise<string[]>} The Ollama models.
*/
const fetchOllamaModels = async (baseURL) => {
let models = [];
if (!baseURL) {
return models;
}
try {
const ollamaEndpoint = deriveBaseURL(baseURL);
/** @type {Promise<AxiosResponse<OllamaListResponse>>} */
const response = await axios.get(`${ollamaEndpoint}/api/tags`);
models = response.data.models.map((tag) => tag.name);
return models;
} catch (error) {
const logMessage =
'Failed to fetch models from Ollama API. If you are not using Ollama directly, and instead, through some aggregator or reverse proxy that handles fetching via OpenAI spec, ensure the name of the endpoint doesn\'t start with `ollama` (case-insensitive).';
logger.error(logMessage, error);
return [];
}
};
/**
* Fetches OpenAI models from the specified base API path or Azure, based on the provided configuration.
*
@ -41,6 +91,10 @@ const fetchModels = async ({
return models;
}
if (name && name.toLowerCase().startsWith('ollama')) {
return await fetchOllamaModels(baseURL);
}
try {
const options = {
headers: {
@ -227,6 +281,7 @@ const getGoogleModels = () => {
module.exports = {
fetchModels,
deriveBaseURL,
getOpenAIModels,
getChatGPTBrowserModels,
getAnthropicModels,

View file

@ -1,6 +1,7 @@
const axios = require('axios');
const { logger } = require('~/config');
const { fetchModels, getOpenAIModels } = require('./ModelService');
const { fetchModels, getOpenAIModels, deriveBaseURL } = require('./ModelService');
jest.mock('~/utils', () => {
const originalUtils = jest.requireActual('~/utils');
return {
@ -256,3 +257,119 @@ describe('getOpenAIModels sorting behavior', () => {
jest.clearAllMocks();
});
});
describe('fetchModels with Ollama specific logic', () => {
const mockOllamaData = {
data: {
models: [{ name: 'Ollama-Base' }, { name: 'Ollama-Advanced' }],
},
};
beforeEach(() => {
axios.get.mockResolvedValue(mockOllamaData);
});
afterEach(() => {
jest.clearAllMocks();
});
it('should fetch Ollama models when name starts with "ollama"', async () => {
const models = await fetchModels({
user: 'user789',
apiKey: 'testApiKey',
baseURL: 'https://api.ollama.test.com',
name: 'OllamaAPI',
});
expect(models).toEqual(['Ollama-Base', 'Ollama-Advanced']);
expect(axios.get).toHaveBeenCalledWith('https://api.ollama.test.com/api/tags'); // Adjusted to expect only one argument if no options are passed
});
it('should handle errors gracefully when fetching Ollama models fails', async () => {
axios.get.mockRejectedValue(new Error('Network error'));
const models = await fetchModels({
user: 'user789',
apiKey: 'testApiKey',
baseURL: 'https://api.ollama.test.com',
name: 'OllamaAPI',
});
expect(models).toEqual([]);
expect(logger.error).toHaveBeenCalled();
});
it('should return an empty array if no baseURL is provided', async () => {
const models = await fetchModels({
user: 'user789',
apiKey: 'testApiKey',
name: 'OllamaAPI',
});
expect(models).toEqual([]);
});
it('should not fetch Ollama models if the name does not start with "ollama"', async () => {
// Mock axios to return a different set of models for non-Ollama API calls
axios.get.mockResolvedValue({
data: {
data: [{ id: 'model-1' }, { id: 'model-2' }],
},
});
const models = await fetchModels({
user: 'user789',
apiKey: 'testApiKey',
baseURL: 'https://api.test.com',
name: 'TestAPI',
});
expect(models).toEqual(['model-1', 'model-2']);
expect(axios.get).toHaveBeenCalledWith(
'https://api.test.com/models', // Ensure the correct API endpoint is called
expect.any(Object), // Ensuring some object (headers, etc.) is passed
);
});
});
describe('deriveBaseURL', () => {
it('should extract the base URL correctly from a full URL with a port', () => {
const fullURL = 'https://example.com:8080/path?query=123';
const baseURL = deriveBaseURL(fullURL);
expect(baseURL).toEqual('https://example.com:8080');
});
it('should extract the base URL correctly from a full URL without a port', () => {
const fullURL = 'https://example.com/path?query=123';
const baseURL = deriveBaseURL(fullURL);
expect(baseURL).toEqual('https://example.com');
});
it('should handle URLs using the HTTP protocol', () => {
const fullURL = 'http://example.com:3000/path?query=123';
const baseURL = deriveBaseURL(fullURL);
expect(baseURL).toEqual('http://example.com:3000');
});
it('should return only the protocol and hostname if no port is specified', () => {
const fullURL = 'http://example.com/path?query=123';
const baseURL = deriveBaseURL(fullURL);
expect(baseURL).toEqual('http://example.com');
});
it('should handle URLs with uncommon protocols', () => {
const fullURL = 'ftp://example.com:2121/path?query=123';
const baseURL = deriveBaseURL(fullURL);
expect(baseURL).toEqual('ftp://example.com:2121');
});
it('should handle edge case where URL ends with a slash', () => {
const fullURL = 'https://example.com/';
const baseURL = deriveBaseURL(fullURL);
expect(baseURL).toEqual('https://example.com');
});
it('should return the original URL if the URL is invalid', () => {
const invalidURL = 'htp:/example.com:8080';
const result = deriveBaseURL(invalidURL);
expect(result).toBe(invalidURL);
});
});