🤖 feat: Custom Endpoint Agents (experimental) (#4627)

* wip: first pass, custom endpoint agents

* chore: imports

* chore: consolidate exports

* fix: imports

* feat: convert message.content array to strings for legacy format handling (deepseek/groq)

* refactor: normalize ollama endpoint name

* refactor: update mocking in isDomainAllowed.spec.js

* refactor: update deepseekModels in tokens.js and tokens.spec.js
This commit is contained in:
Danny Avila 2024-11-04 12:59:04 -05:00 committed by GitHub
parent 9437e95315
commit 2e519f9b57
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
23 changed files with 230 additions and 73 deletions

View file

@ -1,16 +1,3 @@
// const {
// ErrorTypes,
// EModelEndpoint,
// resolveHeaders,
// mapModelToAzureConfig,
// } = require('librechat-data-provider');
// const { getUserKeyValues, checkUserKeyExpiry } = require('~/server/services/UserService');
// const { isEnabled, isUserProvided } = require('~/server/utils');
// const { getAzureCredentials } = require('~/utils');
// const { OpenAIClient } = require('~/app');
const { z } = require('zod');
const { tool } = require('@langchain/core/tools');
const { createContentAggregator, Providers } = require('@librechat/agents');
const {
EModelEndpoint,
@ -25,30 +12,11 @@ const initAnthropic = require('~/server/services/Endpoints/anthropic/initialize'
const getBedrockOptions = require('~/server/services/Endpoints/bedrock/options');
const initOpenAI = require('~/server/services/Endpoints/openAI/initialize');
const initCustom = require('~/server/services/Endpoints/custom/initialize');
const { getCustomEndpointConfig } = require('~/server/services/Config');
const { loadAgentTools } = require('~/server/services/ToolService');
const AgentClient = require('~/server/controllers/agents/client');
const { getModelMaxTokens } = require('~/utils');
/* For testing errors */
const _getWeather = tool(
async ({ location }) => {
if (location === 'SAN FRANCISCO') {
return 'It\'s 60 degrees and foggy';
} else if (location.toLowerCase() === 'san francisco') {
throw new Error('Input queries must be all capitals');
} else {
throw new Error('Invalid input.');
}
},
{
name: 'get_weather',
description: 'Call to get the current weather',
schema: z.object({
location: z.string(),
}),
},
);
const providerConfigMap = {
[EModelEndpoint.openAI]: initOpenAI,
[EModelEndpoint.azureOpenAI]: initOpenAI,
@ -85,18 +53,25 @@ const initializeClient = async ({ req, res, endpointOption }) => {
if (!agent) {
throw new Error('Agent not found');
}
const { tools, toolMap } = await loadAgentTools({
req,
tools: agent.tools,
agent_id: agent.id,
tool_resources: agent.tool_resources,
// openAIApiKey: process.env.OPENAI_API_KEY,
});
const provider = agent.provider;
let modelOptions = { model: agent.model };
let getOptions = providerConfigMap[agent.provider];
let getOptions = providerConfigMap[provider];
if (!getOptions) {
throw new Error(`Provider ${agent.provider} not supported`);
const customEndpointConfig = await getCustomEndpointConfig(provider);
if (!customEndpointConfig) {
throw new Error(`Provider ${provider} not supported`);
}
getOptions = initCustom;
agent.provider = Providers.OPENAI;
agent.endpoint = provider.toLowerCase();
}
// TODO: pass-in override settings that are specific to current run
@ -106,10 +81,14 @@ const initializeClient = async ({ req, res, endpointOption }) => {
res,
endpointOption,
optionsOnly: true,
overrideEndpoint: agent.provider,
overrideEndpoint: provider,
overrideModel: agent.model,
});
modelOptions = Object.assign(modelOptions, options.llmConfig);
if (options.configOptions) {
modelOptions.configuration = options.configOptions;
}
const sender = getResponseSender({
...endpointOption,
@ -128,11 +107,11 @@ const initializeClient = async ({ req, res, endpointOption }) => {
collectedUsage,
artifactPromises,
endpoint: EModelEndpoint.agents,
configOptions: options.configOptions,
attachments: endpointOption.attachments,
maxContextTokens:
agent.max_context_tokens ??
getModelMaxTokens(modelOptions.model, providerEndpointMap[agent.provider]),
getModelMaxTokens(modelOptions.model, providerEndpointMap[provider]) ??
4000,
});
return { client };
};

View file

@ -2,17 +2,17 @@ const {
CacheKeys,
ErrorTypes,
envVarRegex,
EModelEndpoint,
FetchTokenConfig,
extractEnvVariable,
} = require('librechat-data-provider');
const { Providers } = require('@librechat/agents');
const { getUserKeyValues, checkUserKeyExpiry } = require('~/server/services/UserService');
const getCustomConfig = require('~/server/services/Config/getCustomConfig');
const { getLLMConfig } = require('~/server/services/Endpoints/openAI/llm');
const { getCustomEndpointConfig } = require('~/server/services/Config');
const { fetchModels } = require('~/server/services/ModelService');
const getLogStores = require('~/cache/getLogStores');
const { isUserProvided } = require('~/server/utils');
const { OpenAIClient } = require('~/app');
const { Providers } = require('@librechat/agents');
const { PROXY } = process.env;
@ -20,15 +20,11 @@ const initializeClient = async ({ req, res, endpointOption, optionsOnly, overrid
const { key: expiresAt } = req.body;
const endpoint = overrideEndpoint ?? req.body.endpoint;
const customConfig = await getCustomConfig();
if (!customConfig) {
const endpointConfig = await getCustomEndpointConfig(endpoint);
if (!endpointConfig) {
throw new Error(`Config not found for the ${endpoint} custom endpoint.`);
}
const { endpoints = {} } = customConfig;
const customEndpoints = endpoints[EModelEndpoint.custom] ?? [];
const endpointConfig = customEndpoints.find((endpointConfig) => endpointConfig.name === endpoint);
const CUSTOM_API_KEY = extractEnvVariable(endpointConfig.apiKey);
const CUSTOM_BASE_URL = extractEnvVariable(endpointConfig.baseURL);
@ -138,10 +134,21 @@ const initializeClient = async ({ req, res, endpointOption, optionsOnly, overrid
if (optionsOnly) {
const modelOptions = endpointOption.model_parameters;
if (endpoint === Providers.OLLAMA && clientOptions.reverseProxyUrl) {
if (endpoint !== Providers.OLLAMA) {
const requestOptions = Object.assign(
{
modelOptions,
},
clientOptions,
);
return getLLMConfig(apiKey, requestOptions);
}
if (clientOptions.reverseProxyUrl) {
modelOptions.baseUrl = clientOptions.reverseProxyUrl.split('/v1')[0];
delete clientOptions.reverseProxyUrl;
}
return {
llmConfig: modelOptions,
};