🤖 feat: Custom Endpoint Agents (experimental) (#4627)

* wip: first pass, custom endpoint agents

* chore: imports

* chore: consolidate exports

* fix: imports

* feat: convert message.content array to strings for legacy format handling (deepseek/groq)

* refactor: normalize ollama endpoint name

* refactor: update mocking in isDomainAllowed.spec.js

* refactor: update deepseekModels in tokens.js and tokens.spec.js
This commit is contained in:
Danny Avila 2024-11-04 12:59:04 -05:00 committed by GitHub
parent 9437e95315
commit 2e519f9b57
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
23 changed files with 230 additions and 73 deletions

View file

@ -217,9 +217,41 @@ const formatAgentMessages = (payload) => {
return messages;
};
/**
* Formats an array of messages for LangChain, making sure all content fields are strings
* @param {Array<(HumanMessage|AIMessage|SystemMessage|ToolMessage)>} payload - The array of messages to format.
* @returns {Array<(HumanMessage|AIMessage|SystemMessage|ToolMessage)>} - The array of formatted LangChain messages, including ToolMessages for tool calls.
*/
const formatContentStrings = (payload) => {
const messages = [];
for (const message of payload) {
if (typeof message.content === 'string') {
continue;
}
if (!Array.isArray(message.content)) {
continue;
}
// Reduce text types to a single string, ignore all other types
const content = message.content.reduce((acc, curr) => {
if (curr.type === ContentTypes.TEXT) {
return `${acc}${curr[ContentTypes.TEXT]}\n`;
}
return acc;
}, '');
message.content = content.trim();
}
return messages;
};
module.exports = {
formatMessage,
formatFromLangChain,
formatAgentMessages,
formatContentStrings,
formatLangChainMessages,
};

View file

@ -13,6 +13,7 @@ const {
VisionModes,
openAISchema,
EModelEndpoint,
KnownEndpoints,
anthropicSchema,
bedrockOutputParser,
removeNullishValues,
@ -25,6 +26,7 @@ const {
const {
formatMessage,
formatAgentMessages,
formatContentStrings,
createContextHandlers,
} = require('~/app/clients/prompts');
const { encodeAndFormat } = require('~/server/services/Files/images/encode');
@ -44,6 +46,8 @@ const providerParsers = {
[EModelEndpoint.bedrock]: bedrockOutputParser,
};
const legacyContentEndpoints = new Set([KnownEndpoints.groq, KnownEndpoints.deepseek]);
class AgentClient extends BaseClient {
constructor(options = {}) {
super(null, options);
@ -74,6 +78,7 @@ class AgentClient extends BaseClient {
this.collectedUsage = collectedUsage;
/** @type {ArtifactPromises} */
this.artifactPromises = artifactPromises;
/** @type {AgentClientOptions} */
this.options = Object.assign({ endpoint: options.endpoint }, clientOptions);
}
@ -478,6 +483,9 @@ class AgentClient extends BaseClient {
this.run = run;
const messages = formatAgentMessages(payload);
if (legacyContentEndpoints.has(this.options.agent.endpoint)) {
formatContentStrings(messages);
}
await run.processStream({ messages }, config, {
[Callback.TOOL_ERROR]: (graph, error, toolId) => {
logger.error(

View file

@ -3,7 +3,7 @@ const path = require('path');
const crypto = require('crypto');
const multer = require('multer');
const { fileConfig: defaultFileConfig, mergeFileConfig } = require('librechat-data-provider');
const getCustomConfig = require('~/server/services/Config/getCustomConfig');
const { getCustomConfig } = require('~/server/services/Config');
const storage = multer.diskStorage({
destination: function (req, file, cb) {

View file

@ -1,4 +1,4 @@
const { CacheKeys } = require('librechat-data-provider');
const { CacheKeys, EModelEndpoint } = require('librechat-data-provider');
const loadCustomConfig = require('./loadCustomConfig');
const getLogStores = require('~/cache/getLogStores');
@ -22,4 +22,19 @@ async function getCustomConfig() {
return customConfig;
}
module.exports = getCustomConfig;
/**
*
* @param {string | EModelEndpoint} endpoint
*/
const getCustomEndpointConfig = async (endpoint) => {
const customConfig = await getCustomConfig();
if (!customConfig) {
throw new Error(`Config not found for the ${endpoint} custom endpoint.`);
}
const { endpoints = {} } = customConfig;
const customEndpoints = endpoints[EModelEndpoint.custom] ?? [];
return customEndpoints.find((endpointConfig) => endpointConfig.name === endpoint);
};
module.exports = { getCustomConfig, getCustomEndpointConfig };

View file

@ -10,12 +10,12 @@ const loadDefaultEndpointsConfig = require('./loadDefaultEConfig');
module.exports = {
config,
getCustomConfig,
loadCustomConfig,
loadConfigModels,
loadDefaultModels,
loadOverrideConfig,
loadAsyncEndpoints,
...getCustomConfig,
loadConfigEndpoints,
loadDefaultEndpointsConfig,
};

View file

@ -1,6 +1,6 @@
const { EModelEndpoint, extractEnvVariable } = require('librechat-data-provider');
const { getCustomConfig } = require('./getCustomConfig');
const { isUserProvided } = require('~/server/utils');
const getCustomConfig = require('./getCustomConfig');
/**
* Load config endpoints from the cached configuration object

View file

@ -1,7 +1,16 @@
const { Providers } = require('@librechat/agents');
const { EModelEndpoint, extractEnvVariable } = require('librechat-data-provider');
const { fetchModels } = require('~/server/services/ModelService');
const { getCustomConfig } = require('./getCustomConfig');
const { isUserProvided } = require('~/server/utils');
const getCustomConfig = require('./getCustomConfig');
/**
* @param {string} name
* @returns {string}
*/
function normalizeEndpointName(name = '') {
return name.toLowerCase() === Providers.OLLAMA ? Providers.OLLAMA : name;
}
/**
* Load config endpoints from the cached configuration object
@ -61,7 +70,8 @@ async function loadConfigModels(req) {
for (let i = 0; i < customEndpoints.length; i++) {
const endpoint = customEndpoints[i];
const { models, name, baseURL, apiKey } = endpoint;
const { models, name: configName, baseURL, apiKey } = endpoint;
const name = normalizeEndpointName(configName);
endpointsMap[name] = endpoint;
const API_KEY = extractEnvVariable(apiKey);

View file

@ -1,6 +1,6 @@
const { fetchModels } = require('~/server/services/ModelService');
const { getCustomConfig } = require('./getCustomConfig');
const loadConfigModels = require('./loadConfigModels');
const getCustomConfig = require('./getCustomConfig');
jest.mock('~/server/services/ModelService');
jest.mock('./getCustomConfig');
@ -253,13 +253,13 @@ describe('loadConfigModels', () => {
}),
);
// For groq and Ollama, since the apiKey is "user_provided", models should not be fetched
// For groq and ollama, since the apiKey is "user_provided", models should not be fetched
// Depending on your implementation's behavior regarding "default" models without fetching,
// you may need to adjust the following assertions:
expect(result.groq).toBe(exampleConfig.endpoints.custom[2].models.default);
expect(result.Ollama).toBe(exampleConfig.endpoints.custom[3].models.default);
expect(result.ollama).toBe(exampleConfig.endpoints.custom[3].models.default);
// Verifying fetchModels was not called for groq and Ollama
// Verifying fetchModels was not called for groq and ollama
expect(fetchModels).not.toHaveBeenCalledWith(
expect.objectContaining({
name: 'groq',
@ -267,7 +267,7 @@ describe('loadConfigModels', () => {
);
expect(fetchModels).not.toHaveBeenCalledWith(
expect.objectContaining({
name: 'Ollama',
name: 'ollama',
}),
);
});
@ -335,4 +335,68 @@ describe('loadConfigModels', () => {
expect(result.FalsyFetchModel).toEqual(['defaultModel1', 'defaultModel2']);
});
it('normalizes Ollama endpoint name to lowercase', async () => {
const testCases = [
{
name: 'Ollama',
apiKey: 'user_provided',
baseURL: 'http://localhost:11434/v1/',
models: {
default: ['mistral', 'llama2'],
fetch: false,
},
},
{
name: 'OLLAMA',
apiKey: 'user_provided',
baseURL: 'http://localhost:11434/v1/',
models: {
default: ['mixtral', 'codellama'],
fetch: false,
},
},
{
name: 'OLLaMA',
apiKey: 'user_provided',
baseURL: 'http://localhost:11434/v1/',
models: {
default: ['phi', 'neural-chat'],
fetch: false,
},
},
];
getCustomConfig.mockResolvedValue({
endpoints: {
custom: testCases,
},
});
const result = await loadConfigModels(mockRequest);
// All variations of "Ollama" should be normalized to lowercase "ollama"
// and the last config in the array should override previous ones
expect(result.Ollama).toBeUndefined();
expect(result.OLLAMA).toBeUndefined();
expect(result.OLLaMA).toBeUndefined();
expect(result.ollama).toEqual(['phi', 'neural-chat']);
// Verify fetchModels was not called since these are user_provided
expect(fetchModels).not.toHaveBeenCalledWith(
expect.objectContaining({
name: 'Ollama',
}),
);
expect(fetchModels).not.toHaveBeenCalledWith(
expect.objectContaining({
name: 'OLLAMA',
}),
);
expect(fetchModels).not.toHaveBeenCalledWith(
expect.objectContaining({
name: 'OLLaMA',
}),
);
});
});

View file

@ -1,16 +1,3 @@
// const {
// ErrorTypes,
// EModelEndpoint,
// resolveHeaders,
// mapModelToAzureConfig,
// } = require('librechat-data-provider');
// const { getUserKeyValues, checkUserKeyExpiry } = require('~/server/services/UserService');
// const { isEnabled, isUserProvided } = require('~/server/utils');
// const { getAzureCredentials } = require('~/utils');
// const { OpenAIClient } = require('~/app');
const { z } = require('zod');
const { tool } = require('@langchain/core/tools');
const { createContentAggregator, Providers } = require('@librechat/agents');
const {
EModelEndpoint,
@ -25,30 +12,11 @@ const initAnthropic = require('~/server/services/Endpoints/anthropic/initialize'
const getBedrockOptions = require('~/server/services/Endpoints/bedrock/options');
const initOpenAI = require('~/server/services/Endpoints/openAI/initialize');
const initCustom = require('~/server/services/Endpoints/custom/initialize');
const { getCustomEndpointConfig } = require('~/server/services/Config');
const { loadAgentTools } = require('~/server/services/ToolService');
const AgentClient = require('~/server/controllers/agents/client');
const { getModelMaxTokens } = require('~/utils');
/* For testing errors */
const _getWeather = tool(
async ({ location }) => {
if (location === 'SAN FRANCISCO') {
return 'It\'s 60 degrees and foggy';
} else if (location.toLowerCase() === 'san francisco') {
throw new Error('Input queries must be all capitals');
} else {
throw new Error('Invalid input.');
}
},
{
name: 'get_weather',
description: 'Call to get the current weather',
schema: z.object({
location: z.string(),
}),
},
);
const providerConfigMap = {
[EModelEndpoint.openAI]: initOpenAI,
[EModelEndpoint.azureOpenAI]: initOpenAI,
@ -85,18 +53,25 @@ const initializeClient = async ({ req, res, endpointOption }) => {
if (!agent) {
throw new Error('Agent not found');
}
const { tools, toolMap } = await loadAgentTools({
req,
tools: agent.tools,
agent_id: agent.id,
tool_resources: agent.tool_resources,
// openAIApiKey: process.env.OPENAI_API_KEY,
});
const provider = agent.provider;
let modelOptions = { model: agent.model };
let getOptions = providerConfigMap[agent.provider];
let getOptions = providerConfigMap[provider];
if (!getOptions) {
throw new Error(`Provider ${agent.provider} not supported`);
const customEndpointConfig = await getCustomEndpointConfig(provider);
if (!customEndpointConfig) {
throw new Error(`Provider ${provider} not supported`);
}
getOptions = initCustom;
agent.provider = Providers.OPENAI;
agent.endpoint = provider.toLowerCase();
}
// TODO: pass-in override settings that are specific to current run
@ -106,10 +81,14 @@ const initializeClient = async ({ req, res, endpointOption }) => {
res,
endpointOption,
optionsOnly: true,
overrideEndpoint: agent.provider,
overrideEndpoint: provider,
overrideModel: agent.model,
});
modelOptions = Object.assign(modelOptions, options.llmConfig);
if (options.configOptions) {
modelOptions.configuration = options.configOptions;
}
const sender = getResponseSender({
...endpointOption,
@ -128,11 +107,11 @@ const initializeClient = async ({ req, res, endpointOption }) => {
collectedUsage,
artifactPromises,
endpoint: EModelEndpoint.agents,
configOptions: options.configOptions,
attachments: endpointOption.attachments,
maxContextTokens:
agent.max_context_tokens ??
getModelMaxTokens(modelOptions.model, providerEndpointMap[agent.provider]),
getModelMaxTokens(modelOptions.model, providerEndpointMap[provider]) ??
4000,
});
return { client };
};

View file

@ -2,17 +2,17 @@ const {
CacheKeys,
ErrorTypes,
envVarRegex,
EModelEndpoint,
FetchTokenConfig,
extractEnvVariable,
} = require('librechat-data-provider');
const { Providers } = require('@librechat/agents');
const { getUserKeyValues, checkUserKeyExpiry } = require('~/server/services/UserService');
const getCustomConfig = require('~/server/services/Config/getCustomConfig');
const { getLLMConfig } = require('~/server/services/Endpoints/openAI/llm');
const { getCustomEndpointConfig } = require('~/server/services/Config');
const { fetchModels } = require('~/server/services/ModelService');
const getLogStores = require('~/cache/getLogStores');
const { isUserProvided } = require('~/server/utils');
const { OpenAIClient } = require('~/app');
const { Providers } = require('@librechat/agents');
const { PROXY } = process.env;
@ -20,15 +20,11 @@ const initializeClient = async ({ req, res, endpointOption, optionsOnly, overrid
const { key: expiresAt } = req.body;
const endpoint = overrideEndpoint ?? req.body.endpoint;
const customConfig = await getCustomConfig();
if (!customConfig) {
const endpointConfig = await getCustomEndpointConfig(endpoint);
if (!endpointConfig) {
throw new Error(`Config not found for the ${endpoint} custom endpoint.`);
}
const { endpoints = {} } = customConfig;
const customEndpoints = endpoints[EModelEndpoint.custom] ?? [];
const endpointConfig = customEndpoints.find((endpointConfig) => endpointConfig.name === endpoint);
const CUSTOM_API_KEY = extractEnvVariable(endpointConfig.apiKey);
const CUSTOM_BASE_URL = extractEnvVariable(endpointConfig.baseURL);
@ -138,10 +134,21 @@ const initializeClient = async ({ req, res, endpointOption, optionsOnly, overrid
if (optionsOnly) {
const modelOptions = endpointOption.model_parameters;
if (endpoint === Providers.OLLAMA && clientOptions.reverseProxyUrl) {
if (endpoint !== Providers.OLLAMA) {
const requestOptions = Object.assign(
{
modelOptions,
},
clientOptions,
);
return getLLMConfig(apiKey, requestOptions);
}
if (clientOptions.reverseProxyUrl) {
modelOptions.baseUrl = clientOptions.reverseProxyUrl.split('/v1')[0];
delete clientOptions.reverseProxyUrl;
}
return {
llmConfig: modelOptions,
};

View file

@ -2,7 +2,7 @@ const axios = require('axios');
const FormData = require('form-data');
const { Readable } = require('stream');
const { extractEnvVariable, STTProviders } = require('librechat-data-provider');
const getCustomConfig = require('~/server/services/Config/getCustomConfig');
const { getCustomConfig } = require('~/server/services/Config');
const { genAzureEndpoint } = require('~/utils');
const { logger } = require('~/config');

View file

@ -1,9 +1,9 @@
const axios = require('axios');
const { extractEnvVariable, TTSProviders } = require('librechat-data-provider');
const { logger } = require('~/config');
const getCustomConfig = require('~/server/services/Config/getCustomConfig');
const { genAzureEndpoint } = require('~/utils');
const { getRandomVoiceId, createChunkProcessor, splitTextIntoChunks } = require('./streamAudio');
const { getCustomConfig } = require('~/server/services/Config');
const { genAzureEndpoint } = require('~/utils');
const { logger } = require('~/config');
/**
* Service class for handling Text-to-Speech (TTS) operations.

View file

@ -1,4 +1,4 @@
const getCustomConfig = require('~/server/services/Config/getCustomConfig');
const { getCustomConfig } = require('~/server/services/Config');
const { logger } = require('~/config');
/**

View file

@ -1,5 +1,5 @@
const { TTSProviders } = require('librechat-data-provider');
const getCustomConfig = require('~/server/services/Config/getCustomConfig');
const { getCustomConfig } = require('~/server/services/Config');
const { getProvider } = require('./TTSService');
/**

View file

@ -1,7 +1,7 @@
const getVoices = require('./getVoices');
const getCustomConfigSpeech = require('./getCustomConfigSpeech');
const TTSService = require('./TTSService');
const STTService = require('./STTService');
const getVoices = require('./getVoices');
module.exports = {
getVoices,

View file

@ -1,4 +1,4 @@
const getCustomConfig = require('~/server/services/Config/getCustomConfig');
const { getCustomConfig } = require('~/server/services/Config');
async function isDomainAllowed(email) {
if (!email) {

View file

@ -1,7 +1,9 @@
const getCustomConfig = require('~/server/services/Config/getCustomConfig');
const { getCustomConfig } = require('~/server/services/Config');
const isDomainAllowed = require('./isDomainAllowed');
jest.mock('~/server/services/Config/getCustomConfig', () => jest.fn());
jest.mock('~/server/services/Config', () => ({
getCustomConfig: jest.fn(),
}));
describe('isDomainAllowed', () => {
afterEach(() => {

View file

@ -942,6 +942,29 @@
* @memberof typedefs
*/
/**
* @typedef {Object} AgentClientOptions
* @property {Agent} agent - The agent configuration object
* @property {string} endpoint - The endpoint identifier for the agent
* @property {Object} req - The request object
* @property {string} [name] - The username
* @property {string} [modelLabel] - The label for the model being used
* @property {number} [maxContextTokens] - Maximum number of tokens allowed in context
* @property {Object} [endpointTokenConfig] - Token configuration for the endpoint
* @property {boolean} [resendFiles] - Whether to resend files
* @property {string} [imageDetail] - Detail level for image processing
* @property {Object} [spec] - Specification object
* @property {Promise<MongoFile[]>} [attachments] - Promise resolving to file attachments
* @property {Object} [headers] - Additional headers for requests
* @property {string} [proxy] - Proxy configuration
* @property {Object} [tools] - Available tools for the agent
* @property {Object} [toolMap] - Mapping of tool configurations
* @property {Object} [eventHandlers] - Custom event handlers
* @property {Object} [addParams] - Additional parameters to add to requests
* @property {string[]} [dropParams] - Parameters to remove from requests
* @memberof typedefs
*/
/**
* @exports ImportBatchBuilder
* @typedef {import('./server/utils/import/importBatchBuilder.js').ImportBatchBuilder} ImportBatchBuilder

View file

@ -76,6 +76,10 @@ const anthropicModels = {
'claude-3.5-sonnet-latest': 200000,
};
const deepseekModels = {
deepseek: 127500,
};
const metaModels = {
llama3: 8000,
llama2: 4000,
@ -117,6 +121,7 @@ const bedrockModels = {
...mistralModels,
...cohereModels,
...ollamaModels,
...deepseekModels,
...metaModels,
...ai21Models,
...amazonModels,

View file

@ -357,6 +357,11 @@ describe('Meta Models Tests', () => {
expect(getModelMaxTokens('meta/llama3')).toBe(8000);
expect(getModelMaxTokens('meta/llama2')).toBe(4000);
});
test('should match Deepseek model variations', () => {
expect(getModelMaxTokens('deepseek-chat')).toBe(127500);
expect(getModelMaxTokens('deepseek-coder')).toBe(127500);
});
});
describe('matchModelName', () => {
@ -383,6 +388,11 @@ describe('Meta Models Tests', () => {
expect(matchModelName('llama3', EModelEndpoint.bedrock)).toBe('llama3');
expect(matchModelName('llama3.1:8b', EModelEndpoint.bedrock)).toBe('llama3.1:8b');
});
test('should match Deepseek model variations', () => {
expect(matchModelName('deepseek-chat')).toBe('deepseek');
expect(matchModelName('deepseek-coder')).toBe('deepseek');
});
});
describe('processModelData with Meta models', () => {

2
package-lock.json generated
View file

@ -36283,7 +36283,7 @@
},
"packages/data-provider": {
"name": "librechat-data-provider",
"version": "0.7.51",
"version": "0.7.52",
"license": "ISC",
"dependencies": {
"@types/js-yaml": "^4.0.9",

View file

@ -1,6 +1,6 @@
{
"name": "librechat-data-provider",
"version": "0.7.51",
"version": "0.7.52",
"description": "data services for librechat apps",
"main": "dist/index.js",
"module": "dist/index.es.js",

View file

@ -184,6 +184,8 @@ export type Agent = {
id: string;
name: string | null;
author?: string | null;
/** The original custom endpoint name, lowercased */
endpoint?: string | null;
authorName?: string | null;
description: string | null;
created_at: number;