diff --git a/.env.example b/.env.example index 876535b345..2745881992 100644 --- a/.env.example +++ b/.env.example @@ -58,7 +58,7 @@ DEBUG_CONSOLE=false # Endpoints # #===================================================# -# ENDPOINTS=openAI,assistants,azureOpenAI,google,gptPlugins,anthropic +# ENDPOINTS=openAI,assistants,azureOpenAI,google,anthropic PROXY= @@ -142,10 +142,10 @@ GOOGLE_KEY=user_provided # GOOGLE_AUTH_HEADER=true # Gemini API (AI Studio) -# GOOGLE_MODELS=gemini-2.5-pro-preview-05-06,gemini-2.5-flash-preview-04-17,gemini-2.0-flash-001,gemini-2.0-flash-exp,gemini-2.0-flash-lite-001,gemini-1.5-pro-002,gemini-1.5-flash-002 +# GOOGLE_MODELS=gemini-2.5-pro,gemini-2.5-flash,gemini-2.5-flash-lite-preview-06-17,gemini-2.0-flash,gemini-2.0-flash-lite # Vertex AI -# GOOGLE_MODELS=gemini-2.5-pro-preview-05-06,gemini-2.5-flash-preview-04-17,gemini-2.0-flash-001,gemini-2.0-flash-exp,gemini-2.0-flash-lite-001,gemini-1.5-pro-002,gemini-1.5-flash-002 +# GOOGLE_MODELS=gemini-2.5-pro,gemini-2.5-flash,gemini-2.5-flash-lite-preview-06-17,gemini-2.0-flash-001,gemini-2.0-flash-lite-001 # GOOGLE_TITLE_MODEL=gemini-2.0-flash-lite-001 @@ -349,6 +349,11 @@ REGISTRATION_VIOLATION_SCORE=1 CONCURRENT_VIOLATION_SCORE=1 MESSAGE_VIOLATION_SCORE=1 NON_BROWSER_VIOLATION_SCORE=20 +TTS_VIOLATION_SCORE=0 +STT_VIOLATION_SCORE=0 +FORK_VIOLATION_SCORE=0 +IMPORT_VIOLATION_SCORE=0 +FILE_UPLOAD_VIOLATION_SCORE=0 LOGIN_MAX=7 LOGIN_WINDOW=5 @@ -453,8 +458,8 @@ OPENID_REUSE_TOKENS= OPENID_JWKS_URL_CACHE_ENABLED= OPENID_JWKS_URL_CACHE_TIME= # 600000 ms eq to 10 minutes leave empty to disable caching #Set to true to trigger token exchange flow to acquire access token for the userinfo endpoint. -OPENID_ON_BEHALF_FLOW_FOR_USERINFRO_REQUIRED= -OPENID_ON_BEHALF_FLOW_USERINFRO_SCOPE = "user.read" # example for Scope Needed for Microsoft Graph API +OPENID_ON_BEHALF_FLOW_FOR_USERINFO_REQUIRED= +OPENID_ON_BEHALF_FLOW_USERINFO_SCOPE="user.read" # example for Scope Needed for Microsoft Graph API # Set to true to use the OpenID Connect end session endpoint for logout OPENID_USE_END_SESSION_ENDPOINT= @@ -657,4 +662,4 @@ OPENWEATHER_API_KEY= # Reranker (Required) # JINA_API_KEY=your_jina_api_key # or -# COHERE_API_KEY=your_cohere_api_key \ No newline at end of file +# COHERE_API_KEY=your_cohere_api_key diff --git a/Dockerfile b/Dockerfile index 393b35354d..02bcb7da1f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -# v0.7.8 +# v0.7.9-rc1 # Base node image FROM node:20-alpine AS node diff --git a/Dockerfile.multi b/Dockerfile.multi index 17a9847323..9738f4e1f3 100644 --- a/Dockerfile.multi +++ b/Dockerfile.multi @@ -1,5 +1,5 @@ # Dockerfile.multi -# v0.7.8 +# v0.7.9-rc1 # Base for all builds FROM node:20-alpine AS base-min diff --git a/README.md b/README.md index d6bd19ab43..8de9ce5ef3 100644 --- a/README.md +++ b/README.md @@ -52,7 +52,7 @@ - 🖥️ **UI & Experience** inspired by ChatGPT with enhanced design and features - 🤖 **AI Model Selection**: - - Anthropic (Claude), AWS Bedrock, OpenAI, Azure OpenAI, Google, Vertex AI, OpenAI Assistants API (incl. Azure) + - Anthropic (Claude), AWS Bedrock, OpenAI, Azure OpenAI, Google, Vertex AI, OpenAI Responses API (incl. Azure) - [Custom Endpoints](https://www.librechat.ai/docs/quick_start/custom_endpoints): Use any OpenAI-compatible API with LibreChat, no proxy required - Compatible with [Local & Remote AI Providers](https://www.librechat.ai/docs/configuration/librechat_yaml/ai_endpoints): - Ollama, groq, Cohere, Mistral AI, Apple MLX, koboldcpp, together.ai, @@ -66,10 +66,9 @@ - 🔦 **Agents & Tools Integration**: - **[LibreChat Agents](https://www.librechat.ai/docs/features/agents)**: - No-Code Custom Assistants: Build specialized, AI-driven helpers without coding - - Flexible & Extensible: Attach tools like DALL-E-3, file search, code execution, and more - - Compatible with Custom Endpoints, OpenAI, Azure, Anthropic, AWS Bedrock, and more + - Flexible & Extensible: Use MCP Servers, tools, file search, code execution, and more + - Compatible with Custom Endpoints, OpenAI, Azure, Anthropic, AWS Bedrock, Google, Vertex AI, Responses API, and more - [Model Context Protocol (MCP) Support](https://modelcontextprotocol.io/clients#librechat) for Tools - - Use LibreChat Agents and OpenAI Assistants with Files, Code Interpreter, Tools, and API Actions - 🔍 **Web Search**: - Search the internet and retrieve relevant information to enhance your AI context diff --git a/api/app/clients/BaseClient.js b/api/app/clients/BaseClient.js index 55b8780180..0598f0da21 100644 --- a/api/app/clients/BaseClient.js +++ b/api/app/clients/BaseClient.js @@ -13,7 +13,6 @@ const { const { getMessages, saveMessage, updateMessage, saveConvo, getConvo } = require('~/models'); const { checkBalance } = require('~/models/balanceMethods'); const { truncateToolCallOutputs } = require('./prompts'); -const { addSpaceIfNeeded } = require('~/server/utils'); const { getFiles } = require('~/models/File'); const TextStream = require('./TextStream'); const { logger } = require('~/config'); @@ -572,7 +571,7 @@ class BaseClient { }); } - const { generation = '' } = opts; + const { editedContent } = opts; // It's not necessary to push to currentMessages // depending on subclass implementation of handling messages @@ -587,11 +586,21 @@ class BaseClient { isCreatedByUser: false, model: this.modelOptions?.model ?? this.model, sender: this.sender, - text: generation, }; this.currentMessages.push(userMessage, latestMessage); - } else { - latestMessage.text = generation; + } else if (editedContent != null) { + // Handle editedContent for content parts + if (editedContent && latestMessage.content && Array.isArray(latestMessage.content)) { + const { index, text, type } = editedContent; + if (index >= 0 && index < latestMessage.content.length) { + const contentPart = latestMessage.content[index]; + if (type === ContentTypes.THINK && contentPart.type === ContentTypes.THINK) { + contentPart[ContentTypes.THINK] = text; + } else if (type === ContentTypes.TEXT && contentPart.type === ContentTypes.TEXT) { + contentPart[ContentTypes.TEXT] = text; + } + } + } } this.continued = true; } else { @@ -672,16 +681,32 @@ class BaseClient { }; if (typeof completion === 'string') { - responseMessage.text = addSpaceIfNeeded(generation) + completion; + responseMessage.text = completion; } else if ( Array.isArray(completion) && (this.clientName === EModelEndpoint.agents || isParamEndpoint(this.options.endpoint, this.options.endpointType)) ) { responseMessage.text = ''; - responseMessage.content = completion; + + if (!opts.editedContent || this.currentMessages.length === 0) { + responseMessage.content = completion; + } else { + const latestMessage = this.currentMessages[this.currentMessages.length - 1]; + if (!latestMessage?.content) { + responseMessage.content = completion; + } else { + const existingContent = [...latestMessage.content]; + const { type: editedType } = opts.editedContent; + responseMessage.content = this.mergeEditedContent( + existingContent, + completion, + editedType, + ); + } + } } else if (Array.isArray(completion)) { - responseMessage.text = addSpaceIfNeeded(generation) + completion.join(''); + responseMessage.text = completion.join(''); } if ( @@ -792,7 +817,8 @@ class BaseClient { userMessage.tokenCount = userMessageTokenCount; /* - Note: `AskController` saves the user message, so we update the count of its `userMessage` reference + Note: `AgentController` saves the user message if not saved here + (noted by `savedMessageIds`), so we update the count of its `userMessage` reference */ if (typeof opts?.getReqData === 'function') { opts.getReqData({ @@ -801,7 +827,8 @@ class BaseClient { } /* Note: we update the user message to be sure it gets the calculated token count; - though `AskController` saves the user message, EditController does not + though `AgentController` saves the user message if not saved here + (noted by `savedMessageIds`), EditController does not */ await userMessagePromise; await this.updateMessageInDatabase({ @@ -1093,6 +1120,50 @@ class BaseClient { return numTokens; } + /** + * Merges completion content with existing content when editing TEXT or THINK types + * @param {Array} existingContent - The existing content array + * @param {Array} newCompletion - The new completion content + * @param {string} editedType - The type of content being edited + * @returns {Array} The merged content array + */ + mergeEditedContent(existingContent, newCompletion, editedType) { + if (!newCompletion.length) { + return existingContent.concat(newCompletion); + } + + if (editedType !== ContentTypes.TEXT && editedType !== ContentTypes.THINK) { + return existingContent.concat(newCompletion); + } + + const lastIndex = existingContent.length - 1; + const lastExisting = existingContent[lastIndex]; + const firstNew = newCompletion[0]; + + if (lastExisting?.type !== firstNew?.type || firstNew?.type !== editedType) { + return existingContent.concat(newCompletion); + } + + const mergedContent = [...existingContent]; + if (editedType === ContentTypes.TEXT) { + mergedContent[lastIndex] = { + ...mergedContent[lastIndex], + [ContentTypes.TEXT]: + (mergedContent[lastIndex][ContentTypes.TEXT] || '') + (firstNew[ContentTypes.TEXT] || ''), + }; + } else { + mergedContent[lastIndex] = { + ...mergedContent[lastIndex], + [ContentTypes.THINK]: + (mergedContent[lastIndex][ContentTypes.THINK] || '') + + (firstNew[ContentTypes.THINK] || ''), + }; + } + + // Add remaining completion items + return mergedContent.concat(newCompletion.slice(1)); + } + async sendPayload(payload, opts = {}) { if (opts && typeof opts === 'object') { this.setOptions(opts); diff --git a/api/app/clients/ChatGPTClient.js b/api/app/clients/ChatGPTClient.js deleted file mode 100644 index 555028dc3f..0000000000 --- a/api/app/clients/ChatGPTClient.js +++ /dev/null @@ -1,804 +0,0 @@ -const { Keyv } = require('keyv'); -const crypto = require('crypto'); -const { CohereClient } = require('cohere-ai'); -const { fetchEventSource } = require('@waylaidwanderer/fetch-event-source'); -const { constructAzureURL, genAzureChatCompletion } = require('@librechat/api'); -const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken'); -const { - ImageDetail, - EModelEndpoint, - resolveHeaders, - CohereConstants, - mapModelToAzureConfig, -} = require('librechat-data-provider'); -const { createContextHandlers } = require('./prompts'); -const { createCoherePayload } = require('./llm'); -const { extractBaseURL } = require('~/utils'); -const BaseClient = require('./BaseClient'); -const { logger } = require('~/config'); - -const CHATGPT_MODEL = 'gpt-3.5-turbo'; -const tokenizersCache = {}; - -class ChatGPTClient extends BaseClient { - constructor(apiKey, options = {}, cacheOptions = {}) { - super(apiKey, options, cacheOptions); - - cacheOptions.namespace = cacheOptions.namespace || 'chatgpt'; - this.conversationsCache = new Keyv(cacheOptions); - this.setOptions(options); - } - - setOptions(options) { - if (this.options && !this.options.replaceOptions) { - // nested options aren't spread properly, so we need to do this manually - this.options.modelOptions = { - ...this.options.modelOptions, - ...options.modelOptions, - }; - delete options.modelOptions; - // now we can merge options - this.options = { - ...this.options, - ...options, - }; - } else { - this.options = options; - } - - if (this.options.openaiApiKey) { - this.apiKey = this.options.openaiApiKey; - } - - const modelOptions = this.options.modelOptions || {}; - this.modelOptions = { - ...modelOptions, - // set some good defaults (check for undefined in some cases because they may be 0) - model: modelOptions.model || CHATGPT_MODEL, - temperature: typeof modelOptions.temperature === 'undefined' ? 0.8 : modelOptions.temperature, - top_p: typeof modelOptions.top_p === 'undefined' ? 1 : modelOptions.top_p, - presence_penalty: - typeof modelOptions.presence_penalty === 'undefined' ? 1 : modelOptions.presence_penalty, - stop: modelOptions.stop, - }; - - this.isChatGptModel = this.modelOptions.model.includes('gpt-'); - const { isChatGptModel } = this; - this.isUnofficialChatGptModel = - this.modelOptions.model.startsWith('text-chat') || - this.modelOptions.model.startsWith('text-davinci-002-render'); - const { isUnofficialChatGptModel } = this; - - // Davinci models have a max context length of 4097 tokens. - this.maxContextTokens = this.options.maxContextTokens || (isChatGptModel ? 4095 : 4097); - // I decided to reserve 1024 tokens for the response. - // The max prompt tokens is determined by the max context tokens minus the max response tokens. - // Earlier messages will be dropped until the prompt is within the limit. - this.maxResponseTokens = this.modelOptions.max_tokens || 1024; - this.maxPromptTokens = - this.options.maxPromptTokens || this.maxContextTokens - this.maxResponseTokens; - - if (this.maxPromptTokens + this.maxResponseTokens > this.maxContextTokens) { - throw new Error( - `maxPromptTokens + max_tokens (${this.maxPromptTokens} + ${this.maxResponseTokens} = ${ - this.maxPromptTokens + this.maxResponseTokens - }) must be less than or equal to maxContextTokens (${this.maxContextTokens})`, - ); - } - - this.userLabel = this.options.userLabel || 'User'; - this.chatGptLabel = this.options.chatGptLabel || 'ChatGPT'; - - if (isChatGptModel) { - // Use these faux tokens to help the AI understand the context since we are building the chat log ourselves. - // Trying to use "<|im_start|>" causes the AI to still generate "<" or "<|" at the end sometimes for some reason, - // without tripping the stop sequences, so I'm using "||>" instead. - this.startToken = '||>'; - this.endToken = ''; - this.gptEncoder = this.constructor.getTokenizer('cl100k_base'); - } else if (isUnofficialChatGptModel) { - this.startToken = '<|im_start|>'; - this.endToken = '<|im_end|>'; - this.gptEncoder = this.constructor.getTokenizer('text-davinci-003', true, { - '<|im_start|>': 100264, - '<|im_end|>': 100265, - }); - } else { - // Previously I was trying to use "<|endoftext|>" but there seems to be some bug with OpenAI's token counting - // system that causes only the first "<|endoftext|>" to be counted as 1 token, and the rest are not treated - // as a single token. So we're using this instead. - this.startToken = '||>'; - this.endToken = ''; - try { - this.gptEncoder = this.constructor.getTokenizer(this.modelOptions.model, true); - } catch { - this.gptEncoder = this.constructor.getTokenizer('text-davinci-003', true); - } - } - - if (!this.modelOptions.stop) { - const stopTokens = [this.startToken]; - if (this.endToken && this.endToken !== this.startToken) { - stopTokens.push(this.endToken); - } - stopTokens.push(`\n${this.userLabel}:`); - stopTokens.push('<|diff_marker|>'); - // I chose not to do one for `chatGptLabel` because I've never seen it happen - this.modelOptions.stop = stopTokens; - } - - if (this.options.reverseProxyUrl) { - this.completionsUrl = this.options.reverseProxyUrl; - } else if (isChatGptModel) { - this.completionsUrl = 'https://api.openai.com/v1/chat/completions'; - } else { - this.completionsUrl = 'https://api.openai.com/v1/completions'; - } - - return this; - } - - static getTokenizer(encoding, isModelName = false, extendSpecialTokens = {}) { - if (tokenizersCache[encoding]) { - return tokenizersCache[encoding]; - } - let tokenizer; - if (isModelName) { - tokenizer = encodingForModel(encoding, extendSpecialTokens); - } else { - tokenizer = getEncoding(encoding, extendSpecialTokens); - } - tokenizersCache[encoding] = tokenizer; - return tokenizer; - } - - /** @type {getCompletion} */ - async getCompletion(input, onProgress, onTokenProgress, abortController = null) { - if (!abortController) { - abortController = new AbortController(); - } - - let modelOptions = { ...this.modelOptions }; - if (typeof onProgress === 'function') { - modelOptions.stream = true; - } - if (this.isChatGptModel) { - modelOptions.messages = input; - } else { - modelOptions.prompt = input; - } - - if (this.useOpenRouter && modelOptions.prompt) { - delete modelOptions.stop; - } - - const { debug } = this.options; - let baseURL = this.completionsUrl; - if (debug) { - console.debug(); - console.debug(baseURL); - console.debug(modelOptions); - console.debug(); - } - - const opts = { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - }, - }; - - if (this.isVisionModel) { - modelOptions.max_tokens = 4000; - } - - /** @type {TAzureConfig | undefined} */ - const azureConfig = this.options?.req?.app?.locals?.[EModelEndpoint.azureOpenAI]; - - const isAzure = this.azure || this.options.azure; - if ( - (isAzure && this.isVisionModel && azureConfig) || - (azureConfig && this.isVisionModel && this.options.endpoint === EModelEndpoint.azureOpenAI) - ) { - const { modelGroupMap, groupMap } = azureConfig; - const { - azureOptions, - baseURL, - headers = {}, - serverless, - } = mapModelToAzureConfig({ - modelName: modelOptions.model, - modelGroupMap, - groupMap, - }); - opts.headers = resolveHeaders(headers); - this.langchainProxy = extractBaseURL(baseURL); - this.apiKey = azureOptions.azureOpenAIApiKey; - - const groupName = modelGroupMap[modelOptions.model].group; - this.options.addParams = azureConfig.groupMap[groupName].addParams; - this.options.dropParams = azureConfig.groupMap[groupName].dropParams; - // Note: `forcePrompt` not re-assigned as only chat models are vision models - - this.azure = !serverless && azureOptions; - this.azureEndpoint = - !serverless && genAzureChatCompletion(this.azure, modelOptions.model, this); - if (serverless === true) { - this.options.defaultQuery = azureOptions.azureOpenAIApiVersion - ? { 'api-version': azureOptions.azureOpenAIApiVersion } - : undefined; - this.options.headers['api-key'] = this.apiKey; - } - } - - if (this.options.defaultQuery) { - opts.defaultQuery = this.options.defaultQuery; - } - - if (this.options.headers) { - opts.headers = { ...opts.headers, ...this.options.headers }; - } - - if (isAzure) { - // Azure does not accept `model` in the body, so we need to remove it. - delete modelOptions.model; - - baseURL = this.langchainProxy - ? constructAzureURL({ - baseURL: this.langchainProxy, - azureOptions: this.azure, - }) - : this.azureEndpoint.split(/(? msg.role === 'system'); - - if (systemMessageIndex > 0) { - const [systemMessage] = messages.splice(systemMessageIndex, 1); - messages.unshift(systemMessage); - } - - modelOptions.messages = messages; - - if (messages.length === 1 && messages[0].role === 'system') { - modelOptions.messages[0].role = 'user'; - } - } - - if (this.options.addParams && typeof this.options.addParams === 'object') { - modelOptions = { - ...modelOptions, - ...this.options.addParams, - }; - logger.debug('[ChatGPTClient] chatCompletion: added params', { - addParams: this.options.addParams, - modelOptions, - }); - } - - if (this.options.dropParams && Array.isArray(this.options.dropParams)) { - this.options.dropParams.forEach((param) => { - delete modelOptions[param]; - }); - logger.debug('[ChatGPTClient] chatCompletion: dropped params', { - dropParams: this.options.dropParams, - modelOptions, - }); - } - - if (baseURL.startsWith(CohereConstants.API_URL)) { - const payload = createCoherePayload({ modelOptions }); - return await this.cohereChatCompletion({ payload, onTokenProgress }); - } - - if (baseURL.includes('v1') && !baseURL.includes('/completions') && !this.isChatCompletion) { - baseURL = baseURL.split('v1')[0] + 'v1/completions'; - } else if ( - baseURL.includes('v1') && - !baseURL.includes('/chat/completions') && - this.isChatCompletion - ) { - baseURL = baseURL.split('v1')[0] + 'v1/chat/completions'; - } - - const BASE_URL = new URL(baseURL); - if (opts.defaultQuery) { - Object.entries(opts.defaultQuery).forEach(([key, value]) => { - BASE_URL.searchParams.append(key, value); - }); - delete opts.defaultQuery; - } - - const completionsURL = BASE_URL.toString(); - opts.body = JSON.stringify(modelOptions); - - if (modelOptions.stream) { - return new Promise(async (resolve, reject) => { - try { - let done = false; - await fetchEventSource(completionsURL, { - ...opts, - signal: abortController.signal, - async onopen(response) { - if (response.status === 200) { - return; - } - if (debug) { - console.debug(response); - } - let error; - try { - const body = await response.text(); - error = new Error(`Failed to send message. HTTP ${response.status} - ${body}`); - error.status = response.status; - error.json = JSON.parse(body); - } catch { - error = error || new Error(`Failed to send message. HTTP ${response.status}`); - } - throw error; - }, - onclose() { - if (debug) { - console.debug('Server closed the connection unexpectedly, returning...'); - } - // workaround for private API not sending [DONE] event - if (!done) { - onProgress('[DONE]'); - resolve(); - } - }, - onerror(err) { - if (debug) { - console.debug(err); - } - // rethrow to stop the operation - throw err; - }, - onmessage(message) { - if (debug) { - console.debug(message); - } - if (!message.data || message.event === 'ping') { - return; - } - if (message.data === '[DONE]') { - onProgress('[DONE]'); - resolve(); - done = true; - return; - } - onProgress(JSON.parse(message.data)); - }, - }); - } catch (err) { - reject(err); - } - }); - } - const response = await fetch(completionsURL, { - ...opts, - signal: abortController.signal, - }); - if (response.status !== 200) { - const body = await response.text(); - const error = new Error(`Failed to send message. HTTP ${response.status} - ${body}`); - error.status = response.status; - try { - error.json = JSON.parse(body); - } catch { - error.body = body; - } - throw error; - } - return response.json(); - } - - /** @type {cohereChatCompletion} */ - async cohereChatCompletion({ payload, onTokenProgress }) { - const cohere = new CohereClient({ - token: this.apiKey, - environment: this.completionsUrl, - }); - - if (!payload.stream) { - const chatResponse = await cohere.chat(payload); - return chatResponse.text; - } - - const chatStream = await cohere.chatStream(payload); - let reply = ''; - for await (const message of chatStream) { - if (!message) { - continue; - } - - if (message.eventType === 'text-generation' && message.text) { - onTokenProgress(message.text); - reply += message.text; - } - /* - Cohere API Chinese Unicode character replacement hotfix. - Should be un-commented when the following issue is resolved: - https://github.com/cohere-ai/cohere-typescript/issues/151 - - else if (message.eventType === 'stream-end' && message.response) { - reply = message.response.text; - } - */ - } - - return reply; - } - - async generateTitle(userMessage, botMessage) { - const instructionsPayload = { - role: 'system', - content: `Write an extremely concise subtitle for this conversation with no more than a few words. All words should be capitalized. Exclude punctuation. - -||>Message: -${userMessage.message} -||>Response: -${botMessage.message} - -||>Title:`, - }; - - const titleGenClientOptions = JSON.parse(JSON.stringify(this.options)); - titleGenClientOptions.modelOptions = { - model: 'gpt-3.5-turbo', - temperature: 0, - presence_penalty: 0, - frequency_penalty: 0, - }; - const titleGenClient = new ChatGPTClient(this.apiKey, titleGenClientOptions); - const result = await titleGenClient.getCompletion([instructionsPayload], null); - // remove any non-alphanumeric characters, replace multiple spaces with 1, and then trim - return result.choices[0].message.content - .replace(/[^a-zA-Z0-9' ]/g, '') - .replace(/\s+/g, ' ') - .trim(); - } - - async sendMessage(message, opts = {}) { - if (opts.clientOptions && typeof opts.clientOptions === 'object') { - this.setOptions(opts.clientOptions); - } - - const conversationId = opts.conversationId || crypto.randomUUID(); - const parentMessageId = opts.parentMessageId || crypto.randomUUID(); - - let conversation = - typeof opts.conversation === 'object' - ? opts.conversation - : await this.conversationsCache.get(conversationId); - - let isNewConversation = false; - if (!conversation) { - conversation = { - messages: [], - createdAt: Date.now(), - }; - isNewConversation = true; - } - - const shouldGenerateTitle = opts.shouldGenerateTitle && isNewConversation; - - const userMessage = { - id: crypto.randomUUID(), - parentMessageId, - role: 'User', - message, - }; - conversation.messages.push(userMessage); - - // Doing it this way instead of having each message be a separate element in the array seems to be more reliable, - // especially when it comes to keeping the AI in character. It also seems to improve coherency and context retention. - const { prompt: payload, context } = await this.buildPrompt( - conversation.messages, - userMessage.id, - { - isChatGptModel: this.isChatGptModel, - promptPrefix: opts.promptPrefix, - }, - ); - - if (this.options.keepNecessaryMessagesOnly) { - conversation.messages = context; - } - - let reply = ''; - let result = null; - if (typeof opts.onProgress === 'function') { - await this.getCompletion( - payload, - (progressMessage) => { - if (progressMessage === '[DONE]') { - return; - } - const token = this.isChatGptModel - ? progressMessage.choices[0].delta.content - : progressMessage.choices[0].text; - // first event's delta content is always undefined - if (!token) { - return; - } - if (this.options.debug) { - console.debug(token); - } - if (token === this.endToken) { - return; - } - opts.onProgress(token); - reply += token; - }, - opts.abortController || new AbortController(), - ); - } else { - result = await this.getCompletion( - payload, - null, - opts.abortController || new AbortController(), - ); - if (this.options.debug) { - console.debug(JSON.stringify(result)); - } - if (this.isChatGptModel) { - reply = result.choices[0].message.content; - } else { - reply = result.choices[0].text.replace(this.endToken, ''); - } - } - - // avoids some rendering issues when using the CLI app - if (this.options.debug) { - console.debug(); - } - - reply = reply.trim(); - - const replyMessage = { - id: crypto.randomUUID(), - parentMessageId: userMessage.id, - role: 'ChatGPT', - message: reply, - }; - conversation.messages.push(replyMessage); - - const returnData = { - response: replyMessage.message, - conversationId, - parentMessageId: replyMessage.parentMessageId, - messageId: replyMessage.id, - details: result || {}, - }; - - if (shouldGenerateTitle) { - conversation.title = await this.generateTitle(userMessage, replyMessage); - returnData.title = conversation.title; - } - - await this.conversationsCache.set(conversationId, conversation); - - if (this.options.returnConversation) { - returnData.conversation = conversation; - } - - return returnData; - } - - async buildPrompt(messages, { isChatGptModel = false, promptPrefix = null }) { - promptPrefix = (promptPrefix || this.options.promptPrefix || '').trim(); - - // Handle attachments and create augmentedPrompt - if (this.options.attachments) { - const attachments = await this.options.attachments; - const lastMessage = messages[messages.length - 1]; - - if (this.message_file_map) { - this.message_file_map[lastMessage.messageId] = attachments; - } else { - this.message_file_map = { - [lastMessage.messageId]: attachments, - }; - } - - const files = await this.addImageURLs(lastMessage, attachments); - this.options.attachments = files; - - this.contextHandlers = createContextHandlers(this.options.req, lastMessage.text); - } - - if (this.message_file_map) { - this.contextHandlers = createContextHandlers( - this.options.req, - messages[messages.length - 1].text, - ); - } - - // Calculate image token cost and process embedded files - messages.forEach((message, i) => { - if (this.message_file_map && this.message_file_map[message.messageId]) { - const attachments = this.message_file_map[message.messageId]; - for (const file of attachments) { - if (file.embedded) { - this.contextHandlers?.processFile(file); - continue; - } - - messages[i].tokenCount = - (messages[i].tokenCount || 0) + - this.calculateImageTokenCost({ - width: file.width, - height: file.height, - detail: this.options.imageDetail ?? ImageDetail.auto, - }); - } - } - }); - - if (this.contextHandlers) { - this.augmentedPrompt = await this.contextHandlers.createContext(); - promptPrefix = this.augmentedPrompt + promptPrefix; - } - - if (promptPrefix) { - // If the prompt prefix doesn't end with the end token, add it. - if (!promptPrefix.endsWith(`${this.endToken}`)) { - promptPrefix = `${promptPrefix.trim()}${this.endToken}\n\n`; - } - promptPrefix = `${this.startToken}Instructions:\n${promptPrefix}`; - } - const promptSuffix = `${this.startToken}${this.chatGptLabel}:\n`; // Prompt ChatGPT to respond. - - const instructionsPayload = { - role: 'system', - content: promptPrefix, - }; - - const messagePayload = { - role: 'system', - content: promptSuffix, - }; - - let currentTokenCount; - if (isChatGptModel) { - currentTokenCount = - this.getTokenCountForMessage(instructionsPayload) + - this.getTokenCountForMessage(messagePayload); - } else { - currentTokenCount = this.getTokenCount(`${promptPrefix}${promptSuffix}`); - } - let promptBody = ''; - const maxTokenCount = this.maxPromptTokens; - - const context = []; - - // Iterate backwards through the messages, adding them to the prompt until we reach the max token count. - // Do this within a recursive async function so that it doesn't block the event loop for too long. - const buildPromptBody = async () => { - if (currentTokenCount < maxTokenCount && messages.length > 0) { - const message = messages.pop(); - const roleLabel = - message?.isCreatedByUser || message?.role?.toLowerCase() === 'user' - ? this.userLabel - : this.chatGptLabel; - const messageString = `${this.startToken}${roleLabel}:\n${ - message?.text ?? message?.message - }${this.endToken}\n`; - let newPromptBody; - if (promptBody || isChatGptModel) { - newPromptBody = `${messageString}${promptBody}`; - } else { - // Always insert prompt prefix before the last user message, if not gpt-3.5-turbo. - // This makes the AI obey the prompt instructions better, which is important for custom instructions. - // After a bunch of testing, it doesn't seem to cause the AI any confusion, even if you ask it things - // like "what's the last thing I wrote?". - newPromptBody = `${promptPrefix}${messageString}${promptBody}`; - } - - context.unshift(message); - - const tokenCountForMessage = this.getTokenCount(messageString); - const newTokenCount = currentTokenCount + tokenCountForMessage; - if (newTokenCount > maxTokenCount) { - if (promptBody) { - // This message would put us over the token limit, so don't add it. - return false; - } - // This is the first message, so we can't add it. Just throw an error. - throw new Error( - `Prompt is too long. Max token count is ${maxTokenCount}, but prompt is ${newTokenCount} tokens long.`, - ); - } - promptBody = newPromptBody; - currentTokenCount = newTokenCount; - // wait for next tick to avoid blocking the event loop - await new Promise((resolve) => setImmediate(resolve)); - return buildPromptBody(); - } - return true; - }; - - await buildPromptBody(); - - const prompt = `${promptBody}${promptSuffix}`; - if (isChatGptModel) { - messagePayload.content = prompt; - // Add 3 tokens for Assistant Label priming after all messages have been counted. - currentTokenCount += 3; - } - - // Use up to `this.maxContextTokens` tokens (prompt + response), but try to leave `this.maxTokens` tokens for the response. - this.modelOptions.max_tokens = Math.min( - this.maxContextTokens - currentTokenCount, - this.maxResponseTokens, - ); - - if (isChatGptModel) { - return { prompt: [instructionsPayload, messagePayload], context }; - } - return { prompt, context, promptTokens: currentTokenCount }; - } - - getTokenCount(text) { - return this.gptEncoder.encode(text, 'all').length; - } - - /** - * Algorithm adapted from "6. Counting tokens for chat API calls" of - * https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb - * - * An additional 3 tokens need to be added for assistant label priming after all messages have been counted. - * - * @param {Object} message - */ - getTokenCountForMessage(message) { - // Note: gpt-3.5-turbo and gpt-4 may update over time. Use default for these as well as for unknown models - let tokensPerMessage = 3; - let tokensPerName = 1; - - if (this.modelOptions.model === 'gpt-3.5-turbo-0301') { - tokensPerMessage = 4; - tokensPerName = -1; - } - - let numTokens = tokensPerMessage; - for (let [key, value] of Object.entries(message)) { - numTokens += this.getTokenCount(value); - if (key === 'name') { - numTokens += tokensPerName; - } - } - - return numTokens; - } -} - -module.exports = ChatGPTClient; diff --git a/api/app/clients/GoogleClient.js b/api/app/clients/GoogleClient.js index 817239d14f..2ec23a0a06 100644 --- a/api/app/clients/GoogleClient.js +++ b/api/app/clients/GoogleClient.js @@ -1,7 +1,7 @@ const { google } = require('googleapis'); -const { Tokenizer } = require('@librechat/api'); const { concat } = require('@langchain/core/utils/stream'); const { ChatVertexAI } = require('@langchain/google-vertexai'); +const { Tokenizer, getSafetySettings } = require('@librechat/api'); const { ChatGoogleGenerativeAI } = require('@langchain/google-genai'); const { GoogleGenerativeAI: GenAI } = require('@google/generative-ai'); const { HumanMessage, SystemMessage } = require('@langchain/core/messages'); @@ -12,13 +12,13 @@ const { endpointSettings, parseTextParts, EModelEndpoint, + googleSettings, ContentTypes, VisionModes, ErrorTypes, Constants, AuthKeys, } = require('librechat-data-provider'); -const { getSafetySettings } = require('~/server/services/Endpoints/google/llm'); const { encodeAndFormat } = require('~/server/services/Files/images'); const { spendTokens } = require('~/models/spendTokens'); const { getModelMaxTokens } = require('~/utils'); @@ -166,6 +166,16 @@ class GoogleClient extends BaseClient { ); } + // Add thinking configuration + this.modelOptions.thinkingConfig = { + thinkingBudget: + (this.modelOptions.thinking ?? googleSettings.thinking.default) + ? this.modelOptions.thinkingBudget + : 0, + }; + delete this.modelOptions.thinking; + delete this.modelOptions.thinkingBudget; + this.sender = this.options.sender ?? getResponseSender({ diff --git a/api/app/clients/OpenAIClient.js b/api/app/clients/OpenAIClient.js index 2d4146bd9c..2eda322640 100644 --- a/api/app/clients/OpenAIClient.js +++ b/api/app/clients/OpenAIClient.js @@ -5,6 +5,7 @@ const { isEnabled, Tokenizer, createFetch, + resolveHeaders, constructAzureURL, genAzureChatCompletion, createStreamEventHandlers, @@ -15,7 +16,6 @@ const { ContentTypes, parseTextParts, EModelEndpoint, - resolveHeaders, KnownEndpoints, openAISettings, ImageDetailCost, @@ -37,7 +37,6 @@ const { addSpaceIfNeeded, sleep } = require('~/server/utils'); const { spendTokens } = require('~/models/spendTokens'); const { handleOpenAIErrors } = require('./tools/util'); const { createLLM, RunManager } = require('./llm'); -const ChatGPTClient = require('./ChatGPTClient'); const { summaryBuffer } = require('./memory'); const { runTitleChain } = require('./chains'); const { tokenSplit } = require('./document'); @@ -47,12 +46,6 @@ const { logger } = require('~/config'); class OpenAIClient extends BaseClient { constructor(apiKey, options = {}) { super(apiKey, options); - this.ChatGPTClient = new ChatGPTClient(); - this.buildPrompt = this.ChatGPTClient.buildPrompt.bind(this); - /** @type {getCompletion} */ - this.getCompletion = this.ChatGPTClient.getCompletion.bind(this); - /** @type {cohereChatCompletion} */ - this.cohereChatCompletion = this.ChatGPTClient.cohereChatCompletion.bind(this); this.contextStrategy = options.contextStrategy ? options.contextStrategy.toLowerCase() : 'discard'; @@ -379,23 +372,12 @@ class OpenAIClient extends BaseClient { return files; } - async buildMessages( - messages, - parentMessageId, - { isChatCompletion = false, promptPrefix = null }, - opts, - ) { + async buildMessages(messages, parentMessageId, { promptPrefix = null }, opts) { let orderedMessages = this.constructor.getMessagesForConversation({ messages, parentMessageId, summary: this.shouldSummarize, }); - if (!isChatCompletion) { - return await this.buildPrompt(orderedMessages, { - isChatGptModel: isChatCompletion, - promptPrefix, - }); - } let payload; let instructions; diff --git a/api/app/clients/PluginsClient.js b/api/app/clients/PluginsClient.js deleted file mode 100644 index d0ffe2ef75..0000000000 --- a/api/app/clients/PluginsClient.js +++ /dev/null @@ -1,542 +0,0 @@ -const OpenAIClient = require('./OpenAIClient'); -const { CallbackManager } = require('@langchain/core/callbacks/manager'); -const { BufferMemory, ChatMessageHistory } = require('langchain/memory'); -const { addImages, buildErrorInput, buildPromptPrefix } = require('./output_parsers'); -const { initializeCustomAgent, initializeFunctionsAgent } = require('./agents'); -const { processFileURL } = require('~/server/services/Files/process'); -const { EModelEndpoint } = require('librechat-data-provider'); -const { checkBalance } = require('~/models/balanceMethods'); -const { formatLangChainMessages } = require('./prompts'); -const { extractBaseURL } = require('~/utils'); -const { loadTools } = require('./tools/util'); -const { logger } = require('~/config'); - -class PluginsClient extends OpenAIClient { - constructor(apiKey, options = {}) { - super(apiKey, options); - this.sender = options.sender ?? 'Assistant'; - this.tools = []; - this.actions = []; - this.setOptions(options); - this.openAIApiKey = this.apiKey; - this.executor = null; - } - - setOptions(options) { - this.agentOptions = { ...options.agentOptions }; - this.functionsAgent = this.agentOptions?.agent === 'functions'; - this.agentIsGpt3 = this.agentOptions?.model?.includes('gpt-3'); - - super.setOptions(options); - - this.isGpt3 = this.modelOptions?.model?.includes('gpt-3'); - - if (this.options.reverseProxyUrl) { - this.langchainProxy = extractBaseURL(this.options.reverseProxyUrl); - } - } - - getSaveOptions() { - return { - artifacts: this.options.artifacts, - chatGptLabel: this.options.chatGptLabel, - modelLabel: this.options.modelLabel, - promptPrefix: this.options.promptPrefix, - tools: this.options.tools, - ...this.modelOptions, - agentOptions: this.agentOptions, - iconURL: this.options.iconURL, - greeting: this.options.greeting, - spec: this.options.spec, - }; - } - - saveLatestAction(action) { - this.actions.push(action); - } - - getFunctionModelName(input) { - if (/-(?!0314)\d{4}/.test(input)) { - return input; - } else if (input.includes('gpt-3.5-turbo')) { - return 'gpt-3.5-turbo'; - } else if (input.includes('gpt-4')) { - return 'gpt-4'; - } else { - return 'gpt-3.5-turbo'; - } - } - - getBuildMessagesOptions(opts) { - return { - isChatCompletion: true, - promptPrefix: opts.promptPrefix, - abortController: opts.abortController, - }; - } - - async initialize({ user, message, onAgentAction, onChainEnd, signal }) { - const modelOptions = { - modelName: this.agentOptions.model, - temperature: this.agentOptions.temperature, - }; - - const model = this.initializeLLM({ - ...modelOptions, - context: 'plugins', - initialMessageCount: this.currentMessages.length + 1, - }); - - logger.debug( - `[PluginsClient] Agent Model: ${model.modelName} | Temp: ${model.temperature} | Functions: ${this.functionsAgent}`, - ); - - // Map Messages to Langchain format - const pastMessages = formatLangChainMessages(this.currentMessages.slice(0, -1), { - userName: this.options?.name, - }); - logger.debug('[PluginsClient] pastMessages: ' + pastMessages.length); - - // TODO: use readOnly memory, TokenBufferMemory? (both unavailable in LangChainJS) - const memory = new BufferMemory({ - llm: model, - chatHistory: new ChatMessageHistory(pastMessages), - }); - - const { loadedTools } = await loadTools({ - user, - model, - tools: this.options.tools, - functions: this.functionsAgent, - options: { - memory, - signal: this.abortController.signal, - openAIApiKey: this.openAIApiKey, - conversationId: this.conversationId, - fileStrategy: this.options.req.app.locals.fileStrategy, - processFileURL, - message, - }, - useSpecs: true, - }); - - if (loadedTools.length === 0) { - return; - } - - this.tools = loadedTools; - - logger.debug('[PluginsClient] Requested Tools', this.options.tools); - logger.debug( - '[PluginsClient] Loaded Tools', - this.tools.map((tool) => tool.name), - ); - - const handleAction = (action, runId, callback = null) => { - this.saveLatestAction(action); - - logger.debug('[PluginsClient] Latest Agent Action ', this.actions[this.actions.length - 1]); - - if (typeof callback === 'function') { - callback(action, runId); - } - }; - - // initialize agent - const initializer = this.functionsAgent ? initializeFunctionsAgent : initializeCustomAgent; - - let customInstructions = (this.options.promptPrefix ?? '').trim(); - if (typeof this.options.artifactsPrompt === 'string' && this.options.artifactsPrompt) { - customInstructions = `${customInstructions ?? ''}\n${this.options.artifactsPrompt}`.trim(); - } - - this.executor = await initializer({ - model, - signal, - pastMessages, - tools: this.tools, - customInstructions, - verbose: this.options.debug, - returnIntermediateSteps: true, - customName: this.options.chatGptLabel, - currentDateString: this.currentDateString, - callbackManager: CallbackManager.fromHandlers({ - async handleAgentAction(action, runId) { - handleAction(action, runId, onAgentAction); - }, - async handleChainEnd(action) { - if (typeof onChainEnd === 'function') { - onChainEnd(action); - } - }, - }), - }); - - logger.debug('[PluginsClient] Loaded agent.'); - } - - async executorCall(message, { signal, stream, onToolStart, onToolEnd }) { - let errorMessage = ''; - const maxAttempts = 1; - - for (let attempts = 1; attempts <= maxAttempts; attempts++) { - const errorInput = buildErrorInput({ - message, - errorMessage, - actions: this.actions, - functionsAgent: this.functionsAgent, - }); - const input = attempts > 1 ? errorInput : message; - - logger.debug(`[PluginsClient] Attempt ${attempts} of ${maxAttempts}`); - - if (errorMessage.length > 0) { - logger.debug('[PluginsClient] Caught error, input: ' + JSON.stringify(input)); - } - - try { - this.result = await this.executor.call({ input, signal }, [ - { - async handleToolStart(...args) { - await onToolStart(...args); - }, - async handleToolEnd(...args) { - await onToolEnd(...args); - }, - async handleLLMEnd(output) { - const { generations } = output; - const { text } = generations[0][0]; - if (text && typeof stream === 'function') { - await stream(text); - } - }, - }, - ]); - break; // Exit the loop if the function call is successful - } catch (err) { - logger.error('[PluginsClient] executorCall error:', err); - if (attempts === maxAttempts) { - const { run } = this.runManager.getRunByConversationId(this.conversationId); - const defaultOutput = `Encountered an error while attempting to respond: ${err.message}`; - this.result.output = run && run.error ? run.error : defaultOutput; - this.result.errorMessage = run && run.error ? run.error : err.message; - this.result.intermediateSteps = this.actions; - break; - } - } - } - } - - /** - * - * @param {TMessage} responseMessage - * @param {Partial} saveOptions - * @param {string} user - * @returns - */ - async handleResponseMessage(responseMessage, saveOptions, user) { - const { output, errorMessage, ...result } = this.result; - logger.debug('[PluginsClient][handleResponseMessage] Output:', { - output, - errorMessage, - ...result, - }); - const { error } = responseMessage; - if (!error) { - responseMessage.tokenCount = this.getTokenCountForResponse(responseMessage); - responseMessage.completionTokens = this.getTokenCount(responseMessage.text); - } - - // Record usage only when completion is skipped as it is already recorded in the agent phase. - if (!this.agentOptions.skipCompletion && !error) { - await this.recordTokenUsage(responseMessage); - } - - const databasePromise = this.saveMessageToDatabase(responseMessage, saveOptions, user); - delete responseMessage.tokenCount; - return { ...responseMessage, ...result, databasePromise }; - } - - async sendMessage(message, opts = {}) { - /** @type {Promise} */ - let userMessagePromise; - /** @type {{ filteredTools: string[], includedTools: string[] }} */ - const { filteredTools = [], includedTools = [] } = this.options.req.app.locals; - - if (includedTools.length > 0) { - const tools = this.options.tools.filter((plugin) => includedTools.includes(plugin)); - this.options.tools = tools; - } else { - const tools = this.options.tools.filter((plugin) => !filteredTools.includes(plugin)); - this.options.tools = tools; - } - - // If a message is edited, no tools can be used. - const completionMode = this.options.tools.length === 0 || opts.isEdited; - if (completionMode) { - this.setOptions(opts); - return super.sendMessage(message, opts); - } - - logger.debug('[PluginsClient] sendMessage', { userMessageText: message, opts }); - const { - user, - conversationId, - responseMessageId, - saveOptions, - userMessage, - onAgentAction, - onChainEnd, - onToolStart, - onToolEnd, - } = await this.handleStartMethods(message, opts); - - if (opts.progressCallback) { - opts.onProgress = opts.progressCallback.call(null, { - ...(opts.progressOptions ?? {}), - parentMessageId: userMessage.messageId, - messageId: responseMessageId, - }); - } - - this.currentMessages.push(userMessage); - - let { - prompt: payload, - tokenCountMap, - promptTokens, - } = await this.buildMessages( - this.currentMessages, - userMessage.messageId, - this.getBuildMessagesOptions({ - promptPrefix: null, - abortController: this.abortController, - }), - ); - - if (tokenCountMap) { - logger.debug('[PluginsClient] tokenCountMap', { tokenCountMap }); - if (tokenCountMap[userMessage.messageId]) { - userMessage.tokenCount = tokenCountMap[userMessage.messageId]; - logger.debug('[PluginsClient] userMessage.tokenCount', userMessage.tokenCount); - } - this.handleTokenCountMap(tokenCountMap); - } - - this.result = {}; - if (payload) { - this.currentMessages = payload; - } - - if (!this.skipSaveUserMessage) { - userMessagePromise = this.saveMessageToDatabase(userMessage, saveOptions, user); - if (typeof opts?.getReqData === 'function') { - opts.getReqData({ - userMessagePromise, - }); - } - } - - const balance = this.options.req?.app?.locals?.balance; - if (balance?.enabled) { - await checkBalance({ - req: this.options.req, - res: this.options.res, - txData: { - user: this.user, - tokenType: 'prompt', - amount: promptTokens, - debug: this.options.debug, - model: this.modelOptions.model, - endpoint: EModelEndpoint.openAI, - }, - }); - } - - const responseMessage = { - endpoint: EModelEndpoint.gptPlugins, - iconURL: this.options.iconURL, - messageId: responseMessageId, - conversationId, - parentMessageId: userMessage.messageId, - isCreatedByUser: false, - model: this.modelOptions.model, - sender: this.sender, - promptTokens, - }; - - await this.initialize({ - user, - message, - onAgentAction, - onChainEnd, - signal: this.abortController.signal, - onProgress: opts.onProgress, - }); - - // const stream = async (text) => { - // await this.generateTextStream.call(this, text, opts.onProgress, { delay: 1 }); - // }; - await this.executorCall(message, { - signal: this.abortController.signal, - // stream, - onToolStart, - onToolEnd, - }); - - // If message was aborted mid-generation - if (this.result?.errorMessage?.length > 0 && this.result?.errorMessage?.includes('cancel')) { - responseMessage.text = 'Cancelled.'; - return await this.handleResponseMessage(responseMessage, saveOptions, user); - } - - // If error occurred during generation (likely token_balance) - if (this.result?.errorMessage?.length > 0) { - responseMessage.error = true; - responseMessage.text = this.result.output; - return await this.handleResponseMessage(responseMessage, saveOptions, user); - } - - if (this.agentOptions.skipCompletion && this.result.output && this.functionsAgent) { - const partialText = opts.getPartialText(); - const trimmedPartial = opts.getPartialText().replaceAll(':::plugin:::\n', ''); - responseMessage.text = - trimmedPartial.length === 0 ? `${partialText}${this.result.output}` : partialText; - addImages(this.result.intermediateSteps, responseMessage); - await this.generateTextStream(this.result.output, opts.onProgress, { delay: 5 }); - return await this.handleResponseMessage(responseMessage, saveOptions, user); - } - - if (this.agentOptions.skipCompletion && this.result.output) { - responseMessage.text = this.result.output; - addImages(this.result.intermediateSteps, responseMessage); - await this.generateTextStream(this.result.output, opts.onProgress, { delay: 5 }); - return await this.handleResponseMessage(responseMessage, saveOptions, user); - } - - logger.debug('[PluginsClient] Completion phase: this.result', this.result); - - const promptPrefix = buildPromptPrefix({ - result: this.result, - message, - functionsAgent: this.functionsAgent, - }); - - logger.debug('[PluginsClient]', { promptPrefix }); - - payload = await this.buildCompletionPrompt({ - messages: this.currentMessages, - promptPrefix, - }); - - logger.debug('[PluginsClient] buildCompletionPrompt Payload', payload); - responseMessage.text = await this.sendCompletion(payload, opts); - return await this.handleResponseMessage(responseMessage, saveOptions, user); - } - - async buildCompletionPrompt({ messages, promptPrefix: _promptPrefix }) { - logger.debug('[PluginsClient] buildCompletionPrompt messages', messages); - - const orderedMessages = messages; - let promptPrefix = _promptPrefix.trim(); - // If the prompt prefix doesn't end with the end token, add it. - if (!promptPrefix.endsWith(`${this.endToken}`)) { - promptPrefix = `${promptPrefix.trim()}${this.endToken}\n\n`; - } - promptPrefix = `${this.startToken}Instructions:\n${promptPrefix}`; - const promptSuffix = `${this.startToken}${this.chatGptLabel ?? 'Assistant'}:\n`; - - const instructionsPayload = { - role: 'system', - content: promptPrefix, - }; - - const messagePayload = { - role: 'system', - content: promptSuffix, - }; - - if (this.isGpt3) { - instructionsPayload.role = 'user'; - messagePayload.role = 'user'; - instructionsPayload.content += `\n${promptSuffix}`; - } - - // testing if this works with browser endpoint - if (!this.isGpt3 && this.options.reverseProxyUrl) { - instructionsPayload.role = 'user'; - } - - let currentTokenCount = - this.getTokenCountForMessage(instructionsPayload) + - this.getTokenCountForMessage(messagePayload); - - let promptBody = ''; - const maxTokenCount = this.maxPromptTokens; - // Iterate backwards through the messages, adding them to the prompt until we reach the max token count. - // Do this within a recursive async function so that it doesn't block the event loop for too long. - const buildPromptBody = async () => { - if (currentTokenCount < maxTokenCount && orderedMessages.length > 0) { - const message = orderedMessages.pop(); - const isCreatedByUser = message.isCreatedByUser || message.role?.toLowerCase() === 'user'; - const roleLabel = isCreatedByUser ? this.userLabel : this.chatGptLabel; - let messageString = `${this.startToken}${roleLabel}:\n${ - message.text ?? message.content ?? '' - }${this.endToken}\n`; - let newPromptBody = `${messageString}${promptBody}`; - - const tokenCountForMessage = this.getTokenCount(messageString); - const newTokenCount = currentTokenCount + tokenCountForMessage; - if (newTokenCount > maxTokenCount) { - if (promptBody) { - // This message would put us over the token limit, so don't add it. - return false; - } - // This is the first message, so we can't add it. Just throw an error. - throw new Error( - `Prompt is too long. Max token count is ${maxTokenCount}, but prompt is ${newTokenCount} tokens long.`, - ); - } - promptBody = newPromptBody; - currentTokenCount = newTokenCount; - // wait for next tick to avoid blocking the event loop - await new Promise((resolve) => setTimeout(resolve, 0)); - return buildPromptBody(); - } - return true; - }; - - await buildPromptBody(); - const prompt = promptBody; - messagePayload.content = prompt; - // Add 2 tokens for metadata after all messages have been counted. - currentTokenCount += 2; - - if (this.isGpt3 && messagePayload.content.length > 0) { - const context = 'Chat History:\n'; - messagePayload.content = `${context}${prompt}`; - currentTokenCount += this.getTokenCount(context); - } - - // Use up to `this.maxContextTokens` tokens (prompt + response), but try to leave `this.maxTokens` tokens for the response. - this.modelOptions.max_tokens = Math.min( - this.maxContextTokens - currentTokenCount, - this.maxResponseTokens, - ); - - if (this.isGpt3) { - messagePayload.content += promptSuffix; - return [instructionsPayload, messagePayload]; - } - - const result = [messagePayload, instructionsPayload]; - - if (this.functionsAgent && !this.isGpt3) { - result[1].content = `${result[1].content}\n${this.startToken}${this.chatGptLabel}:\nSure thing! Here is the output you requested:\n`; - } - - return result.filter((message) => message.content.length > 0); - } -} - -module.exports = PluginsClient; diff --git a/api/app/clients/index.js b/api/app/clients/index.js index a5e8eee504..d8b2bae27b 100644 --- a/api/app/clients/index.js +++ b/api/app/clients/index.js @@ -1,15 +1,11 @@ -const ChatGPTClient = require('./ChatGPTClient'); const OpenAIClient = require('./OpenAIClient'); -const PluginsClient = require('./PluginsClient'); const GoogleClient = require('./GoogleClient'); const TextStream = require('./TextStream'); const AnthropicClient = require('./AnthropicClient'); const toolUtils = require('./tools/util'); module.exports = { - ChatGPTClient, OpenAIClient, - PluginsClient, GoogleClient, TextStream, AnthropicClient, diff --git a/api/app/clients/prompts/createContextHandlers.js b/api/app/clients/prompts/createContextHandlers.js index 4dcfaf68e4..b3ea9164e7 100644 --- a/api/app/clients/prompts/createContextHandlers.js +++ b/api/app/clients/prompts/createContextHandlers.js @@ -1,6 +1,7 @@ const axios = require('axios'); -const { isEnabled } = require('~/server/utils'); -const { logger } = require('~/config'); +const { isEnabled } = require('@librechat/api'); +const { logger } = require('@librechat/data-schemas'); +const { generateShortLivedToken } = require('~/server/services/AuthService'); const footer = `Use the context as your learned knowledge to better answer the user. @@ -18,7 +19,7 @@ function createContextHandlers(req, userMessageContent) { const queryPromises = []; const processedFiles = []; const processedIds = new Set(); - const jwtToken = req.headers.authorization.split(' ')[1]; + const jwtToken = generateShortLivedToken(req.user.id); const useFullContext = isEnabled(process.env.RAG_USE_FULL_CONTEXT); const query = async (file) => { @@ -96,35 +97,35 @@ function createContextHandlers(req, userMessageContent) { resolvedQueries.length === 0 ? '\n\tThe semantic search did not return any results.' : resolvedQueries - .map((queryResult, index) => { - const file = processedFiles[index]; - let contextItems = queryResult.data; + .map((queryResult, index) => { + const file = processedFiles[index]; + let contextItems = queryResult.data; - const generateContext = (currentContext) => - ` + const generateContext = (currentContext) => + ` ${file.filename} ${currentContext} `; - if (useFullContext) { - return generateContext(`\n${contextItems}`); - } + if (useFullContext) { + return generateContext(`\n${contextItems}`); + } - contextItems = queryResult.data - .map((item) => { - const pageContent = item[0].page_content; - return ` + contextItems = queryResult.data + .map((item) => { + const pageContent = item[0].page_content; + return ` `; - }) - .join(''); + }) + .join(''); - return generateContext(contextItems); - }) - .join(''); + return generateContext(contextItems); + }) + .join(''); if (useFullContext) { const prompt = `${header} diff --git a/api/app/clients/specs/OpenAIClient.test.js b/api/app/clients/specs/OpenAIClient.test.js index cc4aa84d5d..efca66a867 100644 --- a/api/app/clients/specs/OpenAIClient.test.js +++ b/api/app/clients/specs/OpenAIClient.test.js @@ -531,44 +531,6 @@ describe('OpenAIClient', () => { }); }); - describe('sendMessage/getCompletion/chatCompletion', () => { - afterEach(() => { - delete process.env.AZURE_OPENAI_DEFAULT_MODEL; - delete process.env.AZURE_USE_MODEL_AS_DEPLOYMENT_NAME; - }); - - it('should call getCompletion and fetchEventSource when using a text/instruct model', async () => { - const model = 'text-davinci-003'; - const onProgress = jest.fn().mockImplementation(() => ({})); - - const testClient = new OpenAIClient('test-api-key', { - ...defaultOptions, - modelOptions: { model }, - }); - - const getCompletion = jest.spyOn(testClient, 'getCompletion'); - await testClient.sendMessage('Hi mom!', { onProgress }); - - expect(getCompletion).toHaveBeenCalled(); - expect(getCompletion.mock.calls.length).toBe(1); - - expect(getCompletion.mock.calls[0][0]).toBe('||>User:\nHi mom!\n||>Assistant:\n'); - - expect(fetchEventSource).toHaveBeenCalled(); - expect(fetchEventSource.mock.calls.length).toBe(1); - - // Check if the first argument (url) is correct - const firstCallArgs = fetchEventSource.mock.calls[0]; - - const expectedURL = 'https://api.openai.com/v1/completions'; - expect(firstCallArgs[0]).toBe(expectedURL); - - const requestBody = JSON.parse(firstCallArgs[1].body); - expect(requestBody).toHaveProperty('model'); - expect(requestBody.model).toBe(model); - }); - }); - describe('checkVisionRequest functionality', () => { let client; const attachments = [{ type: 'image/png' }]; diff --git a/api/app/clients/specs/PluginsClient.test.js b/api/app/clients/specs/PluginsClient.test.js deleted file mode 100644 index 4928acefd1..0000000000 --- a/api/app/clients/specs/PluginsClient.test.js +++ /dev/null @@ -1,314 +0,0 @@ -const crypto = require('crypto'); -const { Constants } = require('librechat-data-provider'); -const { HumanMessage, AIMessage } = require('@langchain/core/messages'); -const PluginsClient = require('../PluginsClient'); - -jest.mock('~/db/connect'); -jest.mock('~/models/Conversation', () => { - return function () { - return { - save: jest.fn(), - deleteConvos: jest.fn(), - }; - }; -}); - -const defaultAzureOptions = { - azureOpenAIApiInstanceName: 'your-instance-name', - azureOpenAIApiDeploymentName: 'your-deployment-name', - azureOpenAIApiVersion: '2020-07-01-preview', -}; - -describe('PluginsClient', () => { - let TestAgent; - let options = { - tools: [], - modelOptions: { - model: 'gpt-3.5-turbo', - temperature: 0, - max_tokens: 2, - }, - agentOptions: { - model: 'gpt-3.5-turbo', - }, - }; - let parentMessageId; - let conversationId; - const fakeMessages = []; - const userMessage = 'Hello, ChatGPT!'; - const apiKey = 'fake-api-key'; - - beforeEach(() => { - TestAgent = new PluginsClient(apiKey, options); - TestAgent.loadHistory = jest - .fn() - .mockImplementation((conversationId, parentMessageId = null) => { - if (!conversationId) { - TestAgent.currentMessages = []; - return Promise.resolve([]); - } - - const orderedMessages = TestAgent.constructor.getMessagesForConversation({ - messages: fakeMessages, - parentMessageId, - }); - - const chatMessages = orderedMessages.map((msg) => - msg?.isCreatedByUser || msg?.role?.toLowerCase() === 'user' - ? new HumanMessage(msg.text) - : new AIMessage(msg.text), - ); - - TestAgent.currentMessages = orderedMessages; - return Promise.resolve(chatMessages); - }); - TestAgent.sendMessage = jest.fn().mockImplementation(async (message, opts = {}) => { - if (opts && typeof opts === 'object') { - TestAgent.setOptions(opts); - } - const conversationId = opts.conversationId || crypto.randomUUID(); - const parentMessageId = opts.parentMessageId || Constants.NO_PARENT; - const userMessageId = opts.overrideParentMessageId || crypto.randomUUID(); - this.pastMessages = await TestAgent.loadHistory( - conversationId, - TestAgent.options?.parentMessageId, - ); - - const userMessage = { - text: message, - sender: 'ChatGPT', - isCreatedByUser: true, - messageId: userMessageId, - parentMessageId, - conversationId, - }; - - const response = { - sender: 'ChatGPT', - text: 'Hello, User!', - isCreatedByUser: false, - messageId: crypto.randomUUID(), - parentMessageId: userMessage.messageId, - conversationId, - }; - - fakeMessages.push(userMessage); - fakeMessages.push(response); - return response; - }); - }); - - test('initializes PluginsClient without crashing', () => { - expect(TestAgent).toBeInstanceOf(PluginsClient); - }); - - test('check setOptions function', () => { - expect(TestAgent.agentIsGpt3).toBe(true); - }); - - describe('sendMessage', () => { - test('sendMessage should return a response message', async () => { - const expectedResult = expect.objectContaining({ - sender: 'ChatGPT', - text: expect.any(String), - isCreatedByUser: false, - messageId: expect.any(String), - parentMessageId: expect.any(String), - conversationId: expect.any(String), - }); - - const response = await TestAgent.sendMessage(userMessage); - parentMessageId = response.messageId; - conversationId = response.conversationId; - expect(response).toEqual(expectedResult); - }); - - test('sendMessage should work with provided conversationId and parentMessageId', async () => { - const userMessage = 'Second message in the conversation'; - const opts = { - conversationId, - parentMessageId, - }; - - const expectedResult = expect.objectContaining({ - sender: 'ChatGPT', - text: expect.any(String), - isCreatedByUser: false, - messageId: expect.any(String), - parentMessageId: expect.any(String), - conversationId: opts.conversationId, - }); - - const response = await TestAgent.sendMessage(userMessage, opts); - parentMessageId = response.messageId; - expect(response.conversationId).toEqual(conversationId); - expect(response).toEqual(expectedResult); - }); - - test('should return chat history', async () => { - const chatMessages = await TestAgent.loadHistory(conversationId, parentMessageId); - expect(TestAgent.currentMessages).toHaveLength(4); - expect(chatMessages[0].text).toEqual(userMessage); - }); - }); - - describe('getFunctionModelName', () => { - let client; - - beforeEach(() => { - client = new PluginsClient('dummy_api_key'); - }); - - test('should return the input when it includes a dash followed by four digits', () => { - expect(client.getFunctionModelName('-1234')).toBe('-1234'); - expect(client.getFunctionModelName('gpt-4-5678-preview')).toBe('gpt-4-5678-preview'); - }); - - test('should return the input for all function-capable models (`0613` models and above)', () => { - expect(client.getFunctionModelName('gpt-4-0613')).toBe('gpt-4-0613'); - expect(client.getFunctionModelName('gpt-4-32k-0613')).toBe('gpt-4-32k-0613'); - expect(client.getFunctionModelName('gpt-3.5-turbo-0613')).toBe('gpt-3.5-turbo-0613'); - expect(client.getFunctionModelName('gpt-3.5-turbo-16k-0613')).toBe('gpt-3.5-turbo-16k-0613'); - expect(client.getFunctionModelName('gpt-3.5-turbo-1106')).toBe('gpt-3.5-turbo-1106'); - expect(client.getFunctionModelName('gpt-4-1106-preview')).toBe('gpt-4-1106-preview'); - expect(client.getFunctionModelName('gpt-4-1106')).toBe('gpt-4-1106'); - }); - - test('should return the corresponding model if input is non-function capable (`0314` models)', () => { - expect(client.getFunctionModelName('gpt-4-0314')).toBe('gpt-4'); - expect(client.getFunctionModelName('gpt-4-32k-0314')).toBe('gpt-4'); - expect(client.getFunctionModelName('gpt-3.5-turbo-0314')).toBe('gpt-3.5-turbo'); - expect(client.getFunctionModelName('gpt-3.5-turbo-16k-0314')).toBe('gpt-3.5-turbo'); - }); - - test('should return "gpt-3.5-turbo" when the input includes "gpt-3.5-turbo"', () => { - expect(client.getFunctionModelName('test gpt-3.5-turbo model')).toBe('gpt-3.5-turbo'); - }); - - test('should return "gpt-4" when the input includes "gpt-4"', () => { - expect(client.getFunctionModelName('testing gpt-4')).toBe('gpt-4'); - }); - - test('should return "gpt-3.5-turbo" for input that does not meet any specific condition', () => { - expect(client.getFunctionModelName('random string')).toBe('gpt-3.5-turbo'); - expect(client.getFunctionModelName('')).toBe('gpt-3.5-turbo'); - }); - }); - - describe('Azure OpenAI tests specific to Plugins', () => { - // TODO: add more tests for Azure OpenAI integration with Plugins - // let client; - // beforeEach(() => { - // client = new PluginsClient('dummy_api_key'); - // }); - - test('should not call getFunctionModelName when azure options are set', () => { - const spy = jest.spyOn(PluginsClient.prototype, 'getFunctionModelName'); - const model = 'gpt-4-turbo'; - - // note, without the azure change in PR #1766, `getFunctionModelName` is called twice - const testClient = new PluginsClient('dummy_api_key', { - agentOptions: { - model, - agent: 'functions', - }, - azure: defaultAzureOptions, - }); - - expect(spy).not.toHaveBeenCalled(); - expect(testClient.agentOptions.model).toBe(model); - - spy.mockRestore(); - }); - }); - - describe('sendMessage with filtered tools', () => { - let TestAgent; - const apiKey = 'fake-api-key'; - const mockTools = [{ name: 'tool1' }, { name: 'tool2' }, { name: 'tool3' }, { name: 'tool4' }]; - - beforeEach(() => { - TestAgent = new PluginsClient(apiKey, { - tools: mockTools, - modelOptions: { - model: 'gpt-3.5-turbo', - temperature: 0, - max_tokens: 2, - }, - agentOptions: { - model: 'gpt-3.5-turbo', - }, - }); - - TestAgent.options.req = { - app: { - locals: {}, - }, - }; - - TestAgent.sendMessage = jest.fn().mockImplementation(async () => { - const { filteredTools = [], includedTools = [] } = TestAgent.options.req.app.locals; - - if (includedTools.length > 0) { - const tools = TestAgent.options.tools.filter((plugin) => - includedTools.includes(plugin.name), - ); - TestAgent.options.tools = tools; - } else { - const tools = TestAgent.options.tools.filter( - (plugin) => !filteredTools.includes(plugin.name), - ); - TestAgent.options.tools = tools; - } - - return { - text: 'Mocked response', - tools: TestAgent.options.tools, - }; - }); - }); - - test('should filter out tools when filteredTools is provided', async () => { - TestAgent.options.req.app.locals.filteredTools = ['tool1', 'tool3']; - const response = await TestAgent.sendMessage('Test message'); - expect(response.tools).toHaveLength(2); - expect(response.tools).toEqual( - expect.arrayContaining([ - expect.objectContaining({ name: 'tool2' }), - expect.objectContaining({ name: 'tool4' }), - ]), - ); - }); - - test('should only include specified tools when includedTools is provided', async () => { - TestAgent.options.req.app.locals.includedTools = ['tool2', 'tool4']; - const response = await TestAgent.sendMessage('Test message'); - expect(response.tools).toHaveLength(2); - expect(response.tools).toEqual( - expect.arrayContaining([ - expect.objectContaining({ name: 'tool2' }), - expect.objectContaining({ name: 'tool4' }), - ]), - ); - }); - - test('should prioritize includedTools over filteredTools', async () => { - TestAgent.options.req.app.locals.filteredTools = ['tool1', 'tool3']; - TestAgent.options.req.app.locals.includedTools = ['tool1', 'tool2']; - const response = await TestAgent.sendMessage('Test message'); - expect(response.tools).toHaveLength(2); - expect(response.tools).toEqual( - expect.arrayContaining([ - expect.objectContaining({ name: 'tool1' }), - expect.objectContaining({ name: 'tool2' }), - ]), - ); - }); - - test('should not modify tools when no filters are provided', async () => { - const response = await TestAgent.sendMessage('Test message'); - expect(response.tools).toHaveLength(4); - expect(response.tools).toEqual(expect.arrayContaining(mockTools)); - }); - }); -}); diff --git a/api/app/clients/tools/structured/OpenAIImageTools.js b/api/app/clients/tools/structured/OpenAIImageTools.js index 08e15a7fad..411db1edf9 100644 --- a/api/app/clients/tools/structured/OpenAIImageTools.js +++ b/api/app/clients/tools/structured/OpenAIImageTools.js @@ -107,6 +107,12 @@ const getImageEditPromptDescription = () => { return process.env.IMAGE_EDIT_OAI_PROMPT_DESCRIPTION || DEFAULT_IMAGE_EDIT_PROMPT_DESCRIPTION; }; +function createAbortHandler() { + return function () { + logger.debug('[ImageGenOAI] Image generation aborted'); + }; +} + /** * Creates OpenAI Image tools (generation and editing) * @param {Object} fields - Configuration fields @@ -201,10 +207,18 @@ function createOpenAIImageTools(fields = {}) { } let resp; + /** @type {AbortSignal} */ + let derivedSignal = null; + /** @type {() => void} */ + let abortHandler = null; + try { - const derivedSignal = runnableConfig?.signal - ? AbortSignal.any([runnableConfig.signal]) - : undefined; + if (runnableConfig?.signal) { + derivedSignal = AbortSignal.any([runnableConfig.signal]); + abortHandler = createAbortHandler(); + derivedSignal.addEventListener('abort', abortHandler, { once: true }); + } + resp = await openai.images.generate( { model: 'gpt-image-1', @@ -228,6 +242,10 @@ function createOpenAIImageTools(fields = {}) { logAxiosError({ error, message }); return returnValue(`Something went wrong when trying to generate the image. The OpenAI API may be unavailable: Error Message: ${error.message}`); + } finally { + if (abortHandler && derivedSignal) { + derivedSignal.removeEventListener('abort', abortHandler); + } } if (!resp) { @@ -409,10 +427,17 @@ Error Message: ${error.message}`); headers['Authorization'] = `Bearer ${apiKey}`; } + /** @type {AbortSignal} */ + let derivedSignal = null; + /** @type {() => void} */ + let abortHandler = null; + try { - const derivedSignal = runnableConfig?.signal - ? AbortSignal.any([runnableConfig.signal]) - : undefined; + if (runnableConfig?.signal) { + derivedSignal = AbortSignal.any([runnableConfig.signal]); + abortHandler = createAbortHandler(); + derivedSignal.addEventListener('abort', abortHandler, { once: true }); + } /** @type {import('axios').AxiosRequestConfig} */ const axiosConfig = { @@ -467,6 +492,10 @@ Error Message: ${error.message}`); logAxiosError({ error, message }); return returnValue(`Something went wrong when trying to edit the image. The OpenAI API may be unavailable: Error Message: ${error.message || 'Unknown error'}`); + } finally { + if (abortHandler && derivedSignal) { + derivedSignal.removeEventListener('abort', abortHandler); + } } }, { diff --git a/api/app/clients/tools/util/fileSearch.js b/api/app/clients/tools/util/fileSearch.js index 19d3a79edb..050a0fd896 100644 --- a/api/app/clients/tools/util/fileSearch.js +++ b/api/app/clients/tools/util/fileSearch.js @@ -1,9 +1,10 @@ const { z } = require('zod'); const axios = require('axios'); const { tool } = require('@langchain/core/tools'); +const { logger } = require('@librechat/data-schemas'); const { Tools, EToolResources } = require('librechat-data-provider'); +const { generateShortLivedToken } = require('~/server/services/AuthService'); const { getFiles } = require('~/models/File'); -const { logger } = require('~/config'); /** * @@ -59,7 +60,7 @@ const createFileSearchTool = async ({ req, files, entity_id }) => { if (files.length === 0) { return 'No files to search. Instruct the user to add files for the search.'; } - const jwtToken = req.headers.authorization.split(' ')[1]; + const jwtToken = generateShortLivedToken(req.user.id); if (!jwtToken) { return 'There was an error authenticating the file search request.'; } diff --git a/api/cache/banViolation.js b/api/cache/banViolation.js index 17b23f1c12..3a2d9791b4 100644 --- a/api/cache/banViolation.js +++ b/api/cache/banViolation.js @@ -1,7 +1,8 @@ const { logger } = require('@librechat/data-schemas'); +const { isEnabled, math } = require('@librechat/api'); const { ViolationTypes } = require('librechat-data-provider'); -const { isEnabled, math, removePorts } = require('~/server/utils'); const { deleteAllUserSessions } = require('~/models'); +const { removePorts } = require('~/server/utils'); const getLogStores = require('./getLogStores'); const { BAN_VIOLATIONS, BAN_INTERVAL } = process.env ?? {}; diff --git a/api/cache/getLogStores.js b/api/cache/getLogStores.js index 06cadf9f64..0eef7d3fb4 100644 --- a/api/cache/getLogStores.js +++ b/api/cache/getLogStores.js @@ -1,7 +1,7 @@ const { Keyv } = require('keyv'); +const { isEnabled, math } = require('@librechat/api'); const { CacheKeys, ViolationTypes, Time } = require('librechat-data-provider'); const { logFile, violationFile } = require('./keyvFiles'); -const { isEnabled, math } = require('~/server/utils'); const keyvRedis = require('./keyvRedis'); const keyvMongo = require('./keyvMongo'); diff --git a/api/cache/logViolation.js b/api/cache/logViolation.js index a3162bbfac..5d785480b9 100644 --- a/api/cache/logViolation.js +++ b/api/cache/logViolation.js @@ -9,7 +9,7 @@ const banViolation = require('./banViolation'); * @param {Object} res - Express response object. * @param {string} type - The type of violation. * @param {Object} errorMessage - The error message to log. - * @param {number} [score=1] - The severity of the violation. Defaults to 1 + * @param {number | string} [score=1] - The severity of the violation. Defaults to 1 */ const logViolation = async (req, res, type, errorMessage, score = 1) => { const userId = req.user?.id ?? req.user?._id; diff --git a/api/models/Agent.js b/api/models/Agent.js index d33ca8a8bf..dcb646f039 100644 --- a/api/models/Agent.js +++ b/api/models/Agent.js @@ -70,6 +70,9 @@ const loadEphemeralAgent = async ({ req, agent_id, endpoint, model_parameters: _ if (ephemeralAgent?.execute_code === true) { tools.push(Tools.execute_code); } + if (ephemeralAgent?.file_search === true) { + tools.push(Tools.file_search); + } if (ephemeralAgent?.web_search === true) { tools.push(Tools.web_search); } @@ -87,7 +90,7 @@ const loadEphemeralAgent = async ({ req, agent_id, endpoint, model_parameters: _ } const instructions = req.body.promptPrefix; - return { + const result = { id: agent_id, instructions, provider: endpoint, @@ -95,6 +98,11 @@ const loadEphemeralAgent = async ({ req, agent_id, endpoint, model_parameters: _ model, tools, }; + + if (ephemeralAgent?.artifacts != null && ephemeralAgent.artifacts) { + result.artifacts = ephemeralAgent.artifacts; + } + return result; }; /** diff --git a/api/models/Agent.spec.js b/api/models/Agent.spec.js index 0b0646f524..8953ae0482 100644 --- a/api/models/Agent.spec.js +++ b/api/models/Agent.spec.js @@ -43,7 +43,7 @@ describe('models/Agent', () => { const mongoUri = mongoServer.getUri(); Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema); await mongoose.connect(mongoUri); - }); + }, 20000); afterAll(async () => { await mongoose.disconnect(); @@ -413,7 +413,7 @@ describe('models/Agent', () => { const mongoUri = mongoServer.getUri(); Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema); await mongoose.connect(mongoUri); - }); + }, 20000); afterAll(async () => { await mongoose.disconnect(); @@ -670,7 +670,7 @@ describe('models/Agent', () => { const mongoUri = mongoServer.getUri(); Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema); await mongoose.connect(mongoUri); - }); + }, 20000); afterAll(async () => { await mongoose.disconnect(); @@ -1332,7 +1332,7 @@ describe('models/Agent', () => { const mongoUri = mongoServer.getUri(); Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema); await mongoose.connect(mongoUri); - }); + }, 20000); afterAll(async () => { await mongoose.disconnect(); @@ -1514,7 +1514,7 @@ describe('models/Agent', () => { const mongoUri = mongoServer.getUri(); Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema); await mongoose.connect(mongoUri); - }); + }, 20000); afterAll(async () => { await mongoose.disconnect(); @@ -1798,7 +1798,7 @@ describe('models/Agent', () => { const mongoUri = mongoServer.getUri(); Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema); await mongoose.connect(mongoUri); - }); + }, 20000); afterAll(async () => { await mongoose.disconnect(); @@ -2350,7 +2350,7 @@ describe('models/Agent', () => { const mongoUri = mongoServer.getUri(); Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema); await mongoose.connect(mongoUri); - }); + }, 20000); afterAll(async () => { await mongoose.disconnect(); diff --git a/api/models/Conversation.js b/api/models/Conversation.js index 38e2cbb448..b237c41e93 100644 --- a/api/models/Conversation.js +++ b/api/models/Conversation.js @@ -1,4 +1,6 @@ const { logger } = require('@librechat/data-schemas'); +const { createTempChatExpirationDate } = require('@librechat/api'); +const getCustomConfig = require('~/server/services/Config/getCustomConfig'); const { getMessages, deleteMessages } = require('./Message'); const { Conversation } = require('~/db/models'); @@ -98,10 +100,15 @@ module.exports = { update.conversationId = newConversationId; } - if (req.body.isTemporary) { - const expiredAt = new Date(); - expiredAt.setDate(expiredAt.getDate() + 30); - update.expiredAt = expiredAt; + if (req?.body?.isTemporary) { + try { + const customConfig = await getCustomConfig(); + update.expiredAt = createTempChatExpirationDate(customConfig); + } catch (err) { + logger.error('Error creating temporary chat expiration date:', err); + logger.info(`---\`saveConvo\` context: ${metadata?.context}`); + update.expiredAt = null; + } } else { update.expiredAt = null; } diff --git a/api/models/File.js b/api/models/File.js index ff509539e3..1ee943131d 100644 --- a/api/models/File.js +++ b/api/models/File.js @@ -1,5 +1,5 @@ const { logger } = require('@librechat/data-schemas'); -const { EToolResources } = require('librechat-data-provider'); +const { EToolResources, FileContext } = require('librechat-data-provider'); const { File } = require('~/db/models'); /** @@ -32,19 +32,19 @@ const getFiles = async (filter, _sortOptions, selectFields = { text: 0 }) => { * @returns {Promise>} Files that match the criteria */ const getToolFilesByIds = async (fileIds, toolResourceSet) => { - if (!fileIds || !fileIds.length) { + if (!fileIds || !fileIds.length || !toolResourceSet?.size) { return []; } try { const filter = { file_id: { $in: fileIds }, + $or: [], }; - if (toolResourceSet.size) { - filter.$or = []; + if (toolResourceSet.has(EToolResources.ocr)) { + filter.$or.push({ text: { $exists: true, $ne: null }, context: FileContext.agents }); } - if (toolResourceSet.has(EToolResources.file_search)) { filter.$or.push({ embedded: true }); } diff --git a/api/models/Message.js b/api/models/Message.js index abd538084e..3d5eee6ec9 100644 --- a/api/models/Message.js +++ b/api/models/Message.js @@ -1,5 +1,7 @@ const { z } = require('zod'); const { logger } = require('@librechat/data-schemas'); +const { createTempChatExpirationDate } = require('@librechat/api'); +const getCustomConfig = require('~/server/services/Config/getCustomConfig'); const { Message } = require('~/db/models'); const idSchema = z.string().uuid(); @@ -54,9 +56,14 @@ async function saveMessage(req, params, metadata) { }; if (req?.body?.isTemporary) { - const expiredAt = new Date(); - expiredAt.setDate(expiredAt.getDate() + 30); - update.expiredAt = expiredAt; + try { + const customConfig = await getCustomConfig(); + update.expiredAt = createTempChatExpirationDate(customConfig); + } catch (err) { + logger.error('Error creating temporary chat expiration date:', err); + logger.info(`---\`saveMessage\` context: ${metadata?.context}`); + update.expiredAt = null; + } } else { update.expiredAt = null; } diff --git a/api/package.json b/api/package.json index 6633a99c3f..25cde92056 100644 --- a/api/package.json +++ b/api/package.json @@ -1,6 +1,6 @@ { "name": "@librechat/backend", - "version": "v0.7.8", + "version": "v0.7.9-rc1", "description": "", "scripts": { "start": "echo 'please run this from the root directory'", @@ -48,14 +48,13 @@ "@langchain/google-genai": "^0.2.13", "@langchain/google-vertexai": "^0.2.13", "@langchain/textsplitters": "^0.1.0", - "@librechat/agents": "^2.4.41", + "@librechat/agents": "^2.4.56", "@librechat/api": "*", "@librechat/data-schemas": "*", "@node-saml/passport-saml": "^5.0.0", "@waylaidwanderer/fetch-event-source": "^3.0.1", "axios": "^1.8.2", "bcryptjs": "^2.4.3", - "cohere-ai": "^7.9.1", "compression": "^1.7.4", "connect-redis": "^7.1.0", "cookie": "^0.7.2", diff --git a/api/server/cleanup.js b/api/server/cleanup.js index de7450cea0..84164eb641 100644 --- a/api/server/cleanup.js +++ b/api/server/cleanup.js @@ -169,9 +169,6 @@ function disposeClient(client) { client.isGenerativeModel = null; } // Properties specific to OpenAIClient - if (client.ChatGPTClient) { - client.ChatGPTClient = null; - } if (client.completionsUrl) { client.completionsUrl = null; } diff --git a/api/server/controllers/AskController.js b/api/server/controllers/AskController.js deleted file mode 100644 index 40b209ef35..0000000000 --- a/api/server/controllers/AskController.js +++ /dev/null @@ -1,282 +0,0 @@ -const { getResponseSender, Constants } = require('librechat-data-provider'); -const { - handleAbortError, - createAbortController, - cleanupAbortController, -} = require('~/server/middleware'); -const { - disposeClient, - processReqData, - clientRegistry, - requestDataMap, -} = require('~/server/cleanup'); -const { sendMessage, createOnProgress } = require('~/server/utils'); -const { saveMessage } = require('~/models'); -const { logger } = require('~/config'); - -const AskController = async (req, res, next, initializeClient, addTitle) => { - let { - text, - endpointOption, - conversationId, - modelDisplayLabel, - parentMessageId = null, - overrideParentMessageId = null, - } = req.body; - - let client = null; - let abortKey = null; - let cleanupHandlers = []; - let clientRef = null; - - logger.debug('[AskController]', { - text, - conversationId, - ...endpointOption, - modelsConfig: endpointOption?.modelsConfig ? 'exists' : '', - }); - - let userMessage = null; - let userMessagePromise = null; - let promptTokens = null; - let userMessageId = null; - let responseMessageId = null; - let getAbortData = null; - - const sender = getResponseSender({ - ...endpointOption, - model: endpointOption.modelOptions.model, - modelDisplayLabel, - }); - const initialConversationId = conversationId; - const newConvo = !initialConversationId; - const userId = req.user.id; - - let reqDataContext = { - userMessage, - userMessagePromise, - responseMessageId, - promptTokens, - conversationId, - userMessageId, - }; - - const updateReqData = (data = {}) => { - reqDataContext = processReqData(data, reqDataContext); - abortKey = reqDataContext.abortKey; - userMessage = reqDataContext.userMessage; - userMessagePromise = reqDataContext.userMessagePromise; - responseMessageId = reqDataContext.responseMessageId; - promptTokens = reqDataContext.promptTokens; - conversationId = reqDataContext.conversationId; - userMessageId = reqDataContext.userMessageId; - }; - - let { onProgress: progressCallback, getPartialText } = createOnProgress(); - - const performCleanup = () => { - logger.debug('[AskController] Performing cleanup'); - if (Array.isArray(cleanupHandlers)) { - for (const handler of cleanupHandlers) { - try { - if (typeof handler === 'function') { - handler(); - } - } catch (e) { - // Ignore - } - } - } - - if (abortKey) { - logger.debug('[AskController] Cleaning up abort controller'); - cleanupAbortController(abortKey); - abortKey = null; - } - - if (client) { - disposeClient(client); - client = null; - } - - reqDataContext = null; - userMessage = null; - userMessagePromise = null; - promptTokens = null; - getAbortData = null; - progressCallback = null; - endpointOption = null; - cleanupHandlers = null; - addTitle = null; - - if (requestDataMap.has(req)) { - requestDataMap.delete(req); - } - logger.debug('[AskController] Cleanup completed'); - }; - - try { - ({ client } = await initializeClient({ req, res, endpointOption })); - if (clientRegistry && client) { - clientRegistry.register(client, { userId }, client); - } - - if (client) { - requestDataMap.set(req, { client }); - } - - clientRef = new WeakRef(client); - - getAbortData = () => { - const currentClient = clientRef?.deref(); - const currentText = - currentClient?.getStreamText != null ? currentClient.getStreamText() : getPartialText(); - - return { - sender, - conversationId, - messageId: reqDataContext.responseMessageId, - parentMessageId: overrideParentMessageId ?? userMessageId, - text: currentText, - userMessage: userMessage, - userMessagePromise: userMessagePromise, - promptTokens: reqDataContext.promptTokens, - }; - }; - - const { onStart, abortController } = createAbortController( - req, - res, - getAbortData, - updateReqData, - ); - - const closeHandler = () => { - logger.debug('[AskController] Request closed'); - if (!abortController || abortController.signal.aborted || abortController.requestCompleted) { - return; - } - abortController.abort(); - logger.debug('[AskController] Request aborted on close'); - }; - - res.on('close', closeHandler); - cleanupHandlers.push(() => { - try { - res.removeListener('close', closeHandler); - } catch (e) { - // Ignore - } - }); - - const messageOptions = { - user: userId, - parentMessageId, - conversationId: reqDataContext.conversationId, - overrideParentMessageId, - getReqData: updateReqData, - onStart, - abortController, - progressCallback, - progressOptions: { - res, - }, - }; - - /** @type {TMessage} */ - let response = await client.sendMessage(text, messageOptions); - response.endpoint = endpointOption.endpoint; - - const databasePromise = response.databasePromise; - delete response.databasePromise; - - const { conversation: convoData = {} } = await databasePromise; - const conversation = { ...convoData }; - conversation.title = - conversation && !conversation.title ? null : conversation?.title || 'New Chat'; - - const latestUserMessage = reqDataContext.userMessage; - - if (client?.options?.attachments && latestUserMessage) { - latestUserMessage.files = client.options.attachments; - if (endpointOption?.modelOptions?.model) { - conversation.model = endpointOption.modelOptions.model; - } - delete latestUserMessage.image_urls; - } - - if (!abortController.signal.aborted) { - const finalResponseMessage = { ...response }; - - sendMessage(res, { - final: true, - conversation, - title: conversation.title, - requestMessage: latestUserMessage, - responseMessage: finalResponseMessage, - }); - res.end(); - - if (client?.savedMessageIds && !client.savedMessageIds.has(response.messageId)) { - await saveMessage( - req, - { ...finalResponseMessage, user: userId }, - { context: 'api/server/controllers/AskController.js - response end' }, - ); - } - } - - if (!client?.skipSaveUserMessage && latestUserMessage) { - await saveMessage(req, latestUserMessage, { - context: "api/server/controllers/AskController.js - don't skip saving user message", - }); - } - - if (typeof addTitle === 'function' && parentMessageId === Constants.NO_PARENT && newConvo) { - addTitle(req, { - text, - response: { ...response }, - client, - }) - .then(() => { - logger.debug('[AskController] Title generation started'); - }) - .catch((err) => { - logger.error('[AskController] Error in title generation', err); - }) - .finally(() => { - logger.debug('[AskController] Title generation completed'); - performCleanup(); - }); - } else { - performCleanup(); - } - } catch (error) { - logger.error('[AskController] Error handling request', error); - let partialText = ''; - try { - const currentClient = clientRef?.deref(); - partialText = - currentClient?.getStreamText != null ? currentClient.getStreamText() : getPartialText(); - } catch (getTextError) { - logger.error('[AskController] Error calling getText() during error handling', getTextError); - } - - handleAbortError(res, req, error, { - sender, - partialText, - conversationId: reqDataContext.conversationId, - messageId: reqDataContext.responseMessageId, - parentMessageId: overrideParentMessageId ?? reqDataContext.userMessageId ?? parentMessageId, - userMessageId: reqDataContext.userMessageId, - }) - .catch((err) => { - logger.error('[AskController] Error in `handleAbortError` during catch block', err); - }) - .finally(() => { - performCleanup(); - }); - } -}; - -module.exports = AskController; diff --git a/api/server/controllers/AuthController.js b/api/server/controllers/AuthController.js index 0f8152de3e..3dbb1a2f31 100644 --- a/api/server/controllers/AuthController.js +++ b/api/server/controllers/AuthController.js @@ -1,17 +1,17 @@ const cookies = require('cookie'); const jwt = require('jsonwebtoken'); const openIdClient = require('openid-client'); +const { isEnabled } = require('@librechat/api'); const { logger } = require('@librechat/data-schemas'); const { - registerUser, - resetPassword, - setAuthTokens, requestPasswordReset, setOpenIDAuthTokens, + resetPassword, + setAuthTokens, + registerUser, } = require('~/server/services/AuthService'); const { findUser, getUserById, deleteAllUserSessions, findSession } = require('~/models'); const { getOpenIdConfig } = require('~/strategies'); -const { isEnabled } = require('~/server/utils'); const registrationController = async (req, res) => { try { diff --git a/api/server/controllers/EditController.js b/api/server/controllers/EditController.js index d142d474df..d24e87ce3a 100644 --- a/api/server/controllers/EditController.js +++ b/api/server/controllers/EditController.js @@ -1,3 +1,5 @@ +const { sendEvent } = require('@librechat/api'); +const { logger } = require('@librechat/data-schemas'); const { getResponseSender } = require('librechat-data-provider'); const { handleAbortError, @@ -10,9 +12,8 @@ const { clientRegistry, requestDataMap, } = require('~/server/cleanup'); -const { sendMessage, createOnProgress } = require('~/server/utils'); +const { createOnProgress } = require('~/server/utils'); const { saveMessage } = require('~/models'); -const { logger } = require('~/config'); const EditController = async (req, res, next, initializeClient) => { let { @@ -84,7 +85,7 @@ const EditController = async (req, res, next, initializeClient) => { } if (abortKey) { - logger.debug('[AskController] Cleaning up abort controller'); + logger.debug('[EditController] Cleaning up abort controller'); cleanupAbortController(abortKey); abortKey = null; } @@ -198,7 +199,7 @@ const EditController = async (req, res, next, initializeClient) => { const finalUserMessage = reqDataContext.userMessage; const finalResponseMessage = { ...response }; - sendMessage(res, { + sendEvent(res, { final: true, conversation, title: conversation.title, diff --git a/api/server/controllers/ErrorController.js b/api/server/controllers/ErrorController.js index 234cb90fb3..6907e63d9c 100644 --- a/api/server/controllers/ErrorController.js +++ b/api/server/controllers/ErrorController.js @@ -24,17 +24,23 @@ const handleValidationError = (err, res) => { } }; -// eslint-disable-next-line no-unused-vars -module.exports = (err, req, res, next) => { +module.exports = (err, _req, res, _next) => { try { if (err.name === 'ValidationError') { - return (err = handleValidationError(err, res)); + return handleValidationError(err, res); } if (err.code && err.code == 11000) { - return (err = handleDuplicateKeyError(err, res)); + return handleDuplicateKeyError(err, res); } - } catch (err) { + // Special handling for errors like SyntaxError + if (err.statusCode && err.body) { + return res.status(err.statusCode).send(err.body); + } + logger.error('ErrorController => error', err); - res.status(500).send('An unknown error occurred.'); + return res.status(500).send('An unknown error occurred.'); + } catch (err) { + logger.error('ErrorController => processing error', err); + return res.status(500).send('Processing error in ErrorController.'); } }; diff --git a/api/server/controllers/ErrorController.spec.js b/api/server/controllers/ErrorController.spec.js new file mode 100644 index 0000000000..c46315a5e5 --- /dev/null +++ b/api/server/controllers/ErrorController.spec.js @@ -0,0 +1,241 @@ +const errorController = require('./ErrorController'); +const { logger } = require('~/config'); + +// Mock the logger +jest.mock('~/config', () => ({ + logger: { + error: jest.fn(), + }, +})); + +describe('ErrorController', () => { + let mockReq, mockRes, mockNext; + + beforeEach(() => { + mockReq = {}; + mockRes = { + status: jest.fn().mockReturnThis(), + send: jest.fn(), + }; + mockNext = jest.fn(); + logger.error.mockClear(); + }); + + describe('ValidationError handling', () => { + it('should handle ValidationError with single error', () => { + const validationError = { + name: 'ValidationError', + errors: { + email: { message: 'Email is required', path: 'email' }, + }, + }; + + errorController(validationError, mockReq, mockRes, mockNext); + + expect(mockRes.status).toHaveBeenCalledWith(400); + expect(mockRes.send).toHaveBeenCalledWith({ + messages: '["Email is required"]', + fields: '["email"]', + }); + expect(logger.error).toHaveBeenCalledWith('Validation error:', validationError.errors); + }); + + it('should handle ValidationError with multiple errors', () => { + const validationError = { + name: 'ValidationError', + errors: { + email: { message: 'Email is required', path: 'email' }, + password: { message: 'Password is required', path: 'password' }, + }, + }; + + errorController(validationError, mockReq, mockRes, mockNext); + + expect(mockRes.status).toHaveBeenCalledWith(400); + expect(mockRes.send).toHaveBeenCalledWith({ + messages: '"Email is required Password is required"', + fields: '["email","password"]', + }); + expect(logger.error).toHaveBeenCalledWith('Validation error:', validationError.errors); + }); + + it('should handle ValidationError with empty errors object', () => { + const validationError = { + name: 'ValidationError', + errors: {}, + }; + + errorController(validationError, mockReq, mockRes, mockNext); + + expect(mockRes.status).toHaveBeenCalledWith(400); + expect(mockRes.send).toHaveBeenCalledWith({ + messages: '[]', + fields: '[]', + }); + }); + }); + + describe('Duplicate key error handling', () => { + it('should handle duplicate key error (code 11000)', () => { + const duplicateKeyError = { + code: 11000, + keyValue: { email: 'test@example.com' }, + }; + + errorController(duplicateKeyError, mockReq, mockRes, mockNext); + + expect(mockRes.status).toHaveBeenCalledWith(409); + expect(mockRes.send).toHaveBeenCalledWith({ + messages: 'An document with that ["email"] already exists.', + fields: '["email"]', + }); + expect(logger.error).toHaveBeenCalledWith('Duplicate key error:', duplicateKeyError.keyValue); + }); + + it('should handle duplicate key error with multiple fields', () => { + const duplicateKeyError = { + code: 11000, + keyValue: { email: 'test@example.com', username: 'testuser' }, + }; + + errorController(duplicateKeyError, mockReq, mockRes, mockNext); + + expect(mockRes.status).toHaveBeenCalledWith(409); + expect(mockRes.send).toHaveBeenCalledWith({ + messages: 'An document with that ["email","username"] already exists.', + fields: '["email","username"]', + }); + expect(logger.error).toHaveBeenCalledWith('Duplicate key error:', duplicateKeyError.keyValue); + }); + + it('should handle error with code 11000 as string', () => { + const duplicateKeyError = { + code: '11000', + keyValue: { email: 'test@example.com' }, + }; + + errorController(duplicateKeyError, mockReq, mockRes, mockNext); + + expect(mockRes.status).toHaveBeenCalledWith(409); + expect(mockRes.send).toHaveBeenCalledWith({ + messages: 'An document with that ["email"] already exists.', + fields: '["email"]', + }); + }); + }); + + describe('SyntaxError handling', () => { + it('should handle errors with statusCode and body', () => { + const syntaxError = { + statusCode: 400, + body: 'Invalid JSON syntax', + }; + + errorController(syntaxError, mockReq, mockRes, mockNext); + + expect(mockRes.status).toHaveBeenCalledWith(400); + expect(mockRes.send).toHaveBeenCalledWith('Invalid JSON syntax'); + }); + + it('should handle errors with different statusCode and body', () => { + const customError = { + statusCode: 422, + body: { error: 'Unprocessable entity' }, + }; + + errorController(customError, mockReq, mockRes, mockNext); + + expect(mockRes.status).toHaveBeenCalledWith(422); + expect(mockRes.send).toHaveBeenCalledWith({ error: 'Unprocessable entity' }); + }); + + it('should handle error with statusCode but no body', () => { + const partialError = { + statusCode: 400, + }; + + errorController(partialError, mockReq, mockRes, mockNext); + + expect(mockRes.status).toHaveBeenCalledWith(500); + expect(mockRes.send).toHaveBeenCalledWith('An unknown error occurred.'); + }); + + it('should handle error with body but no statusCode', () => { + const partialError = { + body: 'Some error message', + }; + + errorController(partialError, mockReq, mockRes, mockNext); + + expect(mockRes.status).toHaveBeenCalledWith(500); + expect(mockRes.send).toHaveBeenCalledWith('An unknown error occurred.'); + }); + }); + + describe('Unknown error handling', () => { + it('should handle unknown errors', () => { + const unknownError = new Error('Some unknown error'); + + errorController(unknownError, mockReq, mockRes, mockNext); + + expect(mockRes.status).toHaveBeenCalledWith(500); + expect(mockRes.send).toHaveBeenCalledWith('An unknown error occurred.'); + expect(logger.error).toHaveBeenCalledWith('ErrorController => error', unknownError); + }); + + it('should handle errors with code other than 11000', () => { + const mongoError = { + code: 11100, + message: 'Some MongoDB error', + }; + + errorController(mongoError, mockReq, mockRes, mockNext); + + expect(mockRes.status).toHaveBeenCalledWith(500); + expect(mockRes.send).toHaveBeenCalledWith('An unknown error occurred.'); + expect(logger.error).toHaveBeenCalledWith('ErrorController => error', mongoError); + }); + + it('should handle null/undefined errors', () => { + errorController(null, mockReq, mockRes, mockNext); + + expect(mockRes.status).toHaveBeenCalledWith(500); + expect(mockRes.send).toHaveBeenCalledWith('Processing error in ErrorController.'); + expect(logger.error).toHaveBeenCalledWith( + 'ErrorController => processing error', + expect.any(Error), + ); + }); + }); + + describe('Catch block handling', () => { + beforeEach(() => { + // Restore logger mock to normal behavior for these tests + logger.error.mockRestore(); + logger.error = jest.fn(); + }); + + it('should handle errors when logger.error throws', () => { + // Create fresh mocks for this test + const freshMockRes = { + status: jest.fn().mockReturnThis(), + send: jest.fn(), + }; + + // Mock logger to throw on the first call, succeed on the second + logger.error + .mockImplementationOnce(() => { + throw new Error('Logger error'); + }) + .mockImplementation(() => {}); + + const testError = new Error('Test error'); + + errorController(testError, mockReq, freshMockRes, mockNext); + + expect(freshMockRes.status).toHaveBeenCalledWith(500); + expect(freshMockRes.send).toHaveBeenCalledWith('Processing error in ErrorController.'); + expect(logger.error).toHaveBeenCalledTimes(2); + }); + }); +}); diff --git a/api/server/controllers/agents/__tests__/v1.spec.js b/api/server/controllers/agents/__tests__/v1.spec.js new file mode 100644 index 0000000000..b097cd98ce --- /dev/null +++ b/api/server/controllers/agents/__tests__/v1.spec.js @@ -0,0 +1,195 @@ +const { duplicateAgent } = require('../v1'); +const { getAgent, createAgent } = require('~/models/Agent'); +const { getActions } = require('~/models/Action'); +const { nanoid } = require('nanoid'); + +jest.mock('~/models/Agent'); +jest.mock('~/models/Action'); +jest.mock('nanoid'); + +describe('duplicateAgent', () => { + let req, res; + + beforeEach(() => { + req = { + params: { id: 'agent_123' }, + user: { id: 'user_456' }, + }; + res = { + status: jest.fn().mockReturnThis(), + json: jest.fn(), + }; + jest.clearAllMocks(); + }); + + it('should duplicate an agent successfully', async () => { + const mockAgent = { + id: 'agent_123', + name: 'Test Agent', + description: 'Test Description', + instructions: 'Test Instructions', + provider: 'openai', + model: 'gpt-4', + tools: ['file_search'], + actions: [], + author: 'user_789', + versions: [{ name: 'Test Agent', version: 1 }], + __v: 0, + }; + + const mockNewAgent = { + id: 'agent_new_123', + name: 'Test Agent (1/2/23, 12:34)', + description: 'Test Description', + instructions: 'Test Instructions', + provider: 'openai', + model: 'gpt-4', + tools: ['file_search'], + actions: [], + author: 'user_456', + versions: [ + { + name: 'Test Agent (1/2/23, 12:34)', + description: 'Test Description', + instructions: 'Test Instructions', + provider: 'openai', + model: 'gpt-4', + tools: ['file_search'], + actions: [], + createdAt: new Date(), + updatedAt: new Date(), + }, + ], + }; + + getAgent.mockResolvedValue(mockAgent); + getActions.mockResolvedValue([]); + nanoid.mockReturnValue('new_123'); + createAgent.mockResolvedValue(mockNewAgent); + + await duplicateAgent(req, res); + + expect(getAgent).toHaveBeenCalledWith({ id: 'agent_123' }); + expect(getActions).toHaveBeenCalledWith({ agent_id: 'agent_123' }, true); + expect(createAgent).toHaveBeenCalledWith( + expect.objectContaining({ + id: 'agent_new_123', + author: 'user_456', + name: expect.stringContaining('Test Agent ('), + description: 'Test Description', + instructions: 'Test Instructions', + provider: 'openai', + model: 'gpt-4', + tools: ['file_search'], + actions: [], + }), + ); + + expect(createAgent).toHaveBeenCalledWith( + expect.not.objectContaining({ + versions: expect.anything(), + __v: expect.anything(), + }), + ); + + expect(res.status).toHaveBeenCalledWith(201); + expect(res.json).toHaveBeenCalledWith({ + agent: mockNewAgent, + actions: [], + }); + }); + + it('should ensure duplicated agent has clean versions array without nested fields', async () => { + const mockAgent = { + id: 'agent_123', + name: 'Test Agent', + description: 'Test Description', + versions: [ + { + name: 'Test Agent', + versions: [{ name: 'Nested' }], + __v: 1, + }, + ], + __v: 2, + }; + + const mockNewAgent = { + id: 'agent_new_123', + name: 'Test Agent (1/2/23, 12:34)', + description: 'Test Description', + versions: [ + { + name: 'Test Agent (1/2/23, 12:34)', + description: 'Test Description', + createdAt: new Date(), + updatedAt: new Date(), + }, + ], + }; + + getAgent.mockResolvedValue(mockAgent); + getActions.mockResolvedValue([]); + nanoid.mockReturnValue('new_123'); + createAgent.mockResolvedValue(mockNewAgent); + + await duplicateAgent(req, res); + + expect(mockNewAgent.versions).toHaveLength(1); + + const firstVersion = mockNewAgent.versions[0]; + expect(firstVersion).not.toHaveProperty('versions'); + expect(firstVersion).not.toHaveProperty('__v'); + + expect(mockNewAgent).not.toHaveProperty('__v'); + + expect(res.status).toHaveBeenCalledWith(201); + }); + + it('should return 404 if agent not found', async () => { + getAgent.mockResolvedValue(null); + + await duplicateAgent(req, res); + + expect(res.status).toHaveBeenCalledWith(404); + expect(res.json).toHaveBeenCalledWith({ + error: 'Agent not found', + status: 'error', + }); + }); + + it('should handle tool_resources.ocr correctly', async () => { + const mockAgent = { + id: 'agent_123', + name: 'Test Agent', + tool_resources: { + ocr: { enabled: true, config: 'test' }, + other: { should: 'not be copied' }, + }, + }; + + getAgent.mockResolvedValue(mockAgent); + getActions.mockResolvedValue([]); + nanoid.mockReturnValue('new_123'); + createAgent.mockResolvedValue({ id: 'agent_new_123' }); + + await duplicateAgent(req, res); + + expect(createAgent).toHaveBeenCalledWith( + expect.objectContaining({ + tool_resources: { + ocr: { enabled: true, config: 'test' }, + }, + }), + ); + }); + + it('should handle errors gracefully', async () => { + getAgent.mockRejectedValue(new Error('Database error')); + + await duplicateAgent(req, res); + + expect(res.status).toHaveBeenCalledWith(500); + expect(res.json).toHaveBeenCalledWith({ error: 'Database error' }); + }); +}); diff --git a/api/server/controllers/agents/client.js b/api/server/controllers/agents/client.js index 6769348d95..1bdf809d91 100644 --- a/api/server/controllers/agents/client.js +++ b/api/server/controllers/agents/client.js @@ -4,11 +4,13 @@ const { sendEvent, createRun, Tokenizer, + checkAccess, memoryInstructions, createMemoryProcessor, } = require('@librechat/api'); const { Callback, + Providers, GraphEvents, formatMessage, formatAgentMessages, @@ -31,22 +33,29 @@ const { } = require('librechat-data-provider'); const { DynamicStructuredTool } = require('@langchain/core/tools'); const { getBufferString, HumanMessage } = require('@langchain/core/messages'); -const { - getCustomEndpointConfig, - createGetMCPAuthMap, - checkCapability, -} = require('~/server/services/Config'); +const { createGetMCPAuthMap, checkCapability } = require('~/server/services/Config'); const { addCacheControl, createContextHandlers } = require('~/app/clients/prompts'); const { initializeAgent } = require('~/server/services/Endpoints/agents/agent'); const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens'); const { getFormattedMemories, deleteMemory, setMemory } = require('~/models'); const { encodeAndFormat } = require('~/server/services/Files/images/encode'); -const initOpenAI = require('~/server/services/Endpoints/openAI/initialize'); -const { checkAccess } = require('~/server/middleware/roles/access'); +const { getProviderConfig } = require('~/server/services/Endpoints'); const BaseClient = require('~/app/clients/BaseClient'); +const { getRoleByName } = require('~/models/Role'); const { loadAgent } = require('~/models/Agent'); const { getMCPManager } = require('~/config'); +const omitTitleOptions = new Set([ + 'stream', + 'thinking', + 'streaming', + 'clientOptions', + 'thinkingConfig', + 'thinkingBudget', + 'includeThoughts', + 'maxOutputTokens', +]); + /** * @param {ServerRequest} req * @param {Agent} agent @@ -393,7 +402,12 @@ class AgentClient extends BaseClient { if (user.personalization?.memories === false) { return; } - const hasAccess = await checkAccess(user, PermissionTypes.MEMORIES, [Permissions.USE]); + const hasAccess = await checkAccess({ + user, + permissionType: PermissionTypes.MEMORIES, + permissions: [Permissions.USE], + getRoleByName, + }); if (!hasAccess) { logger.debug( @@ -511,7 +525,10 @@ class AgentClient extends BaseClient { messagesToProcess = [...messages.slice(-messageWindowSize)]; } } - return await this.processMemory(messagesToProcess); + + const bufferString = getBufferString(messagesToProcess); + const bufferMessage = new HumanMessage(`# Current Chat:\n\n${bufferString}`); + return await this.processMemory([bufferMessage]); } catch (error) { logger.error('Memory Agent failed to process memory', error); } @@ -677,7 +694,7 @@ class AgentClient extends BaseClient { hide_sequential_outputs: this.options.agent.hide_sequential_outputs, user: this.options.req.user, }, - recursionLimit: agentsEConfig?.recursionLimit, + recursionLimit: agentsEConfig?.recursionLimit ?? 25, signal: abortController.signal, streamMode: 'values', version: 'v2', @@ -983,23 +1000,26 @@ class AgentClient extends BaseClient { throw new Error('Run not initialized'); } const { handleLLMEnd, collected: collectedMetadata } = createMetadataAggregator(); - const endpoint = this.options.agent.endpoint; - const { req, res } = this.options; + const { req, res, agent } = this.options; + const endpoint = agent.endpoint; + /** @type {import('@librechat/agents').ClientOptions} */ let clientOptions = { maxTokens: 75, + model: agent.model_parameters.model, }; - let endpointConfig = req.app.locals[endpoint]; + + const { getOptions, overrideProvider, customEndpointConfig } = + await getProviderConfig(endpoint); + + /** @type {TEndpoint | undefined} */ + const endpointConfig = req.app.locals[endpoint] ?? customEndpointConfig; if (!endpointConfig) { - try { - endpointConfig = await getCustomEndpointConfig(endpoint); - } catch (err) { - logger.error( - '[api/server/controllers/agents/client.js #titleConvo] Error getting custom endpoint config', - err, - ); - } + logger.warn( + '[api/server/controllers/agents/client.js #titleConvo] Error getting endpoint config', + ); } + if ( endpointConfig && endpointConfig.titleModel && @@ -1007,30 +1027,50 @@ class AgentClient extends BaseClient { ) { clientOptions.model = endpointConfig.titleModel; } + + const options = await getOptions({ + req, + res, + optionsOnly: true, + overrideEndpoint: endpoint, + overrideModel: clientOptions.model, + endpointOption: { model_parameters: clientOptions }, + }); + + let provider = options.provider ?? overrideProvider ?? agent.provider; if ( endpoint === EModelEndpoint.azureOpenAI && - clientOptions.model && - this.options.agent.model_parameters.model !== clientOptions.model + options.llmConfig?.azureOpenAIApiInstanceName == null ) { - clientOptions = - ( - await initOpenAI({ - req, - res, - optionsOnly: true, - overrideModel: clientOptions.model, - overrideEndpoint: endpoint, - endpointOption: { - model_parameters: clientOptions, - }, - }) - )?.llmConfig ?? clientOptions; + provider = Providers.OPENAI; } - if (/\b(o\d)\b/i.test(clientOptions.model) && clientOptions.maxTokens != null) { + + /** @type {import('@librechat/agents').ClientOptions} */ + clientOptions = { ...options.llmConfig }; + if (options.configOptions) { + clientOptions.configuration = options.configOptions; + } + + // Ensure maxTokens is set for non-o1 models + if (!/\b(o\d)\b/i.test(clientOptions.model) && !clientOptions.maxTokens) { + clientOptions.maxTokens = 75; + } else if (/\b(o\d)\b/i.test(clientOptions.model) && clientOptions.maxTokens != null) { delete clientOptions.maxTokens; } + + clientOptions = Object.assign( + Object.fromEntries( + Object.entries(clientOptions).filter(([key]) => !omitTitleOptions.has(key)), + ), + ); + + if (provider === Providers.GOOGLE) { + clientOptions.json = true; + } + try { const titleResult = await this.run.generateTitle({ + provider, inputText: text, contentParts: this.contentParts, clientOptions, @@ -1048,8 +1088,10 @@ class AgentClient extends BaseClient { let input_tokens, output_tokens; if (item.usage) { - input_tokens = item.usage.input_tokens || item.usage.inputTokens; - output_tokens = item.usage.output_tokens || item.usage.outputTokens; + input_tokens = + item.usage.prompt_tokens || item.usage.input_tokens || item.usage.inputTokens; + output_tokens = + item.usage.completion_tokens || item.usage.output_tokens || item.usage.outputTokens; } else if (item.tokenUsage) { input_tokens = item.tokenUsage.promptTokens; output_tokens = item.tokenUsage.completionTokens; diff --git a/api/server/controllers/agents/errors.js b/api/server/controllers/agents/errors.js index fb4de45085..b3bb1cea65 100644 --- a/api/server/controllers/agents/errors.js +++ b/api/server/controllers/agents/errors.js @@ -1,10 +1,10 @@ // errorHandler.js -const { logger } = require('~/config'); -const getLogStores = require('~/cache/getLogStores'); +const { logger } = require('@librechat/data-schemas'); const { CacheKeys, ViolationTypes } = require('librechat-data-provider'); +const { sendResponse } = require('~/server/middleware/error'); const { recordUsage } = require('~/server/services/Threads'); const { getConvo } = require('~/models/Conversation'); -const { sendResponse } = require('~/server/utils'); +const getLogStores = require('~/cache/getLogStores'); /** * @typedef {Object} ErrorHandlerContext @@ -75,7 +75,7 @@ const createErrorHandler = ({ req, res, getContext, originPath = '/assistants/ch } else if (/Files.*are invalid/.test(error.message)) { const errorMessage = `Files are invalid, or may not have uploaded yet.${ endpoint === 'azureAssistants' - ? ' If using Azure OpenAI, files are only available in the region of the assistant\'s model at the time of upload.' + ? " If using Azure OpenAI, files are only available in the region of the assistant's model at the time of upload." : '' }`; return sendResponse(req, res, messageData, errorMessage); diff --git a/api/server/controllers/agents/llm.js b/api/server/controllers/agents/llm.js deleted file mode 100644 index 438a38b6cb..0000000000 --- a/api/server/controllers/agents/llm.js +++ /dev/null @@ -1,106 +0,0 @@ -const { HttpsProxyAgent } = require('https-proxy-agent'); -const { resolveHeaders } = require('librechat-data-provider'); -const { createLLM } = require('~/app/clients/llm'); - -/** - * Initializes and returns a Language Learning Model (LLM) instance. - * - * @param {Object} options - Configuration options for the LLM. - * @param {string} options.model - The model identifier. - * @param {string} options.modelName - The specific name of the model. - * @param {number} options.temperature - The temperature setting for the model. - * @param {number} options.presence_penalty - The presence penalty for the model. - * @param {number} options.frequency_penalty - The frequency penalty for the model. - * @param {number} options.max_tokens - The maximum number of tokens for the model output. - * @param {boolean} options.streaming - Whether to use streaming for the model output. - * @param {Object} options.context - The context for the conversation. - * @param {number} options.tokenBuffer - The token buffer size. - * @param {number} options.initialMessageCount - The initial message count. - * @param {string} options.conversationId - The ID of the conversation. - * @param {string} options.user - The user identifier. - * @param {string} options.langchainProxy - The langchain proxy URL. - * @param {boolean} options.useOpenRouter - Whether to use OpenRouter. - * @param {Object} options.options - Additional options. - * @param {Object} options.options.headers - Custom headers for the request. - * @param {string} options.options.proxy - Proxy URL. - * @param {Object} options.options.req - The request object. - * @param {Object} options.options.res - The response object. - * @param {boolean} options.options.debug - Whether to enable debug mode. - * @param {string} options.apiKey - The API key for authentication. - * @param {Object} options.azure - Azure-specific configuration. - * @param {Object} options.abortController - The AbortController instance. - * @returns {Object} The initialized LLM instance. - */ -function initializeLLM(options) { - const { - model, - modelName, - temperature, - presence_penalty, - frequency_penalty, - max_tokens, - streaming, - user, - langchainProxy, - useOpenRouter, - options: { headers, proxy }, - apiKey, - azure, - } = options; - - const modelOptions = { - modelName: modelName || model, - temperature, - presence_penalty, - frequency_penalty, - user, - }; - - if (max_tokens) { - modelOptions.max_tokens = max_tokens; - } - - const configOptions = {}; - - if (langchainProxy) { - configOptions.basePath = langchainProxy; - } - - if (useOpenRouter) { - configOptions.basePath = 'https://openrouter.ai/api/v1'; - configOptions.baseOptions = { - headers: { - 'HTTP-Referer': 'https://librechat.ai', - 'X-Title': 'LibreChat', - }, - }; - } - - if (headers && typeof headers === 'object' && !Array.isArray(headers)) { - configOptions.baseOptions = { - headers: resolveHeaders({ - ...headers, - ...configOptions?.baseOptions?.headers, - }), - }; - } - - if (proxy) { - configOptions.httpAgent = new HttpsProxyAgent(proxy); - configOptions.httpsAgent = new HttpsProxyAgent(proxy); - } - - const llm = createLLM({ - modelOptions, - configOptions, - openAIApiKey: apiKey, - azure, - streaming, - }); - - return llm; -} - -module.exports = { - initializeLLM, -}; diff --git a/api/server/controllers/agents/request.js b/api/server/controllers/agents/request.js index 24b7822c1f..2c8e424b5d 100644 --- a/api/server/controllers/agents/request.js +++ b/api/server/controllers/agents/request.js @@ -1,3 +1,5 @@ +const { sendEvent } = require('@librechat/api'); +const { logger } = require('@librechat/data-schemas'); const { Constants } = require('librechat-data-provider'); const { handleAbortError, @@ -5,17 +7,18 @@ const { cleanupAbortController, } = require('~/server/middleware'); const { disposeClient, clientRegistry, requestDataMap } = require('~/server/cleanup'); -const { sendMessage } = require('~/server/utils'); const { saveMessage } = require('~/models'); -const { logger } = require('~/config'); const AgentController = async (req, res, next, initializeClient, addTitle) => { let { text, endpointOption, conversationId, + isContinued = false, + editedContent = null, parentMessageId = null, overrideParentMessageId = null, + responseMessageId: editedResponseMessageId = null, } = req.body; let sender; @@ -67,7 +70,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => { handler(); } } catch (e) { - // Ignore cleanup errors + logger.error('[AgentController] Error in cleanup handler', e); } } } @@ -155,7 +158,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => { try { res.removeListener('close', closeHandler); } catch (e) { - // Ignore + logger.error('[AgentController] Error removing close listener', e); } }); @@ -163,10 +166,14 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => { user: userId, onStart, getReqData, + isContinued, + editedContent, conversationId, parentMessageId, abortController, overrideParentMessageId, + isEdited: !!editedContent, + responseMessageId: editedResponseMessageId, progressOptions: { res, }, @@ -206,7 +213,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => { // Create a new response object with minimal copies const finalResponse = { ...response }; - sendMessage(res, { + sendEvent(res, { final: true, conversation, title: conversation.title, diff --git a/api/server/controllers/agents/v1.js b/api/server/controllers/agents/v1.js index 18bd7190f0..4aa50521cf 100644 --- a/api/server/controllers/agents/v1.js +++ b/api/server/controllers/agents/v1.js @@ -1,6 +1,8 @@ +const { z } = require('zod'); const fs = require('fs').promises; const { nanoid } = require('nanoid'); const { logger } = require('@librechat/data-schemas'); +const { agentCreateSchema, agentUpdateSchema } = require('@librechat/api'); const { Tools, Constants, @@ -8,6 +10,7 @@ const { SystemRoles, EToolResources, actionDelimiter, + removeNullishValues, } = require('librechat-data-provider'); const { getAgent, @@ -30,6 +33,7 @@ const { deleteFileByFilter } = require('~/models/File'); const systemTools = { [Tools.execute_code]: true, [Tools.file_search]: true, + [Tools.web_search]: true, }; /** @@ -42,9 +46,13 @@ const systemTools = { */ const createAgentHandler = async (req, res) => { try { - const { tools = [], provider, name, description, instructions, model, ...agentData } = req.body; + const validatedData = agentCreateSchema.parse(req.body); + const { tools = [], ...agentData } = removeNullishValues(validatedData); + const { id: userId } = req.user; + agentData.id = `agent_${nanoid()}`; + agentData.author = userId; agentData.tools = []; const availableTools = await getCachedTools({ includeGlobal: true }); @@ -58,19 +66,13 @@ const createAgentHandler = async (req, res) => { } } - Object.assign(agentData, { - author: userId, - name, - description, - instructions, - provider, - model, - }); - - agentData.id = `agent_${nanoid()}`; const agent = await createAgent(agentData); res.status(201).json(agent); } catch (error) { + if (error instanceof z.ZodError) { + logger.error('[/Agents] Validation error', error.errors); + return res.status(400).json({ error: 'Invalid request data', details: error.errors }); + } logger.error('[/Agents] Error creating agent', error); res.status(500).json({ error: error.message }); } @@ -154,14 +156,16 @@ const getAgentHandler = async (req, res) => { const updateAgentHandler = async (req, res) => { try { const id = req.params.id; - const { projectIds, removeProjectIds, ...updateData } = req.body; + const validatedData = agentUpdateSchema.parse(req.body); + const { projectIds, removeProjectIds, ...updateData } = removeNullishValues(validatedData); const isAdmin = req.user.role === SystemRoles.ADMIN; const existingAgent = await getAgent({ id }); - const isAuthor = existingAgent.author.toString() === req.user.id; if (!existingAgent) { return res.status(404).json({ error: 'Agent not found' }); } + + const isAuthor = existingAgent.author.toString() === req.user.id; const hasEditPermission = existingAgent.isCollaborative || isAdmin || isAuthor; if (!hasEditPermission) { @@ -200,6 +204,11 @@ const updateAgentHandler = async (req, res) => { return res.json(updatedAgent); } catch (error) { + if (error instanceof z.ZodError) { + logger.error('[/Agents/:id] Validation error', error.errors); + return res.status(400).json({ error: 'Invalid request data', details: error.errors }); + } + logger.error('[/Agents/:id] Error updating Agent', error); if (error.statusCode === 409) { @@ -242,6 +251,8 @@ const duplicateAgentHandler = async (req, res) => { createdAt: _createdAt, updatedAt: _updatedAt, tool_resources: _tool_resources = {}, + versions: _versions, + __v: _v, ...cloneData } = agent; cloneData.name = `${agent.name} (${new Date().toLocaleString('en-US', { diff --git a/api/server/controllers/agents/v1.spec.js b/api/server/controllers/agents/v1.spec.js new file mode 100644 index 0000000000..5ac2645c04 --- /dev/null +++ b/api/server/controllers/agents/v1.spec.js @@ -0,0 +1,659 @@ +const mongoose = require('mongoose'); +const { v4: uuidv4 } = require('uuid'); +const { MongoMemoryServer } = require('mongodb-memory-server'); +const { agentSchema } = require('@librechat/data-schemas'); + +// Only mock the dependencies that are not database-related +jest.mock('~/server/services/Config', () => ({ + getCachedTools: jest.fn().mockResolvedValue({ + web_search: true, + execute_code: true, + file_search: true, + }), +})); + +jest.mock('~/models/Project', () => ({ + getProjectByName: jest.fn().mockResolvedValue(null), +})); + +jest.mock('~/server/services/Files/strategies', () => ({ + getStrategyFunctions: jest.fn(), +})); + +jest.mock('~/server/services/Files/images/avatar', () => ({ + resizeAvatar: jest.fn(), +})); + +jest.mock('~/server/services/Files/S3/crud', () => ({ + refreshS3Url: jest.fn(), +})); + +jest.mock('~/server/services/Files/process', () => ({ + filterFile: jest.fn(), +})); + +jest.mock('~/models/Action', () => ({ + updateAction: jest.fn(), + getActions: jest.fn().mockResolvedValue([]), +})); + +jest.mock('~/models/File', () => ({ + deleteFileByFilter: jest.fn(), +})); + +const { createAgent: createAgentHandler, updateAgent: updateAgentHandler } = require('./v1'); + +/** + * @type {import('mongoose').Model} + */ +let Agent; + +describe('Agent Controllers - Mass Assignment Protection', () => { + let mongoServer; + let mockReq; + let mockRes; + + beforeAll(async () => { + mongoServer = await MongoMemoryServer.create(); + const mongoUri = mongoServer.getUri(); + await mongoose.connect(mongoUri); + Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema); + }, 20000); + + afterAll(async () => { + await mongoose.disconnect(); + await mongoServer.stop(); + }); + + beforeEach(async () => { + await Agent.deleteMany({}); + + // Reset all mocks + jest.clearAllMocks(); + + // Setup mock request and response objects + mockReq = { + user: { + id: new mongoose.Types.ObjectId().toString(), + role: 'USER', + }, + body: {}, + params: {}, + app: { + locals: { + fileStrategy: 'local', + }, + }, + }; + + mockRes = { + status: jest.fn().mockReturnThis(), + json: jest.fn().mockReturnThis(), + }; + }); + + describe('createAgentHandler', () => { + test('should create agent with allowed fields only', async () => { + const validData = { + name: 'Test Agent', + description: 'A test agent', + instructions: 'Be helpful', + provider: 'openai', + model: 'gpt-4', + tools: ['web_search'], + model_parameters: { temperature: 0.7 }, + tool_resources: { + file_search: { file_ids: ['file1', 'file2'] }, + }, + }; + + mockReq.body = validData; + + await createAgentHandler(mockReq, mockRes); + + expect(mockRes.status).toHaveBeenCalledWith(201); + expect(mockRes.json).toHaveBeenCalled(); + + const createdAgent = mockRes.json.mock.calls[0][0]; + expect(createdAgent.name).toBe('Test Agent'); + expect(createdAgent.description).toBe('A test agent'); + expect(createdAgent.provider).toBe('openai'); + expect(createdAgent.model).toBe('gpt-4'); + expect(createdAgent.author.toString()).toBe(mockReq.user.id); + expect(createdAgent.tools).toContain('web_search'); + + // Verify in database + const agentInDb = await Agent.findOne({ id: createdAgent.id }); + expect(agentInDb).toBeDefined(); + expect(agentInDb.name).toBe('Test Agent'); + expect(agentInDb.author.toString()).toBe(mockReq.user.id); + }); + + test('should reject creation with unauthorized fields (mass assignment protection)', async () => { + const maliciousData = { + // Required fields + provider: 'openai', + model: 'gpt-4', + name: 'Malicious Agent', + + // Unauthorized fields that should be stripped + author: new mongoose.Types.ObjectId().toString(), // Should not be able to set author + authorName: 'Hacker', // Should be stripped + isCollaborative: true, // Should be stripped on creation + versions: [], // Should be stripped + _id: new mongoose.Types.ObjectId(), // Should be stripped + id: 'custom_agent_id', // Should be overridden + createdAt: new Date('2020-01-01'), // Should be stripped + updatedAt: new Date('2020-01-01'), // Should be stripped + }; + + mockReq.body = maliciousData; + + await createAgentHandler(mockReq, mockRes); + + expect(mockRes.status).toHaveBeenCalledWith(201); + + const createdAgent = mockRes.json.mock.calls[0][0]; + + // Verify unauthorized fields were not set + expect(createdAgent.author.toString()).toBe(mockReq.user.id); // Should be the request user, not the malicious value + expect(createdAgent.authorName).toBeUndefined(); + expect(createdAgent.isCollaborative).toBeFalsy(); + expect(createdAgent.versions).toHaveLength(1); // Should have exactly 1 version from creation + expect(createdAgent.id).not.toBe('custom_agent_id'); // Should have generated ID + expect(createdAgent.id).toMatch(/^agent_/); // Should have proper prefix + + // Verify timestamps are recent (not the malicious dates) + const createdTime = new Date(createdAgent.createdAt).getTime(); + const now = Date.now(); + expect(now - createdTime).toBeLessThan(5000); // Created within last 5 seconds + + // Verify in database + const agentInDb = await Agent.findOne({ id: createdAgent.id }); + expect(agentInDb.author.toString()).toBe(mockReq.user.id); + expect(agentInDb.authorName).toBeUndefined(); + }); + + test('should validate required fields', async () => { + const invalidData = { + name: 'Missing Required Fields', + // Missing provider and model + }; + + mockReq.body = invalidData; + + await createAgentHandler(mockReq, mockRes); + + expect(mockRes.status).toHaveBeenCalledWith(400); + expect(mockRes.json).toHaveBeenCalledWith( + expect.objectContaining({ + error: 'Invalid request data', + details: expect.any(Array), + }), + ); + + // Verify nothing was created in database + const count = await Agent.countDocuments(); + expect(count).toBe(0); + }); + + test('should handle tool_resources validation', async () => { + const dataWithInvalidToolResources = { + provider: 'openai', + model: 'gpt-4', + name: 'Agent with Tool Resources', + tool_resources: { + // Valid resources + file_search: { + file_ids: ['file1', 'file2'], + vector_store_ids: ['vs1'], + }, + execute_code: { + file_ids: ['file3'], + }, + // Invalid resource (should be stripped by schema) + invalid_resource: { + file_ids: ['file4'], + }, + }, + }; + + mockReq.body = dataWithInvalidToolResources; + + await createAgentHandler(mockReq, mockRes); + + expect(mockRes.status).toHaveBeenCalledWith(201); + + const createdAgent = mockRes.json.mock.calls[0][0]; + expect(createdAgent.tool_resources).toBeDefined(); + expect(createdAgent.tool_resources.file_search).toBeDefined(); + expect(createdAgent.tool_resources.execute_code).toBeDefined(); + expect(createdAgent.tool_resources.invalid_resource).toBeUndefined(); // Should be stripped + + // Verify in database + const agentInDb = await Agent.findOne({ id: createdAgent.id }); + expect(agentInDb.tool_resources.invalid_resource).toBeUndefined(); + }); + + test('should handle avatar validation', async () => { + const dataWithAvatar = { + provider: 'openai', + model: 'gpt-4', + name: 'Agent with Avatar', + avatar: { + filepath: 'https://example.com/avatar.png', + source: 's3', + }, + }; + + mockReq.body = dataWithAvatar; + + await createAgentHandler(mockReq, mockRes); + + expect(mockRes.status).toHaveBeenCalledWith(201); + + const createdAgent = mockRes.json.mock.calls[0][0]; + expect(createdAgent.avatar).toEqual({ + filepath: 'https://example.com/avatar.png', + source: 's3', + }); + }); + + test('should handle invalid avatar format', async () => { + const dataWithInvalidAvatar = { + provider: 'openai', + model: 'gpt-4', + name: 'Agent with Invalid Avatar', + avatar: 'just-a-string', // Invalid format + }; + + mockReq.body = dataWithInvalidAvatar; + + await createAgentHandler(mockReq, mockRes); + + expect(mockRes.status).toHaveBeenCalledWith(400); + expect(mockRes.json).toHaveBeenCalledWith( + expect.objectContaining({ + error: 'Invalid request data', + }), + ); + }); + }); + + describe('updateAgentHandler', () => { + let existingAgentId; + let existingAgentAuthorId; + + beforeEach(async () => { + // Create an existing agent for update tests + existingAgentAuthorId = new mongoose.Types.ObjectId(); + const agent = await Agent.create({ + id: `agent_${uuidv4()}`, + name: 'Original Agent', + provider: 'openai', + model: 'gpt-3.5-turbo', + author: existingAgentAuthorId, + description: 'Original description', + isCollaborative: false, + versions: [ + { + name: 'Original Agent', + provider: 'openai', + model: 'gpt-3.5-turbo', + description: 'Original description', + createdAt: new Date(), + updatedAt: new Date(), + }, + ], + }); + existingAgentId = agent.id; + }); + + test('should update agent with allowed fields only', async () => { + mockReq.user.id = existingAgentAuthorId.toString(); // Set as author + mockReq.params.id = existingAgentId; + mockReq.body = { + name: 'Updated Agent', + description: 'Updated description', + model: 'gpt-4', + isCollaborative: true, // This IS allowed in updates + }; + + await updateAgentHandler(mockReq, mockRes); + + expect(mockRes.status).not.toHaveBeenCalledWith(400); + expect(mockRes.status).not.toHaveBeenCalledWith(403); + expect(mockRes.json).toHaveBeenCalled(); + + const updatedAgent = mockRes.json.mock.calls[0][0]; + expect(updatedAgent.name).toBe('Updated Agent'); + expect(updatedAgent.description).toBe('Updated description'); + expect(updatedAgent.model).toBe('gpt-4'); + expect(updatedAgent.isCollaborative).toBe(true); + expect(updatedAgent.author).toBe(existingAgentAuthorId.toString()); + + // Verify in database + const agentInDb = await Agent.findOne({ id: existingAgentId }); + expect(agentInDb.name).toBe('Updated Agent'); + expect(agentInDb.isCollaborative).toBe(true); + }); + + test('should reject update with unauthorized fields (mass assignment protection)', async () => { + mockReq.user.id = existingAgentAuthorId.toString(); + mockReq.params.id = existingAgentId; + mockReq.body = { + name: 'Updated Name', + + // Unauthorized fields that should be stripped + author: new mongoose.Types.ObjectId().toString(), // Should not be able to change author + authorName: 'Hacker', // Should be stripped + id: 'different_agent_id', // Should be stripped + _id: new mongoose.Types.ObjectId(), // Should be stripped + versions: [], // Should be stripped + createdAt: new Date('2020-01-01'), // Should be stripped + updatedAt: new Date('2020-01-01'), // Should be stripped + }; + + await updateAgentHandler(mockReq, mockRes); + + expect(mockRes.json).toHaveBeenCalled(); + + const updatedAgent = mockRes.json.mock.calls[0][0]; + + // Verify unauthorized fields were not changed + expect(updatedAgent.author).toBe(existingAgentAuthorId.toString()); // Should not have changed + expect(updatedAgent.authorName).toBeUndefined(); + expect(updatedAgent.id).toBe(existingAgentId); // Should not have changed + expect(updatedAgent.name).toBe('Updated Name'); // Only this should have changed + + // Verify in database + const agentInDb = await Agent.findOne({ id: existingAgentId }); + expect(agentInDb.author.toString()).toBe(existingAgentAuthorId.toString()); + expect(agentInDb.id).toBe(existingAgentId); + }); + + test('should reject update from non-author when not collaborative', async () => { + const differentUserId = new mongoose.Types.ObjectId().toString(); + mockReq.user.id = differentUserId; // Different user + mockReq.params.id = existingAgentId; + mockReq.body = { + name: 'Unauthorized Update', + }; + + await updateAgentHandler(mockReq, mockRes); + + expect(mockRes.status).toHaveBeenCalledWith(403); + expect(mockRes.json).toHaveBeenCalledWith({ + error: 'You do not have permission to modify this non-collaborative agent', + }); + + // Verify agent was not modified in database + const agentInDb = await Agent.findOne({ id: existingAgentId }); + expect(agentInDb.name).toBe('Original Agent'); + }); + + test('should allow update from non-author when collaborative', async () => { + // First make the agent collaborative + await Agent.updateOne({ id: existingAgentId }, { isCollaborative: true }); + + const differentUserId = new mongoose.Types.ObjectId().toString(); + mockReq.user.id = differentUserId; // Different user + mockReq.params.id = existingAgentId; + mockReq.body = { + name: 'Collaborative Update', + }; + + await updateAgentHandler(mockReq, mockRes); + + expect(mockRes.status).not.toHaveBeenCalledWith(403); + expect(mockRes.json).toHaveBeenCalled(); + + const updatedAgent = mockRes.json.mock.calls[0][0]; + expect(updatedAgent.name).toBe('Collaborative Update'); + // Author field should be removed for non-author + expect(updatedAgent.author).toBeUndefined(); + + // Verify in database + const agentInDb = await Agent.findOne({ id: existingAgentId }); + expect(agentInDb.name).toBe('Collaborative Update'); + }); + + test('should allow admin to update any agent', async () => { + const adminUserId = new mongoose.Types.ObjectId().toString(); + mockReq.user.id = adminUserId; + mockReq.user.role = 'ADMIN'; // Set as admin + mockReq.params.id = existingAgentId; + mockReq.body = { + name: 'Admin Update', + }; + + await updateAgentHandler(mockReq, mockRes); + + expect(mockRes.status).not.toHaveBeenCalledWith(403); + expect(mockRes.json).toHaveBeenCalled(); + + const updatedAgent = mockRes.json.mock.calls[0][0]; + expect(updatedAgent.name).toBe('Admin Update'); + }); + + test('should handle projectIds updates', async () => { + mockReq.user.id = existingAgentAuthorId.toString(); + mockReq.params.id = existingAgentId; + + const projectId1 = new mongoose.Types.ObjectId().toString(); + const projectId2 = new mongoose.Types.ObjectId().toString(); + + mockReq.body = { + projectIds: [projectId1, projectId2], + }; + + await updateAgentHandler(mockReq, mockRes); + + expect(mockRes.json).toHaveBeenCalled(); + + const updatedAgent = mockRes.json.mock.calls[0][0]; + expect(updatedAgent).toBeDefined(); + // Note: updateAgentProjects requires more setup, so we just verify the handler doesn't crash + }); + + test('should validate tool_resources in updates', async () => { + mockReq.user.id = existingAgentAuthorId.toString(); + mockReq.params.id = existingAgentId; + mockReq.body = { + tool_resources: { + ocr: { + file_ids: ['ocr1', 'ocr2'], + }, + execute_code: { + file_ids: ['img1'], + }, + // Invalid tool resource + invalid_tool: { + file_ids: ['invalid'], + }, + }, + }; + + await updateAgentHandler(mockReq, mockRes); + + expect(mockRes.json).toHaveBeenCalled(); + + const updatedAgent = mockRes.json.mock.calls[0][0]; + expect(updatedAgent.tool_resources).toBeDefined(); + expect(updatedAgent.tool_resources.ocr).toBeDefined(); + expect(updatedAgent.tool_resources.execute_code).toBeDefined(); + expect(updatedAgent.tool_resources.invalid_tool).toBeUndefined(); + }); + + test('should return 404 for non-existent agent', async () => { + mockReq.user.id = existingAgentAuthorId.toString(); + mockReq.params.id = `agent_${uuidv4()}`; // Non-existent ID + mockReq.body = { + name: 'Update Non-existent', + }; + + await updateAgentHandler(mockReq, mockRes); + + expect(mockRes.status).toHaveBeenCalledWith(404); + expect(mockRes.json).toHaveBeenCalledWith({ error: 'Agent not found' }); + }); + + test('should handle validation errors properly', async () => { + mockReq.user.id = existingAgentAuthorId.toString(); + mockReq.params.id = existingAgentId; + mockReq.body = { + model_parameters: 'invalid-not-an-object', // Should be an object + }; + + await updateAgentHandler(mockReq, mockRes); + + expect(mockRes.status).toHaveBeenCalledWith(400); + expect(mockRes.json).toHaveBeenCalledWith( + expect.objectContaining({ + error: 'Invalid request data', + details: expect.any(Array), + }), + ); + }); + }); + + describe('Mass Assignment Attack Scenarios', () => { + test('should prevent setting system fields during creation', async () => { + const systemFields = { + provider: 'openai', + model: 'gpt-4', + name: 'System Fields Test', + + // System fields that should never be settable by users + __v: 99, + _id: new mongoose.Types.ObjectId(), + versions: [ + { + name: 'Fake Version', + provider: 'fake', + model: 'fake-model', + }, + ], + }; + + mockReq.body = systemFields; + + await createAgentHandler(mockReq, mockRes); + + expect(mockRes.status).toHaveBeenCalledWith(201); + + const createdAgent = mockRes.json.mock.calls[0][0]; + + // Verify system fields were not affected + expect(createdAgent.__v).not.toBe(99); + expect(createdAgent.versions).toHaveLength(1); // Should only have the auto-created version + expect(createdAgent.versions[0].name).toBe('System Fields Test'); // From actual creation + expect(createdAgent.versions[0].provider).toBe('openai'); // From actual creation + + // Verify in database + const agentInDb = await Agent.findOne({ id: createdAgent.id }); + expect(agentInDb.__v).not.toBe(99); + }); + + test('should prevent privilege escalation through isCollaborative', async () => { + // Create a non-collaborative agent + const authorId = new mongoose.Types.ObjectId(); + const agent = await Agent.create({ + id: `agent_${uuidv4()}`, + name: 'Private Agent', + provider: 'openai', + model: 'gpt-4', + author: authorId, + isCollaborative: false, + versions: [ + { + name: 'Private Agent', + provider: 'openai', + model: 'gpt-4', + createdAt: new Date(), + updatedAt: new Date(), + }, + ], + }); + + // Try to make it collaborative as a different user + const attackerId = new mongoose.Types.ObjectId().toString(); + mockReq.user.id = attackerId; + mockReq.params.id = agent.id; + mockReq.body = { + isCollaborative: true, // Trying to escalate privileges + }; + + await updateAgentHandler(mockReq, mockRes); + + // Should be rejected + expect(mockRes.status).toHaveBeenCalledWith(403); + + // Verify in database that it's still not collaborative + const agentInDb = await Agent.findOne({ id: agent.id }); + expect(agentInDb.isCollaborative).toBe(false); + }); + + test('should prevent author hijacking', async () => { + const originalAuthorId = new mongoose.Types.ObjectId(); + const attackerId = new mongoose.Types.ObjectId(); + + // Admin creates an agent + mockReq.user.id = originalAuthorId.toString(); + mockReq.user.role = 'ADMIN'; + mockReq.body = { + provider: 'openai', + model: 'gpt-4', + name: 'Admin Agent', + author: attackerId.toString(), // Trying to set different author + }; + + await createAgentHandler(mockReq, mockRes); + + expect(mockRes.status).toHaveBeenCalledWith(201); + + const createdAgent = mockRes.json.mock.calls[0][0]; + + // Author should be the actual user, not the attempted value + expect(createdAgent.author.toString()).toBe(originalAuthorId.toString()); + expect(createdAgent.author.toString()).not.toBe(attackerId.toString()); + + // Verify in database + const agentInDb = await Agent.findOne({ id: createdAgent.id }); + expect(agentInDb.author.toString()).toBe(originalAuthorId.toString()); + }); + + test('should strip unknown fields to prevent future vulnerabilities', async () => { + mockReq.body = { + provider: 'openai', + model: 'gpt-4', + name: 'Future Proof Test', + + // Unknown fields that might be added in future + superAdminAccess: true, + bypassAllChecks: true, + internalFlag: 'secret', + futureFeature: 'exploit', + }; + + await createAgentHandler(mockReq, mockRes); + + expect(mockRes.status).toHaveBeenCalledWith(201); + + const createdAgent = mockRes.json.mock.calls[0][0]; + + // Verify unknown fields were stripped + expect(createdAgent.superAdminAccess).toBeUndefined(); + expect(createdAgent.bypassAllChecks).toBeUndefined(); + expect(createdAgent.internalFlag).toBeUndefined(); + expect(createdAgent.futureFeature).toBeUndefined(); + + // Also check in database + const agentInDb = await Agent.findOne({ id: createdAgent.id }).lean(); + expect(agentInDb.superAdminAccess).toBeUndefined(); + expect(agentInDb.bypassAllChecks).toBeUndefined(); + expect(agentInDb.internalFlag).toBeUndefined(); + expect(agentInDb.futureFeature).toBeUndefined(); + }); + }); +}); diff --git a/api/server/controllers/assistants/chatV1.js b/api/server/controllers/assistants/chatV1.js index 9129a6a1c1..b4fe0d9013 100644 --- a/api/server/controllers/assistants/chatV1.js +++ b/api/server/controllers/assistants/chatV1.js @@ -1,4 +1,7 @@ const { v4 } = require('uuid'); +const { sleep } = require('@librechat/agents'); +const { sendEvent } = require('@librechat/api'); +const { logger } = require('@librechat/data-schemas'); const { Time, Constants, @@ -19,20 +22,20 @@ const { addThreadMetadata, saveAssistantMessage, } = require('~/server/services/Threads'); -const { sendResponse, sendMessage, sleep, countTokens } = require('~/server/utils'); const { runAssistant, createOnTextProgress } = require('~/server/services/AssistantService'); const validateAuthor = require('~/server/middleware/assistants/validateAuthor'); const { formatMessage, createVisionPrompt } = require('~/app/clients/prompts'); const { createRun, StreamRunManager } = require('~/server/services/Runs'); const { addTitle } = require('~/server/services/Endpoints/assistants'); const { createRunBody } = require('~/server/services/createRunBody'); +const { sendResponse } = require('~/server/middleware/error'); const { getTransactions } = require('~/models/Transaction'); const { checkBalance } = require('~/models/balanceMethods'); const { getConvo } = require('~/models/Conversation'); const getLogStores = require('~/cache/getLogStores'); +const { countTokens } = require('~/server/utils'); const { getModelMaxTokens } = require('~/utils'); const { getOpenAIClient } = require('./helpers'); -const { logger } = require('~/config'); /** * @route POST / @@ -471,7 +474,7 @@ const chatV1 = async (req, res) => { await Promise.all(promises); const sendInitialResponse = () => { - sendMessage(res, { + sendEvent(res, { sync: true, conversationId, // messages: previousMessages, @@ -587,7 +590,7 @@ const chatV1 = async (req, res) => { iconURL: endpointOption.iconURL, }; - sendMessage(res, { + sendEvent(res, { final: true, conversation, requestMessage: { diff --git a/api/server/controllers/assistants/chatV2.js b/api/server/controllers/assistants/chatV2.js index 309e5a86c4..e1ba93bc21 100644 --- a/api/server/controllers/assistants/chatV2.js +++ b/api/server/controllers/assistants/chatV2.js @@ -1,4 +1,7 @@ const { v4 } = require('uuid'); +const { sleep } = require('@librechat/agents'); +const { sendEvent } = require('@librechat/api'); +const { logger } = require('@librechat/data-schemas'); const { Time, Constants, @@ -22,15 +25,14 @@ const { createErrorHandler } = require('~/server/controllers/assistants/errors') const validateAuthor = require('~/server/middleware/assistants/validateAuthor'); const { createRun, StreamRunManager } = require('~/server/services/Runs'); const { addTitle } = require('~/server/services/Endpoints/assistants'); -const { sendMessage, sleep, countTokens } = require('~/server/utils'); const { createRunBody } = require('~/server/services/createRunBody'); const { getTransactions } = require('~/models/Transaction'); const { checkBalance } = require('~/models/balanceMethods'); const { getConvo } = require('~/models/Conversation'); const getLogStores = require('~/cache/getLogStores'); +const { countTokens } = require('~/server/utils'); const { getModelMaxTokens } = require('~/utils'); const { getOpenAIClient } = require('./helpers'); -const { logger } = require('~/config'); /** * @route POST / @@ -309,7 +311,7 @@ const chatV2 = async (req, res) => { await Promise.all(promises); const sendInitialResponse = () => { - sendMessage(res, { + sendEvent(res, { sync: true, conversationId, // messages: previousMessages, @@ -432,7 +434,7 @@ const chatV2 = async (req, res) => { iconURL: endpointOption.iconURL, }; - sendMessage(res, { + sendEvent(res, { final: true, conversation, requestMessage: { diff --git a/api/server/controllers/assistants/errors.js b/api/server/controllers/assistants/errors.js index a4b880bf04..182b230fba 100644 --- a/api/server/controllers/assistants/errors.js +++ b/api/server/controllers/assistants/errors.js @@ -1,10 +1,10 @@ // errorHandler.js -const { sendResponse } = require('~/server/utils'); -const { logger } = require('~/config'); -const getLogStores = require('~/cache/getLogStores'); +const { logger } = require('@librechat/data-schemas'); const { CacheKeys, ViolationTypes, ContentTypes } = require('librechat-data-provider'); -const { getConvo } = require('~/models/Conversation'); const { recordUsage, checkMessageGaps } = require('~/server/services/Threads'); +const { sendResponse } = require('~/server/middleware/error'); +const { getConvo } = require('~/models/Conversation'); +const getLogStores = require('~/cache/getLogStores'); /** * @typedef {Object} ErrorHandlerContext @@ -78,7 +78,7 @@ const createErrorHandler = ({ req, res, getContext, originPath = '/assistants/ch } else if (/Files.*are invalid/.test(error.message)) { const errorMessage = `Files are invalid, or may not have uploaded yet.${ endpoint === 'azureAssistants' - ? ' If using Azure OpenAI, files are only available in the region of the assistant\'s model at the time of upload.' + ? " If using Azure OpenAI, files are only available in the region of the assistant's model at the time of upload." : '' }`; return sendResponse(req, res, messageData, errorMessage); diff --git a/api/server/controllers/tools.js b/api/server/controllers/tools.js index 254ecb4f94..8d5d2e9ce6 100644 --- a/api/server/controllers/tools.js +++ b/api/server/controllers/tools.js @@ -1,5 +1,7 @@ const { nanoid } = require('nanoid'); const { EnvVar } = require('@librechat/agents'); +const { checkAccess } = require('@librechat/api'); +const { logger } = require('@librechat/data-schemas'); const { Tools, AuthType, @@ -13,9 +15,8 @@ const { processCodeOutput } = require('~/server/services/Files/Code/process'); const { createToolCall, getToolCallsByConvo } = require('~/models/ToolCall'); const { loadAuthValues } = require('~/server/services/Tools/credentials'); const { loadTools } = require('~/app/clients/tools/util'); -const { checkAccess } = require('~/server/middleware'); +const { getRoleByName } = require('~/models/Role'); const { getMessage } = require('~/models/Message'); -const { logger } = require('~/config'); const fieldsMap = { [Tools.execute_code]: [EnvVar.CODE_API_KEY], @@ -79,6 +80,7 @@ const verifyToolAuth = async (req, res) => { throwError: false, }); } catch (error) { + logger.error('Error loading auth values', error); res.status(200).json({ authenticated: false, message: AuthType.USER_PROVIDED }); return; } @@ -132,7 +134,12 @@ const callTool = async (req, res) => { logger.debug(`[${toolId}/call] User: ${req.user.id}`); let hasAccess = true; if (toolAccessPermType[toolId]) { - hasAccess = await checkAccess(req.user, toolAccessPermType[toolId], [Permissions.USE]); + hasAccess = await checkAccess({ + user: req.user, + permissionType: toolAccessPermType[toolId], + permissions: [Permissions.USE], + getRoleByName, + }); } if (!hasAccess) { logger.warn( diff --git a/api/server/index.js b/api/server/index.js index 8c7db3e226..2da1adfcde 100644 --- a/api/server/index.js +++ b/api/server/index.js @@ -55,7 +55,6 @@ const startServer = async () => { /* Middleware */ app.use(noIndex); - app.use(errorController); app.use(express.json({ limit: '3mb' })); app.use(express.urlencoded({ extended: true, limit: '3mb' })); app.use(mongoSanitize()); @@ -97,7 +96,6 @@ const startServer = async () => { app.use('/api/actions', routes.actions); app.use('/api/keys', routes.keys); app.use('/api/user', routes.user); - app.use('/api/ask', routes.ask); app.use('/api/search', routes.search); app.use('/api/edit', routes.edit); app.use('/api/messages', routes.messages); @@ -118,11 +116,13 @@ const startServer = async () => { app.use('/api/roles', routes.roles); app.use('/api/agents', routes.agents); app.use('/api/banner', routes.banner); - app.use('/api/bedrock', routes.bedrock); app.use('/api/memories', routes.memories); app.use('/api/tags', routes.tags); app.use('/api/mcp', routes.mcp); + // Add the error controller one more time after all routes + app.use(errorController); + app.use((req, res) => { res.set({ 'Cache-Control': process.env.INDEX_CACHE_CONTROL || 'no-cache, no-store, must-revalidate', diff --git a/api/server/index.spec.js b/api/server/index.spec.js index 25b5ab9f03..43ad57108f 100644 --- a/api/server/index.spec.js +++ b/api/server/index.spec.js @@ -1,5 +1,4 @@ const fs = require('fs'); -const path = require('path'); const request = require('supertest'); const { MongoMemoryServer } = require('mongodb-memory-server'); const mongoose = require('mongoose'); @@ -59,6 +58,30 @@ describe('Server Configuration', () => { expect(response.headers['pragma']).toBe('no-cache'); expect(response.headers['expires']).toBe('0'); }); + + it('should return 500 for unknown errors via ErrorController', async () => { + // Testing the error handling here on top of unit tests to ensure the middleware is correctly integrated + + // Mock MongoDB operations to fail + const originalFindOne = mongoose.models.User.findOne; + const mockError = new Error('MongoDB operation failed'); + mongoose.models.User.findOne = jest.fn().mockImplementation(() => { + throw mockError; + }); + + try { + const response = await request(app).post('/api/auth/login').send({ + email: 'test@example.com', + password: 'password123', + }); + + expect(response.status).toBe(500); + expect(response.text).toBe('An unknown error occurred.'); + } finally { + // Restore original function + mongoose.models.User.findOne = originalFindOne; + } + }); }); // Polls the /health endpoint every 30ms for up to 10 seconds to wait for the server to start completely diff --git a/api/server/middleware/abortMiddleware.js b/api/server/middleware/abortMiddleware.js index 94d69004bd..c5fc48780c 100644 --- a/api/server/middleware/abortMiddleware.js +++ b/api/server/middleware/abortMiddleware.js @@ -1,13 +1,13 @@ -// abortMiddleware.js +const { logger } = require('@librechat/data-schemas'); +const { countTokens, isEnabled, sendEvent } = require('@librechat/api'); const { isAssistantsEndpoint, ErrorTypes } = require('librechat-data-provider'); -const { sendMessage, sendError, countTokens, isEnabled } = require('~/server/utils'); const { truncateText, smartTruncateText } = require('~/app/clients/prompts'); const clearPendingReq = require('~/cache/clearPendingReq'); +const { sendError } = require('~/server/middleware/error'); const { spendTokens } = require('~/models/spendTokens'); const abortControllers = require('./abortControllers'); const { saveMessage, getConvo } = require('~/models'); const { abortRun } = require('./abortRun'); -const { logger } = require('~/config'); const abortDataMap = new WeakMap(); @@ -101,7 +101,7 @@ async function abortMessage(req, res) { cleanupAbortController(abortKey); if (res.headersSent && finalEvent) { - return sendMessage(res, finalEvent); + return sendEvent(res, finalEvent); } res.setHeader('Content-Type', 'application/json'); @@ -174,7 +174,7 @@ const createAbortController = (req, res, getAbortData, getReqData) => { * @param {string} responseMessageId */ const onStart = (userMessage, responseMessageId) => { - sendMessage(res, { message: userMessage, created: true }); + sendEvent(res, { message: userMessage, created: true }); const abortKey = userMessage?.conversationId ?? req.user.id; getReqData({ abortKey }); diff --git a/api/server/middleware/abortRun.js b/api/server/middleware/abortRun.js index 01b34aacc2..2846c6eefc 100644 --- a/api/server/middleware/abortRun.js +++ b/api/server/middleware/abortRun.js @@ -1,11 +1,11 @@ +const { sendEvent } = require('@librechat/api'); +const { logger } = require('@librechat/data-schemas'); const { CacheKeys, RunStatus, isUUID } = require('librechat-data-provider'); const { initializeClient } = require('~/server/services/Endpoints/assistants'); const { checkMessageGaps, recordUsage } = require('~/server/services/Threads'); const { deleteMessages } = require('~/models/Message'); const { getConvo } = require('~/models/Conversation'); const getLogStores = require('~/cache/getLogStores'); -const { sendMessage } = require('~/server/utils'); -const { logger } = require('~/config'); const three_minutes = 1000 * 60 * 3; @@ -34,7 +34,7 @@ async function abortRun(req, res) { const [thread_id, run_id] = runValues.split(':'); if (!run_id) { - logger.warn('[abortRun] Couldn\'t find run for cancel request', { thread_id }); + logger.warn("[abortRun] Couldn't find run for cancel request", { thread_id }); return res.status(204).send({ message: 'Run not found' }); } else if (run_id === 'cancelled') { logger.warn('[abortRun] Run already cancelled', { thread_id }); @@ -93,7 +93,7 @@ async function abortRun(req, res) { }; if (res.headersSent && finalEvent) { - return sendMessage(res, finalEvent); + return sendEvent(res, finalEvent); } res.json(finalEvent); diff --git a/api/server/middleware/buildEndpointOption.js b/api/server/middleware/buildEndpointOption.js index 8394223b5e..d302bf8743 100644 --- a/api/server/middleware/buildEndpointOption.js +++ b/api/server/middleware/buildEndpointOption.js @@ -1,13 +1,12 @@ +const { logger } = require('@librechat/data-schemas'); const { - parseCompactConvo, + EndpointURLs, EModelEndpoint, isAgentsEndpoint, - EndpointURLs, + parseCompactConvo, } = require('librechat-data-provider'); const azureAssistants = require('~/server/services/Endpoints/azureAssistants'); -const { getModelsConfig } = require('~/server/controllers/ModelController'); const assistants = require('~/server/services/Endpoints/assistants'); -const gptPlugins = require('~/server/services/Endpoints/gptPlugins'); const { processFiles } = require('~/server/services/Files/process'); const anthropic = require('~/server/services/Endpoints/anthropic'); const bedrock = require('~/server/services/Endpoints/bedrock'); @@ -25,7 +24,6 @@ const buildFunction = { [EModelEndpoint.bedrock]: bedrock.buildOptions, [EModelEndpoint.azureOpenAI]: openAI.buildOptions, [EModelEndpoint.anthropic]: anthropic.buildOptions, - [EModelEndpoint.gptPlugins]: gptPlugins.buildOptions, [EModelEndpoint.assistants]: assistants.buildOptions, [EModelEndpoint.azureAssistants]: azureAssistants.buildOptions, }; @@ -36,6 +34,9 @@ async function buildEndpointOption(req, res, next) { try { parsedBody = parseCompactConvo({ endpoint, endpointType, conversation: req.body }); } catch (error) { + logger.warn( + `Error parsing conversation for endpoint ${endpoint}${error?.message ? `: ${error.message}` : ''}`, + ); return handleError(res, { text: 'Error parsing conversation' }); } @@ -57,15 +58,6 @@ async function buildEndpointOption(req, res, next) { return handleError(res, { text: 'Model spec mismatch' }); } - if ( - currentModelSpec.preset.endpoint !== EModelEndpoint.gptPlugins && - currentModelSpec.preset.tools - ) { - return handleError(res, { - text: `Only the "${EModelEndpoint.gptPlugins}" endpoint can have tools defined in the preset`, - }); - } - try { currentModelSpec.preset.spec = spec; if (currentModelSpec.iconURL != null && currentModelSpec.iconURL !== '') { @@ -77,6 +69,7 @@ async function buildEndpointOption(req, res, next) { conversation: currentModelSpec.preset, }); } catch (error) { + logger.error(`Error parsing model spec for endpoint ${endpoint}`, error); return handleError(res, { text: 'Error parsing model spec' }); } } @@ -84,20 +77,23 @@ async function buildEndpointOption(req, res, next) { try { const isAgents = isAgentsEndpoint(endpoint) || req.baseUrl.startsWith(EndpointURLs[EModelEndpoint.agents]); - const endpointFn = buildFunction[isAgents ? EModelEndpoint.agents : (endpointType ?? endpoint)]; - const builder = isAgents ? (...args) => endpointFn(req, ...args) : endpointFn; + const builder = isAgents + ? (...args) => buildFunction[EModelEndpoint.agents](req, ...args) + : buildFunction[endpointType ?? endpoint]; // TODO: use object params req.body.endpointOption = await builder(endpoint, parsedBody, endpointType); - // TODO: use `getModelsConfig` only when necessary - const modelsConfig = await getModelsConfig(req); - req.body.endpointOption.modelsConfig = modelsConfig; if (req.body.files && !isAgents) { req.body.endpointOption.attachments = processFiles(req.body.files); } + next(); } catch (error) { + logger.error( + `Error building endpoint option for endpoint ${endpoint} with type ${endpointType}`, + error, + ); return handleError(res, { text: 'Error building endpoint option' }); } } diff --git a/api/server/middleware/checkBan.js b/api/server/middleware/checkBan.js index 91c31ab66a..ad4e4c86ec 100644 --- a/api/server/middleware/checkBan.js +++ b/api/server/middleware/checkBan.js @@ -18,7 +18,6 @@ const message = 'Your account has been temporarily banned due to violations of o * @function * @param {Object} req - Express Request object. * @param {Object} res - Express Response object. - * @param {String} errorMessage - Error message to be displayed in case of /api/ask or /api/edit request. * * @returns {Promise} - Returns a Promise which when resolved sends a response status of 403 with a specific message if request is not of api/ask or api/edit types. If it is, calls `denyRequest()` function. */ @@ -135,6 +134,7 @@ const checkBan = async (req, res, next = () => {}) => { return await banResponse(req, res); } catch (error) { logger.error('Error in checkBan middleware:', error); + return next(error); } }; diff --git a/api/server/middleware/denyRequest.js b/api/server/middleware/denyRequest.js index 62efb1aeaf..20360519cf 100644 --- a/api/server/middleware/denyRequest.js +++ b/api/server/middleware/denyRequest.js @@ -1,6 +1,7 @@ const crypto = require('crypto'); +const { sendEvent } = require('@librechat/api'); const { getResponseSender, Constants } = require('librechat-data-provider'); -const { sendMessage, sendError } = require('~/server/utils'); +const { sendError } = require('~/server/middleware/error'); const { saveMessage } = require('~/models'); /** @@ -36,7 +37,7 @@ const denyRequest = async (req, res, errorMessage) => { isCreatedByUser: true, text, }; - sendMessage(res, { message: userMessage, created: true }); + sendEvent(res, { message: userMessage, created: true }); const shouldSaveMessage = _convoId && parentMessageId && parentMessageId !== Constants.NO_PARENT; diff --git a/api/server/utils/streamResponse.js b/api/server/middleware/error.js similarity index 76% rename from api/server/utils/streamResponse.js rename to api/server/middleware/error.js index bb8d63b229..db445c1d43 100644 --- a/api/server/utils/streamResponse.js +++ b/api/server/middleware/error.js @@ -1,31 +1,9 @@ const crypto = require('crypto'); +const { logger } = require('@librechat/data-schemas'); const { parseConvo } = require('librechat-data-provider'); +const { sendEvent, handleError } = require('@librechat/api'); const { saveMessage, getMessages } = require('~/models/Message'); const { getConvo } = require('~/models/Conversation'); -const { logger } = require('~/config'); - -/** - * Sends error data in Server Sent Events format and ends the response. - * @param {object} res - The server response. - * @param {string} message - The error message. - */ -const handleError = (res, message) => { - res.write(`event: error\ndata: ${JSON.stringify(message)}\n\n`); - res.end(); -}; - -/** - * Sends message data in Server Sent Events format. - * @param {Express.Response} res - - The server response. - * @param {string | Object} message - The message to be sent. - * @param {'message' | 'error' | 'cancel'} event - [Optional] The type of event. Default is 'message'. - */ -const sendMessage = (res, message, event = 'message') => { - if (typeof message === 'string' && message.length === 0) { - return; - } - res.write(`event: ${event}\ndata: ${JSON.stringify(message)}\n\n`); -}; /** * Processes an error with provided options, saves the error message and sends a corresponding SSE response @@ -91,7 +69,7 @@ const sendError = async (req, res, options, callback) => { convo = parseConvo(errorMessage); } - return sendMessage(res, { + return sendEvent(res, { final: true, requestMessage: query?.[0] ? query[0] : requestMessage, responseMessage: errorMessage, @@ -120,12 +98,10 @@ const sendResponse = (req, res, data, errorMessage) => { if (errorMessage) { return sendError(req, res, { ...data, text: errorMessage }); } - return sendMessage(res, data); + return sendEvent(res, data); }; module.exports = { - sendResponse, - handleError, - sendMessage, sendError, + sendResponse, }; diff --git a/api/server/middleware/limiters/forkLimiters.js b/api/server/middleware/limiters/forkLimiters.js new file mode 100644 index 0000000000..8a07cefab6 --- /dev/null +++ b/api/server/middleware/limiters/forkLimiters.js @@ -0,0 +1,95 @@ +const rateLimit = require('express-rate-limit'); +const { isEnabled } = require('@librechat/api'); +const { RedisStore } = require('rate-limit-redis'); +const { logger } = require('@librechat/data-schemas'); +const { ViolationTypes } = require('librechat-data-provider'); +const ioredisClient = require('~/cache/ioredisClient'); +const logViolation = require('~/cache/logViolation'); + +const getEnvironmentVariables = () => { + const FORK_IP_MAX = parseInt(process.env.FORK_IP_MAX) || 30; + const FORK_IP_WINDOW = parseInt(process.env.FORK_IP_WINDOW) || 1; + const FORK_USER_MAX = parseInt(process.env.FORK_USER_MAX) || 7; + const FORK_USER_WINDOW = parseInt(process.env.FORK_USER_WINDOW) || 1; + const FORK_VIOLATION_SCORE = process.env.FORK_VIOLATION_SCORE; + + const forkIpWindowMs = FORK_IP_WINDOW * 60 * 1000; + const forkIpMax = FORK_IP_MAX; + const forkIpWindowInMinutes = forkIpWindowMs / 60000; + + const forkUserWindowMs = FORK_USER_WINDOW * 60 * 1000; + const forkUserMax = FORK_USER_MAX; + const forkUserWindowInMinutes = forkUserWindowMs / 60000; + + return { + forkIpWindowMs, + forkIpMax, + forkIpWindowInMinutes, + forkUserWindowMs, + forkUserMax, + forkUserWindowInMinutes, + forkViolationScore: FORK_VIOLATION_SCORE, + }; +}; + +const createForkHandler = (ip = true) => { + const { + forkIpMax, + forkUserMax, + forkViolationScore, + forkIpWindowInMinutes, + forkUserWindowInMinutes, + } = getEnvironmentVariables(); + + return async (req, res) => { + const type = ViolationTypes.FILE_UPLOAD_LIMIT; + const errorMessage = { + type, + max: ip ? forkIpMax : forkUserMax, + limiter: ip ? 'ip' : 'user', + windowInMinutes: ip ? forkIpWindowInMinutes : forkUserWindowInMinutes, + }; + + await logViolation(req, res, type, errorMessage, forkViolationScore); + res.status(429).json({ message: 'Too many conversation fork requests. Try again later' }); + }; +}; + +const createForkLimiters = () => { + const { forkIpWindowMs, forkIpMax, forkUserWindowMs, forkUserMax } = getEnvironmentVariables(); + + const ipLimiterOptions = { + windowMs: forkIpWindowMs, + max: forkIpMax, + handler: createForkHandler(), + }; + const userLimiterOptions = { + windowMs: forkUserWindowMs, + max: forkUserMax, + handler: createForkHandler(false), + keyGenerator: function (req) { + return req.user?.id; + }, + }; + + if (isEnabled(process.env.USE_REDIS) && ioredisClient) { + logger.debug('Using Redis for fork rate limiters.'); + const sendCommand = (...args) => ioredisClient.call(...args); + const ipStore = new RedisStore({ + sendCommand, + prefix: 'fork_ip_limiter:', + }); + const userStore = new RedisStore({ + sendCommand, + prefix: 'fork_user_limiter:', + }); + ipLimiterOptions.store = ipStore; + userLimiterOptions.store = userStore; + } + + const forkIpLimiter = rateLimit(ipLimiterOptions); + const forkUserLimiter = rateLimit(userLimiterOptions); + return { forkIpLimiter, forkUserLimiter }; +}; + +module.exports = { createForkLimiters }; diff --git a/api/server/middleware/limiters/importLimiters.js b/api/server/middleware/limiters/importLimiters.js index f353f5e996..7ff48af5eb 100644 --- a/api/server/middleware/limiters/importLimiters.js +++ b/api/server/middleware/limiters/importLimiters.js @@ -1,16 +1,17 @@ const rateLimit = require('express-rate-limit'); +const { isEnabled } = require('@librechat/api'); const { RedisStore } = require('rate-limit-redis'); +const { logger } = require('@librechat/data-schemas'); const { ViolationTypes } = require('librechat-data-provider'); const ioredisClient = require('~/cache/ioredisClient'); const logViolation = require('~/cache/logViolation'); -const { isEnabled } = require('~/server/utils'); -const { logger } = require('~/config'); const getEnvironmentVariables = () => { const IMPORT_IP_MAX = parseInt(process.env.IMPORT_IP_MAX) || 100; const IMPORT_IP_WINDOW = parseInt(process.env.IMPORT_IP_WINDOW) || 15; const IMPORT_USER_MAX = parseInt(process.env.IMPORT_USER_MAX) || 50; const IMPORT_USER_WINDOW = parseInt(process.env.IMPORT_USER_WINDOW) || 15; + const IMPORT_VIOLATION_SCORE = process.env.IMPORT_VIOLATION_SCORE; const importIpWindowMs = IMPORT_IP_WINDOW * 60 * 1000; const importIpMax = IMPORT_IP_MAX; @@ -27,12 +28,18 @@ const getEnvironmentVariables = () => { importUserWindowMs, importUserMax, importUserWindowInMinutes, + importViolationScore: IMPORT_VIOLATION_SCORE, }; }; const createImportHandler = (ip = true) => { - const { importIpMax, importIpWindowInMinutes, importUserMax, importUserWindowInMinutes } = - getEnvironmentVariables(); + const { + importIpMax, + importUserMax, + importViolationScore, + importIpWindowInMinutes, + importUserWindowInMinutes, + } = getEnvironmentVariables(); return async (req, res) => { const type = ViolationTypes.FILE_UPLOAD_LIMIT; @@ -43,7 +50,7 @@ const createImportHandler = (ip = true) => { windowInMinutes: ip ? importIpWindowInMinutes : importUserWindowInMinutes, }; - await logViolation(req, res, type, errorMessage); + await logViolation(req, res, type, errorMessage, importViolationScore); res.status(429).json({ message: 'Too many conversation import requests. Try again later' }); }; }; diff --git a/api/server/middleware/limiters/index.js b/api/server/middleware/limiters/index.js index d1c11e0a12..ab110443dc 100644 --- a/api/server/middleware/limiters/index.js +++ b/api/server/middleware/limiters/index.js @@ -4,6 +4,7 @@ const createSTTLimiters = require('./sttLimiters'); const loginLimiter = require('./loginLimiter'); const importLimiters = require('./importLimiters'); const uploadLimiters = require('./uploadLimiters'); +const forkLimiters = require('./forkLimiters'); const registerLimiter = require('./registerLimiter'); const toolCallLimiter = require('./toolCallLimiter'); const messageLimiters = require('./messageLimiters'); @@ -14,6 +15,7 @@ module.exports = { ...uploadLimiters, ...importLimiters, ...messageLimiters, + ...forkLimiters, loginLimiter, registerLimiter, toolCallLimiter, diff --git a/api/server/middleware/limiters/messageLimiters.js b/api/server/middleware/limiters/messageLimiters.js index 4191c9fe7c..cd409fa528 100644 --- a/api/server/middleware/limiters/messageLimiters.js +++ b/api/server/middleware/limiters/messageLimiters.js @@ -11,6 +11,7 @@ const { MESSAGE_IP_WINDOW = 1, MESSAGE_USER_MAX = 40, MESSAGE_USER_WINDOW = 1, + MESSAGE_VIOLATION_SCORE: score, } = process.env; const ipWindowMs = MESSAGE_IP_WINDOW * 60 * 1000; @@ -39,7 +40,7 @@ const createHandler = (ip = true) => { windowInMinutes: ip ? ipWindowInMinutes : userWindowInMinutes, }; - await logViolation(req, res, type, errorMessage); + await logViolation(req, res, type, errorMessage, score); return await denyRequest(req, res, errorMessage); }; }; diff --git a/api/server/middleware/limiters/sttLimiters.js b/api/server/middleware/limiters/sttLimiters.js index 72ed3af6a3..79305bf5d3 100644 --- a/api/server/middleware/limiters/sttLimiters.js +++ b/api/server/middleware/limiters/sttLimiters.js @@ -11,6 +11,7 @@ const getEnvironmentVariables = () => { const STT_IP_WINDOW = parseInt(process.env.STT_IP_WINDOW) || 1; const STT_USER_MAX = parseInt(process.env.STT_USER_MAX) || 50; const STT_USER_WINDOW = parseInt(process.env.STT_USER_WINDOW) || 1; + const STT_VIOLATION_SCORE = process.env.STT_VIOLATION_SCORE; const sttIpWindowMs = STT_IP_WINDOW * 60 * 1000; const sttIpMax = STT_IP_MAX; @@ -27,11 +28,12 @@ const getEnvironmentVariables = () => { sttUserWindowMs, sttUserMax, sttUserWindowInMinutes, + sttViolationScore: STT_VIOLATION_SCORE, }; }; const createSTTHandler = (ip = true) => { - const { sttIpMax, sttIpWindowInMinutes, sttUserMax, sttUserWindowInMinutes } = + const { sttIpMax, sttIpWindowInMinutes, sttUserMax, sttUserWindowInMinutes, sttViolationScore } = getEnvironmentVariables(); return async (req, res) => { @@ -43,7 +45,7 @@ const createSTTHandler = (ip = true) => { windowInMinutes: ip ? sttIpWindowInMinutes : sttUserWindowInMinutes, }; - await logViolation(req, res, type, errorMessage); + await logViolation(req, res, type, errorMessage, sttViolationScore); res.status(429).json({ message: 'Too many STT requests. Try again later' }); }; }; diff --git a/api/server/middleware/limiters/toolCallLimiter.js b/api/server/middleware/limiters/toolCallLimiter.js index 482744a3e9..b14ca55d81 100644 --- a/api/server/middleware/limiters/toolCallLimiter.js +++ b/api/server/middleware/limiters/toolCallLimiter.js @@ -6,6 +6,8 @@ const logViolation = require('~/cache/logViolation'); const { isEnabled } = require('~/server/utils'); const { logger } = require('~/config'); +const { TOOL_CALL_VIOLATION_SCORE: score } = process.env; + const handler = async (req, res) => { const type = ViolationTypes.TOOL_CALL_LIMIT; const errorMessage = { @@ -15,7 +17,7 @@ const handler = async (req, res) => { windowInMinutes: 1, }; - await logViolation(req, res, type, errorMessage, 0); + await logViolation(req, res, type, errorMessage, score); res.status(429).json({ message: 'Too many tool call requests. Try again later' }); }; diff --git a/api/server/middleware/limiters/ttsLimiters.js b/api/server/middleware/limiters/ttsLimiters.js index 9054a6beb1..93dd6eb992 100644 --- a/api/server/middleware/limiters/ttsLimiters.js +++ b/api/server/middleware/limiters/ttsLimiters.js @@ -11,6 +11,7 @@ const getEnvironmentVariables = () => { const TTS_IP_WINDOW = parseInt(process.env.TTS_IP_WINDOW) || 1; const TTS_USER_MAX = parseInt(process.env.TTS_USER_MAX) || 50; const TTS_USER_WINDOW = parseInt(process.env.TTS_USER_WINDOW) || 1; + const TTS_VIOLATION_SCORE = process.env.TTS_VIOLATION_SCORE; const ttsIpWindowMs = TTS_IP_WINDOW * 60 * 1000; const ttsIpMax = TTS_IP_MAX; @@ -27,11 +28,12 @@ const getEnvironmentVariables = () => { ttsUserWindowMs, ttsUserMax, ttsUserWindowInMinutes, + ttsViolationScore: TTS_VIOLATION_SCORE, }; }; const createTTSHandler = (ip = true) => { - const { ttsIpMax, ttsIpWindowInMinutes, ttsUserMax, ttsUserWindowInMinutes } = + const { ttsIpMax, ttsIpWindowInMinutes, ttsUserMax, ttsUserWindowInMinutes, ttsViolationScore } = getEnvironmentVariables(); return async (req, res) => { @@ -43,7 +45,7 @@ const createTTSHandler = (ip = true) => { windowInMinutes: ip ? ttsIpWindowInMinutes : ttsUserWindowInMinutes, }; - await logViolation(req, res, type, errorMessage); + await logViolation(req, res, type, errorMessage, ttsViolationScore); res.status(429).json({ message: 'Too many TTS requests. Try again later' }); }; }; diff --git a/api/server/middleware/limiters/uploadLimiters.js b/api/server/middleware/limiters/uploadLimiters.js index d9049f898e..84eb6c0717 100644 --- a/api/server/middleware/limiters/uploadLimiters.js +++ b/api/server/middleware/limiters/uploadLimiters.js @@ -11,6 +11,7 @@ const getEnvironmentVariables = () => { const FILE_UPLOAD_IP_WINDOW = parseInt(process.env.FILE_UPLOAD_IP_WINDOW) || 15; const FILE_UPLOAD_USER_MAX = parseInt(process.env.FILE_UPLOAD_USER_MAX) || 50; const FILE_UPLOAD_USER_WINDOW = parseInt(process.env.FILE_UPLOAD_USER_WINDOW) || 15; + const FILE_UPLOAD_VIOLATION_SCORE = process.env.FILE_UPLOAD_VIOLATION_SCORE; const fileUploadIpWindowMs = FILE_UPLOAD_IP_WINDOW * 60 * 1000; const fileUploadIpMax = FILE_UPLOAD_IP_MAX; @@ -27,6 +28,7 @@ const getEnvironmentVariables = () => { fileUploadUserWindowMs, fileUploadUserMax, fileUploadUserWindowInMinutes, + fileUploadViolationScore: FILE_UPLOAD_VIOLATION_SCORE, }; }; @@ -36,6 +38,7 @@ const createFileUploadHandler = (ip = true) => { fileUploadIpWindowInMinutes, fileUploadUserMax, fileUploadUserWindowInMinutes, + fileUploadViolationScore, } = getEnvironmentVariables(); return async (req, res) => { @@ -47,7 +50,7 @@ const createFileUploadHandler = (ip = true) => { windowInMinutes: ip ? fileUploadIpWindowInMinutes : fileUploadUserWindowInMinutes, }; - await logViolation(req, res, type, errorMessage); + await logViolation(req, res, type, errorMessage, fileUploadViolationScore); res.status(429).json({ message: 'Too many file upload requests. Try again later' }); }; }; diff --git a/api/server/middleware/roles/access.js b/api/server/middleware/roles/access.js deleted file mode 100644 index cabbd405b0..0000000000 --- a/api/server/middleware/roles/access.js +++ /dev/null @@ -1,78 +0,0 @@ -const { getRoleByName } = require('~/models/Role'); -const { logger } = require('~/config'); - -/** - * Core function to check if a user has one or more required permissions - * - * @param {object} user - The user object - * @param {PermissionTypes} permissionType - The type of permission to check - * @param {Permissions[]} permissions - The list of specific permissions to check - * @param {Record} [bodyProps] - An optional object where keys are permissions and values are arrays of properties to check - * @param {object} [checkObject] - The object to check properties against - * @returns {Promise} Whether the user has the required permissions - */ -const checkAccess = async (user, permissionType, permissions, bodyProps = {}, checkObject = {}) => { - if (!user) { - return false; - } - - const role = await getRoleByName(user.role); - if (role && role.permissions && role.permissions[permissionType]) { - const hasAnyPermission = permissions.some((permission) => { - if (role.permissions[permissionType][permission]) { - return true; - } - - if (bodyProps[permission] && checkObject) { - return bodyProps[permission].some((prop) => - Object.prototype.hasOwnProperty.call(checkObject, prop), - ); - } - - return false; - }); - - return hasAnyPermission; - } - - return false; -}; - -/** - * Middleware to check if a user has one or more required permissions, optionally based on `req.body` properties. - * - * @param {PermissionTypes} permissionType - The type of permission to check. - * @param {Permissions[]} permissions - The list of specific permissions to check. - * @param {Record} [bodyProps] - An optional object where keys are permissions and values are arrays of `req.body` properties to check. - * @returns {(req: ServerRequest, res: ServerResponse, next: NextFunction) => Promise} Express middleware function. - */ -const generateCheckAccess = (permissionType, permissions, bodyProps = {}) => { - return async (req, res, next) => { - try { - const hasAccess = await checkAccess( - req.user, - permissionType, - permissions, - bodyProps, - req.body, - ); - - if (hasAccess) { - return next(); - } - - logger.warn( - `[${permissionType}] Forbidden: Insufficient permissions for User ${req.user.id}: ${permissions.join(', ')}`, - ); - return res.status(403).json({ message: 'Forbidden: Insufficient permissions' }); - } catch (error) { - logger.error(error); - return res.status(500).json({ message: `Server error: ${error.message}` }); - } - }; -}; - -module.exports = { - checkAccess, - generateCheckAccess, -}; diff --git a/api/server/middleware/roles/index.js b/api/server/middleware/roles/index.js index ebc0043f2f..f01b884e5a 100644 --- a/api/server/middleware/roles/index.js +++ b/api/server/middleware/roles/index.js @@ -1,8 +1,5 @@ const checkAdmin = require('./admin'); -const { checkAccess, generateCheckAccess } = require('./access'); module.exports = { checkAdmin, - checkAccess, - generateCheckAccess, }; diff --git a/api/server/routes/agents/actions.js b/api/server/routes/agents/actions.js index 89d6a9dc42..2f11486a0e 100644 --- a/api/server/routes/agents/actions.js +++ b/api/server/routes/agents/actions.js @@ -1,14 +1,28 @@ const express = require('express'); const { nanoid } = require('nanoid'); -const { actionDelimiter, SystemRoles, removeNullishValues } = require('librechat-data-provider'); +const { logger } = require('@librechat/data-schemas'); +const { generateCheckAccess } = require('@librechat/api'); +const { + SystemRoles, + Permissions, + PermissionTypes, + actionDelimiter, + removeNullishValues, +} = require('librechat-data-provider'); const { encryptMetadata, domainParser } = require('~/server/services/ActionService'); const { updateAction, getActions, deleteAction } = require('~/models/Action'); const { isActionDomainAllowed } = require('~/server/services/domains'); const { getAgent, updateAgent } = require('~/models/Agent'); -const { logger } = require('~/config'); +const { getRoleByName } = require('~/models/Role'); const router = express.Router(); +const checkAgentCreate = generateCheckAccess({ + permissionType: PermissionTypes.AGENTS, + permissions: [Permissions.USE, Permissions.CREATE], + getRoleByName, +}); + // If the user has ADMIN role // then action edition is possible even if not owner of the assistant const isAdmin = (req) => { @@ -41,7 +55,7 @@ router.get('/', async (req, res) => { * @param {ActionMetadata} req.body.metadata - Metadata for the action. * @returns {Object} 200 - success response - application/json */ -router.post('/:agent_id', async (req, res) => { +router.post('/:agent_id', checkAgentCreate, async (req, res) => { try { const { agent_id } = req.params; @@ -149,7 +163,7 @@ router.post('/:agent_id', async (req, res) => { * @param {string} req.params.action_id - The ID of the action to delete. * @returns {Object} 200 - success response - application/json */ -router.delete('/:agent_id/:action_id', async (req, res) => { +router.delete('/:agent_id/:action_id', checkAgentCreate, async (req, res) => { try { const { agent_id, action_id } = req.params; const admin = isAdmin(req); diff --git a/api/server/routes/agents/chat.js b/api/server/routes/agents/chat.js index ef66ef7896..0e07c83bd1 100644 --- a/api/server/routes/agents/chat.js +++ b/api/server/routes/agents/chat.js @@ -1,22 +1,28 @@ const express = require('express'); +const { generateCheckAccess, skipAgentCheck } = require('@librechat/api'); const { PermissionTypes, Permissions } = require('librechat-data-provider'); const { setHeaders, moderateText, // validateModel, - generateCheckAccess, validateConvoAccess, buildEndpointOption, } = require('~/server/middleware'); const { initializeClient } = require('~/server/services/Endpoints/agents'); const AgentController = require('~/server/controllers/agents/request'); const addTitle = require('~/server/services/Endpoints/agents/title'); +const { getRoleByName } = require('~/models/Role'); const router = express.Router(); router.use(moderateText); -const checkAgentAccess = generateCheckAccess(PermissionTypes.AGENTS, [Permissions.USE]); +const checkAgentAccess = generateCheckAccess({ + permissionType: PermissionTypes.AGENTS, + permissions: [Permissions.USE], + skipCheck: skipAgentCheck, + getRoleByName, +}); router.use(checkAgentAccess); router.use(validateConvoAccess); diff --git a/api/server/routes/agents/v1.js b/api/server/routes/agents/v1.js index 657aa79414..0455b23948 100644 --- a/api/server/routes/agents/v1.js +++ b/api/server/routes/agents/v1.js @@ -1,29 +1,36 @@ const express = require('express'); +const { generateCheckAccess } = require('@librechat/api'); const { PermissionTypes, Permissions } = require('librechat-data-provider'); -const { requireJwtAuth, generateCheckAccess } = require('~/server/middleware'); +const { requireJwtAuth } = require('~/server/middleware'); const v1 = require('~/server/controllers/agents/v1'); +const { getRoleByName } = require('~/models/Role'); const actions = require('./actions'); const tools = require('./tools'); const router = express.Router(); const avatar = express.Router(); -const checkAgentAccess = generateCheckAccess(PermissionTypes.AGENTS, [Permissions.USE]); -const checkAgentCreate = generateCheckAccess(PermissionTypes.AGENTS, [ - Permissions.USE, - Permissions.CREATE, -]); +const checkAgentAccess = generateCheckAccess({ + permissionType: PermissionTypes.AGENTS, + permissions: [Permissions.USE], + getRoleByName, +}); +const checkAgentCreate = generateCheckAccess({ + permissionType: PermissionTypes.AGENTS, + permissions: [Permissions.USE, Permissions.CREATE], + getRoleByName, +}); -const checkGlobalAgentShare = generateCheckAccess( - PermissionTypes.AGENTS, - [Permissions.USE, Permissions.CREATE], - { +const checkGlobalAgentShare = generateCheckAccess({ + permissionType: PermissionTypes.AGENTS, + permissions: [Permissions.USE, Permissions.CREATE], + bodyProps: { [Permissions.SHARED_GLOBAL]: ['projectIds', 'removeProjectIds'], }, -); + getRoleByName, +}); router.use(requireJwtAuth); -router.use(checkAgentAccess); /** * Agent actions route. diff --git a/api/server/routes/ask/addToCache.js b/api/server/routes/ask/addToCache.js deleted file mode 100644 index a2f427098f..0000000000 --- a/api/server/routes/ask/addToCache.js +++ /dev/null @@ -1,63 +0,0 @@ -const { Keyv } = require('keyv'); -const { KeyvFile } = require('keyv-file'); -const { logger } = require('~/config'); - -const addToCache = async ({ endpoint, endpointOption, userMessage, responseMessage }) => { - try { - const conversationsCache = new Keyv({ - store: new KeyvFile({ filename: './data/cache.json' }), - namespace: 'chatgpt', // should be 'bing' for bing/sydney - }); - - const { - conversationId, - messageId: userMessageId, - parentMessageId: userParentMessageId, - text: userText, - } = userMessage; - const { - messageId: responseMessageId, - parentMessageId: responseParentMessageId, - text: responseText, - } = responseMessage; - - let conversation = await conversationsCache.get(conversationId); - // used to generate a title for the conversation if none exists - // let isNewConversation = false; - if (!conversation) { - conversation = { - messages: [], - createdAt: Date.now(), - }; - // isNewConversation = true; - } - - const roles = (options) => { - if (endpoint === 'openAI') { - return options?.chatGptLabel || 'ChatGPT'; - } - }; - - let _userMessage = { - id: userMessageId, - parentMessageId: userParentMessageId, - role: 'User', - message: userText, - }; - - let _responseMessage = { - id: responseMessageId, - parentMessageId: responseParentMessageId, - role: roles(endpointOption), - message: responseText, - }; - - conversation.messages.push(_userMessage, _responseMessage); - - await conversationsCache.set(conversationId, conversation); - } catch (error) { - logger.error('[addToCache] Error adding conversation to cache', error); - } -}; - -module.exports = addToCache; diff --git a/api/server/routes/ask/anthropic.js b/api/server/routes/ask/anthropic.js deleted file mode 100644 index afe1720d84..0000000000 --- a/api/server/routes/ask/anthropic.js +++ /dev/null @@ -1,25 +0,0 @@ -const express = require('express'); -const AskController = require('~/server/controllers/AskController'); -const { addTitle, initializeClient } = require('~/server/services/Endpoints/anthropic'); -const { - setHeaders, - handleAbort, - validateModel, - validateEndpoint, - buildEndpointOption, -} = require('~/server/middleware'); - -const router = express.Router(); - -router.post( - '/', - validateEndpoint, - validateModel, - buildEndpointOption, - setHeaders, - async (req, res, next) => { - await AskController(req, res, next, initializeClient, addTitle); - }, -); - -module.exports = router; diff --git a/api/server/routes/ask/custom.js b/api/server/routes/ask/custom.js deleted file mode 100644 index 8fc343cf17..0000000000 --- a/api/server/routes/ask/custom.js +++ /dev/null @@ -1,25 +0,0 @@ -const express = require('express'); -const AskController = require('~/server/controllers/AskController'); -const { initializeClient } = require('~/server/services/Endpoints/custom'); -const { addTitle } = require('~/server/services/Endpoints/openAI'); -const { - setHeaders, - validateModel, - validateEndpoint, - buildEndpointOption, -} = require('~/server/middleware'); - -const router = express.Router(); - -router.post( - '/', - validateEndpoint, - validateModel, - buildEndpointOption, - setHeaders, - async (req, res, next) => { - await AskController(req, res, next, initializeClient, addTitle); - }, -); - -module.exports = router; diff --git a/api/server/routes/ask/google.js b/api/server/routes/ask/google.js deleted file mode 100644 index 16c7e265f4..0000000000 --- a/api/server/routes/ask/google.js +++ /dev/null @@ -1,24 +0,0 @@ -const express = require('express'); -const AskController = require('~/server/controllers/AskController'); -const { initializeClient, addTitle } = require('~/server/services/Endpoints/google'); -const { - setHeaders, - validateModel, - validateEndpoint, - buildEndpointOption, -} = require('~/server/middleware'); - -const router = express.Router(); - -router.post( - '/', - validateEndpoint, - validateModel, - buildEndpointOption, - setHeaders, - async (req, res, next) => { - await AskController(req, res, next, initializeClient, addTitle); - }, -); - -module.exports = router; diff --git a/api/server/routes/ask/gptPlugins.js b/api/server/routes/ask/gptPlugins.js deleted file mode 100644 index a40022848a..0000000000 --- a/api/server/routes/ask/gptPlugins.js +++ /dev/null @@ -1,241 +0,0 @@ -const express = require('express'); -const { getResponseSender, Constants } = require('librechat-data-provider'); -const { initializeClient } = require('~/server/services/Endpoints/gptPlugins'); -const { sendMessage, createOnProgress } = require('~/server/utils'); -const { addTitle } = require('~/server/services/Endpoints/openAI'); -const { saveMessage, updateMessage } = require('~/models'); -const { - handleAbort, - createAbortController, - handleAbortError, - setHeaders, - validateModel, - validateEndpoint, - buildEndpointOption, - moderateText, -} = require('~/server/middleware'); -const { validateTools } = require('~/app'); -const { logger } = require('~/config'); - -const router = express.Router(); - -router.use(moderateText); - -router.post( - '/', - validateEndpoint, - validateModel, - buildEndpointOption, - setHeaders, - async (req, res) => { - let { - text, - endpointOption, - conversationId, - parentMessageId = null, - overrideParentMessageId = null, - } = req.body; - - logger.debug('[/ask/gptPlugins]', { text, conversationId, ...endpointOption }); - - let userMessage; - let userMessagePromise; - let promptTokens; - let userMessageId; - let responseMessageId; - const sender = getResponseSender({ - ...endpointOption, - model: endpointOption.modelOptions.model, - }); - const newConvo = !conversationId; - const user = req.user.id; - - const plugins = []; - - const getReqData = (data = {}) => { - for (let key in data) { - if (key === 'userMessage') { - userMessage = data[key]; - userMessageId = data[key].messageId; - } else if (key === 'userMessagePromise') { - userMessagePromise = data[key]; - } else if (key === 'responseMessageId') { - responseMessageId = data[key]; - } else if (key === 'promptTokens') { - promptTokens = data[key]; - } else if (!conversationId && key === 'conversationId') { - conversationId = data[key]; - } - } - }; - - let streaming = null; - let timer = null; - - const { - onProgress: progressCallback, - sendIntermediateMessage, - getPartialText, - } = createOnProgress({ - onProgress: () => { - if (timer) { - clearTimeout(timer); - } - - streaming = new Promise((resolve) => { - timer = setTimeout(() => { - resolve(); - }, 250); - }); - }, - }); - - const pluginMap = new Map(); - const onAgentAction = async (action, runId) => { - pluginMap.set(runId, action.tool); - sendIntermediateMessage(res, { - plugins, - parentMessageId: userMessage.messageId, - messageId: responseMessageId, - }); - }; - - const onToolStart = async (tool, input, runId, parentRunId) => { - const pluginName = pluginMap.get(parentRunId); - const latestPlugin = { - runId, - loading: true, - inputs: [input], - latest: pluginName, - outputs: null, - }; - - if (streaming) { - await streaming; - } - const extraTokens = ':::plugin:::\n'; - plugins.push(latestPlugin); - sendIntermediateMessage( - res, - { plugins, parentMessageId: userMessage.messageId, messageId: responseMessageId }, - extraTokens, - ); - }; - - const onToolEnd = async (output, runId) => { - if (streaming) { - await streaming; - } - - const pluginIndex = plugins.findIndex((plugin) => plugin.runId === runId); - - if (pluginIndex !== -1) { - plugins[pluginIndex].loading = false; - plugins[pluginIndex].outputs = output; - } - }; - - const getAbortData = () => ({ - sender, - conversationId, - userMessagePromise, - messageId: responseMessageId, - parentMessageId: overrideParentMessageId ?? userMessageId, - text: getPartialText(), - plugins: plugins.map((p) => ({ ...p, loading: false })), - userMessage, - promptTokens, - }); - const { abortController, onStart } = createAbortController(req, res, getAbortData, getReqData); - - try { - endpointOption.tools = await validateTools(user, endpointOption.tools); - const { client } = await initializeClient({ req, res, endpointOption }); - - const onChainEnd = () => { - if (!client.skipSaveUserMessage) { - saveMessage( - req, - { ...userMessage, user }, - { context: 'api/server/routes/ask/gptPlugins.js - onChainEnd' }, - ); - } - sendIntermediateMessage(res, { - plugins, - parentMessageId: userMessage.messageId, - messageId: responseMessageId, - }); - }; - - let response = await client.sendMessage(text, { - user, - conversationId, - parentMessageId, - overrideParentMessageId, - getReqData, - onAgentAction, - onChainEnd, - onToolStart, - onToolEnd, - onStart, - getPartialText, - ...endpointOption, - progressCallback, - progressOptions: { - res, - // parentMessageId: overrideParentMessageId || userMessageId, - plugins, - }, - abortController, - }); - - if (overrideParentMessageId) { - response.parentMessageId = overrideParentMessageId; - } - - logger.debug('[/ask/gptPlugins]', response); - - const { conversation = {} } = await response.databasePromise; - delete response.databasePromise; - conversation.title = - conversation && !conversation.title ? null : conversation?.title || 'New Chat'; - - sendMessage(res, { - title: conversation.title, - final: true, - conversation, - requestMessage: userMessage, - responseMessage: response, - }); - res.end(); - - if (parentMessageId === Constants.NO_PARENT && newConvo) { - addTitle(req, { - text, - response, - client, - }); - } - - response.plugins = plugins.map((p) => ({ ...p, loading: false })); - if (response.plugins?.length > 0) { - await updateMessage( - req, - { ...response, user }, - { context: 'api/server/routes/ask/gptPlugins.js - save plugins used' }, - ); - } - } catch (error) { - const partialText = getPartialText(); - handleAbortError(res, req, error, { - partialText, - conversationId, - sender, - messageId: responseMessageId, - parentMessageId: userMessageId ?? parentMessageId, - }); - } - }, -); - -module.exports = router; diff --git a/api/server/routes/ask/index.js b/api/server/routes/ask/index.js deleted file mode 100644 index 525bd8e29d..0000000000 --- a/api/server/routes/ask/index.js +++ /dev/null @@ -1,47 +0,0 @@ -const express = require('express'); -const { EModelEndpoint } = require('librechat-data-provider'); -const { - uaParser, - checkBan, - requireJwtAuth, - messageIpLimiter, - concurrentLimiter, - messageUserLimiter, - validateConvoAccess, -} = require('~/server/middleware'); -const { isEnabled } = require('~/server/utils'); -const gptPlugins = require('./gptPlugins'); -const anthropic = require('./anthropic'); -const custom = require('./custom'); -const google = require('./google'); -const openAI = require('./openAI'); - -const { LIMIT_CONCURRENT_MESSAGES, LIMIT_MESSAGE_IP, LIMIT_MESSAGE_USER } = process.env ?? {}; - -const router = express.Router(); - -router.use(requireJwtAuth); -router.use(checkBan); -router.use(uaParser); - -if (isEnabled(LIMIT_CONCURRENT_MESSAGES)) { - router.use(concurrentLimiter); -} - -if (isEnabled(LIMIT_MESSAGE_IP)) { - router.use(messageIpLimiter); -} - -if (isEnabled(LIMIT_MESSAGE_USER)) { - router.use(messageUserLimiter); -} - -router.use(validateConvoAccess); - -router.use([`/${EModelEndpoint.azureOpenAI}`, `/${EModelEndpoint.openAI}`], openAI); -router.use(`/${EModelEndpoint.gptPlugins}`, gptPlugins); -router.use(`/${EModelEndpoint.anthropic}`, anthropic); -router.use(`/${EModelEndpoint.google}`, google); -router.use(`/${EModelEndpoint.custom}`, custom); - -module.exports = router; diff --git a/api/server/routes/ask/openAI.js b/api/server/routes/ask/openAI.js deleted file mode 100644 index dadf00def4..0000000000 --- a/api/server/routes/ask/openAI.js +++ /dev/null @@ -1,27 +0,0 @@ -const express = require('express'); -const AskController = require('~/server/controllers/AskController'); -const { addTitle, initializeClient } = require('~/server/services/Endpoints/openAI'); -const { - handleAbort, - setHeaders, - validateModel, - validateEndpoint, - buildEndpointOption, - moderateText, -} = require('~/server/middleware'); - -const router = express.Router(); -router.use(moderateText); - -router.post( - '/', - validateEndpoint, - validateModel, - buildEndpointOption, - setHeaders, - async (req, res, next) => { - await AskController(req, res, next, initializeClient, addTitle); - }, -); - -module.exports = router; diff --git a/api/server/routes/bedrock/chat.js b/api/server/routes/bedrock/chat.js deleted file mode 100644 index 263ca96002..0000000000 --- a/api/server/routes/bedrock/chat.js +++ /dev/null @@ -1,37 +0,0 @@ -const express = require('express'); - -const router = express.Router(); -const { - setHeaders, - handleAbort, - moderateText, - // validateModel, - // validateEndpoint, - buildEndpointOption, -} = require('~/server/middleware'); -const { initializeClient } = require('~/server/services/Endpoints/bedrock'); -const AgentController = require('~/server/controllers/agents/request'); -const addTitle = require('~/server/services/Endpoints/agents/title'); - -router.use(moderateText); - -/** - * @route POST / - * @desc Chat with an assistant - * @access Public - * @param {express.Request} req - The request object, containing the request data. - * @param {express.Response} res - The response object, used to send back a response. - * @returns {void} - */ -router.post( - '/', - // validateModel, - // validateEndpoint, - buildEndpointOption, - setHeaders, - async (req, res, next) => { - await AgentController(req, res, next, initializeClient, addTitle); - }, -); - -module.exports = router; diff --git a/api/server/routes/bedrock/index.js b/api/server/routes/bedrock/index.js deleted file mode 100644 index ce440a7c0e..0000000000 --- a/api/server/routes/bedrock/index.js +++ /dev/null @@ -1,35 +0,0 @@ -const express = require('express'); -const { - uaParser, - checkBan, - requireJwtAuth, - messageIpLimiter, - concurrentLimiter, - messageUserLimiter, -} = require('~/server/middleware'); -const { isEnabled } = require('~/server/utils'); -const chat = require('./chat'); - -const { LIMIT_CONCURRENT_MESSAGES, LIMIT_MESSAGE_IP, LIMIT_MESSAGE_USER } = process.env ?? {}; - -const router = express.Router(); - -router.use(requireJwtAuth); -router.use(checkBan); -router.use(uaParser); - -if (isEnabled(LIMIT_CONCURRENT_MESSAGES)) { - router.use(concurrentLimiter); -} - -if (isEnabled(LIMIT_MESSAGE_IP)) { - router.use(messageIpLimiter); -} - -if (isEnabled(LIMIT_MESSAGE_USER)) { - router.use(messageUserLimiter); -} - -router.use('/chat', chat); - -module.exports = router; diff --git a/api/server/routes/convos.js b/api/server/routes/convos.js index eb7e2c5c27..18dbf8db0a 100644 --- a/api/server/routes/convos.js +++ b/api/server/routes/convos.js @@ -1,16 +1,17 @@ const multer = require('multer'); const express = require('express'); +const { sleep } = require('@librechat/agents'); +const { isEnabled } = require('@librechat/api'); +const { logger } = require('@librechat/data-schemas'); const { CacheKeys, EModelEndpoint } = require('librechat-data-provider'); const { getConvosByCursor, deleteConvos, getConvo, saveConvo } = require('~/models/Conversation'); const { forkConversation, duplicateConversation } = require('~/server/utils/import/fork'); +const { createImportLimiters, createForkLimiters } = require('~/server/middleware'); const { storage, importFileFilter } = require('~/server/routes/files/multer'); const requireJwtAuth = require('~/server/middleware/requireJwtAuth'); const { importConversations } = require('~/server/utils/import'); -const { createImportLimiters } = require('~/server/middleware'); const { deleteToolCalls } = require('~/models/ToolCall'); -const { isEnabled, sleep } = require('~/server/utils'); const getLogStores = require('~/cache/getLogStores'); -const { logger } = require('~/config'); const assistantClients = { [EModelEndpoint.azureAssistants]: require('~/server/services/Endpoints/azureAssistants'), @@ -43,6 +44,7 @@ router.get('/', async (req, res) => { }); res.status(200).json(result); } catch (error) { + logger.error('Error fetching conversations', error); res.status(500).json({ error: 'Error fetching conversations' }); } }); @@ -156,6 +158,7 @@ router.post('/update', async (req, res) => { }); const { importIpLimiter, importUserLimiter } = createImportLimiters(); +const { forkIpLimiter, forkUserLimiter } = createForkLimiters(); const upload = multer({ storage: storage, fileFilter: importFileFilter }); /** @@ -189,7 +192,7 @@ router.post( * @param {express.Response} res - Express response object. * @returns {Promise} - The response after forking the conversation. */ -router.post('/fork', async (req, res) => { +router.post('/fork', forkIpLimiter, forkUserLimiter, async (req, res) => { try { /** @type {TForkConvoRequest} */ const { conversationId, messageId, option, splitAtTarget, latestMessageId } = req.body; diff --git a/api/server/routes/edit/gptPlugins.js b/api/server/routes/edit/gptPlugins.js deleted file mode 100644 index 94d9b91d0b..0000000000 --- a/api/server/routes/edit/gptPlugins.js +++ /dev/null @@ -1,207 +0,0 @@ -const express = require('express'); -const { getResponseSender } = require('librechat-data-provider'); -const { - setHeaders, - moderateText, - validateModel, - handleAbortError, - validateEndpoint, - buildEndpointOption, - createAbortController, -} = require('~/server/middleware'); -const { sendMessage, createOnProgress, formatSteps, formatAction } = require('~/server/utils'); -const { initializeClient } = require('~/server/services/Endpoints/gptPlugins'); -const { saveMessage, updateMessage } = require('~/models'); -const { validateTools } = require('~/app'); -const { logger } = require('~/config'); - -const router = express.Router(); - -router.use(moderateText); - -router.post( - '/', - validateEndpoint, - validateModel, - buildEndpointOption, - setHeaders, - async (req, res) => { - let { - text, - generation, - endpointOption, - conversationId, - responseMessageId, - isContinued = false, - parentMessageId = null, - overrideParentMessageId = null, - } = req.body; - - logger.debug('[/edit/gptPlugins]', { - text, - generation, - isContinued, - conversationId, - ...endpointOption, - }); - - let userMessage; - let userMessagePromise; - let promptTokens; - const sender = getResponseSender({ - ...endpointOption, - model: endpointOption.modelOptions.model, - }); - const userMessageId = parentMessageId; - const user = req.user.id; - - const plugin = { - loading: true, - inputs: [], - latest: null, - outputs: null, - }; - - const getReqData = (data = {}) => { - for (let key in data) { - if (key === 'userMessage') { - userMessage = data[key]; - } else if (key === 'userMessagePromise') { - userMessagePromise = data[key]; - } else if (key === 'responseMessageId') { - responseMessageId = data[key]; - } else if (key === 'promptTokens') { - promptTokens = data[key]; - } - } - }; - - const { - onProgress: progressCallback, - sendIntermediateMessage, - getPartialText, - } = createOnProgress({ - generation, - onProgress: () => { - if (plugin.loading === true) { - plugin.loading = false; - } - }, - }); - - const onChainEnd = (data) => { - let { intermediateSteps: steps } = data; - plugin.outputs = steps && steps[0].action ? formatSteps(steps) : 'An error occurred.'; - plugin.loading = false; - saveMessage( - req, - { ...userMessage, user }, - { context: 'api/server/routes/ask/gptPlugins.js - onChainEnd' }, - ); - sendIntermediateMessage(res, { - plugin, - parentMessageId: userMessage.messageId, - messageId: responseMessageId, - }); - // logger.debug('CHAIN END', plugin.outputs); - }; - - const getAbortData = () => ({ - sender, - conversationId, - userMessagePromise, - messageId: responseMessageId, - parentMessageId: overrideParentMessageId ?? userMessageId, - text: getPartialText(), - plugin: { ...plugin, loading: false }, - userMessage, - promptTokens, - }); - const { abortController, onStart } = createAbortController(req, res, getAbortData, getReqData); - - try { - endpointOption.tools = await validateTools(user, endpointOption.tools); - const { client } = await initializeClient({ req, res, endpointOption }); - - const onAgentAction = (action, start = false) => { - const formattedAction = formatAction(action); - plugin.inputs.push(formattedAction); - plugin.latest = formattedAction.plugin; - if (!start && !client.skipSaveUserMessage) { - saveMessage( - req, - { ...userMessage, user }, - { context: 'api/server/routes/ask/gptPlugins.js - onAgentAction' }, - ); - } - sendIntermediateMessage(res, { - plugin, - parentMessageId: userMessage.messageId, - messageId: responseMessageId, - }); - // logger.debug('PLUGIN ACTION', formattedAction); - }; - - let response = await client.sendMessage(text, { - user, - generation, - isContinued, - isEdited: true, - conversationId, - parentMessageId, - responseMessageId, - overrideParentMessageId, - getReqData, - onAgentAction, - onChainEnd, - onStart, - ...endpointOption, - progressCallback, - progressOptions: { - res, - plugin, - // parentMessageId: overrideParentMessageId || userMessageId, - }, - abortController, - }); - - if (overrideParentMessageId) { - response.parentMessageId = overrideParentMessageId; - } - - logger.debug('[/edit/gptPlugins] CLIENT RESPONSE', response); - - const { conversation = {} } = await response.databasePromise; - delete response.databasePromise; - conversation.title = - conversation && !conversation.title ? null : conversation?.title || 'New Chat'; - - sendMessage(res, { - title: conversation.title, - final: true, - conversation, - requestMessage: userMessage, - responseMessage: response, - }); - res.end(); - - response.plugin = { ...plugin, loading: false }; - await updateMessage( - req, - { ...response, user }, - { context: 'api/server/routes/edit/gptPlugins.js' }, - ); - } catch (error) { - const partialText = getPartialText(); - handleAbortError(res, req, error, { - partialText, - conversationId, - sender, - messageId: responseMessageId, - parentMessageId: userMessageId ?? parentMessageId, - }); - } - }, -); - -module.exports = router; diff --git a/api/server/routes/edit/index.js b/api/server/routes/edit/index.js index f1d47af3f9..92a1e63f63 100644 --- a/api/server/routes/edit/index.js +++ b/api/server/routes/edit/index.js @@ -3,7 +3,6 @@ const openAI = require('./openAI'); const custom = require('./custom'); const google = require('./google'); const anthropic = require('./anthropic'); -const gptPlugins = require('./gptPlugins'); const { isEnabled } = require('~/server/utils'); const { EModelEndpoint } = require('librechat-data-provider'); const { @@ -39,7 +38,6 @@ if (isEnabled(LIMIT_MESSAGE_USER)) { router.use(validateConvoAccess); router.use([`/${EModelEndpoint.azureOpenAI}`, `/${EModelEndpoint.openAI}`], openAI); -router.use(`/${EModelEndpoint.gptPlugins}`, gptPlugins); router.use(`/${EModelEndpoint.anthropic}`, anthropic); router.use(`/${EModelEndpoint.google}`, google); router.use(`/${EModelEndpoint.custom}`, custom); diff --git a/api/server/routes/files/files.js b/api/server/routes/files/files.js index bb2ae0bbe5..bdfdca65cf 100644 --- a/api/server/routes/files/files.js +++ b/api/server/routes/files/files.js @@ -283,7 +283,10 @@ router.post('/', async (req, res) => { message += ': ' + error.message; } - if (error.message?.includes('Invalid file format')) { + if ( + error.message?.includes('Invalid file format') || + error.message?.includes('No OCR result') + ) { message = error.message; } diff --git a/api/server/routes/files/multer.spec.js b/api/server/routes/files/multer.spec.js index 0324262a71..2fb9147aef 100644 --- a/api/server/routes/files/multer.spec.js +++ b/api/server/routes/files/multer.spec.js @@ -477,7 +477,9 @@ describe('Multer Configuration', () => { done(new Error('Expected mkdirSync to throw an error but no error was thrown')); } catch (error) { // This is the expected behavior - mkdirSync throws synchronously for invalid paths - expect(error.code).toBe('EACCES'); + // On Linux, this typically returns EACCES (permission denied) + // On macOS/Darwin, this returns ENOENT (no such file or directory) + expect(['EACCES', 'ENOENT']).toContain(error.code); done(); } }); diff --git a/api/server/routes/index.js b/api/server/routes/index.js index 7c1b5de0fa..ec97ba3986 100644 --- a/api/server/routes/index.js +++ b/api/server/routes/index.js @@ -9,7 +9,6 @@ const presets = require('./presets'); const prompts = require('./prompts'); const balance = require('./balance'); const plugins = require('./plugins'); -const bedrock = require('./bedrock'); const actions = require('./actions'); const banner = require('./banner'); const search = require('./search'); @@ -26,11 +25,9 @@ const auth = require('./auth'); const edit = require('./edit'); const keys = require('./keys'); const user = require('./user'); -const ask = require('./ask'); const mcp = require('./mcp'); module.exports = { - ask, edit, auth, keys, @@ -46,7 +43,6 @@ module.exports = { search, config, models, - bedrock, prompts, plugins, actions, diff --git a/api/server/routes/memories.js b/api/server/routes/memories.js index 86065fecaa..a136bc8e61 100644 --- a/api/server/routes/memories.js +++ b/api/server/routes/memories.js @@ -1,37 +1,43 @@ const express = require('express'); -const { Tokenizer } = require('@librechat/api'); +const { Tokenizer, generateCheckAccess } = require('@librechat/api'); const { PermissionTypes, Permissions } = require('librechat-data-provider'); const { getAllUserMemories, toggleUserMemories, createMemory, - setMemory, deleteMemory, + setMemory, } = require('~/models'); -const { requireJwtAuth, generateCheckAccess } = require('~/server/middleware'); +const { requireJwtAuth } = require('~/server/middleware'); +const { getRoleByName } = require('~/models/Role'); const router = express.Router(); -const checkMemoryRead = generateCheckAccess(PermissionTypes.MEMORIES, [ - Permissions.USE, - Permissions.READ, -]); -const checkMemoryCreate = generateCheckAccess(PermissionTypes.MEMORIES, [ - Permissions.USE, - Permissions.CREATE, -]); -const checkMemoryUpdate = generateCheckAccess(PermissionTypes.MEMORIES, [ - Permissions.USE, - Permissions.UPDATE, -]); -const checkMemoryDelete = generateCheckAccess(PermissionTypes.MEMORIES, [ - Permissions.USE, - Permissions.UPDATE, -]); -const checkMemoryOptOut = generateCheckAccess(PermissionTypes.MEMORIES, [ - Permissions.USE, - Permissions.OPT_OUT, -]); +const checkMemoryRead = generateCheckAccess({ + permissionType: PermissionTypes.MEMORIES, + permissions: [Permissions.USE, Permissions.READ], + getRoleByName, +}); +const checkMemoryCreate = generateCheckAccess({ + permissionType: PermissionTypes.MEMORIES, + permissions: [Permissions.USE, Permissions.CREATE], + getRoleByName, +}); +const checkMemoryUpdate = generateCheckAccess({ + permissionType: PermissionTypes.MEMORIES, + permissions: [Permissions.USE, Permissions.UPDATE], + getRoleByName, +}); +const checkMemoryDelete = generateCheckAccess({ + permissionType: PermissionTypes.MEMORIES, + permissions: [Permissions.USE, Permissions.UPDATE], + getRoleByName, +}); +const checkMemoryOptOut = generateCheckAccess({ + permissionType: PermissionTypes.MEMORIES, + permissions: [Permissions.USE, Permissions.OPT_OUT], + getRoleByName, +}); router.use(requireJwtAuth); @@ -166,40 +172,68 @@ router.patch('/preferences', checkMemoryOptOut, async (req, res) => { /** * PATCH /memories/:key * Updates the value of an existing memory entry for the authenticated user. - * Body: { value: string } + * Body: { key?: string, value: string } * Returns 200 and { updated: true, memory: } when successful. */ router.patch('/:key', checkMemoryUpdate, async (req, res) => { - const { key } = req.params; - const { value } = req.body || {}; + const { key: urlKey } = req.params; + const { key: bodyKey, value } = req.body || {}; if (typeof value !== 'string' || value.trim() === '') { return res.status(400).json({ error: 'Value is required and must be a non-empty string.' }); } + // Use the key from the body if provided, otherwise use the key from the URL + const newKey = bodyKey || urlKey; + try { const tokenCount = Tokenizer.getTokenCount(value, 'o200k_base'); const memories = await getAllUserMemories(req.user.id); - const existingMemory = memories.find((m) => m.key === key); + const existingMemory = memories.find((m) => m.key === urlKey); if (!existingMemory) { return res.status(404).json({ error: 'Memory not found.' }); } - const result = await setMemory({ - userId: req.user.id, - key, - value, - tokenCount, - }); + // If the key is changing, we need to handle it specially + if (newKey !== urlKey) { + const keyExists = memories.find((m) => m.key === newKey); + if (keyExists) { + return res.status(409).json({ error: 'Memory with this key already exists.' }); + } - if (!result.ok) { - return res.status(500).json({ error: 'Failed to update memory.' }); + const createResult = await createMemory({ + userId: req.user.id, + key: newKey, + value, + tokenCount, + }); + + if (!createResult.ok) { + return res.status(500).json({ error: 'Failed to create new memory.' }); + } + + const deleteResult = await deleteMemory({ userId: req.user.id, key: urlKey }); + if (!deleteResult.ok) { + return res.status(500).json({ error: 'Failed to delete old memory.' }); + } + } else { + // Key is not changing, just update the value + const result = await setMemory({ + userId: req.user.id, + key: newKey, + value, + tokenCount, + }); + + if (!result.ok) { + return res.status(500).json({ error: 'Failed to update memory.' }); + } } const updatedMemories = await getAllUserMemories(req.user.id); - const updatedMemory = updatedMemories.find((m) => m.key === key); + const updatedMemory = updatedMemories.find((m) => m.key === newKey); res.json({ updated: true, memory: updatedMemory }); } catch (error) { diff --git a/api/server/routes/messages.js b/api/server/routes/messages.js index 356dd25097..0a277a1bd6 100644 --- a/api/server/routes/messages.js +++ b/api/server/routes/messages.js @@ -235,12 +235,13 @@ router.put('/:conversationId/:messageId', validateMessageReq, async (req, res) = return res.status(400).json({ error: 'Content part not found' }); } - if (updatedContent[index].type !== ContentTypes.TEXT) { + const currentPartType = updatedContent[index].type; + if (currentPartType !== ContentTypes.TEXT && currentPartType !== ContentTypes.THINK) { return res.status(400).json({ error: 'Cannot update non-text content' }); } - const oldText = updatedContent[index].text; - updatedContent[index] = { type: ContentTypes.TEXT, text }; + const oldText = updatedContent[index][currentPartType]; + updatedContent[index] = { type: currentPartType, [currentPartType]: text }; let tokenCount = message.tokenCount; if (tokenCount !== undefined) { diff --git a/api/server/routes/prompts.js b/api/server/routes/prompts.js index 499c4cc9ea..3d4b96aa1e 100644 --- a/api/server/routes/prompts.js +++ b/api/server/routes/prompts.js @@ -1,5 +1,7 @@ const express = require('express'); -const { PermissionTypes, Permissions, SystemRoles } = require('librechat-data-provider'); +const { logger } = require('@librechat/data-schemas'); +const { generateCheckAccess } = require('@librechat/api'); +const { Permissions, SystemRoles, PermissionTypes } = require('librechat-data-provider'); const { getPrompt, getPrompts, @@ -16,23 +18,30 @@ const { } = require('~/models/Prompt'); const { requireJwtAuth, generateCheckAccess } = require('~/server/middleware'); const { getUserById, updateUser } = require('~/models'); +const { getRoleByName } = require('~/models/Role'); const { logger } = require('~/config'); const router = express.Router(); -const checkPromptAccess = generateCheckAccess(PermissionTypes.PROMPTS, [Permissions.USE]); -const checkPromptCreate = generateCheckAccess(PermissionTypes.PROMPTS, [ - Permissions.USE, - Permissions.CREATE, -]); +const checkPromptAccess = generateCheckAccess({ + permissionType: PermissionTypes.PROMPTS, + permissions: [Permissions.USE], + getRoleByName, +}); +const checkPromptCreate = generateCheckAccess({ + permissionType: PermissionTypes.PROMPTS, + permissions: [Permissions.USE, Permissions.CREATE], + getRoleByName, +}); -const checkGlobalPromptShare = generateCheckAccess( - PermissionTypes.PROMPTS, - [Permissions.USE, Permissions.CREATE], - { +const checkGlobalPromptShare = generateCheckAccess({ + permissionType: PermissionTypes.PROMPTS, + permissions: [Permissions.USE, Permissions.CREATE], + bodyProps: { [Permissions.SHARED_GLOBAL]: ['projectIds', 'removeProjectIds'], }, -); + getRoleByName, +}); router.use(requireJwtAuth); router.use(checkPromptAccess); diff --git a/api/server/routes/tags.js b/api/server/routes/tags.js index d3e27d3711..0a4ee5084c 100644 --- a/api/server/routes/tags.js +++ b/api/server/routes/tags.js @@ -1,18 +1,24 @@ const express = require('express'); +const { logger } = require('@librechat/data-schemas'); +const { generateCheckAccess } = require('@librechat/api'); const { PermissionTypes, Permissions } = require('librechat-data-provider'); const { - getConversationTags, + updateTagsForConversation, updateConversationTag, createConversationTag, deleteConversationTag, - updateTagsForConversation, + getConversationTags, } = require('~/models/ConversationTag'); -const { requireJwtAuth, generateCheckAccess } = require('~/server/middleware'); -const { logger } = require('~/config'); +const { requireJwtAuth } = require('~/server/middleware'); +const { getRoleByName } = require('~/models/Role'); const router = express.Router(); -const checkBookmarkAccess = generateCheckAccess(PermissionTypes.BOOKMARKS, [Permissions.USE]); +const checkBookmarkAccess = generateCheckAccess({ + permissionType: PermissionTypes.BOOKMARKS, + permissions: [Permissions.USE], + getRoleByName, +}); router.use(requireJwtAuth); router.use(checkBookmarkAccess); diff --git a/api/server/services/AppService.spec.js b/api/server/services/AppService.spec.js index 7edccc2c0d..678e8a90db 100644 --- a/api/server/services/AppService.spec.js +++ b/api/server/services/AppService.spec.js @@ -152,12 +152,14 @@ describe('AppService', () => { filteredTools: undefined, includedTools: undefined, webSearch: { + safeSearch: 1, + jinaApiKey: '${JINA_API_KEY}', cohereApiKey: '${COHERE_API_KEY}', + serperApiKey: '${SERPER_API_KEY}', + searxngApiKey: '${SEARXNG_API_KEY}', firecrawlApiKey: '${FIRECRAWL_API_KEY}', firecrawlApiUrl: '${FIRECRAWL_API_URL}', - jinaApiKey: '${JINA_API_KEY}', - safeSearch: 1, - serperApiKey: '${SERPER_API_KEY}', + searxngInstanceUrl: '${SEARXNG_INSTANCE_URL}', }, memory: undefined, agents: { diff --git a/api/server/services/AssistantService.js b/api/server/services/AssistantService.js index 2db0a56b6b..5354b2e33a 100644 --- a/api/server/services/AssistantService.js +++ b/api/server/services/AssistantService.js @@ -1,4 +1,7 @@ const { klona } = require('klona'); +const { sleep } = require('@librechat/agents'); +const { sendEvent } = require('@librechat/api'); +const { logger } = require('@librechat/data-schemas'); const { StepTypes, RunStatus, @@ -11,11 +14,10 @@ const { } = require('librechat-data-provider'); const { retrieveAndProcessFile } = require('~/server/services/Files/process'); const { processRequiredActions } = require('~/server/services/ToolService'); -const { createOnProgress, sendMessage, sleep } = require('~/server/utils'); const { RunManager, waitForRun } = require('~/server/services/Runs'); const { processMessages } = require('~/server/services/Threads'); +const { createOnProgress } = require('~/server/utils'); const { TextStream } = require('~/app/clients'); -const { logger } = require('~/config'); /** * Sorts, processes, and flattens messages to a single string. @@ -64,7 +66,7 @@ async function createOnTextProgress({ }; logger.debug('Content data:', contentData); - sendMessage(openai.res, contentData); + sendEvent(openai.res, contentData); }; } diff --git a/api/server/services/AuthService.js b/api/server/services/AuthService.js index 6061277437..8c7cbf7d92 100644 --- a/api/server/services/AuthService.js +++ b/api/server/services/AuthService.js @@ -1,4 +1,5 @@ const bcrypt = require('bcryptjs'); +const jwt = require('jsonwebtoken'); const { webcrypto } = require('node:crypto'); const { isEnabled } = require('@librechat/api'); const { logger } = require('@librechat/data-schemas'); @@ -499,6 +500,18 @@ const resendVerificationEmail = async (req) => { }; } }; +/** + * Generate a short-lived JWT token + * @param {String} userId - The ID of the user + * @param {String} [expireIn='5m'] - The expiration time for the token (default is 5 minutes) + * @returns {String} - The generated JWT token + */ +const generateShortLivedToken = (userId, expireIn = '5m') => { + return jwt.sign({ id: userId }, process.env.JWT_SECRET, { + expiresIn: expireIn, + algorithm: 'HS256', + }); +}; module.exports = { logoutUser, @@ -506,7 +519,8 @@ module.exports = { registerUser, setAuthTokens, resetPassword, + setOpenIDAuthTokens, requestPasswordReset, resendVerificationEmail, - setOpenIDAuthTokens, + generateShortLivedToken, }; diff --git a/api/server/services/Config/EndpointService.js b/api/server/services/Config/EndpointService.js index 1f38b70a62..d8277dd67f 100644 --- a/api/server/services/Config/EndpointService.js +++ b/api/server/services/Config/EndpointService.js @@ -1,5 +1,6 @@ +const { isUserProvided } = require('@librechat/api'); const { EModelEndpoint } = require('librechat-data-provider'); -const { isUserProvided, generateConfig } = require('~/server/utils'); +const { generateConfig } = require('~/server/utils/handleText'); const { OPENAI_API_KEY: openAIApiKey, diff --git a/api/server/services/Config/getCustomConfig.js b/api/server/services/Config/getCustomConfig.js index d1ee5c3278..f3fb6f26b4 100644 --- a/api/server/services/Config/getCustomConfig.js +++ b/api/server/services/Config/getCustomConfig.js @@ -40,6 +40,7 @@ async function getBalanceConfig() { /** * * @param {string | EModelEndpoint} endpoint + * @returns {Promise} */ const getCustomEndpointConfig = async (endpoint) => { const customConfig = await getCustomConfig(); diff --git a/api/server/services/Config/getEndpointsConfig.js b/api/server/services/Config/getEndpointsConfig.js index 8ae022e4b3..670bc22d11 100644 --- a/api/server/services/Config/getEndpointsConfig.js +++ b/api/server/services/Config/getEndpointsConfig.js @@ -1,4 +1,10 @@ -const { CacheKeys, EModelEndpoint, orderEndpointsConfig } = require('librechat-data-provider'); +const { + CacheKeys, + EModelEndpoint, + isAgentsEndpoint, + orderEndpointsConfig, + defaultAgentCapabilities, +} = require('librechat-data-provider'); const loadDefaultEndpointsConfig = require('./loadDefaultEConfig'); const loadConfigEndpoints = require('./loadConfigEndpoints'); const getLogStores = require('~/cache/getLogStores'); @@ -80,8 +86,12 @@ async function getEndpointsConfig(req) { * @returns {Promise} */ const checkCapability = async (req, capability) => { + const isAgents = isAgentsEndpoint(req.body?.original_endpoint || req.body?.endpoint); const endpointsConfig = await getEndpointsConfig(req); - const capabilities = endpointsConfig?.[EModelEndpoint.agents]?.capabilities ?? []; + const capabilities = + isAgents || endpointsConfig?.[EModelEndpoint.agents]?.capabilities != null + ? (endpointsConfig?.[EModelEndpoint.agents]?.capabilities ?? []) + : defaultAgentCapabilities; return capabilities.includes(capability); }; diff --git a/api/server/services/Config/loadAsyncEndpoints.js b/api/server/services/Config/loadAsyncEndpoints.js index 0282146cd1..b88744e9ad 100644 --- a/api/server/services/Config/loadAsyncEndpoints.js +++ b/api/server/services/Config/loadAsyncEndpoints.js @@ -1,5 +1,7 @@ +const path = require('path'); +const { logger } = require('@librechat/data-schemas'); +const { loadServiceKey, isUserProvided } = require('@librechat/api'); const { EModelEndpoint } = require('librechat-data-provider'); -const { isUserProvided } = require('~/server/utils'); const { config } = require('./EndpointService'); const { openAIApiKey, azureOpenAIApiKey, useAzurePlugins, userProvidedOpenAI, googleKey } = config; @@ -9,37 +11,41 @@ const { openAIApiKey, azureOpenAIApiKey, useAzurePlugins, userProvidedOpenAI, go * @param {Express.Request} req - The request object */ async function loadAsyncEndpoints(req) { - let i = 0; let serviceKey, googleUserProvides; - try { - serviceKey = require('~/data/auth.json'); - } catch (e) { - if (i === 0) { - i++; + + /** Check if GOOGLE_KEY is provided at all(including 'user_provided') */ + const isGoogleKeyProvided = googleKey && googleKey.trim() !== ''; + + if (isGoogleKeyProvided) { + /** If GOOGLE_KEY is provided, check if it's user_provided */ + googleUserProvides = isUserProvided(googleKey); + } else { + /** Only attempt to load service key if GOOGLE_KEY is not provided */ + const serviceKeyPath = + process.env.GOOGLE_SERVICE_KEY_FILE || path.join(__dirname, '../../..', 'data', 'auth.json'); + + try { + serviceKey = await loadServiceKey(serviceKeyPath); + } catch (error) { + logger.error('Error loading service key', error); + serviceKey = null; } } - if (isUserProvided(googleKey)) { - googleUserProvides = true; - if (i <= 1) { - i++; - } - } - - const google = serviceKey || googleKey ? { userProvide: googleUserProvides } : false; + const google = serviceKey || isGoogleKeyProvided ? { userProvide: googleUserProvides } : false; const useAzure = req.app.locals[EModelEndpoint.azureOpenAI]?.plugins; const gptPlugins = useAzure || openAIApiKey || azureOpenAIApiKey ? { - availableAgents: ['classic', 'functions'], - userProvide: useAzure ? false : userProvidedOpenAI, - userProvideURL: useAzure - ? false - : config[EModelEndpoint.openAI]?.userProvideURL || + availableAgents: ['classic', 'functions'], + userProvide: useAzure ? false : userProvidedOpenAI, + userProvideURL: useAzure + ? false + : config[EModelEndpoint.openAI]?.userProvideURL || config[EModelEndpoint.azureOpenAI]?.userProvideURL, - azure: useAzurePlugins || useAzure, - } + azure: useAzurePlugins || useAzure, + } : false; return { google, gptPlugins }; diff --git a/api/server/services/Config/loadCustomConfig.js b/api/server/services/Config/loadCustomConfig.js index 18f3a44748..393281daf2 100644 --- a/api/server/services/Config/loadCustomConfig.js +++ b/api/server/services/Config/loadCustomConfig.js @@ -1,18 +1,18 @@ const path = require('path'); -const { - CacheKeys, - configSchema, - EImageOutputType, - validateSettingDefinitions, - agentParamSettings, - paramSettings, -} = require('librechat-data-provider'); -const getLogStores = require('~/cache/getLogStores'); -const loadYaml = require('~/utils/loadYaml'); -const { logger } = require('~/config'); const axios = require('axios'); const yaml = require('js-yaml'); const keyBy = require('lodash/keyBy'); +const { loadYaml } = require('@librechat/api'); +const { logger } = require('@librechat/data-schemas'); +const { + CacheKeys, + configSchema, + paramSettings, + EImageOutputType, + agentParamSettings, + validateSettingDefinitions, +} = require('librechat-data-provider'); +const getLogStores = require('~/cache/getLogStores'); const projectRoot = path.resolve(__dirname, '..', '..', '..', '..'); const defaultConfigPath = path.resolve(projectRoot, 'librechat.yaml'); diff --git a/api/server/services/Config/loadCustomConfig.spec.js b/api/server/services/Config/loadCustomConfig.spec.js index ed698e57f1..9b905181c5 100644 --- a/api/server/services/Config/loadCustomConfig.spec.js +++ b/api/server/services/Config/loadCustomConfig.spec.js @@ -1,6 +1,9 @@ jest.mock('axios'); jest.mock('~/cache/getLogStores'); -jest.mock('~/utils/loadYaml'); +jest.mock('@librechat/api', () => ({ + ...jest.requireActual('@librechat/api'), + loadYaml: jest.fn(), +})); jest.mock('librechat-data-provider', () => { const actual = jest.requireActual('librechat-data-provider'); return { @@ -30,11 +33,22 @@ jest.mock('librechat-data-provider', () => { }; }); +jest.mock('@librechat/data-schemas', () => { + return { + logger: { + info: jest.fn(), + warn: jest.fn(), + debug: jest.fn(), + error: jest.fn(), + }, + }; +}); + const axios = require('axios'); +const { loadYaml } = require('@librechat/api'); +const { logger } = require('@librechat/data-schemas'); const loadCustomConfig = require('./loadCustomConfig'); const getLogStores = require('~/cache/getLogStores'); -const loadYaml = require('~/utils/loadYaml'); -const { logger } = require('~/config'); describe('loadCustomConfig', () => { const mockSet = jest.fn(); diff --git a/api/server/services/Endpoints/agents/agent.js b/api/server/services/Endpoints/agents/agent.js index e135401467..6fde14b366 100644 --- a/api/server/services/Endpoints/agents/agent.js +++ b/api/server/services/Endpoints/agents/agent.js @@ -1,5 +1,9 @@ const { Providers } = require('@librechat/agents'); -const { primeResources, optionalChainWithEmptyCheck } = require('@librechat/api'); +const { + primeResources, + extractLibreChatParams, + optionalChainWithEmptyCheck, +} = require('@librechat/api'); const { ErrorTypes, EModelEndpoint, @@ -7,30 +11,12 @@ const { replaceSpecialVars, providerEndpointMap, } = require('librechat-data-provider'); -const initAnthropic = require('~/server/services/Endpoints/anthropic/initialize'); -const getBedrockOptions = require('~/server/services/Endpoints/bedrock/options'); -const initOpenAI = require('~/server/services/Endpoints/openAI/initialize'); -const initCustom = require('~/server/services/Endpoints/custom/initialize'); -const initGoogle = require('~/server/services/Endpoints/google/initialize'); +const { getProviderConfig } = require('~/server/services/Endpoints'); const generateArtifactsPrompt = require('~/app/clients/prompts/artifacts'); -const { getCustomEndpointConfig } = require('~/server/services/Config'); const { processFiles } = require('~/server/services/Files/process'); +const { getFiles, getToolFilesByIds } = require('~/models/File'); const { getConvoFiles } = require('~/models/Conversation'); -const { getToolFilesByIds } = require('~/models/File'); const { getModelMaxTokens } = require('~/utils'); -const { getFiles } = require('~/models/File'); - -const providerConfigMap = { - [Providers.XAI]: initCustom, - [Providers.OLLAMA]: initCustom, - [Providers.DEEPSEEK]: initCustom, - [Providers.OPENROUTER]: initCustom, - [EModelEndpoint.openAI]: initOpenAI, - [EModelEndpoint.google]: initGoogle, - [EModelEndpoint.azureOpenAI]: initOpenAI, - [EModelEndpoint.anthropic]: initAnthropic, - [EModelEndpoint.bedrock]: getBedrockOptions, -}; /** * @param {object} params @@ -71,7 +57,7 @@ const initializeAgent = async ({ ), ); - const { resendFiles = true, ...modelOptions } = _modelOptions; + const { resendFiles, maxContextTokens, modelOptions } = extractLibreChatParams(_modelOptions); if (isInitialAgent && conversationId != null && resendFiles) { const fileIds = (await getConvoFiles(conversationId)) ?? []; @@ -99,7 +85,7 @@ const initializeAgent = async ({ }); const provider = agent.provider; - const { tools, toolContextMap } = + const { tools: structuredTools, toolContextMap } = (await loadTools?.({ req, res, @@ -111,17 +97,9 @@ const initializeAgent = async ({ })) ?? {}; agent.endpoint = provider; - let getOptions = providerConfigMap[provider]; - if (!getOptions && providerConfigMap[provider.toLowerCase()] != null) { - agent.provider = provider.toLowerCase(); - getOptions = providerConfigMap[agent.provider]; - } else if (!getOptions) { - const customEndpointConfig = await getCustomEndpointConfig(provider); - if (!customEndpointConfig) { - throw new Error(`Provider ${provider} not supported`); - } - getOptions = initCustom; - agent.provider = Providers.OPENAI; + const { getOptions, overrideProvider } = await getProviderConfig(provider); + if (overrideProvider) { + agent.provider = overrideProvider; } const _endpointOption = @@ -145,9 +123,8 @@ const initializeAgent = async ({ modelOptions.maxTokens, 0, ); - const maxContextTokens = optionalChainWithEmptyCheck( - modelOptions.maxContextTokens, - modelOptions.max_context_tokens, + const agentMaxContextTokens = optionalChainWithEmptyCheck( + maxContextTokens, getModelMaxTokens(tokensModel, providerEndpointMap[provider]), 4096, ); @@ -163,6 +140,24 @@ const initializeAgent = async ({ agent.provider = options.provider; } + /** @type {import('@librechat/agents').GenericTool[]} */ + let tools = options.tools?.length ? options.tools : structuredTools; + if ( + (agent.provider === Providers.GOOGLE || agent.provider === Providers.VERTEXAI) && + options.tools?.length && + structuredTools?.length + ) { + throw new Error(`{ "type": "${ErrorTypes.GOOGLE_TOOL_CONFLICT}"}`); + } else if ( + (agent.provider === Providers.OPENAI || + agent.provider === Providers.AZURE || + agent.provider === Providers.ANTHROPIC) && + options.tools?.length && + structuredTools?.length + ) { + tools = structuredTools.concat(options.tools); + } + /** @type {import('@librechat/agents').ClientOptions} */ agent.model_parameters = { ...options.llmConfig }; if (options.configOptions) { @@ -185,11 +180,11 @@ const initializeAgent = async ({ return { ...agent, - tools, attachments, resendFiles, toolContextMap, - maxContextTokens: (maxContextTokens - maxTokens) * 0.9, + tools, + maxContextTokens: (agentMaxContextTokens - maxTokens) * 0.9, }; }; diff --git a/api/server/services/Endpoints/agents/build.js b/api/server/services/Endpoints/agents/build.js index 77ebbc58dc..143dde9459 100644 --- a/api/server/services/Endpoints/agents/build.js +++ b/api/server/services/Endpoints/agents/build.js @@ -1,10 +1,9 @@ -const { isAgentsEndpoint, Constants } = require('librechat-data-provider'); +const { isAgentsEndpoint, removeNullishValues, Constants } = require('librechat-data-provider'); const { loadAgent } = require('~/models/Agent'); const { logger } = require('~/config'); const buildOptions = (req, endpoint, parsedBody, endpointType) => { - const { spec, iconURL, agent_id, instructions, maxContextTokens, ...model_parameters } = - parsedBody; + const { spec, iconURL, agent_id, instructions, ...model_parameters } = parsedBody; const agentPromise = loadAgent({ req, agent_id: isAgentsEndpoint(endpoint) ? agent_id : Constants.EPHEMERAL_AGENT_ID, @@ -15,19 +14,16 @@ const buildOptions = (req, endpoint, parsedBody, endpointType) => { return undefined; }); - const endpointOption = { + return removeNullishValues({ spec, iconURL, endpoint, agent_id, endpointType, instructions, - maxContextTokens, model_parameters, agent: agentPromise, - }; - - return endpointOption; + }); }; module.exports = { buildOptions }; diff --git a/api/server/services/Endpoints/agents/initialize.js b/api/server/services/Endpoints/agents/initialize.js index e4ffcf4730..94af3bdd3b 100644 --- a/api/server/services/Endpoints/agents/initialize.js +++ b/api/server/services/Endpoints/agents/initialize.js @@ -1,11 +1,17 @@ const { logger } = require('@librechat/data-schemas'); const { createContentAggregator } = require('@librechat/agents'); -const { Constants, EModelEndpoint, getResponseSender } = require('librechat-data-provider'); const { - getDefaultHandlers, + Constants, + EModelEndpoint, + isAgentsEndpoint, + getResponseSender, +} = require('librechat-data-provider'); +const { createToolEndCallback, + getDefaultHandlers, } = require('~/server/controllers/agents/callbacks'); const { initializeAgent } = require('~/server/services/Endpoints/agents/agent'); +const { getCustomEndpointConfig } = require('~/server/services/Config'); const { loadAgentTools } = require('~/server/services/ToolService'); const AgentClient = require('~/server/controllers/agents/client'); const { getAgent } = require('~/models/Agent'); @@ -61,6 +67,7 @@ const initializeClient = async ({ req, res, endpointOption }) => { } const primaryAgent = await endpointOption.agent; + delete endpointOption.agent; if (!primaryAgent) { throw new Error('Agent not found'); } @@ -108,11 +115,25 @@ const initializeClient = async ({ req, res, endpointOption }) => { } } + let endpointConfig = req.app.locals[primaryConfig.endpoint]; + if (!isAgentsEndpoint(primaryConfig.endpoint) && !endpointConfig) { + try { + endpointConfig = await getCustomEndpointConfig(primaryConfig.endpoint); + } catch (err) { + logger.error( + '[api/server/controllers/agents/client.js #titleConvo] Error getting custom endpoint config', + err, + ); + } + } + const sender = primaryAgent.name ?? getResponseSender({ ...endpointOption, model: endpointOption.model_parameters.model, + modelDisplayLabel: endpointConfig?.modelDisplayLabel, + modelLabel: endpointOption.model_parameters.modelLabel, }); const client = new AgentClient({ diff --git a/api/server/services/Endpoints/agents/title.js b/api/server/services/Endpoints/agents/title.js index ab171bc79d..2e5f00ecd0 100644 --- a/api/server/services/Endpoints/agents/title.js +++ b/api/server/services/Endpoints/agents/title.js @@ -23,7 +23,7 @@ const addTitle = async (req, { text, response, client }) => { let timeoutId; try { const timeoutPromise = new Promise((_, reject) => { - timeoutId = setTimeout(() => reject(new Error('Title generation timeout')), 25000); + timeoutId = setTimeout(() => reject(new Error('Title generation timeout')), 45000); }).catch((error) => { logger.error('Title error:', error); }); diff --git a/api/server/services/Endpoints/anthropic/initialize.js b/api/server/services/Endpoints/anthropic/initialize.js index d4c6dd1795..4546fc634c 100644 --- a/api/server/services/Endpoints/anthropic/initialize.js +++ b/api/server/services/Endpoints/anthropic/initialize.js @@ -41,7 +41,7 @@ const initializeClient = async ({ req, res, endpointOption, overrideModel, optio { reverseProxyUrl: ANTHROPIC_REVERSE_PROXY ?? null, proxy: PROXY ?? null, - modelOptions: endpointOption.model_parameters, + modelOptions: endpointOption?.model_parameters ?? {}, }, clientOptions, ); diff --git a/api/server/services/Endpoints/anthropic/llm.js b/api/server/services/Endpoints/anthropic/llm.js index 66496f00fd..8355b8aa26 100644 --- a/api/server/services/Endpoints/anthropic/llm.js +++ b/api/server/services/Endpoints/anthropic/llm.js @@ -75,9 +75,20 @@ function getLLMConfig(apiKey, options = {}) { if (options.reverseProxyUrl) { requestOptions.clientOptions.baseURL = options.reverseProxyUrl; + requestOptions.anthropicApiUrl = options.reverseProxyUrl; + } + + const tools = []; + + if (mergedOptions.web_search) { + tools.push({ + type: 'web_search_20250305', + name: 'web_search', + }); } return { + tools, /** @type {AnthropicClientOptions} */ llmConfig: removeNullishValues(requestOptions), }; diff --git a/api/server/services/Endpoints/anthropic/llm.spec.js b/api/server/services/Endpoints/anthropic/llm.spec.js index f3f77ee897..cd29975e0a 100644 --- a/api/server/services/Endpoints/anthropic/llm.spec.js +++ b/api/server/services/Endpoints/anthropic/llm.spec.js @@ -1,11 +1,45 @@ -const { anthropicSettings } = require('librechat-data-provider'); +const { anthropicSettings, removeNullishValues } = require('librechat-data-provider'); const { getLLMConfig } = require('~/server/services/Endpoints/anthropic/llm'); +const { checkPromptCacheSupport, getClaudeHeaders, configureReasoning } = require('./helpers'); jest.mock('https-proxy-agent', () => ({ HttpsProxyAgent: jest.fn().mockImplementation((proxy) => ({ proxy })), })); +jest.mock('./helpers', () => ({ + checkPromptCacheSupport: jest.fn(), + getClaudeHeaders: jest.fn(), + configureReasoning: jest.fn((requestOptions) => requestOptions), +})); + +jest.mock('librechat-data-provider', () => ({ + anthropicSettings: { + model: { default: 'claude-3-opus-20240229' }, + maxOutputTokens: { default: 4096, reset: jest.fn(() => 4096) }, + thinking: { default: false }, + promptCache: { default: false }, + thinkingBudget: { default: null }, + }, + removeNullishValues: jest.fn((obj) => { + const result = {}; + for (const key in obj) { + if (obj[key] !== null && obj[key] !== undefined) { + result[key] = obj[key]; + } + } + return result; + }), +})); + describe('getLLMConfig', () => { + beforeEach(() => { + jest.clearAllMocks(); + checkPromptCacheSupport.mockReturnValue(false); + getClaudeHeaders.mockReturnValue(undefined); + configureReasoning.mockImplementation((requestOptions) => requestOptions); + anthropicSettings.maxOutputTokens.reset.mockReturnValue(4096); + }); + it('should create a basic configuration with default values', () => { const result = getLLMConfig('test-api-key', { modelOptions: {} }); @@ -36,6 +70,7 @@ describe('getLLMConfig', () => { }); expect(result.llmConfig.clientOptions).toHaveProperty('baseURL', 'http://reverse-proxy'); + expect(result.llmConfig).toHaveProperty('anthropicApiUrl', 'http://reverse-proxy'); }); it('should include topK and topP for non-Claude-3.7 models', () => { @@ -65,6 +100,11 @@ describe('getLLMConfig', () => { }); it('should NOT include topK and topP for Claude-3-7 models (hyphen notation)', () => { + configureReasoning.mockImplementation((requestOptions) => { + requestOptions.thinking = { type: 'enabled' }; + return requestOptions; + }); + const result = getLLMConfig('test-api-key', { modelOptions: { model: 'claude-3-7-sonnet', @@ -78,6 +118,11 @@ describe('getLLMConfig', () => { }); it('should NOT include topK and topP for Claude-3.7 models (decimal notation)', () => { + configureReasoning.mockImplementation((requestOptions) => { + requestOptions.thinking = { type: 'enabled' }; + return requestOptions; + }); + const result = getLLMConfig('test-api-key', { modelOptions: { model: 'claude-3.7-sonnet', @@ -154,4 +199,160 @@ describe('getLLMConfig', () => { expect(result3.llmConfig).toHaveProperty('topK', 10); expect(result3.llmConfig).toHaveProperty('topP', 0.9); }); + + describe('Edge cases', () => { + it('should handle missing apiKey', () => { + const result = getLLMConfig(undefined, { modelOptions: {} }); + expect(result.llmConfig).not.toHaveProperty('apiKey'); + }); + + it('should handle empty modelOptions', () => { + expect(() => { + getLLMConfig('test-api-key', {}); + }).toThrow("Cannot read properties of undefined (reading 'thinking')"); + }); + + it('should handle no options parameter', () => { + expect(() => { + getLLMConfig('test-api-key'); + }).toThrow("Cannot read properties of undefined (reading 'thinking')"); + }); + + it('should handle temperature, stop sequences, and stream settings', () => { + const result = getLLMConfig('test-api-key', { + modelOptions: { + temperature: 0.7, + stop: ['\n\n', 'END'], + stream: false, + }, + }); + + expect(result.llmConfig).toHaveProperty('temperature', 0.7); + expect(result.llmConfig).toHaveProperty('stopSequences', ['\n\n', 'END']); + expect(result.llmConfig).toHaveProperty('stream', false); + }); + + it('should handle maxOutputTokens when explicitly set to falsy value', () => { + anthropicSettings.maxOutputTokens.reset.mockReturnValue(8192); + const result = getLLMConfig('test-api-key', { + modelOptions: { + model: 'claude-3-opus', + maxOutputTokens: null, + }, + }); + + expect(anthropicSettings.maxOutputTokens.reset).toHaveBeenCalledWith('claude-3-opus'); + expect(result.llmConfig).toHaveProperty('maxTokens', 8192); + }); + + it('should handle both proxy and reverseProxyUrl', () => { + const result = getLLMConfig('test-api-key', { + modelOptions: {}, + proxy: 'http://proxy:8080', + reverseProxyUrl: 'https://reverse-proxy.com', + }); + + expect(result.llmConfig.clientOptions).toHaveProperty('fetchOptions'); + expect(result.llmConfig.clientOptions.fetchOptions).toHaveProperty('dispatcher'); + expect(result.llmConfig.clientOptions.fetchOptions.dispatcher).toBeDefined(); + expect(result.llmConfig.clientOptions.fetchOptions.dispatcher.constructor.name).toBe( + 'ProxyAgent', + ); + expect(result.llmConfig.clientOptions).toHaveProperty('baseURL', 'https://reverse-proxy.com'); + expect(result.llmConfig).toHaveProperty('anthropicApiUrl', 'https://reverse-proxy.com'); + }); + + it('should handle prompt cache with supported model', () => { + checkPromptCacheSupport.mockReturnValue(true); + getClaudeHeaders.mockReturnValue({ 'anthropic-beta': 'prompt-caching-2024-07-31' }); + + const result = getLLMConfig('test-api-key', { + modelOptions: { + model: 'claude-3-5-sonnet', + promptCache: true, + }, + }); + + expect(checkPromptCacheSupport).toHaveBeenCalledWith('claude-3-5-sonnet'); + expect(getClaudeHeaders).toHaveBeenCalledWith('claude-3-5-sonnet', true); + expect(result.llmConfig.clientOptions.defaultHeaders).toEqual({ + 'anthropic-beta': 'prompt-caching-2024-07-31', + }); + }); + + it('should handle thinking and thinkingBudget options', () => { + configureReasoning.mockImplementation((requestOptions, systemOptions) => { + if (systemOptions.thinking) { + requestOptions.thinking = { type: 'enabled' }; + } + if (systemOptions.thinkingBudget) { + requestOptions.thinking = { + ...requestOptions.thinking, + budget_tokens: systemOptions.thinkingBudget, + }; + } + return requestOptions; + }); + + getLLMConfig('test-api-key', { + modelOptions: { + model: 'claude-3-7-sonnet', + thinking: true, + thinkingBudget: 5000, + }, + }); + + expect(configureReasoning).toHaveBeenCalledWith( + expect.any(Object), + expect.objectContaining({ + thinking: true, + promptCache: false, + thinkingBudget: 5000, + }), + ); + }); + + it('should remove system options from modelOptions', () => { + const modelOptions = { + model: 'claude-3-opus', + thinking: true, + promptCache: true, + thinkingBudget: 1000, + temperature: 0.5, + }; + + getLLMConfig('test-api-key', { modelOptions }); + + expect(modelOptions).not.toHaveProperty('thinking'); + expect(modelOptions).not.toHaveProperty('promptCache'); + expect(modelOptions).not.toHaveProperty('thinkingBudget'); + expect(modelOptions).toHaveProperty('temperature', 0.5); + }); + + it('should handle all nullish values removal', () => { + removeNullishValues.mockImplementation((obj) => { + const cleaned = {}; + Object.entries(obj).forEach(([key, value]) => { + if (value !== null && value !== undefined) { + cleaned[key] = value; + } + }); + return cleaned; + }); + + const result = getLLMConfig('test-api-key', { + modelOptions: { + temperature: null, + topP: undefined, + topK: 0, + stop: [], + }, + }); + + expect(result.llmConfig).not.toHaveProperty('temperature'); + expect(result.llmConfig).not.toHaveProperty('topP'); + expect(result.llmConfig).toHaveProperty('topK', 0); + expect(result.llmConfig).toHaveProperty('stopSequences', []); + }); + }); }); diff --git a/api/server/services/Endpoints/azureAssistants/initialize.js b/api/server/services/Endpoints/azureAssistants/initialize.js index 88acef23e5..132c123e7e 100644 --- a/api/server/services/Endpoints/azureAssistants/initialize.js +++ b/api/server/services/Endpoints/azureAssistants/initialize.js @@ -1,12 +1,7 @@ const OpenAI = require('openai'); const { HttpsProxyAgent } = require('https-proxy-agent'); -const { constructAzureURL, isUserProvided } = require('@librechat/api'); -const { - ErrorTypes, - EModelEndpoint, - resolveHeaders, - mapModelToAzureConfig, -} = require('librechat-data-provider'); +const { constructAzureURL, isUserProvided, resolveHeaders } = require('@librechat/api'); +const { ErrorTypes, EModelEndpoint, mapModelToAzureConfig } = require('librechat-data-provider'); const { getUserKeyValues, getUserKeyExpiry, @@ -114,11 +109,14 @@ const initializeClient = async ({ req, res, version, endpointOption, initAppClie apiKey = azureOptions.azureOpenAIApiKey; opts.defaultQuery = { 'api-version': azureOptions.azureOpenAIApiVersion }; - opts.defaultHeaders = resolveHeaders({ - ...headers, - 'api-key': apiKey, - 'OpenAI-Beta': `assistants=${version}`, - }); + opts.defaultHeaders = resolveHeaders( + { + ...headers, + 'api-key': apiKey, + 'OpenAI-Beta': `assistants=${version}`, + }, + req.user, + ); opts.model = azureOptions.azureOpenAIApiDeploymentName; if (initAppClient) { diff --git a/api/server/services/Endpoints/bedrock/options.js b/api/server/services/Endpoints/bedrock/options.js index fc5536abbf..a31d6e10c4 100644 --- a/api/server/services/Endpoints/bedrock/options.js +++ b/api/server/services/Endpoints/bedrock/options.js @@ -64,7 +64,7 @@ const getOptions = async ({ req, overrideModel, endpointOption }) => { /** @type {BedrockClientOptions} */ const requestOptions = { - model: overrideModel ?? endpointOption.model, + model: overrideModel ?? endpointOption?.model, region: BEDROCK_AWS_DEFAULT_REGION, }; @@ -76,7 +76,7 @@ const getOptions = async ({ req, overrideModel, endpointOption }) => { const llmConfig = bedrockOutputParser( bedrockInputParser.parse( - removeNullishValues(Object.assign(requestOptions, endpointOption.model_parameters)), + removeNullishValues(Object.assign(requestOptions, endpointOption?.model_parameters ?? {})), ), ); diff --git a/api/server/services/Endpoints/custom/initialize.js b/api/server/services/Endpoints/custom/initialize.js index 754abef5a8..4fcbe76ea6 100644 --- a/api/server/services/Endpoints/custom/initialize.js +++ b/api/server/services/Endpoints/custom/initialize.js @@ -6,7 +6,7 @@ const { extractEnvVariable, } = require('librechat-data-provider'); const { Providers } = require('@librechat/agents'); -const { getOpenAIConfig, createHandleLLMNewToken } = require('@librechat/api'); +const { getOpenAIConfig, createHandleLLMNewToken, resolveHeaders } = require('@librechat/api'); const { getUserKeyValues, checkUserKeyExpiry } = require('~/server/services/UserService'); const { getCustomEndpointConfig } = require('~/server/services/Config'); const { fetchModels } = require('~/server/services/ModelService'); @@ -28,12 +28,7 @@ const initializeClient = async ({ req, res, endpointOption, optionsOnly, overrid const CUSTOM_API_KEY = extractEnvVariable(endpointConfig.apiKey); const CUSTOM_BASE_URL = extractEnvVariable(endpointConfig.baseURL); - let resolvedHeaders = {}; - if (endpointConfig.headers && typeof endpointConfig.headers === 'object') { - Object.keys(endpointConfig.headers).forEach((key) => { - resolvedHeaders[key] = extractEnvVariable(endpointConfig.headers[key]); - }); - } + let resolvedHeaders = resolveHeaders(endpointConfig.headers, req.user); if (CUSTOM_API_KEY.match(envVarRegex)) { throw new Error(`Missing API Key for ${endpoint}.`); @@ -134,7 +129,7 @@ const initializeClient = async ({ req, res, endpointOption, optionsOnly, overrid }; if (optionsOnly) { - const modelOptions = endpointOption.model_parameters; + const modelOptions = endpointOption?.model_parameters ?? {}; if (endpoint !== Providers.OLLAMA) { clientOptions = Object.assign( { diff --git a/api/server/services/Endpoints/custom/initialize.spec.js b/api/server/services/Endpoints/custom/initialize.spec.js new file mode 100644 index 0000000000..7e28995127 --- /dev/null +++ b/api/server/services/Endpoints/custom/initialize.spec.js @@ -0,0 +1,93 @@ +const initializeClient = require('./initialize'); + +jest.mock('@librechat/api', () => ({ + resolveHeaders: jest.fn(), + getOpenAIConfig: jest.fn(), + createHandleLLMNewToken: jest.fn(), +})); + +jest.mock('librechat-data-provider', () => ({ + CacheKeys: { TOKEN_CONFIG: 'token_config' }, + ErrorTypes: { NO_USER_KEY: 'NO_USER_KEY', NO_BASE_URL: 'NO_BASE_URL' }, + envVarRegex: /\$\{([^}]+)\}/, + FetchTokenConfig: {}, + extractEnvVariable: jest.fn((value) => value), +})); + +jest.mock('@librechat/agents', () => ({ + Providers: { OLLAMA: 'ollama' }, +})); + +jest.mock('~/server/services/UserService', () => ({ + getUserKeyValues: jest.fn(), + checkUserKeyExpiry: jest.fn(), +})); + +jest.mock('~/server/services/Config', () => ({ + getCustomEndpointConfig: jest.fn().mockResolvedValue({ + apiKey: 'test-key', + baseURL: 'https://test.com', + headers: { 'x-user': '{{LIBRECHAT_USER_ID}}', 'x-email': '{{LIBRECHAT_USER_EMAIL}}' }, + models: { default: ['test-model'] }, + }), +})); + +jest.mock('~/server/services/ModelService', () => ({ + fetchModels: jest.fn(), +})); + +jest.mock('~/app/clients/OpenAIClient', () => { + return jest.fn().mockImplementation(() => ({ + options: {}, + })); +}); + +jest.mock('~/server/utils', () => ({ + isUserProvided: jest.fn().mockReturnValue(false), +})); + +jest.mock('~/cache/getLogStores', () => + jest.fn().mockReturnValue({ + get: jest.fn(), + }), +); + +describe('custom/initializeClient', () => { + const mockRequest = { + body: { endpoint: 'test-endpoint' }, + user: { id: 'user-123', email: 'test@example.com' }, + app: { locals: {} }, + }; + const mockResponse = {}; + + beforeEach(() => { + jest.clearAllMocks(); + }); + + it('calls resolveHeaders with headers and user', async () => { + const { resolveHeaders } = require('@librechat/api'); + await initializeClient({ req: mockRequest, res: mockResponse, optionsOnly: true }); + expect(resolveHeaders).toHaveBeenCalledWith( + { 'x-user': '{{LIBRECHAT_USER_ID}}', 'x-email': '{{LIBRECHAT_USER_EMAIL}}' }, + { id: 'user-123', email: 'test@example.com' }, + ); + }); + + it('throws if endpoint config is missing', async () => { + const { getCustomEndpointConfig } = require('~/server/services/Config'); + getCustomEndpointConfig.mockResolvedValueOnce(null); + await expect( + initializeClient({ req: mockRequest, res: mockResponse, optionsOnly: true }), + ).rejects.toThrow('Config not found for the test-endpoint custom endpoint.'); + }); + + it('throws if user is missing', async () => { + await expect( + initializeClient({ + req: { ...mockRequest, user: undefined }, + res: mockResponse, + optionsOnly: true, + }), + ).rejects.toThrow("Cannot read properties of undefined (reading 'id')"); + }); +}); diff --git a/api/server/services/Endpoints/google/initialize.js b/api/server/services/Endpoints/google/initialize.js index b6bc2d6a79..75a31a8c09 100644 --- a/api/server/services/Endpoints/google/initialize.js +++ b/api/server/services/Endpoints/google/initialize.js @@ -1,7 +1,7 @@ +const path = require('path'); const { EModelEndpoint, AuthKeys } = require('librechat-data-provider'); +const { getGoogleConfig, isEnabled, loadServiceKey } = require('@librechat/api'); const { getUserKey, checkUserKeyExpiry } = require('~/server/services/UserService'); -const { getLLMConfig } = require('~/server/services/Endpoints/google/llm'); -const { isEnabled } = require('~/server/utils'); const { GoogleClient } = require('~/app'); const initializeClient = async ({ req, res, endpointOption, overrideModel, optionsOnly }) => { @@ -16,10 +16,25 @@ const initializeClient = async ({ req, res, endpointOption, overrideModel, optio } let serviceKey = {}; - try { - serviceKey = require('~/data/auth.json'); - } catch (e) { - // Do nothing + + /** Check if GOOGLE_KEY is provided at all (including 'user_provided') */ + const isGoogleKeyProvided = + (GOOGLE_KEY && GOOGLE_KEY.trim() !== '') || (isUserProvided && userKey != null); + + if (!isGoogleKeyProvided) { + /** Only attempt to load service key if GOOGLE_KEY is not provided */ + try { + const serviceKeyPath = + process.env.GOOGLE_SERVICE_KEY_FILE || + path.join(__dirname, '../../../..', 'data', 'auth.json'); + serviceKey = await loadServiceKey(serviceKeyPath); + if (!serviceKey) { + serviceKey = {}; + } + } catch (_e) { + // Service key loading failed, but that's okay if not required + serviceKey = {}; + } } const credentials = isUserProvided @@ -58,14 +73,14 @@ const initializeClient = async ({ req, res, endpointOption, overrideModel, optio if (optionsOnly) { clientOptions = Object.assign( { - modelOptions: endpointOption.model_parameters, + modelOptions: endpointOption?.model_parameters ?? {}, }, clientOptions, ); if (overrideModel) { clientOptions.modelOptions.model = overrideModel; } - return getLLMConfig(credentials, clientOptions); + return getGoogleConfig(credentials, clientOptions); } const client = new GoogleClient(credentials, clientOptions); diff --git a/api/server/services/Endpoints/gptPlugins/build.js b/api/server/services/Endpoints/gptPlugins/build.js deleted file mode 100644 index 0d1ec097ad..0000000000 --- a/api/server/services/Endpoints/gptPlugins/build.js +++ /dev/null @@ -1,41 +0,0 @@ -const { removeNullishValues } = require('librechat-data-provider'); -const generateArtifactsPrompt = require('~/app/clients/prompts/artifacts'); - -const buildOptions = (endpoint, parsedBody) => { - const { - modelLabel, - chatGptLabel, - promptPrefix, - agentOptions, - tools = [], - iconURL, - greeting, - spec, - maxContextTokens, - artifacts, - ...modelOptions - } = parsedBody; - const endpointOption = removeNullishValues({ - endpoint, - tools: tools - .map((tool) => tool?.pluginKey ?? tool) - .filter((toolName) => typeof toolName === 'string'), - modelLabel, - chatGptLabel, - promptPrefix, - agentOptions, - iconURL, - greeting, - spec, - maxContextTokens, - modelOptions, - }); - - if (typeof artifacts === 'string') { - endpointOption.artifactsPrompt = generateArtifactsPrompt({ endpoint, artifacts }); - } - - return endpointOption; -}; - -module.exports = buildOptions; diff --git a/api/server/services/Endpoints/gptPlugins/index.js b/api/server/services/Endpoints/gptPlugins/index.js deleted file mode 100644 index 202cb0e4d7..0000000000 --- a/api/server/services/Endpoints/gptPlugins/index.js +++ /dev/null @@ -1,7 +0,0 @@ -const buildOptions = require('./build'); -const initializeClient = require('./initialize'); - -module.exports = { - buildOptions, - initializeClient, -}; diff --git a/api/server/services/Endpoints/gptPlugins/initialize.js b/api/server/services/Endpoints/gptPlugins/initialize.js deleted file mode 100644 index d2af6c757e..0000000000 --- a/api/server/services/Endpoints/gptPlugins/initialize.js +++ /dev/null @@ -1,134 +0,0 @@ -const { - EModelEndpoint, - resolveHeaders, - mapModelToAzureConfig, -} = require('librechat-data-provider'); -const { isEnabled, isUserProvided, getAzureCredentials } = require('@librechat/api'); -const { getUserKeyValues, checkUserKeyExpiry } = require('~/server/services/UserService'); -const { PluginsClient } = require('~/app'); - -const initializeClient = async ({ req, res, endpointOption }) => { - const { - PROXY, - OPENAI_API_KEY, - AZURE_API_KEY, - PLUGINS_USE_AZURE, - OPENAI_REVERSE_PROXY, - AZURE_OPENAI_BASEURL, - OPENAI_SUMMARIZE, - DEBUG_PLUGINS, - } = process.env; - - const { key: expiresAt, model: modelName } = req.body; - const contextStrategy = isEnabled(OPENAI_SUMMARIZE) ? 'summarize' : null; - - let useAzure = isEnabled(PLUGINS_USE_AZURE); - let endpoint = useAzure ? EModelEndpoint.azureOpenAI : EModelEndpoint.openAI; - - /** @type {false | TAzureConfig} */ - const azureConfig = req.app.locals[EModelEndpoint.azureOpenAI]; - useAzure = useAzure || azureConfig?.plugins; - - if (useAzure && endpoint !== EModelEndpoint.azureOpenAI) { - endpoint = EModelEndpoint.azureOpenAI; - } - - const credentials = { - [EModelEndpoint.openAI]: OPENAI_API_KEY, - [EModelEndpoint.azureOpenAI]: AZURE_API_KEY, - }; - - const baseURLOptions = { - [EModelEndpoint.openAI]: OPENAI_REVERSE_PROXY, - [EModelEndpoint.azureOpenAI]: AZURE_OPENAI_BASEURL, - }; - - const userProvidesKey = isUserProvided(credentials[endpoint]); - const userProvidesURL = isUserProvided(baseURLOptions[endpoint]); - - let userValues = null; - if (expiresAt && (userProvidesKey || userProvidesURL)) { - checkUserKeyExpiry(expiresAt, endpoint); - userValues = await getUserKeyValues({ userId: req.user.id, name: endpoint }); - } - - let apiKey = userProvidesKey ? userValues?.apiKey : credentials[endpoint]; - let baseURL = userProvidesURL ? userValues?.baseURL : baseURLOptions[endpoint]; - - const clientOptions = { - contextStrategy, - debug: isEnabled(DEBUG_PLUGINS), - reverseProxyUrl: baseURL ? baseURL : null, - proxy: PROXY ?? null, - req, - res, - ...endpointOption, - }; - - if (useAzure && azureConfig) { - const { modelGroupMap, groupMap } = azureConfig; - const { - azureOptions, - baseURL, - headers = {}, - serverless, - } = mapModelToAzureConfig({ - modelName, - modelGroupMap, - groupMap, - }); - - clientOptions.reverseProxyUrl = baseURL ?? clientOptions.reverseProxyUrl; - clientOptions.headers = resolveHeaders({ ...headers, ...(clientOptions.headers ?? {}) }); - - clientOptions.titleConvo = azureConfig.titleConvo; - clientOptions.titleModel = azureConfig.titleModel; - clientOptions.titleMethod = azureConfig.titleMethod ?? 'completion'; - - const azureRate = modelName.includes('gpt-4') ? 30 : 17; - clientOptions.streamRate = azureConfig.streamRate ?? azureRate; - - const groupName = modelGroupMap[modelName].group; - clientOptions.addParams = azureConfig.groupMap[groupName].addParams; - clientOptions.dropParams = azureConfig.groupMap[groupName].dropParams; - clientOptions.forcePrompt = azureConfig.groupMap[groupName].forcePrompt; - - apiKey = azureOptions.azureOpenAIApiKey; - clientOptions.azure = !serverless && azureOptions; - if (serverless === true) { - clientOptions.defaultQuery = azureOptions.azureOpenAIApiVersion - ? { 'api-version': azureOptions.azureOpenAIApiVersion } - : undefined; - clientOptions.headers['api-key'] = apiKey; - } - } else if (useAzure || (apiKey && apiKey.includes('{"azure') && !clientOptions.azure)) { - clientOptions.azure = userProvidesKey ? JSON.parse(userValues.apiKey) : getAzureCredentials(); - apiKey = clientOptions.azure.azureOpenAIApiKey; - } - - /** @type {undefined | TBaseEndpoint} */ - const pluginsConfig = req.app.locals[EModelEndpoint.gptPlugins]; - - if (!useAzure && pluginsConfig) { - clientOptions.streamRate = pluginsConfig.streamRate; - } - - /** @type {undefined | TBaseEndpoint} */ - const allConfig = req.app.locals.all; - if (allConfig) { - clientOptions.streamRate = allConfig.streamRate; - } - - if (!apiKey) { - throw new Error(`${endpoint} API key not provided. Please provide it again.`); - } - - const client = new PluginsClient(apiKey, clientOptions); - return { - client, - azure: clientOptions.azure, - openAIApiKey: apiKey, - }; -}; - -module.exports = initializeClient; diff --git a/api/server/services/Endpoints/gptPlugins/initialize.spec.js b/api/server/services/Endpoints/gptPlugins/initialize.spec.js deleted file mode 100644 index f9cb2750a4..0000000000 --- a/api/server/services/Endpoints/gptPlugins/initialize.spec.js +++ /dev/null @@ -1,410 +0,0 @@ -// gptPlugins/initializeClient.spec.js -jest.mock('~/cache/getLogStores'); -const { EModelEndpoint, ErrorTypes, validateAzureGroups } = require('librechat-data-provider'); -const { getUserKey, getUserKeyValues } = require('~/server/services/UserService'); -const initializeClient = require('./initialize'); -const { PluginsClient } = require('~/app'); - -// Mock getUserKey since it's the only function we want to mock -jest.mock('~/server/services/UserService', () => ({ - getUserKey: jest.fn(), - getUserKeyValues: jest.fn(), - checkUserKeyExpiry: jest.requireActual('~/server/services/UserService').checkUserKeyExpiry, -})); - -describe('gptPlugins/initializeClient', () => { - // Set up environment variables - const originalEnvironment = process.env; - const app = { - locals: {}, - }; - - const validAzureConfigs = [ - { - group: 'librechat-westus', - apiKey: 'WESTUS_API_KEY', - instanceName: 'librechat-westus', - version: '2023-12-01-preview', - models: { - 'gpt-4-vision-preview': { - deploymentName: 'gpt-4-vision-preview', - version: '2024-02-15-preview', - }, - 'gpt-3.5-turbo': { - deploymentName: 'gpt-35-turbo', - }, - 'gpt-3.5-turbo-1106': { - deploymentName: 'gpt-35-turbo-1106', - }, - 'gpt-4': { - deploymentName: 'gpt-4', - }, - 'gpt-4-1106-preview': { - deploymentName: 'gpt-4-1106-preview', - }, - }, - }, - { - group: 'librechat-eastus', - apiKey: 'EASTUS_API_KEY', - instanceName: 'librechat-eastus', - deploymentName: 'gpt-4-turbo', - version: '2024-02-15-preview', - models: { - 'gpt-4-turbo': true, - }, - baseURL: 'https://eastus.example.com', - additionalHeaders: { - 'x-api-key': 'x-api-key-value', - }, - }, - { - group: 'mistral-inference', - apiKey: 'AZURE_MISTRAL_API_KEY', - baseURL: - 'https://Mistral-large-vnpet-serverless.region.inference.ai.azure.com/v1/chat/completions', - serverless: true, - models: { - 'mistral-large': true, - }, - }, - { - group: 'llama-70b-chat', - apiKey: 'AZURE_LLAMA2_70B_API_KEY', - baseURL: - 'https://Llama-2-70b-chat-qmvyb-serverless.region.inference.ai.azure.com/v1/chat/completions', - serverless: true, - models: { - 'llama-70b-chat': true, - }, - }, - ]; - - const { modelNames, modelGroupMap, groupMap } = validateAzureGroups(validAzureConfigs); - - beforeEach(() => { - jest.resetModules(); // Clears the cache - process.env = { ...originalEnvironment }; // Make a copy - }); - - afterAll(() => { - process.env = originalEnvironment; // Restore original env vars - }); - - test('should initialize PluginsClient with OpenAI API key and default options', async () => { - process.env.OPENAI_API_KEY = 'test-openai-api-key'; - process.env.PLUGINS_USE_AZURE = 'false'; - process.env.DEBUG_PLUGINS = 'false'; - process.env.OPENAI_SUMMARIZE = 'false'; - - const req = { - body: { key: null }, - user: { id: '123' }, - app, - }; - const res = {}; - const endpointOption = { modelOptions: { model: 'default-model' } }; - - const { client, openAIApiKey } = await initializeClient({ req, res, endpointOption }); - - expect(openAIApiKey).toBe('test-openai-api-key'); - expect(client).toBeInstanceOf(PluginsClient); - }); - - test('should initialize PluginsClient with Azure credentials when PLUGINS_USE_AZURE is true', async () => { - process.env.AZURE_API_KEY = 'test-azure-api-key'; - (process.env.AZURE_OPENAI_API_INSTANCE_NAME = 'some-value'), - (process.env.AZURE_OPENAI_API_DEPLOYMENT_NAME = 'some-value'), - (process.env.AZURE_OPENAI_API_VERSION = 'some-value'), - (process.env.AZURE_OPENAI_API_COMPLETIONS_DEPLOYMENT_NAME = 'some-value'), - (process.env.AZURE_OPENAI_API_EMBEDDINGS_DEPLOYMENT_NAME = 'some-value'), - (process.env.PLUGINS_USE_AZURE = 'true'); - process.env.DEBUG_PLUGINS = 'false'; - process.env.OPENAI_SUMMARIZE = 'false'; - - const req = { - body: { key: null }, - user: { id: '123' }, - app, - }; - const res = {}; - const endpointOption = { modelOptions: { model: 'test-model' } }; - - const { client, azure } = await initializeClient({ req, res, endpointOption }); - - expect(azure.azureOpenAIApiKey).toBe('test-azure-api-key'); - expect(client).toBeInstanceOf(PluginsClient); - }); - - test('should use the debug option when DEBUG_PLUGINS is enabled', async () => { - process.env.OPENAI_API_KEY = 'test-openai-api-key'; - process.env.DEBUG_PLUGINS = 'true'; - - const req = { - body: { key: null }, - user: { id: '123' }, - app, - }; - const res = {}; - const endpointOption = { modelOptions: { model: 'default-model' } }; - - const { client } = await initializeClient({ req, res, endpointOption }); - - expect(client.options.debug).toBe(true); - }); - - test('should set contextStrategy to summarize when OPENAI_SUMMARIZE is enabled', async () => { - process.env.OPENAI_API_KEY = 'test-openai-api-key'; - process.env.OPENAI_SUMMARIZE = 'true'; - - const req = { - body: { key: null }, - user: { id: '123' }, - app, - }; - const res = {}; - const endpointOption = { modelOptions: { model: 'default-model' } }; - - const { client } = await initializeClient({ req, res, endpointOption }); - - expect(client.options.contextStrategy).toBe('summarize'); - }); - - // ... additional tests for reverseProxyUrl, proxy, user-provided keys, etc. - - test('should throw an error if no API keys are provided in the environment', async () => { - // Clear the environment variables for API keys - delete process.env.OPENAI_API_KEY; - delete process.env.AZURE_API_KEY; - - const req = { - body: { key: null }, - user: { id: '123' }, - app, - }; - const res = {}; - const endpointOption = { modelOptions: { model: 'default-model' } }; - - await expect(initializeClient({ req, res, endpointOption })).rejects.toThrow( - `${EModelEndpoint.openAI} API key not provided.`, - ); - }); - - // Additional tests for gptPlugins/initializeClient.spec.js - - // ... (previous test setup code) - - test('should handle user-provided OpenAI keys and check expiry', async () => { - process.env.OPENAI_API_KEY = 'user_provided'; - process.env.PLUGINS_USE_AZURE = 'false'; - - const futureDate = new Date(Date.now() + 10000).toISOString(); - const req = { - body: { key: futureDate }, - user: { id: '123' }, - app, - }; - const res = {}; - const endpointOption = { modelOptions: { model: 'default-model' } }; - - getUserKeyValues.mockResolvedValue({ apiKey: 'test-user-provided-openai-api-key' }); - - const { openAIApiKey } = await initializeClient({ req, res, endpointOption }); - - expect(openAIApiKey).toBe('test-user-provided-openai-api-key'); - }); - - test('should handle user-provided Azure keys and check expiry', async () => { - process.env.AZURE_API_KEY = 'user_provided'; - process.env.PLUGINS_USE_AZURE = 'true'; - - const futureDate = new Date(Date.now() + 10000).toISOString(); - const req = { - body: { key: futureDate }, - user: { id: '123' }, - app, - }; - const res = {}; - const endpointOption = { modelOptions: { model: 'test-model' } }; - - getUserKeyValues.mockResolvedValue({ - apiKey: JSON.stringify({ - azureOpenAIApiKey: 'test-user-provided-azure-api-key', - azureOpenAIApiDeploymentName: 'test-deployment', - }), - }); - - const { azure } = await initializeClient({ req, res, endpointOption }); - - expect(azure.azureOpenAIApiKey).toBe('test-user-provided-azure-api-key'); - }); - - test('should throw an error if the user-provided key has expired', async () => { - process.env.OPENAI_API_KEY = 'user_provided'; - process.env.PLUGINS_USE_AZURE = 'FALSE'; - const expiresAt = new Date(Date.now() - 10000).toISOString(); // Expired - const req = { - body: { key: expiresAt }, - user: { id: '123' }, - app, - }; - const res = {}; - const endpointOption = { modelOptions: { model: 'default-model' } }; - - await expect(initializeClient({ req, res, endpointOption })).rejects.toThrow( - /expired_user_key/, - ); - }); - - test('should throw an error if the user-provided Azure key is invalid JSON', async () => { - process.env.AZURE_API_KEY = 'user_provided'; - process.env.PLUGINS_USE_AZURE = 'true'; - - const req = { - body: { key: new Date(Date.now() + 10000).toISOString() }, - user: { id: '123' }, - app, - }; - const res = {}; - const endpointOption = { modelOptions: { model: 'default-model' } }; - - // Simulate an invalid JSON string returned from getUserKey - getUserKey.mockResolvedValue('invalid-json'); - getUserKeyValues.mockImplementation(() => { - let userValues = getUserKey(); - try { - userValues = JSON.parse(userValues); - } catch (e) { - throw new Error( - JSON.stringify({ - type: ErrorTypes.INVALID_USER_KEY, - }), - ); - } - return userValues; - }); - - await expect(initializeClient({ req, res, endpointOption })).rejects.toThrow( - /invalid_user_key/, - ); - }); - - test('should correctly handle the presence of a reverse proxy', async () => { - process.env.OPENAI_REVERSE_PROXY = 'http://reverse.proxy'; - process.env.PROXY = 'http://proxy'; - process.env.OPENAI_API_KEY = 'test-openai-api-key'; - - const req = { - body: { key: null }, - user: { id: '123' }, - app, - }; - const res = {}; - const endpointOption = { modelOptions: { model: 'default-model' } }; - - const { client } = await initializeClient({ req, res, endpointOption }); - - expect(client.options.reverseProxyUrl).toBe('http://reverse.proxy'); - expect(client.options.proxy).toBe('http://proxy'); - }); - - test('should throw an error when user-provided values are not valid JSON', async () => { - process.env.OPENAI_API_KEY = 'user_provided'; - const req = { - body: { key: new Date(Date.now() + 10000).toISOString(), endpoint: 'openAI' }, - user: { id: '123' }, - app, - }; - const res = {}; - const endpointOption = {}; - - // Mock getUserKey to return a non-JSON string - getUserKey.mockResolvedValue('not-a-json'); - getUserKeyValues.mockImplementation(() => { - let userValues = getUserKey(); - try { - userValues = JSON.parse(userValues); - } catch (e) { - throw new Error( - JSON.stringify({ - type: ErrorTypes.INVALID_USER_KEY, - }), - ); - } - return userValues; - }); - - await expect(initializeClient({ req, res, endpointOption })).rejects.toThrow( - /invalid_user_key/, - ); - }); - - test('should initialize client correctly for Azure OpenAI with valid configuration', async () => { - const req = { - body: { - key: null, - endpoint: EModelEndpoint.gptPlugins, - model: modelNames[0], - }, - user: { id: '123' }, - app: { - locals: { - [EModelEndpoint.azureOpenAI]: { - plugins: true, - modelNames, - modelGroupMap, - groupMap, - }, - }, - }, - }; - const res = {}; - const endpointOption = {}; - - const client = await initializeClient({ req, res, endpointOption }); - expect(client.client.options.azure).toBeDefined(); - }); - - test('should initialize client with default options when certain env vars are not set', async () => { - delete process.env.OPENAI_SUMMARIZE; - process.env.OPENAI_API_KEY = 'some-api-key'; - - const req = { - body: { key: null, endpoint: EModelEndpoint.gptPlugins }, - user: { id: '123' }, - app, - }; - const res = {}; - const endpointOption = {}; - - const client = await initializeClient({ req, res, endpointOption }); - expect(client.client.options.contextStrategy).toBe(null); - }); - - test('should correctly use user-provided apiKey and baseURL when provided', async () => { - process.env.OPENAI_API_KEY = 'user_provided'; - process.env.OPENAI_REVERSE_PROXY = 'user_provided'; - const req = { - body: { - key: new Date(Date.now() + 10000).toISOString(), - endpoint: 'openAI', - }, - user: { - id: '123', - }, - app, - }; - const res = {}; - const endpointOption = {}; - - getUserKeyValues.mockResolvedValue({ - apiKey: 'test', - baseURL: 'https://user-provided-url.com', - }); - - const result = await initializeClient({ req, res, endpointOption }); - - expect(result.openAIApiKey).toBe('test'); - expect(result.client.options.reverseProxyUrl).toBe('https://user-provided-url.com'); - }); -}); diff --git a/api/server/services/Endpoints/index.js b/api/server/services/Endpoints/index.js new file mode 100644 index 0000000000..71d04a1ae2 --- /dev/null +++ b/api/server/services/Endpoints/index.js @@ -0,0 +1,75 @@ +const { Providers } = require('@librechat/agents'); +const { EModelEndpoint } = require('librechat-data-provider'); +const initAnthropic = require('~/server/services/Endpoints/anthropic/initialize'); +const getBedrockOptions = require('~/server/services/Endpoints/bedrock/options'); +const initOpenAI = require('~/server/services/Endpoints/openAI/initialize'); +const initCustom = require('~/server/services/Endpoints/custom/initialize'); +const initGoogle = require('~/server/services/Endpoints/google/initialize'); +const { getCustomEndpointConfig } = require('~/server/services/Config'); + +/** Check if the provider is a known custom provider + * @param {string | undefined} [provider] - The provider string + * @returns {boolean} - True if the provider is a known custom provider, false otherwise + */ +function isKnownCustomProvider(provider) { + return [Providers.XAI, Providers.OLLAMA, Providers.DEEPSEEK, Providers.OPENROUTER].includes( + provider?.toLowerCase() || '', + ); +} + +const providerConfigMap = { + [Providers.XAI]: initCustom, + [Providers.OLLAMA]: initCustom, + [Providers.DEEPSEEK]: initCustom, + [Providers.OPENROUTER]: initCustom, + [EModelEndpoint.openAI]: initOpenAI, + [EModelEndpoint.google]: initGoogle, + [EModelEndpoint.azureOpenAI]: initOpenAI, + [EModelEndpoint.anthropic]: initAnthropic, + [EModelEndpoint.bedrock]: getBedrockOptions, +}; + +/** + * Get the provider configuration and override endpoint based on the provider string + * @param {string} provider - The provider string + * @returns {Promise<{ + * getOptions: Function, + * overrideProvider?: string, + * customEndpointConfig?: TEndpoint + * }>} + */ +async function getProviderConfig(provider) { + let getOptions = providerConfigMap[provider]; + let overrideProvider; + /** @type {TEndpoint | undefined} */ + let customEndpointConfig; + + if (!getOptions && providerConfigMap[provider.toLowerCase()] != null) { + overrideProvider = provider.toLowerCase(); + getOptions = providerConfigMap[overrideProvider]; + } else if (!getOptions) { + customEndpointConfig = await getCustomEndpointConfig(provider); + if (!customEndpointConfig) { + throw new Error(`Provider ${provider} not supported`); + } + getOptions = initCustom; + overrideProvider = Providers.OPENAI; + } + + if (isKnownCustomProvider(overrideProvider || provider) && !customEndpointConfig) { + customEndpointConfig = await getCustomEndpointConfig(provider); + if (!customEndpointConfig) { + throw new Error(`Provider ${provider} not supported`); + } + } + + return { + getOptions, + overrideProvider, + customEndpointConfig, + }; +} + +module.exports = { + getProviderConfig, +}; diff --git a/api/server/services/Endpoints/openAI/initialize.js b/api/server/services/Endpoints/openAI/initialize.js index bc0907b3de..e86596181a 100644 --- a/api/server/services/Endpoints/openAI/initialize.js +++ b/api/server/services/Endpoints/openAI/initialize.js @@ -1,11 +1,7 @@ -const { - ErrorTypes, - EModelEndpoint, - resolveHeaders, - mapModelToAzureConfig, -} = require('librechat-data-provider'); +const { ErrorTypes, EModelEndpoint, mapModelToAzureConfig } = require('librechat-data-provider'); const { isEnabled, + resolveHeaders, isUserProvided, getOpenAIConfig, getAzureCredentials, @@ -84,7 +80,10 @@ const initializeClient = async ({ }); clientOptions.reverseProxyUrl = baseURL ?? clientOptions.reverseProxyUrl; - clientOptions.headers = resolveHeaders({ ...headers, ...(clientOptions.headers ?? {}) }); + clientOptions.headers = resolveHeaders( + { ...headers, ...(clientOptions.headers ?? {}) }, + req.user, + ); clientOptions.titleConvo = azureConfig.titleConvo; clientOptions.titleModel = azureConfig.titleModel; @@ -139,7 +138,7 @@ const initializeClient = async ({ } if (optionsOnly) { - const modelOptions = endpointOption.model_parameters; + const modelOptions = endpointOption?.model_parameters ?? {}; modelOptions.model = modelName; clientOptions = Object.assign({ modelOptions }, clientOptions); clientOptions.modelOptions.user = req.user.id; diff --git a/api/server/services/Files/Local/crud.js b/api/server/services/Files/Local/crud.js index 7df528c5e1..455d4e0c4f 100644 --- a/api/server/services/Files/Local/crud.js +++ b/api/server/services/Files/Local/crud.js @@ -1,10 +1,11 @@ const fs = require('fs'); const path = require('path'); const axios = require('axios'); +const { logger } = require('@librechat/data-schemas'); const { EModelEndpoint } = require('librechat-data-provider'); +const { generateShortLivedToken } = require('~/server/services/AuthService'); const { getBufferMetadata } = require('~/server/utils'); const paths = require('~/config/paths'); -const { logger } = require('~/config'); /** * Saves a file to a specified output path with a new filename. @@ -206,7 +207,7 @@ const deleteLocalFile = async (req, file) => { const cleanFilepath = file.filepath.split('?')[0]; if (file.embedded && process.env.RAG_API_URL) { - const jwtToken = req.headers.authorization.split(' ')[1]; + const jwtToken = generateShortLivedToken(req.user.id); axios.delete(`${process.env.RAG_API_URL}/documents`, { headers: { Authorization: `Bearer ${jwtToken}`, diff --git a/api/server/services/Files/VectorDB/crud.js b/api/server/services/Files/VectorDB/crud.js index 1aeabc6c46..d7018f7669 100644 --- a/api/server/services/Files/VectorDB/crud.js +++ b/api/server/services/Files/VectorDB/crud.js @@ -4,6 +4,7 @@ const FormData = require('form-data'); const { logAxiosError } = require('@librechat/api'); const { logger } = require('@librechat/data-schemas'); const { FileSources } = require('librechat-data-provider'); +const { generateShortLivedToken } = require('~/server/services/AuthService'); /** * Deletes a file from the vector database. This function takes a file object, constructs the full path, and @@ -23,7 +24,8 @@ const deleteVectors = async (req, file) => { return; } try { - const jwtToken = req.headers.authorization.split(' ')[1]; + const jwtToken = generateShortLivedToken(req.user.id); + return await axios.delete(`${process.env.RAG_API_URL}/documents`, { headers: { Authorization: `Bearer ${jwtToken}`, @@ -70,7 +72,7 @@ async function uploadVectors({ req, file, file_id, entity_id }) { } try { - const jwtToken = req.headers.authorization.split(' ')[1]; + const jwtToken = generateShortLivedToken(req.user.id); const formData = new FormData(); formData.append('file_id', file_id); formData.append('file', fs.createReadStream(file.path)); diff --git a/api/server/services/Files/process.js b/api/server/services/Files/process.js index 8910163047..38ccdafdd7 100644 --- a/api/server/services/Files/process.js +++ b/api/server/services/Files/process.js @@ -55,7 +55,9 @@ const processFiles = async (files, fileIds) => { } if (!fileIds) { - return await Promise.all(promises); + const results = await Promise.all(promises); + // Filter out null results from failed updateFileUsage calls + return results.filter((result) => result != null); } for (let file_id of fileIds) { @@ -67,7 +69,9 @@ const processFiles = async (files, fileIds) => { } // TODO: calculate token cost when image is first uploaded - return await Promise.all(promises); + const results = await Promise.all(promises); + // Filter out null results from failed updateFileUsage calls + return results.filter((result) => result != null); }; /** diff --git a/api/server/services/Files/processFiles.test.js b/api/server/services/Files/processFiles.test.js new file mode 100644 index 0000000000..8665d33665 --- /dev/null +++ b/api/server/services/Files/processFiles.test.js @@ -0,0 +1,208 @@ +// Mock the updateFileUsage function before importing the actual processFiles +jest.mock('~/models/File', () => ({ + updateFileUsage: jest.fn(), +})); + +// Mock winston and logger configuration to avoid dependency issues +jest.mock('~/config', () => ({ + logger: { + info: jest.fn(), + warn: jest.fn(), + debug: jest.fn(), + error: jest.fn(), + }, +})); + +// Mock all other dependencies that might cause issues +jest.mock('librechat-data-provider', () => ({ + isUUID: { parse: jest.fn() }, + megabyte: 1024 * 1024, + FileContext: { message_attachment: 'message_attachment' }, + FileSources: { local: 'local' }, + EModelEndpoint: { assistants: 'assistants' }, + EToolResources: { file_search: 'file_search' }, + mergeFileConfig: jest.fn(), + removeNullishValues: jest.fn((obj) => obj), + isAssistantsEndpoint: jest.fn(), +})); + +jest.mock('~/server/services/Files/images', () => ({ + convertImage: jest.fn(), + resizeAndConvert: jest.fn(), + resizeImageBuffer: jest.fn(), +})); + +jest.mock('~/server/controllers/assistants/v2', () => ({ + addResourceFileId: jest.fn(), + deleteResourceFileId: jest.fn(), +})); + +jest.mock('~/models/Agent', () => ({ + addAgentResourceFile: jest.fn(), + removeAgentResourceFiles: jest.fn(), +})); + +jest.mock('~/server/controllers/assistants/helpers', () => ({ + getOpenAIClient: jest.fn(), +})); + +jest.mock('~/server/services/Tools/credentials', () => ({ + loadAuthValues: jest.fn(), +})); + +jest.mock('~/server/services/Config', () => ({ + checkCapability: jest.fn(), +})); + +jest.mock('~/server/utils/queue', () => ({ + LB_QueueAsyncCall: jest.fn(), +})); + +jest.mock('./strategies', () => ({ + getStrategyFunctions: jest.fn(), +})); + +jest.mock('~/server/utils', () => ({ + determineFileType: jest.fn(), +})); + +// Import the actual processFiles function after all mocks are set up +const { processFiles } = require('./process'); +const { updateFileUsage } = require('~/models/File'); + +describe('processFiles', () => { + beforeEach(() => { + jest.clearAllMocks(); + }); + + describe('null filtering functionality', () => { + it('should filter out null results from updateFileUsage when files do not exist', async () => { + const mockFiles = [ + { file_id: 'existing-file-1' }, + { file_id: 'non-existent-file' }, + { file_id: 'existing-file-2' }, + ]; + + // Mock updateFileUsage to return null for non-existent files + updateFileUsage.mockImplementation(({ file_id }) => { + if (file_id === 'non-existent-file') { + return Promise.resolve(null); // Simulate file not found in the database + } + return Promise.resolve({ file_id, usage: 1 }); + }); + + const result = await processFiles(mockFiles); + + expect(updateFileUsage).toHaveBeenCalledTimes(3); + expect(result).toEqual([ + { file_id: 'existing-file-1', usage: 1 }, + { file_id: 'existing-file-2', usage: 1 }, + ]); + + // Critical test - ensure no null values in result + expect(result).not.toContain(null); + expect(result).not.toContain(undefined); + expect(result.length).toBe(2); // Only valid files should be returned + }); + + it('should return empty array when all updateFileUsage calls return null', async () => { + const mockFiles = [{ file_id: 'non-existent-1' }, { file_id: 'non-existent-2' }]; + + // All updateFileUsage calls return null + updateFileUsage.mockResolvedValue(null); + + const result = await processFiles(mockFiles); + + expect(updateFileUsage).toHaveBeenCalledTimes(2); + expect(result).toEqual([]); + expect(result).not.toContain(null); + expect(result.length).toBe(0); + }); + + it('should work correctly when all files exist', async () => { + const mockFiles = [{ file_id: 'file-1' }, { file_id: 'file-2' }]; + + updateFileUsage.mockImplementation(({ file_id }) => { + return Promise.resolve({ file_id, usage: 1 }); + }); + + const result = await processFiles(mockFiles); + + expect(result).toEqual([ + { file_id: 'file-1', usage: 1 }, + { file_id: 'file-2', usage: 1 }, + ]); + expect(result).not.toContain(null); + expect(result.length).toBe(2); + }); + + it('should handle fileIds parameter and filter nulls correctly', async () => { + const mockFiles = [{ file_id: 'file-1' }]; + const mockFileIds = ['file-2', 'non-existent-file']; + + updateFileUsage.mockImplementation(({ file_id }) => { + if (file_id === 'non-existent-file') { + return Promise.resolve(null); + } + return Promise.resolve({ file_id, usage: 1 }); + }); + + const result = await processFiles(mockFiles, mockFileIds); + + expect(result).toEqual([ + { file_id: 'file-1', usage: 1 }, + { file_id: 'file-2', usage: 1 }, + ]); + expect(result).not.toContain(null); + expect(result).not.toContain(undefined); + expect(result.length).toBe(2); + }); + + it('should handle duplicate file_ids correctly', async () => { + const mockFiles = [ + { file_id: 'duplicate-file' }, + { file_id: 'duplicate-file' }, // Duplicate should be ignored + { file_id: 'unique-file' }, + ]; + + updateFileUsage.mockImplementation(({ file_id }) => { + return Promise.resolve({ file_id, usage: 1 }); + }); + + const result = await processFiles(mockFiles); + + // Should only call updateFileUsage twice (duplicate ignored) + expect(updateFileUsage).toHaveBeenCalledTimes(2); + expect(result).toEqual([ + { file_id: 'duplicate-file', usage: 1 }, + { file_id: 'unique-file', usage: 1 }, + ]); + expect(result.length).toBe(2); + }); + }); + + describe('edge cases', () => { + it('should handle empty files array', async () => { + const result = await processFiles([]); + expect(result).toEqual([]); + expect(updateFileUsage).not.toHaveBeenCalled(); + }); + + it('should handle mixed null and undefined returns from updateFileUsage', async () => { + const mockFiles = [{ file_id: 'file-1' }, { file_id: 'file-2' }, { file_id: 'file-3' }]; + + updateFileUsage.mockImplementation(({ file_id }) => { + if (file_id === 'file-1') return Promise.resolve(null); + if (file_id === 'file-2') return Promise.resolve(undefined); + return Promise.resolve({ file_id, usage: 1 }); + }); + + const result = await processFiles(mockFiles); + + expect(result).toEqual([{ file_id: 'file-3', usage: 1 }]); + expect(result).not.toContain(null); + expect(result).not.toContain(undefined); + expect(result.length).toBe(1); + }); + }); +}); diff --git a/api/server/services/Files/strategies.js b/api/server/services/Files/strategies.js index 41dcd5518a..4f8067142b 100644 --- a/api/server/services/Files/strategies.js +++ b/api/server/services/Files/strategies.js @@ -1,5 +1,9 @@ const { FileSources } = require('librechat-data-provider'); -const { uploadMistralOCR, uploadAzureMistralOCR } = require('@librechat/api'); +const { + uploadMistralOCR, + uploadAzureMistralOCR, + uploadGoogleVertexMistralOCR, +} = require('@librechat/api'); const { getFirebaseURL, prepareImageURL, @@ -222,6 +226,26 @@ const azureMistralOCRStrategy = () => ({ handleFileUpload: uploadAzureMistralOCR, }); +const vertexMistralOCRStrategy = () => ({ + /** @type {typeof saveFileFromURL | null} */ + saveURL: null, + /** @type {typeof getLocalFileURL | null} */ + getFileURL: null, + /** @type {typeof saveLocalBuffer | null} */ + saveBuffer: null, + /** @type {typeof processLocalAvatar | null} */ + processAvatar: null, + /** @type {typeof uploadLocalImage | null} */ + handleImageUpload: null, + /** @type {typeof prepareImagesLocal | null} */ + prepareImagePayload: null, + /** @type {typeof deleteLocalFile | null} */ + deleteFile: null, + /** @type {typeof getLocalFileStream | null} */ + getDownloadStream: null, + handleFileUpload: uploadGoogleVertexMistralOCR, +}); + // Strategy Selector const getStrategyFunctions = (fileSource) => { if (fileSource === FileSources.firebase) { @@ -244,6 +268,8 @@ const getStrategyFunctions = (fileSource) => { return mistralOCRStrategy(); } else if (fileSource === FileSources.azure_mistral_ocr) { return azureMistralOCRStrategy(); + } else if (fileSource === FileSources.vertexai_mistral_ocr) { + return vertexMistralOCRStrategy(); } else { throw new Error('Invalid file source'); } diff --git a/api/server/services/Runs/StreamRunManager.js b/api/server/services/Runs/StreamRunManager.js index 4bab7326bb..4f6994e0cb 100644 --- a/api/server/services/Runs/StreamRunManager.js +++ b/api/server/services/Runs/StreamRunManager.js @@ -1,3 +1,6 @@ +const { sleep } = require('@librechat/agents'); +const { sendEvent } = require('@librechat/api'); +const { logger } = require('@librechat/data-schemas'); const { Constants, StepTypes, @@ -8,9 +11,8 @@ const { } = require('librechat-data-provider'); const { retrieveAndProcessFile } = require('~/server/services/Files/process'); const { processRequiredActions } = require('~/server/services/ToolService'); -const { createOnProgress, sendMessage, sleep } = require('~/server/utils'); const { processMessages } = require('~/server/services/Threads'); -const { logger } = require('~/config'); +const { createOnProgress } = require('~/server/utils'); /** * Implements the StreamRunManager functionality for managing the streaming @@ -126,7 +128,7 @@ class StreamRunManager { conversationId: this.finalMessage.conversationId, }; - sendMessage(this.res, contentData); + sendEvent(this.res, contentData); } /* <------------------ Misc. Helpers ------------------> */ @@ -302,7 +304,7 @@ class StreamRunManager { for (const d of delta[key]) { if (typeof d === 'object' && !Object.prototype.hasOwnProperty.call(d, 'index')) { - logger.warn('Expected an object with an \'index\' for array updates but got:', d); + logger.warn("Expected an object with an 'index' for array updates but got:", d); continue; } diff --git a/api/server/services/initializeMCP.js b/api/server/services/initializeMCP.js index d7c5ab7d8a..98b87d156e 100644 --- a/api/server/services/initializeMCP.js +++ b/api/server/services/initializeMCP.js @@ -1,9 +1,9 @@ const { logger } = require('@librechat/data-schemas'); -const { CacheKeys, processMCPEnv } = require('librechat-data-provider'); +const { CacheKeys } = require('librechat-data-provider'); +const { findToken, updateToken, createToken, deleteTokens } = require('~/models'); const { getMCPManager, getFlowStateManager } = require('~/config'); const { getCachedTools, setCachedTools } = require('./Config'); const { getLogStores } = require('~/cache'); -const { findToken, updateToken, createToken, deleteTokens } = require('~/models'); /** * Initialize MCP servers @@ -30,7 +30,6 @@ async function initializeMCP(app) { createToken, deleteTokens, }, - processMCPEnv, }); delete app.locals.mcpConfig; diff --git a/api/server/services/start/interface.js b/api/server/services/start/interface.js index c98fdb60bc..5c08b1af2e 100644 --- a/api/server/services/start/interface.js +++ b/api/server/services/start/interface.js @@ -41,6 +41,7 @@ async function loadDefaultInterface(config, configDefaults, roleName = SystemRol sidePanel: interfaceConfig?.sidePanel ?? defaults.sidePanel, privacyPolicy: interfaceConfig?.privacyPolicy ?? defaults.privacyPolicy, termsOfService: interfaceConfig?.termsOfService ?? defaults.termsOfService, + mcpServers: interfaceConfig?.mcpServers ?? defaults.mcpServers, bookmarks: interfaceConfig?.bookmarks ?? defaults.bookmarks, memories: shouldDisableMemories ? false : (interfaceConfig?.memories ?? defaults.memories), prompts: interfaceConfig?.prompts ?? defaults.prompts, diff --git a/api/server/utils/handleText.js b/api/server/utils/handleText.js index 680da5da44..36671c44ff 100644 --- a/api/server/utils/handleText.js +++ b/api/server/utils/handleText.js @@ -7,9 +7,9 @@ const { defaultAssistantsVersion, defaultAgentCapabilities, } = require('librechat-data-provider'); +const { sendEvent } = require('@librechat/api'); const { Providers } = require('@librechat/agents'); const partialRight = require('lodash/partialRight'); -const { sendMessage } = require('./streamResponse'); /** Helper function to escape special characters in regex * @param {string} string - The string to escape. @@ -37,7 +37,7 @@ const createOnProgress = ( basePayload.text = basePayload.text + chunk; const payload = Object.assign({}, basePayload, rest); - sendMessage(res, payload); + sendEvent(res, payload); if (_onProgress) { _onProgress(payload); } @@ -50,7 +50,7 @@ const createOnProgress = ( const sendIntermediateMessage = (res, payload, extraTokens = '') => { basePayload.text = basePayload.text + extraTokens; const message = Object.assign({}, basePayload, payload); - sendMessage(res, message); + sendEvent(res, message); if (i === 0) { basePayload.initial = false; } diff --git a/api/server/utils/import/importers-timestamp.spec.js b/api/server/utils/import/importers-timestamp.spec.js new file mode 100644 index 0000000000..2ce00de82b --- /dev/null +++ b/api/server/utils/import/importers-timestamp.spec.js @@ -0,0 +1,280 @@ +const { Constants } = require('librechat-data-provider'); +const { ImportBatchBuilder } = require('./importBatchBuilder'); +const { getImporter } = require('./importers'); + +// Mock the database methods +jest.mock('~/models/Conversation', () => ({ + bulkSaveConvos: jest.fn(), +})); +jest.mock('~/models/Message', () => ({ + bulkSaveMessages: jest.fn(), +})); +jest.mock('~/cache/getLogStores'); +const getLogStores = require('~/cache/getLogStores'); +const mockedCacheGet = jest.fn(); +getLogStores.mockImplementation(() => ({ + get: mockedCacheGet, +})); + +describe('Import Timestamp Ordering', () => { + beforeEach(() => { + jest.clearAllMocks(); + mockedCacheGet.mockResolvedValue(null); + }); + + describe('LibreChat Import - Timestamp Issues', () => { + test('should maintain proper timestamp order between parent and child messages', async () => { + // Create a LibreChat export with out-of-order timestamps + const jsonData = { + conversationId: 'test-convo-123', + title: 'Test Conversation', + messages: [ + { + messageId: 'parent-1', + parentMessageId: Constants.NO_PARENT, + text: 'Parent Message', + sender: 'user', + isCreatedByUser: true, + createdAt: '2023-01-01T00:02:00Z', // Parent created AFTER child + }, + { + messageId: 'child-1', + parentMessageId: 'parent-1', + text: 'Child Message', + sender: 'assistant', + isCreatedByUser: false, + createdAt: '2023-01-01T00:01:00Z', // Child created BEFORE parent + }, + { + messageId: 'grandchild-1', + parentMessageId: 'child-1', + text: 'Grandchild Message', + sender: 'user', + isCreatedByUser: true, + createdAt: '2023-01-01T00:00:30Z', // Even earlier + }, + ], + }; + + const requestUserId = 'user-123'; + const importBatchBuilder = new ImportBatchBuilder(requestUserId); + jest.spyOn(importBatchBuilder, 'saveMessage'); + + const importer = getImporter(jsonData); + await importer(jsonData, requestUserId, () => importBatchBuilder); + + // Check the actual messages stored in the builder + const savedMessages = importBatchBuilder.messages; + + const parent = savedMessages.find((msg) => msg.text === 'Parent Message'); + const child = savedMessages.find((msg) => msg.text === 'Child Message'); + const grandchild = savedMessages.find((msg) => msg.text === 'Grandchild Message'); + + // Verify all messages were found + expect(parent).toBeDefined(); + expect(child).toBeDefined(); + expect(grandchild).toBeDefined(); + + // FIXED behavior: timestamps ARE corrected + expect(new Date(child.createdAt).getTime()).toBeGreaterThan( + new Date(parent.createdAt).getTime(), + ); + expect(new Date(grandchild.createdAt).getTime()).toBeGreaterThan( + new Date(child.createdAt).getTime(), + ); + }); + + test('should handle complex multi-branch scenario with out-of-order timestamps', async () => { + const jsonData = { + conversationId: 'complex-test-123', + title: 'Complex Test', + messages: [ + // Branch 1: Root -> A -> B with reversed timestamps + { + messageId: 'root-1', + parentMessageId: Constants.NO_PARENT, + text: 'Root 1', + sender: 'user', + isCreatedByUser: true, + createdAt: '2023-01-01T00:03:00Z', + }, + { + messageId: 'a-1', + parentMessageId: 'root-1', + text: 'A1', + sender: 'assistant', + isCreatedByUser: false, + createdAt: '2023-01-01T00:02:00Z', // Before parent + }, + { + messageId: 'b-1', + parentMessageId: 'a-1', + text: 'B1', + sender: 'user', + isCreatedByUser: true, + createdAt: '2023-01-01T00:01:00Z', // Before grandparent + }, + // Branch 2: Root -> C -> D with mixed timestamps + { + messageId: 'root-2', + parentMessageId: Constants.NO_PARENT, + text: 'Root 2', + sender: 'user', + isCreatedByUser: true, + createdAt: '2023-01-01T00:00:30Z', // Earlier than branch 1 + }, + { + messageId: 'c-2', + parentMessageId: 'root-2', + text: 'C2', + sender: 'assistant', + isCreatedByUser: false, + createdAt: '2023-01-01T00:04:00Z', // Much later + }, + { + messageId: 'd-2', + parentMessageId: 'c-2', + text: 'D2', + sender: 'user', + isCreatedByUser: true, + createdAt: '2023-01-01T00:02:30Z', // Between root and parent + }, + ], + }; + + const requestUserId = 'user-123'; + const importBatchBuilder = new ImportBatchBuilder(requestUserId); + jest.spyOn(importBatchBuilder, 'saveMessage'); + + const importer = getImporter(jsonData); + await importer(jsonData, requestUserId, () => importBatchBuilder); + + const savedMessages = importBatchBuilder.messages; + + // Verify that timestamps are preserved as-is (not corrected) + const root1 = savedMessages.find((msg) => msg.text === 'Root 1'); + const a1 = savedMessages.find((msg) => msg.text === 'A1'); + const b1 = savedMessages.find((msg) => msg.text === 'B1'); + const root2 = savedMessages.find((msg) => msg.text === 'Root 2'); + const c2 = savedMessages.find((msg) => msg.text === 'C2'); + const d2 = savedMessages.find((msg) => msg.text === 'D2'); + + // Branch 1: timestamps should now be in correct order + expect(new Date(a1.createdAt).getTime()).toBeGreaterThan(new Date(root1.createdAt).getTime()); + expect(new Date(b1.createdAt).getTime()).toBeGreaterThan(new Date(a1.createdAt).getTime()); + + // Branch 2: all timestamps should be properly ordered + expect(new Date(c2.createdAt).getTime()).toBeGreaterThan(new Date(root2.createdAt).getTime()); + expect(new Date(d2.createdAt).getTime()).toBeGreaterThan(new Date(c2.createdAt).getTime()); + }); + + test('recursive format should NOW have timestamp protection', async () => { + // Create a recursive LibreChat export with out-of-order timestamps + const jsonData = { + conversationId: 'recursive-test-123', + title: 'Recursive Test', + recursive: true, + messages: [ + { + messageId: 'parent-1', + parentMessageId: Constants.NO_PARENT, + text: 'Parent Message', + sender: 'User', + isCreatedByUser: true, + createdAt: '2023-01-01T00:02:00Z', // Parent created AFTER child + children: [ + { + messageId: 'child-1', + parentMessageId: 'parent-1', + text: 'Child Message', + sender: 'Assistant', + isCreatedByUser: false, + createdAt: '2023-01-01T00:01:00Z', // Child created BEFORE parent + children: [ + { + messageId: 'grandchild-1', + parentMessageId: 'child-1', + text: 'Grandchild Message', + sender: 'User', + isCreatedByUser: true, + createdAt: '2023-01-01T00:00:30Z', // Even earlier + children: [], + }, + ], + }, + ], + }, + ], + }; + + const requestUserId = 'user-123'; + const importBatchBuilder = new ImportBatchBuilder(requestUserId); + + const importer = getImporter(jsonData); + await importer(jsonData, requestUserId, () => importBatchBuilder); + + const savedMessages = importBatchBuilder.messages; + + // Messages should be saved + expect(savedMessages).toHaveLength(3); + + // In recursive format, timestamps are NOT included in the saved messages + // The saveMessage method doesn't receive createdAt for recursive imports + const parent = savedMessages.find((msg) => msg.text === 'Parent Message'); + const child = savedMessages.find((msg) => msg.text === 'Child Message'); + const grandchild = savedMessages.find((msg) => msg.text === 'Grandchild Message'); + + expect(parent).toBeDefined(); + expect(child).toBeDefined(); + expect(grandchild).toBeDefined(); + + // Recursive imports NOW preserve and correct timestamps + expect(parent.createdAt).toBeDefined(); + expect(child.createdAt).toBeDefined(); + expect(grandchild.createdAt).toBeDefined(); + + // Timestamps should be corrected to maintain proper order + expect(new Date(child.createdAt).getTime()).toBeGreaterThan( + new Date(parent.createdAt).getTime(), + ); + expect(new Date(grandchild.createdAt).getTime()).toBeGreaterThan( + new Date(child.createdAt).getTime(), + ); + }); + }); + + describe('Comparison with Fork Functionality', () => { + test('fork functionality correctly handles timestamp issues (for comparison)', async () => { + const { cloneMessagesWithTimestamps } = require('./fork'); + + const messagesToClone = [ + { + messageId: 'parent', + parentMessageId: Constants.NO_PARENT, + text: 'Parent Message', + createdAt: '2023-01-01T00:02:00Z', // Parent created AFTER child + }, + { + messageId: 'child', + parentMessageId: 'parent', + text: 'Child Message', + createdAt: '2023-01-01T00:01:00Z', // Child created BEFORE parent + }, + ]; + + const importBatchBuilder = new ImportBatchBuilder('user-123'); + jest.spyOn(importBatchBuilder, 'saveMessage'); + + cloneMessagesWithTimestamps(messagesToClone, importBatchBuilder); + + const savedMessages = importBatchBuilder.messages; + const parent = savedMessages.find((msg) => msg.text === 'Parent Message'); + const child = savedMessages.find((msg) => msg.text === 'Child Message'); + + // Fork functionality DOES correct the timestamps + expect(new Date(child.createdAt).getTime()).toBeGreaterThan( + new Date(parent.createdAt).getTime(), + ); + }); + }); +}); diff --git a/api/server/utils/import/importers.js b/api/server/utils/import/importers.js index b828fed021..ce5ab62454 100644 --- a/api/server/utils/import/importers.js +++ b/api/server/utils/import/importers.js @@ -1,6 +1,7 @@ const { v4: uuidv4 } = require('uuid'); const { EModelEndpoint, Constants, openAISettings, CacheKeys } = require('librechat-data-provider'); const { createImportBatchBuilder } = require('./importBatchBuilder'); +const { cloneMessagesWithTimestamps } = require('./fork'); const getLogStores = require('~/cache/getLogStores'); const logger = require('~/config/winston'); @@ -107,67 +108,47 @@ async function importLibreChatConvo( if (jsonData.recursive) { /** - * Recursively traverse the messages tree and save each message to the database. + * Flatten the recursive message tree into a flat array * @param {TMessage[]} messages * @param {string} parentMessageId + * @param {TMessage[]} flatMessages */ - const traverseMessages = async (messages, parentMessageId = null) => { + const flattenMessages = ( + messages, + parentMessageId = Constants.NO_PARENT, + flatMessages = [], + ) => { for (const message of messages) { if (!message.text && !message.content) { continue; } - let savedMessage; - if (message.sender?.toLowerCase() === 'user' || message.isCreatedByUser) { - savedMessage = await importBatchBuilder.saveMessage({ - text: message.text, - content: message.content, - sender: 'user', - isCreatedByUser: true, - parentMessageId: parentMessageId, - }); - } else { - savedMessage = await importBatchBuilder.saveMessage({ - text: message.text, - content: message.content, - sender: message.sender, - isCreatedByUser: false, - model: options.model, - parentMessageId: parentMessageId, - }); - } + const flatMessage = { + ...message, + parentMessageId: parentMessageId, + children: undefined, // Remove children from flat structure + }; + flatMessages.push(flatMessage); if (!firstMessageDate && message.createdAt) { firstMessageDate = new Date(message.createdAt); } if (message.children && message.children.length > 0) { - await traverseMessages(message.children, savedMessage.messageId); + flattenMessages(message.children, message.messageId, flatMessages); } } + return flatMessages; }; - await traverseMessages(messagesToImport); + const flatMessages = flattenMessages(messagesToImport); + cloneMessagesWithTimestamps(flatMessages, importBatchBuilder); } else if (messagesToImport) { - const idMapping = new Map(); - + cloneMessagesWithTimestamps(messagesToImport, importBatchBuilder); for (const message of messagesToImport) { if (!firstMessageDate && message.createdAt) { firstMessageDate = new Date(message.createdAt); } - const newMessageId = uuidv4(); - idMapping.set(message.messageId, newMessageId); - - const clonedMessage = { - ...message, - messageId: newMessageId, - parentMessageId: - message.parentMessageId && message.parentMessageId !== Constants.NO_PARENT - ? idMapping.get(message.parentMessageId) || Constants.NO_PARENT - : Constants.NO_PARENT, - }; - - importBatchBuilder.saveMessage(clonedMessage); } } else { throw new Error('Invalid LibreChat file format'); diff --git a/api/server/utils/import/importers.spec.js b/api/server/utils/import/importers.spec.js index f08644d5c0..23b7e70901 100644 --- a/api/server/utils/import/importers.spec.js +++ b/api/server/utils/import/importers.spec.js @@ -175,36 +175,60 @@ describe('importLibreChatConvo', () => { jest.spyOn(importBatchBuilder, 'saveMessage'); jest.spyOn(importBatchBuilder, 'saveBatch'); - // When const importer = getImporter(jsonData); await importer(jsonData, requestUserId, () => importBatchBuilder); - // Create a map to track original message IDs to new UUIDs - const idToUUIDMap = new Map(); - importBatchBuilder.saveMessage.mock.calls.forEach((call) => { - const message = call[0]; - idToUUIDMap.set(message.originalMessageId, message.messageId); + // Get the imported messages + const messages = importBatchBuilder.messages; + expect(messages.length).toBeGreaterThan(0); + + // Build maps for verification + const textToMessageMap = new Map(); + const messageIdToMessage = new Map(); + messages.forEach((msg) => { + if (msg.text) { + // For recursive imports, text might be very long, so just use the first 100 chars as key + const textKey = msg.text.substring(0, 100); + textToMessageMap.set(textKey, msg); + } + messageIdToMessage.set(msg.messageId, msg); }); - const checkChildren = (children, parentId) => { - children.forEach((child) => { - const childUUID = idToUUIDMap.get(child.messageId); - const expectedParentId = idToUUIDMap.get(parentId) ?? null; - const messageCall = importBatchBuilder.saveMessage.mock.calls.find( - (call) => call[0].messageId === childUUID, - ); - - const actualParentId = messageCall[0].parentMessageId; - expect(actualParentId).toBe(expectedParentId); - - if (child.children && child.children.length > 0) { - checkChildren(child.children, child.messageId); + // Count expected messages from the tree + const countMessagesInTree = (nodes) => { + let count = 0; + nodes.forEach((node) => { + if (node.text || node.content) { + count++; + } + if (node.children && node.children.length > 0) { + count += countMessagesInTree(node.children); } }); + return count; }; - // Start hierarchy validation from root messages - checkChildren(jsonData.messages, null); + const expectedMessageCount = countMessagesInTree(jsonData.messages); + expect(messages.length).toBe(expectedMessageCount); + + // Verify all messages have valid parent relationships + messages.forEach((msg) => { + if (msg.parentMessageId !== Constants.NO_PARENT) { + const parent = messageIdToMessage.get(msg.parentMessageId); + expect(parent).toBeDefined(); + + // Verify timestamp ordering + if (msg.createdAt && parent.createdAt) { + expect(new Date(msg.createdAt).getTime()).toBeGreaterThanOrEqual( + new Date(parent.createdAt).getTime(), + ); + } + } + }); + + // Verify at least one root message exists + const rootMessages = messages.filter((msg) => msg.parentMessageId === Constants.NO_PARENT); + expect(rootMessages.length).toBeGreaterThan(0); expect(importBatchBuilder.saveBatch).toHaveBeenCalled(); }); diff --git a/api/server/utils/index.js b/api/server/utils/index.js index 2661ff75e1..2672f4f2ea 100644 --- a/api/server/utils/index.js +++ b/api/server/utils/index.js @@ -1,11 +1,9 @@ -const streamResponse = require('./streamResponse'); const removePorts = require('./removePorts'); const countTokens = require('./countTokens'); const handleText = require('./handleText'); const sendEmail = require('./sendEmail'); const queue = require('./queue'); const files = require('./files'); -const math = require('./math'); /** * Check if email configuration is set @@ -28,7 +26,6 @@ function checkEmailConfig() { } module.exports = { - ...streamResponse, checkEmailConfig, ...handleText, countTokens, @@ -36,5 +33,4 @@ module.exports = { sendEmail, ...files, ...queue, - math, }; diff --git a/api/strategies/openIdJwtStrategy.js b/api/strategies/openIdJwtStrategy.js index dae8d17bc6..cc90e20036 100644 --- a/api/strategies/openIdJwtStrategy.js +++ b/api/strategies/openIdJwtStrategy.js @@ -1,4 +1,5 @@ const { SystemRoles } = require('librechat-data-provider'); +const { HttpsProxyAgent } = require('https-proxy-agent'); const { Strategy: JwtStrategy, ExtractJwt } = require('passport-jwt'); const { updateUser, findUser } = require('~/models'); const { logger } = require('~/config'); @@ -13,17 +14,23 @@ const { isEnabled } = require('~/server/utils'); * The strategy extracts the JWT from the Authorization header as a Bearer token. * The JWT is then verified using the signing key, and the user is retrieved from the database. */ -const openIdJwtLogin = (openIdConfig) => - new JwtStrategy( +const openIdJwtLogin = (openIdConfig) => { + let jwksRsaOptions = { + cache: isEnabled(process.env.OPENID_JWKS_URL_CACHE_ENABLED) || true, + cacheMaxAge: process.env.OPENID_JWKS_URL_CACHE_TIME + ? eval(process.env.OPENID_JWKS_URL_CACHE_TIME) + : 60000, + jwksUri: openIdConfig.serverMetadata().jwks_uri, + }; + + if (process.env.PROXY) { + jwksRsaOptions.requestAgent = new HttpsProxyAgent(process.env.PROXY); + } + + return new JwtStrategy( { jwtFromRequest: ExtractJwt.fromAuthHeaderAsBearerToken(), - secretOrKeyProvider: jwksRsa.passportJwtSecret({ - cache: isEnabled(process.env.OPENID_JWKS_URL_CACHE_ENABLED) || true, - cacheMaxAge: process.env.OPENID_JWKS_URL_CACHE_TIME - ? eval(process.env.OPENID_JWKS_URL_CACHE_TIME) - : 60000, - jwksUri: openIdConfig.serverMetadata().jwks_uri, - }), + secretOrKeyProvider: jwksRsa.passportJwtSecret(jwksRsaOptions), }, async (payload, done) => { try { @@ -48,5 +55,6 @@ const openIdJwtLogin = (openIdConfig) => } }, ); +}; module.exports = openIdJwtLogin; diff --git a/api/strategies/openidStrategy.js b/api/strategies/openidStrategy.js index 2449872a9d..563ac8257e 100644 --- a/api/strategies/openidStrategy.js +++ b/api/strategies/openidStrategy.js @@ -49,7 +49,7 @@ async function customFetch(url, options) { logger.info(`[openidStrategy] proxy agent configured: ${process.env.PROXY}`); fetchOptions = { ...options, - dispatcher: new HttpsProxyAgent(process.env.PROXY), + dispatcher: new undici.ProxyAgent(process.env.PROXY), }; } @@ -118,7 +118,7 @@ class CustomOpenIDStrategy extends OpenIDStrategy { */ const exchangeAccessTokenIfNeeded = async (config, accessToken, sub, fromCache = false) => { const tokensCache = getLogStores(CacheKeys.OPENID_EXCHANGED_TOKENS); - const onBehalfFlowRequired = isEnabled(process.env.OPENID_ON_BEHALF_FLOW_FOR_USERINFRO_REQUIRED); + const onBehalfFlowRequired = isEnabled(process.env.OPENID_ON_BEHALF_FLOW_FOR_USERINFO_REQUIRED); if (onBehalfFlowRequired) { if (fromCache) { const cachedToken = await tokensCache.get(sub); @@ -130,7 +130,7 @@ const exchangeAccessTokenIfNeeded = async (config, accessToken, sub, fromCache = config, 'urn:ietf:params:oauth:grant-type:jwt-bearer', { - scope: process.env.OPENID_ON_BEHALF_FLOW_USERINFRO_SCOPE || 'user.read', + scope: process.env.OPENID_ON_BEHALF_FLOW_USERINFO_SCOPE || 'user.read', assertion: accessToken, requested_token_use: 'on_behalf_of', }, diff --git a/api/typedefs.js b/api/typedefs.js index 58cd802425..c0e0dd5f46 100644 --- a/api/typedefs.js +++ b/api/typedefs.js @@ -1503,7 +1503,6 @@ * @property {boolean|{userProvide: boolean}} [anthropic] - Flag to indicate if Anthropic endpoint is user provided, or its configuration. * @property {boolean|{userProvide: boolean}} [google] - Flag to indicate if Google endpoint is user provided, or its configuration. * @property {boolean|{userProvide: boolean, userProvideURL: boolean, name: string}} [custom] - Custom Endpoint configuration. - * @property {boolean|GptPlugins} [gptPlugins] - Configuration for GPT plugins. * @memberof typedefs */ diff --git a/api/utils/index.js b/api/utils/index.js index 50b8c46d99..b80c9b0c31 100644 --- a/api/utils/index.js +++ b/api/utils/index.js @@ -1,11 +1,9 @@ -const loadYaml = require('./loadYaml'); const tokenHelpers = require('./tokens'); const deriveBaseURL = require('./deriveBaseURL'); const extractBaseURL = require('./extractBaseURL'); const findMessageContent = require('./findMessageContent'); module.exports = { - loadYaml, deriveBaseURL, extractBaseURL, ...tokenHelpers, diff --git a/api/utils/loadYaml.js b/api/utils/loadYaml.js deleted file mode 100644 index 50e5d23ec3..0000000000 --- a/api/utils/loadYaml.js +++ /dev/null @@ -1,13 +0,0 @@ -const fs = require('fs'); -const yaml = require('js-yaml'); - -function loadYaml(filepath) { - try { - let fileContents = fs.readFileSync(filepath, 'utf8'); - return yaml.load(fileContents); - } catch (e) { - return e; - } -} - -module.exports = loadYaml; diff --git a/client/package.json b/client/package.json index 67cbec2820..9c86cd5d4d 100644 --- a/client/package.json +++ b/client/package.json @@ -1,6 +1,6 @@ { "name": "@librechat/frontend", - "version": "v0.7.8", + "version": "v0.7.9-rc1", "description": "", "type": "module", "scripts": { diff --git a/client/src/Providers/ActivePanelContext.tsx b/client/src/Providers/ActivePanelContext.tsx new file mode 100644 index 0000000000..4a8d6ccfc4 --- /dev/null +++ b/client/src/Providers/ActivePanelContext.tsx @@ -0,0 +1,37 @@ +import { createContext, useContext, useState, ReactNode } from 'react'; + +interface ActivePanelContextType { + active: string | undefined; + setActive: (id: string) => void; +} + +const ActivePanelContext = createContext(undefined); + +export function ActivePanelProvider({ + children, + defaultActive, +}: { + children: ReactNode; + defaultActive?: string; +}) { + const [active, _setActive] = useState(defaultActive); + + const setActive = (id: string) => { + localStorage.setItem('side:active-panel', id); + _setActive(id); + }; + + return ( + + {children} + + ); +} + +export function useActivePanel() { + const context = useContext(ActivePanelContext); + if (context === undefined) { + throw new Error('useActivePanel must be used within an ActivePanelProvider'); + } + return context; +} diff --git a/client/src/Providers/AgentPanelContext.tsx b/client/src/Providers/AgentPanelContext.tsx index 2cc64ba3ed..409d8998fb 100644 --- a/client/src/Providers/AgentPanelContext.tsx +++ b/client/src/Providers/AgentPanelContext.tsx @@ -1,9 +1,9 @@ import React, { createContext, useContext, useState } from 'react'; import { Constants, EModelEndpoint } from 'librechat-data-provider'; -import type { TPlugin, AgentToolType, Action, MCP } from 'librechat-data-provider'; +import type { MCP, Action, TPlugin, AgentToolType } from 'librechat-data-provider'; import type { AgentPanelContextType } from '~/common'; import { useAvailableToolsQuery, useGetActionsQuery } from '~/data-provider'; -import { useLocalize } from '~/hooks'; +import { useLocalize, useGetAgentsConfig } from '~/hooks'; import { Panel } from '~/common'; const AgentPanelContext = createContext(undefined); @@ -40,57 +40,60 @@ export function AgentPanelProvider({ children }: { children: React.ReactNode }) agent_id: agent_id || '', })) || []; - const groupedTools = - tools?.reduce( - (acc, tool) => { - if (tool.tool_id.includes(Constants.mcp_delimiter)) { - const [_toolName, serverName] = tool.tool_id.split(Constants.mcp_delimiter); - const groupKey = `${serverName.toLowerCase()}`; - if (!acc[groupKey]) { - acc[groupKey] = { - tool_id: groupKey, - metadata: { - name: `${serverName}`, - pluginKey: groupKey, - description: `${localize('com_ui_tool_collection_prefix')} ${serverName}`, - icon: tool.metadata.icon || '', - } as TPlugin, - agent_id: agent_id || '', - tools: [], - }; - } - acc[groupKey].tools?.push({ - tool_id: tool.tool_id, - metadata: tool.metadata, - agent_id: agent_id || '', - }); - } else { - acc[tool.tool_id] = { - tool_id: tool.tool_id, - metadata: tool.metadata, + const groupedTools = tools?.reduce( + (acc, tool) => { + if (tool.tool_id.includes(Constants.mcp_delimiter)) { + const [_toolName, serverName] = tool.tool_id.split(Constants.mcp_delimiter); + const groupKey = `${serverName.toLowerCase()}`; + if (!acc[groupKey]) { + acc[groupKey] = { + tool_id: groupKey, + metadata: { + name: `${serverName}`, + pluginKey: groupKey, + description: `${localize('com_ui_tool_collection_prefix')} ${serverName}`, + icon: tool.metadata.icon || '', + } as TPlugin, agent_id: agent_id || '', + tools: [], }; } - return acc; - }, - {} as Record, - ) || {}; + acc[groupKey].tools?.push({ + tool_id: tool.tool_id, + metadata: tool.metadata, + agent_id: agent_id || '', + }); + } else { + acc[tool.tool_id] = { + tool_id: tool.tool_id, + metadata: tool.metadata, + agent_id: agent_id || '', + }; + } + return acc; + }, + {} as Record, + ); - const value = { - action, - setAction, + const { agentsConfig, endpointsConfig } = useGetAgentsConfig(); + + const value: AgentPanelContextType = { mcp, - setMcp, mcps, - setMcps, - activePanel, - setActivePanel, - setCurrentAgentId, - agent_id, - groupedTools, /** Query data for actions and tools */ - actions, tools, + action, + setMcp, + actions, + setMcps, + agent_id, + setAction, + activePanel, + groupedTools, + agentsConfig, + setActivePanel, + endpointsConfig, + setCurrentAgentId, }; return {children}; diff --git a/client/src/Providers/BadgeRowContext.tsx b/client/src/Providers/BadgeRowContext.tsx new file mode 100644 index 0000000000..e54411ed84 --- /dev/null +++ b/client/src/Providers/BadgeRowContext.tsx @@ -0,0 +1,188 @@ +import React, { createContext, useContext, useEffect, useRef } from 'react'; +import { useSetRecoilState } from 'recoil'; +import { Tools, Constants, LocalStorageKeys, AgentCapabilities } from 'librechat-data-provider'; +import type { TAgentsEndpoint } from 'librechat-data-provider'; +import { + useSearchApiKeyForm, + useGetAgentsConfig, + useCodeApiKeyForm, + useToolToggle, + useMCPSelect, +} from '~/hooks'; +import { useGetStartupConfig } from '~/data-provider'; +import { ephemeralAgentByConvoId } from '~/store'; + +interface BadgeRowContextType { + conversationId?: string | null; + agentsConfig?: TAgentsEndpoint | null; + mcpSelect: ReturnType; + webSearch: ReturnType; + artifacts: ReturnType; + fileSearch: ReturnType; + codeInterpreter: ReturnType; + codeApiKeyForm: ReturnType; + searchApiKeyForm: ReturnType; + startupConfig: ReturnType['data']; +} + +const BadgeRowContext = createContext(undefined); + +export function useBadgeRowContext() { + const context = useContext(BadgeRowContext); + if (context === undefined) { + throw new Error('useBadgeRowContext must be used within a BadgeRowProvider'); + } + return context; +} + +interface BadgeRowProviderProps { + children: React.ReactNode; + isSubmitting?: boolean; + conversationId?: string | null; +} + +export default function BadgeRowProvider({ + children, + isSubmitting, + conversationId, +}: BadgeRowProviderProps) { + const hasInitializedRef = useRef(false); + const lastKeyRef = useRef(''); + const { agentsConfig } = useGetAgentsConfig(); + const key = conversationId ?? Constants.NEW_CONVO; + const setEphemeralAgent = useSetRecoilState(ephemeralAgentByConvoId(key)); + + /** Initialize ephemeralAgent from localStorage on mount and when conversation changes */ + useEffect(() => { + if (isSubmitting) { + return; + } + // Check if this is a new conversation or the first load + if (!hasInitializedRef.current || lastKeyRef.current !== key) { + hasInitializedRef.current = true; + lastKeyRef.current = key; + + // Load all localStorage values + const codeToggleKey = `${LocalStorageKeys.LAST_CODE_TOGGLE_}${key}`; + const webSearchToggleKey = `${LocalStorageKeys.LAST_WEB_SEARCH_TOGGLE_}${key}`; + const fileSearchToggleKey = `${LocalStorageKeys.LAST_FILE_SEARCH_TOGGLE_}${key}`; + const artifactsToggleKey = `${LocalStorageKeys.LAST_ARTIFACTS_TOGGLE_}${key}`; + + const codeToggleValue = localStorage.getItem(codeToggleKey); + const webSearchToggleValue = localStorage.getItem(webSearchToggleKey); + const fileSearchToggleValue = localStorage.getItem(fileSearchToggleKey); + const artifactsToggleValue = localStorage.getItem(artifactsToggleKey); + + const initialValues: Record = {}; + + if (codeToggleValue !== null) { + try { + initialValues[Tools.execute_code] = JSON.parse(codeToggleValue); + } catch (e) { + console.error('Failed to parse code toggle value:', e); + } + } + + if (webSearchToggleValue !== null) { + try { + initialValues[Tools.web_search] = JSON.parse(webSearchToggleValue); + } catch (e) { + console.error('Failed to parse web search toggle value:', e); + } + } + + if (fileSearchToggleValue !== null) { + try { + initialValues[Tools.file_search] = JSON.parse(fileSearchToggleValue); + } catch (e) { + console.error('Failed to parse file search toggle value:', e); + } + } + + if (artifactsToggleValue !== null) { + try { + initialValues[AgentCapabilities.artifacts] = JSON.parse(artifactsToggleValue); + } catch (e) { + console.error('Failed to parse artifacts toggle value:', e); + } + } + + // Always set values for all tools (use defaults if not in localStorage) + // If ephemeralAgent is null, create a new object with just our tool values + setEphemeralAgent((prev) => ({ + ...(prev || {}), + [Tools.execute_code]: initialValues[Tools.execute_code] ?? false, + [Tools.web_search]: initialValues[Tools.web_search] ?? false, + [Tools.file_search]: initialValues[Tools.file_search] ?? false, + [AgentCapabilities.artifacts]: initialValues[AgentCapabilities.artifacts] ?? false, + })); + } + }, [key, isSubmitting, setEphemeralAgent]); + + /** Startup config */ + const { data: startupConfig } = useGetStartupConfig(); + + /** MCPSelect hook */ + const mcpSelect = useMCPSelect({ conversationId }); + + /** CodeInterpreter hooks */ + const codeApiKeyForm = useCodeApiKeyForm({}); + const { setIsDialogOpen: setCodeDialogOpen } = codeApiKeyForm; + + const codeInterpreter = useToolToggle({ + conversationId, + setIsDialogOpen: setCodeDialogOpen, + toolKey: Tools.execute_code, + localStorageKey: LocalStorageKeys.LAST_CODE_TOGGLE_, + authConfig: { + toolId: Tools.execute_code, + queryOptions: { retry: 1 }, + }, + }); + + /** WebSearch hooks */ + const searchApiKeyForm = useSearchApiKeyForm({}); + const { setIsDialogOpen: setWebSearchDialogOpen } = searchApiKeyForm; + + const webSearch = useToolToggle({ + conversationId, + toolKey: Tools.web_search, + localStorageKey: LocalStorageKeys.LAST_WEB_SEARCH_TOGGLE_, + setIsDialogOpen: setWebSearchDialogOpen, + authConfig: { + toolId: Tools.web_search, + queryOptions: { retry: 1 }, + }, + }); + + /** FileSearch hook */ + const fileSearch = useToolToggle({ + conversationId, + toolKey: Tools.file_search, + localStorageKey: LocalStorageKeys.LAST_FILE_SEARCH_TOGGLE_, + isAuthenticated: true, + }); + + /** Artifacts hook - using a custom key since it's not a Tool but a capability */ + const artifacts = useToolToggle({ + conversationId, + toolKey: AgentCapabilities.artifacts, + localStorageKey: LocalStorageKeys.LAST_ARTIFACTS_TOGGLE_, + isAuthenticated: true, + }); + + const value: BadgeRowContextType = { + mcpSelect, + webSearch, + artifacts, + fileSearch, + agentsConfig, + startupConfig, + conversationId, + codeApiKeyForm, + codeInterpreter, + searchApiKeyForm, + }; + + return {children}; +} diff --git a/client/src/Providers/index.ts b/client/src/Providers/index.ts index 41c9cdceb3..b455cb3f1e 100644 --- a/client/src/Providers/index.ts +++ b/client/src/Providers/index.ts @@ -1,6 +1,7 @@ export { default as AssistantsProvider } from './AssistantsContext'; export { default as AgentsProvider } from './AgentsContext'; export { default as ToastProvider } from './ToastContext'; +export * from './ActivePanelContext'; export * from './AgentPanelContext'; export * from './ChatContext'; export * from './ShareContext'; @@ -22,3 +23,5 @@ export * from './CodeBlockContext'; export * from './ToolCallsMapContext'; export * from './SetConvoContext'; export * from './SearchContext'; +export * from './BadgeRowContext'; +export { default as BadgeRowProvider } from './BadgeRowContext'; diff --git a/client/src/common/types.ts b/client/src/common/types.ts index 214dc349b5..52575e180d 100644 --- a/client/src/common/types.ts +++ b/client/src/common/types.ts @@ -206,9 +206,7 @@ export type AgentPanelProps = { setActivePanel: React.Dispatch>; setMcp: React.Dispatch>; setAction: React.Dispatch>; - endpointsConfig?: t.TEndpointsConfig; setCurrentAgentId: React.Dispatch>; - agentsConfig?: t.TAgentsEndpoint | null; }; export type AgentPanelContextType = { @@ -219,12 +217,14 @@ export type AgentPanelContextType = { mcps?: t.MCP[]; setMcp: React.Dispatch>; setMcps: React.Dispatch>; - groupedTools: Record; tools: t.AgentToolType[]; activePanel?: string; setActivePanel: React.Dispatch>; setCurrentAgentId: React.Dispatch>; + groupedTools?: Record; agent_id?: string; + agentsConfig?: t.TAgentsEndpoint | null; + endpointsConfig?: t.TEndpointsConfig | null; }; export type AgentModelPanelProps = { @@ -336,6 +336,11 @@ export type TAskProps = { export type TOptions = { editedMessageId?: string | null; editedText?: string | null; + editedContent?: { + index: number; + text: string; + type: 'text' | 'think'; + }; isRegenerate?: boolean; isContinued?: boolean; isEdited?: boolean; diff --git a/client/src/components/Chat/ExportAndShareMenu.tsx b/client/src/components/Chat/ExportAndShareMenu.tsx index 0ac0144da3..aa9ee8be9b 100644 --- a/client/src/components/Chat/ExportAndShareMenu.tsx +++ b/client/src/components/Chat/ExportAndShareMenu.tsx @@ -68,6 +68,7 @@ export default function ExportAndShareMenu({ return ( <> (() => { + if (typeof toggleState === 'string' && toggleState) { + return { enabled: true, mode: toggleState }; + } + return { enabled: false, mode: '' }; + }, [toggleState]); + + const isEnabled = currentState.enabled; + const isShadcnEnabled = currentState.mode === ArtifactModes.SHADCNUI; + const isCustomEnabled = currentState.mode === ArtifactModes.CUSTOM; + + const handleToggle = useCallback(() => { + if (isEnabled) { + debouncedChange({ value: '' }); + } else { + debouncedChange({ value: ArtifactModes.DEFAULT }); + } + }, [isEnabled, debouncedChange]); + + const handleShadcnToggle = useCallback(() => { + if (isShadcnEnabled) { + debouncedChange({ value: ArtifactModes.DEFAULT }); + } else { + debouncedChange({ value: ArtifactModes.SHADCNUI }); + } + }, [isShadcnEnabled, debouncedChange]); + + const handleCustomToggle = useCallback(() => { + if (isCustomEnabled) { + debouncedChange({ value: ArtifactModes.DEFAULT }); + } else { + debouncedChange({ value: ArtifactModes.CUSTOM }); + } + }, [isCustomEnabled, debouncedChange]); + + if (!isEnabled && !isPinned) { + return null; + } + + return ( +
+ } + /> + + {isEnabled && ( + + e.stopPropagation()} + > + + + + +
+
+ {localize('com_ui_artifacts_options')} +
+ + {/* Include shadcn/ui Option */} + { + event.preventDefault(); + event.stopPropagation(); + handleShadcnToggle(); + }} + disabled={isCustomEnabled} + className={cn( + 'mb-1 flex items-center justify-between rounded-lg px-2 py-2', + 'cursor-pointer outline-none transition-colors', + 'hover:bg-black/[0.075] dark:hover:bg-white/10', + 'data-[active-item]:bg-black/[0.075] dark:data-[active-item]:bg-white/10', + isCustomEnabled && 'cursor-not-allowed opacity-50', + )} + > +
+ + {localize('com_ui_include_shadcnui' as any)} +
+
+ + {/* Custom Prompt Mode Option */} + { + event.preventDefault(); + event.stopPropagation(); + handleCustomToggle(); + }} + className={cn( + 'flex items-center justify-between rounded-lg px-2 py-2', + 'cursor-pointer outline-none transition-colors', + 'hover:bg-black/[0.075] dark:hover:bg-white/10', + 'data-[active-item]:bg-black/[0.075] dark:data-[active-item]:bg-white/10', + )} + > +
+ + {localize('com_ui_custom_prompt_mode' as any)} +
+
+
+
+
+ )} +
+ ); +} + +export default memo(Artifacts); diff --git a/client/src/components/Chat/Input/ArtifactsSubMenu.tsx b/client/src/components/Chat/Input/ArtifactsSubMenu.tsx new file mode 100644 index 0000000000..944ecb66c7 --- /dev/null +++ b/client/src/components/Chat/Input/ArtifactsSubMenu.tsx @@ -0,0 +1,147 @@ +import React from 'react'; +import * as Ariakit from '@ariakit/react'; +import { ChevronRight, WandSparkles } from 'lucide-react'; +import { ArtifactModes } from 'librechat-data-provider'; +import { PinIcon } from '~/components/svg'; +import { useLocalize } from '~/hooks'; +import { cn } from '~/utils'; + +interface ArtifactsSubMenuProps { + isArtifactsPinned: boolean; + setIsArtifactsPinned: (value: boolean) => void; + artifactsMode: string; + handleArtifactsToggle: () => void; + handleShadcnToggle: () => void; + handleCustomToggle: () => void; +} + +const ArtifactsSubMenu = ({ + isArtifactsPinned, + setIsArtifactsPinned, + artifactsMode, + handleArtifactsToggle, + handleShadcnToggle, + handleCustomToggle, + ...props +}: ArtifactsSubMenuProps) => { + const localize = useLocalize(); + + const menuStore = Ariakit.useMenuStore({ + focusLoop: true, + showTimeout: 100, + placement: 'right', + }); + + const isEnabled = artifactsMode !== '' && artifactsMode !== undefined; + const isShadcnEnabled = artifactsMode === ArtifactModes.SHADCNUI; + const isCustomEnabled = artifactsMode === ArtifactModes.CUSTOM; + + return ( + + ) => { + e.stopPropagation(); + handleArtifactsToggle(); + }} + onMouseEnter={() => { + if (isEnabled) { + menuStore.show(); + } + }} + className="flex w-full cursor-pointer items-center justify-between rounded-lg p-2 hover:bg-surface-hover" + /> + } + > +
+ + {localize('com_ui_artifacts')} + {isEnabled && } +
+ +
+ + {isEnabled && ( + +
+
+ {localize('com_ui_artifacts_options')} +
+ + {/* Include shadcn/ui Option */} + { + event.preventDefault(); + event.stopPropagation(); + handleShadcnToggle(); + }} + disabled={isCustomEnabled} + className={cn( + 'mb-1 flex items-center justify-between rounded-lg px-2 py-2', + 'cursor-pointer text-text-primary outline-none transition-colors', + 'hover:bg-black/[0.075] dark:hover:bg-white/10', + 'data-[active-item]:bg-black/[0.075] dark:data-[active-item]:bg-white/10', + isCustomEnabled && 'cursor-not-allowed opacity-50', + )} + > +
+ + {localize('com_ui_include_shadcnui' as any)} +
+
+ + {/* Custom Prompt Mode Option */} + { + event.preventDefault(); + event.stopPropagation(); + handleCustomToggle(); + }} + className={cn( + 'flex items-center justify-between rounded-lg px-2 py-2', + 'cursor-pointer text-text-primary outline-none transition-colors', + 'hover:bg-black/[0.075] dark:hover:bg-white/10', + 'data-[active-item]:bg-black/[0.075] dark:data-[active-item]:bg-white/10', + )} + > +
+ + {localize('com_ui_custom_prompt_mode' as any)} +
+
+
+
+ )} +
+ ); +}; + +export default React.memo(ArtifactsSubMenu); diff --git a/client/src/components/Chat/Input/BadgeRow.tsx b/client/src/components/Chat/Input/BadgeRow.tsx index ed9f4b82c2..d77fc5426c 100644 --- a/client/src/components/Chat/Input/BadgeRow.tsx +++ b/client/src/components/Chat/Input/BadgeRow.tsx @@ -1,19 +1,24 @@ import React, { memo, - useState, useRef, - useEffect, - useCallback, useMemo, + useState, + useEffect, forwardRef, useReducer, + useCallback, } from 'react'; import { useRecoilValue, useRecoilCallback } from 'recoil'; import type { LucideIcon } from 'lucide-react'; import CodeInterpreter from './CodeInterpreter'; +import { BadgeRowProvider } from '~/Providers'; +import ToolsDropdown from './ToolsDropdown'; import type { BadgeItem } from '~/common'; import { useChatBadges } from '~/hooks'; import { Badge } from '~/components/ui'; +import ToolDialogs from './ToolDialogs'; +import FileSearch from './FileSearch'; +import Artifacts from './Artifacts'; import MCPSelect from './MCPSelect'; import WebSearch from './WebSearch'; import store from '~/store'; @@ -23,6 +28,7 @@ interface BadgeRowProps { onChange: (badges: Pick[]) => void; onToggle?: (badgeId: string, currentActive: boolean) => void; conversationId?: string | null; + isSubmitting?: boolean; isInChat: boolean; } @@ -136,6 +142,7 @@ const dragReducer = (state: DragState, action: DragAction): DragState => { function BadgeRow({ showEphemeralBadges, conversationId, + isSubmitting, onChange, onToggle, isInChat, @@ -313,78 +320,84 @@ function BadgeRow({ }, [dragState.draggedBadge, handleMouseMove, handleMouseUp]); return ( -
- {tempBadges.map((badge, index) => ( - - {dragState.draggedBadge && dragState.insertIndex === index && ghostBadge && ( -
- -
- )} - -
- ))} - {dragState.draggedBadge && dragState.insertIndex === tempBadges.length && ghostBadge && ( -
- -
- )} - {showEphemeralBadges === true && ( - <> - - - - - )} - {ghostBadge && ( -
- -
- )} -
+ +
+ {showEphemeralBadges === true && } + {tempBadges.map((badge, index) => ( + + {dragState.draggedBadge && dragState.insertIndex === index && ghostBadge && ( +
+ +
+ )} + +
+ ))} + {dragState.draggedBadge && dragState.insertIndex === tempBadges.length && ghostBadge && ( +
+ +
+ )} + {showEphemeralBadges === true && ( + <> + + + + + + + )} + {ghostBadge && ( +
+ +
+ )} +
+ +
); } diff --git a/client/src/components/Chat/Input/ChatForm.tsx b/client/src/components/Chat/Input/ChatForm.tsx index 23bece3626..0ca6448094 100644 --- a/client/src/components/Chat/Input/ChatForm.tsx +++ b/client/src/components/Chat/Input/ChatForm.tsx @@ -305,6 +305,7 @@ const ChatForm = memo(({ index = 0 }: { index?: number }) => { { - if (rawCurrentValue) { - try { - const currentValue = rawCurrentValue?.trim() ?? ''; - if (currentValue === 'true' && value === false) { - return true; - } - } catch (e) { - console.error(e); - } - } - return value !== undefined && value !== null && value !== '' && value !== false; -}; - -function CodeInterpreter({ conversationId }: { conversationId?: string | null }) { - const triggerRef = useRef(null); +function CodeInterpreter() { const localize = useLocalize(); - const key = conversationId ?? Constants.NEW_CONVO; + const { codeInterpreter, codeApiKeyForm } = useBadgeRowContext(); + const { toggleState: runCode, debouncedChange, isPinned } = codeInterpreter; + const { badgeTriggerRef } = codeApiKeyForm; const canRunCode = useHasAccess({ permissionType: PermissionTypes.RUN_CODE, permission: Permissions.USE, }); - const [ephemeralAgent, setEphemeralAgent] = useRecoilState(ephemeralAgentByConvoId(key)); - const isCodeToggleEnabled = useMemo(() => { - return ephemeralAgent?.execute_code ?? false; - }, [ephemeralAgent?.execute_code]); - - const { data } = useVerifyAgentToolAuth( - { toolId: Tools.execute_code }, - { - retry: 1, - }, - ); - const authType = useMemo(() => data?.message ?? false, [data?.message]); - const isAuthenticated = useMemo(() => data?.authenticated ?? false, [data?.authenticated]); - const { methods, onSubmit, isDialogOpen, setIsDialogOpen, handleRevokeApiKey } = - useCodeApiKeyForm({}); - - const setValue = useCallback( - (isChecked: boolean) => { - setEphemeralAgent((prev) => ({ - ...prev, - execute_code: isChecked, - })); - }, - [setEphemeralAgent], - ); - - const [runCode, setRunCode] = useLocalStorage( - `${LocalStorageKeys.LAST_CODE_TOGGLE_}${key}`, - isCodeToggleEnabled, - setValue, - storageCondition, - ); - - const handleChange = useCallback( - (e: React.ChangeEvent, isChecked: boolean) => { - if (!isAuthenticated) { - setIsDialogOpen(true); - e.preventDefault(); - return; - } - setRunCode(isChecked); - }, - [setRunCode, setIsDialogOpen, isAuthenticated], - ); - - const debouncedChange = useMemo( - () => debounce(handleChange, 50, { leading: true }), - [handleChange], - ); if (!canRunCode) { return null; } return ( - <> + (runCode || isPinned) && ( } /> - - + ) ); } diff --git a/client/src/components/Chat/Input/FileSearch.tsx b/client/src/components/Chat/Input/FileSearch.tsx new file mode 100644 index 0000000000..a4952d1fd1 --- /dev/null +++ b/client/src/components/Chat/Input/FileSearch.tsx @@ -0,0 +1,28 @@ +import React, { memo } from 'react'; +import CheckboxButton from '~/components/ui/CheckboxButton'; +import { useBadgeRowContext } from '~/Providers'; +import { VectorIcon } from '~/components/svg'; +import { useLocalize } from '~/hooks'; + +function FileSearch() { + const localize = useLocalize(); + const { fileSearch } = useBadgeRowContext(); + const { toggleState: fileSearchEnabled, debouncedChange, isPinned } = fileSearch; + + return ( + <> + {(fileSearchEnabled || isPinned) && ( + } + /> + )} + + ); +} + +export default memo(FileSearch); diff --git a/client/src/components/Chat/Input/Files/AttachFileChat.tsx b/client/src/components/Chat/Input/Files/AttachFileChat.tsx index 11bca082fe..6bdecca22a 100644 --- a/client/src/components/Chat/Input/Files/AttachFileChat.tsx +++ b/client/src/components/Chat/Input/Files/AttachFileChat.tsx @@ -1,50 +1,44 @@ import { memo, useMemo } from 'react'; -import { useRecoilValue } from 'recoil'; import { Constants, supportsFiles, mergeFileConfig, isAgentsEndpoint, - isEphemeralAgent, - EndpointFileConfig, + isAssistantsEndpoint, fileConfig as defaultFileConfig, } from 'librechat-data-provider'; -import { useChatContext } from '~/Providers'; +import type { EndpointFileConfig } from 'librechat-data-provider'; import { useGetFileConfig } from '~/data-provider'; -import { ephemeralAgentByConvoId } from '~/store'; import AttachFileMenu from './AttachFileMenu'; +import { useChatContext } from '~/Providers'; import AttachFile from './AttachFile'; function AttachFileChat({ disableInputs }: { disableInputs: boolean }) { const { conversation } = useChatContext(); - - const { endpoint: _endpoint, endpointType } = conversation ?? { endpoint: null }; - - const key = conversation?.conversationId ?? Constants.NEW_CONVO; - const ephemeralAgent = useRecoilValue(ephemeralAgentByConvoId(key)); - const isAgents = useMemo( - () => isAgentsEndpoint(_endpoint) || isEphemeralAgent(_endpoint, ephemeralAgent), - [_endpoint, ephemeralAgent], - ); + const conversationId = conversation?.conversationId ?? Constants.NEW_CONVO; + const { endpoint, endpointType } = conversation ?? { endpoint: null }; + const isAgents = useMemo(() => isAgentsEndpoint(endpoint), [endpoint]); + const isAssistants = useMemo(() => isAssistantsEndpoint(endpoint), [endpoint]); const { data: fileConfig = defaultFileConfig } = useGetFileConfig({ select: (data) => mergeFileConfig(data), }); - const endpointFileConfig = fileConfig.endpoints[_endpoint ?? ''] as - | EndpointFileConfig - | undefined; - - const endpointSupportsFiles: boolean = supportsFiles[endpointType ?? _endpoint ?? ''] ?? false; + const endpointFileConfig = fileConfig.endpoints[endpoint ?? ''] as EndpointFileConfig | undefined; + const endpointSupportsFiles: boolean = supportsFiles[endpointType ?? endpoint ?? ''] ?? false; const isUploadDisabled = (disableInputs || endpointFileConfig?.disabled) ?? false; - if (isAgents) { - return ; - } - if (endpointSupportsFiles && !isUploadDisabled) { + if (isAssistants && endpointSupportsFiles && !isUploadDisabled) { return ; + } else if (isAgents || (endpointSupportsFiles && !isUploadDisabled)) { + return ( + + ); } - return null; } diff --git a/client/src/components/Chat/Input/Files/AttachFileMenu.tsx b/client/src/components/Chat/Input/Files/AttachFileMenu.tsx index 85df07f24f..9fe2988606 100644 --- a/client/src/components/Chat/Input/Files/AttachFileMenu.tsx +++ b/client/src/components/Chat/Input/Files/AttachFileMenu.tsx @@ -1,35 +1,38 @@ +import { useSetRecoilState } from 'recoil'; import * as Ariakit from '@ariakit/react'; import React, { useRef, useState, useMemo } from 'react'; import { FileSearch, ImageUpIcon, TerminalSquareIcon, FileType2Icon } from 'lucide-react'; import { EToolResources, EModelEndpoint, defaultAgentCapabilities } from 'librechat-data-provider'; +import type { EndpointFileConfig } from 'librechat-data-provider'; +import { useLocalize, useGetAgentsConfig, useFileHandling, useAgentCapabilities } from '~/hooks'; import { FileUpload, TooltipAnchor, DropdownPopup, AttachmentIcon } from '~/components'; -import { useGetEndpointsQuery } from '~/data-provider'; -import { useLocalize, useFileHandling } from '~/hooks'; +import { ephemeralAgentByConvoId } from '~/store'; import { cn } from '~/utils'; -interface AttachFileProps { +interface AttachFileMenuProps { + conversationId: string; disabled?: boolean | null; + endpointFileConfig?: EndpointFileConfig; } -const AttachFile = ({ disabled }: AttachFileProps) => { +const AttachFileMenu = ({ disabled, conversationId, endpointFileConfig }: AttachFileMenuProps) => { const localize = useLocalize(); const isUploadDisabled = disabled ?? false; const inputRef = useRef(null); const [isPopoverActive, setIsPopoverActive] = useState(false); + const setEphemeralAgent = useSetRecoilState(ephemeralAgentByConvoId(conversationId)); const [toolResource, setToolResource] = useState(); - const { data: endpointsConfig } = useGetEndpointsQuery(); const { handleFileChange } = useFileHandling({ overrideEndpoint: EModelEndpoint.agents, + overrideEndpointFileConfig: endpointFileConfig, }); + const { agentsConfig } = useGetAgentsConfig(); /** TODO: Ephemeral Agent Capabilities * Allow defining agent capabilities on a per-endpoint basis * Use definition for agents endpoint for ephemeral agents * */ - const capabilities = useMemo( - () => endpointsConfig?.[EModelEndpoint.agents]?.capabilities ?? [], - [endpointsConfig], - ); + const capabilities = useAgentCapabilities(agentsConfig?.capabilities ?? defaultAgentCapabilities); const handleUploadClick = (isImage?: boolean) => { if (!inputRef.current) { @@ -53,7 +56,7 @@ const AttachFile = ({ disabled }: AttachFileProps) => { }, ]; - if (capabilities.includes(EToolResources.ocr)) { + if (capabilities.ocrEnabled) { items.push({ label: localize('com_ui_upload_ocr_text'), onClick: () => { @@ -64,22 +67,27 @@ const AttachFile = ({ disabled }: AttachFileProps) => { }); } - if (capabilities.includes(EToolResources.file_search)) { + if (capabilities.fileSearchEnabled) { items.push({ label: localize('com_ui_upload_file_search'), onClick: () => { setToolResource(EToolResources.file_search); + /** File search is not automatically enabled to simulate legacy behavior */ handleUploadClick(); }, icon: , }); } - if (capabilities.includes(EToolResources.execute_code)) { + if (capabilities.codeEnabled) { items.push({ label: localize('com_ui_upload_code_files'), onClick: () => { setToolResource(EToolResources.execute_code); + setEphemeralAgent((prev) => ({ + ...prev, + [EToolResources.execute_code]: true, + })); handleUploadClick(); }, icon: , @@ -87,7 +95,7 @@ const AttachFile = ({ disabled }: AttachFileProps) => { } return items; - }, [capabilities, localize, setToolResource]); + }, [capabilities, localize, setToolResource, setEphemeralAgent]); const menuTrigger = ( { ); }; -export default React.memo(AttachFile); +export default React.memo(AttachFileMenu); diff --git a/client/src/components/Chat/Input/Files/DragDropModal.tsx b/client/src/components/Chat/Input/Files/DragDropModal.tsx index 784116dc65..5606b4d30c 100644 --- a/client/src/components/Chat/Input/Files/DragDropModal.tsx +++ b/client/src/components/Chat/Input/Files/DragDropModal.tsx @@ -7,7 +7,7 @@ import useLocalize from '~/hooks/useLocalize'; import { OGDialog } from '~/components/ui'; interface DragDropModalProps { - onOptionSelect: (option: string | undefined) => void; + onOptionSelect: (option: EToolResources | undefined) => void; files: File[]; isVisible: boolean; setShowModal: (showModal: boolean) => void; diff --git a/client/src/components/Chat/Input/MCPSelect.tsx b/client/src/components/Chat/Input/MCPSelect.tsx index ebe56c8024..0a03decd53 100644 --- a/client/src/components/Chat/Input/MCPSelect.tsx +++ b/client/src/components/Chat/Input/MCPSelect.tsx @@ -1,74 +1,27 @@ -import React, { memo, useRef, useMemo, useEffect, useCallback, useState } from 'react'; -import { useRecoilState } from 'recoil'; -import { Settings2 } from 'lucide-react'; +import React, { memo, useCallback, useState } from 'react'; +import { SettingsIcon } from 'lucide-react'; +import { Constants } from 'librechat-data-provider'; import { useUpdateUserPluginsMutation } from 'librechat-data-provider/react-query'; -import { Constants, EModelEndpoint, LocalStorageKeys } from 'librechat-data-provider'; -import type { TPlugin, TPluginAuthConfig, TUpdateUserPlugins } from 'librechat-data-provider'; +import type { TUpdateUserPlugins, TPlugin } from 'librechat-data-provider'; import MCPConfigDialog, { type ConfigFieldDetail } from '~/components/ui/MCPConfigDialog'; -import { useAvailableToolsQuery } from '~/data-provider'; -import useLocalStorage from '~/hooks/useLocalStorageAlt'; +import { useToastContext, useBadgeRowContext } from '~/Providers'; import MultiSelect from '~/components/ui/MultiSelect'; -import { ephemeralAgentByConvoId } from '~/store'; -import { useToastContext } from '~/Providers'; -import MCPIcon from '~/components/ui/MCPIcon'; +import { MCPIcon } from '~/components/svg'; import { useLocalize } from '~/hooks'; -interface McpServerInfo { - name: string; - pluginKey: string; - authConfig?: TPluginAuthConfig[]; - authenticated?: boolean; -} - -// Helper function to extract mcp_serverName from a full pluginKey like action_mcp_serverName const getBaseMCPPluginKey = (fullPluginKey: string): string => { const parts = fullPluginKey.split(Constants.mcp_delimiter); return Constants.mcp_prefix + parts[parts.length - 1]; }; -const storageCondition = (value: unknown, rawCurrentValue?: string | null) => { - if (rawCurrentValue) { - try { - const currentValue = rawCurrentValue?.trim() ?? ''; - if (currentValue.length > 2) { - return true; - } - } catch (e) { - console.error(e); - } - } - return Array.isArray(value) && value.length > 0; -}; - -function MCPSelect({ conversationId }: { conversationId?: string | null }) { +function MCPSelect() { const localize = useLocalize(); const { showToast } = useToastContext(); - const key = conversationId ?? Constants.NEW_CONVO; - const hasSetFetched = useRef(null); - const [isConfigModalOpen, setIsConfigModalOpen] = useState(false); - const [selectedToolForConfig, setSelectedToolForConfig] = useState(null); + const { mcpSelect, startupConfig } = useBadgeRowContext(); + const { mcpValues, setMCPValues, mcpServerNames, mcpToolDetails, isPinned } = mcpSelect; - const { data: mcpToolDetails, isFetched } = useAvailableToolsQuery(EModelEndpoint.agents, { - select: (data: TPlugin[]) => { - const mcpToolsMap = new Map(); - data.forEach((tool) => { - const isMCP = tool.pluginKey.includes(Constants.mcp_delimiter); - if (isMCP && tool.chatMenu !== false) { - const parts = tool.pluginKey.split(Constants.mcp_delimiter); - const serverName = parts[parts.length - 1]; - if (!mcpToolsMap.has(serverName)) { - mcpToolsMap.set(serverName, { - name: serverName, - pluginKey: tool.pluginKey, - authConfig: tool.authConfig, - authenticated: tool.authenticated, - }); - } - } - }); - return Array.from(mcpToolsMap.values()); - }, - }); + const [isConfigModalOpen, setIsConfigModalOpen] = useState(false); + const [selectedToolForConfig, setSelectedToolForConfig] = useState(null); const updateUserPluginsMutation = useUpdateUserPluginsMutation({ onSuccess: () => { @@ -84,48 +37,6 @@ function MCPSelect({ conversationId }: { conversationId?: string | null }) { }, }); - const [ephemeralAgent, setEphemeralAgent] = useRecoilState(ephemeralAgentByConvoId(key)); - const mcpState = useMemo(() => { - return ephemeralAgent?.mcp ?? []; - }, [ephemeralAgent?.mcp]); - - const setSelectedValues = useCallback( - (values: string[] | null | undefined) => { - if (!values) { - return; - } - if (!Array.isArray(values)) { - return; - } - setEphemeralAgent((prev) => ({ - ...prev, - mcp: values, - })); - }, - [setEphemeralAgent], - ); - const [mcpValues, setMCPValues] = useLocalStorage( - `${LocalStorageKeys.LAST_MCP_}${key}`, - mcpState, - setSelectedValues, - storageCondition, - ); - - useEffect(() => { - if (hasSetFetched.current === key) { - return; - } - if (!isFetched) { - return; - } - hasSetFetched.current = key; - if ((mcpToolDetails?.length ?? 0) > 0) { - setMCPValues(mcpValues.filter((mcp) => mcpToolDetails?.some((tool) => tool.name === mcp))); - return; - } - setMCPValues([]); - }, [isFetched, setMCPValues, mcpToolDetails, key, mcpValues]); - const renderSelectedValues = useCallback( (values: string[], placeholder?: string) => { if (values.length === 0) { @@ -139,10 +50,6 @@ function MCPSelect({ conversationId }: { conversationId?: string | null }) { [localize], ); - const mcpServerNames = useMemo(() => { - return (mcpToolDetails ?? []).map((tool) => tool.name); - }, [mcpToolDetails]); - const handleConfigSave = useCallback( (targetName: string, authData: Record) => { if (selectedToolForConfig && selectedToolForConfig.name === targetName) { @@ -198,10 +105,10 @@ function MCPSelect({ conversationId }: { conversationId?: string | null }) { setSelectedToolForConfig(tool); setIsConfigModalOpen(true); }} - className="ml-2 flex h-6 w-6 items-center justify-center rounded p-1 hover:bg-black/10 dark:hover:bg-white/10" + className="ml-2 flex h-6 w-6 items-center justify-center rounded p-1 hover:bg-surface-secondary" aria-label={`Configure ${serverName}`} > - + ); @@ -212,10 +119,17 @@ function MCPSelect({ conversationId }: { conversationId?: string | null }) { [mcpToolDetails, setSelectedToolForConfig, setIsConfigModalOpen], ); + // Don't render if no servers are selected and not pinned + if ((!mcpValues || mcpValues.length === 0) && !isPinned) { + return null; + } + if (!mcpToolDetails || mcpToolDetails.length === 0) { return null; } + const placeholderText = + startupConfig?.interface?.mcpServers?.placeholder || localize('com_ui_mcp_servers'); return ( <> } diff --git a/client/src/components/Chat/Input/MCPSubMenu.tsx b/client/src/components/Chat/Input/MCPSubMenu.tsx new file mode 100644 index 0000000000..fd6bd7ad4a --- /dev/null +++ b/client/src/components/Chat/Input/MCPSubMenu.tsx @@ -0,0 +1,103 @@ +import React from 'react'; +import * as Ariakit from '@ariakit/react'; +import { ChevronRight } from 'lucide-react'; +import { PinIcon, MCPIcon } from '~/components/svg'; +import { useLocalize } from '~/hooks'; +import { cn } from '~/utils'; + +interface MCPSubMenuProps { + isMCPPinned: boolean; + setIsMCPPinned: (value: boolean) => void; + mcpValues?: string[]; + mcpServerNames: string[]; + handleMCPToggle: (serverName: string) => void; + placeholder?: string; +} + +const MCPSubMenu = ({ + mcpValues, + isMCPPinned, + mcpServerNames, + setIsMCPPinned, + handleMCPToggle, + placeholder, + ...props +}: MCPSubMenuProps) => { + const localize = useLocalize(); + + const menuStore = Ariakit.useMenuStore({ + focusLoop: true, + showTimeout: 100, + placement: 'right', + }); + + return ( + + ) => { + e.stopPropagation(); + menuStore.toggle(); + }} + className="flex w-full cursor-pointer items-center justify-between rounded-lg p-2 hover:bg-surface-hover" + /> + } + > +
+ + {placeholder || localize('com_ui_mcp_servers')} + +
+ +
+ + {mcpServerNames.map((serverName) => ( + { + event.preventDefault(); + handleMCPToggle(serverName); + }} + className={cn( + 'flex items-center gap-2 rounded-lg px-2 py-1.5 text-text-primary hover:cursor-pointer', + 'scroll-m-1 outline-none transition-colors', + 'hover:bg-black/[0.075] dark:hover:bg-white/10', + 'data-[active-item]:bg-black/[0.075] dark:data-[active-item]:bg-white/10', + 'w-full min-w-0 text-sm', + )} + > + + {serverName} + + ))} + +
+ ); +}; + +export default React.memo(MCPSubMenu); diff --git a/client/src/components/Chat/Input/ToolDialogs.tsx b/client/src/components/Chat/Input/ToolDialogs.tsx new file mode 100644 index 0000000000..d9f2122fca --- /dev/null +++ b/client/src/components/Chat/Input/ToolDialogs.tsx @@ -0,0 +1,66 @@ +import React, { useMemo } from 'react'; +import { AuthType } from 'librechat-data-provider'; +import SearchApiKeyDialog from '~/components/SidePanel/Agents/Search/ApiKeyDialog'; +import CodeApiKeyDialog from '~/components/SidePanel/Agents/Code/ApiKeyDialog'; +import { useBadgeRowContext } from '~/Providers'; + +function ToolDialogs() { + const { webSearch, codeInterpreter, searchApiKeyForm, codeApiKeyForm } = useBadgeRowContext(); + const { authData: webSearchAuthData } = webSearch; + const { authData: codeAuthData } = codeInterpreter; + + const { + methods: searchMethods, + onSubmit: searchOnSubmit, + isDialogOpen: searchDialogOpen, + setIsDialogOpen: setSearchDialogOpen, + handleRevokeApiKey: searchHandleRevoke, + badgeTriggerRef: searchBadgeTriggerRef, + menuTriggerRef: searchMenuTriggerRef, + } = searchApiKeyForm; + + const { + methods: codeMethods, + onSubmit: codeOnSubmit, + isDialogOpen: codeDialogOpen, + setIsDialogOpen: setCodeDialogOpen, + handleRevokeApiKey: codeHandleRevoke, + badgeTriggerRef: codeBadgeTriggerRef, + menuTriggerRef: codeMenuTriggerRef, + } = codeApiKeyForm; + + const searchAuthTypes = useMemo( + () => webSearchAuthData?.authTypes ?? [], + [webSearchAuthData?.authTypes], + ); + const codeAuthType = useMemo(() => codeAuthData?.message ?? false, [codeAuthData?.message]); + + return ( + <> + + + + ); +} + +export default ToolDialogs; diff --git a/client/src/components/Chat/Input/ToolsDropdown.tsx b/client/src/components/Chat/Input/ToolsDropdown.tsx new file mode 100644 index 0000000000..859a7be745 --- /dev/null +++ b/client/src/components/Chat/Input/ToolsDropdown.tsx @@ -0,0 +1,354 @@ +import React, { useState, useMemo, useCallback } from 'react'; +import * as Ariakit from '@ariakit/react'; +import { Globe, Settings, Settings2, TerminalSquareIcon } from 'lucide-react'; +import type { MenuItemProps } from '~/common'; +import { + AuthType, + Permissions, + ArtifactModes, + PermissionTypes, + defaultAgentCapabilities, +} from 'librechat-data-provider'; +import { TooltipAnchor, DropdownPopup } from '~/components'; +import { useLocalize, useHasAccess, useAgentCapabilities } from '~/hooks'; +import ArtifactsSubMenu from '~/components/Chat/Input/ArtifactsSubMenu'; +import MCPSubMenu from '~/components/Chat/Input/MCPSubMenu'; +import { PinIcon, VectorIcon } from '~/components/svg'; +import { useBadgeRowContext } from '~/Providers'; +import { cn } from '~/utils'; + +interface ToolsDropdownProps { + disabled?: boolean; +} + +const ToolsDropdown = ({ disabled }: ToolsDropdownProps) => { + const localize = useLocalize(); + const isDisabled = disabled ?? false; + const [isPopoverActive, setIsPopoverActive] = useState(false); + const { + webSearch, + mcpSelect, + artifacts, + fileSearch, + agentsConfig, + startupConfig, + codeApiKeyForm, + codeInterpreter, + searchApiKeyForm, + } = useBadgeRowContext(); + const { codeEnabled, webSearchEnabled, artifactsEnabled, fileSearchEnabled } = + useAgentCapabilities(agentsConfig?.capabilities ?? defaultAgentCapabilities); + + const { setIsDialogOpen: setIsCodeDialogOpen, menuTriggerRef: codeMenuTriggerRef } = + codeApiKeyForm; + const { setIsDialogOpen: setIsSearchDialogOpen, menuTriggerRef: searchMenuTriggerRef } = + searchApiKeyForm; + const { + isPinned: isSearchPinned, + setIsPinned: setIsSearchPinned, + authData: webSearchAuthData, + } = webSearch; + const { + isPinned: isCodePinned, + setIsPinned: setIsCodePinned, + authData: codeAuthData, + } = codeInterpreter; + const { isPinned: isFileSearchPinned, setIsPinned: setIsFileSearchPinned } = fileSearch; + const { isPinned: isArtifactsPinned, setIsPinned: setIsArtifactsPinned } = artifacts; + const { + mcpValues, + mcpServerNames, + isPinned: isMCPPinned, + setIsPinned: setIsMCPPinned, + } = mcpSelect; + + const canUseWebSearch = useHasAccess({ + permissionType: PermissionTypes.WEB_SEARCH, + permission: Permissions.USE, + }); + + const canRunCode = useHasAccess({ + permissionType: PermissionTypes.RUN_CODE, + permission: Permissions.USE, + }); + + const showWebSearchSettings = useMemo(() => { + const authTypes = webSearchAuthData?.authTypes ?? []; + if (authTypes.length === 0) return true; + return !authTypes.every(([, authType]) => authType === AuthType.SYSTEM_DEFINED); + }, [webSearchAuthData?.authTypes]); + + const showCodeSettings = useMemo( + () => codeAuthData?.message !== AuthType.SYSTEM_DEFINED, + [codeAuthData?.message], + ); + + const handleWebSearchToggle = useCallback(() => { + const newValue = !webSearch.toggleState; + webSearch.debouncedChange({ value: newValue }); + }, [webSearch]); + + const handleCodeInterpreterToggle = useCallback(() => { + const newValue = !codeInterpreter.toggleState; + codeInterpreter.debouncedChange({ value: newValue }); + }, [codeInterpreter]); + + const handleFileSearchToggle = useCallback(() => { + const newValue = !fileSearch.toggleState; + fileSearch.debouncedChange({ value: newValue }); + }, [fileSearch]); + + const handleArtifactsToggle = useCallback(() => { + const currentState = artifacts.toggleState; + if (!currentState || currentState === '') { + artifacts.debouncedChange({ value: ArtifactModes.DEFAULT }); + } else { + artifacts.debouncedChange({ value: '' }); + } + }, [artifacts]); + + const handleShadcnToggle = useCallback(() => { + const currentState = artifacts.toggleState; + if (currentState === ArtifactModes.SHADCNUI) { + artifacts.debouncedChange({ value: ArtifactModes.DEFAULT }); + } else { + artifacts.debouncedChange({ value: ArtifactModes.SHADCNUI }); + } + }, [artifacts]); + + const handleCustomToggle = useCallback(() => { + const currentState = artifacts.toggleState; + if (currentState === ArtifactModes.CUSTOM) { + artifacts.debouncedChange({ value: ArtifactModes.DEFAULT }); + } else { + artifacts.debouncedChange({ value: ArtifactModes.CUSTOM }); + } + }, [artifacts]); + + const handleMCPToggle = useCallback( + (serverName: string) => { + const currentValues = mcpSelect.mcpValues ?? []; + const newValues = currentValues.includes(serverName) + ? currentValues.filter((v) => v !== serverName) + : [...currentValues, serverName]; + mcpSelect.setMCPValues(newValues); + }, + [mcpSelect], + ); + + const mcpPlaceholder = startupConfig?.interface?.mcpServers?.placeholder; + + const dropdownItems: MenuItemProps[] = []; + + if (fileSearchEnabled) { + dropdownItems.push({ + onClick: handleFileSearchToggle, + hideOnClick: false, + render: (props) => ( +
+
+ + {localize('com_assistants_file_search')} +
+ +
+ ), + }); + } + + if (canUseWebSearch && webSearchEnabled) { + dropdownItems.push({ + onClick: handleWebSearchToggle, + hideOnClick: false, + render: (props) => ( +
+
+ + {localize('com_ui_web_search')} +
+
+ {showWebSearchSettings && ( + + )} + +
+
+ ), + }); + } + + if (canRunCode && codeEnabled) { + dropdownItems.push({ + onClick: handleCodeInterpreterToggle, + hideOnClick: false, + render: (props) => ( +
+
+ + {localize('com_assistants_code_interpreter')} +
+
+ {showCodeSettings && ( + + )} + +
+
+ ), + }); + } + + if (artifactsEnabled) { + dropdownItems.push({ + hideOnClick: false, + render: (props) => ( + + ), + }); + } + + if (mcpServerNames && mcpServerNames.length > 0) { + dropdownItems.push({ + hideOnClick: false, + render: (props) => ( + + ), + }); + } + + const menuTrigger = ( + +
+ +
+ + } + id="tools-dropdown-button" + description={localize('com_ui_tools')} + disabled={isDisabled} + /> + ); + + return ( + + ); +}; + +export default React.memo(ToolsDropdown); diff --git a/client/src/components/Chat/Input/WebSearch.tsx b/client/src/components/Chat/Input/WebSearch.tsx index 6844ee1da0..f5139509fc 100644 --- a/client/src/components/Chat/Input/WebSearch.tsx +++ b/client/src/components/Chat/Input/WebSearch.tsx @@ -1,122 +1,37 @@ -import React, { memo, useRef, useMemo, useCallback } from 'react'; +import React, { memo } from 'react'; import { Globe } from 'lucide-react'; -import debounce from 'lodash/debounce'; -import { useRecoilState } from 'recoil'; -import { - Tools, - AuthType, - Constants, - Permissions, - PermissionTypes, - LocalStorageKeys, -} from 'librechat-data-provider'; -import ApiKeyDialog from '~/components/SidePanel/Agents/Search/ApiKeyDialog'; -import { useLocalize, useHasAccess, useSearchApiKeyForm } from '~/hooks'; +import { Permissions, PermissionTypes } from 'librechat-data-provider'; import CheckboxButton from '~/components/ui/CheckboxButton'; -import useLocalStorage from '~/hooks/useLocalStorageAlt'; -import { useVerifyAgentToolAuth } from '~/data-provider'; -import { ephemeralAgentByConvoId } from '~/store'; +import { useLocalize, useHasAccess } from '~/hooks'; +import { useBadgeRowContext } from '~/Providers'; -const storageCondition = (value: unknown, rawCurrentValue?: string | null) => { - if (rawCurrentValue) { - try { - const currentValue = rawCurrentValue?.trim() ?? ''; - if (currentValue === 'true' && value === false) { - return true; - } - } catch (e) { - console.error(e); - } - } - return value !== undefined && value !== null && value !== '' && value !== false; -}; - -function WebSearch({ conversationId }: { conversationId?: string | null }) { - const triggerRef = useRef(null); +function WebSearch() { const localize = useLocalize(); - const key = conversationId ?? Constants.NEW_CONVO; + const { webSearch: webSearchData, searchApiKeyForm } = useBadgeRowContext(); + const { toggleState: webSearch, debouncedChange, isPinned, authData } = webSearchData; + const { badgeTriggerRef } = searchApiKeyForm; const canUseWebSearch = useHasAccess({ permissionType: PermissionTypes.WEB_SEARCH, permission: Permissions.USE, }); - const [ephemeralAgent, setEphemeralAgent] = useRecoilState(ephemeralAgentByConvoId(key)); - const isWebSearchToggleEnabled = useMemo(() => { - return ephemeralAgent?.web_search ?? false; - }, [ephemeralAgent?.web_search]); - - const { data } = useVerifyAgentToolAuth( - { toolId: Tools.web_search }, - { - retry: 1, - }, - ); - const authTypes = useMemo(() => data?.authTypes ?? [], [data?.authTypes]); - const isAuthenticated = useMemo(() => data?.authenticated ?? false, [data?.authenticated]); - const { methods, onSubmit, isDialogOpen, setIsDialogOpen, handleRevokeApiKey } = - useSearchApiKeyForm({}); - - const setValue = useCallback( - (isChecked: boolean) => { - setEphemeralAgent((prev) => ({ - ...prev, - web_search: isChecked, - })); - }, - [setEphemeralAgent], - ); - - const [webSearch, setWebSearch] = useLocalStorage( - `${LocalStorageKeys.LAST_WEB_SEARCH_TOGGLE_}${key}`, - isWebSearchToggleEnabled, - setValue, - storageCondition, - ); - - const handleChange = useCallback( - (e: React.ChangeEvent, isChecked: boolean) => { - if (!isAuthenticated) { - setIsDialogOpen(true); - e.preventDefault(); - return; - } - setWebSearch(isChecked); - }, - [setWebSearch, setIsDialogOpen, isAuthenticated], - ); - - const debouncedChange = useMemo( - () => debounce(handleChange, 50, { leading: true }), - [handleChange], - ); if (!canUseWebSearch) { return null; } return ( - <> + (isPinned || (webSearch && authData?.authenticated)) && ( } /> - - + ) ); } diff --git a/client/src/components/Chat/Menus/Endpoints/utils.ts b/client/src/components/Chat/Menus/Endpoints/utils.ts index 87c0133cf5..5ed155c6a0 100644 --- a/client/src/components/Chat/Menus/Endpoints/utils.ts +++ b/client/src/components/Chat/Menus/Endpoints/utils.ts @@ -83,7 +83,7 @@ export function filterModels( let modelName = modelId; if (isAgentsEndpoint(endpoint.value) && agentsMap && agentsMap[modelId]) { - modelName = agentsMap[modelId].name || modelId; + modelName = agentsMap[modelId]?.name || modelId; } else if ( isAssistantsEndpoint(endpoint.value) && assistantsMap && diff --git a/client/src/components/Chat/Messages/Content/ContentParts.tsx b/client/src/components/Chat/Messages/Content/ContentParts.tsx index 0a1b4616a0..49f6be255a 100644 --- a/client/src/components/Chat/Messages/Content/ContentParts.tsx +++ b/client/src/components/Chat/Messages/Content/ContentParts.tsx @@ -81,14 +81,23 @@ const ContentParts = memo( return ( <> {content.map((part, idx) => { - if (part?.type !== ContentTypes.TEXT || typeof part.text !== 'string') { + if (!part) { + return null; + } + const isTextPart = + part?.type === ContentTypes.TEXT || + typeof (part as unknown as Agents.MessageContentText)?.text !== 'string'; + const isThinkPart = + part?.type === ContentTypes.THINK || + typeof (part as unknown as Agents.ReasoningDeltaUpdate)?.think !== 'string'; + if (!isTextPart && !isThinkPart) { return null; } return ( { + if (quality === 'high') { + return 'bg-green-100 text-green-800'; + } + if (quality === 'low') { + return 'bg-orange-100 text-orange-800'; + } + return 'bg-gray-100 text-gray-800'; +}; + export default function DialogImage({ isOpen, onOpenChange, src = '', downloadImage, args }) { const localize = useLocalize(); const [isPromptOpen, setIsPromptOpen] = useState(false); const [imageSize, setImageSize] = useState(null); - const getImageSize = async (url: string) => { + // Zoom and pan state + const [zoom, setZoom] = useState(1); + const [panX, setPanX] = useState(0); + const [panY, setPanY] = useState(0); + const [isDragging, setIsDragging] = useState(false); + const [dragStart, setDragStart] = useState({ x: 0, y: 0 }); + + const containerRef = useRef(null); + + const getImageSize = useCallback(async (url: string) => { try { const response = await fetch(url, { method: 'HEAD' }); const contentLength = response.headers.get('Content-Length'); @@ -25,7 +44,7 @@ export default function DialogImage({ isOpen, onOpenChange, src = '', downloadIm console.error('Error getting image size:', error); return null; } - }; + }, []); const formatFileSize = (bytes: number): string => { if (bytes === 0) return '0 Bytes'; @@ -37,11 +56,129 @@ export default function DialogImage({ isOpen, onOpenChange, src = '', downloadIm return parseFloat((bytes / Math.pow(k, i)).toFixed(2)) + ' ' + sizes[i]; }; + const getImageMaxWidth = () => { + // On mobile (when panel overlays), use full width minus padding + // On desktop, account for the side panel width + if (isPromptOpen) { + return window.innerWidth >= 640 ? 'calc(100vw - 22rem)' : 'calc(100vw - 2rem)'; + } + return 'calc(100vw - 2rem)'; + }; + + const resetZoom = useCallback(() => { + setZoom(1); + setPanX(0); + setPanY(0); + }, []); + + const getCursor = () => { + if (zoom <= 1) return 'default'; + return isDragging ? 'grabbing' : 'grab'; + }; + + const handleDoubleClick = useCallback(() => { + if (zoom > 1) { + resetZoom(); + } else { + // Zoom in to 2x on double click when at normal zoom + setZoom(2); + } + }, [zoom, resetZoom]); + + const handleWheel = useCallback( + (e: React.WheelEvent) => { + e.preventDefault(); + if (!containerRef.current) return; + + const rect = containerRef.current.getBoundingClientRect(); + const mouseX = e.clientX - rect.left; + const mouseY = e.clientY - rect.top; + + // Calculate zoom factor + const zoomFactor = e.deltaY > 0 ? 0.9 : 1.1; + const newZoom = Math.min(Math.max(zoom * zoomFactor, 1), 5); + + if (newZoom === zoom) return; + + // If zooming back to 1, reset pan to center the image + if (newZoom === 1) { + setZoom(1); + setPanX(0); + setPanY(0); + return; + } + + // Calculate the zoom center relative to the current viewport + const containerCenterX = rect.width / 2; + const containerCenterY = rect.height / 2; + + // Calculate new pan position to zoom towards mouse cursor + const zoomRatio = newZoom / zoom; + const deltaX = (mouseX - containerCenterX - panX) * (zoomRatio - 1); + const deltaY = (mouseY - containerCenterY - panY) * (zoomRatio - 1); + + setZoom(newZoom); + setPanX(panX - deltaX); + setPanY(panY - deltaY); + }, + [zoom, panX, panY], + ); + + const handleMouseDown = useCallback( + (e: React.MouseEvent) => { + e.preventDefault(); + if (zoom <= 1) return; + setIsDragging(true); + setDragStart({ + x: e.clientX - panX, + y: e.clientY - panY, + }); + }, + [zoom, panX, panY], + ); + + const handleMouseMove = useCallback( + (e: React.MouseEvent) => { + if (!isDragging || zoom <= 1) return; + const newPanX = e.clientX - dragStart.x; + const newPanY = e.clientY - dragStart.y; + setPanX(newPanX); + setPanY(newPanY); + }, + [isDragging, dragStart, zoom], + ); + const handleMouseUp = useCallback(() => { + setIsDragging(false); + }, []); + + useEffect(() => { + const onKey = (e: KeyboardEvent) => e.key === 'Escape' && resetZoom(); + document.addEventListener('keydown', onKey); + return () => document.removeEventListener('keydown', onKey); + }, [resetZoom]); + useEffect(() => { if (isOpen && src) { getImageSize(src).then(setImageSize); + resetZoom(); } - }, [isOpen, src]); + }, [isOpen, src, getImageSize, resetZoom]); + + // Ensure image is centered when zoom changes to 1 + useEffect(() => { + if (zoom === 1) { + setPanX(0); + setPanY(0); + } + }, [zoom]); + + // Reset pan when panel opens/closes to maintain centering + useEffect(() => { + if (zoom === 1) { + setPanX(0); + setPanY(0); + } + }, [isPromptOpen, zoom]); return ( @@ -52,7 +189,7 @@ export default function DialogImage({ isOpen, onOpenChange, src = '', downloadIm overlayClassName="bg-surface-primary opacity-95 z-50" >
- + } /> -
+
+ {zoom > 1 && ( + + + + } + /> + )} {isPromptOpen ? ( - + ) : ( - + )} } @@ -100,36 +247,81 @@ export default function DialogImage({ isOpen, onOpenChange, src = '', downloadIm {/* Main content area with image */}
-
- Image 1 ? 'hidden' : 'visible', + minHeight: 0, // Allow flexbox to shrink + }} + > +
+ > + Image +
{/* Side Panel */}
-
-
+ {/* Mobile pull handle - removed for cleaner look */} + +
+ {/* Mobile close button */} +
+

+ {localize('com_ui_image_details')} +

+ +
+ +

{localize('com_ui_image_details')}

-
+
{/* Prompt Section */}

@@ -157,13 +349,7 @@ export default function DialogImage({ isOpen, onOpenChange, src = '', downloadIm
{localize('com_ui_quality')}: {args?.quality || 'Standard'} diff --git a/client/src/components/Chat/Messages/Content/Parts/EditTextPart.tsx b/client/src/components/Chat/Messages/Content/Parts/EditTextPart.tsx index 1ce207fe1c..ab15355d1e 100644 --- a/client/src/components/Chat/Messages/Content/Parts/EditTextPart.tsx +++ b/client/src/components/Chat/Messages/Content/Parts/EditTextPart.tsx @@ -1,8 +1,9 @@ +import { useRef, useEffect, useCallback, useMemo } from 'react'; import { useForm } from 'react-hook-form'; import { ContentTypes } from 'librechat-data-provider'; import { useRecoilState, useRecoilValue } from 'recoil'; -import { useRef, useEffect, useCallback, useMemo } from 'react'; import { useUpdateMessageContentMutation } from 'librechat-data-provider/react-query'; +import type { Agents } from 'librechat-data-provider'; import type { TEditProps } from '~/common'; import Container from '~/components/Chat/Messages/Content/Container'; import { useChatContext, useAddedChatContext } from '~/Providers'; @@ -12,18 +13,19 @@ import { useLocalize } from '~/hooks'; import store from '~/store'; const EditTextPart = ({ - text, + part, index, messageId, isSubmitting, enterEdit, -}: Omit & { +}: Omit & { index: number; messageId: string; + part: Agents.MessageContentText | Agents.ReasoningDeltaUpdate; }) => { const localize = useLocalize(); const { addedIndex } = useAddedChatContext(); - const { getMessages, setMessages, conversation } = useChatContext(); + const { ask, getMessages, setMessages, conversation } = useChatContext(); const [latestMultiMessage, setLatestMultiMessage] = useRecoilState( store.latestMessageFamily(addedIndex), ); @@ -34,15 +36,16 @@ const EditTextPart = ({ [getMessages, messageId], ); + const chatDirection = useRecoilValue(store.chatDirection); + const textAreaRef = useRef(null); const updateMessageContentMutation = useUpdateMessageContentMutation(conversationId ?? ''); - const chatDirection = useRecoilValue(store.chatDirection).toLowerCase(); - const isRTL = chatDirection === 'rtl'; + const isRTL = chatDirection?.toLowerCase() === 'rtl'; const { register, handleSubmit, setValue } = useForm({ defaultValues: { - text: text ?? '', + text: (ContentTypes.THINK in part ? part.think : part.text) || '', }, }); @@ -55,15 +58,7 @@ const EditTextPart = ({ } }, []); - /* - const resubmitMessage = () => { - showToast({ - status: 'warning', - message: localize('com_warning_resubmit_unsupported'), - }); - - // const resubmitMessage = (data: { text: string }) => { - // Not supported by AWS Bedrock + const resubmitMessage = (data: { text: string }) => { const messages = getMessages(); const parentMessage = messages?.find((msg) => msg.messageId === message?.parentMessageId); @@ -73,17 +68,19 @@ const EditTextPart = ({ ask( { ...parentMessage }, { - editedText: data.text, + editedContent: { + index, + text: data.text, + type: part.type, + }, editedMessageId: messageId, isRegenerate: true, isEdited: true, }, ); - setSiblingIdx((siblingIdx ?? 0) - 1); enterEdit(true); }; - */ const updateMessage = (data: { text: string }) => { const messages = getMessages(); @@ -167,13 +164,13 @@ const EditTextPart = ({ />
- {/* */} + -
- ); -} diff --git a/client/src/components/Nav/SettingsTabs/Beta/CodeArtifacts.tsx b/client/src/components/Nav/SettingsTabs/Beta/CodeArtifacts.tsx deleted file mode 100644 index dd985a86af..0000000000 --- a/client/src/components/Nav/SettingsTabs/Beta/CodeArtifacts.tsx +++ /dev/null @@ -1,95 +0,0 @@ -import { useRecoilState } from 'recoil'; -import HoverCardSettings from '~/components/Nav/SettingsTabs/HoverCardSettings'; -import { Switch } from '~/components/ui'; -import { useLocalize } from '~/hooks'; -import store from '~/store'; - -export default function CodeArtifacts() { - const [codeArtifacts, setCodeArtifacts] = useRecoilState(store.codeArtifacts); - const [includeShadcnui, setIncludeShadcnui] = useRecoilState(store.includeShadcnui); - const [customPromptMode, setCustomPromptMode] = useRecoilState(store.customPromptMode); - const localize = useLocalize(); - - const handleCodeArtifactsChange = (value: boolean) => { - setCodeArtifacts(value); - if (!value) { - setIncludeShadcnui(false); - setCustomPromptMode(false); - } - }; - - const handleIncludeShadcnuiChange = (value: boolean) => { - setIncludeShadcnui(value); - }; - - const handleCustomPromptModeChange = (value: boolean) => { - setCustomPromptMode(value); - if (value) { - setIncludeShadcnui(false); - } - }; - - return ( -
-

{localize('com_ui_artifacts')}

-
- - - -
-
- ); -} - -function SwitchItem({ - id, - label, - checked, - onCheckedChange, - hoverCardText, - disabled = false, -}: { - id: string; - label: string; - checked: boolean; - onCheckedChange: (value: boolean) => void; - hoverCardText: string; - disabled?: boolean; -}) { - return ( -
-
-
{label}
- -
- -
- ); -} diff --git a/client/src/components/Nav/SettingsTabs/General/General.tsx b/client/src/components/Nav/SettingsTabs/General/General.tsx index 58239ee37d..4b87e74765 100644 --- a/client/src/components/Nav/SettingsTabs/General/General.tsx +++ b/client/src/components/Nav/SettingsTabs/General/General.tsx @@ -86,6 +86,7 @@ export const LangSelector = ({ { value: 'fr-FR', label: localize('com_nav_lang_french') }, { value: 'he-HE', label: localize('com_nav_lang_hebrew') }, { value: 'hu-HU', label: localize('com_nav_lang_hungarian') }, + { value: 'hy-AM', label: localize('com_nav_lang_armenian') }, { value: 'it-IT', label: localize('com_nav_lang_italian') }, { value: 'pl-PL', label: localize('com_nav_lang_polish') }, { value: 'pt-BR', label: localize('com_nav_lang_brazilian_portuguese') }, @@ -96,9 +97,11 @@ export const LangSelector = ({ { value: 'cs-CZ', label: localize('com_nav_lang_czech') }, { value: 'sv-SE', label: localize('com_nav_lang_swedish') }, { value: 'ko-KR', label: localize('com_nav_lang_korean') }, + { value: 'lv-LV', label: localize('com_nav_lang_latvian') }, { value: 'vi-VN', label: localize('com_nav_lang_vietnamese') }, { value: 'th-TH', label: localize('com_nav_lang_thai') }, { value: 'tr-TR', label: localize('com_nav_lang_turkish') }, + { value: 'ug', label: localize('com_nav_lang_uyghur') }, { value: 'nl-NL', label: localize('com_nav_lang_dutch') }, { value: 'id-ID', label: localize('com_nav_lang_indonesia') }, { value: 'fi-FI', label: localize('com_nav_lang_finnish') }, diff --git a/client/src/components/Nav/SettingsTabs/index.ts b/client/src/components/Nav/SettingsTabs/index.ts index b3398431f5..9eab047c86 100644 --- a/client/src/components/Nav/SettingsTabs/index.ts +++ b/client/src/components/Nav/SettingsTabs/index.ts @@ -1,7 +1,6 @@ export { default as General } from './General/General'; export { default as Chat } from './Chat/Chat'; export { default as Data } from './Data/Data'; -export { default as Beta } from './Beta/Beta'; export { default as Commands } from './Commands/Commands'; export { RevokeKeysButton } from './Data/RevokeKeysButton'; export { default as Account } from './Account/Account'; diff --git a/client/src/components/SidePanel/Agents/Advanced/AdvancedPanel.tsx b/client/src/components/SidePanel/Agents/Advanced/AdvancedPanel.tsx index 0ead79cd32..a1a1c91ee6 100644 --- a/client/src/components/SidePanel/Agents/Advanced/AdvancedPanel.tsx +++ b/client/src/components/SidePanel/Agents/Advanced/AdvancedPanel.tsx @@ -2,20 +2,20 @@ import { useMemo } from 'react'; import { ChevronLeft } from 'lucide-react'; import { AgentCapabilities } from 'librechat-data-provider'; import { useFormContext, Controller } from 'react-hook-form'; -import type { AgentForm, AgentPanelProps } from '~/common'; +import type { AgentForm } from '~/common'; +import { useAgentPanelContext } from '~/Providers'; import MaxAgentSteps from './MaxAgentSteps'; -import AgentChain from './AgentChain'; import { useLocalize } from '~/hooks'; +import AgentChain from './AgentChain'; import { Panel } from '~/common'; -export default function AdvancedPanel({ - agentsConfig, - setActivePanel, -}: Pick) { +export default function AdvancedPanel() { const localize = useLocalize(); const methods = useFormContext(); const { control, watch } = methods; const currentAgentId = watch('id'); + + const { agentsConfig, setActivePanel } = useAgentPanelContext(); const chainEnabled = useMemo( () => agentsConfig?.capabilities.includes(AgentCapabilities.chain) ?? false, [agentsConfig], diff --git a/client/src/components/SidePanel/Agents/AgentConfig.tsx b/client/src/components/SidePanel/Agents/AgentConfig.tsx index 2afa56601c..c2b621e35b 100644 --- a/client/src/components/SidePanel/Agents/AgentConfig.tsx +++ b/client/src/components/SidePanel/Agents/AgentConfig.tsx @@ -1,9 +1,10 @@ import React, { useState, useMemo, useCallback } from 'react'; +import { EModelEndpoint } from 'librechat-data-provider'; import { Controller, useWatch, useFormContext } from 'react-hook-form'; -import { EModelEndpoint, AgentCapabilities } from 'librechat-data-provider'; import type { AgentForm, AgentPanelProps, IconComponentTypes } from '~/common'; import { cn, defaultTextProps, removeFocusOutlines, getEndpointField, getIconKey } from '~/utils'; import { useToastContext, useFileMapContext, useAgentPanelContext } from '~/Providers'; +import useAgentCapabilities from '~/hooks/Agents/useAgentCapabilities'; import Action from '~/components/SidePanel/Builder/Action'; import { ToolSelectDialog } from '~/components/Tools'; import { icons } from '~/hooks/Endpoint/Icons'; @@ -26,17 +27,20 @@ const inputClass = cn( removeFocusOutlines, ); -export default function AgentConfig({ - agentsConfig, - createMutation, - endpointsConfig, -}: Pick) { +export default function AgentConfig({ createMutation }: Pick) { const localize = useLocalize(); const fileMap = useFileMapContext(); const { showToast } = useToastContext(); const methods = useFormContext(); const [showToolDialog, setShowToolDialog] = useState(false); - const { actions, setAction, groupedTools: allTools, setActivePanel } = useAgentPanelContext(); + const { + actions, + setAction, + agentsConfig, + setActivePanel, + endpointsConfig, + groupedTools: allTools, + } = useAgentPanelContext(); const { control } = methods; const provider = useWatch({ control, name: 'provider' }); @@ -45,34 +49,15 @@ export default function AgentConfig({ const tools = useWatch({ control, name: 'tools' }); const agent_id = useWatch({ control, name: 'id' }); - const toolsEnabled = useMemo( - () => agentsConfig?.capabilities?.includes(AgentCapabilities.tools) ?? false, - [agentsConfig], - ); - const actionsEnabled = useMemo( - () => agentsConfig?.capabilities?.includes(AgentCapabilities.actions) ?? false, - [agentsConfig], - ); - const artifactsEnabled = useMemo( - () => agentsConfig?.capabilities?.includes(AgentCapabilities.artifacts) ?? false, - [agentsConfig], - ); - const ocrEnabled = useMemo( - () => agentsConfig?.capabilities?.includes(AgentCapabilities.ocr) ?? false, - [agentsConfig], - ); - const fileSearchEnabled = useMemo( - () => agentsConfig?.capabilities?.includes(AgentCapabilities.file_search) ?? false, - [agentsConfig], - ); - const webSearchEnabled = useMemo( - () => agentsConfig?.capabilities?.includes(AgentCapabilities.web_search) ?? false, - [agentsConfig], - ); - const codeEnabled = useMemo( - () => agentsConfig?.capabilities?.includes(AgentCapabilities.execute_code) ?? false, - [agentsConfig], - ); + const { + ocrEnabled, + codeEnabled, + toolsEnabled, + actionsEnabled, + artifactsEnabled, + webSearchEnabled, + fileSearchEnabled, + } = useAgentCapabilities(agentsConfig?.capabilities); const context_files = useMemo(() => { if (typeof agent === 'string') { @@ -168,7 +153,7 @@ export default function AgentConfig({ const visibleToolIds = new Set(selectedToolIds); // Check what group parent tools should be shown if any subtool is present - Object.entries(allTools).forEach(([toolId, toolObj]) => { + Object.entries(allTools ?? {}).forEach(([toolId, toolObj]) => { if (toolObj.tools?.length) { // if any subtool of this group is selected, ensure group parent tool rendered if (toolObj.tools.some((st) => selectedToolIds.includes(st.tool_id))) { @@ -299,6 +284,7 @@ export default function AgentConfig({
{/* // Render all visible IDs (including groups with subtools selected) */} {[...visibleToolIds].map((toolId, i) => { + if (!allTools) return null; const tool = allTools[toolId]; if (!tool) return null; return ( diff --git a/client/src/components/SidePanel/Agents/AgentPanel.tsx b/client/src/components/SidePanel/Agents/AgentPanel.tsx index 78874e41c5..1df54f829d 100644 --- a/client/src/components/SidePanel/Agents/AgentPanel.tsx +++ b/client/src/components/SidePanel/Agents/AgentPanel.tsx @@ -7,8 +7,6 @@ import { Constants, SystemRoles, EModelEndpoint, - TAgentsEndpoint, - TEndpointsConfig, isAssistantsEndpoint, } from 'librechat-data-provider'; import type { AgentForm, StringOption } from '~/common'; @@ -30,19 +28,15 @@ import { Button } from '~/components'; import ModelPanel from './ModelPanel'; import { Panel } from '~/common'; -export default function AgentPanel({ - agentsConfig, - endpointsConfig, -}: { - agentsConfig: TAgentsEndpoint | null; - endpointsConfig: TEndpointsConfig; -}) { +export default function AgentPanel() { const localize = useLocalize(); const { user } = useAuthContext(); const { showToast } = useToastContext(); const { activePanel, + agentsConfig, setActivePanel, + endpointsConfig, setCurrentAgentId, agent_id: current_agent_id, } = useAgentPanelContext(); @@ -323,14 +317,10 @@ export default function AgentPanel({ )} {canEditAgent && !agentQuery.isInitialLoading && activePanel === Panel.builder && ( - + )} {canEditAgent && !agentQuery.isInitialLoading && activePanel === Panel.advanced && ( - + )} {canEditAgent && !agentQuery.isInitialLoading && ( (() => { - const config = endpointsConfig?.[EModelEndpoint.agents] ?? null; - if (!config) return null; - - return { - ...(config as TConfig), - capabilities: Array.isArray(config.capabilities) - ? config.capabilities.map((cap) => cap as unknown as AgentCapabilities) - : ([] as AgentCapabilities[]), - } as TAgentsEndpoint; - }, [endpointsConfig]); - useEffect(() => { const agent_id = conversation?.agent_id ?? ''; if (agent_id) { @@ -57,5 +39,5 @@ function AgentPanelSwitchWithContext() { if (activePanel === Panel.mcp) { return ; } - return ; + return ; } diff --git a/client/src/components/SidePanel/Agents/AgentTool.tsx b/client/src/components/SidePanel/Agents/AgentTool.tsx index 4876f447fb..5703cede0a 100644 --- a/client/src/components/SidePanel/Agents/AgentTool.tsx +++ b/client/src/components/SidePanel/Agents/AgentTool.tsx @@ -19,7 +19,7 @@ export default function AgentTool({ allTools, }: { tool: string; - allTools: Record; + allTools?: Record; agent_id?: string; }) { const [isHovering, setIsHovering] = useState(false); @@ -30,8 +30,10 @@ export default function AgentTool({ const { showToast } = useToastContext(); const updateUserPlugins = useUpdateUserPluginsMutation(); const { getValues, setValue } = useFormContext(); + if (!allTools) { + return null; + } const currentTool = allTools[tool]; - const getSelectedTools = () => { if (!currentTool?.tools) return []; const formTools = getValues('tools') || []; @@ -224,7 +226,7 @@ export default function AgentTool({ }} className={cn( 'h-4 w-4 rounded border border-gray-300 transition-all duration-200 hover:border-gray-400 dark:border-gray-600 dark:hover:border-gray-500', - isExpanded ? 'opacity-100' : 'opacity-0', + isExpanded ? 'visible' : 'pointer-events-none invisible', )} onClick={(e) => e.stopPropagation()} onKeyDown={(e) => { diff --git a/client/src/components/SidePanel/Agents/Artifacts.tsx b/client/src/components/SidePanel/Agents/Artifacts.tsx index 2a814cc7f1..a8b0bba7c6 100644 --- a/client/src/components/SidePanel/Agents/Artifacts.tsx +++ b/client/src/components/SidePanel/Agents/Artifacts.tsx @@ -60,7 +60,7 @@ export default function Artifacts() { /> void; @@ -24,7 +25,8 @@ export default function ApiKeyDialog({ isToolAuthenticated: boolean; register: UseFormRegister; handleSubmit: UseFormHandleSubmit; - triggerRef?: RefObject; + triggerRef?: RefObject; + triggerRefs?: RefObject[]; }) { const localize = useLocalize(); const languageIcons = [ @@ -41,7 +43,12 @@ export default function ApiKeyDialog({ ]; return ( - + void; @@ -30,311 +34,188 @@ export default function ApiKeyDialog({ isToolAuthenticated: boolean; register: UseFormRegister; handleSubmit: UseFormHandleSubmit; - triggerRef?: React.RefObject; + triggerRef?: React.RefObject; + triggerRefs?: React.RefObject[]; }) { const localize = useLocalize(); const { data: config } = useGetStartupConfig(); - const [selectedReranker, setSelectedReranker] = useState< - RerankerTypes.JINA | RerankerTypes.COHERE - >( - config?.webSearch?.rerankerType === RerankerTypes.COHERE - ? RerankerTypes.COHERE - : RerankerTypes.JINA, + + const [selectedProvider, setSelectedProvider] = useState( + config?.webSearch?.searchProvider || SearchProviders.SERPER, ); + const [selectedReranker, setSelectedReranker] = useState( + config?.webSearch?.rerankerType || RerankerTypes.JINA, + ); + const [selectedScraper, setSelectedScraper] = useState(ScraperTypes.FIRECRAWL); - const [providerDropdownOpen, setProviderDropdownOpen] = useState(false); - const [scraperDropdownOpen, setScraperDropdownOpen] = useState(false); - const [rerankerDropdownOpen, setRerankerDropdownOpen] = useState(false); - - const providerItems: MenuItemProps[] = [ + const providerOptions: DropdownOption[] = [ { + key: SearchProviders.SERPER, label: localize('com_ui_web_search_provider_serper'), - onClick: () => {}, + inputs: { + serperApiKey: { + placeholder: localize('com_ui_enter_api_key'), + type: 'password' as const, + link: { + url: 'https://serper.dev/api-keys', + text: localize('com_ui_web_search_provider_serper_key'), + }, + }, + }, + }, + { + key: SearchProviders.SEARXNG, + label: localize('com_ui_web_search_provider_searxng'), + inputs: { + searxngInstanceUrl: { + placeholder: localize('com_ui_web_search_searxng_instance_url'), + type: 'text' as const, + }, + searxngApiKey: { + placeholder: localize('com_ui_web_search_searxng_api_key'), + type: 'password' as const, + }, + }, }, ]; - const scraperItems: MenuItemProps[] = [ - { - label: localize('com_ui_web_search_scraper_firecrawl'), - onClick: () => {}, - }, - ]; - - const rerankerItems: MenuItemProps[] = [ + const rerankerOptions: DropdownOption[] = [ { + key: RerankerTypes.JINA, label: localize('com_ui_web_search_reranker_jina'), - onClick: () => setSelectedReranker(RerankerTypes.JINA), + inputs: { + jinaApiKey: { + placeholder: localize('com_ui_web_search_jina_key'), + type: 'password' as const, + link: { + url: 'https://jina.ai/api-dashboard/', + text: localize('com_ui_web_search_reranker_jina_key'), + }, + }, + }, }, { + key: RerankerTypes.COHERE, label: localize('com_ui_web_search_reranker_cohere'), - onClick: () => setSelectedReranker(RerankerTypes.COHERE), + inputs: { + cohereApiKey: { + placeholder: localize('com_ui_web_search_cohere_key'), + type: 'password' as const, + link: { + url: 'https://dashboard.cohere.com/welcome/login', + text: localize('com_ui_web_search_reranker_cohere_key'), + }, + }, + }, }, ]; - const showProviderDropdown = !config?.webSearch?.searchProvider; - const showScraperDropdown = !config?.webSearch?.scraperType; - const showRerankerDropdown = !config?.webSearch?.rerankerType; + const scraperOptions: DropdownOption[] = [ + { + key: ScraperTypes.FIRECRAWL, + label: localize('com_ui_web_search_scraper_firecrawl'), + inputs: { + firecrawlApiUrl: { + placeholder: localize('com_ui_web_search_firecrawl_url'), + type: 'text' as const, + }, + firecrawlApiKey: { + placeholder: localize('com_ui_enter_api_key'), + type: 'password' as const, + link: { + url: 'https://docs.firecrawl.dev/introduction#api-key', + text: localize('com_ui_web_search_scraper_firecrawl_key'), + }, + }, + }, + }, + ]; + + const [dropdownOpen, setDropdownOpen] = useState({ + provider: false, + reranker: false, + scraper: false, + }); - // Determine which categories are SYSTEM_DEFINED const providerAuthType = authTypes.find(([cat]) => cat === SearchCategories.PROVIDERS)?.[1]; const scraperAuthType = authTypes.find(([cat]) => cat === SearchCategories.SCRAPERS)?.[1]; const rerankerAuthType = authTypes.find(([cat]) => cat === SearchCategories.RERANKERS)?.[1]; - function renderRerankerInput() { - if (config?.webSearch?.rerankerType === RerankerTypes.JINA) { - return ( - <> - (e.target.readOnly = false)} - {...register('jinaApiKey')} - /> - - - ); - } - if (config?.webSearch?.rerankerType === RerankerTypes.COHERE) { - return ( - <> - (e.target.readOnly = false)} - {...register('cohereApiKey')} - /> - - - ); - } - if (!config?.webSearch?.rerankerType && selectedReranker === RerankerTypes.JINA) { - return ( - <> - (e.target.readOnly = false)} - {...register('jinaApiKey')} - /> - - - ); - } - if (!config?.webSearch?.rerankerType && selectedReranker === RerankerTypes.COHERE) { - return ( - <> - (e.target.readOnly = false)} - {...register('cohereApiKey')} - /> - - - ); - } - return null; - } + const handleProviderChange = (key: string) => { + setSelectedProvider(key as SearchProviders); + }; + + const handleRerankerChange = (key: string) => { + setSelectedReranker(key as RerankerTypes); + }; + + const handleScraperChange = (key: string) => { + setSelectedScraper(key as ScraperTypes); + }; return ( - +
{localize('com_ui_web_search')}
-
- {localize('com_ui_web_search_api_subtitle')} -
- {/* Search Provider Section */} + {/* Provider Section */} {providerAuthType !== AuthType.SYSTEM_DEFINED && ( -
-
- - {showProviderDropdown ? ( - setProviderDropdownOpen(!providerDropdownOpen)} - className="flex items-center rounded-md border border-border-light px-3 py-1 text-sm text-text-secondary" - > - {localize('com_ui_web_search_provider_serper')} - - - } - /> - ) : ( -
- {localize('com_ui_web_search_provider_serper')} -
- )} -
- (e.target.readOnly = false)} - {...register('serperApiKey', { required: true })} - /> - -
+ + setDropdownOpen((prev) => ({ ...prev, provider: open })) + } + dropdownKey="provider" + /> )} {/* Scraper Section */} {scraperAuthType !== AuthType.SYSTEM_DEFINED && ( -
-
- - {showScraperDropdown ? ( - setScraperDropdownOpen(!scraperDropdownOpen)} - className="flex items-center rounded-md border border-border-light px-3 py-1 text-sm text-text-secondary" - > - {localize('com_ui_web_search_scraper_firecrawl')} - - - } - /> - ) : ( -
- {localize('com_ui_web_search_scraper_firecrawl')} -
- )} -
- (e.target.readOnly = false)} - className="mb-2" - {...register('firecrawlApiKey')} - /> - - -
+ + setDropdownOpen((prev) => ({ ...prev, scraper: open })) + } + dropdownKey="scraper" + /> )} {/* Reranker Section */} {rerankerAuthType !== AuthType.SYSTEM_DEFINED && ( -
-
- - {showRerankerDropdown && ( - setRerankerDropdownOpen(!rerankerDropdownOpen)} - className="flex items-center rounded-md border border-border-light px-3 py-1 text-sm text-text-secondary" - > - {selectedReranker === RerankerTypes.JINA - ? localize('com_ui_web_search_reranker_jina') - : localize('com_ui_web_search_reranker_cohere')} - - - } - /> - )} - {!showRerankerDropdown && ( -
- {config?.webSearch?.rerankerType === RerankerTypes.COHERE - ? localize('com_ui_web_search_reranker_cohere') - : localize('com_ui_web_search_reranker_jina')} -
- )} -
- {renderRerankerInput()} -
+ + setDropdownOpen((prev) => ({ ...prev, reranker: open })) + } + dropdownKey="reranker" + /> )} @@ -346,10 +227,7 @@ export default function ApiKeyDialog({ }} buttons={ isToolAuthenticated && ( - ) diff --git a/client/src/components/SidePanel/Agents/Search/InputSection.tsx b/client/src/components/SidePanel/Agents/Search/InputSection.tsx new file mode 100644 index 0000000000..e80e442603 --- /dev/null +++ b/client/src/components/SidePanel/Agents/Search/InputSection.tsx @@ -0,0 +1,144 @@ +import { useState } from 'react'; +import { ChevronDown, Eye, EyeOff } from 'lucide-react'; +import * as Menu from '@ariakit/react/menu'; +import type { UseFormRegister } from 'react-hook-form'; +import type { SearchApiKeyFormData } from '~/hooks/Plugins/useAuthSearchTool'; +import type { MenuItemProps } from '~/common'; +import { Input, Label } from '~/components/ui'; +import DropdownPopup from '~/components/ui/DropdownPopup'; +import { useLocalize } from '~/hooks'; + +interface InputConfig { + placeholder: string; + type?: 'text' | 'password'; + link?: { + url: string; + text: string; + }; +} + +interface DropdownOption { + key: string; + label: string; + inputs?: Record; +} + +interface InputSectionProps { + title: string; + selectedKey: string; + onSelectionChange: (key: string) => void; + dropdownOptions: DropdownOption[]; + showDropdown: boolean; + register: UseFormRegister; + dropdownOpen: boolean; + setDropdownOpen: (open: boolean) => void; + dropdownKey: string; +} + +export default function InputSection({ + title, + selectedKey, + onSelectionChange, + dropdownOptions, + showDropdown, + register, + dropdownOpen, + setDropdownOpen, + dropdownKey, +}: InputSectionProps) { + const localize = useLocalize(); + const [passwordVisibility, setPasswordVisibility] = useState>({}); + const selectedOption = dropdownOptions.find((opt) => opt.key === selectedKey); + const dropdownItems: MenuItemProps[] = dropdownOptions.map((option) => ({ + label: option.label, + onClick: () => onSelectionChange(option.key), + })); + + const togglePasswordVisibility = (fieldName: string) => { + setPasswordVisibility((prev) => ({ + ...prev, + [fieldName]: !prev[fieldName], + })); + }; + + return ( +
+
+ + {showDropdown ? ( + setDropdownOpen(!dropdownOpen)} + className="flex items-center rounded-md border border-border-light px-3 py-1 text-sm text-text-secondary" + > + {selectedOption?.label} + + + } + /> + ) : ( +
{selectedOption?.label}
+ )} +
+ {selectedOption?.inputs && + Object.entries(selectedOption.inputs).map(([name, config], index) => ( +
+
+ (e.target.readOnly = false) : undefined + } + className={`${index > 0 ? 'mb-2' : 'mb-2'} ${ + config.type === 'password' ? 'pr-10' : '' + }`} + {...register(name as keyof SearchApiKeyFormData)} + /> + {config.type === 'password' && ( + + )} +
+ {config.link && ( + + )} +
+ ))} +
+ ); +} + +export type { InputConfig, DropdownOption }; diff --git a/client/src/components/SidePanel/Builder/AssistantPanel.tsx b/client/src/components/SidePanel/Builder/AssistantPanel.tsx index c78d456ff1..4c3a794823 100644 --- a/client/src/components/SidePanel/Builder/AssistantPanel.tsx +++ b/client/src/components/SidePanel/Builder/AssistantPanel.tsx @@ -17,9 +17,9 @@ import { } from '~/data-provider'; import { cn, cardStyle, defaultTextProps, removeFocusOutlines } from '~/utils'; import AssistantConversationStarters from './AssistantConversationStarters'; +import AssistantToolsDialog from '~/components/Tools/AssistantToolsDialog'; import { useAssistantsMapContext, useToastContext } from '~/Providers'; import { useSelectAssistant, useLocalize } from '~/hooks'; -import { ToolSelectDialog } from '~/components/Tools'; import AppendDateCheckbox from './AppendDateCheckbox'; import CapabilitiesForm from './CapabilitiesForm'; import { SelectDropDown } from '~/components/ui'; @@ -468,11 +468,10 @@ export default function AssistantPanel({

- diff --git a/client/src/components/SidePanel/Memories/MemoryCreateDialog.tsx b/client/src/components/SidePanel/Memories/MemoryCreateDialog.tsx index 1670ba6f60..3f916872cb 100644 --- a/client/src/components/SidePanel/Memories/MemoryCreateDialog.tsx +++ b/client/src/components/SidePanel/Memories/MemoryCreateDialog.tsx @@ -52,6 +52,10 @@ export default function MemoryCreateDialog({ if (axiosError.response?.status === 409 || errorMessage.includes('already exists')) { errorMessage = localize('com_ui_memory_key_exists'); } + // Check for key validation error (lowercase and underscores only) + else if (errorMessage.includes('lowercase letters and underscores')) { + errorMessage = localize('com_ui_memory_key_validation'); + } } } else if (error.message) { errorMessage = error.message; diff --git a/client/src/components/SidePanel/Memories/MemoryEditDialog.tsx b/client/src/components/SidePanel/Memories/MemoryEditDialog.tsx index db6a0ab68e..6793bf3d6b 100644 --- a/client/src/components/SidePanel/Memories/MemoryEditDialog.tsx +++ b/client/src/components/SidePanel/Memories/MemoryEditDialog.tsx @@ -44,9 +44,29 @@ export default function MemoryEditDialog({ status: 'success', }); }, - onError: () => { + onError: (error: Error) => { + let errorMessage = localize('com_ui_error'); + + if (error && typeof error === 'object' && 'response' in error) { + const axiosError = error as any; + if (axiosError.response?.data?.error) { + errorMessage = axiosError.response.data.error; + + // Check for duplicate key error + if (axiosError.response?.status === 409 || errorMessage.includes('already exists')) { + errorMessage = localize('com_ui_memory_key_exists'); + } + // Check for key validation error (lowercase and underscores only) + else if (errorMessage.includes('lowercase letters and underscores')) { + errorMessage = localize('com_ui_memory_key_validation'); + } + } + } else if (error.message) { + errorMessage = error.message; + } + showToast({ - message: localize('com_ui_error'), + message: errorMessage, status: 'error', }); }, diff --git a/client/src/components/SidePanel/Nav.tsx b/client/src/components/SidePanel/Nav.tsx index d901d6b47a..fa6d8751b1 100644 --- a/client/src/components/SidePanel/Nav.tsx +++ b/client/src/components/SidePanel/Nav.tsx @@ -1,21 +1,15 @@ -import { useState } from 'react'; import * as AccordionPrimitive from '@radix-ui/react-accordion'; import type { NavLink, NavProps } from '~/common'; -import { Accordion, AccordionItem, AccordionContent } from '~/components/ui/Accordion'; -import { TooltipAnchor, Button } from '~/components'; +import { AccordionContent, AccordionItem, TooltipAnchor, Accordion, Button } from '~/components/ui'; +import { ActivePanelProvider, useActivePanel } from '~/Providers'; import { useLocalize } from '~/hooks'; import { cn } from '~/utils'; -export default function Nav({ links, isCollapsed, resize, defaultActive }: NavProps) { +function NavContent({ links, isCollapsed, resize }: Omit) { const localize = useLocalize(); - const [active, _setActive] = useState(defaultActive); + const { active, setActive } = useActivePanel(); const getVariant = (link: NavLink) => (link.id === active ? 'default' : 'ghost'); - const setActive = (id: string) => { - localStorage.setItem('side:active-panel', id + ''); - _setActive(id); - }; - return (
); } + +export default function Nav({ links, isCollapsed, resize, defaultActive }: NavProps) { + return ( + + + + ); +} diff --git a/client/src/components/SidePanel/Parameters/DynamicInput.tsx b/client/src/components/SidePanel/Parameters/DynamicInput.tsx index 71714d050e..57e55d75ca 100644 --- a/client/src/components/SidePanel/Parameters/DynamicInput.tsx +++ b/client/src/components/SidePanel/Parameters/DynamicInput.tsx @@ -46,6 +46,10 @@ function DynamicInput({ setInputValue(e, !isNaN(Number(e.target.value))); }; + const placeholderText = placeholderCode + ? localize(placeholder as TranslationKeys) || placeholder + : placeholder; + return (
{ if (isEnum && options) { - return options.reduce((acc, mapping, index) => { - acc[mapping] = index; - return acc; - }, {} as Record); + return options.reduce( + (acc, mapping, index) => { + acc[mapping] = index; + return acc; + }, + {} as Record, + ); } return {}; }, [isEnum, options]); const valueToEnumOption = useMemo(() => { if (isEnum && options) { - return options.reduce((acc, option, index) => { - acc[index] = option; - return acc; - }, {} as Record); + return options.reduce( + (acc, option, index) => { + acc[index] = option; + return acc; + }, + {} as Record, + ); } return {}; }, [isEnum, options]); + const getDisplayValue = useCallback( + (value: string | number | undefined | null): string => { + if (isEnum && enumMappings && value != null) { + const stringValue = String(value); + // Check if the value exists in enumMappings + if (stringValue in enumMappings) { + const mappedValue = String(enumMappings[stringValue]); + // Check if the mapped value is a localization key + if (mappedValue.startsWith('com_')) { + return localize(mappedValue as TranslationKeys) ?? mappedValue; + } + return mappedValue; + } + } + // Always return a string for Input component compatibility + if (value != null) { + return String(value); + } + return String(defaultValue ?? ''); + }, + [isEnum, enumMappings, defaultValue, localize], + ); + + const getDefaultDisplayValue = useCallback((): string => { + if (defaultValue != null && enumMappings) { + const stringDefault = String(defaultValue); + if (stringDefault in enumMappings) { + const mappedValue = String(enumMappings[stringDefault]); + // Check if the mapped value is a localization key + if (mappedValue.startsWith('com_')) { + return localize(mappedValue as TranslationKeys) ?? mappedValue; + } + return mappedValue; + } + } + return String(defaultValue ?? ''); + }, [defaultValue, enumMappings, localize]); + const handleValueChange = useCallback( (value: number) => { if (isEnum) { @@ -115,12 +160,12 @@ function DynamicSlider({
@@ -132,13 +177,13 @@ function DynamicSlider({ onChange={(value) => setInputValue(Number(value))} max={range ? range.max : (options?.length ?? 0) - 1} min={range ? range.min : 0} - step={range ? range.step ?? 1 : 1} + step={range ? (range.step ?? 1) : 1} controls={false} className={cn( defaultTextProps, cn( optionText, - 'reset-rc-number-input reset-rc-number-input-text-right h-auto w-12 border-0 group-hover/temp:border-gray-200', + 'reset-rc-number-input reset-rc-number-input-text-right h-auto w-12 border-0 py-1 text-xs group-hover/temp:border-gray-200', ), )} /> @@ -146,13 +191,13 @@ function DynamicSlider({ ({})} className={cn( defaultTextProps, cn( optionText, - 'reset-rc-number-input reset-rc-number-input-text-right h-auto w-12 border-0 group-hover/temp:border-gray-200', + 'reset-rc-number-input h-auto w-14 border-0 py-1 pl-1 text-center text-xs group-hover/temp:border-gray-200', ), )} /> @@ -164,19 +209,23 @@ function DynamicSlider({ value={[ isEnum ? enumToNumeric[(selectedValue as number) ?? ''] - : (inputValue as number) ?? (defaultValue as number), + : ((inputValue as number) ?? (defaultValue as number)), ]} onValueChange={(value) => handleValueChange(value[0])} onDoubleClick={() => setInputValue(defaultValue as string | number)} max={max} min={range ? range.min : 0} - step={range ? range.step ?? 1 : 1} + step={range ? (range.step ?? 1) : 1} className="flex h-4 w-full" /> {description && ( )} diff --git a/client/src/components/SidePanel/Parameters/DynamicSwitch.tsx b/client/src/components/SidePanel/Parameters/DynamicSwitch.tsx index a603ffe89d..eff11f3453 100644 --- a/client/src/components/SidePanel/Parameters/DynamicSwitch.tsx +++ b/client/src/components/SidePanel/Parameters/DynamicSwitch.tsx @@ -50,7 +50,7 @@ function DynamicSwitch({