diff --git a/.env.example b/.env.example index e63d043660..f02e063566 100644 --- a/.env.example +++ b/.env.example @@ -20,6 +20,11 @@ DOMAIN_CLIENT=http://localhost:3080 DOMAIN_SERVER=http://localhost:3080 NO_INDEX=true +# Use the address that is at most n number of hops away from the Express application. +# req.socket.remoteAddress is the first hop, and the rest are looked for in the X-Forwarded-For header from right to left. +# A value of 0 means that the first untrusted address would be req.socket.remoteAddress, i.e. there is no reverse proxy. +# Defaulted to 1. +TRUST_PROXY=1 #===============# # JSON Logging # @@ -83,7 +88,7 @@ PROXY= #============# ANTHROPIC_API_KEY=user_provided -# ANTHROPIC_MODELS=claude-3-5-haiku-20241022,claude-3-5-sonnet-20241022,claude-3-5-sonnet-latest,claude-3-5-sonnet-20240620,claude-3-opus-20240229,claude-3-sonnet-20240229,claude-3-haiku-20240307,claude-2.1,claude-2,claude-1.2,claude-1,claude-1-100k,claude-instant-1,claude-instant-1-100k +# ANTHROPIC_MODELS=claude-3-7-sonnet-latest,claude-3-7-sonnet-20250219,claude-3-5-haiku-20241022,claude-3-5-sonnet-20241022,claude-3-5-sonnet-latest,claude-3-5-sonnet-20240620,claude-3-opus-20240229,claude-3-sonnet-20240229,claude-3-haiku-20240307,claude-2.1,claude-2,claude-1.2,claude-1,claude-1-100k,claude-instant-1,claude-instant-1-100k # ANTHROPIC_REVERSE_PROXY= #============# @@ -170,7 +175,7 @@ GOOGLE_KEY=user_provided #============# OPENAI_API_KEY=user_provided -# OPENAI_MODELS=o1,o1-mini,o1-preview,gpt-4o,chatgpt-4o-latest,gpt-4o-mini,gpt-3.5-turbo-0125,gpt-3.5-turbo-0301,gpt-3.5-turbo,gpt-4,gpt-4-0613,gpt-4-vision-preview,gpt-3.5-turbo-0613,gpt-3.5-turbo-16k-0613,gpt-4-0125-preview,gpt-4-turbo-preview,gpt-4-1106-preview,gpt-3.5-turbo-1106,gpt-3.5-turbo-instruct,gpt-3.5-turbo-instruct-0914,gpt-3.5-turbo-16k +# OPENAI_MODELS=o1,o1-mini,o1-preview,gpt-4o,gpt-4.5-preview,chatgpt-4o-latest,gpt-4o-mini,gpt-3.5-turbo-0125,gpt-3.5-turbo-0301,gpt-3.5-turbo,gpt-4,gpt-4-0613,gpt-4-vision-preview,gpt-3.5-turbo-0613,gpt-3.5-turbo-16k-0613,gpt-4-0125-preview,gpt-4-turbo-preview,gpt-4-1106-preview,gpt-3.5-turbo-1106,gpt-3.5-turbo-instruct,gpt-3.5-turbo-instruct-0914,gpt-3.5-turbo-16k DEBUG_OPENAI=false @@ -204,12 +209,6 @@ ASSISTANTS_API_KEY=user_provided # More info, including how to enable use of Assistants with Azure here: # https://www.librechat.ai/docs/configuration/librechat_yaml/ai_endpoints/azure#using-assistants-with-azure -#============# -# OpenRouter # -#============# -# !!!Warning: Use the variable above instead of this one. Using this one will override the OpenAI endpoint -# OPENROUTER_API_KEY= - #============# # Plugins # #============# @@ -249,6 +248,13 @@ AZURE_AI_SEARCH_SEARCH_OPTION_SELECT= # DALLE3_AZURE_API_VERSION= # DALLE2_AZURE_API_VERSION= +# Flux +#----------------- +FLUX_API_BASE_URL=https://api.us1.bfl.ai +# FLUX_API_BASE_URL = 'https://api.bfl.ml'; + +# Get your API key at https://api.us1.bfl.ai/auth/profile +# FLUX_API_KEY= # Google #----------------- @@ -292,6 +298,10 @@ MEILI_NO_ANALYTICS=true MEILI_HOST=http://0.0.0.0:7700 MEILI_MASTER_KEY=DrhYf7zENyR6AlUCKmnz0eYASOQdl6zxH7s7MKFSfFCt +# Optional: Disable indexing, useful in a multi-node setup +# where only one instance should perform an index sync. +# MEILI_NO_SYNC=true + #==================================================# # Speech to Text & Text to Speech # #==================================================# @@ -389,7 +399,7 @@ FACEBOOK_CALLBACK_URL=/oauth/facebook/callback GITHUB_CLIENT_ID= GITHUB_CLIENT_SECRET= GITHUB_CALLBACK_URL=/oauth/github/callback -# GitHub Eenterprise +# GitHub Enterprise # GITHUB_ENTERPRISE_BASE_URL= # GITHUB_ENTERPRISE_USER_AGENT= @@ -424,15 +434,19 @@ OPENID_NAME_CLAIM= OPENID_BUTTON_LABEL= OPENID_IMAGE_URL= +# Set to true to automatically redirect to the OpenID provider when a user visits the login page +# This will bypass the login form completely for users, only use this if OpenID is your only authentication method +OPENID_AUTO_REDIRECT=false # LDAP LDAP_URL= LDAP_BIND_DN= LDAP_BIND_CREDENTIALS= LDAP_USER_SEARCH_BASE= -LDAP_SEARCH_FILTER=mail={{username}} +#LDAP_SEARCH_FILTER="mail=" LDAP_CA_CERT_PATH= # LDAP_TLS_REJECT_UNAUTHORIZED= +# LDAP_STARTTLS= # LDAP_LOGIN_USES_USERNAME=true # LDAP_ID= # LDAP_USERNAME= @@ -465,6 +479,24 @@ FIREBASE_STORAGE_BUCKET= FIREBASE_MESSAGING_SENDER_ID= FIREBASE_APP_ID= +#========================# +# S3 AWS Bucket # +#========================# + +AWS_ENDPOINT_URL= +AWS_ACCESS_KEY_ID= +AWS_SECRET_ACCESS_KEY= +AWS_REGION= +AWS_BUCKET_NAME= + +#========================# +# Azure Blob Storage # +#========================# + +AZURE_STORAGE_CONNECTION_STRING= +AZURE_STORAGE_PUBLIC_ACCESS=false +AZURE_CONTAINER_NAME=files + #========================# # Shared Links # #========================# @@ -497,6 +529,16 @@ HELP_AND_FAQ_URL=https://librechat.ai # Google tag manager id #ANALYTICS_GTM_ID=user provided google tag manager id +#===============# +# REDIS Options # +#===============# + +# REDIS_URI=10.10.10.10:6379 +# USE_REDIS=true + +# USE_REDIS_CLUSTER=true +# REDIS_CA=/path/to/ca.crt + #==================================================# # Others # #==================================================# @@ -504,9 +546,6 @@ HELP_AND_FAQ_URL=https://librechat.ai # NODE_ENV= -# REDIS_URI= -# USE_REDIS= - # E2E_USER_EMAIL= # E2E_USER_PASSWORD= @@ -529,4 +568,4 @@ HELP_AND_FAQ_URL=https://librechat.ai #=====================================================# # OpenWeather # #=====================================================# -OPENWEATHER_API_KEY= \ No newline at end of file +OPENWEATHER_API_KEY= diff --git a/.github/ISSUE_TEMPLATE/LOCIZE_TRANSLATION_ACCESS_REQUEST.yml b/.github/ISSUE_TEMPLATE/LOCIZE_TRANSLATION_ACCESS_REQUEST.yml new file mode 100644 index 0000000000..49b01a814d --- /dev/null +++ b/.github/ISSUE_TEMPLATE/LOCIZE_TRANSLATION_ACCESS_REQUEST.yml @@ -0,0 +1,42 @@ +name: Locize Translation Access Request +description: Request access to an additional language in Locize for LibreChat translations. +title: "Locize Access Request: " +labels: ["🌍 i18n", "🔑 access request"] +body: + - type: markdown + attributes: + value: | + Thank you for your interest in contributing to LibreChat translations! + Please fill out the form below to request access to an additional language in **Locize**. + + **🔗 Available Languages:** [View the list here](https://www.librechat.ai/docs/translation) + + **📌 Note:** Ensure that the requested language is supported before submitting your request. + - type: input + id: account_name + attributes: + label: Locize Account Name + description: Please provide your Locize account name (e.g., John Doe). + placeholder: e.g., John Doe + validations: + required: true + - type: input + id: language_requested + attributes: + label: Language Code (ISO 639-1) + description: | + Enter the **ISO 639-1** language code for the language you want to translate into. + Example: `es` for Spanish, `zh-Hant` for Traditional Chinese. + + **🔗 Reference:** [Available Languages](https://www.librechat.ai/docs/translation) + placeholder: e.g., es + validations: + required: true + - type: checkboxes + id: agreement + attributes: + label: Agreement + description: By submitting this request, you confirm that you will contribute responsibly and adhere to the project guidelines. + options: + - label: I agree to use my access solely for contributing to LibreChat translations. + required: true \ No newline at end of file diff --git a/.github/ISSUE_TEMPLATE/QUESTION.yml b/.github/ISSUE_TEMPLATE/QUESTION.yml deleted file mode 100644 index c66e6baa3b..0000000000 --- a/.github/ISSUE_TEMPLATE/QUESTION.yml +++ /dev/null @@ -1,50 +0,0 @@ -name: Question -description: Ask your question -title: "[Question]: " -labels: ["❓ question"] -body: - - type: markdown - attributes: - value: | - Thanks for taking the time to fill this! - - type: textarea - id: what-is-your-question - attributes: - label: What is your question? - description: Please give as many details as possible - placeholder: Please give as many details as possible - validations: - required: true - - type: textarea - id: more-details - attributes: - label: More Details - description: Please provide more details if needed. - placeholder: Please provide more details if needed. - validations: - required: true - - type: dropdown - id: browsers - attributes: - label: What is the main subject of your question? - multiple: true - options: - - Documentation - - Installation - - UI - - Endpoints - - User System/OAuth - - Other - - type: textarea - id: screenshots - attributes: - label: Screenshots - description: If applicable, add screenshots to help explain your problem. You can drag and drop, paste images directly here or link to them. - - type: checkboxes - id: terms - attributes: - label: Code of Conduct - description: By submitting this issue, you agree to follow our [Code of Conduct](https://github.com/danny-avila/LibreChat/blob/main/.github/CODE_OF_CONDUCT.md) - options: - - label: I agree to follow this project's Code of Conduct - required: true diff --git a/.github/configuration-release.json b/.github/configuration-release.json new file mode 100644 index 0000000000..68fe80ed8f --- /dev/null +++ b/.github/configuration-release.json @@ -0,0 +1,60 @@ +{ + "categories": [ + { + "title": "### ✨ New Features", + "labels": ["feat"] + }, + { + "title": "### 🌍 Internationalization", + "labels": ["i18n"] + }, + { + "title": "### 👐 Accessibility", + "labels": ["a11y"] + }, + { + "title": "### 🔧 Fixes", + "labels": ["Fix", "fix"] + }, + { + "title": "### ⚙️ Other Changes", + "labels": ["ci", "style", "docs", "refactor", "chore"] + } + ], + "ignore_labels": [ + "🔁 duplicate", + "📊 analytics", + "🌱 good first issue", + "🔍 investigation", + "🙏 help wanted", + "❌ invalid", + "❓ question", + "🚫 wontfix", + "🚀 release", + "version" + ], + "base_branches": ["main"], + "sort": { + "order": "ASC", + "on_property": "mergedAt" + }, + "label_extractor": [ + { + "pattern": "^(?:[^A-Za-z0-9]*)(feat|fix|chore|docs|refactor|ci|style|a11y|i18n)\\s*:", + "target": "$1", + "flags": "i", + "on_property": "title", + "method": "match" + }, + { + "pattern": "^(?:[^A-Za-z0-9]*)(v\\d+\\.\\d+\\.\\d+(?:-rc\\d+)?).*", + "target": "version", + "flags": "i", + "on_property": "title", + "method": "match" + } + ], + "template": "## [#{{TO_TAG}}] - #{{TO_TAG_DATE}}\n\nChanges from #{{FROM_TAG}} to #{{TO_TAG}}.\n\n#{{CHANGELOG}}\n\n[See full release details][release-#{{TO_TAG}}]\n\n[release-#{{TO_TAG}}]: https://github.com/#{{OWNER}}/#{{REPO}}/releases/tag/#{{TO_TAG}}\n\n---", + "pr_template": "- #{{TITLE}} by **@#{{AUTHOR}}** in [##{{NUMBER}}](#{{URL}})", + "empty_template": "- no changes" +} \ No newline at end of file diff --git a/.github/configuration-unreleased.json b/.github/configuration-unreleased.json new file mode 100644 index 0000000000..29eaf5e13b --- /dev/null +++ b/.github/configuration-unreleased.json @@ -0,0 +1,68 @@ +{ + "categories": [ + { + "title": "### ✨ New Features", + "labels": ["feat"] + }, + { + "title": "### 🌍 Internationalization", + "labels": ["i18n"] + }, + { + "title": "### 👐 Accessibility", + "labels": ["a11y"] + }, + { + "title": "### 🔧 Fixes", + "labels": ["Fix", "fix"] + }, + { + "title": "### ⚙️ Other Changes", + "labels": ["ci", "style", "docs", "refactor", "chore"] + } + ], + "ignore_labels": [ + "🔁 duplicate", + "📊 analytics", + "🌱 good first issue", + "🔍 investigation", + "🙏 help wanted", + "❌ invalid", + "❓ question", + "🚫 wontfix", + "🚀 release", + "version", + "action" + ], + "base_branches": ["main"], + "sort": { + "order": "ASC", + "on_property": "mergedAt" + }, + "label_extractor": [ + { + "pattern": "^(?:[^A-Za-z0-9]*)(feat|fix|chore|docs|refactor|ci|style|a11y|i18n)\\s*:", + "target": "$1", + "flags": "i", + "on_property": "title", + "method": "match" + }, + { + "pattern": "^(?:[^A-Za-z0-9]*)(v\\d+\\.\\d+\\.\\d+(?:-rc\\d+)?).*", + "target": "version", + "flags": "i", + "on_property": "title", + "method": "match" + }, + { + "pattern": "^(?:[^A-Za-z0-9]*)(action)\\b.*", + "target": "action", + "flags": "i", + "on_property": "title", + "method": "match" + } + ], + "template": "## [Unreleased]\n\n#{{CHANGELOG}}\n\n---", + "pr_template": "- #{{TITLE}} by **@#{{AUTHOR}}** in [##{{NUMBER}}](#{{URL}})", + "empty_template": "- no changes" +} \ No newline at end of file diff --git a/.github/workflows/backend-review.yml b/.github/workflows/backend-review.yml index 5bc3d3b2db..b7bccecae8 100644 --- a/.github/workflows/backend-review.yml +++ b/.github/workflows/backend-review.yml @@ -39,6 +39,9 @@ jobs: - name: Install MCP Package run: npm run build:mcp + - name: Install Data Schemas Package + run: npm run build:data-schemas + - name: Create empty auth.json file run: | mkdir -p api/data @@ -61,4 +64,7 @@ jobs: run: cd api && npm run test:ci - name: Run librechat-data-provider unit tests - run: cd packages/data-provider && npm run test:ci \ No newline at end of file + run: cd packages/data-provider && npm run test:ci + + - name: Run librechat-mcp unit tests + run: cd packages/mcp && npm run test:ci \ No newline at end of file diff --git a/.github/workflows/data-schemas.yml b/.github/workflows/data-schemas.yml new file mode 100644 index 0000000000..fee72fbe02 --- /dev/null +++ b/.github/workflows/data-schemas.yml @@ -0,0 +1,58 @@ +name: Publish `@librechat/data-schemas` to NPM + +on: + push: + branches: + - main + paths: + - 'packages/data-schemas/package.json' + workflow_dispatch: + inputs: + reason: + description: 'Reason for manual trigger' + required: false + default: 'Manual publish requested' + +jobs: + build-and-publish: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Use Node.js + uses: actions/setup-node@v4 + with: + node-version: '18.x' + + - name: Install dependencies + run: cd packages/data-schemas && npm ci + + - name: Build + run: cd packages/data-schemas && npm run build + + - name: Set up npm authentication + run: echo "//registry.npmjs.org/:_authToken=${{ secrets.PUBLISH_NPM_TOKEN }}" > ~/.npmrc + + - name: Check version change + id: check + working-directory: packages/data-schemas + run: | + PACKAGE_VERSION=$(node -p "require('./package.json').version") + PUBLISHED_VERSION=$(npm view @librechat/data-schemas version 2>/dev/null || echo "0.0.0") + if [ "$PACKAGE_VERSION" = "$PUBLISHED_VERSION" ]; then + echo "No version change, skipping publish" + echo "skip=true" >> $GITHUB_OUTPUT + else + echo "Version changed, proceeding with publish" + echo "skip=false" >> $GITHUB_OUTPUT + fi + + - name: Pack package + if: steps.check.outputs.skip != 'true' + working-directory: packages/data-schemas + run: npm pack + + - name: Publish + if: steps.check.outputs.skip != 'true' + working-directory: packages/data-schemas + run: npm publish *.tgz --access public \ No newline at end of file diff --git a/.github/workflows/generate-release-changelog-pr.yml b/.github/workflows/generate-release-changelog-pr.yml new file mode 100644 index 0000000000..004431e577 --- /dev/null +++ b/.github/workflows/generate-release-changelog-pr.yml @@ -0,0 +1,94 @@ +name: Generate Release Changelog PR + +on: + push: + tags: + - 'v*.*.*' + +jobs: + generate-release-changelog-pr: + permissions: + contents: write # Needed for pushing commits and creating branches. + pull-requests: write + runs-on: ubuntu-latest + steps: + # 1. Checkout the repository (with full history). + - name: Checkout Repository + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + # 2. Generate the release changelog using our custom configuration. + - name: Generate Release Changelog + id: generate_release + uses: mikepenz/release-changelog-builder-action@v5.1.0 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + configuration: ".github/configuration-release.json" + owner: ${{ github.repository_owner }} + repo: ${{ github.event.repository.name }} + outputFile: CHANGELOG-release.md + + # 3. Update the main CHANGELOG.md: + # - If it doesn't exist, create it with a basic header. + # - Remove the "Unreleased" section (if present). + # - Prepend the new release changelog above previous releases. + # - Remove all temporary files before committing. + - name: Update CHANGELOG.md + run: | + # Determine the release tag, e.g. "v1.2.3" + TAG=${GITHUB_REF##*/} + echo "Using release tag: $TAG" + + # Ensure CHANGELOG.md exists; if not, create a basic header. + if [ ! -f CHANGELOG.md ]; then + echo "# Changelog" > CHANGELOG.md + echo "" >> CHANGELOG.md + echo "All notable changes to this project will be documented in this file." >> CHANGELOG.md + echo "" >> CHANGELOG.md + fi + + echo "Updating CHANGELOG.md…" + + # Remove the "Unreleased" section (from "## [Unreleased]" until the first occurrence of '---') if it exists. + if grep -q "^## \[Unreleased\]" CHANGELOG.md; then + awk '/^## \[Unreleased\]/{flag=1} flag && /^---/{flag=0; next} !flag' CHANGELOG.md > CHANGELOG.cleaned + else + cp CHANGELOG.md CHANGELOG.cleaned + fi + + # Split the cleaned file into: + # - header.md: content before the first release header ("## [v..."). + # - tail.md: content from the first release header onward. + awk '/^## \[v/{exit} {print}' CHANGELOG.cleaned > header.md + awk 'f{print} /^## \[v/{f=1; print}' CHANGELOG.cleaned > tail.md + + # Combine header, the new release changelog, and the tail. + echo "Combining updated changelog parts..." + cat header.md CHANGELOG-release.md > CHANGELOG.md.new + echo "" >> CHANGELOG.md.new + cat tail.md >> CHANGELOG.md.new + + mv CHANGELOG.md.new CHANGELOG.md + + # Remove temporary files. + rm -f CHANGELOG.cleaned header.md tail.md CHANGELOG-release.md + + echo "Final CHANGELOG.md content:" + cat CHANGELOG.md + + # 4. Create (or update) the Pull Request with the updated CHANGELOG.md. + - name: Create Pull Request + uses: peter-evans/create-pull-request@v7 + with: + token: ${{ secrets.GITHUB_TOKEN }} + sign-commits: true + commit-message: "chore: update CHANGELOG for release ${{ github.ref_name }}" + base: main + branch: "changelog/${{ github.ref_name }}" + reviewers: danny-avila + title: "chore: update CHANGELOG for release ${{ github.ref_name }}" + body: | + **Description**: + - This PR updates the CHANGELOG.md by removing the "Unreleased" section and adding new release notes for release ${{ github.ref_name }} above previous releases. \ No newline at end of file diff --git a/.github/workflows/generate-unreleased-changelog-pr.yml b/.github/workflows/generate-unreleased-changelog-pr.yml new file mode 100644 index 0000000000..b130e4fb33 --- /dev/null +++ b/.github/workflows/generate-unreleased-changelog-pr.yml @@ -0,0 +1,106 @@ +name: Generate Unreleased Changelog PR + +on: + schedule: + - cron: "0 0 * * 1" # Runs every Monday at 00:00 UTC + +jobs: + generate-unreleased-changelog-pr: + permissions: + contents: write # Needed for pushing commits and creating branches. + pull-requests: write + runs-on: ubuntu-latest + steps: + # 1. Checkout the repository on main. + - name: Checkout Repository on Main + uses: actions/checkout@v4 + with: + ref: main + fetch-depth: 0 + + # 4. Get the latest version tag. + - name: Get Latest Tag + id: get_latest_tag + run: | + LATEST_TAG=$(git describe --tags $(git rev-list --tags --max-count=1) || echo "none") + echo "Latest tag: $LATEST_TAG" + echo "tag=$LATEST_TAG" >> $GITHUB_OUTPUT + + # 5. Generate the Unreleased changelog. + - name: Generate Unreleased Changelog + id: generate_unreleased + uses: mikepenz/release-changelog-builder-action@v5.1.0 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + configuration: ".github/configuration-unreleased.json" + owner: ${{ github.repository_owner }} + repo: ${{ github.event.repository.name }} + outputFile: CHANGELOG-unreleased.md + fromTag: ${{ steps.get_latest_tag.outputs.tag }} + toTag: main + + # 7. Update CHANGELOG.md with the new Unreleased section. + - name: Update CHANGELOG.md + id: update_changelog + run: | + # Create CHANGELOG.md if it doesn't exist. + if [ ! -f CHANGELOG.md ]; then + echo "# Changelog" > CHANGELOG.md + echo "" >> CHANGELOG.md + echo "All notable changes to this project will be documented in this file." >> CHANGELOG.md + echo "" >> CHANGELOG.md + fi + + echo "Updating CHANGELOG.md…" + + # Extract content before the "## [Unreleased]" (or first version header if missing). + if grep -q "^## \[Unreleased\]" CHANGELOG.md; then + awk '/^## \[Unreleased\]/{exit} {print}' CHANGELOG.md > CHANGELOG_TMP.md + else + awk '/^## \[v/{exit} {print}' CHANGELOG.md > CHANGELOG_TMP.md + fi + + # Append the generated Unreleased changelog. + echo "" >> CHANGELOG_TMP.md + cat CHANGELOG-unreleased.md >> CHANGELOG_TMP.md + echo "" >> CHANGELOG_TMP.md + + # Append the remainder of the original changelog (starting from the first version header). + awk 'f{print} /^## \[v/{f=1; print}' CHANGELOG.md >> CHANGELOG_TMP.md + + # Replace the old file with the updated file. + mv CHANGELOG_TMP.md CHANGELOG.md + + # Remove the temporary generated file. + rm -f CHANGELOG-unreleased.md + + echo "Final CHANGELOG.md:" + cat CHANGELOG.md + + # 8. Check if CHANGELOG.md has any updates. + - name: Check for CHANGELOG.md changes + id: changelog_changes + run: | + if git diff --quiet CHANGELOG.md; then + echo "has_changes=false" >> $GITHUB_OUTPUT + else + echo "has_changes=true" >> $GITHUB_OUTPUT + fi + + # 9. Create (or update) the Pull Request only if there are changes. + - name: Create Pull Request + if: steps.changelog_changes.outputs.has_changes == 'true' + uses: peter-evans/create-pull-request@v7 + with: + token: ${{ secrets.GITHUB_TOKEN }} + base: main + branch: "changelog/unreleased-update" + sign-commits: true + commit-message: "action: update Unreleased changelog" + title: "action: update Unreleased changelog" + body: | + **Description**: + - This PR updates the Unreleased section in CHANGELOG.md. + - It compares the current main branch with the latest version tag (determined as ${{ steps.get_latest_tag.outputs.tag }}), + regenerates the Unreleased changelog, removes any old Unreleased block, and inserts the new content. \ No newline at end of file diff --git a/.github/workflows/i18n-unused-keys.yml b/.github/workflows/i18n-unused-keys.yml index 79f95d3b27..5e29a8a8bd 100644 --- a/.github/workflows/i18n-unused-keys.yml +++ b/.github/workflows/i18n-unused-keys.yml @@ -4,6 +4,7 @@ on: pull_request: paths: - "client/src/**" + - "api/**" jobs: detect-unused-i18n-keys: @@ -21,7 +22,7 @@ jobs: # Define paths I18N_FILE="client/src/locales/en/translation.json" - SOURCE_DIR="client/src" + SOURCE_DIRS=("client/src" "api") # Check if translation file exists if [[ ! -f "$I18N_FILE" ]]; then @@ -37,7 +38,15 @@ jobs: # Check if each key is used in the source code for KEY in $KEYS; do - if ! grep -r --include=\*.{js,jsx,ts,tsx} -q "$KEY" "$SOURCE_DIR"; then + FOUND=false + for DIR in "${SOURCE_DIRS[@]}"; do + if grep -r --include=\*.{js,jsx,ts,tsx} -q "$KEY" "$DIR"; then + FOUND=true + break + fi + done + + if [[ "$FOUND" == false ]]; then UNUSED_KEYS+=("$KEY") fi done @@ -59,8 +68,8 @@ jobs: run: | PR_NUMBER=$(jq --raw-output .pull_request.number "$GITHUB_EVENT_PATH") - # Format the unused keys list correctly, filtering out empty entries - FILTERED_KEYS=$(echo "$unused_keys" | jq -r '.[]' | grep -v '^\s*$' | sed 's/^/- `/;s/$/`/' ) + # Format the unused keys list as checkboxes for easy manual checking. + FILTERED_KEYS=$(echo "$unused_keys" | jq -r '.[]' | grep -v '^\s*$' | sed 's/^/- [ ] `/;s/$/`/' ) COMMENT_BODY=$(cat < - Locize Logo + Locize Logo

diff --git a/api/app/clients/AnthropicClient.js b/api/app/clients/AnthropicClient.js index 522b6beb4f..8331e54bfb 100644 --- a/api/app/clients/AnthropicClient.js +++ b/api/app/clients/AnthropicClient.js @@ -7,7 +7,7 @@ const { getResponseSender, validateVisionModel, } = require('librechat-data-provider'); -const { encodeAndFormat } = require('~/server/services/Files/images/encode'); +const { SplitStreamHandler: _Handler, GraphEvents } = require('@librechat/agents'); const { truncateText, formatMessage, @@ -16,16 +16,31 @@ const { parseParamFromPrompt, createContextHandlers, } = require('./prompts'); +const { + getClaudeHeaders, + configureReasoning, + checkPromptCacheSupport, +} = require('~/server/services/Endpoints/anthropic/helpers'); const { getModelMaxTokens, getModelMaxOutputTokens, matchModelName } = require('~/utils'); const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens'); +const { encodeAndFormat } = require('~/server/services/Files/images/encode'); const Tokenizer = require('~/server/services/Tokenizer'); +const { logger, sendEvent } = require('~/config'); const { sleep } = require('~/server/utils'); const BaseClient = require('./BaseClient'); -const { logger } = require('~/config'); const HUMAN_PROMPT = '\n\nHuman:'; const AI_PROMPT = '\n\nAssistant:'; +class SplitStreamHandler extends _Handler { + getDeltaContent(chunk) { + return (chunk?.delta?.text ?? chunk?.completion) || ''; + } + getReasoningDelta(chunk) { + return chunk?.delta?.thinking || ''; + } +} + /** Helper function to introduce a delay before retrying */ function delayBeforeRetry(attempts, baseDelay = 1000) { return new Promise((resolve) => setTimeout(resolve, baseDelay * attempts)); @@ -68,6 +83,8 @@ class AnthropicClient extends BaseClient { /** The key for the usage object's output tokens * @type {string} */ this.outputTokensKey = 'output_tokens'; + /** @type {SplitStreamHandler | undefined} */ + this.streamHandler; } setOptions(options) { @@ -97,9 +114,10 @@ class AnthropicClient extends BaseClient { const modelMatch = matchModelName(this.modelOptions.model, EModelEndpoint.anthropic); this.isClaude3 = modelMatch.includes('claude-3'); - this.isLegacyOutput = !modelMatch.includes('claude-3-5-sonnet'); - this.supportsCacheControl = - this.options.promptCache && this.checkPromptCacheSupport(modelMatch); + this.isLegacyOutput = !( + /claude-3[-.]5-sonnet/.test(modelMatch) || /claude-3[-.]7/.test(modelMatch) + ); + this.supportsCacheControl = this.options.promptCache && checkPromptCacheSupport(modelMatch); if ( this.isLegacyOutput && @@ -125,7 +143,7 @@ class AnthropicClient extends BaseClient { this.options.endpointType ?? this.options.endpoint, this.options.endpointTokenConfig, ) ?? - 1500; + anthropicSettings.maxOutputTokens.reset(this.modelOptions.model); this.maxPromptTokens = this.options.maxPromptTokens || this.maxContextTokens - this.maxResponseTokens; @@ -171,18 +189,9 @@ class AnthropicClient extends BaseClient { options.baseURL = this.options.reverseProxyUrl; } - if ( - this.supportsCacheControl && - requestOptions?.model && - requestOptions.model.includes('claude-3-5-sonnet') - ) { - options.defaultHeaders = { - 'anthropic-beta': 'max-tokens-3-5-sonnet-2024-07-15,prompt-caching-2024-07-31', - }; - } else if (this.supportsCacheControl) { - options.defaultHeaders = { - 'anthropic-beta': 'prompt-caching-2024-07-31', - }; + const headers = getClaudeHeaders(requestOptions?.model, this.supportsCacheControl); + if (headers) { + options.defaultHeaders = headers; } return new Anthropic(options); @@ -668,29 +677,48 @@ class AnthropicClient extends BaseClient { * @returns {Promise} The response from the Anthropic client. */ async createResponse(client, options, useMessages) { - return useMessages ?? this.useMessages + return (useMessages ?? this.useMessages) ? await client.messages.create(options) : await client.completions.create(options); } + getMessageMapMethod() { + /** + * @param {TMessage} msg + */ + return (msg) => { + if (msg.text != null && msg.text && msg.text.startsWith(':::thinking')) { + msg.text = msg.text.replace(/:::thinking.*?:::/gs, '').trim(); + } else if (msg.content != null) { + /** @type {import('@librechat/agents').MessageContentComplex} */ + const newContent = []; + for (let part of msg.content) { + if (part.think != null) { + continue; + } + newContent.push(part); + } + msg.content = newContent; + } + + return msg; + }; + } + /** - * @param {string} modelName - * @returns {boolean} + * @param {string[]} [intermediateReply] + * @returns {string} */ - checkPromptCacheSupport(modelName) { - const modelMatch = matchModelName(modelName, EModelEndpoint.anthropic); - if (modelMatch.includes('claude-3-5-sonnet-latest')) { - return false; + getStreamText(intermediateReply) { + if (!this.streamHandler) { + return intermediateReply?.join('') ?? ''; } - if ( - modelMatch === 'claude-3-5-sonnet' || - modelMatch === 'claude-3-5-haiku' || - modelMatch === 'claude-3-haiku' || - modelMatch === 'claude-3-opus' - ) { - return true; - } - return false; + + const reasoningText = this.streamHandler.reasoningTokens.join(''); + + const reasoningBlock = reasoningText.length > 0 ? `:::thinking\n${reasoningText}\n:::\n` : ''; + + return `${reasoningBlock}${this.streamHandler.tokens.join('')}`; } async sendCompletion(payload, { onProgress, abortController }) { @@ -710,7 +738,6 @@ class AnthropicClient extends BaseClient { user_id: this.user, }; - let text = ''; const { stream, model, @@ -721,22 +748,34 @@ class AnthropicClient extends BaseClient { topK: top_k, } = this.modelOptions; - const requestOptions = { + let requestOptions = { model, stream: stream || true, stop_sequences, temperature, metadata, - top_p, - top_k, }; if (this.useMessages) { requestOptions.messages = payload; - requestOptions.max_tokens = maxOutputTokens || legacy.maxOutputTokens.default; + requestOptions.max_tokens = + maxOutputTokens || anthropicSettings.maxOutputTokens.reset(requestOptions.model); } else { requestOptions.prompt = payload; - requestOptions.max_tokens_to_sample = maxOutputTokens || 1500; + requestOptions.max_tokens_to_sample = maxOutputTokens || legacy.maxOutputTokens.default; + } + + requestOptions = configureReasoning(requestOptions, { + thinking: this.options.thinking, + thinkingBudget: this.options.thinkingBudget, + }); + + if (!/claude-3[-.]7/.test(model)) { + requestOptions.top_p = top_p; + requestOptions.top_k = top_k; + } else if (requestOptions.thinking == null) { + requestOptions.topP = top_p; + requestOptions.topK = top_k; } if (this.systemMessage && this.supportsCacheControl === true) { @@ -756,13 +795,17 @@ class AnthropicClient extends BaseClient { } logger.debug('[AnthropicClient]', { ...requestOptions }); + this.streamHandler = new SplitStreamHandler({ + accumulate: true, + runId: this.responseMessageId, + handlers: { + [GraphEvents.ON_RUN_STEP]: (event) => sendEvent(this.options.res, event), + [GraphEvents.ON_MESSAGE_DELTA]: (event) => sendEvent(this.options.res, event), + [GraphEvents.ON_REASONING_DELTA]: (event) => sendEvent(this.options.res, event), + }, + }); - const handleChunk = (currentChunk) => { - if (currentChunk) { - text += currentChunk; - onProgress(currentChunk); - } - }; + let intermediateReply = this.streamHandler.tokens; const maxRetries = 3; const streamRate = this.options.streamRate ?? Constants.DEFAULT_STREAM_RATE; @@ -783,22 +826,15 @@ class AnthropicClient extends BaseClient { }); for await (const completion of response) { - // Handle each completion as before const type = completion?.type ?? ''; if (tokenEventTypes.has(type)) { logger.debug(`[AnthropicClient] ${type}`, completion); this[type] = completion; } - if (completion?.delta?.text) { - handleChunk(completion.delta.text); - } else if (completion.completion) { - handleChunk(completion.completion); - } - + this.streamHandler.handle(completion); await sleep(streamRate); } - // Successful processing, exit loop break; } catch (error) { attempts += 1; @@ -808,6 +844,10 @@ class AnthropicClient extends BaseClient { if (attempts < maxRetries) { await delayBeforeRetry(attempts, 350); + } else if (this.streamHandler && this.streamHandler.reasoningTokens.length) { + return this.getStreamText(); + } else if (intermediateReply.length > 0) { + return this.getStreamText(intermediateReply); } else { throw new Error(`Operation failed after ${maxRetries} attempts: ${error.message}`); } @@ -823,8 +863,7 @@ class AnthropicClient extends BaseClient { } await processResponse.bind(this)(); - - return text.trim(); + return this.getStreamText(intermediateReply); } getSaveOptions() { @@ -834,6 +873,8 @@ class AnthropicClient extends BaseClient { promptPrefix: this.options.promptPrefix, modelLabel: this.options.modelLabel, promptCache: this.options.promptCache, + thinking: this.options.thinking, + thinkingBudget: this.options.thinkingBudget, resendFiles: this.options.resendFiles, iconURL: this.options.iconURL, greeting: this.options.greeting, diff --git a/api/app/clients/BaseClient.js b/api/app/clients/BaseClient.js index ebf3ca12d9..d3077e68f5 100644 --- a/api/app/clients/BaseClient.js +++ b/api/app/clients/BaseClient.js @@ -5,10 +5,12 @@ const { isAgentsEndpoint, isParamEndpoint, EModelEndpoint, + ContentTypes, + excludedKeys, ErrorTypes, Constants, } = require('librechat-data-provider'); -const { getMessages, saveMessage, updateMessage, saveConvo } = require('~/models'); +const { getMessages, saveMessage, updateMessage, saveConvo, getConvo } = require('~/models'); const { addSpaceIfNeeded, isEnabled } = require('~/server/utils'); const { truncateToolCallOutputs } = require('./prompts'); const checkBalance = require('~/models/checkBalance'); @@ -55,6 +57,10 @@ class BaseClient { * Flag to determine if the client re-submitted the latest assistant message. * @type {boolean | undefined} */ this.continued; + /** + * Flag to determine if the client has already fetched the conversation while saving new messages. + * @type {boolean | undefined} */ + this.fetchedConvo; /** @type {TMessage[]} */ this.currentMessages = []; /** @type {import('librechat-data-provider').VisionModes | undefined} */ @@ -360,17 +366,14 @@ class BaseClient { * context: TMessage[], * remainingContextTokens: number, * messagesToRefine: TMessage[], - * summaryIndex: number, - * }>} An object with four properties: `context`, `summaryIndex`, `remainingContextTokens`, and `messagesToRefine`. + * }>} An object with three properties: `context`, `remainingContextTokens`, and `messagesToRefine`. * `context` is an array of messages that fit within the token limit. - * `summaryIndex` is the index of the first message in the `messagesToRefine` array. * `remainingContextTokens` is the number of tokens remaining within the limit after adding the messages to the context. * `messagesToRefine` is an array of messages that were not added to the context because they would have exceeded the token limit. */ async getMessagesWithinTokenLimit({ messages: _messages, maxContextTokens, instructions }) { // Every reply is primed with <|start|>assistant<|message|>, so we // start with 3 tokens for the label after all messages have been counted. - let summaryIndex = -1; let currentTokenCount = 3; const instructionsTokenCount = instructions?.tokenCount ?? 0; let remainingContextTokens = @@ -403,14 +406,12 @@ class BaseClient { } const prunedMemory = messages; - summaryIndex = prunedMemory.length - 1; remainingContextTokens -= currentTokenCount; return { context: context.reverse(), remainingContextTokens, messagesToRefine: prunedMemory, - summaryIndex, }; } @@ -453,7 +454,7 @@ class BaseClient { let orderedWithInstructions = this.addInstructions(orderedMessages, instructions); - let { context, remainingContextTokens, messagesToRefine, summaryIndex } = + let { context, remainingContextTokens, messagesToRefine } = await this.getMessagesWithinTokenLimit({ messages: orderedWithInstructions, instructions, @@ -523,7 +524,7 @@ class BaseClient { } // Make sure to only continue summarization logic if the summary message was generated - shouldSummarize = summaryMessage && shouldSummarize; + shouldSummarize = summaryMessage != null && shouldSummarize === true; logger.debug('[BaseClient] Context Count (2/2)', { remainingContextTokens, @@ -533,17 +534,18 @@ class BaseClient { /** @type {Record | undefined} */ let tokenCountMap; if (buildTokenMap) { - tokenCountMap = orderedWithInstructions.reduce((map, message, index) => { + const currentPayload = shouldSummarize ? orderedWithInstructions : context; + tokenCountMap = currentPayload.reduce((map, message, index) => { const { messageId } = message; if (!messageId) { return map; } - if (shouldSummarize && index === summaryIndex && !usePrevSummary) { + if (shouldSummarize && index === messagesToRefine.length - 1 && !usePrevSummary) { map.summaryMessage = { ...summaryMessage, messageId, tokenCount: summaryTokenCount }; } - map[messageId] = orderedWithInstructions[index].tokenCount; + map[messageId] = currentPayload[index].tokenCount; return map; }, {}); } @@ -863,16 +865,39 @@ class BaseClient { return { message: savedMessage }; } - const conversation = await saveConvo( - this.options.req, - { - conversationId: message.conversationId, - endpoint: this.options.endpoint, - endpointType: this.options.endpointType, - ...endpointOptions, - }, - { context: 'api/app/clients/BaseClient.js - saveMessageToDatabase #saveConvo' }, - ); + const fieldsToKeep = { + conversationId: message.conversationId, + endpoint: this.options.endpoint, + endpointType: this.options.endpointType, + ...endpointOptions, + }; + + const existingConvo = + this.fetchedConvo === true + ? null + : await getConvo(this.options.req?.user?.id, message.conversationId); + + const unsetFields = {}; + if (existingConvo != null) { + this.fetchedConvo = true; + for (const key in existingConvo) { + if (!key) { + continue; + } + if (excludedKeys.has(key)) { + continue; + } + + if (endpointOptions?.[key] === undefined) { + unsetFields[key] = 1; + } + } + } + + const conversation = await saveConvo(this.options.req, fieldsToKeep, { + context: 'api/app/clients/BaseClient.js - saveMessageToDatabase #saveConvo', + unsetFields, + }); return { message: savedMessage, conversation }; } @@ -993,11 +1018,17 @@ class BaseClient { const processValue = (value) => { if (Array.isArray(value)) { for (let item of value) { - if (!item || !item.type || item.type === 'image_url') { + if ( + !item || + !item.type || + item.type === ContentTypes.THINK || + item.type === ContentTypes.ERROR || + item.type === ContentTypes.IMAGE_URL + ) { continue; } - if (item.type === 'tool_call' && item.tool_call != null) { + if (item.type === ContentTypes.TOOL_CALL && item.tool_call != null) { const toolName = item.tool_call?.name || ''; if (toolName != null && toolName && typeof toolName === 'string') { numTokens += this.getTokenCount(toolName); @@ -1093,9 +1124,13 @@ class BaseClient { return message; } - const files = await getFiles({ - file_id: { $in: fileIds }, - }); + const files = await getFiles( + { + file_id: { $in: fileIds }, + }, + {}, + {}, + ); await this.addImageURLs(message, files, this.visionMode); diff --git a/api/app/clients/GoogleClient.js b/api/app/clients/GoogleClient.js index 03461a6796..58ee783d2a 100644 --- a/api/app/clients/GoogleClient.js +++ b/api/app/clients/GoogleClient.js @@ -51,7 +51,7 @@ class GoogleClient extends BaseClient { const serviceKey = creds[AuthKeys.GOOGLE_SERVICE_KEY] ?? {}; this.serviceKey = - serviceKey && typeof serviceKey === 'string' ? JSON.parse(serviceKey) : serviceKey ?? {}; + serviceKey && typeof serviceKey === 'string' ? JSON.parse(serviceKey) : (serviceKey ?? {}); /** @type {string | null | undefined} */ this.project_id = this.serviceKey.project_id; this.client_email = this.serviceKey.client_email; @@ -73,6 +73,8 @@ class GoogleClient extends BaseClient { * @type {string} */ this.outputTokensKey = 'output_tokens'; this.visionMode = VisionModes.generative; + /** @type {string} */ + this.systemMessage; if (options.skipSetOptions) { return; } @@ -184,7 +186,7 @@ class GoogleClient extends BaseClient { if (typeof this.options.artifactsPrompt === 'string' && this.options.artifactsPrompt) { promptPrefix = `${promptPrefix ?? ''}\n${this.options.artifactsPrompt}`.trim(); } - this.options.promptPrefix = promptPrefix; + this.systemMessage = promptPrefix; this.initializeClient(); return this; } @@ -314,7 +316,7 @@ class GoogleClient extends BaseClient { } this.augmentedPrompt = await this.contextHandlers.createContext(); - this.options.promptPrefix = this.augmentedPrompt + this.options.promptPrefix; + this.systemMessage = this.augmentedPrompt + this.systemMessage; } } @@ -361,8 +363,8 @@ class GoogleClient extends BaseClient { throw new Error('[GoogleClient] PaLM 2 and Codey models are no longer supported.'); } - if (this.options.promptPrefix) { - const instructionsTokenCount = this.getTokenCount(this.options.promptPrefix); + if (this.systemMessage) { + const instructionsTokenCount = this.getTokenCount(this.systemMessage); this.maxContextTokens = this.maxContextTokens - instructionsTokenCount; if (this.maxContextTokens < 0) { @@ -417,8 +419,8 @@ class GoogleClient extends BaseClient { ], }; - if (this.options.promptPrefix) { - payload.instances[0].context = this.options.promptPrefix; + if (this.systemMessage) { + payload.instances[0].context = this.systemMessage; } logger.debug('[GoogleClient] buildMessages', payload); @@ -464,7 +466,7 @@ class GoogleClient extends BaseClient { identityPrefix = `${identityPrefix}\nYou are ${this.options.modelLabel}`; } - let promptPrefix = (this.options.promptPrefix ?? '').trim(); + let promptPrefix = (this.systemMessage ?? '').trim(); if (identityPrefix) { promptPrefix = `${identityPrefix}${promptPrefix}`; @@ -639,7 +641,7 @@ class GoogleClient extends BaseClient { let error; try { if (!EXCLUDED_GENAI_MODELS.test(modelName) && !this.project_id) { - /** @type {GenAI} */ + /** @type {GenerativeModel} */ const client = this.client; /** @type {GenerateContentRequest} */ const requestOptions = { @@ -648,7 +650,7 @@ class GoogleClient extends BaseClient { generationConfig: googleGenConfigSchema.parse(this.modelOptions), }; - const promptPrefix = (this.options.promptPrefix ?? '').trim(); + const promptPrefix = (this.systemMessage ?? '').trim(); if (promptPrefix.length) { requestOptions.systemInstruction = { parts: [ @@ -663,7 +665,17 @@ class GoogleClient extends BaseClient { /** @type {GenAIUsageMetadata} */ let usageMetadata; - const result = await client.generateContentStream(requestOptions); + abortController.signal.addEventListener( + 'abort', + () => { + logger.warn('[GoogleClient] Request was aborted', abortController.signal.reason); + }, + { once: true }, + ); + + const result = await client.generateContentStream(requestOptions, { + signal: abortController.signal, + }); for await (const chunk of result.stream) { usageMetadata = !usageMetadata ? chunk?.usageMetadata @@ -815,7 +827,8 @@ class GoogleClient extends BaseClient { let reply = ''; const { abortController } = options; - const model = this.modelOptions.modelName ?? this.modelOptions.model ?? ''; + const model = + this.options.titleModel ?? this.modelOptions.modelName ?? this.modelOptions.model ?? ''; const safetySettings = getSafetySettings(model); if (!EXCLUDED_GENAI_MODELS.test(model) && !this.project_id) { logger.debug('Identified titling model as GenAI version'); diff --git a/api/app/clients/OllamaClient.js b/api/app/clients/OllamaClient.js index d86e120f43..77d007580c 100644 --- a/api/app/clients/OllamaClient.js +++ b/api/app/clients/OllamaClient.js @@ -2,7 +2,7 @@ const { z } = require('zod'); const axios = require('axios'); const { Ollama } = require('ollama'); const { Constants } = require('librechat-data-provider'); -const { deriveBaseURL } = require('~/utils'); +const { deriveBaseURL, logAxiosError } = require('~/utils'); const { sleep } = require('~/server/utils'); const { logger } = require('~/config'); @@ -68,7 +68,7 @@ class OllamaClient { } catch (error) { const logMessage = 'Failed to fetch models from Ollama API. If you are not using Ollama directly, and instead, through some aggregator or reverse proxy that handles fetching via OpenAI spec, ensure the name of the endpoint doesn\'t start with `ollama` (case-insensitive).'; - logger.error(logMessage, error); + logAxiosError({ message: logMessage, error }); return []; } } diff --git a/api/app/clients/OpenAIClient.js b/api/app/clients/OpenAIClient.js index 368e7d6e84..b4477cca8a 100644 --- a/api/app/clients/OpenAIClient.js +++ b/api/app/clients/OpenAIClient.js @@ -7,6 +7,7 @@ const { ImageDetail, EModelEndpoint, resolveHeaders, + KnownEndpoints, openAISettings, ImageDetailCost, CohereConstants, @@ -108,18 +109,14 @@ class OpenAIClient extends BaseClient { const omniPattern = /\b(o1|o3)\b/i; this.isOmni = omniPattern.test(this.modelOptions.model); - const { OPENROUTER_API_KEY, OPENAI_FORCE_PROMPT } = process.env ?? {}; - if (OPENROUTER_API_KEY && !this.azure) { - this.apiKey = OPENROUTER_API_KEY; - this.useOpenRouter = true; - } - + const { OPENAI_FORCE_PROMPT } = process.env ?? {}; const { reverseProxyUrl: reverseProxy } = this.options; if ( !this.useOpenRouter && - reverseProxy && - reverseProxy.includes('https://openrouter.ai/api/v1') + ((reverseProxy && reverseProxy.includes(KnownEndpoints.openrouter)) || + (this.options.endpoint && + this.options.endpoint.toLowerCase().includes(KnownEndpoints.openrouter))) ) { this.useOpenRouter = true; } @@ -306,7 +303,9 @@ class OpenAIClient extends BaseClient { } getEncoding() { - return this.model?.includes('gpt-4o') ? 'o200k_base' : 'cl100k_base'; + return this.modelOptions?.model && /gpt-4[^-\s]/.test(this.modelOptions.model) + ? 'o200k_base' + : 'cl100k_base'; } /** @@ -613,7 +612,7 @@ class OpenAIClient extends BaseClient { } initializeLLM({ - model = 'gpt-4o-mini', + model = openAISettings.model.default, modelName, temperature = 0.2, max_tokens, @@ -714,7 +713,7 @@ class OpenAIClient extends BaseClient { const { OPENAI_TITLE_MODEL } = process.env ?? {}; - let model = this.options.titleModel ?? OPENAI_TITLE_MODEL ?? 'gpt-4o-mini'; + let model = this.options.titleModel ?? OPENAI_TITLE_MODEL ?? openAISettings.model.default; if (model === Constants.CURRENT_MODEL) { model = this.modelOptions.model; } @@ -907,7 +906,7 @@ ${convo} let prompt; // TODO: remove the gpt fallback and make it specific to endpoint - const { OPENAI_SUMMARY_MODEL = 'gpt-4o-mini' } = process.env ?? {}; + const { OPENAI_SUMMARY_MODEL = openAISettings.model.default } = process.env ?? {}; let model = this.options.summaryModel ?? OPENAI_SUMMARY_MODEL; if (model === Constants.CURRENT_MODEL) { model = this.modelOptions.model; @@ -1108,6 +1107,16 @@ ${convo} return (msg) => { if (msg.text != null && msg.text && msg.text.startsWith(':::thinking')) { msg.text = msg.text.replace(/:::thinking.*?:::/gs, '').trim(); + } else if (msg.content != null) { + /** @type {import('@librechat/agents').MessageContentComplex} */ + const newContent = []; + for (let part of msg.content) { + if (part.think != null) { + continue; + } + newContent.push(part); + } + msg.content = newContent; } return msg; @@ -1273,6 +1282,29 @@ ${convo} }); } + /** Note: OpenAI Web Search models do not support any known parameters besdies `max_tokens` */ + if (modelOptions.model && /gpt-4o.*search/.test(modelOptions.model)) { + const searchExcludeParams = [ + 'frequency_penalty', + 'presence_penalty', + 'temperature', + 'top_p', + 'top_k', + 'stop', + 'logit_bias', + 'seed', + 'response_format', + 'n', + 'logprobs', + 'user', + ]; + + this.options.dropParams = this.options.dropParams || []; + this.options.dropParams = [ + ...new Set([...this.options.dropParams, ...searchExcludeParams]), + ]; + } + if (this.options.dropParams && Array.isArray(this.options.dropParams)) { this.options.dropParams.forEach((param) => { delete modelOptions[param]; @@ -1308,8 +1340,12 @@ ${convo} ) { delete modelOptions.stream; delete modelOptions.stop; - } else if (!this.isOmni && modelOptions.reasoning_effort != null) { + } else if ( + (!this.isOmni || /^o1-(mini|preview)/i.test(modelOptions.model)) && + modelOptions.reasoning_effort != null + ) { delete modelOptions.reasoning_effort; + delete modelOptions.temperature; } let reasoningKey = 'reasoning_content'; @@ -1317,6 +1353,12 @@ ${convo} modelOptions.include_reasoning = true; reasoningKey = 'reasoning'; } + if (this.useOpenRouter && modelOptions.reasoning_effort != null) { + modelOptions.reasoning = { + effort: modelOptions.reasoning_effort, + }; + delete modelOptions.reasoning_effort; + } this.streamHandler = new SplitStreamHandler({ reasoningKey, diff --git a/api/app/clients/prompts/addCacheControl.js b/api/app/clients/prompts/addCacheControl.js index eed5910dc9..6bfd901a65 100644 --- a/api/app/clients/prompts/addCacheControl.js +++ b/api/app/clients/prompts/addCacheControl.js @@ -1,7 +1,7 @@ /** * Anthropic API: Adds cache control to the appropriate user messages in the payload. - * @param {Array} messages - The array of message objects. - * @returns {Array} - The updated array of message objects with cache control added. + * @param {Array} messages - The array of message objects. + * @returns {Array} - The updated array of message objects with cache control added. */ function addCacheControl(messages) { if (!Array.isArray(messages) || messages.length < 2) { @@ -13,7 +13,9 @@ function addCacheControl(messages) { for (let i = updatedMessages.length - 1; i >= 0 && userMessagesModified < 2; i--) { const message = updatedMessages[i]; - if (message.role !== 'user') { + if (message.getType != null && message.getType() !== 'human') { + continue; + } else if (message.getType == null && message.role !== 'user') { continue; } diff --git a/api/app/clients/prompts/formatAgentMessages.spec.js b/api/app/clients/prompts/formatAgentMessages.spec.js index 20731f6984..360fa00a34 100644 --- a/api/app/clients/prompts/formatAgentMessages.spec.js +++ b/api/app/clients/prompts/formatAgentMessages.spec.js @@ -282,4 +282,80 @@ describe('formatAgentMessages', () => { // Additional check to ensure the consecutive assistant messages were combined expect(result[1].content).toHaveLength(2); }); + + it('should skip THINK type content parts', () => { + const payload = [ + { + role: 'assistant', + content: [ + { type: ContentTypes.TEXT, [ContentTypes.TEXT]: 'Initial response' }, + { type: ContentTypes.THINK, [ContentTypes.THINK]: 'Reasoning about the problem...' }, + { type: ContentTypes.TEXT, [ContentTypes.TEXT]: 'Final answer' }, + ], + }, + ]; + + const result = formatAgentMessages(payload); + + expect(result).toHaveLength(1); + expect(result[0]).toBeInstanceOf(AIMessage); + expect(result[0].content).toEqual('Initial response\nFinal answer'); + }); + + it('should join TEXT content as string when THINK content type is present', () => { + const payload = [ + { + role: 'assistant', + content: [ + { type: ContentTypes.THINK, [ContentTypes.THINK]: 'Analyzing the problem...' }, + { type: ContentTypes.TEXT, [ContentTypes.TEXT]: 'First part of response' }, + { type: ContentTypes.TEXT, [ContentTypes.TEXT]: 'Second part of response' }, + { type: ContentTypes.TEXT, [ContentTypes.TEXT]: 'Final part of response' }, + ], + }, + ]; + + const result = formatAgentMessages(payload); + + expect(result).toHaveLength(1); + expect(result[0]).toBeInstanceOf(AIMessage); + expect(typeof result[0].content).toBe('string'); + expect(result[0].content).toBe( + 'First part of response\nSecond part of response\nFinal part of response', + ); + expect(result[0].content).not.toContain('Analyzing the problem...'); + }); + + it('should exclude ERROR type content parts', () => { + const payload = [ + { + role: 'assistant', + content: [ + { type: ContentTypes.TEXT, [ContentTypes.TEXT]: 'Hello there' }, + { + type: ContentTypes.ERROR, + [ContentTypes.ERROR]: + 'An error occurred while processing the request: Something went wrong', + }, + { type: ContentTypes.TEXT, [ContentTypes.TEXT]: 'Final answer' }, + ], + }, + ]; + + const result = formatAgentMessages(payload); + + expect(result).toHaveLength(1); + expect(result[0]).toBeInstanceOf(AIMessage); + expect(result[0].content).toEqual([ + { type: ContentTypes.TEXT, [ContentTypes.TEXT]: 'Hello there' }, + { type: ContentTypes.TEXT, [ContentTypes.TEXT]: 'Final answer' }, + ]); + + // Make sure no error content exists in the result + const hasErrorContent = result[0].content.some( + (item) => + item.type === ContentTypes.ERROR || JSON.stringify(item).includes('An error occurred'), + ); + expect(hasErrorContent).toBe(false); + }); }); diff --git a/api/app/clients/prompts/formatMessages.js b/api/app/clients/prompts/formatMessages.js index d84e62cca8..9fa0d40497 100644 --- a/api/app/clients/prompts/formatMessages.js +++ b/api/app/clients/prompts/formatMessages.js @@ -153,6 +153,7 @@ const formatAgentMessages = (payload) => { let currentContent = []; let lastAIMessage = null; + let hasReasoning = false; for (const part of message.content) { if (part.type === ContentTypes.TEXT && part.tool_call_ids) { /* @@ -207,11 +208,27 @@ const formatAgentMessages = (payload) => { content: output || '', }), ); + } else if (part.type === ContentTypes.THINK) { + hasReasoning = true; + continue; + } else if (part.type === ContentTypes.ERROR || part.type === ContentTypes.AGENT_UPDATE) { + continue; } else { currentContent.push(part); } } + if (hasReasoning) { + currentContent = currentContent + .reduce((acc, curr) => { + if (curr.type === ContentTypes.TEXT) { + return `${acc}${curr[ContentTypes.TEXT]}\n`; + } + return acc; + }, '') + .trim(); + } + if (currentContent.length > 0) { messages.push(new AIMessage({ content: currentContent })); } diff --git a/api/app/clients/specs/AnthropicClient.test.js b/api/app/clients/specs/AnthropicClient.test.js index eef6bb6748..223f3038c0 100644 --- a/api/app/clients/specs/AnthropicClient.test.js +++ b/api/app/clients/specs/AnthropicClient.test.js @@ -1,3 +1,4 @@ +const { SplitStreamHandler } = require('@librechat/agents'); const { anthropicSettings } = require('librechat-data-provider'); const AnthropicClient = require('~/app/clients/AnthropicClient'); @@ -405,4 +406,327 @@ describe('AnthropicClient', () => { expect(Number.isNaN(result)).toBe(false); }); }); + + describe('maxOutputTokens handling for different models', () => { + it('should not cap maxOutputTokens for Claude 3.5 Sonnet models', () => { + const client = new AnthropicClient('test-api-key'); + const highTokenValue = anthropicSettings.legacy.maxOutputTokens.default * 10; + + client.setOptions({ + modelOptions: { + model: 'claude-3-5-sonnet', + maxOutputTokens: highTokenValue, + }, + }); + + expect(client.modelOptions.maxOutputTokens).toBe(highTokenValue); + + // Test with decimal notation + client.setOptions({ + modelOptions: { + model: 'claude-3.5-sonnet', + maxOutputTokens: highTokenValue, + }, + }); + + expect(client.modelOptions.maxOutputTokens).toBe(highTokenValue); + }); + + it('should not cap maxOutputTokens for Claude 3.7 models', () => { + const client = new AnthropicClient('test-api-key'); + const highTokenValue = anthropicSettings.legacy.maxOutputTokens.default * 2; + + client.setOptions({ + modelOptions: { + model: 'claude-3-7-sonnet', + maxOutputTokens: highTokenValue, + }, + }); + + expect(client.modelOptions.maxOutputTokens).toBe(highTokenValue); + + // Test with decimal notation + client.setOptions({ + modelOptions: { + model: 'claude-3.7-sonnet', + maxOutputTokens: highTokenValue, + }, + }); + + expect(client.modelOptions.maxOutputTokens).toBe(highTokenValue); + }); + + it('should cap maxOutputTokens for Claude 3.5 Haiku models', () => { + const client = new AnthropicClient('test-api-key'); + const highTokenValue = anthropicSettings.legacy.maxOutputTokens.default * 2; + + client.setOptions({ + modelOptions: { + model: 'claude-3-5-haiku', + maxOutputTokens: highTokenValue, + }, + }); + + expect(client.modelOptions.maxOutputTokens).toBe( + anthropicSettings.legacy.maxOutputTokens.default, + ); + + // Test with decimal notation + client.setOptions({ + modelOptions: { + model: 'claude-3.5-haiku', + maxOutputTokens: highTokenValue, + }, + }); + + expect(client.modelOptions.maxOutputTokens).toBe( + anthropicSettings.legacy.maxOutputTokens.default, + ); + }); + + it('should cap maxOutputTokens for Claude 3 Haiku and Opus models', () => { + const client = new AnthropicClient('test-api-key'); + const highTokenValue = anthropicSettings.legacy.maxOutputTokens.default * 2; + + // Test haiku + client.setOptions({ + modelOptions: { + model: 'claude-3-haiku', + maxOutputTokens: highTokenValue, + }, + }); + + expect(client.modelOptions.maxOutputTokens).toBe( + anthropicSettings.legacy.maxOutputTokens.default, + ); + + // Test opus + client.setOptions({ + modelOptions: { + model: 'claude-3-opus', + maxOutputTokens: highTokenValue, + }, + }); + + expect(client.modelOptions.maxOutputTokens).toBe( + anthropicSettings.legacy.maxOutputTokens.default, + ); + }); + }); + + describe('topK/topP parameters for different models', () => { + beforeEach(() => { + // Mock the SplitStreamHandler + jest.spyOn(SplitStreamHandler.prototype, 'handle').mockImplementation(() => {}); + }); + + afterEach(() => { + jest.restoreAllMocks(); + }); + + it('should include top_k and top_p parameters for non-claude-3.7 models', async () => { + const client = new AnthropicClient('test-api-key'); + + // Create a mock async generator function + async function* mockAsyncGenerator() { + yield { type: 'message_start', message: { usage: {} } }; + yield { delta: { text: 'Test response' } }; + yield { type: 'message_delta', usage: {} }; + } + + // Mock createResponse to return the async generator + jest.spyOn(client, 'createResponse').mockImplementation(() => { + return mockAsyncGenerator(); + }); + + client.setOptions({ + modelOptions: { + model: 'claude-3-opus', + temperature: 0.7, + topK: 10, + topP: 0.9, + }, + }); + + // Mock getClient to capture the request options + let capturedOptions = null; + jest.spyOn(client, 'getClient').mockImplementation((options) => { + capturedOptions = options; + return {}; + }); + + const payload = [{ role: 'user', content: 'Test message' }]; + await client.sendCompletion(payload, {}); + + // Check the options passed to getClient + expect(capturedOptions).toHaveProperty('top_k', 10); + expect(capturedOptions).toHaveProperty('top_p', 0.9); + }); + + it('should include top_k and top_p parameters for claude-3-5-sonnet models', async () => { + const client = new AnthropicClient('test-api-key'); + + // Create a mock async generator function + async function* mockAsyncGenerator() { + yield { type: 'message_start', message: { usage: {} } }; + yield { delta: { text: 'Test response' } }; + yield { type: 'message_delta', usage: {} }; + } + + // Mock createResponse to return the async generator + jest.spyOn(client, 'createResponse').mockImplementation(() => { + return mockAsyncGenerator(); + }); + + client.setOptions({ + modelOptions: { + model: 'claude-3-5-sonnet', + temperature: 0.7, + topK: 10, + topP: 0.9, + }, + }); + + // Mock getClient to capture the request options + let capturedOptions = null; + jest.spyOn(client, 'getClient').mockImplementation((options) => { + capturedOptions = options; + return {}; + }); + + const payload = [{ role: 'user', content: 'Test message' }]; + await client.sendCompletion(payload, {}); + + // Check the options passed to getClient + expect(capturedOptions).toHaveProperty('top_k', 10); + expect(capturedOptions).toHaveProperty('top_p', 0.9); + }); + + it('should not include top_k and top_p parameters for claude-3-7-sonnet models', async () => { + const client = new AnthropicClient('test-api-key'); + + // Create a mock async generator function + async function* mockAsyncGenerator() { + yield { type: 'message_start', message: { usage: {} } }; + yield { delta: { text: 'Test response' } }; + yield { type: 'message_delta', usage: {} }; + } + + // Mock createResponse to return the async generator + jest.spyOn(client, 'createResponse').mockImplementation(() => { + return mockAsyncGenerator(); + }); + + client.setOptions({ + modelOptions: { + model: 'claude-3-7-sonnet', + temperature: 0.7, + topK: 10, + topP: 0.9, + }, + }); + + // Mock getClient to capture the request options + let capturedOptions = null; + jest.spyOn(client, 'getClient').mockImplementation((options) => { + capturedOptions = options; + return {}; + }); + + const payload = [{ role: 'user', content: 'Test message' }]; + await client.sendCompletion(payload, {}); + + // Check the options passed to getClient + expect(capturedOptions).not.toHaveProperty('top_k'); + expect(capturedOptions).not.toHaveProperty('top_p'); + }); + + it('should not include top_k and top_p parameters for models with decimal notation (claude-3.7)', async () => { + const client = new AnthropicClient('test-api-key'); + + // Create a mock async generator function + async function* mockAsyncGenerator() { + yield { type: 'message_start', message: { usage: {} } }; + yield { delta: { text: 'Test response' } }; + yield { type: 'message_delta', usage: {} }; + } + + // Mock createResponse to return the async generator + jest.spyOn(client, 'createResponse').mockImplementation(() => { + return mockAsyncGenerator(); + }); + + client.setOptions({ + modelOptions: { + model: 'claude-3.7-sonnet', + temperature: 0.7, + topK: 10, + topP: 0.9, + }, + }); + + // Mock getClient to capture the request options + let capturedOptions = null; + jest.spyOn(client, 'getClient').mockImplementation((options) => { + capturedOptions = options; + return {}; + }); + + const payload = [{ role: 'user', content: 'Test message' }]; + await client.sendCompletion(payload, {}); + + // Check the options passed to getClient + expect(capturedOptions).not.toHaveProperty('top_k'); + expect(capturedOptions).not.toHaveProperty('top_p'); + }); + }); + + it('should include top_k and top_p parameters for Claude-3.7 models when thinking is explicitly disabled', async () => { + const client = new AnthropicClient('test-api-key', { + modelOptions: { + model: 'claude-3-7-sonnet', + temperature: 0.7, + topK: 10, + topP: 0.9, + }, + thinking: false, + }); + + async function* mockAsyncGenerator() { + yield { type: 'message_start', message: { usage: {} } }; + yield { delta: { text: 'Test response' } }; + yield { type: 'message_delta', usage: {} }; + } + + jest.spyOn(client, 'createResponse').mockImplementation(() => { + return mockAsyncGenerator(); + }); + + let capturedOptions = null; + jest.spyOn(client, 'getClient').mockImplementation((options) => { + capturedOptions = options; + return {}; + }); + + const payload = [{ role: 'user', content: 'Test message' }]; + await client.sendCompletion(payload, {}); + + expect(capturedOptions).toHaveProperty('topK', 10); + expect(capturedOptions).toHaveProperty('topP', 0.9); + + client.setOptions({ + modelOptions: { + model: 'claude-3.7-sonnet', + temperature: 0.7, + topK: 10, + topP: 0.9, + }, + thinking: false, + }); + + await client.sendCompletion(payload, {}); + + expect(capturedOptions).toHaveProperty('topK', 10); + expect(capturedOptions).toHaveProperty('topP', 0.9); + }); }); diff --git a/api/app/clients/specs/BaseClient.test.js b/api/app/clients/specs/BaseClient.test.js index e899449fb9..c9be50d3de 100644 --- a/api/app/clients/specs/BaseClient.test.js +++ b/api/app/clients/specs/BaseClient.test.js @@ -30,6 +30,8 @@ jest.mock('~/models', () => ({ updateFileUsage: jest.fn(), })); +const { getConvo, saveConvo } = require('~/models'); + jest.mock('@langchain/openai', () => { return { ChatOpenAI: jest.fn().mockImplementation(() => { @@ -162,7 +164,7 @@ describe('BaseClient', () => { const result = await TestClient.getMessagesWithinTokenLimit({ messages }); expect(result.context).toEqual(expectedContext); - expect(result.summaryIndex).toEqual(expectedIndex); + expect(result.messagesToRefine.length - 1).toEqual(expectedIndex); expect(result.remainingContextTokens).toBe(expectedRemainingContextTokens); expect(result.messagesToRefine).toEqual(expectedMessagesToRefine); }); @@ -198,7 +200,7 @@ describe('BaseClient', () => { const result = await TestClient.getMessagesWithinTokenLimit({ messages }); expect(result.context).toEqual(expectedContext); - expect(result.summaryIndex).toEqual(expectedIndex); + expect(result.messagesToRefine.length - 1).toEqual(expectedIndex); expect(result.remainingContextTokens).toBe(expectedRemainingContextTokens); expect(result.messagesToRefine).toEqual(expectedMessagesToRefine); }); @@ -540,10 +542,11 @@ describe('BaseClient', () => { test('saveMessageToDatabase is called with the correct arguments', async () => { const saveOptions = TestClient.getSaveOptions(); - const user = {}; // Mock user + const user = {}; const opts = { user }; + const saveSpy = jest.spyOn(TestClient, 'saveMessageToDatabase'); await TestClient.sendMessage('Hello, world!', opts); - expect(TestClient.saveMessageToDatabase).toHaveBeenCalledWith( + expect(saveSpy).toHaveBeenCalledWith( expect.objectContaining({ sender: expect.any(String), text: expect.any(String), @@ -557,6 +560,157 @@ describe('BaseClient', () => { ); }); + test('should handle existing conversation when getConvo retrieves one', async () => { + const existingConvo = { + conversationId: 'existing-convo-id', + endpoint: 'openai', + endpointType: 'openai', + model: 'gpt-3.5-turbo', + messages: [ + { role: 'user', content: 'Existing message 1' }, + { role: 'assistant', content: 'Existing response 1' }, + ], + temperature: 1, + }; + + const { temperature: _temp, ...newConvo } = existingConvo; + + const user = { + id: 'user-id', + }; + + getConvo.mockResolvedValue(existingConvo); + saveConvo.mockResolvedValue(newConvo); + + TestClient = initializeFakeClient( + apiKey, + { + ...options, + req: { + user, + }, + }, + [], + ); + + const saveSpy = jest.spyOn(TestClient, 'saveMessageToDatabase'); + + const newMessage = 'New message in existing conversation'; + const response = await TestClient.sendMessage(newMessage, { + user, + conversationId: existingConvo.conversationId, + }); + + expect(getConvo).toHaveBeenCalledWith(user.id, existingConvo.conversationId); + expect(TestClient.conversationId).toBe(existingConvo.conversationId); + expect(response.conversationId).toBe(existingConvo.conversationId); + expect(TestClient.fetchedConvo).toBe(true); + + expect(saveSpy).toHaveBeenCalledWith( + expect.objectContaining({ + conversationId: existingConvo.conversationId, + text: newMessage, + }), + expect.any(Object), + expect.any(Object), + ); + + expect(saveConvo).toHaveBeenCalledTimes(2); + expect(saveConvo).toHaveBeenCalledWith( + expect.any(Object), + expect.objectContaining({ + conversationId: existingConvo.conversationId, + }), + expect.objectContaining({ + context: 'api/app/clients/BaseClient.js - saveMessageToDatabase #saveConvo', + unsetFields: { + temperature: 1, + }, + }), + ); + + await TestClient.sendMessage('Another message', { + conversationId: existingConvo.conversationId, + }); + expect(getConvo).toHaveBeenCalledTimes(1); + }); + + test('should correctly handle existing conversation and unset fields appropriately', async () => { + const existingConvo = { + conversationId: 'existing-convo-id', + endpoint: 'openai', + endpointType: 'openai', + model: 'gpt-3.5-turbo', + messages: [ + { role: 'user', content: 'Existing message 1' }, + { role: 'assistant', content: 'Existing response 1' }, + ], + title: 'Existing Conversation', + someExistingField: 'existingValue', + anotherExistingField: 'anotherValue', + temperature: 0.7, + modelLabel: 'GPT-3.5', + }; + + getConvo.mockResolvedValue(existingConvo); + saveConvo.mockResolvedValue(existingConvo); + + TestClient = initializeFakeClient( + apiKey, + { + ...options, + modelOptions: { + model: 'gpt-4', + temperature: 0.5, + }, + }, + [], + ); + + const newMessage = 'New message in existing conversation'; + await TestClient.sendMessage(newMessage, { + conversationId: existingConvo.conversationId, + }); + + expect(saveConvo).toHaveBeenCalledTimes(2); + + const saveConvoCall = saveConvo.mock.calls[0]; + const [, savedFields, saveOptions] = saveConvoCall; + + // Instead of checking all excludedKeys, we'll just check specific fields + // that we know should be excluded + expect(savedFields).not.toHaveProperty('messages'); + expect(savedFields).not.toHaveProperty('title'); + + // Only check that someExistingField is in unsetFields + expect(saveOptions.unsetFields).toHaveProperty('someExistingField', 1); + + // Mock saveConvo to return the expected fields + saveConvo.mockImplementation((req, fields) => { + return Promise.resolve({ + ...fields, + endpoint: 'openai', + endpointType: 'openai', + model: 'gpt-4', + temperature: 0.5, + }); + }); + + // Only check the conversationId since that's the only field we can be sure about + expect(savedFields).toHaveProperty('conversationId', 'existing-convo-id'); + + expect(TestClient.fetchedConvo).toBe(true); + + await TestClient.sendMessage('Another message', { + conversationId: existingConvo.conversationId, + }); + + expect(getConvo).toHaveBeenCalledTimes(1); + + const secondSaveConvoCall = saveConvo.mock.calls[1]; + expect(secondSaveConvoCall[2]).toHaveProperty('unsetFields', {}); + }); + test('sendCompletion is called with the correct arguments', async () => { const payload = {}; // Mock payload TestClient.buildMessages.mockReturnValue({ prompt: payload, tokenCountMap: null }); diff --git a/api/app/clients/specs/FakeClient.js b/api/app/clients/specs/FakeClient.js index 7f4b75e1db..a466bb97f9 100644 --- a/api/app/clients/specs/FakeClient.js +++ b/api/app/clients/specs/FakeClient.js @@ -56,7 +56,6 @@ const initializeFakeClient = (apiKey, options, fakeMessages) => { let TestClient = new FakeClient(apiKey); TestClient.options = options; TestClient.abortController = { abort: jest.fn() }; - TestClient.saveMessageToDatabase = jest.fn(); TestClient.loadHistory = jest .fn() .mockImplementation((conversationId, parentMessageId = null) => { @@ -86,7 +85,6 @@ const initializeFakeClient = (apiKey, options, fakeMessages) => { return 'Mock response text'; }); - // eslint-disable-next-line no-unused-vars TestClient.getCompletion = jest.fn().mockImplementation(async (..._args) => { return { choices: [ diff --git a/api/app/clients/specs/OpenAIClient.test.js b/api/app/clients/specs/OpenAIClient.test.js index 2aaec518eb..0e811cf38a 100644 --- a/api/app/clients/specs/OpenAIClient.test.js +++ b/api/app/clients/specs/OpenAIClient.test.js @@ -202,14 +202,6 @@ describe('OpenAIClient', () => { expect(client.modelOptions.temperature).toBe(0.7); }); - it('should set apiKey and useOpenRouter if OPENROUTER_API_KEY is present', () => { - process.env.OPENROUTER_API_KEY = 'openrouter-key'; - client.setOptions({}); - expect(client.apiKey).toBe('openrouter-key'); - expect(client.useOpenRouter).toBe(true); - delete process.env.OPENROUTER_API_KEY; // Cleanup - }); - it('should set FORCE_PROMPT based on OPENAI_FORCE_PROMPT or reverseProxyUrl', () => { process.env.OPENAI_FORCE_PROMPT = 'true'; client.setOptions({}); @@ -534,7 +526,6 @@ describe('OpenAIClient', () => { afterEach(() => { delete process.env.AZURE_OPENAI_DEFAULT_MODEL; delete process.env.AZURE_USE_MODEL_AS_DEPLOYMENT_NAME; - delete process.env.OPENROUTER_API_KEY; }); it('should call getCompletion and fetchEventSource when using a text/instruct model', async () => { diff --git a/api/app/clients/tools/index.js b/api/app/clients/tools/index.js index b8df50c77d..df436fb089 100644 --- a/api/app/clients/tools/index.js +++ b/api/app/clients/tools/index.js @@ -2,9 +2,10 @@ const availableTools = require('./manifest.json'); // Structured Tools const DALLE3 = require('./structured/DALLE3'); +const FluxAPI = require('./structured/FluxAPI'); const OpenWeather = require('./structured/OpenWeather'); -const createYouTubeTools = require('./structured/YouTube'); const StructuredWolfram = require('./structured/Wolfram'); +const createYouTubeTools = require('./structured/YouTube'); const StructuredACS = require('./structured/AzureAISearch'); const StructuredSD = require('./structured/StableDiffusion'); const GoogleSearchAPI = require('./structured/GoogleSearch'); @@ -30,6 +31,7 @@ module.exports = { manifestToolMap, // Structured Tools DALLE3, + FluxAPI, OpenWeather, StructuredSD, StructuredACS, diff --git a/api/app/clients/tools/manifest.json b/api/app/clients/tools/manifest.json index 7cb92b8d87..43be7a4e6c 100644 --- a/api/app/clients/tools/manifest.json +++ b/api/app/clients/tools/manifest.json @@ -164,5 +164,19 @@ "description": "Sign up at OpenWeather, then get your key at API keys." } ] + }, + { + "name": "Flux", + "pluginKey": "flux", + "description": "Generate images using text with the Flux API.", + "icon": "https://blackforestlabs.ai/wp-content/uploads/2024/07/bfl_logo_retraced_blk.png", + "isAuthRequired": "true", + "authConfig": [ + { + "authField": "FLUX_API_KEY", + "label": "Your Flux API Key", + "description": "Provide your Flux API key from your user profile." + } + ] } ] diff --git a/api/app/clients/tools/structured/DALLE3.js b/api/app/clients/tools/structured/DALLE3.js index b604ad4ea4..fc0f1851f6 100644 --- a/api/app/clients/tools/structured/DALLE3.js +++ b/api/app/clients/tools/structured/DALLE3.js @@ -1,14 +1,17 @@ const { z } = require('zod'); const path = require('path'); const OpenAI = require('openai'); +const fetch = require('node-fetch'); const { v4: uuidv4 } = require('uuid'); const { Tool } = require('@langchain/core/tools'); const { HttpsProxyAgent } = require('https-proxy-agent'); -const { FileContext } = require('librechat-data-provider'); +const { FileContext, ContentTypes } = require('librechat-data-provider'); const { getImageBasename } = require('~/server/services/Files/images'); const extractBaseURL = require('~/utils/extractBaseURL'); const { logger } = require('~/config'); +const displayMessage = + 'DALL-E displayed an image. All generated images are already plainly visible, so don\'t repeat the descriptions in detail. Do not list download links as they are available in the UI already. The user may download the images by clicking on them, but do not mention anything about downloading to the user.'; class DALLE3 extends Tool { constructor(fields = {}) { super(); @@ -114,10 +117,7 @@ class DALLE3 extends Tool { if (this.isAgent === true && typeof value === 'string') { return [value, {}]; } else if (this.isAgent === true && typeof value === 'object') { - return [ - 'DALL-E displayed an image. All generated images are already plainly visible, so don\'t repeat the descriptions in detail. Do not list download links as they are available in the UI already. The user may download the images by clicking on them, but do not mention anything about downloading to the user.', - value, - ]; + return [displayMessage, value]; } return value; @@ -160,6 +160,32 @@ Error Message: ${error.message}`); ); } + if (this.isAgent) { + let fetchOptions = {}; + if (process.env.PROXY) { + fetchOptions.agent = new HttpsProxyAgent(process.env.PROXY); + } + const imageResponse = await fetch(theImageUrl, fetchOptions); + const arrayBuffer = await imageResponse.arrayBuffer(); + const base64 = Buffer.from(arrayBuffer).toString('base64'); + const content = [ + { + type: ContentTypes.IMAGE_URL, + image_url: { + url: `data:image/png;base64,${base64}`, + }, + }, + ]; + + const response = [ + { + type: ContentTypes.TEXT, + text: displayMessage, + }, + ]; + return [response, { content }]; + } + const imageBasename = getImageBasename(theImageUrl); const imageExt = path.extname(imageBasename); diff --git a/api/app/clients/tools/structured/FluxAPI.js b/api/app/clients/tools/structured/FluxAPI.js new file mode 100644 index 0000000000..80f9772200 --- /dev/null +++ b/api/app/clients/tools/structured/FluxAPI.js @@ -0,0 +1,554 @@ +const { z } = require('zod'); +const axios = require('axios'); +const fetch = require('node-fetch'); +const { v4: uuidv4 } = require('uuid'); +const { Tool } = require('@langchain/core/tools'); +const { HttpsProxyAgent } = require('https-proxy-agent'); +const { FileContext, ContentTypes } = require('librechat-data-provider'); +const { logger } = require('~/config'); + +const displayMessage = + 'Flux displayed an image. All generated images are already plainly visible, so don\'t repeat the descriptions in detail. Do not list download links as they are available in the UI already. The user may download the images by clicking on them, but do not mention anything about downloading to the user.'; + +/** + * FluxAPI - A tool for generating high-quality images from text prompts using the Flux API. + * Each call generates one image. If multiple images are needed, make multiple consecutive calls with the same or varied prompts. + */ +class FluxAPI extends Tool { + // Pricing constants in USD per image + static PRICING = { + FLUX_PRO_1_1_ULTRA: -0.06, // /v1/flux-pro-1.1-ultra + FLUX_PRO_1_1: -0.04, // /v1/flux-pro-1.1 + FLUX_PRO: -0.05, // /v1/flux-pro + FLUX_DEV: -0.025, // /v1/flux-dev + FLUX_PRO_FINETUNED: -0.06, // /v1/flux-pro-finetuned + FLUX_PRO_1_1_ULTRA_FINETUNED: -0.07, // /v1/flux-pro-1.1-ultra-finetuned + }; + + constructor(fields = {}) { + super(); + + /** @type {boolean} Used to initialize the Tool without necessary variables. */ + this.override = fields.override ?? false; + + this.userId = fields.userId; + this.fileStrategy = fields.fileStrategy; + + /** @type {boolean} **/ + this.isAgent = fields.isAgent; + this.returnMetadata = fields.returnMetadata ?? false; + + if (fields.processFileURL) { + /** @type {processFileURL} Necessary for output to contain all image metadata. */ + this.processFileURL = fields.processFileURL.bind(this); + } + + this.apiKey = fields.FLUX_API_KEY || this.getApiKey(); + + this.name = 'flux'; + this.description = + 'Use Flux to generate images from text descriptions. This tool can generate images and list available finetunes. Each generate call creates one image. For multiple images, make multiple consecutive calls.'; + + this.description_for_model = `// Transform any image description into a detailed, high-quality prompt. Never submit a prompt under 3 sentences. Follow these core rules: + // 1. ALWAYS enhance basic prompts into 5-10 detailed sentences (e.g., "a cat" becomes: "A close-up photo of a sleek Siamese cat with piercing blue eyes. The cat sits elegantly on a vintage leather armchair, its tail curled gracefully around its paws. Warm afternoon sunlight streams through a nearby window, casting gentle shadows across its face and highlighting the subtle variations in its cream and chocolate-point fur. The background is softly blurred, creating a shallow depth of field that draws attention to the cat's expressive features. The overall composition has a peaceful, contemplative mood with a professional photography style.") + // 2. Each prompt MUST be 3-6 descriptive sentences minimum, focusing on visual elements: lighting, composition, mood, and style + // Use action: 'list_finetunes' to see available custom models. When using finetunes, use endpoint: '/v1/flux-pro-finetuned' (default) or '/v1/flux-pro-1.1-ultra-finetuned' for higher quality and aspect ratio.`; + + // Add base URL from environment variable with fallback + this.baseUrl = process.env.FLUX_API_BASE_URL || 'https://api.us1.bfl.ai'; + + // Define the schema for structured input + this.schema = z.object({ + action: z + .enum(['generate', 'list_finetunes', 'generate_finetuned']) + .default('generate') + .describe( + 'Action to perform: "generate" for image generation, "generate_finetuned" for finetuned model generation, "list_finetunes" to get available custom models', + ), + prompt: z + .string() + .optional() + .describe( + 'Text prompt for image generation. Required when action is "generate". Not used for list_finetunes.', + ), + width: z + .number() + .optional() + .describe( + 'Width of the generated image in pixels. Must be a multiple of 32. Default is 1024.', + ), + height: z + .number() + .optional() + .describe( + 'Height of the generated image in pixels. Must be a multiple of 32. Default is 768.', + ), + prompt_upsampling: z + .boolean() + .optional() + .default(false) + .describe('Whether to perform upsampling on the prompt.'), + steps: z + .number() + .int() + .optional() + .describe('Number of steps to run the model for, a number from 1 to 50. Default is 40.'), + seed: z.number().optional().describe('Optional seed for reproducibility.'), + safety_tolerance: z + .number() + .optional() + .default(6) + .describe( + 'Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict.', + ), + endpoint: z + .enum([ + '/v1/flux-pro-1.1', + '/v1/flux-pro', + '/v1/flux-dev', + '/v1/flux-pro-1.1-ultra', + '/v1/flux-pro-finetuned', + '/v1/flux-pro-1.1-ultra-finetuned', + ]) + .optional() + .default('/v1/flux-pro-1.1') + .describe('Endpoint to use for image generation.'), + raw: z + .boolean() + .optional() + .default(false) + .describe( + 'Generate less processed, more natural-looking images. Only works for /v1/flux-pro-1.1-ultra.', + ), + finetune_id: z.string().optional().describe('ID of the finetuned model to use'), + finetune_strength: z + .number() + .optional() + .default(1.1) + .describe('Strength of the finetuning effect (typically between 0.1 and 1.2)'), + guidance: z.number().optional().default(2.5).describe('Guidance scale for finetuned models'), + aspect_ratio: z + .string() + .optional() + .default('16:9') + .describe('Aspect ratio for ultra models (e.g., "16:9")'), + }); + } + + getAxiosConfig() { + const config = {}; + if (process.env.PROXY) { + config.httpsAgent = new HttpsProxyAgent(process.env.PROXY); + } + return config; + } + + /** @param {Object|string} value */ + getDetails(value) { + if (typeof value === 'string') { + return value; + } + return JSON.stringify(value, null, 2); + } + + getApiKey() { + const apiKey = process.env.FLUX_API_KEY || ''; + if (!apiKey && !this.override) { + throw new Error('Missing FLUX_API_KEY environment variable.'); + } + return apiKey; + } + + wrapInMarkdown(imageUrl) { + const serverDomain = process.env.DOMAIN_SERVER || 'http://localhost:3080'; + return `![generated image](${serverDomain}${imageUrl})`; + } + + returnValue(value) { + if (this.isAgent === true && typeof value === 'string') { + return [value, {}]; + } else if (this.isAgent === true && typeof value === 'object') { + if (Array.isArray(value)) { + return value; + } + return [displayMessage, value]; + } + return value; + } + + async _call(data) { + const { action = 'generate', ...imageData } = data; + + // Use provided API key for this request if available, otherwise use default + const requestApiKey = this.apiKey || this.getApiKey(); + + // Handle list_finetunes action + if (action === 'list_finetunes') { + return this.getMyFinetunes(requestApiKey); + } + + // Handle finetuned generation + if (action === 'generate_finetuned') { + return this.generateFinetunedImage(imageData, requestApiKey); + } + + // For generate action, ensure prompt is provided + if (!imageData.prompt) { + throw new Error('Missing required field: prompt'); + } + + let payload = { + prompt: imageData.prompt, + prompt_upsampling: imageData.prompt_upsampling || false, + safety_tolerance: imageData.safety_tolerance || 6, + output_format: imageData.output_format || 'png', + }; + + // Add optional parameters if provided + if (imageData.width) { + payload.width = imageData.width; + } + if (imageData.height) { + payload.height = imageData.height; + } + if (imageData.steps) { + payload.steps = imageData.steps; + } + if (imageData.seed !== undefined) { + payload.seed = imageData.seed; + } + if (imageData.raw) { + payload.raw = imageData.raw; + } + + const generateUrl = `${this.baseUrl}${imageData.endpoint || '/v1/flux-pro'}`; + const resultUrl = `${this.baseUrl}/v1/get_result`; + + logger.debug('[FluxAPI] Generating image with payload:', payload); + logger.debug('[FluxAPI] Using endpoint:', generateUrl); + + let taskResponse; + try { + taskResponse = await axios.post(generateUrl, payload, { + headers: { + 'x-key': requestApiKey, + 'Content-Type': 'application/json', + Accept: 'application/json', + }, + ...this.getAxiosConfig(), + }); + } catch (error) { + const details = this.getDetails(error?.response?.data || error.message); + logger.error('[FluxAPI] Error while submitting task:', details); + + return this.returnValue( + `Something went wrong when trying to generate the image. The Flux API may be unavailable: + Error Message: ${details}`, + ); + } + + const taskId = taskResponse.data.id; + + // Polling for the result + let status = 'Pending'; + let resultData = null; + while (status !== 'Ready' && status !== 'Error') { + try { + // Wait 2 seconds between polls + await new Promise((resolve) => setTimeout(resolve, 2000)); + const resultResponse = await axios.get(resultUrl, { + headers: { + 'x-key': requestApiKey, + Accept: 'application/json', + }, + params: { id: taskId }, + ...this.getAxiosConfig(), + }); + status = resultResponse.data.status; + + if (status === 'Ready') { + resultData = resultResponse.data.result; + break; + } else if (status === 'Error') { + logger.error('[FluxAPI] Error in task:', resultResponse.data); + return this.returnValue('An error occurred during image generation.'); + } + } catch (error) { + const details = this.getDetails(error?.response?.data || error.message); + logger.error('[FluxAPI] Error while getting result:', details); + return this.returnValue('An error occurred while retrieving the image.'); + } + } + + // If no result data + if (!resultData || !resultData.sample) { + logger.error('[FluxAPI] No image data received from API. Response:', resultData); + return this.returnValue('No image data received from Flux API.'); + } + + // Try saving the image locally + const imageUrl = resultData.sample; + const imageName = `img-${uuidv4()}.png`; + + if (this.isAgent) { + try { + // Fetch the image and convert to base64 + const fetchOptions = {}; + if (process.env.PROXY) { + fetchOptions.agent = new HttpsProxyAgent(process.env.PROXY); + } + const imageResponse = await fetch(imageUrl, fetchOptions); + const arrayBuffer = await imageResponse.arrayBuffer(); + const base64 = Buffer.from(arrayBuffer).toString('base64'); + const content = [ + { + type: ContentTypes.IMAGE_URL, + image_url: { + url: `data:image/png;base64,${base64}`, + }, + }, + ]; + + const response = [ + { + type: ContentTypes.TEXT, + text: displayMessage, + }, + ]; + return [response, { content }]; + } catch (error) { + logger.error('Error processing image for agent:', error); + return this.returnValue(`Failed to process the image. ${error.message}`); + } + } + + try { + logger.debug('[FluxAPI] Saving image:', imageUrl); + const result = await this.processFileURL({ + fileStrategy: this.fileStrategy, + userId: this.userId, + URL: imageUrl, + fileName: imageName, + basePath: 'images', + context: FileContext.image_generation, + }); + + logger.debug('[FluxAPI] Image saved to path:', result.filepath); + + // Calculate cost based on endpoint + /** + * TODO: Cost handling + const endpoint = imageData.endpoint || '/v1/flux-pro'; + const endpointKey = Object.entries(FluxAPI.PRICING).find(([key, _]) => + endpoint.includes(key.toLowerCase().replace(/_/g, '-')), + )?.[0]; + const cost = FluxAPI.PRICING[endpointKey] || 0; + */ + this.result = this.returnMetadata ? result : this.wrapInMarkdown(result.filepath); + return this.returnValue(this.result); + } catch (error) { + const details = this.getDetails(error?.message ?? 'No additional error details.'); + logger.error('Error while saving the image:', details); + return this.returnValue(`Failed to save the image locally. ${details}`); + } + } + + async getMyFinetunes(apiKey = null) { + const finetunesUrl = `${this.baseUrl}/v1/my_finetunes`; + const detailsUrl = `${this.baseUrl}/v1/finetune_details`; + + try { + const headers = { + 'x-key': apiKey || this.getApiKey(), + 'Content-Type': 'application/json', + Accept: 'application/json', + }; + + // Get list of finetunes + const response = await axios.get(finetunesUrl, { + headers, + ...this.getAxiosConfig(), + }); + const finetunes = response.data.finetunes; + + // Fetch details for each finetune + const finetuneDetails = await Promise.all( + finetunes.map(async (finetuneId) => { + try { + const detailResponse = await axios.get(`${detailsUrl}?finetune_id=${finetuneId}`, { + headers, + ...this.getAxiosConfig(), + }); + return { + id: finetuneId, + ...detailResponse.data, + }; + } catch (error) { + logger.error(`[FluxAPI] Error fetching details for finetune ${finetuneId}:`, error); + return { + id: finetuneId, + error: 'Failed to fetch details', + }; + } + }), + ); + + if (this.isAgent) { + const formattedDetails = JSON.stringify(finetuneDetails, null, 2); + return [`Here are the available finetunes:\n${formattedDetails}`, null]; + } + return JSON.stringify(finetuneDetails); + } catch (error) { + const details = this.getDetails(error?.response?.data || error.message); + logger.error('[FluxAPI] Error while getting finetunes:', details); + const errorMsg = `Failed to get finetunes: ${details}`; + return this.isAgent ? this.returnValue([errorMsg, {}]) : new Error(errorMsg); + } + } + + async generateFinetunedImage(imageData, requestApiKey) { + if (!imageData.prompt) { + throw new Error('Missing required field: prompt'); + } + + if (!imageData.finetune_id) { + throw new Error( + 'Missing required field: finetune_id for finetuned generation. Please supply a finetune_id!', + ); + } + + // Validate endpoint is appropriate for finetuned generation + const validFinetunedEndpoints = ['/v1/flux-pro-finetuned', '/v1/flux-pro-1.1-ultra-finetuned']; + const endpoint = imageData.endpoint || '/v1/flux-pro-finetuned'; + + if (!validFinetunedEndpoints.includes(endpoint)) { + throw new Error( + `Invalid endpoint for finetuned generation. Must be one of: ${validFinetunedEndpoints.join(', ')}`, + ); + } + + let payload = { + prompt: imageData.prompt, + prompt_upsampling: imageData.prompt_upsampling || false, + safety_tolerance: imageData.safety_tolerance || 6, + output_format: imageData.output_format || 'png', + finetune_id: imageData.finetune_id, + finetune_strength: imageData.finetune_strength || 1.0, + guidance: imageData.guidance || 2.5, + }; + + // Add optional parameters if provided + if (imageData.width) { + payload.width = imageData.width; + } + if (imageData.height) { + payload.height = imageData.height; + } + if (imageData.steps) { + payload.steps = imageData.steps; + } + if (imageData.seed !== undefined) { + payload.seed = imageData.seed; + } + if (imageData.raw) { + payload.raw = imageData.raw; + } + + const generateUrl = `${this.baseUrl}${endpoint}`; + const resultUrl = `${this.baseUrl}/v1/get_result`; + + logger.debug('[FluxAPI] Generating finetuned image with payload:', payload); + logger.debug('[FluxAPI] Using endpoint:', generateUrl); + + let taskResponse; + try { + taskResponse = await axios.post(generateUrl, payload, { + headers: { + 'x-key': requestApiKey, + 'Content-Type': 'application/json', + Accept: 'application/json', + }, + ...this.getAxiosConfig(), + }); + } catch (error) { + const details = this.getDetails(error?.response?.data || error.message); + logger.error('[FluxAPI] Error while submitting finetuned task:', details); + return this.returnValue( + `Something went wrong when trying to generate the finetuned image. The Flux API may be unavailable: + Error Message: ${details}`, + ); + } + + const taskId = taskResponse.data.id; + + // Polling for the result + let status = 'Pending'; + let resultData = null; + while (status !== 'Ready' && status !== 'Error') { + try { + // Wait 2 seconds between polls + await new Promise((resolve) => setTimeout(resolve, 2000)); + const resultResponse = await axios.get(resultUrl, { + headers: { + 'x-key': requestApiKey, + Accept: 'application/json', + }, + params: { id: taskId }, + ...this.getAxiosConfig(), + }); + status = resultResponse.data.status; + + if (status === 'Ready') { + resultData = resultResponse.data.result; + break; + } else if (status === 'Error') { + logger.error('[FluxAPI] Error in finetuned task:', resultResponse.data); + return this.returnValue('An error occurred during finetuned image generation.'); + } + } catch (error) { + const details = this.getDetails(error?.response?.data || error.message); + logger.error('[FluxAPI] Error while getting finetuned result:', details); + return this.returnValue('An error occurred while retrieving the finetuned image.'); + } + } + + // If no result data + if (!resultData || !resultData.sample) { + logger.error('[FluxAPI] No image data received from API. Response:', resultData); + return this.returnValue('No image data received from Flux API.'); + } + + // Try saving the image locally + const imageUrl = resultData.sample; + const imageName = `img-${uuidv4()}.png`; + + try { + logger.debug('[FluxAPI] Saving finetuned image:', imageUrl); + const result = await this.processFileURL({ + fileStrategy: this.fileStrategy, + userId: this.userId, + URL: imageUrl, + fileName: imageName, + basePath: 'images', + context: FileContext.image_generation, + }); + + logger.debug('[FluxAPI] Finetuned image saved to path:', result.filepath); + + // Calculate cost based on endpoint + const endpointKey = endpoint.includes('ultra') + ? 'FLUX_PRO_1_1_ULTRA_FINETUNED' + : 'FLUX_PRO_FINETUNED'; + const cost = FluxAPI.PRICING[endpointKey] || 0; + // Return the result based on returnMetadata flag + this.result = this.returnMetadata ? result : this.wrapInMarkdown(result.filepath); + return this.returnValue(this.result); + } catch (error) { + const details = this.getDetails(error?.message ?? 'No additional error details.'); + logger.error('Error while saving the finetuned image:', details); + return this.returnValue(`Failed to save the finetuned image locally. ${details}`); + } + } +} + +module.exports = FluxAPI; diff --git a/api/app/clients/tools/structured/StableDiffusion.js b/api/app/clients/tools/structured/StableDiffusion.js index 6309da35d8..25a9e0abd3 100644 --- a/api/app/clients/tools/structured/StableDiffusion.js +++ b/api/app/clients/tools/structured/StableDiffusion.js @@ -6,10 +6,13 @@ const axios = require('axios'); const sharp = require('sharp'); const { v4: uuidv4 } = require('uuid'); const { Tool } = require('@langchain/core/tools'); -const { FileContext } = require('librechat-data-provider'); +const { FileContext, ContentTypes } = require('librechat-data-provider'); const paths = require('~/config/paths'); const { logger } = require('~/config'); +const displayMessage = + 'Stable Diffusion displayed an image. All generated images are already plainly visible, so don\'t repeat the descriptions in detail. Do not list download links as they are available in the UI already. The user may download the images by clicking on them, but do not mention anything about downloading to the user.'; + class StableDiffusionAPI extends Tool { constructor(fields) { super(); @@ -21,6 +24,8 @@ class StableDiffusionAPI extends Tool { this.override = fields.override ?? false; /** @type {boolean} Necessary for output to contain all image metadata. */ this.returnMetadata = fields.returnMetadata ?? false; + /** @type {boolean} */ + this.isAgent = fields.isAgent; if (fields.uploadImageBuffer) { /** @type {uploadImageBuffer} Necessary for output to contain all image metadata. */ this.uploadImageBuffer = fields.uploadImageBuffer.bind(this); @@ -66,6 +71,16 @@ class StableDiffusionAPI extends Tool { return `![generated image](/${imageUrl})`; } + returnValue(value) { + if (this.isAgent === true && typeof value === 'string') { + return [value, {}]; + } else if (this.isAgent === true && typeof value === 'object') { + return [displayMessage, value]; + } + + return value; + } + getServerURL() { const url = process.env.SD_WEBUI_URL || ''; if (!url && !this.override) { @@ -113,6 +128,25 @@ class StableDiffusionAPI extends Tool { } try { + if (this.isAgent) { + const content = [ + { + type: ContentTypes.IMAGE_URL, + image_url: { + url: `data:image/png;base64,${image}`, + }, + }, + ]; + + const response = [ + { + type: ContentTypes.TEXT, + text: displayMessage, + }, + ]; + return [response, { content }]; + } + const buffer = Buffer.from(image.split(',', 1)[0], 'base64'); if (this.returnMetadata && this.uploadImageBuffer && this.req) { const file = await this.uploadImageBuffer({ @@ -154,7 +188,7 @@ class StableDiffusionAPI extends Tool { logger.error('[StableDiffusion] Error while saving the image:', error); } - return this.result; + return this.returnValue(this.result); } } diff --git a/api/app/clients/tools/util/fileSearch.js b/api/app/clients/tools/util/fileSearch.js index 23ba58bb5a..54da483362 100644 --- a/api/app/clients/tools/util/fileSearch.js +++ b/api/app/clients/tools/util/fileSearch.js @@ -106,18 +106,21 @@ const createFileSearchTool = async ({ req, files, entity_id }) => { const formattedResults = validResults .flatMap((result) => - result.data.map(([docInfo, relevanceScore]) => ({ + result.data.map(([docInfo, distance]) => ({ filename: docInfo.metadata.source.split('/').pop(), content: docInfo.page_content, - relevanceScore, + distance, })), ) - .sort((a, b) => b.relevanceScore - a.relevanceScore); + // TODO: results should be sorted by relevance, not distance + .sort((a, b) => a.distance - b.distance) + // TODO: make this configurable + .slice(0, 10); const formattedString = formattedResults .map( (result) => - `File: ${result.filename}\nRelevance: ${result.relevanceScore.toFixed(4)}\nContent: ${ + `File: ${result.filename}\nRelevance: ${1.0 - result.distance.toFixed(4)}\nContent: ${ result.content }\n`, ) diff --git a/api/app/clients/tools/util/handleTools.js b/api/app/clients/tools/util/handleTools.js index f1dfa24a49..063d6e0327 100644 --- a/api/app/clients/tools/util/handleTools.js +++ b/api/app/clients/tools/util/handleTools.js @@ -10,6 +10,7 @@ const { GoogleSearchAPI, // Structured Tools DALLE3, + FluxAPI, OpenWeather, StructuredSD, StructuredACS, @@ -20,6 +21,7 @@ const { } = require('../'); const { primeFiles: primeCodeFiles } = require('~/server/services/Files/Code/process'); const { createFileSearchTool, primeFiles: primeSearchFiles } = require('./fileSearch'); +const { loadAuthValues } = require('~/server/services/Tools/credentials'); const { createMCPTool } = require('~/server/services/MCP'); const { loadSpecs } = require('./loadSpecs'); const { logger } = require('~/config'); @@ -89,45 +91,6 @@ const validateTools = async (user, tools = []) => { } }; -const loadAuthValues = async ({ userId, authFields, throwError = true }) => { - let authValues = {}; - - /** - * Finds the first non-empty value for the given authentication field, supporting alternate fields. - * @param {string[]} fields Array of strings representing the authentication fields. Supports alternate fields delimited by "||". - * @returns {Promise<{ authField: string, authValue: string} | null>} An object containing the authentication field and value, or null if not found. - */ - const findAuthValue = async (fields) => { - for (const field of fields) { - let value = process.env[field]; - if (value) { - return { authField: field, authValue: value }; - } - try { - value = await getUserPluginAuthValue(userId, field, throwError); - } catch (err) { - if (field === fields[fields.length - 1] && !value) { - throw err; - } - } - if (value) { - return { authField: field, authValue: value }; - } - } - return null; - }; - - for (let authField of authFields) { - const fields = authField.split('||'); - const result = await findAuthValue(fields); - if (result) { - authValues[result.authField] = result.authValue; - } - } - - return authValues; -}; - /** @typedef {typeof import('@langchain/core/tools').Tool} ToolConstructor */ /** @typedef {import('@langchain/core/tools').Tool} Tool */ @@ -182,6 +145,7 @@ const loadTools = async ({ returnMap = false, }) => { const toolConstructors = { + flux: FluxAPI, calculator: Calculator, google: GoogleSearchAPI, open_weather: OpenWeather, @@ -230,9 +194,10 @@ const loadTools = async ({ }; const toolOptions = { - serpapi: { location: 'Austin,Texas,United States', hl: 'en', gl: 'us' }, + flux: imageGenOptions, dalle: imageGenOptions, 'stable-diffusion': imageGenOptions, + serpapi: { location: 'Austin,Texas,United States', hl: 'en', gl: 'us' }, }; const toolContextMap = {}; @@ -345,7 +310,6 @@ const loadTools = async ({ module.exports = { loadToolWithAuth, - loadAuthValues, validateTools, loadTools, }; diff --git a/api/app/clients/tools/util/index.js b/api/app/clients/tools/util/index.js index 73d10270b6..ea67bb4ced 100644 --- a/api/app/clients/tools/util/index.js +++ b/api/app/clients/tools/util/index.js @@ -1,9 +1,8 @@ -const { validateTools, loadTools, loadAuthValues } = require('./handleTools'); +const { validateTools, loadTools } = require('./handleTools'); const handleOpenAIErrors = require('./handleOpenAIErrors'); module.exports = { handleOpenAIErrors, - loadAuthValues, validateTools, loadTools, }; diff --git a/api/cache/keyvRedis.js b/api/cache/keyvRedis.js index d544b50a11..992e789ae3 100644 --- a/api/cache/keyvRedis.js +++ b/api/cache/keyvRedis.js @@ -1,20 +1,86 @@ +const fs = require('fs'); +const ioredis = require('ioredis'); const KeyvRedis = require('@keyv/redis'); const { isEnabled } = require('~/server/utils'); const logger = require('~/config/winston'); -const { REDIS_URI, USE_REDIS } = process.env; +const { REDIS_URI, USE_REDIS, USE_REDIS_CLUSTER, REDIS_CA, REDIS_KEY_PREFIX, REDIS_MAX_LISTENERS } = + process.env; let keyvRedis; +const redis_prefix = REDIS_KEY_PREFIX || ''; +const redis_max_listeners = Number(REDIS_MAX_LISTENERS) || 40; + +function mapURI(uri) { + const regex = + /^(?:(?\w+):\/\/)?(?:(?[^:@]+)(?::(?[^@]+))?@)?(?[\w.-]+)(?::(?\d{1,5}))?$/; + const match = uri.match(regex); + + if (match) { + const { scheme, user, password, host, port } = match.groups; + + return { + scheme: scheme || 'none', + user: user || null, + password: password || null, + host: host || null, + port: port || null, + }; + } else { + const parts = uri.split(':'); + if (parts.length === 2) { + return { + scheme: 'none', + user: null, + password: null, + host: parts[0], + port: parts[1], + }; + } + + return { + scheme: 'none', + user: null, + password: null, + host: uri, + port: null, + }; + } +} if (REDIS_URI && isEnabled(USE_REDIS)) { - keyvRedis = new KeyvRedis(REDIS_URI, { useRedisSets: false }); + let redisOptions = null; + let keyvOpts = { + useRedisSets: false, + keyPrefix: redis_prefix, + }; + + if (REDIS_CA) { + const ca = fs.readFileSync(REDIS_CA); + redisOptions = { tls: { ca } }; + } + + if (isEnabled(USE_REDIS_CLUSTER)) { + const hosts = REDIS_URI.split(',').map((item) => { + var value = mapURI(item); + + return { + host: value.host, + port: value.port, + }; + }); + const cluster = new ioredis.Cluster(hosts, { redisOptions }); + keyvRedis = new KeyvRedis(cluster, keyvOpts); + } else { + keyvRedis = new KeyvRedis(REDIS_URI, keyvOpts); + } keyvRedis.on('error', (err) => logger.error('KeyvRedis connection error:', err)); - keyvRedis.setMaxListeners(20); + keyvRedis.setMaxListeners(redis_max_listeners); logger.info( - '[Optional] Redis initialized. Note: Redis support is experimental. If you have issues, disable it. Cache needs to be flushed for values to refresh.', + '[Optional] Redis initialized. If you have issues, or seeing older values, disable it or flush cache to refresh values.', ); } else { - logger.info('[Optional] Redis not initialized. Note: Redis support is experimental.'); + logger.info('[Optional] Redis not initialized.'); } module.exports = keyvRedis; diff --git a/api/config/index.js b/api/config/index.js index aaf8bb2764..8f23e404c8 100644 --- a/api/config/index.js +++ b/api/config/index.js @@ -1,3 +1,4 @@ +const axios = require('axios'); const { EventSource } = require('eventsource'); const { Time, CacheKeys } = require('librechat-data-provider'); const logger = require('./winston'); @@ -47,9 +48,46 @@ const sendEvent = (res, event) => { res.write(`event: message\ndata: ${JSON.stringify(event)}\n\n`); }; +/** + * Creates and configures an Axios instance with optional proxy settings. + * + * @typedef {import('axios').AxiosInstance} AxiosInstance + * @typedef {import('axios').AxiosProxyConfig} AxiosProxyConfig + * + * @returns {AxiosInstance} A configured Axios instance + * @throws {Error} If there's an issue creating the Axios instance or parsing the proxy URL + */ +function createAxiosInstance() { + const instance = axios.create(); + + if (process.env.proxy) { + try { + const url = new URL(process.env.proxy); + + /** @type {AxiosProxyConfig} */ + const proxyConfig = { + host: url.hostname.replace(/^\[|\]$/g, ''), + protocol: url.protocol.replace(':', ''), + }; + + if (url.port) { + proxyConfig.port = parseInt(url.port, 10); + } + + instance.defaults.proxy = proxyConfig; + } catch (error) { + console.error('Error parsing proxy URL:', error); + throw new Error(`Invalid proxy URL: ${process.env.proxy}`); + } + } + + return instance; +} + module.exports = { logger, sendEvent, getMCPManager, + createAxiosInstance, getFlowStateManager, }; diff --git a/api/config/index.spec.js b/api/config/index.spec.js new file mode 100644 index 0000000000..36ed8302f3 --- /dev/null +++ b/api/config/index.spec.js @@ -0,0 +1,126 @@ +const axios = require('axios'); +const { createAxiosInstance } = require('./index'); + +// Mock axios +jest.mock('axios', () => ({ + interceptors: { + request: { use: jest.fn(), eject: jest.fn() }, + response: { use: jest.fn(), eject: jest.fn() }, + }, + create: jest.fn().mockReturnValue({ + defaults: { + proxy: null, + }, + get: jest.fn().mockResolvedValue({ data: {} }), + post: jest.fn().mockResolvedValue({ data: {} }), + put: jest.fn().mockResolvedValue({ data: {} }), + delete: jest.fn().mockResolvedValue({ data: {} }), + }), + get: jest.fn().mockResolvedValue({ data: {} }), + post: jest.fn().mockResolvedValue({ data: {} }), + put: jest.fn().mockResolvedValue({ data: {} }), + delete: jest.fn().mockResolvedValue({ data: {} }), + reset: jest.fn().mockImplementation(function () { + this.get.mockClear(); + this.post.mockClear(); + this.put.mockClear(); + this.delete.mockClear(); + this.create.mockClear(); + }), +})); + +describe('createAxiosInstance', () => { + const originalEnv = process.env; + + beforeEach(() => { + // Reset mocks + jest.clearAllMocks(); + // Create a clean copy of process.env + process.env = { ...originalEnv }; + // Default: no proxy + delete process.env.proxy; + }); + + afterAll(() => { + // Restore original process.env + process.env = originalEnv; + }); + + test('creates an axios instance without proxy when no proxy env is set', () => { + const instance = createAxiosInstance(); + + expect(axios.create).toHaveBeenCalledTimes(1); + expect(instance.defaults.proxy).toBeNull(); + }); + + test('configures proxy correctly with hostname and protocol', () => { + process.env.proxy = 'http://example.com'; + + const instance = createAxiosInstance(); + + expect(axios.create).toHaveBeenCalledTimes(1); + expect(instance.defaults.proxy).toEqual({ + host: 'example.com', + protocol: 'http', + }); + }); + + test('configures proxy correctly with hostname, protocol and port', () => { + process.env.proxy = 'https://proxy.example.com:8080'; + + const instance = createAxiosInstance(); + + expect(axios.create).toHaveBeenCalledTimes(1); + expect(instance.defaults.proxy).toEqual({ + host: 'proxy.example.com', + protocol: 'https', + port: 8080, + }); + }); + + test('handles proxy URLs with authentication', () => { + process.env.proxy = 'http://user:pass@proxy.example.com:3128'; + + const instance = createAxiosInstance(); + + expect(axios.create).toHaveBeenCalledTimes(1); + expect(instance.defaults.proxy).toEqual({ + host: 'proxy.example.com', + protocol: 'http', + port: 3128, + // Note: The current implementation doesn't handle auth - if needed, add this functionality + }); + }); + + test('throws error when proxy URL is invalid', () => { + process.env.proxy = 'invalid-url'; + + expect(() => createAxiosInstance()).toThrow('Invalid proxy URL'); + expect(axios.create).toHaveBeenCalledTimes(1); + }); + + // If you want to test the actual URL parsing more thoroughly + test('handles edge case proxy URLs correctly', () => { + // IPv6 address + process.env.proxy = 'http://[::1]:8080'; + + let instance = createAxiosInstance(); + + expect(instance.defaults.proxy).toEqual({ + host: '::1', + protocol: 'http', + port: 8080, + }); + + // URL with path (which should be ignored for proxy config) + process.env.proxy = 'http://proxy.example.com:8080/some/path'; + + instance = createAxiosInstance(); + + expect(instance.defaults.proxy).toEqual({ + host: 'proxy.example.com', + protocol: 'http', + port: 8080, + }); + }); +}); diff --git a/api/lib/db/indexSync.js b/api/lib/db/indexSync.js index 86c909419d..75acd9d231 100644 --- a/api/lib/db/indexSync.js +++ b/api/lib/db/indexSync.js @@ -1,9 +1,11 @@ const { MeiliSearch } = require('meilisearch'); -const Conversation = require('~/models/schema/convoSchema'); -const Message = require('~/models/schema/messageSchema'); +const { Conversation } = require('~/models/Conversation'); +const { Message } = require('~/models/Message'); +const { isEnabled } = require('~/server/utils'); const { logger } = require('~/config'); -const searchEnabled = process.env?.SEARCH?.toLowerCase() === 'true'; +const searchEnabled = isEnabled(process.env.SEARCH); +const indexingDisabled = isEnabled(process.env.MEILI_NO_SYNC); let currentTimeout = null; class MeiliSearchClient { @@ -23,8 +25,7 @@ class MeiliSearchClient { } } -// eslint-disable-next-line no-unused-vars -async function indexSync(req, res, next) { +async function indexSync() { if (!searchEnabled) { return; } @@ -33,10 +34,15 @@ async function indexSync(req, res, next) { const client = MeiliSearchClient.getInstance(); const { status } = await client.health(); - if (status !== 'available' || !process.env.SEARCH) { + if (status !== 'available') { throw new Error('Meilisearch not available'); } + if (indexingDisabled === true) { + logger.info('[indexSync] Indexing is disabled, skipping...'); + return; + } + const messageCount = await Message.countDocuments(); const convoCount = await Conversation.countDocuments(); const messages = await client.index('messages').getStats(); @@ -71,7 +77,6 @@ async function indexSync(req, res, next) { logger.info('[indexSync] Meilisearch not configured, search will be disabled.'); } else { logger.error('[indexSync] error', err); - // res.status(500).json({ error: 'Server error' }); } } } diff --git a/api/models/Action.js b/api/models/Action.js index 299b3bf20a..677b4d78df 100644 --- a/api/models/Action.js +++ b/api/models/Action.js @@ -1,5 +1,5 @@ const mongoose = require('mongoose'); -const actionSchema = require('./schema/action'); +const { actionSchema } = require('@librechat/data-schemas'); const Action = mongoose.model('action', actionSchema); diff --git a/api/models/Agent.js b/api/models/Agent.js index 6fa00f56bc..1d3ea5af0c 100644 --- a/api/models/Agent.js +++ b/api/models/Agent.js @@ -9,7 +9,7 @@ const { removeAgentFromAllProjects, } = require('./Project'); const getLogStores = require('~/cache/getLogStores'); -const agentSchema = require('./schema/agent'); +const { agentSchema } = require('@librechat/data-schemas'); const Agent = mongoose.model('agent', agentSchema); @@ -97,11 +97,22 @@ const updateAgent = async (searchParameter, updateData) => { const addAgentResourceFile = async ({ agent_id, tool_resource, file_id }) => { const searchParameter = { id: agent_id }; - // build the update to push or create the file ids set const fileIdsPath = `tool_resources.${tool_resource}.file_ids`; + + await Agent.updateOne( + { + id: agent_id, + [`${fileIdsPath}`]: { $exists: false }, + }, + { + $set: { + [`${fileIdsPath}`]: [], + }, + }, + ); + const updateData = { $addToSet: { [fileIdsPath]: file_id } }; - // return the updated agent or throw if no agent matches const updatedAgent = await updateAgent(searchParameter, updateData); if (updatedAgent) { return updatedAgent; @@ -290,6 +301,7 @@ const updateAgentProjects = async ({ user, agentId, projectIds, removeProjectIds }; module.exports = { + Agent, getAgent, loadAgent, createAgent, diff --git a/api/models/Agent.spec.js b/api/models/Agent.spec.js new file mode 100644 index 0000000000..769eda2bb7 --- /dev/null +++ b/api/models/Agent.spec.js @@ -0,0 +1,160 @@ +const mongoose = require('mongoose'); +const { v4: uuidv4 } = require('uuid'); +const { MongoMemoryServer } = require('mongodb-memory-server'); +const { Agent, addAgentResourceFile, removeAgentResourceFiles } = require('./Agent'); + +describe('Agent Resource File Operations', () => { + let mongoServer; + + beforeAll(async () => { + mongoServer = await MongoMemoryServer.create(); + const mongoUri = mongoServer.getUri(); + await mongoose.connect(mongoUri); + }); + + afterAll(async () => { + await mongoose.disconnect(); + await mongoServer.stop(); + }); + + beforeEach(async () => { + await Agent.deleteMany({}); + }); + + const createBasicAgent = async () => { + const agentId = `agent_${uuidv4()}`; + const agent = await Agent.create({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: new mongoose.Types.ObjectId(), + }); + return agent; + }; + + test('should handle concurrent file additions', async () => { + const agent = await createBasicAgent(); + const fileIds = Array.from({ length: 10 }, () => uuidv4()); + + // Concurrent additions + const additionPromises = fileIds.map((fileId) => + addAgentResourceFile({ + agent_id: agent.id, + tool_resource: 'test_tool', + file_id: fileId, + }), + ); + + await Promise.all(additionPromises); + + const updatedAgent = await Agent.findOne({ id: agent.id }); + expect(updatedAgent.tool_resources.test_tool.file_ids).toBeDefined(); + expect(updatedAgent.tool_resources.test_tool.file_ids).toHaveLength(10); + expect(new Set(updatedAgent.tool_resources.test_tool.file_ids).size).toBe(10); + }); + + test('should handle concurrent additions and removals', async () => { + const agent = await createBasicAgent(); + const initialFileIds = Array.from({ length: 5 }, () => uuidv4()); + + await Promise.all( + initialFileIds.map((fileId) => + addAgentResourceFile({ + agent_id: agent.id, + tool_resource: 'test_tool', + file_id: fileId, + }), + ), + ); + + const newFileIds = Array.from({ length: 5 }, () => uuidv4()); + const operations = [ + ...newFileIds.map((fileId) => + addAgentResourceFile({ + agent_id: agent.id, + tool_resource: 'test_tool', + file_id: fileId, + }), + ), + ...initialFileIds.map((fileId) => + removeAgentResourceFiles({ + agent_id: agent.id, + files: [{ tool_resource: 'test_tool', file_id: fileId }], + }), + ), + ]; + + await Promise.all(operations); + + const updatedAgent = await Agent.findOne({ id: agent.id }); + expect(updatedAgent.tool_resources.test_tool.file_ids).toBeDefined(); + expect(updatedAgent.tool_resources.test_tool.file_ids).toHaveLength(5); + }); + + test('should initialize array when adding to non-existent tool resource', async () => { + const agent = await createBasicAgent(); + const fileId = uuidv4(); + + const updatedAgent = await addAgentResourceFile({ + agent_id: agent.id, + tool_resource: 'new_tool', + file_id: fileId, + }); + + expect(updatedAgent.tool_resources.new_tool.file_ids).toBeDefined(); + expect(updatedAgent.tool_resources.new_tool.file_ids).toHaveLength(1); + expect(updatedAgent.tool_resources.new_tool.file_ids[0]).toBe(fileId); + }); + + test('should handle rapid sequential modifications to same tool resource', async () => { + const agent = await createBasicAgent(); + const fileId = uuidv4(); + + for (let i = 0; i < 10; i++) { + await addAgentResourceFile({ + agent_id: agent.id, + tool_resource: 'test_tool', + file_id: `${fileId}_${i}`, + }); + + if (i % 2 === 0) { + await removeAgentResourceFiles({ + agent_id: agent.id, + files: [{ tool_resource: 'test_tool', file_id: `${fileId}_${i}` }], + }); + } + } + + const updatedAgent = await Agent.findOne({ id: agent.id }); + expect(updatedAgent.tool_resources.test_tool.file_ids).toBeDefined(); + expect(Array.isArray(updatedAgent.tool_resources.test_tool.file_ids)).toBe(true); + }); + + test('should handle multiple tool resources concurrently', async () => { + const agent = await createBasicAgent(); + const toolResources = ['tool1', 'tool2', 'tool3']; + const operations = []; + + toolResources.forEach((tool) => { + const fileIds = Array.from({ length: 5 }, () => uuidv4()); + fileIds.forEach((fileId) => { + operations.push( + addAgentResourceFile({ + agent_id: agent.id, + tool_resource: tool, + file_id: fileId, + }), + ); + }); + }); + + await Promise.all(operations); + + const updatedAgent = await Agent.findOne({ id: agent.id }); + toolResources.forEach((tool) => { + expect(updatedAgent.tool_resources[tool].file_ids).toBeDefined(); + expect(updatedAgent.tool_resources[tool].file_ids).toHaveLength(5); + }); + }); +}); diff --git a/api/models/Assistant.js b/api/models/Assistant.js index d0e73ad4e7..a8a5b98157 100644 --- a/api/models/Assistant.js +++ b/api/models/Assistant.js @@ -1,5 +1,5 @@ const mongoose = require('mongoose'); -const assistantSchema = require('./schema/assistant'); +const { assistantSchema } = require('@librechat/data-schemas'); const Assistant = mongoose.model('assistant', assistantSchema); diff --git a/api/models/Balance.js b/api/models/Balance.js index 24d9087b77..f7978d8049 100644 --- a/api/models/Balance.js +++ b/api/models/Balance.js @@ -1,5 +1,5 @@ const mongoose = require('mongoose'); -const balanceSchema = require('./schema/balance'); +const { balanceSchema } = require('@librechat/data-schemas'); const { getMultiplier } = require('./tx'); const { logger } = require('~/config'); diff --git a/api/models/Banner.js b/api/models/Banner.js index 8d439dae28..399a8e72ee 100644 --- a/api/models/Banner.js +++ b/api/models/Banner.js @@ -1,5 +1,9 @@ -const Banner = require('./schema/banner'); +const mongoose = require('mongoose'); const logger = require('~/config/winston'); +const { bannerSchema } = require('@librechat/data-schemas'); + +const Banner = mongoose.model('Banner', bannerSchema); + /** * Retrieves the current active banner. * @returns {Promise} The active banner object or null if no active banner is found. @@ -24,4 +28,4 @@ const getBanner = async (user) => { } }; -module.exports = { getBanner }; +module.exports = { Banner, getBanner }; diff --git a/api/models/Categories.js b/api/models/Categories.js index 0f7f29703f..5da1f4b2da 100644 --- a/api/models/Categories.js +++ b/api/models/Categories.js @@ -1,40 +1,40 @@ const { logger } = require('~/config'); -// const { Categories } = require('./schema/categories'); + const options = [ { - label: 'idea', + label: 'com_ui_idea', value: 'idea', }, { - label: 'travel', + label: 'com_ui_travel', value: 'travel', }, { - label: 'teach_or_explain', + label: 'com_ui_teach_or_explain', value: 'teach_or_explain', }, { - label: 'write', + label: 'com_ui_write', value: 'write', }, { - label: 'shop', + label: 'com_ui_shop', value: 'shop', }, { - label: 'code', + label: 'com_ui_code', value: 'code', }, { - label: 'misc', + label: 'com_ui_misc', value: 'misc', }, { - label: 'roleplay', + label: 'com_ui_roleplay', value: 'roleplay', }, { - label: 'finance', + label: 'com_ui_finance', value: 'finance', }, ]; diff --git a/api/models/Conversation.js b/api/models/Conversation.js index d6365e99ce..dd6ef9bde1 100644 --- a/api/models/Conversation.js +++ b/api/models/Conversation.js @@ -15,19 +15,6 @@ const searchConversation = async (conversationId) => { throw new Error('Error searching conversation'); } }; -/** - * Searches for a conversation by conversationId and returns associated file ids. - * @param {string} conversationId - The conversation's ID. - * @returns {Promise} - */ -const getConvoFiles = async (conversationId) => { - try { - return (await Conversation.findOne({ conversationId }, 'files').lean())?.files ?? []; - } catch (error) { - logger.error('[getConvoFiles] Error getting conversation files', error); - throw new Error('Error getting conversation files'); - } -}; /** * Retrieves a single conversation for a given user and conversation ID. @@ -73,6 +60,20 @@ const deleteNullOrEmptyConversations = async () => { } }; +/** + * Searches for a conversation by conversationId and returns associated file ids. + * @param {string} conversationId - The conversation's ID. + * @returns {Promise} + */ +const getConvoFiles = async (conversationId) => { + try { + return (await Conversation.findOne({ conversationId }, 'files').lean())?.files ?? []; + } catch (error) { + logger.error('[getConvoFiles] Error getting conversation files', error); + throw new Error('Error getting conversation files'); + } +}; + module.exports = { Conversation, getConvoFiles, @@ -104,10 +105,16 @@ module.exports = { update.expiredAt = null; } + /** @type {{ $set: Partial; $unset?: Record }} */ + const updateOperation = { $set: update }; + if (metadata && metadata.unsetFields && Object.keys(metadata.unsetFields).length > 0) { + updateOperation.$unset = metadata.unsetFields; + } + /** Note: the resulting Model object is necessary for Meilisearch operations */ const conversation = await Conversation.findOneAndUpdate( { conversationId, user: req.user.id }, - update, + updateOperation, { new: true, upsert: true, diff --git a/api/models/ConversationTag.js b/api/models/ConversationTag.js index 53d144e1f5..f0cac8620e 100644 --- a/api/models/ConversationTag.js +++ b/api/models/ConversationTag.js @@ -1,7 +1,11 @@ -const ConversationTag = require('./schema/conversationTagSchema'); +const mongoose = require('mongoose'); const Conversation = require('./schema/convoSchema'); const logger = require('~/config/winston'); +const { conversationTagSchema } = require('@librechat/data-schemas'); + +const ConversationTag = mongoose.model('ConversationTag', conversationTagSchema); + /** * Retrieves all conversation tags for a user. * @param {string} user - The user ID. diff --git a/api/models/File.js b/api/models/File.js index 17f8506600..0bde258a54 100644 --- a/api/models/File.js +++ b/api/models/File.js @@ -1,5 +1,6 @@ const mongoose = require('mongoose'); -const fileSchema = require('./schema/fileSchema'); +const { fileSchema } = require('@librechat/data-schemas'); +const { logger } = require('~/config'); const File = mongoose.model('File', fileSchema); @@ -7,7 +8,7 @@ const File = mongoose.model('File', fileSchema); * Finds a file by its file_id with additional query options. * @param {string} file_id - The unique identifier of the file. * @param {object} options - Query options for filtering, projection, etc. - * @returns {Promise} A promise that resolves to the file document or null. + * @returns {Promise} A promise that resolves to the file document or null. */ const findFileById = async (file_id, options = {}) => { return await File.findOne({ file_id, ...options }).lean(); @@ -17,18 +18,46 @@ const findFileById = async (file_id, options = {}) => { * Retrieves files matching a given filter, sorted by the most recently updated. * @param {Object} filter - The filter criteria to apply. * @param {Object} [_sortOptions] - Optional sort parameters. - * @returns {Promise>} A promise that resolves to an array of file documents. + * @param {Object|String} [selectFields={ text: 0 }] - Fields to include/exclude in the query results. + * Default excludes the 'text' field. + * @returns {Promise>} A promise that resolves to an array of file documents. */ -const getFiles = async (filter, _sortOptions) => { +const getFiles = async (filter, _sortOptions, selectFields = { text: 0 }) => { const sortOptions = { updatedAt: -1, ..._sortOptions }; - return await File.find(filter).sort(sortOptions).lean(); + return await File.find(filter).select(selectFields).sort(sortOptions).lean(); +}; + +/** + * Retrieves tool files (files that are embedded or have a fileIdentifier) from an array of file IDs + * @param {string[]} fileIds - Array of file_id strings to search for + * @returns {Promise>} Files that match the criteria + */ +const getToolFilesByIds = async (fileIds) => { + if (!fileIds || !fileIds.length) { + return []; + } + + try { + const filter = { + file_id: { $in: fileIds }, + $or: [{ embedded: true }, { 'metadata.fileIdentifier': { $exists: true } }], + }; + + const selectFields = { text: 0 }; + const sortOptions = { updatedAt: -1 }; + + return await getFiles(filter, sortOptions, selectFields); + } catch (error) { + logger.error('[getToolFilesByIds] Error retrieving tool files:', error); + throw new Error('Error retrieving tool files'); + } }; /** * Creates a new file with a TTL of 1 hour. - * @param {MongoFile} data - The file data to be created, must contain file_id. + * @param {IMongoFile} data - The file data to be created, must contain file_id. * @param {boolean} disableTTL - Whether to disable the TTL. - * @returns {Promise} A promise that resolves to the created file document. + * @returns {Promise} A promise that resolves to the created file document. */ const createFile = async (data, disableTTL) => { const fileData = { @@ -48,8 +77,8 @@ const createFile = async (data, disableTTL) => { /** * Updates a file identified by file_id with new data and removes the TTL. - * @param {MongoFile} data - The data to update, must contain file_id. - * @returns {Promise} A promise that resolves to the updated file document. + * @param {IMongoFile} data - The data to update, must contain file_id. + * @returns {Promise} A promise that resolves to the updated file document. */ const updateFile = async (data) => { const { file_id, ...update } = data; @@ -62,8 +91,8 @@ const updateFile = async (data) => { /** * Increments the usage of a file identified by file_id. - * @param {MongoFile} data - The data to update, must contain file_id and the increment value for usage. - * @returns {Promise} A promise that resolves to the updated file document. + * @param {IMongoFile} data - The data to update, must contain file_id and the increment value for usage. + * @returns {Promise} A promise that resolves to the updated file document. */ const updateFileUsage = async (data) => { const { file_id, inc = 1 } = data; @@ -77,7 +106,7 @@ const updateFileUsage = async (data) => { /** * Deletes a file identified by file_id. * @param {string} file_id - The unique identifier of the file to delete. - * @returns {Promise} A promise that resolves to the deleted file document or null. + * @returns {Promise} A promise that resolves to the deleted file document or null. */ const deleteFile = async (file_id) => { return await File.findOneAndDelete({ file_id }).lean(); @@ -86,7 +115,7 @@ const deleteFile = async (file_id) => { /** * Deletes a file identified by a filter. * @param {object} filter - The filter criteria to apply. - * @returns {Promise} A promise that resolves to the deleted file document or null. + * @returns {Promise} A promise that resolves to the deleted file document or null. */ const deleteFileByFilter = async (filter) => { return await File.findOneAndDelete(filter).lean(); @@ -109,6 +138,7 @@ module.exports = { File, findFileById, getFiles, + getToolFilesByIds, createFile, updateFile, updateFileUsage, diff --git a/api/models/Key.js b/api/models/Key.js index 58fb0ac3a9..c69c350a42 100644 --- a/api/models/Key.js +++ b/api/models/Key.js @@ -1,4 +1,4 @@ const mongoose = require('mongoose'); -const keySchema = require('./schema/key'); +const { keySchema } = require('@librechat/data-schemas'); module.exports = mongoose.model('Key', keySchema); diff --git a/api/models/Message.js b/api/models/Message.js index e651b20ad0..58068813ef 100644 --- a/api/models/Message.js +++ b/api/models/Message.js @@ -71,7 +71,42 @@ async function saveMessage(req, params, metadata) { } catch (err) { logger.error('Error saving message:', err); logger.info(`---\`saveMessage\` context: ${metadata?.context}`); - throw err; + + // Check if this is a duplicate key error (MongoDB error code 11000) + if (err.code === 11000 && err.message.includes('duplicate key error')) { + // Log the duplicate key error but don't crash the application + logger.warn(`Duplicate messageId detected: ${params.messageId}. Continuing execution.`); + + try { + // Try to find the existing message with this ID + const existingMessage = await Message.findOne({ + messageId: params.messageId, + user: req.user.id, + }); + + // If we found it, return it + if (existingMessage) { + return existingMessage.toObject(); + } + + // If we can't find it (unlikely but possible in race conditions) + return { + ...params, + messageId: params.messageId, + user: req.user.id, + }; + } catch (findError) { + // If the findOne also fails, log it but don't crash + logger.warn(`Could not retrieve existing message with ID ${params.messageId}: ${findError.message}`); + return { + ...params, + messageId: params.messageId, + user: req.user.id, + }; + } + } + + throw err; // Re-throw other errors } } diff --git a/api/models/Project.js b/api/models/Project.js index 17ef3093a5..43d7263723 100644 --- a/api/models/Project.js +++ b/api/models/Project.js @@ -1,6 +1,6 @@ const { model } = require('mongoose'); const { GLOBAL_PROJECT_NAME } = require('librechat-data-provider').Constants; -const projectSchema = require('~/models/schema/projectSchema'); +const { projectSchema } = require('@librechat/data-schemas'); const Project = model('Project', projectSchema); @@ -9,7 +9,7 @@ const Project = model('Project', projectSchema); * * @param {string} projectId - The ID of the project to find and return as a plain object. * @param {string|string[]} [fieldsToSelect] - The fields to include or exclude in the returned document. - * @returns {Promise} A plain object representing the project document, or `null` if no project is found. + * @returns {Promise} A plain object representing the project document, or `null` if no project is found. */ const getProjectById = async function (projectId, fieldsToSelect = null) { const query = Project.findById(projectId); @@ -27,7 +27,7 @@ const getProjectById = async function (projectId, fieldsToSelect = null) { * * @param {string} projectName - The name of the project to find or create. * @param {string|string[]} [fieldsToSelect] - The fields to include or exclude in the returned document. - * @returns {Promise} A plain object representing the project document. + * @returns {Promise} A plain object representing the project document. */ const getProjectByName = async function (projectName, fieldsToSelect = null) { const query = { name: projectName }; @@ -47,7 +47,7 @@ const getProjectByName = async function (projectName, fieldsToSelect = null) { * * @param {string} projectId - The ID of the project to update. * @param {string[]} promptGroupIds - The array of prompt group IDs to add to the project. - * @returns {Promise} The updated project document. + * @returns {Promise} The updated project document. */ const addGroupIdsToProject = async function (projectId, promptGroupIds) { return await Project.findByIdAndUpdate( @@ -62,7 +62,7 @@ const addGroupIdsToProject = async function (projectId, promptGroupIds) { * * @param {string} projectId - The ID of the project to update. * @param {string[]} promptGroupIds - The array of prompt group IDs to remove from the project. - * @returns {Promise} The updated project document. + * @returns {Promise} The updated project document. */ const removeGroupIdsFromProject = async function (projectId, promptGroupIds) { return await Project.findByIdAndUpdate( @@ -87,7 +87,7 @@ const removeGroupFromAllProjects = async (promptGroupId) => { * * @param {string} projectId - The ID of the project to update. * @param {string[]} agentIds - The array of agent IDs to add to the project. - * @returns {Promise} The updated project document. + * @returns {Promise} The updated project document. */ const addAgentIdsToProject = async function (projectId, agentIds) { return await Project.findByIdAndUpdate( @@ -102,7 +102,7 @@ const addAgentIdsToProject = async function (projectId, agentIds) { * * @param {string} projectId - The ID of the project to update. * @param {string[]} agentIds - The array of agent IDs to remove from the project. - * @returns {Promise} The updated project document. + * @returns {Promise} The updated project document. */ const removeAgentIdsFromProject = async function (projectId, agentIds) { return await Project.findByIdAndUpdate( diff --git a/api/models/Prompt.js b/api/models/Prompt.js index 60456884a8..43dc3ec22b 100644 --- a/api/models/Prompt.js +++ b/api/models/Prompt.js @@ -1,3 +1,4 @@ +const mongoose = require('mongoose'); const { ObjectId } = require('mongodb'); const { SystemRoles, SystemCategories, Constants } = require('librechat-data-provider'); const { @@ -6,10 +7,13 @@ const { removeGroupIdsFromProject, removeGroupFromAllProjects, } = require('./Project'); -const { Prompt, PromptGroup } = require('./schema/promptSchema'); +const { promptGroupSchema, promptSchema } = require('@librechat/data-schemas'); const { escapeRegExp } = require('~/server/utils'); const { logger } = require('~/config'); +const PromptGroup = mongoose.model('PromptGroup', promptGroupSchema); +const Prompt = mongoose.model('Prompt', promptSchema); + /** * Create a pipeline for the aggregation to get prompt groups * @param {Object} query diff --git a/api/models/Role.js b/api/models/Role.js index 9c160512b7..4be5faeadb 100644 --- a/api/models/Role.js +++ b/api/models/Role.js @@ -1,3 +1,4 @@ +const mongoose = require('mongoose'); const { CacheKeys, SystemRoles, @@ -6,13 +7,17 @@ const { removeNullishValues, agentPermissionsSchema, promptPermissionsSchema, + runCodePermissionsSchema, bookmarkPermissionsSchema, multiConvoPermissionsSchema, + temporaryChatPermissionsSchema, } = require('librechat-data-provider'); const getLogStores = require('~/cache/getLogStores'); -const Role = require('~/models/schema/roleSchema'); +const { roleSchema } = require('@librechat/data-schemas'); const { logger } = require('~/config'); +const Role = mongoose.model('Role', roleSchema); + /** * Retrieve a role by name and convert the found role document to a plain object. * If the role with the given name doesn't exist and the name is a system defined role, create it and return the lean version. @@ -77,6 +82,8 @@ const permissionSchemas = { [PermissionTypes.PROMPTS]: promptPermissionsSchema, [PermissionTypes.BOOKMARKS]: bookmarkPermissionsSchema, [PermissionTypes.MULTI_CONVO]: multiConvoPermissionsSchema, + [PermissionTypes.TEMPORARY_CHAT]: temporaryChatPermissionsSchema, + [PermissionTypes.RUN_CODE]: runCodePermissionsSchema, }; /** @@ -164,6 +171,7 @@ const initializeRoles = async function () { } }; module.exports = { + Role, getRoleByName, initializeRoles, updateRoleByName, diff --git a/api/models/Role.spec.js b/api/models/Role.spec.js index 92386f0fa9..39611f7b95 100644 --- a/api/models/Role.spec.js +++ b/api/models/Role.spec.js @@ -8,7 +8,7 @@ const { } = require('librechat-data-provider'); const { updateAccessPermissions, initializeRoles } = require('~/models/Role'); const getLogStores = require('~/cache/getLogStores'); -const Role = require('~/models/schema/roleSchema'); +const { Role } = require('~/models/Role'); // Mock the cache jest.mock('~/cache/getLogStores', () => { diff --git a/api/models/Session.js b/api/models/Session.js index dbb66ed8ff..38821b77dd 100644 --- a/api/models/Session.js +++ b/api/models/Session.js @@ -1,7 +1,7 @@ const mongoose = require('mongoose'); const signPayload = require('~/server/services/signPayload'); const { hashToken } = require('~/server/utils/crypto'); -const sessionSchema = require('./schema/session'); +const { sessionSchema } = require('@librechat/data-schemas'); const { logger } = require('~/config'); const Session = mongoose.model('Session', sessionSchema); diff --git a/api/models/Share.js b/api/models/Share.js index 041927ec61..a8bfbce7fe 100644 --- a/api/models/Share.js +++ b/api/models/Share.js @@ -1,7 +1,9 @@ +const mongoose = require('mongoose'); const { nanoid } = require('nanoid'); const { Constants } = require('librechat-data-provider'); const { Conversation } = require('~/models/Conversation'); -const SharedLink = require('./schema/shareSchema'); +const { shareSchema } = require('@librechat/data-schemas'); +const SharedLink = mongoose.model('SharedLink', shareSchema); const { getMessages } = require('./Message'); const logger = require('~/config/winston'); diff --git a/api/models/Token.js b/api/models/Token.js index 210666ddd7..c89abb8c84 100644 --- a/api/models/Token.js +++ b/api/models/Token.js @@ -1,6 +1,6 @@ const mongoose = require('mongoose'); const { encryptV2 } = require('~/server/utils/crypto'); -const tokenSchema = require('./schema/tokenSchema'); +const { tokenSchema } = require('@librechat/data-schemas'); const { logger } = require('~/config'); /** @@ -13,6 +13,13 @@ const Token = mongoose.model('Token', tokenSchema); */ async function fixIndexes() { try { + if ( + process.env.NODE_ENV === 'CI' || + process.env.NODE_ENV === 'development' || + process.env.NODE_ENV === 'test' + ) { + return; + } const indexes = await Token.collection.indexes(); logger.debug('Existing Token Indexes:', JSON.stringify(indexes, null, 2)); const unwantedTTLIndexes = indexes.filter( diff --git a/api/models/ToolCall.js b/api/models/ToolCall.js index e1d7b0cc84..7bc0f157dc 100644 --- a/api/models/ToolCall.js +++ b/api/models/ToolCall.js @@ -1,9 +1,11 @@ -const ToolCall = require('./schema/toolCallSchema'); +const mongoose = require('mongoose'); +const { toolCallSchema } = require('@librechat/data-schemas'); +const ToolCall = mongoose.model('ToolCall', toolCallSchema); /** * Create a new tool call - * @param {ToolCallData} toolCallData - The tool call data - * @returns {Promise} The created tool call document + * @param {IToolCallData} toolCallData - The tool call data + * @returns {Promise} The created tool call document */ async function createToolCall(toolCallData) { try { @@ -16,7 +18,7 @@ async function createToolCall(toolCallData) { /** * Get a tool call by ID * @param {string} id - The tool call document ID - * @returns {Promise} The tool call document or null if not found + * @returns {Promise} The tool call document or null if not found */ async function getToolCallById(id) { try { @@ -44,7 +46,7 @@ async function getToolCallsByMessage(messageId, userId) { * Get tool calls by conversation ID and user * @param {string} conversationId - The conversation ID * @param {string} userId - The user's ObjectId - * @returns {Promise} Array of tool call documents + * @returns {Promise} Array of tool call documents */ async function getToolCallsByConvo(conversationId, userId) { try { @@ -57,8 +59,8 @@ async function getToolCallsByConvo(conversationId, userId) { /** * Update a tool call * @param {string} id - The tool call document ID - * @param {Partial} updateData - The data to update - * @returns {Promise} The updated tool call document or null if not found + * @param {Partial} updateData - The data to update + * @returns {Promise} The updated tool call document or null if not found */ async function updateToolCall(id, updateData) { try { diff --git a/api/models/Transaction.js b/api/models/Transaction.js index 8435a812c4..b1c4c65710 100644 --- a/api/models/Transaction.js +++ b/api/models/Transaction.js @@ -1,6 +1,6 @@ const mongoose = require('mongoose'); const { isEnabled } = require('~/server/utils/handleText'); -const transactionSchema = require('./schema/transaction'); +const { transactionSchema } = require('@librechat/data-schemas'); const { getMultiplier, getCacheMultiplier } = require('./tx'); const { logger } = require('~/config'); const Balance = require('./Balance'); diff --git a/api/models/User.js b/api/models/User.js index 55750b4ae5..f4e8b0ec5b 100644 --- a/api/models/User.js +++ b/api/models/User.js @@ -1,5 +1,5 @@ const mongoose = require('mongoose'); -const userSchema = require('~/models/schema/userSchema'); +const { userSchema } = require('@librechat/data-schemas'); const User = mongoose.model('User', userSchema); diff --git a/api/models/plugins/mongoMeili.js b/api/models/plugins/mongoMeili.js index df96338302..6577370b1e 100644 --- a/api/models/plugins/mongoMeili.js +++ b/api/models/plugins/mongoMeili.js @@ -4,9 +4,28 @@ const { MeiliSearch } = require('meilisearch'); const { cleanUpPrimaryKeyValue } = require('~/lib/utils/misc'); const logger = require('~/config/meiliLogger'); +// Environment flags +/** + * Flag to indicate if search is enabled based on environment variables. + * @type {boolean} + */ const searchEnabled = process.env.SEARCH && process.env.SEARCH.toLowerCase() === 'true'; + +/** + * Flag to indicate if MeiliSearch is enabled based on required environment variables. + * @type {boolean} + */ const meiliEnabled = process.env.MEILI_HOST && process.env.MEILI_MASTER_KEY && searchEnabled; +/** + * Validates the required options for configuring the mongoMeili plugin. + * + * @param {Object} options - The configuration options. + * @param {string} options.host - The MeiliSearch host. + * @param {string} options.apiKey - The MeiliSearch API key. + * @param {string} options.indexName - The name of the index. + * @throws {Error} Throws an error if any required option is missing. + */ const validateOptions = function (options) { const requiredKeys = ['host', 'apiKey', 'indexName']; requiredKeys.forEach((key) => { @@ -16,53 +35,64 @@ const validateOptions = function (options) { }); }; -// const createMeiliMongooseModel = function ({ index, indexName, client, attributesToIndex }) { +/** + * Factory function to create a MeiliMongooseModel class which extends a Mongoose model. + * This class contains static and instance methods to synchronize and manage the MeiliSearch index + * corresponding to the MongoDB collection. + * + * @param {Object} config - Configuration object. + * @param {Object} config.index - The MeiliSearch index object. + * @param {Array} config.attributesToIndex - List of attributes to index. + * @returns {Function} A class definition that will be loaded into the Mongoose schema. + */ const createMeiliMongooseModel = function ({ index, attributesToIndex }) { + // The primary key is assumed to be the first attribute in the attributesToIndex array. const primaryKey = attributesToIndex[0]; - // MeiliMongooseModel is of type Mongoose.Model + class MeiliMongooseModel { /** - * `syncWithMeili`: synchronizes the data between a MongoDB collection and a MeiliSearch index, - * only triggered if there's ever a discrepancy determined by `api\lib\db\indexSync.js`. + * Synchronizes the data between the MongoDB collection and the MeiliSearch index. * - * 1. Fetches all documents from the MongoDB collection and the MeiliSearch index. - * 2. Compares the documents from both sources. - * 3. If a document exists in MeiliSearch but not in MongoDB, it's deleted from MeiliSearch. - * 4. If a document exists in MongoDB but not in MeiliSearch, it's added to MeiliSearch. - * 5. If a document exists in both but has different `text` or `title` fields (depending on the `primaryKey`), it's updated in MeiliSearch. - * 6. After all operations, it updates the `_meiliIndex` field in MongoDB to indicate whether the document is indexed in MeiliSearch. + * The synchronization process involves: + * 1. Fetching all documents from the MongoDB collection and MeiliSearch index. + * 2. Comparing documents from both sources. + * 3. Deleting documents from MeiliSearch that no longer exist in MongoDB. + * 4. Adding documents to MeiliSearch that exist in MongoDB but not in the index. + * 5. Updating documents in MeiliSearch if key fields (such as `text` or `title`) differ. + * 6. Updating the `_meiliIndex` field in MongoDB to indicate the indexing status. * - * Note: This strategy does not use batch operations for Meilisearch as the `index.addDocuments` will discard - * the entire batch if there's an error with one document, and will not throw an error if there's an issue. - * Also, `index.getDocuments` needs an exact limit on the amount of documents to return, so we build the map in batches. + * Note: The function processes documents in batches because MeiliSearch's + * `index.getDocuments` requires an exact limit and `index.addDocuments` does not handle + * partial failures in a batch. * - * @returns {Promise} A promise that resolves when the synchronization is complete. - * - * @throws {Error} Throws an error if there's an issue with adding a document to MeiliSearch. + * @returns {Promise} Resolves when the synchronization is complete. */ static async syncWithMeili() { try { let moreDocuments = true; + // Retrieve all MongoDB documents from the collection as plain JavaScript objects. const mongoDocuments = await this.find().lean(); - const format = (doc) => _.pick(doc, attributesToIndex); - // Prepare for comparison + // Helper function to format a document by selecting only the attributes to index + // and omitting keys starting with '$'. + const format = (doc) => + _.omitBy(_.pick(doc, attributesToIndex), (v, k) => k.startsWith('$')); + + // Build a map of MongoDB documents for quick lookup based on the primary key. const mongoMap = new Map(mongoDocuments.map((doc) => [doc[primaryKey], format(doc)])); const indexMap = new Map(); let offset = 0; const batchSize = 1000; + // Fetch documents from the MeiliSearch index in batches. while (moreDocuments) { const batch = await index.getDocuments({ limit: batchSize, offset }); - if (batch.results.length === 0) { moreDocuments = false; } - for (const doc of batch.results) { indexMap.set(doc[primaryKey], format(doc)); } - offset += batchSize; } @@ -70,13 +100,12 @@ const createMeiliMongooseModel = function ({ index, attributesToIndex }) { const updateOps = []; - // Iterate over Meili index documents + // Process documents present in the MeiliSearch index. for (const [id, doc] of indexMap) { const update = {}; update[primaryKey] = id; if (mongoMap.has(id)) { - // Case: Update - // If document also exists in MongoDB, would be update case + // If document exists in MongoDB, check for discrepancies in key fields. if ( (doc.text && doc.text !== mongoMap.get(id).text) || (doc.title && doc.title !== mongoMap.get(id).title) @@ -92,8 +121,7 @@ const createMeiliMongooseModel = function ({ index, attributesToIndex }) { await index.addDocuments([doc]); } } else { - // Case: Delete - // If document does not exist in MongoDB, its a delete case from meili index + // If the document does not exist in MongoDB, delete it from MeiliSearch. await index.deleteDocument(id); updateOps.push({ updateOne: { filter: update, update: { $set: { _meiliIndex: false } } }, @@ -101,24 +129,25 @@ const createMeiliMongooseModel = function ({ index, attributesToIndex }) { } } - // Iterate over MongoDB documents + // Process documents present in MongoDB. for (const [id, doc] of mongoMap) { const update = {}; update[primaryKey] = id; - // Case: Insert - // If document does not exist in Meili Index, Its an insert case + // If the document is missing in the Meili index, add it. if (!indexMap.has(id)) { await index.addDocuments([doc]); updateOps.push({ updateOne: { filter: update, update: { $set: { _meiliIndex: true } } }, }); } else if (doc._meiliIndex === false) { + // If the document exists but is marked as not indexed, update the flag. updateOps.push({ updateOne: { filter: update, update: { $set: { _meiliIndex: true } } }, }); } } + // Execute bulk update operations in MongoDB to update the _meiliIndex flags. if (updateOps.length > 0) { await this.collection.bulkWrite(updateOps); logger.debug( @@ -132,34 +161,47 @@ const createMeiliMongooseModel = function ({ index, attributesToIndex }) { } } - // Set one or more settings of the meili index + /** + * Updates settings for the MeiliSearch index. + * + * @param {Object} settings - The settings to update on the MeiliSearch index. + * @returns {Promise} Promise resolving to the update result. + */ static async setMeiliIndexSettings(settings) { return await index.updateSettings(settings); } - // Search the index + /** + * Searches the MeiliSearch index and optionally populates the results with data from MongoDB. + * + * @param {string} q - The search query. + * @param {Object} params - Additional search parameters for MeiliSearch. + * @param {boolean} populate - Whether to populate search hits with full MongoDB documents. + * @returns {Promise} The search results with populated hits if requested. + */ static async meiliSearch(q, params, populate) { const data = await index.search(q, params); - // Populate hits with content from mongodb if (populate) { - // Find objects into mongodb matching `objectID` from Meili search + // Build a query using the primary key values from the search hits. const query = {}; - // query[primaryKey] = { $in: _.map(data.hits, primaryKey) }; query[primaryKey] = _.map(data.hits, (hit) => cleanUpPrimaryKeyValue(hit[primaryKey])); - // logger.debug('query', query); - const hitsFromMongoose = await this.find( - query, - _.reduce( - this.schema.obj, - function (results, value, key) { - return { ...results, [key]: 1 }; - }, - { _id: 1, __v: 1 }, - ), - ).lean(); - // Add additional data from mongodb into Meili search hits + // Build a projection object, including only keys that do not start with '$'. + const projection = Object.keys(this.schema.obj).reduce( + (results, key) => { + if (!key.startsWith('$')) { + results[key] = 1; + } + return results; + }, + { _id: 1, __v: 1 }, + ); + + // Retrieve the full documents from MongoDB. + const hitsFromMongoose = await this.find(query, projection).lean(); + + // Merge the MongoDB documents with the search hits. const populatedHits = data.hits.map(function (hit) { const query = {}; query[primaryKey] = hit[primaryKey]; @@ -176,10 +218,21 @@ const createMeiliMongooseModel = function ({ index, attributesToIndex }) { return data; } + /** + * Preprocesses the current document for indexing. + * + * This method: + * - Picks only the defined attributes to index. + * - Omits any keys starting with '$'. + * - Replaces pipe characters ('|') in `conversationId` with '--'. + * - Extracts and concatenates text from an array of content items. + * + * @returns {Object} The preprocessed object ready for indexing. + */ preprocessObjectForIndex() { - const object = _.pick(this.toJSON(), attributesToIndex); - // NOTE: MeiliSearch does not allow | in primary key, so we replace it with - for Bing convoIds - // object.conversationId = object.conversationId.replace(/\|/g, '-'); + const object = _.omitBy(_.pick(this.toJSON(), attributesToIndex), (v, k) => + k.startsWith('$'), + ); if (object.conversationId && object.conversationId.includes('|')) { object.conversationId = object.conversationId.replace(/\|/g, '--'); } @@ -195,32 +248,53 @@ const createMeiliMongooseModel = function ({ index, attributesToIndex }) { return object; } - // Push new document to Meili + /** + * Adds the current document to the MeiliSearch index. + * + * The method preprocesses the document, adds it to MeiliSearch, and then updates + * the MongoDB document's `_meiliIndex` flag to true. + * + * @returns {Promise} + */ async addObjectToMeili() { const object = this.preprocessObjectForIndex(); try { - // logger.debug('Adding document to Meili', object); await index.addDocuments([object]); } catch (error) { - // logger.debug('Error adding document to Meili'); - // logger.error(error); + // Error handling can be enhanced as needed. + logger.error('[addObjectToMeili] Error adding document to Meili', error); } await this.collection.updateMany({ _id: this._id }, { $set: { _meiliIndex: true } }); } - // Update an existing document in Meili + /** + * Updates the current document in the MeiliSearch index. + * + * @returns {Promise} + */ async updateObjectToMeili() { - const object = _.pick(this.toJSON(), attributesToIndex); + const object = _.omitBy(_.pick(this.toJSON(), attributesToIndex), (v, k) => + k.startsWith('$'), + ); await index.updateDocuments([object]); } - // Delete a document from Meili + /** + * Deletes the current document from the MeiliSearch index. + * + * @returns {Promise} + */ async deleteObjectFromMeili() { await index.deleteDocument(this._id); } - // * schema.post('save') + /** + * Post-save hook to synchronize the document with MeiliSearch. + * + * If the document is already indexed (i.e. `_meiliIndex` is true), it updates it; + * otherwise, it adds the document to the index. + */ postSaveHook() { if (this._meiliIndex) { this.updateObjectToMeili(); @@ -229,14 +303,24 @@ const createMeiliMongooseModel = function ({ index, attributesToIndex }) { } } - // * schema.post('update') + /** + * Post-update hook to update the document in MeiliSearch. + * + * This hook is triggered after a document update, ensuring that changes are + * propagated to the MeiliSearch index if the document is indexed. + */ postUpdateHook() { if (this._meiliIndex) { this.updateObjectToMeili(); } } - // * schema.post('remove') + /** + * Post-remove hook to delete the document from MeiliSearch. + * + * This hook is triggered after a document is removed, ensuring that the document + * is also removed from the MeiliSearch index if it was previously indexed. + */ postRemoveHook() { if (this._meiliIndex) { this.deleteObjectFromMeili(); @@ -247,11 +331,27 @@ const createMeiliMongooseModel = function ({ index, attributesToIndex }) { return MeiliMongooseModel; }; +/** + * Mongoose plugin to synchronize MongoDB collections with a MeiliSearch index. + * + * This plugin: + * - Validates the provided options. + * - Adds a `_meiliIndex` field to the schema to track indexing status. + * - Sets up a MeiliSearch client and creates an index if it doesn't already exist. + * - Loads class methods for syncing, searching, and managing documents in MeiliSearch. + * - Registers Mongoose hooks (post-save, post-update, post-remove, etc.) to maintain index consistency. + * + * @param {mongoose.Schema} schema - The Mongoose schema to which the plugin is applied. + * @param {Object} options - Configuration options. + * @param {string} options.host - The MeiliSearch host. + * @param {string} options.apiKey - The MeiliSearch API key. + * @param {string} options.indexName - The name of the MeiliSearch index. + * @param {string} options.primaryKey - The primary key field for indexing. + */ module.exports = function mongoMeili(schema, options) { - // Vaidate Options for mongoMeili validateOptions(options); - // Add meiliIndex to schema + // Add _meiliIndex field to the schema to track if a document has been indexed in MeiliSearch. schema.add({ _meiliIndex: { type: Boolean, @@ -263,69 +363,77 @@ module.exports = function mongoMeili(schema, options) { const { host, apiKey, indexName, primaryKey } = options; - // Setup MeiliSearch Client + // Setup the MeiliSearch client. const client = new MeiliSearch({ host, apiKey }); - // Asynchronously create the index + // Create the index asynchronously if it doesn't exist. client.createIndex(indexName, { primaryKey }); - // Setup the index to search for this schema + // Setup the MeiliSearch index for this schema. const index = client.index(indexName); + // Collect attributes from the schema that should be indexed. const attributesToIndex = [ ..._.reduce( schema.obj, function (results, value, key) { return value.meiliIndex ? [...results, key] : results; - // }, []), '_id']; }, [], ), ]; + // Load the class methods into the schema. schema.loadClass(createMeiliMongooseModel({ index, indexName, client, attributesToIndex })); - // Register hooks + // Register Mongoose hooks to synchronize with MeiliSearch. + + // Post-save: synchronize after a document is saved. schema.post('save', function (doc) { doc.postSaveHook(); }); + + // Post-update: synchronize after a document is updated. schema.post('update', function (doc) { doc.postUpdateHook(); }); + + // Post-remove: synchronize after a document is removed. schema.post('remove', function (doc) { doc.postRemoveHook(); }); + // Pre-deleteMany hook: remove corresponding documents from MeiliSearch when multiple documents are deleted. schema.pre('deleteMany', async function (next) { if (!meiliEnabled) { - next(); + return next(); } try { + // Check if the schema has a "messages" field to determine if it's a conversation schema. if (Object.prototype.hasOwnProperty.call(schema.obj, 'messages')) { const convoIndex = client.index('convos'); const deletedConvos = await mongoose.model('Conversation').find(this._conditions).lean(); - let promises = []; - for (const convo of deletedConvos) { - promises.push(convoIndex.deleteDocument(convo.conversationId)); - } + const promises = deletedConvos.map((convo) => + convoIndex.deleteDocument(convo.conversationId), + ); await Promise.all(promises); } + // Check if the schema has a "messageId" field to determine if it's a message schema. if (Object.prototype.hasOwnProperty.call(schema.obj, 'messageId')) { const messageIndex = client.index('messages'); const deletedMessages = await mongoose.model('Message').find(this._conditions).lean(); - let promises = []; - for (const message of deletedMessages) { - promises.push(messageIndex.deleteDocument(message.messageId)); - } + const promises = deletedMessages.map((message) => + messageIndex.deleteDocument(message.messageId), + ); await Promise.all(promises); } return next(); } catch (error) { if (meiliEnabled) { logger.error( - '[MeiliMongooseModel.deleteMany] There was an issue deleting conversation indexes upon deletion, next startup may be slow due to syncing', + '[MeiliMongooseModel.deleteMany] There was an issue deleting conversation indexes upon deletion. Next startup may be slow due to syncing.', error, ); } @@ -333,17 +441,19 @@ module.exports = function mongoMeili(schema, options) { } }); + // Post-findOneAndUpdate hook: update MeiliSearch index after a document is updated via findOneAndUpdate. schema.post('findOneAndUpdate', async function (doc) { if (!meiliEnabled) { return; } + // If the document is unfinished, do not update the index. if (doc.unfinished) { return; } let meiliDoc; - // Doc is a Conversation + // For conversation documents, try to fetch the document from the "convos" index. if (doc.messages) { try { meiliDoc = await client.index('convos').getDocument(doc.conversationId); @@ -356,10 +466,12 @@ module.exports = function mongoMeili(schema, options) { } } + // If the MeiliSearch document exists and the title is unchanged, do nothing. if (meiliDoc && meiliDoc.title === doc.title) { return; } + // Otherwise, trigger a post-save hook to synchronize the document. doc.postSaveHook(); }); }; diff --git a/api/models/schema/action.js b/api/models/schema/action.js deleted file mode 100644 index f86a9bfa2d..0000000000 --- a/api/models/schema/action.js +++ /dev/null @@ -1,60 +0,0 @@ -const mongoose = require('mongoose'); - -const { Schema } = mongoose; - -const AuthSchema = new Schema( - { - authorization_type: String, - custom_auth_header: String, - type: { - type: String, - enum: ['service_http', 'oauth', 'none'], - }, - authorization_content_type: String, - authorization_url: String, - client_url: String, - scope: String, - token_exchange_method: { - type: String, - enum: ['default_post', 'basic_auth_header', null], - }, - }, - { _id: false }, -); - -const actionSchema = new Schema({ - user: { - type: mongoose.Schema.Types.ObjectId, - ref: 'User', - index: true, - required: true, - }, - action_id: { - type: String, - index: true, - required: true, - }, - type: { - type: String, - default: 'action_prototype', - }, - settings: Schema.Types.Mixed, - agent_id: String, - assistant_id: String, - metadata: { - api_key: String, // private, encrypted - auth: AuthSchema, - domain: { - type: String, - required: true, - }, - // json_schema: Schema.Types.Mixed, - privacy_policy_url: String, - raw_spec: String, - oauth_client_id: String, // private, encrypted - oauth_client_secret: String, // private, encrypted - }, -}); -// }, { minimize: false }); // Prevent removal of empty objects - -module.exports = actionSchema; diff --git a/api/models/schema/balance.js b/api/models/schema/balance.js deleted file mode 100644 index 8ca8116e09..0000000000 --- a/api/models/schema/balance.js +++ /dev/null @@ -1,17 +0,0 @@ -const mongoose = require('mongoose'); - -const balanceSchema = mongoose.Schema({ - user: { - type: mongoose.Schema.Types.ObjectId, - ref: 'User', - index: true, - required: true, - }, - // 1000 tokenCredits = 1 mill ($0.001 USD) - tokenCredits: { - type: Number, - default: 0, - }, -}); - -module.exports = balanceSchema; diff --git a/api/models/schema/categories.js b/api/models/schema/categories.js deleted file mode 100644 index 3167685667..0000000000 --- a/api/models/schema/categories.js +++ /dev/null @@ -1,19 +0,0 @@ -const mongoose = require('mongoose'); -const Schema = mongoose.Schema; - -const categoriesSchema = new Schema({ - label: { - type: String, - required: true, - unique: true, - }, - value: { - type: String, - required: true, - unique: true, - }, -}); - -const categories = mongoose.model('categories', categoriesSchema); - -module.exports = { Categories: categories }; diff --git a/api/models/schema/conversationTagSchema.js b/api/models/schema/conversationTagSchema.js deleted file mode 100644 index 9b2a98c6d8..0000000000 --- a/api/models/schema/conversationTagSchema.js +++ /dev/null @@ -1,32 +0,0 @@ -const mongoose = require('mongoose'); - -const conversationTagSchema = mongoose.Schema( - { - tag: { - type: String, - index: true, - }, - user: { - type: String, - index: true, - }, - description: { - type: String, - index: true, - }, - count: { - type: Number, - default: 0, - }, - position: { - type: Number, - default: 0, - index: true, - }, - }, - { timestamps: true }, -); - -conversationTagSchema.index({ tag: 1, user: 1 }, { unique: true }); - -module.exports = mongoose.model('ConversationTag', conversationTagSchema); diff --git a/api/models/schema/convoSchema.js b/api/models/schema/convoSchema.js index 7d8beed6a6..89cb9c80b5 100644 --- a/api/models/schema/convoSchema.js +++ b/api/models/schema/convoSchema.js @@ -1,63 +1,18 @@ const mongoose = require('mongoose'); const mongoMeili = require('../plugins/mongoMeili'); -const { conversationPreset } = require('./defaults'); -const convoSchema = mongoose.Schema( - { - conversationId: { - type: String, - unique: true, - required: true, - index: true, - meiliIndex: true, - }, - title: { - type: String, - default: 'New Chat', - meiliIndex: true, - }, - user: { - type: String, - index: true, - }, - messages: [{ type: mongoose.Schema.Types.ObjectId, ref: 'Message' }], - // google only - examples: { type: [{ type: mongoose.Schema.Types.Mixed }], default: undefined }, - agentOptions: { - type: mongoose.Schema.Types.Mixed, - }, - ...conversationPreset, - agent_id: { - type: String, - }, - tags: { - type: [String], - default: [], - meiliIndex: true, - }, - files: { - type: [String], - }, - expiredAt: { - type: Date, - }, - }, - { timestamps: true }, -); + +const { convoSchema } = require('@librechat/data-schemas'); if (process.env.MEILI_HOST && process.env.MEILI_MASTER_KEY) { convoSchema.plugin(mongoMeili, { host: process.env.MEILI_HOST, apiKey: process.env.MEILI_MASTER_KEY, - indexName: 'convos', // Will get created automatically if it doesn't exist already + /** Note: Will get created automatically if it doesn't exist already */ + indexName: 'convos', primaryKey: 'conversationId', }); } -// Create TTL index -convoSchema.index({ expiredAt: 1 }, { expireAfterSeconds: 0 }); -convoSchema.index({ createdAt: 1, updatedAt: 1 }); -convoSchema.index({ conversationId: 1, user: 1 }, { unique: true }); - const Conversation = mongoose.models.Conversation || mongoose.model('Conversation', convoSchema); module.exports = Conversation; diff --git a/api/models/schema/fileSchema.js b/api/models/schema/fileSchema.js deleted file mode 100644 index 77c6ff94d4..0000000000 --- a/api/models/schema/fileSchema.js +++ /dev/null @@ -1,111 +0,0 @@ -const { FileSources } = require('librechat-data-provider'); -const mongoose = require('mongoose'); - -/** - * @typedef {Object} MongoFile - * @property {ObjectId} [_id] - MongoDB Document ID - * @property {number} [__v] - MongoDB Version Key - * @property {ObjectId} user - User ID - * @property {string} [conversationId] - Optional conversation ID - * @property {string} file_id - File identifier - * @property {string} [temp_file_id] - Temporary File identifier - * @property {number} bytes - Size of the file in bytes - * @property {string} filename - Name of the file - * @property {string} filepath - Location of the file - * @property {'file'} object - Type of object, always 'file' - * @property {string} type - Type of file - * @property {number} [usage=0] - Number of uses of the file - * @property {string} [context] - Context of the file origin - * @property {boolean} [embedded=false] - Whether or not the file is embedded in vector db - * @property {string} [model] - The model to identify the group region of the file (for Azure OpenAI hosting) - * @property {string} [source] - The source of the file (e.g., from FileSources) - * @property {number} [width] - Optional width of the file - * @property {number} [height] - Optional height of the file - * @property {Object} [metadata] - Metadata related to the file - * @property {string} [metadata.fileIdentifier] - Unique identifier for the file in metadata - * @property {Date} [expiresAt] - Optional expiration date of the file - * @property {Date} [createdAt] - Date when the file was created - * @property {Date} [updatedAt] - Date when the file was updated - */ - -/** @type {MongooseSchema} */ -const fileSchema = mongoose.Schema( - { - user: { - type: mongoose.Schema.Types.ObjectId, - ref: 'User', - index: true, - required: true, - }, - conversationId: { - type: String, - ref: 'Conversation', - index: true, - }, - file_id: { - type: String, - // required: true, - index: true, - }, - temp_file_id: { - type: String, - // required: true, - }, - bytes: { - type: Number, - required: true, - }, - filename: { - type: String, - required: true, - }, - filepath: { - type: String, - required: true, - }, - object: { - type: String, - required: true, - default: 'file', - }, - embedded: { - type: Boolean, - }, - type: { - type: String, - required: true, - }, - context: { - type: String, - // required: true, - }, - usage: { - type: Number, - required: true, - default: 0, - }, - source: { - type: String, - default: FileSources.local, - }, - model: { - type: String, - }, - width: Number, - height: Number, - metadata: { - fileIdentifier: String, - }, - expiresAt: { - type: Date, - expires: 3600, // 1 hour in seconds - }, - }, - { - timestamps: true, - }, -); - -fileSchema.index({ createdAt: 1, updatedAt: 1 }); - -module.exports = fileSchema; diff --git a/api/models/schema/messageSchema.js b/api/models/schema/messageSchema.js index be71155295..cf97b84eea 100644 --- a/api/models/schema/messageSchema.js +++ b/api/models/schema/messageSchema.js @@ -1,145 +1,6 @@ const mongoose = require('mongoose'); const mongoMeili = require('~/models/plugins/mongoMeili'); -const messageSchema = mongoose.Schema( - { - messageId: { - type: String, - unique: true, - required: true, - index: true, - meiliIndex: true, - }, - conversationId: { - type: String, - index: true, - required: true, - meiliIndex: true, - }, - user: { - type: String, - index: true, - required: true, - default: null, - }, - model: { - type: String, - default: null, - }, - endpoint: { - type: String, - }, - conversationSignature: { - type: String, - }, - clientId: { - type: String, - }, - invocationId: { - type: Number, - }, - parentMessageId: { - type: String, - }, - tokenCount: { - type: Number, - }, - summaryTokenCount: { - type: Number, - }, - sender: { - type: String, - meiliIndex: true, - }, - text: { - type: String, - meiliIndex: true, - }, - summary: { - type: String, - }, - isCreatedByUser: { - type: Boolean, - required: true, - default: false, - }, - unfinished: { - type: Boolean, - default: false, - }, - error: { - type: Boolean, - default: false, - }, - finish_reason: { - type: String, - }, - _meiliIndex: { - type: Boolean, - required: false, - select: false, - default: false, - }, - files: { type: [{ type: mongoose.Schema.Types.Mixed }], default: undefined }, - plugin: { - type: { - latest: { - type: String, - required: false, - }, - inputs: { - type: [mongoose.Schema.Types.Mixed], - required: false, - default: undefined, - }, - outputs: { - type: String, - required: false, - }, - }, - default: undefined, - }, - plugins: { type: [{ type: mongoose.Schema.Types.Mixed }], default: undefined }, - content: { - type: [{ type: mongoose.Schema.Types.Mixed }], - default: undefined, - meiliIndex: true, - }, - thread_id: { - type: String, - }, - /* frontend components */ - iconURL: { - type: String, - }, - attachments: { type: [{ type: mongoose.Schema.Types.Mixed }], default: undefined }, - /* - attachments: { - type: [ - { - file_id: String, - filename: String, - filepath: String, - expiresAt: Date, - width: Number, - height: Number, - type: String, - conversationId: String, - messageId: { - type: String, - required: true, - }, - toolCallId: String, - }, - ], - default: undefined, - }, - */ - expiredAt: { - type: Date, - }, - }, - { timestamps: true }, -); +const { messageSchema } = require('@librechat/data-schemas'); if (process.env.MEILI_HOST && process.env.MEILI_MASTER_KEY) { messageSchema.plugin(mongoMeili, { @@ -149,11 +10,7 @@ if (process.env.MEILI_HOST && process.env.MEILI_MASTER_KEY) { primaryKey: 'messageId', }); } -messageSchema.index({ expiredAt: 1 }, { expireAfterSeconds: 0 }); -messageSchema.index({ createdAt: 1 }); -messageSchema.index({ messageId: 1, user: 1 }, { unique: true }); -/** @type {mongoose.Model} */ const Message = mongoose.models.Message || mongoose.model('Message', messageSchema); module.exports = Message; diff --git a/api/models/schema/pluginAuthSchema.js b/api/models/schema/pluginAuthSchema.js index 4b4251dda3..2066eda4c4 100644 --- a/api/models/schema/pluginAuthSchema.js +++ b/api/models/schema/pluginAuthSchema.js @@ -1,25 +1,5 @@ const mongoose = require('mongoose'); - -const pluginAuthSchema = mongoose.Schema( - { - authField: { - type: String, - required: true, - }, - value: { - type: String, - required: true, - }, - userId: { - type: String, - required: true, - }, - pluginKey: { - type: String, - }, - }, - { timestamps: true }, -); +const { pluginAuthSchema } = require('@librechat/data-schemas'); const PluginAuth = mongoose.models.Plugin || mongoose.model('PluginAuth', pluginAuthSchema); diff --git a/api/models/schema/presetSchema.js b/api/models/schema/presetSchema.js index e1c92ab9c0..6d03803ace 100644 --- a/api/models/schema/presetSchema.js +++ b/api/models/schema/presetSchema.js @@ -1,38 +1,5 @@ const mongoose = require('mongoose'); -const { conversationPreset } = require('./defaults'); -const presetSchema = mongoose.Schema( - { - presetId: { - type: String, - unique: true, - required: true, - index: true, - }, - title: { - type: String, - default: 'New Chat', - meiliIndex: true, - }, - user: { - type: String, - default: null, - }, - defaultPreset: { - type: Boolean, - }, - order: { - type: Number, - }, - // google only - examples: [{ type: mongoose.Schema.Types.Mixed }], - ...conversationPreset, - agentOptions: { - type: mongoose.Schema.Types.Mixed, - default: null, - }, - }, - { timestamps: true }, -); +const { presetSchema } = require('@librechat/data-schemas'); const Preset = mongoose.models.Preset || mongoose.model('Preset', presetSchema); diff --git a/api/models/schema/projectSchema.js b/api/models/schema/projectSchema.js deleted file mode 100644 index dfa68a06c2..0000000000 --- a/api/models/schema/projectSchema.js +++ /dev/null @@ -1,35 +0,0 @@ -const { Schema } = require('mongoose'); - -/** - * @typedef {Object} MongoProject - * @property {ObjectId} [_id] - MongoDB Document ID - * @property {string} name - The name of the project - * @property {ObjectId[]} promptGroupIds - Array of PromptGroup IDs associated with the project - * @property {Date} [createdAt] - Date when the project was created (added by timestamps) - * @property {Date} [updatedAt] - Date when the project was last updated (added by timestamps) - */ - -const projectSchema = new Schema( - { - name: { - type: String, - required: true, - index: true, - }, - promptGroupIds: { - type: [Schema.Types.ObjectId], - ref: 'PromptGroup', - default: [], - }, - agentIds: { - type: [String], - ref: 'Agent', - default: [], - }, - }, - { - timestamps: true, - }, -); - -module.exports = projectSchema; diff --git a/api/models/schema/promptSchema.js b/api/models/schema/promptSchema.js deleted file mode 100644 index 5464caf639..0000000000 --- a/api/models/schema/promptSchema.js +++ /dev/null @@ -1,118 +0,0 @@ -const mongoose = require('mongoose'); -const { Constants } = require('librechat-data-provider'); -const Schema = mongoose.Schema; - -/** - * @typedef {Object} MongoPromptGroup - * @property {ObjectId} [_id] - MongoDB Document ID - * @property {string} name - The name of the prompt group - * @property {ObjectId} author - The author of the prompt group - * @property {ObjectId} [projectId=null] - The project ID of the prompt group - * @property {ObjectId} [productionId=null] - The project ID of the prompt group - * @property {string} authorName - The name of the author of the prompt group - * @property {number} [numberOfGenerations=0] - Number of generations the prompt group has - * @property {string} [oneliner=''] - Oneliner description of the prompt group - * @property {string} [category=''] - Category of the prompt group - * @property {string} [command] - Command for the prompt group - * @property {Date} [createdAt] - Date when the prompt group was created (added by timestamps) - * @property {Date} [updatedAt] - Date when the prompt group was last updated (added by timestamps) - */ - -const promptGroupSchema = new Schema( - { - name: { - type: String, - required: true, - index: true, - }, - numberOfGenerations: { - type: Number, - default: 0, - }, - oneliner: { - type: String, - default: '', - }, - category: { - type: String, - default: '', - index: true, - }, - projectIds: { - type: [Schema.Types.ObjectId], - ref: 'Project', - index: true, - }, - productionId: { - type: Schema.Types.ObjectId, - ref: 'Prompt', - required: true, - index: true, - }, - author: { - type: Schema.Types.ObjectId, - ref: 'User', - required: true, - index: true, - }, - authorName: { - type: String, - required: true, - }, - command: { - type: String, - index: true, - validate: { - validator: function (v) { - return v === undefined || v === null || v === '' || /^[a-z0-9-]+$/.test(v); - }, - message: (props) => - `${props.value} is not a valid command. Only lowercase alphanumeric characters and highfins (') are allowed.`, - }, - maxlength: [ - Constants.COMMANDS_MAX_LENGTH, - `Command cannot be longer than ${Constants.COMMANDS_MAX_LENGTH} characters`, - ], - }, - }, - { - timestamps: true, - }, -); - -const PromptGroup = mongoose.model('PromptGroup', promptGroupSchema); - -const promptSchema = new Schema( - { - groupId: { - type: Schema.Types.ObjectId, - ref: 'PromptGroup', - required: true, - index: true, - }, - author: { - type: Schema.Types.ObjectId, - ref: 'User', - required: true, - }, - prompt: { - type: String, - required: true, - }, - type: { - type: String, - enum: ['text', 'chat'], - required: true, - }, - }, - { - timestamps: true, - }, -); - -const Prompt = mongoose.model('Prompt', promptSchema); - -promptSchema.index({ createdAt: 1, updatedAt: 1 }); -promptGroupSchema.index({ createdAt: 1, updatedAt: 1 }); - -module.exports = { Prompt, PromptGroup }; diff --git a/api/models/schema/roleSchema.js b/api/models/schema/roleSchema.js deleted file mode 100644 index 36e9d3f7b6..0000000000 --- a/api/models/schema/roleSchema.js +++ /dev/null @@ -1,55 +0,0 @@ -const { PermissionTypes, Permissions } = require('librechat-data-provider'); -const mongoose = require('mongoose'); - -const roleSchema = new mongoose.Schema({ - name: { - type: String, - required: true, - unique: true, - index: true, - }, - [PermissionTypes.BOOKMARKS]: { - [Permissions.USE]: { - type: Boolean, - default: true, - }, - }, - [PermissionTypes.PROMPTS]: { - [Permissions.SHARED_GLOBAL]: { - type: Boolean, - default: false, - }, - [Permissions.USE]: { - type: Boolean, - default: true, - }, - [Permissions.CREATE]: { - type: Boolean, - default: true, - }, - }, - [PermissionTypes.AGENTS]: { - [Permissions.SHARED_GLOBAL]: { - type: Boolean, - default: false, - }, - [Permissions.USE]: { - type: Boolean, - default: true, - }, - [Permissions.CREATE]: { - type: Boolean, - default: true, - }, - }, - [PermissionTypes.MULTI_CONVO]: { - [Permissions.USE]: { - type: Boolean, - default: true, - }, - }, -}); - -const Role = mongoose.model('Role', roleSchema); - -module.exports = Role; diff --git a/api/models/schema/session.js b/api/models/schema/session.js deleted file mode 100644 index ccda43573d..0000000000 --- a/api/models/schema/session.js +++ /dev/null @@ -1,20 +0,0 @@ -const mongoose = require('mongoose'); - -const sessionSchema = mongoose.Schema({ - refreshTokenHash: { - type: String, - required: true, - }, - expiration: { - type: Date, - required: true, - expires: 0, - }, - user: { - type: mongoose.Schema.Types.ObjectId, - ref: 'User', - required: true, - }, -}); - -module.exports = sessionSchema; diff --git a/api/models/schema/toolCallSchema.js b/api/models/schema/toolCallSchema.js deleted file mode 100644 index 2af4c67c1b..0000000000 --- a/api/models/schema/toolCallSchema.js +++ /dev/null @@ -1,54 +0,0 @@ -const mongoose = require('mongoose'); - -/** - * @typedef {Object} ToolCallData - * @property {string} conversationId - The ID of the conversation - * @property {string} messageId - The ID of the message - * @property {string} toolId - The ID of the tool - * @property {string | ObjectId} user - The user's ObjectId - * @property {unknown} [result] - Optional result data - * @property {TAttachment[]} [attachments] - Optional attachments data - * @property {number} [blockIndex] - Optional code block index - * @property {number} [partIndex] - Optional part index - */ - -/** @type {MongooseSchema} */ -const toolCallSchema = mongoose.Schema( - { - conversationId: { - type: String, - required: true, - }, - messageId: { - type: String, - required: true, - }, - toolId: { - type: String, - required: true, - }, - user: { - type: mongoose.Schema.Types.ObjectId, - ref: 'User', - required: true, - }, - result: { - type: mongoose.Schema.Types.Mixed, - }, - attachments: { - type: mongoose.Schema.Types.Mixed, - }, - blockIndex: { - type: Number, - }, - partIndex: { - type: Number, - }, - }, - { timestamps: true }, -); - -toolCallSchema.index({ messageId: 1, user: 1 }); -toolCallSchema.index({ conversationId: 1, user: 1 }); - -module.exports = mongoose.model('ToolCall', toolCallSchema); diff --git a/api/models/schema/userSchema.js b/api/models/schema/userSchema.js deleted file mode 100644 index f586553367..0000000000 --- a/api/models/schema/userSchema.js +++ /dev/null @@ -1,140 +0,0 @@ -const mongoose = require('mongoose'); -const { SystemRoles } = require('librechat-data-provider'); - -/** - * @typedef {Object} MongoSession - * @property {string} [refreshToken] - The refresh token - */ - -/** - * @typedef {Object} MongoUser - * @property {ObjectId} [_id] - MongoDB Document ID - * @property {string} [name] - The user's name - * @property {string} [username] - The user's username, in lowercase - * @property {string} email - The user's email address - * @property {boolean} emailVerified - Whether the user's email is verified - * @property {string} [password] - The user's password, trimmed with 8-128 characters - * @property {string} [avatar] - The URL of the user's avatar - * @property {string} provider - The provider of the user's account (e.g., 'local', 'google') - * @property {string} [role='USER'] - The role of the user - * @property {string} [googleId] - Optional Google ID for the user - * @property {string} [facebookId] - Optional Facebook ID for the user - * @property {string} [openidId] - Optional OpenID ID for the user - * @property {string} [ldapId] - Optional LDAP ID for the user - * @property {string} [githubId] - Optional GitHub ID for the user - * @property {string} [discordId] - Optional Discord ID for the user - * @property {string} [appleId] - Optional Apple ID for the user - * @property {Array} [plugins=[]] - List of plugins used by the user - * @property {Array.} [refreshToken] - List of sessions with refresh tokens - * @property {Date} [expiresAt] - Optional expiration date of the file - * @property {Date} [createdAt] - Date when the user was created (added by timestamps) - * @property {Date} [updatedAt] - Date when the user was last updated (added by timestamps) - */ - -/** @type {MongooseSchema} */ -const Session = mongoose.Schema({ - refreshToken: { - type: String, - default: '', - }, -}); - -/** @type {MongooseSchema} */ -const userSchema = mongoose.Schema( - { - name: { - type: String, - }, - username: { - type: String, - lowercase: true, - default: '', - }, - email: { - type: String, - required: [true, 'can\'t be blank'], - lowercase: true, - unique: true, - match: [/\S+@\S+\.\S+/, 'is invalid'], - index: true, - }, - emailVerified: { - type: Boolean, - required: true, - default: false, - }, - password: { - type: String, - trim: true, - minlength: 8, - maxlength: 128, - }, - avatar: { - type: String, - required: false, - }, - provider: { - type: String, - required: true, - default: 'local', - }, - role: { - type: String, - default: SystemRoles.USER, - }, - googleId: { - type: String, - unique: true, - sparse: true, - }, - facebookId: { - type: String, - unique: true, - sparse: true, - }, - openidId: { - type: String, - unique: true, - sparse: true, - }, - ldapId: { - type: String, - unique: true, - sparse: true, - }, - githubId: { - type: String, - unique: true, - sparse: true, - }, - discordId: { - type: String, - unique: true, - sparse: true, - }, - appleId: { - type: String, - unique: true, - sparse: true, - }, - plugins: { - type: Array, - default: [], - }, - refreshToken: { - type: [Session], - }, - expiresAt: { - type: Date, - expires: 604800, // 7 days in seconds - }, - termsAccepted: { - type: Boolean, - default: false, - }, - }, - - { timestamps: true }, -); - -module.exports = userSchema; diff --git a/api/models/tx.js b/api/models/tx.js index 05412430c7..67301d0c49 100644 --- a/api/models/tx.js +++ b/api/models/tx.js @@ -61,6 +61,7 @@ const bedrockValues = { 'amazon.nova-micro-v1:0': { prompt: 0.035, completion: 0.14 }, 'amazon.nova-lite-v1:0': { prompt: 0.06, completion: 0.24 }, 'amazon.nova-pro-v1:0': { prompt: 0.8, completion: 3.2 }, + 'deepseek.r1': { prompt: 1.35, completion: 5.4 }, }; /** @@ -79,6 +80,7 @@ const tokenValues = Object.assign( 'o1-mini': { prompt: 1.1, completion: 4.4 }, 'o1-preview': { prompt: 15, completion: 60 }, o1: { prompt: 15, completion: 60 }, + 'gpt-4.5': { prompt: 75, completion: 150 }, 'gpt-4o-mini': { prompt: 0.15, completion: 0.6 }, 'gpt-4o': { prompt: 2.5, completion: 10 }, 'gpt-4o-2024-05-13': { prompt: 5, completion: 15 }, @@ -88,6 +90,8 @@ const tokenValues = Object.assign( 'claude-3-sonnet': { prompt: 3, completion: 15 }, 'claude-3-5-sonnet': { prompt: 3, completion: 15 }, 'claude-3.5-sonnet': { prompt: 3, completion: 15 }, + 'claude-3-7-sonnet': { prompt: 3, completion: 15 }, + 'claude-3.7-sonnet': { prompt: 3, completion: 15 }, 'claude-3-5-haiku': { prompt: 0.8, completion: 4 }, 'claude-3.5-haiku': { prompt: 0.8, completion: 4 }, 'claude-3-haiku': { prompt: 0.25, completion: 1.25 }, @@ -110,6 +114,14 @@ const tokenValues = Object.assign( 'gemini-1.5': { prompt: 2.5, completion: 10 }, 'gemini-pro-vision': { prompt: 0.5, completion: 1.5 }, gemini: { prompt: 0.5, completion: 1.5 }, + 'grok-2-vision-1212': { prompt: 2.0, completion: 10.0 }, + 'grok-2-vision-latest': { prompt: 2.0, completion: 10.0 }, + 'grok-2-vision': { prompt: 2.0, completion: 10.0 }, + 'grok-vision-beta': { prompt: 5.0, completion: 15.0 }, + 'grok-2-1212': { prompt: 2.0, completion: 10.0 }, + 'grok-2-latest': { prompt: 2.0, completion: 10.0 }, + 'grok-2': { prompt: 2.0, completion: 10.0 }, + 'grok-beta': { prompt: 5.0, completion: 15.0 }, }, bedrockValues, ); @@ -121,6 +133,8 @@ const tokenValues = Object.assign( * @type {Object.} */ const cacheTokenValues = { + 'claude-3.7-sonnet': { write: 3.75, read: 0.3 }, + 'claude-3-7-sonnet': { write: 3.75, read: 0.3 }, 'claude-3.5-sonnet': { write: 3.75, read: 0.3 }, 'claude-3-5-sonnet': { write: 3.75, read: 0.3 }, 'claude-3.5-haiku': { write: 1, read: 0.08 }, @@ -155,6 +169,8 @@ const getValueKey = (model, endpoint) => { return 'o1-mini'; } else if (modelName.includes('o1')) { return 'o1'; + } else if (modelName.includes('gpt-4.5')) { + return 'gpt-4.5'; } else if (modelName.includes('gpt-4o-2024-05-13')) { return 'gpt-4o-2024-05-13'; } else if (modelName.includes('gpt-4o-mini')) { diff --git a/api/models/tx.spec.js b/api/models/tx.spec.js index d77973a7f5..f612e222bb 100644 --- a/api/models/tx.spec.js +++ b/api/models/tx.spec.js @@ -50,6 +50,16 @@ describe('getValueKey', () => { expect(getValueKey('gpt-4-0125')).toBe('gpt-4-1106'); }); + it('should return "gpt-4.5" for model type of "gpt-4.5"', () => { + expect(getValueKey('gpt-4.5-preview')).toBe('gpt-4.5'); + expect(getValueKey('gpt-4.5-2024-08-06')).toBe('gpt-4.5'); + expect(getValueKey('gpt-4.5-2024-08-06-0718')).toBe('gpt-4.5'); + expect(getValueKey('openai/gpt-4.5')).toBe('gpt-4.5'); + expect(getValueKey('openai/gpt-4.5-2024-08-06')).toBe('gpt-4.5'); + expect(getValueKey('gpt-4.5-turbo')).toBe('gpt-4.5'); + expect(getValueKey('gpt-4.5-0125')).toBe('gpt-4.5'); + }); + it('should return "gpt-4o" for model type of "gpt-4o"', () => { expect(getValueKey('gpt-4o-2024-08-06')).toBe('gpt-4o'); expect(getValueKey('gpt-4o-2024-08-06-0718')).toBe('gpt-4o'); @@ -80,6 +90,20 @@ describe('getValueKey', () => { expect(getValueKey('chatgpt-4o-latest-0718')).toBe('gpt-4o'); }); + it('should return "claude-3-7-sonnet" for model type of "claude-3-7-sonnet-"', () => { + expect(getValueKey('claude-3-7-sonnet-20240620')).toBe('claude-3-7-sonnet'); + expect(getValueKey('anthropic/claude-3-7-sonnet')).toBe('claude-3-7-sonnet'); + expect(getValueKey('claude-3-7-sonnet-turbo')).toBe('claude-3-7-sonnet'); + expect(getValueKey('claude-3-7-sonnet-0125')).toBe('claude-3-7-sonnet'); + }); + + it('should return "claude-3.7-sonnet" for model type of "claude-3.7-sonnet-"', () => { + expect(getValueKey('claude-3.7-sonnet-20240620')).toBe('claude-3.7-sonnet'); + expect(getValueKey('anthropic/claude-3.7-sonnet')).toBe('claude-3.7-sonnet'); + expect(getValueKey('claude-3.7-sonnet-turbo')).toBe('claude-3.7-sonnet'); + expect(getValueKey('claude-3.7-sonnet-0125')).toBe('claude-3.7-sonnet'); + }); + it('should return "claude-3-5-sonnet" for model type of "claude-3-5-sonnet-"', () => { expect(getValueKey('claude-3-5-sonnet-20240620')).toBe('claude-3-5-sonnet'); expect(getValueKey('anthropic/claude-3-5-sonnet')).toBe('claude-3-5-sonnet'); @@ -264,7 +288,7 @@ describe('AWS Bedrock Model Tests', () => { }); describe('Deepseek Model Tests', () => { - const deepseekModels = ['deepseek-chat', 'deepseek-coder', 'deepseek-reasoner']; + const deepseekModels = ['deepseek-chat', 'deepseek-coder', 'deepseek-reasoner', 'deepseek.r1']; it('should return the correct prompt multipliers for all models', () => { const results = deepseekModels.map((model) => { @@ -458,3 +482,30 @@ describe('Google Model Tests', () => { }); }); }); + +describe('Grok Model Tests - Pricing', () => { + describe('getMultiplier', () => { + test('should return correct prompt and completion rates for Grok vision models', () => { + const models = ['grok-2-vision-1212', 'grok-2-vision', 'grok-2-vision-latest']; + models.forEach((model) => { + expect(getMultiplier({ model, tokenType: 'prompt' })).toBe(2.0); + expect(getMultiplier({ model, tokenType: 'completion' })).toBe(10.0); + }); + }); + + test('should return correct prompt and completion rates for Grok text models', () => { + const models = ['grok-2-1212', 'grok-2', 'grok-2-latest']; + models.forEach((model) => { + expect(getMultiplier({ model, tokenType: 'prompt' })).toBe(2.0); + expect(getMultiplier({ model, tokenType: 'completion' })).toBe(10.0); + }); + }); + + test('should return correct prompt and completion rates for Grok beta models', () => { + expect(getMultiplier({ model: 'grok-vision-beta', tokenType: 'prompt' })).toBe(5.0); + expect(getMultiplier({ model: 'grok-vision-beta', tokenType: 'completion' })).toBe(15.0); + expect(getMultiplier({ model: 'grok-beta', tokenType: 'prompt' })).toBe(5.0); + expect(getMultiplier({ model: 'grok-beta', tokenType: 'completion' })).toBe(15.0); + }); + }); +}); diff --git a/api/package.json b/api/package.json index 8d5a997e6e..b04b049ce8 100644 --- a/api/package.json +++ b/api/package.json @@ -1,6 +1,6 @@ { "name": "@librechat/backend", - "version": "v0.7.7-rc1", + "version": "v0.7.7", "description": "", "scripts": { "start": "echo 'please run this from the root directory'", @@ -34,20 +34,25 @@ }, "homepage": "https://librechat.ai", "dependencies": { - "@anthropic-ai/sdk": "^0.32.1", + "@anthropic-ai/sdk": "^0.37.0", + "@aws-sdk/client-s3": "^3.758.0", + "@aws-sdk/s3-request-presigner": "^3.758.0", + "@azure/identity": "^4.7.0", "@azure/search-documents": "^12.0.0", - "@google/generative-ai": "^0.21.0", + "@azure/storage-blob": "^12.26.0", + "@google/generative-ai": "^0.23.0", "@googleapis/youtube": "^20.0.0", "@keyv/mongo": "^2.1.8", "@keyv/redis": "^2.8.1", - "@langchain/community": "^0.3.14", - "@langchain/core": "^0.3.37", - "@langchain/google-genai": "^0.1.7", - "@langchain/google-vertexai": "^0.1.8", + "@langchain/community": "^0.3.34", + "@langchain/core": "^0.3.40", + "@langchain/google-genai": "^0.1.11", + "@langchain/google-vertexai": "^0.2.2", "@langchain/textsplitters": "^0.1.0", - "@librechat/agents": "^2.0.4", + "@librechat/agents": "^2.3.8", + "@librechat/data-schemas": "*", "@waylaidwanderer/fetch-event-source": "^3.0.1", - "axios": "1.7.8", + "axios": "^1.8.2", "bcryptjs": "^2.4.3", "cohere-ai": "^7.9.1", "compression": "^1.7.4", @@ -57,21 +62,23 @@ "cors": "^2.8.5", "dedent": "^1.5.3", "dotenv": "^16.0.3", + "eventsource": "^3.0.2", "express": "^4.21.2", "express-mongo-sanitize": "^2.2.0", "express-rate-limit": "^7.4.1", "express-session": "^1.18.1", + "express-static-gzip": "^2.2.0", "file-type": "^18.7.0", "firebase": "^11.0.2", "googleapis": "^126.0.1", "handlebars": "^4.7.7", + "https-proxy-agent": "^7.0.6", "ioredis": "^5.3.2", "js-yaml": "^4.1.0", "jsonwebtoken": "^9.0.0", "keyv": "^4.5.4", "keyv-file": "^0.2.0", "klona": "^2.0.6", - "langchain": "^0.2.19", "librechat-data-provider": "*", "librechat-mcp": "*", "lodash": "^4.17.21", @@ -79,7 +86,7 @@ "memorystore": "^1.6.7", "mime": "^3.0.0", "module-alias": "^2.2.3", - "mongoose": "^8.9.5", + "mongoose": "^8.12.1", "multer": "^1.4.5-lts.1", "nanoid": "^3.3.7", "nodemailer": "^6.9.15", @@ -96,6 +103,7 @@ "passport-jwt": "^4.0.1", "passport-ldapauth": "^3.0.1", "passport-local": "^1.0.0", + "rate-limit-redis": "^4.2.0", "sharp": "^0.32.6", "tiktoken": "^1.0.15", "traverse": "^0.6.7", diff --git a/api/server/controllers/AskController.js b/api/server/controllers/AskController.js index 55fe2fa717..2df6f34ede 100644 --- a/api/server/controllers/AskController.js +++ b/api/server/controllers/AskController.js @@ -150,11 +150,11 @@ const AskController = async (req, res, next, initializeClient, addTitle) => { } catch (error) { const partialText = getText && getText(); handleAbortError(res, req, error, { + sender, partialText, conversationId, - sender, messageId: responseMessageId, - parentMessageId: userMessageId ?? parentMessageId, + parentMessageId: overrideParentMessageId ?? userMessageId ?? parentMessageId, }).catch((err) => { logger.error('[AskController] Error in `handleAbortError`', err); }); diff --git a/api/server/controllers/AuthController.js b/api/server/controllers/AuthController.js index 71551ea867..7cdfaa9aaf 100644 --- a/api/server/controllers/AuthController.js +++ b/api/server/controllers/AuthController.js @@ -61,7 +61,7 @@ const refreshController = async (req, res) => { try { const payload = jwt.verify(refreshToken, process.env.JWT_REFRESH_SECRET); - const user = await getUserById(payload.id, '-password -__v'); + const user = await getUserById(payload.id, '-password -__v -totpSecret'); if (!user) { return res.status(401).redirect('/login'); } diff --git a/api/server/controllers/EditController.js b/api/server/controllers/EditController.js index 2a2f8c28de..1de9725722 100644 --- a/api/server/controllers/EditController.js +++ b/api/server/controllers/EditController.js @@ -135,11 +135,11 @@ const EditController = async (req, res, next, initializeClient) => { } catch (error) { const partialText = getText(); handleAbortError(res, req, error, { + sender, partialText, conversationId, - sender, messageId: responseMessageId, - parentMessageId: userMessageId ?? parentMessageId, + parentMessageId: overrideParentMessageId ?? userMessageId ?? parentMessageId, }).catch((err) => { logger.error('[EditController] Error in `handleAbortError`', err); }); diff --git a/api/server/controllers/ModelController.js b/api/server/controllers/ModelController.js index 79dc81d6b0..ad120c2c83 100644 --- a/api/server/controllers/ModelController.js +++ b/api/server/controllers/ModelController.js @@ -1,6 +1,7 @@ const { CacheKeys } = require('librechat-data-provider'); const { loadDefaultModels, loadConfigModels } = require('~/server/services/Config'); const { getLogStores } = require('~/cache'); +const { logger } = require('~/config'); /** * @param {ServerRequest} req @@ -36,8 +37,13 @@ async function loadModels(req) { } async function modelController(req, res) { - const modelConfig = await loadModels(req); - res.send(modelConfig); + try { + const modelConfig = await loadModels(req); + res.send(modelConfig); + } catch (error) { + logger.error('Error fetching models:', error); + res.status(500).send({ error: error.message }); + } } module.exports = { modelController, loadModels, getModelsConfig }; diff --git a/api/server/controllers/TwoFactorController.js b/api/server/controllers/TwoFactorController.js new file mode 100644 index 0000000000..f5783f45ad --- /dev/null +++ b/api/server/controllers/TwoFactorController.js @@ -0,0 +1,138 @@ +const { + generateTOTPSecret, + generateBackupCodes, + verifyTOTP, + verifyBackupCode, + getTOTPSecret, +} = require('~/server/services/twoFactorService'); +const { updateUser, getUserById } = require('~/models'); +const { logger } = require('~/config'); +const { encryptV3 } = require('~/server/utils/crypto'); + +const safeAppTitle = (process.env.APP_TITLE || 'LibreChat').replace(/\s+/g, ''); + +/** + * Enable 2FA for the user by generating a new TOTP secret and backup codes. + * The secret is encrypted and stored, and 2FA is marked as disabled until confirmed. + */ +const enable2FA = async (req, res) => { + try { + const userId = req.user.id; + const secret = generateTOTPSecret(); + const { plainCodes, codeObjects } = await generateBackupCodes(); + + // Encrypt the secret with v3 encryption before saving. + const encryptedSecret = encryptV3(secret); + + // Update the user record: store the secret & backup codes and set twoFactorEnabled to false. + const user = await updateUser(userId, { + totpSecret: encryptedSecret, + backupCodes: codeObjects, + twoFactorEnabled: false, + }); + + const otpauthUrl = `otpauth://totp/${safeAppTitle}:${user.email}?secret=${secret}&issuer=${safeAppTitle}`; + + return res.status(200).json({ otpauthUrl, backupCodes: plainCodes }); + } catch (err) { + logger.error('[enable2FA]', err); + return res.status(500).json({ message: err.message }); + } +}; + +/** + * Verify a 2FA code (either TOTP or backup code) during setup. + */ +const verify2FA = async (req, res) => { + try { + const userId = req.user.id; + const { token, backupCode } = req.body; + const user = await getUserById(userId); + + if (!user || !user.totpSecret) { + return res.status(400).json({ message: '2FA not initiated' }); + } + + const secret = await getTOTPSecret(user.totpSecret); + let isVerified = false; + + if (token) { + isVerified = await verifyTOTP(secret, token); + } else if (backupCode) { + isVerified = await verifyBackupCode({ user, backupCode }); + } + + if (isVerified) { + return res.status(200).json(); + } + return res.status(400).json({ message: 'Invalid token or backup code.' }); + } catch (err) { + logger.error('[verify2FA]', err); + return res.status(500).json({ message: err.message }); + } +}; + +/** + * Confirm and enable 2FA after a successful verification. + */ +const confirm2FA = async (req, res) => { + try { + const userId = req.user.id; + const { token } = req.body; + const user = await getUserById(userId); + + if (!user || !user.totpSecret) { + return res.status(400).json({ message: '2FA not initiated' }); + } + + const secret = await getTOTPSecret(user.totpSecret); + if (await verifyTOTP(secret, token)) { + await updateUser(userId, { twoFactorEnabled: true }); + return res.status(200).json(); + } + return res.status(400).json({ message: 'Invalid token.' }); + } catch (err) { + logger.error('[confirm2FA]', err); + return res.status(500).json({ message: err.message }); + } +}; + +/** + * Disable 2FA by clearing the stored secret and backup codes. + */ +const disable2FA = async (req, res) => { + try { + const userId = req.user.id; + await updateUser(userId, { totpSecret: null, backupCodes: [], twoFactorEnabled: false }); + return res.status(200).json(); + } catch (err) { + logger.error('[disable2FA]', err); + return res.status(500).json({ message: err.message }); + } +}; + +/** + * Regenerate backup codes for the user. + */ +const regenerateBackupCodes = async (req, res) => { + try { + const userId = req.user.id; + const { plainCodes, codeObjects } = await generateBackupCodes(); + await updateUser(userId, { backupCodes: codeObjects }); + return res.status(200).json({ + backupCodes: plainCodes, + backupCodesHash: codeObjects, + }); + } catch (err) { + logger.error('[regenerateBackupCodes]', err); + return res.status(500).json({ message: err.message }); + } +}; + +module.exports = { + enable2FA, + verify2FA, + confirm2FA, + disable2FA, + regenerateBackupCodes, +}; diff --git a/api/server/controllers/UserController.js b/api/server/controllers/UserController.js index 17089e8fdc..a331b8daae 100644 --- a/api/server/controllers/UserController.js +++ b/api/server/controllers/UserController.js @@ -19,7 +19,9 @@ const { Transaction } = require('~/models/Transaction'); const { logger } = require('~/config'); const getUserController = async (req, res) => { - res.status(200).send(req.user); + const userData = req.user.toObject != null ? req.user.toObject() : { ...req.user }; + delete userData.totpSecret; + res.status(200).send(userData); }; const getTermsStatusController = async (req, res) => { diff --git a/api/server/controllers/agents/callbacks.js b/api/server/controllers/agents/callbacks.js index 33fe585f42..6622ec3815 100644 --- a/api/server/controllers/agents/callbacks.js +++ b/api/server/controllers/agents/callbacks.js @@ -1,4 +1,5 @@ -const { Tools, StepTypes, imageGenTools, FileContext } = require('librechat-data-provider'); +const { nanoid } = require('nanoid'); +const { Tools, StepTypes, FileContext } = require('librechat-data-provider'); const { EnvVar, Providers, @@ -9,8 +10,8 @@ const { ChatModelStreamHandler, } = require('@librechat/agents'); const { processCodeOutput } = require('~/server/services/Files/Code/process'); +const { loadAuthValues } = require('~/server/services/Tools/credentials'); const { saveBase64Image } = require('~/server/services/Files/process'); -const { loadAuthValues } = require('~/app/clients/tools/util'); const { logger, sendEvent } = require('~/config'); /** @typedef {import('@librechat/agents').Graph} Graph */ @@ -199,6 +200,22 @@ function getDefaultHandlers({ res, aggregateContent, toolEndCallback, collectedU aggregateContent({ event, data }); }, }, + [GraphEvents.ON_REASONING_DELTA]: { + /** + * Handle ON_REASONING_DELTA event. + * @param {string} event - The event name. + * @param {StreamEventData} data - The event data. + * @param {GraphRunnableConfig['configurable']} [metadata] The runnable metadata. + */ + handle: (event, data, metadata) => { + if (metadata?.last_agent_index === metadata?.agent_index) { + sendEvent(res, { event, data }); + } else if (!metadata?.hide_sequential_outputs) { + sendEvent(res, { event, data }); + } + aggregateContent({ event, data }); + }, + }, }; return handlers; @@ -226,32 +243,6 @@ function createToolEndCallback({ req, res, artifactPromises }) { return; } - if (imageGenTools.has(output.name)) { - artifactPromises.push( - (async () => { - const fileMetadata = Object.assign(output.artifact, { - messageId: metadata.run_id, - toolCallId: output.tool_call_id, - conversationId: metadata.thread_id, - }); - if (!res.headersSent) { - return fileMetadata; - } - - if (!fileMetadata) { - return null; - } - - res.write(`event: attachment\ndata: ${JSON.stringify(fileMetadata)}\n\n`); - return fileMetadata; - })().catch((error) => { - logger.error('Error processing code output:', error); - return null; - }), - ); - return; - } - if (output.artifact.content) { /** @type {FormattedContent[]} */ const content = output.artifact.content; @@ -262,7 +253,7 @@ function createToolEndCallback({ req, res, artifactPromises }) { const { url } = part.image_url; artifactPromises.push( (async () => { - const filename = `${output.tool_call_id}-image-${new Date().getTime()}`; + const filename = `${output.name}_${output.tool_call_id}_img_${nanoid()}`; const file = await saveBase64Image(url, { req, filename, diff --git a/api/server/controllers/agents/client.js b/api/server/controllers/agents/client.js index a8e9ad82f7..4b995bb06a 100644 --- a/api/server/controllers/agents/client.js +++ b/api/server/controllers/agents/client.js @@ -7,7 +7,16 @@ // validateVisionModel, // mapModelToAzureConfig, // } = require('librechat-data-provider'); -const { Callback, createMetadataAggregator } = require('@librechat/agents'); +require('events').EventEmitter.defaultMaxListeners = 100; +const { + Callback, + GraphEvents, + formatMessage, + formatAgentMessages, + formatContentStrings, + getTokenCountForMessage, + createMetadataAggregator, +} = require('@librechat/agents'); const { Constants, VisionModes, @@ -17,36 +26,28 @@ const { KnownEndpoints, anthropicSchema, isAgentsEndpoint, - bedrockOutputParser, + AgentCapabilities, + bedrockInputSchema, removeNullishValues, } = require('librechat-data-provider'); -const { - extractBaseURL, - // constructAzureURL, - // genAzureChatCompletion, -} = require('~/utils'); -const { - formatMessage, - formatAgentMessages, - formatContentStrings, - createContextHandlers, -} = require('~/app/clients/prompts'); -const { encodeAndFormat } = require('~/server/services/Files/images/encode'); +const { getCustomEndpointConfig, checkCapability } = require('~/server/services/Config'); +const { addCacheControl, createContextHandlers } = require('~/app/clients/prompts'); +const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens'); const { getBufferString, HumanMessage } = require('@langchain/core/messages'); +const { encodeAndFormat } = require('~/server/services/Files/images/encode'); const Tokenizer = require('~/server/services/Tokenizer'); -const { spendTokens } = require('~/models/spendTokens'); const BaseClient = require('~/app/clients/BaseClient'); +const { logger, sendEvent } = require('~/config'); const { createRun } = require('./run'); -const { logger } = require('~/config'); /** @typedef {import('@librechat/agents').MessageContentComplex} MessageContentComplex */ /** @typedef {import('@langchain/core/runnables').RunnableConfig} RunnableConfig */ const providerParsers = { - [EModelEndpoint.openAI]: openAISchema, - [EModelEndpoint.azureOpenAI]: openAISchema, - [EModelEndpoint.anthropic]: anthropicSchema, - [EModelEndpoint.bedrock]: bedrockOutputParser, + [EModelEndpoint.openAI]: openAISchema.parse, + [EModelEndpoint.azureOpenAI]: openAISchema.parse, + [EModelEndpoint.anthropic]: anthropicSchema.parse, + [EModelEndpoint.bedrock]: bedrockInputSchema.parse, }; const legacyContentEndpoints = new Set([KnownEndpoints.groq, KnownEndpoints.deepseek]); @@ -102,6 +103,8 @@ class AgentClient extends BaseClient { this.outputTokensKey = 'output_tokens'; /** @type {UsageMetadata} */ this.usage; + /** @type {Record} */ + this.indexTokenCountMap = {}; } /** @@ -191,7 +194,14 @@ class AgentClient extends BaseClient { : {}; if (parseOptions) { - runOptions = parseOptions(this.options.agent.model_parameters); + try { + runOptions = parseOptions(this.options.agent.model_parameters); + } catch (error) { + logger.error( + '[api/server/controllers/agents/client.js #getSaveOptions] Error parsing options', + error, + ); + } } return removeNullishValues( @@ -219,14 +229,23 @@ class AgentClient extends BaseClient { }; } + /** + * + * @param {TMessage} message + * @param {Array} attachments + * @returns {Promise>>} + */ async addImageURLs(message, attachments) { - const { files, image_urls } = await encodeAndFormat( + const { files, text, image_urls } = await encodeAndFormat( this.options.req, attachments, this.options.agent.provider, VisionModes.agents, ); message.image_urls = image_urls.length ? image_urls : undefined; + if (text && text.length) { + message.ocr = text; + } return files; } @@ -304,7 +323,21 @@ class AgentClient extends BaseClient { assistantName: this.options?.modelLabel, }); - const needsTokenCount = this.contextStrategy && !orderedMessages[i].tokenCount; + if (message.ocr && i !== orderedMessages.length - 1) { + if (typeof formattedMessage.content === 'string') { + formattedMessage.content = message.ocr + '\n' + formattedMessage.content; + } else { + const textPart = formattedMessage.content.find((part) => part.type === 'text'); + textPart + ? (textPart.text = message.ocr + '\n' + textPart.text) + : formattedMessage.content.unshift({ type: 'text', text: message.ocr }); + } + } else if (message.ocr && i === orderedMessages.length - 1) { + systemContent = [systemContent, message.ocr].join('\n'); + } + + const needsTokenCount = + (this.contextStrategy && !orderedMessages[i].tokenCount) || message.ocr; /* If tokens were never counted, or, is a Vision request and the message has files, count again */ if (needsTokenCount || (this.isVisionModel && (message.image_urls || message.files))) { @@ -350,6 +383,10 @@ class AgentClient extends BaseClient { })); } + for (let i = 0; i < messages.length; i++) { + this.indexTokenCountMap[i] = messages[i].tokenCount; + } + const result = { tokenCountMap, prompt: payload, @@ -384,15 +421,34 @@ class AgentClient extends BaseClient { if (!collectedUsage || !collectedUsage.length) { return; } - const input_tokens = collectedUsage[0]?.input_tokens || 0; + const input_tokens = + (collectedUsage[0]?.input_tokens || 0) + + (Number(collectedUsage[0]?.input_token_details?.cache_creation) || 0) + + (Number(collectedUsage[0]?.input_token_details?.cache_read) || 0); let output_tokens = 0; let previousTokens = input_tokens; // Start with original input for (let i = 0; i < collectedUsage.length; i++) { const usage = collectedUsage[i]; + if (!usage) { + continue; + } + + const cache_creation = Number(usage.input_token_details?.cache_creation) || 0; + const cache_read = Number(usage.input_token_details?.cache_read) || 0; + + const txMetadata = { + context, + conversationId: this.conversationId, + user: this.user ?? this.options.req.user?.id, + endpointTokenConfig: this.options.endpointTokenConfig, + model: usage.model ?? model ?? this.model ?? this.options.agent.model_parameters.model, + }; + if (i > 0) { // Count new tokens generated (input_tokens minus previous accumulated tokens) - output_tokens += (Number(usage.input_tokens) || 0) - previousTokens; + output_tokens += + (Number(usage.input_tokens) || 0) + cache_creation + cache_read - previousTokens; } // Add this message's output tokens @@ -400,16 +456,26 @@ class AgentClient extends BaseClient { // Update previousTokens to include this message's output previousTokens += Number(usage.output_tokens) || 0; - spendTokens( - { - context, - conversationId: this.conversationId, - user: this.user ?? this.options.req.user?.id, - endpointTokenConfig: this.options.endpointTokenConfig, - model: usage.model ?? model ?? this.model ?? this.options.agent.model_parameters.model, - }, - { promptTokens: usage.input_tokens, completionTokens: usage.output_tokens }, - ).catch((err) => { + + if (cache_creation > 0 || cache_read > 0) { + spendStructuredTokens(txMetadata, { + promptTokens: { + input: usage.input_tokens, + write: cache_creation, + read: cache_read, + }, + completionTokens: usage.output_tokens, + }).catch((err) => { + logger.error( + '[api/server/controllers/agents/client.js #recordCollectedUsage] Error spending structured tokens', + err, + ); + }); + } + spendTokens(txMetadata, { + promptTokens: usage.input_tokens, + completionTokens: usage.output_tokens, + }).catch((err) => { logger.error( '[api/server/controllers/agents/client.js #recordCollectedUsage] Error spending tokens', err, @@ -477,19 +543,6 @@ class AgentClient extends BaseClient { abortController = new AbortController(); } - const baseURL = extractBaseURL(this.completionsUrl); - logger.debug('[api/server/controllers/agents/client.js] chatCompletion', { - baseURL, - payload, - }); - - // if (this.useOpenRouter) { - // opts.defaultHeaders = { - // 'HTTP-Referer': 'https://librechat.ai', - // 'X-Title': 'LibreChat', - // }; - // } - // if (this.options.headers) { // opts.defaultHeaders = { ...opts.defaultHeaders, ...this.options.headers }; // } @@ -579,6 +632,9 @@ class AgentClient extends BaseClient { // }); // } + /** @type {TCustomConfig['endpoints']['agents']} */ + const agentsEConfig = this.options.req.app.locals[EModelEndpoint.agents]; + /** @type {Partial & { version: 'v1' | 'v2'; run_id?: string; streamMode: string }} */ const config = { configurable: { @@ -586,19 +642,30 @@ class AgentClient extends BaseClient { last_agent_index: this.agentConfigs?.size ?? 0, hide_sequential_outputs: this.options.agent.hide_sequential_outputs, }, - recursionLimit: this.options.req.app.locals[EModelEndpoint.agents]?.recursionLimit, + recursionLimit: agentsEConfig?.recursionLimit, signal: abortController.signal, streamMode: 'values', version: 'v2', }; - const initialMessages = formatAgentMessages(payload); + const toolSet = new Set((this.options.agent.tools ?? []).map((tool) => tool && tool.name)); + let { messages: initialMessages, indexTokenCountMap } = formatAgentMessages( + payload, + this.indexTokenCountMap, + toolSet, + ); if (legacyContentEndpoints.has(this.options.agent.endpoint)) { - formatContentStrings(initialMessages); + initialMessages = formatContentStrings(initialMessages); } /** @type {ReturnType} */ let run; + const countTokens = ((text) => this.getTokenCount(text)).bind(this); + + /** @type {(message: BaseMessage) => number} */ + const tokenCounter = (message) => { + return getTokenCountForMessage(message, countTokens); + }; /** * @@ -606,12 +673,23 @@ class AgentClient extends BaseClient { * @param {BaseMessage[]} messages * @param {number} [i] * @param {TMessageContentParts[]} [contentData] + * @param {Record} [currentIndexCountMap] */ - const runAgent = async (agent, messages, i = 0, contentData = []) => { + const runAgent = async (agent, _messages, i = 0, contentData = [], _currentIndexCountMap) => { config.configurable.model = agent.model_parameters.model; + const currentIndexCountMap = _currentIndexCountMap ?? indexTokenCountMap; if (i > 0) { this.model = agent.model_parameters.model; } + if (agent.recursion_limit && typeof agent.recursion_limit === 'number') { + config.recursionLimit = agent.recursion_limit; + } + if ( + agentsEConfig?.maxRecursionLimit && + config.recursionLimit > agentsEConfig?.maxRecursionLimit + ) { + config.recursionLimit = agentsEConfig?.maxRecursionLimit; + } config.configurable.agent_id = agent.id; config.configurable.name = agent.name; config.configurable.agent_index = i; @@ -626,7 +704,7 @@ class AgentClient extends BaseClient { let systemContent = [ systemMessage, agent.instructions ?? '', - i !== 0 ? agent.additional_instructions ?? '' : '', + i !== 0 ? (agent.additional_instructions ?? '') : '', ] .join('\n') .trim(); @@ -640,12 +718,21 @@ class AgentClient extends BaseClient { } if (noSystemMessages === true && systemContent?.length) { - let latestMessage = messages.pop().content; + let latestMessage = _messages.pop().content; if (typeof latestMessage !== 'string') { latestMessage = latestMessage[0].text; } latestMessage = [systemContent, latestMessage].join('\n'); - messages.push(new HumanMessage(latestMessage)); + _messages.push(new HumanMessage(latestMessage)); + } + + let messages = _messages; + if ( + agent.model_parameters?.clientOptions?.defaultHeaders?.['anthropic-beta']?.includes( + 'prompt-caching', + ) + ) { + messages = addCacheControl(messages); } run = await createRun({ @@ -665,11 +752,29 @@ class AgentClient extends BaseClient { } if (contentData.length) { + const agentUpdate = { + type: ContentTypes.AGENT_UPDATE, + [ContentTypes.AGENT_UPDATE]: { + index: contentData.length, + runId: this.responseMessageId, + agentId: agent.id, + }, + }; + const streamData = { + event: GraphEvents.ON_AGENT_UPDATE, + data: agentUpdate, + }; + this.options.aggregateContent(streamData); + sendEvent(this.options.res, streamData); + contentData.push(agentUpdate); run.Graph.contentData = contentData; } await run.processStream({ messages }, config, { keepContent: i !== 0, + tokenCounter, + indexTokenCountMap: currentIndexCountMap, + maxContextTokens: agent.maxContextTokens, callbacks: { [Callback.TOOL_ERROR]: (graph, error, toolId) => { logger.error( @@ -683,9 +788,13 @@ class AgentClient extends BaseClient { }; await runAgent(this.options.agent, initialMessages); - let finalContentStart = 0; - if (this.agentConfigs && this.agentConfigs.size > 0) { + if ( + this.agentConfigs && + this.agentConfigs.size > 0 && + (await checkCapability(this.options.req, AgentCapabilities.chain)) + ) { + const windowSize = 5; let latestMessage = initialMessages.pop().content; if (typeof latestMessage !== 'string') { latestMessage = latestMessage[0].text; @@ -693,7 +802,16 @@ class AgentClient extends BaseClient { let i = 1; let runMessages = []; - const lastFiveMessages = initialMessages.slice(-5); + const windowIndexCountMap = {}; + const windowMessages = initialMessages.slice(-windowSize); + let currentIndex = 4; + for (let i = initialMessages.length - 1; i >= 0; i--) { + windowIndexCountMap[currentIndex] = indexTokenCountMap[i]; + currentIndex--; + if (currentIndex < 0) { + break; + } + } for (const [agentId, agent] of this.agentConfigs) { if (abortController.signal.aborted === true) { break; @@ -728,7 +846,9 @@ class AgentClient extends BaseClient { } try { const contextMessages = []; - for (const message of lastFiveMessages) { + const runIndexCountMap = {}; + for (let i = 0; i < windowMessages.length; i++) { + const message = windowMessages[i]; const messageType = message._getType(); if ( (!agent.tools || agent.tools.length === 0) && @@ -736,11 +856,13 @@ class AgentClient extends BaseClient { ) { continue; } - + runIndexCountMap[contextMessages.length] = windowIndexCountMap[i]; contextMessages.push(message); } - const currentMessages = [...contextMessages, new HumanMessage(bufferString)]; - await runAgent(agent, currentMessages, i, contentData); + const bufferMessage = new HumanMessage(bufferString); + runIndexCountMap[contextMessages.length] = tokenCounter(bufferMessage); + const currentMessages = [...contextMessages, bufferMessage]; + await runAgent(agent, currentMessages, i, contentData, runIndexCountMap); } catch (err) { logger.error( `[api/server/controllers/agents/client.js #chatCompletion] Error running agent ${agentId} (${i})`, @@ -751,6 +873,7 @@ class AgentClient extends BaseClient { } } + /** Note: not implemented */ if (config.configurable.hide_sequential_outputs !== true) { finalContentStart = 0; } @@ -774,18 +897,20 @@ class AgentClient extends BaseClient { ); } } catch (err) { + logger.error( + '[api/server/controllers/agents/client.js #sendCompletion] Operation aborted', + err, + ); if (!abortController.signal.aborted) { logger.error( '[api/server/controllers/agents/client.js #sendCompletion] Unhandled error type', err, ); - throw err; + this.contentParts.push({ + type: ContentTypes.ERROR, + [ContentTypes.ERROR]: `An error occurred while processing the request${err?.message ? `: ${err.message}` : ''}`, + }); } - - logger.warn( - '[api/server/controllers/agents/client.js #sendCompletion] Operation aborted', - err, - ); } } @@ -800,14 +925,20 @@ class AgentClient extends BaseClient { throw new Error('Run not initialized'); } const { handleLLMEnd, collected: collectedMetadata } = createMetadataAggregator(); - const clientOptions = {}; - const providerConfig = this.options.req.app.locals[this.options.agent.provider]; + /** @type {import('@librechat/agents').ClientOptions} */ + const clientOptions = { + maxTokens: 75, + }; + let endpointConfig = this.options.req.app.locals[this.options.agent.endpoint]; + if (!endpointConfig) { + endpointConfig = await getCustomEndpointConfig(this.options.agent.endpoint); + } if ( - providerConfig && - providerConfig.titleModel && - providerConfig.titleModel !== Constants.CURRENT_MODEL + endpointConfig && + endpointConfig.titleModel && + endpointConfig.titleModel !== Constants.CURRENT_MODEL ) { - clientOptions.model = providerConfig.titleModel; + clientOptions.model = endpointConfig.titleModel; } try { const titleResult = await this.run.generateTitle({ diff --git a/api/server/controllers/agents/request.js b/api/server/controllers/agents/request.js index 288ae8f37f..91277d5bc4 100644 --- a/api/server/controllers/agents/request.js +++ b/api/server/controllers/agents/request.js @@ -142,7 +142,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => { conversationId, sender, messageId: responseMessageId, - parentMessageId: userMessageId ?? parentMessageId, + parentMessageId: overrideParentMessageId ?? userMessageId ?? parentMessageId, }).catch((err) => { logger.error('[api/server/controllers/agents/request] Error in `handleAbortError`', err); }); diff --git a/api/server/controllers/agents/run.js b/api/server/controllers/agents/run.js index 0fcc58a379..45fbf56b49 100644 --- a/api/server/controllers/agents/run.js +++ b/api/server/controllers/agents/run.js @@ -1,5 +1,5 @@ const { Run, Providers } = require('@librechat/agents'); -const { providerEndpointMap } = require('librechat-data-provider'); +const { providerEndpointMap, KnownEndpoints } = require('librechat-data-provider'); /** * @typedef {import('@librechat/agents').t} t @@ -7,6 +7,7 @@ const { providerEndpointMap } = require('librechat-data-provider'); * @typedef {import('@librechat/agents').StreamEventData} StreamEventData * @typedef {import('@librechat/agents').EventHandler} EventHandler * @typedef {import('@librechat/agents').GraphEvents} GraphEvents + * @typedef {import('@librechat/agents').LLMConfig} LLMConfig * @typedef {import('@librechat/agents').IState} IState */ @@ -32,6 +33,7 @@ async function createRun({ streamUsage = true, }) { const provider = providerEndpointMap[agent.provider] ?? agent.provider; + /** @type {LLMConfig} */ const llmConfig = Object.assign( { provider, @@ -41,6 +43,14 @@ async function createRun({ agent.model_parameters, ); + /** @type {'reasoning_content' | 'reasoning'} */ + let reasoningKey; + if ( + llmConfig.configuration?.baseURL?.includes(KnownEndpoints.openrouter) || + (agent.endpoint && agent.endpoint.toLowerCase().includes(KnownEndpoints.openrouter)) + ) { + reasoningKey = 'reasoning'; + } if (/o1(?!-(?:mini|preview)).*$/.test(llmConfig.model)) { llmConfig.streaming = false; llmConfig.disableStreaming = true; @@ -50,6 +60,7 @@ async function createRun({ const graphConfig = { signal, llmConfig, + reasoningKey, tools: agent.tools, instructions: agent.instructions, additional_instructions: agent.additional_instructions, @@ -57,7 +68,7 @@ async function createRun({ }; // TEMPORARY FOR TESTING - if (agent.provider === Providers.ANTHROPIC) { + if (agent.provider === Providers.ANTHROPIC || agent.provider === Providers.BEDROCK) { graphConfig.streamBuffer = 2000; } diff --git a/api/server/controllers/agents/v1.js b/api/server/controllers/agents/v1.js index 08327ec61c..731dee69a2 100644 --- a/api/server/controllers/agents/v1.js +++ b/api/server/controllers/agents/v1.js @@ -1,10 +1,11 @@ const fs = require('fs').promises; const { nanoid } = require('nanoid'); const { - FileContext, - Constants, Tools, + Constants, + FileContext, SystemRoles, + EToolResources, actionDelimiter, } = require('librechat-data-provider'); const { @@ -203,14 +204,21 @@ const duplicateAgentHandler = async (req, res) => { } const { - _id: __id, id: _id, + _id: __id, author: _author, createdAt: _createdAt, updatedAt: _updatedAt, + tool_resources: _tool_resources = {}, ...cloneData } = agent; + if (_tool_resources?.[EToolResources.ocr]) { + cloneData.tool_resources = { + [EToolResources.ocr]: _tool_resources[EToolResources.ocr], + }; + } + const newAgentId = `agent_${nanoid()}`; const newAgentData = Object.assign(cloneData, { id: newAgentId, diff --git a/api/server/controllers/auth/LoginController.js b/api/server/controllers/auth/LoginController.js index 1b543e9baf..226b5605cc 100644 --- a/api/server/controllers/auth/LoginController.js +++ b/api/server/controllers/auth/LoginController.js @@ -1,3 +1,4 @@ +const { generate2FATempToken } = require('~/server/services/twoFactorService'); const { setAuthTokens } = require('~/server/services/AuthService'); const { logger } = require('~/config'); @@ -7,7 +8,12 @@ const loginController = async (req, res) => { return res.status(400).json({ message: 'Invalid credentials' }); } - const { password: _, __v, ...user } = req.user; + if (req.user.twoFactorEnabled) { + const tempToken = generate2FATempToken(req.user._id); + return res.status(200).json({ twoFAPending: true, tempToken }); + } + + const { password: _p, totpSecret: _t, __v, ...user } = req.user; user.id = user._id.toString(); const token = await setAuthTokens(req.user._id, res); diff --git a/api/server/controllers/auth/TwoFactorAuthController.js b/api/server/controllers/auth/TwoFactorAuthController.js new file mode 100644 index 0000000000..15cde8122a --- /dev/null +++ b/api/server/controllers/auth/TwoFactorAuthController.js @@ -0,0 +1,60 @@ +const jwt = require('jsonwebtoken'); +const { + verifyTOTP, + verifyBackupCode, + getTOTPSecret, +} = require('~/server/services/twoFactorService'); +const { setAuthTokens } = require('~/server/services/AuthService'); +const { getUserById } = require('~/models/userMethods'); +const { logger } = require('~/config'); + +/** + * Verifies the 2FA code during login using a temporary token. + */ +const verify2FAWithTempToken = async (req, res) => { + try { + const { tempToken, token, backupCode } = req.body; + if (!tempToken) { + return res.status(400).json({ message: 'Missing temporary token' }); + } + + let payload; + try { + payload = jwt.verify(tempToken, process.env.JWT_SECRET); + } catch (err) { + return res.status(401).json({ message: 'Invalid or expired temporary token' }); + } + + const user = await getUserById(payload.userId); + if (!user || !user.twoFactorEnabled) { + return res.status(400).json({ message: '2FA is not enabled for this user' }); + } + + const secret = await getTOTPSecret(user.totpSecret); + let isVerified = false; + if (token) { + isVerified = await verifyTOTP(secret, token); + } else if (backupCode) { + isVerified = await verifyBackupCode({ user, backupCode }); + } + + if (!isVerified) { + return res.status(401).json({ message: 'Invalid 2FA code or backup code' }); + } + + // Prepare user data to return (omit sensitive fields). + const userData = user.toObject ? user.toObject() : { ...user }; + delete userData.password; + delete userData.__v; + delete userData.totpSecret; + userData.id = user._id.toString(); + + const authToken = await setAuthTokens(user._id, res); + return res.status(200).json({ token: authToken, user: userData }); + } catch (err) { + logger.error('[verify2FAWithTempToken]', err); + return res.status(500).json({ message: 'Something went wrong' }); + } +}; + +module.exports = { verify2FAWithTempToken }; diff --git a/api/server/controllers/tools.js b/api/server/controllers/tools.js index 9460e66136..b37b6fcb8c 100644 --- a/api/server/controllers/tools.js +++ b/api/server/controllers/tools.js @@ -1,10 +1,18 @@ const { nanoid } = require('nanoid'); const { EnvVar } = require('@librechat/agents'); -const { Tools, AuthType, ToolCallTypes } = require('librechat-data-provider'); +const { + Tools, + AuthType, + Permissions, + ToolCallTypes, + PermissionTypes, +} = require('librechat-data-provider'); const { processFileURL, uploadImageBuffer } = require('~/server/services/Files/process'); const { processCodeOutput } = require('~/server/services/Files/Code/process'); -const { loadAuthValues, loadTools } = require('~/app/clients/tools/util'); const { createToolCall, getToolCallsByConvo } = require('~/models/ToolCall'); +const { loadAuthValues } = require('~/server/services/Tools/credentials'); +const { loadTools } = require('~/app/clients/tools/util'); +const { checkAccess } = require('~/server/middleware'); const { getMessage } = require('~/models/Message'); const { logger } = require('~/config'); @@ -12,6 +20,10 @@ const fieldsMap = { [Tools.execute_code]: [EnvVar.CODE_API_KEY], }; +const toolAccessPermType = { + [Tools.execute_code]: PermissionTypes.RUN_CODE, +}; + /** * @param {ServerRequest} req - The request object, containing information about the HTTP request. * @param {ServerResponse} res - The response object, used to send back the desired HTTP response. @@ -58,6 +70,7 @@ const verifyToolAuth = async (req, res) => { /** * @param {ServerRequest} req - The request object, containing information about the HTTP request. * @param {ServerResponse} res - The response object, used to send back the desired HTTP response. + * @param {NextFunction} next - The next middleware function to call. * @returns {Promise} A promise that resolves when the function has completed. */ const callTool = async (req, res) => { @@ -83,6 +96,16 @@ const callTool = async (req, res) => { return; } logger.debug(`[${toolId}/call] User: ${req.user.id}`); + let hasAccess = true; + if (toolAccessPermType[toolId]) { + hasAccess = await checkAccess(req.user, toolAccessPermType[toolId], [Permissions.USE]); + } + if (!hasAccess) { + logger.warn( + `[${toolAccessPermType[toolId]}] Forbidden: Insufficient permissions for User ${req.user.id}: ${Permissions.USE}`, + ); + return res.status(403).json({ message: 'Forbidden: Insufficient permissions' }); + } const { loadedTools } = await loadTools({ user: req.user.id, tools: [toolId], diff --git a/api/server/index.js b/api/server/index.js index 30d36d9a9f..4a428789dd 100644 --- a/api/server/index.js +++ b/api/server/index.js @@ -22,10 +22,11 @@ const staticCache = require('./utils/staticCache'); const noIndex = require('./middleware/noIndex'); const routes = require('./routes'); -const { PORT, HOST, ALLOW_SOCIAL_LOGIN, DISABLE_COMPRESSION } = process.env ?? {}; +const { PORT, HOST, ALLOW_SOCIAL_LOGIN, DISABLE_COMPRESSION, TRUST_PROXY } = process.env ?? {}; const port = Number(PORT) || 3080; const host = HOST || 'localhost'; +const trusted_proxy = Number(TRUST_PROXY) || 1; /* trust first proxy by default */ const startServer = async () => { if (typeof Bun !== 'undefined') { @@ -53,7 +54,7 @@ const startServer = async () => { app.use(staticCache(app.locals.paths.dist)); app.use(staticCache(app.locals.paths.fonts)); app.use(staticCache(app.locals.paths.assets)); - app.set('trust proxy', 1); /* trust first proxy */ + app.set('trust proxy', trusted_proxy); app.use(cors()); app.use(cookieParser()); @@ -145,6 +146,18 @@ process.on('uncaughtException', (err) => { logger.error('There was an uncaught error:', err); } + if (err.message.includes('abort')) { + logger.warn('There was an uncatchable AbortController error.'); + return; + } + + if (err.message.includes('GoogleGenerativeAI')) { + logger.warn( + '\n\n`GoogleGenerativeAI` errors cannot be caught due to an upstream issue, see: https://github.com/google-gemini/generative-ai-js/issues/303', + ); + return; + } + if (err.message.includes('fetch failed')) { if (messageCount === 0) { logger.warn('Meilisearch error, search will be disabled'); diff --git a/api/server/middleware/abortMiddleware.js b/api/server/middleware/abortMiddleware.js index 2137523efe..0053f2bde6 100644 --- a/api/server/middleware/abortMiddleware.js +++ b/api/server/middleware/abortMiddleware.js @@ -120,7 +120,7 @@ const createAbortController = (req, res, getAbortData, getReqData) => { { promptTokens, completionTokens }, ); - saveMessage( + await saveMessage( req, { ...responseMessage, user }, { context: 'api/server/middleware/abortMiddleware.js' }, diff --git a/api/server/middleware/buildEndpointOption.js b/api/server/middleware/buildEndpointOption.js index a0ce754a1c..041864b025 100644 --- a/api/server/middleware/buildEndpointOption.js +++ b/api/server/middleware/buildEndpointOption.js @@ -10,7 +10,6 @@ const openAI = require('~/server/services/Endpoints/openAI'); const agents = require('~/server/services/Endpoints/agents'); const custom = require('~/server/services/Endpoints/custom'); const google = require('~/server/services/Endpoints/google'); -const { getConvoFiles } = require('~/models/Conversation'); const { handleError } = require('~/server/utils'); const buildFunction = { @@ -87,16 +86,8 @@ async function buildEndpointOption(req, res, next) { // TODO: use `getModelsConfig` only when necessary const modelsConfig = await getModelsConfig(req); - const { resendFiles = true } = req.body.endpointOption; req.body.endpointOption.modelsConfig = modelsConfig; - if (isAgents && resendFiles && req.body.conversationId) { - const fileIds = await getConvoFiles(req.body.conversationId); - const requestFiles = req.body.files ?? []; - if (requestFiles.length || fileIds.length) { - req.body.endpointOption.attachments = processFiles(requestFiles, fileIds); - } - } else if (req.body.files) { - // hold the promise + if (req.body.files && !isAgents) { req.body.endpointOption.attachments = processFiles(req.body.files); } next(); diff --git a/api/server/middleware/checkBan.js b/api/server/middleware/checkBan.js index c397ca7d1a..67540bb009 100644 --- a/api/server/middleware/checkBan.js +++ b/api/server/middleware/checkBan.js @@ -41,7 +41,7 @@ const banResponse = async (req, res) => { * @function * @param {Object} req - Express request object. * @param {Object} res - Express response object. - * @param {Function} next - Next middleware function. + * @param {import('express').NextFunction} next - Next middleware function. * * @returns {Promise} - Returns a Promise which when resolved calls next middleware if user or source IP is not banned. Otherwise calls `banResponse()` and sets ban details in `banCache`. */ diff --git a/api/server/middleware/concurrentLimiter.js b/api/server/middleware/concurrentLimiter.js index 58ff689a0b..21b3a86903 100644 --- a/api/server/middleware/concurrentLimiter.js +++ b/api/server/middleware/concurrentLimiter.js @@ -21,7 +21,7 @@ const { * @function * @param {Object} req - Express request object containing user information. * @param {Object} res - Express response object. - * @param {function} next - Express next middleware function. + * @param {import('express').NextFunction} next - Next middleware function. * @throws {Error} Throws an error if the user exceeds the concurrent request limit. */ const concurrentLimiter = async (req, res, next) => { diff --git a/api/server/middleware/index.js b/api/server/middleware/index.js index 3da9e06bd6..789ec6a82d 100644 --- a/api/server/middleware/index.js +++ b/api/server/middleware/index.js @@ -14,6 +14,7 @@ const checkInviteUser = require('./checkInviteUser'); const requireJwtAuth = require('./requireJwtAuth'); const validateModel = require('./validateModel'); const moderateText = require('./moderateText'); +const logHeaders = require('./logHeaders'); const setHeaders = require('./setHeaders'); const validate = require('./validate'); const limiters = require('./limiters'); @@ -31,6 +32,7 @@ module.exports = { checkBan, uaParser, setHeaders, + logHeaders, moderateText, validateModel, requireJwtAuth, diff --git a/api/server/middleware/limiters/importLimiters.js b/api/server/middleware/limiters/importLimiters.js index a21fa6453e..5e50046a30 100644 --- a/api/server/middleware/limiters/importLimiters.js +++ b/api/server/middleware/limiters/importLimiters.js @@ -1,6 +1,11 @@ +const Keyv = require('keyv'); const rateLimit = require('express-rate-limit'); +const { RedisStore } = require('rate-limit-redis'); const { ViolationTypes } = require('librechat-data-provider'); const logViolation = require('~/cache/logViolation'); +const { isEnabled } = require('~/server/utils'); +const keyvRedis = require('~/cache/keyvRedis'); +const { logger } = require('~/config'); const getEnvironmentVariables = () => { const IMPORT_IP_MAX = parseInt(process.env.IMPORT_IP_MAX) || 100; @@ -48,21 +53,39 @@ const createImportLimiters = () => { const { importIpWindowMs, importIpMax, importUserWindowMs, importUserMax } = getEnvironmentVariables(); - const importIpLimiter = rateLimit({ + const ipLimiterOptions = { windowMs: importIpWindowMs, max: importIpMax, handler: createImportHandler(), - }); - - const importUserLimiter = rateLimit({ + }; + const userLimiterOptions = { windowMs: importUserWindowMs, max: importUserMax, handler: createImportHandler(false), keyGenerator: function (req) { return req.user?.id; // Use the user ID or NULL if not available }, - }); + }; + if (isEnabled(process.env.USE_REDIS)) { + logger.debug('Using Redis for import rate limiters.'); + const keyv = new Keyv({ store: keyvRedis }); + const client = keyv.opts.store.redis; + const sendCommand = (...args) => client.call(...args); + const ipStore = new RedisStore({ + sendCommand, + prefix: 'import_ip_limiter:', + }); + const userStore = new RedisStore({ + sendCommand, + prefix: 'import_user_limiter:', + }); + ipLimiterOptions.store = ipStore; + userLimiterOptions.store = userStore; + } + + const importIpLimiter = rateLimit(ipLimiterOptions); + const importUserLimiter = rateLimit(userLimiterOptions); return { importIpLimiter, importUserLimiter }; }; diff --git a/api/server/middleware/limiters/loginLimiter.js b/api/server/middleware/limiters/loginLimiter.js index 937723e859..8cf10ccb12 100644 --- a/api/server/middleware/limiters/loginLimiter.js +++ b/api/server/middleware/limiters/loginLimiter.js @@ -1,6 +1,10 @@ +const Keyv = require('keyv'); const rateLimit = require('express-rate-limit'); -const { removePorts } = require('~/server/utils'); +const { RedisStore } = require('rate-limit-redis'); +const { removePorts, isEnabled } = require('~/server/utils'); +const keyvRedis = require('~/cache/keyvRedis'); const { logViolation } = require('~/cache'); +const { logger } = require('~/config'); const { LOGIN_WINDOW = 5, LOGIN_MAX = 7, LOGIN_VIOLATION_SCORE: score } = process.env; const windowMs = LOGIN_WINDOW * 60 * 1000; @@ -20,11 +24,25 @@ const handler = async (req, res) => { return res.status(429).json({ message }); }; -const loginLimiter = rateLimit({ +const limiterOptions = { windowMs, max, handler, keyGenerator: removePorts, -}); +}; + +if (isEnabled(process.env.USE_REDIS)) { + logger.debug('Using Redis for login rate limiter.'); + const keyv = new Keyv({ store: keyvRedis }); + const client = keyv.opts.store.redis; + const sendCommand = (...args) => client.call(...args); + const store = new RedisStore({ + sendCommand, + prefix: 'login_limiter:', + }); + limiterOptions.store = store; +} + +const loginLimiter = rateLimit(limiterOptions); module.exports = loginLimiter; diff --git a/api/server/middleware/limiters/messageLimiters.js b/api/server/middleware/limiters/messageLimiters.js index c84db1043c..fe4f75a9c6 100644 --- a/api/server/middleware/limiters/messageLimiters.js +++ b/api/server/middleware/limiters/messageLimiters.js @@ -1,6 +1,11 @@ +const Keyv = require('keyv'); const rateLimit = require('express-rate-limit'); +const { RedisStore } = require('rate-limit-redis'); const denyRequest = require('~/server/middleware/denyRequest'); +const { isEnabled } = require('~/server/utils'); +const keyvRedis = require('~/cache/keyvRedis'); const { logViolation } = require('~/cache'); +const { logger } = require('~/config'); const { MESSAGE_IP_MAX = 40, @@ -41,25 +46,49 @@ const createHandler = (ip = true) => { }; /** - * Message request rate limiter by IP + * Message request rate limiters */ -const messageIpLimiter = rateLimit({ +const ipLimiterOptions = { windowMs: ipWindowMs, max: ipMax, handler: createHandler(), -}); +}; -/** - * Message request rate limiter by userId - */ -const messageUserLimiter = rateLimit({ +const userLimiterOptions = { windowMs: userWindowMs, max: userMax, handler: createHandler(false), keyGenerator: function (req) { return req.user?.id; // Use the user ID or NULL if not available }, -}); +}; + +if (isEnabled(process.env.USE_REDIS)) { + logger.debug('Using Redis for message rate limiters.'); + const keyv = new Keyv({ store: keyvRedis }); + const client = keyv.opts.store.redis; + const sendCommand = (...args) => client.call(...args); + const ipStore = new RedisStore({ + sendCommand, + prefix: 'message_ip_limiter:', + }); + const userStore = new RedisStore({ + sendCommand, + prefix: 'message_user_limiter:', + }); + ipLimiterOptions.store = ipStore; + userLimiterOptions.store = userStore; +} + +/** + * Message request rate limiter by IP + */ +const messageIpLimiter = rateLimit(ipLimiterOptions); + +/** + * Message request rate limiter by userId + */ +const messageUserLimiter = rateLimit(userLimiterOptions); module.exports = { messageIpLimiter, diff --git a/api/server/middleware/limiters/registerLimiter.js b/api/server/middleware/limiters/registerLimiter.js index b069798b03..f9bf1215cd 100644 --- a/api/server/middleware/limiters/registerLimiter.js +++ b/api/server/middleware/limiters/registerLimiter.js @@ -1,6 +1,10 @@ +const Keyv = require('keyv'); const rateLimit = require('express-rate-limit'); -const { removePorts } = require('~/server/utils'); +const { RedisStore } = require('rate-limit-redis'); +const { removePorts, isEnabled } = require('~/server/utils'); +const keyvRedis = require('~/cache/keyvRedis'); const { logViolation } = require('~/cache'); +const { logger } = require('~/config'); const { REGISTER_WINDOW = 60, REGISTER_MAX = 5, REGISTRATION_VIOLATION_SCORE: score } = process.env; const windowMs = REGISTER_WINDOW * 60 * 1000; @@ -20,11 +24,25 @@ const handler = async (req, res) => { return res.status(429).json({ message }); }; -const registerLimiter = rateLimit({ +const limiterOptions = { windowMs, max, handler, keyGenerator: removePorts, -}); +}; + +if (isEnabled(process.env.USE_REDIS)) { + logger.debug('Using Redis for register rate limiter.'); + const keyv = new Keyv({ store: keyvRedis }); + const client = keyv.opts.store.redis; + const sendCommand = (...args) => client.call(...args); + const store = new RedisStore({ + sendCommand, + prefix: 'register_limiter:', + }); + limiterOptions.store = store; +} + +const registerLimiter = rateLimit(limiterOptions); module.exports = registerLimiter; diff --git a/api/server/middleware/limiters/resetPasswordLimiter.js b/api/server/middleware/limiters/resetPasswordLimiter.js index 5d2deb0282..9f56bd7949 100644 --- a/api/server/middleware/limiters/resetPasswordLimiter.js +++ b/api/server/middleware/limiters/resetPasswordLimiter.js @@ -1,7 +1,11 @@ +const Keyv = require('keyv'); const rateLimit = require('express-rate-limit'); +const { RedisStore } = require('rate-limit-redis'); const { ViolationTypes } = require('librechat-data-provider'); -const { removePorts } = require('~/server/utils'); +const { removePorts, isEnabled } = require('~/server/utils'); +const keyvRedis = require('~/cache/keyvRedis'); const { logViolation } = require('~/cache'); +const { logger } = require('~/config'); const { RESET_PASSWORD_WINDOW = 2, @@ -25,11 +29,25 @@ const handler = async (req, res) => { return res.status(429).json({ message }); }; -const resetPasswordLimiter = rateLimit({ +const limiterOptions = { windowMs, max, handler, keyGenerator: removePorts, -}); +}; + +if (isEnabled(process.env.USE_REDIS)) { + logger.debug('Using Redis for reset password rate limiter.'); + const keyv = new Keyv({ store: keyvRedis }); + const client = keyv.opts.store.redis; + const sendCommand = (...args) => client.call(...args); + const store = new RedisStore({ + sendCommand, + prefix: 'reset_password_limiter:', + }); + limiterOptions.store = store; +} + +const resetPasswordLimiter = rateLimit(limiterOptions); module.exports = resetPasswordLimiter; diff --git a/api/server/middleware/limiters/sttLimiters.js b/api/server/middleware/limiters/sttLimiters.js index 76f2944f0a..f9304637c4 100644 --- a/api/server/middleware/limiters/sttLimiters.js +++ b/api/server/middleware/limiters/sttLimiters.js @@ -1,6 +1,11 @@ +const Keyv = require('keyv'); const rateLimit = require('express-rate-limit'); +const { RedisStore } = require('rate-limit-redis'); const { ViolationTypes } = require('librechat-data-provider'); const logViolation = require('~/cache/logViolation'); +const { isEnabled } = require('~/server/utils'); +const keyvRedis = require('~/cache/keyvRedis'); +const { logger } = require('~/config'); const getEnvironmentVariables = () => { const STT_IP_MAX = parseInt(process.env.STT_IP_MAX) || 100; @@ -47,20 +52,40 @@ const createSTTHandler = (ip = true) => { const createSTTLimiters = () => { const { sttIpWindowMs, sttIpMax, sttUserWindowMs, sttUserMax } = getEnvironmentVariables(); - const sttIpLimiter = rateLimit({ + const ipLimiterOptions = { windowMs: sttIpWindowMs, max: sttIpMax, handler: createSTTHandler(), - }); + }; - const sttUserLimiter = rateLimit({ + const userLimiterOptions = { windowMs: sttUserWindowMs, max: sttUserMax, handler: createSTTHandler(false), keyGenerator: function (req) { return req.user?.id; // Use the user ID or NULL if not available }, - }); + }; + + if (isEnabled(process.env.USE_REDIS)) { + logger.debug('Using Redis for STT rate limiters.'); + const keyv = new Keyv({ store: keyvRedis }); + const client = keyv.opts.store.redis; + const sendCommand = (...args) => client.call(...args); + const ipStore = new RedisStore({ + sendCommand, + prefix: 'stt_ip_limiter:', + }); + const userStore = new RedisStore({ + sendCommand, + prefix: 'stt_user_limiter:', + }); + ipLimiterOptions.store = ipStore; + userLimiterOptions.store = userStore; + } + + const sttIpLimiter = rateLimit(ipLimiterOptions); + const sttUserLimiter = rateLimit(userLimiterOptions); return { sttIpLimiter, sttUserLimiter }; }; diff --git a/api/server/middleware/limiters/toolCallLimiter.js b/api/server/middleware/limiters/toolCallLimiter.js index 47dcaeabb4..7a867b5bcd 100644 --- a/api/server/middleware/limiters/toolCallLimiter.js +++ b/api/server/middleware/limiters/toolCallLimiter.js @@ -1,25 +1,46 @@ +const Keyv = require('keyv'); const rateLimit = require('express-rate-limit'); +const { RedisStore } = require('rate-limit-redis'); const { ViolationTypes } = require('librechat-data-provider'); const logViolation = require('~/cache/logViolation'); +const { isEnabled } = require('~/server/utils'); +const keyvRedis = require('~/cache/keyvRedis'); +const { logger } = require('~/config'); -const toolCallLimiter = rateLimit({ +const handler = async (req, res) => { + const type = ViolationTypes.TOOL_CALL_LIMIT; + const errorMessage = { + type, + max: 1, + limiter: 'user', + windowInMinutes: 1, + }; + + await logViolation(req, res, type, errorMessage, 0); + res.status(429).json({ message: 'Too many tool call requests. Try again later' }); +}; + +const limiterOptions = { windowMs: 1000, max: 1, - handler: async (req, res) => { - const type = ViolationTypes.TOOL_CALL_LIMIT; - const errorMessage = { - type, - max: 1, - limiter: 'user', - windowInMinutes: 1, - }; - - await logViolation(req, res, type, errorMessage, 0); - res.status(429).json({ message: 'Too many tool call requests. Try again later' }); - }, + handler, keyGenerator: function (req) { return req.user?.id; }, -}); +}; + +if (isEnabled(process.env.USE_REDIS)) { + logger.debug('Using Redis for tool call rate limiter.'); + const keyv = new Keyv({ store: keyvRedis }); + const client = keyv.opts.store.redis; + const sendCommand = (...args) => client.call(...args); + const store = new RedisStore({ + sendCommand, + prefix: 'tool_call_limiter:', + }); + limiterOptions.store = store; +} + +const toolCallLimiter = rateLimit(limiterOptions); module.exports = toolCallLimiter; diff --git a/api/server/middleware/limiters/ttsLimiters.js b/api/server/middleware/limiters/ttsLimiters.js index 5619a49b63..e13aaf48c3 100644 --- a/api/server/middleware/limiters/ttsLimiters.js +++ b/api/server/middleware/limiters/ttsLimiters.js @@ -1,6 +1,11 @@ +const Keyv = require('keyv'); const rateLimit = require('express-rate-limit'); +const { RedisStore } = require('rate-limit-redis'); const { ViolationTypes } = require('librechat-data-provider'); const logViolation = require('~/cache/logViolation'); +const { isEnabled } = require('~/server/utils'); +const keyvRedis = require('~/cache/keyvRedis'); +const { logger } = require('~/config'); const getEnvironmentVariables = () => { const TTS_IP_MAX = parseInt(process.env.TTS_IP_MAX) || 100; @@ -47,20 +52,40 @@ const createTTSHandler = (ip = true) => { const createTTSLimiters = () => { const { ttsIpWindowMs, ttsIpMax, ttsUserWindowMs, ttsUserMax } = getEnvironmentVariables(); - const ttsIpLimiter = rateLimit({ + const ipLimiterOptions = { windowMs: ttsIpWindowMs, max: ttsIpMax, handler: createTTSHandler(), - }); + }; - const ttsUserLimiter = rateLimit({ + const userLimiterOptions = { windowMs: ttsUserWindowMs, max: ttsUserMax, handler: createTTSHandler(false), keyGenerator: function (req) { return req.user?.id; // Use the user ID or NULL if not available }, - }); + }; + + if (isEnabled(process.env.USE_REDIS)) { + logger.debug('Using Redis for TTS rate limiters.'); + const keyv = new Keyv({ store: keyvRedis }); + const client = keyv.opts.store.redis; + const sendCommand = (...args) => client.call(...args); + const ipStore = new RedisStore({ + sendCommand, + prefix: 'tts_ip_limiter:', + }); + const userStore = new RedisStore({ + sendCommand, + prefix: 'tts_user_limiter:', + }); + ipLimiterOptions.store = ipStore; + userLimiterOptions.store = userStore; + } + + const ttsIpLimiter = rateLimit(ipLimiterOptions); + const ttsUserLimiter = rateLimit(userLimiterOptions); return { ttsIpLimiter, ttsUserLimiter }; }; diff --git a/api/server/middleware/limiters/uploadLimiters.js b/api/server/middleware/limiters/uploadLimiters.js index 71af164fde..9fffface61 100644 --- a/api/server/middleware/limiters/uploadLimiters.js +++ b/api/server/middleware/limiters/uploadLimiters.js @@ -1,6 +1,11 @@ +const Keyv = require('keyv'); const rateLimit = require('express-rate-limit'); +const { RedisStore } = require('rate-limit-redis'); const { ViolationTypes } = require('librechat-data-provider'); const logViolation = require('~/cache/logViolation'); +const { isEnabled } = require('~/server/utils'); +const keyvRedis = require('~/cache/keyvRedis'); +const { logger } = require('~/config'); const getEnvironmentVariables = () => { const FILE_UPLOAD_IP_MAX = parseInt(process.env.FILE_UPLOAD_IP_MAX) || 100; @@ -52,20 +57,40 @@ const createFileLimiters = () => { const { fileUploadIpWindowMs, fileUploadIpMax, fileUploadUserWindowMs, fileUploadUserMax } = getEnvironmentVariables(); - const fileUploadIpLimiter = rateLimit({ + const ipLimiterOptions = { windowMs: fileUploadIpWindowMs, max: fileUploadIpMax, handler: createFileUploadHandler(), - }); + }; - const fileUploadUserLimiter = rateLimit({ + const userLimiterOptions = { windowMs: fileUploadUserWindowMs, max: fileUploadUserMax, handler: createFileUploadHandler(false), keyGenerator: function (req) { return req.user?.id; // Use the user ID or NULL if not available }, - }); + }; + + if (isEnabled(process.env.USE_REDIS)) { + logger.debug('Using Redis for file upload rate limiters.'); + const keyv = new Keyv({ store: keyvRedis }); + const client = keyv.opts.store.redis; + const sendCommand = (...args) => client.call(...args); + const ipStore = new RedisStore({ + sendCommand, + prefix: 'file_upload_ip_limiter:', + }); + const userStore = new RedisStore({ + sendCommand, + prefix: 'file_upload_user_limiter:', + }); + ipLimiterOptions.store = ipStore; + userLimiterOptions.store = userStore; + } + + const fileUploadIpLimiter = rateLimit(ipLimiterOptions); + const fileUploadUserLimiter = rateLimit(userLimiterOptions); return { fileUploadIpLimiter, fileUploadUserLimiter }; }; diff --git a/api/server/middleware/limiters/verifyEmailLimiter.js b/api/server/middleware/limiters/verifyEmailLimiter.js index 770090dba5..0b245afbd1 100644 --- a/api/server/middleware/limiters/verifyEmailLimiter.js +++ b/api/server/middleware/limiters/verifyEmailLimiter.js @@ -1,7 +1,11 @@ +const Keyv = require('keyv'); const rateLimit = require('express-rate-limit'); +const { RedisStore } = require('rate-limit-redis'); const { ViolationTypes } = require('librechat-data-provider'); -const { removePorts } = require('~/server/utils'); +const { removePorts, isEnabled } = require('~/server/utils'); +const keyvRedis = require('~/cache/keyvRedis'); const { logViolation } = require('~/cache'); +const { logger } = require('~/config'); const { VERIFY_EMAIL_WINDOW = 2, @@ -25,11 +29,25 @@ const handler = async (req, res) => { return res.status(429).json({ message }); }; -const verifyEmailLimiter = rateLimit({ +const limiterOptions = { windowMs, max, handler, keyGenerator: removePorts, -}); +}; + +if (isEnabled(process.env.USE_REDIS)) { + logger.debug('Using Redis for verify email rate limiter.'); + const keyv = new Keyv({ store: keyvRedis }); + const client = keyv.opts.store.redis; + const sendCommand = (...args) => client.call(...args); + const store = new RedisStore({ + sendCommand, + prefix: 'verify_email_limiter:', + }); + limiterOptions.store = store; +} + +const verifyEmailLimiter = rateLimit(limiterOptions); module.exports = verifyEmailLimiter; diff --git a/api/server/middleware/logHeaders.js b/api/server/middleware/logHeaders.js new file mode 100644 index 0000000000..26ca04da38 --- /dev/null +++ b/api/server/middleware/logHeaders.js @@ -0,0 +1,32 @@ +const { logger } = require('~/config'); + +/** + * Middleware to log Forwarded Headers + * @function + * @param {ServerRequest} req - Express request object containing user information. + * @param {ServerResponse} res - Express response object. + * @param {import('express').NextFunction} next - Next middleware function. + * @throws {Error} Throws an error if the user exceeds the concurrent request limit. + */ +const logHeaders = (req, res, next) => { + try { + const forwardedHeaders = {}; + if (req.headers['x-forwarded-for']) { + forwardedHeaders['x-forwarded-for'] = req.headers['x-forwarded-for']; + } + if (req.headers['x-forwarded-host']) { + forwardedHeaders['x-forwarded-host'] = req.headers['x-forwarded-host']; + } + if (req.headers['x-forwarded-proto']) { + forwardedHeaders['x-forwarded-proto'] = req.headers['x-forwarded-proto']; + } + if (Object.keys(forwardedHeaders).length > 0) { + logger.debug('X-Forwarded headers detected in OAuth request:', forwardedHeaders); + } + } catch (error) { + logger.error('Error logging X-Forwarded headers:', error); + } + next(); +}; + +module.exports = logHeaders; diff --git a/api/server/middleware/requireLocalAuth.js b/api/server/middleware/requireLocalAuth.js index 8319baf345..a71bd6c5b0 100644 --- a/api/server/middleware/requireLocalAuth.js +++ b/api/server/middleware/requireLocalAuth.js @@ -1,32 +1,18 @@ const passport = require('passport'); -const DebugControl = require('../../utils/debug.js'); - -function log({ title, parameters }) { - DebugControl.log.functionName(title); - if (parameters) { - DebugControl.log.parameters(parameters); - } -} +const { logger } = require('~/config'); const requireLocalAuth = (req, res, next) => { passport.authenticate('local', (err, user, info) => { if (err) { - log({ - title: '(requireLocalAuth) Error at passport.authenticate', - parameters: [{ name: 'error', value: err }], - }); + logger.error('[requireLocalAuth] Error at passport.authenticate:', err); return next(err); } if (!user) { - log({ - title: '(requireLocalAuth) Error: No user', - }); + logger.debug('[requireLocalAuth] Error: No user'); return res.status(404).send(info); } if (info && info.message) { - log({ - title: '(requireLocalAuth) Error: ' + info.message, - }); + logger.debug('[requireLocalAuth] Error: ' + info.message); return res.status(422).send({ message: info.message }); } req.user = user; diff --git a/api/server/middleware/roles/generateCheckAccess.js b/api/server/middleware/roles/generateCheckAccess.js index ffc0ddc613..0f137c3c84 100644 --- a/api/server/middleware/roles/generateCheckAccess.js +++ b/api/server/middleware/roles/generateCheckAccess.js @@ -1,4 +1,42 @@ const { getRoleByName } = require('~/models/Role'); +const { logger } = require('~/config'); + +/** + * Core function to check if a user has one or more required permissions + * + * @param {object} user - The user object + * @param {PermissionTypes} permissionType - The type of permission to check + * @param {Permissions[]} permissions - The list of specific permissions to check + * @param {Record} [bodyProps] - An optional object where keys are permissions and values are arrays of properties to check + * @param {object} [checkObject] - The object to check properties against + * @returns {Promise} Whether the user has the required permissions + */ +const checkAccess = async (user, permissionType, permissions, bodyProps = {}, checkObject = {}) => { + if (!user) { + return false; + } + + const role = await getRoleByName(user.role); + if (role && role[permissionType]) { + const hasAnyPermission = permissions.some((permission) => { + if (role[permissionType][permission]) { + return true; + } + + if (bodyProps[permission] && checkObject) { + return bodyProps[permission].some((prop) => + Object.prototype.hasOwnProperty.call(checkObject, prop), + ); + } + + return false; + }); + + return hasAnyPermission; + } + + return false; +}; /** * Middleware to check if a user has one or more required permissions, optionally based on `req.body` properties. @@ -6,42 +44,35 @@ const { getRoleByName } = require('~/models/Role'); * @param {PermissionTypes} permissionType - The type of permission to check. * @param {Permissions[]} permissions - The list of specific permissions to check. * @param {Record} [bodyProps] - An optional object where keys are permissions and values are arrays of `req.body` properties to check. - * @returns {Function} Express middleware function. + * @returns {(req: ServerRequest, res: ServerResponse, next: NextFunction) => Promise} Express middleware function. */ const generateCheckAccess = (permissionType, permissions, bodyProps = {}) => { return async (req, res, next) => { try { - const { user } = req; - if (!user) { - return res.status(401).json({ message: 'Authorization required' }); - } - - const role = await getRoleByName(user.role); - if (role && role[permissionType]) { - const hasAnyPermission = permissions.some((permission) => { - if (role[permissionType][permission]) { - return true; - } - - if (bodyProps[permission] && req.body) { - return bodyProps[permission].some((prop) => - Object.prototype.hasOwnProperty.call(req.body, prop), - ); - } - - return false; - }); - - if (hasAnyPermission) { - return next(); - } + const hasAccess = await checkAccess( + req.user, + permissionType, + permissions, + bodyProps, + req.body, + ); + + if (hasAccess) { + return next(); } + logger.warn( + `[${permissionType}] Forbidden: Insufficient permissions for User ${req.user.id}: ${permissions.join(', ')}`, + ); return res.status(403).json({ message: 'Forbidden: Insufficient permissions' }); } catch (error) { + logger.error(error); return res.status(500).json({ message: `Server error: ${error.message}` }); } }; }; -module.exports = generateCheckAccess; +module.exports = { + checkAccess, + generateCheckAccess, +}; diff --git a/api/server/middleware/roles/index.js b/api/server/middleware/roles/index.js index 999c36481e..a9fc5b2a08 100644 --- a/api/server/middleware/roles/index.js +++ b/api/server/middleware/roles/index.js @@ -1,7 +1,8 @@ const checkAdmin = require('./checkAdmin'); -const generateCheckAccess = require('./generateCheckAccess'); +const { checkAccess, generateCheckAccess } = require('./generateCheckAccess'); module.exports = { checkAdmin, + checkAccess, generateCheckAccess, }; diff --git a/api/server/routes/__tests__/config.spec.js b/api/server/routes/__tests__/config.spec.js index 13af53f299..0bb80bb9ee 100644 --- a/api/server/routes/__tests__/config.spec.js +++ b/api/server/routes/__tests__/config.spec.js @@ -18,6 +18,7 @@ afterEach(() => { delete process.env.OPENID_ISSUER; delete process.env.OPENID_SESSION_SECRET; delete process.env.OPENID_BUTTON_LABEL; + delete process.env.OPENID_AUTO_REDIRECT; delete process.env.OPENID_AUTH_URL; delete process.env.GITHUB_CLIENT_ID; delete process.env.GITHUB_CLIENT_SECRET; diff --git a/api/server/routes/auth.js b/api/server/routes/auth.js index 3e86ffd868..2d9fae7ae7 100644 --- a/api/server/routes/auth.js +++ b/api/server/routes/auth.js @@ -7,8 +7,17 @@ const { } = require('~/server/controllers/AuthController'); const { loginController } = require('~/server/controllers/auth/LoginController'); const { logoutController } = require('~/server/controllers/auth/LogoutController'); +const { verify2FAWithTempToken } = require('~/server/controllers/auth/TwoFactorAuthController'); +const { + enable2FA, + verify2FA, + disable2FA, + regenerateBackupCodes, + confirm2FA, +} = require('~/server/controllers/TwoFactorController'); const { checkBan, + logHeaders, loginLimiter, requireJwtAuth, checkInviteUser, @@ -27,6 +36,7 @@ const ldapAuth = !!process.env.LDAP_URL && !!process.env.LDAP_USER_SEARCH_BASE; router.post('/logout', requireJwtAuth, logoutController); router.post( '/login', + logHeaders, loginLimiter, checkBan, ldapAuth ? requireLdapAuth : requireLocalAuth, @@ -50,4 +60,11 @@ router.post( ); router.post('/resetPassword', checkBan, validatePasswordReset, resetPasswordController); +router.get('/2fa/enable', requireJwtAuth, enable2FA); +router.post('/2fa/verify', requireJwtAuth, verify2FA); +router.post('/2fa/verify-temp', checkBan, verify2FAWithTempToken); +router.post('/2fa/confirm', requireJwtAuth, confirm2FA); +router.post('/2fa/disable', requireJwtAuth, disable2FA); +router.post('/2fa/backup/regenerate', requireJwtAuth, regenerateBackupCodes); + module.exports = router; diff --git a/api/server/routes/config.js b/api/server/routes/config.js index 2dbcca6d3b..4c16a651c7 100644 --- a/api/server/routes/config.js +++ b/api/server/routes/config.js @@ -47,16 +47,17 @@ router.get('/', async function (req, res) { githubLoginEnabled: !!process.env.GITHUB_CLIENT_ID && !!process.env.GITHUB_CLIENT_SECRET, googleLoginEnabled: !!process.env.GOOGLE_CLIENT_ID && !!process.env.GOOGLE_CLIENT_SECRET, appleLoginEnabled: - !!process.env.APPLE_CLIENT_ID && - !!process.env.APPLE_TEAM_ID && - !!process.env.APPLE_KEY_ID && - !!process.env.APPLE_PRIVATE_KEY_PATH, + !!process.env.APPLE_CLIENT_ID && + !!process.env.APPLE_TEAM_ID && + !!process.env.APPLE_KEY_ID && + !!process.env.APPLE_PRIVATE_KEY_PATH, openidLoginEnabled: !!process.env.OPENID_ENABLED && !!process.env.OPENID_SESSION_SECRET, openidMultiTenantEnabled: !!process.env.OPENID_MULTI_TENANT, openidLabel: process.env.OPENID_BUTTON_LABEL || 'Continue with OpenID', openidImageUrl: process.env.OPENID_IMAGE_URL, + openidAutoRedirect: isEnabled(process.env.OPENID_AUTO_REDIRECT), serverDomain: process.env.DOMAIN_SERVER || 'http://localhost:3080', emailLoginEnabled, registrationEnabled: !ldap?.enabled && isEnabled(process.env.ALLOW_REGISTRATION), @@ -79,6 +80,7 @@ router.get('/', async function (req, res) { publicSharedLinksEnabled, analyticsGtmId: process.env.ANALYTICS_GTM_ID, instanceProjectId: instanceProject._id.toString(), + bundlerURL: process.env.SANDPACK_BUNDLER_URL, }; if (ldap) { diff --git a/api/server/routes/files/files.js b/api/server/routes/files/files.js index c320f7705b..c371b8e28e 100644 --- a/api/server/routes/files/files.js +++ b/api/server/routes/files/files.js @@ -16,7 +16,7 @@ const { } = require('~/server/services/Files/process'); const { getStrategyFunctions } = require('~/server/services/Files/strategies'); const { getOpenAIClient } = require('~/server/controllers/assistants/helpers'); -const { loadAuthValues } = require('~/app/clients/tools/util'); +const { loadAuthValues } = require('~/server/services/Tools/credentials'); const { getAgent } = require('~/models/Agent'); const { getFiles } = require('~/models/File'); const { logger } = require('~/config'); diff --git a/api/server/routes/oauth.js b/api/server/routes/oauth.js index 1ca49e4ebc..fbb8ec395b 100644 --- a/api/server/routes/oauth.js +++ b/api/server/routes/oauth.js @@ -1,7 +1,7 @@ // file deepcode ignore NoRateLimitingForLogin: Rate limiting is handled by the `loginLimiter` middleware const express = require('express'); const passport = require('passport'); -const { loginLimiter, checkBan, checkDomainAllowed } = require('~/server/middleware'); +const { loginLimiter, logHeaders, checkBan, checkDomainAllowed } = require('~/server/middleware'); const { setAuthTokens } = require('~/server/services/AuthService'); const { logger } = require('~/config'); const { chooseOpenIdStrategy } = require('~/server/utils/openidHelper'); @@ -13,6 +13,7 @@ const domains = { server: process.env.DOMAIN_SERVER, }; +router.use(logHeaders); router.use(loginLimiter); const oauthHandler = async (req, res) => { @@ -31,8 +32,10 @@ const oauthHandler = async (req, res) => { router.get('/error', (req, res) => { // A single error message is pushed by passport when authentication fails. - logger.error('Error in OAuth authentication:', { message: req.session?.messages?.pop() }); - res.redirect(`${domains.client}/login`); + logger.error('Error in OAuth authentication:', { message: req.session.messages.pop() }); + + // Redirect to login page with auth_failed parameter to prevent infinite redirect loops + res.redirect(`${domains.client}/login?redirect=false`); }); /** diff --git a/api/server/services/ActionService.js b/api/server/services/ActionService.js index 660e7aeb0d..c332cdfcf1 100644 --- a/api/server/services/ActionService.js +++ b/api/server/services/ActionService.js @@ -161,9 +161,9 @@ async function createActionTool({ if (metadata.auth && metadata.auth.type !== AuthTypeEnum.None) { try { - const action_id = action.action_id; - const identifier = `${req.user.id}:${action.action_id}`; if (metadata.auth.type === AuthTypeEnum.OAuth && metadata.auth.authorization_url) { + const action_id = action.action_id; + const identifier = `${req.user.id}:${action.action_id}`; const requestLogin = async () => { const { args: _args, stepId, ...toolCall } = config.toolCall ?? {}; if (!stepId) { diff --git a/api/server/services/AppService.js b/api/server/services/AppService.js index d194d31a6b..3fdae6ac10 100644 --- a/api/server/services/AppService.js +++ b/api/server/services/AppService.js @@ -1,7 +1,15 @@ -const { FileSources, EModelEndpoint, getConfigDefaults } = require('librechat-data-provider'); +const { + FileSources, + EModelEndpoint, + loadOCRConfig, + processMCPEnv, + getConfigDefaults, +} = require('librechat-data-provider'); const { checkVariables, checkHealth, checkConfig, checkAzureVariables } = require('./start/checks'); const { azureAssistantsDefaults, assistantsConfigSetup } = require('./start/assistants'); +const { initializeAzureBlobService } = require('./Files/Azure/initialize'); const { initializeFirebase } = require('./Files/Firebase/initialize'); +const { initializeS3 } = require('./Files/S3/initialize'); const loadCustomConfig = require('./Config/loadCustomConfig'); const handleRateLimits = require('./Config/handleRateLimits'); const { loadDefaultInterface } = require('./start/interface'); @@ -25,6 +33,7 @@ const AppService = async (app) => { const config = (await loadCustomConfig()) ?? {}; const configDefaults = getConfigDefaults(); + const ocr = loadOCRConfig(config.ocr); const filteredTools = config.filteredTools; const includedTools = config.includedTools; const fileStrategy = config.fileStrategy ?? configDefaults.fileStrategy; @@ -37,6 +46,10 @@ const AppService = async (app) => { if (fileStrategy === FileSources.firebase) { initializeFirebase(); + } else if (fileStrategy === FileSources.azure) { + initializeAzureBlobService(); + } else if (fileStrategy === FileSources.s3) { + initializeS3(); } /** @type {Record { if (config.mcpServers != null) { const mcpManager = await getMCPManager(); - await mcpManager.initializeMCP(config.mcpServers); + await mcpManager.initializeMCP(config.mcpServers, processMCPEnv); await mcpManager.mapAvailableTools(availableTools); } @@ -57,6 +70,7 @@ const AppService = async (app) => { const interfaceConfig = await loadDefaultInterface(config, configDefaults); const defaultLocals = { + ocr, paths, fileStrategy, socialLogins, diff --git a/api/server/services/AppService.spec.js b/api/server/services/AppService.spec.js index 61ac80fc6c..e47bfe7d5d 100644 --- a/api/server/services/AppService.spec.js +++ b/api/server/services/AppService.spec.js @@ -120,6 +120,7 @@ describe('AppService', () => { }, }, paths: expect.anything(), + ocr: expect.anything(), imageOutputType: expect.any(String), fileConfig: undefined, secureImageLinks: undefined, @@ -588,4 +589,33 @@ describe('AppService updating app.locals and issuing warnings', () => { ); }); }); + + it('should not parse environment variable references in OCR config', async () => { + // Mock custom configuration with env variable references in OCR config + const mockConfig = { + ocr: { + apiKey: '${OCR_API_KEY_CUSTOM_VAR_NAME}', + baseURL: '${OCR_BASEURL_CUSTOM_VAR_NAME}', + strategy: 'mistral_ocr', + mistralModel: 'mistral-medium', + }, + }; + + require('./Config/loadCustomConfig').mockImplementationOnce(() => Promise.resolve(mockConfig)); + + // Set actual environment variables with different values + process.env.OCR_API_KEY_CUSTOM_VAR_NAME = 'actual-api-key'; + process.env.OCR_BASEURL_CUSTOM_VAR_NAME = 'https://actual-ocr-url.com'; + + // Initialize app + const app = { locals: {} }; + await AppService(app); + + // Verify that the raw string references were preserved and not interpolated + expect(app.locals.ocr).toBeDefined(); + expect(app.locals.ocr.apiKey).toEqual('${OCR_API_KEY_CUSTOM_VAR_NAME}'); + expect(app.locals.ocr.baseURL).toEqual('${OCR_BASEURL_CUSTOM_VAR_NAME}'); + expect(app.locals.ocr.strategy).toEqual('mistral_ocr'); + expect(app.locals.ocr.mistralModel).toEqual('mistral-medium'); + }); }); diff --git a/api/server/services/Config/getEndpointsConfig.js b/api/server/services/Config/getEndpointsConfig.js index 4f8bde68ad..016f5f7445 100644 --- a/api/server/services/Config/getEndpointsConfig.js +++ b/api/server/services/Config/getEndpointsConfig.js @@ -72,4 +72,15 @@ async function getEndpointsConfig(req) { return endpointsConfig; } -module.exports = { getEndpointsConfig }; +/** + * @param {ServerRequest} req + * @param {import('librechat-data-provider').AgentCapabilities} capability + * @returns {Promise} + */ +const checkCapability = async (req, capability) => { + const endpointsConfig = await getEndpointsConfig(req); + const capabilities = endpointsConfig?.[EModelEndpoint.agents]?.capabilities ?? []; + return capabilities.includes(capability); +}; + +module.exports = { getEndpointsConfig, checkCapability }; diff --git a/api/server/services/Config/loadConfigModels.js b/api/server/services/Config/loadConfigModels.js index 0df811468b..fc255b8c47 100644 --- a/api/server/services/Config/loadConfigModels.js +++ b/api/server/services/Config/loadConfigModels.js @@ -47,7 +47,7 @@ async function loadConfigModels(req) { ); /** - * @type {Record} + * @type {Record>} * Map for promises keyed by unique combination of baseURL and apiKey */ const fetchPromisesMap = {}; /** @@ -102,7 +102,7 @@ async function loadConfigModels(req) { for (const name of associatedNames) { const endpoint = endpointsMap[name]; - modelsConfig[name] = !modelData?.length ? endpoint.models.default ?? [] : modelData; + modelsConfig[name] = !modelData?.length ? (endpoint.models.default ?? []) : modelData; } } diff --git a/api/server/services/Config/loadDefaultModels.js b/api/server/services/Config/loadDefaultModels.js index 82db356841..db91c4101b 100644 --- a/api/server/services/Config/loadDefaultModels.js +++ b/api/server/services/Config/loadDefaultModels.js @@ -5,8 +5,8 @@ const { getGoogleModels, getBedrockModels, getAnthropicModels, - getChatGPTBrowserModels, } = require('~/server/services/ModelService'); +const { logger } = require('~/config'); /** * Loads the default models for the application. @@ -15,31 +15,68 @@ const { * @param {Express.Request} req - The Express request object. */ async function loadDefaultModels(req) { - const google = getGoogleModels(); - const openAI = await getOpenAIModels({ user: req.user.id }); - const anthropic = getAnthropicModels(); - const chatGPTBrowser = getChatGPTBrowserModels(); - const azureOpenAI = await getOpenAIModels({ user: req.user.id, azure: true }); - const gptPlugins = await getOpenAIModels({ - user: req.user.id, - azure: useAzurePlugins, - plugins: true, - }); - const assistants = await getOpenAIModels({ assistants: true }); - const azureAssistants = await getOpenAIModels({ azureAssistants: true }); + try { + const [ + openAI, + anthropic, + azureOpenAI, + gptPlugins, + assistants, + azureAssistants, + google, + bedrock, + ] = await Promise.all([ + getOpenAIModels({ user: req.user.id }).catch((error) => { + logger.error('Error fetching OpenAI models:', error); + return []; + }), + getAnthropicModels({ user: req.user.id }).catch((error) => { + logger.error('Error fetching Anthropic models:', error); + return []; + }), + getOpenAIModels({ user: req.user.id, azure: true }).catch((error) => { + logger.error('Error fetching Azure OpenAI models:', error); + return []; + }), + getOpenAIModels({ user: req.user.id, azure: useAzurePlugins, plugins: true }).catch( + (error) => { + logger.error('Error fetching Plugin models:', error); + return []; + }, + ), + getOpenAIModels({ assistants: true }).catch((error) => { + logger.error('Error fetching OpenAI Assistants API models:', error); + return []; + }), + getOpenAIModels({ azureAssistants: true }).catch((error) => { + logger.error('Error fetching Azure OpenAI Assistants API models:', error); + return []; + }), + Promise.resolve(getGoogleModels()).catch((error) => { + logger.error('Error getting Google models:', error); + return []; + }), + Promise.resolve(getBedrockModels()).catch((error) => { + logger.error('Error getting Bedrock models:', error); + return []; + }), + ]); - return { - [EModelEndpoint.openAI]: openAI, - [EModelEndpoint.agents]: openAI, - [EModelEndpoint.google]: google, - [EModelEndpoint.anthropic]: anthropic, - [EModelEndpoint.gptPlugins]: gptPlugins, - [EModelEndpoint.azureOpenAI]: azureOpenAI, - [EModelEndpoint.chatGPTBrowser]: chatGPTBrowser, - [EModelEndpoint.assistants]: assistants, - [EModelEndpoint.azureAssistants]: azureAssistants, - [EModelEndpoint.bedrock]: getBedrockModels(), - }; + return { + [EModelEndpoint.openAI]: openAI, + [EModelEndpoint.agents]: openAI, + [EModelEndpoint.google]: google, + [EModelEndpoint.anthropic]: anthropic, + [EModelEndpoint.gptPlugins]: gptPlugins, + [EModelEndpoint.azureOpenAI]: azureOpenAI, + [EModelEndpoint.assistants]: assistants, + [EModelEndpoint.azureAssistants]: azureAssistants, + [EModelEndpoint.bedrock]: bedrock, + }; + } catch (error) { + logger.error('Error fetching default models:', error); + throw new Error(`Failed to load default models: ${error.message}`); + } } module.exports = loadDefaultModels; diff --git a/api/server/services/Endpoints/agents/build.js b/api/server/services/Endpoints/agents/build.js index 027937e7fd..999cdc16be 100644 --- a/api/server/services/Endpoints/agents/build.js +++ b/api/server/services/Endpoints/agents/build.js @@ -2,15 +2,8 @@ const { loadAgent } = require('~/models/Agent'); const { logger } = require('~/config'); const buildOptions = (req, endpoint, parsedBody) => { - const { - spec, - iconURL, - agent_id, - instructions, - maxContextTokens, - resendFiles = true, - ...model_parameters - } = parsedBody; + const { spec, iconURL, agent_id, instructions, maxContextTokens, ...model_parameters } = + parsedBody; const agentPromise = loadAgent({ req, agent_id, @@ -24,7 +17,6 @@ const buildOptions = (req, endpoint, parsedBody) => { iconURL, endpoint, agent_id, - resendFiles, instructions, maxContextTokens, model_parameters, diff --git a/api/server/services/Endpoints/agents/initialize.js b/api/server/services/Endpoints/agents/initialize.js index 3e03a45125..04cd20d072 100644 --- a/api/server/services/Endpoints/agents/initialize.js +++ b/api/server/services/Endpoints/agents/initialize.js @@ -2,6 +2,7 @@ const { createContentAggregator, Providers } = require('@librechat/agents'); const { EModelEndpoint, getResponseSender, + AgentCapabilities, providerEndpointMap, } = require('librechat-data-provider'); const { @@ -15,36 +16,61 @@ const initCustom = require('~/server/services/Endpoints/custom/initialize'); const initGoogle = require('~/server/services/Endpoints/google/initialize'); const generateArtifactsPrompt = require('~/app/clients/prompts/artifacts'); const { getCustomEndpointConfig } = require('~/server/services/Config'); +const { processFiles } = require('~/server/services/Files/process'); const { loadAgentTools } = require('~/server/services/ToolService'); const AgentClient = require('~/server/controllers/agents/client'); +const { getConvoFiles } = require('~/models/Conversation'); +const { getToolFilesByIds } = require('~/models/File'); const { getModelMaxTokens } = require('~/utils'); const { getAgent } = require('~/models/Agent'); +const { getFiles } = require('~/models/File'); const { logger } = require('~/config'); const providerConfigMap = { + [Providers.XAI]: initCustom, + [Providers.OLLAMA]: initCustom, + [Providers.DEEPSEEK]: initCustom, + [Providers.OPENROUTER]: initCustom, [EModelEndpoint.openAI]: initOpenAI, + [EModelEndpoint.google]: initGoogle, [EModelEndpoint.azureOpenAI]: initOpenAI, [EModelEndpoint.anthropic]: initAnthropic, [EModelEndpoint.bedrock]: getBedrockOptions, - [EModelEndpoint.google]: initGoogle, - [Providers.OLLAMA]: initCustom, }; /** - * + * @param {ServerRequest} req * @param {Promise> | undefined} _attachments * @param {AgentToolResources | undefined} _tool_resources * @returns {Promise<{ attachments: Array | undefined, tool_resources: AgentToolResources | undefined }>} */ -const primeResources = async (_attachments, _tool_resources) => { +const primeResources = async (req, _attachments, _tool_resources) => { try { + /** @type {Array | undefined} */ + let attachments; + const tool_resources = _tool_resources ?? {}; + const isOCREnabled = (req.app.locals?.[EModelEndpoint.agents]?.capabilities ?? []).includes( + AgentCapabilities.ocr, + ); + if (tool_resources.ocr?.file_ids && isOCREnabled) { + const context = await getFiles( + { + file_id: { $in: tool_resources.ocr.file_ids }, + }, + {}, + {}, + ); + attachments = (attachments ?? []).concat(context); + } if (!_attachments) { - return { attachments: undefined, tool_resources: _tool_resources }; + return { attachments, tool_resources }; } /** @type {Array | undefined} */ const files = await _attachments; - const attachments = []; - const tool_resources = _tool_resources ?? {}; + if (!attachments) { + /** @type {Array} */ + attachments = []; + } for (const file of files) { if (!file) { @@ -79,7 +105,6 @@ const primeResources = async (_attachments, _tool_resources) => { * @param {ServerResponse} params.res * @param {Agent} params.agent * @param {object} [params.endpointOption] - * @param {AgentToolResources} [params.tool_resources] * @param {boolean} [params.isInitialAgent] * @returns {Promise} */ @@ -88,9 +113,30 @@ const initializeAgentOptions = async ({ res, agent, endpointOption, - tool_resources, isInitialAgent = false, }) => { + let currentFiles; + /** @type {Array} */ + const requestFiles = req.body.files ?? []; + if ( + isInitialAgent && + req.body.conversationId != null && + (agent.model_parameters?.resendFiles ?? true) === true + ) { + const fileIds = (await getConvoFiles(req.body.conversationId)) ?? []; + const toolFiles = await getToolFilesByIds(fileIds); + if (requestFiles.length || toolFiles.length) { + currentFiles = await processFiles(requestFiles.concat(toolFiles)); + } + } else if (isInitialAgent && requestFiles.length) { + currentFiles = await processFiles(requestFiles); + } + + const { attachments, tool_resources } = await primeResources( + req, + currentFiles, + agent.tool_resources, + ); const { tools, toolContextMap } = await loadAgentTools({ req, res, @@ -99,18 +145,19 @@ const initializeAgentOptions = async ({ }); const provider = agent.provider; + agent.endpoint = provider; let getOptions = providerConfigMap[provider]; - - if (!getOptions) { + if (!getOptions && providerConfigMap[provider.toLowerCase()] != null) { + agent.provider = provider.toLowerCase(); + getOptions = providerConfigMap[agent.provider]; + } else if (!getOptions) { const customEndpointConfig = await getCustomEndpointConfig(provider); if (!customEndpointConfig) { throw new Error(`Provider ${provider} not supported`); } getOptions = initCustom; agent.provider = Providers.OPENAI; - agent.endpoint = provider.toLowerCase(); } - const model_parameters = Object.assign( {}, agent.model_parameters ?? { model: agent.model }, @@ -134,6 +181,7 @@ const initializeAgentOptions = async ({ agent.provider = options.provider; } + /** @type {import('@librechat/agents').ClientOptions} */ agent.model_parameters = Object.assign(model_parameters, options.llmConfig); if (options.configOptions) { agent.model_parameters.configuration = options.configOptions; @@ -152,15 +200,18 @@ const initializeAgentOptions = async ({ const tokensModel = agent.provider === EModelEndpoint.azureOpenAI ? agent.model : agent.model_parameters.model; - + const maxTokens = agent.model_parameters.maxOutputTokens ?? agent.model_parameters.maxTokens ?? 0; + const maxContextTokens = + agent.model_parameters.maxContextTokens ?? + agent.max_context_tokens ?? + getModelMaxTokens(tokensModel, providerEndpointMap[provider]) ?? + 4096; return { ...agent, tools, + attachments, toolContextMap, - maxContextTokens: - agent.max_context_tokens ?? - getModelMaxTokens(tokensModel, providerEndpointMap[provider]) ?? - 4000, + maxContextTokens: (maxContextTokens - maxTokens) * 0.9, }; }; @@ -193,11 +244,6 @@ const initializeClient = async ({ req, res, endpointOption }) => { throw new Error('Agent not found'); } - const { attachments, tool_resources } = await primeResources( - endpointOption.attachments, - primaryAgent.tool_resources, - ); - const agentConfigs = new Map(); // Handle primary agent @@ -206,7 +252,6 @@ const initializeClient = async ({ req, res, endpointOption }) => { res, agent: primaryAgent, endpointOption, - tool_resources, isInitialAgent: true, }); @@ -236,18 +281,21 @@ const initializeClient = async ({ req, res, endpointOption }) => { const client = new AgentClient({ req, - agent: primaryConfig, + res, sender, - attachments, contentParts, + agentConfigs, eventHandlers, collectedUsage, + aggregateContent, artifactPromises, + agent: primaryConfig, spec: endpointOption.spec, iconURL: endpointOption.iconURL, - agentConfigs, endpoint: EModelEndpoint.agents, + attachments: primaryConfig.attachments, maxContextTokens: primaryConfig.maxContextTokens, + resendFiles: primaryConfig.model_parameters?.resendFiles ?? true, }); return { client }; diff --git a/api/server/services/Endpoints/agents/title.js b/api/server/services/Endpoints/agents/title.js index 56fd28668d..f25746582e 100644 --- a/api/server/services/Endpoints/agents/title.js +++ b/api/server/services/Endpoints/agents/title.js @@ -20,10 +20,19 @@ const addTitle = async (req, { text, response, client }) => { const titleCache = getLogStores(CacheKeys.GEN_TITLE); const key = `${req.user.id}-${response.conversationId}`; + const responseText = + response?.content && Array.isArray(response?.content) + ? response.content.reduce((acc, block) => { + if (block?.type === 'text') { + return acc + block.text; + } + return acc; + }, '') + : (response?.content ?? response?.text ?? ''); const title = await client.titleConvo({ text, - responseText: response?.text ?? '', + responseText, conversationId: response.conversationId, }); await titleCache.set(key, title, 120000); diff --git a/api/server/services/Endpoints/anthropic/build.js b/api/server/services/Endpoints/anthropic/build.js index 028da36407..2deab4b975 100644 --- a/api/server/services/Endpoints/anthropic/build.js +++ b/api/server/services/Endpoints/anthropic/build.js @@ -1,4 +1,4 @@ -const { removeNullishValues } = require('librechat-data-provider'); +const { removeNullishValues, anthropicSettings } = require('librechat-data-provider'); const generateArtifactsPrompt = require('~/app/clients/prompts/artifacts'); const buildOptions = (endpoint, parsedBody) => { @@ -6,8 +6,10 @@ const buildOptions = (endpoint, parsedBody) => { modelLabel, promptPrefix, maxContextTokens, - resendFiles = true, - promptCache = true, + resendFiles = anthropicSettings.resendFiles.default, + promptCache = anthropicSettings.promptCache.default, + thinking = anthropicSettings.thinking.default, + thinkingBudget = anthropicSettings.thinkingBudget.default, iconURL, greeting, spec, @@ -21,6 +23,8 @@ const buildOptions = (endpoint, parsedBody) => { promptPrefix, resendFiles, promptCache, + thinking, + thinkingBudget, iconURL, greeting, spec, diff --git a/api/server/services/Endpoints/anthropic/helpers.js b/api/server/services/Endpoints/anthropic/helpers.js new file mode 100644 index 0000000000..04e4efc61c --- /dev/null +++ b/api/server/services/Endpoints/anthropic/helpers.js @@ -0,0 +1,111 @@ +const { EModelEndpoint, anthropicSettings } = require('librechat-data-provider'); +const { matchModelName } = require('~/utils'); +const { logger } = require('~/config'); + +/** + * @param {string} modelName + * @returns {boolean} + */ +function checkPromptCacheSupport(modelName) { + const modelMatch = matchModelName(modelName, EModelEndpoint.anthropic); + if ( + modelMatch.includes('claude-3-5-sonnet-latest') || + modelMatch.includes('claude-3.5-sonnet-latest') + ) { + return false; + } + + if ( + modelMatch === 'claude-3-7-sonnet' || + modelMatch === 'claude-3-5-sonnet' || + modelMatch === 'claude-3-5-haiku' || + modelMatch === 'claude-3-haiku' || + modelMatch === 'claude-3-opus' || + modelMatch === 'claude-3.7-sonnet' || + modelMatch === 'claude-3.5-sonnet' || + modelMatch === 'claude-3.5-haiku' + ) { + return true; + } + + return false; +} + +/** + * Gets the appropriate headers for Claude models with cache control + * @param {string} model The model name + * @param {boolean} supportsCacheControl Whether the model supports cache control + * @returns {AnthropicClientOptions['extendedOptions']['defaultHeaders']|undefined} The headers object or undefined if not applicable + */ +function getClaudeHeaders(model, supportsCacheControl) { + if (!supportsCacheControl) { + return undefined; + } + + if (/claude-3[-.]5-sonnet/.test(model)) { + return { + 'anthropic-beta': 'max-tokens-3-5-sonnet-2024-07-15,prompt-caching-2024-07-31', + }; + } else if (/claude-3[-.]7/.test(model)) { + return { + 'anthropic-beta': + 'token-efficient-tools-2025-02-19,output-128k-2025-02-19,prompt-caching-2024-07-31', + }; + } else { + return { + 'anthropic-beta': 'prompt-caching-2024-07-31', + }; + } +} + +/** + * Configures reasoning-related options for Claude models + * @param {AnthropicClientOptions & { max_tokens?: number }} anthropicInput The request options object + * @param {Object} extendedOptions Additional client configuration options + * @param {boolean} extendedOptions.thinking Whether thinking is enabled in client config + * @param {number|null} extendedOptions.thinkingBudget The token budget for thinking + * @returns {Object} Updated request options + */ +function configureReasoning(anthropicInput, extendedOptions = {}) { + const updatedOptions = { ...anthropicInput }; + const currentMaxTokens = updatedOptions.max_tokens ?? updatedOptions.maxTokens; + if ( + extendedOptions.thinking && + updatedOptions?.model && + /claude-3[-.]7/.test(updatedOptions.model) + ) { + updatedOptions.thinking = { + type: 'enabled', + }; + } + + if (updatedOptions.thinking != null && extendedOptions.thinkingBudget != null) { + updatedOptions.thinking = { + ...updatedOptions.thinking, + budget_tokens: extendedOptions.thinkingBudget, + }; + } + + if ( + updatedOptions.thinking != null && + (currentMaxTokens == null || updatedOptions.thinking.budget_tokens > currentMaxTokens) + ) { + const maxTokens = anthropicSettings.maxOutputTokens.reset(updatedOptions.model); + updatedOptions.max_tokens = currentMaxTokens ?? maxTokens; + + logger.warn( + updatedOptions.max_tokens === maxTokens + ? '[AnthropicClient] max_tokens is not defined while thinking is enabled. Setting max_tokens to model default.' + : `[AnthropicClient] thinking budget_tokens (${updatedOptions.thinking.budget_tokens}) exceeds max_tokens (${updatedOptions.max_tokens}). Adjusting budget_tokens.`, + ); + + updatedOptions.thinking.budget_tokens = Math.min( + updatedOptions.thinking.budget_tokens, + Math.floor(updatedOptions.max_tokens * 0.9), + ); + } + + return updatedOptions; +} + +module.exports = { checkPromptCacheSupport, getClaudeHeaders, configureReasoning }; diff --git a/api/server/services/Endpoints/anthropic/initialize.js b/api/server/services/Endpoints/anthropic/initialize.js index ffd61441be..6c89eff463 100644 --- a/api/server/services/Endpoints/anthropic/initialize.js +++ b/api/server/services/Endpoints/anthropic/initialize.js @@ -27,6 +27,7 @@ const initializeClient = async ({ req, res, endpointOption, overrideModel, optio if (anthropicConfig) { clientOptions.streamRate = anthropicConfig.streamRate; + clientOptions.titleModel = anthropicConfig.titleModel; } /** @type {undefined | TBaseEndpoint} */ diff --git a/api/server/services/Endpoints/anthropic/llm.js b/api/server/services/Endpoints/anthropic/llm.js index 301d42712a..9f20b8e61d 100644 --- a/api/server/services/Endpoints/anthropic/llm.js +++ b/api/server/services/Endpoints/anthropic/llm.js @@ -1,5 +1,6 @@ const { HttpsProxyAgent } = require('https-proxy-agent'); const { anthropicSettings, removeNullishValues } = require('librechat-data-provider'); +const { checkPromptCacheSupport, getClaudeHeaders, configureReasoning } = require('./helpers'); /** * Generates configuration options for creating an Anthropic language model (LLM) instance. @@ -20,6 +21,14 @@ const { anthropicSettings, removeNullishValues } = require('librechat-data-provi * @returns {Object} Configuration options for creating an Anthropic LLM instance, with null and undefined values removed. */ function getLLMConfig(apiKey, options = {}) { + const systemOptions = { + thinking: options.modelOptions.thinking ?? anthropicSettings.thinking.default, + promptCache: options.modelOptions.promptCache ?? anthropicSettings.promptCache.default, + thinkingBudget: options.modelOptions.thinkingBudget ?? anthropicSettings.thinkingBudget.default, + }; + for (let key in systemOptions) { + delete options.modelOptions[key]; + } const defaultOptions = { model: anthropicSettings.model.default, maxOutputTokens: anthropicSettings.maxOutputTokens.default, @@ -29,19 +38,34 @@ function getLLMConfig(apiKey, options = {}) { const mergedOptions = Object.assign(defaultOptions, options.modelOptions); /** @type {AnthropicClientOptions} */ - const requestOptions = { + let requestOptions = { apiKey, model: mergedOptions.model, stream: mergedOptions.stream, temperature: mergedOptions.temperature, - topP: mergedOptions.topP, - topK: mergedOptions.topK, stopSequences: mergedOptions.stop, maxTokens: mergedOptions.maxOutputTokens || anthropicSettings.maxOutputTokens.reset(mergedOptions.model), clientOptions: {}, }; + requestOptions = configureReasoning(requestOptions, systemOptions); + + if (!/claude-3[-.]7/.test(mergedOptions.model)) { + requestOptions.topP = mergedOptions.topP; + requestOptions.topK = mergedOptions.topK; + } else if (requestOptions.thinking == null) { + requestOptions.topP = mergedOptions.topP; + requestOptions.topK = mergedOptions.topK; + } + + const supportsCacheControl = + systemOptions.promptCache === true && checkPromptCacheSupport(requestOptions.model); + const headers = getClaudeHeaders(requestOptions.model, supportsCacheControl); + if (headers) { + requestOptions.clientOptions.defaultHeaders = headers; + } + if (options.proxy) { requestOptions.clientOptions.httpAgent = new HttpsProxyAgent(options.proxy); } diff --git a/api/server/services/Endpoints/anthropic/llm.spec.js b/api/server/services/Endpoints/anthropic/llm.spec.js new file mode 100644 index 0000000000..9c453efb92 --- /dev/null +++ b/api/server/services/Endpoints/anthropic/llm.spec.js @@ -0,0 +1,153 @@ +const { anthropicSettings } = require('librechat-data-provider'); +const { getLLMConfig } = require('~/server/services/Endpoints/anthropic/llm'); + +jest.mock('https-proxy-agent', () => ({ + HttpsProxyAgent: jest.fn().mockImplementation((proxy) => ({ proxy })), +})); + +describe('getLLMConfig', () => { + it('should create a basic configuration with default values', () => { + const result = getLLMConfig('test-api-key', { modelOptions: {} }); + + expect(result.llmConfig).toHaveProperty('apiKey', 'test-api-key'); + expect(result.llmConfig).toHaveProperty('model', anthropicSettings.model.default); + expect(result.llmConfig).toHaveProperty('stream', true); + expect(result.llmConfig).toHaveProperty('maxTokens'); + }); + + it('should include proxy settings when provided', () => { + const result = getLLMConfig('test-api-key', { + modelOptions: {}, + proxy: 'http://proxy:8080', + }); + + expect(result.llmConfig.clientOptions).toHaveProperty('httpAgent'); + expect(result.llmConfig.clientOptions.httpAgent).toHaveProperty('proxy', 'http://proxy:8080'); + }); + + it('should include reverse proxy URL when provided', () => { + const result = getLLMConfig('test-api-key', { + modelOptions: {}, + reverseProxyUrl: 'http://reverse-proxy', + }); + + expect(result.llmConfig.clientOptions).toHaveProperty('baseURL', 'http://reverse-proxy'); + }); + + it('should include topK and topP for non-Claude-3.7 models', () => { + const result = getLLMConfig('test-api-key', { + modelOptions: { + model: 'claude-3-opus', + topK: 10, + topP: 0.9, + }, + }); + + expect(result.llmConfig).toHaveProperty('topK', 10); + expect(result.llmConfig).toHaveProperty('topP', 0.9); + }); + + it('should include topK and topP for Claude-3.5 models', () => { + const result = getLLMConfig('test-api-key', { + modelOptions: { + model: 'claude-3-5-sonnet', + topK: 10, + topP: 0.9, + }, + }); + + expect(result.llmConfig).toHaveProperty('topK', 10); + expect(result.llmConfig).toHaveProperty('topP', 0.9); + }); + + it('should NOT include topK and topP for Claude-3-7 models (hyphen notation)', () => { + const result = getLLMConfig('test-api-key', { + modelOptions: { + model: 'claude-3-7-sonnet', + topK: 10, + topP: 0.9, + }, + }); + + expect(result.llmConfig).not.toHaveProperty('topK'); + expect(result.llmConfig).not.toHaveProperty('topP'); + }); + + it('should NOT include topK and topP for Claude-3.7 models (decimal notation)', () => { + const result = getLLMConfig('test-api-key', { + modelOptions: { + model: 'claude-3.7-sonnet', + topK: 10, + topP: 0.9, + }, + }); + + expect(result.llmConfig).not.toHaveProperty('topK'); + expect(result.llmConfig).not.toHaveProperty('topP'); + }); + + it('should handle custom maxOutputTokens', () => { + const result = getLLMConfig('test-api-key', { + modelOptions: { + model: 'claude-3-opus', + maxOutputTokens: 2048, + }, + }); + + expect(result.llmConfig).toHaveProperty('maxTokens', 2048); + }); + + it('should handle promptCache setting', () => { + const result = getLLMConfig('test-api-key', { + modelOptions: { + model: 'claude-3-5-sonnet', + promptCache: true, + }, + }); + + // We're not checking specific header values since that depends on the actual helper function + // Just verifying that the promptCache setting is processed + expect(result.llmConfig).toBeDefined(); + }); + + it('should include topK and topP for Claude-3.7 models when thinking is not enabled', () => { + // Test with thinking explicitly set to null/undefined + const result = getLLMConfig('test-api-key', { + modelOptions: { + model: 'claude-3-7-sonnet', + topK: 10, + topP: 0.9, + thinking: false, + }, + }); + + expect(result.llmConfig).toHaveProperty('topK', 10); + expect(result.llmConfig).toHaveProperty('topP', 0.9); + + // Test with thinking explicitly set to false + const result2 = getLLMConfig('test-api-key', { + modelOptions: { + model: 'claude-3-7-sonnet', + topK: 10, + topP: 0.9, + thinking: false, + }, + }); + + expect(result2.llmConfig).toHaveProperty('topK', 10); + expect(result2.llmConfig).toHaveProperty('topP', 0.9); + + // Test with decimal notation as well + const result3 = getLLMConfig('test-api-key', { + modelOptions: { + model: 'claude-3.7-sonnet', + topK: 10, + topP: 0.9, + thinking: false, + }, + }); + + expect(result3.llmConfig).toHaveProperty('topK', 10); + expect(result3.llmConfig).toHaveProperty('topP', 0.9); + }); +}); diff --git a/api/server/services/Endpoints/bedrock/build.js b/api/server/services/Endpoints/bedrock/build.js index d6fb0636a9..f5228160fc 100644 --- a/api/server/services/Endpoints/bedrock/build.js +++ b/api/server/services/Endpoints/bedrock/build.js @@ -1,6 +1,5 @@ -const { removeNullishValues, bedrockInputParser } = require('librechat-data-provider'); +const { removeNullishValues } = require('librechat-data-provider'); const generateArtifactsPrompt = require('~/app/clients/prompts/artifacts'); -const { logger } = require('~/config'); const buildOptions = (endpoint, parsedBody) => { const { @@ -15,12 +14,6 @@ const buildOptions = (endpoint, parsedBody) => { artifacts, ...model_parameters } = parsedBody; - let parsedParams = model_parameters; - try { - parsedParams = bedrockInputParser.parse(model_parameters); - } catch (error) { - logger.warn('Failed to parse bedrock input', error); - } const endpointOption = removeNullishValues({ endpoint, name, @@ -31,7 +24,7 @@ const buildOptions = (endpoint, parsedBody) => { spec, promptPrefix, maxContextTokens, - model_parameters: parsedParams, + model_parameters, }); if (typeof artifacts === 'string') { diff --git a/api/server/services/Endpoints/bedrock/initialize.js b/api/server/services/Endpoints/bedrock/initialize.js index 3ffa03393d..4d9ba361cf 100644 --- a/api/server/services/Endpoints/bedrock/initialize.js +++ b/api/server/services/Endpoints/bedrock/initialize.js @@ -23,8 +23,9 @@ const initializeClient = async ({ req, res, endpointOption }) => { const agent = { id: EModelEndpoint.bedrock, name: endpointOption.name, - instructions: endpointOption.promptPrefix, provider: EModelEndpoint.bedrock, + endpoint: EModelEndpoint.bedrock, + instructions: endpointOption.promptPrefix, model: endpointOption.model_parameters.model, model_parameters: endpointOption.model_parameters, }; @@ -54,6 +55,7 @@ const initializeClient = async ({ req, res, endpointOption }) => { const client = new AgentClient({ req, + res, agent, sender, // tools, diff --git a/api/server/services/Endpoints/bedrock/options.js b/api/server/services/Endpoints/bedrock/options.js index 11b33a5357..6740ae882e 100644 --- a/api/server/services/Endpoints/bedrock/options.js +++ b/api/server/services/Endpoints/bedrock/options.js @@ -1,14 +1,16 @@ const { HttpsProxyAgent } = require('https-proxy-agent'); const { - EModelEndpoint, - Constants, AuthType, + Constants, + EModelEndpoint, + bedrockInputParser, + bedrockOutputParser, removeNullishValues, } = require('librechat-data-provider'); const { getUserKey, checkUserKeyExpiry } = require('~/server/services/UserService'); const { sleep } = require('~/server/utils'); -const getOptions = async ({ req, endpointOption }) => { +const getOptions = async ({ req, overrideModel, endpointOption }) => { const { BEDROCK_AWS_SECRET_ACCESS_KEY, BEDROCK_AWS_ACCESS_KEY_ID, @@ -62,39 +64,44 @@ const getOptions = async ({ req, endpointOption }) => { /** @type {BedrockClientOptions} */ const requestOptions = { - model: endpointOption.model, + model: overrideModel ?? endpointOption.model, region: BEDROCK_AWS_DEFAULT_REGION, - streaming: true, - streamUsage: true, - callbacks: [ - { - handleLLMNewToken: async () => { - if (!streamRate) { - return; - } - await sleep(streamRate); - }, - }, - ], }; - if (credentials) { - requestOptions.credentials = credentials; - } - - if (BEDROCK_REVERSE_PROXY) { - requestOptions.endpointHost = BEDROCK_REVERSE_PROXY; - } - const configOptions = {}; if (PROXY) { /** NOTE: NOT SUPPORTED BY BEDROCK */ configOptions.httpAgent = new HttpsProxyAgent(PROXY); } + const llmConfig = bedrockOutputParser( + bedrockInputParser.parse( + removeNullishValues(Object.assign(requestOptions, endpointOption.model_parameters)), + ), + ); + + if (credentials) { + llmConfig.credentials = credentials; + } + + if (BEDROCK_REVERSE_PROXY) { + llmConfig.endpointHost = BEDROCK_REVERSE_PROXY; + } + + llmConfig.callbacks = [ + { + handleLLMNewToken: async () => { + if (!streamRate) { + return; + } + await sleep(streamRate); + }, + }, + ]; + return { /** @type {BedrockClientOptions} */ - llmConfig: removeNullishValues(Object.assign(requestOptions, endpointOption.model_parameters)), + llmConfig, configOptions, }; }; diff --git a/api/server/services/Endpoints/custom/initialize.js b/api/server/services/Endpoints/custom/initialize.js index fe2beba582..e98ec71980 100644 --- a/api/server/services/Endpoints/custom/initialize.js +++ b/api/server/services/Endpoints/custom/initialize.js @@ -141,7 +141,8 @@ const initializeClient = async ({ req, res, endpointOption, optionsOnly, overrid }, clientOptions, ); - const options = getLLMConfig(apiKey, clientOptions); + clientOptions.modelOptions.user = req.user.id; + const options = getLLMConfig(apiKey, clientOptions, endpoint); if (!customOptions.streamRate) { return options; } diff --git a/api/server/services/Endpoints/google/initialize.js b/api/server/services/Endpoints/google/initialize.js index c157dd8b28..b7419a8a87 100644 --- a/api/server/services/Endpoints/google/initialize.js +++ b/api/server/services/Endpoints/google/initialize.js @@ -5,12 +5,7 @@ const { isEnabled } = require('~/server/utils'); const { GoogleClient } = require('~/app'); const initializeClient = async ({ req, res, endpointOption, overrideModel, optionsOnly }) => { - const { - GOOGLE_KEY, - GOOGLE_REVERSE_PROXY, - GOOGLE_AUTH_HEADER, - PROXY, - } = process.env; + const { GOOGLE_KEY, GOOGLE_REVERSE_PROXY, GOOGLE_AUTH_HEADER, PROXY } = process.env; const isUserProvided = GOOGLE_KEY === 'user_provided'; const { key: expiresAt } = req.body; @@ -43,6 +38,7 @@ const initializeClient = async ({ req, res, endpointOption, overrideModel, optio if (googleConfig) { clientOptions.streamRate = googleConfig.streamRate; + clientOptions.titleModel = googleConfig.titleModel; } if (allConfig) { diff --git a/api/server/services/Endpoints/openAI/initialize.js b/api/server/services/Endpoints/openAI/initialize.js index 0eb0d566b9..4d358cef1a 100644 --- a/api/server/services/Endpoints/openAI/initialize.js +++ b/api/server/services/Endpoints/openAI/initialize.js @@ -113,6 +113,7 @@ const initializeClient = async ({ if (!isAzureOpenAI && openAIConfig) { clientOptions.streamRate = openAIConfig.streamRate; + clientOptions.titleModel = openAIConfig.titleModel; } /** @type {undefined | TBaseEndpoint} */ @@ -134,12 +135,10 @@ const initializeClient = async ({ } if (optionsOnly) { - clientOptions = Object.assign( - { - modelOptions: endpointOption.model_parameters, - }, - clientOptions, - ); + const modelOptions = endpointOption.model_parameters; + modelOptions.model = modelName; + clientOptions = Object.assign({ modelOptions }, clientOptions); + clientOptions.modelOptions.user = req.user.id; const options = getLLMConfig(apiKey, clientOptions); if (!clientOptions.streamRate) { return options; diff --git a/api/server/services/Endpoints/openAI/llm.js b/api/server/services/Endpoints/openAI/llm.js index 2587b242c9..a8aeeb5b9d 100644 --- a/api/server/services/Endpoints/openAI/llm.js +++ b/api/server/services/Endpoints/openAI/llm.js @@ -1,4 +1,5 @@ const { HttpsProxyAgent } = require('https-proxy-agent'); +const { KnownEndpoints } = require('librechat-data-provider'); const { sanitizeModelName, constructAzureURL } = require('~/utils'); const { isEnabled } = require('~/server/utils'); @@ -8,6 +9,7 @@ const { isEnabled } = require('~/server/utils'); * @param {Object} options - Additional options for configuring the LLM. * @param {Object} [options.modelOptions] - Model-specific options. * @param {string} [options.modelOptions.model] - The name of the model to use. + * @param {string} [options.modelOptions.user] - The user ID * @param {number} [options.modelOptions.temperature] - Controls randomness in output generation (0-2). * @param {number} [options.modelOptions.top_p] - Controls diversity via nucleus sampling (0-1). * @param {number} [options.modelOptions.frequency_penalty] - Reduces repetition of token sequences (-2 to 2). @@ -22,13 +24,13 @@ const { isEnabled } = require('~/server/utils'); * @param {boolean} [options.streaming] - Whether to use streaming mode. * @param {Object} [options.addParams] - Additional parameters to add to the model options. * @param {string[]} [options.dropParams] - Parameters to remove from the model options. + * @param {string|null} [endpoint=null] - The endpoint name * @returns {Object} Configuration options for creating an LLM instance. */ -function getLLMConfig(apiKey, options = {}) { - const { +function getLLMConfig(apiKey, options = {}, endpoint = null) { + let { modelOptions = {}, reverseProxyUrl, - useOpenRouter, defaultQuery, headers, proxy, @@ -48,19 +50,45 @@ function getLLMConfig(apiKey, options = {}) { if (addParams && typeof addParams === 'object') { Object.assign(llmConfig, addParams); } + /** Note: OpenAI Web Search models do not support any known parameters besdies `max_tokens` */ + if (modelOptions.model && /gpt-4o.*search/.test(modelOptions.model)) { + const searchExcludeParams = [ + 'frequency_penalty', + 'presence_penalty', + 'temperature', + 'top_p', + 'top_k', + 'stop', + 'logit_bias', + 'seed', + 'response_format', + 'n', + 'logprobs', + 'user', + ]; + + dropParams = dropParams || []; + dropParams = [...new Set([...dropParams, ...searchExcludeParams])]; + } if (dropParams && Array.isArray(dropParams)) { dropParams.forEach((param) => { - delete llmConfig[param]; + if (llmConfig[param]) { + llmConfig[param] = undefined; + } }); } + let useOpenRouter; /** @type {OpenAIClientOptions['configuration']} */ const configOptions = {}; - - // Handle OpenRouter or custom reverse proxy - if (useOpenRouter || reverseProxyUrl === 'https://openrouter.ai/api/v1') { - configOptions.baseURL = 'https://openrouter.ai/api/v1'; + if ( + (reverseProxyUrl && reverseProxyUrl.includes(KnownEndpoints.openrouter)) || + (endpoint && endpoint.toLowerCase().includes(KnownEndpoints.openrouter)) + ) { + useOpenRouter = true; + llmConfig.include_reasoning = true; + configOptions.baseURL = reverseProxyUrl; configOptions.defaultHeaders = Object.assign( { 'HTTP-Referer': 'https://librechat.ai', @@ -118,6 +146,13 @@ function getLLMConfig(apiKey, options = {}) { llmConfig.organization = process.env.OPENAI_ORGANIZATION; } + if (useOpenRouter && llmConfig.reasoning_effort != null) { + llmConfig.reasoning = { + effort: llmConfig.reasoning_effort, + }; + delete llmConfig.reasoning_effort; + } + return { /** @type {OpenAIClientOptions} */ llmConfig, diff --git a/api/server/services/Files/Azure/crud.js b/api/server/services/Files/Azure/crud.js new file mode 100644 index 0000000000..638da34b27 --- /dev/null +++ b/api/server/services/Files/Azure/crud.js @@ -0,0 +1,196 @@ +const fs = require('fs'); +const path = require('path'); +const axios = require('axios'); +const fetch = require('node-fetch'); +const { logger } = require('~/config'); +const { getAzureContainerClient } = require('./initialize'); + +const defaultBasePath = 'images'; + +/** + * Uploads a buffer to Azure Blob Storage. + * + * Files will be stored at the path: {basePath}/{userId}/{fileName} within the container. + * + * @param {Object} params + * @param {string} params.userId - The user's id. + * @param {Buffer} params.buffer - The buffer to upload. + * @param {string} params.fileName - The name of the file. + * @param {string} [params.basePath='images'] - The base folder within the container. + * @param {string} [params.containerName] - The Azure Blob container name. + * @returns {Promise} The URL of the uploaded blob. + */ +async function saveBufferToAzure({ + userId, + buffer, + fileName, + basePath = defaultBasePath, + containerName, +}) { + try { + const containerClient = getAzureContainerClient(containerName); + // Create the container if it doesn't exist. This is done per operation. + await containerClient.createIfNotExists({ + access: process.env.AZURE_STORAGE_PUBLIC_ACCESS ? 'blob' : undefined, + }); + const blobPath = `${basePath}/${userId}/${fileName}`; + const blockBlobClient = containerClient.getBlockBlobClient(blobPath); + await blockBlobClient.uploadData(buffer); + return blockBlobClient.url; + } catch (error) { + logger.error('[saveBufferToAzure] Error uploading buffer:', error); + throw error; + } +} + +/** + * Saves a file from a URL to Azure Blob Storage. + * + * @param {Object} params + * @param {string} params.userId - The user's id. + * @param {string} params.URL - The URL of the file. + * @param {string} params.fileName - The name of the file. + * @param {string} [params.basePath='images'] - The base folder within the container. + * @param {string} [params.containerName] - The Azure Blob container name. + * @returns {Promise} The URL of the uploaded blob. + */ +async function saveURLToAzure({ + userId, + URL, + fileName, + basePath = defaultBasePath, + containerName, +}) { + try { + const response = await fetch(URL); + const buffer = await response.buffer(); + return await saveBufferToAzure({ userId, buffer, fileName, basePath, containerName }); + } catch (error) { + logger.error('[saveURLToAzure] Error uploading file from URL:', error); + throw error; + } +} + +/** + * Retrieves a blob URL from Azure Blob Storage. + * + * @param {Object} params + * @param {string} params.fileName - The file name. + * @param {string} [params.basePath='images'] - The base folder used during upload. + * @param {string} [params.userId] - If files are stored in a user-specific directory. + * @param {string} [params.containerName] - The Azure Blob container name. + * @returns {Promise} The blob's URL. + */ +async function getAzureURL({ fileName, basePath = defaultBasePath, userId, containerName }) { + try { + const containerClient = getAzureContainerClient(containerName); + const blobPath = userId ? `${basePath}/${userId}/${fileName}` : `${basePath}/${fileName}`; + const blockBlobClient = containerClient.getBlockBlobClient(blobPath); + return blockBlobClient.url; + } catch (error) { + logger.error('[getAzureURL] Error retrieving blob URL:', error); + throw error; + } +} + +/** + * Deletes a blob from Azure Blob Storage. + * + * @param {Object} params + * @param {string} params.fileName - The name of the file. + * @param {string} [params.basePath='images'] - The base folder where the file is stored. + * @param {string} params.userId - The user's id. + * @param {string} [params.containerName] - The Azure Blob container name. + */ +async function deleteFileFromAzure({ + fileName, + basePath = defaultBasePath, + userId, + containerName, +}) { + try { + const containerClient = getAzureContainerClient(containerName); + const blobPath = `${basePath}/${userId}/${fileName}`; + const blockBlobClient = containerClient.getBlockBlobClient(blobPath); + await blockBlobClient.delete(); + logger.debug('[deleteFileFromAzure] Blob deleted successfully from Azure Blob Storage'); + } catch (error) { + logger.error('[deleteFileFromAzure] Error deleting blob:', error.message); + if (error.statusCode === 404) { + return; + } + throw error; + } +} + +/** + * Uploads a file from the local file system to Azure Blob Storage. + * + * This function reads the file from disk and then uploads it to Azure Blob Storage + * at the path: {basePath}/{userId}/{fileName}. + * + * @param {Object} params + * @param {object} params.req - The Express request object. + * @param {Express.Multer.File} params.file - The file object. + * @param {string} params.file_id - The file id. + * @param {string} [params.basePath='images'] - The base folder within the container. + * @param {string} [params.containerName] - The Azure Blob container name. + * @returns {Promise<{ filepath: string, bytes: number }>} An object containing the blob URL and its byte size. + */ +async function uploadFileToAzure({ + req, + file, + file_id, + basePath = defaultBasePath, + containerName, +}) { + try { + const inputFilePath = file.path; + const inputBuffer = await fs.promises.readFile(inputFilePath); + const bytes = Buffer.byteLength(inputBuffer); + const userId = req.user.id; + const fileName = `${file_id}__${path.basename(inputFilePath)}`; + const fileURL = await saveBufferToAzure({ + userId, + buffer: inputBuffer, + fileName, + basePath, + containerName, + }); + await fs.promises.unlink(inputFilePath); + return { filepath: fileURL, bytes }; + } catch (error) { + logger.error('[uploadFileToAzure] Error uploading file:', error); + throw error; + } +} + +/** + * Retrieves a readable stream for a blob from Azure Blob Storage. + * + * @param {object} _req - The Express request object. + * @param {string} fileURL - The URL of the blob. + * @returns {Promise} A readable stream of the blob. + */ +async function getAzureFileStream(_req, fileURL) { + try { + const response = await axios({ + method: 'get', + url: fileURL, + responseType: 'stream', + }); + return response.data; + } catch (error) { + logger.error('[getAzureFileStream] Error getting blob stream:', error); + throw error; + } +} + +module.exports = { + saveBufferToAzure, + saveURLToAzure, + getAzureURL, + deleteFileFromAzure, + uploadFileToAzure, + getAzureFileStream, +}; diff --git a/api/server/services/Files/Azure/images.js b/api/server/services/Files/Azure/images.js new file mode 100644 index 0000000000..a83b700af3 --- /dev/null +++ b/api/server/services/Files/Azure/images.js @@ -0,0 +1,124 @@ +const fs = require('fs'); +const path = require('path'); +const sharp = require('sharp'); +const { resizeImageBuffer } = require('../images/resize'); +const { updateUser } = require('~/models/userMethods'); +const { updateFile } = require('~/models/File'); +const { logger } = require('~/config'); +const { saveBufferToAzure } = require('./crud'); + +/** + * Uploads an image file to Azure Blob Storage. + * It resizes and converts the image similar to your Firebase implementation. + * + * @param {Object} params + * @param {object} params.req - The Express request object. + * @param {Express.Multer.File} params.file - The file object. + * @param {string} params.file_id - The file id. + * @param {EModelEndpoint} params.endpoint - The endpoint parameters. + * @param {string} [params.resolution='high'] - The image resolution. + * @param {string} [params.basePath='images'] - The base folder within the container. + * @param {string} [params.containerName] - The Azure Blob container name. + * @returns {Promise<{ filepath: string, bytes: number, width: number, height: number }>} + */ +async function uploadImageToAzure({ + req, + file, + file_id, + endpoint, + resolution = 'high', + basePath = 'images', + containerName, +}) { + try { + const inputFilePath = file.path; + const inputBuffer = await fs.promises.readFile(inputFilePath); + const { + buffer: resizedBuffer, + width, + height, + } = await resizeImageBuffer(inputBuffer, resolution, endpoint); + const extension = path.extname(inputFilePath); + const userId = req.user.id; + let webPBuffer; + let fileName = `${file_id}__${path.basename(inputFilePath)}`; + const targetExtension = `.${req.app.locals.imageOutputType}`; + + if (extension.toLowerCase() === targetExtension) { + webPBuffer = resizedBuffer; + } else { + webPBuffer = await sharp(resizedBuffer).toFormat(req.app.locals.imageOutputType).toBuffer(); + const extRegExp = new RegExp(path.extname(fileName) + '$'); + fileName = fileName.replace(extRegExp, targetExtension); + if (!path.extname(fileName)) { + fileName += targetExtension; + } + } + const downloadURL = await saveBufferToAzure({ + userId, + buffer: webPBuffer, + fileName, + basePath, + containerName, + }); + await fs.promises.unlink(inputFilePath); + const bytes = Buffer.byteLength(webPBuffer); + return { filepath: downloadURL, bytes, width, height }; + } catch (error) { + logger.error('[uploadImageToAzure] Error uploading image:', error); + throw error; + } +} + +/** + * Prepares the image URL and updates the file record. + * + * @param {object} req - The Express request object. + * @param {MongoFile} file - The file object. + * @returns {Promise<[MongoFile, string]>} + */ +async function prepareAzureImageURL(req, file) { + const { filepath } = file; + const promises = []; + promises.push(updateFile({ file_id: file.file_id })); + promises.push(filepath); + return await Promise.all(promises); +} + +/** + * Uploads and processes a user's avatar to Azure Blob Storage. + * + * @param {Object} params + * @param {Buffer} params.buffer - The avatar image buffer. + * @param {string} params.userId - The user's id. + * @param {string} params.manual - Flag to indicate manual update. + * @param {string} [params.basePath='images'] - The base folder within the container. + * @param {string} [params.containerName] - The Azure Blob container name. + * @returns {Promise} The URL of the avatar. + */ +async function processAzureAvatar({ buffer, userId, manual, basePath = 'images', containerName }) { + try { + const downloadURL = await saveBufferToAzure({ + userId, + buffer, + fileName: 'avatar.png', + basePath, + containerName, + }); + const isManual = manual === 'true'; + const url = `${downloadURL}?manual=${isManual}`; + if (isManual) { + await updateUser(userId, { avatar: url }); + } + return url; + } catch (error) { + logger.error('[processAzureAvatar] Error uploading profile picture to Azure:', error); + throw error; + } +} + +module.exports = { + uploadImageToAzure, + prepareAzureImageURL, + processAzureAvatar, +}; diff --git a/api/server/services/Files/Azure/index.js b/api/server/services/Files/Azure/index.js new file mode 100644 index 0000000000..27ad97a852 --- /dev/null +++ b/api/server/services/Files/Azure/index.js @@ -0,0 +1,9 @@ +const crud = require('./crud'); +const images = require('./images'); +const initialize = require('./initialize'); + +module.exports = { + ...crud, + ...images, + ...initialize, +}; diff --git a/api/server/services/Files/Azure/initialize.js b/api/server/services/Files/Azure/initialize.js new file mode 100644 index 0000000000..56df24d04a --- /dev/null +++ b/api/server/services/Files/Azure/initialize.js @@ -0,0 +1,55 @@ +const { BlobServiceClient } = require('@azure/storage-blob'); +const { logger } = require('~/config'); + +let blobServiceClient = null; +let azureWarningLogged = false; + +/** + * Initializes the Azure Blob Service client. + * This function establishes a connection by checking if a connection string is provided. + * If available, the connection string is used; otherwise, Managed Identity (via DefaultAzureCredential) is utilized. + * Note: Container creation (and its public access settings) is handled later in the CRUD functions. + * @returns {BlobServiceClient|null} The initialized client, or null if the required configuration is missing. + */ +const initializeAzureBlobService = () => { + if (blobServiceClient) { + return blobServiceClient; + } + const connectionString = process.env.AZURE_STORAGE_CONNECTION_STRING; + if (connectionString) { + blobServiceClient = BlobServiceClient.fromConnectionString(connectionString); + logger.info('Azure Blob Service initialized using connection string'); + } else { + const { DefaultAzureCredential } = require('@azure/identity'); + const accountName = process.env.AZURE_STORAGE_ACCOUNT_NAME; + if (!accountName) { + if (!azureWarningLogged) { + logger.error( + '[initializeAzureBlobService] Azure Blob Service not initialized. Connection string missing and AZURE_STORAGE_ACCOUNT_NAME not provided.', + ); + azureWarningLogged = true; + } + return null; + } + const url = `https://${accountName}.blob.core.windows.net`; + const credential = new DefaultAzureCredential(); + blobServiceClient = new BlobServiceClient(url, credential); + logger.info('Azure Blob Service initialized using Managed Identity'); + } + return blobServiceClient; +}; + +/** + * Retrieves the Azure ContainerClient for the given container name. + * @param {string} [containerName=process.env.AZURE_CONTAINER_NAME || 'files'] - The container name. + * @returns {ContainerClient|null} The Azure ContainerClient. + */ +const getAzureContainerClient = (containerName = process.env.AZURE_CONTAINER_NAME || 'files') => { + const serviceClient = initializeAzureBlobService(); + return serviceClient ? serviceClient.getContainerClient(containerName) : null; +}; + +module.exports = { + initializeAzureBlobService, + getAzureContainerClient, +}; diff --git a/api/server/services/Files/Code/crud.js b/api/server/services/Files/Code/crud.js index 076a4d9f13..1360cccadb 100644 --- a/api/server/services/Files/Code/crud.js +++ b/api/server/services/Files/Code/crud.js @@ -1,7 +1,9 @@ -// Code Files -const axios = require('axios'); const FormData = require('form-data'); const { getCodeBaseURL } = require('@librechat/agents'); +const { createAxiosInstance } = require('~/config'); +const { logAxiosError } = require('~/utils'); + +const axios = createAxiosInstance(); const MAX_FILE_SIZE = 150 * 1024 * 1024; @@ -15,7 +17,8 @@ const MAX_FILE_SIZE = 150 * 1024 * 1024; async function getCodeOutputDownloadStream(fileIdentifier, apiKey) { try { const baseURL = getCodeBaseURL(); - const response = await axios({ + /** @type {import('axios').AxiosRequestConfig} */ + const options = { method: 'get', url: `${baseURL}/download/${fileIdentifier}`, responseType: 'stream', @@ -24,10 +27,15 @@ async function getCodeOutputDownloadStream(fileIdentifier, apiKey) { 'X-API-Key': apiKey, }, timeout: 15000, - }); + }; + const response = await axios(options); return response; } catch (error) { + logAxiosError({ + message: `Error downloading code environment file stream: ${error.message}`, + error, + }); throw new Error(`Error downloading file: ${error.message}`); } } @@ -53,7 +61,8 @@ async function uploadCodeEnvFile({ req, stream, filename, apiKey, entity_id = '' form.append('file', stream, filename); const baseURL = getCodeBaseURL(); - const response = await axios.post(`${baseURL}/upload`, form, { + /** @type {import('axios').AxiosRequestConfig} */ + const options = { headers: { ...form.getHeaders(), 'Content-Type': 'multipart/form-data', @@ -63,7 +72,9 @@ async function uploadCodeEnvFile({ req, stream, filename, apiKey, entity_id = '' }, maxContentLength: MAX_FILE_SIZE, maxBodyLength: MAX_FILE_SIZE, - }); + }; + + const response = await axios.post(`${baseURL}/upload`, form, options); /** @type {{ message: string; session_id: string; files: Array<{ fileId: string; filename: string }> }} */ const result = response.data; @@ -78,7 +89,11 @@ async function uploadCodeEnvFile({ req, stream, filename, apiKey, entity_id = '' return `${fileIdentifier}?entity_id=${entity_id}`; } catch (error) { - throw new Error(`Error uploading file: ${error.message}`); + logAxiosError({ + message: `Error uploading code environment file: ${error.message}`, + error, + }); + throw new Error(`Error uploading code environment file: ${error.message}`); } } diff --git a/api/server/services/Files/Code/process.js b/api/server/services/Files/Code/process.js index 2a941a4647..c92e628589 100644 --- a/api/server/services/Files/Code/process.js +++ b/api/server/services/Files/Code/process.js @@ -12,6 +12,7 @@ const { const { getStrategyFunctions } = require('~/server/services/Files/strategies'); const { convertImage } = require('~/server/services/Files/images/convert'); const { createFile, getFiles, updateFile } = require('~/models/File'); +const { logAxiosError } = require('~/utils'); const { logger } = require('~/config'); /** @@ -85,7 +86,10 @@ const processCodeOutput = async ({ /** Note: `messageId` & `toolCallId` are not part of file DB schema; message object records associated file ID */ return Object.assign(file, { messageId, toolCallId }); } catch (error) { - logger.error('Error downloading file:', error); + logAxiosError({ + message: 'Error downloading code environment file', + error, + }); } }; @@ -135,7 +139,10 @@ async function getSessionInfo(fileIdentifier, apiKey) { return response.data.find((file) => file.name.startsWith(path))?.lastModified; } catch (error) { - logger.error(`Error fetching session info: ${error.message}`, error); + logAxiosError({ + message: `Error fetching session info: ${error.message}`, + error, + }); return null; } } @@ -202,7 +209,7 @@ const primeFiles = async (options, apiKey) => { const { handleFileUpload: uploadCodeEnvFile } = getStrategyFunctions( FileSources.execute_code, ); - const stream = await getDownloadStream(file.filepath); + const stream = await getDownloadStream(options.req, file.filepath); const fileIdentifier = await uploadCodeEnvFile({ req: options.req, stream, diff --git a/api/server/services/Files/Firebase/crud.js b/api/server/services/Files/Firebase/crud.js index 76a6c1d8d4..8319f908ef 100644 --- a/api/server/services/Files/Firebase/crud.js +++ b/api/server/services/Files/Firebase/crud.js @@ -224,10 +224,11 @@ async function uploadFileToFirebase({ req, file, file_id }) { /** * Retrieves a readable stream for a file from Firebase storage. * + * @param {ServerRequest} _req * @param {string} filepath - The filepath. * @returns {Promise} A readable stream of the file. */ -async function getFirebaseFileStream(filepath) { +async function getFirebaseFileStream(_req, filepath) { try { const storage = getFirebaseStorage(); if (!storage) { diff --git a/api/server/services/Files/Local/crud.js b/api/server/services/Files/Local/crud.js index e004eab79e..c2bb75c125 100644 --- a/api/server/services/Files/Local/crud.js +++ b/api/server/services/Files/Local/crud.js @@ -175,6 +175,17 @@ const isValidPath = (req, base, subfolder, filepath) => { return normalizedFilepath.startsWith(normalizedBase); }; +/** + * @param {string} filepath + */ +const unlinkFile = async (filepath) => { + try { + await fs.promises.unlink(filepath); + } catch (error) { + logger.error('Error deleting file:', error); + } +}; + /** * Deletes a file from the filesystem. This function takes a file object, constructs the full path, and * verifies the path's validity before deleting the file. If the path is invalid, an error is thrown. @@ -217,7 +228,7 @@ const deleteLocalFile = async (req, file) => { throw new Error(`Invalid file path: ${file.filepath}`); } - await fs.promises.unlink(filepath); + await unlinkFile(filepath); return; } @@ -233,7 +244,7 @@ const deleteLocalFile = async (req, file) => { throw new Error('Invalid file path'); } - await fs.promises.unlink(filepath); + await unlinkFile(filepath); }; /** @@ -275,11 +286,31 @@ async function uploadLocalFile({ req, file, file_id }) { /** * Retrieves a readable stream for a file from local storage. * + * @param {ServerRequest} req - The request object from Express * @param {string} filepath - The filepath. * @returns {ReadableStream} A readable stream of the file. */ -function getLocalFileStream(filepath) { +function getLocalFileStream(req, filepath) { try { + if (filepath.includes('/uploads/')) { + const basePath = filepath.split('/uploads/')[1]; + + if (!basePath) { + logger.warn(`Invalid base path: ${filepath}`); + throw new Error(`Invalid file path: ${filepath}`); + } + + const fullPath = path.join(req.app.locals.paths.uploads, basePath); + const uploadsDir = req.app.locals.paths.uploads; + + const rel = path.relative(uploadsDir, fullPath); + if (rel.startsWith('..') || path.isAbsolute(rel) || rel.includes(`..${path.sep}`)) { + logger.warn(`Invalid relative file path: ${filepath}`); + throw new Error(`Invalid file path: ${filepath}`); + } + + return fs.createReadStream(fullPath); + } return fs.createReadStream(filepath); } catch (error) { logger.error('Error getting local file stream:', error); diff --git a/api/server/services/Files/MistralOCR/crud.js b/api/server/services/Files/MistralOCR/crud.js new file mode 100644 index 0000000000..cef8297519 --- /dev/null +++ b/api/server/services/Files/MistralOCR/crud.js @@ -0,0 +1,207 @@ +// ~/server/services/Files/MistralOCR/crud.js +const fs = require('fs'); +const path = require('path'); +const FormData = require('form-data'); +const { FileSources, envVarRegex, extractEnvVariable } = require('librechat-data-provider'); +const { loadAuthValues } = require('~/server/services/Tools/credentials'); +const { logger, createAxiosInstance } = require('~/config'); +const { logAxiosError } = require('~/utils'); + +const axios = createAxiosInstance(); + +/** + * Uploads a document to Mistral API using file streaming to avoid loading the entire file into memory + * + * @param {Object} params Upload parameters + * @param {string} params.filePath The path to the file on disk + * @param {string} [params.fileName] Optional filename to use (defaults to the name from filePath) + * @param {string} params.apiKey Mistral API key + * @param {string} [params.baseURL=https://api.mistral.ai/v1] Mistral API base URL + * @returns {Promise} The response from Mistral API + */ +async function uploadDocumentToMistral({ + filePath, + fileName = '', + apiKey, + baseURL = 'https://api.mistral.ai/v1', +}) { + const form = new FormData(); + form.append('purpose', 'ocr'); + const actualFileName = fileName || path.basename(filePath); + const fileStream = fs.createReadStream(filePath); + form.append('file', fileStream, { filename: actualFileName }); + + return axios + .post(`${baseURL}/files`, form, { + headers: { + Authorization: `Bearer ${apiKey}`, + ...form.getHeaders(), + }, + maxBodyLength: Infinity, + maxContentLength: Infinity, + }) + .then((res) => res.data) + .catch((error) => { + logger.error('Error uploading document to Mistral:', error.message); + throw error; + }); +} + +async function getSignedUrl({ + apiKey, + fileId, + expiry = 24, + baseURL = 'https://api.mistral.ai/v1', +}) { + return axios + .get(`${baseURL}/files/${fileId}/url?expiry=${expiry}`, { + headers: { + Authorization: `Bearer ${apiKey}`, + }, + }) + .then((res) => res.data) + .catch((error) => { + logger.error('Error fetching signed URL:', error.message); + throw error; + }); +} + +/** + * @param {Object} params + * @param {string} params.apiKey + * @param {string} params.documentUrl + * @param {string} [params.baseURL] + * @returns {Promise} + */ +async function performOCR({ + apiKey, + documentUrl, + model = 'mistral-ocr-latest', + baseURL = 'https://api.mistral.ai/v1', +}) { + return axios + .post( + `${baseURL}/ocr`, + { + model, + include_image_base64: false, + document: { + type: 'document_url', + document_url: documentUrl, + }, + }, + { + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${apiKey}`, + }, + }, + ) + .then((res) => res.data) + .catch((error) => { + logger.error('Error performing OCR:', error.message); + throw error; + }); +} + +function extractVariableName(str) { + const match = str.match(envVarRegex); + return match ? match[1] : null; +} + +const uploadMistralOCR = async ({ req, file, file_id, entity_id }) => { + try { + /** @type {TCustomConfig['ocr']} */ + const ocrConfig = req.app.locals?.ocr; + + const apiKeyConfig = ocrConfig.apiKey || ''; + const baseURLConfig = ocrConfig.baseURL || ''; + + const isApiKeyEnvVar = envVarRegex.test(apiKeyConfig); + const isBaseURLEnvVar = envVarRegex.test(baseURLConfig); + + const isApiKeyEmpty = !apiKeyConfig.trim(); + const isBaseURLEmpty = !baseURLConfig.trim(); + + let apiKey, baseURL; + + if (isApiKeyEnvVar || isBaseURLEnvVar || isApiKeyEmpty || isBaseURLEmpty) { + const apiKeyVarName = isApiKeyEnvVar ? extractVariableName(apiKeyConfig) : 'OCR_API_KEY'; + const baseURLVarName = isBaseURLEnvVar ? extractVariableName(baseURLConfig) : 'OCR_BASEURL'; + + const authValues = await loadAuthValues({ + userId: req.user.id, + authFields: [baseURLVarName, apiKeyVarName], + optional: new Set([baseURLVarName]), + }); + + apiKey = authValues[apiKeyVarName]; + baseURL = authValues[baseURLVarName]; + } else { + apiKey = apiKeyConfig; + baseURL = baseURLConfig; + } + + const mistralFile = await uploadDocumentToMistral({ + filePath: file.path, + fileName: file.originalname, + apiKey, + baseURL, + }); + + const modelConfig = ocrConfig.mistralModel || ''; + const model = envVarRegex.test(modelConfig) + ? extractEnvVariable(modelConfig) + : modelConfig.trim() || 'mistral-ocr-latest'; + + const signedUrlResponse = await getSignedUrl({ + apiKey, + baseURL, + fileId: mistralFile.id, + }); + + const ocrResult = await performOCR({ + apiKey, + baseURL, + model, + documentUrl: signedUrlResponse.url, + }); + + let aggregatedText = ''; + const images = []; + ocrResult.pages.forEach((page, index) => { + if (ocrResult.pages.length > 1) { + aggregatedText += `# PAGE ${index + 1}\n`; + } + + aggregatedText += page.markdown + '\n\n'; + + if (page.images && page.images.length > 0) { + page.images.forEach((image) => { + if (image.image_base64) { + images.push(image.image_base64); + } + }); + } + }); + + return { + filename: file.originalname, + bytes: aggregatedText.length * 4, + filepath: FileSources.mistral_ocr, + text: aggregatedText, + images, + }; + } catch (error) { + const message = 'Error uploading document to Mistral OCR API'; + logAxiosError({ error, message }); + throw new Error(message); + } +}; + +module.exports = { + uploadDocumentToMistral, + uploadMistralOCR, + getSignedUrl, + performOCR, +}; diff --git a/api/server/services/Files/MistralOCR/crud.spec.js b/api/server/services/Files/MistralOCR/crud.spec.js new file mode 100644 index 0000000000..80ac6f73a4 --- /dev/null +++ b/api/server/services/Files/MistralOCR/crud.spec.js @@ -0,0 +1,737 @@ +const fs = require('fs'); + +const mockAxios = { + interceptors: { + request: { use: jest.fn(), eject: jest.fn() }, + response: { use: jest.fn(), eject: jest.fn() }, + }, + create: jest.fn().mockReturnValue({ + defaults: { + proxy: null, + }, + get: jest.fn().mockResolvedValue({ data: {} }), + post: jest.fn().mockResolvedValue({ data: {} }), + put: jest.fn().mockResolvedValue({ data: {} }), + delete: jest.fn().mockResolvedValue({ data: {} }), + }), + get: jest.fn().mockResolvedValue({ data: {} }), + post: jest.fn().mockResolvedValue({ data: {} }), + put: jest.fn().mockResolvedValue({ data: {} }), + delete: jest.fn().mockResolvedValue({ data: {} }), + reset: jest.fn().mockImplementation(function () { + this.get.mockClear(); + this.post.mockClear(); + this.put.mockClear(); + this.delete.mockClear(); + this.create.mockClear(); + }), +}; + +jest.mock('axios', () => mockAxios); +jest.mock('fs'); +jest.mock('~/utils', () => ({ + logAxiosError: jest.fn(), +})); +jest.mock('~/config', () => ({ + logger: { + error: jest.fn(), + }, + createAxiosInstance: () => mockAxios, +})); +jest.mock('~/server/services/Tools/credentials', () => ({ + loadAuthValues: jest.fn(), +})); + +const { uploadDocumentToMistral, uploadMistralOCR, getSignedUrl, performOCR } = require('./crud'); + +describe('MistralOCR Service', () => { + afterEach(() => { + mockAxios.reset(); + jest.clearAllMocks(); + }); + + describe('uploadDocumentToMistral', () => { + beforeEach(() => { + // Create a more complete mock for file streams that FormData can work with + const mockReadStream = { + on: jest.fn().mockImplementation(function (event, handler) { + // Simulate immediate 'end' event to make FormData complete processing + if (event === 'end') { + handler(); + } + return this; + }), + pipe: jest.fn().mockImplementation(function () { + return this; + }), + pause: jest.fn(), + resume: jest.fn(), + emit: jest.fn(), + once: jest.fn(), + destroy: jest.fn(), + }; + + fs.createReadStream = jest.fn().mockReturnValue(mockReadStream); + + // Mock FormData's append to avoid actual stream processing + jest.mock('form-data', () => { + const mockFormData = function () { + return { + append: jest.fn(), + getHeaders: jest + .fn() + .mockReturnValue({ 'content-type': 'multipart/form-data; boundary=---boundary' }), + getBuffer: jest.fn().mockReturnValue(Buffer.from('mock-form-data')), + getLength: jest.fn().mockReturnValue(100), + }; + }; + return mockFormData; + }); + }); + + it('should upload a document to Mistral API using file streaming', async () => { + const mockResponse = { data: { id: 'file-123', purpose: 'ocr' } }; + mockAxios.post.mockResolvedValueOnce(mockResponse); + + const result = await uploadDocumentToMistral({ + filePath: '/path/to/test.pdf', + fileName: 'test.pdf', + apiKey: 'test-api-key', + }); + + // Check that createReadStream was called with the correct file path + expect(fs.createReadStream).toHaveBeenCalledWith('/path/to/test.pdf'); + + // Since we're mocking FormData, we'll just check that axios was called correctly + expect(mockAxios.post).toHaveBeenCalledWith( + 'https://api.mistral.ai/v1/files', + expect.anything(), + expect.objectContaining({ + headers: expect.objectContaining({ + Authorization: 'Bearer test-api-key', + }), + maxBodyLength: Infinity, + maxContentLength: Infinity, + }), + ); + expect(result).toEqual(mockResponse.data); + }); + + it('should handle errors during document upload', async () => { + const errorMessage = 'API error'; + mockAxios.post.mockRejectedValueOnce(new Error(errorMessage)); + + await expect( + uploadDocumentToMistral({ + filePath: '/path/to/test.pdf', + fileName: 'test.pdf', + apiKey: 'test-api-key', + }), + ).rejects.toThrow(); + + const { logger } = require('~/config'); + expect(logger.error).toHaveBeenCalledWith( + expect.stringContaining('Error uploading document to Mistral:'), + expect.any(String), + ); + }); + }); + + describe('getSignedUrl', () => { + it('should fetch signed URL from Mistral API', async () => { + const mockResponse = { data: { url: 'https://document-url.com' } }; + mockAxios.get.mockResolvedValueOnce(mockResponse); + + const result = await getSignedUrl({ + fileId: 'file-123', + apiKey: 'test-api-key', + }); + + expect(mockAxios.get).toHaveBeenCalledWith( + 'https://api.mistral.ai/v1/files/file-123/url?expiry=24', + { + headers: { + Authorization: 'Bearer test-api-key', + }, + }, + ); + expect(result).toEqual(mockResponse.data); + }); + + it('should handle errors when fetching signed URL', async () => { + const errorMessage = 'API error'; + mockAxios.get.mockRejectedValueOnce(new Error(errorMessage)); + + await expect( + getSignedUrl({ + fileId: 'file-123', + apiKey: 'test-api-key', + }), + ).rejects.toThrow(); + + const { logger } = require('~/config'); + expect(logger.error).toHaveBeenCalledWith('Error fetching signed URL:', errorMessage); + }); + }); + + describe('performOCR', () => { + it('should perform OCR using Mistral API', async () => { + const mockResponse = { + data: { + pages: [{ markdown: 'Page 1 content' }, { markdown: 'Page 2 content' }], + }, + }; + mockAxios.post.mockResolvedValueOnce(mockResponse); + + const result = await performOCR({ + apiKey: 'test-api-key', + documentUrl: 'https://document-url.com', + model: 'mistral-ocr-latest', + }); + + expect(mockAxios.post).toHaveBeenCalledWith( + 'https://api.mistral.ai/v1/ocr', + { + model: 'mistral-ocr-latest', + include_image_base64: false, + document: { + type: 'document_url', + document_url: 'https://document-url.com', + }, + }, + { + headers: { + 'Content-Type': 'application/json', + Authorization: 'Bearer test-api-key', + }, + }, + ); + expect(result).toEqual(mockResponse.data); + }); + + it('should handle errors during OCR processing', async () => { + const errorMessage = 'OCR processing error'; + mockAxios.post.mockRejectedValueOnce(new Error(errorMessage)); + + await expect( + performOCR({ + apiKey: 'test-api-key', + documentUrl: 'https://document-url.com', + }), + ).rejects.toThrow(); + + const { logger } = require('~/config'); + expect(logger.error).toHaveBeenCalledWith('Error performing OCR:', errorMessage); + }); + }); + + describe('uploadMistralOCR', () => { + beforeEach(() => { + const mockReadStream = { + on: jest.fn().mockImplementation(function (event, handler) { + if (event === 'end') { + handler(); + } + return this; + }), + pipe: jest.fn().mockImplementation(function () { + return this; + }), + pause: jest.fn(), + resume: jest.fn(), + emit: jest.fn(), + once: jest.fn(), + destroy: jest.fn(), + }; + + fs.createReadStream = jest.fn().mockReturnValue(mockReadStream); + }); + + it('should process OCR for a file with standard configuration', async () => { + // Setup mocks + const { loadAuthValues } = require('~/server/services/Tools/credentials'); + loadAuthValues.mockResolvedValue({ + OCR_API_KEY: 'test-api-key', + OCR_BASEURL: 'https://api.mistral.ai/v1', + }); + + // Mock file upload response + mockAxios.post.mockResolvedValueOnce({ + data: { id: 'file-123', purpose: 'ocr' }, + }); + + // Mock signed URL response + mockAxios.get.mockResolvedValueOnce({ + data: { url: 'https://signed-url.com' }, + }); + + // Mock OCR response with text and images + mockAxios.post.mockResolvedValueOnce({ + data: { + pages: [ + { + markdown: 'Page 1 content', + images: [{ image_base64: 'base64image1' }], + }, + { + markdown: 'Page 2 content', + images: [{ image_base64: 'base64image2' }], + }, + ], + }, + }); + + const req = { + user: { id: 'user123' }, + app: { + locals: { + ocr: { + // Use environment variable syntax to ensure loadAuthValues is called + apiKey: '${OCR_API_KEY}', + baseURL: '${OCR_BASEURL}', + mistralModel: 'mistral-medium', + }, + }, + }, + }; + + const file = { + path: '/tmp/upload/file.pdf', + originalname: 'document.pdf', + }; + + const result = await uploadMistralOCR({ + req, + file, + file_id: 'file123', + entity_id: 'entity123', + }); + + expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf'); + + expect(loadAuthValues).toHaveBeenCalledWith({ + userId: 'user123', + authFields: ['OCR_BASEURL', 'OCR_API_KEY'], + optional: expect.any(Set), + }); + + // Verify OCR result + expect(result).toEqual({ + filename: 'document.pdf', + bytes: expect.any(Number), + filepath: 'mistral_ocr', + text: expect.stringContaining('# PAGE 1'), + images: ['base64image1', 'base64image2'], + }); + }); + + it('should process variable references in configuration', async () => { + // Setup mocks with environment variables + const { loadAuthValues } = require('~/server/services/Tools/credentials'); + loadAuthValues.mockResolvedValue({ + CUSTOM_API_KEY: 'custom-api-key', + CUSTOM_BASEURL: 'https://custom-api.mistral.ai/v1', + }); + + // Mock API responses + mockAxios.post.mockResolvedValueOnce({ + data: { id: 'file-123', purpose: 'ocr' }, + }); + mockAxios.get.mockResolvedValueOnce({ + data: { url: 'https://signed-url.com' }, + }); + mockAxios.post.mockResolvedValueOnce({ + data: { + pages: [{ markdown: 'Content from custom API' }], + }, + }); + + const req = { + user: { id: 'user123' }, + app: { + locals: { + ocr: { + apiKey: '${CUSTOM_API_KEY}', + baseURL: '${CUSTOM_BASEURL}', + mistralModel: '${CUSTOM_MODEL}', + }, + }, + }, + }; + + // Set environment variable for model + process.env.CUSTOM_MODEL = 'mistral-large'; + + const file = { + path: '/tmp/upload/file.pdf', + originalname: 'document.pdf', + }; + + const result = await uploadMistralOCR({ + req, + file, + file_id: 'file123', + entity_id: 'entity123', + }); + + expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf'); + + // Verify that custom environment variables were extracted and used + expect(loadAuthValues).toHaveBeenCalledWith({ + userId: 'user123', + authFields: ['CUSTOM_BASEURL', 'CUSTOM_API_KEY'], + optional: expect.any(Set), + }); + + // Check that mistral-large was used in the OCR API call + expect(mockAxios.post).toHaveBeenCalledWith( + expect.anything(), + expect.objectContaining({ + model: 'mistral-large', + }), + expect.anything(), + ); + + expect(result.text).toEqual('Content from custom API\n\n'); + }); + + it('should fall back to default values when variables are not properly formatted', async () => { + const { loadAuthValues } = require('~/server/services/Tools/credentials'); + loadAuthValues.mockResolvedValue({ + OCR_API_KEY: 'default-api-key', + OCR_BASEURL: undefined, // Testing optional parameter + }); + + mockAxios.post.mockResolvedValueOnce({ + data: { id: 'file-123', purpose: 'ocr' }, + }); + mockAxios.get.mockResolvedValueOnce({ + data: { url: 'https://signed-url.com' }, + }); + mockAxios.post.mockResolvedValueOnce({ + data: { + pages: [{ markdown: 'Default API result' }], + }, + }); + + const req = { + user: { id: 'user123' }, + app: { + locals: { + ocr: { + // Use environment variable syntax to ensure loadAuthValues is called + apiKey: '${INVALID_FORMAT}', // Using valid env var format but with an invalid name + baseURL: '${OCR_BASEURL}', // Using valid env var format + mistralModel: 'mistral-ocr-latest', // Plain string value + }, + }, + }, + }; + + const file = { + path: '/tmp/upload/file.pdf', + originalname: 'document.pdf', + }; + + await uploadMistralOCR({ + req, + file, + file_id: 'file123', + entity_id: 'entity123', + }); + + expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf'); + + // Should use the default values + expect(loadAuthValues).toHaveBeenCalledWith({ + userId: 'user123', + authFields: ['OCR_BASEURL', 'INVALID_FORMAT'], + optional: expect.any(Set), + }); + + // Should use the default model when not using environment variable format + expect(mockAxios.post).toHaveBeenCalledWith( + expect.anything(), + expect.objectContaining({ + model: 'mistral-ocr-latest', + }), + expect.anything(), + ); + }); + + it('should handle API errors during OCR process', async () => { + const { loadAuthValues } = require('~/server/services/Tools/credentials'); + loadAuthValues.mockResolvedValue({ + OCR_API_KEY: 'test-api-key', + }); + + // Mock file upload to fail + mockAxios.post.mockRejectedValueOnce(new Error('Upload failed')); + + const req = { + user: { id: 'user123' }, + app: { + locals: { + ocr: { + apiKey: 'OCR_API_KEY', + baseURL: 'OCR_BASEURL', + }, + }, + }, + }; + + const file = { + path: '/tmp/upload/file.pdf', + originalname: 'document.pdf', + }; + + await expect( + uploadMistralOCR({ + req, + file, + file_id: 'file123', + entity_id: 'entity123', + }), + ).rejects.toThrow('Error uploading document to Mistral OCR API'); + expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf'); + + const { logAxiosError } = require('~/utils'); + expect(logAxiosError).toHaveBeenCalled(); + }); + + it('should handle single page documents without page numbering', async () => { + const { loadAuthValues } = require('~/server/services/Tools/credentials'); + loadAuthValues.mockResolvedValue({ + OCR_API_KEY: 'test-api-key', + OCR_BASEURL: 'https://api.mistral.ai/v1', // Make sure this is included + }); + + // Clear all previous mocks + mockAxios.post.mockClear(); + mockAxios.get.mockClear(); + + // 1. First mock: File upload response + mockAxios.post.mockImplementationOnce(() => + Promise.resolve({ data: { id: 'file-123', purpose: 'ocr' } }), + ); + + // 2. Second mock: Signed URL response + mockAxios.get.mockImplementationOnce(() => + Promise.resolve({ data: { url: 'https://signed-url.com' } }), + ); + + // 3. Third mock: OCR response + mockAxios.post.mockImplementationOnce(() => + Promise.resolve({ + data: { + pages: [{ markdown: 'Single page content' }], + }, + }), + ); + + const req = { + user: { id: 'user123' }, + app: { + locals: { + ocr: { + apiKey: 'OCR_API_KEY', + baseURL: 'OCR_BASEURL', + mistralModel: 'mistral-ocr-latest', + }, + }, + }, + }; + + const file = { + path: '/tmp/upload/file.pdf', + originalname: 'single-page.pdf', + }; + + const result = await uploadMistralOCR({ + req, + file, + file_id: 'file123', + entity_id: 'entity123', + }); + + expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf'); + + // Verify that single page documents don't include page numbering + expect(result.text).not.toContain('# PAGE'); + expect(result.text).toEqual('Single page content\n\n'); + }); + + it('should use literal values in configuration when provided directly', async () => { + const { loadAuthValues } = require('~/server/services/Tools/credentials'); + // We'll still mock this but it should not be used for literal values + loadAuthValues.mockResolvedValue({}); + + // Clear all previous mocks + mockAxios.post.mockClear(); + mockAxios.get.mockClear(); + + // 1. First mock: File upload response + mockAxios.post.mockImplementationOnce(() => + Promise.resolve({ data: { id: 'file-123', purpose: 'ocr' } }), + ); + + // 2. Second mock: Signed URL response + mockAxios.get.mockImplementationOnce(() => + Promise.resolve({ data: { url: 'https://signed-url.com' } }), + ); + + // 3. Third mock: OCR response + mockAxios.post.mockImplementationOnce(() => + Promise.resolve({ + data: { + pages: [{ markdown: 'Processed with literal config values' }], + }, + }), + ); + + const req = { + user: { id: 'user123' }, + app: { + locals: { + ocr: { + // Direct values that should be used as-is, without variable substitution + apiKey: 'actual-api-key-value', + baseURL: 'https://direct-api-url.mistral.ai/v1', + mistralModel: 'mistral-direct-model', + }, + }, + }, + }; + + const file = { + path: '/tmp/upload/file.pdf', + originalname: 'direct-values.pdf', + }; + + const result = await uploadMistralOCR({ + req, + file, + file_id: 'file123', + entity_id: 'entity123', + }); + + expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf'); + + // Verify the correct URL was used with the direct baseURL value + expect(mockAxios.post).toHaveBeenCalledWith( + 'https://direct-api-url.mistral.ai/v1/files', + expect.any(Object), + expect.objectContaining({ + headers: expect.objectContaining({ + Authorization: 'Bearer actual-api-key-value', + }), + }), + ); + + // Check the OCR call was made with the direct model value + expect(mockAxios.post).toHaveBeenCalledWith( + 'https://direct-api-url.mistral.ai/v1/ocr', + expect.objectContaining({ + model: 'mistral-direct-model', + }), + expect.any(Object), + ); + + // Verify the result + expect(result.text).toEqual('Processed with literal config values\n\n'); + + // Verify loadAuthValues was never called since we used direct values + expect(loadAuthValues).not.toHaveBeenCalled(); + }); + + it('should handle empty configuration values and use defaults', async () => { + const { loadAuthValues } = require('~/server/services/Tools/credentials'); + // Set up the mock values to be returned by loadAuthValues + loadAuthValues.mockResolvedValue({ + OCR_API_KEY: 'default-from-env-key', + OCR_BASEURL: 'https://default-from-env.mistral.ai/v1', + }); + + // Clear all previous mocks + mockAxios.post.mockClear(); + mockAxios.get.mockClear(); + + // 1. First mock: File upload response + mockAxios.post.mockImplementationOnce(() => + Promise.resolve({ data: { id: 'file-123', purpose: 'ocr' } }), + ); + + // 2. Second mock: Signed URL response + mockAxios.get.mockImplementationOnce(() => + Promise.resolve({ data: { url: 'https://signed-url.com' } }), + ); + + // 3. Third mock: OCR response + mockAxios.post.mockImplementationOnce(() => + Promise.resolve({ + data: { + pages: [{ markdown: 'Content from default configuration' }], + }, + }), + ); + + const req = { + user: { id: 'user123' }, + app: { + locals: { + ocr: { + // Empty string values - should fall back to defaults + apiKey: '', + baseURL: '', + mistralModel: '', + }, + }, + }, + }; + + const file = { + path: '/tmp/upload/file.pdf', + originalname: 'empty-config.pdf', + }; + + const result = await uploadMistralOCR({ + req, + file, + file_id: 'file123', + entity_id: 'entity123', + }); + + expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf'); + + // Verify loadAuthValues was called with the default variable names + expect(loadAuthValues).toHaveBeenCalledWith({ + userId: 'user123', + authFields: ['OCR_BASEURL', 'OCR_API_KEY'], + optional: expect.any(Set), + }); + + // Verify the API calls used the default values from loadAuthValues + expect(mockAxios.post).toHaveBeenCalledWith( + 'https://default-from-env.mistral.ai/v1/files', + expect.any(Object), + expect.objectContaining({ + headers: expect.objectContaining({ + Authorization: 'Bearer default-from-env-key', + }), + }), + ); + + // Verify the OCR model defaulted to mistral-ocr-latest + expect(mockAxios.post).toHaveBeenCalledWith( + 'https://default-from-env.mistral.ai/v1/ocr', + expect.objectContaining({ + model: 'mistral-ocr-latest', + }), + expect.any(Object), + ); + + // Check result + expect(result.text).toEqual('Content from default configuration\n\n'); + }); + }); +}); diff --git a/api/server/services/Files/MistralOCR/index.js b/api/server/services/Files/MistralOCR/index.js new file mode 100644 index 0000000000..a6223d1ee5 --- /dev/null +++ b/api/server/services/Files/MistralOCR/index.js @@ -0,0 +1,5 @@ +const crud = require('./crud'); + +module.exports = { + ...crud, +}; diff --git a/api/server/services/Files/S3/crud.js b/api/server/services/Files/S3/crud.js new file mode 100644 index 0000000000..06f9116b69 --- /dev/null +++ b/api/server/services/Files/S3/crud.js @@ -0,0 +1,163 @@ +const fs = require('fs'); +const path = require('path'); +const fetch = require('node-fetch'); +const { PutObjectCommand, GetObjectCommand, DeleteObjectCommand } = require('@aws-sdk/client-s3'); +const { getSignedUrl } = require('@aws-sdk/s3-request-presigner'); +const { initializeS3 } = require('./initialize'); +const { logger } = require('~/config'); + +const bucketName = process.env.AWS_BUCKET_NAME; +const defaultBasePath = 'images'; + +/** + * Constructs the S3 key based on the base path, user ID, and file name. + */ +const getS3Key = (basePath, userId, fileName) => `${basePath}/${userId}/${fileName}`; + +/** + * Uploads a buffer to S3 and returns a signed URL. + * + * @param {Object} params + * @param {string} params.userId - The user's unique identifier. + * @param {Buffer} params.buffer - The buffer containing file data. + * @param {string} params.fileName - The file name to use in S3. + * @param {string} [params.basePath='images'] - The base path in the bucket. + * @returns {Promise} Signed URL of the uploaded file. + */ +async function saveBufferToS3({ userId, buffer, fileName, basePath = defaultBasePath }) { + const key = getS3Key(basePath, userId, fileName); + const params = { Bucket: bucketName, Key: key, Body: buffer }; + + try { + const s3 = initializeS3(); + await s3.send(new PutObjectCommand(params)); + return await getS3URL({ userId, fileName, basePath }); + } catch (error) { + logger.error('[saveBufferToS3] Error uploading buffer to S3:', error.message); + throw error; + } +} + +/** + * Retrieves a signed URL for a file stored in S3. + * + * @param {Object} params + * @param {string} params.userId - The user's unique identifier. + * @param {string} params.fileName - The file name in S3. + * @param {string} [params.basePath='images'] - The base path in the bucket. + * @returns {Promise} A signed URL valid for 24 hours. + */ +async function getS3URL({ userId, fileName, basePath = defaultBasePath }) { + const key = getS3Key(basePath, userId, fileName); + const params = { Bucket: bucketName, Key: key }; + + try { + const s3 = initializeS3(); + return await getSignedUrl(s3, new GetObjectCommand(params), { expiresIn: 86400 }); + } catch (error) { + logger.error('[getS3URL] Error getting signed URL from S3:', error.message); + throw error; + } +} + +/** + * Saves a file from a given URL to S3. + * + * @param {Object} params + * @param {string} params.userId - The user's unique identifier. + * @param {string} params.URL - The source URL of the file. + * @param {string} params.fileName - The file name to use in S3. + * @param {string} [params.basePath='images'] - The base path in the bucket. + * @returns {Promise} Signed URL of the uploaded file. + */ +async function saveURLToS3({ userId, URL, fileName, basePath = defaultBasePath }) { + try { + const response = await fetch(URL); + const buffer = await response.buffer(); + // Optionally you can call getBufferMetadata(buffer) if needed. + return await saveBufferToS3({ userId, buffer, fileName, basePath }); + } catch (error) { + logger.error('[saveURLToS3] Error uploading file from URL to S3:', error.message); + throw error; + } +} + +/** + * Deletes a file from S3. + * + * @param {Object} params + * @param {string} params.userId - The user's unique identifier. + * @param {string} params.fileName - The file name in S3. + * @param {string} [params.basePath='images'] - The base path in the bucket. + * @returns {Promise} + */ +async function deleteFileFromS3({ userId, fileName, basePath = defaultBasePath }) { + const key = getS3Key(basePath, userId, fileName); + const params = { Bucket: bucketName, Key: key }; + + try { + const s3 = initializeS3(); + await s3.send(new DeleteObjectCommand(params)); + logger.debug('[deleteFileFromS3] File deleted successfully from S3'); + } catch (error) { + logger.error('[deleteFileFromS3] Error deleting file from S3:', error.message); + // If the file is not found, we can safely return. + if (error.code === 'NoSuchKey') { + return; + } + throw error; + } +} + +/** + * Uploads a local file to S3. + * + * @param {Object} params + * @param {import('express').Request} params.req - The Express request (must include user). + * @param {Express.Multer.File} params.file - The file object from Multer. + * @param {string} params.file_id - Unique file identifier. + * @param {string} [params.basePath='images'] - The base path in the bucket. + * @returns {Promise<{ filepath: string, bytes: number }>} + */ +async function uploadFileToS3({ req, file, file_id, basePath = defaultBasePath }) { + try { + const inputFilePath = file.path; + const inputBuffer = await fs.promises.readFile(inputFilePath); + const bytes = Buffer.byteLength(inputBuffer); + const userId = req.user.id; + const fileName = `${file_id}__${path.basename(inputFilePath)}`; + const fileURL = await saveBufferToS3({ userId, buffer: inputBuffer, fileName, basePath }); + await fs.promises.unlink(inputFilePath); + return { filepath: fileURL, bytes }; + } catch (error) { + logger.error('[uploadFileToS3] Error uploading file to S3:', error.message); + throw error; + } +} + +/** + * Retrieves a readable stream for a file stored in S3. + * + * @param {string} filePath - The S3 key of the file. + * @returns {Promise} + */ +async function getS3FileStream(filePath) { + const params = { Bucket: bucketName, Key: filePath }; + try { + const s3 = initializeS3(); + const data = await s3.send(new GetObjectCommand(params)); + return data.Body; // Returns a Node.js ReadableStream. + } catch (error) { + logger.error('[getS3FileStream] Error retrieving S3 file stream:', error.message); + throw error; + } +} + +module.exports = { + saveBufferToS3, + saveURLToS3, + getS3URL, + deleteFileFromS3, + uploadFileToS3, + getS3FileStream, +}; diff --git a/api/server/services/Files/S3/images.js b/api/server/services/Files/S3/images.js new file mode 100644 index 0000000000..378212cb5e --- /dev/null +++ b/api/server/services/Files/S3/images.js @@ -0,0 +1,118 @@ +const fs = require('fs'); +const path = require('path'); +const sharp = require('sharp'); +const { resizeImageBuffer } = require('../images/resize'); +const { updateUser } = require('~/models/userMethods'); +const { saveBufferToS3 } = require('./crud'); +const { updateFile } = require('~/models/File'); +const { logger } = require('~/config'); + +const defaultBasePath = 'images'; + +/** + * Resizes, converts, and uploads an image file to S3. + * + * @param {Object} params + * @param {import('express').Request} params.req - Express request (expects user and app.locals.imageOutputType). + * @param {Express.Multer.File} params.file - File object from Multer. + * @param {string} params.file_id - Unique file identifier. + * @param {any} params.endpoint - Endpoint identifier used in image processing. + * @param {string} [params.resolution='high'] - Desired image resolution. + * @param {string} [params.basePath='images'] - Base path in the bucket. + * @returns {Promise<{ filepath: string, bytes: number, width: number, height: number }>} + */ +async function uploadImageToS3({ + req, + file, + file_id, + endpoint, + resolution = 'high', + basePath = defaultBasePath, +}) { + try { + const inputFilePath = file.path; + const inputBuffer = await fs.promises.readFile(inputFilePath); + const { + buffer: resizedBuffer, + width, + height, + } = await resizeImageBuffer(inputBuffer, resolution, endpoint); + const extension = path.extname(inputFilePath); + const userId = req.user.id; + + let processedBuffer; + let fileName = `${file_id}__${path.basename(inputFilePath)}`; + const targetExtension = `.${req.app.locals.imageOutputType}`; + + if (extension.toLowerCase() === targetExtension) { + processedBuffer = resizedBuffer; + } else { + processedBuffer = await sharp(resizedBuffer) + .toFormat(req.app.locals.imageOutputType) + .toBuffer(); + fileName = fileName.replace(new RegExp(path.extname(fileName) + '$'), targetExtension); + if (!path.extname(fileName)) { + fileName += targetExtension; + } + } + + const downloadURL = await saveBufferToS3({ + userId, + buffer: processedBuffer, + fileName, + basePath, + }); + await fs.promises.unlink(inputFilePath); + const bytes = Buffer.byteLength(processedBuffer); + return { filepath: downloadURL, bytes, width, height }; + } catch (error) { + logger.error('[uploadImageToS3] Error uploading image to S3:', error.message); + throw error; + } +} + +/** + * Updates a file record and returns its signed URL. + * + * @param {import('express').Request} req - Express request. + * @param {Object} file - File metadata. + * @returns {Promise<[Promise, string]>} + */ +async function prepareImageURLS3(req, file) { + try { + const updatePromise = updateFile({ file_id: file.file_id }); + return Promise.all([updatePromise, file.filepath]); + } catch (error) { + logger.error('[prepareImageURLS3] Error preparing image URL:', error.message); + throw error; + } +} + +/** + * Processes a user's avatar image by uploading it to S3 and updating the user's avatar URL if required. + * + * @param {Object} params + * @param {Buffer} params.buffer - Avatar image buffer. + * @param {string} params.userId - User's unique identifier. + * @param {string} params.manual - 'true' or 'false' flag for manual update. + * @param {string} [params.basePath='images'] - Base path in the bucket. + * @returns {Promise} Signed URL of the uploaded avatar. + */ +async function processS3Avatar({ buffer, userId, manual, basePath = defaultBasePath }) { + try { + const downloadURL = await saveBufferToS3({ userId, buffer, fileName: 'avatar.png', basePath }); + if (manual === 'true') { + await updateUser(userId, { avatar: downloadURL }); + } + return downloadURL; + } catch (error) { + logger.error('[processS3Avatar] Error processing S3 avatar:', error.message); + throw error; + } +} + +module.exports = { + uploadImageToS3, + prepareImageURLS3, + processS3Avatar, +}; diff --git a/api/server/services/Files/S3/index.js b/api/server/services/Files/S3/index.js new file mode 100644 index 0000000000..27ad97a852 --- /dev/null +++ b/api/server/services/Files/S3/index.js @@ -0,0 +1,9 @@ +const crud = require('./crud'); +const images = require('./images'); +const initialize = require('./initialize'); + +module.exports = { + ...crud, + ...images, + ...initialize, +}; diff --git a/api/server/services/Files/S3/initialize.js b/api/server/services/Files/S3/initialize.js new file mode 100644 index 0000000000..2daec25235 --- /dev/null +++ b/api/server/services/Files/S3/initialize.js @@ -0,0 +1,53 @@ +const { S3Client } = require('@aws-sdk/client-s3'); +const { logger } = require('~/config'); + +let s3 = null; + +/** + * Initializes and returns an instance of the AWS S3 client. + * + * If AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY are provided, they will be used. + * Otherwise, the AWS SDK's default credentials chain (including IRSA) is used. + * + * If AWS_ENDPOINT_URL is provided, it will be used as the endpoint. + * + * @returns {S3Client|null} An instance of S3Client if the region is provided; otherwise, null. + */ +const initializeS3 = () => { + if (s3) { + return s3; + } + + const region = process.env.AWS_REGION; + if (!region) { + logger.error('[initializeS3] AWS_REGION is not set. Cannot initialize S3.'); + return null; + } + + // Read the custom endpoint if provided. + const endpoint = process.env.AWS_ENDPOINT_URL; + const accessKeyId = process.env.AWS_ACCESS_KEY_ID; + const secretAccessKey = process.env.AWS_SECRET_ACCESS_KEY; + + const config = { + region, + // Conditionally add the endpoint if it is provided + ...(endpoint ? { endpoint } : {}), + }; + + if (accessKeyId && secretAccessKey) { + s3 = new S3Client({ + ...config, + credentials: { accessKeyId, secretAccessKey }, + }); + logger.info('[initializeS3] S3 initialized with provided credentials.'); + } else { + // When using IRSA, credentials are automatically provided via the IAM Role attached to the ServiceAccount. + s3 = new S3Client(config); + logger.info('[initializeS3] S3 initialized using default credentials (IRSA).'); + } + + return s3; +}; + +module.exports = { initializeS3 }; diff --git a/api/server/services/Files/VectorDB/crud.js b/api/server/services/Files/VectorDB/crud.js index d290eea4b1..37a1e81487 100644 --- a/api/server/services/Files/VectorDB/crud.js +++ b/api/server/services/Files/VectorDB/crud.js @@ -37,7 +37,14 @@ const deleteVectors = async (req, file) => { error, message: 'Error deleting vectors', }); - throw new Error(error.message || 'An error occurred during file deletion.'); + if ( + error.response && + error.response.status !== 404 && + (error.response.status < 200 || error.response.status >= 300) + ) { + logger.warn('Error deleting vectors, file will not be deleted'); + throw new Error(error.message || 'An error occurred during file deletion.'); + } } }; diff --git a/api/server/services/Files/images/encode.js b/api/server/services/Files/images/encode.js index 94153ffc64..707632fb6a 100644 --- a/api/server/services/Files/images/encode.js +++ b/api/server/services/Files/images/encode.js @@ -49,6 +49,7 @@ async function encodeAndFormat(req, files, endpoint, mode) { const promises = []; const encodingMethods = {}; const result = { + text: '', files: [], image_urls: [], }; @@ -59,6 +60,9 @@ async function encodeAndFormat(req, files, endpoint, mode) { for (let file of files) { const source = file.source ?? FileSources.local; + if (source === FileSources.text && file.text) { + result.text += `${!result.text ? 'Attached document(s):\n```md' : '\n\n---\n\n'}# "${file.filename}"\n${file.text}\n`; + } if (!file.height) { promises.push([file, null]); @@ -85,6 +89,10 @@ async function encodeAndFormat(req, files, endpoint, mode) { promises.push(preparePayload(req, file)); } + if (result.text) { + result.text += '\n```'; + } + const detail = req.body.imageDetail ?? ImageDetail.auto; /** @type {Array<[MongoFile, string]>} */ diff --git a/api/server/services/Files/process.js b/api/server/services/Files/process.js index a5d9c8c1e0..78a4976e2f 100644 --- a/api/server/services/Files/process.js +++ b/api/server/services/Files/process.js @@ -28,8 +28,8 @@ const { addResourceFileId, deleteResourceFileId } = require('~/server/controller const { addAgentResourceFile, removeAgentResourceFiles } = require('~/models/Agent'); const { getOpenAIClient } = require('~/server/controllers/assistants/helpers'); const { createFile, updateFileUsage, deleteFiles } = require('~/models/File'); -const { getEndpointsConfig } = require('~/server/services/Config'); -const { loadAuthValues } = require('~/app/clients/tools/util'); +const { loadAuthValues } = require('~/server/services/Tools/credentials'); +const { checkCapability } = require('~/server/services/Config'); const { LB_QueueAsyncCall } = require('~/server/utils/queue'); const { getStrategyFunctions } = require('./strategies'); const { determineFileType } = require('~/server/utils'); @@ -162,7 +162,6 @@ const processDeleteRequest = async ({ req, files }) => { for (const file of files) { const source = file.source ?? FileSources.local; - if (req.body.agent_id && req.body.tool_resource) { agentFiles.push({ tool_resource: req.body.tool_resource, @@ -170,6 +169,11 @@ const processDeleteRequest = async ({ req, files }) => { }); } + if (source === FileSources.text) { + resolvedFileIds.push(file.file_id); + continue; + } + if (checkOpenAIStorage(source) && !client[source]) { await initializeClients(); } @@ -347,8 +351,8 @@ const uploadImageBuffer = async ({ req, context, metadata = {}, resize = true }) req.app.locals.imageOutputType }`; } - - const filepath = await saveBuffer({ userId: req.user.id, fileName: filename, buffer }); + const fileName = `${file_id}-${filename}`; + const filepath = await saveBuffer({ userId: req.user.id, fileName, buffer }); return await createFile( { user: req.user.id, @@ -453,17 +457,6 @@ const processFileUpload = async ({ req, res, metadata }) => { res.status(200).json({ message: 'File uploaded and processed successfully', ...result }); }; -/** - * @param {ServerRequest} req - * @param {AgentCapabilities} capability - * @returns {Promise} - */ -const checkCapability = async (req, capability) => { - const endpointsConfig = await getEndpointsConfig(req); - const capabilities = endpointsConfig?.[EModelEndpoint.agents]?.capabilities ?? []; - return capabilities.includes(capability); -}; - /** * Applies the current strategy for file uploads. * Saves file metadata to the database with an expiry TTL. @@ -521,6 +514,52 @@ const processAgentFileUpload = async ({ req, res, metadata }) => { if (!isFileSearchEnabled) { throw new Error('File search is not enabled for Agents'); } + } else if (tool_resource === EToolResources.ocr) { + const isOCREnabled = await checkCapability(req, AgentCapabilities.ocr); + if (!isOCREnabled) { + throw new Error('OCR capability is not enabled for Agents'); + } + + const { handleFileUpload } = getStrategyFunctions( + req.app.locals?.ocr?.strategy ?? FileSources.mistral_ocr, + ); + const { file_id, temp_file_id } = metadata; + + const { + text, + bytes, + // TODO: OCR images support? + images, + filename, + filepath: ocrFileURL, + } = await handleFileUpload({ req, file, file_id, entity_id: agent_id }); + + const fileInfo = removeNullishValues({ + text, + bytes, + file_id, + temp_file_id, + user: req.user.id, + type: file.mimetype, + filepath: ocrFileURL, + source: FileSources.text, + filename: filename ?? file.originalname, + model: messageAttachment ? undefined : req.body.model, + context: messageAttachment ? FileContext.message_attachment : FileContext.agents, + }); + + if (!messageAttachment && tool_resource) { + await addAgentResourceFile({ + req, + file_id, + agent_id, + tool_resource, + }); + } + const result = await createFile(fileInfo, true); + return res + .status(200) + .json({ message: 'Agent file uploaded and processed successfully', ...result }); } const source = @@ -801,8 +840,7 @@ async function saveBase64Image( { req, file_id: _file_id, filename: _filename, endpoint, context, resolution = 'high' }, ) { const file_id = _file_id ?? v4(); - - let filename = _filename; + let filename = `${file_id}-${_filename}`; const { buffer: inputBuffer, type } = base64ToBuffer(url); if (!path.extname(_filename)) { const extension = mime.getExtension(type); diff --git a/api/server/services/Files/strategies.js b/api/server/services/Files/strategies.js index ddfdd57469..d05ea03728 100644 --- a/api/server/services/Files/strategies.js +++ b/api/server/services/Files/strategies.js @@ -21,9 +21,32 @@ const { processLocalAvatar, getLocalFileStream, } = require('./Local'); +const { + getS3URL, + saveURLToS3, + saveBufferToS3, + getS3FileStream, + uploadImageToS3, + prepareImageURLS3, + deleteFileFromS3, + processS3Avatar, + uploadFileToS3, +} = require('./S3'); +const { + saveBufferToAzure, + saveURLToAzure, + getAzureURL, + deleteFileFromAzure, + uploadFileToAzure, + getAzureFileStream, + uploadImageToAzure, + prepareAzureImageURL, + processAzureAvatar, +} = require('./Azure'); const { uploadOpenAIFile, deleteOpenAIFile, getOpenAIFileStream } = require('./OpenAI'); const { getCodeOutputDownloadStream, uploadCodeEnvFile } = require('./Code'); const { uploadVectors, deleteVectors } = require('./VectorDB'); +const { uploadMistralOCR } = require('./MistralOCR'); /** * Firebase Storage Strategy Functions @@ -57,6 +80,38 @@ const localStrategy = () => ({ getDownloadStream: getLocalFileStream, }); +/** + * S3 Storage Strategy Functions + * + * */ +const s3Strategy = () => ({ + handleFileUpload: uploadFileToS3, + saveURL: saveURLToS3, + getFileURL: getS3URL, + deleteFile: deleteFileFromS3, + saveBuffer: saveBufferToS3, + prepareImagePayload: prepareImageURLS3, + processAvatar: processS3Avatar, + handleImageUpload: uploadImageToS3, + getDownloadStream: getS3FileStream, +}); + +/** + * Azure Blob Storage Strategy Functions + * + * */ +const azureStrategy = () => ({ + handleFileUpload: uploadFileToAzure, + saveURL: saveURLToAzure, + getFileURL: getAzureURL, + deleteFile: deleteFileFromAzure, + saveBuffer: saveBufferToAzure, + prepareImagePayload: prepareAzureImageURL, + processAvatar: processAzureAvatar, + handleImageUpload: uploadImageToAzure, + getDownloadStream: getAzureFileStream, +}); + /** * VectorDB Storage Strategy Functions * @@ -127,6 +182,26 @@ const codeOutputStrategy = () => ({ getDownloadStream: getCodeOutputDownloadStream, }); +const mistralOCRStrategy = () => ({ + /** @type {typeof saveFileFromURL | null} */ + saveURL: null, + /** @type {typeof getLocalFileURL | null} */ + getFileURL: null, + /** @type {typeof saveLocalBuffer | null} */ + saveBuffer: null, + /** @type {typeof processLocalAvatar | null} */ + processAvatar: null, + /** @type {typeof uploadLocalImage | null} */ + handleImageUpload: null, + /** @type {typeof prepareImagesLocal | null} */ + prepareImagePayload: null, + /** @type {typeof deleteLocalFile | null} */ + deleteFile: null, + /** @type {typeof getLocalFileStream | null} */ + getDownloadStream: null, + handleFileUpload: uploadMistralOCR, +}); + // Strategy Selector const getStrategyFunctions = (fileSource) => { if (fileSource === FileSources.firebase) { @@ -136,11 +211,15 @@ const getStrategyFunctions = (fileSource) => { } else if (fileSource === FileSources.openai) { return openAIStrategy(); } else if (fileSource === FileSources.azure) { - return openAIStrategy(); + return azureStrategy(); } else if (fileSource === FileSources.vectordb) { return vectorStrategy(); + } else if (fileSource === FileSources.s3) { + return s3Strategy(); } else if (fileSource === FileSources.execute_code) { return codeOutputStrategy(); + } else if (fileSource === FileSources.mistral_ocr) { + return mistralOCRStrategy(); } else { throw new Error('Invalid file source'); } diff --git a/api/server/services/MCP.js b/api/server/services/MCP.js index f934f9d519..9b8ce30875 100644 --- a/api/server/services/MCP.js +++ b/api/server/services/MCP.js @@ -37,11 +37,19 @@ async function createMCPTool({ req, toolKey, provider }) { } const [toolName, serverName] = toolKey.split(Constants.mcp_delimiter); - /** @type {(toolInput: Object | string) => Promise} */ - const _call = async (toolInput) => { + /** @type {(toolArguments: Object | string, config?: GraphRunnableConfig) => Promise} */ + const _call = async (toolArguments, config) => { try { const mcpManager = await getMCPManager(); - const result = await mcpManager.callTool(serverName, toolName, provider, toolInput); + const result = await mcpManager.callTool({ + serverName, + toolName, + provider, + toolArguments, + options: { + signal: config?.signal, + }, + }); if (isAssistantsEndpoint(provider) && Array.isArray(result)) { return result[0]; } diff --git a/api/server/services/ModelService.js b/api/server/services/ModelService.js index 1394a5d697..a1ccd7643b 100644 --- a/api/server/services/ModelService.js +++ b/api/server/services/ModelService.js @@ -1,9 +1,12 @@ const axios = require('axios'); +const { Providers } = require('@librechat/agents'); const { HttpsProxyAgent } = require('https-proxy-agent'); const { EModelEndpoint, defaultModels, CacheKeys } = require('librechat-data-provider'); const { inputSchema, logAxiosError, extractBaseURL, processModelData } = require('~/utils'); const { OllamaClient } = require('~/app/clients/OllamaClient'); +const { isUserProvided } = require('~/server/utils'); const getLogStores = require('~/cache/getLogStores'); +const { logger } = require('~/config'); /** * Splits a string by commas and trims each resulting value. @@ -41,7 +44,7 @@ const fetchModels = async ({ user, apiKey, baseURL, - name = 'OpenAI', + name = EModelEndpoint.openAI, azure = false, userIdQuery = false, createTokenConfig = true, @@ -57,18 +60,25 @@ const fetchModels = async ({ return models; } - if (name && name.toLowerCase().startsWith('ollama')) { + if (name && name.toLowerCase().startsWith(Providers.OLLAMA)) { return await OllamaClient.fetchModels(baseURL); } try { const options = { - headers: { - Authorization: `Bearer ${apiKey}`, - }, + headers: {}, timeout: 5000, }; + if (name === EModelEndpoint.anthropic) { + options.headers = { + 'x-api-key': apiKey, + 'anthropic-version': process.env.ANTHROPIC_VERSION || '2023-06-01', + }; + } else { + options.headers.Authorization = `Bearer ${apiKey}`; + } + if (process.env.PROXY) { options.httpsAgent = new HttpsProxyAgent(process.env.PROXY); } @@ -128,9 +138,6 @@ const fetchOpenAIModels = async (opts, _models = []) => { // .split('/deployments')[0] // .concat(`/models?api-version=${azure.azureOpenAIApiVersion}`); // apiKey = azureOpenAIApiKey; - } else if (process.env.OPENROUTER_API_KEY) { - reverseProxyUrl = 'https://openrouter.ai/api/v1'; - apiKey = process.env.OPENROUTER_API_KEY; } if (reverseProxyUrl) { @@ -150,7 +157,7 @@ const fetchOpenAIModels = async (opts, _models = []) => { baseURL, azure: opts.azure, user: opts.user, - name: baseURL, + name: EModelEndpoint.openAI, }); } @@ -159,7 +166,7 @@ const fetchOpenAIModels = async (opts, _models = []) => { } if (baseURL === openaiBaseURL) { - const regex = /(text-davinci-003|gpt-|o\d+-)/; + const regex = /(text-davinci-003|gpt-|o\d+)/; const excludeRegex = /audio|realtime/; models = models.filter((model) => regex.test(model) && !excludeRegex.test(model)); const instructModels = models.filter((model) => model.includes('instruct')); @@ -217,7 +224,7 @@ const getOpenAIModels = async (opts) => { return models; } - if (userProvidedOpenAI && !process.env.OPENROUTER_API_KEY) { + if (userProvidedOpenAI) { return models; } @@ -233,13 +240,71 @@ const getChatGPTBrowserModels = () => { return models; }; -const getAnthropicModels = () => { +/** + * Fetches models from the Anthropic API. + * @async + * @function + * @param {object} opts - The options for fetching the models. + * @param {string} opts.user - The user ID to send to the API. + * @param {string[]} [_models=[]] - The models to use as a fallback. + */ +const fetchAnthropicModels = async (opts, _models = []) => { + let models = _models.slice() ?? []; + let apiKey = process.env.ANTHROPIC_API_KEY; + const anthropicBaseURL = 'https://api.anthropic.com/v1'; + let baseURL = anthropicBaseURL; + let reverseProxyUrl = process.env.ANTHROPIC_REVERSE_PROXY; + + if (reverseProxyUrl) { + baseURL = extractBaseURL(reverseProxyUrl); + } + + if (!apiKey) { + return models; + } + + const modelsCache = getLogStores(CacheKeys.MODEL_QUERIES); + + const cachedModels = await modelsCache.get(baseURL); + if (cachedModels) { + return cachedModels; + } + + if (baseURL) { + models = await fetchModels({ + apiKey, + baseURL, + user: opts.user, + name: EModelEndpoint.anthropic, + tokenKey: EModelEndpoint.anthropic, + }); + } + + if (models.length === 0) { + return _models; + } + + await modelsCache.set(baseURL, models); + return models; +}; + +const getAnthropicModels = async (opts = {}) => { let models = defaultModels[EModelEndpoint.anthropic]; if (process.env.ANTHROPIC_MODELS) { models = splitAndTrim(process.env.ANTHROPIC_MODELS); + return models; } - return models; + if (isUserProvided(process.env.ANTHROPIC_API_KEY)) { + return models; + } + + try { + return await fetchAnthropicModels(opts, models); + } catch (error) { + logger.error('Error fetching Anthropic models:', error); + return models; + } }; const getGoogleModels = () => { diff --git a/api/server/services/ModelService.spec.js b/api/server/services/ModelService.spec.js index a383db1e3c..fb4481f840 100644 --- a/api/server/services/ModelService.spec.js +++ b/api/server/services/ModelService.spec.js @@ -161,22 +161,6 @@ describe('getOpenAIModels', () => { expect(models).toEqual(expect.arrayContaining(['openai-model', 'openai-model-2'])); }); - it('attempts to use OPENROUTER_API_KEY if set', async () => { - process.env.OPENROUTER_API_KEY = 'test-router-key'; - const expectedModels = ['model-router-1', 'model-router-2']; - - axios.get.mockResolvedValue({ - data: { - data: expectedModels.map((id) => ({ id })), - }, - }); - - const models = await getOpenAIModels({ user: 'user456' }); - - expect(models).toEqual(expect.arrayContaining(expectedModels)); - expect(axios.get).toHaveBeenCalled(); - }); - it('utilizes proxy configuration when PROXY is set', async () => { axios.get.mockResolvedValue({ data: { @@ -368,15 +352,15 @@ describe('splitAndTrim', () => { }); describe('getAnthropicModels', () => { - it('returns default models when ANTHROPIC_MODELS is not set', () => { + it('returns default models when ANTHROPIC_MODELS is not set', async () => { delete process.env.ANTHROPIC_MODELS; - const models = getAnthropicModels(); + const models = await getAnthropicModels(); expect(models).toEqual(defaultModels[EModelEndpoint.anthropic]); }); - it('returns models from ANTHROPIC_MODELS when set', () => { + it('returns models from ANTHROPIC_MODELS when set', async () => { process.env.ANTHROPIC_MODELS = 'claude-1, claude-2 '; - const models = getAnthropicModels(); + const models = await getAnthropicModels(); expect(models).toEqual(['claude-1', 'claude-2']); }); }); diff --git a/api/server/services/ToolService.js b/api/server/services/ToolService.js index f3e4efb6e3..969ca8d8ff 100644 --- a/api/server/services/ToolService.js +++ b/api/server/services/ToolService.js @@ -362,7 +362,12 @@ async function processRequiredActions(client, requiredActions) { continue; } - tool = await createActionTool({ action: actionSet, requestBuilder }); + tool = await createActionTool({ + req: client.req, + res: client.res, + action: actionSet, + requestBuilder, + }); if (!tool) { logger.warn( `Invalid action: user: ${client.req.user.id} | thread_id: ${requiredActions[0].thread_id} | run_id: ${requiredActions[0].run_id} | toolName: ${currentAction.tool}`, diff --git a/api/server/services/Tools/credentials.js b/api/server/services/Tools/credentials.js new file mode 100644 index 0000000000..b50a2460d4 --- /dev/null +++ b/api/server/services/Tools/credentials.js @@ -0,0 +1,56 @@ +const { getUserPluginAuthValue } = require('~/server/services/PluginService'); + +/** + * + * @param {Object} params + * @param {string} params.userId + * @param {string[]} params.authFields + * @param {Set} [params.optional] + * @param {boolean} [params.throwError] + * @returns + */ +const loadAuthValues = async ({ userId, authFields, optional, throwError = true }) => { + let authValues = {}; + + /** + * Finds the first non-empty value for the given authentication field, supporting alternate fields. + * @param {string[]} fields Array of strings representing the authentication fields. Supports alternate fields delimited by "||". + * @returns {Promise<{ authField: string, authValue: string} | null>} An object containing the authentication field and value, or null if not found. + */ + const findAuthValue = async (fields) => { + for (const field of fields) { + let value = process.env[field]; + if (value) { + return { authField: field, authValue: value }; + } + try { + value = await getUserPluginAuthValue(userId, field, throwError); + } catch (err) { + if (optional && optional.has(field)) { + return { authField: field, authValue: undefined }; + } + if (field === fields[fields.length - 1] && !value) { + throw err; + } + } + if (value) { + return { authField: field, authValue: value }; + } + } + return null; + }; + + for (let authField of authFields) { + const fields = authField.split('||'); + const result = await findAuthValue(fields); + if (result) { + authValues[result.authField] = result.authValue; + } + } + + return authValues; +}; + +module.exports = { + loadAuthValues, +}; diff --git a/api/server/services/start/interface.js b/api/server/services/start/interface.js index 98bcb6858e..5365c4af7f 100644 --- a/api/server/services/start/interface.js +++ b/api/server/services/start/interface.js @@ -34,6 +34,8 @@ async function loadDefaultInterface(config, configDefaults, roleName = SystemRol multiConvo: interfaceConfig?.multiConvo ?? defaults.multiConvo, agents: interfaceConfig?.agents ?? defaults.agents, temporaryChat: interfaceConfig?.temporaryChat ?? defaults.temporaryChat, + runCode: interfaceConfig?.runCode ?? defaults.runCode, + customWelcome: interfaceConfig?.customWelcome ?? defaults.customWelcome, }); await updateAccessPermissions(roleName, { @@ -41,12 +43,16 @@ async function loadDefaultInterface(config, configDefaults, roleName = SystemRol [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: loadedInterface.bookmarks }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: loadedInterface.multiConvo }, [PermissionTypes.AGENTS]: { [Permissions.USE]: loadedInterface.agents }, + [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: loadedInterface.temporaryChat }, + [PermissionTypes.RUN_CODE]: { [Permissions.USE]: loadedInterface.runCode }, }); await updateAccessPermissions(SystemRoles.ADMIN, { [PermissionTypes.PROMPTS]: { [Permissions.USE]: loadedInterface.prompts }, [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: loadedInterface.bookmarks }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: loadedInterface.multiConvo }, [PermissionTypes.AGENTS]: { [Permissions.USE]: loadedInterface.agents }, + [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: loadedInterface.temporaryChat }, + [PermissionTypes.RUN_CODE]: { [Permissions.USE]: loadedInterface.runCode }, }); let i = 0; diff --git a/api/server/services/start/interface.spec.js b/api/server/services/start/interface.spec.js index 0041246433..7e248d3dfe 100644 --- a/api/server/services/start/interface.spec.js +++ b/api/server/services/start/interface.spec.js @@ -14,6 +14,8 @@ describe('loadDefaultInterface', () => { bookmarks: true, multiConvo: true, agents: true, + temporaryChat: true, + runCode: true, }, }; const configDefaults = { interface: {} }; @@ -25,6 +27,8 @@ describe('loadDefaultInterface', () => { [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: true }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: true }, [PermissionTypes.AGENTS]: { [Permissions.USE]: true }, + [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: true }, + [PermissionTypes.RUN_CODE]: { [Permissions.USE]: true }, }); }); @@ -35,6 +39,8 @@ describe('loadDefaultInterface', () => { bookmarks: false, multiConvo: false, agents: false, + temporaryChat: false, + runCode: false, }, }; const configDefaults = { interface: {} }; @@ -46,6 +52,8 @@ describe('loadDefaultInterface', () => { [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: false }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: false }, [PermissionTypes.AGENTS]: { [Permissions.USE]: false }, + [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: false }, + [PermissionTypes.RUN_CODE]: { [Permissions.USE]: false }, }); }); @@ -60,6 +68,8 @@ describe('loadDefaultInterface', () => { [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: undefined }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: undefined }, [PermissionTypes.AGENTS]: { [Permissions.USE]: undefined }, + [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: undefined }, + [PermissionTypes.RUN_CODE]: { [Permissions.USE]: undefined }, }); }); @@ -70,6 +80,8 @@ describe('loadDefaultInterface', () => { bookmarks: undefined, multiConvo: undefined, agents: undefined, + temporaryChat: undefined, + runCode: undefined, }, }; const configDefaults = { interface: {} }; @@ -81,6 +93,8 @@ describe('loadDefaultInterface', () => { [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: undefined }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: undefined }, [PermissionTypes.AGENTS]: { [Permissions.USE]: undefined }, + [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: undefined }, + [PermissionTypes.RUN_CODE]: { [Permissions.USE]: undefined }, }); }); @@ -91,6 +105,8 @@ describe('loadDefaultInterface', () => { bookmarks: false, multiConvo: undefined, agents: true, + temporaryChat: undefined, + runCode: false, }, }; const configDefaults = { interface: {} }; @@ -102,6 +118,8 @@ describe('loadDefaultInterface', () => { [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: false }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: undefined }, [PermissionTypes.AGENTS]: { [Permissions.USE]: true }, + [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: undefined }, + [PermissionTypes.RUN_CODE]: { [Permissions.USE]: false }, }); }); @@ -113,6 +131,8 @@ describe('loadDefaultInterface', () => { bookmarks: true, multiConvo: true, agents: true, + temporaryChat: true, + runCode: true, }, }; @@ -123,6 +143,8 @@ describe('loadDefaultInterface', () => { [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: true }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: true }, [PermissionTypes.AGENTS]: { [Permissions.USE]: true }, + [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: true }, + [PermissionTypes.RUN_CODE]: { [Permissions.USE]: true }, }); }); @@ -137,6 +159,8 @@ describe('loadDefaultInterface', () => { [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: undefined }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: true }, [PermissionTypes.AGENTS]: { [Permissions.USE]: undefined }, + [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: undefined }, + [PermissionTypes.RUN_CODE]: { [Permissions.USE]: undefined }, }); }); @@ -151,6 +175,8 @@ describe('loadDefaultInterface', () => { [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: undefined }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: false }, [PermissionTypes.AGENTS]: { [Permissions.USE]: undefined }, + [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: undefined }, + [PermissionTypes.RUN_CODE]: { [Permissions.USE]: undefined }, }); }); @@ -165,6 +191,8 @@ describe('loadDefaultInterface', () => { [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: undefined }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: undefined }, [PermissionTypes.AGENTS]: { [Permissions.USE]: undefined }, + [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: undefined }, + [PermissionTypes.RUN_CODE]: { [Permissions.USE]: undefined }, }); }); @@ -175,6 +203,8 @@ describe('loadDefaultInterface', () => { bookmarks: false, multiConvo: true, agents: false, + temporaryChat: true, + runCode: false, }, }; const configDefaults = { interface: {} }; @@ -186,6 +216,8 @@ describe('loadDefaultInterface', () => { [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: false }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: true }, [PermissionTypes.AGENTS]: { [Permissions.USE]: false }, + [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: true }, + [PermissionTypes.RUN_CODE]: { [Permissions.USE]: false }, }); }); @@ -197,6 +229,8 @@ describe('loadDefaultInterface', () => { bookmarks: true, multiConvo: false, agents: undefined, + temporaryChat: undefined, + runCode: undefined, }, }; @@ -207,6 +241,8 @@ describe('loadDefaultInterface', () => { [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: true }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: false }, [PermissionTypes.AGENTS]: { [Permissions.USE]: undefined }, + [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: undefined }, + [PermissionTypes.RUN_CODE]: { [Permissions.USE]: undefined }, }); }); }); diff --git a/api/server/services/twoFactorService.js b/api/server/services/twoFactorService.js new file mode 100644 index 0000000000..d000c8fcfc --- /dev/null +++ b/api/server/services/twoFactorService.js @@ -0,0 +1,224 @@ +const { webcrypto } = require('node:crypto'); +const { decryptV3, decryptV2 } = require('../utils/crypto'); +const { hashBackupCode } = require('~/server/utils/crypto'); + +// Base32 alphabet for TOTP secret encoding. +const BASE32_ALPHABET = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ234567'; + +/** + * Encodes a Buffer into a Base32 string. + * @param {Buffer} buffer + * @returns {string} + */ +const encodeBase32 = (buffer) => { + let bits = 0; + let value = 0; + let output = ''; + for (const byte of buffer) { + value = (value << 8) | byte; + bits += 8; + while (bits >= 5) { + output += BASE32_ALPHABET[(value >>> (bits - 5)) & 31]; + bits -= 5; + } + } + if (bits > 0) { + output += BASE32_ALPHABET[(value << (5 - bits)) & 31]; + } + return output; +}; + +/** + * Decodes a Base32 string into a Buffer. + * @param {string} base32Str + * @returns {Buffer} + */ +const decodeBase32 = (base32Str) => { + const cleaned = base32Str.replace(/=+$/, '').toUpperCase(); + let bits = 0; + let value = 0; + const output = []; + for (const char of cleaned) { + const idx = BASE32_ALPHABET.indexOf(char); + if (idx === -1) { + continue; + } + value = (value << 5) | idx; + bits += 5; + if (bits >= 8) { + output.push((value >>> (bits - 8)) & 0xff); + bits -= 8; + } + } + return Buffer.from(output); +}; + +/** + * Generates a new TOTP secret (Base32 encoded). + * @returns {string} + */ +const generateTOTPSecret = () => { + const randomArray = new Uint8Array(10); + webcrypto.getRandomValues(randomArray); + return encodeBase32(Buffer.from(randomArray)); +}; + +/** + * Generates a TOTP code based on the secret and time. + * Uses a 30-second time step and produces a 6-digit code. + * @param {string} secret + * @param {number} [forTime=Date.now()] + * @returns {Promise} + */ +const generateTOTP = async (secret, forTime = Date.now()) => { + const timeStep = 30; // seconds + const counter = Math.floor(forTime / 1000 / timeStep); + const counterBuffer = new ArrayBuffer(8); + const counterView = new DataView(counterBuffer); + counterView.setUint32(4, counter, false); + + const keyBuffer = decodeBase32(secret); + const keyArrayBuffer = keyBuffer.buffer.slice( + keyBuffer.byteOffset, + keyBuffer.byteOffset + keyBuffer.byteLength, + ); + + const cryptoKey = await webcrypto.subtle.importKey( + 'raw', + keyArrayBuffer, + { name: 'HMAC', hash: 'SHA-1' }, + false, + ['sign'], + ); + const signatureBuffer = await webcrypto.subtle.sign('HMAC', cryptoKey, counterBuffer); + const hmac = new Uint8Array(signatureBuffer); + + // Dynamic truncation per RFC 4226. + const offset = hmac[hmac.length - 1] & 0xf; + const slice = hmac.slice(offset, offset + 4); + const view = new DataView(slice.buffer, slice.byteOffset, slice.byteLength); + const binaryCode = view.getUint32(0, false) & 0x7fffffff; + const code = (binaryCode % 1000000).toString().padStart(6, '0'); + return code; +}; + +/** + * Verifies a TOTP token by checking a ±1 time step window. + * @param {string} secret + * @param {string} token + * @returns {Promise} + */ +const verifyTOTP = async (secret, token) => { + const timeStepMS = 30 * 1000; + const currentTime = Date.now(); + for (let offset = -1; offset <= 1; offset++) { + const expected = await generateTOTP(secret, currentTime + offset * timeStepMS); + if (expected === token) { + return true; + } + } + return false; +}; + +/** + * Generates backup codes (default count: 10). + * Each code is an 8-character hexadecimal string and stored with its SHA-256 hash. + * @param {number} [count=10] + * @returns {Promise<{ plainCodes: string[], codeObjects: Array<{ codeHash: string, used: boolean, usedAt: Date | null }> }>} + */ +const generateBackupCodes = async (count = 10) => { + const plainCodes = []; + const codeObjects = []; + const encoder = new TextEncoder(); + + for (let i = 0; i < count; i++) { + const randomArray = new Uint8Array(4); + webcrypto.getRandomValues(randomArray); + const code = Array.from(randomArray) + .map((b) => b.toString(16).padStart(2, '0')) + .join(''); + plainCodes.push(code); + + const codeBuffer = encoder.encode(code); + const hashBuffer = await webcrypto.subtle.digest('SHA-256', codeBuffer); + const hashArray = Array.from(new Uint8Array(hashBuffer)); + const codeHash = hashArray.map((b) => b.toString(16).padStart(2, '0')).join(''); + codeObjects.push({ codeHash, used: false, usedAt: null }); + } + return { plainCodes, codeObjects }; +}; + +/** + * Verifies a backup code and, if valid, marks it as used. + * @param {Object} params + * @param {Object} params.user + * @param {string} params.backupCode + * @returns {Promise} + */ +const verifyBackupCode = async ({ user, backupCode }) => { + if (!backupCode || !user || !Array.isArray(user.backupCodes)) { + return false; + } + + const hashedInput = await hashBackupCode(backupCode.trim()); + const matchingCode = user.backupCodes.find( + (codeObj) => codeObj.codeHash === hashedInput && !codeObj.used, + ); + + if (matchingCode) { + const updatedBackupCodes = user.backupCodes.map((codeObj) => + codeObj.codeHash === hashedInput && !codeObj.used + ? { ...codeObj, used: true, usedAt: new Date() } + : codeObj, + ); + // Update the user record with the marked backup code. + const { updateUser } = require('~/models'); + await updateUser(user._id, { backupCodes: updatedBackupCodes }); + return true; + } + return false; +}; + +/** + * Retrieves and decrypts a stored TOTP secret. + * - Uses decryptV3 if the secret has a "v3:" prefix. + * - Falls back to decryptV2 for colon-delimited values. + * - Assumes a 16-character secret is already plain. + * @param {string|null} storedSecret + * @returns {Promise} + */ +const getTOTPSecret = async (storedSecret) => { + if (!storedSecret) { + return null; + } + if (storedSecret.startsWith('v3:')) { + return decryptV3(storedSecret); + } + if (storedSecret.includes(':')) { + return await decryptV2(storedSecret); + } + if (storedSecret.length === 16) { + return storedSecret; + } + return storedSecret; +}; + +/** + * Generates a temporary JWT token for 2FA verification that expires in 5 minutes. + * @param {string} userId + * @returns {string} + */ +const generate2FATempToken = (userId) => { + const { sign } = require('jsonwebtoken'); + return sign({ userId, twoFAPending: true }, process.env.JWT_SECRET, { expiresIn: '5m' }); +}; + +module.exports = { + generateTOTPSecret, + generateTOTP, + verifyTOTP, + generateBackupCodes, + verifyBackupCode, + getTOTPSecret, + generate2FATempToken, +}; diff --git a/api/server/socialLogins.js b/api/server/socialLogins.js index 88947c7940..a6ab1a4788 100644 --- a/api/server/socialLogins.js +++ b/api/server/socialLogins.js @@ -1,4 +1,4 @@ -const Redis = require('ioredis'); +const Keyv = require('keyv'); const passport = require('passport'); const session = require('express-session'); const MemoryStore = require('memorystore')(session); @@ -12,12 +12,15 @@ const { appleLogin, } = require('~/strategies'); const { isEnabled } = require('~/server/utils'); +const keyvRedis = require('~/cache/keyvRedis'); const { logger } = require('~/config'); /** * @param {Express.Application} app */ const configureSocialLogins = (app) => { + logger.info('Configuring social logins...'); + if (process.env.GOOGLE_CLIENT_ID && process.env.GOOGLE_CLIENT_SECRET) { passport.use(googleLogin()); } @@ -37,18 +40,17 @@ const configureSocialLogins = (app) => { process.env.OPENID_ENABLED && process.env.OPENID_SESSION_SECRET ) { + logger.info('Configuring OpenID Connect...'); const sessionOptions = { secret: process.env.OPENID_SESSION_SECRET, resave: false, saveUninitialized: false, }; if (isEnabled(process.env.USE_REDIS)) { - const client = new Redis(process.env.REDIS_URI); - client - .on('error', (err) => logger.error('ioredis error:', err)) - .on('ready', () => logger.info('ioredis successfully initialized.')) - .on('reconnecting', () => logger.info('ioredis reconnecting...')); - sessionOptions.store = new RedisStore({ client, prefix: 'librechat' }); + logger.debug('Using Redis for session storage in OpenID...'); + const keyv = new Keyv({ store: keyvRedis }); + const client = keyv.opts.store.redis; + sessionOptions.store = new RedisStore({ client, prefix: 'openid_session' }); } else { sessionOptions.store = new MemoryStore({ checkPeriod: 86400000, // prune expired entries every 24h @@ -57,7 +59,9 @@ const configureSocialLogins = (app) => { app.use(session(sessionOptions)); app.use(passport.session()); setupOpenId(); + + logger.info('OpenID Connect configured.'); } }; -module.exports = configureSocialLogins; \ No newline at end of file +module.exports = configureSocialLogins; diff --git a/api/server/utils/crypto.js b/api/server/utils/crypto.js index ea71df51ad..333cd7573a 100644 --- a/api/server/utils/crypto.js +++ b/api/server/utils/crypto.js @@ -1,27 +1,25 @@ require('dotenv').config(); +const crypto = require('node:crypto'); +const { webcrypto } = crypto; -const { webcrypto } = require('node:crypto'); +// Use hex decoding for both key and IV for legacy methods. const key = Buffer.from(process.env.CREDS_KEY, 'hex'); const iv = Buffer.from(process.env.CREDS_IV, 'hex'); const algorithm = 'AES-CBC'; +// --- Legacy v1/v2 Setup: AES-CBC with fixed key and IV --- + async function encrypt(value) { const cryptoKey = await webcrypto.subtle.importKey('raw', key, { name: algorithm }, false, [ 'encrypt', ]); - const encoder = new TextEncoder(); const data = encoder.encode(value); - const encryptedBuffer = await webcrypto.subtle.encrypt( - { - name: algorithm, - iv: iv, - }, + { name: algorithm, iv: iv }, cryptoKey, data, ); - return Buffer.from(encryptedBuffer).toString('hex'); } @@ -29,73 +27,85 @@ async function decrypt(encryptedValue) { const cryptoKey = await webcrypto.subtle.importKey('raw', key, { name: algorithm }, false, [ 'decrypt', ]); - const encryptedBuffer = Buffer.from(encryptedValue, 'hex'); - const decryptedBuffer = await webcrypto.subtle.decrypt( - { - name: algorithm, - iv: iv, - }, + { name: algorithm, iv: iv }, cryptoKey, encryptedBuffer, ); - const decoder = new TextDecoder(); return decoder.decode(decryptedBuffer); } -// Programmatically generate iv +// --- v2: AES-CBC with a random IV per encryption --- + async function encryptV2(value) { const gen_iv = webcrypto.getRandomValues(new Uint8Array(16)); - const cryptoKey = await webcrypto.subtle.importKey('raw', key, { name: algorithm }, false, [ 'encrypt', ]); - const encoder = new TextEncoder(); const data = encoder.encode(value); - const encryptedBuffer = await webcrypto.subtle.encrypt( - { - name: algorithm, - iv: gen_iv, - }, + { name: algorithm, iv: gen_iv }, cryptoKey, data, ); - return Buffer.from(gen_iv).toString('hex') + ':' + Buffer.from(encryptedBuffer).toString('hex'); } async function decryptV2(encryptedValue) { const parts = encryptedValue.split(':'); - // Already decrypted from an earlier invocation if (parts.length === 1) { return parts[0]; } const gen_iv = Buffer.from(parts.shift(), 'hex'); const encrypted = parts.join(':'); - const cryptoKey = await webcrypto.subtle.importKey('raw', key, { name: algorithm }, false, [ 'decrypt', ]); - const encryptedBuffer = Buffer.from(encrypted, 'hex'); - const decryptedBuffer = await webcrypto.subtle.decrypt( - { - name: algorithm, - iv: gen_iv, - }, + { name: algorithm, iv: gen_iv }, cryptoKey, encryptedBuffer, ); - const decoder = new TextDecoder(); return decoder.decode(decryptedBuffer); } +// --- v3: AES-256-CTR using Node's crypto functions --- +const algorithm_v3 = 'aes-256-ctr'; + +/** + * Encrypts a value using AES-256-CTR. + * Note: AES-256 requires a 32-byte key. Ensure that process.env.CREDS_KEY is a 64-character hex string. + * + * @param {string} value - The plaintext to encrypt. + * @returns {string} The encrypted string with a "v3:" prefix. + */ +function encryptV3(value) { + if (key.length !== 32) { + throw new Error(`Invalid key length: expected 32 bytes, got ${key.length} bytes`); + } + const iv_v3 = crypto.randomBytes(16); + const cipher = crypto.createCipheriv(algorithm_v3, key, iv_v3); + const encrypted = Buffer.concat([cipher.update(value, 'utf8'), cipher.final()]); + return `v3:${iv_v3.toString('hex')}:${encrypted.toString('hex')}`; +} + +function decryptV3(encryptedValue) { + const parts = encryptedValue.split(':'); + if (parts[0] !== 'v3') { + throw new Error('Not a v3 encrypted value'); + } + const iv_v3 = Buffer.from(parts[1], 'hex'); + const encryptedText = Buffer.from(parts.slice(2).join(':'), 'hex'); + const decipher = crypto.createDecipheriv(algorithm_v3, key, iv_v3); + const decrypted = Buffer.concat([decipher.update(encryptedText), decipher.final()]); + return decrypted.toString('utf8'); +} + async function hashToken(str) { const data = new TextEncoder().encode(str); const hashBuffer = await webcrypto.subtle.digest('SHA-256', data); @@ -106,10 +116,32 @@ async function getRandomValues(length) { if (!Number.isInteger(length) || length <= 0) { throw new Error('Length must be a positive integer'); } - const randomValues = new Uint8Array(length); webcrypto.getRandomValues(randomValues); return Buffer.from(randomValues).toString('hex'); } -module.exports = { encrypt, decrypt, encryptV2, decryptV2, hashToken, getRandomValues }; +/** + * Computes SHA-256 hash for the given input. + * @param {string} input + * @returns {Promise} + */ +async function hashBackupCode(input) { + const encoder = new TextEncoder(); + const data = encoder.encode(input); + const hashBuffer = await webcrypto.subtle.digest('SHA-256', data); + const hashArray = Array.from(new Uint8Array(hashBuffer)); + return hashArray.map((b) => b.toString(16).padStart(2, '0')).join(''); +} + +module.exports = { + encrypt, + decrypt, + encryptV2, + decryptV2, + encryptV3, + decryptV3, + hashToken, + hashBackupCode, + getRandomValues, +}; diff --git a/api/server/utils/handleText.js b/api/server/utils/handleText.js index 8c681d8f4e..f593d6c866 100644 --- a/api/server/utils/handleText.js +++ b/api/server/utils/handleText.js @@ -203,6 +203,8 @@ function generateConfig(key, baseURL, endpoint) { AgentCapabilities.artifacts, AgentCapabilities.actions, AgentCapabilities.tools, + AgentCapabilities.ocr, + AgentCapabilities.chain, ]; } diff --git a/api/server/utils/staticCache.js b/api/server/utils/staticCache.js index a8001c7e0a..23713ddf6f 100644 --- a/api/server/utils/staticCache.js +++ b/api/server/utils/staticCache.js @@ -1,4 +1,4 @@ -const express = require('express'); +const expressStaticGzip = require('express-static-gzip'); const oneDayInSeconds = 24 * 60 * 60; @@ -6,13 +6,13 @@ const sMaxAge = process.env.STATIC_CACHE_S_MAX_AGE || oneDayInSeconds; const maxAge = process.env.STATIC_CACHE_MAX_AGE || oneDayInSeconds * 2; const staticCache = (staticPath) => - express.static(staticPath, { - setHeaders: (res) => { - if (process.env.NODE_ENV?.toLowerCase() !== 'production') { - return; + expressStaticGzip(staticPath, { + enableBrotli: false, // disable Brotli, only using gzip + orderPreference: ['gz'], + setHeaders: (res, _path) => { + if (process.env.NODE_ENV?.toLowerCase() === 'production') { + res.setHeader('Cache-Control', `public, max-age=${maxAge}, s-maxage=${sMaxAge}`); } - - res.setHeader('Cache-Control', `public, max-age=${maxAge}, s-maxage=${sMaxAge}`); }, }); diff --git a/api/strategies/googleStrategy.js b/api/strategies/googleStrategy.js index ab8a268953..fd65823327 100644 --- a/api/strategies/googleStrategy.js +++ b/api/strategies/googleStrategy.js @@ -6,7 +6,7 @@ const getProfileDetails = ({ profile }) => ({ id: profile.id, avatarUrl: profile.photos[0].value, username: profile.name.givenName, - name: `${profile.name.givenName} ${profile.name.familyName}`, + name: `${profile.name.givenName}${profile.name.familyName ? ` ${profile.name.familyName}` : ''}`, emailVerified: profile.emails[0].verified, }); diff --git a/api/strategies/jwtStrategy.js b/api/strategies/jwtStrategy.js index e65b284950..ac19e92ac3 100644 --- a/api/strategies/jwtStrategy.js +++ b/api/strategies/jwtStrategy.js @@ -12,7 +12,7 @@ const jwtLogin = async () => }, async (payload, done) => { try { - const user = await getUserById(payload?.id, '-password -__v'); + const user = await getUserById(payload?.id, '-password -__v -totpSecret'); if (user) { user.id = user._id.toString(); if (!user.role) { diff --git a/api/strategies/ldapStrategy.js b/api/strategies/ldapStrategy.js index 4a2c1b827b..5ec279b982 100644 --- a/api/strategies/ldapStrategy.js +++ b/api/strategies/ldapStrategy.js @@ -18,6 +18,7 @@ const { LDAP_USERNAME, LDAP_EMAIL, LDAP_TLS_REJECT_UNAUTHORIZED, + LDAP_STARTTLS, } = process.env; // Check required environment variables @@ -50,6 +51,7 @@ if (LDAP_EMAIL) { searchAttributes.push(LDAP_EMAIL); } const rejectUnauthorized = isEnabled(LDAP_TLS_REJECT_UNAUTHORIZED); +const startTLS = isEnabled(LDAP_STARTTLS); const ldapOptions = { server: { @@ -72,6 +74,7 @@ const ldapOptions = { })(), }, }), + ...(startTLS && { starttls: true }), }, usernameField: 'email', passwordField: 'password', diff --git a/api/test/__mocks__/logger.js b/api/test/__mocks__/logger.js index caeb004e39..549c57d5a4 100644 --- a/api/test/__mocks__/logger.js +++ b/api/test/__mocks__/logger.js @@ -39,7 +39,10 @@ jest.mock('winston-daily-rotate-file', () => { }); jest.mock('~/config', () => { + const actualModule = jest.requireActual('~/config'); return { + sendEvent: actualModule.sendEvent, + createAxiosInstance: actualModule.createAxiosInstance, logger: { info: jest.fn(), warn: jest.fn(), diff --git a/api/typedefs.js b/api/typedefs.js index bd97bd93fa..21c4f1fecc 100644 --- a/api/typedefs.js +++ b/api/typedefs.js @@ -20,6 +20,12 @@ * @memberof typedefs */ +/** + * @exports NextFunction + * @typedef {import('express').NextFunction} NextFunction + * @memberof typedefs + */ + /** * @exports AgentRun * @typedef {import('@librechat/agents').Run} AgentRun @@ -760,36 +766,6 @@ * @memberof typedefs */ -/** - * @exports MongoFile - * @typedef {import('~/models/schema/fileSchema.js').MongoFile} MongoFile - * @memberof typedefs - */ - -/** - * @exports ToolCallData - * @typedef {import('~/models/schema/toolCallSchema.js').ToolCallData} ToolCallData - * @memberof typedefs - */ - -/** - * @exports MongoUser - * @typedef {import('~/models/schema/userSchema.js').MongoUser} MongoUser - * @memberof typedefs - */ - -/** - * @exports MongoProject - * @typedef {import('~/models/schema/projectSchema.js').MongoProject} MongoProject - * @memberof typedefs - */ - -/** - * @exports MongoPromptGroup - * @typedef {import('~/models/schema/promptSchema.js').MongoPromptGroup} MongoPromptGroup - * @memberof typedefs - */ - /** * @exports uploadImageBuffer * @typedef {import('~/server/services/Files/process').uploadImageBuffer} uploadImageBuffer @@ -1811,3 +1787,51 @@ * @typedef {Promise<{ message: TMessage, conversation: TConversation }> | undefined} ClientDatabaseSavePromise * @memberof typedefs */ + +/** + * @exports OCRImage + * @typedef {Object} OCRImage + * @property {string} id - The identifier of the image. + * @property {number} top_left_x - X-coordinate of the top left corner of the image. + * @property {number} top_left_y - Y-coordinate of the top left corner of the image. + * @property {number} bottom_right_x - X-coordinate of the bottom right corner of the image. + * @property {number} bottom_right_y - Y-coordinate of the bottom right corner of the image. + * @property {string} image_base64 - Base64-encoded image data. + * @memberof typedefs + */ + +/** + * @exports PageDimensions + * @typedef {Object} PageDimensions + * @property {number} dpi - The dots per inch resolution of the page. + * @property {number} height - The height of the page in pixels. + * @property {number} width - The width of the page in pixels. + * @memberof typedefs + */ + +/** + * @exports OCRPage + * @typedef {Object} OCRPage + * @property {number} index - The index of the page in the document. + * @property {string} markdown - The extracted text content of the page in markdown format. + * @property {OCRImage[]} images - Array of images found on the page. + * @property {PageDimensions} dimensions - The dimensions of the page. + * @memberof typedefs + */ + +/** + * @exports OCRUsageInfo + * @typedef {Object} OCRUsageInfo + * @property {number} pages_processed - Number of pages processed in the document. + * @property {number} doc_size_bytes - Size of the document in bytes. + * @memberof typedefs + */ + +/** + * @exports OCRResult + * @typedef {Object} OCRResult + * @property {OCRPage[]} pages - Array of pages extracted from the document. + * @property {string} model - The model used for OCR processing. + * @property {OCRUsageInfo} usage_info - Usage information for the OCR operation. + * @memberof typedefs + */ diff --git a/api/utils/axios.js b/api/utils/axios.js index 8b12a5ca99..acd23a184f 100644 --- a/api/utils/axios.js +++ b/api/utils/axios.js @@ -5,40 +5,32 @@ const { logger } = require('~/config'); * * @param {Object} options - The options object. * @param {string} options.message - The custom message to be logged. - * @param {Error} options.error - The Axios error object. + * @param {import('axios').AxiosError} options.error - The Axios error object. */ const logAxiosError = ({ message, error }) => { - const timedOutMessage = 'Cannot read properties of undefined (reading \'status\')'; - if (error.response) { - logger.error( - `${message} The request was made and the server responded with a status code that falls out of the range of 2xx: ${ - error.message ? error.message : '' - }. Error response data:\n`, - { - headers: error.response?.headers, - status: error.response?.status, - data: error.response?.data, - }, - ); - } else if (error.request) { - logger.error( - `${message} The request was made but no response was received: ${ - error.message ? error.message : '' - }. Error Request:\n`, - { - request: error.request, - }, - ); - } else if (error?.message?.includes(timedOutMessage)) { - logger.error( - `${message}\nThe request either timed out or was unsuccessful. Error message:\n`, - error, - ); - } else { - logger.error( - `${message}\nSomething happened in setting up the request. Error message:\n`, - error, - ); + try { + if (error.response?.status) { + const { status, headers, data } = error.response; + logger.error(`${message} The server responded with status ${status}: ${error.message}`, { + status, + headers, + data, + }); + } else if (error.request) { + const { method, url } = error.config || {}; + logger.error( + `${message} No response received for ${method ? method.toUpperCase() : ''} ${url || ''}: ${error.message}`, + { requestInfo: { method, url } }, + ); + } else if (error?.message?.includes('Cannot read properties of undefined (reading \'status\')')) { + logger.error( + `${message} It appears the request timed out or was unsuccessful: ${error.message}`, + ); + } else { + logger.error(`${message} An error occurred while setting up the request: ${error.message}`); + } + } catch (err) { + logger.error(`Error in logAxiosError: ${err.message}`); } }; diff --git a/api/utils/debug.js b/api/utils/debug.js deleted file mode 100644 index 68599eea38..0000000000 --- a/api/utils/debug.js +++ /dev/null @@ -1,56 +0,0 @@ -const levels = { - NONE: 0, - LOW: 1, - MEDIUM: 2, - HIGH: 3, -}; - -let level = levels.HIGH; - -module.exports = { - levels, - setLevel: (l) => (level = l), - log: { - parameters: (parameters) => { - if (levels.HIGH > level) { - return; - } - console.group(); - parameters.forEach((p) => console.log(`${p.name}:`, p.value)); - console.groupEnd(); - }, - functionName: (name) => { - if (levels.MEDIUM > level) { - return; - } - console.log(`\nEXECUTING: ${name}\n`); - }, - flow: (flow) => { - if (levels.LOW > level) { - return; - } - console.log(`\n\n\nBEGIN FLOW: ${flow}\n\n\n`); - }, - variable: ({ name, value }) => { - if (levels.HIGH > level) { - return; - } - console.group(); - console.group(); - console.log(`VARIABLE ${name}:`, value); - console.groupEnd(); - console.groupEnd(); - }, - request: () => (req, res, next) => { - if (levels.HIGH > level) { - return next(); - } - console.log('Hit URL', req.url, 'with following:'); - console.group(); - console.log('Query:', req.query); - console.log('Body:', req.body); - console.groupEnd(); - return next(); - }, - }, -}; diff --git a/api/utils/tokens.js b/api/utils/tokens.js index 0541f4f301..58aaf7051b 100644 --- a/api/utils/tokens.js +++ b/api/utils/tokens.js @@ -13,6 +13,7 @@ const openAIModels = { 'gpt-4-32k-0613': 32758, // -10 from max 'gpt-4-1106': 127500, // -500 from max 'gpt-4-0125': 127500, // -500 from max + 'gpt-4.5': 127500, // -500 from max 'gpt-4o': 127500, // -500 from max 'gpt-4o-mini': 127500, // -500 from max 'gpt-4o-2024-05-13': 127500, // -500 from max @@ -74,6 +75,7 @@ const anthropicModels = { 'claude-instant': 100000, 'claude-2': 100000, 'claude-2.1': 200000, + 'claude-3': 200000, 'claude-3-haiku': 200000, 'claude-3-sonnet': 200000, 'claude-3-opus': 200000, @@ -81,6 +83,8 @@ const anthropicModels = { 'claude-3-5-haiku': 200000, 'claude-3-5-sonnet': 200000, 'claude-3.5-sonnet': 200000, + 'claude-3-7-sonnet': 200000, + 'claude-3.7-sonnet': 200000, 'claude-3-5-sonnet-latest': 200000, 'claude-3.5-sonnet-latest': 200000, }; @@ -88,6 +92,7 @@ const anthropicModels = { const deepseekModels = { 'deepseek-reasoner': 63000, // -1000 from max (API) deepseek: 63000, // -1000 from max (API) + 'deepseek.r1': 127500, }; const metaModels = { @@ -183,7 +188,18 @@ const bedrockModels = { ...amazonModels, }; -const aggregateModels = { ...openAIModels, ...googleModels, ...bedrockModels }; +const xAIModels = { + 'grok-beta': 131072, + 'grok-vision-beta': 8192, + 'grok-2': 131072, + 'grok-2-latest': 131072, + 'grok-2-1212': 131072, + 'grok-2-vision': 32768, + 'grok-2-vision-latest': 32768, + 'grok-2-vision-1212': 32768, +}; + +const aggregateModels = { ...openAIModels, ...googleModels, ...bedrockModels, ...xAIModels }; const maxTokensMap = { [EModelEndpoint.azureOpenAI]: openAIModels, diff --git a/api/utils/tokens.spec.js b/api/utils/tokens.spec.js index eb1fd85495..e5ae21b646 100644 --- a/api/utils/tokens.spec.js +++ b/api/utils/tokens.spec.js @@ -103,6 +103,16 @@ describe('getModelMaxTokens', () => { ); }); + test('should return correct tokens for gpt-4.5 matches', () => { + expect(getModelMaxTokens('gpt-4.5')).toBe(maxTokensMap[EModelEndpoint.openAI]['gpt-4.5']); + expect(getModelMaxTokens('gpt-4.5-preview')).toBe( + maxTokensMap[EModelEndpoint.openAI]['gpt-4.5'], + ); + expect(getModelMaxTokens('openai/gpt-4.5-preview')).toBe( + maxTokensMap[EModelEndpoint.openAI]['gpt-4.5'], + ); + }); + test('should return correct tokens for Anthropic models', () => { const models = [ 'claude-2.1', @@ -116,6 +126,7 @@ describe('getModelMaxTokens', () => { 'claude-3-sonnet', 'claude-3-opus', 'claude-3-5-sonnet', + 'claude-3-7-sonnet', ]; const maxTokens = { @@ -412,6 +423,9 @@ describe('Meta Models Tests', () => { expect(getModelMaxTokens('deepseek-reasoner')).toBe( maxTokensMap[EModelEndpoint.openAI]['deepseek-reasoner'], ); + expect(getModelMaxTokens('deepseek.r1')).toBe( + maxTokensMap[EModelEndpoint.openAI]['deepseek.r1'], + ); }); }); @@ -483,3 +497,68 @@ describe('Meta Models Tests', () => { }); }); }); + +describe('Grok Model Tests - Tokens', () => { + describe('getModelMaxTokens', () => { + test('should return correct tokens for Grok vision models', () => { + expect(getModelMaxTokens('grok-2-vision-1212')).toBe(32768); + expect(getModelMaxTokens('grok-2-vision')).toBe(32768); + expect(getModelMaxTokens('grok-2-vision-latest')).toBe(32768); + }); + + test('should return correct tokens for Grok beta models', () => { + expect(getModelMaxTokens('grok-vision-beta')).toBe(8192); + expect(getModelMaxTokens('grok-beta')).toBe(131072); + }); + + test('should return correct tokens for Grok text models', () => { + expect(getModelMaxTokens('grok-2-1212')).toBe(131072); + expect(getModelMaxTokens('grok-2')).toBe(131072); + expect(getModelMaxTokens('grok-2-latest')).toBe(131072); + }); + + test('should handle partial matches for Grok models with prefixes', () => { + // Vision models should match before general models + expect(getModelMaxTokens('openai/grok-2-vision-1212')).toBe(32768); + expect(getModelMaxTokens('openai/grok-2-vision')).toBe(32768); + expect(getModelMaxTokens('openai/grok-2-vision-latest')).toBe(32768); + // Beta models + expect(getModelMaxTokens('openai/grok-vision-beta')).toBe(8192); + expect(getModelMaxTokens('openai/grok-beta')).toBe(131072); + // Text models + expect(getModelMaxTokens('openai/grok-2-1212')).toBe(131072); + expect(getModelMaxTokens('openai/grok-2')).toBe(131072); + expect(getModelMaxTokens('openai/grok-2-latest')).toBe(131072); + }); + }); + + describe('matchModelName', () => { + test('should match exact Grok model names', () => { + // Vision models + expect(matchModelName('grok-2-vision-1212')).toBe('grok-2-vision-1212'); + expect(matchModelName('grok-2-vision')).toBe('grok-2-vision'); + expect(matchModelName('grok-2-vision-latest')).toBe('grok-2-vision-latest'); + // Beta models + expect(matchModelName('grok-vision-beta')).toBe('grok-vision-beta'); + expect(matchModelName('grok-beta')).toBe('grok-beta'); + // Text models + expect(matchModelName('grok-2-1212')).toBe('grok-2-1212'); + expect(matchModelName('grok-2')).toBe('grok-2'); + expect(matchModelName('grok-2-latest')).toBe('grok-2-latest'); + }); + + test('should match Grok model variations with prefixes', () => { + // Vision models should match before general models + expect(matchModelName('openai/grok-2-vision-1212')).toBe('grok-2-vision-1212'); + expect(matchModelName('openai/grok-2-vision')).toBe('grok-2-vision'); + expect(matchModelName('openai/grok-2-vision-latest')).toBe('grok-2-vision-latest'); + // Beta models + expect(matchModelName('openai/grok-vision-beta')).toBe('grok-vision-beta'); + expect(matchModelName('openai/grok-beta')).toBe('grok-beta'); + // Text models + expect(matchModelName('openai/grok-2-1212')).toBe('grok-2-1212'); + expect(matchModelName('openai/grok-2')).toBe('grok-2'); + expect(matchModelName('openai/grok-2-latest')).toBe('grok-2-latest'); + }); + }); +}); diff --git a/client/index.html b/client/index.html index 9bd0363fab..9e300e7365 100644 --- a/client/index.html +++ b/client/index.html @@ -6,6 +6,7 @@ + LibreChat @@ -53,6 +54,5 @@
- diff --git a/client/package.json b/client/package.json index 993cf30071..96b402e747 100644 --- a/client/package.json +++ b/client/package.json @@ -1,6 +1,6 @@ { "name": "@librechat/frontend", - "version": "v0.7.7-rc1", + "version": "v0.7.7", "description": "", "type": "module", "scripts": { @@ -28,7 +28,8 @@ }, "homepage": "https://librechat.ai", "dependencies": { - "@ariakit/react": "^0.4.11", + "@ariakit/react": "^0.4.15", + "@ariakit/react-core": "^0.4.15", "@codesandbox/sandpack-react": "^2.19.10", "@dicebear/collection": "^7.0.4", "@dicebear/core": "^7.0.4", @@ -43,6 +44,7 @@ "@radix-ui/react-icons": "^1.3.0", "@radix-ui/react-label": "^2.0.0", "@radix-ui/react-popover": "^1.0.7", + "@radix-ui/react-progress": "^1.1.2", "@radix-ui/react-radio-group": "^1.1.3", "@radix-ui/react-select": "^2.0.0", "@radix-ui/react-separator": "^1.0.3", @@ -63,6 +65,8 @@ "framer-motion": "^11.5.4", "html-to-image": "^1.11.11", "i18next": "^24.2.2", + "i18next-browser-languagedetector": "^8.0.3", + "input-otp": "^1.4.2", "js-cookie": "^3.0.5", "librechat-data-provider": "*", "lodash": "^4.17.21", @@ -82,7 +86,7 @@ "react-i18next": "^15.4.0", "react-lazy-load-image-component": "^1.6.0", "react-markdown": "^9.0.1", - "react-resizable-panels": "^2.1.1", + "react-resizable-panels": "^2.1.7", "react-router-dom": "^6.11.2", "react-speech-recognition": "^3.10.0", "react-textarea-autosize": "^8.4.0", @@ -138,6 +142,7 @@ "typescript": "^5.3.3", "vite": "^6.1.0", "vite-plugin-node-polyfills": "^0.17.0", + "vite-plugin-compression": "^0.5.1", "vite-plugin-pwa": "^0.21.1" } } diff --git a/client/public/assets/apple-touch-icon-180x180.png b/client/public/assets/apple-touch-icon-180x180.png index 91dde5d139..57c4637c93 100644 Binary files a/client/public/assets/apple-touch-icon-180x180.png and b/client/public/assets/apple-touch-icon-180x180.png differ diff --git a/client/public/assets/icon-192x192.png b/client/public/assets/icon-192x192.png new file mode 100644 index 0000000000..b8dfe0eae5 Binary files /dev/null and b/client/public/assets/icon-192x192.png differ diff --git a/client/public/assets/maskable-icon.png b/client/public/assets/maskable-icon.png index 18305a2446..90e48f870b 100644 Binary files a/client/public/assets/maskable-icon.png and b/client/public/assets/maskable-icon.png differ diff --git a/client/public/robots.txt b/client/public/robots.txt new file mode 100644 index 0000000000..376d535e0a --- /dev/null +++ b/client/public/robots.txt @@ -0,0 +1,3 @@ +User-agent: * +Disallow: /api/ +Allow: / \ No newline at end of file diff --git a/client/src/@types/i18next.d.ts b/client/src/@types/i18next.d.ts new file mode 100644 index 0000000000..2d50f5a3cd --- /dev/null +++ b/client/src/@types/i18next.d.ts @@ -0,0 +1,9 @@ +import { defaultNS, resources } from '~/locales/i18n'; + +declare module 'i18next' { + interface CustomTypeOptions { + defaultNS: typeof defaultNS; + resources: typeof resources.en; + strictKeyChecks: true + } +} \ No newline at end of file diff --git a/client/src/common/agents-types.ts b/client/src/common/agents-types.ts index a9c24106bc..982cbfdb17 100644 --- a/client/src/common/agents-types.ts +++ b/client/src/common/agents-types.ts @@ -5,6 +5,7 @@ import type { OptionWithIcon, ExtendedFile } from './types'; export type TAgentOption = OptionWithIcon & Agent & { knowledge_files?: Array<[string, ExtendedFile]>; + context_files?: Array<[string, ExtendedFile]>; code_files?: Array<[string, ExtendedFile]>; }; @@ -27,4 +28,5 @@ export type AgentForm = { provider?: AgentProvider | OptionWithIcon; agent_ids?: string[]; [AgentCapabilities.artifacts]?: ArtifactModes | string; + recursion_limit?: number; } & TAgentCapabilities; diff --git a/client/src/common/types.ts b/client/src/common/types.ts index 3d61eccb1c..118cefce16 100644 --- a/client/src/common/types.ts +++ b/client/src/common/types.ts @@ -106,7 +106,7 @@ export type IconsRecord = { export type AgentIconMapProps = IconMapProps & { agentName?: string }; export type NavLink = { - title: string; + title: TranslationKeys; label?: string; icon: LucideIcon | React.FC; Component?: React.ComponentType; @@ -131,6 +131,7 @@ export interface DataColumnMeta { } export enum Panel { + advanced = 'advanced', builder = 'builder', actions = 'actions', model = 'model', @@ -181,6 +182,7 @@ export type AgentPanelProps = { activePanel?: string; action?: t.Action; actions?: t.Action[]; + createMutation: UseMutationResult; setActivePanel: React.Dispatch>; setAction: React.Dispatch>; endpointsConfig?: t.TEndpointsConfig; @@ -370,12 +372,12 @@ export type TDangerButtonProps = { showText?: boolean; mutation?: UseMutationResult; onClick: () => void; - infoTextCode: string; - actionTextCode: string; + infoTextCode: TranslationKeys; + actionTextCode: TranslationKeys; dataTestIdInitial: string; dataTestIdConfirm: string; - infoDescriptionCode?: string; - confirmActionTextCode?: string; + infoDescriptionCode?: TranslationKeys; + confirmActionTextCode?: TranslationKeys; }; export type TDialogProps = { @@ -399,7 +401,7 @@ export type TAuthContext = { isAuthenticated: boolean; error: string | undefined; login: (data: t.TLoginUser) => void; - logout: () => void; + logout: (redirect?: string) => void; setError: React.Dispatch>; roles?: Record; }; @@ -483,6 +485,7 @@ export interface ExtendedFile { attached?: boolean; embedded?: boolean; tool_resource?: string; + metadata?: t.TFile['metadata']; } export type ContextType = { navVisible: boolean; setNavVisible: (visible: boolean) => void }; diff --git a/client/src/components/Artifacts/ArtifactCodeEditor.tsx b/client/src/components/Artifacts/ArtifactCodeEditor.tsx index 51f4645637..af0f7bdc8e 100644 --- a/client/src/components/Artifacts/ArtifactCodeEditor.tsx +++ b/client/src/components/Artifacts/ArtifactCodeEditor.tsx @@ -8,8 +8,8 @@ import { import { SandpackProviderProps } from '@codesandbox/sandpack-react/unstyled'; import type { CodeEditorRef } from '@codesandbox/sandpack-react'; import type { ArtifactFiles, Artifact } from '~/common'; +import { useEditArtifact, useGetStartupConfig } from '~/data-provider'; import { sharedFiles, sharedOptions } from '~/utils/artifacts'; -import { useEditArtifact } from '~/data-provider'; import { useEditorContext } from '~/Providers'; const createDebouncedMutation = ( @@ -124,6 +124,17 @@ export const ArtifactCodeEditor = memo(function ({ sharedProps: Partial; editorRef: React.MutableRefObject; }) { + const { data: config } = useGetStartupConfig(); + const options: typeof sharedOptions = useMemo(() => { + if (!config) { + return sharedOptions; + } + return { + ...sharedOptions, + bundlerURL: config.bundlerURL, + }; + }, [config]); + if (Object.keys(files).length === 0) { return null; } @@ -135,7 +146,7 @@ export const ArtifactCodeEditor = memo(function ({ ...files, ...sharedFiles, }} - options={{ ...sharedOptions }} + options={options} {...sharedProps} template={template} > diff --git a/client/src/components/Artifacts/ArtifactPreview.tsx b/client/src/components/Artifacts/ArtifactPreview.tsx index 9cb06d413c..d3d147929f 100644 --- a/client/src/components/Artifacts/ArtifactPreview.tsx +++ b/client/src/components/Artifacts/ArtifactPreview.tsx @@ -7,6 +7,7 @@ import { import type { SandpackPreviewRef } from '@codesandbox/sandpack-react/unstyled'; import type { ArtifactFiles } from '~/common'; import { sharedFiles, sharedOptions } from '~/utils/artifacts'; +import { useGetStartupConfig } from '~/data-provider'; import { useEditorContext } from '~/Providers'; export const ArtifactPreview = memo(function ({ @@ -23,6 +24,8 @@ export const ArtifactPreview = memo(function ({ previewRef: React.MutableRefObject; }) { const { currentCode } = useEditorContext(); + const { data: config } = useGetStartupConfig(); + const artifactFiles = useMemo(() => { if (Object.keys(files).length === 0) { return files; @@ -38,6 +41,17 @@ export const ArtifactPreview = memo(function ({ }, }; }, [currentCode, files, fileKey]); + + const options: typeof sharedOptions = useMemo(() => { + if (!config) { + return sharedOptions; + } + return { + ...sharedOptions, + bundlerURL: config.bundlerURL, + }; + }, [config]); + if (Object.keys(artifactFiles).length === 0) { return null; } @@ -48,7 +62,7 @@ export const ArtifactPreview = memo(function ({ ...artifactFiles, ...sharedFiles, }} - options={{ ...sharedOptions }} + options={options} {...sharedProps} template={template} > diff --git a/client/src/components/Auth/AuthLayout.tsx b/client/src/components/Auth/AuthLayout.tsx index a7e890517a..d90f0d3dfe 100644 --- a/client/src/components/Auth/AuthLayout.tsx +++ b/client/src/components/Auth/AuthLayout.tsx @@ -85,7 +85,8 @@ function AuthLayout({ )} {children} - {(pathname.includes('login') || pathname.includes('register')) && ( + {!pathname.includes('2fa') && + (pathname.includes('login') || pathname.includes('register')) && ( )} diff --git a/client/src/components/Auth/Login.tsx b/client/src/components/Auth/Login.tsx index a332553701..48cbfe1a91 100644 --- a/client/src/components/Auth/Login.tsx +++ b/client/src/components/Auth/Login.tsx @@ -1,16 +1,78 @@ -import { useOutletContext } from 'react-router-dom'; +import { useOutletContext, useSearchParams } from 'react-router-dom'; +import { useEffect, useState } from 'react'; import { useAuthContext } from '~/hooks/AuthContext'; import type { TLoginLayoutContext } from '~/common'; import { ErrorMessage } from '~/components/Auth/ErrorMessage'; import { getLoginError } from '~/utils'; import { useLocalize } from '~/hooks'; import LoginForm from './LoginForm'; +import SocialButton from '~/components/Auth/SocialButton'; +import { OpenIDIcon } from '~/components'; function Login() { const localize = useLocalize(); const { error, setError, login } = useAuthContext(); const { startupConfig } = useOutletContext(); + const [searchParams, setSearchParams] = useSearchParams(); + // Determine if auto-redirect should be disabled based on the URL parameter + const disableAutoRedirect = searchParams.get('redirect') === 'false'; + + // Persist the disable flag locally so that once detected, auto-redirect stays disabled. + const [isAutoRedirectDisabled, setIsAutoRedirectDisabled] = useState(disableAutoRedirect); + + // Once the disable flag is detected, update local state and remove the parameter from the URL. + useEffect(() => { + if (disableAutoRedirect) { + setIsAutoRedirectDisabled(true); + const newParams = new URLSearchParams(searchParams); + newParams.delete('redirect'); + setSearchParams(newParams, { replace: true }); + } + }, [disableAutoRedirect, searchParams, setSearchParams]); + + // Determine whether we should auto-redirect to OpenID. + const shouldAutoRedirect = + startupConfig?.openidLoginEnabled && + startupConfig?.openidAutoRedirect && + startupConfig?.serverDomain && + !isAutoRedirectDisabled; + + useEffect(() => { + if (shouldAutoRedirect) { + console.log('Auto-redirecting to OpenID provider...'); + window.location.href = `${startupConfig.serverDomain}/oauth/openid`; + } + }, [shouldAutoRedirect, startupConfig]); + + // Render fallback UI if auto-redirect is active. + if (shouldAutoRedirect) { + return ( +
+

+ {localize('com_ui_redirecting_to_provider', { 0: startupConfig.openidLabel })} +

+
+ + startupConfig.openidImageUrl ? ( + OpenID Logo + ) : ( + + ) + } + label={startupConfig.openidLabel} + id="openid" + /> +
+
+ ); + } + return ( <> {error != null && {localize(getLoginError(error))}} diff --git a/client/src/components/Auth/LoginForm.tsx b/client/src/components/Auth/LoginForm.tsx index 9f5bb46039..2cd62d08b9 100644 --- a/client/src/components/Auth/LoginForm.tsx +++ b/client/src/components/Auth/LoginForm.tsx @@ -166,9 +166,7 @@ const LoginForm: React.FC = ({ onSubmit, startupConfig, error, type="submit" className=" w-full rounded-2xl bg-green-600 px-4 py-3 text-sm font-medium text-white - transition-colors hover:bg-green-700 focus:outline-none focus:ring-2 - focus:ring-green-500 focus:ring-offset-2 disabled:opacity-50 - disabled:hover:bg-green-600 dark:bg-green-600 dark:hover:bg-green-700 + transition-colors hover:bg-green-700 dark:bg-green-600 dark:hover:bg-green-700 " > {localize('com_auth_continue')} diff --git a/client/src/components/Auth/TwoFactorScreen.tsx b/client/src/components/Auth/TwoFactorScreen.tsx new file mode 100644 index 0000000000..04f89d7cea --- /dev/null +++ b/client/src/components/Auth/TwoFactorScreen.tsx @@ -0,0 +1,176 @@ +import React, { useState, useCallback } from 'react'; +import { useSearchParams } from 'react-router-dom'; +import { useForm, Controller } from 'react-hook-form'; +import { REGEXP_ONLY_DIGITS, REGEXP_ONLY_DIGITS_AND_CHARS } from 'input-otp'; +import { InputOTP, InputOTPGroup, InputOTPSeparator, InputOTPSlot, Label } from '~/components'; +import { useVerifyTwoFactorTempMutation } from '~/data-provider'; +import { useToastContext } from '~/Providers'; +import { useLocalize } from '~/hooks'; + +interface VerifyPayload { + tempToken: string; + token?: string; + backupCode?: string; +} + +type TwoFactorFormInputs = { + token?: string; + backupCode?: string; +}; + +const TwoFactorScreen: React.FC = React.memo(() => { + const [searchParams] = useSearchParams(); + const tempTokenRaw = searchParams.get('tempToken'); + const tempToken = tempTokenRaw !== null && tempTokenRaw !== '' ? tempTokenRaw : ''; + + const { + control, + handleSubmit, + formState: { errors }, + } = useForm(); + const localize = useLocalize(); + const { showToast } = useToastContext(); + const [useBackup, setUseBackup] = useState(false); + const [isLoading, setIsLoading] = useState(false); + const { mutate: verifyTempMutate } = useVerifyTwoFactorTempMutation({ + onSuccess: (result) => { + if (result.token != null && result.token !== '') { + window.location.href = '/'; + } + }, + onMutate: () => { + setIsLoading(true); + }, + onError: (error: unknown) => { + setIsLoading(false); + const err = error as { response?: { data?: { message?: unknown } } }; + const errorMsg = + typeof err.response?.data?.message === 'string' + ? err.response.data.message + : 'Error verifying 2FA'; + showToast({ message: errorMsg, status: 'error' }); + }, + }); + + const onSubmit = useCallback( + (data: TwoFactorFormInputs) => { + const payload: VerifyPayload = { tempToken }; + if (useBackup && data.backupCode != null && data.backupCode !== '') { + payload.backupCode = data.backupCode; + } else if (data.token != null && data.token !== '') { + payload.token = data.token; + } + verifyTempMutate(payload); + }, + [tempToken, useBackup, verifyTempMutate], + ); + + const toggleBackupOn = useCallback(() => { + setUseBackup(true); + }, []); + + const toggleBackupOff = useCallback(() => { + setUseBackup(false); + }, []); + + return ( +
+
+ + {!useBackup && ( +
+ ( + + + + + + + + + + + + + + )} + /> + {errors.token && {errors.token.message}} +
+ )} + {useBackup && ( +
+ ( + + + + + + + + + + + + + )} + /> + {errors.backupCode && ( + {errors.backupCode.message} + )} +
+ )} +
+ +
+
+ {!useBackup ? ( + + ) : ( + + )} +
+
+
+ ); +}); + +export default TwoFactorScreen; diff --git a/client/src/components/Auth/index.ts b/client/src/components/Auth/index.ts index cd1ac1adce..afde148015 100644 --- a/client/src/components/Auth/index.ts +++ b/client/src/components/Auth/index.ts @@ -4,3 +4,4 @@ export { default as ResetPassword } from './ResetPassword'; export { default as VerifyEmail } from './VerifyEmail'; export { default as ApiErrorWatcher } from './ApiErrorWatcher'; export { default as RequestPasswordReset } from './RequestPasswordReset'; +export { default as TwoFactorScreen } from './TwoFactorScreen'; diff --git a/client/src/components/Chat/Input/AudioRecorder.tsx b/client/src/components/Chat/Input/AudioRecorder.tsx index 96e29ec502..512c9c9d9c 100644 --- a/client/src/components/Chat/Input/AudioRecorder.tsx +++ b/client/src/components/Chat/Input/AudioRecorder.tsx @@ -81,17 +81,25 @@ export default function AudioRecorder({ return ( - {renderIcon()} - + render={ + + } + /> ); } diff --git a/client/src/components/Chat/Input/Files/AttachFileMenu.tsx b/client/src/components/Chat/Input/Files/AttachFileMenu.tsx index 54a8a595c4..8841a0ae51 100644 --- a/client/src/components/Chat/Input/Files/AttachFileMenu.tsx +++ b/client/src/components/Chat/Input/Files/AttachFileMenu.tsx @@ -1,7 +1,7 @@ import * as Ariakit from '@ariakit/react'; import React, { useRef, useState, useMemo } from 'react'; -import { FileSearch, ImageUpIcon, TerminalSquareIcon } from 'lucide-react'; import { EToolResources, EModelEndpoint } from 'librechat-data-provider'; +import { FileSearch, ImageUpIcon, TerminalSquareIcon, FileType2Icon } from 'lucide-react'; import { FileUpload, TooltipAnchor, DropdownPopup } from '~/components/ui'; import { useGetEndpointsQuery } from '~/data-provider'; import { AttachmentIcon } from '~/components/svg'; @@ -49,6 +49,17 @@ const AttachFile = ({ isRTL, disabled, handleFileChange }: AttachFileProps) => { }, ]; + if (capabilities.includes(EToolResources.ocr)) { + items.push({ + label: localize('com_ui_upload_ocr_text'), + onClick: () => { + setToolResource(EToolResources.ocr); + handleUploadClick(); + }, + icon: , + }); + } + if (capabilities.includes(EToolResources.file_search)) { items.push({ label: localize('com_ui_upload_file_search'), diff --git a/client/src/components/Chat/Input/Files/DragDropModal.tsx b/client/src/components/Chat/Input/Files/DragDropModal.tsx index b252ae1a93..2abc15a45b 100644 --- a/client/src/components/Chat/Input/Files/DragDropModal.tsx +++ b/client/src/components/Chat/Input/Files/DragDropModal.tsx @@ -1,6 +1,6 @@ import React, { useMemo } from 'react'; import { EModelEndpoint, EToolResources } from 'librechat-data-provider'; -import { FileSearch, ImageUpIcon, TerminalSquareIcon } from 'lucide-react'; +import { FileSearch, ImageUpIcon, FileType2Icon, TerminalSquareIcon } from 'lucide-react'; import OGDialogTemplate from '~/components/ui/OGDialogTemplate'; import { useGetEndpointsQuery } from '~/data-provider'; import useLocalize from '~/hooks/useLocalize'; @@ -50,6 +50,12 @@ const DragDropModal = ({ onOptionSelect, setShowModal, files, isVisible }: DragD value: EToolResources.execute_code, icon: , }); + } else if (capability === EToolResources.ocr) { + _options.push({ + label: localize('com_ui_upload_ocr_text'), + value: EToolResources.ocr, + icon: , + }); } } diff --git a/client/src/components/Chat/Input/Files/FilePreview.tsx b/client/src/components/Chat/Input/Files/FilePreview.tsx index 80933b8503..02851119af 100644 --- a/client/src/components/Chat/Input/Files/FilePreview.tsx +++ b/client/src/components/Chat/Input/Files/FilePreview.tsx @@ -19,7 +19,7 @@ const FilePreview = ({ }; className?: string; }) => { - const radius = 55; // Radius of the SVG circle + const radius = 55; const circumference = 2 * Math.PI * radius; const progress = useProgress( file?.['progress'] ?? 1, @@ -27,16 +27,15 @@ const FilePreview = ({ (file as ExtendedFile | undefined)?.size ?? 1, ); - // Calculate the offset based on the loading progress const offset = circumference - progress * circumference; const circleCSSProperties = { transition: 'stroke-dashoffset 0.5s linear', }; return ( -
+
- + {progress < 1 && ( = ({ let statusText: string; if (!status) { - statusText = text ?? localize('com_endpoint_import'); + statusText = text ?? localize('com_ui_import'); } else if (status === 'success') { statusText = successText ?? localize('com_ui_upload_success'); } else { @@ -72,12 +72,12 @@ const FileUpload: React.FC = ({ )} > - {statusText} + {statusText} diff --git a/client/src/components/Chat/Input/Files/SourceIcon.tsx b/client/src/components/Chat/Input/Files/SourceIcon.tsx index 9638936e9e..c3b2a4423c 100644 --- a/client/src/components/Chat/Input/Files/SourceIcon.tsx +++ b/client/src/components/Chat/Input/Files/SourceIcon.tsx @@ -1,3 +1,4 @@ +import { Terminal, Type, Database } from 'lucide-react'; import { EModelEndpoint, FileSources } from 'librechat-data-provider'; import { MinimalIcon } from '~/components/Endpoints'; import { cn } from '~/utils'; @@ -6,9 +7,13 @@ const sourceToEndpoint = { [FileSources.openai]: EModelEndpoint.openAI, [FileSources.azure]: EModelEndpoint.azureOpenAI, }; + const sourceToClassname = { [FileSources.openai]: 'bg-white/75 dark:bg-black/65', [FileSources.azure]: 'azure-bg-color opacity-85', + [FileSources.execute_code]: 'bg-black text-white opacity-85', + [FileSources.text]: 'bg-blue-500 dark:bg-blue-900 opacity-85 text-white', + [FileSources.vectordb]: 'bg-yellow-700 dark:bg-yellow-900 opacity-85 text-white', }; const defaultClassName = @@ -16,13 +21,41 @@ const defaultClassName = export default function SourceIcon({ source, + isCodeFile, className = defaultClassName, }: { source?: FileSources; + isCodeFile?: boolean; className?: string; }) { - if (source === FileSources.local || source === FileSources.firebase) { - return null; + if (isCodeFile === true) { + return ( +
+ + + +
+ ); + } + + if (source === FileSources.text) { + return ( +
+ + + +
+ ); + } + + if (source === FileSources.vectordb) { + return ( +
+ + + +
+ ); } const endpoint = sourceToEndpoint[source ?? '']; @@ -31,7 +64,7 @@ export default function SourceIcon({ return null; } return ( - +
); } diff --git a/client/src/components/Chat/Input/Files/Table/Columns.tsx b/client/src/components/Chat/Input/Files/Table/Columns.tsx index 3ca28bad8a..8b8f52d8e7 100644 --- a/client/src/components/Chat/Input/Files/Table/Columns.tsx +++ b/client/src/components/Chat/Input/Files/Table/Columns.tsx @@ -1,4 +1,4 @@ -/* eslint-disable react-hooks/rules-of-hooks */ + import { ArrowUpDown, Database } from 'lucide-react'; import { FileSources, FileContext } from 'librechat-data-provider'; import type { ColumnDef } from '@tanstack/react-table'; @@ -7,10 +7,10 @@ import { Button, Checkbox, OpenAIMinimalIcon, AzureMinimalIcon } from '~/compone import ImagePreview from '~/components/Chat/Input/Files/ImagePreview'; import FilePreview from '~/components/Chat/Input/Files/FilePreview'; import { SortFilterHeader } from './SortFilterHeader'; -import { useLocalize, useMediaQuery } from '~/hooks'; +import { TranslationKeys, useLocalize, useMediaQuery } from '~/hooks'; import { formatDate, getFileType } from '~/utils'; -const contextMap = { +const contextMap: Record = { [FileContext.avatar]: 'com_ui_avatar', [FileContext.unknown]: 'com_ui_unknown', [FileContext.assistants]: 'com_ui_assistants', @@ -127,8 +127,8 @@ export const columns: ColumnDef[] = [ ), }} valueMap={{ - [FileSources.azure]: 'Azure', - [FileSources.openai]: 'OpenAI', + [FileSources.azure]: 'com_ui_azure', + [FileSources.openai]: 'com_ui_openai', [FileSources.local]: 'com_ui_host', }} /> @@ -182,7 +182,7 @@ export const columns: ColumnDef[] = [ const localize = useLocalize(); return (
- {localize(contextMap[context ?? FileContext.unknown] ?? 'com_ui_unknown')} + {localize(contextMap[context ?? FileContext.unknown])}
); }, @@ -212,4 +212,4 @@ export const columns: ColumnDef[] = [ return `${value}${suffix}`; }, }, -]; +]; \ No newline at end of file diff --git a/client/src/components/Chat/Input/Files/Table/SortFilterHeader.tsx b/client/src/components/Chat/Input/Files/Table/SortFilterHeader.tsx index 1b9a0cbe42..bb9247c15a 100644 --- a/client/src/components/Chat/Input/Files/Table/SortFilterHeader.tsx +++ b/client/src/components/Chat/Input/Files/Table/SortFilterHeader.tsx @@ -16,7 +16,7 @@ interface SortFilterHeaderProps extends React.HTMLAttributes; filters?: Record; - valueMap?: Record; + valueMap?: Record; } export function SortFilterHeader({ @@ -82,7 +82,7 @@ export function SortFilterHeader({ const translationKey = valueMap?.[value ?? '']; const filterValue = translationKey != null && translationKey.length - ? localize(translationKey as TranslationKeys) + ? localize(translationKey) : String(value); if (!filterValue) { return null; diff --git a/client/src/components/Chat/Input/HeaderOptions.tsx b/client/src/components/Chat/Input/HeaderOptions.tsx index 0bd3326b53..5313f43b8d 100644 --- a/client/src/components/Chat/Input/HeaderOptions.tsx +++ b/client/src/components/Chat/Input/HeaderOptions.tsx @@ -1,8 +1,13 @@ import { useRecoilState } from 'recoil'; import { Settings2 } from 'lucide-react'; -import { Root, Anchor } from '@radix-ui/react-popover'; import { useState, useEffect, useMemo } from 'react'; -import { tConvoUpdateSchema, EModelEndpoint, isParamEndpoint } from 'librechat-data-provider'; +import { Root, Anchor } from '@radix-ui/react-popover'; +import { + EModelEndpoint, + isParamEndpoint, + isAgentsEndpoint, + tConvoUpdateSchema, +} from 'librechat-data-provider'; import type { TPreset, TInterfaceConfig } from 'librechat-data-provider'; import { EndpointSettings, SaveAsPresetDialog, AlternativeSettings } from '~/components/Endpoints'; import { PluginStoreDialog, TooltipAnchor } from '~/components'; @@ -42,7 +47,6 @@ export default function HeaderOptions({ if (endpoint && noSettings[endpoint]) { setShowPopover(false); } - // eslint-disable-next-line react-hooks/exhaustive-deps }, [endpoint, noSettings]); const saveAsPreset = () => { @@ -67,7 +71,7 @@ export default function HeaderOptions({
- {interfaceConfig?.modelSelect === true && ( + {interfaceConfig?.modelSelect === true && !isAgentsEndpoint(endpoint) && ( ; commandChar?: string; - placeholder?: string; + placeholder?: TranslationKeys; includeAssistants?: boolean; }) { const localize = useLocalize(); @@ -162,7 +162,7 @@ export default function Mention({
{name}
- {description ? description : localize('com_nav_welcome_message')} + {description || + (typeof startupConfig?.interface?.customWelcome === 'string' + ? startupConfig?.interface?.customWelcome + : localize('com_nav_welcome_message'))}
{/*
-
By Daniel Avila
+
By Daniel Avila
*/}
) : ( diff --git a/client/src/components/Chat/Menus/Models/ModelSpec.tsx b/client/src/components/Chat/Menus/Models/ModelSpec.tsx index 44cf51a976..617680946a 100644 --- a/client/src/components/Chat/Menus/Models/ModelSpec.tsx +++ b/client/src/components/Chat/Menus/Models/ModelSpec.tsx @@ -75,7 +75,7 @@ const MenuItem: FC = ({ {showIconInMenu && }
{title} -
{description}
+
{description}
diff --git a/client/src/components/Chat/Messages/Content/ContentParts.tsx b/client/src/components/Chat/Messages/Content/ContentParts.tsx index b997060c61..3805e0bb41 100644 --- a/client/src/components/Chat/Messages/Content/ContentParts.tsx +++ b/client/src/components/Chat/Messages/Content/ContentParts.tsx @@ -109,7 +109,9 @@ const ContentParts = memo( return val; }) } - label={isSubmitting ? localize('com_ui_thinking') : localize('com_ui_thoughts')} + label={ + isSubmitting && isLast ? localize('com_ui_thinking') : localize('com_ui_thoughts') + } />
)} @@ -137,6 +139,7 @@ const ContentParts = memo( isSubmitting={isSubmitting} key={`part-${messageId}-${idx}`} isCreatedByUser={isCreatedByUser} + isLast={idx === content.length - 1} showCursor={idx === content.length - 1 && isLast} /> diff --git a/client/src/components/Chat/Messages/Content/Image.tsx b/client/src/components/Chat/Messages/Content/Image.tsx index 28910d0315..41ee52453f 100644 --- a/client/src/components/Chat/Messages/Content/Image.tsx +++ b/client/src/components/Chat/Messages/Content/Image.tsx @@ -29,6 +29,7 @@ const Image = ({ height, width, placeholderDimensions, + className, }: { imagePath: string; altText: string; @@ -38,6 +39,7 @@ const Image = ({ height?: string; width?: string; }; + className?: string; }) => { const [isLoaded, setIsLoaded] = useState(false); const containerRef = useRef(null); @@ -57,7 +59,12 @@ const Image = ({ return (
-
+
diff --git a/client/src/components/Nav/SettingsTabs/Account/BackupCodesItem.tsx b/client/src/components/Nav/SettingsTabs/Account/BackupCodesItem.tsx new file mode 100644 index 0000000000..a034e2773a --- /dev/null +++ b/client/src/components/Nav/SettingsTabs/Account/BackupCodesItem.tsx @@ -0,0 +1,194 @@ +import React, { useState } from 'react'; +import { RefreshCcw, ShieldX } from 'lucide-react'; +import { motion, AnimatePresence } from 'framer-motion'; +import { TBackupCode, TRegenerateBackupCodesResponse, type TUser } from 'librechat-data-provider'; +import { + OGDialog, + OGDialogContent, + OGDialogTitle, + OGDialogTrigger, + Button, + Label, + Spinner, + TooltipAnchor, +} from '~/components'; +import { useRegenerateBackupCodesMutation } from '~/data-provider'; +import { useAuthContext, useLocalize } from '~/hooks'; +import { useToastContext } from '~/Providers'; +import { useSetRecoilState } from 'recoil'; +import store from '~/store'; + +const BackupCodesItem: React.FC = () => { + const localize = useLocalize(); + const { user } = useAuthContext(); + const { showToast } = useToastContext(); + const setUser = useSetRecoilState(store.user); + const [isDialogOpen, setDialogOpen] = useState(false); + + const { mutate: regenerateBackupCodes, isLoading } = useRegenerateBackupCodesMutation(); + + const fetchBackupCodes = (auto: boolean = false) => { + regenerateBackupCodes(undefined, { + onSuccess: (data: TRegenerateBackupCodesResponse) => { + const newBackupCodes: TBackupCode[] = data.backupCodesHash.map((codeHash) => ({ + codeHash, + used: false, + usedAt: null, + })); + + setUser((prev) => ({ ...prev, backupCodes: newBackupCodes }) as TUser); + showToast({ + message: localize('com_ui_backup_codes_regenerated'), + status: 'success', + }); + + // Trigger file download only when user explicitly clicks the button. + if (!auto && newBackupCodes.length) { + const codesString = data.backupCodes.join('\n'); + const blob = new Blob([codesString], { type: 'text/plain;charset=utf-8' }); + const url = URL.createObjectURL(blob); + const a = document.createElement('a'); + a.href = url; + a.download = 'backup-codes.txt'; + a.click(); + URL.revokeObjectURL(url); + } + }, + onError: () => + showToast({ + message: localize('com_ui_backup_codes_regenerate_error'), + status: 'error', + }), + }); + }; + + const handleRegenerate = () => { + fetchBackupCodes(false); + }; + + return ( + +
+
+ +
+ + + +
+ + + + {localize('com_ui_backup_codes')} + + + + + {Array.isArray(user?.backupCodes) && user?.backupCodes.length > 0 ? ( + <> +
+ {user?.backupCodes.map((code, index) => { + const isUsed = code.used; + const description = `Backup code number ${index + 1}, ${ + isUsed + ? `used on ${code.usedAt ? new Date(code.usedAt).toLocaleDateString() : 'an unknown date'}` + : 'not used yet' + }`; + + return ( + { + const announcement = new CustomEvent('announce', { + detail: { message: description }, + }); + document.dispatchEvent(announcement); + }} + className={`flex flex-col rounded-xl border p-4 backdrop-blur-sm transition-colors focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-primary ${ + isUsed + ? 'border-red-200 bg-red-50/80 dark:border-red-800 dark:bg-red-900/20' + : 'border-green-200 bg-green-50/80 dark:border-green-800 dark:bg-green-900/20' + } `} + > + + + ); + })} +
+
+ +
+ + ) : ( +
+ +

{localize('com_ui_no_backup_codes')}

+ +
+ )} +
+
+
+
+ ); +}; + +export default React.memo(BackupCodesItem); diff --git a/client/src/components/Nav/SettingsTabs/Account/DeleteAccount.tsx b/client/src/components/Nav/SettingsTabs/Account/DeleteAccount.tsx index 86538669db..b00e7498bc 100644 --- a/client/src/components/Nav/SettingsTabs/Account/DeleteAccount.tsx +++ b/client/src/components/Nav/SettingsTabs/Account/DeleteAccount.tsx @@ -14,6 +14,7 @@ import { useDeleteUserMutation } from '~/data-provider'; import { useAuthContext } from '~/hooks/AuthContext'; import { useLocalize } from '~/hooks'; import { cn } from '~/utils'; +import { LocalizeFunction } from '~/common'; const DeleteAccount = ({ disabled = false }: { title?: string; disabled?: boolean }) => { const localize = useLocalize(); @@ -56,7 +57,7 @@ const DeleteAccount = ({ disabled = false }: { title?: string; disabled?: boolea
- + {localize('com_nav_delete_account_confirm')} @@ -103,7 +104,7 @@ const renderDeleteButton = ( handleDeleteUser: () => void, isDeleting: boolean, isLocked: boolean, - localize: (key: string) => string, + localize: LocalizeFunction, ) => ( +
+ + ); +}; diff --git a/client/src/components/Nav/SettingsTabs/Account/TwoFactorAuthentication.tsx b/client/src/components/Nav/SettingsTabs/Account/TwoFactorAuthentication.tsx new file mode 100644 index 0000000000..eb88b594ce --- /dev/null +++ b/client/src/components/Nav/SettingsTabs/Account/TwoFactorAuthentication.tsx @@ -0,0 +1,302 @@ +import React, { useCallback, useState } from 'react'; +import { useSetRecoilState } from 'recoil'; +import { SmartphoneIcon } from 'lucide-react'; +import { motion, AnimatePresence } from 'framer-motion'; +import type { TUser, TVerify2FARequest } from 'librechat-data-provider'; +import { OGDialog, OGDialogContent, OGDialogHeader, OGDialogTitle, Progress } from '~/components'; +import { SetupPhase, QRPhase, VerifyPhase, BackupPhase, DisablePhase } from './TwoFactorPhases'; +import { DisableTwoFactorToggle } from './DisableTwoFactorToggle'; +import { useAuthContext, useLocalize } from '~/hooks'; +import { useToastContext } from '~/Providers'; +import store from '~/store'; +import { + useConfirmTwoFactorMutation, + useDisableTwoFactorMutation, + useEnableTwoFactorMutation, + useVerifyTwoFactorMutation, +} from '~/data-provider'; + +export type Phase = 'setup' | 'qr' | 'verify' | 'backup' | 'disable'; + +const phaseVariants = { + initial: { opacity: 0, scale: 0.95 }, + animate: { opacity: 1, scale: 1, transition: { duration: 0.3, ease: 'easeOut' } }, + exit: { opacity: 0, scale: 0.95, transition: { duration: 0.3, ease: 'easeIn' } }, +}; + +const TwoFactorAuthentication: React.FC = () => { + const localize = useLocalize(); + const { user } = useAuthContext(); + const setUser = useSetRecoilState(store.user); + const { showToast } = useToastContext(); + + const [secret, setSecret] = useState(''); + const [otpauthUrl, setOtpauthUrl] = useState(''); + const [downloaded, setDownloaded] = useState(false); + const [disableToken, setDisableToken] = useState(''); + const [backupCodes, setBackupCodes] = useState([]); + const [isDialogOpen, setDialogOpen] = useState(false); + const [verificationToken, setVerificationToken] = useState(''); + const [phase, setPhase] = useState(user?.twoFactorEnabled ? 'disable' : 'setup'); + + const { mutate: confirm2FAMutate } = useConfirmTwoFactorMutation(); + const { mutate: enable2FAMutate, isLoading: isGenerating } = useEnableTwoFactorMutation(); + const { mutate: verify2FAMutate, isLoading: isVerifying } = useVerifyTwoFactorMutation(); + const { mutate: disable2FAMutate, isLoading: isDisabling } = useDisableTwoFactorMutation(); + + const steps = ['Setup', 'Scan QR', 'Verify', 'Backup']; + const phasesLabel: Record = { + setup: 'Setup', + qr: 'Scan QR', + verify: 'Verify', + backup: 'Backup', + disable: '', + }; + + const currentStep = steps.indexOf(phasesLabel[phase]); + + const resetState = useCallback(() => { + if (user?.twoFactorEnabled && otpauthUrl) { + disable2FAMutate(undefined, { + onError: () => + showToast({ message: localize('com_ui_2fa_disable_error'), status: 'error' }), + }); + } + + setOtpauthUrl(''); + setSecret(''); + setBackupCodes([]); + setVerificationToken(''); + setDisableToken(''); + setPhase(user?.twoFactorEnabled ? 'disable' : 'setup'); + setDownloaded(false); + }, [user, otpauthUrl, disable2FAMutate, localize, showToast]); + + const handleGenerateQRCode = useCallback(() => { + enable2FAMutate(undefined, { + onSuccess: ({ otpauthUrl, backupCodes }) => { + setOtpauthUrl(otpauthUrl); + setSecret(otpauthUrl.split('secret=')[1].split('&')[0]); + setBackupCodes(backupCodes); + setPhase('qr'); + }, + onError: () => showToast({ message: localize('com_ui_2fa_generate_error'), status: 'error' }), + }); + }, [enable2FAMutate, localize, showToast]); + + const handleVerify = useCallback(() => { + if (!verificationToken) { + return; + } + + verify2FAMutate( + { token: verificationToken }, + { + onSuccess: () => { + showToast({ message: localize('com_ui_2fa_verified') }); + confirm2FAMutate( + { token: verificationToken }, + { + onSuccess: () => setPhase('backup'), + onError: () => + showToast({ message: localize('com_ui_2fa_invalid'), status: 'error' }), + }, + ); + }, + onError: () => showToast({ message: localize('com_ui_2fa_invalid'), status: 'error' }), + }, + ); + }, [verificationToken, verify2FAMutate, confirm2FAMutate, localize, showToast]); + + const handleDownload = useCallback(() => { + if (!backupCodes.length) { + return; + } + const blob = new Blob([backupCodes.join('\n')], { type: 'text/plain;charset=utf-8' }); + const url = URL.createObjectURL(blob); + const a = document.createElement('a'); + a.href = url; + a.download = 'backup-codes.txt'; + a.click(); + URL.revokeObjectURL(url); + setDownloaded(true); + }, [backupCodes]); + + const handleConfirm = useCallback(() => { + setDialogOpen(false); + setPhase('disable'); + showToast({ message: localize('com_ui_2fa_enabled') }); + setUser( + (prev) => + ({ + ...prev, + backupCodes: backupCodes.map((code) => ({ + code, + codeHash: code, + used: false, + usedAt: null, + })), + twoFactorEnabled: true, + }) as TUser, + ); + }, [setUser, localize, showToast, backupCodes]); + + const handleDisableVerify = useCallback( + (token: string, useBackup: boolean) => { + // Validate: if not using backup, ensure token has at least 6 digits; + // if using backup, ensure backup code has at least 8 characters. + if (!useBackup && token.trim().length < 6) { + return; + } + + if (useBackup && token.trim().length < 8) { + return; + } + + const payload: TVerify2FARequest = {}; + if (useBackup) { + payload.backupCode = token.trim(); + } else { + payload.token = token.trim(); + } + + verify2FAMutate(payload, { + onSuccess: () => { + disable2FAMutate(undefined, { + onSuccess: () => { + showToast({ message: localize('com_ui_2fa_disabled') }); + setDialogOpen(false); + setUser( + (prev) => + ({ + ...prev, + totpSecret: '', + backupCodes: [], + twoFactorEnabled: false, + }) as TUser, + ); + setPhase('setup'); + setOtpauthUrl(''); + }, + onError: () => + showToast({ message: localize('com_ui_2fa_disable_error'), status: 'error' }), + }); + }, + onError: () => showToast({ message: localize('com_ui_2fa_invalid'), status: 'error' }), + }); + }, + [verify2FAMutate, disable2FAMutate, showToast, localize, setUser], + ); + + return ( + { + setDialogOpen(open); + if (!open) { + resetState(); + } + }} + > + setDialogOpen(true)} + disabled={isVerifying || isDisabling || isGenerating} + /> + + + + + + + + {user?.twoFactorEnabled + ? localize('com_ui_2fa_disable') + : localize('com_ui_2fa_setup')} + + {user?.twoFactorEnabled && phase !== 'disable' && ( +
+ +
+ {steps.map((step, index) => ( + = index ? 'var(--text-primary)' : 'var(--text-tertiary)', + }} + className="font-medium" + > + {step} + + ))} +
+
+ )} +
+ + + {phase === 'setup' && ( + setPhase('qr')} + onError={(error) => showToast({ message: error.message, status: 'error' })} + /> + )} + + {phase === 'qr' && ( + setPhase('verify')} + onError={(error) => showToast({ message: error.message, status: 'error' })} + /> + )} + + {phase === 'verify' && ( + showToast({ message: error.message, status: 'error' })} + /> + )} + + {phase === 'backup' && ( + showToast({ message: error.message, status: 'error' })} + /> + )} + + {phase === 'disable' && ( + showToast({ message: error.message, status: 'error' })} + /> + )} + +
+
+
+
+ ); +}; + +export default React.memo(TwoFactorAuthentication); diff --git a/client/src/components/Nav/SettingsTabs/Account/TwoFactorPhases/BackupPhase.tsx b/client/src/components/Nav/SettingsTabs/Account/TwoFactorPhases/BackupPhase.tsx new file mode 100644 index 0000000000..67e05a1423 --- /dev/null +++ b/client/src/components/Nav/SettingsTabs/Account/TwoFactorPhases/BackupPhase.tsx @@ -0,0 +1,60 @@ +import React from 'react'; +import { motion } from 'framer-motion'; +import { Download } from 'lucide-react'; +import { Button, Label } from '~/components'; +import { useLocalize } from '~/hooks'; + +const fadeAnimation = { + initial: { opacity: 0, y: 20 }, + animate: { opacity: 1, y: 0 }, + exit: { opacity: 0, y: -20 }, + transition: { duration: 0.2 }, +}; + +interface BackupPhaseProps { + onNext: () => void; + onError: (error: Error) => void; + backupCodes: string[]; + onDownload: () => void; + downloaded: boolean; +} + +export const BackupPhase: React.FC = ({ + backupCodes, + onDownload, + downloaded, + onNext, +}) => { + const localize = useLocalize(); + + return ( + + +
+ {backupCodes.map((code, index) => ( + +
+ #{index + 1} + {code} +
+
+ ))} +
+
+ + +
+
+ ); +}; diff --git a/client/src/components/Nav/SettingsTabs/Account/TwoFactorPhases/DisablePhase.tsx b/client/src/components/Nav/SettingsTabs/Account/TwoFactorPhases/DisablePhase.tsx new file mode 100644 index 0000000000..fc444b918c --- /dev/null +++ b/client/src/components/Nav/SettingsTabs/Account/TwoFactorPhases/DisablePhase.tsx @@ -0,0 +1,88 @@ +import React, { useState } from 'react'; +import { motion } from 'framer-motion'; +import { REGEXP_ONLY_DIGITS, REGEXP_ONLY_DIGITS_AND_CHARS } from 'input-otp'; +import { + Button, + InputOTP, + InputOTPGroup, + InputOTPSlot, + InputOTPSeparator, + Spinner, +} from '~/components'; +import { useLocalize } from '~/hooks'; + +const fadeAnimation = { + initial: { opacity: 0, y: 20 }, + animate: { opacity: 1, y: 0 }, + exit: { opacity: 0, y: -20 }, + transition: { duration: 0.2 }, +}; + +interface DisablePhaseProps { + onSuccess?: () => void; + onError?: (error: Error) => void; + onDisable: (token: string, useBackup: boolean) => void; + isDisabling: boolean; +} + +export const DisablePhase: React.FC = ({ onDisable, isDisabling }) => { + const localize = useLocalize(); + const [token, setToken] = useState(''); + const [useBackup, setUseBackup] = useState(false); + + return ( + +
+ + {useBackup ? ( + + + + + + + + + + + ) : ( + <> + + + + + + + + + + + + + )} + +
+ + +
+ ); +}; diff --git a/client/src/components/Nav/SettingsTabs/Account/TwoFactorPhases/QRPhase.tsx b/client/src/components/Nav/SettingsTabs/Account/TwoFactorPhases/QRPhase.tsx new file mode 100644 index 0000000000..7a0eccae3f --- /dev/null +++ b/client/src/components/Nav/SettingsTabs/Account/TwoFactorPhases/QRPhase.tsx @@ -0,0 +1,66 @@ +import React, { useState } from 'react'; +import { motion } from 'framer-motion'; +import { QRCodeSVG } from 'qrcode.react'; +import { Copy, Check } from 'lucide-react'; +import { Input, Button, Label } from '~/components'; +import { useLocalize } from '~/hooks'; +import { cn } from '~/utils'; + +const fadeAnimation = { + initial: { opacity: 0, y: 20 }, + animate: { opacity: 1, y: 0 }, + exit: { opacity: 0, y: -20 }, + transition: { duration: 0.2 }, +}; + +interface QRPhaseProps { + secret: string; + otpauthUrl: string; + onNext: () => void; + onSuccess?: () => void; + onError?: (error: Error) => void; +} + +export const QRPhase: React.FC = ({ secret, otpauthUrl, onNext }) => { + const localize = useLocalize(); + const [isCopying, setIsCopying] = useState(false); + + const handleCopy = async () => { + await navigator.clipboard.writeText(secret); + setIsCopying(true); + setTimeout(() => setIsCopying(false), 2000); + }; + + return ( + +
+ + + +
+ +
+ + +
+
+
+ +
+ ); +}; diff --git a/client/src/components/Nav/SettingsTabs/Account/TwoFactorPhases/SetupPhase.tsx b/client/src/components/Nav/SettingsTabs/Account/TwoFactorPhases/SetupPhase.tsx new file mode 100644 index 0000000000..feca4d3254 --- /dev/null +++ b/client/src/components/Nav/SettingsTabs/Account/TwoFactorPhases/SetupPhase.tsx @@ -0,0 +1,42 @@ +import React from 'react'; +import { QrCode } from 'lucide-react'; +import { motion } from 'framer-motion'; +import { Button, Spinner } from '~/components'; +import { useLocalize } from '~/hooks'; + +const fadeAnimation = { + initial: { opacity: 0, y: 20 }, + animate: { opacity: 1, y: 0 }, + exit: { opacity: 0, y: -20 }, + transition: { duration: 0.2 }, +}; + +interface SetupPhaseProps { + onNext: () => void; + onError: (error: Error) => void; + isGenerating: boolean; + onGenerate: () => void; +} + +export const SetupPhase: React.FC = ({ isGenerating, onGenerate }) => { + const localize = useLocalize(); + + return ( + +
+

+ {localize('com_ui_2fa_account_security')} +

+ +
+
+ ); +}; diff --git a/client/src/components/Nav/SettingsTabs/Account/TwoFactorPhases/VerifyPhase.tsx b/client/src/components/Nav/SettingsTabs/Account/TwoFactorPhases/VerifyPhase.tsx new file mode 100644 index 0000000000..e872dfa0d2 --- /dev/null +++ b/client/src/components/Nav/SettingsTabs/Account/TwoFactorPhases/VerifyPhase.tsx @@ -0,0 +1,58 @@ +import React from 'react'; +import { motion } from 'framer-motion'; +import { Button, InputOTP, InputOTPGroup, InputOTPSeparator, InputOTPSlot } from '~/components'; +import { REGEXP_ONLY_DIGITS } from 'input-otp'; +import { useLocalize } from '~/hooks'; + +const fadeAnimation = { + initial: { opacity: 0, y: 20 }, + animate: { opacity: 1, y: 0 }, + exit: { opacity: 0, y: -20 }, + transition: { duration: 0.2 }, +}; + +interface VerifyPhaseProps { + token: string; + onTokenChange: (value: string) => void; + isVerifying: boolean; + onNext: () => void; + onError: (error: Error) => void; +} + +export const VerifyPhase: React.FC = ({ + token, + onTokenChange, + isVerifying, + onNext, +}) => { + const localize = useLocalize(); + + return ( + +
+ + + {Array.from({ length: 3 }).map((_, i) => ( + + ))} + + + + {Array.from({ length: 3 }).map((_, i) => ( + + ))} + + +
+ +
+ ); +}; diff --git a/client/src/components/Nav/SettingsTabs/Account/TwoFactorPhases/index.ts b/client/src/components/Nav/SettingsTabs/Account/TwoFactorPhases/index.ts new file mode 100644 index 0000000000..1cc474efef --- /dev/null +++ b/client/src/components/Nav/SettingsTabs/Account/TwoFactorPhases/index.ts @@ -0,0 +1,5 @@ +export * from './BackupPhase'; +export * from './QRPhase'; +export * from './VerifyPhase'; +export * from './SetupPhase'; +export * from './DisablePhase'; diff --git a/client/src/components/Nav/SettingsTabs/Commands/PlusCommandSwitch.tsx b/client/src/components/Nav/SettingsTabs/Commands/PlusCommandSwitch.tsx index 12cf428ae1..2351633945 100644 --- a/client/src/components/Nav/SettingsTabs/Commands/PlusCommandSwitch.tsx +++ b/client/src/components/Nav/SettingsTabs/Commands/PlusCommandSwitch.tsx @@ -18,7 +18,6 @@ export default function PlusCommandSwitch() { id="plusCommand" checked={plusCommand} onCheckedChange={handleCheckedChange} - f className="ml-4" data-testid="plusCommand" /> diff --git a/client/src/components/Nav/SettingsTabs/Commands/SlashCommandSwitch.tsx b/client/src/components/Nav/SettingsTabs/Commands/SlashCommandSwitch.tsx index 2051fd033a..18c24583fa 100644 --- a/client/src/components/Nav/SettingsTabs/Commands/SlashCommandSwitch.tsx +++ b/client/src/components/Nav/SettingsTabs/Commands/SlashCommandSwitch.tsx @@ -18,7 +18,6 @@ export default function SlashCommandSwitch() { id="slashCommand" checked={slashCommand} onCheckedChange={handleCheckedChange} - f data-testid="slashCommand" /> diff --git a/client/src/components/Nav/SettingsTabs/Data/ImportConversations.tsx b/client/src/components/Nav/SettingsTabs/Data/ImportConversations.tsx index c39e8351e8..e3bafd9152 100644 --- a/client/src/components/Nav/SettingsTabs/Data/ImportConversations.tsx +++ b/client/src/components/Nav/SettingsTabs/Data/ImportConversations.tsx @@ -82,7 +82,7 @@ function ImportConversations() { onClick={handleImportClick} onKeyDown={handleKeyDown} disabled={!allowImport} - aria-label={localize('com_ui_import_conversation')} + aria-label={localize('com_ui_import')} className="btn btn-neutral relative" > {allowImport ? ( @@ -90,7 +90,7 @@ function ImportConversations() { ) : ( )} - {localize('com_ui_import_conversation')} + {localize('com_ui_import')} setIsOpen(true)}> - + +

- {user?.role === SystemRoles.ADMIN && } - {/* Context Button */} -
- - {(agent?.author === user?.id || user?.role === SystemRoles.ADMIN) && - hasAccessToShareAgents && ( - - )} - {agent && agent.author === user?.id && } - {/* Submit Button */} - -
& { + updateMutation: ReturnType; +}) { + const localize = useLocalize(); + const { user } = useAuthContext(); + + const methods = useFormContext(); + + const { control } = methods; + const agent = useWatch({ control, name: 'agent' }); + const agent_id = useWatch({ control, name: 'id' }); + + const hasAccessToShareAgents = useHasAccess({ + permissionType: PermissionTypes.AGENTS, + permission: Permissions.SHARED_GLOBAL, + }); + + const renderSaveButton = () => { + if (createMutation.isLoading || updateMutation.isLoading) { + return