mirror of
https://github.com/danny-avila/LibreChat.git
synced 2025-09-22 06:00:56 +02:00
🤖 fix: GoogleClient Context Handling & GenAI Parameters (#5503)
* fix: remove legacy code for GoogleClient and fix model parameters for GenAI * refactor: streamline client init logic * refactor: remove legacy vertex clients, WIP remote vertex token count * refactor: enhance GoogleClient with improved type definitions and streamline token count method * refactor: remove unused methods and consolidate methods * refactor: remove examples * refactor: improve input handling logic in DynamicInput component * refactor: enhance GoogleClient with token usage tracking and context handling improvements * refactor: update GoogleClient to support 'learnlm' model and streamline model checks * refactor: remove unused text model handling in GoogleClient * refactor: record token usage for GoogleClient titles and handle edge cases * chore: remove unused undici, addresses verbose version warning
This commit is contained in:
parent
47b72e8159
commit
528ee62eb1
12 changed files with 277 additions and 270 deletions
|
@ -13,7 +13,6 @@ const {
|
|||
const { extractBaseURL, constructAzureURL, genAzureChatCompletion } = require('~/utils');
|
||||
const { createContextHandlers } = require('./prompts');
|
||||
const { createCoherePayload } = require('./llm');
|
||||
const { Agent, ProxyAgent } = require('undici');
|
||||
const BaseClient = require('./BaseClient');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
|
@ -186,10 +185,6 @@ class ChatGPTClient extends BaseClient {
|
|||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
dispatcher: new Agent({
|
||||
bodyTimeout: 0,
|
||||
headersTimeout: 0,
|
||||
}),
|
||||
};
|
||||
|
||||
if (this.isVisionModel) {
|
||||
|
@ -275,10 +270,6 @@ class ChatGPTClient extends BaseClient {
|
|||
opts.headers['X-Title'] = 'LibreChat';
|
||||
}
|
||||
|
||||
if (this.options.proxy) {
|
||||
opts.dispatcher = new ProxyAgent(this.options.proxy);
|
||||
}
|
||||
|
||||
/* hacky fixes for Mistral AI API:
|
||||
- Re-orders system message to the top of the messages payload, as not allowed anywhere else
|
||||
- If there is only one message and it's a system message, change the role to user
|
||||
|
|
|
@ -1,22 +1,23 @@
|
|||
const { google } = require('googleapis');
|
||||
const { Agent, ProxyAgent } = require('undici');
|
||||
const { concat } = require('@langchain/core/utils/stream');
|
||||
const { ChatVertexAI } = require('@langchain/google-vertexai');
|
||||
const { GoogleVertexAI } = require('@langchain/google-vertexai');
|
||||
const { ChatGoogleVertexAI } = require('@langchain/google-vertexai');
|
||||
const { ChatGoogleGenerativeAI } = require('@langchain/google-genai');
|
||||
const { GoogleGenerativeAI: GenAI } = require('@google/generative-ai');
|
||||
const { AIMessage, HumanMessage, SystemMessage } = require('@langchain/core/messages');
|
||||
const { HumanMessage, SystemMessage } = require('@langchain/core/messages');
|
||||
const {
|
||||
googleGenConfigSchema,
|
||||
validateVisionModel,
|
||||
getResponseSender,
|
||||
endpointSettings,
|
||||
EModelEndpoint,
|
||||
VisionModes,
|
||||
ErrorTypes,
|
||||
Constants,
|
||||
AuthKeys,
|
||||
} = require('librechat-data-provider');
|
||||
const { encodeAndFormat } = require('~/server/services/Files/images');
|
||||
const Tokenizer = require('~/server/services/Tokenizer');
|
||||
const { spendTokens } = require('~/models/spendTokens');
|
||||
const { getModelMaxTokens } = require('~/utils');
|
||||
const { sleep } = require('~/server/utils');
|
||||
const { logger } = require('~/config');
|
||||
|
@ -49,9 +50,10 @@ class GoogleClient extends BaseClient {
|
|||
const serviceKey = creds[AuthKeys.GOOGLE_SERVICE_KEY] ?? {};
|
||||
this.serviceKey =
|
||||
serviceKey && typeof serviceKey === 'string' ? JSON.parse(serviceKey) : serviceKey ?? {};
|
||||
/** @type {string | null | undefined} */
|
||||
this.project_id = this.serviceKey.project_id;
|
||||
this.client_email = this.serviceKey.client_email;
|
||||
this.private_key = this.serviceKey.private_key;
|
||||
this.project_id = this.serviceKey.project_id;
|
||||
this.access_token = null;
|
||||
|
||||
this.apiKey = creds[AuthKeys.GOOGLE_API_KEY];
|
||||
|
@ -60,6 +62,15 @@ class GoogleClient extends BaseClient {
|
|||
|
||||
this.authHeader = options.authHeader;
|
||||
|
||||
/** @type {UsageMetadata | undefined} */
|
||||
this.usage;
|
||||
/** The key for the usage object's input tokens
|
||||
* @type {string} */
|
||||
this.inputTokensKey = 'input_tokens';
|
||||
/** The key for the usage object's output tokens
|
||||
* @type {string} */
|
||||
this.outputTokensKey = 'output_tokens';
|
||||
|
||||
if (options.skipSetOptions) {
|
||||
return;
|
||||
}
|
||||
|
@ -119,22 +130,13 @@ class GoogleClient extends BaseClient {
|
|||
this.options = options;
|
||||
}
|
||||
|
||||
this.options.examples = (this.options.examples ?? [])
|
||||
.filter((ex) => ex)
|
||||
.filter((obj) => obj.input.content !== '' && obj.output.content !== '');
|
||||
|
||||
this.modelOptions = this.options.modelOptions || {};
|
||||
|
||||
this.options.attachments?.then((attachments) => this.checkVisionRequest(attachments));
|
||||
|
||||
/** @type {boolean} Whether using a "GenerativeAI" Model */
|
||||
this.isGenerativeModel = this.modelOptions.model.includes('gemini');
|
||||
const { isGenerativeModel } = this;
|
||||
this.isChatModel = !isGenerativeModel && this.modelOptions.model.includes('chat');
|
||||
const { isChatModel } = this;
|
||||
this.isTextModel =
|
||||
!isGenerativeModel && !isChatModel && /code|text/.test(this.modelOptions.model);
|
||||
const { isTextModel } = this;
|
||||
this.isGenerativeModel =
|
||||
this.modelOptions.model.includes('gemini') || this.modelOptions.model.includes('learnlm');
|
||||
|
||||
this.maxContextTokens =
|
||||
this.options.maxContextTokens ??
|
||||
|
@ -170,40 +172,18 @@ class GoogleClient extends BaseClient {
|
|||
this.userLabel = this.options.userLabel || 'User';
|
||||
this.modelLabel = this.options.modelLabel || 'Assistant';
|
||||
|
||||
if (isChatModel || isGenerativeModel) {
|
||||
// Use these faux tokens to help the AI understand the context since we are building the chat log ourselves.
|
||||
// Trying to use "<|im_start|>" causes the AI to still generate "<" or "<|" at the end sometimes for some reason,
|
||||
// without tripping the stop sequences, so I'm using "||>" instead.
|
||||
this.startToken = '||>';
|
||||
this.endToken = '';
|
||||
} else if (isTextModel) {
|
||||
this.startToken = '||>';
|
||||
this.endToken = '';
|
||||
} else {
|
||||
// Previously I was trying to use "<|endoftext|>" but there seems to be some bug with OpenAI's token counting
|
||||
// system that causes only the first "<|endoftext|>" to be counted as 1 token, and the rest are not treated
|
||||
// as a single token. So we're using this instead.
|
||||
this.startToken = '||>';
|
||||
this.endToken = '';
|
||||
}
|
||||
|
||||
if (!this.modelOptions.stop) {
|
||||
const stopTokens = [this.startToken];
|
||||
if (this.endToken && this.endToken !== this.startToken) {
|
||||
stopTokens.push(this.endToken);
|
||||
}
|
||||
stopTokens.push(`\n${this.userLabel}:`);
|
||||
stopTokens.push('<|diff_marker|>');
|
||||
// I chose not to do one for `modelLabel` because I've never seen it happen
|
||||
this.modelOptions.stop = stopTokens;
|
||||
}
|
||||
|
||||
if (this.options.reverseProxyUrl) {
|
||||
this.completionsUrl = this.options.reverseProxyUrl;
|
||||
} else {
|
||||
this.completionsUrl = this.constructUrl();
|
||||
}
|
||||
|
||||
let promptPrefix = (this.options.promptPrefix ?? '').trim();
|
||||
if (typeof this.options.artifactsPrompt === 'string' && this.options.artifactsPrompt) {
|
||||
promptPrefix = `${promptPrefix ?? ''}\n${this.options.artifactsPrompt}`.trim();
|
||||
}
|
||||
this.options.promptPrefix = promptPrefix;
|
||||
this.initializeClient();
|
||||
return this;
|
||||
}
|
||||
|
||||
|
@ -336,7 +316,6 @@ class GoogleClient extends BaseClient {
|
|||
messages: [new HumanMessage(formatMessage({ message: latestMessage }))],
|
||||
},
|
||||
],
|
||||
parameters: this.modelOptions,
|
||||
};
|
||||
return { prompt: payload };
|
||||
}
|
||||
|
@ -352,23 +331,58 @@ class GoogleClient extends BaseClient {
|
|||
return { prompt: formattedMessages };
|
||||
}
|
||||
|
||||
async buildMessages(messages = [], parentMessageId) {
|
||||
/**
|
||||
* @param {TMessage[]} [messages=[]]
|
||||
* @param {string} [parentMessageId]
|
||||
*/
|
||||
async buildMessages(_messages = [], parentMessageId) {
|
||||
if (!this.isGenerativeModel && !this.project_id) {
|
||||
throw new Error(
|
||||
'[GoogleClient] a Service Account JSON Key is required for PaLM 2 and Codey models (Vertex AI)',
|
||||
);
|
||||
throw new Error('[GoogleClient] PaLM 2 and Codey models are no longer supported.');
|
||||
}
|
||||
|
||||
if (this.options.promptPrefix) {
|
||||
const instructionsTokenCount = this.getTokenCount(this.options.promptPrefix);
|
||||
|
||||
this.maxContextTokens = this.maxContextTokens - instructionsTokenCount;
|
||||
if (this.maxContextTokens < 0) {
|
||||
const info = `${instructionsTokenCount} / ${this.maxContextTokens}`;
|
||||
const errorMessage = `{ "type": "${ErrorTypes.INPUT_LENGTH}", "info": "${info}" }`;
|
||||
logger.warn(`Instructions token count exceeds max context (${info}).`);
|
||||
throw new Error(errorMessage);
|
||||
}
|
||||
}
|
||||
|
||||
for (let i = 0; i < _messages.length; i++) {
|
||||
const message = _messages[i];
|
||||
if (!message.tokenCount) {
|
||||
_messages[i].tokenCount = this.getTokenCountForMessage({
|
||||
role: message.isCreatedByUser ? 'user' : 'assistant',
|
||||
content: message.content ?? message.text,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
const {
|
||||
payload: messages,
|
||||
tokenCountMap,
|
||||
promptTokens,
|
||||
} = await this.handleContextStrategy({
|
||||
orderedMessages: _messages,
|
||||
formattedMessages: _messages,
|
||||
});
|
||||
|
||||
if (!this.project_id && !EXCLUDED_GENAI_MODELS.test(this.modelOptions.model)) {
|
||||
return await this.buildGenerativeMessages(messages);
|
||||
const result = await this.buildGenerativeMessages(messages);
|
||||
result.tokenCountMap = tokenCountMap;
|
||||
result.promptTokens = promptTokens;
|
||||
return result;
|
||||
}
|
||||
|
||||
if (this.options.attachments && this.isGenerativeModel) {
|
||||
return this.buildVisionMessages(messages, parentMessageId);
|
||||
}
|
||||
|
||||
if (this.isTextModel) {
|
||||
return this.buildMessagesPrompt(messages, parentMessageId);
|
||||
const result = this.buildVisionMessages(messages, parentMessageId);
|
||||
result.tokenCountMap = tokenCountMap;
|
||||
result.promptTokens = promptTokens;
|
||||
return result;
|
||||
}
|
||||
|
||||
let payload = {
|
||||
|
@ -380,25 +394,14 @@ class GoogleClient extends BaseClient {
|
|||
.map((message) => formatMessage({ message, langChain: true })),
|
||||
},
|
||||
],
|
||||
parameters: this.modelOptions,
|
||||
};
|
||||
|
||||
let promptPrefix = (this.options.promptPrefix ?? '').trim();
|
||||
if (typeof this.options.artifactsPrompt === 'string' && this.options.artifactsPrompt) {
|
||||
promptPrefix = `${promptPrefix ?? ''}\n${this.options.artifactsPrompt}`.trim();
|
||||
}
|
||||
|
||||
if (promptPrefix) {
|
||||
payload.instances[0].context = promptPrefix;
|
||||
}
|
||||
|
||||
if (this.options.examples.length > 0) {
|
||||
payload.instances[0].examples = this.options.examples;
|
||||
if (this.options.promptPrefix) {
|
||||
payload.instances[0].context = this.options.promptPrefix;
|
||||
}
|
||||
|
||||
logger.debug('[GoogleClient] buildMessages', payload);
|
||||
|
||||
return { prompt: payload };
|
||||
return { prompt: payload, tokenCountMap, promptTokens };
|
||||
}
|
||||
|
||||
async buildMessagesPrompt(messages, parentMessageId) {
|
||||
|
@ -412,10 +415,7 @@ class GoogleClient extends BaseClient {
|
|||
parentMessageId,
|
||||
});
|
||||
|
||||
const formattedMessages = orderedMessages.map((message) => ({
|
||||
author: message.isCreatedByUser ? this.userLabel : this.modelLabel,
|
||||
content: message?.content ?? message.text,
|
||||
}));
|
||||
const formattedMessages = orderedMessages.map(this.formatMessages());
|
||||
|
||||
let lastAuthor = '';
|
||||
let groupedMessages = [];
|
||||
|
@ -444,16 +444,6 @@ class GoogleClient extends BaseClient {
|
|||
}
|
||||
|
||||
let promptPrefix = (this.options.promptPrefix ?? '').trim();
|
||||
if (typeof this.options.artifactsPrompt === 'string' && this.options.artifactsPrompt) {
|
||||
promptPrefix = `${promptPrefix ?? ''}\n${this.options.artifactsPrompt}`.trim();
|
||||
}
|
||||
if (promptPrefix) {
|
||||
// If the prompt prefix doesn't end with the end token, add it.
|
||||
if (!promptPrefix.endsWith(`${this.endToken}`)) {
|
||||
promptPrefix = `${promptPrefix.trim()}${this.endToken}\n\n`;
|
||||
}
|
||||
promptPrefix = `\nContext:\n${promptPrefix}`;
|
||||
}
|
||||
|
||||
if (identityPrefix) {
|
||||
promptPrefix = `${identityPrefix}${promptPrefix}`;
|
||||
|
@ -490,7 +480,7 @@ class GoogleClient extends BaseClient {
|
|||
isCreatedByUser || !isEdited
|
||||
? `\n\n${message.author}:`
|
||||
: `${promptPrefix}\n\n${message.author}:`;
|
||||
const messageString = `${messagePrefix}\n${message.content}${this.endToken}\n`;
|
||||
const messageString = `${messagePrefix}\n${message.content}\n`;
|
||||
let newPromptBody = `${messageString}${promptBody}`;
|
||||
|
||||
context.unshift(message);
|
||||
|
@ -556,34 +546,6 @@ class GoogleClient extends BaseClient {
|
|||
return { prompt, context };
|
||||
}
|
||||
|
||||
async _getCompletion(payload, abortController = null) {
|
||||
if (!abortController) {
|
||||
abortController = new AbortController();
|
||||
}
|
||||
const { debug } = this.options;
|
||||
const url = this.completionsUrl;
|
||||
if (debug) {
|
||||
logger.debug('GoogleClient _getCompletion', { url, payload });
|
||||
}
|
||||
const opts = {
|
||||
method: 'POST',
|
||||
agent: new Agent({
|
||||
bodyTimeout: 0,
|
||||
headersTimeout: 0,
|
||||
}),
|
||||
signal: abortController.signal,
|
||||
};
|
||||
|
||||
if (this.options.proxy) {
|
||||
opts.agent = new ProxyAgent(this.options.proxy);
|
||||
}
|
||||
|
||||
const client = await this.getClient();
|
||||
const res = await client.request({ url, method: 'POST', data: payload });
|
||||
logger.debug('GoogleClient _getCompletion', { res });
|
||||
return res.data;
|
||||
}
|
||||
|
||||
createLLM(clientOptions) {
|
||||
const model = clientOptions.modelName ?? clientOptions.model;
|
||||
clientOptions.location = loc;
|
||||
|
@ -602,33 +564,20 @@ class GoogleClient extends BaseClient {
|
|||
}
|
||||
}
|
||||
|
||||
if (this.project_id && this.isTextModel) {
|
||||
logger.debug('Creating Google VertexAI client');
|
||||
return new GoogleVertexAI(clientOptions);
|
||||
} else if (this.project_id && this.isChatModel) {
|
||||
logger.debug('Creating Chat Google VertexAI client');
|
||||
return new ChatGoogleVertexAI(clientOptions);
|
||||
} else if (this.project_id) {
|
||||
if (this.project_id != null) {
|
||||
logger.debug('Creating VertexAI client');
|
||||
return new ChatVertexAI(clientOptions);
|
||||
} else if (!EXCLUDED_GENAI_MODELS.test(model)) {
|
||||
logger.debug('Creating GenAI client');
|
||||
return new GenAI(this.apiKey).getGenerativeModel({ ...clientOptions, model }, requestOptions);
|
||||
return new GenAI(this.apiKey).getGenerativeModel({ model }, requestOptions);
|
||||
}
|
||||
|
||||
logger.debug('Creating Chat Google Generative AI client');
|
||||
return new ChatGoogleGenerativeAI({ ...clientOptions, apiKey: this.apiKey });
|
||||
}
|
||||
|
||||
async getCompletion(_payload, options = {}) {
|
||||
const { parameters, instances } = _payload;
|
||||
const { onProgress, abortController } = options;
|
||||
const streamRate = this.options.streamRate ?? Constants.DEFAULT_STREAM_RATE;
|
||||
const { messages: _messages, context, examples: _examples } = instances?.[0] ?? {};
|
||||
|
||||
let examples;
|
||||
|
||||
let clientOptions = { ...parameters, maxRetries: 2 };
|
||||
initializeClient() {
|
||||
let clientOptions = { ...this.modelOptions, maxRetries: 2 };
|
||||
|
||||
if (this.project_id) {
|
||||
clientOptions['authOptions'] = {
|
||||
|
@ -639,53 +588,34 @@ class GoogleClient extends BaseClient {
|
|||
};
|
||||
}
|
||||
|
||||
if (!parameters) {
|
||||
clientOptions = { ...clientOptions, ...this.modelOptions };
|
||||
}
|
||||
|
||||
if (this.isGenerativeModel && !this.project_id) {
|
||||
clientOptions.modelName = clientOptions.model;
|
||||
delete clientOptions.model;
|
||||
}
|
||||
|
||||
if (_examples && _examples.length) {
|
||||
examples = _examples
|
||||
.map((ex) => {
|
||||
const { input, output } = ex;
|
||||
if (!input || !output) {
|
||||
return undefined;
|
||||
}
|
||||
return {
|
||||
input: new HumanMessage(input.content),
|
||||
output: new AIMessage(output.content),
|
||||
};
|
||||
})
|
||||
.filter((ex) => ex);
|
||||
this.client = this.createLLM(clientOptions);
|
||||
return this.client;
|
||||
}
|
||||
|
||||
clientOptions.examples = examples;
|
||||
}
|
||||
|
||||
const model = this.createLLM(clientOptions);
|
||||
async getCompletion(_payload, options = {}) {
|
||||
const safetySettings = this.getSafetySettings();
|
||||
const { onProgress, abortController } = options;
|
||||
const streamRate = this.options.streamRate ?? Constants.DEFAULT_STREAM_RATE;
|
||||
const modelName = this.modelOptions.modelName ?? this.modelOptions.model ?? '';
|
||||
|
||||
let reply = '';
|
||||
const messages = this.isTextModel ? _payload.trim() : _messages;
|
||||
|
||||
if (!this.isVisionModel && context && messages?.length > 0) {
|
||||
messages.unshift(new SystemMessage(context));
|
||||
}
|
||||
|
||||
const modelName = clientOptions.modelName ?? clientOptions.model ?? '';
|
||||
if (!EXCLUDED_GENAI_MODELS.test(modelName) && !this.project_id) {
|
||||
const client = model;
|
||||
/** @type {GenAI} */
|
||||
const client = this.client;
|
||||
/** @type {GenerateContentRequest} */
|
||||
const requestOptions = {
|
||||
safetySettings,
|
||||
contents: _payload,
|
||||
generationConfig: googleGenConfigSchema.parse(this.modelOptions),
|
||||
};
|
||||
|
||||
let promptPrefix = (this.options.promptPrefix ?? '').trim();
|
||||
if (typeof this.options.artifactsPrompt === 'string' && this.options.artifactsPrompt) {
|
||||
promptPrefix = `${promptPrefix ?? ''}\n${this.options.artifactsPrompt}`.trim();
|
||||
}
|
||||
|
||||
const promptPrefix = (this.options.promptPrefix ?? '').trim();
|
||||
if (promptPrefix.length) {
|
||||
requestOptions.systemInstruction = {
|
||||
parts: [
|
||||
|
@ -696,11 +626,14 @@ class GoogleClient extends BaseClient {
|
|||
};
|
||||
}
|
||||
|
||||
requestOptions.safetySettings = _payload.safetySettings;
|
||||
|
||||
const delay = modelName.includes('flash') ? 8 : 15;
|
||||
/** @type {GenAIUsageMetadata} */
|
||||
let usageMetadata;
|
||||
const result = await client.generateContentStream(requestOptions);
|
||||
for await (const chunk of result.stream) {
|
||||
usageMetadata = !usageMetadata
|
||||
? chunk?.usageMetadata
|
||||
: Object.assign(usageMetadata, chunk?.usageMetadata);
|
||||
const chunkText = chunk.text();
|
||||
await this.generateTextStream(chunkText, onProgress, {
|
||||
delay,
|
||||
|
@ -708,12 +641,29 @@ class GoogleClient extends BaseClient {
|
|||
reply += chunkText;
|
||||
await sleep(streamRate);
|
||||
}
|
||||
|
||||
if (usageMetadata) {
|
||||
this.usage = {
|
||||
input_tokens: usageMetadata.promptTokenCount,
|
||||
output_tokens: usageMetadata.candidatesTokenCount,
|
||||
};
|
||||
}
|
||||
return reply;
|
||||
}
|
||||
|
||||
const stream = await model.stream(messages, {
|
||||
const { instances } = _payload;
|
||||
const { messages: messages, context } = instances?.[0] ?? {};
|
||||
|
||||
if (!this.isVisionModel && context && messages?.length > 0) {
|
||||
messages.unshift(new SystemMessage(context));
|
||||
}
|
||||
|
||||
/** @type {import('@langchain/core/messages').AIMessageChunk['usage_metadata']} */
|
||||
let usageMetadata;
|
||||
const stream = await this.client.stream(messages, {
|
||||
signal: abortController.signal,
|
||||
safetySettings: _payload.safetySettings,
|
||||
streamUsage: true,
|
||||
safetySettings,
|
||||
});
|
||||
|
||||
let delay = this.options.streamRate || 8;
|
||||
|
@ -728,6 +678,9 @@ class GoogleClient extends BaseClient {
|
|||
}
|
||||
|
||||
for await (const chunk of stream) {
|
||||
usageMetadata = !usageMetadata
|
||||
? chunk?.usage_metadata
|
||||
: concat(usageMetadata, chunk?.usage_metadata);
|
||||
const chunkText = chunk?.content ?? chunk;
|
||||
await this.generateTextStream(chunkText, onProgress, {
|
||||
delay,
|
||||
|
@ -735,88 +688,114 @@ class GoogleClient extends BaseClient {
|
|||
reply += chunkText;
|
||||
}
|
||||
|
||||
if (usageMetadata) {
|
||||
this.usage = usageMetadata;
|
||||
}
|
||||
return reply;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get stream usage as returned by this client's API response.
|
||||
* @returns {UsageMetadata} The stream usage object.
|
||||
*/
|
||||
getStreamUsage() {
|
||||
return this.usage;
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculates the correct token count for the current user message based on the token count map and API usage.
|
||||
* Edge case: If the calculation results in a negative value, it returns the original estimate.
|
||||
* If revisiting a conversation with a chat history entirely composed of token estimates,
|
||||
* the cumulative token count going forward should become more accurate as the conversation progresses.
|
||||
* @param {Object} params - The parameters for the calculation.
|
||||
* @param {Record<string, number>} params.tokenCountMap - A map of message IDs to their token counts.
|
||||
* @param {string} params.currentMessageId - The ID of the current message to calculate.
|
||||
* @param {UsageMetadata} params.usage - The usage object returned by the API.
|
||||
* @returns {number} The correct token count for the current user message.
|
||||
*/
|
||||
calculateCurrentTokenCount({ tokenCountMap, currentMessageId, usage }) {
|
||||
const originalEstimate = tokenCountMap[currentMessageId] || 0;
|
||||
|
||||
if (!usage || typeof usage.input_tokens !== 'number') {
|
||||
return originalEstimate;
|
||||
}
|
||||
|
||||
tokenCountMap[currentMessageId] = 0;
|
||||
const totalTokensFromMap = Object.values(tokenCountMap).reduce((sum, count) => {
|
||||
const numCount = Number(count);
|
||||
return sum + (isNaN(numCount) ? 0 : numCount);
|
||||
}, 0);
|
||||
const totalInputTokens = usage.input_tokens ?? 0;
|
||||
const currentMessageTokens = totalInputTokens - totalTokensFromMap;
|
||||
return currentMessageTokens > 0 ? currentMessageTokens : originalEstimate;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param {object} params
|
||||
* @param {number} params.promptTokens
|
||||
* @param {number} params.completionTokens
|
||||
* @param {UsageMetadata} [params.usage]
|
||||
* @param {string} [params.model]
|
||||
* @param {string} [params.context='message']
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
async recordTokenUsage({ promptTokens, completionTokens, model, context = 'message' }) {
|
||||
await spendTokens(
|
||||
{
|
||||
context,
|
||||
user: this.user ?? this.options.req?.user?.id,
|
||||
conversationId: this.conversationId,
|
||||
model: model ?? this.modelOptions.model,
|
||||
endpointTokenConfig: this.options.endpointTokenConfig,
|
||||
},
|
||||
{ promptTokens, completionTokens },
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Stripped-down logic for generating a title. This uses the non-streaming APIs, since the user does not see titles streaming
|
||||
*/
|
||||
async titleChatCompletion(_payload, options = {}) {
|
||||
const { abortController } = options;
|
||||
const { parameters, instances } = _payload;
|
||||
const { messages: _messages, examples: _examples } = instances?.[0] ?? {};
|
||||
|
||||
let clientOptions = { ...parameters, maxRetries: 2 };
|
||||
|
||||
logger.debug('Initialized title client options');
|
||||
|
||||
if (this.project_id) {
|
||||
clientOptions['authOptions'] = {
|
||||
credentials: {
|
||||
...this.serviceKey,
|
||||
},
|
||||
projectId: this.project_id,
|
||||
};
|
||||
}
|
||||
|
||||
if (!parameters) {
|
||||
clientOptions = { ...clientOptions, ...this.modelOptions };
|
||||
}
|
||||
|
||||
if (this.isGenerativeModel && !this.project_id) {
|
||||
clientOptions.modelName = clientOptions.model;
|
||||
delete clientOptions.model;
|
||||
}
|
||||
|
||||
const model = this.createLLM(clientOptions);
|
||||
const safetySettings = this.getSafetySettings();
|
||||
|
||||
let reply = '';
|
||||
const messages = this.isTextModel ? _payload.trim() : _messages;
|
||||
|
||||
const modelName = clientOptions.modelName ?? clientOptions.model ?? '';
|
||||
if (!EXCLUDED_GENAI_MODELS.test(modelName) && !this.project_id) {
|
||||
const model = this.modelOptions.modelName ?? this.modelOptions.model ?? '';
|
||||
if (!EXCLUDED_GENAI_MODELS.test(model) && !this.project_id) {
|
||||
logger.debug('Identified titling model as GenAI version');
|
||||
/** @type {GenerativeModel} */
|
||||
const client = model;
|
||||
const client = this.client;
|
||||
const requestOptions = {
|
||||
contents: _payload,
|
||||
safetySettings,
|
||||
generationConfig: {
|
||||
temperature: 0.5,
|
||||
},
|
||||
};
|
||||
|
||||
let promptPrefix = (this.options.promptPrefix ?? '').trim();
|
||||
if (typeof this.options.artifactsPrompt === 'string' && this.options.artifactsPrompt) {
|
||||
promptPrefix = `${promptPrefix ?? ''}\n${this.options.artifactsPrompt}`.trim();
|
||||
}
|
||||
|
||||
if (this.options?.promptPrefix?.length) {
|
||||
requestOptions.systemInstruction = {
|
||||
parts: [
|
||||
{
|
||||
text: promptPrefix,
|
||||
},
|
||||
],
|
||||
};
|
||||
}
|
||||
|
||||
const safetySettings = _payload.safetySettings;
|
||||
requestOptions.safetySettings = safetySettings;
|
||||
|
||||
const result = await client.generateContent(requestOptions);
|
||||
|
||||
reply = result.response?.text();
|
||||
|
||||
return reply;
|
||||
} else {
|
||||
logger.debug('Beginning titling');
|
||||
const safetySettings = _payload.safetySettings;
|
||||
|
||||
const titleResponse = await model.invoke(messages, {
|
||||
const { instances } = _payload;
|
||||
const { messages } = instances?.[0] ?? {};
|
||||
const titleResponse = await this.client.invoke(messages, {
|
||||
signal: abortController.signal,
|
||||
timeout: 7000,
|
||||
safetySettings: safetySettings,
|
||||
safetySettings,
|
||||
});
|
||||
|
||||
if (titleResponse.usage_metadata) {
|
||||
await this.recordTokenUsage({
|
||||
model,
|
||||
promptTokens: titleResponse.usage_metadata.input_tokens,
|
||||
completionTokens: titleResponse.usage_metadata.output_tokens,
|
||||
context: 'title',
|
||||
});
|
||||
}
|
||||
|
||||
reply = titleResponse.content;
|
||||
// TODO: RECORD TOKEN USAGE
|
||||
return reply;
|
||||
}
|
||||
}
|
||||
|
@ -840,15 +819,19 @@ class GoogleClient extends BaseClient {
|
|||
},
|
||||
]);
|
||||
|
||||
const model = process.env.GOOGLE_TITLE_MODEL ?? this.modelOptions.model;
|
||||
const availableModels = this.options.modelsConfig?.[EModelEndpoint.google];
|
||||
this.isVisionModel = validateVisionModel({ model, availableModels });
|
||||
|
||||
if (this.isVisionModel) {
|
||||
logger.warn(
|
||||
`Current vision model does not support titling without an attachment; falling back to default model ${settings.model.default}`,
|
||||
);
|
||||
|
||||
payload.parameters = { ...payload.parameters, model: settings.model.default };
|
||||
this.modelOptions.model = settings.model.default;
|
||||
}
|
||||
|
||||
try {
|
||||
this.initializeClient();
|
||||
title = await this.titleChatCompletion(payload, {
|
||||
abortController: new AbortController(),
|
||||
onProgress: () => {},
|
||||
|
@ -865,6 +848,7 @@ class GoogleClient extends BaseClient {
|
|||
endpointType: null,
|
||||
artifacts: this.options.artifacts,
|
||||
promptPrefix: this.options.promptPrefix,
|
||||
maxContextTokens: this.options.maxContextTokens,
|
||||
modelLabel: this.options.modelLabel,
|
||||
iconURL: this.options.iconURL,
|
||||
greeting: this.options.greeting,
|
||||
|
@ -878,8 +862,6 @@ class GoogleClient extends BaseClient {
|
|||
}
|
||||
|
||||
async sendCompletion(payload, opts = {}) {
|
||||
payload.safetySettings = this.getSafetySettings();
|
||||
|
||||
let reply = '';
|
||||
reply = await this.getCompletion(payload, opts);
|
||||
return reply.trim();
|
||||
|
@ -931,6 +913,22 @@ class GoogleClient extends BaseClient {
|
|||
return 'cl100k_base';
|
||||
}
|
||||
|
||||
async getVertexTokenCount(text) {
|
||||
/** @type {ChatVertexAI} */
|
||||
const client = this.client ?? this.initializeClient();
|
||||
const connection = client.connection;
|
||||
const gAuthClient = connection.client;
|
||||
const tokenEndpoint = `https://${connection._endpoint}/${connection.apiVersion}/projects/${this.project_id}/locations/${connection._location}/publishers/google/models/${connection.model}/:countTokens`;
|
||||
const result = await gAuthClient.request({
|
||||
url: tokenEndpoint,
|
||||
method: 'POST',
|
||||
data: {
|
||||
contents: [{ role: 'user', parts: [{ text }] }],
|
||||
},
|
||||
});
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the token count of a given text. It also checks and resets the tokenizers if necessary.
|
||||
* @param {string} text - The text to get the token count for.
|
||||
|
|
|
@ -102,7 +102,6 @@
|
|||
"tiktoken": "^1.0.15",
|
||||
"traverse": "^0.6.7",
|
||||
"ua-parser-js": "^1.0.36",
|
||||
"undici": "^7.2.3",
|
||||
"winston": "^3.11.0",
|
||||
"winston-daily-rotate-file": "^4.7.1",
|
||||
"zod": "^3.22.4"
|
||||
|
|
|
@ -11,6 +11,7 @@ const buildOptions = (endpoint, parsedBody) => {
|
|||
greeting,
|
||||
spec,
|
||||
artifacts,
|
||||
maxContextTokens,
|
||||
...modelOptions
|
||||
} = parsedBody;
|
||||
const endpointOption = removeNullishValues({
|
||||
|
@ -22,6 +23,7 @@ const buildOptions = (endpoint, parsedBody) => {
|
|||
iconURL,
|
||||
greeting,
|
||||
spec,
|
||||
maxContextTokens,
|
||||
modelOptions,
|
||||
});
|
||||
|
||||
|
|
|
@ -1,9 +1,6 @@
|
|||
const { Providers } = require('@librechat/agents');
|
||||
const { AuthKeys } = require('librechat-data-provider');
|
||||
|
||||
// Example internal constant from your code
|
||||
const EXCLUDED_GENAI_MODELS = /gemini-(?:1\.0|1-0|pro)/;
|
||||
|
||||
/**
|
||||
*
|
||||
* @param {boolean} isGemini2
|
||||
|
@ -89,22 +86,12 @@ function getLLMConfig(credentials, options = {}) {
|
|||
|
||||
/** Used only for Safety Settings */
|
||||
const isGemini2 = llmConfig.model.includes('gemini-2.0') && !llmConfig.model.includes('thinking');
|
||||
const isGenerativeModel = llmConfig.model.includes('gemini');
|
||||
const isChatModel = !isGenerativeModel && llmConfig.model.includes('chat');
|
||||
const isTextModel = !isGenerativeModel && !isChatModel && /code|text/.test(llmConfig.model);
|
||||
|
||||
llmConfig.safetySettings = getSafetySettings(isGemini2);
|
||||
|
||||
let provider;
|
||||
|
||||
if (project_id && isTextModel) {
|
||||
if (project_id) {
|
||||
provider = Providers.VERTEXAI;
|
||||
} else if (project_id && isChatModel) {
|
||||
provider = Providers.VERTEXAI;
|
||||
} else if (project_id) {
|
||||
provider = Providers.VERTEXAI;
|
||||
} else if (!EXCLUDED_GENAI_MODELS.test(llmConfig.model)) {
|
||||
provider = Providers.GOOGLE;
|
||||
} else {
|
||||
provider = Providers.GOOGLE;
|
||||
}
|
||||
|
|
|
@ -155,6 +155,18 @@
|
|||
* @memberof typedefs
|
||||
*/
|
||||
|
||||
/**
|
||||
* @exports GenerateContentRequest
|
||||
* @typedef {import('@google/generative-ai').GenerateContentRequest} GenerateContentRequest
|
||||
* @memberof typedefs
|
||||
*/
|
||||
|
||||
/**
|
||||
* @exports GenAIUsageMetadata
|
||||
* @typedef {import('@google/generative-ai').UsageMetadata} GenAIUsageMetadata
|
||||
* @memberof typedefs
|
||||
*/
|
||||
|
||||
/**
|
||||
* @exports AssistantStreamEvent
|
||||
* @typedef {import('openai').default.Beta.AssistantStreamEvent} AssistantStreamEvent
|
||||
|
|
|
@ -34,7 +34,10 @@ function getOpenAIColor(_model: string | null | undefined) {
|
|||
function getGoogleIcon(model: string | null | undefined, size: number) {
|
||||
if (model?.toLowerCase().includes('code') === true) {
|
||||
return <CodeyIcon size={size * 0.75} />;
|
||||
} else if (model?.toLowerCase().includes('gemini') === true) {
|
||||
} else if (
|
||||
model?.toLowerCase().includes('gemini') === true ||
|
||||
model?.toLowerCase().includes('learnlm') === true
|
||||
) {
|
||||
return <GeminiIcon size={size * 0.7} />;
|
||||
} else {
|
||||
return <PaLMIcon size={size * 0.7} />;
|
||||
|
@ -44,7 +47,10 @@ function getGoogleIcon(model: string | null | undefined, size: number) {
|
|||
function getGoogleModelName(model: string | null | undefined) {
|
||||
if (model?.toLowerCase().includes('code') === true) {
|
||||
return 'Codey';
|
||||
} else if (model?.toLowerCase().includes('gemini') === true) {
|
||||
} else if (
|
||||
model?.toLowerCase().includes('gemini') === true ||
|
||||
model?.toLowerCase().includes('learnlm') === true
|
||||
) {
|
||||
return 'Gemini';
|
||||
} else {
|
||||
return 'PaLM2';
|
||||
|
|
|
@ -48,12 +48,15 @@ function DynamicInput({
|
|||
|
||||
const handleInputChange = (e: React.ChangeEvent<HTMLInputElement>) => {
|
||||
const value = e.target.value;
|
||||
if (type === 'number') {
|
||||
if (!isNaN(Number(value))) {
|
||||
setInputValue(e, true);
|
||||
}
|
||||
} else {
|
||||
if (type !== 'number') {
|
||||
setInputValue(e);
|
||||
return;
|
||||
}
|
||||
|
||||
if (value === '') {
|
||||
setInputValue(e);
|
||||
} else if (!isNaN(Number(value))) {
|
||||
setInputValue(e, true);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
12
package-lock.json
generated
12
package-lock.json
generated
|
@ -111,7 +111,6 @@
|
|||
"tiktoken": "^1.0.15",
|
||||
"traverse": "^0.6.7",
|
||||
"ua-parser-js": "^1.0.36",
|
||||
"undici": "^7.2.3",
|
||||
"winston": "^3.11.0",
|
||||
"winston-daily-rotate-file": "^4.7.1",
|
||||
"zod": "^3.22.4"
|
||||
|
@ -33055,15 +33054,6 @@
|
|||
"integrity": "sha512-WxONCrssBM8TSPRqN5EmsjVrsv4A8X12J4ArBiiayv3DyyG3ZlIg6yysuuSYdZsVz3TKcTg2fd//Ujd4CHV1iA==",
|
||||
"dev": true
|
||||
},
|
||||
"node_modules/undici": {
|
||||
"version": "7.2.3",
|
||||
"resolved": "https://registry.npmjs.org/undici/-/undici-7.2.3.tgz",
|
||||
"integrity": "sha512-2oSLHaDalSt2/O/wHA9M+/ZPAOcU2yrSP/cdBYJ+YxZskiPYDSqHbysLSlD7gq3JMqOoJI5O31RVU3BxX/MnAA==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">=20.18.1"
|
||||
}
|
||||
},
|
||||
"node_modules/undici-types": {
|
||||
"version": "5.26.5",
|
||||
"resolved": "https://registry.npmjs.org/undici-types/-/undici-types-5.26.5.tgz",
|
||||
|
@ -35030,7 +35020,7 @@
|
|||
},
|
||||
"packages/data-provider": {
|
||||
"name": "librechat-data-provider",
|
||||
"version": "0.7.694",
|
||||
"version": "0.7.695",
|
||||
"license": "ISC",
|
||||
"dependencies": {
|
||||
"axios": "^1.7.7",
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
{
|
||||
"name": "librechat-data-provider",
|
||||
"version": "0.7.694",
|
||||
"version": "0.7.695",
|
||||
"description": "data services for librechat apps",
|
||||
"main": "dist/index.js",
|
||||
"module": "dist/index.es.js",
|
||||
|
|
|
@ -272,7 +272,7 @@ export const getResponseSender = (endpointOption: t.TEndpointOption): string =>
|
|||
if (endpoint === EModelEndpoint.google) {
|
||||
if (modelLabel) {
|
||||
return modelLabel;
|
||||
} else if (model && model.includes('gemini')) {
|
||||
} else if (model && (model.includes('gemini') || model.includes('learnlm'))) {
|
||||
return 'Gemini';
|
||||
} else if (model && model.includes('code')) {
|
||||
return 'Codey';
|
||||
|
|
|
@ -788,6 +788,25 @@ export const googleSchema = tConversationSchema
|
|||
maxContextTokens: undefined,
|
||||
}));
|
||||
|
||||
/**
|
||||
* TODO: Map the following fields:
|
||||
- presence_penalty -> presencePenalty
|
||||
- frequency_penalty -> frequencyPenalty
|
||||
- stop -> stopSequences
|
||||
*/
|
||||
export const googleGenConfigSchema = z
|
||||
.object({
|
||||
maxOutputTokens: coerceNumber.optional(),
|
||||
temperature: coerceNumber.optional(),
|
||||
topP: coerceNumber.optional(),
|
||||
topK: coerceNumber.optional(),
|
||||
presencePenalty: coerceNumber.optional(),
|
||||
frequencyPenalty: coerceNumber.optional(),
|
||||
stopSequences: z.array(z.string()).optional(),
|
||||
})
|
||||
.strip()
|
||||
.optional();
|
||||
|
||||
export const bingAISchema = tConversationSchema
|
||||
.pick({
|
||||
jailbreak: true,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue