🦙 feat: Fetch list of Ollama Models (#2565)

* 🦙 feat: Fetch list of Ollama Models

* style: better Tag text styling for light mode
This commit is contained in:
Danny Avila 2024-04-27 18:27:04 -04:00 committed by GitHub
parent 8a78500fe2
commit 63ef15ab63
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 210 additions and 4 deletions

View file

@ -1,6 +1,7 @@
const axios = require('axios');
const { logger } = require('~/config');
const { fetchModels, getOpenAIModels } = require('./ModelService');
const { fetchModels, getOpenAIModels, deriveBaseURL } = require('./ModelService');
jest.mock('~/utils', () => {
const originalUtils = jest.requireActual('~/utils');
return {
@ -256,3 +257,119 @@ describe('getOpenAIModels sorting behavior', () => {
jest.clearAllMocks();
});
});
describe('fetchModels with Ollama specific logic', () => {
const mockOllamaData = {
data: {
models: [{ name: 'Ollama-Base' }, { name: 'Ollama-Advanced' }],
},
};
beforeEach(() => {
axios.get.mockResolvedValue(mockOllamaData);
});
afterEach(() => {
jest.clearAllMocks();
});
it('should fetch Ollama models when name starts with "ollama"', async () => {
const models = await fetchModels({
user: 'user789',
apiKey: 'testApiKey',
baseURL: 'https://api.ollama.test.com',
name: 'OllamaAPI',
});
expect(models).toEqual(['Ollama-Base', 'Ollama-Advanced']);
expect(axios.get).toHaveBeenCalledWith('https://api.ollama.test.com/api/tags'); // Adjusted to expect only one argument if no options are passed
});
it('should handle errors gracefully when fetching Ollama models fails', async () => {
axios.get.mockRejectedValue(new Error('Network error'));
const models = await fetchModels({
user: 'user789',
apiKey: 'testApiKey',
baseURL: 'https://api.ollama.test.com',
name: 'OllamaAPI',
});
expect(models).toEqual([]);
expect(logger.error).toHaveBeenCalled();
});
it('should return an empty array if no baseURL is provided', async () => {
const models = await fetchModels({
user: 'user789',
apiKey: 'testApiKey',
name: 'OllamaAPI',
});
expect(models).toEqual([]);
});
it('should not fetch Ollama models if the name does not start with "ollama"', async () => {
// Mock axios to return a different set of models for non-Ollama API calls
axios.get.mockResolvedValue({
data: {
data: [{ id: 'model-1' }, { id: 'model-2' }],
},
});
const models = await fetchModels({
user: 'user789',
apiKey: 'testApiKey',
baseURL: 'https://api.test.com',
name: 'TestAPI',
});
expect(models).toEqual(['model-1', 'model-2']);
expect(axios.get).toHaveBeenCalledWith(
'https://api.test.com/models', // Ensure the correct API endpoint is called
expect.any(Object), // Ensuring some object (headers, etc.) is passed
);
});
});
describe('deriveBaseURL', () => {
it('should extract the base URL correctly from a full URL with a port', () => {
const fullURL = 'https://example.com:8080/path?query=123';
const baseURL = deriveBaseURL(fullURL);
expect(baseURL).toEqual('https://example.com:8080');
});
it('should extract the base URL correctly from a full URL without a port', () => {
const fullURL = 'https://example.com/path?query=123';
const baseURL = deriveBaseURL(fullURL);
expect(baseURL).toEqual('https://example.com');
});
it('should handle URLs using the HTTP protocol', () => {
const fullURL = 'http://example.com:3000/path?query=123';
const baseURL = deriveBaseURL(fullURL);
expect(baseURL).toEqual('http://example.com:3000');
});
it('should return only the protocol and hostname if no port is specified', () => {
const fullURL = 'http://example.com/path?query=123';
const baseURL = deriveBaseURL(fullURL);
expect(baseURL).toEqual('http://example.com');
});
it('should handle URLs with uncommon protocols', () => {
const fullURL = 'ftp://example.com:2121/path?query=123';
const baseURL = deriveBaseURL(fullURL);
expect(baseURL).toEqual('ftp://example.com:2121');
});
it('should handle edge case where URL ends with a slash', () => {
const fullURL = 'https://example.com/';
const baseURL = deriveBaseURL(fullURL);
expect(baseURL).toEqual('https://example.com');
});
it('should return the original URL if the URL is invalid', () => {
const invalidURL = 'htp:/example.com:8080';
const result = deriveBaseURL(invalidURL);
expect(result).toBe(invalidURL);
});
});