diff --git a/.env.example b/.env.example index 142fe88202..e235b6cbb9 100644 --- a/.env.example +++ b/.env.example @@ -20,6 +20,11 @@ DOMAIN_CLIENT=http://localhost:3080 DOMAIN_SERVER=http://localhost:3080 NO_INDEX=true +# Use the address that is at most n number of hops away from the Express application. +# req.socket.remoteAddress is the first hop, and the rest are looked for in the X-Forwarded-For header from right to left. +# A value of 0 means that the first untrusted address would be req.socket.remoteAddress, i.e. there is no reverse proxy. +# Defaulted to 1. +TRUST_PROXY=1 #===============# # JSON Logging # @@ -83,7 +88,7 @@ PROXY= #============# ANTHROPIC_API_KEY=user_provided -# ANTHROPIC_MODELS=claude-3-5-haiku-20241022,claude-3-5-sonnet-20241022,claude-3-5-sonnet-latest,claude-3-5-sonnet-20240620,claude-3-opus-20240229,claude-3-sonnet-20240229,claude-3-haiku-20240307,claude-2.1,claude-2,claude-1.2,claude-1,claude-1-100k,claude-instant-1,claude-instant-1-100k +# ANTHROPIC_MODELS=claude-3-7-sonnet-latest,claude-3-7-sonnet-20250219,claude-3-5-haiku-20241022,claude-3-5-sonnet-20241022,claude-3-5-sonnet-latest,claude-3-5-sonnet-20240620,claude-3-opus-20240229,claude-3-sonnet-20240229,claude-3-haiku-20240307,claude-2.1,claude-2,claude-1.2,claude-1,claude-1-100k,claude-instant-1,claude-instant-1-100k # ANTHROPIC_REVERSE_PROXY= #============# @@ -170,7 +175,7 @@ GOOGLE_KEY=user_provided #============# OPENAI_API_KEY=user_provided -# OPENAI_MODELS=o1,o1-mini,o1-preview,gpt-4o,chatgpt-4o-latest,gpt-4o-mini,gpt-3.5-turbo-0125,gpt-3.5-turbo-0301,gpt-3.5-turbo,gpt-4,gpt-4-0613,gpt-4-vision-preview,gpt-3.5-turbo-0613,gpt-3.5-turbo-16k-0613,gpt-4-0125-preview,gpt-4-turbo-preview,gpt-4-1106-preview,gpt-3.5-turbo-1106,gpt-3.5-turbo-instruct,gpt-3.5-turbo-instruct-0914,gpt-3.5-turbo-16k +# OPENAI_MODELS=o1,o1-mini,o1-preview,gpt-4o,gpt-4.5-preview,chatgpt-4o-latest,gpt-4o-mini,gpt-3.5-turbo-0125,gpt-3.5-turbo-0301,gpt-3.5-turbo,gpt-4,gpt-4-0613,gpt-4-vision-preview,gpt-3.5-turbo-0613,gpt-3.5-turbo-16k-0613,gpt-4-0125-preview,gpt-4-turbo-preview,gpt-4-1106-preview,gpt-3.5-turbo-1106,gpt-3.5-turbo-instruct,gpt-3.5-turbo-instruct-0914,gpt-3.5-turbo-16k DEBUG_OPENAI=false @@ -204,12 +209,6 @@ ASSISTANTS_API_KEY=user_provided # More info, including how to enable use of Assistants with Azure here: # https://www.librechat.ai/docs/configuration/librechat_yaml/ai_endpoints/azure#using-assistants-with-azure -#============# -# OpenRouter # -#============# -# !!!Warning: Use the variable above instead of this one. Using this one will override the OpenAI endpoint -# OPENROUTER_API_KEY= - #============# # Plugins # #============# @@ -249,6 +248,13 @@ AZURE_AI_SEARCH_SEARCH_OPTION_SELECT= # DALLE3_AZURE_API_VERSION= # DALLE2_AZURE_API_VERSION= +# Flux +#----------------- +FLUX_API_BASE_URL=https://api.us1.bfl.ai +# FLUX_API_BASE_URL = 'https://api.bfl.ml'; + +# Get your API key at https://api.us1.bfl.ai/auth/profile +# FLUX_API_KEY= # Google #----------------- @@ -292,6 +298,10 @@ MEILI_NO_ANALYTICS=true MEILI_HOST=http://0.0.0.0:7700 MEILI_MASTER_KEY=DrhYf7zENyR6AlUCKmnz0eYASOQdl6zxH7s7MKFSfFCt +# Optional: Disable indexing, useful in a multi-node setup +# where only one instance should perform an index sync. +# MEILI_NO_SYNC=true + #==================================================# # Speech to Text & Text to Speech # #==================================================# @@ -495,6 +505,16 @@ HELP_AND_FAQ_URL=https://librechat.ai # Google tag manager id #ANALYTICS_GTM_ID=user provided google tag manager id +#===============# +# REDIS Options # +#===============# + +# REDIS_URI=10.10.10.10:6379 +# USE_REDIS=true + +# USE_REDIS_CLUSTER=true +# REDIS_CA=/path/to/ca.crt + #==================================================# # Others # #==================================================# @@ -502,9 +522,6 @@ HELP_AND_FAQ_URL=https://librechat.ai # NODE_ENV= -# REDIS_URI= -# USE_REDIS= - # E2E_USER_EMAIL= # E2E_USER_PASSWORD= diff --git a/.github/ISSUE_TEMPLATE/QUESTION.yml b/.github/ISSUE_TEMPLATE/QUESTION.yml deleted file mode 100644 index c66e6baa3b..0000000000 --- a/.github/ISSUE_TEMPLATE/QUESTION.yml +++ /dev/null @@ -1,50 +0,0 @@ -name: Question -description: Ask your question -title: "[Question]: " -labels: ["❓ question"] -body: - - type: markdown - attributes: - value: | - Thanks for taking the time to fill this! - - type: textarea - id: what-is-your-question - attributes: - label: What is your question? - description: Please give as many details as possible - placeholder: Please give as many details as possible - validations: - required: true - - type: textarea - id: more-details - attributes: - label: More Details - description: Please provide more details if needed. - placeholder: Please provide more details if needed. - validations: - required: true - - type: dropdown - id: browsers - attributes: - label: What is the main subject of your question? - multiple: true - options: - - Documentation - - Installation - - UI - - Endpoints - - User System/OAuth - - Other - - type: textarea - id: screenshots - attributes: - label: Screenshots - description: If applicable, add screenshots to help explain your problem. You can drag and drop, paste images directly here or link to them. - - type: checkboxes - id: terms - attributes: - label: Code of Conduct - description: By submitting this issue, you agree to follow our [Code of Conduct](https://github.com/danny-avila/LibreChat/blob/main/.github/CODE_OF_CONDUCT.md) - options: - - label: I agree to follow this project's Code of Conduct - required: true diff --git a/.github/configuration-release.json b/.github/configuration-release.json new file mode 100644 index 0000000000..68fe80ed8f --- /dev/null +++ b/.github/configuration-release.json @@ -0,0 +1,60 @@ +{ + "categories": [ + { + "title": "### ✨ New Features", + "labels": ["feat"] + }, + { + "title": "### 🌍 Internationalization", + "labels": ["i18n"] + }, + { + "title": "### 👐 Accessibility", + "labels": ["a11y"] + }, + { + "title": "### 🔧 Fixes", + "labels": ["Fix", "fix"] + }, + { + "title": "### ⚙️ Other Changes", + "labels": ["ci", "style", "docs", "refactor", "chore"] + } + ], + "ignore_labels": [ + "🔁 duplicate", + "📊 analytics", + "🌱 good first issue", + "🔍 investigation", + "🙏 help wanted", + "❌ invalid", + "❓ question", + "🚫 wontfix", + "🚀 release", + "version" + ], + "base_branches": ["main"], + "sort": { + "order": "ASC", + "on_property": "mergedAt" + }, + "label_extractor": [ + { + "pattern": "^(?:[^A-Za-z0-9]*)(feat|fix|chore|docs|refactor|ci|style|a11y|i18n)\\s*:", + "target": "$1", + "flags": "i", + "on_property": "title", + "method": "match" + }, + { + "pattern": "^(?:[^A-Za-z0-9]*)(v\\d+\\.\\d+\\.\\d+(?:-rc\\d+)?).*", + "target": "version", + "flags": "i", + "on_property": "title", + "method": "match" + } + ], + "template": "## [#{{TO_TAG}}] - #{{TO_TAG_DATE}}\n\nChanges from #{{FROM_TAG}} to #{{TO_TAG}}.\n\n#{{CHANGELOG}}\n\n[See full release details][release-#{{TO_TAG}}]\n\n[release-#{{TO_TAG}}]: https://github.com/#{{OWNER}}/#{{REPO}}/releases/tag/#{{TO_TAG}}\n\n---", + "pr_template": "- #{{TITLE}} by **@#{{AUTHOR}}** in [##{{NUMBER}}](#{{URL}})", + "empty_template": "- no changes" +} \ No newline at end of file diff --git a/.github/configuration-unreleased.json b/.github/configuration-unreleased.json new file mode 100644 index 0000000000..29eaf5e13b --- /dev/null +++ b/.github/configuration-unreleased.json @@ -0,0 +1,68 @@ +{ + "categories": [ + { + "title": "### ✨ New Features", + "labels": ["feat"] + }, + { + "title": "### 🌍 Internationalization", + "labels": ["i18n"] + }, + { + "title": "### 👐 Accessibility", + "labels": ["a11y"] + }, + { + "title": "### 🔧 Fixes", + "labels": ["Fix", "fix"] + }, + { + "title": "### ⚙️ Other Changes", + "labels": ["ci", "style", "docs", "refactor", "chore"] + } + ], + "ignore_labels": [ + "🔁 duplicate", + "📊 analytics", + "🌱 good first issue", + "🔍 investigation", + "🙏 help wanted", + "❌ invalid", + "❓ question", + "🚫 wontfix", + "🚀 release", + "version", + "action" + ], + "base_branches": ["main"], + "sort": { + "order": "ASC", + "on_property": "mergedAt" + }, + "label_extractor": [ + { + "pattern": "^(?:[^A-Za-z0-9]*)(feat|fix|chore|docs|refactor|ci|style|a11y|i18n)\\s*:", + "target": "$1", + "flags": "i", + "on_property": "title", + "method": "match" + }, + { + "pattern": "^(?:[^A-Za-z0-9]*)(v\\d+\\.\\d+\\.\\d+(?:-rc\\d+)?).*", + "target": "version", + "flags": "i", + "on_property": "title", + "method": "match" + }, + { + "pattern": "^(?:[^A-Za-z0-9]*)(action)\\b.*", + "target": "action", + "flags": "i", + "on_property": "title", + "method": "match" + } + ], + "template": "## [Unreleased]\n\n#{{CHANGELOG}}\n\n---", + "pr_template": "- #{{TITLE}} by **@#{{AUTHOR}}** in [##{{NUMBER}}](#{{URL}})", + "empty_template": "- no changes" +} \ No newline at end of file diff --git a/.github/workflows/backend-review.yml b/.github/workflows/backend-review.yml index 5bc3d3b2db..8469fc366d 100644 --- a/.github/workflows/backend-review.yml +++ b/.github/workflows/backend-review.yml @@ -61,4 +61,7 @@ jobs: run: cd api && npm run test:ci - name: Run librechat-data-provider unit tests - run: cd packages/data-provider && npm run test:ci \ No newline at end of file + run: cd packages/data-provider && npm run test:ci + + - name: Run librechat-mcp unit tests + run: cd packages/mcp && npm run test:ci \ No newline at end of file diff --git a/.github/workflows/generate-release-changelog-pr.yml b/.github/workflows/generate-release-changelog-pr.yml new file mode 100644 index 0000000000..c3bceae9de --- /dev/null +++ b/.github/workflows/generate-release-changelog-pr.yml @@ -0,0 +1,94 @@ +name: Generate Release Changelog PR + +on: + push: + tags: + - 'v*.*.*' + +jobs: + generate-release-changelog-pr: + permissions: + contents: write # Needed for pushing commits and creating branches. + pull-requests: write + runs-on: ubuntu-latest + steps: + # 1. Checkout the repository (with full history). + - name: Checkout Repository + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + # 2. Generate the release changelog using our custom configuration. + - name: Generate Release Changelog + id: generate_release + uses: mikepenz/release-changelog-builder-action@v5.1.0 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + configuration: ".github/configuration-release.json" + owner: ${{ github.repository_owner }} + repo: ${{ github.event.repository.name }} + outputFile: CHANGELOG-release.md + + # 3. Update the main CHANGELOG.md: + # - If it doesn't exist, create it with a basic header. + # - Remove the "Unreleased" section (if present). + # - Prepend the new release changelog above previous releases. + # - Remove all temporary files before committing. + - name: Update CHANGELOG.md + run: | + # Determine the release tag, e.g. "v1.2.3" + TAG=${GITHUB_REF##*/} + echo "Using release tag: $TAG" + + # Ensure CHANGELOG.md exists; if not, create a basic header. + if [ ! -f CHANGELOG.md ]; then + echo "# Changelog" > CHANGELOG.md + echo "" >> CHANGELOG.md + echo "All notable changes to this project will be documented in this file." >> CHANGELOG.md + echo "" >> CHANGELOG.md + fi + + echo "Updating CHANGELOG.md…" + + # Remove the "Unreleased" section (from "## [Unreleased]" until the first occurrence of '---') if it exists. + if grep -q "^## \[Unreleased\]" CHANGELOG.md; then + awk '/^## \[Unreleased\]/{flag=1} flag && /^---/{flag=0; next} !flag' CHANGELOG.md > CHANGELOG.cleaned + else + cp CHANGELOG.md CHANGELOG.cleaned + fi + + # Split the cleaned file into: + # - header.md: content before the first release header ("## [v..."). + # - tail.md: content from the first release header onward. + awk '/^## \[v/{exit} {print}' CHANGELOG.cleaned > header.md + awk 'f{print} /^## \[v/{f=1; print}' CHANGELOG.cleaned > tail.md + + # Combine header, the new release changelog, and the tail. + echo "Combining updated changelog parts..." + cat header.md CHANGELOG-release.md > CHANGELOG.md.new + echo "" >> CHANGELOG.md.new + cat tail.md >> CHANGELOG.md.new + + mv CHANGELOG.md.new CHANGELOG.md + + # Remove temporary files. + rm -f CHANGELOG.cleaned header.md tail.md CHANGELOG-release.md + + echo "Final CHANGELOG.md content:" + cat CHANGELOG.md + + # 4. Create (or update) the Pull Request with the updated CHANGELOG.md. + - name: Create Pull Request + uses: peter-evans/create-pull-request@v7 + with: + token: ${{ secrets.GITHUB_TOKEN }} + sign-commits: true + commit-message: "chore: update CHANGELOG for release ${GITHUB_REF##*/}" + base: main + branch: "changelog/${GITHUB_REF##*/}" + reviewers: danny-avila + title: "chore: update CHANGELOG for release ${GITHUB_REF##*/}" + body: | + **Description**: + - This PR updates the CHANGELOG.md by removing the "Unreleased" section and adding new release notes for release ${GITHUB_REF##*/} above previous releases. \ No newline at end of file diff --git a/.github/workflows/generate-unreleased-changelog-pr.yml b/.github/workflows/generate-unreleased-changelog-pr.yml new file mode 100644 index 0000000000..b130e4fb33 --- /dev/null +++ b/.github/workflows/generate-unreleased-changelog-pr.yml @@ -0,0 +1,106 @@ +name: Generate Unreleased Changelog PR + +on: + schedule: + - cron: "0 0 * * 1" # Runs every Monday at 00:00 UTC + +jobs: + generate-unreleased-changelog-pr: + permissions: + contents: write # Needed for pushing commits and creating branches. + pull-requests: write + runs-on: ubuntu-latest + steps: + # 1. Checkout the repository on main. + - name: Checkout Repository on Main + uses: actions/checkout@v4 + with: + ref: main + fetch-depth: 0 + + # 4. Get the latest version tag. + - name: Get Latest Tag + id: get_latest_tag + run: | + LATEST_TAG=$(git describe --tags $(git rev-list --tags --max-count=1) || echo "none") + echo "Latest tag: $LATEST_TAG" + echo "tag=$LATEST_TAG" >> $GITHUB_OUTPUT + + # 5. Generate the Unreleased changelog. + - name: Generate Unreleased Changelog + id: generate_unreleased + uses: mikepenz/release-changelog-builder-action@v5.1.0 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + configuration: ".github/configuration-unreleased.json" + owner: ${{ github.repository_owner }} + repo: ${{ github.event.repository.name }} + outputFile: CHANGELOG-unreleased.md + fromTag: ${{ steps.get_latest_tag.outputs.tag }} + toTag: main + + # 7. Update CHANGELOG.md with the new Unreleased section. + - name: Update CHANGELOG.md + id: update_changelog + run: | + # Create CHANGELOG.md if it doesn't exist. + if [ ! -f CHANGELOG.md ]; then + echo "# Changelog" > CHANGELOG.md + echo "" >> CHANGELOG.md + echo "All notable changes to this project will be documented in this file." >> CHANGELOG.md + echo "" >> CHANGELOG.md + fi + + echo "Updating CHANGELOG.md…" + + # Extract content before the "## [Unreleased]" (or first version header if missing). + if grep -q "^## \[Unreleased\]" CHANGELOG.md; then + awk '/^## \[Unreleased\]/{exit} {print}' CHANGELOG.md > CHANGELOG_TMP.md + else + awk '/^## \[v/{exit} {print}' CHANGELOG.md > CHANGELOG_TMP.md + fi + + # Append the generated Unreleased changelog. + echo "" >> CHANGELOG_TMP.md + cat CHANGELOG-unreleased.md >> CHANGELOG_TMP.md + echo "" >> CHANGELOG_TMP.md + + # Append the remainder of the original changelog (starting from the first version header). + awk 'f{print} /^## \[v/{f=1; print}' CHANGELOG.md >> CHANGELOG_TMP.md + + # Replace the old file with the updated file. + mv CHANGELOG_TMP.md CHANGELOG.md + + # Remove the temporary generated file. + rm -f CHANGELOG-unreleased.md + + echo "Final CHANGELOG.md:" + cat CHANGELOG.md + + # 8. Check if CHANGELOG.md has any updates. + - name: Check for CHANGELOG.md changes + id: changelog_changes + run: | + if git diff --quiet CHANGELOG.md; then + echo "has_changes=false" >> $GITHUB_OUTPUT + else + echo "has_changes=true" >> $GITHUB_OUTPUT + fi + + # 9. Create (or update) the Pull Request only if there are changes. + - name: Create Pull Request + if: steps.changelog_changes.outputs.has_changes == 'true' + uses: peter-evans/create-pull-request@v7 + with: + token: ${{ secrets.GITHUB_TOKEN }} + base: main + branch: "changelog/unreleased-update" + sign-commits: true + commit-message: "action: update Unreleased changelog" + title: "action: update Unreleased changelog" + body: | + **Description**: + - This PR updates the Unreleased section in CHANGELOG.md. + - It compares the current main branch with the latest version tag (determined as ${{ steps.get_latest_tag.outputs.tag }}), + regenerates the Unreleased changelog, removes any old Unreleased block, and inserts the new content. \ No newline at end of file diff --git a/.github/workflows/unused-packages.yml b/.github/workflows/unused-packages.yml index 7a95f9c5be..442e70e52c 100644 --- a/.github/workflows/unused-packages.yml +++ b/.github/workflows/unused-packages.yml @@ -1,6 +1,12 @@ name: Detect Unused NPM Packages -on: [pull_request] +on: + pull_request: + paths: + - 'package.json' + - 'package-lock.json' + - 'client/**' + - 'api/**' jobs: detect-unused-packages: diff --git a/LICENSE b/LICENSE index 49a224977b..535850a920 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2024 LibreChat +Copyright (c) 2025 LibreChat Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/README.md b/README.md index 2e662ac262..f58b1999e5 100644 --- a/README.md +++ b/README.md @@ -81,7 +81,7 @@ - [Fork Messages & Conversations](https://www.librechat.ai/docs/features/fork) for Advanced Context control - 💬 **Multimodal & File Interactions**: - - Upload and analyze images with Claude 3, GPT-4o, o1, Llama-Vision, and Gemini 📸 + - Upload and analyze images with Claude 3, GPT-4.5, GPT-4o, o1, Llama-Vision, and Gemini 📸 - Chat with Files using Custom Endpoints, OpenAI, Azure, Anthropic, AWS Bedrock, & Google 🗃️ - 🌎 **Multilingual UI**: diff --git a/api/app/clients/AnthropicClient.js b/api/app/clients/AnthropicClient.js index 522b6beb4f..19f4a3930a 100644 --- a/api/app/clients/AnthropicClient.js +++ b/api/app/clients/AnthropicClient.js @@ -7,7 +7,7 @@ const { getResponseSender, validateVisionModel, } = require('librechat-data-provider'); -const { encodeAndFormat } = require('~/server/services/Files/images/encode'); +const { SplitStreamHandler: _Handler, GraphEvents } = require('@librechat/agents'); const { truncateText, formatMessage, @@ -16,16 +16,31 @@ const { parseParamFromPrompt, createContextHandlers, } = require('./prompts'); +const { + getClaudeHeaders, + configureReasoning, + checkPromptCacheSupport, +} = require('~/server/services/Endpoints/anthropic/helpers'); const { getModelMaxTokens, getModelMaxOutputTokens, matchModelName } = require('~/utils'); const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens'); +const { encodeAndFormat } = require('~/server/services/Files/images/encode'); const Tokenizer = require('~/server/services/Tokenizer'); +const { logger, sendEvent } = require('~/config'); const { sleep } = require('~/server/utils'); const BaseClient = require('./BaseClient'); -const { logger } = require('~/config'); const HUMAN_PROMPT = '\n\nHuman:'; const AI_PROMPT = '\n\nAssistant:'; +class SplitStreamHandler extends _Handler { + getDeltaContent(chunk) { + return (chunk?.delta?.text ?? chunk?.completion) || ''; + } + getReasoningDelta(chunk) { + return chunk?.delta?.thinking || ''; + } +} + /** Helper function to introduce a delay before retrying */ function delayBeforeRetry(attempts, baseDelay = 1000) { return new Promise((resolve) => setTimeout(resolve, baseDelay * attempts)); @@ -68,6 +83,8 @@ class AnthropicClient extends BaseClient { /** The key for the usage object's output tokens * @type {string} */ this.outputTokensKey = 'output_tokens'; + /** @type {SplitStreamHandler | undefined} */ + this.streamHandler; } setOptions(options) { @@ -97,9 +114,10 @@ class AnthropicClient extends BaseClient { const modelMatch = matchModelName(this.modelOptions.model, EModelEndpoint.anthropic); this.isClaude3 = modelMatch.includes('claude-3'); - this.isLegacyOutput = !modelMatch.includes('claude-3-5-sonnet'); - this.supportsCacheControl = - this.options.promptCache && this.checkPromptCacheSupport(modelMatch); + this.isLegacyOutput = !( + /claude-3[-.]5-sonnet/.test(modelMatch) || /claude-3[-.]7/.test(modelMatch) + ); + this.supportsCacheControl = this.options.promptCache && checkPromptCacheSupport(modelMatch); if ( this.isLegacyOutput && @@ -125,7 +143,7 @@ class AnthropicClient extends BaseClient { this.options.endpointType ?? this.options.endpoint, this.options.endpointTokenConfig, ) ?? - 1500; + anthropicSettings.maxOutputTokens.reset(this.modelOptions.model); this.maxPromptTokens = this.options.maxPromptTokens || this.maxContextTokens - this.maxResponseTokens; @@ -171,18 +189,9 @@ class AnthropicClient extends BaseClient { options.baseURL = this.options.reverseProxyUrl; } - if ( - this.supportsCacheControl && - requestOptions?.model && - requestOptions.model.includes('claude-3-5-sonnet') - ) { - options.defaultHeaders = { - 'anthropic-beta': 'max-tokens-3-5-sonnet-2024-07-15,prompt-caching-2024-07-31', - }; - } else if (this.supportsCacheControl) { - options.defaultHeaders = { - 'anthropic-beta': 'prompt-caching-2024-07-31', - }; + const headers = getClaudeHeaders(requestOptions?.model, this.supportsCacheControl); + if (headers) { + options.defaultHeaders = headers; } return new Anthropic(options); @@ -668,29 +677,38 @@ class AnthropicClient extends BaseClient { * @returns {Promise} The response from the Anthropic client. */ async createResponse(client, options, useMessages) { - return useMessages ?? this.useMessages + return (useMessages ?? this.useMessages) ? await client.messages.create(options) : await client.completions.create(options); } + getMessageMapMethod() { + /** + * @param {TMessage} msg + */ + return (msg) => { + if (msg.text != null && msg.text && msg.text.startsWith(':::thinking')) { + msg.text = msg.text.replace(/:::thinking.*?:::/gs, '').trim(); + } + + return msg; + }; + } + /** - * @param {string} modelName - * @returns {boolean} + * @param {string[]} [intermediateReply] + * @returns {string} */ - checkPromptCacheSupport(modelName) { - const modelMatch = matchModelName(modelName, EModelEndpoint.anthropic); - if (modelMatch.includes('claude-3-5-sonnet-latest')) { - return false; + getStreamText(intermediateReply) { + if (!this.streamHandler) { + return intermediateReply?.join('') ?? ''; } - if ( - modelMatch === 'claude-3-5-sonnet' || - modelMatch === 'claude-3-5-haiku' || - modelMatch === 'claude-3-haiku' || - modelMatch === 'claude-3-opus' - ) { - return true; - } - return false; + + const reasoningText = this.streamHandler.reasoningTokens.join(''); + + const reasoningBlock = reasoningText.length > 0 ? `:::thinking\n${reasoningText}\n:::\n` : ''; + + return `${reasoningBlock}${this.streamHandler.tokens.join('')}`; } async sendCompletion(payload, { onProgress, abortController }) { @@ -710,7 +728,6 @@ class AnthropicClient extends BaseClient { user_id: this.user, }; - let text = ''; const { stream, model, @@ -721,22 +738,34 @@ class AnthropicClient extends BaseClient { topK: top_k, } = this.modelOptions; - const requestOptions = { + let requestOptions = { model, stream: stream || true, stop_sequences, temperature, metadata, - top_p, - top_k, }; if (this.useMessages) { requestOptions.messages = payload; - requestOptions.max_tokens = maxOutputTokens || legacy.maxOutputTokens.default; + requestOptions.max_tokens = + maxOutputTokens || anthropicSettings.maxOutputTokens.reset(requestOptions.model); } else { requestOptions.prompt = payload; - requestOptions.max_tokens_to_sample = maxOutputTokens || 1500; + requestOptions.max_tokens_to_sample = maxOutputTokens || legacy.maxOutputTokens.default; + } + + requestOptions = configureReasoning(requestOptions, { + thinking: this.options.thinking, + thinkingBudget: this.options.thinkingBudget, + }); + + if (!/claude-3[-.]7/.test(model)) { + requestOptions.top_p = top_p; + requestOptions.top_k = top_k; + } else if (requestOptions.thinking == null) { + requestOptions.topP = top_p; + requestOptions.topK = top_k; } if (this.systemMessage && this.supportsCacheControl === true) { @@ -756,13 +785,17 @@ class AnthropicClient extends BaseClient { } logger.debug('[AnthropicClient]', { ...requestOptions }); + this.streamHandler = new SplitStreamHandler({ + accumulate: true, + runId: this.responseMessageId, + handlers: { + [GraphEvents.ON_RUN_STEP]: (event) => sendEvent(this.options.res, event), + [GraphEvents.ON_MESSAGE_DELTA]: (event) => sendEvent(this.options.res, event), + [GraphEvents.ON_REASONING_DELTA]: (event) => sendEvent(this.options.res, event), + }, + }); - const handleChunk = (currentChunk) => { - if (currentChunk) { - text += currentChunk; - onProgress(currentChunk); - } - }; + let intermediateReply = this.streamHandler.tokens; const maxRetries = 3; const streamRate = this.options.streamRate ?? Constants.DEFAULT_STREAM_RATE; @@ -783,22 +816,15 @@ class AnthropicClient extends BaseClient { }); for await (const completion of response) { - // Handle each completion as before const type = completion?.type ?? ''; if (tokenEventTypes.has(type)) { logger.debug(`[AnthropicClient] ${type}`, completion); this[type] = completion; } - if (completion?.delta?.text) { - handleChunk(completion.delta.text); - } else if (completion.completion) { - handleChunk(completion.completion); - } - + this.streamHandler.handle(completion); await sleep(streamRate); } - // Successful processing, exit loop break; } catch (error) { attempts += 1; @@ -808,6 +834,10 @@ class AnthropicClient extends BaseClient { if (attempts < maxRetries) { await delayBeforeRetry(attempts, 350); + } else if (this.streamHandler && this.streamHandler.reasoningTokens.length) { + return this.getStreamText(); + } else if (intermediateReply.length > 0) { + return this.getStreamText(intermediateReply); } else { throw new Error(`Operation failed after ${maxRetries} attempts: ${error.message}`); } @@ -823,8 +853,7 @@ class AnthropicClient extends BaseClient { } await processResponse.bind(this)(); - - return text.trim(); + return this.getStreamText(intermediateReply); } getSaveOptions() { @@ -834,6 +863,8 @@ class AnthropicClient extends BaseClient { promptPrefix: this.options.promptPrefix, modelLabel: this.options.modelLabel, promptCache: this.options.promptCache, + thinking: this.options.thinking, + thinkingBudget: this.options.thinkingBudget, resendFiles: this.options.resendFiles, iconURL: this.options.iconURL, greeting: this.options.greeting, diff --git a/api/app/clients/BaseClient.js b/api/app/clients/BaseClient.js index 880559cfca..727bce39b2 100644 --- a/api/app/clients/BaseClient.js +++ b/api/app/clients/BaseClient.js @@ -4,10 +4,11 @@ const { isAgentsEndpoint, isParamEndpoint, EModelEndpoint, + excludedKeys, ErrorTypes, Constants, } = require('librechat-data-provider'); -const { getMessages, saveMessage, updateMessage, saveConvo, getUserById } = require('~/models'); +const { getMessages, saveMessage, updateMessage, saveConvo, getConvo, getUserById } = require('~/models'); const { addSpaceIfNeeded, isEnabled } = require('~/server/utils'); const { truncateToolCallOutputs } = require('./prompts'); const checkBalance = require('~/models/checkBalance'); @@ -96,6 +97,10 @@ class BaseClient { * Flag to determine if the client re-submitted the latest assistant message. * @type {boolean | undefined} */ this.continued; + /** + * Flag to determine if the client has already fetched the conversation while saving new messages. + * @type {boolean | undefined} */ + this.fetchedConvo; /** @type {TMessage[]} */ this.currentMessages = []; /** @type {import('librechat-data-provider').VisionModes | undefined} */ @@ -950,16 +955,39 @@ class BaseClient { return { message: savedMessage }; } - const conversation = await saveConvo( - this.options.req, - { - conversationId: message.conversationId, - endpoint: this.options.endpoint, - endpointType: this.options.endpointType, - ...endpointOptions, - }, - { context: 'api/app/clients/BaseClient.js - saveMessageToDatabase #saveConvo' }, - ); + const fieldsToKeep = { + conversationId: message.conversationId, + endpoint: this.options.endpoint, + endpointType: this.options.endpointType, + ...endpointOptions, + }; + + const existingConvo = + this.fetchedConvo === true + ? null + : await getConvo(this.options.req?.user?.id, message.conversationId); + + const unsetFields = {}; + if (existingConvo != null) { + this.fetchedConvo = true; + for (const key in existingConvo) { + if (!key) { + continue; + } + if (excludedKeys.has(key)) { + continue; + } + + if (endpointOptions?.[key] === undefined) { + unsetFields[key] = 1; + } + } + } + + const conversation = await saveConvo(this.options.req, fieldsToKeep, { + context: 'api/app/clients/BaseClient.js - saveMessageToDatabase #saveConvo', + unsetFields, + }); return { message: savedMessage, conversation }; } diff --git a/api/app/clients/GoogleClient.js b/api/app/clients/GoogleClient.js index 03461a6796..58ee783d2a 100644 --- a/api/app/clients/GoogleClient.js +++ b/api/app/clients/GoogleClient.js @@ -51,7 +51,7 @@ class GoogleClient extends BaseClient { const serviceKey = creds[AuthKeys.GOOGLE_SERVICE_KEY] ?? {}; this.serviceKey = - serviceKey && typeof serviceKey === 'string' ? JSON.parse(serviceKey) : serviceKey ?? {}; + serviceKey && typeof serviceKey === 'string' ? JSON.parse(serviceKey) : (serviceKey ?? {}); /** @type {string | null | undefined} */ this.project_id = this.serviceKey.project_id; this.client_email = this.serviceKey.client_email; @@ -73,6 +73,8 @@ class GoogleClient extends BaseClient { * @type {string} */ this.outputTokensKey = 'output_tokens'; this.visionMode = VisionModes.generative; + /** @type {string} */ + this.systemMessage; if (options.skipSetOptions) { return; } @@ -184,7 +186,7 @@ class GoogleClient extends BaseClient { if (typeof this.options.artifactsPrompt === 'string' && this.options.artifactsPrompt) { promptPrefix = `${promptPrefix ?? ''}\n${this.options.artifactsPrompt}`.trim(); } - this.options.promptPrefix = promptPrefix; + this.systemMessage = promptPrefix; this.initializeClient(); return this; } @@ -314,7 +316,7 @@ class GoogleClient extends BaseClient { } this.augmentedPrompt = await this.contextHandlers.createContext(); - this.options.promptPrefix = this.augmentedPrompt + this.options.promptPrefix; + this.systemMessage = this.augmentedPrompt + this.systemMessage; } } @@ -361,8 +363,8 @@ class GoogleClient extends BaseClient { throw new Error('[GoogleClient] PaLM 2 and Codey models are no longer supported.'); } - if (this.options.promptPrefix) { - const instructionsTokenCount = this.getTokenCount(this.options.promptPrefix); + if (this.systemMessage) { + const instructionsTokenCount = this.getTokenCount(this.systemMessage); this.maxContextTokens = this.maxContextTokens - instructionsTokenCount; if (this.maxContextTokens < 0) { @@ -417,8 +419,8 @@ class GoogleClient extends BaseClient { ], }; - if (this.options.promptPrefix) { - payload.instances[0].context = this.options.promptPrefix; + if (this.systemMessage) { + payload.instances[0].context = this.systemMessage; } logger.debug('[GoogleClient] buildMessages', payload); @@ -464,7 +466,7 @@ class GoogleClient extends BaseClient { identityPrefix = `${identityPrefix}\nYou are ${this.options.modelLabel}`; } - let promptPrefix = (this.options.promptPrefix ?? '').trim(); + let promptPrefix = (this.systemMessage ?? '').trim(); if (identityPrefix) { promptPrefix = `${identityPrefix}${promptPrefix}`; @@ -639,7 +641,7 @@ class GoogleClient extends BaseClient { let error; try { if (!EXCLUDED_GENAI_MODELS.test(modelName) && !this.project_id) { - /** @type {GenAI} */ + /** @type {GenerativeModel} */ const client = this.client; /** @type {GenerateContentRequest} */ const requestOptions = { @@ -648,7 +650,7 @@ class GoogleClient extends BaseClient { generationConfig: googleGenConfigSchema.parse(this.modelOptions), }; - const promptPrefix = (this.options.promptPrefix ?? '').trim(); + const promptPrefix = (this.systemMessage ?? '').trim(); if (promptPrefix.length) { requestOptions.systemInstruction = { parts: [ @@ -663,7 +665,17 @@ class GoogleClient extends BaseClient { /** @type {GenAIUsageMetadata} */ let usageMetadata; - const result = await client.generateContentStream(requestOptions); + abortController.signal.addEventListener( + 'abort', + () => { + logger.warn('[GoogleClient] Request was aborted', abortController.signal.reason); + }, + { once: true }, + ); + + const result = await client.generateContentStream(requestOptions, { + signal: abortController.signal, + }); for await (const chunk of result.stream) { usageMetadata = !usageMetadata ? chunk?.usageMetadata @@ -815,7 +827,8 @@ class GoogleClient extends BaseClient { let reply = ''; const { abortController } = options; - const model = this.modelOptions.modelName ?? this.modelOptions.model ?? ''; + const model = + this.options.titleModel ?? this.modelOptions.modelName ?? this.modelOptions.model ?? ''; const safetySettings = getSafetySettings(model); if (!EXCLUDED_GENAI_MODELS.test(model) && !this.project_id) { logger.debug('Identified titling model as GenAI version'); diff --git a/api/app/clients/OpenAIClient.js b/api/app/clients/OpenAIClient.js index 7bd7879dcf..79e65b6d0d 100644 --- a/api/app/clients/OpenAIClient.js +++ b/api/app/clients/OpenAIClient.js @@ -109,15 +109,15 @@ class OpenAIClient extends BaseClient { const omniPattern = /\b(o1|o3)\b/i; this.isOmni = omniPattern.test(this.modelOptions.model); - const { OPENROUTER_API_KEY, OPENAI_FORCE_PROMPT } = process.env ?? {}; - if (OPENROUTER_API_KEY && !this.azure) { - this.apiKey = OPENROUTER_API_KEY; - this.useOpenRouter = true; - } - + const { OPENAI_FORCE_PROMPT } = process.env ?? {}; const { reverseProxyUrl: reverseProxy } = this.options; - if (!this.useOpenRouter && reverseProxy && reverseProxy.includes(KnownEndpoints.openrouter)) { + if ( + !this.useOpenRouter && + ((reverseProxy && reverseProxy.includes(KnownEndpoints.openrouter)) || + (this.options.endpoint && + this.options.endpoint.toLowerCase().includes(KnownEndpoints.openrouter))) + ) { this.useOpenRouter = true; } @@ -303,7 +303,9 @@ class OpenAIClient extends BaseClient { } getEncoding() { - return this.model?.includes('gpt-4o') ? 'o200k_base' : 'cl100k_base'; + return this.modelOptions?.model && /gpt-4[^-\s]/.test(this.modelOptions.model) + ? 'o200k_base' + : 'cl100k_base'; } /** @@ -610,7 +612,7 @@ class OpenAIClient extends BaseClient { } initializeLLM({ - model = 'gpt-4o-mini', + model = openAISettings.model.default, modelName, temperature = 0.2, max_tokens, @@ -711,7 +713,7 @@ class OpenAIClient extends BaseClient { const { OPENAI_TITLE_MODEL } = process.env ?? {}; - let model = this.options.titleModel ?? OPENAI_TITLE_MODEL ?? 'gpt-4o-mini'; + let model = this.options.titleModel ?? OPENAI_TITLE_MODEL ?? openAISettings.model.default; if (model === Constants.CURRENT_MODEL) { model = this.modelOptions.model; } @@ -904,7 +906,7 @@ ${convo} let prompt; // TODO: remove the gpt fallback and make it specific to endpoint - const { OPENAI_SUMMARY_MODEL = 'gpt-4o-mini' } = process.env ?? {}; + const { OPENAI_SUMMARY_MODEL = openAISettings.model.default } = process.env ?? {}; let model = this.options.summaryModel ?? OPENAI_SUMMARY_MODEL; if (model === Constants.CURRENT_MODEL) { model = this.modelOptions.model; @@ -1314,6 +1316,12 @@ ${convo} modelOptions.include_reasoning = true; reasoningKey = 'reasoning'; } + if (this.useOpenRouter && modelOptions.reasoning_effort != null) { + modelOptions.reasoning = { + effort: modelOptions.reasoning_effort, + }; + delete modelOptions.reasoning_effort; + } this.streamHandler = new SplitStreamHandler({ reasoningKey, diff --git a/api/app/clients/prompts/addCacheControl.js b/api/app/clients/prompts/addCacheControl.js index eed5910dc9..6bfd901a65 100644 --- a/api/app/clients/prompts/addCacheControl.js +++ b/api/app/clients/prompts/addCacheControl.js @@ -1,7 +1,7 @@ /** * Anthropic API: Adds cache control to the appropriate user messages in the payload. - * @param {Array} messages - The array of message objects. - * @returns {Array} - The updated array of message objects with cache control added. + * @param {Array} messages - The array of message objects. + * @returns {Array} - The updated array of message objects with cache control added. */ function addCacheControl(messages) { if (!Array.isArray(messages) || messages.length < 2) { @@ -13,7 +13,9 @@ function addCacheControl(messages) { for (let i = updatedMessages.length - 1; i >= 0 && userMessagesModified < 2; i--) { const message = updatedMessages[i]; - if (message.role !== 'user') { + if (message.getType != null && message.getType() !== 'human') { + continue; + } else if (message.getType == null && message.role !== 'user') { continue; } diff --git a/api/app/clients/specs/AnthropicClient.test.js b/api/app/clients/specs/AnthropicClient.test.js index eef6bb6748..223f3038c0 100644 --- a/api/app/clients/specs/AnthropicClient.test.js +++ b/api/app/clients/specs/AnthropicClient.test.js @@ -1,3 +1,4 @@ +const { SplitStreamHandler } = require('@librechat/agents'); const { anthropicSettings } = require('librechat-data-provider'); const AnthropicClient = require('~/app/clients/AnthropicClient'); @@ -405,4 +406,327 @@ describe('AnthropicClient', () => { expect(Number.isNaN(result)).toBe(false); }); }); + + describe('maxOutputTokens handling for different models', () => { + it('should not cap maxOutputTokens for Claude 3.5 Sonnet models', () => { + const client = new AnthropicClient('test-api-key'); + const highTokenValue = anthropicSettings.legacy.maxOutputTokens.default * 10; + + client.setOptions({ + modelOptions: { + model: 'claude-3-5-sonnet', + maxOutputTokens: highTokenValue, + }, + }); + + expect(client.modelOptions.maxOutputTokens).toBe(highTokenValue); + + // Test with decimal notation + client.setOptions({ + modelOptions: { + model: 'claude-3.5-sonnet', + maxOutputTokens: highTokenValue, + }, + }); + + expect(client.modelOptions.maxOutputTokens).toBe(highTokenValue); + }); + + it('should not cap maxOutputTokens for Claude 3.7 models', () => { + const client = new AnthropicClient('test-api-key'); + const highTokenValue = anthropicSettings.legacy.maxOutputTokens.default * 2; + + client.setOptions({ + modelOptions: { + model: 'claude-3-7-sonnet', + maxOutputTokens: highTokenValue, + }, + }); + + expect(client.modelOptions.maxOutputTokens).toBe(highTokenValue); + + // Test with decimal notation + client.setOptions({ + modelOptions: { + model: 'claude-3.7-sonnet', + maxOutputTokens: highTokenValue, + }, + }); + + expect(client.modelOptions.maxOutputTokens).toBe(highTokenValue); + }); + + it('should cap maxOutputTokens for Claude 3.5 Haiku models', () => { + const client = new AnthropicClient('test-api-key'); + const highTokenValue = anthropicSettings.legacy.maxOutputTokens.default * 2; + + client.setOptions({ + modelOptions: { + model: 'claude-3-5-haiku', + maxOutputTokens: highTokenValue, + }, + }); + + expect(client.modelOptions.maxOutputTokens).toBe( + anthropicSettings.legacy.maxOutputTokens.default, + ); + + // Test with decimal notation + client.setOptions({ + modelOptions: { + model: 'claude-3.5-haiku', + maxOutputTokens: highTokenValue, + }, + }); + + expect(client.modelOptions.maxOutputTokens).toBe( + anthropicSettings.legacy.maxOutputTokens.default, + ); + }); + + it('should cap maxOutputTokens for Claude 3 Haiku and Opus models', () => { + const client = new AnthropicClient('test-api-key'); + const highTokenValue = anthropicSettings.legacy.maxOutputTokens.default * 2; + + // Test haiku + client.setOptions({ + modelOptions: { + model: 'claude-3-haiku', + maxOutputTokens: highTokenValue, + }, + }); + + expect(client.modelOptions.maxOutputTokens).toBe( + anthropicSettings.legacy.maxOutputTokens.default, + ); + + // Test opus + client.setOptions({ + modelOptions: { + model: 'claude-3-opus', + maxOutputTokens: highTokenValue, + }, + }); + + expect(client.modelOptions.maxOutputTokens).toBe( + anthropicSettings.legacy.maxOutputTokens.default, + ); + }); + }); + + describe('topK/topP parameters for different models', () => { + beforeEach(() => { + // Mock the SplitStreamHandler + jest.spyOn(SplitStreamHandler.prototype, 'handle').mockImplementation(() => {}); + }); + + afterEach(() => { + jest.restoreAllMocks(); + }); + + it('should include top_k and top_p parameters for non-claude-3.7 models', async () => { + const client = new AnthropicClient('test-api-key'); + + // Create a mock async generator function + async function* mockAsyncGenerator() { + yield { type: 'message_start', message: { usage: {} } }; + yield { delta: { text: 'Test response' } }; + yield { type: 'message_delta', usage: {} }; + } + + // Mock createResponse to return the async generator + jest.spyOn(client, 'createResponse').mockImplementation(() => { + return mockAsyncGenerator(); + }); + + client.setOptions({ + modelOptions: { + model: 'claude-3-opus', + temperature: 0.7, + topK: 10, + topP: 0.9, + }, + }); + + // Mock getClient to capture the request options + let capturedOptions = null; + jest.spyOn(client, 'getClient').mockImplementation((options) => { + capturedOptions = options; + return {}; + }); + + const payload = [{ role: 'user', content: 'Test message' }]; + await client.sendCompletion(payload, {}); + + // Check the options passed to getClient + expect(capturedOptions).toHaveProperty('top_k', 10); + expect(capturedOptions).toHaveProperty('top_p', 0.9); + }); + + it('should include top_k and top_p parameters for claude-3-5-sonnet models', async () => { + const client = new AnthropicClient('test-api-key'); + + // Create a mock async generator function + async function* mockAsyncGenerator() { + yield { type: 'message_start', message: { usage: {} } }; + yield { delta: { text: 'Test response' } }; + yield { type: 'message_delta', usage: {} }; + } + + // Mock createResponse to return the async generator + jest.spyOn(client, 'createResponse').mockImplementation(() => { + return mockAsyncGenerator(); + }); + + client.setOptions({ + modelOptions: { + model: 'claude-3-5-sonnet', + temperature: 0.7, + topK: 10, + topP: 0.9, + }, + }); + + // Mock getClient to capture the request options + let capturedOptions = null; + jest.spyOn(client, 'getClient').mockImplementation((options) => { + capturedOptions = options; + return {}; + }); + + const payload = [{ role: 'user', content: 'Test message' }]; + await client.sendCompletion(payload, {}); + + // Check the options passed to getClient + expect(capturedOptions).toHaveProperty('top_k', 10); + expect(capturedOptions).toHaveProperty('top_p', 0.9); + }); + + it('should not include top_k and top_p parameters for claude-3-7-sonnet models', async () => { + const client = new AnthropicClient('test-api-key'); + + // Create a mock async generator function + async function* mockAsyncGenerator() { + yield { type: 'message_start', message: { usage: {} } }; + yield { delta: { text: 'Test response' } }; + yield { type: 'message_delta', usage: {} }; + } + + // Mock createResponse to return the async generator + jest.spyOn(client, 'createResponse').mockImplementation(() => { + return mockAsyncGenerator(); + }); + + client.setOptions({ + modelOptions: { + model: 'claude-3-7-sonnet', + temperature: 0.7, + topK: 10, + topP: 0.9, + }, + }); + + // Mock getClient to capture the request options + let capturedOptions = null; + jest.spyOn(client, 'getClient').mockImplementation((options) => { + capturedOptions = options; + return {}; + }); + + const payload = [{ role: 'user', content: 'Test message' }]; + await client.sendCompletion(payload, {}); + + // Check the options passed to getClient + expect(capturedOptions).not.toHaveProperty('top_k'); + expect(capturedOptions).not.toHaveProperty('top_p'); + }); + + it('should not include top_k and top_p parameters for models with decimal notation (claude-3.7)', async () => { + const client = new AnthropicClient('test-api-key'); + + // Create a mock async generator function + async function* mockAsyncGenerator() { + yield { type: 'message_start', message: { usage: {} } }; + yield { delta: { text: 'Test response' } }; + yield { type: 'message_delta', usage: {} }; + } + + // Mock createResponse to return the async generator + jest.spyOn(client, 'createResponse').mockImplementation(() => { + return mockAsyncGenerator(); + }); + + client.setOptions({ + modelOptions: { + model: 'claude-3.7-sonnet', + temperature: 0.7, + topK: 10, + topP: 0.9, + }, + }); + + // Mock getClient to capture the request options + let capturedOptions = null; + jest.spyOn(client, 'getClient').mockImplementation((options) => { + capturedOptions = options; + return {}; + }); + + const payload = [{ role: 'user', content: 'Test message' }]; + await client.sendCompletion(payload, {}); + + // Check the options passed to getClient + expect(capturedOptions).not.toHaveProperty('top_k'); + expect(capturedOptions).not.toHaveProperty('top_p'); + }); + }); + + it('should include top_k and top_p parameters for Claude-3.7 models when thinking is explicitly disabled', async () => { + const client = new AnthropicClient('test-api-key', { + modelOptions: { + model: 'claude-3-7-sonnet', + temperature: 0.7, + topK: 10, + topP: 0.9, + }, + thinking: false, + }); + + async function* mockAsyncGenerator() { + yield { type: 'message_start', message: { usage: {} } }; + yield { delta: { text: 'Test response' } }; + yield { type: 'message_delta', usage: {} }; + } + + jest.spyOn(client, 'createResponse').mockImplementation(() => { + return mockAsyncGenerator(); + }); + + let capturedOptions = null; + jest.spyOn(client, 'getClient').mockImplementation((options) => { + capturedOptions = options; + return {}; + }); + + const payload = [{ role: 'user', content: 'Test message' }]; + await client.sendCompletion(payload, {}); + + expect(capturedOptions).toHaveProperty('topK', 10); + expect(capturedOptions).toHaveProperty('topP', 0.9); + + client.setOptions({ + modelOptions: { + model: 'claude-3.7-sonnet', + temperature: 0.7, + topK: 10, + topP: 0.9, + }, + thinking: false, + }); + + await client.sendCompletion(payload, {}); + + expect(capturedOptions).toHaveProperty('topK', 10); + expect(capturedOptions).toHaveProperty('topP', 0.9); + }); }); diff --git a/api/app/clients/specs/BaseClient.test.js b/api/app/clients/specs/BaseClient.test.js index e899449fb9..0dae5b14d3 100644 --- a/api/app/clients/specs/BaseClient.test.js +++ b/api/app/clients/specs/BaseClient.test.js @@ -30,6 +30,8 @@ jest.mock('~/models', () => ({ updateFileUsage: jest.fn(), })); +const { getConvo, saveConvo } = require('~/models'); + jest.mock('@langchain/openai', () => { return { ChatOpenAI: jest.fn().mockImplementation(() => { @@ -540,10 +542,11 @@ describe('BaseClient', () => { test('saveMessageToDatabase is called with the correct arguments', async () => { const saveOptions = TestClient.getSaveOptions(); - const user = {}; // Mock user + const user = {}; const opts = { user }; + const saveSpy = jest.spyOn(TestClient, 'saveMessageToDatabase'); await TestClient.sendMessage('Hello, world!', opts); - expect(TestClient.saveMessageToDatabase).toHaveBeenCalledWith( + expect(saveSpy).toHaveBeenCalledWith( expect.objectContaining({ sender: expect.any(String), text: expect.any(String), @@ -557,6 +560,157 @@ describe('BaseClient', () => { ); }); + test('should handle existing conversation when getConvo retrieves one', async () => { + const existingConvo = { + conversationId: 'existing-convo-id', + endpoint: 'openai', + endpointType: 'openai', + model: 'gpt-3.5-turbo', + messages: [ + { role: 'user', content: 'Existing message 1' }, + { role: 'assistant', content: 'Existing response 1' }, + ], + temperature: 1, + }; + + const { temperature: _temp, ...newConvo } = existingConvo; + + const user = { + id: 'user-id', + }; + + getConvo.mockResolvedValue(existingConvo); + saveConvo.mockResolvedValue(newConvo); + + TestClient = initializeFakeClient( + apiKey, + { + ...options, + req: { + user, + }, + }, + [], + ); + + const saveSpy = jest.spyOn(TestClient, 'saveMessageToDatabase'); + + const newMessage = 'New message in existing conversation'; + const response = await TestClient.sendMessage(newMessage, { + user, + conversationId: existingConvo.conversationId, + }); + + expect(getConvo).toHaveBeenCalledWith(user.id, existingConvo.conversationId); + expect(TestClient.conversationId).toBe(existingConvo.conversationId); + expect(response.conversationId).toBe(existingConvo.conversationId); + expect(TestClient.fetchedConvo).toBe(true); + + expect(saveSpy).toHaveBeenCalledWith( + expect.objectContaining({ + conversationId: existingConvo.conversationId, + text: newMessage, + }), + expect.any(Object), + expect.any(Object), + ); + + expect(saveConvo).toHaveBeenCalledTimes(2); + expect(saveConvo).toHaveBeenCalledWith( + expect.any(Object), + expect.objectContaining({ + conversationId: existingConvo.conversationId, + }), + expect.objectContaining({ + context: 'api/app/clients/BaseClient.js - saveMessageToDatabase #saveConvo', + unsetFields: { + temperature: 1, + }, + }), + ); + + await TestClient.sendMessage('Another message', { + conversationId: existingConvo.conversationId, + }); + expect(getConvo).toHaveBeenCalledTimes(1); + }); + + test('should correctly handle existing conversation and unset fields appropriately', async () => { + const existingConvo = { + conversationId: 'existing-convo-id', + endpoint: 'openai', + endpointType: 'openai', + model: 'gpt-3.5-turbo', + messages: [ + { role: 'user', content: 'Existing message 1' }, + { role: 'assistant', content: 'Existing response 1' }, + ], + title: 'Existing Conversation', + someExistingField: 'existingValue', + anotherExistingField: 'anotherValue', + temperature: 0.7, + modelLabel: 'GPT-3.5', + }; + + getConvo.mockResolvedValue(existingConvo); + saveConvo.mockResolvedValue(existingConvo); + + TestClient = initializeFakeClient( + apiKey, + { + ...options, + modelOptions: { + model: 'gpt-4', + temperature: 0.5, + }, + }, + [], + ); + + const newMessage = 'New message in existing conversation'; + await TestClient.sendMessage(newMessage, { + conversationId: existingConvo.conversationId, + }); + + expect(saveConvo).toHaveBeenCalledTimes(2); + + const saveConvoCall = saveConvo.mock.calls[0]; + const [, savedFields, saveOptions] = saveConvoCall; + + // Instead of checking all excludedKeys, we'll just check specific fields + // that we know should be excluded + expect(savedFields).not.toHaveProperty('messages'); + expect(savedFields).not.toHaveProperty('title'); + + // Only check that someExistingField is in unsetFields + expect(saveOptions.unsetFields).toHaveProperty('someExistingField', 1); + + // Mock saveConvo to return the expected fields + saveConvo.mockImplementation((req, fields) => { + return Promise.resolve({ + ...fields, + endpoint: 'openai', + endpointType: 'openai', + model: 'gpt-4', + temperature: 0.5, + }); + }); + + // Only check the conversationId since that's the only field we can be sure about + expect(savedFields).toHaveProperty('conversationId', 'existing-convo-id'); + + expect(TestClient.fetchedConvo).toBe(true); + + await TestClient.sendMessage('Another message', { + conversationId: existingConvo.conversationId, + }); + + expect(getConvo).toHaveBeenCalledTimes(1); + + const secondSaveConvoCall = saveConvo.mock.calls[1]; + expect(secondSaveConvoCall[2]).toHaveProperty('unsetFields', {}); + }); + test('sendCompletion is called with the correct arguments', async () => { const payload = {}; // Mock payload TestClient.buildMessages.mockReturnValue({ prompt: payload, tokenCountMap: null }); diff --git a/api/app/clients/specs/FakeClient.js b/api/app/clients/specs/FakeClient.js index 7f4b75e1db..a466bb97f9 100644 --- a/api/app/clients/specs/FakeClient.js +++ b/api/app/clients/specs/FakeClient.js @@ -56,7 +56,6 @@ const initializeFakeClient = (apiKey, options, fakeMessages) => { let TestClient = new FakeClient(apiKey); TestClient.options = options; TestClient.abortController = { abort: jest.fn() }; - TestClient.saveMessageToDatabase = jest.fn(); TestClient.loadHistory = jest .fn() .mockImplementation((conversationId, parentMessageId = null) => { @@ -86,7 +85,6 @@ const initializeFakeClient = (apiKey, options, fakeMessages) => { return 'Mock response text'; }); - // eslint-disable-next-line no-unused-vars TestClient.getCompletion = jest.fn().mockImplementation(async (..._args) => { return { choices: [ diff --git a/api/app/clients/specs/OpenAIClient.test.js b/api/app/clients/specs/OpenAIClient.test.js index 2aaec518eb..0e811cf38a 100644 --- a/api/app/clients/specs/OpenAIClient.test.js +++ b/api/app/clients/specs/OpenAIClient.test.js @@ -202,14 +202,6 @@ describe('OpenAIClient', () => { expect(client.modelOptions.temperature).toBe(0.7); }); - it('should set apiKey and useOpenRouter if OPENROUTER_API_KEY is present', () => { - process.env.OPENROUTER_API_KEY = 'openrouter-key'; - client.setOptions({}); - expect(client.apiKey).toBe('openrouter-key'); - expect(client.useOpenRouter).toBe(true); - delete process.env.OPENROUTER_API_KEY; // Cleanup - }); - it('should set FORCE_PROMPT based on OPENAI_FORCE_PROMPT or reverseProxyUrl', () => { process.env.OPENAI_FORCE_PROMPT = 'true'; client.setOptions({}); @@ -534,7 +526,6 @@ describe('OpenAIClient', () => { afterEach(() => { delete process.env.AZURE_OPENAI_DEFAULT_MODEL; delete process.env.AZURE_USE_MODEL_AS_DEPLOYMENT_NAME; - delete process.env.OPENROUTER_API_KEY; }); it('should call getCompletion and fetchEventSource when using a text/instruct model', async () => { diff --git a/api/app/clients/tools/index.js b/api/app/clients/tools/index.js index b8df50c77d..df436fb089 100644 --- a/api/app/clients/tools/index.js +++ b/api/app/clients/tools/index.js @@ -2,9 +2,10 @@ const availableTools = require('./manifest.json'); // Structured Tools const DALLE3 = require('./structured/DALLE3'); +const FluxAPI = require('./structured/FluxAPI'); const OpenWeather = require('./structured/OpenWeather'); -const createYouTubeTools = require('./structured/YouTube'); const StructuredWolfram = require('./structured/Wolfram'); +const createYouTubeTools = require('./structured/YouTube'); const StructuredACS = require('./structured/AzureAISearch'); const StructuredSD = require('./structured/StableDiffusion'); const GoogleSearchAPI = require('./structured/GoogleSearch'); @@ -30,6 +31,7 @@ module.exports = { manifestToolMap, // Structured Tools DALLE3, + FluxAPI, OpenWeather, StructuredSD, StructuredACS, diff --git a/api/app/clients/tools/manifest.json b/api/app/clients/tools/manifest.json index 7cb92b8d87..43be7a4e6c 100644 --- a/api/app/clients/tools/manifest.json +++ b/api/app/clients/tools/manifest.json @@ -164,5 +164,19 @@ "description": "Sign up at OpenWeather, then get your key at API keys." } ] + }, + { + "name": "Flux", + "pluginKey": "flux", + "description": "Generate images using text with the Flux API.", + "icon": "https://blackforestlabs.ai/wp-content/uploads/2024/07/bfl_logo_retraced_blk.png", + "isAuthRequired": "true", + "authConfig": [ + { + "authField": "FLUX_API_KEY", + "label": "Your Flux API Key", + "description": "Provide your Flux API key from your user profile." + } + ] } ] diff --git a/api/app/clients/tools/structured/DALLE3.js b/api/app/clients/tools/structured/DALLE3.js index b604ad4ea4..81200e3a61 100644 --- a/api/app/clients/tools/structured/DALLE3.js +++ b/api/app/clients/tools/structured/DALLE3.js @@ -1,14 +1,17 @@ const { z } = require('zod'); const path = require('path'); const OpenAI = require('openai'); +const fetch = require('node-fetch'); const { v4: uuidv4 } = require('uuid'); const { Tool } = require('@langchain/core/tools'); const { HttpsProxyAgent } = require('https-proxy-agent'); -const { FileContext } = require('librechat-data-provider'); +const { FileContext, ContentTypes } = require('librechat-data-provider'); const { getImageBasename } = require('~/server/services/Files/images'); const extractBaseURL = require('~/utils/extractBaseURL'); const { logger } = require('~/config'); +const displayMessage = + 'DALL-E displayed an image. All generated images are already plainly visible, so don\'t repeat the descriptions in detail. Do not list download links as they are available in the UI already. The user may download the images by clicking on them, but do not mention anything about downloading to the user.'; class DALLE3 extends Tool { constructor(fields = {}) { super(); @@ -114,10 +117,7 @@ class DALLE3 extends Tool { if (this.isAgent === true && typeof value === 'string') { return [value, {}]; } else if (this.isAgent === true && typeof value === 'object') { - return [ - 'DALL-E displayed an image. All generated images are already plainly visible, so don\'t repeat the descriptions in detail. Do not list download links as they are available in the UI already. The user may download the images by clicking on them, but do not mention anything about downloading to the user.', - value, - ]; + return [displayMessage, value]; } return value; @@ -160,6 +160,32 @@ Error Message: ${error.message}`); ); } + if (this.isAgent) { + let fetchOptions = {}; + if (process.env.PROXY) { + fetchOptions.agent = new HttpsProxyAgent(process.env.PROXY); + } + const imageResponse = await fetch(theImageUrl, fetchOptions); + const arrayBuffer = await imageResponse.arrayBuffer(); + const base64 = Buffer.from(arrayBuffer).toString('base64'); + const content = [ + { + type: ContentTypes.IMAGE_URL, + image_url: { + url: `data:image/jpeg;base64,${base64}`, + }, + }, + ]; + + const response = [ + { + type: ContentTypes.TEXT, + text: displayMessage, + }, + ]; + return [response, { content }]; + } + const imageBasename = getImageBasename(theImageUrl); const imageExt = path.extname(imageBasename); diff --git a/api/app/clients/tools/structured/FluxAPI.js b/api/app/clients/tools/structured/FluxAPI.js new file mode 100644 index 0000000000..80f9772200 --- /dev/null +++ b/api/app/clients/tools/structured/FluxAPI.js @@ -0,0 +1,554 @@ +const { z } = require('zod'); +const axios = require('axios'); +const fetch = require('node-fetch'); +const { v4: uuidv4 } = require('uuid'); +const { Tool } = require('@langchain/core/tools'); +const { HttpsProxyAgent } = require('https-proxy-agent'); +const { FileContext, ContentTypes } = require('librechat-data-provider'); +const { logger } = require('~/config'); + +const displayMessage = + 'Flux displayed an image. All generated images are already plainly visible, so don\'t repeat the descriptions in detail. Do not list download links as they are available in the UI already. The user may download the images by clicking on them, but do not mention anything about downloading to the user.'; + +/** + * FluxAPI - A tool for generating high-quality images from text prompts using the Flux API. + * Each call generates one image. If multiple images are needed, make multiple consecutive calls with the same or varied prompts. + */ +class FluxAPI extends Tool { + // Pricing constants in USD per image + static PRICING = { + FLUX_PRO_1_1_ULTRA: -0.06, // /v1/flux-pro-1.1-ultra + FLUX_PRO_1_1: -0.04, // /v1/flux-pro-1.1 + FLUX_PRO: -0.05, // /v1/flux-pro + FLUX_DEV: -0.025, // /v1/flux-dev + FLUX_PRO_FINETUNED: -0.06, // /v1/flux-pro-finetuned + FLUX_PRO_1_1_ULTRA_FINETUNED: -0.07, // /v1/flux-pro-1.1-ultra-finetuned + }; + + constructor(fields = {}) { + super(); + + /** @type {boolean} Used to initialize the Tool without necessary variables. */ + this.override = fields.override ?? false; + + this.userId = fields.userId; + this.fileStrategy = fields.fileStrategy; + + /** @type {boolean} **/ + this.isAgent = fields.isAgent; + this.returnMetadata = fields.returnMetadata ?? false; + + if (fields.processFileURL) { + /** @type {processFileURL} Necessary for output to contain all image metadata. */ + this.processFileURL = fields.processFileURL.bind(this); + } + + this.apiKey = fields.FLUX_API_KEY || this.getApiKey(); + + this.name = 'flux'; + this.description = + 'Use Flux to generate images from text descriptions. This tool can generate images and list available finetunes. Each generate call creates one image. For multiple images, make multiple consecutive calls.'; + + this.description_for_model = `// Transform any image description into a detailed, high-quality prompt. Never submit a prompt under 3 sentences. Follow these core rules: + // 1. ALWAYS enhance basic prompts into 5-10 detailed sentences (e.g., "a cat" becomes: "A close-up photo of a sleek Siamese cat with piercing blue eyes. The cat sits elegantly on a vintage leather armchair, its tail curled gracefully around its paws. Warm afternoon sunlight streams through a nearby window, casting gentle shadows across its face and highlighting the subtle variations in its cream and chocolate-point fur. The background is softly blurred, creating a shallow depth of field that draws attention to the cat's expressive features. The overall composition has a peaceful, contemplative mood with a professional photography style.") + // 2. Each prompt MUST be 3-6 descriptive sentences minimum, focusing on visual elements: lighting, composition, mood, and style + // Use action: 'list_finetunes' to see available custom models. When using finetunes, use endpoint: '/v1/flux-pro-finetuned' (default) or '/v1/flux-pro-1.1-ultra-finetuned' for higher quality and aspect ratio.`; + + // Add base URL from environment variable with fallback + this.baseUrl = process.env.FLUX_API_BASE_URL || 'https://api.us1.bfl.ai'; + + // Define the schema for structured input + this.schema = z.object({ + action: z + .enum(['generate', 'list_finetunes', 'generate_finetuned']) + .default('generate') + .describe( + 'Action to perform: "generate" for image generation, "generate_finetuned" for finetuned model generation, "list_finetunes" to get available custom models', + ), + prompt: z + .string() + .optional() + .describe( + 'Text prompt for image generation. Required when action is "generate". Not used for list_finetunes.', + ), + width: z + .number() + .optional() + .describe( + 'Width of the generated image in pixels. Must be a multiple of 32. Default is 1024.', + ), + height: z + .number() + .optional() + .describe( + 'Height of the generated image in pixels. Must be a multiple of 32. Default is 768.', + ), + prompt_upsampling: z + .boolean() + .optional() + .default(false) + .describe('Whether to perform upsampling on the prompt.'), + steps: z + .number() + .int() + .optional() + .describe('Number of steps to run the model for, a number from 1 to 50. Default is 40.'), + seed: z.number().optional().describe('Optional seed for reproducibility.'), + safety_tolerance: z + .number() + .optional() + .default(6) + .describe( + 'Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict.', + ), + endpoint: z + .enum([ + '/v1/flux-pro-1.1', + '/v1/flux-pro', + '/v1/flux-dev', + '/v1/flux-pro-1.1-ultra', + '/v1/flux-pro-finetuned', + '/v1/flux-pro-1.1-ultra-finetuned', + ]) + .optional() + .default('/v1/flux-pro-1.1') + .describe('Endpoint to use for image generation.'), + raw: z + .boolean() + .optional() + .default(false) + .describe( + 'Generate less processed, more natural-looking images. Only works for /v1/flux-pro-1.1-ultra.', + ), + finetune_id: z.string().optional().describe('ID of the finetuned model to use'), + finetune_strength: z + .number() + .optional() + .default(1.1) + .describe('Strength of the finetuning effect (typically between 0.1 and 1.2)'), + guidance: z.number().optional().default(2.5).describe('Guidance scale for finetuned models'), + aspect_ratio: z + .string() + .optional() + .default('16:9') + .describe('Aspect ratio for ultra models (e.g., "16:9")'), + }); + } + + getAxiosConfig() { + const config = {}; + if (process.env.PROXY) { + config.httpsAgent = new HttpsProxyAgent(process.env.PROXY); + } + return config; + } + + /** @param {Object|string} value */ + getDetails(value) { + if (typeof value === 'string') { + return value; + } + return JSON.stringify(value, null, 2); + } + + getApiKey() { + const apiKey = process.env.FLUX_API_KEY || ''; + if (!apiKey && !this.override) { + throw new Error('Missing FLUX_API_KEY environment variable.'); + } + return apiKey; + } + + wrapInMarkdown(imageUrl) { + const serverDomain = process.env.DOMAIN_SERVER || 'http://localhost:3080'; + return `![generated image](${serverDomain}${imageUrl})`; + } + + returnValue(value) { + if (this.isAgent === true && typeof value === 'string') { + return [value, {}]; + } else if (this.isAgent === true && typeof value === 'object') { + if (Array.isArray(value)) { + return value; + } + return [displayMessage, value]; + } + return value; + } + + async _call(data) { + const { action = 'generate', ...imageData } = data; + + // Use provided API key for this request if available, otherwise use default + const requestApiKey = this.apiKey || this.getApiKey(); + + // Handle list_finetunes action + if (action === 'list_finetunes') { + return this.getMyFinetunes(requestApiKey); + } + + // Handle finetuned generation + if (action === 'generate_finetuned') { + return this.generateFinetunedImage(imageData, requestApiKey); + } + + // For generate action, ensure prompt is provided + if (!imageData.prompt) { + throw new Error('Missing required field: prompt'); + } + + let payload = { + prompt: imageData.prompt, + prompt_upsampling: imageData.prompt_upsampling || false, + safety_tolerance: imageData.safety_tolerance || 6, + output_format: imageData.output_format || 'png', + }; + + // Add optional parameters if provided + if (imageData.width) { + payload.width = imageData.width; + } + if (imageData.height) { + payload.height = imageData.height; + } + if (imageData.steps) { + payload.steps = imageData.steps; + } + if (imageData.seed !== undefined) { + payload.seed = imageData.seed; + } + if (imageData.raw) { + payload.raw = imageData.raw; + } + + const generateUrl = `${this.baseUrl}${imageData.endpoint || '/v1/flux-pro'}`; + const resultUrl = `${this.baseUrl}/v1/get_result`; + + logger.debug('[FluxAPI] Generating image with payload:', payload); + logger.debug('[FluxAPI] Using endpoint:', generateUrl); + + let taskResponse; + try { + taskResponse = await axios.post(generateUrl, payload, { + headers: { + 'x-key': requestApiKey, + 'Content-Type': 'application/json', + Accept: 'application/json', + }, + ...this.getAxiosConfig(), + }); + } catch (error) { + const details = this.getDetails(error?.response?.data || error.message); + logger.error('[FluxAPI] Error while submitting task:', details); + + return this.returnValue( + `Something went wrong when trying to generate the image. The Flux API may be unavailable: + Error Message: ${details}`, + ); + } + + const taskId = taskResponse.data.id; + + // Polling for the result + let status = 'Pending'; + let resultData = null; + while (status !== 'Ready' && status !== 'Error') { + try { + // Wait 2 seconds between polls + await new Promise((resolve) => setTimeout(resolve, 2000)); + const resultResponse = await axios.get(resultUrl, { + headers: { + 'x-key': requestApiKey, + Accept: 'application/json', + }, + params: { id: taskId }, + ...this.getAxiosConfig(), + }); + status = resultResponse.data.status; + + if (status === 'Ready') { + resultData = resultResponse.data.result; + break; + } else if (status === 'Error') { + logger.error('[FluxAPI] Error in task:', resultResponse.data); + return this.returnValue('An error occurred during image generation.'); + } + } catch (error) { + const details = this.getDetails(error?.response?.data || error.message); + logger.error('[FluxAPI] Error while getting result:', details); + return this.returnValue('An error occurred while retrieving the image.'); + } + } + + // If no result data + if (!resultData || !resultData.sample) { + logger.error('[FluxAPI] No image data received from API. Response:', resultData); + return this.returnValue('No image data received from Flux API.'); + } + + // Try saving the image locally + const imageUrl = resultData.sample; + const imageName = `img-${uuidv4()}.png`; + + if (this.isAgent) { + try { + // Fetch the image and convert to base64 + const fetchOptions = {}; + if (process.env.PROXY) { + fetchOptions.agent = new HttpsProxyAgent(process.env.PROXY); + } + const imageResponse = await fetch(imageUrl, fetchOptions); + const arrayBuffer = await imageResponse.arrayBuffer(); + const base64 = Buffer.from(arrayBuffer).toString('base64'); + const content = [ + { + type: ContentTypes.IMAGE_URL, + image_url: { + url: `data:image/png;base64,${base64}`, + }, + }, + ]; + + const response = [ + { + type: ContentTypes.TEXT, + text: displayMessage, + }, + ]; + return [response, { content }]; + } catch (error) { + logger.error('Error processing image for agent:', error); + return this.returnValue(`Failed to process the image. ${error.message}`); + } + } + + try { + logger.debug('[FluxAPI] Saving image:', imageUrl); + const result = await this.processFileURL({ + fileStrategy: this.fileStrategy, + userId: this.userId, + URL: imageUrl, + fileName: imageName, + basePath: 'images', + context: FileContext.image_generation, + }); + + logger.debug('[FluxAPI] Image saved to path:', result.filepath); + + // Calculate cost based on endpoint + /** + * TODO: Cost handling + const endpoint = imageData.endpoint || '/v1/flux-pro'; + const endpointKey = Object.entries(FluxAPI.PRICING).find(([key, _]) => + endpoint.includes(key.toLowerCase().replace(/_/g, '-')), + )?.[0]; + const cost = FluxAPI.PRICING[endpointKey] || 0; + */ + this.result = this.returnMetadata ? result : this.wrapInMarkdown(result.filepath); + return this.returnValue(this.result); + } catch (error) { + const details = this.getDetails(error?.message ?? 'No additional error details.'); + logger.error('Error while saving the image:', details); + return this.returnValue(`Failed to save the image locally. ${details}`); + } + } + + async getMyFinetunes(apiKey = null) { + const finetunesUrl = `${this.baseUrl}/v1/my_finetunes`; + const detailsUrl = `${this.baseUrl}/v1/finetune_details`; + + try { + const headers = { + 'x-key': apiKey || this.getApiKey(), + 'Content-Type': 'application/json', + Accept: 'application/json', + }; + + // Get list of finetunes + const response = await axios.get(finetunesUrl, { + headers, + ...this.getAxiosConfig(), + }); + const finetunes = response.data.finetunes; + + // Fetch details for each finetune + const finetuneDetails = await Promise.all( + finetunes.map(async (finetuneId) => { + try { + const detailResponse = await axios.get(`${detailsUrl}?finetune_id=${finetuneId}`, { + headers, + ...this.getAxiosConfig(), + }); + return { + id: finetuneId, + ...detailResponse.data, + }; + } catch (error) { + logger.error(`[FluxAPI] Error fetching details for finetune ${finetuneId}:`, error); + return { + id: finetuneId, + error: 'Failed to fetch details', + }; + } + }), + ); + + if (this.isAgent) { + const formattedDetails = JSON.stringify(finetuneDetails, null, 2); + return [`Here are the available finetunes:\n${formattedDetails}`, null]; + } + return JSON.stringify(finetuneDetails); + } catch (error) { + const details = this.getDetails(error?.response?.data || error.message); + logger.error('[FluxAPI] Error while getting finetunes:', details); + const errorMsg = `Failed to get finetunes: ${details}`; + return this.isAgent ? this.returnValue([errorMsg, {}]) : new Error(errorMsg); + } + } + + async generateFinetunedImage(imageData, requestApiKey) { + if (!imageData.prompt) { + throw new Error('Missing required field: prompt'); + } + + if (!imageData.finetune_id) { + throw new Error( + 'Missing required field: finetune_id for finetuned generation. Please supply a finetune_id!', + ); + } + + // Validate endpoint is appropriate for finetuned generation + const validFinetunedEndpoints = ['/v1/flux-pro-finetuned', '/v1/flux-pro-1.1-ultra-finetuned']; + const endpoint = imageData.endpoint || '/v1/flux-pro-finetuned'; + + if (!validFinetunedEndpoints.includes(endpoint)) { + throw new Error( + `Invalid endpoint for finetuned generation. Must be one of: ${validFinetunedEndpoints.join(', ')}`, + ); + } + + let payload = { + prompt: imageData.prompt, + prompt_upsampling: imageData.prompt_upsampling || false, + safety_tolerance: imageData.safety_tolerance || 6, + output_format: imageData.output_format || 'png', + finetune_id: imageData.finetune_id, + finetune_strength: imageData.finetune_strength || 1.0, + guidance: imageData.guidance || 2.5, + }; + + // Add optional parameters if provided + if (imageData.width) { + payload.width = imageData.width; + } + if (imageData.height) { + payload.height = imageData.height; + } + if (imageData.steps) { + payload.steps = imageData.steps; + } + if (imageData.seed !== undefined) { + payload.seed = imageData.seed; + } + if (imageData.raw) { + payload.raw = imageData.raw; + } + + const generateUrl = `${this.baseUrl}${endpoint}`; + const resultUrl = `${this.baseUrl}/v1/get_result`; + + logger.debug('[FluxAPI] Generating finetuned image with payload:', payload); + logger.debug('[FluxAPI] Using endpoint:', generateUrl); + + let taskResponse; + try { + taskResponse = await axios.post(generateUrl, payload, { + headers: { + 'x-key': requestApiKey, + 'Content-Type': 'application/json', + Accept: 'application/json', + }, + ...this.getAxiosConfig(), + }); + } catch (error) { + const details = this.getDetails(error?.response?.data || error.message); + logger.error('[FluxAPI] Error while submitting finetuned task:', details); + return this.returnValue( + `Something went wrong when trying to generate the finetuned image. The Flux API may be unavailable: + Error Message: ${details}`, + ); + } + + const taskId = taskResponse.data.id; + + // Polling for the result + let status = 'Pending'; + let resultData = null; + while (status !== 'Ready' && status !== 'Error') { + try { + // Wait 2 seconds between polls + await new Promise((resolve) => setTimeout(resolve, 2000)); + const resultResponse = await axios.get(resultUrl, { + headers: { + 'x-key': requestApiKey, + Accept: 'application/json', + }, + params: { id: taskId }, + ...this.getAxiosConfig(), + }); + status = resultResponse.data.status; + + if (status === 'Ready') { + resultData = resultResponse.data.result; + break; + } else if (status === 'Error') { + logger.error('[FluxAPI] Error in finetuned task:', resultResponse.data); + return this.returnValue('An error occurred during finetuned image generation.'); + } + } catch (error) { + const details = this.getDetails(error?.response?.data || error.message); + logger.error('[FluxAPI] Error while getting finetuned result:', details); + return this.returnValue('An error occurred while retrieving the finetuned image.'); + } + } + + // If no result data + if (!resultData || !resultData.sample) { + logger.error('[FluxAPI] No image data received from API. Response:', resultData); + return this.returnValue('No image data received from Flux API.'); + } + + // Try saving the image locally + const imageUrl = resultData.sample; + const imageName = `img-${uuidv4()}.png`; + + try { + logger.debug('[FluxAPI] Saving finetuned image:', imageUrl); + const result = await this.processFileURL({ + fileStrategy: this.fileStrategy, + userId: this.userId, + URL: imageUrl, + fileName: imageName, + basePath: 'images', + context: FileContext.image_generation, + }); + + logger.debug('[FluxAPI] Finetuned image saved to path:', result.filepath); + + // Calculate cost based on endpoint + const endpointKey = endpoint.includes('ultra') + ? 'FLUX_PRO_1_1_ULTRA_FINETUNED' + : 'FLUX_PRO_FINETUNED'; + const cost = FluxAPI.PRICING[endpointKey] || 0; + // Return the result based on returnMetadata flag + this.result = this.returnMetadata ? result : this.wrapInMarkdown(result.filepath); + return this.returnValue(this.result); + } catch (error) { + const details = this.getDetails(error?.message ?? 'No additional error details.'); + logger.error('Error while saving the finetuned image:', details); + return this.returnValue(`Failed to save the finetuned image locally. ${details}`); + } + } +} + +module.exports = FluxAPI; diff --git a/api/app/clients/tools/structured/StableDiffusion.js b/api/app/clients/tools/structured/StableDiffusion.js index 6309da35d8..25a9e0abd3 100644 --- a/api/app/clients/tools/structured/StableDiffusion.js +++ b/api/app/clients/tools/structured/StableDiffusion.js @@ -6,10 +6,13 @@ const axios = require('axios'); const sharp = require('sharp'); const { v4: uuidv4 } = require('uuid'); const { Tool } = require('@langchain/core/tools'); -const { FileContext } = require('librechat-data-provider'); +const { FileContext, ContentTypes } = require('librechat-data-provider'); const paths = require('~/config/paths'); const { logger } = require('~/config'); +const displayMessage = + 'Stable Diffusion displayed an image. All generated images are already plainly visible, so don\'t repeat the descriptions in detail. Do not list download links as they are available in the UI already. The user may download the images by clicking on them, but do not mention anything about downloading to the user.'; + class StableDiffusionAPI extends Tool { constructor(fields) { super(); @@ -21,6 +24,8 @@ class StableDiffusionAPI extends Tool { this.override = fields.override ?? false; /** @type {boolean} Necessary for output to contain all image metadata. */ this.returnMetadata = fields.returnMetadata ?? false; + /** @type {boolean} */ + this.isAgent = fields.isAgent; if (fields.uploadImageBuffer) { /** @type {uploadImageBuffer} Necessary for output to contain all image metadata. */ this.uploadImageBuffer = fields.uploadImageBuffer.bind(this); @@ -66,6 +71,16 @@ class StableDiffusionAPI extends Tool { return `![generated image](/${imageUrl})`; } + returnValue(value) { + if (this.isAgent === true && typeof value === 'string') { + return [value, {}]; + } else if (this.isAgent === true && typeof value === 'object') { + return [displayMessage, value]; + } + + return value; + } + getServerURL() { const url = process.env.SD_WEBUI_URL || ''; if (!url && !this.override) { @@ -113,6 +128,25 @@ class StableDiffusionAPI extends Tool { } try { + if (this.isAgent) { + const content = [ + { + type: ContentTypes.IMAGE_URL, + image_url: { + url: `data:image/png;base64,${image}`, + }, + }, + ]; + + const response = [ + { + type: ContentTypes.TEXT, + text: displayMessage, + }, + ]; + return [response, { content }]; + } + const buffer = Buffer.from(image.split(',', 1)[0], 'base64'); if (this.returnMetadata && this.uploadImageBuffer && this.req) { const file = await this.uploadImageBuffer({ @@ -154,7 +188,7 @@ class StableDiffusionAPI extends Tool { logger.error('[StableDiffusion] Error while saving the image:', error); } - return this.result; + return this.returnValue(this.result); } } diff --git a/api/app/clients/tools/util/fileSearch.js b/api/app/clients/tools/util/fileSearch.js index 23ba58bb5a..54da483362 100644 --- a/api/app/clients/tools/util/fileSearch.js +++ b/api/app/clients/tools/util/fileSearch.js @@ -106,18 +106,21 @@ const createFileSearchTool = async ({ req, files, entity_id }) => { const formattedResults = validResults .flatMap((result) => - result.data.map(([docInfo, relevanceScore]) => ({ + result.data.map(([docInfo, distance]) => ({ filename: docInfo.metadata.source.split('/').pop(), content: docInfo.page_content, - relevanceScore, + distance, })), ) - .sort((a, b) => b.relevanceScore - a.relevanceScore); + // TODO: results should be sorted by relevance, not distance + .sort((a, b) => a.distance - b.distance) + // TODO: make this configurable + .slice(0, 10); const formattedString = formattedResults .map( (result) => - `File: ${result.filename}\nRelevance: ${result.relevanceScore.toFixed(4)}\nContent: ${ + `File: ${result.filename}\nRelevance: ${1.0 - result.distance.toFixed(4)}\nContent: ${ result.content }\n`, ) diff --git a/api/app/clients/tools/util/handleTools.js b/api/app/clients/tools/util/handleTools.js index f1dfa24a49..ae19a158ee 100644 --- a/api/app/clients/tools/util/handleTools.js +++ b/api/app/clients/tools/util/handleTools.js @@ -10,6 +10,7 @@ const { GoogleSearchAPI, // Structured Tools DALLE3, + FluxAPI, OpenWeather, StructuredSD, StructuredACS, @@ -182,6 +183,7 @@ const loadTools = async ({ returnMap = false, }) => { const toolConstructors = { + flux: FluxAPI, calculator: Calculator, google: GoogleSearchAPI, open_weather: OpenWeather, @@ -230,9 +232,10 @@ const loadTools = async ({ }; const toolOptions = { - serpapi: { location: 'Austin,Texas,United States', hl: 'en', gl: 'us' }, + flux: imageGenOptions, dalle: imageGenOptions, 'stable-diffusion': imageGenOptions, + serpapi: { location: 'Austin,Texas,United States', hl: 'en', gl: 'us' }, }; const toolContextMap = {}; diff --git a/api/cache/keyvRedis.js b/api/cache/keyvRedis.js index d544b50a11..49620c49ae 100644 --- a/api/cache/keyvRedis.js +++ b/api/cache/keyvRedis.js @@ -1,15 +1,81 @@ +const fs = require('fs'); +const ioredis = require('ioredis'); const KeyvRedis = require('@keyv/redis'); const { isEnabled } = require('~/server/utils'); const logger = require('~/config/winston'); -const { REDIS_URI, USE_REDIS } = process.env; +const { REDIS_URI, USE_REDIS, USE_REDIS_CLUSTER, REDIS_CA, REDIS_KEY_PREFIX, REDIS_MAX_LISTENERS } = + process.env; let keyvRedis; +const redis_prefix = REDIS_KEY_PREFIX || ''; +const redis_max_listeners = Number(REDIS_MAX_LISTENERS) || 10; + +function mapURI(uri) { + const regex = + /^(?:(?\w+):\/\/)?(?:(?[^:@]+)(?::(?[^@]+))?@)?(?[\w.-]+)(?::(?\d{1,5}))?$/; + const match = uri.match(regex); + + if (match) { + const { scheme, user, password, host, port } = match.groups; + + return { + scheme: scheme || 'none', + user: user || null, + password: password || null, + host: host || null, + port: port || null, + }; + } else { + const parts = uri.split(':'); + if (parts.length === 2) { + return { + scheme: 'none', + user: null, + password: null, + host: parts[0], + port: parts[1], + }; + } + + return { + scheme: 'none', + user: null, + password: null, + host: uri, + port: null, + }; + } +} if (REDIS_URI && isEnabled(USE_REDIS)) { - keyvRedis = new KeyvRedis(REDIS_URI, { useRedisSets: false }); + let redisOptions = null; + let keyvOpts = { + useRedisSets: false, + keyPrefix: redis_prefix, + }; + + if (REDIS_CA) { + const ca = fs.readFileSync(REDIS_CA); + redisOptions = { tls: { ca } }; + } + + if (isEnabled(USE_REDIS_CLUSTER)) { + const hosts = REDIS_URI.split(',').map((item) => { + var value = mapURI(item); + + return { + host: value.host, + port: value.port, + }; + }); + const cluster = new ioredis.Cluster(hosts, { redisOptions }); + keyvRedis = new KeyvRedis(cluster, keyvOpts); + } else { + keyvRedis = new KeyvRedis(REDIS_URI, keyvOpts); + } keyvRedis.on('error', (err) => logger.error('KeyvRedis connection error:', err)); - keyvRedis.setMaxListeners(20); + keyvRedis.setMaxListeners(redis_max_listeners); logger.info( '[Optional] Redis initialized. Note: Redis support is experimental. If you have issues, disable it. Cache needs to be flushed for values to refresh.', ); diff --git a/api/lib/db/indexSync.js b/api/lib/db/indexSync.js index 86c909419d..9c40e684d3 100644 --- a/api/lib/db/indexSync.js +++ b/api/lib/db/indexSync.js @@ -1,9 +1,11 @@ const { MeiliSearch } = require('meilisearch'); const Conversation = require('~/models/schema/convoSchema'); const Message = require('~/models/schema/messageSchema'); +const { isEnabled } = require('~/server/utils'); const { logger } = require('~/config'); -const searchEnabled = process.env?.SEARCH?.toLowerCase() === 'true'; +const searchEnabled = isEnabled(process.env.SEARCH); +const indexingDisabled = isEnabled(process.env.MEILI_NO_SYNC); let currentTimeout = null; class MeiliSearchClient { @@ -23,8 +25,7 @@ class MeiliSearchClient { } } -// eslint-disable-next-line no-unused-vars -async function indexSync(req, res, next) { +async function indexSync() { if (!searchEnabled) { return; } @@ -33,10 +34,15 @@ async function indexSync(req, res, next) { const client = MeiliSearchClient.getInstance(); const { status } = await client.health(); - if (status !== 'available' || !process.env.SEARCH) { + if (status !== 'available') { throw new Error('Meilisearch not available'); } + if (indexingDisabled === true) { + logger.info('[indexSync] Indexing is disabled, skipping...'); + return; + } + const messageCount = await Message.countDocuments(); const convoCount = await Conversation.countDocuments(); const messages = await client.index('messages').getStats(); @@ -71,7 +77,6 @@ async function indexSync(req, res, next) { logger.info('[indexSync] Meilisearch not configured, search will be disabled.'); } else { logger.error('[indexSync] error', err); - // res.status(500).json({ error: 'Server error' }); } } } diff --git a/api/models/Agent.js b/api/models/Agent.js index 6fa00f56bc..6ea203113c 100644 --- a/api/models/Agent.js +++ b/api/models/Agent.js @@ -97,11 +97,22 @@ const updateAgent = async (searchParameter, updateData) => { const addAgentResourceFile = async ({ agent_id, tool_resource, file_id }) => { const searchParameter = { id: agent_id }; - // build the update to push or create the file ids set const fileIdsPath = `tool_resources.${tool_resource}.file_ids`; + + await Agent.updateOne( + { + id: agent_id, + [`${fileIdsPath}`]: { $exists: false }, + }, + { + $set: { + [`${fileIdsPath}`]: [], + }, + }, + ); + const updateData = { $addToSet: { [fileIdsPath]: file_id } }; - // return the updated agent or throw if no agent matches const updatedAgent = await updateAgent(searchParameter, updateData); if (updatedAgent) { return updatedAgent; @@ -290,6 +301,7 @@ const updateAgentProjects = async ({ user, agentId, projectIds, removeProjectIds }; module.exports = { + Agent, getAgent, loadAgent, createAgent, diff --git a/api/models/Agent.spec.js b/api/models/Agent.spec.js new file mode 100644 index 0000000000..769eda2bb7 --- /dev/null +++ b/api/models/Agent.spec.js @@ -0,0 +1,160 @@ +const mongoose = require('mongoose'); +const { v4: uuidv4 } = require('uuid'); +const { MongoMemoryServer } = require('mongodb-memory-server'); +const { Agent, addAgentResourceFile, removeAgentResourceFiles } = require('./Agent'); + +describe('Agent Resource File Operations', () => { + let mongoServer; + + beforeAll(async () => { + mongoServer = await MongoMemoryServer.create(); + const mongoUri = mongoServer.getUri(); + await mongoose.connect(mongoUri); + }); + + afterAll(async () => { + await mongoose.disconnect(); + await mongoServer.stop(); + }); + + beforeEach(async () => { + await Agent.deleteMany({}); + }); + + const createBasicAgent = async () => { + const agentId = `agent_${uuidv4()}`; + const agent = await Agent.create({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: new mongoose.Types.ObjectId(), + }); + return agent; + }; + + test('should handle concurrent file additions', async () => { + const agent = await createBasicAgent(); + const fileIds = Array.from({ length: 10 }, () => uuidv4()); + + // Concurrent additions + const additionPromises = fileIds.map((fileId) => + addAgentResourceFile({ + agent_id: agent.id, + tool_resource: 'test_tool', + file_id: fileId, + }), + ); + + await Promise.all(additionPromises); + + const updatedAgent = await Agent.findOne({ id: agent.id }); + expect(updatedAgent.tool_resources.test_tool.file_ids).toBeDefined(); + expect(updatedAgent.tool_resources.test_tool.file_ids).toHaveLength(10); + expect(new Set(updatedAgent.tool_resources.test_tool.file_ids).size).toBe(10); + }); + + test('should handle concurrent additions and removals', async () => { + const agent = await createBasicAgent(); + const initialFileIds = Array.from({ length: 5 }, () => uuidv4()); + + await Promise.all( + initialFileIds.map((fileId) => + addAgentResourceFile({ + agent_id: agent.id, + tool_resource: 'test_tool', + file_id: fileId, + }), + ), + ); + + const newFileIds = Array.from({ length: 5 }, () => uuidv4()); + const operations = [ + ...newFileIds.map((fileId) => + addAgentResourceFile({ + agent_id: agent.id, + tool_resource: 'test_tool', + file_id: fileId, + }), + ), + ...initialFileIds.map((fileId) => + removeAgentResourceFiles({ + agent_id: agent.id, + files: [{ tool_resource: 'test_tool', file_id: fileId }], + }), + ), + ]; + + await Promise.all(operations); + + const updatedAgent = await Agent.findOne({ id: agent.id }); + expect(updatedAgent.tool_resources.test_tool.file_ids).toBeDefined(); + expect(updatedAgent.tool_resources.test_tool.file_ids).toHaveLength(5); + }); + + test('should initialize array when adding to non-existent tool resource', async () => { + const agent = await createBasicAgent(); + const fileId = uuidv4(); + + const updatedAgent = await addAgentResourceFile({ + agent_id: agent.id, + tool_resource: 'new_tool', + file_id: fileId, + }); + + expect(updatedAgent.tool_resources.new_tool.file_ids).toBeDefined(); + expect(updatedAgent.tool_resources.new_tool.file_ids).toHaveLength(1); + expect(updatedAgent.tool_resources.new_tool.file_ids[0]).toBe(fileId); + }); + + test('should handle rapid sequential modifications to same tool resource', async () => { + const agent = await createBasicAgent(); + const fileId = uuidv4(); + + for (let i = 0; i < 10; i++) { + await addAgentResourceFile({ + agent_id: agent.id, + tool_resource: 'test_tool', + file_id: `${fileId}_${i}`, + }); + + if (i % 2 === 0) { + await removeAgentResourceFiles({ + agent_id: agent.id, + files: [{ tool_resource: 'test_tool', file_id: `${fileId}_${i}` }], + }); + } + } + + const updatedAgent = await Agent.findOne({ id: agent.id }); + expect(updatedAgent.tool_resources.test_tool.file_ids).toBeDefined(); + expect(Array.isArray(updatedAgent.tool_resources.test_tool.file_ids)).toBe(true); + }); + + test('should handle multiple tool resources concurrently', async () => { + const agent = await createBasicAgent(); + const toolResources = ['tool1', 'tool2', 'tool3']; + const operations = []; + + toolResources.forEach((tool) => { + const fileIds = Array.from({ length: 5 }, () => uuidv4()); + fileIds.forEach((fileId) => { + operations.push( + addAgentResourceFile({ + agent_id: agent.id, + tool_resource: tool, + file_id: fileId, + }), + ); + }); + }); + + await Promise.all(operations); + + const updatedAgent = await Agent.findOne({ id: agent.id }); + toolResources.forEach((tool) => { + expect(updatedAgent.tool_resources[tool].file_ids).toBeDefined(); + expect(updatedAgent.tool_resources[tool].file_ids).toHaveLength(5); + }); + }); +}); diff --git a/api/models/Categories.js b/api/models/Categories.js index 605b68d176..6fb88fb995 100644 --- a/api/models/Categories.js +++ b/api/models/Categories.js @@ -3,40 +3,40 @@ const { logger } = require('~/config'); const options = [ { - label: 'idea', - value: 'com_ui_idea', + label: 'com_ui_idea', + value: 'idea', }, { - label: 'travel', - value: 'com_ui_travel', + label: 'com_ui_travel', + value: 'travel', }, { - label: 'teach_or_explain', - value: 'com_ui_teach_or_explain', + label: 'com_ui_teach_or_explain', + value: 'teach_or_explain', }, { - label: 'write', - value: 'com_ui_write', + label: 'com_ui_write', + value: 'write', }, { - label: 'shop', - value: 'com_ui_shop', + label: 'com_ui_shop', + value: 'shop', }, { - label: 'code', - value: 'com_ui_code', + label: 'com_ui_code', + value: 'code', }, { - label: 'misc', - value: 'com_ui_misc', + label: 'com_ui_misc', + value: 'misc', }, { - label: 'roleplay', - value: 'com_ui_roleplay', + label: 'com_ui_roleplay', + value: 'roleplay', }, { - label: 'finance', - value: 'com_ui_finance', + label: 'com_ui_finance', + value: 'finance', }, ]; diff --git a/api/models/Conversation.js b/api/models/Conversation.js index d6365e99ce..9e51926ebc 100644 --- a/api/models/Conversation.js +++ b/api/models/Conversation.js @@ -104,10 +104,16 @@ module.exports = { update.expiredAt = null; } + /** @type {{ $set: Partial; $unset?: Record }} */ + const updateOperation = { $set: update }; + if (metadata && metadata.unsetFields && Object.keys(metadata.unsetFields).length > 0) { + updateOperation.$unset = metadata.unsetFields; + } + /** Note: the resulting Model object is necessary for Meilisearch operations */ const conversation = await Conversation.findOneAndUpdate( { conversationId, user: req.user.id }, - update, + updateOperation, { new: true, upsert: true, diff --git a/api/models/Role.js b/api/models/Role.js index 9c160512b7..3f02570718 100644 --- a/api/models/Role.js +++ b/api/models/Role.js @@ -6,8 +6,10 @@ const { removeNullishValues, agentPermissionsSchema, promptPermissionsSchema, + runCodePermissionsSchema, bookmarkPermissionsSchema, multiConvoPermissionsSchema, + temporaryChatPermissionsSchema, } = require('librechat-data-provider'); const getLogStores = require('~/cache/getLogStores'); const Role = require('~/models/schema/roleSchema'); @@ -77,6 +79,8 @@ const permissionSchemas = { [PermissionTypes.PROMPTS]: promptPermissionsSchema, [PermissionTypes.BOOKMARKS]: bookmarkPermissionsSchema, [PermissionTypes.MULTI_CONVO]: multiConvoPermissionsSchema, + [PermissionTypes.TEMPORARY_CHAT]: temporaryChatPermissionsSchema, + [PermissionTypes.RUN_CODE]: runCodePermissionsSchema, }; /** diff --git a/api/models/Token.js b/api/models/Token.js index 210666ddd7..0ed18320ae 100644 --- a/api/models/Token.js +++ b/api/models/Token.js @@ -13,6 +13,13 @@ const Token = mongoose.model('Token', tokenSchema); */ async function fixIndexes() { try { + if ( + process.env.NODE_ENV === 'CI' || + process.env.NODE_ENV === 'development' || + process.env.NODE_ENV === 'test' + ) { + return; + } const indexes = await Token.collection.indexes(); logger.debug('Existing Token Indexes:', JSON.stringify(indexes, null, 2)); const unwantedTTLIndexes = indexes.filter( diff --git a/api/models/schema/convoSchema.js b/api/models/schema/convoSchema.js index 7d8beed6a6..ae50b7cd22 100644 --- a/api/models/schema/convoSchema.js +++ b/api/models/schema/convoSchema.js @@ -20,8 +20,6 @@ const convoSchema = mongoose.Schema( index: true, }, messages: [{ type: mongoose.Schema.Types.ObjectId, ref: 'Message' }], - // google only - examples: { type: [{ type: mongoose.Schema.Types.Mixed }], default: undefined }, agentOptions: { type: mongoose.Schema.Types.Mixed, }, @@ -48,12 +46,12 @@ if (process.env.MEILI_HOST && process.env.MEILI_MASTER_KEY) { convoSchema.plugin(mongoMeili, { host: process.env.MEILI_HOST, apiKey: process.env.MEILI_MASTER_KEY, - indexName: 'convos', // Will get created automatically if it doesn't exist already + /** Note: Will get created automatically if it doesn't exist already */ + indexName: 'convos', primaryKey: 'conversationId', }); } -// Create TTL index convoSchema.index({ expiredAt: 1 }, { expireAfterSeconds: 0 }); convoSchema.index({ createdAt: 1, updatedAt: 1 }); convoSchema.index({ conversationId: 1, user: 1 }, { unique: true }); diff --git a/api/models/schema/defaults.js b/api/models/schema/defaults.js index 73fef00b5a..be2af7fb49 100644 --- a/api/models/schema/defaults.js +++ b/api/models/schema/defaults.js @@ -1,3 +1,5 @@ +const mongoose = require('mongoose'); + const conversationPreset = { // endpoint: [azureOpenAI, openAI, anthropic, chatGPTBrowser] endpoint: { @@ -24,6 +26,7 @@ const conversationPreset = { required: false, }, // for google only + examples: { type: [{ type: mongoose.Schema.Types.Mixed }], default: undefined }, modelLabel: { type: String, required: false, @@ -53,6 +56,10 @@ const conversationPreset = { type: Number, required: false, }, + maxTokens: { + type: Number, + required: false, + }, presence_penalty: { type: Number, required: false, @@ -70,6 +77,12 @@ const conversationPreset = { promptCache: { type: Boolean, }, + thinking: { + type: Boolean, + }, + thinkingBudget: { + type: Number, + }, system: { type: String, }, @@ -123,56 +136,6 @@ const conversationPreset = { }, }; -const agentOptions = { - model: { - type: String, - required: false, - }, - // for azureOpenAI, openAI only - chatGptLabel: { - type: String, - required: false, - }, - modelLabel: { - type: String, - required: false, - }, - promptPrefix: { - type: String, - required: false, - }, - temperature: { - type: Number, - required: false, - }, - top_p: { - type: Number, - required: false, - }, - // for google only - topP: { - type: Number, - required: false, - }, - topK: { - type: Number, - required: false, - }, - maxOutputTokens: { - type: Number, - required: false, - }, - presence_penalty: { - type: Number, - required: false, - }, - frequency_penalty: { - type: Number, - required: false, - }, -}; - module.exports = { conversationPreset, - agentOptions, }; diff --git a/api/models/schema/presetSchema.js b/api/models/schema/presetSchema.js index e1c92ab9c0..918e5c4069 100644 --- a/api/models/schema/presetSchema.js +++ b/api/models/schema/presetSchema.js @@ -23,8 +23,6 @@ const presetSchema = mongoose.Schema( order: { type: Number, }, - // google only - examples: [{ type: mongoose.Schema.Types.Mixed }], ...conversationPreset, agentOptions: { type: mongoose.Schema.Types.Mixed, diff --git a/api/models/schema/roleSchema.js b/api/models/schema/roleSchema.js index 36e9d3f7b6..5e6729bd88 100644 --- a/api/models/schema/roleSchema.js +++ b/api/models/schema/roleSchema.js @@ -48,6 +48,18 @@ const roleSchema = new mongoose.Schema({ default: true, }, }, + [PermissionTypes.TEMPORARY_CHAT]: { + [Permissions.USE]: { + type: Boolean, + default: true, + }, + }, + [PermissionTypes.RUN_CODE]: { + [Permissions.USE]: { + type: Boolean, + default: true, + }, + }, }); const Role = mongoose.model('Role', roleSchema); diff --git a/api/models/schema/userSchema.js b/api/models/schema/userSchema.js index 297fb78e7d..b61a0faa09 100644 --- a/api/models/schema/userSchema.js +++ b/api/models/schema/userSchema.js @@ -43,6 +43,12 @@ const Session = mongoose.Schema({ }, }); +const backupCodeSchema = mongoose.Schema({ + codeHash: { type: String, required: true }, + used: { type: Boolean, default: false }, + usedAt: { type: Date, default: null }, +}); + /** @type {MongooseSchema} */ const userSchema = mongoose.Schema( { @@ -123,7 +129,12 @@ const userSchema = mongoose.Schema( }, plugins: { type: Array, - default: [], + }, + totpSecret: { + type: String, + }, + backupCodes: { + type: [backupCodeSchema], }, refreshToken: { type: [Session], diff --git a/api/models/tx.js b/api/models/tx.js index 05412430c7..b534e7edc9 100644 --- a/api/models/tx.js +++ b/api/models/tx.js @@ -79,6 +79,7 @@ const tokenValues = Object.assign( 'o1-mini': { prompt: 1.1, completion: 4.4 }, 'o1-preview': { prompt: 15, completion: 60 }, o1: { prompt: 15, completion: 60 }, + 'gpt-4.5': { prompt: 75, completion: 150 }, 'gpt-4o-mini': { prompt: 0.15, completion: 0.6 }, 'gpt-4o': { prompt: 2.5, completion: 10 }, 'gpt-4o-2024-05-13': { prompt: 5, completion: 15 }, @@ -88,6 +89,8 @@ const tokenValues = Object.assign( 'claude-3-sonnet': { prompt: 3, completion: 15 }, 'claude-3-5-sonnet': { prompt: 3, completion: 15 }, 'claude-3.5-sonnet': { prompt: 3, completion: 15 }, + 'claude-3-7-sonnet': { prompt: 3, completion: 15 }, + 'claude-3.7-sonnet': { prompt: 3, completion: 15 }, 'claude-3-5-haiku': { prompt: 0.8, completion: 4 }, 'claude-3.5-haiku': { prompt: 0.8, completion: 4 }, 'claude-3-haiku': { prompt: 0.25, completion: 1.25 }, @@ -110,6 +113,14 @@ const tokenValues = Object.assign( 'gemini-1.5': { prompt: 2.5, completion: 10 }, 'gemini-pro-vision': { prompt: 0.5, completion: 1.5 }, gemini: { prompt: 0.5, completion: 1.5 }, + 'grok-2-vision-1212': { prompt: 2.0, completion: 10.0 }, + 'grok-2-vision-latest': { prompt: 2.0, completion: 10.0 }, + 'grok-2-vision': { prompt: 2.0, completion: 10.0 }, + 'grok-vision-beta': { prompt: 5.0, completion: 15.0 }, + 'grok-2-1212': { prompt: 2.0, completion: 10.0 }, + 'grok-2-latest': { prompt: 2.0, completion: 10.0 }, + 'grok-2': { prompt: 2.0, completion: 10.0 }, + 'grok-beta': { prompt: 5.0, completion: 15.0 }, }, bedrockValues, ); @@ -121,6 +132,8 @@ const tokenValues = Object.assign( * @type {Object.} */ const cacheTokenValues = { + 'claude-3.7-sonnet': { write: 3.75, read: 0.3 }, + 'claude-3-7-sonnet': { write: 3.75, read: 0.3 }, 'claude-3.5-sonnet': { write: 3.75, read: 0.3 }, 'claude-3-5-sonnet': { write: 3.75, read: 0.3 }, 'claude-3.5-haiku': { write: 1, read: 0.08 }, @@ -155,6 +168,8 @@ const getValueKey = (model, endpoint) => { return 'o1-mini'; } else if (modelName.includes('o1')) { return 'o1'; + } else if (modelName.includes('gpt-4.5')) { + return 'gpt-4.5'; } else if (modelName.includes('gpt-4o-2024-05-13')) { return 'gpt-4o-2024-05-13'; } else if (modelName.includes('gpt-4o-mini')) { diff --git a/api/models/tx.spec.js b/api/models/tx.spec.js index d77973a7f5..b04eacc9f3 100644 --- a/api/models/tx.spec.js +++ b/api/models/tx.spec.js @@ -50,6 +50,16 @@ describe('getValueKey', () => { expect(getValueKey('gpt-4-0125')).toBe('gpt-4-1106'); }); + it('should return "gpt-4.5" for model type of "gpt-4.5"', () => { + expect(getValueKey('gpt-4.5-preview')).toBe('gpt-4.5'); + expect(getValueKey('gpt-4.5-2024-08-06')).toBe('gpt-4.5'); + expect(getValueKey('gpt-4.5-2024-08-06-0718')).toBe('gpt-4.5'); + expect(getValueKey('openai/gpt-4.5')).toBe('gpt-4.5'); + expect(getValueKey('openai/gpt-4.5-2024-08-06')).toBe('gpt-4.5'); + expect(getValueKey('gpt-4.5-turbo')).toBe('gpt-4.5'); + expect(getValueKey('gpt-4.5-0125')).toBe('gpt-4.5'); + }); + it('should return "gpt-4o" for model type of "gpt-4o"', () => { expect(getValueKey('gpt-4o-2024-08-06')).toBe('gpt-4o'); expect(getValueKey('gpt-4o-2024-08-06-0718')).toBe('gpt-4o'); @@ -80,6 +90,20 @@ describe('getValueKey', () => { expect(getValueKey('chatgpt-4o-latest-0718')).toBe('gpt-4o'); }); + it('should return "claude-3-7-sonnet" for model type of "claude-3-7-sonnet-"', () => { + expect(getValueKey('claude-3-7-sonnet-20240620')).toBe('claude-3-7-sonnet'); + expect(getValueKey('anthropic/claude-3-7-sonnet')).toBe('claude-3-7-sonnet'); + expect(getValueKey('claude-3-7-sonnet-turbo')).toBe('claude-3-7-sonnet'); + expect(getValueKey('claude-3-7-sonnet-0125')).toBe('claude-3-7-sonnet'); + }); + + it('should return "claude-3.7-sonnet" for model type of "claude-3.7-sonnet-"', () => { + expect(getValueKey('claude-3.7-sonnet-20240620')).toBe('claude-3.7-sonnet'); + expect(getValueKey('anthropic/claude-3.7-sonnet')).toBe('claude-3.7-sonnet'); + expect(getValueKey('claude-3.7-sonnet-turbo')).toBe('claude-3.7-sonnet'); + expect(getValueKey('claude-3.7-sonnet-0125')).toBe('claude-3.7-sonnet'); + }); + it('should return "claude-3-5-sonnet" for model type of "claude-3-5-sonnet-"', () => { expect(getValueKey('claude-3-5-sonnet-20240620')).toBe('claude-3-5-sonnet'); expect(getValueKey('anthropic/claude-3-5-sonnet')).toBe('claude-3-5-sonnet'); @@ -458,3 +482,30 @@ describe('Google Model Tests', () => { }); }); }); + +describe('Grok Model Tests - Pricing', () => { + describe('getMultiplier', () => { + test('should return correct prompt and completion rates for Grok vision models', () => { + const models = ['grok-2-vision-1212', 'grok-2-vision', 'grok-2-vision-latest']; + models.forEach((model) => { + expect(getMultiplier({ model, tokenType: 'prompt' })).toBe(2.0); + expect(getMultiplier({ model, tokenType: 'completion' })).toBe(10.0); + }); + }); + + test('should return correct prompt and completion rates for Grok text models', () => { + const models = ['grok-2-1212', 'grok-2', 'grok-2-latest']; + models.forEach((model) => { + expect(getMultiplier({ model, tokenType: 'prompt' })).toBe(2.0); + expect(getMultiplier({ model, tokenType: 'completion' })).toBe(10.0); + }); + }); + + test('should return correct prompt and completion rates for Grok beta models', () => { + expect(getMultiplier({ model: 'grok-vision-beta', tokenType: 'prompt' })).toBe(5.0); + expect(getMultiplier({ model: 'grok-vision-beta', tokenType: 'completion' })).toBe(15.0); + expect(getMultiplier({ model: 'grok-beta', tokenType: 'prompt' })).toBe(5.0); + expect(getMultiplier({ model: 'grok-beta', tokenType: 'completion' })).toBe(15.0); + }); + }); +}); diff --git a/api/package.json b/api/package.json index e000dd8bd0..f29c8cf43e 100644 --- a/api/package.json +++ b/api/package.json @@ -34,18 +34,18 @@ }, "homepage": "https://librechat.ai", "dependencies": { - "@anthropic-ai/sdk": "^0.32.1", + "@anthropic-ai/sdk": "^0.37.0", "@azure/search-documents": "^12.0.0", - "@google/generative-ai": "^0.21.0", + "@google/generative-ai": "^0.23.0", "@googleapis/youtube": "^20.0.0", "@keyv/mongo": "^2.1.8", "@keyv/redis": "^2.8.1", "@langchain/community": "^0.3.14", - "@langchain/core": "^0.3.37", - "@langchain/google-genai": "^0.1.7", - "@langchain/google-vertexai": "^0.1.8", + "@langchain/core": "^0.3.40", + "@langchain/google-genai": "^0.1.9", + "@langchain/google-vertexai": "^0.2.0", "@langchain/textsplitters": "^0.1.0", - "@librechat/agents": "^2.1.2", + "@librechat/agents": "^2.1.8", "@waylaidwanderer/fetch-event-source": "^3.0.1", "axios": "1.7.8", "bcryptjs": "^2.4.3", @@ -57,10 +57,12 @@ "cors": "^2.8.5", "dedent": "^1.5.3", "dotenv": "^16.0.3", + "eventsource": "^3.0.2", "express": "^4.21.2", "express-mongo-sanitize": "^2.2.0", "express-rate-limit": "^7.4.1", "express-session": "^1.18.1", + "express-static-gzip": "^2.2.0", "file-type": "^18.7.0", "firebase": "^11.0.2", "googleapis": "^126.0.1", diff --git a/api/server/controllers/AuthController.js b/api/server/controllers/AuthController.js index 71551ea867..7cdfaa9aaf 100644 --- a/api/server/controllers/AuthController.js +++ b/api/server/controllers/AuthController.js @@ -61,7 +61,7 @@ const refreshController = async (req, res) => { try { const payload = jwt.verify(refreshToken, process.env.JWT_REFRESH_SECRET); - const user = await getUserById(payload.id, '-password -__v'); + const user = await getUserById(payload.id, '-password -__v -totpSecret'); if (!user) { return res.status(401).redirect('/login'); } diff --git a/api/server/controllers/TwoFactorController.js b/api/server/controllers/TwoFactorController.js new file mode 100644 index 0000000000..f145d69d92 --- /dev/null +++ b/api/server/controllers/TwoFactorController.js @@ -0,0 +1,119 @@ +const { + verifyTOTP, + verifyBackupCode, + generateTOTPSecret, + generateBackupCodes, + getTOTPSecret, +} = require('~/server/services/twoFactorService'); +const { updateUser, getUserById } = require('~/models'); +const { logger } = require('~/config'); +const { encryptV2 } = require('~/server/utils/crypto'); + +const enable2FAController = async (req, res) => { + const safeAppTitle = (process.env.APP_TITLE || 'LibreChat').replace(/\s+/g, ''); + + try { + const userId = req.user.id; + const secret = generateTOTPSecret(); + const { plainCodes, codeObjects } = await generateBackupCodes(); + + const encryptedSecret = await encryptV2(secret); + const user = await updateUser(userId, { totpSecret: encryptedSecret, backupCodes: codeObjects }); + + const otpauthUrl = `otpauth://totp/${safeAppTitle}:${user.email}?secret=${secret}&issuer=${safeAppTitle}`; + + res.status(200).json({ + otpauthUrl, + backupCodes: plainCodes, + }); + } catch (err) { + logger.error('[enable2FAController]', err); + res.status(500).json({ message: err.message }); + } +}; + +const verify2FAController = async (req, res) => { + try { + const userId = req.user.id; + const { token, backupCode } = req.body; + const user = await getUserById(userId); + if (!user || !user.totpSecret) { + return res.status(400).json({ message: '2FA not initiated' }); + } + + // Retrieve the plain TOTP secret using getTOTPSecret. + const secret = await getTOTPSecret(user.totpSecret); + + if (token && (await verifyTOTP(secret, token))) { + return res.status(200).json(); + } else if (backupCode) { + const verified = await verifyBackupCode({ user, backupCode }); + if (verified) { + return res.status(200).json(); + } + } + + return res.status(400).json({ message: 'Invalid token.' }); + } catch (err) { + logger.error('[verify2FAController]', err); + res.status(500).json({ message: err.message }); + } +}; + +const confirm2FAController = async (req, res) => { + try { + const userId = req.user.id; + const { token } = req.body; + const user = await getUserById(userId); + + if (!user || !user.totpSecret) { + return res.status(400).json({ message: '2FA not initiated' }); + } + + // Retrieve the plain TOTP secret using getTOTPSecret. + const secret = await getTOTPSecret(user.totpSecret); + + if (await verifyTOTP(secret, token)) { + return res.status(200).json(); + } + + return res.status(400).json({ message: 'Invalid token.' }); + } catch (err) { + logger.error('[confirm2FAController]', err); + res.status(500).json({ message: err.message }); + } +}; + +const disable2FAController = async (req, res) => { + try { + const userId = req.user.id; + await updateUser(userId, { totpSecret: null, backupCodes: [] }); + res.status(200).json(); + } catch (err) { + logger.error('[disable2FAController]', err); + res.status(500).json({ message: err.message }); + } +}; + +const regenerateBackupCodesController = async (req, res) => { + try { + const userId = req.user.id; + const { plainCodes, codeObjects } = await generateBackupCodes(); + await updateUser(userId, { backupCodes: codeObjects }); + res.status(200).json({ + backupCodes: plainCodes, + backupCodesHash: codeObjects, + }); + } catch (err) { + logger.error('[regenerateBackupCodesController]', err); + res.status(500).json({ message: err.message }); + } +}; + +module.exports = { + enable2FAController, + verify2FAController, + confirm2FAController, + disable2FAController, + regenerateBackupCodesController, +}; diff --git a/api/server/controllers/UserController.js b/api/server/controllers/UserController.js index 9238bda941..414e90f373 100644 --- a/api/server/controllers/UserController.js +++ b/api/server/controllers/UserController.js @@ -20,7 +20,9 @@ const { Transaction } = require('~/models/Transaction'); const { logger } = require('~/config'); const getUserController = async (req, res) => { - res.status(200).send(req.user); + const userData = req.user.toObject != null ? req.user.toObject() : { ...req.user }; + delete userData.totpSecret; + res.status(200).send(userData); }; const getTermsStatusController = async (req, res) => { diff --git a/api/server/controllers/agents/callbacks.js b/api/server/controllers/agents/callbacks.js index f43c9db5ba..45beefe7e6 100644 --- a/api/server/controllers/agents/callbacks.js +++ b/api/server/controllers/agents/callbacks.js @@ -1,4 +1,5 @@ -const { Tools, StepTypes, imageGenTools, FileContext } = require('librechat-data-provider'); +const { nanoid } = require('nanoid'); +const { Tools, StepTypes, FileContext } = require('librechat-data-provider'); const { EnvVar, Providers, @@ -242,32 +243,6 @@ function createToolEndCallback({ req, res, artifactPromises }) { return; } - if (imageGenTools.has(output.name)) { - artifactPromises.push( - (async () => { - const fileMetadata = Object.assign(output.artifact, { - messageId: metadata.run_id, - toolCallId: output.tool_call_id, - conversationId: metadata.thread_id, - }); - if (!res.headersSent) { - return fileMetadata; - } - - if (!fileMetadata) { - return null; - } - - res.write(`event: attachment\ndata: ${JSON.stringify(fileMetadata)}\n\n`); - return fileMetadata; - })().catch((error) => { - logger.error('Error processing code output:', error); - return null; - }), - ); - return; - } - if (output.artifact.content) { /** @type {FormattedContent[]} */ const content = output.artifact.content; @@ -278,7 +253,7 @@ function createToolEndCallback({ req, res, artifactPromises }) { const { url } = part.image_url; artifactPromises.push( (async () => { - const filename = `${output.tool_call_id}-image-${new Date().getTime()}`; + const filename = `${output.name}_${output.tool_call_id}_img_${nanoid()}`; const file = await saveBase64Image(url, { req, filename, diff --git a/api/server/controllers/agents/client.js b/api/server/controllers/agents/client.js index 156424e035..fb2ba6999e 100644 --- a/api/server/controllers/agents/client.js +++ b/api/server/controllers/agents/client.js @@ -17,19 +17,21 @@ const { KnownEndpoints, anthropicSchema, isAgentsEndpoint, - bedrockOutputParser, + bedrockInputSchema, removeNullishValues, } = require('librechat-data-provider'); const { formatMessage, + addCacheControl, formatAgentMessages, formatContentStrings, createContextHandlers, } = require('~/app/clients/prompts'); -const { encodeAndFormat } = require('~/server/services/Files/images/encode'); +const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens'); const { getBufferString, HumanMessage } = require('@langchain/core/messages'); +const { encodeAndFormat } = require('~/server/services/Files/images/encode'); +const { getCustomEndpointConfig } = require('~/server/services/Config'); const Tokenizer = require('~/server/services/Tokenizer'); -const { spendTokens } = require('~/models/spendTokens'); const BaseClient = require('~/app/clients/BaseClient'); const { createRun } = require('./run'); const { logger } = require('~/config'); @@ -38,10 +40,10 @@ const { logger } = require('~/config'); /** @typedef {import('@langchain/core/runnables').RunnableConfig} RunnableConfig */ const providerParsers = { - [EModelEndpoint.openAI]: openAISchema, - [EModelEndpoint.azureOpenAI]: openAISchema, - [EModelEndpoint.anthropic]: anthropicSchema, - [EModelEndpoint.bedrock]: bedrockOutputParser, + [EModelEndpoint.openAI]: openAISchema.parse, + [EModelEndpoint.azureOpenAI]: openAISchema.parse, + [EModelEndpoint.anthropic]: anthropicSchema.parse, + [EModelEndpoint.bedrock]: bedrockInputSchema.parse, }; const legacyContentEndpoints = new Set([KnownEndpoints.groq, KnownEndpoints.deepseek]); @@ -186,7 +188,14 @@ class AgentClient extends BaseClient { : {}; if (parseOptions) { - runOptions = parseOptions(this.options.agent.model_parameters); + try { + runOptions = parseOptions(this.options.agent.model_parameters); + } catch (error) { + logger.error( + '[api/server/controllers/agents/client.js #getSaveOptions] Error parsing options', + error, + ); + } } return removeNullishValues( @@ -379,15 +388,34 @@ class AgentClient extends BaseClient { if (!collectedUsage || !collectedUsage.length) { return; } - const input_tokens = collectedUsage[0]?.input_tokens || 0; + const input_tokens = + (collectedUsage[0]?.input_tokens || 0) + + (Number(collectedUsage[0]?.input_token_details?.cache_creation) || 0) + + (Number(collectedUsage[0]?.input_token_details?.cache_read) || 0); let output_tokens = 0; let previousTokens = input_tokens; // Start with original input for (let i = 0; i < collectedUsage.length; i++) { const usage = collectedUsage[i]; + if (!usage) { + continue; + } + + const cache_creation = Number(usage.input_token_details?.cache_creation) || 0; + const cache_read = Number(usage.input_token_details?.cache_read) || 0; + + const txMetadata = { + context, + conversationId: this.conversationId, + user: this.user ?? this.options.req.user?.id, + endpointTokenConfig: this.options.endpointTokenConfig, + model: usage.model ?? model ?? this.model ?? this.options.agent.model_parameters.model, + }; + if (i > 0) { // Count new tokens generated (input_tokens minus previous accumulated tokens) - output_tokens += (Number(usage.input_tokens) || 0) - previousTokens; + output_tokens += + (Number(usage.input_tokens) || 0) + cache_creation + cache_read - previousTokens; } // Add this message's output tokens @@ -395,16 +423,26 @@ class AgentClient extends BaseClient { // Update previousTokens to include this message's output previousTokens += Number(usage.output_tokens) || 0; - spendTokens( - { - context, - conversationId: this.conversationId, - user: this.user ?? this.options.req.user?.id, - endpointTokenConfig: this.options.endpointTokenConfig, - model: usage.model ?? model ?? this.model ?? this.options.agent.model_parameters.model, - }, - { promptTokens: usage.input_tokens, completionTokens: usage.output_tokens }, - ).catch((err) => { + + if (cache_creation > 0 || cache_read > 0) { + spendStructuredTokens(txMetadata, { + promptTokens: { + input: usage.input_tokens, + write: cache_creation, + read: cache_read, + }, + completionTokens: usage.output_tokens, + }).catch((err) => { + logger.error( + '[api/server/controllers/agents/client.js #recordCollectedUsage] Error spending structured tokens', + err, + ); + }); + } + spendTokens(txMetadata, { + promptTokens: usage.input_tokens, + completionTokens: usage.output_tokens, + }).catch((err) => { logger.error( '[api/server/controllers/agents/client.js #recordCollectedUsage] Error spending tokens', err, @@ -589,7 +627,7 @@ class AgentClient extends BaseClient { * @param {number} [i] * @param {TMessageContentParts[]} [contentData] */ - const runAgent = async (agent, messages, i = 0, contentData = []) => { + const runAgent = async (agent, _messages, i = 0, contentData = []) => { config.configurable.model = agent.model_parameters.model; if (i > 0) { this.model = agent.model_parameters.model; @@ -622,12 +660,21 @@ class AgentClient extends BaseClient { } if (noSystemMessages === true && systemContent?.length) { - let latestMessage = messages.pop().content; + let latestMessage = _messages.pop().content; if (typeof latestMessage !== 'string') { latestMessage = latestMessage[0].text; } latestMessage = [systemContent, latestMessage].join('\n'); - messages.push(new HumanMessage(latestMessage)); + _messages.push(new HumanMessage(latestMessage)); + } + + let messages = _messages; + if ( + agent.model_parameters?.clientOptions?.defaultHeaders?.['anthropic-beta']?.includes( + 'prompt-caching', + ) + ) { + messages = addCacheControl(messages); } run = await createRun({ @@ -756,6 +803,10 @@ class AgentClient extends BaseClient { ); } } catch (err) { + logger.error( + '[api/server/controllers/agents/client.js #sendCompletion] Operation aborted', + err, + ); if (!abortController.signal.aborted) { logger.error( '[api/server/controllers/agents/client.js #sendCompletion] Unhandled error type', @@ -763,11 +814,6 @@ class AgentClient extends BaseClient { ); throw err; } - - logger.warn( - '[api/server/controllers/agents/client.js #sendCompletion] Operation aborted', - err, - ); } } @@ -782,14 +828,20 @@ class AgentClient extends BaseClient { throw new Error('Run not initialized'); } const { handleLLMEnd, collected: collectedMetadata } = createMetadataAggregator(); - const clientOptions = {}; - const providerConfig = this.options.req.app.locals[this.options.agent.provider]; + /** @type {import('@librechat/agents').ClientOptions} */ + const clientOptions = { + maxTokens: 75, + }; + let endpointConfig = this.options.req.app.locals[this.options.agent.endpoint]; + if (!endpointConfig) { + endpointConfig = await getCustomEndpointConfig(this.options.agent.endpoint); + } if ( - providerConfig && - providerConfig.titleModel && - providerConfig.titleModel !== Constants.CURRENT_MODEL + endpointConfig && + endpointConfig.titleModel && + endpointConfig.titleModel !== Constants.CURRENT_MODEL ) { - clientOptions.model = providerConfig.titleModel; + clientOptions.model = endpointConfig.titleModel; } try { const titleResult = await this.run.generateTitle({ diff --git a/api/server/controllers/agents/run.js b/api/server/controllers/agents/run.js index 346b9e6df8..59d1a5f146 100644 --- a/api/server/controllers/agents/run.js +++ b/api/server/controllers/agents/run.js @@ -45,7 +45,10 @@ async function createRun({ /** @type {'reasoning_content' | 'reasoning'} */ let reasoningKey; - if (llmConfig.configuration?.baseURL.includes(KnownEndpoints.openrouter)) { + if ( + llmConfig.configuration?.baseURL?.includes(KnownEndpoints.openrouter) || + (agent.endpoint && agent.endpoint.toLowerCase().includes(KnownEndpoints.openrouter)) + ) { reasoningKey = 'reasoning'; } if (/o1(?!-(?:mini|preview)).*$/.test(llmConfig.model)) { diff --git a/api/server/controllers/auth/LoginController.js b/api/server/controllers/auth/LoginController.js index 1b543e9baf..8ab9a99ddb 100644 --- a/api/server/controllers/auth/LoginController.js +++ b/api/server/controllers/auth/LoginController.js @@ -1,3 +1,4 @@ +const { generate2FATempToken } = require('~/server/services/twoFactorService'); const { setAuthTokens } = require('~/server/services/AuthService'); const { logger } = require('~/config'); @@ -7,7 +8,12 @@ const loginController = async (req, res) => { return res.status(400).json({ message: 'Invalid credentials' }); } - const { password: _, __v, ...user } = req.user; + if (req.user.backupCodes != null && req.user.backupCodes.length > 0) { + const tempToken = generate2FATempToken(req.user._id); + return res.status(200).json({ twoFAPending: true, tempToken }); + } + + const { password: _p, totpSecret: _t, __v, ...user } = req.user; user.id = user._id.toString(); const token = await setAuthTokens(req.user._id, res); diff --git a/api/server/controllers/auth/TwoFactorAuthController.js b/api/server/controllers/auth/TwoFactorAuthController.js new file mode 100644 index 0000000000..78c5c0314e --- /dev/null +++ b/api/server/controllers/auth/TwoFactorAuthController.js @@ -0,0 +1,58 @@ +const jwt = require('jsonwebtoken'); +const { verifyTOTP, verifyBackupCode, getTOTPSecret } = require('~/server/services/twoFactorService'); +const { setAuthTokens } = require('~/server/services/AuthService'); +const { getUserById } = require('~/models/userMethods'); +const { logger } = require('~/config'); + +const verify2FA = async (req, res) => { + try { + const { tempToken, token, backupCode } = req.body; + if (!tempToken) { + return res.status(400).json({ message: 'Missing temporary token' }); + } + + let payload; + try { + payload = jwt.verify(tempToken, process.env.JWT_SECRET); + } catch (err) { + return res.status(401).json({ message: 'Invalid or expired temporary token' }); + } + + const user = await getUserById(payload.userId); + // Ensure that the user exists and has backup codes (i.e. 2FA enabled) + if (!user || !(user.backupCodes && user.backupCodes.length > 0)) { + return res.status(400).json({ message: '2FA is not enabled for this user' }); + } + + // Use the new getTOTPSecret function to retrieve (and decrypt if necessary) the TOTP secret. + const secret = await getTOTPSecret(user.totpSecret); + + let verified = false; + if (token && (await verifyTOTP(secret, token))) { + verified = true; + } else if (backupCode) { + verified = await verifyBackupCode({ user, backupCode }); + } + + if (!verified) { + return res.status(401).json({ message: 'Invalid 2FA code or backup code' }); + } + + // Prepare user data for response. + // If the user is a plain object (from lean queries), we create a shallow copy. + const userData = user.toObject ? user.toObject() : { ...user }; + // Remove sensitive fields. + delete userData.password; + delete userData.__v; + delete userData.totpSecret; + userData.id = user._id.toString(); + + const authToken = await setAuthTokens(user._id, res); + return res.status(200).json({ token: authToken, user: userData }); + } catch (err) { + logger.error('[verify2FA]', err); + return res.status(500).json({ message: 'Something went wrong' }); + } +}; + +module.exports = { verify2FA }; diff --git a/api/server/controllers/tools.js b/api/server/controllers/tools.js index 9460e66136..1c5330af35 100644 --- a/api/server/controllers/tools.js +++ b/api/server/controllers/tools.js @@ -1,10 +1,17 @@ const { nanoid } = require('nanoid'); const { EnvVar } = require('@librechat/agents'); -const { Tools, AuthType, ToolCallTypes } = require('librechat-data-provider'); +const { + Tools, + AuthType, + Permissions, + ToolCallTypes, + PermissionTypes, +} = require('librechat-data-provider'); const { processFileURL, uploadImageBuffer } = require('~/server/services/Files/process'); const { processCodeOutput } = require('~/server/services/Files/Code/process'); -const { loadAuthValues, loadTools } = require('~/app/clients/tools/util'); const { createToolCall, getToolCallsByConvo } = require('~/models/ToolCall'); +const { loadAuthValues, loadTools } = require('~/app/clients/tools/util'); +const { checkAccess } = require('~/server/middleware'); const { getMessage } = require('~/models/Message'); const { logger } = require('~/config'); @@ -12,6 +19,10 @@ const fieldsMap = { [Tools.execute_code]: [EnvVar.CODE_API_KEY], }; +const toolAccessPermType = { + [Tools.execute_code]: PermissionTypes.RUN_CODE, +}; + /** * @param {ServerRequest} req - The request object, containing information about the HTTP request. * @param {ServerResponse} res - The response object, used to send back the desired HTTP response. @@ -58,6 +69,7 @@ const verifyToolAuth = async (req, res) => { /** * @param {ServerRequest} req - The request object, containing information about the HTTP request. * @param {ServerResponse} res - The response object, used to send back the desired HTTP response. + * @param {NextFunction} next - The next middleware function to call. * @returns {Promise} A promise that resolves when the function has completed. */ const callTool = async (req, res) => { @@ -83,6 +95,16 @@ const callTool = async (req, res) => { return; } logger.debug(`[${toolId}/call] User: ${req.user.id}`); + let hasAccess = true; + if (toolAccessPermType[toolId]) { + hasAccess = await checkAccess(req.user, toolAccessPermType[toolId], [Permissions.USE]); + } + if (!hasAccess) { + logger.warn( + `[${toolAccessPermType[toolId]}] Forbidden: Insufficient permissions for User ${req.user.id}: ${Permissions.USE}`, + ); + return res.status(403).json({ message: 'Forbidden: Insufficient permissions' }); + } const { loadedTools } = await loadTools({ user: req.user.id, tools: [toolId], diff --git a/api/server/index.js b/api/server/index.js index 30d36d9a9f..4a428789dd 100644 --- a/api/server/index.js +++ b/api/server/index.js @@ -22,10 +22,11 @@ const staticCache = require('./utils/staticCache'); const noIndex = require('./middleware/noIndex'); const routes = require('./routes'); -const { PORT, HOST, ALLOW_SOCIAL_LOGIN, DISABLE_COMPRESSION } = process.env ?? {}; +const { PORT, HOST, ALLOW_SOCIAL_LOGIN, DISABLE_COMPRESSION, TRUST_PROXY } = process.env ?? {}; const port = Number(PORT) || 3080; const host = HOST || 'localhost'; +const trusted_proxy = Number(TRUST_PROXY) || 1; /* trust first proxy by default */ const startServer = async () => { if (typeof Bun !== 'undefined') { @@ -53,7 +54,7 @@ const startServer = async () => { app.use(staticCache(app.locals.paths.dist)); app.use(staticCache(app.locals.paths.fonts)); app.use(staticCache(app.locals.paths.assets)); - app.set('trust proxy', 1); /* trust first proxy */ + app.set('trust proxy', trusted_proxy); app.use(cors()); app.use(cookieParser()); @@ -145,6 +146,18 @@ process.on('uncaughtException', (err) => { logger.error('There was an uncaught error:', err); } + if (err.message.includes('abort')) { + logger.warn('There was an uncatchable AbortController error.'); + return; + } + + if (err.message.includes('GoogleGenerativeAI')) { + logger.warn( + '\n\n`GoogleGenerativeAI` errors cannot be caught due to an upstream issue, see: https://github.com/google-gemini/generative-ai-js/issues/303', + ); + return; + } + if (err.message.includes('fetch failed')) { if (messageCount === 0) { logger.warn('Meilisearch error, search will be disabled'); diff --git a/api/server/middleware/roles/generateCheckAccess.js b/api/server/middleware/roles/generateCheckAccess.js index ffc0ddc613..0f137c3c84 100644 --- a/api/server/middleware/roles/generateCheckAccess.js +++ b/api/server/middleware/roles/generateCheckAccess.js @@ -1,4 +1,42 @@ const { getRoleByName } = require('~/models/Role'); +const { logger } = require('~/config'); + +/** + * Core function to check if a user has one or more required permissions + * + * @param {object} user - The user object + * @param {PermissionTypes} permissionType - The type of permission to check + * @param {Permissions[]} permissions - The list of specific permissions to check + * @param {Record} [bodyProps] - An optional object where keys are permissions and values are arrays of properties to check + * @param {object} [checkObject] - The object to check properties against + * @returns {Promise} Whether the user has the required permissions + */ +const checkAccess = async (user, permissionType, permissions, bodyProps = {}, checkObject = {}) => { + if (!user) { + return false; + } + + const role = await getRoleByName(user.role); + if (role && role[permissionType]) { + const hasAnyPermission = permissions.some((permission) => { + if (role[permissionType][permission]) { + return true; + } + + if (bodyProps[permission] && checkObject) { + return bodyProps[permission].some((prop) => + Object.prototype.hasOwnProperty.call(checkObject, prop), + ); + } + + return false; + }); + + return hasAnyPermission; + } + + return false; +}; /** * Middleware to check if a user has one or more required permissions, optionally based on `req.body` properties. @@ -6,42 +44,35 @@ const { getRoleByName } = require('~/models/Role'); * @param {PermissionTypes} permissionType - The type of permission to check. * @param {Permissions[]} permissions - The list of specific permissions to check. * @param {Record} [bodyProps] - An optional object where keys are permissions and values are arrays of `req.body` properties to check. - * @returns {Function} Express middleware function. + * @returns {(req: ServerRequest, res: ServerResponse, next: NextFunction) => Promise} Express middleware function. */ const generateCheckAccess = (permissionType, permissions, bodyProps = {}) => { return async (req, res, next) => { try { - const { user } = req; - if (!user) { - return res.status(401).json({ message: 'Authorization required' }); - } - - const role = await getRoleByName(user.role); - if (role && role[permissionType]) { - const hasAnyPermission = permissions.some((permission) => { - if (role[permissionType][permission]) { - return true; - } - - if (bodyProps[permission] && req.body) { - return bodyProps[permission].some((prop) => - Object.prototype.hasOwnProperty.call(req.body, prop), - ); - } - - return false; - }); - - if (hasAnyPermission) { - return next(); - } + const hasAccess = await checkAccess( + req.user, + permissionType, + permissions, + bodyProps, + req.body, + ); + + if (hasAccess) { + return next(); } + logger.warn( + `[${permissionType}] Forbidden: Insufficient permissions for User ${req.user.id}: ${permissions.join(', ')}`, + ); return res.status(403).json({ message: 'Forbidden: Insufficient permissions' }); } catch (error) { + logger.error(error); return res.status(500).json({ message: `Server error: ${error.message}` }); } }; }; -module.exports = generateCheckAccess; +module.exports = { + checkAccess, + generateCheckAccess, +}; diff --git a/api/server/middleware/roles/index.js b/api/server/middleware/roles/index.js index 999c36481e..a9fc5b2a08 100644 --- a/api/server/middleware/roles/index.js +++ b/api/server/middleware/roles/index.js @@ -1,7 +1,8 @@ const checkAdmin = require('./checkAdmin'); -const generateCheckAccess = require('./generateCheckAccess'); +const { checkAccess, generateCheckAccess } = require('./generateCheckAccess'); module.exports = { checkAdmin, + checkAccess, generateCheckAccess, }; diff --git a/api/server/routes/auth.js b/api/server/routes/auth.js index 3e86ffd868..03046d903f 100644 --- a/api/server/routes/auth.js +++ b/api/server/routes/auth.js @@ -7,6 +7,13 @@ const { } = require('~/server/controllers/AuthController'); const { loginController } = require('~/server/controllers/auth/LoginController'); const { logoutController } = require('~/server/controllers/auth/LogoutController'); +const { verify2FA } = require('~/server/controllers/auth/TwoFactorAuthController'); +const { + enable2FAController, + verify2FAController, + disable2FAController, + regenerateBackupCodesController, confirm2FAController, +} = require('~/server/controllers/TwoFactorController'); const { checkBan, loginLimiter, @@ -50,4 +57,11 @@ router.post( ); router.post('/resetPassword', checkBan, validatePasswordReset, resetPasswordController); +router.get('/2fa/enable', requireJwtAuth, enable2FAController); +router.post('/2fa/verify', requireJwtAuth, verify2FAController); +router.post('/2fa/verify-temp', checkBan, verify2FA); +router.post('/2fa/confirm', requireJwtAuth, confirm2FAController); +router.post('/2fa/disable', requireJwtAuth, disable2FAController); +router.post('/2fa/backup/regenerate', requireJwtAuth, regenerateBackupCodesController); + module.exports = router; diff --git a/api/server/services/Endpoints/agents/initialize.js b/api/server/services/Endpoints/agents/initialize.js index 1726ef3460..8e3fc8ae86 100644 --- a/api/server/services/Endpoints/agents/initialize.js +++ b/api/server/services/Endpoints/agents/initialize.js @@ -101,6 +101,7 @@ const initializeAgentOptions = async ({ }); const provider = agent.provider; + agent.endpoint = provider; let getOptions = providerConfigMap[provider]; if (!getOptions && providerConfigMap[provider.toLowerCase()] != null) { agent.provider = provider.toLowerCase(); @@ -112,9 +113,7 @@ const initializeAgentOptions = async ({ } getOptions = initCustom; agent.provider = Providers.OPENAI; - agent.endpoint = provider.toLowerCase(); } - const model_parameters = Object.assign( {}, agent.model_parameters ?? { model: agent.model }, diff --git a/api/server/services/Endpoints/agents/title.js b/api/server/services/Endpoints/agents/title.js index 56fd28668d..f25746582e 100644 --- a/api/server/services/Endpoints/agents/title.js +++ b/api/server/services/Endpoints/agents/title.js @@ -20,10 +20,19 @@ const addTitle = async (req, { text, response, client }) => { const titleCache = getLogStores(CacheKeys.GEN_TITLE); const key = `${req.user.id}-${response.conversationId}`; + const responseText = + response?.content && Array.isArray(response?.content) + ? response.content.reduce((acc, block) => { + if (block?.type === 'text') { + return acc + block.text; + } + return acc; + }, '') + : (response?.content ?? response?.text ?? ''); const title = await client.titleConvo({ text, - responseText: response?.text ?? '', + responseText, conversationId: response.conversationId, }); await titleCache.set(key, title, 120000); diff --git a/api/server/services/Endpoints/anthropic/build.js b/api/server/services/Endpoints/anthropic/build.js index 028da36407..2deab4b975 100644 --- a/api/server/services/Endpoints/anthropic/build.js +++ b/api/server/services/Endpoints/anthropic/build.js @@ -1,4 +1,4 @@ -const { removeNullishValues } = require('librechat-data-provider'); +const { removeNullishValues, anthropicSettings } = require('librechat-data-provider'); const generateArtifactsPrompt = require('~/app/clients/prompts/artifacts'); const buildOptions = (endpoint, parsedBody) => { @@ -6,8 +6,10 @@ const buildOptions = (endpoint, parsedBody) => { modelLabel, promptPrefix, maxContextTokens, - resendFiles = true, - promptCache = true, + resendFiles = anthropicSettings.resendFiles.default, + promptCache = anthropicSettings.promptCache.default, + thinking = anthropicSettings.thinking.default, + thinkingBudget = anthropicSettings.thinkingBudget.default, iconURL, greeting, spec, @@ -21,6 +23,8 @@ const buildOptions = (endpoint, parsedBody) => { promptPrefix, resendFiles, promptCache, + thinking, + thinkingBudget, iconURL, greeting, spec, diff --git a/api/server/services/Endpoints/anthropic/helpers.js b/api/server/services/Endpoints/anthropic/helpers.js new file mode 100644 index 0000000000..04e4efc61c --- /dev/null +++ b/api/server/services/Endpoints/anthropic/helpers.js @@ -0,0 +1,111 @@ +const { EModelEndpoint, anthropicSettings } = require('librechat-data-provider'); +const { matchModelName } = require('~/utils'); +const { logger } = require('~/config'); + +/** + * @param {string} modelName + * @returns {boolean} + */ +function checkPromptCacheSupport(modelName) { + const modelMatch = matchModelName(modelName, EModelEndpoint.anthropic); + if ( + modelMatch.includes('claude-3-5-sonnet-latest') || + modelMatch.includes('claude-3.5-sonnet-latest') + ) { + return false; + } + + if ( + modelMatch === 'claude-3-7-sonnet' || + modelMatch === 'claude-3-5-sonnet' || + modelMatch === 'claude-3-5-haiku' || + modelMatch === 'claude-3-haiku' || + modelMatch === 'claude-3-opus' || + modelMatch === 'claude-3.7-sonnet' || + modelMatch === 'claude-3.5-sonnet' || + modelMatch === 'claude-3.5-haiku' + ) { + return true; + } + + return false; +} + +/** + * Gets the appropriate headers for Claude models with cache control + * @param {string} model The model name + * @param {boolean} supportsCacheControl Whether the model supports cache control + * @returns {AnthropicClientOptions['extendedOptions']['defaultHeaders']|undefined} The headers object or undefined if not applicable + */ +function getClaudeHeaders(model, supportsCacheControl) { + if (!supportsCacheControl) { + return undefined; + } + + if (/claude-3[-.]5-sonnet/.test(model)) { + return { + 'anthropic-beta': 'max-tokens-3-5-sonnet-2024-07-15,prompt-caching-2024-07-31', + }; + } else if (/claude-3[-.]7/.test(model)) { + return { + 'anthropic-beta': + 'token-efficient-tools-2025-02-19,output-128k-2025-02-19,prompt-caching-2024-07-31', + }; + } else { + return { + 'anthropic-beta': 'prompt-caching-2024-07-31', + }; + } +} + +/** + * Configures reasoning-related options for Claude models + * @param {AnthropicClientOptions & { max_tokens?: number }} anthropicInput The request options object + * @param {Object} extendedOptions Additional client configuration options + * @param {boolean} extendedOptions.thinking Whether thinking is enabled in client config + * @param {number|null} extendedOptions.thinkingBudget The token budget for thinking + * @returns {Object} Updated request options + */ +function configureReasoning(anthropicInput, extendedOptions = {}) { + const updatedOptions = { ...anthropicInput }; + const currentMaxTokens = updatedOptions.max_tokens ?? updatedOptions.maxTokens; + if ( + extendedOptions.thinking && + updatedOptions?.model && + /claude-3[-.]7/.test(updatedOptions.model) + ) { + updatedOptions.thinking = { + type: 'enabled', + }; + } + + if (updatedOptions.thinking != null && extendedOptions.thinkingBudget != null) { + updatedOptions.thinking = { + ...updatedOptions.thinking, + budget_tokens: extendedOptions.thinkingBudget, + }; + } + + if ( + updatedOptions.thinking != null && + (currentMaxTokens == null || updatedOptions.thinking.budget_tokens > currentMaxTokens) + ) { + const maxTokens = anthropicSettings.maxOutputTokens.reset(updatedOptions.model); + updatedOptions.max_tokens = currentMaxTokens ?? maxTokens; + + logger.warn( + updatedOptions.max_tokens === maxTokens + ? '[AnthropicClient] max_tokens is not defined while thinking is enabled. Setting max_tokens to model default.' + : `[AnthropicClient] thinking budget_tokens (${updatedOptions.thinking.budget_tokens}) exceeds max_tokens (${updatedOptions.max_tokens}). Adjusting budget_tokens.`, + ); + + updatedOptions.thinking.budget_tokens = Math.min( + updatedOptions.thinking.budget_tokens, + Math.floor(updatedOptions.max_tokens * 0.9), + ); + } + + return updatedOptions; +} + +module.exports = { checkPromptCacheSupport, getClaudeHeaders, configureReasoning }; diff --git a/api/server/services/Endpoints/anthropic/initialize.js b/api/server/services/Endpoints/anthropic/initialize.js index ffd61441be..6c89eff463 100644 --- a/api/server/services/Endpoints/anthropic/initialize.js +++ b/api/server/services/Endpoints/anthropic/initialize.js @@ -27,6 +27,7 @@ const initializeClient = async ({ req, res, endpointOption, overrideModel, optio if (anthropicConfig) { clientOptions.streamRate = anthropicConfig.streamRate; + clientOptions.titleModel = anthropicConfig.titleModel; } /** @type {undefined | TBaseEndpoint} */ diff --git a/api/server/services/Endpoints/anthropic/llm.js b/api/server/services/Endpoints/anthropic/llm.js index 301d42712a..9f20b8e61d 100644 --- a/api/server/services/Endpoints/anthropic/llm.js +++ b/api/server/services/Endpoints/anthropic/llm.js @@ -1,5 +1,6 @@ const { HttpsProxyAgent } = require('https-proxy-agent'); const { anthropicSettings, removeNullishValues } = require('librechat-data-provider'); +const { checkPromptCacheSupport, getClaudeHeaders, configureReasoning } = require('./helpers'); /** * Generates configuration options for creating an Anthropic language model (LLM) instance. @@ -20,6 +21,14 @@ const { anthropicSettings, removeNullishValues } = require('librechat-data-provi * @returns {Object} Configuration options for creating an Anthropic LLM instance, with null and undefined values removed. */ function getLLMConfig(apiKey, options = {}) { + const systemOptions = { + thinking: options.modelOptions.thinking ?? anthropicSettings.thinking.default, + promptCache: options.modelOptions.promptCache ?? anthropicSettings.promptCache.default, + thinkingBudget: options.modelOptions.thinkingBudget ?? anthropicSettings.thinkingBudget.default, + }; + for (let key in systemOptions) { + delete options.modelOptions[key]; + } const defaultOptions = { model: anthropicSettings.model.default, maxOutputTokens: anthropicSettings.maxOutputTokens.default, @@ -29,19 +38,34 @@ function getLLMConfig(apiKey, options = {}) { const mergedOptions = Object.assign(defaultOptions, options.modelOptions); /** @type {AnthropicClientOptions} */ - const requestOptions = { + let requestOptions = { apiKey, model: mergedOptions.model, stream: mergedOptions.stream, temperature: mergedOptions.temperature, - topP: mergedOptions.topP, - topK: mergedOptions.topK, stopSequences: mergedOptions.stop, maxTokens: mergedOptions.maxOutputTokens || anthropicSettings.maxOutputTokens.reset(mergedOptions.model), clientOptions: {}, }; + requestOptions = configureReasoning(requestOptions, systemOptions); + + if (!/claude-3[-.]7/.test(mergedOptions.model)) { + requestOptions.topP = mergedOptions.topP; + requestOptions.topK = mergedOptions.topK; + } else if (requestOptions.thinking == null) { + requestOptions.topP = mergedOptions.topP; + requestOptions.topK = mergedOptions.topK; + } + + const supportsCacheControl = + systemOptions.promptCache === true && checkPromptCacheSupport(requestOptions.model); + const headers = getClaudeHeaders(requestOptions.model, supportsCacheControl); + if (headers) { + requestOptions.clientOptions.defaultHeaders = headers; + } + if (options.proxy) { requestOptions.clientOptions.httpAgent = new HttpsProxyAgent(options.proxy); } diff --git a/api/server/services/Endpoints/anthropic/llm.spec.js b/api/server/services/Endpoints/anthropic/llm.spec.js new file mode 100644 index 0000000000..9c453efb92 --- /dev/null +++ b/api/server/services/Endpoints/anthropic/llm.spec.js @@ -0,0 +1,153 @@ +const { anthropicSettings } = require('librechat-data-provider'); +const { getLLMConfig } = require('~/server/services/Endpoints/anthropic/llm'); + +jest.mock('https-proxy-agent', () => ({ + HttpsProxyAgent: jest.fn().mockImplementation((proxy) => ({ proxy })), +})); + +describe('getLLMConfig', () => { + it('should create a basic configuration with default values', () => { + const result = getLLMConfig('test-api-key', { modelOptions: {} }); + + expect(result.llmConfig).toHaveProperty('apiKey', 'test-api-key'); + expect(result.llmConfig).toHaveProperty('model', anthropicSettings.model.default); + expect(result.llmConfig).toHaveProperty('stream', true); + expect(result.llmConfig).toHaveProperty('maxTokens'); + }); + + it('should include proxy settings when provided', () => { + const result = getLLMConfig('test-api-key', { + modelOptions: {}, + proxy: 'http://proxy:8080', + }); + + expect(result.llmConfig.clientOptions).toHaveProperty('httpAgent'); + expect(result.llmConfig.clientOptions.httpAgent).toHaveProperty('proxy', 'http://proxy:8080'); + }); + + it('should include reverse proxy URL when provided', () => { + const result = getLLMConfig('test-api-key', { + modelOptions: {}, + reverseProxyUrl: 'http://reverse-proxy', + }); + + expect(result.llmConfig.clientOptions).toHaveProperty('baseURL', 'http://reverse-proxy'); + }); + + it('should include topK and topP for non-Claude-3.7 models', () => { + const result = getLLMConfig('test-api-key', { + modelOptions: { + model: 'claude-3-opus', + topK: 10, + topP: 0.9, + }, + }); + + expect(result.llmConfig).toHaveProperty('topK', 10); + expect(result.llmConfig).toHaveProperty('topP', 0.9); + }); + + it('should include topK and topP for Claude-3.5 models', () => { + const result = getLLMConfig('test-api-key', { + modelOptions: { + model: 'claude-3-5-sonnet', + topK: 10, + topP: 0.9, + }, + }); + + expect(result.llmConfig).toHaveProperty('topK', 10); + expect(result.llmConfig).toHaveProperty('topP', 0.9); + }); + + it('should NOT include topK and topP for Claude-3-7 models (hyphen notation)', () => { + const result = getLLMConfig('test-api-key', { + modelOptions: { + model: 'claude-3-7-sonnet', + topK: 10, + topP: 0.9, + }, + }); + + expect(result.llmConfig).not.toHaveProperty('topK'); + expect(result.llmConfig).not.toHaveProperty('topP'); + }); + + it('should NOT include topK and topP for Claude-3.7 models (decimal notation)', () => { + const result = getLLMConfig('test-api-key', { + modelOptions: { + model: 'claude-3.7-sonnet', + topK: 10, + topP: 0.9, + }, + }); + + expect(result.llmConfig).not.toHaveProperty('topK'); + expect(result.llmConfig).not.toHaveProperty('topP'); + }); + + it('should handle custom maxOutputTokens', () => { + const result = getLLMConfig('test-api-key', { + modelOptions: { + model: 'claude-3-opus', + maxOutputTokens: 2048, + }, + }); + + expect(result.llmConfig).toHaveProperty('maxTokens', 2048); + }); + + it('should handle promptCache setting', () => { + const result = getLLMConfig('test-api-key', { + modelOptions: { + model: 'claude-3-5-sonnet', + promptCache: true, + }, + }); + + // We're not checking specific header values since that depends on the actual helper function + // Just verifying that the promptCache setting is processed + expect(result.llmConfig).toBeDefined(); + }); + + it('should include topK and topP for Claude-3.7 models when thinking is not enabled', () => { + // Test with thinking explicitly set to null/undefined + const result = getLLMConfig('test-api-key', { + modelOptions: { + model: 'claude-3-7-sonnet', + topK: 10, + topP: 0.9, + thinking: false, + }, + }); + + expect(result.llmConfig).toHaveProperty('topK', 10); + expect(result.llmConfig).toHaveProperty('topP', 0.9); + + // Test with thinking explicitly set to false + const result2 = getLLMConfig('test-api-key', { + modelOptions: { + model: 'claude-3-7-sonnet', + topK: 10, + topP: 0.9, + thinking: false, + }, + }); + + expect(result2.llmConfig).toHaveProperty('topK', 10); + expect(result2.llmConfig).toHaveProperty('topP', 0.9); + + // Test with decimal notation as well + const result3 = getLLMConfig('test-api-key', { + modelOptions: { + model: 'claude-3.7-sonnet', + topK: 10, + topP: 0.9, + thinking: false, + }, + }); + + expect(result3.llmConfig).toHaveProperty('topK', 10); + expect(result3.llmConfig).toHaveProperty('topP', 0.9); + }); +}); diff --git a/api/server/services/Endpoints/bedrock/build.js b/api/server/services/Endpoints/bedrock/build.js index d6fb0636a9..f5228160fc 100644 --- a/api/server/services/Endpoints/bedrock/build.js +++ b/api/server/services/Endpoints/bedrock/build.js @@ -1,6 +1,5 @@ -const { removeNullishValues, bedrockInputParser } = require('librechat-data-provider'); +const { removeNullishValues } = require('librechat-data-provider'); const generateArtifactsPrompt = require('~/app/clients/prompts/artifacts'); -const { logger } = require('~/config'); const buildOptions = (endpoint, parsedBody) => { const { @@ -15,12 +14,6 @@ const buildOptions = (endpoint, parsedBody) => { artifacts, ...model_parameters } = parsedBody; - let parsedParams = model_parameters; - try { - parsedParams = bedrockInputParser.parse(model_parameters); - } catch (error) { - logger.warn('Failed to parse bedrock input', error); - } const endpointOption = removeNullishValues({ endpoint, name, @@ -31,7 +24,7 @@ const buildOptions = (endpoint, parsedBody) => { spec, promptPrefix, maxContextTokens, - model_parameters: parsedParams, + model_parameters, }); if (typeof artifacts === 'string') { diff --git a/api/server/services/Endpoints/bedrock/options.js b/api/server/services/Endpoints/bedrock/options.js index 11b33a5357..6740ae882e 100644 --- a/api/server/services/Endpoints/bedrock/options.js +++ b/api/server/services/Endpoints/bedrock/options.js @@ -1,14 +1,16 @@ const { HttpsProxyAgent } = require('https-proxy-agent'); const { - EModelEndpoint, - Constants, AuthType, + Constants, + EModelEndpoint, + bedrockInputParser, + bedrockOutputParser, removeNullishValues, } = require('librechat-data-provider'); const { getUserKey, checkUserKeyExpiry } = require('~/server/services/UserService'); const { sleep } = require('~/server/utils'); -const getOptions = async ({ req, endpointOption }) => { +const getOptions = async ({ req, overrideModel, endpointOption }) => { const { BEDROCK_AWS_SECRET_ACCESS_KEY, BEDROCK_AWS_ACCESS_KEY_ID, @@ -62,39 +64,44 @@ const getOptions = async ({ req, endpointOption }) => { /** @type {BedrockClientOptions} */ const requestOptions = { - model: endpointOption.model, + model: overrideModel ?? endpointOption.model, region: BEDROCK_AWS_DEFAULT_REGION, - streaming: true, - streamUsage: true, - callbacks: [ - { - handleLLMNewToken: async () => { - if (!streamRate) { - return; - } - await sleep(streamRate); - }, - }, - ], }; - if (credentials) { - requestOptions.credentials = credentials; - } - - if (BEDROCK_REVERSE_PROXY) { - requestOptions.endpointHost = BEDROCK_REVERSE_PROXY; - } - const configOptions = {}; if (PROXY) { /** NOTE: NOT SUPPORTED BY BEDROCK */ configOptions.httpAgent = new HttpsProxyAgent(PROXY); } + const llmConfig = bedrockOutputParser( + bedrockInputParser.parse( + removeNullishValues(Object.assign(requestOptions, endpointOption.model_parameters)), + ), + ); + + if (credentials) { + llmConfig.credentials = credentials; + } + + if (BEDROCK_REVERSE_PROXY) { + llmConfig.endpointHost = BEDROCK_REVERSE_PROXY; + } + + llmConfig.callbacks = [ + { + handleLLMNewToken: async () => { + if (!streamRate) { + return; + } + await sleep(streamRate); + }, + }, + ]; + return { /** @type {BedrockClientOptions} */ - llmConfig: removeNullishValues(Object.assign(requestOptions, endpointOption.model_parameters)), + llmConfig, configOptions, }; }; diff --git a/api/server/services/Endpoints/custom/initialize.js b/api/server/services/Endpoints/custom/initialize.js index fe2beba582..e81b8fca3b 100644 --- a/api/server/services/Endpoints/custom/initialize.js +++ b/api/server/services/Endpoints/custom/initialize.js @@ -141,7 +141,7 @@ const initializeClient = async ({ req, res, endpointOption, optionsOnly, overrid }, clientOptions, ); - const options = getLLMConfig(apiKey, clientOptions); + const options = getLLMConfig(apiKey, clientOptions, endpoint); if (!customOptions.streamRate) { return options; } diff --git a/api/server/services/Endpoints/google/initialize.js b/api/server/services/Endpoints/google/initialize.js index c157dd8b28..b7419a8a87 100644 --- a/api/server/services/Endpoints/google/initialize.js +++ b/api/server/services/Endpoints/google/initialize.js @@ -5,12 +5,7 @@ const { isEnabled } = require('~/server/utils'); const { GoogleClient } = require('~/app'); const initializeClient = async ({ req, res, endpointOption, overrideModel, optionsOnly }) => { - const { - GOOGLE_KEY, - GOOGLE_REVERSE_PROXY, - GOOGLE_AUTH_HEADER, - PROXY, - } = process.env; + const { GOOGLE_KEY, GOOGLE_REVERSE_PROXY, GOOGLE_AUTH_HEADER, PROXY } = process.env; const isUserProvided = GOOGLE_KEY === 'user_provided'; const { key: expiresAt } = req.body; @@ -43,6 +38,7 @@ const initializeClient = async ({ req, res, endpointOption, overrideModel, optio if (googleConfig) { clientOptions.streamRate = googleConfig.streamRate; + clientOptions.titleModel = googleConfig.titleModel; } if (allConfig) { diff --git a/api/server/services/Endpoints/openAI/initialize.js b/api/server/services/Endpoints/openAI/initialize.js index 0eb0d566b9..7549a76ce5 100644 --- a/api/server/services/Endpoints/openAI/initialize.js +++ b/api/server/services/Endpoints/openAI/initialize.js @@ -113,6 +113,7 @@ const initializeClient = async ({ if (!isAzureOpenAI && openAIConfig) { clientOptions.streamRate = openAIConfig.streamRate; + clientOptions.titleModel = openAIConfig.titleModel; } /** @type {undefined | TBaseEndpoint} */ diff --git a/api/server/services/Endpoints/openAI/llm.js b/api/server/services/Endpoints/openAI/llm.js index 05b08b284b..025fbd2816 100644 --- a/api/server/services/Endpoints/openAI/llm.js +++ b/api/server/services/Endpoints/openAI/llm.js @@ -23,13 +23,13 @@ const { isEnabled } = require('~/server/utils'); * @param {boolean} [options.streaming] - Whether to use streaming mode. * @param {Object} [options.addParams] - Additional parameters to add to the model options. * @param {string[]} [options.dropParams] - Parameters to remove from the model options. + * @param {string|null} [endpoint=null] - The endpoint name * @returns {Object} Configuration options for creating an LLM instance. */ -function getLLMConfig(apiKey, options = {}) { +function getLLMConfig(apiKey, options = {}, endpoint = null) { const { modelOptions = {}, reverseProxyUrl, - useOpenRouter, defaultQuery, headers, proxy, @@ -56,9 +56,14 @@ function getLLMConfig(apiKey, options = {}) { }); } + let useOpenRouter; /** @type {OpenAIClientOptions['configuration']} */ const configOptions = {}; - if (useOpenRouter || reverseProxyUrl.includes(KnownEndpoints.openrouter)) { + if ( + (reverseProxyUrl && reverseProxyUrl.includes(KnownEndpoints.openrouter)) || + (endpoint && endpoint.toLowerCase().includes(KnownEndpoints.openrouter)) + ) { + useOpenRouter = true; llmConfig.include_reasoning = true; configOptions.baseURL = reverseProxyUrl; configOptions.defaultHeaders = Object.assign( @@ -118,6 +123,13 @@ function getLLMConfig(apiKey, options = {}) { llmConfig.organization = process.env.OPENAI_ORGANIZATION; } + if (useOpenRouter && llmConfig.reasoning_effort != null) { + llmConfig.reasoning = { + effort: llmConfig.reasoning_effort, + }; + delete llmConfig.reasoning_effort; + } + return { /** @type {OpenAIClientOptions} */ llmConfig, diff --git a/api/server/services/Files/Code/crud.js b/api/server/services/Files/Code/crud.js index 076a4d9f13..7b26093d62 100644 --- a/api/server/services/Files/Code/crud.js +++ b/api/server/services/Files/Code/crud.js @@ -2,6 +2,7 @@ const axios = require('axios'); const FormData = require('form-data'); const { getCodeBaseURL } = require('@librechat/agents'); +const { logAxiosError } = require('~/utils'); const MAX_FILE_SIZE = 150 * 1024 * 1024; @@ -78,7 +79,11 @@ async function uploadCodeEnvFile({ req, stream, filename, apiKey, entity_id = '' return `${fileIdentifier}?entity_id=${entity_id}`; } catch (error) { - throw new Error(`Error uploading file: ${error.message}`); + logAxiosError({ + message: `Error uploading code environment file: ${error.message}`, + error, + }); + throw new Error(`Error uploading code environment file: ${error.message}`); } } diff --git a/api/server/services/Files/Code/process.js b/api/server/services/Files/Code/process.js index 2a941a4647..c92e628589 100644 --- a/api/server/services/Files/Code/process.js +++ b/api/server/services/Files/Code/process.js @@ -12,6 +12,7 @@ const { const { getStrategyFunctions } = require('~/server/services/Files/strategies'); const { convertImage } = require('~/server/services/Files/images/convert'); const { createFile, getFiles, updateFile } = require('~/models/File'); +const { logAxiosError } = require('~/utils'); const { logger } = require('~/config'); /** @@ -85,7 +86,10 @@ const processCodeOutput = async ({ /** Note: `messageId` & `toolCallId` are not part of file DB schema; message object records associated file ID */ return Object.assign(file, { messageId, toolCallId }); } catch (error) { - logger.error('Error downloading file:', error); + logAxiosError({ + message: 'Error downloading code environment file', + error, + }); } }; @@ -135,7 +139,10 @@ async function getSessionInfo(fileIdentifier, apiKey) { return response.data.find((file) => file.name.startsWith(path))?.lastModified; } catch (error) { - logger.error(`Error fetching session info: ${error.message}`, error); + logAxiosError({ + message: `Error fetching session info: ${error.message}`, + error, + }); return null; } } @@ -202,7 +209,7 @@ const primeFiles = async (options, apiKey) => { const { handleFileUpload: uploadCodeEnvFile } = getStrategyFunctions( FileSources.execute_code, ); - const stream = await getDownloadStream(file.filepath); + const stream = await getDownloadStream(options.req, file.filepath); const fileIdentifier = await uploadCodeEnvFile({ req: options.req, stream, diff --git a/api/server/services/Files/Firebase/crud.js b/api/server/services/Files/Firebase/crud.js index 76a6c1d8d4..8319f908ef 100644 --- a/api/server/services/Files/Firebase/crud.js +++ b/api/server/services/Files/Firebase/crud.js @@ -224,10 +224,11 @@ async function uploadFileToFirebase({ req, file, file_id }) { /** * Retrieves a readable stream for a file from Firebase storage. * + * @param {ServerRequest} _req * @param {string} filepath - The filepath. * @returns {Promise} A readable stream of the file. */ -async function getFirebaseFileStream(filepath) { +async function getFirebaseFileStream(_req, filepath) { try { const storage = getFirebaseStorage(); if (!storage) { diff --git a/api/server/services/Files/Local/crud.js b/api/server/services/Files/Local/crud.js index e004eab79e..c2bb75c125 100644 --- a/api/server/services/Files/Local/crud.js +++ b/api/server/services/Files/Local/crud.js @@ -175,6 +175,17 @@ const isValidPath = (req, base, subfolder, filepath) => { return normalizedFilepath.startsWith(normalizedBase); }; +/** + * @param {string} filepath + */ +const unlinkFile = async (filepath) => { + try { + await fs.promises.unlink(filepath); + } catch (error) { + logger.error('Error deleting file:', error); + } +}; + /** * Deletes a file from the filesystem. This function takes a file object, constructs the full path, and * verifies the path's validity before deleting the file. If the path is invalid, an error is thrown. @@ -217,7 +228,7 @@ const deleteLocalFile = async (req, file) => { throw new Error(`Invalid file path: ${file.filepath}`); } - await fs.promises.unlink(filepath); + await unlinkFile(filepath); return; } @@ -233,7 +244,7 @@ const deleteLocalFile = async (req, file) => { throw new Error('Invalid file path'); } - await fs.promises.unlink(filepath); + await unlinkFile(filepath); }; /** @@ -275,11 +286,31 @@ async function uploadLocalFile({ req, file, file_id }) { /** * Retrieves a readable stream for a file from local storage. * + * @param {ServerRequest} req - The request object from Express * @param {string} filepath - The filepath. * @returns {ReadableStream} A readable stream of the file. */ -function getLocalFileStream(filepath) { +function getLocalFileStream(req, filepath) { try { + if (filepath.includes('/uploads/')) { + const basePath = filepath.split('/uploads/')[1]; + + if (!basePath) { + logger.warn(`Invalid base path: ${filepath}`); + throw new Error(`Invalid file path: ${filepath}`); + } + + const fullPath = path.join(req.app.locals.paths.uploads, basePath); + const uploadsDir = req.app.locals.paths.uploads; + + const rel = path.relative(uploadsDir, fullPath); + if (rel.startsWith('..') || path.isAbsolute(rel) || rel.includes(`..${path.sep}`)) { + logger.warn(`Invalid relative file path: ${filepath}`); + throw new Error(`Invalid file path: ${filepath}`); + } + + return fs.createReadStream(fullPath); + } return fs.createReadStream(filepath); } catch (error) { logger.error('Error getting local file stream:', error); diff --git a/api/server/services/Files/VectorDB/crud.js b/api/server/services/Files/VectorDB/crud.js index d290eea4b1..37a1e81487 100644 --- a/api/server/services/Files/VectorDB/crud.js +++ b/api/server/services/Files/VectorDB/crud.js @@ -37,7 +37,14 @@ const deleteVectors = async (req, file) => { error, message: 'Error deleting vectors', }); - throw new Error(error.message || 'An error occurred during file deletion.'); + if ( + error.response && + error.response.status !== 404 && + (error.response.status < 200 || error.response.status >= 300) + ) { + logger.warn('Error deleting vectors, file will not be deleted'); + throw new Error(error.message || 'An error occurred during file deletion.'); + } } }; diff --git a/api/server/services/Files/process.js b/api/server/services/Files/process.js index a5d9c8c1e0..8744eb409b 100644 --- a/api/server/services/Files/process.js +++ b/api/server/services/Files/process.js @@ -347,8 +347,8 @@ const uploadImageBuffer = async ({ req, context, metadata = {}, resize = true }) req.app.locals.imageOutputType }`; } - - const filepath = await saveBuffer({ userId: req.user.id, fileName: filename, buffer }); + const fileName = `${file_id}-${filename}`; + const filepath = await saveBuffer({ userId: req.user.id, fileName, buffer }); return await createFile( { user: req.user.id, @@ -801,8 +801,7 @@ async function saveBase64Image( { req, file_id: _file_id, filename: _filename, endpoint, context, resolution = 'high' }, ) { const file_id = _file_id ?? v4(); - - let filename = _filename; + let filename = `${file_id}-${_filename}`; const { buffer: inputBuffer, type } = base64ToBuffer(url); if (!path.extname(_filename)) { const extension = mime.getExtension(type); diff --git a/api/server/services/ModelService.js b/api/server/services/ModelService.js index 9630f0bd87..48ae85b663 100644 --- a/api/server/services/ModelService.js +++ b/api/server/services/ModelService.js @@ -129,9 +129,6 @@ const fetchOpenAIModels = async (opts, _models = []) => { // .split('/deployments')[0] // .concat(`/models?api-version=${azure.azureOpenAIApiVersion}`); // apiKey = azureOpenAIApiKey; - } else if (process.env.OPENROUTER_API_KEY) { - reverseProxyUrl = 'https://openrouter.ai/api/v1'; - apiKey = process.env.OPENROUTER_API_KEY; } if (reverseProxyUrl) { @@ -218,7 +215,7 @@ const getOpenAIModels = async (opts) => { return models; } - if (userProvidedOpenAI && !process.env.OPENROUTER_API_KEY) { + if (userProvidedOpenAI) { return models; } diff --git a/api/server/services/ModelService.spec.js b/api/server/services/ModelService.spec.js index a383db1e3c..1fbe347a00 100644 --- a/api/server/services/ModelService.spec.js +++ b/api/server/services/ModelService.spec.js @@ -161,22 +161,6 @@ describe('getOpenAIModels', () => { expect(models).toEqual(expect.arrayContaining(['openai-model', 'openai-model-2'])); }); - it('attempts to use OPENROUTER_API_KEY if set', async () => { - process.env.OPENROUTER_API_KEY = 'test-router-key'; - const expectedModels = ['model-router-1', 'model-router-2']; - - axios.get.mockResolvedValue({ - data: { - data: expectedModels.map((id) => ({ id })), - }, - }); - - const models = await getOpenAIModels({ user: 'user456' }); - - expect(models).toEqual(expect.arrayContaining(expectedModels)); - expect(axios.get).toHaveBeenCalled(); - }); - it('utilizes proxy configuration when PROXY is set', async () => { axios.get.mockResolvedValue({ data: { diff --git a/api/server/services/start/interface.js b/api/server/services/start/interface.js index 98bcb6858e..5365c4af7f 100644 --- a/api/server/services/start/interface.js +++ b/api/server/services/start/interface.js @@ -34,6 +34,8 @@ async function loadDefaultInterface(config, configDefaults, roleName = SystemRol multiConvo: interfaceConfig?.multiConvo ?? defaults.multiConvo, agents: interfaceConfig?.agents ?? defaults.agents, temporaryChat: interfaceConfig?.temporaryChat ?? defaults.temporaryChat, + runCode: interfaceConfig?.runCode ?? defaults.runCode, + customWelcome: interfaceConfig?.customWelcome ?? defaults.customWelcome, }); await updateAccessPermissions(roleName, { @@ -41,12 +43,16 @@ async function loadDefaultInterface(config, configDefaults, roleName = SystemRol [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: loadedInterface.bookmarks }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: loadedInterface.multiConvo }, [PermissionTypes.AGENTS]: { [Permissions.USE]: loadedInterface.agents }, + [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: loadedInterface.temporaryChat }, + [PermissionTypes.RUN_CODE]: { [Permissions.USE]: loadedInterface.runCode }, }); await updateAccessPermissions(SystemRoles.ADMIN, { [PermissionTypes.PROMPTS]: { [Permissions.USE]: loadedInterface.prompts }, [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: loadedInterface.bookmarks }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: loadedInterface.multiConvo }, [PermissionTypes.AGENTS]: { [Permissions.USE]: loadedInterface.agents }, + [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: loadedInterface.temporaryChat }, + [PermissionTypes.RUN_CODE]: { [Permissions.USE]: loadedInterface.runCode }, }); let i = 0; diff --git a/api/server/services/start/interface.spec.js b/api/server/services/start/interface.spec.js index 0041246433..7e248d3dfe 100644 --- a/api/server/services/start/interface.spec.js +++ b/api/server/services/start/interface.spec.js @@ -14,6 +14,8 @@ describe('loadDefaultInterface', () => { bookmarks: true, multiConvo: true, agents: true, + temporaryChat: true, + runCode: true, }, }; const configDefaults = { interface: {} }; @@ -25,6 +27,8 @@ describe('loadDefaultInterface', () => { [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: true }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: true }, [PermissionTypes.AGENTS]: { [Permissions.USE]: true }, + [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: true }, + [PermissionTypes.RUN_CODE]: { [Permissions.USE]: true }, }); }); @@ -35,6 +39,8 @@ describe('loadDefaultInterface', () => { bookmarks: false, multiConvo: false, agents: false, + temporaryChat: false, + runCode: false, }, }; const configDefaults = { interface: {} }; @@ -46,6 +52,8 @@ describe('loadDefaultInterface', () => { [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: false }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: false }, [PermissionTypes.AGENTS]: { [Permissions.USE]: false }, + [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: false }, + [PermissionTypes.RUN_CODE]: { [Permissions.USE]: false }, }); }); @@ -60,6 +68,8 @@ describe('loadDefaultInterface', () => { [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: undefined }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: undefined }, [PermissionTypes.AGENTS]: { [Permissions.USE]: undefined }, + [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: undefined }, + [PermissionTypes.RUN_CODE]: { [Permissions.USE]: undefined }, }); }); @@ -70,6 +80,8 @@ describe('loadDefaultInterface', () => { bookmarks: undefined, multiConvo: undefined, agents: undefined, + temporaryChat: undefined, + runCode: undefined, }, }; const configDefaults = { interface: {} }; @@ -81,6 +93,8 @@ describe('loadDefaultInterface', () => { [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: undefined }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: undefined }, [PermissionTypes.AGENTS]: { [Permissions.USE]: undefined }, + [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: undefined }, + [PermissionTypes.RUN_CODE]: { [Permissions.USE]: undefined }, }); }); @@ -91,6 +105,8 @@ describe('loadDefaultInterface', () => { bookmarks: false, multiConvo: undefined, agents: true, + temporaryChat: undefined, + runCode: false, }, }; const configDefaults = { interface: {} }; @@ -102,6 +118,8 @@ describe('loadDefaultInterface', () => { [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: false }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: undefined }, [PermissionTypes.AGENTS]: { [Permissions.USE]: true }, + [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: undefined }, + [PermissionTypes.RUN_CODE]: { [Permissions.USE]: false }, }); }); @@ -113,6 +131,8 @@ describe('loadDefaultInterface', () => { bookmarks: true, multiConvo: true, agents: true, + temporaryChat: true, + runCode: true, }, }; @@ -123,6 +143,8 @@ describe('loadDefaultInterface', () => { [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: true }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: true }, [PermissionTypes.AGENTS]: { [Permissions.USE]: true }, + [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: true }, + [PermissionTypes.RUN_CODE]: { [Permissions.USE]: true }, }); }); @@ -137,6 +159,8 @@ describe('loadDefaultInterface', () => { [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: undefined }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: true }, [PermissionTypes.AGENTS]: { [Permissions.USE]: undefined }, + [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: undefined }, + [PermissionTypes.RUN_CODE]: { [Permissions.USE]: undefined }, }); }); @@ -151,6 +175,8 @@ describe('loadDefaultInterface', () => { [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: undefined }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: false }, [PermissionTypes.AGENTS]: { [Permissions.USE]: undefined }, + [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: undefined }, + [PermissionTypes.RUN_CODE]: { [Permissions.USE]: undefined }, }); }); @@ -165,6 +191,8 @@ describe('loadDefaultInterface', () => { [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: undefined }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: undefined }, [PermissionTypes.AGENTS]: { [Permissions.USE]: undefined }, + [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: undefined }, + [PermissionTypes.RUN_CODE]: { [Permissions.USE]: undefined }, }); }); @@ -175,6 +203,8 @@ describe('loadDefaultInterface', () => { bookmarks: false, multiConvo: true, agents: false, + temporaryChat: true, + runCode: false, }, }; const configDefaults = { interface: {} }; @@ -186,6 +216,8 @@ describe('loadDefaultInterface', () => { [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: false }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: true }, [PermissionTypes.AGENTS]: { [Permissions.USE]: false }, + [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: true }, + [PermissionTypes.RUN_CODE]: { [Permissions.USE]: false }, }); }); @@ -197,6 +229,8 @@ describe('loadDefaultInterface', () => { bookmarks: true, multiConvo: false, agents: undefined, + temporaryChat: undefined, + runCode: undefined, }, }; @@ -207,6 +241,8 @@ describe('loadDefaultInterface', () => { [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: true }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: false }, [PermissionTypes.AGENTS]: { [Permissions.USE]: undefined }, + [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: undefined }, + [PermissionTypes.RUN_CODE]: { [Permissions.USE]: undefined }, }); }); }); diff --git a/api/server/services/twoFactorService.js b/api/server/services/twoFactorService.js new file mode 100644 index 0000000000..e48b2ac938 --- /dev/null +++ b/api/server/services/twoFactorService.js @@ -0,0 +1,238 @@ +const { sign } = require('jsonwebtoken'); +const { webcrypto } = require('node:crypto'); +const { hashBackupCode, decryptV2 } = require('~/server/utils/crypto'); +const { updateUser } = require('~/models/userMethods'); + +const BASE32_ALPHABET = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ234567'; + +/** + * Encodes a Buffer into a Base32 string using the RFC 4648 alphabet. + * + * @param {Buffer} buffer - The buffer to encode. + * @returns {string} The Base32 encoded string. + */ +const encodeBase32 = (buffer) => { + let bits = 0; + let value = 0; + let output = ''; + for (const byte of buffer) { + value = (value << 8) | byte; + bits += 8; + while (bits >= 5) { + output += BASE32_ALPHABET[(value >>> (bits - 5)) & 31]; + bits -= 5; + } + } + if (bits > 0) { + output += BASE32_ALPHABET[(value << (5 - bits)) & 31]; + } + return output; +}; + +/** + * Decodes a Base32-encoded string back into a Buffer. + * + * @param {string} base32Str - The Base32-encoded string. + * @returns {Buffer} The decoded buffer. + */ +const decodeBase32 = (base32Str) => { + const cleaned = base32Str.replace(/=+$/, '').toUpperCase(); + let bits = 0; + let value = 0; + const output = []; + for (const char of cleaned) { + const idx = BASE32_ALPHABET.indexOf(char); + if (idx === -1) { + continue; + } + value = (value << 5) | idx; + bits += 5; + if (bits >= 8) { + output.push((value >>> (bits - 8)) & 0xff); + bits -= 8; + } + } + return Buffer.from(output); +}; + +/** + * Generates a temporary token for 2FA verification. + * The token is signed with the JWT_SECRET and expires in 5 minutes. + * + * @param {string} userId - The unique identifier of the user. + * @returns {string} The signed JWT token. + */ +const generate2FATempToken = (userId) => + sign({ userId, twoFAPending: true }, process.env.JWT_SECRET, { expiresIn: '5m' }); + +/** + * Generates a TOTP secret. + * Creates 10 random bytes using WebCrypto and encodes them into a Base32 string. + * + * @returns {string} A Base32-encoded secret for TOTP. + */ +const generateTOTPSecret = () => { + const randomArray = new Uint8Array(10); + webcrypto.getRandomValues(randomArray); + return encodeBase32(Buffer.from(randomArray)); +}; + +/** + * Generates a Time-based One-Time Password (TOTP) based on the provided secret and time. + * This implementation uses a 30-second time step and produces a 6-digit code. + * + * @param {string} secret - The Base32-encoded TOTP secret. + * @param {number} [forTime=Date.now()] - The time (in milliseconds) for which to generate the TOTP. + * @returns {Promise} A promise that resolves to the 6-digit TOTP code. + */ +const generateTOTP = async (secret, forTime = Date.now()) => { + const timeStep = 30; // seconds + const counter = Math.floor(forTime / 1000 / timeStep); + const counterBuffer = new ArrayBuffer(8); + const counterView = new DataView(counterBuffer); + // Write counter into the last 4 bytes (big-endian) + counterView.setUint32(4, counter, false); + + // Decode the secret into an ArrayBuffer + const keyBuffer = decodeBase32(secret); + const keyArrayBuffer = keyBuffer.buffer.slice( + keyBuffer.byteOffset, + keyBuffer.byteOffset + keyBuffer.byteLength, + ); + + // Import the key for HMAC-SHA1 signing + const cryptoKey = await webcrypto.subtle.importKey( + 'raw', + keyArrayBuffer, + { name: 'HMAC', hash: 'SHA-1' }, + false, + ['sign'], + ); + + // Generate HMAC signature + const signatureBuffer = await webcrypto.subtle.sign('HMAC', cryptoKey, counterBuffer); + const hmac = new Uint8Array(signatureBuffer); + + // Dynamic truncation as per RFC 4226 + const offset = hmac[hmac.length - 1] & 0xf; + const slice = hmac.slice(offset, offset + 4); + const view = new DataView(slice.buffer, slice.byteOffset, slice.byteLength); + const binaryCode = view.getUint32(0, false) & 0x7fffffff; + const code = (binaryCode % 1000000).toString().padStart(6, '0'); + return code; +}; + +/** + * Verifies a provided TOTP token against the secret. + * It allows for a ±1 time-step window to account for slight clock discrepancies. + * + * @param {string} secret - The Base32-encoded TOTP secret. + * @param {string} token - The TOTP token provided by the user. + * @returns {Promise} A promise that resolves to true if the token is valid; otherwise, false. + */ +const verifyTOTP = async (secret, token) => { + const timeStepMS = 30 * 1000; + const currentTime = Date.now(); + for (let offset = -1; offset <= 1; offset++) { + const expected = await generateTOTP(secret, currentTime + offset * timeStepMS); + if (expected === token) { + return true; + } + } + return false; +}; + +/** + * Generates backup codes for two-factor authentication. + * Each backup code is an 8-character hexadecimal string along with its SHA-256 hash. + * The plain codes are returned for one-time download, while the hashed objects are meant for secure storage. + * + * @param {number} [count=10] - The number of backup codes to generate. + * @returns {Promise<{ plainCodes: string[], codeObjects: Array<{ codeHash: string, used: boolean, usedAt: Date | null }> }>} + * A promise that resolves to an object containing both plain backup codes and their corresponding code objects. + */ +const generateBackupCodes = async (count = 10) => { + const plainCodes = []; + const codeObjects = []; + const encoder = new TextEncoder(); + for (let i = 0; i < count; i++) { + const randomArray = new Uint8Array(4); + webcrypto.getRandomValues(randomArray); + const code = Array.from(randomArray) + .map((b) => b.toString(16).padStart(2, '0')) + .join(''); // 8-character hex code + plainCodes.push(code); + + // Compute SHA-256 hash of the code using WebCrypto + const codeBuffer = encoder.encode(code); + const hashBuffer = await webcrypto.subtle.digest('SHA-256', codeBuffer); + const hashArray = Array.from(new Uint8Array(hashBuffer)); + const codeHash = hashArray.map((b) => b.toString(16).padStart(2, '0')).join(''); + codeObjects.push({ codeHash, used: false, usedAt: null }); + } + return { plainCodes, codeObjects }; +}; + +/** + * Verifies a backup code for a user and updates its status as used if valid. + * + * @param {Object} params - The parameters object. + * @param {TUser | undefined} [params.user] - The user object containing backup codes. + * @param {string | undefined} [params.backupCode] - The backup code to verify. + * @returns {Promise} A promise that resolves to true if the backup code is valid and updated; otherwise, false. + */ +const verifyBackupCode = async ({ user, backupCode }) => { + if (!backupCode || !user || !Array.isArray(user.backupCodes)) { + return false; + } + + const hashedInput = await hashBackupCode(backupCode.trim()); + const matchingCode = user.backupCodes.find( + (codeObj) => codeObj.codeHash === hashedInput && !codeObj.used, + ); + + if (matchingCode) { + const updatedBackupCodes = user.backupCodes.map((codeObj) => + codeObj.codeHash === hashedInput && !codeObj.used + ? { ...codeObj, used: true, usedAt: new Date() } + : codeObj, + ); + + await updateUser(user._id, { backupCodes: updatedBackupCodes }); + return true; + } + + return false; +}; + +/** + * Retrieves and, if necessary, decrypts a stored TOTP secret. + * If the secret contains a colon, it is assumed to be in the format "iv:encryptedData" and will be decrypted. + * If the secret is exactly 16 characters long, it is assumed to be a legacy plain secret. + * + * @param {string|null} storedSecret - The stored TOTP secret (which may be encrypted). + * @returns {Promise} A promise that resolves to the plain TOTP secret, or null if none is provided. + */ +const getTOTPSecret = async (storedSecret) => { + if (!storedSecret) { return null; } + // Check for a colon marker (encrypted secrets are stored as "iv:encryptedData") + if (storedSecret.includes(':')) { + return await decryptV2(storedSecret); + } + // If it's exactly 16 characters, assume it's already plain (legacy secret) + if (storedSecret.length === 16) { + return storedSecret; + } + // Fallback in case it doesn't meet our criteria. + return storedSecret; +}; + +module.exports = { + verifyTOTP, + generateTOTP, + getTOTPSecret, + verifyBackupCode, + generateTOTPSecret, + generateBackupCodes, + generate2FATempToken, +}; diff --git a/api/server/utils/crypto.js b/api/server/utils/crypto.js index ea71df51ad..407fad62ac 100644 --- a/api/server/utils/crypto.js +++ b/api/server/utils/crypto.js @@ -112,4 +112,25 @@ async function getRandomValues(length) { return Buffer.from(randomValues).toString('hex'); } -module.exports = { encrypt, decrypt, encryptV2, decryptV2, hashToken, getRandomValues }; +/** + * Computes SHA-256 hash for the given input using WebCrypto + * @param {string} input + * @returns {Promise} - Hex hash string + */ +const hashBackupCode = async (input) => { + const encoder = new TextEncoder(); + const data = encoder.encode(input); + const hashBuffer = await webcrypto.subtle.digest('SHA-256', data); + const hashArray = Array.from(new Uint8Array(hashBuffer)); + return hashArray.map((b) => b.toString(16).padStart(2, '0')).join(''); +}; + +module.exports = { + encrypt, + decrypt, + encryptV2, + decryptV2, + hashToken, + hashBackupCode, + getRandomValues, +}; diff --git a/api/server/utils/staticCache.js b/api/server/utils/staticCache.js index a8001c7e0a..23713ddf6f 100644 --- a/api/server/utils/staticCache.js +++ b/api/server/utils/staticCache.js @@ -1,4 +1,4 @@ -const express = require('express'); +const expressStaticGzip = require('express-static-gzip'); const oneDayInSeconds = 24 * 60 * 60; @@ -6,13 +6,13 @@ const sMaxAge = process.env.STATIC_CACHE_S_MAX_AGE || oneDayInSeconds; const maxAge = process.env.STATIC_CACHE_MAX_AGE || oneDayInSeconds * 2; const staticCache = (staticPath) => - express.static(staticPath, { - setHeaders: (res) => { - if (process.env.NODE_ENV?.toLowerCase() !== 'production') { - return; + expressStaticGzip(staticPath, { + enableBrotli: false, // disable Brotli, only using gzip + orderPreference: ['gz'], + setHeaders: (res, _path) => { + if (process.env.NODE_ENV?.toLowerCase() === 'production') { + res.setHeader('Cache-Control', `public, max-age=${maxAge}, s-maxage=${sMaxAge}`); } - - res.setHeader('Cache-Control', `public, max-age=${maxAge}, s-maxage=${sMaxAge}`); }, }); diff --git a/api/strategies/jwtStrategy.js b/api/strategies/jwtStrategy.js index e65b284950..ac19e92ac3 100644 --- a/api/strategies/jwtStrategy.js +++ b/api/strategies/jwtStrategy.js @@ -12,7 +12,7 @@ const jwtLogin = async () => }, async (payload, done) => { try { - const user = await getUserById(payload?.id, '-password -__v'); + const user = await getUserById(payload?.id, '-password -__v -totpSecret'); if (user) { user.id = user._id.toString(); if (!user.role) { diff --git a/api/typedefs.js b/api/typedefs.js index bd97bd93fa..cc7ae41895 100644 --- a/api/typedefs.js +++ b/api/typedefs.js @@ -20,6 +20,12 @@ * @memberof typedefs */ +/** + * @exports NextFunction + * @typedef {import('express').NextFunction} NextFunction + * @memberof typedefs + */ + /** * @exports AgentRun * @typedef {import('@librechat/agents').Run} AgentRun diff --git a/api/utils/axios.js b/api/utils/axios.js index 8b12a5ca99..acd23a184f 100644 --- a/api/utils/axios.js +++ b/api/utils/axios.js @@ -5,40 +5,32 @@ const { logger } = require('~/config'); * * @param {Object} options - The options object. * @param {string} options.message - The custom message to be logged. - * @param {Error} options.error - The Axios error object. + * @param {import('axios').AxiosError} options.error - The Axios error object. */ const logAxiosError = ({ message, error }) => { - const timedOutMessage = 'Cannot read properties of undefined (reading \'status\')'; - if (error.response) { - logger.error( - `${message} The request was made and the server responded with a status code that falls out of the range of 2xx: ${ - error.message ? error.message : '' - }. Error response data:\n`, - { - headers: error.response?.headers, - status: error.response?.status, - data: error.response?.data, - }, - ); - } else if (error.request) { - logger.error( - `${message} The request was made but no response was received: ${ - error.message ? error.message : '' - }. Error Request:\n`, - { - request: error.request, - }, - ); - } else if (error?.message?.includes(timedOutMessage)) { - logger.error( - `${message}\nThe request either timed out or was unsuccessful. Error message:\n`, - error, - ); - } else { - logger.error( - `${message}\nSomething happened in setting up the request. Error message:\n`, - error, - ); + try { + if (error.response?.status) { + const { status, headers, data } = error.response; + logger.error(`${message} The server responded with status ${status}: ${error.message}`, { + status, + headers, + data, + }); + } else if (error.request) { + const { method, url } = error.config || {}; + logger.error( + `${message} No response received for ${method ? method.toUpperCase() : ''} ${url || ''}: ${error.message}`, + { requestInfo: { method, url } }, + ); + } else if (error?.message?.includes('Cannot read properties of undefined (reading \'status\')')) { + logger.error( + `${message} It appears the request timed out or was unsuccessful: ${error.message}`, + ); + } else { + logger.error(`${message} An error occurred while setting up the request: ${error.message}`); + } + } catch (err) { + logger.error(`Error in logAxiosError: ${err.message}`); } }; diff --git a/api/utils/tokens.js b/api/utils/tokens.js index 0541f4f301..8edfb0a31c 100644 --- a/api/utils/tokens.js +++ b/api/utils/tokens.js @@ -13,6 +13,7 @@ const openAIModels = { 'gpt-4-32k-0613': 32758, // -10 from max 'gpt-4-1106': 127500, // -500 from max 'gpt-4-0125': 127500, // -500 from max + 'gpt-4.5': 127500, // -500 from max 'gpt-4o': 127500, // -500 from max 'gpt-4o-mini': 127500, // -500 from max 'gpt-4o-2024-05-13': 127500, // -500 from max @@ -74,6 +75,7 @@ const anthropicModels = { 'claude-instant': 100000, 'claude-2': 100000, 'claude-2.1': 200000, + 'claude-3': 200000, 'claude-3-haiku': 200000, 'claude-3-sonnet': 200000, 'claude-3-opus': 200000, @@ -81,6 +83,8 @@ const anthropicModels = { 'claude-3-5-haiku': 200000, 'claude-3-5-sonnet': 200000, 'claude-3.5-sonnet': 200000, + 'claude-3-7-sonnet': 200000, + 'claude-3.7-sonnet': 200000, 'claude-3-5-sonnet-latest': 200000, 'claude-3.5-sonnet-latest': 200000, }; @@ -183,7 +187,18 @@ const bedrockModels = { ...amazonModels, }; -const aggregateModels = { ...openAIModels, ...googleModels, ...bedrockModels }; +const xAIModels = { + 'grok-beta': 131072, + 'grok-vision-beta': 8192, + 'grok-2': 131072, + 'grok-2-latest': 131072, + 'grok-2-1212': 131072, + 'grok-2-vision': 32768, + 'grok-2-vision-latest': 32768, + 'grok-2-vision-1212': 32768, +}; + +const aggregateModels = { ...openAIModels, ...googleModels, ...bedrockModels, ...xAIModels }; const maxTokensMap = { [EModelEndpoint.azureOpenAI]: openAIModels, diff --git a/api/utils/tokens.spec.js b/api/utils/tokens.spec.js index eb1fd85495..d4dbb30498 100644 --- a/api/utils/tokens.spec.js +++ b/api/utils/tokens.spec.js @@ -103,6 +103,16 @@ describe('getModelMaxTokens', () => { ); }); + test('should return correct tokens for gpt-4.5 matches', () => { + expect(getModelMaxTokens('gpt-4.5')).toBe(maxTokensMap[EModelEndpoint.openAI]['gpt-4.5']); + expect(getModelMaxTokens('gpt-4.5-preview')).toBe( + maxTokensMap[EModelEndpoint.openAI]['gpt-4.5'], + ); + expect(getModelMaxTokens('openai/gpt-4.5-preview')).toBe( + maxTokensMap[EModelEndpoint.openAI]['gpt-4.5'], + ); + }); + test('should return correct tokens for Anthropic models', () => { const models = [ 'claude-2.1', @@ -116,6 +126,7 @@ describe('getModelMaxTokens', () => { 'claude-3-sonnet', 'claude-3-opus', 'claude-3-5-sonnet', + 'claude-3-7-sonnet', ]; const maxTokens = { @@ -483,3 +494,68 @@ describe('Meta Models Tests', () => { }); }); }); + +describe('Grok Model Tests - Tokens', () => { + describe('getModelMaxTokens', () => { + test('should return correct tokens for Grok vision models', () => { + expect(getModelMaxTokens('grok-2-vision-1212')).toBe(32768); + expect(getModelMaxTokens('grok-2-vision')).toBe(32768); + expect(getModelMaxTokens('grok-2-vision-latest')).toBe(32768); + }); + + test('should return correct tokens for Grok beta models', () => { + expect(getModelMaxTokens('grok-vision-beta')).toBe(8192); + expect(getModelMaxTokens('grok-beta')).toBe(131072); + }); + + test('should return correct tokens for Grok text models', () => { + expect(getModelMaxTokens('grok-2-1212')).toBe(131072); + expect(getModelMaxTokens('grok-2')).toBe(131072); + expect(getModelMaxTokens('grok-2-latest')).toBe(131072); + }); + + test('should handle partial matches for Grok models with prefixes', () => { + // Vision models should match before general models + expect(getModelMaxTokens('openai/grok-2-vision-1212')).toBe(32768); + expect(getModelMaxTokens('openai/grok-2-vision')).toBe(32768); + expect(getModelMaxTokens('openai/grok-2-vision-latest')).toBe(32768); + // Beta models + expect(getModelMaxTokens('openai/grok-vision-beta')).toBe(8192); + expect(getModelMaxTokens('openai/grok-beta')).toBe(131072); + // Text models + expect(getModelMaxTokens('openai/grok-2-1212')).toBe(131072); + expect(getModelMaxTokens('openai/grok-2')).toBe(131072); + expect(getModelMaxTokens('openai/grok-2-latest')).toBe(131072); + }); + }); + + describe('matchModelName', () => { + test('should match exact Grok model names', () => { + // Vision models + expect(matchModelName('grok-2-vision-1212')).toBe('grok-2-vision-1212'); + expect(matchModelName('grok-2-vision')).toBe('grok-2-vision'); + expect(matchModelName('grok-2-vision-latest')).toBe('grok-2-vision-latest'); + // Beta models + expect(matchModelName('grok-vision-beta')).toBe('grok-vision-beta'); + expect(matchModelName('grok-beta')).toBe('grok-beta'); + // Text models + expect(matchModelName('grok-2-1212')).toBe('grok-2-1212'); + expect(matchModelName('grok-2')).toBe('grok-2'); + expect(matchModelName('grok-2-latest')).toBe('grok-2-latest'); + }); + + test('should match Grok model variations with prefixes', () => { + // Vision models should match before general models + expect(matchModelName('openai/grok-2-vision-1212')).toBe('grok-2-vision-1212'); + expect(matchModelName('openai/grok-2-vision')).toBe('grok-2-vision'); + expect(matchModelName('openai/grok-2-vision-latest')).toBe('grok-2-vision-latest'); + // Beta models + expect(matchModelName('openai/grok-vision-beta')).toBe('grok-vision-beta'); + expect(matchModelName('openai/grok-beta')).toBe('grok-beta'); + // Text models + expect(matchModelName('openai/grok-2-1212')).toBe('grok-2-1212'); + expect(matchModelName('openai/grok-2')).toBe('grok-2'); + expect(matchModelName('openai/grok-2-latest')).toBe('grok-2-latest'); + }); + }); +}); diff --git a/client/index.html b/client/index.html index 9bd0363fab..9e300e7365 100644 --- a/client/index.html +++ b/client/index.html @@ -6,6 +6,7 @@ + LibreChat @@ -53,6 +54,5 @@
- diff --git a/client/package.json b/client/package.json index 22e9b1dd03..5aa8293f52 100644 --- a/client/package.json +++ b/client/package.json @@ -44,6 +44,7 @@ "@radix-ui/react-icons": "^1.3.0", "@radix-ui/react-label": "^2.0.0", "@radix-ui/react-popover": "^1.0.7", + "@radix-ui/react-progress": "^1.1.2", "@radix-ui/react-radio-group": "^1.1.3", "@radix-ui/react-select": "^2.0.0", "@radix-ui/react-separator": "^1.0.3", @@ -65,6 +66,7 @@ "html-to-image": "^1.11.11", "i18next": "^24.2.2", "i18next-browser-languagedetector": "^8.0.3", + "input-otp": "^1.4.2", "js-cookie": "^3.0.5", "librechat-data-provider": "*", "lodash": "^4.17.21", @@ -84,7 +86,7 @@ "react-i18next": "^15.4.0", "react-lazy-load-image-component": "^1.6.0", "react-markdown": "^9.0.1", - "react-resizable-panels": "^2.1.1", + "react-resizable-panels": "^2.1.7", "react-router-dom": "^6.11.2", "react-speech-recognition": "^3.10.0", "react-textarea-autosize": "^8.4.0", @@ -140,6 +142,7 @@ "typescript": "^5.3.3", "vite": "^6.1.0", "vite-plugin-node-polyfills": "^0.17.0", + "vite-plugin-compression": "^0.5.1", "vite-plugin-pwa": "^0.21.1" } } diff --git a/client/public/assets/apple-touch-icon-180x180.png b/client/public/assets/apple-touch-icon-180x180.png index 91dde5d139..57c4637c93 100644 Binary files a/client/public/assets/apple-touch-icon-180x180.png and b/client/public/assets/apple-touch-icon-180x180.png differ diff --git a/client/public/assets/icon-192x192.png b/client/public/assets/icon-192x192.png new file mode 100644 index 0000000000..b8dfe0eae5 Binary files /dev/null and b/client/public/assets/icon-192x192.png differ diff --git a/client/public/assets/maskable-icon.png b/client/public/assets/maskable-icon.png index 18305a2446..90e48f870b 100644 Binary files a/client/public/assets/maskable-icon.png and b/client/public/assets/maskable-icon.png differ diff --git a/client/public/robots.txt b/client/public/robots.txt new file mode 100644 index 0000000000..376d535e0a --- /dev/null +++ b/client/public/robots.txt @@ -0,0 +1,3 @@ +User-agent: * +Disallow: /api/ +Allow: / \ No newline at end of file diff --git a/client/src/components/Auth/AuthLayout.tsx b/client/src/components/Auth/AuthLayout.tsx index a7e890517a..d90f0d3dfe 100644 --- a/client/src/components/Auth/AuthLayout.tsx +++ b/client/src/components/Auth/AuthLayout.tsx @@ -85,7 +85,8 @@ function AuthLayout({ )} {children} - {(pathname.includes('login') || pathname.includes('register')) && ( + {!pathname.includes('2fa') && + (pathname.includes('login') || pathname.includes('register')) && ( )} diff --git a/client/src/components/Auth/LoginForm.tsx b/client/src/components/Auth/LoginForm.tsx index 9f5bb46039..2cd62d08b9 100644 --- a/client/src/components/Auth/LoginForm.tsx +++ b/client/src/components/Auth/LoginForm.tsx @@ -166,9 +166,7 @@ const LoginForm: React.FC = ({ onSubmit, startupConfig, error, type="submit" className=" w-full rounded-2xl bg-green-600 px-4 py-3 text-sm font-medium text-white - transition-colors hover:bg-green-700 focus:outline-none focus:ring-2 - focus:ring-green-500 focus:ring-offset-2 disabled:opacity-50 - disabled:hover:bg-green-600 dark:bg-green-600 dark:hover:bg-green-700 + transition-colors hover:bg-green-700 dark:bg-green-600 dark:hover:bg-green-700 " > {localize('com_auth_continue')} diff --git a/client/src/components/Auth/TwoFactorScreen.tsx b/client/src/components/Auth/TwoFactorScreen.tsx new file mode 100644 index 0000000000..04f89d7cea --- /dev/null +++ b/client/src/components/Auth/TwoFactorScreen.tsx @@ -0,0 +1,176 @@ +import React, { useState, useCallback } from 'react'; +import { useSearchParams } from 'react-router-dom'; +import { useForm, Controller } from 'react-hook-form'; +import { REGEXP_ONLY_DIGITS, REGEXP_ONLY_DIGITS_AND_CHARS } from 'input-otp'; +import { InputOTP, InputOTPGroup, InputOTPSeparator, InputOTPSlot, Label } from '~/components'; +import { useVerifyTwoFactorTempMutation } from '~/data-provider'; +import { useToastContext } from '~/Providers'; +import { useLocalize } from '~/hooks'; + +interface VerifyPayload { + tempToken: string; + token?: string; + backupCode?: string; +} + +type TwoFactorFormInputs = { + token?: string; + backupCode?: string; +}; + +const TwoFactorScreen: React.FC = React.memo(() => { + const [searchParams] = useSearchParams(); + const tempTokenRaw = searchParams.get('tempToken'); + const tempToken = tempTokenRaw !== null && tempTokenRaw !== '' ? tempTokenRaw : ''; + + const { + control, + handleSubmit, + formState: { errors }, + } = useForm(); + const localize = useLocalize(); + const { showToast } = useToastContext(); + const [useBackup, setUseBackup] = useState(false); + const [isLoading, setIsLoading] = useState(false); + const { mutate: verifyTempMutate } = useVerifyTwoFactorTempMutation({ + onSuccess: (result) => { + if (result.token != null && result.token !== '') { + window.location.href = '/'; + } + }, + onMutate: () => { + setIsLoading(true); + }, + onError: (error: unknown) => { + setIsLoading(false); + const err = error as { response?: { data?: { message?: unknown } } }; + const errorMsg = + typeof err.response?.data?.message === 'string' + ? err.response.data.message + : 'Error verifying 2FA'; + showToast({ message: errorMsg, status: 'error' }); + }, + }); + + const onSubmit = useCallback( + (data: TwoFactorFormInputs) => { + const payload: VerifyPayload = { tempToken }; + if (useBackup && data.backupCode != null && data.backupCode !== '') { + payload.backupCode = data.backupCode; + } else if (data.token != null && data.token !== '') { + payload.token = data.token; + } + verifyTempMutate(payload); + }, + [tempToken, useBackup, verifyTempMutate], + ); + + const toggleBackupOn = useCallback(() => { + setUseBackup(true); + }, []); + + const toggleBackupOff = useCallback(() => { + setUseBackup(false); + }, []); + + return ( +
+
+ + {!useBackup && ( +
+ ( + + + + + + + + + + + + + + )} + /> + {errors.token && {errors.token.message}} +
+ )} + {useBackup && ( +
+ ( + + + + + + + + + + + + + )} + /> + {errors.backupCode && ( + {errors.backupCode.message} + )} +
+ )} +
+ +
+
+ {!useBackup ? ( + + ) : ( + + )} +
+
+
+ ); +}); + +export default TwoFactorScreen; diff --git a/client/src/components/Auth/index.ts b/client/src/components/Auth/index.ts index cd1ac1adce..afde148015 100644 --- a/client/src/components/Auth/index.ts +++ b/client/src/components/Auth/index.ts @@ -4,3 +4,4 @@ export { default as ResetPassword } from './ResetPassword'; export { default as VerifyEmail } from './VerifyEmail'; export { default as ApiErrorWatcher } from './ApiErrorWatcher'; export { default as RequestPasswordReset } from './RequestPasswordReset'; +export { default as TwoFactorScreen } from './TwoFactorScreen'; diff --git a/client/src/components/Chat/Input/AudioRecorder.tsx b/client/src/components/Chat/Input/AudioRecorder.tsx index 96e29ec502..512c9c9d9c 100644 --- a/client/src/components/Chat/Input/AudioRecorder.tsx +++ b/client/src/components/Chat/Input/AudioRecorder.tsx @@ -81,17 +81,25 @@ export default function AudioRecorder({ return ( - {renderIcon()} - + render={ + + } + /> ); } diff --git a/client/src/components/Chat/Input/Files/FileUpload.tsx b/client/src/components/Chat/Input/Files/FileUpload.tsx index 506f50c01d..723fa32e86 100644 --- a/client/src/components/Chat/Input/Files/FileUpload.tsx +++ b/client/src/components/Chat/Input/Files/FileUpload.tsx @@ -55,7 +55,7 @@ const FileUpload: React.FC = ({ let statusText: string; if (!status) { - statusText = text ?? localize('com_endpoint_import'); + statusText = text ?? localize('com_ui_import'); } else if (status === 'success') { statusText = successText ?? localize('com_ui_upload_success'); } else { @@ -72,12 +72,12 @@ const FileUpload: React.FC = ({ )} > - {statusText} + {statusText} diff --git a/client/src/components/Chat/Input/HeaderOptions.tsx b/client/src/components/Chat/Input/HeaderOptions.tsx index 0bd3326b53..5313f43b8d 100644 --- a/client/src/components/Chat/Input/HeaderOptions.tsx +++ b/client/src/components/Chat/Input/HeaderOptions.tsx @@ -1,8 +1,13 @@ import { useRecoilState } from 'recoil'; import { Settings2 } from 'lucide-react'; -import { Root, Anchor } from '@radix-ui/react-popover'; import { useState, useEffect, useMemo } from 'react'; -import { tConvoUpdateSchema, EModelEndpoint, isParamEndpoint } from 'librechat-data-provider'; +import { Root, Anchor } from '@radix-ui/react-popover'; +import { + EModelEndpoint, + isParamEndpoint, + isAgentsEndpoint, + tConvoUpdateSchema, +} from 'librechat-data-provider'; import type { TPreset, TInterfaceConfig } from 'librechat-data-provider'; import { EndpointSettings, SaveAsPresetDialog, AlternativeSettings } from '~/components/Endpoints'; import { PluginStoreDialog, TooltipAnchor } from '~/components'; @@ -42,7 +47,6 @@ export default function HeaderOptions({ if (endpoint && noSettings[endpoint]) { setShowPopover(false); } - // eslint-disable-next-line react-hooks/exhaustive-deps }, [endpoint, noSettings]); const saveAsPreset = () => { @@ -67,7 +71,7 @@ export default function HeaderOptions({
- {interfaceConfig?.modelSelect === true && ( + {interfaceConfig?.modelSelect === true && !isAgentsEndpoint(endpoint) && (
{name}
- {description ? description : localize('com_nav_welcome_message')} + {description || + (typeof startupConfig?.interface?.customWelcome === 'string' + ? startupConfig?.interface?.customWelcome + : localize('com_nav_welcome_message'))}
{/*
-
By Daniel Avila
+
By Daniel Avila
*/}
) : ( diff --git a/client/src/components/Chat/Messages/Content/ContentParts.tsx b/client/src/components/Chat/Messages/Content/ContentParts.tsx index b997060c61..ddf08976eb 100644 --- a/client/src/components/Chat/Messages/Content/ContentParts.tsx +++ b/client/src/components/Chat/Messages/Content/ContentParts.tsx @@ -109,7 +109,9 @@ const ContentParts = memo( return val; }) } - label={isSubmitting ? localize('com_ui_thinking') : localize('com_ui_thoughts')} + label={ + isSubmitting && isLast ? localize('com_ui_thinking') : localize('com_ui_thoughts') + } />
)} diff --git a/client/src/components/Chat/Messages/Content/Image.tsx b/client/src/components/Chat/Messages/Content/Image.tsx index 28910d0315..41ee52453f 100644 --- a/client/src/components/Chat/Messages/Content/Image.tsx +++ b/client/src/components/Chat/Messages/Content/Image.tsx @@ -29,6 +29,7 @@ const Image = ({ height, width, placeholderDimensions, + className, }: { imagePath: string; altText: string; @@ -38,6 +39,7 @@ const Image = ({ height?: string; width?: string; }; + className?: string; }) => { const [isLoaded, setIsLoaded] = useState(false); const containerRef = useRef(null); @@ -57,7 +59,12 @@ const Image = ({ return (
-
+
diff --git a/client/src/components/Nav/SettingsTabs/Account/BackupCodesItem.tsx b/client/src/components/Nav/SettingsTabs/Account/BackupCodesItem.tsx new file mode 100644 index 0000000000..a034e2773a --- /dev/null +++ b/client/src/components/Nav/SettingsTabs/Account/BackupCodesItem.tsx @@ -0,0 +1,194 @@ +import React, { useState } from 'react'; +import { RefreshCcw, ShieldX } from 'lucide-react'; +import { motion, AnimatePresence } from 'framer-motion'; +import { TBackupCode, TRegenerateBackupCodesResponse, type TUser } from 'librechat-data-provider'; +import { + OGDialog, + OGDialogContent, + OGDialogTitle, + OGDialogTrigger, + Button, + Label, + Spinner, + TooltipAnchor, +} from '~/components'; +import { useRegenerateBackupCodesMutation } from '~/data-provider'; +import { useAuthContext, useLocalize } from '~/hooks'; +import { useToastContext } from '~/Providers'; +import { useSetRecoilState } from 'recoil'; +import store from '~/store'; + +const BackupCodesItem: React.FC = () => { + const localize = useLocalize(); + const { user } = useAuthContext(); + const { showToast } = useToastContext(); + const setUser = useSetRecoilState(store.user); + const [isDialogOpen, setDialogOpen] = useState(false); + + const { mutate: regenerateBackupCodes, isLoading } = useRegenerateBackupCodesMutation(); + + const fetchBackupCodes = (auto: boolean = false) => { + regenerateBackupCodes(undefined, { + onSuccess: (data: TRegenerateBackupCodesResponse) => { + const newBackupCodes: TBackupCode[] = data.backupCodesHash.map((codeHash) => ({ + codeHash, + used: false, + usedAt: null, + })); + + setUser((prev) => ({ ...prev, backupCodes: newBackupCodes }) as TUser); + showToast({ + message: localize('com_ui_backup_codes_regenerated'), + status: 'success', + }); + + // Trigger file download only when user explicitly clicks the button. + if (!auto && newBackupCodes.length) { + const codesString = data.backupCodes.join('\n'); + const blob = new Blob([codesString], { type: 'text/plain;charset=utf-8' }); + const url = URL.createObjectURL(blob); + const a = document.createElement('a'); + a.href = url; + a.download = 'backup-codes.txt'; + a.click(); + URL.revokeObjectURL(url); + } + }, + onError: () => + showToast({ + message: localize('com_ui_backup_codes_regenerate_error'), + status: 'error', + }), + }); + }; + + const handleRegenerate = () => { + fetchBackupCodes(false); + }; + + return ( + +
+
+ +
+ + + +
+ + + + {localize('com_ui_backup_codes')} + + + + + {Array.isArray(user?.backupCodes) && user?.backupCodes.length > 0 ? ( + <> +
+ {user?.backupCodes.map((code, index) => { + const isUsed = code.used; + const description = `Backup code number ${index + 1}, ${ + isUsed + ? `used on ${code.usedAt ? new Date(code.usedAt).toLocaleDateString() : 'an unknown date'}` + : 'not used yet' + }`; + + return ( + { + const announcement = new CustomEvent('announce', { + detail: { message: description }, + }); + document.dispatchEvent(announcement); + }} + className={`flex flex-col rounded-xl border p-4 backdrop-blur-sm transition-colors focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-primary ${ + isUsed + ? 'border-red-200 bg-red-50/80 dark:border-red-800 dark:bg-red-900/20' + : 'border-green-200 bg-green-50/80 dark:border-green-800 dark:bg-green-900/20' + } `} + > + + + ); + })} +
+
+ +
+ + ) : ( +
+ +

{localize('com_ui_no_backup_codes')}

+ +
+ )} +
+
+
+
+ ); +}; + +export default React.memo(BackupCodesItem); diff --git a/client/src/components/Nav/SettingsTabs/Account/DeleteAccount.tsx b/client/src/components/Nav/SettingsTabs/Account/DeleteAccount.tsx index 1c1e207d58..b00e7498bc 100644 --- a/client/src/components/Nav/SettingsTabs/Account/DeleteAccount.tsx +++ b/client/src/components/Nav/SettingsTabs/Account/DeleteAccount.tsx @@ -57,7 +57,7 @@ const DeleteAccount = ({ disabled = false }: { title?: string; disabled?: boolea
- + {localize('com_nav_delete_account_confirm')} diff --git a/client/src/components/Nav/SettingsTabs/Account/DisableTwoFactorToggle.tsx b/client/src/components/Nav/SettingsTabs/Account/DisableTwoFactorToggle.tsx new file mode 100644 index 0000000000..5dfad770d3 --- /dev/null +++ b/client/src/components/Nav/SettingsTabs/Account/DisableTwoFactorToggle.tsx @@ -0,0 +1,36 @@ +import React from 'react'; +import { motion } from 'framer-motion'; +import { LockIcon, UnlockIcon } from 'lucide-react'; +import { Label, Button } from '~/components'; +import { useLocalize } from '~/hooks'; + +interface DisableTwoFactorToggleProps { + enabled: boolean; + onChange: () => void; + disabled?: boolean; +} + +export const DisableTwoFactorToggle: React.FC = ({ + enabled, + onChange, + disabled, +}) => { + const localize = useLocalize(); + + return ( +
+
+ +
+
+ +
+
+ ); +}; diff --git a/client/src/components/Nav/SettingsTabs/Account/TwoFactorAuthentication.tsx b/client/src/components/Nav/SettingsTabs/Account/TwoFactorAuthentication.tsx new file mode 100644 index 0000000000..bd46e80249 --- /dev/null +++ b/client/src/components/Nav/SettingsTabs/Account/TwoFactorAuthentication.tsx @@ -0,0 +1,298 @@ +import React, { useCallback, useState } from 'react'; +import { useSetRecoilState } from 'recoil'; +import { SmartphoneIcon } from 'lucide-react'; +import { motion, AnimatePresence } from 'framer-motion'; +import type { TUser, TVerify2FARequest } from 'librechat-data-provider'; +import { OGDialog, OGDialogContent, OGDialogHeader, OGDialogTitle, Progress } from '~/components'; +import { SetupPhase, QRPhase, VerifyPhase, BackupPhase, DisablePhase } from './TwoFactorPhases'; +import { DisableTwoFactorToggle } from './DisableTwoFactorToggle'; +import { useAuthContext, useLocalize } from '~/hooks'; +import { useToastContext } from '~/Providers'; +import store from '~/store'; +import { + useConfirmTwoFactorMutation, + useDisableTwoFactorMutation, + useEnableTwoFactorMutation, + useVerifyTwoFactorMutation, +} from '~/data-provider'; + +export type Phase = 'setup' | 'qr' | 'verify' | 'backup' | 'disable'; + +const phaseVariants = { + initial: { opacity: 0, scale: 0.95 }, + animate: { opacity: 1, scale: 1, transition: { duration: 0.3, ease: 'easeOut' } }, + exit: { opacity: 0, scale: 0.95, transition: { duration: 0.3, ease: 'easeIn' } }, +}; + +const TwoFactorAuthentication: React.FC = () => { + const localize = useLocalize(); + const { user } = useAuthContext(); + const setUser = useSetRecoilState(store.user); + const { showToast } = useToastContext(); + + const [secret, setSecret] = useState(''); + const [otpauthUrl, setOtpauthUrl] = useState(''); + const [downloaded, setDownloaded] = useState(false); + const [disableToken, setDisableToken] = useState(''); + const [backupCodes, setBackupCodes] = useState([]); + const [isDialogOpen, setDialogOpen] = useState(false); + const [verificationToken, setVerificationToken] = useState(''); + const [phase, setPhase] = useState(Array.isArray(user?.backupCodes) && user?.backupCodes.length > 0 ? 'disable' : 'setup'); + + const { mutate: confirm2FAMutate } = useConfirmTwoFactorMutation(); + const { mutate: enable2FAMutate, isLoading: isGenerating } = useEnableTwoFactorMutation(); + const { mutate: verify2FAMutate, isLoading: isVerifying } = useVerifyTwoFactorMutation(); + const { mutate: disable2FAMutate, isLoading: isDisabling } = useDisableTwoFactorMutation(); + + const steps = ['Setup', 'Scan QR', 'Verify', 'Backup']; + const phasesLabel: Record = { + setup: 'Setup', + qr: 'Scan QR', + verify: 'Verify', + backup: 'Backup', + disable: '', + }; + + const currentStep = steps.indexOf(phasesLabel[phase]); + + const resetState = useCallback(() => { + if (Array.isArray(user?.backupCodes) && user?.backupCodes.length > 0 && otpauthUrl) { + disable2FAMutate(undefined, { + onError: () => + showToast({ message: localize('com_ui_2fa_disable_error'), status: 'error' }), + }); + } + + setOtpauthUrl(''); + setSecret(''); + setBackupCodes([]); + setVerificationToken(''); + setDisableToken(''); + setPhase(Array.isArray(user?.backupCodes) && user?.backupCodes.length > 0 ? 'disable' : 'setup'); + setDownloaded(false); + }, [user, otpauthUrl, disable2FAMutate, localize, showToast]); + + const handleGenerateQRCode = useCallback(() => { + enable2FAMutate(undefined, { + onSuccess: ({ otpauthUrl, backupCodes }) => { + setOtpauthUrl(otpauthUrl); + setSecret(otpauthUrl.split('secret=')[1].split('&')[0]); + setBackupCodes(backupCodes); + setPhase('qr'); + }, + onError: () => showToast({ message: localize('com_ui_2fa_generate_error'), status: 'error' }), + }); + }, [enable2FAMutate, localize, showToast]); + + const handleVerify = useCallback(() => { + if (!verificationToken) { + return; + } + + verify2FAMutate( + { token: verificationToken }, + { + onSuccess: () => { + showToast({ message: localize('com_ui_2fa_verified') }); + confirm2FAMutate( + { token: verificationToken }, + { + onSuccess: () => setPhase('backup'), + onError: () => + showToast({ message: localize('com_ui_2fa_invalid'), status: 'error' }), + }, + ); + }, + onError: () => showToast({ message: localize('com_ui_2fa_invalid'), status: 'error' }), + }, + ); + }, [verificationToken, verify2FAMutate, confirm2FAMutate, localize, showToast]); + + const handleDownload = useCallback(() => { + if (!backupCodes.length) { + return; + } + const blob = new Blob([backupCodes.join('\n')], { type: 'text/plain;charset=utf-8' }); + const url = URL.createObjectURL(blob); + const a = document.createElement('a'); + a.href = url; + a.download = 'backup-codes.txt'; + a.click(); + URL.revokeObjectURL(url); + setDownloaded(true); + }, [backupCodes]); + + const handleConfirm = useCallback(() => { + setDialogOpen(false); + setPhase('disable'); + showToast({ message: localize('com_ui_2fa_enabled') }); + setUser( + (prev) => + ({ + ...prev, + backupCodes: backupCodes.map((code) => ({ + code, + codeHash: code, + used: false, + usedAt: null, + })), + }) as TUser, + ); + }, [setUser, localize, showToast, backupCodes]); + + const handleDisableVerify = useCallback( + (token: string, useBackup: boolean) => { + // Validate: if not using backup, ensure token has at least 6 digits; + // if using backup, ensure backup code has at least 8 characters. + if (!useBackup && token.trim().length < 6) { + return; + } + + if (useBackup && token.trim().length < 8) { + return; + } + + const payload: TVerify2FARequest = {}; + if (useBackup) { + payload.backupCode = token.trim(); + } else { + payload.token = token.trim(); + } + + verify2FAMutate(payload, { + onSuccess: () => { + disable2FAMutate(undefined, { + onSuccess: () => { + showToast({ message: localize('com_ui_2fa_disabled') }); + setDialogOpen(false); + setUser( + (prev) => + ({ + ...prev, + totpSecret: '', + backupCodes: [], + }) as TUser, + ); + setPhase('setup'); + setOtpauthUrl(''); + }, + onError: () => + showToast({ message: localize('com_ui_2fa_disable_error'), status: 'error' }), + }); + }, + onError: () => showToast({ message: localize('com_ui_2fa_invalid'), status: 'error' }), + }); + }, + [disableToken, verify2FAMutate, disable2FAMutate, showToast, localize, setUser], + ); + + return ( + { + setDialogOpen(open); + if (!open) { + resetState(); + } + }} + > + 0} + onChange={() => setDialogOpen(true)} + disabled={isVerifying || isDisabling || isGenerating} + /> + + + + + + + + {Array.isArray(user?.backupCodes) && user?.backupCodes.length > 0 ? localize('com_ui_2fa_disable') : localize('com_ui_2fa_setup')} + + {Array.isArray(user?.backupCodes) && user?.backupCodes.length > 0 && phase !== 'disable' && ( +
+ +
+ {steps.map((step, index) => ( + = index ? 'var(--text-primary)' : 'var(--text-tertiary)', + }} + className="font-medium" + > + {step} + + ))} +
+
+ )} +
+ + + {phase === 'setup' && ( + setPhase('qr')} + onError={(error) => showToast({ message: error.message, status: 'error' })} + /> + )} + + {phase === 'qr' && ( + setPhase('verify')} + onError={(error) => showToast({ message: error.message, status: 'error' })} + /> + )} + + {phase === 'verify' && ( + showToast({ message: error.message, status: 'error' })} + /> + )} + + {phase === 'backup' && ( + showToast({ message: error.message, status: 'error' })} + /> + )} + + {phase === 'disable' && ( + showToast({ message: error.message, status: 'error' })} + /> + )} + +
+
+
+
+ ); +}; + +export default React.memo(TwoFactorAuthentication); diff --git a/client/src/components/Nav/SettingsTabs/Account/TwoFactorPhases/BackupPhase.tsx b/client/src/components/Nav/SettingsTabs/Account/TwoFactorPhases/BackupPhase.tsx new file mode 100644 index 0000000000..67e05a1423 --- /dev/null +++ b/client/src/components/Nav/SettingsTabs/Account/TwoFactorPhases/BackupPhase.tsx @@ -0,0 +1,60 @@ +import React from 'react'; +import { motion } from 'framer-motion'; +import { Download } from 'lucide-react'; +import { Button, Label } from '~/components'; +import { useLocalize } from '~/hooks'; + +const fadeAnimation = { + initial: { opacity: 0, y: 20 }, + animate: { opacity: 1, y: 0 }, + exit: { opacity: 0, y: -20 }, + transition: { duration: 0.2 }, +}; + +interface BackupPhaseProps { + onNext: () => void; + onError: (error: Error) => void; + backupCodes: string[]; + onDownload: () => void; + downloaded: boolean; +} + +export const BackupPhase: React.FC = ({ + backupCodes, + onDownload, + downloaded, + onNext, +}) => { + const localize = useLocalize(); + + return ( + + +
+ {backupCodes.map((code, index) => ( + +
+ #{index + 1} + {code} +
+
+ ))} +
+
+ + +
+
+ ); +}; diff --git a/client/src/components/Nav/SettingsTabs/Account/TwoFactorPhases/DisablePhase.tsx b/client/src/components/Nav/SettingsTabs/Account/TwoFactorPhases/DisablePhase.tsx new file mode 100644 index 0000000000..27422d26c3 --- /dev/null +++ b/client/src/components/Nav/SettingsTabs/Account/TwoFactorPhases/DisablePhase.tsx @@ -0,0 +1,88 @@ +import React, { useState } from 'react'; +import { motion } from 'framer-motion'; +import { REGEXP_ONLY_DIGITS, REGEXP_ONLY_DIGITS_AND_CHARS } from 'input-otp'; +import { + Button, + InputOTP, + InputOTPGroup, + InputOTPSlot, + InputOTPSeparator, + Spinner, +} from '~/components'; +import { useLocalize } from '~/hooks'; + +const fadeAnimation = { + initial: { opacity: 0, y: 20 }, + animate: { opacity: 1, y: 0 }, + exit: { opacity: 0, y: -20 }, + transition: { duration: 0.2 }, +}; + +interface DisablePhaseProps { + onSuccess?: () => void; + onError?: (error: Error) => void; + onDisable: (token: string, useBackup: boolean) => void; + isDisabling: boolean; +} + +export const DisablePhase: React.FC = ({ onDisable, isDisabling }) => { + const localize = useLocalize(); + const [token, setToken] = useState(''); + const [useBackup, setUseBackup] = useState(false); + + return ( + +
+ + {useBackup ? ( + + + + + + + + + + + ) : ( + <> + + + + + + + + + + + + + )} + +
+ + +
+ ); +}; diff --git a/client/src/components/Nav/SettingsTabs/Account/TwoFactorPhases/QRPhase.tsx b/client/src/components/Nav/SettingsTabs/Account/TwoFactorPhases/QRPhase.tsx new file mode 100644 index 0000000000..7a0eccae3f --- /dev/null +++ b/client/src/components/Nav/SettingsTabs/Account/TwoFactorPhases/QRPhase.tsx @@ -0,0 +1,66 @@ +import React, { useState } from 'react'; +import { motion } from 'framer-motion'; +import { QRCodeSVG } from 'qrcode.react'; +import { Copy, Check } from 'lucide-react'; +import { Input, Button, Label } from '~/components'; +import { useLocalize } from '~/hooks'; +import { cn } from '~/utils'; + +const fadeAnimation = { + initial: { opacity: 0, y: 20 }, + animate: { opacity: 1, y: 0 }, + exit: { opacity: 0, y: -20 }, + transition: { duration: 0.2 }, +}; + +interface QRPhaseProps { + secret: string; + otpauthUrl: string; + onNext: () => void; + onSuccess?: () => void; + onError?: (error: Error) => void; +} + +export const QRPhase: React.FC = ({ secret, otpauthUrl, onNext }) => { + const localize = useLocalize(); + const [isCopying, setIsCopying] = useState(false); + + const handleCopy = async () => { + await navigator.clipboard.writeText(secret); + setIsCopying(true); + setTimeout(() => setIsCopying(false), 2000); + }; + + return ( + +
+ + + +
+ +
+ + +
+
+
+ +
+ ); +}; diff --git a/client/src/components/Nav/SettingsTabs/Account/TwoFactorPhases/SetupPhase.tsx b/client/src/components/Nav/SettingsTabs/Account/TwoFactorPhases/SetupPhase.tsx new file mode 100644 index 0000000000..4fd2d1181d --- /dev/null +++ b/client/src/components/Nav/SettingsTabs/Account/TwoFactorPhases/SetupPhase.tsx @@ -0,0 +1,42 @@ +import React from 'react'; +import { QrCode } from 'lucide-react'; +import { motion } from 'framer-motion'; +import { Button, Spinner } from '~/components'; +import { useLocalize } from '~/hooks'; + +const fadeAnimation = { + initial: { opacity: 0, y: 20 }, + animate: { opacity: 1, y: 0 }, + exit: { opacity: 0, y: -20 }, + transition: { duration: 0.2 }, +}; + +interface SetupPhaseProps { + onNext: () => void; + onError: (error: Error) => void; + isGenerating: boolean; + onGenerate: () => void; +} + +export const SetupPhase: React.FC = ({ isGenerating, onGenerate, onNext }) => { + const localize = useLocalize(); + + return ( + +
+

+ {localize('com_ui_2fa_account_security')} +

+ +
+
+ ); +}; diff --git a/client/src/components/Nav/SettingsTabs/Account/TwoFactorPhases/VerifyPhase.tsx b/client/src/components/Nav/SettingsTabs/Account/TwoFactorPhases/VerifyPhase.tsx new file mode 100644 index 0000000000..e872dfa0d2 --- /dev/null +++ b/client/src/components/Nav/SettingsTabs/Account/TwoFactorPhases/VerifyPhase.tsx @@ -0,0 +1,58 @@ +import React from 'react'; +import { motion } from 'framer-motion'; +import { Button, InputOTP, InputOTPGroup, InputOTPSeparator, InputOTPSlot } from '~/components'; +import { REGEXP_ONLY_DIGITS } from 'input-otp'; +import { useLocalize } from '~/hooks'; + +const fadeAnimation = { + initial: { opacity: 0, y: 20 }, + animate: { opacity: 1, y: 0 }, + exit: { opacity: 0, y: -20 }, + transition: { duration: 0.2 }, +}; + +interface VerifyPhaseProps { + token: string; + onTokenChange: (value: string) => void; + isVerifying: boolean; + onNext: () => void; + onError: (error: Error) => void; +} + +export const VerifyPhase: React.FC = ({ + token, + onTokenChange, + isVerifying, + onNext, +}) => { + const localize = useLocalize(); + + return ( + +
+ + + {Array.from({ length: 3 }).map((_, i) => ( + + ))} + + + + {Array.from({ length: 3 }).map((_, i) => ( + + ))} + + +
+ +
+ ); +}; diff --git a/client/src/components/Nav/SettingsTabs/Account/TwoFactorPhases/index.ts b/client/src/components/Nav/SettingsTabs/Account/TwoFactorPhases/index.ts new file mode 100644 index 0000000000..1cc474efef --- /dev/null +++ b/client/src/components/Nav/SettingsTabs/Account/TwoFactorPhases/index.ts @@ -0,0 +1,5 @@ +export * from './BackupPhase'; +export * from './QRPhase'; +export * from './VerifyPhase'; +export * from './SetupPhase'; +export * from './DisablePhase'; diff --git a/client/src/components/Nav/SettingsTabs/Data/ImportConversations.tsx b/client/src/components/Nav/SettingsTabs/Data/ImportConversations.tsx index c39e8351e8..e3bafd9152 100644 --- a/client/src/components/Nav/SettingsTabs/Data/ImportConversations.tsx +++ b/client/src/components/Nav/SettingsTabs/Data/ImportConversations.tsx @@ -82,7 +82,7 @@ function ImportConversations() { onClick={handleImportClick} onKeyDown={handleKeyDown} disabled={!allowImport} - aria-label={localize('com_ui_import_conversation')} + aria-label={localize('com_ui_import')} className="btn btn-neutral relative" > {allowImport ? ( @@ -90,7 +90,7 @@ function ImportConversations() { ) : ( )} - {localize('com_ui_import_conversation')} + {localize('com_ui_import')} setIsOpen(true)}> - + +