From b35a8b78e2884854376d553dd53bc6616ae80dea Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Fri, 17 Jan 2025 12:55:48 -0500 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=A7=20refactor:=20Improve=20Agent=20Co?= =?UTF-8?q?ntext=20&=20Minor=20Fixes=20(#5349)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * refactor: Improve Context for Agents * 🔧 fix: Safeguard against undefined properties in OpenAIClient response handling * refactor: log error before re-throwing for original stack trace * refactor: remove toolResource state from useFileHandling, allow svg files * refactor: prevent verbose logs from axios errors when using actions * refactor: add silent method recordTokenUsage in AgentClient * refactor: streamline token count assignment in BaseClient * refactor: enhance safety settings handling for Gemini 2.0 model * fix: capabilities structure in MCPConnection * refactor: simplify civic integrity threshold handling in GoogleClient and llm * refactor: update token count retrieval method in BaseClient tests * ci: fix test for svg --- api/app/clients/BaseClient.js | 26 +++- api/app/clients/GoogleClient.js | 28 +++-- api/app/clients/OpenAIClient.js | 2 +- api/app/clients/prompts/index.js | 4 +- api/app/clients/prompts/truncate.js | 115 ++++++++++++++++++ api/app/clients/prompts/truncateText.js | 40 ------ api/app/clients/specs/BaseClient.test.js | 4 +- .../clients/tools/util/handleOpenAIErrors.js | 2 + api/server/controllers/agents/client.js | 112 +++++++++++++++-- api/server/services/ActionService.js | 12 +- api/server/services/Endpoints/google/llm.js | 36 ++++-- .../Chat/Input/Files/AttachFileMenu.tsx | 19 +-- .../Chat/Input/Files/FileFormWrapper.tsx | 3 +- client/src/hooks/Files/useFileHandling.ts | 13 +- package-lock.json | 2 +- packages/data-provider/package.json | 2 +- .../data-provider/specs/filetypes.spec.ts | 8 +- packages/data-provider/src/file-config.ts | 4 + packages/mcp/src/connection.ts | 4 +- 19 files changed, 324 insertions(+), 112 deletions(-) create mode 100644 api/app/clients/prompts/truncate.js delete mode 100644 api/app/clients/prompts/truncateText.js diff --git a/api/app/clients/BaseClient.js b/api/app/clients/BaseClient.js index 5abdad686..5b73ae6b8 100644 --- a/api/app/clients/BaseClient.js +++ b/api/app/clients/BaseClient.js @@ -4,6 +4,7 @@ const { supportsBalanceCheck, isAgentsEndpoint, isParamEndpoint, + EModelEndpoint, ErrorTypes, Constants, CacheKeys, @@ -11,6 +12,7 @@ const { } = require('librechat-data-provider'); const { getMessages, saveMessage, updateMessage, saveConvo } = require('~/models'); const { addSpaceIfNeeded, isEnabled } = require('~/server/utils'); +const { truncateToolCallOutputs } = require('./prompts'); const checkBalance = require('~/models/checkBalance'); const { getFiles } = require('~/models/File'); const { getLogStores } = require('~/cache'); @@ -95,7 +97,7 @@ class BaseClient { * @returns {number} */ getTokenCountForResponse(responseMessage) { - logger.debug('`[BaseClient] recordTokenUsage` not implemented.', responseMessage); + logger.debug('[BaseClient] `recordTokenUsage` not implemented.', responseMessage); } /** @@ -106,7 +108,7 @@ class BaseClient { * @returns {Promise} */ async recordTokenUsage({ promptTokens, completionTokens }) { - logger.debug('`[BaseClient] recordTokenUsage` not implemented.', { + logger.debug('[BaseClient] `recordTokenUsage` not implemented.', { promptTokens, completionTokens, }); @@ -287,6 +289,9 @@ class BaseClient { } async handleTokenCountMap(tokenCountMap) { + if (this.clientName === EModelEndpoint.agents) { + return; + } if (this.currentMessages.length === 0) { return; } @@ -394,6 +399,21 @@ class BaseClient { _instructions && logger.debug('[BaseClient] instructions tokenCount: ' + tokenCount); let payload = this.addInstructions(formattedMessages, _instructions); let orderedWithInstructions = this.addInstructions(orderedMessages, instructions); + if (this.clientName === EModelEndpoint.agents) { + const { dbMessages, editedIndices } = truncateToolCallOutputs( + orderedWithInstructions, + this.maxContextTokens, + this.getTokenCountForMessage.bind(this), + ); + + if (editedIndices.length > 0) { + logger.debug('[BaseClient] Truncated tool call outputs:', editedIndices); + for (const index of editedIndices) { + payload[index].content = dbMessages[index].content; + } + orderedWithInstructions = dbMessages; + } + } let { context, remainingContextTokens, messagesToRefine, summaryIndex } = await this.getMessagesWithinTokenLimit(orderedWithInstructions); @@ -625,7 +645,7 @@ class BaseClient { await this.updateUserMessageTokenCount({ usage, tokenCountMap, userMessage, opts }); } else { responseMessage.tokenCount = this.getTokenCountForResponse(responseMessage); - completionTokens = this.getTokenCount(completion); + completionTokens = responseMessage.tokenCount; } await this.recordTokenUsage({ promptTokens, completionTokens, usage }); diff --git a/api/app/clients/GoogleClient.js b/api/app/clients/GoogleClient.js index 4b966646c..5601e0e3c 100644 --- a/api/app/clients/GoogleClient.js +++ b/api/app/clients/GoogleClient.js @@ -886,32 +886,42 @@ class GoogleClient extends BaseClient { } getSafetySettings() { + const isGemini2 = this.modelOptions.model.includes('gemini-2.0'); + const mapThreshold = (value) => { + if (isGemini2 && value === 'BLOCK_NONE') { + return 'OFF'; + } + return value; + }; + return [ { category: 'HARM_CATEGORY_SEXUALLY_EXPLICIT', - threshold: + threshold: mapThreshold( process.env.GOOGLE_SAFETY_SEXUALLY_EXPLICIT || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED', + ), }, { category: 'HARM_CATEGORY_HATE_SPEECH', - threshold: process.env.GOOGLE_SAFETY_HATE_SPEECH || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED', + threshold: mapThreshold( + process.env.GOOGLE_SAFETY_HATE_SPEECH || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED', + ), }, { category: 'HARM_CATEGORY_HARASSMENT', - threshold: process.env.GOOGLE_SAFETY_HARASSMENT || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED', + threshold: mapThreshold( + process.env.GOOGLE_SAFETY_HARASSMENT || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED', + ), }, { category: 'HARM_CATEGORY_DANGEROUS_CONTENT', - threshold: + threshold: mapThreshold( process.env.GOOGLE_SAFETY_DANGEROUS_CONTENT || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED', + ), }, { category: 'HARM_CATEGORY_CIVIC_INTEGRITY', - /** - * Note: this was added since `gemini-2.0-flash-thinking-exp-1219` does not - * accept 'HARM_BLOCK_THRESHOLD_UNSPECIFIED' for 'HARM_CATEGORY_CIVIC_INTEGRITY' - * */ - threshold: process.env.GOOGLE_SAFETY_CIVIC_INTEGRITY || 'BLOCK_NONE', + threshold: mapThreshold(process.env.GOOGLE_SAFETY_CIVIC_INTEGRITY || 'BLOCK_NONE'), }, ]; } diff --git a/api/app/clients/OpenAIClient.js b/api/app/clients/OpenAIClient.js index 15fd20aef..17eeeb239 100644 --- a/api/app/clients/OpenAIClient.js +++ b/api/app/clients/OpenAIClient.js @@ -1293,7 +1293,7 @@ ${convo} }); for await (const chunk of stream) { - const token = chunk.choices[0]?.delta?.content || ''; + const token = chunk?.choices?.[0]?.delta?.content || ''; intermediateReply.push(token); onProgress(token); if (abortController.signal.aborted) { diff --git a/api/app/clients/prompts/index.js b/api/app/clients/prompts/index.js index 364ad34b5..2549ccda5 100644 --- a/api/app/clients/prompts/index.js +++ b/api/app/clients/prompts/index.js @@ -4,7 +4,7 @@ const summaryPrompts = require('./summaryPrompts'); const handleInputs = require('./handleInputs'); const instructions = require('./instructions'); const titlePrompts = require('./titlePrompts'); -const truncateText = require('./truncateText'); +const truncate = require('./truncate'); const createVisionPrompt = require('./createVisionPrompt'); const createContextHandlers = require('./createContextHandlers'); @@ -15,7 +15,7 @@ module.exports = { ...handleInputs, ...instructions, ...titlePrompts, - ...truncateText, + ...truncate, createVisionPrompt, createContextHandlers, }; diff --git a/api/app/clients/prompts/truncate.js b/api/app/clients/prompts/truncate.js new file mode 100644 index 000000000..564b39efe --- /dev/null +++ b/api/app/clients/prompts/truncate.js @@ -0,0 +1,115 @@ +const MAX_CHAR = 255; + +/** + * Truncates a given text to a specified maximum length, appending ellipsis and a notification + * if the original text exceeds the maximum length. + * + * @param {string} text - The text to be truncated. + * @param {number} [maxLength=MAX_CHAR] - The maximum length of the text after truncation. Defaults to MAX_CHAR. + * @returns {string} The truncated text if the original text length exceeds maxLength, otherwise returns the original text. + */ +function truncateText(text, maxLength = MAX_CHAR) { + if (text.length > maxLength) { + return `${text.slice(0, maxLength)}... [text truncated for brevity]`; + } + return text; +} + +/** + * Truncates a given text to a specified maximum length by showing the first half and the last half of the text, + * separated by ellipsis. This method ensures the output does not exceed the maximum length, including the addition + * of ellipsis and notification if the original text exceeds the maximum length. + * + * @param {string} text - The text to be truncated. + * @param {number} [maxLength=MAX_CHAR] - The maximum length of the output text after truncation. Defaults to MAX_CHAR. + * @returns {string} The truncated text showing the first half and the last half, or the original text if it does not exceed maxLength. + */ +function smartTruncateText(text, maxLength = MAX_CHAR) { + const ellipsis = '...'; + const notification = ' [text truncated for brevity]'; + const halfMaxLength = Math.floor((maxLength - ellipsis.length - notification.length) / 2); + + if (text.length > maxLength) { + const startLastHalf = text.length - halfMaxLength; + return `${text.slice(0, halfMaxLength)}${ellipsis}${text.slice(startLastHalf)}${notification}`; + } + + return text; +} + +/** + * @param {TMessage[]} _messages + * @param {number} maxContextTokens + * @param {function({role: string, content: TMessageContent[]}): number} getTokenCountForMessage + * + * @returns {{ + * dbMessages: TMessage[], + * editedIndices: number[] + * }} + */ +function truncateToolCallOutputs(_messages, maxContextTokens, getTokenCountForMessage) { + const THRESHOLD_PERCENTAGE = 0.5; + const targetTokenLimit = maxContextTokens * THRESHOLD_PERCENTAGE; + + let currentTokenCount = 3; + const messages = [..._messages]; + const processedMessages = []; + let currentIndex = messages.length; + const editedIndices = new Set(); + while (messages.length > 0) { + currentIndex--; + const message = messages.pop(); + currentTokenCount += message.tokenCount; + if (currentTokenCount < targetTokenLimit) { + processedMessages.push(message); + continue; + } + + if (!message.content || !Array.isArray(message.content)) { + processedMessages.push(message); + continue; + } + + const toolCallIndices = message.content + .map((item, index) => (item.type === 'tool_call' ? index : -1)) + .filter((index) => index !== -1) + .reverse(); + + if (toolCallIndices.length === 0) { + processedMessages.push(message); + continue; + } + + const newContent = [...message.content]; + + // Truncate all tool outputs since we're over threshold + for (const index of toolCallIndices) { + const toolCall = newContent[index].tool_call; + if (!toolCall || !toolCall.output) { + continue; + } + + editedIndices.add(currentIndex); + + newContent[index] = { + ...newContent[index], + tool_call: { + ...toolCall, + output: '[OUTPUT_OMITTED_FOR_BREVITY]', + }, + }; + } + + const truncatedMessage = { + ...message, + content: newContent, + tokenCount: getTokenCountForMessage({ role: 'assistant', content: newContent }), + }; + + processedMessages.push(truncatedMessage); + } + + return { dbMessages: processedMessages.reverse(), editedIndices: Array.from(editedIndices) }; +} + +module.exports = { truncateText, smartTruncateText, truncateToolCallOutputs }; diff --git a/api/app/clients/prompts/truncateText.js b/api/app/clients/prompts/truncateText.js deleted file mode 100644 index e744b40da..000000000 --- a/api/app/clients/prompts/truncateText.js +++ /dev/null @@ -1,40 +0,0 @@ -const MAX_CHAR = 255; - -/** - * Truncates a given text to a specified maximum length, appending ellipsis and a notification - * if the original text exceeds the maximum length. - * - * @param {string} text - The text to be truncated. - * @param {number} [maxLength=MAX_CHAR] - The maximum length of the text after truncation. Defaults to MAX_CHAR. - * @returns {string} The truncated text if the original text length exceeds maxLength, otherwise returns the original text. - */ -function truncateText(text, maxLength = MAX_CHAR) { - if (text.length > maxLength) { - return `${text.slice(0, maxLength)}... [text truncated for brevity]`; - } - return text; -} - -/** - * Truncates a given text to a specified maximum length by showing the first half and the last half of the text, - * separated by ellipsis. This method ensures the output does not exceed the maximum length, including the addition - * of ellipsis and notification if the original text exceeds the maximum length. - * - * @param {string} text - The text to be truncated. - * @param {number} [maxLength=MAX_CHAR] - The maximum length of the output text after truncation. Defaults to MAX_CHAR. - * @returns {string} The truncated text showing the first half and the last half, or the original text if it does not exceed maxLength. - */ -function smartTruncateText(text, maxLength = MAX_CHAR) { - const ellipsis = '...'; - const notification = ' [text truncated for brevity]'; - const halfMaxLength = Math.floor((maxLength - ellipsis.length - notification.length) / 2); - - if (text.length > maxLength) { - const startLastHalf = text.length - halfMaxLength; - return `${text.slice(0, halfMaxLength)}${ellipsis}${text.slice(startLastHalf)}${notification}`; - } - - return text; -} - -module.exports = { truncateText, smartTruncateText }; diff --git a/api/app/clients/specs/BaseClient.test.js b/api/app/clients/specs/BaseClient.test.js index 4db1c9822..a6925759a 100644 --- a/api/app/clients/specs/BaseClient.test.js +++ b/api/app/clients/specs/BaseClient.test.js @@ -615,9 +615,9 @@ describe('BaseClient', () => { test('getTokenCount for response is called with the correct arguments', async () => { const tokenCountMap = {}; // Mock tokenCountMap TestClient.buildMessages.mockReturnValue({ prompt: [], tokenCountMap }); - TestClient.getTokenCount = jest.fn(); + TestClient.getTokenCountForResponse = jest.fn(); const response = await TestClient.sendMessage('Hello, world!', {}); - expect(TestClient.getTokenCount).toHaveBeenCalledWith(response.text); + expect(TestClient.getTokenCountForResponse).toHaveBeenCalledWith(response); }); test('returns an object with the correct shape', async () => { diff --git a/api/app/clients/tools/util/handleOpenAIErrors.js b/api/app/clients/tools/util/handleOpenAIErrors.js index 53a4f37ac..490f3882a 100644 --- a/api/app/clients/tools/util/handleOpenAIErrors.js +++ b/api/app/clients/tools/util/handleOpenAIErrors.js @@ -23,6 +23,8 @@ async function handleOpenAIErrors(err, errorCallback, context = 'stream') { logger.warn(`[OpenAIClient.chatCompletion][${context}] Unhandled error type`); } + logger.error(err); + if (errorCallback) { errorCallback(err); } diff --git a/api/server/controllers/agents/client.js b/api/server/controllers/agents/client.js index e3c6dd567..a8e9ad82f 100644 --- a/api/server/controllers/agents/client.js +++ b/api/server/controllers/agents/client.js @@ -60,6 +60,9 @@ const noSystemModelRegex = [/\bo1\b/gi]; class AgentClient extends BaseClient { constructor(options = {}) { super(null, options); + /** The current client class + * @type {string} */ + this.clientName = EModelEndpoint.agents; /** @type {'discard' | 'summarize'} */ this.contextStrategy = 'discard'; @@ -91,6 +94,14 @@ class AgentClient extends BaseClient { this.options = Object.assign({ endpoint: options.endpoint }, clientOptions); /** @type {string} */ this.model = this.options.agent.model_parameters.model; + /** The key for the usage object's input tokens + * @type {string} */ + this.inputTokensKey = 'input_tokens'; + /** The key for the usage object's output tokens + * @type {string} */ + this.outputTokensKey = 'output_tokens'; + /** @type {UsageMetadata} */ + this.usage; } /** @@ -329,16 +340,18 @@ class AgentClient extends BaseClient { this.options.agent.instructions = systemContent; } + /** @type {Record | undefined} */ + let tokenCountMap; + if (this.contextStrategy) { - ({ payload, promptTokens, messages } = await this.handleContextStrategy({ + ({ payload, promptTokens, tokenCountMap, messages } = await this.handleContextStrategy({ orderedMessages, formattedMessages, - /* prefer usage_metadata from final message */ - buildTokenMap: false, })); } const result = { + tokenCountMap, prompt: payload, promptTokens, messages, @@ -368,8 +381,26 @@ class AgentClient extends BaseClient { * @param {UsageMetadata[]} [params.collectedUsage=this.collectedUsage] */ async recordCollectedUsage({ model, context = 'message', collectedUsage = this.collectedUsage }) { - for (const usage of collectedUsage) { - await spendTokens( + if (!collectedUsage || !collectedUsage.length) { + return; + } + const input_tokens = collectedUsage[0]?.input_tokens || 0; + + let output_tokens = 0; + let previousTokens = input_tokens; // Start with original input + for (let i = 0; i < collectedUsage.length; i++) { + const usage = collectedUsage[i]; + if (i > 0) { + // Count new tokens generated (input_tokens minus previous accumulated tokens) + output_tokens += (Number(usage.input_tokens) || 0) - previousTokens; + } + + // Add this message's output tokens + output_tokens += Number(usage.output_tokens) || 0; + + // Update previousTokens to include this message's output + previousTokens += Number(usage.output_tokens) || 0; + spendTokens( { context, conversationId: this.conversationId, @@ -378,8 +409,66 @@ class AgentClient extends BaseClient { model: usage.model ?? model ?? this.model ?? this.options.agent.model_parameters.model, }, { promptTokens: usage.input_tokens, completionTokens: usage.output_tokens }, - ); + ).catch((err) => { + logger.error( + '[api/server/controllers/agents/client.js #recordCollectedUsage] Error spending tokens', + err, + ); + }); } + + this.usage = { + input_tokens, + output_tokens, + }; + } + + /** + * Get stream usage as returned by this client's API response. + * @returns {UsageMetadata} The stream usage object. + */ + getStreamUsage() { + return this.usage; + } + + /** + * @param {TMessage} responseMessage + * @returns {number} + */ + getTokenCountForResponse({ content }) { + return this.getTokenCountForMessage({ + role: 'assistant', + content, + }); + } + + /** + * Calculates the correct token count for the current user message based on the token count map and API usage. + * Edge case: If the calculation results in a negative value, it returns the original estimate. + * If revisiting a conversation with a chat history entirely composed of token estimates, + * the cumulative token count going forward should become more accurate as the conversation progresses. + * @param {Object} params - The parameters for the calculation. + * @param {Record} params.tokenCountMap - A map of message IDs to their token counts. + * @param {string} params.currentMessageId - The ID of the current message to calculate. + * @param {OpenAIUsageMetadata} params.usage - The usage object returned by the API. + * @returns {number} The correct token count for the current user message. + */ + calculateCurrentTokenCount({ tokenCountMap, currentMessageId, usage }) { + const originalEstimate = tokenCountMap[currentMessageId] || 0; + + if (!usage || typeof usage[this.inputTokensKey] !== 'number') { + return originalEstimate; + } + + tokenCountMap[currentMessageId] = 0; + const totalTokensFromMap = Object.values(tokenCountMap).reduce((sum, count) => { + const numCount = Number(count); + return sum + (isNaN(numCount) ? 0 : numCount); + }, 0); + const totalInputTokens = usage[this.inputTokensKey] ?? 0; + + const currentMessageTokens = totalInputTokens - totalTokensFromMap; + return currentMessageTokens > 0 ? currentMessageTokens : originalEstimate; } async chatCompletion({ payload, abortController = null }) { @@ -676,12 +765,14 @@ class AgentClient extends BaseClient { ); }); - this.recordCollectedUsage({ context: 'message' }).catch((err) => { + try { + await this.recordCollectedUsage({ context: 'message' }); + } catch (err) { logger.error( '[api/server/controllers/agents/client.js #chatCompletion] Error recording collected usage', err, ); - }); + } } catch (err) { if (!abortController.signal.aborted) { logger.error( @@ -767,8 +858,11 @@ class AgentClient extends BaseClient { } } + /** Silent method, as `recordCollectedUsage` is used instead */ + async recordTokenUsage() {} + getEncoding() { - return this.model?.includes('gpt-4o') ? 'o200k_base' : 'cl100k_base'; + return 'o200k_base'; } /** diff --git a/api/server/services/ActionService.js b/api/server/services/ActionService.js index 068e96948..712157bf2 100644 --- a/api/server/services/ActionService.js +++ b/api/server/services/ActionService.js @@ -11,6 +11,7 @@ const { isActionDomainAllowed } = require('~/server/services/domains'); const { encryptV2, decryptV2 } = require('~/server/utils/crypto'); const { getActions, deleteActions } = require('~/models/Action'); const { deleteAssistant } = require('~/models/Assistant'); +const { logAxiosError } = require('~/utils'); const { getLogStores } = require('~/cache'); const { logger } = require('~/config'); @@ -146,15 +147,8 @@ async function createActionTool({ action, requestBuilder, zodSchema, name, descr } return res.data; } catch (error) { - logger.error(`API call to ${action.metadata.domain} failed`, error); - if (error.response) { - const { status, data } = error.response; - return `API call to ${ - action.metadata.domain - } failed with status ${status}: ${JSON.stringify(data)}`; - } - - return `API call to ${action.metadata.domain} failed.`; + const logMessage = `API call to ${action.metadata.domain} failed`; + logAxiosError({ message: logMessage, error }); } }; diff --git a/api/server/services/Endpoints/google/llm.js b/api/server/services/Endpoints/google/llm.js index 959e9a494..92eca9a6a 100644 --- a/api/server/services/Endpoints/google/llm.js +++ b/api/server/services/Endpoints/google/llm.js @@ -4,27 +4,47 @@ const { AuthKeys } = require('librechat-data-provider'); // Example internal constant from your code const EXCLUDED_GENAI_MODELS = /gemini-(?:1\.0|1-0|pro)/; -function getSafetySettings() { +/** + * + * @param {boolean} isGemini2 + * @returns {Array<{category: string, threshold: string}>} + */ +function getSafetySettings(isGemini2) { + const mapThreshold = (value) => { + if (isGemini2 && value === 'BLOCK_NONE') { + return 'OFF'; + } + return value; + }; + return [ { category: 'HARM_CATEGORY_SEXUALLY_EXPLICIT', - threshold: process.env.GOOGLE_SAFETY_SEXUALLY_EXPLICIT || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED', + threshold: mapThreshold( + process.env.GOOGLE_SAFETY_SEXUALLY_EXPLICIT || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED', + ), }, { category: 'HARM_CATEGORY_HATE_SPEECH', - threshold: process.env.GOOGLE_SAFETY_HATE_SPEECH || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED', + threshold: mapThreshold( + process.env.GOOGLE_SAFETY_HATE_SPEECH || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED', + ), }, { category: 'HARM_CATEGORY_HARASSMENT', - threshold: process.env.GOOGLE_SAFETY_HARASSMENT || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED', + threshold: mapThreshold( + process.env.GOOGLE_SAFETY_HARASSMENT || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED', + ), }, { category: 'HARM_CATEGORY_DANGEROUS_CONTENT', - threshold: process.env.GOOGLE_SAFETY_DANGEROUS_CONTENT || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED', + threshold: mapThreshold( + process.env.GOOGLE_SAFETY_DANGEROUS_CONTENT || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED', + ), }, { category: 'HARM_CATEGORY_CIVIC_INTEGRITY', - threshold: process.env.GOOGLE_SAFETY_CIVIC_INTEGRITY || 'BLOCK_NONE', + threshold: mapThreshold(process.env.GOOGLE_SAFETY_CIVIC_INTEGRITY || 'BLOCK_NONE'), }, ]; } @@ -64,14 +84,16 @@ function getLLMConfig(credentials, options = {}) { /** @type {GoogleClientOptions | VertexAIClientOptions} */ let llmConfig = { ...(options.modelOptions || {}), - safetySettings: getSafetySettings(), maxRetries: 2, }; + const isGemini2 = llmConfig.model.includes('gemini-2.0'); const isGenerativeModel = llmConfig.model.includes('gemini'); const isChatModel = !isGenerativeModel && llmConfig.model.includes('chat'); const isTextModel = !isGenerativeModel && !isChatModel && /code|text/.test(llmConfig.model); + llmConfig.safetySettings = getSafetySettings(isGemini2); + let provider; if (project_id && isTextModel) { diff --git a/client/src/components/Chat/Input/Files/AttachFileMenu.tsx b/client/src/components/Chat/Input/Files/AttachFileMenu.tsx index 088cdfaa9..c3d9809d5 100644 --- a/client/src/components/Chat/Input/Files/AttachFileMenu.tsx +++ b/client/src/components/Chat/Input/Files/AttachFileMenu.tsx @@ -11,15 +11,15 @@ import { cn } from '~/utils'; interface AttachFileProps { isRTL: boolean; disabled?: boolean | null; - handleFileChange: (event: React.ChangeEvent) => void; - setToolResource?: React.Dispatch>; + handleFileChange: (event: React.ChangeEvent, toolResource?: string) => void; } -const AttachFile = ({ isRTL, disabled, setToolResource, handleFileChange }: AttachFileProps) => { +const AttachFile = ({ isRTL, disabled, handleFileChange }: AttachFileProps) => { const localize = useLocalize(); const isUploadDisabled = disabled ?? false; const inputRef = useRef(null); const [isPopoverActive, setIsPopoverActive] = useState(false); + const [toolResource, setToolResource] = useState(); const { data: endpointsConfig } = useGetEndpointsQuery(); const capabilities = useMemo( @@ -42,7 +42,7 @@ const AttachFile = ({ isRTL, disabled, setToolResource, handleFileChange }: Atta { label: localize('com_ui_upload_image_input'), onClick: () => { - setToolResource?.(undefined); + setToolResource(undefined); handleUploadClick(true); }, icon: , @@ -53,7 +53,7 @@ const AttachFile = ({ isRTL, disabled, setToolResource, handleFileChange }: Atta items.push({ label: localize('com_ui_upload_file_search'), onClick: () => { - setToolResource?.(EToolResources.file_search); + setToolResource(EToolResources.file_search); handleUploadClick(); }, icon: , @@ -64,7 +64,7 @@ const AttachFile = ({ isRTL, disabled, setToolResource, handleFileChange }: Atta items.push({ label: localize('com_ui_upload_code_files'), onClick: () => { - setToolResource?.(EToolResources.execute_code); + setToolResource(EToolResources.execute_code); handleUploadClick(); }, icon: , @@ -98,7 +98,12 @@ const AttachFile = ({ isRTL, disabled, setToolResource, handleFileChange }: Atta ); return ( - + { + handleFileChange(e, toolResource); + }} + >
isAgentsEndpoint(_endpoint), [_endpoint]); - const { handleFileChange, abortUpload, setToolResource } = useFileHandling(); + const { handleFileChange, abortUpload } = useFileHandling(); const { data: fileConfig = defaultFileConfig } = useGetFileConfig({ select: (data) => mergeFileConfig(data), @@ -48,7 +48,6 @@ function FileFormWrapper({ ); diff --git a/client/src/hooks/Files/useFileHandling.ts b/client/src/hooks/Files/useFileHandling.ts index d1e71d8a6..2723474d4 100644 --- a/client/src/hooks/Files/useFileHandling.ts +++ b/client/src/hooks/Files/useFileHandling.ts @@ -39,7 +39,6 @@ const useFileHandling = (params?: UseFileHandling) => { const [errors, setErrors] = useState([]); const abortControllerRef = useRef(null); const { startUploadTimer, clearUploadTimer } = useDelayedUploadToast(); - const [toolResource, setToolResource] = useState(); const { files, setFiles, setFilesLoading, conversation } = useChatContext(); const setError = (error: string) => setErrors((prevErrors) => [...prevErrors, error]); const { addFile, replaceFile, updateFileById, deleteFileById } = useUpdateFiles( @@ -149,9 +148,6 @@ const useFileHandling = (params?: UseFileHandling) => { : error?.response?.data?.message ?? 'com_error_files_upload'; setError(errorMessage); }, - onMutate: () => { - setToolResource(undefined); - }, }, abortControllerRef.current?.signal, ); @@ -187,7 +183,7 @@ const useFileHandling = (params?: UseFileHandling) => { if (!agent_id) { formData.append('message_file', 'true'); } - const tool_resource = extendedFile.tool_resource ?? toolResource; + const tool_resource = extendedFile.tool_resource; if (tool_resource != null) { formData.append('tool_resource', tool_resource); } @@ -365,7 +361,7 @@ const useFileHandling = (params?: UseFileHandling) => { const isImage = originalFile.type.split('/')[0] === 'image'; const tool_resource = - extendedFile.tool_resource ?? params?.additionalMetadata?.tool_resource ?? toolResource; + extendedFile.tool_resource ?? params?.additionalMetadata?.tool_resource; if (isAgentsEndpoint(endpoint) && !isImage && tool_resource == null) { /** Note: this needs to be removed when we can support files to providers */ setError('com_error_files_unsupported_capability'); @@ -388,11 +384,11 @@ const useFileHandling = (params?: UseFileHandling) => { } }; - const handleFileChange = (event: React.ChangeEvent) => { + const handleFileChange = (event: React.ChangeEvent, _toolResource?: string) => { event.stopPropagation(); if (event.target.files) { setFilesLoading(true); - handleFiles(event.target.files); + handleFiles(event.target.files, _toolResource); // reset the input event.target.value = ''; } @@ -408,7 +404,6 @@ const useFileHandling = (params?: UseFileHandling) => { return { handleFileChange, - setToolResource, handleFiles, abortUpload, setFiles, diff --git a/package-lock.json b/package-lock.json index 04813ade9..60c8e8871 100644 --- a/package-lock.json +++ b/package-lock.json @@ -36322,7 +36322,7 @@ }, "packages/data-provider": { "name": "librechat-data-provider", - "version": "0.7.692", + "version": "0.7.693", "license": "ISC", "dependencies": { "axios": "^1.7.7", diff --git a/packages/data-provider/package.json b/packages/data-provider/package.json index 7d7bd73b2..4e5b2e343 100644 --- a/packages/data-provider/package.json +++ b/packages/data-provider/package.json @@ -1,6 +1,6 @@ { "name": "librechat-data-provider", - "version": "0.7.692", + "version": "0.7.693", "description": "data services for librechat apps", "main": "dist/index.js", "module": "dist/index.es.js", diff --git a/packages/data-provider/specs/filetypes.spec.ts b/packages/data-provider/specs/filetypes.spec.ts index de511fc68..e37baca59 100644 --- a/packages/data-provider/specs/filetypes.spec.ts +++ b/packages/data-provider/specs/filetypes.spec.ts @@ -14,13 +14,7 @@ import { } from '../src/file-config'; describe('MIME Type Regex Patterns', () => { - const unsupportedMimeTypes = [ - 'text/x-unknown', - 'application/unknown', - 'image/bmp', - 'image/svg', - 'audio/mp3', - ]; + const unsupportedMimeTypes = ['text/x-unknown', 'application/unknown', 'image/bmp', 'audio/mp3']; // Testing general supported MIME types fullMimeTypesList.forEach((mimeType) => { diff --git a/packages/data-provider/src/file-config.ts b/packages/data-provider/src/file-config.ts index 5fb22b2f6..a34d44ff3 100644 --- a/packages/data-provider/src/file-config.ts +++ b/packages/data-provider/src/file-config.ts @@ -54,6 +54,8 @@ export const fullMimeTypesList = [ 'application/typescript', 'application/xml', 'application/zip', + 'image/svg', + 'image/svg+xml', ...excelFileTypes, ]; @@ -122,6 +124,8 @@ export const supportedMimeTypes = [ excelMimeTypes, applicationMimeTypes, imageMimeTypes, + /** Supported by LC Code Interpreter PAI */ + /^image\/(svg|svg\+xml)$/, ]; export const codeInterpreterMimeTypes = [ diff --git a/packages/mcp/src/connection.ts b/packages/mcp/src/connection.ts index ac860d1c4..68b5c5ed5 100644 --- a/packages/mcp/src/connection.ts +++ b/packages/mcp/src/connection.ts @@ -55,9 +55,7 @@ export class MCPConnection extends EventEmitter { version: '1.0.0', }, { - capabilities: { - tools: {}, - }, + capabilities: {}, }, );