🔧 refactor: Improve Agent Context & Minor Fixes (#5349)

* refactor: Improve Context for Agents

* 🔧 fix: Safeguard against undefined properties in OpenAIClient response handling

* refactor: log error before re-throwing for original stack trace

* refactor: remove toolResource state from useFileHandling, allow svg files

* refactor: prevent verbose logs from axios errors when using actions

* refactor: add silent method recordTokenUsage in AgentClient

* refactor: streamline token count assignment in BaseClient

* refactor: enhance safety settings handling for Gemini 2.0 model

* fix: capabilities structure in MCPConnection

* refactor: simplify civic integrity threshold handling in GoogleClient and llm

* refactor: update token count retrieval method in BaseClient tests

* ci: fix test for svg
This commit is contained in:
Danny Avila 2025-01-17 12:55:48 -05:00 committed by GitHub
parent e309c6abef
commit b35a8b78e2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 324 additions and 112 deletions

View file

@ -4,6 +4,7 @@ const {
supportsBalanceCheck, supportsBalanceCheck,
isAgentsEndpoint, isAgentsEndpoint,
isParamEndpoint, isParamEndpoint,
EModelEndpoint,
ErrorTypes, ErrorTypes,
Constants, Constants,
CacheKeys, CacheKeys,
@ -11,6 +12,7 @@ const {
} = require('librechat-data-provider'); } = require('librechat-data-provider');
const { getMessages, saveMessage, updateMessage, saveConvo } = require('~/models'); const { getMessages, saveMessage, updateMessage, saveConvo } = require('~/models');
const { addSpaceIfNeeded, isEnabled } = require('~/server/utils'); const { addSpaceIfNeeded, isEnabled } = require('~/server/utils');
const { truncateToolCallOutputs } = require('./prompts');
const checkBalance = require('~/models/checkBalance'); const checkBalance = require('~/models/checkBalance');
const { getFiles } = require('~/models/File'); const { getFiles } = require('~/models/File');
const { getLogStores } = require('~/cache'); const { getLogStores } = require('~/cache');
@ -95,7 +97,7 @@ class BaseClient {
* @returns {number} * @returns {number}
*/ */
getTokenCountForResponse(responseMessage) { getTokenCountForResponse(responseMessage) {
logger.debug('`[BaseClient] recordTokenUsage` not implemented.', responseMessage); logger.debug('[BaseClient] `recordTokenUsage` not implemented.', responseMessage);
} }
/** /**
@ -106,7 +108,7 @@ class BaseClient {
* @returns {Promise<void>} * @returns {Promise<void>}
*/ */
async recordTokenUsage({ promptTokens, completionTokens }) { async recordTokenUsage({ promptTokens, completionTokens }) {
logger.debug('`[BaseClient] recordTokenUsage` not implemented.', { logger.debug('[BaseClient] `recordTokenUsage` not implemented.', {
promptTokens, promptTokens,
completionTokens, completionTokens,
}); });
@ -287,6 +289,9 @@ class BaseClient {
} }
async handleTokenCountMap(tokenCountMap) { async handleTokenCountMap(tokenCountMap) {
if (this.clientName === EModelEndpoint.agents) {
return;
}
if (this.currentMessages.length === 0) { if (this.currentMessages.length === 0) {
return; return;
} }
@ -394,6 +399,21 @@ class BaseClient {
_instructions && logger.debug('[BaseClient] instructions tokenCount: ' + tokenCount); _instructions && logger.debug('[BaseClient] instructions tokenCount: ' + tokenCount);
let payload = this.addInstructions(formattedMessages, _instructions); let payload = this.addInstructions(formattedMessages, _instructions);
let orderedWithInstructions = this.addInstructions(orderedMessages, instructions); let orderedWithInstructions = this.addInstructions(orderedMessages, instructions);
if (this.clientName === EModelEndpoint.agents) {
const { dbMessages, editedIndices } = truncateToolCallOutputs(
orderedWithInstructions,
this.maxContextTokens,
this.getTokenCountForMessage.bind(this),
);
if (editedIndices.length > 0) {
logger.debug('[BaseClient] Truncated tool call outputs:', editedIndices);
for (const index of editedIndices) {
payload[index].content = dbMessages[index].content;
}
orderedWithInstructions = dbMessages;
}
}
let { context, remainingContextTokens, messagesToRefine, summaryIndex } = let { context, remainingContextTokens, messagesToRefine, summaryIndex } =
await this.getMessagesWithinTokenLimit(orderedWithInstructions); await this.getMessagesWithinTokenLimit(orderedWithInstructions);
@ -625,7 +645,7 @@ class BaseClient {
await this.updateUserMessageTokenCount({ usage, tokenCountMap, userMessage, opts }); await this.updateUserMessageTokenCount({ usage, tokenCountMap, userMessage, opts });
} else { } else {
responseMessage.tokenCount = this.getTokenCountForResponse(responseMessage); responseMessage.tokenCount = this.getTokenCountForResponse(responseMessage);
completionTokens = this.getTokenCount(completion); completionTokens = responseMessage.tokenCount;
} }
await this.recordTokenUsage({ promptTokens, completionTokens, usage }); await this.recordTokenUsage({ promptTokens, completionTokens, usage });

View file

@ -886,32 +886,42 @@ class GoogleClient extends BaseClient {
} }
getSafetySettings() { getSafetySettings() {
const isGemini2 = this.modelOptions.model.includes('gemini-2.0');
const mapThreshold = (value) => {
if (isGemini2 && value === 'BLOCK_NONE') {
return 'OFF';
}
return value;
};
return [ return [
{ {
category: 'HARM_CATEGORY_SEXUALLY_EXPLICIT', category: 'HARM_CATEGORY_SEXUALLY_EXPLICIT',
threshold: threshold: mapThreshold(
process.env.GOOGLE_SAFETY_SEXUALLY_EXPLICIT || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED', process.env.GOOGLE_SAFETY_SEXUALLY_EXPLICIT || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
),
}, },
{ {
category: 'HARM_CATEGORY_HATE_SPEECH', category: 'HARM_CATEGORY_HATE_SPEECH',
threshold: process.env.GOOGLE_SAFETY_HATE_SPEECH || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED', threshold: mapThreshold(
process.env.GOOGLE_SAFETY_HATE_SPEECH || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
),
}, },
{ {
category: 'HARM_CATEGORY_HARASSMENT', category: 'HARM_CATEGORY_HARASSMENT',
threshold: process.env.GOOGLE_SAFETY_HARASSMENT || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED', threshold: mapThreshold(
process.env.GOOGLE_SAFETY_HARASSMENT || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
),
}, },
{ {
category: 'HARM_CATEGORY_DANGEROUS_CONTENT', category: 'HARM_CATEGORY_DANGEROUS_CONTENT',
threshold: threshold: mapThreshold(
process.env.GOOGLE_SAFETY_DANGEROUS_CONTENT || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED', process.env.GOOGLE_SAFETY_DANGEROUS_CONTENT || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
),
}, },
{ {
category: 'HARM_CATEGORY_CIVIC_INTEGRITY', category: 'HARM_CATEGORY_CIVIC_INTEGRITY',
/** threshold: mapThreshold(process.env.GOOGLE_SAFETY_CIVIC_INTEGRITY || 'BLOCK_NONE'),
* Note: this was added since `gemini-2.0-flash-thinking-exp-1219` does not
* accept 'HARM_BLOCK_THRESHOLD_UNSPECIFIED' for 'HARM_CATEGORY_CIVIC_INTEGRITY'
* */
threshold: process.env.GOOGLE_SAFETY_CIVIC_INTEGRITY || 'BLOCK_NONE',
}, },
]; ];
} }

View file

@ -1293,7 +1293,7 @@ ${convo}
}); });
for await (const chunk of stream) { for await (const chunk of stream) {
const token = chunk.choices[0]?.delta?.content || ''; const token = chunk?.choices?.[0]?.delta?.content || '';
intermediateReply.push(token); intermediateReply.push(token);
onProgress(token); onProgress(token);
if (abortController.signal.aborted) { if (abortController.signal.aborted) {

View file

@ -4,7 +4,7 @@ const summaryPrompts = require('./summaryPrompts');
const handleInputs = require('./handleInputs'); const handleInputs = require('./handleInputs');
const instructions = require('./instructions'); const instructions = require('./instructions');
const titlePrompts = require('./titlePrompts'); const titlePrompts = require('./titlePrompts');
const truncateText = require('./truncateText'); const truncate = require('./truncate');
const createVisionPrompt = require('./createVisionPrompt'); const createVisionPrompt = require('./createVisionPrompt');
const createContextHandlers = require('./createContextHandlers'); const createContextHandlers = require('./createContextHandlers');
@ -15,7 +15,7 @@ module.exports = {
...handleInputs, ...handleInputs,
...instructions, ...instructions,
...titlePrompts, ...titlePrompts,
...truncateText, ...truncate,
createVisionPrompt, createVisionPrompt,
createContextHandlers, createContextHandlers,
}; };

View file

@ -0,0 +1,115 @@
const MAX_CHAR = 255;
/**
* Truncates a given text to a specified maximum length, appending ellipsis and a notification
* if the original text exceeds the maximum length.
*
* @param {string} text - The text to be truncated.
* @param {number} [maxLength=MAX_CHAR] - The maximum length of the text after truncation. Defaults to MAX_CHAR.
* @returns {string} The truncated text if the original text length exceeds maxLength, otherwise returns the original text.
*/
function truncateText(text, maxLength = MAX_CHAR) {
if (text.length > maxLength) {
return `${text.slice(0, maxLength)}... [text truncated for brevity]`;
}
return text;
}
/**
* Truncates a given text to a specified maximum length by showing the first half and the last half of the text,
* separated by ellipsis. This method ensures the output does not exceed the maximum length, including the addition
* of ellipsis and notification if the original text exceeds the maximum length.
*
* @param {string} text - The text to be truncated.
* @param {number} [maxLength=MAX_CHAR] - The maximum length of the output text after truncation. Defaults to MAX_CHAR.
* @returns {string} The truncated text showing the first half and the last half, or the original text if it does not exceed maxLength.
*/
function smartTruncateText(text, maxLength = MAX_CHAR) {
const ellipsis = '...';
const notification = ' [text truncated for brevity]';
const halfMaxLength = Math.floor((maxLength - ellipsis.length - notification.length) / 2);
if (text.length > maxLength) {
const startLastHalf = text.length - halfMaxLength;
return `${text.slice(0, halfMaxLength)}${ellipsis}${text.slice(startLastHalf)}${notification}`;
}
return text;
}
/**
* @param {TMessage[]} _messages
* @param {number} maxContextTokens
* @param {function({role: string, content: TMessageContent[]}): number} getTokenCountForMessage
*
* @returns {{
* dbMessages: TMessage[],
* editedIndices: number[]
* }}
*/
function truncateToolCallOutputs(_messages, maxContextTokens, getTokenCountForMessage) {
const THRESHOLD_PERCENTAGE = 0.5;
const targetTokenLimit = maxContextTokens * THRESHOLD_PERCENTAGE;
let currentTokenCount = 3;
const messages = [..._messages];
const processedMessages = [];
let currentIndex = messages.length;
const editedIndices = new Set();
while (messages.length > 0) {
currentIndex--;
const message = messages.pop();
currentTokenCount += message.tokenCount;
if (currentTokenCount < targetTokenLimit) {
processedMessages.push(message);
continue;
}
if (!message.content || !Array.isArray(message.content)) {
processedMessages.push(message);
continue;
}
const toolCallIndices = message.content
.map((item, index) => (item.type === 'tool_call' ? index : -1))
.filter((index) => index !== -1)
.reverse();
if (toolCallIndices.length === 0) {
processedMessages.push(message);
continue;
}
const newContent = [...message.content];
// Truncate all tool outputs since we're over threshold
for (const index of toolCallIndices) {
const toolCall = newContent[index].tool_call;
if (!toolCall || !toolCall.output) {
continue;
}
editedIndices.add(currentIndex);
newContent[index] = {
...newContent[index],
tool_call: {
...toolCall,
output: '[OUTPUT_OMITTED_FOR_BREVITY]',
},
};
}
const truncatedMessage = {
...message,
content: newContent,
tokenCount: getTokenCountForMessage({ role: 'assistant', content: newContent }),
};
processedMessages.push(truncatedMessage);
}
return { dbMessages: processedMessages.reverse(), editedIndices: Array.from(editedIndices) };
}
module.exports = { truncateText, smartTruncateText, truncateToolCallOutputs };

View file

@ -1,40 +0,0 @@
const MAX_CHAR = 255;
/**
* Truncates a given text to a specified maximum length, appending ellipsis and a notification
* if the original text exceeds the maximum length.
*
* @param {string} text - The text to be truncated.
* @param {number} [maxLength=MAX_CHAR] - The maximum length of the text after truncation. Defaults to MAX_CHAR.
* @returns {string} The truncated text if the original text length exceeds maxLength, otherwise returns the original text.
*/
function truncateText(text, maxLength = MAX_CHAR) {
if (text.length > maxLength) {
return `${text.slice(0, maxLength)}... [text truncated for brevity]`;
}
return text;
}
/**
* Truncates a given text to a specified maximum length by showing the first half and the last half of the text,
* separated by ellipsis. This method ensures the output does not exceed the maximum length, including the addition
* of ellipsis and notification if the original text exceeds the maximum length.
*
* @param {string} text - The text to be truncated.
* @param {number} [maxLength=MAX_CHAR] - The maximum length of the output text after truncation. Defaults to MAX_CHAR.
* @returns {string} The truncated text showing the first half and the last half, or the original text if it does not exceed maxLength.
*/
function smartTruncateText(text, maxLength = MAX_CHAR) {
const ellipsis = '...';
const notification = ' [text truncated for brevity]';
const halfMaxLength = Math.floor((maxLength - ellipsis.length - notification.length) / 2);
if (text.length > maxLength) {
const startLastHalf = text.length - halfMaxLength;
return `${text.slice(0, halfMaxLength)}${ellipsis}${text.slice(startLastHalf)}${notification}`;
}
return text;
}
module.exports = { truncateText, smartTruncateText };

View file

@ -615,9 +615,9 @@ describe('BaseClient', () => {
test('getTokenCount for response is called with the correct arguments', async () => { test('getTokenCount for response is called with the correct arguments', async () => {
const tokenCountMap = {}; // Mock tokenCountMap const tokenCountMap = {}; // Mock tokenCountMap
TestClient.buildMessages.mockReturnValue({ prompt: [], tokenCountMap }); TestClient.buildMessages.mockReturnValue({ prompt: [], tokenCountMap });
TestClient.getTokenCount = jest.fn(); TestClient.getTokenCountForResponse = jest.fn();
const response = await TestClient.sendMessage('Hello, world!', {}); const response = await TestClient.sendMessage('Hello, world!', {});
expect(TestClient.getTokenCount).toHaveBeenCalledWith(response.text); expect(TestClient.getTokenCountForResponse).toHaveBeenCalledWith(response);
}); });
test('returns an object with the correct shape', async () => { test('returns an object with the correct shape', async () => {

View file

@ -23,6 +23,8 @@ async function handleOpenAIErrors(err, errorCallback, context = 'stream') {
logger.warn(`[OpenAIClient.chatCompletion][${context}] Unhandled error type`); logger.warn(`[OpenAIClient.chatCompletion][${context}] Unhandled error type`);
} }
logger.error(err);
if (errorCallback) { if (errorCallback) {
errorCallback(err); errorCallback(err);
} }

View file

@ -60,6 +60,9 @@ const noSystemModelRegex = [/\bo1\b/gi];
class AgentClient extends BaseClient { class AgentClient extends BaseClient {
constructor(options = {}) { constructor(options = {}) {
super(null, options); super(null, options);
/** The current client class
* @type {string} */
this.clientName = EModelEndpoint.agents;
/** @type {'discard' | 'summarize'} */ /** @type {'discard' | 'summarize'} */
this.contextStrategy = 'discard'; this.contextStrategy = 'discard';
@ -91,6 +94,14 @@ class AgentClient extends BaseClient {
this.options = Object.assign({ endpoint: options.endpoint }, clientOptions); this.options = Object.assign({ endpoint: options.endpoint }, clientOptions);
/** @type {string} */ /** @type {string} */
this.model = this.options.agent.model_parameters.model; this.model = this.options.agent.model_parameters.model;
/** The key for the usage object's input tokens
* @type {string} */
this.inputTokensKey = 'input_tokens';
/** The key for the usage object's output tokens
* @type {string} */
this.outputTokensKey = 'output_tokens';
/** @type {UsageMetadata} */
this.usage;
} }
/** /**
@ -329,16 +340,18 @@ class AgentClient extends BaseClient {
this.options.agent.instructions = systemContent; this.options.agent.instructions = systemContent;
} }
/** @type {Record<string, number> | undefined} */
let tokenCountMap;
if (this.contextStrategy) { if (this.contextStrategy) {
({ payload, promptTokens, messages } = await this.handleContextStrategy({ ({ payload, promptTokens, tokenCountMap, messages } = await this.handleContextStrategy({
orderedMessages, orderedMessages,
formattedMessages, formattedMessages,
/* prefer usage_metadata from final message */
buildTokenMap: false,
})); }));
} }
const result = { const result = {
tokenCountMap,
prompt: payload, prompt: payload,
promptTokens, promptTokens,
messages, messages,
@ -368,8 +381,26 @@ class AgentClient extends BaseClient {
* @param {UsageMetadata[]} [params.collectedUsage=this.collectedUsage] * @param {UsageMetadata[]} [params.collectedUsage=this.collectedUsage]
*/ */
async recordCollectedUsage({ model, context = 'message', collectedUsage = this.collectedUsage }) { async recordCollectedUsage({ model, context = 'message', collectedUsage = this.collectedUsage }) {
for (const usage of collectedUsage) { if (!collectedUsage || !collectedUsage.length) {
await spendTokens( return;
}
const input_tokens = collectedUsage[0]?.input_tokens || 0;
let output_tokens = 0;
let previousTokens = input_tokens; // Start with original input
for (let i = 0; i < collectedUsage.length; i++) {
const usage = collectedUsage[i];
if (i > 0) {
// Count new tokens generated (input_tokens minus previous accumulated tokens)
output_tokens += (Number(usage.input_tokens) || 0) - previousTokens;
}
// Add this message's output tokens
output_tokens += Number(usage.output_tokens) || 0;
// Update previousTokens to include this message's output
previousTokens += Number(usage.output_tokens) || 0;
spendTokens(
{ {
context, context,
conversationId: this.conversationId, conversationId: this.conversationId,
@ -378,8 +409,66 @@ class AgentClient extends BaseClient {
model: usage.model ?? model ?? this.model ?? this.options.agent.model_parameters.model, model: usage.model ?? model ?? this.model ?? this.options.agent.model_parameters.model,
}, },
{ promptTokens: usage.input_tokens, completionTokens: usage.output_tokens }, { promptTokens: usage.input_tokens, completionTokens: usage.output_tokens },
).catch((err) => {
logger.error(
'[api/server/controllers/agents/client.js #recordCollectedUsage] Error spending tokens',
err,
); );
});
} }
this.usage = {
input_tokens,
output_tokens,
};
}
/**
* Get stream usage as returned by this client's API response.
* @returns {UsageMetadata} The stream usage object.
*/
getStreamUsage() {
return this.usage;
}
/**
* @param {TMessage} responseMessage
* @returns {number}
*/
getTokenCountForResponse({ content }) {
return this.getTokenCountForMessage({
role: 'assistant',
content,
});
}
/**
* Calculates the correct token count for the current user message based on the token count map and API usage.
* Edge case: If the calculation results in a negative value, it returns the original estimate.
* If revisiting a conversation with a chat history entirely composed of token estimates,
* the cumulative token count going forward should become more accurate as the conversation progresses.
* @param {Object} params - The parameters for the calculation.
* @param {Record<string, number>} params.tokenCountMap - A map of message IDs to their token counts.
* @param {string} params.currentMessageId - The ID of the current message to calculate.
* @param {OpenAIUsageMetadata} params.usage - The usage object returned by the API.
* @returns {number} The correct token count for the current user message.
*/
calculateCurrentTokenCount({ tokenCountMap, currentMessageId, usage }) {
const originalEstimate = tokenCountMap[currentMessageId] || 0;
if (!usage || typeof usage[this.inputTokensKey] !== 'number') {
return originalEstimate;
}
tokenCountMap[currentMessageId] = 0;
const totalTokensFromMap = Object.values(tokenCountMap).reduce((sum, count) => {
const numCount = Number(count);
return sum + (isNaN(numCount) ? 0 : numCount);
}, 0);
const totalInputTokens = usage[this.inputTokensKey] ?? 0;
const currentMessageTokens = totalInputTokens - totalTokensFromMap;
return currentMessageTokens > 0 ? currentMessageTokens : originalEstimate;
} }
async chatCompletion({ payload, abortController = null }) { async chatCompletion({ payload, abortController = null }) {
@ -676,12 +765,14 @@ class AgentClient extends BaseClient {
); );
}); });
this.recordCollectedUsage({ context: 'message' }).catch((err) => { try {
await this.recordCollectedUsage({ context: 'message' });
} catch (err) {
logger.error( logger.error(
'[api/server/controllers/agents/client.js #chatCompletion] Error recording collected usage', '[api/server/controllers/agents/client.js #chatCompletion] Error recording collected usage',
err, err,
); );
}); }
} catch (err) { } catch (err) {
if (!abortController.signal.aborted) { if (!abortController.signal.aborted) {
logger.error( logger.error(
@ -767,8 +858,11 @@ class AgentClient extends BaseClient {
} }
} }
/** Silent method, as `recordCollectedUsage` is used instead */
async recordTokenUsage() {}
getEncoding() { getEncoding() {
return this.model?.includes('gpt-4o') ? 'o200k_base' : 'cl100k_base'; return 'o200k_base';
} }
/** /**

View file

@ -11,6 +11,7 @@ const { isActionDomainAllowed } = require('~/server/services/domains');
const { encryptV2, decryptV2 } = require('~/server/utils/crypto'); const { encryptV2, decryptV2 } = require('~/server/utils/crypto');
const { getActions, deleteActions } = require('~/models/Action'); const { getActions, deleteActions } = require('~/models/Action');
const { deleteAssistant } = require('~/models/Assistant'); const { deleteAssistant } = require('~/models/Assistant');
const { logAxiosError } = require('~/utils');
const { getLogStores } = require('~/cache'); const { getLogStores } = require('~/cache');
const { logger } = require('~/config'); const { logger } = require('~/config');
@ -146,15 +147,8 @@ async function createActionTool({ action, requestBuilder, zodSchema, name, descr
} }
return res.data; return res.data;
} catch (error) { } catch (error) {
logger.error(`API call to ${action.metadata.domain} failed`, error); const logMessage = `API call to ${action.metadata.domain} failed`;
if (error.response) { logAxiosError({ message: logMessage, error });
const { status, data } = error.response;
return `API call to ${
action.metadata.domain
} failed with status ${status}: ${JSON.stringify(data)}`;
}
return `API call to ${action.metadata.domain} failed.`;
} }
}; };

View file

@ -4,27 +4,47 @@ const { AuthKeys } = require('librechat-data-provider');
// Example internal constant from your code // Example internal constant from your code
const EXCLUDED_GENAI_MODELS = /gemini-(?:1\.0|1-0|pro)/; const EXCLUDED_GENAI_MODELS = /gemini-(?:1\.0|1-0|pro)/;
function getSafetySettings() { /**
*
* @param {boolean} isGemini2
* @returns {Array<{category: string, threshold: string}>}
*/
function getSafetySettings(isGemini2) {
const mapThreshold = (value) => {
if (isGemini2 && value === 'BLOCK_NONE') {
return 'OFF';
}
return value;
};
return [ return [
{ {
category: 'HARM_CATEGORY_SEXUALLY_EXPLICIT', category: 'HARM_CATEGORY_SEXUALLY_EXPLICIT',
threshold: process.env.GOOGLE_SAFETY_SEXUALLY_EXPLICIT || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED', threshold: mapThreshold(
process.env.GOOGLE_SAFETY_SEXUALLY_EXPLICIT || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
),
}, },
{ {
category: 'HARM_CATEGORY_HATE_SPEECH', category: 'HARM_CATEGORY_HATE_SPEECH',
threshold: process.env.GOOGLE_SAFETY_HATE_SPEECH || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED', threshold: mapThreshold(
process.env.GOOGLE_SAFETY_HATE_SPEECH || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
),
}, },
{ {
category: 'HARM_CATEGORY_HARASSMENT', category: 'HARM_CATEGORY_HARASSMENT',
threshold: process.env.GOOGLE_SAFETY_HARASSMENT || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED', threshold: mapThreshold(
process.env.GOOGLE_SAFETY_HARASSMENT || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
),
}, },
{ {
category: 'HARM_CATEGORY_DANGEROUS_CONTENT', category: 'HARM_CATEGORY_DANGEROUS_CONTENT',
threshold: process.env.GOOGLE_SAFETY_DANGEROUS_CONTENT || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED', threshold: mapThreshold(
process.env.GOOGLE_SAFETY_DANGEROUS_CONTENT || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
),
}, },
{ {
category: 'HARM_CATEGORY_CIVIC_INTEGRITY', category: 'HARM_CATEGORY_CIVIC_INTEGRITY',
threshold: process.env.GOOGLE_SAFETY_CIVIC_INTEGRITY || 'BLOCK_NONE', threshold: mapThreshold(process.env.GOOGLE_SAFETY_CIVIC_INTEGRITY || 'BLOCK_NONE'),
}, },
]; ];
} }
@ -64,14 +84,16 @@ function getLLMConfig(credentials, options = {}) {
/** @type {GoogleClientOptions | VertexAIClientOptions} */ /** @type {GoogleClientOptions | VertexAIClientOptions} */
let llmConfig = { let llmConfig = {
...(options.modelOptions || {}), ...(options.modelOptions || {}),
safetySettings: getSafetySettings(),
maxRetries: 2, maxRetries: 2,
}; };
const isGemini2 = llmConfig.model.includes('gemini-2.0');
const isGenerativeModel = llmConfig.model.includes('gemini'); const isGenerativeModel = llmConfig.model.includes('gemini');
const isChatModel = !isGenerativeModel && llmConfig.model.includes('chat'); const isChatModel = !isGenerativeModel && llmConfig.model.includes('chat');
const isTextModel = !isGenerativeModel && !isChatModel && /code|text/.test(llmConfig.model); const isTextModel = !isGenerativeModel && !isChatModel && /code|text/.test(llmConfig.model);
llmConfig.safetySettings = getSafetySettings(isGemini2);
let provider; let provider;
if (project_id && isTextModel) { if (project_id && isTextModel) {

View file

@ -11,15 +11,15 @@ import { cn } from '~/utils';
interface AttachFileProps { interface AttachFileProps {
isRTL: boolean; isRTL: boolean;
disabled?: boolean | null; disabled?: boolean | null;
handleFileChange: (event: React.ChangeEvent<HTMLInputElement>) => void; handleFileChange: (event: React.ChangeEvent<HTMLInputElement>, toolResource?: string) => void;
setToolResource?: React.Dispatch<React.SetStateAction<string | undefined>>;
} }
const AttachFile = ({ isRTL, disabled, setToolResource, handleFileChange }: AttachFileProps) => { const AttachFile = ({ isRTL, disabled, handleFileChange }: AttachFileProps) => {
const localize = useLocalize(); const localize = useLocalize();
const isUploadDisabled = disabled ?? false; const isUploadDisabled = disabled ?? false;
const inputRef = useRef<HTMLInputElement>(null); const inputRef = useRef<HTMLInputElement>(null);
const [isPopoverActive, setIsPopoverActive] = useState(false); const [isPopoverActive, setIsPopoverActive] = useState(false);
const [toolResource, setToolResource] = useState<EToolResources | undefined>();
const { data: endpointsConfig } = useGetEndpointsQuery(); const { data: endpointsConfig } = useGetEndpointsQuery();
const capabilities = useMemo( const capabilities = useMemo(
@ -42,7 +42,7 @@ const AttachFile = ({ isRTL, disabled, setToolResource, handleFileChange }: Atta
{ {
label: localize('com_ui_upload_image_input'), label: localize('com_ui_upload_image_input'),
onClick: () => { onClick: () => {
setToolResource?.(undefined); setToolResource(undefined);
handleUploadClick(true); handleUploadClick(true);
}, },
icon: <ImageUpIcon className="icon-md" />, icon: <ImageUpIcon className="icon-md" />,
@ -53,7 +53,7 @@ const AttachFile = ({ isRTL, disabled, setToolResource, handleFileChange }: Atta
items.push({ items.push({
label: localize('com_ui_upload_file_search'), label: localize('com_ui_upload_file_search'),
onClick: () => { onClick: () => {
setToolResource?.(EToolResources.file_search); setToolResource(EToolResources.file_search);
handleUploadClick(); handleUploadClick();
}, },
icon: <FileSearch className="icon-md" />, icon: <FileSearch className="icon-md" />,
@ -64,7 +64,7 @@ const AttachFile = ({ isRTL, disabled, setToolResource, handleFileChange }: Atta
items.push({ items.push({
label: localize('com_ui_upload_code_files'), label: localize('com_ui_upload_code_files'),
onClick: () => { onClick: () => {
setToolResource?.(EToolResources.execute_code); setToolResource(EToolResources.execute_code);
handleUploadClick(); handleUploadClick();
}, },
icon: <TerminalSquareIcon className="icon-md" />, icon: <TerminalSquareIcon className="icon-md" />,
@ -98,7 +98,12 @@ const AttachFile = ({ isRTL, disabled, setToolResource, handleFileChange }: Atta
); );
return ( return (
<FileUpload ref={inputRef} handleFileChange={handleFileChange}> <FileUpload
ref={inputRef}
handleFileChange={(e) => {
handleFileChange(e, toolResource);
}}
>
<div className="relative select-none"> <div className="relative select-none">
<DropdownPopup <DropdownPopup
menuId="attach-file-menu" menuId="attach-file-menu"

View file

@ -27,7 +27,7 @@ function FileFormWrapper({
const { endpoint: _endpoint, endpointType } = conversation ?? { endpoint: null }; const { endpoint: _endpoint, endpointType } = conversation ?? { endpoint: null };
const isAgents = useMemo(() => isAgentsEndpoint(_endpoint), [_endpoint]); const isAgents = useMemo(() => isAgentsEndpoint(_endpoint), [_endpoint]);
const { handleFileChange, abortUpload, setToolResource } = useFileHandling(); const { handleFileChange, abortUpload } = useFileHandling();
const { data: fileConfig = defaultFileConfig } = useGetFileConfig({ const { data: fileConfig = defaultFileConfig } = useGetFileConfig({
select: (data) => mergeFileConfig(data), select: (data) => mergeFileConfig(data),
@ -48,7 +48,6 @@ function FileFormWrapper({
<AttachFileMenu <AttachFileMenu
isRTL={isRTL} isRTL={isRTL}
disabled={disableInputs} disabled={disableInputs}
setToolResource={setToolResource}
handleFileChange={handleFileChange} handleFileChange={handleFileChange}
/> />
); );

View file

@ -39,7 +39,6 @@ const useFileHandling = (params?: UseFileHandling) => {
const [errors, setErrors] = useState<string[]>([]); const [errors, setErrors] = useState<string[]>([]);
const abortControllerRef = useRef<AbortController | null>(null); const abortControllerRef = useRef<AbortController | null>(null);
const { startUploadTimer, clearUploadTimer } = useDelayedUploadToast(); const { startUploadTimer, clearUploadTimer } = useDelayedUploadToast();
const [toolResource, setToolResource] = useState<string | undefined>();
const { files, setFiles, setFilesLoading, conversation } = useChatContext(); const { files, setFiles, setFilesLoading, conversation } = useChatContext();
const setError = (error: string) => setErrors((prevErrors) => [...prevErrors, error]); const setError = (error: string) => setErrors((prevErrors) => [...prevErrors, error]);
const { addFile, replaceFile, updateFileById, deleteFileById } = useUpdateFiles( const { addFile, replaceFile, updateFileById, deleteFileById } = useUpdateFiles(
@ -149,9 +148,6 @@ const useFileHandling = (params?: UseFileHandling) => {
: error?.response?.data?.message ?? 'com_error_files_upload'; : error?.response?.data?.message ?? 'com_error_files_upload';
setError(errorMessage); setError(errorMessage);
}, },
onMutate: () => {
setToolResource(undefined);
},
}, },
abortControllerRef.current?.signal, abortControllerRef.current?.signal,
); );
@ -187,7 +183,7 @@ const useFileHandling = (params?: UseFileHandling) => {
if (!agent_id) { if (!agent_id) {
formData.append('message_file', 'true'); formData.append('message_file', 'true');
} }
const tool_resource = extendedFile.tool_resource ?? toolResource; const tool_resource = extendedFile.tool_resource;
if (tool_resource != null) { if (tool_resource != null) {
formData.append('tool_resource', tool_resource); formData.append('tool_resource', tool_resource);
} }
@ -365,7 +361,7 @@ const useFileHandling = (params?: UseFileHandling) => {
const isImage = originalFile.type.split('/')[0] === 'image'; const isImage = originalFile.type.split('/')[0] === 'image';
const tool_resource = const tool_resource =
extendedFile.tool_resource ?? params?.additionalMetadata?.tool_resource ?? toolResource; extendedFile.tool_resource ?? params?.additionalMetadata?.tool_resource;
if (isAgentsEndpoint(endpoint) && !isImage && tool_resource == null) { if (isAgentsEndpoint(endpoint) && !isImage && tool_resource == null) {
/** Note: this needs to be removed when we can support files to providers */ /** Note: this needs to be removed when we can support files to providers */
setError('com_error_files_unsupported_capability'); setError('com_error_files_unsupported_capability');
@ -388,11 +384,11 @@ const useFileHandling = (params?: UseFileHandling) => {
} }
}; };
const handleFileChange = (event: React.ChangeEvent<HTMLInputElement>) => { const handleFileChange = (event: React.ChangeEvent<HTMLInputElement>, _toolResource?: string) => {
event.stopPropagation(); event.stopPropagation();
if (event.target.files) { if (event.target.files) {
setFilesLoading(true); setFilesLoading(true);
handleFiles(event.target.files); handleFiles(event.target.files, _toolResource);
// reset the input // reset the input
event.target.value = ''; event.target.value = '';
} }
@ -408,7 +404,6 @@ const useFileHandling = (params?: UseFileHandling) => {
return { return {
handleFileChange, handleFileChange,
setToolResource,
handleFiles, handleFiles,
abortUpload, abortUpload,
setFiles, setFiles,

2
package-lock.json generated
View file

@ -36322,7 +36322,7 @@
}, },
"packages/data-provider": { "packages/data-provider": {
"name": "librechat-data-provider", "name": "librechat-data-provider",
"version": "0.7.692", "version": "0.7.693",
"license": "ISC", "license": "ISC",
"dependencies": { "dependencies": {
"axios": "^1.7.7", "axios": "^1.7.7",

View file

@ -1,6 +1,6 @@
{ {
"name": "librechat-data-provider", "name": "librechat-data-provider",
"version": "0.7.692", "version": "0.7.693",
"description": "data services for librechat apps", "description": "data services for librechat apps",
"main": "dist/index.js", "main": "dist/index.js",
"module": "dist/index.es.js", "module": "dist/index.es.js",

View file

@ -14,13 +14,7 @@ import {
} from '../src/file-config'; } from '../src/file-config';
describe('MIME Type Regex Patterns', () => { describe('MIME Type Regex Patterns', () => {
const unsupportedMimeTypes = [ const unsupportedMimeTypes = ['text/x-unknown', 'application/unknown', 'image/bmp', 'audio/mp3'];
'text/x-unknown',
'application/unknown',
'image/bmp',
'image/svg',
'audio/mp3',
];
// Testing general supported MIME types // Testing general supported MIME types
fullMimeTypesList.forEach((mimeType) => { fullMimeTypesList.forEach((mimeType) => {

View file

@ -54,6 +54,8 @@ export const fullMimeTypesList = [
'application/typescript', 'application/typescript',
'application/xml', 'application/xml',
'application/zip', 'application/zip',
'image/svg',
'image/svg+xml',
...excelFileTypes, ...excelFileTypes,
]; ];
@ -122,6 +124,8 @@ export const supportedMimeTypes = [
excelMimeTypes, excelMimeTypes,
applicationMimeTypes, applicationMimeTypes,
imageMimeTypes, imageMimeTypes,
/** Supported by LC Code Interpreter PAI */
/^image\/(svg|svg\+xml)$/,
]; ];
export const codeInterpreterMimeTypes = [ export const codeInterpreterMimeTypes = [

View file

@ -55,9 +55,7 @@ export class MCPConnection extends EventEmitter {
version: '1.0.0', version: '1.0.0',
}, },
{ {
capabilities: { capabilities: {},
tools: {},
},
}, },
); );