🤖 feat: Support Google Agents, fix Various Provider Configurations (#5126)

* feat: Refactor ModelEndHandler to collect usage metadata only if it exists

* feat: google tool end handling, custom anthropic class for better token ux

* refactor: differentiate between client <> request options

* feat: initial support for google agents

* feat: only cache messages with non-empty text

* feat: Cache non-empty messages in chatV2 controller

* fix: anthropic llm client options llmConfig

* refactor: streamline client options handling in LLM configuration

* fix: VertexAI Agent Auth & Tool Handling

* fix: additional fields for llmConfig, however customHeaders are not supported by langchain, requires PR

* feat: set default location for vertexai LLM configuration

* fix: outdated OpenAI Client options for getLLMConfig

* chore: agent provider options typing

* chore: add note about currently unsupported customHeaders in langchain GenAI client

* fix: skip transaction creation when rawAmount is NaN
This commit is contained in:
Danny Avila 2024-12-28 17:15:03 -05:00 committed by GitHub
parent a423eb8c7b
commit 24cad6bbd4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 429 additions and 363 deletions

View file

@ -649,15 +649,17 @@ class BaseClient {
this.responsePromise = this.saveMessageToDatabase(responseMessage, saveOptions, user);
this.savedMessageIds.add(responseMessage.messageId);
const messageCache = getLogStores(CacheKeys.MESSAGES);
messageCache.set(
responseMessageId,
{
text: responseMessage.text,
complete: true,
},
Time.FIVE_MINUTES,
);
if (responseMessage.text) {
const messageCache = getLogStores(CacheKeys.MESSAGES);
messageCache.set(
responseMessageId,
{
text: responseMessage.text,
complete: true,
},
Time.FIVE_MINUTES,
);
}
delete responseMessage.tokenCount;
return responseMessage;
}

View file

@ -256,15 +256,17 @@ class PluginsClient extends OpenAIClient {
}
this.responsePromise = this.saveMessageToDatabase(responseMessage, saveOptions, user);
const messageCache = getLogStores(CacheKeys.MESSAGES);
messageCache.set(
responseMessage.messageId,
{
text: responseMessage.text,
complete: true,
},
Time.FIVE_MINUTES,
);
if (responseMessage.text) {
const messageCache = getLogStores(CacheKeys.MESSAGES);
messageCache.set(
responseMessage.messageId,
{
text: responseMessage.text,
complete: true,
},
Time.FIVE_MINUTES,
);
}
delete responseMessage.tokenCount;
return { ...responseMessage, ...result };
}

View file

@ -27,6 +27,9 @@ transactionSchema.methods.calculateTokenValue = function () {
*/
transactionSchema.statics.create = async function (txData) {
const Transaction = this;
if (txData.rawAmount != null && isNaN(txData.rawAmount)) {
return;
}
const transaction = new Transaction(txData);
transaction.endpointTokenConfig = txData.endpointTokenConfig;

View file

@ -1,5 +1,6 @@
const mongoose = require('mongoose');
const { MongoMemoryServer } = require('mongodb-memory-server');
const { Transaction } = require('./Transaction');
const Balance = require('./Balance');
const { spendTokens, spendStructuredTokens } = require('./spendTokens');
const { getMultiplier, getCacheMultiplier } = require('./tx');
@ -346,3 +347,28 @@ describe('Structured Token Spending Tests', () => {
expect(result.completion.completion).toBeCloseTo(-50 * 15 * 1.15, 0); // Assuming multiplier is 15 and cancelRate is 1.15
});
});
describe('NaN Handling Tests', () => {
test('should skip transaction creation when rawAmount is NaN', async () => {
const userId = new mongoose.Types.ObjectId();
const initialBalance = 10000000;
await Balance.create({ user: userId, tokenCredits: initialBalance });
const model = 'gpt-3.5-turbo';
const txData = {
user: userId,
conversationId: 'test-conversation-id',
model,
context: 'test',
endpointTokenConfig: null,
rawAmount: NaN,
tokenType: 'prompt',
};
const result = await Transaction.create(txData);
expect(result).toBeUndefined();
const balance = await Balance.findOne({ user: userId });
expect(balance.tokenCredits).toBe(initialBalance);
});
});

View file

@ -44,7 +44,7 @@
"@langchain/google-genai": "^0.1.4",
"@langchain/google-vertexai": "^0.1.4",
"@langchain/textsplitters": "^0.1.0",
"@librechat/agents": "^1.8.8",
"@librechat/agents": "^1.9.7",
"axios": "^1.7.7",
"bcryptjs": "^2.4.3",
"cheerio": "^1.0.0-rc.12",

View file

@ -1,8 +1,10 @@
const { Tools, StepTypes, imageGenTools, FileContext } = require('librechat-data-provider');
const {
EnvVar,
Providers,
GraphEvents,
ToolEndHandler,
handleToolCalls,
ChatModelStreamHandler,
} = require('@librechat/agents');
const { processCodeOutput } = require('~/server/services/Files/Code/process');
@ -57,13 +59,22 @@ class ModelEndHandler {
return;
}
const usage = data?.output?.usage_metadata;
if (metadata?.model) {
usage.model = metadata.model;
}
try {
if (metadata.provider === Providers.GOOGLE) {
handleToolCalls(data?.output?.tool_calls, metadata, graph);
}
const usage = data?.output?.usage_metadata;
if (!usage) {
return;
}
if (metadata?.model) {
usage.model = metadata.model;
}
if (usage) {
this.collectedUsage.push(usage);
} catch (error) {
logger.error('Error handling model end event:', error);
}
}
}

View file

@ -398,15 +398,17 @@ const chatV2 = async (req, res) => {
response = streamRunManager;
response.text = streamRunManager.intermediateText;
const messageCache = getLogStores(CacheKeys.MESSAGES);
messageCache.set(
responseMessageId,
{
complete: true,
text: response.text,
},
Time.FIVE_MINUTES,
);
if (response.text) {
const messageCache = getLogStores(CacheKeys.MESSAGES);
messageCache.set(
responseMessageId,
{
complete: true,
text: response.text,
},
Time.FIVE_MINUTES,
);
}
};
await processRun();

View file

@ -12,6 +12,7 @@ const initAnthropic = require('~/server/services/Endpoints/anthropic/initialize'
const getBedrockOptions = require('~/server/services/Endpoints/bedrock/options');
const initOpenAI = require('~/server/services/Endpoints/openAI/initialize');
const initCustom = require('~/server/services/Endpoints/custom/initialize');
const initGoogle = require('~/server/services/Endpoints/google/initialize');
const { getCustomEndpointConfig } = require('~/server/services/Config');
const { loadAgentTools } = require('~/server/services/ToolService');
const AgentClient = require('~/server/controllers/agents/client');
@ -24,6 +25,7 @@ const providerConfigMap = {
[EModelEndpoint.azureOpenAI]: initOpenAI,
[EModelEndpoint.anthropic]: initAnthropic,
[EModelEndpoint.bedrock]: getBedrockOptions,
[EModelEndpoint.google]: initGoogle,
[Providers.OLLAMA]: initCustom,
};
@ -116,6 +118,10 @@ const initializeAgentOptions = async ({
endpointOption: _endpointOption,
});
if (options.provider != null) {
agent.provider = options.provider;
}
agent.model_parameters = Object.assign(model_parameters, options.llmConfig);
if (options.configOptions) {
agent.model_parameters.configuration = options.configOptions;

View file

@ -20,7 +20,7 @@ const initializeClient = async ({ req, res, endpointOption, overrideModel, optio
checkUserKeyExpiry(expiresAt, EModelEndpoint.anthropic);
}
const clientOptions = {};
let clientOptions = {};
/** @type {undefined | TBaseEndpoint} */
const anthropicConfig = req.app.locals[EModelEndpoint.anthropic];
@ -36,7 +36,7 @@ const initializeClient = async ({ req, res, endpointOption, overrideModel, optio
}
if (optionsOnly) {
const requestOptions = Object.assign(
clientOptions = Object.assign(
{
reverseProxyUrl: ANTHROPIC_REVERSE_PROXY ?? null,
proxy: PROXY ?? null,
@ -45,9 +45,9 @@ const initializeClient = async ({ req, res, endpointOption, overrideModel, optio
clientOptions,
);
if (overrideModel) {
requestOptions.modelOptions.model = overrideModel;
clientOptions.modelOptions.model = overrideModel;
}
return getLLMConfig(anthropicApiKey, requestOptions);
return getLLMConfig(anthropicApiKey, clientOptions);
}
const client = new AnthropicClient(anthropicApiKey, {

View file

@ -28,28 +28,32 @@ function getLLMConfig(apiKey, options = {}) {
const mergedOptions = Object.assign(defaultOptions, options.modelOptions);
/** @type {AnthropicClientOptions} */
const requestOptions = {
apiKey,
model: mergedOptions.model,
stream: mergedOptions.stream,
temperature: mergedOptions.temperature,
top_p: mergedOptions.topP,
top_k: mergedOptions.topK,
stop_sequences: mergedOptions.stop,
max_tokens:
topP: mergedOptions.topP,
topK: mergedOptions.topK,
stopSequences: mergedOptions.stop,
maxTokens:
mergedOptions.maxOutputTokens || anthropicSettings.maxOutputTokens.reset(mergedOptions.model),
clientOptions: {},
};
const configOptions = {};
if (options.proxy) {
configOptions.httpAgent = new HttpsProxyAgent(options.proxy);
requestOptions.clientOptions.httpAgent = new HttpsProxyAgent(options.proxy);
}
if (options.reverseProxyUrl) {
configOptions.baseURL = options.reverseProxyUrl;
requestOptions.clientOptions.baseURL = options.reverseProxyUrl;
}
return { llmConfig: removeNullishValues(requestOptions), configOptions };
return {
/** @type {AnthropicClientOptions} */
llmConfig: removeNullishValues(requestOptions),
};
}
module.exports = { getLLMConfig };

View file

@ -60,42 +60,41 @@ const getOptions = async ({ req, endpointOption }) => {
streamRate = allConfig.streamRate;
}
/** @type {import('@librechat/agents').BedrockConverseClientOptions} */
const requestOptions = Object.assign(
{
model: endpointOption.model,
region: BEDROCK_AWS_DEFAULT_REGION,
streaming: true,
streamUsage: true,
callbacks: [
{
handleLLMNewToken: async () => {
if (!streamRate) {
return;
}
await sleep(streamRate);
},
/** @type {BedrockClientOptions} */
const requestOptions = {
model: endpointOption.model,
region: BEDROCK_AWS_DEFAULT_REGION,
streaming: true,
streamUsage: true,
callbacks: [
{
handleLLMNewToken: async () => {
if (!streamRate) {
return;
}
await sleep(streamRate);
},
],
},
endpointOption.model_parameters,
);
},
],
};
if (credentials) {
requestOptions.credentials = credentials;
}
if (BEDROCK_REVERSE_PROXY) {
requestOptions.endpointHost = BEDROCK_REVERSE_PROXY;
}
const configOptions = {};
if (PROXY) {
/** NOTE: NOT SUPPORTED BY BEDROCK */
configOptions.httpAgent = new HttpsProxyAgent(PROXY);
}
if (BEDROCK_REVERSE_PROXY) {
configOptions.endpointHost = BEDROCK_REVERSE_PROXY;
}
return {
llmConfig: removeNullishValues(requestOptions),
/** @type {BedrockClientOptions} */
llmConfig: removeNullishValues(Object.assign(requestOptions, endpointOption.model_parameters)),
configOptions,
};
};

View file

@ -123,7 +123,7 @@ const initializeClient = async ({ req, res, endpointOption, optionsOnly, overrid
customOptions.streamRate = allConfig.streamRate;
}
const clientOptions = {
let clientOptions = {
reverseProxyUrl: baseURL ?? null,
proxy: PROXY ?? null,
req,
@ -135,13 +135,13 @@ const initializeClient = async ({ req, res, endpointOption, optionsOnly, overrid
if (optionsOnly) {
const modelOptions = endpointOption.model_parameters;
if (endpoint !== Providers.OLLAMA) {
const requestOptions = Object.assign(
clientOptions = Object.assign(
{
modelOptions,
},
clientOptions,
);
const options = getLLMConfig(apiKey, requestOptions);
const options = getLLMConfig(apiKey, clientOptions);
if (!customOptions.streamRate) {
return options;
}

View file

@ -1,9 +1,10 @@
const { EModelEndpoint, AuthKeys } = require('librechat-data-provider');
const { getUserKey, checkUserKeyExpiry } = require('~/server/services/UserService');
const { GoogleClient } = require('~/app');
const { getLLMConfig } = require('~/server/services/Endpoints/google/llm');
const { isEnabled } = require('~/server/utils');
const { GoogleClient } = require('~/app');
const initializeClient = async ({ req, res, endpointOption }) => {
const initializeClient = async ({ req, res, endpointOption, overrideModel, optionsOnly }) => {
const {
GOOGLE_KEY,
GOOGLE_REVERSE_PROXY,
@ -33,7 +34,7 @@ const initializeClient = async ({ req, res, endpointOption }) => {
[AuthKeys.GOOGLE_API_KEY]: GOOGLE_KEY,
};
const clientOptions = {};
let clientOptions = {};
/** @type {undefined | TBaseEndpoint} */
const allConfig = req.app.locals.all;
@ -48,7 +49,7 @@ const initializeClient = async ({ req, res, endpointOption }) => {
clientOptions.streamRate = allConfig.streamRate;
}
const client = new GoogleClient(credentials, {
clientOptions = {
req,
res,
reverseProxyUrl: GOOGLE_REVERSE_PROXY ?? null,
@ -56,7 +57,22 @@ const initializeClient = async ({ req, res, endpointOption }) => {
proxy: PROXY ?? null,
...clientOptions,
...endpointOption,
});
};
if (optionsOnly) {
clientOptions = Object.assign(
{
modelOptions: endpointOption.model_parameters,
},
clientOptions,
);
if (overrideModel) {
clientOptions.modelOptions.model = overrideModel;
}
return getLLMConfig(credentials, clientOptions);
}
const client = new GoogleClient(credentials, clientOptions);
return {
client,

View file

@ -0,0 +1,146 @@
const { Providers } = require('@librechat/agents');
const { AuthKeys } = require('librechat-data-provider');
// Example internal constant from your code
const EXCLUDED_GENAI_MODELS = /gemini-(?:1\.0|1-0|pro)/;
function getSafetySettings() {
return [
{
category: 'HARM_CATEGORY_SEXUALLY_EXPLICIT',
threshold: process.env.GOOGLE_SAFETY_SEXUALLY_EXPLICIT || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
},
{
category: 'HARM_CATEGORY_HATE_SPEECH',
threshold: process.env.GOOGLE_SAFETY_HATE_SPEECH || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
},
{
category: 'HARM_CATEGORY_HARASSMENT',
threshold: process.env.GOOGLE_SAFETY_HARASSMENT || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
},
{
category: 'HARM_CATEGORY_DANGEROUS_CONTENT',
threshold: process.env.GOOGLE_SAFETY_DANGEROUS_CONTENT || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
},
{
category: 'HARM_CATEGORY_CIVIC_INTEGRITY',
threshold: process.env.GOOGLE_SAFETY_CIVIC_INTEGRITY || 'BLOCK_NONE',
},
];
}
/**
* Replicates core logic from GoogleClient's constructor and setOptions, plus client determination.
* Returns an object with the provider label and the final options that would be passed to createLLM.
*
* @param {string | object} credentials - Either a JSON string or an object containing Google keys
* @param {object} [options={}] - The same shape as the "GoogleClient" constructor options
*/
function getLLMConfig(credentials, options = {}) {
// 1. Parse credentials
let creds = {};
if (typeof credentials === 'string') {
try {
creds = JSON.parse(credentials);
} catch (err) {
throw new Error(`Error parsing string credentials: ${err.message}`);
}
} else if (credentials && typeof credentials === 'object') {
creds = credentials;
}
// Extract from credentials
const serviceKeyRaw = creds[AuthKeys.GOOGLE_SERVICE_KEY] ?? {};
const serviceKey =
typeof serviceKeyRaw === 'string' ? JSON.parse(serviceKeyRaw) : serviceKeyRaw ?? {};
const project_id = serviceKey?.project_id ?? null;
const apiKey = creds[AuthKeys.GOOGLE_API_KEY] ?? null;
const reverseProxyUrl = options.reverseProxyUrl;
const authHeader = options.authHeader;
/** @type {GoogleClientOptions | VertexAIClientOptions} */
let llmConfig = {
...(options.modelOptions || {}),
safetySettings: getSafetySettings(),
maxRetries: 2,
};
const isGenerativeModel = llmConfig.model.includes('gemini');
const isChatModel = !isGenerativeModel && llmConfig.model.includes('chat');
const isTextModel = !isGenerativeModel && !isChatModel && /code|text/.test(llmConfig.model);
let provider;
if (project_id && isTextModel) {
provider = Providers.VERTEXAI;
} else if (project_id && isChatModel) {
provider = Providers.VERTEXAI;
} else if (project_id) {
provider = Providers.VERTEXAI;
} else if (!EXCLUDED_GENAI_MODELS.test(llmConfig.model)) {
provider = Providers.GOOGLE;
} else {
provider = Providers.GOOGLE;
}
// If we have a GCP project => Vertex AI
if (project_id && provider === Providers.VERTEXAI) {
/** @type {VertexAIClientOptions['authOptions']} */
llmConfig.authOptions = {
credentials: { ...serviceKey },
projectId: project_id,
};
llmConfig.location = process.env.GOOGLE_LOC || 'us-central1';
} else if (apiKey && provider === Providers.GOOGLE) {
llmConfig.apiKey = apiKey;
}
/*
let legacyOptions = {};
// Filter out any "examples" that are empty
legacyOptions.examples = (legacyOptions.examples ?? [])
.filter(Boolean)
.filter((obj) => obj?.input?.content !== '' && obj?.output?.content !== '');
// If user has "examples" from legacyOptions, push them onto llmConfig
if (legacyOptions.examples?.length) {
llmConfig.examples = legacyOptions.examples.map((ex) => {
const { input, output } = ex;
if (!input?.content || !output?.content) {return undefined;}
return {
input: new HumanMessage(input.content),
output: new AIMessage(output.content),
};
}).filter(Boolean);
}
*/
if (reverseProxyUrl) {
llmConfig.baseUrl = reverseProxyUrl;
}
if (authHeader) {
/**
* NOTE: NOT SUPPORTED BY LANGCHAIN GENAI CLIENT,
* REQUIRES PR IN https://github.com/langchain-ai/langchainjs
*/
llmConfig.customHeaders = {
Authorization: `Bearer ${apiKey}`,
};
}
// Return the final shape
return {
/** @type {Providers.GOOGLE | Providers.VERTEXAI} */
provider,
/** @type {GoogleClientOptions | VertexAIClientOptions} */
llmConfig,
};
}
module.exports = {
getLLMConfig,
};

View file

@ -54,7 +54,7 @@ const initializeClient = async ({
let apiKey = userProvidesKey ? userValues?.apiKey : credentials[endpoint];
let baseURL = userProvidesURL ? userValues?.baseURL : baseURLOptions[endpoint];
const clientOptions = {
let clientOptions = {
contextStrategy,
proxy: PROXY ?? null,
debug: isEnabled(DEBUG_OPENAI),
@ -134,13 +134,13 @@ const initializeClient = async ({
}
if (optionsOnly) {
const requestOptions = Object.assign(
clientOptions = Object.assign(
{
modelOptions: endpointOption.model_parameters,
},
clientOptions,
);
const options = getLLMConfig(apiKey, requestOptions);
const options = getLLMConfig(apiKey, clientOptions);
if (!clientOptions.streamRate) {
return options;
}

View file

@ -38,6 +38,7 @@ function getLLMConfig(apiKey, options = {}) {
dropParams,
} = options;
/** @type {OpenAIClientOptions} */
let llmConfig = {
streaming,
};
@ -54,29 +55,28 @@ function getLLMConfig(apiKey, options = {}) {
});
}
/** @type {OpenAIClientOptions['configuration']} */
const configOptions = {};
// Handle OpenRouter or custom reverse proxy
if (useOpenRouter || reverseProxyUrl === 'https://openrouter.ai/api/v1') {
configOptions.basePath = 'https://openrouter.ai/api/v1';
configOptions.baseOptions = {
headers: Object.assign(
{
'HTTP-Referer': 'https://librechat.ai',
'X-Title': 'LibreChat',
},
headers,
),
};
configOptions.baseURL = 'https://openrouter.ai/api/v1';
configOptions.defaultHeaders = Object.assign(
{
'HTTP-Referer': 'https://librechat.ai',
'X-Title': 'LibreChat',
},
headers,
);
} else if (reverseProxyUrl) {
configOptions.basePath = reverseProxyUrl;
configOptions.baseURL = reverseProxyUrl;
if (headers) {
configOptions.baseOptions = { headers };
configOptions.defaultHeaders = headers;
}
}
if (defaultQuery) {
configOptions.baseOptions.defaultQuery = defaultQuery;
configOptions.defaultQuery = defaultQuery;
}
if (proxy) {
@ -97,9 +97,9 @@ function getLLMConfig(apiKey, options = {}) {
llmConfig.model = process.env.AZURE_OPENAI_DEFAULT_MODEL;
}
if (configOptions.basePath) {
if (configOptions.baseURL) {
const azureURL = constructAzureURL({
baseURL: configOptions.basePath,
baseURL: configOptions.baseURL,
azureOptions: azure,
});
azure.azureOpenAIBasePath = azureURL.split(`/${azure.azureOpenAIApiDeploymentName}`)[0];
@ -118,7 +118,12 @@ function getLLMConfig(apiKey, options = {}) {
llmConfig.organization = process.env.OPENAI_ORGANIZATION;
}
return { llmConfig, configOptions };
return {
/** @type {OpenAIClientOptions} */
llmConfig,
/** @type {OpenAIClientOptions['configuration']} */
configOptions,
};
}
module.exports = { getLLMConfig };

View file

@ -38,12 +38,36 @@
* @memberof typedefs
*/
/**
* @exports OpenAIClientOptions
* @typedef {import('@librechat/agents').OpenAIClientOptions} OpenAIClientOptions
* @memberof typedefs
*/
/**
* @exports AnthropicClientOptions
* @typedef {import('@librechat/agents').AnthropicClientOptions} AnthropicClientOptions
* @memberof typedefs
*/
/**
* @exports BedrockClientOptions
* @typedef {import('@librechat/agents').BedrockConverseClientOptions} BedrockClientOptions
* @memberof typedefs
*/
/**
* @exports VertexAIClientOptions
* @typedef {import('@librechat/agents').VertexAIClientOptions} VertexAIClientOptions
* @memberof typedefs
*/
/**
* @exports GoogleClientOptions
* @typedef {import('@librechat/agents').GoogleClientOptions} GoogleClientOptions
* @memberof typedefs
*/
/**
* @exports StreamEventData
* @typedef {import('@librechat/agents').StreamEventData} StreamEventData